302 lines
		
	
	
		
			7.4 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			302 lines
		
	
	
		
			7.4 KiB
		
	
	
	
		
			Go
		
	
	
	
/*
 | 
						|
 * Copyright 2025 coze-dev Authors
 | 
						|
 *
 | 
						|
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
 * you may not use this file except in compliance with the License.
 | 
						|
 * You may obtain a copy of the License at
 | 
						|
 *
 | 
						|
 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
 *
 | 
						|
 * Unless required by applicable law or agreed to in writing, software
 | 
						|
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
 * See the License for the specific language governing permissions and
 | 
						|
 * limitations under the License.
 | 
						|
 */
 | 
						|
 | 
						|
package agentflow
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"fmt"
 | 
						|
	"sort"
 | 
						|
	"strconv"
 | 
						|
	"strings"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"github.com/cloudwego/eino/components/tool"
 | 
						|
	"github.com/cloudwego/eino/components/tool/utils"
 | 
						|
 | 
						|
	"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
 | 
						|
 | 
						|
	"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/bot_common"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/api/model/table"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossdatabase"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/infra/impl/sqlparser"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
 | 
						|
)
 | 
						|
 | 
						|
const (
 | 
						|
	TimeFormat = "2006-01-02 15:04:05"
 | 
						|
)
 | 
						|
 | 
						|
type databaseConfig struct {
 | 
						|
	agentIdentity *entity.AgentIdentity
 | 
						|
	userID        string
 | 
						|
	spaceID       int64
 | 
						|
 | 
						|
	databaseConf []*bot_common.Database
 | 
						|
}
 | 
						|
 | 
						|
type databaseTool struct {
 | 
						|
	agentIdentity *entity.AgentIdentity
 | 
						|
	connectorUID  string
 | 
						|
	spaceID       int64
 | 
						|
 | 
						|
	databaseID int64
 | 
						|
 | 
						|
	name           string
 | 
						|
	promptDisabled bool
 | 
						|
}
 | 
						|
 | 
						|
type ExecuteSQLRequest struct {
 | 
						|
	SQL string `json:"sql" jsonschema:"description=SQL query to execute against the database. You can use standard SQL syntax like SELECT, INSERT, UPDATE, DELETE."`
 | 
						|
}
 | 
						|
 | 
						|
func (d *databaseTool) Invoke(ctx context.Context, req ExecuteSQLRequest) (string, error) {
 | 
						|
	if req.SQL == "" {
 | 
						|
		return "", fmt.Errorf("sql is empty")
 | 
						|
	}
 | 
						|
	if d.promptDisabled {
 | 
						|
		return "the tool to be called is not available", nil
 | 
						|
	}
 | 
						|
 | 
						|
	tableType := table.TableType_OnlineTable
 | 
						|
	if d.agentIdentity.IsDraft {
 | 
						|
		tableType = table.TableType_DraftTable
 | 
						|
	}
 | 
						|
 | 
						|
	tableName, err := sqlparser.NewSQLParser().GetTableName(req.SQL)
 | 
						|
	if err != nil {
 | 
						|
		return "", err
 | 
						|
	}
 | 
						|
	if tableName != d.name {
 | 
						|
		return "", fmt.Errorf("sql table name %s not match database %s", tableName, d.name)
 | 
						|
	}
 | 
						|
 | 
						|
	eReq := &service.ExecuteSQLRequest{
 | 
						|
		SQL:         &req.SQL,
 | 
						|
		DatabaseID:  d.databaseID,
 | 
						|
		SQLType:     database.SQLType_Raw,
 | 
						|
		UserID:      d.connectorUID,
 | 
						|
		SpaceID:     d.spaceID,
 | 
						|
		ConnectorID: ptr.Of(d.agentIdentity.ConnectorID),
 | 
						|
		TableType:   tableType,
 | 
						|
	}
 | 
						|
 | 
						|
	sqlResult, err := crossdatabase.DefaultSVC().ExecuteSQL(ctx, eReq)
 | 
						|
	if err != nil {
 | 
						|
		return "", err
 | 
						|
	}
 | 
						|
 | 
						|
	return formatDatabaseResult(sqlResult), nil
 | 
						|
}
 | 
						|
 | 
						|
func newDatabaseTools(ctx context.Context, conf *databaseConfig) ([]tool.InvokableTool, error) {
 | 
						|
	if conf == nil || len(conf.databaseConf) == 0 {
 | 
						|
		return nil, nil
 | 
						|
	}
 | 
						|
 | 
						|
	dbInfos := conf.databaseConf
 | 
						|
 | 
						|
	tools := make([]tool.InvokableTool, 0, len(dbInfos))
 | 
						|
	for _, dbInfo := range dbInfos {
 | 
						|
		tID, err := strconv.ParseInt(dbInfo.GetTableId(), 10, 64)
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
		d := &databaseTool{
 | 
						|
			spaceID:        conf.spaceID,
 | 
						|
			connectorUID:   conf.userID,
 | 
						|
			agentIdentity:  conf.agentIdentity,
 | 
						|
			promptDisabled: dbInfo.GetPromptDisabled(),
 | 
						|
			name:           dbInfo.GetTableName(),
 | 
						|
			databaseID:     tID,
 | 
						|
		}
 | 
						|
 | 
						|
		dbTool, err := utils.InferTool(dbInfo.GetTableName(), buildDatabaseToolDescription(dbInfo), d.Invoke)
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
 | 
						|
		tools = append(tools, dbTool)
 | 
						|
	}
 | 
						|
 | 
						|
	return tools, nil
 | 
						|
}
 | 
						|
 | 
						|
func buildDatabaseToolDescription(tableInfo *bot_common.Database) string {
 | 
						|
	var sb strings.Builder
 | 
						|
 | 
						|
	sb.WriteString(fmt.Sprintf("Mysql query tool. Table name is '%s'.", tableInfo.GetTableName()))
 | 
						|
	if tableInfo.GetTableDesc() != "" {
 | 
						|
		sb.WriteString(fmt.Sprintf(" This table's desc is %s.", tableInfo.GetTableDesc()))
 | 
						|
	}
 | 
						|
	sb.WriteString("\n\nTable structure:\n")
 | 
						|
 | 
						|
	for _, field := range tableInfo.FieldList {
 | 
						|
		if field.Name == nil || field.Type == nil {
 | 
						|
			continue
 | 
						|
		}
 | 
						|
 | 
						|
		fieldType := getFieldTypeString(*field.Type)
 | 
						|
		sb.WriteString(fmt.Sprintf("- %s (%s)", *field.Name, fieldType))
 | 
						|
 | 
						|
		if field.Desc != nil && *field.Desc != "" {
 | 
						|
			sb.WriteString(fmt.Sprintf(": %s", *field.Desc))
 | 
						|
		}
 | 
						|
 | 
						|
		if field.MustRequired != nil && *field.MustRequired {
 | 
						|
			sb.WriteString(" (required)")
 | 
						|
		}
 | 
						|
 | 
						|
		sb.WriteString("\n")
 | 
						|
	}
 | 
						|
 | 
						|
	sb.WriteString("\nUse SQL to query this table. You can write SQL statements directly to operate.")
 | 
						|
	return sb.String()
 | 
						|
}
 | 
						|
 | 
						|
func getFieldTypeString(fieldType bot_common.FieldItemType) string {
 | 
						|
	switch fieldType {
 | 
						|
	case bot_common.FieldItemType_Text:
 | 
						|
		return "text"
 | 
						|
	case bot_common.FieldItemType_Number:
 | 
						|
		return "number"
 | 
						|
	case bot_common.FieldItemType_Date:
 | 
						|
		return "date"
 | 
						|
	case bot_common.FieldItemType_Float:
 | 
						|
		return "float"
 | 
						|
	case bot_common.FieldItemType_Boolean:
 | 
						|
		return "bool"
 | 
						|
	default:
 | 
						|
		return "invalid"
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func formatDatabaseResult(result *service.ExecuteSQLResponse) string {
 | 
						|
	var sb strings.Builder
 | 
						|
 | 
						|
	if len(result.Records) == 0 {
 | 
						|
		if result.RowsAffected == nil {
 | 
						|
			return "result is empty"
 | 
						|
		} else {
 | 
						|
			sb.WriteString("Rows affected: " + strconv.FormatInt(*result.RowsAffected, 10))
 | 
						|
			return sb.String()
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	var headers []string
 | 
						|
	if len(result.Records) > 0 {
 | 
						|
		firstRecord := result.Records[0]
 | 
						|
		for key := range firstRecord {
 | 
						|
			headers = append(headers, key)
 | 
						|
		}
 | 
						|
		sort.Strings(headers)
 | 
						|
	}
 | 
						|
 | 
						|
	if len(headers) == 0 {
 | 
						|
		return "no fields found in result"
 | 
						|
	}
 | 
						|
 | 
						|
	sb.WriteString("| ")
 | 
						|
	for _, header := range headers {
 | 
						|
		sb.WriteString(header + " | ")
 | 
						|
	}
 | 
						|
	sb.WriteString("\n")
 | 
						|
 | 
						|
	sb.WriteString("| ")
 | 
						|
	for range headers {
 | 
						|
		sb.WriteString("--- | ")
 | 
						|
	}
 | 
						|
	sb.WriteString("\n")
 | 
						|
 | 
						|
	fieldMap := slices.ToMap(result.FieldList, func(f *database.FieldItem) (string, *database.FieldItem) {
 | 
						|
		return f.Name, f
 | 
						|
	})
 | 
						|
	for _, record := range result.Records {
 | 
						|
		sb.WriteString("| ")
 | 
						|
		for _, header := range headers {
 | 
						|
			value, exists := record[header]
 | 
						|
			if !exists {
 | 
						|
				value = ""
 | 
						|
			}
 | 
						|
			fieldType := table.FieldItemType_Text
 | 
						|
			fValue, fExists := fieldMap[header]
 | 
						|
			if fExists {
 | 
						|
				fieldType = fValue.Type
 | 
						|
			}
 | 
						|
			sb.WriteString(convertDBValueToString(value, fieldType) + " | ")
 | 
						|
		}
 | 
						|
		sb.WriteString("\n")
 | 
						|
	}
 | 
						|
 | 
						|
	if result.RowsAffected != nil {
 | 
						|
		sb.WriteString("\nRows affected: " + strconv.FormatInt(*result.RowsAffected, 10))
 | 
						|
	}
 | 
						|
 | 
						|
	return sb.String()
 | 
						|
}
 | 
						|
 | 
						|
func convertDBValueToString(value interface{}, fieldType table.FieldItemType) string {
 | 
						|
	switch fieldType {
 | 
						|
	case table.FieldItemType_Text:
 | 
						|
		if byteArray, ok := value.([]uint8); ok {
 | 
						|
			return string(byteArray)
 | 
						|
		}
 | 
						|
 | 
						|
	case table.FieldItemType_Number:
 | 
						|
		switch v := value.(type) {
 | 
						|
		case int64:
 | 
						|
			return strconv.FormatInt(v, 10)
 | 
						|
		case []uint8:
 | 
						|
			return string(v)
 | 
						|
		}
 | 
						|
 | 
						|
	case table.FieldItemType_Float:
 | 
						|
		switch v := value.(type) {
 | 
						|
		case float64:
 | 
						|
			return strconv.FormatFloat(v, 'f', -1, 64)
 | 
						|
		case []uint8:
 | 
						|
			return string(v)
 | 
						|
		}
 | 
						|
 | 
						|
	case table.FieldItemType_Boolean:
 | 
						|
		switch v := value.(type) {
 | 
						|
		case bool:
 | 
						|
			return strconv.FormatBool(v)
 | 
						|
		case int64:
 | 
						|
			return strconv.FormatBool(v != 0)
 | 
						|
		case []uint8:
 | 
						|
			boolStr := string(v)
 | 
						|
			if boolStr == "1" || boolStr == "true" {
 | 
						|
				return "true"
 | 
						|
			}
 | 
						|
			return "false"
 | 
						|
		}
 | 
						|
 | 
						|
	case table.FieldItemType_Date:
 | 
						|
		switch v := value.(type) {
 | 
						|
		case time.Time:
 | 
						|
			return v.Format(TimeFormat)
 | 
						|
		case []uint8:
 | 
						|
			return string(v)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return fmt.Sprintf("%v", value)
 | 
						|
}
 |