feat: manually mirror opencoze's code from bytedance
Change-Id: I09a73aadda978ad9511264a756b2ce51f5761adf
This commit is contained in:
@@ -0,0 +1,302 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package agentflow
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
"github.com/cloudwego/eino/components/tool/utils"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/bot_common"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/table"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossdatabase"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/sqlparser"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
)
|
||||
|
||||
const (
|
||||
TimeFormat = "2006-01-02 15:04:05"
|
||||
)
|
||||
|
||||
type databaseConfig struct {
|
||||
agentIdentity *entity.AgentIdentity
|
||||
userID string
|
||||
spaceID int64
|
||||
|
||||
databaseConf []*bot_common.Database
|
||||
}
|
||||
|
||||
type databaseTool struct {
|
||||
agentIdentity *entity.AgentIdentity
|
||||
connectorUID string
|
||||
spaceID int64
|
||||
|
||||
databaseID int64
|
||||
|
||||
name string
|
||||
promptDisabled bool
|
||||
}
|
||||
|
||||
type ExecuteSQLRequest struct {
|
||||
SQL string `json:"sql" jsonschema:"description=SQL query to execute against the database. You can use standard SQL syntax like SELECT, INSERT, UPDATE, DELETE."`
|
||||
}
|
||||
|
||||
func (d *databaseTool) Invoke(ctx context.Context, req ExecuteSQLRequest) (string, error) {
|
||||
if req.SQL == "" {
|
||||
return "", fmt.Errorf("sql is empty")
|
||||
}
|
||||
if d.promptDisabled {
|
||||
return "the tool to be called is not available", nil
|
||||
}
|
||||
|
||||
tableType := table.TableType_OnlineTable
|
||||
if d.agentIdentity.IsDraft {
|
||||
tableType = table.TableType_DraftTable
|
||||
}
|
||||
|
||||
tableName, err := sqlparser.NewSQLParser().GetTableName(req.SQL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if tableName != d.name {
|
||||
return "", fmt.Errorf("sql table name %s not match database %s", tableName, d.name)
|
||||
}
|
||||
|
||||
eReq := &service.ExecuteSQLRequest{
|
||||
SQL: &req.SQL,
|
||||
DatabaseID: d.databaseID,
|
||||
SQLType: database.SQLType_Raw,
|
||||
UserID: d.connectorUID,
|
||||
SpaceID: d.spaceID,
|
||||
ConnectorID: ptr.Of(d.agentIdentity.ConnectorID),
|
||||
TableType: tableType,
|
||||
}
|
||||
|
||||
sqlResult, err := crossdatabase.DefaultSVC().ExecuteSQL(ctx, eReq)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return formatDatabaseResult(sqlResult), nil
|
||||
}
|
||||
|
||||
func newDatabaseTools(ctx context.Context, conf *databaseConfig) ([]tool.InvokableTool, error) {
|
||||
if conf == nil || len(conf.databaseConf) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
dbInfos := conf.databaseConf
|
||||
|
||||
d := &databaseTool{
|
||||
spaceID: conf.spaceID,
|
||||
connectorUID: conf.userID,
|
||||
agentIdentity: conf.agentIdentity,
|
||||
}
|
||||
|
||||
tools := make([]tool.InvokableTool, 0, len(dbInfos))
|
||||
for _, dbInfo := range dbInfos {
|
||||
tID, err := strconv.ParseInt(dbInfo.GetTableId(), 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
d.databaseID = tID
|
||||
d.promptDisabled = dbInfo.GetPromptDisabled()
|
||||
d.name = dbInfo.GetTableName()
|
||||
dbTool, err := utils.InferTool(dbInfo.GetTableName(), buildDatabaseToolDescription(dbInfo), d.Invoke)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tools = append(tools, dbTool)
|
||||
}
|
||||
|
||||
return tools, nil
|
||||
}
|
||||
|
||||
func buildDatabaseToolDescription(tableInfo *bot_common.Database) string {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString(fmt.Sprintf("Mysql query tool. Table name is '%s'.", tableInfo.GetTableName()))
|
||||
if tableInfo.GetTableDesc() != "" {
|
||||
sb.WriteString(fmt.Sprintf(" This table's desc is %s.", tableInfo.GetTableDesc()))
|
||||
}
|
||||
sb.WriteString("\n\nTable structure:\n")
|
||||
|
||||
for _, field := range tableInfo.FieldList {
|
||||
if field.Name == nil || field.Type == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
fieldType := getFieldTypeString(*field.Type)
|
||||
sb.WriteString(fmt.Sprintf("- %s (%s)", *field.Name, fieldType))
|
||||
|
||||
if field.Desc != nil && *field.Desc != "" {
|
||||
sb.WriteString(fmt.Sprintf(": %s", *field.Desc))
|
||||
}
|
||||
|
||||
if field.MustRequired != nil && *field.MustRequired {
|
||||
sb.WriteString(" (required)")
|
||||
}
|
||||
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
sb.WriteString("\nUse SQL to query this table. You can write SQL statements directly to operate.")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func getFieldTypeString(fieldType bot_common.FieldItemType) string {
|
||||
switch fieldType {
|
||||
case bot_common.FieldItemType_Text:
|
||||
return "text"
|
||||
case bot_common.FieldItemType_Number:
|
||||
return "number"
|
||||
case bot_common.FieldItemType_Date:
|
||||
return "date"
|
||||
case bot_common.FieldItemType_Float:
|
||||
return "float"
|
||||
case bot_common.FieldItemType_Boolean:
|
||||
return "bool"
|
||||
default:
|
||||
return "invalid"
|
||||
}
|
||||
}
|
||||
|
||||
func formatDatabaseResult(result *service.ExecuteSQLResponse) string {
|
||||
var sb strings.Builder
|
||||
|
||||
if len(result.Records) == 0 {
|
||||
if result.RowsAffected == nil {
|
||||
return "result is empty"
|
||||
} else {
|
||||
sb.WriteString("Rows affected: " + strconv.FormatInt(*result.RowsAffected, 10))
|
||||
return sb.String()
|
||||
}
|
||||
}
|
||||
|
||||
var headers []string
|
||||
if len(result.Records) > 0 {
|
||||
firstRecord := result.Records[0]
|
||||
for key := range firstRecord {
|
||||
headers = append(headers, key)
|
||||
}
|
||||
sort.Strings(headers)
|
||||
}
|
||||
|
||||
if len(headers) == 0 {
|
||||
return "no fields found in result"
|
||||
}
|
||||
|
||||
sb.WriteString("| ")
|
||||
for _, header := range headers {
|
||||
sb.WriteString(header + " | ")
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
|
||||
sb.WriteString("| ")
|
||||
for range headers {
|
||||
sb.WriteString("--- | ")
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
|
||||
fieldMap := slices.ToMap(result.FieldList, func(f *database.FieldItem) (string, *database.FieldItem) {
|
||||
return f.Name, f
|
||||
})
|
||||
for _, record := range result.Records {
|
||||
sb.WriteString("| ")
|
||||
for _, header := range headers {
|
||||
value, exists := record[header]
|
||||
if !exists {
|
||||
value = ""
|
||||
}
|
||||
fieldType := table.FieldItemType_Text
|
||||
fValue, fExists := fieldMap[header]
|
||||
if fExists {
|
||||
fieldType = fValue.Type
|
||||
}
|
||||
sb.WriteString(convertDBValueToString(value, fieldType) + " | ")
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
if result.RowsAffected != nil {
|
||||
sb.WriteString("\nRows affected: " + strconv.FormatInt(*result.RowsAffected, 10))
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func convertDBValueToString(value interface{}, fieldType table.FieldItemType) string {
|
||||
switch fieldType {
|
||||
case table.FieldItemType_Text:
|
||||
if byteArray, ok := value.([]uint8); ok {
|
||||
return string(byteArray)
|
||||
}
|
||||
|
||||
case table.FieldItemType_Number:
|
||||
switch v := value.(type) {
|
||||
case int64:
|
||||
return strconv.FormatInt(v, 10)
|
||||
case []uint8:
|
||||
return string(v)
|
||||
}
|
||||
|
||||
case table.FieldItemType_Float:
|
||||
switch v := value.(type) {
|
||||
case float64:
|
||||
return strconv.FormatFloat(v, 'f', -1, 64)
|
||||
case []uint8:
|
||||
return string(v)
|
||||
}
|
||||
|
||||
case table.FieldItemType_Boolean:
|
||||
switch v := value.(type) {
|
||||
case bool:
|
||||
return strconv.FormatBool(v)
|
||||
case int64:
|
||||
return strconv.FormatBool(v != 0)
|
||||
case []uint8:
|
||||
boolStr := string(v)
|
||||
if boolStr == "1" || boolStr == "true" {
|
||||
return "true"
|
||||
}
|
||||
return "false"
|
||||
}
|
||||
|
||||
case table.FieldItemType_Date:
|
||||
switch v := value.(type) {
|
||||
case time.Time:
|
||||
return v.Format(TimeFormat)
|
||||
case []uint8:
|
||||
return string(v)
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%v", value)
|
||||
}
|
||||
Reference in New Issue
Block a user