refactor(knowledge): Move the all dependent components to app infra (#795)
This commit is contained in:
parent
23a468c72c
commit
f940edf585
|
|
@ -21,7 +21,6 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/alicebob/miniredis/v2"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
|
@ -31,6 +30,8 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/alicebob/miniredis/v2"
|
||||||
|
|
||||||
"github.com/bytedance/mockey"
|
"github.com/bytedance/mockey"
|
||||||
"github.com/cloudwego/eino/callbacks"
|
"github.com/cloudwego/eino/callbacks"
|
||||||
model2 "github.com/cloudwego/eino/components/model"
|
model2 "github.com/cloudwego/eino/components/model"
|
||||||
|
|
|
||||||
|
|
@ -67,6 +67,7 @@ import (
|
||||||
variablesImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/variables"
|
variablesImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/variables"
|
||||||
workflowImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/workflow"
|
workflowImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/workflow"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/contract/eventbus"
|
"github.com/coze-dev/coze-studio/backend/infra/contract/eventbus"
|
||||||
|
"github.com/coze-dev/coze-studio/backend/infra/impl/chatmodel"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/impl/checkpoint"
|
"github.com/coze-dev/coze-studio/backend/infra/impl/checkpoint"
|
||||||
implEventbus "github.com/coze-dev/coze-studio/backend/infra/impl/eventbus"
|
implEventbus "github.com/coze-dev/coze-studio/backend/infra/impl/eventbus"
|
||||||
)
|
)
|
||||||
|
|
@ -191,7 +192,9 @@ func initPrimaryServices(ctx context.Context, basicServices *basicServices) (*pr
|
||||||
|
|
||||||
memorySVC := memory.InitService(basicServices.toMemoryServiceComponents())
|
memorySVC := memory.InitService(basicServices.toMemoryServiceComponents())
|
||||||
|
|
||||||
knowledgeSVC, err := knowledge.InitService(basicServices.toKnowledgeServiceComponents(memorySVC))
|
knowledgeSVC, err := knowledge.InitService(ctx,
|
||||||
|
basicServices.toKnowledgeServiceComponents(memorySVC),
|
||||||
|
basicServices.eventbus.resourceEventBus)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -256,14 +259,18 @@ func (b *basicServices) toPluginServiceComponents() *plugin.ServiceComponents {
|
||||||
func (b *basicServices) toKnowledgeServiceComponents(memoryService *memory.MemoryApplicationServices) *knowledge.ServiceComponents {
|
func (b *basicServices) toKnowledgeServiceComponents(memoryService *memory.MemoryApplicationServices) *knowledge.ServiceComponents {
|
||||||
return &knowledge.ServiceComponents{
|
return &knowledge.ServiceComponents{
|
||||||
DB: b.infra.DB,
|
DB: b.infra.DB,
|
||||||
IDGenSVC: b.infra.IDGenSVC,
|
IDGen: b.infra.IDGenSVC,
|
||||||
Storage: b.infra.TOSClient,
|
|
||||||
RDB: memoryService.RDBDomainSVC,
|
RDB: memoryService.RDBDomainSVC,
|
||||||
|
Producer: b.infra.KnowledgeEventProducer,
|
||||||
SearchStoreManagers: b.infra.SearchStoreManagers,
|
SearchStoreManagers: b.infra.SearchStoreManagers,
|
||||||
EventBus: b.eventbus.resourceEventBus,
|
ParseManager: b.infra.ParserManager,
|
||||||
CacheCli: b.infra.CacheCli,
|
Storage: b.infra.TOSClient,
|
||||||
|
Rewriter: b.infra.Rewriter,
|
||||||
|
Reranker: b.infra.Reranker,
|
||||||
|
NL2Sql: b.infra.NL2SQL,
|
||||||
OCR: b.infra.OCR,
|
OCR: b.infra.OCR,
|
||||||
ParserManager: b.infra.ParserManager,
|
CacheCli: b.infra.CacheCli,
|
||||||
|
ModelFactory: chatmodel.NewDefaultFactory(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -280,18 +287,19 @@ func (b *basicServices) toMemoryServiceComponents() *memory.ServiceComponents {
|
||||||
|
|
||||||
func (b *basicServices) toWorkflowServiceComponents(pluginSVC *plugin.PluginApplicationService, memorySVC *memory.MemoryApplicationServices, knowledgeSVC *knowledge.KnowledgeApplicationService) *workflow.ServiceComponents {
|
func (b *basicServices) toWorkflowServiceComponents(pluginSVC *plugin.PluginApplicationService, memorySVC *memory.MemoryApplicationServices, knowledgeSVC *knowledge.KnowledgeApplicationService) *workflow.ServiceComponents {
|
||||||
return &workflow.ServiceComponents{
|
return &workflow.ServiceComponents{
|
||||||
IDGen: b.infra.IDGenSVC,
|
IDGen: b.infra.IDGenSVC,
|
||||||
DB: b.infra.DB,
|
DB: b.infra.DB,
|
||||||
Cache: b.infra.CacheCli,
|
Cache: b.infra.CacheCli,
|
||||||
Tos: b.infra.TOSClient,
|
Tos: b.infra.TOSClient,
|
||||||
ImageX: b.infra.ImageXClient,
|
ImageX: b.infra.ImageXClient,
|
||||||
DatabaseDomainSVC: memorySVC.DatabaseDomainSVC,
|
DatabaseDomainSVC: memorySVC.DatabaseDomainSVC,
|
||||||
VariablesDomainSVC: memorySVC.VariablesDomainSVC,
|
VariablesDomainSVC: memorySVC.VariablesDomainSVC,
|
||||||
PluginDomainSVC: pluginSVC.DomainSVC,
|
PluginDomainSVC: pluginSVC.DomainSVC,
|
||||||
KnowledgeDomainSVC: knowledgeSVC.DomainSVC,
|
KnowledgeDomainSVC: knowledgeSVC.DomainSVC,
|
||||||
DomainNotifier: b.eventbus.resourceEventBus,
|
DomainNotifier: b.eventbus.resourceEventBus,
|
||||||
CPStore: checkpoint.NewRedisStore(b.infra.CacheCli),
|
CPStore: checkpoint.NewRedisStore(b.infra.CacheCli),
|
||||||
CodeRunner: b.infra.CodeRunner,
|
CodeRunner: b.infra.CodeRunner,
|
||||||
|
WorkflowBuildInChatModel: b.infra.WorkflowBuildInChatModel,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,9 +18,11 @@ package appinfra
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
@ -31,26 +33,32 @@ import (
|
||||||
"github.com/cloudwego/eino-ext/components/embedding/gemini"
|
"github.com/cloudwego/eino-ext/components/embedding/gemini"
|
||||||
"github.com/cloudwego/eino-ext/components/embedding/ollama"
|
"github.com/cloudwego/eino-ext/components/embedding/ollama"
|
||||||
"github.com/cloudwego/eino-ext/components/embedding/openai"
|
"github.com/cloudwego/eino-ext/components/embedding/openai"
|
||||||
|
"github.com/cloudwego/eino/components/prompt"
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
"github.com/milvus-io/milvus/client/v2/milvusclient"
|
"github.com/milvus-io/milvus/client/v2/milvusclient"
|
||||||
"github.com/volcengine/volc-sdk-golang/service/visual"
|
"github.com/volcengine/volc-sdk-golang/service/visual"
|
||||||
|
|
||||||
"github.com/coze-dev/coze-studio/backend/application/internal"
|
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
|
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
|
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/contract/coderunner"
|
"github.com/coze-dev/coze-studio/backend/infra/contract/coderunner"
|
||||||
|
"github.com/coze-dev/coze-studio/backend/infra/contract/document/nl2sql"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/ocr"
|
"github.com/coze-dev/coze-studio/backend/infra/contract/document/ocr"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
|
"github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
|
||||||
|
"github.com/coze-dev/coze-studio/backend/infra/contract/document/rerank"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
|
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
|
"github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
|
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
|
||||||
|
"github.com/coze-dev/coze-studio/backend/infra/contract/messages2query"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
|
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/impl/cache/redis"
|
"github.com/coze-dev/coze-studio/backend/infra/impl/cache/redis"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/impl/coderunner/direct"
|
"github.com/coze-dev/coze-studio/backend/infra/impl/coderunner/direct"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/impl/coderunner/sandbox"
|
"github.com/coze-dev/coze-studio/backend/infra/impl/coderunner/sandbox"
|
||||||
|
builtinNL2SQL "github.com/coze-dev/coze-studio/backend/infra/impl/document/nl2sql/builtin"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/ocr/ppocr"
|
"github.com/coze-dev/coze-studio/backend/infra/impl/document/ocr/ppocr"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/ocr/veocr"
|
"github.com/coze-dev/coze-studio/backend/infra/impl/document/ocr/veocr"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/parser/builtin"
|
"github.com/coze-dev/coze-studio/backend/infra/impl/document/parser/builtin"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/parser/ppstructure"
|
"github.com/coze-dev/coze-studio/backend/infra/impl/document/parser/ppstructure"
|
||||||
|
"github.com/coze-dev/coze-studio/backend/infra/impl/document/rerank/rrf"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/elasticsearch"
|
"github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/elasticsearch"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/milvus"
|
"github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/milvus"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/vikingdb"
|
"github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/vikingdb"
|
||||||
|
|
@ -61,6 +69,7 @@ import (
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/impl/eventbus"
|
"github.com/coze-dev/coze-studio/backend/infra/impl/eventbus"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/impl/idgen"
|
"github.com/coze-dev/coze-studio/backend/infra/impl/idgen"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/impl/imagex/veimagex"
|
"github.com/coze-dev/coze-studio/backend/infra/impl/imagex/veimagex"
|
||||||
|
builtinM2Q "github.com/coze-dev/coze-studio/backend/infra/impl/messages2query/builtin"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/impl/mysql"
|
"github.com/coze-dev/coze-studio/backend/infra/impl/mysql"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/impl/storage"
|
"github.com/coze-dev/coze-studio/backend/infra/impl/storage"
|
||||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||||
|
|
@ -70,19 +79,24 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type AppDependencies struct {
|
type AppDependencies struct {
|
||||||
DB *gorm.DB
|
DB *gorm.DB
|
||||||
CacheCli cache.Cmdable
|
CacheCli cache.Cmdable
|
||||||
IDGenSVC idgen.IDGenerator
|
IDGenSVC idgen.IDGenerator
|
||||||
ESClient es.Client
|
ESClient es.Client
|
||||||
ImageXClient imagex.ImageX
|
ImageXClient imagex.ImageX
|
||||||
TOSClient storage.Storage
|
TOSClient storage.Storage
|
||||||
ResourceEventProducer eventbus.Producer
|
ResourceEventProducer eventbus.Producer
|
||||||
AppEventProducer eventbus.Producer
|
AppEventProducer eventbus.Producer
|
||||||
ModelMgr modelmgr.Manager
|
KnowledgeEventProducer eventbus.Producer
|
||||||
CodeRunner coderunner.Runner
|
ModelMgr modelmgr.Manager
|
||||||
OCR ocr.OCR
|
CodeRunner coderunner.Runner
|
||||||
ParserManager parser.Manager
|
OCR ocr.OCR
|
||||||
SearchStoreManagers []searchstore.Manager
|
ParserManager parser.Manager
|
||||||
|
SearchStoreManagers []searchstore.Manager
|
||||||
|
Reranker rerank.Reranker
|
||||||
|
Rewriter messages2query.MessagesToQuery
|
||||||
|
NL2SQL nl2sql.NL2SQL
|
||||||
|
WorkflowBuildInChatModel chatmodel.BaseChatModel
|
||||||
}
|
}
|
||||||
|
|
||||||
func Init(ctx context.Context) (*AppDependencies, error) {
|
func Init(ctx context.Context) (*AppDependencies, error) {
|
||||||
|
|
@ -126,6 +140,23 @@ func Init(ctx context.Context) (*AppDependencies, error) {
|
||||||
return nil, fmt.Errorf("init app event producer failed, err=%w", err)
|
return nil, fmt.Errorf("init app event producer failed, err=%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
deps.KnowledgeEventProducer, err = initKnowledgeEventBusProducer()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("init knowledge event bus producer failed, err=%w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
deps.Reranker = rrf.NewRRFReranker(0)
|
||||||
|
|
||||||
|
deps.Rewriter, err = initRewriter(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("init rewriter failed, err=%w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
deps.NL2SQL, err = initNL2SQL(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("init nl2sql failed, err=%w", err)
|
||||||
|
}
|
||||||
|
|
||||||
deps.ModelMgr, err = initModelMgr()
|
deps.ModelMgr, err = initModelMgr()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("init model manager failed, err=%w", err)
|
return nil, fmt.Errorf("init model manager failed, err=%w", err)
|
||||||
|
|
@ -135,11 +166,21 @@ func Init(ctx context.Context) (*AppDependencies, error) {
|
||||||
|
|
||||||
deps.OCR = initOCR()
|
deps.OCR = initOCR()
|
||||||
|
|
||||||
imageAnnotationModel, _, err := internal.GetBuiltinChatModel(ctx, "IA_")
|
imageAnnotationModel, _, err := getBuiltinChatModel(ctx, "IA_")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get builtin chat model failed, err=%w", err)
|
return nil, fmt.Errorf("get builtin chat model failed, err=%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var ok bool
|
||||||
|
deps.WorkflowBuildInChatModel, ok, err = getBuiltinChatModel(ctx, "WKR_")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get workflow builtin chat model failed, err=%w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
logs.CtxWarnf(ctx, "workflow builtin chat model for knowledge recall not configured")
|
||||||
|
}
|
||||||
|
|
||||||
deps.ParserManager, err = initParserManager(deps.TOSClient, deps.OCR, imageAnnotationModel)
|
deps.ParserManager, err = initParserManager(deps.TOSClient, deps.OCR, imageAnnotationModel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("init parser manager failed, err=%w", err)
|
return nil, fmt.Errorf("init parser manager failed, err=%w", err)
|
||||||
|
|
@ -166,6 +207,71 @@ func initSearchStoreManagers(ctx context.Context, es es.Client) ([]searchstore.M
|
||||||
return []searchstore.Manager{esSearchstoreManager, mgr}, nil
|
return []searchstore.Manager{esSearchstoreManager, mgr}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func initRewriter(ctx context.Context) (messages2query.MessagesToQuery, error) {
|
||||||
|
rewriterChatModel, _, err := getBuiltinChatModel(ctx, "M2Q_")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
filePath := filepath.Join(getWorkingDirectory(), "resources/conf/prompt/messages_to_query_template_jinja2.json")
|
||||||
|
rewriterTemplate, err := readJinja2PromptTemplate(filePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter, err := builtinM2Q.NewMessagesToQuery(ctx, rewriterChatModel, rewriterTemplate)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return rewriter, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getWorkingDirectory() string {
|
||||||
|
root, err := os.Getwd()
|
||||||
|
if err != nil {
|
||||||
|
logs.Warnf("[InitConfig] Failed to get current working directory: %v", err)
|
||||||
|
root = os.Getenv("PWD")
|
||||||
|
}
|
||||||
|
return root
|
||||||
|
}
|
||||||
|
|
||||||
|
func readJinja2PromptTemplate(jsonFilePath string) (prompt.ChatTemplate, error) {
|
||||||
|
b, err := os.ReadFile(jsonFilePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var m2qMessages []*schema.Message
|
||||||
|
if err = json.Unmarshal(b, &m2qMessages); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tpl := make([]schema.MessagesTemplate, len(m2qMessages))
|
||||||
|
for i := range m2qMessages {
|
||||||
|
tpl[i] = m2qMessages[i]
|
||||||
|
}
|
||||||
|
return prompt.FromMessages(schema.Jinja2, tpl...), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func initNL2SQL(ctx context.Context) (nl2sql.NL2SQL, error) {
|
||||||
|
n2sChatModel, _, err := getBuiltinChatModel(ctx, "NL2SQL_")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
filePath := filepath.Join(getWorkingDirectory(), "resources/conf/prompt/nl2sql_template_jinja2.json")
|
||||||
|
n2sTemplate, err := readJinja2PromptTemplate(filePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
n2s, err := builtinNL2SQL.NewNL2SQL(ctx, n2sChatModel, n2sTemplate)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return n2s, nil
|
||||||
|
}
|
||||||
|
|
||||||
func initImageX(ctx context.Context) (imagex.ImageX, error) {
|
func initImageX(ctx context.Context) (imagex.ImageX, error) {
|
||||||
uploadComponentType := os.Getenv(consts.FileUploadComponentType)
|
uploadComponentType := os.Getenv(consts.FileUploadComponentType)
|
||||||
if uploadComponentType != consts.FileUploadComponentTypeImagex {
|
if uploadComponentType != consts.FileUploadComponentTypeImagex {
|
||||||
|
|
@ -206,6 +312,17 @@ func initAppEventProducer() (eventbus.Producer, error) {
|
||||||
return appEventProducer, nil
|
return appEventProducer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func initKnowledgeEventBusProducer() (eventbus.Producer, error) {
|
||||||
|
nameServer := os.Getenv(consts.MQServer)
|
||||||
|
|
||||||
|
knowledgeProducer, err := eventbus.NewProducer(nameServer, consts.RMQTopicKnowledge, consts.RMQConsumeGroupKnowledge, 2)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("init knowledge producer failed, err=%w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return knowledgeProducer, nil
|
||||||
|
}
|
||||||
|
|
||||||
func initCodeRunner() coderunner.Runner {
|
func initCodeRunner() coderunner.Runner {
|
||||||
switch typ := os.Getenv(consts.CodeRunnerType); typ {
|
switch typ := os.Getenv(consts.CodeRunnerType); typ {
|
||||||
case "sandbox":
|
case "sandbox":
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package internal
|
package appinfra
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
@ -33,7 +33,7 @@ import (
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
|
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetBuiltinChatModel(ctx context.Context, envPrefix string) (bcm chatmodel.BaseChatModel, configured bool, err error) {
|
func getBuiltinChatModel(ctx context.Context, envPrefix string) (bcm chatmodel.BaseChatModel, configured bool, err error) {
|
||||||
getEnv := func(key string) string {
|
getEnv := func(key string) string {
|
||||||
if val := os.Getenv(envPrefix + key); val != "" {
|
if val := os.Getenv(envPrefix + key); val != "" {
|
||||||
return val
|
return val
|
||||||
|
|
@ -99,7 +99,7 @@ func GetBuiltinChatModel(ctx context.Context, envPrefix string) (bcm chatmodel.B
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, false, fmt.Errorf("knowledge init openai chat mode failed, %w", err)
|
return nil, false, fmt.Errorf("builtin %s chat model init failed, %w", envPrefix, err)
|
||||||
}
|
}
|
||||||
if bcm != nil {
|
if bcm != nil {
|
||||||
configured = true
|
configured = true
|
||||||
|
|
@ -18,132 +18,27 @@ package knowledge
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
|
||||||
|
|
||||||
"github.com/cloudwego/eino/components/prompt"
|
|
||||||
"github.com/cloudwego/eino/schema"
|
|
||||||
"gorm.io/gorm"
|
|
||||||
|
|
||||||
"github.com/coze-dev/coze-studio/backend/application/internal"
|
|
||||||
"github.com/coze-dev/coze-studio/backend/application/search"
|
"github.com/coze-dev/coze-studio/backend/application/search"
|
||||||
knowledgeImpl "github.com/coze-dev/coze-studio/backend/domain/knowledge/service"
|
knowledgeImpl "github.com/coze-dev/coze-studio/backend/domain/knowledge/service"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
|
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/nl2sql"
|
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/ocr"
|
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
|
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
|
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/contract/messages2query"
|
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/contract/rdb"
|
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
|
||||||
chatmodelImpl "github.com/coze-dev/coze-studio/backend/infra/impl/chatmodel"
|
|
||||||
builtinNL2SQL "github.com/coze-dev/coze-studio/backend/infra/impl/document/nl2sql/builtin"
|
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/rerank/rrf"
|
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/impl/eventbus"
|
"github.com/coze-dev/coze-studio/backend/infra/impl/eventbus"
|
||||||
builtinM2Q "github.com/coze-dev/coze-studio/backend/infra/impl/messages2query/builtin"
|
|
||||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
|
||||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ServiceComponents struct {
|
type ServiceComponents = knowledgeImpl.KnowledgeSVCConfig
|
||||||
DB *gorm.DB
|
|
||||||
IDGenSVC idgen.IDGenerator
|
|
||||||
Storage storage.Storage
|
|
||||||
RDB rdb.RDB
|
|
||||||
EventBus search.ResourceEventBus
|
|
||||||
CacheCli cache.Cmdable
|
|
||||||
OCR ocr.OCR
|
|
||||||
ParserManager parser.Manager
|
|
||||||
SearchStoreManagers []searchstore.Manager
|
|
||||||
}
|
|
||||||
|
|
||||||
func InitService(c *ServiceComponents) (*KnowledgeApplicationService, error) {
|
func InitService(ctx context.Context, c *ServiceComponents, bus search.ResourceEventBus) (*KnowledgeApplicationService, error) {
|
||||||
ctx := context.Background()
|
knowledgeDomainSVC, knowledgeEventHandler := knowledgeImpl.NewKnowledgeSVC(c)
|
||||||
|
|
||||||
nameServer := os.Getenv(consts.MQServer)
|
nameServer := os.Getenv(consts.MQServer)
|
||||||
|
if err := eventbus.DefaultSVC().RegisterConsumer(nameServer, consts.RMQTopicKnowledge, consts.RMQConsumeGroupKnowledge, knowledgeEventHandler); err != nil {
|
||||||
knowledgeProducer, err := eventbus.NewProducer(nameServer, consts.RMQTopicKnowledge, consts.RMQConsumeGroupKnowledge, 2)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("init knowledge producer failed, err=%w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
root, err := os.Getwd()
|
|
||||||
if err != nil {
|
|
||||||
logs.Warnf("[InitConfig] Failed to get current working directory: %v", err)
|
|
||||||
root = os.Getenv("PWD")
|
|
||||||
}
|
|
||||||
|
|
||||||
var rewriter messages2query.MessagesToQuery
|
|
||||||
if rewriterChatModel, _, err := internal.GetBuiltinChatModel(ctx, "M2Q_"); err != nil {
|
|
||||||
return nil, err
|
|
||||||
} else {
|
|
||||||
filePath := filepath.Join(root, "resources/conf/prompt/messages_to_query_template_jinja2.json")
|
|
||||||
rewriterTemplate, err := readJinja2PromptTemplate(filePath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
rewriter, err = builtinM2Q.NewMessagesToQuery(ctx, rewriterChatModel, rewriterTemplate)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var n2s nl2sql.NL2SQL
|
|
||||||
if n2sChatModel, _, err := internal.GetBuiltinChatModel(ctx, "NL2SQL_"); err != nil {
|
|
||||||
return nil, err
|
|
||||||
} else {
|
|
||||||
filePath := filepath.Join(root, "resources/conf/prompt/nl2sql_template_jinja2.json")
|
|
||||||
n2sTemplate, err := readJinja2PromptTemplate(filePath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
n2s, err = builtinNL2SQL.NewNL2SQL(ctx, n2sChatModel, n2sTemplate)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
knowledgeDomainSVC, knowledgeEventHandler := knowledgeImpl.NewKnowledgeSVC(&knowledgeImpl.KnowledgeSVCConfig{
|
|
||||||
DB: c.DB,
|
|
||||||
IDGen: c.IDGenSVC,
|
|
||||||
RDB: c.RDB,
|
|
||||||
Producer: knowledgeProducer,
|
|
||||||
SearchStoreManagers: c.SearchStoreManagers,
|
|
||||||
ParseManager: c.ParserManager,
|
|
||||||
Storage: c.Storage,
|
|
||||||
Rewriter: rewriter,
|
|
||||||
Reranker: rrf.NewRRFReranker(0), // default rrf
|
|
||||||
NL2Sql: n2s,
|
|
||||||
OCR: c.OCR,
|
|
||||||
CacheCli: c.CacheCli,
|
|
||||||
ModelFactory: chatmodelImpl.NewDefaultFactory(),
|
|
||||||
})
|
|
||||||
|
|
||||||
if err = eventbus.DefaultSVC().RegisterConsumer(nameServer, consts.RMQTopicKnowledge, consts.RMQConsumeGroupKnowledge, knowledgeEventHandler); err != nil {
|
|
||||||
return nil, fmt.Errorf("register knowledge consumer failed, err=%w", err)
|
return nil, fmt.Errorf("register knowledge consumer failed, err=%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
KnowledgeSVC.DomainSVC = knowledgeDomainSVC
|
KnowledgeSVC.DomainSVC = knowledgeDomainSVC
|
||||||
KnowledgeSVC.eventBus = c.EventBus
|
KnowledgeSVC.eventBus = bus
|
||||||
KnowledgeSVC.storage = c.Storage
|
KnowledgeSVC.storage = c.Storage
|
||||||
return KnowledgeSVC, nil
|
return KnowledgeSVC, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func readJinja2PromptTemplate(jsonFilePath string) (prompt.ChatTemplate, error) {
|
|
||||||
b, err := os.ReadFile(jsonFilePath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
var m2qMessages []*schema.Message
|
|
||||||
if err = json.Unmarshal(b, &m2qMessages); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
tpl := make([]schema.MessagesTemplate, len(m2qMessages))
|
|
||||||
for i := range m2qMessages {
|
|
||||||
tpl[i] = m2qMessages[i]
|
|
||||||
}
|
|
||||||
return prompt.FromMessages(schema.Jinja2, tpl...), nil
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,6 @@ import (
|
||||||
"github.com/cloudwego/eino/compose"
|
"github.com/cloudwego/eino/compose"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
|
||||||
"github.com/coze-dev/coze-studio/backend/application/internal"
|
|
||||||
"github.com/coze-dev/coze-studio/backend/crossdomain/impl/code"
|
"github.com/coze-dev/coze-studio/backend/crossdomain/impl/code"
|
||||||
knowledge "github.com/coze-dev/coze-studio/backend/domain/knowledge/service"
|
knowledge "github.com/coze-dev/coze-studio/backend/domain/knowledge/service"
|
||||||
dbservice "github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
|
dbservice "github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
|
||||||
|
|
@ -35,41 +34,34 @@ import (
|
||||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/service"
|
"github.com/coze-dev/coze-studio/backend/domain/workflow/service"
|
||||||
workflowservice "github.com/coze-dev/coze-studio/backend/domain/workflow/service"
|
workflowservice "github.com/coze-dev/coze-studio/backend/domain/workflow/service"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
|
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
|
||||||
|
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/contract/coderunner"
|
"github.com/coze-dev/coze-studio/backend/infra/contract/coderunner"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
|
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type ServiceComponents struct {
|
type ServiceComponents struct {
|
||||||
IDGen idgen.IDGenerator
|
IDGen idgen.IDGenerator
|
||||||
DB *gorm.DB
|
DB *gorm.DB
|
||||||
Cache cache.Cmdable
|
Cache cache.Cmdable
|
||||||
DatabaseDomainSVC dbservice.Database
|
DatabaseDomainSVC dbservice.Database
|
||||||
VariablesDomainSVC variables.Variables
|
VariablesDomainSVC variables.Variables
|
||||||
PluginDomainSVC plugin.PluginService
|
PluginDomainSVC plugin.PluginService
|
||||||
KnowledgeDomainSVC knowledge.Knowledge
|
KnowledgeDomainSVC knowledge.Knowledge
|
||||||
DomainNotifier search.ResourceEventBus
|
DomainNotifier search.ResourceEventBus
|
||||||
Tos storage.Storage
|
Tos storage.Storage
|
||||||
ImageX imagex.ImageX
|
ImageX imagex.ImageX
|
||||||
CPStore compose.CheckPointStore
|
CPStore compose.CheckPointStore
|
||||||
CodeRunner coderunner.Runner
|
CodeRunner coderunner.Runner
|
||||||
|
WorkflowBuildInChatModel chatmodel.BaseChatModel
|
||||||
}
|
}
|
||||||
|
|
||||||
func InitService(ctx context.Context, components *ServiceComponents) (*ApplicationService, error) {
|
func InitService(ctx context.Context, components *ServiceComponents) (*ApplicationService, error) {
|
||||||
bcm, ok, err := internal.GetBuiltinChatModel(ctx, "WKR_")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if !ok {
|
|
||||||
logs.CtxWarnf(ctx, "workflow builtin chat model for knowledge recall not configured")
|
|
||||||
}
|
|
||||||
|
|
||||||
service.RegisterAllNodeAdaptors()
|
service.RegisterAllNodeAdaptors()
|
||||||
|
|
||||||
workflowRepo := service.NewWorkflowRepository(components.IDGen, components.DB, components.Cache,
|
workflowRepo := service.NewWorkflowRepository(components.IDGen, components.DB, components.Cache,
|
||||||
components.Tos, components.CPStore, bcm)
|
components.Tos, components.CPStore, components.WorkflowBuildInChatModel)
|
||||||
workflow.SetRepository(workflowRepo)
|
workflow.SetRepository(workflowRepo)
|
||||||
|
|
||||||
workflowDomainSVC := service.NewWorkflowService(workflowRepo)
|
workflowDomainSVC := service.NewWorkflowService(workflowRepo)
|
||||||
|
|
@ -83,5 +75,5 @@ func InitService(ctx context.Context, components *ServiceComponents) (*Applicati
|
||||||
SVC.TosClient = components.Tos
|
SVC.TosClient = components.Tos
|
||||||
SVC.IDGenerator = components.IDGen
|
SVC.IDGenerator = components.IDGen
|
||||||
|
|
||||||
return SVC, err
|
return SVC, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -58,9 +58,7 @@ import (
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/contract/rdb"
|
"github.com/coze-dev/coze-studio/backend/infra/contract/rdb"
|
||||||
rdbEntity "github.com/coze-dev/coze-studio/backend/infra/contract/rdb/entity"
|
rdbEntity "github.com/coze-dev/coze-studio/backend/infra/contract/rdb/entity"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/parser/builtin"
|
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/progressbar"
|
"github.com/coze-dev/coze-studio/backend/infra/impl/document/progressbar"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/rerank/rrf"
|
|
||||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||||
|
|
@ -87,12 +85,6 @@ func NewKnowledgeSVC(config *KnowledgeSVCConfig) (Knowledge, eventbus.ConsumerHa
|
||||||
cacheCli: config.CacheCli,
|
cacheCli: config.CacheCli,
|
||||||
modelFactory: config.ModelFactory,
|
modelFactory: config.ModelFactory,
|
||||||
}
|
}
|
||||||
if svc.reranker == nil {
|
|
||||||
svc.reranker = rrf.NewRRFReranker(0)
|
|
||||||
}
|
|
||||||
if svc.parseManager == nil {
|
|
||||||
svc.parseManager = builtin.NewManager(config.Storage, config.OCR, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
return svc, svc
|
return svc, svc
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -41,6 +41,8 @@ import (
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
|
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/impl/cache/redis"
|
"github.com/coze-dev/coze-studio/backend/infra/impl/cache/redis"
|
||||||
|
"github.com/coze-dev/coze-studio/backend/infra/impl/document/parser/builtin"
|
||||||
|
"github.com/coze-dev/coze-studio/backend/infra/impl/document/rerank/rrf"
|
||||||
sses "github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/elasticsearch"
|
sses "github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/elasticsearch"
|
||||||
ssmilvus "github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/milvus"
|
ssmilvus "github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/milvus"
|
||||||
hembed "github.com/coze-dev/coze-studio/backend/infra/impl/embedding/http"
|
hembed "github.com/coze-dev/coze-studio/backend/infra/impl/embedding/http"
|
||||||
|
|
@ -169,10 +171,10 @@ func (suite *KnowledgeTestSuite) SetupSuite() {
|
||||||
RDB: rdbService,
|
RDB: rdbService,
|
||||||
Producer: knowledgeProducer,
|
Producer: knowledgeProducer,
|
||||||
SearchStoreManagers: mgrs,
|
SearchStoreManagers: mgrs,
|
||||||
ParseManager: nil, // default builtin
|
ParseManager: builtin.NewManager(tosClient, nil, nil), // default builtin
|
||||||
Storage: tosClient,
|
Storage: tosClient,
|
||||||
Rewriter: nil,
|
Rewriter: nil,
|
||||||
Reranker: nil, // default rrf
|
Reranker: rrf.NewRRFReranker(0), // default rrf
|
||||||
EnableCompactTable: ptr.Of(true),
|
EnableCompactTable: ptr.Of(true),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,8 @@ import (
|
||||||
knowledgeModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
|
knowledgeModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
|
||||||
"github.com/coze-dev/coze-studio/backend/domain/knowledge/entity"
|
"github.com/coze-dev/coze-studio/backend/domain/knowledge/entity"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
|
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
|
||||||
|
"github.com/coze-dev/coze-studio/backend/infra/impl/document/parser/builtin"
|
||||||
|
"github.com/coze-dev/coze-studio/backend/infra/impl/document/rerank/rrf"
|
||||||
"github.com/coze-dev/coze-studio/backend/infra/impl/rdb"
|
"github.com/coze-dev/coze-studio/backend/infra/impl/rdb"
|
||||||
producerMock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/eventbus"
|
producerMock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/eventbus"
|
||||||
mock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/idgen"
|
mock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/idgen"
|
||||||
|
|
@ -98,11 +100,14 @@ func MockKnowledgeSVC(t *testing.T) Knowledge {
|
||||||
mockStorage.EXPECT().PutObject(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
|
mockStorage.EXPECT().PutObject(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
|
||||||
rdb := rdb.NewService(db, mockIDGen)
|
rdb := rdb.NewService(db, mockIDGen)
|
||||||
svc, _ := NewKnowledgeSVC(&KnowledgeSVCConfig{
|
svc, _ := NewKnowledgeSVC(&KnowledgeSVCConfig{
|
||||||
DB: db,
|
DB: db,
|
||||||
IDGen: mockIDGen,
|
IDGen: mockIDGen,
|
||||||
Storage: mockStorage,
|
Storage: mockStorage,
|
||||||
Producer: producer,
|
Producer: producer,
|
||||||
RDB: rdb,
|
RDB: rdb,
|
||||||
|
Reranker: rrf.NewRRFReranker(0),
|
||||||
|
ParseManager: builtin.NewManager(mockStorage, nil, nil), // default builtin
|
||||||
|
|
||||||
})
|
})
|
||||||
return svc
|
return svc
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue