fix(knowledge): Fix the issue where knowledge cannot execute aggregated SQL (#794)

This commit is contained in:
liuyunchao-1998
2025-08-19 15:03:30 +08:00
committed by GitHub
parent f940edf585
commit 5e9740c047
8 changed files with 513 additions and 40 deletions

View File

@@ -24,6 +24,7 @@ import (
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
)
//go:generate mockgen -destination ../../../internal/mock/infra/contract/nl2sql_mock/nl2sql_mock.go -package mock -source nl2sql.go Factory
type NL2SQL interface {
NL2SQL(ctx context.Context, messages []*schema.Message, tables []*document.TableSchema, opts ...Option) (sql string, err error)
}

View File

@@ -74,4 +74,7 @@ type SQLParser interface {
// AppendSQLFilter appends a filter condition to the SQL statement.
AppendSQLFilter(sql string, op SQLFilterOp, filter string) (string, error)
// AddSelectFieldsToSelectSQL add select fields to select sql
AddSelectFieldsToSelectSQL(origSQL string, cols []string) (string, error)
}

View File

@@ -490,3 +490,72 @@ func mergeExpr(left, right ast.ExprNode, op sqlparser.SQLFilterOp) ast.ExprNode
return nil
}
}
func (p *Impl) AddSelectFieldsToSelectSQL(origSQL string, cols []string) (string, error) {
if origSQL == "" {
return "", fmt.Errorf("empty SQL statement")
}
stmtNode, err := p.parser.ParseOneStmt(origSQL, mysql.UTF8MB4Charset, mysql.UTF8MB4GeneralCICollation)
if err != nil {
return "", fmt.Errorf("failed to parse SQL: %v", err)
}
stmt, ok := stmtNode.(*ast.SelectStmt)
if !ok {
return "", fmt.Errorf("not a select statement")
}
if containsAggregate(stmt) || isSelectAll(stmt) {
return origSQL, nil
}
for _, col := range cols {
stmt.Fields.Fields = append(stmt.Fields.Fields, &ast.SelectField{
Expr: &ast.ColumnNameExpr{
Name: &ast.ColumnName{
Name: ast.CIStr{O: col, L: strings.ToLower(col)},
},
},
})
}
var sb strings.Builder
restoreCtx := format.NewRestoreCtx(format.DefaultRestoreFlags, &sb)
if err := stmt.Restore(restoreCtx); err != nil {
return "", err
}
return sb.String(), nil
}
type aggregateVisitor struct {
hasAggregate bool
}
func (v *aggregateVisitor) Enter(n ast.Node) (node ast.Node, skipChildren bool) {
switch n.(type) {
case *ast.AggregateFuncExpr:
v.hasAggregate = true
return n, true
}
return n, false
}
func (v *aggregateVisitor) Leave(n ast.Node) (node ast.Node, ok bool) {
return n, true
}
func containsAggregate(stmt *ast.SelectStmt) bool {
visitor := &aggregateVisitor{}
stmt.Accept(visitor)
return visitor.hasAggregate
}
func isSelectAll(stmt *ast.SelectStmt) bool {
for _, field := range stmt.Fields.Fields {
if field.WildCard != nil {
return true
}
if columnExpr, ok := field.Expr.(*ast.ColumnNameExpr); ok {
if columnExpr.Name.Name.L == "*" {
return true
}
}
}
return false
}

View File

@@ -751,3 +751,71 @@ func TestAppendSQLFilter(t *testing.T) {
}
}
func TestAddSliceIdColumn(t *testing.T) {
parser := NewSQLParser().(*Impl)
tests := []struct {
name string
input string
expected string
wantErr bool
}{
{
name: "simple select",
input: "SELECT name, age FROM users",
expected: "SELECT `name`,`age`,`pk_id` FROM `users`",
},
{
name: "select star",
input: "SELECT * FROM users",
expected: "SELECT * FROM users",
},
{
name: "count function",
input: "SELECT COUNT(*) FROM users",
expected: "SELECT COUNT(*) FROM users",
},
{
name: "complex aggregate",
input: "SELECT AVG(age), MAX(score) FROM users",
expected: "SELECT AVG(age), MAX(score) FROM users",
},
{
name: "with alias",
input: "SELECT u.name, u.age FROM users u",
expected: "SELECT `u`.`name`,`u`.`age`,`pk_id` FROM `users` AS `u`",
},
{
name: "sql is empty",
input: "",
expected: "",
wantErr: true,
},
{
name: "sql is not select",
input: "INSERT INTO users (name, age) VALUES ('Alice', 30)",
expected: "",
wantErr: true,
},
{
name: "sql is wrong",
input: "SELECT * users WHERE name = 'Alice'",
expected: "",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := parser.AddSelectFieldsToSelectSQL(tt.input, []string{"pk_id"})
if (err != nil) != tt.wantErr {
t.Errorf("addSliceIdColumn() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !strings.EqualFold(strings.TrimSpace(got), strings.TrimSpace(tt.expected)) {
t.Errorf("addSliceIdColumn() = %v, want %v", got, tt.expected)
}
})
}
}