274 lines
		
	
	
		
			8.2 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			274 lines
		
	
	
		
			8.2 KiB
		
	
	
	
		
			Go
		
	
	
	
| /*
 | |
|  * Copyright 2025 coze-dev Authors
 | |
|  *
 | |
|  * Licensed under the Apache License, Version 2.0 (the "License");
 | |
|  * you may not use this file except in compliance with the License.
 | |
|  * You may obtain a copy of the License at
 | |
|  *
 | |
|  *     http://www.apache.org/licenses/LICENSE-2.0
 | |
|  *
 | |
|  * Unless required by applicable law or agreed to in writing, software
 | |
|  * distributed under the License is distributed on an "AS IS" BASIS,
 | |
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
|  * See the License for the specific language governing permissions and
 | |
|  * limitations under the License.
 | |
|  */
 | |
| 
 | |
| package chatmodel
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"fmt"
 | |
| 
 | |
| 	"github.com/cloudwego/eino-ext/components/model/ark"
 | |
| 	"github.com/cloudwego/eino-ext/components/model/claude"
 | |
| 	"github.com/cloudwego/eino-ext/components/model/deepseek"
 | |
| 	"github.com/cloudwego/eino-ext/components/model/gemini"
 | |
| 	"github.com/cloudwego/eino-ext/components/model/ollama"
 | |
| 	"github.com/cloudwego/eino-ext/components/model/openai"
 | |
| 	"github.com/cloudwego/eino-ext/components/model/qwen"
 | |
| 	"github.com/ollama/ollama/api"
 | |
| 	"google.golang.org/genai"
 | |
| 
 | |
| 	"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
 | |
| 	"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
 | |
| )
 | |
| 
 | |
| type Builder func(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error)
 | |
| 
 | |
| func NewDefaultFactory() chatmodel.Factory {
 | |
| 	return NewFactory(nil)
 | |
| }
 | |
| 
 | |
| func NewFactory(customFactory map[chatmodel.Protocol]Builder) chatmodel.Factory {
 | |
| 	protocol2Builder := map[chatmodel.Protocol]Builder{
 | |
| 		chatmodel.ProtocolOpenAI:   openAIBuilder,
 | |
| 		chatmodel.ProtocolClaude:   claudeBuilder,
 | |
| 		chatmodel.ProtocolDeepseek: deepseekBuilder,
 | |
| 		chatmodel.ProtocolArk:      arkBuilder,
 | |
| 		chatmodel.ProtocolGemini:   geminiBuilder,
 | |
| 		chatmodel.ProtocolOllama:   ollamaBuilder,
 | |
| 		chatmodel.ProtocolQwen:     qwenBuilder,
 | |
| 		chatmodel.ProtocolErnie:    nil,
 | |
| 	}
 | |
| 
 | |
| 	for p := range customFactory {
 | |
| 		protocol2Builder[p] = customFactory[p]
 | |
| 	}
 | |
| 
 | |
| 	return &defaultFactory{protocol2Builder: protocol2Builder}
 | |
| }
 | |
| 
 | |
| type defaultFactory struct {
 | |
| 	protocol2Builder map[chatmodel.Protocol]Builder
 | |
| }
 | |
| 
 | |
| func (f *defaultFactory) CreateChatModel(ctx context.Context, protocol chatmodel.Protocol, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
 | |
| 	if config == nil {
 | |
| 		return nil, fmt.Errorf("[CreateChatModel] config not provided")
 | |
| 	}
 | |
| 
 | |
| 	builder, found := f.protocol2Builder[protocol]
 | |
| 	if !found {
 | |
| 		return nil, fmt.Errorf("[CreateChatModel] protocol not support, protocol=%s", protocol)
 | |
| 	}
 | |
| 
 | |
| 	return builder(ctx, config)
 | |
| }
 | |
| 
 | |
| func (f *defaultFactory) SupportProtocol(protocol chatmodel.Protocol) bool {
 | |
| 	_, found := f.protocol2Builder[protocol]
 | |
| 	return found
 | |
| }
 | |
| 
 | |
| func openAIBuilder(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
 | |
| 	cfg := &openai.ChatModelConfig{
 | |
| 		APIKey:           config.APIKey,
 | |
| 		Timeout:          config.Timeout,
 | |
| 		BaseURL:          config.BaseURL,
 | |
| 		Model:            config.Model,
 | |
| 		MaxTokens:        config.MaxTokens,
 | |
| 		Temperature:      config.Temperature,
 | |
| 		TopP:             config.TopP,
 | |
| 		Stop:             config.Stop,
 | |
| 		PresencePenalty:  config.PresencePenalty,
 | |
| 		FrequencyPenalty: config.FrequencyPenalty,
 | |
| 	}
 | |
| 	if config.OpenAI != nil {
 | |
| 		cfg.ByAzure = config.OpenAI.ByAzure
 | |
| 		cfg.APIVersion = config.OpenAI.APIVersion
 | |
| 		cfg.ResponseFormat = config.OpenAI.ResponseFormat
 | |
| 	}
 | |
| 	return openai.NewChatModel(ctx, cfg)
 | |
| }
 | |
| 
 | |
| func claudeBuilder(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
 | |
| 	cfg := &claude.Config{
 | |
| 		APIKey:        config.APIKey,
 | |
| 		Model:         config.Model,
 | |
| 		Temperature:   config.Temperature,
 | |
| 		TopP:          config.TopP,
 | |
| 		StopSequences: config.Stop,
 | |
| 	}
 | |
| 	if config.BaseURL != "" {
 | |
| 		cfg.BaseURL = &config.BaseURL
 | |
| 	}
 | |
| 	if config.MaxTokens != nil {
 | |
| 		cfg.MaxTokens = *config.MaxTokens
 | |
| 	}
 | |
| 	if config.TopK != nil {
 | |
| 		cfg.TopK = ptr.Of(int32(*config.TopK))
 | |
| 	}
 | |
| 	if config.Claude != nil {
 | |
| 		cfg.ByBedrock = config.Claude.ByBedrock
 | |
| 		cfg.AccessKey = config.Claude.AccessKey
 | |
| 		cfg.SecretAccessKey = config.Claude.SecretAccessKey
 | |
| 		cfg.SessionToken = config.Claude.SessionToken
 | |
| 		cfg.Region = config.Claude.Region
 | |
| 	}
 | |
| 	return claude.NewChatModel(ctx, cfg)
 | |
| }
 | |
| 
 | |
| func deepseekBuilder(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
 | |
| 	cfg := &deepseek.ChatModelConfig{
 | |
| 		APIKey:  config.APIKey,
 | |
| 		Timeout: config.Timeout,
 | |
| 		BaseURL: config.BaseURL,
 | |
| 		Model:   config.Model,
 | |
| 		Stop:    config.Stop,
 | |
| 	}
 | |
| 	if config.Temperature != nil {
 | |
| 		cfg.Temperature = *config.Temperature
 | |
| 	}
 | |
| 	if config.FrequencyPenalty != nil {
 | |
| 		cfg.FrequencyPenalty = *config.FrequencyPenalty
 | |
| 	}
 | |
| 	if config.PresencePenalty != nil {
 | |
| 		cfg.PresencePenalty = *config.PresencePenalty
 | |
| 	}
 | |
| 	if config.MaxTokens != nil {
 | |
| 		cfg.MaxTokens = *config.MaxTokens
 | |
| 	}
 | |
| 	if config.TopP != nil {
 | |
| 		cfg.TopP = *config.TopP
 | |
| 	}
 | |
| 	if config.Deepseek != nil {
 | |
| 		cfg.ResponseFormatType = config.Deepseek.ResponseFormatType
 | |
| 	}
 | |
| 	return deepseek.NewChatModel(ctx, cfg)
 | |
| }
 | |
| 
 | |
| func arkBuilder(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
 | |
| 	cfg := &ark.ChatModelConfig{
 | |
| 		BaseURL:          config.BaseURL,
 | |
| 		APIKey:           config.APIKey,
 | |
| 		Model:            config.Model,
 | |
| 		MaxTokens:        config.MaxTokens,
 | |
| 		Temperature:      config.Temperature,
 | |
| 		TopP:             config.TopP,
 | |
| 		Stop:             config.Stop,
 | |
| 		FrequencyPenalty: config.FrequencyPenalty,
 | |
| 		PresencePenalty:  config.PresencePenalty,
 | |
| 	}
 | |
| 	if config.Timeout != 0 {
 | |
| 		cfg.Timeout = &config.Timeout
 | |
| 	}
 | |
| 	if config.Ark != nil {
 | |
| 		cfg.Region = config.Ark.Region
 | |
| 		cfg.AccessKey = config.Ark.AccessKey
 | |
| 		cfg.SecretKey = config.Ark.SecretKey
 | |
| 		cfg.RetryTimes = config.Ark.RetryTimes
 | |
| 		cfg.CustomHeader = config.Ark.CustomHeader
 | |
| 	}
 | |
| 	return ark.NewChatModel(ctx, cfg)
 | |
| }
 | |
| 
 | |
| func ollamaBuilder(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
 | |
| 	cfg := &ollama.ChatModelConfig{
 | |
| 		BaseURL:    config.BaseURL,
 | |
| 		Timeout:    config.Timeout,
 | |
| 		HTTPClient: nil,
 | |
| 		Model:      config.Model,
 | |
| 		Format:     nil,
 | |
| 		KeepAlive:  nil,
 | |
| 		Options: &api.Options{
 | |
| 			TopK:             ptr.From(config.TopK),
 | |
| 			TopP:             ptr.From(config.TopP),
 | |
| 			Temperature:      ptr.From(config.Temperature),
 | |
| 			PresencePenalty:  ptr.From(config.PresencePenalty),
 | |
| 			FrequencyPenalty: ptr.From(config.FrequencyPenalty),
 | |
| 			Stop:             config.Stop,
 | |
| 		},
 | |
| 	}
 | |
| 	return ollama.NewChatModel(ctx, cfg)
 | |
| }
 | |
| 
 | |
| func qwenBuilder(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
 | |
| 	cfg := &qwen.ChatModelConfig{
 | |
| 		APIKey:           config.APIKey,
 | |
| 		Timeout:          config.Timeout,
 | |
| 		BaseURL:          config.BaseURL,
 | |
| 		Model:            config.Model,
 | |
| 		MaxTokens:        config.MaxTokens,
 | |
| 		Temperature:      config.Temperature,
 | |
| 		TopP:             config.TopP,
 | |
| 		Stop:             config.Stop,
 | |
| 		PresencePenalty:  config.PresencePenalty,
 | |
| 		FrequencyPenalty: config.FrequencyPenalty,
 | |
| 		EnableThinking:   config.EnableThinking,
 | |
| 	}
 | |
| 	if config.Qwen != nil {
 | |
| 		cfg.ResponseFormat = config.Qwen.ResponseFormat
 | |
| 	}
 | |
| 	return qwen.NewChatModel(ctx, cfg)
 | |
| }
 | |
| 
 | |
| func geminiBuilder(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
 | |
| 	gc := &genai.ClientConfig{
 | |
| 		APIKey: config.APIKey,
 | |
| 		HTTPOptions: genai.HTTPOptions{
 | |
| 			BaseURL: config.BaseURL,
 | |
| 		},
 | |
| 	}
 | |
| 	if config.Gemini != nil {
 | |
| 		gc.Backend = config.Gemini.Backend
 | |
| 		gc.Project = config.Gemini.Project
 | |
| 		gc.Location = config.Gemini.Location
 | |
| 		gc.HTTPOptions.APIVersion = config.Gemini.APIVersion
 | |
| 		gc.HTTPOptions.Headers = config.Gemini.Headers
 | |
| 	}
 | |
| 
 | |
| 	client, err := genai.NewClient(ctx, gc)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	cfg := &gemini.Config{
 | |
| 		Client:      client,
 | |
| 		Model:       config.Model,
 | |
| 		MaxTokens:   config.MaxTokens,
 | |
| 		Temperature: config.Temperature,
 | |
| 		TopP:        config.TopP,
 | |
| 		ThinkingConfig: &genai.ThinkingConfig{
 | |
| 			IncludeThoughts: true,
 | |
| 			ThinkingBudget:  nil,
 | |
| 		},
 | |
| 	}
 | |
| 	if config.TopK != nil {
 | |
| 		cfg.TopK = ptr.Of(int32(ptr.From(config.TopK)))
 | |
| 	}
 | |
| 	if config.Gemini != nil && config.Gemini.IncludeThoughts != nil {
 | |
| 		cfg.ThinkingConfig.IncludeThoughts = ptr.From(config.Gemini.IncludeThoughts)
 | |
| 	}
 | |
| 	if config.Gemini != nil && config.Gemini.ThinkingBudget != nil {
 | |
| 		cfg.ThinkingConfig.ThinkingBudget = config.Gemini.ThinkingBudget
 | |
| 	}
 | |
| 
 | |
| 	cm, err := gemini.NewChatModel(ctx, cfg)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return cm, nil
 | |
| }
 |