236 lines
		
	
	
		
			6.3 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			236 lines
		
	
	
		
			6.3 KiB
		
	
	
	
		
			Go
		
	
	
	
/*
 | 
						|
 * Copyright 2025 coze-dev Authors
 | 
						|
 *
 | 
						|
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
 * you may not use this file except in compliance with the License.
 | 
						|
 * You may obtain a copy of the License at
 | 
						|
 *
 | 
						|
 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
 *
 | 
						|
 * Unless required by applicable law or agreed to in writing, software
 | 
						|
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
 * See the License for the specific language governing permissions and
 | 
						|
 * limitations under the License.
 | 
						|
 */
 | 
						|
 | 
						|
package appinfra
 | 
						|
 | 
						|
import (
 | 
						|
	"fmt"
 | 
						|
	"os"
 | 
						|
	"path/filepath"
 | 
						|
	"strconv"
 | 
						|
	"strings"
 | 
						|
 | 
						|
	"gopkg.in/yaml.v3"
 | 
						|
 | 
						|
	"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/infra/impl/modelmgr/static"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/pkg/logs"
 | 
						|
)
 | 
						|
 | 
						|
func initModelMgr() (modelmgr.Manager, error) {
 | 
						|
	wd, err := os.Getwd()
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	staticModel, err := initModelByTemplate(wd, "resources/conf/model")
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	envModel, err := initModelByEnv(wd, "resources/conf/model/template")
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	all := append(staticModel, envModel...)
 | 
						|
	if err := fillModelContent(all); err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	mgr, err := static.NewModelMgr(all)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	return mgr, nil
 | 
						|
}
 | 
						|
 | 
						|
func initModelByTemplate(wd, configPath string) ([]*modelmgr.Model, error) {
 | 
						|
	configRoot := filepath.Join(wd, configPath)
 | 
						|
	staticModel, err := readDirYaml[modelmgr.Model](configRoot)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	return staticModel, nil
 | 
						|
}
 | 
						|
 | 
						|
func initModelByEnv(wd, templatePath string) (modelEntities []*modelmgr.Model, err error) {
 | 
						|
	entityRoot := filepath.Join(wd, templatePath)
 | 
						|
 | 
						|
	for i := -1; i < 1000; i++ {
 | 
						|
		rawProtocol := os.Getenv(concatEnvKey(modelProtocolPrefix, i))
 | 
						|
		if rawProtocol == "" {
 | 
						|
			if i < 0 {
 | 
						|
				continue
 | 
						|
			} else {
 | 
						|
				break
 | 
						|
			}
 | 
						|
		}
 | 
						|
 | 
						|
		protocol := chatmodel.Protocol(rawProtocol)
 | 
						|
		info, valid := getModelEnv(i)
 | 
						|
		if !valid {
 | 
						|
			break
 | 
						|
		}
 | 
						|
 | 
						|
		mapping, found := modelMapping[protocol]
 | 
						|
		if !found {
 | 
						|
			return nil, fmt.Errorf("[initModelByEnv] unsupport protocol: %s", rawProtocol)
 | 
						|
		}
 | 
						|
 | 
						|
		switch protocol {
 | 
						|
		case chatmodel.ProtocolArk:
 | 
						|
			fileSuffix, foundTemplate := mapping[info.modelName]
 | 
						|
			if !foundTemplate {
 | 
						|
				logs.Warnf("[initModelByEnv] unsupport model=%s, using default config", info.modelName)
 | 
						|
			}
 | 
						|
			modelEntity, err := readYaml[modelmgr.Model](filepath.Join(entityRoot, concatTemplateFileName("model_template_ark", fileSuffix)))
 | 
						|
			if err != nil {
 | 
						|
				return nil, err
 | 
						|
			}
 | 
						|
			id, err := strconv.ParseInt(info.id, 10, 64)
 | 
						|
			if err != nil {
 | 
						|
				return nil, err
 | 
						|
			}
 | 
						|
 | 
						|
			modelEntity.ID = id
 | 
						|
			if !foundTemplate {
 | 
						|
				modelEntity.Name = info.modelName
 | 
						|
			}
 | 
						|
			modelEntity.Meta.ConnConfig.Model = info.modelID
 | 
						|
			modelEntity.Meta.ConnConfig.APIKey = info.apiKey
 | 
						|
			modelEntity.Meta.ConnConfig.BaseURL = info.baseURL
 | 
						|
 | 
						|
			modelEntities = append(modelEntities, modelEntity)
 | 
						|
 | 
						|
		default:
 | 
						|
			return nil, fmt.Errorf("[initModelByEnv] unsupport protocol: %s", rawProtocol)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return modelEntities, nil
 | 
						|
}
 | 
						|
 | 
						|
type envModelInfo struct {
 | 
						|
	id, modelName, modelID, apiKey, baseURL string
 | 
						|
}
 | 
						|
 | 
						|
func getModelEnv(idx int) (info envModelInfo, valid bool) {
 | 
						|
	info.id = os.Getenv(concatEnvKey(modelOpenCozeIDPrefix, idx))
 | 
						|
	info.modelName = os.Getenv(concatEnvKey(modelNamePrefix, idx))
 | 
						|
	info.modelID = os.Getenv(concatEnvKey(modelIDPrefix, idx))
 | 
						|
	info.apiKey = os.Getenv(concatEnvKey(modelApiKeyPrefix, idx))
 | 
						|
	info.baseURL = os.Getenv(concatEnvKey(modelBaseURLPrefix, idx))
 | 
						|
	valid = info.modelName != "" && info.modelID != "" && info.apiKey != ""
 | 
						|
	return
 | 
						|
}
 | 
						|
 | 
						|
func readDirYaml[T any](dir string) ([]*T, error) {
 | 
						|
	des, err := os.ReadDir(dir)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
	resp := make([]*T, 0, len(des))
 | 
						|
	for _, file := range des {
 | 
						|
		if file.IsDir() {
 | 
						|
			continue
 | 
						|
		}
 | 
						|
		if strings.HasSuffix(file.Name(), ".yaml") || strings.HasSuffix(file.Name(), ".yml") {
 | 
						|
			filePath := filepath.Join(dir, file.Name())
 | 
						|
			data, err := os.ReadFile(filePath)
 | 
						|
			if err != nil {
 | 
						|
				return nil, err
 | 
						|
			}
 | 
						|
			var content T
 | 
						|
			if err := yaml.Unmarshal(data, &content); err != nil {
 | 
						|
				return nil, err
 | 
						|
			}
 | 
						|
			resp = append(resp, &content)
 | 
						|
		}
 | 
						|
	}
 | 
						|
	return resp, nil
 | 
						|
}
 | 
						|
 | 
						|
func readYaml[T any](fPath string) (*T, error) {
 | 
						|
	data, err := os.ReadFile(fPath)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
	var content T
 | 
						|
	if err := yaml.Unmarshal(data, &content); err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
	return &content, nil
 | 
						|
}
 | 
						|
 | 
						|
func concatEnvKey(prefix string, idx int) string {
 | 
						|
	if idx < 0 {
 | 
						|
		return prefix
 | 
						|
	}
 | 
						|
	return fmt.Sprintf("%s_%d", prefix, idx)
 | 
						|
}
 | 
						|
 | 
						|
func concatTemplateFileName(prefix, suffix string) string {
 | 
						|
	if suffix == "" {
 | 
						|
		return prefix + ".yaml"
 | 
						|
	}
 | 
						|
	return prefix + "_" + suffix + ".yaml"
 | 
						|
}
 | 
						|
 | 
						|
const (
 | 
						|
	modelProtocolPrefix   = "MODEL_PROTOCOL"    // model protocol
 | 
						|
	modelOpenCozeIDPrefix = "MODEL_OPENCOZE_ID" // opencoze model id
 | 
						|
	modelNamePrefix       = "MODEL_NAME"        // model name,
 | 
						|
	modelIDPrefix         = "MODEL_ID"          // model in conn config
 | 
						|
	modelApiKeyPrefix     = "MODEL_API_KEY"     // model api key
 | 
						|
	modelBaseURLPrefix    = "MODEL_BASE_URL"    // model base url
 | 
						|
)
 | 
						|
 | 
						|
var modelMapping = map[chatmodel.Protocol]map[string]string{
 | 
						|
	chatmodel.ProtocolArk: {
 | 
						|
		"doubao-seed-1.6":                "doubao-seed-1.6",
 | 
						|
		"doubao-seed-1.6-flash":          "doubao-seed-1.6-flash",
 | 
						|
		"doubao-seed-1.6-thinking":       "doubao-seed-1.6-thinking",
 | 
						|
		"doubao-1.5-thinking-vision-pro": "doubao-1.5-thinking-vision-pro",
 | 
						|
		"doubao-1.5-thinking-pro":        "doubao-1.5-thinking-pro",
 | 
						|
		"doubao-1.5-vision-pro":          "doubao-1.5-vision-pro",
 | 
						|
		"doubao-1.5-vision-lite":         "doubao-1.5-vision-lite",
 | 
						|
		"doubao-1.5-pro-32k":             "doubao-1.5-pro-32k",
 | 
						|
		"doubao-1.5-pro-256k":            "doubao-1.5-pro-256k",
 | 
						|
		"doubao-1.5-lite":                "doubao-1.5-lite",
 | 
						|
		"deepseek-r1":                    "volc_deepseek-r1",
 | 
						|
		"deepseek-v3":                    "volc_deepseek-v3",
 | 
						|
	},
 | 
						|
}
 | 
						|
 | 
						|
func fillModelContent(items []*modelmgr.Model) error {
 | 
						|
	for i := range items {
 | 
						|
		item := items[i]
 | 
						|
		if item.Meta.Status == modelmgr.StatusDefault {
 | 
						|
			item.Meta.Status = modelmgr.StatusInUse
 | 
						|
		}
 | 
						|
 | 
						|
		if item.IconURI == "" && item.IconURL == "" {
 | 
						|
			return fmt.Errorf("missing icon URI or icon URL, id=%d", item.ID)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 |