443 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			443 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Go
		
	
	
	
/*
 | 
						|
 * Copyright 2025 coze-dev Authors
 | 
						|
 *
 | 
						|
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
 * you may not use this file except in compliance with the License.
 | 
						|
 * You may obtain a copy of the License at
 | 
						|
 *
 | 
						|
 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
 *
 | 
						|
 * Unless required by applicable law or agreed to in writing, software
 | 
						|
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
 * See the License for the specific language governing permissions and
 | 
						|
 * limitations under the License.
 | 
						|
 */
 | 
						|
 | 
						|
package database
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"fmt"
 | 
						|
	"strconv"
 | 
						|
	"strings"
 | 
						|
 | 
						|
	"github.com/spf13/cast"
 | 
						|
 | 
						|
	"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/api/model/table"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
 | 
						|
	nodedatabase "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
 | 
						|
)
 | 
						|
 | 
						|
type DatabaseRepository struct {
 | 
						|
	client service.Database
 | 
						|
}
 | 
						|
 | 
						|
func NewDatabaseRepository(client service.Database) *DatabaseRepository {
 | 
						|
	return &DatabaseRepository{
 | 
						|
		client: client,
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (d *DatabaseRepository) Execute(ctx context.Context, request *nodedatabase.CustomSQLRequest) (*nodedatabase.Response, error) {
 | 
						|
	var (
 | 
						|
		err            error
 | 
						|
		databaseInfoID = request.DatabaseInfoID
 | 
						|
		tableType      = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
 | 
						|
	)
 | 
						|
 | 
						|
	if request.IsDebugRun {
 | 
						|
		databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	req := &service.ExecuteSQLRequest{
 | 
						|
		DatabaseID:  databaseInfoID,
 | 
						|
		OperateType: database.OperateType_Custom,
 | 
						|
		SQL:         &request.SQL,
 | 
						|
		TableType:   tableType,
 | 
						|
		UserID:      strconv.FormatInt(request.UserID, 10),
 | 
						|
	}
 | 
						|
 | 
						|
	req.SQLParams = make([]*database.SQLParamVal, 0, len(request.Params))
 | 
						|
	for i := range request.Params {
 | 
						|
		param := request.Params[i]
 | 
						|
		req.SQLParams = append(req.SQLParams, &database.SQLParamVal{
 | 
						|
			ValueType: table.FieldItemType_Text,
 | 
						|
			Value:     ¶m.Value,
 | 
						|
			ISNull:    param.IsNull,
 | 
						|
		})
 | 
						|
	}
 | 
						|
	response, err := d.client.ExecuteSQL(ctx, req)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	// if rows affected is nil use 0 instead
 | 
						|
	if response.RowsAffected == nil {
 | 
						|
		response.RowsAffected = ptr.Of(int64(0))
 | 
						|
	}
 | 
						|
	return toNodeDateBaseResponse(response), nil
 | 
						|
}
 | 
						|
 | 
						|
func (d *DatabaseRepository) Delete(ctx context.Context, request *nodedatabase.DeleteRequest) (*nodedatabase.Response, error) {
 | 
						|
	var (
 | 
						|
		err            error
 | 
						|
		databaseInfoID = request.DatabaseInfoID
 | 
						|
		tableType      = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
 | 
						|
	)
 | 
						|
 | 
						|
	if request.IsDebugRun {
 | 
						|
		databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	req := &service.ExecuteSQLRequest{
 | 
						|
		DatabaseID:  databaseInfoID,
 | 
						|
		OperateType: database.OperateType_Delete,
 | 
						|
		TableType:   tableType,
 | 
						|
		UserID:      strconv.FormatInt(request.UserID, 10),
 | 
						|
	}
 | 
						|
 | 
						|
	if request.ConditionGroup != nil {
 | 
						|
		req.Condition, req.SQLParams, err = buildComplexCondition(request.ConditionGroup)
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	response, err := d.client.ExecuteSQL(ctx, req)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
	return toNodeDateBaseResponse(response), nil
 | 
						|
}
 | 
						|
 | 
						|
func (d *DatabaseRepository) Query(ctx context.Context, request *nodedatabase.QueryRequest) (*nodedatabase.Response, error) {
 | 
						|
	var (
 | 
						|
		err            error
 | 
						|
		databaseInfoID = request.DatabaseInfoID
 | 
						|
		tableType      = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
 | 
						|
	)
 | 
						|
 | 
						|
	if request.IsDebugRun {
 | 
						|
		databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	req := &service.ExecuteSQLRequest{
 | 
						|
		DatabaseID:  databaseInfoID,
 | 
						|
		OperateType: database.OperateType_Select,
 | 
						|
		TableType:   tableType,
 | 
						|
		UserID:      strconv.FormatInt(request.UserID, 10),
 | 
						|
	}
 | 
						|
 | 
						|
	req.SelectFieldList = &database.SelectFieldList{FieldID: make([]string, 0, len(request.SelectFields))}
 | 
						|
	for i := range request.SelectFields {
 | 
						|
		req.SelectFieldList.FieldID = append(req.SelectFieldList.FieldID, request.SelectFields[i])
 | 
						|
	}
 | 
						|
 | 
						|
	req.OrderByList = make([]database.OrderBy, 0)
 | 
						|
	for i := range request.OrderClauses {
 | 
						|
		clause := request.OrderClauses[i]
 | 
						|
		req.OrderByList = append(req.OrderByList, database.OrderBy{
 | 
						|
			Field:     clause.FieldID,
 | 
						|
			Direction: toOrderDirection(clause.IsAsc),
 | 
						|
		})
 | 
						|
	}
 | 
						|
 | 
						|
	if request.ConditionGroup != nil {
 | 
						|
		req.Condition, req.SQLParams, err = buildComplexCondition(request.ConditionGroup)
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	limit := request.Limit
 | 
						|
	req.Limit = &limit
 | 
						|
 | 
						|
	response, err := d.client.ExecuteSQL(ctx, req)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
	return toNodeDateBaseResponse(response), nil
 | 
						|
}
 | 
						|
 | 
						|
func (d *DatabaseRepository) Update(ctx context.Context, request *nodedatabase.UpdateRequest) (*nodedatabase.Response, error) {
 | 
						|
 | 
						|
	var (
 | 
						|
		err            error
 | 
						|
		condition      *database.ComplexCondition
 | 
						|
		params         []*database.SQLParamVal
 | 
						|
		databaseInfoID = request.DatabaseInfoID
 | 
						|
		tableType      = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
 | 
						|
	)
 | 
						|
 | 
						|
	if request.IsDebugRun {
 | 
						|
		databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	req := &service.ExecuteSQLRequest{
 | 
						|
		DatabaseID:  databaseInfoID,
 | 
						|
		OperateType: database.OperateType_Update,
 | 
						|
		SQLParams:   make([]*database.SQLParamVal, 0),
 | 
						|
		TableType:   tableType,
 | 
						|
	}
 | 
						|
 | 
						|
	uid := ctxutil.GetUIDFromCtx(ctx)
 | 
						|
	if uid != nil {
 | 
						|
		req.UserID = conv.Int64ToStr(*uid)
 | 
						|
	}
 | 
						|
	req.UpsertRows, req.SQLParams, err = resolveUpsertRow(request.Fields)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	if request.ConditionGroup != nil {
 | 
						|
		condition, params, err = buildComplexCondition(request.ConditionGroup)
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
 | 
						|
		req.Condition = condition
 | 
						|
		req.SQLParams = append(req.SQLParams, params...)
 | 
						|
	}
 | 
						|
 | 
						|
	response, err := d.client.ExecuteSQL(ctx, req)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
	return toNodeDateBaseResponse(response), nil
 | 
						|
}
 | 
						|
 | 
						|
func (d *DatabaseRepository) Insert(ctx context.Context, request *nodedatabase.InsertRequest) (*nodedatabase.Response, error) {
 | 
						|
	var (
 | 
						|
		err            error
 | 
						|
		databaseInfoID = request.DatabaseInfoID
 | 
						|
		tableType      = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
 | 
						|
	)
 | 
						|
 | 
						|
	if request.IsDebugRun {
 | 
						|
		databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	req := &service.ExecuteSQLRequest{
 | 
						|
		DatabaseID:  databaseInfoID,
 | 
						|
		OperateType: database.OperateType_Insert,
 | 
						|
		TableType:   tableType,
 | 
						|
		UserID:      strconv.FormatInt(request.UserID, 10),
 | 
						|
	}
 | 
						|
 | 
						|
	req.UpsertRows, req.SQLParams, err = resolveUpsertRow(request.Fields)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	response, err := d.client.ExecuteSQL(ctx, req)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
	return toNodeDateBaseResponse(response), nil
 | 
						|
}
 | 
						|
 | 
						|
func (d *DatabaseRepository) getDraftTableID(ctx context.Context, onlineID int64) (int64, error) {
 | 
						|
	resp, err := d.client.GetDraftDatabaseByOnlineID(ctx, &service.GetDraftDatabaseByOnlineIDRequest{OnlineID: onlineID})
 | 
						|
	if err != nil {
 | 
						|
		return 0, err
 | 
						|
	}
 | 
						|
	return resp.Database.ID, nil
 | 
						|
 | 
						|
}
 | 
						|
 | 
						|
func buildComplexCondition(conditionGroup *nodedatabase.ConditionGroup) (*database.ComplexCondition, []*database.SQLParamVal, error) {
 | 
						|
	condition := &database.ComplexCondition{}
 | 
						|
	logic, err := toLogic(conditionGroup.Relation)
 | 
						|
	if err != nil {
 | 
						|
		return nil, nil, err
 | 
						|
	}
 | 
						|
	condition.Logic = logic
 | 
						|
 | 
						|
	params := make([]*database.SQLParamVal, 0)
 | 
						|
	for i := range conditionGroup.Conditions {
 | 
						|
		var (
 | 
						|
			nCond = conditionGroup.Conditions[i]
 | 
						|
			vals  []*database.SQLParamVal
 | 
						|
			dCond = &database.Condition{
 | 
						|
				Left: nCond.Left,
 | 
						|
			}
 | 
						|
		)
 | 
						|
		opt, err := toOperation(nCond.Operator)
 | 
						|
		if err != nil {
 | 
						|
			return nil, nil, err
 | 
						|
		}
 | 
						|
		dCond.Operation = opt
 | 
						|
 | 
						|
		if isNullOrNotNull(opt) {
 | 
						|
			condition.Conditions = append(condition.Conditions, dCond)
 | 
						|
			continue
 | 
						|
		}
 | 
						|
		dCond.Right, vals, err = resolveRightValue(opt, nCond.Right)
 | 
						|
		if err != nil {
 | 
						|
			return nil, nil, err
 | 
						|
		}
 | 
						|
		condition.Conditions = append(condition.Conditions, dCond)
 | 
						|
 | 
						|
		params = append(params, vals...)
 | 
						|
 | 
						|
	}
 | 
						|
	return condition, params, nil
 | 
						|
}
 | 
						|
 | 
						|
func toMapStringAny(m map[string]string) map[string]any {
 | 
						|
	ret := make(map[string]any, len(m))
 | 
						|
	for k, v := range m {
 | 
						|
		ret[k] = v
 | 
						|
	}
 | 
						|
	return ret
 | 
						|
}
 | 
						|
 | 
						|
func toOperation(operator nodedatabase.Operator) (database.Operation, error) {
 | 
						|
	switch operator {
 | 
						|
	case nodedatabase.OperatorEqual:
 | 
						|
		return database.Operation_EQUAL, nil
 | 
						|
	case nodedatabase.OperatorNotEqual:
 | 
						|
		return database.Operation_NOT_EQUAL, nil
 | 
						|
	case nodedatabase.OperatorGreater:
 | 
						|
		return database.Operation_GREATER_THAN, nil
 | 
						|
	case nodedatabase.OperatorGreaterOrEqual:
 | 
						|
		return database.Operation_GREATER_EQUAL, nil
 | 
						|
	case nodedatabase.OperatorLesser:
 | 
						|
		return database.Operation_LESS_THAN, nil
 | 
						|
	case nodedatabase.OperatorLesserOrEqual:
 | 
						|
		return database.Operation_LESS_EQUAL, nil
 | 
						|
	case nodedatabase.OperatorIn:
 | 
						|
		return database.Operation_IN, nil
 | 
						|
	case nodedatabase.OperatorNotIn:
 | 
						|
		return database.Operation_NOT_IN, nil
 | 
						|
	case nodedatabase.OperatorIsNotNull:
 | 
						|
		return database.Operation_IS_NOT_NULL, nil
 | 
						|
	case nodedatabase.OperatorIsNull:
 | 
						|
		return database.Operation_IS_NULL, nil
 | 
						|
	case nodedatabase.OperatorLike:
 | 
						|
		return database.Operation_LIKE, nil
 | 
						|
	case nodedatabase.OperatorNotLike:
 | 
						|
		return database.Operation_NOT_LIKE, nil
 | 
						|
	default:
 | 
						|
		return database.Operation(0), fmt.Errorf("invalid operator %v", operator)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func resolveRightValue(operator database.Operation, right any) (string, []*database.SQLParamVal, error) {
 | 
						|
 | 
						|
	if isInOrNotIn(operator) {
 | 
						|
		var (
 | 
						|
			vals    = make([]*database.SQLParamVal, 0)
 | 
						|
			anyVals = make([]any, 0)
 | 
						|
			commas  = make([]string, 0, len(anyVals))
 | 
						|
		)
 | 
						|
 | 
						|
		anyVals = right.([]any)
 | 
						|
		for i := range anyVals {
 | 
						|
			v := cast.ToString(anyVals[i])
 | 
						|
			vals = append(vals, &database.SQLParamVal{ValueType: table.FieldItemType_Text, Value: &v})
 | 
						|
			commas = append(commas, "?")
 | 
						|
		}
 | 
						|
		value := "(" + strings.Join(commas, ",") + ")"
 | 
						|
		return value, vals, nil
 | 
						|
	}
 | 
						|
 | 
						|
	rightValue, err := cast.ToStringE(right)
 | 
						|
	if err != nil {
 | 
						|
		return "", nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	if isLikeOrNotLike(operator) {
 | 
						|
		var (
 | 
						|
			value = "?"
 | 
						|
			v     = "%s" + rightValue + "%s"
 | 
						|
		)
 | 
						|
		return value, []*database.SQLParamVal{{ValueType: table.FieldItemType_Text, Value: &v}}, nil
 | 
						|
	}
 | 
						|
 | 
						|
	return "?", []*database.SQLParamVal{{ValueType: table.FieldItemType_Text, Value: &rightValue}}, nil
 | 
						|
}
 | 
						|
 | 
						|
func resolveUpsertRow(fields map[string]any) ([]*database.UpsertRow, []*database.SQLParamVal, error) {
 | 
						|
	upsertRow := &database.UpsertRow{Records: make([]*database.Record, 0, len(fields))}
 | 
						|
	params := make([]*database.SQLParamVal, 0)
 | 
						|
	for key, value := range fields {
 | 
						|
		val, err := cast.ToStringE(value)
 | 
						|
		if err != nil {
 | 
						|
			return nil, nil, err
 | 
						|
		}
 | 
						|
		record := &database.Record{
 | 
						|
			FieldId:    key,
 | 
						|
			FieldValue: "?",
 | 
						|
		}
 | 
						|
		upsertRow.Records = append(upsertRow.Records, record)
 | 
						|
		params = append(params, &database.SQLParamVal{
 | 
						|
			ValueType: table.FieldItemType_Text,
 | 
						|
			Value:     &val,
 | 
						|
		})
 | 
						|
	}
 | 
						|
	return []*database.UpsertRow{upsertRow}, params, nil
 | 
						|
}
 | 
						|
 | 
						|
func isNullOrNotNull(opt database.Operation) bool {
 | 
						|
	return opt == database.Operation_IS_NOT_NULL || opt == database.Operation_IS_NULL
 | 
						|
}
 | 
						|
 | 
						|
func isLikeOrNotLike(opt database.Operation) bool {
 | 
						|
	return opt == database.Operation_LIKE || opt == database.Operation_NOT_LIKE
 | 
						|
}
 | 
						|
 | 
						|
func isInOrNotIn(opt database.Operation) bool {
 | 
						|
	return opt == database.Operation_IN || opt == database.Operation_NOT_IN
 | 
						|
}
 | 
						|
 | 
						|
func toOrderDirection(isAsc bool) table.SortDirection {
 | 
						|
	if isAsc {
 | 
						|
		return table.SortDirection_ASC
 | 
						|
	}
 | 
						|
	return table.SortDirection_Desc
 | 
						|
}
 | 
						|
 | 
						|
func toLogic(relation nodedatabase.ClauseRelation) (database.Logic, error) {
 | 
						|
	switch relation {
 | 
						|
	case nodedatabase.ClauseRelationOR:
 | 
						|
		return database.Logic_Or, nil
 | 
						|
	case nodedatabase.ClauseRelationAND:
 | 
						|
		return database.Logic_And, nil
 | 
						|
	default:
 | 
						|
		return database.Logic(0), fmt.Errorf("invalid relation %v", relation)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func toNodeDateBaseResponse(response *service.ExecuteSQLResponse) *nodedatabase.Response {
 | 
						|
	objects := make([]nodedatabase.Object, 0, len(response.Records))
 | 
						|
	for i := range response.Records {
 | 
						|
		objects = append(objects, response.Records[i])
 | 
						|
	}
 | 
						|
	return &nodedatabase.Response{
 | 
						|
		Objects:   objects,
 | 
						|
		RowNumber: response.RowsAffected,
 | 
						|
	}
 | 
						|
}
 |