diff --git a/backend/domain/memory/database/service/database_impl.go b/backend/domain/memory/database/service/database_impl.go index 651d78f0..4808f433 100644 --- a/backend/domain/memory/database/service/database_impl.go +++ b/backend/domain/memory/database/service/database_impl.go @@ -1075,7 +1075,16 @@ func (d databaseService) executeCustomSQL(ctx context.Context, req *ExecuteSQLRe if err != nil { return nil, fmt.Errorf("parse sql failed: %v", err) } - + // add rw mode + if tableInfo.RwMode == table.BotTableRWMode_LimitedReadWrite && len(req.UserID) != 0 { + switch operation { + case sqlparsercontract.OperationTypeSelect, sqlparsercontract.OperationTypeUpdate, sqlparsercontract.OperationTypeDelete: + parsedSQL, err = sqlparser.NewSQLParser().AppendSQLFilter(parsedSQL, sqlparsercontract.SQLFilterOpAnd, fmt.Sprintf("%s = '%s'", database.DefaultUidColName, req.UserID)) + if err != nil { + return nil, fmt.Errorf("append sql filter failed: %v", err) + } + } + } insertResult := make([]map[string]interface{}, 0) if operation == sqlparsercontract.OperationTypeInsert { cid := consts.CozeConnectorID diff --git a/backend/infra/contract/sqlparser/sql_parser.go b/backend/infra/contract/sqlparser/sql_parser.go index b6695387..07ddcded 100644 --- a/backend/infra/contract/sqlparser/sql_parser.go +++ b/backend/infra/contract/sqlparser/sql_parser.go @@ -48,6 +48,13 @@ const ( OperationTypeUnknown OperationType = "UNKNOWN" ) +type SQLFilterOp string + +const ( + SQLFilterOpAnd SQLFilterOp = "AND" + SQLFilterOpOr SQLFilterOp = "OR" +) + // SQLParser defines the interface for parsing and modifying SQL statements type SQLParser interface { // ParseAndModifySQL parses SQL and replaces table/column names according to the provided message @@ -64,4 +71,7 @@ type SQLParser interface { // GetInsertDataNums extracts the number of rows to be inserted from a SQL statement. Only supports single-table insert. GetInsertDataNums(sql string) (int, error) + + // AppendSQLFilter appends a filter condition to the SQL statement. + AppendSQLFilter(sql string, op SQLFilterOp, filter string) (string, error) } diff --git a/backend/infra/impl/sqlparser/sql_parser.go b/backend/infra/impl/sqlparser/sql_parser.go index 645ff262..b8d1711a 100644 --- a/backend/infra/impl/sqlparser/sql_parser.go +++ b/backend/infra/impl/sqlparser/sql_parser.go @@ -24,6 +24,7 @@ import ( "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/format" "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/opcode" _ "github.com/pingcap/tidb/pkg/parser/test_driver" "github.com/coze-dev/coze-studio/backend/infra/contract/sqlparser" @@ -411,3 +412,81 @@ func (p *Impl) GetInsertDataNums(sql string) (int, error) { return len(insert.Lists), nil } +func (p *Impl) AppendSQLFilter(sql string, op sqlparser.SQLFilterOp, filter string) (string, error) { + if sql == "" { + return "", fmt.Errorf("empty SQL statement") + } + if op == "" || (op != sqlparser.SQLFilterOpAnd && op != sqlparser.SQLFilterOpOr) { + return "", fmt.Errorf("invalid filter operator: %s", op) + } + if filter == "" { + return "", fmt.Errorf("empty filter condition") + } + stmtNode, err := p.parser.ParseOneStmt(sql, mysql.UTF8MB4Charset, mysql.UTF8MB4GeneralCICollation) + if err != nil { + return "", fmt.Errorf("failed to parse SQL: %v", err) + } + // extract WHERE clause + var originalWhere ast.ExprNode + switch stmt := stmtNode.(type) { + case *ast.SelectStmt: + originalWhere = stmt.Where + case *ast.UpdateStmt: + originalWhere = stmt.Where + case *ast.DeleteStmt: + originalWhere = stmt.Where + default: + return "", fmt.Errorf("append filter condition failed: only support SELECT/UPDATE/DELETE") + } + tmpSQL := fmt.Sprintf("SELECT * FROM tmp WHERE %s", filter) + tmpNode, err := p.parser.ParseOneStmt(tmpSQL, mysql.UTF8MB4Charset, mysql.UTF8MB4GeneralCICollation) + if err != nil { + return "", fmt.Errorf("parse filter condition failed: %v", err) + } + newExpr := tmpNode.(*ast.SelectStmt).Where + mergedExpr := mergeExpr(originalWhere, newExpr, op) + // update AST + switch stmt := stmtNode.(type) { + case *ast.SelectStmt: + stmt.Where = mergedExpr + case *ast.UpdateStmt: + stmt.Where = mergedExpr + case *ast.DeleteStmt: + stmt.Where = mergedExpr + } + + // regenerate SQL + var sb strings.Builder + flags := format.RestoreStringSingleQuotes | format.RestoreStringWithoutCharset | format.RestoreNameBackQuotes + restoreCtx := format.NewRestoreCtx(flags, &sb) + if err := stmtNode.Restore(restoreCtx); err != nil { + return "", fmt.Errorf("gen SQL failed: %v", err) + } + return sb.String(), nil +} + +func mergeExpr(left, right ast.ExprNode, op sqlparser.SQLFilterOp) ast.ExprNode { + if left == nil { + return right + } + if right == nil { + return left + } + + switch op { + case sqlparser.SQLFilterOpAnd: + return &ast.BinaryOperationExpr{ + Op: opcode.LogicAnd, + L: left, + R: right, + } + case sqlparser.SQLFilterOpOr: + return &ast.BinaryOperationExpr{ + Op: opcode.LogicOr, + L: left, + R: right, + } + default: + return nil + } +} diff --git a/backend/infra/impl/sqlparser/sql_parser_test.go b/backend/infra/impl/sqlparser/sql_parser_test.go index 010b216f..8053a171 100644 --- a/backend/infra/impl/sqlparser/sql_parser_test.go +++ b/backend/infra/impl/sqlparser/sql_parser_test.go @@ -548,3 +548,206 @@ func TestGetInsertDataNums(t *testing.T) { }) } } + +func TestAppendSQLFilter(t *testing.T) { + parser := NewSQLParser().(*Impl) + tests := []struct { + name string + sql string + condition string + connector string + want string + wantErr bool + errContains string + }{ + // tset - SELECT + { + name: "SELECT - add AND to existing WHERE", + sql: "SELECT * FROM users WHERE age > 18", + condition: "status = 'active'", + connector: "AND", + want: " select * from `users` where `age`>18 and `status`='active'", + }, + { + name: "SELECT - add OR to existing WHERE", + sql: "SELECT * FROM products WHERE price < 100", + condition: "category = 'electronics'", + connector: "OR", + want: "select * from `products` where `price`<100 or `category`='electronics'", + }, + { + name: "SELECT - add AND to multiple conditions", + sql: "SELECT * FROM orders WHERE total > 50 AND status = 'completed'", + condition: "customer_id = 123", + connector: "AND", + want: "select * from `orders` where `total`>50 and `status`='completed' and `customer_id`=123", + }, + { + name: "SELECT - add condition without WHERE", + sql: "SELECT id, name FROM customers", + condition: "is_verified = 1", + connector: "AND", + want: "select `id`,`name` from `customers` where `is_verified`=1", + }, + + // tset - UPDATE + { + name: "UPDATE - add AND condition", + sql: "UPDATE users SET last_login = NOW() WHERE id = 42", + condition: "is_active = true", + connector: "AND", + want: "update `users` set `last_login`=now() where `id`=42 and `is_active`=true", + }, + { + name: "UPDATE - add OR condition without WHERE", + sql: "UPDATE products SET discount = 0.1", + condition: "inventory > 100", + connector: "OR", + want: "update `products` set `discount`=0.1 where `inventory`>100", + }, + + // tset - DELETE + { + name: "DELETE - add AND condition", + sql: "DELETE FROM logs WHERE created_at < '2023-01-01'", + condition: "severity = 'DEBUG'", + connector: "AND", + want: "delete from `logs` where `created_at`<'2023-01-01' and `severity`='debug'", + }, + { + name: "DELETE - add OR condition", + sql: "DELETE FROM sessions WHERE expires_at < NOW()", + condition: "invalid = true", + connector: "OR", + want: "delete from `sessions` where `expires_at` 100 OR priority = 1)", + connector: "AND", + want: "select * from `orders` where `status`='shipped' and (`total`>100 or `priority`=1)", + }, + { + name: "Add condition to existing parentheses", + sql: "SELECT * FROM users WHERE (age > 18 OR parent_consent = true) AND country = 'US'", + condition: "is_verified = 1", + connector: "AND", + want: "select * from `users` where (`age`>18 or `parent_consent`=true) and `country`='us' and `is_verified`=1", + }, + + // tset - JOIN + { + name: "SELECT with JOIN", + sql: "SELECT u.name, o.total FROM users u JOIN orders o ON u.id = o.user_id WHERE u.country = 'US'", + condition: "o.status = 'completed'", + connector: "AND", + want: "select `u`.`name`,`o`.`total` from `users` as `u` join `orders` as `o` on `u`.`id`=`o`.`user_id` where `u`.`country`='us' and `o`.`status`='completed'", + }, + { + name: "SELECT with multiple joins", + sql: "SELECT p.name, c.category_name FROM products p JOIN categories c ON p.category_id = c.id WHERE p.price < 50", + condition: "c.parent_id = 1", + connector: "AND", + want: "select `p`.`name`,`c`.`category_name` from `products` as `p` join `categories` as `c` on `p`.`category_id`=`c`.`id` where `p`.`price`<50 and `c`.`parent_id`=1", + }, + + // test - case sensitive + { + name: "Mixed case connector", + sql: "SELECT * FROM users WHERE age > 18", + condition: "status = 'active'", + connector: "aNd", + want: "", + wantErr: true, + errContains: "invalid filter operator", + }, + { + name: "Mixed case condition", + sql: "SELECT * FROM products", + condition: "CaTegorY = 'ELECTRONICS'", + connector: "AND", + want: "select * from `products` where `category`='electronics'", + }, + + // test - error case + { + name: "Empty SQL", + sql: "", + condition: "id = 1", + connector: "AND", + wantErr: true, + errContains: "empty SQL statement", + }, + { + name: "Empty condition", + sql: "SELECT * FROM users", + condition: "", + connector: "AND", + wantErr: true, + errContains: "empty filter condition", + }, + { + name: "Invalid connector", + sql: "SELECT * FROM users", + condition: "is_active = true", + connector: "", + wantErr: true, + errContains: "invalid filter operator", + }, + { + name: "Unsupported statement type", + sql: "CREATE TABLE users (id INT, name VARCHAR(255))", + condition: "id > 0", + connector: "AND", + wantErr: true, + errContains: "only support SELECT/UPDATE/DELETE", + }, + { + name: "Malformed SQL", + sql: "SELECTZ * FRON users", + condition: "id = 1", + connector: "AND", + wantErr: true, + errContains: "failed to parse SQL", + }, + { + name: "Malformed condition", + sql: "SELECT * FROM users", + condition: "id ==", + connector: "AND", + wantErr: true, + errContains: "parse filter condition failed", + }, + } + + // run case + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := parser.AppendSQLFilter(tt.sql, sqlparser.SQLFilterOp(tt.connector), tt.condition) + + if tt.wantErr { + if err == nil { + t.Fatal("Expected error, got nil") + } + if tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) { + t.Errorf("Expected error to contain %q, got %q", tt.errContains, err.Error()) + } + return + } + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + normalizedResult := strings.ToLower(strings.Join(strings.Fields(result), " ")) + normalizedWant := strings.ToLower(strings.Join(strings.Fields(tt.want), " ")) + if !strings.EqualFold(normalizedResult, normalizedWant) { + t.Errorf("Result mismatch:\nWant: %s\nGot: %s", normalizedWant, normalizedResult) + } + }) + } + +}