From 5e9740c04727f13e5ee9918a7bf6d1a3ce0c532a Mon Sep 17 00:00:00 2001 From: liuyunchao-1998 Date: Tue, 19 Aug 2025 15:03:30 +0800 Subject: [PATCH] fix(knowledge): Fix the issue where knowledge cannot execute aggregated SQL (#794) --- backend/domain/knowledge/service/interface.go | 7 +- backend/domain/knowledge/service/retrieve.go | 89 ++++--- .../domain/knowledge/service/retrieve_test.go | 236 ++++++++++++++++++ .../infra/contract/document/nl2sql/nl2sql.go | 1 + .../infra/contract/sqlparser/sql_parser.go | 3 + backend/infra/impl/sqlparser/sql_parser.go | 69 +++++ .../infra/impl/sqlparser/sql_parser_test.go | 68 +++++ .../mock/infra/contract/nl2sql/nl2sql_mock.go | 80 ++++++ 8 files changed, 513 insertions(+), 40 deletions(-) create mode 100644 backend/domain/knowledge/service/retrieve_test.go create mode 100644 backend/internal/mock/infra/contract/nl2sql/nl2sql_mock.go diff --git a/backend/domain/knowledge/service/interface.go b/backend/domain/knowledge/service/interface.go index f2c9364f..aec51491 100644 --- a/backend/domain/knowledge/service/interface.go +++ b/backend/domain/knowledge/service/interface.go @@ -217,9 +217,10 @@ type RetrieveContext struct { } type KnowledgeInfo struct { - DocumentIDs []int64 - DocumentType knowledge.DocumentType - TableColumns []*entity.TableColumn + KnowledgeName string + DocumentIDs []int64 + DocumentType knowledge.DocumentType + TableColumns []*entity.TableColumn } type AlterTableSchemaRequest struct { DocumentID int64 diff --git a/backend/domain/knowledge/service/retrieve.go b/backend/domain/knowledge/service/retrieve.go index f653d38e..9e0cc0f2 100644 --- a/backend/domain/knowledge/service/retrieve.go +++ b/backend/domain/knowledge/service/retrieve.go @@ -18,11 +18,9 @@ package service import ( "context" - "errors" "fmt" "regexp" "strconv" - "strings" "sync" "unicode/utf8" @@ -50,6 +48,7 @@ import ( "github.com/coze-dev/coze-studio/backend/pkg/lang/sets" "github.com/coze-dev/coze-studio/backend/pkg/lang/slices" "github.com/coze-dev/coze-studio/backend/pkg/logs" + "github.com/coze-dev/coze-studio/backend/pkg/sonic" "github.com/coze-dev/coze-studio/backend/types/errno" ) @@ -127,6 +126,7 @@ func (k *knowledgeSVC) newRetrieveContext(ctx context.Context, req *RetrieveRequ knowledgeInfoMap[kn.ID] = &KnowledgeInfo{} knowledgeInfoMap[kn.ID].DocumentType = knowledgeModel.DocumentType(kn.FormatType) knowledgeInfoMap[kn.ID].DocumentIDs = []int64{} + knowledgeInfoMap[kn.ID].KnowledgeName = kn.Name } } for _, doc := range enableDocs { @@ -404,15 +404,26 @@ func (k *knowledgeSVC) nl2SqlExec(ctx context.Context, doc *model.KnowledgeDocum return nil, err } for i := range resp.ResultSet.Rows { + d := &schema.Document{ + Content: "", + MetaData: map[string]any{ + "document_id": doc.ID, + "document_name": doc.Name, + "knowledge_id": doc.KnowledgeID, + "knowledge_name": retrieveCtx.KnowledgeInfoMap[doc.KnowledgeID].KnowledgeName, + }, + } id, ok := resp.ResultSet.Rows[i][consts.RDBFieldID].(int64) if !ok { - logs.CtxWarnf(ctx, "convert id failed, row: %v", resp.ResultSet.Rows[i]) - return nil, errors.New("convert id failed") - } - d := &schema.Document{ - ID: strconv.FormatInt(id, 10), - Content: "", - MetaData: map[string]any{}, + byteData, err := sonic.Marshal(resp.ResultSet.Rows) + if err != nil { + logs.CtxErrorf(ctx, "marshal sql resp failed: %v", err) + return nil, err + } + prefix := "sql:" + sql + ";result:" + d.Content = prefix + string(byteData) + } else { + d.ID = strconv.FormatInt(id, 10) } d.WithScore(1) retrieveResult = append(retrieveResult, d) @@ -423,29 +434,13 @@ func (k *knowledgeSVC) nl2SqlExec(ctx context.Context, doc *model.KnowledgeDocum const pkID = "_knowledge_slice_id" func addSliceIdColumn(originalSql string) string { - lowerSql := strings.ToLower(originalSql) - selectIndex := strings.Index(lowerSql, "select ") - if selectIndex == -1 { + sql, err := sqlparser.NewSQLParser().AddSelectFieldsToSelectSQL(originalSql, []string{pkID}) + if err != nil { + logs.Errorf("add slice id column failed: %v", err) return originalSql } - result := originalSql[:selectIndex+len("select ")] // Keep selected part - remainder := originalSql[selectIndex+len("select "):] - - lowerRemainder := strings.ToLower(remainder) - fromIndex := strings.Index(lowerRemainder, " from") - if fromIndex == -1 { - return originalSql - } - - columns := strings.TrimSpace(remainder[:fromIndex]) - if columns != "*" { - columns += ", " + pkID - } - - result += columns + remainder[fromIndex:] - return result + return sql } - func packNL2SqlRequest(doc *model.KnowledgeDocument) *document.TableSchema { res := &document.TableSchema{} if doc.TableInfo == nil { @@ -561,18 +556,39 @@ func (k *knowledgeSVC) packResults(ctx context.Context, retrieveResult []*schema sliceIDs := make(sets.Set[int64]) docIDs := make(sets.Set[int64]) knowledgeIDs := make(sets.Set[int64]) - + results = []*knowledgeModel.RetrieveSlice{} documentMap := map[int64]*model.KnowledgeDocument{} knowledgeMap := map[int64]*model.Knowledge{} sliceScoreMap := map[int64]float64{} for _, doc := range retrieveResult { - id, err := strconv.ParseInt(doc.ID, 10, 64) - if err != nil { - logs.CtxErrorf(ctx, "convert id failed: %v", err) - return nil, errorx.New(errno.ErrKnowledgeSystemCode, errorx.KV("msg", "convert id failed")) + if len(doc.ID) == 0 { + results = append(results, &knowledgeModel.RetrieveSlice{ + Slice: &knowledgeModel.Slice{ + KnowledgeID: doc.MetaData["knowledge_id"].(int64), + DocumentID: doc.MetaData["document_id"].(int64), + DocumentName: doc.MetaData["document_name"].(string), + RawContent: []*knowledgeModel.SliceContent{ + { + Type: knowledgeModel.SliceContentTypeText, + Text: ptr.Of(doc.Content), + }, + }, + Extra: map[string]string{ + consts.KnowledgeName: doc.MetaData["knowledge_name"].(string), + consts.DocumentURL: "", + }, + }, + Score: 1, + }) + } else { + id, err := strconv.ParseInt(doc.ID, 10, 64) + if err != nil { + logs.CtxErrorf(ctx, "convert id failed: %v", err) + return nil, errorx.New(errno.ErrKnowledgeSystemCode, errorx.KV("msg", "convert id failed")) + } + sliceIDs[id] = struct{}{} + sliceScoreMap[id] = doc.Score() } - sliceIDs[id] = struct{}{} - sliceScoreMap[id] = doc.Score() } slices, err := k.sliceRepo.MGetSlices(ctx, sliceIDs.ToSlice()) if err != nil { @@ -625,7 +641,6 @@ func (k *knowledgeSVC) packResults(ctx context.Context, retrieveResult []*schema return nil, err } } - results = []*knowledgeModel.RetrieveSlice{} for i := range slices { doc := documentMap[slices[i].DocumentID] kn := knowledgeMap[slices[i].KnowledgeID] diff --git a/backend/domain/knowledge/service/retrieve_test.go b/backend/domain/knowledge/service/retrieve_test.go new file mode 100644 index 00000000..8166e0b2 --- /dev/null +++ b/backend/domain/knowledge/service/retrieve_test.go @@ -0,0 +1,236 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package service + +import ( + "context" + "errors" + "os" + "strings" + "testing" + + "github.com/cloudwego/eino/schema" + "github.com/coze-dev/coze-studio/backend/domain/knowledge/entity" + "github.com/coze-dev/coze-studio/backend/domain/knowledge/internal/dal/model" + "github.com/coze-dev/coze-studio/backend/domain/knowledge/repository" + "github.com/coze-dev/coze-studio/backend/infra/contract/document" + "github.com/coze-dev/coze-studio/backend/infra/contract/document/nl2sql" + "github.com/coze-dev/coze-studio/backend/infra/contract/rdb" + rdb_entity "github.com/coze-dev/coze-studio/backend/infra/contract/rdb/entity" + mock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/nl2sql" + mock_db "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/rdb" + "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" + "github.com/coze-dev/coze-studio/backend/pkg/lang/sets" + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + "gorm.io/driver/mysql" + "gorm.io/gorm" +) + +func TestAddSliceIdColumn(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "simple select", + input: "SELECT name, age FROM users", + expected: "SELECT `name`,`age`,`_knowledge_slice_id` FROM `users`", + }, + { + name: "select stmt wrong", + input: "SELECT FROM users", + expected: "SELECT FROM users", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actual := addSliceIdColumn(tt.input) + if actual != tt.expected { + t.Errorf("AddSliceIdColumn() = %v, want %v", actual, tt.expected) + } + }) + } +} + +func TestNL2sqlExec(t *testing.T) { + svc := knowledgeSVC{} + ctrl := gomock.NewController(t) + db := mock_db.NewMockRDB(ctrl) + nl2SQL := mock.NewMockNL2SQL(ctrl) + nl2SQL.EXPECT().NL2SQL(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, messages []*schema.Message, tables []*document.TableSchema, opts ...nl2sql.Option) (sql string, err error) { + return "select count(*) from users", nil + }) + db.EXPECT().ExecuteSQL(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, req *rdb.ExecuteSQLRequest) (*rdb.ExecuteSQLResponse, error) { + return &rdb.ExecuteSQLResponse{ + ResultSet: &rdb_entity.ResultSet{Rows: []map[string]interface{}{ + { + "count(*)": 100, + }, + }}, + }, nil + }) + svc.nl2Sql = nl2SQL + svc.rdb = db + ctx := context.Background() + docu := model.KnowledgeDocument{ + ID: 110, + KnowledgeID: 111, + Name: "users", + FileExtension: "xlsx", + DocumentType: 1, + CreatorID: 666, + SpaceID: 666, + Status: 1, + TableInfo: &entity.TableInfo{ + VirtualTableName: "users", + PhysicalTableName: "table_111", + TableDesc: "user table", + Columns: []*entity.TableColumn{ + { + ID: 1, + Name: "_knowledge_slice_id", + Type: document.TableColumnTypeInteger, + Description: "id", + Indexing: false, + Sequence: 1, + }, + { + ID: 2, + Name: "name", + Type: document.TableColumnTypeString, + Description: "name", + Indexing: true, + Sequence: 2, + }, + }, + }, + } + retrieveCtx := &RetrieveContext{ + Ctx: ctx, + OriginQuery: "select count(*) from users", + KnowledgeIDs: sets.FromSlice[int64]([]int64{111}), + Documents: []*model.KnowledgeDocument{&docu}, + KnowledgeInfoMap: map[int64]*KnowledgeInfo{ + 111: &KnowledgeInfo{ + KnowledgeName: "users", + DocumentIDs: []int64{110}, + DocumentType: 1, + TableColumns: []*entity.TableColumn{ + { + ID: 1, + Name: "_knowledge_slice_id", + Type: document.TableColumnTypeInteger, + Description: "id", + Indexing: false, + Sequence: 1, + }, + { + ID: 2, + Name: "name", + Type: document.TableColumnTypeString, + Description: "name", + Indexing: true, + Sequence: 2, + }, + }, + }, + }, + } + docs, err := svc.nl2SqlExec(ctx, &docu, retrieveCtx, nil) + assert.Equal(t, nil, err) + assert.Equal(t, 1, len(docs)) + assert.Equal(t, "sql:select count(*) from users;result:[{\"count(*)\":100}]", docs[0].Content) + nl2SQL.EXPECT().NL2SQL(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, messages []*schema.Message, tables []*document.TableSchema, opts ...nl2sql.Option) (sql string, err error) { + return "", errors.New("nl2sql error") + }) + _, err = svc.nl2SqlExec(ctx, &docu, retrieveCtx, nil) + assert.Equal(t, "nl2sql error", err.Error()) + db.EXPECT().ExecuteSQL(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, req *rdb.ExecuteSQLRequest) (*rdb.ExecuteSQLResponse, error) { + return nil, errors.New("rdb error") + }) + nl2SQL.EXPECT().NL2SQL(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, messages []*schema.Message, tables []*document.TableSchema, opts ...nl2sql.Option) (sql string, err error) { + return "select count(*) from users", nil + }) + _, err = svc.nl2SqlExec(ctx, &docu, retrieveCtx, nil) + assert.Equal(t, "rdb error", err.Error()) + db.EXPECT().ExecuteSQL(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, req *rdb.ExecuteSQLRequest) (*rdb.ExecuteSQLResponse, error) { + return &rdb.ExecuteSQLResponse{ + ResultSet: &rdb_entity.ResultSet{Rows: []map[string]interface{}{ + { + "name": "666", + "_knowledge_document_slice_id": int64(999), + }, + }}, + }, nil + }) + nl2SQL.EXPECT().NL2SQL(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, messages []*schema.Message, tables []*document.TableSchema, opts ...nl2sql.Option) (sql string, err error) { + return "select name from users", nil + }) + docs, err = svc.nl2SqlExec(ctx, &docu, retrieveCtx, nil) + assert.Equal(t, nil, err) + assert.Equal(t, 1, len(docs)) + assert.Equal(t, "999", docs[0].ID) + +} + +func TestPackResults(t *testing.T) { + svc := knowledgeSVC{} + ctx := context.Background() + svc.packResults(ctx, []*schema.Document{}) + dsn := "root:root@tcp(127.0.0.1:3306)/opencoze?charset=utf8mb4&parseTime=True&loc=Local" + if os.Getenv("CI_JOB_NAME") != "" { + dsn = strings.ReplaceAll(dsn, "127.0.0.1", "mysql") + } + gormDB, err := gorm.Open(mysql.Open(dsn)) + assert.Equal(t, nil, err) + svc.knowledgeRepo = repository.NewKnowledgeDAO(gormDB) + svc.documentRepo = repository.NewKnowledgeDocumentDAO(gormDB) + svc.sliceRepo = repository.NewKnowledgeDocumentSliceDAO(gormDB) + docs := []*schema.Document{ + { + ID: "", + Content: "sql:select count(*) from users;result:[{\"count(*)\":100}]", + MetaData: map[string]any{ + "knowledge_id": int64(111), + "document_id": int64(110), + "document_name": "users", + "knowledge_name": "users", + }, + }, + } + res, err := svc.packResults(ctx, docs) + assert.Equal(t, nil, err) + assert.Equal(t, 1, len(res)) + assert.Equal(t, "sql:select count(*) from users;result:[{\"count(*)\":100}]", ptr.From(res[0].Slice.RawContent[0].Text)) + docs = []*schema.Document{ + { + ID: "10000", + Content: "", + MetaData: map[string]any{ + "knowledge_id": int64(111), + "document_id": int64(110), + "document_name": "users", + "knowledge_name": "users", + }, + }, + } + res, err = svc.packResults(ctx, docs) + assert.Equal(t, 0, len(res)) + assert.Equal(t, nil, err) +} diff --git a/backend/infra/contract/document/nl2sql/nl2sql.go b/backend/infra/contract/document/nl2sql/nl2sql.go index 9d45a5f3..d0939a11 100644 --- a/backend/infra/contract/document/nl2sql/nl2sql.go +++ b/backend/infra/contract/document/nl2sql/nl2sql.go @@ -24,6 +24,7 @@ import ( "github.com/coze-dev/coze-studio/backend/infra/contract/document" ) +//go:generate mockgen -destination ../../../internal/mock/infra/contract/nl2sql_mock/nl2sql_mock.go -package mock -source nl2sql.go Factory type NL2SQL interface { NL2SQL(ctx context.Context, messages []*schema.Message, tables []*document.TableSchema, opts ...Option) (sql string, err error) } diff --git a/backend/infra/contract/sqlparser/sql_parser.go b/backend/infra/contract/sqlparser/sql_parser.go index 07ddcded..796c1c25 100644 --- a/backend/infra/contract/sqlparser/sql_parser.go +++ b/backend/infra/contract/sqlparser/sql_parser.go @@ -74,4 +74,7 @@ type SQLParser interface { // AppendSQLFilter appends a filter condition to the SQL statement. AppendSQLFilter(sql string, op SQLFilterOp, filter string) (string, error) + + // AddSelectFieldsToSelectSQL add select fields to select sql + AddSelectFieldsToSelectSQL(origSQL string, cols []string) (string, error) } diff --git a/backend/infra/impl/sqlparser/sql_parser.go b/backend/infra/impl/sqlparser/sql_parser.go index b8d1711a..052e62dc 100644 --- a/backend/infra/impl/sqlparser/sql_parser.go +++ b/backend/infra/impl/sqlparser/sql_parser.go @@ -490,3 +490,72 @@ func mergeExpr(left, right ast.ExprNode, op sqlparser.SQLFilterOp) ast.ExprNode return nil } } + +func (p *Impl) AddSelectFieldsToSelectSQL(origSQL string, cols []string) (string, error) { + if origSQL == "" { + return "", fmt.Errorf("empty SQL statement") + } + stmtNode, err := p.parser.ParseOneStmt(origSQL, mysql.UTF8MB4Charset, mysql.UTF8MB4GeneralCICollation) + if err != nil { + return "", fmt.Errorf("failed to parse SQL: %v", err) + } + stmt, ok := stmtNode.(*ast.SelectStmt) + if !ok { + return "", fmt.Errorf("not a select statement") + } + if containsAggregate(stmt) || isSelectAll(stmt) { + return origSQL, nil + } + for _, col := range cols { + stmt.Fields.Fields = append(stmt.Fields.Fields, &ast.SelectField{ + Expr: &ast.ColumnNameExpr{ + Name: &ast.ColumnName{ + Name: ast.CIStr{O: col, L: strings.ToLower(col)}, + }, + }, + }) + } + var sb strings.Builder + restoreCtx := format.NewRestoreCtx(format.DefaultRestoreFlags, &sb) + if err := stmt.Restore(restoreCtx); err != nil { + return "", err + } + + return sb.String(), nil +} + +type aggregateVisitor struct { + hasAggregate bool +} + +func (v *aggregateVisitor) Enter(n ast.Node) (node ast.Node, skipChildren bool) { + switch n.(type) { + case *ast.AggregateFuncExpr: + v.hasAggregate = true + return n, true + } + return n, false +} + +func (v *aggregateVisitor) Leave(n ast.Node) (node ast.Node, ok bool) { + return n, true +} + +func containsAggregate(stmt *ast.SelectStmt) bool { + visitor := &aggregateVisitor{} + stmt.Accept(visitor) + return visitor.hasAggregate +} +func isSelectAll(stmt *ast.SelectStmt) bool { + for _, field := range stmt.Fields.Fields { + if field.WildCard != nil { + return true + } + if columnExpr, ok := field.Expr.(*ast.ColumnNameExpr); ok { + if columnExpr.Name.Name.L == "*" { + return true + } + } + } + return false +} diff --git a/backend/infra/impl/sqlparser/sql_parser_test.go b/backend/infra/impl/sqlparser/sql_parser_test.go index 8053a171..2338ee24 100644 --- a/backend/infra/impl/sqlparser/sql_parser_test.go +++ b/backend/infra/impl/sqlparser/sql_parser_test.go @@ -751,3 +751,71 @@ func TestAppendSQLFilter(t *testing.T) { } } + +func TestAddSliceIdColumn(t *testing.T) { + parser := NewSQLParser().(*Impl) + + tests := []struct { + name string + input string + expected string + wantErr bool + }{ + { + name: "simple select", + input: "SELECT name, age FROM users", + expected: "SELECT `name`,`age`,`pk_id` FROM `users`", + }, + { + name: "select star", + input: "SELECT * FROM users", + expected: "SELECT * FROM users", + }, + { + name: "count function", + input: "SELECT COUNT(*) FROM users", + expected: "SELECT COUNT(*) FROM users", + }, + { + name: "complex aggregate", + input: "SELECT AVG(age), MAX(score) FROM users", + expected: "SELECT AVG(age), MAX(score) FROM users", + }, + { + name: "with alias", + input: "SELECT u.name, u.age FROM users u", + expected: "SELECT `u`.`name`,`u`.`age`,`pk_id` FROM `users` AS `u`", + }, + { + name: "sql is empty", + input: "", + expected: "", + wantErr: true, + }, + { + name: "sql is not select", + input: "INSERT INTO users (name, age) VALUES ('Alice', 30)", + expected: "", + wantErr: true, + }, + { + name: "sql is wrong", + input: "SELECT * users WHERE name = 'Alice'", + expected: "", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parser.AddSelectFieldsToSelectSQL(tt.input, []string{"pk_id"}) + if (err != nil) != tt.wantErr { + t.Errorf("addSliceIdColumn() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !strings.EqualFold(strings.TrimSpace(got), strings.TrimSpace(tt.expected)) { + t.Errorf("addSliceIdColumn() = %v, want %v", got, tt.expected) + } + }) + } +} diff --git a/backend/internal/mock/infra/contract/nl2sql/nl2sql_mock.go b/backend/internal/mock/infra/contract/nl2sql/nl2sql_mock.go new file mode 100644 index 00000000..ae74f968 --- /dev/null +++ b/backend/internal/mock/infra/contract/nl2sql/nl2sql_mock.go @@ -0,0 +1,80 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Code generated by MockGen. DO NOT EDIT. +// Source: nl2sql.go +// +// Generated by this command: +// +// mockgen -destination ../../../internal/mock/infra/contract/nl2sql_mock/nl2sql_mock.go -package mock -source nl2sql.go Factory +// + +// Package mock is a generated GoMock package. +package mock + +import ( + context "context" + reflect "reflect" + + schema "github.com/cloudwego/eino/schema" + document "github.com/coze-dev/coze-studio/backend/infra/contract/document" + nl2sql "github.com/coze-dev/coze-studio/backend/infra/contract/document/nl2sql" + gomock "go.uber.org/mock/gomock" +) + +// MockNL2SQL is a mock of NL2SQL interface. +type MockNL2SQL struct { + ctrl *gomock.Controller + recorder *MockNL2SQLMockRecorder + isgomock struct{} +} + +// MockNL2SQLMockRecorder is the mock recorder for MockNL2SQL. +type MockNL2SQLMockRecorder struct { + mock *MockNL2SQL +} + +// NewMockNL2SQL creates a new mock instance. +func NewMockNL2SQL(ctrl *gomock.Controller) *MockNL2SQL { + mock := &MockNL2SQL{ctrl: ctrl} + mock.recorder = &MockNL2SQLMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockNL2SQL) EXPECT() *MockNL2SQLMockRecorder { + return m.recorder +} + +// NL2SQL mocks base method. +func (m *MockNL2SQL) NL2SQL(ctx context.Context, messages []*schema.Message, tables []*document.TableSchema, opts ...nl2sql.Option) (string, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, messages, tables} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "NL2SQL", varargs...) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// NL2SQL indicates an expected call of NL2SQL. +func (mr *MockNL2SQLMockRecorder) NL2SQL(ctx, messages, tables any, opts ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, messages, tables}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NL2SQL", reflect.TypeOf((*MockNL2SQL)(nil).NL2SQL), varargs...) +}