139 lines
		
	
	
		
			3.4 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			139 lines
		
	
	
		
			3.4 KiB
		
	
	
	
		
			Go
		
	
	
	
| /*
 | |
|  * Copyright 2025 coze-dev Authors
 | |
|  *
 | |
|  * Licensed under the Apache License, Version 2.0 (the "License");
 | |
|  * you may not use this file except in compliance with the License.
 | |
|  * You may obtain a copy of the License at
 | |
|  *
 | |
|  *     http://www.apache.org/licenses/LICENSE-2.0
 | |
|  *
 | |
|  * Unless required by applicable law or agreed to in writing, software
 | |
|  * distributed under the License is distributed on an "AS IS" BASIS,
 | |
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
|  * See the License for the specific language governing permissions and
 | |
|  * limitations under the License.
 | |
|  */
 | |
| 
 | |
| package http
 | |
| 
 | |
| import (
 | |
| 	"bytes"
 | |
| 	"context"
 | |
| 	"encoding/json"
 | |
| 	"io"
 | |
| 	"net/http"
 | |
| 	"time"
 | |
| 
 | |
| 	opt "github.com/cloudwego/eino/components/embedding"
 | |
| 
 | |
| 	"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
 | |
| 
 | |
| 	"github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
 | |
| )
 | |
| 
 | |
| const pathEmbed = "/embedding"
 | |
| 
 | |
| type embedReq struct {
 | |
| 	Texts      []string `json:"texts"`
 | |
| 	NeedSparse bool     `json:"need_sparse"`
 | |
| }
 | |
| 
 | |
| type embedResp struct {
 | |
| 	Dense  [][]float64       `json:"dense"`
 | |
| 	Sparse []map[int]float64 `json:"sparse"`
 | |
| }
 | |
| 
 | |
| func NewEmbedding(addr string, dims int64, batchSize int) (embedding.Embedder, error) {
 | |
| 	cli := &http.Client{Timeout: time.Second * 30}
 | |
| 	return &embedder{
 | |
| 		cli:       cli,
 | |
| 		addr:      addr,
 | |
| 		dim:       dims,
 | |
| 		batchSize: batchSize,
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| type embedder struct {
 | |
| 	cli       *http.Client
 | |
| 	addr      string
 | |
| 	dim       int64
 | |
| 	batchSize int
 | |
| }
 | |
| 
 | |
| func (e *embedder) EmbedStrings(ctx context.Context, texts []string, opts ...opt.Option) ([][]float64, error) {
 | |
| 	dense := make([][]float64, 0, len(texts))
 | |
| 	for _, part := range slices.Chunks(texts, e.batchSize) {
 | |
| 		rb, err := json.Marshal(&embedReq{
 | |
| 			Texts:      part,
 | |
| 			NeedSparse: false,
 | |
| 		})
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 		req, err := http.NewRequestWithContext(ctx, http.MethodPost, e.addr+pathEmbed, bytes.NewReader(rb))
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 		req.Header.Set("Content-Type", "application/json; charset=utf-8")
 | |
| 		resp, err := e.do(req)
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 		dense = append(dense, resp.Dense...)
 | |
| 	}
 | |
| 	return dense, nil
 | |
| }
 | |
| 
 | |
| func (e *embedder) EmbedStringsHybrid(ctx context.Context, texts []string, opts ...opt.Option) ([][]float64, []map[int]float64, error) {
 | |
| 	dense := make([][]float64, 0, len(texts))
 | |
| 	sparse := make([]map[int]float64, 0, len(texts))
 | |
| 	for _, part := range slices.Chunks(texts, e.batchSize) {
 | |
| 		rb, err := json.Marshal(&embedReq{
 | |
| 			Texts:      part,
 | |
| 			NeedSparse: true,
 | |
| 		})
 | |
| 		if err != nil {
 | |
| 			return nil, nil, err
 | |
| 		}
 | |
| 		req, err := http.NewRequestWithContext(ctx, http.MethodPost, e.addr+pathEmbed, bytes.NewReader(rb))
 | |
| 		if err != nil {
 | |
| 			return nil, nil, err
 | |
| 		}
 | |
| 		req.Header.Set("Content-Type", "application/json; charset=utf-8")
 | |
| 		resp, err := e.do(req)
 | |
| 		if err != nil {
 | |
| 			return nil, nil, err
 | |
| 		}
 | |
| 		dense = append(dense, resp.Dense...)
 | |
| 		sparse = append(sparse, resp.Sparse...)
 | |
| 	}
 | |
| 	return dense, sparse, nil
 | |
| }
 | |
| 
 | |
| func (e *embedder) do(req *http.Request) (*embedResp, error) {
 | |
| 	resp, err := e.cli.Do(req)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	defer resp.Body.Close()
 | |
| 	b, err := io.ReadAll(resp.Body)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	r := &embedResp{}
 | |
| 	if err = json.Unmarshal(b, r); err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	return r, nil
 | |
| }
 | |
| 
 | |
| func (e *embedder) Dimensions() int64 {
 | |
| 	return e.dim
 | |
| }
 | |
| 
 | |
| func (e *embedder) SupportStatus() embedding.SupportStatus {
 | |
| 	return embedding.SupportDenseAndSparse
 | |
| }
 |