refactor(workflow): Move the database component in the Workflow package into the common crossdomain package (#704)

This commit is contained in:
Ryo 2025-08-12 15:42:58 +08:00 committed by GitHub
parent e7011f2549
commit 9ff065cebd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 834 additions and 1353 deletions

View File

@ -59,7 +59,10 @@ import (
appplugin "github.com/coze-dev/coze-studio/backend/application/plugin" appplugin "github.com/coze-dev/coze-studio/backend/application/plugin"
"github.com/coze-dev/coze-studio/backend/application/user" "github.com/coze-dev/coze-studio/backend/application/user"
appworkflow "github.com/coze-dev/coze-studio/backend/application/workflow" appworkflow "github.com/coze-dev/coze-studio/backend/application/workflow"
crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/database/databasemock"
crossuser "github.com/coze-dev/coze-studio/backend/crossdomain/contract/user" crossuser "github.com/coze-dev/coze-studio/backend/crossdomain/contract/user"
"github.com/coze-dev/coze-studio/backend/crossdomain/impl/code"
plugin3 "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/plugin" plugin3 "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/plugin"
entity4 "github.com/coze-dev/coze-studio/backend/domain/memory/database/entity" entity4 "github.com/coze-dev/coze-studio/backend/domain/memory/database/entity"
entity2 "github.com/coze-dev/coze-studio/backend/domain/openauth/openapiauth/entity" entity2 "github.com/coze-dev/coze-studio/backend/domain/openauth/openapiauth/entity"
@ -67,9 +70,6 @@ import (
entity5 "github.com/coze-dev/coze-studio/backend/domain/plugin/entity" entity5 "github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
userentity "github.com/coze-dev/coze-studio/backend/domain/user/entity" userentity "github.com/coze-dev/coze-studio/backend/domain/user/entity"
workflow2 "github.com/coze-dev/coze-studio/backend/domain/workflow" workflow2 "github.com/coze-dev/coze-studio/backend/domain/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/code"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database/databasemock"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge" "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge/knowledgemock" "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge/knowledgemock"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model" "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
@ -122,7 +122,7 @@ type wfTestRunner struct {
plugin *mockPlugin.MockPluginService plugin *mockPlugin.MockPluginService
tos *storageMock.MockStorage tos *storageMock.MockStorage
knowledge *knowledgemock.MockKnowledgeOperator knowledge *knowledgemock.MockKnowledgeOperator
database *databasemock.MockDatabaseOperator database *databasemock.MockDatabase
pluginSrv *pluginmock.MockService pluginSrv *pluginmock.MockService
internalModel *testutil.UTChatModel internalModel *testutil.UTChatModel
ctx context.Context ctx context.Context
@ -291,8 +291,8 @@ func newWfTestRunner(t *testing.T) *wfTestRunner {
m4 := mockey.Mock(ctxutil.MustGetUIDFromCtx).Return(int64(1)).Build() m4 := mockey.Mock(ctxutil.MustGetUIDFromCtx).Return(int64(1)).Build()
m5 := mockey.Mock(ctxutil.GetUIDFromCtx).Return(ptr.Of(int64(1))).Build() m5 := mockey.Mock(ctxutil.GetUIDFromCtx).Return(ptr.Of(int64(1))).Build()
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl) mockDatabaseOperator := databasemock.NewMockDatabase(ctrl)
database.SetDatabaseOperator(mockDatabaseOperator) crossdatabase.SetDefaultSVC(mockDatabaseOperator)
mockPluginSrv := pluginmock.NewMockService(ctrl) mockPluginSrv := pluginmock.NewMockService(ctrl)
plugin.SetPluginService(mockPluginSrv) plugin.SetPluginService(mockPluginSrv)

View File

@ -204,3 +204,110 @@ type GetAllDatabaseByAppIDRequest struct {
type GetAllDatabaseByAppIDResponse struct { type GetAllDatabaseByAppIDResponse struct {
Databases []*Database // online databases Databases []*Database // online databases
} }
type SQLParam struct {
Value string
IsNull bool
}
type CustomSQLRequest struct {
DatabaseInfoID int64
SQL string
Params []SQLParam
IsDebugRun bool
UserID string
ConnectorID int64
}
type Object = map[string]any
type Response struct {
RowNumber *int64
Objects []Object
}
type Operator string
type ClauseRelation string
const (
ClauseRelationAND ClauseRelation = "and"
ClauseRelationOR ClauseRelation = "or"
)
const (
OperatorEqual Operator = "="
OperatorNotEqual Operator = "!="
OperatorGreater Operator = ">"
OperatorLesser Operator = "<"
OperatorGreaterOrEqual Operator = ">="
OperatorLesserOrEqual Operator = "<="
OperatorIn Operator = "in"
OperatorNotIn Operator = "not_in"
OperatorIsNull Operator = "is_null"
OperatorIsNotNull Operator = "is_not_null"
OperatorLike Operator = "like"
OperatorNotLike Operator = "not_like"
)
type ClauseGroup struct {
Single *Clause
Multi *MultiClause
}
type Clause struct {
Left string
Operator Operator
}
type MultiClause struct {
Clauses []*Clause
Relation ClauseRelation
}
type ConditionStr struct {
Left string
Operator Operator
Right any
}
type ConditionGroup struct {
Conditions []*ConditionStr
Relation ClauseRelation
}
type DeleteRequest struct {
DatabaseInfoID int64
ConditionGroup *ConditionGroup
IsDebugRun bool
UserID string
ConnectorID int64
}
type QueryRequest struct {
DatabaseInfoID int64
SelectFields []string
Limit int64
ConditionGroup *ConditionGroup
OrderClauses []*OrderClause
IsDebugRun bool
UserID string
ConnectorID int64
}
type OrderClause struct {
FieldID string
IsAsc bool
}
type UpdateRequest struct {
DatabaseInfoID int64
ConditionGroup *ConditionGroup
Fields map[string]any
IsDebugRun bool
UserID string
ConnectorID int64
}
type InsertRequest struct {
DatabaseInfoID int64
Fields map[string]any
IsDebugRun bool
UserID string
ConnectorID int64
}

View File

@ -126,7 +126,6 @@ func Init(ctx context.Context) (err error) {
if err != nil { if err != nil {
return fmt.Errorf("Init - initVitalServices failed, err: %v", err) return fmt.Errorf("Init - initVitalServices failed, err: %v", err)
} }
crossconnector.SetDefaultSVC(connectorImpl.InitDomainService(basicServices.connectorSVC.DomainSVC)) crossconnector.SetDefaultSVC(connectorImpl.InitDomainService(basicServices.connectorSVC.DomainSVC))
crossdatabase.SetDefaultSVC(databaseImpl.InitDomainService(primaryServices.memorySVC.DatabaseDomainSVC)) crossdatabase.SetDefaultSVC(databaseImpl.InitDomainService(primaryServices.memorySVC.DatabaseDomainSVC))
crossknowledge.SetDefaultSVC(knowledgeImpl.InitDomainService(primaryServices.knowledgeSVC.DomainSVC)) crossknowledge.SetDefaultSVC(knowledgeImpl.InitDomainService(primaryServices.knowledgeSVC.DomainSVC))

View File

@ -19,14 +19,12 @@ package workflow
import ( import (
"context" "context"
"github.com/cloudwego/eino/callbacks"
"github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/compose"
"gorm.io/gorm" "gorm.io/gorm"
"github.com/cloudwego/eino/callbacks"
"github.com/coze-dev/coze-studio/backend/application/internal" "github.com/coze-dev/coze-studio/backend/application/internal"
"github.com/coze-dev/coze-studio/backend/crossdomain/impl/code"
wfdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/database"
wfknowledge "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/knowledge" wfknowledge "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/knowledge"
wfmodel "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/model" wfmodel "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/model"
wfplugin "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/plugin" wfplugin "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/plugin"
@ -38,8 +36,6 @@ import (
plugin "github.com/coze-dev/coze-studio/backend/domain/plugin/service" plugin "github.com/coze-dev/coze-studio/backend/domain/plugin/service"
search "github.com/coze-dev/coze-studio/backend/domain/search/service" search "github.com/coze-dev/coze-studio/backend/domain/search/service"
"github.com/coze-dev/coze-studio/backend/domain/workflow" "github.com/coze-dev/coze-studio/backend/domain/workflow"
crosscode "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/code"
crossdatabase "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
crossknowledge "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge" crossknowledge "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
crossmodel "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model" crossmodel "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
crossplugin "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin" crossplugin "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
@ -88,13 +84,12 @@ func InitService(ctx context.Context, components *ServiceComponents) (*Applicati
workflow.SetRepository(workflowRepo) workflow.SetRepository(workflowRepo)
workflowDomainSVC := service.NewWorkflowService(workflowRepo) workflowDomainSVC := service.NewWorkflowService(workflowRepo)
crossdatabase.SetDatabaseOperator(wfdatabase.NewDatabaseRepository(components.DatabaseDomainSVC))
crossvariable.SetVariableHandler(variable.NewVariableHandler(components.VariablesDomainSVC)) crossvariable.SetVariableHandler(variable.NewVariableHandler(components.VariablesDomainSVC))
crossvariable.SetVariablesMetaGetter(variable.NewVariablesMetaGetter(components.VariablesDomainSVC)) crossvariable.SetVariablesMetaGetter(variable.NewVariablesMetaGetter(components.VariablesDomainSVC))
crossplugin.SetPluginService(wfplugin.NewPluginService(components.PluginDomainSVC, components.Tos)) crossplugin.SetPluginService(wfplugin.NewPluginService(components.PluginDomainSVC, components.Tos))
crossknowledge.SetKnowledgeOperator(wfknowledge.NewKnowledgeRepository(components.KnowledgeDomainSVC, components.IDGen)) crossknowledge.SetKnowledgeOperator(wfknowledge.NewKnowledgeRepository(components.KnowledgeDomainSVC, components.IDGen))
crossmodel.SetManager(wfmodel.NewModelManager(components.ModelManager, nil)) crossmodel.SetManager(wfmodel.NewModelManager(components.ModelManager, nil))
crosscode.SetCodeRunner(components.CodeRunner) code.SetCodeRunner(components.CodeRunner)
crosssearch.SetNotifier(wfsearch.NewNotify(components.DomainNotifier)) crosssearch.SetNotifier(wfsearch.NewNotify(components.DomainNotifier))
callbacks.AppendGlobalHandlers(workflowservice.GetTokenCallbackHandler()) callbacks.AppendGlobalHandlers(workflowservice.GetTokenCallbackHandler())

View File

@ -25,9 +25,8 @@ import (
"strings" "strings"
"time" "time"
xmaps "golang.org/x/exp/maps"
"github.com/cloudwego/eino/schema" "github.com/cloudwego/eino/schema"
xmaps "golang.org/x/exp/maps"
"github.com/coze-dev/coze-studio/backend/api/model/app/bot_common" "github.com/coze-dev/coze-studio/backend/api/model/app/bot_common"
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge" model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"

View File

@ -22,6 +22,7 @@ import (
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database" "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
) )
//go:generate mockgen -destination databasemock/database_mock.go --package databasemock -source database.go
type Database interface { type Database interface {
ExecuteSQL(ctx context.Context, req *database.ExecuteSQLRequest) (*database.ExecuteSQLResponse, error) ExecuteSQL(ctx context.Context, req *database.ExecuteSQLRequest) (*database.ExecuteSQLResponse, error)
PublishDatabase(ctx context.Context, req *database.PublishDatabaseRequest) (resp *database.PublishDatabaseResponse, err error) PublishDatabase(ctx context.Context, req *database.PublishDatabaseRequest) (resp *database.PublishDatabaseResponse, err error)
@ -30,6 +31,12 @@ type Database interface {
UnBindDatabase(ctx context.Context, req *database.UnBindDatabaseToAgentRequest) error UnBindDatabase(ctx context.Context, req *database.UnBindDatabaseToAgentRequest) error
MGetDatabase(ctx context.Context, req *database.MGetDatabaseRequest) (*database.MGetDatabaseResponse, error) MGetDatabase(ctx context.Context, req *database.MGetDatabaseRequest) (*database.MGetDatabaseResponse, error)
GetAllDatabaseByAppID(ctx context.Context, req *database.GetAllDatabaseByAppIDRequest) (*database.GetAllDatabaseByAppIDResponse, error) GetAllDatabaseByAppID(ctx context.Context, req *database.GetAllDatabaseByAppIDRequest) (*database.GetAllDatabaseByAppIDResponse, error)
Execute(ctx context.Context, request *database.CustomSQLRequest) (*database.Response, error)
Query(ctx context.Context, request *database.QueryRequest) (*database.Response, error)
Update(context.Context, *database.UpdateRequest) (*database.Response, error)
Insert(ctx context.Context, request *database.InsertRequest) (*database.Response, error)
Delete(context.Context, *database.DeleteRequest) (*database.Response, error)
} }
var defaultSVC Database var defaultSVC Database

View File

@ -0,0 +1,219 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: database.go
//
// Generated by this command:
//
// mockgen -destination databasemock/database_mock.go --package databasemock -source database.go
//
// Package databasemock is a generated GoMock package.
package databasemock
import (
context "context"
reflect "reflect"
database "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
gomock "go.uber.org/mock/gomock"
)
// MockDatabase is a mock of Database interface.
type MockDatabase struct {
ctrl *gomock.Controller
recorder *MockDatabaseMockRecorder
isgomock struct{}
}
// MockDatabaseMockRecorder is the mock recorder for MockDatabase.
type MockDatabaseMockRecorder struct {
mock *MockDatabase
}
// NewMockDatabase creates a new mock instance.
func NewMockDatabase(ctrl *gomock.Controller) *MockDatabase {
mock := &MockDatabase{ctrl: ctrl}
mock.recorder = &MockDatabaseMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockDatabase) EXPECT() *MockDatabaseMockRecorder {
return m.recorder
}
// BindDatabase mocks base method.
func (m *MockDatabase) BindDatabase(ctx context.Context, req *database.BindDatabaseToAgentRequest) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BindDatabase", ctx, req)
ret0, _ := ret[0].(error)
return ret0
}
// BindDatabase indicates an expected call of BindDatabase.
func (mr *MockDatabaseMockRecorder) BindDatabase(ctx, req any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BindDatabase", reflect.TypeOf((*MockDatabase)(nil).BindDatabase), ctx, req)
}
// Delete mocks base method.
func (m *MockDatabase) Delete(arg0 context.Context, arg1 *database.DeleteRequest) (*database.Response, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Delete", arg0, arg1)
ret0, _ := ret[0].(*database.Response)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Delete indicates an expected call of Delete.
func (mr *MockDatabaseMockRecorder) Delete(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockDatabase)(nil).Delete), arg0, arg1)
}
// DeleteDatabase mocks base method.
func (m *MockDatabase) DeleteDatabase(ctx context.Context, req *database.DeleteDatabaseRequest) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteDatabase", ctx, req)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteDatabase indicates an expected call of DeleteDatabase.
func (mr *MockDatabaseMockRecorder) DeleteDatabase(ctx, req any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteDatabase", reflect.TypeOf((*MockDatabase)(nil).DeleteDatabase), ctx, req)
}
// Execute mocks base method.
func (m *MockDatabase) Execute(ctx context.Context, request *database.CustomSQLRequest) (*database.Response, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Execute", ctx, request)
ret0, _ := ret[0].(*database.Response)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Execute indicates an expected call of Execute.
func (mr *MockDatabaseMockRecorder) Execute(ctx, request any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Execute", reflect.TypeOf((*MockDatabase)(nil).Execute), ctx, request)
}
// ExecuteSQL mocks base method.
func (m *MockDatabase) ExecuteSQL(ctx context.Context, req *database.ExecuteSQLRequest) (*database.ExecuteSQLResponse, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ExecuteSQL", ctx, req)
ret0, _ := ret[0].(*database.ExecuteSQLResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ExecuteSQL indicates an expected call of ExecuteSQL.
func (mr *MockDatabaseMockRecorder) ExecuteSQL(ctx, req any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExecuteSQL", reflect.TypeOf((*MockDatabase)(nil).ExecuteSQL), ctx, req)
}
// GetAllDatabaseByAppID mocks base method.
func (m *MockDatabase) GetAllDatabaseByAppID(ctx context.Context, req *database.GetAllDatabaseByAppIDRequest) (*database.GetAllDatabaseByAppIDResponse, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAllDatabaseByAppID", ctx, req)
ret0, _ := ret[0].(*database.GetAllDatabaseByAppIDResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAllDatabaseByAppID indicates an expected call of GetAllDatabaseByAppID.
func (mr *MockDatabaseMockRecorder) GetAllDatabaseByAppID(ctx, req any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllDatabaseByAppID", reflect.TypeOf((*MockDatabase)(nil).GetAllDatabaseByAppID), ctx, req)
}
// Insert mocks base method.
func (m *MockDatabase) Insert(ctx context.Context, request *database.InsertRequest) (*database.Response, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Insert", ctx, request)
ret0, _ := ret[0].(*database.Response)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Insert indicates an expected call of Insert.
func (mr *MockDatabaseMockRecorder) Insert(ctx, request any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Insert", reflect.TypeOf((*MockDatabase)(nil).Insert), ctx, request)
}
// MGetDatabase mocks base method.
func (m *MockDatabase) MGetDatabase(ctx context.Context, req *database.MGetDatabaseRequest) (*database.MGetDatabaseResponse, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MGetDatabase", ctx, req)
ret0, _ := ret[0].(*database.MGetDatabaseResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// MGetDatabase indicates an expected call of MGetDatabase.
func (mr *MockDatabaseMockRecorder) MGetDatabase(ctx, req any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGetDatabase", reflect.TypeOf((*MockDatabase)(nil).MGetDatabase), ctx, req)
}
// PublishDatabase mocks base method.
func (m *MockDatabase) PublishDatabase(ctx context.Context, req *database.PublishDatabaseRequest) (*database.PublishDatabaseResponse, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "PublishDatabase", ctx, req)
ret0, _ := ret[0].(*database.PublishDatabaseResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// PublishDatabase indicates an expected call of PublishDatabase.
func (mr *MockDatabaseMockRecorder) PublishDatabase(ctx, req any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublishDatabase", reflect.TypeOf((*MockDatabase)(nil).PublishDatabase), ctx, req)
}
// Query mocks base method.
func (m *MockDatabase) Query(ctx context.Context, request *database.QueryRequest) (*database.Response, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Query", ctx, request)
ret0, _ := ret[0].(*database.Response)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Query indicates an expected call of Query.
func (mr *MockDatabaseMockRecorder) Query(ctx, request any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockDatabase)(nil).Query), ctx, request)
}
// UnBindDatabase mocks base method.
func (m *MockDatabase) UnBindDatabase(ctx context.Context, req *database.UnBindDatabaseToAgentRequest) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UnBindDatabase", ctx, req)
ret0, _ := ret[0].(error)
return ret0
}
// UnBindDatabase indicates an expected call of UnBindDatabase.
func (mr *MockDatabaseMockRecorder) UnBindDatabase(ctx, req any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnBindDatabase", reflect.TypeOf((*MockDatabase)(nil).UnBindDatabase), ctx, req)
}
// Update mocks base method.
func (m *MockDatabase) Update(arg0 context.Context, arg1 *database.UpdateRequest) (*database.Response, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Update", arg0, arg1)
ret0, _ := ret[0].(*database.Response)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Update indicates an expected call of Update.
func (mr *MockDatabaseMockRecorder) Update(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockDatabase)(nil).Update), arg0, arg1)
}

View File

@ -18,10 +18,20 @@ package database
import ( import (
"context" "context"
"fmt"
"strings"
"github.com/spf13/cast"
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database" model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/api/model/data/database/table"
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database" crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database"
"github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
database "github.com/coze-dev/coze-studio/backend/domain/memory/database/service" database "github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
) )
var defaultSVC crossdatabase.Database var defaultSVC crossdatabase.Database
@ -65,3 +75,406 @@ func (c *databaseImpl) MGetDatabase(ctx context.Context, req *model.MGetDatabase
func (c *databaseImpl) GetAllDatabaseByAppID(ctx context.Context, req *model.GetAllDatabaseByAppIDRequest) (*model.GetAllDatabaseByAppIDResponse, error) { func (c *databaseImpl) GetAllDatabaseByAppID(ctx context.Context, req *model.GetAllDatabaseByAppIDRequest) (*model.GetAllDatabaseByAppIDResponse, error) {
return c.DomainSVC.GetAllDatabaseByAppID(ctx, req) return c.DomainSVC.GetAllDatabaseByAppID(ctx, req)
} }
func (d *databaseImpl) Execute(ctx context.Context, request *model.CustomSQLRequest) (*model.Response, error) {
var (
err error
databaseInfoID = request.DatabaseInfoID
tableType = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
)
if request.IsDebugRun {
databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
if err != nil {
return nil, err
}
}
req := &service.ExecuteSQLRequest{
DatabaseID: databaseInfoID,
OperateType: model.OperateType_Custom,
SQL: &request.SQL,
TableType: tableType,
UserID: request.UserID,
ConnectorID: ptr.Of(request.ConnectorID),
}
req.SQLParams = make([]*model.SQLParamVal, 0, len(request.Params))
for i := range request.Params {
param := request.Params[i]
req.SQLParams = append(req.SQLParams, &model.SQLParamVal{
ValueType: table.FieldItemType_Text,
Value: &param.Value,
ISNull: param.IsNull,
})
}
response, err := d.DomainSVC.ExecuteSQL(ctx, req)
if err != nil {
return nil, err
}
// if rows affected is nil use 0 instead
if response.RowsAffected == nil {
response.RowsAffected = ptr.Of(int64(0))
}
return toNodeDateBaseResponse(response), nil
}
func (d *databaseImpl) Delete(ctx context.Context, request *model.DeleteRequest) (*model.Response, error) {
var (
err error
databaseInfoID = request.DatabaseInfoID
tableType = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
)
if request.IsDebugRun {
databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
if err != nil {
return nil, err
}
}
req := &service.ExecuteSQLRequest{
DatabaseID: databaseInfoID,
OperateType: model.OperateType_Delete,
TableType: tableType,
UserID: request.UserID,
ConnectorID: ptr.Of(request.ConnectorID),
}
if request.ConditionGroup != nil {
req.Condition, req.SQLParams, err = buildComplexCondition(request.ConditionGroup)
if err != nil {
return nil, err
}
}
response, err := d.DomainSVC.ExecuteSQL(ctx, req)
if err != nil {
return nil, err
}
return toNodeDateBaseResponse(response), nil
}
func (d *databaseImpl) Query(ctx context.Context, request *model.QueryRequest) (*model.Response, error) {
var (
err error
databaseInfoID = request.DatabaseInfoID
tableType = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
)
if request.IsDebugRun {
databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
if err != nil {
return nil, err
}
}
req := &service.ExecuteSQLRequest{
DatabaseID: databaseInfoID,
OperateType: model.OperateType_Select,
TableType: tableType,
UserID: request.UserID,
ConnectorID: ptr.Of(request.ConnectorID),
}
req.SelectFieldList = &model.SelectFieldList{FieldID: make([]string, 0, len(request.SelectFields))}
for i := range request.SelectFields {
req.SelectFieldList.FieldID = append(req.SelectFieldList.FieldID, request.SelectFields[i])
}
req.OrderByList = make([]model.OrderBy, 0)
for i := range request.OrderClauses {
clause := request.OrderClauses[i]
req.OrderByList = append(req.OrderByList, model.OrderBy{
Field: clause.FieldID,
Direction: toOrderDirection(clause.IsAsc),
})
}
if request.ConditionGroup != nil {
req.Condition, req.SQLParams, err = buildComplexCondition(request.ConditionGroup)
if err != nil {
return nil, err
}
}
limit := request.Limit
req.Limit = &limit
response, err := d.DomainSVC.ExecuteSQL(ctx, req)
if err != nil {
return nil, err
}
return toNodeDateBaseResponse(response), nil
}
func (d *databaseImpl) Update(ctx context.Context, request *model.UpdateRequest) (*model.Response, error) {
var (
err error
condition *model.ComplexCondition
params []*model.SQLParamVal
databaseInfoID = request.DatabaseInfoID
tableType = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
)
if request.IsDebugRun {
databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
if err != nil {
return nil, err
}
}
req := &service.ExecuteSQLRequest{
DatabaseID: databaseInfoID,
OperateType: model.OperateType_Update,
SQLParams: make([]*model.SQLParamVal, 0),
TableType: tableType,
UserID: request.UserID,
ConnectorID: ptr.Of(request.ConnectorID),
}
uid := ctxutil.GetUIDFromCtx(ctx)
if uid != nil {
req.UserID = conv.Int64ToStr(*uid)
}
req.UpsertRows, req.SQLParams, err = resolveUpsertRow(request.Fields)
if err != nil {
return nil, err
}
if request.ConditionGroup != nil {
condition, params, err = buildComplexCondition(request.ConditionGroup)
if err != nil {
return nil, err
}
req.Condition = condition
req.SQLParams = append(req.SQLParams, params...)
}
response, err := d.DomainSVC.ExecuteSQL(ctx, req)
if err != nil {
return nil, err
}
return toNodeDateBaseResponse(response), nil
}
func (d *databaseImpl) Insert(ctx context.Context, request *model.InsertRequest) (*model.Response, error) {
var (
err error
databaseInfoID = request.DatabaseInfoID
tableType = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
)
if request.IsDebugRun {
databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
if err != nil {
return nil, err
}
}
req := &service.ExecuteSQLRequest{
DatabaseID: databaseInfoID,
OperateType: model.OperateType_Insert,
TableType: tableType,
UserID: request.UserID,
ConnectorID: ptr.Of(request.ConnectorID),
}
req.UpsertRows, req.SQLParams, err = resolveUpsertRow(request.Fields)
if err != nil {
return nil, err
}
response, err := d.DomainSVC.ExecuteSQL(ctx, req)
if err != nil {
return nil, err
}
return toNodeDateBaseResponse(response), nil
}
func (d *databaseImpl) getDraftTableID(ctx context.Context, onlineID int64) (int64, error) {
resp, err := d.DomainSVC.GetDraftDatabaseByOnlineID(ctx, &service.GetDraftDatabaseByOnlineIDRequest{OnlineID: onlineID})
if err != nil {
return 0, err
}
return resp.Database.ID, nil
}
func buildComplexCondition(conditionGroup *model.ConditionGroup) (*model.ComplexCondition, []*model.SQLParamVal, error) {
condition := &model.ComplexCondition{}
logic, err := toLogic(conditionGroup.Relation)
if err != nil {
return nil, nil, err
}
condition.Logic = logic
params := make([]*model.SQLParamVal, 0)
for i := range conditionGroup.Conditions {
var (
nCond = conditionGroup.Conditions[i]
vals []*model.SQLParamVal
dCond = &model.Condition{
Left: nCond.Left,
}
)
opt, err := toOperation(nCond.Operator)
if err != nil {
return nil, nil, err
}
dCond.Operation = opt
if isNullOrNotNull(opt) {
condition.Conditions = append(condition.Conditions, dCond)
continue
}
dCond.Right, vals, err = resolveRightValue(opt, nCond.Right)
if err != nil {
return nil, nil, err
}
condition.Conditions = append(condition.Conditions, dCond)
params = append(params, vals...)
}
return condition, params, nil
}
func toMapStringAny(m map[string]string) map[string]any {
ret := make(map[string]any, len(m))
for k, v := range m {
ret[k] = v
}
return ret
}
func toOperation(operator model.Operator) (model.Operation, error) {
switch operator {
case model.OperatorEqual:
return model.Operation_EQUAL, nil
case model.OperatorNotEqual:
return model.Operation_NOT_EQUAL, nil
case model.OperatorGreater:
return model.Operation_GREATER_THAN, nil
case model.OperatorGreaterOrEqual:
return model.Operation_GREATER_EQUAL, nil
case model.OperatorLesser:
return model.Operation_LESS_THAN, nil
case model.OperatorLesserOrEqual:
return model.Operation_LESS_EQUAL, nil
case model.OperatorIn:
return model.Operation_IN, nil
case model.OperatorNotIn:
return model.Operation_NOT_IN, nil
case model.OperatorIsNotNull:
return model.Operation_IS_NOT_NULL, nil
case model.OperatorIsNull:
return model.Operation_IS_NULL, nil
case model.OperatorLike:
return model.Operation_LIKE, nil
case model.OperatorNotLike:
return model.Operation_NOT_LIKE, nil
default:
return model.Operation(0), fmt.Errorf("invalid operator %v", operator)
}
}
func resolveRightValue(operator model.Operation, right any) (string, []*model.SQLParamVal, error) {
if isInOrNotIn(operator) {
var (
vals = make([]*model.SQLParamVal, 0)
anyVals = make([]any, 0)
commas = make([]string, 0, len(anyVals))
)
anyVals = right.([]any)
for i := range anyVals {
v := cast.ToString(anyVals[i])
vals = append(vals, &model.SQLParamVal{ValueType: table.FieldItemType_Text, Value: &v})
commas = append(commas, "?")
}
value := "(" + strings.Join(commas, ",") + ")"
return value, vals, nil
}
rightValue, err := cast.ToStringE(right)
if err != nil {
return "", nil, err
}
if isLikeOrNotLike(operator) {
var (
value = "?"
v = "%s" + rightValue + "%s"
)
return value, []*model.SQLParamVal{{ValueType: table.FieldItemType_Text, Value: &v}}, nil
}
return "?", []*model.SQLParamVal{{ValueType: table.FieldItemType_Text, Value: &rightValue}}, nil
}
func resolveUpsertRow(fields map[string]any) ([]*model.UpsertRow, []*model.SQLParamVal, error) {
upsertRow := &model.UpsertRow{Records: make([]*model.Record, 0, len(fields))}
params := make([]*model.SQLParamVal, 0)
for key, value := range fields {
val, err := cast.ToStringE(value)
if err != nil {
return nil, nil, err
}
record := &model.Record{
FieldId: key,
FieldValue: "?",
}
upsertRow.Records = append(upsertRow.Records, record)
params = append(params, &model.SQLParamVal{
ValueType: table.FieldItemType_Text,
Value: &val,
})
}
return []*model.UpsertRow{upsertRow}, params, nil
}
func isNullOrNotNull(opt model.Operation) bool {
return opt == model.Operation_IS_NOT_NULL || opt == model.Operation_IS_NULL
}
func isLikeOrNotLike(opt model.Operation) bool {
return opt == model.Operation_LIKE || opt == model.Operation_NOT_LIKE
}
func isInOrNotIn(opt model.Operation) bool {
return opt == model.Operation_IN || opt == model.Operation_NOT_IN
}
func toOrderDirection(isAsc bool) table.SortDirection {
if isAsc {
return table.SortDirection_ASC
}
return table.SortDirection_Desc
}
func toLogic(relation model.ClauseRelation) (model.Logic, error) {
switch relation {
case model.ClauseRelationOR:
return model.Logic_Or, nil
case model.ClauseRelationAND:
return model.Logic_And, nil
default:
return model.Logic(0), fmt.Errorf("invalid relation %v", relation)
}
}
func toNodeDateBaseResponse(response *service.ExecuteSQLResponse) *model.Response {
objects := make([]model.Object, 0, len(response.Records))
for i := range response.Records {
objects = append(objects, response.Records[i])
}
return &model.Response{
Objects: objects,
RowNumber: response.RowsAffected,
}
}

View File

@ -1,447 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package database
import (
"context"
"fmt"
"strings"
"github.com/spf13/cast"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/api/model/data/database/table"
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
"github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
nodedatabase "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
)
type DatabaseRepository struct {
client service.Database
}
func NewDatabaseRepository(client service.Database) *DatabaseRepository {
return &DatabaseRepository{
client: client,
}
}
func (d *DatabaseRepository) Execute(ctx context.Context, request *nodedatabase.CustomSQLRequest) (*nodedatabase.Response, error) {
var (
err error
databaseInfoID = request.DatabaseInfoID
tableType = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
)
if request.IsDebugRun {
databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
if err != nil {
return nil, err
}
}
req := &service.ExecuteSQLRequest{
DatabaseID: databaseInfoID,
OperateType: database.OperateType_Custom,
SQL: &request.SQL,
TableType: tableType,
UserID: request.UserID,
ConnectorID: ptr.Of(request.ConnectorID),
}
req.SQLParams = make([]*database.SQLParamVal, 0, len(request.Params))
for i := range request.Params {
param := request.Params[i]
req.SQLParams = append(req.SQLParams, &database.SQLParamVal{
ValueType: table.FieldItemType_Text,
Value: &param.Value,
ISNull: param.IsNull,
})
}
response, err := d.client.ExecuteSQL(ctx, req)
if err != nil {
return nil, err
}
// if rows affected is nil use 0 instead
if response.RowsAffected == nil {
response.RowsAffected = ptr.Of(int64(0))
}
return toNodeDateBaseResponse(response), nil
}
func (d *DatabaseRepository) Delete(ctx context.Context, request *nodedatabase.DeleteRequest) (*nodedatabase.Response, error) {
var (
err error
databaseInfoID = request.DatabaseInfoID
tableType = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
)
if request.IsDebugRun {
databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
if err != nil {
return nil, err
}
}
req := &service.ExecuteSQLRequest{
DatabaseID: databaseInfoID,
OperateType: database.OperateType_Delete,
TableType: tableType,
UserID: request.UserID,
ConnectorID: ptr.Of(request.ConnectorID),
}
if request.ConditionGroup != nil {
req.Condition, req.SQLParams, err = buildComplexCondition(request.ConditionGroup)
if err != nil {
return nil, err
}
}
response, err := d.client.ExecuteSQL(ctx, req)
if err != nil {
return nil, err
}
return toNodeDateBaseResponse(response), nil
}
func (d *DatabaseRepository) Query(ctx context.Context, request *nodedatabase.QueryRequest) (*nodedatabase.Response, error) {
var (
err error
databaseInfoID = request.DatabaseInfoID
tableType = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
)
if request.IsDebugRun {
databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
if err != nil {
return nil, err
}
}
req := &service.ExecuteSQLRequest{
DatabaseID: databaseInfoID,
OperateType: database.OperateType_Select,
TableType: tableType,
UserID: request.UserID,
ConnectorID: ptr.Of(request.ConnectorID),
}
req.SelectFieldList = &database.SelectFieldList{FieldID: make([]string, 0, len(request.SelectFields))}
for i := range request.SelectFields {
req.SelectFieldList.FieldID = append(req.SelectFieldList.FieldID, request.SelectFields[i])
}
req.OrderByList = make([]database.OrderBy, 0)
for i := range request.OrderClauses {
clause := request.OrderClauses[i]
req.OrderByList = append(req.OrderByList, database.OrderBy{
Field: clause.FieldID,
Direction: toOrderDirection(clause.IsAsc),
})
}
if request.ConditionGroup != nil {
req.Condition, req.SQLParams, err = buildComplexCondition(request.ConditionGroup)
if err != nil {
return nil, err
}
}
limit := request.Limit
req.Limit = &limit
response, err := d.client.ExecuteSQL(ctx, req)
if err != nil {
return nil, err
}
return toNodeDateBaseResponse(response), nil
}
func (d *DatabaseRepository) Update(ctx context.Context, request *nodedatabase.UpdateRequest) (*nodedatabase.Response, error) {
var (
err error
condition *database.ComplexCondition
params []*database.SQLParamVal
databaseInfoID = request.DatabaseInfoID
tableType = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
)
if request.IsDebugRun {
databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
if err != nil {
return nil, err
}
}
req := &service.ExecuteSQLRequest{
DatabaseID: databaseInfoID,
OperateType: database.OperateType_Update,
SQLParams: make([]*database.SQLParamVal, 0),
TableType: tableType,
UserID: request.UserID,
ConnectorID: ptr.Of(request.ConnectorID),
}
uid := ctxutil.GetUIDFromCtx(ctx)
if uid != nil {
req.UserID = conv.Int64ToStr(*uid)
}
req.UpsertRows, req.SQLParams, err = resolveUpsertRow(request.Fields)
if err != nil {
return nil, err
}
if request.ConditionGroup != nil {
condition, params, err = buildComplexCondition(request.ConditionGroup)
if err != nil {
return nil, err
}
req.Condition = condition
req.SQLParams = append(req.SQLParams, params...)
}
response, err := d.client.ExecuteSQL(ctx, req)
if err != nil {
return nil, err
}
return toNodeDateBaseResponse(response), nil
}
func (d *DatabaseRepository) Insert(ctx context.Context, request *nodedatabase.InsertRequest) (*nodedatabase.Response, error) {
var (
err error
databaseInfoID = request.DatabaseInfoID
tableType = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
)
if request.IsDebugRun {
databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
if err != nil {
return nil, err
}
}
req := &service.ExecuteSQLRequest{
DatabaseID: databaseInfoID,
OperateType: database.OperateType_Insert,
TableType: tableType,
UserID: request.UserID,
ConnectorID: ptr.Of(request.ConnectorID),
}
req.UpsertRows, req.SQLParams, err = resolveUpsertRow(request.Fields)
if err != nil {
return nil, err
}
response, err := d.client.ExecuteSQL(ctx, req)
if err != nil {
return nil, err
}
return toNodeDateBaseResponse(response), nil
}
func (d *DatabaseRepository) getDraftTableID(ctx context.Context, onlineID int64) (int64, error) {
resp, err := d.client.GetDraftDatabaseByOnlineID(ctx, &service.GetDraftDatabaseByOnlineIDRequest{OnlineID: onlineID})
if err != nil {
return 0, err
}
return resp.Database.ID, nil
}
func buildComplexCondition(conditionGroup *nodedatabase.ConditionGroup) (*database.ComplexCondition, []*database.SQLParamVal, error) {
condition := &database.ComplexCondition{}
logic, err := toLogic(conditionGroup.Relation)
if err != nil {
return nil, nil, err
}
condition.Logic = logic
params := make([]*database.SQLParamVal, 0)
for i := range conditionGroup.Conditions {
var (
nCond = conditionGroup.Conditions[i]
vals []*database.SQLParamVal
dCond = &database.Condition{
Left: nCond.Left,
}
)
opt, err := toOperation(nCond.Operator)
if err != nil {
return nil, nil, err
}
dCond.Operation = opt
if isNullOrNotNull(opt) {
condition.Conditions = append(condition.Conditions, dCond)
continue
}
dCond.Right, vals, err = resolveRightValue(opt, nCond.Right)
if err != nil {
return nil, nil, err
}
condition.Conditions = append(condition.Conditions, dCond)
params = append(params, vals...)
}
return condition, params, nil
}
func toMapStringAny(m map[string]string) map[string]any {
ret := make(map[string]any, len(m))
for k, v := range m {
ret[k] = v
}
return ret
}
func toOperation(operator nodedatabase.Operator) (database.Operation, error) {
switch operator {
case nodedatabase.OperatorEqual:
return database.Operation_EQUAL, nil
case nodedatabase.OperatorNotEqual:
return database.Operation_NOT_EQUAL, nil
case nodedatabase.OperatorGreater:
return database.Operation_GREATER_THAN, nil
case nodedatabase.OperatorGreaterOrEqual:
return database.Operation_GREATER_EQUAL, nil
case nodedatabase.OperatorLesser:
return database.Operation_LESS_THAN, nil
case nodedatabase.OperatorLesserOrEqual:
return database.Operation_LESS_EQUAL, nil
case nodedatabase.OperatorIn:
return database.Operation_IN, nil
case nodedatabase.OperatorNotIn:
return database.Operation_NOT_IN, nil
case nodedatabase.OperatorIsNotNull:
return database.Operation_IS_NOT_NULL, nil
case nodedatabase.OperatorIsNull:
return database.Operation_IS_NULL, nil
case nodedatabase.OperatorLike:
return database.Operation_LIKE, nil
case nodedatabase.OperatorNotLike:
return database.Operation_NOT_LIKE, nil
default:
return database.Operation(0), fmt.Errorf("invalid operator %v", operator)
}
}
func resolveRightValue(operator database.Operation, right any) (string, []*database.SQLParamVal, error) {
if isInOrNotIn(operator) {
var (
vals = make([]*database.SQLParamVal, 0)
anyVals = make([]any, 0)
commas = make([]string, 0, len(anyVals))
)
anyVals = right.([]any)
for i := range anyVals {
v := cast.ToString(anyVals[i])
vals = append(vals, &database.SQLParamVal{ValueType: table.FieldItemType_Text, Value: &v})
commas = append(commas, "?")
}
value := "(" + strings.Join(commas, ",") + ")"
return value, vals, nil
}
rightValue, err := cast.ToStringE(right)
if err != nil {
return "", nil, err
}
if isLikeOrNotLike(operator) {
var (
value = "?"
v = "%s" + rightValue + "%s"
)
return value, []*database.SQLParamVal{{ValueType: table.FieldItemType_Text, Value: &v}}, nil
}
return "?", []*database.SQLParamVal{{ValueType: table.FieldItemType_Text, Value: &rightValue}}, nil
}
func resolveUpsertRow(fields map[string]any) ([]*database.UpsertRow, []*database.SQLParamVal, error) {
upsertRow := &database.UpsertRow{Records: make([]*database.Record, 0, len(fields))}
params := make([]*database.SQLParamVal, 0)
for key, value := range fields {
val, err := cast.ToStringE(value)
if err != nil {
return nil, nil, err
}
record := &database.Record{
FieldId: key,
FieldValue: "?",
}
upsertRow.Records = append(upsertRow.Records, record)
params = append(params, &database.SQLParamVal{
ValueType: table.FieldItemType_Text,
Value: &val,
})
}
return []*database.UpsertRow{upsertRow}, params, nil
}
func isNullOrNotNull(opt database.Operation) bool {
return opt == database.Operation_IS_NOT_NULL || opt == database.Operation_IS_NULL
}
func isLikeOrNotLike(opt database.Operation) bool {
return opt == database.Operation_LIKE || opt == database.Operation_NOT_LIKE
}
func isInOrNotIn(opt database.Operation) bool {
return opt == database.Operation_IN || opt == database.Operation_NOT_IN
}
func toOrderDirection(isAsc bool) table.SortDirection {
if isAsc {
return table.SortDirection_ASC
}
return table.SortDirection_Desc
}
func toLogic(relation nodedatabase.ClauseRelation) (database.Logic, error) {
switch relation {
case nodedatabase.ClauseRelationOR:
return database.Logic_Or, nil
case nodedatabase.ClauseRelationAND:
return database.Logic_And, nil
default:
return database.Logic(0), fmt.Errorf("invalid relation %v", relation)
}
}
func toNodeDateBaseResponse(response *service.ExecuteSQLResponse) *nodedatabase.Response {
objects := make([]nodedatabase.Object, 0, len(response.Records))
for i := range response.Records {
objects = append(objects, response.Records[i])
}
return &nodedatabase.Response{
Objects: objects,
RowNumber: response.RowsAffected,
}
}

View File

@ -1,224 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package database
import (
"context"
"testing"
"github.com/spf13/cast"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
nodedatabase "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
mockDatabase "github.com/coze-dev/coze-studio/backend/internal/mock/domain/memory/database"
)
func mockExecuteSQL(t *testing.T) func(ctx context.Context, request *service.ExecuteSQLRequest) (*service.ExecuteSQLResponse, error) {
return func(ctx context.Context, request *service.ExecuteSQLRequest) (*service.ExecuteSQLResponse, error) {
if request.OperateType == database.OperateType_Custom {
assert.Equal(t, *request.SQL, "select * from table where v1=? and v2=?")
rs := make([]string, 0)
for idx := range request.SQLParams {
rs = append(rs, *request.SQLParams[idx].Value)
}
assert.Equal(t, rs, []string{"1", "2"})
return &service.ExecuteSQLResponse{
Records: []map[string]any{
{"v1": "1", "v2": "2"},
},
}, nil
}
if request.OperateType == database.OperateType_Select {
sFields := []string{"v1", "v2", "v3", "v4"}
assert.Equal(t, request.SelectFieldList.FieldID, sFields)
cond := request.Condition.Conditions[1] // in
assert.Equal(t, "(?,?)", cond.Right)
assert.Equal(t, database.Operation_IN, cond.Operation)
assert.Equal(t, "v2_1", *request.SQLParams[1].Value)
assert.Equal(t, "v2_2", *request.SQLParams[2].Value)
assert.Equal(t, "%sv4%s", *request.SQLParams[3].Value)
rowsAffected := int64(10)
return &service.ExecuteSQLResponse{
Records: []map[string]any{
{"v1": "1", "v2": "2", "v3": "3", "v4": "4"},
},
RowsAffected: &rowsAffected,
}, nil
}
if request.OperateType == database.OperateType_Delete {
cond := request.Condition.Conditions[1] // in
assert.Equal(t, "(?,?)", cond.Right)
assert.Equal(t, database.Operation_NOT_IN, cond.Operation)
assert.Equal(t, "v2_1", *request.SQLParams[1].Value)
assert.Equal(t, "v2_2", *request.SQLParams[2].Value)
assert.Equal(t, "%sv4%s", *request.SQLParams[3].Value)
rowsAffected := int64(10)
return &service.ExecuteSQLResponse{
Records: []map[string]any{
{"v1": "1", "v2": "2", "v3": "3", "v4": "4"},
},
RowsAffected: &rowsAffected,
}, nil
}
if request.OperateType == database.OperateType_Insert {
records := request.UpsertRows[0].Records
ret := map[string]interface{}{
"v1": "1",
"v2": 2,
"v3": 33,
"v4": "44aacc",
}
for idx := range records {
assert.Equal(t, *request.SQLParams[idx].Value, cast.ToString(ret[records[idx].FieldId]))
}
}
if request.OperateType == database.OperateType_Update {
records := request.UpsertRows[0].Records
ret := map[string]interface{}{
"v1": "1",
"v2": 2,
"v3": 33,
"v4": "aabbcc",
}
for idx := range records {
assert.Equal(t, *request.SQLParams[idx].Value, cast.ToString(ret[records[idx].FieldId]))
}
request.SQLParams = request.SQLParams[len(records):]
cond := request.Condition.Conditions[1] // in
assert.Equal(t, "(?,?)", cond.Right)
assert.Equal(t, database.Operation_IN, cond.Operation)
assert.Equal(t, "v2_1", *request.SQLParams[1].Value)
assert.Equal(t, "v2_2", *request.SQLParams[2].Value)
assert.Equal(t, "%sv4%s", *request.SQLParams[3].Value)
}
return &service.ExecuteSQLResponse{}, nil
}
}
func TestDatabase_Database(t *testing.T) {
ctrl := gomock.NewController(t)
mockClient := mockDatabase.NewMockDatabase(ctrl)
defer ctrl.Finish()
ds := DatabaseRepository{
client: mockClient,
}
mockClient.EXPECT().ExecuteSQL(gomock.Any(), gomock.Any()).DoAndReturn(mockExecuteSQL(t)).AnyTimes()
t.Run("execute", func(t *testing.T) {
response, err := ds.Execute(context.Background(), &nodedatabase.CustomSQLRequest{
DatabaseInfoID: 1,
SQL: "select * from table where v1=? and v2=?",
Params: []nodedatabase.SQLParam{
nodedatabase.SQLParam{Value: "1"},
nodedatabase.SQLParam{Value: "2"},
},
})
assert.Nil(t, err)
assert.Equal(t, response.Objects, []nodedatabase.Object{
{"v1": "1", "v2": "2"},
})
})
t.Run("select", func(t *testing.T) {
req := &nodedatabase.QueryRequest{
DatabaseInfoID: 1,
SelectFields: []string{"v1", "v2", "v3", "v4"},
Limit: 10,
OrderClauses: []*nodedatabase.OrderClause{
{FieldID: "v1", IsAsc: true},
{FieldID: "v2", IsAsc: false},
},
ConditionGroup: &nodedatabase.ConditionGroup{
Conditions: []*nodedatabase.Condition{
{Left: "v1", Operator: nodedatabase.OperatorEqual, Right: "1"},
{Left: "v2", Operator: nodedatabase.OperatorIn, Right: []any{"v2_1", "v2_2"}},
{Left: "v3", Operator: nodedatabase.OperatorIsNull},
{Left: "v4", Operator: nodedatabase.OperatorLike, Right: "v4"},
},
Relation: nodedatabase.ClauseRelationOR,
},
}
response, err := ds.Query(context.Background(), req)
assert.Nil(t, err)
assert.Equal(t, *response.RowNumber, int64(10))
})
t.Run("delete", func(t *testing.T) {
req := &nodedatabase.DeleteRequest{
DatabaseInfoID: 1,
ConditionGroup: &nodedatabase.ConditionGroup{
Conditions: []*nodedatabase.Condition{
{Left: "v1", Operator: nodedatabase.OperatorEqual, Right: "1"},
{Left: "v2", Operator: nodedatabase.OperatorNotIn, Right: []any{"v2_1", "v2_2"}},
{Left: "v3", Operator: nodedatabase.OperatorIsNotNull},
{Left: "v4", Operator: nodedatabase.OperatorNotLike, Right: "v4"},
},
Relation: nodedatabase.ClauseRelationOR,
},
}
response, err := ds.Delete(context.Background(), req)
assert.Nil(t, err)
assert.Equal(t, *response.RowNumber, int64(10))
})
t.Run("insert", func(t *testing.T) {
req := &nodedatabase.InsertRequest{
DatabaseInfoID: 1,
Fields: map[string]interface{}{
"v1": "1",
"v2": 2,
"v3": 33,
"v4": "44aacc",
},
}
_, err := ds.Insert(context.Background(), req)
assert.Nil(t, err)
})
t.Run("update", func(t *testing.T) {
req := &nodedatabase.UpdateRequest{
DatabaseInfoID: 1,
ConditionGroup: &nodedatabase.ConditionGroup{
Conditions: []*nodedatabase.Condition{
{Left: "v1", Operator: nodedatabase.OperatorEqual, Right: "1"},
{Left: "v2", Operator: nodedatabase.OperatorIn, Right: []any{"v2_1", "v2_2"}},
{Left: "v3", Operator: nodedatabase.OperatorIsNull},
{Left: "v4", Operator: nodedatabase.OperatorLike, Right: "v4"},
},
Relation: nodedatabase.ClauseRelationOR,
},
Fields: map[string]interface{}{
"v1": "1",
"v2": 2,
"v3": 33,
"v4": "aabbcc",
},
}
_, err := ds.Update(context.Background(), req)
assert.Nil(t, err)
})
}

View File

@ -1,61 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package conversation
import "context"
type ClearMessageRequest struct {
Name string
}
type ClearMessageResponse struct {
IsSuccess bool
}
type CreateConversationRequest struct {
Name string
}
type CreateConversationResponse struct {
Result map[string]any
}
type ListMessageRequest struct {
ConversationName string
Limit *int
BeforeID *string
AfterID *string
}
type Message struct {
ID string `json:"id"`
Role string `json:"role"`
ContentType string `json:"contentType"`
Content string `json:"content"`
}
type ListMessageResponse struct {
Messages []*Message
FirstID string
LastID string
HasMore bool
}
var ConversationManagerImpl ConversationManager
type ConversationManager interface {
ClearMessage(context.Context, *ClearMessageRequest) (*ClearMessageResponse, error)
CreateConversation(ctx context.Context, c *CreateConversationRequest) (*CreateConversationResponse, error)
MessageList(ctx context.Context, req *ListMessageRequest) (*ListMessageResponse, error)
}

View File

@ -1,148 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package database
import (
"context"
)
type SQLParam struct {
Value string
IsNull bool
}
type CustomSQLRequest struct {
DatabaseInfoID int64
SQL string
Params []SQLParam
IsDebugRun bool
UserID string
ConnectorID int64
}
type Object = map[string]any
type Response struct {
RowNumber *int64
Objects []Object
}
type Operator string
type ClauseRelation string
const (
ClauseRelationAND ClauseRelation = "and"
ClauseRelationOR ClauseRelation = "or"
)
const (
OperatorEqual Operator = "="
OperatorNotEqual Operator = "!="
OperatorGreater Operator = ">"
OperatorLesser Operator = "<"
OperatorGreaterOrEqual Operator = ">="
OperatorLesserOrEqual Operator = "<="
OperatorIn Operator = "in"
OperatorNotIn Operator = "not_in"
OperatorIsNull Operator = "is_null"
OperatorIsNotNull Operator = "is_not_null"
OperatorLike Operator = "like"
OperatorNotLike Operator = "not_like"
)
type ClauseGroup struct {
Single *Clause
Multi *MultiClause
}
type Clause struct {
Left string
Operator Operator
}
type MultiClause struct {
Clauses []*Clause
Relation ClauseRelation
}
type Condition struct {
Left string
Operator Operator
Right any
}
type ConditionGroup struct {
Conditions []*Condition
Relation ClauseRelation
}
type DeleteRequest struct {
DatabaseInfoID int64
ConditionGroup *ConditionGroup
IsDebugRun bool
UserID string
ConnectorID int64
}
type QueryRequest struct {
DatabaseInfoID int64
SelectFields []string
Limit int64
ConditionGroup *ConditionGroup
OrderClauses []*OrderClause
IsDebugRun bool
UserID string
ConnectorID int64
}
type OrderClause struct {
FieldID string
IsAsc bool
}
type UpdateRequest struct {
DatabaseInfoID int64
ConditionGroup *ConditionGroup
Fields map[string]any
IsDebugRun bool
UserID string
ConnectorID int64
}
type InsertRequest struct {
DatabaseInfoID int64
Fields map[string]any
IsDebugRun bool
UserID string
ConnectorID int64
}
func GetDatabaseOperator() DatabaseOperator {
return databaseOperatorImpl
}
func SetDatabaseOperator(d DatabaseOperator) {
databaseOperatorImpl = d
}
var (
databaseOperatorImpl DatabaseOperator
)
//go:generate mockgen -destination databasemock/database_mock.go --package databasemock -source database.go
type DatabaseOperator interface {
Execute(ctx context.Context, request *CustomSQLRequest) (*Response, error)
Query(ctx context.Context, request *QueryRequest) (*Response, error)
Update(context.Context, *UpdateRequest) (*Response, error)
Insert(ctx context.Context, request *InsertRequest) (*Response, error)
Delete(context.Context, *DeleteRequest) (*Response, error)
}

View File

@ -1,133 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Code generated by MockGen. DO NOT EDIT.
// Source: database.go
//
// Generated by this command:
//
// mockgen -destination databasemock/database_mock.go --package databasemock -source database.go
//
// Package databasemock is a generated GoMock package.
package databasemock
import (
context "context"
reflect "reflect"
database "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
gomock "go.uber.org/mock/gomock"
)
// MockDatabaseOperator is a mock of DatabaseOperator interface.
type MockDatabaseOperator struct {
ctrl *gomock.Controller
recorder *MockDatabaseOperatorMockRecorder
isgomock struct{}
}
// MockDatabaseOperatorMockRecorder is the mock recorder for MockDatabaseOperator.
type MockDatabaseOperatorMockRecorder struct {
mock *MockDatabaseOperator
}
// NewMockDatabaseOperator creates a new mock instance.
func NewMockDatabaseOperator(ctrl *gomock.Controller) *MockDatabaseOperator {
mock := &MockDatabaseOperator{ctrl: ctrl}
mock.recorder = &MockDatabaseOperatorMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockDatabaseOperator) EXPECT() *MockDatabaseOperatorMockRecorder {
return m.recorder
}
// Delete mocks base method.
func (m *MockDatabaseOperator) Delete(arg0 context.Context, arg1 *database.DeleteRequest) (*database.Response, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Delete", arg0, arg1)
ret0, _ := ret[0].(*database.Response)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Delete indicates an expected call of Delete.
func (mr *MockDatabaseOperatorMockRecorder) Delete(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockDatabaseOperator)(nil).Delete), arg0, arg1)
}
// Execute mocks base method.
func (m *MockDatabaseOperator) Execute(ctx context.Context, request *database.CustomSQLRequest) (*database.Response, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Execute", ctx, request)
ret0, _ := ret[0].(*database.Response)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Execute indicates an expected call of Execute.
func (mr *MockDatabaseOperatorMockRecorder) Execute(ctx, request any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Execute", reflect.TypeOf((*MockDatabaseOperator)(nil).Execute), ctx, request)
}
// Insert mocks base method.
func (m *MockDatabaseOperator) Insert(ctx context.Context, request *database.InsertRequest) (*database.Response, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Insert", ctx, request)
ret0, _ := ret[0].(*database.Response)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Insert indicates an expected call of Insert.
func (mr *MockDatabaseOperatorMockRecorder) Insert(ctx, request any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Insert", reflect.TypeOf((*MockDatabaseOperator)(nil).Insert), ctx, request)
}
// Query mocks base method.
func (m *MockDatabaseOperator) Query(ctx context.Context, request *database.QueryRequest) (*database.Response, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Query", ctx, request)
ret0, _ := ret[0].(*database.Response)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Query indicates an expected call of Query.
func (mr *MockDatabaseOperatorMockRecorder) Query(ctx, request any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockDatabaseOperator)(nil).Query), ctx, request)
}
// Update mocks base method.
func (m *MockDatabaseOperator) Update(arg0 context.Context, arg1 *database.UpdateRequest) (*database.Response, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Update", arg0, arg1)
ret0, _ := ret[0].(*database.Response)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Update indicates an expected call of Update.
func (mr *MockDatabaseOperatorMockRecorder) Update(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockDatabaseOperator)(nil).Update), arg0, arg1)
}

View File

@ -32,13 +32,12 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock" "go.uber.org/mock/gomock"
"github.com/coze-dev/coze-studio/backend/infra/contract/coderunner" crossmodel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/database/databasemock"
"github.com/coze-dev/coze-studio/backend/crossdomain/impl/code"
userentity "github.com/coze-dev/coze-studio/backend/domain/user/entity" userentity "github.com/coze-dev/coze-studio/backend/domain/user/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow" "github.com/coze-dev/coze-studio/backend/domain/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/code"
crossdatabase "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database/databasemock"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge" "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge/knowledgemock" "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge/knowledgemock"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model" "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
@ -50,6 +49,7 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/infra/contract/coderunner"
mockWorkflow "github.com/coze-dev/coze-studio/backend/internal/mock/domain/workflow" mockWorkflow "github.com/coze-dev/coze-studio/backend/internal/mock/domain/workflow"
mockcode "github.com/coze-dev/coze-studio/backend/internal/mock/domain/workflow/crossdomain/code" mockcode "github.com/coze-dev/coze-studio/backend/internal/mock/domain/workflow/crossdomain/code"
"github.com/coze-dev/coze-studio/backend/internal/testutil" "github.com/coze-dev/coze-studio/backend/internal/testutil"
@ -105,10 +105,10 @@ func TestIntentDetectorAndDatabase(t *testing.T) {
} }
mockModelManager.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(chatModel, nil, nil).AnyTimes() mockModelManager.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(chatModel, nil, nil).AnyTimes()
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl) mockDatabaseOperator := databasemock.NewMockDatabase(ctrl)
n := int64(2) n := int64(2)
resp := &crossdatabase.Response{ resp := &crossmodel.Response{
Objects: []crossdatabase.Object{ Objects: []crossmodel.Object{
{ {
"v2": "123", "v2": "123",
}, },
@ -119,7 +119,7 @@ func TestIntentDetectorAndDatabase(t *testing.T) {
RowNumber: &n, RowNumber: &n,
} }
mockDatabaseOperator.EXPECT().Execute(gomock.Any(), gomock.Any()).Return(resp, nil).AnyTimes() mockDatabaseOperator.EXPECT().Execute(gomock.Any(), gomock.Any()).Return(resp, nil).AnyTimes()
mockey.Mock(crossdatabase.GetDatabaseOperator).Return(mockDatabaseOperator).Build() crossdatabase.SetDefaultSVC(mockDatabaseOperator)
workflowSC, err := CanvasToWorkflowSchema(ctx, c) workflowSC, err := CanvasToWorkflowSchema(ctx, c)
assert.NoError(t, err) assert.NoError(t, err)
@ -144,44 +144,44 @@ func TestIntentDetectorAndDatabase(t *testing.T) {
}) })
} }
func mockUpdate(t *testing.T) func(context.Context, *crossdatabase.UpdateRequest) (*crossdatabase.Response, error) { func mockUpdate(t *testing.T) func(context.Context, *crossmodel.UpdateRequest) (*crossmodel.Response, error) {
return func(ctx context.Context, req *crossdatabase.UpdateRequest) (*crossdatabase.Response, error) { return func(ctx context.Context, req *crossmodel.UpdateRequest) (*crossmodel.Response, error) {
assert.Equal(t, req.ConditionGroup.Conditions[0], &crossdatabase.Condition{ assert.Equal(t, req.ConditionGroup.Conditions[0], &crossmodel.ConditionStr{
Left: "v2", Left: "v2",
Operator: "=", Operator: "=",
Right: int64(1), Right: int64(1),
}) })
assert.Equal(t, req.ConditionGroup.Conditions[1], &crossdatabase.Condition{ assert.Equal(t, req.ConditionGroup.Conditions[1], &crossmodel.ConditionStr{
Left: "v1", Left: "v1",
Operator: "=", Operator: "=",
Right: "abc", Right: "abc",
}) })
assert.Equal(t, req.ConditionGroup.Relation, crossdatabase.ClauseRelationAND) assert.Equal(t, req.ConditionGroup.Relation, crossmodel.ClauseRelationAND)
assert.Equal(t, req.Fields, map[string]interface{}{ assert.Equal(t, req.Fields, map[string]interface{}{
"1783392627713": int64(123), "1783392627713": int64(123),
}) })
return &crossdatabase.Response{}, nil return &crossmodel.Response{}, nil
} }
} }
func mockInsert(t *testing.T) func(ctx context.Context, request *crossdatabase.InsertRequest) (*crossdatabase.Response, error) { func mockInsert(t *testing.T) func(ctx context.Context, request *crossmodel.InsertRequest) (*crossmodel.Response, error) {
return func(ctx context.Context, req *crossdatabase.InsertRequest) (*crossdatabase.Response, error) { return func(ctx context.Context, req *crossmodel.InsertRequest) (*crossmodel.Response, error) {
v := req.Fields["1785960530945"] v := req.Fields["1785960530945"]
assert.Equal(t, v, int64(123)) assert.Equal(t, v, int64(123))
vs := req.Fields["1783122026497"] vs := req.Fields["1783122026497"]
assert.Equal(t, vs, "input for database curd") assert.Equal(t, vs, "input for database curd")
n := int64(10) n := int64(10)
return &crossdatabase.Response{ return &crossmodel.Response{
RowNumber: &n, RowNumber: &n,
}, nil }, nil
} }
} }
func mockQuery(t *testing.T) func(ctx context.Context, request *crossdatabase.QueryRequest) (*crossdatabase.Response, error) { func mockQuery(t *testing.T) func(ctx context.Context, request *crossmodel.QueryRequest) (*crossmodel.Response, error) {
return func(ctx context.Context, req *crossdatabase.QueryRequest) (*crossdatabase.Response, error) { return func(ctx context.Context, req *crossmodel.QueryRequest) (*crossmodel.Response, error) {
assert.Equal(t, req.ConditionGroup.Conditions[0], &crossdatabase.Condition{ assert.Equal(t, req.ConditionGroup.Conditions[0], &crossmodel.ConditionStr{
Left: "v1", Left: "v1",
Operator: "=", Operator: "=",
Right: "abc", Right: "abc",
@ -191,26 +191,26 @@ func mockQuery(t *testing.T) func(ctx context.Context, request *crossdatabase.Qu
"1783122026497", "1784288924673", "1783392627713", "1783122026497", "1784288924673", "1783392627713",
}) })
n := int64(10) n := int64(10)
return &crossdatabase.Response{ return &crossmodel.Response{
RowNumber: &n, RowNumber: &n,
Objects: []crossdatabase.Object{ Objects: []crossmodel.Object{
{"v1": "vv"}, {"v1": "vv"},
}, },
}, nil }, nil
} }
} }
func mockDelete(t *testing.T) func(context.Context, *crossdatabase.DeleteRequest) (*crossdatabase.Response, error) { func mockDelete(t *testing.T) func(context.Context, *crossmodel.DeleteRequest) (*crossmodel.Response, error) {
return func(ctx context.Context, req *crossdatabase.DeleteRequest) (*crossdatabase.Response, error) { return func(ctx context.Context, req *crossmodel.DeleteRequest) (*crossmodel.Response, error) {
nn := int64(10) nn := int64(10)
assert.Equal(t, req.ConditionGroup.Conditions[0], &crossdatabase.Condition{ assert.Equal(t, req.ConditionGroup.Conditions[0], &crossmodel.ConditionStr{
Left: "v2", Left: "v2",
Operator: "=", Operator: "=",
Right: nn, Right: nn,
}) })
n := int64(1) n := int64(1)
return &crossdatabase.Response{ return &crossmodel.Response{
RowNumber: &n, RowNumber: &n,
}, nil }, nil
} }
@ -228,8 +228,8 @@ func TestDatabaseCURD(t *testing.T) {
_ = ctx _ = ctx
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
defer ctrl.Finish() defer ctrl.Finish()
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl) mockDatabaseOperator := databasemock.NewMockDatabase(ctrl)
mockey.Mock(crossdatabase.GetDatabaseOperator).Return(mockDatabaseOperator).Build() mockey.Mock(crossdatabase.DefaultSVC).Return(mockDatabaseOperator).Build()
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery(t)) mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery(t))
mockDatabaseOperator.EXPECT().Update(gomock.Any(), gomock.Any()).DoAndReturn(mockUpdate(t)) mockDatabaseOperator.EXPECT().Update(gomock.Any(), gomock.Any()).DoAndReturn(mockUpdate(t))
mockDatabaseOperator.EXPECT().Insert(gomock.Any(), gomock.Any()).DoAndReturn(mockInsert(t)) mockDatabaseOperator.EXPECT().Insert(gomock.Any(), gomock.Any()).DoAndReturn(mockInsert(t))

View File

@ -25,7 +25,7 @@ import (
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
code2 "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/code" code2 "github.com/coze-dev/coze-studio/backend/crossdomain/impl/code"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"

View File

@ -24,10 +24,9 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock" "go.uber.org/mock/gomock"
"github.com/coze-dev/coze-studio/backend/infra/contract/coderunner"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/infra/contract/coderunner"
mockcode "github.com/coze-dev/coze-studio/backend/internal/mock/domain/workflow/crossdomain/code" mockcode "github.com/coze-dev/coze-studio/backend/internal/mock/domain/workflow/crossdomain/code"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache" "github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
) )

View File

@ -1,64 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package conversation
import (
"context"
"errors"
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/conversation"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
)
type ClearMessageConfig struct {
Clearer conversation.ConversationManager
}
type MessageClear struct {
config *ClearMessageConfig
}
func NewClearMessage(ctx context.Context, cfg *ClearMessageConfig) (*MessageClear, error) {
if cfg == nil {
return nil, errors.New("config is required")
}
if cfg.Clearer == nil {
return nil, errors.New("clearer is required")
}
return &MessageClear{
config: cfg,
}, nil
}
func (c *MessageClear) Clear(ctx context.Context, input map[string]any) (map[string]any, error) {
name, ok := nodes.TakeMapValue(input, compose.FieldPath{"ConversationName"})
if !ok {
return nil, errors.New("input map should contains 'ConversationName' key ")
}
response, err := c.config.Clearer.ClearMessage(ctx, &conversation.ClearMessageRequest{
Name: name.(string),
})
if err != nil {
return nil, err
}
return map[string]any{
"isSuccess": response.IsSuccess,
}, nil
}

View File

@ -1,62 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package conversation
import (
"context"
"errors"
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/conversation"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
)
type CreateConversationConfig struct {
Creator conversation.ConversationManager
}
type CreateConversation struct {
config *CreateConversationConfig
}
func NewCreateConversation(ctx context.Context, cfg *CreateConversationConfig) (*CreateConversation, error) {
if cfg == nil {
return nil, errors.New("config is required")
}
if cfg.Creator == nil {
return nil, errors.New("creator is required")
}
return &CreateConversation{
config: cfg,
}, nil
}
func (c *CreateConversation) Create(ctx context.Context, input map[string]any) (map[string]any, error) {
name, ok := nodes.TakeMapValue(input, compose.FieldPath{"ConversationName"})
if !ok {
return nil, errors.New("input map should contains 'ConversationName' key ")
}
response, err := c.config.Creator.CreateConversation(ctx, &conversation.CreateConversationRequest{
Name: name.(string),
})
if err != nil {
return nil, err
}
return response.Result, nil
}

View File

@ -1,108 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package conversation
import (
"context"
"encoding/json"
"errors"
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/conversation"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
)
type MessageListConfig struct {
Lister conversation.ConversationManager
}
type MessageList struct {
config *MessageListConfig
}
type Param struct {
ConversationName string
Limit *int
BeforeID *string
AfterID *string
}
func NewMessageList(ctx context.Context, cfg *MessageListConfig) (*MessageList, error) {
if cfg == nil {
return nil, errors.New("config is required")
}
if cfg.Lister == nil {
return nil, errors.New("lister is required")
}
return &MessageList{
config: cfg,
}, nil
}
func (m *MessageList) List(ctx context.Context, input map[string]any) (map[string]any, error) {
param := &Param{}
name, ok := nodes.TakeMapValue(input, compose.FieldPath{"ConversationName"})
if !ok {
return nil, errors.New("ConversationName is required")
}
param.ConversationName = name.(string)
limit, ok := nodes.TakeMapValue(input, compose.FieldPath{"Limit"})
if ok {
limit := limit.(int)
param.Limit = &limit
}
beforeID, ok := nodes.TakeMapValue(input, compose.FieldPath{"BeforeID"})
if ok {
beforeID := beforeID.(string)
param.BeforeID = &beforeID
}
afterID, ok := nodes.TakeMapValue(input, compose.FieldPath{"AfterID"})
if ok {
afterID := afterID.(string)
param.BeforeID = &afterID
}
r, err := m.config.Lister.MessageList(ctx, &conversation.ListMessageRequest{
ConversationName: param.ConversationName,
Limit: param.Limit,
BeforeID: param.BeforeID,
AfterID: param.AfterID,
})
if err != nil {
return nil, err
}
result := make(map[string]any)
objects := make([]any, 0, len(r.Messages))
for _, msg := range r.Messages {
object := make(map[string]any)
bs, _ := json.Marshal(msg)
err := json.Unmarshal(bs, &object)
if err != nil {
return nil, err
}
objects = append(objects, object)
}
result["messageList"] = objects
result["firstId"] = r.FirstID
result["hasMore"] = r.HasMore
return result, nil
}

View File

@ -21,7 +21,7 @@ import (
einoCompose "github.com/cloudwego/eino/compose" einoCompose "github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database" "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"

View File

@ -25,7 +25,7 @@ import (
"github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database" "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
@ -349,7 +349,7 @@ func convertClauseGroupToConditionGroup(_ context.Context, clauseGroup *database
) )
conditionGroup := &database.ConditionGroup{ conditionGroup := &database.ConditionGroup{
Conditions: make([]*database.Condition, 0), Conditions: make([]*database.ConditionStr, 0),
Relation: database.ClauseRelationAND, Relation: database.ClauseRelationAND,
} }
@ -362,7 +362,7 @@ func convertClauseGroupToConditionGroup(_ context.Context, clauseGroup *database
} }
} }
conditionGroup.Conditions = append(conditionGroup.Conditions, &database.Condition{ conditionGroup.Conditions = append(conditionGroup.Conditions, &database.ConditionStr{
Left: clause.Left, Left: clause.Left,
Operator: clause.Operator, Operator: clause.Operator,
Right: rightValue, Right: rightValue,
@ -373,7 +373,7 @@ func convertClauseGroupToConditionGroup(_ context.Context, clauseGroup *database
if clauseGroup.Multi != nil { if clauseGroup.Multi != nil {
conditionGroup.Relation = clauseGroup.Multi.Relation conditionGroup.Relation = clauseGroup.Multi.Relation
conditionGroup.Conditions = make([]*database.Condition, len(clauseGroup.Multi.Clauses)) conditionGroup.Conditions = make([]*database.ConditionStr, len(clauseGroup.Multi.Clauses))
multiSelect := clauseGroup.Multi multiSelect := clauseGroup.Multi
for idx, clause := range multiSelect.Clauses { for idx, clause := range multiSelect.Clauses {
if !notNeedTakeMapValue(clause.Operator) { if !notNeedTakeMapValue(clause.Operator) {
@ -382,7 +382,7 @@ func convertClauseGroupToConditionGroup(_ context.Context, clauseGroup *database
return nil, fmt.Errorf("cannot take multi clause from input") return nil, fmt.Errorf("cannot take multi clause from input")
} }
} }
conditionGroup.Conditions[idx] = &database.Condition{ conditionGroup.Conditions[idx] = &database.ConditionStr{
Left: clause.Left, Left: clause.Left,
Operator: clause.Operator, Operator: clause.Operator,
Right: rightValue, Right: rightValue,

View File

@ -24,7 +24,8 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database" "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
@ -85,18 +86,16 @@ func (c *CustomSQLConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...s
} }
return &CustomSQL{ return &CustomSQL{
databaseInfoID: c.DatabaseInfoID, databaseInfoID: c.DatabaseInfoID,
sqlTemplate: c.SQLTemplate, sqlTemplate: c.SQLTemplate,
outputTypes: ns.OutputTypes, outputTypes: ns.OutputTypes,
customSQLExecutor: database.GetDatabaseOperator(),
}, nil }, nil
} }
type CustomSQL struct { type CustomSQL struct {
databaseInfoID int64 databaseInfoID int64
sqlTemplate string sqlTemplate string
outputTypes map[string]*vo.TypeInfo outputTypes map[string]*vo.TypeInfo
customSQLExecutor database.DatabaseOperator
} }
func (c *CustomSQL) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) { func (c *CustomSQL) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
@ -155,7 +154,7 @@ func (c *CustomSQL) Invoke(ctx context.Context, input map[string]any) (map[strin
templateSQL = strings.Replace(templateSQL, "`?`", "?", -1) templateSQL = strings.Replace(templateSQL, "`?`", "?", -1)
req.SQL = templateSQL req.SQL = templateSQL
req.Params = sqlParams req.Params = sqlParams
response, err := c.customSQLExecutor.Execute(ctx, req) response, err := crossdatabase.DefaultSVC().Execute(ctx, req)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -24,8 +24,9 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock" "go.uber.org/mock/gomock"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database" "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database/databasemock" crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/database/databasemock"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
@ -78,10 +79,9 @@ func TestCustomSQL_Execute(t *testing.T) {
}, },
}).Build().UnPatch() }).Build().UnPatch()
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl) mockDatabaseOperator := databasemock.NewMockDatabase(ctrl)
mockDatabaseOperator.EXPECT().Execute(gomock.Any(), gomock.Any()).DoAndReturn(mockSQLer.Execute()).AnyTimes() mockDatabaseOperator.EXPECT().Execute(gomock.Any(), gomock.Any()).DoAndReturn(mockSQLer.Execute()).AnyTimes()
crossdatabase.SetDefaultSVC(mockDatabaseOperator)
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
cfg := &CustomSQLConfig{ cfg := &CustomSQLConfig{
DatabaseInfoID: 111, DatabaseInfoID: 111,

View File

@ -22,7 +22,8 @@ import (
"fmt" "fmt"
"strconv" "strconv"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database" "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
@ -87,7 +88,6 @@ func (d *DeleteConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...sche
databaseInfoID: d.DatabaseInfoID, databaseInfoID: d.DatabaseInfoID,
clauseGroup: d.ClauseGroup, clauseGroup: d.ClauseGroup,
outputTypes: ns.OutputTypes, outputTypes: ns.OutputTypes,
deleter: database.GetDatabaseOperator(),
}, nil }, nil
} }
@ -95,7 +95,6 @@ type Delete struct {
databaseInfoID int64 databaseInfoID int64
clauseGroup *database.ClauseGroup clauseGroup *database.ClauseGroup
outputTypes map[string]*vo.TypeInfo outputTypes map[string]*vo.TypeInfo
deleter database.DatabaseOperator
} }
func (d *Delete) Invoke(ctx context.Context, in map[string]any) (map[string]any, error) { func (d *Delete) Invoke(ctx context.Context, in map[string]any) (map[string]any, error) {
@ -111,7 +110,7 @@ func (d *Delete) Invoke(ctx context.Context, in map[string]any) (map[string]any,
ConnectorID: getConnectorID(ctx), ConnectorID: getConnectorID(ctx),
} }
response, err := d.deleter.Delete(ctx, request) response, err := crossdatabase.DefaultSVC().Delete(ctx, request)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -22,7 +22,8 @@ import (
"fmt" "fmt"
"strconv" "strconv"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database" "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
@ -73,14 +74,12 @@ func (i *InsertConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...sche
return &Insert{ return &Insert{
databaseInfoID: i.DatabaseInfoID, databaseInfoID: i.DatabaseInfoID,
outputTypes: ns.OutputTypes, outputTypes: ns.OutputTypes,
inserter: database.GetDatabaseOperator(),
}, nil }, nil
} }
type Insert struct { type Insert struct {
databaseInfoID int64 databaseInfoID int64
outputTypes map[string]*vo.TypeInfo outputTypes map[string]*vo.TypeInfo
inserter database.DatabaseOperator
} }
func (is *Insert) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) { func (is *Insert) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
@ -93,7 +92,7 @@ func (is *Insert) Invoke(ctx context.Context, input map[string]any) (map[string]
ConnectorID: getConnectorID(ctx), ConnectorID: getConnectorID(ctx),
} }
response, err := is.inserter.Insert(ctx, req) response, err := crossdatabase.DefaultSVC().Insert(ctx, req)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -22,7 +22,8 @@ import (
"fmt" "fmt"
"strconv" "strconv"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database" "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
@ -114,7 +115,6 @@ func (q *QueryConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...schem
outputTypes: ns.OutputTypes, outputTypes: ns.OutputTypes,
clauseGroup: q.ClauseGroup, clauseGroup: q.ClauseGroup,
limit: q.Limit, limit: q.Limit,
op: database.GetDatabaseOperator(),
}, nil }, nil
} }
@ -125,7 +125,6 @@ type Query struct {
outputTypes map[string]*vo.TypeInfo outputTypes map[string]*vo.TypeInfo
clauseGroup *database.ClauseGroup clauseGroup *database.ClauseGroup
limit int64 limit int64
op database.DatabaseOperator
} }
func (ds *Query) Invoke(ctx context.Context, in map[string]any) (map[string]any, error) { func (ds *Query) Invoke(ctx context.Context, in map[string]any) (map[string]any, error) {
@ -146,7 +145,7 @@ func (ds *Query) Invoke(ctx context.Context, in map[string]any) (map[string]any,
req.ConditionGroup = conditionGroup req.ConditionGroup = conditionGroup
response, err := ds.op.Query(ctx, req) response, err := crossdatabase.DefaultSVC().Query(ctx, req)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -26,8 +26,9 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database" "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database/databasemock" crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/database/databasemock"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
@ -95,10 +96,9 @@ func TestDataset_Query(t *testing.T) {
assert.Equal(t, cGroup.Conditions[0].Operator, cfg.ClauseGroup.Single.Operator) assert.Equal(t, cGroup.Conditions[0].Operator, cfg.ClauseGroup.Single.Operator)
}} }}
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl) mockDatabaseOperator := databasemock.NewMockDatabase(ctrl)
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()) mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query())
crossdatabase.SetDefaultSVC(mockDatabaseOperator)
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
ds, err := cfg.Build(context.Background(), &schema.NodeSchema{ ds, err := cfg.Build(context.Background(), &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{ OutputTypes: map[string]*vo.TypeInfo{
@ -159,10 +159,9 @@ func TestDataset_Query(t *testing.T) {
assert.Equal(t, cGroup.Relation, cfg.ClauseGroup.Multi.Relation) assert.Equal(t, cGroup.Relation, cfg.ClauseGroup.Multi.Relation)
}} }}
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl) mockDatabaseOperator := databasemock.NewMockDatabase(ctrl)
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes() mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
crossdatabase.SetDefaultSVC(mockDatabaseOperator)
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
ds, err := cfg.Build(context.Background(), &schema.NodeSchema{ ds, err := cfg.Build(context.Background(), &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{ OutputTypes: map[string]*vo.TypeInfo{
@ -219,10 +218,9 @@ func TestDataset_Query(t *testing.T) {
assert.Equal(t, cGroup.Conditions[0].Operator, cfg.ClauseGroup.Single.Operator) assert.Equal(t, cGroup.Conditions[0].Operator, cfg.ClauseGroup.Single.Operator)
}} }}
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl) mockDatabaseOperator := databasemock.NewMockDatabase(ctrl)
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes() mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
crossdatabase.SetDefaultSVC(mockDatabaseOperator)
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
ds, err := cfg.Build(context.Background(), &schema.NodeSchema{ ds, err := cfg.Build(context.Background(), &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{ OutputTypes: map[string]*vo.TypeInfo{
@ -278,10 +276,9 @@ func TestDataset_Query(t *testing.T) {
assert.Equal(t, cGroup.Conditions[0].Left, cfg.ClauseGroup.Single.Left) assert.Equal(t, cGroup.Conditions[0].Left, cfg.ClauseGroup.Single.Left)
assert.Equal(t, cGroup.Conditions[0].Operator, cfg.ClauseGroup.Single.Operator) assert.Equal(t, cGroup.Conditions[0].Operator, cfg.ClauseGroup.Single.Operator)
}} }}
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl) mockDatabaseOperator := databasemock.NewMockDatabase(ctrl)
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes() mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
crossdatabase.SetDefaultSVC(mockDatabaseOperator)
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
ds, err := cfg.Build(context.Background(), &schema.NodeSchema{ ds, err := cfg.Build(context.Background(), &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{ OutputTypes: map[string]*vo.TypeInfo{
@ -347,10 +344,9 @@ func TestDataset_Query(t *testing.T) {
assert.Equal(t, cGroup.Conditions[0].Operator, cfg.ClauseGroup.Single.Operator) assert.Equal(t, cGroup.Conditions[0].Operator, cfg.ClauseGroup.Single.Operator)
}} }}
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl) mockDatabaseOperator := databasemock.NewMockDatabase(ctrl)
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes() mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
crossdatabase.SetDefaultSVC(mockDatabaseOperator)
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
ds, err := cfg.Build(context.Background(), &schema.NodeSchema{ ds, err := cfg.Build(context.Background(), &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{ OutputTypes: map[string]*vo.TypeInfo{
@ -425,10 +421,9 @@ func TestDataset_Query(t *testing.T) {
assert.Equal(t, cGroup.Conditions[0].Operator, cfg.ClauseGroup.Single.Operator) assert.Equal(t, cGroup.Conditions[0].Operator, cfg.ClauseGroup.Single.Operator)
}} }}
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl) mockDatabaseOperator := databasemock.NewMockDatabase(ctrl)
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes() mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
crossdatabase.SetDefaultSVC(mockDatabaseOperator)
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
ds, err := cfg.Build(context.Background(), &schema.NodeSchema{ ds, err := cfg.Build(context.Background(), &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{ OutputTypes: map[string]*vo.TypeInfo{

View File

@ -22,7 +22,8 @@ import (
"fmt" "fmt"
"strconv" "strconv"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database" "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
@ -89,7 +90,6 @@ func (u *UpdateConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...sche
databaseInfoID: u.DatabaseInfoID, databaseInfoID: u.DatabaseInfoID,
clauseGroup: u.ClauseGroup, clauseGroup: u.ClauseGroup,
outputTypes: ns.OutputTypes, outputTypes: ns.OutputTypes,
updater: database.GetDatabaseOperator(),
}, nil }, nil
} }
@ -97,7 +97,6 @@ type Update struct {
databaseInfoID int64 databaseInfoID int64
clauseGroup *database.ClauseGroup clauseGroup *database.ClauseGroup
outputTypes map[string]*vo.TypeInfo outputTypes map[string]*vo.TypeInfo
updater database.DatabaseOperator
} }
type updateInventory struct { type updateInventory struct {
@ -126,7 +125,7 @@ func (u *Update) Invoke(ctx context.Context, in map[string]any) (map[string]any,
ConnectorID: getConnectorID(ctx), ConnectorID: getConnectorID(ctx),
} }
response, err := u.updater.Update(ctx, req) response, err := crossdatabase.DefaultSVC().Update(ctx, req)
if err != nil { if err != nil {
return nil, err return nil, err