feat: refactor model manager
* chore: mv model icon * fix: model icon * fix: model icon * feat: refactor model manager * fix: model icon * fix: model icon * feat: refactor model manager See merge request: !905
This commit is contained in:
@@ -25,9 +25,9 @@ import (
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
BaseURL string `json:"base_url,omitempty" yaml:"base_url,omitempty"`
|
||||
APIKey string `json:"api_key,omitempty" yaml:"api_key,omitempty"`
|
||||
Timeout time.Duration `json:"timeout,omitempty" yaml:"timeout,omitempty"`
|
||||
BaseURL string `json:"base_url,omitempty" yaml:"base_url"`
|
||||
APIKey string `json:"api_key,omitempty" yaml:"api_key"`
|
||||
Timeout time.Duration `json:"timeout,omitempty" yaml:"timeout"`
|
||||
|
||||
Model string `json:"model" yaml:"model"`
|
||||
Temperature *float32 `json:"temperature,omitempty" yaml:"temperature,omitempty"`
|
||||
|
||||
60
backend/infra/contract/modelmgr/const.go
Normal file
60
backend/infra/contract/modelmgr/const.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package modelmgr
|
||||
|
||||
type ParameterName string
|
||||
|
||||
const (
|
||||
Temperature ParameterName = "temperature"
|
||||
TopP ParameterName = "top_p"
|
||||
TopK ParameterName = "top_k"
|
||||
MaxTokens ParameterName = "max_tokens"
|
||||
RespFormat ParameterName = "response_format"
|
||||
FrequencyPenalty ParameterName = "frequency_penalty"
|
||||
PresencePenalty ParameterName = "presence_penalty"
|
||||
)
|
||||
|
||||
type ValueType string
|
||||
|
||||
const (
|
||||
ValueTypeInt ValueType = "int"
|
||||
ValueTypeFloat ValueType = "float"
|
||||
ValueTypeBoolean ValueType = "boolean"
|
||||
ValueTypeString ValueType = "string"
|
||||
)
|
||||
|
||||
type DefaultType string
|
||||
|
||||
const (
|
||||
DefaultTypeDefault DefaultType = "default_val"
|
||||
DefaultTypeCreative DefaultType = "creative"
|
||||
DefaultTypeBalance DefaultType = "balance"
|
||||
DefaultTypePrecise DefaultType = "precise"
|
||||
)
|
||||
|
||||
// Deprecated
|
||||
type Scenario int64 // 模型实体使用场景
|
||||
|
||||
type Modal string
|
||||
|
||||
const (
|
||||
ModalText Modal = "text"
|
||||
ModalImage Modal = "image"
|
||||
ModalFile Modal = "file"
|
||||
ModalAudio Modal = "audio"
|
||||
ModalVideo Modal = "video"
|
||||
)
|
||||
|
||||
type ModelStatus int64
|
||||
|
||||
const (
|
||||
StatusDefault ModelStatus = 0 // 未配置时的默认状态,表现等同 StatusInUse
|
||||
StatusInUse ModelStatus = 1 // 应用中,可使用可新建
|
||||
StatusPending ModelStatus = 5 // 待下线,可使用不可新建
|
||||
StatusDeleted ModelStatus = 10 // 已下线,不可使用不可新建
|
||||
)
|
||||
|
||||
type Widget string
|
||||
|
||||
const (
|
||||
WidgetSlider Widget = "slider"
|
||||
WidgetRadioButtons Widget = "radio_buttons"
|
||||
)
|
||||
169
backend/infra/contract/modelmgr/desc.go
Normal file
169
backend/infra/contract/modelmgr/desc.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package modelmgr
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/i18n"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
ID int64 `yaml:"id"`
|
||||
Name string `yaml:"name"`
|
||||
IconURI string `yaml:"icon_uri"`
|
||||
IconURL string `yaml:"icon_url"`
|
||||
Description *MultilingualText `yaml:"description"`
|
||||
DefaultParameters []*Parameter `yaml:"default_parameters"`
|
||||
Meta ModelMeta `yaml:"meta"`
|
||||
}
|
||||
|
||||
func (m *Model) FindParameter(name ParameterName) (*Parameter, bool) {
|
||||
if len(m.DefaultParameters) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
for _, param := range m.DefaultParameters {
|
||||
if param.Name == name {
|
||||
return param, true
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
type Parameter struct {
|
||||
Name ParameterName `json:"name" yaml:"name"`
|
||||
Label *MultilingualText `json:"label,omitempty" yaml:"label,omitempty"`
|
||||
Desc *MultilingualText `json:"desc" yaml:"desc"`
|
||||
Type ValueType `json:"type" yaml:"type"`
|
||||
Min string `json:"min" yaml:"min"`
|
||||
Max string `json:"max" yaml:"max"`
|
||||
DefaultVal DefaultValue `json:"default_val" yaml:"default_val"`
|
||||
Precision int `json:"precision,omitempty" yaml:"precision,omitempty"` // float precision, default 2
|
||||
Options []*ParamOption `json:"options" yaml:"options"` // enum options
|
||||
Style DisplayStyle `json:"param_class" yaml:"style"`
|
||||
}
|
||||
|
||||
func (p *Parameter) GetFloat(tp DefaultType) (float64, error) {
|
||||
if p.Type != ValueTypeFloat {
|
||||
return 0, fmt.Errorf("unexpected paramerter type, name=%v, expect=%v, given=%v",
|
||||
p.Name, ValueTypeFloat, p.Type)
|
||||
}
|
||||
|
||||
if tp != DefaultTypeDefault && p.DefaultVal[tp] == "" {
|
||||
tp = DefaultTypeDefault
|
||||
}
|
||||
|
||||
val, ok := p.DefaultVal[tp]
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("unexpected default type, name=%v, type=%v", p.Name, tp)
|
||||
}
|
||||
|
||||
return strconv.ParseFloat(val, 64)
|
||||
}
|
||||
|
||||
func (p *Parameter) GetInt(tp DefaultType) (int64, error) {
|
||||
if p.Type != ValueTypeInt {
|
||||
return 0, fmt.Errorf("unexpected paramerter type, name=%v, expect=%v, given=%v",
|
||||
p.Name, ValueTypeInt, p.Type)
|
||||
}
|
||||
|
||||
if tp != DefaultTypeDefault && p.DefaultVal[tp] == "" {
|
||||
tp = DefaultTypeDefault
|
||||
}
|
||||
val, ok := p.DefaultVal[tp]
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("unexpected default type, name=%v, type=%v", p.Name, tp)
|
||||
}
|
||||
return strconv.ParseInt(val, 10, 64)
|
||||
}
|
||||
|
||||
func (p *Parameter) GetBool(tp DefaultType) (bool, error) {
|
||||
if p.Type != ValueTypeBoolean {
|
||||
return false, fmt.Errorf("unexpected paramerter type, name=%v, expect=%v, given=%v",
|
||||
p.Name, ValueTypeBoolean, p.Type)
|
||||
}
|
||||
if tp != DefaultTypeDefault && p.DefaultVal[tp] == "" {
|
||||
tp = DefaultTypeDefault
|
||||
}
|
||||
val, ok := p.DefaultVal[tp]
|
||||
if !ok {
|
||||
return false, fmt.Errorf("unexpected default type, name=%v, type=%v", p.Name, tp)
|
||||
}
|
||||
return strconv.ParseBool(val)
|
||||
}
|
||||
|
||||
func (p *Parameter) GetString(tp DefaultType) (string, error) {
|
||||
if tp != DefaultTypeDefault && p.DefaultVal[tp] == "" {
|
||||
tp = DefaultTypeDefault
|
||||
}
|
||||
|
||||
val, ok := p.DefaultVal[tp]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("unexpected default type, name=%v, type=%v", p.Name, tp)
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
type ModelMeta struct {
|
||||
Name string `yaml:"name"`
|
||||
Protocol chatmodel.Protocol `yaml:"protocol"` // 模型通信协议
|
||||
Capability *Capability `yaml:"capability"` // 模型能力
|
||||
ConnConfig *chatmodel.Config `yaml:"conn_config"` // 模型连接配置
|
||||
Status ModelStatus `yaml:"status"` // 模型状态
|
||||
}
|
||||
|
||||
type DefaultValue map[DefaultType]string
|
||||
|
||||
type DisplayStyle struct {
|
||||
Widget Widget `json:"class_id" yaml:"widget"`
|
||||
Label *MultilingualText `json:"label" yaml:"label"`
|
||||
}
|
||||
|
||||
type ParamOption struct {
|
||||
Label string `json:"label"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
type Capability struct {
|
||||
// Model supports function calling
|
||||
FunctionCall bool `json:"function_call" yaml:"function_call" mapstructure:"function_call"`
|
||||
// Input modals
|
||||
InputModal []Modal `json:"input_modal,omitempty" yaml:"input_modal,omitempty" mapstructure:"input_modal,omitempty"`
|
||||
// Input tokens
|
||||
InputTokens int `json:"input_tokens" yaml:"input_tokens" mapstructure:"input_tokens"`
|
||||
// Model supports json mode
|
||||
JSONMode bool `json:"json_mode" yaml:"json_mode" mapstructure:"json_mode"`
|
||||
// Max tokens
|
||||
MaxTokens int `json:"max_tokens" yaml:"max_tokens" mapstructure:"max_tokens"`
|
||||
// Output modals
|
||||
OutputModal []Modal `json:"output_modal,omitempty" yaml:"output_modal,omitempty" mapstructure:"output_modal,omitempty"`
|
||||
// Output tokens
|
||||
OutputTokens int `json:"output_tokens" yaml:"output_tokens" mapstructure:"output_tokens"`
|
||||
// Model supports prefix caching
|
||||
PrefixCaching bool `json:"prefix_caching" yaml:"prefix_caching" mapstructure:"prefix_caching"`
|
||||
// Model supports reasoning
|
||||
Reasoning bool `json:"reasoning" yaml:"reasoning" mapstructure:"reasoning"`
|
||||
// Model supports prefill response
|
||||
PrefillResponse bool `json:"prefill_response" yaml:"prefill_response" mapstructure:"prefill_response"`
|
||||
}
|
||||
|
||||
type MultilingualText struct {
|
||||
ZH string `json:"zh,omitempty" yaml:"zh,omitempty"`
|
||||
EN string `json:"en,omitempty" yaml:"en,omitempty"`
|
||||
}
|
||||
|
||||
func (m *MultilingualText) Read(locale i18n.Locale) string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
switch locale {
|
||||
case i18n.LocaleZH:
|
||||
return m.ZH
|
||||
case i18n.LocaleEN:
|
||||
return m.EN
|
||||
default:
|
||||
return m.EN
|
||||
}
|
||||
}
|
||||
27
backend/infra/contract/modelmgr/modelmgr.go
Normal file
27
backend/infra/contract/modelmgr/modelmgr.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package modelmgr
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type Manager interface {
|
||||
ListModel(ctx context.Context, req *ListModelRequest) (*ListModelResponse, error)
|
||||
MGetModelByID(ctx context.Context, req *MGetModelRequest) ([]*Model, error)
|
||||
}
|
||||
|
||||
type ListModelRequest struct {
|
||||
FuzzyModelName *string
|
||||
Status []ModelStatus // default is default and in_use status
|
||||
Limit int
|
||||
Cursor *string
|
||||
}
|
||||
|
||||
type ListModelResponse struct {
|
||||
ModelList []*Model
|
||||
HasMore bool
|
||||
NextCursor *string
|
||||
}
|
||||
|
||||
type MGetModelRequest struct {
|
||||
IDs []int64
|
||||
}
|
||||
80
backend/infra/impl/modelmgr/static/modelmgr.go
Normal file
80
backend/infra/impl/modelmgr/static/modelmgr.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package static
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/sets"
|
||||
)
|
||||
|
||||
func NewModelMgr(staticModels []*modelmgr.Model) (modelmgr.Manager, error) {
|
||||
mapping := make(map[int64]*modelmgr.Model, len(staticModels))
|
||||
for i := range staticModels {
|
||||
mapping[staticModels[i].ID] = staticModels[i]
|
||||
}
|
||||
return &staticModelManager{
|
||||
models: staticModels,
|
||||
modelMapping: mapping,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type staticModelManager struct {
|
||||
models []*modelmgr.Model
|
||||
modelMapping map[int64]*modelmgr.Model
|
||||
}
|
||||
|
||||
func (s *staticModelManager) ListModel(_ context.Context, req *modelmgr.ListModelRequest) (*modelmgr.ListModelResponse, error) {
|
||||
startIdx := 0
|
||||
if req.Cursor != nil {
|
||||
start, err := strconv.ParseInt(*req.Cursor, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
startIdx = int(start)
|
||||
}
|
||||
|
||||
limit := req.Limit
|
||||
if limit == 0 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
var (
|
||||
i int
|
||||
respList []*modelmgr.Model
|
||||
statSet = sets.FromSlice(req.Status)
|
||||
)
|
||||
|
||||
for i = startIdx; i < len(s.models) && len(respList) < limit; i++ {
|
||||
m := s.models[i]
|
||||
if req.FuzzyModelName != nil && !strings.Contains(m.Name, *req.FuzzyModelName) {
|
||||
continue
|
||||
}
|
||||
if len(statSet) > 0 && !statSet.Contains(m.Meta.Status) {
|
||||
continue
|
||||
}
|
||||
respList = append(respList, m)
|
||||
}
|
||||
|
||||
resp := &modelmgr.ListModelResponse{
|
||||
ModelList: respList,
|
||||
}
|
||||
resp.HasMore = i != len(s.models)
|
||||
if resp.HasMore {
|
||||
resp.NextCursor = ptr.Of(strconv.FormatInt(int64(i), 10))
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (s *staticModelManager) MGetModelByID(_ context.Context, req *modelmgr.MGetModelRequest) ([]*modelmgr.Model, error) {
|
||||
resp := make([]*modelmgr.Model, 0, len(s.models))
|
||||
for _, id := range req.IDs {
|
||||
if m, found := s.modelMapping[id]; found {
|
||||
resp = append(resp, m)
|
||||
}
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
Reference in New Issue
Block a user