feat: manually mirror opencoze's code from bytedance

Change-Id: I09a73aadda978ad9511264a756b2ce51f5761adf
This commit is contained in:
fanlv
2025-07-20 17:36:12 +08:00
commit 890153324f
14811 changed files with 1923430 additions and 0 deletions

23
backend/infra/contract/cache/cache.go vendored Normal file
View File

@@ -0,0 +1,23 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package cache
import (
"github.com/redis/go-redis/v9"
)
type Cmdable = redis.Cmdable

View File

@@ -0,0 +1,35 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package chatmodel
import (
"context"
"github.com/cloudwego/eino/components/model"
)
//go:generate mockgen -destination ../../../internal/mock/infra/contract/chatmodel/base_model_mock.go -package mock -source ${GOPATH}/src/github.com/cloudwego/eino/components/model/interface.go BaseChatModel
type BaseChatModel = model.BaseChatModel
//go:generate mockgen -destination ../../../internal/mock/infra/contract/chatmodel/toolcalling_model_mock.go -package mock -source ${GOPATH}/src/github.com/cloudwego/eino/components/model/interface.go ToolCallingChatModel
type ToolCallingChatModel = model.ToolCallingChatModel
//go:generate mockgen -destination ../../../internal/mock/infra/contract/chatmodel/chat_model_factory_mock.go -package mock -source chat_model.go Factory
type Factory interface {
CreateChatModel(ctx context.Context, protocol Protocol, config *Config) (ToolCallingChatModel, error)
SupportProtocol(protocol Protocol) bool
}

View File

@@ -0,0 +1,94 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package chatmodel
import (
"time"
"github.com/cloudwego/eino-ext/components/model/deepseek"
"github.com/cloudwego/eino-ext/libs/acl/openai"
"google.golang.org/genai"
)
type Config struct {
BaseURL string `json:"base_url,omitempty" yaml:"base_url,omitempty"`
APIKey string `json:"api_key,omitempty" yaml:"api_key,omitempty"`
Timeout time.Duration `json:"timeout,omitempty" yaml:"timeout,omitempty"`
Model string `json:"model" yaml:"model"`
Temperature *float32 `json:"temperature,omitempty" yaml:"temperature,omitempty"`
FrequencyPenalty *float32 `json:"frequency_penalty,omitempty" yaml:"frequency_penalty,omitempty"`
PresencePenalty *float32 `json:"presence_penalty,omitempty" yaml:"presence_penalty,omitempty"`
MaxTokens *int `json:"max_tokens,omitempty" yaml:"max_tokens,omitempty"`
TopP *float32 `json:"top_p,omitempty" yaml:"top_p"`
TopK *int `json:"top_k,omitempty" yaml:"top_k"`
Stop []string `json:"stop,omitempty" yaml:"stop"`
EnableThinking *bool `json:"enable_thinking,omitempty" yaml:"enable_thinking,omitempty"`
OpenAI *OpenAIConfig `json:"open_ai,omitempty" yaml:"openai"`
Claude *ClaudeConfig `json:"claude,omitempty" yaml:"claude"`
Ark *ArkConfig `json:"ark,omitempty" yaml:"ark"`
Deepseek *DeepseekConfig `json:"deepseek,omitempty" yaml:"deepseek"`
Qwen *QwenConfig `json:"qwen,omitempty" yaml:"qwen"`
Gemini *GeminiConfig `json:"gemini,omitempty" yaml:"gemini"`
Custom map[string]string `json:"custom,omitempty" yaml:"custom"`
}
type OpenAIConfig struct {
ByAzure bool `json:"by_azure,omitempty" yaml:"by_azure"`
APIVersion string `json:"api_version,omitempty" yaml:"api_version"`
ResponseFormat *openai.ChatCompletionResponseFormat `json:"response_format,omitempty" yaml:"response_format"`
}
type ClaudeConfig struct {
ByBedrock bool `json:"by_bedrock" yaml:"by_bedrock"`
// bedrock config
AccessKey string `json:"access_key,omitempty" yaml:"access_key"`
SecretAccessKey string `json:"secret_access_key,omitempty" yaml:"secret_access_key"`
SessionToken string `json:"session_token,omitempty" yaml:"session_token"`
Region string `json:"region,omitempty" yaml:"region"`
}
type ArkConfig struct {
Region string `json:"region" yaml:"region"`
AccessKey string `json:"access_key,omitempty" yaml:"access_key"`
SecretKey string `json:"secret_key,omitempty" yaml:"secret_key"`
RetryTimes *int `json:"retry_times,omitempty" yaml:"retry_times"`
CustomHeader map[string]string `json:"custom_header,omitempty" yaml:"custom_header"`
}
type DeepseekConfig struct {
ResponseFormatType deepseek.ResponseFormatType `json:"response_format_type" yaml:"response_format_type"`
}
type QwenConfig struct {
ResponseFormat *openai.ChatCompletionResponseFormat `json:"response_format,omitempty" yaml:"response_format"`
}
type GeminiConfig struct {
Backend genai.Backend `json:"backend,omitempty" yaml:"backend"`
Project string `json:"project,omitempty" yaml:"project"`
Location string `json:"location,omitempty" yaml:"location"`
APIVersion string `json:"api_version,omitempty" yaml:"api_version"`
Headers map[string][]string `json:"headers,omitempty" yaml:"headers"`
TimeoutMs int64 `json:"timeout_ms,omitempty" yaml:"timeout_ms"`
IncludeThoughts *bool `json:"include_thoughts,omitempty" yaml:"include_thoughts"` // default true
ThinkingBudget *int32 `json:"thinking_budget,omitempty" yaml:"thinking_budget"` // default nil
}

View File

@@ -0,0 +1,55 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package chatmodel
import "github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/developer_api"
type Protocol string
const (
ProtocolOpenAI Protocol = "openai"
ProtocolClaude Protocol = "claude"
ProtocolDeepseek Protocol = "deepseek"
ProtocolGemini Protocol = "gemini"
ProtocolArk Protocol = "ark"
ProtocolOllama Protocol = "ollama"
ProtocolQwen Protocol = "qwen"
ProtocolErnie Protocol = "ernie"
)
func (p Protocol) TOModelClass() developer_api.ModelClass {
switch p {
case ProtocolArk:
return developer_api.ModelClass_SEED
case ProtocolOpenAI:
return developer_api.ModelClass_GPT
case ProtocolDeepseek:
return developer_api.ModelClass_DeekSeek
case ProtocolClaude:
return developer_api.ModelClass_Claude
case ProtocolGemini:
return developer_api.ModelClass_Gemini
case ProtocolOllama:
return developer_api.ModelClass_Llama
case ProtocolQwen:
return developer_api.ModelClass_QWen
case ProtocolErnie:
return developer_api.ModelClass_Ernie
default:
return developer_api.ModelClass_Other
}
}

View File

@@ -0,0 +1,126 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package document
import (
"fmt"
"github.com/cloudwego/eino/schema"
)
const (
MetaDataKeyColumns = "table_columns" // val: []*Column
MetaDataKeyColumnData = "table_column_data" // val: []*ColumnData
MetaDataKeyColumnsOnly = "table_columns_only" // val: struct{}, which means table has no data, only header.
MetaDataKeyCreatorID = "creator_id" // val: int64
MetaDataKeyExternalStorage = "external_storage" // val: map[string]any
)
func GetDocumentColumns(doc *schema.Document) ([]*Column, error) {
if doc == nil || doc.MetaData == nil {
return nil, fmt.Errorf("invalid document")
}
columns, ok := doc.MetaData[MetaDataKeyColumns].([]*Column)
if !ok {
return nil, fmt.Errorf("invalid document columns")
}
return columns, nil
}
func WithDocumentColumns(doc *schema.Document, columns []*Column) *schema.Document {
doc.MetaData[MetaDataKeyColumns] = columns
return doc
}
func GetDocumentColumnData(doc *schema.Document) ([]*ColumnData, error) {
if doc == nil || doc.MetaData == nil {
return nil, fmt.Errorf("invalid document")
}
data, ok := doc.MetaData[MetaDataKeyColumnData].([]*ColumnData)
if !ok {
return nil, fmt.Errorf("invalid document column data")
}
return data, nil
}
func WithDocumentColumnData(doc *schema.Document, data []*ColumnData) *schema.Document {
doc.MetaData[MetaDataKeyColumnData] = data
return doc
}
func WithDocumentColumnsOnly(doc *schema.Document) *schema.Document {
doc.MetaData[MetaDataKeyColumnsOnly] = struct{}{}
return doc
}
func GetDocumentColumnsOnly(doc *schema.Document) (bool, error) {
if doc == nil || doc.MetaData == nil {
return false, fmt.Errorf("invalid document")
}
_, ok := doc.MetaData[MetaDataKeyColumnsOnly].(struct{})
return ok, nil
}
func GetDocumentsColumnsOnly(docs []*schema.Document) (bool, error) {
if len(docs) != 1 {
return false, nil
}
return GetDocumentColumnsOnly(docs[0])
}
func GetDocumentCreatorID(doc *schema.Document) (int64, error) {
if doc == nil || doc.MetaData == nil {
return 0, fmt.Errorf("invalid document")
}
creatorID, ok := doc.MetaData[MetaDataKeyCreatorID].(int64)
if !ok {
return 0, fmt.Errorf("invalid document creator id")
}
return creatorID, nil
}
func WithDocumentCreatorID(doc *schema.Document, creatorID int64) *schema.Document {
doc.MetaData[MetaDataKeyCreatorID] = creatorID
return doc
}
func GetDocumentExternalStorage(doc *schema.Document) (map[string]any, error) {
if doc == nil || doc.MetaData == nil {
return nil, fmt.Errorf("invalid document")
}
data, ok := doc.MetaData[MetaDataKeyExternalStorage].(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid document external storage")
}
return data, nil
}
func WithDocumentExternalStorage(doc *schema.Document, externalStorage map[string]any) *schema.Document {
doc.MetaData[MetaDataKeyExternalStorage] = externalStorage
return doc
}

View File

@@ -0,0 +1,23 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package imageunderstand
import "context"
type ImageUnderstand interface {
ImageUnderstand(ctx context.Context, image []byte) (content string, err error)
}

View File

@@ -0,0 +1,29 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nl2sql
import (
"context"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
)
type NL2SQL interface {
NL2SQL(ctx context.Context, messages []*schema.Message, tables []*document.TableSchema, opts ...Option) (sql string, err error)
}

View File

@@ -0,0 +1,31 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nl2sql
import "github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
type Option func(o *Options)
type Options struct {
ChatModel chatmodel.BaseChatModel
}
func WithChatModel(cm chatmodel.BaseChatModel) Option {
return func(o *Options) {
o.ChatModel = cm
}
}

View File

@@ -0,0 +1,24 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ocr
import "context"
type OCR interface {
FromBase64(ctx context.Context, b64 string) (texts []string, err error)
FromURL(ctx context.Context, url string) (texts []string, err error)
}

View File

@@ -0,0 +1,128 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package parser
import (
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
"github.com/coze-dev/coze-studio/backend/pkg/lang/sets"
)
type Manager interface {
GetParser(config *Config) (Parser, error)
}
type Config struct {
FileExtension FileExtension
ParsingStrategy *ParsingStrategy
ChunkingStrategy *ChunkingStrategy
}
// ParsingStrategy for document parse before indexing
type ParsingStrategy struct {
// Doc
ExtractImage bool `json:"extract_image"` // 提取图片元素
ExtractTable bool `json:"extract_table"` // 提取表格元素
ImageOCR bool `json:"image_ocr"` // 图片 ocr
FilterPages []int `json:"filter_pages"` // 页过滤, 第一页=1
// Sheet
SheetID *int `json:"sheet_id"` // xlsx sheet id
HeaderLine int `json:"header_line"` // 表头行
DataStartLine int `json:"data_start_line"` // 数据起始行
RowsCount int `json:"rows_count"` // 读取数据行数
IsAppend bool `json:"-"` // 行插入
Columns []*document.Column `json:"-"` // sheet 对齐表头
IgnoreColumnTypeErr bool `json:"-"` // true 时忽略 column type 与 value 未对齐的问题,此时 value 为空
// Image
ImageAnnotationType ImageAnnotationType `json:"image_annotation_type"` // 图片内容标注类型
}
type ChunkingStrategy struct {
ChunkType ChunkType `json:"chunk_type"`
// custom config
ChunkSize int64 `json:"chunk_size"` // 分段最大长度
Separator string `json:"separator"` // 分段标识符
Overlap int64 `json:"overlap"` // 分段重叠比例
TrimSpace bool `json:"trim_space"`
TrimURLAndEmail bool `json:"trim_url_and_email"`
// leveled config
MaxDepth int64 `json:"max_depth"` // 按层级分段时的最大层级
SaveTitle bool `json:"save_title"` // 保留层级标题
}
type ChunkType int64
const (
ChunkTypeDefault ChunkType = 0 // 自动分片
ChunkTypeCustom ChunkType = 1 // 自定义规则分片
ChunkTypeLeveled ChunkType = 2 // 层级分片
)
type ImageAnnotationType int64
const (
ImageAnnotationTypeModel ImageAnnotationType = 0 // 模型自动标注
ImageAnnotationTypeManual ImageAnnotationType = 1 // 人工标注
)
type FileExtension string
const (
// document
FileExtensionPDF FileExtension = "pdf"
FileExtensionTXT FileExtension = "txt"
FileExtensionDoc FileExtension = "doc"
FileExtensionDocx FileExtension = "docx"
FileExtensionMarkdown FileExtension = "md"
// sheet
FileExtensionCSV FileExtension = "csv"
FileExtensionXLSX FileExtension = "xlsx"
FileExtensionJSON FileExtension = "json"
FileExtensionJsonMaps FileExtension = "json_maps" // json of []map[string]string
// image
FileExtensionJPG FileExtension = "jpg"
FileExtensionJPEG FileExtension = "jpeg"
FileExtensionPNG FileExtension = "png"
)
func ValidateFileExtension(fileSuffix string) (ext FileExtension, support bool) {
fileExtension := FileExtension(fileSuffix)
_, ok := fileExtensionSet[fileExtension]
if !ok {
return "", false
}
return fileExtension, true
}
var fileExtensionSet = sets.Set[FileExtension]{
FileExtensionPDF: {},
FileExtensionTXT: {},
FileExtensionDoc: {},
FileExtensionDocx: {},
FileExtensionMarkdown: {},
FileExtensionCSV: {},
FileExtensionJSON: {},
FileExtensionJsonMaps: {},
FileExtensionJPG: {},
FileExtensionJPEG: {},
FileExtensionPNG: {},
}

View File

@@ -0,0 +1,21 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package parser
import "github.com/cloudwego/eino/components/document/parser"
type Parser = parser.Parser

View File

@@ -0,0 +1,26 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package progressbar
import "context"
// ProgressBar is the interface for the progress bar.
type ProgressBar interface {
AddN(n int) error
ReportError(err error) error
GetProgress(ctx context.Context) (percent int, remainSec int, errMsg string)
}

View File

@@ -0,0 +1,43 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package rerank
import (
"context"
"github.com/cloudwego/eino/schema"
)
type Reranker interface {
Rerank(ctx context.Context, req *Request) (*Response, error)
}
type Request struct {
Query string
Data [][]*Data
TopN *int64
}
type Response struct {
SortedData []*Data // 高分在前
TokenUsage *int64
}
type Data struct {
Document *schema.Document
Score float64
}

View File

@@ -0,0 +1,54 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package searchstore
import "fmt"
type DSL struct {
Op Op
Field string
Value interface{} // builtin types / []*DSL
}
type Op string
const (
OpEq Op = "eq"
OpNe Op = "ne"
OpLike Op = "like"
OpIn Op = "in"
OpAnd Op = "and"
OpOr Op = "or"
)
func (d *DSL) DSL() map[string]any {
return map[string]any{"dsl": d}
}
func LoadDSL(src map[string]any) (*DSL, error) {
if src == nil {
return nil, nil
}
dsl, ok := src["dsl"].(*DSL)
if !ok {
return nil, fmt.Errorf("load dsl failed")
}
return dsl, nil
}

View File

@@ -0,0 +1,82 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package searchstore
import (
"context"
)
type Manager interface {
Create(ctx context.Context, req *CreateRequest) error
Drop(ctx context.Context, req *DropRequest) error
GetType() SearchStoreType
GetSearchStore(ctx context.Context, collectionName string) (SearchStore, error)
}
type CreateRequest struct {
CollectionName string
Fields []*Field
CollectionMeta map[string]string
}
type DropRequest struct {
CollectionName string
}
type GetSearchStoreRequest struct {
CollectionName string
}
type Field struct {
Name FieldName
Type FieldType
Description string
Nullable bool
IsPrimary bool
Indexing bool
}
type SearchStoreType string
const (
TypeVectorStore SearchStoreType = "vector"
TypeTextStore SearchStoreType = "text"
)
type FieldName = string
// 内置 field name
const (
FieldID FieldName = "id" // int64
FieldCreatorID FieldName = "creator_id" // int64
FieldTextContent FieldName = "text_content" // string
)
type FieldType int64
const (
FieldTypeUnknown FieldType = 0
FieldTypeInt64 FieldType = 1
FieldTypeText FieldType = 2
FieldTypeDenseVector FieldType = 3
FieldTypeSparseVector FieldType = 4
)

View File

@@ -0,0 +1,87 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package searchstore
import (
"github.com/cloudwego/eino/components/indexer"
"github.com/cloudwego/eino/components/retriever"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/progressbar"
)
type IndexerOptions struct {
PartitionKey *string
Partition *string // 存储分片映射
IndexingFields []string
ProgressBar progressbar.ProgressBar
}
type RetrieverOptions struct {
MultiMatch *MultiMatch // 多 field 查询
PartitionKey *string
Partitions []string // 查询分片映射
}
type MultiMatch struct {
Fields []string
Query string
}
func WithIndexerPartitionKey(key string) indexer.Option {
return indexer.WrapImplSpecificOptFn(func(o *IndexerOptions) {
o.PartitionKey = &key
})
}
func WithPartition(partition string) indexer.Option {
return indexer.WrapImplSpecificOptFn(func(o *IndexerOptions) {
o.Partition = &partition
})
}
func WithIndexingFields(fields []string) indexer.Option {
return indexer.WrapImplSpecificOptFn(func(o *IndexerOptions) {
o.IndexingFields = fields
})
}
func WithProgressBar(progressBar progressbar.ProgressBar) indexer.Option {
return indexer.WrapImplSpecificOptFn(func(o *IndexerOptions) {
o.ProgressBar = progressBar
})
}
func WithMultiMatch(fields []string, query string) retriever.Option {
return retriever.WrapImplSpecificOptFn(func(o *RetrieverOptions) {
o.MultiMatch = &MultiMatch{
Fields: fields,
Query: query,
}
})
}
func WithRetrieverPartitionKey(key string) retriever.Option {
return retriever.WrapImplSpecificOptFn(func(o *RetrieverOptions) {
o.PartitionKey = &key
})
}
func WithPartitions(partitions []string) retriever.Option {
return retriever.WrapImplSpecificOptFn(func(o *RetrieverOptions) {
o.Partitions = partitions
})
}

View File

@@ -0,0 +1,32 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package searchstore
import (
"context"
"github.com/cloudwego/eino/components/indexer"
"github.com/cloudwego/eino/components/retriever"
)
type SearchStore interface {
indexer.Indexer
retriever.Retriever
Delete(ctx context.Context, ids []string) error
}

View File

@@ -0,0 +1,155 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package document
import (
"strconv"
"time"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
type TableSchema struct {
Name string
Comment string
Columns []*Column
}
type Column struct {
ID int64
Name string
Type TableColumnType
Description string
Nullable bool
IsPrimary bool
Sequence int // 排序编号
}
type TableColumnType int
const (
TableColumnTypeUnknown TableColumnType = 0
TableColumnTypeString TableColumnType = 1
TableColumnTypeInteger TableColumnType = 2
TableColumnTypeTime TableColumnType = 3
TableColumnTypeNumber TableColumnType = 4
TableColumnTypeBoolean TableColumnType = 5
TableColumnTypeImage TableColumnType = 6
)
func (t TableColumnType) String() string {
switch t {
case TableColumnTypeString:
return "varchar"
case TableColumnTypeInteger:
return "bigint"
case TableColumnTypeTime:
return "timestamp"
case TableColumnTypeNumber:
return "double"
case TableColumnTypeBoolean:
return "boolean"
case TableColumnTypeImage:
return "image"
default:
return "unknown"
}
}
type ColumnData struct {
ColumnID int64
ColumnName string
Type TableColumnType
ValString *string
ValInteger *int64
ValTime *time.Time
ValNumber *float64
ValBoolean *bool
ValImage *string // base64 / url
}
func (d *ColumnData) GetValue() interface{} {
switch d.Type {
case TableColumnTypeString:
return d.ValString
case TableColumnTypeInteger:
return d.ValInteger
case TableColumnTypeTime:
return d.ValTime
case TableColumnTypeNumber:
return d.ValNumber
case TableColumnTypeBoolean:
return d.ValBoolean
case TableColumnTypeImage:
return d.ValImage
default:
return nil
}
}
func (d *ColumnData) GetStringValue() string {
switch d.Type {
case TableColumnTypeString:
return ptr.From(d.ValString)
case TableColumnTypeInteger:
return strconv.FormatInt(ptr.From(d.ValInteger), 10)
case TableColumnTypeTime:
return ptr.From(d.ValTime).Format(time.DateTime)
case TableColumnTypeNumber:
return strconv.FormatFloat(ptr.From(d.ValNumber), 'f', 20, 64)
case TableColumnTypeBoolean:
return strconv.FormatBool(ptr.From(d.ValBoolean))
case TableColumnTypeImage:
return ptr.From(d.ValImage)
default:
return ptr.From(d.ValString)
}
}
func (d *ColumnData) GetNullableStringValue() string {
switch d.Type {
case TableColumnTypeString:
return ptr.From(d.ValString)
case TableColumnTypeInteger:
if d.ValInteger == nil {
return ""
}
return strconv.FormatInt(ptr.From(d.ValInteger), 10)
case TableColumnTypeTime:
if d.ValTime == nil {
return ""
}
return ptr.From(d.ValTime).Format(time.DateTime)
case TableColumnTypeNumber:
if d.ValNumber == nil {
return ""
}
return strconv.FormatFloat(ptr.From(d.ValNumber), 'f', 20, 64)
case TableColumnTypeBoolean:
if d.ValBoolean == nil {
return ""
}
return strconv.FormatBool(ptr.From(d.ValBoolean))
case TableColumnTypeImage:
if d.ValImage == nil {
return ""
}
return ptr.From(d.ValImage)
default:
return ptr.From(d.ValString)
}
}

View File

@@ -0,0 +1,39 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dynconf
import "context"
// zookeeper, etcd, nacos
type Provider interface {
Initialize(ctx context.Context, namespace, group string, opts ...Option) (DynamicClient, error)
}
type DynamicClient interface {
AddListener(key string, callback func(value string, err error)) error
RemoveListener(key string) error
Get(ctx context.Context, key string) (string, error)
}
type options struct{}
type Option struct {
apply func(opts *options)
implSpecificOptFn any
}

View File

@@ -0,0 +1,37 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package embedding
import (
"context"
"github.com/cloudwego/eino/components/embedding"
)
type Embedder interface {
embedding.Embedder
EmbedStringsHybrid(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, []map[int]float64, error) // hybrid embedding
Dimensions() int64
SupportStatus() SupportStatus
}
type SupportStatus int
const (
SupportDense SupportStatus = 1
SupportDenseAndSparse SupportStatus = 3
)

View File

@@ -0,0 +1,44 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package es
import (
"context"
)
type Client interface {
Create(ctx context.Context, index, id string, document any) error
Update(ctx context.Context, index, id string, document any) error
Delete(ctx context.Context, index, id string) error
Search(ctx context.Context, index string, req *Request) (*Response, error)
Exists(ctx context.Context, index string) (bool, error)
CreateIndex(ctx context.Context, index string, properties map[string]any) error
DeleteIndex(ctx context.Context, index string) error
Types() Types
NewBulkIndexer(index string) (BulkIndexer, error)
}
type Types interface {
NewLongNumberProperty() any
NewTextProperty() any
NewUnsignedLongNumberProperty() any
}
type BulkIndexer interface {
Add(ctx context.Context, item BulkIndexerItem) error
Close(ctx context.Context) error
}

View File

@@ -0,0 +1,73 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package es
import (
"encoding/json"
"io"
"github.com/elastic/go-elasticsearch/v8/typedapi/types/enums/totalhitsrelation"
)
type BulkIndexerItem struct {
Index string
Action string
DocumentID string
Routing string
Version *int64
VersionType string
Body io.ReadSeeker
RetryOnConflict *int
}
type Request struct {
Size *int
Query *Query
MinScore *float64
Sort []SortFiled
SearchAfter []any
From *int
}
type SortFiled struct {
Field string
Asc bool
}
type Response struct {
Hits HitsMetadata `json:"hits"`
MaxScore *float64 `json:"max_score,omitempty"`
}
type HitsMetadata struct {
Hits []Hit `json:"hits"`
MaxScore *float64 `json:"max_score,omitempty"`
// Total Total hit count information, present only if `track_total_hits` wasn't
// `false` in the search request.
Total *TotalHits `json:"total,omitempty"`
}
type Hit struct {
Id_ *string `json:"_id,omitempty"`
Score_ *float64 `json:"_score,omitempty"`
Source_ json.RawMessage `json:"_source,omitempty"`
}
type TotalHits struct {
Relation totalhitsrelation.TotalHitsRelation `json:"relation"`
Value int64 `json:"value"`
}

View File

@@ -0,0 +1,111 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package es
const (
QueryTypeEqual = "equal"
QueryTypeMatch = "match"
QueryTypeMultiMatch = "multi_match"
QueryTypeNotExists = "not_exists"
QueryTypeContains = "contains"
QueryTypeIn = "in"
)
type KV struct {
Key string
Value any
}
type QueryType string
type Query struct {
KV KV
Type QueryType
MultiMatchQuery MultiMatchQuery
Bool *BoolQuery
}
type BoolQuery struct {
Filter []Query
Must []Query
MustNot []Query
Should []Query
MinimumShouldMatch *int
}
type MultiMatchQuery struct {
Fields []string
Type string // best_fields
Query string
Operator string
}
const (
Or = "or"
And = "and"
)
func NewEqualQuery(k string, v any) Query {
return Query{
KV: KV{Key: k, Value: v},
Type: QueryTypeEqual,
}
}
func NewMatchQuery(k string, v any) Query {
return Query{
KV: KV{Key: k, Value: v},
Type: QueryTypeMatch,
}
}
func NewMultiMatchQuery(fields []string, query, typeStr, operator string) Query {
return Query{
Type: QueryTypeMultiMatch,
MultiMatchQuery: MultiMatchQuery{
Fields: fields,
Query: query,
Operator: operator,
Type: typeStr,
},
}
}
func NewNotExistsQuery(k string) Query {
return Query{
KV: KV{Key: k},
Type: QueryTypeNotExists,
}
}
func NewContainsQuery(k string, v any) Query {
return Query{
KV: KV{Key: k, Value: v},
Type: QueryTypeContains,
}
}
func NewInQuery[T any](k string, v []T) Query {
arr := make([]any, 0, len(v))
for _, item := range v {
arr = append(arr, item)
}
return Query{
KV: KV{Key: k, Value: arr},
Type: QueryTypeIn,
}
}

View File

@@ -0,0 +1,30 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package eventbus
type ConsumerOpt func(option *ConsumerOption)
type ConsumerOption struct {
Orderly *bool
// ConsumeFromWhere
}
func WithConsumerOrderly(orderly bool) ConsumerOpt {
return func(option *ConsumerOption) {
option.Orderly = &orderly
}
}

View File

@@ -0,0 +1,36 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package eventbus
import "context"
//go:generate mockgen -destination ../../../internal/mock/infra/contract/eventbus/eventbus_mock.go -package mock -source eventbus.go Factory
type Producer interface {
Send(ctx context.Context, body []byte, opts ...SendOpt) error
BatchSend(ctx context.Context, bodyArr [][]byte, opts ...SendOpt) error
}
type Consumer interface{}
type ConsumerHandler interface {
HandleMessage(ctx context.Context, msg *Message) error
}
type Message struct {
Topic string
Group string
Body []byte
}

View File

@@ -0,0 +1,29 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package eventbus
type SendOpt func(option *SendOption)
type SendOption struct {
ShardingKey *string
}
func WithShardingKey(key string) SendOpt {
return func(o *SendOption) {
o.ShardingKey = &key
}
}

View File

@@ -0,0 +1,27 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package idgen
import (
"context"
)
//go:generate mockgen -destination ../../../internal/mock/infra/contract/idgen/idgen_mock.go --package mock -source idgen.go
type IDGenerator interface {
GenID(ctx context.Context) (int64, error)
GenMultiIDs(ctx context.Context, counts int) ([]int64, error) // suggest batch size <= 200
}

View File

@@ -0,0 +1,50 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package imagex
type GetResourceOpt func(option *GetResourceOption)
type GetResourceOption struct {
Format string
Template string
Proto string
Expire int
}
func WithResourceFormat(format string) GetResourceOpt {
return func(o *GetResourceOption) {
o.Format = format
}
}
func WithResourceTemplate(template string) GetResourceOpt {
return func(o *GetResourceOption) {
o.Template = template
}
}
func WithResourceProto(proto string) GetResourceOpt {
return func(o *GetResourceOption) {
o.Proto = proto
}
}
func WithResourceExpire(expire int) GetResourceOpt {
return func(o *GetResourceOption) {
o.Expire = expire
}
}

View File

@@ -0,0 +1,71 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package imagex
import (
"context"
"time"
)
//go:generate mockgen -destination ../../../internal/mock/infra/contract/imagex/imagex_mock.go --package imagex -source imagex.go
type ImageX interface {
GetUploadAuth(ctx context.Context, opt ...UploadAuthOpt) (*SecurityToken, error)
GetUploadAuthWithExpire(ctx context.Context, expire time.Duration, opt ...UploadAuthOpt) (*SecurityToken, error)
GetResourceURL(ctx context.Context, uri string, opts ...GetResourceOpt) (*ResourceURL, error)
Upload(ctx context.Context, data []byte, opts ...UploadAuthOpt) (*UploadResult, error)
GetServerID() string
GetUploadHost(ctx context.Context) string
}
type SecurityToken struct {
AccessKeyID string `thrift:"access_key_id,1" frugal:"1,default,string" json:"access_key_id"`
SecretAccessKey string `thrift:"secret_access_key,2" frugal:"2,default,string" json:"secret_access_key"`
SessionToken string `thrift:"session_token,3" frugal:"3,default,string" json:"session_token"`
ExpiredTime string `thrift:"expired_time,4" frugal:"4,default,string" json:"expired_time"`
CurrentTime string `thrift:"current_time,5" frugal:"5,default,string" json:"current_time"`
HostScheme string `thrift:"host_scheme,6" frugal:"6,default,string" json:"host_scheme"`
}
type ResourceURL struct {
// REQUIRED; 结果图访问精简地址,与默认地址相比缺少 Bucket 部分。
CompactURL string `json:"CompactURL"`
// REQUIRED; 结果图访问默认地址。
URL string `json:"URL"`
}
type UploadResult struct {
Result *Result `json:"Results"`
RequestId string `json:"RequestId"`
FileInfo *FileInfo `json:"PluginResult"`
}
type Result struct {
Uri string `json:"Uri"`
UriStatus int `json:"UriStatus"` // 2000表示上传成功
}
type FileInfo struct {
Name string `json:"FileName"`
Uri string `json:"ImageUri"`
ImageWidth int `json:"ImageWidth"`
ImageHeight int `json:"ImageHeight"`
Md5 string `json:"ImageMd5"`
ImageFormat string `json:"ImageFormat"`
ImageSize int `json:"ImageSize"`
FrameCnt int `json:"FrameCnt"`
Duration int `json:"Duration"`
}

View File

@@ -0,0 +1,72 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package imagex
type UploadAuthOpt func(option *UploadAuthOption)
type UploadAuthOption struct {
ContentTypeBlackList []string
ContentTypeWhiteList []string
FileSizeUpLimit *string
FileSizeBottomLimit *string
KeyPtn *string
UploadOverWrite *bool
conditions map[string]string
StoreKey *string
}
func WithStoreKey(key string) UploadAuthOpt {
return func(o *UploadAuthOption) {
o.StoreKey = &key
}
}
func WithUploadKeyPtn(ptn string) UploadAuthOpt {
return func(o *UploadAuthOption) {
o.KeyPtn = &ptn
}
}
func WithUploadOverwrite(overwrite bool) UploadAuthOpt {
return func(op *UploadAuthOption) {
op.UploadOverWrite = &overwrite
}
}
func WithUploadContentTypeBlackList(blackList []string) UploadAuthOpt {
return func(op *UploadAuthOption) {
op.ContentTypeBlackList = blackList
}
}
func WithUploadContentTypeWhiteList(whiteList []string) UploadAuthOpt {
return func(op *UploadAuthOption) {
op.ContentTypeWhiteList = whiteList
}
}
func WithUploadFileSizeUpLimit(limit string) UploadAuthOpt {
return func(op *UploadAuthOption) {
op.FileSizeUpLimit = &limit
}
}
func WithUploadFileSizeBottomLimit(limit string) UploadAuthOpt {
return func(op *UploadAuthOption) {
op.FileSizeBottomLimit = &limit
}
}

View File

@@ -0,0 +1,27 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package messages2query
import (
"context"
"github.com/cloudwego/eino/schema"
)
type MessagesToQuery interface {
MessagesToQuery(ctx context.Context, messages []*schema.Message, opts ...Option) (newQuery string, err error)
}

View File

@@ -0,0 +1,31 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package messages2query
import "github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
type Option func(o *Options)
type Options struct {
ChatModel chatmodel.BaseChatModel
}
func WithChatModel(cm chatmodel.BaseChatModel) Option {
return func(o *Options) {
o.ChatModel = cm
}
}

View File

@@ -0,0 +1,23 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package orm
import (
"gorm.io/gorm"
)
type DB = gorm.DB

View File

@@ -0,0 +1,91 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package entity
type DataType string
const (
TypeInt DataType = "INT"
TypeVarchar DataType = "VARCHAR"
TypeText DataType = "TEXT"
TypeBoolean DataType = "BOOLEAN"
TypeJson DataType = "JSON"
TypeTimestamp DataType = "TIMESTAMP"
TypeFloat DataType = "FLOAT"
TypeBigInt DataType = "BIGINT"
TypeDouble DataType = "DOUBLE"
)
type IndexType string
const (
PrimaryKey IndexType = "PRIMARY KEY"
UniqueKey IndexType = "UNIQUE KEY"
NormalKey IndexType = "KEY"
)
// AlterTableAction 定义修改表的动作类型
type AlterTableAction string
const (
AddColumn AlterTableAction = "ADD COLUMN"
DropColumn AlterTableAction = "DROP COLUMN"
ModifyColumn AlterTableAction = "MODIFY COLUMN"
RenameColumn AlterTableAction = "RENAME COLUMN"
AddIndex AlterTableAction = "ADD INDEX"
)
type LogicalOperator string
const (
AND LogicalOperator = "AND"
OR LogicalOperator = "OR"
)
type Operator string
const (
OperatorEqual Operator = "="
OperatorNotEqual Operator = "!="
OperatorGreater Operator = ">"
OperatorGreaterEqual Operator = ">="
OperatorLess Operator = "<"
OperatorLessEqual Operator = "<="
OperatorLike Operator = "LIKE"
OperatorNotLike Operator = "NOT LIKE"
OperatorIn Operator = "IN"
OperatorNotIn Operator = "NOT IN"
OperatorIsNull Operator = "IS NULL"
OperatorIsNotNull Operator = "IS NOT NULL"
)
type SortDirection string
const (
SortDirectionAsc SortDirection = "ASC" // 升序
SortDirectionDesc SortDirection = "DESC" // 降序
)
type SQLType int32
const (
SQLType_Parameterized SQLType = 0
SQLType_Raw SQLType = 1 // Complete/raw SQL
)

View File

@@ -0,0 +1,54 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package entity
type Column struct {
Name string // 保证唯一性
DataType DataType
Length *int
NotNull bool
DefaultValue *string
AutoIncrement bool // 表示该列是否为自动递增
Comment *string
}
type Index struct {
Name string
Type IndexType
Columns []string
}
type TableOption struct {
Collate *string
AutoIncrement *int64 // 设置表的自动递增初始值
Comment *string
}
type Table struct {
Name string // 保证唯一性
Columns []*Column
Indexes []*Index
Options *TableOption
CreatedAt int64
UpdatedAt int64
}
type ResultSet struct {
Columns []string
Rows []map[string]interface{}
AffectedRows int64
}

View File

@@ -0,0 +1,189 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package rdb
import (
"context"
"github.com/coze-dev/coze-studio/backend/infra/contract/rdb/entity"
)
//go:generate mockgen -destination ../../../internal/mock/infra/contract/rdb/rdb_mock.go --package rdb -source rdb.go
type RDB interface {
CreateTable(ctx context.Context, req *CreateTableRequest) (*CreateTableResponse, error)
AlterTable(ctx context.Context, req *AlterTableRequest) (*AlterTableResponse, error)
DropTable(ctx context.Context, req *DropTableRequest) (*DropTableResponse, error)
GetTable(ctx context.Context, req *GetTableRequest) (*GetTableResponse, error)
InsertData(ctx context.Context, req *InsertDataRequest) (*InsertDataResponse, error)
UpdateData(ctx context.Context, req *UpdateDataRequest) (*UpdateDataResponse, error)
DeleteData(ctx context.Context, req *DeleteDataRequest) (*DeleteDataResponse, error)
SelectData(ctx context.Context, req *SelectDataRequest) (*SelectDataResponse, error)
UpsertData(ctx context.Context, req *UpsertDataRequest) (*UpsertDataResponse, error)
ExecuteSQL(ctx context.Context, req *ExecuteSQLRequest) (*ExecuteSQLResponse, error)
}
// CreateTableRequest 创建表请求
type CreateTableRequest struct {
Table *entity.Table
}
// CreateTableResponse 创建表响应
type CreateTableResponse struct {
Table *entity.Table
}
// AlterTableOperation 修改表操作
type AlterTableOperation struct {
Action entity.AlterTableAction
Column *entity.Column
OldName *string
Index *entity.Index
IndexName *string
NewTableName *string
}
// AlterTableRequest 修改表请求
type AlterTableRequest struct {
TableName string
Operations []*AlterTableOperation
}
// AlterTableResponse 修改表响应
type AlterTableResponse struct {
Table *entity.Table
}
// DropTableRequest 删除表请求
type DropTableRequest struct {
TableName string
IfExists bool
}
// DropTableResponse 删除表响应
type DropTableResponse struct {
Success bool
}
// GetTableRequest 获取表信息请求
type GetTableRequest struct {
TableName string
}
// GetTableResponse 获取表信息响应
type GetTableResponse struct {
Table *entity.Table
}
// InsertDataRequest 插入数据请求
type InsertDataRequest struct {
TableName string
Data []map[string]interface{}
}
// InsertDataResponse 插入数据响应
type InsertDataResponse struct {
AffectedRows int64
}
// Condition 定义查询条件
type Condition struct {
Field string
Operator entity.Operator
Value interface{}
}
// ComplexCondition 复杂条件
type ComplexCondition struct {
Conditions []*Condition
NestedConditions []*ComplexCondition // 与 Conditions互斥 example: WHERE (age >= 18 AND status = 'active') OR (age >= 21 AND status = 'pending')
Operator entity.LogicalOperator
}
// UpdateDataRequest 更新数据请求
type UpdateDataRequest struct {
TableName string
Data map[string]interface{}
Where *ComplexCondition
Limit *int
}
// UpdateDataResponse 更新数据响应
type UpdateDataResponse struct {
AffectedRows int64
}
// DeleteDataRequest 删除数据请求
type DeleteDataRequest struct {
TableName string
Where *ComplexCondition
Limit *int
}
// DeleteDataResponse 删除数据响应
type DeleteDataResponse struct {
AffectedRows int64
}
type OrderBy struct {
Field string // 排序字段
Direction entity.SortDirection // 排序方向
}
// SelectDataRequest 查询数据请求
type SelectDataRequest struct {
TableName string
Fields []string // 要查询的字段,如果为空则查询所有字段
Where *ComplexCondition
OrderBy []*OrderBy // 排序条件
Limit *int // 限制返回行数
Offset *int // 偏移量
}
// SelectDataResponse 查询数据响应
type SelectDataResponse struct {
ResultSet *entity.ResultSet
Total int64 // 符合条件的总记录数(不考虑分页)
}
type UpsertDataRequest struct {
TableName string
Data []map[string]interface{} // 要更新或插入的数据
Keys []string // 用于标识唯一记录的列名,为空的话默认使用主键
}
type UpsertDataResponse struct {
AffectedRows int64 // 受影响的行数
InsertedRows int64 // 新插入的行数
UpdatedRows int64 // 更新的行数
UnchangedRows int64 // 不变的行数(没有插入或更新的行数)
}
// ExecuteSQLRequest 执行SQL请求
type ExecuteSQLRequest struct {
SQL string
Params []interface{} // 用于参数化查询
// SQLType indicates the type of SQL: parameterized or raw SQL. It takes effect if OperateType is 0.
SQLType entity.SQLType
}
// ExecuteSQLResponse 执行SQL响应
type ExecuteSQLResponse struct {
ResultSet *entity.ResultSet
}

View File

@@ -0,0 +1,67 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sqlparser
// TableColumn represents table and column name mapping
type TableColumn struct {
NewTableName *string // if nil, not replace table name
ColumnMap map[string]string // Column name mapping: key is original column name, value is new column name
}
type ColumnValue struct {
ColName string
Value interface{}
}
type PrimaryKeyValue struct {
ColName string
Values []interface{}
}
// OperationType represents the type of SQL operation
type OperationType string
// SQL operation types
const (
OperationTypeSelect OperationType = "SELECT"
OperationTypeInsert OperationType = "INSERT"
OperationTypeUpdate OperationType = "UPDATE"
OperationTypeDelete OperationType = "DELETE"
OperationTypeCreate OperationType = "CREATE"
OperationTypeAlter OperationType = "ALTER"
OperationTypeDrop OperationType = "DROP"
OperationTypeTruncate OperationType = "TRUNCATE"
OperationTypeUnknown OperationType = "UNKNOWN"
)
// SQLParser defines the interface for parsing and modifying SQL statements
type SQLParser interface {
// ParseAndModifySQL parses SQL and replaces table/column names according to the provided message
ParseAndModifySQL(sql string, tableColumns map[string]TableColumn) (string, error) // tableColumns Original table name -> new TableInfo
// GetSQLOperation identifies the operation type in the SQL statement
GetSQLOperation(sql string) (OperationType, error)
// AddColumnsToInsertSQL adds columns to the INSERT SQL statement.
AddColumnsToInsertSQL(origSQL string, addCols []ColumnValue, colVals *PrimaryKeyValue, isParam bool) (string, map[string]bool, error)
// GetTableName extracts the table name from a SQL statement. Only supports single-table select/insert/update/delete. If it has multiple tables, return first table name.
GetTableName(sql string) (string, error)
// GetInsertDataNums extracts the number of rows to be inserted from a SQL statement. Only supports single-table insert.
GetInsertDataNums(sql string) (int, error)
}

View File

@@ -0,0 +1,27 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sse
import (
"context"
"github.com/hertz-contrib/sse"
)
type SSender interface {
Send(ctx context.Context, s *sse.Stream, event *sse.Event) error
}

View File

@@ -0,0 +1,73 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package storage
import (
"time"
)
type GetOptFn func(option *GetOption)
type GetOption struct {
Expire int64 // seconds
}
func WithExpire(expire int64) GetOptFn {
return func(o *GetOption) {
o.Expire = expire
}
}
type PutOption struct {
ContentType *string
ContentEncoding *string
ContentDisposition *string
ContentLanguage *string
Expires *time.Time
}
type PutOptFn func(option *PutOption)
func WithContentType(v string) PutOptFn {
return func(o *PutOption) {
o.ContentType = &v
}
}
func WithContentEncoding(v string) PutOptFn {
return func(o *PutOption) {
o.ContentEncoding = &v
}
}
func WithContentDisposition(v string) PutOptFn {
return func(o *PutOption) {
o.ContentDisposition = &v
}
}
func WithContentLanguage(v string) PutOptFn {
return func(o *PutOption) {
o.ContentLanguage = &v
}
}
func WithExpires(v time.Time) PutOptFn {
return func(o *PutOption) {
o.Expires = &v
}
}

View File

@@ -0,0 +1,35 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package storage
import "context"
//go:generate mockgen -destination ../../../internal/mock/infra/contract/storage/storage_mock.go -package mock -source storage.go Factory
type Storage interface {
PutObject(ctx context.Context, objectKey string, content []byte, opts ...PutOptFn) error
GetObject(ctx context.Context, objectKey string) ([]byte, error)
DeleteObject(ctx context.Context, objectKey string) error
GetObjectUrl(ctx context.Context, objectKey string, opts ...GetOptFn) (string, error)
}
type SecurityToken struct {
AccessKeyID string `thrift:"access_key_id,1" frugal:"1,default,string" json:"access_key_id"`
SecretAccessKey string `thrift:"secret_access_key,2" frugal:"2,default,string" json:"secret_access_key"`
SessionToken string `thrift:"session_token,3" frugal:"3,default,string" json:"session_token"`
ExpiredTime string `thrift:"expired_time,4" frugal:"4,default,string" json:"expired_time"`
CurrentTime string `thrift:"current_time,5" frugal:"5,default,string" json:"current_time"`
}

46
backend/infra/impl/cache/redis/redis.go vendored Normal file
View File

@@ -0,0 +1,46 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package redis
import (
"os"
"time"
"github.com/redis/go-redis/v9"
)
type Client = redis.Client
func New() *redis.Client {
addr := os.Getenv("REDIS_ADDR")
rdb := redis.NewClient(&redis.Options{
Addr: addr, // Redis地址
DB: 0, // 默认数据库
// 连接池配置
PoolSize: 100, // 最大连接数建议设置为CPU核心数*10
MinIdleConns: 10, // 最小空闲连接
MaxIdleConns: 30, // 最大空闲连接
ConnMaxIdleTime: 5 * time.Minute, // 空闲连接超时时间
// 超时配置
DialTimeout: 5 * time.Second, // 连接建立超时
ReadTimeout: 3 * time.Second, // 读操作超时
WriteTimeout: 3 * time.Second, // 写操作超时
})
return rdb
}

View File

@@ -0,0 +1,273 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package chatmodel
import (
"context"
"fmt"
"github.com/cloudwego/eino-ext/components/model/ark"
"github.com/cloudwego/eino-ext/components/model/claude"
"github.com/cloudwego/eino-ext/components/model/deepseek"
"github.com/cloudwego/eino-ext/components/model/gemini"
"github.com/cloudwego/eino-ext/components/model/ollama"
"github.com/cloudwego/eino-ext/components/model/openai"
"github.com/cloudwego/eino-ext/components/model/qwen"
"github.com/ollama/ollama/api"
"google.golang.org/genai"
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
type Builder func(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error)
func NewDefaultFactory() chatmodel.Factory {
return NewFactory(nil)
}
func NewFactory(customFactory map[chatmodel.Protocol]Builder) chatmodel.Factory {
protocol2Builder := map[chatmodel.Protocol]Builder{
chatmodel.ProtocolOpenAI: openAIBuilder,
chatmodel.ProtocolClaude: claudeBuilder,
chatmodel.ProtocolDeepseek: deepseekBuilder,
chatmodel.ProtocolArk: arkBuilder,
chatmodel.ProtocolGemini: geminiBuilder,
chatmodel.ProtocolOllama: ollamaBuilder,
chatmodel.ProtocolQwen: qwenBuilder,
chatmodel.ProtocolErnie: nil,
}
for p := range customFactory {
protocol2Builder[p] = customFactory[p]
}
return &defaultFactory{protocol2Builder: protocol2Builder}
}
type defaultFactory struct {
protocol2Builder map[chatmodel.Protocol]Builder
}
func (f *defaultFactory) CreateChatModel(ctx context.Context, protocol chatmodel.Protocol, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
if config == nil {
return nil, fmt.Errorf("[CreateChatModel] config not provided")
}
builder, found := f.protocol2Builder[protocol]
if !found {
return nil, fmt.Errorf("[CreateChatModel] protocol not support, protocol=%s", protocol)
}
return builder(ctx, config)
}
func (f *defaultFactory) SupportProtocol(protocol chatmodel.Protocol) bool {
_, found := f.protocol2Builder[protocol]
return found
}
func openAIBuilder(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
cfg := &openai.ChatModelConfig{
APIKey: config.APIKey,
Timeout: config.Timeout,
BaseURL: config.BaseURL,
Model: config.Model,
MaxTokens: config.MaxTokens,
Temperature: config.Temperature,
TopP: config.TopP,
Stop: config.Stop,
PresencePenalty: config.PresencePenalty,
FrequencyPenalty: config.FrequencyPenalty,
}
if config.OpenAI != nil {
cfg.ByAzure = config.OpenAI.ByAzure
cfg.APIVersion = config.OpenAI.APIVersion
cfg.ResponseFormat = config.OpenAI.ResponseFormat
}
return openai.NewChatModel(ctx, cfg)
}
func claudeBuilder(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
cfg := &claude.Config{
APIKey: config.APIKey,
Model: config.Model,
Temperature: config.Temperature,
TopP: config.TopP,
StopSequences: config.Stop,
}
if config.BaseURL != "" {
cfg.BaseURL = &config.BaseURL
}
if config.MaxTokens != nil {
cfg.MaxTokens = *config.MaxTokens
}
if config.TopK != nil {
cfg.TopK = ptr.Of(int32(*config.TopK))
}
if config.Claude != nil {
cfg.ByBedrock = config.Claude.ByBedrock
cfg.AccessKey = config.Claude.AccessKey
cfg.SecretAccessKey = config.Claude.SecretAccessKey
cfg.SessionToken = config.Claude.SessionToken
cfg.Region = config.Claude.Region
}
return claude.NewChatModel(ctx, cfg)
}
func deepseekBuilder(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
cfg := &deepseek.ChatModelConfig{
APIKey: config.APIKey,
Timeout: config.Timeout,
BaseURL: config.BaseURL,
Model: config.Model,
Stop: config.Stop,
}
if config.Temperature != nil {
cfg.Temperature = *config.Temperature
}
if config.FrequencyPenalty != nil {
cfg.FrequencyPenalty = *config.FrequencyPenalty
}
if config.PresencePenalty != nil {
cfg.PresencePenalty = *config.PresencePenalty
}
if config.MaxTokens != nil {
cfg.MaxTokens = *config.MaxTokens
}
if config.TopP != nil {
cfg.TopP = *config.TopP
}
if config.Deepseek != nil {
cfg.ResponseFormatType = config.Deepseek.ResponseFormatType
}
return deepseek.NewChatModel(ctx, cfg)
}
func arkBuilder(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
cfg := &ark.ChatModelConfig{
BaseURL: config.BaseURL,
APIKey: config.APIKey,
Model: config.Model,
MaxTokens: config.MaxTokens,
Temperature: config.Temperature,
TopP: config.TopP,
Stop: config.Stop,
FrequencyPenalty: config.FrequencyPenalty,
PresencePenalty: config.PresencePenalty,
}
if config.Timeout != 0 {
cfg.Timeout = &config.Timeout
}
if config.Ark != nil {
cfg.Region = config.Ark.Region
cfg.AccessKey = config.Ark.AccessKey
cfg.SecretKey = config.Ark.SecretKey
cfg.RetryTimes = config.Ark.RetryTimes
cfg.CustomHeader = config.Ark.CustomHeader
}
return ark.NewChatModel(ctx, cfg)
}
func ollamaBuilder(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
cfg := &ollama.ChatModelConfig{
BaseURL: config.BaseURL,
Timeout: config.Timeout,
HTTPClient: nil,
Model: config.Model,
Format: nil,
KeepAlive: nil,
Options: &api.Options{
TopK: ptr.From(config.TopK),
TopP: ptr.From(config.TopP),
Temperature: ptr.From(config.Temperature),
PresencePenalty: ptr.From(config.PresencePenalty),
FrequencyPenalty: ptr.From(config.FrequencyPenalty),
Stop: config.Stop,
},
}
return ollama.NewChatModel(ctx, cfg)
}
func qwenBuilder(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
cfg := &qwen.ChatModelConfig{
APIKey: config.APIKey,
Timeout: config.Timeout,
BaseURL: config.BaseURL,
Model: config.Model,
MaxTokens: config.MaxTokens,
Temperature: config.Temperature,
TopP: config.TopP,
Stop: config.Stop,
PresencePenalty: config.PresencePenalty,
FrequencyPenalty: config.FrequencyPenalty,
EnableThinking: config.EnableThinking,
}
if config.Qwen != nil {
cfg.ResponseFormat = config.Qwen.ResponseFormat
}
return qwen.NewChatModel(ctx, cfg)
}
func geminiBuilder(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
gc := &genai.ClientConfig{
APIKey: config.APIKey,
HTTPOptions: genai.HTTPOptions{
BaseURL: config.BaseURL,
},
}
if config.Gemini != nil {
gc.Backend = config.Gemini.Backend
gc.Project = config.Gemini.Project
gc.Location = config.Gemini.Location
gc.HTTPOptions.APIVersion = config.Gemini.APIVersion
gc.HTTPOptions.Headers = config.Gemini.Headers
}
client, err := genai.NewClient(ctx, gc)
if err != nil {
return nil, err
}
cfg := &gemini.Config{
Client: client,
Model: config.Model,
MaxTokens: config.MaxTokens,
Temperature: config.Temperature,
TopP: config.TopP,
ThinkingConfig: &genai.ThinkingConfig{
IncludeThoughts: true,
ThinkingBudget: nil,
},
}
if config.TopK != nil {
cfg.TopK = ptr.Of(int32(ptr.From(config.TopK)))
}
if config.Gemini != nil && config.Gemini.IncludeThoughts != nil {
cfg.ThinkingConfig.IncludeThoughts = ptr.From(config.Gemini.IncludeThoughts)
}
if config.Gemini != nil && config.Gemini.ThinkingBudget != nil {
cfg.ThinkingConfig.ThinkingBudget = config.Gemini.ThinkingBudget
}
cm, err := gemini.NewChatModel(ctx, cfg)
if err != nil {
return nil, err
}
return cm, nil
}

View File

@@ -0,0 +1,38 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package chatmodel
import (
"sync"
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
)
var (
once sync.Once
singletonFactory chatmodel.Factory
)
func InitSingletonFactory(factory chatmodel.Factory) {
once.Do(func() {
singletonFactory = factory
})
}
func GetSingletonFactory() chatmodel.Factory {
return singletonFactory
}

View File

@@ -0,0 +1,49 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package checkpoint
import (
"context"
"sync"
"github.com/cloudwego/eino/compose"
)
type inMemoryStore struct {
m map[string][]byte
mu sync.RWMutex
}
func (i *inMemoryStore) Get(_ context.Context, checkPointID string) ([]byte, bool, error) {
i.mu.RLock()
v, ok := i.m[checkPointID]
i.mu.RUnlock()
return v, ok, nil
}
func (i *inMemoryStore) Set(_ context.Context, checkPointID string, checkPoint []byte) error {
i.mu.Lock()
i.m[checkPointID] = checkPoint
i.mu.Unlock()
return nil
}
func NewInMemoryStore() compose.CheckPointStore {
return &inMemoryStore{
m: make(map[string][]byte),
}
}

View File

@@ -0,0 +1,55 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package checkpoint
import (
"context"
"errors"
"fmt"
"time"
"github.com/cloudwego/eino/compose"
"github.com/redis/go-redis/v9"
)
type redisStore struct {
client *redis.Client
}
const (
checkpointKeyTpl = "checkpoint_key:%s"
checkpointExpire = 24 * 7 * 3600 * time.Second
)
func (r *redisStore) Get(ctx context.Context, checkPointID string) ([]byte, bool, error) {
v, err := r.client.Get(ctx, fmt.Sprintf(checkpointKeyTpl, checkPointID)).Bytes()
if err != nil {
if errors.Is(err, redis.Nil) {
return nil, false, nil
}
return nil, false, err
}
return v, true, nil
}
func (r *redisStore) Set(ctx context.Context, checkPointID string, checkPoint []byte) error {
return r.client.Set(ctx, fmt.Sprintf(checkpointKeyTpl, checkPointID), checkPoint, checkpointExpire).Err()
}
func NewRedisStore(client *redis.Client) compose.CheckPointStore {
return &redisStore{client: client}
}

View File

@@ -0,0 +1,97 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package coderunner
import (
"bytes"
"context"
"fmt"
"os/exec"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/code"
"github.com/coze-dev/coze-studio/backend/pkg/goutil"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
)
var pythonCode = `
import asyncio
import json
import sys
class Args:
def __init__(self, params):
self.params = params
class Output(dict):
pass
%s
try:
result = asyncio.run(main( Args(json.loads(sys.argv[1]))))
print(json.dumps(result))
except Exception as e:
print(f"{type(e).__name__}: {str(e)}", file=sys.stderr)
sys.exit(1)
`
type Runner struct{}
func NewRunner() *Runner {
return &Runner{}
}
func (r *Runner) Run(ctx context.Context, request *code.RunRequest) (*code.RunResponse, error) {
var (
params = request.Params
c = request.Code
)
if request.Language == code.Python {
ret, err := r.pythonCmdRun(ctx, c, params)
if err != nil {
return nil, err
}
return &code.RunResponse{
Result: ret,
}, nil
}
return nil, fmt.Errorf("unsupported language: %s", request.Language)
}
func (r *Runner) pythonCmdRun(_ context.Context, code string, params map[string]any) (map[string]any, error) {
bs, _ := sonic.Marshal(params)
cmd := exec.Command(goutil.GetPython3Path(), "-c", fmt.Sprintf(pythonCode, code), string(bs)) //ignore_security_alert RCE
stdout := new(bytes.Buffer)
stderr := new(bytes.Buffer)
cmd.Stdout = stdout
cmd.Stderr = stderr
err := cmd.Run()
if err != nil {
return nil, fmt.Errorf("failed to run python script err: %s, std err: %s", err.Error(), stderr.String())
}
if stderr.String() != "" {
return nil, fmt.Errorf("failed to run python script err: %s", stderr.String())
}
ret := make(map[string]any)
err = sonic.Unmarshal(stdout.Bytes(), &ret)
if err != nil {
return nil, err
}
return ret, nil
}

View File

@@ -0,0 +1,62 @@
import json
import sys
import os
import asyncio
import time
import random
try:
from RestrictedPython import safe_builtins, limited_builtins, utility_builtins
except ModuleNotFoundError:
print("RestrictedPython module required, please run pip install RestrictedPython",file=sys.stderr)
sys.exit(1)
custom_builtins = safe_builtins.copy()
custom_builtins['__import__'] = __import__
custom_builtins['asyncio'] = asyncio
custom_builtins['json'] = json
custom_builtins['time'] = time
custom_builtins['random'] = random
restricted_globals = {
'__builtins__': custom_builtins,
'_utility_builtins': utility_builtins,
'_limited_builtins': limited_builtins,
'__name__': '__main__',
'dict': dict,
'list': list,
'print': print,
'set': set,
}
class Args:
def __init__(self, params):
self.params = params
DefaultCode = """
class Args:
def __init__(self, params):
self.params = params
class Output(dict):
pass
"""
async def run_main(app_code, params):
try:
complete_code = DefaultCode + app_code
locals_dict = {"args": Args(params=params)}
exec(complete_code, restricted_globals, locals_dict) # ignore_security_alert
main_func = locals_dict['main']
ret = await main_func(locals_dict['args'])
except Exception as e:
print(f"{type(e).__name__}: {str(e)}", file=sys.stderr)
sys.exit(1)
return ret
code = sys.argv[1]
result = asyncio.run(run_main(code, params=json.loads(sys.argv[2])))
print(json.dumps(result))

View File

@@ -0,0 +1,118 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"github.com/cloudwego/eino/components/prompt"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/nl2sql"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
)
const (
defaultTableFmt = "table name: %s.\ntable describe: %s.\n\n| field name | description | field type | is required |\n"
defaultColumnFmt = "| %s | %s | %s | %t |\n\n"
)
func NewNL2SQL(_ context.Context, cm chatmodel.BaseChatModel, tpl prompt.ChatTemplate) (nl2sql.NL2SQL, error) {
return &n2s{cm: cm, tpl: tpl}, nil
}
type n2s struct {
ch *compose.Chain[*nl2sqlInput, string]
runnable compose.Runnable[*nl2sqlInput, string]
cm chatmodel.BaseChatModel
tpl prompt.ChatTemplate
}
func (n *n2s) NL2SQL(ctx context.Context, messages []*schema.Message, tables []*document.TableSchema, opts ...nl2sql.Option) (sql string, err error) {
o := &nl2sql.Options{ChatModel: n.cm}
for _, opt := range opts {
opt(o)
}
if o.ChatModel == nil {
return "", fmt.Errorf("[NL2SQL] chat model not configured")
}
c := compose.NewChain[*nl2sqlInput, string]().
AppendLambda(compose.InvokableLambda(func(ctx context.Context, input *nl2sqlInput) (output map[string]any, err error) {
if len(input.tables) == 0 {
return nil, errors.New("table meta is empty")
}
tableDesc := strings.Builder{}
for _, table := range input.tables {
tableDesc.WriteString(fmt.Sprintf(defaultTableFmt, table.Name, table.Comment))
for _, column := range table.Columns {
tableDesc.WriteString(fmt.Sprintf(defaultColumnFmt, column.Name, column.Description, column.Type.String(), !column.Nullable))
}
}
//logs.CtxInfof(ctx, "table schema: %s", tableDesc.String())
return map[string]interface{}{
"messages": input.messages,
"table_schema": tableDesc.String(),
}, nil
})).
AppendChatTemplate(n.tpl).
AppendChatModel(o.ChatModel).
AppendLambda(compose.InvokableLambda(func(ctx context.Context, msg *schema.Message) (sql string, err error) {
var promptResp *promptResponse
if err := json.Unmarshal([]byte(msg.Content), &promptResp); err != nil {
logs.CtxWarnf(ctx, "unmarshal failed: %v", err)
return "", err
}
if promptResp.SQL == "" {
logs.CtxInfof(ctx, "no sql generated, err_code: %v, err_msg: %v", promptResp.ErrCode, promptResp.ErrMsg)
return "", errors.New(promptResp.ErrMsg)
}
return promptResp.SQL, nil
}))
r, err := c.Compile(ctx)
if err != nil {
return "", err
}
input := &nl2sqlInput{
messages: messages,
tables: tables,
}
return r.Invoke(ctx, input)
}
type nl2sqlInput struct {
messages []*schema.Message
tables []*document.TableSchema
}
type promptResponse struct {
SQL string `json:"sql"`
ErrCode int `json:"err_code"`
ErrMsg string `json:"err_msg"`
}

View File

@@ -0,0 +1,139 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"context"
"testing"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/components/prompt"
"github.com/cloudwego/eino/schema"
"github.com/stretchr/testify/assert"
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
)
func TestNL2SQL(t *testing.T) {
ctx := context.Background()
t.Run("test table meta not provided", func(t *testing.T) {
impl, err := NewNL2SQL(ctx, &mockChatModel{"mock resp"}, prompt.FromMessages(schema.Jinja2,
schema.SystemMessage("system message 123"),
schema.UserMessage("{{messages}}, {{table_meta}}"),
))
assert.NoError(t, err)
sql, err := impl.NL2SQL(ctx, []*schema.Message{schema.UserMessage("hello")}, nil)
assert.Error(t, err)
assert.Equal(t, "", sql)
})
t.Run("test parse failed", func(t *testing.T) {
impl, err := NewNL2SQL(ctx, &mockChatModel{"mock resp"}, prompt.FromMessages(schema.Jinja2,
schema.SystemMessage("system message 123"),
schema.UserMessage("{{messages}}, {{table_meta}}"),
))
assert.NoError(t, err)
sql, err := impl.NL2SQL(ctx, []*schema.Message{schema.UserMessage("hello")}, []*document.TableSchema{
{
Name: "mock_table_1",
Comment: "hello",
Columns: []*document.Column{
{
ID: 121,
Name: "id",
Type: document.TableColumnTypeInteger,
Description: "test",
Nullable: false,
IsPrimary: true,
Sequence: 0,
},
{
ID: 123,
Name: "col_1",
Type: document.TableColumnTypeString,
Description: "column_1",
Nullable: true,
IsPrimary: false,
Sequence: 1,
},
},
},
})
assert.Error(t, err)
assert.Equal(t, "", sql)
})
t.Run("test success", func(t *testing.T) {
impl, err := NewNL2SQL(ctx, &mockChatModel{`{"sql":"mock sql","err_code":0,"err_msg":""}`}, prompt.FromMessages(schema.Jinja2,
schema.SystemMessage("system message 123"),
schema.UserMessage("{{messages}}, {{table_meta}}"),
))
assert.NoError(t, err)
sql, err := impl.NL2SQL(ctx, []*schema.Message{schema.UserMessage("hello")}, []*document.TableSchema{
{
Name: "mock_table_1",
Comment: "hello",
Columns: []*document.Column{
{
ID: 121,
Name: "id",
Type: document.TableColumnTypeInteger,
Description: "test",
Nullable: false,
IsPrimary: true,
Sequence: 0,
},
{
ID: 123,
Name: "col_1",
Type: document.TableColumnTypeString,
Description: "column_1",
Nullable: true,
IsPrimary: false,
Sequence: 1,
},
},
},
})
assert.NoError(t, err)
assert.Equal(t, "mock sql", sql)
})
}
type mockChatModel struct {
content string
}
func (m mockChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) {
return schema.AssistantMessage(m.content, nil), nil
}
func (m mockChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) {
return nil, nil
}
func (m mockChatModel) BindTools(tools []*schema.ToolInfo) error {
return nil
}
const sys = "# Role: NL2SQL Consultant\n\n## Goals\nTranslate natural language statements into SQL queries in MySQL standard. Follow the Constraints and return only a JSON always.\n\n## Format\n- JSON format only. JSON contains field \"sql\" for generated SQL, filed \"err_code\" for reason type, field \"err_msg\" for detail reason (prefer more than 10 words)\n- Don't use \"```json\" markdown format\n\n## Skills\n- Good at Translate natural language statements into SQL queries in MySQL standard.\n\n## Define\n\"err_code\" Reason Type Define:\n- 0 means you generated a SQL\n- 3002 means you cannot generate a SQL because of timeout\n- 3003 means you cannot generate a SQL because of table schema missing\n- 3005 means you cannot generate a SQL because of some term is ambiguous\n\n## Example\nQ: Help me implement NL2SQL.\n.table schema description: CREATE TABLE `sales_records` (\\n `sales_id` bigint(20) unsigned NOT NULL COMMENT 'id of sales person',\\n `product_id` bigint(64) COMMENT 'id of product',\\n `sale_date` datetime(3) COMMENT 'sold date and time',\\n `quantity_sold` int(11) COMMENT 'sold amount',\\n PRIMARY KEY (`sales_id`)\\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='销售记录表';\n.natural language description of the SQL requirement: ​​​​查询上月的销量总额第一名的销售员和他的销售总额\nA: {\n \"sql\":\"SELECT sales_id, SUM(quantity_sold) AS total_sales FROM sales_records WHERE MONTH(sale_date) = MONTH(CURRENT_DATE - INTERVAL 1 MONTH) AND YEAR(sale_date) = YEAR(CURRENT_DATE - INTERVAL 1 MONTH) GROUP BY sales_id ORDER BY total_sales DESC LIMIT 1\",\n \"err_code\":0,\n \"err_msg\":\"SQL query generated successfully\"\n}"
const usr = "help me implement NL2SQL.\ntable schema description:{{tableSchema}}\nnatural language description of the SQL requirement: {{chat_history}}."

View File

@@ -0,0 +1,96 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package veocr
import (
"context"
"fmt"
"net/http"
"net/url"
"strconv"
"github.com/volcengine/volc-sdk-golang/service/visual"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/ocr"
)
type Config struct {
Client *visual.Visual
// see: https://www.volcengine.com/docs/6790/117730
ApproximatePixel *int // default: 0
Mode *string // default: "text_block"
FilterThresh *int // default: 80
HalfToFull *bool // default: false
}
func NewOCR(config *Config) ocr.OCR {
return &ocrImpl{config}
}
type ocrImpl struct {
config *Config
}
func (o *ocrImpl) FromBase64(ctx context.Context, b64 string) ([]string, error) {
form := o.newForm()
form.Add("image_base64", b64)
resp, statusCode, err := o.config.Client.OCRNormal(form)
if err != nil {
return nil, err
}
if statusCode != http.StatusOK {
return nil, fmt.Errorf("[FromBase64] failed, status code=%d", statusCode)
}
return resp.Data.LineTexts, nil
}
func (o *ocrImpl) FromURL(ctx context.Context, url string) ([]string, error) {
form := o.newForm()
form.Add("image_url", url)
resp, statusCode, err := o.config.Client.OCRNormal(form)
if err != nil {
return nil, err
}
if statusCode != http.StatusOK {
return nil, fmt.Errorf("[FromBase64] failed, status code=%d", statusCode)
}
return resp.Data.LineTexts, nil
}
func (o *ocrImpl) newForm() url.Values {
form := url.Values{}
if o.config.ApproximatePixel != nil {
form.Add("approximate_pixel", strconv.FormatInt(int64(*o.config.ApproximatePixel), 10))
}
if o.config.Mode != nil {
form.Add("mode", *o.config.Mode)
} else {
form.Add("mode", "text_block")
}
if o.config.FilterThresh != nil {
form.Add("filter_thresh", strconv.FormatInt(int64(*o.config.FilterThresh), 10))
}
if o.config.HalfToFull != nil {
form.Add("half_to_full", strconv.FormatBool(*o.config.HalfToFull))
}
return form
}

View File

@@ -0,0 +1,37 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
)
func alignTableSliceValue(schema []*document.Column, row []*document.ColumnData) (err error) {
for i, col := range row {
var newCol *document.ColumnData
newCol, err = assertValAs(schema[i].Type, col.GetStringValue())
if err != nil {
return err
}
newCol.ColumnID = col.ColumnID
newCol.ColumnName = col.ColumnName
row[i] = newCol
}
return nil
}

View File

@@ -0,0 +1,142 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"testing"
"time"
. "github.com/bytedance/mockey"
"github.com/smartystreets/goconvey/convey"
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
func TestAssertVal(t *testing.T) {
PatchConvey("test assertVal", t, func() {
convey.So(assertVal(""), convey.ShouldEqual, document.ColumnData{
Type: document.TableColumnTypeUnknown,
ValString: ptr.Of(""),
})
convey.So(assertVal("true"), convey.ShouldEqual, document.ColumnData{
Type: document.TableColumnTypeBoolean,
ValBoolean: ptr.Of(true),
})
convey.So(assertVal("10"), convey.ShouldEqual, document.ColumnData{
Type: document.TableColumnTypeInteger,
ValInteger: ptr.Of(int64(10)),
})
convey.So(assertVal("1.0"), convey.ShouldEqual, document.ColumnData{
Type: document.TableColumnTypeNumber,
ValNumber: ptr.Of(1.0),
})
ts := time.Now().Format(timeFormat)
now, err := time.Parse(timeFormat, ts)
convey.So(err, convey.ShouldBeNil)
convey.So(assertVal(ts), convey.ShouldEqual, document.ColumnData{
Type: document.TableColumnTypeTime,
ValTime: ptr.Of(now),
})
convey.So(assertVal("hello"), convey.ShouldEqual, document.ColumnData{
Type: document.TableColumnTypeString,
ValString: ptr.Of("hello"),
})
})
}
func TestAssertValAs(t *testing.T) {
PatchConvey("test assertValAs", t, func() {
type testCase struct {
typ document.TableColumnType
val string
isErr bool
data *document.ColumnData
}
ts := time.Now().Format(timeFormat)
now, _ := time.Parse(timeFormat, ts)
cases := []testCase{
{
typ: document.TableColumnTypeString,
val: "hello",
isErr: false,
data: &document.ColumnData{Type: document.TableColumnTypeString, ValString: ptr.Of("hello")},
},
{
typ: document.TableColumnTypeInteger,
val: "1",
isErr: false,
data: &document.ColumnData{Type: document.TableColumnTypeInteger, ValInteger: ptr.Of(int64(1))},
},
{
typ: document.TableColumnTypeInteger,
val: "hello",
isErr: true,
},
{
typ: document.TableColumnTypeTime,
val: ts,
isErr: false,
data: &document.ColumnData{Type: document.TableColumnTypeTime, ValTime: ptr.Of(now)},
},
{
typ: document.TableColumnTypeTime,
val: "hello",
isErr: true,
},
{
typ: document.TableColumnTypeNumber,
val: "1.0",
isErr: false,
data: &document.ColumnData{Type: document.TableColumnTypeNumber, ValNumber: ptr.Of(1.0)},
},
{
typ: document.TableColumnTypeNumber,
val: "hello",
isErr: true,
},
{
typ: document.TableColumnTypeBoolean,
val: "true",
isErr: false,
data: &document.ColumnData{Type: document.TableColumnTypeBoolean, ValBoolean: ptr.Of(true)},
},
{
typ: document.TableColumnTypeBoolean,
val: "hello",
isErr: true,
},
{
typ: document.TableColumnTypeUnknown,
val: "hello",
isErr: true,
},
}
for _, c := range cases {
v, err := assertValAs(c.typ, c.val)
if c.isErr {
convey.So(err, convey.ShouldNotBeNil)
convey.So(v, convey.ShouldBeNil)
} else {
convey.So(err, convey.ShouldBeNil)
convey.So(v, convey.ShouldEqual, c.data)
}
}
})
}

View File

@@ -0,0 +1,116 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"context"
"fmt"
"regexp"
"strings"
"github.com/cloudwego/eino/components/document/parser"
"github.com/cloudwego/eino/schema"
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
)
var (
spaceRegex = regexp.MustCompile(`\s+`)
urlRegex = regexp.MustCompile(`https?://\S+|www\.\S+`)
emailRegex = regexp.MustCompile(`[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}`)
)
func chunkCustom(_ context.Context, text string, config *contract.Config, opts ...parser.Option) (docs []*schema.Document, err error) {
cs := config.ChunkingStrategy
if cs.Overlap >= cs.ChunkSize {
return nil, fmt.Errorf("[chunkCustom] invalid param, overlap >= chunk_size")
}
var (
parts = strings.Split(text, cs.Separator)
buffer []rune
currentLength int64
options = parser.GetCommonOptions(&parser.Options{ExtraMeta: map[string]any{}}, opts...)
)
trim := func(text string) string {
if cs.TrimURLAndEmail {
text = urlRegex.ReplaceAllString(text, "")
text = emailRegex.ReplaceAllString(text, "")
}
if cs.TrimSpace {
text = strings.TrimSpace(text)
text = spaceRegex.ReplaceAllString(text, " ")
}
return text
}
add := func() {
if len(buffer) == 0 {
return
}
doc := &schema.Document{
Content: string(buffer),
MetaData: map[string]any{},
}
for k, v := range options.ExtraMeta {
doc.MetaData[k] = v
}
docs = append(docs, doc)
buffer = []rune{}
}
processPart := func(part string) {
runes := []rune(part)
for partLength := int64(len(runes)); partLength > 0; partLength = int64(len(runes)) {
pos := min(partLength, cs.ChunkSize-currentLength)
buffer = append(buffer, runes[:pos]...)
currentLength = int64(len(buffer))
if currentLength >= cs.ChunkSize {
add()
if cs.Overlap > 0 {
buffer = getOverlap([]rune(docs[len(docs)-1].Content), cs.Overlap, cs.ChunkSize)
currentLength = int64(len(buffer))
} else {
currentLength = 0
}
}
runes = runes[pos:]
}
add()
}
for _, part := range parts {
processPart(trim(part))
}
add()
return docs, nil
}
func getOverlap(runes []rune, overlapRatio int64, chunkSize int64) []rune {
overlap := int64(float64(chunkSize) * float64(overlapRatio) / 100)
if int64(len(runes)) <= overlap {
return runes
}
return runes[len(runes)-int(overlap):]
}

View File

@@ -0,0 +1,47 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
)
func TestChunkCustom(t *testing.T) {
ctx := context.Background()
t.Run("test \n no overlap", func(t *testing.T) {
text := "1. Eiffel Tower: Located in Paris, France, it is one of the most famous landmarks in the world, designed by Gustave Eiffel and built in 1889.\n2. The Great Wall: Located in China, it is one of the Seven Wonders of the World, built from the Qin Dynasty to the Ming Dynasty, with a total length of over 20000 kilometers.\n3. Grand Canyon National Park: Located in Arizona, USA, it is famous for its deep canyons and magnificent scenery, which are cut by the Colorado River.\n4. The Colosseum: Located in Rome, Italy, built between 70-80 AD, it was the largest circular arena in the ancient Roman Empire.\n5. Taj Mahal: Located in Agra, India, it was completed by Mughal Emperor Shah Jahan in 1653 to commemorate his wife and is one of the New Seven Wonders of the World.\n6. Sydney Opera House: Located in Sydney Harbour, Australia, it is one of the most iconic buildings of the 20th century, renowned for its unique sailboat design.\n7. Louvre Museum: Located in Paris, France, it is one of the largest museums in the world with a rich collection, including Leonardo da Vinci's Mona Lisa and Greece's Venus de Milo.\n8. Niagara Falls: located at the border of the United States and Canada, consisting of three main waterfalls, its spectacular scenery attracts millions of tourists every year.\n9. St. Sophia Cathedral: located in Istanbul, Türkiye, originally built in 537 A.D., it used to be an Orthodox cathedral and mosque, and now it is a museum.\n10. Machu Picchu: an ancient Inca site located on the plateau of the Andes Mountains in Peru, one of the New Seven Wonders of the World, with an altitude of over 2400 meters."
cs := &parser.ChunkingStrategy{
ChunkType: parser.ChunkTypeCustom,
ChunkSize: 1000,
Separator: "\n",
Overlap: 0,
TrimSpace: true,
TrimURLAndEmail: true,
}
slices, err := chunkCustom(ctx, text, &parser.Config{ChunkingStrategy: cs})
assert.NoError(t, err)
assert.Len(t, slices, 10)
})
}

View File

@@ -0,0 +1,213 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"fmt"
"strconv"
"time"
"unicode/utf8"
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
const (
timeFormat = "2006-01-02 15:04:05"
)
func assertValAs(typ document.TableColumnType, val string) (*document.ColumnData, error) {
if val == "" {
return &document.ColumnData{
Type: typ,
}, nil
}
switch typ {
case document.TableColumnTypeString:
return &document.ColumnData{
Type: document.TableColumnTypeString,
ValString: &val,
}, nil
case document.TableColumnTypeInteger:
i, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return nil, err
}
return &document.ColumnData{
Type: document.TableColumnTypeInteger,
ValInteger: &i,
}, nil
case document.TableColumnTypeTime:
if val == "" {
var emptyTime time.Time
return &document.ColumnData{
Type: document.TableColumnTypeTime,
ValTime: ptr.Of(emptyTime),
}, nil
}
// 支持时间戳和时间字符串
i, err := strconv.ParseInt(val, 10, 64)
if err == nil {
t := time.Unix(i, 0)
return &document.ColumnData{
Type: document.TableColumnTypeTime,
ValTime: &t,
}, nil
}
t, err := time.Parse(timeFormat, val)
if err != nil {
return nil, err
}
return &document.ColumnData{
Type: document.TableColumnTypeTime,
ValTime: &t,
}, nil
case document.TableColumnTypeNumber:
f, err := strconv.ParseFloat(val, 64)
if err != nil {
return nil, err
}
return &document.ColumnData{
Type: document.TableColumnTypeNumber,
ValNumber: &f,
}, nil
case document.TableColumnTypeBoolean:
t, err := strconv.ParseBool(val)
if err != nil {
return nil, err
}
return &document.ColumnData{
Type: document.TableColumnTypeBoolean,
ValBoolean: &t,
}, nil
case document.TableColumnTypeImage:
return &document.ColumnData{
Type: document.TableColumnTypeImage,
ValImage: &val,
}, nil
default:
return nil, fmt.Errorf("[assertValAs] type not support, type=%d, val=%s", typ, val)
}
}
func assertValAsForce(typ document.TableColumnType, val string, nullable bool) *document.ColumnData {
cd := &document.ColumnData{
Type: typ,
}
switch typ {
case document.TableColumnTypeString:
cd.ValString = &val
case document.TableColumnTypeInteger:
if i, err := strconv.ParseInt(val, 10, 64); err == nil {
cd.ValInteger = ptr.Of(i)
} else if !nullable {
cd.ValInteger = ptr.Of(int64(0))
}
case document.TableColumnTypeTime:
if t, err := time.Parse(timeFormat, val); err == nil {
cd.ValTime = ptr.Of(t)
} else if !nullable {
cd.ValTime = ptr.Of(time.Time{})
}
case document.TableColumnTypeNumber:
if f, err := strconv.ParseFloat(val, 64); err == nil {
cd.ValNumber = ptr.Of(f)
} else if !nullable {
cd.ValNumber = ptr.Of(0.0)
}
case document.TableColumnTypeBoolean:
if t, err := strconv.ParseBool(val); err == nil {
cd.ValBoolean = ptr.Of(t)
} else if !nullable {
cd.ValBoolean = ptr.Of(false)
}
case document.TableColumnTypeImage:
cd.ValImage = ptr.Of(val)
default:
cd.ValString = &val
}
return cd
}
func assertVal(val string) document.ColumnData {
// TODO: 先不处理 image
if val == "" {
return document.ColumnData{
Type: document.TableColumnTypeUnknown,
ValString: &val,
}
}
if t, err := strconv.ParseBool(val); err == nil {
return document.ColumnData{
Type: document.TableColumnTypeBoolean,
ValBoolean: &t,
}
}
if i, err := strconv.ParseInt(val, 10, 64); err == nil {
return document.ColumnData{
Type: document.TableColumnTypeInteger,
ValInteger: &i,
}
}
if f, err := strconv.ParseFloat(val, 64); err == nil {
return document.ColumnData{
Type: document.TableColumnTypeNumber,
ValNumber: &f,
}
}
if t, err := time.Parse(timeFormat, val); err == nil {
return document.ColumnData{
Type: document.TableColumnTypeTime,
ValTime: &t,
}
}
return document.ColumnData{
Type: document.TableColumnTypeString,
ValString: &val,
}
}
func transformColumnType(src, dst document.TableColumnType) document.TableColumnType {
if src == document.TableColumnTypeUnknown {
return dst
}
if dst == document.TableColumnTypeUnknown {
return src
}
if dst == document.TableColumnTypeString {
return dst
}
if src == dst {
return dst
}
if src == document.TableColumnTypeInteger && dst == document.TableColumnTypeNumber {
return dst
}
return document.TableColumnTypeString
}
func charCount(text string) int64 {
return int64(utf8.RuneCountInString(text))
}

View File

@@ -0,0 +1,36 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"context"
"fmt"
"time"
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
)
func putImageObject(ctx context.Context, st storage.Storage, imgExt string, uid int64, img []byte) (format string, err error) {
secret := createSecret(uid, imgExt)
fileName := fmt.Sprintf("%d_%d_%s.%s", uid, time.Now().UnixNano(), secret, imgExt)
objectName := fmt.Sprintf("%s/%s", knowledgePrefix, fileName)
if err := st.PutObject(ctx, objectName, img); err != nil {
return "", err
}
imgSrc := fmt.Sprintf(imgSrcFormat, objectName)
return imgSrc, nil
}

View File

@@ -0,0 +1,77 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"fmt"
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/ocr"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
"github.com/coze-dev/coze-studio/backend/pkg/goutil"
)
func NewManager(storage storage.Storage, ocr ocr.OCR, imageAnnotationModel chatmodel.BaseChatModel) parser.Manager {
return &manager{
storage: storage,
ocr: ocr,
model: imageAnnotationModel,
}
}
type manager struct {
ocr ocr.OCR
storage storage.Storage
model chatmodel.BaseChatModel
}
func (m *manager) GetParser(config *parser.Config) (parser.Parser, error) {
var pFn parseFn
if config.ParsingStrategy.HeaderLine == 0 && config.ParsingStrategy.DataStartLine == 0 {
config.ParsingStrategy.DataStartLine = 1
} else if config.ParsingStrategy.HeaderLine >= config.ParsingStrategy.DataStartLine {
return nil, fmt.Errorf("[GetParser] invalid header line and data start line, header=%d, data_start=%d",
config.ParsingStrategy.HeaderLine, config.ParsingStrategy.DataStartLine)
}
switch config.FileExtension {
case parser.FileExtensionPDF:
pFn = parseByPython(config, m.storage, m.ocr, goutil.GetPython3Path(), goutil.GetPythonFilePath("parse_pdf.py"))
case parser.FileExtensionTXT:
pFn = parseText(config)
case parser.FileExtensionMarkdown:
pFn = parseMarkdown(config, m.storage, m.ocr)
case parser.FileExtensionDocx:
pFn = parseByPython(config, m.storage, m.ocr, goutil.GetPython3Path(), goutil.GetPythonFilePath("parse_docx.py"))
case parser.FileExtensionCSV:
pFn = parseCSV(config)
case parser.FileExtensionXLSX:
pFn = parseXLSX(config)
case parser.FileExtensionJSON:
pFn = parseJSON(config)
case parser.FileExtensionJsonMaps:
pFn = parseJSONMaps(config)
case parser.FileExtensionJPG, parser.FileExtensionJPEG, parser.FileExtensionPNG:
pFn = parseImage(config, m.model)
default:
return nil, fmt.Errorf("[Parse] document type not support, type=%s", config.FileExtension)
}
return &p{parseFn: pFn}, nil
}

View File

@@ -0,0 +1,53 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"context"
"encoding/csv"
"errors"
"io"
"github.com/cloudwego/eino/components/document/parser"
"github.com/cloudwego/eino/schema"
"github.com/dimchansky/utfbom"
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
)
func parseCSV(config *contract.Config) parseFn {
return func(ctx context.Context, reader io.Reader, opts ...parser.Option) (docs []*schema.Document, err error) {
iter := &csvIterator{csv.NewReader(utfbom.SkipOnly(reader))}
return parseByRowIterator(iter, config, opts...)
}
}
type csvIterator struct {
reader *csv.Reader
}
func (c *csvIterator) NextRow() (row []string, end bool, err error) {
row, e := c.reader.Read()
if e != nil {
if errors.Is(e, io.EOF) {
return nil, true, nil
}
return nil, false, err
}
return row, false, nil
}

View File

@@ -0,0 +1,200 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"bytes"
"context"
"fmt"
"io"
"os"
"testing"
"github.com/cloudwego/eino/components/document/parser"
"github.com/cloudwego/eino/schema"
"github.com/stretchr/testify/assert"
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
)
func TestParseCSV(t *testing.T) {
ctx := context.Background()
b, err := os.ReadFile("./test_data/test_csv.csv")
assert.NoError(t, err)
r1 := bytes.NewReader(b)
c1 := &contract.Config{
FileExtension: contract.FileExtensionCSV,
ParsingStrategy: &contract.ParsingStrategy{
HeaderLine: 0,
DataStartLine: 1,
RowsCount: 20,
},
ChunkingStrategy: nil,
}
p1 := parseCSV(c1)
docs, err := p1(ctx, r1, parser.WithExtraMeta(map[string]any{
"document_id": int64(123),
"knowledge_id": int64(456),
}))
assert.NoError(t, err)
for i, doc := range docs {
assertSheet(t, i, doc)
}
// parse
r2 := bytes.NewReader(b)
c2 := &contract.Config{
FileExtension: contract.FileExtensionCSV,
ParsingStrategy: &contract.ParsingStrategy{
HeaderLine: 0,
DataStartLine: 1,
RowsCount: 10,
Columns: []*document.Column{
{
ID: 0,
Name: "col_string_indexing",
Type: document.TableColumnTypeString,
Nullable: false,
Sequence: 0,
},
{
ID: 0,
Name: "col_string",
Type: document.TableColumnTypeString,
Nullable: false,
Sequence: 1,
},
{
ID: 0,
Name: "col_int",
Type: document.TableColumnTypeInteger,
Nullable: false,
Sequence: 2,
},
{
ID: 0,
Name: "col_number",
Type: document.TableColumnTypeNumber,
Nullable: true,
Sequence: 3,
},
{
ID: 0,
Name: "col_bool",
Type: document.TableColumnTypeBoolean,
Nullable: true,
Sequence: 4,
},
{
ID: 0,
Name: "col_time",
Type: document.TableColumnTypeTime,
Nullable: true,
Sequence: 5,
},
},
},
ChunkingStrategy: nil,
}
p2 := parseCSV(c2)
docs, err = p2(ctx, r2, parser.WithExtraMeta(map[string]any{
"document_id": int64(123),
"knowledge_id": int64(456),
}))
assert.NoError(t, err)
for i, doc := range docs {
assertSheet(t, i, doc)
}
}
func TestParseCSVBadCases(t *testing.T) {
t.Run("test nil row", func(t *testing.T) {
ctx := context.Background()
f, err := os.Open("test_data/test_csv_badcase_1.csv")
assert.NoError(t, err)
b, err := io.ReadAll(f)
assert.NoError(t, err)
pfn := parseCSV(&contract.Config{
FileExtension: "csv",
ParsingStrategy: &contract.ParsingStrategy{
ExtractImage: true,
ExtractTable: true,
ImageOCR: false,
SheetID: nil,
HeaderLine: 0,
DataStartLine: 1,
RowsCount: 0,
IsAppend: false,
Columns: nil,
IgnoreColumnTypeErr: true,
ImageAnnotationType: 0,
},
})
resp, err := pfn(ctx, bytes.NewReader(b))
assert.NoError(t, err)
assert.True(t, len(resp) > 0)
cols, err := document.GetDocumentColumns(resp[0])
assert.NoError(t, err)
cols[5].Nullable = false
npfn := parseCSV(&contract.Config{
FileExtension: "csv",
ParsingStrategy: &contract.ParsingStrategy{
ExtractImage: true,
ExtractTable: true,
ImageOCR: false,
SheetID: nil,
HeaderLine: 0,
DataStartLine: 1,
RowsCount: 0,
IsAppend: false,
Columns: cols,
IgnoreColumnTypeErr: true,
ImageAnnotationType: 0,
},
})
resp, err = npfn(ctx, bytes.NewReader(b))
assert.NoError(t, err)
assert.True(t, len(resp) > 0)
for _, item := range resp {
data, err := document.GetDocumentColumnData(item)
assert.NoError(t, err)
assert.NotNil(t, data[5].GetValue())
}
})
}
func assertSheet(t *testing.T, i int, doc *schema.Document) {
fmt.Printf("sheet[%d]:\n", i)
assert.NotNil(t, doc.MetaData)
assert.NotNil(t, doc.MetaData[document.MetaDataKeyColumns])
cols, ok := doc.MetaData[document.MetaDataKeyColumns].([]*document.Column)
assert.True(t, ok)
assert.NotNil(t, doc.MetaData[document.MetaDataKeyColumnData])
row, ok := doc.MetaData[document.MetaDataKeyColumnData].([]*document.ColumnData)
assert.True(t, ok)
assert.Equal(t, int64(123), doc.MetaData["document_id"].(int64))
assert.Equal(t, int64(456), doc.MetaData["knowledge_id"].(int64))
for j := range row {
col := cols[j]
val := row[j]
fmt.Printf("row[%d]: %v=%v\n", j, col.Name, val.GetStringValue())
}
}

View File

@@ -0,0 +1,172 @@
import io
import os
import json
import sys
import base64
import logging
import time
from abc import ABC
from typing import List, IO
from docx import ImagePart
from docx.oxml import CT_P, CT_Tbl
from docx.table import Table
from docx.text.paragraph import Paragraph
from docx import Document
from PIL import Image
logger = logging.getLogger(__name__)
class DocxLoader(ABC):
def __init__(
self,
file_content: IO[bytes],
extract_images: bool = True,
extract_tables: bool = True,
):
self.file_content = file_content
self.extract_images = extract_images
self.extract_tables = extract_tables
def load(self) -> List[dict]:
result = []
doc = Document(self.file_content)
it = iter(doc.element.body)
text = ""
for part in it:
blocks = self.parse_part(part, doc)
if blocks is None or len(blocks) == 0:
continue
for block in blocks:
if self.extract_images and isinstance(block, list):
for b in block:
image = io.BytesIO()
try:
Image.open(io.BytesIO(b.image.blob)).save(image, format="png")
except Exception as e:
logging.error(f"load image failed, time={time.asctime()}, err:{e}")
raise RuntimeError("ExtractImageError")
if len(text) > 0:
result.append(
{
"content": text,
"type": "text",
}
)
text = ""
result.append(
{
"content": base64.b64encode(image.getvalue()).decode('utf-8'),
"type": "image",
}
)
if isinstance(block, Paragraph):
text += block.text
if self.extract_tables and isinstance(block, Table):
rows = block.rows
if len(text) > 0:
result.append(
{
"content": text,
"type": "text",
}
)
text = ""
table = self.convert_table(rows)
result.append(
{
"table": table,
"type": "table",
}
)
if text:
text += "\n\n"
if len(text) > 0:
result.append(
{
"content": text,
"type": "text",
}
)
return result
def parse_part(self, block, doc: Document):
if isinstance(block, CT_P):
blocks = []
para = Paragraph(block, doc)
image_part = self.get_image_part(para, doc)
if image_part and para.text:
blocks.extend(self.parse_run(para))
elif image_part:
blocks.append(image_part)
elif para.text:
blocks.append(para)
return blocks
elif isinstance(block, CT_Tbl):
return [Table(block, doc)]
def parse_run(self, para: Paragraph):
runs = para.runs
paras = []
if runs is None or len(runs) == 0:
return paras
for run in runs:
if run is None or run.element is None:
continue
p = Paragraph(run.element, para)
image_part = self.get_image_part(p, para)
if image_part:
paras.append(image_part)
else:
paras.append(p)
return paras
@staticmethod
def get_image_part(graph: Paragraph, doc: Document):
images = graph._element.xpath(".//pic:pic")
image_parts = []
for image in images:
for img_id in image.xpath(".//a:blip/@r:embed"):
part = doc.part.related_parts[img_id]
if isinstance(part, ImagePart):
image_parts.append(part)
return image_parts
@staticmethod
def convert_table(rows) -> List[List[str]]:
resp_rows = []
for i, row in enumerate(rows):
resp_row = []
for j, cell in enumerate(row.cells):
resp_row.append(cell.text if cell is not None else '')
resp_rows.append(resp_row)
return resp_rows
if __name__ == "__main__":
w = os.fdopen(3, "wb", )
r = os.fdopen(4, "rb", )
try:
req = json.load(r)
ei, et = req['extract_images'], req['extract_tables']
loader = DocxLoader(file_content=io.BytesIO(sys.stdin.buffer.read()), extract_images=ei, extract_tables=et)
resp = loader.load()
print(f"Extracted {len(resp)} items")
result = json.dumps({"content": resp}, ensure_ascii=False)
w.write(str.encode(result))
w.flush()
w.close()
print("Docx parse done")
except Exception as e:
print("Docx parse error", e)
w.write(str.encode(json.dumps({"error": str(e)})))
w.flush()
w.close()

View File

@@ -0,0 +1,91 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"context"
"encoding/base64"
"fmt"
"io"
"github.com/cloudwego/eino/components/document/parser"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
func parseImage(config *contract.Config, model chatmodel.BaseChatModel) parseFn {
return func(ctx context.Context, reader io.Reader, opts ...parser.Option) (docs []*schema.Document, err error) {
options := parser.GetCommonOptions(&parser.Options{}, opts...)
doc := &schema.Document{
MetaData: map[string]any{},
}
for k, v := range options.ExtraMeta {
doc.MetaData[k] = v
}
switch config.ParsingStrategy.ImageAnnotationType {
case contract.ImageAnnotationTypeModel:
if model == nil {
return nil, errorx.New(errno.ErrKnowledgeNonRetryableCode, errorx.KV("reason", "model is not provided"))
}
bytes, err := io.ReadAll(reader)
if err != nil {
return nil, err
}
b64 := base64.StdEncoding.EncodeToString(bytes)
mime := fmt.Sprintf("image/%s", config.FileExtension)
url := fmt.Sprintf("data:%s;base64,%s", mime, b64)
input := &schema.Message{
Role: schema.User,
MultiContent: []schema.ChatMessagePart{
{
Type: schema.ChatMessagePartTypeText,
//Text: "Give a short description of the image.", // TODO: prompt in current language
Text: "简短描述下这张图片",
},
{
Type: schema.ChatMessagePartTypeImageURL,
ImageURL: &schema.ChatMessageImageURL{
URL: url,
MIMEType: mime,
},
},
},
}
output, err := model.Generate(ctx, []*schema.Message{input})
if err != nil {
return nil, fmt.Errorf("[parseImage] model generate failed: %w", err)
}
doc.Content = output.Content
case contract.ImageAnnotationTypeManual:
// do nothing
default:
return nil, fmt.Errorf("[parseImage] unknown image annotation type=%d", config.ParsingStrategy.ImageAnnotationType)
}
return []*schema.Document{doc}, nil
}
}

View File

@@ -0,0 +1,170 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"encoding/json"
"github.com/cloudwego/eino/components/document/parser"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
)
type rowIterator interface {
NextRow() (row []string, end bool, err error)
}
func parseByRowIterator(iter rowIterator, config *contract.Config, opts ...parser.Option) (
docs []*schema.Document, err error) {
ps := config.ParsingStrategy
options := parser.GetCommonOptions(&parser.Options{}, opts...)
i := 0
columnsProvides := ps.IsAppend || len(ps.Columns) > 0
rev := make(map[int]*document.Column)
var (
expColumns []*document.Column
expData [][]*document.ColumnData
)
for {
row, end, err := iter.NextRow()
if err != nil {
return nil, err
}
if end {
break
}
if i == ps.HeaderLine {
if columnsProvides {
expColumns = ps.Columns
} else {
for j, col := range row {
expColumns = append(expColumns, &document.Column{
Name: col,
Type: document.TableColumnTypeUnknown,
Sequence: j,
})
}
}
for j := range expColumns {
tc := expColumns[j]
rev[tc.Sequence] = tc
}
}
if i >= ps.DataStartLine {
var rowData []*document.ColumnData
for j := range row {
colSchema, found := rev[j]
if !found { // 列裁剪
continue
}
val := row[j]
if columnsProvides {
var data *document.ColumnData
if config.ParsingStrategy.IgnoreColumnTypeErr {
data = assertValAsForce(colSchema.Type, val, colSchema.Nullable)
} else {
data, err = assertValAs(colSchema.Type, val)
if err != nil {
return nil, err
}
}
data.ColumnID = colSchema.ID
data.ColumnName = colSchema.Name
rowData = append(rowData, data)
} else {
exp := assertVal(val)
colSchema.Type = transformColumnType(colSchema.Type, exp.Type)
rowData = append(rowData, &document.ColumnData{
ColumnID: colSchema.ID,
ColumnName: colSchema.Name,
Type: document.TableColumnTypeUnknown,
ValString: &val,
})
}
}
if rowData != nil {
expData = append(expData, rowData)
}
}
i++
if ps.RowsCount != 0 && len(docs) == ps.RowsCount {
break
}
}
if !columnsProvides {
// align data type when columns are provided
for _, col := range expColumns {
if col.Type == document.TableColumnTypeUnknown {
col.Type = document.TableColumnTypeString
}
}
for _, row := range expData {
if err = alignTableSliceValue(expColumns, row); err != nil {
return nil, err
}
}
}
if len(expData) == 0 {
// return a special document with columns only if there is no data
doc := &schema.Document{
MetaData: map[string]any{
document.MetaDataKeyColumns: expColumns,
document.MetaDataKeyColumnsOnly: struct{}{},
},
}
for k, v := range options.ExtraMeta {
doc.MetaData[k] = v
}
return []*schema.Document{doc}, nil
}
for j := range expData {
contentMapping := make(map[string]string)
for _, col := range expData[j] {
contentMapping[col.ColumnName] = col.GetStringValue()
}
b, err := json.Marshal(contentMapping)
if err != nil {
return nil, err
}
doc := &schema.Document{
Content: string(b), // set for tables in text
MetaData: map[string]any{
document.MetaDataKeyColumns: expColumns,
document.MetaDataKeyColumnData: expData[j],
},
}
for k, v := range options.ExtraMeta {
doc.MetaData[k] = v
}
docs = append(docs, doc)
}
return docs, nil
}

View File

@@ -0,0 +1,97 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"context"
"encoding/json"
"fmt"
"io"
"github.com/cloudwego/eino/components/document/parser"
"github.com/cloudwego/eino/schema"
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
)
func parseJSON(config *contract.Config) parseFn {
return func(ctx context.Context, reader io.Reader, opts ...parser.Option) (docs []*schema.Document, err error) {
b, err := io.ReadAll(reader)
if err != nil {
return nil, err
}
var rawSlices []map[string]string
if err = json.Unmarshal(b, &rawSlices); err != nil {
return nil, err
}
if len(rawSlices) == 0 {
return nil, fmt.Errorf("[parseJSON] json data is empty")
}
var header []string
if config.ParsingStrategy.IsAppend {
for _, col := range config.ParsingStrategy.Columns {
header = append(header, col.Name)
}
} else {
for k := range rawSlices[0] {
// init 取首个 json item 中 key 的随机顺序
header = append(header, k)
}
}
iter := &jsonIterator{
header: header,
rows: rawSlices,
i: 0,
}
return parseByRowIterator(iter, config, opts...)
}
}
type jsonIterator struct {
header []string
rows []map[string]string
i int
}
func (j *jsonIterator) NextRow() (row []string, end bool, err error) {
if j.i == 0 {
j.i++
return j.header, false, nil
}
if j.i == len(j.rows)+1 {
return nil, true, nil
}
raw := j.rows[j.i-1]
j.i++
for _, h := range j.header {
val, found := raw[h]
if !found {
row = append(row, "")
} else {
row = append(row, val)
}
}
return row, false, nil
}

View File

@@ -0,0 +1,130 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"context"
"encoding/json"
"fmt"
"io"
"github.com/cloudwego/eino/components/document/parser"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
)
func parseJSONMaps(config *contract.Config) parseFn {
return func(ctx context.Context, reader io.Reader, opts ...parser.Option) (docs []*schema.Document, err error) {
b, err := io.ReadAll(reader)
if err != nil {
return nil, err
}
var customContent []map[string]string
if err = json.Unmarshal(b, &customContent); err != nil {
return nil, err
}
if config.ParsingStrategy == nil {
config.ParsingStrategy = &contract.ParsingStrategy{
HeaderLine: 0,
DataStartLine: 1,
RowsCount: 0,
}
}
iter := &customContentContainer{
i: 0,
colIdx: nil,
customContent: customContent,
curColumns: config.ParsingStrategy.Columns,
}
newConfig := &contract.Config{
FileExtension: config.FileExtension,
ParsingStrategy: &contract.ParsingStrategy{
SheetID: config.ParsingStrategy.SheetID,
HeaderLine: 0,
DataStartLine: 1,
RowsCount: 0,
IsAppend: config.ParsingStrategy.IsAppend,
Columns: config.ParsingStrategy.Columns,
},
ChunkingStrategy: config.ChunkingStrategy,
}
return parseByRowIterator(iter, newConfig, opts...)
}
}
type customContentContainer struct {
i int
colIdx map[string]int
customContent []map[string]string
curColumns []*document.Column
}
func (c *customContentContainer) NextRow() (row []string, end bool, err error) {
if c.i == 0 && c.colIdx == nil {
if len(c.customContent) == 0 {
return nil, false, fmt.Errorf("[customContentContainer] data is nil")
}
headerRow := c.customContent[0]
founded := make(map[string]struct{})
colIdx := make(map[string]int, len(headerRow))
for _, col := range c.curColumns {
name := col.Name
if _, found := headerRow[name]; found {
founded[name] = struct{}{}
colIdx[name] = len(colIdx)
row = append(row, name)
}
}
for name := range headerRow {
if _, found := founded[name]; !found {
colIdx[name] = len(colIdx)
row = append(row, name)
}
}
c.colIdx = colIdx
return row, false, nil
}
if c.i >= len(c.customContent) {
return nil, true, nil
}
content := c.customContent[c.i]
c.i++
row = make([]string, len(content))
for k, v := range content {
idx, found := c.colIdx[k]
if !found {
return nil, false, fmt.Errorf("[customContentContainer] column not found, name=%s", k)
}
row[idx] = v
}
return row, false, nil
}

View File

@@ -0,0 +1,97 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"bytes"
"context"
"testing"
"github.com/cloudwego/eino/components/document/parser"
"github.com/stretchr/testify/assert"
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
)
func TestParseTableCustomContent(t *testing.T) {
ctx := context.Background()
b := []byte(`[{"col_string_indexing":"hello","col_string":"asd","col_int":"1","col_number":"1","col_bool":"true","col_time":"2006-01-02 15:04:05"},{"col_string_indexing":"bye","col_string":"","col_int":"2","col_number":"2.0","col_bool":"false","col_time":""}]`)
reader := bytes.NewReader(b)
config := &contract.Config{
FileExtension: contract.FileExtensionJsonMaps,
ParsingStrategy: &contract.ParsingStrategy{
HeaderLine: 0,
DataStartLine: 1,
RowsCount: 10,
Columns: []*document.Column{
{
ID: 0,
Name: "col_string_indexing",
Type: document.TableColumnTypeString,
Nullable: false,
Sequence: 0,
},
{
ID: 0,
Name: "col_string",
Type: document.TableColumnTypeString,
Nullable: false,
Sequence: 1,
},
{
ID: 0,
Name: "col_int",
Type: document.TableColumnTypeInteger,
Nullable: false,
Sequence: 2,
},
{
ID: 0,
Name: "col_number",
Type: document.TableColumnTypeNumber,
Nullable: true,
Sequence: 3,
},
{
ID: 0,
Name: "col_bool",
Type: document.TableColumnTypeBoolean,
Nullable: true,
Sequence: 4,
},
{
ID: 0,
Name: "col_time",
Type: document.TableColumnTypeTime,
Nullable: true,
Sequence: 5,
},
},
},
}
pfn := parseJSONMaps(config)
docs, err := pfn(ctx, reader, parser.WithExtraMeta(map[string]any{
"document_id": int64(123),
"knowledge_id": int64(456),
}))
assert.NoError(t, err)
for i, doc := range docs {
assertSheet(t, i, doc)
}
}

View File

@@ -0,0 +1,133 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"bytes"
"context"
"testing"
"github.com/cloudwego/eino/components/document/parser"
"github.com/stretchr/testify/assert"
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
)
func TestParseJSON(t *testing.T) {
b := []byte(`[
{
"department": "心血管科",
"title": "高血压患者能吃党参吗?",
"question": "我有高血压这两天女婿来的时候给我拿了些党参泡水喝,您好高血压可以吃党参吗?",
"answer": "高血压病人可以口服党参的。党参有降血脂,降血压的作用,可以彻底消除血液中的垃圾,从而对冠心病以及心血管疾病的患者都有一定的稳定预防工作作用,因此平时口服党参能远离三高的危害。另外党参除了益气养血,降低中枢神经作用,调整消化系统功能,健脾补肺的功能。感谢您的进行咨询,期望我的解释对你有所帮助。"
},
{
"department": "消化科",
"title": "哪家医院能治胃反流",
"question": "烧心打隔咳嗽低烧以有4年多",
"answer": "建议你用奥美拉唑同时,加用吗丁啉或莫沙必利或援生力维,另外还可以加用达喜片"
}
]`)
reader := bytes.NewReader(b)
config := &contract.Config{
FileExtension: contract.FileExtensionJSON,
ParsingStrategy: &contract.ParsingStrategy{
HeaderLine: 0,
DataStartLine: 1,
RowsCount: 2,
},
ChunkingStrategy: nil,
}
pfn := parseJSON(config)
docs, err := pfn(context.Background(), reader, parser.WithExtraMeta(map[string]any{
"document_id": int64(123),
"knowledge_id": int64(456),
}))
assert.NoError(t, err)
for i, doc := range docs {
assertSheet(t, i, doc)
}
}
func TestParseJSONWithSchema(t *testing.T) {
b := []byte(`[
{
"department": "心血管科",
"title": "高血压患者能吃党参吗?",
"question": "我有高血压这两天女婿来的时候给我拿了些党参泡水喝,您好高血压可以吃党参吗?",
"answer": "高血压病人可以口服党参的。党参有降血脂,降血压的作用,可以彻底消除血液中的垃圾,从而对冠心病以及心血管疾病的患者都有一定的稳定预防工作作用,因此平时口服党参能远离三高的危害。另外党参除了益气养血,降低中枢神经作用,调整消化系统功能,健脾补肺的功能。感谢您的进行咨询,期望我的解释对你有所帮助。"
},
{
"department": "消化科",
"title": "哪家医院能治胃反流",
"question": "烧心打隔咳嗽低烧以有4年多",
"answer": "建议你用奥美拉唑同时,加用吗丁啉或莫沙必利或援生力维,另外还可以加用达喜片"
}
]`)
reader := bytes.NewReader(b)
config := &contract.Config{
FileExtension: contract.FileExtensionJSON,
ParsingStrategy: &contract.ParsingStrategy{
HeaderLine: 0,
DataStartLine: 1,
RowsCount: 2,
Columns: []*document.Column{
{
ID: 101,
Name: "department",
Type: document.TableColumnTypeString,
Nullable: false,
Sequence: 0,
},
{
ID: 102,
Name: "title",
Type: document.TableColumnTypeString,
Nullable: false,
Sequence: 1,
},
{
ID: 103,
Name: "question",
Type: document.TableColumnTypeString,
Nullable: false,
Sequence: 2,
},
{
ID: 104,
Name: "answer",
Type: document.TableColumnTypeString,
Nullable: false,
Sequence: 3,
},
},
},
}
pfn := parseJSON(config)
docs, err := pfn(context.Background(), reader, parser.WithExtraMeta(map[string]any{
"document_id": int64(123),
"knowledge_id": int64(456),
}))
assert.NoError(t, err)
for i, doc := range docs {
assertSheet(t, i, doc)
}
}

View File

@@ -0,0 +1,221 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"context"
"encoding/base64"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/cloudwego/eino/components/document/parser"
"github.com/cloudwego/eino/schema"
"github.com/yuin/goldmark"
"github.com/yuin/goldmark/ast"
"github.com/yuin/goldmark/text"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/ocr"
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
)
func parseMarkdown(config *contract.Config, storage storage.Storage, ocr ocr.OCR) parseFn {
return func(ctx context.Context, reader io.Reader, opts ...parser.Option) (docs []*schema.Document, err error) {
options := parser.GetCommonOptions(&parser.Options{}, opts...)
mdParser := goldmark.DefaultParser()
b, err := io.ReadAll(reader)
if err != nil {
return nil, err
}
node := mdParser.Parse(text.NewReader(b))
cs := config.ChunkingStrategy
ps := config.ParsingStrategy
if cs.ChunkType != contract.ChunkTypeCustom && cs.ChunkType != contract.ChunkTypeDefault {
return nil, fmt.Errorf("[parseMarkdown] chunk type not support, chunk type=%d", cs.ChunkType)
}
var (
last *schema.Document
emptySlice bool
)
addSliceContent := func(content string) {
emptySlice = false
last.Content += content
}
newSlice := func(needOverlap bool) {
last = &schema.Document{
MetaData: map[string]any{},
}
for k, v := range options.ExtraMeta {
last.MetaData[k] = v
}
if needOverlap && cs.Overlap > 0 && len(docs) > 0 {
overlap := getOverlap([]rune(docs[len(docs)-1].Content), cs.Overlap, cs.ChunkSize)
addSliceContent(string(overlap))
}
emptySlice = true
}
pushSlice := func() {
if !emptySlice && last.Content != "" {
docs = append(docs, last)
newSlice(true)
}
}
trim := func(text string) string {
if cs.TrimURLAndEmail {
text = urlRegex.ReplaceAllString(text, "")
text = emailRegex.ReplaceAllString(text, "")
}
if cs.TrimSpace {
text = strings.TrimSpace(text)
text = spaceRegex.ReplaceAllString(text, " ")
}
return text
}
downloadImage := func(ctx context.Context, url string) ([]byte, error) {
client := &http.Client{Timeout: 5 * time.Second}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to download image: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to download image, status code: %d", resp.StatusCode)
}
data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read image content: %w", err)
}
return data, nil
}
walker := func(n ast.Node, entering bool) (ast.WalkStatus, error) {
if !entering {
return ast.WalkContinue, nil
}
switch n.Kind() {
case ast.KindText:
if n.HasChildren() {
break
}
textNode := n.(*ast.Text)
plainText := trim(string(textNode.Segment.Value(b)))
for _, part := range strings.Split(plainText, cs.Separator) {
runes := []rune(part)
for partLength := int64(len(runes)); partLength > 0; partLength = int64(len(runes)) {
pos := min(partLength, cs.ChunkSize-charCount(last.Content))
chunk := runes[:pos]
addSliceContent(string(chunk))
runes = runes[pos:]
if charCount(last.Content) >= cs.ChunkSize {
pushSlice()
}
}
}
case ast.KindImage:
if !ps.ExtractImage {
break
}
imageNode := n.(*ast.Image)
if ps.ExtractImage {
imageURL := string(imageNode.Destination)
if _, err = url.ParseRequestURI(imageURL); err == nil {
sp := strings.Split(imageURL, ".")
if len(sp) == 0 {
return ast.WalkStop, fmt.Errorf("failed to extract image extension, url=%s", imageURL)
}
ext := sp[len(sp)-1]
img, err := downloadImage(ctx, imageURL)
if err != nil {
return ast.WalkStop, fmt.Errorf("failed to download image: %w", err)
}
imgSrc, err := putImageObject(ctx, storage, ext, getCreatorIDFromExtraMeta(options.ExtraMeta), img)
if err != nil {
return ast.WalkStop, err
}
if !emptySlice && last.Content != "" {
pushSlice()
} else {
newSlice(false)
}
addSliceContent(fmt.Sprintf("\n%s\n", imgSrc))
if ps.ImageOCR && ocr != nil {
texts, err := ocr.FromBase64(ctx, base64.StdEncoding.EncodeToString(img))
if err != nil {
return ast.WalkStop, fmt.Errorf("failed to perform OCR on image: %w", err)
}
addSliceContent(strings.Join(texts, "\n"))
}
if charCount(last.Content) >= cs.ChunkSize {
pushSlice()
}
} else {
logs.CtxInfof(ctx, "[parseMarkdown] not a valid image url, skip, got=%s", imageURL)
}
}
}
return ast.WalkContinue, nil
}
newSlice(false)
if err = ast.Walk(node, walker); err != nil {
return nil, err
}
if !emptySlice {
pushSlice()
}
return docs, nil
}
}

View File

@@ -0,0 +1,75 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"context"
"fmt"
"os"
"testing"
"github.com/cloudwego/eino/components/document/parser"
"github.com/cloudwego/eino/schema"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
ms "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/storage"
)
func TestParseMarkdown(t *testing.T) {
ctx := context.Background()
ctrl := gomock.NewController(t)
mockStorage := ms.NewMockStorage(ctrl)
mockStorage.EXPECT().PutObject(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
pfn := parseMarkdown(&contract.Config{
FileExtension: contract.FileExtensionMarkdown,
ParsingStrategy: &contract.ParsingStrategy{
ExtractImage: true,
ExtractTable: true,
ImageOCR: true,
},
ChunkingStrategy: &contract.ChunkingStrategy{
ChunkType: contract.ChunkTypeCustom,
ChunkSize: 800,
Separator: "\n",
Overlap: 10,
TrimSpace: true,
TrimURLAndEmail: true,
},
}, mockStorage, nil)
f, err := os.Open("test_data/test_markdown.md")
assert.NoError(t, err)
docs, err := pfn(ctx, f, parser.WithExtraMeta(map[string]any{
"document_id": int64(123),
"knowledge_id": int64(456),
}))
assert.NoError(t, err)
for _, doc := range docs {
assertDoc(t, doc)
}
}
func assertDoc(t *testing.T, doc *schema.Document) {
assert.NotZero(t, doc.Content)
fmt.Println(doc.Content)
assert.NotNil(t, doc.MetaData)
assert.Equal(t, int64(123), doc.MetaData["document_id"].(int64))
assert.Equal(t, int64(456), doc.MetaData["knowledge_id"].(int64))
}

View File

@@ -0,0 +1,152 @@
import io
import json
import os
import sys
import base64
from typing import Literal
import pdfplumber
from PIL import Image, ImageChops
from pdfminer.pdfcolor import (
LITERAL_DEVICE_CMYK,
)
from pdfminer.pdftypes import (
LITERALS_DCT_DECODE,
LITERALS_FLATE_DECODE,
)
def bbox_overlap(bbox1, bbox2):
x0_1, y0_1, x1_1, y1_1 = bbox1
x0_2, y0_2, x1_2, y1_2 = bbox2
x_overlap = max(0, min(x1_1, x1_2) - max(x0_1, x0_2))
y_overlap = max(0, min(y1_1, y1_2) - max(y0_1, y0_2))
overlap_area = x_overlap * y_overlap
bbox1_area = (x1_1 - x0_1) * (y1_1 - y0_1)
bbox2_area = (x1_2 - x0_2) * (y1_2 - y0_2)
if bbox1_area == 0 or bbox2_area == 0:
return 0
return overlap_area / min(bbox1_area, bbox2_area)
def is_structured_table(table):
if not table:
return False
row_count = len(table)
col_count = max(len(row) for row in table)
return row_count >= 2 and col_count >= 2
def extract_pdf_content(pdf_data: bytes, extract_images, extract_tables: bool, filter_pages: []):
with pdfplumber.open(io.BytesIO(pdf_data)) as pdf:
content = []
for page_num, page in enumerate(pdf.pages):
if filter_pages is not None and page_num + 1 in filter_pages:
print(f"Skip page {page_num + 1}...")
continue
print(f"Processing page {page_num + 1}...")
text = page.extract_text(x_tolerance=2)
content.append({
'type': 'text',
'content': text,
'page': page_num + 1,
'bbox': page.bbox
})
if extract_images:
images = page.images
for img_index, img in enumerate(images):
try:
filters = img['stream'].get_filters()
data = img['stream'].get_data()
buffered = io.BytesIO()
if filters[-1][0] in LITERALS_DCT_DECODE:
if LITERAL_DEVICE_CMYK in img['colorspace']:
i = Image.open(io.BytesIO(data))
i = ImageChops.invert(i)
i = i.convert("RGB")
i.save(buffered, format="PNG")
else:
buffered.write(data)
elif len(filters) == 1 and filters[0][0] in LITERALS_FLATE_DECODE:
width, height = img['srcsize']
channels = len(img['stream'].get_data()) / width / height / (img['bits'] / 8)
mode: Literal["1", "L", "RGB", "CMYK"]
if img['bits'] == 1:
mode = "1"
elif img['bits'] == 8 and channels == 1:
mode = "L"
elif img['bits'] == 8 and channels == 3:
mode = "RGB"
elif img['bits'] == 8 and channels == 4:
mode = "CMYK"
i = Image.frombytes(mode, img['srcsize'], data, "raw")
i.save(buffered, format="PNG")
else:
buffered.write(data)
content.append({
'type': 'image',
'content': base64.b64encode(buffered.getvalue()).decode('utf-8'),
'page': page_num + 1,
'bbox': (img['x0'], img['top'], img['x1'], img['bottom'])
})
except Exception as err:
print(f"Skipping an unsupported image on page {page_num + 1}, error message: {err}")
if extract_tables:
tables = page.extract_tables()
for table in tables:
content.append({
'type': 'table',
'table': table,
'page': page_num + 1,
'bbox': page.bbox
})
content.sort(key=lambda x: (x['page'], x['bbox'][1], x['bbox'][0]))
filtered_content = []
for item in content:
if item['type'] == 'table':
if is_structured_table(item['table']):
filtered_content.append(item)
continue
overlap_found = False
for existing_item in filtered_content:
if existing_item['type'] == 'text' and bbox_overlap(item['bbox'], existing_item['bbox']) > 0.8:
overlap_found = True
break
if overlap_found:
continue
filtered_content.append(item)
return filtered_content
if __name__ == "__main__":
w = os.fdopen(3, "wb", )
r = os.fdopen(4, "rb", )
pdf_data = sys.stdin.buffer.read()
print(f"Read {len(pdf_data)} bytes of PDF data")
try:
req = json.load(r)
ei, et, fp = req['extract_images'], req['extract_tables'], req['filter_pages']
extracted_content = extract_pdf_content(pdf_data, ei, et, fp)
print(f"Extracted {len(extracted_content)} items")
result = json.dumps({"content": extracted_content}, ensure_ascii=False)
w.write(str.encode(result))
w.flush()
w.close()
print("Pdf parse done")
except Exception as e:
print("Pdf parse error", e)
w.write(str.encode(json.dumps({"error": str(e)})))
w.flush()
w.close()

View File

@@ -0,0 +1,49 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"context"
"fmt"
"io"
"github.com/cloudwego/eino/components/document/parser"
"github.com/cloudwego/eino/schema"
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
)
func parseText(config *contract.Config) parseFn {
return func(ctx context.Context, reader io.Reader, opts ...parser.Option) (docs []*schema.Document, err error) {
content, err := io.ReadAll(reader)
if err != nil {
return nil, err
}
switch config.ChunkingStrategy.ChunkType {
case contract.ChunkTypeCustom, contract.ChunkTypeDefault:
docs, err = chunkCustom(ctx, string(content), config, opts...)
default:
return nil, fmt.Errorf("[parseText] chunk type not support, type=%d", config.ChunkingStrategy.ChunkType)
}
if err != nil {
return nil, err
}
return docs, nil
}
}

View File

@@ -0,0 +1,78 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"context"
"io"
"github.com/cloudwego/eino/components/document/parser"
"github.com/cloudwego/eino/schema"
"github.com/xuri/excelize/v2"
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
)
func parseXLSX(config *contract.Config) parseFn {
return func(ctx context.Context, reader io.Reader, opts ...parser.Option) (docs []*schema.Document, err error) {
f, err := excelize.OpenReader(reader)
if err != nil {
return nil, err
}
sheetID := 0
if config.ParsingStrategy.SheetID != nil {
sheetID = *config.ParsingStrategy.SheetID
}
rows, err := f.Rows(f.GetSheetName(sheetID))
if err != nil {
return nil, err
}
iter := &xlsxIterator{rows, 0}
return parseByRowIterator(iter, config, opts...)
}
}
type xlsxIterator struct {
rows *excelize.Rows
firstRowSize int
}
func (x *xlsxIterator) NextRow() (row []string, end bool, err error) {
end = !x.rows.Next()
if end {
return nil, end, nil
}
row, err = x.rows.Columns()
if err != nil {
return nil, false, err
}
if x.firstRowSize == 0 {
x.firstRowSize = len(row)
} else if x.firstRowSize > len(row) {
row = append(row, make([]string, x.firstRowSize-len(row))...)
} else if x.firstRowSize < len(row) {
row = row[:x.firstRowSize]
}
return row, false, nil
}

View File

@@ -0,0 +1,171 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"bytes"
"context"
"os"
"testing"
"github.com/cloudwego/eino/components/document/parser"
"github.com/stretchr/testify/assert"
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
)
func TestParseXLSX(t *testing.T) {
ctx := context.Background()
b, err := os.ReadFile("./test_data/test_xlsx.xlsx")
assert.NoError(t, err)
reader := bytes.NewReader(b)
config := &contract.Config{
FileExtension: contract.FileExtensionXLSX,
ParsingStrategy: &contract.ParsingStrategy{
HeaderLine: 0,
DataStartLine: 1,
RowsCount: 10,
Columns: []*document.Column{
{
ID: 0,
Name: "col_string_indexing",
Type: document.TableColumnTypeString,
Nullable: false,
Sequence: 0,
},
{
ID: 0,
Name: "col_string",
Type: document.TableColumnTypeString,
Nullable: true,
Sequence: 1,
},
{
ID: 0,
Name: "col_int",
Type: document.TableColumnTypeInteger,
Nullable: false,
Sequence: 2,
},
{
ID: 0,
Name: "col_number",
Type: document.TableColumnTypeNumber,
Nullable: true,
Sequence: 3,
},
{
ID: 0,
Name: "col_bool",
Type: document.TableColumnTypeBoolean,
Nullable: true,
Sequence: 4,
},
{
ID: 0,
Name: "col_time",
Type: document.TableColumnTypeTime,
Nullable: true,
Sequence: 5,
},
},
},
ChunkingStrategy: nil,
}
pfn := parseXLSX(config)
docs, err := pfn(ctx, reader, parser.WithExtraMeta(map[string]any{
"document_id": int64(123),
"knowledge_id": int64(456),
}))
assert.NoError(t, err)
for i, doc := range docs {
assertSheet(t, i, doc)
}
}
func TestParseXLSXConvertColumnType(t *testing.T) {
ctx := context.Background()
b, err := os.ReadFile("./test_data/test_xlsx.xlsx")
assert.NoError(t, err)
reader := bytes.NewReader(b)
config := &contract.Config{
FileExtension: contract.FileExtensionXLSX,
ParsingStrategy: &contract.ParsingStrategy{
HeaderLine: 0,
DataStartLine: 1,
RowsCount: 10,
IgnoreColumnTypeErr: true,
Columns: []*document.Column{
{
ID: 0,
Name: "col_string_indexing",
Type: document.TableColumnTypeString,
Nullable: false,
Sequence: 0,
},
{
ID: 0,
Name: "col_string",
Type: document.TableColumnTypeInteger, // string -> int: null
Nullable: true,
Sequence: 1,
},
{
ID: 0,
Name: "col_int",
Type: document.TableColumnTypeString, // int -> string: strconv
Nullable: false,
Sequence: 2,
},
{
ID: 0,
Name: "col_number",
Type: document.TableColumnTypeString, // float -> string: strconv
Nullable: true,
Sequence: 3,
},
//{
// ID: 0,
// Name: "col_bool",
// Type: document.TableColumnTypeBoolean, // trim
// Nullable: true,
// Sequence: 4,
//},
//{
// ID: 0,
// Name: "col_time",
// Type: document.TableColumnTypeTime, // trim
// Nullable: true,
// Sequence: 5,
//},
},
},
ChunkingStrategy: nil,
}
pfn := parseXLSX(config)
docs, err := pfn(ctx, reader, parser.WithExtraMeta(map[string]any{
"document_id": int64(123),
"knowledge_id": int64(456),
}))
assert.NoError(t, err)
for i, doc := range docs {
assertSheet(t, i, doc)
}
}

View File

@@ -0,0 +1,35 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"context"
"io"
"github.com/cloudwego/eino/components/document/parser"
"github.com/cloudwego/eino/schema"
)
type p struct {
parseFn
}
func (p p) Parse(ctx context.Context, reader io.Reader, opts ...parser.Option) ([]*schema.Document, error) {
return p.parseFn(ctx, reader, opts...)
}
type parseFn func(ctx context.Context, reader io.Reader, opts ...parser.Option) (docs []*schema.Document, err error)

View File

@@ -0,0 +1,270 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"os"
"os/exec"
"strings"
"github.com/cloudwego/eino/components/document/parser"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/ocr"
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
)
const (
contentTypeText = "text"
contentTypeImage = "image"
contentTypeTable = "table"
)
type pyParseRequest struct {
ExtractImages bool `json:"extract_images"`
ExtractTables bool `json:"extract_tables"`
FilterPages []int `json:"filter_pages"`
}
type pyParseResult struct {
Error string `json:"error"`
Content []*pyParseContent `json:"content"`
}
type pyParseContent struct {
Type string `json:"type"`
Content string `json:"content"`
Table [][]string `json:"table"`
Page int `json:"page"`
}
type pyPDFTableIterator struct {
i int
rows [][]string
}
func (p *pyPDFTableIterator) NextRow() (row []string, end bool, err error) {
if p.i >= len(p.rows) {
return nil, true, nil
}
row = p.rows[p.i]
p.i++
return row, false, nil
}
func parseByPython(config *contract.Config, storage storage.Storage, ocr ocr.OCR, pyPath, scriptPath string) parseFn {
return func(ctx context.Context, reader io.Reader, opts ...parser.Option) (docs []*schema.Document, err error) {
pr, pw, err := os.Pipe()
if err != nil {
return nil, fmt.Errorf("[parseByPython] create rpipe failed, %w", err)
}
r, w, err := os.Pipe()
if err != nil {
return nil, fmt.Errorf("[parseByPython] create pipe failed: %w", err)
}
options := parser.GetCommonOptions(&parser.Options{ExtraMeta: map[string]any{}}, opts...)
reqb, err := json.Marshal(pyParseRequest{
ExtractImages: config.ParsingStrategy.ExtractImage,
ExtractTables: config.ParsingStrategy.ExtractTable,
FilterPages: config.ParsingStrategy.FilterPages,
})
if err != nil {
return nil, fmt.Errorf("[parseByPython] create parse request failed, %w", err)
}
if _, err = pw.Write(reqb); err != nil {
return nil, fmt.Errorf("[parseByPython] write parse request bytes failed, %w", err)
}
if err = pw.Close(); err != nil {
return nil, fmt.Errorf("[parseByPython] close write request pipe failed, %w", err)
}
cmd := exec.Command(pyPath, scriptPath)
cmd.Stdin = reader
cmd.Stdout = os.Stdout
cmd.ExtraFiles = []*os.File{w, pr}
if err = cmd.Start(); err != nil {
return nil, fmt.Errorf("[parseByPython] failed to start Python script: %w", err)
}
if err = w.Close(); err != nil {
return nil, fmt.Errorf("[parseByPython] failed to close write pipe: %w", err)
}
result := &pyParseResult{}
if err = json.NewDecoder(r).Decode(result); err != nil {
return nil, fmt.Errorf("[parseByPython] failed to decode result: %w", err)
}
if err = cmd.Wait(); err != nil {
return nil, fmt.Errorf("[parseByPython] cmd wait err: %w", err)
}
if result.Error != "" {
return nil, fmt.Errorf("[parseByPython] python execution failed: %s", result.Error)
}
for i, item := range result.Content {
switch item.Type {
case contentTypeText:
partDocs, err := chunkCustom(ctx, item.Content, config, opts...)
if err != nil {
return nil, fmt.Errorf("[parseByPython] chunk text failed, %w", err)
}
docs = append(docs, partDocs...)
case contentTypeImage:
if !config.ParsingStrategy.ExtractImage {
continue
}
image, err := base64.StdEncoding.DecodeString(item.Content)
if err != nil {
return nil, fmt.Errorf("[parseByPython] decode image failed, %w", err)
}
imgSrc, err := putImageObject(ctx, storage, "png", getCreatorIDFromExtraMeta(options.ExtraMeta), image)
if err != nil {
return nil, err
}
label := fmt.Sprintf("\n%s", imgSrc)
if config.ParsingStrategy.ImageOCR && ocr != nil {
texts, err := ocr.FromBase64(ctx, item.Content)
if err != nil {
return nil, fmt.Errorf("[parseByPython] FromBase64 failed, %w", err)
}
label += strings.Join(texts, "\n")
}
if i == len(result.Content)-1 || result.Content[i+1].Type != "text" {
doc := &schema.Document{
Content: label,
MetaData: map[string]any{},
}
for k, v := range options.ExtraMeta {
doc.MetaData[k] = v
}
docs = append(docs, doc)
} else {
// TODO: 这里有点问题img label 可能被较短的 chunk size 截断
result.Content[i+1].Content = label + result.Content[i+1].Content
}
case contentTypeTable:
if !config.ParsingStrategy.ExtractTable {
continue
}
iterator := &pyPDFTableIterator{i: 0, rows: item.Table}
rawTableDocs, err := parseByRowIterator(iterator, &contract.Config{
FileExtension: contract.FileExtensionCSV,
ParsingStrategy: &contract.ParsingStrategy{
HeaderLine: 0,
DataStartLine: 1,
RowsCount: 0,
},
ChunkingStrategy: config.ChunkingStrategy,
}, opts...)
if err != nil {
return nil, fmt.Errorf("[parseByPython] parse table failed, %w", err)
}
fmtTableDocs, err := formatTablesInDocument(rawTableDocs)
if err != nil {
return nil, fmt.Errorf("[parseByPython] format table failed, %w", err)
}
docs = append(docs, fmtTableDocs...)
default:
return nil, fmt.Errorf("[parseByPython] invalid content type: %s", item.Type)
}
}
return docs, nil
}
}
func formatTablesInDocument(input []*schema.Document) (output []*schema.Document, err error) {
const (
maxSize = 65535
tableStart, tableEnd = "<table>", "</table>"
)
var (
buffer strings.Builder
firstDoc *schema.Document
)
endSize := len(tableEnd)
buffer.WriteString(tableStart)
push := func() {
newDoc := &schema.Document{
Content: buffer.String() + tableEnd,
MetaData: map[string]any{},
}
for k, v := range firstDoc.MetaData {
if k == document.MetaDataKeyColumnData {
continue
}
newDoc.MetaData[k] = v
}
output = append(output, newDoc)
buffer.Reset()
buffer.WriteString(tableStart)
}
write := func(contents []string) {
row := fmt.Sprintf("<tr><td>%s</td></tr>", strings.Join(contents, "</td><td>"))
buffer.WriteString(row)
if buffer.Len()+endSize >= maxSize {
push()
}
}
for i := range input {
doc := input[i]
if i == 0 {
firstDoc = doc
cols, err := document.GetDocumentColumns(doc)
if err != nil {
return nil, fmt.Errorf("[formatTablesInDocument] invalid table columns, %w", err)
}
values := make([]string, 0, len(cols))
for _, col := range cols {
values = append(values, col.Name)
}
write(values)
}
data, err := document.GetDocumentColumnData(doc)
if err != nil {
return nil, fmt.Errorf("[formatTablesInDocument] invalid table data, %w", err)
}
values := make([]string, 0, len(data))
for _, col := range data {
values = append(values, col.GetNullableStringValue())
}
write(values)
}
if buffer.String() != tableStart {
push()
}
return
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 65 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 38 KiB

View File

@@ -0,0 +1,3 @@
col_string_indexing,col_string,col_int,col_number,col_bool,col_time
hello,asd,1,1.0,TRUE,2006-01-02 15:04:02
bye,,2,2.0,TRUE,
1 col_string_indexing col_string col_int col_number col_bool col_time
2 hello asd 1 1.0 TRUE 2006-01-02 15:04:02
3 bye 2 2.0 TRUE

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1 @@
col_string_indexing,col_string,col_int,col_number,col_bool,col_time
1 col_string_indexing col_string col_int col_number col_bool col_time

View File

@@ -0,0 +1,272 @@
# 1. 欢迎使用 Cmd Markdown 编辑阅读器
<!-- TOC -->
- [1. 欢迎使用 Cmd Markdown 编辑阅读器](#1-欢迎使用-cmd-markdown-编辑阅读器)
- [1.1. markdown扩展需求](#11-markdown扩展需求)
- [1.1.1. 一、各种流程图](#111-一各种流程图)
- [1.1.2. [Windows/Mac/Linux 全平台客户端](https://www.zybuluo.com/cmd/)](#112-windowsmaclinux-全平台客户端httpswwwzybuluocomcmd)
- [1.2. 什么是 Markdown](#12-什么是-markdown)
- [1.2.1. 制作一份待办事宜 [Todo 列表](https://www.zybuluo.com/mdeditor?url=https://www.zybuluo.com/static/editor/md-help.markdown#13-待办事宜-todo-列表)](#121-制作一份待办事宜-todo-列表httpswwwzybuluocommdeditorurlhttpswwwzybuluocomstaticeditormd-helpmarkdown13-待办事宜-todo-列表)
- [1.2.2. 书写一个质能守恒公式[^LaTeX]](#122-书写一个质能守恒公式^latex)
- [1.2.3. 高亮一段代码[^code]](#123-高亮一段代码^code)
- [1.2.4. 高效绘制 [流程图](https://www.zybuluo.com/mdeditor?url=https://www.zybuluo.com/static/editor/md-help.markdown#7-流程图)](#124-高效绘制-流程图httpswwwzybuluocommdeditorurlhttpswwwzybuluocomstaticeditormd-helpmarkdown7-流程图)
- [1.2.5. 高效绘制 [序列图](https://www.zybuluo.com/mdeditor?url=https://www.zybuluo.com/static/editor/md-help.markdown#8-序列图)](#125-高效绘制-序列图httpswwwzybuluocommdeditorurlhttpswwwzybuluocomstaticeditormd-helpmarkdown8-序列图)
- [1.2.6. 绘制表格](#126-绘制表格)
- [1.2.7. 更详细语法说明](#127-更详细语法说明)
- [1.3. 什么是 Cmd Markdown](#13-什么是-cmd-markdown)
- [1.3.1. 实时同步预览](#131-实时同步预览)
- [1.3.2. 编辑工具栏](#132-编辑工具栏)
- [1.3.3. 编辑模式](#133-编辑模式)
- [1.3.4. 实时的云端文稿](#134-实时的云端文稿)
- [1.3.5. 离线模式](#135-离线模式)
- [1.3.6. 管理工具栏](#136-管理工具栏)
- [1.3.7. 阅读工具栏](#137-阅读工具栏)
- [1.3.8. 阅读模式](#138-阅读模式)
- [1.3.9. 标签、分类和搜索](#139-标签分类和搜索)
- [1.3.10. 文稿发布和分享](#1310-文稿发布和分享)
<!-- /TOC -->
[ ] dddd
[x] xxxx
第一行
第二行
------
> 一个快速笔记工具,可生成网页快速分享。
## 1.1. markdown扩展需求
1. 目录
2. 表情
3. 粘贴截图
4. 流程图、时序图
5. 数学公式
6. 标签
7. 简单动画
### 1.1.1. 一、各种流程图
1. 时序图
```seq
Alice->Bob: Hello Bob, how are you?
Note right of Bob: Bob thinks
Bob-->Alice: I am good thanks!
```
2. 流程图
```flow
st=>start: Start
op=>operation: Your Operation
cond=>condition: Yes or No?
e=>end
st->op->cond
cond(yes)->e
cond(no)->op
```
3. 甘特图
```gantt
title 项目开发流程
section 项目确定
需求分析 :a1, 2016-06-22, 3d
可行性报告 :after a1, 5d
概念验证 : 5d
section 项目实施
概要设计 :2016-07-05, 5d
详细设计 :2016-07-08, 10d
编码 :2016-07-15, 10d
测试 :2016-07-22, 5d
section 发布验收
发布: 2d
验收: 3d
```
4. Mermaid 流程图
```graphLR
A[Hard edge] -->|Link text| B(Round edge)
B --> C{Decision}
C -->|One| D[Result one]
C -->|Two| E[Result two]
```
5. Mermaid 序列图
```sequence
Alice->John: Hello John, how are you?
loop every minute
John-->Alice: Great!
end
```
我们理解您需要更便捷更高效的工具记录思想,整理笔记、知识,并将其中承载的价值传播给他人,**Cmd Markdown** 是我们给出的答案 —— 我们为记录思想和分享知识提供更专业的工具。 您可以使用 Cmd Markdown
> * 整理知识,学习笔记
> * 发布日记,杂文,所见所想
> * 撰写发布技术文稿(代码支持)
> * 撰写发布学术论文LaTeX 公式支持)
![cmd-markdown-logo](logo.png)
除了您现在看到的这个 Cmd Markdown 在线版本,您还可以前往以下网址下载:
### 1.1.2. [Windows/Mac/Linux 全平台客户端](https://www.zybuluo.com/cmd/)
> 请保留此份 Cmd Markdown 的欢迎稿兼使用说明,如需撰写新稿件,点击顶部工具栏右侧的 <i class="icon-file"></i> **新文稿** 或者使用快捷键 `Ctrl+Alt+N`。
------
## 1.2. 什么是 Markdown
Markdown 是一种方便记忆、书写的纯文本标记语言,用户可以使用这些标记符号以最小的输入代价生成极富表现力的文档:譬如您正在阅读的这份文档。它使用简单的符号标记不同的标题,分割不同的段落,**粗体** 或者 *斜体* 某些文字,更棒的是,它还可以
### 1.2.1. 制作一份待办事宜 [Todo 列表](https://www.zybuluo.com/mdeditor?url=https://www.zybuluo.com/static/editor/md-help.markdown#13-待办事宜-todo-列表)
- [ ] 支持以 PDF 格式导出文稿
- [ ] 改进 Cmd 渲染算法,使用局部渲染技术提高渲染效率
- [x] 新增 Todo 列表功能
- [x] 修复 LaTex 公式渲染问题
- [x] 新增 LaTex 公式编号功能
### 1.2.2. 书写一个质能守恒公式[^LaTeX]
$$E=mc^2$$
### 1.2.3. 高亮一段代码[^code]
```python
@requires_authorization
class SomeClass:
pass
if __name__ == '__main__':
# A comment
print 'hello world'
```
### 1.2.4. 高效绘制 [流程图](https://www.zybuluo.com/mdeditor?url=https://www.zybuluo.com/static/editor/md-help.markdown#7-流程图)
```flow
st=>start: Start
op=>operation: Your Operation
cond=>condition: Yes or No?
e=>end
st->op->cond
cond(yes)->e
cond(no)->op
```
### 1.2.5. 高效绘制 [序列图](https://www.zybuluo.com/mdeditor?url=https://www.zybuluo.com/static/editor/md-help.markdown#8-序列图)
```seq
Alice->Bob: Hello Bob, how are you?
Note right of Bob: Bob thinks
Bob-->Alice: I am good thanks!
```
### 1.2.6. 绘制表格
| 项目 | 价格 | 数量 |
| -------- | -----: | :----: |
| 计算机 | \$1600 | 5 |
| 手机 | \$12 | 12 |
| 管线 | \$1 | 234 |
### 1.2.7. 更详细语法说明
想要查看更详细的语法说明,可以参考我们准备的 [Cmd Markdown 简明语法手册][1],进阶用户可以参考 [Cmd Markdown 高阶语法手册][2] 了解更多高级功能。
总而言之,不同于其它 *所见即所得* 的编辑器:你只需使用键盘专注于书写文本内容,就可以生成印刷级的排版格式,省却在键盘和工具栏之间来回切换,调整内容和格式的麻烦。**Markdown 在流畅的书写和印刷级的阅读体验之间找到了平衡。** 目前它已经成为世界上最大的技术分享网站 GitHub 和 技术问答网站 StackOverFlow 的御用书写格式。
---
## 1.3. 什么是 Cmd Markdown
您可以使用很多工具书写 Markdown但是 Cmd Markdown 是这个星球上我们已知的、最好的 Markdown 工具——没有之一 :)因为深信文字的力量,所以我们和你一样,对流畅书写,分享思想和知识,以及阅读体验有极致的追求,我们把对于这些诉求的回应整合在 Cmd Markdown并且一次两次三次乃至无数次地提升这个工具的体验最终将它演化成一个 **编辑/发布/阅读** Markdown 的在线平台——您可以在任何地方,任何系统/设备上管理这里的文字。
### 1.3.1. 实时同步预览
我们将 Cmd Markdown 的主界面一分为二,左边为**编辑区**,右边为**预览区**,在编辑区的操作会实时地渲染到预览区方便查看最终的版面效果,并且如果你在其中一个区拖动滚动条,我们有一个巧妙的算法把另一个区的滚动条同步到等价的位置,超酷!
### 1.3.2. 编辑工具栏
也许您还是一个 Markdown 语法的新手,在您完全熟悉它之前,我们在 **编辑区** 的顶部放置了一个如下图所示的工具栏,您可以使用鼠标在工具栏上调整格式,不过我们仍旧鼓励你使用键盘标记格式,提高书写的流畅度。
![tool-editor](toolbar-editor.png)
### 1.3.3. 编辑模式
完全心无旁骛的方式编辑文字:点击 **编辑工具栏** 最右测的拉伸按钮或者按下 `Ctrl + M`,将 Cmd Markdown 切换到独立的编辑模式,这是一个极度简洁的写作环境,所有可能会引起分心的元素都已经被挪除,超清爽!
### 1.3.4. 实时的云端文稿
为了保障数据安全Cmd Markdown 会将您每一次击键的内容保存至云端,同时在 **编辑工具栏** 的最右侧提示 `已保存` 的字样。无需担心浏览器崩溃,机器掉电或者地震,海啸——在编辑的过程中随时关闭浏览器或者机器,下一次回到 Cmd Markdown 的时候继续写作。
### 1.3.5. 离线模式
在网络环境不稳定的情况下记录文字一样很安全在您写作的时候如果电脑突然失去网络连接Cmd Markdown 会智能切换至离线模式,将您后续键入的文字保存在本地,直到网络恢复再将他们传送至云端,即使在网络恢复前关闭浏览器或者电脑,一样没有问题,等到下次开启 Cmd Markdown 的时候,她会提醒您将离线保存的文字传送至云端。简而言之,我们尽最大的努力保障您文字的安全。
### 1.3.6. 管理工具栏
为了便于管理您的文稿,在 **预览区** 的顶部放置了如下所示的 **管理工具栏**
通过管理工具栏可以:
<i class="icon-share"></i> 发布:将当前的文稿生成固定链接,在网络上发布,分享
<i class="icon-file"></i> 新建:开始撰写一篇新的文稿
<i class="icon-trash"></i> 删除:删除当前的文稿
<i class="icon-cloud"></i> 导出:将当前的文稿转化为 Markdown 文本或者 Html 格式,并导出到本地
<i class="icon-reorder"></i> 列表:所有新增和过往的文稿都可以在这里查看、操作
<i class="icon-pencil"></i> 模式:切换 普通/Vim/Emacs 编辑模式
### 1.3.7. 阅读工具栏
通过 **预览区** 右上角的 **阅读工具栏**,可以查看当前文稿的目录并增强阅读体验。
工具栏上的五个图标依次为:
<i class="icon-list"></i> 目录:快速导航当前文稿的目录结构以跳转到感兴趣的段落
<i class="icon-chevron-sign-left"></i> 视图:互换左边编辑区和右边预览区的位置
<i class="icon-adjust"></i> 主题:内置了黑白两种模式的主题,试试 **黑色主题**,超炫!
<i class="icon-desktop"></i> 阅读:心无旁骛的阅读模式提供超一流的阅读体验
<i class="icon-fullscreen"></i> 全屏:简洁,简洁,再简洁,一个完全沉浸式的写作和阅读环境
### 1.3.8. 阅读模式
**阅读工具栏** 点击 <i class="icon-desktop"></i> 或者按下 `Ctrl+Alt+M` 随即进入独立的阅读模式界面,我们在版面渲染上的每一个细节:字体,字号,行间距,前背景色都倾注了大量的时间,努力提升阅读的体验和品质。
### 1.3.9. 标签、分类和搜索
在编辑区任意行首位置输入以下格式的文字可以标签当前文档:
标签: 未分类
标签以后的文稿在【文件列表】Ctrl+Alt+F里会按照标签分类用户可以同时使用键盘或者鼠标浏览查看或者在【文件列表】的搜索文本框内搜索标题关键字过滤文稿如下图所示
![file-list](file-list.png)
### 1.3.10. 文稿发布和分享
在您使用 Cmd Markdown 记录,创作,整理,阅读文稿的同时,我们不仅希望它是一个有力的工具,更希望您的思想和知识通过这个平台,连同优质的阅读体验,将他们分享给有相同志趣的人,进而鼓励更多的人来到这里记录分享他们的思想和知识,尝试点击 <i class="icon-share"></i> (Ctrl+Alt+P) 发布这份文档给好友吧!
------
再一次感谢您花费时间阅读这份欢迎稿,点击 <i class="icon-file"></i> (Ctrl+Alt+N) 开始撰写新的文稿吧!祝您在这里记录、阅读、分享愉快!
作者 [@ghosert][3]
2015 年 06月 15日
[^LaTeX]: 支持 **LaTeX** 编辑显示支持,例如:$\sum_{i=1}^n a_i=0$ 访问 [MathJax][4] 参考更多使用方法。
[^code]: 代码高亮功能支持包括 Java, Python, JavaScript 在内的,**四十一**种主流编程语言。
[1]: https://www.zybuluo.com/mdeditor?url=https://www.zybuluo.com/static/editor/md-help.markdown
[2]: https://www.zybuluo.com/mdeditor?url=https://www.zybuluo.com/static/editor/md-help.markdown#cmd-markdown-高阶语法手册
[3]: http://weibo.com/ghosert
[4]: http://meta.math.stackexchange.com/questions/5020/mathjax-basic-tutorial-and-quick-reference

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.6 KiB

View File

@@ -0,0 +1,74 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"crypto/sha256"
"encoding/base64"
"fmt"
"math/rand"
"path"
"strings"
"time"
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
)
const baseWord = "1Aa2Bb3Cc4Dd5Ee6Ff7Gg8Hh9Ii0JjKkLlMmNnOoPpQqRrSsTtUuVvWwXxYyZz"
const knowledgePrefix = "BIZ_KNOWLEDGE"
const imgSrcFormat = `<img src="" data-tos-key="%s">`
func createSecret(uid int64, fileType string) string {
num := 10
input := fmt.Sprintf("upload_%d_Ma*9)fhi_%d_gou_%s_rand_%d", uid, time.Now().Unix(), fileType, rand.Intn(100000))
// 做md5取前20个,// mapIntToBase62 把数字映射到 Base62
hash := sha256.Sum256([]byte(fmt.Sprintf("%s", input)))
hashString := base64.StdEncoding.EncodeToString(hash[:])
if len(hashString) > num {
hashString = hashString[:num]
}
result := ""
for _, char := range hashString {
index := int(char) % 62
result += string(baseWord[index])
}
return result
}
func getExtension(uri string) string {
if uri == "" {
return ""
}
fileExtension := path.Base(uri)
ext := path.Ext(fileExtension)
if ext != "" {
return strings.TrimPrefix(ext, ".")
}
return ""
}
func getCreatorIDFromExtraMeta(extraMeta map[string]any) int64 {
if extraMeta == nil {
return 0
}
if uid, ok := extraMeta[document.MetaDataKeyCreatorID]; ok {
if uidInt, ok := uid.(int64); ok {
return uidInt
}
}
return 0
}

View File

@@ -0,0 +1,151 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package progressbar
import (
"context"
"errors"
"fmt"
"time"
"github.com/redis/go-redis/v9"
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/progressbar"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
type ProgressBarImpl struct {
CacheCli cache.Cmdable
PrimaryKeyID int64
Total int64
ErrMsg string
}
const (
ttl = time.Hour * 2
ProgressBarStartTimeRedisKey = "RedisBiz.Knowledge_ProgressBar_StartTime_%d"
ProgressBarErrMsgRedisKey = "RedisBiz.Knowledge_ProgressBar_ErrMsg_%d"
ProgressBarTotalNumRedisKey = "RedisBiz.Knowledge_ProgressBar_TotalNum_%d"
ProgressBarProcessedNumRedisKey = "RedisBiz.Knowledge_ProgressBar_ProcessedNum_%d"
DefaultProcessTime = 300
ProcessDone = 100
ProcessInit = 0
)
func NewProgressBar(ctx context.Context, pkID int64, total int64, CacheCli cache.Cmdable, needInit bool) progressbar.ProgressBar {
if needInit {
CacheCli.Set(ctx, fmt.Sprintf(ProgressBarTotalNumRedisKey, pkID), total, ttl)
CacheCli.Set(ctx, fmt.Sprintf(ProgressBarProcessedNumRedisKey, pkID), 0, ttl)
CacheCli.Set(ctx, fmt.Sprintf(ProgressBarErrMsgRedisKey, pkID), "", ttl)
CacheCli.Set(ctx, fmt.Sprintf(ProgressBarStartTimeRedisKey, pkID), time.Now().Unix(), ttl)
}
return &ProgressBarImpl{
PrimaryKeyID: pkID,
Total: total,
CacheCli: CacheCli,
}
}
func (p *ProgressBarImpl) AddN(n int) error {
if p.ErrMsg != "" {
return errors.New(p.ErrMsg)
}
_, err := p.CacheCli.IncrBy(context.Background(), fmt.Sprintf(ProgressBarProcessedNumRedisKey, p.PrimaryKeyID), int64(n)).Result()
if err != nil {
return err
}
return nil
}
func (p *ProgressBarImpl) ReportError(err error) error {
p.ErrMsg = err.Error()
_, err = p.CacheCli.Set(context.Background(), fmt.Sprintf(ProgressBarErrMsgRedisKey, p.PrimaryKeyID), err.Error(), ttl).Result()
if err != nil {
return err
}
return nil
}
func (p *ProgressBarImpl) GetProgress(ctx context.Context) (percent int, remainSec int, errMsg string) {
var (
totalNum *int64
processedNum *int64
startTime *int64
err error
)
errMsg, err = p.CacheCli.Get(ctx, fmt.Sprintf(ProgressBarErrMsgRedisKey, p.PrimaryKeyID)).Result()
if err == redis.Nil {
errMsg = ""
} else if err != nil {
return ProcessDone, 0, err.Error()
}
if len(errMsg) != 0 {
return ProcessDone, 0, errMsg
}
totalNumStr, err := p.CacheCli.Get(ctx, fmt.Sprintf(ProgressBarTotalNumRedisKey, p.PrimaryKeyID)).Result()
if err == redis.Nil || len(totalNumStr) == 0 {
totalNum = ptr.Of(int64(0))
} else if err != nil {
return ProcessDone, 0, err.Error()
} else {
num, err := conv.StrToInt64(totalNumStr)
if err != nil {
totalNum = ptr.Of(int64(0))
} else {
totalNum = ptr.Of(num)
}
}
processedNumStr, err := p.CacheCli.Get(ctx, fmt.Sprintf(ProgressBarProcessedNumRedisKey, p.PrimaryKeyID)).Result()
if err == redis.Nil || len(processedNumStr) == 0 {
processedNum = ptr.Of(int64(0))
} else if err != nil {
return ProcessDone, 0, err.Error()
} else {
num, err := conv.StrToInt64(processedNumStr)
if err != nil {
processedNum = ptr.Of(int64(0))
} else {
processedNum = ptr.Of(num)
}
}
if ptr.From(totalNum) == 0 {
return ProcessInit, DefaultProcessTime, ""
}
startTimeStr, err := p.CacheCli.Get(ctx, fmt.Sprintf(ProgressBarStartTimeRedisKey, p.PrimaryKeyID)).Result()
if err == redis.Nil || len(startTimeStr) == 0 {
startTime = ptr.Of(int64(0))
} else if err != nil {
return ProcessDone, 0, err.Error()
} else {
num, err := conv.StrToInt64(startTimeStr)
if err != nil {
startTime = ptr.Of(int64(0))
} else {
startTime = ptr.Of(num)
}
}
percent = int(float64(ptr.From(processedNum)) / float64(ptr.From(totalNum)) * 100)
if ptr.From(startTime) == 0 {
remainSec = DefaultProcessTime
} else {
usedSec := time.Now().Unix() - ptr.From(startTime)
remainSec = int(float64(ptr.From(totalNum)-ptr.From(processedNum)) / float64(ptr.From(processedNum)) * float64(usedSec))
}
return
}

View File

@@ -0,0 +1,70 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package rrf
import (
"context"
"fmt"
"sort"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/rerank"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
func NewRRFReranker(k int64) rerank.Reranker {
if k == 0 {
k = 60
}
return &rrfReranker{k}
}
type rrfReranker struct {
k int64
}
func (r *rrfReranker) Rerank(ctx context.Context, req *rerank.Request) (*rerank.Response, error) {
if req == nil || req.Data == nil || len(req.Data) == 0 {
return nil, fmt.Errorf("invalid request: no data provided")
}
id2Score := make(map[string]float64)
id2Data := make(map[string]*rerank.Data)
for _, resultList := range req.Data {
for rank := range resultList {
result := resultList[rank]
if result != nil && result.Document != nil {
score := 1.0 / (float64(rank) + float64(r.k))
if score > id2Score[result.Document.ID] {
id2Score[result.Document.ID] = score
id2Data[result.Document.ID] = result
}
}
}
}
var sorted []*rerank.Data
for _, data := range id2Data {
sorted = append(sorted, data)
}
sort.Slice(sorted, func(i, j int) bool {
return id2Score[sorted[i].Document.ID] > id2Score[sorted[j].Document.ID]
})
topN := int64(len(sorted))
if req.TopN != nil && ptr.From(req.TopN) != 0 && ptr.From(req.TopN) < topN {
topN = ptr.From(req.TopN)
}
return &rerank.Response{SortedData: sorted[:topN]}, nil
}

View File

@@ -0,0 +1,161 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package vikingdb
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"sort"
"github.com/volcengine/volc-sdk-golang/base"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/rerank"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
type Config struct {
AK string
SK string
Region string // default cn-north-1
}
func NewReranker(config *Config) rerank.Reranker {
if config.Region == "" {
config.Region = "cn-north-1"
}
return &reranker{config: config}
}
const (
domain = "api-knowledgebase.mlp.cn-beijing.volces.com"
defaultModel = "base-multilingual-rerank"
)
type reranker struct {
config *Config
}
type rerankReq struct {
Datas []rerankData `json:"datas"`
RerankModel string `json:"rerank_model"`
}
type rerankData struct {
Query string `json:"query"`
Content string `json:"content"`
Title *string `json:"title,omitempty"`
}
type rerankResp struct {
Code int64 `json:"code"`
Message string `json:"message"`
Data struct {
Scores []float64 `json:"scores"`
TokenUsage int64 `json:"token_usage"`
} `json:"data"`
}
func (r *reranker) Rerank(ctx context.Context, req *rerank.Request) (*rerank.Response, error) {
rReq := &rerankReq{
Datas: make([]rerankData, 0, len(req.Data)),
RerankModel: defaultModel,
}
var flat []*rerank.Data
for _, channel := range req.Data {
flat = append(flat, channel...)
}
for _, item := range flat {
rReq.Datas = append(rReq.Datas, rerankData{
Query: req.Query,
Content: item.Document.Content,
})
}
body, err := json.Marshal(rReq)
if err != nil {
return nil, err
}
resp, err := http.DefaultClient.Do(r.prepareRequest(body))
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
rResp := rerankResp{}
if err = json.Unmarshal(respBody, &rResp); err != nil {
return nil, err
}
if rResp.Code != 0 {
return nil, fmt.Errorf("[Rerank] failed, code=%d, msg=%v", rResp.Code, rResp.Message)
}
sorted := make([]*rerank.Data, 0, len(rResp.Data.Scores))
for i, score := range rResp.Data.Scores {
sorted = append(sorted, &rerank.Data{
Document: flat[i].Document,
Score: score,
})
}
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].Score > sorted[j].Score
})
right := len(sorted)
if req.TopN != nil {
right = min(right, int(*req.TopN))
}
return &rerank.Response{
SortedData: sorted[:right],
TokenUsage: ptr.Of(rResp.Data.TokenUsage),
}, nil
}
func (r *reranker) prepareRequest(body []byte) *http.Request {
u := url.URL{
Scheme: "https",
Host: domain,
Path: "/api/knowledge/service/rerank",
}
req, _ := http.NewRequest(http.MethodPost, u.String(), bytes.NewReader(body))
req.Header.Add("Accept", "application/json")
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Host", domain)
credential := base.Credentials{
AccessKeyID: r.config.AK,
SecretAccessKey: r.config.SK,
Service: "air",
Region: r.config.Region,
}
req = credential.Sign(req)
return req
}

View File

@@ -0,0 +1,48 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package vikingdb
//func TestRun(t *testing.T) {
// AK := os.Getenv("test_ak")
// SK := os.Getenv("test_sk")
//
// r := NewReranker(&Config{
// AK: AK,
// SK: SK,
// })
// resp, err := r.Rerank(context.Background(), &rerank.Request{
// Data: [][]*knowledge.RetrieveSlice{
// {
// {Slice: &entity.Slice{PlainText: "吉尼斯世界纪录网站数据显示蓝鲸是目前已知世界上最大的动物体长可达30米相当于一架波音737飞机的长度"}},
// {Slice: &entity.Slice{PlainText: "一头成年雌性弓头鲸可以长到22米长而一头雄性鲸鱼可以长到18米长"}},
// },
// },
// Query: "世界上最大的鲸鱼是什么?",
// TopN: nil,
// })
// assert.NoError(t, err)
//
// for _, item := range resp.Sorted {
// fmt.Println(item.Slice.PlainText, item.Score)
// }
// // 吉尼斯世界纪录网站数据显示蓝鲸是目前已知世界上最大的动物体长可达30米相当于一架波音737飞机的长度 0.6209664529733573
// // 一头成年雌性弓头鲸可以长到22米长而一头雄性鲸鱼可以长到18米长 0.4269785303456468
//
// fmt.Println(resp.TokenUsage)
// // 95
//
//}

View File

@@ -0,0 +1,21 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package elasticsearch
const (
topK = 10
)

View File

@@ -0,0 +1,116 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package elasticsearch
import (
"context"
"fmt"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
"github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
"github.com/coze-dev/coze-studio/backend/infra/contract/es"
)
type ManagerConfig struct {
Client es.Client
}
func NewManager(config *ManagerConfig) searchstore.Manager {
return &esManager{config: config}
}
type esManager struct {
config *ManagerConfig
}
func (e *esManager) Create(ctx context.Context, req *searchstore.CreateRequest) error {
cli := e.config.Client
index := req.CollectionName
indexExists, err := cli.Exists(ctx, index)
if err != nil {
return err
}
if indexExists { // exists
return nil
}
properties := make(map[string]any)
var foundID, foundCreatorID, foundTextContent bool
for _, field := range req.Fields {
switch field.Name {
case searchstore.FieldID:
foundID = true
case searchstore.FieldCreatorID:
foundCreatorID = true
case searchstore.FieldTextContent:
foundTextContent = true
default:
}
var property any
switch field.Type {
case searchstore.FieldTypeInt64:
property = cli.Types().NewLongNumberProperty()
case searchstore.FieldTypeText:
property = cli.Types().NewTextProperty()
default:
return fmt.Errorf("[Create] es unsupported field type: %d", field.Type)
}
properties[field.Name] = property
}
if !foundID {
properties[searchstore.FieldID] = cli.Types().NewLongNumberProperty()
}
if !foundCreatorID {
properties[searchstore.FieldCreatorID] = cli.Types().NewUnsignedLongNumberProperty()
}
if !foundTextContent {
properties[searchstore.FieldTextContent] = cli.Types().NewTextProperty()
}
err = cli.CreateIndex(ctx, index, properties)
if err != nil {
return err
}
return err
}
func (e *esManager) Drop(ctx context.Context, req *searchstore.DropRequest) error {
cli := e.config.Client
index := req.CollectionName
return cli.DeleteIndex(ctx, index)
}
func (e *esManager) GetType() searchstore.SearchStoreType {
return searchstore.TypeTextStore
}
func (e *esManager) GetSearchStore(ctx context.Context, collectionName string) (searchstore.SearchStore, error) {
return &esSearchStore{
config: e.config,
indexName: collectionName,
}, nil
}
func (e *esManager) GetEmbedding() embedding.Embedder {
return nil
}

View File

@@ -0,0 +1,329 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package elasticsearch
import (
"bytes"
"context"
"encoding/json"
"fmt"
"reflect"
"strconv"
"github.com/cloudwego/eino/components/indexer"
"github.com/cloudwego/eino/components/retriever"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
"github.com/coze-dev/coze-studio/backend/infra/contract/es"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
type esSearchStore struct {
config *ManagerConfig
indexName string
}
func (e *esSearchStore) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) (ids []string, err error) {
implSpecOptions := indexer.GetImplSpecificOptions(&searchstore.IndexerOptions{}, opts...)
defer func() {
if err != nil {
if implSpecOptions.ProgressBar != nil {
implSpecOptions.ProgressBar.ReportError(err)
}
}
}()
cli := e.config.Client
index := e.indexName
bi, err := cli.NewBulkIndexer(index)
if err != nil {
return nil, err
}
ids = make([]string, 0, len(docs))
for _, doc := range docs {
fieldMapping, err := e.fromDocument(doc)
if err != nil {
return nil, err
}
body, err := json.Marshal(fieldMapping)
if err != nil {
return nil, err
}
if err = bi.Add(ctx, es.BulkIndexerItem{
Index: e.indexName,
Action: "index",
DocumentID: doc.ID,
Body: bytes.NewReader(body),
}); err != nil {
return nil, err
}
ids = append(ids, doc.ID)
if implSpecOptions.ProgressBar != nil {
if err = implSpecOptions.ProgressBar.AddN(1); err != nil {
return nil, err
}
}
}
if err = bi.Close(ctx); err != nil {
return nil, err
}
return ids, nil
}
func (e *esSearchStore) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) {
var (
cli = e.config.Client
index = e.indexName
options = retriever.GetCommonOptions(&retriever.Options{TopK: ptr.Of(topK)}, opts...)
implSpecOptions = retriever.GetImplSpecificOptions(&searchstore.RetrieverOptions{}, opts...)
req = &es.Request{
Query: &es.Query{
Bool: &es.BoolQuery{},
},
Size: options.TopK,
}
)
if implSpecOptions.MultiMatch == nil {
req.Query.Bool.Must = append(req.Query.Bool.Must,
es.NewMatchQuery(searchstore.FieldTextContent, query))
} else {
req.Query.Bool.Must = append(req.Query.Bool.Must,
es.NewMultiMatchQuery(implSpecOptions.MultiMatch.Fields, query,
"best_fields", es.Or))
}
dsl, err := searchstore.LoadDSL(options.DSLInfo)
if err != nil {
return nil, err
}
if err = e.travDSL(req.Query, dsl); err != nil {
return nil, err
}
if options.ScoreThreshold != nil {
req.MinScore = options.ScoreThreshold
}
resp, err := cli.Search(ctx, index, req)
if err != nil {
return nil, err
}
docs, err := e.parseSearchResult(resp)
if err != nil {
return nil, err
}
return docs, nil
}
func (e *esSearchStore) Delete(ctx context.Context, ids []string) error {
bi, err := e.config.Client.NewBulkIndexer(e.indexName)
if err != nil {
return err
}
for _, id := range ids {
if err = bi.Add(ctx, es.BulkIndexerItem{
Index: e.indexName,
Action: "delete",
DocumentID: id,
}); err != nil {
return err
}
}
return bi.Close(ctx)
}
func (e *esSearchStore) travDSL(query *es.Query, dsl *searchstore.DSL) error {
if dsl == nil {
return nil
}
switch dsl.Op {
case searchstore.OpEq, searchstore.OpNe:
arr := stringifyValue(dsl.Value)
v := dsl.Value
if len(arr) > 0 {
v = arr[0]
}
if dsl.Op == searchstore.OpEq {
query.Bool.Must = append(query.Bool.Must,
es.NewEqualQuery(dsl.Field, v))
} else {
query.Bool.MustNot = append(query.Bool.MustNot,
es.NewEqualQuery(dsl.Field, v))
}
case searchstore.OpLike:
s, ok := dsl.Value.(string)
if !ok {
return fmt.Errorf("[travDSL] OpLike value should be string, but got %v", dsl.Value)
}
query.Bool.Must = append(query.Bool.Must, es.NewMatchQuery(dsl.Field, s))
case searchstore.OpIn:
query.Bool.Must = append(query.Bool.MustNot,
es.NewInQuery(dsl.Field, stringifyValue(dsl.Value)))
case searchstore.OpAnd, searchstore.OpOr:
conds, ok := dsl.Value.([]*searchstore.DSL)
if !ok {
return fmt.Errorf("[travDSL] value type assertion failed for or")
}
for _, cond := range conds {
sub := &es.Query{}
if err := e.travDSL(sub, cond); err != nil {
return err
}
if dsl.Op == searchstore.OpOr {
query.Bool.Should = append(query.Bool.Should, *sub)
} else {
query.Bool.Must = append(query.Bool.Must, *sub)
}
}
default:
return fmt.Errorf("[trav] unknown op %s", dsl.Op)
}
return nil
}
func (e *esSearchStore) parseSearchResult(resp *es.Response) (docs []*schema.Document, err error) {
docs = make([]*schema.Document, 0, len(resp.Hits.Hits))
firstScore := 0.0
for i, hit := range resp.Hits.Hits {
var src map[string]any
d := json.NewDecoder(bytes.NewReader(hit.Source_))
d.UseNumber()
if err = d.Decode(&src); err != nil {
return nil, err
}
ext := make(map[string]any)
doc := &schema.Document{MetaData: map[string]any{document.MetaDataKeyExternalStorage: ext}}
for field, val := range src {
ok := true
switch field {
case searchstore.FieldTextContent:
doc.Content, ok = val.(string)
case searchstore.FieldCreatorID:
var jn json.Number
jn, ok = val.(json.Number)
if ok {
doc.MetaData[document.MetaDataKeyCreatorID], ok = assertJSONNumber(jn).(int64)
}
default:
if jn, jok := val.(json.Number); jok {
ext[field] = assertJSONNumber(jn)
} else {
ext[field] = val
}
}
if !ok {
return nil, fmt.Errorf("[parseSearchResult] type assertion failed, field=%s, val=%v", field, val)
}
}
if hit.Id_ != nil {
doc.ID = *hit.Id_
}
if hit.Score_ == nil { // unexpected
return nil, fmt.Errorf("[parseSearchResult] es retrieve score not found")
}
score := float64(ptr.From(hit.Score_))
if i == 0 {
firstScore = score
}
doc.WithScore(score / firstScore)
docs = append(docs, doc)
}
return docs, nil
}
func (e *esSearchStore) fromDocument(doc *schema.Document) (map[string]any, error) {
if doc.MetaData == nil {
return nil, fmt.Errorf("[fromDocument] es document meta data is nil")
}
creatorID, ok := doc.MetaData[searchstore.FieldCreatorID].(int64)
if !ok {
return nil, fmt.Errorf("[fromDocument] creator id not found or type invalid")
}
fieldMapping := map[string]any{
searchstore.FieldTextContent: doc.Content,
searchstore.FieldCreatorID: creatorID,
}
if ext, ok := doc.MetaData[document.MetaDataKeyExternalStorage].(map[string]any); ok {
for k, v := range ext {
fieldMapping[k] = v
}
}
return fieldMapping, nil
}
func stringifyValue(dslValue any) []any {
value := reflect.ValueOf(dslValue)
switch value.Kind() {
case reflect.Slice, reflect.Array:
length := value.Len()
slice := make([]any, 0, length)
for i := 0; i < length; i++ {
elem := value.Index(i)
switch elem.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
slice = append(slice, strconv.FormatInt(elem.Int(), 10))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
slice = append(slice, strconv.FormatUint(elem.Uint(), 10))
case reflect.Float32, reflect.Float64:
slice = append(slice, strconv.FormatFloat(elem.Float(), 'f', -1, 64))
case reflect.String:
slice = append(slice, elem.String())
default:
slice = append(slice, elem) // do nothing
}
}
return slice
default:
return []any{fmt.Sprintf("%v", value)}
}
}
func assertJSONNumber(f json.Number) any {
if i64, err := f.Int64(); err == nil {
return i64
}
if f64, err := f.Float64(); err == nil {
return f64
}
return f.String()
}

View File

@@ -0,0 +1,22 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package milvus
const (
batchSize = 100
topK = 4
)

View File

@@ -0,0 +1,119 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package milvus
import (
"fmt"
"github.com/milvus-io/milvus/client/v2/entity"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
)
func denseFieldName(name string) string {
return fmt.Sprintf("dense_%s", name)
}
func denseIndexName(name string) string {
return fmt.Sprintf("index_dense_%s", name)
}
func sparseFieldName(name string) string {
return fmt.Sprintf("sparse_%s", name)
}
func sparseIndexName(name string) string {
return fmt.Sprintf("index_sparse_%s", name)
}
func convertFieldType(typ searchstore.FieldType) (entity.FieldType, error) {
switch typ {
case searchstore.FieldTypeInt64:
return entity.FieldTypeInt64, nil
case searchstore.FieldTypeText:
return entity.FieldTypeVarChar, nil
case searchstore.FieldTypeDenseVector:
return entity.FieldTypeFloatVector, nil
case searchstore.FieldTypeSparseVector:
return entity.FieldTypeSparseVector, nil
default:
return entity.FieldTypeNone, fmt.Errorf("[convertFieldType] unknown field type: %v", typ)
}
}
func convertDense(dense [][]float64) [][]float32 {
return slices.Transform(dense, func(a []float64) []float32 {
r := make([]float32, len(a))
for i := 0; i < len(a); i++ {
r[i] = float32(a[i])
}
return r
})
}
func convertMilvusDenseVector(dense [][]float64) []entity.Vector {
return slices.Transform(dense, func(a []float64) entity.Vector {
r := make([]float32, len(a))
for i := 0; i < len(a); i++ {
r[i] = float32(a[i])
}
return entity.FloatVector(r)
})
}
func convertSparse(sparse []map[int]float64) ([]entity.SparseEmbedding, error) {
r := make([]entity.SparseEmbedding, 0, len(sparse))
for _, s := range sparse {
ks := make([]uint32, 0, len(s))
vs := make([]float32, 0, len(s))
for k, v := range s {
ks = append(ks, uint32(k))
vs = append(vs, float32(v))
}
se, err := entity.NewSliceSparseEmbedding(ks, vs)
if err != nil {
return nil, err
}
r = append(r, se)
}
return r, nil
}
func convertMilvusSparseVector(sparse []map[int]float64) ([]entity.Vector, error) {
r := make([]entity.Vector, 0, len(sparse))
for _, s := range sparse {
ks := make([]uint32, 0, len(s))
vs := make([]float32, 0, len(s))
for k, v := range s {
ks = append(ks, uint32(k))
vs = append(vs, float32(v))
}
se, err := entity.NewSliceSparseEmbedding(ks, vs)
if err != nil {
return nil, err
}
r = append(r, se)
}
return r, nil
}

View File

@@ -0,0 +1,334 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package milvus
import (
"context"
"fmt"
"strings"
mentity "github.com/milvus-io/milvus/client/v2/entity"
mindex "github.com/milvus-io/milvus/client/v2/index"
client "github.com/milvus-io/milvus/client/v2/milvusclient"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
"github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/sets"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
type ManagerConfig struct {
Client *client.Client // required
Embedding embedding.Embedder // required
EnableHybrid *bool // optional: default Embedding.SupportStatus() == embedding.SupportDenseAndSparse
DenseIndex mindex.Index // optional: default HNSW, M=30, efConstruction=360
DenseMetric mentity.MetricType // optional: default IP
SparseIndex mindex.Index // optional: default SPARSE_INVERTED_INDEX, drop_ratio=0.2
SparseMetric mentity.MetricType // optional: default IP
ShardNum int // optional: default 1
BatchSize int // optional: default 100
}
func NewManager(config *ManagerConfig) (searchstore.Manager, error) {
if config.Client == nil {
return nil, fmt.Errorf("[NewManager] milvus client not provided")
}
if config.Embedding == nil {
return nil, fmt.Errorf("[NewManager] milvus embedder not provided")
}
enableSparse := config.Embedding.SupportStatus() == embedding.SupportDenseAndSparse
if config.EnableHybrid == nil {
config.EnableHybrid = ptr.Of(enableSparse)
} else if !enableSparse && ptr.From(config.EnableHybrid) {
logs.Warnf("[NewManager] milvus embedding not support sparse, so hybrid search is disabled.")
config.EnableHybrid = ptr.Of(false)
}
if config.DenseMetric == "" {
config.DenseMetric = mentity.IP
}
if config.DenseIndex == nil {
config.DenseIndex = mindex.NewHNSWIndex(config.DenseMetric, 30, 360)
}
if config.SparseMetric == "" {
config.SparseMetric = mentity.IP
}
if config.SparseIndex == nil {
config.SparseIndex = mindex.NewSparseInvertedIndex(config.SparseMetric, 0.2)
}
if config.ShardNum == 0 {
config.ShardNum = 1
}
if config.BatchSize == 0 {
config.BatchSize = 100
}
return &milvusManager{config: config}, nil
}
type milvusManager struct {
config *ManagerConfig
}
func (m *milvusManager) Create(ctx context.Context, req *searchstore.CreateRequest) error {
if err := m.createCollection(ctx, req); err != nil {
return fmt.Errorf("[Create] create collection failed, %w", err)
}
if err := m.createIndexes(ctx, req); err != nil {
return fmt.Errorf("[Create] create indexes failed, %w", err)
}
if exists, err := m.loadCollection(ctx, req.CollectionName); err != nil {
return fmt.Errorf("[Create] load collection failed, %w", err)
} else if !exists {
return fmt.Errorf("[Create] load collection failed, collection=%v does not exist", req.CollectionName)
}
return nil
}
func (m *milvusManager) Drop(ctx context.Context, req *searchstore.DropRequest) error {
return m.config.Client.DropCollection(ctx, client.NewDropCollectionOption(req.CollectionName))
}
func (m *milvusManager) GetType() searchstore.SearchStoreType {
return searchstore.TypeVectorStore
}
func (m *milvusManager) GetSearchStore(ctx context.Context, collectionName string) (searchstore.SearchStore, error) {
if exists, err := m.loadCollection(ctx, collectionName); err != nil {
return nil, err
} else if !exists {
return nil, errorx.New(errno.ErrKnowledgeNonRetryableCode,
errorx.KVf("reason", "[GetSearchStore] collection=%v does not exist", collectionName))
}
return &milvusSearchStore{
config: m.config,
collectionName: collectionName,
}, nil
}
func (m *milvusManager) createCollection(ctx context.Context, req *searchstore.CreateRequest) error {
if req.CollectionName == "" || len(req.Fields) == 0 {
return fmt.Errorf("[createCollection] invalid request params")
}
cli := m.config.Client
collectionName := req.CollectionName
has, err := cli.HasCollection(ctx, client.NewHasCollectionOption(collectionName))
if err != nil {
return fmt.Errorf("[createCollection] HasCollection failed, %w", err)
}
if has {
return nil
}
fields, err := m.convertFields(req.Fields)
if err != nil {
return err
}
opt := client.NewCreateCollectionOption(collectionName, &mentity.Schema{
CollectionName: collectionName,
Description: fmt.Sprintf("created by coze"),
AutoID: false,
Fields: fields,
EnableDynamicField: false,
}).WithShardNum(int32(m.config.ShardNum))
for k, v := range req.CollectionMeta {
opt.WithProperty(k, v)
}
if err = cli.CreateCollection(ctx, opt); err != nil {
return fmt.Errorf("[createCollection] CreateCollection failed, %w", err)
}
return nil
}
func (m *milvusManager) createIndexes(ctx context.Context, req *searchstore.CreateRequest) error {
collectionName := req.CollectionName
indexes, err := m.config.Client.ListIndexes(ctx, client.NewListIndexOption(req.CollectionName))
if err != nil {
if !strings.Contains(err.Error(), "index not found") {
return fmt.Errorf("[createIndexes] ListIndexes failed, %w", err)
}
}
createdIndexes := sets.FromSlice(indexes)
var ops []func() error
for i := range req.Fields {
f := req.Fields[i]
if !f.Indexing {
continue
}
ops = append(ops, m.tryCreateIndex(ctx, collectionName, denseFieldName(f.Name), denseIndexName(f.Name), m.config.DenseIndex, createdIndexes))
if m.config.Embedding.SupportStatus() == embedding.SupportDenseAndSparse {
ops = append(ops, m.tryCreateIndex(ctx, collectionName, sparseFieldName(f.Name), sparseIndexName(f.Name), m.config.SparseIndex, createdIndexes))
}
}
for _, op := range ops {
if err := op(); err != nil {
return fmt.Errorf("[createIndexes] failed, %w", err)
}
}
return nil
}
func (m *milvusManager) tryCreateIndex(ctx context.Context, collectionName, fieldName, indexName string, idx mindex.Index, createdIndexes sets.Set[string]) func() error {
return func() error {
if _, found := createdIndexes[indexName]; found {
logs.CtxInfof(ctx, "[tryCreateIndex] index exists, so skip, collectionName=%s, fieldName=%s, idx=%v, type=%s\n",
collectionName, fieldName, indexName, idx.IndexType())
return nil
}
cli := m.config.Client
task, err := cli.CreateIndex(ctx, client.NewCreateIndexOption(collectionName, fieldName, idx).WithIndexName(indexName))
if err != nil {
return fmt.Errorf("[tryCreateIndex] CreateIndex failed, %w", err)
}
if err = task.Await(ctx); err != nil {
return fmt.Errorf("[tryCreateIndex] await failed, %w", err)
}
logs.CtxInfof(ctx, "[tryCreateIndex] CreateIndex success, collectionName=%s, fieldName=%s, idx=%v, type=%s\n",
collectionName, fieldName, indexName, idx.IndexType())
return nil
}
}
func (m *milvusManager) loadCollection(ctx context.Context, collectionName string) (exists bool, err error) {
cli := m.config.Client
stat, err := cli.GetLoadState(ctx, client.NewGetLoadStateOption(collectionName))
if err != nil {
return false, fmt.Errorf("[loadCollection] GetLoadState failed, %w", err)
}
switch stat.State {
case mentity.LoadStateNotLoad:
task, err := cli.LoadCollection(ctx, client.NewLoadCollectionOption(collectionName))
if err != nil {
return false, fmt.Errorf("[loadCollection] LoadCollection failed, collection=%v, %w", collectionName, err)
}
if err = task.Await(ctx); err != nil {
return false, fmt.Errorf("[loadCollection] await failed, collection=%v, %w", collectionName, err)
}
return true, nil
case mentity.LoadStateLoaded:
return true, nil
case mentity.LoadStateLoading:
return true, fmt.Errorf("[loadCollection] collection is unloading, retry later, collection=%v", collectionName)
case mentity.LoadStateUnloading:
return false, nil
default:
return false, fmt.Errorf("[loadCollection] load state unexpected, state=%d", stat)
}
}
func (m *milvusManager) convertFields(fields []*searchstore.Field) ([]*mentity.Field, error) {
var foundID, foundCreatorID bool
resp := make([]*mentity.Field, 0, len(fields))
for _, f := range fields {
switch f.Name {
case searchstore.FieldID:
foundID = true
case searchstore.FieldCreatorID:
foundCreatorID = true
default:
}
if f.Indexing {
if f.Type != searchstore.FieldTypeText {
return nil, fmt.Errorf("[convertFields] milvus only support text field indexing, field=%s, type=%d", f.Name, f.Type)
}
// indexing 时只有 content 存储原文
if f.Name == searchstore.FieldTextContent {
resp = append(resp, mentity.NewField().
WithName(f.Name).
WithDescription(f.Description).
WithIsPrimaryKey(f.IsPrimary).
WithNullable(f.Nullable).
WithDataType(mentity.FieldTypeVarChar).
WithMaxLength(65535))
}
resp = append(resp, mentity.NewField().
WithName(denseFieldName(f.Name)).
WithDataType(mentity.FieldTypeFloatVector).
WithDim(m.config.Embedding.Dimensions()))
if m.config.Embedding.SupportStatus() == embedding.SupportDenseAndSparse {
resp = append(resp, mentity.NewField().
WithName(sparseFieldName(f.Name)).
WithDataType(mentity.FieldTypeSparseVector))
}
} else {
mf := mentity.NewField().
WithName(f.Name).
WithDescription(f.Description).
WithIsPrimaryKey(f.IsPrimary).
WithNullable(f.Nullable)
typ, err := convertFieldType(f.Type)
if err != nil {
return nil, err
}
mf.WithDataType(typ)
if typ == mentity.FieldTypeVarChar {
mf.WithMaxLength(65535)
} else if typ == mentity.FieldTypeFloatVector {
mf.WithDim(m.config.Embedding.Dimensions())
}
resp = append(resp, mf)
}
}
if !foundID {
resp = append(resp, mentity.NewField().
WithName(searchstore.FieldID).
WithDataType(mentity.FieldTypeInt64).
WithIsPrimaryKey(true).
WithNullable(false))
}
if !foundCreatorID {
resp = append(resp, mentity.NewField().
WithName(searchstore.FieldCreatorID).
WithDataType(mentity.FieldTypeInt64).
WithNullable(false))
}
return resp, nil
}
func (m *milvusManager) GetEmbedding() embedding.Embedder {
return m.config.Embedding
}

View File

@@ -0,0 +1,600 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package milvus
import (
"context"
"encoding/json"
"fmt"
"math"
"reflect"
"sort"
"strconv"
"strings"
"github.com/cloudwego/eino/components/indexer"
"github.com/cloudwego/eino/components/retriever"
"github.com/cloudwego/eino/schema"
"github.com/milvus-io/milvus/client/v2/column"
mentity "github.com/milvus-io/milvus/client/v2/entity"
mindex "github.com/milvus-io/milvus/client/v2/index"
client "github.com/milvus-io/milvus/client/v2/milvusclient"
"github.com/slongfield/pyfmt"
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
"github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/sets"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
)
type milvusSearchStore struct {
config *ManagerConfig
collectionName string
}
func (m *milvusSearchStore) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) (ids []string, err error) {
if len(docs) == 0 {
return nil, nil
}
implSpecOptions := indexer.GetImplSpecificOptions(&searchstore.IndexerOptions{}, opts...)
defer func() {
if err != nil {
if implSpecOptions.ProgressBar != nil {
implSpecOptions.ProgressBar.ReportError(err)
}
}
}()
indexingFields := make(sets.Set[string])
for _, field := range implSpecOptions.IndexingFields {
indexingFields[field] = struct{}{}
}
if implSpecOptions.Partition != nil {
partition := *implSpecOptions.Partition
hasPartition, err := m.config.Client.HasPartition(ctx, client.NewHasPartitionOption(m.collectionName, partition))
if err != nil {
return nil, fmt.Errorf("[Store] HasPartition failed, %w", err)
}
if !hasPartition {
if err = m.config.Client.CreatePartition(ctx, client.NewCreatePartitionOption(m.collectionName, partition)); err != nil {
return nil, fmt.Errorf("[Store] CreatePartition failed, %w", err)
}
}
}
for _, part := range slices.Chunks(docs, batchSize) {
columns, err := m.documents2Columns(ctx, part, indexingFields)
if err != nil {
return nil, err
}
createReq := client.NewColumnBasedInsertOption(m.collectionName, columns...)
if implSpecOptions.Partition != nil {
createReq.WithPartition(*implSpecOptions.Partition)
}
result, err := m.config.Client.Upsert(ctx, createReq)
if err != nil {
return nil, fmt.Errorf("[Store] upsert failed, %w", err)
}
partIDs := result.IDs
for i := 0; i < partIDs.Len(); i++ {
var sid string
if partIDs.Type() == mentity.FieldTypeInt64 {
id, err := partIDs.GetAsInt64(i)
if err != nil {
return nil, err
}
sid = strconv.FormatInt(id, 10)
} else {
sid, err = partIDs.GetAsString(i)
if err != nil {
return nil, err
}
}
ids = append(ids, sid)
}
if implSpecOptions.ProgressBar != nil {
if err = implSpecOptions.ProgressBar.AddN(len(part)); err != nil {
return nil, err
}
}
}
return ids, nil
}
func (m *milvusSearchStore) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) {
cli := m.config.Client
emb := m.config.Embedding
options := retriever.GetCommonOptions(&retriever.Options{TopK: ptr.Of(topK)}, opts...)
implSpecOptions := retriever.GetImplSpecificOptions(&searchstore.RetrieverOptions{}, opts...)
desc, err := cli.DescribeCollection(ctx, client.NewDescribeCollectionOption(m.collectionName))
if err != nil {
return nil, err
}
var (
dense [][]float64
sparse []map[int]float64
expr string
result []client.ResultSet
fields = desc.Schema.Fields
outputFields []string
enableSparse = m.enableSparse(fields)
)
if options.DSLInfo != nil {
expr, err = m.dsl2Expr(options.DSLInfo)
if err != nil {
return nil, err
}
}
if enableSparse {
dense, sparse, err = emb.EmbedStringsHybrid(ctx, []string{query})
if err != nil {
return nil, fmt.Errorf("[Retrieve] EmbedStringsHybrid failed, %w", err)
}
} else {
dense, err = emb.EmbedStrings(ctx, []string{query})
if err != nil {
return nil, fmt.Errorf("[Retrieve] EmbedStrings failed, %w", err)
}
}
dv := convertMilvusDenseVector(dense)
sv, err := convertMilvusSparseVector(sparse)
if err != nil {
return nil, err
}
for _, field := range fields {
outputFields = append(outputFields, field.Name)
}
var scoreNormType *mindex.MetricType
if enableSparse {
var annRequests []*client.AnnRequest
for _, field := range fields {
var (
vector []mentity.Vector
metricsType mindex.MetricType
)
if field.DataType == mentity.FieldTypeFloatVector {
vector = dv
metricsType, err = m.getIndexMetricsType(ctx, denseIndexName(field.Name))
} else if field.DataType == mentity.FieldTypeSparseVector {
vector = sv
metricsType, err = m.getIndexMetricsType(ctx, sparseIndexName(field.Name))
}
if err != nil {
return nil, err
}
annRequests = append(annRequests,
client.NewAnnRequest(field.Name, ptr.From(options.TopK), vector...).
WithSearchParam(mindex.MetricTypeKey, string(metricsType)).
WithFilter(expr),
)
}
searchOption := client.NewHybridSearchOption(m.collectionName, ptr.From(options.TopK), annRequests...).
WithPartitons(implSpecOptions.Partitions...).
WithReranker(client.NewRRFReranker()).
WithOutputFields(outputFields...)
result, err = cli.HybridSearch(ctx, searchOption)
if err != nil {
return nil, fmt.Errorf("[Retrieve] HybridSearch failed, %w", err)
}
} else {
indexes, err := cli.ListIndexes(ctx, client.NewListIndexOption(m.collectionName))
if err != nil {
return nil, fmt.Errorf("[Retrieve] ListIndexes failed, %w", err)
}
if len(indexes) != 1 {
return nil, fmt.Errorf("[Retrieve] restrict single index ann search, but got %d, collection=%s",
len(indexes), m.collectionName)
}
metricsType, err := m.getIndexMetricsType(ctx, indexes[0])
if err != nil {
return nil, err
}
scoreNormType = &metricsType
searchOption := client.NewSearchOption(m.collectionName, ptr.From(options.TopK), dv).
WithPartitions(implSpecOptions.Partitions...).
WithFilter(expr).
WithOutputFields(outputFields...).
WithSearchParam(mindex.MetricTypeKey, string(metricsType))
result, err = cli.Search(ctx, searchOption)
if err != nil {
return nil, fmt.Errorf("[Retrieve] Search failed, %w", err)
}
}
docs, err := m.resultSet2Document(result, scoreNormType)
if err != nil {
return nil, fmt.Errorf("[Retrieve] resultSet2Document failed, %w", err)
}
return docs, nil
}
func (m *milvusSearchStore) Delete(ctx context.Context, ids []string) error {
int64IDs := make([]int64, 0, len(ids))
for _, sid := range ids {
id, err := strconv.ParseInt(sid, 10, 64)
if err != nil {
return err
}
int64IDs = append(int64IDs, id)
}
_, err := m.config.Client.Delete(ctx,
client.NewDeleteOption(m.collectionName).WithInt64IDs(searchstore.FieldID, int64IDs))
return err
}
func (m *milvusSearchStore) documents2Columns(ctx context.Context, docs []*schema.Document, indexingFields sets.Set[string]) (
cols []column.Column, err error) {
var (
ids []int64
contents []string
creatorIDs []int64
emptyContents = true
)
colMapping := map[string]any{}
colTypeMapping := map[string]searchstore.FieldType{
searchstore.FieldID: searchstore.FieldTypeInt64,
searchstore.FieldCreatorID: searchstore.FieldTypeInt64,
searchstore.FieldTextContent: searchstore.FieldTypeText,
}
for _, doc := range docs {
if doc.MetaData == nil {
return nil, fmt.Errorf("[documents2Columns] meta data is nil")
}
id, err := strconv.ParseInt(doc.ID, 10, 64)
if err != nil {
return nil, fmt.Errorf("[documents2Columns] parse id failed, %w", err)
}
ids = append(ids, id)
contents = append(contents, doc.Content)
if doc.Content != "" {
emptyContents = false
}
creatorID, err := document.GetDocumentCreatorID(doc)
if err != nil {
return nil, fmt.Errorf("[documents2Columns] creator_id not found or type invalid., %w", err)
}
creatorIDs = append(creatorIDs, creatorID)
ext, ok := doc.MetaData[document.MetaDataKeyExternalStorage].(map[string]any)
if !ok {
continue
}
for field := range ext {
val := ext[field]
container := colMapping[field]
switch t := val.(type) {
case uint, uint8, uint16, uint32, uint64, uintptr:
var c []int64
if container == nil {
colTypeMapping[field] = searchstore.FieldTypeInt64
} else {
c, ok = container.([]int64)
if !ok {
return nil, fmt.Errorf("[documents2Columns] container type not int64")
}
}
c = append(c, int64(reflect.ValueOf(t).Uint()))
colMapping[field] = c
case int, int8, int16, int32, int64:
var c []int64
if container == nil {
colTypeMapping[field] = searchstore.FieldTypeInt64
} else {
c, ok = container.([]int64)
if !ok {
return nil, fmt.Errorf("[documents2Columns] container type not int64")
}
}
c = append(c, reflect.ValueOf(t).Int())
colMapping[field] = c
case string:
var c []string
if container == nil {
colTypeMapping[field] = searchstore.FieldTypeText
} else {
c, ok = container.([]string)
if !ok {
return nil, fmt.Errorf("[documents2Columns] container type not int64")
}
}
c = append(c, t)
colMapping[field] = c
case []float64:
var c [][]float64
if container == nil {
container = c
colTypeMapping[field] = searchstore.FieldTypeDenseVector
} else {
c, ok = container.([][]float64)
if !ok {
return nil, fmt.Errorf("[documents2Columns] container type not int64")
}
}
c = append(c, t)
colMapping[field] = c
case map[int]float64:
var c []map[int]float64
if container == nil {
container = c
colTypeMapping[field] = searchstore.FieldTypeSparseVector
} else {
c, ok = container.([]map[int]float64)
if !ok {
return nil, fmt.Errorf("[documents2Columns] container type not int64")
}
}
c = append(c, t)
colMapping[field] = c
default:
return nil, fmt.Errorf("[documents2Columns] val type not support, val=%v", val)
}
}
}
colMapping[searchstore.FieldID] = ids
colMapping[searchstore.FieldCreatorID] = creatorIDs
colMapping[searchstore.FieldTextContent] = contents
for fieldName, container := range colMapping {
colType := colTypeMapping[fieldName]
switch colType {
case searchstore.FieldTypeInt64:
c, ok := container.([]int64)
if !ok {
return nil, fmt.Errorf("[documents2Columns] container type not int64")
}
cols = append(cols, column.NewColumnInt64(fieldName, c))
case searchstore.FieldTypeText:
c, ok := container.([]string)
if !ok {
return nil, fmt.Errorf("[documents2Columns] container type not string")
}
if _, indexing := indexingFields[fieldName]; indexing {
if fieldName == searchstore.FieldTextContent && !emptyContents {
cols = append(cols, column.NewColumnVarChar(fieldName, c))
}
var (
emb = m.config.Embedding
dense [][]float64
sparse []map[int]float64
)
if emb.SupportStatus() == embedding.SupportDenseAndSparse {
dense, sparse, err = emb.EmbedStringsHybrid(ctx, c)
} else {
dense, err = emb.EmbedStrings(ctx, c)
}
if err != nil {
return nil, fmt.Errorf("[slices2Columns] embed failed, %w", err)
}
cols = append(cols, column.NewColumnFloatVector(denseFieldName(fieldName), int(emb.Dimensions()), convertDense(dense)))
if emb.SupportStatus() == embedding.SupportDenseAndSparse {
s, err := convertSparse(sparse)
if err != nil {
return nil, err
}
cols = append(cols, column.NewColumnSparseVectors(sparseFieldName(fieldName), s))
}
} else {
cols = append(cols, column.NewColumnVarChar(fieldName, c))
}
case searchstore.FieldTypeDenseVector:
c, ok := container.([][]float64)
if !ok {
return nil, fmt.Errorf("[documents2Columns] container type not []float64")
}
cols = append(cols, column.NewColumnFloatVector(fieldName, int(m.config.Embedding.Dimensions()), convertDense(c)))
case searchstore.FieldTypeSparseVector:
c, ok := container.([]map[int]float64)
if !ok {
return nil, fmt.Errorf("[documents2Columns] container type not map[int]float64")
}
sparse, err := convertSparse(c)
if err != nil {
return nil, err
}
cols = append(cols, column.NewColumnSparseVectors(fieldName, sparse))
default:
return nil, fmt.Errorf("[documents2Columns] column type not support, type=%d", colType)
}
}
return cols, nil
}
func (m *milvusSearchStore) resultSet2Document(result []client.ResultSet, metricsType *mindex.MetricType) (docs []*schema.Document, err error) {
docs = make([]*schema.Document, 0, len(result))
minScore := math.MaxFloat64
maxScore := 0.0
for _, r := range result {
for i := 0; i < r.ResultCount; i++ {
ext := make(map[string]any)
doc := &schema.Document{MetaData: map[string]any{document.MetaDataKeyExternalStorage: ext}}
score := float64(r.Scores[i])
minScore = min(minScore, score)
maxScore = max(maxScore, score)
doc.WithScore(score)
for _, field := range r.Fields {
switch field.Name() {
case searchstore.FieldID:
id, err := field.GetAsInt64(i)
if err != nil {
return nil, err
}
doc.ID = strconv.FormatInt(id, 10)
case searchstore.FieldTextContent:
doc.Content, err = field.GetAsString(i)
case searchstore.FieldCreatorID:
doc.MetaData[document.MetaDataKeyCreatorID], err = field.GetAsInt64(i)
default:
ext[field.Name()], err = field.Get(i)
}
if err != nil {
return nil, err
}
}
docs = append(docs, doc)
}
}
sort.Slice(docs, func(i, j int) bool {
return docs[i].Score() > docs[j].Score()
})
// norm score
if (m.config.EnableHybrid != nil && *m.config.EnableHybrid) || metricsType == nil {
return docs, nil
}
switch *metricsType {
case mentity.L2:
base := maxScore - minScore
for i := range docs {
if base == 0 {
docs[i].WithScore(1.0)
} else {
docs[i].WithScore(1.0 - (docs[i].Score()-minScore)/base)
}
}
docs = slices.Reverse(docs)
case mentity.IP, mentity.COSINE:
for i := range docs {
docs[i].WithScore((docs[i].Score() + 1) / 2)
}
default:
}
return docs, nil
}
func (m *milvusSearchStore) enableSparse(fields []*mentity.Field) bool {
found := false
for _, field := range fields {
if field.DataType == mentity.FieldTypeSparseVector {
found = true
break
}
}
return found && *m.config.EnableHybrid && m.config.Embedding.SupportStatus() == embedding.SupportDenseAndSparse
}
func (m *milvusSearchStore) dsl2Expr(src map[string]interface{}) (string, error) {
if src == nil {
return "", nil
}
dsl, err := searchstore.LoadDSL(src)
if err != nil {
return "", err
}
var travDSL func(dsl *searchstore.DSL) (string, error)
travDSL = func(dsl *searchstore.DSL) (string, error) {
kv := map[string]interface{}{
"field": dsl.Field,
"val": dsl.Value,
}
switch dsl.Op {
case searchstore.OpEq:
return pyfmt.Fmt("{field} == {val}", kv)
case searchstore.OpNe:
return pyfmt.Fmt("{field} != {val}", kv)
case searchstore.OpLike:
return pyfmt.Fmt("{field} LIKE {val}", kv)
case searchstore.OpIn:
b, err := json.Marshal(dsl.Value)
if err != nil {
return "", err
}
kv["val"] = string(b)
return pyfmt.Fmt("{field} IN {val}", kv)
case searchstore.OpAnd, searchstore.OpOr:
sub, ok := dsl.Value.([]*searchstore.DSL)
if !ok {
return "", fmt.Errorf("[dsl2Expr] invalid sub dsl")
}
var items []string
for _, s := range sub {
str, err := travDSL(s)
if err != nil {
return "", fmt.Errorf("[dsl2Expr] parse sub failed, %w", err)
}
items = append(items, str)
}
if dsl.Op == searchstore.OpAnd {
return strings.Join(items, " AND "), nil
} else {
return strings.Join(items, " OR "), nil
}
default:
return "", fmt.Errorf("[dsl2Expr] unknown op type=%s", dsl.Op)
}
}
return travDSL(dsl)
}
func (m *milvusSearchStore) getIndexMetricsType(ctx context.Context, indexName string) (mindex.MetricType, error) {
index, err := m.config.Client.DescribeIndex(ctx, client.NewDescribeIndexOption(m.collectionName, indexName))
if err != nil {
return "", fmt.Errorf("[getIndexMetricsType] describe index failed, collection=%s, index=%s, %w",
m.collectionName, indexName, err)
}
typ, found := index.Params()[mindex.MetricTypeKey]
if !found { // unexpected
return "", fmt.Errorf("[getIndexMetricsType] invalid index params, collection=%s, index=%s", m.collectionName, indexName)
}
return mindex.MetricType(typ), nil
}

View File

@@ -0,0 +1,122 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package vikingdb
import (
"fmt"
"github.com/volcengine/volc-sdk-golang/service/vikingdb"
embcontract "github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
type VikingEmbeddingModelName string
const (
ModelNameDoubaoEmbedding VikingEmbeddingModelName = "doubao-embedding"
ModelNameDoubaoEmbeddingLarge VikingEmbeddingModelName = "doubao-embedding-large"
ModelNameDoubaoEmbeddingVision VikingEmbeddingModelName = "doubao-embedding-vision"
ModelNameBGELargeZH VikingEmbeddingModelName = "bge-large-zh"
ModelNameBGEM3 VikingEmbeddingModelName = "bge-m3"
ModelNameBGEVisualizedM3 VikingEmbeddingModelName = "bge-visualized-m3"
//ModelNameDoubaoEmbeddingAndM3 VikingEmbeddingModelName = "doubao-embedding-and-m3"
//ModelNameDoubaoEmbeddingLargeAndM3 VikingEmbeddingModelName = "doubao-embedding-large-and-m3"
//ModelNameBGELargeZHAndM3 VikingEmbeddingModelName = "bge-large-zh-and-m3"
)
func (v VikingEmbeddingModelName) Dimensions() int64 {
switch v {
case ModelNameDoubaoEmbedding, ModelNameDoubaoEmbeddingVision:
return 2048
case ModelNameDoubaoEmbeddingLarge:
return 4096
case ModelNameBGELargeZH, ModelNameBGEM3, ModelNameBGEVisualizedM3:
return 1024
default:
return 0
}
}
func (v VikingEmbeddingModelName) ModelVersion() *string {
switch v {
case ModelNameDoubaoEmbedding:
return ptr.Of("240515")
case ModelNameDoubaoEmbeddingLarge:
return ptr.Of("240915")
case ModelNameDoubaoEmbeddingVision:
return ptr.Of("250328")
default:
return nil
}
}
func (v VikingEmbeddingModelName) SupportStatus() embcontract.SupportStatus {
switch v {
case ModelNameDoubaoEmbedding, ModelNameDoubaoEmbeddingLarge, ModelNameDoubaoEmbeddingVision, ModelNameBGELargeZH, ModelNameBGEVisualizedM3:
return embcontract.SupportDense
case ModelNameBGEM3:
return embcontract.SupportDenseAndSparse
default:
return embcontract.SupportDense
}
}
type IndexType string
const (
IndexTypeHNSW IndexType = vikingdb.HNSW
IndexTypeHNSWHybrid IndexType = vikingdb.HNSW_HYBRID
IndexTypeFlat IndexType = vikingdb.FLAT
IndexTypeIVF IndexType = vikingdb.IVF
IndexTypeDiskANN IndexType = vikingdb.DiskANN
)
type IndexDistance string
const (
IndexDistanceIP IndexDistance = vikingdb.IP
IndexDistanceL2 IndexDistance = vikingdb.L2
IndexDistanceCosine IndexDistance = vikingdb.COSINE
)
type IndexQuant string
const (
IndexQuantInt8 IndexQuant = vikingdb.Int8
IndexQuantFloat IndexQuant = vikingdb.Float
IndexQuantFix16 IndexQuant = vikingdb.Fix16
IndexQuantPQ IndexQuant = vikingdb.PQ
)
const (
vikingEmbeddingUseDense = "return_dense"
vikingEmbeddingUseSparse = "return_sparse"
vikingEmbeddingRespSentenceDense = "sentence_dense_embedding"
vikingEmbeddingRespSentenceSparse = "sentence_sparse_embedding"
vikingIndexName = "opencoze_index"
)
const (
errCollectionNotFound = "collection not found"
errIndexNotFound = "index not found"
)
func denseFieldName(name string) string {
return fmt.Sprintf("dense_%s", name)
}

View File

@@ -0,0 +1,331 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package vikingdb
import (
"context"
"fmt"
"strings"
"github.com/volcengine/volc-sdk-golang/service/vikingdb"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
"github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
)
type ManagerConfig struct {
Service *vikingdb.VikingDBService
IndexingConfig *VikingIndexingConfig
EmbeddingConfig *VikingEmbeddingConfig
// TODO: cache viking collection & index client
}
type VikingIndexingConfig struct {
// vector index config
Type IndexType // default: hnsw / hnsw_hybrid
Distance *IndexDistance // default: ip
Quant *IndexQuant // default: int8
HnswM *int64 // default: 20
HnswCef *int64 // default: 400
HnswSef *int64 // default: 800
// others
CpuQuota int64 // default: 2
ShardCount int64 // default: 1
}
type VikingEmbeddingConfig struct {
UseVikingEmbedding bool
EnableHybrid bool
// viking embedding config
ModelName VikingEmbeddingModelName
ModelVersion *string
DenseWeight *float64
// builtin embedding config
BuiltinEmbedding embedding.Embedder
}
func NewManager(config *ManagerConfig) (searchstore.Manager, error) {
if config.Service == nil {
return nil, fmt.Errorf("[NewManager] vikingdb service is nil")
}
if config.EmbeddingConfig == nil {
return nil, fmt.Errorf("[NewManager] vikingdb embedding config is nil")
}
if !config.EmbeddingConfig.UseVikingEmbedding && config.EmbeddingConfig.BuiltinEmbedding == nil {
return nil, fmt.Errorf("[NewManager] vikingdb built embedding not provided")
}
if !config.EmbeddingConfig.UseVikingEmbedding && config.EmbeddingConfig.EnableHybrid {
return nil, fmt.Errorf("[NewManager] vikingdb hybrid not support for builtin embedding")
}
if config.EmbeddingConfig.UseVikingEmbedding && config.EmbeddingConfig.ModelName == "" {
return nil, fmt.Errorf("[NewManager] vikingdb model name is empty")
}
if config.EmbeddingConfig.UseVikingEmbedding &&
config.EmbeddingConfig.EnableHybrid &&
config.EmbeddingConfig.ModelName.SupportStatus() != embedding.SupportDenseAndSparse {
return nil, fmt.Errorf("[NewManager] vikingdb embedding model not support sparse embedding, model=%v", config.EmbeddingConfig.ModelName)
}
if config.IndexingConfig == nil {
config.IndexingConfig = &VikingIndexingConfig{}
}
if config.IndexingConfig.Type == "" {
if !config.EmbeddingConfig.UseVikingEmbedding || !config.EmbeddingConfig.EnableHybrid {
config.IndexingConfig.Type = IndexTypeHNSW
} else {
config.IndexingConfig.Type = IndexTypeHNSWHybrid
}
}
if config.IndexingConfig.Distance == nil {
config.IndexingConfig.Distance = ptr.Of(IndexDistanceIP)
}
if config.IndexingConfig.Quant == nil {
config.IndexingConfig.Quant = ptr.Of(IndexQuantInt8)
}
if config.IndexingConfig.HnswM == nil {
config.IndexingConfig.HnswM = ptr.Of(int64(20))
}
if config.IndexingConfig.HnswCef == nil {
config.IndexingConfig.HnswCef = ptr.Of(int64(400))
}
if config.IndexingConfig.HnswSef == nil {
config.IndexingConfig.HnswSef = ptr.Of(int64(800))
}
if config.IndexingConfig.CpuQuota == 0 {
config.IndexingConfig.CpuQuota = 2
}
if config.IndexingConfig.ShardCount == 0 {
config.IndexingConfig.ShardCount = 1
}
return &manager{
config: config,
}, nil
}
type manager struct {
config *ManagerConfig
}
func (m *manager) Create(ctx context.Context, req *searchstore.CreateRequest) error {
if err := m.createCollection(ctx, req); err != nil {
return err
}
if err := m.createIndex(ctx, req); err != nil {
return err
}
return nil
}
func (m *manager) Drop(_ context.Context, req *searchstore.DropRequest) error {
if err := m.config.Service.DropIndex(req.CollectionName, vikingIndexName); err != nil {
if !strings.Contains(err.Error(), errIndexNotFound) {
return err
}
}
if err := m.config.Service.DropCollection(req.CollectionName); err != nil {
if !strings.Contains(err.Error(), errCollectionNotFound) {
return err
}
}
return nil
}
func (m *manager) GetType() searchstore.SearchStoreType {
return searchstore.TypeVectorStore
}
func (m *manager) GetSearchStore(_ context.Context, collectionName string) (searchstore.SearchStore, error) {
collection, err := m.config.Service.GetCollection(collectionName)
if err != nil {
return nil, err
}
return &vkSearchStore{manager: m, collection: collection}, nil
}
func (m *manager) createCollection(ctx context.Context, req *searchstore.CreateRequest) error {
svc := m.config.Service
collection, err := svc.GetCollection(req.CollectionName)
if err != nil {
if !strings.Contains(err.Error(), errCollectionNotFound) {
return err
}
} else if collection != nil {
return nil
}
fields, vopts, err := m.mapFields(req.Fields)
if err != nil {
return err
}
if vopts != nil {
_, err = svc.CreateCollection(req.CollectionName, fields, "", vopts)
} else {
_, err = svc.CreateCollection(req.CollectionName, fields, "")
}
if err != nil {
return err
}
logs.CtxInfof(ctx, "[vikingdb] Create collection success, collection=%s", req.CollectionName)
return nil
}
func (m *manager) createIndex(ctx context.Context, req *searchstore.CreateRequest) error {
svc := m.config.Service
index, err := svc.GetIndex(req.CollectionName, vikingIndexName)
if err != nil {
if !strings.Contains(err.Error(), errIndexNotFound) {
return err
}
} else if index != nil {
return nil
}
vectorIndex := &vikingdb.VectorIndexParams{
IndexType: string(m.config.IndexingConfig.Type),
Distance: string(ptr.From(m.config.IndexingConfig.Distance)),
Quant: string(ptr.From(m.config.IndexingConfig.Quant)),
HnswM: ptr.From(m.config.IndexingConfig.HnswM),
HnswCef: ptr.From(m.config.IndexingConfig.HnswCef),
HnswSef: ptr.From(m.config.IndexingConfig.HnswSef),
}
opts := vikingdb.NewIndexOptions().
SetVectorIndex(vectorIndex).
SetCpuQuota(m.config.IndexingConfig.CpuQuota).
SetShardCount(m.config.IndexingConfig.ShardCount)
_, err = svc.CreateIndex(req.CollectionName, vikingIndexName, opts)
if err != nil {
return err
}
logs.CtxInfof(ctx, "[vikingdb] Create index success, collection=%s, index=%s", req.CollectionName, vikingIndexName)
return nil
}
func (m *manager) mapFields(srcFields []*searchstore.Field) ([]vikingdb.Field, []*vikingdb.VectorizeTuple, error) {
var (
foundID bool
foundCreatorID bool
dstFields = make([]vikingdb.Field, 0, len(srcFields))
vectorizeOpts []*vikingdb.VectorizeTuple
embConfig = m.config.EmbeddingConfig
)
for _, srcField := range srcFields {
switch srcField.Name {
case searchstore.FieldID:
foundID = true
case searchstore.FieldCreatorID:
foundCreatorID = true
default:
}
if srcField.Indexing {
if srcField.Type != searchstore.FieldTypeText {
return nil, nil, fmt.Errorf("[mapFields] currently only support text field indexing, field=%s", srcField.Name)
}
if embConfig.UseVikingEmbedding {
vt := vikingdb.NewVectorizeTuple().SetDense(m.newVectorizeModelConf(srcField.Name))
if embConfig.EnableHybrid {
vt = vt.SetSparse(m.newVectorizeModelConf(srcField.Name))
}
vectorizeOpts = append(vectorizeOpts, vt)
} else {
dstFields = append(dstFields, vikingdb.Field{
FieldName: denseFieldName(srcField.Name),
FieldType: vikingdb.Vector,
DefaultVal: nil,
Dim: m.getDims(),
})
}
}
dstField := vikingdb.Field{
FieldName: srcField.Name,
IsPrimaryKey: srcField.IsPrimary,
}
switch srcField.Type {
case searchstore.FieldTypeInt64:
dstField.FieldType = vikingdb.Int64
case searchstore.FieldTypeText:
dstField.FieldType = vikingdb.Text
case searchstore.FieldTypeDenseVector:
dstField.FieldType = vikingdb.Vector
dstField.Dim = m.getDims()
case searchstore.FieldTypeSparseVector:
dstField.FieldType = vikingdb.Sparse_Vector
default:
return nil, nil, fmt.Errorf("unknown field type: %v", srcField.Type)
}
dstFields = append(dstFields, dstField)
}
if !foundID {
dstFields = append(dstFields, vikingdb.Field{
FieldName: searchstore.FieldID,
FieldType: vikingdb.Int64,
IsPrimaryKey: true,
})
}
if !foundCreatorID {
dstFields = append(dstFields, vikingdb.Field{
FieldName: searchstore.FieldCreatorID,
FieldType: vikingdb.Int64,
})
}
return dstFields, vectorizeOpts, nil
}
func (m *manager) newVectorizeModelConf(fieldName string) *vikingdb.VectorizeModelConf {
embConfig := m.config.EmbeddingConfig
vmc := vikingdb.NewVectorizeModelConf().
SetTextField(fieldName).
SetModelName(string(embConfig.ModelName)).
SetDim(m.getDims())
if embConfig.ModelVersion != nil {
vmc = vmc.SetModelVersion(ptr.From(embConfig.ModelVersion))
}
return vmc
}
func (m *manager) getDims() int64 {
if m.config.EmbeddingConfig.UseVikingEmbedding {
return m.config.EmbeddingConfig.ModelName.Dimensions()
}
return m.config.EmbeddingConfig.BuiltinEmbedding.Dimensions()
}

View File

@@ -0,0 +1,388 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package vikingdb
import (
"context"
"encoding/json"
"fmt"
"reflect"
"strconv"
"github.com/cloudwego/eino/components/indexer"
"github.com/cloudwego/eino/components/retriever"
"github.com/cloudwego/eino/schema"
"github.com/volcengine/volc-sdk-golang/service/vikingdb"
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/sets"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
)
type vkSearchStore struct {
*manager
collection *vikingdb.Collection
index *vikingdb.Index
}
func (v *vkSearchStore) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) (ids []string, err error) {
if len(docs) == 0 {
return nil, nil
}
implSpecOptions := indexer.GetImplSpecificOptions(&searchstore.IndexerOptions{}, opts...)
defer func() {
if err != nil {
if implSpecOptions.ProgressBar != nil {
_ = implSpecOptions.ProgressBar.ReportError(err)
}
}
}()
docsWithoutVector, err := slices.TransformWithErrorCheck(docs, v.document2DataWithoutVector)
if err != nil {
return nil, fmt.Errorf("[Store] vikingdb failed to transform documents, %w", err)
}
indexingFields := sets.FromSlice(implSpecOptions.IndexingFields)
for _, part := range slices.Chunks(docsWithoutVector, 100) {
docsWithVector, err := v.addEmbedding(ctx, part, indexingFields)
if err != nil {
return nil, err
}
if err := v.collection.UpsertData(docsWithVector); err != nil {
return nil, err
}
}
ids = slices.Transform(docs, func(a *schema.Document) string { return a.ID })
return
}
func (v *vkSearchStore) Retrieve(ctx context.Context, query string, opts ...retriever.Option) (docs []*schema.Document, err error) {
indexClient := v.index
if indexClient == nil {
foundIndex := false
for _, index := range v.collection.Indexes {
if index.IndexName == vikingIndexName {
foundIndex = true
break
}
}
if !foundIndex {
return nil, fmt.Errorf("[Retrieve] vikingdb index not found, name=%s", vikingIndexName)
}
indexClient, err = v.config.Service.GetIndex(v.collection.CollectionName, vikingIndexName)
if err != nil {
return nil, fmt.Errorf("[Retrieve] vikingdb failed to get index, %w", err)
}
}
options := retriever.GetCommonOptions(&retriever.Options{TopK: ptr.Of(4)}, opts...)
implSpecOptions := retriever.GetImplSpecificOptions(&searchstore.RetrieverOptions{}, opts...)
searchOpts := vikingdb.NewSearchOptions().
SetLimit(int64(ptr.From(options.TopK))).
SetText(query).
SetRetry(true)
filter, err := v.genFilter(ctx, options, implSpecOptions)
if err != nil {
return nil, fmt.Errorf("[Retrieve] vikingdb failed to build filter, %w", err)
}
if filter != nil {
// 不支持跨 partition 召回,使用 filter 替代
searchOpts = searchOpts.SetFilter(filter)
}
var data []*vikingdb.Data
if v.config.EmbeddingConfig.UseVikingEmbedding {
data, err = indexClient.SearchWithMultiModal(searchOpts)
} else {
var dense [][]float64
dense, err = v.config.EmbeddingConfig.BuiltinEmbedding.EmbedStrings(ctx, []string{query})
if err != nil {
return nil, fmt.Errorf("[Retrieve] embed failed, %w", err)
}
if len(dense) != 1 {
return nil, fmt.Errorf("[Retrieve] unexpected dense vector size, expected=1, got=%d", len(dense))
}
data, err = indexClient.SearchByVector(dense[0], searchOpts)
}
if err != nil {
return nil, fmt.Errorf("[Retrieve] vikingdb search failed, %w", err)
}
docs, err = v.parseSearchResult(data)
if err != nil {
return nil, err
}
return
}
func (v *vkSearchStore) Delete(ctx context.Context, ids []string) error {
for _, part := range slices.Chunks(ids, 100) {
if err := v.collection.DeleteData(part); err != nil {
return err
}
}
return nil
}
func (v *vkSearchStore) document2DataWithoutVector(doc *schema.Document) (data vikingdb.Data, err error) {
creatorID, err := document.GetDocumentCreatorID(doc)
if err != nil {
return data, err
}
docID, err := strconv.ParseInt(doc.ID, 10, 64)
if err != nil {
return data, err
}
fields := map[string]interface{}{
searchstore.FieldID: docID,
searchstore.FieldCreatorID: creatorID,
searchstore.FieldTextContent: doc.Content,
}
if ext, err := document.GetDocumentExternalStorage(doc); err == nil { // try load
for key, val := range ext {
fields[key] = val
}
}
return vikingdb.Data{
Id: doc.ID,
Fields: fields,
}, nil
}
func (v *vkSearchStore) addEmbedding(ctx context.Context, rows []vikingdb.Data, indexingFields map[string]struct{}) ([]vikingdb.Data, error) {
if v.config.EmbeddingConfig.UseVikingEmbedding {
return rows, nil
}
emb := v.config.EmbeddingConfig.BuiltinEmbedding
for indexingField := range indexingFields {
values := make([]string, len(rows))
for i, row := range rows {
val, found := row.Fields[indexingField]
if !found {
return nil, fmt.Errorf("[addEmbedding] indexing field not found in document, field=%s", indexingField)
}
strVal, ok := val.(string)
if !ok {
return nil, fmt.Errorf("[addEmbedding] val not string, field=%s, val=%v", indexingField, val)
}
values[i] = strVal
}
dense, err := emb.EmbedStrings(ctx, values)
if err != nil {
return nil, fmt.Errorf("[addEmbedding] failed to embed, %w", err)
}
if len(dense) != len(values) {
return nil, fmt.Errorf("[addEmbedding] unexpected dense vector size, expected=%d, got=%d", len(values), len(dense))
}
df := denseFieldName(indexingField)
for i := range dense {
rows[i].Fields[df] = dense[i]
}
}
return rows, nil
}
func (v *vkSearchStore) parseSearchResult(result []*vikingdb.Data) ([]*schema.Document, error) {
docs := make([]*schema.Document, 0, len(result))
for _, data := range result {
ext := make(map[string]any)
doc := document.WithDocumentExternalStorage(&schema.Document{MetaData: map[string]any{}}, ext).
WithScore(data.Score)
for field, val := range data.Fields {
switch field {
case searchstore.FieldID:
jn, ok := val.(json.Number)
if !ok {
return nil, fmt.Errorf("[parseSearchResult] id type assertion failed, val=%v", val)
}
doc.ID = jn.String()
case searchstore.FieldCreatorID:
jn, ok := val.(json.Number)
if !ok {
return nil, fmt.Errorf("[parseSearchResult] creator_id type assertion failed, val=%v", val)
}
creatorID, err := jn.Int64()
if err != nil {
return nil, fmt.Errorf("[parseSearchResult] creator_id value not int64, val=%v", jn.String())
}
doc = document.WithDocumentCreatorID(doc, creatorID)
case searchstore.FieldTextContent:
text, ok := val.(string)
if !ok {
return nil, fmt.Errorf("[parseSearchResult] content value not string, val=%v", val)
}
doc.Content = text
default:
switch t := val.(type) {
case json.Number:
if i64, err := t.Int64(); err == nil {
ext[field] = i64
} else if f64, err := t.Float64(); err == nil {
ext[field] = f64
} else {
ext[field] = t.String()
}
default:
ext[field] = val
}
}
}
docs = append(docs, doc)
}
return docs, nil
}
func (v *vkSearchStore) genFilter(ctx context.Context, co *retriever.Options, ro *searchstore.RetrieverOptions) (map[string]any, error) {
filter, err := v.dsl2Filter(ctx, co.DSLInfo)
if err != nil {
return nil, err
}
if ro.PartitionKey != nil && len(ro.Partitions) > 0 {
var (
key = ptr.From(ro.PartitionKey)
fieldType = ""
conds any
)
for _, field := range v.collection.Fields {
if field.FieldName == key {
fieldType = field.FieldType
}
}
if fieldType == "" {
return nil, fmt.Errorf("[Retrieve] partition key not found, key=%s", key)
}
switch fieldType {
case vikingdb.Int64:
c := make([]int64, 0, len(ro.Partitions))
for _, item := range ro.Partitions {
i64, err := strconv.ParseInt(item, 10, 64)
if err != nil {
return nil, fmt.Errorf("[Retrieve] partition value parse error, key=%s, val=%v, err=%v", key, item, err)
}
c = append(c, i64)
}
conds = c
case vikingdb.String:
conds = ro.Partitions
default:
return nil, fmt.Errorf("[Retrieve] invalid field type for partition, key=%s, type=%s", key, fieldType)
}
op := map[string]any{"op": "must", "field": key, "conds": conds}
if filter != nil {
filter = op
} else {
filter = map[string]any{
"op": "and",
"conds": []map[string]any{op, filter},
}
}
}
return filter, nil
}
func (v *vkSearchStore) dsl2Filter(ctx context.Context, src map[string]any) (map[string]any, error) {
dsl, err := searchstore.LoadDSL(src)
if err != nil {
return nil, err
}
if dsl == nil {
return nil, nil
}
toSliceValue := func(val any) any {
if reflect.TypeOf(val).Kind() == reflect.Slice {
return val
}
return []any{val}
}
var filter map[string]any
switch dsl.Op {
case searchstore.OpEq, searchstore.OpIn:
filter = map[string]any{
"op": "must",
"field": dsl.Field,
"conds": toSliceValue(dsl.Value),
}
case searchstore.OpNe:
filter = map[string]any{
"op": "must_not",
"field": dsl.Field,
"conds": toSliceValue(dsl.Value),
}
case searchstore.OpLike:
logs.CtxWarnf(ctx, "[dsl2Filter] vikingdb invalid dsl type, skip, type=%s", dsl.Op)
case searchstore.OpAnd, searchstore.OpOr:
var conds []map[string]any
sub, ok := dsl.Value.([]map[string]any)
if !ok {
return nil, fmt.Errorf("[dsl2Filter] invalid value for and/or, should be []map[string]any")
}
for _, subDSL := range sub {
cond, err := v.dsl2Filter(ctx, subDSL)
if err != nil {
return nil, err
}
conds = append(conds, cond)
}
op := "and"
if dsl.Op == searchstore.OpOr {
op = "or"
}
filter = map[string]any{
"op": op,
"field": dsl.Field,
"conds": conds,
}
}
return filter, nil
}

Some files were not shown because too many files have changed in this diff Show More