fix(plugin): setting the default value of the agent tool does not take effect (#341)

This commit is contained in:
mrh997
2025-07-30 22:51:16 +08:00
committed by GitHub
parent 357da72a52
commit 9660a85454
7 changed files with 115 additions and 46 deletions

View File

@@ -20,7 +20,9 @@ import (
"context"
"errors"
"gorm.io/gen/field"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
"github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal/model"
@@ -55,6 +57,26 @@ func (a agentToolDraftPO) ToDO() *entity.ToolInfo {
}
}
func (at *AgentToolDraftDAO) getSelected(opt *ToolSelectedOption) (selected []field.Expr) {
if opt == nil {
return selected
}
table := at.query.AgentToolDraft
if opt.ToolID {
selected = append(selected, table.ToolID)
}
if opt.ToolMethod {
selected = append(selected, table.Method)
}
if opt.ToolSubURL {
selected = append(selected, table.SubURL)
}
return selected
}
func (at *AgentToolDraftDAO) Get(ctx context.Context, agentID, toolID int64) (tool *entity.ToolInfo, exist bool, err error) {
table := at.query.AgentToolDraft
tl, err := table.WithContext(ctx).
@@ -120,13 +142,14 @@ func (at *AgentToolDraftDAO) MGet(ctx context.Context, agentID int64, toolIDs []
return tools, nil
}
func (at *AgentToolDraftDAO) GetAll(ctx context.Context, agentID int64) (tools []*entity.ToolInfo, err error) {
func (at *AgentToolDraftDAO) GetAll(ctx context.Context, agentID int64, opt *ToolSelectedOption) (tools []*entity.ToolInfo, err error) {
const limit = 20
table := at.query.AgentToolDraft
cursor := int64(0)
for {
tls, err := table.WithContext(ctx).
Select(at.getSelected(opt)...).
Where(
table.AgentID.Eq(agentID),
table.ID.Gt(cursor),
@@ -171,6 +194,16 @@ func (at *AgentToolDraftDAO) UpdateWithToolName(ctx context.Context, agentID int
}
func (at *AgentToolDraftDAO) BatchCreateWithTX(ctx context.Context, tx *query.QueryTx, agentID int64, tools []*entity.ToolInfo) (err error) {
return at.batchCreateWithTX(ctx, tx, agentID, tools, false)
}
func (at *AgentToolDraftDAO) BatchCreateIgnoreConflictWithTX(ctx context.Context, tx *query.QueryTx, agentID int64, tools []*entity.ToolInfo) (err error) {
return at.batchCreateWithTX(ctx, tx, agentID, tools, true)
}
func (at *AgentToolDraftDAO) batchCreateWithTX(ctx context.Context, tx *query.QueryTx, agentID int64,
tools []*entity.ToolInfo, ignoreConflict bool) (err error) {
tls := make([]*model.AgentToolDraft, 0, len(tools))
for _, tl := range tools {
id, err := at.idGen.GenID(ctx)
@@ -192,7 +225,13 @@ func (at *AgentToolDraftDAO) BatchCreateWithTX(ctx context.Context, tx *query.Qu
}
table := tx.AgentToolDraft
err = table.WithContext(ctx).CreateInBatches(tls, 20)
if ignoreConflict {
err = table.WithContext(ctx).Clauses(clause.OnConflict{DoNothing: true}).
CreateInBatches(tls, 20)
} else {
err = table.WithContext(ctx).CreateInBatches(tls, 20)
}
if err != nil {
return err
}
@@ -200,27 +239,6 @@ func (at *AgentToolDraftDAO) BatchCreateWithTX(ctx context.Context, tx *query.Qu
return nil
}
func (at *AgentToolDraftDAO) DeleteAll(ctx context.Context, agentID int64) (err error) {
const limit = 20
table := at.query.AgentToolDraft
for {
info, err := table.WithContext(ctx).
Where(table.AgentID.Eq(agentID)).
Limit(limit).
Delete()
if err != nil {
return err
}
if info.RowsAffected == 0 || info.RowsAffected < limit {
break
}
}
return nil
}
func (at *AgentToolDraftDAO) GetAllPluginIDs(ctx context.Context, agentID int64) (pluginIDs []int64, err error) {
const size = 100
table := at.query.AgentToolDraft
@@ -254,22 +272,22 @@ func (at *AgentToolDraftDAO) GetAllPluginIDs(ctx context.Context, agentID int64)
return slices.Unique(pluginIDs), nil
}
func (at *AgentToolDraftDAO) DeleteAllWithTX(ctx context.Context, tx *query.QueryTx, agentID int64) (err error) {
func (at *AgentToolDraftDAO) DeleteWithTX(ctx context.Context, tx *query.QueryTx, agentID int64, toolIDs []int64) (err error) {
const limit = 20
table := tx.AgentToolDraft
for {
info, err := table.WithContext(ctx).
Where(table.AgentID.Eq(agentID)).
chunks := slices.Chunks(toolIDs, limit)
for _, chunk := range chunks {
_, err = table.WithContext(ctx).
Where(
table.AgentID.Eq(agentID),
table.ToolID.In(chunk...),
).
Limit(limit).
Delete()
if err != nil {
return err
}
if info.RowsAffected == 0 || info.RowsAffected < limit {
break
}
}
return nil

View File

@@ -114,7 +114,6 @@ func (t *ToolDraftDAO) Create(ctx context.Context, tool *entity.ToolInfo) (toolI
}
func (t *ToolDraftDAO) genToolID(ctx context.Context) (id int64, err error) {
retryTimes := 5
for i := 0; i < retryTimes; i++ {
@@ -123,11 +122,13 @@ func (t *ToolDraftDAO) genToolID(ctx context.Context) (id int64, err error) {
return 0, err
}
if _, ok := conf.GetToolProduct(id); !ok {
_, ok := conf.GetToolProduct(id)
if !ok {
break
}
if i == retryTimes-1 {
return 0, fmt.Errorf("id %d is confilict with product tool id.", id)
return 0, fmt.Errorf("id %d is confilict with product tool id", id)
}
}