fix(knowledge): Fix the issue where knowledge cannot execute aggregated SQL (#794)
This commit is contained in:
@@ -490,3 +490,72 @@ func mergeExpr(left, right ast.ExprNode, op sqlparser.SQLFilterOp) ast.ExprNode
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Impl) AddSelectFieldsToSelectSQL(origSQL string, cols []string) (string, error) {
|
||||
if origSQL == "" {
|
||||
return "", fmt.Errorf("empty SQL statement")
|
||||
}
|
||||
stmtNode, err := p.parser.ParseOneStmt(origSQL, mysql.UTF8MB4Charset, mysql.UTF8MB4GeneralCICollation)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse SQL: %v", err)
|
||||
}
|
||||
stmt, ok := stmtNode.(*ast.SelectStmt)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("not a select statement")
|
||||
}
|
||||
if containsAggregate(stmt) || isSelectAll(stmt) {
|
||||
return origSQL, nil
|
||||
}
|
||||
for _, col := range cols {
|
||||
stmt.Fields.Fields = append(stmt.Fields.Fields, &ast.SelectField{
|
||||
Expr: &ast.ColumnNameExpr{
|
||||
Name: &ast.ColumnName{
|
||||
Name: ast.CIStr{O: col, L: strings.ToLower(col)},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
var sb strings.Builder
|
||||
restoreCtx := format.NewRestoreCtx(format.DefaultRestoreFlags, &sb)
|
||||
if err := stmt.Restore(restoreCtx); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
type aggregateVisitor struct {
|
||||
hasAggregate bool
|
||||
}
|
||||
|
||||
func (v *aggregateVisitor) Enter(n ast.Node) (node ast.Node, skipChildren bool) {
|
||||
switch n.(type) {
|
||||
case *ast.AggregateFuncExpr:
|
||||
v.hasAggregate = true
|
||||
return n, true
|
||||
}
|
||||
return n, false
|
||||
}
|
||||
|
||||
func (v *aggregateVisitor) Leave(n ast.Node) (node ast.Node, ok bool) {
|
||||
return n, true
|
||||
}
|
||||
|
||||
func containsAggregate(stmt *ast.SelectStmt) bool {
|
||||
visitor := &aggregateVisitor{}
|
||||
stmt.Accept(visitor)
|
||||
return visitor.hasAggregate
|
||||
}
|
||||
func isSelectAll(stmt *ast.SelectStmt) bool {
|
||||
for _, field := range stmt.Fields.Fields {
|
||||
if field.WildCard != nil {
|
||||
return true
|
||||
}
|
||||
if columnExpr, ok := field.Expr.(*ast.ColumnNameExpr); ok {
|
||||
if columnExpr.Name.Name.L == "*" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -751,3 +751,71 @@ func TestAppendSQLFilter(t *testing.T) {
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestAddSliceIdColumn(t *testing.T) {
|
||||
parser := NewSQLParser().(*Impl)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "simple select",
|
||||
input: "SELECT name, age FROM users",
|
||||
expected: "SELECT `name`,`age`,`pk_id` FROM `users`",
|
||||
},
|
||||
{
|
||||
name: "select star",
|
||||
input: "SELECT * FROM users",
|
||||
expected: "SELECT * FROM users",
|
||||
},
|
||||
{
|
||||
name: "count function",
|
||||
input: "SELECT COUNT(*) FROM users",
|
||||
expected: "SELECT COUNT(*) FROM users",
|
||||
},
|
||||
{
|
||||
name: "complex aggregate",
|
||||
input: "SELECT AVG(age), MAX(score) FROM users",
|
||||
expected: "SELECT AVG(age), MAX(score) FROM users",
|
||||
},
|
||||
{
|
||||
name: "with alias",
|
||||
input: "SELECT u.name, u.age FROM users u",
|
||||
expected: "SELECT `u`.`name`,`u`.`age`,`pk_id` FROM `users` AS `u`",
|
||||
},
|
||||
{
|
||||
name: "sql is empty",
|
||||
input: "",
|
||||
expected: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "sql is not select",
|
||||
input: "INSERT INTO users (name, age) VALUES ('Alice', 30)",
|
||||
expected: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "sql is wrong",
|
||||
input: "SELECT * users WHERE name = 'Alice'",
|
||||
expected: "",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := parser.AddSelectFieldsToSelectSQL(tt.input, []string{"pk_id"})
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("addSliceIdColumn() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !strings.EqualFold(strings.TrimSpace(got), strings.TrimSpace(tt.expected)) {
|
||||
t.Errorf("addSliceIdColumn() = %v, want %v", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user