coze-studio/backend/infra/impl/chatmodel/default_factory.go

301 lines
8.9 KiB
Go

/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package chatmodel
import (
"context"
"fmt"
"github.com/cloudwego/eino-ext/components/model/ark"
"github.com/cloudwego/eino-ext/components/model/claude"
"github.com/cloudwego/eino-ext/components/model/deepseek"
"github.com/cloudwego/eino-ext/components/model/gemini"
"github.com/cloudwego/eino-ext/components/model/ollama"
"github.com/cloudwego/eino-ext/components/model/openai"
"github.com/cloudwego/eino-ext/components/model/qwen"
"github.com/ollama/ollama/api"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
"google.golang.org/genai"
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
type Builder func(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error)
func NewDefaultFactory() chatmodel.Factory {
return NewFactory(nil)
}
func NewFactory(customFactory map[chatmodel.Protocol]Builder) chatmodel.Factory {
protocol2Builder := map[chatmodel.Protocol]Builder{
chatmodel.ProtocolOpenAI: openAIBuilder,
chatmodel.ProtocolClaude: claudeBuilder,
chatmodel.ProtocolDeepseek: deepseekBuilder,
chatmodel.ProtocolArk: arkBuilder,
chatmodel.ProtocolGemini: geminiBuilder,
chatmodel.ProtocolOllama: ollamaBuilder,
chatmodel.ProtocolQwen: qwenBuilder,
chatmodel.ProtocolErnie: nil,
}
for p := range customFactory {
protocol2Builder[p] = customFactory[p]
}
return &defaultFactory{protocol2Builder: protocol2Builder}
}
type defaultFactory struct {
protocol2Builder map[chatmodel.Protocol]Builder
}
func (f *defaultFactory) CreateChatModel(ctx context.Context, protocol chatmodel.Protocol, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
if config == nil {
return nil, fmt.Errorf("[CreateChatModel] config not provided")
}
builder, found := f.protocol2Builder[protocol]
if !found {
return nil, fmt.Errorf("[CreateChatModel] protocol not support, protocol=%s", protocol)
}
return builder(ctx, config)
}
func (f *defaultFactory) SupportProtocol(protocol chatmodel.Protocol) bool {
_, found := f.protocol2Builder[protocol]
return found
}
func openAIBuilder(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
cfg := &openai.ChatModelConfig{
APIKey: config.APIKey,
Timeout: config.Timeout,
BaseURL: config.BaseURL,
Model: config.Model,
MaxTokens: config.MaxTokens,
Temperature: config.Temperature,
TopP: config.TopP,
Stop: config.Stop,
PresencePenalty: config.PresencePenalty,
FrequencyPenalty: config.FrequencyPenalty,
}
if config.OpenAI != nil {
cfg.ByAzure = config.OpenAI.ByAzure
cfg.APIVersion = config.OpenAI.APIVersion
cfg.ResponseFormat = config.OpenAI.ResponseFormat
}
return openai.NewChatModel(ctx, cfg)
}
func claudeBuilder(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
cfg := &claude.Config{
APIKey: config.APIKey,
Model: config.Model,
Temperature: config.Temperature,
TopP: config.TopP,
StopSequences: config.Stop,
}
if config.BaseURL != "" {
cfg.BaseURL = &config.BaseURL
}
if config.MaxTokens != nil {
cfg.MaxTokens = *config.MaxTokens
}
if config.TopK != nil {
cfg.TopK = ptr.Of(int32(*config.TopK))
}
if config.Claude != nil {
cfg.ByBedrock = config.Claude.ByBedrock
cfg.AccessKey = config.Claude.AccessKey
cfg.SecretAccessKey = config.Claude.SecretAccessKey
cfg.SessionToken = config.Claude.SessionToken
cfg.Region = config.Claude.Region
}
if config.EnableThinking != nil {
cfg.Thinking = &claude.Thinking{
Enable: ptr.From(config.EnableThinking),
}
if config.Claude != nil && config.Claude.BudgetTokens != nil {
cfg.Thinking.BudgetTokens = ptr.From(config.Claude.BudgetTokens)
}
}
return claude.NewChatModel(ctx, cfg)
}
func deepseekBuilder(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
cfg := &deepseek.ChatModelConfig{
APIKey: config.APIKey,
Timeout: config.Timeout,
BaseURL: config.BaseURL,
Model: config.Model,
Stop: config.Stop,
}
if config.Temperature != nil {
cfg.Temperature = *config.Temperature
}
if config.FrequencyPenalty != nil {
cfg.FrequencyPenalty = *config.FrequencyPenalty
}
if config.PresencePenalty != nil {
cfg.PresencePenalty = *config.PresencePenalty
}
if config.MaxTokens != nil {
cfg.MaxTokens = *config.MaxTokens
}
if config.TopP != nil {
cfg.TopP = *config.TopP
}
if config.Deepseek != nil {
cfg.ResponseFormatType = config.Deepseek.ResponseFormatType
}
return deepseek.NewChatModel(ctx, cfg)
}
func arkBuilder(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
cfg := &ark.ChatModelConfig{
BaseURL: config.BaseURL,
APIKey: config.APIKey,
Model: config.Model,
MaxTokens: config.MaxTokens,
Temperature: config.Temperature,
TopP: config.TopP,
Stop: config.Stop,
FrequencyPenalty: config.FrequencyPenalty,
PresencePenalty: config.PresencePenalty,
}
if config.Timeout != 0 {
cfg.Timeout = &config.Timeout
}
if config.Ark != nil {
cfg.Region = config.Ark.Region
cfg.AccessKey = config.Ark.AccessKey
cfg.SecretKey = config.Ark.SecretKey
cfg.RetryTimes = config.Ark.RetryTimes
cfg.CustomHeader = config.Ark.CustomHeader
}
if config.EnableThinking != nil {
cfg.Thinking = func() *model.Thinking {
var arkThinkingType model.ThinkingType
switch {
case ptr.From(config.EnableThinking):
arkThinkingType = model.ThinkingTypeEnabled
default:
arkThinkingType = model.ThinkingTypeDisabled
}
return &model.Thinking{
Type: arkThinkingType,
}
}()
}
return ark.NewChatModel(ctx, cfg)
}
func ollamaBuilder(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
cfg := &ollama.ChatModelConfig{
BaseURL: config.BaseURL,
Timeout: config.Timeout,
HTTPClient: nil,
Model: config.Model,
Format: nil,
KeepAlive: nil,
Options: &api.Options{
TopK: ptr.From(config.TopK),
TopP: ptr.From(config.TopP),
Temperature: ptr.From(config.Temperature),
PresencePenalty: ptr.From(config.PresencePenalty),
FrequencyPenalty: ptr.From(config.FrequencyPenalty),
Stop: config.Stop,
},
}
if config.EnableThinking != nil {
cfg.Thinking = config.EnableThinking
}
return ollama.NewChatModel(ctx, cfg)
}
func qwenBuilder(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
cfg := &qwen.ChatModelConfig{
APIKey: config.APIKey,
Timeout: config.Timeout,
BaseURL: config.BaseURL,
Model: config.Model,
MaxTokens: config.MaxTokens,
Temperature: config.Temperature,
TopP: config.TopP,
Stop: config.Stop,
PresencePenalty: config.PresencePenalty,
FrequencyPenalty: config.FrequencyPenalty,
EnableThinking: config.EnableThinking,
}
if config.Qwen != nil {
cfg.ResponseFormat = config.Qwen.ResponseFormat
}
return qwen.NewChatModel(ctx, cfg)
}
func geminiBuilder(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
gc := &genai.ClientConfig{
APIKey: config.APIKey,
HTTPOptions: genai.HTTPOptions{
BaseURL: config.BaseURL,
},
}
if config.Gemini != nil {
gc.Backend = config.Gemini.Backend
gc.Project = config.Gemini.Project
gc.Location = config.Gemini.Location
gc.HTTPOptions.APIVersion = config.Gemini.APIVersion
gc.HTTPOptions.Headers = config.Gemini.Headers
}
client, err := genai.NewClient(ctx, gc)
if err != nil {
return nil, err
}
cfg := &gemini.Config{
Client: client,
Model: config.Model,
MaxTokens: config.MaxTokens,
Temperature: config.Temperature,
TopP: config.TopP,
ThinkingConfig: &genai.ThinkingConfig{
IncludeThoughts: true,
ThinkingBudget: nil,
},
}
if config.TopK != nil {
cfg.TopK = ptr.Of(int32(ptr.From(config.TopK)))
}
if config.Gemini != nil && config.Gemini.IncludeThoughts != nil {
cfg.ThinkingConfig.IncludeThoughts = ptr.From(config.Gemini.IncludeThoughts)
}
if config.Gemini != nil && config.Gemini.ThinkingBudget != nil {
cfg.ThinkingConfig.ThinkingBudget = config.Gemini.ThinkingBudget
}
cm, err := gemini.NewChatModel(ctx, cfg)
if err != nil {
return nil, err
}
return cm, nil
}