feat: http embedding using /support_status to return support vector types (#564)

This commit is contained in:
N3ko 2025-08-07 15:10:20 +08:00 committed by GitHub
parent e2b1f6e381
commit 8cc2a7768c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 68 additions and 4 deletions

View File

@ -20,18 +20,25 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"io" "io"
"net/http" "net/http"
"net/url"
"strconv"
"time" "time"
opt "github.com/cloudwego/eino/components/embedding" opt "github.com/cloudwego/eino/components/embedding"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices" "github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/infra/contract/embedding" "github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
) )
const pathEmbed = "/embedding" const (
pathEmbed = "/embedding"
pathSupportStatus = "/support_status"
)
type embedReq struct { type embedReq struct {
Texts []string `json:"texts"` Texts []string `json:"texts"`
@ -45,11 +52,56 @@ type embedResp struct {
func NewEmbedding(addr string, dims int64, batchSize int) (embedding.Embedder, error) { func NewEmbedding(addr string, dims int64, batchSize int) (embedding.Embedder, error) {
cli := &http.Client{Timeout: time.Second * 30} cli := &http.Client{Timeout: time.Second * 30}
status := embedding.SupportDenseAndSparse
getStatusErr := func() error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
path, err := url.JoinPath(addr, pathSupportStatus)
if err != nil {
return fmt.Errorf("url join path failed, %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, path, nil)
if err != nil {
return fmt.Errorf("new request failed, %w", err)
}
resp, err := cli.Do(req)
if err != nil {
return fmt.Errorf("/support_status failed, %w", err)
}
if resp == nil {
return fmt.Errorf("/support_status nil response")
} else if resp.StatusCode != http.StatusOK {
return fmt.Errorf("/support_status bad status code: %d", resp.StatusCode)
} else if resp.Body == nil {
return fmt.Errorf("/support_status nil response body")
}
defer resp.Body.Close()
b, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("/support_status read response body failed, %w", err)
}
i, err := strconv.ParseInt(string(b), 10, 64)
if err != nil {
return fmt.Errorf("/support_status parse response body failed, %w", err)
}
s := embedding.SupportStatus(i)
if s != embedding.SupportDense && s != embedding.SupportDenseAndSparse {
return fmt.Errorf("invalid support status=%d", s)
}
status = s
return nil
}()
if getStatusErr != nil {
logs.Errorf("[NewEmbedding] http embedding get support status failed, using default SupportDenseAndSparse=3, %v", getStatusErr)
}
return &embedder{ return &embedder{
cli: cli, cli: cli,
addr: addr, addr: addr,
dim: dims, dim: dims,
batchSize: batchSize, batchSize: batchSize,
status: status,
}, nil }, nil
} }
@ -58,6 +110,7 @@ type embedder struct {
addr string addr string
dim int64 dim int64
batchSize int batchSize int
status embedding.SupportStatus
} }
func (e *embedder) EmbedStrings(ctx context.Context, texts []string, opts ...opt.Option) ([][]float64, error) { func (e *embedder) EmbedStrings(ctx context.Context, texts []string, opts ...opt.Option) ([][]float64, error) {
@ -70,7 +123,11 @@ func (e *embedder) EmbedStrings(ctx context.Context, texts []string, opts ...opt
if err != nil { if err != nil {
return nil, err return nil, err
} }
req, err := http.NewRequestWithContext(ctx, http.MethodPost, e.addr+pathEmbed, bytes.NewReader(rb)) path, err := url.JoinPath(e.addr, pathEmbed)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, path, bytes.NewReader(rb))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -85,6 +142,9 @@ func (e *embedder) EmbedStrings(ctx context.Context, texts []string, opts ...opt
} }
func (e *embedder) EmbedStringsHybrid(ctx context.Context, texts []string, opts ...opt.Option) ([][]float64, []map[int]float64, error) { func (e *embedder) EmbedStringsHybrid(ctx context.Context, texts []string, opts ...opt.Option) ([][]float64, []map[int]float64, error) {
if e.status == embedding.SupportDense {
return nil, nil, fmt.Errorf("support status=%d not support EmbedStringsHybrid", e.status)
}
dense := make([][]float64, 0, len(texts)) dense := make([][]float64, 0, len(texts))
sparse := make([]map[int]float64, 0, len(texts)) sparse := make([]map[int]float64, 0, len(texts))
for _, part := range slices.Chunks(texts, e.batchSize) { for _, part := range slices.Chunks(texts, e.batchSize) {
@ -95,7 +155,11 @@ func (e *embedder) EmbedStringsHybrid(ctx context.Context, texts []string, opts
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
req, err := http.NewRequestWithContext(ctx, http.MethodPost, e.addr+pathEmbed, bytes.NewReader(rb)) path, err := url.JoinPath(e.addr, pathEmbed)
if err != nil {
return nil, nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, path, bytes.NewReader(rb))
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -134,5 +198,5 @@ func (e *embedder) Dimensions() int64 {
} }
func (e *embedder) SupportStatus() embedding.SupportStatus { func (e *embedder) SupportStatus() embedding.SupportStatus {
return embedding.SupportDenseAndSparse return e.status
} }