274 lines
		
	
	
		
			8.2 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			274 lines
		
	
	
		
			8.2 KiB
		
	
	
	
		
			Go
		
	
	
	
/*
 | 
						|
 * Copyright 2025 coze-dev Authors
 | 
						|
 *
 | 
						|
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
 * you may not use this file except in compliance with the License.
 | 
						|
 * You may obtain a copy of the License at
 | 
						|
 *
 | 
						|
 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
 *
 | 
						|
 * Unless required by applicable law or agreed to in writing, software
 | 
						|
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
 * See the License for the specific language governing permissions and
 | 
						|
 * limitations under the License.
 | 
						|
 */
 | 
						|
 | 
						|
package chatmodel
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"fmt"
 | 
						|
 | 
						|
	"github.com/cloudwego/eino-ext/components/model/ark"
 | 
						|
	"github.com/cloudwego/eino-ext/components/model/claude"
 | 
						|
	"github.com/cloudwego/eino-ext/components/model/deepseek"
 | 
						|
	"github.com/cloudwego/eino-ext/components/model/gemini"
 | 
						|
	"github.com/cloudwego/eino-ext/components/model/ollama"
 | 
						|
	"github.com/cloudwego/eino-ext/components/model/openai"
 | 
						|
	"github.com/cloudwego/eino-ext/components/model/qwen"
 | 
						|
	"github.com/ollama/ollama/api"
 | 
						|
	"google.golang.org/genai"
 | 
						|
 | 
						|
	"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
 | 
						|
)
 | 
						|
 | 
						|
type Builder func(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error)
 | 
						|
 | 
						|
func NewDefaultFactory() chatmodel.Factory {
 | 
						|
	return NewFactory(nil)
 | 
						|
}
 | 
						|
 | 
						|
func NewFactory(customFactory map[chatmodel.Protocol]Builder) chatmodel.Factory {
 | 
						|
	protocol2Builder := map[chatmodel.Protocol]Builder{
 | 
						|
		chatmodel.ProtocolOpenAI:   openAIBuilder,
 | 
						|
		chatmodel.ProtocolClaude:   claudeBuilder,
 | 
						|
		chatmodel.ProtocolDeepseek: deepseekBuilder,
 | 
						|
		chatmodel.ProtocolArk:      arkBuilder,
 | 
						|
		chatmodel.ProtocolGemini:   geminiBuilder,
 | 
						|
		chatmodel.ProtocolOllama:   ollamaBuilder,
 | 
						|
		chatmodel.ProtocolQwen:     qwenBuilder,
 | 
						|
		chatmodel.ProtocolErnie:    nil,
 | 
						|
	}
 | 
						|
 | 
						|
	for p := range customFactory {
 | 
						|
		protocol2Builder[p] = customFactory[p]
 | 
						|
	}
 | 
						|
 | 
						|
	return &defaultFactory{protocol2Builder: protocol2Builder}
 | 
						|
}
 | 
						|
 | 
						|
type defaultFactory struct {
 | 
						|
	protocol2Builder map[chatmodel.Protocol]Builder
 | 
						|
}
 | 
						|
 | 
						|
func (f *defaultFactory) CreateChatModel(ctx context.Context, protocol chatmodel.Protocol, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
 | 
						|
	if config == nil {
 | 
						|
		return nil, fmt.Errorf("[CreateChatModel] config not provided")
 | 
						|
	}
 | 
						|
 | 
						|
	builder, found := f.protocol2Builder[protocol]
 | 
						|
	if !found {
 | 
						|
		return nil, fmt.Errorf("[CreateChatModel] protocol not support, protocol=%s", protocol)
 | 
						|
	}
 | 
						|
 | 
						|
	return builder(ctx, config)
 | 
						|
}
 | 
						|
 | 
						|
func (f *defaultFactory) SupportProtocol(protocol chatmodel.Protocol) bool {
 | 
						|
	_, found := f.protocol2Builder[protocol]
 | 
						|
	return found
 | 
						|
}
 | 
						|
 | 
						|
func openAIBuilder(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
 | 
						|
	cfg := &openai.ChatModelConfig{
 | 
						|
		APIKey:           config.APIKey,
 | 
						|
		Timeout:          config.Timeout,
 | 
						|
		BaseURL:          config.BaseURL,
 | 
						|
		Model:            config.Model,
 | 
						|
		MaxTokens:        config.MaxTokens,
 | 
						|
		Temperature:      config.Temperature,
 | 
						|
		TopP:             config.TopP,
 | 
						|
		Stop:             config.Stop,
 | 
						|
		PresencePenalty:  config.PresencePenalty,
 | 
						|
		FrequencyPenalty: config.FrequencyPenalty,
 | 
						|
	}
 | 
						|
	if config.OpenAI != nil {
 | 
						|
		cfg.ByAzure = config.OpenAI.ByAzure
 | 
						|
		cfg.APIVersion = config.OpenAI.APIVersion
 | 
						|
		cfg.ResponseFormat = config.OpenAI.ResponseFormat
 | 
						|
	}
 | 
						|
	return openai.NewChatModel(ctx, cfg)
 | 
						|
}
 | 
						|
 | 
						|
func claudeBuilder(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
 | 
						|
	cfg := &claude.Config{
 | 
						|
		APIKey:        config.APIKey,
 | 
						|
		Model:         config.Model,
 | 
						|
		Temperature:   config.Temperature,
 | 
						|
		TopP:          config.TopP,
 | 
						|
		StopSequences: config.Stop,
 | 
						|
	}
 | 
						|
	if config.BaseURL != "" {
 | 
						|
		cfg.BaseURL = &config.BaseURL
 | 
						|
	}
 | 
						|
	if config.MaxTokens != nil {
 | 
						|
		cfg.MaxTokens = *config.MaxTokens
 | 
						|
	}
 | 
						|
	if config.TopK != nil {
 | 
						|
		cfg.TopK = ptr.Of(int32(*config.TopK))
 | 
						|
	}
 | 
						|
	if config.Claude != nil {
 | 
						|
		cfg.ByBedrock = config.Claude.ByBedrock
 | 
						|
		cfg.AccessKey = config.Claude.AccessKey
 | 
						|
		cfg.SecretAccessKey = config.Claude.SecretAccessKey
 | 
						|
		cfg.SessionToken = config.Claude.SessionToken
 | 
						|
		cfg.Region = config.Claude.Region
 | 
						|
	}
 | 
						|
	return claude.NewChatModel(ctx, cfg)
 | 
						|
}
 | 
						|
 | 
						|
func deepseekBuilder(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
 | 
						|
	cfg := &deepseek.ChatModelConfig{
 | 
						|
		APIKey:  config.APIKey,
 | 
						|
		Timeout: config.Timeout,
 | 
						|
		BaseURL: config.BaseURL,
 | 
						|
		Model:   config.Model,
 | 
						|
		Stop:    config.Stop,
 | 
						|
	}
 | 
						|
	if config.Temperature != nil {
 | 
						|
		cfg.Temperature = *config.Temperature
 | 
						|
	}
 | 
						|
	if config.FrequencyPenalty != nil {
 | 
						|
		cfg.FrequencyPenalty = *config.FrequencyPenalty
 | 
						|
	}
 | 
						|
	if config.PresencePenalty != nil {
 | 
						|
		cfg.PresencePenalty = *config.PresencePenalty
 | 
						|
	}
 | 
						|
	if config.MaxTokens != nil {
 | 
						|
		cfg.MaxTokens = *config.MaxTokens
 | 
						|
	}
 | 
						|
	if config.TopP != nil {
 | 
						|
		cfg.TopP = *config.TopP
 | 
						|
	}
 | 
						|
	if config.Deepseek != nil {
 | 
						|
		cfg.ResponseFormatType = config.Deepseek.ResponseFormatType
 | 
						|
	}
 | 
						|
	return deepseek.NewChatModel(ctx, cfg)
 | 
						|
}
 | 
						|
 | 
						|
func arkBuilder(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
 | 
						|
	cfg := &ark.ChatModelConfig{
 | 
						|
		BaseURL:          config.BaseURL,
 | 
						|
		APIKey:           config.APIKey,
 | 
						|
		Model:            config.Model,
 | 
						|
		MaxTokens:        config.MaxTokens,
 | 
						|
		Temperature:      config.Temperature,
 | 
						|
		TopP:             config.TopP,
 | 
						|
		Stop:             config.Stop,
 | 
						|
		FrequencyPenalty: config.FrequencyPenalty,
 | 
						|
		PresencePenalty:  config.PresencePenalty,
 | 
						|
	}
 | 
						|
	if config.Timeout != 0 {
 | 
						|
		cfg.Timeout = &config.Timeout
 | 
						|
	}
 | 
						|
	if config.Ark != nil {
 | 
						|
		cfg.Region = config.Ark.Region
 | 
						|
		cfg.AccessKey = config.Ark.AccessKey
 | 
						|
		cfg.SecretKey = config.Ark.SecretKey
 | 
						|
		cfg.RetryTimes = config.Ark.RetryTimes
 | 
						|
		cfg.CustomHeader = config.Ark.CustomHeader
 | 
						|
	}
 | 
						|
	return ark.NewChatModel(ctx, cfg)
 | 
						|
}
 | 
						|
 | 
						|
func ollamaBuilder(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
 | 
						|
	cfg := &ollama.ChatModelConfig{
 | 
						|
		BaseURL:    config.BaseURL,
 | 
						|
		Timeout:    config.Timeout,
 | 
						|
		HTTPClient: nil,
 | 
						|
		Model:      config.Model,
 | 
						|
		Format:     nil,
 | 
						|
		KeepAlive:  nil,
 | 
						|
		Options: &api.Options{
 | 
						|
			TopK:             ptr.From(config.TopK),
 | 
						|
			TopP:             ptr.From(config.TopP),
 | 
						|
			Temperature:      ptr.From(config.Temperature),
 | 
						|
			PresencePenalty:  ptr.From(config.PresencePenalty),
 | 
						|
			FrequencyPenalty: ptr.From(config.FrequencyPenalty),
 | 
						|
			Stop:             config.Stop,
 | 
						|
		},
 | 
						|
	}
 | 
						|
	return ollama.NewChatModel(ctx, cfg)
 | 
						|
}
 | 
						|
 | 
						|
func qwenBuilder(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
 | 
						|
	cfg := &qwen.ChatModelConfig{
 | 
						|
		APIKey:           config.APIKey,
 | 
						|
		Timeout:          config.Timeout,
 | 
						|
		BaseURL:          config.BaseURL,
 | 
						|
		Model:            config.Model,
 | 
						|
		MaxTokens:        config.MaxTokens,
 | 
						|
		Temperature:      config.Temperature,
 | 
						|
		TopP:             config.TopP,
 | 
						|
		Stop:             config.Stop,
 | 
						|
		PresencePenalty:  config.PresencePenalty,
 | 
						|
		FrequencyPenalty: config.FrequencyPenalty,
 | 
						|
		EnableThinking:   config.EnableThinking,
 | 
						|
	}
 | 
						|
	if config.Qwen != nil {
 | 
						|
		cfg.ResponseFormat = config.Qwen.ResponseFormat
 | 
						|
	}
 | 
						|
	return qwen.NewChatModel(ctx, cfg)
 | 
						|
}
 | 
						|
 | 
						|
func geminiBuilder(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
 | 
						|
	gc := &genai.ClientConfig{
 | 
						|
		APIKey: config.APIKey,
 | 
						|
		HTTPOptions: genai.HTTPOptions{
 | 
						|
			BaseURL: config.BaseURL,
 | 
						|
		},
 | 
						|
	}
 | 
						|
	if config.Gemini != nil {
 | 
						|
		gc.Backend = config.Gemini.Backend
 | 
						|
		gc.Project = config.Gemini.Project
 | 
						|
		gc.Location = config.Gemini.Location
 | 
						|
		gc.HTTPOptions.APIVersion = config.Gemini.APIVersion
 | 
						|
		gc.HTTPOptions.Headers = config.Gemini.Headers
 | 
						|
	}
 | 
						|
 | 
						|
	client, err := genai.NewClient(ctx, gc)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	cfg := &gemini.Config{
 | 
						|
		Client:      client,
 | 
						|
		Model:       config.Model,
 | 
						|
		MaxTokens:   config.MaxTokens,
 | 
						|
		Temperature: config.Temperature,
 | 
						|
		TopP:        config.TopP,
 | 
						|
		ThinkingConfig: &genai.ThinkingConfig{
 | 
						|
			IncludeThoughts: true,
 | 
						|
			ThinkingBudget:  nil,
 | 
						|
		},
 | 
						|
	}
 | 
						|
	if config.TopK != nil {
 | 
						|
		cfg.TopK = ptr.Of(int32(ptr.From(config.TopK)))
 | 
						|
	}
 | 
						|
	if config.Gemini != nil && config.Gemini.IncludeThoughts != nil {
 | 
						|
		cfg.ThinkingConfig.IncludeThoughts = ptr.From(config.Gemini.IncludeThoughts)
 | 
						|
	}
 | 
						|
	if config.Gemini != nil && config.Gemini.ThinkingBudget != nil {
 | 
						|
		cfg.ThinkingConfig.ThinkingBudget = config.Gemini.ThinkingBudget
 | 
						|
	}
 | 
						|
 | 
						|
	cm, err := gemini.NewChatModel(ctx, cfg)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	return cm, nil
 | 
						|
}
 |