feat: refactor model manager
* chore: mv model icon * fix: model icon * fix: model icon * feat: refactor model manager * fix: model icon * fix: model icon * feat: refactor model manager See merge request: !905
This commit is contained in:
@@ -18,7 +18,6 @@ package coze
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -49,7 +48,6 @@ import (
|
||||
"gorm.io/gorm"
|
||||
|
||||
modelknowledge "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
|
||||
crossmodelmgr "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
|
||||
plugin2 "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
|
||||
pluginmodel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/playground"
|
||||
@@ -85,6 +83,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/service"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/checkpoint"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/coderunner"
|
||||
mockCrossUser "github.com/coze-dev/coze-studio/backend/internal/mock/crossdomain/crossuser"
|
||||
@@ -1503,7 +1502,7 @@ func TestNestedSubWorkflowWithInterrupt(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, params *model.LLMParams) (model2.BaseChatModel, *crossmodelmgr.Model, error) {
|
||||
r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, params *model.LLMParams) (model2.BaseChatModel, *modelmgr.Model, error) {
|
||||
if params.ModelType == 1737521813 {
|
||||
return chatModel1, nil, nil
|
||||
} else {
|
||||
@@ -1972,7 +1971,7 @@ func TestReturnDirectlyStreamableTool(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, params *model.LLMParams) (model2.BaseChatModel, *crossmodelmgr.Model, error) {
|
||||
r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, params *model.LLMParams) (model2.BaseChatModel, *modelmgr.Model, error) {
|
||||
if params.ModelType == 1706077826 {
|
||||
innerModel.ModelType = strconv.FormatInt(params.ModelType, 10)
|
||||
return innerModel, nil, nil
|
||||
@@ -2161,7 +2160,7 @@ func TestStreamableToolWithMultipleInterrupts(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, params *model.LLMParams) (model2.BaseChatModel, *crossmodelmgr.Model, error) {
|
||||
r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, params *model.LLMParams) (model2.BaseChatModel, *modelmgr.Model, error) {
|
||||
if params.ModelType == 1706077827 {
|
||||
outerModel.ModelType = strconv.FormatInt(params.ModelType, 10)
|
||||
return outerModel, nil, nil
|
||||
@@ -2455,7 +2454,7 @@ func TestAggregateStreamVariables(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, params *model.LLMParams) (model2.BaseChatModel, *crossmodelmgr.Model, error) {
|
||||
r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, params *model.LLMParams) (model2.BaseChatModel, *modelmgr.Model, error) {
|
||||
if params.ModelType == 1737521813 {
|
||||
cm1.ModelType = strconv.FormatInt(params.ModelType, 10)
|
||||
return cm1, nil, nil
|
||||
@@ -2598,7 +2597,7 @@ func TestParallelInterrupts(t *testing.T) {
|
||||
}
|
||||
},
|
||||
}
|
||||
r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, params *model.LLMParams) (model2.BaseChatModel, *crossmodelmgr.Model, error) {
|
||||
r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, params *model.LLMParams) (model2.BaseChatModel, *modelmgr.Model, error) {
|
||||
if params.ModelType == 1737521813 {
|
||||
return chatModel1, nil, nil
|
||||
} else {
|
||||
@@ -3871,7 +3870,7 @@ func TestLLMException(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, params *model.LLMParams) (model2.BaseChatModel, *crossmodelmgr.Model, error) {
|
||||
r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, params *model.LLMParams) (model2.BaseChatModel, *modelmgr.Model, error) {
|
||||
if params.ModelType == 1737521813 {
|
||||
return mainChatModel, nil, nil
|
||||
} else {
|
||||
@@ -3938,7 +3937,7 @@ func TestLLMExceptionThenThrow(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, params *model.LLMParams) (model2.BaseChatModel, *crossmodelmgr.Model, error) {
|
||||
r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, params *model.LLMParams) (model2.BaseChatModel, *modelmgr.Model, error) {
|
||||
if params.ModelType == 1737521813 {
|
||||
return mainChatModel, nil, nil
|
||||
} else {
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
package modelmgr
|
||||
|
||||
type ParameterName string
|
||||
|
||||
const (
|
||||
Temperature ParameterName = "temperature"
|
||||
TopP ParameterName = "top_p"
|
||||
TopK ParameterName = "top_k"
|
||||
MaxTokens ParameterName = "max_tokens"
|
||||
RespFormat ParameterName = "response_format"
|
||||
FrequencyPenalty ParameterName = "frequency_penalty"
|
||||
PresencePenalty ParameterName = "presence_penalty"
|
||||
)
|
||||
|
||||
type ValueType string
|
||||
|
||||
const (
|
||||
ValueTypeInt ValueType = "int"
|
||||
ValueTypeFloat ValueType = "float"
|
||||
ValueTypeBoolean ValueType = "boolean"
|
||||
ValueTypeString ValueType = "string"
|
||||
)
|
||||
|
||||
type DefaultType string
|
||||
|
||||
const (
|
||||
DefaultTypeDefault DefaultType = "default_val"
|
||||
DefaultTypeCreative DefaultType = "creative"
|
||||
DefaultTypeBalance DefaultType = "balance"
|
||||
DefaultTypePrecise DefaultType = "precise"
|
||||
)
|
||||
|
||||
// Deprecated
|
||||
type Scenario int64 // 模型实体使用场景
|
||||
|
||||
type Modal string
|
||||
|
||||
const (
|
||||
ModalText Modal = "text"
|
||||
ModalImage Modal = "image"
|
||||
ModalFile Modal = "file"
|
||||
ModalAudio Modal = "audio"
|
||||
ModalVideo Modal = "video"
|
||||
)
|
||||
|
||||
type ModelMetaStatus int64 // 模型实体状态
|
||||
|
||||
const (
|
||||
StatusInUse ModelMetaStatus = 1 // 应用中,可使用可新建
|
||||
StatusPending ModelMetaStatus = 5 // 待下线,可使用不可新建
|
||||
StatusDeleted ModelMetaStatus = 10 // 已下线,不可使用不可新建
|
||||
)
|
||||
|
||||
type Widget string
|
||||
|
||||
const (
|
||||
WidgetSlider Widget = "slider"
|
||||
WidgetRadioButtons Widget = "radio_buttons"
|
||||
)
|
||||
|
||||
type ModelEntityStatus int64
|
||||
|
||||
const (
|
||||
ModelEntityStatusDefault ModelEntityStatus = 0
|
||||
ModelEntityStatusInUse ModelEntityStatus = 1
|
||||
ModelEntityStatusPending ModelEntityStatus = 5
|
||||
ModelEntityStatusDeleted ModelEntityStatus = 10
|
||||
)
|
||||
@@ -1,171 +0,0 @@
|
||||
package modelmgr
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/i18n"
|
||||
)
|
||||
|
||||
type MGetModelRequest struct {
|
||||
IDs []int64
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
ID int64 `yaml:"id"`
|
||||
Name string `yaml:"name"`
|
||||
Description string `yaml:"description"`
|
||||
DefaultParameters []*Parameter `yaml:"default_parameters"`
|
||||
|
||||
CreatedAtMs int64
|
||||
UpdatedAtMs int64
|
||||
DeletedAtMs int64
|
||||
|
||||
Meta ModelMeta `yaml:"meta"`
|
||||
}
|
||||
|
||||
type Parameter struct {
|
||||
Name ParameterName `json:"name" yaml:"name"`
|
||||
Label *MultilingualText `json:"label,omitempty" yaml:"label,omitempty"`
|
||||
Desc *MultilingualText `json:"desc" yaml:"desc"`
|
||||
Type ValueType `json:"type" yaml:"type"`
|
||||
Min string `json:"min" yaml:"min"`
|
||||
Max string `json:"max" yaml:"max"`
|
||||
DefaultVal DefaultValue `json:"default_val" yaml:"default_val"`
|
||||
Precision int `json:"precision,omitempty" yaml:"precision,omitempty"` // float precision, default 2
|
||||
Options []*ParamOption `json:"options" yaml:"options"` // enum options
|
||||
Style DisplayStyle `json:"param_class" yaml:"style"`
|
||||
}
|
||||
|
||||
func (p *Parameter) GetFloat(tp DefaultType) (float64, error) {
|
||||
if p.Type != ValueTypeFloat {
|
||||
return 0, fmt.Errorf("unexpected paramerter type, name=%v, expect=%v, given=%v",
|
||||
p.Name, ValueTypeFloat, p.Type)
|
||||
}
|
||||
|
||||
if tp != DefaultTypeDefault && p.DefaultVal[tp] == "" {
|
||||
tp = DefaultTypeDefault
|
||||
}
|
||||
|
||||
val, ok := p.DefaultVal[tp]
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("unexpected default type, name=%v, type=%v", p.Name, tp)
|
||||
}
|
||||
|
||||
return strconv.ParseFloat(val, 64)
|
||||
}
|
||||
|
||||
func (p *Parameter) GetInt(tp DefaultType) (int64, error) {
|
||||
if p.Type != ValueTypeInt {
|
||||
return 0, fmt.Errorf("unexpected paramerter type, name=%v, expect=%v, given=%v",
|
||||
p.Name, ValueTypeInt, p.Type)
|
||||
}
|
||||
|
||||
if tp != DefaultTypeDefault && p.DefaultVal[tp] == "" {
|
||||
tp = DefaultTypeDefault
|
||||
}
|
||||
val, ok := p.DefaultVal[tp]
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("unexpected default type, name=%v, type=%v", p.Name, tp)
|
||||
}
|
||||
return strconv.ParseInt(val, 10, 64)
|
||||
}
|
||||
|
||||
func (p *Parameter) GetBool(tp DefaultType) (bool, error) {
|
||||
if p.Type != ValueTypeBoolean {
|
||||
return false, fmt.Errorf("unexpected paramerter type, name=%v, expect=%v, given=%v",
|
||||
p.Name, ValueTypeBoolean, p.Type)
|
||||
}
|
||||
if tp != DefaultTypeDefault && p.DefaultVal[tp] == "" {
|
||||
tp = DefaultTypeDefault
|
||||
}
|
||||
val, ok := p.DefaultVal[tp]
|
||||
if !ok {
|
||||
return false, fmt.Errorf("unexpected default type, name=%v, type=%v", p.Name, tp)
|
||||
}
|
||||
return strconv.ParseBool(val)
|
||||
}
|
||||
|
||||
func (p *Parameter) GetString(tp DefaultType) (string, error) {
|
||||
if tp != DefaultTypeDefault && p.DefaultVal[tp] == "" {
|
||||
tp = DefaultTypeDefault
|
||||
}
|
||||
|
||||
val, ok := p.DefaultVal[tp]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("unexpected default type, name=%v, type=%v", p.Name, tp)
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
type ModelMeta struct {
|
||||
ID int64 `yaml:"id"`
|
||||
Name string `yaml:"name"`
|
||||
IconURI string `yaml:"icon_uri"`
|
||||
IconURL string `yaml:"icon_url"`
|
||||
Description *MultilingualText `yaml:"description"`
|
||||
|
||||
CreatedAtMs int64
|
||||
UpdatedAtMs int64
|
||||
DeletedAtMs int64
|
||||
|
||||
Protocol chatmodel.Protocol `yaml:"protocol"` // 模型通信协议
|
||||
Capability *Capability `yaml:"capability"` // 模型能力
|
||||
ConnConfig *chatmodel.Config `yaml:"conn_config"` // 模型连接配置
|
||||
Status ModelMetaStatus `yaml:"status"` // 模型状态
|
||||
}
|
||||
|
||||
type DefaultValue map[DefaultType]string
|
||||
|
||||
type DisplayStyle struct {
|
||||
Widget Widget `json:"class_id" yaml:"widget"`
|
||||
Label *MultilingualText `json:"label" yaml:"label"`
|
||||
}
|
||||
|
||||
type ParamOption struct {
|
||||
Label string `json:"label"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
type Capability struct {
|
||||
// Model supports function calling
|
||||
FunctionCall bool `json:"function_call" yaml:"function_call" mapstructure:"function_call"`
|
||||
// Input modals
|
||||
InputModal []Modal `json:"input_modal,omitempty" yaml:"input_modal,omitempty" mapstructure:"input_modal,omitempty"`
|
||||
// Input tokens
|
||||
InputTokens int `json:"input_tokens" yaml:"input_tokens" mapstructure:"input_tokens"`
|
||||
// Model supports json mode
|
||||
JSONMode bool `json:"json_mode" yaml:"json_mode" mapstructure:"json_mode"`
|
||||
// Max tokens
|
||||
MaxTokens int `json:"max_tokens" yaml:"max_tokens" mapstructure:"max_tokens"`
|
||||
// Output modals
|
||||
OutputModal []Modal `json:"output_modal,omitempty" yaml:"output_modal,omitempty" mapstructure:"output_modal,omitempty"`
|
||||
// Output tokens
|
||||
OutputTokens int `json:"output_tokens" yaml:"output_tokens" mapstructure:"output_tokens"`
|
||||
// Model supports prefix caching
|
||||
PrefixCaching bool `json:"prefix_caching" yaml:"prefix_caching" mapstructure:"prefix_caching"`
|
||||
// Model supports reasoning
|
||||
Reasoning bool `json:"reasoning" yaml:"reasoning" mapstructure:"reasoning"`
|
||||
// Model supports prefill response
|
||||
PrefillResponse bool `json:"prefill_response" yaml:"prefill_response" mapstructure:"prefill_response"`
|
||||
}
|
||||
|
||||
type MultilingualText struct {
|
||||
ZH string `json:"zh,omitempty" yaml:"zh,omitempty"`
|
||||
EN string `json:"en,omitempty" yaml:"en,omitempty"`
|
||||
}
|
||||
|
||||
func (m *MultilingualText) Read(locale i18n.Locale) string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
switch locale {
|
||||
case i18n.LocaleZH:
|
||||
return m.ZH
|
||||
case i18n.LocaleEN:
|
||||
return m.EN
|
||||
default:
|
||||
return m.EN
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user