fix: When agents use data tables, different users influence each other (#565)
This commit is contained in:
parent
8b91a640b9
commit
60285ca014
|
|
@ -1075,7 +1075,16 @@ func (d databaseService) executeCustomSQL(ctx context.Context, req *ExecuteSQLRe
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("parse sql failed: %v", err)
|
||||
}
|
||||
|
||||
// add rw mode
|
||||
if tableInfo.RwMode == table.BotTableRWMode_LimitedReadWrite && len(req.UserID) != 0 {
|
||||
switch operation {
|
||||
case sqlparsercontract.OperationTypeSelect, sqlparsercontract.OperationTypeUpdate, sqlparsercontract.OperationTypeDelete:
|
||||
parsedSQL, err = sqlparser.NewSQLParser().AppendSQLFilter(parsedSQL, sqlparsercontract.SQLFilterOpAnd, fmt.Sprintf("%s = '%s'", database.DefaultUidColName, req.UserID))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("append sql filter failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
insertResult := make([]map[string]interface{}, 0)
|
||||
if operation == sqlparsercontract.OperationTypeInsert {
|
||||
cid := consts.CozeConnectorID
|
||||
|
|
|
|||
|
|
@ -48,6 +48,13 @@ const (
|
|||
OperationTypeUnknown OperationType = "UNKNOWN"
|
||||
)
|
||||
|
||||
type SQLFilterOp string
|
||||
|
||||
const (
|
||||
SQLFilterOpAnd SQLFilterOp = "AND"
|
||||
SQLFilterOpOr SQLFilterOp = "OR"
|
||||
)
|
||||
|
||||
// SQLParser defines the interface for parsing and modifying SQL statements
|
||||
type SQLParser interface {
|
||||
// ParseAndModifySQL parses SQL and replaces table/column names according to the provided message
|
||||
|
|
@ -64,4 +71,7 @@ type SQLParser interface {
|
|||
|
||||
// GetInsertDataNums extracts the number of rows to be inserted from a SQL statement. Only supports single-table insert.
|
||||
GetInsertDataNums(sql string) (int, error)
|
||||
|
||||
// AppendSQLFilter appends a filter condition to the SQL statement.
|
||||
AppendSQLFilter(sql string, op SQLFilterOp, filter string) (string, error)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ import (
|
|||
"github.com/pingcap/tidb/pkg/parser/ast"
|
||||
"github.com/pingcap/tidb/pkg/parser/format"
|
||||
"github.com/pingcap/tidb/pkg/parser/mysql"
|
||||
"github.com/pingcap/tidb/pkg/parser/opcode"
|
||||
_ "github.com/pingcap/tidb/pkg/parser/test_driver"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/sqlparser"
|
||||
|
|
@ -411,3 +412,81 @@ func (p *Impl) GetInsertDataNums(sql string) (int, error) {
|
|||
|
||||
return len(insert.Lists), nil
|
||||
}
|
||||
func (p *Impl) AppendSQLFilter(sql string, op sqlparser.SQLFilterOp, filter string) (string, error) {
|
||||
if sql == "" {
|
||||
return "", fmt.Errorf("empty SQL statement")
|
||||
}
|
||||
if op == "" || (op != sqlparser.SQLFilterOpAnd && op != sqlparser.SQLFilterOpOr) {
|
||||
return "", fmt.Errorf("invalid filter operator: %s", op)
|
||||
}
|
||||
if filter == "" {
|
||||
return "", fmt.Errorf("empty filter condition")
|
||||
}
|
||||
stmtNode, err := p.parser.ParseOneStmt(sql, mysql.UTF8MB4Charset, mysql.UTF8MB4GeneralCICollation)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse SQL: %v", err)
|
||||
}
|
||||
// extract WHERE clause
|
||||
var originalWhere ast.ExprNode
|
||||
switch stmt := stmtNode.(type) {
|
||||
case *ast.SelectStmt:
|
||||
originalWhere = stmt.Where
|
||||
case *ast.UpdateStmt:
|
||||
originalWhere = stmt.Where
|
||||
case *ast.DeleteStmt:
|
||||
originalWhere = stmt.Where
|
||||
default:
|
||||
return "", fmt.Errorf("append filter condition failed: only support SELECT/UPDATE/DELETE")
|
||||
}
|
||||
tmpSQL := fmt.Sprintf("SELECT * FROM tmp WHERE %s", filter)
|
||||
tmpNode, err := p.parser.ParseOneStmt(tmpSQL, mysql.UTF8MB4Charset, mysql.UTF8MB4GeneralCICollation)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse filter condition failed: %v", err)
|
||||
}
|
||||
newExpr := tmpNode.(*ast.SelectStmt).Where
|
||||
mergedExpr := mergeExpr(originalWhere, newExpr, op)
|
||||
// update AST
|
||||
switch stmt := stmtNode.(type) {
|
||||
case *ast.SelectStmt:
|
||||
stmt.Where = mergedExpr
|
||||
case *ast.UpdateStmt:
|
||||
stmt.Where = mergedExpr
|
||||
case *ast.DeleteStmt:
|
||||
stmt.Where = mergedExpr
|
||||
}
|
||||
|
||||
// regenerate SQL
|
||||
var sb strings.Builder
|
||||
flags := format.RestoreStringSingleQuotes | format.RestoreStringWithoutCharset | format.RestoreNameBackQuotes
|
||||
restoreCtx := format.NewRestoreCtx(flags, &sb)
|
||||
if err := stmtNode.Restore(restoreCtx); err != nil {
|
||||
return "", fmt.Errorf("gen SQL failed: %v", err)
|
||||
}
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
func mergeExpr(left, right ast.ExprNode, op sqlparser.SQLFilterOp) ast.ExprNode {
|
||||
if left == nil {
|
||||
return right
|
||||
}
|
||||
if right == nil {
|
||||
return left
|
||||
}
|
||||
|
||||
switch op {
|
||||
case sqlparser.SQLFilterOpAnd:
|
||||
return &ast.BinaryOperationExpr{
|
||||
Op: opcode.LogicAnd,
|
||||
L: left,
|
||||
R: right,
|
||||
}
|
||||
case sqlparser.SQLFilterOpOr:
|
||||
return &ast.BinaryOperationExpr{
|
||||
Op: opcode.LogicOr,
|
||||
L: left,
|
||||
R: right,
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -548,3 +548,206 @@ func TestGetInsertDataNums(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendSQLFilter(t *testing.T) {
|
||||
parser := NewSQLParser().(*Impl)
|
||||
tests := []struct {
|
||||
name string
|
||||
sql string
|
||||
condition string
|
||||
connector string
|
||||
want string
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
// tset - SELECT
|
||||
{
|
||||
name: "SELECT - add AND to existing WHERE",
|
||||
sql: "SELECT * FROM users WHERE age > 18",
|
||||
condition: "status = 'active'",
|
||||
connector: "AND",
|
||||
want: " select * from `users` where `age`>18 and `status`='active'",
|
||||
},
|
||||
{
|
||||
name: "SELECT - add OR to existing WHERE",
|
||||
sql: "SELECT * FROM products WHERE price < 100",
|
||||
condition: "category = 'electronics'",
|
||||
connector: "OR",
|
||||
want: "select * from `products` where `price`<100 or `category`='electronics'",
|
||||
},
|
||||
{
|
||||
name: "SELECT - add AND to multiple conditions",
|
||||
sql: "SELECT * FROM orders WHERE total > 50 AND status = 'completed'",
|
||||
condition: "customer_id = 123",
|
||||
connector: "AND",
|
||||
want: "select * from `orders` where `total`>50 and `status`='completed' and `customer_id`=123",
|
||||
},
|
||||
{
|
||||
name: "SELECT - add condition without WHERE",
|
||||
sql: "SELECT id, name FROM customers",
|
||||
condition: "is_verified = 1",
|
||||
connector: "AND",
|
||||
want: "select `id`,`name` from `customers` where `is_verified`=1",
|
||||
},
|
||||
|
||||
// tset - UPDATE
|
||||
{
|
||||
name: "UPDATE - add AND condition",
|
||||
sql: "UPDATE users SET last_login = NOW() WHERE id = 42",
|
||||
condition: "is_active = true",
|
||||
connector: "AND",
|
||||
want: "update `users` set `last_login`=now() where `id`=42 and `is_active`=true",
|
||||
},
|
||||
{
|
||||
name: "UPDATE - add OR condition without WHERE",
|
||||
sql: "UPDATE products SET discount = 0.1",
|
||||
condition: "inventory > 100",
|
||||
connector: "OR",
|
||||
want: "update `products` set `discount`=0.1 where `inventory`>100",
|
||||
},
|
||||
|
||||
// tset - DELETE
|
||||
{
|
||||
name: "DELETE - add AND condition",
|
||||
sql: "DELETE FROM logs WHERE created_at < '2023-01-01'",
|
||||
condition: "severity = 'DEBUG'",
|
||||
connector: "AND",
|
||||
want: "delete from `logs` where `created_at`<'2023-01-01' and `severity`='debug'",
|
||||
},
|
||||
{
|
||||
name: "DELETE - add OR condition",
|
||||
sql: "DELETE FROM sessions WHERE expires_at < NOW()",
|
||||
condition: "invalid = true",
|
||||
connector: "OR",
|
||||
want: "delete from `sessions` where `expires_at`<now() or `invalid`=true",
|
||||
},
|
||||
|
||||
// tset - complex expr
|
||||
{
|
||||
name: "Complex condition with parentheses",
|
||||
sql: "SELECT * FROM orders WHERE `status` = 'shipped'",
|
||||
condition: "(total > 100 OR priority = 1)",
|
||||
connector: "AND",
|
||||
want: "select * from `orders` where `status`='shipped' and (`total`>100 or `priority`=1)",
|
||||
},
|
||||
{
|
||||
name: "Add condition to existing parentheses",
|
||||
sql: "SELECT * FROM users WHERE (age > 18 OR parent_consent = true) AND country = 'US'",
|
||||
condition: "is_verified = 1",
|
||||
connector: "AND",
|
||||
want: "select * from `users` where (`age`>18 or `parent_consent`=true) and `country`='us' and `is_verified`=1",
|
||||
},
|
||||
|
||||
// tset - JOIN
|
||||
{
|
||||
name: "SELECT with JOIN",
|
||||
sql: "SELECT u.name, o.total FROM users u JOIN orders o ON u.id = o.user_id WHERE u.country = 'US'",
|
||||
condition: "o.status = 'completed'",
|
||||
connector: "AND",
|
||||
want: "select `u`.`name`,`o`.`total` from `users` as `u` join `orders` as `o` on `u`.`id`=`o`.`user_id` where `u`.`country`='us' and `o`.`status`='completed'",
|
||||
},
|
||||
{
|
||||
name: "SELECT with multiple joins",
|
||||
sql: "SELECT p.name, c.category_name FROM products p JOIN categories c ON p.category_id = c.id WHERE p.price < 50",
|
||||
condition: "c.parent_id = 1",
|
||||
connector: "AND",
|
||||
want: "select `p`.`name`,`c`.`category_name` from `products` as `p` join `categories` as `c` on `p`.`category_id`=`c`.`id` where `p`.`price`<50 and `c`.`parent_id`=1",
|
||||
},
|
||||
|
||||
// test - case sensitive
|
||||
{
|
||||
name: "Mixed case connector",
|
||||
sql: "SELECT * FROM users WHERE age > 18",
|
||||
condition: "status = 'active'",
|
||||
connector: "aNd",
|
||||
want: "",
|
||||
wantErr: true,
|
||||
errContains: "invalid filter operator",
|
||||
},
|
||||
{
|
||||
name: "Mixed case condition",
|
||||
sql: "SELECT * FROM products",
|
||||
condition: "CaTegorY = 'ELECTRONICS'",
|
||||
connector: "AND",
|
||||
want: "select * from `products` where `category`='electronics'",
|
||||
},
|
||||
|
||||
// test - error case
|
||||
{
|
||||
name: "Empty SQL",
|
||||
sql: "",
|
||||
condition: "id = 1",
|
||||
connector: "AND",
|
||||
wantErr: true,
|
||||
errContains: "empty SQL statement",
|
||||
},
|
||||
{
|
||||
name: "Empty condition",
|
||||
sql: "SELECT * FROM users",
|
||||
condition: "",
|
||||
connector: "AND",
|
||||
wantErr: true,
|
||||
errContains: "empty filter condition",
|
||||
},
|
||||
{
|
||||
name: "Invalid connector",
|
||||
sql: "SELECT * FROM users",
|
||||
condition: "is_active = true",
|
||||
connector: "",
|
||||
wantErr: true,
|
||||
errContains: "invalid filter operator",
|
||||
},
|
||||
{
|
||||
name: "Unsupported statement type",
|
||||
sql: "CREATE TABLE users (id INT, name VARCHAR(255))",
|
||||
condition: "id > 0",
|
||||
connector: "AND",
|
||||
wantErr: true,
|
||||
errContains: "only support SELECT/UPDATE/DELETE",
|
||||
},
|
||||
{
|
||||
name: "Malformed SQL",
|
||||
sql: "SELECTZ * FRON users",
|
||||
condition: "id = 1",
|
||||
connector: "AND",
|
||||
wantErr: true,
|
||||
errContains: "failed to parse SQL",
|
||||
},
|
||||
{
|
||||
name: "Malformed condition",
|
||||
sql: "SELECT * FROM users",
|
||||
condition: "id ==",
|
||||
connector: "AND",
|
||||
wantErr: true,
|
||||
errContains: "parse filter condition failed",
|
||||
},
|
||||
}
|
||||
|
||||
// run case
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := parser.AppendSQLFilter(tt.sql, sqlparser.SQLFilterOp(tt.connector), tt.condition)
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Fatal("Expected error, got nil")
|
||||
}
|
||||
if tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) {
|
||||
t.Errorf("Expected error to contain %q, got %q", tt.errContains, err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
normalizedResult := strings.ToLower(strings.Join(strings.Fields(result), " "))
|
||||
normalizedWant := strings.ToLower(strings.Join(strings.Fields(tt.want), " "))
|
||||
if !strings.EqualFold(normalizedResult, normalizedWant) {
|
||||
t.Errorf("Result mismatch:\nWant: %s\nGot: %s", normalizedWant, normalizedResult)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue