996 lines
		
	
	
		
			27 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			996 lines
		
	
	
		
			27 KiB
		
	
	
	
		
			Go
		
	
	
	
| /*
 | |
|  * Copyright 2025 coze-dev Authors
 | |
|  *
 | |
|  * Licensed under the Apache License, Version 2.0 (the "License");
 | |
|  * you may not use this file except in compliance with the License.
 | |
|  * You may obtain a copy of the License at
 | |
|  *
 | |
|  *     http://www.apache.org/licenses/LICENSE-2.0
 | |
|  *
 | |
|  * Unless required by applicable law or agreed to in writing, software
 | |
|  * distributed under the License is distributed on an "AS IS" BASIS,
 | |
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
|  * See the License for the specific language governing permissions and
 | |
|  * limitations under the License.
 | |
|  */
 | |
| 
 | |
| package rdb
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"fmt"
 | |
| 	"reflect"
 | |
| 	"strings"
 | |
| 
 | |
| 	"gorm.io/gorm"
 | |
| 
 | |
| 	"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
 | |
| 	"github.com/coze-dev/coze-studio/backend/infra/contract/rdb"
 | |
| 	entity2 "github.com/coze-dev/coze-studio/backend/infra/contract/rdb/entity"
 | |
| 	sqlparsercontract "github.com/coze-dev/coze-studio/backend/infra/contract/sqlparser"
 | |
| 	"github.com/coze-dev/coze-studio/backend/infra/impl/sqlparser"
 | |
| 	"github.com/coze-dev/coze-studio/backend/pkg/logs"
 | |
| )
 | |
| 
 | |
| type mysqlService struct {
 | |
| 	db        *gorm.DB
 | |
| 	generator idgen.IDGenerator
 | |
| }
 | |
| 
 | |
| func NewService(db *gorm.DB, generator idgen.IDGenerator) rdb.RDB {
 | |
| 	return &mysqlService{db: db, generator: generator}
 | |
| }
 | |
| 
 | |
| // CreateTable create table
 | |
| func (m *mysqlService) CreateTable(ctx context.Context, req *rdb.CreateTableRequest) (*rdb.CreateTableResponse, error) {
 | |
| 	if req == nil || req.Table == nil {
 | |
| 		return nil, fmt.Errorf("invalid request")
 | |
| 	}
 | |
| 
 | |
| 	// build column definitions
 | |
| 	columnDefs := make([]string, 0, len(req.Table.Columns))
 | |
| 	for _, col := range req.Table.Columns {
 | |
| 		colDef := fmt.Sprintf("`%s` %s", col.Name, col.DataType)
 | |
| 
 | |
| 		if col.Length != nil {
 | |
| 			colDef += fmt.Sprintf("(%d)", *col.Length)
 | |
| 		} else if col.Length == nil && col.DataType == entity2.TypeVarchar {
 | |
| 			colDef += fmt.Sprintf("(%d)", 255)
 | |
| 		}
 | |
| 
 | |
| 		if col.NotNull {
 | |
| 			colDef += " NOT NULL"
 | |
| 		}
 | |
| 		if col.DefaultValue != nil {
 | |
| 			if col.DataType == entity2.TypeTimestamp {
 | |
| 				colDef += fmt.Sprintf(" DEFAULT %s", *col.DefaultValue)
 | |
| 			} else if col.DataType == entity2.TypeText {
 | |
| 				// do nothing
 | |
| 			} else {
 | |
| 				colDef += fmt.Sprintf(" DEFAULT '%s'", *col.DefaultValue)
 | |
| 			}
 | |
| 		}
 | |
| 		if col.AutoIncrement {
 | |
| 			colDef += " AUTO_INCREMENT"
 | |
| 		}
 | |
| 		if col.Comment != nil {
 | |
| 			colDef += fmt.Sprintf(" COMMENT '%s'", *col.Comment)
 | |
| 		}
 | |
| 
 | |
| 		columnDefs = append(columnDefs, colDef)
 | |
| 	}
 | |
| 
 | |
| 	// build index definitions
 | |
| 	for _, idx := range req.Table.Indexes {
 | |
| 		var idxDef string
 | |
| 		switch idx.Type {
 | |
| 		case entity2.PrimaryKey:
 | |
| 			idxDef = fmt.Sprintf("PRIMARY KEY (`%s`)", strings.Join(idx.Columns, "`,`"))
 | |
| 		case entity2.UniqueKey:
 | |
| 			idxDef = fmt.Sprintf("UNIQUE KEY `%s` (`%s`)", idx.Name, strings.Join(idx.Columns, "`,`"))
 | |
| 		default:
 | |
| 			idxDef = fmt.Sprintf("KEY `%s` (`%s`)", idx.Name, strings.Join(idx.Columns, "`,`"))
 | |
| 		}
 | |
| 		columnDefs = append(columnDefs, idxDef)
 | |
| 	}
 | |
| 
 | |
| 	tableOptions := make([]string, 0)
 | |
| 	if req.Table.Options != nil {
 | |
| 		if req.Table.Options.Collate != nil {
 | |
| 			tableOptions = append(tableOptions, fmt.Sprintf("COLLATE=%s", *req.Table.Options.Collate))
 | |
| 		}
 | |
| 		if req.Table.Options.AutoIncrement != nil {
 | |
| 			tableOptions = append(tableOptions, fmt.Sprintf("AUTO_INCREMENT=%d", *req.Table.Options.AutoIncrement))
 | |
| 		}
 | |
| 		if req.Table.Options.Comment != nil {
 | |
| 			tableOptions = append(tableOptions, fmt.Sprintf("COMMENT='%s'", *req.Table.Options.Comment))
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	tableName := req.Table.Name
 | |
| 	if req.Table.Name == "" {
 | |
| 		genName, err := m.genTableName(ctx)
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 
 | |
| 		tableName = genName
 | |
| 	}
 | |
| 
 | |
| 	createSQL := fmt.Sprintf("CREATE TABLE IF NOT EXISTS `%s` (\n  %s\n) %s",
 | |
| 		tableName,
 | |
| 		strings.Join(columnDefs, ",\n  "),
 | |
| 		strings.Join(tableOptions, " "),
 | |
| 	)
 | |
| 
 | |
| 	logs.CtxInfof(ctx, "[CreateTable] execute sql is %s, req is %v", createSQL, req)
 | |
| 
 | |
| 	err := m.db.WithContext(ctx).Exec(createSQL).Error
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("failed to create table: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	resTable := req.Table
 | |
| 	resTable.Name = tableName
 | |
| 	return &rdb.CreateTableResponse{Table: resTable}, nil
 | |
| }
 | |
| 
 | |
| // AlterTable alter table
 | |
| func (m *mysqlService) AlterTable(ctx context.Context, req *rdb.AlterTableRequest) (*rdb.AlterTableResponse, error) {
 | |
| 	if req == nil || len(req.Operations) == 0 {
 | |
| 		return nil, fmt.Errorf("invalid request")
 | |
| 	}
 | |
| 
 | |
| 	alterSQL := fmt.Sprintf("ALTER TABLE `%s`", req.TableName)
 | |
| 	operations := make([]string, 0, len(req.Operations))
 | |
| 
 | |
| 	for _, op := range req.Operations {
 | |
| 		switch op.Action {
 | |
| 		case entity2.AddColumn:
 | |
| 			if op.Column == nil {
 | |
| 				return nil, fmt.Errorf("column is required for ADD COLUMN operation")
 | |
| 			}
 | |
| 			colDef := fmt.Sprintf("ADD COLUMN `%s` %s", op.Column.Name, op.Column.DataType)
 | |
| 			if op.Column.Length != nil {
 | |
| 				colDef += fmt.Sprintf("(%d)", *op.Column.Length)
 | |
| 			} else if op.Column.Length == nil && op.Column.DataType == entity2.TypeVarchar {
 | |
| 				colDef += fmt.Sprintf("(%d)", 255)
 | |
| 			}
 | |
| 
 | |
| 			if op.Column.NotNull {
 | |
| 				colDef += " NOT NULL"
 | |
| 			}
 | |
| 
 | |
| 			if op.Column.DefaultValue != nil {
 | |
| 				if op.Column.DataType == entity2.TypeTimestamp {
 | |
| 					colDef += fmt.Sprintf(" DEFAULT %s", *op.Column.DefaultValue)
 | |
| 				} else {
 | |
| 					colDef += fmt.Sprintf(" DEFAULT '%s'", *op.Column.DefaultValue)
 | |
| 				}
 | |
| 			}
 | |
| 
 | |
| 			operations = append(operations, colDef)
 | |
| 
 | |
| 		case entity2.DropColumn:
 | |
| 			if op.Column == nil {
 | |
| 				return nil, fmt.Errorf("column is required for DROP COLUMN operation")
 | |
| 			}
 | |
| 			operations = append(operations, fmt.Sprintf("DROP COLUMN `%s`", op.Column.Name))
 | |
| 
 | |
| 		case entity2.ModifyColumn:
 | |
| 			if op.Column == nil {
 | |
| 				return nil, fmt.Errorf("column is required for MODIFY COLUMN operation")
 | |
| 			}
 | |
| 			colDef := fmt.Sprintf("MODIFY COLUMN `%s` %s", op.Column.Name, op.Column.DataType)
 | |
| 			if op.Column.Length != nil {
 | |
| 				colDef += fmt.Sprintf("(%d)", *op.Column.Length)
 | |
| 			} else if op.Column.Length == nil && op.Column.DataType == entity2.TypeVarchar {
 | |
| 				colDef += fmt.Sprintf("(%d)", 255)
 | |
| 			}
 | |
| 			operations = append(operations, colDef)
 | |
| 
 | |
| 		case entity2.RenameColumn:
 | |
| 			if op.Column == nil || op.OldName == nil {
 | |
| 				return nil, fmt.Errorf("column and old name are required for RENAME COLUMN operation")
 | |
| 			}
 | |
| 			operations = append(operations, fmt.Sprintf("RENAME COLUMN `%s` TO `%s`", *op.OldName, op.Column.Name))
 | |
| 
 | |
| 		case entity2.AddIndex:
 | |
| 			if op.Index == nil {
 | |
| 				return nil, fmt.Errorf("index is required for ADD INDEX operation")
 | |
| 			}
 | |
| 			var idxDef string
 | |
| 			switch op.Index.Type {
 | |
| 			case entity2.PrimaryKey:
 | |
| 				idxDef = fmt.Sprintf("ADD PRIMARY KEY (`%s`)", strings.Join(op.Index.Columns, "`,`"))
 | |
| 			case entity2.UniqueKey:
 | |
| 				idxDef = fmt.Sprintf("ADD UNIQUE INDEX `%s` (`%s`)", op.Index.Name, strings.Join(op.Index.Columns, "`,`"))
 | |
| 			default:
 | |
| 				idxDef = fmt.Sprintf("ADD INDEX `%s` (`%s`)", op.Index.Name, strings.Join(op.Index.Columns, "`,`"))
 | |
| 			}
 | |
| 			operations = append(operations, idxDef)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	alterSQL += " " + strings.Join(operations, ", ")
 | |
| 
 | |
| 	logs.CtxInfof(ctx, "[AlterTable] execute sql is %s, req is %v", alterSQL, req)
 | |
| 
 | |
| 	err := m.db.WithContext(ctx).Exec(alterSQL).Error
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("failed to alter table: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	table, err := m.getTableInfo(ctx, req.TableName)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("failed to get table info: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	return &rdb.AlterTableResponse{Table: table}, nil
 | |
| }
 | |
| 
 | |
| // DropTable drop table
 | |
| func (m *mysqlService) DropTable(ctx context.Context, req *rdb.DropTableRequest) (*rdb.DropTableResponse, error) {
 | |
| 	if req == nil {
 | |
| 		return nil, fmt.Errorf("invalid request")
 | |
| 	}
 | |
| 
 | |
| 	dropSQL := "DROP TABLE"
 | |
| 	if req.IfExists {
 | |
| 		dropSQL += " IF EXISTS"
 | |
| 	}
 | |
| 	dropSQL += fmt.Sprintf(" `%s`", req.TableName)
 | |
| 
 | |
| 	logs.CtxInfof(ctx, "[DropTable] execute sql is %s, req is %v", dropSQL, req)
 | |
| 
 | |
| 	err := m.db.WithContext(ctx).Exec(dropSQL).Error
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("failed to drop table: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	return &rdb.DropTableResponse{Success: true}, nil
 | |
| }
 | |
| 
 | |
| // GetTable get table schema info
 | |
| func (m *mysqlService) GetTable(ctx context.Context, req *rdb.GetTableRequest) (*rdb.GetTableResponse, error) {
 | |
| 	if req == nil {
 | |
| 		return nil, fmt.Errorf("invalid request")
 | |
| 	}
 | |
| 
 | |
| 	table, err := m.getTableInfo(ctx, req.TableName)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return &rdb.GetTableResponse{Table: table}, nil
 | |
| }
 | |
| 
 | |
| func (m *mysqlService) InsertData(ctx context.Context, req *rdb.InsertDataRequest) (*rdb.InsertDataResponse, error) {
 | |
| 	if req == nil || len(req.Data) == 0 {
 | |
| 		return nil, fmt.Errorf("invalid request")
 | |
| 	}
 | |
| 
 | |
| 	fields := make([]string, 0)
 | |
| 	for field := range req.Data[0] {
 | |
| 		fields = append(fields, field)
 | |
| 	}
 | |
| 
 | |
| 	const batchSize = 1000
 | |
| 	var totalAffected int64
 | |
| 
 | |
| 	for i := 0; i < len(req.Data); i += batchSize {
 | |
| 		end := i + batchSize
 | |
| 		if end > len(req.Data) {
 | |
| 			end = len(req.Data)
 | |
| 		}
 | |
| 
 | |
| 		currentBatch := req.Data[i:end]
 | |
| 
 | |
| 		placeholderGroups := make([]string, 0, len(currentBatch))
 | |
| 		values := make([]interface{}, 0, len(currentBatch)*len(fields))
 | |
| 
 | |
| 		for _, row := range currentBatch {
 | |
| 			placeholders := make([]string, len(fields))
 | |
| 			for j := range placeholders {
 | |
| 				placeholders[j] = "?"
 | |
| 			}
 | |
| 			placeholderGroups = append(placeholderGroups, "("+strings.Join(placeholders, ",")+")")
 | |
| 
 | |
| 			for _, field := range fields {
 | |
| 				values = append(values, row[field])
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		insertSQL := fmt.Sprintf("INSERT INTO `%s` (`%s`) VALUES %s",
 | |
| 			req.TableName,
 | |
| 			strings.Join(fields, "`,`"),
 | |
| 			strings.Join(placeholderGroups, ","),
 | |
| 		)
 | |
| 
 | |
| 		logs.CtxInfof(ctx, "[InsertData] execute sql is %s, value is %v in batch %d", insertSQL, values, i)
 | |
| 
 | |
| 		result := m.db.WithContext(ctx).Exec(insertSQL, values...)
 | |
| 		if result.Error != nil {
 | |
| 			return nil, result.Error
 | |
| 		}
 | |
| 
 | |
| 		affected := result.RowsAffected
 | |
| 		totalAffected += affected
 | |
| 	}
 | |
| 
 | |
| 	return &rdb.InsertDataResponse{AffectedRows: totalAffected}, nil
 | |
| }
 | |
| 
 | |
| // UpdateData Update data
 | |
| func (m *mysqlService) UpdateData(ctx context.Context, req *rdb.UpdateDataRequest) (*rdb.UpdateDataResponse, error) {
 | |
| 	if req == nil {
 | |
| 		return nil, fmt.Errorf("invalid request")
 | |
| 	}
 | |
| 
 | |
| 	setClauses := make([]string, 0)
 | |
| 	values := make([]interface{}, 0)
 | |
| 	for field, value := range req.Data {
 | |
| 		setClauses = append(setClauses, fmt.Sprintf("`%s` = ?", field))
 | |
| 		values = append(values, value)
 | |
| 	}
 | |
| 
 | |
| 	whereClause, whereValues, err := m.buildWhereClause(req.Where)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("failed to build where clause: %v", err)
 | |
| 	}
 | |
| 	values = append(values, whereValues...)
 | |
| 
 | |
| 	limitClause := ""
 | |
| 	if req.Limit != nil {
 | |
| 		limitClause = fmt.Sprintf(" LIMIT %d", *req.Limit)
 | |
| 	}
 | |
| 
 | |
| 	updateSQL := fmt.Sprintf("UPDATE `%s` SET %s%s%s",
 | |
| 		req.TableName,
 | |
| 		strings.Join(setClauses, ", "),
 | |
| 		whereClause,
 | |
| 		limitClause,
 | |
| 	)
 | |
| 
 | |
| 	logs.CtxInfof(ctx, "[UpdateData] execute sql is %s, value is %v, req is %v", updateSQL, values, req)
 | |
| 
 | |
| 	result := m.db.WithContext(ctx).Exec(updateSQL, values...)
 | |
| 	if result.Error != nil {
 | |
| 		return nil, result.Error
 | |
| 	}
 | |
| 
 | |
| 	affectedRows := result.RowsAffected
 | |
| 
 | |
| 	return &rdb.UpdateDataResponse{AffectedRows: affectedRows}, nil
 | |
| }
 | |
| 
 | |
| // DeleteData delete data
 | |
| func (m *mysqlService) DeleteData(ctx context.Context, req *rdb.DeleteDataRequest) (*rdb.DeleteDataResponse, error) {
 | |
| 	if req == nil {
 | |
| 		return nil, fmt.Errorf("invalid request")
 | |
| 	}
 | |
| 
 | |
| 	whereClause, whereValues, err := m.buildWhereClause(req.Where)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("failed to build where clause: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	limitClause := ""
 | |
| 	if req.Limit != nil {
 | |
| 		limitClause = fmt.Sprintf(" LIMIT %d", *req.Limit)
 | |
| 	}
 | |
| 
 | |
| 	deleteSQL := fmt.Sprintf("DELETE FROM `%s`%s%s",
 | |
| 		req.TableName,
 | |
| 		whereClause,
 | |
| 		limitClause,
 | |
| 	)
 | |
| 
 | |
| 	logs.CtxInfof(ctx, "[DeleteData] execute sql is %s, value is %v, req is %v", deleteSQL, whereValues, req)
 | |
| 
 | |
| 	result := m.db.WithContext(ctx).Exec(deleteSQL, whereValues...)
 | |
| 	if result.Error != nil {
 | |
| 		return nil, fmt.Errorf("failed to delete data: %v", result.Error)
 | |
| 	}
 | |
| 
 | |
| 	affectedRows := result.RowsAffected
 | |
| 
 | |
| 	return &rdb.DeleteDataResponse{AffectedRows: affectedRows}, nil
 | |
| }
 | |
| 
 | |
| // SelectData select data
 | |
| func (m *mysqlService) SelectData(ctx context.Context, req *rdb.SelectDataRequest) (*rdb.SelectDataResponse, error) {
 | |
| 	if req == nil {
 | |
| 		return nil, fmt.Errorf("invalid request")
 | |
| 	}
 | |
| 
 | |
| 	fields := "*"
 | |
| 	if len(req.Fields) > 0 {
 | |
| 		fields = strings.Join(req.Fields, ", ")
 | |
| 	}
 | |
| 
 | |
| 	whereClause := ""
 | |
| 	whereValues := make([]interface{}, 0)
 | |
| 	if req.Where != nil {
 | |
| 		clause, values, err := m.buildWhereClause(req.Where)
 | |
| 		if err != nil {
 | |
| 			return nil, fmt.Errorf("failed to build where clause: %v", err)
 | |
| 		}
 | |
| 		whereClause = clause
 | |
| 		whereValues = values
 | |
| 	}
 | |
| 
 | |
| 	orderByClause := ""
 | |
| 	if len(req.OrderBy) > 0 {
 | |
| 		orders := make([]string, len(req.OrderBy))
 | |
| 		for i, order := range req.OrderBy {
 | |
| 			orders[i] = fmt.Sprintf("%s %s", order.Field, order.Direction)
 | |
| 		}
 | |
| 		orderByClause = " ORDER BY " + strings.Join(orders, ", ")
 | |
| 	}
 | |
| 
 | |
| 	limitClause := ""
 | |
| 	if req.Limit != nil {
 | |
| 		limitClause = fmt.Sprintf(" LIMIT %d", *req.Limit)
 | |
| 		if req.Offset != nil {
 | |
| 			limitClause += fmt.Sprintf(" OFFSET %d", *req.Offset)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	selectSQL := fmt.Sprintf("SELECT %s FROM `%s`%s%s%s",
 | |
| 		fields,
 | |
| 		req.TableName,
 | |
| 		whereClause,
 | |
| 		orderByClause,
 | |
| 		limitClause,
 | |
| 	)
 | |
| 
 | |
| 	logs.CtxInfof(ctx, "[SelectData] execute sql is %s, value is %v, req is %v", selectSQL, whereValues, req)
 | |
| 
 | |
| 	rows, err := m.db.WithContext(ctx).Raw(selectSQL, whereValues...).Rows()
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("failed to execute select: %v", err)
 | |
| 	}
 | |
| 	defer rows.Close()
 | |
| 
 | |
| 	columns, err := rows.Columns()
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("failed to get columns: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	resultSet := &entity2.ResultSet{
 | |
| 		Columns: columns,
 | |
| 		Rows:    make([]map[string]interface{}, 0),
 | |
| 	}
 | |
| 
 | |
| 	for rows.Next() {
 | |
| 		values := make([]interface{}, len(columns))
 | |
| 		valuePtrs := make([]interface{}, len(columns))
 | |
| 		for i := range values {
 | |
| 			valuePtrs[i] = &values[i]
 | |
| 		}
 | |
| 
 | |
| 		if err := rows.Scan(valuePtrs...); err != nil {
 | |
| 			return nil, fmt.Errorf("failed to scan row: %v", err)
 | |
| 		}
 | |
| 
 | |
| 		rowData := make(map[string]interface{})
 | |
| 		for i, col := range columns {
 | |
| 			rowData[col] = values[i]
 | |
| 		}
 | |
| 		resultSet.Rows = append(resultSet.Rows, rowData)
 | |
| 	}
 | |
| 
 | |
| 	// get total count
 | |
| 	var total int64
 | |
| 	if whereClause != "" {
 | |
| 		countSQL := fmt.Sprintf("SELECT COUNT(*) FROM `%s`%s", req.TableName, whereClause)
 | |
| 		err = m.db.WithContext(ctx).Raw(countSQL, whereValues...).Scan(&total).Error
 | |
| 		if err != nil {
 | |
| 			return nil, fmt.Errorf("failed to get total count: %v", err)
 | |
| 		}
 | |
| 	} else {
 | |
| 		total = int64(len(resultSet.Rows))
 | |
| 	}
 | |
| 
 | |
| 	return &rdb.SelectDataResponse{
 | |
| 		ResultSet: resultSet,
 | |
| 		Total:     total,
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| // UpsertData upsert data
 | |
| func (m *mysqlService) UpsertData(ctx context.Context, req *rdb.UpsertDataRequest) (*rdb.UpsertDataResponse, error) {
 | |
| 	if req == nil || len(req.Data) == 0 {
 | |
| 		return nil, fmt.Errorf("invalid request: empty data")
 | |
| 	}
 | |
| 
 | |
| 	keys := req.Keys
 | |
| 	if len(keys) == 0 {
 | |
| 		primaryKeys, err := m.getTablePrimaryKeys(ctx, req.TableName)
 | |
| 		if err != nil {
 | |
| 			return nil, fmt.Errorf("failed to get primary keys: %v", err)
 | |
| 		}
 | |
| 
 | |
| 		if len(primaryKeys) == 0 {
 | |
| 			return nil, fmt.Errorf("table %s has no primary key, keys are required for upsert operation", req.TableName)
 | |
| 		}
 | |
| 
 | |
| 		keys = primaryKeys
 | |
| 	}
 | |
| 
 | |
| 	fields := make([]string, 0)
 | |
| 	for field := range req.Data[0] {
 | |
| 		fields = append(fields, field)
 | |
| 	}
 | |
| 
 | |
| 	const batchSize = 1000
 | |
| 	var totalAffected, totalInserted, totalUpdated int64
 | |
| 
 | |
| 	for i := 0; i < len(req.Data); i += batchSize {
 | |
| 		end := i + batchSize
 | |
| 		if end > len(req.Data) {
 | |
| 			end = len(req.Data)
 | |
| 		}
 | |
| 
 | |
| 		currentBatch := req.Data[i:end]
 | |
| 
 | |
| 		placeholderGroups := make([]string, 0, len(currentBatch))
 | |
| 		values := make([]interface{}, 0, len(currentBatch)*len(fields))
 | |
| 
 | |
| 		for _, row := range currentBatch {
 | |
| 			placeholders := make([]string, len(fields))
 | |
| 			for j := range placeholders {
 | |
| 				placeholders[j] = "?"
 | |
| 			}
 | |
| 			placeholderGroups = append(placeholderGroups, "("+strings.Join(placeholders, ",")+")")
 | |
| 
 | |
| 			for _, field := range fields {
 | |
| 				values = append(values, row[field])
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		// ON DUPLICATE KEY UPDATE部分
 | |
| 		updateClauses := make([]string, 0, len(fields))
 | |
| 		for _, field := range fields {
 | |
| 			isKey := false
 | |
| 			for _, key := range keys {
 | |
| 				if field == key {
 | |
| 					isKey = true
 | |
| 					break
 | |
| 				}
 | |
| 			}
 | |
| 			if !isKey {
 | |
| 				updateClauses = append(updateClauses, fmt.Sprintf("`%s`=VALUES(`%s`)", field, field))
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		upsertSQL := fmt.Sprintf(
 | |
| 			"INSERT INTO `%s` (`%s`) VALUES %s ON DUPLICATE KEY UPDATE %s",
 | |
| 			req.TableName,
 | |
| 			strings.Join(fields, "`,`"),
 | |
| 			strings.Join(placeholderGroups, ","),
 | |
| 			strings.Join(updateClauses, ","),
 | |
| 		)
 | |
| 
 | |
| 		logs.CtxInfof(ctx, "[UpsertData] execute sql is %s, value is %v, batch is %d", upsertSQL, values, i)
 | |
| 
 | |
| 		result := m.db.WithContext(ctx).Exec(upsertSQL, values...)
 | |
| 		if result.Error != nil {
 | |
| 			return nil, fmt.Errorf("failed to upsert data: %v", result.Error)
 | |
| 		}
 | |
| 
 | |
| 		total, inserted, updated := calculateInsertedUpdated(result.RowsAffected, len(currentBatch))
 | |
| 		totalInserted += inserted
 | |
| 		totalUpdated += updated
 | |
| 		totalAffected += total
 | |
| 	}
 | |
| 
 | |
| 	return &rdb.UpsertDataResponse{
 | |
| 		AffectedRows:  totalAffected,
 | |
| 		InsertedRows:  totalInserted,
 | |
| 		UpdatedRows:   totalUpdated,
 | |
| 		UnchangedRows: int64(len(req.Data)) - totalAffected,
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| func (m *mysqlService) getTablePrimaryKeys(ctx context.Context, tableName string) ([]string, error) {
 | |
| 	query := `
 | |
|         SELECT COLUMN_NAME
 | |
|         FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE
 | |
|         WHERE TABLE_SCHEMA = DATABASE()
 | |
|           AND TABLE_NAME = ?
 | |
|           AND CONSTRAINT_NAME = 'PRIMARY'
 | |
|         ORDER BY ORDINAL_POSITION
 | |
|     `
 | |
| 
 | |
| 	var primaryKeys []string
 | |
| 	rows, err := m.db.WithContext(ctx).Raw(query, tableName).Rows()
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	defer rows.Close()
 | |
| 
 | |
| 	for rows.Next() {
 | |
| 		var columnName string
 | |
| 		if err := rows.Scan(&columnName); err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 		primaryKeys = append(primaryKeys, columnName)
 | |
| 	}
 | |
| 
 | |
| 	return primaryKeys, nil
 | |
| }
 | |
| 
 | |
| // calculateInsertedUpdated 函数保持不变
 | |
| func calculateInsertedUpdated(affectedRows int64, batchSize int) (int64, int64, int64) {
 | |
| 	updated := int64(0)
 | |
| 	inserted := affectedRows
 | |
| 	if affectedRows > int64(batchSize) {
 | |
| 		updated = affectedRows - int64(batchSize)
 | |
| 		inserted = int64(batchSize) - updated
 | |
| 	}
 | |
| 
 | |
| 	return inserted + updated, inserted, updated
 | |
| }
 | |
| 
 | |
| // ExecuteSQL Execute SQL
 | |
| func (m *mysqlService) ExecuteSQL(ctx context.Context, req *rdb.ExecuteSQLRequest) (*rdb.ExecuteSQLResponse, error) {
 | |
| 	if req == nil {
 | |
| 		return nil, fmt.Errorf("invalid request")
 | |
| 	}
 | |
| 
 | |
| 	logs.CtxInfof(ctx, "[ExecuteSQL] req is %v", req)
 | |
| 
 | |
| 	var processedSQL string
 | |
| 	var processedParams []interface{}
 | |
| 	var err error
 | |
| 
 | |
| 	// Handle SQLType: if raw, do not process params
 | |
| 	if req.SQLType == entity2.SQLType_Raw {
 | |
| 		processedSQL = req.SQL
 | |
| 		processedParams = nil
 | |
| 	} else {
 | |
| 		processedSQL, processedParams, err = m.processSliceParams(req.SQL, req.Params)
 | |
| 		if err != nil {
 | |
| 			return nil, fmt.Errorf("failed to process parameters: %v", err)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	operation, err := sqlparser.NewSQLParser().GetSQLOperation(processedSQL)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	if operation != sqlparsercontract.OperationTypeSelect {
 | |
| 		result := m.db.WithContext(ctx).Exec(processedSQL, processedParams...)
 | |
| 		if result.Error != nil {
 | |
| 			return nil, fmt.Errorf("failed to execute SQL: %v", result.Error)
 | |
| 		}
 | |
| 
 | |
| 		resultSet := &entity2.ResultSet{
 | |
| 			Columns:      []string{},
 | |
| 			Rows:         []map[string]interface{}{},
 | |
| 			AffectedRows: result.RowsAffected,
 | |
| 		}
 | |
| 
 | |
| 		return &rdb.ExecuteSQLResponse{
 | |
| 			ResultSet: resultSet,
 | |
| 		}, nil
 | |
| 	}
 | |
| 
 | |
| 	rows, err := m.db.WithContext(ctx).Raw(processedSQL, processedParams...).Rows()
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("failed to execute SQL: %v", err)
 | |
| 	}
 | |
| 	defer rows.Close()
 | |
| 
 | |
| 	columns, err := rows.Columns()
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("failed to get columns: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	resultSet := &entity2.ResultSet{
 | |
| 		Columns: columns,
 | |
| 		Rows:    make([]map[string]interface{}, 0),
 | |
| 	}
 | |
| 
 | |
| 	for rows.Next() {
 | |
| 		values := make([]interface{}, len(columns))
 | |
| 		valuePtrs := make([]interface{}, len(columns))
 | |
| 		for i := range values {
 | |
| 			valuePtrs[i] = &values[i]
 | |
| 		}
 | |
| 
 | |
| 		if err := rows.Scan(valuePtrs...); err != nil {
 | |
| 			return nil, fmt.Errorf("failed to scan row: %v", err)
 | |
| 		}
 | |
| 
 | |
| 		rowData := make(map[string]interface{})
 | |
| 		for i, col := range columns {
 | |
| 			rowData[col] = values[i]
 | |
| 		}
 | |
| 		resultSet.Rows = append(resultSet.Rows, rowData)
 | |
| 	}
 | |
| 
 | |
| 	if err := rows.Err(); err != nil {
 | |
| 		return nil, fmt.Errorf("error while reading rows: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	return &rdb.ExecuteSQLResponse{
 | |
| 		ResultSet: resultSet,
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| func (m *mysqlService) processSliceParams(sql string, params []interface{}) (string, []interface{}, error) {
 | |
| 	if len(params) == 0 {
 | |
| 		return sql, params, nil
 | |
| 	}
 | |
| 
 | |
| 	processedParams := make([]interface{}, 0)
 | |
| 	paramIndex := 0
 | |
| 	resultSQL := ""
 | |
| 	lastPos := 0
 | |
| 
 | |
| 	// get all ? positions
 | |
| 	for i := 0; i < len(sql); i++ {
 | |
| 		if sql[i] == '?' && paramIndex < len(params) {
 | |
| 			resultSQL += sql[lastPos:i]
 | |
| 			lastPos = i + 1
 | |
| 
 | |
| 			param := params[paramIndex]
 | |
| 			paramIndex++
 | |
| 
 | |
| 			if m.isSlice(param) {
 | |
| 				sliceValues, err := m.getSliceValues(param)
 | |
| 				if err != nil {
 | |
| 					return "", nil, err
 | |
| 				}
 | |
| 
 | |
| 				if len(sliceValues) == 0 {
 | |
| 					resultSQL += "(NULL)"
 | |
| 				} else {
 | |
| 					// (?, ?, ...)
 | |
| 					placeholders := make([]string, len(sliceValues))
 | |
| 					for j := range placeholders {
 | |
| 						placeholders[j] = "?"
 | |
| 					}
 | |
| 					resultSQL += "(" + strings.Join(placeholders, ", ") + ")"
 | |
| 
 | |
| 					processedParams = append(processedParams, sliceValues...)
 | |
| 				}
 | |
| 			} else {
 | |
| 				resultSQL += "?"
 | |
| 				processedParams = append(processedParams, param)
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	resultSQL += sql[lastPos:]
 | |
| 
 | |
| 	return resultSQL, processedParams, nil
 | |
| }
 | |
| 
 | |
| func (m *mysqlService) isSlice(param interface{}) bool {
 | |
| 	if param == nil {
 | |
| 		return false
 | |
| 	}
 | |
| 
 | |
| 	rv := reflect.ValueOf(param)
 | |
| 	return rv.Kind() == reflect.Slice && rv.Type().Elem().Kind() != reflect.Uint8 // exclude []byte
 | |
| }
 | |
| 
 | |
| func (m *mysqlService) getSliceValues(param interface{}) ([]interface{}, error) {
 | |
| 	rv := reflect.ValueOf(param)
 | |
| 	if rv.Kind() != reflect.Slice {
 | |
| 		return nil, fmt.Errorf("parameter is not a slice")
 | |
| 	}
 | |
| 
 | |
| 	length := rv.Len()
 | |
| 	values := make([]interface{}, length)
 | |
| 
 | |
| 	for i := 0; i < length; i++ {
 | |
| 		values[i] = rv.Index(i).Interface()
 | |
| 	}
 | |
| 
 | |
| 	return values, nil
 | |
| }
 | |
| 
 | |
| func (m *mysqlService) genTableName(ctx context.Context) (string, error) {
 | |
| 	id, err := m.generator.GenID(ctx)
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 
 | |
| 	return fmt.Sprintf("table_%d", id), nil
 | |
| }
 | |
| 
 | |
| func (m *mysqlService) getTableInfo(ctx context.Context, tableName string) (*entity2.Table, error) {
 | |
| 	tableInfoSQL := `
 | |
|         SELECT 
 | |
|             TABLE_NAME,
 | |
|             TABLE_COLLATION,
 | |
|             AUTO_INCREMENT,
 | |
|             TABLE_COMMENT
 | |
|         FROM information_schema.TABLES 
 | |
|         WHERE TABLE_SCHEMA = DATABASE() 
 | |
|         AND TABLE_NAME = ?
 | |
|     `
 | |
| 
 | |
| 	var (
 | |
| 		name          string
 | |
| 		collation     *string
 | |
| 		autoIncrement *int64
 | |
| 		comment       *string
 | |
| 	)
 | |
| 
 | |
| 	err := m.db.WithContext(ctx).Raw(tableInfoSQL, tableName).Row().Scan(
 | |
| 		&name,
 | |
| 		&collation,
 | |
| 		&autoIncrement,
 | |
| 		&comment,
 | |
| 	)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	columnsSQL := `
 | |
|         SELECT 
 | |
|             COLUMN_NAME,
 | |
|             DATA_TYPE,
 | |
|             CHARACTER_MAXIMUM_LENGTH,
 | |
|             IS_NULLABLE,
 | |
|             COLUMN_DEFAULT,
 | |
|             EXTRA,
 | |
|             COLUMN_COMMENT
 | |
|         FROM information_schema.COLUMNS 
 | |
|         WHERE TABLE_SCHEMA = DATABASE() 
 | |
|         AND TABLE_NAME = ?
 | |
|         ORDER BY ORDINAL_POSITION
 | |
|     `
 | |
| 
 | |
| 	type columnInfo struct {
 | |
| 		ColumnName    string  `gorm:"column:COLUMN_NAME"`
 | |
| 		DataType      string  `gorm:"column:DATA_TYPE"`
 | |
| 		CharLength    *int    `gorm:"column:CHARACTER_MAXIMUM_LENGTH"`
 | |
| 		IsNullable    string  `gorm:"column:IS_NULLABLE"`
 | |
| 		DefaultValue  *string `gorm:"column:COLUMN_DEFAULT"`
 | |
| 		Extra         string  `gorm:"column:EXTRA"`
 | |
| 		ColumnComment *string `gorm:"column:COLUMN_COMMENT"`
 | |
| 	}
 | |
| 
 | |
| 	var columnsData []columnInfo
 | |
| 	err = m.db.WithContext(ctx).Raw(columnsSQL, tableName).Scan(&columnsData).Error
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	columns := make([]*entity2.Column, len(columnsData))
 | |
| 	for i, colData := range columnsData {
 | |
| 		column := &entity2.Column{
 | |
| 			Name:          colData.ColumnName,
 | |
| 			DataType:      entity2.DataType(colData.DataType),
 | |
| 			Length:        colData.CharLength,
 | |
| 			NotNull:       colData.IsNullable == "NO",
 | |
| 			DefaultValue:  colData.DefaultValue,
 | |
| 			AutoIncrement: strings.Contains(colData.Extra, "auto_increment"),
 | |
| 			Comment:       colData.ColumnComment,
 | |
| 		}
 | |
| 		columns[i] = column
 | |
| 	}
 | |
| 
 | |
| 	indexesSQL := `
 | |
|         SELECT 
 | |
|             INDEX_NAME,
 | |
|             NON_UNIQUE,
 | |
|             INDEX_TYPE,
 | |
|             GROUP_CONCAT(COLUMN_NAME ORDER BY SEQ_IN_INDEX)
 | |
|         FROM information_schema.STATISTICS 
 | |
|         WHERE TABLE_SCHEMA = DATABASE() 
 | |
|         AND TABLE_NAME = ?
 | |
|         GROUP BY INDEX_NAME, NON_UNIQUE, INDEX_TYPE
 | |
|     `
 | |
| 
 | |
| 	type indexInfo struct {
 | |
| 		IndexName string `gorm:"column:INDEX_NAME"`
 | |
| 		NonUnique int    `gorm:"column:NON_UNIQUE"`
 | |
| 		IndexType string `gorm:"column:INDEX_TYPE"`
 | |
| 		Columns   string `gorm:"column:GROUP_CONCAT"`
 | |
| 	}
 | |
| 
 | |
| 	var indexesData []indexInfo
 | |
| 	err = m.db.WithContext(ctx).Raw(indexesSQL, tableName).Scan(&indexesData).Error
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	indexes := make([]*entity2.Index, 0, len(indexesData))
 | |
| 	for _, idxData := range indexesData {
 | |
| 		index := &entity2.Index{
 | |
| 			Name:    idxData.IndexName,
 | |
| 			Type:    entity2.IndexType(idxData.IndexType),
 | |
| 			Columns: strings.Split(idxData.Columns, ","),
 | |
| 		}
 | |
| 		indexes = append(indexes, index)
 | |
| 	}
 | |
| 
 | |
| 	return &entity2.Table{
 | |
| 		Name:    name,
 | |
| 		Columns: columns,
 | |
| 		Indexes: indexes,
 | |
| 		Options: &entity2.TableOption{
 | |
| 			Collate:       collation,
 | |
| 			AutoIncrement: autoIncrement,
 | |
| 			Comment:       comment,
 | |
| 		},
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| func (m *mysqlService) buildWhereClause(condition *rdb.ComplexCondition) (string, []interface{}, error) {
 | |
| 	if condition == nil {
 | |
| 		return "", nil, nil
 | |
| 	}
 | |
| 
 | |
| 	if condition.Operator == "" {
 | |
| 		condition.Operator = entity2.AND
 | |
| 	}
 | |
| 
 | |
| 	var whereClause strings.Builder
 | |
| 	values := make([]interface{}, 0)
 | |
| 
 | |
| 	for i, cond := range condition.Conditions {
 | |
| 		if i > 0 {
 | |
| 			whereClause.WriteString(fmt.Sprintf(" %s ", condition.Operator))
 | |
| 		}
 | |
| 
 | |
| 		if cond.Operator == entity2.OperatorIn || cond.Operator == entity2.OperatorNotIn {
 | |
| 			if m.isSlice(cond.Value) {
 | |
| 				sliceValues, err := m.getSliceValues(cond.Value)
 | |
| 				if err != nil {
 | |
| 					return "", nil, fmt.Errorf("failed to process slice values: %v", err)
 | |
| 				}
 | |
| 
 | |
| 				if len(sliceValues) == 0 {
 | |
| 					whereClause.WriteString(fmt.Sprintf("`%s` %s (NULL)", cond.Field, string(cond.Operator)))
 | |
| 				} else {
 | |
| 					placeholders := make([]string, len(sliceValues))
 | |
| 					for i := range placeholders {
 | |
| 						placeholders[i] = "?"
 | |
| 					}
 | |
| 					whereClause.WriteString(fmt.Sprintf("`%s` %s (%s)", cond.Field, string(cond.Operator), strings.Join(placeholders, ",")))
 | |
| 
 | |
| 					values = append(values, sliceValues...)
 | |
| 				}
 | |
| 			} else {
 | |
| 				return "", nil, fmt.Errorf("IN operator requires a slice of values")
 | |
| 			}
 | |
| 		} else if cond.Operator == entity2.OperatorIsNull || cond.Operator == entity2.OperatorIsNotNull {
 | |
| 			whereClause.WriteString(fmt.Sprintf("`%s` %s", cond.Field, cond.Operator))
 | |
| 		} else {
 | |
| 			whereClause.WriteString(fmt.Sprintf("`%s` %s ?", cond.Field, cond.Operator))
 | |
| 			values = append(values, cond.Value)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if len(condition.NestedConditions) > 0 {
 | |
| 		whereClause.WriteString(" AND (")
 | |
| 		for i, nested := range condition.NestedConditions {
 | |
| 			if i > 0 {
 | |
| 				whereClause.WriteString(fmt.Sprintf(" %s ", nested.Operator))
 | |
| 			}
 | |
| 			nestedClause, nestedValues, err := m.buildWhereClause(nested)
 | |
| 			if err != nil {
 | |
| 				return "", nil, err
 | |
| 			}
 | |
| 			whereClause.WriteString(nestedClause)
 | |
| 			values = append(values, nestedValues...)
 | |
| 		}
 | |
| 		whereClause.WriteString(")")
 | |
| 	}
 | |
| 
 | |
| 	if whereClause.Len() > 0 {
 | |
| 		return " WHERE " + whereClause.String(), values, nil
 | |
| 	}
 | |
| 	return "", values, nil
 | |
| }
 |