refactor(workflow): Move the database component in the Workflow package into the common crossdomain package (#704)
This commit is contained in:
parent
e7011f2549
commit
9ff065cebd
|
|
@ -59,7 +59,10 @@ import (
|
|||
appplugin "github.com/coze-dev/coze-studio/backend/application/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/application/user"
|
||||
appworkflow "github.com/coze-dev/coze-studio/backend/application/workflow"
|
||||
crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/database/databasemock"
|
||||
crossuser "github.com/coze-dev/coze-studio/backend/crossdomain/contract/user"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/impl/code"
|
||||
plugin3 "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/plugin"
|
||||
entity4 "github.com/coze-dev/coze-studio/backend/domain/memory/database/entity"
|
||||
entity2 "github.com/coze-dev/coze-studio/backend/domain/openauth/openapiauth/entity"
|
||||
|
|
@ -67,9 +70,6 @@ import (
|
|||
entity5 "github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
|
||||
userentity "github.com/coze-dev/coze-studio/backend/domain/user/entity"
|
||||
workflow2 "github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/code"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database/databasemock"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge/knowledgemock"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
|
||||
|
|
@ -122,7 +122,7 @@ type wfTestRunner struct {
|
|||
plugin *mockPlugin.MockPluginService
|
||||
tos *storageMock.MockStorage
|
||||
knowledge *knowledgemock.MockKnowledgeOperator
|
||||
database *databasemock.MockDatabaseOperator
|
||||
database *databasemock.MockDatabase
|
||||
pluginSrv *pluginmock.MockService
|
||||
internalModel *testutil.UTChatModel
|
||||
ctx context.Context
|
||||
|
|
@ -291,8 +291,8 @@ func newWfTestRunner(t *testing.T) *wfTestRunner {
|
|||
m4 := mockey.Mock(ctxutil.MustGetUIDFromCtx).Return(int64(1)).Build()
|
||||
m5 := mockey.Mock(ctxutil.GetUIDFromCtx).Return(ptr.Of(int64(1))).Build()
|
||||
|
||||
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
|
||||
database.SetDatabaseOperator(mockDatabaseOperator)
|
||||
mockDatabaseOperator := databasemock.NewMockDatabase(ctrl)
|
||||
crossdatabase.SetDefaultSVC(mockDatabaseOperator)
|
||||
|
||||
mockPluginSrv := pluginmock.NewMockService(ctrl)
|
||||
plugin.SetPluginService(mockPluginSrv)
|
||||
|
|
|
|||
|
|
@ -204,3 +204,110 @@ type GetAllDatabaseByAppIDRequest struct {
|
|||
type GetAllDatabaseByAppIDResponse struct {
|
||||
Databases []*Database // online databases
|
||||
}
|
||||
|
||||
type SQLParam struct {
|
||||
Value string
|
||||
IsNull bool
|
||||
}
|
||||
type CustomSQLRequest struct {
|
||||
DatabaseInfoID int64
|
||||
SQL string
|
||||
Params []SQLParam
|
||||
IsDebugRun bool
|
||||
UserID string
|
||||
ConnectorID int64
|
||||
}
|
||||
|
||||
type Object = map[string]any
|
||||
|
||||
type Response struct {
|
||||
RowNumber *int64
|
||||
Objects []Object
|
||||
}
|
||||
|
||||
type Operator string
|
||||
type ClauseRelation string
|
||||
|
||||
const (
|
||||
ClauseRelationAND ClauseRelation = "and"
|
||||
ClauseRelationOR ClauseRelation = "or"
|
||||
)
|
||||
|
||||
const (
|
||||
OperatorEqual Operator = "="
|
||||
OperatorNotEqual Operator = "!="
|
||||
OperatorGreater Operator = ">"
|
||||
OperatorLesser Operator = "<"
|
||||
OperatorGreaterOrEqual Operator = ">="
|
||||
OperatorLesserOrEqual Operator = "<="
|
||||
OperatorIn Operator = "in"
|
||||
OperatorNotIn Operator = "not_in"
|
||||
OperatorIsNull Operator = "is_null"
|
||||
OperatorIsNotNull Operator = "is_not_null"
|
||||
OperatorLike Operator = "like"
|
||||
OperatorNotLike Operator = "not_like"
|
||||
)
|
||||
|
||||
type ClauseGroup struct {
|
||||
Single *Clause
|
||||
Multi *MultiClause
|
||||
}
|
||||
type Clause struct {
|
||||
Left string
|
||||
Operator Operator
|
||||
}
|
||||
type MultiClause struct {
|
||||
Clauses []*Clause
|
||||
Relation ClauseRelation
|
||||
}
|
||||
|
||||
type ConditionStr struct {
|
||||
Left string
|
||||
Operator Operator
|
||||
Right any
|
||||
}
|
||||
|
||||
type ConditionGroup struct {
|
||||
Conditions []*ConditionStr
|
||||
Relation ClauseRelation
|
||||
}
|
||||
|
||||
type DeleteRequest struct {
|
||||
DatabaseInfoID int64
|
||||
ConditionGroup *ConditionGroup
|
||||
IsDebugRun bool
|
||||
UserID string
|
||||
ConnectorID int64
|
||||
}
|
||||
|
||||
type QueryRequest struct {
|
||||
DatabaseInfoID int64
|
||||
SelectFields []string
|
||||
Limit int64
|
||||
ConditionGroup *ConditionGroup
|
||||
OrderClauses []*OrderClause
|
||||
IsDebugRun bool
|
||||
UserID string
|
||||
ConnectorID int64
|
||||
}
|
||||
|
||||
type OrderClause struct {
|
||||
FieldID string
|
||||
IsAsc bool
|
||||
}
|
||||
type UpdateRequest struct {
|
||||
DatabaseInfoID int64
|
||||
ConditionGroup *ConditionGroup
|
||||
Fields map[string]any
|
||||
IsDebugRun bool
|
||||
UserID string
|
||||
ConnectorID int64
|
||||
}
|
||||
|
||||
type InsertRequest struct {
|
||||
DatabaseInfoID int64
|
||||
Fields map[string]any
|
||||
IsDebugRun bool
|
||||
UserID string
|
||||
ConnectorID int64
|
||||
}
|
||||
|
|
|
|||
|
|
@ -126,7 +126,6 @@ func Init(ctx context.Context) (err error) {
|
|||
if err != nil {
|
||||
return fmt.Errorf("Init - initVitalServices failed, err: %v", err)
|
||||
}
|
||||
|
||||
crossconnector.SetDefaultSVC(connectorImpl.InitDomainService(basicServices.connectorSVC.DomainSVC))
|
||||
crossdatabase.SetDefaultSVC(databaseImpl.InitDomainService(primaryServices.memorySVC.DatabaseDomainSVC))
|
||||
crossknowledge.SetDefaultSVC(knowledgeImpl.InitDomainService(primaryServices.knowledgeSVC.DomainSVC))
|
||||
|
|
|
|||
|
|
@ -19,14 +19,12 @@ package workflow
|
|||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cloudwego/eino/callbacks"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/cloudwego/eino/callbacks"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/application/internal"
|
||||
|
||||
wfdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/database"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/impl/code"
|
||||
wfknowledge "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/knowledge"
|
||||
wfmodel "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/model"
|
||||
wfplugin "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/plugin"
|
||||
|
|
@ -38,8 +36,6 @@ import (
|
|||
plugin "github.com/coze-dev/coze-studio/backend/domain/plugin/service"
|
||||
search "github.com/coze-dev/coze-studio/backend/domain/search/service"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
crosscode "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/code"
|
||||
crossdatabase "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
crossknowledge "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
|
||||
crossmodel "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
|
||||
crossplugin "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
|
||||
|
|
@ -88,13 +84,12 @@ func InitService(ctx context.Context, components *ServiceComponents) (*Applicati
|
|||
workflow.SetRepository(workflowRepo)
|
||||
|
||||
workflowDomainSVC := service.NewWorkflowService(workflowRepo)
|
||||
crossdatabase.SetDatabaseOperator(wfdatabase.NewDatabaseRepository(components.DatabaseDomainSVC))
|
||||
crossvariable.SetVariableHandler(variable.NewVariableHandler(components.VariablesDomainSVC))
|
||||
crossvariable.SetVariablesMetaGetter(variable.NewVariablesMetaGetter(components.VariablesDomainSVC))
|
||||
crossplugin.SetPluginService(wfplugin.NewPluginService(components.PluginDomainSVC, components.Tos))
|
||||
crossknowledge.SetKnowledgeOperator(wfknowledge.NewKnowledgeRepository(components.KnowledgeDomainSVC, components.IDGen))
|
||||
crossmodel.SetManager(wfmodel.NewModelManager(components.ModelManager, nil))
|
||||
crosscode.SetCodeRunner(components.CodeRunner)
|
||||
code.SetCodeRunner(components.CodeRunner)
|
||||
crosssearch.SetNotifier(wfsearch.NewNotify(components.DomainNotifier))
|
||||
callbacks.AppendGlobalHandlers(workflowservice.GetTokenCallbackHandler())
|
||||
|
||||
|
|
|
|||
|
|
@ -25,9 +25,8 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
xmaps "golang.org/x/exp/maps"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
xmaps "golang.org/x/exp/maps"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/app/bot_common"
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ import (
|
|||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
|
||||
)
|
||||
|
||||
//go:generate mockgen -destination databasemock/database_mock.go --package databasemock -source database.go
|
||||
type Database interface {
|
||||
ExecuteSQL(ctx context.Context, req *database.ExecuteSQLRequest) (*database.ExecuteSQLResponse, error)
|
||||
PublishDatabase(ctx context.Context, req *database.PublishDatabaseRequest) (resp *database.PublishDatabaseResponse, err error)
|
||||
|
|
@ -30,6 +31,12 @@ type Database interface {
|
|||
UnBindDatabase(ctx context.Context, req *database.UnBindDatabaseToAgentRequest) error
|
||||
MGetDatabase(ctx context.Context, req *database.MGetDatabaseRequest) (*database.MGetDatabaseResponse, error)
|
||||
GetAllDatabaseByAppID(ctx context.Context, req *database.GetAllDatabaseByAppIDRequest) (*database.GetAllDatabaseByAppIDResponse, error)
|
||||
|
||||
Execute(ctx context.Context, request *database.CustomSQLRequest) (*database.Response, error)
|
||||
Query(ctx context.Context, request *database.QueryRequest) (*database.Response, error)
|
||||
Update(context.Context, *database.UpdateRequest) (*database.Response, error)
|
||||
Insert(ctx context.Context, request *database.InsertRequest) (*database.Response, error)
|
||||
Delete(context.Context, *database.DeleteRequest) (*database.Response, error)
|
||||
}
|
||||
|
||||
var defaultSVC Database
|
||||
|
|
|
|||
|
|
@ -0,0 +1,219 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: database.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -destination databasemock/database_mock.go --package databasemock -source database.go
|
||||
//
|
||||
|
||||
// Package databasemock is a generated GoMock package.
|
||||
package databasemock
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
database "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockDatabase is a mock of Database interface.
|
||||
type MockDatabase struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockDatabaseMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockDatabaseMockRecorder is the mock recorder for MockDatabase.
|
||||
type MockDatabaseMockRecorder struct {
|
||||
mock *MockDatabase
|
||||
}
|
||||
|
||||
// NewMockDatabase creates a new mock instance.
|
||||
func NewMockDatabase(ctrl *gomock.Controller) *MockDatabase {
|
||||
mock := &MockDatabase{ctrl: ctrl}
|
||||
mock.recorder = &MockDatabaseMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockDatabase) EXPECT() *MockDatabaseMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// BindDatabase mocks base method.
|
||||
func (m *MockDatabase) BindDatabase(ctx context.Context, req *database.BindDatabaseToAgentRequest) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "BindDatabase", ctx, req)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// BindDatabase indicates an expected call of BindDatabase.
|
||||
func (mr *MockDatabaseMockRecorder) BindDatabase(ctx, req any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BindDatabase", reflect.TypeOf((*MockDatabase)(nil).BindDatabase), ctx, req)
|
||||
}
|
||||
|
||||
// Delete mocks base method.
|
||||
func (m *MockDatabase) Delete(arg0 context.Context, arg1 *database.DeleteRequest) (*database.Response, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Delete", arg0, arg1)
|
||||
ret0, _ := ret[0].(*database.Response)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Delete indicates an expected call of Delete.
|
||||
func (mr *MockDatabaseMockRecorder) Delete(arg0, arg1 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockDatabase)(nil).Delete), arg0, arg1)
|
||||
}
|
||||
|
||||
// DeleteDatabase mocks base method.
|
||||
func (m *MockDatabase) DeleteDatabase(ctx context.Context, req *database.DeleteDatabaseRequest) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteDatabase", ctx, req)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteDatabase indicates an expected call of DeleteDatabase.
|
||||
func (mr *MockDatabaseMockRecorder) DeleteDatabase(ctx, req any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteDatabase", reflect.TypeOf((*MockDatabase)(nil).DeleteDatabase), ctx, req)
|
||||
}
|
||||
|
||||
// Execute mocks base method.
|
||||
func (m *MockDatabase) Execute(ctx context.Context, request *database.CustomSQLRequest) (*database.Response, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Execute", ctx, request)
|
||||
ret0, _ := ret[0].(*database.Response)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Execute indicates an expected call of Execute.
|
||||
func (mr *MockDatabaseMockRecorder) Execute(ctx, request any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Execute", reflect.TypeOf((*MockDatabase)(nil).Execute), ctx, request)
|
||||
}
|
||||
|
||||
// ExecuteSQL mocks base method.
|
||||
func (m *MockDatabase) ExecuteSQL(ctx context.Context, req *database.ExecuteSQLRequest) (*database.ExecuteSQLResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ExecuteSQL", ctx, req)
|
||||
ret0, _ := ret[0].(*database.ExecuteSQLResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ExecuteSQL indicates an expected call of ExecuteSQL.
|
||||
func (mr *MockDatabaseMockRecorder) ExecuteSQL(ctx, req any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExecuteSQL", reflect.TypeOf((*MockDatabase)(nil).ExecuteSQL), ctx, req)
|
||||
}
|
||||
|
||||
// GetAllDatabaseByAppID mocks base method.
|
||||
func (m *MockDatabase) GetAllDatabaseByAppID(ctx context.Context, req *database.GetAllDatabaseByAppIDRequest) (*database.GetAllDatabaseByAppIDResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAllDatabaseByAppID", ctx, req)
|
||||
ret0, _ := ret[0].(*database.GetAllDatabaseByAppIDResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAllDatabaseByAppID indicates an expected call of GetAllDatabaseByAppID.
|
||||
func (mr *MockDatabaseMockRecorder) GetAllDatabaseByAppID(ctx, req any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllDatabaseByAppID", reflect.TypeOf((*MockDatabase)(nil).GetAllDatabaseByAppID), ctx, req)
|
||||
}
|
||||
|
||||
// Insert mocks base method.
|
||||
func (m *MockDatabase) Insert(ctx context.Context, request *database.InsertRequest) (*database.Response, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Insert", ctx, request)
|
||||
ret0, _ := ret[0].(*database.Response)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Insert indicates an expected call of Insert.
|
||||
func (mr *MockDatabaseMockRecorder) Insert(ctx, request any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Insert", reflect.TypeOf((*MockDatabase)(nil).Insert), ctx, request)
|
||||
}
|
||||
|
||||
// MGetDatabase mocks base method.
|
||||
func (m *MockDatabase) MGetDatabase(ctx context.Context, req *database.MGetDatabaseRequest) (*database.MGetDatabaseResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "MGetDatabase", ctx, req)
|
||||
ret0, _ := ret[0].(*database.MGetDatabaseResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// MGetDatabase indicates an expected call of MGetDatabase.
|
||||
func (mr *MockDatabaseMockRecorder) MGetDatabase(ctx, req any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGetDatabase", reflect.TypeOf((*MockDatabase)(nil).MGetDatabase), ctx, req)
|
||||
}
|
||||
|
||||
// PublishDatabase mocks base method.
|
||||
func (m *MockDatabase) PublishDatabase(ctx context.Context, req *database.PublishDatabaseRequest) (*database.PublishDatabaseResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "PublishDatabase", ctx, req)
|
||||
ret0, _ := ret[0].(*database.PublishDatabaseResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// PublishDatabase indicates an expected call of PublishDatabase.
|
||||
func (mr *MockDatabaseMockRecorder) PublishDatabase(ctx, req any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublishDatabase", reflect.TypeOf((*MockDatabase)(nil).PublishDatabase), ctx, req)
|
||||
}
|
||||
|
||||
// Query mocks base method.
|
||||
func (m *MockDatabase) Query(ctx context.Context, request *database.QueryRequest) (*database.Response, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Query", ctx, request)
|
||||
ret0, _ := ret[0].(*database.Response)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Query indicates an expected call of Query.
|
||||
func (mr *MockDatabaseMockRecorder) Query(ctx, request any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockDatabase)(nil).Query), ctx, request)
|
||||
}
|
||||
|
||||
// UnBindDatabase mocks base method.
|
||||
func (m *MockDatabase) UnBindDatabase(ctx context.Context, req *database.UnBindDatabaseToAgentRequest) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UnBindDatabase", ctx, req)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UnBindDatabase indicates an expected call of UnBindDatabase.
|
||||
func (mr *MockDatabaseMockRecorder) UnBindDatabase(ctx, req any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnBindDatabase", reflect.TypeOf((*MockDatabase)(nil).UnBindDatabase), ctx, req)
|
||||
}
|
||||
|
||||
// Update mocks base method.
|
||||
func (m *MockDatabase) Update(arg0 context.Context, arg1 *database.UpdateRequest) (*database.Response, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Update", arg0, arg1)
|
||||
ret0, _ := ret[0].(*database.Response)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Update indicates an expected call of Update.
|
||||
func (mr *MockDatabaseMockRecorder) Update(arg0, arg1 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockDatabase)(nil).Update), arg0, arg1)
|
||||
}
|
||||
|
|
@ -18,10 +18,20 @@ package database
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cast"
|
||||
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/data/database/table"
|
||||
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
|
||||
crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
|
||||
database "github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
|
||||
)
|
||||
|
||||
var defaultSVC crossdatabase.Database
|
||||
|
|
@ -65,3 +75,406 @@ func (c *databaseImpl) MGetDatabase(ctx context.Context, req *model.MGetDatabase
|
|||
func (c *databaseImpl) GetAllDatabaseByAppID(ctx context.Context, req *model.GetAllDatabaseByAppIDRequest) (*model.GetAllDatabaseByAppIDResponse, error) {
|
||||
return c.DomainSVC.GetAllDatabaseByAppID(ctx, req)
|
||||
}
|
||||
|
||||
func (d *databaseImpl) Execute(ctx context.Context, request *model.CustomSQLRequest) (*model.Response, error) {
|
||||
var (
|
||||
err error
|
||||
databaseInfoID = request.DatabaseInfoID
|
||||
tableType = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
|
||||
)
|
||||
|
||||
if request.IsDebugRun {
|
||||
databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
req := &service.ExecuteSQLRequest{
|
||||
DatabaseID: databaseInfoID,
|
||||
OperateType: model.OperateType_Custom,
|
||||
SQL: &request.SQL,
|
||||
TableType: tableType,
|
||||
UserID: request.UserID,
|
||||
ConnectorID: ptr.Of(request.ConnectorID),
|
||||
}
|
||||
|
||||
req.SQLParams = make([]*model.SQLParamVal, 0, len(request.Params))
|
||||
for i := range request.Params {
|
||||
param := request.Params[i]
|
||||
req.SQLParams = append(req.SQLParams, &model.SQLParamVal{
|
||||
ValueType: table.FieldItemType_Text,
|
||||
Value: ¶m.Value,
|
||||
ISNull: param.IsNull,
|
||||
})
|
||||
}
|
||||
response, err := d.DomainSVC.ExecuteSQL(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// if rows affected is nil use 0 instead
|
||||
if response.RowsAffected == nil {
|
||||
response.RowsAffected = ptr.Of(int64(0))
|
||||
}
|
||||
return toNodeDateBaseResponse(response), nil
|
||||
}
|
||||
|
||||
func (d *databaseImpl) Delete(ctx context.Context, request *model.DeleteRequest) (*model.Response, error) {
|
||||
var (
|
||||
err error
|
||||
databaseInfoID = request.DatabaseInfoID
|
||||
tableType = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
|
||||
)
|
||||
|
||||
if request.IsDebugRun {
|
||||
databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
req := &service.ExecuteSQLRequest{
|
||||
DatabaseID: databaseInfoID,
|
||||
OperateType: model.OperateType_Delete,
|
||||
TableType: tableType,
|
||||
UserID: request.UserID,
|
||||
ConnectorID: ptr.Of(request.ConnectorID),
|
||||
}
|
||||
|
||||
if request.ConditionGroup != nil {
|
||||
req.Condition, req.SQLParams, err = buildComplexCondition(request.ConditionGroup)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
response, err := d.DomainSVC.ExecuteSQL(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return toNodeDateBaseResponse(response), nil
|
||||
}
|
||||
|
||||
func (d *databaseImpl) Query(ctx context.Context, request *model.QueryRequest) (*model.Response, error) {
|
||||
var (
|
||||
err error
|
||||
databaseInfoID = request.DatabaseInfoID
|
||||
tableType = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
|
||||
)
|
||||
|
||||
if request.IsDebugRun {
|
||||
databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
req := &service.ExecuteSQLRequest{
|
||||
DatabaseID: databaseInfoID,
|
||||
OperateType: model.OperateType_Select,
|
||||
TableType: tableType,
|
||||
UserID: request.UserID,
|
||||
ConnectorID: ptr.Of(request.ConnectorID),
|
||||
}
|
||||
|
||||
req.SelectFieldList = &model.SelectFieldList{FieldID: make([]string, 0, len(request.SelectFields))}
|
||||
for i := range request.SelectFields {
|
||||
req.SelectFieldList.FieldID = append(req.SelectFieldList.FieldID, request.SelectFields[i])
|
||||
}
|
||||
|
||||
req.OrderByList = make([]model.OrderBy, 0)
|
||||
for i := range request.OrderClauses {
|
||||
clause := request.OrderClauses[i]
|
||||
req.OrderByList = append(req.OrderByList, model.OrderBy{
|
||||
Field: clause.FieldID,
|
||||
Direction: toOrderDirection(clause.IsAsc),
|
||||
})
|
||||
}
|
||||
|
||||
if request.ConditionGroup != nil {
|
||||
req.Condition, req.SQLParams, err = buildComplexCondition(request.ConditionGroup)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
limit := request.Limit
|
||||
req.Limit = &limit
|
||||
|
||||
response, err := d.DomainSVC.ExecuteSQL(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return toNodeDateBaseResponse(response), nil
|
||||
}
|
||||
|
||||
func (d *databaseImpl) Update(ctx context.Context, request *model.UpdateRequest) (*model.Response, error) {
|
||||
|
||||
var (
|
||||
err error
|
||||
condition *model.ComplexCondition
|
||||
params []*model.SQLParamVal
|
||||
databaseInfoID = request.DatabaseInfoID
|
||||
tableType = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
|
||||
)
|
||||
|
||||
if request.IsDebugRun {
|
||||
databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
req := &service.ExecuteSQLRequest{
|
||||
DatabaseID: databaseInfoID,
|
||||
OperateType: model.OperateType_Update,
|
||||
SQLParams: make([]*model.SQLParamVal, 0),
|
||||
TableType: tableType,
|
||||
UserID: request.UserID,
|
||||
ConnectorID: ptr.Of(request.ConnectorID),
|
||||
}
|
||||
|
||||
uid := ctxutil.GetUIDFromCtx(ctx)
|
||||
if uid != nil {
|
||||
req.UserID = conv.Int64ToStr(*uid)
|
||||
}
|
||||
req.UpsertRows, req.SQLParams, err = resolveUpsertRow(request.Fields)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if request.ConditionGroup != nil {
|
||||
condition, params, err = buildComplexCondition(request.ConditionGroup)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Condition = condition
|
||||
req.SQLParams = append(req.SQLParams, params...)
|
||||
}
|
||||
|
||||
response, err := d.DomainSVC.ExecuteSQL(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return toNodeDateBaseResponse(response), nil
|
||||
}
|
||||
|
||||
func (d *databaseImpl) Insert(ctx context.Context, request *model.InsertRequest) (*model.Response, error) {
|
||||
var (
|
||||
err error
|
||||
databaseInfoID = request.DatabaseInfoID
|
||||
tableType = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
|
||||
)
|
||||
|
||||
if request.IsDebugRun {
|
||||
databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
req := &service.ExecuteSQLRequest{
|
||||
DatabaseID: databaseInfoID,
|
||||
OperateType: model.OperateType_Insert,
|
||||
TableType: tableType,
|
||||
UserID: request.UserID,
|
||||
ConnectorID: ptr.Of(request.ConnectorID),
|
||||
}
|
||||
|
||||
req.UpsertRows, req.SQLParams, err = resolveUpsertRow(request.Fields)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
response, err := d.DomainSVC.ExecuteSQL(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return toNodeDateBaseResponse(response), nil
|
||||
}
|
||||
|
||||
func (d *databaseImpl) getDraftTableID(ctx context.Context, onlineID int64) (int64, error) {
|
||||
resp, err := d.DomainSVC.GetDraftDatabaseByOnlineID(ctx, &service.GetDraftDatabaseByOnlineIDRequest{OnlineID: onlineID})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return resp.Database.ID, nil
|
||||
}
|
||||
|
||||
func buildComplexCondition(conditionGroup *model.ConditionGroup) (*model.ComplexCondition, []*model.SQLParamVal, error) {
|
||||
condition := &model.ComplexCondition{}
|
||||
logic, err := toLogic(conditionGroup.Relation)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
condition.Logic = logic
|
||||
|
||||
params := make([]*model.SQLParamVal, 0)
|
||||
for i := range conditionGroup.Conditions {
|
||||
var (
|
||||
nCond = conditionGroup.Conditions[i]
|
||||
vals []*model.SQLParamVal
|
||||
dCond = &model.Condition{
|
||||
Left: nCond.Left,
|
||||
}
|
||||
)
|
||||
opt, err := toOperation(nCond.Operator)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
dCond.Operation = opt
|
||||
|
||||
if isNullOrNotNull(opt) {
|
||||
condition.Conditions = append(condition.Conditions, dCond)
|
||||
continue
|
||||
}
|
||||
dCond.Right, vals, err = resolveRightValue(opt, nCond.Right)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
condition.Conditions = append(condition.Conditions, dCond)
|
||||
|
||||
params = append(params, vals...)
|
||||
|
||||
}
|
||||
return condition, params, nil
|
||||
}
|
||||
|
||||
func toMapStringAny(m map[string]string) map[string]any {
|
||||
ret := make(map[string]any, len(m))
|
||||
for k, v := range m {
|
||||
ret[k] = v
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func toOperation(operator model.Operator) (model.Operation, error) {
|
||||
switch operator {
|
||||
case model.OperatorEqual:
|
||||
return model.Operation_EQUAL, nil
|
||||
case model.OperatorNotEqual:
|
||||
return model.Operation_NOT_EQUAL, nil
|
||||
case model.OperatorGreater:
|
||||
return model.Operation_GREATER_THAN, nil
|
||||
case model.OperatorGreaterOrEqual:
|
||||
return model.Operation_GREATER_EQUAL, nil
|
||||
case model.OperatorLesser:
|
||||
return model.Operation_LESS_THAN, nil
|
||||
case model.OperatorLesserOrEqual:
|
||||
return model.Operation_LESS_EQUAL, nil
|
||||
case model.OperatorIn:
|
||||
return model.Operation_IN, nil
|
||||
case model.OperatorNotIn:
|
||||
return model.Operation_NOT_IN, nil
|
||||
case model.OperatorIsNotNull:
|
||||
return model.Operation_IS_NOT_NULL, nil
|
||||
case model.OperatorIsNull:
|
||||
return model.Operation_IS_NULL, nil
|
||||
case model.OperatorLike:
|
||||
return model.Operation_LIKE, nil
|
||||
case model.OperatorNotLike:
|
||||
return model.Operation_NOT_LIKE, nil
|
||||
default:
|
||||
return model.Operation(0), fmt.Errorf("invalid operator %v", operator)
|
||||
}
|
||||
}
|
||||
|
||||
func resolveRightValue(operator model.Operation, right any) (string, []*model.SQLParamVal, error) {
|
||||
|
||||
if isInOrNotIn(operator) {
|
||||
var (
|
||||
vals = make([]*model.SQLParamVal, 0)
|
||||
anyVals = make([]any, 0)
|
||||
commas = make([]string, 0, len(anyVals))
|
||||
)
|
||||
|
||||
anyVals = right.([]any)
|
||||
for i := range anyVals {
|
||||
v := cast.ToString(anyVals[i])
|
||||
vals = append(vals, &model.SQLParamVal{ValueType: table.FieldItemType_Text, Value: &v})
|
||||
commas = append(commas, "?")
|
||||
}
|
||||
value := "(" + strings.Join(commas, ",") + ")"
|
||||
return value, vals, nil
|
||||
}
|
||||
|
||||
rightValue, err := cast.ToStringE(right)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if isLikeOrNotLike(operator) {
|
||||
var (
|
||||
value = "?"
|
||||
v = "%s" + rightValue + "%s"
|
||||
)
|
||||
return value, []*model.SQLParamVal{{ValueType: table.FieldItemType_Text, Value: &v}}, nil
|
||||
}
|
||||
|
||||
return "?", []*model.SQLParamVal{{ValueType: table.FieldItemType_Text, Value: &rightValue}}, nil
|
||||
}
|
||||
|
||||
func resolveUpsertRow(fields map[string]any) ([]*model.UpsertRow, []*model.SQLParamVal, error) {
|
||||
upsertRow := &model.UpsertRow{Records: make([]*model.Record, 0, len(fields))}
|
||||
params := make([]*model.SQLParamVal, 0)
|
||||
for key, value := range fields {
|
||||
val, err := cast.ToStringE(value)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
record := &model.Record{
|
||||
FieldId: key,
|
||||
FieldValue: "?",
|
||||
}
|
||||
upsertRow.Records = append(upsertRow.Records, record)
|
||||
params = append(params, &model.SQLParamVal{
|
||||
ValueType: table.FieldItemType_Text,
|
||||
Value: &val,
|
||||
})
|
||||
}
|
||||
return []*model.UpsertRow{upsertRow}, params, nil
|
||||
}
|
||||
|
||||
func isNullOrNotNull(opt model.Operation) bool {
|
||||
return opt == model.Operation_IS_NOT_NULL || opt == model.Operation_IS_NULL
|
||||
}
|
||||
|
||||
func isLikeOrNotLike(opt model.Operation) bool {
|
||||
return opt == model.Operation_LIKE || opt == model.Operation_NOT_LIKE
|
||||
}
|
||||
|
||||
func isInOrNotIn(opt model.Operation) bool {
|
||||
return opt == model.Operation_IN || opt == model.Operation_NOT_IN
|
||||
}
|
||||
|
||||
func toOrderDirection(isAsc bool) table.SortDirection {
|
||||
if isAsc {
|
||||
return table.SortDirection_ASC
|
||||
}
|
||||
return table.SortDirection_Desc
|
||||
}
|
||||
|
||||
func toLogic(relation model.ClauseRelation) (model.Logic, error) {
|
||||
switch relation {
|
||||
case model.ClauseRelationOR:
|
||||
return model.Logic_Or, nil
|
||||
case model.ClauseRelationAND:
|
||||
return model.Logic_And, nil
|
||||
default:
|
||||
return model.Logic(0), fmt.Errorf("invalid relation %v", relation)
|
||||
}
|
||||
}
|
||||
|
||||
func toNodeDateBaseResponse(response *service.ExecuteSQLResponse) *model.Response {
|
||||
objects := make([]model.Object, 0, len(response.Records))
|
||||
for i := range response.Records {
|
||||
objects = append(objects, response.Records[i])
|
||||
}
|
||||
return &model.Response{
|
||||
Objects: objects,
|
||||
RowNumber: response.RowsAffected,
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,447 +0,0 @@
|
|||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cast"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/data/database/table"
|
||||
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
|
||||
nodedatabase "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
|
||||
)
|
||||
|
||||
type DatabaseRepository struct {
|
||||
client service.Database
|
||||
}
|
||||
|
||||
func NewDatabaseRepository(client service.Database) *DatabaseRepository {
|
||||
return &DatabaseRepository{
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DatabaseRepository) Execute(ctx context.Context, request *nodedatabase.CustomSQLRequest) (*nodedatabase.Response, error) {
|
||||
var (
|
||||
err error
|
||||
databaseInfoID = request.DatabaseInfoID
|
||||
tableType = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
|
||||
)
|
||||
|
||||
if request.IsDebugRun {
|
||||
databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
req := &service.ExecuteSQLRequest{
|
||||
DatabaseID: databaseInfoID,
|
||||
OperateType: database.OperateType_Custom,
|
||||
SQL: &request.SQL,
|
||||
TableType: tableType,
|
||||
UserID: request.UserID,
|
||||
ConnectorID: ptr.Of(request.ConnectorID),
|
||||
}
|
||||
|
||||
req.SQLParams = make([]*database.SQLParamVal, 0, len(request.Params))
|
||||
for i := range request.Params {
|
||||
param := request.Params[i]
|
||||
req.SQLParams = append(req.SQLParams, &database.SQLParamVal{
|
||||
ValueType: table.FieldItemType_Text,
|
||||
Value: ¶m.Value,
|
||||
ISNull: param.IsNull,
|
||||
})
|
||||
}
|
||||
response, err := d.client.ExecuteSQL(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// if rows affected is nil use 0 instead
|
||||
if response.RowsAffected == nil {
|
||||
response.RowsAffected = ptr.Of(int64(0))
|
||||
}
|
||||
return toNodeDateBaseResponse(response), nil
|
||||
}
|
||||
|
||||
func (d *DatabaseRepository) Delete(ctx context.Context, request *nodedatabase.DeleteRequest) (*nodedatabase.Response, error) {
|
||||
var (
|
||||
err error
|
||||
databaseInfoID = request.DatabaseInfoID
|
||||
tableType = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
|
||||
)
|
||||
|
||||
if request.IsDebugRun {
|
||||
databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
req := &service.ExecuteSQLRequest{
|
||||
DatabaseID: databaseInfoID,
|
||||
OperateType: database.OperateType_Delete,
|
||||
TableType: tableType,
|
||||
UserID: request.UserID,
|
||||
ConnectorID: ptr.Of(request.ConnectorID),
|
||||
}
|
||||
|
||||
if request.ConditionGroup != nil {
|
||||
req.Condition, req.SQLParams, err = buildComplexCondition(request.ConditionGroup)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
response, err := d.client.ExecuteSQL(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return toNodeDateBaseResponse(response), nil
|
||||
}
|
||||
|
||||
func (d *DatabaseRepository) Query(ctx context.Context, request *nodedatabase.QueryRequest) (*nodedatabase.Response, error) {
|
||||
var (
|
||||
err error
|
||||
databaseInfoID = request.DatabaseInfoID
|
||||
tableType = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
|
||||
)
|
||||
|
||||
if request.IsDebugRun {
|
||||
databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
req := &service.ExecuteSQLRequest{
|
||||
DatabaseID: databaseInfoID,
|
||||
OperateType: database.OperateType_Select,
|
||||
TableType: tableType,
|
||||
UserID: request.UserID,
|
||||
ConnectorID: ptr.Of(request.ConnectorID),
|
||||
}
|
||||
|
||||
req.SelectFieldList = &database.SelectFieldList{FieldID: make([]string, 0, len(request.SelectFields))}
|
||||
for i := range request.SelectFields {
|
||||
req.SelectFieldList.FieldID = append(req.SelectFieldList.FieldID, request.SelectFields[i])
|
||||
}
|
||||
|
||||
req.OrderByList = make([]database.OrderBy, 0)
|
||||
for i := range request.OrderClauses {
|
||||
clause := request.OrderClauses[i]
|
||||
req.OrderByList = append(req.OrderByList, database.OrderBy{
|
||||
Field: clause.FieldID,
|
||||
Direction: toOrderDirection(clause.IsAsc),
|
||||
})
|
||||
}
|
||||
|
||||
if request.ConditionGroup != nil {
|
||||
req.Condition, req.SQLParams, err = buildComplexCondition(request.ConditionGroup)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
limit := request.Limit
|
||||
req.Limit = &limit
|
||||
|
||||
response, err := d.client.ExecuteSQL(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return toNodeDateBaseResponse(response), nil
|
||||
}
|
||||
|
||||
func (d *DatabaseRepository) Update(ctx context.Context, request *nodedatabase.UpdateRequest) (*nodedatabase.Response, error) {
|
||||
|
||||
var (
|
||||
err error
|
||||
condition *database.ComplexCondition
|
||||
params []*database.SQLParamVal
|
||||
databaseInfoID = request.DatabaseInfoID
|
||||
tableType = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
|
||||
)
|
||||
|
||||
if request.IsDebugRun {
|
||||
databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
req := &service.ExecuteSQLRequest{
|
||||
DatabaseID: databaseInfoID,
|
||||
OperateType: database.OperateType_Update,
|
||||
SQLParams: make([]*database.SQLParamVal, 0),
|
||||
TableType: tableType,
|
||||
UserID: request.UserID,
|
||||
ConnectorID: ptr.Of(request.ConnectorID),
|
||||
}
|
||||
|
||||
uid := ctxutil.GetUIDFromCtx(ctx)
|
||||
if uid != nil {
|
||||
req.UserID = conv.Int64ToStr(*uid)
|
||||
}
|
||||
req.UpsertRows, req.SQLParams, err = resolveUpsertRow(request.Fields)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if request.ConditionGroup != nil {
|
||||
condition, params, err = buildComplexCondition(request.ConditionGroup)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Condition = condition
|
||||
req.SQLParams = append(req.SQLParams, params...)
|
||||
}
|
||||
|
||||
response, err := d.client.ExecuteSQL(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return toNodeDateBaseResponse(response), nil
|
||||
}
|
||||
|
||||
func (d *DatabaseRepository) Insert(ctx context.Context, request *nodedatabase.InsertRequest) (*nodedatabase.Response, error) {
|
||||
var (
|
||||
err error
|
||||
databaseInfoID = request.DatabaseInfoID
|
||||
tableType = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
|
||||
)
|
||||
|
||||
if request.IsDebugRun {
|
||||
databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
req := &service.ExecuteSQLRequest{
|
||||
DatabaseID: databaseInfoID,
|
||||
OperateType: database.OperateType_Insert,
|
||||
TableType: tableType,
|
||||
UserID: request.UserID,
|
||||
ConnectorID: ptr.Of(request.ConnectorID),
|
||||
}
|
||||
|
||||
req.UpsertRows, req.SQLParams, err = resolveUpsertRow(request.Fields)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
response, err := d.client.ExecuteSQL(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return toNodeDateBaseResponse(response), nil
|
||||
}
|
||||
|
||||
func (d *DatabaseRepository) getDraftTableID(ctx context.Context, onlineID int64) (int64, error) {
|
||||
resp, err := d.client.GetDraftDatabaseByOnlineID(ctx, &service.GetDraftDatabaseByOnlineIDRequest{OnlineID: onlineID})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return resp.Database.ID, nil
|
||||
|
||||
}
|
||||
|
||||
func buildComplexCondition(conditionGroup *nodedatabase.ConditionGroup) (*database.ComplexCondition, []*database.SQLParamVal, error) {
|
||||
condition := &database.ComplexCondition{}
|
||||
logic, err := toLogic(conditionGroup.Relation)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
condition.Logic = logic
|
||||
|
||||
params := make([]*database.SQLParamVal, 0)
|
||||
for i := range conditionGroup.Conditions {
|
||||
var (
|
||||
nCond = conditionGroup.Conditions[i]
|
||||
vals []*database.SQLParamVal
|
||||
dCond = &database.Condition{
|
||||
Left: nCond.Left,
|
||||
}
|
||||
)
|
||||
opt, err := toOperation(nCond.Operator)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
dCond.Operation = opt
|
||||
|
||||
if isNullOrNotNull(opt) {
|
||||
condition.Conditions = append(condition.Conditions, dCond)
|
||||
continue
|
||||
}
|
||||
dCond.Right, vals, err = resolveRightValue(opt, nCond.Right)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
condition.Conditions = append(condition.Conditions, dCond)
|
||||
|
||||
params = append(params, vals...)
|
||||
|
||||
}
|
||||
return condition, params, nil
|
||||
}
|
||||
|
||||
func toMapStringAny(m map[string]string) map[string]any {
|
||||
ret := make(map[string]any, len(m))
|
||||
for k, v := range m {
|
||||
ret[k] = v
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func toOperation(operator nodedatabase.Operator) (database.Operation, error) {
|
||||
switch operator {
|
||||
case nodedatabase.OperatorEqual:
|
||||
return database.Operation_EQUAL, nil
|
||||
case nodedatabase.OperatorNotEqual:
|
||||
return database.Operation_NOT_EQUAL, nil
|
||||
case nodedatabase.OperatorGreater:
|
||||
return database.Operation_GREATER_THAN, nil
|
||||
case nodedatabase.OperatorGreaterOrEqual:
|
||||
return database.Operation_GREATER_EQUAL, nil
|
||||
case nodedatabase.OperatorLesser:
|
||||
return database.Operation_LESS_THAN, nil
|
||||
case nodedatabase.OperatorLesserOrEqual:
|
||||
return database.Operation_LESS_EQUAL, nil
|
||||
case nodedatabase.OperatorIn:
|
||||
return database.Operation_IN, nil
|
||||
case nodedatabase.OperatorNotIn:
|
||||
return database.Operation_NOT_IN, nil
|
||||
case nodedatabase.OperatorIsNotNull:
|
||||
return database.Operation_IS_NOT_NULL, nil
|
||||
case nodedatabase.OperatorIsNull:
|
||||
return database.Operation_IS_NULL, nil
|
||||
case nodedatabase.OperatorLike:
|
||||
return database.Operation_LIKE, nil
|
||||
case nodedatabase.OperatorNotLike:
|
||||
return database.Operation_NOT_LIKE, nil
|
||||
default:
|
||||
return database.Operation(0), fmt.Errorf("invalid operator %v", operator)
|
||||
}
|
||||
}
|
||||
|
||||
func resolveRightValue(operator database.Operation, right any) (string, []*database.SQLParamVal, error) {
|
||||
|
||||
if isInOrNotIn(operator) {
|
||||
var (
|
||||
vals = make([]*database.SQLParamVal, 0)
|
||||
anyVals = make([]any, 0)
|
||||
commas = make([]string, 0, len(anyVals))
|
||||
)
|
||||
|
||||
anyVals = right.([]any)
|
||||
for i := range anyVals {
|
||||
v := cast.ToString(anyVals[i])
|
||||
vals = append(vals, &database.SQLParamVal{ValueType: table.FieldItemType_Text, Value: &v})
|
||||
commas = append(commas, "?")
|
||||
}
|
||||
value := "(" + strings.Join(commas, ",") + ")"
|
||||
return value, vals, nil
|
||||
}
|
||||
|
||||
rightValue, err := cast.ToStringE(right)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if isLikeOrNotLike(operator) {
|
||||
var (
|
||||
value = "?"
|
||||
v = "%s" + rightValue + "%s"
|
||||
)
|
||||
return value, []*database.SQLParamVal{{ValueType: table.FieldItemType_Text, Value: &v}}, nil
|
||||
}
|
||||
|
||||
return "?", []*database.SQLParamVal{{ValueType: table.FieldItemType_Text, Value: &rightValue}}, nil
|
||||
}
|
||||
|
||||
func resolveUpsertRow(fields map[string]any) ([]*database.UpsertRow, []*database.SQLParamVal, error) {
|
||||
upsertRow := &database.UpsertRow{Records: make([]*database.Record, 0, len(fields))}
|
||||
params := make([]*database.SQLParamVal, 0)
|
||||
for key, value := range fields {
|
||||
val, err := cast.ToStringE(value)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
record := &database.Record{
|
||||
FieldId: key,
|
||||
FieldValue: "?",
|
||||
}
|
||||
upsertRow.Records = append(upsertRow.Records, record)
|
||||
params = append(params, &database.SQLParamVal{
|
||||
ValueType: table.FieldItemType_Text,
|
||||
Value: &val,
|
||||
})
|
||||
}
|
||||
return []*database.UpsertRow{upsertRow}, params, nil
|
||||
}
|
||||
|
||||
func isNullOrNotNull(opt database.Operation) bool {
|
||||
return opt == database.Operation_IS_NOT_NULL || opt == database.Operation_IS_NULL
|
||||
}
|
||||
|
||||
func isLikeOrNotLike(opt database.Operation) bool {
|
||||
return opt == database.Operation_LIKE || opt == database.Operation_NOT_LIKE
|
||||
}
|
||||
|
||||
func isInOrNotIn(opt database.Operation) bool {
|
||||
return opt == database.Operation_IN || opt == database.Operation_NOT_IN
|
||||
}
|
||||
|
||||
func toOrderDirection(isAsc bool) table.SortDirection {
|
||||
if isAsc {
|
||||
return table.SortDirection_ASC
|
||||
}
|
||||
return table.SortDirection_Desc
|
||||
}
|
||||
|
||||
func toLogic(relation nodedatabase.ClauseRelation) (database.Logic, error) {
|
||||
switch relation {
|
||||
case nodedatabase.ClauseRelationOR:
|
||||
return database.Logic_Or, nil
|
||||
case nodedatabase.ClauseRelationAND:
|
||||
return database.Logic_And, nil
|
||||
default:
|
||||
return database.Logic(0), fmt.Errorf("invalid relation %v", relation)
|
||||
}
|
||||
}
|
||||
|
||||
func toNodeDateBaseResponse(response *service.ExecuteSQLResponse) *nodedatabase.Response {
|
||||
objects := make([]nodedatabase.Object, 0, len(response.Records))
|
||||
for i := range response.Records {
|
||||
objects = append(objects, response.Records[i])
|
||||
}
|
||||
return &nodedatabase.Response{
|
||||
Objects: objects,
|
||||
RowNumber: response.RowsAffected,
|
||||
}
|
||||
}
|
||||
|
|
@ -1,224 +0,0 @@
|
|||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/cast"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
|
||||
nodedatabase "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
mockDatabase "github.com/coze-dev/coze-studio/backend/internal/mock/domain/memory/database"
|
||||
)
|
||||
|
||||
func mockExecuteSQL(t *testing.T) func(ctx context.Context, request *service.ExecuteSQLRequest) (*service.ExecuteSQLResponse, error) {
|
||||
return func(ctx context.Context, request *service.ExecuteSQLRequest) (*service.ExecuteSQLResponse, error) {
|
||||
if request.OperateType == database.OperateType_Custom {
|
||||
assert.Equal(t, *request.SQL, "select * from table where v1=? and v2=?")
|
||||
rs := make([]string, 0)
|
||||
for idx := range request.SQLParams {
|
||||
rs = append(rs, *request.SQLParams[idx].Value)
|
||||
}
|
||||
assert.Equal(t, rs, []string{"1", "2"})
|
||||
return &service.ExecuteSQLResponse{
|
||||
Records: []map[string]any{
|
||||
{"v1": "1", "v2": "2"},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
if request.OperateType == database.OperateType_Select {
|
||||
sFields := []string{"v1", "v2", "v3", "v4"}
|
||||
assert.Equal(t, request.SelectFieldList.FieldID, sFields)
|
||||
cond := request.Condition.Conditions[1] // in
|
||||
assert.Equal(t, "(?,?)", cond.Right)
|
||||
assert.Equal(t, database.Operation_IN, cond.Operation)
|
||||
assert.Equal(t, "v2_1", *request.SQLParams[1].Value)
|
||||
assert.Equal(t, "v2_2", *request.SQLParams[2].Value)
|
||||
assert.Equal(t, "%sv4%s", *request.SQLParams[3].Value)
|
||||
rowsAffected := int64(10)
|
||||
return &service.ExecuteSQLResponse{
|
||||
Records: []map[string]any{
|
||||
{"v1": "1", "v2": "2", "v3": "3", "v4": "4"},
|
||||
},
|
||||
RowsAffected: &rowsAffected,
|
||||
}, nil
|
||||
|
||||
}
|
||||
if request.OperateType == database.OperateType_Delete {
|
||||
cond := request.Condition.Conditions[1] // in
|
||||
assert.Equal(t, "(?,?)", cond.Right)
|
||||
assert.Equal(t, database.Operation_NOT_IN, cond.Operation)
|
||||
assert.Equal(t, "v2_1", *request.SQLParams[1].Value)
|
||||
assert.Equal(t, "v2_2", *request.SQLParams[2].Value)
|
||||
assert.Equal(t, "%sv4%s", *request.SQLParams[3].Value)
|
||||
rowsAffected := int64(10)
|
||||
return &service.ExecuteSQLResponse{
|
||||
Records: []map[string]any{
|
||||
{"v1": "1", "v2": "2", "v3": "3", "v4": "4"},
|
||||
},
|
||||
RowsAffected: &rowsAffected,
|
||||
}, nil
|
||||
}
|
||||
if request.OperateType == database.OperateType_Insert {
|
||||
records := request.UpsertRows[0].Records
|
||||
ret := map[string]interface{}{
|
||||
"v1": "1",
|
||||
"v2": 2,
|
||||
"v3": 33,
|
||||
"v4": "44aacc",
|
||||
}
|
||||
for idx := range records {
|
||||
assert.Equal(t, *request.SQLParams[idx].Value, cast.ToString(ret[records[idx].FieldId]))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if request.OperateType == database.OperateType_Update {
|
||||
|
||||
records := request.UpsertRows[0].Records
|
||||
ret := map[string]interface{}{
|
||||
"v1": "1",
|
||||
"v2": 2,
|
||||
"v3": 33,
|
||||
"v4": "aabbcc",
|
||||
}
|
||||
for idx := range records {
|
||||
assert.Equal(t, *request.SQLParams[idx].Value, cast.ToString(ret[records[idx].FieldId]))
|
||||
}
|
||||
|
||||
request.SQLParams = request.SQLParams[len(records):]
|
||||
cond := request.Condition.Conditions[1] // in
|
||||
assert.Equal(t, "(?,?)", cond.Right)
|
||||
assert.Equal(t, database.Operation_IN, cond.Operation)
|
||||
assert.Equal(t, "v2_1", *request.SQLParams[1].Value)
|
||||
assert.Equal(t, "v2_2", *request.SQLParams[2].Value)
|
||||
assert.Equal(t, "%sv4%s", *request.SQLParams[3].Value)
|
||||
|
||||
}
|
||||
return &service.ExecuteSQLResponse{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabase_Database(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mockClient := mockDatabase.NewMockDatabase(ctrl)
|
||||
defer ctrl.Finish()
|
||||
ds := DatabaseRepository{
|
||||
client: mockClient,
|
||||
}
|
||||
mockClient.EXPECT().ExecuteSQL(gomock.Any(), gomock.Any()).DoAndReturn(mockExecuteSQL(t)).AnyTimes()
|
||||
|
||||
t.Run("execute", func(t *testing.T) {
|
||||
response, err := ds.Execute(context.Background(), &nodedatabase.CustomSQLRequest{
|
||||
DatabaseInfoID: 1,
|
||||
SQL: "select * from table where v1=? and v2=?",
|
||||
Params: []nodedatabase.SQLParam{
|
||||
nodedatabase.SQLParam{Value: "1"},
|
||||
nodedatabase.SQLParam{Value: "2"},
|
||||
},
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, response.Objects, []nodedatabase.Object{
|
||||
{"v1": "1", "v2": "2"},
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("select", func(t *testing.T) {
|
||||
req := &nodedatabase.QueryRequest{
|
||||
DatabaseInfoID: 1,
|
||||
SelectFields: []string{"v1", "v2", "v3", "v4"},
|
||||
Limit: 10,
|
||||
OrderClauses: []*nodedatabase.OrderClause{
|
||||
{FieldID: "v1", IsAsc: true},
|
||||
{FieldID: "v2", IsAsc: false},
|
||||
},
|
||||
ConditionGroup: &nodedatabase.ConditionGroup{
|
||||
Conditions: []*nodedatabase.Condition{
|
||||
{Left: "v1", Operator: nodedatabase.OperatorEqual, Right: "1"},
|
||||
{Left: "v2", Operator: nodedatabase.OperatorIn, Right: []any{"v2_1", "v2_2"}},
|
||||
{Left: "v3", Operator: nodedatabase.OperatorIsNull},
|
||||
{Left: "v4", Operator: nodedatabase.OperatorLike, Right: "v4"},
|
||||
},
|
||||
Relation: nodedatabase.ClauseRelationOR,
|
||||
},
|
||||
}
|
||||
|
||||
response, err := ds.Query(context.Background(), req)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, *response.RowNumber, int64(10))
|
||||
})
|
||||
|
||||
t.Run("delete", func(t *testing.T) {
|
||||
req := &nodedatabase.DeleteRequest{
|
||||
DatabaseInfoID: 1,
|
||||
ConditionGroup: &nodedatabase.ConditionGroup{
|
||||
Conditions: []*nodedatabase.Condition{
|
||||
{Left: "v1", Operator: nodedatabase.OperatorEqual, Right: "1"},
|
||||
{Left: "v2", Operator: nodedatabase.OperatorNotIn, Right: []any{"v2_1", "v2_2"}},
|
||||
{Left: "v3", Operator: nodedatabase.OperatorIsNotNull},
|
||||
{Left: "v4", Operator: nodedatabase.OperatorNotLike, Right: "v4"},
|
||||
},
|
||||
Relation: nodedatabase.ClauseRelationOR,
|
||||
},
|
||||
}
|
||||
response, err := ds.Delete(context.Background(), req)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, *response.RowNumber, int64(10))
|
||||
})
|
||||
|
||||
t.Run("insert", func(t *testing.T) {
|
||||
req := &nodedatabase.InsertRequest{
|
||||
DatabaseInfoID: 1,
|
||||
Fields: map[string]interface{}{
|
||||
"v1": "1",
|
||||
"v2": 2,
|
||||
"v3": 33,
|
||||
"v4": "44aacc",
|
||||
},
|
||||
}
|
||||
_, err := ds.Insert(context.Background(), req)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
|
||||
t.Run("update", func(t *testing.T) {
|
||||
req := &nodedatabase.UpdateRequest{
|
||||
DatabaseInfoID: 1,
|
||||
ConditionGroup: &nodedatabase.ConditionGroup{
|
||||
Conditions: []*nodedatabase.Condition{
|
||||
{Left: "v1", Operator: nodedatabase.OperatorEqual, Right: "1"},
|
||||
{Left: "v2", Operator: nodedatabase.OperatorIn, Right: []any{"v2_1", "v2_2"}},
|
||||
{Left: "v3", Operator: nodedatabase.OperatorIsNull},
|
||||
{Left: "v4", Operator: nodedatabase.OperatorLike, Right: "v4"},
|
||||
},
|
||||
Relation: nodedatabase.ClauseRelationOR,
|
||||
},
|
||||
Fields: map[string]interface{}{
|
||||
"v1": "1",
|
||||
"v2": 2,
|
||||
"v3": 33,
|
||||
"v4": "aabbcc",
|
||||
},
|
||||
}
|
||||
_, err := ds.Update(context.Background(), req)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
}
|
||||
|
|
@ -1,61 +0,0 @@
|
|||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package conversation
|
||||
|
||||
import "context"
|
||||
|
||||
type ClearMessageRequest struct {
|
||||
Name string
|
||||
}
|
||||
type ClearMessageResponse struct {
|
||||
IsSuccess bool
|
||||
}
|
||||
type CreateConversationRequest struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
type CreateConversationResponse struct {
|
||||
Result map[string]any
|
||||
}
|
||||
|
||||
type ListMessageRequest struct {
|
||||
ConversationName string
|
||||
Limit *int
|
||||
BeforeID *string
|
||||
AfterID *string
|
||||
}
|
||||
type Message struct {
|
||||
ID string `json:"id"`
|
||||
Role string `json:"role"`
|
||||
ContentType string `json:"contentType"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type ListMessageResponse struct {
|
||||
Messages []*Message
|
||||
FirstID string
|
||||
LastID string
|
||||
HasMore bool
|
||||
}
|
||||
|
||||
var ConversationManagerImpl ConversationManager
|
||||
|
||||
type ConversationManager interface {
|
||||
ClearMessage(context.Context, *ClearMessageRequest) (*ClearMessageResponse, error)
|
||||
CreateConversation(ctx context.Context, c *CreateConversationRequest) (*CreateConversationResponse, error)
|
||||
MessageList(ctx context.Context, req *ListMessageRequest) (*ListMessageResponse, error)
|
||||
}
|
||||
|
|
@ -1,148 +0,0 @@
|
|||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type SQLParam struct {
|
||||
Value string
|
||||
IsNull bool
|
||||
}
|
||||
type CustomSQLRequest struct {
|
||||
DatabaseInfoID int64
|
||||
SQL string
|
||||
Params []SQLParam
|
||||
IsDebugRun bool
|
||||
UserID string
|
||||
ConnectorID int64
|
||||
}
|
||||
|
||||
type Object = map[string]any
|
||||
|
||||
type Response struct {
|
||||
RowNumber *int64
|
||||
Objects []Object
|
||||
}
|
||||
|
||||
type Operator string
|
||||
type ClauseRelation string
|
||||
|
||||
const (
|
||||
ClauseRelationAND ClauseRelation = "and"
|
||||
ClauseRelationOR ClauseRelation = "or"
|
||||
)
|
||||
|
||||
const (
|
||||
OperatorEqual Operator = "="
|
||||
OperatorNotEqual Operator = "!="
|
||||
OperatorGreater Operator = ">"
|
||||
OperatorLesser Operator = "<"
|
||||
OperatorGreaterOrEqual Operator = ">="
|
||||
OperatorLesserOrEqual Operator = "<="
|
||||
OperatorIn Operator = "in"
|
||||
OperatorNotIn Operator = "not_in"
|
||||
OperatorIsNull Operator = "is_null"
|
||||
OperatorIsNotNull Operator = "is_not_null"
|
||||
OperatorLike Operator = "like"
|
||||
OperatorNotLike Operator = "not_like"
|
||||
)
|
||||
|
||||
type ClauseGroup struct {
|
||||
Single *Clause
|
||||
Multi *MultiClause
|
||||
}
|
||||
type Clause struct {
|
||||
Left string
|
||||
Operator Operator
|
||||
}
|
||||
type MultiClause struct {
|
||||
Clauses []*Clause
|
||||
Relation ClauseRelation
|
||||
}
|
||||
|
||||
type Condition struct {
|
||||
Left string
|
||||
Operator Operator
|
||||
Right any
|
||||
}
|
||||
|
||||
type ConditionGroup struct {
|
||||
Conditions []*Condition
|
||||
Relation ClauseRelation
|
||||
}
|
||||
|
||||
type DeleteRequest struct {
|
||||
DatabaseInfoID int64
|
||||
ConditionGroup *ConditionGroup
|
||||
IsDebugRun bool
|
||||
UserID string
|
||||
ConnectorID int64
|
||||
}
|
||||
|
||||
type QueryRequest struct {
|
||||
DatabaseInfoID int64
|
||||
SelectFields []string
|
||||
Limit int64
|
||||
ConditionGroup *ConditionGroup
|
||||
OrderClauses []*OrderClause
|
||||
IsDebugRun bool
|
||||
UserID string
|
||||
ConnectorID int64
|
||||
}
|
||||
|
||||
type OrderClause struct {
|
||||
FieldID string
|
||||
IsAsc bool
|
||||
}
|
||||
type UpdateRequest struct {
|
||||
DatabaseInfoID int64
|
||||
ConditionGroup *ConditionGroup
|
||||
Fields map[string]any
|
||||
IsDebugRun bool
|
||||
UserID string
|
||||
ConnectorID int64
|
||||
}
|
||||
|
||||
type InsertRequest struct {
|
||||
DatabaseInfoID int64
|
||||
Fields map[string]any
|
||||
IsDebugRun bool
|
||||
UserID string
|
||||
ConnectorID int64
|
||||
}
|
||||
|
||||
func GetDatabaseOperator() DatabaseOperator {
|
||||
return databaseOperatorImpl
|
||||
}
|
||||
func SetDatabaseOperator(d DatabaseOperator) {
|
||||
databaseOperatorImpl = d
|
||||
}
|
||||
|
||||
var (
|
||||
databaseOperatorImpl DatabaseOperator
|
||||
)
|
||||
|
||||
//go:generate mockgen -destination databasemock/database_mock.go --package databasemock -source database.go
|
||||
type DatabaseOperator interface {
|
||||
Execute(ctx context.Context, request *CustomSQLRequest) (*Response, error)
|
||||
Query(ctx context.Context, request *QueryRequest) (*Response, error)
|
||||
Update(context.Context, *UpdateRequest) (*Response, error)
|
||||
Insert(ctx context.Context, request *InsertRequest) (*Response, error)
|
||||
Delete(context.Context, *DeleteRequest) (*Response, error)
|
||||
}
|
||||
|
|
@ -1,133 +0,0 @@
|
|||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: database.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -destination databasemock/database_mock.go --package databasemock -source database.go
|
||||
//
|
||||
|
||||
// Package databasemock is a generated GoMock package.
|
||||
package databasemock
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
database "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockDatabaseOperator is a mock of DatabaseOperator interface.
|
||||
type MockDatabaseOperator struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockDatabaseOperatorMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockDatabaseOperatorMockRecorder is the mock recorder for MockDatabaseOperator.
|
||||
type MockDatabaseOperatorMockRecorder struct {
|
||||
mock *MockDatabaseOperator
|
||||
}
|
||||
|
||||
// NewMockDatabaseOperator creates a new mock instance.
|
||||
func NewMockDatabaseOperator(ctrl *gomock.Controller) *MockDatabaseOperator {
|
||||
mock := &MockDatabaseOperator{ctrl: ctrl}
|
||||
mock.recorder = &MockDatabaseOperatorMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockDatabaseOperator) EXPECT() *MockDatabaseOperatorMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Delete mocks base method.
|
||||
func (m *MockDatabaseOperator) Delete(arg0 context.Context, arg1 *database.DeleteRequest) (*database.Response, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Delete", arg0, arg1)
|
||||
ret0, _ := ret[0].(*database.Response)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Delete indicates an expected call of Delete.
|
||||
func (mr *MockDatabaseOperatorMockRecorder) Delete(arg0, arg1 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockDatabaseOperator)(nil).Delete), arg0, arg1)
|
||||
}
|
||||
|
||||
// Execute mocks base method.
|
||||
func (m *MockDatabaseOperator) Execute(ctx context.Context, request *database.CustomSQLRequest) (*database.Response, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Execute", ctx, request)
|
||||
ret0, _ := ret[0].(*database.Response)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Execute indicates an expected call of Execute.
|
||||
func (mr *MockDatabaseOperatorMockRecorder) Execute(ctx, request any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Execute", reflect.TypeOf((*MockDatabaseOperator)(nil).Execute), ctx, request)
|
||||
}
|
||||
|
||||
// Insert mocks base method.
|
||||
func (m *MockDatabaseOperator) Insert(ctx context.Context, request *database.InsertRequest) (*database.Response, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Insert", ctx, request)
|
||||
ret0, _ := ret[0].(*database.Response)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Insert indicates an expected call of Insert.
|
||||
func (mr *MockDatabaseOperatorMockRecorder) Insert(ctx, request any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Insert", reflect.TypeOf((*MockDatabaseOperator)(nil).Insert), ctx, request)
|
||||
}
|
||||
|
||||
// Query mocks base method.
|
||||
func (m *MockDatabaseOperator) Query(ctx context.Context, request *database.QueryRequest) (*database.Response, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Query", ctx, request)
|
||||
ret0, _ := ret[0].(*database.Response)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Query indicates an expected call of Query.
|
||||
func (mr *MockDatabaseOperatorMockRecorder) Query(ctx, request any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockDatabaseOperator)(nil).Query), ctx, request)
|
||||
}
|
||||
|
||||
// Update mocks base method.
|
||||
func (m *MockDatabaseOperator) Update(arg0 context.Context, arg1 *database.UpdateRequest) (*database.Response, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Update", arg0, arg1)
|
||||
ret0, _ := ret[0].(*database.Response)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Update indicates an expected call of Update.
|
||||
func (mr *MockDatabaseOperatorMockRecorder) Update(arg0, arg1 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockDatabaseOperator)(nil).Update), arg0, arg1)
|
||||
}
|
||||
|
|
@ -32,13 +32,12 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/coderunner"
|
||||
|
||||
crossmodel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
|
||||
crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/database/databasemock"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/impl/code"
|
||||
userentity "github.com/coze-dev/coze-studio/backend/domain/user/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/code"
|
||||
crossdatabase "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database/databasemock"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge/knowledgemock"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
|
||||
|
|
@ -50,6 +49,7 @@ import (
|
|||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/coderunner"
|
||||
mockWorkflow "github.com/coze-dev/coze-studio/backend/internal/mock/domain/workflow"
|
||||
mockcode "github.com/coze-dev/coze-studio/backend/internal/mock/domain/workflow/crossdomain/code"
|
||||
"github.com/coze-dev/coze-studio/backend/internal/testutil"
|
||||
|
|
@ -105,10 +105,10 @@ func TestIntentDetectorAndDatabase(t *testing.T) {
|
|||
}
|
||||
mockModelManager.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(chatModel, nil, nil).AnyTimes()
|
||||
|
||||
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
|
||||
mockDatabaseOperator := databasemock.NewMockDatabase(ctrl)
|
||||
n := int64(2)
|
||||
resp := &crossdatabase.Response{
|
||||
Objects: []crossdatabase.Object{
|
||||
resp := &crossmodel.Response{
|
||||
Objects: []crossmodel.Object{
|
||||
{
|
||||
"v2": "123",
|
||||
},
|
||||
|
|
@ -119,7 +119,7 @@ func TestIntentDetectorAndDatabase(t *testing.T) {
|
|||
RowNumber: &n,
|
||||
}
|
||||
mockDatabaseOperator.EXPECT().Execute(gomock.Any(), gomock.Any()).Return(resp, nil).AnyTimes()
|
||||
mockey.Mock(crossdatabase.GetDatabaseOperator).Return(mockDatabaseOperator).Build()
|
||||
crossdatabase.SetDefaultSVC(mockDatabaseOperator)
|
||||
|
||||
workflowSC, err := CanvasToWorkflowSchema(ctx, c)
|
||||
assert.NoError(t, err)
|
||||
|
|
@ -144,44 +144,44 @@ func TestIntentDetectorAndDatabase(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func mockUpdate(t *testing.T) func(context.Context, *crossdatabase.UpdateRequest) (*crossdatabase.Response, error) {
|
||||
return func(ctx context.Context, req *crossdatabase.UpdateRequest) (*crossdatabase.Response, error) {
|
||||
assert.Equal(t, req.ConditionGroup.Conditions[0], &crossdatabase.Condition{
|
||||
func mockUpdate(t *testing.T) func(context.Context, *crossmodel.UpdateRequest) (*crossmodel.Response, error) {
|
||||
return func(ctx context.Context, req *crossmodel.UpdateRequest) (*crossmodel.Response, error) {
|
||||
assert.Equal(t, req.ConditionGroup.Conditions[0], &crossmodel.ConditionStr{
|
||||
Left: "v2",
|
||||
Operator: "=",
|
||||
Right: int64(1),
|
||||
})
|
||||
|
||||
assert.Equal(t, req.ConditionGroup.Conditions[1], &crossdatabase.Condition{
|
||||
assert.Equal(t, req.ConditionGroup.Conditions[1], &crossmodel.ConditionStr{
|
||||
Left: "v1",
|
||||
Operator: "=",
|
||||
Right: "abc",
|
||||
})
|
||||
assert.Equal(t, req.ConditionGroup.Relation, crossdatabase.ClauseRelationAND)
|
||||
assert.Equal(t, req.ConditionGroup.Relation, crossmodel.ClauseRelationAND)
|
||||
assert.Equal(t, req.Fields, map[string]interface{}{
|
||||
"1783392627713": int64(123),
|
||||
})
|
||||
|
||||
return &crossdatabase.Response{}, nil
|
||||
return &crossmodel.Response{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func mockInsert(t *testing.T) func(ctx context.Context, request *crossdatabase.InsertRequest) (*crossdatabase.Response, error) {
|
||||
return func(ctx context.Context, req *crossdatabase.InsertRequest) (*crossdatabase.Response, error) {
|
||||
func mockInsert(t *testing.T) func(ctx context.Context, request *crossmodel.InsertRequest) (*crossmodel.Response, error) {
|
||||
return func(ctx context.Context, req *crossmodel.InsertRequest) (*crossmodel.Response, error) {
|
||||
v := req.Fields["1785960530945"]
|
||||
assert.Equal(t, v, int64(123))
|
||||
vs := req.Fields["1783122026497"]
|
||||
assert.Equal(t, vs, "input for database curd")
|
||||
n := int64(10)
|
||||
return &crossdatabase.Response{
|
||||
return &crossmodel.Response{
|
||||
RowNumber: &n,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func mockQuery(t *testing.T) func(ctx context.Context, request *crossdatabase.QueryRequest) (*crossdatabase.Response, error) {
|
||||
return func(ctx context.Context, req *crossdatabase.QueryRequest) (*crossdatabase.Response, error) {
|
||||
assert.Equal(t, req.ConditionGroup.Conditions[0], &crossdatabase.Condition{
|
||||
func mockQuery(t *testing.T) func(ctx context.Context, request *crossmodel.QueryRequest) (*crossmodel.Response, error) {
|
||||
return func(ctx context.Context, req *crossmodel.QueryRequest) (*crossmodel.Response, error) {
|
||||
assert.Equal(t, req.ConditionGroup.Conditions[0], &crossmodel.ConditionStr{
|
||||
Left: "v1",
|
||||
Operator: "=",
|
||||
Right: "abc",
|
||||
|
|
@ -191,26 +191,26 @@ func mockQuery(t *testing.T) func(ctx context.Context, request *crossdatabase.Qu
|
|||
"1783122026497", "1784288924673", "1783392627713",
|
||||
})
|
||||
n := int64(10)
|
||||
return &crossdatabase.Response{
|
||||
return &crossmodel.Response{
|
||||
RowNumber: &n,
|
||||
Objects: []crossdatabase.Object{
|
||||
Objects: []crossmodel.Object{
|
||||
{"v1": "vv"},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func mockDelete(t *testing.T) func(context.Context, *crossdatabase.DeleteRequest) (*crossdatabase.Response, error) {
|
||||
return func(ctx context.Context, req *crossdatabase.DeleteRequest) (*crossdatabase.Response, error) {
|
||||
func mockDelete(t *testing.T) func(context.Context, *crossmodel.DeleteRequest) (*crossmodel.Response, error) {
|
||||
return func(ctx context.Context, req *crossmodel.DeleteRequest) (*crossmodel.Response, error) {
|
||||
nn := int64(10)
|
||||
assert.Equal(t, req.ConditionGroup.Conditions[0], &crossdatabase.Condition{
|
||||
assert.Equal(t, req.ConditionGroup.Conditions[0], &crossmodel.ConditionStr{
|
||||
Left: "v2",
|
||||
Operator: "=",
|
||||
Right: nn,
|
||||
})
|
||||
|
||||
n := int64(1)
|
||||
return &crossdatabase.Response{
|
||||
return &crossmodel.Response{
|
||||
RowNumber: &n,
|
||||
}, nil
|
||||
}
|
||||
|
|
@ -228,8 +228,8 @@ func TestDatabaseCURD(t *testing.T) {
|
|||
_ = ctx
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
|
||||
mockey.Mock(crossdatabase.GetDatabaseOperator).Return(mockDatabaseOperator).Build()
|
||||
mockDatabaseOperator := databasemock.NewMockDatabase(ctrl)
|
||||
mockey.Mock(crossdatabase.DefaultSVC).Return(mockDatabaseOperator).Build()
|
||||
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery(t))
|
||||
mockDatabaseOperator.EXPECT().Update(gomock.Any(), gomock.Any()).DoAndReturn(mockUpdate(t))
|
||||
mockDatabaseOperator.EXPECT().Insert(gomock.Any(), gomock.Any()).DoAndReturn(mockInsert(t))
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ import (
|
|||
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
code2 "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/code"
|
||||
code2 "github.com/coze-dev/coze-studio/backend/crossdomain/impl/code"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
|
|
|
|||
|
|
@ -24,10 +24,9 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/coderunner"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/coderunner"
|
||||
mockcode "github.com/coze-dev/coze-studio/backend/internal/mock/domain/workflow/crossdomain/code"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,64 +0,0 @@
|
|||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/conversation"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
)
|
||||
|
||||
type ClearMessageConfig struct {
|
||||
Clearer conversation.ConversationManager
|
||||
}
|
||||
|
||||
type MessageClear struct {
|
||||
config *ClearMessageConfig
|
||||
}
|
||||
|
||||
func NewClearMessage(ctx context.Context, cfg *ClearMessageConfig) (*MessageClear, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config is required")
|
||||
}
|
||||
if cfg.Clearer == nil {
|
||||
return nil, errors.New("clearer is required")
|
||||
}
|
||||
|
||||
return &MessageClear{
|
||||
config: cfg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *MessageClear) Clear(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
name, ok := nodes.TakeMapValue(input, compose.FieldPath{"ConversationName"})
|
||||
if !ok {
|
||||
return nil, errors.New("input map should contains 'ConversationName' key ")
|
||||
}
|
||||
response, err := c.config.Clearer.ClearMessage(ctx, &conversation.ClearMessageRequest{
|
||||
Name: name.(string),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return map[string]any{
|
||||
"isSuccess": response.IsSuccess,
|
||||
}, nil
|
||||
}
|
||||
|
|
@ -1,62 +0,0 @@
|
|||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/conversation"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
)
|
||||
|
||||
type CreateConversationConfig struct {
|
||||
Creator conversation.ConversationManager
|
||||
}
|
||||
|
||||
type CreateConversation struct {
|
||||
config *CreateConversationConfig
|
||||
}
|
||||
|
||||
func NewCreateConversation(ctx context.Context, cfg *CreateConversationConfig) (*CreateConversation, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config is required")
|
||||
}
|
||||
if cfg.Creator == nil {
|
||||
return nil, errors.New("creator is required")
|
||||
}
|
||||
return &CreateConversation{
|
||||
config: cfg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *CreateConversation) Create(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
name, ok := nodes.TakeMapValue(input, compose.FieldPath{"ConversationName"})
|
||||
if !ok {
|
||||
return nil, errors.New("input map should contains 'ConversationName' key ")
|
||||
}
|
||||
response, err := c.config.Creator.CreateConversation(ctx, &conversation.CreateConversationRequest{
|
||||
Name: name.(string),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return response.Result, nil
|
||||
|
||||
}
|
||||
|
|
@ -1,108 +0,0 @@
|
|||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/conversation"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
)
|
||||
|
||||
type MessageListConfig struct {
|
||||
Lister conversation.ConversationManager
|
||||
}
|
||||
type MessageList struct {
|
||||
config *MessageListConfig
|
||||
}
|
||||
|
||||
type Param struct {
|
||||
ConversationName string
|
||||
Limit *int
|
||||
BeforeID *string
|
||||
AfterID *string
|
||||
}
|
||||
|
||||
func NewMessageList(ctx context.Context, cfg *MessageListConfig) (*MessageList, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config is required")
|
||||
}
|
||||
|
||||
if cfg.Lister == nil {
|
||||
return nil, errors.New("lister is required")
|
||||
}
|
||||
|
||||
return &MessageList{
|
||||
config: cfg,
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
func (m *MessageList) List(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
param := &Param{}
|
||||
name, ok := nodes.TakeMapValue(input, compose.FieldPath{"ConversationName"})
|
||||
if !ok {
|
||||
return nil, errors.New("ConversationName is required")
|
||||
}
|
||||
param.ConversationName = name.(string)
|
||||
limit, ok := nodes.TakeMapValue(input, compose.FieldPath{"Limit"})
|
||||
if ok {
|
||||
limit := limit.(int)
|
||||
param.Limit = &limit
|
||||
}
|
||||
beforeID, ok := nodes.TakeMapValue(input, compose.FieldPath{"BeforeID"})
|
||||
if ok {
|
||||
beforeID := beforeID.(string)
|
||||
param.BeforeID = &beforeID
|
||||
}
|
||||
afterID, ok := nodes.TakeMapValue(input, compose.FieldPath{"AfterID"})
|
||||
if ok {
|
||||
afterID := afterID.(string)
|
||||
param.BeforeID = &afterID
|
||||
}
|
||||
r, err := m.config.Lister.MessageList(ctx, &conversation.ListMessageRequest{
|
||||
ConversationName: param.ConversationName,
|
||||
Limit: param.Limit,
|
||||
BeforeID: param.BeforeID,
|
||||
AfterID: param.AfterID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make(map[string]any)
|
||||
objects := make([]any, 0, len(r.Messages))
|
||||
for _, msg := range r.Messages {
|
||||
object := make(map[string]any)
|
||||
bs, _ := json.Marshal(msg)
|
||||
err := json.Unmarshal(bs, &object)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
objects = append(objects, object)
|
||||
}
|
||||
|
||||
result["messageList"] = objects
|
||||
result["firstId"] = r.FirstID
|
||||
result["hasMore"] = r.HasMore
|
||||
return result, nil
|
||||
|
||||
}
|
||||
|
|
@ -21,7 +21,7 @@ import (
|
|||
|
||||
einoCompose "github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ import (
|
|||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
|
|
@ -349,7 +349,7 @@ func convertClauseGroupToConditionGroup(_ context.Context, clauseGroup *database
|
|||
)
|
||||
|
||||
conditionGroup := &database.ConditionGroup{
|
||||
Conditions: make([]*database.Condition, 0),
|
||||
Conditions: make([]*database.ConditionStr, 0),
|
||||
Relation: database.ClauseRelationAND,
|
||||
}
|
||||
|
||||
|
|
@ -362,7 +362,7 @@ func convertClauseGroupToConditionGroup(_ context.Context, clauseGroup *database
|
|||
}
|
||||
}
|
||||
|
||||
conditionGroup.Conditions = append(conditionGroup.Conditions, &database.Condition{
|
||||
conditionGroup.Conditions = append(conditionGroup.Conditions, &database.ConditionStr{
|
||||
Left: clause.Left,
|
||||
Operator: clause.Operator,
|
||||
Right: rightValue,
|
||||
|
|
@ -373,7 +373,7 @@ func convertClauseGroupToConditionGroup(_ context.Context, clauseGroup *database
|
|||
if clauseGroup.Multi != nil {
|
||||
conditionGroup.Relation = clauseGroup.Multi.Relation
|
||||
|
||||
conditionGroup.Conditions = make([]*database.Condition, len(clauseGroup.Multi.Clauses))
|
||||
conditionGroup.Conditions = make([]*database.ConditionStr, len(clauseGroup.Multi.Clauses))
|
||||
multiSelect := clauseGroup.Multi
|
||||
for idx, clause := range multiSelect.Clauses {
|
||||
if !notNeedTakeMapValue(clause.Operator) {
|
||||
|
|
@ -382,7 +382,7 @@ func convertClauseGroupToConditionGroup(_ context.Context, clauseGroup *database
|
|||
return nil, fmt.Errorf("cannot take multi clause from input")
|
||||
}
|
||||
}
|
||||
conditionGroup.Conditions[idx] = &database.Condition{
|
||||
conditionGroup.Conditions[idx] = &database.ConditionStr{
|
||||
Left: clause.Left,
|
||||
Operator: clause.Operator,
|
||||
Right: rightValue,
|
||||
|
|
|
|||
|
|
@ -24,7 +24,8 @@ import (
|
|||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
|
||||
crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
|
|
@ -88,7 +89,6 @@ func (c *CustomSQLConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...s
|
|||
databaseInfoID: c.DatabaseInfoID,
|
||||
sqlTemplate: c.SQLTemplate,
|
||||
outputTypes: ns.OutputTypes,
|
||||
customSQLExecutor: database.GetDatabaseOperator(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
@ -96,7 +96,6 @@ type CustomSQL struct {
|
|||
databaseInfoID int64
|
||||
sqlTemplate string
|
||||
outputTypes map[string]*vo.TypeInfo
|
||||
customSQLExecutor database.DatabaseOperator
|
||||
}
|
||||
|
||||
func (c *CustomSQL) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
|
|
@ -155,7 +154,7 @@ func (c *CustomSQL) Invoke(ctx context.Context, input map[string]any) (map[strin
|
|||
templateSQL = strings.Replace(templateSQL, "`?`", "?", -1)
|
||||
req.SQL = templateSQL
|
||||
req.Params = sqlParams
|
||||
response, err := c.customSQLExecutor.Execute(ctx, req)
|
||||
response, err := crossdatabase.DefaultSVC().Execute(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,8 +24,9 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database/databasemock"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
|
||||
crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/database/databasemock"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
|
|
@ -78,10 +79,9 @@ func TestCustomSQL_Execute(t *testing.T) {
|
|||
},
|
||||
}).Build().UnPatch()
|
||||
|
||||
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
|
||||
mockDatabaseOperator := databasemock.NewMockDatabase(ctrl)
|
||||
mockDatabaseOperator.EXPECT().Execute(gomock.Any(), gomock.Any()).DoAndReturn(mockSQLer.Execute()).AnyTimes()
|
||||
|
||||
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
|
||||
crossdatabase.SetDefaultSVC(mockDatabaseOperator)
|
||||
|
||||
cfg := &CustomSQLConfig{
|
||||
DatabaseInfoID: 111,
|
||||
|
|
|
|||
|
|
@ -22,7 +22,8 @@ import (
|
|||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
|
||||
crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
|
|
@ -87,7 +88,6 @@ func (d *DeleteConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...sche
|
|||
databaseInfoID: d.DatabaseInfoID,
|
||||
clauseGroup: d.ClauseGroup,
|
||||
outputTypes: ns.OutputTypes,
|
||||
deleter: database.GetDatabaseOperator(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
@ -95,7 +95,6 @@ type Delete struct {
|
|||
databaseInfoID int64
|
||||
clauseGroup *database.ClauseGroup
|
||||
outputTypes map[string]*vo.TypeInfo
|
||||
deleter database.DatabaseOperator
|
||||
}
|
||||
|
||||
func (d *Delete) Invoke(ctx context.Context, in map[string]any) (map[string]any, error) {
|
||||
|
|
@ -111,7 +110,7 @@ func (d *Delete) Invoke(ctx context.Context, in map[string]any) (map[string]any,
|
|||
ConnectorID: getConnectorID(ctx),
|
||||
}
|
||||
|
||||
response, err := d.deleter.Delete(ctx, request)
|
||||
response, err := crossdatabase.DefaultSVC().Delete(ctx, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,7 +22,8 @@ import (
|
|||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
|
||||
crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
|
|
@ -73,14 +74,12 @@ func (i *InsertConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...sche
|
|||
return &Insert{
|
||||
databaseInfoID: i.DatabaseInfoID,
|
||||
outputTypes: ns.OutputTypes,
|
||||
inserter: database.GetDatabaseOperator(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type Insert struct {
|
||||
databaseInfoID int64
|
||||
outputTypes map[string]*vo.TypeInfo
|
||||
inserter database.DatabaseOperator
|
||||
}
|
||||
|
||||
func (is *Insert) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
|
|
@ -93,7 +92,7 @@ func (is *Insert) Invoke(ctx context.Context, input map[string]any) (map[string]
|
|||
ConnectorID: getConnectorID(ctx),
|
||||
}
|
||||
|
||||
response, err := is.inserter.Insert(ctx, req)
|
||||
response, err := crossdatabase.DefaultSVC().Insert(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,7 +22,8 @@ import (
|
|||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
|
||||
crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
|
|
@ -114,7 +115,6 @@ func (q *QueryConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...schem
|
|||
outputTypes: ns.OutputTypes,
|
||||
clauseGroup: q.ClauseGroup,
|
||||
limit: q.Limit,
|
||||
op: database.GetDatabaseOperator(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
@ -125,7 +125,6 @@ type Query struct {
|
|||
outputTypes map[string]*vo.TypeInfo
|
||||
clauseGroup *database.ClauseGroup
|
||||
limit int64
|
||||
op database.DatabaseOperator
|
||||
}
|
||||
|
||||
func (ds *Query) Invoke(ctx context.Context, in map[string]any) (map[string]any, error) {
|
||||
|
|
@ -146,7 +145,7 @@ func (ds *Query) Invoke(ctx context.Context, in map[string]any) (map[string]any,
|
|||
|
||||
req.ConditionGroup = conditionGroup
|
||||
|
||||
response, err := ds.op.Query(ctx, req)
|
||||
response, err := crossdatabase.DefaultSVC().Query(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -26,8 +26,9 @@ import (
|
|||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database/databasemock"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
|
||||
crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/database/databasemock"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
|
|
@ -95,10 +96,9 @@ func TestDataset_Query(t *testing.T) {
|
|||
assert.Equal(t, cGroup.Conditions[0].Operator, cfg.ClauseGroup.Single.Operator)
|
||||
|
||||
}}
|
||||
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
|
||||
mockDatabaseOperator := databasemock.NewMockDatabase(ctrl)
|
||||
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query())
|
||||
|
||||
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
|
||||
crossdatabase.SetDefaultSVC(mockDatabaseOperator)
|
||||
|
||||
ds, err := cfg.Build(context.Background(), &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
|
|
@ -159,10 +159,9 @@ func TestDataset_Query(t *testing.T) {
|
|||
assert.Equal(t, cGroup.Relation, cfg.ClauseGroup.Multi.Relation)
|
||||
|
||||
}}
|
||||
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
|
||||
mockDatabaseOperator := databasemock.NewMockDatabase(ctrl)
|
||||
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
|
||||
|
||||
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
|
||||
crossdatabase.SetDefaultSVC(mockDatabaseOperator)
|
||||
|
||||
ds, err := cfg.Build(context.Background(), &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
|
|
@ -219,10 +218,9 @@ func TestDataset_Query(t *testing.T) {
|
|||
assert.Equal(t, cGroup.Conditions[0].Operator, cfg.ClauseGroup.Single.Operator)
|
||||
|
||||
}}
|
||||
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
|
||||
mockDatabaseOperator := databasemock.NewMockDatabase(ctrl)
|
||||
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
|
||||
|
||||
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
|
||||
crossdatabase.SetDefaultSVC(mockDatabaseOperator)
|
||||
|
||||
ds, err := cfg.Build(context.Background(), &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
|
|
@ -278,10 +276,9 @@ func TestDataset_Query(t *testing.T) {
|
|||
assert.Equal(t, cGroup.Conditions[0].Left, cfg.ClauseGroup.Single.Left)
|
||||
assert.Equal(t, cGroup.Conditions[0].Operator, cfg.ClauseGroup.Single.Operator)
|
||||
}}
|
||||
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
|
||||
mockDatabaseOperator := databasemock.NewMockDatabase(ctrl)
|
||||
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
|
||||
|
||||
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
|
||||
crossdatabase.SetDefaultSVC(mockDatabaseOperator)
|
||||
|
||||
ds, err := cfg.Build(context.Background(), &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
|
|
@ -347,10 +344,9 @@ func TestDataset_Query(t *testing.T) {
|
|||
assert.Equal(t, cGroup.Conditions[0].Operator, cfg.ClauseGroup.Single.Operator)
|
||||
|
||||
}}
|
||||
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
|
||||
mockDatabaseOperator := databasemock.NewMockDatabase(ctrl)
|
||||
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
|
||||
|
||||
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
|
||||
crossdatabase.SetDefaultSVC(mockDatabaseOperator)
|
||||
|
||||
ds, err := cfg.Build(context.Background(), &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
|
|
@ -425,10 +421,9 @@ func TestDataset_Query(t *testing.T) {
|
|||
assert.Equal(t, cGroup.Conditions[0].Operator, cfg.ClauseGroup.Single.Operator)
|
||||
|
||||
}}
|
||||
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
|
||||
mockDatabaseOperator := databasemock.NewMockDatabase(ctrl)
|
||||
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
|
||||
|
||||
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
|
||||
crossdatabase.SetDefaultSVC(mockDatabaseOperator)
|
||||
|
||||
ds, err := cfg.Build(context.Background(), &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
|
|
|
|||
|
|
@ -22,7 +22,8 @@ import (
|
|||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
|
||||
crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
|
|
@ -89,7 +90,6 @@ func (u *UpdateConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...sche
|
|||
databaseInfoID: u.DatabaseInfoID,
|
||||
clauseGroup: u.ClauseGroup,
|
||||
outputTypes: ns.OutputTypes,
|
||||
updater: database.GetDatabaseOperator(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
@ -97,7 +97,6 @@ type Update struct {
|
|||
databaseInfoID int64
|
||||
clauseGroup *database.ClauseGroup
|
||||
outputTypes map[string]*vo.TypeInfo
|
||||
updater database.DatabaseOperator
|
||||
}
|
||||
|
||||
type updateInventory struct {
|
||||
|
|
@ -126,7 +125,7 @@ func (u *Update) Invoke(ctx context.Context, in map[string]any) (map[string]any,
|
|||
ConnectorID: getConnectorID(ctx),
|
||||
}
|
||||
|
||||
response, err := u.updater.Update(ctx, req)
|
||||
response, err := crossdatabase.DefaultSVC().Update(ctx, req)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
|||
Loading…
Reference in New Issue