198 lines
		
	
	
		
			5.7 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			198 lines
		
	
	
		
			5.7 KiB
		
	
	
	
		
			Go
		
	
	
	
/*
 | 
						|
 * Copyright 2025 coze-dev Authors
 | 
						|
 *
 | 
						|
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
 * you may not use this file except in compliance with the License.
 | 
						|
 * You may obtain a copy of the License at
 | 
						|
 *
 | 
						|
 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
 *
 | 
						|
 * Unless required by applicable law or agreed to in writing, software
 | 
						|
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
 * See the License for the specific language governing permissions and
 | 
						|
 * limitations under the License.
 | 
						|
 */
 | 
						|
 | 
						|
package dao
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"errors"
 | 
						|
	"strconv"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"gorm.io/gorm"
 | 
						|
 | 
						|
	"github.com/coze-dev/coze-studio/backend/domain/knowledge/entity"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/domain/knowledge/internal/dal/model"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/domain/knowledge/internal/dal/query"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
 | 
						|
)
 | 
						|
 | 
						|
type KnowledgeDocumentDAO struct {
 | 
						|
	DB    *gorm.DB
 | 
						|
	Query *query.Query
 | 
						|
}
 | 
						|
 | 
						|
func (dao *KnowledgeDocumentDAO) Create(ctx context.Context, document *model.KnowledgeDocument) error {
 | 
						|
	return dao.Query.KnowledgeDocument.WithContext(ctx).Create(document)
 | 
						|
}
 | 
						|
 | 
						|
func (dao *KnowledgeDocumentDAO) Update(ctx context.Context, document *model.KnowledgeDocument) error {
 | 
						|
	document.UpdatedAt = time.Now().UnixMilli()
 | 
						|
	err := dao.Query.KnowledgeDocument.WithContext(ctx).Save(document)
 | 
						|
	return err
 | 
						|
}
 | 
						|
 | 
						|
func (dao *KnowledgeDocumentDAO) Delete(ctx context.Context, id int64) error {
 | 
						|
	k := dao.Query.KnowledgeDocument
 | 
						|
	_, err := k.WithContext(ctx).Where(k.ID.Eq(id)).Delete()
 | 
						|
	return err
 | 
						|
}
 | 
						|
 | 
						|
func (dao *KnowledgeDocumentDAO) MGetByID(ctx context.Context, ids []int64) ([]*model.KnowledgeDocument, error) {
 | 
						|
	if len(ids) == 0 {
 | 
						|
		return nil, nil
 | 
						|
	}
 | 
						|
 | 
						|
	k := dao.Query.KnowledgeDocument
 | 
						|
	pos, err := k.WithContext(ctx).Where(k.ID.In(ids...)).Find()
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	return pos, err
 | 
						|
}
 | 
						|
 | 
						|
func (dao *KnowledgeDocumentDAO) fromCursor(cursor string) (id int64, err error) {
 | 
						|
	id, err = strconv.ParseInt(cursor, 10, 64)
 | 
						|
	return
 | 
						|
}
 | 
						|
 | 
						|
func (dao *KnowledgeDocumentDAO) FindDocumentByCondition(ctx context.Context, opts *entity.WhereDocumentOpt) ([]*model.KnowledgeDocument, int64, error) {
 | 
						|
	k := dao.Query.KnowledgeDocument
 | 
						|
	do := k.WithContext(ctx)
 | 
						|
	if opts == nil {
 | 
						|
		return nil, 0, nil
 | 
						|
	}
 | 
						|
	if len(opts.IDs) == 0 && len(opts.KnowledgeIDs) == 0 {
 | 
						|
		return nil, 0, errors.New("need ids or knowledge_ids")
 | 
						|
	}
 | 
						|
	if opts.CreatorID > 0 {
 | 
						|
		do = do.Where(k.CreatorID.Eq(opts.CreatorID))
 | 
						|
	}
 | 
						|
	if len(opts.IDs) > 0 {
 | 
						|
		do = do.Where(k.ID.In(opts.IDs...))
 | 
						|
	}
 | 
						|
	if len(opts.KnowledgeIDs) > 0 {
 | 
						|
		do = do.Where(k.KnowledgeID.In(opts.KnowledgeIDs...))
 | 
						|
	}
 | 
						|
	if len(opts.StatusIn) > 0 {
 | 
						|
		do = do.Where(k.Status.In(opts.StatusIn...))
 | 
						|
	}
 | 
						|
	if len(opts.StatusNotIn) > 0 {
 | 
						|
		do = do.Where(k.Status.NotIn(opts.StatusNotIn...))
 | 
						|
	}
 | 
						|
	if opts.SelectAll {
 | 
						|
		do = do.Limit(-1)
 | 
						|
	} else {
 | 
						|
		if opts.Limit != 0 {
 | 
						|
			do = do.Limit(opts.Limit)
 | 
						|
		}
 | 
						|
		if opts.Offset != nil {
 | 
						|
			do = do.Offset(ptr.From(opts.Offset))
 | 
						|
		}
 | 
						|
	}
 | 
						|
	if opts.Cursor != nil {
 | 
						|
		id, err := dao.fromCursor(ptr.From(opts.Cursor))
 | 
						|
		if err != nil {
 | 
						|
			return nil, 0, err
 | 
						|
		}
 | 
						|
		do = do.Where(k.ID.Lt(id)).Order(k.ID.Desc())
 | 
						|
	}
 | 
						|
	resp, err := do.Find()
 | 
						|
	if err != nil {
 | 
						|
		return nil, 0, err
 | 
						|
	}
 | 
						|
	total, err := do.Limit(-1).Offset(-1).Count()
 | 
						|
	if err != nil {
 | 
						|
		return nil, 0, err
 | 
						|
	}
 | 
						|
	return resp, total, nil
 | 
						|
}
 | 
						|
 | 
						|
func (dao *KnowledgeDocumentDAO) DeleteDocuments(ctx context.Context, ids []int64) error {
 | 
						|
	tx := dao.DB.Begin()
 | 
						|
	var err error
 | 
						|
	defer func() {
 | 
						|
		if err != nil {
 | 
						|
			tx.Rollback()
 | 
						|
		} else {
 | 
						|
			tx.Commit()
 | 
						|
		}
 | 
						|
	}()
 | 
						|
	// 删除document
 | 
						|
	err = tx.WithContext(ctx).Model(&model.KnowledgeDocument{}).Where("id in ?", ids).Delete(&model.KnowledgeDocument{}).Error
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
	// 删除document_slice
 | 
						|
	err = tx.WithContext(ctx).Model(&model.KnowledgeDocumentSlice{}).Where("document_id in?", ids).Delete(&model.KnowledgeDocumentSlice{}).Error
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (dao *KnowledgeDocumentDAO) SetStatus(ctx context.Context, documentID int64, status int32, reason string) error {
 | 
						|
	k := dao.Query.KnowledgeDocument
 | 
						|
	d := &model.KnowledgeDocument{Status: status, FailReason: reason, UpdatedAt: time.Now().UnixMilli()}
 | 
						|
	_, err := k.WithContext(ctx).Debug().Where(k.ID.Eq(documentID)).Updates(d)
 | 
						|
	return err
 | 
						|
}
 | 
						|
 | 
						|
func (dao *KnowledgeDocumentDAO) CreateWithTx(ctx context.Context, tx *gorm.DB, documents []*model.KnowledgeDocument) error {
 | 
						|
	if len(documents) == 0 {
 | 
						|
		return nil
 | 
						|
	}
 | 
						|
	tx = tx.WithContext(ctx).Debug().CreateInBatches(documents, len(documents))
 | 
						|
	return tx.Error
 | 
						|
}
 | 
						|
 | 
						|
func (dao *KnowledgeDocumentDAO) GetByID(ctx context.Context, id int64) (*model.KnowledgeDocument, error) {
 | 
						|
	k := dao.Query.KnowledgeDocument
 | 
						|
	document, err := k.WithContext(ctx).Where(k.ID.Eq(id)).First()
 | 
						|
	if err != nil {
 | 
						|
		if errors.Is(err, gorm.ErrRecordNotFound) {
 | 
						|
			return nil, nil
 | 
						|
		}
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
	return document, nil
 | 
						|
}
 | 
						|
 | 
						|
func (dao *KnowledgeDocumentDAO) UpdateDocumentSliceInfo(ctx context.Context, documentID int64) error {
 | 
						|
	s := dao.Query.KnowledgeDocumentSlice
 | 
						|
	var err error
 | 
						|
	var sliceCount int64
 | 
						|
	var totalSize *int64
 | 
						|
	sliceCount, err = s.WithContext(ctx).Debug().Where(s.DocumentID.Eq(documentID)).Count()
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
	err = dao.DB.Raw("SELECT SUM(CHAR_LENGTH(content)) FROM knowledge_document_slice WHERE document_id = ? AND deleted_at IS NULL", documentID).Scan(&totalSize).Error
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
	k := dao.Query.KnowledgeDocument
 | 
						|
	updates := map[string]any{}
 | 
						|
	updates[k.SliceCount.ColumnName().String()] = sliceCount
 | 
						|
	if totalSize != nil {
 | 
						|
		updates[k.Size.ColumnName().String()] = ptr.From(totalSize)
 | 
						|
	}
 | 
						|
	updates[k.UpdatedAt.ColumnName().String()] = time.Now().UnixMilli()
 | 
						|
	_, err = k.WithContext(ctx).Debug().Where(k.ID.Eq(documentID)).Updates(updates)
 | 
						|
	return err
 | 
						|
}
 |