fix(plugin): setting the default value of the agent tool does not take effect (#341)
This commit is contained in:
		
							parent
							
								
									357da72a52
								
							
						
					
					
						commit
						9660a85454
					
				|  | @ -22,6 +22,8 @@ import ( | ||||||
| 
 | 
 | ||||||
| 	"github.com/getkin/kin-openapi/openapi3" | 	"github.com/getkin/kin-openapi/openapi3" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/coze-dev/coze-studio/backend/domain/plugin/entity" | ||||||
|  | 
 | ||||||
| 	"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin" | 	"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin" | ||||||
| 	common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common" | 	common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common" | ||||||
| 	"github.com/coze-dev/coze-studio/backend/domain/plugin/entity" | 	"github.com/coze-dev/coze-studio/backend/domain/plugin/entity" | ||||||
|  |  | ||||||
|  | @ -21,9 +21,10 @@ import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"math/rand" | 	"math/rand" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/jinzhu/copier" | ||||||
|  | 
 | ||||||
| 	"github.com/cloudwego/eino/compose" | 	"github.com/cloudwego/eino/compose" | ||||||
| 	"github.com/cloudwego/eino/schema" | 	"github.com/cloudwego/eino/schema" | ||||||
| 	"github.com/jinzhu/copier" |  | ||||||
| 
 | 
 | ||||||
| 	"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/bot_common" | 	"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/bot_common" | ||||||
| 	"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossplugin" | 	"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossplugin" | ||||||
|  |  | ||||||
|  | @ -24,6 +24,7 @@ import ( | ||||||
| 	"strings" | 	"strings" | ||||||
| 
 | 
 | ||||||
| 	"github.com/getkin/kin-openapi/openapi3" | 	"github.com/getkin/kin-openapi/openapi3" | ||||||
|  | 	"github.com/mohae/deepcopy" | ||||||
| 	"golang.org/x/mod/semver" | 	"golang.org/x/mod/semver" | ||||||
| 	"gopkg.in/yaml.v3" | 	"gopkg.in/yaml.v3" | ||||||
| 
 | 
 | ||||||
|  | @ -59,18 +60,26 @@ var ( | ||||||
| 
 | 
 | ||||||
| func GetToolProduct(toolID int64) (*ToolInfo, bool) { | func GetToolProduct(toolID int64) (*ToolInfo, bool) { | ||||||
| 	ti, ok := toolProducts[toolID] | 	ti, ok := toolProducts[toolID] | ||||||
| 	return ti, ok | 	if !ok { | ||||||
|  | 		return nil, false | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	ti_ := deepcopy.Copy(ti).(*ToolInfo) | ||||||
|  | 
 | ||||||
|  | 	return ti_, true | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func MGetToolProducts(toolIDs []int64) []*ToolInfo { | func MGetToolProducts(toolIDs []int64) []*ToolInfo { | ||||||
| 	tools := make([]*ToolInfo, 0, len(toolIDs)) | 	tools := make([]*ToolInfo, 0, len(toolIDs)) | ||||||
| 	for _, toolID := range toolIDs { | 	for _, toolID := range toolIDs { | ||||||
| 		ti, ok := toolProducts[toolID] | 		ti, ok := GetToolProduct(toolID) | ||||||
| 		if !ok { | 		if !ok { | ||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
|  | 
 | ||||||
| 		tools = append(tools, ti) | 		tools = append(tools, ti) | ||||||
| 	} | 	} | ||||||
|  | 
 | ||||||
| 	return tools | 	return tools | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -20,7 +20,9 @@ import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"errors" | 	"errors" | ||||||
| 
 | 
 | ||||||
|  | 	"gorm.io/gen/field" | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
|  | 	"gorm.io/gorm/clause" | ||||||
| 
 | 
 | ||||||
| 	"github.com/coze-dev/coze-studio/backend/domain/plugin/entity" | 	"github.com/coze-dev/coze-studio/backend/domain/plugin/entity" | ||||||
| 	"github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal/model" | 	"github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal/model" | ||||||
|  | @ -55,6 +57,26 @@ func (a agentToolDraftPO) ToDO() *entity.ToolInfo { | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (at *AgentToolDraftDAO) getSelected(opt *ToolSelectedOption) (selected []field.Expr) { | ||||||
|  | 	if opt == nil { | ||||||
|  | 		return selected | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	table := at.query.AgentToolDraft | ||||||
|  | 
 | ||||||
|  | 	if opt.ToolID { | ||||||
|  | 		selected = append(selected, table.ToolID) | ||||||
|  | 	} | ||||||
|  | 	if opt.ToolMethod { | ||||||
|  | 		selected = append(selected, table.Method) | ||||||
|  | 	} | ||||||
|  | 	if opt.ToolSubURL { | ||||||
|  | 		selected = append(selected, table.SubURL) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return selected | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (at *AgentToolDraftDAO) Get(ctx context.Context, agentID, toolID int64) (tool *entity.ToolInfo, exist bool, err error) { | func (at *AgentToolDraftDAO) Get(ctx context.Context, agentID, toolID int64) (tool *entity.ToolInfo, exist bool, err error) { | ||||||
| 	table := at.query.AgentToolDraft | 	table := at.query.AgentToolDraft | ||||||
| 	tl, err := table.WithContext(ctx). | 	tl, err := table.WithContext(ctx). | ||||||
|  | @ -120,13 +142,14 @@ func (at *AgentToolDraftDAO) MGet(ctx context.Context, agentID int64, toolIDs [] | ||||||
| 	return tools, nil | 	return tools, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (at *AgentToolDraftDAO) GetAll(ctx context.Context, agentID int64) (tools []*entity.ToolInfo, err error) { | func (at *AgentToolDraftDAO) GetAll(ctx context.Context, agentID int64, opt *ToolSelectedOption) (tools []*entity.ToolInfo, err error) { | ||||||
| 	const limit = 20 | 	const limit = 20 | ||||||
| 	table := at.query.AgentToolDraft | 	table := at.query.AgentToolDraft | ||||||
| 	cursor := int64(0) | 	cursor := int64(0) | ||||||
| 
 | 
 | ||||||
| 	for { | 	for { | ||||||
| 		tls, err := table.WithContext(ctx). | 		tls, err := table.WithContext(ctx). | ||||||
|  | 			Select(at.getSelected(opt)...). | ||||||
| 			Where( | 			Where( | ||||||
| 				table.AgentID.Eq(agentID), | 				table.AgentID.Eq(agentID), | ||||||
| 				table.ID.Gt(cursor), | 				table.ID.Gt(cursor), | ||||||
|  | @ -171,6 +194,16 @@ func (at *AgentToolDraftDAO) UpdateWithToolName(ctx context.Context, agentID int | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (at *AgentToolDraftDAO) BatchCreateWithTX(ctx context.Context, tx *query.QueryTx, agentID int64, tools []*entity.ToolInfo) (err error) { | func (at *AgentToolDraftDAO) BatchCreateWithTX(ctx context.Context, tx *query.QueryTx, agentID int64, tools []*entity.ToolInfo) (err error) { | ||||||
|  | 	return at.batchCreateWithTX(ctx, tx, agentID, tools, false) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (at *AgentToolDraftDAO) BatchCreateIgnoreConflictWithTX(ctx context.Context, tx *query.QueryTx, agentID int64, tools []*entity.ToolInfo) (err error) { | ||||||
|  | 	return at.batchCreateWithTX(ctx, tx, agentID, tools, true) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (at *AgentToolDraftDAO) batchCreateWithTX(ctx context.Context, tx *query.QueryTx, agentID int64, | ||||||
|  | 	tools []*entity.ToolInfo, ignoreConflict bool) (err error) { | ||||||
|  | 
 | ||||||
| 	tls := make([]*model.AgentToolDraft, 0, len(tools)) | 	tls := make([]*model.AgentToolDraft, 0, len(tools)) | ||||||
| 	for _, tl := range tools { | 	for _, tl := range tools { | ||||||
| 		id, err := at.idGen.GenID(ctx) | 		id, err := at.idGen.GenID(ctx) | ||||||
|  | @ -192,7 +225,13 @@ func (at *AgentToolDraftDAO) BatchCreateWithTX(ctx context.Context, tx *query.Qu | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	table := tx.AgentToolDraft | 	table := tx.AgentToolDraft | ||||||
|  | 
 | ||||||
|  | 	if ignoreConflict { | ||||||
|  | 		err = table.WithContext(ctx).Clauses(clause.OnConflict{DoNothing: true}). | ||||||
|  | 			CreateInBatches(tls, 20) | ||||||
|  | 	} else { | ||||||
| 		err = table.WithContext(ctx).CreateInBatches(tls, 20) | 		err = table.WithContext(ctx).CreateInBatches(tls, 20) | ||||||
|  | 	} | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|  | @ -200,27 +239,6 @@ func (at *AgentToolDraftDAO) BatchCreateWithTX(ctx context.Context, tx *query.Qu | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (at *AgentToolDraftDAO) DeleteAll(ctx context.Context, agentID int64) (err error) { |  | ||||||
| 	const limit = 20 |  | ||||||
| 	table := at.query.AgentToolDraft |  | ||||||
| 
 |  | ||||||
| 	for { |  | ||||||
| 		info, err := table.WithContext(ctx). |  | ||||||
| 			Where(table.AgentID.Eq(agentID)). |  | ||||||
| 			Limit(limit). |  | ||||||
| 			Delete() |  | ||||||
| 		if err != nil { |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if info.RowsAffected == 0 || info.RowsAffected < limit { |  | ||||||
| 			break |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (at *AgentToolDraftDAO) GetAllPluginIDs(ctx context.Context, agentID int64) (pluginIDs []int64, err error) { | func (at *AgentToolDraftDAO) GetAllPluginIDs(ctx context.Context, agentID int64) (pluginIDs []int64, err error) { | ||||||
| 	const size = 100 | 	const size = 100 | ||||||
| 	table := at.query.AgentToolDraft | 	table := at.query.AgentToolDraft | ||||||
|  | @ -254,22 +272,22 @@ func (at *AgentToolDraftDAO) GetAllPluginIDs(ctx context.Context, agentID int64) | ||||||
| 	return slices.Unique(pluginIDs), nil | 	return slices.Unique(pluginIDs), nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (at *AgentToolDraftDAO) DeleteAllWithTX(ctx context.Context, tx *query.QueryTx, agentID int64) (err error) { | func (at *AgentToolDraftDAO) DeleteWithTX(ctx context.Context, tx *query.QueryTx, agentID int64, toolIDs []int64) (err error) { | ||||||
| 	const limit = 20 | 	const limit = 20 | ||||||
| 	table := tx.AgentToolDraft | 	table := tx.AgentToolDraft | ||||||
| 
 | 
 | ||||||
| 	for { | 	chunks := slices.Chunks(toolIDs, limit) | ||||||
| 		info, err := table.WithContext(ctx). | 	for _, chunk := range chunks { | ||||||
| 			Where(table.AgentID.Eq(agentID)). | 		_, err = table.WithContext(ctx). | ||||||
|  | 			Where( | ||||||
|  | 				table.AgentID.Eq(agentID), | ||||||
|  | 				table.ToolID.In(chunk...), | ||||||
|  | 			). | ||||||
| 			Limit(limit). | 			Limit(limit). | ||||||
| 			Delete() | 			Delete() | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| 
 |  | ||||||
| 		if info.RowsAffected == 0 || info.RowsAffected < limit { |  | ||||||
| 			break |  | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return nil | ||||||
|  | @ -114,7 +114,6 @@ func (t *ToolDraftDAO) Create(ctx context.Context, tool *entity.ToolInfo) (toolI | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (t *ToolDraftDAO) genToolID(ctx context.Context) (id int64, err error) { | func (t *ToolDraftDAO) genToolID(ctx context.Context) (id int64, err error) { | ||||||
| 
 |  | ||||||
| 	retryTimes := 5 | 	retryTimes := 5 | ||||||
| 
 | 
 | ||||||
| 	for i := 0; i < retryTimes; i++ { | 	for i := 0; i < retryTimes; i++ { | ||||||
|  | @ -123,11 +122,13 @@ func (t *ToolDraftDAO) genToolID(ctx context.Context) (id int64, err error) { | ||||||
| 			return 0, err | 			return 0, err | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		if _, ok := conf.GetToolProduct(id); !ok { | 		_, ok := conf.GetToolProduct(id) | ||||||
|  | 		if !ok { | ||||||
| 			break | 			break | ||||||
| 		} | 		} | ||||||
|  | 
 | ||||||
| 		if i == retryTimes-1 { | 		if i == retryTimes-1 { | ||||||
| 			return 0, fmt.Errorf("id %d is confilict with product tool id.", id) | 			return 0, fmt.Errorf("id %d is confilict with product tool id", id) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -216,6 +216,7 @@ func (t *toolRepoImpl) GetOnlineTool(ctx context.Context, toolID int64) (tool *e | ||||||
| 
 | 
 | ||||||
| func (t *toolRepoImpl) MGetOnlineTools(ctx context.Context, toolIDs []int64, opts ...ToolSelectedOptions) (tools []*entity.ToolInfo, err error) { | func (t *toolRepoImpl) MGetOnlineTools(ctx context.Context, toolIDs []int64, opts ...ToolSelectedOptions) (tools []*entity.ToolInfo, err error) { | ||||||
| 	toolProducts := pluginConf.MGetToolProducts(toolIDs) | 	toolProducts := pluginConf.MGetToolProducts(toolIDs) | ||||||
|  | 
 | ||||||
| 	tools = slices.Transform(toolProducts, func(tool *pluginConf.ToolInfo) *entity.ToolInfo { | 	tools = slices.Transform(toolProducts, func(tool *pluginConf.ToolInfo) *entity.ToolInfo { | ||||||
| 		return tool.Info | 		return tool.Info | ||||||
| 	}) | 	}) | ||||||
|  | @ -269,13 +270,38 @@ func (t *toolRepoImpl) MGetVersionTools(ctx context.Context, versionTools []enti | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (t *toolRepoImpl) BindDraftAgentTools(ctx context.Context, agentID int64, toolIDs []int64) (err error) { | func (t *toolRepoImpl) BindDraftAgentTools(ctx context.Context, agentID int64, toolIDs []int64) (err error) { | ||||||
| 	onlineTools, err := t.MGetOnlineTools(ctx, toolIDs) | 	opt := &dal.ToolSelectedOption{ | ||||||
|  | 		ToolID: true, | ||||||
|  | 	} | ||||||
|  | 	draftAgentTools, err := t.agentToolDraftDAO.GetAll(ctx, agentID, opt) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if len(onlineTools) == 0 { | 	draftAgentToolIDMap := slices.ToMap(draftAgentTools, func(tool *entity.ToolInfo) (int64, bool) { | ||||||
| 		return t.agentToolDraftDAO.DeleteAll(ctx, agentID) | 		return tool.ID, true | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	bindToolIDMap := slices.ToMap(toolIDs, func(toolID int64) (int64, bool) { | ||||||
|  | 		return toolID, true | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	newBindToolIDs := make([]int64, 0, len(toolIDs)) | ||||||
|  | 	for _, toolID := range toolIDs { | ||||||
|  | 		_, ok := draftAgentToolIDMap[toolID] | ||||||
|  | 		if ok { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 		newBindToolIDs = append(newBindToolIDs, toolID) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	removeToolIDs := make([]int64, 0, len(draftAgentTools)) | ||||||
|  | 	for toolID := range draftAgentToolIDMap { | ||||||
|  | 		_, ok := bindToolIDMap[toolID] | ||||||
|  | 		if ok { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 		removeToolIDs = append(removeToolIDs, toolID) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	tx := t.query.Begin() | 	tx := t.query.Begin() | ||||||
|  | @ -298,17 +324,29 @@ func (t *toolRepoImpl) BindDraftAgentTools(ctx context.Context, agentID int64, t | ||||||
| 		} | 		} | ||||||
| 	}() | 	}() | ||||||
| 
 | 
 | ||||||
| 	err = t.agentToolDraftDAO.DeleteAllWithTX(ctx, tx, agentID) | 	onlineTools, err := t.MGetOnlineTools(ctx, newBindToolIDs) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	err = t.agentToolDraftDAO.BatchCreateWithTX(ctx, tx, agentID, onlineTools) | 	if len(onlineTools) > 0 { | ||||||
|  | 		err = t.agentToolDraftDAO.BatchCreateIgnoreConflictWithTX(ctx, tx, agentID, onlineTools) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	err = t.agentToolDraftDAO.DeleteWithTX(ctx, tx, agentID, removeToolIDs) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return tx.Commit() | 	err = tx.Commit() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (t *toolRepoImpl) GetAgentPluginIDs(ctx context.Context, agentID int64) (pluginIDs []int64, err error) { | func (t *toolRepoImpl) GetAgentPluginIDs(ctx context.Context, agentID int64) (pluginIDs []int64, err error) { | ||||||
|  | @ -316,7 +354,7 @@ func (t *toolRepoImpl) GetAgentPluginIDs(ctx context.Context, agentID int64) (pl | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (t *toolRepoImpl) DuplicateDraftAgentTools(ctx context.Context, fromAgentID, toAgentID int64) (err error) { | func (t *toolRepoImpl) DuplicateDraftAgentTools(ctx context.Context, fromAgentID, toAgentID int64) (err error) { | ||||||
| 	tools, err := t.agentToolDraftDAO.GetAll(ctx, fromAgentID) | 	tools, err := t.agentToolDraftDAO.GetAll(ctx, fromAgentID, nil) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|  | @ -370,7 +408,7 @@ func (t *toolRepoImpl) UpdateDraftAgentTool(ctx context.Context, req *UpdateDraf | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (t *toolRepoImpl) GetSpaceAllDraftAgentTools(ctx context.Context, agentID int64) (tools []*entity.ToolInfo, err error) { | func (t *toolRepoImpl) GetSpaceAllDraftAgentTools(ctx context.Context, agentID int64) (tools []*entity.ToolInfo, err error) { | ||||||
| 	return t.agentToolDraftDAO.GetAll(ctx, agentID) | 	return t.agentToolDraftDAO.GetAll(ctx, agentID, nil) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (t *toolRepoImpl) GetVersionAgentTool(ctx context.Context, agentID int64, vAgentTool entity.VersionAgentTool) (tool *entity.ToolInfo, exist bool, err error) { | func (t *toolRepoImpl) GetVersionAgentTool(ctx context.Context, agentID int64, vAgentTool entity.VersionAgentTool) (tool *entity.ToolInfo, exist bool, err error) { | ||||||
|  |  | ||||||
|  | @ -596,7 +596,7 @@ func genRequestString(req *http.Request, body []byte) (string, error) { | ||||||
| 		return "", fmt.Errorf("[genRequestString] marshal failed, err=%s", err) | 		return "", fmt.Errorf("[genRequestString] marshal failed, err=%s", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if body != nil { | 	if len(body) > 0 { | ||||||
| 		requestStr, err = sjson.SetRaw(requestStr, "body", string(body)) | 		requestStr, err = sjson.SetRaw(requestStr, "body", string(body)) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return "", fmt.Errorf("[genRequestString] set body failed, err=%s", err) | 			return "", fmt.Errorf("[genRequestString] set body failed, err=%s", err) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue