2207 lines
63 KiB
Go
2207 lines
63 KiB
Go
/*
|
|
* Copyright 2025 coze-dev Authors
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package service
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"math/rand"
|
|
"runtime/debug"
|
|
"strconv"
|
|
"time"
|
|
|
|
"github.com/tealeg/xlsx/v3"
|
|
"gorm.io/gorm"
|
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
|
|
|
|
"github.com/coze-dev/coze-studio/backend/api/model/app/bot_common"
|
|
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
|
|
"github.com/coze-dev/coze-studio/backend/api/model/data/database/table"
|
|
crossvariables "github.com/coze-dev/coze-studio/backend/crossdomain/contract/variables"
|
|
entity2 "github.com/coze-dev/coze-studio/backend/domain/memory/database/entity"
|
|
"github.com/coze-dev/coze-studio/backend/domain/memory/database/internal/convertor"
|
|
"github.com/coze-dev/coze-studio/backend/domain/memory/database/internal/dal/query"
|
|
"github.com/coze-dev/coze-studio/backend/domain/memory/database/internal/physicaltable"
|
|
"github.com/coze-dev/coze-studio/backend/domain/memory/database/internal/sheet"
|
|
"github.com/coze-dev/coze-studio/backend/domain/memory/database/repository"
|
|
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
|
"github.com/coze-dev/coze-studio/backend/infra/contract/rdb"
|
|
entity3 "github.com/coze-dev/coze-studio/backend/infra/contract/rdb/entity"
|
|
sqlparsercontract "github.com/coze-dev/coze-studio/backend/infra/contract/sqlparser"
|
|
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
|
"github.com/coze-dev/coze-studio/backend/infra/impl/sqlparser"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
|
"github.com/coze-dev/coze-studio/backend/types/consts"
|
|
"github.com/coze-dev/coze-studio/backend/types/errno"
|
|
)
|
|
|
|
type databaseService struct {
|
|
rdb rdb.RDB
|
|
db *gorm.DB
|
|
generator idgen.IDGenerator
|
|
draftDAO repository.DraftDAO
|
|
onlineDAO repository.OnlineDAO
|
|
agentToDatabaseDAO repository.AgentToDatabaseDAO
|
|
storage storage.Storage
|
|
cache cache.Cmdable
|
|
}
|
|
|
|
func NewService(rdb rdb.RDB, db *gorm.DB, generator idgen.IDGenerator, storage storage.Storage, cacheCli cache.Cmdable) Database {
|
|
return &databaseService{
|
|
rdb: rdb,
|
|
db: db,
|
|
generator: generator,
|
|
draftDAO: repository.NewDraftDatabaseDAO(db, generator),
|
|
onlineDAO: repository.NewOnlineDatabaseDAO(db, generator),
|
|
agentToDatabaseDAO: repository.NewAgentToDatabaseDAO(db, generator),
|
|
storage: storage,
|
|
cache: cacheCli,
|
|
}
|
|
}
|
|
|
|
func (d databaseService) CreateDatabase(ctx context.Context, req *CreateDatabaseRequest) (*CreateDatabaseResponse, error) {
|
|
draftEntity, onlineEntity := req.Database, req.Database
|
|
fieldItems, columns := physicaltable.CreateFieldInfo(req.Database.FieldList)
|
|
|
|
// create physical draft table
|
|
draftEntity.FieldList = fieldItems
|
|
|
|
draftPhysicalTableRes, err := physicaltable.CreatePhysicalTable(ctx, d.rdb, columns)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if draftPhysicalTableRes.Table == nil {
|
|
return nil, fmt.Errorf("create draft table failed, columns info is %v", columns)
|
|
}
|
|
|
|
draftID, err := d.generator.GenID(ctx)
|
|
if err != nil {
|
|
return nil, errorx.WrapByCode(err, errno.ErrMemoryIDGenFailCode, errorx.KV("msg", "CreateDatabase"))
|
|
}
|
|
|
|
// create physical online table
|
|
onlineEntity.FieldList = fieldItems
|
|
|
|
onlinePhysicalTableRes, err := physicaltable.CreatePhysicalTable(ctx, d.rdb, columns)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if onlinePhysicalTableRes.Table == nil {
|
|
return nil, fmt.Errorf("create online table failed, columns info is %v", columns)
|
|
}
|
|
|
|
onlineID, err := d.generator.GenID(ctx)
|
|
if err != nil {
|
|
return nil, errorx.WrapByCode(err, errno.ErrMemoryIDGenFailCode, errorx.KV("msg", "CreateDatabase"))
|
|
}
|
|
|
|
// insert draft and online database info
|
|
tx := query.Use(d.db).Begin()
|
|
if tx.Error != nil {
|
|
return nil, fmt.Errorf("start transaction failed, %v", tx.Error)
|
|
}
|
|
|
|
if draftEntity.IconURI == "" {
|
|
draftEntity.IconURI = consts.DefaultDatabaseIcon
|
|
}
|
|
if onlineEntity.IconURI == "" {
|
|
onlineEntity.IconURI = consts.DefaultDatabaseIcon
|
|
}
|
|
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
e := tx.Rollback()
|
|
if e != nil {
|
|
logs.CtxErrorf(ctx, "rollback failed, err=%v", e)
|
|
}
|
|
|
|
err = fmt.Errorf("catch panic: %v\nstack=%s", r, string(debug.Stack()))
|
|
return
|
|
}
|
|
|
|
if err != nil {
|
|
e := tx.Rollback()
|
|
if e != nil {
|
|
logs.CtxErrorf(ctx, "rollback failed, err=%v", e)
|
|
}
|
|
}
|
|
}()
|
|
|
|
_, err = d.draftDAO.CreateWithTX(ctx, tx, draftEntity, draftID, onlineID, draftPhysicalTableRes.Table.Name)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
onlineEntity, err = d.onlineDAO.CreateWithTX(ctx, tx, onlineEntity, draftID, onlineID, onlinePhysicalTableRes.Table.Name)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
err = tx.Commit()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
onlineEntity.ActualTableName = onlinePhysicalTableRes.Table.Name
|
|
onlineEntity.ID = onlineID
|
|
onlineEntity.DraftID = ptr.Of(draftID)
|
|
objURL, uRrr := d.storage.GetObjectUrl(ctx, onlineEntity.IconURI)
|
|
if uRrr == nil {
|
|
onlineEntity.IconURL = objURL
|
|
}
|
|
|
|
return &CreateDatabaseResponse{
|
|
Database: onlineEntity,
|
|
}, nil
|
|
}
|
|
|
|
func (d databaseService) UpdateDatabase(ctx context.Context, req *UpdateDatabaseRequest) (*UpdateDatabaseResponse, error) {
|
|
// req.Database.ID is the id of online database
|
|
input := req.Database
|
|
onlineInfo, err := d.onlineDAO.Get(ctx, req.Database.ID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get online database info failed: %v", err)
|
|
}
|
|
|
|
draftInfo, err := d.draftDAO.Get(ctx, onlineInfo.GetDraftID())
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get draft database info failed: %v", err)
|
|
}
|
|
|
|
draftEntity, onlineEntity := *input, *input
|
|
|
|
draftEntity.ID = draftInfo.ID
|
|
onlineEntity.ID = onlineInfo.ID
|
|
|
|
fieldItems, columns, droppedColumns, err := physicaltable.UpdateFieldInfo(input.FieldList, onlineInfo.FieldList)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
draftEntity.FieldList = fieldItems
|
|
onlineEntity.FieldList = fieldItems
|
|
|
|
// get draft and online physical table info
|
|
draftPhysicalTable, err := d.rdb.GetTable(ctx, &rdb.GetTableRequest{
|
|
TableName: draftInfo.ActualTableName,
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get physical table info failed: %v", err)
|
|
}
|
|
|
|
onlinePhysicalTable, err := d.rdb.GetTable(ctx, &rdb.GetTableRequest{
|
|
TableName: onlineInfo.ActualTableName,
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get physical table info failed: %v", err)
|
|
}
|
|
|
|
err = physicaltable.UpdatePhysicalTableWithDrops(ctx, d.rdb, draftPhysicalTable.Table, columns, droppedColumns, draftInfo.ActualTableName)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("update draft physical table failed: %v", err)
|
|
}
|
|
|
|
err = physicaltable.UpdatePhysicalTableWithDrops(ctx, d.rdb, onlinePhysicalTable.Table, columns, droppedColumns, onlineInfo.ActualTableName)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("update online physical table failed: %v", err)
|
|
}
|
|
|
|
tx := query.Use(d.db).Begin()
|
|
if tx.Error != nil {
|
|
return nil, fmt.Errorf("start transaction failed, %v", tx.Error)
|
|
}
|
|
|
|
if draftEntity.IconURI == "" {
|
|
draftEntity.IconURI = consts.DefaultDatabaseIcon
|
|
}
|
|
if onlineEntity.IconURI == "" {
|
|
onlineEntity.IconURI = consts.DefaultDatabaseIcon
|
|
}
|
|
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
e := tx.Rollback()
|
|
if e != nil {
|
|
logs.CtxErrorf(ctx, "rollback failed, err=%v", e)
|
|
}
|
|
|
|
err = fmt.Errorf("catch panic: %v\nstack=%s", r, string(debug.Stack()))
|
|
return
|
|
}
|
|
|
|
if err != nil {
|
|
e := tx.Rollback()
|
|
if e != nil {
|
|
logs.CtxErrorf(ctx, "rollback failed, err=%v", e)
|
|
}
|
|
}
|
|
}()
|
|
|
|
_, err = d.draftDAO.UpdateWithTX(ctx, tx, &draftEntity)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("update draft database info failed: %v", err)
|
|
}
|
|
|
|
onlineEntityUpdated, err := d.onlineDAO.UpdateWithTX(ctx, tx, &onlineEntity)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("update online database info failed: %v", err)
|
|
}
|
|
|
|
err = tx.Commit()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("commit transaction failed: %v", err)
|
|
}
|
|
|
|
return &UpdateDatabaseResponse{
|
|
Database: onlineEntityUpdated,
|
|
}, nil
|
|
}
|
|
|
|
func (d databaseService) DeleteDatabase(ctx context.Context, req *DeleteDatabaseRequest) error {
|
|
onlineInfo, err := d.onlineDAO.Get(ctx, req.ID)
|
|
if err != nil {
|
|
return fmt.Errorf("get online database info failed: %v", err)
|
|
}
|
|
|
|
draftInfo, err := d.draftDAO.Get(ctx, onlineInfo.GetDraftID())
|
|
if err != nil {
|
|
return fmt.Errorf("get draft database info failed: %v", err)
|
|
}
|
|
|
|
tx := query.Use(d.db).Begin()
|
|
if tx.Error != nil {
|
|
return fmt.Errorf("start transaction failed, %v", tx.Error)
|
|
}
|
|
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
e := tx.Rollback()
|
|
if e != nil {
|
|
logs.CtxErrorf(ctx, "rollback failed, err=%v", e)
|
|
}
|
|
|
|
err = fmt.Errorf("catch panic: %v\nstack=%s", r, string(debug.Stack()))
|
|
return
|
|
}
|
|
|
|
if err != nil {
|
|
e := tx.Rollback()
|
|
if e != nil {
|
|
logs.CtxErrorf(ctx, "rollback failed, err=%v", e)
|
|
}
|
|
}
|
|
}()
|
|
|
|
err = d.draftDAO.DeleteWithTX(ctx, tx, draftInfo.ID)
|
|
if err != nil {
|
|
return fmt.Errorf("delete draft database info failed: %v", err)
|
|
}
|
|
|
|
err = d.onlineDAO.DeleteWithTX(ctx, tx, onlineInfo.ID)
|
|
if err != nil {
|
|
return fmt.Errorf("delete online database info failed: %v", err)
|
|
}
|
|
|
|
err = tx.Commit()
|
|
if err != nil {
|
|
return fmt.Errorf("commit transaction failed: %v", err)
|
|
}
|
|
|
|
// delete draft physical table
|
|
if draftInfo.ActualTableName != "" {
|
|
_, err = d.rdb.DropTable(ctx, &rdb.DropTableRequest{
|
|
TableName: draftInfo.ActualTableName,
|
|
})
|
|
if err != nil {
|
|
logs.Errorf("drop draft physical table failed: %v, table_name=%s", err, draftInfo.ActualTableName)
|
|
}
|
|
}
|
|
|
|
// delete online physical table
|
|
if onlineInfo.ActualTableName != "" {
|
|
_, err = d.rdb.DropTable(ctx, &rdb.DropTableRequest{
|
|
TableName: onlineInfo.ActualTableName,
|
|
})
|
|
if err != nil {
|
|
logs.Errorf("drop online physical table failed: %v, table_name=%s", err, onlineInfo.ActualTableName)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (d databaseService) MGetDatabase(ctx context.Context, req *MGetDatabaseRequest) (*MGetDatabaseResponse, error) {
|
|
if len(req.Basics) == 0 {
|
|
return &MGetDatabaseResponse{
|
|
Databases: []*entity2.Database{},
|
|
}, nil
|
|
}
|
|
|
|
onlineID2NeedSysFields := make(map[int64]bool)
|
|
draftID2NeedSysFields := make(map[int64]bool)
|
|
|
|
uniqueOnlineIDs := make([]int64, 0)
|
|
uniqueDraftIDs := make([]int64, 0)
|
|
idMap := make(map[int64]bool)
|
|
for _, basic := range req.Basics {
|
|
if !idMap[basic.ID] {
|
|
idMap[basic.ID] = true
|
|
if basic.TableType == table.TableType_OnlineTable {
|
|
uniqueOnlineIDs = append(uniqueOnlineIDs, basic.ID)
|
|
onlineID2NeedSysFields[basic.ID] = basic.NeedSysFields
|
|
} else {
|
|
uniqueDraftIDs = append(uniqueDraftIDs, basic.ID)
|
|
draftID2NeedSysFields[basic.ID] = basic.NeedSysFields
|
|
}
|
|
}
|
|
}
|
|
|
|
onlineDatabases, err := d.onlineDAO.MGet(ctx, uniqueOnlineIDs)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("batch get database info failed: %v", err)
|
|
}
|
|
|
|
draftDatabases, err := d.draftDAO.MGet(ctx, uniqueDraftIDs)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("batch get database info failed: %v", err)
|
|
}
|
|
|
|
for _, onlineDatabase := range onlineDatabases {
|
|
if needSys, ok := onlineID2NeedSysFields[onlineDatabase.ID]; ok && needSys {
|
|
if onlineDatabase.FieldList == nil {
|
|
onlineDatabase.FieldList = make([]*database.FieldItem, 0, 3)
|
|
}
|
|
onlineDatabase.FieldList = append(onlineDatabase.FieldList, physicaltable.GetDisplayCreateTimeField(), physicaltable.GetDisplayUidField(), physicaltable.GetDisplayIDField())
|
|
}
|
|
if onlineDatabase.IconURI != "" {
|
|
objURL, uRrr := d.storage.GetObjectUrl(ctx, onlineDatabase.IconURI)
|
|
if uRrr == nil {
|
|
onlineDatabase.IconURL = objURL
|
|
}
|
|
}
|
|
}
|
|
for _, draftDatabase := range draftDatabases {
|
|
if needSys, ok := draftID2NeedSysFields[draftDatabase.ID]; ok && needSys {
|
|
if draftDatabase.FieldList == nil {
|
|
draftDatabase.FieldList = make([]*database.FieldItem, 0, 3)
|
|
}
|
|
draftDatabase.FieldList = append(draftDatabase.FieldList, physicaltable.GetDisplayCreateTimeField(), physicaltable.GetDisplayUidField(), physicaltable.GetDisplayIDField())
|
|
}
|
|
if draftDatabase.IconURI != "" {
|
|
objURL, uRrr := d.storage.GetObjectUrl(ctx, draftDatabase.IconURI)
|
|
if uRrr == nil {
|
|
draftDatabase.IconURL = objURL
|
|
}
|
|
}
|
|
}
|
|
|
|
databases := make([]*entity2.Database, 0)
|
|
databases = append(databases, onlineDatabases...)
|
|
databases = append(databases, draftDatabases...)
|
|
|
|
return &MGetDatabaseResponse{
|
|
Databases: databases,
|
|
}, nil
|
|
}
|
|
|
|
func (d databaseService) ListDatabase(ctx context.Context, req *ListDatabaseRequest) (*ListDatabaseResponse, error) {
|
|
filter := &entity2.DatabaseFilter{
|
|
CreatorID: req.CreatorID,
|
|
SpaceID: req.SpaceID,
|
|
TableName: req.TableName,
|
|
AppID: &req.AppID,
|
|
}
|
|
|
|
page := &entity2.Pagination{
|
|
Limit: req.Limit,
|
|
Offset: req.Offset,
|
|
}
|
|
|
|
var databases []*entity2.Database
|
|
var err error
|
|
var count int64
|
|
if req.TableType == table.TableType_OnlineTable {
|
|
databases, count, err = d.onlineDAO.List(ctx, filter, page, req.OrderBy)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list database failed: %v", err)
|
|
}
|
|
} else {
|
|
databases, count, err = d.draftDAO.List(ctx, filter, page, req.OrderBy)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list database failed: %v", err)
|
|
}
|
|
}
|
|
|
|
for _, database := range databases {
|
|
if database.IconURI != "" {
|
|
objURL, uRrr := d.storage.GetObjectUrl(ctx, database.IconURI)
|
|
if uRrr == nil {
|
|
database.IconURL = objURL
|
|
}
|
|
}
|
|
}
|
|
|
|
var hasMore bool
|
|
if count <= int64(req.Limit)+int64(req.Offset) {
|
|
hasMore = false
|
|
} else {
|
|
hasMore = true
|
|
}
|
|
|
|
return &ListDatabaseResponse{
|
|
Databases: databases,
|
|
HasMore: hasMore,
|
|
TotalCount: count,
|
|
}, nil
|
|
}
|
|
|
|
func (d databaseService) AddDatabaseRecord(ctx context.Context, req *AddDatabaseRecordRequest) error {
|
|
var tableInfo *entity2.Database
|
|
var err error
|
|
|
|
if req.TableType == table.TableType_OnlineTable {
|
|
tableInfo, err = d.onlineDAO.Get(ctx, req.DatabaseID)
|
|
} else {
|
|
tableInfo, err = d.draftDAO.Get(ctx, req.DatabaseID)
|
|
}
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("get table info failed: %v", err)
|
|
}
|
|
|
|
if tableInfo.RwMode == table.BotTableRWMode_ReadOnly {
|
|
return errorx.New(errno.ErrMemoryDatabaseCannotAddData)
|
|
}
|
|
|
|
physicalTableName := tableInfo.ActualTableName
|
|
if physicalTableName == "" {
|
|
return fmt.Errorf("physical table name is empty")
|
|
}
|
|
|
|
fieldList := append(tableInfo.FieldList, physicaltable.GetCreateTimeField(), physicaltable.GetUidField(), physicaltable.GetIDField(), physicaltable.GetConnectIDField())
|
|
fieldMap := slices.ToMap(fieldList, func(e *database.FieldItem) (string, *database.FieldItem) {
|
|
return e.Name, e
|
|
})
|
|
|
|
convertedRecords := make([]map[string]interface{}, 0, len(req.Records))
|
|
ids, err := d.generator.GenMultiIDs(ctx, len(req.Records))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for index, recordMap := range req.Records {
|
|
convertedRecord := make(map[string]interface{})
|
|
|
|
cid := consts.CozeConnectorID
|
|
if req.ConnectorID != nil {
|
|
cid = *req.ConnectorID
|
|
}
|
|
convertedRecord[database.DefaultUidColName] = req.UserID
|
|
convertedRecord[database.DefaultCidColName] = cid
|
|
convertedRecord[database.DefaultCreateTimeColName] = time.Now()
|
|
convertedRecord[database.DefaultIDColName] = ids[index]
|
|
|
|
if _, ok := recordMap[database.DefaultIDColName]; ok {
|
|
delete(recordMap, database.DefaultIDColName)
|
|
}
|
|
|
|
for fieldName, value := range recordMap {
|
|
if _, fOk := fieldMap[fieldName]; !fOk {
|
|
return errorx.New(errno.ErrMemoryDatabaseFieldNotFoundCode, errorx.KV("msg", fmt.Sprintf("field %s not found in table definition", fieldName)))
|
|
}
|
|
|
|
fieldInfo, _ := fieldMap[fieldName]
|
|
if value == "" && fieldInfo.MustRequired {
|
|
return fmt.Errorf("field %s's value is required", fieldName)
|
|
}
|
|
|
|
physicalFieldName := fieldInfo.PhysicalName
|
|
convertedValue, err := convertor.ConvertValueByType(value, fieldInfo.Type)
|
|
if err != nil {
|
|
return fmt.Errorf("convert value failed for field %s: %v, using original value", fieldName, err)
|
|
}
|
|
|
|
convertedRecord[physicalFieldName] = convertedValue
|
|
}
|
|
|
|
convertedRecords = append(convertedRecords, convertedRecord)
|
|
}
|
|
|
|
_, err = d.rdb.InsertData(ctx, &rdb.InsertDataRequest{
|
|
TableName: physicalTableName,
|
|
Data: convertedRecords,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("insert data failed: %v", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (d databaseService) UpdateDatabaseRecord(ctx context.Context, req *UpdateDatabaseRecordRequest) error {
|
|
var tableInfo *database.Database
|
|
var err error
|
|
|
|
if req.TableType == table.TableType_OnlineTable {
|
|
tableInfo, err = d.onlineDAO.Get(ctx, req.DatabaseID)
|
|
} else {
|
|
tableInfo, err = d.draftDAO.Get(ctx, req.DatabaseID)
|
|
}
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("get table info failed: %v", err)
|
|
}
|
|
|
|
if tableInfo.RwMode == table.BotTableRWMode_ReadOnly {
|
|
return errorx.New(errno.ErrMemoryDatabaseCannotAddData)
|
|
}
|
|
|
|
physicalTableName := tableInfo.ActualTableName
|
|
if physicalTableName == "" {
|
|
return fmt.Errorf("physical table name is empty")
|
|
}
|
|
|
|
fieldList := append(tableInfo.FieldList, physicaltable.GetCreateTimeField(), physicaltable.GetUidField(), physicaltable.GetIDField(), physicaltable.GetConnectIDField())
|
|
fieldMap := slices.ToMap(fieldList, func(e *database.FieldItem) (string, *database.FieldItem) {
|
|
return e.Name, e
|
|
})
|
|
|
|
for _, record := range req.Records {
|
|
idStr, exists := record[database.DefaultIDColName]
|
|
if !exists {
|
|
return fmt.Errorf("record must contain %s field for update", database.DefaultIDColName)
|
|
}
|
|
|
|
id, err := strconv.ParseInt(idStr, 10, 64)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid ID format: %v", err)
|
|
}
|
|
|
|
updateData := make(map[string]interface{})
|
|
|
|
for fieldName, valueStr := range record {
|
|
if fieldName == database.DefaultIDColName {
|
|
continue
|
|
}
|
|
|
|
if _, fOk := fieldMap[fieldName]; !fOk {
|
|
return errorx.New(errno.ErrMemoryDatabaseFieldNotFoundCode, errorx.KV("msg", fmt.Sprintf("field %s not found in table definition", fieldName)))
|
|
}
|
|
|
|
fieldInfo, _ := fieldMap[fieldName]
|
|
if valueStr == "" && fieldInfo.MustRequired {
|
|
return fmt.Errorf("field %s's value is required", fieldName)
|
|
}
|
|
|
|
physicalFieldName := fieldInfo.PhysicalName
|
|
convertedValue, err := convertor.ConvertValueByType(valueStr, fieldInfo.Type)
|
|
if err != nil {
|
|
logs.Warnf("convert value failed for field %s: %v, using original value", fieldName, err)
|
|
convertedValue = valueStr
|
|
}
|
|
updateData[physicalFieldName] = convertedValue
|
|
}
|
|
|
|
if len(updateData) == 0 {
|
|
continue
|
|
}
|
|
|
|
condition := &rdb.ComplexCondition{
|
|
Conditions: []*rdb.Condition{
|
|
{
|
|
Field: database.DefaultIDColName,
|
|
Operator: entity3.OperatorEqual,
|
|
Value: id,
|
|
},
|
|
},
|
|
}
|
|
|
|
if tableInfo.RwMode == table.BotTableRWMode_LimitedReadWrite {
|
|
cond := &rdb.Condition{
|
|
Field: database.DefaultUidColName,
|
|
Operator: entity3.OperatorEqual,
|
|
Value: strconv.FormatInt(req.UserID, 10),
|
|
}
|
|
|
|
condition.Conditions = append(condition.Conditions, cond)
|
|
}
|
|
|
|
_, err = d.rdb.UpdateData(ctx, &rdb.UpdateDataRequest{
|
|
TableName: physicalTableName,
|
|
Data: updateData,
|
|
Where: condition,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("update data failed for ID %d: %v", id, err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (d databaseService) DeleteDatabaseRecord(ctx context.Context, req *DeleteDatabaseRecordRequest) error {
|
|
var tableInfo *entity2.Database
|
|
var err error
|
|
|
|
if req.TableType == table.TableType_OnlineTable {
|
|
tableInfo, err = d.onlineDAO.Get(ctx, req.DatabaseID)
|
|
} else {
|
|
tableInfo, err = d.draftDAO.Get(ctx, req.DatabaseID)
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if tableInfo.RwMode == table.BotTableRWMode_ReadOnly {
|
|
return errorx.New(errno.ErrMemoryDatabaseCannotAddData)
|
|
}
|
|
|
|
physicalTableName := tableInfo.ActualTableName
|
|
if physicalTableName == "" {
|
|
return fmt.Errorf("physical table name is empty")
|
|
}
|
|
|
|
var ids []interface{}
|
|
for _, record := range req.Records {
|
|
idStr, exists := record[database.DefaultIDColName]
|
|
if !exists {
|
|
return fmt.Errorf("record must contain %s field for deletion", database.DefaultIDColName)
|
|
}
|
|
|
|
id, err := strconv.ParseInt(idStr, 10, 64)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid ID format: %v", err)
|
|
}
|
|
|
|
ids = append(ids, id)
|
|
}
|
|
|
|
condition := &rdb.ComplexCondition{
|
|
Conditions: []*rdb.Condition{
|
|
{
|
|
Field: database.DefaultIDColName,
|
|
Operator: entity3.OperatorIn,
|
|
Value: ids,
|
|
},
|
|
},
|
|
}
|
|
|
|
if tableInfo.RwMode == table.BotTableRWMode_LimitedReadWrite {
|
|
cond := &rdb.Condition{
|
|
Field: database.DefaultUidColName,
|
|
Operator: entity3.OperatorEqual,
|
|
Value: strconv.FormatInt(req.UserID, 10),
|
|
}
|
|
|
|
condition.Conditions = append(condition.Conditions, cond)
|
|
}
|
|
|
|
_, err = d.rdb.DeleteData(ctx, &rdb.DeleteDataRequest{
|
|
TableName: physicalTableName,
|
|
Where: condition,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("delete data failed: %v", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (d databaseService) ListDatabaseRecord(ctx context.Context, req *ListDatabaseRecordRequest) (*ListDatabaseRecordResponse, error) {
|
|
var tableInfo *entity2.Database
|
|
var err error
|
|
|
|
if req.TableType == table.TableType_OnlineTable {
|
|
tableInfo, err = d.onlineDAO.Get(ctx, req.DatabaseID)
|
|
} else {
|
|
tableInfo, err = d.draftDAO.Get(ctx, req.DatabaseID)
|
|
}
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get table info failed: %v", err)
|
|
}
|
|
|
|
physicalTableName := tableInfo.ActualTableName
|
|
if physicalTableName == "" {
|
|
return nil, fmt.Errorf("physical table name is empty")
|
|
}
|
|
|
|
fieldNameToPhysical := make(map[string]string)
|
|
physicalToFieldName := make(map[string]string)
|
|
physicalToFieldType := make(map[string]table.FieldItemType)
|
|
|
|
for _, field := range tableInfo.FieldList {
|
|
if field.AlterID > 0 {
|
|
physicalName := physicaltable.GetFieldPhysicsName(field.AlterID)
|
|
fieldNameToPhysical[field.Name] = physicalName
|
|
physicalToFieldName[physicalName] = field.Name
|
|
physicalToFieldType[physicalName] = field.Type
|
|
}
|
|
}
|
|
|
|
var complexCondition *rdb.ComplexCondition
|
|
|
|
if req.ConnectorID != nil && *req.ConnectorID > 0 {
|
|
cond := &rdb.Condition{
|
|
Field: database.DefaultCidColName,
|
|
Operator: entity3.OperatorEqual,
|
|
Value: *req.ConnectorID,
|
|
}
|
|
|
|
complexCondition = &rdb.ComplexCondition{
|
|
Conditions: []*rdb.Condition{cond},
|
|
}
|
|
}
|
|
if req.TableType == table.TableType_DraftTable {
|
|
if tableInfo.RwMode == table.BotTableRWMode_LimitedReadWrite {
|
|
cond := &rdb.Condition{
|
|
Field: database.DefaultUidColName,
|
|
Operator: entity3.OperatorEqual,
|
|
Value: strconv.FormatInt(req.UserID, 10),
|
|
}
|
|
|
|
if complexCondition == nil {
|
|
complexCondition = &rdb.ComplexCondition{
|
|
Conditions: []*rdb.Condition{cond},
|
|
}
|
|
} else {
|
|
complexCondition.Conditions = append(complexCondition.Conditions, cond)
|
|
}
|
|
}
|
|
}
|
|
|
|
limit := 50
|
|
if req.Limit > 0 {
|
|
limit = req.Limit
|
|
}
|
|
|
|
orderBy := []*rdb.OrderBy{
|
|
{
|
|
Field: database.DefaultCreateTimeColName,
|
|
Direction: entity3.SortDirectionDesc,
|
|
},
|
|
}
|
|
|
|
selectResp, err := d.rdb.SelectData(ctx, &rdb.SelectDataRequest{
|
|
TableName: physicalTableName,
|
|
Fields: []string{}, // Null means query all fields
|
|
Where: complexCondition,
|
|
OrderBy: orderBy,
|
|
Limit: &limit,
|
|
Offset: &req.Offset,
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("select data failed: %v", err)
|
|
}
|
|
|
|
if selectResp.ResultSet == nil {
|
|
return &ListDatabaseRecordResponse{}, nil
|
|
}
|
|
|
|
records := convertor.ConvertResultSetToString(selectResp.ResultSet, physicalToFieldName, physicalToFieldType)
|
|
|
|
var hasMore bool
|
|
if selectResp.Total <= int64(req.Limit)+int64(req.Offset) {
|
|
hasMore = false
|
|
} else {
|
|
hasMore = true
|
|
}
|
|
|
|
return &ListDatabaseRecordResponse{
|
|
Records: records,
|
|
FieldList: tableInfo.FieldList,
|
|
HasMore: hasMore,
|
|
TotalCount: selectResp.Total,
|
|
}, nil
|
|
}
|
|
|
|
func (d databaseService) GetDatabaseTemplate(ctx context.Context, req *GetDatabaseTemplateRequest) (*GetDatabaseTemplateResponse, error) {
|
|
items := req.FieldItems
|
|
tableName := req.TableName
|
|
|
|
file := xlsx.NewFile()
|
|
sheet, err := file.AddSheet("Sheet1")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// add header
|
|
header := sheet.AddRow()
|
|
headerTitles := make([]string, 0)
|
|
for i := range items {
|
|
headerTitles = append(headerTitles, items[i].GetName())
|
|
}
|
|
for _, title := range headerTitles {
|
|
cell := header.AddCell()
|
|
cell.Value = title
|
|
}
|
|
|
|
row := sheet.AddRow()
|
|
for _, item := range items {
|
|
row.AddCell().Value = physicaltable.GetTemplateTypeMap()[item.GetType()]
|
|
}
|
|
var buffer bytes.Buffer
|
|
err = file.Write(&buffer)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
binaryData := buffer.Bytes()
|
|
url, err := d.uploadFile(ctx, req.UserID, string(binaryData), tableName, "xlsx", nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &GetDatabaseTemplateResponse{
|
|
Url: url,
|
|
}, nil
|
|
}
|
|
|
|
func (d databaseService) uploadFile(ctx context.Context, UserId int64, content string, bizType, fileType string, suffix *string) (string, error) {
|
|
secret := createSecret(UserId, fileType)
|
|
fileName := fmt.Sprintf("%d_%d_%s.%s", UserId, time.Now().UnixNano(), secret, fileType)
|
|
if suffix != nil {
|
|
fileName = fmt.Sprintf("%d_%d_%s_%s.%s", UserId, time.Now().UnixNano(), secret, *suffix, fileType)
|
|
}
|
|
|
|
objectName := fmt.Sprintf("%s/%s", bizType, fileName)
|
|
err := d.storage.PutObject(ctx, objectName, []byte(content))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
url, err := d.storage.GetObjectUrl(ctx, objectName)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return url, nil
|
|
}
|
|
|
|
const baseWord = "1Aa2Bb3Cc4Dd5Ee6Ff7Gg8Hh9Ii0JjKkLlMmNnOoPpQqRrSsTtUuVvWwXxYyZz"
|
|
|
|
func createSecret(uid int64, fileType string) string {
|
|
num := 10
|
|
input := fmt.Sprintf("upload_%d_Ma*9)fhi_%d_gou_%s_rand_%d", uid, time.Now().Unix(), fileType, rand.Intn(100000))
|
|
hash := sha256.Sum256([]byte(fmt.Sprintf("%s", input)))
|
|
hashString := base64.StdEncoding.EncodeToString(hash[:])
|
|
|
|
if len(hashString) > num {
|
|
hashString = hashString[:num]
|
|
}
|
|
|
|
result := ""
|
|
for _, char := range hashString {
|
|
index := int(char) % 62
|
|
result += string(baseWord[index])
|
|
}
|
|
return result
|
|
}
|
|
|
|
func (d databaseService) ExecuteSQL(ctx context.Context, req *ExecuteSQLRequest) (*ExecuteSQLResponse, error) {
|
|
var tableInfo *entity2.Database
|
|
var err error
|
|
|
|
if req.TableType == table.TableType_OnlineTable {
|
|
tableInfo, err = d.onlineDAO.Get(ctx, req.DatabaseID)
|
|
} else {
|
|
tableInfo, err = d.draftDAO.Get(ctx, req.DatabaseID)
|
|
}
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get table info failed: %v", err)
|
|
}
|
|
|
|
if tableInfo.RwMode == table.BotTableRWMode_ReadOnly &&
|
|
(req.OperateType == database.OperateType_Insert || req.OperateType == database.OperateType_Update ||
|
|
req.OperateType == database.OperateType_Delete) {
|
|
return nil, errorx.New(errno.ErrMemoryDatabaseCannotAddData)
|
|
}
|
|
|
|
physicalTableName := tableInfo.ActualTableName
|
|
if physicalTableName == "" {
|
|
return nil, fmt.Errorf("physical table name is empty")
|
|
}
|
|
|
|
fieldNameToPhysical := make(map[string]string)
|
|
physicalToFieldName := make(map[string]string)
|
|
physicalToFieldType := make(map[string]table.FieldItemType)
|
|
|
|
for _, field := range tableInfo.FieldList {
|
|
if field.AlterID > 0 {
|
|
physicalName := physicaltable.GetFieldPhysicsName(field.AlterID)
|
|
fieldNameToPhysical[field.Name] = physicalName
|
|
physicalToFieldName[physicalName] = field.Name
|
|
physicalToFieldType[physicalName] = field.Type
|
|
}
|
|
}
|
|
fieldNameToPhysical[database.DefaultIDDisplayColName] = database.DefaultIDColName
|
|
fieldNameToPhysical[database.DefaultUidDisplayColName] = database.DefaultUidColName
|
|
fieldNameToPhysical[database.DefaultCreateTimeDisplayColName] = database.DefaultCreateTimeColName
|
|
|
|
var resultSet *entity3.ResultSet
|
|
var rowsAffected int64
|
|
|
|
switch req.OperateType {
|
|
case database.OperateType_Custom:
|
|
resultSet, err = d.executeCustomSQL(ctx, req, physicalTableName, tableInfo, fieldNameToPhysical)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
case database.OperateType_Select:
|
|
resultSet, err = d.executeSelectSQL(ctx, req, physicalTableName, tableInfo, fieldNameToPhysical)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
case database.OperateType_Insert:
|
|
resultSet, err = d.executeInsertSQL(ctx, req, physicalTableName, tableInfo)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
case database.OperateType_Update:
|
|
rowsAffected, err = d.executeUpdateSQL(ctx, req, physicalTableName, tableInfo, fieldNameToPhysical)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
case database.OperateType_Delete:
|
|
rowsAffected, err = d.executeDeleteSQL(ctx, req, physicalTableName, tableInfo, fieldNameToPhysical)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
default:
|
|
return nil, fmt.Errorf("unsupported operation type: %v", req.OperateType)
|
|
}
|
|
|
|
response := &ExecuteSQLResponse{
|
|
FieldList: tableInfo.FieldList,
|
|
}
|
|
|
|
if resultSet != nil && len(resultSet.Rows) > 0 {
|
|
response.Records = convertor.ConvertResultSet(resultSet, physicalToFieldName, physicalToFieldType)
|
|
} else {
|
|
response.Records = make([]map[string]interface{}, 0)
|
|
}
|
|
|
|
// process special system fields
|
|
for _, record := range response.Records {
|
|
if val, ok := record[database.DefaultUidColName]; ok {
|
|
delete(record, database.DefaultUidColName)
|
|
record[database.DefaultUidDisplayColName] = val
|
|
}
|
|
if val, ok := record[database.DefaultCreateTimeColName]; ok {
|
|
delete(record, database.DefaultCreateTimeColName)
|
|
record[database.DefaultCreateTimeDisplayColName] = val
|
|
}
|
|
if val, ok := record[database.DefaultIDColName]; ok {
|
|
delete(record, database.DefaultIDColName)
|
|
record[database.DefaultIDDisplayColName] = val
|
|
}
|
|
if _, ok := record[database.DefaultCidColName]; ok {
|
|
delete(record, database.DefaultCidColName)
|
|
}
|
|
}
|
|
|
|
if resultSet != nil && resultSet.AffectedRows > 0 {
|
|
response.RowsAffected = &resultSet.AffectedRows
|
|
}
|
|
|
|
if rowsAffected > 0 {
|
|
response.RowsAffected = &rowsAffected
|
|
}
|
|
|
|
return response, nil
|
|
}
|
|
|
|
func (d databaseService) executeCustomSQL(ctx context.Context, req *ExecuteSQLRequest, physicalTableName string, tableInfo *entity2.Database, fieldNameToPhysical map[string]string) (*entity3.ResultSet, error) {
|
|
var params []interface{}
|
|
if req.SQL == nil || *req.SQL == "" {
|
|
return nil, fmt.Errorf("SQL is empty")
|
|
}
|
|
|
|
operation, err := sqlparser.NewSQLParser().GetSQLOperation(*req.SQL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if tableInfo.RwMode == table.BotTableRWMode_ReadOnly && (operation == sqlparsercontract.OperationTypeInsert || operation == sqlparsercontract.OperationTypeUpdate || operation == sqlparsercontract.OperationTypeDelete) {
|
|
return nil, fmt.Errorf("unsupported operation type: %v", operation)
|
|
}
|
|
|
|
if req.SQLParams != nil {
|
|
params = make([]interface{}, 0, len(req.SQLParams))
|
|
for _, param := range req.SQLParams {
|
|
value := param.Value
|
|
if param.ISNull {
|
|
value = nil
|
|
}
|
|
params = append(params, value)
|
|
}
|
|
}
|
|
|
|
tableColumnMapping := map[string]sqlparsercontract.TableColumn{
|
|
tableInfo.TableName: {
|
|
NewTableName: &physicalTableName,
|
|
ColumnMap: fieldNameToPhysical,
|
|
},
|
|
}
|
|
|
|
parsedSQL, err := sqlparser.NewSQLParser().ParseAndModifySQL(*req.SQL, tableColumnMapping)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse sql failed: %v", err)
|
|
}
|
|
// add rw mode
|
|
if tableInfo.RwMode == table.BotTableRWMode_LimitedReadWrite && len(req.UserID) != 0 {
|
|
switch operation {
|
|
case sqlparsercontract.OperationTypeSelect, sqlparsercontract.OperationTypeUpdate, sqlparsercontract.OperationTypeDelete:
|
|
parsedSQL, err = sqlparser.NewSQLParser().AppendSQLFilter(parsedSQL, sqlparsercontract.SQLFilterOpAnd, fmt.Sprintf("%s = '%s'", database.DefaultUidColName, req.UserID))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("append sql filter failed: %v", err)
|
|
}
|
|
}
|
|
}
|
|
insertResult := make([]map[string]interface{}, 0)
|
|
if operation == sqlparsercontract.OperationTypeInsert {
|
|
cid := consts.CozeConnectorID
|
|
if req.ConnectorID != nil {
|
|
cid = *req.ConnectorID
|
|
}
|
|
nums, err := sqlparser.NewSQLParser().GetInsertDataNums(parsedSQL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
ids, err := d.generator.GenMultiIDs(ctx, nums)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, id := range ids {
|
|
insertResult = append(insertResult, map[string]interface{}{
|
|
database.DefaultIDColName: id,
|
|
})
|
|
}
|
|
|
|
existingCols := make(map[string]bool)
|
|
if req.SQLType == database.SQLType_Raw {
|
|
iIDs := make([]interface{}, len(ids))
|
|
for i, id := range ids {
|
|
iIDs[i] = id
|
|
}
|
|
parsedSQL, _, err = sqlparser.NewSQLParser().AddColumnsToInsertSQL(parsedSQL, []sqlparsercontract.ColumnValue{
|
|
{
|
|
ColName: database.DefaultCidColName,
|
|
Value: cid,
|
|
},
|
|
{
|
|
ColName: database.DefaultUidColName,
|
|
Value: req.UserID,
|
|
},
|
|
}, &sqlparsercontract.PrimaryKeyValue{ColName: database.DefaultIDColName, Values: iIDs}, false)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("add columns to insert sql failed: %v", err)
|
|
}
|
|
} else if req.SQLType == database.SQLType_Parameterized {
|
|
parsedSQL, existingCols, err = sqlparser.NewSQLParser().AddColumnsToInsertSQL(parsedSQL, []sqlparsercontract.ColumnValue{
|
|
{
|
|
ColName: database.DefaultCidColName,
|
|
},
|
|
{
|
|
ColName: database.DefaultUidColName,
|
|
},
|
|
}, &sqlparsercontract.PrimaryKeyValue{ColName: database.DefaultIDColName}, true)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("add columns to insert sql failed: %v", err)
|
|
}
|
|
|
|
if nums > 0 {
|
|
if len(params)%nums != 0 {
|
|
return nil, fmt.Errorf("number of params is not a multiple of number of rows")
|
|
}
|
|
paramsPerRow := len(params) / nums
|
|
newParams := make([]interface{}, 0)
|
|
for i := 0; i < nums; i++ {
|
|
newParams = append(newParams, params[i*paramsPerRow:(i+1)*paramsPerRow]...)
|
|
if !existingCols[database.DefaultCidColName] {
|
|
newParams = append(newParams, cid)
|
|
}
|
|
if !existingCols[database.DefaultUidColName] {
|
|
newParams = append(newParams, req.UserID)
|
|
}
|
|
if !existingCols[database.DefaultIDColName] {
|
|
newParams = append(newParams, ids[i])
|
|
}
|
|
}
|
|
params = newParams
|
|
}
|
|
}
|
|
}
|
|
|
|
execResp, err := d.rdb.ExecuteSQL(ctx, &rdb.ExecuteSQLRequest{
|
|
SQL: parsedSQL,
|
|
Params: params,
|
|
|
|
SQLType: entity3.SQLType(req.SQLType),
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("execute SQL failed: %v", err)
|
|
}
|
|
|
|
if operation == sqlparsercontract.OperationTypeInsert {
|
|
if execResp.ResultSet == nil {
|
|
execResp.ResultSet = &entity3.ResultSet{
|
|
Rows: insertResult,
|
|
}
|
|
} else {
|
|
execResp.ResultSet.Rows = insertResult
|
|
}
|
|
}
|
|
return execResp.ResultSet, nil
|
|
}
|
|
|
|
func (d databaseService) executeSelectSQL(ctx context.Context, req *ExecuteSQLRequest, physicalTableName string, tableInfo *entity2.Database, fieldNameToPhysical map[string]string) (*entity3.ResultSet, error) {
|
|
selectReq := &rdb.SelectDataRequest{
|
|
TableName: physicalTableName,
|
|
Limit: int64PtrToIntPtr(req.Limit),
|
|
Offset: int64PtrToIntPtr(req.Offset),
|
|
}
|
|
|
|
fieldList := append(tableInfo.FieldList, physicaltable.GetCreateTimeField(), physicaltable.GetUidField(), physicaltable.GetIDField(), physicaltable.GetConnectIDField())
|
|
fieldMap := slices.ToMap(fieldList, func(e *database.FieldItem) (string, *database.FieldItem) {
|
|
return strconv.FormatInt(e.AlterID, 10), e
|
|
})
|
|
|
|
if req.SelectFieldList != nil && !req.SelectFieldList.IsDistinct && len(req.SelectFieldList.FieldID) > 0 {
|
|
fields := make([]string, 0, len(req.SelectFieldList.FieldID))
|
|
for _, fieldID := range req.SelectFieldList.FieldID {
|
|
if _, exists := fieldMap[fieldID]; !exists {
|
|
return nil, fmt.Errorf("fieldID %s does not exist", fieldID)
|
|
}
|
|
|
|
field, _ := fieldMap[fieldID]
|
|
fields = append(fields, field.PhysicalName)
|
|
}
|
|
selectReq.Fields = fields
|
|
}
|
|
|
|
var complexCond *rdb.ComplexCondition
|
|
var err error
|
|
if req.Condition != nil {
|
|
complexCond, err = convertCondition(ctx, req.Condition, fieldNameToPhysical, req.SQLParams)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("convert condition failed: %v", err)
|
|
}
|
|
}
|
|
|
|
// add rw mode
|
|
if tableInfo.RwMode == table.BotTableRWMode_LimitedReadWrite && req.UserID != "" {
|
|
cond := &rdb.Condition{
|
|
Field: database.DefaultUidColName,
|
|
Operator: entity3.OperatorEqual,
|
|
Value: req.UserID,
|
|
}
|
|
|
|
if complexCond == nil {
|
|
complexCond = &rdb.ComplexCondition{
|
|
Conditions: []*rdb.Condition{cond},
|
|
}
|
|
} else {
|
|
complexCond.Conditions = append(complexCond.Conditions, cond)
|
|
}
|
|
}
|
|
|
|
if complexCond != nil {
|
|
selectReq.Where = complexCond
|
|
}
|
|
|
|
if len(req.OrderByList) > 0 {
|
|
orderBy := make([]*rdb.OrderBy, 0, len(req.OrderByList))
|
|
for _, order := range req.OrderByList {
|
|
physicalField := order.Field
|
|
if mapped, exists := fieldNameToPhysical[order.Field]; exists {
|
|
physicalField = mapped
|
|
}
|
|
|
|
orderBy = append(orderBy, &rdb.OrderBy{
|
|
Field: physicalField,
|
|
Direction: convertSortDirection(order.Direction),
|
|
})
|
|
}
|
|
selectReq.OrderBy = orderBy
|
|
}
|
|
|
|
selectResp, err := d.rdb.SelectData(ctx, selectReq)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("select data failed: %v", err)
|
|
}
|
|
|
|
return selectResp.ResultSet, nil
|
|
}
|
|
|
|
func (d databaseService) executeInsertSQL(ctx context.Context, req *ExecuteSQLRequest, physicalTableName string, tableInfo *entity2.Database) (*entity3.ResultSet, error) {
|
|
if len(req.UpsertRows) == 0 {
|
|
return nil, fmt.Errorf("no data to insert")
|
|
}
|
|
|
|
insertData := make([]map[string]interface{}, 0, len(req.UpsertRows))
|
|
ids, err := d.generator.GenMultiIDs(ctx, len(req.UpsertRows))
|
|
if err != nil {
|
|
return nil, errorx.WrapByCode(err, errno.ErrMemoryIDGenFailCode, errorx.KV("msg", "executeInsertSQL"))
|
|
}
|
|
|
|
fieldList := append(tableInfo.FieldList, physicaltable.GetCreateTimeField(), physicaltable.GetUidField(), physicaltable.GetIDField(), physicaltable.GetConnectIDField())
|
|
fieldMap := slices.ToMap(fieldList, func(e *database.FieldItem) (string, *database.FieldItem) {
|
|
return strconv.FormatInt(e.AlterID, 10), e
|
|
})
|
|
|
|
sqlParams := req.SQLParams
|
|
i := 0
|
|
|
|
insertResult := make([]map[string]interface{}, 0, len(req.UpsertRows))
|
|
for index, upsertRow := range req.UpsertRows {
|
|
rowData := make(map[string]interface{})
|
|
|
|
cid := consts.CozeConnectorID
|
|
if req.ConnectorID != nil {
|
|
cid = *req.ConnectorID
|
|
}
|
|
|
|
if req.UserID != "" {
|
|
rowData[database.DefaultUidColName] = req.UserID
|
|
}
|
|
rowData[database.DefaultCidColName] = cid
|
|
rowData[database.DefaultCreateTimeColName] = time.Now()
|
|
rowData[database.DefaultIDColName] = ids[index]
|
|
|
|
for _, record := range upsertRow.Records {
|
|
field, exists := fieldMap[record.FieldId]
|
|
if !exists {
|
|
return nil, errorx.New(errno.ErrMemoryDatabaseFieldNotFoundCode)
|
|
}
|
|
|
|
fieldVal := sqlParams[i].Value
|
|
if sqlParams[i].ISNull || fieldVal == nil {
|
|
rowData[field.PhysicalName] = nil
|
|
i++
|
|
continue
|
|
}
|
|
|
|
convertedValue, err := convertor.ConvertValueByType(*fieldVal, field.Type)
|
|
if err != nil {
|
|
logs.Warnf("convert value failed: %v, using original value", err)
|
|
rowData[field.PhysicalName] = *fieldVal
|
|
} else {
|
|
rowData[field.PhysicalName] = convertedValue
|
|
}
|
|
i++
|
|
}
|
|
|
|
insertData = append(insertData, rowData)
|
|
insertResult = append(insertResult, map[string]interface{}{
|
|
database.DefaultIDColName: ids[index],
|
|
})
|
|
}
|
|
|
|
insertResp, err := d.rdb.InsertData(ctx, &rdb.InsertDataRequest{
|
|
TableName: physicalTableName,
|
|
Data: insertData,
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("insert data failed: %v", err)
|
|
}
|
|
|
|
return &entity3.ResultSet{
|
|
Rows: insertResult,
|
|
AffectedRows: insertResp.AffectedRows,
|
|
}, nil
|
|
}
|
|
|
|
func (d databaseService) executeUpdateSQL(ctx context.Context, req *ExecuteSQLRequest, physicalTableName string, tableInfo *entity2.Database, fieldNameToPhysical map[string]string) (int64, error) {
|
|
if len(req.UpsertRows) == 0 || req.Condition == nil {
|
|
return -1, fmt.Errorf("missing update data or condition")
|
|
}
|
|
|
|
fieldList := append(tableInfo.FieldList, physicaltable.GetCreateTimeField(), physicaltable.GetUidField(), physicaltable.GetIDField(), physicaltable.GetConnectIDField())
|
|
fieldMap := slices.ToMap(fieldList, func(e *database.FieldItem) (string, *database.FieldItem) {
|
|
return strconv.FormatInt(e.AlterID, 10), e
|
|
})
|
|
|
|
updateData := make(map[string]interface{})
|
|
index := 0
|
|
for _, record := range req.UpsertRows[0].Records {
|
|
field, exists := fieldMap[record.FieldId]
|
|
if !exists {
|
|
return -1, errorx.New(errno.ErrMemoryDatabaseFieldNotFoundCode)
|
|
}
|
|
|
|
param := req.SQLParams[index]
|
|
fieldVal := param.Value
|
|
index++
|
|
if param.ISNull || fieldVal == nil {
|
|
updateData[field.PhysicalName] = nil
|
|
continue
|
|
}
|
|
|
|
convertedValue, err := convertor.ConvertValueByType(*fieldVal, field.Type)
|
|
if err != nil {
|
|
logs.Warnf("convert value failed: %v, using original value", err)
|
|
updateData[field.PhysicalName] = *fieldVal
|
|
} else {
|
|
updateData[field.PhysicalName] = convertedValue
|
|
}
|
|
}
|
|
|
|
condParams := req.SQLParams[index:]
|
|
complexCond, err := convertCondition(ctx, req.Condition, fieldNameToPhysical, condParams)
|
|
if err != nil {
|
|
return -1, fmt.Errorf("convert condition failed: %v", err)
|
|
}
|
|
|
|
// add rw mode
|
|
if tableInfo.RwMode == table.BotTableRWMode_LimitedReadWrite && req.UserID != "" {
|
|
cond := &rdb.Condition{
|
|
Field: database.DefaultUidColName,
|
|
Operator: entity3.OperatorEqual,
|
|
Value: req.UserID,
|
|
}
|
|
|
|
if complexCond == nil {
|
|
complexCond = &rdb.ComplexCondition{
|
|
Conditions: []*rdb.Condition{cond},
|
|
}
|
|
} else {
|
|
complexCond.Conditions = append(complexCond.Conditions, cond)
|
|
}
|
|
}
|
|
|
|
updateResp, err := d.rdb.UpdateData(ctx, &rdb.UpdateDataRequest{
|
|
TableName: physicalTableName,
|
|
Data: updateData,
|
|
Where: complexCond,
|
|
Limit: int64PtrToIntPtr(req.Limit),
|
|
})
|
|
if err != nil {
|
|
return -1, fmt.Errorf("update data failed: %v", err)
|
|
}
|
|
|
|
return updateResp.AffectedRows, nil
|
|
}
|
|
|
|
func (d databaseService) executeDeleteSQL(ctx context.Context, req *ExecuteSQLRequest, physicalTableName string, tableInfo *entity2.Database, fieldNameToPhysical map[string]string) (int64, error) {
|
|
if req.Condition == nil {
|
|
return -1, fmt.Errorf("missing delete condition")
|
|
}
|
|
|
|
complexCond, err := convertCondition(ctx, req.Condition, fieldNameToPhysical, req.SQLParams)
|
|
if err != nil {
|
|
return -1, fmt.Errorf("convert condition failed: %v", err)
|
|
}
|
|
|
|
// add rw mode
|
|
if tableInfo.RwMode == table.BotTableRWMode_LimitedReadWrite && req.UserID != "" {
|
|
cond := &rdb.Condition{
|
|
Field: database.DefaultUidColName,
|
|
Operator: entity3.OperatorEqual,
|
|
Value: req.UserID,
|
|
}
|
|
|
|
if complexCond == nil {
|
|
complexCond = &rdb.ComplexCondition{
|
|
Conditions: []*rdb.Condition{cond},
|
|
}
|
|
} else {
|
|
complexCond.Conditions = append(complexCond.Conditions, cond)
|
|
}
|
|
}
|
|
|
|
deleteResp, err := d.rdb.DeleteData(ctx, &rdb.DeleteDataRequest{
|
|
TableName: physicalTableName,
|
|
Where: complexCond,
|
|
Limit: int64PtrToIntPtr(req.Limit),
|
|
})
|
|
if err != nil {
|
|
return -1, fmt.Errorf("delete data failed: %v", err)
|
|
}
|
|
|
|
return deleteResp.AffectedRows, nil
|
|
}
|
|
|
|
func int64PtrToIntPtr(i64ptr *int64) *int {
|
|
if i64ptr == nil {
|
|
return nil
|
|
}
|
|
|
|
i := int(*i64ptr)
|
|
return &i
|
|
}
|
|
|
|
func convertSortDirection(direction table.SortDirection) entity3.SortDirection {
|
|
if direction == table.SortDirection_Desc {
|
|
return entity3.SortDirectionDesc
|
|
}
|
|
return entity3.SortDirectionAsc
|
|
}
|
|
|
|
func convertCondition(ctx context.Context, cond *database.ComplexCondition, fieldMap map[string]string, params []*database.SQLParamVal) (*rdb.ComplexCondition, error) {
|
|
if cond == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
result := &rdb.ComplexCondition{
|
|
Operator: convertor.ConvertLogicOperator(cond.Logic),
|
|
}
|
|
|
|
index := 0
|
|
if len(cond.Conditions) > 0 {
|
|
conditions := make([]*rdb.Condition, 0, len(cond.Conditions))
|
|
for _, c := range cond.Conditions {
|
|
leftField := c.Left
|
|
if mapped, exists := fieldMap[c.Left]; exists {
|
|
leftField = mapped
|
|
}
|
|
|
|
if c.Operation == database.Operation_IS_NULL || c.Operation == database.Operation_IS_NOT_NULL {
|
|
conditions = append(conditions, &rdb.Condition{
|
|
Field: leftField,
|
|
Operator: convertor.ConvertOperator(c.Operation),
|
|
})
|
|
continue
|
|
}
|
|
|
|
if c.Operation == database.Operation_IN || c.Operation == database.Operation_NOT_IN {
|
|
// c.Right: example: (?,?)
|
|
qCount := 0
|
|
for i := 0; i < len(c.Right); i++ {
|
|
if c.Right[i] == '?' {
|
|
qCount++
|
|
}
|
|
}
|
|
if qCount == 0 {
|
|
return nil, fmt.Errorf("IN/NOT_IN condition right side must contain ? placeholders")
|
|
}
|
|
vals := make([]interface{}, 0, qCount)
|
|
for j := 0; j < qCount; j++ {
|
|
if index >= len(params) {
|
|
return nil, fmt.Errorf("not enough params for IN/NOT_IN condition")
|
|
}
|
|
if params[index].ISNull || params[index].Value == nil {
|
|
index++
|
|
continue
|
|
}
|
|
vals = append(vals, decryptSysUUIDKey(ctx, leftField, *params[index].Value))
|
|
index++
|
|
}
|
|
conditions = append(conditions, &rdb.Condition{
|
|
Field: leftField,
|
|
Operator: convertor.ConvertOperator(c.Operation),
|
|
Value: vals,
|
|
})
|
|
continue
|
|
}
|
|
|
|
if params[index].ISNull || params[index].Value == nil {
|
|
index++
|
|
continue
|
|
}
|
|
|
|
conditions = append(conditions, &rdb.Condition{
|
|
Field: leftField,
|
|
Operator: convertor.ConvertOperator(c.Operation),
|
|
Value: decryptSysUUIDKey(ctx, leftField, *params[index].Value),
|
|
})
|
|
index++
|
|
}
|
|
result.Conditions = conditions
|
|
}
|
|
// if cond.NestedConditions != nil {
|
|
// nested, err := convertCondition(cond.NestedConditions, fieldMap, params)
|
|
// if err != nil {
|
|
// return nil, err
|
|
// }
|
|
// result.NestedConditions = []*rdb.ComplexCondition{nested}
|
|
// }
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func decryptSysUUIDKey(ctx context.Context, leftField, value string) string {
|
|
if leftField == database.DefaultUidDisplayColName || leftField == database.DefaultUidColName {
|
|
decryptVal := crossvariables.DefaultSVC().DecryptSysUUIDKey(ctx, value)
|
|
if decryptVal != nil {
|
|
value = decryptVal.ConnectorUID
|
|
}
|
|
}
|
|
|
|
return value
|
|
}
|
|
|
|
func (d databaseService) BindDatabase(ctx context.Context, req *BindDatabaseToAgentRequest) error {
|
|
draft, err := d.MGetDatabase(ctx, &MGetDatabaseRequest{
|
|
Basics: []*database.DatabaseBasic{
|
|
{
|
|
ID: req.DraftDatabaseID,
|
|
TableType: table.TableType_DraftTable,
|
|
},
|
|
},
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if len(draft.Databases) == 0 {
|
|
return fmt.Errorf("online table not found, id: %d", req.DraftDatabaseID)
|
|
}
|
|
|
|
onlineID := draft.Databases[0].GetOnlineID()
|
|
relations := []*database.AgentToDatabase{
|
|
{
|
|
AgentID: req.AgentID,
|
|
DatabaseID: onlineID,
|
|
TableType: table.TableType_OnlineTable,
|
|
},
|
|
{
|
|
AgentID: req.AgentID,
|
|
DatabaseID: req.DraftDatabaseID,
|
|
TableType: table.TableType_DraftTable,
|
|
},
|
|
}
|
|
|
|
_, err = d.agentToDatabaseDAO.BatchCreate(ctx, relations)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to bind databases to agent: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (d databaseService) UnBindDatabase(ctx context.Context, req *UnBindDatabaseToAgentRequest) error {
|
|
draft, err := d.MGetDatabase(ctx, &MGetDatabaseRequest{
|
|
Basics: []*database.DatabaseBasic{
|
|
{
|
|
ID: req.DraftDatabaseID,
|
|
TableType: table.TableType_DraftTable,
|
|
},
|
|
},
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if len(draft.Databases) == 0 {
|
|
return fmt.Errorf("online table not found, id: %d", req.DraftDatabaseID)
|
|
}
|
|
|
|
onlineID := draft.Databases[0].GetOnlineID()
|
|
relations := []*database.AgentToDatabaseBasic{
|
|
{
|
|
AgentID: req.AgentID,
|
|
DatabaseID: onlineID,
|
|
},
|
|
{
|
|
AgentID: req.AgentID,
|
|
DatabaseID: req.DraftDatabaseID,
|
|
},
|
|
}
|
|
|
|
err = d.agentToDatabaseDAO.BatchDelete(ctx, relations)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to unbind databases from agent: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (d databaseService) MGetDatabaseByAgentID(ctx context.Context, req *MGetDatabaseByAgentIDRequest) (*MGetDatabaseByAgentIDResponse, error) {
|
|
if req == nil {
|
|
return nil, fmt.Errorf("invalid request: request is nil")
|
|
}
|
|
|
|
relations, err := d.agentToDatabaseDAO.ListByAgentID(ctx, req.AgentID, req.TableType)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
mGetBasics := make([]*database.DatabaseBasic, 0, len(relations))
|
|
for _, relation := range relations {
|
|
mGetBasics = append(mGetBasics, &database.DatabaseBasic{
|
|
ID: relation.DatabaseID,
|
|
TableType: req.TableType,
|
|
NeedSysFields: req.NeedSysFields,
|
|
})
|
|
}
|
|
databases, err := d.MGetDatabase(ctx, &MGetDatabaseRequest{Basics: mGetBasics})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &MGetDatabaseByAgentIDResponse{
|
|
Databases: databases.Databases,
|
|
}, nil
|
|
}
|
|
|
|
// PublishDatabase return online database according to draft database info
|
|
func (d databaseService) PublishDatabase(ctx context.Context, req *PublishDatabaseRequest) (*PublishDatabaseResponse, error) {
|
|
if req == nil {
|
|
return nil, fmt.Errorf("invalid request: request is nil")
|
|
}
|
|
|
|
relationResp, err := d.MGetRelationsByAgentID(ctx, &MGetRelationsByAgentIDRequest{
|
|
AgentID: req.AgentID,
|
|
TableType: table.TableType_DraftTable,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(relationResp.Relations) == 0 {
|
|
return &PublishDatabaseResponse{}, nil
|
|
}
|
|
|
|
dBasics := make([]*database.DatabaseBasic, 0, len(relationResp.Relations))
|
|
for _, draftR := range relationResp.Relations {
|
|
dBasics = append(dBasics, &database.DatabaseBasic{
|
|
ID: draftR.DatabaseID,
|
|
TableType: table.TableType_DraftTable,
|
|
NeedSysFields: false,
|
|
})
|
|
}
|
|
|
|
draftDatabaseResp, err := d.MGetDatabase(ctx, &MGetDatabaseRequest{
|
|
Basics: dBasics,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
oBasics := make([]*database.DatabaseBasic, 0, len(draftDatabaseResp.Databases))
|
|
for _, draft := range draftDatabaseResp.Databases {
|
|
oBasics = append(oBasics, &database.DatabaseBasic{
|
|
ID: draft.GetOnlineID(),
|
|
TableType: table.TableType_OnlineTable,
|
|
NeedSysFields: false,
|
|
})
|
|
}
|
|
|
|
onlineDatabaseResp, err := d.MGetDatabase(ctx, &MGetDatabaseRequest{
|
|
Basics: oBasics,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
results := make([]*bot_common.Database, 0, len(onlineDatabaseResp.Databases))
|
|
for _, online := range onlineDatabaseResp.Databases {
|
|
fields := make([]*bot_common.FieldItem, 0, len(online.FieldList))
|
|
for _, field := range online.FieldList {
|
|
fields = append(fields, &bot_common.FieldItem{
|
|
Name: ptr.Of(field.Name),
|
|
Desc: ptr.Of(field.Desc),
|
|
Type: ptr.Of(bot_common.FieldItemType(field.Type)),
|
|
MustRequired: ptr.Of(field.MustRequired),
|
|
AlterId: ptr.Of(field.AlterID),
|
|
Id: ptr.Of(int64(0)),
|
|
})
|
|
}
|
|
|
|
results = append(results, &bot_common.Database{
|
|
TableId: ptr.Of(strconv.FormatInt(online.ID, 10)),
|
|
TableName: ptr.Of(online.TableName),
|
|
TableDesc: ptr.Of(online.TableDesc),
|
|
FieldList: fields,
|
|
RWMode: ptr.Of(bot_common.BotTableRWMode(online.RwMode)),
|
|
})
|
|
}
|
|
|
|
return &PublishDatabaseResponse{
|
|
OnlineDatabases: results,
|
|
}, nil
|
|
}
|
|
|
|
func (d databaseService) MGetRelationsByAgentID(ctx context.Context, req *MGetRelationsByAgentIDRequest) (*MGetRelationsByAgentIDResponse, error) {
|
|
if req == nil {
|
|
return nil, fmt.Errorf("invalid request: request is nil")
|
|
}
|
|
|
|
relations, err := d.agentToDatabaseDAO.ListByAgentID(ctx, req.AgentID, req.TableType)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &MGetRelationsByAgentIDResponse{
|
|
Relations: relations,
|
|
}, nil
|
|
}
|
|
|
|
func (d databaseService) GetDatabaseTableSchema(ctx context.Context, req *GetDatabaseTableSchemaRequest) (*GetDatabaseTableSchemaResponse, error) {
|
|
parser := &sheet.TosTableParser{
|
|
UserID: req.UserID,
|
|
DocumentSource: database.DocumentSourceType_Document,
|
|
TosURI: req.TosURL,
|
|
TosServ: d.storage,
|
|
}
|
|
|
|
res, extra, err := parser.GetTableDataBySheetIDx(ctx, entity2.TableReaderMeta{
|
|
TosMaxLine: 100000,
|
|
HeaderLineIdx: req.TableSheet.HeaderLineIdx,
|
|
SheetId: req.TableSheet.SheetID,
|
|
StartLineIdx: req.TableSheet.StartLineIdx,
|
|
ReaderMethod: database.TableReadDataMethodHead,
|
|
ReadLineCnt: 20,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
res.Columns, err = parser.PredictColumnType(res.Columns, res.SampleData, req.TableSheet.SheetID, req.TableSheet.StartLineIdx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
resp := &GetDatabaseTableSchemaResponse{}
|
|
if req.TableDataType == table.TableDataType_AllData || req.TableDataType == table.TableDataType_OnlyPreview {
|
|
previewData, tErr := parser.TransferPreviewData(ctx, res.Columns, res.SampleData, 20)
|
|
if tErr != nil {
|
|
return resp, tErr
|
|
}
|
|
resp.PreviewData = previewData
|
|
}
|
|
resp.TableMeta = res.Columns
|
|
resp.SheetList = extra.Sheets
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
func (d databaseService) ValidateDatabaseTableSchema(ctx context.Context, req *ValidateDatabaseTableSchemaRequest) (*ValidateDatabaseTableSchemaResponse, error) {
|
|
parser := &sheet.TosTableParser{
|
|
UserID: req.UserID,
|
|
DocumentSource: database.DocumentSourceType_Document,
|
|
TosURI: req.TosURL,
|
|
TosServ: d.storage,
|
|
}
|
|
|
|
res, sheetRes, err := parser.GetTableDataBySheetIDx(ctx, entity2.TableReaderMeta{
|
|
TosMaxLine: 100000,
|
|
HeaderLineIdx: req.TableSheet.HeaderLineIdx,
|
|
SheetId: req.TableSheet.SheetID,
|
|
StartLineIdx: req.TableSheet.StartLineIdx,
|
|
ReaderMethod: database.TableReadDataMethodAll,
|
|
ReadLineCnt: 20,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
valid, invalidMsg := sheet.CheckSheetIsValid(req.Fields, res.Columns, sheetRes)
|
|
return &ValidateDatabaseTableSchemaResponse{
|
|
Valid: valid,
|
|
InvalidMsg: invalidMsg,
|
|
}, nil
|
|
}
|
|
|
|
func (d databaseService) SubmitDatabaseInsertTask(ctx context.Context, req *SubmitDatabaseInsertTaskRequest) error {
|
|
var err error
|
|
failKey := onlineFailReasonKey
|
|
if req.TableType == table.TableType_DraftTable {
|
|
failKey = draftFailReasonKey
|
|
}
|
|
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
errMsg := fmt.Sprintf("panic: %v", r)
|
|
d.cache.Set(ctx, fmt.Sprintf(failKey, req.DatabaseID, req.UserID), errMsg, redisKeyTimeOut)
|
|
err = fmt.Errorf("panic: %v", r)
|
|
return
|
|
}
|
|
if err != nil {
|
|
d.cache.Set(ctx, fmt.Sprintf(failKey, req.DatabaseID, req.UserID), err.Error(), redisKeyTimeOut)
|
|
}
|
|
}()
|
|
|
|
parser := &sheet.TosTableParser{
|
|
UserID: req.UserID,
|
|
DocumentSource: database.DocumentSourceType_Document,
|
|
TosURI: req.FileURI,
|
|
TosServ: d.storage,
|
|
}
|
|
parseData, extra, err := parser.GetTableDataBySheetIDx(ctx, entity2.TableReaderMeta{
|
|
TosMaxLine: 100000,
|
|
SheetId: req.TableSheet.SheetID,
|
|
HeaderLineIdx: req.TableSheet.HeaderLineIdx,
|
|
StartLineIdx: req.TableSheet.StartLineIdx,
|
|
ReaderMethod: database.TableReadDataMethodAll,
|
|
},
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = d.initializeCache(ctx, req, parseData, extra)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
columns := parseData.Columns
|
|
|
|
records := make([]map[string]string, 0, len(parseData.SampleData))
|
|
for _, data := range parseData.SampleData {
|
|
record := make(map[string]string)
|
|
for i, column := range columns {
|
|
record[column.ColumnName] = data[i]
|
|
}
|
|
records = append(records, record)
|
|
}
|
|
|
|
batchSize := 20
|
|
go func() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
errMsg := fmt.Sprintf("panic: %v", r)
|
|
d.cache.Set(ctx, fmt.Sprintf(failKey, req.DatabaseID, req.UserID), errMsg, redisKeyTimeOut)
|
|
}
|
|
}()
|
|
|
|
for i := 0; i < len(records); i += batchSize {
|
|
end := i + batchSize
|
|
if end > len(records) {
|
|
end = len(records)
|
|
}
|
|
batchRecords := records[i:end]
|
|
err = d.AddDatabaseRecord(ctx, &AddDatabaseRecordRequest{
|
|
DatabaseID: req.DatabaseID,
|
|
TableType: req.TableType,
|
|
ConnectorID: req.ConnectorID,
|
|
UserID: req.UserID,
|
|
Records: batchRecords,
|
|
})
|
|
if err != nil {
|
|
d.cache.Set(ctx, fmt.Sprintf(failKey, req.DatabaseID, req.UserID), err.Error(), redisKeyTimeOut)
|
|
return
|
|
}
|
|
|
|
err = d.increaseProgress(ctx, req, int64(len(batchRecords)))
|
|
if err != nil {
|
|
d.cache.Set(ctx, fmt.Sprintf(failKey, req.DatabaseID, req.UserID), err.Error(), redisKeyTimeOut)
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (d databaseService) GetDatabaseFileProgressData(ctx context.Context, req *GetDatabaseFileProgressDataRequest) (*GetDatabaseFileProgressDataResponse, error) {
|
|
totalKey := onlineTotalCountKey
|
|
if req.TableType == table.TableType_DraftTable {
|
|
totalKey = draftTotalCountKey
|
|
}
|
|
progressKey := onlineProgressKey
|
|
if req.TableType == table.TableType_DraftTable {
|
|
progressKey = draftProgressKey
|
|
}
|
|
failKey := onlineFailReasonKey
|
|
if req.TableType == table.TableType_DraftTable {
|
|
failKey = draftFailReasonKey
|
|
}
|
|
currentFileName := onlineCurrentFileName
|
|
if req.TableType == table.TableType_DraftTable {
|
|
currentFileName = draftCurrentFileName
|
|
}
|
|
totalNum, err := d.cache.Get(ctx, fmt.Sprintf(totalKey, req.DatabaseID, req.UserID)).Int64()
|
|
if err != nil && !errors.Is(err, cache.Nil) {
|
|
return nil, err
|
|
}
|
|
|
|
progressNum, err := d.cache.Get(ctx, fmt.Sprintf(progressKey, req.DatabaseID, req.UserID)).Int64()
|
|
if err != nil && !errors.Is(err, cache.Nil) {
|
|
return nil, err
|
|
}
|
|
|
|
failReason, err := d.cache.Get(ctx, fmt.Sprintf(failKey, req.DatabaseID, req.UserID)).Result()
|
|
if err != nil && !errors.Is(err, cache.Nil) {
|
|
return nil, err
|
|
}
|
|
|
|
fileName, err := d.cache.Get(ctx, fmt.Sprintf(currentFileName, req.DatabaseID, req.UserID)).Result()
|
|
if err != nil && !errors.Is(err, cache.Nil) {
|
|
return nil, err
|
|
}
|
|
|
|
resp := &GetDatabaseFileProgressDataResponse{}
|
|
if totalNum == 0 {
|
|
resp.FileName = ""
|
|
resp.Progress = 100
|
|
} else {
|
|
resp.FileName = fileName
|
|
resp.Progress = int32(float32(progressNum) / float32(totalNum) * 100)
|
|
resp.StatusDescript = ptr.Of(failReason)
|
|
}
|
|
return resp, nil
|
|
}
|
|
|
|
const (
|
|
draftTotalCountKey = "database_file_%d_%d_draft_total"
|
|
onlineTotalCountKey = "database_file_%d_%d_online_total"
|
|
draftProgressKey = "database_file_%d_%d_draft_progress"
|
|
onlineProgressKey = "database_file_%d_%d_online_progress"
|
|
draftFailReasonKey = "database_file_%d_%d_draft_fail_reason"
|
|
onlineFailReasonKey = "database_file_%d_%d_online_fail_reason"
|
|
draftCurrentFileName = "database_file_%d_%d_draft_file_name"
|
|
onlineCurrentFileName = "database_file_%d_%d_online_file_name"
|
|
redisKeyTimeOut = time.Hour * 12
|
|
)
|
|
|
|
func (d databaseService) initializeCache(ctx context.Context, req *SubmitDatabaseInsertTaskRequest, parseData *entity2.TableReaderSheetData, extra *entity2.ExcelExtraInfo) error {
|
|
tableType := req.TableType
|
|
userID := req.UserID
|
|
databaseID := req.DatabaseID
|
|
|
|
totalKey := onlineTotalCountKey
|
|
if tableType == table.TableType_DraftTable {
|
|
totalKey = draftTotalCountKey
|
|
}
|
|
currentFileName := onlineCurrentFileName
|
|
if tableType == table.TableType_DraftTable {
|
|
currentFileName = draftCurrentFileName
|
|
}
|
|
progressKey := onlineProgressKey
|
|
if tableType == table.TableType_DraftTable {
|
|
progressKey = draftProgressKey
|
|
}
|
|
failKey := onlineFailReasonKey
|
|
if tableType == table.TableType_DraftTable {
|
|
failKey = draftFailReasonKey
|
|
}
|
|
|
|
_, err := d.cache.Set(ctx, fmt.Sprintf(totalKey, databaseID, userID), fmt.Sprintf("%d", len(parseData.SampleData)), redisKeyTimeOut).Result()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = d.cache.Set(ctx, fmt.Sprintf(progressKey, databaseID, userID), int64(0), redisKeyTimeOut).Result()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = d.cache.Set(ctx, fmt.Sprintf(failKey, databaseID, userID), "", redisKeyTimeOut).Result()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = d.cache.Set(ctx, fmt.Sprintf(currentFileName, databaseID, userID), extra.Sheets[req.TableSheet.SheetID].SheetName, redisKeyTimeOut).Result()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (d databaseService) increaseProgress(ctx context.Context, req *SubmitDatabaseInsertTaskRequest, successNum int64) error {
|
|
tableType := req.TableType
|
|
userID := req.UserID
|
|
databaseID := req.DatabaseID
|
|
|
|
progressKey := onlineProgressKey
|
|
if tableType == table.TableType_DraftTable {
|
|
progressKey = draftProgressKey
|
|
}
|
|
|
|
_, err := d.cache.IncrBy(ctx, fmt.Sprintf(progressKey, databaseID, userID), successNum).Result()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (d databaseService) GetDraftDatabaseByOnlineID(ctx context.Context, req *GetDraftDatabaseByOnlineIDRequest) (*GetDraftDatabaseByOnlineIDResponse, error) {
|
|
online, err := d.MGetDatabase(ctx, &MGetDatabaseRequest{
|
|
Basics: []*database.DatabaseBasic{
|
|
{
|
|
ID: req.OnlineID,
|
|
TableType: table.TableType_OnlineTable,
|
|
},
|
|
},
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(online.Databases) == 0 {
|
|
return nil, fmt.Errorf("online table not found, id: %d", req.OnlineID)
|
|
}
|
|
|
|
draftID := online.Databases[0].GetDraftID()
|
|
|
|
draftResp, err := d.MGetDatabase(ctx, &MGetDatabaseRequest{
|
|
Basics: []*database.DatabaseBasic{
|
|
{
|
|
ID: draftID,
|
|
TableType: table.TableType_DraftTable,
|
|
},
|
|
},
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(draftResp.Databases) == 0 {
|
|
return nil, fmt.Errorf("online table not found, id: %d", req.OnlineID)
|
|
}
|
|
|
|
return &GetDraftDatabaseByOnlineIDResponse{
|
|
Database: draftResp.Databases[0],
|
|
}, nil
|
|
}
|
|
|
|
// DeleteDatabaseByAppID delete all records and all physical tables by app id
|
|
func (d databaseService) DeleteDatabaseByAppID(ctx context.Context, req *DeleteDatabaseByAppIDRequest) (*DeleteDatabaseByAppIDResponse, error) {
|
|
onlineDBInfos, err := d.listDatabasesByAppID(ctx, req.AppID, table.TableType_OnlineTable)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
draftDBInfos, err := d.listDatabasesByAppID(ctx, req.AppID, table.TableType_DraftTable)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
tx := query.Use(d.db).Begin()
|
|
if tx.Error != nil {
|
|
return nil, fmt.Errorf("start transaction failed, %v", tx.Error)
|
|
}
|
|
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
e := tx.Rollback()
|
|
if e != nil {
|
|
logs.CtxErrorf(ctx, "rollback failed, err=%v", e)
|
|
}
|
|
|
|
err = fmt.Errorf("catch panic: %v\nstack=%s", r, string(debug.Stack()))
|
|
return
|
|
}
|
|
|
|
if err != nil {
|
|
e := tx.Rollback()
|
|
if e != nil {
|
|
logs.CtxErrorf(ctx, "rollback failed, err=%v", e)
|
|
}
|
|
}
|
|
}()
|
|
|
|
onlineIDs := make([]int64, 0, len(onlineDBInfos))
|
|
for _, db := range onlineDBInfos {
|
|
onlineIDs = append(onlineIDs, db.ID)
|
|
}
|
|
|
|
draftIDs := make([]int64, 0, len(draftDBInfos))
|
|
for _, db := range draftDBInfos {
|
|
draftIDs = append(draftIDs, db.ID)
|
|
}
|
|
|
|
if err = d.onlineDAO.BatchDeleteWithTX(ctx, tx, onlineIDs); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err = d.draftDAO.BatchDeleteWithTX(ctx, tx, draftIDs); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
err = tx.Commit()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("commit transaction failed: %v", err)
|
|
}
|
|
|
|
// delete draft and online physical table
|
|
onlinePhysicals := make([]string, 0, len(onlineDBInfos))
|
|
for _, db := range onlineDBInfos {
|
|
onlinePhysicals = append(onlinePhysicals, db.ActualTableName)
|
|
}
|
|
|
|
draftPhysicals := make([]string, 0, len(draftDBInfos))
|
|
for _, db := range draftDBInfos {
|
|
draftPhysicals = append(draftPhysicals, db.ActualTableName)
|
|
}
|
|
|
|
for _, physical := range onlinePhysicals {
|
|
_, err = d.rdb.DropTable(ctx, &rdb.DropTableRequest{
|
|
TableName: physical,
|
|
})
|
|
if err != nil {
|
|
logs.Errorf("drop online physical table failed: %v, table_name=%s", err, physical)
|
|
}
|
|
}
|
|
for _, physical := range draftPhysicals {
|
|
_, err = d.rdb.DropTable(ctx, &rdb.DropTableRequest{
|
|
TableName: physical,
|
|
})
|
|
if err != nil {
|
|
logs.Errorf("drop draft physical table failed: %v, table_name=%s", err, physical)
|
|
}
|
|
}
|
|
|
|
return &DeleteDatabaseByAppIDResponse{
|
|
DeletedDatabaseIDs: onlineIDs,
|
|
}, nil
|
|
}
|
|
|
|
func (d databaseService) listDatabasesByAppID(ctx context.Context, appID int64, tableType table.TableType) ([]*entity2.Database, error) {
|
|
const batchSize = 100
|
|
offset := 0
|
|
dbInfos := make([]*entity2.Database, 0)
|
|
for {
|
|
resp, err := d.ListDatabase(ctx, &ListDatabaseRequest{
|
|
AppID: appID,
|
|
TableType: tableType,
|
|
Limit: batchSize,
|
|
Offset: offset,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, db := range resp.Databases {
|
|
dbInfos = append(dbInfos, db)
|
|
}
|
|
|
|
if !resp.HasMore {
|
|
break
|
|
}
|
|
|
|
offset += batchSize
|
|
}
|
|
|
|
return dbInfos, nil
|
|
}
|
|
|
|
func (d databaseService) GetAllDatabaseByAppID(ctx context.Context, req *GetAllDatabaseByAppIDRequest) (*GetAllDatabaseByAppIDResponse, error) {
|
|
onlineDBs, err := d.listDatabasesByAppID(ctx, req.AppID, table.TableType_OnlineTable)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &GetAllDatabaseByAppIDResponse{
|
|
Databases: onlineDBs,
|
|
}, nil
|
|
}
|