feat: refactor model manager
* chore: mv model icon * fix: model icon * fix: model icon * feat: refactor model manager * fix: model icon * fix: model icon * feat: refactor model manager See merge request: !905
This commit is contained in:
@@ -24,6 +24,7 @@ import (
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/cache/redis"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/es"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/eventbus"
|
||||
@@ -43,6 +44,7 @@ type AppDependencies struct {
|
||||
TOSClient storage.Storage
|
||||
ResourceEventProducer eventbus.Producer
|
||||
AppEventProducer eventbus.Producer
|
||||
ModelMgr modelmgr.Manager
|
||||
}
|
||||
|
||||
func Init(ctx context.Context) (*AppDependencies, error) {
|
||||
@@ -86,6 +88,11 @@ func Init(ctx context.Context) (*AppDependencies, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
deps.ModelMgr, err = initModelMgr()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return deps, nil
|
||||
}
|
||||
|
||||
|
||||
216
backend/application/base/appinfra/modelmgr.go
Normal file
216
backend/application/base/appinfra/modelmgr.go
Normal file
@@ -0,0 +1,216 @@
|
||||
package appinfra
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/modelmgr/static"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
)
|
||||
|
||||
func initModelMgr() (modelmgr.Manager, error) {
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
staticModel, err := initModelByTemplate(wd, "resources/conf/model")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
envModel, err := initModelByEnv(wd, "resources/conf/model/template")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
all := append(staticModel, envModel...)
|
||||
if err := fillModelContent(all); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mgr, err := static.NewModelMgr(all)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return mgr, nil
|
||||
}
|
||||
|
||||
func initModelByTemplate(wd, configPath string) ([]*modelmgr.Model, error) {
|
||||
configRoot := filepath.Join(wd, configPath)
|
||||
staticModel, err := readDirYaml[modelmgr.Model](configRoot)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return staticModel, nil
|
||||
}
|
||||
|
||||
func initModelByEnv(wd, templatePath string) (modelEntities []*modelmgr.Model, err error) {
|
||||
entityRoot := filepath.Join(wd, templatePath)
|
||||
|
||||
for i := -1; i < 1000; i++ {
|
||||
rawProtocol := os.Getenv(concatEnvKey(modelProtocolPrefix, i))
|
||||
if rawProtocol == "" {
|
||||
if i < 0 {
|
||||
continue
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
protocol := chatmodel.Protocol(rawProtocol)
|
||||
info, valid := getModelEnv(i)
|
||||
if !valid {
|
||||
break
|
||||
}
|
||||
|
||||
mapping, found := modelMapping[protocol]
|
||||
if !found {
|
||||
return nil, fmt.Errorf("[initModelByEnv] unsupport protocol: %s", rawProtocol)
|
||||
}
|
||||
|
||||
switch protocol {
|
||||
case chatmodel.ProtocolArk:
|
||||
fileSuffix, foundTemplate := mapping[info.modelName]
|
||||
if !foundTemplate {
|
||||
logs.Warnf("[initModelByEnv] unsupport model=%s, using default config", info.modelName)
|
||||
}
|
||||
modelEntity, err := readYaml[modelmgr.Model](filepath.Join(entityRoot, concatTemplateFileName("model_template_ark", fileSuffix)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
id, err := strconv.ParseInt(info.id, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
modelEntity.ID = id
|
||||
if !foundTemplate {
|
||||
modelEntity.Name = info.modelName
|
||||
}
|
||||
|
||||
modelEntities = append(modelEntities, modelEntity)
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("[initModelByEnv] unsupport protocol: %s", rawProtocol)
|
||||
}
|
||||
}
|
||||
|
||||
return modelEntities, nil
|
||||
}
|
||||
|
||||
type envModelInfo struct {
|
||||
id, modelName, modelID, apiKey, baseURL string
|
||||
}
|
||||
|
||||
func getModelEnv(idx int) (info envModelInfo, valid bool) {
|
||||
info.id = os.Getenv(concatEnvKey(modelOpenCozeIDPrefix, idx))
|
||||
info.modelName = os.Getenv(concatEnvKey(modelNamePrefix, idx))
|
||||
info.modelID = os.Getenv(concatEnvKey(modelIDPrefix, idx))
|
||||
info.apiKey = os.Getenv(concatEnvKey(modelApiKeyPrefix, idx))
|
||||
info.baseURL = os.Getenv(concatEnvKey(modelBaseURLPrefix, idx))
|
||||
valid = info.modelName != "" && info.modelID != "" && info.apiKey != ""
|
||||
return
|
||||
}
|
||||
|
||||
func readDirYaml[T any](dir string) ([]*T, error) {
|
||||
des, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp := make([]*T, 0, len(des))
|
||||
for _, file := range des {
|
||||
if file.IsDir() {
|
||||
continue
|
||||
}
|
||||
if strings.HasSuffix(file.Name(), ".yaml") || strings.HasSuffix(file.Name(), ".yml") {
|
||||
filePath := filepath.Join(dir, file.Name())
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var content T
|
||||
if err := yaml.Unmarshal(data, &content); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp = append(resp, &content)
|
||||
}
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func readYaml[T any](fPath string) (*T, error) {
|
||||
data, err := os.ReadFile(fPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var content T
|
||||
if err := yaml.Unmarshal(data, &content); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &content, nil
|
||||
}
|
||||
|
||||
func concatEnvKey(prefix string, idx int) string {
|
||||
if idx < 0 {
|
||||
return prefix
|
||||
}
|
||||
return fmt.Sprintf("%s_%d", prefix, idx)
|
||||
}
|
||||
|
||||
func concatTemplateFileName(prefix, suffix string) string {
|
||||
if suffix == "" {
|
||||
return prefix + ".yaml"
|
||||
}
|
||||
return prefix + "_" + suffix + ".yaml"
|
||||
}
|
||||
|
||||
const (
|
||||
modelProtocolPrefix = "MODEL_PROTOCOL" // model protocol
|
||||
modelOpenCozeIDPrefix = "MODEL_OPENCOZE_ID" // opencoze model id
|
||||
modelNamePrefix = "MODEL_NAME" // model name,
|
||||
modelIDPrefix = "MODEL_ID" // model in conn config
|
||||
modelApiKeyPrefix = "MODEL_API_KEY" // model api key
|
||||
modelBaseURLPrefix = "MODEL_BASE_URL" // model base url
|
||||
)
|
||||
|
||||
var modelMapping = map[chatmodel.Protocol]map[string]string{
|
||||
chatmodel.ProtocolArk: {
|
||||
"doubao-seed-1.6": "doubao-seed-1.6",
|
||||
"doubao-seed-1.6-flash": "doubao-seed-1.6-flash",
|
||||
"doubao-seed-1.6-thinking": "doubao-seed-1.6-thinking",
|
||||
"doubao-1.5-thinking-vision-pro": "doubao-1.5-thinking-vision-pro",
|
||||
"doubao-1.5-thinking-pro": "doubao-1.5-thinking-pro",
|
||||
"doubao-1.5-vision-pro": "doubao-1.5-vision-pro",
|
||||
"doubao-1.5-vision-lite": "doubao-1.5-vision-lite",
|
||||
"doubao-1.5-pro-32k": "doubao-1.5-pro-32k",
|
||||
"doubao-1.5-pro-256k": "doubao-1.5-pro-256k",
|
||||
"doubao-1.5-lite": "doubao-1.5-lite",
|
||||
"deepseek-r1": "volc_deepseek-r1",
|
||||
"deepseek-v3": "volc_deepseek-v3",
|
||||
},
|
||||
}
|
||||
|
||||
func fillModelContent(items []*modelmgr.Model) error {
|
||||
for i := range items {
|
||||
item := items[i]
|
||||
if item.Meta.Status == modelmgr.StatusDefault {
|
||||
item.Meta.Status = modelmgr.StatusInUse
|
||||
}
|
||||
|
||||
if item.IconURI == "" && item.IconURL == "" {
|
||||
return fmt.Errorf("missing icon URI or icon URL, id=%d", item.ID)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
30
backend/application/base/appinfra/modelmgr_test.go
Normal file
30
backend/application/base/appinfra/modelmgr_test.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package appinfra
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
|
||||
)
|
||||
|
||||
func TestInitByEnv(t *testing.T) {
|
||||
i := 0
|
||||
for k := range modelMapping[chatmodel.ProtocolArk] {
|
||||
_ = os.Setenv(concatEnvKey(modelProtocolPrefix, i), "ark")
|
||||
_ = os.Setenv(concatEnvKey(modelOpenCozeIDPrefix, i), fmt.Sprintf("%d", 45678+i))
|
||||
_ = os.Setenv(concatEnvKey(modelNamePrefix, i), k)
|
||||
_ = os.Setenv(concatEnvKey(modelIDPrefix, i), k)
|
||||
_ = os.Setenv(concatEnvKey(modelApiKeyPrefix, i), "mock_api_key")
|
||||
i++
|
||||
}
|
||||
|
||||
wd, err := os.Getwd()
|
||||
assert.NoError(t, err)
|
||||
|
||||
ms, err := initModelByEnv(wd, "../../../conf/model/template")
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, ms, len(modelMapping[chatmodel.ProtocolArk]))
|
||||
}
|
||||
Reference in New Issue
Block a user