fix(knowledge): Fix the issue where knowledge cannot execute aggregated SQL (#794)

This commit is contained in:
liuyunchao-1998
2025-08-19 15:03:30 +08:00
committed by GitHub
parent f940edf585
commit 5e9740c047
8 changed files with 513 additions and 40 deletions

View File

@@ -217,9 +217,10 @@ type RetrieveContext struct {
}
type KnowledgeInfo struct {
DocumentIDs []int64
DocumentType knowledge.DocumentType
TableColumns []*entity.TableColumn
KnowledgeName string
DocumentIDs []int64
DocumentType knowledge.DocumentType
TableColumns []*entity.TableColumn
}
type AlterTableSchemaRequest struct {
DocumentID int64

View File

@@ -18,11 +18,9 @@ package service
import (
"context"
"errors"
"fmt"
"regexp"
"strconv"
"strings"
"sync"
"unicode/utf8"
@@ -50,6 +48,7 @@ import (
"github.com/coze-dev/coze-studio/backend/pkg/lang/sets"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
@@ -127,6 +126,7 @@ func (k *knowledgeSVC) newRetrieveContext(ctx context.Context, req *RetrieveRequ
knowledgeInfoMap[kn.ID] = &KnowledgeInfo{}
knowledgeInfoMap[kn.ID].DocumentType = knowledgeModel.DocumentType(kn.FormatType)
knowledgeInfoMap[kn.ID].DocumentIDs = []int64{}
knowledgeInfoMap[kn.ID].KnowledgeName = kn.Name
}
}
for _, doc := range enableDocs {
@@ -404,15 +404,26 @@ func (k *knowledgeSVC) nl2SqlExec(ctx context.Context, doc *model.KnowledgeDocum
return nil, err
}
for i := range resp.ResultSet.Rows {
d := &schema.Document{
Content: "",
MetaData: map[string]any{
"document_id": doc.ID,
"document_name": doc.Name,
"knowledge_id": doc.KnowledgeID,
"knowledge_name": retrieveCtx.KnowledgeInfoMap[doc.KnowledgeID].KnowledgeName,
},
}
id, ok := resp.ResultSet.Rows[i][consts.RDBFieldID].(int64)
if !ok {
logs.CtxWarnf(ctx, "convert id failed, row: %v", resp.ResultSet.Rows[i])
return nil, errors.New("convert id failed")
}
d := &schema.Document{
ID: strconv.FormatInt(id, 10),
Content: "",
MetaData: map[string]any{},
byteData, err := sonic.Marshal(resp.ResultSet.Rows)
if err != nil {
logs.CtxErrorf(ctx, "marshal sql resp failed: %v", err)
return nil, err
}
prefix := "sql:" + sql + ";result:"
d.Content = prefix + string(byteData)
} else {
d.ID = strconv.FormatInt(id, 10)
}
d.WithScore(1)
retrieveResult = append(retrieveResult, d)
@@ -423,29 +434,13 @@ func (k *knowledgeSVC) nl2SqlExec(ctx context.Context, doc *model.KnowledgeDocum
const pkID = "_knowledge_slice_id"
func addSliceIdColumn(originalSql string) string {
lowerSql := strings.ToLower(originalSql)
selectIndex := strings.Index(lowerSql, "select ")
if selectIndex == -1 {
sql, err := sqlparser.NewSQLParser().AddSelectFieldsToSelectSQL(originalSql, []string{pkID})
if err != nil {
logs.Errorf("add slice id column failed: %v", err)
return originalSql
}
result := originalSql[:selectIndex+len("select ")] // Keep selected part
remainder := originalSql[selectIndex+len("select "):]
lowerRemainder := strings.ToLower(remainder)
fromIndex := strings.Index(lowerRemainder, " from")
if fromIndex == -1 {
return originalSql
}
columns := strings.TrimSpace(remainder[:fromIndex])
if columns != "*" {
columns += ", " + pkID
}
result += columns + remainder[fromIndex:]
return result
return sql
}
func packNL2SqlRequest(doc *model.KnowledgeDocument) *document.TableSchema {
res := &document.TableSchema{}
if doc.TableInfo == nil {
@@ -561,18 +556,39 @@ func (k *knowledgeSVC) packResults(ctx context.Context, retrieveResult []*schema
sliceIDs := make(sets.Set[int64])
docIDs := make(sets.Set[int64])
knowledgeIDs := make(sets.Set[int64])
results = []*knowledgeModel.RetrieveSlice{}
documentMap := map[int64]*model.KnowledgeDocument{}
knowledgeMap := map[int64]*model.Knowledge{}
sliceScoreMap := map[int64]float64{}
for _, doc := range retrieveResult {
id, err := strconv.ParseInt(doc.ID, 10, 64)
if err != nil {
logs.CtxErrorf(ctx, "convert id failed: %v", err)
return nil, errorx.New(errno.ErrKnowledgeSystemCode, errorx.KV("msg", "convert id failed"))
if len(doc.ID) == 0 {
results = append(results, &knowledgeModel.RetrieveSlice{
Slice: &knowledgeModel.Slice{
KnowledgeID: doc.MetaData["knowledge_id"].(int64),
DocumentID: doc.MetaData["document_id"].(int64),
DocumentName: doc.MetaData["document_name"].(string),
RawContent: []*knowledgeModel.SliceContent{
{
Type: knowledgeModel.SliceContentTypeText,
Text: ptr.Of(doc.Content),
},
},
Extra: map[string]string{
consts.KnowledgeName: doc.MetaData["knowledge_name"].(string),
consts.DocumentURL: "",
},
},
Score: 1,
})
} else {
id, err := strconv.ParseInt(doc.ID, 10, 64)
if err != nil {
logs.CtxErrorf(ctx, "convert id failed: %v", err)
return nil, errorx.New(errno.ErrKnowledgeSystemCode, errorx.KV("msg", "convert id failed"))
}
sliceIDs[id] = struct{}{}
sliceScoreMap[id] = doc.Score()
}
sliceIDs[id] = struct{}{}
sliceScoreMap[id] = doc.Score()
}
slices, err := k.sliceRepo.MGetSlices(ctx, sliceIDs.ToSlice())
if err != nil {
@@ -625,7 +641,6 @@ func (k *knowledgeSVC) packResults(ctx context.Context, retrieveResult []*schema
return nil, err
}
}
results = []*knowledgeModel.RetrieveSlice{}
for i := range slices {
doc := documentMap[slices[i].DocumentID]
kn := knowledgeMap[slices[i].KnowledgeID]

View File

@@ -0,0 +1,236 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package service
import (
"context"
"errors"
"os"
"strings"
"testing"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/domain/knowledge/entity"
"github.com/coze-dev/coze-studio/backend/domain/knowledge/internal/dal/model"
"github.com/coze-dev/coze-studio/backend/domain/knowledge/repository"
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/nl2sql"
"github.com/coze-dev/coze-studio/backend/infra/contract/rdb"
rdb_entity "github.com/coze-dev/coze-studio/backend/infra/contract/rdb/entity"
mock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/nl2sql"
mock_db "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/rdb"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/sets"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
"gorm.io/driver/mysql"
"gorm.io/gorm"
)
func TestAddSliceIdColumn(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "simple select",
input: "SELECT name, age FROM users",
expected: "SELECT `name`,`age`,`_knowledge_slice_id` FROM `users`",
},
{
name: "select stmt wrong",
input: "SELECT FROM users",
expected: "SELECT FROM users",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual := addSliceIdColumn(tt.input)
if actual != tt.expected {
t.Errorf("AddSliceIdColumn() = %v, want %v", actual, tt.expected)
}
})
}
}
func TestNL2sqlExec(t *testing.T) {
svc := knowledgeSVC{}
ctrl := gomock.NewController(t)
db := mock_db.NewMockRDB(ctrl)
nl2SQL := mock.NewMockNL2SQL(ctrl)
nl2SQL.EXPECT().NL2SQL(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, messages []*schema.Message, tables []*document.TableSchema, opts ...nl2sql.Option) (sql string, err error) {
return "select count(*) from users", nil
})
db.EXPECT().ExecuteSQL(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, req *rdb.ExecuteSQLRequest) (*rdb.ExecuteSQLResponse, error) {
return &rdb.ExecuteSQLResponse{
ResultSet: &rdb_entity.ResultSet{Rows: []map[string]interface{}{
{
"count(*)": 100,
},
}},
}, nil
})
svc.nl2Sql = nl2SQL
svc.rdb = db
ctx := context.Background()
docu := model.KnowledgeDocument{
ID: 110,
KnowledgeID: 111,
Name: "users",
FileExtension: "xlsx",
DocumentType: 1,
CreatorID: 666,
SpaceID: 666,
Status: 1,
TableInfo: &entity.TableInfo{
VirtualTableName: "users",
PhysicalTableName: "table_111",
TableDesc: "user table",
Columns: []*entity.TableColumn{
{
ID: 1,
Name: "_knowledge_slice_id",
Type: document.TableColumnTypeInteger,
Description: "id",
Indexing: false,
Sequence: 1,
},
{
ID: 2,
Name: "name",
Type: document.TableColumnTypeString,
Description: "name",
Indexing: true,
Sequence: 2,
},
},
},
}
retrieveCtx := &RetrieveContext{
Ctx: ctx,
OriginQuery: "select count(*) from users",
KnowledgeIDs: sets.FromSlice[int64]([]int64{111}),
Documents: []*model.KnowledgeDocument{&docu},
KnowledgeInfoMap: map[int64]*KnowledgeInfo{
111: &KnowledgeInfo{
KnowledgeName: "users",
DocumentIDs: []int64{110},
DocumentType: 1,
TableColumns: []*entity.TableColumn{
{
ID: 1,
Name: "_knowledge_slice_id",
Type: document.TableColumnTypeInteger,
Description: "id",
Indexing: false,
Sequence: 1,
},
{
ID: 2,
Name: "name",
Type: document.TableColumnTypeString,
Description: "name",
Indexing: true,
Sequence: 2,
},
},
},
},
}
docs, err := svc.nl2SqlExec(ctx, &docu, retrieveCtx, nil)
assert.Equal(t, nil, err)
assert.Equal(t, 1, len(docs))
assert.Equal(t, "sql:select count(*) from users;result:[{\"count(*)\":100}]", docs[0].Content)
nl2SQL.EXPECT().NL2SQL(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, messages []*schema.Message, tables []*document.TableSchema, opts ...nl2sql.Option) (sql string, err error) {
return "", errors.New("nl2sql error")
})
_, err = svc.nl2SqlExec(ctx, &docu, retrieveCtx, nil)
assert.Equal(t, "nl2sql error", err.Error())
db.EXPECT().ExecuteSQL(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, req *rdb.ExecuteSQLRequest) (*rdb.ExecuteSQLResponse, error) {
return nil, errors.New("rdb error")
})
nl2SQL.EXPECT().NL2SQL(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, messages []*schema.Message, tables []*document.TableSchema, opts ...nl2sql.Option) (sql string, err error) {
return "select count(*) from users", nil
})
_, err = svc.nl2SqlExec(ctx, &docu, retrieveCtx, nil)
assert.Equal(t, "rdb error", err.Error())
db.EXPECT().ExecuteSQL(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, req *rdb.ExecuteSQLRequest) (*rdb.ExecuteSQLResponse, error) {
return &rdb.ExecuteSQLResponse{
ResultSet: &rdb_entity.ResultSet{Rows: []map[string]interface{}{
{
"name": "666",
"_knowledge_document_slice_id": int64(999),
},
}},
}, nil
})
nl2SQL.EXPECT().NL2SQL(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, messages []*schema.Message, tables []*document.TableSchema, opts ...nl2sql.Option) (sql string, err error) {
return "select name from users", nil
})
docs, err = svc.nl2SqlExec(ctx, &docu, retrieveCtx, nil)
assert.Equal(t, nil, err)
assert.Equal(t, 1, len(docs))
assert.Equal(t, "999", docs[0].ID)
}
func TestPackResults(t *testing.T) {
svc := knowledgeSVC{}
ctx := context.Background()
svc.packResults(ctx, []*schema.Document{})
dsn := "root:root@tcp(127.0.0.1:3306)/opencoze?charset=utf8mb4&parseTime=True&loc=Local"
if os.Getenv("CI_JOB_NAME") != "" {
dsn = strings.ReplaceAll(dsn, "127.0.0.1", "mysql")
}
gormDB, err := gorm.Open(mysql.Open(dsn))
assert.Equal(t, nil, err)
svc.knowledgeRepo = repository.NewKnowledgeDAO(gormDB)
svc.documentRepo = repository.NewKnowledgeDocumentDAO(gormDB)
svc.sliceRepo = repository.NewKnowledgeDocumentSliceDAO(gormDB)
docs := []*schema.Document{
{
ID: "",
Content: "sql:select count(*) from users;result:[{\"count(*)\":100}]",
MetaData: map[string]any{
"knowledge_id": int64(111),
"document_id": int64(110),
"document_name": "users",
"knowledge_name": "users",
},
},
}
res, err := svc.packResults(ctx, docs)
assert.Equal(t, nil, err)
assert.Equal(t, 1, len(res))
assert.Equal(t, "sql:select count(*) from users;result:[{\"count(*)\":100}]", ptr.From(res[0].Slice.RawContent[0].Text))
docs = []*schema.Document{
{
ID: "10000",
Content: "",
MetaData: map[string]any{
"knowledge_id": int64(111),
"document_id": int64(110),
"document_name": "users",
"knowledge_name": "users",
},
},
}
res, err = svc.packResults(ctx, docs)
assert.Equal(t, 0, len(res))
assert.Equal(t, nil, err)
}