996 lines
27 KiB
Go
996 lines
27 KiB
Go
/*
|
|
* Copyright 2025 coze-dev Authors
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package rdb
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"reflect"
|
|
"strings"
|
|
|
|
"gorm.io/gorm"
|
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
|
"github.com/coze-dev/coze-studio/backend/infra/contract/rdb"
|
|
entity2 "github.com/coze-dev/coze-studio/backend/infra/contract/rdb/entity"
|
|
sqlparsercontract "github.com/coze-dev/coze-studio/backend/infra/contract/sqlparser"
|
|
"github.com/coze-dev/coze-studio/backend/infra/impl/sqlparser"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
|
)
|
|
|
|
type mysqlService struct {
|
|
db *gorm.DB
|
|
generator idgen.IDGenerator
|
|
}
|
|
|
|
func NewService(db *gorm.DB, generator idgen.IDGenerator) rdb.RDB {
|
|
return &mysqlService{db: db, generator: generator}
|
|
}
|
|
|
|
// CreateTable create table
|
|
func (m *mysqlService) CreateTable(ctx context.Context, req *rdb.CreateTableRequest) (*rdb.CreateTableResponse, error) {
|
|
if req == nil || req.Table == nil {
|
|
return nil, fmt.Errorf("invalid request")
|
|
}
|
|
|
|
// build column definitions
|
|
columnDefs := make([]string, 0, len(req.Table.Columns))
|
|
for _, col := range req.Table.Columns {
|
|
colDef := fmt.Sprintf("`%s` %s", col.Name, col.DataType)
|
|
|
|
if col.Length != nil {
|
|
colDef += fmt.Sprintf("(%d)", *col.Length)
|
|
} else if col.Length == nil && col.DataType == entity2.TypeVarchar {
|
|
colDef += fmt.Sprintf("(%d)", 255)
|
|
}
|
|
|
|
if col.NotNull {
|
|
colDef += " NOT NULL"
|
|
}
|
|
if col.DefaultValue != nil {
|
|
if col.DataType == entity2.TypeTimestamp {
|
|
colDef += fmt.Sprintf(" DEFAULT %s", *col.DefaultValue)
|
|
} else if col.DataType == entity2.TypeText {
|
|
// do nothing
|
|
} else {
|
|
colDef += fmt.Sprintf(" DEFAULT '%s'", *col.DefaultValue)
|
|
}
|
|
}
|
|
if col.AutoIncrement {
|
|
colDef += " AUTO_INCREMENT"
|
|
}
|
|
if col.Comment != nil {
|
|
colDef += fmt.Sprintf(" COMMENT '%s'", *col.Comment)
|
|
}
|
|
|
|
columnDefs = append(columnDefs, colDef)
|
|
}
|
|
|
|
// build index definitions
|
|
for _, idx := range req.Table.Indexes {
|
|
var idxDef string
|
|
switch idx.Type {
|
|
case entity2.PrimaryKey:
|
|
idxDef = fmt.Sprintf("PRIMARY KEY (`%s`)", strings.Join(idx.Columns, "`,`"))
|
|
case entity2.UniqueKey:
|
|
idxDef = fmt.Sprintf("UNIQUE KEY `%s` (`%s`)", idx.Name, strings.Join(idx.Columns, "`,`"))
|
|
default:
|
|
idxDef = fmt.Sprintf("KEY `%s` (`%s`)", idx.Name, strings.Join(idx.Columns, "`,`"))
|
|
}
|
|
columnDefs = append(columnDefs, idxDef)
|
|
}
|
|
|
|
tableOptions := make([]string, 0)
|
|
if req.Table.Options != nil {
|
|
if req.Table.Options.Collate != nil {
|
|
tableOptions = append(tableOptions, fmt.Sprintf("COLLATE=%s", *req.Table.Options.Collate))
|
|
}
|
|
if req.Table.Options.AutoIncrement != nil {
|
|
tableOptions = append(tableOptions, fmt.Sprintf("AUTO_INCREMENT=%d", *req.Table.Options.AutoIncrement))
|
|
}
|
|
if req.Table.Options.Comment != nil {
|
|
tableOptions = append(tableOptions, fmt.Sprintf("COMMENT='%s'", *req.Table.Options.Comment))
|
|
}
|
|
}
|
|
|
|
tableName := req.Table.Name
|
|
if req.Table.Name == "" {
|
|
genName, err := m.genTableName(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
tableName = genName
|
|
}
|
|
|
|
createSQL := fmt.Sprintf("CREATE TABLE IF NOT EXISTS `%s` (\n %s\n) %s",
|
|
tableName,
|
|
strings.Join(columnDefs, ",\n "),
|
|
strings.Join(tableOptions, " "),
|
|
)
|
|
|
|
logs.CtxInfof(ctx, "[CreateTable] execute sql is %s, req is %v", createSQL, req)
|
|
|
|
err := m.db.WithContext(ctx).Exec(createSQL).Error
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create table: %v", err)
|
|
}
|
|
|
|
resTable := req.Table
|
|
resTable.Name = tableName
|
|
return &rdb.CreateTableResponse{Table: resTable}, nil
|
|
}
|
|
|
|
// AlterTable alter table
|
|
func (m *mysqlService) AlterTable(ctx context.Context, req *rdb.AlterTableRequest) (*rdb.AlterTableResponse, error) {
|
|
if req == nil || len(req.Operations) == 0 {
|
|
return nil, fmt.Errorf("invalid request")
|
|
}
|
|
|
|
alterSQL := fmt.Sprintf("ALTER TABLE `%s`", req.TableName)
|
|
operations := make([]string, 0, len(req.Operations))
|
|
|
|
for _, op := range req.Operations {
|
|
switch op.Action {
|
|
case entity2.AddColumn:
|
|
if op.Column == nil {
|
|
return nil, fmt.Errorf("column is required for ADD COLUMN operation")
|
|
}
|
|
colDef := fmt.Sprintf("ADD COLUMN `%s` %s", op.Column.Name, op.Column.DataType)
|
|
if op.Column.Length != nil {
|
|
colDef += fmt.Sprintf("(%d)", *op.Column.Length)
|
|
} else if op.Column.Length == nil && op.Column.DataType == entity2.TypeVarchar {
|
|
colDef += fmt.Sprintf("(%d)", 255)
|
|
}
|
|
|
|
if op.Column.NotNull {
|
|
colDef += " NOT NULL"
|
|
}
|
|
|
|
if op.Column.DefaultValue != nil {
|
|
if op.Column.DataType == entity2.TypeTimestamp {
|
|
colDef += fmt.Sprintf(" DEFAULT %s", *op.Column.DefaultValue)
|
|
} else {
|
|
colDef += fmt.Sprintf(" DEFAULT '%s'", *op.Column.DefaultValue)
|
|
}
|
|
}
|
|
|
|
operations = append(operations, colDef)
|
|
|
|
case entity2.DropColumn:
|
|
if op.Column == nil {
|
|
return nil, fmt.Errorf("column is required for DROP COLUMN operation")
|
|
}
|
|
operations = append(operations, fmt.Sprintf("DROP COLUMN `%s`", op.Column.Name))
|
|
|
|
case entity2.ModifyColumn:
|
|
if op.Column == nil {
|
|
return nil, fmt.Errorf("column is required for MODIFY COLUMN operation")
|
|
}
|
|
colDef := fmt.Sprintf("MODIFY COLUMN `%s` %s", op.Column.Name, op.Column.DataType)
|
|
if op.Column.Length != nil {
|
|
colDef += fmt.Sprintf("(%d)", *op.Column.Length)
|
|
} else if op.Column.Length == nil && op.Column.DataType == entity2.TypeVarchar {
|
|
colDef += fmt.Sprintf("(%d)", 255)
|
|
}
|
|
operations = append(operations, colDef)
|
|
|
|
case entity2.RenameColumn:
|
|
if op.Column == nil || op.OldName == nil {
|
|
return nil, fmt.Errorf("column and old name are required for RENAME COLUMN operation")
|
|
}
|
|
operations = append(operations, fmt.Sprintf("RENAME COLUMN `%s` TO `%s`", *op.OldName, op.Column.Name))
|
|
|
|
case entity2.AddIndex:
|
|
if op.Index == nil {
|
|
return nil, fmt.Errorf("index is required for ADD INDEX operation")
|
|
}
|
|
var idxDef string
|
|
switch op.Index.Type {
|
|
case entity2.PrimaryKey:
|
|
idxDef = fmt.Sprintf("ADD PRIMARY KEY (`%s`)", strings.Join(op.Index.Columns, "`,`"))
|
|
case entity2.UniqueKey:
|
|
idxDef = fmt.Sprintf("ADD UNIQUE INDEX `%s` (`%s`)", op.Index.Name, strings.Join(op.Index.Columns, "`,`"))
|
|
default:
|
|
idxDef = fmt.Sprintf("ADD INDEX `%s` (`%s`)", op.Index.Name, strings.Join(op.Index.Columns, "`,`"))
|
|
}
|
|
operations = append(operations, idxDef)
|
|
}
|
|
}
|
|
|
|
alterSQL += " " + strings.Join(operations, ", ")
|
|
|
|
logs.CtxInfof(ctx, "[AlterTable] execute sql is %s, req is %v", alterSQL, req)
|
|
|
|
err := m.db.WithContext(ctx).Exec(alterSQL).Error
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to alter table: %v", err)
|
|
}
|
|
|
|
table, err := m.getTableInfo(ctx, req.TableName)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get table info: %v", err)
|
|
}
|
|
|
|
return &rdb.AlterTableResponse{Table: table}, nil
|
|
}
|
|
|
|
// DropTable drop table
|
|
func (m *mysqlService) DropTable(ctx context.Context, req *rdb.DropTableRequest) (*rdb.DropTableResponse, error) {
|
|
if req == nil {
|
|
return nil, fmt.Errorf("invalid request")
|
|
}
|
|
|
|
dropSQL := "DROP TABLE"
|
|
if req.IfExists {
|
|
dropSQL += " IF EXISTS"
|
|
}
|
|
dropSQL += fmt.Sprintf(" `%s`", req.TableName)
|
|
|
|
logs.CtxInfof(ctx, "[DropTable] execute sql is %s, req is %v", dropSQL, req)
|
|
|
|
err := m.db.WithContext(ctx).Exec(dropSQL).Error
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to drop table: %v", err)
|
|
}
|
|
|
|
return &rdb.DropTableResponse{Success: true}, nil
|
|
}
|
|
|
|
// GetTable get table schema info
|
|
func (m *mysqlService) GetTable(ctx context.Context, req *rdb.GetTableRequest) (*rdb.GetTableResponse, error) {
|
|
if req == nil {
|
|
return nil, fmt.Errorf("invalid request")
|
|
}
|
|
|
|
table, err := m.getTableInfo(ctx, req.TableName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &rdb.GetTableResponse{Table: table}, nil
|
|
}
|
|
|
|
func (m *mysqlService) InsertData(ctx context.Context, req *rdb.InsertDataRequest) (*rdb.InsertDataResponse, error) {
|
|
if req == nil || len(req.Data) == 0 {
|
|
return nil, fmt.Errorf("invalid request")
|
|
}
|
|
|
|
fields := make([]string, 0)
|
|
for field := range req.Data[0] {
|
|
fields = append(fields, field)
|
|
}
|
|
|
|
const batchSize = 1000
|
|
var totalAffected int64
|
|
|
|
for i := 0; i < len(req.Data); i += batchSize {
|
|
end := i + batchSize
|
|
if end > len(req.Data) {
|
|
end = len(req.Data)
|
|
}
|
|
|
|
currentBatch := req.Data[i:end]
|
|
|
|
placeholderGroups := make([]string, 0, len(currentBatch))
|
|
values := make([]interface{}, 0, len(currentBatch)*len(fields))
|
|
|
|
for _, row := range currentBatch {
|
|
placeholders := make([]string, len(fields))
|
|
for j := range placeholders {
|
|
placeholders[j] = "?"
|
|
}
|
|
placeholderGroups = append(placeholderGroups, "("+strings.Join(placeholders, ",")+")")
|
|
|
|
for _, field := range fields {
|
|
values = append(values, row[field])
|
|
}
|
|
}
|
|
|
|
insertSQL := fmt.Sprintf("INSERT INTO `%s` (`%s`) VALUES %s",
|
|
req.TableName,
|
|
strings.Join(fields, "`,`"),
|
|
strings.Join(placeholderGroups, ","),
|
|
)
|
|
|
|
logs.CtxInfof(ctx, "[InsertData] execute sql is %s, value is %v in batch %d", insertSQL, values, i)
|
|
|
|
result := m.db.WithContext(ctx).Exec(insertSQL, values...)
|
|
if result.Error != nil {
|
|
return nil, result.Error
|
|
}
|
|
|
|
affected := result.RowsAffected
|
|
totalAffected += affected
|
|
}
|
|
|
|
return &rdb.InsertDataResponse{AffectedRows: totalAffected}, nil
|
|
}
|
|
|
|
// UpdateData Update data
|
|
func (m *mysqlService) UpdateData(ctx context.Context, req *rdb.UpdateDataRequest) (*rdb.UpdateDataResponse, error) {
|
|
if req == nil {
|
|
return nil, fmt.Errorf("invalid request")
|
|
}
|
|
|
|
setClauses := make([]string, 0)
|
|
values := make([]interface{}, 0)
|
|
for field, value := range req.Data {
|
|
setClauses = append(setClauses, fmt.Sprintf("`%s` = ?", field))
|
|
values = append(values, value)
|
|
}
|
|
|
|
whereClause, whereValues, err := m.buildWhereClause(req.Where)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to build where clause: %v", err)
|
|
}
|
|
values = append(values, whereValues...)
|
|
|
|
limitClause := ""
|
|
if req.Limit != nil {
|
|
limitClause = fmt.Sprintf(" LIMIT %d", *req.Limit)
|
|
}
|
|
|
|
updateSQL := fmt.Sprintf("UPDATE `%s` SET %s%s%s",
|
|
req.TableName,
|
|
strings.Join(setClauses, ", "),
|
|
whereClause,
|
|
limitClause,
|
|
)
|
|
|
|
logs.CtxInfof(ctx, "[UpdateData] execute sql is %s, value is %v, req is %v", updateSQL, values, req)
|
|
|
|
result := m.db.WithContext(ctx).Exec(updateSQL, values...)
|
|
if result.Error != nil {
|
|
return nil, result.Error
|
|
}
|
|
|
|
affectedRows := result.RowsAffected
|
|
|
|
return &rdb.UpdateDataResponse{AffectedRows: affectedRows}, nil
|
|
}
|
|
|
|
// DeleteData delete data
|
|
func (m *mysqlService) DeleteData(ctx context.Context, req *rdb.DeleteDataRequest) (*rdb.DeleteDataResponse, error) {
|
|
if req == nil {
|
|
return nil, fmt.Errorf("invalid request")
|
|
}
|
|
|
|
whereClause, whereValues, err := m.buildWhereClause(req.Where)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to build where clause: %v", err)
|
|
}
|
|
|
|
limitClause := ""
|
|
if req.Limit != nil {
|
|
limitClause = fmt.Sprintf(" LIMIT %d", *req.Limit)
|
|
}
|
|
|
|
deleteSQL := fmt.Sprintf("DELETE FROM `%s`%s%s",
|
|
req.TableName,
|
|
whereClause,
|
|
limitClause,
|
|
)
|
|
|
|
logs.CtxInfof(ctx, "[DeleteData] execute sql is %s, value is %v, req is %v", deleteSQL, whereValues, req)
|
|
|
|
result := m.db.WithContext(ctx).Exec(deleteSQL, whereValues...)
|
|
if result.Error != nil {
|
|
return nil, fmt.Errorf("failed to delete data: %v", result.Error)
|
|
}
|
|
|
|
affectedRows := result.RowsAffected
|
|
|
|
return &rdb.DeleteDataResponse{AffectedRows: affectedRows}, nil
|
|
}
|
|
|
|
// SelectData select data
|
|
func (m *mysqlService) SelectData(ctx context.Context, req *rdb.SelectDataRequest) (*rdb.SelectDataResponse, error) {
|
|
if req == nil {
|
|
return nil, fmt.Errorf("invalid request")
|
|
}
|
|
|
|
fields := "*"
|
|
if len(req.Fields) > 0 {
|
|
fields = strings.Join(req.Fields, ", ")
|
|
}
|
|
|
|
whereClause := ""
|
|
whereValues := make([]interface{}, 0)
|
|
if req.Where != nil {
|
|
clause, values, err := m.buildWhereClause(req.Where)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to build where clause: %v", err)
|
|
}
|
|
whereClause = clause
|
|
whereValues = values
|
|
}
|
|
|
|
orderByClause := ""
|
|
if len(req.OrderBy) > 0 {
|
|
orders := make([]string, len(req.OrderBy))
|
|
for i, order := range req.OrderBy {
|
|
orders[i] = fmt.Sprintf("%s %s", order.Field, order.Direction)
|
|
}
|
|
orderByClause = " ORDER BY " + strings.Join(orders, ", ")
|
|
}
|
|
|
|
limitClause := ""
|
|
if req.Limit != nil {
|
|
limitClause = fmt.Sprintf(" LIMIT %d", *req.Limit)
|
|
if req.Offset != nil {
|
|
limitClause += fmt.Sprintf(" OFFSET %d", *req.Offset)
|
|
}
|
|
}
|
|
|
|
selectSQL := fmt.Sprintf("SELECT %s FROM `%s`%s%s%s",
|
|
fields,
|
|
req.TableName,
|
|
whereClause,
|
|
orderByClause,
|
|
limitClause,
|
|
)
|
|
|
|
logs.CtxInfof(ctx, "[SelectData] execute sql is %s, value is %v, req is %v", selectSQL, whereValues, req)
|
|
|
|
rows, err := m.db.WithContext(ctx).Raw(selectSQL, whereValues...).Rows()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to execute select: %v", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
columns, err := rows.Columns()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get columns: %v", err)
|
|
}
|
|
|
|
resultSet := &entity2.ResultSet{
|
|
Columns: columns,
|
|
Rows: make([]map[string]interface{}, 0),
|
|
}
|
|
|
|
for rows.Next() {
|
|
values := make([]interface{}, len(columns))
|
|
valuePtrs := make([]interface{}, len(columns))
|
|
for i := range values {
|
|
valuePtrs[i] = &values[i]
|
|
}
|
|
|
|
if err := rows.Scan(valuePtrs...); err != nil {
|
|
return nil, fmt.Errorf("failed to scan row: %v", err)
|
|
}
|
|
|
|
rowData := make(map[string]interface{})
|
|
for i, col := range columns {
|
|
rowData[col] = values[i]
|
|
}
|
|
resultSet.Rows = append(resultSet.Rows, rowData)
|
|
}
|
|
|
|
// get total count
|
|
var total int64
|
|
if whereClause != "" {
|
|
countSQL := fmt.Sprintf("SELECT COUNT(*) FROM `%s`%s", req.TableName, whereClause)
|
|
err = m.db.WithContext(ctx).Raw(countSQL, whereValues...).Scan(&total).Error
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get total count: %v", err)
|
|
}
|
|
} else {
|
|
total = int64(len(resultSet.Rows))
|
|
}
|
|
|
|
return &rdb.SelectDataResponse{
|
|
ResultSet: resultSet,
|
|
Total: total,
|
|
}, nil
|
|
}
|
|
|
|
// UpsertData upsert data
|
|
func (m *mysqlService) UpsertData(ctx context.Context, req *rdb.UpsertDataRequest) (*rdb.UpsertDataResponse, error) {
|
|
if req == nil || len(req.Data) == 0 {
|
|
return nil, fmt.Errorf("invalid request: empty data")
|
|
}
|
|
|
|
keys := req.Keys
|
|
if len(keys) == 0 {
|
|
primaryKeys, err := m.getTablePrimaryKeys(ctx, req.TableName)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get primary keys: %v", err)
|
|
}
|
|
|
|
if len(primaryKeys) == 0 {
|
|
return nil, fmt.Errorf("table %s has no primary key, keys are required for upsert operation", req.TableName)
|
|
}
|
|
|
|
keys = primaryKeys
|
|
}
|
|
|
|
fields := make([]string, 0)
|
|
for field := range req.Data[0] {
|
|
fields = append(fields, field)
|
|
}
|
|
|
|
const batchSize = 1000
|
|
var totalAffected, totalInserted, totalUpdated int64
|
|
|
|
for i := 0; i < len(req.Data); i += batchSize {
|
|
end := i + batchSize
|
|
if end > len(req.Data) {
|
|
end = len(req.Data)
|
|
}
|
|
|
|
currentBatch := req.Data[i:end]
|
|
|
|
placeholderGroups := make([]string, 0, len(currentBatch))
|
|
values := make([]interface{}, 0, len(currentBatch)*len(fields))
|
|
|
|
for _, row := range currentBatch {
|
|
placeholders := make([]string, len(fields))
|
|
for j := range placeholders {
|
|
placeholders[j] = "?"
|
|
}
|
|
placeholderGroups = append(placeholderGroups, "("+strings.Join(placeholders, ",")+")")
|
|
|
|
for _, field := range fields {
|
|
values = append(values, row[field])
|
|
}
|
|
}
|
|
|
|
// ON DUPLICATE KEY UPDATE部分
|
|
updateClauses := make([]string, 0, len(fields))
|
|
for _, field := range fields {
|
|
isKey := false
|
|
for _, key := range keys {
|
|
if field == key {
|
|
isKey = true
|
|
break
|
|
}
|
|
}
|
|
if !isKey {
|
|
updateClauses = append(updateClauses, fmt.Sprintf("`%s`=VALUES(`%s`)", field, field))
|
|
}
|
|
}
|
|
|
|
upsertSQL := fmt.Sprintf(
|
|
"INSERT INTO `%s` (`%s`) VALUES %s ON DUPLICATE KEY UPDATE %s",
|
|
req.TableName,
|
|
strings.Join(fields, "`,`"),
|
|
strings.Join(placeholderGroups, ","),
|
|
strings.Join(updateClauses, ","),
|
|
)
|
|
|
|
logs.CtxInfof(ctx, "[UpsertData] execute sql is %s, value is %v, batch is %d", upsertSQL, values, i)
|
|
|
|
result := m.db.WithContext(ctx).Exec(upsertSQL, values...)
|
|
if result.Error != nil {
|
|
return nil, fmt.Errorf("failed to upsert data: %v", result.Error)
|
|
}
|
|
|
|
total, inserted, updated := calculateInsertedUpdated(result.RowsAffected, len(currentBatch))
|
|
totalInserted += inserted
|
|
totalUpdated += updated
|
|
totalAffected += total
|
|
}
|
|
|
|
return &rdb.UpsertDataResponse{
|
|
AffectedRows: totalAffected,
|
|
InsertedRows: totalInserted,
|
|
UpdatedRows: totalUpdated,
|
|
UnchangedRows: int64(len(req.Data)) - totalAffected,
|
|
}, nil
|
|
}
|
|
|
|
func (m *mysqlService) getTablePrimaryKeys(ctx context.Context, tableName string) ([]string, error) {
|
|
query := `
|
|
SELECT COLUMN_NAME
|
|
FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE
|
|
WHERE TABLE_SCHEMA = DATABASE()
|
|
AND TABLE_NAME = ?
|
|
AND CONSTRAINT_NAME = 'PRIMARY'
|
|
ORDER BY ORDINAL_POSITION
|
|
`
|
|
|
|
var primaryKeys []string
|
|
rows, err := m.db.WithContext(ctx).Raw(query, tableName).Rows()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var columnName string
|
|
if err := rows.Scan(&columnName); err != nil {
|
|
return nil, err
|
|
}
|
|
primaryKeys = append(primaryKeys, columnName)
|
|
}
|
|
|
|
return primaryKeys, nil
|
|
}
|
|
|
|
// calculateInsertedUpdated 函数保持不变
|
|
func calculateInsertedUpdated(affectedRows int64, batchSize int) (int64, int64, int64) {
|
|
updated := int64(0)
|
|
inserted := affectedRows
|
|
if affectedRows > int64(batchSize) {
|
|
updated = affectedRows - int64(batchSize)
|
|
inserted = int64(batchSize) - updated
|
|
}
|
|
|
|
return inserted + updated, inserted, updated
|
|
}
|
|
|
|
// ExecuteSQL Execute SQL
|
|
func (m *mysqlService) ExecuteSQL(ctx context.Context, req *rdb.ExecuteSQLRequest) (*rdb.ExecuteSQLResponse, error) {
|
|
if req == nil {
|
|
return nil, fmt.Errorf("invalid request")
|
|
}
|
|
|
|
logs.CtxInfof(ctx, "[ExecuteSQL] req is %v", req)
|
|
|
|
var processedSQL string
|
|
var processedParams []interface{}
|
|
var err error
|
|
|
|
// Handle SQLType: if raw, do not process params
|
|
if req.SQLType == entity2.SQLType_Raw {
|
|
processedSQL = req.SQL
|
|
processedParams = nil
|
|
} else {
|
|
processedSQL, processedParams, err = m.processSliceParams(req.SQL, req.Params)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to process parameters: %v", err)
|
|
}
|
|
}
|
|
|
|
operation, err := sqlparser.NewSQLParser().GetSQLOperation(processedSQL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if operation != sqlparsercontract.OperationTypeSelect {
|
|
result := m.db.WithContext(ctx).Exec(processedSQL, processedParams...)
|
|
if result.Error != nil {
|
|
return nil, fmt.Errorf("failed to execute SQL: %v", result.Error)
|
|
}
|
|
|
|
resultSet := &entity2.ResultSet{
|
|
Columns: []string{},
|
|
Rows: []map[string]interface{}{},
|
|
AffectedRows: result.RowsAffected,
|
|
}
|
|
|
|
return &rdb.ExecuteSQLResponse{
|
|
ResultSet: resultSet,
|
|
}, nil
|
|
}
|
|
|
|
rows, err := m.db.WithContext(ctx).Raw(processedSQL, processedParams...).Rows()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to execute SQL: %v", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
columns, err := rows.Columns()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get columns: %v", err)
|
|
}
|
|
|
|
resultSet := &entity2.ResultSet{
|
|
Columns: columns,
|
|
Rows: make([]map[string]interface{}, 0),
|
|
}
|
|
|
|
for rows.Next() {
|
|
values := make([]interface{}, len(columns))
|
|
valuePtrs := make([]interface{}, len(columns))
|
|
for i := range values {
|
|
valuePtrs[i] = &values[i]
|
|
}
|
|
|
|
if err := rows.Scan(valuePtrs...); err != nil {
|
|
return nil, fmt.Errorf("failed to scan row: %v", err)
|
|
}
|
|
|
|
rowData := make(map[string]interface{})
|
|
for i, col := range columns {
|
|
rowData[col] = values[i]
|
|
}
|
|
resultSet.Rows = append(resultSet.Rows, rowData)
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("error while reading rows: %v", err)
|
|
}
|
|
|
|
return &rdb.ExecuteSQLResponse{
|
|
ResultSet: resultSet,
|
|
}, nil
|
|
}
|
|
|
|
func (m *mysqlService) processSliceParams(sql string, params []interface{}) (string, []interface{}, error) {
|
|
if len(params) == 0 {
|
|
return sql, params, nil
|
|
}
|
|
|
|
processedParams := make([]interface{}, 0)
|
|
paramIndex := 0
|
|
resultSQL := ""
|
|
lastPos := 0
|
|
|
|
// get all ? positions
|
|
for i := 0; i < len(sql); i++ {
|
|
if sql[i] == '?' && paramIndex < len(params) {
|
|
resultSQL += sql[lastPos:i]
|
|
lastPos = i + 1
|
|
|
|
param := params[paramIndex]
|
|
paramIndex++
|
|
|
|
if m.isSlice(param) {
|
|
sliceValues, err := m.getSliceValues(param)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
|
|
if len(sliceValues) == 0 {
|
|
resultSQL += "(NULL)"
|
|
} else {
|
|
// (?, ?, ...)
|
|
placeholders := make([]string, len(sliceValues))
|
|
for j := range placeholders {
|
|
placeholders[j] = "?"
|
|
}
|
|
resultSQL += "(" + strings.Join(placeholders, ", ") + ")"
|
|
|
|
processedParams = append(processedParams, sliceValues...)
|
|
}
|
|
} else {
|
|
resultSQL += "?"
|
|
processedParams = append(processedParams, param)
|
|
}
|
|
}
|
|
}
|
|
|
|
resultSQL += sql[lastPos:]
|
|
|
|
return resultSQL, processedParams, nil
|
|
}
|
|
|
|
func (m *mysqlService) isSlice(param interface{}) bool {
|
|
if param == nil {
|
|
return false
|
|
}
|
|
|
|
rv := reflect.ValueOf(param)
|
|
return rv.Kind() == reflect.Slice && rv.Type().Elem().Kind() != reflect.Uint8 // exclude []byte
|
|
}
|
|
|
|
func (m *mysqlService) getSliceValues(param interface{}) ([]interface{}, error) {
|
|
rv := reflect.ValueOf(param)
|
|
if rv.Kind() != reflect.Slice {
|
|
return nil, fmt.Errorf("parameter is not a slice")
|
|
}
|
|
|
|
length := rv.Len()
|
|
values := make([]interface{}, length)
|
|
|
|
for i := 0; i < length; i++ {
|
|
values[i] = rv.Index(i).Interface()
|
|
}
|
|
|
|
return values, nil
|
|
}
|
|
|
|
func (m *mysqlService) genTableName(ctx context.Context) (string, error) {
|
|
id, err := m.generator.GenID(ctx)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return fmt.Sprintf("table_%d", id), nil
|
|
}
|
|
|
|
func (m *mysqlService) getTableInfo(ctx context.Context, tableName string) (*entity2.Table, error) {
|
|
tableInfoSQL := `
|
|
SELECT
|
|
TABLE_NAME,
|
|
TABLE_COLLATION,
|
|
AUTO_INCREMENT,
|
|
TABLE_COMMENT
|
|
FROM information_schema.TABLES
|
|
WHERE TABLE_SCHEMA = DATABASE()
|
|
AND TABLE_NAME = ?
|
|
`
|
|
|
|
var (
|
|
name string
|
|
collation *string
|
|
autoIncrement *int64
|
|
comment *string
|
|
)
|
|
|
|
err := m.db.WithContext(ctx).Raw(tableInfoSQL, tableName).Row().Scan(
|
|
&name,
|
|
&collation,
|
|
&autoIncrement,
|
|
&comment,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
columnsSQL := `
|
|
SELECT
|
|
COLUMN_NAME,
|
|
DATA_TYPE,
|
|
CHARACTER_MAXIMUM_LENGTH,
|
|
IS_NULLABLE,
|
|
COLUMN_DEFAULT,
|
|
EXTRA,
|
|
COLUMN_COMMENT
|
|
FROM information_schema.COLUMNS
|
|
WHERE TABLE_SCHEMA = DATABASE()
|
|
AND TABLE_NAME = ?
|
|
ORDER BY ORDINAL_POSITION
|
|
`
|
|
|
|
type columnInfo struct {
|
|
ColumnName string `gorm:"column:COLUMN_NAME"`
|
|
DataType string `gorm:"column:DATA_TYPE"`
|
|
CharLength *int `gorm:"column:CHARACTER_MAXIMUM_LENGTH"`
|
|
IsNullable string `gorm:"column:IS_NULLABLE"`
|
|
DefaultValue *string `gorm:"column:COLUMN_DEFAULT"`
|
|
Extra string `gorm:"column:EXTRA"`
|
|
ColumnComment *string `gorm:"column:COLUMN_COMMENT"`
|
|
}
|
|
|
|
var columnsData []columnInfo
|
|
err = m.db.WithContext(ctx).Raw(columnsSQL, tableName).Scan(&columnsData).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
columns := make([]*entity2.Column, len(columnsData))
|
|
for i, colData := range columnsData {
|
|
column := &entity2.Column{
|
|
Name: colData.ColumnName,
|
|
DataType: entity2.DataType(colData.DataType),
|
|
Length: colData.CharLength,
|
|
NotNull: colData.IsNullable == "NO",
|
|
DefaultValue: colData.DefaultValue,
|
|
AutoIncrement: strings.Contains(colData.Extra, "auto_increment"),
|
|
Comment: colData.ColumnComment,
|
|
}
|
|
columns[i] = column
|
|
}
|
|
|
|
indexesSQL := `
|
|
SELECT
|
|
INDEX_NAME,
|
|
NON_UNIQUE,
|
|
INDEX_TYPE,
|
|
GROUP_CONCAT(COLUMN_NAME ORDER BY SEQ_IN_INDEX)
|
|
FROM information_schema.STATISTICS
|
|
WHERE TABLE_SCHEMA = DATABASE()
|
|
AND TABLE_NAME = ?
|
|
GROUP BY INDEX_NAME, NON_UNIQUE, INDEX_TYPE
|
|
`
|
|
|
|
type indexInfo struct {
|
|
IndexName string `gorm:"column:INDEX_NAME"`
|
|
NonUnique int `gorm:"column:NON_UNIQUE"`
|
|
IndexType string `gorm:"column:INDEX_TYPE"`
|
|
Columns string `gorm:"column:GROUP_CONCAT"`
|
|
}
|
|
|
|
var indexesData []indexInfo
|
|
err = m.db.WithContext(ctx).Raw(indexesSQL, tableName).Scan(&indexesData).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
indexes := make([]*entity2.Index, 0, len(indexesData))
|
|
for _, idxData := range indexesData {
|
|
index := &entity2.Index{
|
|
Name: idxData.IndexName,
|
|
Type: entity2.IndexType(idxData.IndexType),
|
|
Columns: strings.Split(idxData.Columns, ","),
|
|
}
|
|
indexes = append(indexes, index)
|
|
}
|
|
|
|
return &entity2.Table{
|
|
Name: name,
|
|
Columns: columns,
|
|
Indexes: indexes,
|
|
Options: &entity2.TableOption{
|
|
Collate: collation,
|
|
AutoIncrement: autoIncrement,
|
|
Comment: comment,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
func (m *mysqlService) buildWhereClause(condition *rdb.ComplexCondition) (string, []interface{}, error) {
|
|
if condition == nil {
|
|
return "", nil, nil
|
|
}
|
|
|
|
if condition.Operator == "" {
|
|
condition.Operator = entity2.AND
|
|
}
|
|
|
|
var whereClause strings.Builder
|
|
values := make([]interface{}, 0)
|
|
|
|
for i, cond := range condition.Conditions {
|
|
if i > 0 {
|
|
whereClause.WriteString(fmt.Sprintf(" %s ", condition.Operator))
|
|
}
|
|
|
|
if cond.Operator == entity2.OperatorIn || cond.Operator == entity2.OperatorNotIn {
|
|
if m.isSlice(cond.Value) {
|
|
sliceValues, err := m.getSliceValues(cond.Value)
|
|
if err != nil {
|
|
return "", nil, fmt.Errorf("failed to process slice values: %v", err)
|
|
}
|
|
|
|
if len(sliceValues) == 0 {
|
|
whereClause.WriteString(fmt.Sprintf("`%s` %s (NULL)", cond.Field, string(cond.Operator)))
|
|
} else {
|
|
placeholders := make([]string, len(sliceValues))
|
|
for i := range placeholders {
|
|
placeholders[i] = "?"
|
|
}
|
|
whereClause.WriteString(fmt.Sprintf("`%s` %s (%s)", cond.Field, string(cond.Operator), strings.Join(placeholders, ",")))
|
|
|
|
values = append(values, sliceValues...)
|
|
}
|
|
} else {
|
|
return "", nil, fmt.Errorf("IN operator requires a slice of values")
|
|
}
|
|
} else if cond.Operator == entity2.OperatorIsNull || cond.Operator == entity2.OperatorIsNotNull {
|
|
whereClause.WriteString(fmt.Sprintf("`%s` %s", cond.Field, cond.Operator))
|
|
} else {
|
|
whereClause.WriteString(fmt.Sprintf("`%s` %s ?", cond.Field, cond.Operator))
|
|
values = append(values, cond.Value)
|
|
}
|
|
}
|
|
|
|
if len(condition.NestedConditions) > 0 {
|
|
whereClause.WriteString(" AND (")
|
|
for i, nested := range condition.NestedConditions {
|
|
if i > 0 {
|
|
whereClause.WriteString(fmt.Sprintf(" %s ", nested.Operator))
|
|
}
|
|
nestedClause, nestedValues, err := m.buildWhereClause(nested)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
whereClause.WriteString(nestedClause)
|
|
values = append(values, nestedValues...)
|
|
}
|
|
whereClause.WriteString(")")
|
|
}
|
|
|
|
if whereClause.Len() > 0 {
|
|
return " WHERE " + whereClause.String(), values, nil
|
|
}
|
|
return "", values, nil
|
|
}
|