345 lines
		
	
	
		
			9.9 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			345 lines
		
	
	
		
			9.9 KiB
		
	
	
	
		
			Go
		
	
	
	
/*
 | 
						|
 * Copyright 2025 coze-dev Authors
 | 
						|
 *
 | 
						|
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
 * you may not use this file except in compliance with the License.
 | 
						|
 * You may obtain a copy of the License at
 | 
						|
 *
 | 
						|
 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
 *
 | 
						|
 * Unless required by applicable law or agreed to in writing, software
 | 
						|
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
 * See the License for the specific language governing permissions and
 | 
						|
 * limitations under the License.
 | 
						|
 */
 | 
						|
 | 
						|
package physicaltable
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"fmt"
 | 
						|
 | 
						|
	"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
 | 
						|
	model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/api/model/table"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/domain/memory/database/internal/convertor"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/infra/contract/rdb"
 | 
						|
	entity3 "github.com/coze-dev/coze-studio/backend/infra/contract/rdb/entity"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
 | 
						|
)
 | 
						|
 | 
						|
func CreatePhysicalTable(ctx context.Context, db rdb.RDB, columns []*entity3.Column) (*rdb.CreateTableResponse, error) {
 | 
						|
	table := &entity3.Table{
 | 
						|
		Columns: columns,
 | 
						|
	}
 | 
						|
	// get indexes
 | 
						|
	indexes := make([]*entity3.Index, 0)
 | 
						|
	indexes = append(indexes, &entity3.Index{
 | 
						|
		Name:    "PRIMARY",
 | 
						|
		Type:    entity3.PrimaryKey,
 | 
						|
		Columns: []string{database.DefaultIDColName},
 | 
						|
	})
 | 
						|
	indexes = append(indexes, &entity3.Index{
 | 
						|
		Name:    "idx_uid",
 | 
						|
		Type:    entity3.NormalKey,
 | 
						|
		Columns: []string{database.DefaultUidColName, database.DefaultCidColName},
 | 
						|
	})
 | 
						|
	table.Indexes = indexes
 | 
						|
 | 
						|
	physicalTableRes, err := db.CreateTable(ctx, &rdb.CreateTableRequest{Table: table})
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	return physicalTableRes, nil
 | 
						|
}
 | 
						|
 | 
						|
func CreateFieldInfo(fieldItems []*model.FieldItem) ([]*model.FieldItem, []*entity3.Column) {
 | 
						|
	columns := make([]*entity3.Column, len(fieldItems))
 | 
						|
 | 
						|
	fieldID := int64(1)
 | 
						|
	for i, field := range fieldItems {
 | 
						|
		field.AlterID = fieldID
 | 
						|
		field.PhysicalName = GetFieldPhysicsName(fieldID)
 | 
						|
 | 
						|
		columns[i] = &entity3.Column{
 | 
						|
			Name:     GetFieldPhysicsName(fieldID),
 | 
						|
			DataType: convertor.SwitchToDataType(field.Type),
 | 
						|
			NotNull:  field.MustRequired,
 | 
						|
			Comment:  &field.Desc,
 | 
						|
		}
 | 
						|
 | 
						|
		if field.Type == table.FieldItemType_Text && !field.MustRequired {
 | 
						|
			columns[i].DefaultValue = ptr.Of("")
 | 
						|
		}
 | 
						|
 | 
						|
		fieldID++ // field is incremented begin from 1
 | 
						|
	}
 | 
						|
 | 
						|
	columns = append(columns, GetDefaultColumns()...)
 | 
						|
 | 
						|
	return fieldItems, columns
 | 
						|
}
 | 
						|
 | 
						|
func GetDefaultColumns() []*entity3.Column {
 | 
						|
	return getDefaultColumns()
 | 
						|
}
 | 
						|
 | 
						|
func getDefaultColumns() []*entity3.Column {
 | 
						|
	return []*entity3.Column{
 | 
						|
		{
 | 
						|
			Name:          database.DefaultIDColName,
 | 
						|
			DataType:      entity3.TypeBigInt,
 | 
						|
			NotNull:       true,
 | 
						|
			AutoIncrement: true,
 | 
						|
		},
 | 
						|
		{
 | 
						|
			Name:     database.DefaultUidColName,
 | 
						|
			DataType: entity3.TypeVarchar,
 | 
						|
			NotNull:  true,
 | 
						|
		},
 | 
						|
		{
 | 
						|
			Name:     database.DefaultCidColName,
 | 
						|
			DataType: entity3.TypeVarchar,
 | 
						|
			NotNull:  true,
 | 
						|
		},
 | 
						|
		{
 | 
						|
			Name:         database.DefaultCreateTimeColName,
 | 
						|
			DataType:     entity3.TypeTimestamp,
 | 
						|
			NotNull:      true,
 | 
						|
			DefaultValue: ptr.Of("CURRENT_TIMESTAMP"),
 | 
						|
		},
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func GetTablePhysicsName(tableID int64) string {
 | 
						|
	return fmt.Sprintf("table_%d", tableID)
 | 
						|
}
 | 
						|
 | 
						|
func GetFieldPhysicsName(fieldID int64) string {
 | 
						|
	return fmt.Sprintf("f_%d", fieldID)
 | 
						|
}
 | 
						|
 | 
						|
// UpdateFieldInfo handles field information updates.
 | 
						|
// 1. If alterID exists, use alterID to update existing fields.
 | 
						|
// 2. If alterID does not exist, add new fields.
 | 
						|
// 3. Delete fields that have alterIDs not present in the new list.
 | 
						|
func UpdateFieldInfo(newFieldItems []*database.FieldItem, existingFieldItems []*database.FieldItem) ([]*database.FieldItem, []*entity3.Column, []string, error) {
 | 
						|
	existingFieldMap := make(map[int64]*database.FieldItem)
 | 
						|
	maxAlterID := int64(-1)
 | 
						|
	for _, field := range existingFieldItems {
 | 
						|
		if field.AlterID > 0 {
 | 
						|
			existingFieldMap[field.AlterID] = field
 | 
						|
			maxAlterID = max(maxAlterID, field.AlterID)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	newFieldIDs := make(map[int64]bool)
 | 
						|
 | 
						|
	updatedColumns := make([]*entity3.Column, 0, len(newFieldItems))
 | 
						|
	updatedFieldItems := make([]*database.FieldItem, 0, len(newFieldItems))
 | 
						|
 | 
						|
	for _, field := range newFieldItems {
 | 
						|
		if field.AlterID > 0 {
 | 
						|
			// update field
 | 
						|
			newFieldIDs[field.AlterID] = true
 | 
						|
			field.PhysicalName = GetFieldPhysicsName(field.AlterID)
 | 
						|
			updatedFieldItems = append(updatedFieldItems, field)
 | 
						|
 | 
						|
			updatedColumns = append(updatedColumns, &entity3.Column{
 | 
						|
				Name:     GetFieldPhysicsName(field.AlterID),
 | 
						|
				DataType: convertor.SwitchToDataType(field.Type),
 | 
						|
				NotNull:  field.MustRequired,
 | 
						|
				Comment:  &field.Desc,
 | 
						|
			})
 | 
						|
		} else {
 | 
						|
			fieldID := maxAlterID + 1 // auto increment begin from existing maxAlterID
 | 
						|
			maxAlterID++
 | 
						|
			field.AlterID = fieldID
 | 
						|
			field.PhysicalName = GetFieldPhysicsName(fieldID)
 | 
						|
			updatedFieldItems = append(updatedFieldItems, field)
 | 
						|
 | 
						|
			updatedColumns = append(updatedColumns, &entity3.Column{
 | 
						|
				Name:     GetFieldPhysicsName(fieldID),
 | 
						|
				DataType: convertor.SwitchToDataType(field.Type),
 | 
						|
				NotNull:  field.MustRequired,
 | 
						|
				Comment:  &field.Desc,
 | 
						|
			})
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	droppedColumns := make([]string, 0, len(existingFieldMap))
 | 
						|
	// get dropped columns
 | 
						|
	for alterID := range existingFieldMap {
 | 
						|
		if !newFieldIDs[alterID] {
 | 
						|
			droppedColumns = append(droppedColumns, GetFieldPhysicsName(alterID))
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return updatedFieldItems, updatedColumns, droppedColumns, nil
 | 
						|
}
 | 
						|
 | 
						|
// UpdatePhysicalTableWithDrops 更新物理表结构,包括显式指定要删除的列
 | 
						|
func UpdatePhysicalTableWithDrops(ctx context.Context, db rdb.RDB, existingTable *entity3.Table, newColumns []*entity3.Column, droppedColumns []string, tableName string) error {
 | 
						|
	// 创建列名到列的映射
 | 
						|
	existingColumnMap := make(map[string]*entity3.Column)
 | 
						|
	for _, col := range existingTable.Columns {
 | 
						|
		existingColumnMap[col.Name] = col
 | 
						|
	}
 | 
						|
 | 
						|
	// 收集要添加和修改的列
 | 
						|
	var columnsToAdd, columnsToModify []*entity3.Column
 | 
						|
 | 
						|
	// 查找要添加和修改的列
 | 
						|
	for _, newCol := range newColumns {
 | 
						|
		if _, exists := existingColumnMap[newCol.Name]; exists {
 | 
						|
			columnsToModify = append(columnsToModify, newCol)
 | 
						|
		} else {
 | 
						|
			columnsToAdd = append(columnsToAdd, newCol)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	// 应用变更到物理表
 | 
						|
	if len(columnsToAdd) > 0 || len(columnsToModify) > 0 || len(droppedColumns) > 0 {
 | 
						|
		// build AlterTableRequest
 | 
						|
		alterReq := &rdb.AlterTableRequest{
 | 
						|
			TableName:  tableName,
 | 
						|
			Operations: getOperation(columnsToAdd, columnsToModify, droppedColumns),
 | 
						|
		}
 | 
						|
 | 
						|
		// 执行表结构变更
 | 
						|
		_, err := db.AlterTable(ctx, alterReq)
 | 
						|
		if err != nil {
 | 
						|
			return err
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
// getOperation 将列的添加、修改和删除操作转换为 AlterTableOperation 数组
 | 
						|
func getOperation(columnsToAdd, columnsToModify []*entity3.Column, droppedColumns []string) []*rdb.AlterTableOperation {
 | 
						|
	operations := make([]*rdb.AlterTableOperation, 0)
 | 
						|
 | 
						|
	// 处理添加列操作
 | 
						|
	for _, column := range columnsToAdd {
 | 
						|
		operations = append(operations, &rdb.AlterTableOperation{
 | 
						|
			Action: entity3.AddColumn,
 | 
						|
			Column: column,
 | 
						|
		})
 | 
						|
	}
 | 
						|
 | 
						|
	// 处理修改列操作
 | 
						|
	for _, column := range columnsToModify {
 | 
						|
		operations = append(operations, &rdb.AlterTableOperation{
 | 
						|
			Action: entity3.ModifyColumn,
 | 
						|
			Column: column,
 | 
						|
		})
 | 
						|
	}
 | 
						|
 | 
						|
	// 处理删除列操作
 | 
						|
	for _, columnName := range droppedColumns {
 | 
						|
		operations = append(operations, &rdb.AlterTableOperation{
 | 
						|
			Action: entity3.DropColumn,
 | 
						|
			Column: &entity3.Column{Name: columnName},
 | 
						|
		})
 | 
						|
	}
 | 
						|
 | 
						|
	return operations
 | 
						|
}
 | 
						|
 | 
						|
func GetTemplateTypeMap() map[table.FieldItemType]string {
 | 
						|
	return map[table.FieldItemType]string{
 | 
						|
		table.FieldItemType_Boolean: "false",
 | 
						|
		table.FieldItemType_Number:  "0",
 | 
						|
		table.FieldItemType_Date:    "0001-01-01 00:00:00",
 | 
						|
		table.FieldItemType_Text:    "",
 | 
						|
		table.FieldItemType_Float:   "0",
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func GetCreateTimeField() *database.FieldItem {
 | 
						|
	return &database.FieldItem{
 | 
						|
		Name:          database.DefaultCreateTimeColName,
 | 
						|
		Desc:          "create time",
 | 
						|
		Type:          table.FieldItemType_Date,
 | 
						|
		MustRequired:  false,
 | 
						|
		IsSystemField: true,
 | 
						|
		AlterID:       103,
 | 
						|
		PhysicalName:  database.DefaultCreateTimeColName,
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func GetUidField() *database.FieldItem {
 | 
						|
	return &database.FieldItem{
 | 
						|
		Name:          database.DefaultUidColName,
 | 
						|
		Desc:          "user id",
 | 
						|
		Type:          table.FieldItemType_Text,
 | 
						|
		MustRequired:  false,
 | 
						|
		IsSystemField: true,
 | 
						|
		AlterID:       101,
 | 
						|
		PhysicalName:  database.DefaultUidColName,
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func GetConnectIDField() *database.FieldItem {
 | 
						|
	return &database.FieldItem{
 | 
						|
		Name:          database.DefaultCidColName,
 | 
						|
		Desc:          "connector id",
 | 
						|
		Type:          table.FieldItemType_Text,
 | 
						|
		MustRequired:  false,
 | 
						|
		IsSystemField: true,
 | 
						|
		AlterID:       104,
 | 
						|
		PhysicalName:  database.DefaultCidColName,
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func GetIDField() *database.FieldItem {
 | 
						|
	return &database.FieldItem{
 | 
						|
		Name:          database.DefaultIDColName,
 | 
						|
		Desc:          "primary_key",
 | 
						|
		Type:          table.FieldItemType_Number,
 | 
						|
		MustRequired:  false,
 | 
						|
		IsSystemField: true,
 | 
						|
		AlterID:       102,
 | 
						|
		PhysicalName:  database.DefaultIDColName,
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func GetDisplayCreateTimeField() *database.FieldItem {
 | 
						|
	return &database.FieldItem{
 | 
						|
		Name:          database.DefaultCreateTimeDisplayColName,
 | 
						|
		Desc:          "create time",
 | 
						|
		Type:          table.FieldItemType_Date,
 | 
						|
		MustRequired:  false,
 | 
						|
		IsSystemField: true,
 | 
						|
		AlterID:       103,
 | 
						|
		PhysicalName:  database.DefaultCreateTimeDisplayColName,
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func GetDisplayUidField() *database.FieldItem {
 | 
						|
	return &database.FieldItem{
 | 
						|
		Name:          database.DefaultUidDisplayColName,
 | 
						|
		Desc:          "user id",
 | 
						|
		Type:          table.FieldItemType_Text,
 | 
						|
		MustRequired:  false,
 | 
						|
		IsSystemField: true,
 | 
						|
		AlterID:       101,
 | 
						|
		PhysicalName:  database.DefaultUidDisplayColName,
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func GetDisplayIDField() *database.FieldItem {
 | 
						|
	return &database.FieldItem{
 | 
						|
		Name:          database.DefaultIDDisplayColName,
 | 
						|
		Desc:          "primary_key",
 | 
						|
		Type:          table.FieldItemType_Number,
 | 
						|
		MustRequired:  false,
 | 
						|
		IsSystemField: true,
 | 
						|
		AlterID:       102,
 | 
						|
		PhysicalName:  database.DefaultIDDisplayColName,
 | 
						|
	}
 | 
						|
}
 |