feat: manually mirror opencoze's code from bytedance

Change-Id: I09a73aadda978ad9511264a756b2ce51f5761adf
This commit is contained in:
fanlv
2025-07-20 17:36:12 +08:00
commit 890153324f
14811 changed files with 1923430 additions and 0 deletions

View File

@@ -0,0 +1,118 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"github.com/cloudwego/eino/components/prompt"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/nl2sql"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
)
const (
defaultTableFmt = "table name: %s.\ntable describe: %s.\n\n| field name | description | field type | is required |\n"
defaultColumnFmt = "| %s | %s | %s | %t |\n\n"
)
func NewNL2SQL(_ context.Context, cm chatmodel.BaseChatModel, tpl prompt.ChatTemplate) (nl2sql.NL2SQL, error) {
return &n2s{cm: cm, tpl: tpl}, nil
}
type n2s struct {
ch *compose.Chain[*nl2sqlInput, string]
runnable compose.Runnable[*nl2sqlInput, string]
cm chatmodel.BaseChatModel
tpl prompt.ChatTemplate
}
func (n *n2s) NL2SQL(ctx context.Context, messages []*schema.Message, tables []*document.TableSchema, opts ...nl2sql.Option) (sql string, err error) {
o := &nl2sql.Options{ChatModel: n.cm}
for _, opt := range opts {
opt(o)
}
if o.ChatModel == nil {
return "", fmt.Errorf("[NL2SQL] chat model not configured")
}
c := compose.NewChain[*nl2sqlInput, string]().
AppendLambda(compose.InvokableLambda(func(ctx context.Context, input *nl2sqlInput) (output map[string]any, err error) {
if len(input.tables) == 0 {
return nil, errors.New("table meta is empty")
}
tableDesc := strings.Builder{}
for _, table := range input.tables {
tableDesc.WriteString(fmt.Sprintf(defaultTableFmt, table.Name, table.Comment))
for _, column := range table.Columns {
tableDesc.WriteString(fmt.Sprintf(defaultColumnFmt, column.Name, column.Description, column.Type.String(), !column.Nullable))
}
}
//logs.CtxInfof(ctx, "table schema: %s", tableDesc.String())
return map[string]interface{}{
"messages": input.messages,
"table_schema": tableDesc.String(),
}, nil
})).
AppendChatTemplate(n.tpl).
AppendChatModel(o.ChatModel).
AppendLambda(compose.InvokableLambda(func(ctx context.Context, msg *schema.Message) (sql string, err error) {
var promptResp *promptResponse
if err := json.Unmarshal([]byte(msg.Content), &promptResp); err != nil {
logs.CtxWarnf(ctx, "unmarshal failed: %v", err)
return "", err
}
if promptResp.SQL == "" {
logs.CtxInfof(ctx, "no sql generated, err_code: %v, err_msg: %v", promptResp.ErrCode, promptResp.ErrMsg)
return "", errors.New(promptResp.ErrMsg)
}
return promptResp.SQL, nil
}))
r, err := c.Compile(ctx)
if err != nil {
return "", err
}
input := &nl2sqlInput{
messages: messages,
tables: tables,
}
return r.Invoke(ctx, input)
}
type nl2sqlInput struct {
messages []*schema.Message
tables []*document.TableSchema
}
type promptResponse struct {
SQL string `json:"sql"`
ErrCode int `json:"err_code"`
ErrMsg string `json:"err_msg"`
}

View File

@@ -0,0 +1,139 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"context"
"testing"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/components/prompt"
"github.com/cloudwego/eino/schema"
"github.com/stretchr/testify/assert"
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
)
func TestNL2SQL(t *testing.T) {
ctx := context.Background()
t.Run("test table meta not provided", func(t *testing.T) {
impl, err := NewNL2SQL(ctx, &mockChatModel{"mock resp"}, prompt.FromMessages(schema.Jinja2,
schema.SystemMessage("system message 123"),
schema.UserMessage("{{messages}}, {{table_meta}}"),
))
assert.NoError(t, err)
sql, err := impl.NL2SQL(ctx, []*schema.Message{schema.UserMessage("hello")}, nil)
assert.Error(t, err)
assert.Equal(t, "", sql)
})
t.Run("test parse failed", func(t *testing.T) {
impl, err := NewNL2SQL(ctx, &mockChatModel{"mock resp"}, prompt.FromMessages(schema.Jinja2,
schema.SystemMessage("system message 123"),
schema.UserMessage("{{messages}}, {{table_meta}}"),
))
assert.NoError(t, err)
sql, err := impl.NL2SQL(ctx, []*schema.Message{schema.UserMessage("hello")}, []*document.TableSchema{
{
Name: "mock_table_1",
Comment: "hello",
Columns: []*document.Column{
{
ID: 121,
Name: "id",
Type: document.TableColumnTypeInteger,
Description: "test",
Nullable: false,
IsPrimary: true,
Sequence: 0,
},
{
ID: 123,
Name: "col_1",
Type: document.TableColumnTypeString,
Description: "column_1",
Nullable: true,
IsPrimary: false,
Sequence: 1,
},
},
},
})
assert.Error(t, err)
assert.Equal(t, "", sql)
})
t.Run("test success", func(t *testing.T) {
impl, err := NewNL2SQL(ctx, &mockChatModel{`{"sql":"mock sql","err_code":0,"err_msg":""}`}, prompt.FromMessages(schema.Jinja2,
schema.SystemMessage("system message 123"),
schema.UserMessage("{{messages}}, {{table_meta}}"),
))
assert.NoError(t, err)
sql, err := impl.NL2SQL(ctx, []*schema.Message{schema.UserMessage("hello")}, []*document.TableSchema{
{
Name: "mock_table_1",
Comment: "hello",
Columns: []*document.Column{
{
ID: 121,
Name: "id",
Type: document.TableColumnTypeInteger,
Description: "test",
Nullable: false,
IsPrimary: true,
Sequence: 0,
},
{
ID: 123,
Name: "col_1",
Type: document.TableColumnTypeString,
Description: "column_1",
Nullable: true,
IsPrimary: false,
Sequence: 1,
},
},
},
})
assert.NoError(t, err)
assert.Equal(t, "mock sql", sql)
})
}
type mockChatModel struct {
content string
}
func (m mockChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) {
return schema.AssistantMessage(m.content, nil), nil
}
func (m mockChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) {
return nil, nil
}
func (m mockChatModel) BindTools(tools []*schema.ToolInfo) error {
return nil
}
const sys = "# Role: NL2SQL Consultant\n\n## Goals\nTranslate natural language statements into SQL queries in MySQL standard. Follow the Constraints and return only a JSON always.\n\n## Format\n- JSON format only. JSON contains field \"sql\" for generated SQL, filed \"err_code\" for reason type, field \"err_msg\" for detail reason (prefer more than 10 words)\n- Don't use \"```json\" markdown format\n\n## Skills\n- Good at Translate natural language statements into SQL queries in MySQL standard.\n\n## Define\n\"err_code\" Reason Type Define:\n- 0 means you generated a SQL\n- 3002 means you cannot generate a SQL because of timeout\n- 3003 means you cannot generate a SQL because of table schema missing\n- 3005 means you cannot generate a SQL because of some term is ambiguous\n\n## Example\nQ: Help me implement NL2SQL.\n.table schema description: CREATE TABLE `sales_records` (\\n `sales_id` bigint(20) unsigned NOT NULL COMMENT 'id of sales person',\\n `product_id` bigint(64) COMMENT 'id of product',\\n `sale_date` datetime(3) COMMENT 'sold date and time',\\n `quantity_sold` int(11) COMMENT 'sold amount',\\n PRIMARY KEY (`sales_id`)\\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='销售记录表';\n.natural language description of the SQL requirement: ​​​​查询上月的销量总额第一名的销售员和他的销售总额\nA: {\n \"sql\":\"SELECT sales_id, SUM(quantity_sold) AS total_sales FROM sales_records WHERE MONTH(sale_date) = MONTH(CURRENT_DATE - INTERVAL 1 MONTH) AND YEAR(sale_date) = YEAR(CURRENT_DATE - INTERVAL 1 MONTH) GROUP BY sales_id ORDER BY total_sales DESC LIMIT 1\",\n \"err_code\":0,\n \"err_msg\":\"SQL query generated successfully\"\n}"
const usr = "help me implement NL2SQL.\ntable schema description:{{tableSchema}}\nnatural language description of the SQL requirement: {{chat_history}}."

View File

@@ -0,0 +1,96 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package veocr
import (
"context"
"fmt"
"net/http"
"net/url"
"strconv"
"github.com/volcengine/volc-sdk-golang/service/visual"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/ocr"
)
type Config struct {
Client *visual.Visual
// see: https://www.volcengine.com/docs/6790/117730
ApproximatePixel *int // default: 0
Mode *string // default: "text_block"
FilterThresh *int // default: 80
HalfToFull *bool // default: false
}
func NewOCR(config *Config) ocr.OCR {
return &ocrImpl{config}
}
type ocrImpl struct {
config *Config
}
func (o *ocrImpl) FromBase64(ctx context.Context, b64 string) ([]string, error) {
form := o.newForm()
form.Add("image_base64", b64)
resp, statusCode, err := o.config.Client.OCRNormal(form)
if err != nil {
return nil, err
}
if statusCode != http.StatusOK {
return nil, fmt.Errorf("[FromBase64] failed, status code=%d", statusCode)
}
return resp.Data.LineTexts, nil
}
func (o *ocrImpl) FromURL(ctx context.Context, url string) ([]string, error) {
form := o.newForm()
form.Add("image_url", url)
resp, statusCode, err := o.config.Client.OCRNormal(form)
if err != nil {
return nil, err
}
if statusCode != http.StatusOK {
return nil, fmt.Errorf("[FromBase64] failed, status code=%d", statusCode)
}
return resp.Data.LineTexts, nil
}
func (o *ocrImpl) newForm() url.Values {
form := url.Values{}
if o.config.ApproximatePixel != nil {
form.Add("approximate_pixel", strconv.FormatInt(int64(*o.config.ApproximatePixel), 10))
}
if o.config.Mode != nil {
form.Add("mode", *o.config.Mode)
} else {
form.Add("mode", "text_block")
}
if o.config.FilterThresh != nil {
form.Add("filter_thresh", strconv.FormatInt(int64(*o.config.FilterThresh), 10))
}
if o.config.HalfToFull != nil {
form.Add("half_to_full", strconv.FormatBool(*o.config.HalfToFull))
}
return form
}

View File

@@ -0,0 +1,37 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
)
func alignTableSliceValue(schema []*document.Column, row []*document.ColumnData) (err error) {
for i, col := range row {
var newCol *document.ColumnData
newCol, err = assertValAs(schema[i].Type, col.GetStringValue())
if err != nil {
return err
}
newCol.ColumnID = col.ColumnID
newCol.ColumnName = col.ColumnName
row[i] = newCol
}
return nil
}

View File

@@ -0,0 +1,142 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"testing"
"time"
. "github.com/bytedance/mockey"
"github.com/smartystreets/goconvey/convey"
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
func TestAssertVal(t *testing.T) {
PatchConvey("test assertVal", t, func() {
convey.So(assertVal(""), convey.ShouldEqual, document.ColumnData{
Type: document.TableColumnTypeUnknown,
ValString: ptr.Of(""),
})
convey.So(assertVal("true"), convey.ShouldEqual, document.ColumnData{
Type: document.TableColumnTypeBoolean,
ValBoolean: ptr.Of(true),
})
convey.So(assertVal("10"), convey.ShouldEqual, document.ColumnData{
Type: document.TableColumnTypeInteger,
ValInteger: ptr.Of(int64(10)),
})
convey.So(assertVal("1.0"), convey.ShouldEqual, document.ColumnData{
Type: document.TableColumnTypeNumber,
ValNumber: ptr.Of(1.0),
})
ts := time.Now().Format(timeFormat)
now, err := time.Parse(timeFormat, ts)
convey.So(err, convey.ShouldBeNil)
convey.So(assertVal(ts), convey.ShouldEqual, document.ColumnData{
Type: document.TableColumnTypeTime,
ValTime: ptr.Of(now),
})
convey.So(assertVal("hello"), convey.ShouldEqual, document.ColumnData{
Type: document.TableColumnTypeString,
ValString: ptr.Of("hello"),
})
})
}
func TestAssertValAs(t *testing.T) {
PatchConvey("test assertValAs", t, func() {
type testCase struct {
typ document.TableColumnType
val string
isErr bool
data *document.ColumnData
}
ts := time.Now().Format(timeFormat)
now, _ := time.Parse(timeFormat, ts)
cases := []testCase{
{
typ: document.TableColumnTypeString,
val: "hello",
isErr: false,
data: &document.ColumnData{Type: document.TableColumnTypeString, ValString: ptr.Of("hello")},
},
{
typ: document.TableColumnTypeInteger,
val: "1",
isErr: false,
data: &document.ColumnData{Type: document.TableColumnTypeInteger, ValInteger: ptr.Of(int64(1))},
},
{
typ: document.TableColumnTypeInteger,
val: "hello",
isErr: true,
},
{
typ: document.TableColumnTypeTime,
val: ts,
isErr: false,
data: &document.ColumnData{Type: document.TableColumnTypeTime, ValTime: ptr.Of(now)},
},
{
typ: document.TableColumnTypeTime,
val: "hello",
isErr: true,
},
{
typ: document.TableColumnTypeNumber,
val: "1.0",
isErr: false,
data: &document.ColumnData{Type: document.TableColumnTypeNumber, ValNumber: ptr.Of(1.0)},
},
{
typ: document.TableColumnTypeNumber,
val: "hello",
isErr: true,
},
{
typ: document.TableColumnTypeBoolean,
val: "true",
isErr: false,
data: &document.ColumnData{Type: document.TableColumnTypeBoolean, ValBoolean: ptr.Of(true)},
},
{
typ: document.TableColumnTypeBoolean,
val: "hello",
isErr: true,
},
{
typ: document.TableColumnTypeUnknown,
val: "hello",
isErr: true,
},
}
for _, c := range cases {
v, err := assertValAs(c.typ, c.val)
if c.isErr {
convey.So(err, convey.ShouldNotBeNil)
convey.So(v, convey.ShouldBeNil)
} else {
convey.So(err, convey.ShouldBeNil)
convey.So(v, convey.ShouldEqual, c.data)
}
}
})
}

View File

@@ -0,0 +1,116 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"context"
"fmt"
"regexp"
"strings"
"github.com/cloudwego/eino/components/document/parser"
"github.com/cloudwego/eino/schema"
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
)
var (
spaceRegex = regexp.MustCompile(`\s+`)
urlRegex = regexp.MustCompile(`https?://\S+|www\.\S+`)
emailRegex = regexp.MustCompile(`[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}`)
)
func chunkCustom(_ context.Context, text string, config *contract.Config, opts ...parser.Option) (docs []*schema.Document, err error) {
cs := config.ChunkingStrategy
if cs.Overlap >= cs.ChunkSize {
return nil, fmt.Errorf("[chunkCustom] invalid param, overlap >= chunk_size")
}
var (
parts = strings.Split(text, cs.Separator)
buffer []rune
currentLength int64
options = parser.GetCommonOptions(&parser.Options{ExtraMeta: map[string]any{}}, opts...)
)
trim := func(text string) string {
if cs.TrimURLAndEmail {
text = urlRegex.ReplaceAllString(text, "")
text = emailRegex.ReplaceAllString(text, "")
}
if cs.TrimSpace {
text = strings.TrimSpace(text)
text = spaceRegex.ReplaceAllString(text, " ")
}
return text
}
add := func() {
if len(buffer) == 0 {
return
}
doc := &schema.Document{
Content: string(buffer),
MetaData: map[string]any{},
}
for k, v := range options.ExtraMeta {
doc.MetaData[k] = v
}
docs = append(docs, doc)
buffer = []rune{}
}
processPart := func(part string) {
runes := []rune(part)
for partLength := int64(len(runes)); partLength > 0; partLength = int64(len(runes)) {
pos := min(partLength, cs.ChunkSize-currentLength)
buffer = append(buffer, runes[:pos]...)
currentLength = int64(len(buffer))
if currentLength >= cs.ChunkSize {
add()
if cs.Overlap > 0 {
buffer = getOverlap([]rune(docs[len(docs)-1].Content), cs.Overlap, cs.ChunkSize)
currentLength = int64(len(buffer))
} else {
currentLength = 0
}
}
runes = runes[pos:]
}
add()
}
for _, part := range parts {
processPart(trim(part))
}
add()
return docs, nil
}
func getOverlap(runes []rune, overlapRatio int64, chunkSize int64) []rune {
overlap := int64(float64(chunkSize) * float64(overlapRatio) / 100)
if int64(len(runes)) <= overlap {
return runes
}
return runes[len(runes)-int(overlap):]
}

View File

@@ -0,0 +1,47 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
)
func TestChunkCustom(t *testing.T) {
ctx := context.Background()
t.Run("test \n no overlap", func(t *testing.T) {
text := "1. Eiffel Tower: Located in Paris, France, it is one of the most famous landmarks in the world, designed by Gustave Eiffel and built in 1889.\n2. The Great Wall: Located in China, it is one of the Seven Wonders of the World, built from the Qin Dynasty to the Ming Dynasty, with a total length of over 20000 kilometers.\n3. Grand Canyon National Park: Located in Arizona, USA, it is famous for its deep canyons and magnificent scenery, which are cut by the Colorado River.\n4. The Colosseum: Located in Rome, Italy, built between 70-80 AD, it was the largest circular arena in the ancient Roman Empire.\n5. Taj Mahal: Located in Agra, India, it was completed by Mughal Emperor Shah Jahan in 1653 to commemorate his wife and is one of the New Seven Wonders of the World.\n6. Sydney Opera House: Located in Sydney Harbour, Australia, it is one of the most iconic buildings of the 20th century, renowned for its unique sailboat design.\n7. Louvre Museum: Located in Paris, France, it is one of the largest museums in the world with a rich collection, including Leonardo da Vinci's Mona Lisa and Greece's Venus de Milo.\n8. Niagara Falls: located at the border of the United States and Canada, consisting of three main waterfalls, its spectacular scenery attracts millions of tourists every year.\n9. St. Sophia Cathedral: located in Istanbul, Türkiye, originally built in 537 A.D., it used to be an Orthodox cathedral and mosque, and now it is a museum.\n10. Machu Picchu: an ancient Inca site located on the plateau of the Andes Mountains in Peru, one of the New Seven Wonders of the World, with an altitude of over 2400 meters."
cs := &parser.ChunkingStrategy{
ChunkType: parser.ChunkTypeCustom,
ChunkSize: 1000,
Separator: "\n",
Overlap: 0,
TrimSpace: true,
TrimURLAndEmail: true,
}
slices, err := chunkCustom(ctx, text, &parser.Config{ChunkingStrategy: cs})
assert.NoError(t, err)
assert.Len(t, slices, 10)
})
}

View File

@@ -0,0 +1,213 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"fmt"
"strconv"
"time"
"unicode/utf8"
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
const (
timeFormat = "2006-01-02 15:04:05"
)
func assertValAs(typ document.TableColumnType, val string) (*document.ColumnData, error) {
if val == "" {
return &document.ColumnData{
Type: typ,
}, nil
}
switch typ {
case document.TableColumnTypeString:
return &document.ColumnData{
Type: document.TableColumnTypeString,
ValString: &val,
}, nil
case document.TableColumnTypeInteger:
i, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return nil, err
}
return &document.ColumnData{
Type: document.TableColumnTypeInteger,
ValInteger: &i,
}, nil
case document.TableColumnTypeTime:
if val == "" {
var emptyTime time.Time
return &document.ColumnData{
Type: document.TableColumnTypeTime,
ValTime: ptr.Of(emptyTime),
}, nil
}
// 支持时间戳和时间字符串
i, err := strconv.ParseInt(val, 10, 64)
if err == nil {
t := time.Unix(i, 0)
return &document.ColumnData{
Type: document.TableColumnTypeTime,
ValTime: &t,
}, nil
}
t, err := time.Parse(timeFormat, val)
if err != nil {
return nil, err
}
return &document.ColumnData{
Type: document.TableColumnTypeTime,
ValTime: &t,
}, nil
case document.TableColumnTypeNumber:
f, err := strconv.ParseFloat(val, 64)
if err != nil {
return nil, err
}
return &document.ColumnData{
Type: document.TableColumnTypeNumber,
ValNumber: &f,
}, nil
case document.TableColumnTypeBoolean:
t, err := strconv.ParseBool(val)
if err != nil {
return nil, err
}
return &document.ColumnData{
Type: document.TableColumnTypeBoolean,
ValBoolean: &t,
}, nil
case document.TableColumnTypeImage:
return &document.ColumnData{
Type: document.TableColumnTypeImage,
ValImage: &val,
}, nil
default:
return nil, fmt.Errorf("[assertValAs] type not support, type=%d, val=%s", typ, val)
}
}
func assertValAsForce(typ document.TableColumnType, val string, nullable bool) *document.ColumnData {
cd := &document.ColumnData{
Type: typ,
}
switch typ {
case document.TableColumnTypeString:
cd.ValString = &val
case document.TableColumnTypeInteger:
if i, err := strconv.ParseInt(val, 10, 64); err == nil {
cd.ValInteger = ptr.Of(i)
} else if !nullable {
cd.ValInteger = ptr.Of(int64(0))
}
case document.TableColumnTypeTime:
if t, err := time.Parse(timeFormat, val); err == nil {
cd.ValTime = ptr.Of(t)
} else if !nullable {
cd.ValTime = ptr.Of(time.Time{})
}
case document.TableColumnTypeNumber:
if f, err := strconv.ParseFloat(val, 64); err == nil {
cd.ValNumber = ptr.Of(f)
} else if !nullable {
cd.ValNumber = ptr.Of(0.0)
}
case document.TableColumnTypeBoolean:
if t, err := strconv.ParseBool(val); err == nil {
cd.ValBoolean = ptr.Of(t)
} else if !nullable {
cd.ValBoolean = ptr.Of(false)
}
case document.TableColumnTypeImage:
cd.ValImage = ptr.Of(val)
default:
cd.ValString = &val
}
return cd
}
func assertVal(val string) document.ColumnData {
// TODO: 先不处理 image
if val == "" {
return document.ColumnData{
Type: document.TableColumnTypeUnknown,
ValString: &val,
}
}
if t, err := strconv.ParseBool(val); err == nil {
return document.ColumnData{
Type: document.TableColumnTypeBoolean,
ValBoolean: &t,
}
}
if i, err := strconv.ParseInt(val, 10, 64); err == nil {
return document.ColumnData{
Type: document.TableColumnTypeInteger,
ValInteger: &i,
}
}
if f, err := strconv.ParseFloat(val, 64); err == nil {
return document.ColumnData{
Type: document.TableColumnTypeNumber,
ValNumber: &f,
}
}
if t, err := time.Parse(timeFormat, val); err == nil {
return document.ColumnData{
Type: document.TableColumnTypeTime,
ValTime: &t,
}
}
return document.ColumnData{
Type: document.TableColumnTypeString,
ValString: &val,
}
}
func transformColumnType(src, dst document.TableColumnType) document.TableColumnType {
if src == document.TableColumnTypeUnknown {
return dst
}
if dst == document.TableColumnTypeUnknown {
return src
}
if dst == document.TableColumnTypeString {
return dst
}
if src == dst {
return dst
}
if src == document.TableColumnTypeInteger && dst == document.TableColumnTypeNumber {
return dst
}
return document.TableColumnTypeString
}
func charCount(text string) int64 {
return int64(utf8.RuneCountInString(text))
}

View File

@@ -0,0 +1,36 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"context"
"fmt"
"time"
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
)
func putImageObject(ctx context.Context, st storage.Storage, imgExt string, uid int64, img []byte) (format string, err error) {
secret := createSecret(uid, imgExt)
fileName := fmt.Sprintf("%d_%d_%s.%s", uid, time.Now().UnixNano(), secret, imgExt)
objectName := fmt.Sprintf("%s/%s", knowledgePrefix, fileName)
if err := st.PutObject(ctx, objectName, img); err != nil {
return "", err
}
imgSrc := fmt.Sprintf(imgSrcFormat, objectName)
return imgSrc, nil
}

View File

@@ -0,0 +1,77 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"fmt"
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/ocr"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
"github.com/coze-dev/coze-studio/backend/pkg/goutil"
)
func NewManager(storage storage.Storage, ocr ocr.OCR, imageAnnotationModel chatmodel.BaseChatModel) parser.Manager {
return &manager{
storage: storage,
ocr: ocr,
model: imageAnnotationModel,
}
}
type manager struct {
ocr ocr.OCR
storage storage.Storage
model chatmodel.BaseChatModel
}
func (m *manager) GetParser(config *parser.Config) (parser.Parser, error) {
var pFn parseFn
if config.ParsingStrategy.HeaderLine == 0 && config.ParsingStrategy.DataStartLine == 0 {
config.ParsingStrategy.DataStartLine = 1
} else if config.ParsingStrategy.HeaderLine >= config.ParsingStrategy.DataStartLine {
return nil, fmt.Errorf("[GetParser] invalid header line and data start line, header=%d, data_start=%d",
config.ParsingStrategy.HeaderLine, config.ParsingStrategy.DataStartLine)
}
switch config.FileExtension {
case parser.FileExtensionPDF:
pFn = parseByPython(config, m.storage, m.ocr, goutil.GetPython3Path(), goutil.GetPythonFilePath("parse_pdf.py"))
case parser.FileExtensionTXT:
pFn = parseText(config)
case parser.FileExtensionMarkdown:
pFn = parseMarkdown(config, m.storage, m.ocr)
case parser.FileExtensionDocx:
pFn = parseByPython(config, m.storage, m.ocr, goutil.GetPython3Path(), goutil.GetPythonFilePath("parse_docx.py"))
case parser.FileExtensionCSV:
pFn = parseCSV(config)
case parser.FileExtensionXLSX:
pFn = parseXLSX(config)
case parser.FileExtensionJSON:
pFn = parseJSON(config)
case parser.FileExtensionJsonMaps:
pFn = parseJSONMaps(config)
case parser.FileExtensionJPG, parser.FileExtensionJPEG, parser.FileExtensionPNG:
pFn = parseImage(config, m.model)
default:
return nil, fmt.Errorf("[Parse] document type not support, type=%s", config.FileExtension)
}
return &p{parseFn: pFn}, nil
}

View File

@@ -0,0 +1,53 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"context"
"encoding/csv"
"errors"
"io"
"github.com/cloudwego/eino/components/document/parser"
"github.com/cloudwego/eino/schema"
"github.com/dimchansky/utfbom"
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
)
func parseCSV(config *contract.Config) parseFn {
return func(ctx context.Context, reader io.Reader, opts ...parser.Option) (docs []*schema.Document, err error) {
iter := &csvIterator{csv.NewReader(utfbom.SkipOnly(reader))}
return parseByRowIterator(iter, config, opts...)
}
}
type csvIterator struct {
reader *csv.Reader
}
func (c *csvIterator) NextRow() (row []string, end bool, err error) {
row, e := c.reader.Read()
if e != nil {
if errors.Is(e, io.EOF) {
return nil, true, nil
}
return nil, false, err
}
return row, false, nil
}

View File

@@ -0,0 +1,200 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"bytes"
"context"
"fmt"
"io"
"os"
"testing"
"github.com/cloudwego/eino/components/document/parser"
"github.com/cloudwego/eino/schema"
"github.com/stretchr/testify/assert"
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
)
func TestParseCSV(t *testing.T) {
ctx := context.Background()
b, err := os.ReadFile("./test_data/test_csv.csv")
assert.NoError(t, err)
r1 := bytes.NewReader(b)
c1 := &contract.Config{
FileExtension: contract.FileExtensionCSV,
ParsingStrategy: &contract.ParsingStrategy{
HeaderLine: 0,
DataStartLine: 1,
RowsCount: 20,
},
ChunkingStrategy: nil,
}
p1 := parseCSV(c1)
docs, err := p1(ctx, r1, parser.WithExtraMeta(map[string]any{
"document_id": int64(123),
"knowledge_id": int64(456),
}))
assert.NoError(t, err)
for i, doc := range docs {
assertSheet(t, i, doc)
}
// parse
r2 := bytes.NewReader(b)
c2 := &contract.Config{
FileExtension: contract.FileExtensionCSV,
ParsingStrategy: &contract.ParsingStrategy{
HeaderLine: 0,
DataStartLine: 1,
RowsCount: 10,
Columns: []*document.Column{
{
ID: 0,
Name: "col_string_indexing",
Type: document.TableColumnTypeString,
Nullable: false,
Sequence: 0,
},
{
ID: 0,
Name: "col_string",
Type: document.TableColumnTypeString,
Nullable: false,
Sequence: 1,
},
{
ID: 0,
Name: "col_int",
Type: document.TableColumnTypeInteger,
Nullable: false,
Sequence: 2,
},
{
ID: 0,
Name: "col_number",
Type: document.TableColumnTypeNumber,
Nullable: true,
Sequence: 3,
},
{
ID: 0,
Name: "col_bool",
Type: document.TableColumnTypeBoolean,
Nullable: true,
Sequence: 4,
},
{
ID: 0,
Name: "col_time",
Type: document.TableColumnTypeTime,
Nullable: true,
Sequence: 5,
},
},
},
ChunkingStrategy: nil,
}
p2 := parseCSV(c2)
docs, err = p2(ctx, r2, parser.WithExtraMeta(map[string]any{
"document_id": int64(123),
"knowledge_id": int64(456),
}))
assert.NoError(t, err)
for i, doc := range docs {
assertSheet(t, i, doc)
}
}
func TestParseCSVBadCases(t *testing.T) {
t.Run("test nil row", func(t *testing.T) {
ctx := context.Background()
f, err := os.Open("test_data/test_csv_badcase_1.csv")
assert.NoError(t, err)
b, err := io.ReadAll(f)
assert.NoError(t, err)
pfn := parseCSV(&contract.Config{
FileExtension: "csv",
ParsingStrategy: &contract.ParsingStrategy{
ExtractImage: true,
ExtractTable: true,
ImageOCR: false,
SheetID: nil,
HeaderLine: 0,
DataStartLine: 1,
RowsCount: 0,
IsAppend: false,
Columns: nil,
IgnoreColumnTypeErr: true,
ImageAnnotationType: 0,
},
})
resp, err := pfn(ctx, bytes.NewReader(b))
assert.NoError(t, err)
assert.True(t, len(resp) > 0)
cols, err := document.GetDocumentColumns(resp[0])
assert.NoError(t, err)
cols[5].Nullable = false
npfn := parseCSV(&contract.Config{
FileExtension: "csv",
ParsingStrategy: &contract.ParsingStrategy{
ExtractImage: true,
ExtractTable: true,
ImageOCR: false,
SheetID: nil,
HeaderLine: 0,
DataStartLine: 1,
RowsCount: 0,
IsAppend: false,
Columns: cols,
IgnoreColumnTypeErr: true,
ImageAnnotationType: 0,
},
})
resp, err = npfn(ctx, bytes.NewReader(b))
assert.NoError(t, err)
assert.True(t, len(resp) > 0)
for _, item := range resp {
data, err := document.GetDocumentColumnData(item)
assert.NoError(t, err)
assert.NotNil(t, data[5].GetValue())
}
})
}
func assertSheet(t *testing.T, i int, doc *schema.Document) {
fmt.Printf("sheet[%d]:\n", i)
assert.NotNil(t, doc.MetaData)
assert.NotNil(t, doc.MetaData[document.MetaDataKeyColumns])
cols, ok := doc.MetaData[document.MetaDataKeyColumns].([]*document.Column)
assert.True(t, ok)
assert.NotNil(t, doc.MetaData[document.MetaDataKeyColumnData])
row, ok := doc.MetaData[document.MetaDataKeyColumnData].([]*document.ColumnData)
assert.True(t, ok)
assert.Equal(t, int64(123), doc.MetaData["document_id"].(int64))
assert.Equal(t, int64(456), doc.MetaData["knowledge_id"].(int64))
for j := range row {
col := cols[j]
val := row[j]
fmt.Printf("row[%d]: %v=%v\n", j, col.Name, val.GetStringValue())
}
}

View File

@@ -0,0 +1,172 @@
import io
import os
import json
import sys
import base64
import logging
import time
from abc import ABC
from typing import List, IO
from docx import ImagePart
from docx.oxml import CT_P, CT_Tbl
from docx.table import Table
from docx.text.paragraph import Paragraph
from docx import Document
from PIL import Image
logger = logging.getLogger(__name__)
class DocxLoader(ABC):
def __init__(
self,
file_content: IO[bytes],
extract_images: bool = True,
extract_tables: bool = True,
):
self.file_content = file_content
self.extract_images = extract_images
self.extract_tables = extract_tables
def load(self) -> List[dict]:
result = []
doc = Document(self.file_content)
it = iter(doc.element.body)
text = ""
for part in it:
blocks = self.parse_part(part, doc)
if blocks is None or len(blocks) == 0:
continue
for block in blocks:
if self.extract_images and isinstance(block, list):
for b in block:
image = io.BytesIO()
try:
Image.open(io.BytesIO(b.image.blob)).save(image, format="png")
except Exception as e:
logging.error(f"load image failed, time={time.asctime()}, err:{e}")
raise RuntimeError("ExtractImageError")
if len(text) > 0:
result.append(
{
"content": text,
"type": "text",
}
)
text = ""
result.append(
{
"content": base64.b64encode(image.getvalue()).decode('utf-8'),
"type": "image",
}
)
if isinstance(block, Paragraph):
text += block.text
if self.extract_tables and isinstance(block, Table):
rows = block.rows
if len(text) > 0:
result.append(
{
"content": text,
"type": "text",
}
)
text = ""
table = self.convert_table(rows)
result.append(
{
"table": table,
"type": "table",
}
)
if text:
text += "\n\n"
if len(text) > 0:
result.append(
{
"content": text,
"type": "text",
}
)
return result
def parse_part(self, block, doc: Document):
if isinstance(block, CT_P):
blocks = []
para = Paragraph(block, doc)
image_part = self.get_image_part(para, doc)
if image_part and para.text:
blocks.extend(self.parse_run(para))
elif image_part:
blocks.append(image_part)
elif para.text:
blocks.append(para)
return blocks
elif isinstance(block, CT_Tbl):
return [Table(block, doc)]
def parse_run(self, para: Paragraph):
runs = para.runs
paras = []
if runs is None or len(runs) == 0:
return paras
for run in runs:
if run is None or run.element is None:
continue
p = Paragraph(run.element, para)
image_part = self.get_image_part(p, para)
if image_part:
paras.append(image_part)
else:
paras.append(p)
return paras
@staticmethod
def get_image_part(graph: Paragraph, doc: Document):
images = graph._element.xpath(".//pic:pic")
image_parts = []
for image in images:
for img_id in image.xpath(".//a:blip/@r:embed"):
part = doc.part.related_parts[img_id]
if isinstance(part, ImagePart):
image_parts.append(part)
return image_parts
@staticmethod
def convert_table(rows) -> List[List[str]]:
resp_rows = []
for i, row in enumerate(rows):
resp_row = []
for j, cell in enumerate(row.cells):
resp_row.append(cell.text if cell is not None else '')
resp_rows.append(resp_row)
return resp_rows
if __name__ == "__main__":
w = os.fdopen(3, "wb", )
r = os.fdopen(4, "rb", )
try:
req = json.load(r)
ei, et = req['extract_images'], req['extract_tables']
loader = DocxLoader(file_content=io.BytesIO(sys.stdin.buffer.read()), extract_images=ei, extract_tables=et)
resp = loader.load()
print(f"Extracted {len(resp)} items")
result = json.dumps({"content": resp}, ensure_ascii=False)
w.write(str.encode(result))
w.flush()
w.close()
print("Docx parse done")
except Exception as e:
print("Docx parse error", e)
w.write(str.encode(json.dumps({"error": str(e)})))
w.flush()
w.close()

View File

@@ -0,0 +1,91 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"context"
"encoding/base64"
"fmt"
"io"
"github.com/cloudwego/eino/components/document/parser"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
func parseImage(config *contract.Config, model chatmodel.BaseChatModel) parseFn {
return func(ctx context.Context, reader io.Reader, opts ...parser.Option) (docs []*schema.Document, err error) {
options := parser.GetCommonOptions(&parser.Options{}, opts...)
doc := &schema.Document{
MetaData: map[string]any{},
}
for k, v := range options.ExtraMeta {
doc.MetaData[k] = v
}
switch config.ParsingStrategy.ImageAnnotationType {
case contract.ImageAnnotationTypeModel:
if model == nil {
return nil, errorx.New(errno.ErrKnowledgeNonRetryableCode, errorx.KV("reason", "model is not provided"))
}
bytes, err := io.ReadAll(reader)
if err != nil {
return nil, err
}
b64 := base64.StdEncoding.EncodeToString(bytes)
mime := fmt.Sprintf("image/%s", config.FileExtension)
url := fmt.Sprintf("data:%s;base64,%s", mime, b64)
input := &schema.Message{
Role: schema.User,
MultiContent: []schema.ChatMessagePart{
{
Type: schema.ChatMessagePartTypeText,
//Text: "Give a short description of the image.", // TODO: prompt in current language
Text: "简短描述下这张图片",
},
{
Type: schema.ChatMessagePartTypeImageURL,
ImageURL: &schema.ChatMessageImageURL{
URL: url,
MIMEType: mime,
},
},
},
}
output, err := model.Generate(ctx, []*schema.Message{input})
if err != nil {
return nil, fmt.Errorf("[parseImage] model generate failed: %w", err)
}
doc.Content = output.Content
case contract.ImageAnnotationTypeManual:
// do nothing
default:
return nil, fmt.Errorf("[parseImage] unknown image annotation type=%d", config.ParsingStrategy.ImageAnnotationType)
}
return []*schema.Document{doc}, nil
}
}

View File

@@ -0,0 +1,170 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"encoding/json"
"github.com/cloudwego/eino/components/document/parser"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
)
type rowIterator interface {
NextRow() (row []string, end bool, err error)
}
func parseByRowIterator(iter rowIterator, config *contract.Config, opts ...parser.Option) (
docs []*schema.Document, err error) {
ps := config.ParsingStrategy
options := parser.GetCommonOptions(&parser.Options{}, opts...)
i := 0
columnsProvides := ps.IsAppend || len(ps.Columns) > 0
rev := make(map[int]*document.Column)
var (
expColumns []*document.Column
expData [][]*document.ColumnData
)
for {
row, end, err := iter.NextRow()
if err != nil {
return nil, err
}
if end {
break
}
if i == ps.HeaderLine {
if columnsProvides {
expColumns = ps.Columns
} else {
for j, col := range row {
expColumns = append(expColumns, &document.Column{
Name: col,
Type: document.TableColumnTypeUnknown,
Sequence: j,
})
}
}
for j := range expColumns {
tc := expColumns[j]
rev[tc.Sequence] = tc
}
}
if i >= ps.DataStartLine {
var rowData []*document.ColumnData
for j := range row {
colSchema, found := rev[j]
if !found { // 列裁剪
continue
}
val := row[j]
if columnsProvides {
var data *document.ColumnData
if config.ParsingStrategy.IgnoreColumnTypeErr {
data = assertValAsForce(colSchema.Type, val, colSchema.Nullable)
} else {
data, err = assertValAs(colSchema.Type, val)
if err != nil {
return nil, err
}
}
data.ColumnID = colSchema.ID
data.ColumnName = colSchema.Name
rowData = append(rowData, data)
} else {
exp := assertVal(val)
colSchema.Type = transformColumnType(colSchema.Type, exp.Type)
rowData = append(rowData, &document.ColumnData{
ColumnID: colSchema.ID,
ColumnName: colSchema.Name,
Type: document.TableColumnTypeUnknown,
ValString: &val,
})
}
}
if rowData != nil {
expData = append(expData, rowData)
}
}
i++
if ps.RowsCount != 0 && len(docs) == ps.RowsCount {
break
}
}
if !columnsProvides {
// align data type when columns are provided
for _, col := range expColumns {
if col.Type == document.TableColumnTypeUnknown {
col.Type = document.TableColumnTypeString
}
}
for _, row := range expData {
if err = alignTableSliceValue(expColumns, row); err != nil {
return nil, err
}
}
}
if len(expData) == 0 {
// return a special document with columns only if there is no data
doc := &schema.Document{
MetaData: map[string]any{
document.MetaDataKeyColumns: expColumns,
document.MetaDataKeyColumnsOnly: struct{}{},
},
}
for k, v := range options.ExtraMeta {
doc.MetaData[k] = v
}
return []*schema.Document{doc}, nil
}
for j := range expData {
contentMapping := make(map[string]string)
for _, col := range expData[j] {
contentMapping[col.ColumnName] = col.GetStringValue()
}
b, err := json.Marshal(contentMapping)
if err != nil {
return nil, err
}
doc := &schema.Document{
Content: string(b), // set for tables in text
MetaData: map[string]any{
document.MetaDataKeyColumns: expColumns,
document.MetaDataKeyColumnData: expData[j],
},
}
for k, v := range options.ExtraMeta {
doc.MetaData[k] = v
}
docs = append(docs, doc)
}
return docs, nil
}

View File

@@ -0,0 +1,97 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"context"
"encoding/json"
"fmt"
"io"
"github.com/cloudwego/eino/components/document/parser"
"github.com/cloudwego/eino/schema"
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
)
func parseJSON(config *contract.Config) parseFn {
return func(ctx context.Context, reader io.Reader, opts ...parser.Option) (docs []*schema.Document, err error) {
b, err := io.ReadAll(reader)
if err != nil {
return nil, err
}
var rawSlices []map[string]string
if err = json.Unmarshal(b, &rawSlices); err != nil {
return nil, err
}
if len(rawSlices) == 0 {
return nil, fmt.Errorf("[parseJSON] json data is empty")
}
var header []string
if config.ParsingStrategy.IsAppend {
for _, col := range config.ParsingStrategy.Columns {
header = append(header, col.Name)
}
} else {
for k := range rawSlices[0] {
// init 取首个 json item 中 key 的随机顺序
header = append(header, k)
}
}
iter := &jsonIterator{
header: header,
rows: rawSlices,
i: 0,
}
return parseByRowIterator(iter, config, opts...)
}
}
type jsonIterator struct {
header []string
rows []map[string]string
i int
}
func (j *jsonIterator) NextRow() (row []string, end bool, err error) {
if j.i == 0 {
j.i++
return j.header, false, nil
}
if j.i == len(j.rows)+1 {
return nil, true, nil
}
raw := j.rows[j.i-1]
j.i++
for _, h := range j.header {
val, found := raw[h]
if !found {
row = append(row, "")
} else {
row = append(row, val)
}
}
return row, false, nil
}

View File

@@ -0,0 +1,130 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"context"
"encoding/json"
"fmt"
"io"
"github.com/cloudwego/eino/components/document/parser"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
)
func parseJSONMaps(config *contract.Config) parseFn {
return func(ctx context.Context, reader io.Reader, opts ...parser.Option) (docs []*schema.Document, err error) {
b, err := io.ReadAll(reader)
if err != nil {
return nil, err
}
var customContent []map[string]string
if err = json.Unmarshal(b, &customContent); err != nil {
return nil, err
}
if config.ParsingStrategy == nil {
config.ParsingStrategy = &contract.ParsingStrategy{
HeaderLine: 0,
DataStartLine: 1,
RowsCount: 0,
}
}
iter := &customContentContainer{
i: 0,
colIdx: nil,
customContent: customContent,
curColumns: config.ParsingStrategy.Columns,
}
newConfig := &contract.Config{
FileExtension: config.FileExtension,
ParsingStrategy: &contract.ParsingStrategy{
SheetID: config.ParsingStrategy.SheetID,
HeaderLine: 0,
DataStartLine: 1,
RowsCount: 0,
IsAppend: config.ParsingStrategy.IsAppend,
Columns: config.ParsingStrategy.Columns,
},
ChunkingStrategy: config.ChunkingStrategy,
}
return parseByRowIterator(iter, newConfig, opts...)
}
}
type customContentContainer struct {
i int
colIdx map[string]int
customContent []map[string]string
curColumns []*document.Column
}
func (c *customContentContainer) NextRow() (row []string, end bool, err error) {
if c.i == 0 && c.colIdx == nil {
if len(c.customContent) == 0 {
return nil, false, fmt.Errorf("[customContentContainer] data is nil")
}
headerRow := c.customContent[0]
founded := make(map[string]struct{})
colIdx := make(map[string]int, len(headerRow))
for _, col := range c.curColumns {
name := col.Name
if _, found := headerRow[name]; found {
founded[name] = struct{}{}
colIdx[name] = len(colIdx)
row = append(row, name)
}
}
for name := range headerRow {
if _, found := founded[name]; !found {
colIdx[name] = len(colIdx)
row = append(row, name)
}
}
c.colIdx = colIdx
return row, false, nil
}
if c.i >= len(c.customContent) {
return nil, true, nil
}
content := c.customContent[c.i]
c.i++
row = make([]string, len(content))
for k, v := range content {
idx, found := c.colIdx[k]
if !found {
return nil, false, fmt.Errorf("[customContentContainer] column not found, name=%s", k)
}
row[idx] = v
}
return row, false, nil
}

View File

@@ -0,0 +1,97 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"bytes"
"context"
"testing"
"github.com/cloudwego/eino/components/document/parser"
"github.com/stretchr/testify/assert"
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
)
func TestParseTableCustomContent(t *testing.T) {
ctx := context.Background()
b := []byte(`[{"col_string_indexing":"hello","col_string":"asd","col_int":"1","col_number":"1","col_bool":"true","col_time":"2006-01-02 15:04:05"},{"col_string_indexing":"bye","col_string":"","col_int":"2","col_number":"2.0","col_bool":"false","col_time":""}]`)
reader := bytes.NewReader(b)
config := &contract.Config{
FileExtension: contract.FileExtensionJsonMaps,
ParsingStrategy: &contract.ParsingStrategy{
HeaderLine: 0,
DataStartLine: 1,
RowsCount: 10,
Columns: []*document.Column{
{
ID: 0,
Name: "col_string_indexing",
Type: document.TableColumnTypeString,
Nullable: false,
Sequence: 0,
},
{
ID: 0,
Name: "col_string",
Type: document.TableColumnTypeString,
Nullable: false,
Sequence: 1,
},
{
ID: 0,
Name: "col_int",
Type: document.TableColumnTypeInteger,
Nullable: false,
Sequence: 2,
},
{
ID: 0,
Name: "col_number",
Type: document.TableColumnTypeNumber,
Nullable: true,
Sequence: 3,
},
{
ID: 0,
Name: "col_bool",
Type: document.TableColumnTypeBoolean,
Nullable: true,
Sequence: 4,
},
{
ID: 0,
Name: "col_time",
Type: document.TableColumnTypeTime,
Nullable: true,
Sequence: 5,
},
},
},
}
pfn := parseJSONMaps(config)
docs, err := pfn(ctx, reader, parser.WithExtraMeta(map[string]any{
"document_id": int64(123),
"knowledge_id": int64(456),
}))
assert.NoError(t, err)
for i, doc := range docs {
assertSheet(t, i, doc)
}
}

View File

@@ -0,0 +1,133 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"bytes"
"context"
"testing"
"github.com/cloudwego/eino/components/document/parser"
"github.com/stretchr/testify/assert"
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
)
func TestParseJSON(t *testing.T) {
b := []byte(`[
{
"department": "心血管科",
"title": "高血压患者能吃党参吗?",
"question": "我有高血压这两天女婿来的时候给我拿了些党参泡水喝,您好高血压可以吃党参吗?",
"answer": "高血压病人可以口服党参的。党参有降血脂,降血压的作用,可以彻底消除血液中的垃圾,从而对冠心病以及心血管疾病的患者都有一定的稳定预防工作作用,因此平时口服党参能远离三高的危害。另外党参除了益气养血,降低中枢神经作用,调整消化系统功能,健脾补肺的功能。感谢您的进行咨询,期望我的解释对你有所帮助。"
},
{
"department": "消化科",
"title": "哪家医院能治胃反流",
"question": "烧心打隔咳嗽低烧以有4年多",
"answer": "建议你用奥美拉唑同时,加用吗丁啉或莫沙必利或援生力维,另外还可以加用达喜片"
}
]`)
reader := bytes.NewReader(b)
config := &contract.Config{
FileExtension: contract.FileExtensionJSON,
ParsingStrategy: &contract.ParsingStrategy{
HeaderLine: 0,
DataStartLine: 1,
RowsCount: 2,
},
ChunkingStrategy: nil,
}
pfn := parseJSON(config)
docs, err := pfn(context.Background(), reader, parser.WithExtraMeta(map[string]any{
"document_id": int64(123),
"knowledge_id": int64(456),
}))
assert.NoError(t, err)
for i, doc := range docs {
assertSheet(t, i, doc)
}
}
func TestParseJSONWithSchema(t *testing.T) {
b := []byte(`[
{
"department": "心血管科",
"title": "高血压患者能吃党参吗?",
"question": "我有高血压这两天女婿来的时候给我拿了些党参泡水喝,您好高血压可以吃党参吗?",
"answer": "高血压病人可以口服党参的。党参有降血脂,降血压的作用,可以彻底消除血液中的垃圾,从而对冠心病以及心血管疾病的患者都有一定的稳定预防工作作用,因此平时口服党参能远离三高的危害。另外党参除了益气养血,降低中枢神经作用,调整消化系统功能,健脾补肺的功能。感谢您的进行咨询,期望我的解释对你有所帮助。"
},
{
"department": "消化科",
"title": "哪家医院能治胃反流",
"question": "烧心打隔咳嗽低烧以有4年多",
"answer": "建议你用奥美拉唑同时,加用吗丁啉或莫沙必利或援生力维,另外还可以加用达喜片"
}
]`)
reader := bytes.NewReader(b)
config := &contract.Config{
FileExtension: contract.FileExtensionJSON,
ParsingStrategy: &contract.ParsingStrategy{
HeaderLine: 0,
DataStartLine: 1,
RowsCount: 2,
Columns: []*document.Column{
{
ID: 101,
Name: "department",
Type: document.TableColumnTypeString,
Nullable: false,
Sequence: 0,
},
{
ID: 102,
Name: "title",
Type: document.TableColumnTypeString,
Nullable: false,
Sequence: 1,
},
{
ID: 103,
Name: "question",
Type: document.TableColumnTypeString,
Nullable: false,
Sequence: 2,
},
{
ID: 104,
Name: "answer",
Type: document.TableColumnTypeString,
Nullable: false,
Sequence: 3,
},
},
},
}
pfn := parseJSON(config)
docs, err := pfn(context.Background(), reader, parser.WithExtraMeta(map[string]any{
"document_id": int64(123),
"knowledge_id": int64(456),
}))
assert.NoError(t, err)
for i, doc := range docs {
assertSheet(t, i, doc)
}
}

View File

@@ -0,0 +1,221 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"context"
"encoding/base64"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/cloudwego/eino/components/document/parser"
"github.com/cloudwego/eino/schema"
"github.com/yuin/goldmark"
"github.com/yuin/goldmark/ast"
"github.com/yuin/goldmark/text"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/ocr"
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
)
func parseMarkdown(config *contract.Config, storage storage.Storage, ocr ocr.OCR) parseFn {
return func(ctx context.Context, reader io.Reader, opts ...parser.Option) (docs []*schema.Document, err error) {
options := parser.GetCommonOptions(&parser.Options{}, opts...)
mdParser := goldmark.DefaultParser()
b, err := io.ReadAll(reader)
if err != nil {
return nil, err
}
node := mdParser.Parse(text.NewReader(b))
cs := config.ChunkingStrategy
ps := config.ParsingStrategy
if cs.ChunkType != contract.ChunkTypeCustom && cs.ChunkType != contract.ChunkTypeDefault {
return nil, fmt.Errorf("[parseMarkdown] chunk type not support, chunk type=%d", cs.ChunkType)
}
var (
last *schema.Document
emptySlice bool
)
addSliceContent := func(content string) {
emptySlice = false
last.Content += content
}
newSlice := func(needOverlap bool) {
last = &schema.Document{
MetaData: map[string]any{},
}
for k, v := range options.ExtraMeta {
last.MetaData[k] = v
}
if needOverlap && cs.Overlap > 0 && len(docs) > 0 {
overlap := getOverlap([]rune(docs[len(docs)-1].Content), cs.Overlap, cs.ChunkSize)
addSliceContent(string(overlap))
}
emptySlice = true
}
pushSlice := func() {
if !emptySlice && last.Content != "" {
docs = append(docs, last)
newSlice(true)
}
}
trim := func(text string) string {
if cs.TrimURLAndEmail {
text = urlRegex.ReplaceAllString(text, "")
text = emailRegex.ReplaceAllString(text, "")
}
if cs.TrimSpace {
text = strings.TrimSpace(text)
text = spaceRegex.ReplaceAllString(text, " ")
}
return text
}
downloadImage := func(ctx context.Context, url string) ([]byte, error) {
client := &http.Client{Timeout: 5 * time.Second}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to download image: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to download image, status code: %d", resp.StatusCode)
}
data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read image content: %w", err)
}
return data, nil
}
walker := func(n ast.Node, entering bool) (ast.WalkStatus, error) {
if !entering {
return ast.WalkContinue, nil
}
switch n.Kind() {
case ast.KindText:
if n.HasChildren() {
break
}
textNode := n.(*ast.Text)
plainText := trim(string(textNode.Segment.Value(b)))
for _, part := range strings.Split(plainText, cs.Separator) {
runes := []rune(part)
for partLength := int64(len(runes)); partLength > 0; partLength = int64(len(runes)) {
pos := min(partLength, cs.ChunkSize-charCount(last.Content))
chunk := runes[:pos]
addSliceContent(string(chunk))
runes = runes[pos:]
if charCount(last.Content) >= cs.ChunkSize {
pushSlice()
}
}
}
case ast.KindImage:
if !ps.ExtractImage {
break
}
imageNode := n.(*ast.Image)
if ps.ExtractImage {
imageURL := string(imageNode.Destination)
if _, err = url.ParseRequestURI(imageURL); err == nil {
sp := strings.Split(imageURL, ".")
if len(sp) == 0 {
return ast.WalkStop, fmt.Errorf("failed to extract image extension, url=%s", imageURL)
}
ext := sp[len(sp)-1]
img, err := downloadImage(ctx, imageURL)
if err != nil {
return ast.WalkStop, fmt.Errorf("failed to download image: %w", err)
}
imgSrc, err := putImageObject(ctx, storage, ext, getCreatorIDFromExtraMeta(options.ExtraMeta), img)
if err != nil {
return ast.WalkStop, err
}
if !emptySlice && last.Content != "" {
pushSlice()
} else {
newSlice(false)
}
addSliceContent(fmt.Sprintf("\n%s\n", imgSrc))
if ps.ImageOCR && ocr != nil {
texts, err := ocr.FromBase64(ctx, base64.StdEncoding.EncodeToString(img))
if err != nil {
return ast.WalkStop, fmt.Errorf("failed to perform OCR on image: %w", err)
}
addSliceContent(strings.Join(texts, "\n"))
}
if charCount(last.Content) >= cs.ChunkSize {
pushSlice()
}
} else {
logs.CtxInfof(ctx, "[parseMarkdown] not a valid image url, skip, got=%s", imageURL)
}
}
}
return ast.WalkContinue, nil
}
newSlice(false)
if err = ast.Walk(node, walker); err != nil {
return nil, err
}
if !emptySlice {
pushSlice()
}
return docs, nil
}
}

View File

@@ -0,0 +1,75 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"context"
"fmt"
"os"
"testing"
"github.com/cloudwego/eino/components/document/parser"
"github.com/cloudwego/eino/schema"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
ms "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/storage"
)
func TestParseMarkdown(t *testing.T) {
ctx := context.Background()
ctrl := gomock.NewController(t)
mockStorage := ms.NewMockStorage(ctrl)
mockStorage.EXPECT().PutObject(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
pfn := parseMarkdown(&contract.Config{
FileExtension: contract.FileExtensionMarkdown,
ParsingStrategy: &contract.ParsingStrategy{
ExtractImage: true,
ExtractTable: true,
ImageOCR: true,
},
ChunkingStrategy: &contract.ChunkingStrategy{
ChunkType: contract.ChunkTypeCustom,
ChunkSize: 800,
Separator: "\n",
Overlap: 10,
TrimSpace: true,
TrimURLAndEmail: true,
},
}, mockStorage, nil)
f, err := os.Open("test_data/test_markdown.md")
assert.NoError(t, err)
docs, err := pfn(ctx, f, parser.WithExtraMeta(map[string]any{
"document_id": int64(123),
"knowledge_id": int64(456),
}))
assert.NoError(t, err)
for _, doc := range docs {
assertDoc(t, doc)
}
}
func assertDoc(t *testing.T, doc *schema.Document) {
assert.NotZero(t, doc.Content)
fmt.Println(doc.Content)
assert.NotNil(t, doc.MetaData)
assert.Equal(t, int64(123), doc.MetaData["document_id"].(int64))
assert.Equal(t, int64(456), doc.MetaData["knowledge_id"].(int64))
}

View File

@@ -0,0 +1,152 @@
import io
import json
import os
import sys
import base64
from typing import Literal
import pdfplumber
from PIL import Image, ImageChops
from pdfminer.pdfcolor import (
LITERAL_DEVICE_CMYK,
)
from pdfminer.pdftypes import (
LITERALS_DCT_DECODE,
LITERALS_FLATE_DECODE,
)
def bbox_overlap(bbox1, bbox2):
x0_1, y0_1, x1_1, y1_1 = bbox1
x0_2, y0_2, x1_2, y1_2 = bbox2
x_overlap = max(0, min(x1_1, x1_2) - max(x0_1, x0_2))
y_overlap = max(0, min(y1_1, y1_2) - max(y0_1, y0_2))
overlap_area = x_overlap * y_overlap
bbox1_area = (x1_1 - x0_1) * (y1_1 - y0_1)
bbox2_area = (x1_2 - x0_2) * (y1_2 - y0_2)
if bbox1_area == 0 or bbox2_area == 0:
return 0
return overlap_area / min(bbox1_area, bbox2_area)
def is_structured_table(table):
if not table:
return False
row_count = len(table)
col_count = max(len(row) for row in table)
return row_count >= 2 and col_count >= 2
def extract_pdf_content(pdf_data: bytes, extract_images, extract_tables: bool, filter_pages: []):
with pdfplumber.open(io.BytesIO(pdf_data)) as pdf:
content = []
for page_num, page in enumerate(pdf.pages):
if filter_pages is not None and page_num + 1 in filter_pages:
print(f"Skip page {page_num + 1}...")
continue
print(f"Processing page {page_num + 1}...")
text = page.extract_text(x_tolerance=2)
content.append({
'type': 'text',
'content': text,
'page': page_num + 1,
'bbox': page.bbox
})
if extract_images:
images = page.images
for img_index, img in enumerate(images):
try:
filters = img['stream'].get_filters()
data = img['stream'].get_data()
buffered = io.BytesIO()
if filters[-1][0] in LITERALS_DCT_DECODE:
if LITERAL_DEVICE_CMYK in img['colorspace']:
i = Image.open(io.BytesIO(data))
i = ImageChops.invert(i)
i = i.convert("RGB")
i.save(buffered, format="PNG")
else:
buffered.write(data)
elif len(filters) == 1 and filters[0][0] in LITERALS_FLATE_DECODE:
width, height = img['srcsize']
channels = len(img['stream'].get_data()) / width / height / (img['bits'] / 8)
mode: Literal["1", "L", "RGB", "CMYK"]
if img['bits'] == 1:
mode = "1"
elif img['bits'] == 8 and channels == 1:
mode = "L"
elif img['bits'] == 8 and channels == 3:
mode = "RGB"
elif img['bits'] == 8 and channels == 4:
mode = "CMYK"
i = Image.frombytes(mode, img['srcsize'], data, "raw")
i.save(buffered, format="PNG")
else:
buffered.write(data)
content.append({
'type': 'image',
'content': base64.b64encode(buffered.getvalue()).decode('utf-8'),
'page': page_num + 1,
'bbox': (img['x0'], img['top'], img['x1'], img['bottom'])
})
except Exception as err:
print(f"Skipping an unsupported image on page {page_num + 1}, error message: {err}")
if extract_tables:
tables = page.extract_tables()
for table in tables:
content.append({
'type': 'table',
'table': table,
'page': page_num + 1,
'bbox': page.bbox
})
content.sort(key=lambda x: (x['page'], x['bbox'][1], x['bbox'][0]))
filtered_content = []
for item in content:
if item['type'] == 'table':
if is_structured_table(item['table']):
filtered_content.append(item)
continue
overlap_found = False
for existing_item in filtered_content:
if existing_item['type'] == 'text' and bbox_overlap(item['bbox'], existing_item['bbox']) > 0.8:
overlap_found = True
break
if overlap_found:
continue
filtered_content.append(item)
return filtered_content
if __name__ == "__main__":
w = os.fdopen(3, "wb", )
r = os.fdopen(4, "rb", )
pdf_data = sys.stdin.buffer.read()
print(f"Read {len(pdf_data)} bytes of PDF data")
try:
req = json.load(r)
ei, et, fp = req['extract_images'], req['extract_tables'], req['filter_pages']
extracted_content = extract_pdf_content(pdf_data, ei, et, fp)
print(f"Extracted {len(extracted_content)} items")
result = json.dumps({"content": extracted_content}, ensure_ascii=False)
w.write(str.encode(result))
w.flush()
w.close()
print("Pdf parse done")
except Exception as e:
print("Pdf parse error", e)
w.write(str.encode(json.dumps({"error": str(e)})))
w.flush()
w.close()

View File

@@ -0,0 +1,49 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"context"
"fmt"
"io"
"github.com/cloudwego/eino/components/document/parser"
"github.com/cloudwego/eino/schema"
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
)
func parseText(config *contract.Config) parseFn {
return func(ctx context.Context, reader io.Reader, opts ...parser.Option) (docs []*schema.Document, err error) {
content, err := io.ReadAll(reader)
if err != nil {
return nil, err
}
switch config.ChunkingStrategy.ChunkType {
case contract.ChunkTypeCustom, contract.ChunkTypeDefault:
docs, err = chunkCustom(ctx, string(content), config, opts...)
default:
return nil, fmt.Errorf("[parseText] chunk type not support, type=%d", config.ChunkingStrategy.ChunkType)
}
if err != nil {
return nil, err
}
return docs, nil
}
}

View File

@@ -0,0 +1,78 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"context"
"io"
"github.com/cloudwego/eino/components/document/parser"
"github.com/cloudwego/eino/schema"
"github.com/xuri/excelize/v2"
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
)
func parseXLSX(config *contract.Config) parseFn {
return func(ctx context.Context, reader io.Reader, opts ...parser.Option) (docs []*schema.Document, err error) {
f, err := excelize.OpenReader(reader)
if err != nil {
return nil, err
}
sheetID := 0
if config.ParsingStrategy.SheetID != nil {
sheetID = *config.ParsingStrategy.SheetID
}
rows, err := f.Rows(f.GetSheetName(sheetID))
if err != nil {
return nil, err
}
iter := &xlsxIterator{rows, 0}
return parseByRowIterator(iter, config, opts...)
}
}
type xlsxIterator struct {
rows *excelize.Rows
firstRowSize int
}
func (x *xlsxIterator) NextRow() (row []string, end bool, err error) {
end = !x.rows.Next()
if end {
return nil, end, nil
}
row, err = x.rows.Columns()
if err != nil {
return nil, false, err
}
if x.firstRowSize == 0 {
x.firstRowSize = len(row)
} else if x.firstRowSize > len(row) {
row = append(row, make([]string, x.firstRowSize-len(row))...)
} else if x.firstRowSize < len(row) {
row = row[:x.firstRowSize]
}
return row, false, nil
}

View File

@@ -0,0 +1,171 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"bytes"
"context"
"os"
"testing"
"github.com/cloudwego/eino/components/document/parser"
"github.com/stretchr/testify/assert"
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
)
func TestParseXLSX(t *testing.T) {
ctx := context.Background()
b, err := os.ReadFile("./test_data/test_xlsx.xlsx")
assert.NoError(t, err)
reader := bytes.NewReader(b)
config := &contract.Config{
FileExtension: contract.FileExtensionXLSX,
ParsingStrategy: &contract.ParsingStrategy{
HeaderLine: 0,
DataStartLine: 1,
RowsCount: 10,
Columns: []*document.Column{
{
ID: 0,
Name: "col_string_indexing",
Type: document.TableColumnTypeString,
Nullable: false,
Sequence: 0,
},
{
ID: 0,
Name: "col_string",
Type: document.TableColumnTypeString,
Nullable: true,
Sequence: 1,
},
{
ID: 0,
Name: "col_int",
Type: document.TableColumnTypeInteger,
Nullable: false,
Sequence: 2,
},
{
ID: 0,
Name: "col_number",
Type: document.TableColumnTypeNumber,
Nullable: true,
Sequence: 3,
},
{
ID: 0,
Name: "col_bool",
Type: document.TableColumnTypeBoolean,
Nullable: true,
Sequence: 4,
},
{
ID: 0,
Name: "col_time",
Type: document.TableColumnTypeTime,
Nullable: true,
Sequence: 5,
},
},
},
ChunkingStrategy: nil,
}
pfn := parseXLSX(config)
docs, err := pfn(ctx, reader, parser.WithExtraMeta(map[string]any{
"document_id": int64(123),
"knowledge_id": int64(456),
}))
assert.NoError(t, err)
for i, doc := range docs {
assertSheet(t, i, doc)
}
}
func TestParseXLSXConvertColumnType(t *testing.T) {
ctx := context.Background()
b, err := os.ReadFile("./test_data/test_xlsx.xlsx")
assert.NoError(t, err)
reader := bytes.NewReader(b)
config := &contract.Config{
FileExtension: contract.FileExtensionXLSX,
ParsingStrategy: &contract.ParsingStrategy{
HeaderLine: 0,
DataStartLine: 1,
RowsCount: 10,
IgnoreColumnTypeErr: true,
Columns: []*document.Column{
{
ID: 0,
Name: "col_string_indexing",
Type: document.TableColumnTypeString,
Nullable: false,
Sequence: 0,
},
{
ID: 0,
Name: "col_string",
Type: document.TableColumnTypeInteger, // string -> int: null
Nullable: true,
Sequence: 1,
},
{
ID: 0,
Name: "col_int",
Type: document.TableColumnTypeString, // int -> string: strconv
Nullable: false,
Sequence: 2,
},
{
ID: 0,
Name: "col_number",
Type: document.TableColumnTypeString, // float -> string: strconv
Nullable: true,
Sequence: 3,
},
//{
// ID: 0,
// Name: "col_bool",
// Type: document.TableColumnTypeBoolean, // trim
// Nullable: true,
// Sequence: 4,
//},
//{
// ID: 0,
// Name: "col_time",
// Type: document.TableColumnTypeTime, // trim
// Nullable: true,
// Sequence: 5,
//},
},
},
ChunkingStrategy: nil,
}
pfn := parseXLSX(config)
docs, err := pfn(ctx, reader, parser.WithExtraMeta(map[string]any{
"document_id": int64(123),
"knowledge_id": int64(456),
}))
assert.NoError(t, err)
for i, doc := range docs {
assertSheet(t, i, doc)
}
}

View File

@@ -0,0 +1,35 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"context"
"io"
"github.com/cloudwego/eino/components/document/parser"
"github.com/cloudwego/eino/schema"
)
type p struct {
parseFn
}
func (p p) Parse(ctx context.Context, reader io.Reader, opts ...parser.Option) ([]*schema.Document, error) {
return p.parseFn(ctx, reader, opts...)
}
type parseFn func(ctx context.Context, reader io.Reader, opts ...parser.Option) (docs []*schema.Document, err error)

View File

@@ -0,0 +1,270 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"os"
"os/exec"
"strings"
"github.com/cloudwego/eino/components/document/parser"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/ocr"
contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
)
const (
contentTypeText = "text"
contentTypeImage = "image"
contentTypeTable = "table"
)
type pyParseRequest struct {
ExtractImages bool `json:"extract_images"`
ExtractTables bool `json:"extract_tables"`
FilterPages []int `json:"filter_pages"`
}
type pyParseResult struct {
Error string `json:"error"`
Content []*pyParseContent `json:"content"`
}
type pyParseContent struct {
Type string `json:"type"`
Content string `json:"content"`
Table [][]string `json:"table"`
Page int `json:"page"`
}
type pyPDFTableIterator struct {
i int
rows [][]string
}
func (p *pyPDFTableIterator) NextRow() (row []string, end bool, err error) {
if p.i >= len(p.rows) {
return nil, true, nil
}
row = p.rows[p.i]
p.i++
return row, false, nil
}
func parseByPython(config *contract.Config, storage storage.Storage, ocr ocr.OCR, pyPath, scriptPath string) parseFn {
return func(ctx context.Context, reader io.Reader, opts ...parser.Option) (docs []*schema.Document, err error) {
pr, pw, err := os.Pipe()
if err != nil {
return nil, fmt.Errorf("[parseByPython] create rpipe failed, %w", err)
}
r, w, err := os.Pipe()
if err != nil {
return nil, fmt.Errorf("[parseByPython] create pipe failed: %w", err)
}
options := parser.GetCommonOptions(&parser.Options{ExtraMeta: map[string]any{}}, opts...)
reqb, err := json.Marshal(pyParseRequest{
ExtractImages: config.ParsingStrategy.ExtractImage,
ExtractTables: config.ParsingStrategy.ExtractTable,
FilterPages: config.ParsingStrategy.FilterPages,
})
if err != nil {
return nil, fmt.Errorf("[parseByPython] create parse request failed, %w", err)
}
if _, err = pw.Write(reqb); err != nil {
return nil, fmt.Errorf("[parseByPython] write parse request bytes failed, %w", err)
}
if err = pw.Close(); err != nil {
return nil, fmt.Errorf("[parseByPython] close write request pipe failed, %w", err)
}
cmd := exec.Command(pyPath, scriptPath)
cmd.Stdin = reader
cmd.Stdout = os.Stdout
cmd.ExtraFiles = []*os.File{w, pr}
if err = cmd.Start(); err != nil {
return nil, fmt.Errorf("[parseByPython] failed to start Python script: %w", err)
}
if err = w.Close(); err != nil {
return nil, fmt.Errorf("[parseByPython] failed to close write pipe: %w", err)
}
result := &pyParseResult{}
if err = json.NewDecoder(r).Decode(result); err != nil {
return nil, fmt.Errorf("[parseByPython] failed to decode result: %w", err)
}
if err = cmd.Wait(); err != nil {
return nil, fmt.Errorf("[parseByPython] cmd wait err: %w", err)
}
if result.Error != "" {
return nil, fmt.Errorf("[parseByPython] python execution failed: %s", result.Error)
}
for i, item := range result.Content {
switch item.Type {
case contentTypeText:
partDocs, err := chunkCustom(ctx, item.Content, config, opts...)
if err != nil {
return nil, fmt.Errorf("[parseByPython] chunk text failed, %w", err)
}
docs = append(docs, partDocs...)
case contentTypeImage:
if !config.ParsingStrategy.ExtractImage {
continue
}
image, err := base64.StdEncoding.DecodeString(item.Content)
if err != nil {
return nil, fmt.Errorf("[parseByPython] decode image failed, %w", err)
}
imgSrc, err := putImageObject(ctx, storage, "png", getCreatorIDFromExtraMeta(options.ExtraMeta), image)
if err != nil {
return nil, err
}
label := fmt.Sprintf("\n%s", imgSrc)
if config.ParsingStrategy.ImageOCR && ocr != nil {
texts, err := ocr.FromBase64(ctx, item.Content)
if err != nil {
return nil, fmt.Errorf("[parseByPython] FromBase64 failed, %w", err)
}
label += strings.Join(texts, "\n")
}
if i == len(result.Content)-1 || result.Content[i+1].Type != "text" {
doc := &schema.Document{
Content: label,
MetaData: map[string]any{},
}
for k, v := range options.ExtraMeta {
doc.MetaData[k] = v
}
docs = append(docs, doc)
} else {
// TODO: 这里有点问题img label 可能被较短的 chunk size 截断
result.Content[i+1].Content = label + result.Content[i+1].Content
}
case contentTypeTable:
if !config.ParsingStrategy.ExtractTable {
continue
}
iterator := &pyPDFTableIterator{i: 0, rows: item.Table}
rawTableDocs, err := parseByRowIterator(iterator, &contract.Config{
FileExtension: contract.FileExtensionCSV,
ParsingStrategy: &contract.ParsingStrategy{
HeaderLine: 0,
DataStartLine: 1,
RowsCount: 0,
},
ChunkingStrategy: config.ChunkingStrategy,
}, opts...)
if err != nil {
return nil, fmt.Errorf("[parseByPython] parse table failed, %w", err)
}
fmtTableDocs, err := formatTablesInDocument(rawTableDocs)
if err != nil {
return nil, fmt.Errorf("[parseByPython] format table failed, %w", err)
}
docs = append(docs, fmtTableDocs...)
default:
return nil, fmt.Errorf("[parseByPython] invalid content type: %s", item.Type)
}
}
return docs, nil
}
}
func formatTablesInDocument(input []*schema.Document) (output []*schema.Document, err error) {
const (
maxSize = 65535
tableStart, tableEnd = "<table>", "</table>"
)
var (
buffer strings.Builder
firstDoc *schema.Document
)
endSize := len(tableEnd)
buffer.WriteString(tableStart)
push := func() {
newDoc := &schema.Document{
Content: buffer.String() + tableEnd,
MetaData: map[string]any{},
}
for k, v := range firstDoc.MetaData {
if k == document.MetaDataKeyColumnData {
continue
}
newDoc.MetaData[k] = v
}
output = append(output, newDoc)
buffer.Reset()
buffer.WriteString(tableStart)
}
write := func(contents []string) {
row := fmt.Sprintf("<tr><td>%s</td></tr>", strings.Join(contents, "</td><td>"))
buffer.WriteString(row)
if buffer.Len()+endSize >= maxSize {
push()
}
}
for i := range input {
doc := input[i]
if i == 0 {
firstDoc = doc
cols, err := document.GetDocumentColumns(doc)
if err != nil {
return nil, fmt.Errorf("[formatTablesInDocument] invalid table columns, %w", err)
}
values := make([]string, 0, len(cols))
for _, col := range cols {
values = append(values, col.Name)
}
write(values)
}
data, err := document.GetDocumentColumnData(doc)
if err != nil {
return nil, fmt.Errorf("[formatTablesInDocument] invalid table data, %w", err)
}
values := make([]string, 0, len(data))
for _, col := range data {
values = append(values, col.GetNullableStringValue())
}
write(values)
}
if buffer.String() != tableStart {
push()
}
return
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 65 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 38 KiB

View File

@@ -0,0 +1,3 @@
col_string_indexing,col_string,col_int,col_number,col_bool,col_time
hello,asd,1,1.0,TRUE,2006-01-02 15:04:02
bye,,2,2.0,TRUE,
1 col_string_indexing col_string col_int col_number col_bool col_time
2 hello asd 1 1.0 TRUE 2006-01-02 15:04:02
3 bye 2 2.0 TRUE

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1 @@
col_string_indexing,col_string,col_int,col_number,col_bool,col_time
1 col_string_indexing col_string col_int col_number col_bool col_time

View File

@@ -0,0 +1,272 @@
# 1. 欢迎使用 Cmd Markdown 编辑阅读器
<!-- TOC -->
- [1. 欢迎使用 Cmd Markdown 编辑阅读器](#1-欢迎使用-cmd-markdown-编辑阅读器)
- [1.1. markdown扩展需求](#11-markdown扩展需求)
- [1.1.1. 一、各种流程图](#111-一各种流程图)
- [1.1.2. [Windows/Mac/Linux 全平台客户端](https://www.zybuluo.com/cmd/)](#112-windowsmaclinux-全平台客户端httpswwwzybuluocomcmd)
- [1.2. 什么是 Markdown](#12-什么是-markdown)
- [1.2.1. 制作一份待办事宜 [Todo 列表](https://www.zybuluo.com/mdeditor?url=https://www.zybuluo.com/static/editor/md-help.markdown#13-待办事宜-todo-列表)](#121-制作一份待办事宜-todo-列表httpswwwzybuluocommdeditorurlhttpswwwzybuluocomstaticeditormd-helpmarkdown13-待办事宜-todo-列表)
- [1.2.2. 书写一个质能守恒公式[^LaTeX]](#122-书写一个质能守恒公式^latex)
- [1.2.3. 高亮一段代码[^code]](#123-高亮一段代码^code)
- [1.2.4. 高效绘制 [流程图](https://www.zybuluo.com/mdeditor?url=https://www.zybuluo.com/static/editor/md-help.markdown#7-流程图)](#124-高效绘制-流程图httpswwwzybuluocommdeditorurlhttpswwwzybuluocomstaticeditormd-helpmarkdown7-流程图)
- [1.2.5. 高效绘制 [序列图](https://www.zybuluo.com/mdeditor?url=https://www.zybuluo.com/static/editor/md-help.markdown#8-序列图)](#125-高效绘制-序列图httpswwwzybuluocommdeditorurlhttpswwwzybuluocomstaticeditormd-helpmarkdown8-序列图)
- [1.2.6. 绘制表格](#126-绘制表格)
- [1.2.7. 更详细语法说明](#127-更详细语法说明)
- [1.3. 什么是 Cmd Markdown](#13-什么是-cmd-markdown)
- [1.3.1. 实时同步预览](#131-实时同步预览)
- [1.3.2. 编辑工具栏](#132-编辑工具栏)
- [1.3.3. 编辑模式](#133-编辑模式)
- [1.3.4. 实时的云端文稿](#134-实时的云端文稿)
- [1.3.5. 离线模式](#135-离线模式)
- [1.3.6. 管理工具栏](#136-管理工具栏)
- [1.3.7. 阅读工具栏](#137-阅读工具栏)
- [1.3.8. 阅读模式](#138-阅读模式)
- [1.3.9. 标签、分类和搜索](#139-标签分类和搜索)
- [1.3.10. 文稿发布和分享](#1310-文稿发布和分享)
<!-- /TOC -->
[ ] dddd
[x] xxxx
第一行
第二行
------
> 一个快速笔记工具,可生成网页快速分享。
## 1.1. markdown扩展需求
1. 目录
2. 表情
3. 粘贴截图
4. 流程图、时序图
5. 数学公式
6. 标签
7. 简单动画
### 1.1.1. 一、各种流程图
1. 时序图
```seq
Alice->Bob: Hello Bob, how are you?
Note right of Bob: Bob thinks
Bob-->Alice: I am good thanks!
```
2. 流程图
```flow
st=>start: Start
op=>operation: Your Operation
cond=>condition: Yes or No?
e=>end
st->op->cond
cond(yes)->e
cond(no)->op
```
3. 甘特图
```gantt
title 项目开发流程
section 项目确定
需求分析 :a1, 2016-06-22, 3d
可行性报告 :after a1, 5d
概念验证 : 5d
section 项目实施
概要设计 :2016-07-05, 5d
详细设计 :2016-07-08, 10d
编码 :2016-07-15, 10d
测试 :2016-07-22, 5d
section 发布验收
发布: 2d
验收: 3d
```
4. Mermaid 流程图
```graphLR
A[Hard edge] -->|Link text| B(Round edge)
B --> C{Decision}
C -->|One| D[Result one]
C -->|Two| E[Result two]
```
5. Mermaid 序列图
```sequence
Alice->John: Hello John, how are you?
loop every minute
John-->Alice: Great!
end
```
我们理解您需要更便捷更高效的工具记录思想,整理笔记、知识,并将其中承载的价值传播给他人,**Cmd Markdown** 是我们给出的答案 —— 我们为记录思想和分享知识提供更专业的工具。 您可以使用 Cmd Markdown
> * 整理知识,学习笔记
> * 发布日记,杂文,所见所想
> * 撰写发布技术文稿(代码支持)
> * 撰写发布学术论文LaTeX 公式支持)
![cmd-markdown-logo](logo.png)
除了您现在看到的这个 Cmd Markdown 在线版本,您还可以前往以下网址下载:
### 1.1.2. [Windows/Mac/Linux 全平台客户端](https://www.zybuluo.com/cmd/)
> 请保留此份 Cmd Markdown 的欢迎稿兼使用说明,如需撰写新稿件,点击顶部工具栏右侧的 <i class="icon-file"></i> **新文稿** 或者使用快捷键 `Ctrl+Alt+N`。
------
## 1.2. 什么是 Markdown
Markdown 是一种方便记忆、书写的纯文本标记语言,用户可以使用这些标记符号以最小的输入代价生成极富表现力的文档:譬如您正在阅读的这份文档。它使用简单的符号标记不同的标题,分割不同的段落,**粗体** 或者 *斜体* 某些文字,更棒的是,它还可以
### 1.2.1. 制作一份待办事宜 [Todo 列表](https://www.zybuluo.com/mdeditor?url=https://www.zybuluo.com/static/editor/md-help.markdown#13-待办事宜-todo-列表)
- [ ] 支持以 PDF 格式导出文稿
- [ ] 改进 Cmd 渲染算法,使用局部渲染技术提高渲染效率
- [x] 新增 Todo 列表功能
- [x] 修复 LaTex 公式渲染问题
- [x] 新增 LaTex 公式编号功能
### 1.2.2. 书写一个质能守恒公式[^LaTeX]
$$E=mc^2$$
### 1.2.3. 高亮一段代码[^code]
```python
@requires_authorization
class SomeClass:
pass
if __name__ == '__main__':
# A comment
print 'hello world'
```
### 1.2.4. 高效绘制 [流程图](https://www.zybuluo.com/mdeditor?url=https://www.zybuluo.com/static/editor/md-help.markdown#7-流程图)
```flow
st=>start: Start
op=>operation: Your Operation
cond=>condition: Yes or No?
e=>end
st->op->cond
cond(yes)->e
cond(no)->op
```
### 1.2.5. 高效绘制 [序列图](https://www.zybuluo.com/mdeditor?url=https://www.zybuluo.com/static/editor/md-help.markdown#8-序列图)
```seq
Alice->Bob: Hello Bob, how are you?
Note right of Bob: Bob thinks
Bob-->Alice: I am good thanks!
```
### 1.2.6. 绘制表格
| 项目 | 价格 | 数量 |
| -------- | -----: | :----: |
| 计算机 | \$1600 | 5 |
| 手机 | \$12 | 12 |
| 管线 | \$1 | 234 |
### 1.2.7. 更详细语法说明
想要查看更详细的语法说明,可以参考我们准备的 [Cmd Markdown 简明语法手册][1],进阶用户可以参考 [Cmd Markdown 高阶语法手册][2] 了解更多高级功能。
总而言之,不同于其它 *所见即所得* 的编辑器:你只需使用键盘专注于书写文本内容,就可以生成印刷级的排版格式,省却在键盘和工具栏之间来回切换,调整内容和格式的麻烦。**Markdown 在流畅的书写和印刷级的阅读体验之间找到了平衡。** 目前它已经成为世界上最大的技术分享网站 GitHub 和 技术问答网站 StackOverFlow 的御用书写格式。
---
## 1.3. 什么是 Cmd Markdown
您可以使用很多工具书写 Markdown但是 Cmd Markdown 是这个星球上我们已知的、最好的 Markdown 工具——没有之一 :)因为深信文字的力量,所以我们和你一样,对流畅书写,分享思想和知识,以及阅读体验有极致的追求,我们把对于这些诉求的回应整合在 Cmd Markdown并且一次两次三次乃至无数次地提升这个工具的体验最终将它演化成一个 **编辑/发布/阅读** Markdown 的在线平台——您可以在任何地方,任何系统/设备上管理这里的文字。
### 1.3.1. 实时同步预览
我们将 Cmd Markdown 的主界面一分为二,左边为**编辑区**,右边为**预览区**,在编辑区的操作会实时地渲染到预览区方便查看最终的版面效果,并且如果你在其中一个区拖动滚动条,我们有一个巧妙的算法把另一个区的滚动条同步到等价的位置,超酷!
### 1.3.2. 编辑工具栏
也许您还是一个 Markdown 语法的新手,在您完全熟悉它之前,我们在 **编辑区** 的顶部放置了一个如下图所示的工具栏,您可以使用鼠标在工具栏上调整格式,不过我们仍旧鼓励你使用键盘标记格式,提高书写的流畅度。
![tool-editor](toolbar-editor.png)
### 1.3.3. 编辑模式
完全心无旁骛的方式编辑文字:点击 **编辑工具栏** 最右测的拉伸按钮或者按下 `Ctrl + M`,将 Cmd Markdown 切换到独立的编辑模式,这是一个极度简洁的写作环境,所有可能会引起分心的元素都已经被挪除,超清爽!
### 1.3.4. 实时的云端文稿
为了保障数据安全Cmd Markdown 会将您每一次击键的内容保存至云端,同时在 **编辑工具栏** 的最右侧提示 `已保存` 的字样。无需担心浏览器崩溃,机器掉电或者地震,海啸——在编辑的过程中随时关闭浏览器或者机器,下一次回到 Cmd Markdown 的时候继续写作。
### 1.3.5. 离线模式
在网络环境不稳定的情况下记录文字一样很安全在您写作的时候如果电脑突然失去网络连接Cmd Markdown 会智能切换至离线模式,将您后续键入的文字保存在本地,直到网络恢复再将他们传送至云端,即使在网络恢复前关闭浏览器或者电脑,一样没有问题,等到下次开启 Cmd Markdown 的时候,她会提醒您将离线保存的文字传送至云端。简而言之,我们尽最大的努力保障您文字的安全。
### 1.3.6. 管理工具栏
为了便于管理您的文稿,在 **预览区** 的顶部放置了如下所示的 **管理工具栏**
通过管理工具栏可以:
<i class="icon-share"></i> 发布:将当前的文稿生成固定链接,在网络上发布,分享
<i class="icon-file"></i> 新建:开始撰写一篇新的文稿
<i class="icon-trash"></i> 删除:删除当前的文稿
<i class="icon-cloud"></i> 导出:将当前的文稿转化为 Markdown 文本或者 Html 格式,并导出到本地
<i class="icon-reorder"></i> 列表:所有新增和过往的文稿都可以在这里查看、操作
<i class="icon-pencil"></i> 模式:切换 普通/Vim/Emacs 编辑模式
### 1.3.7. 阅读工具栏
通过 **预览区** 右上角的 **阅读工具栏**,可以查看当前文稿的目录并增强阅读体验。
工具栏上的五个图标依次为:
<i class="icon-list"></i> 目录:快速导航当前文稿的目录结构以跳转到感兴趣的段落
<i class="icon-chevron-sign-left"></i> 视图:互换左边编辑区和右边预览区的位置
<i class="icon-adjust"></i> 主题:内置了黑白两种模式的主题,试试 **黑色主题**,超炫!
<i class="icon-desktop"></i> 阅读:心无旁骛的阅读模式提供超一流的阅读体验
<i class="icon-fullscreen"></i> 全屏:简洁,简洁,再简洁,一个完全沉浸式的写作和阅读环境
### 1.3.8. 阅读模式
**阅读工具栏** 点击 <i class="icon-desktop"></i> 或者按下 `Ctrl+Alt+M` 随即进入独立的阅读模式界面,我们在版面渲染上的每一个细节:字体,字号,行间距,前背景色都倾注了大量的时间,努力提升阅读的体验和品质。
### 1.3.9. 标签、分类和搜索
在编辑区任意行首位置输入以下格式的文字可以标签当前文档:
标签: 未分类
标签以后的文稿在【文件列表】Ctrl+Alt+F里会按照标签分类用户可以同时使用键盘或者鼠标浏览查看或者在【文件列表】的搜索文本框内搜索标题关键字过滤文稿如下图所示
![file-list](file-list.png)
### 1.3.10. 文稿发布和分享
在您使用 Cmd Markdown 记录,创作,整理,阅读文稿的同时,我们不仅希望它是一个有力的工具,更希望您的思想和知识通过这个平台,连同优质的阅读体验,将他们分享给有相同志趣的人,进而鼓励更多的人来到这里记录分享他们的思想和知识,尝试点击 <i class="icon-share"></i> (Ctrl+Alt+P) 发布这份文档给好友吧!
------
再一次感谢您花费时间阅读这份欢迎稿,点击 <i class="icon-file"></i> (Ctrl+Alt+N) 开始撰写新的文稿吧!祝您在这里记录、阅读、分享愉快!
作者 [@ghosert][3]
2015 年 06月 15日
[^LaTeX]: 支持 **LaTeX** 编辑显示支持,例如:$\sum_{i=1}^n a_i=0$ 访问 [MathJax][4] 参考更多使用方法。
[^code]: 代码高亮功能支持包括 Java, Python, JavaScript 在内的,**四十一**种主流编程语言。
[1]: https://www.zybuluo.com/mdeditor?url=https://www.zybuluo.com/static/editor/md-help.markdown
[2]: https://www.zybuluo.com/mdeditor?url=https://www.zybuluo.com/static/editor/md-help.markdown#cmd-markdown-高阶语法手册
[3]: http://weibo.com/ghosert
[4]: http://meta.math.stackexchange.com/questions/5020/mathjax-basic-tutorial-and-quick-reference

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.6 KiB

View File

@@ -0,0 +1,74 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"crypto/sha256"
"encoding/base64"
"fmt"
"math/rand"
"path"
"strings"
"time"
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
)
const baseWord = "1Aa2Bb3Cc4Dd5Ee6Ff7Gg8Hh9Ii0JjKkLlMmNnOoPpQqRrSsTtUuVvWwXxYyZz"
const knowledgePrefix = "BIZ_KNOWLEDGE"
const imgSrcFormat = `<img src="" data-tos-key="%s">`
func createSecret(uid int64, fileType string) string {
num := 10
input := fmt.Sprintf("upload_%d_Ma*9)fhi_%d_gou_%s_rand_%d", uid, time.Now().Unix(), fileType, rand.Intn(100000))
// 做md5取前20个,// mapIntToBase62 把数字映射到 Base62
hash := sha256.Sum256([]byte(fmt.Sprintf("%s", input)))
hashString := base64.StdEncoding.EncodeToString(hash[:])
if len(hashString) > num {
hashString = hashString[:num]
}
result := ""
for _, char := range hashString {
index := int(char) % 62
result += string(baseWord[index])
}
return result
}
func getExtension(uri string) string {
if uri == "" {
return ""
}
fileExtension := path.Base(uri)
ext := path.Ext(fileExtension)
if ext != "" {
return strings.TrimPrefix(ext, ".")
}
return ""
}
func getCreatorIDFromExtraMeta(extraMeta map[string]any) int64 {
if extraMeta == nil {
return 0
}
if uid, ok := extraMeta[document.MetaDataKeyCreatorID]; ok {
if uidInt, ok := uid.(int64); ok {
return uidInt
}
}
return 0
}

View File

@@ -0,0 +1,151 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package progressbar
import (
"context"
"errors"
"fmt"
"time"
"github.com/redis/go-redis/v9"
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/progressbar"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
type ProgressBarImpl struct {
CacheCli cache.Cmdable
PrimaryKeyID int64
Total int64
ErrMsg string
}
const (
ttl = time.Hour * 2
ProgressBarStartTimeRedisKey = "RedisBiz.Knowledge_ProgressBar_StartTime_%d"
ProgressBarErrMsgRedisKey = "RedisBiz.Knowledge_ProgressBar_ErrMsg_%d"
ProgressBarTotalNumRedisKey = "RedisBiz.Knowledge_ProgressBar_TotalNum_%d"
ProgressBarProcessedNumRedisKey = "RedisBiz.Knowledge_ProgressBar_ProcessedNum_%d"
DefaultProcessTime = 300
ProcessDone = 100
ProcessInit = 0
)
func NewProgressBar(ctx context.Context, pkID int64, total int64, CacheCli cache.Cmdable, needInit bool) progressbar.ProgressBar {
if needInit {
CacheCli.Set(ctx, fmt.Sprintf(ProgressBarTotalNumRedisKey, pkID), total, ttl)
CacheCli.Set(ctx, fmt.Sprintf(ProgressBarProcessedNumRedisKey, pkID), 0, ttl)
CacheCli.Set(ctx, fmt.Sprintf(ProgressBarErrMsgRedisKey, pkID), "", ttl)
CacheCli.Set(ctx, fmt.Sprintf(ProgressBarStartTimeRedisKey, pkID), time.Now().Unix(), ttl)
}
return &ProgressBarImpl{
PrimaryKeyID: pkID,
Total: total,
CacheCli: CacheCli,
}
}
func (p *ProgressBarImpl) AddN(n int) error {
if p.ErrMsg != "" {
return errors.New(p.ErrMsg)
}
_, err := p.CacheCli.IncrBy(context.Background(), fmt.Sprintf(ProgressBarProcessedNumRedisKey, p.PrimaryKeyID), int64(n)).Result()
if err != nil {
return err
}
return nil
}
func (p *ProgressBarImpl) ReportError(err error) error {
p.ErrMsg = err.Error()
_, err = p.CacheCli.Set(context.Background(), fmt.Sprintf(ProgressBarErrMsgRedisKey, p.PrimaryKeyID), err.Error(), ttl).Result()
if err != nil {
return err
}
return nil
}
func (p *ProgressBarImpl) GetProgress(ctx context.Context) (percent int, remainSec int, errMsg string) {
var (
totalNum *int64
processedNum *int64
startTime *int64
err error
)
errMsg, err = p.CacheCli.Get(ctx, fmt.Sprintf(ProgressBarErrMsgRedisKey, p.PrimaryKeyID)).Result()
if err == redis.Nil {
errMsg = ""
} else if err != nil {
return ProcessDone, 0, err.Error()
}
if len(errMsg) != 0 {
return ProcessDone, 0, errMsg
}
totalNumStr, err := p.CacheCli.Get(ctx, fmt.Sprintf(ProgressBarTotalNumRedisKey, p.PrimaryKeyID)).Result()
if err == redis.Nil || len(totalNumStr) == 0 {
totalNum = ptr.Of(int64(0))
} else if err != nil {
return ProcessDone, 0, err.Error()
} else {
num, err := conv.StrToInt64(totalNumStr)
if err != nil {
totalNum = ptr.Of(int64(0))
} else {
totalNum = ptr.Of(num)
}
}
processedNumStr, err := p.CacheCli.Get(ctx, fmt.Sprintf(ProgressBarProcessedNumRedisKey, p.PrimaryKeyID)).Result()
if err == redis.Nil || len(processedNumStr) == 0 {
processedNum = ptr.Of(int64(0))
} else if err != nil {
return ProcessDone, 0, err.Error()
} else {
num, err := conv.StrToInt64(processedNumStr)
if err != nil {
processedNum = ptr.Of(int64(0))
} else {
processedNum = ptr.Of(num)
}
}
if ptr.From(totalNum) == 0 {
return ProcessInit, DefaultProcessTime, ""
}
startTimeStr, err := p.CacheCli.Get(ctx, fmt.Sprintf(ProgressBarStartTimeRedisKey, p.PrimaryKeyID)).Result()
if err == redis.Nil || len(startTimeStr) == 0 {
startTime = ptr.Of(int64(0))
} else if err != nil {
return ProcessDone, 0, err.Error()
} else {
num, err := conv.StrToInt64(startTimeStr)
if err != nil {
startTime = ptr.Of(int64(0))
} else {
startTime = ptr.Of(num)
}
}
percent = int(float64(ptr.From(processedNum)) / float64(ptr.From(totalNum)) * 100)
if ptr.From(startTime) == 0 {
remainSec = DefaultProcessTime
} else {
usedSec := time.Now().Unix() - ptr.From(startTime)
remainSec = int(float64(ptr.From(totalNum)-ptr.From(processedNum)) / float64(ptr.From(processedNum)) * float64(usedSec))
}
return
}

View File

@@ -0,0 +1,70 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package rrf
import (
"context"
"fmt"
"sort"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/rerank"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
func NewRRFReranker(k int64) rerank.Reranker {
if k == 0 {
k = 60
}
return &rrfReranker{k}
}
type rrfReranker struct {
k int64
}
func (r *rrfReranker) Rerank(ctx context.Context, req *rerank.Request) (*rerank.Response, error) {
if req == nil || req.Data == nil || len(req.Data) == 0 {
return nil, fmt.Errorf("invalid request: no data provided")
}
id2Score := make(map[string]float64)
id2Data := make(map[string]*rerank.Data)
for _, resultList := range req.Data {
for rank := range resultList {
result := resultList[rank]
if result != nil && result.Document != nil {
score := 1.0 / (float64(rank) + float64(r.k))
if score > id2Score[result.Document.ID] {
id2Score[result.Document.ID] = score
id2Data[result.Document.ID] = result
}
}
}
}
var sorted []*rerank.Data
for _, data := range id2Data {
sorted = append(sorted, data)
}
sort.Slice(sorted, func(i, j int) bool {
return id2Score[sorted[i].Document.ID] > id2Score[sorted[j].Document.ID]
})
topN := int64(len(sorted))
if req.TopN != nil && ptr.From(req.TopN) != 0 && ptr.From(req.TopN) < topN {
topN = ptr.From(req.TopN)
}
return &rerank.Response{SortedData: sorted[:topN]}, nil
}

View File

@@ -0,0 +1,161 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package vikingdb
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"sort"
"github.com/volcengine/volc-sdk-golang/base"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/rerank"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
type Config struct {
AK string
SK string
Region string // default cn-north-1
}
func NewReranker(config *Config) rerank.Reranker {
if config.Region == "" {
config.Region = "cn-north-1"
}
return &reranker{config: config}
}
const (
domain = "api-knowledgebase.mlp.cn-beijing.volces.com"
defaultModel = "base-multilingual-rerank"
)
type reranker struct {
config *Config
}
type rerankReq struct {
Datas []rerankData `json:"datas"`
RerankModel string `json:"rerank_model"`
}
type rerankData struct {
Query string `json:"query"`
Content string `json:"content"`
Title *string `json:"title,omitempty"`
}
type rerankResp struct {
Code int64 `json:"code"`
Message string `json:"message"`
Data struct {
Scores []float64 `json:"scores"`
TokenUsage int64 `json:"token_usage"`
} `json:"data"`
}
func (r *reranker) Rerank(ctx context.Context, req *rerank.Request) (*rerank.Response, error) {
rReq := &rerankReq{
Datas: make([]rerankData, 0, len(req.Data)),
RerankModel: defaultModel,
}
var flat []*rerank.Data
for _, channel := range req.Data {
flat = append(flat, channel...)
}
for _, item := range flat {
rReq.Datas = append(rReq.Datas, rerankData{
Query: req.Query,
Content: item.Document.Content,
})
}
body, err := json.Marshal(rReq)
if err != nil {
return nil, err
}
resp, err := http.DefaultClient.Do(r.prepareRequest(body))
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
rResp := rerankResp{}
if err = json.Unmarshal(respBody, &rResp); err != nil {
return nil, err
}
if rResp.Code != 0 {
return nil, fmt.Errorf("[Rerank] failed, code=%d, msg=%v", rResp.Code, rResp.Message)
}
sorted := make([]*rerank.Data, 0, len(rResp.Data.Scores))
for i, score := range rResp.Data.Scores {
sorted = append(sorted, &rerank.Data{
Document: flat[i].Document,
Score: score,
})
}
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].Score > sorted[j].Score
})
right := len(sorted)
if req.TopN != nil {
right = min(right, int(*req.TopN))
}
return &rerank.Response{
SortedData: sorted[:right],
TokenUsage: ptr.Of(rResp.Data.TokenUsage),
}, nil
}
func (r *reranker) prepareRequest(body []byte) *http.Request {
u := url.URL{
Scheme: "https",
Host: domain,
Path: "/api/knowledge/service/rerank",
}
req, _ := http.NewRequest(http.MethodPost, u.String(), bytes.NewReader(body))
req.Header.Add("Accept", "application/json")
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Host", domain)
credential := base.Credentials{
AccessKeyID: r.config.AK,
SecretAccessKey: r.config.SK,
Service: "air",
Region: r.config.Region,
}
req = credential.Sign(req)
return req
}

View File

@@ -0,0 +1,48 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package vikingdb
//func TestRun(t *testing.T) {
// AK := os.Getenv("test_ak")
// SK := os.Getenv("test_sk")
//
// r := NewReranker(&Config{
// AK: AK,
// SK: SK,
// })
// resp, err := r.Rerank(context.Background(), &rerank.Request{
// Data: [][]*knowledge.RetrieveSlice{
// {
// {Slice: &entity.Slice{PlainText: "吉尼斯世界纪录网站数据显示蓝鲸是目前已知世界上最大的动物体长可达30米相当于一架波音737飞机的长度"}},
// {Slice: &entity.Slice{PlainText: "一头成年雌性弓头鲸可以长到22米长而一头雄性鲸鱼可以长到18米长"}},
// },
// },
// Query: "世界上最大的鲸鱼是什么?",
// TopN: nil,
// })
// assert.NoError(t, err)
//
// for _, item := range resp.Sorted {
// fmt.Println(item.Slice.PlainText, item.Score)
// }
// // 吉尼斯世界纪录网站数据显示蓝鲸是目前已知世界上最大的动物体长可达30米相当于一架波音737飞机的长度 0.6209664529733573
// // 一头成年雌性弓头鲸可以长到22米长而一头雄性鲸鱼可以长到18米长 0.4269785303456468
//
// fmt.Println(resp.TokenUsage)
// // 95
//
//}

View File

@@ -0,0 +1,21 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package elasticsearch
const (
topK = 10
)

View File

@@ -0,0 +1,116 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package elasticsearch
import (
"context"
"fmt"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
"github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
"github.com/coze-dev/coze-studio/backend/infra/contract/es"
)
type ManagerConfig struct {
Client es.Client
}
func NewManager(config *ManagerConfig) searchstore.Manager {
return &esManager{config: config}
}
type esManager struct {
config *ManagerConfig
}
func (e *esManager) Create(ctx context.Context, req *searchstore.CreateRequest) error {
cli := e.config.Client
index := req.CollectionName
indexExists, err := cli.Exists(ctx, index)
if err != nil {
return err
}
if indexExists { // exists
return nil
}
properties := make(map[string]any)
var foundID, foundCreatorID, foundTextContent bool
for _, field := range req.Fields {
switch field.Name {
case searchstore.FieldID:
foundID = true
case searchstore.FieldCreatorID:
foundCreatorID = true
case searchstore.FieldTextContent:
foundTextContent = true
default:
}
var property any
switch field.Type {
case searchstore.FieldTypeInt64:
property = cli.Types().NewLongNumberProperty()
case searchstore.FieldTypeText:
property = cli.Types().NewTextProperty()
default:
return fmt.Errorf("[Create] es unsupported field type: %d", field.Type)
}
properties[field.Name] = property
}
if !foundID {
properties[searchstore.FieldID] = cli.Types().NewLongNumberProperty()
}
if !foundCreatorID {
properties[searchstore.FieldCreatorID] = cli.Types().NewUnsignedLongNumberProperty()
}
if !foundTextContent {
properties[searchstore.FieldTextContent] = cli.Types().NewTextProperty()
}
err = cli.CreateIndex(ctx, index, properties)
if err != nil {
return err
}
return err
}
func (e *esManager) Drop(ctx context.Context, req *searchstore.DropRequest) error {
cli := e.config.Client
index := req.CollectionName
return cli.DeleteIndex(ctx, index)
}
func (e *esManager) GetType() searchstore.SearchStoreType {
return searchstore.TypeTextStore
}
func (e *esManager) GetSearchStore(ctx context.Context, collectionName string) (searchstore.SearchStore, error) {
return &esSearchStore{
config: e.config,
indexName: collectionName,
}, nil
}
func (e *esManager) GetEmbedding() embedding.Embedder {
return nil
}

View File

@@ -0,0 +1,329 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package elasticsearch
import (
"bytes"
"context"
"encoding/json"
"fmt"
"reflect"
"strconv"
"github.com/cloudwego/eino/components/indexer"
"github.com/cloudwego/eino/components/retriever"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
"github.com/coze-dev/coze-studio/backend/infra/contract/es"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
type esSearchStore struct {
config *ManagerConfig
indexName string
}
func (e *esSearchStore) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) (ids []string, err error) {
implSpecOptions := indexer.GetImplSpecificOptions(&searchstore.IndexerOptions{}, opts...)
defer func() {
if err != nil {
if implSpecOptions.ProgressBar != nil {
implSpecOptions.ProgressBar.ReportError(err)
}
}
}()
cli := e.config.Client
index := e.indexName
bi, err := cli.NewBulkIndexer(index)
if err != nil {
return nil, err
}
ids = make([]string, 0, len(docs))
for _, doc := range docs {
fieldMapping, err := e.fromDocument(doc)
if err != nil {
return nil, err
}
body, err := json.Marshal(fieldMapping)
if err != nil {
return nil, err
}
if err = bi.Add(ctx, es.BulkIndexerItem{
Index: e.indexName,
Action: "index",
DocumentID: doc.ID,
Body: bytes.NewReader(body),
}); err != nil {
return nil, err
}
ids = append(ids, doc.ID)
if implSpecOptions.ProgressBar != nil {
if err = implSpecOptions.ProgressBar.AddN(1); err != nil {
return nil, err
}
}
}
if err = bi.Close(ctx); err != nil {
return nil, err
}
return ids, nil
}
func (e *esSearchStore) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) {
var (
cli = e.config.Client
index = e.indexName
options = retriever.GetCommonOptions(&retriever.Options{TopK: ptr.Of(topK)}, opts...)
implSpecOptions = retriever.GetImplSpecificOptions(&searchstore.RetrieverOptions{}, opts...)
req = &es.Request{
Query: &es.Query{
Bool: &es.BoolQuery{},
},
Size: options.TopK,
}
)
if implSpecOptions.MultiMatch == nil {
req.Query.Bool.Must = append(req.Query.Bool.Must,
es.NewMatchQuery(searchstore.FieldTextContent, query))
} else {
req.Query.Bool.Must = append(req.Query.Bool.Must,
es.NewMultiMatchQuery(implSpecOptions.MultiMatch.Fields, query,
"best_fields", es.Or))
}
dsl, err := searchstore.LoadDSL(options.DSLInfo)
if err != nil {
return nil, err
}
if err = e.travDSL(req.Query, dsl); err != nil {
return nil, err
}
if options.ScoreThreshold != nil {
req.MinScore = options.ScoreThreshold
}
resp, err := cli.Search(ctx, index, req)
if err != nil {
return nil, err
}
docs, err := e.parseSearchResult(resp)
if err != nil {
return nil, err
}
return docs, nil
}
func (e *esSearchStore) Delete(ctx context.Context, ids []string) error {
bi, err := e.config.Client.NewBulkIndexer(e.indexName)
if err != nil {
return err
}
for _, id := range ids {
if err = bi.Add(ctx, es.BulkIndexerItem{
Index: e.indexName,
Action: "delete",
DocumentID: id,
}); err != nil {
return err
}
}
return bi.Close(ctx)
}
func (e *esSearchStore) travDSL(query *es.Query, dsl *searchstore.DSL) error {
if dsl == nil {
return nil
}
switch dsl.Op {
case searchstore.OpEq, searchstore.OpNe:
arr := stringifyValue(dsl.Value)
v := dsl.Value
if len(arr) > 0 {
v = arr[0]
}
if dsl.Op == searchstore.OpEq {
query.Bool.Must = append(query.Bool.Must,
es.NewEqualQuery(dsl.Field, v))
} else {
query.Bool.MustNot = append(query.Bool.MustNot,
es.NewEqualQuery(dsl.Field, v))
}
case searchstore.OpLike:
s, ok := dsl.Value.(string)
if !ok {
return fmt.Errorf("[travDSL] OpLike value should be string, but got %v", dsl.Value)
}
query.Bool.Must = append(query.Bool.Must, es.NewMatchQuery(dsl.Field, s))
case searchstore.OpIn:
query.Bool.Must = append(query.Bool.MustNot,
es.NewInQuery(dsl.Field, stringifyValue(dsl.Value)))
case searchstore.OpAnd, searchstore.OpOr:
conds, ok := dsl.Value.([]*searchstore.DSL)
if !ok {
return fmt.Errorf("[travDSL] value type assertion failed for or")
}
for _, cond := range conds {
sub := &es.Query{}
if err := e.travDSL(sub, cond); err != nil {
return err
}
if dsl.Op == searchstore.OpOr {
query.Bool.Should = append(query.Bool.Should, *sub)
} else {
query.Bool.Must = append(query.Bool.Must, *sub)
}
}
default:
return fmt.Errorf("[trav] unknown op %s", dsl.Op)
}
return nil
}
func (e *esSearchStore) parseSearchResult(resp *es.Response) (docs []*schema.Document, err error) {
docs = make([]*schema.Document, 0, len(resp.Hits.Hits))
firstScore := 0.0
for i, hit := range resp.Hits.Hits {
var src map[string]any
d := json.NewDecoder(bytes.NewReader(hit.Source_))
d.UseNumber()
if err = d.Decode(&src); err != nil {
return nil, err
}
ext := make(map[string]any)
doc := &schema.Document{MetaData: map[string]any{document.MetaDataKeyExternalStorage: ext}}
for field, val := range src {
ok := true
switch field {
case searchstore.FieldTextContent:
doc.Content, ok = val.(string)
case searchstore.FieldCreatorID:
var jn json.Number
jn, ok = val.(json.Number)
if ok {
doc.MetaData[document.MetaDataKeyCreatorID], ok = assertJSONNumber(jn).(int64)
}
default:
if jn, jok := val.(json.Number); jok {
ext[field] = assertJSONNumber(jn)
} else {
ext[field] = val
}
}
if !ok {
return nil, fmt.Errorf("[parseSearchResult] type assertion failed, field=%s, val=%v", field, val)
}
}
if hit.Id_ != nil {
doc.ID = *hit.Id_
}
if hit.Score_ == nil { // unexpected
return nil, fmt.Errorf("[parseSearchResult] es retrieve score not found")
}
score := float64(ptr.From(hit.Score_))
if i == 0 {
firstScore = score
}
doc.WithScore(score / firstScore)
docs = append(docs, doc)
}
return docs, nil
}
func (e *esSearchStore) fromDocument(doc *schema.Document) (map[string]any, error) {
if doc.MetaData == nil {
return nil, fmt.Errorf("[fromDocument] es document meta data is nil")
}
creatorID, ok := doc.MetaData[searchstore.FieldCreatorID].(int64)
if !ok {
return nil, fmt.Errorf("[fromDocument] creator id not found or type invalid")
}
fieldMapping := map[string]any{
searchstore.FieldTextContent: doc.Content,
searchstore.FieldCreatorID: creatorID,
}
if ext, ok := doc.MetaData[document.MetaDataKeyExternalStorage].(map[string]any); ok {
for k, v := range ext {
fieldMapping[k] = v
}
}
return fieldMapping, nil
}
func stringifyValue(dslValue any) []any {
value := reflect.ValueOf(dslValue)
switch value.Kind() {
case reflect.Slice, reflect.Array:
length := value.Len()
slice := make([]any, 0, length)
for i := 0; i < length; i++ {
elem := value.Index(i)
switch elem.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
slice = append(slice, strconv.FormatInt(elem.Int(), 10))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
slice = append(slice, strconv.FormatUint(elem.Uint(), 10))
case reflect.Float32, reflect.Float64:
slice = append(slice, strconv.FormatFloat(elem.Float(), 'f', -1, 64))
case reflect.String:
slice = append(slice, elem.String())
default:
slice = append(slice, elem) // do nothing
}
}
return slice
default:
return []any{fmt.Sprintf("%v", value)}
}
}
func assertJSONNumber(f json.Number) any {
if i64, err := f.Int64(); err == nil {
return i64
}
if f64, err := f.Float64(); err == nil {
return f64
}
return f.String()
}

View File

@@ -0,0 +1,22 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package milvus
const (
batchSize = 100
topK = 4
)

View File

@@ -0,0 +1,119 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package milvus
import (
"fmt"
"github.com/milvus-io/milvus/client/v2/entity"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
)
func denseFieldName(name string) string {
return fmt.Sprintf("dense_%s", name)
}
func denseIndexName(name string) string {
return fmt.Sprintf("index_dense_%s", name)
}
func sparseFieldName(name string) string {
return fmt.Sprintf("sparse_%s", name)
}
func sparseIndexName(name string) string {
return fmt.Sprintf("index_sparse_%s", name)
}
func convertFieldType(typ searchstore.FieldType) (entity.FieldType, error) {
switch typ {
case searchstore.FieldTypeInt64:
return entity.FieldTypeInt64, nil
case searchstore.FieldTypeText:
return entity.FieldTypeVarChar, nil
case searchstore.FieldTypeDenseVector:
return entity.FieldTypeFloatVector, nil
case searchstore.FieldTypeSparseVector:
return entity.FieldTypeSparseVector, nil
default:
return entity.FieldTypeNone, fmt.Errorf("[convertFieldType] unknown field type: %v", typ)
}
}
func convertDense(dense [][]float64) [][]float32 {
return slices.Transform(dense, func(a []float64) []float32 {
r := make([]float32, len(a))
for i := 0; i < len(a); i++ {
r[i] = float32(a[i])
}
return r
})
}
func convertMilvusDenseVector(dense [][]float64) []entity.Vector {
return slices.Transform(dense, func(a []float64) entity.Vector {
r := make([]float32, len(a))
for i := 0; i < len(a); i++ {
r[i] = float32(a[i])
}
return entity.FloatVector(r)
})
}
func convertSparse(sparse []map[int]float64) ([]entity.SparseEmbedding, error) {
r := make([]entity.SparseEmbedding, 0, len(sparse))
for _, s := range sparse {
ks := make([]uint32, 0, len(s))
vs := make([]float32, 0, len(s))
for k, v := range s {
ks = append(ks, uint32(k))
vs = append(vs, float32(v))
}
se, err := entity.NewSliceSparseEmbedding(ks, vs)
if err != nil {
return nil, err
}
r = append(r, se)
}
return r, nil
}
func convertMilvusSparseVector(sparse []map[int]float64) ([]entity.Vector, error) {
r := make([]entity.Vector, 0, len(sparse))
for _, s := range sparse {
ks := make([]uint32, 0, len(s))
vs := make([]float32, 0, len(s))
for k, v := range s {
ks = append(ks, uint32(k))
vs = append(vs, float32(v))
}
se, err := entity.NewSliceSparseEmbedding(ks, vs)
if err != nil {
return nil, err
}
r = append(r, se)
}
return r, nil
}

View File

@@ -0,0 +1,334 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package milvus
import (
"context"
"fmt"
"strings"
mentity "github.com/milvus-io/milvus/client/v2/entity"
mindex "github.com/milvus-io/milvus/client/v2/index"
client "github.com/milvus-io/milvus/client/v2/milvusclient"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
"github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/sets"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
type ManagerConfig struct {
Client *client.Client // required
Embedding embedding.Embedder // required
EnableHybrid *bool // optional: default Embedding.SupportStatus() == embedding.SupportDenseAndSparse
DenseIndex mindex.Index // optional: default HNSW, M=30, efConstruction=360
DenseMetric mentity.MetricType // optional: default IP
SparseIndex mindex.Index // optional: default SPARSE_INVERTED_INDEX, drop_ratio=0.2
SparseMetric mentity.MetricType // optional: default IP
ShardNum int // optional: default 1
BatchSize int // optional: default 100
}
func NewManager(config *ManagerConfig) (searchstore.Manager, error) {
if config.Client == nil {
return nil, fmt.Errorf("[NewManager] milvus client not provided")
}
if config.Embedding == nil {
return nil, fmt.Errorf("[NewManager] milvus embedder not provided")
}
enableSparse := config.Embedding.SupportStatus() == embedding.SupportDenseAndSparse
if config.EnableHybrid == nil {
config.EnableHybrid = ptr.Of(enableSparse)
} else if !enableSparse && ptr.From(config.EnableHybrid) {
logs.Warnf("[NewManager] milvus embedding not support sparse, so hybrid search is disabled.")
config.EnableHybrid = ptr.Of(false)
}
if config.DenseMetric == "" {
config.DenseMetric = mentity.IP
}
if config.DenseIndex == nil {
config.DenseIndex = mindex.NewHNSWIndex(config.DenseMetric, 30, 360)
}
if config.SparseMetric == "" {
config.SparseMetric = mentity.IP
}
if config.SparseIndex == nil {
config.SparseIndex = mindex.NewSparseInvertedIndex(config.SparseMetric, 0.2)
}
if config.ShardNum == 0 {
config.ShardNum = 1
}
if config.BatchSize == 0 {
config.BatchSize = 100
}
return &milvusManager{config: config}, nil
}
type milvusManager struct {
config *ManagerConfig
}
func (m *milvusManager) Create(ctx context.Context, req *searchstore.CreateRequest) error {
if err := m.createCollection(ctx, req); err != nil {
return fmt.Errorf("[Create] create collection failed, %w", err)
}
if err := m.createIndexes(ctx, req); err != nil {
return fmt.Errorf("[Create] create indexes failed, %w", err)
}
if exists, err := m.loadCollection(ctx, req.CollectionName); err != nil {
return fmt.Errorf("[Create] load collection failed, %w", err)
} else if !exists {
return fmt.Errorf("[Create] load collection failed, collection=%v does not exist", req.CollectionName)
}
return nil
}
func (m *milvusManager) Drop(ctx context.Context, req *searchstore.DropRequest) error {
return m.config.Client.DropCollection(ctx, client.NewDropCollectionOption(req.CollectionName))
}
func (m *milvusManager) GetType() searchstore.SearchStoreType {
return searchstore.TypeVectorStore
}
func (m *milvusManager) GetSearchStore(ctx context.Context, collectionName string) (searchstore.SearchStore, error) {
if exists, err := m.loadCollection(ctx, collectionName); err != nil {
return nil, err
} else if !exists {
return nil, errorx.New(errno.ErrKnowledgeNonRetryableCode,
errorx.KVf("reason", "[GetSearchStore] collection=%v does not exist", collectionName))
}
return &milvusSearchStore{
config: m.config,
collectionName: collectionName,
}, nil
}
func (m *milvusManager) createCollection(ctx context.Context, req *searchstore.CreateRequest) error {
if req.CollectionName == "" || len(req.Fields) == 0 {
return fmt.Errorf("[createCollection] invalid request params")
}
cli := m.config.Client
collectionName := req.CollectionName
has, err := cli.HasCollection(ctx, client.NewHasCollectionOption(collectionName))
if err != nil {
return fmt.Errorf("[createCollection] HasCollection failed, %w", err)
}
if has {
return nil
}
fields, err := m.convertFields(req.Fields)
if err != nil {
return err
}
opt := client.NewCreateCollectionOption(collectionName, &mentity.Schema{
CollectionName: collectionName,
Description: fmt.Sprintf("created by coze"),
AutoID: false,
Fields: fields,
EnableDynamicField: false,
}).WithShardNum(int32(m.config.ShardNum))
for k, v := range req.CollectionMeta {
opt.WithProperty(k, v)
}
if err = cli.CreateCollection(ctx, opt); err != nil {
return fmt.Errorf("[createCollection] CreateCollection failed, %w", err)
}
return nil
}
func (m *milvusManager) createIndexes(ctx context.Context, req *searchstore.CreateRequest) error {
collectionName := req.CollectionName
indexes, err := m.config.Client.ListIndexes(ctx, client.NewListIndexOption(req.CollectionName))
if err != nil {
if !strings.Contains(err.Error(), "index not found") {
return fmt.Errorf("[createIndexes] ListIndexes failed, %w", err)
}
}
createdIndexes := sets.FromSlice(indexes)
var ops []func() error
for i := range req.Fields {
f := req.Fields[i]
if !f.Indexing {
continue
}
ops = append(ops, m.tryCreateIndex(ctx, collectionName, denseFieldName(f.Name), denseIndexName(f.Name), m.config.DenseIndex, createdIndexes))
if m.config.Embedding.SupportStatus() == embedding.SupportDenseAndSparse {
ops = append(ops, m.tryCreateIndex(ctx, collectionName, sparseFieldName(f.Name), sparseIndexName(f.Name), m.config.SparseIndex, createdIndexes))
}
}
for _, op := range ops {
if err := op(); err != nil {
return fmt.Errorf("[createIndexes] failed, %w", err)
}
}
return nil
}
func (m *milvusManager) tryCreateIndex(ctx context.Context, collectionName, fieldName, indexName string, idx mindex.Index, createdIndexes sets.Set[string]) func() error {
return func() error {
if _, found := createdIndexes[indexName]; found {
logs.CtxInfof(ctx, "[tryCreateIndex] index exists, so skip, collectionName=%s, fieldName=%s, idx=%v, type=%s\n",
collectionName, fieldName, indexName, idx.IndexType())
return nil
}
cli := m.config.Client
task, err := cli.CreateIndex(ctx, client.NewCreateIndexOption(collectionName, fieldName, idx).WithIndexName(indexName))
if err != nil {
return fmt.Errorf("[tryCreateIndex] CreateIndex failed, %w", err)
}
if err = task.Await(ctx); err != nil {
return fmt.Errorf("[tryCreateIndex] await failed, %w", err)
}
logs.CtxInfof(ctx, "[tryCreateIndex] CreateIndex success, collectionName=%s, fieldName=%s, idx=%v, type=%s\n",
collectionName, fieldName, indexName, idx.IndexType())
return nil
}
}
func (m *milvusManager) loadCollection(ctx context.Context, collectionName string) (exists bool, err error) {
cli := m.config.Client
stat, err := cli.GetLoadState(ctx, client.NewGetLoadStateOption(collectionName))
if err != nil {
return false, fmt.Errorf("[loadCollection] GetLoadState failed, %w", err)
}
switch stat.State {
case mentity.LoadStateNotLoad:
task, err := cli.LoadCollection(ctx, client.NewLoadCollectionOption(collectionName))
if err != nil {
return false, fmt.Errorf("[loadCollection] LoadCollection failed, collection=%v, %w", collectionName, err)
}
if err = task.Await(ctx); err != nil {
return false, fmt.Errorf("[loadCollection] await failed, collection=%v, %w", collectionName, err)
}
return true, nil
case mentity.LoadStateLoaded:
return true, nil
case mentity.LoadStateLoading:
return true, fmt.Errorf("[loadCollection] collection is unloading, retry later, collection=%v", collectionName)
case mentity.LoadStateUnloading:
return false, nil
default:
return false, fmt.Errorf("[loadCollection] load state unexpected, state=%d", stat)
}
}
func (m *milvusManager) convertFields(fields []*searchstore.Field) ([]*mentity.Field, error) {
var foundID, foundCreatorID bool
resp := make([]*mentity.Field, 0, len(fields))
for _, f := range fields {
switch f.Name {
case searchstore.FieldID:
foundID = true
case searchstore.FieldCreatorID:
foundCreatorID = true
default:
}
if f.Indexing {
if f.Type != searchstore.FieldTypeText {
return nil, fmt.Errorf("[convertFields] milvus only support text field indexing, field=%s, type=%d", f.Name, f.Type)
}
// indexing 时只有 content 存储原文
if f.Name == searchstore.FieldTextContent {
resp = append(resp, mentity.NewField().
WithName(f.Name).
WithDescription(f.Description).
WithIsPrimaryKey(f.IsPrimary).
WithNullable(f.Nullable).
WithDataType(mentity.FieldTypeVarChar).
WithMaxLength(65535))
}
resp = append(resp, mentity.NewField().
WithName(denseFieldName(f.Name)).
WithDataType(mentity.FieldTypeFloatVector).
WithDim(m.config.Embedding.Dimensions()))
if m.config.Embedding.SupportStatus() == embedding.SupportDenseAndSparse {
resp = append(resp, mentity.NewField().
WithName(sparseFieldName(f.Name)).
WithDataType(mentity.FieldTypeSparseVector))
}
} else {
mf := mentity.NewField().
WithName(f.Name).
WithDescription(f.Description).
WithIsPrimaryKey(f.IsPrimary).
WithNullable(f.Nullable)
typ, err := convertFieldType(f.Type)
if err != nil {
return nil, err
}
mf.WithDataType(typ)
if typ == mentity.FieldTypeVarChar {
mf.WithMaxLength(65535)
} else if typ == mentity.FieldTypeFloatVector {
mf.WithDim(m.config.Embedding.Dimensions())
}
resp = append(resp, mf)
}
}
if !foundID {
resp = append(resp, mentity.NewField().
WithName(searchstore.FieldID).
WithDataType(mentity.FieldTypeInt64).
WithIsPrimaryKey(true).
WithNullable(false))
}
if !foundCreatorID {
resp = append(resp, mentity.NewField().
WithName(searchstore.FieldCreatorID).
WithDataType(mentity.FieldTypeInt64).
WithNullable(false))
}
return resp, nil
}
func (m *milvusManager) GetEmbedding() embedding.Embedder {
return m.config.Embedding
}

View File

@@ -0,0 +1,600 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package milvus
import (
"context"
"encoding/json"
"fmt"
"math"
"reflect"
"sort"
"strconv"
"strings"
"github.com/cloudwego/eino/components/indexer"
"github.com/cloudwego/eino/components/retriever"
"github.com/cloudwego/eino/schema"
"github.com/milvus-io/milvus/client/v2/column"
mentity "github.com/milvus-io/milvus/client/v2/entity"
mindex "github.com/milvus-io/milvus/client/v2/index"
client "github.com/milvus-io/milvus/client/v2/milvusclient"
"github.com/slongfield/pyfmt"
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
"github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/sets"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
)
type milvusSearchStore struct {
config *ManagerConfig
collectionName string
}
func (m *milvusSearchStore) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) (ids []string, err error) {
if len(docs) == 0 {
return nil, nil
}
implSpecOptions := indexer.GetImplSpecificOptions(&searchstore.IndexerOptions{}, opts...)
defer func() {
if err != nil {
if implSpecOptions.ProgressBar != nil {
implSpecOptions.ProgressBar.ReportError(err)
}
}
}()
indexingFields := make(sets.Set[string])
for _, field := range implSpecOptions.IndexingFields {
indexingFields[field] = struct{}{}
}
if implSpecOptions.Partition != nil {
partition := *implSpecOptions.Partition
hasPartition, err := m.config.Client.HasPartition(ctx, client.NewHasPartitionOption(m.collectionName, partition))
if err != nil {
return nil, fmt.Errorf("[Store] HasPartition failed, %w", err)
}
if !hasPartition {
if err = m.config.Client.CreatePartition(ctx, client.NewCreatePartitionOption(m.collectionName, partition)); err != nil {
return nil, fmt.Errorf("[Store] CreatePartition failed, %w", err)
}
}
}
for _, part := range slices.Chunks(docs, batchSize) {
columns, err := m.documents2Columns(ctx, part, indexingFields)
if err != nil {
return nil, err
}
createReq := client.NewColumnBasedInsertOption(m.collectionName, columns...)
if implSpecOptions.Partition != nil {
createReq.WithPartition(*implSpecOptions.Partition)
}
result, err := m.config.Client.Upsert(ctx, createReq)
if err != nil {
return nil, fmt.Errorf("[Store] upsert failed, %w", err)
}
partIDs := result.IDs
for i := 0; i < partIDs.Len(); i++ {
var sid string
if partIDs.Type() == mentity.FieldTypeInt64 {
id, err := partIDs.GetAsInt64(i)
if err != nil {
return nil, err
}
sid = strconv.FormatInt(id, 10)
} else {
sid, err = partIDs.GetAsString(i)
if err != nil {
return nil, err
}
}
ids = append(ids, sid)
}
if implSpecOptions.ProgressBar != nil {
if err = implSpecOptions.ProgressBar.AddN(len(part)); err != nil {
return nil, err
}
}
}
return ids, nil
}
func (m *milvusSearchStore) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) {
cli := m.config.Client
emb := m.config.Embedding
options := retriever.GetCommonOptions(&retriever.Options{TopK: ptr.Of(topK)}, opts...)
implSpecOptions := retriever.GetImplSpecificOptions(&searchstore.RetrieverOptions{}, opts...)
desc, err := cli.DescribeCollection(ctx, client.NewDescribeCollectionOption(m.collectionName))
if err != nil {
return nil, err
}
var (
dense [][]float64
sparse []map[int]float64
expr string
result []client.ResultSet
fields = desc.Schema.Fields
outputFields []string
enableSparse = m.enableSparse(fields)
)
if options.DSLInfo != nil {
expr, err = m.dsl2Expr(options.DSLInfo)
if err != nil {
return nil, err
}
}
if enableSparse {
dense, sparse, err = emb.EmbedStringsHybrid(ctx, []string{query})
if err != nil {
return nil, fmt.Errorf("[Retrieve] EmbedStringsHybrid failed, %w", err)
}
} else {
dense, err = emb.EmbedStrings(ctx, []string{query})
if err != nil {
return nil, fmt.Errorf("[Retrieve] EmbedStrings failed, %w", err)
}
}
dv := convertMilvusDenseVector(dense)
sv, err := convertMilvusSparseVector(sparse)
if err != nil {
return nil, err
}
for _, field := range fields {
outputFields = append(outputFields, field.Name)
}
var scoreNormType *mindex.MetricType
if enableSparse {
var annRequests []*client.AnnRequest
for _, field := range fields {
var (
vector []mentity.Vector
metricsType mindex.MetricType
)
if field.DataType == mentity.FieldTypeFloatVector {
vector = dv
metricsType, err = m.getIndexMetricsType(ctx, denseIndexName(field.Name))
} else if field.DataType == mentity.FieldTypeSparseVector {
vector = sv
metricsType, err = m.getIndexMetricsType(ctx, sparseIndexName(field.Name))
}
if err != nil {
return nil, err
}
annRequests = append(annRequests,
client.NewAnnRequest(field.Name, ptr.From(options.TopK), vector...).
WithSearchParam(mindex.MetricTypeKey, string(metricsType)).
WithFilter(expr),
)
}
searchOption := client.NewHybridSearchOption(m.collectionName, ptr.From(options.TopK), annRequests...).
WithPartitons(implSpecOptions.Partitions...).
WithReranker(client.NewRRFReranker()).
WithOutputFields(outputFields...)
result, err = cli.HybridSearch(ctx, searchOption)
if err != nil {
return nil, fmt.Errorf("[Retrieve] HybridSearch failed, %w", err)
}
} else {
indexes, err := cli.ListIndexes(ctx, client.NewListIndexOption(m.collectionName))
if err != nil {
return nil, fmt.Errorf("[Retrieve] ListIndexes failed, %w", err)
}
if len(indexes) != 1 {
return nil, fmt.Errorf("[Retrieve] restrict single index ann search, but got %d, collection=%s",
len(indexes), m.collectionName)
}
metricsType, err := m.getIndexMetricsType(ctx, indexes[0])
if err != nil {
return nil, err
}
scoreNormType = &metricsType
searchOption := client.NewSearchOption(m.collectionName, ptr.From(options.TopK), dv).
WithPartitions(implSpecOptions.Partitions...).
WithFilter(expr).
WithOutputFields(outputFields...).
WithSearchParam(mindex.MetricTypeKey, string(metricsType))
result, err = cli.Search(ctx, searchOption)
if err != nil {
return nil, fmt.Errorf("[Retrieve] Search failed, %w", err)
}
}
docs, err := m.resultSet2Document(result, scoreNormType)
if err != nil {
return nil, fmt.Errorf("[Retrieve] resultSet2Document failed, %w", err)
}
return docs, nil
}
func (m *milvusSearchStore) Delete(ctx context.Context, ids []string) error {
int64IDs := make([]int64, 0, len(ids))
for _, sid := range ids {
id, err := strconv.ParseInt(sid, 10, 64)
if err != nil {
return err
}
int64IDs = append(int64IDs, id)
}
_, err := m.config.Client.Delete(ctx,
client.NewDeleteOption(m.collectionName).WithInt64IDs(searchstore.FieldID, int64IDs))
return err
}
func (m *milvusSearchStore) documents2Columns(ctx context.Context, docs []*schema.Document, indexingFields sets.Set[string]) (
cols []column.Column, err error) {
var (
ids []int64
contents []string
creatorIDs []int64
emptyContents = true
)
colMapping := map[string]any{}
colTypeMapping := map[string]searchstore.FieldType{
searchstore.FieldID: searchstore.FieldTypeInt64,
searchstore.FieldCreatorID: searchstore.FieldTypeInt64,
searchstore.FieldTextContent: searchstore.FieldTypeText,
}
for _, doc := range docs {
if doc.MetaData == nil {
return nil, fmt.Errorf("[documents2Columns] meta data is nil")
}
id, err := strconv.ParseInt(doc.ID, 10, 64)
if err != nil {
return nil, fmt.Errorf("[documents2Columns] parse id failed, %w", err)
}
ids = append(ids, id)
contents = append(contents, doc.Content)
if doc.Content != "" {
emptyContents = false
}
creatorID, err := document.GetDocumentCreatorID(doc)
if err != nil {
return nil, fmt.Errorf("[documents2Columns] creator_id not found or type invalid., %w", err)
}
creatorIDs = append(creatorIDs, creatorID)
ext, ok := doc.MetaData[document.MetaDataKeyExternalStorage].(map[string]any)
if !ok {
continue
}
for field := range ext {
val := ext[field]
container := colMapping[field]
switch t := val.(type) {
case uint, uint8, uint16, uint32, uint64, uintptr:
var c []int64
if container == nil {
colTypeMapping[field] = searchstore.FieldTypeInt64
} else {
c, ok = container.([]int64)
if !ok {
return nil, fmt.Errorf("[documents2Columns] container type not int64")
}
}
c = append(c, int64(reflect.ValueOf(t).Uint()))
colMapping[field] = c
case int, int8, int16, int32, int64:
var c []int64
if container == nil {
colTypeMapping[field] = searchstore.FieldTypeInt64
} else {
c, ok = container.([]int64)
if !ok {
return nil, fmt.Errorf("[documents2Columns] container type not int64")
}
}
c = append(c, reflect.ValueOf(t).Int())
colMapping[field] = c
case string:
var c []string
if container == nil {
colTypeMapping[field] = searchstore.FieldTypeText
} else {
c, ok = container.([]string)
if !ok {
return nil, fmt.Errorf("[documents2Columns] container type not int64")
}
}
c = append(c, t)
colMapping[field] = c
case []float64:
var c [][]float64
if container == nil {
container = c
colTypeMapping[field] = searchstore.FieldTypeDenseVector
} else {
c, ok = container.([][]float64)
if !ok {
return nil, fmt.Errorf("[documents2Columns] container type not int64")
}
}
c = append(c, t)
colMapping[field] = c
case map[int]float64:
var c []map[int]float64
if container == nil {
container = c
colTypeMapping[field] = searchstore.FieldTypeSparseVector
} else {
c, ok = container.([]map[int]float64)
if !ok {
return nil, fmt.Errorf("[documents2Columns] container type not int64")
}
}
c = append(c, t)
colMapping[field] = c
default:
return nil, fmt.Errorf("[documents2Columns] val type not support, val=%v", val)
}
}
}
colMapping[searchstore.FieldID] = ids
colMapping[searchstore.FieldCreatorID] = creatorIDs
colMapping[searchstore.FieldTextContent] = contents
for fieldName, container := range colMapping {
colType := colTypeMapping[fieldName]
switch colType {
case searchstore.FieldTypeInt64:
c, ok := container.([]int64)
if !ok {
return nil, fmt.Errorf("[documents2Columns] container type not int64")
}
cols = append(cols, column.NewColumnInt64(fieldName, c))
case searchstore.FieldTypeText:
c, ok := container.([]string)
if !ok {
return nil, fmt.Errorf("[documents2Columns] container type not string")
}
if _, indexing := indexingFields[fieldName]; indexing {
if fieldName == searchstore.FieldTextContent && !emptyContents {
cols = append(cols, column.NewColumnVarChar(fieldName, c))
}
var (
emb = m.config.Embedding
dense [][]float64
sparse []map[int]float64
)
if emb.SupportStatus() == embedding.SupportDenseAndSparse {
dense, sparse, err = emb.EmbedStringsHybrid(ctx, c)
} else {
dense, err = emb.EmbedStrings(ctx, c)
}
if err != nil {
return nil, fmt.Errorf("[slices2Columns] embed failed, %w", err)
}
cols = append(cols, column.NewColumnFloatVector(denseFieldName(fieldName), int(emb.Dimensions()), convertDense(dense)))
if emb.SupportStatus() == embedding.SupportDenseAndSparse {
s, err := convertSparse(sparse)
if err != nil {
return nil, err
}
cols = append(cols, column.NewColumnSparseVectors(sparseFieldName(fieldName), s))
}
} else {
cols = append(cols, column.NewColumnVarChar(fieldName, c))
}
case searchstore.FieldTypeDenseVector:
c, ok := container.([][]float64)
if !ok {
return nil, fmt.Errorf("[documents2Columns] container type not []float64")
}
cols = append(cols, column.NewColumnFloatVector(fieldName, int(m.config.Embedding.Dimensions()), convertDense(c)))
case searchstore.FieldTypeSparseVector:
c, ok := container.([]map[int]float64)
if !ok {
return nil, fmt.Errorf("[documents2Columns] container type not map[int]float64")
}
sparse, err := convertSparse(c)
if err != nil {
return nil, err
}
cols = append(cols, column.NewColumnSparseVectors(fieldName, sparse))
default:
return nil, fmt.Errorf("[documents2Columns] column type not support, type=%d", colType)
}
}
return cols, nil
}
func (m *milvusSearchStore) resultSet2Document(result []client.ResultSet, metricsType *mindex.MetricType) (docs []*schema.Document, err error) {
docs = make([]*schema.Document, 0, len(result))
minScore := math.MaxFloat64
maxScore := 0.0
for _, r := range result {
for i := 0; i < r.ResultCount; i++ {
ext := make(map[string]any)
doc := &schema.Document{MetaData: map[string]any{document.MetaDataKeyExternalStorage: ext}}
score := float64(r.Scores[i])
minScore = min(minScore, score)
maxScore = max(maxScore, score)
doc.WithScore(score)
for _, field := range r.Fields {
switch field.Name() {
case searchstore.FieldID:
id, err := field.GetAsInt64(i)
if err != nil {
return nil, err
}
doc.ID = strconv.FormatInt(id, 10)
case searchstore.FieldTextContent:
doc.Content, err = field.GetAsString(i)
case searchstore.FieldCreatorID:
doc.MetaData[document.MetaDataKeyCreatorID], err = field.GetAsInt64(i)
default:
ext[field.Name()], err = field.Get(i)
}
if err != nil {
return nil, err
}
}
docs = append(docs, doc)
}
}
sort.Slice(docs, func(i, j int) bool {
return docs[i].Score() > docs[j].Score()
})
// norm score
if (m.config.EnableHybrid != nil && *m.config.EnableHybrid) || metricsType == nil {
return docs, nil
}
switch *metricsType {
case mentity.L2:
base := maxScore - minScore
for i := range docs {
if base == 0 {
docs[i].WithScore(1.0)
} else {
docs[i].WithScore(1.0 - (docs[i].Score()-minScore)/base)
}
}
docs = slices.Reverse(docs)
case mentity.IP, mentity.COSINE:
for i := range docs {
docs[i].WithScore((docs[i].Score() + 1) / 2)
}
default:
}
return docs, nil
}
func (m *milvusSearchStore) enableSparse(fields []*mentity.Field) bool {
found := false
for _, field := range fields {
if field.DataType == mentity.FieldTypeSparseVector {
found = true
break
}
}
return found && *m.config.EnableHybrid && m.config.Embedding.SupportStatus() == embedding.SupportDenseAndSparse
}
func (m *milvusSearchStore) dsl2Expr(src map[string]interface{}) (string, error) {
if src == nil {
return "", nil
}
dsl, err := searchstore.LoadDSL(src)
if err != nil {
return "", err
}
var travDSL func(dsl *searchstore.DSL) (string, error)
travDSL = func(dsl *searchstore.DSL) (string, error) {
kv := map[string]interface{}{
"field": dsl.Field,
"val": dsl.Value,
}
switch dsl.Op {
case searchstore.OpEq:
return pyfmt.Fmt("{field} == {val}", kv)
case searchstore.OpNe:
return pyfmt.Fmt("{field} != {val}", kv)
case searchstore.OpLike:
return pyfmt.Fmt("{field} LIKE {val}", kv)
case searchstore.OpIn:
b, err := json.Marshal(dsl.Value)
if err != nil {
return "", err
}
kv["val"] = string(b)
return pyfmt.Fmt("{field} IN {val}", kv)
case searchstore.OpAnd, searchstore.OpOr:
sub, ok := dsl.Value.([]*searchstore.DSL)
if !ok {
return "", fmt.Errorf("[dsl2Expr] invalid sub dsl")
}
var items []string
for _, s := range sub {
str, err := travDSL(s)
if err != nil {
return "", fmt.Errorf("[dsl2Expr] parse sub failed, %w", err)
}
items = append(items, str)
}
if dsl.Op == searchstore.OpAnd {
return strings.Join(items, " AND "), nil
} else {
return strings.Join(items, " OR "), nil
}
default:
return "", fmt.Errorf("[dsl2Expr] unknown op type=%s", dsl.Op)
}
}
return travDSL(dsl)
}
func (m *milvusSearchStore) getIndexMetricsType(ctx context.Context, indexName string) (mindex.MetricType, error) {
index, err := m.config.Client.DescribeIndex(ctx, client.NewDescribeIndexOption(m.collectionName, indexName))
if err != nil {
return "", fmt.Errorf("[getIndexMetricsType] describe index failed, collection=%s, index=%s, %w",
m.collectionName, indexName, err)
}
typ, found := index.Params()[mindex.MetricTypeKey]
if !found { // unexpected
return "", fmt.Errorf("[getIndexMetricsType] invalid index params, collection=%s, index=%s", m.collectionName, indexName)
}
return mindex.MetricType(typ), nil
}

View File

@@ -0,0 +1,122 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package vikingdb
import (
"fmt"
"github.com/volcengine/volc-sdk-golang/service/vikingdb"
embcontract "github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
type VikingEmbeddingModelName string
const (
ModelNameDoubaoEmbedding VikingEmbeddingModelName = "doubao-embedding"
ModelNameDoubaoEmbeddingLarge VikingEmbeddingModelName = "doubao-embedding-large"
ModelNameDoubaoEmbeddingVision VikingEmbeddingModelName = "doubao-embedding-vision"
ModelNameBGELargeZH VikingEmbeddingModelName = "bge-large-zh"
ModelNameBGEM3 VikingEmbeddingModelName = "bge-m3"
ModelNameBGEVisualizedM3 VikingEmbeddingModelName = "bge-visualized-m3"
//ModelNameDoubaoEmbeddingAndM3 VikingEmbeddingModelName = "doubao-embedding-and-m3"
//ModelNameDoubaoEmbeddingLargeAndM3 VikingEmbeddingModelName = "doubao-embedding-large-and-m3"
//ModelNameBGELargeZHAndM3 VikingEmbeddingModelName = "bge-large-zh-and-m3"
)
func (v VikingEmbeddingModelName) Dimensions() int64 {
switch v {
case ModelNameDoubaoEmbedding, ModelNameDoubaoEmbeddingVision:
return 2048
case ModelNameDoubaoEmbeddingLarge:
return 4096
case ModelNameBGELargeZH, ModelNameBGEM3, ModelNameBGEVisualizedM3:
return 1024
default:
return 0
}
}
func (v VikingEmbeddingModelName) ModelVersion() *string {
switch v {
case ModelNameDoubaoEmbedding:
return ptr.Of("240515")
case ModelNameDoubaoEmbeddingLarge:
return ptr.Of("240915")
case ModelNameDoubaoEmbeddingVision:
return ptr.Of("250328")
default:
return nil
}
}
func (v VikingEmbeddingModelName) SupportStatus() embcontract.SupportStatus {
switch v {
case ModelNameDoubaoEmbedding, ModelNameDoubaoEmbeddingLarge, ModelNameDoubaoEmbeddingVision, ModelNameBGELargeZH, ModelNameBGEVisualizedM3:
return embcontract.SupportDense
case ModelNameBGEM3:
return embcontract.SupportDenseAndSparse
default:
return embcontract.SupportDense
}
}
type IndexType string
const (
IndexTypeHNSW IndexType = vikingdb.HNSW
IndexTypeHNSWHybrid IndexType = vikingdb.HNSW_HYBRID
IndexTypeFlat IndexType = vikingdb.FLAT
IndexTypeIVF IndexType = vikingdb.IVF
IndexTypeDiskANN IndexType = vikingdb.DiskANN
)
type IndexDistance string
const (
IndexDistanceIP IndexDistance = vikingdb.IP
IndexDistanceL2 IndexDistance = vikingdb.L2
IndexDistanceCosine IndexDistance = vikingdb.COSINE
)
type IndexQuant string
const (
IndexQuantInt8 IndexQuant = vikingdb.Int8
IndexQuantFloat IndexQuant = vikingdb.Float
IndexQuantFix16 IndexQuant = vikingdb.Fix16
IndexQuantPQ IndexQuant = vikingdb.PQ
)
const (
vikingEmbeddingUseDense = "return_dense"
vikingEmbeddingUseSparse = "return_sparse"
vikingEmbeddingRespSentenceDense = "sentence_dense_embedding"
vikingEmbeddingRespSentenceSparse = "sentence_sparse_embedding"
vikingIndexName = "opencoze_index"
)
const (
errCollectionNotFound = "collection not found"
errIndexNotFound = "index not found"
)
func denseFieldName(name string) string {
return fmt.Sprintf("dense_%s", name)
}

View File

@@ -0,0 +1,331 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package vikingdb
import (
"context"
"fmt"
"strings"
"github.com/volcengine/volc-sdk-golang/service/vikingdb"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
"github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
)
type ManagerConfig struct {
Service *vikingdb.VikingDBService
IndexingConfig *VikingIndexingConfig
EmbeddingConfig *VikingEmbeddingConfig
// TODO: cache viking collection & index client
}
type VikingIndexingConfig struct {
// vector index config
Type IndexType // default: hnsw / hnsw_hybrid
Distance *IndexDistance // default: ip
Quant *IndexQuant // default: int8
HnswM *int64 // default: 20
HnswCef *int64 // default: 400
HnswSef *int64 // default: 800
// others
CpuQuota int64 // default: 2
ShardCount int64 // default: 1
}
type VikingEmbeddingConfig struct {
UseVikingEmbedding bool
EnableHybrid bool
// viking embedding config
ModelName VikingEmbeddingModelName
ModelVersion *string
DenseWeight *float64
// builtin embedding config
BuiltinEmbedding embedding.Embedder
}
func NewManager(config *ManagerConfig) (searchstore.Manager, error) {
if config.Service == nil {
return nil, fmt.Errorf("[NewManager] vikingdb service is nil")
}
if config.EmbeddingConfig == nil {
return nil, fmt.Errorf("[NewManager] vikingdb embedding config is nil")
}
if !config.EmbeddingConfig.UseVikingEmbedding && config.EmbeddingConfig.BuiltinEmbedding == nil {
return nil, fmt.Errorf("[NewManager] vikingdb built embedding not provided")
}
if !config.EmbeddingConfig.UseVikingEmbedding && config.EmbeddingConfig.EnableHybrid {
return nil, fmt.Errorf("[NewManager] vikingdb hybrid not support for builtin embedding")
}
if config.EmbeddingConfig.UseVikingEmbedding && config.EmbeddingConfig.ModelName == "" {
return nil, fmt.Errorf("[NewManager] vikingdb model name is empty")
}
if config.EmbeddingConfig.UseVikingEmbedding &&
config.EmbeddingConfig.EnableHybrid &&
config.EmbeddingConfig.ModelName.SupportStatus() != embedding.SupportDenseAndSparse {
return nil, fmt.Errorf("[NewManager] vikingdb embedding model not support sparse embedding, model=%v", config.EmbeddingConfig.ModelName)
}
if config.IndexingConfig == nil {
config.IndexingConfig = &VikingIndexingConfig{}
}
if config.IndexingConfig.Type == "" {
if !config.EmbeddingConfig.UseVikingEmbedding || !config.EmbeddingConfig.EnableHybrid {
config.IndexingConfig.Type = IndexTypeHNSW
} else {
config.IndexingConfig.Type = IndexTypeHNSWHybrid
}
}
if config.IndexingConfig.Distance == nil {
config.IndexingConfig.Distance = ptr.Of(IndexDistanceIP)
}
if config.IndexingConfig.Quant == nil {
config.IndexingConfig.Quant = ptr.Of(IndexQuantInt8)
}
if config.IndexingConfig.HnswM == nil {
config.IndexingConfig.HnswM = ptr.Of(int64(20))
}
if config.IndexingConfig.HnswCef == nil {
config.IndexingConfig.HnswCef = ptr.Of(int64(400))
}
if config.IndexingConfig.HnswSef == nil {
config.IndexingConfig.HnswSef = ptr.Of(int64(800))
}
if config.IndexingConfig.CpuQuota == 0 {
config.IndexingConfig.CpuQuota = 2
}
if config.IndexingConfig.ShardCount == 0 {
config.IndexingConfig.ShardCount = 1
}
return &manager{
config: config,
}, nil
}
type manager struct {
config *ManagerConfig
}
func (m *manager) Create(ctx context.Context, req *searchstore.CreateRequest) error {
if err := m.createCollection(ctx, req); err != nil {
return err
}
if err := m.createIndex(ctx, req); err != nil {
return err
}
return nil
}
func (m *manager) Drop(_ context.Context, req *searchstore.DropRequest) error {
if err := m.config.Service.DropIndex(req.CollectionName, vikingIndexName); err != nil {
if !strings.Contains(err.Error(), errIndexNotFound) {
return err
}
}
if err := m.config.Service.DropCollection(req.CollectionName); err != nil {
if !strings.Contains(err.Error(), errCollectionNotFound) {
return err
}
}
return nil
}
func (m *manager) GetType() searchstore.SearchStoreType {
return searchstore.TypeVectorStore
}
func (m *manager) GetSearchStore(_ context.Context, collectionName string) (searchstore.SearchStore, error) {
collection, err := m.config.Service.GetCollection(collectionName)
if err != nil {
return nil, err
}
return &vkSearchStore{manager: m, collection: collection}, nil
}
func (m *manager) createCollection(ctx context.Context, req *searchstore.CreateRequest) error {
svc := m.config.Service
collection, err := svc.GetCollection(req.CollectionName)
if err != nil {
if !strings.Contains(err.Error(), errCollectionNotFound) {
return err
}
} else if collection != nil {
return nil
}
fields, vopts, err := m.mapFields(req.Fields)
if err != nil {
return err
}
if vopts != nil {
_, err = svc.CreateCollection(req.CollectionName, fields, "", vopts)
} else {
_, err = svc.CreateCollection(req.CollectionName, fields, "")
}
if err != nil {
return err
}
logs.CtxInfof(ctx, "[vikingdb] Create collection success, collection=%s", req.CollectionName)
return nil
}
func (m *manager) createIndex(ctx context.Context, req *searchstore.CreateRequest) error {
svc := m.config.Service
index, err := svc.GetIndex(req.CollectionName, vikingIndexName)
if err != nil {
if !strings.Contains(err.Error(), errIndexNotFound) {
return err
}
} else if index != nil {
return nil
}
vectorIndex := &vikingdb.VectorIndexParams{
IndexType: string(m.config.IndexingConfig.Type),
Distance: string(ptr.From(m.config.IndexingConfig.Distance)),
Quant: string(ptr.From(m.config.IndexingConfig.Quant)),
HnswM: ptr.From(m.config.IndexingConfig.HnswM),
HnswCef: ptr.From(m.config.IndexingConfig.HnswCef),
HnswSef: ptr.From(m.config.IndexingConfig.HnswSef),
}
opts := vikingdb.NewIndexOptions().
SetVectorIndex(vectorIndex).
SetCpuQuota(m.config.IndexingConfig.CpuQuota).
SetShardCount(m.config.IndexingConfig.ShardCount)
_, err = svc.CreateIndex(req.CollectionName, vikingIndexName, opts)
if err != nil {
return err
}
logs.CtxInfof(ctx, "[vikingdb] Create index success, collection=%s, index=%s", req.CollectionName, vikingIndexName)
return nil
}
func (m *manager) mapFields(srcFields []*searchstore.Field) ([]vikingdb.Field, []*vikingdb.VectorizeTuple, error) {
var (
foundID bool
foundCreatorID bool
dstFields = make([]vikingdb.Field, 0, len(srcFields))
vectorizeOpts []*vikingdb.VectorizeTuple
embConfig = m.config.EmbeddingConfig
)
for _, srcField := range srcFields {
switch srcField.Name {
case searchstore.FieldID:
foundID = true
case searchstore.FieldCreatorID:
foundCreatorID = true
default:
}
if srcField.Indexing {
if srcField.Type != searchstore.FieldTypeText {
return nil, nil, fmt.Errorf("[mapFields] currently only support text field indexing, field=%s", srcField.Name)
}
if embConfig.UseVikingEmbedding {
vt := vikingdb.NewVectorizeTuple().SetDense(m.newVectorizeModelConf(srcField.Name))
if embConfig.EnableHybrid {
vt = vt.SetSparse(m.newVectorizeModelConf(srcField.Name))
}
vectorizeOpts = append(vectorizeOpts, vt)
} else {
dstFields = append(dstFields, vikingdb.Field{
FieldName: denseFieldName(srcField.Name),
FieldType: vikingdb.Vector,
DefaultVal: nil,
Dim: m.getDims(),
})
}
}
dstField := vikingdb.Field{
FieldName: srcField.Name,
IsPrimaryKey: srcField.IsPrimary,
}
switch srcField.Type {
case searchstore.FieldTypeInt64:
dstField.FieldType = vikingdb.Int64
case searchstore.FieldTypeText:
dstField.FieldType = vikingdb.Text
case searchstore.FieldTypeDenseVector:
dstField.FieldType = vikingdb.Vector
dstField.Dim = m.getDims()
case searchstore.FieldTypeSparseVector:
dstField.FieldType = vikingdb.Sparse_Vector
default:
return nil, nil, fmt.Errorf("unknown field type: %v", srcField.Type)
}
dstFields = append(dstFields, dstField)
}
if !foundID {
dstFields = append(dstFields, vikingdb.Field{
FieldName: searchstore.FieldID,
FieldType: vikingdb.Int64,
IsPrimaryKey: true,
})
}
if !foundCreatorID {
dstFields = append(dstFields, vikingdb.Field{
FieldName: searchstore.FieldCreatorID,
FieldType: vikingdb.Int64,
})
}
return dstFields, vectorizeOpts, nil
}
func (m *manager) newVectorizeModelConf(fieldName string) *vikingdb.VectorizeModelConf {
embConfig := m.config.EmbeddingConfig
vmc := vikingdb.NewVectorizeModelConf().
SetTextField(fieldName).
SetModelName(string(embConfig.ModelName)).
SetDim(m.getDims())
if embConfig.ModelVersion != nil {
vmc = vmc.SetModelVersion(ptr.From(embConfig.ModelVersion))
}
return vmc
}
func (m *manager) getDims() int64 {
if m.config.EmbeddingConfig.UseVikingEmbedding {
return m.config.EmbeddingConfig.ModelName.Dimensions()
}
return m.config.EmbeddingConfig.BuiltinEmbedding.Dimensions()
}

View File

@@ -0,0 +1,388 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package vikingdb
import (
"context"
"encoding/json"
"fmt"
"reflect"
"strconv"
"github.com/cloudwego/eino/components/indexer"
"github.com/cloudwego/eino/components/retriever"
"github.com/cloudwego/eino/schema"
"github.com/volcengine/volc-sdk-golang/service/vikingdb"
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/sets"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
)
type vkSearchStore struct {
*manager
collection *vikingdb.Collection
index *vikingdb.Index
}
func (v *vkSearchStore) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) (ids []string, err error) {
if len(docs) == 0 {
return nil, nil
}
implSpecOptions := indexer.GetImplSpecificOptions(&searchstore.IndexerOptions{}, opts...)
defer func() {
if err != nil {
if implSpecOptions.ProgressBar != nil {
_ = implSpecOptions.ProgressBar.ReportError(err)
}
}
}()
docsWithoutVector, err := slices.TransformWithErrorCheck(docs, v.document2DataWithoutVector)
if err != nil {
return nil, fmt.Errorf("[Store] vikingdb failed to transform documents, %w", err)
}
indexingFields := sets.FromSlice(implSpecOptions.IndexingFields)
for _, part := range slices.Chunks(docsWithoutVector, 100) {
docsWithVector, err := v.addEmbedding(ctx, part, indexingFields)
if err != nil {
return nil, err
}
if err := v.collection.UpsertData(docsWithVector); err != nil {
return nil, err
}
}
ids = slices.Transform(docs, func(a *schema.Document) string { return a.ID })
return
}
func (v *vkSearchStore) Retrieve(ctx context.Context, query string, opts ...retriever.Option) (docs []*schema.Document, err error) {
indexClient := v.index
if indexClient == nil {
foundIndex := false
for _, index := range v.collection.Indexes {
if index.IndexName == vikingIndexName {
foundIndex = true
break
}
}
if !foundIndex {
return nil, fmt.Errorf("[Retrieve] vikingdb index not found, name=%s", vikingIndexName)
}
indexClient, err = v.config.Service.GetIndex(v.collection.CollectionName, vikingIndexName)
if err != nil {
return nil, fmt.Errorf("[Retrieve] vikingdb failed to get index, %w", err)
}
}
options := retriever.GetCommonOptions(&retriever.Options{TopK: ptr.Of(4)}, opts...)
implSpecOptions := retriever.GetImplSpecificOptions(&searchstore.RetrieverOptions{}, opts...)
searchOpts := vikingdb.NewSearchOptions().
SetLimit(int64(ptr.From(options.TopK))).
SetText(query).
SetRetry(true)
filter, err := v.genFilter(ctx, options, implSpecOptions)
if err != nil {
return nil, fmt.Errorf("[Retrieve] vikingdb failed to build filter, %w", err)
}
if filter != nil {
// 不支持跨 partition 召回,使用 filter 替代
searchOpts = searchOpts.SetFilter(filter)
}
var data []*vikingdb.Data
if v.config.EmbeddingConfig.UseVikingEmbedding {
data, err = indexClient.SearchWithMultiModal(searchOpts)
} else {
var dense [][]float64
dense, err = v.config.EmbeddingConfig.BuiltinEmbedding.EmbedStrings(ctx, []string{query})
if err != nil {
return nil, fmt.Errorf("[Retrieve] embed failed, %w", err)
}
if len(dense) != 1 {
return nil, fmt.Errorf("[Retrieve] unexpected dense vector size, expected=1, got=%d", len(dense))
}
data, err = indexClient.SearchByVector(dense[0], searchOpts)
}
if err != nil {
return nil, fmt.Errorf("[Retrieve] vikingdb search failed, %w", err)
}
docs, err = v.parseSearchResult(data)
if err != nil {
return nil, err
}
return
}
func (v *vkSearchStore) Delete(ctx context.Context, ids []string) error {
for _, part := range slices.Chunks(ids, 100) {
if err := v.collection.DeleteData(part); err != nil {
return err
}
}
return nil
}
func (v *vkSearchStore) document2DataWithoutVector(doc *schema.Document) (data vikingdb.Data, err error) {
creatorID, err := document.GetDocumentCreatorID(doc)
if err != nil {
return data, err
}
docID, err := strconv.ParseInt(doc.ID, 10, 64)
if err != nil {
return data, err
}
fields := map[string]interface{}{
searchstore.FieldID: docID,
searchstore.FieldCreatorID: creatorID,
searchstore.FieldTextContent: doc.Content,
}
if ext, err := document.GetDocumentExternalStorage(doc); err == nil { // try load
for key, val := range ext {
fields[key] = val
}
}
return vikingdb.Data{
Id: doc.ID,
Fields: fields,
}, nil
}
func (v *vkSearchStore) addEmbedding(ctx context.Context, rows []vikingdb.Data, indexingFields map[string]struct{}) ([]vikingdb.Data, error) {
if v.config.EmbeddingConfig.UseVikingEmbedding {
return rows, nil
}
emb := v.config.EmbeddingConfig.BuiltinEmbedding
for indexingField := range indexingFields {
values := make([]string, len(rows))
for i, row := range rows {
val, found := row.Fields[indexingField]
if !found {
return nil, fmt.Errorf("[addEmbedding] indexing field not found in document, field=%s", indexingField)
}
strVal, ok := val.(string)
if !ok {
return nil, fmt.Errorf("[addEmbedding] val not string, field=%s, val=%v", indexingField, val)
}
values[i] = strVal
}
dense, err := emb.EmbedStrings(ctx, values)
if err != nil {
return nil, fmt.Errorf("[addEmbedding] failed to embed, %w", err)
}
if len(dense) != len(values) {
return nil, fmt.Errorf("[addEmbedding] unexpected dense vector size, expected=%d, got=%d", len(values), len(dense))
}
df := denseFieldName(indexingField)
for i := range dense {
rows[i].Fields[df] = dense[i]
}
}
return rows, nil
}
func (v *vkSearchStore) parseSearchResult(result []*vikingdb.Data) ([]*schema.Document, error) {
docs := make([]*schema.Document, 0, len(result))
for _, data := range result {
ext := make(map[string]any)
doc := document.WithDocumentExternalStorage(&schema.Document{MetaData: map[string]any{}}, ext).
WithScore(data.Score)
for field, val := range data.Fields {
switch field {
case searchstore.FieldID:
jn, ok := val.(json.Number)
if !ok {
return nil, fmt.Errorf("[parseSearchResult] id type assertion failed, val=%v", val)
}
doc.ID = jn.String()
case searchstore.FieldCreatorID:
jn, ok := val.(json.Number)
if !ok {
return nil, fmt.Errorf("[parseSearchResult] creator_id type assertion failed, val=%v", val)
}
creatorID, err := jn.Int64()
if err != nil {
return nil, fmt.Errorf("[parseSearchResult] creator_id value not int64, val=%v", jn.String())
}
doc = document.WithDocumentCreatorID(doc, creatorID)
case searchstore.FieldTextContent:
text, ok := val.(string)
if !ok {
return nil, fmt.Errorf("[parseSearchResult] content value not string, val=%v", val)
}
doc.Content = text
default:
switch t := val.(type) {
case json.Number:
if i64, err := t.Int64(); err == nil {
ext[field] = i64
} else if f64, err := t.Float64(); err == nil {
ext[field] = f64
} else {
ext[field] = t.String()
}
default:
ext[field] = val
}
}
}
docs = append(docs, doc)
}
return docs, nil
}
func (v *vkSearchStore) genFilter(ctx context.Context, co *retriever.Options, ro *searchstore.RetrieverOptions) (map[string]any, error) {
filter, err := v.dsl2Filter(ctx, co.DSLInfo)
if err != nil {
return nil, err
}
if ro.PartitionKey != nil && len(ro.Partitions) > 0 {
var (
key = ptr.From(ro.PartitionKey)
fieldType = ""
conds any
)
for _, field := range v.collection.Fields {
if field.FieldName == key {
fieldType = field.FieldType
}
}
if fieldType == "" {
return nil, fmt.Errorf("[Retrieve] partition key not found, key=%s", key)
}
switch fieldType {
case vikingdb.Int64:
c := make([]int64, 0, len(ro.Partitions))
for _, item := range ro.Partitions {
i64, err := strconv.ParseInt(item, 10, 64)
if err != nil {
return nil, fmt.Errorf("[Retrieve] partition value parse error, key=%s, val=%v, err=%v", key, item, err)
}
c = append(c, i64)
}
conds = c
case vikingdb.String:
conds = ro.Partitions
default:
return nil, fmt.Errorf("[Retrieve] invalid field type for partition, key=%s, type=%s", key, fieldType)
}
op := map[string]any{"op": "must", "field": key, "conds": conds}
if filter != nil {
filter = op
} else {
filter = map[string]any{
"op": "and",
"conds": []map[string]any{op, filter},
}
}
}
return filter, nil
}
func (v *vkSearchStore) dsl2Filter(ctx context.Context, src map[string]any) (map[string]any, error) {
dsl, err := searchstore.LoadDSL(src)
if err != nil {
return nil, err
}
if dsl == nil {
return nil, nil
}
toSliceValue := func(val any) any {
if reflect.TypeOf(val).Kind() == reflect.Slice {
return val
}
return []any{val}
}
var filter map[string]any
switch dsl.Op {
case searchstore.OpEq, searchstore.OpIn:
filter = map[string]any{
"op": "must",
"field": dsl.Field,
"conds": toSliceValue(dsl.Value),
}
case searchstore.OpNe:
filter = map[string]any{
"op": "must_not",
"field": dsl.Field,
"conds": toSliceValue(dsl.Value),
}
case searchstore.OpLike:
logs.CtxWarnf(ctx, "[dsl2Filter] vikingdb invalid dsl type, skip, type=%s", dsl.Op)
case searchstore.OpAnd, searchstore.OpOr:
var conds []map[string]any
sub, ok := dsl.Value.([]map[string]any)
if !ok {
return nil, fmt.Errorf("[dsl2Filter] invalid value for and/or, should be []map[string]any")
}
for _, subDSL := range sub {
cond, err := v.dsl2Filter(ctx, subDSL)
if err != nil {
return nil, err
}
conds = append(conds, cond)
}
op := "and"
if dsl.Op == searchstore.OpOr {
op = "or"
}
filter = map[string]any{
"op": op,
"field": dsl.Field,
"conds": conds,
}
}
return filter, nil
}

View File

@@ -0,0 +1,290 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package vikingdb
import (
"context"
"fmt"
"os"
"testing"
"github.com/cloudwego/eino-ext/components/embedding/openai"
"github.com/cloudwego/eino/components/retriever"
"github.com/cloudwego/eino/schema"
"github.com/stretchr/testify/assert"
"github.com/volcengine/volc-sdk-golang/service/vikingdb"
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
"github.com/coze-dev/coze-studio/backend/infra/impl/embedding/wrap"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
func TestVikingEmbeddingIntegration(t *testing.T) {
if os.Getenv("ENABLE_VIKINGDB_INTEGRATION_TEST") != "true" {
return
}
ctx := context.Background()
svc := vikingdb.NewVikingDBService(
"api-vikingdb.volces.com",
"cn-beijing",
os.Getenv("VIKING_DB_AK"),
os.Getenv("VIKING_DB_SK"),
"https",
)
cfg := &ManagerConfig{
Service: svc,
IndexingConfig: nil,
EmbeddingConfig: &VikingEmbeddingConfig{
UseVikingEmbedding: true,
EnableHybrid: false,
ModelName: ModelNameDoubaoEmbedding,
ModelVersion: ModelNameDoubaoEmbedding.ModelVersion(),
DenseWeight: nil,
BuiltinEmbedding: nil,
},
}
mgr, err := NewManager(cfg)
assert.NoError(t, err)
collectionName := "test_coze_coll_1"
t.Run("create", func(t *testing.T) {
err = mgr.Create(ctx, &searchstore.CreateRequest{
CollectionName: collectionName,
Fields: []*searchstore.Field{
{
Name: searchstore.FieldID,
Type: searchstore.FieldTypeInt64,
IsPrimary: true,
},
{
Name: searchstore.FieldCreatorID,
Type: searchstore.FieldTypeInt64,
},
{
Name: "document_id",
Type: searchstore.FieldTypeInt64,
},
{
Name: searchstore.FieldTextContent,
Type: searchstore.FieldTypeText,
Indexing: true,
},
},
CollectionMeta: nil,
})
assert.NoError(t, err)
})
t.Run("store", func(t *testing.T) {
ss, err := mgr.GetSearchStore(ctx, collectionName)
assert.NoError(t, err)
ids, err := ss.Store(ctx, []*schema.Document{
{
ID: "101",
Content: "埃菲尔铁塔:位于法国巴黎,是世界上最著名的地标之一,由居斯塔夫・埃菲尔设计并建于 1889 年。",
MetaData: map[string]any{
document.MetaDataKeyCreatorID: int64(111),
document.MetaDataKeyExternalStorage: map[string]any{
"document_id": int64(567),
},
},
},
{
ID: "102",
Content: "长城:位于中国,是世界七大奇迹之一,从秦至明代修筑而成,全长超过 2 万公里",
MetaData: map[string]any{
document.MetaDataKeyCreatorID: int64(111),
document.MetaDataKeyExternalStorage: map[string]any{
"document_id": int64(567),
},
},
},
{
ID: "103",
Content: "罗马斗兽场:位于意大利罗马,于公元 70-80 年间建成,是古罗马帝国最大的圆形竞技场。",
MetaData: map[string]any{
document.MetaDataKeyCreatorID: int64(111),
document.MetaDataKeyExternalStorage: map[string]any{
"document_id": int64(568),
},
},
},
}, searchstore.WithIndexingFields([]string{searchstore.FieldTextContent}))
assert.NoError(t, err)
fmt.Println(ids)
})
t.Run("retrieve", func(t *testing.T) {
ss, err := mgr.GetSearchStore(ctx, collectionName)
assert.NoError(t, err)
dsl := &searchstore.DSL{
Op: searchstore.OpIn,
Field: "creator_id",
Value: int64(111),
}
opts := []retriever.Option{
searchstore.WithRetrieverPartitionKey("document_id"),
searchstore.WithPartitions([]string{"567"}),
retriever.WithDSLInfo(dsl.DSL()),
}
resp, err := ss.Retrieve(ctx, "旅游景点推荐", opts...)
assert.NoError(t, err)
fmt.Println(resp)
})
t.Run("drop", func(t *testing.T) {
assert.NoError(t, mgr.Drop(ctx, &searchstore.DropRequest{CollectionName: collectionName}))
})
}
func TestBuiltinEmbeddingIntegration(t *testing.T) {
if os.Getenv("ENABLE_VIKINGDB_INTEGRATION_TEST") != "true" {
return
}
ctx := context.Background()
svc := vikingdb.NewVikingDBService(
"api-vikingdb.volces.com",
"cn-beijing",
os.Getenv("VIKING_DB_AK"),
os.Getenv("VIKING_DB_SK"),
"https",
)
embConfig := &openai.EmbeddingConfig{
APIKey: os.Getenv("OPENAI_EMBEDDING_API_KEY"),
ByAzure: true,
BaseURL: os.Getenv("OPENAI_EMBEDDING_BASE_URL"),
Model: os.Getenv("OPENAI_EMBEDDING_MODEL"),
Dimensions: ptr.Of(1024),
}
emb, err := wrap.NewOpenAIEmbedder(ctx, embConfig, 1024)
assert.NoError(t, err)
cfg := &ManagerConfig{
Service: svc,
IndexingConfig: nil,
EmbeddingConfig: &VikingEmbeddingConfig{
UseVikingEmbedding: false,
BuiltinEmbedding: emb,
},
}
mgr, err := NewManager(cfg)
assert.NoError(t, err)
collectionName := "test_coze_coll_2"
t.Run("create", func(t *testing.T) {
err = mgr.Create(ctx, &searchstore.CreateRequest{
CollectionName: collectionName,
Fields: []*searchstore.Field{
{
Name: searchstore.FieldID,
Type: searchstore.FieldTypeInt64,
IsPrimary: true,
},
{
Name: searchstore.FieldCreatorID,
Type: searchstore.FieldTypeInt64,
},
{
Name: "document_id",
Type: searchstore.FieldTypeInt64,
},
{
Name: searchstore.FieldTextContent,
Type: searchstore.FieldTypeText,
Indexing: true,
},
},
CollectionMeta: nil,
})
assert.NoError(t, err)
})
t.Run("store", func(t *testing.T) {
ss, err := mgr.GetSearchStore(ctx, collectionName)
assert.NoError(t, err)
ids, err := ss.Store(ctx, []*schema.Document{
{
ID: "101",
Content: "埃菲尔铁塔:位于法国巴黎,是世界上最著名的地标之一,由居斯塔夫・埃菲尔设计并建于 1889 年。",
MetaData: map[string]any{
document.MetaDataKeyCreatorID: int64(111),
document.MetaDataKeyExternalStorage: map[string]any{
"document_id": int64(567),
},
},
},
{
ID: "102",
Content: "长城:位于中国,是世界七大奇迹之一,从秦至明代修筑而成,全长超过 2 万公里",
MetaData: map[string]any{
document.MetaDataKeyCreatorID: int64(111),
document.MetaDataKeyExternalStorage: map[string]any{
"document_id": int64(567),
},
},
},
{
ID: "103",
Content: "罗马斗兽场:位于意大利罗马,于公元 70-80 年间建成,是古罗马帝国最大的圆形竞技场。",
MetaData: map[string]any{
document.MetaDataKeyCreatorID: int64(111),
document.MetaDataKeyExternalStorage: map[string]any{
"document_id": int64(568),
},
},
},
}, searchstore.WithIndexingFields([]string{searchstore.FieldTextContent}))
assert.NoError(t, err)
fmt.Println(ids)
})
t.Run("retrieve", func(t *testing.T) {
ss, err := mgr.GetSearchStore(ctx, collectionName)
assert.NoError(t, err)
dsl := &searchstore.DSL{
Op: searchstore.OpIn,
Field: "creator_id",
Value: int64(111),
}
opts := []retriever.Option{
searchstore.WithRetrieverPartitionKey("document_id"),
searchstore.WithPartitions([]string{"567"}),
retriever.WithDSLInfo(dsl.DSL()),
}
resp, err := ss.Retrieve(ctx, "旅游景点推荐", opts...)
assert.NoError(t, err)
fmt.Println(resp)
})
t.Run("drop", func(t *testing.T) {
assert.NoError(t, mgr.Drop(ctx, &searchstore.DropRequest{CollectionName: collectionName}))
})
}