feat: manually mirror opencoze's code from bytedance
Change-Id: I09a73aadda978ad9511264a756b2ce51f5761adf
This commit is contained in:
273
backend/infra/impl/chatmodel/default_factory.go
Normal file
273
backend/infra/impl/chatmodel/default_factory.go
Normal file
@@ -0,0 +1,273 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package chatmodel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/model/ark"
|
||||
"github.com/cloudwego/eino-ext/components/model/claude"
|
||||
"github.com/cloudwego/eino-ext/components/model/deepseek"
|
||||
"github.com/cloudwego/eino-ext/components/model/gemini"
|
||||
"github.com/cloudwego/eino-ext/components/model/ollama"
|
||||
"github.com/cloudwego/eino-ext/components/model/openai"
|
||||
"github.com/cloudwego/eino-ext/components/model/qwen"
|
||||
"github.com/ollama/ollama/api"
|
||||
"google.golang.org/genai"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
type Builder func(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error)
|
||||
|
||||
func NewDefaultFactory() chatmodel.Factory {
|
||||
return NewFactory(nil)
|
||||
}
|
||||
|
||||
func NewFactory(customFactory map[chatmodel.Protocol]Builder) chatmodel.Factory {
|
||||
protocol2Builder := map[chatmodel.Protocol]Builder{
|
||||
chatmodel.ProtocolOpenAI: openAIBuilder,
|
||||
chatmodel.ProtocolClaude: claudeBuilder,
|
||||
chatmodel.ProtocolDeepseek: deepseekBuilder,
|
||||
chatmodel.ProtocolArk: arkBuilder,
|
||||
chatmodel.ProtocolGemini: geminiBuilder,
|
||||
chatmodel.ProtocolOllama: ollamaBuilder,
|
||||
chatmodel.ProtocolQwen: qwenBuilder,
|
||||
chatmodel.ProtocolErnie: nil,
|
||||
}
|
||||
|
||||
for p := range customFactory {
|
||||
protocol2Builder[p] = customFactory[p]
|
||||
}
|
||||
|
||||
return &defaultFactory{protocol2Builder: protocol2Builder}
|
||||
}
|
||||
|
||||
type defaultFactory struct {
|
||||
protocol2Builder map[chatmodel.Protocol]Builder
|
||||
}
|
||||
|
||||
func (f *defaultFactory) CreateChatModel(ctx context.Context, protocol chatmodel.Protocol, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
|
||||
if config == nil {
|
||||
return nil, fmt.Errorf("[CreateChatModel] config not provided")
|
||||
}
|
||||
|
||||
builder, found := f.protocol2Builder[protocol]
|
||||
if !found {
|
||||
return nil, fmt.Errorf("[CreateChatModel] protocol not support, protocol=%s", protocol)
|
||||
}
|
||||
|
||||
return builder(ctx, config)
|
||||
}
|
||||
|
||||
func (f *defaultFactory) SupportProtocol(protocol chatmodel.Protocol) bool {
|
||||
_, found := f.protocol2Builder[protocol]
|
||||
return found
|
||||
}
|
||||
|
||||
func openAIBuilder(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
|
||||
cfg := &openai.ChatModelConfig{
|
||||
APIKey: config.APIKey,
|
||||
Timeout: config.Timeout,
|
||||
BaseURL: config.BaseURL,
|
||||
Model: config.Model,
|
||||
MaxTokens: config.MaxTokens,
|
||||
Temperature: config.Temperature,
|
||||
TopP: config.TopP,
|
||||
Stop: config.Stop,
|
||||
PresencePenalty: config.PresencePenalty,
|
||||
FrequencyPenalty: config.FrequencyPenalty,
|
||||
}
|
||||
if config.OpenAI != nil {
|
||||
cfg.ByAzure = config.OpenAI.ByAzure
|
||||
cfg.APIVersion = config.OpenAI.APIVersion
|
||||
cfg.ResponseFormat = config.OpenAI.ResponseFormat
|
||||
}
|
||||
return openai.NewChatModel(ctx, cfg)
|
||||
}
|
||||
|
||||
func claudeBuilder(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
|
||||
cfg := &claude.Config{
|
||||
APIKey: config.APIKey,
|
||||
Model: config.Model,
|
||||
Temperature: config.Temperature,
|
||||
TopP: config.TopP,
|
||||
StopSequences: config.Stop,
|
||||
}
|
||||
if config.BaseURL != "" {
|
||||
cfg.BaseURL = &config.BaseURL
|
||||
}
|
||||
if config.MaxTokens != nil {
|
||||
cfg.MaxTokens = *config.MaxTokens
|
||||
}
|
||||
if config.TopK != nil {
|
||||
cfg.TopK = ptr.Of(int32(*config.TopK))
|
||||
}
|
||||
if config.Claude != nil {
|
||||
cfg.ByBedrock = config.Claude.ByBedrock
|
||||
cfg.AccessKey = config.Claude.AccessKey
|
||||
cfg.SecretAccessKey = config.Claude.SecretAccessKey
|
||||
cfg.SessionToken = config.Claude.SessionToken
|
||||
cfg.Region = config.Claude.Region
|
||||
}
|
||||
return claude.NewChatModel(ctx, cfg)
|
||||
}
|
||||
|
||||
func deepseekBuilder(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
|
||||
cfg := &deepseek.ChatModelConfig{
|
||||
APIKey: config.APIKey,
|
||||
Timeout: config.Timeout,
|
||||
BaseURL: config.BaseURL,
|
||||
Model: config.Model,
|
||||
Stop: config.Stop,
|
||||
}
|
||||
if config.Temperature != nil {
|
||||
cfg.Temperature = *config.Temperature
|
||||
}
|
||||
if config.FrequencyPenalty != nil {
|
||||
cfg.FrequencyPenalty = *config.FrequencyPenalty
|
||||
}
|
||||
if config.PresencePenalty != nil {
|
||||
cfg.PresencePenalty = *config.PresencePenalty
|
||||
}
|
||||
if config.MaxTokens != nil {
|
||||
cfg.MaxTokens = *config.MaxTokens
|
||||
}
|
||||
if config.TopP != nil {
|
||||
cfg.TopP = *config.TopP
|
||||
}
|
||||
if config.Deepseek != nil {
|
||||
cfg.ResponseFormatType = config.Deepseek.ResponseFormatType
|
||||
}
|
||||
return deepseek.NewChatModel(ctx, cfg)
|
||||
}
|
||||
|
||||
func arkBuilder(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
|
||||
cfg := &ark.ChatModelConfig{
|
||||
BaseURL: config.BaseURL,
|
||||
APIKey: config.APIKey,
|
||||
Model: config.Model,
|
||||
MaxTokens: config.MaxTokens,
|
||||
Temperature: config.Temperature,
|
||||
TopP: config.TopP,
|
||||
Stop: config.Stop,
|
||||
FrequencyPenalty: config.FrequencyPenalty,
|
||||
PresencePenalty: config.PresencePenalty,
|
||||
}
|
||||
if config.Timeout != 0 {
|
||||
cfg.Timeout = &config.Timeout
|
||||
}
|
||||
if config.Ark != nil {
|
||||
cfg.Region = config.Ark.Region
|
||||
cfg.AccessKey = config.Ark.AccessKey
|
||||
cfg.SecretKey = config.Ark.SecretKey
|
||||
cfg.RetryTimes = config.Ark.RetryTimes
|
||||
cfg.CustomHeader = config.Ark.CustomHeader
|
||||
}
|
||||
return ark.NewChatModel(ctx, cfg)
|
||||
}
|
||||
|
||||
func ollamaBuilder(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
|
||||
cfg := &ollama.ChatModelConfig{
|
||||
BaseURL: config.BaseURL,
|
||||
Timeout: config.Timeout,
|
||||
HTTPClient: nil,
|
||||
Model: config.Model,
|
||||
Format: nil,
|
||||
KeepAlive: nil,
|
||||
Options: &api.Options{
|
||||
TopK: ptr.From(config.TopK),
|
||||
TopP: ptr.From(config.TopP),
|
||||
Temperature: ptr.From(config.Temperature),
|
||||
PresencePenalty: ptr.From(config.PresencePenalty),
|
||||
FrequencyPenalty: ptr.From(config.FrequencyPenalty),
|
||||
Stop: config.Stop,
|
||||
},
|
||||
}
|
||||
return ollama.NewChatModel(ctx, cfg)
|
||||
}
|
||||
|
||||
func qwenBuilder(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
|
||||
cfg := &qwen.ChatModelConfig{
|
||||
APIKey: config.APIKey,
|
||||
Timeout: config.Timeout,
|
||||
BaseURL: config.BaseURL,
|
||||
Model: config.Model,
|
||||
MaxTokens: config.MaxTokens,
|
||||
Temperature: config.Temperature,
|
||||
TopP: config.TopP,
|
||||
Stop: config.Stop,
|
||||
PresencePenalty: config.PresencePenalty,
|
||||
FrequencyPenalty: config.FrequencyPenalty,
|
||||
EnableThinking: config.EnableThinking,
|
||||
}
|
||||
if config.Qwen != nil {
|
||||
cfg.ResponseFormat = config.Qwen.ResponseFormat
|
||||
}
|
||||
return qwen.NewChatModel(ctx, cfg)
|
||||
}
|
||||
|
||||
func geminiBuilder(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
|
||||
gc := &genai.ClientConfig{
|
||||
APIKey: config.APIKey,
|
||||
HTTPOptions: genai.HTTPOptions{
|
||||
BaseURL: config.BaseURL,
|
||||
},
|
||||
}
|
||||
if config.Gemini != nil {
|
||||
gc.Backend = config.Gemini.Backend
|
||||
gc.Project = config.Gemini.Project
|
||||
gc.Location = config.Gemini.Location
|
||||
gc.HTTPOptions.APIVersion = config.Gemini.APIVersion
|
||||
gc.HTTPOptions.Headers = config.Gemini.Headers
|
||||
}
|
||||
|
||||
client, err := genai.NewClient(ctx, gc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cfg := &gemini.Config{
|
||||
Client: client,
|
||||
Model: config.Model,
|
||||
MaxTokens: config.MaxTokens,
|
||||
Temperature: config.Temperature,
|
||||
TopP: config.TopP,
|
||||
ThinkingConfig: &genai.ThinkingConfig{
|
||||
IncludeThoughts: true,
|
||||
ThinkingBudget: nil,
|
||||
},
|
||||
}
|
||||
if config.TopK != nil {
|
||||
cfg.TopK = ptr.Of(int32(ptr.From(config.TopK)))
|
||||
}
|
||||
if config.Gemini != nil && config.Gemini.IncludeThoughts != nil {
|
||||
cfg.ThinkingConfig.IncludeThoughts = ptr.From(config.Gemini.IncludeThoughts)
|
||||
}
|
||||
if config.Gemini != nil && config.Gemini.ThinkingBudget != nil {
|
||||
cfg.ThinkingConfig.ThinkingBudget = config.Gemini.ThinkingBudget
|
||||
}
|
||||
|
||||
cm, err := gemini.NewChatModel(ctx, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return cm, nil
|
||||
}
|
||||
38
backend/infra/impl/chatmodel/singleton.go
Normal file
38
backend/infra/impl/chatmodel/singleton.go
Normal file
@@ -0,0 +1,38 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package chatmodel
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
|
||||
)
|
||||
|
||||
var (
|
||||
once sync.Once
|
||||
singletonFactory chatmodel.Factory
|
||||
)
|
||||
|
||||
func InitSingletonFactory(factory chatmodel.Factory) {
|
||||
once.Do(func() {
|
||||
singletonFactory = factory
|
||||
})
|
||||
}
|
||||
|
||||
func GetSingletonFactory() chatmodel.Factory {
|
||||
return singletonFactory
|
||||
}
|
||||
Reference in New Issue
Block a user