/* * Copyright 2025 coze-dev Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package knowledge import ( "context" "errors" "github.com/spf13/cast" "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge" crossknowledge "github.com/coze-dev/coze-studio/backend/crossdomain/contract/knowledge" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/pkg/lang/slices" ) const outputList = "outputList" type RetrieveConfig struct { KnowledgeIDs []int64 RetrievalStrategy *knowledge.RetrievalStrategy } func (r *RetrieveConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) { ns := &schema.NodeSchema{ Key: vo.NodeKey(n.ID), Type: entity.NodeTypeKnowledgeRetriever, Name: n.Data.Meta.Title, Configs: r, } inputs := n.Data.Inputs datasetListInfoParam := inputs.DatasetParam[0] datasetIDs := datasetListInfoParam.Input.Value.Content.([]any) knowledgeIDs := make([]int64, 0, len(datasetIDs)) for _, id := range datasetIDs { k, err := cast.ToInt64E(id) if err != nil { return nil, err } knowledgeIDs = append(knowledgeIDs, k) } r.KnowledgeIDs = knowledgeIDs retrievalStrategy := &knowledge.RetrievalStrategy{} var getDesignatedParamContent = func(name string) (any, bool) { for _, param := range inputs.DatasetParam { if param.Name == name { return param.Input.Value.Content, true } } return nil, false } if content, ok := getDesignatedParamContent("topK"); ok { topK, err := cast.ToInt64E(content) if err != nil { return nil, err } retrievalStrategy.TopK = &topK } if content, ok := getDesignatedParamContent("useRerank"); ok { useRerank, err := cast.ToBoolE(content) if err != nil { return nil, err } retrievalStrategy.EnableRerank = useRerank } if content, ok := getDesignatedParamContent("useRewrite"); ok { useRewrite, err := cast.ToBoolE(content) if err != nil { return nil, err } retrievalStrategy.EnableQueryRewrite = useRewrite } if content, ok := getDesignatedParamContent("isPersonalOnly"); ok { isPersonalOnly, err := cast.ToBoolE(content) if err != nil { return nil, err } retrievalStrategy.IsPersonalOnly = isPersonalOnly } if content, ok := getDesignatedParamContent("useNl2sql"); ok { useNl2sql, err := cast.ToBoolE(content) if err != nil { return nil, err } retrievalStrategy.EnableNL2SQL = useNl2sql } if content, ok := getDesignatedParamContent("minScore"); ok { minScore, err := cast.ToFloat64E(content) if err != nil { return nil, err } retrievalStrategy.MinScore = &minScore } if content, ok := getDesignatedParamContent("strategy"); ok { strategy, err := cast.ToInt64E(content) if err != nil { return nil, err } searchType, err := convertRetrievalSearchType(strategy) if err != nil { return nil, err } retrievalStrategy.SearchType = searchType } r.RetrievalStrategy = retrievalStrategy if err := convert.SetInputsForNodeSchema(n, ns); err != nil { return nil, err } if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil { return nil, err } return ns, nil } func (r *RetrieveConfig) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) { if len(r.KnowledgeIDs) == 0 { return nil, errors.New("knowledge ids are required") } if r.RetrievalStrategy == nil { return nil, errors.New("retrieval strategy is required") } return &Retrieve{ knowledgeIDs: r.KnowledgeIDs, retrievalStrategy: r.RetrievalStrategy, }, nil } type Retrieve struct { knowledgeIDs []int64 retrievalStrategy *knowledge.RetrievalStrategy } func (kr *Retrieve) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) { query, ok := input["Query"].(string) if !ok { return nil, errors.New("capital query key is required") } req := &knowledge.RetrieveRequest{ Query: query, KnowledgeIDs: kr.knowledgeIDs, Strategy: kr.retrievalStrategy, } response, err := crossknowledge.DefaultSVC().Retrieve(ctx, req) if err != nil { return nil, err } result := make(map[string]any) result[outputList] = slices.Transform(response.RetrieveSlices, func(m *knowledge.RetrieveSlice) any { return map[string]any{ "documentId": m.Slice.DocumentID, "output": m.Slice.GetSliceContent(), } }) return result, nil }