fix(plugin): setting the default value of the agent tool does not take effect (#341)
This commit is contained in:
parent
357da72a52
commit
9660a85454
|
|
@ -22,6 +22,8 @@ import (
|
||||||
|
|
||||||
"github.com/getkin/kin-openapi/openapi3"
|
"github.com/getkin/kin-openapi/openapi3"
|
||||||
|
|
||||||
|
"github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
|
||||||
|
|
||||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
|
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
|
||||||
common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common"
|
common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common"
|
||||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
|
"github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
|
||||||
|
|
|
||||||
|
|
@ -21,9 +21,10 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
|
|
||||||
|
"github.com/jinzhu/copier"
|
||||||
|
|
||||||
"github.com/cloudwego/eino/compose"
|
"github.com/cloudwego/eino/compose"
|
||||||
"github.com/cloudwego/eino/schema"
|
"github.com/cloudwego/eino/schema"
|
||||||
"github.com/jinzhu/copier"
|
|
||||||
|
|
||||||
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/bot_common"
|
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/bot_common"
|
||||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossplugin"
|
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossplugin"
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/getkin/kin-openapi/openapi3"
|
"github.com/getkin/kin-openapi/openapi3"
|
||||||
|
"github.com/mohae/deepcopy"
|
||||||
"golang.org/x/mod/semver"
|
"golang.org/x/mod/semver"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
|
|
||||||
|
|
@ -59,18 +60,26 @@ var (
|
||||||
|
|
||||||
func GetToolProduct(toolID int64) (*ToolInfo, bool) {
|
func GetToolProduct(toolID int64) (*ToolInfo, bool) {
|
||||||
ti, ok := toolProducts[toolID]
|
ti, ok := toolProducts[toolID]
|
||||||
return ti, ok
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
ti_ := deepcopy.Copy(ti).(*ToolInfo)
|
||||||
|
|
||||||
|
return ti_, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func MGetToolProducts(toolIDs []int64) []*ToolInfo {
|
func MGetToolProducts(toolIDs []int64) []*ToolInfo {
|
||||||
tools := make([]*ToolInfo, 0, len(toolIDs))
|
tools := make([]*ToolInfo, 0, len(toolIDs))
|
||||||
for _, toolID := range toolIDs {
|
for _, toolID := range toolIDs {
|
||||||
ti, ok := toolProducts[toolID]
|
ti, ok := GetToolProduct(toolID)
|
||||||
if !ok {
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
tools = append(tools, ti)
|
tools = append(tools, ti)
|
||||||
}
|
}
|
||||||
|
|
||||||
return tools
|
return tools
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,9 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
|
"gorm.io/gen/field"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
|
|
||||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
|
"github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
|
||||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal/model"
|
"github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal/model"
|
||||||
|
|
@ -55,6 +57,26 @@ func (a agentToolDraftPO) ToDO() *entity.ToolInfo {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (at *AgentToolDraftDAO) getSelected(opt *ToolSelectedOption) (selected []field.Expr) {
|
||||||
|
if opt == nil {
|
||||||
|
return selected
|
||||||
|
}
|
||||||
|
|
||||||
|
table := at.query.AgentToolDraft
|
||||||
|
|
||||||
|
if opt.ToolID {
|
||||||
|
selected = append(selected, table.ToolID)
|
||||||
|
}
|
||||||
|
if opt.ToolMethod {
|
||||||
|
selected = append(selected, table.Method)
|
||||||
|
}
|
||||||
|
if opt.ToolSubURL {
|
||||||
|
selected = append(selected, table.SubURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
return selected
|
||||||
|
}
|
||||||
|
|
||||||
func (at *AgentToolDraftDAO) Get(ctx context.Context, agentID, toolID int64) (tool *entity.ToolInfo, exist bool, err error) {
|
func (at *AgentToolDraftDAO) Get(ctx context.Context, agentID, toolID int64) (tool *entity.ToolInfo, exist bool, err error) {
|
||||||
table := at.query.AgentToolDraft
|
table := at.query.AgentToolDraft
|
||||||
tl, err := table.WithContext(ctx).
|
tl, err := table.WithContext(ctx).
|
||||||
|
|
@ -120,13 +142,14 @@ func (at *AgentToolDraftDAO) MGet(ctx context.Context, agentID int64, toolIDs []
|
||||||
return tools, nil
|
return tools, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (at *AgentToolDraftDAO) GetAll(ctx context.Context, agentID int64) (tools []*entity.ToolInfo, err error) {
|
func (at *AgentToolDraftDAO) GetAll(ctx context.Context, agentID int64, opt *ToolSelectedOption) (tools []*entity.ToolInfo, err error) {
|
||||||
const limit = 20
|
const limit = 20
|
||||||
table := at.query.AgentToolDraft
|
table := at.query.AgentToolDraft
|
||||||
cursor := int64(0)
|
cursor := int64(0)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
tls, err := table.WithContext(ctx).
|
tls, err := table.WithContext(ctx).
|
||||||
|
Select(at.getSelected(opt)...).
|
||||||
Where(
|
Where(
|
||||||
table.AgentID.Eq(agentID),
|
table.AgentID.Eq(agentID),
|
||||||
table.ID.Gt(cursor),
|
table.ID.Gt(cursor),
|
||||||
|
|
@ -171,6 +194,16 @@ func (at *AgentToolDraftDAO) UpdateWithToolName(ctx context.Context, agentID int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (at *AgentToolDraftDAO) BatchCreateWithTX(ctx context.Context, tx *query.QueryTx, agentID int64, tools []*entity.ToolInfo) (err error) {
|
func (at *AgentToolDraftDAO) BatchCreateWithTX(ctx context.Context, tx *query.QueryTx, agentID int64, tools []*entity.ToolInfo) (err error) {
|
||||||
|
return at.batchCreateWithTX(ctx, tx, agentID, tools, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (at *AgentToolDraftDAO) BatchCreateIgnoreConflictWithTX(ctx context.Context, tx *query.QueryTx, agentID int64, tools []*entity.ToolInfo) (err error) {
|
||||||
|
return at.batchCreateWithTX(ctx, tx, agentID, tools, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (at *AgentToolDraftDAO) batchCreateWithTX(ctx context.Context, tx *query.QueryTx, agentID int64,
|
||||||
|
tools []*entity.ToolInfo, ignoreConflict bool) (err error) {
|
||||||
|
|
||||||
tls := make([]*model.AgentToolDraft, 0, len(tools))
|
tls := make([]*model.AgentToolDraft, 0, len(tools))
|
||||||
for _, tl := range tools {
|
for _, tl := range tools {
|
||||||
id, err := at.idGen.GenID(ctx)
|
id, err := at.idGen.GenID(ctx)
|
||||||
|
|
@ -192,7 +225,13 @@ func (at *AgentToolDraftDAO) BatchCreateWithTX(ctx context.Context, tx *query.Qu
|
||||||
}
|
}
|
||||||
|
|
||||||
table := tx.AgentToolDraft
|
table := tx.AgentToolDraft
|
||||||
|
|
||||||
|
if ignoreConflict {
|
||||||
|
err = table.WithContext(ctx).Clauses(clause.OnConflict{DoNothing: true}).
|
||||||
|
CreateInBatches(tls, 20)
|
||||||
|
} else {
|
||||||
err = table.WithContext(ctx).CreateInBatches(tls, 20)
|
err = table.WithContext(ctx).CreateInBatches(tls, 20)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -200,27 +239,6 @@ func (at *AgentToolDraftDAO) BatchCreateWithTX(ctx context.Context, tx *query.Qu
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (at *AgentToolDraftDAO) DeleteAll(ctx context.Context, agentID int64) (err error) {
|
|
||||||
const limit = 20
|
|
||||||
table := at.query.AgentToolDraft
|
|
||||||
|
|
||||||
for {
|
|
||||||
info, err := table.WithContext(ctx).
|
|
||||||
Where(table.AgentID.Eq(agentID)).
|
|
||||||
Limit(limit).
|
|
||||||
Delete()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if info.RowsAffected == 0 || info.RowsAffected < limit {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (at *AgentToolDraftDAO) GetAllPluginIDs(ctx context.Context, agentID int64) (pluginIDs []int64, err error) {
|
func (at *AgentToolDraftDAO) GetAllPluginIDs(ctx context.Context, agentID int64) (pluginIDs []int64, err error) {
|
||||||
const size = 100
|
const size = 100
|
||||||
table := at.query.AgentToolDraft
|
table := at.query.AgentToolDraft
|
||||||
|
|
@ -254,22 +272,22 @@ func (at *AgentToolDraftDAO) GetAllPluginIDs(ctx context.Context, agentID int64)
|
||||||
return slices.Unique(pluginIDs), nil
|
return slices.Unique(pluginIDs), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (at *AgentToolDraftDAO) DeleteAllWithTX(ctx context.Context, tx *query.QueryTx, agentID int64) (err error) {
|
func (at *AgentToolDraftDAO) DeleteWithTX(ctx context.Context, tx *query.QueryTx, agentID int64, toolIDs []int64) (err error) {
|
||||||
const limit = 20
|
const limit = 20
|
||||||
table := tx.AgentToolDraft
|
table := tx.AgentToolDraft
|
||||||
|
|
||||||
for {
|
chunks := slices.Chunks(toolIDs, limit)
|
||||||
info, err := table.WithContext(ctx).
|
for _, chunk := range chunks {
|
||||||
Where(table.AgentID.Eq(agentID)).
|
_, err = table.WithContext(ctx).
|
||||||
|
Where(
|
||||||
|
table.AgentID.Eq(agentID),
|
||||||
|
table.ToolID.In(chunk...),
|
||||||
|
).
|
||||||
Limit(limit).
|
Limit(limit).
|
||||||
Delete()
|
Delete()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if info.RowsAffected == 0 || info.RowsAffected < limit {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
@ -114,7 +114,6 @@ func (t *ToolDraftDAO) Create(ctx context.Context, tool *entity.ToolInfo) (toolI
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *ToolDraftDAO) genToolID(ctx context.Context) (id int64, err error) {
|
func (t *ToolDraftDAO) genToolID(ctx context.Context) (id int64, err error) {
|
||||||
|
|
||||||
retryTimes := 5
|
retryTimes := 5
|
||||||
|
|
||||||
for i := 0; i < retryTimes; i++ {
|
for i := 0; i < retryTimes; i++ {
|
||||||
|
|
@ -123,11 +122,13 @@ func (t *ToolDraftDAO) genToolID(ctx context.Context) (id int64, err error) {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := conf.GetToolProduct(id); !ok {
|
_, ok := conf.GetToolProduct(id)
|
||||||
|
if !ok {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
if i == retryTimes-1 {
|
if i == retryTimes-1 {
|
||||||
return 0, fmt.Errorf("id %d is confilict with product tool id.", id)
|
return 0, fmt.Errorf("id %d is confilict with product tool id", id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -216,6 +216,7 @@ func (t *toolRepoImpl) GetOnlineTool(ctx context.Context, toolID int64) (tool *e
|
||||||
|
|
||||||
func (t *toolRepoImpl) MGetOnlineTools(ctx context.Context, toolIDs []int64, opts ...ToolSelectedOptions) (tools []*entity.ToolInfo, err error) {
|
func (t *toolRepoImpl) MGetOnlineTools(ctx context.Context, toolIDs []int64, opts ...ToolSelectedOptions) (tools []*entity.ToolInfo, err error) {
|
||||||
toolProducts := pluginConf.MGetToolProducts(toolIDs)
|
toolProducts := pluginConf.MGetToolProducts(toolIDs)
|
||||||
|
|
||||||
tools = slices.Transform(toolProducts, func(tool *pluginConf.ToolInfo) *entity.ToolInfo {
|
tools = slices.Transform(toolProducts, func(tool *pluginConf.ToolInfo) *entity.ToolInfo {
|
||||||
return tool.Info
|
return tool.Info
|
||||||
})
|
})
|
||||||
|
|
@ -269,13 +270,38 @@ func (t *toolRepoImpl) MGetVersionTools(ctx context.Context, versionTools []enti
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *toolRepoImpl) BindDraftAgentTools(ctx context.Context, agentID int64, toolIDs []int64) (err error) {
|
func (t *toolRepoImpl) BindDraftAgentTools(ctx context.Context, agentID int64, toolIDs []int64) (err error) {
|
||||||
onlineTools, err := t.MGetOnlineTools(ctx, toolIDs)
|
opt := &dal.ToolSelectedOption{
|
||||||
|
ToolID: true,
|
||||||
|
}
|
||||||
|
draftAgentTools, err := t.agentToolDraftDAO.GetAll(ctx, agentID, opt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(onlineTools) == 0 {
|
draftAgentToolIDMap := slices.ToMap(draftAgentTools, func(tool *entity.ToolInfo) (int64, bool) {
|
||||||
return t.agentToolDraftDAO.DeleteAll(ctx, agentID)
|
return tool.ID, true
|
||||||
|
})
|
||||||
|
|
||||||
|
bindToolIDMap := slices.ToMap(toolIDs, func(toolID int64) (int64, bool) {
|
||||||
|
return toolID, true
|
||||||
|
})
|
||||||
|
|
||||||
|
newBindToolIDs := make([]int64, 0, len(toolIDs))
|
||||||
|
for _, toolID := range toolIDs {
|
||||||
|
_, ok := draftAgentToolIDMap[toolID]
|
||||||
|
if ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
newBindToolIDs = append(newBindToolIDs, toolID)
|
||||||
|
}
|
||||||
|
|
||||||
|
removeToolIDs := make([]int64, 0, len(draftAgentTools))
|
||||||
|
for toolID := range draftAgentToolIDMap {
|
||||||
|
_, ok := bindToolIDMap[toolID]
|
||||||
|
if ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
removeToolIDs = append(removeToolIDs, toolID)
|
||||||
}
|
}
|
||||||
|
|
||||||
tx := t.query.Begin()
|
tx := t.query.Begin()
|
||||||
|
|
@ -298,17 +324,29 @@ func (t *toolRepoImpl) BindDraftAgentTools(ctx context.Context, agentID int64, t
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err = t.agentToolDraftDAO.DeleteAllWithTX(ctx, tx, agentID)
|
onlineTools, err := t.MGetOnlineTools(ctx, newBindToolIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = t.agentToolDraftDAO.BatchCreateWithTX(ctx, tx, agentID, onlineTools)
|
if len(onlineTools) > 0 {
|
||||||
|
err = t.agentToolDraftDAO.BatchCreateIgnoreConflictWithTX(ctx, tx, agentID, onlineTools)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = t.agentToolDraftDAO.DeleteWithTX(ctx, tx, agentID, removeToolIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return tx.Commit()
|
err = tx.Commit()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *toolRepoImpl) GetAgentPluginIDs(ctx context.Context, agentID int64) (pluginIDs []int64, err error) {
|
func (t *toolRepoImpl) GetAgentPluginIDs(ctx context.Context, agentID int64) (pluginIDs []int64, err error) {
|
||||||
|
|
@ -316,7 +354,7 @@ func (t *toolRepoImpl) GetAgentPluginIDs(ctx context.Context, agentID int64) (pl
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *toolRepoImpl) DuplicateDraftAgentTools(ctx context.Context, fromAgentID, toAgentID int64) (err error) {
|
func (t *toolRepoImpl) DuplicateDraftAgentTools(ctx context.Context, fromAgentID, toAgentID int64) (err error) {
|
||||||
tools, err := t.agentToolDraftDAO.GetAll(ctx, fromAgentID)
|
tools, err := t.agentToolDraftDAO.GetAll(ctx, fromAgentID, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -370,7 +408,7 @@ func (t *toolRepoImpl) UpdateDraftAgentTool(ctx context.Context, req *UpdateDraf
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *toolRepoImpl) GetSpaceAllDraftAgentTools(ctx context.Context, agentID int64) (tools []*entity.ToolInfo, err error) {
|
func (t *toolRepoImpl) GetSpaceAllDraftAgentTools(ctx context.Context, agentID int64) (tools []*entity.ToolInfo, err error) {
|
||||||
return t.agentToolDraftDAO.GetAll(ctx, agentID)
|
return t.agentToolDraftDAO.GetAll(ctx, agentID, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *toolRepoImpl) GetVersionAgentTool(ctx context.Context, agentID int64, vAgentTool entity.VersionAgentTool) (tool *entity.ToolInfo, exist bool, err error) {
|
func (t *toolRepoImpl) GetVersionAgentTool(ctx context.Context, agentID int64, vAgentTool entity.VersionAgentTool) (tool *entity.ToolInfo, exist bool, err error) {
|
||||||
|
|
|
||||||
|
|
@ -596,7 +596,7 @@ func genRequestString(req *http.Request, body []byte) (string, error) {
|
||||||
return "", fmt.Errorf("[genRequestString] marshal failed, err=%s", err)
|
return "", fmt.Errorf("[genRequestString] marshal failed, err=%s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if body != nil {
|
if len(body) > 0 {
|
||||||
requestStr, err = sjson.SetRaw(requestStr, "body", string(body))
|
requestStr, err = sjson.SetRaw(requestStr, "body", string(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("[genRequestString] set body failed, err=%s", err)
|
return "", fmt.Errorf("[genRequestString] set body failed, err=%s", err)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue