diff --git a/backend/infra/impl/embedding/http/http.go b/backend/infra/impl/embedding/http/http.go index 8a3640a4..36d7c70a 100644 --- a/backend/infra/impl/embedding/http/http.go +++ b/backend/infra/impl/embedding/http/http.go @@ -20,18 +20,25 @@ import ( "bytes" "context" "encoding/json" + "fmt" "io" "net/http" + "net/url" + "strconv" "time" opt "github.com/cloudwego/eino/components/embedding" + "github.com/coze-dev/coze-studio/backend/pkg/logs" "github.com/coze-dev/coze-studio/backend/pkg/lang/slices" "github.com/coze-dev/coze-studio/backend/infra/contract/embedding" ) -const pathEmbed = "/embedding" +const ( + pathEmbed = "/embedding" + pathSupportStatus = "/support_status" +) type embedReq struct { Texts []string `json:"texts"` @@ -45,11 +52,56 @@ type embedResp struct { func NewEmbedding(addr string, dims int64, batchSize int) (embedding.Embedder, error) { cli := &http.Client{Timeout: time.Second * 30} + status := embedding.SupportDenseAndSparse + + getStatusErr := func() error { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + path, err := url.JoinPath(addr, pathSupportStatus) + if err != nil { + return fmt.Errorf("url join path failed, %w", err) + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, path, nil) + if err != nil { + return fmt.Errorf("new request failed, %w", err) + } + resp, err := cli.Do(req) + if err != nil { + return fmt.Errorf("/support_status failed, %w", err) + } + if resp == nil { + return fmt.Errorf("/support_status nil response") + } else if resp.StatusCode != http.StatusOK { + return fmt.Errorf("/support_status bad status code: %d", resp.StatusCode) + } else if resp.Body == nil { + return fmt.Errorf("/support_status nil response body") + } + defer resp.Body.Close() + b, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("/support_status read response body failed, %w", err) + } + i, err := strconv.ParseInt(string(b), 10, 64) + if err != nil { + return fmt.Errorf("/support_status parse response body failed, %w", err) + } + s := embedding.SupportStatus(i) + if s != embedding.SupportDense && s != embedding.SupportDenseAndSparse { + return fmt.Errorf("invalid support status=%d", s) + } + status = s + return nil + }() + if getStatusErr != nil { + logs.Errorf("[NewEmbedding] http embedding get support status failed, using default SupportDenseAndSparse=3, %v", getStatusErr) + } + return &embedder{ cli: cli, addr: addr, dim: dims, batchSize: batchSize, + status: status, }, nil } @@ -58,6 +110,7 @@ type embedder struct { addr string dim int64 batchSize int + status embedding.SupportStatus } func (e *embedder) EmbedStrings(ctx context.Context, texts []string, opts ...opt.Option) ([][]float64, error) { @@ -70,7 +123,11 @@ func (e *embedder) EmbedStrings(ctx context.Context, texts []string, opts ...opt if err != nil { return nil, err } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, e.addr+pathEmbed, bytes.NewReader(rb)) + path, err := url.JoinPath(e.addr, pathEmbed) + if err != nil { + return nil, err + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, path, bytes.NewReader(rb)) if err != nil { return nil, err } @@ -85,6 +142,9 @@ func (e *embedder) EmbedStrings(ctx context.Context, texts []string, opts ...opt } func (e *embedder) EmbedStringsHybrid(ctx context.Context, texts []string, opts ...opt.Option) ([][]float64, []map[int]float64, error) { + if e.status == embedding.SupportDense { + return nil, nil, fmt.Errorf("support status=%d not support EmbedStringsHybrid", e.status) + } dense := make([][]float64, 0, len(texts)) sparse := make([]map[int]float64, 0, len(texts)) for _, part := range slices.Chunks(texts, e.batchSize) { @@ -95,7 +155,11 @@ func (e *embedder) EmbedStringsHybrid(ctx context.Context, texts []string, opts if err != nil { return nil, nil, err } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, e.addr+pathEmbed, bytes.NewReader(rb)) + path, err := url.JoinPath(e.addr, pathEmbed) + if err != nil { + return nil, nil, err + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, path, bytes.NewReader(rb)) if err != nil { return nil, nil, err } @@ -134,5 +198,5 @@ func (e *embedder) Dimensions() int64 { } func (e *embedder) SupportStatus() embedding.SupportStatus { - return embedding.SupportDenseAndSparse + return e.status }