diff --git a/backend/application/knowledge/init.go b/backend/application/knowledge/init.go index 0a00be16..6e1d0659 100644 --- a/backend/application/knowledge/init.go +++ b/backend/application/knowledge/init.go @@ -280,6 +280,14 @@ func getVectorStore(ctx context.Context) (searchstore.Manager, error) { } func getEmbedding(ctx context.Context) (embedding.Embedder, error) { + var batchSize int + if bs, err := strconv.ParseInt(os.Getenv("EMBEDDING_MAX_BATCH_SIZE"), 10, 64); err != nil { + logs.CtxWarnf(ctx, "EMBEDDING_MAX_BATCH_SIZE not set / invalid, using default batchSize=100") + batchSize = 100 + } else { + batchSize = int(bs) + } + var emb embedding.Embedder switch os.Getenv("EMBEDDING_TYPE") { @@ -318,7 +326,7 @@ func getEmbedding(ctx context.Context) (embedding.Embedder, error) { openAICfg.Dimensions = ptr.Of(int(reqDims)) } - emb, err = wrap.NewOpenAIEmbedder(ctx, openAICfg, dims) + emb, err = wrap.NewOpenAIEmbedder(ctx, openAICfg, dims, batchSize) if err != nil { return nil, fmt.Errorf("init openai embedding failed, err=%w", err) } @@ -340,7 +348,7 @@ func getEmbedding(ctx context.Context) (embedding.Embedder, error) { APIKey: arkEmbeddingAK, Model: arkEmbeddingModel, BaseURL: arkEmbeddingBaseURL, - }, dims) + }, dims, batchSize) if err != nil { return nil, fmt.Errorf("init ark embedding client failed, err=%w", err) } @@ -360,7 +368,7 @@ func getEmbedding(ctx context.Context) (embedding.Embedder, error) { emb, err = wrap.NewOllamaEmbedder(ctx, &ollamaEmb.EmbeddingConfig{ BaseURL: ollamaEmbeddingBaseURL, Model: ollamaEmbeddingModel, - }, dims) + }, dims, batchSize) if err != nil { return nil, fmt.Errorf("init ollama embedding failed, err=%w", err) } @@ -374,7 +382,7 @@ func getEmbedding(ctx context.Context) (embedding.Embedder, error) { if err != nil { return nil, fmt.Errorf("init http embedding dims failed, err=%w", err) } - emb, err = http.NewEmbedding(httpEmbeddingBaseURL, dims) + emb, err = http.NewEmbedding(httpEmbeddingBaseURL, dims, batchSize) if err != nil { return nil, fmt.Errorf("init http embedding failed, err=%w", err) } diff --git a/backend/infra/impl/document/searchstore/vikingdb/vk_test.go b/backend/infra/impl/document/searchstore/vikingdb/vk_test.go index 4826dafa..d113746e 100644 --- a/backend/infra/impl/document/searchstore/vikingdb/vk_test.go +++ b/backend/infra/impl/document/searchstore/vikingdb/vk_test.go @@ -179,7 +179,7 @@ func TestBuiltinEmbeddingIntegration(t *testing.T) { Model: os.Getenv("OPENAI_EMBEDDING_MODEL"), Dimensions: ptr.Of(1024), } - emb, err := wrap.NewOpenAIEmbedder(ctx, embConfig, 1024) + emb, err := wrap.NewOpenAIEmbedder(ctx, embConfig, 1024, 100) assert.NoError(t, err) cfg := &ManagerConfig{ diff --git a/backend/infra/impl/embedding/ark/ark.go b/backend/infra/impl/embedding/ark/ark.go index a3c0b0ed..02976998 100644 --- a/backend/infra/impl/embedding/ark/ark.go +++ b/backend/infra/impl/embedding/ark/ark.go @@ -33,23 +33,24 @@ import ( "github.com/coze-dev/coze-studio/backend/types/errno" ) -func NewArkEmbedder(ctx context.Context, config *ark.EmbeddingConfig, dimensions int64) (contract.Embedder, error) { +func NewArkEmbedder(ctx context.Context, config *ark.EmbeddingConfig, dimensions int64, batchSize int) (contract.Embedder, error) { emb, err := ark.NewEmbedder(ctx, config) if err != nil { return nil, err } - return &embWrap{dims: dimensions, Embedder: emb}, nil + return &embWrap{dims: dimensions, batchSize: batchSize, Embedder: emb}, nil } type embWrap struct { - dims int64 + dims int64 + batchSize int embedding.Embedder } func (d embWrap) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, error) { resp := make([][]float64, 0, len(texts)) - for _, part := range slices.Chunks(texts, 100) { + for _, part := range slices.Chunks(texts, d.batchSize) { partResult, err := d.Embedder.EmbedStrings(ctx, part, opts...) if err != nil { return nil, err diff --git a/backend/infra/impl/embedding/http/http.go b/backend/infra/impl/embedding/http/http.go index 2f101699..da002375 100644 --- a/backend/infra/impl/embedding/http/http.go +++ b/backend/infra/impl/embedding/http/http.go @@ -25,6 +25,7 @@ import ( "time" opt "github.com/cloudwego/eino/components/embedding" + "github.com/coze-dev/coze-studio/backend/pkg/lang/slices" "github.com/coze-dev/coze-studio/backend/infra/contract/embedding" ) @@ -41,88 +42,90 @@ type embedResp struct { Sparse []map[int]float64 `json:"sparse"` } -func NewEmbedding(addr string, dims int64) (embedding.Embedder, error) { +func NewEmbedding(addr string, dims int64, batchSize int) (embedding.Embedder, error) { cli := &http.Client{Timeout: time.Second * 30} return &embedder{ - cli: cli, - addr: addr, - dim: dims, + cli: cli, + addr: addr, + dim: dims, + batchSize: batchSize, }, nil } type embedder struct { - cli *http.Client - addr string - dim int64 + cli *http.Client + addr string + dim int64 + batchSize int } func (e *embedder) EmbedStrings(ctx context.Context, texts []string, opts ...opt.Option) ([][]float64, error) { - rb, err := json.Marshal(&embedReq{ - Texts: texts, - NeedSparse: false, - }) - if err != nil { - return nil, err + dense := make([][]float64, 0, len(texts)) + for _, part := range slices.Chunks(texts, e.batchSize) { + rb, err := json.Marshal(&embedReq{ + Texts: part, + NeedSparse: false, + }) + if err != nil { + return nil, err + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, e.addr+pathEmbed, bytes.NewReader(rb)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json; charset=utf-8") + resp, err := e.do(req) + if err != nil { + return nil, err + } + dense = append(dense, resp.Dense...) } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, e.addr+pathEmbed, bytes.NewReader(rb)) - if err != nil { - return nil, err - } - req.Header.Set("Content-Type", "application/json; charset=utf-8") - - resp, err := e.cli.Do(req) - if err != nil { - return nil, err - } - - defer resp.Body.Close() - b, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } - - r := &embedResp{} - if err = json.Unmarshal(b, r); err != nil { - return nil, err - } - - return r.Dense, nil + return dense, nil } func (e *embedder) EmbedStringsHybrid(ctx context.Context, texts []string, opts ...opt.Option) ([][]float64, []map[int]float64, error) { - rb, err := json.Marshal(&embedReq{ - Texts: texts, - NeedSparse: true, - }) - if err != nil { - return nil, nil, err + dense := make([][]float64, 0, len(texts)) + sparse := make([]map[int]float64, 0, len(texts)) + for _, part := range slices.Chunks(texts, e.batchSize) { + rb, err := json.Marshal(&embedReq{ + Texts: part, + NeedSparse: true, + }) + if err != nil { + return nil, nil, err + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, e.addr+pathEmbed, bytes.NewReader(rb)) + if err != nil { + return nil, nil, err + } + req.Header.Set("Content-Type", "application/json; charset=utf-8") + resp, err := e.do(req) + if err != nil { + return nil, nil, err + } + dense = append(dense, resp.Dense...) + sparse = append(sparse, resp.Sparse...) } + return dense, sparse, nil +} - req, err := http.NewRequestWithContext(ctx, http.MethodPost, e.addr+pathEmbed, bytes.NewReader(rb)) - if err != nil { - return nil, nil, err - } - req.Header.Set("Content-Type", "application/json; charset=utf-8") - +func (e *embedder) do(req *http.Request) (*embedResp, error) { resp, err := e.cli.Do(req) if err != nil { - return nil, nil, err + return nil, err } defer resp.Body.Close() b, err := io.ReadAll(resp.Body) if err != nil { - return nil, nil, err + return nil, err } r := &embedResp{} if err = json.Unmarshal(b, r); err != nil { - return nil, nil, err + return nil, err } - - return r.Dense, r.Sparse, nil - + return r, nil } func (e *embedder) Dimensions() int64 { diff --git a/backend/infra/impl/embedding/http/http_test.go b/backend/infra/impl/embedding/http/http_test.go index 5c88b63b..3a77b910 100644 --- a/backend/infra/impl/embedding/http/http_test.go +++ b/backend/infra/impl/embedding/http/http_test.go @@ -31,7 +31,7 @@ func TestHTTPEmbedding(t *testing.T) { } ctx := context.Background() - emb, err := NewEmbedding("http://127.0.0.1:6543", 1024) + emb, err := NewEmbedding("http://127.0.0.1:6543", 1024, 10) assert.NoError(t, err) texts := []string{ "hello", diff --git a/backend/infra/impl/embedding/wrap/dense_only.go b/backend/infra/impl/embedding/wrap/dense_only.go index b706450e..6f248d3b 100644 --- a/backend/infra/impl/embedding/wrap/dense_only.go +++ b/backend/infra/impl/embedding/wrap/dense_only.go @@ -27,13 +27,14 @@ import ( ) type denseOnlyWrap struct { - dims int64 + dims int64 + batchSize int embedding.Embedder } func (d denseOnlyWrap) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, error) { resp := make([][]float64, 0, len(texts)) - for _, part := range slices.Chunks(texts, 100) { + for _, part := range slices.Chunks(texts, d.batchSize) { partResult, err := d.Embedder.EmbedStrings(ctx, part, opts...) if err != nil { return nil, err diff --git a/backend/infra/impl/embedding/wrap/ollama.go b/backend/infra/impl/embedding/wrap/ollama.go index 516ac075..c3cb5484 100644 --- a/backend/infra/impl/embedding/wrap/ollama.go +++ b/backend/infra/impl/embedding/wrap/ollama.go @@ -24,10 +24,10 @@ import ( contract "github.com/coze-dev/coze-studio/backend/infra/contract/embedding" ) -func NewOllamaEmbedder(ctx context.Context, config *ollama.EmbeddingConfig, dimensions int64) (contract.Embedder, error) { +func NewOllamaEmbedder(ctx context.Context, config *ollama.EmbeddingConfig, dimensions int64, batchSize int) (contract.Embedder, error) { emb, err := ollama.NewEmbedder(ctx, config) if err != nil { return nil, err } - return &denseOnlyWrap{dims: dimensions, Embedder: emb}, nil + return &denseOnlyWrap{dims: dimensions, batchSize: batchSize, Embedder: emb}, nil } diff --git a/backend/infra/impl/embedding/wrap/openai.go b/backend/infra/impl/embedding/wrap/openai.go index 7b02734a..d7607560 100644 --- a/backend/infra/impl/embedding/wrap/openai.go +++ b/backend/infra/impl/embedding/wrap/openai.go @@ -24,10 +24,10 @@ import ( contract "github.com/coze-dev/coze-studio/backend/infra/contract/embedding" ) -func NewOpenAIEmbedder(ctx context.Context, config *openai.EmbeddingConfig, dimensions int64) (contract.Embedder, error) { +func NewOpenAIEmbedder(ctx context.Context, config *openai.EmbeddingConfig, dimensions int64, batchSize int) (contract.Embedder, error) { emb, err := openai.NewEmbedder(ctx, config) if err != nil { return nil, err } - return &denseOnlyWrap{dims: dimensions, Embedder: emb}, nil + return &denseOnlyWrap{dims: dimensions, batchSize: batchSize, Embedder: emb}, nil } diff --git a/docker/.env.example b/docker/.env.example index ef0a8df0..f407baf6 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -97,6 +97,7 @@ export VIKING_DB_MODEL_NAME="" # if vikingdb model name is not set, you need to # Coze Studio supports three access methods: openai, ark, ollama, and custom http. Users can simply choose one of them when using # embedding type: openai / ark / ollama / http export EMBEDDING_TYPE="ark" +export EMBEDDING_MAX_BATCH_SIZE=100 # openai embedding export OPENAI_EMBEDDING_BASE_URL="" # (string) OpenAI base_url export OPENAI_EMBEDDING_MODEL="" # (string) OpenAI embedding model