feat: manually mirror opencoze's code from bytedance

Change-Id: I09a73aadda978ad9511264a756b2ce51f5761adf
This commit is contained in:
fanlv
2025-07-20 17:36:12 +08:00
commit 890153324f
14811 changed files with 1923430 additions and 0 deletions

View File

@@ -0,0 +1,993 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package rdb
import (
"context"
"fmt"
"reflect"
"strings"
"gorm.io/gorm"
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
"github.com/coze-dev/coze-studio/backend/infra/contract/rdb"
entity2 "github.com/coze-dev/coze-studio/backend/infra/contract/rdb/entity"
sqlparsercontract "github.com/coze-dev/coze-studio/backend/infra/contract/sqlparser"
"github.com/coze-dev/coze-studio/backend/infra/impl/sqlparser"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
)
type mysqlService struct {
db *gorm.DB
generator idgen.IDGenerator
}
func NewService(db *gorm.DB, generator idgen.IDGenerator) rdb.RDB {
return &mysqlService{db: db, generator: generator}
}
// CreateTable create table
func (m *mysqlService) CreateTable(ctx context.Context, req *rdb.CreateTableRequest) (*rdb.CreateTableResponse, error) {
if req == nil || req.Table == nil {
return nil, fmt.Errorf("invalid request")
}
// build column definitions
columnDefs := make([]string, 0, len(req.Table.Columns))
for _, col := range req.Table.Columns {
colDef := fmt.Sprintf("`%s` %s", col.Name, col.DataType)
if col.Length != nil {
colDef += fmt.Sprintf("(%d)", *col.Length)
} else if col.Length == nil && col.DataType == entity2.TypeVarchar {
colDef += fmt.Sprintf("(%d)", 255)
}
if col.NotNull {
colDef += " NOT NULL"
}
if col.DefaultValue != nil {
if col.DataType == entity2.TypeTimestamp {
colDef += fmt.Sprintf(" DEFAULT %s", *col.DefaultValue)
} else {
colDef += fmt.Sprintf(" DEFAULT '%s'", *col.DefaultValue)
}
}
if col.AutoIncrement {
colDef += " AUTO_INCREMENT"
}
if col.Comment != nil {
colDef += fmt.Sprintf(" COMMENT '%s'", *col.Comment)
}
columnDefs = append(columnDefs, colDef)
}
// build index definitions
for _, idx := range req.Table.Indexes {
var idxDef string
switch idx.Type {
case entity2.PrimaryKey:
idxDef = fmt.Sprintf("PRIMARY KEY (`%s`)", strings.Join(idx.Columns, "`,`"))
case entity2.UniqueKey:
idxDef = fmt.Sprintf("UNIQUE KEY `%s` (`%s`)", idx.Name, strings.Join(idx.Columns, "`,`"))
default:
idxDef = fmt.Sprintf("KEY `%s` (`%s`)", idx.Name, strings.Join(idx.Columns, "`,`"))
}
columnDefs = append(columnDefs, idxDef)
}
tableOptions := make([]string, 0)
if req.Table.Options != nil {
if req.Table.Options.Collate != nil {
tableOptions = append(tableOptions, fmt.Sprintf("COLLATE=%s", *req.Table.Options.Collate))
}
if req.Table.Options.AutoIncrement != nil {
tableOptions = append(tableOptions, fmt.Sprintf("AUTO_INCREMENT=%d", *req.Table.Options.AutoIncrement))
}
if req.Table.Options.Comment != nil {
tableOptions = append(tableOptions, fmt.Sprintf("COMMENT='%s'", *req.Table.Options.Comment))
}
}
tableName := req.Table.Name
if req.Table.Name == "" {
genName, err := m.genTableName(ctx)
if err != nil {
return nil, err
}
tableName = genName
}
createSQL := fmt.Sprintf("CREATE TABLE IF NOT EXISTS `%s` (\n %s\n) %s",
tableName,
strings.Join(columnDefs, ",\n "),
strings.Join(tableOptions, " "),
)
logs.CtxInfof(ctx, "[CreateTable] execute sql is %s, req is %v", createSQL, req)
err := m.db.WithContext(ctx).Exec(createSQL).Error
if err != nil {
return nil, fmt.Errorf("failed to create table: %v", err)
}
resTable := req.Table
resTable.Name = tableName
return &rdb.CreateTableResponse{Table: resTable}, nil
}
// AlterTable alter table
func (m *mysqlService) AlterTable(ctx context.Context, req *rdb.AlterTableRequest) (*rdb.AlterTableResponse, error) {
if req == nil || len(req.Operations) == 0 {
return nil, fmt.Errorf("invalid request")
}
alterSQL := fmt.Sprintf("ALTER TABLE `%s`", req.TableName)
operations := make([]string, 0, len(req.Operations))
for _, op := range req.Operations {
switch op.Action {
case entity2.AddColumn:
if op.Column == nil {
return nil, fmt.Errorf("column is required for ADD COLUMN operation")
}
colDef := fmt.Sprintf("ADD COLUMN `%s` %s", op.Column.Name, op.Column.DataType)
if op.Column.Length != nil {
colDef += fmt.Sprintf("(%d)", *op.Column.Length)
} else if op.Column.Length == nil && op.Column.DataType == entity2.TypeVarchar {
colDef += fmt.Sprintf("(%d)", 255)
}
if op.Column.NotNull {
colDef += " NOT NULL"
}
if op.Column.DefaultValue != nil {
if op.Column.DataType == entity2.TypeTimestamp {
colDef += fmt.Sprintf(" DEFAULT %s", *op.Column.DefaultValue)
} else {
colDef += fmt.Sprintf(" DEFAULT '%s'", *op.Column.DefaultValue)
}
}
operations = append(operations, colDef)
case entity2.DropColumn:
if op.Column == nil {
return nil, fmt.Errorf("column is required for DROP COLUMN operation")
}
operations = append(operations, fmt.Sprintf("DROP COLUMN `%s`", op.Column.Name))
case entity2.ModifyColumn:
if op.Column == nil {
return nil, fmt.Errorf("column is required for MODIFY COLUMN operation")
}
colDef := fmt.Sprintf("MODIFY COLUMN `%s` %s", op.Column.Name, op.Column.DataType)
if op.Column.Length != nil {
colDef += fmt.Sprintf("(%d)", *op.Column.Length)
} else if op.Column.Length == nil && op.Column.DataType == entity2.TypeVarchar {
colDef += fmt.Sprintf("(%d)", 255)
}
operations = append(operations, colDef)
case entity2.RenameColumn:
if op.Column == nil || op.OldName == nil {
return nil, fmt.Errorf("column and old name are required for RENAME COLUMN operation")
}
operations = append(operations, fmt.Sprintf("RENAME COLUMN `%s` TO `%s`", *op.OldName, op.Column.Name))
case entity2.AddIndex:
if op.Index == nil {
return nil, fmt.Errorf("index is required for ADD INDEX operation")
}
var idxDef string
switch op.Index.Type {
case entity2.PrimaryKey:
idxDef = fmt.Sprintf("ADD PRIMARY KEY (`%s`)", strings.Join(op.Index.Columns, "`,`"))
case entity2.UniqueKey:
idxDef = fmt.Sprintf("ADD UNIQUE INDEX `%s` (`%s`)", op.Index.Name, strings.Join(op.Index.Columns, "`,`"))
default:
idxDef = fmt.Sprintf("ADD INDEX `%s` (`%s`)", op.Index.Name, strings.Join(op.Index.Columns, "`,`"))
}
operations = append(operations, idxDef)
}
}
alterSQL += " " + strings.Join(operations, ", ")
logs.CtxInfof(ctx, "[AlterTable] execute sql is %s, req is %v", alterSQL, req)
err := m.db.WithContext(ctx).Exec(alterSQL).Error
if err != nil {
return nil, fmt.Errorf("failed to alter table: %v", err)
}
table, err := m.getTableInfo(ctx, req.TableName)
if err != nil {
return nil, fmt.Errorf("failed to get table info: %v", err)
}
return &rdb.AlterTableResponse{Table: table}, nil
}
// DropTable drop table
func (m *mysqlService) DropTable(ctx context.Context, req *rdb.DropTableRequest) (*rdb.DropTableResponse, error) {
if req == nil {
return nil, fmt.Errorf("invalid request")
}
dropSQL := "DROP TABLE"
if req.IfExists {
dropSQL += " IF EXISTS"
}
dropSQL += fmt.Sprintf(" `%s`", req.TableName)
logs.CtxInfof(ctx, "[DropTable] execute sql is %s, req is %v", dropSQL, req)
err := m.db.WithContext(ctx).Exec(dropSQL).Error
if err != nil {
return nil, fmt.Errorf("failed to drop table: %v", err)
}
return &rdb.DropTableResponse{Success: true}, nil
}
// GetTable get table schema info
func (m *mysqlService) GetTable(ctx context.Context, req *rdb.GetTableRequest) (*rdb.GetTableResponse, error) {
if req == nil {
return nil, fmt.Errorf("invalid request")
}
table, err := m.getTableInfo(ctx, req.TableName)
if err != nil {
return nil, err
}
return &rdb.GetTableResponse{Table: table}, nil
}
func (m *mysqlService) InsertData(ctx context.Context, req *rdb.InsertDataRequest) (*rdb.InsertDataResponse, error) {
if req == nil || len(req.Data) == 0 {
return nil, fmt.Errorf("invalid request")
}
fields := make([]string, 0)
for field := range req.Data[0] {
fields = append(fields, field)
}
const batchSize = 1000
var totalAffected int64
for i := 0; i < len(req.Data); i += batchSize {
end := i + batchSize
if end > len(req.Data) {
end = len(req.Data)
}
currentBatch := req.Data[i:end]
placeholderGroups := make([]string, 0, len(currentBatch))
values := make([]interface{}, 0, len(currentBatch)*len(fields))
for _, row := range currentBatch {
placeholders := make([]string, len(fields))
for j := range placeholders {
placeholders[j] = "?"
}
placeholderGroups = append(placeholderGroups, "("+strings.Join(placeholders, ",")+")")
for _, field := range fields {
values = append(values, row[field])
}
}
insertSQL := fmt.Sprintf("INSERT INTO `%s` (`%s`) VALUES %s",
req.TableName,
strings.Join(fields, "`,`"),
strings.Join(placeholderGroups, ","),
)
logs.CtxInfof(ctx, "[InsertData] execute sql is %s, value is %v in batch %d", insertSQL, values, i)
result := m.db.WithContext(ctx).Exec(insertSQL, values...)
if result.Error != nil {
return nil, result.Error
}
affected := result.RowsAffected
totalAffected += affected
}
return &rdb.InsertDataResponse{AffectedRows: totalAffected}, nil
}
// UpdateData Update data
func (m *mysqlService) UpdateData(ctx context.Context, req *rdb.UpdateDataRequest) (*rdb.UpdateDataResponse, error) {
if req == nil {
return nil, fmt.Errorf("invalid request")
}
setClauses := make([]string, 0)
values := make([]interface{}, 0)
for field, value := range req.Data {
setClauses = append(setClauses, fmt.Sprintf("`%s` = ?", field))
values = append(values, value)
}
whereClause, whereValues, err := m.buildWhereClause(req.Where)
if err != nil {
return nil, fmt.Errorf("failed to build where clause: %v", err)
}
values = append(values, whereValues...)
limitClause := ""
if req.Limit != nil {
limitClause = fmt.Sprintf(" LIMIT %d", *req.Limit)
}
updateSQL := fmt.Sprintf("UPDATE `%s` SET %s%s%s",
req.TableName,
strings.Join(setClauses, ", "),
whereClause,
limitClause,
)
logs.CtxInfof(ctx, "[UpdateData] execute sql is %s, value is %v, req is %v", updateSQL, values, req)
result := m.db.WithContext(ctx).Exec(updateSQL, values...)
if result.Error != nil {
return nil, result.Error
}
affectedRows := result.RowsAffected
return &rdb.UpdateDataResponse{AffectedRows: affectedRows}, nil
}
// DeleteData delete data
func (m *mysqlService) DeleteData(ctx context.Context, req *rdb.DeleteDataRequest) (*rdb.DeleteDataResponse, error) {
if req == nil {
return nil, fmt.Errorf("invalid request")
}
whereClause, whereValues, err := m.buildWhereClause(req.Where)
if err != nil {
return nil, fmt.Errorf("failed to build where clause: %v", err)
}
limitClause := ""
if req.Limit != nil {
limitClause = fmt.Sprintf(" LIMIT %d", *req.Limit)
}
deleteSQL := fmt.Sprintf("DELETE FROM `%s`%s%s",
req.TableName,
whereClause,
limitClause,
)
logs.CtxInfof(ctx, "[DeleteData] execute sql is %s, value is %v, req is %v", deleteSQL, whereValues, req)
result := m.db.WithContext(ctx).Exec(deleteSQL, whereValues...)
if result.Error != nil {
return nil, fmt.Errorf("failed to delete data: %v", result.Error)
}
affectedRows := result.RowsAffected
return &rdb.DeleteDataResponse{AffectedRows: affectedRows}, nil
}
// SelectData select data
func (m *mysqlService) SelectData(ctx context.Context, req *rdb.SelectDataRequest) (*rdb.SelectDataResponse, error) {
if req == nil {
return nil, fmt.Errorf("invalid request")
}
fields := "*"
if len(req.Fields) > 0 {
fields = strings.Join(req.Fields, ", ")
}
whereClause := ""
whereValues := make([]interface{}, 0)
if req.Where != nil {
clause, values, err := m.buildWhereClause(req.Where)
if err != nil {
return nil, fmt.Errorf("failed to build where clause: %v", err)
}
whereClause = clause
whereValues = values
}
orderByClause := ""
if len(req.OrderBy) > 0 {
orders := make([]string, len(req.OrderBy))
for i, order := range req.OrderBy {
orders[i] = fmt.Sprintf("%s %s", order.Field, order.Direction)
}
orderByClause = " ORDER BY " + strings.Join(orders, ", ")
}
limitClause := ""
if req.Limit != nil {
limitClause = fmt.Sprintf(" LIMIT %d", *req.Limit)
if req.Offset != nil {
limitClause += fmt.Sprintf(" OFFSET %d", *req.Offset)
}
}
selectSQL := fmt.Sprintf("SELECT %s FROM `%s`%s%s%s",
fields,
req.TableName,
whereClause,
orderByClause,
limitClause,
)
logs.CtxInfof(ctx, "[SelectData] execute sql is %s, value is %v, req is %v", selectSQL, whereValues, req)
rows, err := m.db.WithContext(ctx).Raw(selectSQL, whereValues...).Rows()
if err != nil {
return nil, fmt.Errorf("failed to execute select: %v", err)
}
defer rows.Close()
columns, err := rows.Columns()
if err != nil {
return nil, fmt.Errorf("failed to get columns: %v", err)
}
resultSet := &entity2.ResultSet{
Columns: columns,
Rows: make([]map[string]interface{}, 0),
}
for rows.Next() {
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range values {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
return nil, fmt.Errorf("failed to scan row: %v", err)
}
rowData := make(map[string]interface{})
for i, col := range columns {
rowData[col] = values[i]
}
resultSet.Rows = append(resultSet.Rows, rowData)
}
// get total count
var total int64
if whereClause != "" {
countSQL := fmt.Sprintf("SELECT COUNT(*) FROM `%s`%s", req.TableName, whereClause)
err = m.db.WithContext(ctx).Raw(countSQL, whereValues...).Scan(&total).Error
if err != nil {
return nil, fmt.Errorf("failed to get total count: %v", err)
}
} else {
total = int64(len(resultSet.Rows))
}
return &rdb.SelectDataResponse{
ResultSet: resultSet,
Total: total,
}, nil
}
// UpsertData upsert data
func (m *mysqlService) UpsertData(ctx context.Context, req *rdb.UpsertDataRequest) (*rdb.UpsertDataResponse, error) {
if req == nil || len(req.Data) == 0 {
return nil, fmt.Errorf("invalid request: empty data")
}
keys := req.Keys
if len(keys) == 0 {
primaryKeys, err := m.getTablePrimaryKeys(ctx, req.TableName)
if err != nil {
return nil, fmt.Errorf("failed to get primary keys: %v", err)
}
if len(primaryKeys) == 0 {
return nil, fmt.Errorf("table %s has no primary key, keys are required for upsert operation", req.TableName)
}
keys = primaryKeys
}
fields := make([]string, 0)
for field := range req.Data[0] {
fields = append(fields, field)
}
const batchSize = 1000
var totalAffected, totalInserted, totalUpdated int64
for i := 0; i < len(req.Data); i += batchSize {
end := i + batchSize
if end > len(req.Data) {
end = len(req.Data)
}
currentBatch := req.Data[i:end]
placeholderGroups := make([]string, 0, len(currentBatch))
values := make([]interface{}, 0, len(currentBatch)*len(fields))
for _, row := range currentBatch {
placeholders := make([]string, len(fields))
for j := range placeholders {
placeholders[j] = "?"
}
placeholderGroups = append(placeholderGroups, "("+strings.Join(placeholders, ",")+")")
for _, field := range fields {
values = append(values, row[field])
}
}
// ON DUPLICATE KEY UPDATE部分
updateClauses := make([]string, 0, len(fields))
for _, field := range fields {
isKey := false
for _, key := range keys {
if field == key {
isKey = true
break
}
}
if !isKey {
updateClauses = append(updateClauses, fmt.Sprintf("`%s`=VALUES(`%s`)", field, field))
}
}
upsertSQL := fmt.Sprintf(
"INSERT INTO `%s` (`%s`) VALUES %s ON DUPLICATE KEY UPDATE %s",
req.TableName,
strings.Join(fields, "`,`"),
strings.Join(placeholderGroups, ","),
strings.Join(updateClauses, ","),
)
logs.CtxInfof(ctx, "[UpsertData] execute sql is %s, value is %v, batch is %d", upsertSQL, values, i)
result := m.db.WithContext(ctx).Exec(upsertSQL, values...)
if result.Error != nil {
return nil, fmt.Errorf("failed to upsert data: %v", result.Error)
}
total, inserted, updated := calculateInsertedUpdated(result.RowsAffected, len(currentBatch))
totalInserted += inserted
totalUpdated += updated
totalAffected += total
}
return &rdb.UpsertDataResponse{
AffectedRows: totalAffected,
InsertedRows: totalInserted,
UpdatedRows: totalUpdated,
UnchangedRows: int64(len(req.Data)) - totalAffected,
}, nil
}
func (m *mysqlService) getTablePrimaryKeys(ctx context.Context, tableName string) ([]string, error) {
query := `
SELECT COLUMN_NAME
FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE
WHERE TABLE_SCHEMA = DATABASE()
AND TABLE_NAME = ?
AND CONSTRAINT_NAME = 'PRIMARY'
ORDER BY ORDINAL_POSITION
`
var primaryKeys []string
rows, err := m.db.WithContext(ctx).Raw(query, tableName).Rows()
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var columnName string
if err := rows.Scan(&columnName); err != nil {
return nil, err
}
primaryKeys = append(primaryKeys, columnName)
}
return primaryKeys, nil
}
// calculateInsertedUpdated 函数保持不变
func calculateInsertedUpdated(affectedRows int64, batchSize int) (int64, int64, int64) {
updated := int64(0)
inserted := affectedRows
if affectedRows > int64(batchSize) {
updated = affectedRows - int64(batchSize)
inserted = int64(batchSize) - updated
}
return inserted + updated, inserted, updated
}
// ExecuteSQL Execute SQL
func (m *mysqlService) ExecuteSQL(ctx context.Context, req *rdb.ExecuteSQLRequest) (*rdb.ExecuteSQLResponse, error) {
if req == nil {
return nil, fmt.Errorf("invalid request")
}
logs.CtxInfof(ctx, "[ExecuteSQL] req is %v", req)
var processedSQL string
var processedParams []interface{}
var err error
// Handle SQLType: if raw, do not process params
if req.SQLType == entity2.SQLType_Raw {
processedSQL = req.SQL
processedParams = nil
} else {
processedSQL, processedParams, err = m.processSliceParams(req.SQL, req.Params)
if err != nil {
return nil, fmt.Errorf("failed to process parameters: %v", err)
}
}
operation, err := sqlparser.NewSQLParser().GetSQLOperation(processedSQL)
if err != nil {
return nil, err
}
if operation != sqlparsercontract.OperationTypeSelect {
result := m.db.WithContext(ctx).Exec(processedSQL, processedParams...)
if result.Error != nil {
return nil, fmt.Errorf("failed to execute SQL: %v", result.Error)
}
resultSet := &entity2.ResultSet{
Columns: []string{},
Rows: []map[string]interface{}{},
AffectedRows: result.RowsAffected,
}
return &rdb.ExecuteSQLResponse{
ResultSet: resultSet,
}, nil
}
rows, err := m.db.WithContext(ctx).Raw(processedSQL, processedParams...).Rows()
if err != nil {
return nil, fmt.Errorf("failed to execute SQL: %v", err)
}
defer rows.Close()
columns, err := rows.Columns()
if err != nil {
return nil, fmt.Errorf("failed to get columns: %v", err)
}
resultSet := &entity2.ResultSet{
Columns: columns,
Rows: make([]map[string]interface{}, 0),
}
for rows.Next() {
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range values {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
return nil, fmt.Errorf("failed to scan row: %v", err)
}
rowData := make(map[string]interface{})
for i, col := range columns {
rowData[col] = values[i]
}
resultSet.Rows = append(resultSet.Rows, rowData)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error while reading rows: %v", err)
}
return &rdb.ExecuteSQLResponse{
ResultSet: resultSet,
}, nil
}
func (m *mysqlService) processSliceParams(sql string, params []interface{}) (string, []interface{}, error) {
if len(params) == 0 {
return sql, params, nil
}
processedParams := make([]interface{}, 0)
paramIndex := 0
resultSQL := ""
lastPos := 0
// get all ? positions
for i := 0; i < len(sql); i++ {
if sql[i] == '?' && paramIndex < len(params) {
resultSQL += sql[lastPos:i]
lastPos = i + 1
param := params[paramIndex]
paramIndex++
if m.isSlice(param) {
sliceValues, err := m.getSliceValues(param)
if err != nil {
return "", nil, err
}
if len(sliceValues) == 0 {
resultSQL += "(NULL)"
} else {
// (?, ?, ...)
placeholders := make([]string, len(sliceValues))
for j := range placeholders {
placeholders[j] = "?"
}
resultSQL += "(" + strings.Join(placeholders, ", ") + ")"
processedParams = append(processedParams, sliceValues...)
}
} else {
resultSQL += "?"
processedParams = append(processedParams, param)
}
}
}
resultSQL += sql[lastPos:]
return resultSQL, processedParams, nil
}
func (m *mysqlService) isSlice(param interface{}) bool {
if param == nil {
return false
}
rv := reflect.ValueOf(param)
return rv.Kind() == reflect.Slice && rv.Type().Elem().Kind() != reflect.Uint8 // exclude []byte
}
func (m *mysqlService) getSliceValues(param interface{}) ([]interface{}, error) {
rv := reflect.ValueOf(param)
if rv.Kind() != reflect.Slice {
return nil, fmt.Errorf("parameter is not a slice")
}
length := rv.Len()
values := make([]interface{}, length)
for i := 0; i < length; i++ {
values[i] = rv.Index(i).Interface()
}
return values, nil
}
func (m *mysqlService) genTableName(ctx context.Context) (string, error) {
id, err := m.generator.GenID(ctx)
if err != nil {
return "", err
}
return fmt.Sprintf("table_%d", id), nil
}
func (m *mysqlService) getTableInfo(ctx context.Context, tableName string) (*entity2.Table, error) {
tableInfoSQL := `
SELECT
TABLE_NAME,
TABLE_COLLATION,
AUTO_INCREMENT,
TABLE_COMMENT
FROM information_schema.TABLES
WHERE TABLE_SCHEMA = DATABASE()
AND TABLE_NAME = ?
`
var (
name string
collation *string
autoIncrement *int64
comment *string
)
err := m.db.WithContext(ctx).Raw(tableInfoSQL, tableName).Row().Scan(
&name,
&collation,
&autoIncrement,
&comment,
)
if err != nil {
return nil, err
}
columnsSQL := `
SELECT
COLUMN_NAME,
DATA_TYPE,
CHARACTER_MAXIMUM_LENGTH,
IS_NULLABLE,
COLUMN_DEFAULT,
EXTRA,
COLUMN_COMMENT
FROM information_schema.COLUMNS
WHERE TABLE_SCHEMA = DATABASE()
AND TABLE_NAME = ?
ORDER BY ORDINAL_POSITION
`
type columnInfo struct {
ColumnName string `gorm:"column:COLUMN_NAME"`
DataType string `gorm:"column:DATA_TYPE"`
CharLength *int `gorm:"column:CHARACTER_MAXIMUM_LENGTH"`
IsNullable string `gorm:"column:IS_NULLABLE"`
DefaultValue *string `gorm:"column:COLUMN_DEFAULT"`
Extra string `gorm:"column:EXTRA"`
ColumnComment *string `gorm:"column:COLUMN_COMMENT"`
}
var columnsData []columnInfo
err = m.db.WithContext(ctx).Raw(columnsSQL, tableName).Scan(&columnsData).Error
if err != nil {
return nil, err
}
columns := make([]*entity2.Column, len(columnsData))
for i, colData := range columnsData {
column := &entity2.Column{
Name: colData.ColumnName,
DataType: entity2.DataType(colData.DataType),
Length: colData.CharLength,
NotNull: colData.IsNullable == "NO",
DefaultValue: colData.DefaultValue,
AutoIncrement: strings.Contains(colData.Extra, "auto_increment"),
Comment: colData.ColumnComment,
}
columns[i] = column
}
indexesSQL := `
SELECT
INDEX_NAME,
NON_UNIQUE,
INDEX_TYPE,
GROUP_CONCAT(COLUMN_NAME ORDER BY SEQ_IN_INDEX)
FROM information_schema.STATISTICS
WHERE TABLE_SCHEMA = DATABASE()
AND TABLE_NAME = ?
GROUP BY INDEX_NAME, NON_UNIQUE, INDEX_TYPE
`
type indexInfo struct {
IndexName string `gorm:"column:INDEX_NAME"`
NonUnique int `gorm:"column:NON_UNIQUE"`
IndexType string `gorm:"column:INDEX_TYPE"`
Columns string `gorm:"column:GROUP_CONCAT"`
}
var indexesData []indexInfo
err = m.db.WithContext(ctx).Raw(indexesSQL, tableName).Scan(&indexesData).Error
if err != nil {
return nil, err
}
indexes := make([]*entity2.Index, 0, len(indexesData))
for _, idxData := range indexesData {
index := &entity2.Index{
Name: idxData.IndexName,
Type: entity2.IndexType(idxData.IndexType),
Columns: strings.Split(idxData.Columns, ","),
}
indexes = append(indexes, index)
}
return &entity2.Table{
Name: name,
Columns: columns,
Indexes: indexes,
Options: &entity2.TableOption{
Collate: collation,
AutoIncrement: autoIncrement,
Comment: comment,
},
}, nil
}
func (m *mysqlService) buildWhereClause(condition *rdb.ComplexCondition) (string, []interface{}, error) {
if condition == nil {
return "", nil, nil
}
if condition.Operator == "" {
condition.Operator = entity2.AND
}
var whereClause strings.Builder
values := make([]interface{}, 0)
for i, cond := range condition.Conditions {
if i > 0 {
whereClause.WriteString(fmt.Sprintf(" %s ", condition.Operator))
}
if cond.Operator == entity2.OperatorIn || cond.Operator == entity2.OperatorNotIn {
if m.isSlice(cond.Value) {
sliceValues, err := m.getSliceValues(cond.Value)
if err != nil {
return "", nil, fmt.Errorf("failed to process slice values: %v", err)
}
if len(sliceValues) == 0 {
whereClause.WriteString(fmt.Sprintf("`%s` %s (NULL)", cond.Field, string(cond.Operator)))
} else {
placeholders := make([]string, len(sliceValues))
for i := range placeholders {
placeholders[i] = "?"
}
whereClause.WriteString(fmt.Sprintf("`%s` %s (%s)", cond.Field, string(cond.Operator), strings.Join(placeholders, ",")))
values = append(values, sliceValues...)
}
} else {
return "", nil, fmt.Errorf("IN operator requires a slice of values")
}
} else if cond.Operator == entity2.OperatorIsNull || cond.Operator == entity2.OperatorIsNotNull {
whereClause.WriteString(fmt.Sprintf("`%s` %s", cond.Field, cond.Operator))
} else {
whereClause.WriteString(fmt.Sprintf("`%s` %s ?", cond.Field, cond.Operator))
values = append(values, cond.Value)
}
}
if len(condition.NestedConditions) > 0 {
whereClause.WriteString(" AND (")
for i, nested := range condition.NestedConditions {
if i > 0 {
whereClause.WriteString(fmt.Sprintf(" %s ", nested.Operator))
}
nestedClause, nestedValues, err := m.buildWhereClause(nested)
if err != nil {
return "", nil, err
}
whereClause.WriteString(nestedClause)
values = append(values, nestedValues...)
}
whereClause.WriteString(")")
}
if whereClause.Len() > 0 {
return " WHERE " + whereClause.String(), values, nil
}
return "", values, nil
}

View File

@@ -0,0 +1,891 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package rdb
import (
"context"
"database/sql"
"fmt"
"os"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"github.com/coze-dev/coze-studio/backend/infra/contract/rdb"
entity2 "github.com/coze-dev/coze-studio/backend/infra/contract/rdb/entity"
mock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/idgen"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
func setupTestDB(t *testing.T) (*gorm.DB, rdb.RDB) {
dsn := "root:root@tcp(127.0.0.1:3306)/opencoze?charset=utf8mb4&parseTime=True&loc=Local"
if os.Getenv("CI_JOB_NAME") != "" {
dsn = strings.ReplaceAll(dsn, "127.0.0.1", "mysql")
}
db, err := gorm.Open(mysql.Open(dsn))
assert.NoError(t, err)
ctrl := gomock.NewController(t)
idGen := mock.NewMockIDGenerator(ctrl)
idGen.EXPECT().GenID(gomock.Any()).Return(int64(123), nil).AnyTimes()
return db, NewService(db, idGen)
}
func cleanupTestDB(t *testing.T, db *gorm.DB, tableNames ...string) {
for _, tableName := range tableNames {
db.WithContext(context.Background()).Exec(fmt.Sprintf("DROP TABLE IF EXISTS `%s`", tableName))
}
}
func TestCreateTable(t *testing.T) {
db, svc := setupTestDB(t)
defer cleanupTestDB(t, db, "test_table")
length := 255
req := &rdb.CreateTableRequest{
Table: &entity2.Table{
Name: "test_table",
Columns: []*entity2.Column{
{
Name: "id",
DataType: entity2.TypeInt,
NotNull: true,
},
{
Name: "name",
DataType: entity2.TypeVarchar,
Length: &length,
NotNull: true,
},
{
Name: "created_at",
DataType: entity2.TypeTimestamp,
NotNull: true,
DefaultValue: func() *string {
val := "CURRENT_TIMESTAMP"
return &val
}(),
},
{
Name: "score",
DataType: entity2.TypeDouble,
NotNull: true,
DefaultValue: ptr.Of("60.5"),
},
},
Indexes: []*entity2.Index{
{
Name: "PRIMARY",
Type: entity2.PrimaryKey,
Columns: []string{"id"},
},
{
Name: "idx_name",
Type: entity2.NormalKey,
Columns: []string{"name"},
},
},
Options: &entity2.TableOption{
Comment: func() *string {
comment := "Test table created by unit test"
return &comment
}(),
},
},
}
resp, err := svc.CreateTable(context.Background(), req)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, req.Table.Name, resp.Table.Name)
var tableExists bool
err = db.Raw("SELECT 1 FROM information_schema.tables WHERE table_schema = DATABASE() AND table_name = ?", "test_table").Scan(&tableExists).Error
assert.NoError(t, err)
assert.True(t, tableExists)
}
func TestAlterTable(t *testing.T) {
db, svc := setupTestDB(t)
defer cleanupTestDB(t, db, "test_table")
t.Run("success", func(t *testing.T) {
err := db.Exec(`
CREATE TABLE IF NOT EXISTS test_table (
id INT NOT NULL AUTO_INCREMENT,
name VARCHAR(255) NOT NULL,
description VARCHAR(255) NOT NULL,
droped VARCHAR(255) NOT NULL,
PRIMARY KEY (id),
INDEX idx_name (name)
) COMMENT='Test table created by unit test'
`).Error
assert.NoError(t, err, "Failed to create test table")
length := 100
req := &rdb.AlterTableRequest{
TableName: "test_table",
Operations: []*rdb.AlterTableOperation{
{
Action: entity2.AddColumn,
Column: &entity2.Column{
Name: "email",
DataType: entity2.TypeVarchar,
Length: &length,
NotNull: false,
},
},
{
Action: entity2.ModifyColumn,
Column: &entity2.Column{
Name: "description",
DataType: entity2.TypeText,
NotNull: false,
},
},
{
Action: entity2.DropColumn,
Column: &entity2.Column{
Name: "droped",
DataType: entity2.TypeVarchar,
NotNull: false,
},
},
},
}
resp, err := svc.AlterTable(context.Background(), req)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, "test_table", resp.Table.Name)
var columnExists bool
err = db.Raw("SELECT 1 FROM information_schema.columns WHERE table_schema = DATABASE() AND table_name = ? AND column_name = ?", "test_table", "email").Scan(&columnExists).Error
assert.NoError(t, err)
assert.True(t, columnExists)
})
}
func TestGetTable(t *testing.T) {
db, svc := setupTestDB(t)
defer cleanupTestDB(t, db, "test_info_table")
t.Run("success", func(t *testing.T) {
err := db.Exec(`
CREATE TABLE IF NOT EXISTS test_info_table (
id INT NOT NULL AUTO_INCREMENT,
name VARCHAR(255) NOT NULL,
description TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (id),
INDEX idx_name (name)
) COMMENT='Table info test'
`).Error
assert.NoError(t, err, "Failed to create test table")
req := &rdb.GetTableRequest{
TableName: "test_info_table",
}
resp, err := svc.GetTable(context.Background(), req)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, "test_info_table", resp.Table.Name)
assert.Equal(t, len(resp.Table.Columns), 4)
columnMap := make(map[string]*entity2.Column)
for _, col := range resp.Table.Columns {
columnMap[col.Name] = col
}
assert.Contains(t, columnMap, "id")
assert.Contains(t, columnMap, "name")
assert.Contains(t, columnMap, "created_at")
})
t.Run("not found", func(t *testing.T) {
err := db.Exec(`
CREATE TABLE IF NOT EXISTS test_info_table (
id INT NOT NULL AUTO_INCREMENT,
name VARCHAR(255) NOT NULL,
description TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (id),
INDEX idx_name (name)
) COMMENT='Table info test'
`).Error
assert.NoError(t, err, "Failed to create test table")
req := &rdb.GetTableRequest{
TableName: "test_info_table_error_name",
}
resp, err := svc.GetTable(context.Background(), req)
assert.Error(t, err)
assert.Equal(t, err, sql.ErrNoRows)
assert.Nil(t, resp)
})
}
func TestInsertData(t *testing.T) {
db, svc := setupTestDB(t)
defer cleanupTestDB(t, db, "test_insert_table")
err := db.Exec(`
CREATE TABLE IF NOT EXISTS test_insert_table (
id INT NOT NULL AUTO_INCREMENT,
name VARCHAR(255) NOT NULL,
age INT,
PRIMARY KEY (id)
)
`).Error
assert.NoError(t, err, "Failed to create test table")
t.Run("success", func(t *testing.T) {
req := &rdb.InsertDataRequest{
TableName: "test_insert_table",
Data: []map[string]interface{}{
{
"name": "John Doe",
"age": 30,
},
{
"name": "Jane Smith",
"age": nil,
},
},
}
resp, err := svc.InsertData(context.Background(), req)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, int64(2), resp.AffectedRows)
var count int
err = db.Raw("SELECT COUNT(*) FROM test_insert_table").Scan(&count).Error
assert.NoError(t, err)
assert.Equal(t, 2, count)
})
t.Run("table name error", func(t *testing.T) {
req := &rdb.InsertDataRequest{
TableName: "test_insert_table_error_name",
Data: []map[string]interface{}{
{
"name": "John Doe",
"age": 30,
},
{
"name": "Jane Smith",
"age": nil,
},
},
}
resp, err := svc.InsertData(context.Background(), req)
assert.Error(t, err)
assert.Nil(t, resp)
})
}
func TestUpdateData(t *testing.T) {
db, svc := setupTestDB(t)
defer cleanupTestDB(t, db, "test_update_table")
err := db.Exec(`
CREATE TABLE IF NOT EXISTS test_update_table (
id INT NOT NULL AUTO_INCREMENT,
name VARCHAR(255) NOT NULL,
age INT,
status VARCHAR(20) DEFAULT 'active',
PRIMARY KEY (id)
)
`).Error
assert.NoError(t, err, "Failed to create test table")
err = db.Exec("INSERT INTO test_update_table (name, age) VALUES (?, ?), (?, ?)",
"John Doe", 30, "Jane Smith", 25).Error
assert.NoError(t, err, "Failed to insert test data")
t.Run("success", func(t *testing.T) {
req := &rdb.UpdateDataRequest{
TableName: "test_update_table",
Data: map[string]interface{}{
"age": 35,
"status": "updated",
},
Where: &rdb.ComplexCondition{
Conditions: []*rdb.Condition{
{
Field: "name",
Operator: entity2.OperatorEqual,
Value: "John Doe",
},
},
Operator: entity2.AND,
},
}
resp, err := svc.UpdateData(context.Background(), req)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, int64(1), resp.AffectedRows)
type Result struct {
Age int
Status string
}
var result Result
err = db.Raw("SELECT age, status FROM test_update_table WHERE name = ?", "John Doe").Scan(&result).Error
assert.NoError(t, err)
assert.Equal(t, 35, result.Age)
assert.Equal(t, "updated", result.Status)
})
}
func TestDeleteData(t *testing.T) {
db, svc := setupTestDB(t)
defer cleanupTestDB(t, db, "test_delete_table")
err := db.Exec(`
CREATE TABLE IF NOT EXISTS test_delete_table (
id INT NOT NULL AUTO_INCREMENT,
name VARCHAR(255) NOT NULL,
age INT,
PRIMARY KEY (id)
)
`).Error
assert.NoError(t, err, "Failed to create test table")
err = db.Exec("INSERT INTO test_delete_table (name, age) VALUES (?, ?), (?, ?), (?, ?)",
"John Doe", 30, "Jane Smith", 25, "Bob Johnson", 40).Error
assert.NoError(t, err, "Failed to insert test data")
t.Run("success", func(t *testing.T) {
req := &rdb.DeleteDataRequest{
TableName: "test_delete_table",
Where: &rdb.ComplexCondition{
Conditions: []*rdb.Condition{
{
Field: "age",
Operator: entity2.OperatorGreaterEqual,
Value: 30,
},
},
Operator: entity2.AND,
},
}
resp, err := svc.DeleteData(context.Background(), req)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, int64(2), resp.AffectedRows)
var count int
err = db.Raw("SELECT COUNT(*) FROM test_delete_table").Scan(&count).Error
assert.NoError(t, err)
assert.Equal(t, 1, count)
})
}
func TestSelectData(t *testing.T) {
db, svc := setupTestDB(t)
defer cleanupTestDB(t, db, "test_select_table")
err := db.Exec(`
CREATE TABLE IF NOT EXISTS test_select_table (
id INT NOT NULL AUTO_INCREMENT,
name VARCHAR(255) NOT NULL,
age BIGINT,
status VARCHAR(20) DEFAULT 'active',
score FLOAT,
score2 DOUBLE DEFAULT '90.5',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (id)
)
`).Error
assert.NoError(t, err, "Failed to create test table")
err = db.Exec(`
INSERT INTO test_select_table (name, age, status, score, score2) VALUES
(?, ?, ?, ?, ?), (?, ?, ?, ?, ?), (?, ?, ?, ?, ?), (?, ?, ?, ?, ?)`,
"John Doe", 30, "active", 1.1, 89.55554444,
"Jane Smith", 25, "active", 1.2, 90.55554444,
"Bob Johnson", 40, "inactive", 1.3, 91.55554444,
"Alice Brown", 35, "active", nil, 92.55554444).Error
assert.NoError(t, err, "Failed to insert test data")
t.Run("success", func(t *testing.T) {
req := &rdb.SelectDataRequest{
TableName: "test_select_table",
Fields: []string{"id", "name", "age", "created_at", "score", "score2"},
Where: &rdb.ComplexCondition{
Conditions: []*rdb.Condition{
{
Field: "status",
Operator: entity2.OperatorEqual,
Value: "active",
},
{
Field: "age",
Operator: entity2.OperatorGreaterEqual,
Value: 25,
},
},
Operator: entity2.AND,
},
OrderBy: []*rdb.OrderBy{
{
Field: "age",
Direction: entity2.SortDirectionDesc,
},
},
Limit: func() *int { limit := 2; return &limit }(),
Offset: func() *int { offset := 0; return &offset }(),
}
resp, err := svc.SelectData(context.Background(), req)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, int64(3), resp.Total)
assert.Len(t, resp.ResultSet.Rows, 2)
if len(resp.ResultSet.Rows) > 0 {
firstRow := resp.ResultSet.Rows[0]
assert.Equal(t, "Alice Brown", string(firstRow["name"].([]uint8)))
assert.Equal(t, int64(35), firstRow["age"])
assert.Equal(t, int64(35), firstRow["age"].(int64))
assert.Equal(t, nil, firstRow["score"])
timeR := firstRow["created_at"].(time.Time)
assert.False(t, timeR.IsZero())
assert.Nil(t, firstRow["score"])
assert.Equal(t, 92.55554444, firstRow["score2"].(float64))
}
})
t.Run("success", func(t *testing.T) {
req := &rdb.SelectDataRequest{
TableName: "test_select_table",
Fields: []string{"id", "name", "age", "created_at", "score"},
Where: &rdb.ComplexCondition{
Conditions: []*rdb.Condition{
{
Field: "age",
Operator: entity2.OperatorIn,
Value: []int{30, 25, 18},
},
},
Operator: entity2.AND,
},
OrderBy: []*rdb.OrderBy{
{
Field: "age",
Direction: entity2.SortDirectionDesc,
},
},
}
resp, err := svc.SelectData(context.Background(), req)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, int64(2), resp.Total)
assert.Len(t, resp.ResultSet.Rows, 2)
if len(resp.ResultSet.Rows) > 0 {
firstRow := resp.ResultSet.Rows[0]
assert.Equal(t, "John Doe", string(firstRow["name"].([]uint8)))
assert.Equal(t, int64(30), firstRow["age"])
assert.Equal(t, float32(1.1), firstRow["score"].(float32))
}
})
}
func TestExecuteSQL(t *testing.T) {
t.Run("success", func(t *testing.T) {
db, svc := setupTestDB(t)
defer cleanupTestDB(t, db, "test_sql_table")
err := db.Exec(`
CREATE TABLE IF NOT EXISTS test_sql_table (
id INT NOT NULL AUTO_INCREMENT,
name VARCHAR(255) NOT NULL,
age INT,
PRIMARY KEY (id)
)
`).Error
assert.NoError(t, err, "Failed to create test table")
err = db.Exec("INSERT INTO test_sql_table (name, age) VALUES (?, ?), (?, ?)",
"John Doe", 30, "Jane Smith", 25).Error
assert.NoError(t, err, "Failed to insert test data")
req := &rdb.ExecuteSQLRequest{
SQL: "SELECT id, name, age FROM test_sql_table WHERE age in ? and name in ? ORDER BY age DESC",
Params: []interface{}{[]int{30, 25}, []string{"John Doe", "Jane Smith"}},
}
resp, err := svc.ExecuteSQL(context.Background(), req)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Len(t, resp.ResultSet.Rows, 2)
assert.Equal(t, []string{"id", "name", "age"}, resp.ResultSet.Columns)
if len(resp.ResultSet.Rows) > 0 {
firstRow := resp.ResultSet.Rows[0]
assert.Equal(t, "John Doe", string(firstRow["name"].([]uint8)))
}
rawReq := &rdb.ExecuteSQLRequest{
SQL: "SELECT id, name, age FROM test_sql_table WHERE age in (30, 25) and name in (\"John Doe\", \"Jane Smith\") ORDER BY age DESC",
}
rawResp, err := svc.ExecuteSQL(context.Background(), rawReq)
assert.NoError(t, err)
assert.NotNil(t, rawResp)
assert.Len(t, rawResp.ResultSet.Rows, 2)
assert.Equal(t, []string{"id", "name", "age"}, rawResp.ResultSet.Columns)
if len(rawResp.ResultSet.Rows) > 0 {
firstRow := rawResp.ResultSet.Rows[0]
assert.Equal(t, "John Doe", string(firstRow["name"].([]uint8)))
}
})
t.Run("success", func(t *testing.T) {
db, svc := setupTestDB(t)
defer cleanupTestDB(t, db, "test_sql_table")
err := db.Exec(`
CREATE TABLE IF NOT EXISTS test_sql_table (
id INT NOT NULL AUTO_INCREMENT,
name VARCHAR(255) NOT NULL,
age INT,
PRIMARY KEY (id)
)
`).Error
assert.NoError(t, err, "Failed to create test table")
err = db.Exec("INSERT INTO test_sql_table (name, age) VALUES (?, ?), (?, ?)",
"John Doe", 30, "Jane Smith", 25).Error
assert.NoError(t, err, "Failed to insert test data")
req := &rdb.ExecuteSQLRequest{
SQL: "SELECT id, name, age FROM test_sql_table WHERE age in (?, ?) and name in (?, ?) ORDER BY age DESC",
Params: []interface{}{30, 25, "John Doe", "Jane Smith"},
}
resp, err := svc.ExecuteSQL(context.Background(), req)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, len(resp.ResultSet.Rows), 2)
assert.Equal(t, []string{"id", "name", "age"}, resp.ResultSet.Columns)
if len(resp.ResultSet.Rows) > 0 {
firstRow := resp.ResultSet.Rows[0]
assert.Equal(t, "John Doe", string(firstRow["name"].([]uint8)))
}
})
}
func TestUpsertData(t *testing.T) {
t.Run("insert new records", func(t *testing.T) {
db, svc := setupTestDB(t)
defer cleanupTestDB(t, db, "test_upsert_table")
err := db.Exec(`
CREATE TABLE IF NOT EXISTS test_upsert_table (
id INT NOT NULL,
name VARCHAR(255) NOT NULL,
age INT,
status VARCHAR(20) DEFAULT 'active',
PRIMARY KEY (id),
UNIQUE KEY idx_name (name)
)
`).Error
assert.NoError(t, err, "Failed to create test table")
req := &rdb.UpsertDataRequest{
TableName: "test_upsert_table",
Data: []map[string]interface{}{
{
"id": 1,
"name": "John Doe",
"age": 30,
"status": "active",
},
{
"id": 2,
"name": "Jane Smith",
"age": 25,
"status": "active",
},
},
Keys: []string{"name"},
}
resp, err := svc.UpsertData(context.Background(), req)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, int64(2), resp.InsertedRows)
assert.Equal(t, int64(0), resp.UpdatedRows)
var count int
err = db.Raw("SELECT COUNT(*) FROM test_upsert_table WHERE name IN (?, ?)",
"John Doe", "Jane Smith").Scan(&count).Error
assert.NoError(t, err)
assert.Equal(t, 2, count)
})
t.Run("error", func(t *testing.T) {
db, svc := setupTestDB(t)
defer cleanupTestDB(t, db, "test_upsert_table")
err := db.Exec(`
CREATE TABLE IF NOT EXISTS test_upsert_table (
id INT NOT NULL,
name VARCHAR(255) NOT NULL,
age INT,
status VARCHAR(20) DEFAULT 'active',
PRIMARY KEY (id),
UNIQUE KEY idx_name (name)
)
`).Error
assert.NoError(t, err, "Failed to create test table")
reqInsert := &rdb.UpsertDataRequest{
TableName: "test_upsert_table",
Data: []map[string]interface{}{
{
"id": 1,
"name": "John Doe",
"age": 30,
"status": "active",
},
{
"id": 2,
"name": "Jane Smith",
"age": 25,
"status": "active",
},
},
Keys: []string{"name"},
}
_, err = svc.UpsertData(context.Background(), reqInsert)
assert.NoError(t, err)
req := &rdb.UpsertDataRequest{
TableName: "test_upsert_table",
Data: []map[string]interface{}{
{
"name": "New Person",
"age": 40,
"status": "active",
},
},
Keys: []string{"name"},
}
resp, err := svc.UpsertData(context.Background(), req)
assert.Nil(t, resp)
assert.Error(t, err)
})
t.Run("update existing records", func(t *testing.T) {
db, svc := setupTestDB(t)
defer cleanupTestDB(t, db, "test_upsert_table")
err := db.Exec(`
CREATE TABLE IF NOT EXISTS test_upsert_table (
id INT NOT NULL,
name VARCHAR(255) NOT NULL,
age INT,
status VARCHAR(20) DEFAULT 'active',
PRIMARY KEY (id),
UNIQUE KEY idx_name (name)
)
`).Error
assert.NoError(t, err, "Failed to create test table")
reqInsert := &rdb.UpsertDataRequest{
TableName: "test_upsert_table",
Data: []map[string]interface{}{
{
"id": 1,
"name": "John Doe",
"age": 30,
"status": "active",
},
{
"id": 2,
"name": "Jane Smith",
"age": 25,
"status": "active",
},
},
Keys: []string{"name"},
}
_, err = svc.UpsertData(context.Background(), reqInsert)
assert.NoError(t, err)
req := &rdb.UpsertDataRequest{
TableName: "test_upsert_table",
Data: []map[string]interface{}{
{
"id": 1,
"name": "John Doe",
"age": 35,
"status": "updated",
},
{
"id": 2,
"name": "Jane Smith",
"age": 25,
"status": "updated",
},
{
"id": 3,
"name": "New Person",
"age": 40,
"status": "active",
},
{
"id": 4,
"name": "New Person 2",
"age": 40,
"status": "active",
},
{
"id": 5,
"name": "New Person 3",
"age": 40,
"status": "active",
},
},
Keys: []string{"name"},
}
resp, err := svc.UpsertData(context.Background(), req)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, int64(3), resp.InsertedRows)
assert.Equal(t, int64(2), resp.UpdatedRows)
})
t.Run("use primary key when keys not specified", func(t *testing.T) {
db, svc := setupTestDB(t)
defer cleanupTestDB(t, db, "test_upsert_table")
err := db.Exec(`
CREATE TABLE IF NOT EXISTS test_upsert_table (
id INT NOT NULL,
name VARCHAR(255) NOT NULL,
age INT,
status VARCHAR(20) DEFAULT 'active',
PRIMARY KEY (age),
UNIQUE KEY idx_name (name)
)
`).Error
assert.NoError(t, err, "Failed to create test table")
reqInsert := &rdb.UpsertDataRequest{
TableName: "test_upsert_table",
Data: []map[string]interface{}{
{
"id": 1,
"name": "John Doe",
"age": 30,
"status": "active",
},
{
"id": 2,
"name": "Jane Smith",
"age": 25,
"status": "active",
},
},
}
_, err = svc.UpsertData(context.Background(), reqInsert)
assert.NoError(t, err)
req := &rdb.UpsertDataRequest{
TableName: "test_upsert_table",
Data: []map[string]interface{}{
{
"id": 1,
"name": "John Doe Updated",
"age": 30,
"status": "primary key updated",
},
{
"id": 3,
"name": "New Person",
"age": 40,
"status": "active",
},
{
"id": 4,
"name": "New Person 2",
"age": 45,
"status": "active",
},
},
}
resp, err := svc.UpsertData(context.Background(), req)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, int64(2), resp.InsertedRows)
assert.Equal(t, int64(1), resp.UpdatedRows)
req = &rdb.UpsertDataRequest{
TableName: "test_upsert_table",
Data: []map[string]interface{}{
{
"id": 1,
"name": "John Doe Updated",
"age": 30,
"status": "primary key updated",
},
{
"id": 3,
"name": "New Person",
"age": 40,
"status": "active",
},
{
"id": 4,
"name": "New Person 2",
"age": 45,
"status": "active update",
},
},
}
resp, err = svc.UpsertData(context.Background(), req)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, int64(2), resp.AffectedRows)
assert.Equal(t, int64(1), resp.UnchangedRows)
})
}