493 lines
15 KiB
Go
493 lines
15 KiB
Go
/*
|
|
* Copyright 2025 coze-dev Authors
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package sqlparser
|
|
|
|
import (
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/pingcap/tidb/pkg/parser"
|
|
"github.com/pingcap/tidb/pkg/parser/ast"
|
|
"github.com/pingcap/tidb/pkg/parser/format"
|
|
"github.com/pingcap/tidb/pkg/parser/mysql"
|
|
"github.com/pingcap/tidb/pkg/parser/opcode"
|
|
_ "github.com/pingcap/tidb/pkg/parser/test_driver"
|
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/contract/sqlparser"
|
|
)
|
|
|
|
// Impl implements the SQLParser interface
|
|
type Impl struct {
|
|
parser *parser.Parser
|
|
}
|
|
|
|
// NewSQLParser creates a new SQL parser
|
|
func NewSQLParser() sqlparser.SQLParser {
|
|
p := parser.New()
|
|
return &Impl{
|
|
parser: p,
|
|
}
|
|
}
|
|
|
|
// ParseAndModifySQL implements the SQLParser interface
|
|
|
|
func (p *Impl) ParseAndModifySQL(sql string, tableColumns map[string]sqlparser.TableColumn) (string, error) {
|
|
if len(tableColumns) == 0 {
|
|
return sql, nil
|
|
}
|
|
|
|
// check tableColumns
|
|
for originalTableName, tableColumn := range tableColumns {
|
|
if originalTableName == "" {
|
|
return "", fmt.Errorf("original TableName must be non-empty")
|
|
}
|
|
|
|
// Check if ColumnMap is either empty or all key-value pairs are non-empty
|
|
if tableColumn.ColumnMap != nil {
|
|
for key, value := range tableColumn.ColumnMap {
|
|
if (key == "") != (value == "") {
|
|
return "", fmt.Errorf("ColumnMap key and value must be either both empty or both non-empty")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Parse SQL
|
|
stmt, err := p.parser.ParseOneStmt(sql, mysql.UTF8MB4Charset, mysql.UTF8MB4GeneralCICollation)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to parse SQL: %v", err)
|
|
}
|
|
|
|
// First pass: collect all table aliases
|
|
aliasCollector := NewAliasCollector()
|
|
stmt.Accept(aliasCollector)
|
|
|
|
for originalTableName, _ := range tableColumns {
|
|
if _, ok := aliasCollector.tableAliases[originalTableName]; ok {
|
|
return "", fmt.Errorf("alisa table name should not equal with origin table name")
|
|
}
|
|
}
|
|
|
|
// Second pass: modify the AST with collected aliases
|
|
modifier := NewSQLModifier(tableColumns, aliasCollector.tableAliases)
|
|
stmt.Accept(modifier)
|
|
|
|
// Convert modified AST back to SQL
|
|
var sb strings.Builder
|
|
// Use single quotes for string values & remove charset prefix
|
|
flags := format.RestoreStringSingleQuotes | format.RestoreStringWithoutCharset
|
|
restoreCtx := format.NewRestoreCtx(flags, &sb)
|
|
err = stmt.Restore(restoreCtx)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to restore SQL: %v", err)
|
|
}
|
|
|
|
return sb.String(), nil
|
|
}
|
|
|
|
// AliasCollector collects table aliases in a first pass
|
|
type AliasCollector struct {
|
|
tableAliases map[string]string // key is alias, value is original table name
|
|
}
|
|
|
|
// NewAliasCollector creates a new alias collector
|
|
func NewAliasCollector() *AliasCollector {
|
|
return &AliasCollector{
|
|
tableAliases: make(map[string]string),
|
|
}
|
|
}
|
|
|
|
// Enter implements ast.Visitor interface
|
|
func (c *AliasCollector) Enter(n ast.Node) (ast.Node, bool) {
|
|
if node, ok := n.(*ast.TableSource); ok {
|
|
if ts, nameOk := node.Source.(*ast.TableName); nameOk {
|
|
if node.AsName.L != "" {
|
|
c.tableAliases[node.AsName.L] = ts.Name.L
|
|
}
|
|
}
|
|
}
|
|
return n, false
|
|
}
|
|
|
|
// Leave implements ast.Visitor interface
|
|
func (c *AliasCollector) Leave(n ast.Node) (ast.Node, bool) {
|
|
return n, true
|
|
}
|
|
|
|
// SQLModifier is used to modify SQL AST
|
|
type SQLModifier struct {
|
|
tableMap map[string]string // key is original table name, value is new table name
|
|
columnMap map[string]map[string]string // key is table name, value is column name mapping
|
|
tableAliases map[string]string // key is table alias, value is original table name
|
|
}
|
|
|
|
// NewSQLModifier creates a new SQL modifier with pre-collected aliases
|
|
func NewSQLModifier(tableColumns map[string]sqlparser.TableColumn, tableAliases map[string]string) *SQLModifier {
|
|
modifier := &SQLModifier{
|
|
tableMap: make(map[string]string),
|
|
columnMap: make(map[string]map[string]string),
|
|
tableAliases: tableAliases,
|
|
}
|
|
|
|
// Initialize table and column name mappings
|
|
for originalTableName, tableColumn := range tableColumns {
|
|
if tableColumn.NewTableName != nil && *tableColumn.NewTableName != "" {
|
|
modifier.tableMap[originalTableName] = *tableColumn.NewTableName
|
|
}
|
|
modifier.columnMap[originalTableName] = tableColumn.ColumnMap
|
|
}
|
|
|
|
return modifier
|
|
}
|
|
|
|
// Enter implements ast.Visitor interface
|
|
func (m *SQLModifier) Enter(n ast.Node) (ast.Node, bool) {
|
|
switch node := n.(type) {
|
|
case *ast.TableName:
|
|
// Replace table name
|
|
if newTableName, ok := m.tableMap[node.Name.L]; ok {
|
|
// Modify all related fields of table name
|
|
node.Name.L = newTableName
|
|
node.Name.O = newTableName
|
|
}
|
|
case *ast.ColumnName:
|
|
// Replace column name with the appropriate mapping
|
|
if node.Table.L != "" {
|
|
// Get the table name or alias
|
|
tableRef := node.Table.L
|
|
|
|
// If this is an alias, look up the original table name for column mapping
|
|
originalTable, isAlias := m.tableAliases[tableRef]
|
|
|
|
if isAlias {
|
|
// For aliased tables, apply column mapping using the original table name
|
|
if columnMap, ok := m.columnMap[originalTable]; ok {
|
|
if newColName, colOk := columnMap[node.Name.L]; colOk {
|
|
node.Name.L = newColName
|
|
node.Name.O = newColName
|
|
}
|
|
}
|
|
} else {
|
|
// For direct table references (not aliases)
|
|
if newTableName, ok := m.tableMap[tableRef]; ok {
|
|
node.Table.L = newTableName
|
|
node.Table.O = newTableName
|
|
}
|
|
|
|
if columnMap, ok := m.columnMap[tableRef]; ok {
|
|
if newColName, colOk := columnMap[node.Name.L]; colOk {
|
|
node.Name.L = newColName
|
|
node.Name.O = newColName
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
// Handle columns without table qualifiers
|
|
for _, columnMap := range m.columnMap {
|
|
if newColName, ok := columnMap[node.Name.L]; ok {
|
|
node.Name.L = newColName
|
|
node.Name.O = newColName
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return n, false
|
|
}
|
|
|
|
// Leave implements ast.Visitor interface
|
|
func (m *SQLModifier) Leave(n ast.Node) (ast.Node, bool) {
|
|
return n, true
|
|
}
|
|
|
|
// GetSQLOperation implements the SQLParser interface
|
|
func (p *Impl) GetSQLOperation(sql string) (sqlparser.OperationType, error) {
|
|
if sql == "" {
|
|
return sqlparser.OperationTypeUnknown, fmt.Errorf("empty SQL statement")
|
|
}
|
|
|
|
// Parse SQL statement
|
|
stmt, err := p.parser.ParseOneStmt(sql, mysql.UTF8MB4Charset, mysql.UTF8MB4GeneralCICollation)
|
|
if err != nil {
|
|
return sqlparser.OperationTypeUnknown, fmt.Errorf("failed to parse SQL: %v", err)
|
|
}
|
|
|
|
// Identify the statement type
|
|
switch stmt.(type) {
|
|
case *ast.SelectStmt:
|
|
return sqlparser.OperationTypeSelect, nil
|
|
case *ast.InsertStmt:
|
|
return sqlparser.OperationTypeInsert, nil
|
|
case *ast.UpdateStmt:
|
|
return sqlparser.OperationTypeUpdate, nil
|
|
case *ast.DeleteStmt:
|
|
return sqlparser.OperationTypeDelete, nil
|
|
case *ast.CreateTableStmt:
|
|
return sqlparser.OperationTypeCreate, nil
|
|
case *ast.AlterTableStmt:
|
|
return sqlparser.OperationTypeAlter, nil
|
|
case *ast.DropTableStmt:
|
|
return sqlparser.OperationTypeDrop, nil
|
|
case *ast.TruncateTableStmt:
|
|
return sqlparser.OperationTypeTruncate, nil
|
|
default:
|
|
// Handle other statement types if needed
|
|
return sqlparser.OperationTypeUnknown, nil
|
|
}
|
|
}
|
|
|
|
// AddColumnsToInsertSQL takes an original insert SQL and columns to add (with values), returns the modified SQL.
|
|
// addCols: a slice of ColumnValue, where each element represents a column and its value to be inserted for every row.
|
|
// primaryKeyValue: a PrimaryKeyValue struct that contains the primary key column name and its values for every row, only supported for single primary key.
|
|
// If isParam is true, placeholders (?) will be added as values, otherwise the actual values from addCols will be used.
|
|
func (p *Impl) AddColumnsToInsertSQL(origSQL string, addCols []sqlparser.ColumnValue, primaryKeyValue *sqlparser.PrimaryKeyValue, isParam bool) (string, map[string]bool, error) {
|
|
if len(addCols) == 0 {
|
|
return origSQL, nil, nil
|
|
}
|
|
|
|
stmt, err := parser.New().ParseOneStmt(origSQL, mysql.UTF8MB4Charset, mysql.UTF8MB4GeneralCICollation)
|
|
if err != nil {
|
|
return "", nil, fmt.Errorf("failed to parse SQL: %v", err)
|
|
}
|
|
insertStmt, ok := stmt.(*ast.InsertStmt)
|
|
if !ok {
|
|
return "", nil, fmt.Errorf("not an INSERT statement")
|
|
}
|
|
|
|
existingCols := make(map[string]bool)
|
|
for _, col := range insertStmt.Columns {
|
|
existingCols[col.Name.O] = true
|
|
}
|
|
|
|
colsToAdd := make([]sqlparser.ColumnValue, 0, len(addCols))
|
|
for _, colVal := range addCols {
|
|
if !existingCols[colVal.ColName] {
|
|
colsToAdd = append(colsToAdd, colVal)
|
|
}
|
|
}
|
|
if len(colsToAdd) == 0 {
|
|
return origSQL, existingCols, nil
|
|
}
|
|
|
|
rowCount := len(insertStmt.Lists)
|
|
if rowCount == 0 && insertStmt.Setlist {
|
|
rowCount = 1
|
|
}
|
|
|
|
for _, colVal := range colsToAdd {
|
|
insertStmt.Columns = append(insertStmt.Columns, &ast.ColumnName{Name: ast.NewCIStr(colVal.ColName)})
|
|
}
|
|
|
|
if primaryKeyValue != nil && !existingCols[primaryKeyValue.ColName] {
|
|
insertStmt.Columns = append(insertStmt.Columns, &ast.ColumnName{Name: ast.NewCIStr(primaryKeyValue.ColName)})
|
|
}
|
|
|
|
for i := 0; i < rowCount; i++ {
|
|
paramCount := 0
|
|
|
|
for _, colVal := range colsToAdd {
|
|
if isParam {
|
|
valExpr := ast.NewParamMarkerExpr(paramCount)
|
|
insertStmt.Lists[i] = append(insertStmt.Lists[i], valExpr)
|
|
paramCount++
|
|
} else {
|
|
insertStmt.Lists[i] = append(insertStmt.Lists[i], ast.NewValueExpr(colVal.Value, "", ""))
|
|
}
|
|
}
|
|
|
|
if primaryKeyValue != nil && !existingCols[primaryKeyValue.ColName] {
|
|
if isParam {
|
|
valExpr := ast.NewParamMarkerExpr(paramCount)
|
|
insertStmt.Lists[i] = append(insertStmt.Lists[i], valExpr)
|
|
paramCount++
|
|
} else {
|
|
insertStmt.Lists[i] = append(insertStmt.Lists[i], ast.NewValueExpr(primaryKeyValue.Values[i], "", ""))
|
|
}
|
|
}
|
|
}
|
|
|
|
var sb strings.Builder
|
|
flags := format.RestoreStringSingleQuotes | format.RestoreStringWithoutCharset
|
|
restoreCtx := format.NewRestoreCtx(flags, &sb)
|
|
err = insertStmt.Restore(restoreCtx)
|
|
if err != nil {
|
|
return "", nil, fmt.Errorf("failed to restore modified INSERT SQL: %v", err)
|
|
}
|
|
|
|
return sb.String(), existingCols, nil
|
|
}
|
|
|
|
// GetTableName extracts the table name from a SQL statement. Only supports single-table select/insert/update/delete.
|
|
func (p *Impl) GetTableName(sql string) (string, error) {
|
|
if sql == "" {
|
|
return "", fmt.Errorf("empty SQL statement")
|
|
}
|
|
stmt, err := p.parser.ParseOneStmt(sql, mysql.UTF8MB4Charset, mysql.UTF8MB4GeneralCICollation)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to parse SQL: %v", err)
|
|
}
|
|
|
|
switch s := stmt.(type) {
|
|
case *ast.SelectStmt:
|
|
if s.From == nil || s.From.TableRefs == nil {
|
|
return "", fmt.Errorf("no table found in SELECT")
|
|
}
|
|
tableSrc, ok := s.From.TableRefs.Left.(*ast.TableSource)
|
|
if !ok {
|
|
return "", fmt.Errorf("unsupported SELECT FROM structure")
|
|
}
|
|
tableName, ok := tableSrc.Source.(*ast.TableName)
|
|
if !ok {
|
|
return "", fmt.Errorf("unsupported SELECT FROM structure (not a table)")
|
|
}
|
|
return tableName.Name.O, nil
|
|
case *ast.InsertStmt:
|
|
if s.Table == nil || s.Table.TableRefs == nil {
|
|
return "", fmt.Errorf("no table found in INSERT")
|
|
}
|
|
tableSrc, ok := s.Table.TableRefs.Left.(*ast.TableSource)
|
|
if !ok {
|
|
return "", fmt.Errorf("unsupported INSERT INTO structure")
|
|
}
|
|
tableName, ok := tableSrc.Source.(*ast.TableName)
|
|
if !ok {
|
|
return "", fmt.Errorf("unsupported INSERT INTO structure (not a table)")
|
|
}
|
|
return tableName.Name.O, nil
|
|
case *ast.UpdateStmt:
|
|
if s.TableRefs == nil {
|
|
return "", fmt.Errorf("no table found in UPDATE")
|
|
}
|
|
tableSrc, ok := s.TableRefs.TableRefs.Left.(*ast.TableSource)
|
|
if !ok {
|
|
return "", fmt.Errorf("unsupported UPDATE structure")
|
|
}
|
|
tableName, ok := tableSrc.Source.(*ast.TableName)
|
|
if !ok {
|
|
return "", fmt.Errorf("unsupported UPDATE structure (not a table)")
|
|
}
|
|
return tableName.Name.O, nil
|
|
case *ast.DeleteStmt:
|
|
if s.TableRefs == nil {
|
|
return "", fmt.Errorf("no table found in DELETE")
|
|
}
|
|
tableSrc, ok := s.TableRefs.TableRefs.Left.(*ast.TableSource)
|
|
if !ok {
|
|
return "", fmt.Errorf("unsupported DELETE structure")
|
|
}
|
|
tableName, ok := tableSrc.Source.(*ast.TableName)
|
|
if !ok {
|
|
return "", fmt.Errorf("unsupported DELETE structure (not a table)")
|
|
}
|
|
return tableName.Name.O, nil
|
|
default:
|
|
return "", fmt.Errorf("unsupported SQL statement type for table name extraction")
|
|
}
|
|
}
|
|
|
|
func (p *Impl) GetInsertDataNums(sql string) (int, error) {
|
|
stmt, err := p.parser.ParseOneStmt(sql, mysql.UTF8MB4Charset, mysql.UTF8MB4GeneralCICollation)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
insert, ok := stmt.(*ast.InsertStmt)
|
|
if !ok {
|
|
return 0, fmt.Errorf("not an insert statement")
|
|
}
|
|
|
|
return len(insert.Lists), nil
|
|
}
|
|
func (p *Impl) AppendSQLFilter(sql string, op sqlparser.SQLFilterOp, filter string) (string, error) {
|
|
if sql == "" {
|
|
return "", fmt.Errorf("empty SQL statement")
|
|
}
|
|
if op == "" || (op != sqlparser.SQLFilterOpAnd && op != sqlparser.SQLFilterOpOr) {
|
|
return "", fmt.Errorf("invalid filter operator: %s", op)
|
|
}
|
|
if filter == "" {
|
|
return "", fmt.Errorf("empty filter condition")
|
|
}
|
|
stmtNode, err := p.parser.ParseOneStmt(sql, mysql.UTF8MB4Charset, mysql.UTF8MB4GeneralCICollation)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to parse SQL: %v", err)
|
|
}
|
|
// extract WHERE clause
|
|
var originalWhere ast.ExprNode
|
|
switch stmt := stmtNode.(type) {
|
|
case *ast.SelectStmt:
|
|
originalWhere = stmt.Where
|
|
case *ast.UpdateStmt:
|
|
originalWhere = stmt.Where
|
|
case *ast.DeleteStmt:
|
|
originalWhere = stmt.Where
|
|
default:
|
|
return "", fmt.Errorf("append filter condition failed: only support SELECT/UPDATE/DELETE")
|
|
}
|
|
tmpSQL := fmt.Sprintf("SELECT * FROM tmp WHERE %s", filter)
|
|
tmpNode, err := p.parser.ParseOneStmt(tmpSQL, mysql.UTF8MB4Charset, mysql.UTF8MB4GeneralCICollation)
|
|
if err != nil {
|
|
return "", fmt.Errorf("parse filter condition failed: %v", err)
|
|
}
|
|
newExpr := tmpNode.(*ast.SelectStmt).Where
|
|
mergedExpr := mergeExpr(originalWhere, newExpr, op)
|
|
// update AST
|
|
switch stmt := stmtNode.(type) {
|
|
case *ast.SelectStmt:
|
|
stmt.Where = mergedExpr
|
|
case *ast.UpdateStmt:
|
|
stmt.Where = mergedExpr
|
|
case *ast.DeleteStmt:
|
|
stmt.Where = mergedExpr
|
|
}
|
|
|
|
// regenerate SQL
|
|
var sb strings.Builder
|
|
flags := format.RestoreStringSingleQuotes | format.RestoreStringWithoutCharset | format.RestoreNameBackQuotes
|
|
restoreCtx := format.NewRestoreCtx(flags, &sb)
|
|
if err := stmtNode.Restore(restoreCtx); err != nil {
|
|
return "", fmt.Errorf("gen SQL failed: %v", err)
|
|
}
|
|
return sb.String(), nil
|
|
}
|
|
|
|
func mergeExpr(left, right ast.ExprNode, op sqlparser.SQLFilterOp) ast.ExprNode {
|
|
if left == nil {
|
|
return right
|
|
}
|
|
if right == nil {
|
|
return left
|
|
}
|
|
|
|
switch op {
|
|
case sqlparser.SQLFilterOpAnd:
|
|
return &ast.BinaryOperationExpr{
|
|
Op: opcode.LogicAnd,
|
|
L: left,
|
|
R: right,
|
|
}
|
|
case sqlparser.SQLFilterOpOr:
|
|
return &ast.BinaryOperationExpr{
|
|
Op: opcode.LogicOr,
|
|
L: left,
|
|
R: right,
|
|
}
|
|
default:
|
|
return nil
|
|
}
|
|
}
|