feat: manually mirror opencoze's code from bytedance
Change-Id: I09a73aadda978ad9511264a756b2ce51f5761adf
This commit is contained in:
23
backend/infra/contract/cache/cache.go
vendored
Normal file
23
backend/infra/contract/cache/cache.go
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package cache
|
||||
|
||||
import (
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
type Cmdable = redis.Cmdable
|
||||
35
backend/infra/contract/chatmodel/chat_model.go
Normal file
35
backend/infra/contract/chatmodel/chat_model.go
Normal file
@@ -0,0 +1,35 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package chatmodel
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cloudwego/eino/components/model"
|
||||
)
|
||||
|
||||
//go:generate mockgen -destination ../../../internal/mock/infra/contract/chatmodel/base_model_mock.go -package mock -source ${GOPATH}/src/github.com/cloudwego/eino/components/model/interface.go BaseChatModel
|
||||
type BaseChatModel = model.BaseChatModel
|
||||
|
||||
//go:generate mockgen -destination ../../../internal/mock/infra/contract/chatmodel/toolcalling_model_mock.go -package mock -source ${GOPATH}/src/github.com/cloudwego/eino/components/model/interface.go ToolCallingChatModel
|
||||
type ToolCallingChatModel = model.ToolCallingChatModel
|
||||
|
||||
//go:generate mockgen -destination ../../../internal/mock/infra/contract/chatmodel/chat_model_factory_mock.go -package mock -source chat_model.go Factory
|
||||
type Factory interface {
|
||||
CreateChatModel(ctx context.Context, protocol Protocol, config *Config) (ToolCallingChatModel, error)
|
||||
SupportProtocol(protocol Protocol) bool
|
||||
}
|
||||
94
backend/infra/contract/chatmodel/config.go
Normal file
94
backend/infra/contract/chatmodel/config.go
Normal file
@@ -0,0 +1,94 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package chatmodel
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/model/deepseek"
|
||||
"github.com/cloudwego/eino-ext/libs/acl/openai"
|
||||
"google.golang.org/genai"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
BaseURL string `json:"base_url,omitempty" yaml:"base_url,omitempty"`
|
||||
APIKey string `json:"api_key,omitempty" yaml:"api_key,omitempty"`
|
||||
Timeout time.Duration `json:"timeout,omitempty" yaml:"timeout,omitempty"`
|
||||
|
||||
Model string `json:"model" yaml:"model"`
|
||||
Temperature *float32 `json:"temperature,omitempty" yaml:"temperature,omitempty"`
|
||||
FrequencyPenalty *float32 `json:"frequency_penalty,omitempty" yaml:"frequency_penalty,omitempty"`
|
||||
PresencePenalty *float32 `json:"presence_penalty,omitempty" yaml:"presence_penalty,omitempty"`
|
||||
MaxTokens *int `json:"max_tokens,omitempty" yaml:"max_tokens,omitempty"`
|
||||
TopP *float32 `json:"top_p,omitempty" yaml:"top_p"`
|
||||
TopK *int `json:"top_k,omitempty" yaml:"top_k"`
|
||||
Stop []string `json:"stop,omitempty" yaml:"stop"`
|
||||
EnableThinking *bool `json:"enable_thinking,omitempty" yaml:"enable_thinking,omitempty"`
|
||||
|
||||
OpenAI *OpenAIConfig `json:"open_ai,omitempty" yaml:"openai"`
|
||||
Claude *ClaudeConfig `json:"claude,omitempty" yaml:"claude"`
|
||||
Ark *ArkConfig `json:"ark,omitempty" yaml:"ark"`
|
||||
Deepseek *DeepseekConfig `json:"deepseek,omitempty" yaml:"deepseek"`
|
||||
Qwen *QwenConfig `json:"qwen,omitempty" yaml:"qwen"`
|
||||
Gemini *GeminiConfig `json:"gemini,omitempty" yaml:"gemini"`
|
||||
|
||||
Custom map[string]string `json:"custom,omitempty" yaml:"custom"`
|
||||
}
|
||||
|
||||
type OpenAIConfig struct {
|
||||
ByAzure bool `json:"by_azure,omitempty" yaml:"by_azure"`
|
||||
APIVersion string `json:"api_version,omitempty" yaml:"api_version"`
|
||||
|
||||
ResponseFormat *openai.ChatCompletionResponseFormat `json:"response_format,omitempty" yaml:"response_format"`
|
||||
}
|
||||
|
||||
type ClaudeConfig struct {
|
||||
ByBedrock bool `json:"by_bedrock" yaml:"by_bedrock"`
|
||||
// bedrock config
|
||||
AccessKey string `json:"access_key,omitempty" yaml:"access_key"`
|
||||
SecretAccessKey string `json:"secret_access_key,omitempty" yaml:"secret_access_key"`
|
||||
SessionToken string `json:"session_token,omitempty" yaml:"session_token"`
|
||||
Region string `json:"region,omitempty" yaml:"region"`
|
||||
}
|
||||
|
||||
type ArkConfig struct {
|
||||
Region string `json:"region" yaml:"region"`
|
||||
AccessKey string `json:"access_key,omitempty" yaml:"access_key"`
|
||||
SecretKey string `json:"secret_key,omitempty" yaml:"secret_key"`
|
||||
RetryTimes *int `json:"retry_times,omitempty" yaml:"retry_times"`
|
||||
CustomHeader map[string]string `json:"custom_header,omitempty" yaml:"custom_header"`
|
||||
}
|
||||
|
||||
type DeepseekConfig struct {
|
||||
ResponseFormatType deepseek.ResponseFormatType `json:"response_format_type" yaml:"response_format_type"`
|
||||
}
|
||||
|
||||
type QwenConfig struct {
|
||||
ResponseFormat *openai.ChatCompletionResponseFormat `json:"response_format,omitempty" yaml:"response_format"`
|
||||
}
|
||||
|
||||
type GeminiConfig struct {
|
||||
Backend genai.Backend `json:"backend,omitempty" yaml:"backend"`
|
||||
Project string `json:"project,omitempty" yaml:"project"`
|
||||
Location string `json:"location,omitempty" yaml:"location"`
|
||||
APIVersion string `json:"api_version,omitempty" yaml:"api_version"`
|
||||
Headers map[string][]string `json:"headers,omitempty" yaml:"headers"`
|
||||
TimeoutMs int64 `json:"timeout_ms,omitempty" yaml:"timeout_ms"`
|
||||
|
||||
IncludeThoughts *bool `json:"include_thoughts,omitempty" yaml:"include_thoughts"` // default true
|
||||
ThinkingBudget *int32 `json:"thinking_budget,omitempty" yaml:"thinking_budget"` // default nil
|
||||
}
|
||||
55
backend/infra/contract/chatmodel/protocol.go
Normal file
55
backend/infra/contract/chatmodel/protocol.go
Normal file
@@ -0,0 +1,55 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package chatmodel
|
||||
|
||||
import "github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/developer_api"
|
||||
|
||||
type Protocol string
|
||||
|
||||
const (
|
||||
ProtocolOpenAI Protocol = "openai"
|
||||
ProtocolClaude Protocol = "claude"
|
||||
ProtocolDeepseek Protocol = "deepseek"
|
||||
ProtocolGemini Protocol = "gemini"
|
||||
ProtocolArk Protocol = "ark"
|
||||
ProtocolOllama Protocol = "ollama"
|
||||
ProtocolQwen Protocol = "qwen"
|
||||
ProtocolErnie Protocol = "ernie"
|
||||
)
|
||||
|
||||
func (p Protocol) TOModelClass() developer_api.ModelClass {
|
||||
switch p {
|
||||
case ProtocolArk:
|
||||
return developer_api.ModelClass_SEED
|
||||
case ProtocolOpenAI:
|
||||
return developer_api.ModelClass_GPT
|
||||
case ProtocolDeepseek:
|
||||
return developer_api.ModelClass_DeekSeek
|
||||
case ProtocolClaude:
|
||||
return developer_api.ModelClass_Claude
|
||||
case ProtocolGemini:
|
||||
return developer_api.ModelClass_Gemini
|
||||
case ProtocolOllama:
|
||||
return developer_api.ModelClass_Llama
|
||||
case ProtocolQwen:
|
||||
return developer_api.ModelClass_QWen
|
||||
case ProtocolErnie:
|
||||
return developer_api.ModelClass_Ernie
|
||||
default:
|
||||
return developer_api.ModelClass_Other
|
||||
}
|
||||
}
|
||||
126
backend/infra/contract/document/extra.go
Normal file
126
backend/infra/contract/document/extra.go
Normal file
@@ -0,0 +1,126 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package document
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
const (
|
||||
MetaDataKeyColumns = "table_columns" // val: []*Column
|
||||
MetaDataKeyColumnData = "table_column_data" // val: []*ColumnData
|
||||
MetaDataKeyColumnsOnly = "table_columns_only" // val: struct{}, which means table has no data, only header.
|
||||
|
||||
MetaDataKeyCreatorID = "creator_id" // val: int64
|
||||
MetaDataKeyExternalStorage = "external_storage" // val: map[string]any
|
||||
)
|
||||
|
||||
func GetDocumentColumns(doc *schema.Document) ([]*Column, error) {
|
||||
if doc == nil || doc.MetaData == nil {
|
||||
return nil, fmt.Errorf("invalid document")
|
||||
}
|
||||
|
||||
columns, ok := doc.MetaData[MetaDataKeyColumns].([]*Column)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid document columns")
|
||||
}
|
||||
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
func WithDocumentColumns(doc *schema.Document, columns []*Column) *schema.Document {
|
||||
doc.MetaData[MetaDataKeyColumns] = columns
|
||||
return doc
|
||||
}
|
||||
|
||||
func GetDocumentColumnData(doc *schema.Document) ([]*ColumnData, error) {
|
||||
if doc == nil || doc.MetaData == nil {
|
||||
return nil, fmt.Errorf("invalid document")
|
||||
}
|
||||
|
||||
data, ok := doc.MetaData[MetaDataKeyColumnData].([]*ColumnData)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid document column data")
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func WithDocumentColumnData(doc *schema.Document, data []*ColumnData) *schema.Document {
|
||||
doc.MetaData[MetaDataKeyColumnData] = data
|
||||
return doc
|
||||
}
|
||||
|
||||
func WithDocumentColumnsOnly(doc *schema.Document) *schema.Document {
|
||||
doc.MetaData[MetaDataKeyColumnsOnly] = struct{}{}
|
||||
return doc
|
||||
}
|
||||
|
||||
func GetDocumentColumnsOnly(doc *schema.Document) (bool, error) {
|
||||
if doc == nil || doc.MetaData == nil {
|
||||
return false, fmt.Errorf("invalid document")
|
||||
}
|
||||
|
||||
_, ok := doc.MetaData[MetaDataKeyColumnsOnly].(struct{})
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
func GetDocumentsColumnsOnly(docs []*schema.Document) (bool, error) {
|
||||
if len(docs) != 1 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return GetDocumentColumnsOnly(docs[0])
|
||||
}
|
||||
|
||||
func GetDocumentCreatorID(doc *schema.Document) (int64, error) {
|
||||
if doc == nil || doc.MetaData == nil {
|
||||
return 0, fmt.Errorf("invalid document")
|
||||
}
|
||||
|
||||
creatorID, ok := doc.MetaData[MetaDataKeyCreatorID].(int64)
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("invalid document creator id")
|
||||
}
|
||||
|
||||
return creatorID, nil
|
||||
}
|
||||
|
||||
func WithDocumentCreatorID(doc *schema.Document, creatorID int64) *schema.Document {
|
||||
doc.MetaData[MetaDataKeyCreatorID] = creatorID
|
||||
return doc
|
||||
}
|
||||
|
||||
func GetDocumentExternalStorage(doc *schema.Document) (map[string]any, error) {
|
||||
if doc == nil || doc.MetaData == nil {
|
||||
return nil, fmt.Errorf("invalid document")
|
||||
}
|
||||
|
||||
data, ok := doc.MetaData[MetaDataKeyExternalStorage].(map[string]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid document external storage")
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func WithDocumentExternalStorage(doc *schema.Document, externalStorage map[string]any) *schema.Document {
|
||||
doc.MetaData[MetaDataKeyExternalStorage] = externalStorage
|
||||
return doc
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package imageunderstand
|
||||
|
||||
import "context"
|
||||
|
||||
type ImageUnderstand interface {
|
||||
ImageUnderstand(ctx context.Context, image []byte) (content string, err error)
|
||||
}
|
||||
29
backend/infra/contract/document/nl2sql/nl2sql.go
Normal file
29
backend/infra/contract/document/nl2sql/nl2sql.go
Normal file
@@ -0,0 +1,29 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package nl2sql
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
|
||||
)
|
||||
|
||||
type NL2SQL interface {
|
||||
NL2SQL(ctx context.Context, messages []*schema.Message, tables []*document.TableSchema, opts ...Option) (sql string, err error)
|
||||
}
|
||||
31
backend/infra/contract/document/nl2sql/options.go
Normal file
31
backend/infra/contract/document/nl2sql/options.go
Normal file
@@ -0,0 +1,31 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package nl2sql
|
||||
|
||||
import "github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
|
||||
|
||||
type Option func(o *Options)
|
||||
|
||||
type Options struct {
|
||||
ChatModel chatmodel.BaseChatModel
|
||||
}
|
||||
|
||||
func WithChatModel(cm chatmodel.BaseChatModel) Option {
|
||||
return func(o *Options) {
|
||||
o.ChatModel = cm
|
||||
}
|
||||
}
|
||||
24
backend/infra/contract/document/ocr/ocr.go
Normal file
24
backend/infra/contract/document/ocr/ocr.go
Normal file
@@ -0,0 +1,24 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package ocr
|
||||
|
||||
import "context"
|
||||
|
||||
type OCR interface {
|
||||
FromBase64(ctx context.Context, b64 string) (texts []string, err error)
|
||||
FromURL(ctx context.Context, url string) (texts []string, err error)
|
||||
}
|
||||
128
backend/infra/contract/document/parser/manager.go
Normal file
128
backend/infra/contract/document/parser/manager.go
Normal file
@@ -0,0 +1,128 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package parser
|
||||
|
||||
import (
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/sets"
|
||||
)
|
||||
|
||||
type Manager interface {
|
||||
GetParser(config *Config) (Parser, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
FileExtension FileExtension
|
||||
ParsingStrategy *ParsingStrategy
|
||||
ChunkingStrategy *ChunkingStrategy
|
||||
}
|
||||
|
||||
// ParsingStrategy for document parse before indexing
|
||||
type ParsingStrategy struct {
|
||||
// Doc
|
||||
ExtractImage bool `json:"extract_image"` // 提取图片元素
|
||||
ExtractTable bool `json:"extract_table"` // 提取表格元素
|
||||
ImageOCR bool `json:"image_ocr"` // 图片 ocr
|
||||
FilterPages []int `json:"filter_pages"` // 页过滤, 第一页=1
|
||||
|
||||
// Sheet
|
||||
SheetID *int `json:"sheet_id"` // xlsx sheet id
|
||||
HeaderLine int `json:"header_line"` // 表头行
|
||||
DataStartLine int `json:"data_start_line"` // 数据起始行
|
||||
RowsCount int `json:"rows_count"` // 读取数据行数
|
||||
IsAppend bool `json:"-"` // 行插入
|
||||
Columns []*document.Column `json:"-"` // sheet 对齐表头
|
||||
IgnoreColumnTypeErr bool `json:"-"` // true 时忽略 column type 与 value 未对齐的问题,此时 value 为空
|
||||
|
||||
// Image
|
||||
ImageAnnotationType ImageAnnotationType `json:"image_annotation_type"` // 图片内容标注类型
|
||||
}
|
||||
|
||||
type ChunkingStrategy struct {
|
||||
ChunkType ChunkType `json:"chunk_type"`
|
||||
|
||||
// custom config
|
||||
ChunkSize int64 `json:"chunk_size"` // 分段最大长度
|
||||
Separator string `json:"separator"` // 分段标识符
|
||||
Overlap int64 `json:"overlap"` // 分段重叠比例
|
||||
TrimSpace bool `json:"trim_space"`
|
||||
TrimURLAndEmail bool `json:"trim_url_and_email"`
|
||||
|
||||
// leveled config
|
||||
MaxDepth int64 `json:"max_depth"` // 按层级分段时的最大层级
|
||||
SaveTitle bool `json:"save_title"` // 保留层级标题
|
||||
}
|
||||
|
||||
type ChunkType int64
|
||||
|
||||
const (
|
||||
ChunkTypeDefault ChunkType = 0 // 自动分片
|
||||
ChunkTypeCustom ChunkType = 1 // 自定义规则分片
|
||||
ChunkTypeLeveled ChunkType = 2 // 层级分片
|
||||
)
|
||||
|
||||
type ImageAnnotationType int64
|
||||
|
||||
const (
|
||||
ImageAnnotationTypeModel ImageAnnotationType = 0 // 模型自动标注
|
||||
ImageAnnotationTypeManual ImageAnnotationType = 1 // 人工标注
|
||||
)
|
||||
|
||||
type FileExtension string
|
||||
|
||||
const (
|
||||
// document
|
||||
FileExtensionPDF FileExtension = "pdf"
|
||||
FileExtensionTXT FileExtension = "txt"
|
||||
FileExtensionDoc FileExtension = "doc"
|
||||
FileExtensionDocx FileExtension = "docx"
|
||||
FileExtensionMarkdown FileExtension = "md"
|
||||
|
||||
// sheet
|
||||
FileExtensionCSV FileExtension = "csv"
|
||||
FileExtensionXLSX FileExtension = "xlsx"
|
||||
FileExtensionJSON FileExtension = "json"
|
||||
FileExtensionJsonMaps FileExtension = "json_maps" // json of []map[string]string
|
||||
|
||||
// image
|
||||
FileExtensionJPG FileExtension = "jpg"
|
||||
FileExtensionJPEG FileExtension = "jpeg"
|
||||
FileExtensionPNG FileExtension = "png"
|
||||
)
|
||||
|
||||
func ValidateFileExtension(fileSuffix string) (ext FileExtension, support bool) {
|
||||
fileExtension := FileExtension(fileSuffix)
|
||||
_, ok := fileExtensionSet[fileExtension]
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
return fileExtension, true
|
||||
}
|
||||
|
||||
var fileExtensionSet = sets.Set[FileExtension]{
|
||||
FileExtensionPDF: {},
|
||||
FileExtensionTXT: {},
|
||||
FileExtensionDoc: {},
|
||||
FileExtensionDocx: {},
|
||||
FileExtensionMarkdown: {},
|
||||
FileExtensionCSV: {},
|
||||
FileExtensionJSON: {},
|
||||
FileExtensionJsonMaps: {},
|
||||
FileExtensionJPG: {},
|
||||
FileExtensionJPEG: {},
|
||||
FileExtensionPNG: {},
|
||||
}
|
||||
21
backend/infra/contract/document/parser/parser.go
Normal file
21
backend/infra/contract/document/parser/parser.go
Normal file
@@ -0,0 +1,21 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package parser
|
||||
|
||||
import "github.com/cloudwego/eino/components/document/parser"
|
||||
|
||||
type Parser = parser.Parser
|
||||
26
backend/infra/contract/document/progressbar/interface.go
Normal file
26
backend/infra/contract/document/progressbar/interface.go
Normal file
@@ -0,0 +1,26 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package progressbar
|
||||
|
||||
import "context"
|
||||
|
||||
// ProgressBar is the interface for the progress bar.
|
||||
type ProgressBar interface {
|
||||
AddN(n int) error
|
||||
ReportError(err error) error
|
||||
GetProgress(ctx context.Context) (percent int, remainSec int, errMsg string)
|
||||
}
|
||||
43
backend/infra/contract/document/rerank/rerank.go
Normal file
43
backend/infra/contract/document/rerank/rerank.go
Normal file
@@ -0,0 +1,43 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package rerank
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
type Reranker interface {
|
||||
Rerank(ctx context.Context, req *Request) (*Response, error)
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
Query string
|
||||
Data [][]*Data
|
||||
TopN *int64
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
SortedData []*Data // 高分在前
|
||||
TokenUsage *int64
|
||||
}
|
||||
|
||||
type Data struct {
|
||||
Document *schema.Document
|
||||
Score float64
|
||||
}
|
||||
54
backend/infra/contract/document/searchstore/dsl.go
Normal file
54
backend/infra/contract/document/searchstore/dsl.go
Normal file
@@ -0,0 +1,54 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package searchstore
|
||||
|
||||
import "fmt"
|
||||
|
||||
type DSL struct {
|
||||
Op Op
|
||||
Field string
|
||||
Value interface{} // builtin types / []*DSL
|
||||
}
|
||||
|
||||
type Op string
|
||||
|
||||
const (
|
||||
OpEq Op = "eq"
|
||||
OpNe Op = "ne"
|
||||
OpLike Op = "like"
|
||||
OpIn Op = "in"
|
||||
|
||||
OpAnd Op = "and"
|
||||
OpOr Op = "or"
|
||||
)
|
||||
|
||||
func (d *DSL) DSL() map[string]any {
|
||||
return map[string]any{"dsl": d}
|
||||
}
|
||||
|
||||
func LoadDSL(src map[string]any) (*DSL, error) {
|
||||
if src == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
dsl, ok := src["dsl"].(*DSL)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("load dsl failed")
|
||||
}
|
||||
|
||||
return dsl, nil
|
||||
}
|
||||
82
backend/infra/contract/document/searchstore/manager.go
Normal file
82
backend/infra/contract/document/searchstore/manager.go
Normal file
@@ -0,0 +1,82 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package searchstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type Manager interface {
|
||||
Create(ctx context.Context, req *CreateRequest) error
|
||||
|
||||
Drop(ctx context.Context, req *DropRequest) error
|
||||
|
||||
GetType() SearchStoreType
|
||||
|
||||
GetSearchStore(ctx context.Context, collectionName string) (SearchStore, error)
|
||||
}
|
||||
|
||||
type CreateRequest struct {
|
||||
CollectionName string
|
||||
Fields []*Field
|
||||
CollectionMeta map[string]string
|
||||
}
|
||||
|
||||
type DropRequest struct {
|
||||
CollectionName string
|
||||
}
|
||||
|
||||
type GetSearchStoreRequest struct {
|
||||
CollectionName string
|
||||
}
|
||||
|
||||
type Field struct {
|
||||
Name FieldName
|
||||
Type FieldType
|
||||
Description string
|
||||
|
||||
Nullable bool
|
||||
IsPrimary bool
|
||||
|
||||
Indexing bool
|
||||
}
|
||||
|
||||
type SearchStoreType string
|
||||
|
||||
const (
|
||||
TypeVectorStore SearchStoreType = "vector"
|
||||
TypeTextStore SearchStoreType = "text"
|
||||
)
|
||||
|
||||
type FieldName = string
|
||||
|
||||
// 内置 field name
|
||||
const (
|
||||
FieldID FieldName = "id" // int64
|
||||
FieldCreatorID FieldName = "creator_id" // int64
|
||||
FieldTextContent FieldName = "text_content" // string
|
||||
)
|
||||
|
||||
type FieldType int64
|
||||
|
||||
const (
|
||||
FieldTypeUnknown FieldType = 0
|
||||
FieldTypeInt64 FieldType = 1
|
||||
FieldTypeText FieldType = 2
|
||||
FieldTypeDenseVector FieldType = 3
|
||||
FieldTypeSparseVector FieldType = 4
|
||||
)
|
||||
87
backend/infra/contract/document/searchstore/options.go
Normal file
87
backend/infra/contract/document/searchstore/options.go
Normal file
@@ -0,0 +1,87 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package searchstore
|
||||
|
||||
import (
|
||||
"github.com/cloudwego/eino/components/indexer"
|
||||
"github.com/cloudwego/eino/components/retriever"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/progressbar"
|
||||
)
|
||||
|
||||
type IndexerOptions struct {
|
||||
PartitionKey *string
|
||||
Partition *string // 存储分片映射
|
||||
IndexingFields []string
|
||||
ProgressBar progressbar.ProgressBar
|
||||
}
|
||||
|
||||
type RetrieverOptions struct {
|
||||
MultiMatch *MultiMatch // 多 field 查询
|
||||
PartitionKey *string
|
||||
Partitions []string // 查询分片映射
|
||||
}
|
||||
|
||||
type MultiMatch struct {
|
||||
Fields []string
|
||||
Query string
|
||||
}
|
||||
|
||||
func WithIndexerPartitionKey(key string) indexer.Option {
|
||||
return indexer.WrapImplSpecificOptFn(func(o *IndexerOptions) {
|
||||
o.PartitionKey = &key
|
||||
})
|
||||
}
|
||||
|
||||
func WithPartition(partition string) indexer.Option {
|
||||
return indexer.WrapImplSpecificOptFn(func(o *IndexerOptions) {
|
||||
o.Partition = &partition
|
||||
})
|
||||
}
|
||||
|
||||
func WithIndexingFields(fields []string) indexer.Option {
|
||||
return indexer.WrapImplSpecificOptFn(func(o *IndexerOptions) {
|
||||
o.IndexingFields = fields
|
||||
})
|
||||
}
|
||||
|
||||
func WithProgressBar(progressBar progressbar.ProgressBar) indexer.Option {
|
||||
return indexer.WrapImplSpecificOptFn(func(o *IndexerOptions) {
|
||||
o.ProgressBar = progressBar
|
||||
})
|
||||
}
|
||||
|
||||
func WithMultiMatch(fields []string, query string) retriever.Option {
|
||||
return retriever.WrapImplSpecificOptFn(func(o *RetrieverOptions) {
|
||||
o.MultiMatch = &MultiMatch{
|
||||
Fields: fields,
|
||||
Query: query,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func WithRetrieverPartitionKey(key string) retriever.Option {
|
||||
return retriever.WrapImplSpecificOptFn(func(o *RetrieverOptions) {
|
||||
o.PartitionKey = &key
|
||||
})
|
||||
}
|
||||
|
||||
func WithPartitions(partitions []string) retriever.Option {
|
||||
return retriever.WrapImplSpecificOptFn(func(o *RetrieverOptions) {
|
||||
o.Partitions = partitions
|
||||
})
|
||||
}
|
||||
32
backend/infra/contract/document/searchstore/searchstore.go
Normal file
32
backend/infra/contract/document/searchstore/searchstore.go
Normal file
@@ -0,0 +1,32 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package searchstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cloudwego/eino/components/indexer"
|
||||
"github.com/cloudwego/eino/components/retriever"
|
||||
)
|
||||
|
||||
type SearchStore interface {
|
||||
indexer.Indexer
|
||||
|
||||
retriever.Retriever
|
||||
|
||||
Delete(ctx context.Context, ids []string) error
|
||||
}
|
||||
155
backend/infra/contract/document/table.go
Normal file
155
backend/infra/contract/document/table.go
Normal file
@@ -0,0 +1,155 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package document
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
type TableSchema struct {
|
||||
Name string
|
||||
Comment string
|
||||
Columns []*Column
|
||||
}
|
||||
|
||||
type Column struct {
|
||||
ID int64
|
||||
Name string
|
||||
Type TableColumnType
|
||||
Description string
|
||||
Nullable bool
|
||||
IsPrimary bool
|
||||
Sequence int // 排序编号
|
||||
}
|
||||
|
||||
type TableColumnType int
|
||||
|
||||
const (
|
||||
TableColumnTypeUnknown TableColumnType = 0
|
||||
TableColumnTypeString TableColumnType = 1
|
||||
TableColumnTypeInteger TableColumnType = 2
|
||||
TableColumnTypeTime TableColumnType = 3
|
||||
TableColumnTypeNumber TableColumnType = 4
|
||||
TableColumnTypeBoolean TableColumnType = 5
|
||||
TableColumnTypeImage TableColumnType = 6
|
||||
)
|
||||
|
||||
func (t TableColumnType) String() string {
|
||||
switch t {
|
||||
case TableColumnTypeString:
|
||||
return "varchar"
|
||||
case TableColumnTypeInteger:
|
||||
return "bigint"
|
||||
case TableColumnTypeTime:
|
||||
return "timestamp"
|
||||
case TableColumnTypeNumber:
|
||||
return "double"
|
||||
case TableColumnTypeBoolean:
|
||||
return "boolean"
|
||||
case TableColumnTypeImage:
|
||||
return "image"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
type ColumnData struct {
|
||||
ColumnID int64
|
||||
ColumnName string
|
||||
Type TableColumnType
|
||||
ValString *string
|
||||
ValInteger *int64
|
||||
ValTime *time.Time
|
||||
ValNumber *float64
|
||||
ValBoolean *bool
|
||||
ValImage *string // base64 / url
|
||||
}
|
||||
|
||||
func (d *ColumnData) GetValue() interface{} {
|
||||
switch d.Type {
|
||||
case TableColumnTypeString:
|
||||
return d.ValString
|
||||
case TableColumnTypeInteger:
|
||||
return d.ValInteger
|
||||
case TableColumnTypeTime:
|
||||
return d.ValTime
|
||||
case TableColumnTypeNumber:
|
||||
return d.ValNumber
|
||||
case TableColumnTypeBoolean:
|
||||
return d.ValBoolean
|
||||
case TableColumnTypeImage:
|
||||
return d.ValImage
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (d *ColumnData) GetStringValue() string {
|
||||
switch d.Type {
|
||||
case TableColumnTypeString:
|
||||
return ptr.From(d.ValString)
|
||||
case TableColumnTypeInteger:
|
||||
return strconv.FormatInt(ptr.From(d.ValInteger), 10)
|
||||
case TableColumnTypeTime:
|
||||
return ptr.From(d.ValTime).Format(time.DateTime)
|
||||
case TableColumnTypeNumber:
|
||||
return strconv.FormatFloat(ptr.From(d.ValNumber), 'f', 20, 64)
|
||||
case TableColumnTypeBoolean:
|
||||
return strconv.FormatBool(ptr.From(d.ValBoolean))
|
||||
case TableColumnTypeImage:
|
||||
return ptr.From(d.ValImage)
|
||||
default:
|
||||
return ptr.From(d.ValString)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *ColumnData) GetNullableStringValue() string {
|
||||
switch d.Type {
|
||||
case TableColumnTypeString:
|
||||
return ptr.From(d.ValString)
|
||||
case TableColumnTypeInteger:
|
||||
if d.ValInteger == nil {
|
||||
return ""
|
||||
}
|
||||
return strconv.FormatInt(ptr.From(d.ValInteger), 10)
|
||||
case TableColumnTypeTime:
|
||||
if d.ValTime == nil {
|
||||
return ""
|
||||
}
|
||||
return ptr.From(d.ValTime).Format(time.DateTime)
|
||||
case TableColumnTypeNumber:
|
||||
if d.ValNumber == nil {
|
||||
return ""
|
||||
}
|
||||
return strconv.FormatFloat(ptr.From(d.ValNumber), 'f', 20, 64)
|
||||
case TableColumnTypeBoolean:
|
||||
if d.ValBoolean == nil {
|
||||
return ""
|
||||
}
|
||||
return strconv.FormatBool(ptr.From(d.ValBoolean))
|
||||
case TableColumnTypeImage:
|
||||
if d.ValImage == nil {
|
||||
return ""
|
||||
}
|
||||
return ptr.From(d.ValImage)
|
||||
default:
|
||||
return ptr.From(d.ValString)
|
||||
}
|
||||
}
|
||||
39
backend/infra/contract/dynconf/provider.go
Normal file
39
backend/infra/contract/dynconf/provider.go
Normal file
@@ -0,0 +1,39 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package dynconf
|
||||
|
||||
import "context"
|
||||
|
||||
// zookeeper, etcd, nacos
|
||||
|
||||
type Provider interface {
|
||||
Initialize(ctx context.Context, namespace, group string, opts ...Option) (DynamicClient, error)
|
||||
}
|
||||
|
||||
type DynamicClient interface {
|
||||
AddListener(key string, callback func(value string, err error)) error
|
||||
RemoveListener(key string) error
|
||||
Get(ctx context.Context, key string) (string, error)
|
||||
}
|
||||
|
||||
type options struct{}
|
||||
|
||||
type Option struct {
|
||||
apply func(opts *options)
|
||||
|
||||
implSpecificOptFn any
|
||||
}
|
||||
37
backend/infra/contract/embedding/embedding.go
Normal file
37
backend/infra/contract/embedding/embedding.go
Normal file
@@ -0,0 +1,37 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package embedding
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cloudwego/eino/components/embedding"
|
||||
)
|
||||
|
||||
type Embedder interface {
|
||||
embedding.Embedder
|
||||
EmbedStringsHybrid(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, []map[int]float64, error) // hybrid embedding
|
||||
Dimensions() int64
|
||||
SupportStatus() SupportStatus
|
||||
}
|
||||
|
||||
type SupportStatus int
|
||||
|
||||
const (
|
||||
SupportDense SupportStatus = 1
|
||||
SupportDenseAndSparse SupportStatus = 3
|
||||
)
|
||||
44
backend/infra/contract/es/es.go
Normal file
44
backend/infra/contract/es/es.go
Normal file
@@ -0,0 +1,44 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package es
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type Client interface {
|
||||
Create(ctx context.Context, index, id string, document any) error
|
||||
Update(ctx context.Context, index, id string, document any) error
|
||||
Delete(ctx context.Context, index, id string) error
|
||||
Search(ctx context.Context, index string, req *Request) (*Response, error)
|
||||
Exists(ctx context.Context, index string) (bool, error)
|
||||
CreateIndex(ctx context.Context, index string, properties map[string]any) error
|
||||
DeleteIndex(ctx context.Context, index string) error
|
||||
Types() Types
|
||||
NewBulkIndexer(index string) (BulkIndexer, error)
|
||||
}
|
||||
|
||||
type Types interface {
|
||||
NewLongNumberProperty() any
|
||||
NewTextProperty() any
|
||||
NewUnsignedLongNumberProperty() any
|
||||
}
|
||||
|
||||
type BulkIndexer interface {
|
||||
Add(ctx context.Context, item BulkIndexerItem) error
|
||||
Close(ctx context.Context) error
|
||||
}
|
||||
73
backend/infra/contract/es/model.go
Normal file
73
backend/infra/contract/es/model.go
Normal file
@@ -0,0 +1,73 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package es
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
|
||||
"github.com/elastic/go-elasticsearch/v8/typedapi/types/enums/totalhitsrelation"
|
||||
)
|
||||
|
||||
type BulkIndexerItem struct {
|
||||
Index string
|
||||
Action string
|
||||
DocumentID string
|
||||
Routing string
|
||||
Version *int64
|
||||
VersionType string
|
||||
Body io.ReadSeeker
|
||||
RetryOnConflict *int
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
Size *int
|
||||
Query *Query
|
||||
MinScore *float64
|
||||
Sort []SortFiled
|
||||
SearchAfter []any
|
||||
From *int
|
||||
}
|
||||
|
||||
type SortFiled struct {
|
||||
Field string
|
||||
Asc bool
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
Hits HitsMetadata `json:"hits"`
|
||||
MaxScore *float64 `json:"max_score,omitempty"`
|
||||
}
|
||||
|
||||
type HitsMetadata struct {
|
||||
Hits []Hit `json:"hits"`
|
||||
MaxScore *float64 `json:"max_score,omitempty"`
|
||||
// Total Total hit count information, present only if `track_total_hits` wasn't
|
||||
// `false` in the search request.
|
||||
Total *TotalHits `json:"total,omitempty"`
|
||||
}
|
||||
|
||||
type Hit struct {
|
||||
Id_ *string `json:"_id,omitempty"`
|
||||
Score_ *float64 `json:"_score,omitempty"`
|
||||
Source_ json.RawMessage `json:"_source,omitempty"`
|
||||
}
|
||||
|
||||
type TotalHits struct {
|
||||
Relation totalhitsrelation.TotalHitsRelation `json:"relation"`
|
||||
Value int64 `json:"value"`
|
||||
}
|
||||
111
backend/infra/contract/es/query.go
Normal file
111
backend/infra/contract/es/query.go
Normal file
@@ -0,0 +1,111 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package es
|
||||
|
||||
const (
|
||||
QueryTypeEqual = "equal"
|
||||
QueryTypeMatch = "match"
|
||||
QueryTypeMultiMatch = "multi_match"
|
||||
QueryTypeNotExists = "not_exists"
|
||||
QueryTypeContains = "contains"
|
||||
QueryTypeIn = "in"
|
||||
)
|
||||
|
||||
type KV struct {
|
||||
Key string
|
||||
Value any
|
||||
}
|
||||
|
||||
type QueryType string
|
||||
|
||||
type Query struct {
|
||||
KV KV
|
||||
Type QueryType
|
||||
MultiMatchQuery MultiMatchQuery
|
||||
Bool *BoolQuery
|
||||
}
|
||||
|
||||
type BoolQuery struct {
|
||||
Filter []Query
|
||||
Must []Query
|
||||
MustNot []Query
|
||||
Should []Query
|
||||
MinimumShouldMatch *int
|
||||
}
|
||||
|
||||
type MultiMatchQuery struct {
|
||||
Fields []string
|
||||
Type string // best_fields
|
||||
Query string
|
||||
Operator string
|
||||
}
|
||||
|
||||
const (
|
||||
Or = "or"
|
||||
And = "and"
|
||||
)
|
||||
|
||||
func NewEqualQuery(k string, v any) Query {
|
||||
return Query{
|
||||
KV: KV{Key: k, Value: v},
|
||||
Type: QueryTypeEqual,
|
||||
}
|
||||
}
|
||||
|
||||
func NewMatchQuery(k string, v any) Query {
|
||||
return Query{
|
||||
KV: KV{Key: k, Value: v},
|
||||
Type: QueryTypeMatch,
|
||||
}
|
||||
}
|
||||
|
||||
func NewMultiMatchQuery(fields []string, query, typeStr, operator string) Query {
|
||||
return Query{
|
||||
Type: QueryTypeMultiMatch,
|
||||
MultiMatchQuery: MultiMatchQuery{
|
||||
Fields: fields,
|
||||
Query: query,
|
||||
Operator: operator,
|
||||
Type: typeStr,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func NewNotExistsQuery(k string) Query {
|
||||
return Query{
|
||||
KV: KV{Key: k},
|
||||
Type: QueryTypeNotExists,
|
||||
}
|
||||
}
|
||||
|
||||
func NewContainsQuery(k string, v any) Query {
|
||||
return Query{
|
||||
KV: KV{Key: k, Value: v},
|
||||
Type: QueryTypeContains,
|
||||
}
|
||||
}
|
||||
|
||||
func NewInQuery[T any](k string, v []T) Query {
|
||||
arr := make([]any, 0, len(v))
|
||||
for _, item := range v {
|
||||
arr = append(arr, item)
|
||||
}
|
||||
return Query{
|
||||
KV: KV{Key: k, Value: arr},
|
||||
Type: QueryTypeIn,
|
||||
}
|
||||
}
|
||||
30
backend/infra/contract/eventbus/consume_option.go
Normal file
30
backend/infra/contract/eventbus/consume_option.go
Normal file
@@ -0,0 +1,30 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package eventbus
|
||||
|
||||
type ConsumerOpt func(option *ConsumerOption)
|
||||
|
||||
type ConsumerOption struct {
|
||||
Orderly *bool
|
||||
// ConsumeFromWhere
|
||||
}
|
||||
|
||||
func WithConsumerOrderly(orderly bool) ConsumerOpt {
|
||||
return func(option *ConsumerOption) {
|
||||
option.Orderly = &orderly
|
||||
}
|
||||
}
|
||||
36
backend/infra/contract/eventbus/eventbus.go
Normal file
36
backend/infra/contract/eventbus/eventbus.go
Normal file
@@ -0,0 +1,36 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package eventbus
|
||||
|
||||
import "context"
|
||||
|
||||
//go:generate mockgen -destination ../../../internal/mock/infra/contract/eventbus/eventbus_mock.go -package mock -source eventbus.go Factory
|
||||
type Producer interface {
|
||||
Send(ctx context.Context, body []byte, opts ...SendOpt) error
|
||||
BatchSend(ctx context.Context, bodyArr [][]byte, opts ...SendOpt) error
|
||||
}
|
||||
|
||||
type Consumer interface{}
|
||||
type ConsumerHandler interface {
|
||||
HandleMessage(ctx context.Context, msg *Message) error
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Topic string
|
||||
Group string
|
||||
Body []byte
|
||||
}
|
||||
29
backend/infra/contract/eventbus/send_option.go
Normal file
29
backend/infra/contract/eventbus/send_option.go
Normal file
@@ -0,0 +1,29 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package eventbus
|
||||
|
||||
type SendOpt func(option *SendOption)
|
||||
|
||||
type SendOption struct {
|
||||
ShardingKey *string
|
||||
}
|
||||
|
||||
func WithShardingKey(key string) SendOpt {
|
||||
return func(o *SendOption) {
|
||||
o.ShardingKey = &key
|
||||
}
|
||||
}
|
||||
27
backend/infra/contract/idgen/idgen.go
Normal file
27
backend/infra/contract/idgen/idgen.go
Normal file
@@ -0,0 +1,27 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package idgen
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
//go:generate mockgen -destination ../../../internal/mock/infra/contract/idgen/idgen_mock.go --package mock -source idgen.go
|
||||
type IDGenerator interface {
|
||||
GenID(ctx context.Context) (int64, error)
|
||||
GenMultiIDs(ctx context.Context, counts int) ([]int64, error) // suggest batch size <= 200
|
||||
}
|
||||
50
backend/infra/contract/imagex/get_resource_opt.go
Normal file
50
backend/infra/contract/imagex/get_resource_opt.go
Normal file
@@ -0,0 +1,50 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package imagex
|
||||
|
||||
type GetResourceOpt func(option *GetResourceOption)
|
||||
|
||||
type GetResourceOption struct {
|
||||
Format string
|
||||
Template string
|
||||
Proto string
|
||||
Expire int
|
||||
}
|
||||
|
||||
func WithResourceFormat(format string) GetResourceOpt {
|
||||
return func(o *GetResourceOption) {
|
||||
o.Format = format
|
||||
}
|
||||
}
|
||||
|
||||
func WithResourceTemplate(template string) GetResourceOpt {
|
||||
return func(o *GetResourceOption) {
|
||||
o.Template = template
|
||||
}
|
||||
}
|
||||
|
||||
func WithResourceProto(proto string) GetResourceOpt {
|
||||
return func(o *GetResourceOption) {
|
||||
o.Proto = proto
|
||||
}
|
||||
}
|
||||
|
||||
func WithResourceExpire(expire int) GetResourceOpt {
|
||||
return func(o *GetResourceOption) {
|
||||
o.Expire = expire
|
||||
}
|
||||
}
|
||||
71
backend/infra/contract/imagex/imagex.go
Normal file
71
backend/infra/contract/imagex/imagex.go
Normal file
@@ -0,0 +1,71 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package imagex
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
//go:generate mockgen -destination ../../../internal/mock/infra/contract/imagex/imagex_mock.go --package imagex -source imagex.go
|
||||
type ImageX interface {
|
||||
GetUploadAuth(ctx context.Context, opt ...UploadAuthOpt) (*SecurityToken, error)
|
||||
GetUploadAuthWithExpire(ctx context.Context, expire time.Duration, opt ...UploadAuthOpt) (*SecurityToken, error)
|
||||
GetResourceURL(ctx context.Context, uri string, opts ...GetResourceOpt) (*ResourceURL, error)
|
||||
Upload(ctx context.Context, data []byte, opts ...UploadAuthOpt) (*UploadResult, error)
|
||||
GetServerID() string
|
||||
GetUploadHost(ctx context.Context) string
|
||||
}
|
||||
|
||||
type SecurityToken struct {
|
||||
AccessKeyID string `thrift:"access_key_id,1" frugal:"1,default,string" json:"access_key_id"`
|
||||
SecretAccessKey string `thrift:"secret_access_key,2" frugal:"2,default,string" json:"secret_access_key"`
|
||||
SessionToken string `thrift:"session_token,3" frugal:"3,default,string" json:"session_token"`
|
||||
ExpiredTime string `thrift:"expired_time,4" frugal:"4,default,string" json:"expired_time"`
|
||||
CurrentTime string `thrift:"current_time,5" frugal:"5,default,string" json:"current_time"`
|
||||
HostScheme string `thrift:"host_scheme,6" frugal:"6,default,string" json:"host_scheme"`
|
||||
}
|
||||
|
||||
type ResourceURL struct {
|
||||
// REQUIRED; 结果图访问精简地址,与默认地址相比缺少 Bucket 部分。
|
||||
CompactURL string `json:"CompactURL"`
|
||||
// REQUIRED; 结果图访问默认地址。
|
||||
URL string `json:"URL"`
|
||||
}
|
||||
|
||||
type UploadResult struct {
|
||||
Result *Result `json:"Results"`
|
||||
RequestId string `json:"RequestId"`
|
||||
FileInfo *FileInfo `json:"PluginResult"`
|
||||
}
|
||||
|
||||
type Result struct {
|
||||
Uri string `json:"Uri"`
|
||||
UriStatus int `json:"UriStatus"` // 2000表示上传成功
|
||||
}
|
||||
|
||||
type FileInfo struct {
|
||||
Name string `json:"FileName"`
|
||||
Uri string `json:"ImageUri"`
|
||||
ImageWidth int `json:"ImageWidth"`
|
||||
ImageHeight int `json:"ImageHeight"`
|
||||
Md5 string `json:"ImageMd5"`
|
||||
ImageFormat string `json:"ImageFormat"`
|
||||
ImageSize int `json:"ImageSize"`
|
||||
FrameCnt int `json:"FrameCnt"`
|
||||
Duration int `json:"Duration"`
|
||||
}
|
||||
72
backend/infra/contract/imagex/upload_auth_opt.go
Normal file
72
backend/infra/contract/imagex/upload_auth_opt.go
Normal file
@@ -0,0 +1,72 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package imagex
|
||||
|
||||
type UploadAuthOpt func(option *UploadAuthOption)
|
||||
|
||||
type UploadAuthOption struct {
|
||||
ContentTypeBlackList []string
|
||||
ContentTypeWhiteList []string
|
||||
FileSizeUpLimit *string
|
||||
FileSizeBottomLimit *string
|
||||
KeyPtn *string
|
||||
UploadOverWrite *bool
|
||||
conditions map[string]string
|
||||
StoreKey *string
|
||||
}
|
||||
|
||||
func WithStoreKey(key string) UploadAuthOpt {
|
||||
return func(o *UploadAuthOption) {
|
||||
o.StoreKey = &key
|
||||
}
|
||||
}
|
||||
|
||||
func WithUploadKeyPtn(ptn string) UploadAuthOpt {
|
||||
return func(o *UploadAuthOption) {
|
||||
o.KeyPtn = &ptn
|
||||
}
|
||||
}
|
||||
|
||||
func WithUploadOverwrite(overwrite bool) UploadAuthOpt {
|
||||
return func(op *UploadAuthOption) {
|
||||
op.UploadOverWrite = &overwrite
|
||||
}
|
||||
}
|
||||
|
||||
func WithUploadContentTypeBlackList(blackList []string) UploadAuthOpt {
|
||||
return func(op *UploadAuthOption) {
|
||||
op.ContentTypeBlackList = blackList
|
||||
}
|
||||
}
|
||||
|
||||
func WithUploadContentTypeWhiteList(whiteList []string) UploadAuthOpt {
|
||||
return func(op *UploadAuthOption) {
|
||||
op.ContentTypeWhiteList = whiteList
|
||||
}
|
||||
}
|
||||
|
||||
func WithUploadFileSizeUpLimit(limit string) UploadAuthOpt {
|
||||
return func(op *UploadAuthOption) {
|
||||
op.FileSizeUpLimit = &limit
|
||||
}
|
||||
}
|
||||
|
||||
func WithUploadFileSizeBottomLimit(limit string) UploadAuthOpt {
|
||||
return func(op *UploadAuthOption) {
|
||||
op.FileSizeBottomLimit = &limit
|
||||
}
|
||||
}
|
||||
27
backend/infra/contract/messages2query/messages_to_query.go
Normal file
27
backend/infra/contract/messages2query/messages_to_query.go
Normal file
@@ -0,0 +1,27 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package messages2query
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
type MessagesToQuery interface {
|
||||
MessagesToQuery(ctx context.Context, messages []*schema.Message, opts ...Option) (newQuery string, err error)
|
||||
}
|
||||
31
backend/infra/contract/messages2query/options.go
Normal file
31
backend/infra/contract/messages2query/options.go
Normal file
@@ -0,0 +1,31 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package messages2query
|
||||
|
||||
import "github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
|
||||
|
||||
type Option func(o *Options)
|
||||
|
||||
type Options struct {
|
||||
ChatModel chatmodel.BaseChatModel
|
||||
}
|
||||
|
||||
func WithChatModel(cm chatmodel.BaseChatModel) Option {
|
||||
return func(o *Options) {
|
||||
o.ChatModel = cm
|
||||
}
|
||||
}
|
||||
23
backend/infra/contract/orm/database.go
Normal file
23
backend/infra/contract/orm/database.go
Normal file
@@ -0,0 +1,23 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type DB = gorm.DB
|
||||
91
backend/infra/contract/rdb/entity/const.go
Normal file
91
backend/infra/contract/rdb/entity/const.go
Normal file
@@ -0,0 +1,91 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package entity
|
||||
|
||||
type DataType string
|
||||
|
||||
const (
|
||||
TypeInt DataType = "INT"
|
||||
TypeVarchar DataType = "VARCHAR"
|
||||
TypeText DataType = "TEXT"
|
||||
TypeBoolean DataType = "BOOLEAN"
|
||||
TypeJson DataType = "JSON"
|
||||
TypeTimestamp DataType = "TIMESTAMP"
|
||||
TypeFloat DataType = "FLOAT"
|
||||
TypeBigInt DataType = "BIGINT"
|
||||
TypeDouble DataType = "DOUBLE"
|
||||
)
|
||||
|
||||
type IndexType string
|
||||
|
||||
const (
|
||||
PrimaryKey IndexType = "PRIMARY KEY"
|
||||
UniqueKey IndexType = "UNIQUE KEY"
|
||||
NormalKey IndexType = "KEY"
|
||||
)
|
||||
|
||||
// AlterTableAction 定义修改表的动作类型
|
||||
type AlterTableAction string
|
||||
|
||||
const (
|
||||
AddColumn AlterTableAction = "ADD COLUMN"
|
||||
DropColumn AlterTableAction = "DROP COLUMN"
|
||||
ModifyColumn AlterTableAction = "MODIFY COLUMN"
|
||||
RenameColumn AlterTableAction = "RENAME COLUMN"
|
||||
AddIndex AlterTableAction = "ADD INDEX"
|
||||
)
|
||||
|
||||
type LogicalOperator string
|
||||
|
||||
const (
|
||||
AND LogicalOperator = "AND"
|
||||
OR LogicalOperator = "OR"
|
||||
)
|
||||
|
||||
type Operator string
|
||||
|
||||
const (
|
||||
OperatorEqual Operator = "="
|
||||
OperatorNotEqual Operator = "!="
|
||||
OperatorGreater Operator = ">"
|
||||
OperatorGreaterEqual Operator = ">="
|
||||
OperatorLess Operator = "<"
|
||||
OperatorLessEqual Operator = "<="
|
||||
|
||||
OperatorLike Operator = "LIKE"
|
||||
OperatorNotLike Operator = "NOT LIKE"
|
||||
|
||||
OperatorIn Operator = "IN"
|
||||
OperatorNotIn Operator = "NOT IN"
|
||||
|
||||
OperatorIsNull Operator = "IS NULL"
|
||||
OperatorIsNotNull Operator = "IS NOT NULL"
|
||||
)
|
||||
|
||||
type SortDirection string
|
||||
|
||||
const (
|
||||
SortDirectionAsc SortDirection = "ASC" // 升序
|
||||
SortDirectionDesc SortDirection = "DESC" // 降序
|
||||
)
|
||||
|
||||
type SQLType int32
|
||||
|
||||
const (
|
||||
SQLType_Parameterized SQLType = 0
|
||||
SQLType_Raw SQLType = 1 // Complete/raw SQL
|
||||
)
|
||||
54
backend/infra/contract/rdb/entity/rdb.go
Normal file
54
backend/infra/contract/rdb/entity/rdb.go
Normal file
@@ -0,0 +1,54 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package entity
|
||||
|
||||
type Column struct {
|
||||
Name string // 保证唯一性
|
||||
DataType DataType
|
||||
Length *int
|
||||
NotNull bool
|
||||
DefaultValue *string
|
||||
AutoIncrement bool // 表示该列是否为自动递增
|
||||
Comment *string
|
||||
}
|
||||
|
||||
type Index struct {
|
||||
Name string
|
||||
Type IndexType
|
||||
Columns []string
|
||||
}
|
||||
|
||||
type TableOption struct {
|
||||
Collate *string
|
||||
AutoIncrement *int64 // 设置表的自动递增初始值
|
||||
Comment *string
|
||||
}
|
||||
|
||||
type Table struct {
|
||||
Name string // 保证唯一性
|
||||
Columns []*Column
|
||||
Indexes []*Index
|
||||
Options *TableOption
|
||||
CreatedAt int64
|
||||
UpdatedAt int64
|
||||
}
|
||||
|
||||
type ResultSet struct {
|
||||
Columns []string
|
||||
Rows []map[string]interface{}
|
||||
AffectedRows int64
|
||||
}
|
||||
189
backend/infra/contract/rdb/rdb.go
Normal file
189
backend/infra/contract/rdb/rdb.go
Normal file
@@ -0,0 +1,189 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package rdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/rdb/entity"
|
||||
)
|
||||
|
||||
//go:generate mockgen -destination ../../../internal/mock/infra/contract/rdb/rdb_mock.go --package rdb -source rdb.go
|
||||
type RDB interface {
|
||||
CreateTable(ctx context.Context, req *CreateTableRequest) (*CreateTableResponse, error)
|
||||
AlterTable(ctx context.Context, req *AlterTableRequest) (*AlterTableResponse, error)
|
||||
DropTable(ctx context.Context, req *DropTableRequest) (*DropTableResponse, error)
|
||||
GetTable(ctx context.Context, req *GetTableRequest) (*GetTableResponse, error)
|
||||
|
||||
InsertData(ctx context.Context, req *InsertDataRequest) (*InsertDataResponse, error)
|
||||
UpdateData(ctx context.Context, req *UpdateDataRequest) (*UpdateDataResponse, error)
|
||||
DeleteData(ctx context.Context, req *DeleteDataRequest) (*DeleteDataResponse, error)
|
||||
SelectData(ctx context.Context, req *SelectDataRequest) (*SelectDataResponse, error)
|
||||
UpsertData(ctx context.Context, req *UpsertDataRequest) (*UpsertDataResponse, error)
|
||||
|
||||
ExecuteSQL(ctx context.Context, req *ExecuteSQLRequest) (*ExecuteSQLResponse, error)
|
||||
}
|
||||
|
||||
// CreateTableRequest 创建表请求
|
||||
type CreateTableRequest struct {
|
||||
Table *entity.Table
|
||||
}
|
||||
|
||||
// CreateTableResponse 创建表响应
|
||||
type CreateTableResponse struct {
|
||||
Table *entity.Table
|
||||
}
|
||||
|
||||
// AlterTableOperation 修改表操作
|
||||
type AlterTableOperation struct {
|
||||
Action entity.AlterTableAction
|
||||
Column *entity.Column
|
||||
OldName *string
|
||||
Index *entity.Index
|
||||
IndexName *string
|
||||
NewTableName *string
|
||||
}
|
||||
|
||||
// AlterTableRequest 修改表请求
|
||||
type AlterTableRequest struct {
|
||||
TableName string
|
||||
Operations []*AlterTableOperation
|
||||
}
|
||||
|
||||
// AlterTableResponse 修改表响应
|
||||
type AlterTableResponse struct {
|
||||
Table *entity.Table
|
||||
}
|
||||
|
||||
// DropTableRequest 删除表请求
|
||||
type DropTableRequest struct {
|
||||
TableName string
|
||||
IfExists bool
|
||||
}
|
||||
|
||||
// DropTableResponse 删除表响应
|
||||
type DropTableResponse struct {
|
||||
Success bool
|
||||
}
|
||||
|
||||
// GetTableRequest 获取表信息请求
|
||||
type GetTableRequest struct {
|
||||
TableName string
|
||||
}
|
||||
|
||||
// GetTableResponse 获取表信息响应
|
||||
type GetTableResponse struct {
|
||||
Table *entity.Table
|
||||
}
|
||||
|
||||
// InsertDataRequest 插入数据请求
|
||||
type InsertDataRequest struct {
|
||||
TableName string
|
||||
Data []map[string]interface{}
|
||||
}
|
||||
|
||||
// InsertDataResponse 插入数据响应
|
||||
type InsertDataResponse struct {
|
||||
AffectedRows int64
|
||||
}
|
||||
|
||||
// Condition 定义查询条件
|
||||
type Condition struct {
|
||||
Field string
|
||||
Operator entity.Operator
|
||||
Value interface{}
|
||||
}
|
||||
|
||||
// ComplexCondition 复杂条件
|
||||
type ComplexCondition struct {
|
||||
Conditions []*Condition
|
||||
NestedConditions []*ComplexCondition // 与 Conditions互斥 example: WHERE (age >= 18 AND status = 'active') OR (age >= 21 AND status = 'pending')
|
||||
Operator entity.LogicalOperator
|
||||
}
|
||||
|
||||
// UpdateDataRequest 更新数据请求
|
||||
type UpdateDataRequest struct {
|
||||
TableName string
|
||||
Data map[string]interface{}
|
||||
Where *ComplexCondition
|
||||
Limit *int
|
||||
}
|
||||
|
||||
// UpdateDataResponse 更新数据响应
|
||||
type UpdateDataResponse struct {
|
||||
AffectedRows int64
|
||||
}
|
||||
|
||||
// DeleteDataRequest 删除数据请求
|
||||
type DeleteDataRequest struct {
|
||||
TableName string
|
||||
Where *ComplexCondition
|
||||
Limit *int
|
||||
}
|
||||
|
||||
// DeleteDataResponse 删除数据响应
|
||||
type DeleteDataResponse struct {
|
||||
AffectedRows int64
|
||||
}
|
||||
|
||||
type OrderBy struct {
|
||||
Field string // 排序字段
|
||||
Direction entity.SortDirection // 排序方向
|
||||
}
|
||||
|
||||
// SelectDataRequest 查询数据请求
|
||||
type SelectDataRequest struct {
|
||||
TableName string
|
||||
Fields []string // 要查询的字段,如果为空则查询所有字段
|
||||
Where *ComplexCondition
|
||||
OrderBy []*OrderBy // 排序条件
|
||||
Limit *int // 限制返回行数
|
||||
Offset *int // 偏移量
|
||||
}
|
||||
|
||||
// SelectDataResponse 查询数据响应
|
||||
type SelectDataResponse struct {
|
||||
ResultSet *entity.ResultSet
|
||||
Total int64 // 符合条件的总记录数(不考虑分页)
|
||||
}
|
||||
|
||||
type UpsertDataRequest struct {
|
||||
TableName string
|
||||
Data []map[string]interface{} // 要更新或插入的数据
|
||||
Keys []string // 用于标识唯一记录的列名,为空的话默认使用主键
|
||||
}
|
||||
|
||||
type UpsertDataResponse struct {
|
||||
AffectedRows int64 // 受影响的行数
|
||||
InsertedRows int64 // 新插入的行数
|
||||
UpdatedRows int64 // 更新的行数
|
||||
UnchangedRows int64 // 不变的行数(没有插入或更新的行数)
|
||||
}
|
||||
|
||||
// ExecuteSQLRequest 执行SQL请求
|
||||
type ExecuteSQLRequest struct {
|
||||
SQL string
|
||||
Params []interface{} // 用于参数化查询
|
||||
|
||||
// SQLType indicates the type of SQL: parameterized or raw SQL. It takes effect if OperateType is 0.
|
||||
SQLType entity.SQLType
|
||||
}
|
||||
|
||||
// ExecuteSQLResponse 执行SQL响应
|
||||
type ExecuteSQLResponse struct {
|
||||
ResultSet *entity.ResultSet
|
||||
}
|
||||
67
backend/infra/contract/sqlparser/sql_parser.go
Normal file
67
backend/infra/contract/sqlparser/sql_parser.go
Normal file
@@ -0,0 +1,67 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package sqlparser
|
||||
|
||||
// TableColumn represents table and column name mapping
|
||||
type TableColumn struct {
|
||||
NewTableName *string // if nil, not replace table name
|
||||
ColumnMap map[string]string // Column name mapping: key is original column name, value is new column name
|
||||
}
|
||||
|
||||
type ColumnValue struct {
|
||||
ColName string
|
||||
Value interface{}
|
||||
}
|
||||
|
||||
type PrimaryKeyValue struct {
|
||||
ColName string
|
||||
Values []interface{}
|
||||
}
|
||||
|
||||
// OperationType represents the type of SQL operation
|
||||
type OperationType string
|
||||
|
||||
// SQL operation types
|
||||
const (
|
||||
OperationTypeSelect OperationType = "SELECT"
|
||||
OperationTypeInsert OperationType = "INSERT"
|
||||
OperationTypeUpdate OperationType = "UPDATE"
|
||||
OperationTypeDelete OperationType = "DELETE"
|
||||
OperationTypeCreate OperationType = "CREATE"
|
||||
OperationTypeAlter OperationType = "ALTER"
|
||||
OperationTypeDrop OperationType = "DROP"
|
||||
OperationTypeTruncate OperationType = "TRUNCATE"
|
||||
OperationTypeUnknown OperationType = "UNKNOWN"
|
||||
)
|
||||
|
||||
// SQLParser defines the interface for parsing and modifying SQL statements
|
||||
type SQLParser interface {
|
||||
// ParseAndModifySQL parses SQL and replaces table/column names according to the provided message
|
||||
ParseAndModifySQL(sql string, tableColumns map[string]TableColumn) (string, error) // tableColumns Original table name -> new TableInfo
|
||||
|
||||
// GetSQLOperation identifies the operation type in the SQL statement
|
||||
GetSQLOperation(sql string) (OperationType, error)
|
||||
|
||||
// AddColumnsToInsertSQL adds columns to the INSERT SQL statement.
|
||||
AddColumnsToInsertSQL(origSQL string, addCols []ColumnValue, colVals *PrimaryKeyValue, isParam bool) (string, map[string]bool, error)
|
||||
|
||||
// GetTableName extracts the table name from a SQL statement. Only supports single-table select/insert/update/delete. If it has multiple tables, return first table name.
|
||||
GetTableName(sql string) (string, error)
|
||||
|
||||
// GetInsertDataNums extracts the number of rows to be inserted from a SQL statement. Only supports single-table insert.
|
||||
GetInsertDataNums(sql string) (int, error)
|
||||
}
|
||||
27
backend/infra/contract/sse/sse.go
Normal file
27
backend/infra/contract/sse/sse.go
Normal file
@@ -0,0 +1,27 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package sse
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/hertz-contrib/sse"
|
||||
)
|
||||
|
||||
type SSender interface {
|
||||
Send(ctx context.Context, s *sse.Stream, event *sse.Event) error
|
||||
}
|
||||
73
backend/infra/contract/storage/option.go
Normal file
73
backend/infra/contract/storage/option.go
Normal file
@@ -0,0 +1,73 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package storage
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type GetOptFn func(option *GetOption)
|
||||
|
||||
type GetOption struct {
|
||||
Expire int64 // seconds
|
||||
}
|
||||
|
||||
func WithExpire(expire int64) GetOptFn {
|
||||
return func(o *GetOption) {
|
||||
o.Expire = expire
|
||||
}
|
||||
}
|
||||
|
||||
type PutOption struct {
|
||||
ContentType *string
|
||||
ContentEncoding *string
|
||||
ContentDisposition *string
|
||||
ContentLanguage *string
|
||||
Expires *time.Time
|
||||
}
|
||||
|
||||
type PutOptFn func(option *PutOption)
|
||||
|
||||
func WithContentType(v string) PutOptFn {
|
||||
return func(o *PutOption) {
|
||||
o.ContentType = &v
|
||||
}
|
||||
}
|
||||
|
||||
func WithContentEncoding(v string) PutOptFn {
|
||||
return func(o *PutOption) {
|
||||
o.ContentEncoding = &v
|
||||
}
|
||||
}
|
||||
|
||||
func WithContentDisposition(v string) PutOptFn {
|
||||
return func(o *PutOption) {
|
||||
o.ContentDisposition = &v
|
||||
}
|
||||
}
|
||||
|
||||
func WithContentLanguage(v string) PutOptFn {
|
||||
return func(o *PutOption) {
|
||||
o.ContentLanguage = &v
|
||||
}
|
||||
}
|
||||
|
||||
func WithExpires(v time.Time) PutOptFn {
|
||||
return func(o *PutOption) {
|
||||
o.Expires = &v
|
||||
}
|
||||
}
|
||||
35
backend/infra/contract/storage/storage.go
Normal file
35
backend/infra/contract/storage/storage.go
Normal file
@@ -0,0 +1,35 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package storage
|
||||
|
||||
import "context"
|
||||
|
||||
//go:generate mockgen -destination ../../../internal/mock/infra/contract/storage/storage_mock.go -package mock -source storage.go Factory
|
||||
type Storage interface {
|
||||
PutObject(ctx context.Context, objectKey string, content []byte, opts ...PutOptFn) error
|
||||
GetObject(ctx context.Context, objectKey string) ([]byte, error)
|
||||
DeleteObject(ctx context.Context, objectKey string) error
|
||||
GetObjectUrl(ctx context.Context, objectKey string, opts ...GetOptFn) (string, error)
|
||||
}
|
||||
|
||||
type SecurityToken struct {
|
||||
AccessKeyID string `thrift:"access_key_id,1" frugal:"1,default,string" json:"access_key_id"`
|
||||
SecretAccessKey string `thrift:"secret_access_key,2" frugal:"2,default,string" json:"secret_access_key"`
|
||||
SessionToken string `thrift:"session_token,3" frugal:"3,default,string" json:"session_token"`
|
||||
ExpiredTime string `thrift:"expired_time,4" frugal:"4,default,string" json:"expired_time"`
|
||||
CurrentTime string `thrift:"current_time,5" frugal:"5,default,string" json:"current_time"`
|
||||
}
|
||||
46
backend/infra/impl/cache/redis/redis.go
vendored
Normal file
46
backend/infra/impl/cache/redis/redis.go
vendored
Normal file
@@ -0,0 +1,46 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package redis
|
||||
|
||||
import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
type Client = redis.Client
|
||||
|
||||
func New() *redis.Client {
|
||||
addr := os.Getenv("REDIS_ADDR")
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: addr, // Redis地址
|
||||
DB: 0, // 默认数据库
|
||||
// 连接池配置
|
||||
PoolSize: 100, // 最大连接数(建议设置为CPU核心数*10)
|
||||
MinIdleConns: 10, // 最小空闲连接
|
||||
MaxIdleConns: 30, // 最大空闲连接
|
||||
ConnMaxIdleTime: 5 * time.Minute, // 空闲连接超时时间
|
||||
|
||||
// 超时配置
|
||||
DialTimeout: 5 * time.Second, // 连接建立超时
|
||||
ReadTimeout: 3 * time.Second, // 读操作超时
|
||||
WriteTimeout: 3 * time.Second, // 写操作超时
|
||||
})
|
||||
|
||||
return rdb
|
||||
}
|
||||
273
backend/infra/impl/chatmodel/default_factory.go
Normal file
273
backend/infra/impl/chatmodel/default_factory.go
Normal file
@@ -0,0 +1,273 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package chatmodel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/model/ark"
|
||||
"github.com/cloudwego/eino-ext/components/model/claude"
|
||||
"github.com/cloudwego/eino-ext/components/model/deepseek"
|
||||
"github.com/cloudwego/eino-ext/components/model/gemini"
|
||||
"github.com/cloudwego/eino-ext/components/model/ollama"
|
||||
"github.com/cloudwego/eino-ext/components/model/openai"
|
||||
"github.com/cloudwego/eino-ext/components/model/qwen"
|
||||
"github.com/ollama/ollama/api"
|
||||
"google.golang.org/genai"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
type Builder func(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error)
|
||||
|
||||
func NewDefaultFactory() chatmodel.Factory {
|
||||
return NewFactory(nil)
|
||||
}
|
||||
|
||||
func NewFactory(customFactory map[chatmodel.Protocol]Builder) chatmodel.Factory {
|
||||
protocol2Builder := map[chatmodel.Protocol]Builder{
|
||||
chatmodel.ProtocolOpenAI: openAIBuilder,
|
||||
chatmodel.ProtocolClaude: claudeBuilder,
|
||||
chatmodel.ProtocolDeepseek: deepseekBuilder,
|
||||
chatmodel.ProtocolArk: arkBuilder,
|
||||
chatmodel.ProtocolGemini: geminiBuilder,
|
||||
chatmodel.ProtocolOllama: ollamaBuilder,
|
||||
chatmodel.ProtocolQwen: qwenBuilder,
|
||||
chatmodel.ProtocolErnie: nil,
|
||||
}
|
||||
|
||||
for p := range customFactory {
|
||||
protocol2Builder[p] = customFactory[p]
|
||||
}
|
||||
|
||||
return &defaultFactory{protocol2Builder: protocol2Builder}
|
||||
}
|
||||
|
||||
type defaultFactory struct {
|
||||
protocol2Builder map[chatmodel.Protocol]Builder
|
||||
}
|
||||
|
||||
func (f *defaultFactory) CreateChatModel(ctx context.Context, protocol chatmodel.Protocol, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
|
||||
if config == nil {
|
||||
return nil, fmt.Errorf("[CreateChatModel] config not provided")
|
||||
}
|
||||
|
||||
builder, found := f.protocol2Builder[protocol]
|
||||
if !found {
|
||||
return nil, fmt.Errorf("[CreateChatModel] protocol not support, protocol=%s", protocol)
|
||||
}
|
||||
|
||||
return builder(ctx, config)
|
||||
}
|
||||
|
||||
func (f *defaultFactory) SupportProtocol(protocol chatmodel.Protocol) bool {
|
||||
_, found := f.protocol2Builder[protocol]
|
||||
return found
|
||||
}
|
||||
|
||||
func openAIBuilder(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
|
||||
cfg := &openai.ChatModelConfig{
|
||||
APIKey: config.APIKey,
|
||||
Timeout: config.Timeout,
|
||||
BaseURL: config.BaseURL,
|
||||
Model: config.Model,
|
||||
MaxTokens: config.MaxTokens,
|
||||
Temperature: config.Temperature,
|
||||
TopP: config.TopP,
|
||||
Stop: config.Stop,
|
||||
PresencePenalty: config.PresencePenalty,
|
||||
FrequencyPenalty: config.FrequencyPenalty,
|
||||
}
|
||||
if config.OpenAI != nil {
|
||||
cfg.ByAzure = config.OpenAI.ByAzure
|
||||
cfg.APIVersion = config.OpenAI.APIVersion
|
||||
cfg.ResponseFormat = config.OpenAI.ResponseFormat
|
||||
}
|
||||
return openai.NewChatModel(ctx, cfg)
|
||||
}
|
||||
|
||||
func claudeBuilder(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
|
||||
cfg := &claude.Config{
|
||||
APIKey: config.APIKey,
|
||||
Model: config.Model,
|
||||
Temperature: config.Temperature,
|
||||
TopP: config.TopP,
|
||||
StopSequences: config.Stop,
|
||||
}
|
||||
if config.BaseURL != "" {
|
||||
cfg.BaseURL = &config.BaseURL
|
||||
}
|
||||
if config.MaxTokens != nil {
|
||||
cfg.MaxTokens = *config.MaxTokens
|
||||
}
|
||||
if config.TopK != nil {
|
||||
cfg.TopK = ptr.Of(int32(*config.TopK))
|
||||
}
|
||||
if config.Claude != nil {
|
||||
cfg.ByBedrock = config.Claude.ByBedrock
|
||||
cfg.AccessKey = config.Claude.AccessKey
|
||||
cfg.SecretAccessKey = config.Claude.SecretAccessKey
|
||||
cfg.SessionToken = config.Claude.SessionToken
|
||||
cfg.Region = config.Claude.Region
|
||||
}
|
||||
return claude.NewChatModel(ctx, cfg)
|
||||
}
|
||||
|
||||
func deepseekBuilder(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
|
||||
cfg := &deepseek.ChatModelConfig{
|
||||
APIKey: config.APIKey,
|
||||
Timeout: config.Timeout,
|
||||
BaseURL: config.BaseURL,
|
||||
Model: config.Model,
|
||||
Stop: config.Stop,
|
||||
}
|
||||
if config.Temperature != nil {
|
||||
cfg.Temperature = *config.Temperature
|
||||
}
|
||||
if config.FrequencyPenalty != nil {
|
||||
cfg.FrequencyPenalty = *config.FrequencyPenalty
|
||||
}
|
||||
if config.PresencePenalty != nil {
|
||||
cfg.PresencePenalty = *config.PresencePenalty
|
||||
}
|
||||
if config.MaxTokens != nil {
|
||||
cfg.MaxTokens = *config.MaxTokens
|
||||
}
|
||||
if config.TopP != nil {
|
||||
cfg.TopP = *config.TopP
|
||||
}
|
||||
if config.Deepseek != nil {
|
||||
cfg.ResponseFormatType = config.Deepseek.ResponseFormatType
|
||||
}
|
||||
return deepseek.NewChatModel(ctx, cfg)
|
||||
}
|
||||
|
||||
func arkBuilder(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
|
||||
cfg := &ark.ChatModelConfig{
|
||||
BaseURL: config.BaseURL,
|
||||
APIKey: config.APIKey,
|
||||
Model: config.Model,
|
||||
MaxTokens: config.MaxTokens,
|
||||
Temperature: config.Temperature,
|
||||
TopP: config.TopP,
|
||||
Stop: config.Stop,
|
||||
FrequencyPenalty: config.FrequencyPenalty,
|
||||
PresencePenalty: config.PresencePenalty,
|
||||
}
|
||||
if config.Timeout != 0 {
|
||||
cfg.Timeout = &config.Timeout
|
||||
}
|
||||
if config.Ark != nil {
|
||||
cfg.Region = config.Ark.Region
|
||||
cfg.AccessKey = config.Ark.AccessKey
|
||||
cfg.SecretKey = config.Ark.SecretKey
|
||||
cfg.RetryTimes = config.Ark.RetryTimes
|
||||
cfg.CustomHeader = config.Ark.CustomHeader
|
||||
}
|
||||
return ark.NewChatModel(ctx, cfg)
|
||||
}
|
||||
|
||||
func ollamaBuilder(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
|
||||
cfg := &ollama.ChatModelConfig{
|
||||
BaseURL: config.BaseURL,
|
||||
Timeout: config.Timeout,
|
||||
HTTPClient: nil,
|
||||
Model: config.Model,
|
||||
Format: nil,
|
||||
KeepAlive: nil,
|
||||
Options: &api.Options{
|
||||
TopK: ptr.From(config.TopK),
|
||||
TopP: ptr.From(config.TopP),
|
||||
Temperature: ptr.From(config.Temperature),
|
||||
PresencePenalty: ptr.From(config.PresencePenalty),
|
||||
FrequencyPenalty: ptr.From(config.FrequencyPenalty),
|
||||
Stop: config.Stop,
|
||||
},
|
||||
}
|
||||
return ollama.NewChatModel(ctx, cfg)
|
||||
}
|
||||
|
||||
func qwenBuilder(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
|
||||
cfg := &qwen.ChatModelConfig{
|
||||
APIKey: config.APIKey,
|
||||
Timeout: config.Timeout,
|
||||
BaseURL: config.BaseURL,
|
||||
Model: config.Model,
|
||||
MaxTokens: config.MaxTokens,
|
||||
Temperature: config.Temperature,
|
||||
TopP: config.TopP,
|
||||
Stop: config.Stop,
|
||||
PresencePenalty: config.PresencePenalty,
|
||||
FrequencyPenalty: config.FrequencyPenalty,
|
||||
EnableThinking: config.EnableThinking,
|
||||
}
|
||||
if config.Qwen != nil {
|
||||
cfg.ResponseFormat = config.Qwen.ResponseFormat
|
||||
}
|
||||
return qwen.NewChatModel(ctx, cfg)
|
||||
}
|
||||
|
||||
func geminiBuilder(ctx context.Context, config *chatmodel.Config) (chatmodel.ToolCallingChatModel, error) {
|
||||
gc := &genai.ClientConfig{
|
||||
APIKey: config.APIKey,
|
||||
HTTPOptions: genai.HTTPOptions{
|
||||
BaseURL: config.BaseURL,
|
||||
},
|
||||
}
|
||||
if config.Gemini != nil {
|
||||
gc.Backend = config.Gemini.Backend
|
||||
gc.Project = config.Gemini.Project
|
||||
gc.Location = config.Gemini.Location
|
||||
gc.HTTPOptions.APIVersion = config.Gemini.APIVersion
|
||||
gc.HTTPOptions.Headers = config.Gemini.Headers
|
||||
}
|
||||
|
||||
client, err := genai.NewClient(ctx, gc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cfg := &gemini.Config{
|
||||
Client: client,
|
||||
Model: config.Model,
|
||||
MaxTokens: config.MaxTokens,
|
||||
Temperature: config.Temperature,
|
||||
TopP: config.TopP,
|
||||
ThinkingConfig: &genai.ThinkingConfig{
|
||||
IncludeThoughts: true,
|
||||
ThinkingBudget: nil,
|
||||
},
|
||||
}
|
||||
if config.TopK != nil {
|
||||
cfg.TopK = ptr.Of(int32(ptr.From(config.TopK)))
|
||||
}
|
||||
if config.Gemini != nil && config.Gemini.IncludeThoughts != nil {
|
||||
cfg.ThinkingConfig.IncludeThoughts = ptr.From(config.Gemini.IncludeThoughts)
|
||||
}
|
||||
if config.Gemini != nil && config.Gemini.ThinkingBudget != nil {
|
||||
cfg.ThinkingConfig.ThinkingBudget = config.Gemini.ThinkingBudget
|
||||
}
|
||||
|
||||
cm, err := gemini.NewChatModel(ctx, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return cm, nil
|
||||
}
|
||||
38
backend/infra/impl/chatmodel/singleton.go
Normal file
38
backend/infra/impl/chatmodel/singleton.go
Normal file
@@ -0,0 +1,38 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package chatmodel
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
|
||||
)
|
||||
|
||||
var (
|
||||
once sync.Once
|
||||
singletonFactory chatmodel.Factory
|
||||
)
|
||||
|
||||
func InitSingletonFactory(factory chatmodel.Factory) {
|
||||
once.Do(func() {
|
||||
singletonFactory = factory
|
||||
})
|
||||
}
|
||||
|
||||
func GetSingletonFactory() chatmodel.Factory {
|
||||
return singletonFactory
|
||||
}
|
||||
49
backend/infra/impl/checkpoint/mem.go
Normal file
49
backend/infra/impl/checkpoint/mem.go
Normal file
@@ -0,0 +1,49 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package checkpoint
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
)
|
||||
|
||||
type inMemoryStore struct {
|
||||
m map[string][]byte
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func (i *inMemoryStore) Get(_ context.Context, checkPointID string) ([]byte, bool, error) {
|
||||
i.mu.RLock()
|
||||
v, ok := i.m[checkPointID]
|
||||
i.mu.RUnlock()
|
||||
return v, ok, nil
|
||||
}
|
||||
|
||||
func (i *inMemoryStore) Set(_ context.Context, checkPointID string, checkPoint []byte) error {
|
||||
i.mu.Lock()
|
||||
i.m[checkPointID] = checkPoint
|
||||
i.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewInMemoryStore() compose.CheckPointStore {
|
||||
return &inMemoryStore{
|
||||
m: make(map[string][]byte),
|
||||
}
|
||||
}
|
||||
55
backend/infra/impl/checkpoint/redis.go
Normal file
55
backend/infra/impl/checkpoint/redis.go
Normal file
@@ -0,0 +1,55 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package checkpoint
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
type redisStore struct {
|
||||
client *redis.Client
|
||||
}
|
||||
|
||||
const (
|
||||
checkpointKeyTpl = "checkpoint_key:%s"
|
||||
checkpointExpire = 24 * 7 * 3600 * time.Second
|
||||
)
|
||||
|
||||
func (r *redisStore) Get(ctx context.Context, checkPointID string) ([]byte, bool, error) {
|
||||
v, err := r.client.Get(ctx, fmt.Sprintf(checkpointKeyTpl, checkPointID)).Bytes()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return nil, false, nil
|
||||
}
|
||||
return nil, false, err
|
||||
}
|
||||
return v, true, nil
|
||||
}
|
||||
|
||||
func (r *redisStore) Set(ctx context.Context, checkPointID string, checkPoint []byte) error {
|
||||
return r.client.Set(ctx, fmt.Sprintf(checkpointKeyTpl, checkPointID), checkPoint, checkpointExpire).Err()
|
||||
}
|
||||
|
||||
func NewRedisStore(client *redis.Client) compose.CheckPointStore {
|
||||
return &redisStore{client: client}
|
||||
}
|
||||
97
backend/infra/impl/coderunner/runner.go
Normal file
97
backend/infra/impl/coderunner/runner.go
Normal file
@@ -0,0 +1,97 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package coderunner
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/code"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/goutil"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
)
|
||||
|
||||
var pythonCode = `
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
|
||||
class Args:
|
||||
def __init__(self, params):
|
||||
self.params = params
|
||||
|
||||
class Output(dict):
|
||||
pass
|
||||
|
||||
%s
|
||||
|
||||
try:
|
||||
result = asyncio.run(main( Args(json.loads(sys.argv[1]))))
|
||||
print(json.dumps(result))
|
||||
except Exception as e:
|
||||
print(f"{type(e).__name__}: {str(e)}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
`
|
||||
|
||||
type Runner struct{}
|
||||
|
||||
func NewRunner() *Runner {
|
||||
return &Runner{}
|
||||
}
|
||||
|
||||
func (r *Runner) Run(ctx context.Context, request *code.RunRequest) (*code.RunResponse, error) {
|
||||
var (
|
||||
params = request.Params
|
||||
c = request.Code
|
||||
)
|
||||
if request.Language == code.Python {
|
||||
ret, err := r.pythonCmdRun(ctx, c, params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &code.RunResponse{
|
||||
Result: ret,
|
||||
}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("unsupported language: %s", request.Language)
|
||||
}
|
||||
|
||||
func (r *Runner) pythonCmdRun(_ context.Context, code string, params map[string]any) (map[string]any, error) {
|
||||
bs, _ := sonic.Marshal(params)
|
||||
cmd := exec.Command(goutil.GetPython3Path(), "-c", fmt.Sprintf(pythonCode, code), string(bs)) //ignore_security_alert RCE
|
||||
stdout := new(bytes.Buffer)
|
||||
stderr := new(bytes.Buffer)
|
||||
cmd.Stdout = stdout
|
||||
cmd.Stderr = stderr
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to run python script err: %s, std err: %s", err.Error(), stderr.String())
|
||||
}
|
||||
|
||||
if stderr.String() != "" {
|
||||
return nil, fmt.Errorf("failed to run python script err: %s", stderr.String())
|
||||
}
|
||||
ret := make(map[string]any)
|
||||
err = sonic.Unmarshal(stdout.Bytes(), &ret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
62
backend/infra/impl/coderunner/script/python_script.py
Normal file
62
backend/infra/impl/coderunner/script/python_script.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
import asyncio
|
||||
import time
|
||||
import random
|
||||
try:
|
||||
from RestrictedPython import safe_builtins, limited_builtins, utility_builtins
|
||||
except ModuleNotFoundError:
|
||||
print("RestrictedPython module required, please run pip install RestrictedPython",file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
custom_builtins = safe_builtins.copy()
|
||||
|
||||
custom_builtins['__import__'] = __import__
|
||||
custom_builtins['asyncio'] = asyncio
|
||||
custom_builtins['json'] = json
|
||||
custom_builtins['time'] = time
|
||||
custom_builtins['random'] = random
|
||||
|
||||
restricted_globals = {
|
||||
'__builtins__': custom_builtins,
|
||||
'_utility_builtins': utility_builtins,
|
||||
'_limited_builtins': limited_builtins,
|
||||
'__name__': '__main__',
|
||||
'dict': dict,
|
||||
'list': list,
|
||||
'print': print,
|
||||
'set': set,
|
||||
|
||||
}
|
||||
|
||||
class Args:
|
||||
def __init__(self, params):
|
||||
self.params = params
|
||||
|
||||
|
||||
DefaultCode = """
|
||||
class Args:
|
||||
def __init__(self, params):
|
||||
self.params = params
|
||||
class Output(dict):
|
||||
pass
|
||||
"""
|
||||
|
||||
|
||||
async def run_main(app_code, params):
|
||||
try:
|
||||
complete_code = DefaultCode + app_code
|
||||
locals_dict = {"args": Args(params=params)}
|
||||
exec(complete_code, restricted_globals, locals_dict) # ignore_security_alert
|
||||
main_func = locals_dict['main']
|
||||
ret = await main_func(locals_dict['args'])
|
||||
except Exception as e:
|
||||
print(f"{type(e).__name__}: {str(e)}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
return ret
|
||||
|
||||
|
||||
code = sys.argv[1]
|
||||
result = asyncio.run(run_main(code, params=json.loads(sys.argv[2])))
|
||||
print(json.dumps(result))
|
||||
118
backend/infra/impl/document/nl2sql/builtin/nl2sql.go
Normal file
118
backend/infra/impl/document/nl2sql/builtin/nl2sql.go
Normal file
@@ -0,0 +1,118 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package builtin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/components/prompt"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/nl2sql"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTableFmt = "table name: %s.\ntable describe: %s.\n\n| field name | description | field type | is required |\n"
|
||||
defaultColumnFmt = "| %s | %s | %s | %t |\n\n"
|
||||
)
|
||||
|
||||
func NewNL2SQL(_ context.Context, cm chatmodel.BaseChatModel, tpl prompt.ChatTemplate) (nl2sql.NL2SQL, error) {
|
||||
return &n2s{cm: cm, tpl: tpl}, nil
|
||||
}
|
||||
|
||||
type n2s struct {
|
||||
ch *compose.Chain[*nl2sqlInput, string]
|
||||
runnable compose.Runnable[*nl2sqlInput, string]
|
||||
|
||||
cm chatmodel.BaseChatModel
|
||||
tpl prompt.ChatTemplate
|
||||
}
|
||||
|
||||
func (n *n2s) NL2SQL(ctx context.Context, messages []*schema.Message, tables []*document.TableSchema, opts ...nl2sql.Option) (sql string, err error) {
|
||||
o := &nl2sql.Options{ChatModel: n.cm}
|
||||
for _, opt := range opts {
|
||||
opt(o)
|
||||
}
|
||||
|
||||
if o.ChatModel == nil {
|
||||
return "", fmt.Errorf("[NL2SQL] chat model not configured")
|
||||
}
|
||||
|
||||
c := compose.NewChain[*nl2sqlInput, string]().
|
||||
AppendLambda(compose.InvokableLambda(func(ctx context.Context, input *nl2sqlInput) (output map[string]any, err error) {
|
||||
if len(input.tables) == 0 {
|
||||
return nil, errors.New("table meta is empty")
|
||||
}
|
||||
tableDesc := strings.Builder{}
|
||||
for _, table := range input.tables {
|
||||
tableDesc.WriteString(fmt.Sprintf(defaultTableFmt, table.Name, table.Comment))
|
||||
for _, column := range table.Columns {
|
||||
tableDesc.WriteString(fmt.Sprintf(defaultColumnFmt, column.Name, column.Description, column.Type.String(), !column.Nullable))
|
||||
}
|
||||
}
|
||||
//logs.CtxInfof(ctx, "table schema: %s", tableDesc.String())
|
||||
return map[string]interface{}{
|
||||
"messages": input.messages,
|
||||
"table_schema": tableDesc.String(),
|
||||
}, nil
|
||||
})).
|
||||
AppendChatTemplate(n.tpl).
|
||||
AppendChatModel(o.ChatModel).
|
||||
AppendLambda(compose.InvokableLambda(func(ctx context.Context, msg *schema.Message) (sql string, err error) {
|
||||
var promptResp *promptResponse
|
||||
if err := json.Unmarshal([]byte(msg.Content), &promptResp); err != nil {
|
||||
logs.CtxWarnf(ctx, "unmarshal failed: %v", err)
|
||||
return "", err
|
||||
}
|
||||
if promptResp.SQL == "" {
|
||||
logs.CtxInfof(ctx, "no sql generated, err_code: %v, err_msg: %v", promptResp.ErrCode, promptResp.ErrMsg)
|
||||
return "", errors.New(promptResp.ErrMsg)
|
||||
}
|
||||
return promptResp.SQL, nil
|
||||
}))
|
||||
|
||||
r, err := c.Compile(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
input := &nl2sqlInput{
|
||||
messages: messages,
|
||||
tables: tables,
|
||||
}
|
||||
|
||||
return r.Invoke(ctx, input)
|
||||
}
|
||||
|
||||
type nl2sqlInput struct {
|
||||
messages []*schema.Message
|
||||
tables []*document.TableSchema
|
||||
}
|
||||
|
||||
type promptResponse struct {
|
||||
SQL string `json:"sql"`
|
||||
ErrCode int `json:"err_code"`
|
||||
ErrMsg string `json:"err_msg"`
|
||||
}
|
||||
139
backend/infra/impl/document/nl2sql/builtin/nl2sql_test.go
Normal file
139
backend/infra/impl/document/nl2sql/builtin/nl2sql_test.go
Normal file
@@ -0,0 +1,139 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package builtin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/components/prompt"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
|
||||
)
|
||||
|
||||
func TestNL2SQL(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("test table meta not provided", func(t *testing.T) {
|
||||
impl, err := NewNL2SQL(ctx, &mockChatModel{"mock resp"}, prompt.FromMessages(schema.Jinja2,
|
||||
schema.SystemMessage("system message 123"),
|
||||
schema.UserMessage("{{messages}}, {{table_meta}}"),
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
sql, err := impl.NL2SQL(ctx, []*schema.Message{schema.UserMessage("hello")}, nil)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "", sql)
|
||||
})
|
||||
|
||||
t.Run("test parse failed", func(t *testing.T) {
|
||||
impl, err := NewNL2SQL(ctx, &mockChatModel{"mock resp"}, prompt.FromMessages(schema.Jinja2,
|
||||
schema.SystemMessage("system message 123"),
|
||||
schema.UserMessage("{{messages}}, {{table_meta}}"),
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
sql, err := impl.NL2SQL(ctx, []*schema.Message{schema.UserMessage("hello")}, []*document.TableSchema{
|
||||
{
|
||||
Name: "mock_table_1",
|
||||
Comment: "hello",
|
||||
Columns: []*document.Column{
|
||||
{
|
||||
ID: 121,
|
||||
Name: "id",
|
||||
Type: document.TableColumnTypeInteger,
|
||||
Description: "test",
|
||||
Nullable: false,
|
||||
IsPrimary: true,
|
||||
Sequence: 0,
|
||||
},
|
||||
{
|
||||
ID: 123,
|
||||
Name: "col_1",
|
||||
Type: document.TableColumnTypeString,
|
||||
Description: "column_1",
|
||||
Nullable: true,
|
||||
IsPrimary: false,
|
||||
Sequence: 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "", sql)
|
||||
})
|
||||
|
||||
t.Run("test success", func(t *testing.T) {
|
||||
impl, err := NewNL2SQL(ctx, &mockChatModel{`{"sql":"mock sql","err_code":0,"err_msg":""}`}, prompt.FromMessages(schema.Jinja2,
|
||||
schema.SystemMessage("system message 123"),
|
||||
schema.UserMessage("{{messages}}, {{table_meta}}"),
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
sql, err := impl.NL2SQL(ctx, []*schema.Message{schema.UserMessage("hello")}, []*document.TableSchema{
|
||||
{
|
||||
Name: "mock_table_1",
|
||||
Comment: "hello",
|
||||
Columns: []*document.Column{
|
||||
{
|
||||
ID: 121,
|
||||
Name: "id",
|
||||
Type: document.TableColumnTypeInteger,
|
||||
Description: "test",
|
||||
Nullable: false,
|
||||
IsPrimary: true,
|
||||
Sequence: 0,
|
||||
},
|
||||
{
|
||||
ID: 123,
|
||||
Name: "col_1",
|
||||
Type: document.TableColumnTypeString,
|
||||
Description: "column_1",
|
||||
Nullable: true,
|
||||
IsPrimary: false,
|
||||
Sequence: 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "mock sql", sql)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
type mockChatModel struct {
|
||||
content string
|
||||
}
|
||||
|
||||
func (m mockChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) {
|
||||
return schema.AssistantMessage(m.content, nil), nil
|
||||
}
|
||||
|
||||
func (m mockChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m mockChatModel) BindTools(tools []*schema.ToolInfo) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
const sys = "# Role: NL2SQL Consultant\n\n## Goals\nTranslate natural language statements into SQL queries in MySQL standard. Follow the Constraints and return only a JSON always.\n\n## Format\n- JSON format only. JSON contains field \"sql\" for generated SQL, filed \"err_code\" for reason type, field \"err_msg\" for detail reason (prefer more than 10 words)\n- Don't use \"```json\" markdown format\n\n## Skills\n- Good at Translate natural language statements into SQL queries in MySQL standard.\n\n## Define\n\"err_code\" Reason Type Define:\n- 0 means you generated a SQL\n- 3002 means you cannot generate a SQL because of timeout\n- 3003 means you cannot generate a SQL because of table schema missing\n- 3005 means you cannot generate a SQL because of some term is ambiguous\n\n## Example\nQ: Help me implement NL2SQL.\n.table schema description: CREATE TABLE `sales_records` (\\n `sales_id` bigint(20) unsigned NOT NULL COMMENT 'id of sales person',\\n `product_id` bigint(64) COMMENT 'id of product',\\n `sale_date` datetime(3) COMMENT 'sold date and time',\\n `quantity_sold` int(11) COMMENT 'sold amount',\\n PRIMARY KEY (`sales_id`)\\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='销售记录表';\n.natural language description of the SQL requirement: 查询上月的销量总额第一名的销售员和他的销售总额\nA: {\n \"sql\":\"SELECT sales_id, SUM(quantity_sold) AS total_sales FROM sales_records WHERE MONTH(sale_date) = MONTH(CURRENT_DATE - INTERVAL 1 MONTH) AND YEAR(sale_date) = YEAR(CURRENT_DATE - INTERVAL 1 MONTH) GROUP BY sales_id ORDER BY total_sales DESC LIMIT 1\",\n \"err_code\":0,\n \"err_msg\":\"SQL query generated successfully\"\n}"
|
||||
const usr = "help me implement NL2SQL.\ntable schema description:{{tableSchema}}\nnatural language description of the SQL requirement: {{chat_history}}."
|
||||
96
backend/infra/impl/document/ocr/veocr/ve_ocr.go
Normal file
96
backend/infra/impl/document/ocr/veocr/ve_ocr.go
Normal file
@@ -0,0 +1,96 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package veocr
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"github.com/volcengine/volc-sdk-golang/service/visual"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/ocr"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Client *visual.Visual
|
||||
|
||||
// see: https://www.volcengine.com/docs/6790/117730
|
||||
ApproximatePixel *int // default: 0
|
||||
Mode *string // default: "text_block"
|
||||
FilterThresh *int // default: 80
|
||||
HalfToFull *bool // default: false
|
||||
}
|
||||
|
||||
func NewOCR(config *Config) ocr.OCR {
|
||||
return &ocrImpl{config}
|
||||
}
|
||||
|
||||
type ocrImpl struct {
|
||||
config *Config
|
||||
}
|
||||
|
||||
func (o *ocrImpl) FromBase64(ctx context.Context, b64 string) ([]string, error) {
|
||||
form := o.newForm()
|
||||
form.Add("image_base64", b64)
|
||||
|
||||
resp, statusCode, err := o.config.Client.OCRNormal(form)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if statusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("[FromBase64] failed, status code=%d", statusCode)
|
||||
}
|
||||
|
||||
return resp.Data.LineTexts, nil
|
||||
}
|
||||
|
||||
func (o *ocrImpl) FromURL(ctx context.Context, url string) ([]string, error) {
|
||||
form := o.newForm()
|
||||
form.Add("image_url", url)
|
||||
|
||||
resp, statusCode, err := o.config.Client.OCRNormal(form)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if statusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("[FromBase64] failed, status code=%d", statusCode)
|
||||
}
|
||||
|
||||
return resp.Data.LineTexts, nil
|
||||
}
|
||||
|
||||
func (o *ocrImpl) newForm() url.Values {
|
||||
form := url.Values{}
|
||||
if o.config.ApproximatePixel != nil {
|
||||
form.Add("approximate_pixel", strconv.FormatInt(int64(*o.config.ApproximatePixel), 10))
|
||||
}
|
||||
if o.config.Mode != nil {
|
||||
form.Add("mode", *o.config.Mode)
|
||||
} else {
|
||||
form.Add("mode", "text_block")
|
||||
}
|
||||
if o.config.FilterThresh != nil {
|
||||
form.Add("filter_thresh", strconv.FormatInt(int64(*o.config.FilterThresh), 10))
|
||||
}
|
||||
if o.config.HalfToFull != nil {
|
||||
form.Add("half_to_full", strconv.FormatBool(*o.config.HalfToFull))
|
||||
}
|
||||
return form
|
||||
}
|
||||
37
backend/infra/impl/document/parser/builtin/align_schema.go
Normal file
37
backend/infra/impl/document/parser/builtin/align_schema.go
Normal file
@@ -0,0 +1,37 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package builtin
|
||||
|
||||
import (
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
|
||||
)
|
||||
|
||||
func alignTableSliceValue(schema []*document.Column, row []*document.ColumnData) (err error) {
|
||||
for i, col := range row {
|
||||
var newCol *document.ColumnData
|
||||
newCol, err = assertValAs(schema[i].Type, col.GetStringValue())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newCol.ColumnID = col.ColumnID
|
||||
newCol.ColumnName = col.ColumnName
|
||||
row[i] = newCol
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
142
backend/infra/impl/document/parser/builtin/align_schema_test.go
Normal file
142
backend/infra/impl/document/parser/builtin/align_schema_test.go
Normal file
@@ -0,0 +1,142 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package builtin
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
. "github.com/bytedance/mockey"
|
||||
"github.com/smartystreets/goconvey/convey"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
func TestAssertVal(t *testing.T) {
|
||||
PatchConvey("test assertVal", t, func() {
|
||||
convey.So(assertVal(""), convey.ShouldEqual, document.ColumnData{
|
||||
Type: document.TableColumnTypeUnknown,
|
||||
ValString: ptr.Of(""),
|
||||
})
|
||||
convey.So(assertVal("true"), convey.ShouldEqual, document.ColumnData{
|
||||
Type: document.TableColumnTypeBoolean,
|
||||
ValBoolean: ptr.Of(true),
|
||||
})
|
||||
convey.So(assertVal("10"), convey.ShouldEqual, document.ColumnData{
|
||||
Type: document.TableColumnTypeInteger,
|
||||
ValInteger: ptr.Of(int64(10)),
|
||||
})
|
||||
convey.So(assertVal("1.0"), convey.ShouldEqual, document.ColumnData{
|
||||
Type: document.TableColumnTypeNumber,
|
||||
ValNumber: ptr.Of(1.0),
|
||||
})
|
||||
ts := time.Now().Format(timeFormat)
|
||||
now, err := time.Parse(timeFormat, ts)
|
||||
convey.So(err, convey.ShouldBeNil)
|
||||
convey.So(assertVal(ts), convey.ShouldEqual, document.ColumnData{
|
||||
Type: document.TableColumnTypeTime,
|
||||
ValTime: ptr.Of(now),
|
||||
})
|
||||
convey.So(assertVal("hello"), convey.ShouldEqual, document.ColumnData{
|
||||
Type: document.TableColumnTypeString,
|
||||
ValString: ptr.Of("hello"),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestAssertValAs(t *testing.T) {
|
||||
PatchConvey("test assertValAs", t, func() {
|
||||
type testCase struct {
|
||||
typ document.TableColumnType
|
||||
val string
|
||||
isErr bool
|
||||
data *document.ColumnData
|
||||
}
|
||||
|
||||
ts := time.Now().Format(timeFormat)
|
||||
now, _ := time.Parse(timeFormat, ts)
|
||||
cases := []testCase{
|
||||
{
|
||||
typ: document.TableColumnTypeString,
|
||||
val: "hello",
|
||||
isErr: false,
|
||||
data: &document.ColumnData{Type: document.TableColumnTypeString, ValString: ptr.Of("hello")},
|
||||
},
|
||||
{
|
||||
typ: document.TableColumnTypeInteger,
|
||||
val: "1",
|
||||
isErr: false,
|
||||
data: &document.ColumnData{Type: document.TableColumnTypeInteger, ValInteger: ptr.Of(int64(1))},
|
||||
},
|
||||
{
|
||||
typ: document.TableColumnTypeInteger,
|
||||
val: "hello",
|
||||
isErr: true,
|
||||
},
|
||||
{
|
||||
typ: document.TableColumnTypeTime,
|
||||
val: ts,
|
||||
isErr: false,
|
||||
data: &document.ColumnData{Type: document.TableColumnTypeTime, ValTime: ptr.Of(now)},
|
||||
},
|
||||
{
|
||||
typ: document.TableColumnTypeTime,
|
||||
val: "hello",
|
||||
isErr: true,
|
||||
},
|
||||
{
|
||||
typ: document.TableColumnTypeNumber,
|
||||
val: "1.0",
|
||||
isErr: false,
|
||||
data: &document.ColumnData{Type: document.TableColumnTypeNumber, ValNumber: ptr.Of(1.0)},
|
||||
},
|
||||
{
|
||||
typ: document.TableColumnTypeNumber,
|
||||
val: "hello",
|
||||
isErr: true,
|
||||
},
|
||||
{
|
||||
typ: document.TableColumnTypeBoolean,
|
||||
val: "true",
|
||||
isErr: false,
|
||||
data: &document.ColumnData{Type: document.TableColumnTypeBoolean, ValBoolean: ptr.Of(true)},
|
||||
},
|
||||
{
|
||||
typ: document.TableColumnTypeBoolean,
|
||||
val: "hello",
|
||||
isErr: true,
|
||||
},
|
||||
{
|
||||
typ: document.TableColumnTypeUnknown,
|
||||
val: "hello",
|
||||
isErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
v, err := assertValAs(c.typ, c.val)
|
||||
if c.isErr {
|
||||
convey.So(err, convey.ShouldNotBeNil)
|
||||
convey.So(v, convey.ShouldBeNil)
|
||||
} else {
|
||||
convey.So(err, convey.ShouldBeNil)
|
||||
convey.So(v, convey.ShouldEqual, c.data)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
116
backend/infra/impl/document/parser/builtin/chunk_custom.go
Normal file
116
backend/infra/impl/document/parser/builtin/chunk_custom.go
Normal file
@@ -0,0 +1,116 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package builtin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/components/document/parser"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
|
||||
)
|
||||
|
||||
var (
|
||||
spaceRegex = regexp.MustCompile(`\s+`)
|
||||
urlRegex = regexp.MustCompile(`https?://\S+|www\.\S+`)
|
||||
emailRegex = regexp.MustCompile(`[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}`)
|
||||
)
|
||||
|
||||
func chunkCustom(_ context.Context, text string, config *contract.Config, opts ...parser.Option) (docs []*schema.Document, err error) {
|
||||
cs := config.ChunkingStrategy
|
||||
if cs.Overlap >= cs.ChunkSize {
|
||||
return nil, fmt.Errorf("[chunkCustom] invalid param, overlap >= chunk_size")
|
||||
}
|
||||
|
||||
var (
|
||||
parts = strings.Split(text, cs.Separator)
|
||||
buffer []rune
|
||||
currentLength int64
|
||||
options = parser.GetCommonOptions(&parser.Options{ExtraMeta: map[string]any{}}, opts...)
|
||||
)
|
||||
|
||||
trim := func(text string) string {
|
||||
if cs.TrimURLAndEmail {
|
||||
text = urlRegex.ReplaceAllString(text, "")
|
||||
text = emailRegex.ReplaceAllString(text, "")
|
||||
}
|
||||
|
||||
if cs.TrimSpace {
|
||||
text = strings.TrimSpace(text)
|
||||
text = spaceRegex.ReplaceAllString(text, " ")
|
||||
}
|
||||
|
||||
return text
|
||||
}
|
||||
|
||||
add := func() {
|
||||
if len(buffer) == 0 {
|
||||
return
|
||||
}
|
||||
doc := &schema.Document{
|
||||
Content: string(buffer),
|
||||
MetaData: map[string]any{},
|
||||
}
|
||||
for k, v := range options.ExtraMeta {
|
||||
doc.MetaData[k] = v
|
||||
}
|
||||
docs = append(docs, doc)
|
||||
buffer = []rune{}
|
||||
}
|
||||
|
||||
processPart := func(part string) {
|
||||
runes := []rune(part)
|
||||
for partLength := int64(len(runes)); partLength > 0; partLength = int64(len(runes)) {
|
||||
pos := min(partLength, cs.ChunkSize-currentLength)
|
||||
buffer = append(buffer, runes[:pos]...)
|
||||
currentLength = int64(len(buffer))
|
||||
|
||||
if currentLength >= cs.ChunkSize {
|
||||
add()
|
||||
if cs.Overlap > 0 {
|
||||
buffer = getOverlap([]rune(docs[len(docs)-1].Content), cs.Overlap, cs.ChunkSize)
|
||||
currentLength = int64(len(buffer))
|
||||
} else {
|
||||
currentLength = 0
|
||||
}
|
||||
}
|
||||
runes = runes[pos:]
|
||||
}
|
||||
|
||||
add()
|
||||
}
|
||||
|
||||
for _, part := range parts {
|
||||
processPart(trim(part))
|
||||
}
|
||||
|
||||
add()
|
||||
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
func getOverlap(runes []rune, overlapRatio int64, chunkSize int64) []rune {
|
||||
overlap := int64(float64(chunkSize) * float64(overlapRatio) / 100)
|
||||
if int64(len(runes)) <= overlap {
|
||||
return runes
|
||||
}
|
||||
return runes[len(runes)-int(overlap):]
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package builtin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
|
||||
)
|
||||
|
||||
func TestChunkCustom(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
t.Run("test \n no overlap", func(t *testing.T) {
|
||||
text := "1. Eiffel Tower: Located in Paris, France, it is one of the most famous landmarks in the world, designed by Gustave Eiffel and built in 1889.\n2. The Great Wall: Located in China, it is one of the Seven Wonders of the World, built from the Qin Dynasty to the Ming Dynasty, with a total length of over 20000 kilometers.\n3. Grand Canyon National Park: Located in Arizona, USA, it is famous for its deep canyons and magnificent scenery, which are cut by the Colorado River.\n4. The Colosseum: Located in Rome, Italy, built between 70-80 AD, it was the largest circular arena in the ancient Roman Empire.\n5. Taj Mahal: Located in Agra, India, it was completed by Mughal Emperor Shah Jahan in 1653 to commemorate his wife and is one of the New Seven Wonders of the World.\n6. Sydney Opera House: Located in Sydney Harbour, Australia, it is one of the most iconic buildings of the 20th century, renowned for its unique sailboat design.\n7. Louvre Museum: Located in Paris, France, it is one of the largest museums in the world with a rich collection, including Leonardo da Vinci's Mona Lisa and Greece's Venus de Milo.\n8. Niagara Falls: located at the border of the United States and Canada, consisting of three main waterfalls, its spectacular scenery attracts millions of tourists every year.\n9. St. Sophia Cathedral: located in Istanbul, Türkiye, originally built in 537 A.D., it used to be an Orthodox cathedral and mosque, and now it is a museum.\n10. Machu Picchu: an ancient Inca site located on the plateau of the Andes Mountains in Peru, one of the New Seven Wonders of the World, with an altitude of over 2400 meters."
|
||||
|
||||
cs := &parser.ChunkingStrategy{
|
||||
ChunkType: parser.ChunkTypeCustom,
|
||||
ChunkSize: 1000,
|
||||
Separator: "\n",
|
||||
Overlap: 0,
|
||||
TrimSpace: true,
|
||||
TrimURLAndEmail: true,
|
||||
}
|
||||
|
||||
slices, err := chunkCustom(ctx, text, &parser.Config{ChunkingStrategy: cs})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, slices, 10)
|
||||
})
|
||||
}
|
||||
213
backend/infra/impl/document/parser/builtin/convert.go
Normal file
213
backend/infra/impl/document/parser/builtin/convert.go
Normal file
@@ -0,0 +1,213 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package builtin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
const (
|
||||
timeFormat = "2006-01-02 15:04:05"
|
||||
)
|
||||
|
||||
func assertValAs(typ document.TableColumnType, val string) (*document.ColumnData, error) {
|
||||
if val == "" {
|
||||
return &document.ColumnData{
|
||||
Type: typ,
|
||||
}, nil
|
||||
}
|
||||
|
||||
switch typ {
|
||||
case document.TableColumnTypeString:
|
||||
return &document.ColumnData{
|
||||
Type: document.TableColumnTypeString,
|
||||
ValString: &val,
|
||||
}, nil
|
||||
|
||||
case document.TableColumnTypeInteger:
|
||||
i, err := strconv.ParseInt(val, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &document.ColumnData{
|
||||
Type: document.TableColumnTypeInteger,
|
||||
ValInteger: &i,
|
||||
}, nil
|
||||
|
||||
case document.TableColumnTypeTime:
|
||||
if val == "" {
|
||||
var emptyTime time.Time
|
||||
return &document.ColumnData{
|
||||
Type: document.TableColumnTypeTime,
|
||||
ValTime: ptr.Of(emptyTime),
|
||||
}, nil
|
||||
}
|
||||
// 支持时间戳和时间字符串
|
||||
i, err := strconv.ParseInt(val, 10, 64)
|
||||
if err == nil {
|
||||
t := time.Unix(i, 0)
|
||||
return &document.ColumnData{
|
||||
Type: document.TableColumnTypeTime,
|
||||
ValTime: &t,
|
||||
}, nil
|
||||
|
||||
}
|
||||
t, err := time.Parse(timeFormat, val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &document.ColumnData{
|
||||
Type: document.TableColumnTypeTime,
|
||||
ValTime: &t,
|
||||
}, nil
|
||||
|
||||
case document.TableColumnTypeNumber:
|
||||
f, err := strconv.ParseFloat(val, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &document.ColumnData{
|
||||
Type: document.TableColumnTypeNumber,
|
||||
ValNumber: &f,
|
||||
}, nil
|
||||
|
||||
case document.TableColumnTypeBoolean:
|
||||
t, err := strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &document.ColumnData{
|
||||
Type: document.TableColumnTypeBoolean,
|
||||
ValBoolean: &t,
|
||||
}, nil
|
||||
case document.TableColumnTypeImage:
|
||||
return &document.ColumnData{
|
||||
Type: document.TableColumnTypeImage,
|
||||
ValImage: &val,
|
||||
}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("[assertValAs] type not support, type=%d, val=%s", typ, val)
|
||||
}
|
||||
}
|
||||
|
||||
func assertValAsForce(typ document.TableColumnType, val string, nullable bool) *document.ColumnData {
|
||||
cd := &document.ColumnData{
|
||||
Type: typ,
|
||||
}
|
||||
switch typ {
|
||||
case document.TableColumnTypeString:
|
||||
cd.ValString = &val
|
||||
case document.TableColumnTypeInteger:
|
||||
if i, err := strconv.ParseInt(val, 10, 64); err == nil {
|
||||
cd.ValInteger = ptr.Of(i)
|
||||
} else if !nullable {
|
||||
cd.ValInteger = ptr.Of(int64(0))
|
||||
}
|
||||
case document.TableColumnTypeTime:
|
||||
if t, err := time.Parse(timeFormat, val); err == nil {
|
||||
cd.ValTime = ptr.Of(t)
|
||||
} else if !nullable {
|
||||
cd.ValTime = ptr.Of(time.Time{})
|
||||
}
|
||||
case document.TableColumnTypeNumber:
|
||||
if f, err := strconv.ParseFloat(val, 64); err == nil {
|
||||
cd.ValNumber = ptr.Of(f)
|
||||
} else if !nullable {
|
||||
cd.ValNumber = ptr.Of(0.0)
|
||||
}
|
||||
case document.TableColumnTypeBoolean:
|
||||
if t, err := strconv.ParseBool(val); err == nil {
|
||||
cd.ValBoolean = ptr.Of(t)
|
||||
} else if !nullable {
|
||||
cd.ValBoolean = ptr.Of(false)
|
||||
}
|
||||
case document.TableColumnTypeImage:
|
||||
cd.ValImage = ptr.Of(val)
|
||||
default:
|
||||
cd.ValString = &val
|
||||
}
|
||||
|
||||
return cd
|
||||
}
|
||||
|
||||
func assertVal(val string) document.ColumnData {
|
||||
// TODO: 先不处理 image
|
||||
if val == "" {
|
||||
return document.ColumnData{
|
||||
Type: document.TableColumnTypeUnknown,
|
||||
ValString: &val,
|
||||
}
|
||||
}
|
||||
if t, err := strconv.ParseBool(val); err == nil {
|
||||
return document.ColumnData{
|
||||
Type: document.TableColumnTypeBoolean,
|
||||
ValBoolean: &t,
|
||||
}
|
||||
}
|
||||
if i, err := strconv.ParseInt(val, 10, 64); err == nil {
|
||||
return document.ColumnData{
|
||||
Type: document.TableColumnTypeInteger,
|
||||
ValInteger: &i,
|
||||
}
|
||||
}
|
||||
if f, err := strconv.ParseFloat(val, 64); err == nil {
|
||||
return document.ColumnData{
|
||||
Type: document.TableColumnTypeNumber,
|
||||
ValNumber: &f,
|
||||
}
|
||||
}
|
||||
if t, err := time.Parse(timeFormat, val); err == nil {
|
||||
return document.ColumnData{
|
||||
Type: document.TableColumnTypeTime,
|
||||
ValTime: &t,
|
||||
}
|
||||
}
|
||||
return document.ColumnData{
|
||||
Type: document.TableColumnTypeString,
|
||||
ValString: &val,
|
||||
}
|
||||
}
|
||||
|
||||
func transformColumnType(src, dst document.TableColumnType) document.TableColumnType {
|
||||
if src == document.TableColumnTypeUnknown {
|
||||
return dst
|
||||
}
|
||||
if dst == document.TableColumnTypeUnknown {
|
||||
return src
|
||||
}
|
||||
if dst == document.TableColumnTypeString {
|
||||
return dst
|
||||
}
|
||||
if src == dst {
|
||||
return dst
|
||||
}
|
||||
if src == document.TableColumnTypeInteger && dst == document.TableColumnTypeNumber {
|
||||
return dst
|
||||
}
|
||||
return document.TableColumnTypeString
|
||||
}
|
||||
|
||||
func charCount(text string) int64 {
|
||||
return int64(utf8.RuneCountInString(text))
|
||||
}
|
||||
36
backend/infra/impl/document/parser/builtin/image.go
Normal file
36
backend/infra/impl/document/parser/builtin/image.go
Normal file
@@ -0,0 +1,36 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package builtin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
)
|
||||
|
||||
func putImageObject(ctx context.Context, st storage.Storage, imgExt string, uid int64, img []byte) (format string, err error) {
|
||||
secret := createSecret(uid, imgExt)
|
||||
fileName := fmt.Sprintf("%d_%d_%s.%s", uid, time.Now().UnixNano(), secret, imgExt)
|
||||
objectName := fmt.Sprintf("%s/%s", knowledgePrefix, fileName)
|
||||
if err := st.PutObject(ctx, objectName, img); err != nil {
|
||||
return "", err
|
||||
}
|
||||
imgSrc := fmt.Sprintf(imgSrcFormat, objectName)
|
||||
return imgSrc, nil
|
||||
}
|
||||
77
backend/infra/impl/document/parser/builtin/manager.go
Normal file
77
backend/infra/impl/document/parser/builtin/manager.go
Normal file
@@ -0,0 +1,77 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package builtin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/ocr"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/goutil"
|
||||
)
|
||||
|
||||
func NewManager(storage storage.Storage, ocr ocr.OCR, imageAnnotationModel chatmodel.BaseChatModel) parser.Manager {
|
||||
return &manager{
|
||||
storage: storage,
|
||||
ocr: ocr,
|
||||
model: imageAnnotationModel,
|
||||
}
|
||||
}
|
||||
|
||||
type manager struct {
|
||||
ocr ocr.OCR
|
||||
storage storage.Storage
|
||||
model chatmodel.BaseChatModel
|
||||
}
|
||||
|
||||
func (m *manager) GetParser(config *parser.Config) (parser.Parser, error) {
|
||||
var pFn parseFn
|
||||
|
||||
if config.ParsingStrategy.HeaderLine == 0 && config.ParsingStrategy.DataStartLine == 0 {
|
||||
config.ParsingStrategy.DataStartLine = 1
|
||||
} else if config.ParsingStrategy.HeaderLine >= config.ParsingStrategy.DataStartLine {
|
||||
return nil, fmt.Errorf("[GetParser] invalid header line and data start line, header=%d, data_start=%d",
|
||||
config.ParsingStrategy.HeaderLine, config.ParsingStrategy.DataStartLine)
|
||||
}
|
||||
|
||||
switch config.FileExtension {
|
||||
case parser.FileExtensionPDF:
|
||||
pFn = parseByPython(config, m.storage, m.ocr, goutil.GetPython3Path(), goutil.GetPythonFilePath("parse_pdf.py"))
|
||||
case parser.FileExtensionTXT:
|
||||
pFn = parseText(config)
|
||||
case parser.FileExtensionMarkdown:
|
||||
pFn = parseMarkdown(config, m.storage, m.ocr)
|
||||
case parser.FileExtensionDocx:
|
||||
pFn = parseByPython(config, m.storage, m.ocr, goutil.GetPython3Path(), goutil.GetPythonFilePath("parse_docx.py"))
|
||||
case parser.FileExtensionCSV:
|
||||
pFn = parseCSV(config)
|
||||
case parser.FileExtensionXLSX:
|
||||
pFn = parseXLSX(config)
|
||||
case parser.FileExtensionJSON:
|
||||
pFn = parseJSON(config)
|
||||
case parser.FileExtensionJsonMaps:
|
||||
pFn = parseJSONMaps(config)
|
||||
case parser.FileExtensionJPG, parser.FileExtensionJPEG, parser.FileExtensionPNG:
|
||||
pFn = parseImage(config, m.model)
|
||||
default:
|
||||
return nil, fmt.Errorf("[Parse] document type not support, type=%s", config.FileExtension)
|
||||
}
|
||||
|
||||
return &p{parseFn: pFn}, nil
|
||||
}
|
||||
53
backend/infra/impl/document/parser/builtin/parse_csv.go
Normal file
53
backend/infra/impl/document/parser/builtin/parse_csv.go
Normal file
@@ -0,0 +1,53 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package builtin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/csv"
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
"github.com/cloudwego/eino/components/document/parser"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/dimchansky/utfbom"
|
||||
|
||||
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
|
||||
)
|
||||
|
||||
func parseCSV(config *contract.Config) parseFn {
|
||||
return func(ctx context.Context, reader io.Reader, opts ...parser.Option) (docs []*schema.Document, err error) {
|
||||
iter := &csvIterator{csv.NewReader(utfbom.SkipOnly(reader))}
|
||||
return parseByRowIterator(iter, config, opts...)
|
||||
}
|
||||
}
|
||||
|
||||
type csvIterator struct {
|
||||
reader *csv.Reader
|
||||
}
|
||||
|
||||
func (c *csvIterator) NextRow() (row []string, end bool, err error) {
|
||||
row, e := c.reader.Read()
|
||||
if e != nil {
|
||||
if errors.Is(e, io.EOF) {
|
||||
return nil, true, nil
|
||||
}
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
return row, false, nil
|
||||
}
|
||||
200
backend/infra/impl/document/parser/builtin/parse_csv_test.go
Normal file
200
backend/infra/impl/document/parser/builtin/parse_csv_test.go
Normal file
@@ -0,0 +1,200 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package builtin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/components/document/parser"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
|
||||
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
|
||||
)
|
||||
|
||||
func TestParseCSV(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
b, err := os.ReadFile("./test_data/test_csv.csv")
|
||||
assert.NoError(t, err)
|
||||
|
||||
r1 := bytes.NewReader(b)
|
||||
c1 := &contract.Config{
|
||||
FileExtension: contract.FileExtensionCSV,
|
||||
ParsingStrategy: &contract.ParsingStrategy{
|
||||
HeaderLine: 0,
|
||||
DataStartLine: 1,
|
||||
RowsCount: 20,
|
||||
},
|
||||
ChunkingStrategy: nil,
|
||||
}
|
||||
p1 := parseCSV(c1)
|
||||
docs, err := p1(ctx, r1, parser.WithExtraMeta(map[string]any{
|
||||
"document_id": int64(123),
|
||||
"knowledge_id": int64(456),
|
||||
}))
|
||||
assert.NoError(t, err)
|
||||
for i, doc := range docs {
|
||||
assertSheet(t, i, doc)
|
||||
}
|
||||
|
||||
// parse
|
||||
r2 := bytes.NewReader(b)
|
||||
c2 := &contract.Config{
|
||||
FileExtension: contract.FileExtensionCSV,
|
||||
ParsingStrategy: &contract.ParsingStrategy{
|
||||
HeaderLine: 0,
|
||||
DataStartLine: 1,
|
||||
RowsCount: 10,
|
||||
Columns: []*document.Column{
|
||||
{
|
||||
ID: 0,
|
||||
Name: "col_string_indexing",
|
||||
Type: document.TableColumnTypeString,
|
||||
Nullable: false,
|
||||
Sequence: 0,
|
||||
},
|
||||
{
|
||||
ID: 0,
|
||||
Name: "col_string",
|
||||
Type: document.TableColumnTypeString,
|
||||
Nullable: false,
|
||||
Sequence: 1,
|
||||
},
|
||||
{
|
||||
ID: 0,
|
||||
Name: "col_int",
|
||||
Type: document.TableColumnTypeInteger,
|
||||
Nullable: false,
|
||||
Sequence: 2,
|
||||
},
|
||||
{
|
||||
ID: 0,
|
||||
Name: "col_number",
|
||||
Type: document.TableColumnTypeNumber,
|
||||
Nullable: true,
|
||||
Sequence: 3,
|
||||
},
|
||||
{
|
||||
ID: 0,
|
||||
Name: "col_bool",
|
||||
Type: document.TableColumnTypeBoolean,
|
||||
Nullable: true,
|
||||
Sequence: 4,
|
||||
},
|
||||
{
|
||||
ID: 0,
|
||||
Name: "col_time",
|
||||
Type: document.TableColumnTypeTime,
|
||||
Nullable: true,
|
||||
Sequence: 5,
|
||||
},
|
||||
},
|
||||
},
|
||||
ChunkingStrategy: nil,
|
||||
}
|
||||
p2 := parseCSV(c2)
|
||||
docs, err = p2(ctx, r2, parser.WithExtraMeta(map[string]any{
|
||||
"document_id": int64(123),
|
||||
"knowledge_id": int64(456),
|
||||
}))
|
||||
assert.NoError(t, err)
|
||||
for i, doc := range docs {
|
||||
assertSheet(t, i, doc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseCSVBadCases(t *testing.T) {
|
||||
t.Run("test nil row", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
f, err := os.Open("test_data/test_csv_badcase_1.csv")
|
||||
assert.NoError(t, err)
|
||||
b, err := io.ReadAll(f)
|
||||
assert.NoError(t, err)
|
||||
|
||||
pfn := parseCSV(&contract.Config{
|
||||
FileExtension: "csv",
|
||||
ParsingStrategy: &contract.ParsingStrategy{
|
||||
ExtractImage: true,
|
||||
ExtractTable: true,
|
||||
ImageOCR: false,
|
||||
SheetID: nil,
|
||||
HeaderLine: 0,
|
||||
DataStartLine: 1,
|
||||
RowsCount: 0,
|
||||
IsAppend: false,
|
||||
Columns: nil,
|
||||
IgnoreColumnTypeErr: true,
|
||||
ImageAnnotationType: 0,
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := pfn(ctx, bytes.NewReader(b))
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, len(resp) > 0)
|
||||
cols, err := document.GetDocumentColumns(resp[0])
|
||||
assert.NoError(t, err)
|
||||
cols[5].Nullable = false
|
||||
npfn := parseCSV(&contract.Config{
|
||||
FileExtension: "csv",
|
||||
ParsingStrategy: &contract.ParsingStrategy{
|
||||
ExtractImage: true,
|
||||
ExtractTable: true,
|
||||
ImageOCR: false,
|
||||
SheetID: nil,
|
||||
HeaderLine: 0,
|
||||
DataStartLine: 1,
|
||||
RowsCount: 0,
|
||||
IsAppend: false,
|
||||
Columns: cols,
|
||||
IgnoreColumnTypeErr: true,
|
||||
ImageAnnotationType: 0,
|
||||
},
|
||||
})
|
||||
resp, err = npfn(ctx, bytes.NewReader(b))
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, len(resp) > 0)
|
||||
for _, item := range resp {
|
||||
data, err := document.GetDocumentColumnData(item)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, data[5].GetValue())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func assertSheet(t *testing.T, i int, doc *schema.Document) {
|
||||
fmt.Printf("sheet[%d]:\n", i)
|
||||
assert.NotNil(t, doc.MetaData)
|
||||
assert.NotNil(t, doc.MetaData[document.MetaDataKeyColumns])
|
||||
cols, ok := doc.MetaData[document.MetaDataKeyColumns].([]*document.Column)
|
||||
assert.True(t, ok)
|
||||
assert.NotNil(t, doc.MetaData[document.MetaDataKeyColumnData])
|
||||
row, ok := doc.MetaData[document.MetaDataKeyColumnData].([]*document.ColumnData)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, int64(123), doc.MetaData["document_id"].(int64))
|
||||
assert.Equal(t, int64(456), doc.MetaData["knowledge_id"].(int64))
|
||||
for j := range row {
|
||||
col := cols[j]
|
||||
val := row[j]
|
||||
fmt.Printf("row[%d]: %v=%v\n", j, col.Name, val.GetStringValue())
|
||||
}
|
||||
}
|
||||
172
backend/infra/impl/document/parser/builtin/parse_docx.py
Normal file
172
backend/infra/impl/document/parser/builtin/parse_docx.py
Normal file
@@ -0,0 +1,172 @@
|
||||
import io
|
||||
import os
|
||||
import json
|
||||
import sys
|
||||
import base64
|
||||
import logging
|
||||
import time
|
||||
from abc import ABC
|
||||
from typing import List, IO
|
||||
|
||||
from docx import ImagePart
|
||||
from docx.oxml import CT_P, CT_Tbl
|
||||
from docx.table import Table
|
||||
from docx.text.paragraph import Paragraph
|
||||
from docx import Document
|
||||
from PIL import Image
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DocxLoader(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
file_content: IO[bytes],
|
||||
extract_images: bool = True,
|
||||
extract_tables: bool = True,
|
||||
):
|
||||
self.file_content = file_content
|
||||
self.extract_images = extract_images
|
||||
self.extract_tables = extract_tables
|
||||
|
||||
def load(self) -> List[dict]:
|
||||
result = []
|
||||
doc = Document(self.file_content)
|
||||
it = iter(doc.element.body)
|
||||
text = ""
|
||||
|
||||
for part in it:
|
||||
blocks = self.parse_part(part, doc)
|
||||
if blocks is None or len(blocks) == 0:
|
||||
continue
|
||||
for block in blocks:
|
||||
if self.extract_images and isinstance(block, list):
|
||||
for b in block:
|
||||
image = io.BytesIO()
|
||||
try:
|
||||
Image.open(io.BytesIO(b.image.blob)).save(image, format="png")
|
||||
except Exception as e:
|
||||
logging.error(f"load image failed, time={time.asctime()}, err:{e}")
|
||||
raise RuntimeError("ExtractImageError")
|
||||
|
||||
if len(text) > 0:
|
||||
result.append(
|
||||
{
|
||||
"content": text,
|
||||
"type": "text",
|
||||
}
|
||||
)
|
||||
text = ""
|
||||
|
||||
result.append(
|
||||
{
|
||||
"content": base64.b64encode(image.getvalue()).decode('utf-8'),
|
||||
"type": "image",
|
||||
}
|
||||
)
|
||||
|
||||
if isinstance(block, Paragraph):
|
||||
text += block.text
|
||||
|
||||
if self.extract_tables and isinstance(block, Table):
|
||||
rows = block.rows
|
||||
if len(text) > 0:
|
||||
result.append(
|
||||
{
|
||||
"content": text,
|
||||
"type": "text",
|
||||
}
|
||||
)
|
||||
text = ""
|
||||
table = self.convert_table(rows)
|
||||
result.append(
|
||||
{
|
||||
"table": table,
|
||||
"type": "table",
|
||||
}
|
||||
)
|
||||
if text:
|
||||
text += "\n\n"
|
||||
if len(text) > 0:
|
||||
result.append(
|
||||
{
|
||||
"content": text,
|
||||
"type": "text",
|
||||
}
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def parse_part(self, block, doc: Document):
|
||||
if isinstance(block, CT_P):
|
||||
blocks = []
|
||||
para = Paragraph(block, doc)
|
||||
image_part = self.get_image_part(para, doc)
|
||||
if image_part and para.text:
|
||||
blocks.extend(self.parse_run(para))
|
||||
elif image_part:
|
||||
blocks.append(image_part)
|
||||
elif para.text:
|
||||
blocks.append(para)
|
||||
return blocks
|
||||
elif isinstance(block, CT_Tbl):
|
||||
return [Table(block, doc)]
|
||||
|
||||
def parse_run(self, para: Paragraph):
|
||||
runs = para.runs
|
||||
paras = []
|
||||
if runs is None or len(runs) == 0:
|
||||
return paras
|
||||
for run in runs:
|
||||
if run is None or run.element is None:
|
||||
continue
|
||||
p = Paragraph(run.element, para)
|
||||
image_part = self.get_image_part(p, para)
|
||||
if image_part:
|
||||
paras.append(image_part)
|
||||
else:
|
||||
paras.append(p)
|
||||
return paras
|
||||
|
||||
@staticmethod
|
||||
def get_image_part(graph: Paragraph, doc: Document):
|
||||
images = graph._element.xpath(".//pic:pic")
|
||||
image_parts = []
|
||||
for image in images:
|
||||
for img_id in image.xpath(".//a:blip/@r:embed"):
|
||||
part = doc.part.related_parts[img_id]
|
||||
if isinstance(part, ImagePart):
|
||||
image_parts.append(part)
|
||||
return image_parts
|
||||
|
||||
@staticmethod
|
||||
def convert_table(rows) -> List[List[str]]:
|
||||
resp_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
resp_row = []
|
||||
for j, cell in enumerate(row.cells):
|
||||
resp_row.append(cell.text if cell is not None else '')
|
||||
resp_rows.append(resp_row)
|
||||
|
||||
return resp_rows
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
w = os.fdopen(3, "wb", )
|
||||
r = os.fdopen(4, "rb", )
|
||||
|
||||
try:
|
||||
req = json.load(r)
|
||||
ei, et = req['extract_images'], req['extract_tables']
|
||||
loader = DocxLoader(file_content=io.BytesIO(sys.stdin.buffer.read()), extract_images=ei, extract_tables=et)
|
||||
resp = loader.load()
|
||||
print(f"Extracted {len(resp)} items")
|
||||
result = json.dumps({"content": resp}, ensure_ascii=False)
|
||||
w.write(str.encode(result))
|
||||
w.flush()
|
||||
w.close()
|
||||
print("Docx parse done")
|
||||
except Exception as e:
|
||||
print("Docx parse error", e)
|
||||
w.write(str.encode(json.dumps({"error": str(e)})))
|
||||
w.flush()
|
||||
w.close()
|
||||
91
backend/infra/impl/document/parser/builtin/parse_image.go
Normal file
91
backend/infra/impl/document/parser/builtin/parse_image.go
Normal file
@@ -0,0 +1,91 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package builtin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/cloudwego/eino/components/document/parser"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
|
||||
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
func parseImage(config *contract.Config, model chatmodel.BaseChatModel) parseFn {
|
||||
return func(ctx context.Context, reader io.Reader, opts ...parser.Option) (docs []*schema.Document, err error) {
|
||||
options := parser.GetCommonOptions(&parser.Options{}, opts...)
|
||||
doc := &schema.Document{
|
||||
MetaData: map[string]any{},
|
||||
}
|
||||
for k, v := range options.ExtraMeta {
|
||||
doc.MetaData[k] = v
|
||||
}
|
||||
|
||||
switch config.ParsingStrategy.ImageAnnotationType {
|
||||
case contract.ImageAnnotationTypeModel:
|
||||
if model == nil {
|
||||
return nil, errorx.New(errno.ErrKnowledgeNonRetryableCode, errorx.KV("reason", "model is not provided"))
|
||||
}
|
||||
|
||||
bytes, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b64 := base64.StdEncoding.EncodeToString(bytes)
|
||||
mime := fmt.Sprintf("image/%s", config.FileExtension)
|
||||
url := fmt.Sprintf("data:%s;base64,%s", mime, b64)
|
||||
|
||||
input := &schema.Message{
|
||||
Role: schema.User,
|
||||
MultiContent: []schema.ChatMessagePart{
|
||||
{
|
||||
Type: schema.ChatMessagePartTypeText,
|
||||
//Text: "Give a short description of the image.", // TODO: prompt in current language
|
||||
Text: "简短描述下这张图片",
|
||||
},
|
||||
{
|
||||
Type: schema.ChatMessagePartTypeImageURL,
|
||||
ImageURL: &schema.ChatMessageImageURL{
|
||||
URL: url,
|
||||
MIMEType: mime,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
output, err := model.Generate(ctx, []*schema.Message{input})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[parseImage] model generate failed: %w", err)
|
||||
}
|
||||
|
||||
doc.Content = output.Content
|
||||
case contract.ImageAnnotationTypeManual:
|
||||
// do nothing
|
||||
default:
|
||||
return nil, fmt.Errorf("[parseImage] unknown image annotation type=%d", config.ParsingStrategy.ImageAnnotationType)
|
||||
}
|
||||
|
||||
return []*schema.Document{doc}, nil
|
||||
}
|
||||
}
|
||||
170
backend/infra/impl/document/parser/builtin/parse_iter.go
Normal file
170
backend/infra/impl/document/parser/builtin/parse_iter.go
Normal file
@@ -0,0 +1,170 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package builtin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/cloudwego/eino/components/document/parser"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
|
||||
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
|
||||
)
|
||||
|
||||
type rowIterator interface {
|
||||
NextRow() (row []string, end bool, err error)
|
||||
}
|
||||
|
||||
func parseByRowIterator(iter rowIterator, config *contract.Config, opts ...parser.Option) (
|
||||
docs []*schema.Document, err error) {
|
||||
|
||||
ps := config.ParsingStrategy
|
||||
options := parser.GetCommonOptions(&parser.Options{}, opts...)
|
||||
i := 0
|
||||
columnsProvides := ps.IsAppend || len(ps.Columns) > 0
|
||||
rev := make(map[int]*document.Column)
|
||||
|
||||
var (
|
||||
expColumns []*document.Column
|
||||
expData [][]*document.ColumnData
|
||||
)
|
||||
|
||||
for {
|
||||
row, end, err := iter.NextRow()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if end {
|
||||
break
|
||||
}
|
||||
if i == ps.HeaderLine {
|
||||
if columnsProvides {
|
||||
expColumns = ps.Columns
|
||||
} else {
|
||||
for j, col := range row {
|
||||
expColumns = append(expColumns, &document.Column{
|
||||
Name: col,
|
||||
Type: document.TableColumnTypeUnknown,
|
||||
Sequence: j,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for j := range expColumns {
|
||||
tc := expColumns[j]
|
||||
rev[tc.Sequence] = tc
|
||||
}
|
||||
}
|
||||
|
||||
if i >= ps.DataStartLine {
|
||||
var rowData []*document.ColumnData
|
||||
for j := range row {
|
||||
colSchema, found := rev[j]
|
||||
if !found { // 列裁剪
|
||||
continue
|
||||
}
|
||||
|
||||
val := row[j]
|
||||
|
||||
if columnsProvides {
|
||||
var data *document.ColumnData
|
||||
if config.ParsingStrategy.IgnoreColumnTypeErr {
|
||||
data = assertValAsForce(colSchema.Type, val, colSchema.Nullable)
|
||||
} else {
|
||||
data, err = assertValAs(colSchema.Type, val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
data.ColumnID = colSchema.ID
|
||||
data.ColumnName = colSchema.Name
|
||||
rowData = append(rowData, data)
|
||||
} else {
|
||||
exp := assertVal(val)
|
||||
colSchema.Type = transformColumnType(colSchema.Type, exp.Type)
|
||||
rowData = append(rowData, &document.ColumnData{
|
||||
ColumnID: colSchema.ID,
|
||||
ColumnName: colSchema.Name,
|
||||
Type: document.TableColumnTypeUnknown,
|
||||
ValString: &val,
|
||||
})
|
||||
}
|
||||
}
|
||||
if rowData != nil {
|
||||
expData = append(expData, rowData)
|
||||
}
|
||||
}
|
||||
|
||||
i++
|
||||
if ps.RowsCount != 0 && len(docs) == ps.RowsCount {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !columnsProvides {
|
||||
// align data type when columns are provided
|
||||
for _, col := range expColumns {
|
||||
if col.Type == document.TableColumnTypeUnknown {
|
||||
col.Type = document.TableColumnTypeString
|
||||
}
|
||||
}
|
||||
for _, row := range expData {
|
||||
if err = alignTableSliceValue(expColumns, row); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(expData) == 0 {
|
||||
// return a special document with columns only if there is no data
|
||||
doc := &schema.Document{
|
||||
MetaData: map[string]any{
|
||||
document.MetaDataKeyColumns: expColumns,
|
||||
document.MetaDataKeyColumnsOnly: struct{}{},
|
||||
},
|
||||
}
|
||||
for k, v := range options.ExtraMeta {
|
||||
doc.MetaData[k] = v
|
||||
}
|
||||
return []*schema.Document{doc}, nil
|
||||
}
|
||||
|
||||
for j := range expData {
|
||||
contentMapping := make(map[string]string)
|
||||
for _, col := range expData[j] {
|
||||
contentMapping[col.ColumnName] = col.GetStringValue()
|
||||
}
|
||||
b, err := json.Marshal(contentMapping)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
doc := &schema.Document{
|
||||
Content: string(b), // set for tables in text
|
||||
MetaData: map[string]any{
|
||||
document.MetaDataKeyColumns: expColumns,
|
||||
document.MetaDataKeyColumnData: expData[j],
|
||||
},
|
||||
}
|
||||
for k, v := range options.ExtraMeta {
|
||||
doc.MetaData[k] = v
|
||||
}
|
||||
docs = append(docs, doc)
|
||||
}
|
||||
|
||||
return docs, nil
|
||||
}
|
||||
97
backend/infra/impl/document/parser/builtin/parse_json.go
Normal file
97
backend/infra/impl/document/parser/builtin/parse_json.go
Normal file
@@ -0,0 +1,97 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package builtin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/cloudwego/eino/components/document/parser"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
|
||||
)
|
||||
|
||||
func parseJSON(config *contract.Config) parseFn {
|
||||
return func(ctx context.Context, reader io.Reader, opts ...parser.Option) (docs []*schema.Document, err error) {
|
||||
b, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var rawSlices []map[string]string
|
||||
if err = json.Unmarshal(b, &rawSlices); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(rawSlices) == 0 {
|
||||
return nil, fmt.Errorf("[parseJSON] json data is empty")
|
||||
}
|
||||
|
||||
var header []string
|
||||
if config.ParsingStrategy.IsAppend {
|
||||
for _, col := range config.ParsingStrategy.Columns {
|
||||
header = append(header, col.Name)
|
||||
}
|
||||
} else {
|
||||
for k := range rawSlices[0] {
|
||||
// init 取首个 json item 中 key 的随机顺序
|
||||
header = append(header, k)
|
||||
}
|
||||
}
|
||||
|
||||
iter := &jsonIterator{
|
||||
header: header,
|
||||
rows: rawSlices,
|
||||
i: 0,
|
||||
}
|
||||
|
||||
return parseByRowIterator(iter, config, opts...)
|
||||
}
|
||||
}
|
||||
|
||||
type jsonIterator struct {
|
||||
header []string
|
||||
rows []map[string]string
|
||||
i int
|
||||
}
|
||||
|
||||
func (j *jsonIterator) NextRow() (row []string, end bool, err error) {
|
||||
if j.i == 0 {
|
||||
j.i++
|
||||
return j.header, false, nil
|
||||
}
|
||||
|
||||
if j.i == len(j.rows)+1 {
|
||||
return nil, true, nil
|
||||
}
|
||||
|
||||
raw := j.rows[j.i-1]
|
||||
j.i++
|
||||
for _, h := range j.header {
|
||||
val, found := raw[h]
|
||||
if !found {
|
||||
row = append(row, "")
|
||||
} else {
|
||||
row = append(row, val)
|
||||
}
|
||||
}
|
||||
|
||||
return row, false, nil
|
||||
}
|
||||
130
backend/infra/impl/document/parser/builtin/parse_json_maps.go
Normal file
130
backend/infra/impl/document/parser/builtin/parse_json_maps.go
Normal file
@@ -0,0 +1,130 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package builtin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/cloudwego/eino/components/document/parser"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
|
||||
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
|
||||
)
|
||||
|
||||
func parseJSONMaps(config *contract.Config) parseFn {
|
||||
return func(ctx context.Context, reader io.Reader, opts ...parser.Option) (docs []*schema.Document, err error) {
|
||||
b, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var customContent []map[string]string
|
||||
if err = json.Unmarshal(b, &customContent); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if config.ParsingStrategy == nil {
|
||||
config.ParsingStrategy = &contract.ParsingStrategy{
|
||||
HeaderLine: 0,
|
||||
DataStartLine: 1,
|
||||
RowsCount: 0,
|
||||
}
|
||||
}
|
||||
|
||||
iter := &customContentContainer{
|
||||
i: 0,
|
||||
colIdx: nil,
|
||||
customContent: customContent,
|
||||
curColumns: config.ParsingStrategy.Columns,
|
||||
}
|
||||
|
||||
newConfig := &contract.Config{
|
||||
FileExtension: config.FileExtension,
|
||||
ParsingStrategy: &contract.ParsingStrategy{
|
||||
SheetID: config.ParsingStrategy.SheetID,
|
||||
HeaderLine: 0,
|
||||
DataStartLine: 1,
|
||||
RowsCount: 0,
|
||||
IsAppend: config.ParsingStrategy.IsAppend,
|
||||
Columns: config.ParsingStrategy.Columns,
|
||||
},
|
||||
ChunkingStrategy: config.ChunkingStrategy,
|
||||
}
|
||||
|
||||
return parseByRowIterator(iter, newConfig, opts...)
|
||||
}
|
||||
}
|
||||
|
||||
type customContentContainer struct {
|
||||
i int
|
||||
colIdx map[string]int
|
||||
customContent []map[string]string
|
||||
curColumns []*document.Column
|
||||
}
|
||||
|
||||
func (c *customContentContainer) NextRow() (row []string, end bool, err error) {
|
||||
if c.i == 0 && c.colIdx == nil {
|
||||
if len(c.customContent) == 0 {
|
||||
return nil, false, fmt.Errorf("[customContentContainer] data is nil")
|
||||
}
|
||||
|
||||
headerRow := c.customContent[0]
|
||||
founded := make(map[string]struct{})
|
||||
colIdx := make(map[string]int, len(headerRow))
|
||||
|
||||
for _, col := range c.curColumns {
|
||||
name := col.Name
|
||||
if _, found := headerRow[name]; found {
|
||||
founded[name] = struct{}{}
|
||||
colIdx[name] = len(colIdx)
|
||||
row = append(row, name)
|
||||
}
|
||||
}
|
||||
for name := range headerRow {
|
||||
if _, found := founded[name]; !found {
|
||||
colIdx[name] = len(colIdx)
|
||||
row = append(row, name)
|
||||
}
|
||||
}
|
||||
|
||||
c.colIdx = colIdx
|
||||
return row, false, nil
|
||||
}
|
||||
|
||||
if c.i >= len(c.customContent) {
|
||||
return nil, true, nil
|
||||
}
|
||||
|
||||
content := c.customContent[c.i]
|
||||
c.i++
|
||||
row = make([]string, len(content))
|
||||
|
||||
for k, v := range content {
|
||||
idx, found := c.colIdx[k]
|
||||
if !found {
|
||||
return nil, false, fmt.Errorf("[customContentContainer] column not found, name=%s", k)
|
||||
}
|
||||
|
||||
row[idx] = v
|
||||
}
|
||||
|
||||
return row, false, nil
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package builtin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/components/document/parser"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
|
||||
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
|
||||
)
|
||||
|
||||
func TestParseTableCustomContent(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
b := []byte(`[{"col_string_indexing":"hello","col_string":"asd","col_int":"1","col_number":"1","col_bool":"true","col_time":"2006-01-02 15:04:05"},{"col_string_indexing":"bye","col_string":"","col_int":"2","col_number":"2.0","col_bool":"false","col_time":""}]`)
|
||||
reader := bytes.NewReader(b)
|
||||
config := &contract.Config{
|
||||
FileExtension: contract.FileExtensionJsonMaps,
|
||||
ParsingStrategy: &contract.ParsingStrategy{
|
||||
HeaderLine: 0,
|
||||
DataStartLine: 1,
|
||||
RowsCount: 10,
|
||||
Columns: []*document.Column{
|
||||
{
|
||||
ID: 0,
|
||||
Name: "col_string_indexing",
|
||||
Type: document.TableColumnTypeString,
|
||||
Nullable: false,
|
||||
Sequence: 0,
|
||||
},
|
||||
{
|
||||
ID: 0,
|
||||
Name: "col_string",
|
||||
Type: document.TableColumnTypeString,
|
||||
Nullable: false,
|
||||
Sequence: 1,
|
||||
},
|
||||
{
|
||||
ID: 0,
|
||||
Name: "col_int",
|
||||
Type: document.TableColumnTypeInteger,
|
||||
Nullable: false,
|
||||
Sequence: 2,
|
||||
},
|
||||
{
|
||||
ID: 0,
|
||||
Name: "col_number",
|
||||
Type: document.TableColumnTypeNumber,
|
||||
Nullable: true,
|
||||
Sequence: 3,
|
||||
},
|
||||
{
|
||||
ID: 0,
|
||||
Name: "col_bool",
|
||||
Type: document.TableColumnTypeBoolean,
|
||||
Nullable: true,
|
||||
Sequence: 4,
|
||||
},
|
||||
{
|
||||
ID: 0,
|
||||
Name: "col_time",
|
||||
Type: document.TableColumnTypeTime,
|
||||
Nullable: true,
|
||||
Sequence: 5,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
pfn := parseJSONMaps(config)
|
||||
docs, err := pfn(ctx, reader, parser.WithExtraMeta(map[string]any{
|
||||
"document_id": int64(123),
|
||||
"knowledge_id": int64(456),
|
||||
}))
|
||||
assert.NoError(t, err)
|
||||
for i, doc := range docs {
|
||||
assertSheet(t, i, doc)
|
||||
}
|
||||
}
|
||||
133
backend/infra/impl/document/parser/builtin/parse_json_test.go
Normal file
133
backend/infra/impl/document/parser/builtin/parse_json_test.go
Normal file
@@ -0,0 +1,133 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package builtin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/components/document/parser"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
|
||||
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
|
||||
)
|
||||
|
||||
func TestParseJSON(t *testing.T) {
|
||||
b := []byte(`[
|
||||
{
|
||||
"department": "心血管科",
|
||||
"title": "高血压患者能吃党参吗?",
|
||||
"question": "我有高血压这两天女婿来的时候给我拿了些党参泡水喝,您好高血压可以吃党参吗?",
|
||||
"answer": "高血压病人可以口服党参的。党参有降血脂,降血压的作用,可以彻底消除血液中的垃圾,从而对冠心病以及心血管疾病的患者都有一定的稳定预防工作作用,因此平时口服党参能远离三高的危害。另外党参除了益气养血,降低中枢神经作用,调整消化系统功能,健脾补肺的功能。感谢您的进行咨询,期望我的解释对你有所帮助。"
|
||||
},
|
||||
{
|
||||
"department": "消化科",
|
||||
"title": "哪家医院能治胃反流",
|
||||
"question": "烧心,打隔,咳嗽低烧,以有4年多",
|
||||
"answer": "建议你用奥美拉唑同时,加用吗丁啉或莫沙必利或援生力维,另外还可以加用达喜片"
|
||||
}
|
||||
]`)
|
||||
|
||||
reader := bytes.NewReader(b)
|
||||
|
||||
config := &contract.Config{
|
||||
FileExtension: contract.FileExtensionJSON,
|
||||
ParsingStrategy: &contract.ParsingStrategy{
|
||||
HeaderLine: 0,
|
||||
DataStartLine: 1,
|
||||
RowsCount: 2,
|
||||
},
|
||||
ChunkingStrategy: nil,
|
||||
}
|
||||
pfn := parseJSON(config)
|
||||
docs, err := pfn(context.Background(), reader, parser.WithExtraMeta(map[string]any{
|
||||
"document_id": int64(123),
|
||||
"knowledge_id": int64(456),
|
||||
}))
|
||||
assert.NoError(t, err)
|
||||
for i, doc := range docs {
|
||||
assertSheet(t, i, doc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONWithSchema(t *testing.T) {
|
||||
b := []byte(`[
|
||||
{
|
||||
"department": "心血管科",
|
||||
"title": "高血压患者能吃党参吗?",
|
||||
"question": "我有高血压这两天女婿来的时候给我拿了些党参泡水喝,您好高血压可以吃党参吗?",
|
||||
"answer": "高血压病人可以口服党参的。党参有降血脂,降血压的作用,可以彻底消除血液中的垃圾,从而对冠心病以及心血管疾病的患者都有一定的稳定预防工作作用,因此平时口服党参能远离三高的危害。另外党参除了益气养血,降低中枢神经作用,调整消化系统功能,健脾补肺的功能。感谢您的进行咨询,期望我的解释对你有所帮助。"
|
||||
},
|
||||
{
|
||||
"department": "消化科",
|
||||
"title": "哪家医院能治胃反流",
|
||||
"question": "烧心,打隔,咳嗽低烧,以有4年多",
|
||||
"answer": "建议你用奥美拉唑同时,加用吗丁啉或莫沙必利或援生力维,另外还可以加用达喜片"
|
||||
}
|
||||
]`)
|
||||
|
||||
reader := bytes.NewReader(b)
|
||||
config := &contract.Config{
|
||||
FileExtension: contract.FileExtensionJSON,
|
||||
ParsingStrategy: &contract.ParsingStrategy{
|
||||
HeaderLine: 0,
|
||||
DataStartLine: 1,
|
||||
RowsCount: 2,
|
||||
Columns: []*document.Column{
|
||||
{
|
||||
ID: 101,
|
||||
Name: "department",
|
||||
Type: document.TableColumnTypeString,
|
||||
Nullable: false,
|
||||
Sequence: 0,
|
||||
},
|
||||
{
|
||||
ID: 102,
|
||||
Name: "title",
|
||||
Type: document.TableColumnTypeString,
|
||||
Nullable: false,
|
||||
Sequence: 1,
|
||||
},
|
||||
{
|
||||
ID: 103,
|
||||
Name: "question",
|
||||
Type: document.TableColumnTypeString,
|
||||
Nullable: false,
|
||||
Sequence: 2,
|
||||
},
|
||||
{
|
||||
ID: 104,
|
||||
Name: "answer",
|
||||
Type: document.TableColumnTypeString,
|
||||
Nullable: false,
|
||||
Sequence: 3,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
pfn := parseJSON(config)
|
||||
docs, err := pfn(context.Background(), reader, parser.WithExtraMeta(map[string]any{
|
||||
"document_id": int64(123),
|
||||
"knowledge_id": int64(456),
|
||||
}))
|
||||
assert.NoError(t, err)
|
||||
for i, doc := range docs {
|
||||
assertSheet(t, i, doc)
|
||||
}
|
||||
}
|
||||
221
backend/infra/impl/document/parser/builtin/parse_markdown.go
Normal file
221
backend/infra/impl/document/parser/builtin/parse_markdown.go
Normal file
@@ -0,0 +1,221 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package builtin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/components/document/parser"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/yuin/goldmark"
|
||||
"github.com/yuin/goldmark/ast"
|
||||
"github.com/yuin/goldmark/text"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/ocr"
|
||||
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
)
|
||||
|
||||
func parseMarkdown(config *contract.Config, storage storage.Storage, ocr ocr.OCR) parseFn {
|
||||
return func(ctx context.Context, reader io.Reader, opts ...parser.Option) (docs []*schema.Document, err error) {
|
||||
options := parser.GetCommonOptions(&parser.Options{}, opts...)
|
||||
mdParser := goldmark.DefaultParser()
|
||||
b, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
node := mdParser.Parse(text.NewReader(b))
|
||||
cs := config.ChunkingStrategy
|
||||
ps := config.ParsingStrategy
|
||||
|
||||
if cs.ChunkType != contract.ChunkTypeCustom && cs.ChunkType != contract.ChunkTypeDefault {
|
||||
return nil, fmt.Errorf("[parseMarkdown] chunk type not support, chunk type=%d", cs.ChunkType)
|
||||
}
|
||||
|
||||
var (
|
||||
last *schema.Document
|
||||
emptySlice bool
|
||||
)
|
||||
|
||||
addSliceContent := func(content string) {
|
||||
emptySlice = false
|
||||
last.Content += content
|
||||
}
|
||||
|
||||
newSlice := func(needOverlap bool) {
|
||||
last = &schema.Document{
|
||||
MetaData: map[string]any{},
|
||||
}
|
||||
|
||||
for k, v := range options.ExtraMeta {
|
||||
last.MetaData[k] = v
|
||||
}
|
||||
|
||||
if needOverlap && cs.Overlap > 0 && len(docs) > 0 {
|
||||
overlap := getOverlap([]rune(docs[len(docs)-1].Content), cs.Overlap, cs.ChunkSize)
|
||||
addSliceContent(string(overlap))
|
||||
}
|
||||
|
||||
emptySlice = true
|
||||
}
|
||||
|
||||
pushSlice := func() {
|
||||
if !emptySlice && last.Content != "" {
|
||||
docs = append(docs, last)
|
||||
newSlice(true)
|
||||
}
|
||||
}
|
||||
|
||||
trim := func(text string) string {
|
||||
if cs.TrimURLAndEmail {
|
||||
text = urlRegex.ReplaceAllString(text, "")
|
||||
text = emailRegex.ReplaceAllString(text, "")
|
||||
}
|
||||
if cs.TrimSpace {
|
||||
text = strings.TrimSpace(text)
|
||||
text = spaceRegex.ReplaceAllString(text, " ")
|
||||
}
|
||||
return text
|
||||
}
|
||||
|
||||
downloadImage := func(ctx context.Context, url string) ([]byte, error) {
|
||||
client := &http.Client{Timeout: 5 * time.Second}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to download image: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("failed to download image, status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read image content: %w", err)
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
walker := func(n ast.Node, entering bool) (ast.WalkStatus, error) {
|
||||
if !entering {
|
||||
return ast.WalkContinue, nil
|
||||
}
|
||||
|
||||
switch n.Kind() {
|
||||
case ast.KindText:
|
||||
if n.HasChildren() {
|
||||
break
|
||||
}
|
||||
textNode := n.(*ast.Text)
|
||||
plainText := trim(string(textNode.Segment.Value(b)))
|
||||
|
||||
for _, part := range strings.Split(plainText, cs.Separator) {
|
||||
runes := []rune(part)
|
||||
for partLength := int64(len(runes)); partLength > 0; partLength = int64(len(runes)) {
|
||||
pos := min(partLength, cs.ChunkSize-charCount(last.Content))
|
||||
chunk := runes[:pos]
|
||||
addSliceContent(string(chunk))
|
||||
runes = runes[pos:]
|
||||
if charCount(last.Content) >= cs.ChunkSize {
|
||||
pushSlice()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case ast.KindImage:
|
||||
if !ps.ExtractImage {
|
||||
break
|
||||
}
|
||||
|
||||
imageNode := n.(*ast.Image)
|
||||
|
||||
if ps.ExtractImage {
|
||||
imageURL := string(imageNode.Destination)
|
||||
if _, err = url.ParseRequestURI(imageURL); err == nil {
|
||||
sp := strings.Split(imageURL, ".")
|
||||
if len(sp) == 0 {
|
||||
return ast.WalkStop, fmt.Errorf("failed to extract image extension, url=%s", imageURL)
|
||||
}
|
||||
ext := sp[len(sp)-1]
|
||||
|
||||
img, err := downloadImage(ctx, imageURL)
|
||||
if err != nil {
|
||||
return ast.WalkStop, fmt.Errorf("failed to download image: %w", err)
|
||||
}
|
||||
|
||||
imgSrc, err := putImageObject(ctx, storage, ext, getCreatorIDFromExtraMeta(options.ExtraMeta), img)
|
||||
if err != nil {
|
||||
return ast.WalkStop, err
|
||||
}
|
||||
|
||||
if !emptySlice && last.Content != "" {
|
||||
pushSlice()
|
||||
} else {
|
||||
newSlice(false)
|
||||
}
|
||||
|
||||
addSliceContent(fmt.Sprintf("\n%s\n", imgSrc))
|
||||
|
||||
if ps.ImageOCR && ocr != nil {
|
||||
texts, err := ocr.FromBase64(ctx, base64.StdEncoding.EncodeToString(img))
|
||||
if err != nil {
|
||||
return ast.WalkStop, fmt.Errorf("failed to perform OCR on image: %w", err)
|
||||
}
|
||||
addSliceContent(strings.Join(texts, "\n"))
|
||||
}
|
||||
|
||||
if charCount(last.Content) >= cs.ChunkSize {
|
||||
pushSlice()
|
||||
}
|
||||
} else {
|
||||
logs.CtxInfof(ctx, "[parseMarkdown] not a valid image url, skip, got=%s", imageURL)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ast.WalkContinue, nil
|
||||
}
|
||||
|
||||
newSlice(false)
|
||||
|
||||
if err = ast.Walk(node, walker); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !emptySlice {
|
||||
pushSlice()
|
||||
}
|
||||
|
||||
return docs, nil
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package builtin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/components/document/parser"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
|
||||
ms "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/storage"
|
||||
)
|
||||
|
||||
func TestParseMarkdown(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
mockStorage := ms.NewMockStorage(ctrl)
|
||||
mockStorage.EXPECT().PutObject(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
|
||||
|
||||
pfn := parseMarkdown(&contract.Config{
|
||||
FileExtension: contract.FileExtensionMarkdown,
|
||||
ParsingStrategy: &contract.ParsingStrategy{
|
||||
ExtractImage: true,
|
||||
ExtractTable: true,
|
||||
ImageOCR: true,
|
||||
},
|
||||
ChunkingStrategy: &contract.ChunkingStrategy{
|
||||
ChunkType: contract.ChunkTypeCustom,
|
||||
ChunkSize: 800,
|
||||
Separator: "\n",
|
||||
Overlap: 10,
|
||||
TrimSpace: true,
|
||||
TrimURLAndEmail: true,
|
||||
},
|
||||
}, mockStorage, nil)
|
||||
|
||||
f, err := os.Open("test_data/test_markdown.md")
|
||||
assert.NoError(t, err)
|
||||
docs, err := pfn(ctx, f, parser.WithExtraMeta(map[string]any{
|
||||
"document_id": int64(123),
|
||||
"knowledge_id": int64(456),
|
||||
}))
|
||||
assert.NoError(t, err)
|
||||
for _, doc := range docs {
|
||||
assertDoc(t, doc)
|
||||
}
|
||||
}
|
||||
|
||||
func assertDoc(t *testing.T, doc *schema.Document) {
|
||||
assert.NotZero(t, doc.Content)
|
||||
fmt.Println(doc.Content)
|
||||
assert.NotNil(t, doc.MetaData)
|
||||
assert.Equal(t, int64(123), doc.MetaData["document_id"].(int64))
|
||||
assert.Equal(t, int64(456), doc.MetaData["knowledge_id"].(int64))
|
||||
}
|
||||
152
backend/infra/impl/document/parser/builtin/parse_pdf.py
Normal file
152
backend/infra/impl/document/parser/builtin/parse_pdf.py
Normal file
@@ -0,0 +1,152 @@
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import base64
|
||||
|
||||
from typing import Literal
|
||||
import pdfplumber
|
||||
from PIL import Image, ImageChops
|
||||
from pdfminer.pdfcolor import (
|
||||
LITERAL_DEVICE_CMYK,
|
||||
)
|
||||
from pdfminer.pdftypes import (
|
||||
LITERALS_DCT_DECODE,
|
||||
LITERALS_FLATE_DECODE,
|
||||
)
|
||||
|
||||
def bbox_overlap(bbox1, bbox2):
|
||||
x0_1, y0_1, x1_1, y1_1 = bbox1
|
||||
x0_2, y0_2, x1_2, y1_2 = bbox2
|
||||
|
||||
x_overlap = max(0, min(x1_1, x1_2) - max(x0_1, x0_2))
|
||||
y_overlap = max(0, min(y1_1, y1_2) - max(y0_1, y0_2))
|
||||
|
||||
overlap_area = x_overlap * y_overlap
|
||||
|
||||
bbox1_area = (x1_1 - x0_1) * (y1_1 - y0_1)
|
||||
bbox2_area = (x1_2 - x0_2) * (y1_2 - y0_2)
|
||||
if bbox1_area == 0 or bbox2_area == 0:
|
||||
return 0
|
||||
|
||||
return overlap_area / min(bbox1_area, bbox2_area)
|
||||
|
||||
|
||||
def is_structured_table(table):
|
||||
if not table:
|
||||
return False
|
||||
row_count = len(table)
|
||||
col_count = max(len(row) for row in table)
|
||||
return row_count >= 2 and col_count >= 2
|
||||
|
||||
|
||||
def extract_pdf_content(pdf_data: bytes, extract_images, extract_tables: bool, filter_pages: []):
|
||||
with pdfplumber.open(io.BytesIO(pdf_data)) as pdf:
|
||||
content = []
|
||||
|
||||
for page_num, page in enumerate(pdf.pages):
|
||||
if filter_pages is not None and page_num + 1 in filter_pages:
|
||||
print(f"Skip page {page_num + 1}...")
|
||||
continue
|
||||
print(f"Processing page {page_num + 1}...")
|
||||
text = page.extract_text(x_tolerance=2)
|
||||
content.append({
|
||||
'type': 'text',
|
||||
'content': text,
|
||||
'page': page_num + 1,
|
||||
'bbox': page.bbox
|
||||
})
|
||||
|
||||
if extract_images:
|
||||
images = page.images
|
||||
for img_index, img in enumerate(images):
|
||||
try:
|
||||
filters = img['stream'].get_filters()
|
||||
data = img['stream'].get_data()
|
||||
buffered = io.BytesIO()
|
||||
|
||||
if filters[-1][0] in LITERALS_DCT_DECODE:
|
||||
if LITERAL_DEVICE_CMYK in img['colorspace']:
|
||||
i = Image.open(io.BytesIO(data))
|
||||
i = ImageChops.invert(i)
|
||||
i = i.convert("RGB")
|
||||
i.save(buffered, format="PNG")
|
||||
else:
|
||||
buffered.write(data)
|
||||
|
||||
elif len(filters) == 1 and filters[0][0] in LITERALS_FLATE_DECODE:
|
||||
width, height = img['srcsize']
|
||||
channels = len(img['stream'].get_data()) / width / height / (img['bits'] / 8)
|
||||
mode: Literal["1", "L", "RGB", "CMYK"]
|
||||
if img['bits'] == 1:
|
||||
mode = "1"
|
||||
elif img['bits'] == 8 and channels == 1:
|
||||
mode = "L"
|
||||
elif img['bits'] == 8 and channels == 3:
|
||||
mode = "RGB"
|
||||
elif img['bits'] == 8 and channels == 4:
|
||||
mode = "CMYK"
|
||||
i = Image.frombytes(mode, img['srcsize'], data, "raw")
|
||||
i.save(buffered, format="PNG")
|
||||
else:
|
||||
buffered.write(data)
|
||||
content.append({
|
||||
'type': 'image',
|
||||
'content': base64.b64encode(buffered.getvalue()).decode('utf-8'),
|
||||
'page': page_num + 1,
|
||||
'bbox': (img['x0'], img['top'], img['x1'], img['bottom'])
|
||||
})
|
||||
except Exception as err:
|
||||
print(f"Skipping an unsupported image on page {page_num + 1}, error message: {err}")
|
||||
|
||||
if extract_tables:
|
||||
tables = page.extract_tables()
|
||||
for table in tables:
|
||||
content.append({
|
||||
'type': 'table',
|
||||
'table': table,
|
||||
'page': page_num + 1,
|
||||
'bbox': page.bbox
|
||||
})
|
||||
|
||||
content.sort(key=lambda x: (x['page'], x['bbox'][1], x['bbox'][0]))
|
||||
|
||||
filtered_content = []
|
||||
for item in content:
|
||||
if item['type'] == 'table':
|
||||
if is_structured_table(item['table']):
|
||||
filtered_content.append(item)
|
||||
continue
|
||||
overlap_found = False
|
||||
for existing_item in filtered_content:
|
||||
if existing_item['type'] == 'text' and bbox_overlap(item['bbox'], existing_item['bbox']) > 0.8:
|
||||
overlap_found = True
|
||||
break
|
||||
if overlap_found:
|
||||
continue
|
||||
filtered_content.append(item)
|
||||
|
||||
return filtered_content
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
w = os.fdopen(3, "wb", )
|
||||
r = os.fdopen(4, "rb", )
|
||||
pdf_data = sys.stdin.buffer.read()
|
||||
print(f"Read {len(pdf_data)} bytes of PDF data")
|
||||
|
||||
try:
|
||||
req = json.load(r)
|
||||
ei, et, fp = req['extract_images'], req['extract_tables'], req['filter_pages']
|
||||
extracted_content = extract_pdf_content(pdf_data, ei, et, fp)
|
||||
print(f"Extracted {len(extracted_content)} items")
|
||||
result = json.dumps({"content": extracted_content}, ensure_ascii=False)
|
||||
w.write(str.encode(result))
|
||||
w.flush()
|
||||
w.close()
|
||||
print("Pdf parse done")
|
||||
except Exception as e:
|
||||
print("Pdf parse error", e)
|
||||
w.write(str.encode(json.dumps({"error": str(e)})))
|
||||
w.flush()
|
||||
w.close()
|
||||
49
backend/infra/impl/document/parser/builtin/parse_text.go
Normal file
49
backend/infra/impl/document/parser/builtin/parse_text.go
Normal file
@@ -0,0 +1,49 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package builtin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/cloudwego/eino/components/document/parser"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
|
||||
)
|
||||
|
||||
func parseText(config *contract.Config) parseFn {
|
||||
return func(ctx context.Context, reader io.Reader, opts ...parser.Option) (docs []*schema.Document, err error) {
|
||||
content, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch config.ChunkingStrategy.ChunkType {
|
||||
case contract.ChunkTypeCustom, contract.ChunkTypeDefault:
|
||||
docs, err = chunkCustom(ctx, string(content), config, opts...)
|
||||
default:
|
||||
return nil, fmt.Errorf("[parseText] chunk type not support, type=%d", config.ChunkingStrategy.ChunkType)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return docs, nil
|
||||
}
|
||||
}
|
||||
78
backend/infra/impl/document/parser/builtin/parse_xlsx.go
Normal file
78
backend/infra/impl/document/parser/builtin/parse_xlsx.go
Normal file
@@ -0,0 +1,78 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package builtin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"github.com/cloudwego/eino/components/document/parser"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/xuri/excelize/v2"
|
||||
|
||||
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
|
||||
)
|
||||
|
||||
func parseXLSX(config *contract.Config) parseFn {
|
||||
return func(ctx context.Context, reader io.Reader, opts ...parser.Option) (docs []*schema.Document, err error) {
|
||||
f, err := excelize.OpenReader(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sheetID := 0
|
||||
if config.ParsingStrategy.SheetID != nil {
|
||||
sheetID = *config.ParsingStrategy.SheetID
|
||||
}
|
||||
|
||||
rows, err := f.Rows(f.GetSheetName(sheetID))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
iter := &xlsxIterator{rows, 0}
|
||||
|
||||
return parseByRowIterator(iter, config, opts...)
|
||||
}
|
||||
}
|
||||
|
||||
type xlsxIterator struct {
|
||||
rows *excelize.Rows
|
||||
firstRowSize int
|
||||
}
|
||||
|
||||
func (x *xlsxIterator) NextRow() (row []string, end bool, err error) {
|
||||
end = !x.rows.Next()
|
||||
if end {
|
||||
return nil, end, nil
|
||||
}
|
||||
|
||||
row, err = x.rows.Columns()
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
if x.firstRowSize == 0 {
|
||||
x.firstRowSize = len(row)
|
||||
} else if x.firstRowSize > len(row) {
|
||||
row = append(row, make([]string, x.firstRowSize-len(row))...)
|
||||
} else if x.firstRowSize < len(row) {
|
||||
row = row[:x.firstRowSize]
|
||||
}
|
||||
|
||||
return row, false, nil
|
||||
}
|
||||
171
backend/infra/impl/document/parser/builtin/parse_xlsx_test.go
Normal file
171
backend/infra/impl/document/parser/builtin/parse_xlsx_test.go
Normal file
@@ -0,0 +1,171 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package builtin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/components/document/parser"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
|
||||
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
|
||||
)
|
||||
|
||||
func TestParseXLSX(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
b, err := os.ReadFile("./test_data/test_xlsx.xlsx")
|
||||
assert.NoError(t, err)
|
||||
reader := bytes.NewReader(b)
|
||||
config := &contract.Config{
|
||||
FileExtension: contract.FileExtensionXLSX,
|
||||
ParsingStrategy: &contract.ParsingStrategy{
|
||||
HeaderLine: 0,
|
||||
DataStartLine: 1,
|
||||
RowsCount: 10,
|
||||
Columns: []*document.Column{
|
||||
{
|
||||
ID: 0,
|
||||
Name: "col_string_indexing",
|
||||
Type: document.TableColumnTypeString,
|
||||
Nullable: false,
|
||||
Sequence: 0,
|
||||
},
|
||||
{
|
||||
ID: 0,
|
||||
Name: "col_string",
|
||||
Type: document.TableColumnTypeString,
|
||||
Nullable: true,
|
||||
Sequence: 1,
|
||||
},
|
||||
{
|
||||
ID: 0,
|
||||
Name: "col_int",
|
||||
Type: document.TableColumnTypeInteger,
|
||||
Nullable: false,
|
||||
Sequence: 2,
|
||||
},
|
||||
{
|
||||
ID: 0,
|
||||
Name: "col_number",
|
||||
Type: document.TableColumnTypeNumber,
|
||||
Nullable: true,
|
||||
Sequence: 3,
|
||||
},
|
||||
{
|
||||
ID: 0,
|
||||
Name: "col_bool",
|
||||
Type: document.TableColumnTypeBoolean,
|
||||
Nullable: true,
|
||||
Sequence: 4,
|
||||
},
|
||||
{
|
||||
ID: 0,
|
||||
Name: "col_time",
|
||||
Type: document.TableColumnTypeTime,
|
||||
Nullable: true,
|
||||
Sequence: 5,
|
||||
},
|
||||
},
|
||||
},
|
||||
ChunkingStrategy: nil,
|
||||
}
|
||||
|
||||
pfn := parseXLSX(config)
|
||||
docs, err := pfn(ctx, reader, parser.WithExtraMeta(map[string]any{
|
||||
"document_id": int64(123),
|
||||
"knowledge_id": int64(456),
|
||||
}))
|
||||
assert.NoError(t, err)
|
||||
for i, doc := range docs {
|
||||
assertSheet(t, i, doc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseXLSXConvertColumnType(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
b, err := os.ReadFile("./test_data/test_xlsx.xlsx")
|
||||
assert.NoError(t, err)
|
||||
reader := bytes.NewReader(b)
|
||||
config := &contract.Config{
|
||||
FileExtension: contract.FileExtensionXLSX,
|
||||
ParsingStrategy: &contract.ParsingStrategy{
|
||||
HeaderLine: 0,
|
||||
DataStartLine: 1,
|
||||
RowsCount: 10,
|
||||
IgnoreColumnTypeErr: true,
|
||||
Columns: []*document.Column{
|
||||
{
|
||||
ID: 0,
|
||||
Name: "col_string_indexing",
|
||||
Type: document.TableColumnTypeString,
|
||||
Nullable: false,
|
||||
Sequence: 0,
|
||||
},
|
||||
{
|
||||
ID: 0,
|
||||
Name: "col_string",
|
||||
Type: document.TableColumnTypeInteger, // string -> int: null
|
||||
Nullable: true,
|
||||
Sequence: 1,
|
||||
},
|
||||
{
|
||||
ID: 0,
|
||||
Name: "col_int",
|
||||
Type: document.TableColumnTypeString, // int -> string: strconv
|
||||
Nullable: false,
|
||||
Sequence: 2,
|
||||
},
|
||||
{
|
||||
ID: 0,
|
||||
Name: "col_number",
|
||||
Type: document.TableColumnTypeString, // float -> string: strconv
|
||||
Nullable: true,
|
||||
Sequence: 3,
|
||||
},
|
||||
//{
|
||||
// ID: 0,
|
||||
// Name: "col_bool",
|
||||
// Type: document.TableColumnTypeBoolean, // trim
|
||||
// Nullable: true,
|
||||
// Sequence: 4,
|
||||
//},
|
||||
//{
|
||||
// ID: 0,
|
||||
// Name: "col_time",
|
||||
// Type: document.TableColumnTypeTime, // trim
|
||||
// Nullable: true,
|
||||
// Sequence: 5,
|
||||
//},
|
||||
},
|
||||
},
|
||||
ChunkingStrategy: nil,
|
||||
}
|
||||
|
||||
pfn := parseXLSX(config)
|
||||
docs, err := pfn(ctx, reader, parser.WithExtraMeta(map[string]any{
|
||||
"document_id": int64(123),
|
||||
"knowledge_id": int64(456),
|
||||
}))
|
||||
assert.NoError(t, err)
|
||||
for i, doc := range docs {
|
||||
assertSheet(t, i, doc)
|
||||
}
|
||||
}
|
||||
35
backend/infra/impl/document/parser/builtin/parser.go
Normal file
35
backend/infra/impl/document/parser/builtin/parser.go
Normal file
@@ -0,0 +1,35 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package builtin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"github.com/cloudwego/eino/components/document/parser"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
type p struct {
|
||||
parseFn
|
||||
}
|
||||
|
||||
func (p p) Parse(ctx context.Context, reader io.Reader, opts ...parser.Option) ([]*schema.Document, error) {
|
||||
return p.parseFn(ctx, reader, opts...)
|
||||
}
|
||||
|
||||
type parseFn func(ctx context.Context, reader io.Reader, opts ...parser.Option) (docs []*schema.Document, err error)
|
||||
270
backend/infra/impl/document/parser/builtin/py_parser_protocol.go
Normal file
270
backend/infra/impl/document/parser/builtin/py_parser_protocol.go
Normal file
@@ -0,0 +1,270 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package builtin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/components/document/parser"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/ocr"
|
||||
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
)
|
||||
|
||||
const (
|
||||
contentTypeText = "text"
|
||||
contentTypeImage = "image"
|
||||
contentTypeTable = "table"
|
||||
)
|
||||
|
||||
type pyParseRequest struct {
|
||||
ExtractImages bool `json:"extract_images"`
|
||||
ExtractTables bool `json:"extract_tables"`
|
||||
FilterPages []int `json:"filter_pages"`
|
||||
}
|
||||
|
||||
type pyParseResult struct {
|
||||
Error string `json:"error"`
|
||||
Content []*pyParseContent `json:"content"`
|
||||
}
|
||||
|
||||
type pyParseContent struct {
|
||||
Type string `json:"type"`
|
||||
Content string `json:"content"`
|
||||
Table [][]string `json:"table"`
|
||||
Page int `json:"page"`
|
||||
}
|
||||
|
||||
type pyPDFTableIterator struct {
|
||||
i int
|
||||
rows [][]string
|
||||
}
|
||||
|
||||
func (p *pyPDFTableIterator) NextRow() (row []string, end bool, err error) {
|
||||
if p.i >= len(p.rows) {
|
||||
return nil, true, nil
|
||||
}
|
||||
row = p.rows[p.i]
|
||||
p.i++
|
||||
return row, false, nil
|
||||
}
|
||||
|
||||
func parseByPython(config *contract.Config, storage storage.Storage, ocr ocr.OCR, pyPath, scriptPath string) parseFn {
|
||||
return func(ctx context.Context, reader io.Reader, opts ...parser.Option) (docs []*schema.Document, err error) {
|
||||
pr, pw, err := os.Pipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[parseByPython] create rpipe failed, %w", err)
|
||||
}
|
||||
r, w, err := os.Pipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[parseByPython] create pipe failed: %w", err)
|
||||
}
|
||||
options := parser.GetCommonOptions(&parser.Options{ExtraMeta: map[string]any{}}, opts...)
|
||||
|
||||
reqb, err := json.Marshal(pyParseRequest{
|
||||
ExtractImages: config.ParsingStrategy.ExtractImage,
|
||||
ExtractTables: config.ParsingStrategy.ExtractTable,
|
||||
FilterPages: config.ParsingStrategy.FilterPages,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[parseByPython] create parse request failed, %w", err)
|
||||
}
|
||||
if _, err = pw.Write(reqb); err != nil {
|
||||
return nil, fmt.Errorf("[parseByPython] write parse request bytes failed, %w", err)
|
||||
}
|
||||
if err = pw.Close(); err != nil {
|
||||
return nil, fmt.Errorf("[parseByPython] close write request pipe failed, %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command(pyPath, scriptPath)
|
||||
cmd.Stdin = reader
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.ExtraFiles = []*os.File{w, pr}
|
||||
if err = cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("[parseByPython] failed to start Python script: %w", err)
|
||||
}
|
||||
if err = w.Close(); err != nil {
|
||||
return nil, fmt.Errorf("[parseByPython] failed to close write pipe: %w", err)
|
||||
}
|
||||
|
||||
result := &pyParseResult{}
|
||||
|
||||
if err = json.NewDecoder(r).Decode(result); err != nil {
|
||||
return nil, fmt.Errorf("[parseByPython] failed to decode result: %w", err)
|
||||
}
|
||||
if err = cmd.Wait(); err != nil {
|
||||
return nil, fmt.Errorf("[parseByPython] cmd wait err: %w", err)
|
||||
}
|
||||
|
||||
if result.Error != "" {
|
||||
return nil, fmt.Errorf("[parseByPython] python execution failed: %s", result.Error)
|
||||
}
|
||||
|
||||
for i, item := range result.Content {
|
||||
switch item.Type {
|
||||
case contentTypeText:
|
||||
partDocs, err := chunkCustom(ctx, item.Content, config, opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[parseByPython] chunk text failed, %w", err)
|
||||
}
|
||||
docs = append(docs, partDocs...)
|
||||
case contentTypeImage:
|
||||
if !config.ParsingStrategy.ExtractImage {
|
||||
continue
|
||||
}
|
||||
image, err := base64.StdEncoding.DecodeString(item.Content)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[parseByPython] decode image failed, %w", err)
|
||||
}
|
||||
imgSrc, err := putImageObject(ctx, storage, "png", getCreatorIDFromExtraMeta(options.ExtraMeta), image)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
label := fmt.Sprintf("\n%s", imgSrc)
|
||||
if config.ParsingStrategy.ImageOCR && ocr != nil {
|
||||
texts, err := ocr.FromBase64(ctx, item.Content)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[parseByPython] FromBase64 failed, %w", err)
|
||||
}
|
||||
label += strings.Join(texts, "\n")
|
||||
}
|
||||
|
||||
if i == len(result.Content)-1 || result.Content[i+1].Type != "text" {
|
||||
doc := &schema.Document{
|
||||
Content: label,
|
||||
MetaData: map[string]any{},
|
||||
}
|
||||
for k, v := range options.ExtraMeta {
|
||||
doc.MetaData[k] = v
|
||||
}
|
||||
docs = append(docs, doc)
|
||||
} else {
|
||||
// TODO: 这里有点问题,img label 可能被较短的 chunk size 截断
|
||||
result.Content[i+1].Content = label + result.Content[i+1].Content
|
||||
}
|
||||
case contentTypeTable:
|
||||
if !config.ParsingStrategy.ExtractTable {
|
||||
continue
|
||||
}
|
||||
iterator := &pyPDFTableIterator{i: 0, rows: item.Table}
|
||||
rawTableDocs, err := parseByRowIterator(iterator, &contract.Config{
|
||||
FileExtension: contract.FileExtensionCSV,
|
||||
ParsingStrategy: &contract.ParsingStrategy{
|
||||
HeaderLine: 0,
|
||||
DataStartLine: 1,
|
||||
RowsCount: 0,
|
||||
},
|
||||
ChunkingStrategy: config.ChunkingStrategy,
|
||||
}, opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[parseByPython] parse table failed, %w", err)
|
||||
}
|
||||
fmtTableDocs, err := formatTablesInDocument(rawTableDocs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[parseByPython] format table failed, %w", err)
|
||||
}
|
||||
docs = append(docs, fmtTableDocs...)
|
||||
default:
|
||||
return nil, fmt.Errorf("[parseByPython] invalid content type: %s", item.Type)
|
||||
}
|
||||
}
|
||||
|
||||
return docs, nil
|
||||
}
|
||||
}
|
||||
|
||||
func formatTablesInDocument(input []*schema.Document) (output []*schema.Document, err error) {
|
||||
const (
|
||||
maxSize = 65535
|
||||
tableStart, tableEnd = "<table>", "</table>"
|
||||
)
|
||||
|
||||
var (
|
||||
buffer strings.Builder
|
||||
firstDoc *schema.Document
|
||||
)
|
||||
|
||||
endSize := len(tableEnd)
|
||||
buffer.WriteString(tableStart)
|
||||
|
||||
push := func() {
|
||||
newDoc := &schema.Document{
|
||||
Content: buffer.String() + tableEnd,
|
||||
MetaData: map[string]any{},
|
||||
}
|
||||
for k, v := range firstDoc.MetaData {
|
||||
if k == document.MetaDataKeyColumnData {
|
||||
continue
|
||||
}
|
||||
newDoc.MetaData[k] = v
|
||||
}
|
||||
output = append(output, newDoc)
|
||||
buffer.Reset()
|
||||
buffer.WriteString(tableStart)
|
||||
}
|
||||
|
||||
write := func(contents []string) {
|
||||
row := fmt.Sprintf("<tr><td>%s</td></tr>", strings.Join(contents, "</td><td>"))
|
||||
buffer.WriteString(row)
|
||||
if buffer.Len()+endSize >= maxSize {
|
||||
push()
|
||||
}
|
||||
}
|
||||
|
||||
for i := range input {
|
||||
doc := input[i]
|
||||
|
||||
if i == 0 {
|
||||
firstDoc = doc
|
||||
cols, err := document.GetDocumentColumns(doc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[formatTablesInDocument] invalid table columns, %w", err)
|
||||
}
|
||||
values := make([]string, 0, len(cols))
|
||||
for _, col := range cols {
|
||||
values = append(values, col.Name)
|
||||
}
|
||||
write(values)
|
||||
}
|
||||
|
||||
data, err := document.GetDocumentColumnData(doc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[formatTablesInDocument] invalid table data, %w", err)
|
||||
}
|
||||
values := make([]string, 0, len(data))
|
||||
for _, col := range data {
|
||||
values = append(values, col.GetNullableStringValue())
|
||||
}
|
||||
write(values)
|
||||
}
|
||||
|
||||
if buffer.String() != tableStart {
|
||||
push()
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 65 KiB |
BIN
backend/infra/impl/document/parser/builtin/test_data/logo.png
Normal file
BIN
backend/infra/impl/document/parser/builtin/test_data/logo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 38 KiB |
@@ -0,0 +1,3 @@
|
||||
col_string_indexing,col_string,col_int,col_number,col_bool,col_time
|
||||
hello,asd,1,1.0,TRUE,2006-01-02 15:04:02
|
||||
bye,,2,2.0,TRUE,
|
||||
|
File diff suppressed because one or more lines are too long
@@ -0,0 +1 @@
|
||||
col_string_indexing,col_string,col_int,col_number,col_bool,col_time
|
||||
|
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,272 @@
|
||||
# 1. 欢迎使用 Cmd Markdown 编辑阅读器
|
||||
<!-- TOC -->
|
||||
|
||||
- [1. 欢迎使用 Cmd Markdown 编辑阅读器](#1-欢迎使用-cmd-markdown-编辑阅读器)
|
||||
- [1.1. markdown扩展需求](#11-markdown扩展需求)
|
||||
- [1.1.1. 一、各种流程图](#111-一各种流程图)
|
||||
- [1.1.2. [Windows/Mac/Linux 全平台客户端](https://www.zybuluo.com/cmd/)](#112-windowsmaclinux-全平台客户端httpswwwzybuluocomcmd)
|
||||
- [1.2. 什么是 Markdown](#12-什么是-markdown)
|
||||
- [1.2.1. 制作一份待办事宜 [Todo 列表](https://www.zybuluo.com/mdeditor?url=https://www.zybuluo.com/static/editor/md-help.markdown#13-待办事宜-todo-列表)](#121-制作一份待办事宜-todo-列表httpswwwzybuluocommdeditorurlhttpswwwzybuluocomstaticeditormd-helpmarkdown13-待办事宜-todo-列表)
|
||||
- [1.2.2. 书写一个质能守恒公式[^LaTeX]](#122-书写一个质能守恒公式^latex)
|
||||
- [1.2.3. 高亮一段代码[^code]](#123-高亮一段代码^code)
|
||||
- [1.2.4. 高效绘制 [流程图](https://www.zybuluo.com/mdeditor?url=https://www.zybuluo.com/static/editor/md-help.markdown#7-流程图)](#124-高效绘制-流程图httpswwwzybuluocommdeditorurlhttpswwwzybuluocomstaticeditormd-helpmarkdown7-流程图)
|
||||
- [1.2.5. 高效绘制 [序列图](https://www.zybuluo.com/mdeditor?url=https://www.zybuluo.com/static/editor/md-help.markdown#8-序列图)](#125-高效绘制-序列图httpswwwzybuluocommdeditorurlhttpswwwzybuluocomstaticeditormd-helpmarkdown8-序列图)
|
||||
- [1.2.6. 绘制表格](#126-绘制表格)
|
||||
- [1.2.7. 更详细语法说明](#127-更详细语法说明)
|
||||
- [1.3. 什么是 Cmd Markdown](#13-什么是-cmd-markdown)
|
||||
- [1.3.1. 实时同步预览](#131-实时同步预览)
|
||||
- [1.3.2. 编辑工具栏](#132-编辑工具栏)
|
||||
- [1.3.3. 编辑模式](#133-编辑模式)
|
||||
- [1.3.4. 实时的云端文稿](#134-实时的云端文稿)
|
||||
- [1.3.5. 离线模式](#135-离线模式)
|
||||
- [1.3.6. 管理工具栏](#136-管理工具栏)
|
||||
- [1.3.7. 阅读工具栏](#137-阅读工具栏)
|
||||
- [1.3.8. 阅读模式](#138-阅读模式)
|
||||
- [1.3.9. 标签、分类和搜索](#139-标签分类和搜索)
|
||||
- [1.3.10. 文稿发布和分享](#1310-文稿发布和分享)
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
[ ] dddd
|
||||
[x] xxxx
|
||||
第一行
|
||||
第二行
|
||||
------
|
||||
> 一个快速笔记工具,可生成网页快速分享。
|
||||
## 1.1. markdown扩展需求
|
||||
1. 目录
|
||||
2. 表情
|
||||
3. 粘贴截图
|
||||
4. 流程图、时序图
|
||||
5. 数学公式
|
||||
6. 标签
|
||||
7. 简单动画
|
||||
|
||||
|
||||
|
||||
### 1.1.1. 一、各种流程图
|
||||
1. 时序图
|
||||
|
||||
```seq
|
||||
Alice->Bob: Hello Bob, how are you?
|
||||
Note right of Bob: Bob thinks
|
||||
Bob-->Alice: I am good thanks!
|
||||
```
|
||||
|
||||
2. 流程图
|
||||
|
||||
```flow
|
||||
st=>start: Start
|
||||
op=>operation: Your Operation
|
||||
cond=>condition: Yes or No?
|
||||
e=>end
|
||||
|
||||
st->op->cond
|
||||
cond(yes)->e
|
||||
cond(no)->op
|
||||
```
|
||||
|
||||
3. 甘特图
|
||||
|
||||
```gantt
|
||||
title 项目开发流程
|
||||
section 项目确定
|
||||
需求分析 :a1, 2016-06-22, 3d
|
||||
可行性报告 :after a1, 5d
|
||||
概念验证 : 5d
|
||||
section 项目实施
|
||||
概要设计 :2016-07-05, 5d
|
||||
详细设计 :2016-07-08, 10d
|
||||
编码 :2016-07-15, 10d
|
||||
测试 :2016-07-22, 5d
|
||||
section 发布验收
|
||||
发布: 2d
|
||||
验收: 3d
|
||||
```
|
||||
|
||||
4. Mermaid 流程图
|
||||
|
||||
```graphLR
|
||||
A[Hard edge] -->|Link text| B(Round edge)
|
||||
B --> C{Decision}
|
||||
C -->|One| D[Result one]
|
||||
C -->|Two| E[Result two]
|
||||
```
|
||||
|
||||
5. Mermaid 序列图
|
||||
|
||||
```sequence
|
||||
Alice->John: Hello John, how are you?
|
||||
loop every minute
|
||||
John-->Alice: Great!
|
||||
end
|
||||
```
|
||||
|
||||
我们理解您需要更便捷更高效的工具记录思想,整理笔记、知识,并将其中承载的价值传播给他人,**Cmd Markdown** 是我们给出的答案 —— 我们为记录思想和分享知识提供更专业的工具。 您可以使用 Cmd Markdown:
|
||||
|
||||
> * 整理知识,学习笔记
|
||||
|
||||
> * 发布日记,杂文,所见所想
|
||||
> * 撰写发布技术文稿(代码支持)
|
||||
> * 撰写发布学术论文(LaTeX 公式支持)
|
||||
|
||||

|
||||
|
||||
除了您现在看到的这个 Cmd Markdown 在线版本,您还可以前往以下网址下载:
|
||||
|
||||
### 1.1.2. [Windows/Mac/Linux 全平台客户端](https://www.zybuluo.com/cmd/)
|
||||
|
||||
> 请保留此份 Cmd Markdown 的欢迎稿兼使用说明,如需撰写新稿件,点击顶部工具栏右侧的 <i class="icon-file"></i> **新文稿** 或者使用快捷键 `Ctrl+Alt+N`。
|
||||
|
||||
------
|
||||
|
||||
## 1.2. 什么是 Markdown
|
||||
|
||||
Markdown 是一种方便记忆、书写的纯文本标记语言,用户可以使用这些标记符号以最小的输入代价生成极富表现力的文档:譬如您正在阅读的这份文档。它使用简单的符号标记不同的标题,分割不同的段落,**粗体** 或者 *斜体* 某些文字,更棒的是,它还可以
|
||||
|
||||
### 1.2.1. 制作一份待办事宜 [Todo 列表](https://www.zybuluo.com/mdeditor?url=https://www.zybuluo.com/static/editor/md-help.markdown#13-待办事宜-todo-列表)
|
||||
|
||||
- [ ] 支持以 PDF 格式导出文稿
|
||||
- [ ] 改进 Cmd 渲染算法,使用局部渲染技术提高渲染效率
|
||||
- [x] 新增 Todo 列表功能
|
||||
- [x] 修复 LaTex 公式渲染问题
|
||||
- [x] 新增 LaTex 公式编号功能
|
||||
|
||||
### 1.2.2. 书写一个质能守恒公式[^LaTeX]
|
||||
|
||||
$$E=mc^2$$
|
||||
|
||||
### 1.2.3. 高亮一段代码[^code]
|
||||
|
||||
```python
|
||||
@requires_authorization
|
||||
class SomeClass:
|
||||
pass
|
||||
|
||||
if __name__ == '__main__':
|
||||
# A comment
|
||||
print 'hello world'
|
||||
```
|
||||
|
||||
### 1.2.4. 高效绘制 [流程图](https://www.zybuluo.com/mdeditor?url=https://www.zybuluo.com/static/editor/md-help.markdown#7-流程图)
|
||||
|
||||
```flow
|
||||
st=>start: Start
|
||||
op=>operation: Your Operation
|
||||
cond=>condition: Yes or No?
|
||||
e=>end
|
||||
|
||||
st->op->cond
|
||||
cond(yes)->e
|
||||
cond(no)->op
|
||||
```
|
||||
|
||||
### 1.2.5. 高效绘制 [序列图](https://www.zybuluo.com/mdeditor?url=https://www.zybuluo.com/static/editor/md-help.markdown#8-序列图)
|
||||
|
||||
```seq
|
||||
Alice->Bob: Hello Bob, how are you?
|
||||
Note right of Bob: Bob thinks
|
||||
Bob-->Alice: I am good thanks!
|
||||
```
|
||||
|
||||
### 1.2.6. 绘制表格
|
||||
|
||||
| 项目 | 价格 | 数量 |
|
||||
| -------- | -----: | :----: |
|
||||
| 计算机 | \$1600 | 5 |
|
||||
| 手机 | \$12 | 12 |
|
||||
| 管线 | \$1 | 234 |
|
||||
|
||||
### 1.2.7. 更详细语法说明
|
||||
|
||||
想要查看更详细的语法说明,可以参考我们准备的 [Cmd Markdown 简明语法手册][1],进阶用户可以参考 [Cmd Markdown 高阶语法手册][2] 了解更多高级功能。
|
||||
|
||||
总而言之,不同于其它 *所见即所得* 的编辑器:你只需使用键盘专注于书写文本内容,就可以生成印刷级的排版格式,省却在键盘和工具栏之间来回切换,调整内容和格式的麻烦。**Markdown 在流畅的书写和印刷级的阅读体验之间找到了平衡。** 目前它已经成为世界上最大的技术分享网站 GitHub 和 技术问答网站 StackOverFlow 的御用书写格式。
|
||||
|
||||
---
|
||||
|
||||
## 1.3. 什么是 Cmd Markdown
|
||||
|
||||
您可以使用很多工具书写 Markdown,但是 Cmd Markdown 是这个星球上我们已知的、最好的 Markdown 工具——没有之一 :)因为深信文字的力量,所以我们和你一样,对流畅书写,分享思想和知识,以及阅读体验有极致的追求,我们把对于这些诉求的回应整合在 Cmd Markdown,并且一次,两次,三次,乃至无数次地提升这个工具的体验,最终将它演化成一个 **编辑/发布/阅读** Markdown 的在线平台——您可以在任何地方,任何系统/设备上管理这里的文字。
|
||||
|
||||
### 1.3.1. 实时同步预览
|
||||
|
||||
我们将 Cmd Markdown 的主界面一分为二,左边为**编辑区**,右边为**预览区**,在编辑区的操作会实时地渲染到预览区方便查看最终的版面效果,并且如果你在其中一个区拖动滚动条,我们有一个巧妙的算法把另一个区的滚动条同步到等价的位置,超酷!
|
||||
|
||||
### 1.3.2. 编辑工具栏
|
||||
|
||||
也许您还是一个 Markdown 语法的新手,在您完全熟悉它之前,我们在 **编辑区** 的顶部放置了一个如下图所示的工具栏,您可以使用鼠标在工具栏上调整格式,不过我们仍旧鼓励你使用键盘标记格式,提高书写的流畅度。
|
||||
|
||||

|
||||
|
||||
### 1.3.3. 编辑模式
|
||||
|
||||
完全心无旁骛的方式编辑文字:点击 **编辑工具栏** 最右测的拉伸按钮或者按下 `Ctrl + M`,将 Cmd Markdown 切换到独立的编辑模式,这是一个极度简洁的写作环境,所有可能会引起分心的元素都已经被挪除,超清爽!
|
||||
|
||||
### 1.3.4. 实时的云端文稿
|
||||
|
||||
为了保障数据安全,Cmd Markdown 会将您每一次击键的内容保存至云端,同时在 **编辑工具栏** 的最右侧提示 `已保存` 的字样。无需担心浏览器崩溃,机器掉电或者地震,海啸——在编辑的过程中随时关闭浏览器或者机器,下一次回到 Cmd Markdown 的时候继续写作。
|
||||
|
||||
### 1.3.5. 离线模式
|
||||
|
||||
在网络环境不稳定的情况下记录文字一样很安全!在您写作的时候,如果电脑突然失去网络连接,Cmd Markdown 会智能切换至离线模式,将您后续键入的文字保存在本地,直到网络恢复再将他们传送至云端,即使在网络恢复前关闭浏览器或者电脑,一样没有问题,等到下次开启 Cmd Markdown 的时候,她会提醒您将离线保存的文字传送至云端。简而言之,我们尽最大的努力保障您文字的安全。
|
||||
|
||||
### 1.3.6. 管理工具栏
|
||||
|
||||
为了便于管理您的文稿,在 **预览区** 的顶部放置了如下所示的 **管理工具栏**:
|
||||
|
||||
通过管理工具栏可以:
|
||||
|
||||
<i class="icon-share"></i> 发布:将当前的文稿生成固定链接,在网络上发布,分享
|
||||
<i class="icon-file"></i> 新建:开始撰写一篇新的文稿
|
||||
<i class="icon-trash"></i> 删除:删除当前的文稿
|
||||
<i class="icon-cloud"></i> 导出:将当前的文稿转化为 Markdown 文本或者 Html 格式,并导出到本地
|
||||
<i class="icon-reorder"></i> 列表:所有新增和过往的文稿都可以在这里查看、操作
|
||||
<i class="icon-pencil"></i> 模式:切换 普通/Vim/Emacs 编辑模式
|
||||
|
||||
### 1.3.7. 阅读工具栏
|
||||
|
||||
通过 **预览区** 右上角的 **阅读工具栏**,可以查看当前文稿的目录并增强阅读体验。
|
||||
|
||||
工具栏上的五个图标依次为:
|
||||
|
||||
<i class="icon-list"></i> 目录:快速导航当前文稿的目录结构以跳转到感兴趣的段落
|
||||
<i class="icon-chevron-sign-left"></i> 视图:互换左边编辑区和右边预览区的位置
|
||||
<i class="icon-adjust"></i> 主题:内置了黑白两种模式的主题,试试 **黑色主题**,超炫!
|
||||
<i class="icon-desktop"></i> 阅读:心无旁骛的阅读模式提供超一流的阅读体验
|
||||
<i class="icon-fullscreen"></i> 全屏:简洁,简洁,再简洁,一个完全沉浸式的写作和阅读环境
|
||||
|
||||
### 1.3.8. 阅读模式
|
||||
|
||||
在 **阅读工具栏** 点击 <i class="icon-desktop"></i> 或者按下 `Ctrl+Alt+M` 随即进入独立的阅读模式界面,我们在版面渲染上的每一个细节:字体,字号,行间距,前背景色都倾注了大量的时间,努力提升阅读的体验和品质。
|
||||
|
||||
### 1.3.9. 标签、分类和搜索
|
||||
|
||||
在编辑区任意行首位置输入以下格式的文字可以标签当前文档:
|
||||
|
||||
标签: 未分类
|
||||
|
||||
标签以后的文稿在【文件列表】(Ctrl+Alt+F)里会按照标签分类,用户可以同时使用键盘或者鼠标浏览查看,或者在【文件列表】的搜索文本框内搜索标题关键字过滤文稿,如下图所示:
|
||||
|
||||

|
||||
|
||||
### 1.3.10. 文稿发布和分享
|
||||
|
||||
在您使用 Cmd Markdown 记录,创作,整理,阅读文稿的同时,我们不仅希望它是一个有力的工具,更希望您的思想和知识通过这个平台,连同优质的阅读体验,将他们分享给有相同志趣的人,进而鼓励更多的人来到这里记录分享他们的思想和知识,尝试点击 <i class="icon-share"></i> (Ctrl+Alt+P) 发布这份文档给好友吧!
|
||||
|
||||
------
|
||||
|
||||
再一次感谢您花费时间阅读这份欢迎稿,点击 <i class="icon-file"></i> (Ctrl+Alt+N) 开始撰写新的文稿吧!祝您在这里记录、阅读、分享愉快!
|
||||
|
||||
作者 [@ghosert][3]
|
||||
2015 年 06月 15日
|
||||
|
||||
[^LaTeX]: 支持 **LaTeX** 编辑显示支持,例如:$\sum_{i=1}^n a_i=0$, 访问 [MathJax][4] 参考更多使用方法。
|
||||
|
||||
[^code]: 代码高亮功能支持包括 Java, Python, JavaScript 在内的,**四十一**种主流编程语言。
|
||||
|
||||
[1]: https://www.zybuluo.com/mdeditor?url=https://www.zybuluo.com/static/editor/md-help.markdown
|
||||
[2]: https://www.zybuluo.com/mdeditor?url=https://www.zybuluo.com/static/editor/md-help.markdown#cmd-markdown-高阶语法手册
|
||||
[3]: http://weibo.com/ghosert
|
||||
[4]: http://meta.math.stackexchange.com/questions/5020/mathjax-basic-tutorial-and-quick-reference
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
After Width: | Height: | Size: 8.6 KiB |
74
backend/infra/impl/document/parser/builtin/util.go
Normal file
74
backend/infra/impl/document/parser/builtin/util.go
Normal file
@@ -0,0 +1,74 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package builtin
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"path"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
|
||||
)
|
||||
|
||||
const baseWord = "1Aa2Bb3Cc4Dd5Ee6Ff7Gg8Hh9Ii0JjKkLlMmNnOoPpQqRrSsTtUuVvWwXxYyZz"
|
||||
const knowledgePrefix = "BIZ_KNOWLEDGE"
|
||||
const imgSrcFormat = `<img src="" data-tos-key="%s">`
|
||||
|
||||
func createSecret(uid int64, fileType string) string {
|
||||
num := 10
|
||||
input := fmt.Sprintf("upload_%d_Ma*9)fhi_%d_gou_%s_rand_%d", uid, time.Now().Unix(), fileType, rand.Intn(100000))
|
||||
// 做md5,取前20个,// mapIntToBase62 把数字映射到 Base62
|
||||
hash := sha256.Sum256([]byte(fmt.Sprintf("%s", input)))
|
||||
hashString := base64.StdEncoding.EncodeToString(hash[:])
|
||||
if len(hashString) > num {
|
||||
hashString = hashString[:num]
|
||||
}
|
||||
|
||||
result := ""
|
||||
for _, char := range hashString {
|
||||
index := int(char) % 62
|
||||
result += string(baseWord[index])
|
||||
}
|
||||
return result
|
||||
}
|
||||
func getExtension(uri string) string {
|
||||
if uri == "" {
|
||||
return ""
|
||||
}
|
||||
fileExtension := path.Base(uri)
|
||||
ext := path.Ext(fileExtension)
|
||||
if ext != "" {
|
||||
return strings.TrimPrefix(ext, ".")
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func getCreatorIDFromExtraMeta(extraMeta map[string]any) int64 {
|
||||
if extraMeta == nil {
|
||||
return 0
|
||||
}
|
||||
if uid, ok := extraMeta[document.MetaDataKeyCreatorID]; ok {
|
||||
if uidInt, ok := uid.(int64); ok {
|
||||
return uidInt
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
151
backend/infra/impl/document/progressbar/impl.go
Normal file
151
backend/infra/impl/document/progressbar/impl.go
Normal file
@@ -0,0 +1,151 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package progressbar
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/progressbar"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
type ProgressBarImpl struct {
|
||||
CacheCli cache.Cmdable
|
||||
PrimaryKeyID int64
|
||||
Total int64
|
||||
ErrMsg string
|
||||
}
|
||||
|
||||
const (
|
||||
ttl = time.Hour * 2
|
||||
ProgressBarStartTimeRedisKey = "RedisBiz.Knowledge_ProgressBar_StartTime_%d"
|
||||
ProgressBarErrMsgRedisKey = "RedisBiz.Knowledge_ProgressBar_ErrMsg_%d"
|
||||
ProgressBarTotalNumRedisKey = "RedisBiz.Knowledge_ProgressBar_TotalNum_%d"
|
||||
ProgressBarProcessedNumRedisKey = "RedisBiz.Knowledge_ProgressBar_ProcessedNum_%d"
|
||||
DefaultProcessTime = 300
|
||||
ProcessDone = 100
|
||||
ProcessInit = 0
|
||||
)
|
||||
|
||||
func NewProgressBar(ctx context.Context, pkID int64, total int64, CacheCli cache.Cmdable, needInit bool) progressbar.ProgressBar {
|
||||
if needInit {
|
||||
CacheCli.Set(ctx, fmt.Sprintf(ProgressBarTotalNumRedisKey, pkID), total, ttl)
|
||||
CacheCli.Set(ctx, fmt.Sprintf(ProgressBarProcessedNumRedisKey, pkID), 0, ttl)
|
||||
CacheCli.Set(ctx, fmt.Sprintf(ProgressBarErrMsgRedisKey, pkID), "", ttl)
|
||||
CacheCli.Set(ctx, fmt.Sprintf(ProgressBarStartTimeRedisKey, pkID), time.Now().Unix(), ttl)
|
||||
}
|
||||
return &ProgressBarImpl{
|
||||
PrimaryKeyID: pkID,
|
||||
Total: total,
|
||||
CacheCli: CacheCli,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProgressBarImpl) AddN(n int) error {
|
||||
if p.ErrMsg != "" {
|
||||
return errors.New(p.ErrMsg)
|
||||
}
|
||||
_, err := p.CacheCli.IncrBy(context.Background(), fmt.Sprintf(ProgressBarProcessedNumRedisKey, p.PrimaryKeyID), int64(n)).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ProgressBarImpl) ReportError(err error) error {
|
||||
p.ErrMsg = err.Error()
|
||||
_, err = p.CacheCli.Set(context.Background(), fmt.Sprintf(ProgressBarErrMsgRedisKey, p.PrimaryKeyID), err.Error(), ttl).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ProgressBarImpl) GetProgress(ctx context.Context) (percent int, remainSec int, errMsg string) {
|
||||
var (
|
||||
totalNum *int64
|
||||
processedNum *int64
|
||||
startTime *int64
|
||||
err error
|
||||
)
|
||||
errMsg, err = p.CacheCli.Get(ctx, fmt.Sprintf(ProgressBarErrMsgRedisKey, p.PrimaryKeyID)).Result()
|
||||
if err == redis.Nil {
|
||||
errMsg = ""
|
||||
} else if err != nil {
|
||||
return ProcessDone, 0, err.Error()
|
||||
}
|
||||
if len(errMsg) != 0 {
|
||||
return ProcessDone, 0, errMsg
|
||||
}
|
||||
totalNumStr, err := p.CacheCli.Get(ctx, fmt.Sprintf(ProgressBarTotalNumRedisKey, p.PrimaryKeyID)).Result()
|
||||
if err == redis.Nil || len(totalNumStr) == 0 {
|
||||
totalNum = ptr.Of(int64(0))
|
||||
} else if err != nil {
|
||||
return ProcessDone, 0, err.Error()
|
||||
} else {
|
||||
num, err := conv.StrToInt64(totalNumStr)
|
||||
if err != nil {
|
||||
totalNum = ptr.Of(int64(0))
|
||||
} else {
|
||||
totalNum = ptr.Of(num)
|
||||
}
|
||||
}
|
||||
processedNumStr, err := p.CacheCli.Get(ctx, fmt.Sprintf(ProgressBarProcessedNumRedisKey, p.PrimaryKeyID)).Result()
|
||||
if err == redis.Nil || len(processedNumStr) == 0 {
|
||||
processedNum = ptr.Of(int64(0))
|
||||
} else if err != nil {
|
||||
return ProcessDone, 0, err.Error()
|
||||
} else {
|
||||
num, err := conv.StrToInt64(processedNumStr)
|
||||
if err != nil {
|
||||
processedNum = ptr.Of(int64(0))
|
||||
} else {
|
||||
processedNum = ptr.Of(num)
|
||||
}
|
||||
}
|
||||
if ptr.From(totalNum) == 0 {
|
||||
return ProcessInit, DefaultProcessTime, ""
|
||||
}
|
||||
startTimeStr, err := p.CacheCli.Get(ctx, fmt.Sprintf(ProgressBarStartTimeRedisKey, p.PrimaryKeyID)).Result()
|
||||
if err == redis.Nil || len(startTimeStr) == 0 {
|
||||
startTime = ptr.Of(int64(0))
|
||||
} else if err != nil {
|
||||
return ProcessDone, 0, err.Error()
|
||||
} else {
|
||||
num, err := conv.StrToInt64(startTimeStr)
|
||||
if err != nil {
|
||||
startTime = ptr.Of(int64(0))
|
||||
} else {
|
||||
startTime = ptr.Of(num)
|
||||
}
|
||||
}
|
||||
percent = int(float64(ptr.From(processedNum)) / float64(ptr.From(totalNum)) * 100)
|
||||
if ptr.From(startTime) == 0 {
|
||||
remainSec = DefaultProcessTime
|
||||
} else {
|
||||
usedSec := time.Now().Unix() - ptr.From(startTime)
|
||||
remainSec = int(float64(ptr.From(totalNum)-ptr.From(processedNum)) / float64(ptr.From(processedNum)) * float64(usedSec))
|
||||
}
|
||||
return
|
||||
}
|
||||
70
backend/infra/impl/document/rerank/rrf/rrf.go
Normal file
70
backend/infra/impl/document/rerank/rrf/rrf.go
Normal file
@@ -0,0 +1,70 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package rrf
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/rerank"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
func NewRRFReranker(k int64) rerank.Reranker {
|
||||
if k == 0 {
|
||||
k = 60
|
||||
}
|
||||
return &rrfReranker{k}
|
||||
}
|
||||
|
||||
type rrfReranker struct {
|
||||
k int64
|
||||
}
|
||||
|
||||
func (r *rrfReranker) Rerank(ctx context.Context, req *rerank.Request) (*rerank.Response, error) {
|
||||
if req == nil || req.Data == nil || len(req.Data) == 0 {
|
||||
return nil, fmt.Errorf("invalid request: no data provided")
|
||||
}
|
||||
id2Score := make(map[string]float64)
|
||||
id2Data := make(map[string]*rerank.Data)
|
||||
for _, resultList := range req.Data {
|
||||
for rank := range resultList {
|
||||
result := resultList[rank]
|
||||
if result != nil && result.Document != nil {
|
||||
score := 1.0 / (float64(rank) + float64(r.k))
|
||||
if score > id2Score[result.Document.ID] {
|
||||
id2Score[result.Document.ID] = score
|
||||
id2Data[result.Document.ID] = result
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
var sorted []*rerank.Data
|
||||
for _, data := range id2Data {
|
||||
sorted = append(sorted, data)
|
||||
}
|
||||
sort.Slice(sorted, func(i, j int) bool {
|
||||
return id2Score[sorted[i].Document.ID] > id2Score[sorted[j].Document.ID]
|
||||
})
|
||||
topN := int64(len(sorted))
|
||||
if req.TopN != nil && ptr.From(req.TopN) != 0 && ptr.From(req.TopN) < topN {
|
||||
topN = ptr.From(req.TopN)
|
||||
}
|
||||
|
||||
return &rerank.Response{SortedData: sorted[:topN]}, nil
|
||||
}
|
||||
161
backend/infra/impl/document/rerank/vikingdb/vikingdb.go
Normal file
161
backend/infra/impl/document/rerank/vikingdb/vikingdb.go
Normal file
@@ -0,0 +1,161 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package vikingdb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
|
||||
"github.com/volcengine/volc-sdk-golang/base"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/rerank"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
AK string
|
||||
SK string
|
||||
|
||||
Region string // default cn-north-1
|
||||
}
|
||||
|
||||
func NewReranker(config *Config) rerank.Reranker {
|
||||
if config.Region == "" {
|
||||
config.Region = "cn-north-1"
|
||||
}
|
||||
return &reranker{config: config}
|
||||
}
|
||||
|
||||
const (
|
||||
domain = "api-knowledgebase.mlp.cn-beijing.volces.com"
|
||||
defaultModel = "base-multilingual-rerank"
|
||||
)
|
||||
|
||||
type reranker struct {
|
||||
config *Config
|
||||
}
|
||||
|
||||
type rerankReq struct {
|
||||
Datas []rerankData `json:"datas"`
|
||||
RerankModel string `json:"rerank_model"`
|
||||
}
|
||||
|
||||
type rerankData struct {
|
||||
Query string `json:"query"`
|
||||
Content string `json:"content"`
|
||||
Title *string `json:"title,omitempty"`
|
||||
}
|
||||
|
||||
type rerankResp struct {
|
||||
Code int64 `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data struct {
|
||||
Scores []float64 `json:"scores"`
|
||||
TokenUsage int64 `json:"token_usage"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
func (r *reranker) Rerank(ctx context.Context, req *rerank.Request) (*rerank.Response, error) {
|
||||
rReq := &rerankReq{
|
||||
Datas: make([]rerankData, 0, len(req.Data)),
|
||||
RerankModel: defaultModel,
|
||||
}
|
||||
|
||||
var flat []*rerank.Data
|
||||
for _, channel := range req.Data {
|
||||
flat = append(flat, channel...)
|
||||
}
|
||||
|
||||
for _, item := range flat {
|
||||
rReq.Datas = append(rReq.Datas, rerankData{
|
||||
Query: req.Query,
|
||||
Content: item.Document.Content,
|
||||
})
|
||||
}
|
||||
|
||||
body, err := json.Marshal(rReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(r.prepareRequest(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rResp := rerankResp{}
|
||||
if err = json.Unmarshal(respBody, &rResp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rResp.Code != 0 {
|
||||
return nil, fmt.Errorf("[Rerank] failed, code=%d, msg=%v", rResp.Code, rResp.Message)
|
||||
}
|
||||
|
||||
sorted := make([]*rerank.Data, 0, len(rResp.Data.Scores))
|
||||
for i, score := range rResp.Data.Scores {
|
||||
sorted = append(sorted, &rerank.Data{
|
||||
Document: flat[i].Document,
|
||||
Score: score,
|
||||
})
|
||||
}
|
||||
|
||||
sort.Slice(sorted, func(i, j int) bool {
|
||||
return sorted[i].Score > sorted[j].Score
|
||||
})
|
||||
|
||||
right := len(sorted)
|
||||
if req.TopN != nil {
|
||||
right = min(right, int(*req.TopN))
|
||||
}
|
||||
|
||||
return &rerank.Response{
|
||||
SortedData: sorted[:right],
|
||||
TokenUsage: ptr.Of(rResp.Data.TokenUsage),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *reranker) prepareRequest(body []byte) *http.Request {
|
||||
u := url.URL{
|
||||
Scheme: "https",
|
||||
Host: domain,
|
||||
Path: "/api/knowledge/service/rerank",
|
||||
}
|
||||
req, _ := http.NewRequest(http.MethodPost, u.String(), bytes.NewReader(body))
|
||||
req.Header.Add("Accept", "application/json")
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
req.Header.Add("Host", domain)
|
||||
credential := base.Credentials{
|
||||
AccessKeyID: r.config.AK,
|
||||
SecretAccessKey: r.config.SK,
|
||||
Service: "air",
|
||||
Region: r.config.Region,
|
||||
}
|
||||
req = credential.Sign(req)
|
||||
return req
|
||||
}
|
||||
48
backend/infra/impl/document/rerank/vikingdb/vikingdb_test.go
Normal file
48
backend/infra/impl/document/rerank/vikingdb/vikingdb_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package vikingdb
|
||||
|
||||
//func TestRun(t *testing.T) {
|
||||
// AK := os.Getenv("test_ak")
|
||||
// SK := os.Getenv("test_sk")
|
||||
//
|
||||
// r := NewReranker(&Config{
|
||||
// AK: AK,
|
||||
// SK: SK,
|
||||
// })
|
||||
// resp, err := r.Rerank(context.Background(), &rerank.Request{
|
||||
// Data: [][]*knowledge.RetrieveSlice{
|
||||
// {
|
||||
// {Slice: &entity.Slice{PlainText: "吉尼斯世界纪录网站数据显示,蓝鲸是目前已知世界上最大的动物,体长可达30米,相当于一架波音737飞机的长度"}},
|
||||
// {Slice: &entity.Slice{PlainText: "一头成年雌性弓头鲸可以长到22米长,而一头雄性鲸鱼可以长到18米长"}},
|
||||
// },
|
||||
// },
|
||||
// Query: "世界上最大的鲸鱼是什么?",
|
||||
// TopN: nil,
|
||||
// })
|
||||
// assert.NoError(t, err)
|
||||
//
|
||||
// for _, item := range resp.Sorted {
|
||||
// fmt.Println(item.Slice.PlainText, item.Score)
|
||||
// }
|
||||
// // 吉尼斯世界纪录网站数据显示,蓝鲸是目前已知世界上最大的动物,体长可达30米,相当于一架波音737飞机的长度 0.6209664529733573
|
||||
// // 一头成年雌性弓头鲸可以长到22米长,而一头雄性鲸鱼可以长到18米长 0.4269785303456468
|
||||
//
|
||||
// fmt.Println(resp.TokenUsage)
|
||||
// // 95
|
||||
//
|
||||
//}
|
||||
@@ -0,0 +1,21 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package elasticsearch
|
||||
|
||||
const (
|
||||
topK = 10
|
||||
)
|
||||
@@ -0,0 +1,116 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package elasticsearch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/es"
|
||||
)
|
||||
|
||||
type ManagerConfig struct {
|
||||
Client es.Client
|
||||
}
|
||||
|
||||
func NewManager(config *ManagerConfig) searchstore.Manager {
|
||||
return &esManager{config: config}
|
||||
}
|
||||
|
||||
type esManager struct {
|
||||
config *ManagerConfig
|
||||
}
|
||||
|
||||
func (e *esManager) Create(ctx context.Context, req *searchstore.CreateRequest) error {
|
||||
cli := e.config.Client
|
||||
index := req.CollectionName
|
||||
indexExists, err := cli.Exists(ctx, index)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if indexExists { // exists
|
||||
return nil
|
||||
}
|
||||
|
||||
properties := make(map[string]any)
|
||||
var foundID, foundCreatorID, foundTextContent bool
|
||||
for _, field := range req.Fields {
|
||||
switch field.Name {
|
||||
case searchstore.FieldID:
|
||||
foundID = true
|
||||
case searchstore.FieldCreatorID:
|
||||
foundCreatorID = true
|
||||
case searchstore.FieldTextContent:
|
||||
foundTextContent = true
|
||||
default:
|
||||
|
||||
}
|
||||
|
||||
var property any
|
||||
switch field.Type {
|
||||
case searchstore.FieldTypeInt64:
|
||||
property = cli.Types().NewLongNumberProperty()
|
||||
case searchstore.FieldTypeText:
|
||||
property = cli.Types().NewTextProperty()
|
||||
default:
|
||||
return fmt.Errorf("[Create] es unsupported field type: %d", field.Type)
|
||||
}
|
||||
|
||||
properties[field.Name] = property
|
||||
}
|
||||
|
||||
if !foundID {
|
||||
properties[searchstore.FieldID] = cli.Types().NewLongNumberProperty()
|
||||
}
|
||||
if !foundCreatorID {
|
||||
properties[searchstore.FieldCreatorID] = cli.Types().NewUnsignedLongNumberProperty()
|
||||
}
|
||||
if !foundTextContent {
|
||||
properties[searchstore.FieldTextContent] = cli.Types().NewTextProperty()
|
||||
}
|
||||
|
||||
err = cli.CreateIndex(ctx, index, properties)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (e *esManager) Drop(ctx context.Context, req *searchstore.DropRequest) error {
|
||||
cli := e.config.Client
|
||||
index := req.CollectionName
|
||||
|
||||
return cli.DeleteIndex(ctx, index)
|
||||
}
|
||||
|
||||
func (e *esManager) GetType() searchstore.SearchStoreType {
|
||||
return searchstore.TypeTextStore
|
||||
}
|
||||
|
||||
func (e *esManager) GetSearchStore(ctx context.Context, collectionName string) (searchstore.SearchStore, error) {
|
||||
return &esSearchStore{
|
||||
config: e.config,
|
||||
indexName: collectionName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *esManager) GetEmbedding() embedding.Embedder {
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,329 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package elasticsearch
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
|
||||
"github.com/cloudwego/eino/components/indexer"
|
||||
"github.com/cloudwego/eino/components/retriever"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/es"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
type esSearchStore struct {
|
||||
config *ManagerConfig
|
||||
indexName string
|
||||
}
|
||||
|
||||
func (e *esSearchStore) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) (ids []string, err error) {
|
||||
implSpecOptions := indexer.GetImplSpecificOptions(&searchstore.IndexerOptions{}, opts...)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if implSpecOptions.ProgressBar != nil {
|
||||
implSpecOptions.ProgressBar.ReportError(err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
cli := e.config.Client
|
||||
index := e.indexName
|
||||
bi, err := cli.NewBulkIndexer(index)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ids = make([]string, 0, len(docs))
|
||||
for _, doc := range docs {
|
||||
fieldMapping, err := e.fromDocument(doc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
body, err := json.Marshal(fieldMapping)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = bi.Add(ctx, es.BulkIndexerItem{
|
||||
Index: e.indexName,
|
||||
Action: "index",
|
||||
DocumentID: doc.ID,
|
||||
Body: bytes.NewReader(body),
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ids = append(ids, doc.ID)
|
||||
if implSpecOptions.ProgressBar != nil {
|
||||
if err = implSpecOptions.ProgressBar.AddN(1); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err = bi.Close(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (e *esSearchStore) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) {
|
||||
var (
|
||||
cli = e.config.Client
|
||||
index = e.indexName
|
||||
|
||||
options = retriever.GetCommonOptions(&retriever.Options{TopK: ptr.Of(topK)}, opts...)
|
||||
implSpecOptions = retriever.GetImplSpecificOptions(&searchstore.RetrieverOptions{}, opts...)
|
||||
req = &es.Request{
|
||||
Query: &es.Query{
|
||||
Bool: &es.BoolQuery{},
|
||||
},
|
||||
Size: options.TopK,
|
||||
}
|
||||
)
|
||||
|
||||
if implSpecOptions.MultiMatch == nil {
|
||||
req.Query.Bool.Must = append(req.Query.Bool.Must,
|
||||
es.NewMatchQuery(searchstore.FieldTextContent, query))
|
||||
} else {
|
||||
req.Query.Bool.Must = append(req.Query.Bool.Must,
|
||||
es.NewMultiMatchQuery(implSpecOptions.MultiMatch.Fields, query,
|
||||
"best_fields", es.Or))
|
||||
}
|
||||
|
||||
dsl, err := searchstore.LoadDSL(options.DSLInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = e.travDSL(req.Query, dsl); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if options.ScoreThreshold != nil {
|
||||
req.MinScore = options.ScoreThreshold
|
||||
}
|
||||
|
||||
resp, err := cli.Search(ctx, index, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
docs, err := e.parseSearchResult(resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
func (e *esSearchStore) Delete(ctx context.Context, ids []string) error {
|
||||
bi, err := e.config.Client.NewBulkIndexer(e.indexName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, id := range ids {
|
||||
if err = bi.Add(ctx, es.BulkIndexerItem{
|
||||
Index: e.indexName,
|
||||
Action: "delete",
|
||||
DocumentID: id,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return bi.Close(ctx)
|
||||
}
|
||||
|
||||
func (e *esSearchStore) travDSL(query *es.Query, dsl *searchstore.DSL) error {
|
||||
if dsl == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch dsl.Op {
|
||||
case searchstore.OpEq, searchstore.OpNe:
|
||||
arr := stringifyValue(dsl.Value)
|
||||
v := dsl.Value
|
||||
if len(arr) > 0 {
|
||||
v = arr[0]
|
||||
}
|
||||
|
||||
if dsl.Op == searchstore.OpEq {
|
||||
query.Bool.Must = append(query.Bool.Must,
|
||||
es.NewEqualQuery(dsl.Field, v))
|
||||
} else {
|
||||
query.Bool.MustNot = append(query.Bool.MustNot,
|
||||
es.NewEqualQuery(dsl.Field, v))
|
||||
}
|
||||
case searchstore.OpLike:
|
||||
s, ok := dsl.Value.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("[travDSL] OpLike value should be string, but got %v", dsl.Value)
|
||||
}
|
||||
query.Bool.Must = append(query.Bool.Must, es.NewMatchQuery(dsl.Field, s))
|
||||
|
||||
case searchstore.OpIn:
|
||||
query.Bool.Must = append(query.Bool.MustNot,
|
||||
es.NewInQuery(dsl.Field, stringifyValue(dsl.Value)))
|
||||
|
||||
case searchstore.OpAnd, searchstore.OpOr:
|
||||
conds, ok := dsl.Value.([]*searchstore.DSL)
|
||||
if !ok {
|
||||
return fmt.Errorf("[travDSL] value type assertion failed for or")
|
||||
}
|
||||
|
||||
for _, cond := range conds {
|
||||
sub := &es.Query{}
|
||||
if err := e.travDSL(sub, cond); err != nil {
|
||||
return err
|
||||
}
|
||||
if dsl.Op == searchstore.OpOr {
|
||||
query.Bool.Should = append(query.Bool.Should, *sub)
|
||||
} else {
|
||||
query.Bool.Must = append(query.Bool.Must, *sub)
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
return fmt.Errorf("[trav] unknown op %s", dsl.Op)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *esSearchStore) parseSearchResult(resp *es.Response) (docs []*schema.Document, err error) {
|
||||
docs = make([]*schema.Document, 0, len(resp.Hits.Hits))
|
||||
firstScore := 0.0
|
||||
for i, hit := range resp.Hits.Hits {
|
||||
var src map[string]any
|
||||
d := json.NewDecoder(bytes.NewReader(hit.Source_))
|
||||
d.UseNumber()
|
||||
if err = d.Decode(&src); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ext := make(map[string]any)
|
||||
doc := &schema.Document{MetaData: map[string]any{document.MetaDataKeyExternalStorage: ext}}
|
||||
|
||||
for field, val := range src {
|
||||
ok := true
|
||||
switch field {
|
||||
case searchstore.FieldTextContent:
|
||||
doc.Content, ok = val.(string)
|
||||
case searchstore.FieldCreatorID:
|
||||
var jn json.Number
|
||||
jn, ok = val.(json.Number)
|
||||
if ok {
|
||||
doc.MetaData[document.MetaDataKeyCreatorID], ok = assertJSONNumber(jn).(int64)
|
||||
}
|
||||
default:
|
||||
if jn, jok := val.(json.Number); jok {
|
||||
ext[field] = assertJSONNumber(jn)
|
||||
} else {
|
||||
ext[field] = val
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("[parseSearchResult] type assertion failed, field=%s, val=%v", field, val)
|
||||
}
|
||||
}
|
||||
if hit.Id_ != nil {
|
||||
doc.ID = *hit.Id_
|
||||
}
|
||||
if hit.Score_ == nil { // unexpected
|
||||
return nil, fmt.Errorf("[parseSearchResult] es retrieve score not found")
|
||||
}
|
||||
score := float64(ptr.From(hit.Score_))
|
||||
if i == 0 {
|
||||
firstScore = score
|
||||
}
|
||||
doc.WithScore(score / firstScore)
|
||||
|
||||
docs = append(docs, doc)
|
||||
}
|
||||
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
func (e *esSearchStore) fromDocument(doc *schema.Document) (map[string]any, error) {
|
||||
if doc.MetaData == nil {
|
||||
return nil, fmt.Errorf("[fromDocument] es document meta data is nil")
|
||||
}
|
||||
|
||||
creatorID, ok := doc.MetaData[searchstore.FieldCreatorID].(int64)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("[fromDocument] creator id not found or type invalid")
|
||||
}
|
||||
|
||||
fieldMapping := map[string]any{
|
||||
searchstore.FieldTextContent: doc.Content,
|
||||
searchstore.FieldCreatorID: creatorID,
|
||||
}
|
||||
|
||||
if ext, ok := doc.MetaData[document.MetaDataKeyExternalStorage].(map[string]any); ok {
|
||||
for k, v := range ext {
|
||||
fieldMapping[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return fieldMapping, nil
|
||||
}
|
||||
|
||||
func stringifyValue(dslValue any) []any {
|
||||
value := reflect.ValueOf(dslValue)
|
||||
switch value.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
length := value.Len()
|
||||
slice := make([]any, 0, length)
|
||||
for i := 0; i < length; i++ {
|
||||
elem := value.Index(i)
|
||||
switch elem.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
slice = append(slice, strconv.FormatInt(elem.Int(), 10))
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
slice = append(slice, strconv.FormatUint(elem.Uint(), 10))
|
||||
case reflect.Float32, reflect.Float64:
|
||||
slice = append(slice, strconv.FormatFloat(elem.Float(), 'f', -1, 64))
|
||||
case reflect.String:
|
||||
slice = append(slice, elem.String())
|
||||
default:
|
||||
slice = append(slice, elem) // do nothing
|
||||
}
|
||||
}
|
||||
return slice
|
||||
default:
|
||||
return []any{fmt.Sprintf("%v", value)}
|
||||
}
|
||||
}
|
||||
|
||||
func assertJSONNumber(f json.Number) any {
|
||||
if i64, err := f.Int64(); err == nil {
|
||||
return i64
|
||||
}
|
||||
if f64, err := f.Float64(); err == nil {
|
||||
return f64
|
||||
}
|
||||
return f.String()
|
||||
}
|
||||
22
backend/infra/impl/document/searchstore/milvus/consts.go
Normal file
22
backend/infra/impl/document/searchstore/milvus/consts.go
Normal file
@@ -0,0 +1,22 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package milvus
|
||||
|
||||
const (
|
||||
batchSize = 100
|
||||
topK = 4
|
||||
)
|
||||
119
backend/infra/impl/document/searchstore/milvus/convert.go
Normal file
119
backend/infra/impl/document/searchstore/milvus/convert.go
Normal file
@@ -0,0 +1,119 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package milvus
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
)
|
||||
|
||||
func denseFieldName(name string) string {
|
||||
return fmt.Sprintf("dense_%s", name)
|
||||
}
|
||||
|
||||
func denseIndexName(name string) string {
|
||||
return fmt.Sprintf("index_dense_%s", name)
|
||||
}
|
||||
|
||||
func sparseFieldName(name string) string {
|
||||
return fmt.Sprintf("sparse_%s", name)
|
||||
}
|
||||
|
||||
func sparseIndexName(name string) string {
|
||||
return fmt.Sprintf("index_sparse_%s", name)
|
||||
}
|
||||
|
||||
func convertFieldType(typ searchstore.FieldType) (entity.FieldType, error) {
|
||||
switch typ {
|
||||
case searchstore.FieldTypeInt64:
|
||||
return entity.FieldTypeInt64, nil
|
||||
case searchstore.FieldTypeText:
|
||||
return entity.FieldTypeVarChar, nil
|
||||
case searchstore.FieldTypeDenseVector:
|
||||
return entity.FieldTypeFloatVector, nil
|
||||
case searchstore.FieldTypeSparseVector:
|
||||
return entity.FieldTypeSparseVector, nil
|
||||
default:
|
||||
return entity.FieldTypeNone, fmt.Errorf("[convertFieldType] unknown field type: %v", typ)
|
||||
}
|
||||
}
|
||||
|
||||
func convertDense(dense [][]float64) [][]float32 {
|
||||
return slices.Transform(dense, func(a []float64) []float32 {
|
||||
r := make([]float32, len(a))
|
||||
for i := 0; i < len(a); i++ {
|
||||
r[i] = float32(a[i])
|
||||
}
|
||||
return r
|
||||
})
|
||||
}
|
||||
|
||||
func convertMilvusDenseVector(dense [][]float64) []entity.Vector {
|
||||
return slices.Transform(dense, func(a []float64) entity.Vector {
|
||||
r := make([]float32, len(a))
|
||||
for i := 0; i < len(a); i++ {
|
||||
r[i] = float32(a[i])
|
||||
}
|
||||
return entity.FloatVector(r)
|
||||
})
|
||||
}
|
||||
|
||||
func convertSparse(sparse []map[int]float64) ([]entity.SparseEmbedding, error) {
|
||||
r := make([]entity.SparseEmbedding, 0, len(sparse))
|
||||
for _, s := range sparse {
|
||||
ks := make([]uint32, 0, len(s))
|
||||
vs := make([]float32, 0, len(s))
|
||||
for k, v := range s {
|
||||
ks = append(ks, uint32(k))
|
||||
vs = append(vs, float32(v))
|
||||
}
|
||||
|
||||
se, err := entity.NewSliceSparseEmbedding(ks, vs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r = append(r, se)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func convertMilvusSparseVector(sparse []map[int]float64) ([]entity.Vector, error) {
|
||||
r := make([]entity.Vector, 0, len(sparse))
|
||||
for _, s := range sparse {
|
||||
ks := make([]uint32, 0, len(s))
|
||||
vs := make([]float32, 0, len(s))
|
||||
for k, v := range s {
|
||||
ks = append(ks, uint32(k))
|
||||
vs = append(vs, float32(v))
|
||||
}
|
||||
|
||||
se, err := entity.NewSliceSparseEmbedding(ks, vs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r = append(r, se)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
334
backend/infra/impl/document/searchstore/milvus/milvus_manager.go
Normal file
334
backend/infra/impl/document/searchstore/milvus/milvus_manager.go
Normal file
@@ -0,0 +1,334 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package milvus
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
mentity "github.com/milvus-io/milvus/client/v2/entity"
|
||||
mindex "github.com/milvus-io/milvus/client/v2/index"
|
||||
client "github.com/milvus-io/milvus/client/v2/milvusclient"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/sets"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
type ManagerConfig struct {
|
||||
Client *client.Client // required
|
||||
Embedding embedding.Embedder // required
|
||||
|
||||
EnableHybrid *bool // optional: default Embedding.SupportStatus() == embedding.SupportDenseAndSparse
|
||||
DenseIndex mindex.Index // optional: default HNSW, M=30, efConstruction=360
|
||||
DenseMetric mentity.MetricType // optional: default IP
|
||||
SparseIndex mindex.Index // optional: default SPARSE_INVERTED_INDEX, drop_ratio=0.2
|
||||
SparseMetric mentity.MetricType // optional: default IP
|
||||
ShardNum int // optional: default 1
|
||||
BatchSize int // optional: default 100
|
||||
}
|
||||
|
||||
func NewManager(config *ManagerConfig) (searchstore.Manager, error) {
|
||||
if config.Client == nil {
|
||||
return nil, fmt.Errorf("[NewManager] milvus client not provided")
|
||||
}
|
||||
if config.Embedding == nil {
|
||||
return nil, fmt.Errorf("[NewManager] milvus embedder not provided")
|
||||
}
|
||||
|
||||
enableSparse := config.Embedding.SupportStatus() == embedding.SupportDenseAndSparse
|
||||
if config.EnableHybrid == nil {
|
||||
config.EnableHybrid = ptr.Of(enableSparse)
|
||||
} else if !enableSparse && ptr.From(config.EnableHybrid) {
|
||||
logs.Warnf("[NewManager] milvus embedding not support sparse, so hybrid search is disabled.")
|
||||
config.EnableHybrid = ptr.Of(false)
|
||||
}
|
||||
if config.DenseMetric == "" {
|
||||
config.DenseMetric = mentity.IP
|
||||
}
|
||||
if config.DenseIndex == nil {
|
||||
config.DenseIndex = mindex.NewHNSWIndex(config.DenseMetric, 30, 360)
|
||||
}
|
||||
if config.SparseMetric == "" {
|
||||
config.SparseMetric = mentity.IP
|
||||
}
|
||||
if config.SparseIndex == nil {
|
||||
config.SparseIndex = mindex.NewSparseInvertedIndex(config.SparseMetric, 0.2)
|
||||
}
|
||||
if config.ShardNum == 0 {
|
||||
config.ShardNum = 1
|
||||
}
|
||||
if config.BatchSize == 0 {
|
||||
config.BatchSize = 100
|
||||
}
|
||||
|
||||
return &milvusManager{config: config}, nil
|
||||
}
|
||||
|
||||
type milvusManager struct {
|
||||
config *ManagerConfig
|
||||
}
|
||||
|
||||
func (m *milvusManager) Create(ctx context.Context, req *searchstore.CreateRequest) error {
|
||||
if err := m.createCollection(ctx, req); err != nil {
|
||||
return fmt.Errorf("[Create] create collection failed, %w", err)
|
||||
}
|
||||
|
||||
if err := m.createIndexes(ctx, req); err != nil {
|
||||
return fmt.Errorf("[Create] create indexes failed, %w", err)
|
||||
}
|
||||
|
||||
if exists, err := m.loadCollection(ctx, req.CollectionName); err != nil {
|
||||
return fmt.Errorf("[Create] load collection failed, %w", err)
|
||||
} else if !exists {
|
||||
return fmt.Errorf("[Create] load collection failed, collection=%v does not exist", req.CollectionName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *milvusManager) Drop(ctx context.Context, req *searchstore.DropRequest) error {
|
||||
return m.config.Client.DropCollection(ctx, client.NewDropCollectionOption(req.CollectionName))
|
||||
}
|
||||
|
||||
func (m *milvusManager) GetType() searchstore.SearchStoreType {
|
||||
return searchstore.TypeVectorStore
|
||||
}
|
||||
|
||||
func (m *milvusManager) GetSearchStore(ctx context.Context, collectionName string) (searchstore.SearchStore, error) {
|
||||
if exists, err := m.loadCollection(ctx, collectionName); err != nil {
|
||||
return nil, err
|
||||
} else if !exists {
|
||||
return nil, errorx.New(errno.ErrKnowledgeNonRetryableCode,
|
||||
errorx.KVf("reason", "[GetSearchStore] collection=%v does not exist", collectionName))
|
||||
}
|
||||
|
||||
return &milvusSearchStore{
|
||||
config: m.config,
|
||||
collectionName: collectionName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *milvusManager) createCollection(ctx context.Context, req *searchstore.CreateRequest) error {
|
||||
if req.CollectionName == "" || len(req.Fields) == 0 {
|
||||
return fmt.Errorf("[createCollection] invalid request params")
|
||||
}
|
||||
|
||||
cli := m.config.Client
|
||||
collectionName := req.CollectionName
|
||||
has, err := cli.HasCollection(ctx, client.NewHasCollectionOption(collectionName))
|
||||
if err != nil {
|
||||
return fmt.Errorf("[createCollection] HasCollection failed, %w", err)
|
||||
}
|
||||
if has {
|
||||
return nil
|
||||
}
|
||||
|
||||
fields, err := m.convertFields(req.Fields)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
opt := client.NewCreateCollectionOption(collectionName, &mentity.Schema{
|
||||
CollectionName: collectionName,
|
||||
Description: fmt.Sprintf("created by coze"),
|
||||
AutoID: false,
|
||||
Fields: fields,
|
||||
EnableDynamicField: false,
|
||||
}).WithShardNum(int32(m.config.ShardNum))
|
||||
|
||||
for k, v := range req.CollectionMeta {
|
||||
opt.WithProperty(k, v)
|
||||
}
|
||||
|
||||
if err = cli.CreateCollection(ctx, opt); err != nil {
|
||||
return fmt.Errorf("[createCollection] CreateCollection failed, %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *milvusManager) createIndexes(ctx context.Context, req *searchstore.CreateRequest) error {
|
||||
collectionName := req.CollectionName
|
||||
indexes, err := m.config.Client.ListIndexes(ctx, client.NewListIndexOption(req.CollectionName))
|
||||
if err != nil {
|
||||
if !strings.Contains(err.Error(), "index not found") {
|
||||
return fmt.Errorf("[createIndexes] ListIndexes failed, %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
createdIndexes := sets.FromSlice(indexes)
|
||||
|
||||
var ops []func() error
|
||||
for i := range req.Fields {
|
||||
f := req.Fields[i]
|
||||
if !f.Indexing {
|
||||
continue
|
||||
}
|
||||
|
||||
ops = append(ops, m.tryCreateIndex(ctx, collectionName, denseFieldName(f.Name), denseIndexName(f.Name), m.config.DenseIndex, createdIndexes))
|
||||
if m.config.Embedding.SupportStatus() == embedding.SupportDenseAndSparse {
|
||||
ops = append(ops, m.tryCreateIndex(ctx, collectionName, sparseFieldName(f.Name), sparseIndexName(f.Name), m.config.SparseIndex, createdIndexes))
|
||||
}
|
||||
}
|
||||
|
||||
for _, op := range ops {
|
||||
if err := op(); err != nil {
|
||||
return fmt.Errorf("[createIndexes] failed, %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *milvusManager) tryCreateIndex(ctx context.Context, collectionName, fieldName, indexName string, idx mindex.Index, createdIndexes sets.Set[string]) func() error {
|
||||
return func() error {
|
||||
if _, found := createdIndexes[indexName]; found {
|
||||
logs.CtxInfof(ctx, "[tryCreateIndex] index exists, so skip, collectionName=%s, fieldName=%s, idx=%v, type=%s\n",
|
||||
collectionName, fieldName, indexName, idx.IndexType())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
cli := m.config.Client
|
||||
|
||||
task, err := cli.CreateIndex(ctx, client.NewCreateIndexOption(collectionName, fieldName, idx).WithIndexName(indexName))
|
||||
if err != nil {
|
||||
return fmt.Errorf("[tryCreateIndex] CreateIndex failed, %w", err)
|
||||
}
|
||||
|
||||
if err = task.Await(ctx); err != nil {
|
||||
return fmt.Errorf("[tryCreateIndex] await failed, %w", err)
|
||||
}
|
||||
|
||||
logs.CtxInfof(ctx, "[tryCreateIndex] CreateIndex success, collectionName=%s, fieldName=%s, idx=%v, type=%s\n",
|
||||
collectionName, fieldName, indexName, idx.IndexType())
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *milvusManager) loadCollection(ctx context.Context, collectionName string) (exists bool, err error) {
|
||||
cli := m.config.Client
|
||||
|
||||
stat, err := cli.GetLoadState(ctx, client.NewGetLoadStateOption(collectionName))
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("[loadCollection] GetLoadState failed, %w", err)
|
||||
}
|
||||
|
||||
switch stat.State {
|
||||
case mentity.LoadStateNotLoad:
|
||||
task, err := cli.LoadCollection(ctx, client.NewLoadCollectionOption(collectionName))
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("[loadCollection] LoadCollection failed, collection=%v, %w", collectionName, err)
|
||||
}
|
||||
if err = task.Await(ctx); err != nil {
|
||||
return false, fmt.Errorf("[loadCollection] await failed, collection=%v, %w", collectionName, err)
|
||||
}
|
||||
return true, nil
|
||||
case mentity.LoadStateLoaded:
|
||||
return true, nil
|
||||
case mentity.LoadStateLoading:
|
||||
return true, fmt.Errorf("[loadCollection] collection is unloading, retry later, collection=%v", collectionName)
|
||||
case mentity.LoadStateUnloading:
|
||||
return false, nil
|
||||
default:
|
||||
return false, fmt.Errorf("[loadCollection] load state unexpected, state=%d", stat)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *milvusManager) convertFields(fields []*searchstore.Field) ([]*mentity.Field, error) {
|
||||
var foundID, foundCreatorID bool
|
||||
resp := make([]*mentity.Field, 0, len(fields))
|
||||
for _, f := range fields {
|
||||
switch f.Name {
|
||||
case searchstore.FieldID:
|
||||
foundID = true
|
||||
case searchstore.FieldCreatorID:
|
||||
foundCreatorID = true
|
||||
default:
|
||||
}
|
||||
|
||||
if f.Indexing {
|
||||
if f.Type != searchstore.FieldTypeText {
|
||||
return nil, fmt.Errorf("[convertFields] milvus only support text field indexing, field=%s, type=%d", f.Name, f.Type)
|
||||
}
|
||||
// indexing 时只有 content 存储原文
|
||||
if f.Name == searchstore.FieldTextContent {
|
||||
resp = append(resp, mentity.NewField().
|
||||
WithName(f.Name).
|
||||
WithDescription(f.Description).
|
||||
WithIsPrimaryKey(f.IsPrimary).
|
||||
WithNullable(f.Nullable).
|
||||
WithDataType(mentity.FieldTypeVarChar).
|
||||
WithMaxLength(65535))
|
||||
}
|
||||
resp = append(resp, mentity.NewField().
|
||||
WithName(denseFieldName(f.Name)).
|
||||
WithDataType(mentity.FieldTypeFloatVector).
|
||||
WithDim(m.config.Embedding.Dimensions()))
|
||||
if m.config.Embedding.SupportStatus() == embedding.SupportDenseAndSparse {
|
||||
resp = append(resp, mentity.NewField().
|
||||
WithName(sparseFieldName(f.Name)).
|
||||
WithDataType(mentity.FieldTypeSparseVector))
|
||||
}
|
||||
} else {
|
||||
mf := mentity.NewField().
|
||||
WithName(f.Name).
|
||||
WithDescription(f.Description).
|
||||
WithIsPrimaryKey(f.IsPrimary).
|
||||
WithNullable(f.Nullable)
|
||||
typ, err := convertFieldType(f.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mf.WithDataType(typ)
|
||||
if typ == mentity.FieldTypeVarChar {
|
||||
mf.WithMaxLength(65535)
|
||||
} else if typ == mentity.FieldTypeFloatVector {
|
||||
mf.WithDim(m.config.Embedding.Dimensions())
|
||||
}
|
||||
resp = append(resp, mf)
|
||||
}
|
||||
}
|
||||
|
||||
if !foundID {
|
||||
resp = append(resp, mentity.NewField().
|
||||
WithName(searchstore.FieldID).
|
||||
WithDataType(mentity.FieldTypeInt64).
|
||||
WithIsPrimaryKey(true).
|
||||
WithNullable(false))
|
||||
}
|
||||
|
||||
if !foundCreatorID {
|
||||
resp = append(resp, mentity.NewField().
|
||||
WithName(searchstore.FieldCreatorID).
|
||||
WithDataType(mentity.FieldTypeInt64).
|
||||
WithNullable(false))
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (m *milvusManager) GetEmbedding() embedding.Embedder {
|
||||
return m.config.Embedding
|
||||
}
|
||||
@@ -0,0 +1,600 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package milvus
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/components/indexer"
|
||||
"github.com/cloudwego/eino/components/retriever"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/milvus-io/milvus/client/v2/column"
|
||||
mentity "github.com/milvus-io/milvus/client/v2/entity"
|
||||
mindex "github.com/milvus-io/milvus/client/v2/index"
|
||||
client "github.com/milvus-io/milvus/client/v2/milvusclient"
|
||||
"github.com/slongfield/pyfmt"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/sets"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
)
|
||||
|
||||
type milvusSearchStore struct {
|
||||
config *ManagerConfig
|
||||
|
||||
collectionName string
|
||||
}
|
||||
|
||||
func (m *milvusSearchStore) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) (ids []string, err error) {
|
||||
if len(docs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
implSpecOptions := indexer.GetImplSpecificOptions(&searchstore.IndexerOptions{}, opts...)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if implSpecOptions.ProgressBar != nil {
|
||||
implSpecOptions.ProgressBar.ReportError(err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
indexingFields := make(sets.Set[string])
|
||||
for _, field := range implSpecOptions.IndexingFields {
|
||||
indexingFields[field] = struct{}{}
|
||||
}
|
||||
|
||||
if implSpecOptions.Partition != nil {
|
||||
partition := *implSpecOptions.Partition
|
||||
hasPartition, err := m.config.Client.HasPartition(ctx, client.NewHasPartitionOption(m.collectionName, partition))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[Store] HasPartition failed, %w", err)
|
||||
}
|
||||
if !hasPartition {
|
||||
if err = m.config.Client.CreatePartition(ctx, client.NewCreatePartitionOption(m.collectionName, partition)); err != nil {
|
||||
return nil, fmt.Errorf("[Store] CreatePartition failed, %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, part := range slices.Chunks(docs, batchSize) {
|
||||
columns, err := m.documents2Columns(ctx, part, indexingFields)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
createReq := client.NewColumnBasedInsertOption(m.collectionName, columns...)
|
||||
if implSpecOptions.Partition != nil {
|
||||
createReq.WithPartition(*implSpecOptions.Partition)
|
||||
}
|
||||
|
||||
result, err := m.config.Client.Upsert(ctx, createReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[Store] upsert failed, %w", err)
|
||||
}
|
||||
|
||||
partIDs := result.IDs
|
||||
for i := 0; i < partIDs.Len(); i++ {
|
||||
var sid string
|
||||
if partIDs.Type() == mentity.FieldTypeInt64 {
|
||||
id, err := partIDs.GetAsInt64(i)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sid = strconv.FormatInt(id, 10)
|
||||
} else {
|
||||
sid, err = partIDs.GetAsString(i)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
ids = append(ids, sid)
|
||||
}
|
||||
if implSpecOptions.ProgressBar != nil {
|
||||
if err = implSpecOptions.ProgressBar.AddN(len(part)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (m *milvusSearchStore) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) {
|
||||
cli := m.config.Client
|
||||
emb := m.config.Embedding
|
||||
options := retriever.GetCommonOptions(&retriever.Options{TopK: ptr.Of(topK)}, opts...)
|
||||
implSpecOptions := retriever.GetImplSpecificOptions(&searchstore.RetrieverOptions{}, opts...)
|
||||
|
||||
desc, err := cli.DescribeCollection(ctx, client.NewDescribeCollectionOption(m.collectionName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var (
|
||||
dense [][]float64
|
||||
sparse []map[int]float64
|
||||
expr string
|
||||
result []client.ResultSet
|
||||
|
||||
fields = desc.Schema.Fields
|
||||
outputFields []string
|
||||
enableSparse = m.enableSparse(fields)
|
||||
)
|
||||
|
||||
if options.DSLInfo != nil {
|
||||
expr, err = m.dsl2Expr(options.DSLInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if enableSparse {
|
||||
dense, sparse, err = emb.EmbedStringsHybrid(ctx, []string{query})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[Retrieve] EmbedStringsHybrid failed, %w", err)
|
||||
}
|
||||
} else {
|
||||
dense, err = emb.EmbedStrings(ctx, []string{query})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[Retrieve] EmbedStrings failed, %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
dv := convertMilvusDenseVector(dense)
|
||||
sv, err := convertMilvusSparseVector(sparse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, field := range fields {
|
||||
outputFields = append(outputFields, field.Name)
|
||||
}
|
||||
|
||||
var scoreNormType *mindex.MetricType
|
||||
|
||||
if enableSparse {
|
||||
var annRequests []*client.AnnRequest
|
||||
for _, field := range fields {
|
||||
var (
|
||||
vector []mentity.Vector
|
||||
metricsType mindex.MetricType
|
||||
)
|
||||
if field.DataType == mentity.FieldTypeFloatVector {
|
||||
vector = dv
|
||||
metricsType, err = m.getIndexMetricsType(ctx, denseIndexName(field.Name))
|
||||
} else if field.DataType == mentity.FieldTypeSparseVector {
|
||||
vector = sv
|
||||
metricsType, err = m.getIndexMetricsType(ctx, sparseIndexName(field.Name))
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
annRequests = append(annRequests,
|
||||
client.NewAnnRequest(field.Name, ptr.From(options.TopK), vector...).
|
||||
WithSearchParam(mindex.MetricTypeKey, string(metricsType)).
|
||||
WithFilter(expr),
|
||||
)
|
||||
}
|
||||
|
||||
searchOption := client.NewHybridSearchOption(m.collectionName, ptr.From(options.TopK), annRequests...).
|
||||
WithPartitons(implSpecOptions.Partitions...).
|
||||
WithReranker(client.NewRRFReranker()).
|
||||
WithOutputFields(outputFields...)
|
||||
|
||||
result, err = cli.HybridSearch(ctx, searchOption)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[Retrieve] HybridSearch failed, %w", err)
|
||||
}
|
||||
} else {
|
||||
indexes, err := cli.ListIndexes(ctx, client.NewListIndexOption(m.collectionName))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[Retrieve] ListIndexes failed, %w", err)
|
||||
}
|
||||
if len(indexes) != 1 {
|
||||
return nil, fmt.Errorf("[Retrieve] restrict single index ann search, but got %d, collection=%s",
|
||||
len(indexes), m.collectionName)
|
||||
}
|
||||
metricsType, err := m.getIndexMetricsType(ctx, indexes[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
scoreNormType = &metricsType
|
||||
searchOption := client.NewSearchOption(m.collectionName, ptr.From(options.TopK), dv).
|
||||
WithPartitions(implSpecOptions.Partitions...).
|
||||
WithFilter(expr).
|
||||
WithOutputFields(outputFields...).
|
||||
WithSearchParam(mindex.MetricTypeKey, string(metricsType))
|
||||
result, err = cli.Search(ctx, searchOption)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[Retrieve] Search failed, %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
docs, err := m.resultSet2Document(result, scoreNormType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[Retrieve] resultSet2Document failed, %w", err)
|
||||
}
|
||||
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
func (m *milvusSearchStore) Delete(ctx context.Context, ids []string) error {
|
||||
int64IDs := make([]int64, 0, len(ids))
|
||||
for _, sid := range ids {
|
||||
id, err := strconv.ParseInt(sid, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
int64IDs = append(int64IDs, id)
|
||||
}
|
||||
_, err := m.config.Client.Delete(ctx,
|
||||
client.NewDeleteOption(m.collectionName).WithInt64IDs(searchstore.FieldID, int64IDs))
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *milvusSearchStore) documents2Columns(ctx context.Context, docs []*schema.Document, indexingFields sets.Set[string]) (
|
||||
cols []column.Column, err error) {
|
||||
|
||||
var (
|
||||
ids []int64
|
||||
contents []string
|
||||
creatorIDs []int64
|
||||
emptyContents = true
|
||||
)
|
||||
|
||||
colMapping := map[string]any{}
|
||||
colTypeMapping := map[string]searchstore.FieldType{
|
||||
searchstore.FieldID: searchstore.FieldTypeInt64,
|
||||
searchstore.FieldCreatorID: searchstore.FieldTypeInt64,
|
||||
searchstore.FieldTextContent: searchstore.FieldTypeText,
|
||||
}
|
||||
for _, doc := range docs {
|
||||
if doc.MetaData == nil {
|
||||
return nil, fmt.Errorf("[documents2Columns] meta data is nil")
|
||||
}
|
||||
|
||||
id, err := strconv.ParseInt(doc.ID, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[documents2Columns] parse id failed, %w", err)
|
||||
}
|
||||
ids = append(ids, id)
|
||||
contents = append(contents, doc.Content)
|
||||
if doc.Content != "" {
|
||||
emptyContents = false
|
||||
}
|
||||
|
||||
creatorID, err := document.GetDocumentCreatorID(doc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[documents2Columns] creator_id not found or type invalid., %w", err)
|
||||
}
|
||||
creatorIDs = append(creatorIDs, creatorID)
|
||||
|
||||
ext, ok := doc.MetaData[document.MetaDataKeyExternalStorage].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
for field := range ext {
|
||||
val := ext[field]
|
||||
container := colMapping[field]
|
||||
switch t := val.(type) {
|
||||
case uint, uint8, uint16, uint32, uint64, uintptr:
|
||||
var c []int64
|
||||
if container == nil {
|
||||
colTypeMapping[field] = searchstore.FieldTypeInt64
|
||||
} else {
|
||||
c, ok = container.([]int64)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("[documents2Columns] container type not int64")
|
||||
}
|
||||
}
|
||||
c = append(c, int64(reflect.ValueOf(t).Uint()))
|
||||
colMapping[field] = c
|
||||
case int, int8, int16, int32, int64:
|
||||
var c []int64
|
||||
if container == nil {
|
||||
colTypeMapping[field] = searchstore.FieldTypeInt64
|
||||
} else {
|
||||
c, ok = container.([]int64)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("[documents2Columns] container type not int64")
|
||||
}
|
||||
}
|
||||
c = append(c, reflect.ValueOf(t).Int())
|
||||
colMapping[field] = c
|
||||
case string:
|
||||
var c []string
|
||||
if container == nil {
|
||||
colTypeMapping[field] = searchstore.FieldTypeText
|
||||
} else {
|
||||
c, ok = container.([]string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("[documents2Columns] container type not int64")
|
||||
}
|
||||
}
|
||||
c = append(c, t)
|
||||
colMapping[field] = c
|
||||
case []float64:
|
||||
var c [][]float64
|
||||
if container == nil {
|
||||
container = c
|
||||
colTypeMapping[field] = searchstore.FieldTypeDenseVector
|
||||
} else {
|
||||
c, ok = container.([][]float64)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("[documents2Columns] container type not int64")
|
||||
}
|
||||
}
|
||||
c = append(c, t)
|
||||
colMapping[field] = c
|
||||
case map[int]float64:
|
||||
var c []map[int]float64
|
||||
if container == nil {
|
||||
container = c
|
||||
colTypeMapping[field] = searchstore.FieldTypeSparseVector
|
||||
} else {
|
||||
c, ok = container.([]map[int]float64)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("[documents2Columns] container type not int64")
|
||||
}
|
||||
}
|
||||
c = append(c, t)
|
||||
colMapping[field] = c
|
||||
default:
|
||||
return nil, fmt.Errorf("[documents2Columns] val type not support, val=%v", val)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
colMapping[searchstore.FieldID] = ids
|
||||
colMapping[searchstore.FieldCreatorID] = creatorIDs
|
||||
colMapping[searchstore.FieldTextContent] = contents
|
||||
|
||||
for fieldName, container := range colMapping {
|
||||
colType := colTypeMapping[fieldName]
|
||||
switch colType {
|
||||
case searchstore.FieldTypeInt64:
|
||||
c, ok := container.([]int64)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("[documents2Columns] container type not int64")
|
||||
}
|
||||
cols = append(cols, column.NewColumnInt64(fieldName, c))
|
||||
case searchstore.FieldTypeText:
|
||||
c, ok := container.([]string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("[documents2Columns] container type not string")
|
||||
}
|
||||
|
||||
if _, indexing := indexingFields[fieldName]; indexing {
|
||||
if fieldName == searchstore.FieldTextContent && !emptyContents {
|
||||
cols = append(cols, column.NewColumnVarChar(fieldName, c))
|
||||
}
|
||||
|
||||
var (
|
||||
emb = m.config.Embedding
|
||||
dense [][]float64
|
||||
sparse []map[int]float64
|
||||
)
|
||||
if emb.SupportStatus() == embedding.SupportDenseAndSparse {
|
||||
dense, sparse, err = emb.EmbedStringsHybrid(ctx, c)
|
||||
} else {
|
||||
dense, err = emb.EmbedStrings(ctx, c)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[slices2Columns] embed failed, %w", err)
|
||||
}
|
||||
|
||||
cols = append(cols, column.NewColumnFloatVector(denseFieldName(fieldName), int(emb.Dimensions()), convertDense(dense)))
|
||||
|
||||
if emb.SupportStatus() == embedding.SupportDenseAndSparse {
|
||||
s, err := convertSparse(sparse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cols = append(cols, column.NewColumnSparseVectors(sparseFieldName(fieldName), s))
|
||||
}
|
||||
} else {
|
||||
cols = append(cols, column.NewColumnVarChar(fieldName, c))
|
||||
}
|
||||
|
||||
case searchstore.FieldTypeDenseVector:
|
||||
c, ok := container.([][]float64)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("[documents2Columns] container type not []float64")
|
||||
}
|
||||
cols = append(cols, column.NewColumnFloatVector(fieldName, int(m.config.Embedding.Dimensions()), convertDense(c)))
|
||||
case searchstore.FieldTypeSparseVector:
|
||||
c, ok := container.([]map[int]float64)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("[documents2Columns] container type not map[int]float64")
|
||||
}
|
||||
sparse, err := convertSparse(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cols = append(cols, column.NewColumnSparseVectors(fieldName, sparse))
|
||||
default:
|
||||
return nil, fmt.Errorf("[documents2Columns] column type not support, type=%d", colType)
|
||||
}
|
||||
}
|
||||
|
||||
return cols, nil
|
||||
}
|
||||
|
||||
func (m *milvusSearchStore) resultSet2Document(result []client.ResultSet, metricsType *mindex.MetricType) (docs []*schema.Document, err error) {
|
||||
docs = make([]*schema.Document, 0, len(result))
|
||||
minScore := math.MaxFloat64
|
||||
maxScore := 0.0
|
||||
|
||||
for _, r := range result {
|
||||
for i := 0; i < r.ResultCount; i++ {
|
||||
ext := make(map[string]any)
|
||||
doc := &schema.Document{MetaData: map[string]any{document.MetaDataKeyExternalStorage: ext}}
|
||||
score := float64(r.Scores[i])
|
||||
minScore = min(minScore, score)
|
||||
maxScore = max(maxScore, score)
|
||||
doc.WithScore(score)
|
||||
|
||||
for _, field := range r.Fields {
|
||||
switch field.Name() {
|
||||
case searchstore.FieldID:
|
||||
id, err := field.GetAsInt64(i)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
doc.ID = strconv.FormatInt(id, 10)
|
||||
case searchstore.FieldTextContent:
|
||||
doc.Content, err = field.GetAsString(i)
|
||||
case searchstore.FieldCreatorID:
|
||||
doc.MetaData[document.MetaDataKeyCreatorID], err = field.GetAsInt64(i)
|
||||
default:
|
||||
ext[field.Name()], err = field.Get(i)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
docs = append(docs, doc)
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(docs, func(i, j int) bool {
|
||||
return docs[i].Score() > docs[j].Score()
|
||||
})
|
||||
|
||||
// norm score
|
||||
if (m.config.EnableHybrid != nil && *m.config.EnableHybrid) || metricsType == nil {
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
switch *metricsType {
|
||||
case mentity.L2:
|
||||
base := maxScore - minScore
|
||||
for i := range docs {
|
||||
if base == 0 {
|
||||
docs[i].WithScore(1.0)
|
||||
} else {
|
||||
docs[i].WithScore(1.0 - (docs[i].Score()-minScore)/base)
|
||||
}
|
||||
}
|
||||
docs = slices.Reverse(docs)
|
||||
case mentity.IP, mentity.COSINE:
|
||||
for i := range docs {
|
||||
docs[i].WithScore((docs[i].Score() + 1) / 2)
|
||||
}
|
||||
default:
|
||||
|
||||
}
|
||||
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
func (m *milvusSearchStore) enableSparse(fields []*mentity.Field) bool {
|
||||
found := false
|
||||
for _, field := range fields {
|
||||
if field.DataType == mentity.FieldTypeSparseVector {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return found && *m.config.EnableHybrid && m.config.Embedding.SupportStatus() == embedding.SupportDenseAndSparse
|
||||
}
|
||||
|
||||
func (m *milvusSearchStore) dsl2Expr(src map[string]interface{}) (string, error) {
|
||||
if src == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
dsl, err := searchstore.LoadDSL(src)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var travDSL func(dsl *searchstore.DSL) (string, error)
|
||||
travDSL = func(dsl *searchstore.DSL) (string, error) {
|
||||
kv := map[string]interface{}{
|
||||
"field": dsl.Field,
|
||||
"val": dsl.Value,
|
||||
}
|
||||
|
||||
switch dsl.Op {
|
||||
case searchstore.OpEq:
|
||||
return pyfmt.Fmt("{field} == {val}", kv)
|
||||
case searchstore.OpNe:
|
||||
return pyfmt.Fmt("{field} != {val}", kv)
|
||||
case searchstore.OpLike:
|
||||
return pyfmt.Fmt("{field} LIKE {val}", kv)
|
||||
case searchstore.OpIn:
|
||||
b, err := json.Marshal(dsl.Value)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
kv["val"] = string(b)
|
||||
return pyfmt.Fmt("{field} IN {val}", kv)
|
||||
case searchstore.OpAnd, searchstore.OpOr:
|
||||
sub, ok := dsl.Value.([]*searchstore.DSL)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("[dsl2Expr] invalid sub dsl")
|
||||
}
|
||||
var items []string
|
||||
for _, s := range sub {
|
||||
str, err := travDSL(s)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("[dsl2Expr] parse sub failed, %w", err)
|
||||
}
|
||||
items = append(items, str)
|
||||
}
|
||||
|
||||
if dsl.Op == searchstore.OpAnd {
|
||||
return strings.Join(items, " AND "), nil
|
||||
} else {
|
||||
return strings.Join(items, " OR "), nil
|
||||
}
|
||||
default:
|
||||
return "", fmt.Errorf("[dsl2Expr] unknown op type=%s", dsl.Op)
|
||||
}
|
||||
}
|
||||
|
||||
return travDSL(dsl)
|
||||
}
|
||||
|
||||
func (m *milvusSearchStore) getIndexMetricsType(ctx context.Context, indexName string) (mindex.MetricType, error) {
|
||||
index, err := m.config.Client.DescribeIndex(ctx, client.NewDescribeIndexOption(m.collectionName, indexName))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("[getIndexMetricsType] describe index failed, collection=%s, index=%s, %w",
|
||||
m.collectionName, indexName, err)
|
||||
}
|
||||
|
||||
typ, found := index.Params()[mindex.MetricTypeKey]
|
||||
if !found { // unexpected
|
||||
return "", fmt.Errorf("[getIndexMetricsType] invalid index params, collection=%s, index=%s", m.collectionName, indexName)
|
||||
}
|
||||
|
||||
return mindex.MetricType(typ), nil
|
||||
}
|
||||
122
backend/infra/impl/document/searchstore/vikingdb/consts.go
Normal file
122
backend/infra/impl/document/searchstore/vikingdb/consts.go
Normal file
@@ -0,0 +1,122 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package vikingdb
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/volcengine/volc-sdk-golang/service/vikingdb"
|
||||
|
||||
embcontract "github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
type VikingEmbeddingModelName string
|
||||
|
||||
const (
|
||||
ModelNameDoubaoEmbedding VikingEmbeddingModelName = "doubao-embedding"
|
||||
ModelNameDoubaoEmbeddingLarge VikingEmbeddingModelName = "doubao-embedding-large"
|
||||
ModelNameDoubaoEmbeddingVision VikingEmbeddingModelName = "doubao-embedding-vision"
|
||||
ModelNameBGELargeZH VikingEmbeddingModelName = "bge-large-zh"
|
||||
ModelNameBGEM3 VikingEmbeddingModelName = "bge-m3"
|
||||
ModelNameBGEVisualizedM3 VikingEmbeddingModelName = "bge-visualized-m3"
|
||||
|
||||
//ModelNameDoubaoEmbeddingAndM3 VikingEmbeddingModelName = "doubao-embedding-and-m3"
|
||||
//ModelNameDoubaoEmbeddingLargeAndM3 VikingEmbeddingModelName = "doubao-embedding-large-and-m3"
|
||||
//ModelNameBGELargeZHAndM3 VikingEmbeddingModelName = "bge-large-zh-and-m3"
|
||||
)
|
||||
|
||||
func (v VikingEmbeddingModelName) Dimensions() int64 {
|
||||
switch v {
|
||||
case ModelNameDoubaoEmbedding, ModelNameDoubaoEmbeddingVision:
|
||||
return 2048
|
||||
case ModelNameDoubaoEmbeddingLarge:
|
||||
return 4096
|
||||
case ModelNameBGELargeZH, ModelNameBGEM3, ModelNameBGEVisualizedM3:
|
||||
return 1024
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func (v VikingEmbeddingModelName) ModelVersion() *string {
|
||||
switch v {
|
||||
case ModelNameDoubaoEmbedding:
|
||||
return ptr.Of("240515")
|
||||
case ModelNameDoubaoEmbeddingLarge:
|
||||
return ptr.Of("240915")
|
||||
case ModelNameDoubaoEmbeddingVision:
|
||||
return ptr.Of("250328")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (v VikingEmbeddingModelName) SupportStatus() embcontract.SupportStatus {
|
||||
switch v {
|
||||
case ModelNameDoubaoEmbedding, ModelNameDoubaoEmbeddingLarge, ModelNameDoubaoEmbeddingVision, ModelNameBGELargeZH, ModelNameBGEVisualizedM3:
|
||||
return embcontract.SupportDense
|
||||
case ModelNameBGEM3:
|
||||
return embcontract.SupportDenseAndSparse
|
||||
default:
|
||||
return embcontract.SupportDense
|
||||
}
|
||||
}
|
||||
|
||||
type IndexType string
|
||||
|
||||
const (
|
||||
IndexTypeHNSW IndexType = vikingdb.HNSW
|
||||
IndexTypeHNSWHybrid IndexType = vikingdb.HNSW_HYBRID
|
||||
IndexTypeFlat IndexType = vikingdb.FLAT
|
||||
IndexTypeIVF IndexType = vikingdb.IVF
|
||||
IndexTypeDiskANN IndexType = vikingdb.DiskANN
|
||||
)
|
||||
|
||||
type IndexDistance string
|
||||
|
||||
const (
|
||||
IndexDistanceIP IndexDistance = vikingdb.IP
|
||||
IndexDistanceL2 IndexDistance = vikingdb.L2
|
||||
IndexDistanceCosine IndexDistance = vikingdb.COSINE
|
||||
)
|
||||
|
||||
type IndexQuant string
|
||||
|
||||
const (
|
||||
IndexQuantInt8 IndexQuant = vikingdb.Int8
|
||||
IndexQuantFloat IndexQuant = vikingdb.Float
|
||||
IndexQuantFix16 IndexQuant = vikingdb.Fix16
|
||||
IndexQuantPQ IndexQuant = vikingdb.PQ
|
||||
)
|
||||
|
||||
const (
|
||||
vikingEmbeddingUseDense = "return_dense"
|
||||
vikingEmbeddingUseSparse = "return_sparse"
|
||||
vikingEmbeddingRespSentenceDense = "sentence_dense_embedding"
|
||||
vikingEmbeddingRespSentenceSparse = "sentence_sparse_embedding"
|
||||
vikingIndexName = "opencoze_index"
|
||||
)
|
||||
|
||||
const (
|
||||
errCollectionNotFound = "collection not found"
|
||||
errIndexNotFound = "index not found"
|
||||
)
|
||||
|
||||
func denseFieldName(name string) string {
|
||||
return fmt.Sprintf("dense_%s", name)
|
||||
}
|
||||
@@ -0,0 +1,331 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package vikingdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/volcengine/volc-sdk-golang/service/vikingdb"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
)
|
||||
|
||||
type ManagerConfig struct {
|
||||
Service *vikingdb.VikingDBService
|
||||
|
||||
IndexingConfig *VikingIndexingConfig
|
||||
EmbeddingConfig *VikingEmbeddingConfig
|
||||
|
||||
// TODO: cache viking collection & index client
|
||||
}
|
||||
|
||||
type VikingIndexingConfig struct {
|
||||
// vector index config
|
||||
Type IndexType // default: hnsw / hnsw_hybrid
|
||||
Distance *IndexDistance // default: ip
|
||||
Quant *IndexQuant // default: int8
|
||||
HnswM *int64 // default: 20
|
||||
HnswCef *int64 // default: 400
|
||||
HnswSef *int64 // default: 800
|
||||
|
||||
// others
|
||||
CpuQuota int64 // default: 2
|
||||
ShardCount int64 // default: 1
|
||||
}
|
||||
|
||||
type VikingEmbeddingConfig struct {
|
||||
UseVikingEmbedding bool
|
||||
EnableHybrid bool
|
||||
|
||||
// viking embedding config
|
||||
ModelName VikingEmbeddingModelName
|
||||
ModelVersion *string
|
||||
DenseWeight *float64
|
||||
|
||||
// builtin embedding config
|
||||
BuiltinEmbedding embedding.Embedder
|
||||
}
|
||||
|
||||
func NewManager(config *ManagerConfig) (searchstore.Manager, error) {
|
||||
if config.Service == nil {
|
||||
return nil, fmt.Errorf("[NewManager] vikingdb service is nil")
|
||||
}
|
||||
if config.EmbeddingConfig == nil {
|
||||
return nil, fmt.Errorf("[NewManager] vikingdb embedding config is nil")
|
||||
}
|
||||
if !config.EmbeddingConfig.UseVikingEmbedding && config.EmbeddingConfig.BuiltinEmbedding == nil {
|
||||
return nil, fmt.Errorf("[NewManager] vikingdb built embedding not provided")
|
||||
}
|
||||
if !config.EmbeddingConfig.UseVikingEmbedding && config.EmbeddingConfig.EnableHybrid {
|
||||
return nil, fmt.Errorf("[NewManager] vikingdb hybrid not support for builtin embedding")
|
||||
}
|
||||
if config.EmbeddingConfig.UseVikingEmbedding && config.EmbeddingConfig.ModelName == "" {
|
||||
return nil, fmt.Errorf("[NewManager] vikingdb model name is empty")
|
||||
}
|
||||
if config.EmbeddingConfig.UseVikingEmbedding &&
|
||||
config.EmbeddingConfig.EnableHybrid &&
|
||||
config.EmbeddingConfig.ModelName.SupportStatus() != embedding.SupportDenseAndSparse {
|
||||
return nil, fmt.Errorf("[NewManager] vikingdb embedding model not support sparse embedding, model=%v", config.EmbeddingConfig.ModelName)
|
||||
}
|
||||
if config.IndexingConfig == nil {
|
||||
config.IndexingConfig = &VikingIndexingConfig{}
|
||||
}
|
||||
if config.IndexingConfig.Type == "" {
|
||||
if !config.EmbeddingConfig.UseVikingEmbedding || !config.EmbeddingConfig.EnableHybrid {
|
||||
config.IndexingConfig.Type = IndexTypeHNSW
|
||||
} else {
|
||||
config.IndexingConfig.Type = IndexTypeHNSWHybrid
|
||||
}
|
||||
}
|
||||
if config.IndexingConfig.Distance == nil {
|
||||
config.IndexingConfig.Distance = ptr.Of(IndexDistanceIP)
|
||||
}
|
||||
if config.IndexingConfig.Quant == nil {
|
||||
config.IndexingConfig.Quant = ptr.Of(IndexQuantInt8)
|
||||
}
|
||||
if config.IndexingConfig.HnswM == nil {
|
||||
config.IndexingConfig.HnswM = ptr.Of(int64(20))
|
||||
}
|
||||
if config.IndexingConfig.HnswCef == nil {
|
||||
config.IndexingConfig.HnswCef = ptr.Of(int64(400))
|
||||
}
|
||||
if config.IndexingConfig.HnswSef == nil {
|
||||
config.IndexingConfig.HnswSef = ptr.Of(int64(800))
|
||||
}
|
||||
if config.IndexingConfig.CpuQuota == 0 {
|
||||
config.IndexingConfig.CpuQuota = 2
|
||||
}
|
||||
if config.IndexingConfig.ShardCount == 0 {
|
||||
config.IndexingConfig.ShardCount = 1
|
||||
}
|
||||
|
||||
return &manager{
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type manager struct {
|
||||
config *ManagerConfig
|
||||
}
|
||||
|
||||
func (m *manager) Create(ctx context.Context, req *searchstore.CreateRequest) error {
|
||||
if err := m.createCollection(ctx, req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := m.createIndex(ctx, req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *manager) Drop(_ context.Context, req *searchstore.DropRequest) error {
|
||||
if err := m.config.Service.DropIndex(req.CollectionName, vikingIndexName); err != nil {
|
||||
if !strings.Contains(err.Error(), errIndexNotFound) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := m.config.Service.DropCollection(req.CollectionName); err != nil {
|
||||
if !strings.Contains(err.Error(), errCollectionNotFound) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *manager) GetType() searchstore.SearchStoreType {
|
||||
return searchstore.TypeVectorStore
|
||||
}
|
||||
|
||||
func (m *manager) GetSearchStore(_ context.Context, collectionName string) (searchstore.SearchStore, error) {
|
||||
collection, err := m.config.Service.GetCollection(collectionName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &vkSearchStore{manager: m, collection: collection}, nil
|
||||
}
|
||||
|
||||
func (m *manager) createCollection(ctx context.Context, req *searchstore.CreateRequest) error {
|
||||
svc := m.config.Service
|
||||
|
||||
collection, err := svc.GetCollection(req.CollectionName)
|
||||
if err != nil {
|
||||
if !strings.Contains(err.Error(), errCollectionNotFound) {
|
||||
return err
|
||||
}
|
||||
} else if collection != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
fields, vopts, err := m.mapFields(req.Fields)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if vopts != nil {
|
||||
_, err = svc.CreateCollection(req.CollectionName, fields, "", vopts)
|
||||
} else {
|
||||
_, err = svc.CreateCollection(req.CollectionName, fields, "")
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logs.CtxInfof(ctx, "[vikingdb] Create collection success, collection=%s", req.CollectionName)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *manager) createIndex(ctx context.Context, req *searchstore.CreateRequest) error {
|
||||
svc := m.config.Service
|
||||
index, err := svc.GetIndex(req.CollectionName, vikingIndexName)
|
||||
if err != nil {
|
||||
if !strings.Contains(err.Error(), errIndexNotFound) {
|
||||
return err
|
||||
}
|
||||
} else if index != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
vectorIndex := &vikingdb.VectorIndexParams{
|
||||
IndexType: string(m.config.IndexingConfig.Type),
|
||||
Distance: string(ptr.From(m.config.IndexingConfig.Distance)),
|
||||
Quant: string(ptr.From(m.config.IndexingConfig.Quant)),
|
||||
HnswM: ptr.From(m.config.IndexingConfig.HnswM),
|
||||
HnswCef: ptr.From(m.config.IndexingConfig.HnswCef),
|
||||
HnswSef: ptr.From(m.config.IndexingConfig.HnswSef),
|
||||
}
|
||||
|
||||
opts := vikingdb.NewIndexOptions().
|
||||
SetVectorIndex(vectorIndex).
|
||||
SetCpuQuota(m.config.IndexingConfig.CpuQuota).
|
||||
SetShardCount(m.config.IndexingConfig.ShardCount)
|
||||
|
||||
_, err = svc.CreateIndex(req.CollectionName, vikingIndexName, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logs.CtxInfof(ctx, "[vikingdb] Create index success, collection=%s, index=%s", req.CollectionName, vikingIndexName)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *manager) mapFields(srcFields []*searchstore.Field) ([]vikingdb.Field, []*vikingdb.VectorizeTuple, error) {
|
||||
var (
|
||||
foundID bool
|
||||
foundCreatorID bool
|
||||
dstFields = make([]vikingdb.Field, 0, len(srcFields))
|
||||
vectorizeOpts []*vikingdb.VectorizeTuple
|
||||
embConfig = m.config.EmbeddingConfig
|
||||
)
|
||||
|
||||
for _, srcField := range srcFields {
|
||||
switch srcField.Name {
|
||||
case searchstore.FieldID:
|
||||
foundID = true
|
||||
case searchstore.FieldCreatorID:
|
||||
foundCreatorID = true
|
||||
default:
|
||||
}
|
||||
|
||||
if srcField.Indexing {
|
||||
if srcField.Type != searchstore.FieldTypeText {
|
||||
return nil, nil, fmt.Errorf("[mapFields] currently only support text field indexing, field=%s", srcField.Name)
|
||||
}
|
||||
if embConfig.UseVikingEmbedding {
|
||||
vt := vikingdb.NewVectorizeTuple().SetDense(m.newVectorizeModelConf(srcField.Name))
|
||||
if embConfig.EnableHybrid {
|
||||
vt = vt.SetSparse(m.newVectorizeModelConf(srcField.Name))
|
||||
}
|
||||
vectorizeOpts = append(vectorizeOpts, vt)
|
||||
} else {
|
||||
dstFields = append(dstFields, vikingdb.Field{
|
||||
FieldName: denseFieldName(srcField.Name),
|
||||
FieldType: vikingdb.Vector,
|
||||
DefaultVal: nil,
|
||||
Dim: m.getDims(),
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
dstField := vikingdb.Field{
|
||||
FieldName: srcField.Name,
|
||||
IsPrimaryKey: srcField.IsPrimary,
|
||||
}
|
||||
switch srcField.Type {
|
||||
case searchstore.FieldTypeInt64:
|
||||
dstField.FieldType = vikingdb.Int64
|
||||
case searchstore.FieldTypeText:
|
||||
dstField.FieldType = vikingdb.Text
|
||||
case searchstore.FieldTypeDenseVector:
|
||||
dstField.FieldType = vikingdb.Vector
|
||||
dstField.Dim = m.getDims()
|
||||
case searchstore.FieldTypeSparseVector:
|
||||
dstField.FieldType = vikingdb.Sparse_Vector
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("unknown field type: %v", srcField.Type)
|
||||
}
|
||||
dstFields = append(dstFields, dstField)
|
||||
}
|
||||
|
||||
if !foundID {
|
||||
dstFields = append(dstFields, vikingdb.Field{
|
||||
FieldName: searchstore.FieldID,
|
||||
FieldType: vikingdb.Int64,
|
||||
IsPrimaryKey: true,
|
||||
})
|
||||
}
|
||||
|
||||
if !foundCreatorID {
|
||||
dstFields = append(dstFields, vikingdb.Field{
|
||||
FieldName: searchstore.FieldCreatorID,
|
||||
FieldType: vikingdb.Int64,
|
||||
})
|
||||
}
|
||||
|
||||
return dstFields, vectorizeOpts, nil
|
||||
}
|
||||
|
||||
func (m *manager) newVectorizeModelConf(fieldName string) *vikingdb.VectorizeModelConf {
|
||||
embConfig := m.config.EmbeddingConfig
|
||||
vmc := vikingdb.NewVectorizeModelConf().
|
||||
SetTextField(fieldName).
|
||||
SetModelName(string(embConfig.ModelName)).
|
||||
SetDim(m.getDims())
|
||||
if embConfig.ModelVersion != nil {
|
||||
vmc = vmc.SetModelVersion(ptr.From(embConfig.ModelVersion))
|
||||
}
|
||||
return vmc
|
||||
}
|
||||
|
||||
func (m *manager) getDims() int64 {
|
||||
if m.config.EmbeddingConfig.UseVikingEmbedding {
|
||||
return m.config.EmbeddingConfig.ModelName.Dimensions()
|
||||
}
|
||||
|
||||
return m.config.EmbeddingConfig.BuiltinEmbedding.Dimensions()
|
||||
}
|
||||
@@ -0,0 +1,388 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package vikingdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
|
||||
"github.com/cloudwego/eino/components/indexer"
|
||||
"github.com/cloudwego/eino/components/retriever"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/volcengine/volc-sdk-golang/service/vikingdb"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/sets"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
)
|
||||
|
||||
type vkSearchStore struct {
|
||||
*manager
|
||||
collection *vikingdb.Collection
|
||||
index *vikingdb.Index
|
||||
}
|
||||
|
||||
func (v *vkSearchStore) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) (ids []string, err error) {
|
||||
if len(docs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
implSpecOptions := indexer.GetImplSpecificOptions(&searchstore.IndexerOptions{}, opts...)
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if implSpecOptions.ProgressBar != nil {
|
||||
_ = implSpecOptions.ProgressBar.ReportError(err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
docsWithoutVector, err := slices.TransformWithErrorCheck(docs, v.document2DataWithoutVector)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[Store] vikingdb failed to transform documents, %w", err)
|
||||
}
|
||||
|
||||
indexingFields := sets.FromSlice(implSpecOptions.IndexingFields)
|
||||
for _, part := range slices.Chunks(docsWithoutVector, 100) {
|
||||
docsWithVector, err := v.addEmbedding(ctx, part, indexingFields)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := v.collection.UpsertData(docsWithVector); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
ids = slices.Transform(docs, func(a *schema.Document) string { return a.ID })
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (v *vkSearchStore) Retrieve(ctx context.Context, query string, opts ...retriever.Option) (docs []*schema.Document, err error) {
|
||||
indexClient := v.index
|
||||
if indexClient == nil {
|
||||
foundIndex := false
|
||||
for _, index := range v.collection.Indexes {
|
||||
if index.IndexName == vikingIndexName {
|
||||
foundIndex = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundIndex {
|
||||
return nil, fmt.Errorf("[Retrieve] vikingdb index not found, name=%s", vikingIndexName)
|
||||
}
|
||||
|
||||
indexClient, err = v.config.Service.GetIndex(v.collection.CollectionName, vikingIndexName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[Retrieve] vikingdb failed to get index, %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
options := retriever.GetCommonOptions(&retriever.Options{TopK: ptr.Of(4)}, opts...)
|
||||
implSpecOptions := retriever.GetImplSpecificOptions(&searchstore.RetrieverOptions{}, opts...)
|
||||
|
||||
searchOpts := vikingdb.NewSearchOptions().
|
||||
SetLimit(int64(ptr.From(options.TopK))).
|
||||
SetText(query).
|
||||
SetRetry(true)
|
||||
|
||||
filter, err := v.genFilter(ctx, options, implSpecOptions)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[Retrieve] vikingdb failed to build filter, %w", err)
|
||||
}
|
||||
if filter != nil {
|
||||
// 不支持跨 partition 召回,使用 filter 替代
|
||||
searchOpts = searchOpts.SetFilter(filter)
|
||||
}
|
||||
|
||||
var data []*vikingdb.Data
|
||||
|
||||
if v.config.EmbeddingConfig.UseVikingEmbedding {
|
||||
data, err = indexClient.SearchWithMultiModal(searchOpts)
|
||||
} else {
|
||||
var dense [][]float64
|
||||
dense, err = v.config.EmbeddingConfig.BuiltinEmbedding.EmbedStrings(ctx, []string{query})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[Retrieve] embed failed, %w", err)
|
||||
}
|
||||
if len(dense) != 1 {
|
||||
return nil, fmt.Errorf("[Retrieve] unexpected dense vector size, expected=1, got=%d", len(dense))
|
||||
}
|
||||
data, err = indexClient.SearchByVector(dense[0], searchOpts)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[Retrieve] vikingdb search failed, %w", err)
|
||||
}
|
||||
|
||||
docs, err = v.parseSearchResult(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (v *vkSearchStore) Delete(ctx context.Context, ids []string) error {
|
||||
for _, part := range slices.Chunks(ids, 100) {
|
||||
if err := v.collection.DeleteData(part); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *vkSearchStore) document2DataWithoutVector(doc *schema.Document) (data vikingdb.Data, err error) {
|
||||
creatorID, err := document.GetDocumentCreatorID(doc)
|
||||
if err != nil {
|
||||
return data, err
|
||||
}
|
||||
|
||||
docID, err := strconv.ParseInt(doc.ID, 10, 64)
|
||||
if err != nil {
|
||||
return data, err
|
||||
}
|
||||
|
||||
fields := map[string]interface{}{
|
||||
searchstore.FieldID: docID,
|
||||
searchstore.FieldCreatorID: creatorID,
|
||||
searchstore.FieldTextContent: doc.Content,
|
||||
}
|
||||
|
||||
if ext, err := document.GetDocumentExternalStorage(doc); err == nil { // try load
|
||||
for key, val := range ext {
|
||||
fields[key] = val
|
||||
}
|
||||
}
|
||||
return vikingdb.Data{
|
||||
Id: doc.ID,
|
||||
Fields: fields,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (v *vkSearchStore) addEmbedding(ctx context.Context, rows []vikingdb.Data, indexingFields map[string]struct{}) ([]vikingdb.Data, error) {
|
||||
if v.config.EmbeddingConfig.UseVikingEmbedding {
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
emb := v.config.EmbeddingConfig.BuiltinEmbedding
|
||||
|
||||
for indexingField := range indexingFields {
|
||||
values := make([]string, len(rows))
|
||||
for i, row := range rows {
|
||||
val, found := row.Fields[indexingField]
|
||||
if !found {
|
||||
return nil, fmt.Errorf("[addEmbedding] indexing field not found in document, field=%s", indexingField)
|
||||
}
|
||||
|
||||
strVal, ok := val.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("[addEmbedding] val not string, field=%s, val=%v", indexingField, val)
|
||||
}
|
||||
|
||||
values[i] = strVal
|
||||
}
|
||||
|
||||
dense, err := emb.EmbedStrings(ctx, values)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[addEmbedding] failed to embed, %w", err)
|
||||
}
|
||||
if len(dense) != len(values) {
|
||||
return nil, fmt.Errorf("[addEmbedding] unexpected dense vector size, expected=%d, got=%d", len(values), len(dense))
|
||||
}
|
||||
|
||||
df := denseFieldName(indexingField)
|
||||
for i := range dense {
|
||||
rows[i].Fields[df] = dense[i]
|
||||
}
|
||||
}
|
||||
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
func (v *vkSearchStore) parseSearchResult(result []*vikingdb.Data) ([]*schema.Document, error) {
|
||||
docs := make([]*schema.Document, 0, len(result))
|
||||
for _, data := range result {
|
||||
ext := make(map[string]any)
|
||||
doc := document.WithDocumentExternalStorage(&schema.Document{MetaData: map[string]any{}}, ext).
|
||||
WithScore(data.Score)
|
||||
|
||||
for field, val := range data.Fields {
|
||||
switch field {
|
||||
case searchstore.FieldID:
|
||||
jn, ok := val.(json.Number)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("[parseSearchResult] id type assertion failed, val=%v", val)
|
||||
}
|
||||
doc.ID = jn.String()
|
||||
case searchstore.FieldCreatorID:
|
||||
jn, ok := val.(json.Number)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("[parseSearchResult] creator_id type assertion failed, val=%v", val)
|
||||
}
|
||||
creatorID, err := jn.Int64()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[parseSearchResult] creator_id value not int64, val=%v", jn.String())
|
||||
}
|
||||
doc = document.WithDocumentCreatorID(doc, creatorID)
|
||||
case searchstore.FieldTextContent:
|
||||
text, ok := val.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("[parseSearchResult] content value not string, val=%v", val)
|
||||
}
|
||||
doc.Content = text
|
||||
default:
|
||||
switch t := val.(type) {
|
||||
case json.Number:
|
||||
if i64, err := t.Int64(); err == nil {
|
||||
ext[field] = i64
|
||||
} else if f64, err := t.Float64(); err == nil {
|
||||
ext[field] = f64
|
||||
} else {
|
||||
ext[field] = t.String()
|
||||
}
|
||||
default:
|
||||
ext[field] = val
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
docs = append(docs, doc)
|
||||
}
|
||||
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
func (v *vkSearchStore) genFilter(ctx context.Context, co *retriever.Options, ro *searchstore.RetrieverOptions) (map[string]any, error) {
|
||||
filter, err := v.dsl2Filter(ctx, co.DSLInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if ro.PartitionKey != nil && len(ro.Partitions) > 0 {
|
||||
var (
|
||||
key = ptr.From(ro.PartitionKey)
|
||||
fieldType = ""
|
||||
conds any
|
||||
)
|
||||
for _, field := range v.collection.Fields {
|
||||
if field.FieldName == key {
|
||||
fieldType = field.FieldType
|
||||
}
|
||||
}
|
||||
if fieldType == "" {
|
||||
return nil, fmt.Errorf("[Retrieve] partition key not found, key=%s", key)
|
||||
}
|
||||
|
||||
switch fieldType {
|
||||
case vikingdb.Int64:
|
||||
c := make([]int64, 0, len(ro.Partitions))
|
||||
for _, item := range ro.Partitions {
|
||||
i64, err := strconv.ParseInt(item, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[Retrieve] partition value parse error, key=%s, val=%v, err=%v", key, item, err)
|
||||
}
|
||||
c = append(c, i64)
|
||||
}
|
||||
conds = c
|
||||
case vikingdb.String:
|
||||
conds = ro.Partitions
|
||||
default:
|
||||
return nil, fmt.Errorf("[Retrieve] invalid field type for partition, key=%s, type=%s", key, fieldType)
|
||||
}
|
||||
|
||||
op := map[string]any{"op": "must", "field": key, "conds": conds}
|
||||
|
||||
if filter != nil {
|
||||
filter = op
|
||||
} else {
|
||||
filter = map[string]any{
|
||||
"op": "and",
|
||||
"conds": []map[string]any{op, filter},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return filter, nil
|
||||
}
|
||||
|
||||
func (v *vkSearchStore) dsl2Filter(ctx context.Context, src map[string]any) (map[string]any, error) {
|
||||
dsl, err := searchstore.LoadDSL(src)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if dsl == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
toSliceValue := func(val any) any {
|
||||
if reflect.TypeOf(val).Kind() == reflect.Slice {
|
||||
return val
|
||||
}
|
||||
return []any{val}
|
||||
}
|
||||
|
||||
var filter map[string]any
|
||||
|
||||
switch dsl.Op {
|
||||
case searchstore.OpEq, searchstore.OpIn:
|
||||
filter = map[string]any{
|
||||
"op": "must",
|
||||
"field": dsl.Field,
|
||||
"conds": toSliceValue(dsl.Value),
|
||||
}
|
||||
case searchstore.OpNe:
|
||||
filter = map[string]any{
|
||||
"op": "must_not",
|
||||
"field": dsl.Field,
|
||||
"conds": toSliceValue(dsl.Value),
|
||||
}
|
||||
case searchstore.OpLike:
|
||||
logs.CtxWarnf(ctx, "[dsl2Filter] vikingdb invalid dsl type, skip, type=%s", dsl.Op)
|
||||
case searchstore.OpAnd, searchstore.OpOr:
|
||||
var conds []map[string]any
|
||||
sub, ok := dsl.Value.([]map[string]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("[dsl2Filter] invalid value for and/or, should be []map[string]any")
|
||||
}
|
||||
for _, subDSL := range sub {
|
||||
cond, err := v.dsl2Filter(ctx, subDSL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conds = append(conds, cond)
|
||||
}
|
||||
op := "and"
|
||||
if dsl.Op == searchstore.OpOr {
|
||||
op = "or"
|
||||
}
|
||||
filter = map[string]any{
|
||||
"op": op,
|
||||
"field": dsl.Field,
|
||||
"conds": conds,
|
||||
}
|
||||
}
|
||||
|
||||
return filter, nil
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user