From 8137b0aee5d079e90abb8a2359a6d855f897b019 Mon Sep 17 00:00:00 2001 From: N3ko <1377648701@qq.com> Date: Tue, 29 Jul 2025 19:02:03 +0800 Subject: [PATCH] feat: init improvements (#174) --- backend/application/knowledge/init.go | 25 ++++++++++++ .../model/template/model_template_openai.yaml | 2 +- backend/go.mod | 7 ++-- backend/go.sum | 10 +++-- .../infra/impl/document/ocr/veocr/ve_ocr.go | 38 +++++++++++++++++-- backend/infra/impl/embedding/ark/ark.go | 21 +++++++++- backend/infra/impl/embedding/wrap/ollama.go | 32 ++++++++++++++++ .../infra/impl/modelmgr/static/modelmgr.go | 5 +++ docker/.env.example | 8 +++- 9 files changed, 133 insertions(+), 15 deletions(-) create mode 100644 backend/infra/impl/embedding/wrap/ollama.go diff --git a/backend/application/knowledge/init.go b/backend/application/knowledge/init.go index 2cf587a6..d2591028 100644 --- a/backend/application/knowledge/init.go +++ b/backend/application/knowledge/init.go @@ -26,6 +26,7 @@ import ( "time" "github.com/cloudwego/eino-ext/components/embedding/ark" + ollamaEmb "github.com/cloudwego/eino-ext/components/embedding/ollama" "github.com/cloudwego/eino-ext/components/embedding/openai" ao "github.com/cloudwego/eino-ext/components/model/ark" "github.com/cloudwego/eino-ext/components/model/deepseek" @@ -111,6 +112,9 @@ func InitService(c *ServiceComponents) (*KnowledgeApplicationService, error) { case "ve": ocrAK := os.Getenv("VE_OCR_AK") ocrSK := os.Getenv("VE_OCR_SK") + if ocrAK == "" || ocrSK == "" { + logs.Warnf("[ve_ocr] ak / sk not configured, ocr might not work well") + } inst := visual.NewInstance() inst.Client.SetAccessKey(ocrAK) inst.Client.SetSecretKey(ocrSK) @@ -346,6 +350,27 @@ func getEmbedding(ctx context.Context) (embedding.Embedder, error) { if err != nil { return nil, fmt.Errorf("init ark embedding client failed, err=%w", err) } + + case "ollama": + var ( + ollamaEmbeddingBaseURL = os.Getenv("OLLAMA_EMBEDDING_BASE_URL") + ollamaEmbeddingModel = os.Getenv("OLLAMA_EMBEDDING_MODEL") + ollamaEmbeddingDims = os.Getenv("OLLAMA_EMBEDDING_DIMS") + ) + + dims, err := strconv.ParseInt(ollamaEmbeddingDims, 10, 64) + if err != nil { + return nil, fmt.Errorf("init ollama embedding dims failed, err=%w", err) + } + + emb, err = wrap.NewOllamaEmbedder(ctx, &ollamaEmb.EmbeddingConfig{ + BaseURL: ollamaEmbeddingBaseURL, + Model: ollamaEmbeddingModel, + }, dims) + if err != nil { + return nil, fmt.Errorf("init ollama embedding failed, err=%w", err) + } + default: return nil, fmt.Errorf("init knowledge embedding failed, type not configured") } diff --git a/backend/conf/model/template/model_template_openai.yaml b/backend/conf/model/template/model_template_openai.yaml index 5a32b075..84fb7449 100755 --- a/backend/conf/model/template/model_template_openai.yaml +++ b/backend/conf/model/template/model_template_openai.yaml @@ -157,7 +157,7 @@ meta: top_k: 0 stop: [] openai: - by_azure: true + by_azure: false api_version: "" response_format: type: text diff --git a/backend/go.mod b/backend/go.mod index f701b6db..11a7b965 100755 --- a/backend/go.mod +++ b/backend/go.mod @@ -12,7 +12,7 @@ require ( github.com/apache/thrift v0.21.0 github.com/bytedance/mockey v1.2.14 github.com/bytedance/sonic v1.13.2 - github.com/cloudwego/eino v0.3.51 + github.com/cloudwego/eino v0.3.55 github.com/cloudwego/eino-ext/components/model/ark v0.1.15 github.com/cloudwego/eino-ext/components/model/claude v0.1.1 github.com/cloudwego/eino-ext/components/model/deepseek v0.0.0-20250715055739-0d0e28441a2f @@ -55,6 +55,7 @@ require github.com/alicebob/miniredis/v2 v2.34.0 require ( github.com/DATA-DOG/go-sqlmock v1.5.2 github.com/cloudwego/eino-ext/components/embedding/ark v0.0.0-20250522060253-ddb617598b09 + github.com/cloudwego/eino-ext/components/embedding/ollama v0.0.0-20250728060543-79ec300857b8 github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20250522060253-ddb617598b09 github.com/cloudwego/eino-ext/components/model/gemini v0.1.2 github.com/cloudwego/eino-ext/components/model/ollama v0.0.0-20250610035057-2c4e7c8488a5 @@ -66,7 +67,7 @@ require ( github.com/jinzhu/copier v0.4.0 github.com/mattn/go-shellwords v1.0.12 github.com/nsqio/go-nsq v1.1.0 - github.com/ollama/ollama v0.6.5 + github.com/ollama/ollama v0.9.6 github.com/rbretecher/go-postman-collection v0.9.0 github.com/volcengine/ve-tos-golang-sdk/v2 v2.7.17 github.com/yuin/goldmark v1.4.13 @@ -246,7 +247,7 @@ require ( github.com/tmc/grpc-websocket-proxy v0.0.0-20201229170055-e5319fda7802 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/uber/jaeger-client-go v2.30.0+incompatible // indirect - github.com/volcengine/volcengine-go-sdk v1.1.20 // indirect + github.com/volcengine/volcengine-go-sdk v1.1.20 github.com/x448/float16 v0.8.4 // indirect github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2 // indirect github.com/xuri/efp v0.0.0-20240408161823-9ad904a10d6d // indirect diff --git a/backend/go.sum b/backend/go.sum index 7c5fafc1..37dcbbec 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -938,10 +938,12 @@ github.com/circonus-labs/circonusllhist v0.1.3/go.mod h1:kMXHVDlOchFAehlya5ePtbp github.com/clbanning/mxj v1.8.4/go.mod h1:BVjHeAH+rl9rs6f+QIpeRl0tfu10SXn1pUSa5PVGJng= github.com/cloudwego/base64x v0.1.5 h1:XPciSp1xaq2VCSt6lF0phncD4koWyULpl5bUxbfCyP4= github.com/cloudwego/base64x v0.1.5/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= -github.com/cloudwego/eino v0.3.51 h1:emSaDu49v9EEJYOusL42Li/VL5QBSyBvhxO9ZcKPZvs= -github.com/cloudwego/eino v0.3.51/go.mod h1:wUjz990apdsaOraOXdh6CdhVXq8DJsOvLsVlxNTcNfY= +github.com/cloudwego/eino v0.3.55 h1:lMZrGtEh0k3qykQTLNXSXuAa98OtF2tS43GMHyvN7nA= +github.com/cloudwego/eino v0.3.55/go.mod h1:wUjz990apdsaOraOXdh6CdhVXq8DJsOvLsVlxNTcNfY= github.com/cloudwego/eino-ext/components/embedding/ark v0.0.0-20250522060253-ddb617598b09 h1:hZScBE/Etiji2RqjlABcAkq6n1uzYPu+jo4GV5TF8Hc= github.com/cloudwego/eino-ext/components/embedding/ark v0.0.0-20250522060253-ddb617598b09/go.mod h1:pLtH5BZKgb7/bB8+P3W5/f1d46gTl9K77+08j88Gb4k= +github.com/cloudwego/eino-ext/components/embedding/ollama v0.0.0-20250728060543-79ec300857b8 h1:uJrs6SmfYnca8A+k9+3qJ4MYwYHMncUlGac1mYQT+Ak= +github.com/cloudwego/eino-ext/components/embedding/ollama v0.0.0-20250728060543-79ec300857b8/go.mod h1:nav79aUcd+UR24dLA+7l7RcHCMlg26zbDAKvjONdrw0= github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20250522060253-ddb617598b09 h1:C8RjF193iguUuevkuv0q4SC+XGlM/DlJEgic7l8OUAI= github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20250522060253-ddb617598b09/go.mod h1:S09z/CAQNyx+AbgfJRQXLUAYlPpxQWWLVuQxO34F90A= github.com/cloudwego/eino-ext/components/model/ark v0.1.15 h1:ydOvtEK67VI5DvNgg64eTxbjxMYhGBMOVP2okaZKk18= @@ -1616,8 +1618,8 @@ github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+ github.com/nyaruka/phonenumbers v1.0.55 h1:bj0nTO88Y68KeUQ/n3Lo2KgK7lM1hF7L9NFuwcCl3yg= github.com/nyaruka/phonenumbers v1.0.55/go.mod h1:sDaTZ/KPX5f8qyV9qN+hIm+4ZBARJrupC6LuhshJq1U= github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= -github.com/ollama/ollama v0.6.5 h1:vXKkVX57ql/1ZzMw4SVK866Qfd6pjwEcITVyEpF0QXQ= -github.com/ollama/ollama v0.6.5/go.mod h1:pGgtoNyc9DdM6oZI6yMfI6jTk2Eh4c36c2GpfQCH7PY= +github.com/ollama/ollama v0.9.6 h1:HZNJmB52pMt6zLkGkkheBuXBXM5478eiSAj7GR75AMc= +github.com/ollama/ollama v0.9.6/go.mod h1:zLwx3iZ3AI4Rc/egsrx3u1w4RU2MHQ/Ylxse48jvyt4= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.8.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= diff --git a/backend/infra/impl/document/ocr/veocr/ve_ocr.go b/backend/infra/impl/document/ocr/veocr/ve_ocr.go index 59d3b1ee..56a01708 100644 --- a/backend/infra/impl/document/ocr/veocr/ve_ocr.go +++ b/backend/infra/impl/document/ocr/veocr/ve_ocr.go @@ -18,12 +18,16 @@ package veocr import ( "context" + "errors" "fmt" "net/http" "net/url" "strconv" + "github.com/coze-dev/coze-studio/backend/pkg/errorx" + "github.com/coze-dev/coze-studio/backend/types/errno" "github.com/volcengine/volc-sdk-golang/service/visual" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model" "github.com/coze-dev/coze-studio/backend/infra/contract/document/ocr" ) @@ -52,10 +56,14 @@ func (o *ocrImpl) FromBase64(ctx context.Context, b64 string) ([]string, error) resp, statusCode, err := o.config.Client.OCRNormal(form) if err != nil { - return nil, err + return nil, o.handleError(fmt.Errorf("[ve_ocr][FromBase64] OCRNormal err: %w", err)) } if statusCode != http.StatusOK { - return nil, fmt.Errorf("[FromBase64] failed, status code=%d", statusCode) + err = fmt.Errorf("[ve_ocr][FromBase64] OCRNormal failed, status code=%d", statusCode) + if statusCode == http.StatusBadRequest { + return nil, errorx.WrapByCode(err, errno.ErrKnowledgeNonRetryableCode) + } + return nil, err } return resp.Data.LineTexts, nil @@ -67,10 +75,14 @@ func (o *ocrImpl) FromURL(ctx context.Context, url string) ([]string, error) { resp, statusCode, err := o.config.Client.OCRNormal(form) if err != nil { - return nil, err + return nil, o.handleError(fmt.Errorf("[ve_ocr][FromURL] OCRNormal error: %w", err)) } if statusCode != http.StatusOK { - return nil, fmt.Errorf("[FromBase64] failed, status code=%d", statusCode) + err = fmt.Errorf("[ve_ocr][FromURL] OCRNormal failed, status code=%d", statusCode) + if statusCode == http.StatusBadRequest { + return nil, errorx.WrapByCode(err, errno.ErrKnowledgeNonRetryableCode) + } + return nil, err } return resp.Data.LineTexts, nil @@ -94,3 +106,21 @@ func (o *ocrImpl) newForm() url.Values { } return form } + +func (o *ocrImpl) handleError(err error) error { + var ( + apiErr = &model.APIError{} + reqErr = &model.RequestError{} + ) + if errors.As(err, &apiErr) { + if apiErr.HTTPStatusCode >= http.StatusInternalServerError || + apiErr.HTTPStatusCode == http.StatusTooManyRequests { + return err + } + } else if errors.As(err, &reqErr) { + if reqErr.HTTPStatusCode >= http.StatusInternalServerError { + return err + } + } + return errorx.WrapByCode(err, errno.ErrKnowledgeNonRetryableCode) +} diff --git a/backend/infra/impl/embedding/ark/ark.go b/backend/infra/impl/embedding/ark/ark.go index be3cf728..8e850589 100644 --- a/backend/infra/impl/embedding/ark/ark.go +++ b/backend/infra/impl/embedding/ark/ark.go @@ -18,11 +18,16 @@ package ark import ( "context" + "errors" "fmt" "math" + "net/http" "github.com/cloudwego/eino-ext/components/embedding/ark" "github.com/cloudwego/eino/components/embedding" + "github.com/coze-dev/coze-studio/backend/pkg/errorx" + "github.com/coze-dev/coze-studio/backend/types/errno" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model" contract "github.com/coze-dev/coze-studio/backend/infra/contract/embedding" "github.com/coze-dev/coze-studio/backend/pkg/lang/slices" @@ -51,7 +56,21 @@ func (d embWrap) EmbedStrings(ctx context.Context, texts []string, opts ...embed } normed, err := d.slicedNormL2(partResult) if err != nil { - return nil, err + var ( + apiErr = &model.APIError{} + reqErr = &model.RequestError{} + ) + if errors.As(err, &apiErr) { + if apiErr.HTTPStatusCode >= http.StatusInternalServerError || + apiErr.HTTPStatusCode == http.StatusTooManyRequests { + return nil, err + } + } else if errors.As(err, &reqErr) { + if reqErr.HTTPStatusCode >= http.StatusInternalServerError { + return nil, err + } + } + return nil, errorx.WrapByCode(err, errno.ErrKnowledgeNonRetryableCode) } resp = append(resp, normed...) } diff --git a/backend/infra/impl/embedding/wrap/ollama.go b/backend/infra/impl/embedding/wrap/ollama.go new file mode 100644 index 00000000..bc062572 --- /dev/null +++ b/backend/infra/impl/embedding/wrap/ollama.go @@ -0,0 +1,32 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package wrap + +import ( + "context" + + "github.com/cloudwego/eino-ext/components/embedding/ollama" + contract "github.com/coze-dev/coze-studio/backend/infra/contract/embedding" +) + +func NewOllamaEmbedder(ctx context.Context, config *ollama.EmbeddingConfig, dimensions int64) (contract.Embedder, error) { + emb, err := ollama.NewEmbedder(ctx, config) + if err != nil { + return nil, err + } + return &denseOnlyWrap{dims: dimensions, Embedder: emb}, nil +} diff --git a/backend/infra/impl/modelmgr/static/modelmgr.go b/backend/infra/impl/modelmgr/static/modelmgr.go index 8b05054b..adf32e68 100644 --- a/backend/infra/impl/modelmgr/static/modelmgr.go +++ b/backend/infra/impl/modelmgr/static/modelmgr.go @@ -24,9 +24,14 @@ import ( "github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" "github.com/coze-dev/coze-studio/backend/pkg/lang/sets" + "github.com/coze-dev/coze-studio/backend/pkg/logs" ) func NewModelMgr(staticModels []*modelmgr.Model) (modelmgr.Manager, error) { + if len(staticModels) == 0 { + logs.Warnf("[NewModelMgr] no static models found, please check if the config has been loaded correctly") + } + mapping := make(map[int64]*modelmgr.Model, len(staticModels)) for i := range staticModels { mapping[staticModels[i].ID] = staticModels[i] diff --git a/docker/.env.example b/docker/.env.example index a82a892a..d4d14166 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -90,8 +90,8 @@ export VIKING_DB_MODEL_NAME="" # if vikingdb model name is not set, you need to # Settings for Embedding # The Embedding model relied on by knowledge base vectorization does not need to be configured # if the vector database comes with built-in Embedding functionality (such as VikingDB). Currently, -# Coze Studio supports three access methods: openai, ark, and custom http. Users can simply choose one of them when using -# embedding type: openai / ark / http +# Coze Studio supports three access methods: openai, ark, ollama, and custom http. Users can simply choose one of them when using +# embedding type: openai / ark / ollama / http export EMBEDDING_TYPE="ark" # openai embedding export OPENAI_EMBEDDING_BASE_URL="" # (string) OpenAI base_url @@ -108,6 +108,10 @@ export ARK_EMBEDDING_AK="" export ARK_EMBEDDING_DIMS="2048" export ARK_EMBEDDING_BASE_URL="" +# ollama embedding +export OLLAMA_EMBEDDING_BASE_URL="" +export OLLAMA_EMBEDDING_MODEL="" +export OLLAMA_EMBEDDING_DIMS="" # http embedding export HTTP_EMBEDDING_ADDR="http://127.0.0.1:6543"