2207 lines
		
	
	
		
			63 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			2207 lines
		
	
	
		
			63 KiB
		
	
	
	
		
			Go
		
	
	
	
| /*
 | |
|  * Copyright 2025 coze-dev Authors
 | |
|  *
 | |
|  * Licensed under the Apache License, Version 2.0 (the "License");
 | |
|  * you may not use this file except in compliance with the License.
 | |
|  * You may obtain a copy of the License at
 | |
|  *
 | |
|  *     http://www.apache.org/licenses/LICENSE-2.0
 | |
|  *
 | |
|  * Unless required by applicable law or agreed to in writing, software
 | |
|  * distributed under the License is distributed on an "AS IS" BASIS,
 | |
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
|  * See the License for the specific language governing permissions and
 | |
|  * limitations under the License.
 | |
|  */
 | |
| 
 | |
| package service
 | |
| 
 | |
| import (
 | |
| 	"bytes"
 | |
| 	"context"
 | |
| 	"crypto/sha256"
 | |
| 	"encoding/base64"
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"math/rand"
 | |
| 	"runtime/debug"
 | |
| 	"strconv"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/tealeg/xlsx/v3"
 | |
| 	"gorm.io/gorm"
 | |
| 
 | |
| 	"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
 | |
| 
 | |
| 	"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
 | |
| 	"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/bot_common"
 | |
| 	"github.com/coze-dev/coze-studio/backend/api/model/table"
 | |
| 	"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossvariables"
 | |
| 	entity2 "github.com/coze-dev/coze-studio/backend/domain/memory/database/entity"
 | |
| 	"github.com/coze-dev/coze-studio/backend/domain/memory/database/internal/convertor"
 | |
| 	"github.com/coze-dev/coze-studio/backend/domain/memory/database/internal/dal/query"
 | |
| 	"github.com/coze-dev/coze-studio/backend/domain/memory/database/internal/physicaltable"
 | |
| 	"github.com/coze-dev/coze-studio/backend/domain/memory/database/internal/sheet"
 | |
| 	"github.com/coze-dev/coze-studio/backend/domain/memory/database/repository"
 | |
| 	"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
 | |
| 	"github.com/coze-dev/coze-studio/backend/infra/contract/rdb"
 | |
| 	entity3 "github.com/coze-dev/coze-studio/backend/infra/contract/rdb/entity"
 | |
| 	sqlparsercontract "github.com/coze-dev/coze-studio/backend/infra/contract/sqlparser"
 | |
| 	"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
 | |
| 	"github.com/coze-dev/coze-studio/backend/infra/impl/sqlparser"
 | |
| 	"github.com/coze-dev/coze-studio/backend/pkg/errorx"
 | |
| 	"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
 | |
| 	"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
 | |
| 	"github.com/coze-dev/coze-studio/backend/pkg/logs"
 | |
| 	"github.com/coze-dev/coze-studio/backend/types/consts"
 | |
| 	"github.com/coze-dev/coze-studio/backend/types/errno"
 | |
| )
 | |
| 
 | |
| type databaseService struct {
 | |
| 	rdb                rdb.RDB
 | |
| 	db                 *gorm.DB
 | |
| 	generator          idgen.IDGenerator
 | |
| 	draftDAO           repository.DraftDAO
 | |
| 	onlineDAO          repository.OnlineDAO
 | |
| 	agentToDatabaseDAO repository.AgentToDatabaseDAO
 | |
| 	storage            storage.Storage
 | |
| 	cache              cache.Cmdable
 | |
| }
 | |
| 
 | |
| func NewService(rdb rdb.RDB, db *gorm.DB, generator idgen.IDGenerator, storage storage.Storage, cacheCli cache.Cmdable) Database {
 | |
| 	return &databaseService{
 | |
| 		rdb:                rdb,
 | |
| 		db:                 db,
 | |
| 		generator:          generator,
 | |
| 		draftDAO:           repository.NewDraftDatabaseDAO(db, generator),
 | |
| 		onlineDAO:          repository.NewOnlineDatabaseDAO(db, generator),
 | |
| 		agentToDatabaseDAO: repository.NewAgentToDatabaseDAO(db, generator),
 | |
| 		storage:            storage,
 | |
| 		cache:              cacheCli,
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (d databaseService) CreateDatabase(ctx context.Context, req *CreateDatabaseRequest) (*CreateDatabaseResponse, error) {
 | |
| 	draftEntity, onlineEntity := req.Database, req.Database
 | |
| 	fieldItems, columns := physicaltable.CreateFieldInfo(req.Database.FieldList)
 | |
| 
 | |
| 	// create physical draft table
 | |
| 	draftEntity.FieldList = fieldItems
 | |
| 
 | |
| 	draftPhysicalTableRes, err := physicaltable.CreatePhysicalTable(ctx, d.rdb, columns)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	if draftPhysicalTableRes.Table == nil {
 | |
| 		return nil, fmt.Errorf("create draft table failed, columns info is %v", columns)
 | |
| 	}
 | |
| 
 | |
| 	draftID, err := d.generator.GenID(ctx)
 | |
| 	if err != nil {
 | |
| 		return nil, errorx.WrapByCode(err, errno.ErrMemoryIDGenFailCode, errorx.KV("msg", "CreateDatabase"))
 | |
| 	}
 | |
| 
 | |
| 	// create physical online table
 | |
| 	onlineEntity.FieldList = fieldItems
 | |
| 
 | |
| 	onlinePhysicalTableRes, err := physicaltable.CreatePhysicalTable(ctx, d.rdb, columns)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	if onlinePhysicalTableRes.Table == nil {
 | |
| 		return nil, fmt.Errorf("create online table failed, columns info is %v", columns)
 | |
| 	}
 | |
| 
 | |
| 	onlineID, err := d.generator.GenID(ctx)
 | |
| 	if err != nil {
 | |
| 		return nil, errorx.WrapByCode(err, errno.ErrMemoryIDGenFailCode, errorx.KV("msg", "CreateDatabase"))
 | |
| 	}
 | |
| 
 | |
| 	// insert draft and online database info
 | |
| 	tx := query.Use(d.db).Begin()
 | |
| 	if tx.Error != nil {
 | |
| 		return nil, fmt.Errorf("start transaction failed, %v", tx.Error)
 | |
| 	}
 | |
| 
 | |
| 	if draftEntity.IconURI == "" {
 | |
| 		draftEntity.IconURI = consts.DefaultDatabaseIcon
 | |
| 	}
 | |
| 	if onlineEntity.IconURI == "" {
 | |
| 		onlineEntity.IconURI = consts.DefaultDatabaseIcon
 | |
| 	}
 | |
| 
 | |
| 	defer func() {
 | |
| 		if r := recover(); r != nil {
 | |
| 			e := tx.Rollback()
 | |
| 			if e != nil {
 | |
| 				logs.CtxErrorf(ctx, "rollback failed, err=%v", e)
 | |
| 			}
 | |
| 
 | |
| 			err = fmt.Errorf("catch panic: %v\nstack=%s", r, string(debug.Stack()))
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		if err != nil {
 | |
| 			e := tx.Rollback()
 | |
| 			if e != nil {
 | |
| 				logs.CtxErrorf(ctx, "rollback failed, err=%v", e)
 | |
| 			}
 | |
| 		}
 | |
| 	}()
 | |
| 
 | |
| 	_, err = d.draftDAO.CreateWithTX(ctx, tx, draftEntity, draftID, onlineID, draftPhysicalTableRes.Table.Name)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	onlineEntity, err = d.onlineDAO.CreateWithTX(ctx, tx, onlineEntity, draftID, onlineID, onlinePhysicalTableRes.Table.Name)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	err = tx.Commit()
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	onlineEntity.ActualTableName = onlinePhysicalTableRes.Table.Name
 | |
| 	onlineEntity.ID = onlineID
 | |
| 	onlineEntity.DraftID = ptr.Of(draftID)
 | |
| 	objURL, uRrr := d.storage.GetObjectUrl(ctx, onlineEntity.IconURI)
 | |
| 	if uRrr == nil {
 | |
| 		onlineEntity.IconURL = objURL
 | |
| 	}
 | |
| 
 | |
| 	return &CreateDatabaseResponse{
 | |
| 		Database: onlineEntity,
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| func (d databaseService) UpdateDatabase(ctx context.Context, req *UpdateDatabaseRequest) (*UpdateDatabaseResponse, error) {
 | |
| 	// req.Database.ID is the id of online database
 | |
| 	input := req.Database
 | |
| 	onlineInfo, err := d.onlineDAO.Get(ctx, req.Database.ID)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("get online database info failed: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	draftInfo, err := d.draftDAO.Get(ctx, onlineInfo.GetDraftID())
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("get draft database info failed: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	draftEntity, onlineEntity := *input, *input
 | |
| 
 | |
| 	draftEntity.ID = draftInfo.ID
 | |
| 	onlineEntity.ID = onlineInfo.ID
 | |
| 
 | |
| 	fieldItems, columns, droppedColumns, err := physicaltable.UpdateFieldInfo(input.FieldList, onlineInfo.FieldList)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	draftEntity.FieldList = fieldItems
 | |
| 	onlineEntity.FieldList = fieldItems
 | |
| 
 | |
| 	// get draft and online physical table info
 | |
| 	draftPhysicalTable, err := d.rdb.GetTable(ctx, &rdb.GetTableRequest{
 | |
| 		TableName: draftInfo.ActualTableName,
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("get physical table info failed: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	onlinePhysicalTable, err := d.rdb.GetTable(ctx, &rdb.GetTableRequest{
 | |
| 		TableName: onlineInfo.ActualTableName,
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("get physical table info failed: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	err = physicaltable.UpdatePhysicalTableWithDrops(ctx, d.rdb, draftPhysicalTable.Table, columns, droppedColumns, draftInfo.ActualTableName)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("update draft physical table failed: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	err = physicaltable.UpdatePhysicalTableWithDrops(ctx, d.rdb, onlinePhysicalTable.Table, columns, droppedColumns, onlineInfo.ActualTableName)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("update online physical table failed: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	tx := query.Use(d.db).Begin()
 | |
| 	if tx.Error != nil {
 | |
| 		return nil, fmt.Errorf("start transaction failed, %v", tx.Error)
 | |
| 	}
 | |
| 
 | |
| 	if draftEntity.IconURI == "" {
 | |
| 		draftEntity.IconURI = consts.DefaultDatabaseIcon
 | |
| 	}
 | |
| 	if onlineEntity.IconURI == "" {
 | |
| 		onlineEntity.IconURI = consts.DefaultDatabaseIcon
 | |
| 	}
 | |
| 
 | |
| 	defer func() {
 | |
| 		if r := recover(); r != nil {
 | |
| 			e := tx.Rollback()
 | |
| 			if e != nil {
 | |
| 				logs.CtxErrorf(ctx, "rollback failed, err=%v", e)
 | |
| 			}
 | |
| 
 | |
| 			err = fmt.Errorf("catch panic: %v\nstack=%s", r, string(debug.Stack()))
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		if err != nil {
 | |
| 			e := tx.Rollback()
 | |
| 			if e != nil {
 | |
| 				logs.CtxErrorf(ctx, "rollback failed, err=%v", e)
 | |
| 			}
 | |
| 		}
 | |
| 	}()
 | |
| 
 | |
| 	_, err = d.draftDAO.UpdateWithTX(ctx, tx, &draftEntity)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("update draft database info failed: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	onlineEntityUpdated, err := d.onlineDAO.UpdateWithTX(ctx, tx, &onlineEntity)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("update online database info failed: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	err = tx.Commit()
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("commit transaction failed: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	return &UpdateDatabaseResponse{
 | |
| 		Database: onlineEntityUpdated,
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| func (d databaseService) DeleteDatabase(ctx context.Context, req *DeleteDatabaseRequest) error {
 | |
| 	onlineInfo, err := d.onlineDAO.Get(ctx, req.ID)
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("get online database info failed: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	draftInfo, err := d.draftDAO.Get(ctx, onlineInfo.GetDraftID())
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("get draft database info failed: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	tx := query.Use(d.db).Begin()
 | |
| 	if tx.Error != nil {
 | |
| 		return fmt.Errorf("start transaction failed, %v", tx.Error)
 | |
| 	}
 | |
| 
 | |
| 	defer func() {
 | |
| 		if r := recover(); r != nil {
 | |
| 			e := tx.Rollback()
 | |
| 			if e != nil {
 | |
| 				logs.CtxErrorf(ctx, "rollback failed, err=%v", e)
 | |
| 			}
 | |
| 
 | |
| 			err = fmt.Errorf("catch panic: %v\nstack=%s", r, string(debug.Stack()))
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		if err != nil {
 | |
| 			e := tx.Rollback()
 | |
| 			if e != nil {
 | |
| 				logs.CtxErrorf(ctx, "rollback failed, err=%v", e)
 | |
| 			}
 | |
| 		}
 | |
| 	}()
 | |
| 
 | |
| 	err = d.draftDAO.DeleteWithTX(ctx, tx, draftInfo.ID)
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("delete draft database info failed: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	err = d.onlineDAO.DeleteWithTX(ctx, tx, onlineInfo.ID)
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("delete online database info failed: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	err = tx.Commit()
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("commit transaction failed: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	// delete draft physical table
 | |
| 	if draftInfo.ActualTableName != "" {
 | |
| 		_, err = d.rdb.DropTable(ctx, &rdb.DropTableRequest{
 | |
| 			TableName: draftInfo.ActualTableName,
 | |
| 		})
 | |
| 		if err != nil {
 | |
| 			logs.Errorf("drop draft physical table failed: %v, table_name=%s", err, draftInfo.ActualTableName)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// delete online physical table
 | |
| 	if onlineInfo.ActualTableName != "" {
 | |
| 		_, err = d.rdb.DropTable(ctx, &rdb.DropTableRequest{
 | |
| 			TableName: onlineInfo.ActualTableName,
 | |
| 		})
 | |
| 		if err != nil {
 | |
| 			logs.Errorf("drop online physical table failed: %v, table_name=%s", err, onlineInfo.ActualTableName)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (d databaseService) MGetDatabase(ctx context.Context, req *MGetDatabaseRequest) (*MGetDatabaseResponse, error) {
 | |
| 	if len(req.Basics) == 0 {
 | |
| 		return &MGetDatabaseResponse{
 | |
| 			Databases: []*entity2.Database{},
 | |
| 		}, nil
 | |
| 	}
 | |
| 
 | |
| 	onlineID2NeedSysFields := make(map[int64]bool)
 | |
| 	draftID2NeedSysFields := make(map[int64]bool)
 | |
| 
 | |
| 	uniqueOnlineIDs := make([]int64, 0)
 | |
| 	uniqueDraftIDs := make([]int64, 0)
 | |
| 	idMap := make(map[int64]bool)
 | |
| 	for _, basic := range req.Basics {
 | |
| 		if !idMap[basic.ID] {
 | |
| 			idMap[basic.ID] = true
 | |
| 			if basic.TableType == table.TableType_OnlineTable {
 | |
| 				uniqueOnlineIDs = append(uniqueOnlineIDs, basic.ID)
 | |
| 				onlineID2NeedSysFields[basic.ID] = basic.NeedSysFields
 | |
| 			} else {
 | |
| 				uniqueDraftIDs = append(uniqueDraftIDs, basic.ID)
 | |
| 				draftID2NeedSysFields[basic.ID] = basic.NeedSysFields
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	onlineDatabases, err := d.onlineDAO.MGet(ctx, uniqueOnlineIDs)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("batch get database info failed: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	draftDatabases, err := d.draftDAO.MGet(ctx, uniqueDraftIDs)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("batch get database info failed: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	for _, onlineDatabase := range onlineDatabases {
 | |
| 		if needSys, ok := onlineID2NeedSysFields[onlineDatabase.ID]; ok && needSys {
 | |
| 			if onlineDatabase.FieldList == nil {
 | |
| 				onlineDatabase.FieldList = make([]*database.FieldItem, 0, 3)
 | |
| 			}
 | |
| 			onlineDatabase.FieldList = append(onlineDatabase.FieldList, physicaltable.GetDisplayCreateTimeField(), physicaltable.GetDisplayUidField(), physicaltable.GetDisplayIDField())
 | |
| 		}
 | |
| 		if onlineDatabase.IconURI != "" {
 | |
| 			objURL, uRrr := d.storage.GetObjectUrl(ctx, onlineDatabase.IconURI)
 | |
| 			if uRrr == nil {
 | |
| 				onlineDatabase.IconURL = objURL
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 	for _, draftDatabase := range draftDatabases {
 | |
| 		if needSys, ok := draftID2NeedSysFields[draftDatabase.ID]; ok && needSys {
 | |
| 			if draftDatabase.FieldList == nil {
 | |
| 				draftDatabase.FieldList = make([]*database.FieldItem, 0, 3)
 | |
| 			}
 | |
| 			draftDatabase.FieldList = append(draftDatabase.FieldList, physicaltable.GetDisplayCreateTimeField(), physicaltable.GetDisplayUidField(), physicaltable.GetDisplayIDField())
 | |
| 		}
 | |
| 		if draftDatabase.IconURI != "" {
 | |
| 			objURL, uRrr := d.storage.GetObjectUrl(ctx, draftDatabase.IconURI)
 | |
| 			if uRrr == nil {
 | |
| 				draftDatabase.IconURL = objURL
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	databases := make([]*entity2.Database, 0)
 | |
| 	databases = append(databases, onlineDatabases...)
 | |
| 	databases = append(databases, draftDatabases...)
 | |
| 
 | |
| 	return &MGetDatabaseResponse{
 | |
| 		Databases: databases,
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| func (d databaseService) ListDatabase(ctx context.Context, req *ListDatabaseRequest) (*ListDatabaseResponse, error) {
 | |
| 	filter := &entity2.DatabaseFilter{
 | |
| 		CreatorID: req.CreatorID,
 | |
| 		SpaceID:   req.SpaceID,
 | |
| 		TableName: req.TableName,
 | |
| 		AppID:     &req.AppID,
 | |
| 	}
 | |
| 
 | |
| 	page := &entity2.Pagination{
 | |
| 		Limit:  req.Limit,
 | |
| 		Offset: req.Offset,
 | |
| 	}
 | |
| 
 | |
| 	var databases []*entity2.Database
 | |
| 	var err error
 | |
| 	var count int64
 | |
| 	if req.TableType == table.TableType_OnlineTable {
 | |
| 		databases, count, err = d.onlineDAO.List(ctx, filter, page, req.OrderBy)
 | |
| 		if err != nil {
 | |
| 			return nil, fmt.Errorf("list database failed: %v", err)
 | |
| 		}
 | |
| 	} else {
 | |
| 		databases, count, err = d.draftDAO.List(ctx, filter, page, req.OrderBy)
 | |
| 		if err != nil {
 | |
| 			return nil, fmt.Errorf("list database failed: %v", err)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	for _, database := range databases {
 | |
| 		if database.IconURI != "" {
 | |
| 			objURL, uRrr := d.storage.GetObjectUrl(ctx, database.IconURI)
 | |
| 			if uRrr == nil {
 | |
| 				database.IconURL = objURL
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	var hasMore bool
 | |
| 	if count <= int64(req.Limit)+int64(req.Offset) {
 | |
| 		hasMore = false
 | |
| 	} else {
 | |
| 		hasMore = true
 | |
| 	}
 | |
| 
 | |
| 	return &ListDatabaseResponse{
 | |
| 		Databases:  databases,
 | |
| 		HasMore:    hasMore,
 | |
| 		TotalCount: count,
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| func (d databaseService) AddDatabaseRecord(ctx context.Context, req *AddDatabaseRecordRequest) error {
 | |
| 	var tableInfo *entity2.Database
 | |
| 	var err error
 | |
| 
 | |
| 	if req.TableType == table.TableType_OnlineTable {
 | |
| 		tableInfo, err = d.onlineDAO.Get(ctx, req.DatabaseID)
 | |
| 	} else {
 | |
| 		tableInfo, err = d.draftDAO.Get(ctx, req.DatabaseID)
 | |
| 	}
 | |
| 
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("get table info failed: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	if tableInfo.RwMode == table.BotTableRWMode_ReadOnly {
 | |
| 		return errorx.New(errno.ErrMemoryDatabaseCannotAddData)
 | |
| 	}
 | |
| 
 | |
| 	physicalTableName := tableInfo.ActualTableName
 | |
| 	if physicalTableName == "" {
 | |
| 		return fmt.Errorf("physical table name is empty")
 | |
| 	}
 | |
| 
 | |
| 	fieldList := append(tableInfo.FieldList, physicaltable.GetCreateTimeField(), physicaltable.GetUidField(), physicaltable.GetIDField(), physicaltable.GetConnectIDField())
 | |
| 	fieldMap := slices.ToMap(fieldList, func(e *database.FieldItem) (string, *database.FieldItem) {
 | |
| 		return e.Name, e
 | |
| 	})
 | |
| 
 | |
| 	convertedRecords := make([]map[string]interface{}, 0, len(req.Records))
 | |
| 	ids, err := d.generator.GenMultiIDs(ctx, len(req.Records))
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	for index, recordMap := range req.Records {
 | |
| 		convertedRecord := make(map[string]interface{})
 | |
| 
 | |
| 		cid := consts.CozeConnectorID
 | |
| 		if req.ConnectorID != nil {
 | |
| 			cid = *req.ConnectorID
 | |
| 		}
 | |
| 		convertedRecord[database.DefaultUidColName] = req.UserID
 | |
| 		convertedRecord[database.DefaultCidColName] = cid
 | |
| 		convertedRecord[database.DefaultCreateTimeColName] = time.Now()
 | |
| 		convertedRecord[database.DefaultIDColName] = ids[index]
 | |
| 
 | |
| 		if _, ok := recordMap[database.DefaultIDColName]; ok {
 | |
| 			delete(recordMap, database.DefaultIDColName)
 | |
| 		}
 | |
| 
 | |
| 		for fieldName, value := range recordMap {
 | |
| 			if _, fOk := fieldMap[fieldName]; !fOk {
 | |
| 				return errorx.New(errno.ErrMemoryDatabaseFieldNotFoundCode, errorx.KV("msg", fmt.Sprintf("field %s not found in table definition", fieldName)))
 | |
| 			}
 | |
| 
 | |
| 			fieldInfo, _ := fieldMap[fieldName]
 | |
| 			if value == "" && fieldInfo.MustRequired {
 | |
| 				return fmt.Errorf("field %s's value is required", fieldName)
 | |
| 			}
 | |
| 
 | |
| 			physicalFieldName := fieldInfo.PhysicalName
 | |
| 			convertedValue, err := convertor.ConvertValueByType(value, fieldInfo.Type)
 | |
| 			if err != nil {
 | |
| 				return fmt.Errorf("convert value failed for field %s: %v, using original value", fieldName, err)
 | |
| 			}
 | |
| 
 | |
| 			convertedRecord[physicalFieldName] = convertedValue
 | |
| 		}
 | |
| 
 | |
| 		convertedRecords = append(convertedRecords, convertedRecord)
 | |
| 	}
 | |
| 
 | |
| 	_, err = d.rdb.InsertData(ctx, &rdb.InsertDataRequest{
 | |
| 		TableName: physicalTableName,
 | |
| 		Data:      convertedRecords,
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("insert data failed: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (d databaseService) UpdateDatabaseRecord(ctx context.Context, req *UpdateDatabaseRecordRequest) error {
 | |
| 	var tableInfo *database.Database
 | |
| 	var err error
 | |
| 
 | |
| 	if req.TableType == table.TableType_OnlineTable {
 | |
| 		tableInfo, err = d.onlineDAO.Get(ctx, req.DatabaseID)
 | |
| 	} else {
 | |
| 		tableInfo, err = d.draftDAO.Get(ctx, req.DatabaseID)
 | |
| 	}
 | |
| 
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("get table info failed: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	if tableInfo.RwMode == table.BotTableRWMode_ReadOnly {
 | |
| 		return errorx.New(errno.ErrMemoryDatabaseCannotAddData)
 | |
| 	}
 | |
| 
 | |
| 	physicalTableName := tableInfo.ActualTableName
 | |
| 	if physicalTableName == "" {
 | |
| 		return fmt.Errorf("physical table name is empty")
 | |
| 	}
 | |
| 
 | |
| 	fieldList := append(tableInfo.FieldList, physicaltable.GetCreateTimeField(), physicaltable.GetUidField(), physicaltable.GetIDField(), physicaltable.GetConnectIDField())
 | |
| 	fieldMap := slices.ToMap(fieldList, func(e *database.FieldItem) (string, *database.FieldItem) {
 | |
| 		return e.Name, e
 | |
| 	})
 | |
| 
 | |
| 	for _, record := range req.Records {
 | |
| 		idStr, exists := record[database.DefaultIDColName]
 | |
| 		if !exists {
 | |
| 			return fmt.Errorf("record must contain %s field for update", database.DefaultIDColName)
 | |
| 		}
 | |
| 
 | |
| 		id, err := strconv.ParseInt(idStr, 10, 64)
 | |
| 		if err != nil {
 | |
| 			return fmt.Errorf("invalid ID format: %v", err)
 | |
| 		}
 | |
| 
 | |
| 		updateData := make(map[string]interface{})
 | |
| 
 | |
| 		for fieldName, valueStr := range record {
 | |
| 			if fieldName == database.DefaultIDColName {
 | |
| 				continue
 | |
| 			}
 | |
| 
 | |
| 			if _, fOk := fieldMap[fieldName]; !fOk {
 | |
| 				return errorx.New(errno.ErrMemoryDatabaseFieldNotFoundCode, errorx.KV("msg", fmt.Sprintf("field %s not found in table definition", fieldName)))
 | |
| 			}
 | |
| 
 | |
| 			fieldInfo, _ := fieldMap[fieldName]
 | |
| 			if valueStr == "" && fieldInfo.MustRequired {
 | |
| 				return fmt.Errorf("field %s's value is required", fieldName)
 | |
| 			}
 | |
| 
 | |
| 			physicalFieldName := fieldInfo.PhysicalName
 | |
| 			convertedValue, err := convertor.ConvertValueByType(valueStr, fieldInfo.Type)
 | |
| 			if err != nil {
 | |
| 				logs.Warnf("convert value failed for field %s: %v, using original value", fieldName, err)
 | |
| 				convertedValue = valueStr
 | |
| 			}
 | |
| 			updateData[physicalFieldName] = convertedValue
 | |
| 		}
 | |
| 
 | |
| 		if len(updateData) == 0 {
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		condition := &rdb.ComplexCondition{
 | |
| 			Conditions: []*rdb.Condition{
 | |
| 				{
 | |
| 					Field:    database.DefaultIDColName,
 | |
| 					Operator: entity3.OperatorEqual,
 | |
| 					Value:    id,
 | |
| 				},
 | |
| 			},
 | |
| 		}
 | |
| 
 | |
| 		if tableInfo.RwMode == table.BotTableRWMode_LimitedReadWrite {
 | |
| 			cond := &rdb.Condition{
 | |
| 				Field:    database.DefaultUidColName,
 | |
| 				Operator: entity3.OperatorEqual,
 | |
| 				Value:    strconv.FormatInt(req.UserID, 10),
 | |
| 			}
 | |
| 
 | |
| 			condition.Conditions = append(condition.Conditions, cond)
 | |
| 		}
 | |
| 
 | |
| 		_, err = d.rdb.UpdateData(ctx, &rdb.UpdateDataRequest{
 | |
| 			TableName: physicalTableName,
 | |
| 			Data:      updateData,
 | |
| 			Where:     condition,
 | |
| 		})
 | |
| 		if err != nil {
 | |
| 			return fmt.Errorf("update data failed for ID %d: %v", id, err)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (d databaseService) DeleteDatabaseRecord(ctx context.Context, req *DeleteDatabaseRecordRequest) error {
 | |
| 	var tableInfo *entity2.Database
 | |
| 	var err error
 | |
| 
 | |
| 	if req.TableType == table.TableType_OnlineTable {
 | |
| 		tableInfo, err = d.onlineDAO.Get(ctx, req.DatabaseID)
 | |
| 	} else {
 | |
| 		tableInfo, err = d.draftDAO.Get(ctx, req.DatabaseID)
 | |
| 	}
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	if tableInfo.RwMode == table.BotTableRWMode_ReadOnly {
 | |
| 		return errorx.New(errno.ErrMemoryDatabaseCannotAddData)
 | |
| 	}
 | |
| 
 | |
| 	physicalTableName := tableInfo.ActualTableName
 | |
| 	if physicalTableName == "" {
 | |
| 		return fmt.Errorf("physical table name is empty")
 | |
| 	}
 | |
| 
 | |
| 	var ids []interface{}
 | |
| 	for _, record := range req.Records {
 | |
| 		idStr, exists := record[database.DefaultIDColName]
 | |
| 		if !exists {
 | |
| 			return fmt.Errorf("record must contain %s field for deletion", database.DefaultIDColName)
 | |
| 		}
 | |
| 
 | |
| 		id, err := strconv.ParseInt(idStr, 10, 64)
 | |
| 		if err != nil {
 | |
| 			return fmt.Errorf("invalid ID format: %v", err)
 | |
| 		}
 | |
| 
 | |
| 		ids = append(ids, id)
 | |
| 	}
 | |
| 
 | |
| 	condition := &rdb.ComplexCondition{
 | |
| 		Conditions: []*rdb.Condition{
 | |
| 			{
 | |
| 				Field:    database.DefaultIDColName,
 | |
| 				Operator: entity3.OperatorIn,
 | |
| 				Value:    ids,
 | |
| 			},
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	if tableInfo.RwMode == table.BotTableRWMode_LimitedReadWrite {
 | |
| 		cond := &rdb.Condition{
 | |
| 			Field:    database.DefaultUidColName,
 | |
| 			Operator: entity3.OperatorEqual,
 | |
| 			Value:    strconv.FormatInt(req.UserID, 10),
 | |
| 		}
 | |
| 
 | |
| 		condition.Conditions = append(condition.Conditions, cond)
 | |
| 	}
 | |
| 
 | |
| 	_, err = d.rdb.DeleteData(ctx, &rdb.DeleteDataRequest{
 | |
| 		TableName: physicalTableName,
 | |
| 		Where:     condition,
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("delete data failed: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (d databaseService) ListDatabaseRecord(ctx context.Context, req *ListDatabaseRecordRequest) (*ListDatabaseRecordResponse, error) {
 | |
| 	var tableInfo *entity2.Database
 | |
| 	var err error
 | |
| 
 | |
| 	if req.TableType == table.TableType_OnlineTable {
 | |
| 		tableInfo, err = d.onlineDAO.Get(ctx, req.DatabaseID)
 | |
| 	} else {
 | |
| 		tableInfo, err = d.draftDAO.Get(ctx, req.DatabaseID)
 | |
| 	}
 | |
| 
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("get table info failed: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	physicalTableName := tableInfo.ActualTableName
 | |
| 	if physicalTableName == "" {
 | |
| 		return nil, fmt.Errorf("physical table name is empty")
 | |
| 	}
 | |
| 
 | |
| 	fieldNameToPhysical := make(map[string]string)
 | |
| 	physicalToFieldName := make(map[string]string)
 | |
| 	physicalToFieldType := make(map[string]table.FieldItemType)
 | |
| 
 | |
| 	for _, field := range tableInfo.FieldList {
 | |
| 		if field.AlterID > 0 {
 | |
| 			physicalName := physicaltable.GetFieldPhysicsName(field.AlterID)
 | |
| 			fieldNameToPhysical[field.Name] = physicalName
 | |
| 			physicalToFieldName[physicalName] = field.Name
 | |
| 			physicalToFieldType[physicalName] = field.Type
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	var complexCondition *rdb.ComplexCondition
 | |
| 
 | |
| 	if req.ConnectorID != nil && *req.ConnectorID > 0 {
 | |
| 		cond := &rdb.Condition{
 | |
| 			Field:    database.DefaultCidColName,
 | |
| 			Operator: entity3.OperatorEqual,
 | |
| 			Value:    *req.ConnectorID,
 | |
| 		}
 | |
| 
 | |
| 		complexCondition = &rdb.ComplexCondition{
 | |
| 			Conditions: []*rdb.Condition{cond},
 | |
| 		}
 | |
| 	}
 | |
| 	if req.TableType == table.TableType_DraftTable {
 | |
| 		if tableInfo.RwMode == table.BotTableRWMode_LimitedReadWrite {
 | |
| 			cond := &rdb.Condition{
 | |
| 				Field:    database.DefaultUidColName,
 | |
| 				Operator: entity3.OperatorEqual,
 | |
| 				Value:    strconv.FormatInt(req.UserID, 10),
 | |
| 			}
 | |
| 
 | |
| 			if complexCondition == nil {
 | |
| 				complexCondition = &rdb.ComplexCondition{
 | |
| 					Conditions: []*rdb.Condition{cond},
 | |
| 				}
 | |
| 			} else {
 | |
| 				complexCondition.Conditions = append(complexCondition.Conditions, cond)
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	limit := 50
 | |
| 	if req.Limit > 0 {
 | |
| 		limit = req.Limit
 | |
| 	}
 | |
| 
 | |
| 	orderBy := []*rdb.OrderBy{
 | |
| 		{
 | |
| 			Field:     database.DefaultCreateTimeColName,
 | |
| 			Direction: entity3.SortDirectionDesc,
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	selectResp, err := d.rdb.SelectData(ctx, &rdb.SelectDataRequest{
 | |
| 		TableName: physicalTableName,
 | |
| 		Fields:    []string{}, // Null means query all fields
 | |
| 		Where:     complexCondition,
 | |
| 		OrderBy:   orderBy,
 | |
| 		Limit:     &limit,
 | |
| 		Offset:    &req.Offset,
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("select data failed: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	if selectResp.ResultSet == nil {
 | |
| 		return &ListDatabaseRecordResponse{}, nil
 | |
| 	}
 | |
| 
 | |
| 	records := convertor.ConvertResultSetToString(selectResp.ResultSet, physicalToFieldName, physicalToFieldType)
 | |
| 
 | |
| 	var hasMore bool
 | |
| 	if selectResp.Total <= int64(req.Limit)+int64(req.Offset) {
 | |
| 		hasMore = false
 | |
| 	} else {
 | |
| 		hasMore = true
 | |
| 	}
 | |
| 
 | |
| 	return &ListDatabaseRecordResponse{
 | |
| 		Records:    records,
 | |
| 		FieldList:  tableInfo.FieldList,
 | |
| 		HasMore:    hasMore,
 | |
| 		TotalCount: selectResp.Total,
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| func (d databaseService) GetDatabaseTemplate(ctx context.Context, req *GetDatabaseTemplateRequest) (*GetDatabaseTemplateResponse, error) {
 | |
| 	items := req.FieldItems
 | |
| 	tableName := req.TableName
 | |
| 
 | |
| 	file := xlsx.NewFile()
 | |
| 	sheet, err := file.AddSheet("Sheet1")
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	// add header
 | |
| 	header := sheet.AddRow()
 | |
| 	headerTitles := make([]string, 0)
 | |
| 	for i := range items {
 | |
| 		headerTitles = append(headerTitles, items[i].GetName())
 | |
| 	}
 | |
| 	for _, title := range headerTitles {
 | |
| 		cell := header.AddCell()
 | |
| 		cell.Value = title
 | |
| 	}
 | |
| 
 | |
| 	row := sheet.AddRow()
 | |
| 	for _, item := range items {
 | |
| 		row.AddCell().Value = physicaltable.GetTemplateTypeMap()[item.GetType()]
 | |
| 	}
 | |
| 	var buffer bytes.Buffer
 | |
| 	err = file.Write(&buffer)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	binaryData := buffer.Bytes()
 | |
| 	url, err := d.uploadFile(ctx, req.UserID, string(binaryData), tableName, "xlsx", nil)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return &GetDatabaseTemplateResponse{
 | |
| 		Url: url,
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| func (d databaseService) uploadFile(ctx context.Context, UserId int64, content string, bizType, fileType string, suffix *string) (string, error) {
 | |
| 	secret := createSecret(UserId, fileType)
 | |
| 	fileName := fmt.Sprintf("%d_%d_%s.%s", UserId, time.Now().UnixNano(), secret, fileType)
 | |
| 	if suffix != nil {
 | |
| 		fileName = fmt.Sprintf("%d_%d_%s_%s.%s", UserId, time.Now().UnixNano(), secret, *suffix, fileType)
 | |
| 	}
 | |
| 
 | |
| 	objectName := fmt.Sprintf("%s/%s", bizType, fileName)
 | |
| 	err := d.storage.PutObject(ctx, objectName, []byte(content))
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 
 | |
| 	url, err := d.storage.GetObjectUrl(ctx, objectName)
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 
 | |
| 	return url, nil
 | |
| }
 | |
| 
 | |
| const baseWord = "1Aa2Bb3Cc4Dd5Ee6Ff7Gg8Hh9Ii0JjKkLlMmNnOoPpQqRrSsTtUuVvWwXxYyZz"
 | |
| 
 | |
| func createSecret(uid int64, fileType string) string {
 | |
| 	num := 10
 | |
| 	input := fmt.Sprintf("upload_%d_Ma*9)fhi_%d_gou_%s_rand_%d", uid, time.Now().Unix(), fileType, rand.Intn(100000))
 | |
| 	hash := sha256.Sum256([]byte(fmt.Sprintf("%s", input)))
 | |
| 	hashString := base64.StdEncoding.EncodeToString(hash[:])
 | |
| 
 | |
| 	if len(hashString) > num {
 | |
| 		hashString = hashString[:num]
 | |
| 	}
 | |
| 
 | |
| 	result := ""
 | |
| 	for _, char := range hashString {
 | |
| 		index := int(char) % 62
 | |
| 		result += string(baseWord[index])
 | |
| 	}
 | |
| 	return result
 | |
| }
 | |
| 
 | |
| func (d databaseService) ExecuteSQL(ctx context.Context, req *ExecuteSQLRequest) (*ExecuteSQLResponse, error) {
 | |
| 	var tableInfo *entity2.Database
 | |
| 	var err error
 | |
| 
 | |
| 	if req.TableType == table.TableType_OnlineTable {
 | |
| 		tableInfo, err = d.onlineDAO.Get(ctx, req.DatabaseID)
 | |
| 	} else {
 | |
| 		tableInfo, err = d.draftDAO.Get(ctx, req.DatabaseID)
 | |
| 	}
 | |
| 
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("get table info failed: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	if tableInfo.RwMode == table.BotTableRWMode_ReadOnly &&
 | |
| 		(req.OperateType == database.OperateType_Insert || req.OperateType == database.OperateType_Update ||
 | |
| 			req.OperateType == database.OperateType_Delete) {
 | |
| 		return nil, errorx.New(errno.ErrMemoryDatabaseCannotAddData)
 | |
| 	}
 | |
| 
 | |
| 	physicalTableName := tableInfo.ActualTableName
 | |
| 	if physicalTableName == "" {
 | |
| 		return nil, fmt.Errorf("physical table name is empty")
 | |
| 	}
 | |
| 
 | |
| 	fieldNameToPhysical := make(map[string]string)
 | |
| 	physicalToFieldName := make(map[string]string)
 | |
| 	physicalToFieldType := make(map[string]table.FieldItemType)
 | |
| 
 | |
| 	for _, field := range tableInfo.FieldList {
 | |
| 		if field.AlterID > 0 {
 | |
| 			physicalName := physicaltable.GetFieldPhysicsName(field.AlterID)
 | |
| 			fieldNameToPhysical[field.Name] = physicalName
 | |
| 			physicalToFieldName[physicalName] = field.Name
 | |
| 			physicalToFieldType[physicalName] = field.Type
 | |
| 		}
 | |
| 	}
 | |
| 	fieldNameToPhysical[database.DefaultIDDisplayColName] = database.DefaultIDColName
 | |
| 	fieldNameToPhysical[database.DefaultUidDisplayColName] = database.DefaultUidColName
 | |
| 	fieldNameToPhysical[database.DefaultCreateTimeDisplayColName] = database.DefaultCreateTimeColName
 | |
| 
 | |
| 	var resultSet *entity3.ResultSet
 | |
| 	var rowsAffected int64
 | |
| 
 | |
| 	switch req.OperateType {
 | |
| 	case database.OperateType_Custom:
 | |
| 		resultSet, err = d.executeCustomSQL(ctx, req, physicalTableName, tableInfo, fieldNameToPhysical)
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 
 | |
| 	case database.OperateType_Select:
 | |
| 		resultSet, err = d.executeSelectSQL(ctx, req, physicalTableName, tableInfo, fieldNameToPhysical)
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 
 | |
| 	case database.OperateType_Insert:
 | |
| 		resultSet, err = d.executeInsertSQL(ctx, req, physicalTableName, tableInfo)
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 
 | |
| 	case database.OperateType_Update:
 | |
| 		rowsAffected, err = d.executeUpdateSQL(ctx, req, physicalTableName, tableInfo, fieldNameToPhysical)
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 
 | |
| 	case database.OperateType_Delete:
 | |
| 		rowsAffected, err = d.executeDeleteSQL(ctx, req, physicalTableName, tableInfo, fieldNameToPhysical)
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 
 | |
| 	default:
 | |
| 		return nil, fmt.Errorf("unsupported operation type: %v", req.OperateType)
 | |
| 	}
 | |
| 
 | |
| 	response := &ExecuteSQLResponse{
 | |
| 		FieldList: tableInfo.FieldList,
 | |
| 	}
 | |
| 
 | |
| 	if resultSet != nil && len(resultSet.Rows) > 0 {
 | |
| 		response.Records = convertor.ConvertResultSet(resultSet, physicalToFieldName, physicalToFieldType)
 | |
| 	} else {
 | |
| 		response.Records = make([]map[string]interface{}, 0)
 | |
| 	}
 | |
| 
 | |
| 	// process special system fields
 | |
| 	for _, record := range response.Records {
 | |
| 		if val, ok := record[database.DefaultUidColName]; ok {
 | |
| 			delete(record, database.DefaultUidColName)
 | |
| 			record[database.DefaultUidDisplayColName] = val
 | |
| 		}
 | |
| 		if val, ok := record[database.DefaultCreateTimeColName]; ok {
 | |
| 			delete(record, database.DefaultCreateTimeColName)
 | |
| 			record[database.DefaultCreateTimeDisplayColName] = val
 | |
| 		}
 | |
| 		if val, ok := record[database.DefaultIDColName]; ok {
 | |
| 			delete(record, database.DefaultIDColName)
 | |
| 			record[database.DefaultIDDisplayColName] = val
 | |
| 		}
 | |
| 		if _, ok := record[database.DefaultCidColName]; ok {
 | |
| 			delete(record, database.DefaultCidColName)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if resultSet != nil && resultSet.AffectedRows > 0 {
 | |
| 		response.RowsAffected = &resultSet.AffectedRows
 | |
| 	}
 | |
| 
 | |
| 	if rowsAffected > 0 {
 | |
| 		response.RowsAffected = &rowsAffected
 | |
| 	}
 | |
| 
 | |
| 	return response, nil
 | |
| }
 | |
| 
 | |
| func (d databaseService) executeCustomSQL(ctx context.Context, req *ExecuteSQLRequest, physicalTableName string, tableInfo *entity2.Database, fieldNameToPhysical map[string]string) (*entity3.ResultSet, error) {
 | |
| 	var params []interface{}
 | |
| 	if req.SQL == nil || *req.SQL == "" {
 | |
| 		return nil, fmt.Errorf("SQL is empty")
 | |
| 	}
 | |
| 
 | |
| 	operation, err := sqlparser.NewSQLParser().GetSQLOperation(*req.SQL)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	if tableInfo.RwMode == table.BotTableRWMode_ReadOnly && (operation == sqlparsercontract.OperationTypeInsert || operation == sqlparsercontract.OperationTypeUpdate || operation == sqlparsercontract.OperationTypeDelete) {
 | |
| 		return nil, fmt.Errorf("unsupported operation type: %v", operation)
 | |
| 	}
 | |
| 
 | |
| 	if req.SQLParams != nil {
 | |
| 		params = make([]interface{}, 0, len(req.SQLParams))
 | |
| 		for _, param := range req.SQLParams {
 | |
| 			value := param.Value
 | |
| 			if param.ISNull {
 | |
| 				value = nil
 | |
| 			}
 | |
| 			params = append(params, value)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	tableColumnMapping := map[string]sqlparsercontract.TableColumn{
 | |
| 		tableInfo.TableName: {
 | |
| 			NewTableName: &physicalTableName,
 | |
| 			ColumnMap:    fieldNameToPhysical,
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	parsedSQL, err := sqlparser.NewSQLParser().ParseAndModifySQL(*req.SQL, tableColumnMapping)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("parse sql failed: %v", err)
 | |
| 	}
 | |
| 	// add rw mode
 | |
| 	if tableInfo.RwMode == table.BotTableRWMode_LimitedReadWrite && len(req.UserID) != 0 {
 | |
| 		switch operation {
 | |
| 		case sqlparsercontract.OperationTypeSelect, sqlparsercontract.OperationTypeUpdate, sqlparsercontract.OperationTypeDelete:
 | |
| 			parsedSQL, err = sqlparser.NewSQLParser().AppendSQLFilter(parsedSQL, sqlparsercontract.SQLFilterOpAnd, fmt.Sprintf("%s = '%s'", database.DefaultUidColName, req.UserID))
 | |
| 			if err != nil {
 | |
| 				return nil, fmt.Errorf("append sql filter failed: %v", err)
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 	insertResult := make([]map[string]interface{}, 0)
 | |
| 	if operation == sqlparsercontract.OperationTypeInsert {
 | |
| 		cid := consts.CozeConnectorID
 | |
| 		if req.ConnectorID != nil {
 | |
| 			cid = *req.ConnectorID
 | |
| 		}
 | |
| 		nums, err := sqlparser.NewSQLParser().GetInsertDataNums(parsedSQL)
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 
 | |
| 		ids, err := d.generator.GenMultiIDs(ctx, nums)
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 
 | |
| 		for _, id := range ids {
 | |
| 			insertResult = append(insertResult, map[string]interface{}{
 | |
| 				database.DefaultIDColName: id,
 | |
| 			})
 | |
| 		}
 | |
| 
 | |
| 		existingCols := make(map[string]bool)
 | |
| 		if req.SQLType == database.SQLType_Raw {
 | |
| 			iIDs := make([]interface{}, len(ids))
 | |
| 			for i, id := range ids {
 | |
| 				iIDs[i] = id
 | |
| 			}
 | |
| 			parsedSQL, _, err = sqlparser.NewSQLParser().AddColumnsToInsertSQL(parsedSQL, []sqlparsercontract.ColumnValue{
 | |
| 				{
 | |
| 					ColName: database.DefaultCidColName,
 | |
| 					Value:   cid,
 | |
| 				},
 | |
| 				{
 | |
| 					ColName: database.DefaultUidColName,
 | |
| 					Value:   req.UserID,
 | |
| 				},
 | |
| 			}, &sqlparsercontract.PrimaryKeyValue{ColName: database.DefaultIDColName, Values: iIDs}, false)
 | |
| 			if err != nil {
 | |
| 				return nil, fmt.Errorf("add columns to insert sql failed: %v", err)
 | |
| 			}
 | |
| 		} else if req.SQLType == database.SQLType_Parameterized {
 | |
| 			parsedSQL, existingCols, err = sqlparser.NewSQLParser().AddColumnsToInsertSQL(parsedSQL, []sqlparsercontract.ColumnValue{
 | |
| 				{
 | |
| 					ColName: database.DefaultCidColName,
 | |
| 				},
 | |
| 				{
 | |
| 					ColName: database.DefaultUidColName,
 | |
| 				},
 | |
| 			}, &sqlparsercontract.PrimaryKeyValue{ColName: database.DefaultIDColName}, true)
 | |
| 			if err != nil {
 | |
| 				return nil, fmt.Errorf("add columns to insert sql failed: %v", err)
 | |
| 			}
 | |
| 
 | |
| 			if nums > 0 {
 | |
| 				if len(params)%nums != 0 {
 | |
| 					return nil, fmt.Errorf("number of params is not a multiple of number of rows")
 | |
| 				}
 | |
| 				paramsPerRow := len(params) / nums
 | |
| 				newParams := make([]interface{}, 0)
 | |
| 				for i := 0; i < nums; i++ {
 | |
| 					newParams = append(newParams, params[i*paramsPerRow:(i+1)*paramsPerRow]...)
 | |
| 					if !existingCols[database.DefaultCidColName] {
 | |
| 						newParams = append(newParams, cid)
 | |
| 					}
 | |
| 					if !existingCols[database.DefaultUidColName] {
 | |
| 						newParams = append(newParams, req.UserID)
 | |
| 					}
 | |
| 					if !existingCols[database.DefaultIDColName] {
 | |
| 						newParams = append(newParams, ids[i])
 | |
| 					}
 | |
| 				}
 | |
| 				params = newParams
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	execResp, err := d.rdb.ExecuteSQL(ctx, &rdb.ExecuteSQLRequest{
 | |
| 		SQL:    parsedSQL,
 | |
| 		Params: params,
 | |
| 
 | |
| 		SQLType: entity3.SQLType(req.SQLType),
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("execute SQL failed: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	if operation == sqlparsercontract.OperationTypeInsert {
 | |
| 		if execResp.ResultSet == nil {
 | |
| 			execResp.ResultSet = &entity3.ResultSet{
 | |
| 				Rows: insertResult,
 | |
| 			}
 | |
| 		} else {
 | |
| 			execResp.ResultSet.Rows = insertResult
 | |
| 		}
 | |
| 	}
 | |
| 	return execResp.ResultSet, nil
 | |
| }
 | |
| 
 | |
| func (d databaseService) executeSelectSQL(ctx context.Context, req *ExecuteSQLRequest, physicalTableName string, tableInfo *entity2.Database, fieldNameToPhysical map[string]string) (*entity3.ResultSet, error) {
 | |
| 	selectReq := &rdb.SelectDataRequest{
 | |
| 		TableName: physicalTableName,
 | |
| 		Limit:     int64PtrToIntPtr(req.Limit),
 | |
| 		Offset:    int64PtrToIntPtr(req.Offset),
 | |
| 	}
 | |
| 
 | |
| 	fieldList := append(tableInfo.FieldList, physicaltable.GetCreateTimeField(), physicaltable.GetUidField(), physicaltable.GetIDField(), physicaltable.GetConnectIDField())
 | |
| 	fieldMap := slices.ToMap(fieldList, func(e *database.FieldItem) (string, *database.FieldItem) {
 | |
| 		return strconv.FormatInt(e.AlterID, 10), e
 | |
| 	})
 | |
| 
 | |
| 	if req.SelectFieldList != nil && !req.SelectFieldList.IsDistinct && len(req.SelectFieldList.FieldID) > 0 {
 | |
| 		fields := make([]string, 0, len(req.SelectFieldList.FieldID))
 | |
| 		for _, fieldID := range req.SelectFieldList.FieldID {
 | |
| 			if _, exists := fieldMap[fieldID]; !exists {
 | |
| 				return nil, fmt.Errorf("fieldID %s does not exist", fieldID)
 | |
| 			}
 | |
| 
 | |
| 			field, _ := fieldMap[fieldID]
 | |
| 			fields = append(fields, field.PhysicalName)
 | |
| 		}
 | |
| 		selectReq.Fields = fields
 | |
| 	}
 | |
| 
 | |
| 	var complexCond *rdb.ComplexCondition
 | |
| 	var err error
 | |
| 	if req.Condition != nil {
 | |
| 		complexCond, err = convertCondition(ctx, req.Condition, fieldNameToPhysical, req.SQLParams)
 | |
| 		if err != nil {
 | |
| 			return nil, fmt.Errorf("convert condition failed: %v", err)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// add rw mode
 | |
| 	if tableInfo.RwMode == table.BotTableRWMode_LimitedReadWrite && req.UserID != "" {
 | |
| 		cond := &rdb.Condition{
 | |
| 			Field:    database.DefaultUidColName,
 | |
| 			Operator: entity3.OperatorEqual,
 | |
| 			Value:    req.UserID,
 | |
| 		}
 | |
| 
 | |
| 		if complexCond == nil {
 | |
| 			complexCond = &rdb.ComplexCondition{
 | |
| 				Conditions: []*rdb.Condition{cond},
 | |
| 			}
 | |
| 		} else {
 | |
| 			complexCond.Conditions = append(complexCond.Conditions, cond)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if complexCond != nil {
 | |
| 		selectReq.Where = complexCond
 | |
| 	}
 | |
| 
 | |
| 	if len(req.OrderByList) > 0 {
 | |
| 		orderBy := make([]*rdb.OrderBy, 0, len(req.OrderByList))
 | |
| 		for _, order := range req.OrderByList {
 | |
| 			physicalField := order.Field
 | |
| 			if mapped, exists := fieldNameToPhysical[order.Field]; exists {
 | |
| 				physicalField = mapped
 | |
| 			}
 | |
| 
 | |
| 			orderBy = append(orderBy, &rdb.OrderBy{
 | |
| 				Field:     physicalField,
 | |
| 				Direction: convertSortDirection(order.Direction),
 | |
| 			})
 | |
| 		}
 | |
| 		selectReq.OrderBy = orderBy
 | |
| 	}
 | |
| 
 | |
| 	selectResp, err := d.rdb.SelectData(ctx, selectReq)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("select data failed: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	return selectResp.ResultSet, nil
 | |
| }
 | |
| 
 | |
| func (d databaseService) executeInsertSQL(ctx context.Context, req *ExecuteSQLRequest, physicalTableName string, tableInfo *entity2.Database) (*entity3.ResultSet, error) {
 | |
| 	if len(req.UpsertRows) == 0 {
 | |
| 		return nil, fmt.Errorf("no data to insert")
 | |
| 	}
 | |
| 
 | |
| 	insertData := make([]map[string]interface{}, 0, len(req.UpsertRows))
 | |
| 	ids, err := d.generator.GenMultiIDs(ctx, len(req.UpsertRows))
 | |
| 	if err != nil {
 | |
| 		return nil, errorx.WrapByCode(err, errno.ErrMemoryIDGenFailCode, errorx.KV("msg", "executeInsertSQL"))
 | |
| 	}
 | |
| 
 | |
| 	fieldList := append(tableInfo.FieldList, physicaltable.GetCreateTimeField(), physicaltable.GetUidField(), physicaltable.GetIDField(), physicaltable.GetConnectIDField())
 | |
| 	fieldMap := slices.ToMap(fieldList, func(e *database.FieldItem) (string, *database.FieldItem) {
 | |
| 		return strconv.FormatInt(e.AlterID, 10), e
 | |
| 	})
 | |
| 
 | |
| 	sqlParams := req.SQLParams
 | |
| 	i := 0
 | |
| 
 | |
| 	insertResult := make([]map[string]interface{}, 0, len(req.UpsertRows))
 | |
| 	for index, upsertRow := range req.UpsertRows {
 | |
| 		rowData := make(map[string]interface{})
 | |
| 
 | |
| 		cid := consts.CozeConnectorID
 | |
| 		if req.ConnectorID != nil {
 | |
| 			cid = *req.ConnectorID
 | |
| 		}
 | |
| 
 | |
| 		if req.UserID != "" {
 | |
| 			rowData[database.DefaultUidColName] = req.UserID
 | |
| 		}
 | |
| 		rowData[database.DefaultCidColName] = cid
 | |
| 		rowData[database.DefaultCreateTimeColName] = time.Now()
 | |
| 		rowData[database.DefaultIDColName] = ids[index]
 | |
| 
 | |
| 		for _, record := range upsertRow.Records {
 | |
| 			field, exists := fieldMap[record.FieldId]
 | |
| 			if !exists {
 | |
| 				return nil, errorx.New(errno.ErrMemoryDatabaseFieldNotFoundCode)
 | |
| 			}
 | |
| 
 | |
| 			fieldVal := sqlParams[i].Value
 | |
| 			if sqlParams[i].ISNull || fieldVal == nil {
 | |
| 				rowData[field.PhysicalName] = nil
 | |
| 				i++
 | |
| 				continue
 | |
| 			}
 | |
| 
 | |
| 			convertedValue, err := convertor.ConvertValueByType(*fieldVal, field.Type)
 | |
| 			if err != nil {
 | |
| 				logs.Warnf("convert value failed: %v, using original value", err)
 | |
| 				rowData[field.PhysicalName] = *fieldVal
 | |
| 			} else {
 | |
| 				rowData[field.PhysicalName] = convertedValue
 | |
| 			}
 | |
| 			i++
 | |
| 		}
 | |
| 
 | |
| 		insertData = append(insertData, rowData)
 | |
| 		insertResult = append(insertResult, map[string]interface{}{
 | |
| 			database.DefaultIDColName: ids[index],
 | |
| 		})
 | |
| 	}
 | |
| 
 | |
| 	insertResp, err := d.rdb.InsertData(ctx, &rdb.InsertDataRequest{
 | |
| 		TableName: physicalTableName,
 | |
| 		Data:      insertData,
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("insert data failed: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	return &entity3.ResultSet{
 | |
| 		Rows:         insertResult,
 | |
| 		AffectedRows: insertResp.AffectedRows,
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| func (d databaseService) executeUpdateSQL(ctx context.Context, req *ExecuteSQLRequest, physicalTableName string, tableInfo *entity2.Database, fieldNameToPhysical map[string]string) (int64, error) {
 | |
| 	if len(req.UpsertRows) == 0 || req.Condition == nil {
 | |
| 		return -1, fmt.Errorf("missing update data or condition")
 | |
| 	}
 | |
| 
 | |
| 	fieldList := append(tableInfo.FieldList, physicaltable.GetCreateTimeField(), physicaltable.GetUidField(), physicaltable.GetIDField(), physicaltable.GetConnectIDField())
 | |
| 	fieldMap := slices.ToMap(fieldList, func(e *database.FieldItem) (string, *database.FieldItem) {
 | |
| 		return strconv.FormatInt(e.AlterID, 10), e
 | |
| 	})
 | |
| 
 | |
| 	updateData := make(map[string]interface{})
 | |
| 	index := 0
 | |
| 	for _, record := range req.UpsertRows[0].Records {
 | |
| 		field, exists := fieldMap[record.FieldId]
 | |
| 		if !exists {
 | |
| 			return -1, errorx.New(errno.ErrMemoryDatabaseFieldNotFoundCode)
 | |
| 		}
 | |
| 
 | |
| 		param := req.SQLParams[index]
 | |
| 		fieldVal := param.Value
 | |
| 		index++
 | |
| 		if param.ISNull || fieldVal == nil {
 | |
| 			updateData[field.PhysicalName] = nil
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		convertedValue, err := convertor.ConvertValueByType(*fieldVal, field.Type)
 | |
| 		if err != nil {
 | |
| 			logs.Warnf("convert value failed: %v, using original value", err)
 | |
| 			updateData[field.PhysicalName] = *fieldVal
 | |
| 		} else {
 | |
| 			updateData[field.PhysicalName] = convertedValue
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	condParams := req.SQLParams[index:]
 | |
| 	complexCond, err := convertCondition(ctx, req.Condition, fieldNameToPhysical, condParams)
 | |
| 	if err != nil {
 | |
| 		return -1, fmt.Errorf("convert condition failed: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	// add rw mode
 | |
| 	if tableInfo.RwMode == table.BotTableRWMode_LimitedReadWrite && req.UserID != "" {
 | |
| 		cond := &rdb.Condition{
 | |
| 			Field:    database.DefaultUidColName,
 | |
| 			Operator: entity3.OperatorEqual,
 | |
| 			Value:    req.UserID,
 | |
| 		}
 | |
| 
 | |
| 		if complexCond == nil {
 | |
| 			complexCond = &rdb.ComplexCondition{
 | |
| 				Conditions: []*rdb.Condition{cond},
 | |
| 			}
 | |
| 		} else {
 | |
| 			complexCond.Conditions = append(complexCond.Conditions, cond)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	updateResp, err := d.rdb.UpdateData(ctx, &rdb.UpdateDataRequest{
 | |
| 		TableName: physicalTableName,
 | |
| 		Data:      updateData,
 | |
| 		Where:     complexCond,
 | |
| 		Limit:     int64PtrToIntPtr(req.Limit),
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		return -1, fmt.Errorf("update data failed: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	return updateResp.AffectedRows, nil
 | |
| }
 | |
| 
 | |
| func (d databaseService) executeDeleteSQL(ctx context.Context, req *ExecuteSQLRequest, physicalTableName string, tableInfo *entity2.Database, fieldNameToPhysical map[string]string) (int64, error) {
 | |
| 	if req.Condition == nil {
 | |
| 		return -1, fmt.Errorf("missing delete condition")
 | |
| 	}
 | |
| 
 | |
| 	complexCond, err := convertCondition(ctx, req.Condition, fieldNameToPhysical, req.SQLParams)
 | |
| 	if err != nil {
 | |
| 		return -1, fmt.Errorf("convert condition failed: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	// add rw mode
 | |
| 	if tableInfo.RwMode == table.BotTableRWMode_LimitedReadWrite && req.UserID != "" {
 | |
| 		cond := &rdb.Condition{
 | |
| 			Field:    database.DefaultUidColName,
 | |
| 			Operator: entity3.OperatorEqual,
 | |
| 			Value:    req.UserID,
 | |
| 		}
 | |
| 
 | |
| 		if complexCond == nil {
 | |
| 			complexCond = &rdb.ComplexCondition{
 | |
| 				Conditions: []*rdb.Condition{cond},
 | |
| 			}
 | |
| 		} else {
 | |
| 			complexCond.Conditions = append(complexCond.Conditions, cond)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	deleteResp, err := d.rdb.DeleteData(ctx, &rdb.DeleteDataRequest{
 | |
| 		TableName: physicalTableName,
 | |
| 		Where:     complexCond,
 | |
| 		Limit:     int64PtrToIntPtr(req.Limit),
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		return -1, fmt.Errorf("delete data failed: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	return deleteResp.AffectedRows, nil
 | |
| }
 | |
| 
 | |
| func int64PtrToIntPtr(i64ptr *int64) *int {
 | |
| 	if i64ptr == nil {
 | |
| 		return nil
 | |
| 	}
 | |
| 
 | |
| 	i := int(*i64ptr)
 | |
| 	return &i
 | |
| }
 | |
| 
 | |
| func convertSortDirection(direction table.SortDirection) entity3.SortDirection {
 | |
| 	if direction == table.SortDirection_Desc {
 | |
| 		return entity3.SortDirectionDesc
 | |
| 	}
 | |
| 	return entity3.SortDirectionAsc
 | |
| }
 | |
| 
 | |
| func convertCondition(ctx context.Context, cond *database.ComplexCondition, fieldMap map[string]string, params []*database.SQLParamVal) (*rdb.ComplexCondition, error) {
 | |
| 	if cond == nil {
 | |
| 		return nil, nil
 | |
| 	}
 | |
| 
 | |
| 	result := &rdb.ComplexCondition{
 | |
| 		Operator: convertor.ConvertLogicOperator(cond.Logic),
 | |
| 	}
 | |
| 
 | |
| 	index := 0
 | |
| 	if len(cond.Conditions) > 0 {
 | |
| 		conditions := make([]*rdb.Condition, 0, len(cond.Conditions))
 | |
| 		for _, c := range cond.Conditions {
 | |
| 			leftField := c.Left
 | |
| 			if mapped, exists := fieldMap[c.Left]; exists {
 | |
| 				leftField = mapped
 | |
| 			}
 | |
| 
 | |
| 			if c.Operation == database.Operation_IS_NULL || c.Operation == database.Operation_IS_NOT_NULL {
 | |
| 				conditions = append(conditions, &rdb.Condition{
 | |
| 					Field:    leftField,
 | |
| 					Operator: convertor.ConvertOperator(c.Operation),
 | |
| 				})
 | |
| 				continue
 | |
| 			}
 | |
| 
 | |
| 			if c.Operation == database.Operation_IN || c.Operation == database.Operation_NOT_IN {
 | |
| 				// c.Right: example: (?,?)
 | |
| 				qCount := 0
 | |
| 				for i := 0; i < len(c.Right); i++ {
 | |
| 					if c.Right[i] == '?' {
 | |
| 						qCount++
 | |
| 					}
 | |
| 				}
 | |
| 				if qCount == 0 {
 | |
| 					return nil, fmt.Errorf("IN/NOT_IN condition right side must contain ? placeholders")
 | |
| 				}
 | |
| 				vals := make([]interface{}, 0, qCount)
 | |
| 				for j := 0; j < qCount; j++ {
 | |
| 					if index >= len(params) {
 | |
| 						return nil, fmt.Errorf("not enough params for IN/NOT_IN condition")
 | |
| 					}
 | |
| 					if params[index].ISNull || params[index].Value == nil {
 | |
| 						index++
 | |
| 						continue
 | |
| 					}
 | |
| 					vals = append(vals, decryptSysUUIDKey(ctx, leftField, *params[index].Value))
 | |
| 					index++
 | |
| 				}
 | |
| 				conditions = append(conditions, &rdb.Condition{
 | |
| 					Field:    leftField,
 | |
| 					Operator: convertor.ConvertOperator(c.Operation),
 | |
| 					Value:    vals,
 | |
| 				})
 | |
| 				continue
 | |
| 			}
 | |
| 
 | |
| 			if params[index].ISNull || params[index].Value == nil {
 | |
| 				index++
 | |
| 				continue
 | |
| 			}
 | |
| 
 | |
| 			conditions = append(conditions, &rdb.Condition{
 | |
| 				Field:    leftField,
 | |
| 				Operator: convertor.ConvertOperator(c.Operation),
 | |
| 				Value:    decryptSysUUIDKey(ctx, leftField, *params[index].Value),
 | |
| 			})
 | |
| 			index++
 | |
| 		}
 | |
| 		result.Conditions = conditions
 | |
| 	}
 | |
| 	// if cond.NestedConditions != nil {
 | |
| 	//	nested, err := convertCondition(cond.NestedConditions, fieldMap, params)
 | |
| 	//	if err != nil {
 | |
| 	//		return nil, err
 | |
| 	//	}
 | |
| 	//	result.NestedConditions = []*rdb.ComplexCondition{nested}
 | |
| 	// }
 | |
| 
 | |
| 	return result, nil
 | |
| }
 | |
| 
 | |
| func decryptSysUUIDKey(ctx context.Context, leftField, value string) string {
 | |
| 	if leftField == database.DefaultUidDisplayColName || leftField == database.DefaultUidColName {
 | |
| 		decryptVal := crossvariables.DefaultSVC().DecryptSysUUIDKey(ctx, value)
 | |
| 		if decryptVal != nil {
 | |
| 			value = decryptVal.ConnectorUID
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return value
 | |
| }
 | |
| 
 | |
| func (d databaseService) BindDatabase(ctx context.Context, req *BindDatabaseToAgentRequest) error {
 | |
| 	draft, err := d.MGetDatabase(ctx, &MGetDatabaseRequest{
 | |
| 		Basics: []*database.DatabaseBasic{
 | |
| 			{
 | |
| 				ID:        req.DraftDatabaseID,
 | |
| 				TableType: table.TableType_DraftTable,
 | |
| 			},
 | |
| 		},
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	if len(draft.Databases) == 0 {
 | |
| 		return fmt.Errorf("online table not found, id: %d", req.DraftDatabaseID)
 | |
| 	}
 | |
| 
 | |
| 	onlineID := draft.Databases[0].GetOnlineID()
 | |
| 	relations := []*database.AgentToDatabase{
 | |
| 		{
 | |
| 			AgentID:    req.AgentID,
 | |
| 			DatabaseID: onlineID,
 | |
| 			TableType:  table.TableType_OnlineTable,
 | |
| 		},
 | |
| 		{
 | |
| 			AgentID:    req.AgentID,
 | |
| 			DatabaseID: req.DraftDatabaseID,
 | |
| 			TableType:  table.TableType_DraftTable,
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	_, err = d.agentToDatabaseDAO.BatchCreate(ctx, relations)
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("failed to bind databases to agent: %w", err)
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (d databaseService) UnBindDatabase(ctx context.Context, req *UnBindDatabaseToAgentRequest) error {
 | |
| 	draft, err := d.MGetDatabase(ctx, &MGetDatabaseRequest{
 | |
| 		Basics: []*database.DatabaseBasic{
 | |
| 			{
 | |
| 				ID:        req.DraftDatabaseID,
 | |
| 				TableType: table.TableType_DraftTable,
 | |
| 			},
 | |
| 		},
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	if len(draft.Databases) == 0 {
 | |
| 		return fmt.Errorf("online table not found, id: %d", req.DraftDatabaseID)
 | |
| 	}
 | |
| 
 | |
| 	onlineID := draft.Databases[0].GetOnlineID()
 | |
| 	relations := []*database.AgentToDatabaseBasic{
 | |
| 		{
 | |
| 			AgentID:    req.AgentID,
 | |
| 			DatabaseID: onlineID,
 | |
| 		},
 | |
| 		{
 | |
| 			AgentID:    req.AgentID,
 | |
| 			DatabaseID: req.DraftDatabaseID,
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	err = d.agentToDatabaseDAO.BatchDelete(ctx, relations)
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("failed to unbind databases from agent: %w", err)
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (d databaseService) MGetDatabaseByAgentID(ctx context.Context, req *MGetDatabaseByAgentIDRequest) (*MGetDatabaseByAgentIDResponse, error) {
 | |
| 	if req == nil {
 | |
| 		return nil, fmt.Errorf("invalid request: request is nil")
 | |
| 	}
 | |
| 
 | |
| 	relations, err := d.agentToDatabaseDAO.ListByAgentID(ctx, req.AgentID, req.TableType)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	mGetBasics := make([]*database.DatabaseBasic, 0, len(relations))
 | |
| 	for _, relation := range relations {
 | |
| 		mGetBasics = append(mGetBasics, &database.DatabaseBasic{
 | |
| 			ID:            relation.DatabaseID,
 | |
| 			TableType:     req.TableType,
 | |
| 			NeedSysFields: req.NeedSysFields,
 | |
| 		})
 | |
| 	}
 | |
| 	databases, err := d.MGetDatabase(ctx, &MGetDatabaseRequest{Basics: mGetBasics})
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return &MGetDatabaseByAgentIDResponse{
 | |
| 		Databases: databases.Databases,
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| // PublishDatabase return online database according to draft database info
 | |
| func (d databaseService) PublishDatabase(ctx context.Context, req *PublishDatabaseRequest) (*PublishDatabaseResponse, error) {
 | |
| 	if req == nil {
 | |
| 		return nil, fmt.Errorf("invalid request: request is nil")
 | |
| 	}
 | |
| 
 | |
| 	relationResp, err := d.MGetRelationsByAgentID(ctx, &MGetRelationsByAgentIDRequest{
 | |
| 		AgentID:   req.AgentID,
 | |
| 		TableType: table.TableType_DraftTable,
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	if len(relationResp.Relations) == 0 {
 | |
| 		return &PublishDatabaseResponse{}, nil
 | |
| 	}
 | |
| 
 | |
| 	dBasics := make([]*database.DatabaseBasic, 0, len(relationResp.Relations))
 | |
| 	for _, draftR := range relationResp.Relations {
 | |
| 		dBasics = append(dBasics, &database.DatabaseBasic{
 | |
| 			ID:            draftR.DatabaseID,
 | |
| 			TableType:     table.TableType_DraftTable,
 | |
| 			NeedSysFields: false,
 | |
| 		})
 | |
| 	}
 | |
| 
 | |
| 	draftDatabaseResp, err := d.MGetDatabase(ctx, &MGetDatabaseRequest{
 | |
| 		Basics: dBasics,
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	oBasics := make([]*database.DatabaseBasic, 0, len(draftDatabaseResp.Databases))
 | |
| 	for _, draft := range draftDatabaseResp.Databases {
 | |
| 		oBasics = append(oBasics, &database.DatabaseBasic{
 | |
| 			ID:            draft.GetOnlineID(),
 | |
| 			TableType:     table.TableType_OnlineTable,
 | |
| 			NeedSysFields: false,
 | |
| 		})
 | |
| 	}
 | |
| 
 | |
| 	onlineDatabaseResp, err := d.MGetDatabase(ctx, &MGetDatabaseRequest{
 | |
| 		Basics: oBasics,
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	results := make([]*bot_common.Database, 0, len(onlineDatabaseResp.Databases))
 | |
| 	for _, online := range onlineDatabaseResp.Databases {
 | |
| 		fields := make([]*bot_common.FieldItem, 0, len(online.FieldList))
 | |
| 		for _, field := range online.FieldList {
 | |
| 			fields = append(fields, &bot_common.FieldItem{
 | |
| 				Name:         ptr.Of(field.Name),
 | |
| 				Desc:         ptr.Of(field.Desc),
 | |
| 				Type:         ptr.Of(bot_common.FieldItemType(field.Type)),
 | |
| 				MustRequired: ptr.Of(field.MustRequired),
 | |
| 				AlterId:      ptr.Of(field.AlterID),
 | |
| 				Id:           ptr.Of(int64(0)),
 | |
| 			})
 | |
| 		}
 | |
| 
 | |
| 		results = append(results, &bot_common.Database{
 | |
| 			TableId:   ptr.Of(strconv.FormatInt(online.ID, 10)),
 | |
| 			TableName: ptr.Of(online.TableName),
 | |
| 			TableDesc: ptr.Of(online.TableDesc),
 | |
| 			FieldList: fields,
 | |
| 			RWMode:    ptr.Of(bot_common.BotTableRWMode(online.RwMode)),
 | |
| 		})
 | |
| 	}
 | |
| 
 | |
| 	return &PublishDatabaseResponse{
 | |
| 		OnlineDatabases: results,
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| func (d databaseService) MGetRelationsByAgentID(ctx context.Context, req *MGetRelationsByAgentIDRequest) (*MGetRelationsByAgentIDResponse, error) {
 | |
| 	if req == nil {
 | |
| 		return nil, fmt.Errorf("invalid request: request is nil")
 | |
| 	}
 | |
| 
 | |
| 	relations, err := d.agentToDatabaseDAO.ListByAgentID(ctx, req.AgentID, req.TableType)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return &MGetRelationsByAgentIDResponse{
 | |
| 		Relations: relations,
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| func (d databaseService) GetDatabaseTableSchema(ctx context.Context, req *GetDatabaseTableSchemaRequest) (*GetDatabaseTableSchemaResponse, error) {
 | |
| 	parser := &sheet.TosTableParser{
 | |
| 		UserID:         req.UserID,
 | |
| 		DocumentSource: database.DocumentSourceType_Document,
 | |
| 		TosURI:         req.TosURL,
 | |
| 		TosServ:        d.storage,
 | |
| 	}
 | |
| 
 | |
| 	res, extra, err := parser.GetTableDataBySheetIDx(ctx, entity2.TableReaderMeta{
 | |
| 		TosMaxLine:    100000,
 | |
| 		HeaderLineIdx: req.TableSheet.HeaderLineIdx,
 | |
| 		SheetId:       req.TableSheet.SheetID,
 | |
| 		StartLineIdx:  req.TableSheet.StartLineIdx,
 | |
| 		ReaderMethod:  database.TableReadDataMethodHead,
 | |
| 		ReadLineCnt:   20,
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	res.Columns, err = parser.PredictColumnType(res.Columns, res.SampleData, req.TableSheet.SheetID, req.TableSheet.StartLineIdx)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	resp := &GetDatabaseTableSchemaResponse{}
 | |
| 	if req.TableDataType == table.TableDataType_AllData || req.TableDataType == table.TableDataType_OnlyPreview {
 | |
| 		previewData, tErr := parser.TransferPreviewData(ctx, res.Columns, res.SampleData, 20)
 | |
| 		if tErr != nil {
 | |
| 			return resp, tErr
 | |
| 		}
 | |
| 		resp.PreviewData = previewData
 | |
| 	}
 | |
| 	resp.TableMeta = res.Columns
 | |
| 	resp.SheetList = extra.Sheets
 | |
| 
 | |
| 	return resp, nil
 | |
| }
 | |
| 
 | |
| func (d databaseService) ValidateDatabaseTableSchema(ctx context.Context, req *ValidateDatabaseTableSchemaRequest) (*ValidateDatabaseTableSchemaResponse, error) {
 | |
| 	parser := &sheet.TosTableParser{
 | |
| 		UserID:         req.UserID,
 | |
| 		DocumentSource: database.DocumentSourceType_Document,
 | |
| 		TosURI:         req.TosURL,
 | |
| 		TosServ:        d.storage,
 | |
| 	}
 | |
| 
 | |
| 	res, sheetRes, err := parser.GetTableDataBySheetIDx(ctx, entity2.TableReaderMeta{
 | |
| 		TosMaxLine:    100000,
 | |
| 		HeaderLineIdx: req.TableSheet.HeaderLineIdx,
 | |
| 		SheetId:       req.TableSheet.SheetID,
 | |
| 		StartLineIdx:  req.TableSheet.StartLineIdx,
 | |
| 		ReaderMethod:  database.TableReadDataMethodAll,
 | |
| 		ReadLineCnt:   20,
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	valid, invalidMsg := sheet.CheckSheetIsValid(req.Fields, res.Columns, sheetRes)
 | |
| 	return &ValidateDatabaseTableSchemaResponse{
 | |
| 		Valid:      valid,
 | |
| 		InvalidMsg: invalidMsg,
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| func (d databaseService) SubmitDatabaseInsertTask(ctx context.Context, req *SubmitDatabaseInsertTaskRequest) error {
 | |
| 	var err error
 | |
| 	failKey := onlineFailReasonKey
 | |
| 	if req.TableType == table.TableType_DraftTable {
 | |
| 		failKey = draftFailReasonKey
 | |
| 	}
 | |
| 
 | |
| 	defer func() {
 | |
| 		if r := recover(); r != nil {
 | |
| 			errMsg := fmt.Sprintf("panic: %v", r)
 | |
| 			d.cache.Set(ctx, fmt.Sprintf(failKey, req.DatabaseID, req.UserID), errMsg, redisKeyTimeOut)
 | |
| 			err = fmt.Errorf("panic: %v", r)
 | |
| 			return
 | |
| 		}
 | |
| 		if err != nil {
 | |
| 			d.cache.Set(ctx, fmt.Sprintf(failKey, req.DatabaseID, req.UserID), err.Error(), redisKeyTimeOut)
 | |
| 		}
 | |
| 	}()
 | |
| 
 | |
| 	parser := &sheet.TosTableParser{
 | |
| 		UserID:         req.UserID,
 | |
| 		DocumentSource: database.DocumentSourceType_Document,
 | |
| 		TosURI:         req.FileURI,
 | |
| 		TosServ:        d.storage,
 | |
| 	}
 | |
| 	parseData, extra, err := parser.GetTableDataBySheetIDx(ctx, entity2.TableReaderMeta{
 | |
| 		TosMaxLine:    100000,
 | |
| 		SheetId:       req.TableSheet.SheetID,
 | |
| 		HeaderLineIdx: req.TableSheet.HeaderLineIdx,
 | |
| 		StartLineIdx:  req.TableSheet.StartLineIdx,
 | |
| 		ReaderMethod:  database.TableReadDataMethodAll,
 | |
| 	},
 | |
| 	)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	err = d.initializeCache(ctx, req, parseData, extra)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	columns := parseData.Columns
 | |
| 
 | |
| 	records := make([]map[string]string, 0, len(parseData.SampleData))
 | |
| 	for _, data := range parseData.SampleData {
 | |
| 		record := make(map[string]string)
 | |
| 		for i, column := range columns {
 | |
| 			record[column.ColumnName] = data[i]
 | |
| 		}
 | |
| 		records = append(records, record)
 | |
| 	}
 | |
| 
 | |
| 	batchSize := 20
 | |
| 	go func() {
 | |
| 		defer func() {
 | |
| 			if r := recover(); r != nil {
 | |
| 				errMsg := fmt.Sprintf("panic: %v", r)
 | |
| 				d.cache.Set(ctx, fmt.Sprintf(failKey, req.DatabaseID, req.UserID), errMsg, redisKeyTimeOut)
 | |
| 			}
 | |
| 		}()
 | |
| 
 | |
| 		for i := 0; i < len(records); i += batchSize {
 | |
| 			end := i + batchSize
 | |
| 			if end > len(records) {
 | |
| 				end = len(records)
 | |
| 			}
 | |
| 			batchRecords := records[i:end]
 | |
| 			err = d.AddDatabaseRecord(ctx, &AddDatabaseRecordRequest{
 | |
| 				DatabaseID:  req.DatabaseID,
 | |
| 				TableType:   req.TableType,
 | |
| 				ConnectorID: req.ConnectorID,
 | |
| 				UserID:      req.UserID,
 | |
| 				Records:     batchRecords,
 | |
| 			})
 | |
| 			if err != nil {
 | |
| 				d.cache.Set(ctx, fmt.Sprintf(failKey, req.DatabaseID, req.UserID), err.Error(), redisKeyTimeOut)
 | |
| 				return
 | |
| 			}
 | |
| 
 | |
| 			err = d.increaseProgress(ctx, req, int64(len(batchRecords)))
 | |
| 			if err != nil {
 | |
| 				d.cache.Set(ctx, fmt.Sprintf(failKey, req.DatabaseID, req.UserID), err.Error(), redisKeyTimeOut)
 | |
| 				return
 | |
| 			}
 | |
| 		}
 | |
| 	}()
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (d databaseService) GetDatabaseFileProgressData(ctx context.Context, req *GetDatabaseFileProgressDataRequest) (*GetDatabaseFileProgressDataResponse, error) {
 | |
| 	totalKey := onlineTotalCountKey
 | |
| 	if req.TableType == table.TableType_DraftTable {
 | |
| 		totalKey = draftTotalCountKey
 | |
| 	}
 | |
| 	progressKey := onlineProgressKey
 | |
| 	if req.TableType == table.TableType_DraftTable {
 | |
| 		progressKey = draftProgressKey
 | |
| 	}
 | |
| 	failKey := onlineFailReasonKey
 | |
| 	if req.TableType == table.TableType_DraftTable {
 | |
| 		failKey = draftFailReasonKey
 | |
| 	}
 | |
| 	currentFileName := onlineCurrentFileName
 | |
| 	if req.TableType == table.TableType_DraftTable {
 | |
| 		currentFileName = draftCurrentFileName
 | |
| 	}
 | |
| 	totalNum, err := d.cache.Get(ctx, fmt.Sprintf(totalKey, req.DatabaseID, req.UserID)).Int64()
 | |
| 	if err != nil && !errors.Is(err, cache.Nil) {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	progressNum, err := d.cache.Get(ctx, fmt.Sprintf(progressKey, req.DatabaseID, req.UserID)).Int64()
 | |
| 	if err != nil && !errors.Is(err, cache.Nil) {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	failReason, err := d.cache.Get(ctx, fmt.Sprintf(failKey, req.DatabaseID, req.UserID)).Result()
 | |
| 	if err != nil && !errors.Is(err, cache.Nil) {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	fileName, err := d.cache.Get(ctx, fmt.Sprintf(currentFileName, req.DatabaseID, req.UserID)).Result()
 | |
| 	if err != nil && !errors.Is(err, cache.Nil) {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	resp := &GetDatabaseFileProgressDataResponse{}
 | |
| 	if totalNum == 0 {
 | |
| 		resp.FileName = ""
 | |
| 		resp.Progress = 100
 | |
| 	} else {
 | |
| 		resp.FileName = fileName
 | |
| 		resp.Progress = int32(float32(progressNum) / float32(totalNum) * 100)
 | |
| 		resp.StatusDescript = ptr.Of(failReason)
 | |
| 	}
 | |
| 	return resp, nil
 | |
| }
 | |
| 
 | |
| const (
 | |
| 	draftTotalCountKey    = "database_file_%d_%d_draft_total"
 | |
| 	onlineTotalCountKey   = "database_file_%d_%d_online_total"
 | |
| 	draftProgressKey      = "database_file_%d_%d_draft_progress"
 | |
| 	onlineProgressKey     = "database_file_%d_%d_online_progress"
 | |
| 	draftFailReasonKey    = "database_file_%d_%d_draft_fail_reason"
 | |
| 	onlineFailReasonKey   = "database_file_%d_%d_online_fail_reason"
 | |
| 	draftCurrentFileName  = "database_file_%d_%d_draft_file_name"
 | |
| 	onlineCurrentFileName = "database_file_%d_%d_online_file_name"
 | |
| 	redisKeyTimeOut       = time.Hour * 12
 | |
| )
 | |
| 
 | |
| func (d databaseService) initializeCache(ctx context.Context, req *SubmitDatabaseInsertTaskRequest, parseData *entity2.TableReaderSheetData, extra *entity2.ExcelExtraInfo) error {
 | |
| 	tableType := req.TableType
 | |
| 	userID := req.UserID
 | |
| 	databaseID := req.DatabaseID
 | |
| 
 | |
| 	totalKey := onlineTotalCountKey
 | |
| 	if tableType == table.TableType_DraftTable {
 | |
| 		totalKey = draftTotalCountKey
 | |
| 	}
 | |
| 	currentFileName := onlineCurrentFileName
 | |
| 	if tableType == table.TableType_DraftTable {
 | |
| 		currentFileName = draftCurrentFileName
 | |
| 	}
 | |
| 	progressKey := onlineProgressKey
 | |
| 	if tableType == table.TableType_DraftTable {
 | |
| 		progressKey = draftProgressKey
 | |
| 	}
 | |
| 	failKey := onlineFailReasonKey
 | |
| 	if tableType == table.TableType_DraftTable {
 | |
| 		failKey = draftFailReasonKey
 | |
| 	}
 | |
| 
 | |
| 	_, err := d.cache.Set(ctx, fmt.Sprintf(totalKey, databaseID, userID), fmt.Sprintf("%d", len(parseData.SampleData)), redisKeyTimeOut).Result()
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	_, err = d.cache.Set(ctx, fmt.Sprintf(progressKey, databaseID, userID), int64(0), redisKeyTimeOut).Result()
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	_, err = d.cache.Set(ctx, fmt.Sprintf(failKey, databaseID, userID), "", redisKeyTimeOut).Result()
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	_, err = d.cache.Set(ctx, fmt.Sprintf(currentFileName, databaseID, userID), extra.Sheets[req.TableSheet.SheetID].SheetName, redisKeyTimeOut).Result()
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (d databaseService) increaseProgress(ctx context.Context, req *SubmitDatabaseInsertTaskRequest, successNum int64) error {
 | |
| 	tableType := req.TableType
 | |
| 	userID := req.UserID
 | |
| 	databaseID := req.DatabaseID
 | |
| 
 | |
| 	progressKey := onlineProgressKey
 | |
| 	if tableType == table.TableType_DraftTable {
 | |
| 		progressKey = draftProgressKey
 | |
| 	}
 | |
| 
 | |
| 	_, err := d.cache.IncrBy(ctx, fmt.Sprintf(progressKey, databaseID, userID), successNum).Result()
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (d databaseService) GetDraftDatabaseByOnlineID(ctx context.Context, req *GetDraftDatabaseByOnlineIDRequest) (*GetDraftDatabaseByOnlineIDResponse, error) {
 | |
| 	online, err := d.MGetDatabase(ctx, &MGetDatabaseRequest{
 | |
| 		Basics: []*database.DatabaseBasic{
 | |
| 			{
 | |
| 				ID:        req.OnlineID,
 | |
| 				TableType: table.TableType_OnlineTable,
 | |
| 			},
 | |
| 		},
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	if len(online.Databases) == 0 {
 | |
| 		return nil, fmt.Errorf("online table not found, id: %d", req.OnlineID)
 | |
| 	}
 | |
| 
 | |
| 	draftID := online.Databases[0].GetDraftID()
 | |
| 
 | |
| 	draftResp, err := d.MGetDatabase(ctx, &MGetDatabaseRequest{
 | |
| 		Basics: []*database.DatabaseBasic{
 | |
| 			{
 | |
| 				ID:        draftID,
 | |
| 				TableType: table.TableType_DraftTable,
 | |
| 			},
 | |
| 		},
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	if len(draftResp.Databases) == 0 {
 | |
| 		return nil, fmt.Errorf("online table not found, id: %d", req.OnlineID)
 | |
| 	}
 | |
| 
 | |
| 	return &GetDraftDatabaseByOnlineIDResponse{
 | |
| 		Database: draftResp.Databases[0],
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| // DeleteDatabaseByAppID delete all records and all physical tables by app id
 | |
| func (d databaseService) DeleteDatabaseByAppID(ctx context.Context, req *DeleteDatabaseByAppIDRequest) (*DeleteDatabaseByAppIDResponse, error) {
 | |
| 	onlineDBInfos, err := d.listDatabasesByAppID(ctx, req.AppID, table.TableType_OnlineTable)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	draftDBInfos, err := d.listDatabasesByAppID(ctx, req.AppID, table.TableType_DraftTable)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	tx := query.Use(d.db).Begin()
 | |
| 	if tx.Error != nil {
 | |
| 		return nil, fmt.Errorf("start transaction failed, %v", tx.Error)
 | |
| 	}
 | |
| 
 | |
| 	defer func() {
 | |
| 		if r := recover(); r != nil {
 | |
| 			e := tx.Rollback()
 | |
| 			if e != nil {
 | |
| 				logs.CtxErrorf(ctx, "rollback failed, err=%v", e)
 | |
| 			}
 | |
| 
 | |
| 			err = fmt.Errorf("catch panic: %v\nstack=%s", r, string(debug.Stack()))
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		if err != nil {
 | |
| 			e := tx.Rollback()
 | |
| 			if e != nil {
 | |
| 				logs.CtxErrorf(ctx, "rollback failed, err=%v", e)
 | |
| 			}
 | |
| 		}
 | |
| 	}()
 | |
| 
 | |
| 	onlineIDs := make([]int64, 0, len(onlineDBInfos))
 | |
| 	for _, db := range onlineDBInfos {
 | |
| 		onlineIDs = append(onlineIDs, db.ID)
 | |
| 	}
 | |
| 
 | |
| 	draftIDs := make([]int64, 0, len(draftDBInfos))
 | |
| 	for _, db := range draftDBInfos {
 | |
| 		draftIDs = append(draftIDs, db.ID)
 | |
| 	}
 | |
| 
 | |
| 	if err = d.onlineDAO.BatchDeleteWithTX(ctx, tx, onlineIDs); err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	if err = d.draftDAO.BatchDeleteWithTX(ctx, tx, draftIDs); err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	err = tx.Commit()
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("commit transaction failed: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	// delete draft and online physical table
 | |
| 	onlinePhysicals := make([]string, 0, len(onlineDBInfos))
 | |
| 	for _, db := range onlineDBInfos {
 | |
| 		onlinePhysicals = append(onlinePhysicals, db.ActualTableName)
 | |
| 	}
 | |
| 
 | |
| 	draftPhysicals := make([]string, 0, len(draftDBInfos))
 | |
| 	for _, db := range draftDBInfos {
 | |
| 		draftPhysicals = append(draftPhysicals, db.ActualTableName)
 | |
| 	}
 | |
| 
 | |
| 	for _, physical := range onlinePhysicals {
 | |
| 		_, err = d.rdb.DropTable(ctx, &rdb.DropTableRequest{
 | |
| 			TableName: physical,
 | |
| 		})
 | |
| 		if err != nil {
 | |
| 			logs.Errorf("drop online physical table failed: %v, table_name=%s", err, physical)
 | |
| 		}
 | |
| 	}
 | |
| 	for _, physical := range draftPhysicals {
 | |
| 		_, err = d.rdb.DropTable(ctx, &rdb.DropTableRequest{
 | |
| 			TableName: physical,
 | |
| 		})
 | |
| 		if err != nil {
 | |
| 			logs.Errorf("drop draft physical table failed: %v, table_name=%s", err, physical)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return &DeleteDatabaseByAppIDResponse{
 | |
| 		DeletedDatabaseIDs: onlineIDs,
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| func (d databaseService) listDatabasesByAppID(ctx context.Context, appID int64, tableType table.TableType) ([]*entity2.Database, error) {
 | |
| 	const batchSize = 100
 | |
| 	offset := 0
 | |
| 	dbInfos := make([]*entity2.Database, 0)
 | |
| 	for {
 | |
| 		resp, err := d.ListDatabase(ctx, &ListDatabaseRequest{
 | |
| 			AppID:     appID,
 | |
| 			TableType: tableType,
 | |
| 			Limit:     batchSize,
 | |
| 			Offset:    offset,
 | |
| 		})
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 
 | |
| 		for _, db := range resp.Databases {
 | |
| 			dbInfos = append(dbInfos, db)
 | |
| 		}
 | |
| 
 | |
| 		if !resp.HasMore {
 | |
| 			break
 | |
| 		}
 | |
| 
 | |
| 		offset += batchSize
 | |
| 	}
 | |
| 
 | |
| 	return dbInfos, nil
 | |
| }
 | |
| 
 | |
| func (d databaseService) GetAllDatabaseByAppID(ctx context.Context, req *GetAllDatabaseByAppIDRequest) (*GetAllDatabaseByAppIDResponse, error) {
 | |
| 	onlineDBs, err := d.listDatabasesByAppID(ctx, req.AppID, table.TableType_OnlineTable)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return &GetAllDatabaseByAppIDResponse{
 | |
| 		Databases: onlineDBs,
 | |
| 	}, nil
 | |
| }
 |