/* * Copyright 2025 coze-dev Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package service import ( "bytes" "context" "crypto/sha256" "encoding/base64" "errors" "fmt" "math/rand" "runtime/debug" "strconv" "time" "github.com/redis/go-redis/v9" "github.com/tealeg/xlsx/v3" "gorm.io/gorm" "github.com/coze-dev/coze-studio/backend/infra/contract/cache" "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database" "github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/bot_common" "github.com/coze-dev/coze-studio/backend/api/model/table" "github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossvariables" entity2 "github.com/coze-dev/coze-studio/backend/domain/memory/database/entity" "github.com/coze-dev/coze-studio/backend/domain/memory/database/internal/convertor" "github.com/coze-dev/coze-studio/backend/domain/memory/database/internal/dal/query" "github.com/coze-dev/coze-studio/backend/domain/memory/database/internal/physicaltable" "github.com/coze-dev/coze-studio/backend/domain/memory/database/internal/sheet" "github.com/coze-dev/coze-studio/backend/domain/memory/database/repository" "github.com/coze-dev/coze-studio/backend/infra/contract/idgen" "github.com/coze-dev/coze-studio/backend/infra/contract/rdb" entity3 "github.com/coze-dev/coze-studio/backend/infra/contract/rdb/entity" sqlparsercontract "github.com/coze-dev/coze-studio/backend/infra/contract/sqlparser" "github.com/coze-dev/coze-studio/backend/infra/contract/storage" "github.com/coze-dev/coze-studio/backend/infra/impl/sqlparser" "github.com/coze-dev/coze-studio/backend/pkg/errorx" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" "github.com/coze-dev/coze-studio/backend/pkg/lang/slices" "github.com/coze-dev/coze-studio/backend/pkg/logs" "github.com/coze-dev/coze-studio/backend/types/consts" "github.com/coze-dev/coze-studio/backend/types/errno" ) type databaseService struct { rdb rdb.RDB db *gorm.DB generator idgen.IDGenerator draftDAO repository.DraftDAO onlineDAO repository.OnlineDAO agentToDatabaseDAO repository.AgentToDatabaseDAO storage storage.Storage cache cache.Cmdable } func NewService(rdb rdb.RDB, db *gorm.DB, generator idgen.IDGenerator, storage storage.Storage, cacheCli cache.Cmdable) Database { return &databaseService{ rdb: rdb, db: db, generator: generator, draftDAO: repository.NewDraftDatabaseDAO(db, generator), onlineDAO: repository.NewOnlineDatabaseDAO(db, generator), agentToDatabaseDAO: repository.NewAgentToDatabaseDAO(db, generator), storage: storage, cache: cacheCli, } } func (d databaseService) CreateDatabase(ctx context.Context, req *CreateDatabaseRequest) (*CreateDatabaseResponse, error) { draftEntity, onlineEntity := req.Database, req.Database fieldItems, columns := physicaltable.CreateFieldInfo(req.Database.FieldList) // create physical draft table draftEntity.FieldList = fieldItems draftPhysicalTableRes, err := physicaltable.CreatePhysicalTable(ctx, d.rdb, columns) if err != nil { return nil, err } if draftPhysicalTableRes.Table == nil { return nil, fmt.Errorf("create draft table failed, columns info is %v", columns) } draftID, err := d.generator.GenID(ctx) if err != nil { return nil, errorx.WrapByCode(err, errno.ErrMemoryIDGenFailCode, errorx.KV("msg", "CreateDatabase")) } // create physical online table onlineEntity.FieldList = fieldItems onlinePhysicalTableRes, err := physicaltable.CreatePhysicalTable(ctx, d.rdb, columns) if err != nil { return nil, err } if onlinePhysicalTableRes.Table == nil { return nil, fmt.Errorf("create online table failed, columns info is %v", columns) } onlineID, err := d.generator.GenID(ctx) if err != nil { return nil, errorx.WrapByCode(err, errno.ErrMemoryIDGenFailCode, errorx.KV("msg", "CreateDatabase")) } // insert draft and online database info tx := query.Use(d.db).Begin() if tx.Error != nil { return nil, fmt.Errorf("start transaction failed, %v", tx.Error) } if draftEntity.IconURI == "" { draftEntity.IconURI = consts.DefaultDatabaseIcon } if onlineEntity.IconURI == "" { onlineEntity.IconURI = consts.DefaultDatabaseIcon } defer func() { if r := recover(); r != nil { e := tx.Rollback() if e != nil { logs.CtxErrorf(ctx, "rollback failed, err=%v", e) } err = fmt.Errorf("catch panic: %v\nstack=%s", r, string(debug.Stack())) return } if err != nil { e := tx.Rollback() if e != nil { logs.CtxErrorf(ctx, "rollback failed, err=%v", e) } } }() _, err = d.draftDAO.CreateWithTX(ctx, tx, draftEntity, draftID, onlineID, draftPhysicalTableRes.Table.Name) if err != nil { return nil, err } onlineEntity, err = d.onlineDAO.CreateWithTX(ctx, tx, onlineEntity, draftID, onlineID, onlinePhysicalTableRes.Table.Name) if err != nil { return nil, err } err = tx.Commit() if err != nil { return nil, err } onlineEntity.ActualTableName = onlinePhysicalTableRes.Table.Name onlineEntity.ID = onlineID onlineEntity.DraftID = ptr.Of(draftID) objURL, uRrr := d.storage.GetObjectUrl(ctx, onlineEntity.IconURI) if uRrr == nil { onlineEntity.IconURL = objURL } return &CreateDatabaseResponse{ Database: onlineEntity, }, nil } func (d databaseService) UpdateDatabase(ctx context.Context, req *UpdateDatabaseRequest) (*UpdateDatabaseResponse, error) { // req.Database.ID is the id of online database input := req.Database onlineInfo, err := d.onlineDAO.Get(ctx, req.Database.ID) if err != nil { return nil, fmt.Errorf("get online database info failed: %v", err) } draftInfo, err := d.draftDAO.Get(ctx, onlineInfo.GetDraftID()) if err != nil { return nil, fmt.Errorf("get draft database info failed: %v", err) } draftEntity, onlineEntity := *input, *input draftEntity.ID = draftInfo.ID onlineEntity.ID = onlineInfo.ID fieldItems, columns, droppedColumns, err := physicaltable.UpdateFieldInfo(input.FieldList, onlineInfo.FieldList) if err != nil { return nil, err } draftEntity.FieldList = fieldItems onlineEntity.FieldList = fieldItems // get draft and online physical table info draftPhysicalTable, err := d.rdb.GetTable(ctx, &rdb.GetTableRequest{ TableName: draftInfo.ActualTableName, }) if err != nil { return nil, fmt.Errorf("get physical table info failed: %v", err) } onlinePhysicalTable, err := d.rdb.GetTable(ctx, &rdb.GetTableRequest{ TableName: onlineInfo.ActualTableName, }) if err != nil { return nil, fmt.Errorf("get physical table info failed: %v", err) } err = physicaltable.UpdatePhysicalTableWithDrops(ctx, d.rdb, draftPhysicalTable.Table, columns, droppedColumns, draftInfo.ActualTableName) if err != nil { return nil, fmt.Errorf("update draft physical table failed: %v", err) } err = physicaltable.UpdatePhysicalTableWithDrops(ctx, d.rdb, onlinePhysicalTable.Table, columns, droppedColumns, onlineInfo.ActualTableName) if err != nil { return nil, fmt.Errorf("update online physical table failed: %v", err) } tx := query.Use(d.db).Begin() if tx.Error != nil { return nil, fmt.Errorf("start transaction failed, %v", tx.Error) } if draftEntity.IconURI == "" { draftEntity.IconURI = consts.DefaultDatabaseIcon } if onlineEntity.IconURI == "" { onlineEntity.IconURI = consts.DefaultDatabaseIcon } defer func() { if r := recover(); r != nil { e := tx.Rollback() if e != nil { logs.CtxErrorf(ctx, "rollback failed, err=%v", e) } err = fmt.Errorf("catch panic: %v\nstack=%s", r, string(debug.Stack())) return } if err != nil { e := tx.Rollback() if e != nil { logs.CtxErrorf(ctx, "rollback failed, err=%v", e) } } }() _, err = d.draftDAO.UpdateWithTX(ctx, tx, &draftEntity) if err != nil { return nil, fmt.Errorf("update draft database info failed: %v", err) } onlineEntityUpdated, err := d.onlineDAO.UpdateWithTX(ctx, tx, &onlineEntity) if err != nil { return nil, fmt.Errorf("update online database info failed: %v", err) } err = tx.Commit() if err != nil { return nil, fmt.Errorf("commit transaction failed: %v", err) } return &UpdateDatabaseResponse{ Database: onlineEntityUpdated, }, nil } func (d databaseService) DeleteDatabase(ctx context.Context, req *DeleteDatabaseRequest) error { onlineInfo, err := d.onlineDAO.Get(ctx, req.ID) if err != nil { return fmt.Errorf("get online database info failed: %v", err) } draftInfo, err := d.draftDAO.Get(ctx, onlineInfo.GetDraftID()) if err != nil { return fmt.Errorf("get draft database info failed: %v", err) } tx := query.Use(d.db).Begin() if tx.Error != nil { return fmt.Errorf("start transaction failed, %v", tx.Error) } defer func() { if r := recover(); r != nil { e := tx.Rollback() if e != nil { logs.CtxErrorf(ctx, "rollback failed, err=%v", e) } err = fmt.Errorf("catch panic: %v\nstack=%s", r, string(debug.Stack())) return } if err != nil { e := tx.Rollback() if e != nil { logs.CtxErrorf(ctx, "rollback failed, err=%v", e) } } }() err = d.draftDAO.DeleteWithTX(ctx, tx, draftInfo.ID) if err != nil { return fmt.Errorf("delete draft database info failed: %v", err) } err = d.onlineDAO.DeleteWithTX(ctx, tx, onlineInfo.ID) if err != nil { return fmt.Errorf("delete online database info failed: %v", err) } err = tx.Commit() if err != nil { return fmt.Errorf("commit transaction failed: %v", err) } // delete draft physical table if draftInfo.ActualTableName != "" { _, err = d.rdb.DropTable(ctx, &rdb.DropTableRequest{ TableName: draftInfo.ActualTableName, }) if err != nil { logs.Errorf("drop draft physical table failed: %v, table_name=%s", err, draftInfo.ActualTableName) } } // delete online physical table if onlineInfo.ActualTableName != "" { _, err = d.rdb.DropTable(ctx, &rdb.DropTableRequest{ TableName: onlineInfo.ActualTableName, }) if err != nil { logs.Errorf("drop online physical table failed: %v, table_name=%s", err, onlineInfo.ActualTableName) } } return nil } func (d databaseService) MGetDatabase(ctx context.Context, req *MGetDatabaseRequest) (*MGetDatabaseResponse, error) { if len(req.Basics) == 0 { return &MGetDatabaseResponse{ Databases: []*entity2.Database{}, }, nil } onlineID2NeedSysFields := make(map[int64]bool) draftID2NeedSysFields := make(map[int64]bool) uniqueOnlineIDs := make([]int64, 0) uniqueDraftIDs := make([]int64, 0) idMap := make(map[int64]bool) for _, basic := range req.Basics { if !idMap[basic.ID] { idMap[basic.ID] = true if basic.TableType == table.TableType_OnlineTable { uniqueOnlineIDs = append(uniqueOnlineIDs, basic.ID) onlineID2NeedSysFields[basic.ID] = basic.NeedSysFields } else { uniqueDraftIDs = append(uniqueDraftIDs, basic.ID) draftID2NeedSysFields[basic.ID] = basic.NeedSysFields } } } onlineDatabases, err := d.onlineDAO.MGet(ctx, uniqueOnlineIDs) if err != nil { return nil, fmt.Errorf("batch get database info failed: %v", err) } draftDatabases, err := d.draftDAO.MGet(ctx, uniqueDraftIDs) if err != nil { return nil, fmt.Errorf("batch get database info failed: %v", err) } for _, onlineDatabase := range onlineDatabases { if needSys, ok := onlineID2NeedSysFields[onlineDatabase.ID]; ok && needSys { if onlineDatabase.FieldList == nil { onlineDatabase.FieldList = make([]*database.FieldItem, 0, 3) } onlineDatabase.FieldList = append(onlineDatabase.FieldList, physicaltable.GetDisplayCreateTimeField(), physicaltable.GetDisplayUidField(), physicaltable.GetDisplayIDField()) } if onlineDatabase.IconURI != "" { objURL, uRrr := d.storage.GetObjectUrl(ctx, onlineDatabase.IconURI) if uRrr == nil { onlineDatabase.IconURL = objURL } } } for _, draftDatabase := range draftDatabases { if needSys, ok := draftID2NeedSysFields[draftDatabase.ID]; ok && needSys { if draftDatabase.FieldList == nil { draftDatabase.FieldList = make([]*database.FieldItem, 0, 3) } draftDatabase.FieldList = append(draftDatabase.FieldList, physicaltable.GetDisplayCreateTimeField(), physicaltable.GetDisplayUidField(), physicaltable.GetDisplayIDField()) } if draftDatabase.IconURI != "" { objURL, uRrr := d.storage.GetObjectUrl(ctx, draftDatabase.IconURI) if uRrr == nil { draftDatabase.IconURL = objURL } } } databases := make([]*entity2.Database, 0) databases = append(databases, onlineDatabases...) databases = append(databases, draftDatabases...) return &MGetDatabaseResponse{ Databases: databases, }, nil } func (d databaseService) ListDatabase(ctx context.Context, req *ListDatabaseRequest) (*ListDatabaseResponse, error) { filter := &entity2.DatabaseFilter{ CreatorID: req.CreatorID, SpaceID: req.SpaceID, TableName: req.TableName, AppID: &req.AppID, } page := &entity2.Pagination{ Limit: req.Limit, Offset: req.Offset, } var databases []*entity2.Database var err error var count int64 if req.TableType == table.TableType_OnlineTable { databases, count, err = d.onlineDAO.List(ctx, filter, page, req.OrderBy) if err != nil { return nil, fmt.Errorf("list database failed: %v", err) } } else { databases, count, err = d.draftDAO.List(ctx, filter, page, req.OrderBy) if err != nil { return nil, fmt.Errorf("list database failed: %v", err) } } for _, database := range databases { if database.IconURI != "" { objURL, uRrr := d.storage.GetObjectUrl(ctx, database.IconURI) if uRrr == nil { database.IconURL = objURL } } } var hasMore bool if count <= int64(req.Limit)+int64(req.Offset) { hasMore = false } else { hasMore = true } return &ListDatabaseResponse{ Databases: databases, HasMore: hasMore, TotalCount: count, }, nil } func (d databaseService) AddDatabaseRecord(ctx context.Context, req *AddDatabaseRecordRequest) error { var tableInfo *entity2.Database var err error if req.TableType == table.TableType_OnlineTable { tableInfo, err = d.onlineDAO.Get(ctx, req.DatabaseID) } else { tableInfo, err = d.draftDAO.Get(ctx, req.DatabaseID) } if err != nil { return fmt.Errorf("get table info failed: %v", err) } if tableInfo.RwMode == table.BotTableRWMode_ReadOnly { return errorx.New(errno.ErrMemoryDatabaseCannotAddData) } physicalTableName := tableInfo.ActualTableName if physicalTableName == "" { return fmt.Errorf("physical table name is empty") } fieldList := append(tableInfo.FieldList, physicaltable.GetCreateTimeField(), physicaltable.GetUidField(), physicaltable.GetIDField(), physicaltable.GetConnectIDField()) fieldMap := slices.ToMap(fieldList, func(e *database.FieldItem) (string, *database.FieldItem) { return e.Name, e }) convertedRecords := make([]map[string]interface{}, 0, len(req.Records)) ids, err := d.generator.GenMultiIDs(ctx, len(req.Records)) if err != nil { return err } for index, recordMap := range req.Records { convertedRecord := make(map[string]interface{}) cid := consts.CozeConnectorID if req.ConnectorID != nil { cid = *req.ConnectorID } convertedRecord[database.DefaultUidColName] = req.UserID convertedRecord[database.DefaultCidColName] = cid convertedRecord[database.DefaultCreateTimeColName] = time.Now() convertedRecord[database.DefaultIDColName] = ids[index] if _, ok := recordMap[database.DefaultIDColName]; ok { delete(recordMap, database.DefaultIDColName) } for fieldName, value := range recordMap { if _, fOk := fieldMap[fieldName]; !fOk { return errorx.New(errno.ErrMemoryDatabaseFieldNotFoundCode, errorx.KV("msg", fmt.Sprintf("field %s not found in table definition", fieldName))) } fieldInfo, _ := fieldMap[fieldName] if value == "" && fieldInfo.MustRequired { return fmt.Errorf("field %s's value is required", fieldName) } physicalFieldName := fieldInfo.PhysicalName convertedValue, err := convertor.ConvertValueByType(value, fieldInfo.Type) if err != nil { return fmt.Errorf("convert value failed for field %s: %v, using original value", fieldName, err) } convertedRecord[physicalFieldName] = convertedValue } convertedRecords = append(convertedRecords, convertedRecord) } _, err = d.rdb.InsertData(ctx, &rdb.InsertDataRequest{ TableName: physicalTableName, Data: convertedRecords, }) if err != nil { return fmt.Errorf("insert data failed: %v", err) } return nil } func (d databaseService) UpdateDatabaseRecord(ctx context.Context, req *UpdateDatabaseRecordRequest) error { var tableInfo *database.Database var err error if req.TableType == table.TableType_OnlineTable { tableInfo, err = d.onlineDAO.Get(ctx, req.DatabaseID) } else { tableInfo, err = d.draftDAO.Get(ctx, req.DatabaseID) } if err != nil { return fmt.Errorf("get table info failed: %v", err) } if tableInfo.RwMode == table.BotTableRWMode_ReadOnly { return errorx.New(errno.ErrMemoryDatabaseCannotAddData) } physicalTableName := tableInfo.ActualTableName if physicalTableName == "" { return fmt.Errorf("physical table name is empty") } fieldList := append(tableInfo.FieldList, physicaltable.GetCreateTimeField(), physicaltable.GetUidField(), physicaltable.GetIDField(), physicaltable.GetConnectIDField()) fieldMap := slices.ToMap(fieldList, func(e *database.FieldItem) (string, *database.FieldItem) { return e.Name, e }) for _, record := range req.Records { idStr, exists := record[database.DefaultIDColName] if !exists { return fmt.Errorf("record must contain %s field for update", database.DefaultIDColName) } id, err := strconv.ParseInt(idStr, 10, 64) if err != nil { return fmt.Errorf("invalid ID format: %v", err) } updateData := make(map[string]interface{}) for fieldName, valueStr := range record { if fieldName == database.DefaultIDColName { continue } if _, fOk := fieldMap[fieldName]; !fOk { return errorx.New(errno.ErrMemoryDatabaseFieldNotFoundCode, errorx.KV("msg", fmt.Sprintf("field %s not found in table definition", fieldName))) } fieldInfo, _ := fieldMap[fieldName] if valueStr == "" && fieldInfo.MustRequired { return fmt.Errorf("field %s's value is required", fieldName) } physicalFieldName := fieldInfo.PhysicalName convertedValue, err := convertor.ConvertValueByType(valueStr, fieldInfo.Type) if err != nil { logs.Warnf("convert value failed for field %s: %v, using original value", fieldName, err) convertedValue = valueStr } updateData[physicalFieldName] = convertedValue } if len(updateData) == 0 { continue } condition := &rdb.ComplexCondition{ Conditions: []*rdb.Condition{ { Field: database.DefaultIDColName, Operator: entity3.OperatorEqual, Value: id, }, }, } if tableInfo.RwMode == table.BotTableRWMode_LimitedReadWrite { cond := &rdb.Condition{ Field: database.DefaultUidColName, Operator: entity3.OperatorEqual, Value: strconv.FormatInt(req.UserID, 10), } condition.Conditions = append(condition.Conditions, cond) } _, err = d.rdb.UpdateData(ctx, &rdb.UpdateDataRequest{ TableName: physicalTableName, Data: updateData, Where: condition, }) if err != nil { return fmt.Errorf("update data failed for ID %d: %v", id, err) } } return nil } func (d databaseService) DeleteDatabaseRecord(ctx context.Context, req *DeleteDatabaseRecordRequest) error { var tableInfo *entity2.Database var err error if req.TableType == table.TableType_OnlineTable { tableInfo, err = d.onlineDAO.Get(ctx, req.DatabaseID) } else { tableInfo, err = d.draftDAO.Get(ctx, req.DatabaseID) } if err != nil { return err } if tableInfo.RwMode == table.BotTableRWMode_ReadOnly { return errorx.New(errno.ErrMemoryDatabaseCannotAddData) } physicalTableName := tableInfo.ActualTableName if physicalTableName == "" { return fmt.Errorf("physical table name is empty") } var ids []interface{} for _, record := range req.Records { idStr, exists := record[database.DefaultIDColName] if !exists { return fmt.Errorf("record must contain %s field for deletion", database.DefaultIDColName) } id, err := strconv.ParseInt(idStr, 10, 64) if err != nil { return fmt.Errorf("invalid ID format: %v", err) } ids = append(ids, id) } condition := &rdb.ComplexCondition{ Conditions: []*rdb.Condition{ { Field: database.DefaultIDColName, Operator: entity3.OperatorIn, Value: ids, }, }, } if tableInfo.RwMode == table.BotTableRWMode_LimitedReadWrite { cond := &rdb.Condition{ Field: database.DefaultUidColName, Operator: entity3.OperatorEqual, Value: strconv.FormatInt(req.UserID, 10), } condition.Conditions = append(condition.Conditions, cond) } _, err = d.rdb.DeleteData(ctx, &rdb.DeleteDataRequest{ TableName: physicalTableName, Where: condition, }) if err != nil { return fmt.Errorf("delete data failed: %v", err) } return nil } func (d databaseService) ListDatabaseRecord(ctx context.Context, req *ListDatabaseRecordRequest) (*ListDatabaseRecordResponse, error) { var tableInfo *entity2.Database var err error if req.TableType == table.TableType_OnlineTable { tableInfo, err = d.onlineDAO.Get(ctx, req.DatabaseID) } else { tableInfo, err = d.draftDAO.Get(ctx, req.DatabaseID) } if err != nil { return nil, fmt.Errorf("get table info failed: %v", err) } physicalTableName := tableInfo.ActualTableName if physicalTableName == "" { return nil, fmt.Errorf("physical table name is empty") } fieldNameToPhysical := make(map[string]string) physicalToFieldName := make(map[string]string) physicalToFieldType := make(map[string]table.FieldItemType) for _, field := range tableInfo.FieldList { if field.AlterID > 0 { physicalName := physicaltable.GetFieldPhysicsName(field.AlterID) fieldNameToPhysical[field.Name] = physicalName physicalToFieldName[physicalName] = field.Name physicalToFieldType[physicalName] = field.Type } } var complexCondition *rdb.ComplexCondition if req.ConnectorID != nil && *req.ConnectorID > 0 { cond := &rdb.Condition{ Field: database.DefaultCidColName, Operator: entity3.OperatorEqual, Value: *req.ConnectorID, } complexCondition = &rdb.ComplexCondition{ Conditions: []*rdb.Condition{cond}, } } if req.TableType == table.TableType_DraftTable { if tableInfo.RwMode == table.BotTableRWMode_LimitedReadWrite { cond := &rdb.Condition{ Field: database.DefaultUidColName, Operator: entity3.OperatorEqual, Value: strconv.FormatInt(req.UserID, 10), } if complexCondition == nil { complexCondition = &rdb.ComplexCondition{ Conditions: []*rdb.Condition{cond}, } } else { complexCondition.Conditions = append(complexCondition.Conditions, cond) } } } limit := 50 if req.Limit > 0 { limit = req.Limit } orderBy := []*rdb.OrderBy{ { Field: database.DefaultCreateTimeColName, Direction: entity3.SortDirectionDesc, }, } selectResp, err := d.rdb.SelectData(ctx, &rdb.SelectDataRequest{ TableName: physicalTableName, Fields: []string{}, // Null means query all fields Where: complexCondition, OrderBy: orderBy, Limit: &limit, Offset: &req.Offset, }) if err != nil { return nil, fmt.Errorf("select data failed: %v", err) } if selectResp.ResultSet == nil { return &ListDatabaseRecordResponse{}, nil } records := convertor.ConvertResultSetToString(selectResp.ResultSet, physicalToFieldName, physicalToFieldType) var hasMore bool if selectResp.Total <= int64(req.Limit)+int64(req.Offset) { hasMore = false } else { hasMore = true } return &ListDatabaseRecordResponse{ Records: records, FieldList: tableInfo.FieldList, HasMore: hasMore, TotalCount: selectResp.Total, }, nil } func (d databaseService) GetDatabaseTemplate(ctx context.Context, req *GetDatabaseTemplateRequest) (*GetDatabaseTemplateResponse, error) { items := req.FieldItems tableName := req.TableName file := xlsx.NewFile() sheet, err := file.AddSheet("Sheet1") if err != nil { return nil, err } // add header header := sheet.AddRow() headerTitles := make([]string, 0) for i := range items { headerTitles = append(headerTitles, items[i].GetName()) } for _, title := range headerTitles { cell := header.AddCell() cell.Value = title } row := sheet.AddRow() for _, item := range items { row.AddCell().Value = physicaltable.GetTemplateTypeMap()[item.GetType()] } var buffer bytes.Buffer err = file.Write(&buffer) if err != nil { return nil, err } binaryData := buffer.Bytes() url, err := d.uploadFile(ctx, req.UserID, string(binaryData), tableName, "xlsx", nil) if err != nil { return nil, err } return &GetDatabaseTemplateResponse{ Url: url, }, nil } func (d databaseService) uploadFile(ctx context.Context, UserId int64, content string, bizType, fileType string, suffix *string) (string, error) { secret := createSecret(UserId, fileType) fileName := fmt.Sprintf("%d_%d_%s.%s", UserId, time.Now().UnixNano(), secret, fileType) if suffix != nil { fileName = fmt.Sprintf("%d_%d_%s_%s.%s", UserId, time.Now().UnixNano(), secret, *suffix, fileType) } objectName := fmt.Sprintf("%s/%s", bizType, fileName) err := d.storage.PutObject(ctx, objectName, []byte(content)) if err != nil { return "", err } url, err := d.storage.GetObjectUrl(ctx, objectName) if err != nil { return "", err } return url, nil } const baseWord = "1Aa2Bb3Cc4Dd5Ee6Ff7Gg8Hh9Ii0JjKkLlMmNnOoPpQqRrSsTtUuVvWwXxYyZz" func createSecret(uid int64, fileType string) string { num := 10 input := fmt.Sprintf("upload_%d_Ma*9)fhi_%d_gou_%s_rand_%d", uid, time.Now().Unix(), fileType, rand.Intn(100000)) hash := sha256.Sum256([]byte(fmt.Sprintf("%s", input))) hashString := base64.StdEncoding.EncodeToString(hash[:]) if len(hashString) > num { hashString = hashString[:num] } result := "" for _, char := range hashString { index := int(char) % 62 result += string(baseWord[index]) } return result } func (d databaseService) ExecuteSQL(ctx context.Context, req *ExecuteSQLRequest) (*ExecuteSQLResponse, error) { var tableInfo *entity2.Database var err error if req.TableType == table.TableType_OnlineTable { tableInfo, err = d.onlineDAO.Get(ctx, req.DatabaseID) } else { tableInfo, err = d.draftDAO.Get(ctx, req.DatabaseID) } if err != nil { return nil, fmt.Errorf("get table info failed: %v", err) } if tableInfo.RwMode == table.BotTableRWMode_ReadOnly && (req.OperateType == database.OperateType_Insert || req.OperateType == database.OperateType_Update || req.OperateType == database.OperateType_Delete) { return nil, errorx.New(errno.ErrMemoryDatabaseCannotAddData) } physicalTableName := tableInfo.ActualTableName if physicalTableName == "" { return nil, fmt.Errorf("physical table name is empty") } fieldNameToPhysical := make(map[string]string) physicalToFieldName := make(map[string]string) physicalToFieldType := make(map[string]table.FieldItemType) for _, field := range tableInfo.FieldList { if field.AlterID > 0 { physicalName := physicaltable.GetFieldPhysicsName(field.AlterID) fieldNameToPhysical[field.Name] = physicalName physicalToFieldName[physicalName] = field.Name physicalToFieldType[physicalName] = field.Type } } fieldNameToPhysical[database.DefaultIDDisplayColName] = database.DefaultIDColName fieldNameToPhysical[database.DefaultUidDisplayColName] = database.DefaultUidColName fieldNameToPhysical[database.DefaultCreateTimeDisplayColName] = database.DefaultCreateTimeColName var resultSet *entity3.ResultSet var rowsAffected int64 switch req.OperateType { case database.OperateType_Custom: resultSet, err = d.executeCustomSQL(ctx, req, physicalTableName, tableInfo, fieldNameToPhysical) if err != nil { return nil, err } case database.OperateType_Select: resultSet, err = d.executeSelectSQL(ctx, req, physicalTableName, tableInfo, fieldNameToPhysical) if err != nil { return nil, err } case database.OperateType_Insert: resultSet, err = d.executeInsertSQL(ctx, req, physicalTableName, tableInfo) if err != nil { return nil, err } case database.OperateType_Update: rowsAffected, err = d.executeUpdateSQL(ctx, req, physicalTableName, tableInfo, fieldNameToPhysical) if err != nil { return nil, err } case database.OperateType_Delete: rowsAffected, err = d.executeDeleteSQL(ctx, req, physicalTableName, tableInfo, fieldNameToPhysical) if err != nil { return nil, err } default: return nil, fmt.Errorf("unsupported operation type: %v", req.OperateType) } response := &ExecuteSQLResponse{ FieldList: tableInfo.FieldList, } if resultSet != nil && len(resultSet.Rows) > 0 { response.Records = convertor.ConvertResultSet(resultSet, physicalToFieldName, physicalToFieldType) } else { response.Records = make([]map[string]interface{}, 0) } // process special system fields for _, record := range response.Records { if val, ok := record[database.DefaultUidColName]; ok { delete(record, database.DefaultUidColName) record[database.DefaultUidDisplayColName] = val } if val, ok := record[database.DefaultCreateTimeColName]; ok { delete(record, database.DefaultCreateTimeColName) record[database.DefaultCreateTimeDisplayColName] = val } if val, ok := record[database.DefaultIDColName]; ok { delete(record, database.DefaultIDColName) record[database.DefaultIDDisplayColName] = val } if _, ok := record[database.DefaultCidColName]; ok { delete(record, database.DefaultCidColName) } } if resultSet != nil && resultSet.AffectedRows > 0 { response.RowsAffected = &resultSet.AffectedRows } if rowsAffected > 0 { response.RowsAffected = &rowsAffected } return response, nil } func (d databaseService) executeCustomSQL(ctx context.Context, req *ExecuteSQLRequest, physicalTableName string, tableInfo *entity2.Database, fieldNameToPhysical map[string]string) (*entity3.ResultSet, error) { var params []interface{} if req.SQL == nil || *req.SQL == "" { return nil, fmt.Errorf("SQL is empty") } operation, err := sqlparser.NewSQLParser().GetSQLOperation(*req.SQL) if err != nil { return nil, err } if tableInfo.RwMode == table.BotTableRWMode_ReadOnly && (operation == sqlparsercontract.OperationTypeInsert || operation == sqlparsercontract.OperationTypeUpdate || operation == sqlparsercontract.OperationTypeDelete) { return nil, fmt.Errorf("unsupported operation type: %v", operation) } if req.SQLParams != nil { params = make([]interface{}, 0, len(req.SQLParams)) for _, param := range req.SQLParams { value := param.Value if param.ISNull { value = nil } params = append(params, value) } } tableColumnMapping := map[string]sqlparsercontract.TableColumn{ tableInfo.TableName: { NewTableName: &physicalTableName, ColumnMap: fieldNameToPhysical, }, } parsedSQL, err := sqlparser.NewSQLParser().ParseAndModifySQL(*req.SQL, tableColumnMapping) if err != nil { return nil, fmt.Errorf("parse sql failed: %v", err) } insertResult := make([]map[string]interface{}, 0) if operation == sqlparsercontract.OperationTypeInsert { cid := consts.CozeConnectorID if req.ConnectorID != nil { cid = *req.ConnectorID } nums, err := sqlparser.NewSQLParser().GetInsertDataNums(parsedSQL) if err != nil { return nil, err } ids, err := d.generator.GenMultiIDs(ctx, nums) if err != nil { return nil, err } for _, id := range ids { insertResult = append(insertResult, map[string]interface{}{ database.DefaultIDColName: id, }) } existingCols := make(map[string]bool) if req.SQLType == database.SQLType_Raw { iIDs := make([]interface{}, len(ids)) for i, id := range ids { iIDs[i] = id } parsedSQL, _, err = sqlparser.NewSQLParser().AddColumnsToInsertSQL(parsedSQL, []sqlparsercontract.ColumnValue{ { ColName: database.DefaultCidColName, Value: cid, }, { ColName: database.DefaultUidColName, Value: req.UserID, }, }, &sqlparsercontract.PrimaryKeyValue{ColName: database.DefaultIDColName, Values: iIDs}, false) if err != nil { return nil, fmt.Errorf("add columns to insert sql failed: %v", err) } } else if req.SQLType == database.SQLType_Parameterized { parsedSQL, existingCols, err = sqlparser.NewSQLParser().AddColumnsToInsertSQL(parsedSQL, []sqlparsercontract.ColumnValue{ { ColName: database.DefaultCidColName, }, { ColName: database.DefaultUidColName, }, }, &sqlparsercontract.PrimaryKeyValue{ColName: database.DefaultIDColName}, true) if err != nil { return nil, fmt.Errorf("add columns to insert sql failed: %v", err) } if nums > 0 { if len(params)%nums != 0 { return nil, fmt.Errorf("number of params is not a multiple of number of rows") } paramsPerRow := len(params) / nums newParams := make([]interface{}, 0) for i := 0; i < nums; i++ { newParams = append(newParams, params[i*paramsPerRow:(i+1)*paramsPerRow]...) if !existingCols[database.DefaultCidColName] { newParams = append(newParams, cid) } if !existingCols[database.DefaultUidColName] { newParams = append(newParams, req.UserID) } if !existingCols[database.DefaultIDColName] { newParams = append(newParams, ids[i]) } } params = newParams } } } execResp, err := d.rdb.ExecuteSQL(ctx, &rdb.ExecuteSQLRequest{ SQL: parsedSQL, Params: params, SQLType: entity3.SQLType(req.SQLType), }) if err != nil { return nil, fmt.Errorf("execute SQL failed: %v", err) } if operation == sqlparsercontract.OperationTypeInsert { if execResp.ResultSet == nil { execResp.ResultSet = &entity3.ResultSet{ Rows: insertResult, } } else { execResp.ResultSet.Rows = insertResult } } return execResp.ResultSet, nil } func (d databaseService) executeSelectSQL(ctx context.Context, req *ExecuteSQLRequest, physicalTableName string, tableInfo *entity2.Database, fieldNameToPhysical map[string]string) (*entity3.ResultSet, error) { selectReq := &rdb.SelectDataRequest{ TableName: physicalTableName, Limit: int64PtrToIntPtr(req.Limit), Offset: int64PtrToIntPtr(req.Offset), } fieldList := append(tableInfo.FieldList, physicaltable.GetCreateTimeField(), physicaltable.GetUidField(), physicaltable.GetIDField(), physicaltable.GetConnectIDField()) fieldMap := slices.ToMap(fieldList, func(e *database.FieldItem) (string, *database.FieldItem) { return strconv.FormatInt(e.AlterID, 10), e }) if req.SelectFieldList != nil && !req.SelectFieldList.IsDistinct && len(req.SelectFieldList.FieldID) > 0 { fields := make([]string, 0, len(req.SelectFieldList.FieldID)) for _, fieldID := range req.SelectFieldList.FieldID { if _, exists := fieldMap[fieldID]; !exists { return nil, fmt.Errorf("fieldID %s does not exist", fieldID) } field, _ := fieldMap[fieldID] fields = append(fields, field.PhysicalName) } selectReq.Fields = fields } var complexCond *rdb.ComplexCondition var err error if req.Condition != nil { complexCond, err = convertCondition(ctx, req.Condition, fieldNameToPhysical, req.SQLParams) if err != nil { return nil, fmt.Errorf("convert condition failed: %v", err) } } // add rw mode if tableInfo.RwMode == table.BotTableRWMode_LimitedReadWrite && req.UserID != "" { cond := &rdb.Condition{ Field: database.DefaultUidColName, Operator: entity3.OperatorEqual, Value: req.UserID, } if complexCond == nil { complexCond = &rdb.ComplexCondition{ Conditions: []*rdb.Condition{cond}, } } else { complexCond.Conditions = append(complexCond.Conditions, cond) } } if complexCond != nil { selectReq.Where = complexCond } if len(req.OrderByList) > 0 { orderBy := make([]*rdb.OrderBy, 0, len(req.OrderByList)) for _, order := range req.OrderByList { physicalField := order.Field if mapped, exists := fieldNameToPhysical[order.Field]; exists { physicalField = mapped } orderBy = append(orderBy, &rdb.OrderBy{ Field: physicalField, Direction: convertSortDirection(order.Direction), }) } selectReq.OrderBy = orderBy } selectResp, err := d.rdb.SelectData(ctx, selectReq) if err != nil { return nil, fmt.Errorf("select data failed: %v", err) } return selectResp.ResultSet, nil } func (d databaseService) executeInsertSQL(ctx context.Context, req *ExecuteSQLRequest, physicalTableName string, tableInfo *entity2.Database) (*entity3.ResultSet, error) { if len(req.UpsertRows) == 0 { return nil, fmt.Errorf("no data to insert") } insertData := make([]map[string]interface{}, 0, len(req.UpsertRows)) ids, err := d.generator.GenMultiIDs(ctx, len(req.UpsertRows)) if err != nil { return nil, errorx.WrapByCode(err, errno.ErrMemoryIDGenFailCode, errorx.KV("msg", "executeInsertSQL")) } fieldList := append(tableInfo.FieldList, physicaltable.GetCreateTimeField(), physicaltable.GetUidField(), physicaltable.GetIDField(), physicaltable.GetConnectIDField()) fieldMap := slices.ToMap(fieldList, func(e *database.FieldItem) (string, *database.FieldItem) { return strconv.FormatInt(e.AlterID, 10), e }) sqlParams := req.SQLParams i := 0 insertResult := make([]map[string]interface{}, 0, len(req.UpsertRows)) for index, upsertRow := range req.UpsertRows { rowData := make(map[string]interface{}) cid := consts.CozeConnectorID if req.ConnectorID != nil { cid = *req.ConnectorID } if req.UserID != "" { rowData[database.DefaultUidColName] = req.UserID } rowData[database.DefaultCidColName] = cid rowData[database.DefaultCreateTimeColName] = time.Now() rowData[database.DefaultIDColName] = ids[index] for _, record := range upsertRow.Records { field, exists := fieldMap[record.FieldId] if !exists { return nil, errorx.New(errno.ErrMemoryDatabaseFieldNotFoundCode) } fieldVal := sqlParams[i].Value if sqlParams[i].ISNull || fieldVal == nil { rowData[field.PhysicalName] = nil i++ continue } convertedValue, err := convertor.ConvertValueByType(*fieldVal, field.Type) if err != nil { logs.Warnf("convert value failed: %v, using original value", err) rowData[field.PhysicalName] = *fieldVal } else { rowData[field.PhysicalName] = convertedValue } i++ } insertData = append(insertData, rowData) insertResult = append(insertResult, map[string]interface{}{ database.DefaultIDColName: ids[index], }) } insertResp, err := d.rdb.InsertData(ctx, &rdb.InsertDataRequest{ TableName: physicalTableName, Data: insertData, }) if err != nil { return nil, fmt.Errorf("insert data failed: %v", err) } return &entity3.ResultSet{ Rows: insertResult, AffectedRows: insertResp.AffectedRows, }, nil } func (d databaseService) executeUpdateSQL(ctx context.Context, req *ExecuteSQLRequest, physicalTableName string, tableInfo *entity2.Database, fieldNameToPhysical map[string]string) (int64, error) { if len(req.UpsertRows) == 0 || req.Condition == nil { return -1, fmt.Errorf("missing update data or condition") } fieldList := append(tableInfo.FieldList, physicaltable.GetCreateTimeField(), physicaltable.GetUidField(), physicaltable.GetIDField(), physicaltable.GetConnectIDField()) fieldMap := slices.ToMap(fieldList, func(e *database.FieldItem) (string, *database.FieldItem) { return strconv.FormatInt(e.AlterID, 10), e }) updateData := make(map[string]interface{}) index := 0 for _, record := range req.UpsertRows[0].Records { field, exists := fieldMap[record.FieldId] if !exists { return -1, errorx.New(errno.ErrMemoryDatabaseFieldNotFoundCode) } param := req.SQLParams[index] fieldVal := param.Value index++ if param.ISNull || fieldVal == nil { updateData[field.PhysicalName] = nil continue } convertedValue, err := convertor.ConvertValueByType(*fieldVal, field.Type) if err != nil { logs.Warnf("convert value failed: %v, using original value", err) updateData[field.PhysicalName] = *fieldVal } else { updateData[field.PhysicalName] = convertedValue } } condParams := req.SQLParams[index:] complexCond, err := convertCondition(ctx, req.Condition, fieldNameToPhysical, condParams) if err != nil { return -1, fmt.Errorf("convert condition failed: %v", err) } // add rw mode if tableInfo.RwMode == table.BotTableRWMode_LimitedReadWrite && req.UserID != "" { cond := &rdb.Condition{ Field: database.DefaultUidColName, Operator: entity3.OperatorEqual, Value: req.UserID, } if complexCond == nil { complexCond = &rdb.ComplexCondition{ Conditions: []*rdb.Condition{cond}, } } else { complexCond.Conditions = append(complexCond.Conditions, cond) } } updateResp, err := d.rdb.UpdateData(ctx, &rdb.UpdateDataRequest{ TableName: physicalTableName, Data: updateData, Where: complexCond, Limit: int64PtrToIntPtr(req.Limit), }) if err != nil { return -1, fmt.Errorf("update data failed: %v", err) } return updateResp.AffectedRows, nil } func (d databaseService) executeDeleteSQL(ctx context.Context, req *ExecuteSQLRequest, physicalTableName string, tableInfo *entity2.Database, fieldNameToPhysical map[string]string) (int64, error) { if req.Condition == nil { return -1, fmt.Errorf("missing delete condition") } complexCond, err := convertCondition(ctx, req.Condition, fieldNameToPhysical, req.SQLParams) if err != nil { return -1, fmt.Errorf("convert condition failed: %v", err) } // add rw mode if tableInfo.RwMode == table.BotTableRWMode_LimitedReadWrite && req.UserID != "" { cond := &rdb.Condition{ Field: database.DefaultUidColName, Operator: entity3.OperatorEqual, Value: req.UserID, } if complexCond == nil { complexCond = &rdb.ComplexCondition{ Conditions: []*rdb.Condition{cond}, } } else { complexCond.Conditions = append(complexCond.Conditions, cond) } } deleteResp, err := d.rdb.DeleteData(ctx, &rdb.DeleteDataRequest{ TableName: physicalTableName, Where: complexCond, Limit: int64PtrToIntPtr(req.Limit), }) if err != nil { return -1, fmt.Errorf("delete data failed: %v", err) } return deleteResp.AffectedRows, nil } func int64PtrToIntPtr(i64ptr *int64) *int { if i64ptr == nil { return nil } i := int(*i64ptr) return &i } func convertSortDirection(direction table.SortDirection) entity3.SortDirection { if direction == table.SortDirection_Desc { return entity3.SortDirectionDesc } return entity3.SortDirectionAsc } func convertCondition(ctx context.Context, cond *database.ComplexCondition, fieldMap map[string]string, params []*database.SQLParamVal) (*rdb.ComplexCondition, error) { if cond == nil { return nil, nil } result := &rdb.ComplexCondition{ Operator: convertor.ConvertLogicOperator(cond.Logic), } index := 0 if len(cond.Conditions) > 0 { conditions := make([]*rdb.Condition, 0, len(cond.Conditions)) for _, c := range cond.Conditions { leftField := c.Left if mapped, exists := fieldMap[c.Left]; exists { leftField = mapped } if c.Operation == database.Operation_IS_NULL || c.Operation == database.Operation_IS_NOT_NULL { conditions = append(conditions, &rdb.Condition{ Field: leftField, Operator: convertor.ConvertOperator(c.Operation), }) continue } if c.Operation == database.Operation_IN || c.Operation == database.Operation_NOT_IN { // c.Right: example: (?,?) qCount := 0 for i := 0; i < len(c.Right); i++ { if c.Right[i] == '?' { qCount++ } } if qCount == 0 { return nil, fmt.Errorf("IN/NOT_IN condition right side must contain ? placeholders") } vals := make([]interface{}, 0, qCount) for j := 0; j < qCount; j++ { if index >= len(params) { return nil, fmt.Errorf("not enough params for IN/NOT_IN condition") } if params[index].ISNull || params[index].Value == nil { index++ continue } vals = append(vals, decryptSysUUIDKey(ctx, leftField, *params[index].Value)) index++ } conditions = append(conditions, &rdb.Condition{ Field: leftField, Operator: convertor.ConvertOperator(c.Operation), Value: vals, }) continue } if params[index].ISNull || params[index].Value == nil { index++ continue } conditions = append(conditions, &rdb.Condition{ Field: leftField, Operator: convertor.ConvertOperator(c.Operation), Value: decryptSysUUIDKey(ctx, leftField, *params[index].Value), }) index++ } result.Conditions = conditions } // if cond.NestedConditions != nil { // nested, err := convertCondition(cond.NestedConditions, fieldMap, params) // if err != nil { // return nil, err // } // result.NestedConditions = []*rdb.ComplexCondition{nested} // } return result, nil } func decryptSysUUIDKey(ctx context.Context, leftField, value string) string { if leftField == database.DefaultUidDisplayColName || leftField == database.DefaultUidColName { decryptVal := crossvariables.DefaultSVC().DecryptSysUUIDKey(ctx, value) if decryptVal != nil { value = decryptVal.ConnectorUID } } return value } func (d databaseService) BindDatabase(ctx context.Context, req *BindDatabaseToAgentRequest) error { draft, err := d.MGetDatabase(ctx, &MGetDatabaseRequest{ Basics: []*database.DatabaseBasic{ { ID: req.DraftDatabaseID, TableType: table.TableType_DraftTable, }, }, }) if err != nil { return err } if len(draft.Databases) == 0 { return fmt.Errorf("online table not found, id: %d", req.DraftDatabaseID) } onlineID := draft.Databases[0].GetOnlineID() relations := []*database.AgentToDatabase{ { AgentID: req.AgentID, DatabaseID: onlineID, TableType: table.TableType_OnlineTable, }, { AgentID: req.AgentID, DatabaseID: req.DraftDatabaseID, TableType: table.TableType_DraftTable, }, } _, err = d.agentToDatabaseDAO.BatchCreate(ctx, relations) if err != nil { return fmt.Errorf("failed to bind databases to agent: %w", err) } return nil } func (d databaseService) UnBindDatabase(ctx context.Context, req *UnBindDatabaseToAgentRequest) error { draft, err := d.MGetDatabase(ctx, &MGetDatabaseRequest{ Basics: []*database.DatabaseBasic{ { ID: req.DraftDatabaseID, TableType: table.TableType_DraftTable, }, }, }) if err != nil { return err } if len(draft.Databases) == 0 { return fmt.Errorf("online table not found, id: %d", req.DraftDatabaseID) } onlineID := draft.Databases[0].GetOnlineID() relations := []*database.AgentToDatabaseBasic{ { AgentID: req.AgentID, DatabaseID: onlineID, }, { AgentID: req.AgentID, DatabaseID: req.DraftDatabaseID, }, } err = d.agentToDatabaseDAO.BatchDelete(ctx, relations) if err != nil { return fmt.Errorf("failed to unbind databases from agent: %w", err) } return nil } func (d databaseService) MGetDatabaseByAgentID(ctx context.Context, req *MGetDatabaseByAgentIDRequest) (*MGetDatabaseByAgentIDResponse, error) { if req == nil { return nil, fmt.Errorf("invalid request: request is nil") } relations, err := d.agentToDatabaseDAO.ListByAgentID(ctx, req.AgentID, req.TableType) if err != nil { return nil, err } mGetBasics := make([]*database.DatabaseBasic, 0, len(relations)) for _, relation := range relations { mGetBasics = append(mGetBasics, &database.DatabaseBasic{ ID: relation.DatabaseID, TableType: req.TableType, NeedSysFields: req.NeedSysFields, }) } databases, err := d.MGetDatabase(ctx, &MGetDatabaseRequest{Basics: mGetBasics}) if err != nil { return nil, err } return &MGetDatabaseByAgentIDResponse{ Databases: databases.Databases, }, nil } // PublishDatabase return online database according to draft database info func (d databaseService) PublishDatabase(ctx context.Context, req *PublishDatabaseRequest) (*PublishDatabaseResponse, error) { if req == nil { return nil, fmt.Errorf("invalid request: request is nil") } relationResp, err := d.MGetRelationsByAgentID(ctx, &MGetRelationsByAgentIDRequest{ AgentID: req.AgentID, TableType: table.TableType_DraftTable, }) if err != nil { return nil, err } if len(relationResp.Relations) == 0 { return &PublishDatabaseResponse{}, nil } dBasics := make([]*database.DatabaseBasic, 0, len(relationResp.Relations)) for _, draftR := range relationResp.Relations { dBasics = append(dBasics, &database.DatabaseBasic{ ID: draftR.DatabaseID, TableType: table.TableType_DraftTable, NeedSysFields: false, }) } draftDatabaseResp, err := d.MGetDatabase(ctx, &MGetDatabaseRequest{ Basics: dBasics, }) if err != nil { return nil, err } oBasics := make([]*database.DatabaseBasic, 0, len(draftDatabaseResp.Databases)) for _, draft := range draftDatabaseResp.Databases { oBasics = append(oBasics, &database.DatabaseBasic{ ID: draft.GetOnlineID(), TableType: table.TableType_OnlineTable, NeedSysFields: false, }) } onlineDatabaseResp, err := d.MGetDatabase(ctx, &MGetDatabaseRequest{ Basics: oBasics, }) if err != nil { return nil, err } results := make([]*bot_common.Database, 0, len(onlineDatabaseResp.Databases)) for _, online := range onlineDatabaseResp.Databases { fields := make([]*bot_common.FieldItem, 0, len(online.FieldList)) for _, field := range online.FieldList { fields = append(fields, &bot_common.FieldItem{ Name: ptr.Of(field.Name), Desc: ptr.Of(field.Desc), Type: ptr.Of(bot_common.FieldItemType(field.Type)), MustRequired: ptr.Of(field.MustRequired), AlterId: ptr.Of(field.AlterID), Id: ptr.Of(int64(0)), }) } results = append(results, &bot_common.Database{ TableId: ptr.Of(strconv.FormatInt(online.ID, 10)), TableName: ptr.Of(online.TableName), TableDesc: ptr.Of(online.TableDesc), FieldList: fields, RWMode: ptr.Of(bot_common.BotTableRWMode(online.RwMode)), }) } return &PublishDatabaseResponse{ OnlineDatabases: results, }, nil } func (d databaseService) MGetRelationsByAgentID(ctx context.Context, req *MGetRelationsByAgentIDRequest) (*MGetRelationsByAgentIDResponse, error) { if req == nil { return nil, fmt.Errorf("invalid request: request is nil") } relations, err := d.agentToDatabaseDAO.ListByAgentID(ctx, req.AgentID, req.TableType) if err != nil { return nil, err } return &MGetRelationsByAgentIDResponse{ Relations: relations, }, nil } func (d databaseService) GetDatabaseTableSchema(ctx context.Context, req *GetDatabaseTableSchemaRequest) (*GetDatabaseTableSchemaResponse, error) { parser := &sheet.TosTableParser{ UserID: req.UserID, DocumentSource: database.DocumentSourceType_Document, TosURI: req.TosURL, TosServ: d.storage, } res, extra, err := parser.GetTableDataBySheetIDx(ctx, entity2.TableReaderMeta{ TosMaxLine: 100000, HeaderLineIdx: req.TableSheet.HeaderLineIdx, SheetId: req.TableSheet.SheetID, StartLineIdx: req.TableSheet.StartLineIdx, ReaderMethod: database.TableReadDataMethodHead, ReadLineCnt: 20, }) if err != nil { return nil, err } res.Columns, err = parser.PredictColumnType(res.Columns, res.SampleData, req.TableSheet.SheetID, req.TableSheet.StartLineIdx) if err != nil { return nil, err } resp := &GetDatabaseTableSchemaResponse{} if req.TableDataType == table.TableDataType_AllData || req.TableDataType == table.TableDataType_OnlyPreview { previewData, tErr := parser.TransferPreviewData(ctx, res.Columns, res.SampleData, 20) if tErr != nil { return resp, tErr } resp.PreviewData = previewData } resp.TableMeta = res.Columns resp.SheetList = extra.Sheets return resp, nil } func (d databaseService) ValidateDatabaseTableSchema(ctx context.Context, req *ValidateDatabaseTableSchemaRequest) (*ValidateDatabaseTableSchemaResponse, error) { parser := &sheet.TosTableParser{ UserID: req.UserID, DocumentSource: database.DocumentSourceType_Document, TosURI: req.TosURL, TosServ: d.storage, } res, sheetRes, err := parser.GetTableDataBySheetIDx(ctx, entity2.TableReaderMeta{ TosMaxLine: 100000, HeaderLineIdx: req.TableSheet.HeaderLineIdx, SheetId: req.TableSheet.SheetID, StartLineIdx: req.TableSheet.StartLineIdx, ReaderMethod: database.TableReadDataMethodAll, ReadLineCnt: 20, }) if err != nil { return nil, err } valid, invalidMsg := sheet.CheckSheetIsValid(req.Fields, res.Columns, sheetRes) return &ValidateDatabaseTableSchemaResponse{ Valid: valid, InvalidMsg: invalidMsg, }, nil } func (d databaseService) SubmitDatabaseInsertTask(ctx context.Context, req *SubmitDatabaseInsertTaskRequest) error { var err error failKey := onlineFailReasonKey if req.TableType == table.TableType_DraftTable { failKey = draftFailReasonKey } defer func() { if r := recover(); r != nil { errMsg := fmt.Sprintf("panic: %v", r) d.cache.Set(ctx, fmt.Sprintf(failKey, req.DatabaseID, req.UserID), errMsg, redisKeyTimeOut) err = fmt.Errorf("panic: %v", r) return } if err != nil { d.cache.Set(ctx, fmt.Sprintf(failKey, req.DatabaseID, req.UserID), err.Error(), redisKeyTimeOut) } }() parser := &sheet.TosTableParser{ UserID: req.UserID, DocumentSource: database.DocumentSourceType_Document, TosURI: req.FileURI, TosServ: d.storage, } parseData, extra, err := parser.GetTableDataBySheetIDx(ctx, entity2.TableReaderMeta{ TosMaxLine: 100000, SheetId: req.TableSheet.SheetID, HeaderLineIdx: req.TableSheet.HeaderLineIdx, StartLineIdx: req.TableSheet.StartLineIdx, ReaderMethod: database.TableReadDataMethodAll, }, ) if err != nil { return err } err = d.initializeCache(ctx, req, parseData, extra) if err != nil { return err } columns := parseData.Columns records := make([]map[string]string, 0, len(parseData.SampleData)) for _, data := range parseData.SampleData { record := make(map[string]string) for i, column := range columns { record[column.ColumnName] = data[i] } records = append(records, record) } batchSize := 20 go func() { defer func() { if r := recover(); r != nil { errMsg := fmt.Sprintf("panic: %v", r) d.cache.Set(ctx, fmt.Sprintf(failKey, req.DatabaseID, req.UserID), errMsg, redisKeyTimeOut) } }() for i := 0; i < len(records); i += batchSize { end := i + batchSize if end > len(records) { end = len(records) } batchRecords := records[i:end] err = d.AddDatabaseRecord(ctx, &AddDatabaseRecordRequest{ DatabaseID: req.DatabaseID, TableType: req.TableType, ConnectorID: req.ConnectorID, UserID: req.UserID, Records: batchRecords, }) if err != nil { d.cache.Set(ctx, fmt.Sprintf(failKey, req.DatabaseID, req.UserID), err.Error(), redisKeyTimeOut) return } err = d.increaseProgress(ctx, req, int64(len(batchRecords))) if err != nil { d.cache.Set(ctx, fmt.Sprintf(failKey, req.DatabaseID, req.UserID), err.Error(), redisKeyTimeOut) return } } }() return nil } func (d databaseService) GetDatabaseFileProgressData(ctx context.Context, req *GetDatabaseFileProgressDataRequest) (*GetDatabaseFileProgressDataResponse, error) { totalKey := onlineTotalCountKey if req.TableType == table.TableType_DraftTable { totalKey = draftTotalCountKey } progressKey := onlineProgressKey if req.TableType == table.TableType_DraftTable { progressKey = draftProgressKey } failKey := onlineFailReasonKey if req.TableType == table.TableType_DraftTable { failKey = draftFailReasonKey } currentFileName := onlineCurrentFileName if req.TableType == table.TableType_DraftTable { currentFileName = draftCurrentFileName } totalNum, err := d.cache.Get(ctx, fmt.Sprintf(totalKey, req.DatabaseID, req.UserID)).Int64() if err != nil && !errors.Is(err, redis.Nil) { return nil, err } progressNum, err := d.cache.Get(ctx, fmt.Sprintf(progressKey, req.DatabaseID, req.UserID)).Int64() if err != nil && !errors.Is(err, redis.Nil) { return nil, err } failReason, err := d.cache.Get(ctx, fmt.Sprintf(failKey, req.DatabaseID, req.UserID)).Result() if err != nil && !errors.Is(err, redis.Nil) { return nil, err } fileName, err := d.cache.Get(ctx, fmt.Sprintf(currentFileName, req.DatabaseID, req.UserID)).Result() if err != nil && !errors.Is(err, redis.Nil) { return nil, err } resp := &GetDatabaseFileProgressDataResponse{} if totalNum == 0 { resp.FileName = "" resp.Progress = 100 } else { resp.FileName = fileName resp.Progress = int32(float32(progressNum) / float32(totalNum) * 100) resp.StatusDescript = ptr.Of(failReason) } return resp, nil } const ( draftTotalCountKey = "database_file_%d_%d_draft_total" onlineTotalCountKey = "database_file_%d_%d_online_total" draftProgressKey = "database_file_%d_%d_draft_progress" onlineProgressKey = "database_file_%d_%d_online_progress" draftFailReasonKey = "database_file_%d_%d_draft_fail_reason" onlineFailReasonKey = "database_file_%d_%d_online_fail_reason" draftCurrentFileName = "database_file_%d_%d_draft_file_name" onlineCurrentFileName = "database_file_%d_%d_online_file_name" redisKeyTimeOut = time.Hour * 12 ) func (d databaseService) initializeCache(ctx context.Context, req *SubmitDatabaseInsertTaskRequest, parseData *entity2.TableReaderSheetData, extra *entity2.ExcelExtraInfo) error { tableType := req.TableType userID := req.UserID databaseID := req.DatabaseID totalKey := onlineTotalCountKey if tableType == table.TableType_DraftTable { totalKey = draftTotalCountKey } currentFileName := onlineCurrentFileName if tableType == table.TableType_DraftTable { currentFileName = draftCurrentFileName } progressKey := onlineProgressKey if tableType == table.TableType_DraftTable { progressKey = draftProgressKey } failKey := onlineFailReasonKey if tableType == table.TableType_DraftTable { failKey = draftFailReasonKey } _, err := d.cache.Set(ctx, fmt.Sprintf(totalKey, databaseID, userID), fmt.Sprintf("%d", len(parseData.SampleData)), redisKeyTimeOut).Result() if err != nil { return err } _, err = d.cache.Set(ctx, fmt.Sprintf(progressKey, databaseID, userID), int64(0), redisKeyTimeOut).Result() if err != nil { return err } _, err = d.cache.Set(ctx, fmt.Sprintf(failKey, databaseID, userID), "", redisKeyTimeOut).Result() if err != nil { return err } _, err = d.cache.Set(ctx, fmt.Sprintf(currentFileName, databaseID, userID), extra.Sheets[req.TableSheet.SheetID].SheetName, redisKeyTimeOut).Result() if err != nil { return err } return nil } func (d databaseService) increaseProgress(ctx context.Context, req *SubmitDatabaseInsertTaskRequest, successNum int64) error { tableType := req.TableType userID := req.UserID databaseID := req.DatabaseID progressKey := onlineProgressKey if tableType == table.TableType_DraftTable { progressKey = draftProgressKey } _, err := d.cache.IncrBy(ctx, fmt.Sprintf(progressKey, databaseID, userID), successNum).Result() if err != nil { return err } return nil } func (d databaseService) GetDraftDatabaseByOnlineID(ctx context.Context, req *GetDraftDatabaseByOnlineIDRequest) (*GetDraftDatabaseByOnlineIDResponse, error) { online, err := d.MGetDatabase(ctx, &MGetDatabaseRequest{ Basics: []*database.DatabaseBasic{ { ID: req.OnlineID, TableType: table.TableType_OnlineTable, }, }, }) if err != nil { return nil, err } if len(online.Databases) == 0 { return nil, fmt.Errorf("online table not found, id: %d", req.OnlineID) } draftID := online.Databases[0].GetDraftID() draftResp, err := d.MGetDatabase(ctx, &MGetDatabaseRequest{ Basics: []*database.DatabaseBasic{ { ID: draftID, TableType: table.TableType_DraftTable, }, }, }) if err != nil { return nil, err } if len(draftResp.Databases) == 0 { return nil, fmt.Errorf("online table not found, id: %d", req.OnlineID) } return &GetDraftDatabaseByOnlineIDResponse{ Database: draftResp.Databases[0], }, nil } // DeleteDatabaseByAppID delete all records and all physical tables by app id func (d databaseService) DeleteDatabaseByAppID(ctx context.Context, req *DeleteDatabaseByAppIDRequest) (*DeleteDatabaseByAppIDResponse, error) { onlineDBInfos, err := d.listDatabasesByAppID(ctx, req.AppID, table.TableType_OnlineTable) if err != nil { return nil, err } draftDBInfos, err := d.listDatabasesByAppID(ctx, req.AppID, table.TableType_DraftTable) if err != nil { return nil, err } tx := query.Use(d.db).Begin() if tx.Error != nil { return nil, fmt.Errorf("start transaction failed, %v", tx.Error) } defer func() { if r := recover(); r != nil { e := tx.Rollback() if e != nil { logs.CtxErrorf(ctx, "rollback failed, err=%v", e) } err = fmt.Errorf("catch panic: %v\nstack=%s", r, string(debug.Stack())) return } if err != nil { e := tx.Rollback() if e != nil { logs.CtxErrorf(ctx, "rollback failed, err=%v", e) } } }() onlineIDs := make([]int64, 0, len(onlineDBInfos)) for _, db := range onlineDBInfos { onlineIDs = append(onlineIDs, db.ID) } draftIDs := make([]int64, 0, len(draftDBInfos)) for _, db := range draftDBInfos { draftIDs = append(draftIDs, db.ID) } if err = d.onlineDAO.BatchDeleteWithTX(ctx, tx, onlineIDs); err != nil { return nil, err } if err = d.draftDAO.BatchDeleteWithTX(ctx, tx, draftIDs); err != nil { return nil, err } err = tx.Commit() if err != nil { return nil, fmt.Errorf("commit transaction failed: %v", err) } // delete draft and online physical table onlinePhysicals := make([]string, 0, len(onlineDBInfos)) for _, db := range onlineDBInfos { onlinePhysicals = append(onlinePhysicals, db.ActualTableName) } draftPhysicals := make([]string, 0, len(draftDBInfos)) for _, db := range draftDBInfos { draftPhysicals = append(draftPhysicals, db.ActualTableName) } for _, physical := range onlinePhysicals { _, err = d.rdb.DropTable(ctx, &rdb.DropTableRequest{ TableName: physical, }) if err != nil { logs.Errorf("drop online physical table failed: %v, table_name=%s", err, physical) } } for _, physical := range draftPhysicals { _, err = d.rdb.DropTable(ctx, &rdb.DropTableRequest{ TableName: physical, }) if err != nil { logs.Errorf("drop draft physical table failed: %v, table_name=%s", err, physical) } } return &DeleteDatabaseByAppIDResponse{ DeletedDatabaseIDs: onlineIDs, }, nil } func (d databaseService) listDatabasesByAppID(ctx context.Context, appID int64, tableType table.TableType) ([]*entity2.Database, error) { const batchSize = 100 offset := 0 dbInfos := make([]*entity2.Database, 0) for { resp, err := d.ListDatabase(ctx, &ListDatabaseRequest{ AppID: appID, TableType: tableType, Limit: batchSize, Offset: offset, }) if err != nil { return nil, err } for _, db := range resp.Databases { dbInfos = append(dbInfos, db) } if !resp.HasMore { break } offset += batchSize } return dbInfos, nil } func (d databaseService) GetAllDatabaseByAppID(ctx context.Context, req *GetAllDatabaseByAppIDRequest) (*GetAllDatabaseByAppIDResponse, error) { onlineDBs, err := d.listDatabasesByAppID(ctx, req.AppID, table.TableType_OnlineTable) if err != nil { return nil, err } return &GetAllDatabaseByAppIDResponse{ Databases: onlineDBs, }, nil }