feat: manually mirror opencoze's code from bytedance
Change-Id: I09a73aadda978ad9511264a756b2ce51f5761adf
This commit is contained in:
413
backend/infra/impl/sqlparser/sql_parser.go
Normal file
413
backend/infra/impl/sqlparser/sql_parser.go
Normal file
@@ -0,0 +1,413 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package sqlparser
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pingcap/tidb/pkg/parser"
|
||||
"github.com/pingcap/tidb/pkg/parser/ast"
|
||||
"github.com/pingcap/tidb/pkg/parser/format"
|
||||
"github.com/pingcap/tidb/pkg/parser/mysql"
|
||||
_ "github.com/pingcap/tidb/pkg/parser/test_driver"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/sqlparser"
|
||||
)
|
||||
|
||||
// Impl implements the SQLParser interface
|
||||
type Impl struct {
|
||||
parser *parser.Parser
|
||||
}
|
||||
|
||||
// NewSQLParser creates a new SQL parser
|
||||
func NewSQLParser() sqlparser.SQLParser {
|
||||
p := parser.New()
|
||||
return &Impl{
|
||||
parser: p,
|
||||
}
|
||||
}
|
||||
|
||||
// ParseAndModifySQL implements the SQLParser interface
|
||||
|
||||
func (p *Impl) ParseAndModifySQL(sql string, tableColumns map[string]sqlparser.TableColumn) (string, error) {
|
||||
if len(tableColumns) == 0 {
|
||||
return sql, nil
|
||||
}
|
||||
|
||||
// check tableColumns
|
||||
for originalTableName, tableColumn := range tableColumns {
|
||||
if originalTableName == "" {
|
||||
return "", fmt.Errorf("original TableName must be non-empty")
|
||||
}
|
||||
|
||||
// Check if ColumnMap is either empty or all key-value pairs are non-empty
|
||||
if tableColumn.ColumnMap != nil {
|
||||
for key, value := range tableColumn.ColumnMap {
|
||||
if (key == "") != (value == "") {
|
||||
return "", fmt.Errorf("ColumnMap key and value must be either both empty or both non-empty")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parse SQL
|
||||
stmt, err := p.parser.ParseOneStmt(sql, mysql.UTF8MB4Charset, mysql.UTF8MB4GeneralCICollation)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse SQL: %v", err)
|
||||
}
|
||||
|
||||
// First pass: collect all table aliases
|
||||
aliasCollector := NewAliasCollector()
|
||||
stmt.Accept(aliasCollector)
|
||||
|
||||
for originalTableName, _ := range tableColumns {
|
||||
if _, ok := aliasCollector.tableAliases[originalTableName]; ok {
|
||||
return "", fmt.Errorf("alisa table name should not equal with origin table name")
|
||||
}
|
||||
}
|
||||
|
||||
// Second pass: modify the AST with collected aliases
|
||||
modifier := NewSQLModifier(tableColumns, aliasCollector.tableAliases)
|
||||
stmt.Accept(modifier)
|
||||
|
||||
// Convert modified AST back to SQL
|
||||
var sb strings.Builder
|
||||
// Use single quotes for string values & remove charset prefix
|
||||
flags := format.RestoreStringSingleQuotes | format.RestoreStringWithoutCharset
|
||||
restoreCtx := format.NewRestoreCtx(flags, &sb)
|
||||
err = stmt.Restore(restoreCtx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to restore SQL: %v", err)
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
// AliasCollector collects table aliases in a first pass
|
||||
type AliasCollector struct {
|
||||
tableAliases map[string]string // key is alias, value is original table name
|
||||
}
|
||||
|
||||
// NewAliasCollector creates a new alias collector
|
||||
func NewAliasCollector() *AliasCollector {
|
||||
return &AliasCollector{
|
||||
tableAliases: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
// Enter implements ast.Visitor interface
|
||||
func (c *AliasCollector) Enter(n ast.Node) (ast.Node, bool) {
|
||||
if node, ok := n.(*ast.TableSource); ok {
|
||||
if ts, nameOk := node.Source.(*ast.TableName); nameOk {
|
||||
if node.AsName.L != "" {
|
||||
c.tableAliases[node.AsName.L] = ts.Name.L
|
||||
}
|
||||
}
|
||||
}
|
||||
return n, false
|
||||
}
|
||||
|
||||
// Leave implements ast.Visitor interface
|
||||
func (c *AliasCollector) Leave(n ast.Node) (ast.Node, bool) {
|
||||
return n, true
|
||||
}
|
||||
|
||||
// SQLModifier is used to modify SQL AST
|
||||
type SQLModifier struct {
|
||||
tableMap map[string]string // key is original table name, value is new table name
|
||||
columnMap map[string]map[string]string // key is table name, value is column name mapping
|
||||
tableAliases map[string]string // key is table alias, value is original table name
|
||||
}
|
||||
|
||||
// NewSQLModifier creates a new SQL modifier with pre-collected aliases
|
||||
func NewSQLModifier(tableColumns map[string]sqlparser.TableColumn, tableAliases map[string]string) *SQLModifier {
|
||||
modifier := &SQLModifier{
|
||||
tableMap: make(map[string]string),
|
||||
columnMap: make(map[string]map[string]string),
|
||||
tableAliases: tableAliases,
|
||||
}
|
||||
|
||||
// Initialize table and column name mappings
|
||||
for originalTableName, tableColumn := range tableColumns {
|
||||
if tableColumn.NewTableName != nil && *tableColumn.NewTableName != "" {
|
||||
modifier.tableMap[originalTableName] = *tableColumn.NewTableName
|
||||
}
|
||||
modifier.columnMap[originalTableName] = tableColumn.ColumnMap
|
||||
}
|
||||
|
||||
return modifier
|
||||
}
|
||||
|
||||
// Enter implements ast.Visitor interface
|
||||
func (m *SQLModifier) Enter(n ast.Node) (ast.Node, bool) {
|
||||
switch node := n.(type) {
|
||||
case *ast.TableName:
|
||||
// Replace table name
|
||||
if newTableName, ok := m.tableMap[node.Name.L]; ok {
|
||||
// Modify all related fields of table name
|
||||
node.Name.L = newTableName
|
||||
node.Name.O = newTableName
|
||||
}
|
||||
case *ast.ColumnName:
|
||||
// Replace column name with the appropriate mapping
|
||||
if node.Table.L != "" {
|
||||
// Get the table name or alias
|
||||
tableRef := node.Table.L
|
||||
|
||||
// If this is an alias, look up the original table name for column mapping
|
||||
originalTable, isAlias := m.tableAliases[tableRef]
|
||||
|
||||
if isAlias {
|
||||
// For aliased tables, apply column mapping using the original table name
|
||||
if columnMap, ok := m.columnMap[originalTable]; ok {
|
||||
if newColName, colOk := columnMap[node.Name.L]; colOk {
|
||||
node.Name.L = newColName
|
||||
node.Name.O = newColName
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// For direct table references (not aliases)
|
||||
if newTableName, ok := m.tableMap[tableRef]; ok {
|
||||
node.Table.L = newTableName
|
||||
node.Table.O = newTableName
|
||||
}
|
||||
|
||||
if columnMap, ok := m.columnMap[tableRef]; ok {
|
||||
if newColName, colOk := columnMap[node.Name.L]; colOk {
|
||||
node.Name.L = newColName
|
||||
node.Name.O = newColName
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Handle columns without table qualifiers
|
||||
for _, columnMap := range m.columnMap {
|
||||
if newColName, ok := columnMap[node.Name.L]; ok {
|
||||
node.Name.L = newColName
|
||||
node.Name.O = newColName
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return n, false
|
||||
}
|
||||
|
||||
// Leave implements ast.Visitor interface
|
||||
func (m *SQLModifier) Leave(n ast.Node) (ast.Node, bool) {
|
||||
return n, true
|
||||
}
|
||||
|
||||
// GetSQLOperation implements the SQLParser interface
|
||||
func (p *Impl) GetSQLOperation(sql string) (sqlparser.OperationType, error) {
|
||||
if sql == "" {
|
||||
return sqlparser.OperationTypeUnknown, fmt.Errorf("empty SQL statement")
|
||||
}
|
||||
|
||||
// Parse SQL statement
|
||||
stmt, err := p.parser.ParseOneStmt(sql, mysql.UTF8MB4Charset, mysql.UTF8MB4GeneralCICollation)
|
||||
if err != nil {
|
||||
return sqlparser.OperationTypeUnknown, fmt.Errorf("failed to parse SQL: %v", err)
|
||||
}
|
||||
|
||||
// Identify the statement type
|
||||
switch stmt.(type) {
|
||||
case *ast.SelectStmt:
|
||||
return sqlparser.OperationTypeSelect, nil
|
||||
case *ast.InsertStmt:
|
||||
return sqlparser.OperationTypeInsert, nil
|
||||
case *ast.UpdateStmt:
|
||||
return sqlparser.OperationTypeUpdate, nil
|
||||
case *ast.DeleteStmt:
|
||||
return sqlparser.OperationTypeDelete, nil
|
||||
case *ast.CreateTableStmt:
|
||||
return sqlparser.OperationTypeCreate, nil
|
||||
case *ast.AlterTableStmt:
|
||||
return sqlparser.OperationTypeAlter, nil
|
||||
case *ast.DropTableStmt:
|
||||
return sqlparser.OperationTypeDrop, nil
|
||||
case *ast.TruncateTableStmt:
|
||||
return sqlparser.OperationTypeTruncate, nil
|
||||
default:
|
||||
// Handle other statement types if needed
|
||||
return sqlparser.OperationTypeUnknown, nil
|
||||
}
|
||||
}
|
||||
|
||||
// AddColumnsToInsertSQL takes an original insert SQL and columns to add (with values), returns the modified SQL.
|
||||
// addCols: a slice of ColumnValue, where each element represents a column and its value to be inserted for every row.
|
||||
// primaryKeyValue: a PrimaryKeyValue struct that contains the primary key column name and its values for every row, only supported for single primary key.
|
||||
// If isParam is true, placeholders (?) will be added as values, otherwise the actual values from addCols will be used.
|
||||
func (p *Impl) AddColumnsToInsertSQL(origSQL string, addCols []sqlparser.ColumnValue, primaryKeyValue *sqlparser.PrimaryKeyValue, isParam bool) (string, map[string]bool, error) {
|
||||
if len(addCols) == 0 {
|
||||
return origSQL, nil, nil
|
||||
}
|
||||
|
||||
stmt, err := parser.New().ParseOneStmt(origSQL, mysql.UTF8MB4Charset, mysql.UTF8MB4GeneralCICollation)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to parse SQL: %v", err)
|
||||
}
|
||||
insertStmt, ok := stmt.(*ast.InsertStmt)
|
||||
if !ok {
|
||||
return "", nil, fmt.Errorf("not an INSERT statement")
|
||||
}
|
||||
|
||||
existingCols := make(map[string]bool)
|
||||
for _, col := range insertStmt.Columns {
|
||||
existingCols[col.Name.O] = true
|
||||
}
|
||||
|
||||
colsToAdd := make([]sqlparser.ColumnValue, 0, len(addCols))
|
||||
for _, colVal := range addCols {
|
||||
if !existingCols[colVal.ColName] {
|
||||
colsToAdd = append(colsToAdd, colVal)
|
||||
}
|
||||
}
|
||||
if len(colsToAdd) == 0 {
|
||||
return origSQL, existingCols, nil
|
||||
}
|
||||
|
||||
rowCount := len(insertStmt.Lists)
|
||||
if rowCount == 0 && insertStmt.Setlist {
|
||||
rowCount = 1
|
||||
}
|
||||
|
||||
for _, colVal := range colsToAdd {
|
||||
insertStmt.Columns = append(insertStmt.Columns, &ast.ColumnName{Name: ast.NewCIStr(colVal.ColName)})
|
||||
}
|
||||
|
||||
if primaryKeyValue != nil && !existingCols[primaryKeyValue.ColName] {
|
||||
insertStmt.Columns = append(insertStmt.Columns, &ast.ColumnName{Name: ast.NewCIStr(primaryKeyValue.ColName)})
|
||||
}
|
||||
|
||||
for i := 0; i < rowCount; i++ {
|
||||
paramCount := 0
|
||||
|
||||
for _, colVal := range colsToAdd {
|
||||
if isParam {
|
||||
valExpr := ast.NewParamMarkerExpr(paramCount)
|
||||
insertStmt.Lists[i] = append(insertStmt.Lists[i], valExpr)
|
||||
paramCount++
|
||||
} else {
|
||||
insertStmt.Lists[i] = append(insertStmt.Lists[i], ast.NewValueExpr(colVal.Value, "", ""))
|
||||
}
|
||||
}
|
||||
|
||||
if primaryKeyValue != nil && !existingCols[primaryKeyValue.ColName] {
|
||||
if isParam {
|
||||
valExpr := ast.NewParamMarkerExpr(paramCount)
|
||||
insertStmt.Lists[i] = append(insertStmt.Lists[i], valExpr)
|
||||
paramCount++
|
||||
} else {
|
||||
insertStmt.Lists[i] = append(insertStmt.Lists[i], ast.NewValueExpr(primaryKeyValue.Values[i], "", ""))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
flags := format.RestoreStringSingleQuotes | format.RestoreStringWithoutCharset
|
||||
restoreCtx := format.NewRestoreCtx(flags, &sb)
|
||||
err = insertStmt.Restore(restoreCtx)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to restore modified INSERT SQL: %v", err)
|
||||
}
|
||||
|
||||
return sb.String(), existingCols, nil
|
||||
}
|
||||
|
||||
// GetTableName extracts the table name from a SQL statement. Only supports single-table select/insert/update/delete.
|
||||
func (p *Impl) GetTableName(sql string) (string, error) {
|
||||
if sql == "" {
|
||||
return "", fmt.Errorf("empty SQL statement")
|
||||
}
|
||||
stmt, err := p.parser.ParseOneStmt(sql, mysql.UTF8MB4Charset, mysql.UTF8MB4GeneralCICollation)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse SQL: %v", err)
|
||||
}
|
||||
|
||||
switch s := stmt.(type) {
|
||||
case *ast.SelectStmt:
|
||||
if s.From == nil || s.From.TableRefs == nil {
|
||||
return "", fmt.Errorf("no table found in SELECT")
|
||||
}
|
||||
tableSrc, ok := s.From.TableRefs.Left.(*ast.TableSource)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("unsupported SELECT FROM structure")
|
||||
}
|
||||
tableName, ok := tableSrc.Source.(*ast.TableName)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("unsupported SELECT FROM structure (not a table)")
|
||||
}
|
||||
return tableName.Name.O, nil
|
||||
case *ast.InsertStmt:
|
||||
if s.Table == nil || s.Table.TableRefs == nil {
|
||||
return "", fmt.Errorf("no table found in INSERT")
|
||||
}
|
||||
tableSrc, ok := s.Table.TableRefs.Left.(*ast.TableSource)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("unsupported INSERT INTO structure")
|
||||
}
|
||||
tableName, ok := tableSrc.Source.(*ast.TableName)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("unsupported INSERT INTO structure (not a table)")
|
||||
}
|
||||
return tableName.Name.O, nil
|
||||
case *ast.UpdateStmt:
|
||||
if s.TableRefs == nil {
|
||||
return "", fmt.Errorf("no table found in UPDATE")
|
||||
}
|
||||
tableSrc, ok := s.TableRefs.TableRefs.Left.(*ast.TableSource)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("unsupported UPDATE structure")
|
||||
}
|
||||
tableName, ok := tableSrc.Source.(*ast.TableName)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("unsupported UPDATE structure (not a table)")
|
||||
}
|
||||
return tableName.Name.O, nil
|
||||
case *ast.DeleteStmt:
|
||||
if s.TableRefs == nil {
|
||||
return "", fmt.Errorf("no table found in DELETE")
|
||||
}
|
||||
tableSrc, ok := s.TableRefs.TableRefs.Left.(*ast.TableSource)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("unsupported DELETE structure")
|
||||
}
|
||||
tableName, ok := tableSrc.Source.(*ast.TableName)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("unsupported DELETE structure (not a table)")
|
||||
}
|
||||
return tableName.Name.O, nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported SQL statement type for table name extraction")
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Impl) GetInsertDataNums(sql string) (int, error) {
|
||||
stmt, err := p.parser.ParseOneStmt(sql, mysql.UTF8MB4Charset, mysql.UTF8MB4GeneralCICollation)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
insert, ok := stmt.(*ast.InsertStmt)
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("not an insert statement")
|
||||
}
|
||||
|
||||
return len(insert.Lists), nil
|
||||
}
|
||||
550
backend/infra/impl/sqlparser/sql_parser_test.go
Normal file
550
backend/infra/impl/sqlparser/sql_parser_test.go
Normal file
@@ -0,0 +1,550 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package sqlparser
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/sqlparser"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
func TestSQLParser_ParseAndModifySQL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sql string
|
||||
mappings map[string]sqlparser.TableColumn
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "sql parser error",
|
||||
sql: "SELECTS id, name FROM users WHERE age > 18",
|
||||
mappings: map[string]sqlparser.TableColumn{
|
||||
"users": {
|
||||
NewTableName: ptr.Of("new_users"),
|
||||
ColumnMap: map[string]string{
|
||||
"id": "user_id",
|
||||
"name": "user_name",
|
||||
"age": "user_age",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "no new table name",
|
||||
sql: "SELECT id, name FROM users WHERE age > 18",
|
||||
mappings: map[string]sqlparser.TableColumn{
|
||||
"users": {
|
||||
ColumnMap: map[string]string{
|
||||
"id": "user_id",
|
||||
"name": "user_name",
|
||||
"age": "user_age",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: "SELECT user_id,user_name FROM users WHERE user_age>18",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "input parameters error",
|
||||
sql: "SELECT id, name FROM users WHERE age > 18",
|
||||
mappings: map[string]sqlparser.TableColumn{
|
||||
"users": {
|
||||
NewTableName: ptr.Of("new_users"),
|
||||
ColumnMap: map[string]string{
|
||||
"id": "",
|
||||
"": "user_name",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "select",
|
||||
sql: "SELECT id, name FROM users WHERE age > ?",
|
||||
mappings: map[string]sqlparser.TableColumn{
|
||||
"users": {
|
||||
NewTableName: ptr.Of("new_users"),
|
||||
ColumnMap: map[string]string{
|
||||
"id": "user_id",
|
||||
"name": "user_name",
|
||||
"age": "user_age",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: "SELECT user_id,user_name FROM new_users WHERE user_age>?",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "select",
|
||||
sql: "SELECT id, name FROM users WHERE age > 20",
|
||||
mappings: map[string]sqlparser.TableColumn{
|
||||
"users": {
|
||||
NewTableName: ptr.Of("new_users"),
|
||||
ColumnMap: map[string]string{
|
||||
"id": "user_id",
|
||||
"name": "user_name",
|
||||
"age": "user_age",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: "SELECT user_id,user_name FROM new_users WHERE user_age>20",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "alias",
|
||||
sql: "SELECT u.id, u.name, o.order_id FROM users as u JOIN orders as o ON u.id = o.user_id",
|
||||
mappings: map[string]sqlparser.TableColumn{
|
||||
"users": {
|
||||
NewTableName: ptr.Of("new_users"),
|
||||
ColumnMap: map[string]string{
|
||||
"id": "user_id",
|
||||
},
|
||||
},
|
||||
"orders": {
|
||||
NewTableName: ptr.Of("new_orders"),
|
||||
ColumnMap: map[string]string{
|
||||
"order_id": "id",
|
||||
"user_id": "customer_id",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: "SELECT u.user_id,u.name,o.id FROM new_users AS u JOIN new_orders AS o ON u.user_id=o.customer_id",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "alias",
|
||||
sql: "SELECT u.id, u.name, o.order_id FROM users as u JOIN orders as o ON u.id = o.user_id",
|
||||
mappings: map[string]sqlparser.TableColumn{
|
||||
"users": {
|
||||
NewTableName: ptr.Of("new_users"),
|
||||
ColumnMap: map[string]string{
|
||||
"id": "user_id",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: "SELECT u.user_id,u.name,o.order_id FROM new_users AS u JOIN orders AS o ON u.user_id=o.user_id",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "join query",
|
||||
sql: "SELECT users.id, users.name, orders.order_id FROM users JOIN orders ON users.id = orders.user_id",
|
||||
mappings: map[string]sqlparser.TableColumn{
|
||||
"users": {
|
||||
NewTableName: ptr.Of("new_users"),
|
||||
ColumnMap: map[string]string{
|
||||
"id": "user_id",
|
||||
},
|
||||
},
|
||||
"orders": {
|
||||
NewTableName: ptr.Of("new_orders"),
|
||||
ColumnMap: map[string]string{
|
||||
"order_id": "id",
|
||||
"user_id": "customer_id",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: "SELECT new_users.user_id,new_users.name,new_orders.id FROM new_users JOIN new_orders ON new_users.user_id=new_orders.customer_id",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "insert statement",
|
||||
sql: "INSERT INTO users (id, name, age) VALUES (1, 'John', ?)",
|
||||
mappings: map[string]sqlparser.TableColumn{
|
||||
"users": {
|
||||
NewTableName: ptr.Of("new_users"),
|
||||
ColumnMap: map[string]string{
|
||||
"id": "user_id",
|
||||
"name": "user_name",
|
||||
"age": "user_age",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: "INSERT INTO new_users (user_id,user_name,user_age) VALUES (1,'John',?)",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "update statement",
|
||||
sql: "UPDATE users SET name = 'John', age = 25 WHERE id = 1",
|
||||
mappings: map[string]sqlparser.TableColumn{
|
||||
"users": {
|
||||
NewTableName: ptr.Of("new_users"),
|
||||
ColumnMap: map[string]string{
|
||||
"id": "user_id",
|
||||
"name": "user_name",
|
||||
"age": "user_age",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: "UPDATE new_users SET user_name='John', user_age=25 WHERE user_id=1",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "only change table name",
|
||||
sql: "UPDATE users SET name = 'John', age = 25 WHERE id = 1",
|
||||
mappings: map[string]sqlparser.TableColumn{
|
||||
"users": {
|
||||
NewTableName: ptr.Of("new_users"),
|
||||
},
|
||||
},
|
||||
want: "UPDATE new_users SET name='John', age=25 WHERE id=1",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "alias error",
|
||||
sql: "SELECT u.id, u.name, o.order_id FROM (SELECT id, name FROM u) AS uu JOIN orders AS u ON uu.id = o.user_id;",
|
||||
mappings: map[string]sqlparser.TableColumn{
|
||||
"u": {
|
||||
NewTableName: ptr.Of("new_users"),
|
||||
ColumnMap: map[string]string{
|
||||
"id": "user_id",
|
||||
"name": "user_name",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "alias error",
|
||||
sql: "INSERT INTO database (name, age) VALUES ('Nick', 25);",
|
||||
mappings: map[string]sqlparser.TableColumn{
|
||||
"database": {
|
||||
NewTableName: ptr.Of("database_new"),
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
parser := NewSQLParser()
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := parser.ParseAndModifySQL(tt.sql, tt.mappings)
|
||||
assert.Equal(t, tt.wantErr, err != nil)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLParser_GetSQLOperation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sql string
|
||||
want sqlparser.OperationType
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "empty sql",
|
||||
sql: "",
|
||||
want: sqlparser.OperationTypeUnknown,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid sql",
|
||||
sql: "SELECTS * FROM users",
|
||||
want: sqlparser.OperationTypeUnknown,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "select statement",
|
||||
sql: "SELECT id, name FROM users WHERE age > 18",
|
||||
want: sqlparser.OperationTypeSelect,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "insert statement",
|
||||
sql: "INSERT INTO users (id, name, age) VALUES (1, 'John', 25)",
|
||||
want: sqlparser.OperationTypeInsert,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "update statement",
|
||||
sql: "UPDATE users SET name = 'John', age = 25 WHERE id = 1",
|
||||
want: sqlparser.OperationTypeUpdate,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "delete statement",
|
||||
sql: "DELETE FROM users WHERE id = 1",
|
||||
want: sqlparser.OperationTypeDelete,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "create table statement",
|
||||
sql: "CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(255), age INT)",
|
||||
want: sqlparser.OperationTypeCreate,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "alter table statement",
|
||||
sql: "ALTER TABLE users ADD COLUMN email VARCHAR(255)",
|
||||
want: sqlparser.OperationTypeAlter,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "drop table statement",
|
||||
sql: "DROP TABLE users",
|
||||
want: sqlparser.OperationTypeDrop,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "truncate table statement",
|
||||
sql: "TRUNCATE TABLE users",
|
||||
want: sqlparser.OperationTypeTruncate,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "complex select statement",
|
||||
sql: "SELECT u.id, u.name FROM users u JOIN orders o ON u.id = o.user_id WHERE u.age > 18 ORDER BY u.name",
|
||||
want: sqlparser.OperationTypeSelect,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "complex statement",
|
||||
sql: "UPDATE employees SET s = s * 1.15 WHERE d = ( SELECT id FROM departments WHERE name = 't')",
|
||||
want: sqlparser.OperationTypeUpdate,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
parser := NewSQLParser()
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := parser.GetSQLOperation(tt.sql)
|
||||
assert.Equal(t, tt.wantErr, err != nil)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestImpl_AddColumnsToInsertSQL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
origSQL string
|
||||
addCols []sqlparser.ColumnValue
|
||||
wantSQL string
|
||||
isParam bool
|
||||
primaryKeyValue *sqlparser.PrimaryKeyValue
|
||||
}{
|
||||
{
|
||||
name: "add new columns to single row insert",
|
||||
origSQL: "INSERT INTO users (id, name) VALUES (1, 'name')",
|
||||
addCols: []sqlparser.ColumnValue{
|
||||
{
|
||||
ColName: "age",
|
||||
Value: 18,
|
||||
},
|
||||
},
|
||||
wantSQL: "INSERT INTO users (id,name,age) VALUES (1, 'name',18)",
|
||||
},
|
||||
{
|
||||
name: "add new columns to multi-row insert",
|
||||
origSQL: "INSERT INTO users (id, name) VALUES (1, 'name'), (1, 'name')",
|
||||
addCols: []sqlparser.ColumnValue{
|
||||
{
|
||||
ColName: "age",
|
||||
Value: 18,
|
||||
},
|
||||
},
|
||||
primaryKeyValue: &sqlparser.PrimaryKeyValue{
|
||||
ColName: "pri_id",
|
||||
Values: []interface{}{1, 2},
|
||||
},
|
||||
wantSQL: "INSERT INTO users (id,name,age,pri_id) VALUES (1, 'name',18,1), (1, 'name',18,2)",
|
||||
},
|
||||
{
|
||||
name: "addCols is empty, no change",
|
||||
origSQL: "INSERT INTO users (id, name) VALUES (1, 'name')",
|
||||
addCols: []sqlparser.ColumnValue{},
|
||||
wantSQL: "INSERT INTO users (id, name) VALUES (1, 'name')",
|
||||
},
|
||||
{
|
||||
name: "column already exists, do not add",
|
||||
origSQL: "INSERT INTO users (id, name) VALUES (1, 'name')",
|
||||
addCols: []sqlparser.ColumnValue{{
|
||||
ColName: "name",
|
||||
Value: "abc",
|
||||
}},
|
||||
wantSQL: "INSERT INTO users (id, name) VALUES (1, 'name')",
|
||||
},
|
||||
{
|
||||
name: "add new columns to single row insert",
|
||||
origSQL: "INSERT INTO users (id, name) VALUES (? ,?)",
|
||||
addCols: []sqlparser.ColumnValue{
|
||||
{
|
||||
ColName: "age",
|
||||
},
|
||||
},
|
||||
wantSQL: "INSERT INTO users (id,name,age) VALUES (?, ?, ?)",
|
||||
isParam: true,
|
||||
},
|
||||
{
|
||||
name: "add new columns to single row insert",
|
||||
origSQL: "INSERT INTO users (id, name) VALUES (? ,?), (?, ?)",
|
||||
addCols: []sqlparser.ColumnValue{
|
||||
{
|
||||
ColName: "age",
|
||||
},
|
||||
},
|
||||
wantSQL: "INSERT INTO users (id,name,age) VALUES (?, ?, ?), (?, ?, ?)",
|
||||
isParam: true,
|
||||
},
|
||||
}
|
||||
|
||||
parser := NewSQLParser()
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, _, err := parser.AddColumnsToInsertSQL(tt.origSQL, tt.addCols, tt.primaryKeyValue, tt.isParam)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
gotNorm := strings.ReplaceAll(got, " ", "")
|
||||
wantNorm := strings.ReplaceAll(tt.wantSQL, " ", "")
|
||||
if gotNorm != wantNorm {
|
||||
t.Errorf("got SQL: %s, want: %s", got, tt.wantSQL)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestImpl_GetTableName(t *testing.T) {
|
||||
parser := NewSQLParser().(*Impl)
|
||||
tests := []struct {
|
||||
name string
|
||||
sql string
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "select single table",
|
||||
sql: "SELECT * FROM users WHERE id = 1",
|
||||
want: "users",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "insert single table",
|
||||
sql: "INSERT INTO users (id, name) VALUES (1, 'a')",
|
||||
want: "users",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "update single table",
|
||||
sql: "UPDATE users SET name = 'b' WHERE id = 2",
|
||||
want: "users",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "delete single table",
|
||||
sql: "DELETE FROM users WHERE id = 3",
|
||||
want: "users",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "select join (unsupported)",
|
||||
sql: "SELECT * FROM users u JOIN orders o ON u.id = o.user_id",
|
||||
want: "users",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty sql",
|
||||
sql: "",
|
||||
want: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid sql",
|
||||
sql: "SELECTS * FROM users",
|
||||
want: "",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tbl, err := parser.GetTableName(tt.sql)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("expected error, got nil")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if tbl != tt.want {
|
||||
t.Errorf("got table: %s, want: %s", tbl, tt.want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetInsertDataNums(t *testing.T) {
|
||||
parser := NewSQLParser().(*Impl)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sql string
|
||||
want int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "single row insert",
|
||||
sql: "INSERT INTO users (name, age) VALUES ('Alice', 25);",
|
||||
want: 1,
|
||||
},
|
||||
{
|
||||
name: "multi-row insert",
|
||||
sql: "INSERT INTO users (name, age) VALUES ('Alice', 25), ('Bob', 30);",
|
||||
want: 2,
|
||||
},
|
||||
{
|
||||
name: "multi-row insert",
|
||||
sql: "INSERT INTO users (name, age) VALUES (?, ?), (?, ?), (?, ?), (?, ?);",
|
||||
want: 4,
|
||||
},
|
||||
{
|
||||
name: "not an insert statement",
|
||||
sql: "SELECT * FROM users;",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid sql",
|
||||
sql: "INSERTT INTO users (name, age) VALUES ('Alice', 25);",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := parser.GetInsertDataNums(tt.sql)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user