feat: manually mirror opencoze's code from bytedance
Change-Id: I09a73aadda978ad9511264a756b2ce51f5761adf
This commit is contained in:
993
backend/infra/impl/rdb/mysql.go
Normal file
993
backend/infra/impl/rdb/mysql.go
Normal file
@@ -0,0 +1,993 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package rdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/rdb"
|
||||
entity2 "github.com/coze-dev/coze-studio/backend/infra/contract/rdb/entity"
|
||||
sqlparsercontract "github.com/coze-dev/coze-studio/backend/infra/contract/sqlparser"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/sqlparser"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
)
|
||||
|
||||
type mysqlService struct {
|
||||
db *gorm.DB
|
||||
generator idgen.IDGenerator
|
||||
}
|
||||
|
||||
func NewService(db *gorm.DB, generator idgen.IDGenerator) rdb.RDB {
|
||||
return &mysqlService{db: db, generator: generator}
|
||||
}
|
||||
|
||||
// CreateTable create table
|
||||
func (m *mysqlService) CreateTable(ctx context.Context, req *rdb.CreateTableRequest) (*rdb.CreateTableResponse, error) {
|
||||
if req == nil || req.Table == nil {
|
||||
return nil, fmt.Errorf("invalid request")
|
||||
}
|
||||
|
||||
// build column definitions
|
||||
columnDefs := make([]string, 0, len(req.Table.Columns))
|
||||
for _, col := range req.Table.Columns {
|
||||
colDef := fmt.Sprintf("`%s` %s", col.Name, col.DataType)
|
||||
|
||||
if col.Length != nil {
|
||||
colDef += fmt.Sprintf("(%d)", *col.Length)
|
||||
} else if col.Length == nil && col.DataType == entity2.TypeVarchar {
|
||||
colDef += fmt.Sprintf("(%d)", 255)
|
||||
}
|
||||
|
||||
if col.NotNull {
|
||||
colDef += " NOT NULL"
|
||||
}
|
||||
if col.DefaultValue != nil {
|
||||
if col.DataType == entity2.TypeTimestamp {
|
||||
colDef += fmt.Sprintf(" DEFAULT %s", *col.DefaultValue)
|
||||
} else {
|
||||
colDef += fmt.Sprintf(" DEFAULT '%s'", *col.DefaultValue)
|
||||
}
|
||||
}
|
||||
if col.AutoIncrement {
|
||||
colDef += " AUTO_INCREMENT"
|
||||
}
|
||||
if col.Comment != nil {
|
||||
colDef += fmt.Sprintf(" COMMENT '%s'", *col.Comment)
|
||||
}
|
||||
|
||||
columnDefs = append(columnDefs, colDef)
|
||||
}
|
||||
|
||||
// build index definitions
|
||||
for _, idx := range req.Table.Indexes {
|
||||
var idxDef string
|
||||
switch idx.Type {
|
||||
case entity2.PrimaryKey:
|
||||
idxDef = fmt.Sprintf("PRIMARY KEY (`%s`)", strings.Join(idx.Columns, "`,`"))
|
||||
case entity2.UniqueKey:
|
||||
idxDef = fmt.Sprintf("UNIQUE KEY `%s` (`%s`)", idx.Name, strings.Join(idx.Columns, "`,`"))
|
||||
default:
|
||||
idxDef = fmt.Sprintf("KEY `%s` (`%s`)", idx.Name, strings.Join(idx.Columns, "`,`"))
|
||||
}
|
||||
columnDefs = append(columnDefs, idxDef)
|
||||
}
|
||||
|
||||
tableOptions := make([]string, 0)
|
||||
if req.Table.Options != nil {
|
||||
if req.Table.Options.Collate != nil {
|
||||
tableOptions = append(tableOptions, fmt.Sprintf("COLLATE=%s", *req.Table.Options.Collate))
|
||||
}
|
||||
if req.Table.Options.AutoIncrement != nil {
|
||||
tableOptions = append(tableOptions, fmt.Sprintf("AUTO_INCREMENT=%d", *req.Table.Options.AutoIncrement))
|
||||
}
|
||||
if req.Table.Options.Comment != nil {
|
||||
tableOptions = append(tableOptions, fmt.Sprintf("COMMENT='%s'", *req.Table.Options.Comment))
|
||||
}
|
||||
}
|
||||
|
||||
tableName := req.Table.Name
|
||||
if req.Table.Name == "" {
|
||||
genName, err := m.genTableName(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tableName = genName
|
||||
}
|
||||
|
||||
createSQL := fmt.Sprintf("CREATE TABLE IF NOT EXISTS `%s` (\n %s\n) %s",
|
||||
tableName,
|
||||
strings.Join(columnDefs, ",\n "),
|
||||
strings.Join(tableOptions, " "),
|
||||
)
|
||||
|
||||
logs.CtxInfof(ctx, "[CreateTable] execute sql is %s, req is %v", createSQL, req)
|
||||
|
||||
err := m.db.WithContext(ctx).Exec(createSQL).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create table: %v", err)
|
||||
}
|
||||
|
||||
resTable := req.Table
|
||||
resTable.Name = tableName
|
||||
return &rdb.CreateTableResponse{Table: resTable}, nil
|
||||
}
|
||||
|
||||
// AlterTable alter table
|
||||
func (m *mysqlService) AlterTable(ctx context.Context, req *rdb.AlterTableRequest) (*rdb.AlterTableResponse, error) {
|
||||
if req == nil || len(req.Operations) == 0 {
|
||||
return nil, fmt.Errorf("invalid request")
|
||||
}
|
||||
|
||||
alterSQL := fmt.Sprintf("ALTER TABLE `%s`", req.TableName)
|
||||
operations := make([]string, 0, len(req.Operations))
|
||||
|
||||
for _, op := range req.Operations {
|
||||
switch op.Action {
|
||||
case entity2.AddColumn:
|
||||
if op.Column == nil {
|
||||
return nil, fmt.Errorf("column is required for ADD COLUMN operation")
|
||||
}
|
||||
colDef := fmt.Sprintf("ADD COLUMN `%s` %s", op.Column.Name, op.Column.DataType)
|
||||
if op.Column.Length != nil {
|
||||
colDef += fmt.Sprintf("(%d)", *op.Column.Length)
|
||||
} else if op.Column.Length == nil && op.Column.DataType == entity2.TypeVarchar {
|
||||
colDef += fmt.Sprintf("(%d)", 255)
|
||||
}
|
||||
|
||||
if op.Column.NotNull {
|
||||
colDef += " NOT NULL"
|
||||
}
|
||||
|
||||
if op.Column.DefaultValue != nil {
|
||||
if op.Column.DataType == entity2.TypeTimestamp {
|
||||
colDef += fmt.Sprintf(" DEFAULT %s", *op.Column.DefaultValue)
|
||||
} else {
|
||||
colDef += fmt.Sprintf(" DEFAULT '%s'", *op.Column.DefaultValue)
|
||||
}
|
||||
}
|
||||
|
||||
operations = append(operations, colDef)
|
||||
|
||||
case entity2.DropColumn:
|
||||
if op.Column == nil {
|
||||
return nil, fmt.Errorf("column is required for DROP COLUMN operation")
|
||||
}
|
||||
operations = append(operations, fmt.Sprintf("DROP COLUMN `%s`", op.Column.Name))
|
||||
|
||||
case entity2.ModifyColumn:
|
||||
if op.Column == nil {
|
||||
return nil, fmt.Errorf("column is required for MODIFY COLUMN operation")
|
||||
}
|
||||
colDef := fmt.Sprintf("MODIFY COLUMN `%s` %s", op.Column.Name, op.Column.DataType)
|
||||
if op.Column.Length != nil {
|
||||
colDef += fmt.Sprintf("(%d)", *op.Column.Length)
|
||||
} else if op.Column.Length == nil && op.Column.DataType == entity2.TypeVarchar {
|
||||
colDef += fmt.Sprintf("(%d)", 255)
|
||||
}
|
||||
operations = append(operations, colDef)
|
||||
|
||||
case entity2.RenameColumn:
|
||||
if op.Column == nil || op.OldName == nil {
|
||||
return nil, fmt.Errorf("column and old name are required for RENAME COLUMN operation")
|
||||
}
|
||||
operations = append(operations, fmt.Sprintf("RENAME COLUMN `%s` TO `%s`", *op.OldName, op.Column.Name))
|
||||
|
||||
case entity2.AddIndex:
|
||||
if op.Index == nil {
|
||||
return nil, fmt.Errorf("index is required for ADD INDEX operation")
|
||||
}
|
||||
var idxDef string
|
||||
switch op.Index.Type {
|
||||
case entity2.PrimaryKey:
|
||||
idxDef = fmt.Sprintf("ADD PRIMARY KEY (`%s`)", strings.Join(op.Index.Columns, "`,`"))
|
||||
case entity2.UniqueKey:
|
||||
idxDef = fmt.Sprintf("ADD UNIQUE INDEX `%s` (`%s`)", op.Index.Name, strings.Join(op.Index.Columns, "`,`"))
|
||||
default:
|
||||
idxDef = fmt.Sprintf("ADD INDEX `%s` (`%s`)", op.Index.Name, strings.Join(op.Index.Columns, "`,`"))
|
||||
}
|
||||
operations = append(operations, idxDef)
|
||||
}
|
||||
}
|
||||
|
||||
alterSQL += " " + strings.Join(operations, ", ")
|
||||
|
||||
logs.CtxInfof(ctx, "[AlterTable] execute sql is %s, req is %v", alterSQL, req)
|
||||
|
||||
err := m.db.WithContext(ctx).Exec(alterSQL).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to alter table: %v", err)
|
||||
}
|
||||
|
||||
table, err := m.getTableInfo(ctx, req.TableName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get table info: %v", err)
|
||||
}
|
||||
|
||||
return &rdb.AlterTableResponse{Table: table}, nil
|
||||
}
|
||||
|
||||
// DropTable drop table
|
||||
func (m *mysqlService) DropTable(ctx context.Context, req *rdb.DropTableRequest) (*rdb.DropTableResponse, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("invalid request")
|
||||
}
|
||||
|
||||
dropSQL := "DROP TABLE"
|
||||
if req.IfExists {
|
||||
dropSQL += " IF EXISTS"
|
||||
}
|
||||
dropSQL += fmt.Sprintf(" `%s`", req.TableName)
|
||||
|
||||
logs.CtxInfof(ctx, "[DropTable] execute sql is %s, req is %v", dropSQL, req)
|
||||
|
||||
err := m.db.WithContext(ctx).Exec(dropSQL).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to drop table: %v", err)
|
||||
}
|
||||
|
||||
return &rdb.DropTableResponse{Success: true}, nil
|
||||
}
|
||||
|
||||
// GetTable get table schema info
|
||||
func (m *mysqlService) GetTable(ctx context.Context, req *rdb.GetTableRequest) (*rdb.GetTableResponse, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("invalid request")
|
||||
}
|
||||
|
||||
table, err := m.getTableInfo(ctx, req.TableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &rdb.GetTableResponse{Table: table}, nil
|
||||
}
|
||||
|
||||
func (m *mysqlService) InsertData(ctx context.Context, req *rdb.InsertDataRequest) (*rdb.InsertDataResponse, error) {
|
||||
if req == nil || len(req.Data) == 0 {
|
||||
return nil, fmt.Errorf("invalid request")
|
||||
}
|
||||
|
||||
fields := make([]string, 0)
|
||||
for field := range req.Data[0] {
|
||||
fields = append(fields, field)
|
||||
}
|
||||
|
||||
const batchSize = 1000
|
||||
var totalAffected int64
|
||||
|
||||
for i := 0; i < len(req.Data); i += batchSize {
|
||||
end := i + batchSize
|
||||
if end > len(req.Data) {
|
||||
end = len(req.Data)
|
||||
}
|
||||
|
||||
currentBatch := req.Data[i:end]
|
||||
|
||||
placeholderGroups := make([]string, 0, len(currentBatch))
|
||||
values := make([]interface{}, 0, len(currentBatch)*len(fields))
|
||||
|
||||
for _, row := range currentBatch {
|
||||
placeholders := make([]string, len(fields))
|
||||
for j := range placeholders {
|
||||
placeholders[j] = "?"
|
||||
}
|
||||
placeholderGroups = append(placeholderGroups, "("+strings.Join(placeholders, ",")+")")
|
||||
|
||||
for _, field := range fields {
|
||||
values = append(values, row[field])
|
||||
}
|
||||
}
|
||||
|
||||
insertSQL := fmt.Sprintf("INSERT INTO `%s` (`%s`) VALUES %s",
|
||||
req.TableName,
|
||||
strings.Join(fields, "`,`"),
|
||||
strings.Join(placeholderGroups, ","),
|
||||
)
|
||||
|
||||
logs.CtxInfof(ctx, "[InsertData] execute sql is %s, value is %v in batch %d", insertSQL, values, i)
|
||||
|
||||
result := m.db.WithContext(ctx).Exec(insertSQL, values...)
|
||||
if result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
affected := result.RowsAffected
|
||||
totalAffected += affected
|
||||
}
|
||||
|
||||
return &rdb.InsertDataResponse{AffectedRows: totalAffected}, nil
|
||||
}
|
||||
|
||||
// UpdateData Update data
|
||||
func (m *mysqlService) UpdateData(ctx context.Context, req *rdb.UpdateDataRequest) (*rdb.UpdateDataResponse, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("invalid request")
|
||||
}
|
||||
|
||||
setClauses := make([]string, 0)
|
||||
values := make([]interface{}, 0)
|
||||
for field, value := range req.Data {
|
||||
setClauses = append(setClauses, fmt.Sprintf("`%s` = ?", field))
|
||||
values = append(values, value)
|
||||
}
|
||||
|
||||
whereClause, whereValues, err := m.buildWhereClause(req.Where)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build where clause: %v", err)
|
||||
}
|
||||
values = append(values, whereValues...)
|
||||
|
||||
limitClause := ""
|
||||
if req.Limit != nil {
|
||||
limitClause = fmt.Sprintf(" LIMIT %d", *req.Limit)
|
||||
}
|
||||
|
||||
updateSQL := fmt.Sprintf("UPDATE `%s` SET %s%s%s",
|
||||
req.TableName,
|
||||
strings.Join(setClauses, ", "),
|
||||
whereClause,
|
||||
limitClause,
|
||||
)
|
||||
|
||||
logs.CtxInfof(ctx, "[UpdateData] execute sql is %s, value is %v, req is %v", updateSQL, values, req)
|
||||
|
||||
result := m.db.WithContext(ctx).Exec(updateSQL, values...)
|
||||
if result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
affectedRows := result.RowsAffected
|
||||
|
||||
return &rdb.UpdateDataResponse{AffectedRows: affectedRows}, nil
|
||||
}
|
||||
|
||||
// DeleteData delete data
|
||||
func (m *mysqlService) DeleteData(ctx context.Context, req *rdb.DeleteDataRequest) (*rdb.DeleteDataResponse, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("invalid request")
|
||||
}
|
||||
|
||||
whereClause, whereValues, err := m.buildWhereClause(req.Where)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build where clause: %v", err)
|
||||
}
|
||||
|
||||
limitClause := ""
|
||||
if req.Limit != nil {
|
||||
limitClause = fmt.Sprintf(" LIMIT %d", *req.Limit)
|
||||
}
|
||||
|
||||
deleteSQL := fmt.Sprintf("DELETE FROM `%s`%s%s",
|
||||
req.TableName,
|
||||
whereClause,
|
||||
limitClause,
|
||||
)
|
||||
|
||||
logs.CtxInfof(ctx, "[DeleteData] execute sql is %s, value is %v, req is %v", deleteSQL, whereValues, req)
|
||||
|
||||
result := m.db.WithContext(ctx).Exec(deleteSQL, whereValues...)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("failed to delete data: %v", result.Error)
|
||||
}
|
||||
|
||||
affectedRows := result.RowsAffected
|
||||
|
||||
return &rdb.DeleteDataResponse{AffectedRows: affectedRows}, nil
|
||||
}
|
||||
|
||||
// SelectData select data
|
||||
func (m *mysqlService) SelectData(ctx context.Context, req *rdb.SelectDataRequest) (*rdb.SelectDataResponse, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("invalid request")
|
||||
}
|
||||
|
||||
fields := "*"
|
||||
if len(req.Fields) > 0 {
|
||||
fields = strings.Join(req.Fields, ", ")
|
||||
}
|
||||
|
||||
whereClause := ""
|
||||
whereValues := make([]interface{}, 0)
|
||||
if req.Where != nil {
|
||||
clause, values, err := m.buildWhereClause(req.Where)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build where clause: %v", err)
|
||||
}
|
||||
whereClause = clause
|
||||
whereValues = values
|
||||
}
|
||||
|
||||
orderByClause := ""
|
||||
if len(req.OrderBy) > 0 {
|
||||
orders := make([]string, len(req.OrderBy))
|
||||
for i, order := range req.OrderBy {
|
||||
orders[i] = fmt.Sprintf("%s %s", order.Field, order.Direction)
|
||||
}
|
||||
orderByClause = " ORDER BY " + strings.Join(orders, ", ")
|
||||
}
|
||||
|
||||
limitClause := ""
|
||||
if req.Limit != nil {
|
||||
limitClause = fmt.Sprintf(" LIMIT %d", *req.Limit)
|
||||
if req.Offset != nil {
|
||||
limitClause += fmt.Sprintf(" OFFSET %d", *req.Offset)
|
||||
}
|
||||
}
|
||||
|
||||
selectSQL := fmt.Sprintf("SELECT %s FROM `%s`%s%s%s",
|
||||
fields,
|
||||
req.TableName,
|
||||
whereClause,
|
||||
orderByClause,
|
||||
limitClause,
|
||||
)
|
||||
|
||||
logs.CtxInfof(ctx, "[SelectData] execute sql is %s, value is %v, req is %v", selectSQL, whereValues, req)
|
||||
|
||||
rows, err := m.db.WithContext(ctx).Raw(selectSQL, whereValues...).Rows()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute select: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get columns: %v", err)
|
||||
}
|
||||
|
||||
resultSet := &entity2.ResultSet{
|
||||
Columns: columns,
|
||||
Rows: make([]map[string]interface{}, 0),
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
values := make([]interface{}, len(columns))
|
||||
valuePtrs := make([]interface{}, len(columns))
|
||||
for i := range values {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan row: %v", err)
|
||||
}
|
||||
|
||||
rowData := make(map[string]interface{})
|
||||
for i, col := range columns {
|
||||
rowData[col] = values[i]
|
||||
}
|
||||
resultSet.Rows = append(resultSet.Rows, rowData)
|
||||
}
|
||||
|
||||
// get total count
|
||||
var total int64
|
||||
if whereClause != "" {
|
||||
countSQL := fmt.Sprintf("SELECT COUNT(*) FROM `%s`%s", req.TableName, whereClause)
|
||||
err = m.db.WithContext(ctx).Raw(countSQL, whereValues...).Scan(&total).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get total count: %v", err)
|
||||
}
|
||||
} else {
|
||||
total = int64(len(resultSet.Rows))
|
||||
}
|
||||
|
||||
return &rdb.SelectDataResponse{
|
||||
ResultSet: resultSet,
|
||||
Total: total,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UpsertData upsert data
|
||||
func (m *mysqlService) UpsertData(ctx context.Context, req *rdb.UpsertDataRequest) (*rdb.UpsertDataResponse, error) {
|
||||
if req == nil || len(req.Data) == 0 {
|
||||
return nil, fmt.Errorf("invalid request: empty data")
|
||||
}
|
||||
|
||||
keys := req.Keys
|
||||
if len(keys) == 0 {
|
||||
primaryKeys, err := m.getTablePrimaryKeys(ctx, req.TableName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get primary keys: %v", err)
|
||||
}
|
||||
|
||||
if len(primaryKeys) == 0 {
|
||||
return nil, fmt.Errorf("table %s has no primary key, keys are required for upsert operation", req.TableName)
|
||||
}
|
||||
|
||||
keys = primaryKeys
|
||||
}
|
||||
|
||||
fields := make([]string, 0)
|
||||
for field := range req.Data[0] {
|
||||
fields = append(fields, field)
|
||||
}
|
||||
|
||||
const batchSize = 1000
|
||||
var totalAffected, totalInserted, totalUpdated int64
|
||||
|
||||
for i := 0; i < len(req.Data); i += batchSize {
|
||||
end := i + batchSize
|
||||
if end > len(req.Data) {
|
||||
end = len(req.Data)
|
||||
}
|
||||
|
||||
currentBatch := req.Data[i:end]
|
||||
|
||||
placeholderGroups := make([]string, 0, len(currentBatch))
|
||||
values := make([]interface{}, 0, len(currentBatch)*len(fields))
|
||||
|
||||
for _, row := range currentBatch {
|
||||
placeholders := make([]string, len(fields))
|
||||
for j := range placeholders {
|
||||
placeholders[j] = "?"
|
||||
}
|
||||
placeholderGroups = append(placeholderGroups, "("+strings.Join(placeholders, ",")+")")
|
||||
|
||||
for _, field := range fields {
|
||||
values = append(values, row[field])
|
||||
}
|
||||
}
|
||||
|
||||
// ON DUPLICATE KEY UPDATE部分
|
||||
updateClauses := make([]string, 0, len(fields))
|
||||
for _, field := range fields {
|
||||
isKey := false
|
||||
for _, key := range keys {
|
||||
if field == key {
|
||||
isKey = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !isKey {
|
||||
updateClauses = append(updateClauses, fmt.Sprintf("`%s`=VALUES(`%s`)", field, field))
|
||||
}
|
||||
}
|
||||
|
||||
upsertSQL := fmt.Sprintf(
|
||||
"INSERT INTO `%s` (`%s`) VALUES %s ON DUPLICATE KEY UPDATE %s",
|
||||
req.TableName,
|
||||
strings.Join(fields, "`,`"),
|
||||
strings.Join(placeholderGroups, ","),
|
||||
strings.Join(updateClauses, ","),
|
||||
)
|
||||
|
||||
logs.CtxInfof(ctx, "[UpsertData] execute sql is %s, value is %v, batch is %d", upsertSQL, values, i)
|
||||
|
||||
result := m.db.WithContext(ctx).Exec(upsertSQL, values...)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("failed to upsert data: %v", result.Error)
|
||||
}
|
||||
|
||||
total, inserted, updated := calculateInsertedUpdated(result.RowsAffected, len(currentBatch))
|
||||
totalInserted += inserted
|
||||
totalUpdated += updated
|
||||
totalAffected += total
|
||||
}
|
||||
|
||||
return &rdb.UpsertDataResponse{
|
||||
AffectedRows: totalAffected,
|
||||
InsertedRows: totalInserted,
|
||||
UpdatedRows: totalUpdated,
|
||||
UnchangedRows: int64(len(req.Data)) - totalAffected,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mysqlService) getTablePrimaryKeys(ctx context.Context, tableName string) ([]string, error) {
|
||||
query := `
|
||||
SELECT COLUMN_NAME
|
||||
FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE
|
||||
WHERE TABLE_SCHEMA = DATABASE()
|
||||
AND TABLE_NAME = ?
|
||||
AND CONSTRAINT_NAME = 'PRIMARY'
|
||||
ORDER BY ORDINAL_POSITION
|
||||
`
|
||||
|
||||
var primaryKeys []string
|
||||
rows, err := m.db.WithContext(ctx).Raw(query, tableName).Rows()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var columnName string
|
||||
if err := rows.Scan(&columnName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
primaryKeys = append(primaryKeys, columnName)
|
||||
}
|
||||
|
||||
return primaryKeys, nil
|
||||
}
|
||||
|
||||
// calculateInsertedUpdated 函数保持不变
|
||||
func calculateInsertedUpdated(affectedRows int64, batchSize int) (int64, int64, int64) {
|
||||
updated := int64(0)
|
||||
inserted := affectedRows
|
||||
if affectedRows > int64(batchSize) {
|
||||
updated = affectedRows - int64(batchSize)
|
||||
inserted = int64(batchSize) - updated
|
||||
}
|
||||
|
||||
return inserted + updated, inserted, updated
|
||||
}
|
||||
|
||||
// ExecuteSQL Execute SQL
|
||||
func (m *mysqlService) ExecuteSQL(ctx context.Context, req *rdb.ExecuteSQLRequest) (*rdb.ExecuteSQLResponse, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("invalid request")
|
||||
}
|
||||
|
||||
logs.CtxInfof(ctx, "[ExecuteSQL] req is %v", req)
|
||||
|
||||
var processedSQL string
|
||||
var processedParams []interface{}
|
||||
var err error
|
||||
|
||||
// Handle SQLType: if raw, do not process params
|
||||
if req.SQLType == entity2.SQLType_Raw {
|
||||
processedSQL = req.SQL
|
||||
processedParams = nil
|
||||
} else {
|
||||
processedSQL, processedParams, err = m.processSliceParams(req.SQL, req.Params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to process parameters: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
operation, err := sqlparser.NewSQLParser().GetSQLOperation(processedSQL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if operation != sqlparsercontract.OperationTypeSelect {
|
||||
result := m.db.WithContext(ctx).Exec(processedSQL, processedParams...)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("failed to execute SQL: %v", result.Error)
|
||||
}
|
||||
|
||||
resultSet := &entity2.ResultSet{
|
||||
Columns: []string{},
|
||||
Rows: []map[string]interface{}{},
|
||||
AffectedRows: result.RowsAffected,
|
||||
}
|
||||
|
||||
return &rdb.ExecuteSQLResponse{
|
||||
ResultSet: resultSet,
|
||||
}, nil
|
||||
}
|
||||
|
||||
rows, err := m.db.WithContext(ctx).Raw(processedSQL, processedParams...).Rows()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute SQL: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get columns: %v", err)
|
||||
}
|
||||
|
||||
resultSet := &entity2.ResultSet{
|
||||
Columns: columns,
|
||||
Rows: make([]map[string]interface{}, 0),
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
values := make([]interface{}, len(columns))
|
||||
valuePtrs := make([]interface{}, len(columns))
|
||||
for i := range values {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan row: %v", err)
|
||||
}
|
||||
|
||||
rowData := make(map[string]interface{})
|
||||
for i, col := range columns {
|
||||
rowData[col] = values[i]
|
||||
}
|
||||
resultSet.Rows = append(resultSet.Rows, rowData)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error while reading rows: %v", err)
|
||||
}
|
||||
|
||||
return &rdb.ExecuteSQLResponse{
|
||||
ResultSet: resultSet,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mysqlService) processSliceParams(sql string, params []interface{}) (string, []interface{}, error) {
|
||||
if len(params) == 0 {
|
||||
return sql, params, nil
|
||||
}
|
||||
|
||||
processedParams := make([]interface{}, 0)
|
||||
paramIndex := 0
|
||||
resultSQL := ""
|
||||
lastPos := 0
|
||||
|
||||
// get all ? positions
|
||||
for i := 0; i < len(sql); i++ {
|
||||
if sql[i] == '?' && paramIndex < len(params) {
|
||||
resultSQL += sql[lastPos:i]
|
||||
lastPos = i + 1
|
||||
|
||||
param := params[paramIndex]
|
||||
paramIndex++
|
||||
|
||||
if m.isSlice(param) {
|
||||
sliceValues, err := m.getSliceValues(param)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if len(sliceValues) == 0 {
|
||||
resultSQL += "(NULL)"
|
||||
} else {
|
||||
// (?, ?, ...)
|
||||
placeholders := make([]string, len(sliceValues))
|
||||
for j := range placeholders {
|
||||
placeholders[j] = "?"
|
||||
}
|
||||
resultSQL += "(" + strings.Join(placeholders, ", ") + ")"
|
||||
|
||||
processedParams = append(processedParams, sliceValues...)
|
||||
}
|
||||
} else {
|
||||
resultSQL += "?"
|
||||
processedParams = append(processedParams, param)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
resultSQL += sql[lastPos:]
|
||||
|
||||
return resultSQL, processedParams, nil
|
||||
}
|
||||
|
||||
func (m *mysqlService) isSlice(param interface{}) bool {
|
||||
if param == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
rv := reflect.ValueOf(param)
|
||||
return rv.Kind() == reflect.Slice && rv.Type().Elem().Kind() != reflect.Uint8 // exclude []byte
|
||||
}
|
||||
|
||||
func (m *mysqlService) getSliceValues(param interface{}) ([]interface{}, error) {
|
||||
rv := reflect.ValueOf(param)
|
||||
if rv.Kind() != reflect.Slice {
|
||||
return nil, fmt.Errorf("parameter is not a slice")
|
||||
}
|
||||
|
||||
length := rv.Len()
|
||||
values := make([]interface{}, length)
|
||||
|
||||
for i := 0; i < length; i++ {
|
||||
values[i] = rv.Index(i).Interface()
|
||||
}
|
||||
|
||||
return values, nil
|
||||
}
|
||||
|
||||
func (m *mysqlService) genTableName(ctx context.Context) (string, error) {
|
||||
id, err := m.generator.GenID(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return fmt.Sprintf("table_%d", id), nil
|
||||
}
|
||||
|
||||
func (m *mysqlService) getTableInfo(ctx context.Context, tableName string) (*entity2.Table, error) {
|
||||
tableInfoSQL := `
|
||||
SELECT
|
||||
TABLE_NAME,
|
||||
TABLE_COLLATION,
|
||||
AUTO_INCREMENT,
|
||||
TABLE_COMMENT
|
||||
FROM information_schema.TABLES
|
||||
WHERE TABLE_SCHEMA = DATABASE()
|
||||
AND TABLE_NAME = ?
|
||||
`
|
||||
|
||||
var (
|
||||
name string
|
||||
collation *string
|
||||
autoIncrement *int64
|
||||
comment *string
|
||||
)
|
||||
|
||||
err := m.db.WithContext(ctx).Raw(tableInfoSQL, tableName).Row().Scan(
|
||||
&name,
|
||||
&collation,
|
||||
&autoIncrement,
|
||||
&comment,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
columnsSQL := `
|
||||
SELECT
|
||||
COLUMN_NAME,
|
||||
DATA_TYPE,
|
||||
CHARACTER_MAXIMUM_LENGTH,
|
||||
IS_NULLABLE,
|
||||
COLUMN_DEFAULT,
|
||||
EXTRA,
|
||||
COLUMN_COMMENT
|
||||
FROM information_schema.COLUMNS
|
||||
WHERE TABLE_SCHEMA = DATABASE()
|
||||
AND TABLE_NAME = ?
|
||||
ORDER BY ORDINAL_POSITION
|
||||
`
|
||||
|
||||
type columnInfo struct {
|
||||
ColumnName string `gorm:"column:COLUMN_NAME"`
|
||||
DataType string `gorm:"column:DATA_TYPE"`
|
||||
CharLength *int `gorm:"column:CHARACTER_MAXIMUM_LENGTH"`
|
||||
IsNullable string `gorm:"column:IS_NULLABLE"`
|
||||
DefaultValue *string `gorm:"column:COLUMN_DEFAULT"`
|
||||
Extra string `gorm:"column:EXTRA"`
|
||||
ColumnComment *string `gorm:"column:COLUMN_COMMENT"`
|
||||
}
|
||||
|
||||
var columnsData []columnInfo
|
||||
err = m.db.WithContext(ctx).Raw(columnsSQL, tableName).Scan(&columnsData).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
columns := make([]*entity2.Column, len(columnsData))
|
||||
for i, colData := range columnsData {
|
||||
column := &entity2.Column{
|
||||
Name: colData.ColumnName,
|
||||
DataType: entity2.DataType(colData.DataType),
|
||||
Length: colData.CharLength,
|
||||
NotNull: colData.IsNullable == "NO",
|
||||
DefaultValue: colData.DefaultValue,
|
||||
AutoIncrement: strings.Contains(colData.Extra, "auto_increment"),
|
||||
Comment: colData.ColumnComment,
|
||||
}
|
||||
columns[i] = column
|
||||
}
|
||||
|
||||
indexesSQL := `
|
||||
SELECT
|
||||
INDEX_NAME,
|
||||
NON_UNIQUE,
|
||||
INDEX_TYPE,
|
||||
GROUP_CONCAT(COLUMN_NAME ORDER BY SEQ_IN_INDEX)
|
||||
FROM information_schema.STATISTICS
|
||||
WHERE TABLE_SCHEMA = DATABASE()
|
||||
AND TABLE_NAME = ?
|
||||
GROUP BY INDEX_NAME, NON_UNIQUE, INDEX_TYPE
|
||||
`
|
||||
|
||||
type indexInfo struct {
|
||||
IndexName string `gorm:"column:INDEX_NAME"`
|
||||
NonUnique int `gorm:"column:NON_UNIQUE"`
|
||||
IndexType string `gorm:"column:INDEX_TYPE"`
|
||||
Columns string `gorm:"column:GROUP_CONCAT"`
|
||||
}
|
||||
|
||||
var indexesData []indexInfo
|
||||
err = m.db.WithContext(ctx).Raw(indexesSQL, tableName).Scan(&indexesData).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
indexes := make([]*entity2.Index, 0, len(indexesData))
|
||||
for _, idxData := range indexesData {
|
||||
index := &entity2.Index{
|
||||
Name: idxData.IndexName,
|
||||
Type: entity2.IndexType(idxData.IndexType),
|
||||
Columns: strings.Split(idxData.Columns, ","),
|
||||
}
|
||||
indexes = append(indexes, index)
|
||||
}
|
||||
|
||||
return &entity2.Table{
|
||||
Name: name,
|
||||
Columns: columns,
|
||||
Indexes: indexes,
|
||||
Options: &entity2.TableOption{
|
||||
Collate: collation,
|
||||
AutoIncrement: autoIncrement,
|
||||
Comment: comment,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mysqlService) buildWhereClause(condition *rdb.ComplexCondition) (string, []interface{}, error) {
|
||||
if condition == nil {
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
if condition.Operator == "" {
|
||||
condition.Operator = entity2.AND
|
||||
}
|
||||
|
||||
var whereClause strings.Builder
|
||||
values := make([]interface{}, 0)
|
||||
|
||||
for i, cond := range condition.Conditions {
|
||||
if i > 0 {
|
||||
whereClause.WriteString(fmt.Sprintf(" %s ", condition.Operator))
|
||||
}
|
||||
|
||||
if cond.Operator == entity2.OperatorIn || cond.Operator == entity2.OperatorNotIn {
|
||||
if m.isSlice(cond.Value) {
|
||||
sliceValues, err := m.getSliceValues(cond.Value)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to process slice values: %v", err)
|
||||
}
|
||||
|
||||
if len(sliceValues) == 0 {
|
||||
whereClause.WriteString(fmt.Sprintf("`%s` %s (NULL)", cond.Field, string(cond.Operator)))
|
||||
} else {
|
||||
placeholders := make([]string, len(sliceValues))
|
||||
for i := range placeholders {
|
||||
placeholders[i] = "?"
|
||||
}
|
||||
whereClause.WriteString(fmt.Sprintf("`%s` %s (%s)", cond.Field, string(cond.Operator), strings.Join(placeholders, ",")))
|
||||
|
||||
values = append(values, sliceValues...)
|
||||
}
|
||||
} else {
|
||||
return "", nil, fmt.Errorf("IN operator requires a slice of values")
|
||||
}
|
||||
} else if cond.Operator == entity2.OperatorIsNull || cond.Operator == entity2.OperatorIsNotNull {
|
||||
whereClause.WriteString(fmt.Sprintf("`%s` %s", cond.Field, cond.Operator))
|
||||
} else {
|
||||
whereClause.WriteString(fmt.Sprintf("`%s` %s ?", cond.Field, cond.Operator))
|
||||
values = append(values, cond.Value)
|
||||
}
|
||||
}
|
||||
|
||||
if len(condition.NestedConditions) > 0 {
|
||||
whereClause.WriteString(" AND (")
|
||||
for i, nested := range condition.NestedConditions {
|
||||
if i > 0 {
|
||||
whereClause.WriteString(fmt.Sprintf(" %s ", nested.Operator))
|
||||
}
|
||||
nestedClause, nestedValues, err := m.buildWhereClause(nested)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
whereClause.WriteString(nestedClause)
|
||||
values = append(values, nestedValues...)
|
||||
}
|
||||
whereClause.WriteString(")")
|
||||
}
|
||||
|
||||
if whereClause.Len() > 0 {
|
||||
return " WHERE " + whereClause.String(), values, nil
|
||||
}
|
||||
return "", values, nil
|
||||
}
|
||||
891
backend/infra/impl/rdb/mysql_test.go
Normal file
891
backend/infra/impl/rdb/mysql_test.go
Normal file
@@ -0,0 +1,891 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package rdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/mock/gomock"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/rdb"
|
||||
entity2 "github.com/coze-dev/coze-studio/backend/infra/contract/rdb/entity"
|
||||
mock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
func setupTestDB(t *testing.T) (*gorm.DB, rdb.RDB) {
|
||||
dsn := "root:root@tcp(127.0.0.1:3306)/opencoze?charset=utf8mb4&parseTime=True&loc=Local"
|
||||
if os.Getenv("CI_JOB_NAME") != "" {
|
||||
dsn = strings.ReplaceAll(dsn, "127.0.0.1", "mysql")
|
||||
}
|
||||
db, err := gorm.Open(mysql.Open(dsn))
|
||||
assert.NoError(t, err)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
idGen := mock.NewMockIDGenerator(ctrl)
|
||||
idGen.EXPECT().GenID(gomock.Any()).Return(int64(123), nil).AnyTimes()
|
||||
|
||||
return db, NewService(db, idGen)
|
||||
}
|
||||
|
||||
func cleanupTestDB(t *testing.T, db *gorm.DB, tableNames ...string) {
|
||||
for _, tableName := range tableNames {
|
||||
db.WithContext(context.Background()).Exec(fmt.Sprintf("DROP TABLE IF EXISTS `%s`", tableName))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateTable(t *testing.T) {
|
||||
db, svc := setupTestDB(t)
|
||||
defer cleanupTestDB(t, db, "test_table")
|
||||
|
||||
length := 255
|
||||
req := &rdb.CreateTableRequest{
|
||||
Table: &entity2.Table{
|
||||
Name: "test_table",
|
||||
Columns: []*entity2.Column{
|
||||
{
|
||||
Name: "id",
|
||||
DataType: entity2.TypeInt,
|
||||
NotNull: true,
|
||||
},
|
||||
{
|
||||
Name: "name",
|
||||
DataType: entity2.TypeVarchar,
|
||||
Length: &length,
|
||||
NotNull: true,
|
||||
},
|
||||
{
|
||||
Name: "created_at",
|
||||
DataType: entity2.TypeTimestamp,
|
||||
NotNull: true,
|
||||
DefaultValue: func() *string {
|
||||
val := "CURRENT_TIMESTAMP"
|
||||
return &val
|
||||
}(),
|
||||
},
|
||||
{
|
||||
Name: "score",
|
||||
DataType: entity2.TypeDouble,
|
||||
NotNull: true,
|
||||
DefaultValue: ptr.Of("60.5"),
|
||||
},
|
||||
},
|
||||
Indexes: []*entity2.Index{
|
||||
{
|
||||
Name: "PRIMARY",
|
||||
Type: entity2.PrimaryKey,
|
||||
Columns: []string{"id"},
|
||||
},
|
||||
{
|
||||
Name: "idx_name",
|
||||
Type: entity2.NormalKey,
|
||||
Columns: []string{"name"},
|
||||
},
|
||||
},
|
||||
Options: &entity2.TableOption{
|
||||
Comment: func() *string {
|
||||
comment := "Test table created by unit test"
|
||||
return &comment
|
||||
}(),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := svc.CreateTable(context.Background(), req)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
assert.Equal(t, req.Table.Name, resp.Table.Name)
|
||||
|
||||
var tableExists bool
|
||||
err = db.Raw("SELECT 1 FROM information_schema.tables WHERE table_schema = DATABASE() AND table_name = ?", "test_table").Scan(&tableExists).Error
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, tableExists)
|
||||
}
|
||||
|
||||
func TestAlterTable(t *testing.T) {
|
||||
|
||||
db, svc := setupTestDB(t)
|
||||
defer cleanupTestDB(t, db, "test_table")
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
err := db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS test_table (
|
||||
id INT NOT NULL AUTO_INCREMENT,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
description VARCHAR(255) NOT NULL,
|
||||
droped VARCHAR(255) NOT NULL,
|
||||
PRIMARY KEY (id),
|
||||
INDEX idx_name (name)
|
||||
) COMMENT='Test table created by unit test'
|
||||
`).Error
|
||||
assert.NoError(t, err, "Failed to create test table")
|
||||
|
||||
length := 100
|
||||
req := &rdb.AlterTableRequest{
|
||||
TableName: "test_table",
|
||||
Operations: []*rdb.AlterTableOperation{
|
||||
{
|
||||
Action: entity2.AddColumn,
|
||||
Column: &entity2.Column{
|
||||
Name: "email",
|
||||
DataType: entity2.TypeVarchar,
|
||||
Length: &length,
|
||||
NotNull: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
Action: entity2.ModifyColumn,
|
||||
Column: &entity2.Column{
|
||||
Name: "description",
|
||||
DataType: entity2.TypeText,
|
||||
NotNull: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
Action: entity2.DropColumn,
|
||||
Column: &entity2.Column{
|
||||
Name: "droped",
|
||||
DataType: entity2.TypeVarchar,
|
||||
NotNull: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := svc.AlterTable(context.Background(), req)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
assert.Equal(t, "test_table", resp.Table.Name)
|
||||
|
||||
var columnExists bool
|
||||
err = db.Raw("SELECT 1 FROM information_schema.columns WHERE table_schema = DATABASE() AND table_name = ? AND column_name = ?", "test_table", "email").Scan(&columnExists).Error
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, columnExists)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetTable(t *testing.T) {
|
||||
db, svc := setupTestDB(t)
|
||||
defer cleanupTestDB(t, db, "test_info_table")
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
err := db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS test_info_table (
|
||||
id INT NOT NULL AUTO_INCREMENT,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
description TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
PRIMARY KEY (id),
|
||||
INDEX idx_name (name)
|
||||
) COMMENT='Table info test'
|
||||
`).Error
|
||||
assert.NoError(t, err, "Failed to create test table")
|
||||
|
||||
req := &rdb.GetTableRequest{
|
||||
TableName: "test_info_table",
|
||||
}
|
||||
|
||||
resp, err := svc.GetTable(context.Background(), req)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
assert.Equal(t, "test_info_table", resp.Table.Name)
|
||||
assert.Equal(t, len(resp.Table.Columns), 4)
|
||||
|
||||
columnMap := make(map[string]*entity2.Column)
|
||||
for _, col := range resp.Table.Columns {
|
||||
columnMap[col.Name] = col
|
||||
}
|
||||
|
||||
assert.Contains(t, columnMap, "id")
|
||||
assert.Contains(t, columnMap, "name")
|
||||
assert.Contains(t, columnMap, "created_at")
|
||||
})
|
||||
t.Run("not found", func(t *testing.T) {
|
||||
err := db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS test_info_table (
|
||||
id INT NOT NULL AUTO_INCREMENT,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
description TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
PRIMARY KEY (id),
|
||||
INDEX idx_name (name)
|
||||
) COMMENT='Table info test'
|
||||
`).Error
|
||||
assert.NoError(t, err, "Failed to create test table")
|
||||
|
||||
req := &rdb.GetTableRequest{
|
||||
TableName: "test_info_table_error_name",
|
||||
}
|
||||
|
||||
resp, err := svc.GetTable(context.Background(), req)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, err, sql.ErrNoRows)
|
||||
assert.Nil(t, resp)
|
||||
})
|
||||
}
|
||||
|
||||
func TestInsertData(t *testing.T) {
|
||||
db, svc := setupTestDB(t)
|
||||
defer cleanupTestDB(t, db, "test_insert_table")
|
||||
|
||||
err := db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS test_insert_table (
|
||||
id INT NOT NULL AUTO_INCREMENT,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
age INT,
|
||||
PRIMARY KEY (id)
|
||||
)
|
||||
`).Error
|
||||
assert.NoError(t, err, "Failed to create test table")
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
req := &rdb.InsertDataRequest{
|
||||
TableName: "test_insert_table",
|
||||
Data: []map[string]interface{}{
|
||||
{
|
||||
"name": "John Doe",
|
||||
"age": 30,
|
||||
},
|
||||
{
|
||||
"name": "Jane Smith",
|
||||
"age": nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := svc.InsertData(context.Background(), req)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
assert.Equal(t, int64(2), resp.AffectedRows)
|
||||
|
||||
var count int
|
||||
err = db.Raw("SELECT COUNT(*) FROM test_insert_table").Scan(&count).Error
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
})
|
||||
|
||||
t.Run("table name error", func(t *testing.T) {
|
||||
req := &rdb.InsertDataRequest{
|
||||
TableName: "test_insert_table_error_name",
|
||||
Data: []map[string]interface{}{
|
||||
{
|
||||
"name": "John Doe",
|
||||
"age": 30,
|
||||
},
|
||||
{
|
||||
"name": "Jane Smith",
|
||||
"age": nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := svc.InsertData(context.Background(), req)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, resp)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpdateData(t *testing.T) {
|
||||
db, svc := setupTestDB(t)
|
||||
defer cleanupTestDB(t, db, "test_update_table")
|
||||
|
||||
err := db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS test_update_table (
|
||||
id INT NOT NULL AUTO_INCREMENT,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
age INT,
|
||||
status VARCHAR(20) DEFAULT 'active',
|
||||
PRIMARY KEY (id)
|
||||
)
|
||||
`).Error
|
||||
assert.NoError(t, err, "Failed to create test table")
|
||||
|
||||
err = db.Exec("INSERT INTO test_update_table (name, age) VALUES (?, ?), (?, ?)",
|
||||
"John Doe", 30, "Jane Smith", 25).Error
|
||||
assert.NoError(t, err, "Failed to insert test data")
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
req := &rdb.UpdateDataRequest{
|
||||
TableName: "test_update_table",
|
||||
Data: map[string]interface{}{
|
||||
"age": 35,
|
||||
"status": "updated",
|
||||
},
|
||||
Where: &rdb.ComplexCondition{
|
||||
Conditions: []*rdb.Condition{
|
||||
{
|
||||
Field: "name",
|
||||
Operator: entity2.OperatorEqual,
|
||||
Value: "John Doe",
|
||||
},
|
||||
},
|
||||
Operator: entity2.AND,
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := svc.UpdateData(context.Background(), req)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
assert.Equal(t, int64(1), resp.AffectedRows)
|
||||
|
||||
type Result struct {
|
||||
Age int
|
||||
Status string
|
||||
}
|
||||
var result Result
|
||||
err = db.Raw("SELECT age, status FROM test_update_table WHERE name = ?", "John Doe").Scan(&result).Error
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 35, result.Age)
|
||||
assert.Equal(t, "updated", result.Status)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteData(t *testing.T) {
|
||||
db, svc := setupTestDB(t)
|
||||
defer cleanupTestDB(t, db, "test_delete_table")
|
||||
|
||||
err := db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS test_delete_table (
|
||||
id INT NOT NULL AUTO_INCREMENT,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
age INT,
|
||||
PRIMARY KEY (id)
|
||||
)
|
||||
`).Error
|
||||
assert.NoError(t, err, "Failed to create test table")
|
||||
|
||||
err = db.Exec("INSERT INTO test_delete_table (name, age) VALUES (?, ?), (?, ?), (?, ?)",
|
||||
"John Doe", 30, "Jane Smith", 25, "Bob Johnson", 40).Error
|
||||
assert.NoError(t, err, "Failed to insert test data")
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
req := &rdb.DeleteDataRequest{
|
||||
TableName: "test_delete_table",
|
||||
Where: &rdb.ComplexCondition{
|
||||
Conditions: []*rdb.Condition{
|
||||
{
|
||||
Field: "age",
|
||||
Operator: entity2.OperatorGreaterEqual,
|
||||
Value: 30,
|
||||
},
|
||||
},
|
||||
Operator: entity2.AND,
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := svc.DeleteData(context.Background(), req)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
assert.Equal(t, int64(2), resp.AffectedRows)
|
||||
|
||||
var count int
|
||||
err = db.Raw("SELECT COUNT(*) FROM test_delete_table").Scan(&count).Error
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, count)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSelectData(t *testing.T) {
|
||||
db, svc := setupTestDB(t)
|
||||
defer cleanupTestDB(t, db, "test_select_table")
|
||||
|
||||
err := db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS test_select_table (
|
||||
id INT NOT NULL AUTO_INCREMENT,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
age BIGINT,
|
||||
status VARCHAR(20) DEFAULT 'active',
|
||||
score FLOAT,
|
||||
score2 DOUBLE DEFAULT '90.5',
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
PRIMARY KEY (id)
|
||||
)
|
||||
`).Error
|
||||
assert.NoError(t, err, "Failed to create test table")
|
||||
|
||||
err = db.Exec(`
|
||||
INSERT INTO test_select_table (name, age, status, score, score2) VALUES
|
||||
(?, ?, ?, ?, ?), (?, ?, ?, ?, ?), (?, ?, ?, ?, ?), (?, ?, ?, ?, ?)`,
|
||||
"John Doe", 30, "active", 1.1, 89.55554444,
|
||||
"Jane Smith", 25, "active", 1.2, 90.55554444,
|
||||
"Bob Johnson", 40, "inactive", 1.3, 91.55554444,
|
||||
"Alice Brown", 35, "active", nil, 92.55554444).Error
|
||||
assert.NoError(t, err, "Failed to insert test data")
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
req := &rdb.SelectDataRequest{
|
||||
TableName: "test_select_table",
|
||||
Fields: []string{"id", "name", "age", "created_at", "score", "score2"},
|
||||
Where: &rdb.ComplexCondition{
|
||||
Conditions: []*rdb.Condition{
|
||||
{
|
||||
Field: "status",
|
||||
Operator: entity2.OperatorEqual,
|
||||
Value: "active",
|
||||
},
|
||||
{
|
||||
Field: "age",
|
||||
Operator: entity2.OperatorGreaterEqual,
|
||||
Value: 25,
|
||||
},
|
||||
},
|
||||
Operator: entity2.AND,
|
||||
},
|
||||
OrderBy: []*rdb.OrderBy{
|
||||
{
|
||||
Field: "age",
|
||||
Direction: entity2.SortDirectionDesc,
|
||||
},
|
||||
},
|
||||
Limit: func() *int { limit := 2; return &limit }(),
|
||||
Offset: func() *int { offset := 0; return &offset }(),
|
||||
}
|
||||
|
||||
resp, err := svc.SelectData(context.Background(), req)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
assert.Equal(t, int64(3), resp.Total)
|
||||
assert.Len(t, resp.ResultSet.Rows, 2)
|
||||
|
||||
if len(resp.ResultSet.Rows) > 0 {
|
||||
firstRow := resp.ResultSet.Rows[0]
|
||||
assert.Equal(t, "Alice Brown", string(firstRow["name"].([]uint8)))
|
||||
assert.Equal(t, int64(35), firstRow["age"])
|
||||
|
||||
assert.Equal(t, int64(35), firstRow["age"].(int64))
|
||||
assert.Equal(t, nil, firstRow["score"])
|
||||
timeR := firstRow["created_at"].(time.Time)
|
||||
assert.False(t, timeR.IsZero())
|
||||
assert.Nil(t, firstRow["score"])
|
||||
assert.Equal(t, 92.55554444, firstRow["score2"].(float64))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
req := &rdb.SelectDataRequest{
|
||||
TableName: "test_select_table",
|
||||
Fields: []string{"id", "name", "age", "created_at", "score"},
|
||||
Where: &rdb.ComplexCondition{
|
||||
Conditions: []*rdb.Condition{
|
||||
{
|
||||
Field: "age",
|
||||
Operator: entity2.OperatorIn,
|
||||
Value: []int{30, 25, 18},
|
||||
},
|
||||
},
|
||||
Operator: entity2.AND,
|
||||
},
|
||||
OrderBy: []*rdb.OrderBy{
|
||||
{
|
||||
Field: "age",
|
||||
Direction: entity2.SortDirectionDesc,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := svc.SelectData(context.Background(), req)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
assert.Equal(t, int64(2), resp.Total)
|
||||
assert.Len(t, resp.ResultSet.Rows, 2)
|
||||
|
||||
if len(resp.ResultSet.Rows) > 0 {
|
||||
firstRow := resp.ResultSet.Rows[0]
|
||||
assert.Equal(t, "John Doe", string(firstRow["name"].([]uint8)))
|
||||
assert.Equal(t, int64(30), firstRow["age"])
|
||||
|
||||
assert.Equal(t, float32(1.1), firstRow["score"].(float32))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestExecuteSQL(t *testing.T) {
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
db, svc := setupTestDB(t)
|
||||
defer cleanupTestDB(t, db, "test_sql_table")
|
||||
|
||||
err := db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS test_sql_table (
|
||||
id INT NOT NULL AUTO_INCREMENT,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
age INT,
|
||||
PRIMARY KEY (id)
|
||||
)
|
||||
`).Error
|
||||
assert.NoError(t, err, "Failed to create test table")
|
||||
err = db.Exec("INSERT INTO test_sql_table (name, age) VALUES (?, ?), (?, ?)",
|
||||
"John Doe", 30, "Jane Smith", 25).Error
|
||||
assert.NoError(t, err, "Failed to insert test data")
|
||||
|
||||
req := &rdb.ExecuteSQLRequest{
|
||||
SQL: "SELECT id, name, age FROM test_sql_table WHERE age in ? and name in ? ORDER BY age DESC",
|
||||
Params: []interface{}{[]int{30, 25}, []string{"John Doe", "Jane Smith"}},
|
||||
}
|
||||
|
||||
resp, err := svc.ExecuteSQL(context.Background(), req)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
assert.Len(t, resp.ResultSet.Rows, 2)
|
||||
assert.Equal(t, []string{"id", "name", "age"}, resp.ResultSet.Columns)
|
||||
|
||||
if len(resp.ResultSet.Rows) > 0 {
|
||||
firstRow := resp.ResultSet.Rows[0]
|
||||
assert.Equal(t, "John Doe", string(firstRow["name"].([]uint8)))
|
||||
}
|
||||
|
||||
rawReq := &rdb.ExecuteSQLRequest{
|
||||
SQL: "SELECT id, name, age FROM test_sql_table WHERE age in (30, 25) and name in (\"John Doe\", \"Jane Smith\") ORDER BY age DESC",
|
||||
}
|
||||
|
||||
rawResp, err := svc.ExecuteSQL(context.Background(), rawReq)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, rawResp)
|
||||
assert.Len(t, rawResp.ResultSet.Rows, 2)
|
||||
assert.Equal(t, []string{"id", "name", "age"}, rawResp.ResultSet.Columns)
|
||||
|
||||
if len(rawResp.ResultSet.Rows) > 0 {
|
||||
firstRow := rawResp.ResultSet.Rows[0]
|
||||
assert.Equal(t, "John Doe", string(firstRow["name"].([]uint8)))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
db, svc := setupTestDB(t)
|
||||
defer cleanupTestDB(t, db, "test_sql_table")
|
||||
|
||||
err := db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS test_sql_table (
|
||||
id INT NOT NULL AUTO_INCREMENT,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
age INT,
|
||||
PRIMARY KEY (id)
|
||||
)
|
||||
`).Error
|
||||
assert.NoError(t, err, "Failed to create test table")
|
||||
err = db.Exec("INSERT INTO test_sql_table (name, age) VALUES (?, ?), (?, ?)",
|
||||
"John Doe", 30, "Jane Smith", 25).Error
|
||||
assert.NoError(t, err, "Failed to insert test data")
|
||||
|
||||
req := &rdb.ExecuteSQLRequest{
|
||||
SQL: "SELECT id, name, age FROM test_sql_table WHERE age in (?, ?) and name in (?, ?) ORDER BY age DESC",
|
||||
Params: []interface{}{30, 25, "John Doe", "Jane Smith"},
|
||||
}
|
||||
|
||||
resp, err := svc.ExecuteSQL(context.Background(), req)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
assert.Equal(t, len(resp.ResultSet.Rows), 2)
|
||||
assert.Equal(t, []string{"id", "name", "age"}, resp.ResultSet.Columns)
|
||||
|
||||
if len(resp.ResultSet.Rows) > 0 {
|
||||
firstRow := resp.ResultSet.Rows[0]
|
||||
assert.Equal(t, "John Doe", string(firstRow["name"].([]uint8)))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpsertData(t *testing.T) {
|
||||
t.Run("insert new records", func(t *testing.T) {
|
||||
db, svc := setupTestDB(t)
|
||||
defer cleanupTestDB(t, db, "test_upsert_table")
|
||||
|
||||
err := db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS test_upsert_table (
|
||||
id INT NOT NULL,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
age INT,
|
||||
status VARCHAR(20) DEFAULT 'active',
|
||||
PRIMARY KEY (id),
|
||||
UNIQUE KEY idx_name (name)
|
||||
)
|
||||
`).Error
|
||||
assert.NoError(t, err, "Failed to create test table")
|
||||
|
||||
req := &rdb.UpsertDataRequest{
|
||||
TableName: "test_upsert_table",
|
||||
Data: []map[string]interface{}{
|
||||
{
|
||||
"id": 1,
|
||||
"name": "John Doe",
|
||||
"age": 30,
|
||||
"status": "active",
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"name": "Jane Smith",
|
||||
"age": 25,
|
||||
"status": "active",
|
||||
},
|
||||
},
|
||||
Keys: []string{"name"},
|
||||
}
|
||||
|
||||
resp, err := svc.UpsertData(context.Background(), req)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
assert.Equal(t, int64(2), resp.InsertedRows)
|
||||
assert.Equal(t, int64(0), resp.UpdatedRows)
|
||||
|
||||
var count int
|
||||
err = db.Raw("SELECT COUNT(*) FROM test_upsert_table WHERE name IN (?, ?)",
|
||||
"John Doe", "Jane Smith").Scan(&count).Error
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
})
|
||||
|
||||
t.Run("error", func(t *testing.T) {
|
||||
db, svc := setupTestDB(t)
|
||||
defer cleanupTestDB(t, db, "test_upsert_table")
|
||||
|
||||
err := db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS test_upsert_table (
|
||||
id INT NOT NULL,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
age INT,
|
||||
status VARCHAR(20) DEFAULT 'active',
|
||||
PRIMARY KEY (id),
|
||||
UNIQUE KEY idx_name (name)
|
||||
)
|
||||
`).Error
|
||||
assert.NoError(t, err, "Failed to create test table")
|
||||
|
||||
reqInsert := &rdb.UpsertDataRequest{
|
||||
TableName: "test_upsert_table",
|
||||
Data: []map[string]interface{}{
|
||||
{
|
||||
"id": 1,
|
||||
"name": "John Doe",
|
||||
"age": 30,
|
||||
"status": "active",
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"name": "Jane Smith",
|
||||
"age": 25,
|
||||
"status": "active",
|
||||
},
|
||||
},
|
||||
Keys: []string{"name"},
|
||||
}
|
||||
_, err = svc.UpsertData(context.Background(), reqInsert)
|
||||
assert.NoError(t, err)
|
||||
|
||||
req := &rdb.UpsertDataRequest{
|
||||
TableName: "test_upsert_table",
|
||||
Data: []map[string]interface{}{
|
||||
{
|
||||
"name": "New Person",
|
||||
"age": 40,
|
||||
"status": "active",
|
||||
},
|
||||
},
|
||||
Keys: []string{"name"},
|
||||
}
|
||||
|
||||
resp, err := svc.UpsertData(context.Background(), req)
|
||||
assert.Nil(t, resp)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("update existing records", func(t *testing.T) {
|
||||
db, svc := setupTestDB(t)
|
||||
defer cleanupTestDB(t, db, "test_upsert_table")
|
||||
|
||||
err := db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS test_upsert_table (
|
||||
id INT NOT NULL,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
age INT,
|
||||
status VARCHAR(20) DEFAULT 'active',
|
||||
PRIMARY KEY (id),
|
||||
UNIQUE KEY idx_name (name)
|
||||
)
|
||||
`).Error
|
||||
assert.NoError(t, err, "Failed to create test table")
|
||||
|
||||
reqInsert := &rdb.UpsertDataRequest{
|
||||
TableName: "test_upsert_table",
|
||||
Data: []map[string]interface{}{
|
||||
{
|
||||
"id": 1,
|
||||
"name": "John Doe",
|
||||
"age": 30,
|
||||
"status": "active",
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"name": "Jane Smith",
|
||||
"age": 25,
|
||||
"status": "active",
|
||||
},
|
||||
},
|
||||
Keys: []string{"name"},
|
||||
}
|
||||
_, err = svc.UpsertData(context.Background(), reqInsert)
|
||||
assert.NoError(t, err)
|
||||
|
||||
req := &rdb.UpsertDataRequest{
|
||||
TableName: "test_upsert_table",
|
||||
Data: []map[string]interface{}{
|
||||
{
|
||||
"id": 1,
|
||||
"name": "John Doe",
|
||||
"age": 35,
|
||||
"status": "updated",
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"name": "Jane Smith",
|
||||
"age": 25,
|
||||
"status": "updated",
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"name": "New Person",
|
||||
"age": 40,
|
||||
"status": "active",
|
||||
},
|
||||
{
|
||||
"id": 4,
|
||||
"name": "New Person 2",
|
||||
"age": 40,
|
||||
"status": "active",
|
||||
},
|
||||
{
|
||||
"id": 5,
|
||||
"name": "New Person 3",
|
||||
"age": 40,
|
||||
"status": "active",
|
||||
},
|
||||
},
|
||||
Keys: []string{"name"},
|
||||
}
|
||||
|
||||
resp, err := svc.UpsertData(context.Background(), req)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
assert.Equal(t, int64(3), resp.InsertedRows)
|
||||
assert.Equal(t, int64(2), resp.UpdatedRows)
|
||||
})
|
||||
|
||||
t.Run("use primary key when keys not specified", func(t *testing.T) {
|
||||
db, svc := setupTestDB(t)
|
||||
defer cleanupTestDB(t, db, "test_upsert_table")
|
||||
|
||||
err := db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS test_upsert_table (
|
||||
id INT NOT NULL,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
age INT,
|
||||
status VARCHAR(20) DEFAULT 'active',
|
||||
PRIMARY KEY (age),
|
||||
UNIQUE KEY idx_name (name)
|
||||
)
|
||||
`).Error
|
||||
assert.NoError(t, err, "Failed to create test table")
|
||||
|
||||
reqInsert := &rdb.UpsertDataRequest{
|
||||
TableName: "test_upsert_table",
|
||||
Data: []map[string]interface{}{
|
||||
{
|
||||
"id": 1,
|
||||
"name": "John Doe",
|
||||
"age": 30,
|
||||
"status": "active",
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"name": "Jane Smith",
|
||||
"age": 25,
|
||||
"status": "active",
|
||||
},
|
||||
},
|
||||
}
|
||||
_, err = svc.UpsertData(context.Background(), reqInsert)
|
||||
assert.NoError(t, err)
|
||||
|
||||
req := &rdb.UpsertDataRequest{
|
||||
TableName: "test_upsert_table",
|
||||
Data: []map[string]interface{}{
|
||||
{
|
||||
"id": 1,
|
||||
"name": "John Doe Updated",
|
||||
"age": 30,
|
||||
"status": "primary key updated",
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"name": "New Person",
|
||||
"age": 40,
|
||||
"status": "active",
|
||||
},
|
||||
{
|
||||
"id": 4,
|
||||
"name": "New Person 2",
|
||||
"age": 45,
|
||||
"status": "active",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := svc.UpsertData(context.Background(), req)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
assert.Equal(t, int64(2), resp.InsertedRows)
|
||||
assert.Equal(t, int64(1), resp.UpdatedRows)
|
||||
req = &rdb.UpsertDataRequest{
|
||||
TableName: "test_upsert_table",
|
||||
Data: []map[string]interface{}{
|
||||
{
|
||||
"id": 1,
|
||||
"name": "John Doe Updated",
|
||||
"age": 30,
|
||||
"status": "primary key updated",
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"name": "New Person",
|
||||
"age": 40,
|
||||
"status": "active",
|
||||
},
|
||||
{
|
||||
"id": 4,
|
||||
"name": "New Person 2",
|
||||
"age": 45,
|
||||
"status": "active update",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err = svc.UpsertData(context.Background(), req)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
assert.Equal(t, int64(2), resp.AffectedRows)
|
||||
assert.Equal(t, int64(1), resp.UnchangedRows)
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user