feat: refactor model manager
* chore: mv model icon * fix: model icon * fix: model icon * feat: refactor model manager * fix: model icon * fix: model icon * feat: refactor model manager See merge request: !905
This commit is contained in:
@@ -47,7 +47,6 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossdatacopy"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossknowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossmessage"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossmodelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossplugin"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossuser"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossvariables"
|
||||
@@ -60,7 +59,6 @@ import (
|
||||
dataCopyImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/datacopy"
|
||||
knowledgeImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/knowledge"
|
||||
messageImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/message"
|
||||
modelmgrImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/modelmgr"
|
||||
pluginImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/plugin"
|
||||
searchImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/search"
|
||||
singleagentImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/singleagent"
|
||||
@@ -130,7 +128,6 @@ func Init(ctx context.Context) (err error) {
|
||||
crossconnector.SetDefaultSVC(connectorImpl.InitDomainService(basicServices.connectorSVC.DomainSVC))
|
||||
crossdatabase.SetDefaultSVC(databaseImpl.InitDomainService(primaryServices.memorySVC.DatabaseDomainSVC))
|
||||
crossknowledge.SetDefaultSVC(knowledgeImpl.InitDomainService(primaryServices.knowledgeSVC.DomainSVC))
|
||||
crossmodelmgr.SetDefaultSVC(modelmgrImpl.InitDomainService(basicServices.modelMgrSVC.DomainSVC))
|
||||
crossplugin.SetDefaultSVC(pluginImpl.InitDomainService(primaryServices.pluginSVC.DomainSVC))
|
||||
crossvariables.SetDefaultSVC(variablesImpl.InitDomainService(primaryServices.memorySVC.VariablesDomainSVC))
|
||||
crossworkflow.SetDefaultSVC(workflowImpl.InitDomainService(primaryServices.workflowSVC.DomainSVC))
|
||||
@@ -158,10 +155,7 @@ func initBasicServices(ctx context.Context, infra *appinfra.AppDependencies, e *
|
||||
upload.InitService(infra.TOSClient, infra.CacheCli)
|
||||
openAuthSVC := openauth.InitService(infra.DB, infra.IDGenSVC)
|
||||
promptSVC := prompt.InitService(infra.DB, infra.IDGenSVC, e.resourceEventBus)
|
||||
modelMgrSVC, err := modelmgr.InitService(infra.DB, infra.IDGenSVC, infra.TOSClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
modelMgrSVC := modelmgr.InitService(infra.ModelMgr, infra.TOSClient)
|
||||
connectorSVC := connector.InitService(infra.TOSClient)
|
||||
userSVC := user.InitService(ctx, infra.DB, infra.TOSClient, infra.IDGenSVC)
|
||||
templateSVC := template.InitService(ctx, &template.ServiceComponents{
|
||||
@@ -285,7 +279,7 @@ func (b *basicServices) toWorkflowServiceComponents(pluginSVC *plugin.PluginAppl
|
||||
VariablesDomainSVC: memorySVC.VariablesDomainSVC,
|
||||
PluginDomainSVC: pluginSVC.DomainSVC,
|
||||
KnowledgeDomainSVC: knowledgeSVC.DomainSVC,
|
||||
ModelManager: b.modelMgrSVC.DomainSVC,
|
||||
ModelManager: b.infra.ModelMgr,
|
||||
DomainNotifier: b.eventbus.resourceEventBus,
|
||||
CPStore: checkpoint.NewRedisStore(b.infra.CacheCli),
|
||||
}
|
||||
@@ -299,7 +293,7 @@ func (p *primaryServices) toSingleAgentServiceComponents() *singleagent.ServiceC
|
||||
Cache: p.basicServices.infra.CacheCli,
|
||||
TosClient: p.basicServices.infra.TOSClient,
|
||||
ImageX: p.basicServices.infra.ImageXClient,
|
||||
ModelMgrDomainSVC: p.basicServices.modelMgrSVC.DomainSVC,
|
||||
ModelMgr: p.infra.ModelMgr,
|
||||
UserDomainSVC: p.basicServices.userSVC.DomainSVC,
|
||||
EventBus: p.basicServices.eventbus.projectEventBus,
|
||||
DatabaseDomainSVC: p.memorySVC.DatabaseDomainSVC,
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/cache/redis"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/es"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/eventbus"
|
||||
@@ -43,6 +44,7 @@ type AppDependencies struct {
|
||||
TOSClient storage.Storage
|
||||
ResourceEventProducer eventbus.Producer
|
||||
AppEventProducer eventbus.Producer
|
||||
ModelMgr modelmgr.Manager
|
||||
}
|
||||
|
||||
func Init(ctx context.Context) (*AppDependencies, error) {
|
||||
@@ -86,6 +88,11 @@ func Init(ctx context.Context) (*AppDependencies, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
deps.ModelMgr, err = initModelMgr()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return deps, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,21 +1,61 @@
|
||||
package modelmgr
|
||||
package appinfra
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/modelmgr/static"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
)
|
||||
|
||||
func initModelByEnv(wd, templatePath string) (metaSlice []*modelmgr.ModelMeta, entitySlice []*modelmgr.Model, err error) {
|
||||
metaRoot := filepath.Join(wd, templatePath, "meta")
|
||||
entityRoot := filepath.Join(wd, templatePath, "entity")
|
||||
func initModelMgr() (modelmgr.Manager, error) {
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
staticModel, err := initModelByTemplate(wd, "resources/conf/model")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
envModel, err := initModelByEnv(wd, "resources/conf/model/template")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
all := append(staticModel, envModel...)
|
||||
if err := fillModelContent(all); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mgr, err := static.NewModelMgr(all)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return mgr, nil
|
||||
}
|
||||
|
||||
func initModelByTemplate(wd, configPath string) ([]*modelmgr.Model, error) {
|
||||
configRoot := filepath.Join(wd, configPath)
|
||||
staticModel, err := readDirYaml[modelmgr.Model](configRoot)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return staticModel, nil
|
||||
}
|
||||
|
||||
func initModelByEnv(wd, templatePath string) (modelEntities []*modelmgr.Model, err error) {
|
||||
entityRoot := filepath.Join(wd, templatePath)
|
||||
|
||||
for i := -1; i < 1000; i++ {
|
||||
rawProtocol := os.Getenv(concatEnvKey(modelProtocolPrefix, i))
|
||||
@@ -35,7 +75,7 @@ func initModelByEnv(wd, templatePath string) (metaSlice []*modelmgr.ModelMeta, e
|
||||
|
||||
mapping, found := modelMapping[protocol]
|
||||
if !found {
|
||||
return nil, nil, fmt.Errorf("[initModelByEnv] unsupport protocol: %s", rawProtocol)
|
||||
return nil, fmt.Errorf("[initModelByEnv] unsupport protocol: %s", rawProtocol)
|
||||
}
|
||||
|
||||
switch protocol {
|
||||
@@ -44,41 +84,28 @@ func initModelByEnv(wd, templatePath string) (metaSlice []*modelmgr.ModelMeta, e
|
||||
if !foundTemplate {
|
||||
logs.Warnf("[initModelByEnv] unsupport model=%s, using default config", info.modelName)
|
||||
}
|
||||
modelMeta, err := readYaml[modelmgr.ModelMeta](filepath.Join(metaRoot, concatTemplateFileName("model_meta_template_ark", fileSuffix)))
|
||||
modelEntity, err := readYaml[modelmgr.Model](filepath.Join(entityRoot, concatTemplateFileName("model_template_ark", fileSuffix)))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
modelEntity, err := readYaml[modelmgr.Model](filepath.Join(entityRoot, concatTemplateFileName("model_entity_template_ark", fileSuffix)))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
id, err := strconv.ParseInt(info.id, 10, 64)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// meta 和 entity 用一个 id,有概率冲突
|
||||
modelMeta.ID = id
|
||||
modelMeta.ConnConfig.Model = info.modelID
|
||||
modelMeta.ConnConfig.APIKey = info.apiKey
|
||||
if info.baseURL != "" {
|
||||
modelMeta.ConnConfig.BaseURL = info.baseURL
|
||||
}
|
||||
modelEntity.ID = id
|
||||
modelEntity.Meta.ID = id
|
||||
if !foundTemplate {
|
||||
modelEntity.Name = info.modelName
|
||||
}
|
||||
|
||||
metaSlice = append(metaSlice, modelMeta)
|
||||
entitySlice = append(entitySlice, modelEntity)
|
||||
modelEntities = append(modelEntities, modelEntity)
|
||||
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("[initModelByEnv] unsupport protocol: %s", rawProtocol)
|
||||
return nil, fmt.Errorf("[initModelByEnv] unsupport protocol: %s", rawProtocol)
|
||||
}
|
||||
}
|
||||
|
||||
return metaSlice, entitySlice, nil
|
||||
return modelEntities, nil
|
||||
}
|
||||
|
||||
type envModelInfo struct {
|
||||
@@ -95,6 +122,32 @@ func getModelEnv(idx int) (info envModelInfo, valid bool) {
|
||||
return
|
||||
}
|
||||
|
||||
func readDirYaml[T any](dir string) ([]*T, error) {
|
||||
des, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp := make([]*T, 0, len(des))
|
||||
for _, file := range des {
|
||||
if file.IsDir() {
|
||||
continue
|
||||
}
|
||||
if strings.HasSuffix(file.Name(), ".yaml") || strings.HasSuffix(file.Name(), ".yml") {
|
||||
filePath := filepath.Join(dir, file.Name())
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var content T
|
||||
if err := yaml.Unmarshal(data, &content); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp = append(resp, &content)
|
||||
}
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func readYaml[T any](fPath string) (*T, error) {
|
||||
data, err := os.ReadFile(fPath)
|
||||
if err != nil {
|
||||
@@ -146,3 +199,18 @@ var modelMapping = map[chatmodel.Protocol]map[string]string{
|
||||
"deepseek-v3": "volc_deepseek-v3",
|
||||
},
|
||||
}
|
||||
|
||||
func fillModelContent(items []*modelmgr.Model) error {
|
||||
for i := range items {
|
||||
item := items[i]
|
||||
if item.Meta.Status == modelmgr.StatusDefault {
|
||||
item.Meta.Status = modelmgr.StatusInUse
|
||||
}
|
||||
|
||||
if item.IconURI == "" && item.IconURL == "" {
|
||||
return fmt.Errorf("missing icon URI or icon URL, id=%d", item.ID)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package modelmgr
|
||||
package appinfra
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -24,8 +24,7 @@ func TestInitByEnv(t *testing.T) {
|
||||
wd, err := os.Getwd()
|
||||
assert.NoError(t, err)
|
||||
|
||||
ms, es, err := initModelByEnv(wd, "../../conf/model/template")
|
||||
ms, err := initModelByEnv(wd, "../../../conf/model/template")
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, ms, len(modelMapping[chatmodel.ProtocolArk]))
|
||||
assert.Len(t, es, len(modelMapping[chatmodel.ProtocolArk]))
|
||||
}
|
||||
@@ -1,198 +1,11 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package modelmgr
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
"gorm.io/gorm"
|
||||
|
||||
crossmodelmgr "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/modelmgr/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/modelmgr/service"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/storage"
|
||||
)
|
||||
|
||||
func InitService(db *gorm.DB, idgen idgen.IDGenerator, oss storage.Storage) (*ModelmgrApplicationService, error) {
|
||||
svc := service.NewModelManager(db, idgen, oss)
|
||||
if err := loadStaticModelConfig(svc, oss); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ModelmgrApplicationSVC.DomainSVC = svc
|
||||
|
||||
return ModelmgrApplicationSVC, nil
|
||||
}
|
||||
|
||||
func loadStaticModelConfig(svc modelmgr.Manager, oss storage.Storage) error {
|
||||
ctx := context.Background()
|
||||
|
||||
id2Meta := make(map[int64]*entity.ModelMeta)
|
||||
var cursor *string
|
||||
for {
|
||||
req := &modelmgr.ListModelMetaRequest{
|
||||
Status: []entity.ModelMetaStatus{
|
||||
crossmodelmgr.StatusInUse,
|
||||
crossmodelmgr.StatusPending,
|
||||
crossmodelmgr.StatusDeleted,
|
||||
},
|
||||
Limit: 100,
|
||||
Cursor: cursor,
|
||||
}
|
||||
listMetaResp, err := svc.ListModelMeta(ctx, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, item := range listMetaResp.ModelMetaList {
|
||||
cpItem := item
|
||||
id2Meta[cpItem.ID] = cpItem
|
||||
}
|
||||
if !listMetaResp.HasMore {
|
||||
break
|
||||
}
|
||||
cursor = listMetaResp.NextCursor
|
||||
}
|
||||
|
||||
root, err := os.Getwd()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
envModelMeta, envModelEntity, err := initModelByEnv(root, "resources/conf/model/template")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
filePath := filepath.Join(root, "resources/conf/model/meta")
|
||||
staticModelMeta, err := readDirYaml[crossmodelmgr.ModelMeta](filePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
staticModelMeta = append(staticModelMeta, envModelMeta...)
|
||||
for _, modelMeta := range staticModelMeta {
|
||||
if _, found := id2Meta[modelMeta.ID]; !found {
|
||||
if modelMeta.IconURI == "" && modelMeta.IconURL == "" {
|
||||
return fmt.Errorf("missing icon URI or icon URL, id=%d", modelMeta.ID)
|
||||
} else if modelMeta.IconURL != "" {
|
||||
// do nothing
|
||||
} else if modelMeta.IconURI != "" {
|
||||
// try local path
|
||||
base := filepath.Base(modelMeta.IconURI)
|
||||
iconPath := filepath.Join("resources/conf/model/icon", base)
|
||||
if _, err = os.Stat(iconPath); err == nil {
|
||||
// try upload icon
|
||||
icon, err := os.ReadFile(iconPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
key := fmt.Sprintf("icon_%s_%d", base, time.Now().Second())
|
||||
if err := oss.PutObject(ctx, key, icon); err != nil {
|
||||
return err
|
||||
}
|
||||
modelMeta.IconURI = key
|
||||
} else if errors.Is(err, os.ErrNotExist) {
|
||||
// try to get object from uri
|
||||
if _, err := oss.GetObject(ctx, modelMeta.IconURI); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
newMeta, err := svc.CreateModelMeta(ctx, modelMeta)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||
logs.Infof("[loadStaticModelConfig] model meta conflict for id=%d, skip", newMeta.ID)
|
||||
}
|
||||
return err
|
||||
} else {
|
||||
logs.Infof("[loadStaticModelConfig] model meta create success, id=%d", newMeta.ID)
|
||||
}
|
||||
id2Meta[newMeta.ID] = newMeta
|
||||
} else {
|
||||
logs.Infof("[loadStaticModelConfig] model meta founded, skip create, id=%d", modelMeta.ID)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
filePath = filepath.Join(root, "resources/conf/model/entity")
|
||||
staticModel, err := readDirYaml[crossmodelmgr.Model](filePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
staticModel = append(staticModel, envModelEntity...)
|
||||
for _, modelEntity := range staticModel {
|
||||
curModelEntities, err := svc.MGetModelByID(ctx, &modelmgr.MGetModelRequest{IDs: []int64{modelEntity.ID}})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(curModelEntities) > 0 {
|
||||
logs.Infof("[loadStaticModelConfig] model entity founded, skip create, id=%d", modelEntity.ID)
|
||||
continue
|
||||
}
|
||||
meta, found := id2Meta[modelEntity.Meta.ID]
|
||||
if !found {
|
||||
return fmt.Errorf("model meta not found for id=%d, model_id=%d", modelEntity.Meta.ID, modelEntity.ID)
|
||||
}
|
||||
modelEntity.Meta = *meta
|
||||
if _, err = svc.CreateModel(ctx, &entity.Model{Model: modelEntity}); err != nil {
|
||||
if errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||
logs.Infof("[loadStaticModelConfig] model entity conflict for id=%d, skip", modelEntity.ID)
|
||||
}
|
||||
return err
|
||||
} else {
|
||||
logs.Infof("[loadStaticModelConfig] model entity create success, id=%d", modelEntity.ID)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func readDirYaml[T any](dir string) ([]*T, error) {
|
||||
des, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp := make([]*T, 0, len(des))
|
||||
for _, file := range des {
|
||||
if file.IsDir() {
|
||||
continue
|
||||
}
|
||||
if strings.HasSuffix(file.Name(), ".yaml") || strings.HasSuffix(file.Name(), ".yml") {
|
||||
filePath := filepath.Join(dir, file.Name())
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var content T
|
||||
if err := yaml.Unmarshal(data, &content); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp = append(resp, &content)
|
||||
}
|
||||
}
|
||||
return resp, nil
|
||||
func InitService(mgr modelmgr.Manager, tosClient storage.Storage) *ModelmgrApplicationService {
|
||||
ModelmgrApplicationSVC = &ModelmgrApplicationService{mgr, tosClient}
|
||||
return ModelmgrApplicationSVC
|
||||
}
|
||||
|
||||
@@ -19,10 +19,9 @@ package modelmgr
|
||||
import (
|
||||
"context"
|
||||
|
||||
modelmgrEntity "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/developer_api"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/modelmgr"
|
||||
modelEntity "github.com/coze-dev/coze-studio/backend/domain/modelmgr/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/storage"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/i18n"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/sets"
|
||||
@@ -31,18 +30,19 @@ import (
|
||||
)
|
||||
|
||||
type ModelmgrApplicationService struct {
|
||||
DomainSVC modelmgr.Manager
|
||||
Mgr modelmgr.Manager
|
||||
TosClient storage.Storage
|
||||
}
|
||||
|
||||
var ModelmgrApplicationSVC = &ModelmgrApplicationService{}
|
||||
|
||||
func (m *ModelmgrApplicationService) GetModelList(ctx context.Context, req *developer_api.GetTypeListRequest) (
|
||||
func (m *ModelmgrApplicationService) GetModelList(ctx context.Context, _ *developer_api.GetTypeListRequest) (
|
||||
resp *developer_api.GetTypeListResponse, err error,
|
||||
) {
|
||||
// 一般不太可能同时配置这么多模型
|
||||
const modelMaxLimit = 300
|
||||
|
||||
modelResp, err := m.DomainSVC.ListModel(ctx, &modelmgr.ListModelRequest{
|
||||
modelResp, err := m.Mgr.ListModel(ctx, &modelmgr.ListModelRequest{
|
||||
Limit: modelMaxLimit,
|
||||
Cursor: nil,
|
||||
})
|
||||
@@ -51,9 +51,15 @@ func (m *ModelmgrApplicationService) GetModelList(ctx context.Context, req *deve
|
||||
}
|
||||
|
||||
locale := i18n.GetLocale(ctx)
|
||||
modelList, err := slices.TransformWithErrorCheck(modelResp.ModelList, func(m *modelEntity.Model) (*developer_api.Model, error) {
|
||||
logs.CtxInfof(ctx, "ChatModel DefaultParameters: %v", m.DefaultParameters)
|
||||
return modelDo2To(m, locale)
|
||||
modelList, err := slices.TransformWithErrorCheck(modelResp.ModelList, func(mm *modelmgr.Model) (*developer_api.Model, error) {
|
||||
logs.CtxInfof(ctx, "ChatModel DefaultParameters: %v", mm.DefaultParameters)
|
||||
if mm.IconURI != "" {
|
||||
iconUrl, err := m.TosClient.GetObjectUrl(ctx, mm.IconURI)
|
||||
if err == nil {
|
||||
mm.IconURL = iconUrl
|
||||
}
|
||||
}
|
||||
return modelDo2To(mm, locale)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -68,11 +74,11 @@ func (m *ModelmgrApplicationService) GetModelList(ctx context.Context, req *deve
|
||||
}, nil
|
||||
}
|
||||
|
||||
func modelDo2To(model *modelEntity.Model, locale i18n.Locale) (*developer_api.Model, error) {
|
||||
func modelDo2To(model *modelmgr.Model, locale i18n.Locale) (*developer_api.Model, error) {
|
||||
mm := model.Meta
|
||||
|
||||
mps := slices.Transform(model.DefaultParameters,
|
||||
func(param *modelmgrEntity.Parameter) *developer_api.ModelParameter {
|
||||
func(param *modelmgr.Parameter) *developer_api.ModelParameter {
|
||||
return parameterDo2To(param, locale)
|
||||
},
|
||||
)
|
||||
@@ -83,7 +89,7 @@ func modelDo2To(model *modelEntity.Model, locale i18n.Locale) (*developer_api.Mo
|
||||
Name: model.Name,
|
||||
ModelType: model.ID,
|
||||
ModelClass: mm.Protocol.TOModelClass(),
|
||||
ModelIcon: mm.IconURL,
|
||||
ModelIcon: model.IconURL,
|
||||
ModelInputPrice: 0,
|
||||
ModelOutputPrice: 0,
|
||||
ModelQuota: &developer_api.ModelQuota{
|
||||
@@ -102,19 +108,19 @@ func modelDo2To(model *modelEntity.Model, locale i18n.Locale) (*developer_api.Mo
|
||||
},
|
||||
ModelName: mm.Name,
|
||||
ModelClassName: mm.Protocol.TOModelClass().String(),
|
||||
IsOffline: mm.Status != modelmgrEntity.StatusInUse,
|
||||
IsOffline: mm.Status != modelmgr.StatusInUse,
|
||||
ModelParams: mps,
|
||||
ModelDesc: []*developer_api.ModelDescGroup{
|
||||
{
|
||||
GroupName: "Description",
|
||||
Desc: []string{model.Description},
|
||||
Desc: []string{model.Description.Read(locale)},
|
||||
},
|
||||
},
|
||||
FuncConfig: nil,
|
||||
EndpointName: nil,
|
||||
ModelTagList: nil,
|
||||
IsUpRequired: nil,
|
||||
ModelBriefDesc: mm.Description.Read(locale),
|
||||
ModelBriefDesc: model.Description.Read(locale),
|
||||
ModelSeries: &developer_api.ModelSeriesInfo{ // TODO: 替换为真实配置
|
||||
SeriesName: "热门模型",
|
||||
},
|
||||
@@ -122,16 +128,16 @@ func modelDo2To(model *modelEntity.Model, locale i18n.Locale) (*developer_api.Mo
|
||||
ModelAbility: &developer_api.ModelAbility{
|
||||
CotDisplay: ptr.Of(mm.Capability.Reasoning),
|
||||
FunctionCall: ptr.Of(mm.Capability.FunctionCall),
|
||||
ImageUnderstanding: ptr.Of(modalSet.Contains(modelmgrEntity.ModalImage)),
|
||||
VideoUnderstanding: ptr.Of(modalSet.Contains(modelmgrEntity.ModalVideo)),
|
||||
AudioUnderstanding: ptr.Of(modalSet.Contains(modelmgrEntity.ModalAudio)),
|
||||
ImageUnderstanding: ptr.Of(modalSet.Contains(modelmgr.ModalImage)),
|
||||
VideoUnderstanding: ptr.Of(modalSet.Contains(modelmgr.ModalVideo)),
|
||||
AudioUnderstanding: ptr.Of(modalSet.Contains(modelmgr.ModalAudio)),
|
||||
SupportMultiModal: ptr.Of(len(modalSet) > 1),
|
||||
PrefillResp: ptr.Of(mm.Capability.PrefillResponse),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parameterDo2To(param *modelmgrEntity.Parameter, locale i18n.Locale) *developer_api.ModelParameter {
|
||||
func parameterDo2To(param *modelmgr.Parameter, locale i18n.Locale) *developer_api.ModelParameter {
|
||||
if param == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -146,19 +152,19 @@ func parameterDo2To(param *modelmgrEntity.Parameter, locale i18n.Locale) *develo
|
||||
|
||||
var custom string
|
||||
var creative, balance, precise *string
|
||||
if val, ok := param.DefaultVal[modelmgrEntity.DefaultTypeDefault]; ok {
|
||||
if val, ok := param.DefaultVal[modelmgr.DefaultTypeDefault]; ok {
|
||||
custom = val
|
||||
}
|
||||
|
||||
if val, ok := param.DefaultVal[modelmgrEntity.DefaultTypeCreative]; ok {
|
||||
if val, ok := param.DefaultVal[modelmgr.DefaultTypeCreative]; ok {
|
||||
creative = ptr.Of(val)
|
||||
}
|
||||
|
||||
if val, ok := param.DefaultVal[modelmgrEntity.DefaultTypeBalance]; ok {
|
||||
if val, ok := param.DefaultVal[modelmgr.DefaultTypeBalance]; ok {
|
||||
balance = ptr.Of(val)
|
||||
}
|
||||
|
||||
if val, ok := param.DefaultVal[modelmgrEntity.DefaultTypePrecise]; ok {
|
||||
if val, ok := param.DefaultVal[modelmgr.DefaultTypePrecise]; ok {
|
||||
precise = ptr.Of(val)
|
||||
}
|
||||
|
||||
@@ -168,11 +174,11 @@ func parameterDo2To(param *modelmgrEntity.Parameter, locale i18n.Locale) *develo
|
||||
Desc: param.Desc.Read(locale),
|
||||
Type: func() developer_api.ModelParamType {
|
||||
switch param.Type {
|
||||
case modelmgrEntity.ValueTypeBoolean:
|
||||
case modelmgr.ValueTypeBoolean:
|
||||
return developer_api.ModelParamType_Boolean
|
||||
case modelmgrEntity.ValueTypeInt:
|
||||
case modelmgr.ValueTypeInt:
|
||||
return developer_api.ModelParamType_Int
|
||||
case modelmgrEntity.ValueTypeFloat:
|
||||
case modelmgr.ValueTypeFloat:
|
||||
return developer_api.ModelParamType_Float
|
||||
default:
|
||||
return developer_api.ModelParamType_String
|
||||
@@ -191,9 +197,9 @@ func parameterDo2To(param *modelmgrEntity.Parameter, locale i18n.Locale) *develo
|
||||
ParamClass: &developer_api.ModelParamClass{
|
||||
ClassID: func() int32 {
|
||||
switch param.Style.Widget {
|
||||
case modelmgrEntity.WidgetSlider:
|
||||
case modelmgr.WidgetSlider:
|
||||
return 1
|
||||
case modelmgrEntity.WidgetRadioButtons:
|
||||
case modelmgr.WidgetRadioButtons:
|
||||
return 2
|
||||
default:
|
||||
return 0
|
||||
|
||||
@@ -20,15 +20,14 @@ import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
modelmgrEntity "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
|
||||
intelligence "github.com/coze-dev/coze-studio/backend/api/model/intelligence/common"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/bot_common"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/developer_api"
|
||||
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/modelmgr"
|
||||
searchEntity "github.com/coze-dev/coze-studio/backend/domain/search/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
@@ -95,7 +94,7 @@ func (s *SingleAgentApplicationService) newDefaultSingleAgent(ctx context.Contex
|
||||
Plugin: []*bot_common.PluginInfo{},
|
||||
Knowledge: &bot_common.Knowledge{
|
||||
TopK: ptr.Of(int64(1)),
|
||||
MinScore: ptr.Of(float64(0.01)),
|
||||
MinScore: ptr.Of(0.01),
|
||||
SearchStrategy: ptr.Of(bot_common.SearchStrategy_SemanticSearch),
|
||||
RecallStrategy: &bot_common.RecallStrategy{
|
||||
UseNl2sql: ptr.Of(true),
|
||||
@@ -115,8 +114,8 @@ func (s *SingleAgentApplicationService) newDefaultSingleAgent(ctx context.Contex
|
||||
}
|
||||
|
||||
func (s *SingleAgentApplicationService) defaultModelInfo(ctx context.Context) (*bot_common.ModelInfo, error) {
|
||||
modelResp, err := s.appContext.ModelMgrDomainSVC.ListModel(ctx, &modelmgr.ListModelRequest{
|
||||
Status: []modelmgrEntity.ModelEntityStatus{modelmgrEntity.ModelEntityStatusDefault, modelmgrEntity.ModelEntityStatusInUse},
|
||||
modelResp, err := s.appContext.ModelMgr.ListModel(ctx, &modelmgr.ListModelRequest{
|
||||
Status: []modelmgr.ModelStatus{modelmgr.StatusInUse},
|
||||
Limit: 1,
|
||||
Cursor: nil,
|
||||
})
|
||||
@@ -131,8 +130,8 @@ func (s *SingleAgentApplicationService) defaultModelInfo(ctx context.Context) (*
|
||||
dm := modelResp.ModelList[0]
|
||||
|
||||
var temperature *float64
|
||||
if tp, ok := dm.FindParameter(modelmgrEntity.Temperature); ok {
|
||||
t, err := tp.GetFloat(modelmgrEntity.DefaultTypeBalance)
|
||||
if tp, ok := dm.FindParameter(modelmgr.Temperature); ok {
|
||||
t, err := tp.GetFloat(modelmgr.DefaultTypeBalance)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -141,8 +140,8 @@ func (s *SingleAgentApplicationService) defaultModelInfo(ctx context.Context) (*
|
||||
}
|
||||
|
||||
var maxTokens *int32
|
||||
if tp, ok := dm.FindParameter(modelmgrEntity.MaxTokens); ok {
|
||||
t, err := tp.GetInt(modelmgrEntity.DefaultTypeBalance)
|
||||
if tp, ok := dm.FindParameter(modelmgr.MaxTokens); ok {
|
||||
t, err := tp.GetInt(modelmgr.DefaultTypeBalance)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -152,8 +151,8 @@ func (s *SingleAgentApplicationService) defaultModelInfo(ctx context.Context) (*
|
||||
}
|
||||
|
||||
var topP *float64
|
||||
if tp, ok := dm.FindParameter(modelmgrEntity.TopP); ok {
|
||||
t, err := tp.GetFloat(modelmgrEntity.DefaultTypeBalance)
|
||||
if tp, ok := dm.FindParameter(modelmgr.TopP); ok {
|
||||
t, err := tp.GetFloat(modelmgr.DefaultTypeBalance)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -161,8 +160,8 @@ func (s *SingleAgentApplicationService) defaultModelInfo(ctx context.Context) (*
|
||||
}
|
||||
|
||||
var topK *int32
|
||||
if tp, ok := dm.FindParameter(modelmgrEntity.TopK); ok {
|
||||
t, err := tp.GetInt(modelmgrEntity.DefaultTypeBalance)
|
||||
if tp, ok := dm.FindParameter(modelmgr.TopK); ok {
|
||||
t, err := tp.GetInt(modelmgr.DefaultTypeBalance)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -170,8 +169,8 @@ func (s *SingleAgentApplicationService) defaultModelInfo(ctx context.Context) (*
|
||||
}
|
||||
|
||||
var frequencyPenalty *float64
|
||||
if tp, ok := dm.FindParameter(modelmgrEntity.FrequencyPenalty); ok {
|
||||
t, err := tp.GetFloat(modelmgrEntity.DefaultTypeBalance)
|
||||
if tp, ok := dm.FindParameter(modelmgr.FrequencyPenalty); ok {
|
||||
t, err := tp.GetFloat(modelmgr.DefaultTypeBalance)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -179,8 +178,8 @@ func (s *SingleAgentApplicationService) defaultModelInfo(ctx context.Context) (*
|
||||
}
|
||||
|
||||
var presencePenalty *float64
|
||||
if tp, ok := dm.FindParameter(modelmgrEntity.PresencePenalty); ok {
|
||||
t, err := tp.GetFloat(modelmgrEntity.DefaultTypeBalance)
|
||||
if tp, ok := dm.FindParameter(modelmgr.PresencePenalty); ok {
|
||||
t, err := tp.GetFloat(modelmgr.DefaultTypeBalance)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -30,13 +30,12 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
|
||||
knowledge "github.com/coze-dev/coze-studio/backend/domain/knowledge/service"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/modelmgr"
|
||||
modelEntity "github.com/coze-dev/coze-studio/backend/domain/modelmgr/entity"
|
||||
pluginEntity "github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/service"
|
||||
shortcutCMDEntity "github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/entity"
|
||||
workflowEntity "github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
@@ -159,13 +158,13 @@ func (s *SingleAgentApplicationService) shortcutCMDDo2Vo(cmdDOs []*shortcutCMDEn
|
||||
})
|
||||
}
|
||||
|
||||
func (s *SingleAgentApplicationService) fetchModelDetails(ctx context.Context, agentInfo *entity.SingleAgent) ([]*modelEntity.Model, error) {
|
||||
func (s *SingleAgentApplicationService) fetchModelDetails(ctx context.Context, agentInfo *entity.SingleAgent) ([]*modelmgr.Model, error) {
|
||||
if agentInfo.ModelInfo.ModelId == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
modelID := agentInfo.ModelInfo.GetModelId()
|
||||
modelInfos, err := s.appContext.ModelMgrDomainSVC.MGetModelByID(ctx, &modelmgr.MGetModelRequest{
|
||||
modelInfos, err := s.appContext.ModelMgr.MGetModelByID(ctx, &modelmgr.MGetModelRequest{
|
||||
IDs: []int64{modelID},
|
||||
})
|
||||
if err != nil {
|
||||
@@ -249,13 +248,13 @@ func (s *SingleAgentApplicationService) fetchWorkflowDetails(ctx context.Context
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func modelInfoDo2Vo(modelInfos []*modelEntity.Model) map[int64]*playground.ModelDetail {
|
||||
return slices.ToMap(modelInfos, func(e *modelEntity.Model) (int64, *playground.ModelDetail) {
|
||||
func modelInfoDo2Vo(modelInfos []*modelmgr.Model) map[int64]*playground.ModelDetail {
|
||||
return slices.ToMap(modelInfos, func(e *modelmgr.Model) (int64, *playground.ModelDetail) {
|
||||
return e.ID, toModelDetail(e)
|
||||
})
|
||||
}
|
||||
|
||||
func toModelDetail(m *modelEntity.Model) *playground.ModelDetail {
|
||||
func toModelDetail(m *modelmgr.Model) *playground.ModelDetail {
|
||||
mm := m.Meta
|
||||
|
||||
return &playground.ModelDetail{
|
||||
@@ -263,7 +262,7 @@ func toModelDetail(m *modelEntity.Model) *playground.ModelDetail {
|
||||
ModelName: ptr.Of(m.Meta.Name),
|
||||
ModelID: ptr.Of(m.ID),
|
||||
ModelFamily: ptr.Of(int64(mm.Protocol.TOModelClass())),
|
||||
ModelIconURL: ptr.Of(mm.IconURL),
|
||||
ModelIconURL: ptr.Of(m.IconURL),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -28,7 +28,6 @@ import (
|
||||
knowledge "github.com/coze-dev/coze-studio/backend/domain/knowledge/service"
|
||||
database "github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
|
||||
variables "github.com/coze-dev/coze-studio/backend/domain/memory/variables/service"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/service"
|
||||
search "github.com/coze-dev/coze-studio/backend/domain/search/service"
|
||||
shortcutCmd "github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/service"
|
||||
@@ -36,6 +35,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/chatmodel"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/jsoncache"
|
||||
@@ -55,9 +55,9 @@ type ServiceComponents struct {
|
||||
ImageX imagex.ImageX
|
||||
EventBus search.ProjectEventBus
|
||||
CounterRepo repository.CounterRepository
|
||||
ModelMgr modelmgr.Manager
|
||||
|
||||
KnowledgeDomainSVC knowledge.Knowledge
|
||||
ModelMgrDomainSVC modelmgr.Manager
|
||||
PluginDomainSVC service.PluginService
|
||||
WorkflowDomainSVC workflow.Service
|
||||
UserDomainSVC user.User
|
||||
@@ -76,6 +76,7 @@ func InitService(c *ServiceComponents) (*SingleAgentApplicationService, error) {
|
||||
CounterRepo: repository.NewCounterRepo(c.Cache),
|
||||
CPStore: c.CPStore,
|
||||
ModelFactory: chatmodel.NewDefaultFactory(),
|
||||
ModelMgr: c.ModelMgr,
|
||||
}
|
||||
|
||||
singleAgentDomainSVC := singleagent.NewService(domainComponents)
|
||||
|
||||
@@ -30,7 +30,6 @@ import (
|
||||
knowledge "github.com/coze-dev/coze-studio/backend/domain/knowledge/service"
|
||||
dbservice "github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
|
||||
variables "github.com/coze-dev/coze-studio/backend/domain/memory/variables/service"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/modelmgr"
|
||||
plugin "github.com/coze-dev/coze-studio/backend/domain/plugin/service"
|
||||
search "github.com/coze-dev/coze-studio/backend/domain/search/service"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
@@ -44,6 +43,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/service"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/coderunner"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user