From 9660a85454055dd90dd353dba6d4399589433be3 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Wed, 30 Jul 2025 22:51:16 +0800 Subject: [PATCH] fix(plugin): setting the default value of the agent tool does not take effect (#341) --- backend/application/base/pluginutil/api.go | 2 + .../singleagent/service/single_agent_impl.go | 3 +- backend/domain/plugin/conf/load_plugin.go | 13 ++- ...getn_tool_draft.go => agent_tool_draft.go} | 80 ++++++++++++------- .../domain/plugin/internal/dal/tool_draft.go | 7 +- backend/domain/plugin/repository/tool_impl.go | 54 +++++++++++-- backend/domain/plugin/service/exec_tool.go | 2 +- 7 files changed, 115 insertions(+), 46 deletions(-) rename backend/domain/plugin/internal/dal/{agetn_tool_draft.go => agent_tool_draft.go} (79%) diff --git a/backend/application/base/pluginutil/api.go b/backend/application/base/pluginutil/api.go index 17690183..ec866777 100644 --- a/backend/application/base/pluginutil/api.go +++ b/backend/application/base/pluginutil/api.go @@ -22,6 +22,8 @@ import ( "github.com/getkin/kin-openapi/openapi3" + "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" + "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin" common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common" "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" diff --git a/backend/domain/agent/singleagent/service/single_agent_impl.go b/backend/domain/agent/singleagent/service/single_agent_impl.go index 53dfb5cd..6735f93e 100644 --- a/backend/domain/agent/singleagent/service/single_agent_impl.go +++ b/backend/domain/agent/singleagent/service/single_agent_impl.go @@ -21,9 +21,10 @@ import ( "fmt" "math/rand" + "github.com/jinzhu/copier" + "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/schema" - "github.com/jinzhu/copier" "github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/bot_common" "github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossplugin" diff --git a/backend/domain/plugin/conf/load_plugin.go b/backend/domain/plugin/conf/load_plugin.go index 9eedb8f1..9b31ee2a 100644 --- a/backend/domain/plugin/conf/load_plugin.go +++ b/backend/domain/plugin/conf/load_plugin.go @@ -24,6 +24,7 @@ import ( "strings" "github.com/getkin/kin-openapi/openapi3" + "github.com/mohae/deepcopy" "golang.org/x/mod/semver" "gopkg.in/yaml.v3" @@ -59,18 +60,26 @@ var ( func GetToolProduct(toolID int64) (*ToolInfo, bool) { ti, ok := toolProducts[toolID] - return ti, ok + if !ok { + return nil, false + } + + ti_ := deepcopy.Copy(ti).(*ToolInfo) + + return ti_, true } func MGetToolProducts(toolIDs []int64) []*ToolInfo { tools := make([]*ToolInfo, 0, len(toolIDs)) for _, toolID := range toolIDs { - ti, ok := toolProducts[toolID] + ti, ok := GetToolProduct(toolID) if !ok { continue } + tools = append(tools, ti) } + return tools } diff --git a/backend/domain/plugin/internal/dal/agetn_tool_draft.go b/backend/domain/plugin/internal/dal/agent_tool_draft.go similarity index 79% rename from backend/domain/plugin/internal/dal/agetn_tool_draft.go rename to backend/domain/plugin/internal/dal/agent_tool_draft.go index 27a297c6..bacb21a5 100644 --- a/backend/domain/plugin/internal/dal/agetn_tool_draft.go +++ b/backend/domain/plugin/internal/dal/agent_tool_draft.go @@ -20,7 +20,9 @@ import ( "context" "errors" + "gorm.io/gen/field" "gorm.io/gorm" + "gorm.io/gorm/clause" "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" "github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal/model" @@ -55,6 +57,26 @@ func (a agentToolDraftPO) ToDO() *entity.ToolInfo { } } +func (at *AgentToolDraftDAO) getSelected(opt *ToolSelectedOption) (selected []field.Expr) { + if opt == nil { + return selected + } + + table := at.query.AgentToolDraft + + if opt.ToolID { + selected = append(selected, table.ToolID) + } + if opt.ToolMethod { + selected = append(selected, table.Method) + } + if opt.ToolSubURL { + selected = append(selected, table.SubURL) + } + + return selected +} + func (at *AgentToolDraftDAO) Get(ctx context.Context, agentID, toolID int64) (tool *entity.ToolInfo, exist bool, err error) { table := at.query.AgentToolDraft tl, err := table.WithContext(ctx). @@ -120,13 +142,14 @@ func (at *AgentToolDraftDAO) MGet(ctx context.Context, agentID int64, toolIDs [] return tools, nil } -func (at *AgentToolDraftDAO) GetAll(ctx context.Context, agentID int64) (tools []*entity.ToolInfo, err error) { +func (at *AgentToolDraftDAO) GetAll(ctx context.Context, agentID int64, opt *ToolSelectedOption) (tools []*entity.ToolInfo, err error) { const limit = 20 table := at.query.AgentToolDraft cursor := int64(0) for { tls, err := table.WithContext(ctx). + Select(at.getSelected(opt)...). Where( table.AgentID.Eq(agentID), table.ID.Gt(cursor), @@ -171,6 +194,16 @@ func (at *AgentToolDraftDAO) UpdateWithToolName(ctx context.Context, agentID int } func (at *AgentToolDraftDAO) BatchCreateWithTX(ctx context.Context, tx *query.QueryTx, agentID int64, tools []*entity.ToolInfo) (err error) { + return at.batchCreateWithTX(ctx, tx, agentID, tools, false) +} + +func (at *AgentToolDraftDAO) BatchCreateIgnoreConflictWithTX(ctx context.Context, tx *query.QueryTx, agentID int64, tools []*entity.ToolInfo) (err error) { + return at.batchCreateWithTX(ctx, tx, agentID, tools, true) +} + +func (at *AgentToolDraftDAO) batchCreateWithTX(ctx context.Context, tx *query.QueryTx, agentID int64, + tools []*entity.ToolInfo, ignoreConflict bool) (err error) { + tls := make([]*model.AgentToolDraft, 0, len(tools)) for _, tl := range tools { id, err := at.idGen.GenID(ctx) @@ -192,7 +225,13 @@ func (at *AgentToolDraftDAO) BatchCreateWithTX(ctx context.Context, tx *query.Qu } table := tx.AgentToolDraft - err = table.WithContext(ctx).CreateInBatches(tls, 20) + + if ignoreConflict { + err = table.WithContext(ctx).Clauses(clause.OnConflict{DoNothing: true}). + CreateInBatches(tls, 20) + } else { + err = table.WithContext(ctx).CreateInBatches(tls, 20) + } if err != nil { return err } @@ -200,27 +239,6 @@ func (at *AgentToolDraftDAO) BatchCreateWithTX(ctx context.Context, tx *query.Qu return nil } -func (at *AgentToolDraftDAO) DeleteAll(ctx context.Context, agentID int64) (err error) { - const limit = 20 - table := at.query.AgentToolDraft - - for { - info, err := table.WithContext(ctx). - Where(table.AgentID.Eq(agentID)). - Limit(limit). - Delete() - if err != nil { - return err - } - - if info.RowsAffected == 0 || info.RowsAffected < limit { - break - } - } - - return nil -} - func (at *AgentToolDraftDAO) GetAllPluginIDs(ctx context.Context, agentID int64) (pluginIDs []int64, err error) { const size = 100 table := at.query.AgentToolDraft @@ -254,22 +272,22 @@ func (at *AgentToolDraftDAO) GetAllPluginIDs(ctx context.Context, agentID int64) return slices.Unique(pluginIDs), nil } -func (at *AgentToolDraftDAO) DeleteAllWithTX(ctx context.Context, tx *query.QueryTx, agentID int64) (err error) { +func (at *AgentToolDraftDAO) DeleteWithTX(ctx context.Context, tx *query.QueryTx, agentID int64, toolIDs []int64) (err error) { const limit = 20 table := tx.AgentToolDraft - for { - info, err := table.WithContext(ctx). - Where(table.AgentID.Eq(agentID)). + chunks := slices.Chunks(toolIDs, limit) + for _, chunk := range chunks { + _, err = table.WithContext(ctx). + Where( + table.AgentID.Eq(agentID), + table.ToolID.In(chunk...), + ). Limit(limit). Delete() if err != nil { return err } - - if info.RowsAffected == 0 || info.RowsAffected < limit { - break - } } return nil diff --git a/backend/domain/plugin/internal/dal/tool_draft.go b/backend/domain/plugin/internal/dal/tool_draft.go index a7bbd737..dfb06355 100644 --- a/backend/domain/plugin/internal/dal/tool_draft.go +++ b/backend/domain/plugin/internal/dal/tool_draft.go @@ -114,7 +114,6 @@ func (t *ToolDraftDAO) Create(ctx context.Context, tool *entity.ToolInfo) (toolI } func (t *ToolDraftDAO) genToolID(ctx context.Context) (id int64, err error) { - retryTimes := 5 for i := 0; i < retryTimes; i++ { @@ -123,11 +122,13 @@ func (t *ToolDraftDAO) genToolID(ctx context.Context) (id int64, err error) { return 0, err } - if _, ok := conf.GetToolProduct(id); !ok { + _, ok := conf.GetToolProduct(id) + if !ok { break } + if i == retryTimes-1 { - return 0, fmt.Errorf("id %d is confilict with product tool id.", id) + return 0, fmt.Errorf("id %d is confilict with product tool id", id) } } diff --git a/backend/domain/plugin/repository/tool_impl.go b/backend/domain/plugin/repository/tool_impl.go index f75d7a7a..b5baad46 100644 --- a/backend/domain/plugin/repository/tool_impl.go +++ b/backend/domain/plugin/repository/tool_impl.go @@ -216,6 +216,7 @@ func (t *toolRepoImpl) GetOnlineTool(ctx context.Context, toolID int64) (tool *e func (t *toolRepoImpl) MGetOnlineTools(ctx context.Context, toolIDs []int64, opts ...ToolSelectedOptions) (tools []*entity.ToolInfo, err error) { toolProducts := pluginConf.MGetToolProducts(toolIDs) + tools = slices.Transform(toolProducts, func(tool *pluginConf.ToolInfo) *entity.ToolInfo { return tool.Info }) @@ -269,13 +270,38 @@ func (t *toolRepoImpl) MGetVersionTools(ctx context.Context, versionTools []enti } func (t *toolRepoImpl) BindDraftAgentTools(ctx context.Context, agentID int64, toolIDs []int64) (err error) { - onlineTools, err := t.MGetOnlineTools(ctx, toolIDs) + opt := &dal.ToolSelectedOption{ + ToolID: true, + } + draftAgentTools, err := t.agentToolDraftDAO.GetAll(ctx, agentID, opt) if err != nil { return err } - if len(onlineTools) == 0 { - return t.agentToolDraftDAO.DeleteAll(ctx, agentID) + draftAgentToolIDMap := slices.ToMap(draftAgentTools, func(tool *entity.ToolInfo) (int64, bool) { + return tool.ID, true + }) + + bindToolIDMap := slices.ToMap(toolIDs, func(toolID int64) (int64, bool) { + return toolID, true + }) + + newBindToolIDs := make([]int64, 0, len(toolIDs)) + for _, toolID := range toolIDs { + _, ok := draftAgentToolIDMap[toolID] + if ok { + continue + } + newBindToolIDs = append(newBindToolIDs, toolID) + } + + removeToolIDs := make([]int64, 0, len(draftAgentTools)) + for toolID := range draftAgentToolIDMap { + _, ok := bindToolIDMap[toolID] + if ok { + continue + } + removeToolIDs = append(removeToolIDs, toolID) } tx := t.query.Begin() @@ -298,17 +324,29 @@ func (t *toolRepoImpl) BindDraftAgentTools(ctx context.Context, agentID int64, t } }() - err = t.agentToolDraftDAO.DeleteAllWithTX(ctx, tx, agentID) + onlineTools, err := t.MGetOnlineTools(ctx, newBindToolIDs) if err != nil { return err } - err = t.agentToolDraftDAO.BatchCreateWithTX(ctx, tx, agentID, onlineTools) + if len(onlineTools) > 0 { + err = t.agentToolDraftDAO.BatchCreateIgnoreConflictWithTX(ctx, tx, agentID, onlineTools) + if err != nil { + return err + } + } + + err = t.agentToolDraftDAO.DeleteWithTX(ctx, tx, agentID, removeToolIDs) if err != nil { return err } - return tx.Commit() + err = tx.Commit() + if err != nil { + return err + } + + return nil } func (t *toolRepoImpl) GetAgentPluginIDs(ctx context.Context, agentID int64) (pluginIDs []int64, err error) { @@ -316,7 +354,7 @@ func (t *toolRepoImpl) GetAgentPluginIDs(ctx context.Context, agentID int64) (pl } func (t *toolRepoImpl) DuplicateDraftAgentTools(ctx context.Context, fromAgentID, toAgentID int64) (err error) { - tools, err := t.agentToolDraftDAO.GetAll(ctx, fromAgentID) + tools, err := t.agentToolDraftDAO.GetAll(ctx, fromAgentID, nil) if err != nil { return err } @@ -370,7 +408,7 @@ func (t *toolRepoImpl) UpdateDraftAgentTool(ctx context.Context, req *UpdateDraf } func (t *toolRepoImpl) GetSpaceAllDraftAgentTools(ctx context.Context, agentID int64) (tools []*entity.ToolInfo, err error) { - return t.agentToolDraftDAO.GetAll(ctx, agentID) + return t.agentToolDraftDAO.GetAll(ctx, agentID, nil) } func (t *toolRepoImpl) GetVersionAgentTool(ctx context.Context, agentID int64, vAgentTool entity.VersionAgentTool) (tool *entity.ToolInfo, exist bool, err error) { diff --git a/backend/domain/plugin/service/exec_tool.go b/backend/domain/plugin/service/exec_tool.go index 75553e1e..ce35a247 100644 --- a/backend/domain/plugin/service/exec_tool.go +++ b/backend/domain/plugin/service/exec_tool.go @@ -596,7 +596,7 @@ func genRequestString(req *http.Request, body []byte) (string, error) { return "", fmt.Errorf("[genRequestString] marshal failed, err=%s", err) } - if body != nil { + if len(body) > 0 { requestStr, err = sjson.SetRaw(requestStr, "body", string(body)) if err != nil { return "", fmt.Errorf("[genRequestString] set body failed, err=%s", err)