feat: manually mirror opencoze's code from bytedance

Change-Id: I09a73aadda978ad9511264a756b2ce51f5761adf
This commit is contained in:
fanlv
2025-07-20 17:36:12 +08:00
commit 890153324f
14811 changed files with 1923430 additions and 0 deletions

View File

@@ -0,0 +1,198 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package modelmgr
import (
"context"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"gopkg.in/yaml.v3"
"gorm.io/gorm"
crossmodelmgr "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
"github.com/coze-dev/coze-studio/backend/domain/modelmgr"
"github.com/coze-dev/coze-studio/backend/domain/modelmgr/entity"
"github.com/coze-dev/coze-studio/backend/domain/modelmgr/service"
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
"github.com/coze-dev/coze-studio/backend/infra/impl/idgen"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
)
func InitService(db *gorm.DB, idgen idgen.IDGenerator, oss storage.Storage) (*ModelmgrApplicationService, error) {
svc := service.NewModelManager(db, idgen, oss)
if err := loadStaticModelConfig(svc, oss); err != nil {
return nil, err
}
ModelmgrApplicationSVC.DomainSVC = svc
return ModelmgrApplicationSVC, nil
}
func loadStaticModelConfig(svc modelmgr.Manager, oss storage.Storage) error {
ctx := context.Background()
id2Meta := make(map[int64]*entity.ModelMeta)
var cursor *string
for {
req := &modelmgr.ListModelMetaRequest{
Status: []entity.ModelMetaStatus{
crossmodelmgr.StatusInUse,
crossmodelmgr.StatusPending,
crossmodelmgr.StatusDeleted,
},
Limit: 100,
Cursor: cursor,
}
listMetaResp, err := svc.ListModelMeta(ctx, req)
if err != nil {
return err
}
for _, item := range listMetaResp.ModelMetaList {
cpItem := item
id2Meta[cpItem.ID] = cpItem
}
if !listMetaResp.HasMore {
break
}
cursor = listMetaResp.NextCursor
}
root, err := os.Getwd()
if err != nil {
return err
}
envModelMeta, envModelEntity, err := initModelByEnv(root, "resources/conf/model/template")
if err != nil {
return err
}
filePath := filepath.Join(root, "resources/conf/model/meta")
staticModelMeta, err := readDirYaml[crossmodelmgr.ModelMeta](filePath)
if err != nil {
return err
}
staticModelMeta = append(staticModelMeta, envModelMeta...)
for _, modelMeta := range staticModelMeta {
if _, found := id2Meta[modelMeta.ID]; !found {
if modelMeta.IconURI == "" && modelMeta.IconURL == "" {
return fmt.Errorf("missing icon URI or icon URL, id=%d", modelMeta.ID)
} else if modelMeta.IconURL != "" {
// do nothing
} else if modelMeta.IconURI != "" {
// try local path
base := filepath.Base(modelMeta.IconURI)
iconPath := filepath.Join("resources/conf/model/icon", base)
if _, err = os.Stat(iconPath); err == nil {
// try upload icon
icon, err := os.ReadFile(iconPath)
if err != nil {
return err
}
key := fmt.Sprintf("icon_%s_%d", base, time.Now().Second())
if err := oss.PutObject(ctx, key, icon); err != nil {
return err
}
modelMeta.IconURI = key
} else if errors.Is(err, os.ErrNotExist) {
// try to get object from uri
if _, err := oss.GetObject(ctx, modelMeta.IconURI); err != nil {
return err
}
} else {
return err
}
}
newMeta, err := svc.CreateModelMeta(ctx, modelMeta)
if err != nil {
if errors.Is(err, gorm.ErrDuplicatedKey) {
logs.Infof("[loadStaticModelConfig] model meta conflict for id=%d, skip", newMeta.ID)
}
return err
} else {
logs.Infof("[loadStaticModelConfig] model meta create success, id=%d", newMeta.ID)
}
id2Meta[newMeta.ID] = newMeta
} else {
logs.Infof("[loadStaticModelConfig] model meta founded, skip create, id=%d", modelMeta.ID)
}
}
filePath = filepath.Join(root, "resources/conf/model/entity")
staticModel, err := readDirYaml[crossmodelmgr.Model](filePath)
if err != nil {
return err
}
staticModel = append(staticModel, envModelEntity...)
for _, modelEntity := range staticModel {
curModelEntities, err := svc.MGetModelByID(ctx, &modelmgr.MGetModelRequest{IDs: []int64{modelEntity.ID}})
if err != nil {
return err
}
if len(curModelEntities) > 0 {
logs.Infof("[loadStaticModelConfig] model entity founded, skip create, id=%d", modelEntity.ID)
continue
}
meta, found := id2Meta[modelEntity.Meta.ID]
if !found {
return fmt.Errorf("model meta not found for id=%d, model_id=%d", modelEntity.Meta.ID, modelEntity.ID)
}
modelEntity.Meta = *meta
if _, err = svc.CreateModel(ctx, &entity.Model{Model: modelEntity}); err != nil {
if errors.Is(err, gorm.ErrDuplicatedKey) {
logs.Infof("[loadStaticModelConfig] model entity conflict for id=%d, skip", modelEntity.ID)
}
return err
} else {
logs.Infof("[loadStaticModelConfig] model entity create success, id=%d", modelEntity.ID)
}
}
return nil
}
func readDirYaml[T any](dir string) ([]*T, error) {
des, err := os.ReadDir(dir)
if err != nil {
return nil, err
}
resp := make([]*T, 0, len(des))
for _, file := range des {
if file.IsDir() {
continue
}
if strings.HasSuffix(file.Name(), ".yaml") || strings.HasSuffix(file.Name(), ".yml") {
filePath := filepath.Join(dir, file.Name())
data, err := os.ReadFile(filePath)
if err != nil {
return nil, err
}
var content T
if err := yaml.Unmarshal(data, &content); err != nil {
return nil, err
}
resp = append(resp, &content)
}
}
return resp, nil
}

View File

@@ -0,0 +1,148 @@
package modelmgr
import (
"fmt"
"os"
"path/filepath"
"strconv"
"gopkg.in/yaml.v3"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
)
func initModelByEnv(wd, templatePath string) (metaSlice []*modelmgr.ModelMeta, entitySlice []*modelmgr.Model, err error) {
metaRoot := filepath.Join(wd, templatePath, "meta")
entityRoot := filepath.Join(wd, templatePath, "entity")
for i := -1; i < 1000; i++ {
rawProtocol := os.Getenv(concatEnvKey(modelProtocolPrefix, i))
if rawProtocol == "" {
if i < 0 {
continue
} else {
break
}
}
protocol := chatmodel.Protocol(rawProtocol)
info, valid := getModelEnv(i)
if !valid {
break
}
mapping, found := modelMapping[protocol]
if !found {
return nil, nil, fmt.Errorf("[initModelByEnv] unsupport protocol: %s", rawProtocol)
}
switch protocol {
case chatmodel.ProtocolArk:
fileSuffix, foundTemplate := mapping[info.modelName]
if !foundTemplate {
logs.Warnf("[initModelByEnv] unsupport model=%s, using default config", info.modelName)
}
modelMeta, err := readYaml[modelmgr.ModelMeta](filepath.Join(metaRoot, concatTemplateFileName("model_meta_template_ark", fileSuffix)))
if err != nil {
return nil, nil, err
}
modelEntity, err := readYaml[modelmgr.Model](filepath.Join(entityRoot, concatTemplateFileName("model_entity_template_ark", fileSuffix)))
if err != nil {
return nil, nil, err
}
id, err := strconv.ParseInt(info.id, 10, 64)
if err != nil {
return nil, nil, err
}
// meta 和 entity 用一个 id有概率冲突
modelMeta.ID = id
modelMeta.ConnConfig.Model = info.modelID
modelMeta.ConnConfig.APIKey = info.apiKey
if info.baseURL != "" {
modelMeta.ConnConfig.BaseURL = info.baseURL
}
modelEntity.ID = id
modelEntity.Meta.ID = id
if !foundTemplate {
modelEntity.Name = info.modelName
}
metaSlice = append(metaSlice, modelMeta)
entitySlice = append(entitySlice, modelEntity)
default:
return nil, nil, fmt.Errorf("[initModelByEnv] unsupport protocol: %s", rawProtocol)
}
}
return metaSlice, entitySlice, nil
}
type envModelInfo struct {
id, modelName, modelID, apiKey, baseURL string
}
func getModelEnv(idx int) (info envModelInfo, valid bool) {
info.id = os.Getenv(concatEnvKey(modelOpenCozeIDPrefix, idx))
info.modelName = os.Getenv(concatEnvKey(modelNamePrefix, idx))
info.modelID = os.Getenv(concatEnvKey(modelIDPrefix, idx))
info.apiKey = os.Getenv(concatEnvKey(modelApiKeyPrefix, idx))
info.baseURL = os.Getenv(concatEnvKey(modelBaseURLPrefix, idx))
valid = info.modelName != "" && info.modelID != "" && info.apiKey != ""
return
}
func readYaml[T any](fPath string) (*T, error) {
data, err := os.ReadFile(fPath)
if err != nil {
return nil, err
}
var content T
if err := yaml.Unmarshal(data, &content); err != nil {
return nil, err
}
return &content, nil
}
func concatEnvKey(prefix string, idx int) string {
if idx < 0 {
return prefix
}
return fmt.Sprintf("%s_%d", prefix, idx)
}
func concatTemplateFileName(prefix, suffix string) string {
if suffix == "" {
return prefix + ".yaml"
}
return prefix + "_" + suffix + ".yaml"
}
const (
modelProtocolPrefix = "MODEL_PROTOCOL" // model protocol
modelOpenCozeIDPrefix = "MODEL_OPENCOZE_ID" // opencoze model id
modelNamePrefix = "MODEL_NAME" // model name,
modelIDPrefix = "MODEL_ID" // model in conn config
modelApiKeyPrefix = "MODEL_API_KEY" // model api key
modelBaseURLPrefix = "MODEL_BASE_URL" // model base url
)
var modelMapping = map[chatmodel.Protocol]map[string]string{
chatmodel.ProtocolArk: {
"doubao-seed-1.6": "doubao-seed-1.6",
"doubao-seed-1.6-flash": "doubao-seed-1.6-flash",
"doubao-seed-1.6-thinking": "doubao-seed-1.6-thinking",
"doubao-1.5-thinking-vision-pro": "doubao-1.5-thinking-vision-pro",
"doubao-1.5-thinking-pro": "doubao-1.5-thinking-pro",
"doubao-1.5-vision-pro": "doubao-1.5-vision-pro",
"doubao-1.5-vision-lite": "doubao-1.5-vision-lite",
"doubao-1.5-pro-32k": "doubao-1.5-pro-32k",
"doubao-1.5-pro-256k": "doubao-1.5-pro-256k",
"doubao-1.5-lite": "doubao-1.5-lite",
"deepseek-r1": "volc_deepseek-r1",
"deepseek-v3": "volc_deepseek-v3",
},
}

View File

@@ -0,0 +1,31 @@
package modelmgr
import (
"fmt"
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
)
func TestInitByEnv(t *testing.T) {
i := 0
for k := range modelMapping[chatmodel.ProtocolArk] {
_ = os.Setenv(concatEnvKey(modelProtocolPrefix, i), "ark")
_ = os.Setenv(concatEnvKey(modelOpenCozeIDPrefix, i), fmt.Sprintf("%d", 45678+i))
_ = os.Setenv(concatEnvKey(modelNamePrefix, i), k)
_ = os.Setenv(concatEnvKey(modelIDPrefix, i), k)
_ = os.Setenv(concatEnvKey(modelApiKeyPrefix, i), "mock_api_key")
i++
}
wd, err := os.Getwd()
assert.NoError(t, err)
ms, es, err := initModelByEnv(wd, "../../conf/model/template")
assert.NoError(t, err)
assert.Len(t, ms, len(modelMapping[chatmodel.ProtocolArk]))
assert.Len(t, es, len(modelMapping[chatmodel.ProtocolArk]))
}

View File

@@ -0,0 +1,205 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package modelmgr
import (
"context"
modelmgrEntity "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/developer_api"
"github.com/coze-dev/coze-studio/backend/domain/modelmgr"
modelEntity "github.com/coze-dev/coze-studio/backend/domain/modelmgr/entity"
"github.com/coze-dev/coze-studio/backend/pkg/i18n"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/sets"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
)
type ModelmgrApplicationService struct {
DomainSVC modelmgr.Manager
}
var ModelmgrApplicationSVC = &ModelmgrApplicationService{}
func (m *ModelmgrApplicationService) GetModelList(ctx context.Context, req *developer_api.GetTypeListRequest) (
resp *developer_api.GetTypeListResponse, err error,
) {
// 一般不太可能同时配置这么多模型
const modelMaxLimit = 300
modelResp, err := m.DomainSVC.ListModel(ctx, &modelmgr.ListModelRequest{
Limit: modelMaxLimit,
Cursor: nil,
})
if err != nil {
return nil, err
}
locale := i18n.GetLocale(ctx)
modelList, err := slices.TransformWithErrorCheck(modelResp.ModelList, func(m *modelEntity.Model) (*developer_api.Model, error) {
logs.CtxInfof(ctx, "ChatModel DefaultParameters: %v", m.DefaultParameters)
return modelDo2To(m, locale)
})
if err != nil {
return nil, err
}
return &developer_api.GetTypeListResponse{
Code: 0,
Msg: "success",
Data: &developer_api.GetTypeListData{
ModelList: modelList,
},
}, nil
}
func modelDo2To(model *modelEntity.Model, locale i18n.Locale) (*developer_api.Model, error) {
mm := model.Meta
mps := slices.Transform(model.DefaultParameters,
func(param *modelmgrEntity.Parameter) *developer_api.ModelParameter {
return parameterDo2To(param, locale)
},
)
modalSet := sets.FromSlice(mm.Capability.InputModal)
return &developer_api.Model{
Name: model.Name,
ModelType: model.ID,
ModelClass: mm.Protocol.TOModelClass(),
ModelIcon: mm.IconURL,
ModelInputPrice: 0,
ModelOutputPrice: 0,
ModelQuota: &developer_api.ModelQuota{
TokenLimit: int32(mm.Capability.MaxTokens),
TokenResp: int32(mm.Capability.OutputTokens),
// TokenSystem: 0,
// TokenUserIn: 0,
// TokenToolsIn: 0,
// TokenToolsOut: 0,
// TokenData: 0,
// TokenHistory: 0,
// TokenCutSwitch: false,
PriceIn: 0,
PriceOut: 0,
SystemPromptLimit: nil,
},
ModelName: mm.Name,
ModelClassName: mm.Protocol.TOModelClass().String(),
IsOffline: mm.Status != modelmgrEntity.StatusInUse,
ModelParams: mps,
ModelDesc: []*developer_api.ModelDescGroup{
{
GroupName: "Description",
Desc: []string{model.Description},
},
},
FuncConfig: nil,
EndpointName: nil,
ModelTagList: nil,
IsUpRequired: nil,
ModelBriefDesc: mm.Description.Read(locale),
ModelSeries: &developer_api.ModelSeriesInfo{ // TODO: 替换为真实配置
SeriesName: "热门模型",
},
ModelStatusDetails: nil,
ModelAbility: &developer_api.ModelAbility{
CotDisplay: ptr.Of(mm.Capability.Reasoning),
FunctionCall: ptr.Of(mm.Capability.FunctionCall),
ImageUnderstanding: ptr.Of(modalSet.Contains(modelmgrEntity.ModalImage)),
VideoUnderstanding: ptr.Of(modalSet.Contains(modelmgrEntity.ModalVideo)),
AudioUnderstanding: ptr.Of(modalSet.Contains(modelmgrEntity.ModalAudio)),
SupportMultiModal: ptr.Of(len(modalSet) > 1),
PrefillResp: ptr.Of(mm.Capability.PrefillResponse),
},
}, nil
}
func parameterDo2To(param *modelmgrEntity.Parameter, locale i18n.Locale) *developer_api.ModelParameter {
if param == nil {
return nil
}
apiOptions := make([]*developer_api.Option, 0, len(param.Options))
for _, opt := range param.Options {
apiOptions = append(apiOptions, &developer_api.Option{
Label: opt.Label,
Value: opt.Value,
})
}
var custom string
var creative, balance, precise *string
if val, ok := param.DefaultVal[modelmgrEntity.DefaultTypeDefault]; ok {
custom = val
}
if val, ok := param.DefaultVal[modelmgrEntity.DefaultTypeCreative]; ok {
creative = ptr.Of(val)
}
if val, ok := param.DefaultVal[modelmgrEntity.DefaultTypeBalance]; ok {
balance = ptr.Of(val)
}
if val, ok := param.DefaultVal[modelmgrEntity.DefaultTypePrecise]; ok {
precise = ptr.Of(val)
}
return &developer_api.ModelParameter{
Name: string(param.Name),
Label: param.Label.Read(locale),
Desc: param.Desc.Read(locale),
Type: func() developer_api.ModelParamType {
switch param.Type {
case modelmgrEntity.ValueTypeBoolean:
return developer_api.ModelParamType_Boolean
case modelmgrEntity.ValueTypeInt:
return developer_api.ModelParamType_Int
case modelmgrEntity.ValueTypeFloat:
return developer_api.ModelParamType_Float
default:
return developer_api.ModelParamType_String
}
}(),
Min: param.Min,
Max: param.Max,
Precision: int32(param.Precision),
DefaultVal: &developer_api.ModelParamDefaultValue{
DefaultVal: custom,
Creative: creative,
Balance: balance,
Precise: precise,
},
Options: apiOptions,
ParamClass: &developer_api.ModelParamClass{
ClassID: func() int32 {
switch param.Style.Widget {
case modelmgrEntity.WidgetSlider:
return 1
case modelmgrEntity.WidgetRadioButtons:
return 2
default:
return 0
}
}(),
Label: param.Style.Label.Read(locale),
},
}
}