fix(knowledge): Fix the issue where knowledge cannot execute aggregated SQL (#794)
This commit is contained in:
parent
f940edf585
commit
5e9740c047
|
|
@ -217,6 +217,7 @@ type RetrieveContext struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type KnowledgeInfo struct {
|
type KnowledgeInfo struct {
|
||||||
|
KnowledgeName string
|
||||||
DocumentIDs []int64
|
DocumentIDs []int64
|
||||||
DocumentType knowledge.DocumentType
|
DocumentType knowledge.DocumentType
|
||||||
TableColumns []*entity.TableColumn
|
TableColumns []*entity.TableColumn
|
||||||
|
|
|
||||||
|
|
@ -18,11 +18,9 @@ package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
|
||||||
|
|
@ -50,6 +48,7 @@ import (
|
||||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/sets"
|
"github.com/coze-dev/coze-studio/backend/pkg/lang/sets"
|
||||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||||
|
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -127,6 +126,7 @@ func (k *knowledgeSVC) newRetrieveContext(ctx context.Context, req *RetrieveRequ
|
||||||
knowledgeInfoMap[kn.ID] = &KnowledgeInfo{}
|
knowledgeInfoMap[kn.ID] = &KnowledgeInfo{}
|
||||||
knowledgeInfoMap[kn.ID].DocumentType = knowledgeModel.DocumentType(kn.FormatType)
|
knowledgeInfoMap[kn.ID].DocumentType = knowledgeModel.DocumentType(kn.FormatType)
|
||||||
knowledgeInfoMap[kn.ID].DocumentIDs = []int64{}
|
knowledgeInfoMap[kn.ID].DocumentIDs = []int64{}
|
||||||
|
knowledgeInfoMap[kn.ID].KnowledgeName = kn.Name
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for _, doc := range enableDocs {
|
for _, doc := range enableDocs {
|
||||||
|
|
@ -404,15 +404,26 @@ func (k *knowledgeSVC) nl2SqlExec(ctx context.Context, doc *model.KnowledgeDocum
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
for i := range resp.ResultSet.Rows {
|
for i := range resp.ResultSet.Rows {
|
||||||
|
d := &schema.Document{
|
||||||
|
Content: "",
|
||||||
|
MetaData: map[string]any{
|
||||||
|
"document_id": doc.ID,
|
||||||
|
"document_name": doc.Name,
|
||||||
|
"knowledge_id": doc.KnowledgeID,
|
||||||
|
"knowledge_name": retrieveCtx.KnowledgeInfoMap[doc.KnowledgeID].KnowledgeName,
|
||||||
|
},
|
||||||
|
}
|
||||||
id, ok := resp.ResultSet.Rows[i][consts.RDBFieldID].(int64)
|
id, ok := resp.ResultSet.Rows[i][consts.RDBFieldID].(int64)
|
||||||
if !ok {
|
if !ok {
|
||||||
logs.CtxWarnf(ctx, "convert id failed, row: %v", resp.ResultSet.Rows[i])
|
byteData, err := sonic.Marshal(resp.ResultSet.Rows)
|
||||||
return nil, errors.New("convert id failed")
|
if err != nil {
|
||||||
|
logs.CtxErrorf(ctx, "marshal sql resp failed: %v", err)
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
d := &schema.Document{
|
prefix := "sql:" + sql + ";result:"
|
||||||
ID: strconv.FormatInt(id, 10),
|
d.Content = prefix + string(byteData)
|
||||||
Content: "",
|
} else {
|
||||||
MetaData: map[string]any{},
|
d.ID = strconv.FormatInt(id, 10)
|
||||||
}
|
}
|
||||||
d.WithScore(1)
|
d.WithScore(1)
|
||||||
retrieveResult = append(retrieveResult, d)
|
retrieveResult = append(retrieveResult, d)
|
||||||
|
|
@ -423,29 +434,13 @@ func (k *knowledgeSVC) nl2SqlExec(ctx context.Context, doc *model.KnowledgeDocum
|
||||||
const pkID = "_knowledge_slice_id"
|
const pkID = "_knowledge_slice_id"
|
||||||
|
|
||||||
func addSliceIdColumn(originalSql string) string {
|
func addSliceIdColumn(originalSql string) string {
|
||||||
lowerSql := strings.ToLower(originalSql)
|
sql, err := sqlparser.NewSQLParser().AddSelectFieldsToSelectSQL(originalSql, []string{pkID})
|
||||||
selectIndex := strings.Index(lowerSql, "select ")
|
if err != nil {
|
||||||
if selectIndex == -1 {
|
logs.Errorf("add slice id column failed: %v", err)
|
||||||
return originalSql
|
return originalSql
|
||||||
}
|
}
|
||||||
result := originalSql[:selectIndex+len("select ")] // Keep selected part
|
return sql
|
||||||
remainder := originalSql[selectIndex+len("select "):]
|
|
||||||
|
|
||||||
lowerRemainder := strings.ToLower(remainder)
|
|
||||||
fromIndex := strings.Index(lowerRemainder, " from")
|
|
||||||
if fromIndex == -1 {
|
|
||||||
return originalSql
|
|
||||||
}
|
|
||||||
|
|
||||||
columns := strings.TrimSpace(remainder[:fromIndex])
|
|
||||||
if columns != "*" {
|
|
||||||
columns += ", " + pkID
|
|
||||||
}
|
|
||||||
|
|
||||||
result += columns + remainder[fromIndex:]
|
|
||||||
return result
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func packNL2SqlRequest(doc *model.KnowledgeDocument) *document.TableSchema {
|
func packNL2SqlRequest(doc *model.KnowledgeDocument) *document.TableSchema {
|
||||||
res := &document.TableSchema{}
|
res := &document.TableSchema{}
|
||||||
if doc.TableInfo == nil {
|
if doc.TableInfo == nil {
|
||||||
|
|
@ -561,11 +556,31 @@ func (k *knowledgeSVC) packResults(ctx context.Context, retrieveResult []*schema
|
||||||
sliceIDs := make(sets.Set[int64])
|
sliceIDs := make(sets.Set[int64])
|
||||||
docIDs := make(sets.Set[int64])
|
docIDs := make(sets.Set[int64])
|
||||||
knowledgeIDs := make(sets.Set[int64])
|
knowledgeIDs := make(sets.Set[int64])
|
||||||
|
results = []*knowledgeModel.RetrieveSlice{}
|
||||||
documentMap := map[int64]*model.KnowledgeDocument{}
|
documentMap := map[int64]*model.KnowledgeDocument{}
|
||||||
knowledgeMap := map[int64]*model.Knowledge{}
|
knowledgeMap := map[int64]*model.Knowledge{}
|
||||||
sliceScoreMap := map[int64]float64{}
|
sliceScoreMap := map[int64]float64{}
|
||||||
for _, doc := range retrieveResult {
|
for _, doc := range retrieveResult {
|
||||||
|
if len(doc.ID) == 0 {
|
||||||
|
results = append(results, &knowledgeModel.RetrieveSlice{
|
||||||
|
Slice: &knowledgeModel.Slice{
|
||||||
|
KnowledgeID: doc.MetaData["knowledge_id"].(int64),
|
||||||
|
DocumentID: doc.MetaData["document_id"].(int64),
|
||||||
|
DocumentName: doc.MetaData["document_name"].(string),
|
||||||
|
RawContent: []*knowledgeModel.SliceContent{
|
||||||
|
{
|
||||||
|
Type: knowledgeModel.SliceContentTypeText,
|
||||||
|
Text: ptr.Of(doc.Content),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Extra: map[string]string{
|
||||||
|
consts.KnowledgeName: doc.MetaData["knowledge_name"].(string),
|
||||||
|
consts.DocumentURL: "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Score: 1,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
id, err := strconv.ParseInt(doc.ID, 10, 64)
|
id, err := strconv.ParseInt(doc.ID, 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logs.CtxErrorf(ctx, "convert id failed: %v", err)
|
logs.CtxErrorf(ctx, "convert id failed: %v", err)
|
||||||
|
|
@ -574,6 +589,7 @@ func (k *knowledgeSVC) packResults(ctx context.Context, retrieveResult []*schema
|
||||||
sliceIDs[id] = struct{}{}
|
sliceIDs[id] = struct{}{}
|
||||||
sliceScoreMap[id] = doc.Score()
|
sliceScoreMap[id] = doc.Score()
|
||||||
}
|
}
|
||||||
|
}
|
||||||
slices, err := k.sliceRepo.MGetSlices(ctx, sliceIDs.ToSlice())
|
slices, err := k.sliceRepo.MGetSlices(ctx, sliceIDs.ToSlice())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logs.CtxErrorf(ctx, "mget slices failed: %v", err)
|
logs.CtxErrorf(ctx, "mget slices failed: %v", err)
|
||||||
|
|
@ -625,7 +641,6 @@ func (k *knowledgeSVC) packResults(ctx context.Context, retrieveResult []*schema
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
results = []*knowledgeModel.RetrieveSlice{}
|
|
||||||
for i := range slices {
|
for i := range slices {
|
||||||
doc := documentMap[slices[i].DocumentID]
|
doc := documentMap[slices[i].DocumentID]
|
||||||
kn := knowledgeMap[slices[i].KnowledgeID]
|
kn := knowledgeMap[slices[i].KnowledgeID]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,236 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2025 coze-dev Authors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
|
"github.com/coze-dev/coze-studio/backend/domain/knowledge/entity"
|
||||||
|
"github.com/coze-dev/coze-studio/backend/domain/knowledge/internal/dal/model"
|
||||||
|
"github.com/coze-dev/coze-studio/backend/domain/knowledge/repository"
|
||||||
|
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
|
||||||
|
"github.com/coze-dev/coze-studio/backend/infra/contract/document/nl2sql"
|
||||||
|
"github.com/coze-dev/coze-studio/backend/infra/contract/rdb"
|
||||||
|
rdb_entity "github.com/coze-dev/coze-studio/backend/infra/contract/rdb/entity"
|
||||||
|
mock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/nl2sql"
|
||||||
|
mock_db "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/rdb"
|
||||||
|
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||||
|
"github.com/coze-dev/coze-studio/backend/pkg/lang/sets"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"go.uber.org/mock/gomock"
|
||||||
|
"gorm.io/driver/mysql"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAddSliceIdColumn(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple select",
|
||||||
|
input: "SELECT name, age FROM users",
|
||||||
|
expected: "SELECT `name`,`age`,`_knowledge_slice_id` FROM `users`",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "select stmt wrong",
|
||||||
|
input: "SELECT FROM users",
|
||||||
|
expected: "SELECT FROM users",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
actual := addSliceIdColumn(tt.input)
|
||||||
|
if actual != tt.expected {
|
||||||
|
t.Errorf("AddSliceIdColumn() = %v, want %v", actual, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNL2sqlExec(t *testing.T) {
|
||||||
|
svc := knowledgeSVC{}
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
db := mock_db.NewMockRDB(ctrl)
|
||||||
|
nl2SQL := mock.NewMockNL2SQL(ctrl)
|
||||||
|
nl2SQL.EXPECT().NL2SQL(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, messages []*schema.Message, tables []*document.TableSchema, opts ...nl2sql.Option) (sql string, err error) {
|
||||||
|
return "select count(*) from users", nil
|
||||||
|
})
|
||||||
|
db.EXPECT().ExecuteSQL(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, req *rdb.ExecuteSQLRequest) (*rdb.ExecuteSQLResponse, error) {
|
||||||
|
return &rdb.ExecuteSQLResponse{
|
||||||
|
ResultSet: &rdb_entity.ResultSet{Rows: []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"count(*)": 100,
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
}, nil
|
||||||
|
})
|
||||||
|
svc.nl2Sql = nl2SQL
|
||||||
|
svc.rdb = db
|
||||||
|
ctx := context.Background()
|
||||||
|
docu := model.KnowledgeDocument{
|
||||||
|
ID: 110,
|
||||||
|
KnowledgeID: 111,
|
||||||
|
Name: "users",
|
||||||
|
FileExtension: "xlsx",
|
||||||
|
DocumentType: 1,
|
||||||
|
CreatorID: 666,
|
||||||
|
SpaceID: 666,
|
||||||
|
Status: 1,
|
||||||
|
TableInfo: &entity.TableInfo{
|
||||||
|
VirtualTableName: "users",
|
||||||
|
PhysicalTableName: "table_111",
|
||||||
|
TableDesc: "user table",
|
||||||
|
Columns: []*entity.TableColumn{
|
||||||
|
{
|
||||||
|
ID: 1,
|
||||||
|
Name: "_knowledge_slice_id",
|
||||||
|
Type: document.TableColumnTypeInteger,
|
||||||
|
Description: "id",
|
||||||
|
Indexing: false,
|
||||||
|
Sequence: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: 2,
|
||||||
|
Name: "name",
|
||||||
|
Type: document.TableColumnTypeString,
|
||||||
|
Description: "name",
|
||||||
|
Indexing: true,
|
||||||
|
Sequence: 2,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
retrieveCtx := &RetrieveContext{
|
||||||
|
Ctx: ctx,
|
||||||
|
OriginQuery: "select count(*) from users",
|
||||||
|
KnowledgeIDs: sets.FromSlice[int64]([]int64{111}),
|
||||||
|
Documents: []*model.KnowledgeDocument{&docu},
|
||||||
|
KnowledgeInfoMap: map[int64]*KnowledgeInfo{
|
||||||
|
111: &KnowledgeInfo{
|
||||||
|
KnowledgeName: "users",
|
||||||
|
DocumentIDs: []int64{110},
|
||||||
|
DocumentType: 1,
|
||||||
|
TableColumns: []*entity.TableColumn{
|
||||||
|
{
|
||||||
|
ID: 1,
|
||||||
|
Name: "_knowledge_slice_id",
|
||||||
|
Type: document.TableColumnTypeInteger,
|
||||||
|
Description: "id",
|
||||||
|
Indexing: false,
|
||||||
|
Sequence: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: 2,
|
||||||
|
Name: "name",
|
||||||
|
Type: document.TableColumnTypeString,
|
||||||
|
Description: "name",
|
||||||
|
Indexing: true,
|
||||||
|
Sequence: 2,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
docs, err := svc.nl2SqlExec(ctx, &docu, retrieveCtx, nil)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
assert.Equal(t, 1, len(docs))
|
||||||
|
assert.Equal(t, "sql:select count(*) from users;result:[{\"count(*)\":100}]", docs[0].Content)
|
||||||
|
nl2SQL.EXPECT().NL2SQL(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, messages []*schema.Message, tables []*document.TableSchema, opts ...nl2sql.Option) (sql string, err error) {
|
||||||
|
return "", errors.New("nl2sql error")
|
||||||
|
})
|
||||||
|
_, err = svc.nl2SqlExec(ctx, &docu, retrieveCtx, nil)
|
||||||
|
assert.Equal(t, "nl2sql error", err.Error())
|
||||||
|
db.EXPECT().ExecuteSQL(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, req *rdb.ExecuteSQLRequest) (*rdb.ExecuteSQLResponse, error) {
|
||||||
|
return nil, errors.New("rdb error")
|
||||||
|
})
|
||||||
|
nl2SQL.EXPECT().NL2SQL(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, messages []*schema.Message, tables []*document.TableSchema, opts ...nl2sql.Option) (sql string, err error) {
|
||||||
|
return "select count(*) from users", nil
|
||||||
|
})
|
||||||
|
_, err = svc.nl2SqlExec(ctx, &docu, retrieveCtx, nil)
|
||||||
|
assert.Equal(t, "rdb error", err.Error())
|
||||||
|
db.EXPECT().ExecuteSQL(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, req *rdb.ExecuteSQLRequest) (*rdb.ExecuteSQLResponse, error) {
|
||||||
|
return &rdb.ExecuteSQLResponse{
|
||||||
|
ResultSet: &rdb_entity.ResultSet{Rows: []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"name": "666",
|
||||||
|
"_knowledge_document_slice_id": int64(999),
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
}, nil
|
||||||
|
})
|
||||||
|
nl2SQL.EXPECT().NL2SQL(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, messages []*schema.Message, tables []*document.TableSchema, opts ...nl2sql.Option) (sql string, err error) {
|
||||||
|
return "select name from users", nil
|
||||||
|
})
|
||||||
|
docs, err = svc.nl2SqlExec(ctx, &docu, retrieveCtx, nil)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
assert.Equal(t, 1, len(docs))
|
||||||
|
assert.Equal(t, "999", docs[0].ID)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPackResults(t *testing.T) {
|
||||||
|
svc := knowledgeSVC{}
|
||||||
|
ctx := context.Background()
|
||||||
|
svc.packResults(ctx, []*schema.Document{})
|
||||||
|
dsn := "root:root@tcp(127.0.0.1:3306)/opencoze?charset=utf8mb4&parseTime=True&loc=Local"
|
||||||
|
if os.Getenv("CI_JOB_NAME") != "" {
|
||||||
|
dsn = strings.ReplaceAll(dsn, "127.0.0.1", "mysql")
|
||||||
|
}
|
||||||
|
gormDB, err := gorm.Open(mysql.Open(dsn))
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
svc.knowledgeRepo = repository.NewKnowledgeDAO(gormDB)
|
||||||
|
svc.documentRepo = repository.NewKnowledgeDocumentDAO(gormDB)
|
||||||
|
svc.sliceRepo = repository.NewKnowledgeDocumentSliceDAO(gormDB)
|
||||||
|
docs := []*schema.Document{
|
||||||
|
{
|
||||||
|
ID: "",
|
||||||
|
Content: "sql:select count(*) from users;result:[{\"count(*)\":100}]",
|
||||||
|
MetaData: map[string]any{
|
||||||
|
"knowledge_id": int64(111),
|
||||||
|
"document_id": int64(110),
|
||||||
|
"document_name": "users",
|
||||||
|
"knowledge_name": "users",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
res, err := svc.packResults(ctx, docs)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
assert.Equal(t, 1, len(res))
|
||||||
|
assert.Equal(t, "sql:select count(*) from users;result:[{\"count(*)\":100}]", ptr.From(res[0].Slice.RawContent[0].Text))
|
||||||
|
docs = []*schema.Document{
|
||||||
|
{
|
||||||
|
ID: "10000",
|
||||||
|
Content: "",
|
||||||
|
MetaData: map[string]any{
|
||||||
|
"knowledge_id": int64(111),
|
||||||
|
"document_id": int64(110),
|
||||||
|
"document_name": "users",
|
||||||
|
"knowledge_name": "users",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
res, err = svc.packResults(ctx, docs)
|
||||||
|
assert.Equal(t, 0, len(res))
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
}
|
||||||
|
|
@ -24,6 +24,7 @@ import (
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
|
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
//go:generate mockgen -destination ../../../internal/mock/infra/contract/nl2sql_mock/nl2sql_mock.go -package mock -source nl2sql.go Factory
|
||||||
type NL2SQL interface {
|
type NL2SQL interface {
|
||||||
NL2SQL(ctx context.Context, messages []*schema.Message, tables []*document.TableSchema, opts ...Option) (sql string, err error)
|
NL2SQL(ctx context.Context, messages []*schema.Message, tables []*document.TableSchema, opts ...Option) (sql string, err error)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -74,4 +74,7 @@ type SQLParser interface {
|
||||||
|
|
||||||
// AppendSQLFilter appends a filter condition to the SQL statement.
|
// AppendSQLFilter appends a filter condition to the SQL statement.
|
||||||
AppendSQLFilter(sql string, op SQLFilterOp, filter string) (string, error)
|
AppendSQLFilter(sql string, op SQLFilterOp, filter string) (string, error)
|
||||||
|
|
||||||
|
// AddSelectFieldsToSelectSQL add select fields to select sql
|
||||||
|
AddSelectFieldsToSelectSQL(origSQL string, cols []string) (string, error)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -490,3 +490,72 @@ func mergeExpr(left, right ast.ExprNode, op sqlparser.SQLFilterOp) ast.ExprNode
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *Impl) AddSelectFieldsToSelectSQL(origSQL string, cols []string) (string, error) {
|
||||||
|
if origSQL == "" {
|
||||||
|
return "", fmt.Errorf("empty SQL statement")
|
||||||
|
}
|
||||||
|
stmtNode, err := p.parser.ParseOneStmt(origSQL, mysql.UTF8MB4Charset, mysql.UTF8MB4GeneralCICollation)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to parse SQL: %v", err)
|
||||||
|
}
|
||||||
|
stmt, ok := stmtNode.(*ast.SelectStmt)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("not a select statement")
|
||||||
|
}
|
||||||
|
if containsAggregate(stmt) || isSelectAll(stmt) {
|
||||||
|
return origSQL, nil
|
||||||
|
}
|
||||||
|
for _, col := range cols {
|
||||||
|
stmt.Fields.Fields = append(stmt.Fields.Fields, &ast.SelectField{
|
||||||
|
Expr: &ast.ColumnNameExpr{
|
||||||
|
Name: &ast.ColumnName{
|
||||||
|
Name: ast.CIStr{O: col, L: strings.ToLower(col)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
var sb strings.Builder
|
||||||
|
restoreCtx := format.NewRestoreCtx(format.DefaultRestoreFlags, &sb)
|
||||||
|
if err := stmt.Restore(restoreCtx); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type aggregateVisitor struct {
|
||||||
|
hasAggregate bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *aggregateVisitor) Enter(n ast.Node) (node ast.Node, skipChildren bool) {
|
||||||
|
switch n.(type) {
|
||||||
|
case *ast.AggregateFuncExpr:
|
||||||
|
v.hasAggregate = true
|
||||||
|
return n, true
|
||||||
|
}
|
||||||
|
return n, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *aggregateVisitor) Leave(n ast.Node) (node ast.Node, ok bool) {
|
||||||
|
return n, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsAggregate(stmt *ast.SelectStmt) bool {
|
||||||
|
visitor := &aggregateVisitor{}
|
||||||
|
stmt.Accept(visitor)
|
||||||
|
return visitor.hasAggregate
|
||||||
|
}
|
||||||
|
func isSelectAll(stmt *ast.SelectStmt) bool {
|
||||||
|
for _, field := range stmt.Fields.Fields {
|
||||||
|
if field.WildCard != nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if columnExpr, ok := field.Expr.(*ast.ColumnNameExpr); ok {
|
||||||
|
if columnExpr.Name.Name.L == "*" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -751,3 +751,71 @@ func TestAppendSQLFilter(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAddSliceIdColumn(t *testing.T) {
|
||||||
|
parser := NewSQLParser().(*Impl)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple select",
|
||||||
|
input: "SELECT name, age FROM users",
|
||||||
|
expected: "SELECT `name`,`age`,`pk_id` FROM `users`",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "select star",
|
||||||
|
input: "SELECT * FROM users",
|
||||||
|
expected: "SELECT * FROM users",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "count function",
|
||||||
|
input: "SELECT COUNT(*) FROM users",
|
||||||
|
expected: "SELECT COUNT(*) FROM users",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "complex aggregate",
|
||||||
|
input: "SELECT AVG(age), MAX(score) FROM users",
|
||||||
|
expected: "SELECT AVG(age), MAX(score) FROM users",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with alias",
|
||||||
|
input: "SELECT u.name, u.age FROM users u",
|
||||||
|
expected: "SELECT `u`.`name`,`u`.`age`,`pk_id` FROM `users` AS `u`",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "sql is empty",
|
||||||
|
input: "",
|
||||||
|
expected: "",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "sql is not select",
|
||||||
|
input: "INSERT INTO users (name, age) VALUES ('Alice', 30)",
|
||||||
|
expected: "",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "sql is wrong",
|
||||||
|
input: "SELECT * users WHERE name = 'Alice'",
|
||||||
|
expected: "",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := parser.AddSelectFieldsToSelectSQL(tt.input, []string{"pk_id"})
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("addSliceIdColumn() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !strings.EqualFold(strings.TrimSpace(got), strings.TrimSpace(tt.expected)) {
|
||||||
|
t.Errorf("addSliceIdColumn() = %v, want %v", got, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,80 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2025 coze-dev Authors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Code generated by MockGen. DO NOT EDIT.
|
||||||
|
// Source: nl2sql.go
|
||||||
|
//
|
||||||
|
// Generated by this command:
|
||||||
|
//
|
||||||
|
// mockgen -destination ../../../internal/mock/infra/contract/nl2sql_mock/nl2sql_mock.go -package mock -source nl2sql.go Factory
|
||||||
|
//
|
||||||
|
|
||||||
|
// Package mock is a generated GoMock package.
|
||||||
|
package mock
|
||||||
|
|
||||||
|
import (
|
||||||
|
context "context"
|
||||||
|
reflect "reflect"
|
||||||
|
|
||||||
|
schema "github.com/cloudwego/eino/schema"
|
||||||
|
document "github.com/coze-dev/coze-studio/backend/infra/contract/document"
|
||||||
|
nl2sql "github.com/coze-dev/coze-studio/backend/infra/contract/document/nl2sql"
|
||||||
|
gomock "go.uber.org/mock/gomock"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockNL2SQL is a mock of NL2SQL interface.
|
||||||
|
type MockNL2SQL struct {
|
||||||
|
ctrl *gomock.Controller
|
||||||
|
recorder *MockNL2SQLMockRecorder
|
||||||
|
isgomock struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockNL2SQLMockRecorder is the mock recorder for MockNL2SQL.
|
||||||
|
type MockNL2SQLMockRecorder struct {
|
||||||
|
mock *MockNL2SQL
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockNL2SQL creates a new mock instance.
|
||||||
|
func NewMockNL2SQL(ctrl *gomock.Controller) *MockNL2SQL {
|
||||||
|
mock := &MockNL2SQL{ctrl: ctrl}
|
||||||
|
mock.recorder = &MockNL2SQLMockRecorder{mock}
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||||
|
func (m *MockNL2SQL) EXPECT() *MockNL2SQLMockRecorder {
|
||||||
|
return m.recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// NL2SQL mocks base method.
|
||||||
|
func (m *MockNL2SQL) NL2SQL(ctx context.Context, messages []*schema.Message, tables []*document.TableSchema, opts ...nl2sql.Option) (string, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
varargs := []any{ctx, messages, tables}
|
||||||
|
for _, a := range opts {
|
||||||
|
varargs = append(varargs, a)
|
||||||
|
}
|
||||||
|
ret := m.ctrl.Call(m, "NL2SQL", varargs...)
|
||||||
|
ret0, _ := ret[0].(string)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// NL2SQL indicates an expected call of NL2SQL.
|
||||||
|
func (mr *MockNL2SQLMockRecorder) NL2SQL(ctx, messages, tables any, opts ...any) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
varargs := append([]any{ctx, messages, tables}, opts...)
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NL2SQL", reflect.TypeOf((*MockNL2SQL)(nil).NL2SQL), varargs...)
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue