feat: manually mirror opencoze's code from bytedance
Change-Id: I09a73aadda978ad9511264a756b2ce51f5761adf
This commit is contained in:
1293
backend/application/app/app.go
Normal file
1293
backend/application/app/app.go
Normal file
File diff suppressed because it is too large
Load Diff
56
backend/application/app/convert.go
Normal file
56
backend/application/app/convert.go
Normal file
@@ -0,0 +1,56 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package app
|
||||
|
||||
import (
|
||||
resourceCommon "github.com/coze-dev/coze-studio/backend/api/model/resource/common"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/app/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
func toResourceType(resType resourceCommon.ResType) (entity.ResourceType, error) {
|
||||
switch resType {
|
||||
case resourceCommon.ResType_Plugin:
|
||||
return entity.ResourceTypeOfPlugin, nil
|
||||
case resourceCommon.ResType_Workflow:
|
||||
return entity.ResourceTypeOfWorkflow, nil
|
||||
case resourceCommon.ResType_Knowledge:
|
||||
return entity.ResourceTypeOfKnowledge, nil
|
||||
case resourceCommon.ResType_Database:
|
||||
return entity.ResourceTypeOfDatabase, nil
|
||||
default:
|
||||
return "", errorx.New(errno.ErrAppInvalidParamCode,
|
||||
errorx.KVf(errno.APPMsgKey, "unsupported resource type '%s'", resType))
|
||||
}
|
||||
}
|
||||
|
||||
func toThriftResourceType(resType entity.ResourceType) (resourceCommon.ResType, error) {
|
||||
switch resType {
|
||||
case entity.ResourceTypeOfPlugin:
|
||||
return resourceCommon.ResType_Plugin, nil
|
||||
case entity.ResourceTypeOfWorkflow:
|
||||
return resourceCommon.ResType_Workflow, nil
|
||||
case entity.ResourceTypeOfKnowledge:
|
||||
return resourceCommon.ResType_Knowledge, nil
|
||||
case entity.ResourceTypeOfDatabase:
|
||||
return resourceCommon.ResType_Database, nil
|
||||
default:
|
||||
return 0, errorx.New(errno.ErrAppInvalidParamCode,
|
||||
errorx.KVf(errno.APPMsgKey, "unsupported resource type '%s'", resType))
|
||||
}
|
||||
}
|
||||
71
backend/application/app/init.go
Normal file
71
backend/application/app/init.go
Normal file
@@ -0,0 +1,71 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package app
|
||||
|
||||
import (
|
||||
redisV9 "github.com/redis/go-redis/v9"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/app/repository"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/app/service"
|
||||
connector "github.com/coze-dev/coze-studio/backend/domain/connector/service"
|
||||
variables "github.com/coze-dev/coze-studio/backend/domain/memory/variables/service"
|
||||
search "github.com/coze-dev/coze-studio/backend/domain/search/service"
|
||||
user "github.com/coze-dev/coze-studio/backend/domain/user/service"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
)
|
||||
|
||||
type ServiceComponents struct {
|
||||
IDGen idgen.IDGenerator
|
||||
DB *gorm.DB
|
||||
OSS storage.Storage
|
||||
CacheCli *redisV9.Client
|
||||
ProjectEventBus search.ProjectEventBus
|
||||
|
||||
UserSVC user.User
|
||||
ConnectorSVC connector.Connector
|
||||
VariablesSVC variables.Variables
|
||||
}
|
||||
|
||||
func InitService(components *ServiceComponents) (*APPApplicationService, error) {
|
||||
appRepo := repository.NewAPPRepo(&repository.APPRepoComponents{
|
||||
IDGen: components.IDGen,
|
||||
DB: components.DB,
|
||||
CacheCli: components.CacheCli,
|
||||
})
|
||||
|
||||
domainComponents := &service.Components{
|
||||
IDGen: components.IDGen,
|
||||
DB: components.DB,
|
||||
APPRepo: appRepo,
|
||||
}
|
||||
|
||||
domainSVC := service.NewService(domainComponents)
|
||||
|
||||
APPApplicationSVC.DomainSVC = domainSVC
|
||||
APPApplicationSVC.appRepo = appRepo
|
||||
|
||||
APPApplicationSVC.oss = components.OSS
|
||||
APPApplicationSVC.projectEventBus = components.ProjectEventBus
|
||||
|
||||
APPApplicationSVC.userSVC = components.UserSVC
|
||||
APPApplicationSVC.connectorSVC = components.ConnectorSVC
|
||||
APPApplicationSVC.variablesSVC = components.VariablesSVC
|
||||
|
||||
return APPApplicationSVC, nil
|
||||
}
|
||||
32
backend/application/app/model.go
Normal file
32
backend/application/app/model.go
Normal file
@@ -0,0 +1,32 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package app
|
||||
|
||||
import (
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/resource/common"
|
||||
)
|
||||
|
||||
type copyMetaInfo struct {
|
||||
scene common.ResourceCopyScene
|
||||
|
||||
userID int64
|
||||
appSpaceID int64
|
||||
copyTaskID string
|
||||
|
||||
fromAppID int64
|
||||
toAppID *int64
|
||||
}
|
||||
362
backend/application/application.go
Normal file
362
backend/application/application.go
Normal file
@@ -0,0 +1,362 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package application
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/application/openauth"
|
||||
"github.com/coze-dev/coze-studio/backend/application/template"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crosssearch"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/application/app"
|
||||
"github.com/coze-dev/coze-studio/backend/application/base/appinfra"
|
||||
"github.com/coze-dev/coze-studio/backend/application/connector"
|
||||
"github.com/coze-dev/coze-studio/backend/application/conversation"
|
||||
"github.com/coze-dev/coze-studio/backend/application/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/application/memory"
|
||||
"github.com/coze-dev/coze-studio/backend/application/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/application/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/application/prompt"
|
||||
"github.com/coze-dev/coze-studio/backend/application/search"
|
||||
"github.com/coze-dev/coze-studio/backend/application/shortcutcmd"
|
||||
"github.com/coze-dev/coze-studio/backend/application/singleagent"
|
||||
"github.com/coze-dev/coze-studio/backend/application/upload"
|
||||
"github.com/coze-dev/coze-studio/backend/application/user"
|
||||
"github.com/coze-dev/coze-studio/backend/application/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossagent"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossagentrun"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossconnector"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossconversation"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossdatabase"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossdatacopy"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossknowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossmessage"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossmodelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossplugin"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossuser"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossvariables"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossworkflow"
|
||||
agentrunImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/agentrun"
|
||||
connectorImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/connector"
|
||||
conversationImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/conversation"
|
||||
crossuserImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/crossuser"
|
||||
databaseImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/database"
|
||||
dataCopyImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/datacopy"
|
||||
knowledgeImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/knowledge"
|
||||
messageImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/message"
|
||||
modelmgrImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/modelmgr"
|
||||
pluginImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/plugin"
|
||||
searchImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/search"
|
||||
singleagentImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/singleagent"
|
||||
variablesImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/variables"
|
||||
workflowImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/checkpoint"
|
||||
)
|
||||
|
||||
type eventbusImpl struct {
|
||||
resourceEventBus search.ResourceEventBus
|
||||
projectEventBus search.ProjectEventBus
|
||||
}
|
||||
|
||||
type basicServices struct {
|
||||
infra *appinfra.AppDependencies
|
||||
eventbus *eventbusImpl
|
||||
modelMgrSVC *modelmgr.ModelmgrApplicationService
|
||||
connectorSVC *connector.ConnectorApplicationService
|
||||
userSVC *user.UserApplicationService
|
||||
promptSVC *prompt.PromptApplicationService
|
||||
templateSVC *template.ApplicationService
|
||||
openAuthSVC *openauth.OpenAuthApplicationService
|
||||
}
|
||||
|
||||
type primaryServices struct {
|
||||
basicServices *basicServices
|
||||
infra *appinfra.AppDependencies
|
||||
|
||||
pluginSVC *plugin.PluginApplicationService
|
||||
memorySVC *memory.MemoryApplicationServices
|
||||
knowledgeSVC *knowledge.KnowledgeApplicationService
|
||||
workflowSVC *workflow.ApplicationService
|
||||
shortcutSVC *shortcutcmd.ShortcutCmdApplicationService
|
||||
}
|
||||
|
||||
type complexServices struct {
|
||||
primaryServices *primaryServices
|
||||
singleAgentSVC *singleagent.SingleAgentApplicationService
|
||||
appSVC *app.APPApplicationService
|
||||
searchSVC *search.SearchApplicationService
|
||||
conversationSVC *conversation.ConversationApplicationService
|
||||
}
|
||||
|
||||
func Init(ctx context.Context) (err error) {
|
||||
infra, err := appinfra.Init(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
eventbus := initEventBus(infra)
|
||||
|
||||
basicServices, err := initBasicServices(ctx, infra, eventbus)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Init - initBasicServices failed, err: %v", err)
|
||||
}
|
||||
|
||||
primaryServices, err := initPrimaryServices(ctx, basicServices)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Init - initPrimaryServices failed, err: %v", err)
|
||||
}
|
||||
|
||||
complexServices, err := initComplexServices(ctx, primaryServices)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Init - initVitalServices failed, err: %v", err)
|
||||
}
|
||||
|
||||
crossconnector.SetDefaultSVC(connectorImpl.InitDomainService(basicServices.connectorSVC.DomainSVC))
|
||||
crossdatabase.SetDefaultSVC(databaseImpl.InitDomainService(primaryServices.memorySVC.DatabaseDomainSVC))
|
||||
crossknowledge.SetDefaultSVC(knowledgeImpl.InitDomainService(primaryServices.knowledgeSVC.DomainSVC))
|
||||
crossmodelmgr.SetDefaultSVC(modelmgrImpl.InitDomainService(basicServices.modelMgrSVC.DomainSVC))
|
||||
crossplugin.SetDefaultSVC(pluginImpl.InitDomainService(primaryServices.pluginSVC.DomainSVC))
|
||||
crossvariables.SetDefaultSVC(variablesImpl.InitDomainService(primaryServices.memorySVC.VariablesDomainSVC))
|
||||
crossworkflow.SetDefaultSVC(workflowImpl.InitDomainService(primaryServices.workflowSVC.DomainSVC))
|
||||
crossconversation.SetDefaultSVC(conversationImpl.InitDomainService(complexServices.conversationSVC.ConversationDomainSVC))
|
||||
crossmessage.SetDefaultSVC(messageImpl.InitDomainService(complexServices.conversationSVC.MessageDomainSVC))
|
||||
crossagentrun.SetDefaultSVC(agentrunImpl.InitDomainService(complexServices.conversationSVC.AgentRunDomainSVC))
|
||||
crossagent.SetDefaultSVC(singleagentImpl.InitDomainService(complexServices.singleAgentSVC.DomainSVC))
|
||||
crossuser.SetDefaultSVC(crossuserImpl.InitDomainService(basicServices.userSVC.DomainSVC))
|
||||
crossdatacopy.SetDefaultSVC(dataCopyImpl.InitDomainService(basicServices.infra))
|
||||
crosssearch.SetDefaultSVC(searchImpl.InitDomainService(complexServices.searchSVC.DomainSVC))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func initEventBus(infra *appinfra.AppDependencies) *eventbusImpl {
|
||||
e := &eventbusImpl{}
|
||||
e.resourceEventBus = search.NewResourceEventBus(infra.ResourceEventProducer)
|
||||
e.projectEventBus = search.NewProjectEventBus(infra.AppEventProducer)
|
||||
|
||||
return e
|
||||
}
|
||||
|
||||
// initBasicServices init basic services that only depends on infra.
|
||||
func initBasicServices(ctx context.Context, infra *appinfra.AppDependencies, e *eventbusImpl) (*basicServices, error) {
|
||||
upload.InitService(infra.TOSClient, infra.CacheCli)
|
||||
openAuthSVC := openauth.InitService(infra.DB, infra.IDGenSVC)
|
||||
promptSVC := prompt.InitService(infra.DB, infra.IDGenSVC, e.resourceEventBus)
|
||||
modelMgrSVC, err := modelmgr.InitService(infra.DB, infra.IDGenSVC, infra.TOSClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
connectorSVC := connector.InitService(infra.TOSClient)
|
||||
userSVC := user.InitService(ctx, infra.DB, infra.TOSClient, infra.IDGenSVC)
|
||||
templateSVC := template.InitService(ctx, &template.ServiceComponents{
|
||||
DB: infra.DB,
|
||||
IDGen: infra.IDGenSVC,
|
||||
Storage: infra.TOSClient,
|
||||
})
|
||||
|
||||
return &basicServices{
|
||||
infra: infra,
|
||||
eventbus: e,
|
||||
modelMgrSVC: modelMgrSVC,
|
||||
connectorSVC: connectorSVC,
|
||||
userSVC: userSVC,
|
||||
promptSVC: promptSVC,
|
||||
templateSVC: templateSVC,
|
||||
openAuthSVC: openAuthSVC,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// initPrimaryServices init primary services that depends on basic services.
|
||||
func initPrimaryServices(ctx context.Context, basicServices *basicServices) (*primaryServices, error) {
|
||||
pluginSVC, err := plugin.InitService(ctx, basicServices.toPluginServiceComponents())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
memorySVC := memory.InitService(basicServices.toMemoryServiceComponents())
|
||||
|
||||
knowledgeSVC, err := knowledge.InitService(basicServices.toKnowledgeServiceComponents(memorySVC))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
workflowDomainSVC := workflow.InitService(
|
||||
basicServices.toWorkflowServiceComponents(pluginSVC, memorySVC, knowledgeSVC))
|
||||
|
||||
shortcutSVC := shortcutcmd.InitService(basicServices.infra.DB, basicServices.infra.IDGenSVC)
|
||||
|
||||
return &primaryServices{
|
||||
basicServices: basicServices,
|
||||
pluginSVC: pluginSVC,
|
||||
memorySVC: memorySVC,
|
||||
knowledgeSVC: knowledgeSVC,
|
||||
workflowSVC: workflowDomainSVC,
|
||||
shortcutSVC: shortcutSVC,
|
||||
infra: basicServices.infra,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// initComplexServices init complex services that depends on primary services.
|
||||
func initComplexServices(ctx context.Context, p *primaryServices) (*complexServices, error) {
|
||||
singleAgentSVC, err := singleagent.InitService(p.toSingleAgentServiceComponents())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
appSVC, err := app.InitService(p.toAPPServiceComponents())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
searchSVC, err := search.InitService(ctx, p.toSearchServiceComponents(singleAgentSVC, appSVC))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conversationSVC := conversation.InitService(p.toConversationComponents(singleAgentSVC))
|
||||
|
||||
return &complexServices{
|
||||
primaryServices: p,
|
||||
singleAgentSVC: singleAgentSVC,
|
||||
appSVC: appSVC,
|
||||
searchSVC: searchSVC,
|
||||
conversationSVC: conversationSVC,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (b *basicServices) toPluginServiceComponents() *plugin.ServiceComponents {
|
||||
return &plugin.ServiceComponents{
|
||||
IDGen: b.infra.IDGenSVC,
|
||||
DB: b.infra.DB,
|
||||
EventBus: b.eventbus.resourceEventBus,
|
||||
OSS: b.infra.TOSClient,
|
||||
UserSVC: b.userSVC.DomainSVC,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *basicServices) toKnowledgeServiceComponents(memoryService *memory.MemoryApplicationServices) *knowledge.ServiceComponents {
|
||||
return &knowledge.ServiceComponents{
|
||||
DB: b.infra.DB,
|
||||
IDGenSVC: b.infra.IDGenSVC,
|
||||
Storage: b.infra.TOSClient,
|
||||
RDB: memoryService.RDBDomainSVC,
|
||||
ImageX: b.infra.ImageXClient,
|
||||
ES: b.infra.ESClient,
|
||||
EventBus: b.eventbus.resourceEventBus,
|
||||
CacheCli: b.infra.CacheCli,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *basicServices) toMemoryServiceComponents() *memory.ServiceComponents {
|
||||
return &memory.ServiceComponents{
|
||||
IDGen: b.infra.IDGenSVC,
|
||||
DB: b.infra.DB,
|
||||
EventBus: b.eventbus.resourceEventBus,
|
||||
TosClient: b.infra.TOSClient,
|
||||
ResourceDomainNotifier: b.eventbus.resourceEventBus,
|
||||
CacheCli: b.infra.CacheCli,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *basicServices) toWorkflowServiceComponents(pluginSVC *plugin.PluginApplicationService, memorySVC *memory.MemoryApplicationServices, knowledgeSVC *knowledge.KnowledgeApplicationService) *workflow.ServiceComponents {
|
||||
return &workflow.ServiceComponents{
|
||||
IDGen: b.infra.IDGenSVC,
|
||||
DB: b.infra.DB,
|
||||
Cache: b.infra.CacheCli,
|
||||
Tos: b.infra.TOSClient,
|
||||
ImageX: b.infra.ImageXClient,
|
||||
DatabaseDomainSVC: memorySVC.DatabaseDomainSVC,
|
||||
VariablesDomainSVC: memorySVC.VariablesDomainSVC,
|
||||
PluginDomainSVC: pluginSVC.DomainSVC,
|
||||
KnowledgeDomainSVC: knowledgeSVC.DomainSVC,
|
||||
ModelManager: b.modelMgrSVC.DomainSVC,
|
||||
DomainNotifier: b.eventbus.resourceEventBus,
|
||||
CPStore: checkpoint.NewRedisStore(b.infra.CacheCli),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *primaryServices) toSingleAgentServiceComponents() *singleagent.ServiceComponents {
|
||||
|
||||
return &singleagent.ServiceComponents{
|
||||
IDGen: p.basicServices.infra.IDGenSVC,
|
||||
DB: p.basicServices.infra.DB,
|
||||
Cache: p.basicServices.infra.CacheCli,
|
||||
TosClient: p.basicServices.infra.TOSClient,
|
||||
ImageX: p.basicServices.infra.ImageXClient,
|
||||
ModelMgrDomainSVC: p.basicServices.modelMgrSVC.DomainSVC,
|
||||
UserDomainSVC: p.basicServices.userSVC.DomainSVC,
|
||||
EventBus: p.basicServices.eventbus.projectEventBus,
|
||||
DatabaseDomainSVC: p.memorySVC.DatabaseDomainSVC,
|
||||
ConnectorDomainSVC: p.basicServices.connectorSVC.DomainSVC,
|
||||
KnowledgeDomainSVC: p.knowledgeSVC.DomainSVC,
|
||||
PluginDomainSVC: p.pluginSVC.DomainSVC,
|
||||
WorkflowDomainSVC: p.workflowSVC.DomainSVC,
|
||||
VariablesDomainSVC: p.memorySVC.VariablesDomainSVC,
|
||||
ShortcutCMDDomainSVC: p.shortcutSVC.ShortCutDomainSVC,
|
||||
CPStore: checkpoint.NewRedisStore(p.infra.CacheCli),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *primaryServices) toSearchServiceComponents(singleAgentSVC *singleagent.SingleAgentApplicationService, appSVC *app.APPApplicationService) *search.ServiceComponents {
|
||||
infra := p.basicServices.infra
|
||||
|
||||
return &search.ServiceComponents{
|
||||
DB: infra.DB,
|
||||
Cache: infra.CacheCli,
|
||||
TOS: infra.TOSClient,
|
||||
ESClient: infra.ESClient,
|
||||
ProjectEventBus: p.basicServices.eventbus.projectEventBus,
|
||||
SingleAgentDomainSVC: singleAgentSVC.DomainSVC,
|
||||
APPDomainSVC: appSVC.DomainSVC,
|
||||
KnowledgeDomainSVC: p.knowledgeSVC.DomainSVC,
|
||||
PluginDomainSVC: p.pluginSVC.DomainSVC,
|
||||
WorkflowDomainSVC: p.workflowSVC.DomainSVC,
|
||||
UserDomainSVC: p.basicServices.userSVC.DomainSVC,
|
||||
ConnectorDomainSVC: p.basicServices.connectorSVC.DomainSVC,
|
||||
PromptDomainSVC: p.basicServices.promptSVC.DomainSVC,
|
||||
DatabaseDomainSVC: p.memorySVC.DatabaseDomainSVC,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *primaryServices) toAPPServiceComponents() *app.ServiceComponents {
|
||||
infra := p.basicServices.infra
|
||||
basic := p.basicServices
|
||||
return &app.ServiceComponents{
|
||||
IDGen: infra.IDGenSVC,
|
||||
DB: infra.DB,
|
||||
OSS: infra.TOSClient,
|
||||
CacheCli: infra.CacheCli,
|
||||
ProjectEventBus: basic.eventbus.projectEventBus,
|
||||
UserSVC: basic.userSVC.DomainSVC,
|
||||
ConnectorSVC: basic.connectorSVC.DomainSVC,
|
||||
VariablesSVC: p.memorySVC.VariablesDomainSVC,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *primaryServices) toConversationComponents(singleAgentSVC *singleagent.SingleAgentApplicationService) *conversation.ServiceComponents {
|
||||
infra := p.basicServices.infra
|
||||
|
||||
return &conversation.ServiceComponents{
|
||||
DB: infra.DB,
|
||||
IDGen: infra.IDGenSVC,
|
||||
TosClient: infra.TOSClient,
|
||||
ImageX: infra.ImageXClient,
|
||||
SingleAgentDomainSVC: singleAgentSVC.DomainSVC,
|
||||
}
|
||||
}
|
||||
132
backend/application/base/appinfra/app_infra.go
Normal file
132
backend/application/base/appinfra/app_infra.go
Normal file
@@ -0,0 +1,132 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package appinfra
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/cache/redis"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/es"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/eventbus"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/imagex/veimagex"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/mysql"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/storage"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
)
|
||||
|
||||
type AppDependencies struct {
|
||||
DB *gorm.DB
|
||||
CacheCli *redis.Client
|
||||
IDGenSVC idgen.IDGenerator
|
||||
ESClient es.Client
|
||||
ImageXClient imagex.ImageX
|
||||
TOSClient storage.Storage
|
||||
ResourceEventProducer eventbus.Producer
|
||||
AppEventProducer eventbus.Producer
|
||||
}
|
||||
|
||||
func Init(ctx context.Context) (*AppDependencies, error) {
|
||||
deps := &AppDependencies{}
|
||||
var err error
|
||||
|
||||
deps.DB, err = mysql.New()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
deps.CacheCli = redis.New()
|
||||
|
||||
deps.IDGenSVC, err = idgen.New(deps.CacheCli)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
deps.ESClient, err = es.New()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
deps.ImageXClient, err = initImageX(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
deps.TOSClient, err = initTOS(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
deps.ResourceEventProducer, err = initResourceEventBusProducer()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
deps.AppEventProducer, err = initAppEventProducer()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return deps, nil
|
||||
}
|
||||
|
||||
func initImageX(ctx context.Context) (imagex.ImageX, error) {
|
||||
|
||||
uploadComponentType := os.Getenv(consts.FileUploadComponentType)
|
||||
|
||||
if uploadComponentType != consts.FileUploadComponentTypeImagex {
|
||||
return storage.NewImagex(ctx)
|
||||
}
|
||||
return veimagex.New(
|
||||
os.Getenv(consts.VeImageXAK),
|
||||
os.Getenv(consts.VeImageXSK),
|
||||
os.Getenv(consts.VeImageXDomain),
|
||||
os.Getenv(consts.VeImageXUploadHost),
|
||||
os.Getenv(consts.VeImageXTemplate),
|
||||
[]string{os.Getenv(consts.VeImageXServerID)},
|
||||
)
|
||||
}
|
||||
|
||||
func initTOS(ctx context.Context) (storage.Storage, error) {
|
||||
return storage.New(ctx)
|
||||
}
|
||||
|
||||
func initResourceEventBusProducer() (eventbus.Producer, error) {
|
||||
nameServer := os.Getenv(consts.MQServer)
|
||||
resourceEventBusProducer, err := eventbus.NewProducer(nameServer,
|
||||
consts.RMQTopicResource, consts.RMQConsumeGroupResource, 1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init resource producer failed, err=%w", err)
|
||||
}
|
||||
|
||||
return resourceEventBusProducer, nil
|
||||
}
|
||||
|
||||
func initAppEventProducer() (eventbus.Producer, error) {
|
||||
nameServer := os.Getenv(consts.MQServer)
|
||||
appEventProducer, err := eventbus.NewProducer(nameServer, consts.RMQTopicApp, consts.RMQConsumeGroupApp, 1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init app producer failed, err=%w", err)
|
||||
}
|
||||
|
||||
return appEventProducer, nil
|
||||
}
|
||||
42
backend/application/base/ctxutil/api_auth.go
Normal file
42
backend/application/base/ctxutil/api_auth.go
Normal file
@@ -0,0 +1,42 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package ctxutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/openauth/openapiauth/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
)
|
||||
|
||||
func GetApiAuthFromCtx(ctx context.Context) *entity.ApiKey {
|
||||
data, ok := ctxcache.Get[*entity.ApiKey](ctx, consts.OpenapiAuthKeyInCtx)
|
||||
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func MustGetUIDFromApiAuthCtx(ctx context.Context) int64 {
|
||||
apiKeyInfo := GetApiAuthFromCtx(ctx)
|
||||
if apiKeyInfo == nil {
|
||||
panic("mustGetUIDFromApiAuthCtx: apiKeyInfo is nil")
|
||||
}
|
||||
return apiKeyInfo.UserID
|
||||
}
|
||||
35
backend/application/base/ctxutil/request.go
Normal file
35
backend/application/base/ctxutil/request.go
Normal file
@@ -0,0 +1,35 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package ctxutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
func GetRequestFullPathFromCtx(ctx context.Context) string {
|
||||
contextValue := ctx.Value("request.full_path")
|
||||
if contextValue == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
fullPath, ok := contextValue.(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
return fullPath
|
||||
}
|
||||
52
backend/application/base/ctxutil/session.go
Normal file
52
backend/application/base/ctxutil/session.go
Normal file
@@ -0,0 +1,52 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package ctxutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/user/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
)
|
||||
|
||||
func GetUserSessionFromCtx(ctx context.Context) *entity.Session {
|
||||
data, ok := ctxcache.Get[*entity.Session](ctx, consts.SessionDataKeyInCtx)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
func MustGetUIDFromCtx(ctx context.Context) int64 {
|
||||
sessionData := GetUserSessionFromCtx(ctx)
|
||||
if sessionData == nil {
|
||||
panic("mustGetUIDFromCtx: sessionData is nil")
|
||||
}
|
||||
|
||||
return sessionData.UserID
|
||||
}
|
||||
|
||||
func GetUIDFromCtx(ctx context.Context) *int64 {
|
||||
sessionData := GetUserSessionFromCtx(ctx)
|
||||
if sessionData == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &sessionData.UserID
|
||||
}
|
||||
327
backend/application/base/pluginutil/api.go
Normal file
327
backend/application/base/pluginutil/api.go
Normal file
@@ -0,0 +1,327 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package pluginutil
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/getkin/kin-openapi/openapi3"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
|
||||
common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
func APIParamsToOpenapiOperation(reqParams, respParams []*common.APIParameter) (*openapi3.Operation, error) {
|
||||
op := &openapi3.Operation{}
|
||||
|
||||
hasSetReqBody := false
|
||||
hasSetParams := false
|
||||
|
||||
for _, apiParam := range reqParams {
|
||||
if apiParam.Location != common.ParameterLocation_Body {
|
||||
if !hasSetParams {
|
||||
hasSetParams = true
|
||||
op.Parameters = []*openapi3.ParameterRef{}
|
||||
}
|
||||
|
||||
_apiParam, err := toOpenapiParameter(apiParam)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
op.Parameters = append(op.Parameters, &openapi3.ParameterRef{
|
||||
Value: _apiParam,
|
||||
})
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
var mType *openapi3.MediaType
|
||||
if hasSetReqBody {
|
||||
mType = op.RequestBody.Value.Content[plugin.MediaTypeJson]
|
||||
} else {
|
||||
hasSetReqBody = true
|
||||
mType = &openapi3.MediaType{
|
||||
Schema: &openapi3.SchemaRef{
|
||||
Value: &openapi3.Schema{
|
||||
Type: openapi3.TypeObject,
|
||||
Properties: map[string]*openapi3.SchemaRef{},
|
||||
},
|
||||
},
|
||||
}
|
||||
op.RequestBody = &openapi3.RequestBodyRef{
|
||||
Value: &openapi3.RequestBody{
|
||||
Content: map[string]*openapi3.MediaType{
|
||||
plugin.MediaTypeJson: mType,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
_apiParam, err := toOpenapi3Schema(apiParam)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mType.Schema.Value.Properties[apiParam.Name] = &openapi3.SchemaRef{
|
||||
Value: _apiParam,
|
||||
}
|
||||
if apiParam.IsRequired {
|
||||
mType.Schema.Value.Required = append(mType.Schema.Value.Required, apiParam.Name)
|
||||
}
|
||||
}
|
||||
|
||||
hasSetRespBody := false
|
||||
|
||||
for _, apiParam := range respParams {
|
||||
if !hasSetRespBody {
|
||||
hasSetRespBody = true
|
||||
op.Responses = map[string]*openapi3.ResponseRef{
|
||||
strconv.Itoa(http.StatusOK): {
|
||||
Value: &openapi3.Response{
|
||||
Content: map[string]*openapi3.MediaType{
|
||||
plugin.MediaTypeJson: {
|
||||
Schema: &openapi3.SchemaRef{
|
||||
Value: &openapi3.Schema{
|
||||
Type: openapi3.TypeObject,
|
||||
Properties: map[string]*openapi3.SchemaRef{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
_apiParam, err := toOpenapi3Schema(apiParam)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, _ := op.Responses[strconv.Itoa(http.StatusOK)]
|
||||
mType, _ := resp.Value.Content[plugin.MediaTypeJson] // only support application/json
|
||||
mType.Schema.Value.Properties[apiParam.Name] = &openapi3.SchemaRef{
|
||||
Value: _apiParam,
|
||||
}
|
||||
|
||||
if apiParam.IsRequired {
|
||||
mType.Schema.Value.Required = append(mType.Schema.Value.Required, apiParam.Name)
|
||||
}
|
||||
}
|
||||
|
||||
return op, nil
|
||||
}
|
||||
|
||||
func toOpenapiParameter(apiParam *common.APIParameter) (*openapi3.Parameter, error) {
|
||||
paramType, ok := plugin.ToOpenapiParamType(apiParam.Type)
|
||||
if !ok {
|
||||
return nil, errorx.New(errno.ErrPluginInvalidParamCode,
|
||||
errorx.KVf(errno.PluginMsgKey, "the type '%s' of field '%s' is invalid", apiParam.Type, apiParam.Name))
|
||||
}
|
||||
|
||||
if paramType == openapi3.TypeObject {
|
||||
return nil, errorx.New(errno.ErrPluginInvalidParamCode,
|
||||
errorx.KVf(errno.PluginMsgKey, "the type of field '%s' cannot be 'object'", apiParam.Name))
|
||||
}
|
||||
|
||||
paramSchema := &openapi3.Schema{
|
||||
Type: paramType,
|
||||
Default: apiParam.GlobalDefault,
|
||||
Extensions: map[string]interface{}{
|
||||
plugin.APISchemaExtendGlobalDisable: apiParam.GlobalDisable,
|
||||
},
|
||||
}
|
||||
|
||||
if paramType == openapi3.TypeArray {
|
||||
if apiParam.Location == common.ParameterLocation_Path {
|
||||
return nil, errorx.New(errno.ErrPluginInvalidParamCode,
|
||||
errorx.KVf(errno.PluginMsgKey, "the type of field '%s' cannot be 'array'", apiParam.Name))
|
||||
}
|
||||
if len(apiParam.SubParameters) == 0 {
|
||||
return nil, errorx.New(errno.ErrPluginInvalidParamCode,
|
||||
errorx.KVf(errno.PluginMsgKey, "the sub parameters of field '%s' is required", apiParam.Name))
|
||||
}
|
||||
|
||||
arrayItem := apiParam.SubParameters[0]
|
||||
arrayItemType, ok := plugin.ToOpenapiParamType(arrayItem.Type)
|
||||
if !ok {
|
||||
return nil, errorx.New(errno.ErrPluginInvalidParamCode,
|
||||
errorx.KVf(errno.PluginMsgKey, "the item type '%s' of field '%s' is invalid", arrayItemType, apiParam.Name))
|
||||
}
|
||||
|
||||
if arrayItemType == openapi3.TypeObject || arrayItemType == openapi3.TypeArray {
|
||||
return nil, errorx.New(errno.ErrPluginInvalidParamCode,
|
||||
errorx.KVf(errno.PluginMsgKey, "the item type of field '%s' cannot be 'array' or 'object'", apiParam.Name))
|
||||
}
|
||||
|
||||
itemSchema := &openapi3.Schema{
|
||||
Type: arrayItemType,
|
||||
Description: arrayItem.Desc,
|
||||
Extensions: map[string]any{},
|
||||
}
|
||||
|
||||
if arrayItem.GetAssistType() > 0 {
|
||||
aType, ok := plugin.ToAPIAssistType(arrayItem.GetAssistType())
|
||||
if !ok {
|
||||
return nil, errorx.New(errno.ErrPluginInvalidParamCode,
|
||||
errorx.KVf(errno.PluginMsgKey, "the assist type '%s' of field '%s' is invalid", arrayItem.GetAssistType(), apiParam.Name))
|
||||
}
|
||||
itemSchema.Extensions[plugin.APISchemaExtendAssistType] = aType
|
||||
format, ok := plugin.AssistTypeToFormat(aType)
|
||||
if !ok {
|
||||
return nil, errorx.New(errno.ErrPluginInvalidParamCode,
|
||||
errorx.KVf(errno.PluginMsgKey, "the assist type '%s' of field '%s' is invalid", aType, apiParam.Name))
|
||||
}
|
||||
itemSchema.Format = format
|
||||
}
|
||||
|
||||
paramSchema.Items = &openapi3.SchemaRef{
|
||||
Value: itemSchema,
|
||||
}
|
||||
}
|
||||
|
||||
if apiParam.LocalDefault != nil && *apiParam.LocalDefault != "" {
|
||||
paramSchema.Default = apiParam.LocalDefault
|
||||
}
|
||||
if apiParam.LocalDisable {
|
||||
paramSchema.Extensions[plugin.APISchemaExtendLocalDisable] = true
|
||||
}
|
||||
if apiParam.VariableRef != nil && *apiParam.VariableRef != "" {
|
||||
paramSchema.Extensions[plugin.APISchemaExtendVariableRef] = apiParam.VariableRef
|
||||
}
|
||||
|
||||
if apiParam.GetAssistType() > 0 {
|
||||
aType, ok := plugin.ToAPIAssistType(apiParam.GetAssistType())
|
||||
if !ok {
|
||||
return nil, errorx.New(errno.ErrPluginInvalidParamCode,
|
||||
errorx.KVf(errno.PluginMsgKey, "the assist type '%s' of field '%s' is invalid", apiParam.GetAssistType(), apiParam.Name))
|
||||
}
|
||||
paramSchema.Extensions[plugin.APISchemaExtendAssistType] = aType
|
||||
format, ok := plugin.AssistTypeToFormat(aType)
|
||||
if !ok {
|
||||
return nil, errorx.New(errno.ErrPluginInvalidParamCode,
|
||||
errorx.KVf(errno.PluginMsgKey, "the assist type '%s' of field '%s' is invalid", aType, apiParam.Name))
|
||||
}
|
||||
paramSchema.Format = format
|
||||
}
|
||||
|
||||
loc, ok := plugin.ToHTTPParamLocation(apiParam.Location)
|
||||
if !ok {
|
||||
return nil, errorx.New(errno.ErrPluginInvalidParamCode,
|
||||
errorx.KVf(errno.PluginMsgKey, "the location '%s' of field '%s' is invalid ", apiParam.Location, apiParam.Name))
|
||||
}
|
||||
|
||||
param := &openapi3.Parameter{
|
||||
Description: apiParam.Desc,
|
||||
Name: apiParam.Name,
|
||||
In: string(loc),
|
||||
Required: apiParam.IsRequired,
|
||||
Schema: &openapi3.SchemaRef{
|
||||
Value: paramSchema,
|
||||
},
|
||||
}
|
||||
|
||||
return param, nil
|
||||
}
|
||||
|
||||
func toOpenapi3Schema(apiParam *common.APIParameter) (*openapi3.Schema, error) {
|
||||
paramType, ok := plugin.ToOpenapiParamType(apiParam.Type)
|
||||
if !ok {
|
||||
return nil, errorx.New(errno.ErrPluginInvalidParamCode,
|
||||
errorx.KVf(errno.PluginMsgKey, "the type '%s' of field '%s' is invalid", apiParam.Type, apiParam.Name))
|
||||
}
|
||||
|
||||
sc := &openapi3.Schema{
|
||||
Description: apiParam.Desc,
|
||||
Type: paramType,
|
||||
Default: apiParam.GlobalDefault,
|
||||
Extensions: map[string]interface{}{
|
||||
plugin.APISchemaExtendGlobalDisable: apiParam.GlobalDisable,
|
||||
},
|
||||
}
|
||||
if apiParam.LocalDefault != nil && *apiParam.LocalDefault != "" {
|
||||
sc.Default = apiParam.LocalDefault
|
||||
}
|
||||
if apiParam.LocalDisable {
|
||||
sc.Extensions[plugin.APISchemaExtendLocalDisable] = true
|
||||
}
|
||||
if apiParam.VariableRef != nil && *apiParam.VariableRef != "" {
|
||||
sc.Extensions[plugin.APISchemaExtendVariableRef] = apiParam.VariableRef
|
||||
}
|
||||
|
||||
if apiParam.GetAssistType() > 0 {
|
||||
aType, ok := plugin.ToAPIAssistType(apiParam.GetAssistType())
|
||||
if !ok {
|
||||
return nil, errorx.New(errno.ErrPluginInvalidParamCode,
|
||||
errorx.KVf(errno.PluginMsgKey, "the assist type '%s' of field '%s' is invalid", apiParam.GetAssistType(), apiParam.Name))
|
||||
}
|
||||
sc.Extensions[plugin.APISchemaExtendAssistType] = aType
|
||||
format, ok := plugin.AssistTypeToFormat(aType)
|
||||
if !ok {
|
||||
return nil, errorx.New(errno.ErrPluginInvalidParamCode,
|
||||
errorx.KVf(errno.PluginMsgKey, "the assist type '%s' of field '%s' is invalid", aType, apiParam.Name))
|
||||
}
|
||||
sc.Format = format
|
||||
}
|
||||
|
||||
switch paramType {
|
||||
case openapi3.TypeObject:
|
||||
sc.Properties = map[string]*openapi3.SchemaRef{}
|
||||
for _, subParam := range apiParam.SubParameters {
|
||||
_subParam, err := toOpenapi3Schema(subParam)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sc.Properties[subParam.Name] = &openapi3.SchemaRef{
|
||||
Value: _subParam,
|
||||
}
|
||||
if subParam.IsRequired {
|
||||
sc.Required = append(sc.Required, subParam.Name)
|
||||
}
|
||||
}
|
||||
|
||||
return sc, nil
|
||||
|
||||
case openapi3.TypeArray:
|
||||
if len(apiParam.SubParameters) == 0 {
|
||||
return nil, errorx.New(errno.ErrPluginInvalidParamCode,
|
||||
errorx.KVf(errno.PluginMsgKey, "the sub-parameters of field '%s' are required", apiParam.Name))
|
||||
}
|
||||
|
||||
arrayItem := apiParam.SubParameters[0]
|
||||
itemType, ok := plugin.ToOpenapiParamType(arrayItem.Type)
|
||||
if !ok {
|
||||
return nil, errorx.New(errno.ErrPluginInvalidParamCode,
|
||||
errorx.KVf(errno.PluginMsgKey, "the item type '%s' of field '%s' is invalid", itemType, apiParam.Name))
|
||||
}
|
||||
|
||||
subParam, err := toOpenapi3Schema(arrayItem)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sc.Items = &openapi3.SchemaRef{
|
||||
Value: subParam,
|
||||
}
|
||||
|
||||
return sc, nil
|
||||
}
|
||||
|
||||
return sc, nil
|
||||
}
|
||||
41
backend/application/connector/connector.go
Normal file
41
backend/application/connector/connector.go
Normal file
@@ -0,0 +1,41 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package connector
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/connector/entity"
|
||||
connector "github.com/coze-dev/coze-studio/backend/domain/connector/service"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
)
|
||||
|
||||
type ConnectorApplicationService struct {
|
||||
DomainSVC connector.Connector
|
||||
}
|
||||
|
||||
var ConnectorApplicationSVC *ConnectorApplicationService
|
||||
|
||||
func New(domainSVC connector.Connector, tosClient storage.Storage) *ConnectorApplicationService {
|
||||
return &ConnectorApplicationService{
|
||||
DomainSVC: domainSVC,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ConnectorApplicationService) List(ctx context.Context) ([]*entity.Connector, error) {
|
||||
return c.DomainSVC.List(ctx)
|
||||
}
|
||||
29
backend/application/connector/init.go
Normal file
29
backend/application/connector/init.go
Normal file
@@ -0,0 +1,29 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package connector
|
||||
|
||||
import (
|
||||
connector "github.com/coze-dev/coze-studio/backend/domain/connector/service"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
)
|
||||
|
||||
func InitService(tos storage.Storage) *ConnectorApplicationService {
|
||||
connectorDomainSVC := connector.NewService(tos)
|
||||
ConnectorApplicationSVC = New(connectorDomainSVC, tos)
|
||||
|
||||
return ConnectorApplicationSVC
|
||||
}
|
||||
471
backend/application/conversation/agent_run.go
Normal file
471
backend/application/conversation/agent_run.go
Normal file
@@ -0,0 +1,471 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"strconv"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/conversation/message"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/conversation/run"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
|
||||
crossDomainMessage "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
||||
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
|
||||
saEntity "github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
|
||||
convEntity "github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/entity"
|
||||
msgEntity "github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
|
||||
cmdEntity "github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/entity"
|
||||
sseImpl "github.com/coze-dev/coze-studio/backend/infra/impl/sse"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
func (c *ConversationApplicationService) Run(ctx context.Context, sseSender *sseImpl.SSenderImpl, ar *run.AgentRunRequest) error {
|
||||
agentInfo, caErr := c.checkAgent(ctx, ar)
|
||||
if caErr != nil {
|
||||
logs.CtxErrorf(ctx, "checkAgent err:%v", caErr)
|
||||
return caErr
|
||||
}
|
||||
|
||||
userID := ctxutil.MustGetUIDFromCtx(ctx)
|
||||
conversationData, ccErr := c.checkConversation(ctx, ar, userID)
|
||||
|
||||
if ccErr != nil {
|
||||
logs.CtxErrorf(ctx, "checkConversation err:%v", ccErr)
|
||||
return ccErr
|
||||
}
|
||||
|
||||
if ar.RegenMessageID != nil && ptr.From(ar.RegenMessageID) > 0 {
|
||||
msgMeta, err := c.MessageDomainSVC.GetByID(ctx, ptr.From(ar.RegenMessageID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if msgMeta != nil {
|
||||
if msgMeta.UserID != conv.Int64ToStr(userID) {
|
||||
return errorx.New(errno.ErrConversationPermissionCode, errorx.KV("msg", "message not match"))
|
||||
}
|
||||
delErr := c.MessageDomainSVC.Delete(ctx, &msgEntity.DeleteMeta{
|
||||
RunIDs: []int64{msgMeta.RunID},
|
||||
})
|
||||
if delErr != nil {
|
||||
return delErr
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
var shortcutCmd *cmdEntity.ShortcutCmd
|
||||
if ar.GetShortcutCmdID() > 0 {
|
||||
cmdID := ar.GetShortcutCmdID()
|
||||
cmdMeta, err := c.ShortcutDomainSVC.GetByCmdID(ctx, cmdID, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
shortcutCmd = cmdMeta
|
||||
}
|
||||
|
||||
arr, err := c.buildAgentRunRequest(ctx, ar, userID, agentInfo.SpaceID, conversationData, shortcutCmd)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "buildAgentRunRequest err:%v", err)
|
||||
return err
|
||||
}
|
||||
streamer, err := c.AgentRunDomainSVC.AgentRun(ctx, arr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.pullStream(ctx, sseSender, streamer, ar)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ConversationApplicationService) pullStream(ctx context.Context, sseSender *sseImpl.SSenderImpl, arStream *schema.StreamReader[*entity.AgentRunResponse], req *run.AgentRunRequest) {
|
||||
var ackMessageInfo *entity.ChunkMessageItem
|
||||
for {
|
||||
chunk, recvErr := arStream.Recv()
|
||||
if recvErr != nil {
|
||||
if errors.Is(recvErr, io.EOF) {
|
||||
return
|
||||
}
|
||||
sseSender.Send(ctx, buildErrorEvent(errno.ErrConversationAgentRunError, recvErr.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
switch chunk.Event {
|
||||
case entity.RunEventCreated, entity.RunEventInProgress, entity.RunEventCompleted:
|
||||
case entity.RunEventError:
|
||||
id, err := c.GenID(ctx)
|
||||
if err != nil {
|
||||
sseSender.Send(ctx, buildErrorEvent(errno.ErrConversationAgentRunError, err.Error()))
|
||||
|
||||
} else {
|
||||
sseSender.Send(ctx, buildMessageChunkEvent(run.RunEventMessage, buildErrMsg(ackMessageInfo, chunk.Error, id)))
|
||||
}
|
||||
case entity.RunEventStreamDone:
|
||||
sseSender.Send(ctx, buildDoneEvent(run.RunEventDone))
|
||||
case entity.RunEventAck:
|
||||
ackMessageInfo = chunk.ChunkMessageItem
|
||||
sseSender.Send(ctx, buildMessageChunkEvent(run.RunEventMessage, buildARSM2Message(chunk, req)))
|
||||
case entity.RunEventMessageDelta, entity.RunEventMessageCompleted:
|
||||
sseSender.Send(ctx, buildMessageChunkEvent(run.RunEventMessage, buildARSM2Message(chunk, req)))
|
||||
default:
|
||||
logs.CtxErrorf(ctx, "unknown handler event:%v", chunk.Event)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func buildARSM2Message(chunk *entity.AgentRunResponse, req *run.AgentRunRequest) []byte {
|
||||
chunkMessageItem := chunk.ChunkMessageItem
|
||||
|
||||
chunkMessage := &run.RunStreamResponse{
|
||||
ConversationID: strconv.FormatInt(chunkMessageItem.ConversationID, 10),
|
||||
IsFinish: ptr.Of(chunk.ChunkMessageItem.IsFinish),
|
||||
Message: &message.ChatMessage{
|
||||
Role: string(chunkMessageItem.Role),
|
||||
ContentType: string(chunkMessageItem.ContentType),
|
||||
MessageID: strconv.FormatInt(chunkMessageItem.ID, 10),
|
||||
SectionID: strconv.FormatInt(chunkMessageItem.SectionID, 10),
|
||||
ContentTime: chunkMessageItem.CreatedAt,
|
||||
ExtraInfo: buildExt(chunkMessageItem.Ext),
|
||||
ReplyID: strconv.FormatInt(chunkMessageItem.ReplyID, 10),
|
||||
|
||||
Status: "",
|
||||
Type: string(chunkMessageItem.MessageType),
|
||||
Content: chunkMessageItem.Content,
|
||||
ReasoningContent: chunkMessageItem.ReasoningContent,
|
||||
RequiredAction: chunkMessageItem.RequiredAction,
|
||||
},
|
||||
Index: int32(chunkMessageItem.Index),
|
||||
SeqID: int32(chunkMessageItem.SeqID),
|
||||
}
|
||||
if chunkMessageItem.MessageType == crossDomainMessage.MessageTypeAck {
|
||||
chunkMessage.Message.Content = req.GetQuery()
|
||||
chunkMessage.Message.ContentType = req.GetContentType()
|
||||
chunkMessage.Message.ExtraInfo = &message.ExtraInfo{
|
||||
LocalMessageID: req.GetLocalMessageID(),
|
||||
}
|
||||
} else {
|
||||
chunkMessage.Message.ExtraInfo = buildExt(chunkMessageItem.Ext)
|
||||
chunkMessage.Message.SenderID = ptr.Of(strconv.FormatInt(chunkMessageItem.AgentID, 10))
|
||||
chunkMessage.Message.Content = chunkMessageItem.Content
|
||||
|
||||
if chunkMessageItem.MessageType == crossDomainMessage.MessageTypeKnowledge {
|
||||
chunkMessage.Message.Type = string(crossDomainMessage.MessageTypeVerbose)
|
||||
}
|
||||
}
|
||||
|
||||
if chunk.ChunkMessageItem.IsFinish && chunkMessageItem.MessageType == crossDomainMessage.MessageTypeAnswer {
|
||||
chunkMessage.Message.Content = ""
|
||||
chunkMessage.Message.ReasoningContent = ptr.Of("")
|
||||
}
|
||||
|
||||
mCM, _ := json.Marshal(chunkMessage)
|
||||
return mCM
|
||||
}
|
||||
|
||||
func buildExt(extra map[string]string) *message.ExtraInfo {
|
||||
if extra == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &message.ExtraInfo{
|
||||
InputTokens: extra["input_tokens"],
|
||||
OutputTokens: extra["output_tokens"],
|
||||
Token: extra["token"],
|
||||
PluginStatus: extra["plugin_status"],
|
||||
TimeCost: extra["time_cost"],
|
||||
WorkflowTokens: extra["workflow_tokens"],
|
||||
BotState: extra["bot_state"],
|
||||
PluginRequest: extra["plugin_request"],
|
||||
ToolName: extra["tool_name"],
|
||||
Plugin: extra["plugin"],
|
||||
MockHitInfo: extra["mock_hit_info"],
|
||||
MessageTitle: extra["message_title"],
|
||||
StreamPluginRunning: extra["stream_plugin_running"],
|
||||
ExecuteDisplayName: extra["execute_display_name"],
|
||||
TaskType: extra["task_type"],
|
||||
ReferFormat: extra["refer_format"],
|
||||
}
|
||||
}
|
||||
func buildErrMsg(ackChunk *entity.ChunkMessageItem, err *entity.RunError, id int64) []byte {
|
||||
|
||||
chunkMessage := &run.RunStreamResponse{
|
||||
IsFinish: ptr.Of(true),
|
||||
ConversationID: strconv.FormatInt(ackChunk.ConversationID, 10),
|
||||
Message: &message.ChatMessage{
|
||||
Role: string(schema.Assistant),
|
||||
ContentType: string(crossDomainMessage.ContentTypeText),
|
||||
Type: string(crossDomainMessage.MessageTypeAnswer),
|
||||
MessageID: strconv.FormatInt(id, 10),
|
||||
SectionID: strconv.FormatInt(ackChunk.SectionID, 10),
|
||||
ReplyID: strconv.FormatInt(ackChunk.ReplyID, 10),
|
||||
Content: "Something error:" + err.Msg,
|
||||
ExtraInfo: &message.ExtraInfo{},
|
||||
},
|
||||
}
|
||||
|
||||
mCM, _ := json.Marshal(chunkMessage)
|
||||
return mCM
|
||||
}
|
||||
|
||||
func (c *ConversationApplicationService) GenID(ctx context.Context) (int64, error) {
|
||||
id, err := c.appContext.IDGen.GenID(ctx)
|
||||
return id, err
|
||||
}
|
||||
|
||||
func (c *ConversationApplicationService) checkConversation(ctx context.Context, ar *run.AgentRunRequest, userID int64) (*convEntity.Conversation, error) {
|
||||
var conversationData *convEntity.Conversation
|
||||
if ar.ConversationID > 0 {
|
||||
|
||||
realCurrCon, err := c.ConversationDomainSVC.GetCurrentConversation(ctx, &convEntity.GetCurrent{
|
||||
UserID: userID,
|
||||
AgentID: ar.BotID,
|
||||
Scene: ptr.From(ar.Scene),
|
||||
ConnectorID: consts.CozeConnectorID,
|
||||
})
|
||||
logs.CtxInfof(ctx, "conversatioin data:%v", conv.DebugJsonToStr(realCurrCon))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if realCurrCon != nil {
|
||||
conversationData = realCurrCon
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if ar.ConversationID == 0 || conversationData == nil {
|
||||
|
||||
conData, err := c.ConversationDomainSVC.Create(ctx, &convEntity.CreateMeta{
|
||||
AgentID: ar.BotID,
|
||||
UserID: userID,
|
||||
Scene: ptr.From(ar.Scene),
|
||||
ConnectorID: consts.CozeConnectorID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
logs.CtxInfof(ctx, "conversatioin create data:%v", conv.DebugJsonToStr(conData))
|
||||
conversationData = conData
|
||||
|
||||
ar.ConversationID = conversationData.ID
|
||||
}
|
||||
|
||||
if conversationData.CreatorID != userID {
|
||||
return nil, errorx.New(errno.ErrConversationPermissionCode, errorx.KV("msg", "conversation not match"))
|
||||
}
|
||||
|
||||
return conversationData, nil
|
||||
}
|
||||
|
||||
func (c *ConversationApplicationService) checkAgent(ctx context.Context, ar *run.AgentRunRequest) (*saEntity.SingleAgent, error) {
|
||||
agentInfo, err := c.appContext.SingleAgentDomainSVC.GetSingleAgent(ctx, ar.BotID, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if agentInfo == nil {
|
||||
return nil, errorx.New(errno.ErrAgentNotExists)
|
||||
}
|
||||
return agentInfo, nil
|
||||
}
|
||||
|
||||
func (c *ConversationApplicationService) buildAgentRunRequest(ctx context.Context, ar *run.AgentRunRequest, userID int64, spaceID int64, conversationData *convEntity.Conversation, shortcutCMD *cmdEntity.ShortcutCmd) (*entity.AgentRunMeta, error) {
|
||||
var contentType crossDomainMessage.ContentType
|
||||
contentType = crossDomainMessage.ContentTypeText
|
||||
|
||||
if ptr.From(ar.ContentType) != string(crossDomainMessage.ContentTypeText) {
|
||||
contentType = crossDomainMessage.ContentTypeMix
|
||||
}
|
||||
|
||||
shortcutCMDData, err := c.buildTools(ctx, ar.ToolList, shortcutCMD)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
arm := &entity.AgentRunMeta{
|
||||
ConversationID: conversationData.ID,
|
||||
AgentID: ar.BotID,
|
||||
Content: c.buildMultiContent(ctx, ar),
|
||||
DisplayContent: c.buildDisplayContent(ctx, ar),
|
||||
SpaceID: spaceID,
|
||||
UserID: conv.Int64ToStr(userID),
|
||||
SectionID: conversationData.SectionID,
|
||||
PreRetrieveTools: shortcutCMDData,
|
||||
IsDraft: ptr.From(ar.DraftMode),
|
||||
ConnectorID: consts.CozeConnectorID,
|
||||
ContentType: contentType,
|
||||
Ext: ar.Extra,
|
||||
}
|
||||
return arm, nil
|
||||
}
|
||||
|
||||
func (c *ConversationApplicationService) buildDisplayContent(ctx context.Context, ar *run.AgentRunRequest) string {
|
||||
if *ar.ContentType == run.ContentTypeText {
|
||||
return ""
|
||||
}
|
||||
return ar.Query
|
||||
}
|
||||
|
||||
func (c *ConversationApplicationService) buildTools(ctx context.Context, tools []*run.Tool, shortcutCMD *cmdEntity.ShortcutCmd) ([]*entity.Tool, error) {
|
||||
var ts []*entity.Tool
|
||||
for _, tool := range tools {
|
||||
if shortcutCMD != nil {
|
||||
|
||||
arguments := make(map[string]string)
|
||||
for key, parametersStruct := range tool.Parameters {
|
||||
if parametersStruct == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
arguments[key] = parametersStruct.Value
|
||||
// uri需要转换成url
|
||||
if parametersStruct.ResourceType == consts.ShortcutCommandResourceType {
|
||||
|
||||
resourceInfo, err := c.appContext.ImageX.GetResourceURL(ctx, parametersStruct.Value)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
arguments[key] = resourceInfo.URL
|
||||
}
|
||||
}
|
||||
|
||||
argBytes, err := json.Marshal(arguments)
|
||||
if err == nil {
|
||||
ts = append(ts, &entity.Tool{
|
||||
PluginID: shortcutCMD.PluginID,
|
||||
Arguments: string(argBytes),
|
||||
ToolName: shortcutCMD.PluginToolName,
|
||||
ToolID: shortcutCMD.PluginToolID,
|
||||
Type: agentrun.ToolType(shortcutCMD.ToolType),
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
return ts, nil
|
||||
}
|
||||
|
||||
func (c *ConversationApplicationService) buildMultiContent(ctx context.Context, ar *run.AgentRunRequest) []*crossDomainMessage.InputMetaData {
|
||||
var multiContents []*crossDomainMessage.InputMetaData
|
||||
|
||||
switch *ar.ContentType {
|
||||
case run.ContentTypeText:
|
||||
multiContents = append(multiContents, &crossDomainMessage.InputMetaData{
|
||||
Type: crossDomainMessage.InputTypeText,
|
||||
Text: ar.Query,
|
||||
})
|
||||
case run.ContentTypeImage, run.ContentTypeFile, run.ContentTypeMix, run.ContentTypeVideo, run.ContentTypeAudio:
|
||||
var mc *run.MixContentModel
|
||||
|
||||
err := json.Unmarshal([]byte(ar.Query), &mc)
|
||||
if err != nil {
|
||||
multiContents = append(multiContents, &crossDomainMessage.InputMetaData{
|
||||
Type: crossDomainMessage.InputTypeText,
|
||||
Text: ar.Query,
|
||||
})
|
||||
return multiContents
|
||||
}
|
||||
|
||||
mcContent, newItemList := c.parseMultiContent(ctx, mc.ItemList)
|
||||
|
||||
multiContents = append(multiContents, mcContent...)
|
||||
|
||||
mc.ItemList = newItemList
|
||||
mcByte, err := json.Marshal(mc)
|
||||
if err == nil {
|
||||
ar.Query = string(mcByte)
|
||||
}
|
||||
}
|
||||
|
||||
return multiContents
|
||||
}
|
||||
|
||||
func (c *ConversationApplicationService) parseMultiContent(ctx context.Context, mc []*run.Item) (multiContents []*crossDomainMessage.InputMetaData, mcNew []*run.Item) {
|
||||
for index, item := range mc {
|
||||
switch item.Type {
|
||||
case run.ContentTypeText:
|
||||
multiContents = append(multiContents, &crossDomainMessage.InputMetaData{
|
||||
Type: crossDomainMessage.InputTypeText,
|
||||
Text: item.Text,
|
||||
})
|
||||
case run.ContentTypeImage:
|
||||
|
||||
resourceUrl, err := c.getUrlByUri(ctx, item.Image.Key)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "failed to unescape resource url, err is %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
mc[index].Image.ImageThumb.URL = resourceUrl
|
||||
mc[index].Image.ImageOri.URL = resourceUrl
|
||||
|
||||
multiContents = append(multiContents, &crossDomainMessage.InputMetaData{
|
||||
Type: crossDomainMessage.InputTypeImage,
|
||||
FileData: []*crossDomainMessage.FileData{
|
||||
{
|
||||
Url: resourceUrl,
|
||||
},
|
||||
},
|
||||
})
|
||||
case run.ContentTypeFile, run.ContentTypeAudio, run.ContentTypeVideo:
|
||||
|
||||
resourceUrl, err := c.getUrlByUri(ctx, item.File.FileKey)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
mc[index].File.FileURL = resourceUrl
|
||||
|
||||
multiContents = append(multiContents, &crossDomainMessage.InputMetaData{
|
||||
Type: crossDomainMessage.InputType(item.Type),
|
||||
FileData: []*crossDomainMessage.FileData{
|
||||
{
|
||||
Url: resourceUrl,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return multiContents, mc
|
||||
}
|
||||
|
||||
func (s *ConversationApplicationService) getUrlByUri(ctx context.Context, uri string) (string, error) {
|
||||
|
||||
url, err := s.appContext.ImageX.GetResourceURL(ctx, uri)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return url.URL, nil
|
||||
}
|
||||
51
backend/application/conversation/build_chunk_event.go
Normal file
51
backend/application/conversation/build_chunk_event.go
Normal file
@@ -0,0 +1,51 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/hertz-contrib/sse"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/conversation/run"
|
||||
)
|
||||
|
||||
func buildDoneEvent(event string) *sse.Event {
|
||||
return &sse.Event{
|
||||
Event: event,
|
||||
}
|
||||
}
|
||||
|
||||
func buildErrorEvent(errCode int64, errMsg string) *sse.Event {
|
||||
errData := run.ErrorData{
|
||||
Code: errCode,
|
||||
Msg: errMsg,
|
||||
}
|
||||
ed, _ := json.Marshal(errData)
|
||||
|
||||
return &sse.Event{
|
||||
Event: run.RunEventError,
|
||||
Data: ed,
|
||||
}
|
||||
}
|
||||
|
||||
func buildMessageChunkEvent(event string, chunkMsg []byte) *sse.Event {
|
||||
return &sse.Event{
|
||||
Event: event,
|
||||
Data: chunkMsg,
|
||||
}
|
||||
}
|
||||
188
backend/application/conversation/conversation.go
Normal file
188
backend/application/conversation/conversation.go
Normal file
@@ -0,0 +1,188 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/conversation/common"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/conversation/conversation"
|
||||
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
|
||||
agentrun "github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/service"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/entity"
|
||||
conversationService "github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/service"
|
||||
message "github.com/coze-dev/coze-studio/backend/domain/conversation/message/service"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/service"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
type ConversationApplicationService struct {
|
||||
appContext *ServiceComponents
|
||||
|
||||
AgentRunDomainSVC agentrun.Run
|
||||
ConversationDomainSVC conversationService.Conversation
|
||||
MessageDomainSVC message.Message
|
||||
|
||||
ShortcutDomainSVC service.ShortcutCmd
|
||||
}
|
||||
|
||||
var ConversationSVC = new(ConversationApplicationService)
|
||||
|
||||
type OpenapiAgentRunApplication struct {
|
||||
ShortcutDomainSVC service.ShortcutCmd
|
||||
}
|
||||
|
||||
var ConversationOpenAPISVC = new(OpenapiAgentRunApplication)
|
||||
|
||||
func (c *ConversationApplicationService) ClearHistory(ctx context.Context, req *conversation.ClearConversationHistoryRequest) (*conversation.ClearConversationHistoryResponse, error) {
|
||||
resp := new(conversation.ClearConversationHistoryResponse)
|
||||
|
||||
conversationID := req.ConversationID
|
||||
|
||||
// get conversation
|
||||
currentRes, err := c.ConversationDomainSVC.GetByID(ctx, conversationID)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
if currentRes == nil {
|
||||
return resp, errorx.New(errno.ErrConversationNotFound)
|
||||
}
|
||||
// check user
|
||||
userID := ctxutil.GetUIDFromCtx(ctx)
|
||||
if userID == nil || *userID != currentRes.CreatorID {
|
||||
return resp, errorx.New(errno.ErrConversationNotFound, errorx.KV("msg", "user not match"))
|
||||
}
|
||||
|
||||
// delete conversation
|
||||
err = c.ConversationDomainSVC.Delete(ctx, conversationID)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
// create new conversation
|
||||
convRes, err := c.ConversationDomainSVC.Create(ctx, &entity.CreateMeta{
|
||||
AgentID: currentRes.AgentID,
|
||||
UserID: currentRes.CreatorID,
|
||||
Scene: currentRes.Scene,
|
||||
ConnectorID: consts.CozeConnectorID,
|
||||
})
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
resp.NewSectionID = convRes.SectionID
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (c *ConversationApplicationService) CreateSection(ctx context.Context, conversationID int64) (int64, error) {
|
||||
currentRes, err := c.ConversationDomainSVC.GetByID(ctx, conversationID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if currentRes == nil {
|
||||
return 0, errorx.New(errno.ErrConversationNotFound, errorx.KV("msg", "conversation not found"))
|
||||
}
|
||||
var userID int64
|
||||
if currentRes.ConnectorID == consts.CozeConnectorID {
|
||||
userID = ctxutil.MustGetUIDFromCtx(ctx)
|
||||
} else {
|
||||
userID = ctxutil.MustGetUIDFromApiAuthCtx(ctx)
|
||||
}
|
||||
|
||||
if userID != currentRes.CreatorID {
|
||||
return 0, errorx.New(errno.ErrConversationNotFound, errorx.KV("msg", "user not match"))
|
||||
}
|
||||
|
||||
convRes, err := c.ConversationDomainSVC.NewConversationCtx(ctx, &entity.NewConversationCtxRequest{
|
||||
ID: conversationID,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return convRes.SectionID, nil
|
||||
}
|
||||
|
||||
func (c *ConversationApplicationService) CreateConversation(ctx context.Context, agentID int64, connectorID int64) (*conversation.CreateConversationResponse, error) {
|
||||
resp := new(conversation.CreateConversationResponse)
|
||||
apiKeyInfo := ctxutil.GetApiAuthFromCtx(ctx)
|
||||
userID := apiKeyInfo.UserID
|
||||
if connectorID != consts.WebSDKConnectorID {
|
||||
connectorID = apiKeyInfo.ConnectorID
|
||||
}
|
||||
|
||||
conversationData, err := c.ConversationDomainSVC.Create(ctx, &entity.CreateMeta{
|
||||
AgentID: agentID,
|
||||
UserID: userID,
|
||||
ConnectorID: connectorID,
|
||||
Scene: common.Scene_SceneOpenApi,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp.ConversationData = &conversation.ConversationData{
|
||||
Id: conversationData.ID,
|
||||
LastSectionID: &conversationData.SectionID,
|
||||
ConnectorID: &conversationData.ConnectorID,
|
||||
CreatedAt: conversationData.CreatedAt,
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (c *ConversationApplicationService) ListConversation(ctx context.Context, req *conversation.ListConversationsApiRequest) (*conversation.ListConversationsApiResponse, error) {
|
||||
|
||||
resp := new(conversation.ListConversationsApiResponse)
|
||||
|
||||
apiKeyInfo := ctxutil.GetApiAuthFromCtx(ctx)
|
||||
userID := apiKeyInfo.UserID
|
||||
connectorID := apiKeyInfo.ConnectorID
|
||||
|
||||
if userID == 0 {
|
||||
return resp, errorx.New(errno.ErrConversationNotFound)
|
||||
}
|
||||
if ptr.From(req.ConnectorID) == consts.WebSDKConnectorID {
|
||||
connectorID = ptr.From(req.ConnectorID)
|
||||
}
|
||||
|
||||
conversationDOList, hasMore, err := c.ConversationDomainSVC.List(ctx, &entity.ListMeta{
|
||||
UserID: userID,
|
||||
AgentID: req.GetBotID(),
|
||||
ConnectorID: connectorID,
|
||||
Scene: common.Scene_SceneOpenApi,
|
||||
Page: int(req.GetPageNum()),
|
||||
Limit: int(req.GetPageSize()),
|
||||
})
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
conversationData := slices.Transform(conversationDOList, func(conv *entity.Conversation) *conversation.ConversationData {
|
||||
return &conversation.ConversationData{
|
||||
Id: conv.ID,
|
||||
LastSectionID: &conv.SectionID,
|
||||
ConnectorID: &conv.ConnectorID,
|
||||
CreatedAt: conv.CreatedAt,
|
||||
}
|
||||
})
|
||||
|
||||
resp.Data = &conversation.ListConversationData{
|
||||
Conversations: conversationData,
|
||||
HasMore: hasMore,
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
76
backend/application/conversation/init.go
Normal file
76
backend/application/conversation/init.go
Normal file
@@ -0,0 +1,76 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/application/singleagent"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/repository"
|
||||
agentrun "github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/service"
|
||||
convRepo "github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/repository"
|
||||
conversation "github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/service"
|
||||
msgRepo "github.com/coze-dev/coze-studio/backend/domain/conversation/message/repository"
|
||||
message "github.com/coze-dev/coze-studio/backend/domain/conversation/message/service"
|
||||
shortcutRepo "github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/repository"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/service"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
)
|
||||
|
||||
type ServiceComponents struct {
|
||||
IDGen idgen.IDGenerator
|
||||
DB *gorm.DB
|
||||
TosClient storage.Storage
|
||||
ImageX imagex.ImageX
|
||||
|
||||
SingleAgentDomainSVC singleagent.SingleAgent
|
||||
}
|
||||
|
||||
func InitService(s *ServiceComponents) *ConversationApplicationService {
|
||||
mDomainComponents := &message.Components{
|
||||
MessageRepo: msgRepo.NewMessageRepo(s.DB, s.IDGen),
|
||||
}
|
||||
messageDomainSVC := message.NewService(mDomainComponents)
|
||||
|
||||
cDomainComponents := &conversation.Components{
|
||||
ConversationRepo: convRepo.NewConversationRepo(s.DB, s.IDGen),
|
||||
}
|
||||
|
||||
conversationDomainSVC := conversation.NewService(cDomainComponents)
|
||||
|
||||
arDomainComponents := &agentrun.Components{
|
||||
RunRecordRepo: repository.NewRunRecordRepo(s.DB, s.IDGen),
|
||||
}
|
||||
|
||||
agentRunDomainSVC := agentrun.NewService(arDomainComponents)
|
||||
components := &service.Components{
|
||||
ShortCutCmdRepo: shortcutRepo.NewShortCutCmdRepo(s.DB, s.IDGen),
|
||||
}
|
||||
shortcutCmdDomainSVC := service.NewShortcutCommandService(components)
|
||||
|
||||
ConversationSVC.AgentRunDomainSVC = agentRunDomainSVC
|
||||
ConversationSVC.MessageDomainSVC = messageDomainSVC
|
||||
ConversationSVC.ConversationDomainSVC = conversationDomainSVC
|
||||
ConversationSVC.appContext = s
|
||||
ConversationSVC.ShortcutDomainSVC = shortcutCmdDomainSVC
|
||||
|
||||
ConversationOpenAPISVC.ShortcutDomainSVC = shortcutCmdDomainSVC
|
||||
|
||||
return ConversationSVC
|
||||
}
|
||||
293
backend/application/conversation/message.go
Normal file
293
backend/application/conversation/message.go
Normal file
@@ -0,0 +1,293 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/conversation/common"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/conversation/message"
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
||||
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
|
||||
singleAgentEntity "github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
|
||||
convEntity "github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
func (c *ConversationApplicationService) GetMessageList(ctx context.Context, mr *message.GetMessageListRequest) (*message.GetMessageListResponse, error) {
|
||||
// Get Conversation ID by agent id & userID & scene
|
||||
userID := ctxutil.GetUIDFromCtx(ctx)
|
||||
|
||||
agentID, err := strconv.ParseInt(mr.BotID, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
currentConversation, isNewCreate, err := c.getCurrentConversation(ctx, *userID, agentID, *mr.Scene, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if isNewCreate {
|
||||
return &message.GetMessageListResponse{
|
||||
MessageList: []*message.ChatMessage{},
|
||||
Cursor: mr.Cursor,
|
||||
NextCursor: "0",
|
||||
NextHasMore: false,
|
||||
ConversationID: strconv.FormatInt(currentConversation.ID, 10),
|
||||
LastSectionID: ptr.Of(strconv.FormatInt(currentConversation.SectionID, 10)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
cursor, err := strconv.ParseInt(mr.Cursor, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mListMessages, err := c.MessageDomainSVC.List(ctx, &entity.ListMeta{
|
||||
ConversationID: currentConversation.ID,
|
||||
AgentID: agentID,
|
||||
Limit: int(mr.Count),
|
||||
Cursor: cursor,
|
||||
Direction: loadDirectionToScrollDirection(mr.LoadDirection),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// get agent id
|
||||
var agentIDs []int64
|
||||
for _, mOne := range mListMessages.Messages {
|
||||
agentIDs = append(agentIDs, mOne.AgentID)
|
||||
}
|
||||
|
||||
agentInfo, err := c.buildAgentInfo(ctx, agentIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp := c.buildMessageListResponse(ctx, mListMessages, currentConversation)
|
||||
|
||||
resp.ParticipantInfoMap = map[string]*message.MsgParticipantInfo{}
|
||||
for _, aOne := range agentInfo {
|
||||
resp.ParticipantInfoMap[aOne.ID] = aOne
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (c *ConversationApplicationService) buildAgentInfo(ctx context.Context, agentIDs []int64) ([]*message.MsgParticipantInfo, error) {
|
||||
var result []*message.MsgParticipantInfo
|
||||
if len(agentIDs) > 0 {
|
||||
agentInfos, err := c.appContext.SingleAgentDomainSVC.MGetSingleAgentDraft(ctx, agentIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result = slices.Transform(agentInfos, func(a *singleAgentEntity.SingleAgent) *message.MsgParticipantInfo {
|
||||
return &message.MsgParticipantInfo{
|
||||
ID: strconv.FormatInt(a.AgentID, 10),
|
||||
Name: a.Name,
|
||||
UserID: strconv.FormatInt(a.CreatorID, 10),
|
||||
Desc: a.Desc,
|
||||
AvatarURL: a.IconURI,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *ConversationApplicationService) getCurrentConversation(ctx context.Context, userID int64, agentID int64, scene common.Scene, connectorID *int64) (*convEntity.Conversation, bool, error) {
|
||||
var currentConversation *convEntity.Conversation
|
||||
var isNewCreate bool
|
||||
|
||||
if connectorID == nil && scene == common.Scene_Playground {
|
||||
connectorID = ptr.Of(consts.CozeConnectorID)
|
||||
}
|
||||
|
||||
currentConversation, err := c.ConversationDomainSVC.GetCurrentConversation(ctx, &convEntity.GetCurrent{
|
||||
UserID: userID,
|
||||
Scene: scene,
|
||||
AgentID: agentID,
|
||||
ConnectorID: ptr.From(connectorID),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, isNewCreate, err
|
||||
}
|
||||
|
||||
if currentConversation == nil { // new conversation
|
||||
// create conversation
|
||||
ccNew, err := c.ConversationDomainSVC.Create(ctx, &convEntity.CreateMeta{
|
||||
AgentID: agentID,
|
||||
UserID: userID,
|
||||
Scene: scene,
|
||||
ConnectorID: ptr.From(connectorID),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, isNewCreate, err
|
||||
}
|
||||
if ccNew == nil {
|
||||
return nil, isNewCreate,
|
||||
errorx.New(errno.ErrConversationNotFound)
|
||||
}
|
||||
isNewCreate = true
|
||||
currentConversation = ccNew
|
||||
}
|
||||
|
||||
return currentConversation, isNewCreate, nil
|
||||
}
|
||||
|
||||
func loadDirectionToScrollDirection(direction *message.LoadDirection) entity.ScrollPageDirection {
|
||||
if direction != nil && *direction == message.LoadDirection_Next {
|
||||
return entity.ScrollPageDirectionNext
|
||||
}
|
||||
return entity.ScrollPageDirectionPrev
|
||||
}
|
||||
|
||||
func (c *ConversationApplicationService) buildMessageListResponse(ctx context.Context, mListMessages *entity.ListResult, currentConversation *convEntity.Conversation) *message.GetMessageListResponse {
|
||||
var messages []*message.ChatMessage
|
||||
runToQuestionIDMap := make(map[int64]int64)
|
||||
|
||||
for _, mMessage := range mListMessages.Messages {
|
||||
if mMessage.MessageType == model.MessageTypeQuestion {
|
||||
runToQuestionIDMap[mMessage.RunID] = mMessage.ID
|
||||
}
|
||||
}
|
||||
|
||||
for _, mMessage := range mListMessages.Messages {
|
||||
messages = append(messages, c.buildDomainMsg2VOMessage(ctx, mMessage, runToQuestionIDMap))
|
||||
}
|
||||
|
||||
resp := &message.GetMessageListResponse{
|
||||
MessageList: messages,
|
||||
Cursor: strconv.FormatInt(mListMessages.PrevCursor, 10),
|
||||
NextCursor: strconv.FormatInt(mListMessages.NextCursor, 10),
|
||||
ConversationID: strconv.FormatInt(currentConversation.ID, 10),
|
||||
LastSectionID: ptr.Of(strconv.FormatInt(currentConversation.SectionID, 10)),
|
||||
ConnectorConversationID: strconv.FormatInt(currentConversation.ID, 10),
|
||||
}
|
||||
|
||||
if mListMessages.Direction == entity.ScrollPageDirectionPrev {
|
||||
resp.Hasmore = mListMessages.HasMore
|
||||
} else {
|
||||
resp.NextHasMore = mListMessages.HasMore
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
func (c *ConversationApplicationService) buildDomainMsg2VOMessage(ctx context.Context, dm *entity.Message, runToQuestionIDMap map[int64]int64) *message.ChatMessage {
|
||||
cm := &message.ChatMessage{
|
||||
MessageID: strconv.FormatInt(dm.ID, 10),
|
||||
Role: string(dm.Role),
|
||||
Type: string(dm.MessageType),
|
||||
Content: dm.Content,
|
||||
ContentType: string(dm.ContentType),
|
||||
ReplyID: "0",
|
||||
SectionID: strconv.FormatInt(dm.SectionID, 10),
|
||||
ExtraInfo: buildDExt2ApiExt(dm.Ext),
|
||||
ContentTime: dm.CreatedAt,
|
||||
Status: "available",
|
||||
Source: 0,
|
||||
ReasoningContent: ptr.Of(dm.ReasoningContent),
|
||||
}
|
||||
|
||||
if dm.Status == model.MessageStatusBroken {
|
||||
cm.BrokenPos = ptr.Of(dm.Position)
|
||||
}
|
||||
|
||||
if dm.ContentType == model.ContentTypeMix && dm.DisplayContent != "" {
|
||||
cm.Content = dm.DisplayContent
|
||||
}
|
||||
|
||||
if dm.MessageType != model.MessageTypeQuestion {
|
||||
cm.ReplyID = strconv.FormatInt(runToQuestionIDMap[dm.RunID], 10)
|
||||
cm.SenderID = ptr.Of(strconv.FormatInt(dm.AgentID, 10))
|
||||
}
|
||||
return cm
|
||||
}
|
||||
|
||||
func buildDExt2ApiExt(extra map[string]string) *message.ExtraInfo {
|
||||
return &message.ExtraInfo{
|
||||
InputTokens: extra["input_tokens"],
|
||||
OutputTokens: extra["output_tokens"],
|
||||
Token: extra["token"],
|
||||
PluginStatus: extra["plugin_status"],
|
||||
TimeCost: extra["time_cost"],
|
||||
WorkflowTokens: extra["workflow_tokens"],
|
||||
BotState: extra["bot_state"],
|
||||
PluginRequest: extra["plugin_request"],
|
||||
ToolName: extra["tool_name"],
|
||||
Plugin: extra["plugin"],
|
||||
MockHitInfo: extra["mock_hit_info"],
|
||||
MessageTitle: extra["message_title"],
|
||||
StreamPluginRunning: extra["stream_plugin_running"],
|
||||
ExecuteDisplayName: extra["execute_display_name"],
|
||||
TaskType: extra["task_type"],
|
||||
ReferFormat: extra["refer_format"],
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ConversationApplicationService) DeleteMessage(ctx context.Context, mr *message.DeleteMessageRequest) (*message.DeleteMessageResponse, error) {
|
||||
resp := new(message.DeleteMessageResponse)
|
||||
messageInfo, err := c.MessageDomainSVC.GetByID(ctx, mr.MessageID)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
if messageInfo == nil {
|
||||
return resp, errorx.New(errno.ErrConversationMessageNotFound)
|
||||
}
|
||||
|
||||
userID := ctxutil.GetUIDFromCtx(ctx)
|
||||
if messageInfo.UserID != conv.Int64ToStr(*userID) {
|
||||
return resp, errorx.New(errno.ErrConversationPermissionCode, errorx.KV("msg", "permission denied"))
|
||||
}
|
||||
|
||||
err = c.AgentRunDomainSVC.Delete(ctx, []int64{messageInfo.RunID})
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
err = c.MessageDomainSVC.Delete(ctx, &entity.DeleteMeta{
|
||||
RunIDs: []int64{messageInfo.RunID},
|
||||
})
|
||||
if err != nil {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (c *ConversationApplicationService) BreakMessage(ctx context.Context, mr *message.BreakMessageRequest) (*message.BreakMessageResponse, error) {
|
||||
resp := new(message.BreakMessageResponse)
|
||||
|
||||
err := c.MessageDomainSVC.Broken(ctx, &entity.BrokenMeta{
|
||||
ID: *mr.AnswerMessageID,
|
||||
Position: mr.BrokenPos,
|
||||
})
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
329
backend/application/conversation/openapi_agent_run.go
Normal file
329
backend/application/conversation/openapi_agent_run.go
Normal file
@@ -0,0 +1,329 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"strconv"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/conversation/common"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/conversation/run"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
|
||||
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
|
||||
saEntity "github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
|
||||
convEntity "github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/entity"
|
||||
cmdEntity "github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/entity"
|
||||
sseImpl "github.com/coze-dev/coze-studio/backend/infra/impl/sse"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
func (a *OpenapiAgentRunApplication) OpenapiAgentRun(ctx context.Context, sseSender *sseImpl.SSenderImpl, ar *run.ChatV3Request) error {
|
||||
|
||||
apiKeyInfo := ctxutil.GetApiAuthFromCtx(ctx)
|
||||
creatorID := apiKeyInfo.UserID
|
||||
connectorID := apiKeyInfo.ConnectorID
|
||||
|
||||
if ptr.From(ar.ConnectorID) == consts.WebSDKConnectorID {
|
||||
connectorID = ptr.From(ar.ConnectorID)
|
||||
}
|
||||
agentInfo, caErr := a.checkAgent(ctx, ar, connectorID)
|
||||
if caErr != nil {
|
||||
logs.CtxErrorf(ctx, "checkAgent err:%v", caErr)
|
||||
return caErr
|
||||
}
|
||||
|
||||
conversationData, ccErr := a.checkConversation(ctx, ar, creatorID, connectorID)
|
||||
if ccErr != nil {
|
||||
logs.CtxErrorf(ctx, "checkConversation err:%v", ccErr)
|
||||
return ccErr
|
||||
}
|
||||
|
||||
spaceID := agentInfo.SpaceID
|
||||
arr, err := a.buildAgentRunRequest(ctx, ar, connectorID, spaceID, conversationData)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "buildAgentRunRequest err:%v", err)
|
||||
return err
|
||||
}
|
||||
streamer, err := ConversationSVC.AgentRunDomainSVC.AgentRun(ctx, arr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
a.pullStream(ctx, sseSender, streamer)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *OpenapiAgentRunApplication) checkConversation(ctx context.Context, ar *run.ChatV3Request, userID int64, connectorID int64) (*convEntity.Conversation, error) {
|
||||
var conversationData *convEntity.Conversation
|
||||
if ptr.From(ar.ConversationID) > 0 {
|
||||
conData, err := ConversationSVC.ConversationDomainSVC.GetByID(ctx, ptr.From(ar.ConversationID))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conversationData = conData
|
||||
}
|
||||
|
||||
if ptr.From(ar.ConversationID) == 0 || conversationData == nil {
|
||||
|
||||
conData, err := ConversationSVC.ConversationDomainSVC.Create(ctx, &convEntity.CreateMeta{
|
||||
AgentID: ar.BotID,
|
||||
UserID: userID,
|
||||
ConnectorID: connectorID,
|
||||
Scene: common.Scene_SceneOpenApi,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if conData == nil {
|
||||
return nil, errors.New("conversation data is nil")
|
||||
}
|
||||
conversationData = conData
|
||||
|
||||
ar.ConversationID = ptr.Of(conversationData.ID)
|
||||
}
|
||||
|
||||
if conversationData.CreatorID != userID {
|
||||
return nil, errors.New("conversation data not match")
|
||||
}
|
||||
|
||||
return conversationData, nil
|
||||
}
|
||||
|
||||
func (a *OpenapiAgentRunApplication) checkAgent(ctx context.Context, ar *run.ChatV3Request, connectorID int64) (*saEntity.SingleAgent, error) {
|
||||
agentInfo, err := ConversationSVC.appContext.SingleAgentDomainSVC.ObtainAgentByIdentity(ctx, &singleagent.AgentIdentity{
|
||||
AgentID: ar.BotID,
|
||||
IsDraft: false,
|
||||
ConnectorID: connectorID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if agentInfo == nil {
|
||||
return nil, errors.New("agent info is nil")
|
||||
}
|
||||
return agentInfo, nil
|
||||
}
|
||||
|
||||
func (a *OpenapiAgentRunApplication) buildAgentRunRequest(ctx context.Context, ar *run.ChatV3Request, connectorID int64, spaceID int64, conversationData *convEntity.Conversation) (*entity.AgentRunMeta, error) {
|
||||
|
||||
shortcutCMDData, err := a.buildTools(ctx, ar.ShortcutCommand)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
multiContent, contentType, err := a.buildMultiContent(ctx, ar)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
displayContent := a.buildDisplayContent(ctx, ar)
|
||||
arm := &entity.AgentRunMeta{
|
||||
ConversationID: ptr.From(ar.ConversationID),
|
||||
AgentID: ar.BotID,
|
||||
Content: multiContent,
|
||||
DisplayContent: displayContent,
|
||||
SpaceID: spaceID,
|
||||
UserID: ar.User,
|
||||
SectionID: conversationData.SectionID,
|
||||
PreRetrieveTools: shortcutCMDData,
|
||||
IsDraft: false,
|
||||
ConnectorID: connectorID,
|
||||
ContentType: contentType,
|
||||
Ext: ar.ExtraParams,
|
||||
}
|
||||
return arm, nil
|
||||
}
|
||||
|
||||
func (a *OpenapiAgentRunApplication) buildTools(ctx context.Context, shortcmd *run.ShortcutCommandDetail) ([]*entity.Tool, error) {
|
||||
var ts []*entity.Tool
|
||||
|
||||
if shortcmd == nil {
|
||||
return ts, nil
|
||||
}
|
||||
|
||||
var shortcutCMD *cmdEntity.ShortcutCmd
|
||||
cmdMeta, err := a.ShortcutDomainSVC.GetByCmdID(ctx, shortcmd.CommandID, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
shortcutCMD = cmdMeta
|
||||
if shortcutCMD != nil {
|
||||
argBytes, err := json.Marshal(shortcmd.Parameters)
|
||||
if err == nil {
|
||||
ts = append(ts, &entity.Tool{
|
||||
PluginID: shortcutCMD.PluginID,
|
||||
Arguments: string(argBytes),
|
||||
ToolName: shortcutCMD.PluginToolName,
|
||||
ToolID: shortcutCMD.PluginToolID,
|
||||
Type: agentrun.ToolType(shortcutCMD.ToolType),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return ts, nil
|
||||
}
|
||||
|
||||
func (a *OpenapiAgentRunApplication) buildDisplayContent(_ context.Context, ar *run.ChatV3Request) string {
|
||||
for _, item := range ar.AdditionalMessages {
|
||||
if item.ContentType == run.ContentTypeMixApi {
|
||||
return item.Content
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (a *OpenapiAgentRunApplication) buildMultiContent(ctx context.Context, ar *run.ChatV3Request) ([]*message.InputMetaData, message.ContentType, error) {
|
||||
var multiContents []*message.InputMetaData
|
||||
contentType := message.ContentTypeText
|
||||
|
||||
for _, item := range ar.AdditionalMessages {
|
||||
if item == nil {
|
||||
continue
|
||||
}
|
||||
if item.Role != string(schema.User) {
|
||||
return nil, contentType, errors.New("role not match")
|
||||
}
|
||||
if item.ContentType == run.ContentTypeText {
|
||||
if item.Content == "" {
|
||||
continue
|
||||
}
|
||||
multiContents = append(multiContents, &message.InputMetaData{
|
||||
Type: message.InputTypeText,
|
||||
Text: item.Content,
|
||||
})
|
||||
}
|
||||
|
||||
if item.ContentType == run.ContentTypeMixApi {
|
||||
contentType = message.ContentTypeMix
|
||||
var inputs []*run.AdditionalContent
|
||||
err := json.Unmarshal([]byte(item.Content), &inputs)
|
||||
|
||||
logs.CtxInfof(ctx, "inputs:%v, err:%v", conv.DebugJsonToStr(inputs), err)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for _, one := range inputs {
|
||||
if one == nil {
|
||||
continue
|
||||
}
|
||||
switch message.InputType(one.Type) {
|
||||
case message.InputTypeText:
|
||||
multiContents = append(multiContents, &message.InputMetaData{
|
||||
Type: message.InputTypeText,
|
||||
Text: ptr.From(one.Text),
|
||||
})
|
||||
case message.InputTypeImage, message.InputTypeFile:
|
||||
multiContents = append(multiContents, &message.InputMetaData{
|
||||
Type: message.InputType(one.Type),
|
||||
FileData: []*message.FileData{
|
||||
{
|
||||
Url: one.GetFileURL(),
|
||||
},
|
||||
},
|
||||
})
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return multiContents, contentType, nil
|
||||
}
|
||||
|
||||
func (a *OpenapiAgentRunApplication) pullStream(ctx context.Context, sseSender *sseImpl.SSenderImpl, streamer *schema.StreamReader[*entity.AgentRunResponse]) {
|
||||
for {
|
||||
chunk, recvErr := streamer.Recv()
|
||||
logs.CtxInfof(ctx, "chunk :%v, err:%v", conv.DebugJsonToStr(chunk), recvErr)
|
||||
if recvErr != nil {
|
||||
if errors.Is(recvErr, io.EOF) {
|
||||
return
|
||||
}
|
||||
sseSender.Send(ctx, buildErrorEvent(errno.ErrConversationAgentRunError, recvErr.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
switch chunk.Event {
|
||||
|
||||
case entity.RunEventError:
|
||||
sseSender.Send(ctx, buildErrorEvent(chunk.Error.Code, chunk.Error.Msg))
|
||||
case entity.RunEventStreamDone:
|
||||
sseSender.Send(ctx, buildDoneEvent(string(entity.RunEventStreamDone)))
|
||||
case entity.RunEventAck:
|
||||
case entity.RunEventCreated, entity.RunEventCancelled, entity.RunEventInProgress, entity.RunEventFailed, entity.RunEventCompleted:
|
||||
sseSender.Send(ctx, buildMessageChunkEvent(string(chunk.Event), buildARSM2ApiChatMessage(chunk)))
|
||||
case entity.RunEventMessageDelta, entity.RunEventMessageCompleted:
|
||||
sseSender.Send(ctx, buildMessageChunkEvent(string(chunk.Event), buildARSM2ApiMessage(chunk)))
|
||||
|
||||
default:
|
||||
logs.CtxErrorf(ctx, "unknow handler event:%v", chunk.Event)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func buildARSM2ApiMessage(chunk *entity.AgentRunResponse) []byte {
|
||||
chunkMessageItem := chunk.ChunkMessageItem
|
||||
chunkMessage := &run.ChatV3MessageDetail{
|
||||
ID: strconv.FormatInt(chunkMessageItem.ID, 10),
|
||||
ConversationID: strconv.FormatInt(chunkMessageItem.ConversationID, 10),
|
||||
BotID: strconv.FormatInt(chunkMessageItem.AgentID, 10),
|
||||
Role: string(chunkMessageItem.Role),
|
||||
Type: string(chunkMessageItem.MessageType),
|
||||
Content: chunkMessageItem.Content,
|
||||
ContentType: string(chunkMessageItem.ContentType),
|
||||
MetaData: chunkMessageItem.Ext,
|
||||
ChatID: strconv.FormatInt(chunkMessageItem.RunID, 10),
|
||||
ReasoningContent: chunkMessageItem.ReasoningContent,
|
||||
}
|
||||
|
||||
mCM, _ := json.Marshal(chunkMessage)
|
||||
return mCM
|
||||
}
|
||||
|
||||
func buildARSM2ApiChatMessage(chunk *entity.AgentRunResponse) []byte {
|
||||
chunkRunItem := chunk.ChunkRunItem
|
||||
chunkMessage := &run.ChatV3ChatDetail{
|
||||
ID: chunkRunItem.ID,
|
||||
ConversationID: chunkRunItem.ConversationID,
|
||||
BotID: chunkRunItem.AgentID,
|
||||
Status: string(chunkRunItem.Status),
|
||||
SectionID: ptr.Of(chunkRunItem.SectionID),
|
||||
CreatedAt: ptr.Of(int32(chunkRunItem.CreatedAt / 1000)),
|
||||
CompletedAt: ptr.Of(int32(chunkRunItem.CompletedAt / 1000)),
|
||||
FailedAt: ptr.Of(int32(chunkRunItem.FailedAt / 1000)),
|
||||
}
|
||||
if chunkRunItem.Usage != nil {
|
||||
chunkMessage.Usage = &run.Usage{
|
||||
TokenCount: ptr.Of(int32(chunkRunItem.Usage.LlmTotalTokens)),
|
||||
InputTokens: ptr.Of(int32(chunkRunItem.Usage.LlmPromptTokens)),
|
||||
OutputTokens: ptr.Of(int32(chunkRunItem.Usage.LlmCompletionTokens)),
|
||||
}
|
||||
}
|
||||
mCM, _ := json.Marshal(chunkMessage)
|
||||
return mCM
|
||||
}
|
||||
133
backend/application/conversation/openapi_message.go
Normal file
133
backend/application/conversation/openapi_message.go
Normal file
@@ -0,0 +1,133 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/conversation/message"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/conversation/run"
|
||||
message3 "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
||||
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
|
||||
convEntity "github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
type OpenapiMessageApplication struct{}
|
||||
|
||||
var OpenapiMessageApplicationService = new(OpenapiMessageApplication)
|
||||
|
||||
func (m *OpenapiMessageApplication) GetApiMessageList(ctx context.Context, mr *message.ListMessageApiRequest) (*message.ListMessageApiResponse, error) {
|
||||
// Get Conversation ID by agent id & userID & scene
|
||||
userID := ctxutil.MustGetUIDFromApiAuthCtx(ctx)
|
||||
|
||||
currentConversation, err := getConversation(ctx, mr.ConversationID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if currentConversation == nil {
|
||||
return nil, errorx.New(errno.ErrConversationNotFound)
|
||||
}
|
||||
|
||||
if currentConversation.CreatorID != userID {
|
||||
return nil, errorx.New(errno.ErrConversationPermissionCode, errorx.KV("msg", "permission denied"))
|
||||
}
|
||||
|
||||
msgListMeta := &entity.ListMeta{
|
||||
ConversationID: currentConversation.ID,
|
||||
AgentID: currentConversation.AgentID,
|
||||
Limit: int(ptr.From(mr.Limit)),
|
||||
}
|
||||
|
||||
if mr.BeforeID != nil {
|
||||
msgListMeta.Direction = entity.ScrollPageDirectionPrev
|
||||
msgListMeta.Cursor = *mr.BeforeID
|
||||
} else {
|
||||
msgListMeta.Direction = entity.ScrollPageDirectionNext
|
||||
msgListMeta.Cursor = ptr.From(mr.AfterID)
|
||||
}
|
||||
if mr.Order == nil {
|
||||
msgListMeta.OrderBy = ptr.Of(message.OrderByDesc)
|
||||
} else {
|
||||
msgListMeta.OrderBy = mr.Order
|
||||
}
|
||||
|
||||
mListMessages, err := ConversationSVC.MessageDomainSVC.List(ctx, msgListMeta)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// get agent id
|
||||
var agentIDs []int64
|
||||
for _, mOne := range mListMessages.Messages {
|
||||
agentIDs = append(agentIDs, mOne.AgentID)
|
||||
}
|
||||
|
||||
resp := m.buildMessageListResponse(ctx, mListMessages, currentConversation)
|
||||
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func getConversation(ctx context.Context, conversationID int64) (*convEntity.Conversation, error) {
|
||||
conversationInfo, err := ConversationSVC.ConversationDomainSVC.GetByID(ctx, conversationID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return conversationInfo, nil
|
||||
}
|
||||
|
||||
func (m *OpenapiMessageApplication) buildMessageListResponse(ctx context.Context, mListMessages *entity.ListResult, currentConversation *convEntity.Conversation) *message.ListMessageApiResponse {
|
||||
messagesVO := slices.Transform(mListMessages.Messages, func(dm *entity.Message) *message.OpenMessageApi {
|
||||
|
||||
content := dm.Content
|
||||
|
||||
msg := &message.OpenMessageApi{
|
||||
ID: dm.ID,
|
||||
ConversationID: dm.ConversationID,
|
||||
BotID: dm.AgentID,
|
||||
Role: string(dm.Role),
|
||||
Type: string(dm.MessageType),
|
||||
Content: content,
|
||||
ContentType: string(dm.ContentType),
|
||||
SectionID: strconv.FormatInt(dm.SectionID, 10),
|
||||
CreatedAt: dm.CreatedAt,
|
||||
UpdatedAt: dm.UpdatedAt,
|
||||
ChatID: dm.RunID,
|
||||
MetaData: dm.Ext,
|
||||
}
|
||||
if dm.ContentType == message3.ContentTypeMix && dm.DisplayContent != "" {
|
||||
msg.Content = dm.DisplayContent
|
||||
msg.ContentType = run.ContentTypeMixApi
|
||||
}
|
||||
return msg
|
||||
})
|
||||
|
||||
resp := &message.ListMessageApiResponse{
|
||||
Messages: messagesVO,
|
||||
HasMore: ptr.Of(mListMessages.HasMore),
|
||||
FirstID: ptr.Of(mListMessages.PrevCursor),
|
||||
LastID: ptr.Of(mListMessages.NextCursor),
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
780
backend/application/knowledge/convertor.go
Normal file
780
backend/application/knowledge/convertor.go
Normal file
@@ -0,0 +1,780 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
modelCommon "github.com/coze-dev/coze-studio/backend/api/model/common"
|
||||
knowledgeModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/flow/dataengine/dataset"
|
||||
"github.com/coze-dev/coze-studio/backend/application/upload"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/knowledge/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/knowledge/service"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
)
|
||||
|
||||
func assertValAs(typ document.TableColumnType, val string) (*document.ColumnData, error) {
|
||||
cd := &document.ColumnData{
|
||||
Type: typ,
|
||||
}
|
||||
if val == "" {
|
||||
return cd, nil
|
||||
}
|
||||
switch typ {
|
||||
case document.TableColumnTypeString:
|
||||
return &document.ColumnData{
|
||||
Type: document.TableColumnTypeString,
|
||||
ValString: &val,
|
||||
}, nil
|
||||
|
||||
case document.TableColumnTypeInteger:
|
||||
i, err := strconv.ParseInt(val, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &document.ColumnData{
|
||||
Type: document.TableColumnTypeInteger,
|
||||
ValInteger: &i,
|
||||
}, nil
|
||||
|
||||
case document.TableColumnTypeTime:
|
||||
// 支持时间戳和时间字符串
|
||||
i, err := strconv.ParseInt(val, 10, 64)
|
||||
if err == nil {
|
||||
t := time.Unix(i, 0)
|
||||
return &document.ColumnData{
|
||||
Type: document.TableColumnTypeTime,
|
||||
ValTime: &t,
|
||||
}, nil
|
||||
|
||||
}
|
||||
t, err := time.Parse(time.DateTime, val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &document.ColumnData{
|
||||
Type: document.TableColumnTypeTime,
|
||||
ValTime: &t,
|
||||
}, nil
|
||||
|
||||
case document.TableColumnTypeNumber:
|
||||
f, err := strconv.ParseFloat(val, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &document.ColumnData{
|
||||
Type: document.TableColumnTypeNumber,
|
||||
ValNumber: &f,
|
||||
}, nil
|
||||
|
||||
case document.TableColumnTypeBoolean:
|
||||
t, err := strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &document.ColumnData{
|
||||
Type: document.TableColumnTypeBoolean,
|
||||
ValBoolean: &t,
|
||||
}, nil
|
||||
case document.TableColumnTypeImage:
|
||||
return &document.ColumnData{
|
||||
Type: document.TableColumnTypeImage,
|
||||
ValImage: &val,
|
||||
}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("[assertValAs] type not support, type=%d, val=%s", typ, val)
|
||||
}
|
||||
}
|
||||
|
||||
func convertTableDataType2Entity(t dataset.TableDataType) service.TableDataType {
|
||||
switch t {
|
||||
case dataset.TableDataType_AllData:
|
||||
return service.AllData
|
||||
case dataset.TableDataType_OnlySchema:
|
||||
return service.OnlySchema
|
||||
case dataset.TableDataType_OnlyPreview:
|
||||
return service.OnlyPreview
|
||||
default:
|
||||
return service.AllData
|
||||
}
|
||||
}
|
||||
|
||||
func convertTableSheet2Entity(sheet *dataset.TableSheet) *entity.TableSheet {
|
||||
if sheet == nil {
|
||||
return nil
|
||||
}
|
||||
return &entity.TableSheet{
|
||||
SheetId: sheet.GetSheetID(),
|
||||
StartLineIdx: sheet.GetStartLineIdx(),
|
||||
HeaderLineIdx: sheet.GetHeaderLineIdx(),
|
||||
}
|
||||
}
|
||||
|
||||
func convertDocTableSheet2Model(sheet entity.TableSheet) *dataset.DocTableSheet {
|
||||
return &dataset.DocTableSheet{
|
||||
ID: sheet.SheetId,
|
||||
SheetName: sheet.SheetName,
|
||||
TotalRow: sheet.TotalRows,
|
||||
}
|
||||
}
|
||||
|
||||
func convertTableMeta(t []*entity.TableColumn) []*modelCommon.DocTableColumn {
|
||||
if len(t) == 0 {
|
||||
return nil
|
||||
}
|
||||
resp := make([]*modelCommon.DocTableColumn, 0)
|
||||
for i := range t {
|
||||
if t[i] == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
resp = append(resp, &modelCommon.DocTableColumn{
|
||||
ID: t[i].ID,
|
||||
ColumnName: t[i].Name,
|
||||
IsSemantic: t[i].Indexing,
|
||||
Desc: &t[i].Description,
|
||||
Sequence: t[i].Sequence,
|
||||
ColumnType: convertColumnType(t[i].Type),
|
||||
})
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func convertColumnType(t document.TableColumnType) *modelCommon.ColumnType {
|
||||
switch t {
|
||||
case document.TableColumnTypeString:
|
||||
return modelCommon.ColumnTypePtr(modelCommon.ColumnType_Text)
|
||||
case document.TableColumnTypeBoolean:
|
||||
return modelCommon.ColumnTypePtr(modelCommon.ColumnType_Boolean)
|
||||
case document.TableColumnTypeNumber:
|
||||
return modelCommon.ColumnTypePtr(modelCommon.ColumnType_Float)
|
||||
case document.TableColumnTypeTime:
|
||||
return modelCommon.ColumnTypePtr(modelCommon.ColumnType_Date)
|
||||
case document.TableColumnTypeInteger:
|
||||
return modelCommon.ColumnTypePtr(modelCommon.ColumnType_Number)
|
||||
case document.TableColumnTypeImage:
|
||||
return modelCommon.ColumnTypePtr(modelCommon.ColumnType_Image)
|
||||
default:
|
||||
return modelCommon.ColumnTypePtr(modelCommon.ColumnType_Text)
|
||||
}
|
||||
}
|
||||
|
||||
func convertDocTableSheet(t *entity.TableSheet) *modelCommon.DocTableSheet {
|
||||
if t == nil {
|
||||
return nil
|
||||
}
|
||||
return &modelCommon.DocTableSheet{
|
||||
ID: t.SheetId,
|
||||
SheetName: t.SheetName,
|
||||
TotalRow: t.TotalRows,
|
||||
}
|
||||
}
|
||||
|
||||
func convertSlice2Model(sliceEntity *entity.Slice) *dataset.SliceInfo {
|
||||
if sliceEntity == nil {
|
||||
return nil
|
||||
}
|
||||
return &dataset.SliceInfo{
|
||||
SliceID: sliceEntity.ID,
|
||||
Content: convertSliceContent(sliceEntity),
|
||||
Status: convertSliceStatus2Model(sliceEntity.SliceStatus),
|
||||
HitCount: sliceEntity.Hit,
|
||||
CharCount: sliceEntity.CharCount,
|
||||
Sequence: sliceEntity.Sequence,
|
||||
DocumentID: sliceEntity.DocumentID,
|
||||
ChunkInfo: "",
|
||||
}
|
||||
}
|
||||
|
||||
func convertSliceContent(s *entity.Slice) string {
|
||||
if len(s.RawContent) == 0 {
|
||||
return ""
|
||||
}
|
||||
if s.RawContent[0].Type == knowledgeModel.SliceContentTypeTable {
|
||||
tableData := make([]sliceContentData, 0, len(s.RawContent[0].Table.Columns))
|
||||
for _, col := range s.RawContent[0].Table.Columns {
|
||||
tableData = append(tableData, sliceContentData{
|
||||
ColumnID: strconv.FormatInt(col.ColumnID, 10),
|
||||
ColumnName: col.ColumnName,
|
||||
Value: col.GetNullableStringValue(),
|
||||
Desc: "",
|
||||
})
|
||||
}
|
||||
b, _ := json.Marshal(tableData)
|
||||
return string(b)
|
||||
}
|
||||
return s.GetSliceContent()
|
||||
}
|
||||
|
||||
type sliceContentData struct {
|
||||
ColumnID string `json:"column_id"`
|
||||
ColumnName string `json:"column_name"`
|
||||
Value string `json:"value"`
|
||||
Desc string `json:"desc"`
|
||||
}
|
||||
|
||||
func convertSliceStatus2Model(status knowledgeModel.SliceStatus) dataset.SliceStatus {
|
||||
switch status {
|
||||
case knowledgeModel.SliceStatusInit:
|
||||
return dataset.SliceStatus_PendingVectoring
|
||||
case knowledgeModel.SliceStatusFinishStore:
|
||||
return dataset.SliceStatus_FinishVectoring
|
||||
case knowledgeModel.SliceStatusFailed:
|
||||
return dataset.SliceStatus_Deactive
|
||||
default:
|
||||
return dataset.SliceStatus_PendingVectoring
|
||||
}
|
||||
}
|
||||
func convertFilterStrategy2Model(strategy *entity.ParsingStrategy) *dataset.FilterStrategy {
|
||||
if strategy == nil {
|
||||
return nil
|
||||
}
|
||||
if len(strategy.FilterPages) != 0 {
|
||||
return &dataset.FilterStrategy{
|
||||
FilterPage: slices.Transform(strategy.FilterPages, func(page int) int32 {
|
||||
return int32(page)
|
||||
}),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func convertDocument2Model(documentEntity *entity.Document) *dataset.DocumentInfo {
|
||||
if documentEntity == nil {
|
||||
return nil
|
||||
}
|
||||
chunkStrategy := convertChunkingStrategy2Model(documentEntity.ChunkingStrategy)
|
||||
filterStrategy := convertFilterStrategy2Model(documentEntity.ParsingStrategy)
|
||||
parseStrategy, _ := convertParsingStrategy2Model(documentEntity.ParsingStrategy)
|
||||
docInfo := &dataset.DocumentInfo{
|
||||
Name: documentEntity.Name,
|
||||
DocumentID: documentEntity.ID,
|
||||
TosURI: &documentEntity.URI,
|
||||
CreateTime: int32(documentEntity.CreatedAtMs / 1000),
|
||||
UpdateTime: int32(documentEntity.UpdatedAtMs / 1000),
|
||||
CreatorID: ptr.Of(documentEntity.CreatorID),
|
||||
SliceCount: int32(documentEntity.SliceCount),
|
||||
Type: string(documentEntity.FileExtension),
|
||||
Size: int32(documentEntity.Size),
|
||||
CharCount: int32(documentEntity.CharCount),
|
||||
Status: convertDocumentStatus2Model(documentEntity.Status),
|
||||
HitCount: int32(documentEntity.Hits),
|
||||
SourceType: convertDocumentSource2Model(documentEntity.Source),
|
||||
FormatType: convertDocumentTypeEntity2Dataset(documentEntity.Type),
|
||||
WebURL: &documentEntity.URL,
|
||||
TableMeta: convertTableColumns2Model(documentEntity.TableInfo.Columns),
|
||||
StatusDescript: &documentEntity.StatusMsg,
|
||||
SpaceID: ptr.Of(documentEntity.SpaceID),
|
||||
EditableAppendContent: nil,
|
||||
FilterStrategy: filterStrategy,
|
||||
PreviewTosURL: &documentEntity.URL,
|
||||
ChunkStrategy: chunkStrategy,
|
||||
ParsingStrategy: parseStrategy,
|
||||
}
|
||||
return docInfo
|
||||
}
|
||||
|
||||
func convertDocumentSource2Entity(sourceType dataset.DocumentSource) entity.DocumentSource {
|
||||
switch sourceType {
|
||||
case dataset.DocumentSource_Custom:
|
||||
return entity.DocumentSourceCustom
|
||||
case dataset.DocumentSource_Document:
|
||||
return entity.DocumentSourceLocal
|
||||
default:
|
||||
return entity.DocumentSourceLocal
|
||||
}
|
||||
}
|
||||
|
||||
func convertDocumentSource2Model(sourceType entity.DocumentSource) dataset.DocumentSource {
|
||||
switch sourceType {
|
||||
case entity.DocumentSourceCustom:
|
||||
return dataset.DocumentSource_Custom
|
||||
case entity.DocumentSourceLocal:
|
||||
return dataset.DocumentSource_Document
|
||||
default:
|
||||
return dataset.DocumentSource_Document
|
||||
}
|
||||
}
|
||||
|
||||
func convertDocumentStatus2Model(status entity.DocumentStatus) dataset.DocumentStatus {
|
||||
switch status {
|
||||
case entity.DocumentStatusDeleted:
|
||||
return dataset.DocumentStatus_Deleted
|
||||
case entity.DocumentStatusEnable, entity.DocumentStatusInit:
|
||||
return dataset.DocumentStatus_Enable
|
||||
case entity.DocumentStatusFailed:
|
||||
return dataset.DocumentStatus_Failed
|
||||
default:
|
||||
return dataset.DocumentStatus_Processing
|
||||
}
|
||||
}
|
||||
|
||||
func convertTableColumns2Entity(columns []*dataset.TableColumn) []*entity.TableColumn {
|
||||
if len(columns) == 0 {
|
||||
return nil
|
||||
}
|
||||
columnEntities := make([]*entity.TableColumn, 0, len(columns))
|
||||
for i := range columns {
|
||||
columnEntities = append(columnEntities, &entity.TableColumn{
|
||||
ID: columns[i].GetID(),
|
||||
Name: columns[i].GetColumnName(),
|
||||
Type: convertColumnType2Entity(columns[i].GetColumnType()),
|
||||
Description: columns[i].GetDesc(),
|
||||
Indexing: columns[i].GetIsSemantic(),
|
||||
Sequence: columns[i].GetSequence(),
|
||||
})
|
||||
}
|
||||
return columnEntities
|
||||
}
|
||||
|
||||
func convertTableColumns2Model(columns []*entity.TableColumn) []*dataset.TableColumn {
|
||||
if len(columns) == 0 {
|
||||
return nil
|
||||
}
|
||||
columnModels := make([]*dataset.TableColumn, 0, len(columns))
|
||||
for i := range columns {
|
||||
columnType := convertColumnType2Model(columns[i].Type)
|
||||
columnModels = append(columnModels, &dataset.TableColumn{
|
||||
ID: columns[i].ID,
|
||||
ColumnName: columns[i].Name,
|
||||
ColumnType: &columnType,
|
||||
Desc: &columns[i].Description,
|
||||
IsSemantic: columns[i].Indexing,
|
||||
Sequence: columns[i].Sequence,
|
||||
})
|
||||
}
|
||||
return columnModels
|
||||
}
|
||||
|
||||
func convertTableColumnDataSlice(cols []*entity.TableColumn, data []*document.ColumnData) (map[string]string, error) {
|
||||
if len(cols) != len(data) {
|
||||
return nil, fmt.Errorf("[convertTableColumnDataSlice] invalid cols and vals, len(cols)=%d, len(vals)=%d", len(cols), len(data))
|
||||
}
|
||||
|
||||
resp := make(map[string]string, len(data))
|
||||
for i := range data {
|
||||
col := cols[i]
|
||||
val := data[i]
|
||||
content := ""
|
||||
if val != nil {
|
||||
content = val.GetStringValue()
|
||||
}
|
||||
resp[strconv.FormatInt(col.Sequence, 10)] = content
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func convertColumnType2Model(columnType document.TableColumnType) dataset.ColumnType {
|
||||
switch columnType {
|
||||
case document.TableColumnTypeString:
|
||||
return dataset.ColumnType_Text
|
||||
case document.TableColumnTypeInteger:
|
||||
return dataset.ColumnType_Number
|
||||
case document.TableColumnTypeImage:
|
||||
return dataset.ColumnType_Image
|
||||
case document.TableColumnTypeBoolean:
|
||||
return dataset.ColumnType_Boolean
|
||||
case document.TableColumnTypeTime:
|
||||
return dataset.ColumnType_Date
|
||||
case document.TableColumnTypeNumber:
|
||||
return dataset.ColumnType_Float
|
||||
default:
|
||||
return dataset.ColumnType_Text
|
||||
}
|
||||
}
|
||||
|
||||
func convertColumnType2Entity(columnType dataset.ColumnType) document.TableColumnType {
|
||||
switch columnType {
|
||||
case dataset.ColumnType_Text:
|
||||
return document.TableColumnTypeString
|
||||
case dataset.ColumnType_Number:
|
||||
return document.TableColumnTypeInteger
|
||||
case dataset.ColumnType_Image:
|
||||
return document.TableColumnTypeImage
|
||||
case dataset.ColumnType_Boolean:
|
||||
return document.TableColumnTypeBoolean
|
||||
case dataset.ColumnType_Date:
|
||||
return document.TableColumnTypeTime
|
||||
case dataset.ColumnType_Float:
|
||||
return document.TableColumnTypeNumber
|
||||
default:
|
||||
return document.TableColumnTypeString
|
||||
}
|
||||
}
|
||||
|
||||
func convertParsingStrategy2Entity(strategy *dataset.ParsingStrategy, sheet *dataset.TableSheet, captionType *dataset.CaptionType, filterStrategy *dataset.FilterStrategy) *entity.ParsingStrategy {
|
||||
if strategy == nil && sheet == nil && captionType == nil {
|
||||
return nil
|
||||
}
|
||||
res := &entity.ParsingStrategy{}
|
||||
if strategy != nil {
|
||||
res.ExtractImage = strategy.GetImageExtraction()
|
||||
res.ExtractTable = strategy.GetTableExtraction()
|
||||
res.ImageOCR = strategy.GetImageOcr()
|
||||
res.ParsingType = convertParsingType2Entity(strategy.GetParsingType())
|
||||
if strategy.GetParsingType() == dataset.ParsingType_FastParsing {
|
||||
res.ExtractImage = false
|
||||
res.ExtractTable = false
|
||||
res.ImageOCR = false
|
||||
}
|
||||
}
|
||||
if sheet != nil {
|
||||
res.SheetID = sheet.GetSheetID()
|
||||
res.HeaderLine = int(sheet.GetHeaderLineIdx())
|
||||
res.DataStartLine = int(sheet.GetStartLineIdx())
|
||||
}
|
||||
if filterStrategy != nil {
|
||||
res.FilterPages = slices.Transform(filterStrategy.GetFilterPage(), func(page int32) int { return int(page) })
|
||||
}
|
||||
res.CaptionType = convertCaptionType2Entity(captionType)
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func convertParsingType2Entity(pt dataset.ParsingType) entity.ParsingType {
|
||||
switch pt {
|
||||
case dataset.ParsingType_AccurateParsing:
|
||||
return entity.ParsingType_AccurateParsing
|
||||
case dataset.ParsingType_FastParsing:
|
||||
return entity.ParsingType_FastParsing
|
||||
default:
|
||||
return entity.ParsingType_FastParsing
|
||||
}
|
||||
}
|
||||
|
||||
func convertParsingStrategy2Model(strategy *entity.ParsingStrategy) (s *dataset.ParsingStrategy, sheet *dataset.TableSheet) {
|
||||
if strategy == nil {
|
||||
return nil, nil
|
||||
}
|
||||
sheet = &dataset.TableSheet{
|
||||
SheetID: strategy.SheetID,
|
||||
HeaderLineIdx: int64(strategy.HeaderLine),
|
||||
StartLineIdx: int64(strategy.DataStartLine),
|
||||
}
|
||||
return &dataset.ParsingStrategy{
|
||||
ParsingType: ptr.Of(convertParsingType2Model(strategy.ParsingType)),
|
||||
ImageExtraction: &strategy.ExtractImage,
|
||||
TableExtraction: &strategy.ExtractTable,
|
||||
ImageOcr: &strategy.ImageOCR,
|
||||
}, sheet
|
||||
}
|
||||
func convertParsingType2Model(pt entity.ParsingType) dataset.ParsingType {
|
||||
switch pt {
|
||||
case entity.ParsingType_AccurateParsing:
|
||||
return dataset.ParsingType_AccurateParsing
|
||||
case entity.ParsingType_FastParsing:
|
||||
return dataset.ParsingType_FastParsing
|
||||
default:
|
||||
return dataset.ParsingType_FastParsing
|
||||
}
|
||||
}
|
||||
func convertChunkingStrategy2Entity(strategy *dataset.ChunkStrategy) *entity.ChunkingStrategy {
|
||||
if strategy == nil {
|
||||
return nil
|
||||
}
|
||||
if strategy.ChunkType == dataset.ChunkType_DefaultChunk {
|
||||
return &entity.ChunkingStrategy{
|
||||
ChunkType: convertChunkType2Entity(dataset.ChunkType_DefaultChunk),
|
||||
}
|
||||
}
|
||||
return &entity.ChunkingStrategy{
|
||||
ChunkType: convertChunkType2Entity(strategy.ChunkType),
|
||||
ChunkSize: strategy.GetMaxTokens(),
|
||||
Separator: strategy.GetSeparator(),
|
||||
Overlap: strategy.GetOverlap(),
|
||||
TrimSpace: strategy.GetRemoveExtraSpaces(),
|
||||
TrimURLAndEmail: strategy.GetRemoveUrlsEmails(),
|
||||
MaxDepth: strategy.GetMaxLevel(),
|
||||
SaveTitle: strategy.GetSaveTitle(),
|
||||
}
|
||||
}
|
||||
|
||||
func GetExtension(uri string) string {
|
||||
if uri == "" {
|
||||
return ""
|
||||
}
|
||||
fileExtension := path.Base(uri)
|
||||
ext := path.Ext(fileExtension)
|
||||
if ext != "" {
|
||||
return strings.TrimPrefix(ext, ".")
|
||||
}
|
||||
return ""
|
||||
}
|
||||
func convertCaptionType2Entity(ct *dataset.CaptionType) *parser.ImageAnnotationType {
|
||||
if ct == nil {
|
||||
return nil
|
||||
}
|
||||
switch ptr.From(ct) {
|
||||
case dataset.CaptionType_Auto:
|
||||
return ptr.Of(parser.ImageAnnotationTypeModel)
|
||||
case dataset.CaptionType_Manual:
|
||||
return ptr.Of(parser.ImageAnnotationTypeManual)
|
||||
default:
|
||||
return ptr.Of(parser.ImageAnnotationTypeModel)
|
||||
}
|
||||
}
|
||||
func convertDatasetStatus2Entity(status dataset.DatasetStatus) model.KnowledgeStatus {
|
||||
switch status {
|
||||
case dataset.DatasetStatus_DatasetReady:
|
||||
return model.KnowledgeStatusEnable
|
||||
case dataset.DatasetStatus_DatasetForbid, dataset.DatasetStatus_DatasetDeleted:
|
||||
return model.KnowledgeStatusDisable
|
||||
default:
|
||||
return model.KnowledgeStatusEnable
|
||||
}
|
||||
}
|
||||
|
||||
func convertChunkType2model(chunkType parser.ChunkType) dataset.ChunkType {
|
||||
switch chunkType {
|
||||
case parser.ChunkTypeCustom:
|
||||
return dataset.ChunkType_CustomChunk
|
||||
case parser.ChunkTypeDefault:
|
||||
return dataset.ChunkType_DefaultChunk
|
||||
case parser.ChunkTypeLeveled:
|
||||
return dataset.ChunkType_LevelChunk
|
||||
default:
|
||||
return dataset.ChunkType_CustomChunk
|
||||
}
|
||||
}
|
||||
|
||||
func convertChunkType2Entity(chunkType dataset.ChunkType) parser.ChunkType {
|
||||
switch chunkType {
|
||||
case dataset.ChunkType_CustomChunk:
|
||||
return parser.ChunkTypeCustom
|
||||
case dataset.ChunkType_DefaultChunk:
|
||||
return parser.ChunkTypeDefault
|
||||
case dataset.ChunkType_LevelChunk:
|
||||
return parser.ChunkTypeLeveled
|
||||
default:
|
||||
return parser.ChunkTypeDefault
|
||||
}
|
||||
}
|
||||
|
||||
func convertChunkingStrategy2Model(chunkingStrategy *entity.ChunkingStrategy) *dataset.ChunkStrategy {
|
||||
if chunkingStrategy == nil {
|
||||
return nil
|
||||
}
|
||||
return &dataset.ChunkStrategy{
|
||||
Separator: chunkingStrategy.Separator,
|
||||
MaxTokens: chunkingStrategy.ChunkSize,
|
||||
RemoveExtraSpaces: chunkingStrategy.TrimSpace,
|
||||
RemoveUrlsEmails: chunkingStrategy.TrimURLAndEmail,
|
||||
ChunkType: convertChunkType2model(chunkingStrategy.ChunkType),
|
||||
Overlap: &chunkingStrategy.Overlap,
|
||||
MaxLevel: &chunkingStrategy.MaxDepth,
|
||||
SaveTitle: &chunkingStrategy.SaveTitle,
|
||||
}
|
||||
}
|
||||
|
||||
func convertDocumentTypeEntity2Dataset(formatType model.DocumentType) dataset.FormatType {
|
||||
switch formatType {
|
||||
case model.DocumentTypeText:
|
||||
return dataset.FormatType_Text
|
||||
case model.DocumentTypeTable:
|
||||
return dataset.FormatType_Table
|
||||
case model.DocumentTypeImage:
|
||||
return dataset.FormatType_Image
|
||||
default:
|
||||
return dataset.FormatType_Text
|
||||
}
|
||||
}
|
||||
|
||||
func convertDocumentTypeDataset2Entity(formatType dataset.FormatType) model.DocumentType {
|
||||
switch formatType {
|
||||
case dataset.FormatType_Text:
|
||||
return model.DocumentTypeText
|
||||
case dataset.FormatType_Table:
|
||||
return model.DocumentTypeTable
|
||||
case dataset.FormatType_Image:
|
||||
return model.DocumentTypeImage
|
||||
default:
|
||||
return model.DocumentTypeUnknown
|
||||
}
|
||||
}
|
||||
|
||||
func batchConvertKnowledgeEntity2Model(ctx context.Context, knowledgeEntity []*model.Knowledge) (map[int64]*dataset.Dataset, error) {
|
||||
knowledgeMap := map[int64]*dataset.Dataset{}
|
||||
for _, k := range knowledgeEntity {
|
||||
documentEntity, err := KnowledgeSVC.DomainSVC.ListDocument(ctx, &service.ListDocumentRequest{
|
||||
KnowledgeID: k.ID,
|
||||
SelectAll: true,
|
||||
})
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "list document failed, err: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
datasetStatus := dataset.DatasetStatus_DatasetReady
|
||||
if k.Status == model.KnowledgeStatusDisable {
|
||||
datasetStatus = dataset.DatasetStatus_DatasetForbid
|
||||
}
|
||||
|
||||
var (
|
||||
rule *entity.ChunkingStrategy
|
||||
totalSize int64
|
||||
sliceCount int32
|
||||
processingFileList []string
|
||||
processingFileIDList []string
|
||||
fileList []string
|
||||
)
|
||||
for i := range documentEntity.Documents {
|
||||
doc := documentEntity.Documents[i]
|
||||
totalSize += doc.Size
|
||||
sliceCount += int32(doc.SliceCount)
|
||||
if doc.Status == entity.DocumentStatusChunking || doc.Status == entity.DocumentStatusUploading {
|
||||
processingFileList = append(processingFileList, doc.Name)
|
||||
processingFileIDList = append(processingFileIDList, strconv.FormatInt(doc.ID, 10))
|
||||
}
|
||||
if i == 0 {
|
||||
rule = doc.ChunkingStrategy
|
||||
}
|
||||
fileList = append(fileList, doc.Name)
|
||||
}
|
||||
knowledgeMap[k.ID] = &dataset.Dataset{
|
||||
DatasetID: k.ID,
|
||||
Name: k.Name,
|
||||
FileList: fileList,
|
||||
AllFileSize: totalSize,
|
||||
BotUsedCount: 0,
|
||||
Status: datasetStatus,
|
||||
ProcessingFileList: processingFileList,
|
||||
UpdateTime: int32(k.UpdatedAtMs / 1000),
|
||||
IconURI: k.IconURI,
|
||||
IconURL: k.IconURL,
|
||||
Description: k.Description,
|
||||
CanEdit: true,
|
||||
CreateTime: int32(k.CreatedAtMs / 1000),
|
||||
CreatorID: k.CreatorID,
|
||||
SpaceID: k.SpaceID,
|
||||
FailedFileList: nil,
|
||||
FormatType: convertDocumentTypeEntity2Dataset(k.Type),
|
||||
SliceCount: sliceCount,
|
||||
DocCount: int32(len(documentEntity.Documents)),
|
||||
HitCount: int32(k.SliceHit),
|
||||
ChunkStrategy: convertChunkingStrategy2Model(rule),
|
||||
ProcessingFileIDList: processingFileIDList,
|
||||
ProjectID: strconv.FormatInt(k.AppID, 10),
|
||||
}
|
||||
}
|
||||
return knowledgeMap, nil
|
||||
}
|
||||
|
||||
func convertSourceInfo(sourceInfo *dataset.SourceInfo) (*service.TableSourceInfo, error) {
|
||||
if sourceInfo == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
fType := sourceInfo.FileType
|
||||
if fType == nil && sourceInfo.TosURI != nil {
|
||||
split := strings.Split(sourceInfo.GetTosURI(), ".")
|
||||
fType = &split[len(split)-1]
|
||||
}
|
||||
|
||||
var customContent []map[string]string
|
||||
if sourceInfo.CustomContent != nil {
|
||||
if err := json.Unmarshal([]byte(sourceInfo.GetCustomContent()), &customContent); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &service.TableSourceInfo{
|
||||
FileType: fType,
|
||||
Uri: sourceInfo.TosURI,
|
||||
FileBase64: sourceInfo.FileBase64,
|
||||
CustomContent: customContent,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func convertCreateDocReviewReq(req *dataset.CreateDocumentReviewRequest) *service.CreateDocumentReviewRequest {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
var captionType *dataset.CaptionType
|
||||
if req.GetChunkStrategy() != nil {
|
||||
captionType = req.GetChunkStrategy().CaptionType
|
||||
}
|
||||
resp := &service.CreateDocumentReviewRequest{
|
||||
ChunkStrategy: convertChunkingStrategy2Entity(req.ChunkStrategy),
|
||||
ParsingStrategy: convertParsingStrategy2Entity(req.ParsingStrategy, nil, captionType, nil),
|
||||
}
|
||||
resp.KnowledgeID = req.GetDatasetID()
|
||||
resp.Reviews = slices.Transform(req.GetReviews(), func(r *dataset.ReviewInput) *service.ReviewInput {
|
||||
return &service.ReviewInput{
|
||||
DocumentName: r.GetDocumentName(),
|
||||
DocumentType: r.GetDocumentType(),
|
||||
TosUri: r.GetTosURI(),
|
||||
DocumentID: ptr.Of(r.GetDocumentID()),
|
||||
}
|
||||
})
|
||||
return resp
|
||||
}
|
||||
|
||||
func convertReviewStatus2Model(status *entity.ReviewStatus) *dataset.ReviewStatus {
|
||||
if status == nil {
|
||||
return nil
|
||||
}
|
||||
switch *status {
|
||||
case entity.ReviewStatus_Enable:
|
||||
return dataset.ReviewStatusPtr(dataset.ReviewStatus_Enable)
|
||||
case entity.ReviewStatus_Processing:
|
||||
return dataset.ReviewStatusPtr(dataset.ReviewStatus_Processing)
|
||||
case entity.ReviewStatus_Failed:
|
||||
return dataset.ReviewStatusPtr(dataset.ReviewStatus_Failed)
|
||||
case entity.ReviewStatus_ForceStop:
|
||||
return dataset.ReviewStatusPtr(dataset.ReviewStatus_ForceStop)
|
||||
default:
|
||||
return dataset.ReviewStatusPtr(dataset.ReviewStatus_Processing)
|
||||
}
|
||||
}
|
||||
|
||||
func getIconURI(tp dataset.FormatType) string {
|
||||
switch tp {
|
||||
case dataset.FormatType_Text:
|
||||
return upload.TextKnowledgeDefaultIcon
|
||||
case dataset.FormatType_Table:
|
||||
return upload.TableKnowledgeDefaultIcon
|
||||
case dataset.FormatType_Image:
|
||||
return upload.ImageKnowledgeDefaultIcon
|
||||
default:
|
||||
return upload.TextKnowledgeDefaultIcon
|
||||
}
|
||||
}
|
||||
|
||||
func convertFormatType2Entity(tp dataset.FormatType) model.DocumentType {
|
||||
switch tp {
|
||||
case dataset.FormatType_Text:
|
||||
return model.DocumentTypeText
|
||||
case dataset.FormatType_Table:
|
||||
return model.DocumentTypeTable
|
||||
case dataset.FormatType_Image:
|
||||
return model.DocumentTypeImage
|
||||
default:
|
||||
return model.DocumentTypeUnknown
|
||||
}
|
||||
}
|
||||
445
backend/application/knowledge/init.go
Normal file
445
backend/application/knowledge/init.go
Normal file
@@ -0,0 +1,445 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/embedding/ark"
|
||||
"github.com/cloudwego/eino-ext/components/embedding/openai"
|
||||
ao "github.com/cloudwego/eino-ext/components/model/ark"
|
||||
"github.com/cloudwego/eino-ext/components/model/deepseek"
|
||||
"github.com/cloudwego/eino-ext/components/model/gemini"
|
||||
"github.com/cloudwego/eino-ext/components/model/ollama"
|
||||
mo "github.com/cloudwego/eino-ext/components/model/openai"
|
||||
"github.com/cloudwego/eino-ext/components/model/qwen"
|
||||
"github.com/cloudwego/eino/components/prompt"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/milvus-io/milvus/client/v2/milvusclient"
|
||||
"github.com/volcengine/volc-sdk-golang/service/vikingdb"
|
||||
"github.com/volcengine/volc-sdk-golang/service/visual"
|
||||
"google.golang.org/genai"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/application/search"
|
||||
knowledgeImpl "github.com/coze-dev/coze-studio/backend/domain/knowledge/service"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/nl2sql"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/ocr"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/es"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/messages2query"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/rdb"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
chatmodelImpl "github.com/coze-dev/coze-studio/backend/infra/impl/chatmodel"
|
||||
builtinNL2SQL "github.com/coze-dev/coze-studio/backend/infra/impl/document/nl2sql/builtin"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/ocr/veocr"
|
||||
builtinParser "github.com/coze-dev/coze-studio/backend/infra/impl/document/parser/builtin"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/rerank/rrf"
|
||||
sses "github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/elasticsearch"
|
||||
ssmilvus "github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/milvus"
|
||||
ssvikingdb "github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/vikingdb"
|
||||
arkemb "github.com/coze-dev/coze-studio/backend/infra/impl/embedding/ark"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/embedding/wrap"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/eventbus"
|
||||
builtinM2Q "github.com/coze-dev/coze-studio/backend/infra/impl/messages2query/builtin"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
)
|
||||
|
||||
type ServiceComponents struct {
|
||||
DB *gorm.DB
|
||||
IDGenSVC idgen.IDGenerator
|
||||
Storage storage.Storage
|
||||
RDB rdb.RDB
|
||||
ImageX imagex.ImageX
|
||||
ES es.Client
|
||||
EventBus search.ResourceEventBus
|
||||
CacheCli cache.Cmdable
|
||||
}
|
||||
|
||||
func InitService(c *ServiceComponents) (*KnowledgeApplicationService, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
nameServer := os.Getenv(consts.MQServer)
|
||||
|
||||
knowledgeProducer, err := eventbus.NewProducer(nameServer, consts.RMQTopicKnowledge, consts.RMQConsumeGroupKnowledge, 2)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init knowledge producer failed, err=%w", err)
|
||||
}
|
||||
|
||||
var sManagers []searchstore.Manager
|
||||
|
||||
// es full text search
|
||||
sManagers = append(sManagers, sses.NewManager(&sses.ManagerConfig{Client: c.ES}))
|
||||
|
||||
// vector search
|
||||
mgr, err := getVectorStore(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init vector store failed, err=%w", err)
|
||||
}
|
||||
sManagers = append(sManagers, mgr)
|
||||
|
||||
var ocrImpl ocr.OCR
|
||||
switch os.Getenv("OCR_TYPE") {
|
||||
case "ve":
|
||||
ocrAK := os.Getenv("VE_OCR_AK")
|
||||
ocrSK := os.Getenv("VE_OCR_SK")
|
||||
inst := visual.NewInstance()
|
||||
inst.Client.SetAccessKey(ocrAK)
|
||||
inst.Client.SetSecretKey(ocrSK)
|
||||
ocrImpl = veocr.NewOCR(&veocr.Config{Client: inst})
|
||||
default:
|
||||
// accept ocr not configured
|
||||
}
|
||||
|
||||
root, err := os.Getwd()
|
||||
if err != nil {
|
||||
logs.Warnf("[InitConfig] Failed to get current working directory: %v", err)
|
||||
root = os.Getenv("PWD")
|
||||
}
|
||||
|
||||
var rewriter messages2query.MessagesToQuery
|
||||
if rewriterChatModel, _, err := getBuiltinChatModel(ctx, "M2Q_"); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
filePath := filepath.Join(root, "resources/conf/prompt/messages_to_query_template_jinja2.json")
|
||||
rewriterTemplate, err := readJinja2PromptTemplate(filePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rewriter, err = builtinM2Q.NewMessagesToQuery(ctx, rewriterChatModel, rewriterTemplate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var n2s nl2sql.NL2SQL
|
||||
if n2sChatModel, _, err := getBuiltinChatModel(ctx, "NL2SQL_"); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
filePath := filepath.Join(root, "resources/conf/prompt/nl2sql_template_jinja2.json")
|
||||
n2sTemplate, err := readJinja2PromptTemplate(filePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
n2s, err = builtinNL2SQL.NewNL2SQL(ctx, n2sChatModel, n2sTemplate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
imageAnnoChatModel, configured, err := getBuiltinChatModel(ctx, "IA_")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
knowledgeDomainSVC, knowledgeEventHandler := knowledgeImpl.NewKnowledgeSVC(&knowledgeImpl.KnowledgeSVCConfig{
|
||||
DB: c.DB,
|
||||
IDGen: c.IDGenSVC,
|
||||
RDB: c.RDB,
|
||||
Producer: knowledgeProducer,
|
||||
SearchStoreManagers: sManagers,
|
||||
ParseManager: builtinParser.NewManager(c.Storage, ocrImpl, imageAnnoChatModel), // default builtin
|
||||
Storage: c.Storage,
|
||||
Rewriter: rewriter,
|
||||
Reranker: rrf.NewRRFReranker(0), // default rrf
|
||||
NL2Sql: n2s,
|
||||
OCR: ocrImpl,
|
||||
CacheCli: c.CacheCli,
|
||||
IsAutoAnnotationSupported: configured,
|
||||
ModelFactory: chatmodelImpl.NewDefaultFactory(),
|
||||
})
|
||||
|
||||
if err = eventbus.RegisterConsumer(nameServer, consts.RMQTopicKnowledge, consts.RMQConsumeGroupKnowledge, knowledgeEventHandler); err != nil {
|
||||
return nil, fmt.Errorf("register knowledge consumer failed, err=%w", err)
|
||||
}
|
||||
|
||||
KnowledgeSVC.DomainSVC = knowledgeDomainSVC
|
||||
KnowledgeSVC.eventBus = c.EventBus
|
||||
KnowledgeSVC.storage = c.Storage
|
||||
return KnowledgeSVC, nil
|
||||
}
|
||||
|
||||
func getVectorStore(ctx context.Context) (searchstore.Manager, error) {
|
||||
vsType := os.Getenv("VECTOR_STORE_TYPE")
|
||||
|
||||
switch vsType {
|
||||
case "milvus":
|
||||
cctx, cancel := context.WithTimeout(ctx, time.Second*5)
|
||||
defer cancel()
|
||||
|
||||
milvusAddr := os.Getenv("MILVUS_ADDR")
|
||||
mc, err := milvusclient.New(cctx, &milvusclient.ClientConfig{Address: milvusAddr})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init milvus client failed, err=%w", err)
|
||||
}
|
||||
|
||||
emb, err := getEmbedding(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init milvus embedding failed, err=%w", err)
|
||||
}
|
||||
|
||||
mgr, err := ssmilvus.NewManager(&ssmilvus.ManagerConfig{
|
||||
Client: mc,
|
||||
Embedding: emb,
|
||||
EnableHybrid: ptr.Of(true),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init milvus vector store failed, err=%w", err)
|
||||
}
|
||||
|
||||
return mgr, nil
|
||||
case "vikingdb":
|
||||
var (
|
||||
host = os.Getenv("VIKING_DB_HOST")
|
||||
region = os.Getenv("VIKING_DB_REGION")
|
||||
ak = os.Getenv("VIKING_DB_AK")
|
||||
sk = os.Getenv("VIKING_DB_SK")
|
||||
scheme = os.Getenv("VIKING_DB_SCHEME")
|
||||
modelName = os.Getenv("VIKING_DB_MODEL_NAME")
|
||||
)
|
||||
if ak == "" || sk == "" {
|
||||
return nil, fmt.Errorf("invalid vikingdb ak / sk")
|
||||
}
|
||||
if host == "" {
|
||||
host = "api-vikingdb.volces.com"
|
||||
}
|
||||
if region == "" {
|
||||
region = "cn-beijing"
|
||||
}
|
||||
if scheme == "" {
|
||||
scheme = "https"
|
||||
}
|
||||
|
||||
var embConfig *ssvikingdb.VikingEmbeddingConfig
|
||||
if modelName != "" {
|
||||
embName := ssvikingdb.VikingEmbeddingModelName(modelName)
|
||||
if embName.Dimensions() == 0 {
|
||||
return nil, fmt.Errorf("embedding model not support, model_name=%s", modelName)
|
||||
}
|
||||
embConfig = &ssvikingdb.VikingEmbeddingConfig{
|
||||
UseVikingEmbedding: true,
|
||||
EnableHybrid: embName.SupportStatus() == embedding.SupportDenseAndSparse,
|
||||
ModelName: embName,
|
||||
ModelVersion: embName.ModelVersion(),
|
||||
DenseWeight: ptr.Of(0.2),
|
||||
BuiltinEmbedding: nil,
|
||||
}
|
||||
} else {
|
||||
builtinEmbedding, err := getEmbedding(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("builtint embedding init failed, err=%w", err)
|
||||
}
|
||||
|
||||
embConfig = &ssvikingdb.VikingEmbeddingConfig{
|
||||
UseVikingEmbedding: false,
|
||||
EnableHybrid: false,
|
||||
BuiltinEmbedding: builtinEmbedding,
|
||||
}
|
||||
}
|
||||
svc := vikingdb.NewVikingDBService(host, region, ak, sk, scheme)
|
||||
mgr, err := ssvikingdb.NewManager(&ssvikingdb.ManagerConfig{
|
||||
Service: svc,
|
||||
IndexingConfig: nil, // use default config
|
||||
EmbeddingConfig: embConfig,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init vikingdb manager failed, err=%w", err)
|
||||
}
|
||||
|
||||
return mgr, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected vector store type, type=%s", vsType)
|
||||
}
|
||||
}
|
||||
|
||||
func getEmbedding(ctx context.Context) (embedding.Embedder, error) {
|
||||
var emb embedding.Embedder
|
||||
|
||||
switch os.Getenv("EMBEDDING_TYPE") {
|
||||
case "openai":
|
||||
var (
|
||||
openAIEmbeddingBaseURL = os.Getenv("OPENAI_EMBEDDING_BASE_URL")
|
||||
openAIEmbeddingModel = os.Getenv("OPENAI_EMBEDDING_MODEL")
|
||||
openAIEmbeddingApiKey = os.Getenv("OPENAI_EMBEDDING_API_KEY")
|
||||
openAIEmbeddingByAzure = os.Getenv("OPENAI_EMBEDDING_BY_AZURE")
|
||||
openAIEmbeddingApiVersion = os.Getenv("OPENAI_EMBEDDING_API_VERSION")
|
||||
openAIEmbeddingDims = os.Getenv("OPENAI_EMBEDDING_DIMS")
|
||||
openAIRequestEmbeddingDims = os.Getenv("OPENAI_EMBEDDING_REQUEST_DIMS")
|
||||
)
|
||||
|
||||
byAzure, err := strconv.ParseBool(openAIEmbeddingByAzure)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init openai embedding by_azure failed, err=%w", err)
|
||||
}
|
||||
|
||||
dims, err := strconv.ParseInt(openAIEmbeddingDims, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init openai embedding dims failed, err=%w", err)
|
||||
}
|
||||
|
||||
openAICfg := &openai.EmbeddingConfig{
|
||||
APIKey: openAIEmbeddingApiKey,
|
||||
ByAzure: byAzure,
|
||||
BaseURL: openAIEmbeddingBaseURL,
|
||||
APIVersion: openAIEmbeddingApiVersion,
|
||||
Model: openAIEmbeddingModel,
|
||||
// Dimensions: ptr.Of(int(dims)),
|
||||
}
|
||||
reqDims := conv.StrToInt64D(openAIRequestEmbeddingDims, 0)
|
||||
if reqDims > 0 {
|
||||
// some openai model not support request dims
|
||||
openAICfg.Dimensions = ptr.Of(int(reqDims))
|
||||
}
|
||||
|
||||
emb, err = wrap.NewOpenAIEmbedder(ctx, openAICfg, dims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init openai embedding failed, err=%w", err)
|
||||
}
|
||||
|
||||
case "ark":
|
||||
var (
|
||||
arkEmbeddingBaseURL = os.Getenv("ARK_EMBEDDING_BASE_URL")
|
||||
arkEmbeddingModel = os.Getenv("ARK_EMBEDDING_MODEL")
|
||||
arkEmbeddingAK = os.Getenv("ARK_EMBEDDING_AK")
|
||||
arkEmbeddingDims = os.Getenv("ARK_EMBEDDING_DIMS")
|
||||
)
|
||||
|
||||
dims, err := strconv.ParseInt(arkEmbeddingDims, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init ark embedding dims failed, err=%w", err)
|
||||
}
|
||||
|
||||
emb, err = arkemb.NewArkEmbedder(ctx, &ark.EmbeddingConfig{
|
||||
APIKey: arkEmbeddingAK,
|
||||
Model: arkEmbeddingModel,
|
||||
BaseURL: arkEmbeddingBaseURL,
|
||||
}, dims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init ark embedding client failed, err=%w", err)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("init knowledge embedding failed, type not configured")
|
||||
}
|
||||
|
||||
return emb, nil
|
||||
}
|
||||
|
||||
func getBuiltinChatModel(ctx context.Context, envPrefix string) (bcm chatmodel.BaseChatModel, configured bool, err error) {
|
||||
getEnv := func(key string) string {
|
||||
if val := os.Getenv(envPrefix + key); val != "" {
|
||||
return val
|
||||
}
|
||||
return os.Getenv(key)
|
||||
}
|
||||
|
||||
switch getEnv("BUILTIN_CM_TYPE") {
|
||||
case "openai":
|
||||
byAzure, _ := strconv.ParseBool(getEnv("BUILTIN_CM_OPENAI_BY_AZURE"))
|
||||
bcm, err = mo.NewChatModel(ctx, &mo.ChatModelConfig{
|
||||
APIKey: getEnv("BUILTIN_CM_OPENAI_API_KEY"),
|
||||
ByAzure: byAzure,
|
||||
BaseURL: getEnv("BUILTIN_CM_OPENAI_BASE_URL"),
|
||||
Model: getEnv("BUILTIN_CM_OPENAI_MODEL"),
|
||||
})
|
||||
case "ark":
|
||||
bcm, err = ao.NewChatModel(ctx, &ao.ChatModelConfig{
|
||||
APIKey: getEnv("BUILTIN_CM_ARK_API_KEY"),
|
||||
Model: getEnv("BUILTIN_CM_ARK_MODEL"),
|
||||
BaseURL: getEnv("BUILTIN_CM_ARK_BASE_URL"),
|
||||
})
|
||||
case "deepseek":
|
||||
bcm, err = deepseek.NewChatModel(ctx, &deepseek.ChatModelConfig{
|
||||
APIKey: getEnv("BUILTIN_CM_DEEPSEEK_API_KEY"),
|
||||
BaseURL: getEnv("BUILTIN_CM_DEEPSEEK_BASE_URL"),
|
||||
Model: getEnv("BUILTIN_CM_DEEPSEEK_MODEL"),
|
||||
})
|
||||
case "ollama":
|
||||
bcm, err = ollama.NewChatModel(ctx, &ollama.ChatModelConfig{
|
||||
BaseURL: getEnv("BUILTIN_CM_OLLAMA_BASE_URL"),
|
||||
Model: getEnv("BUILTIN_CM_OLLAMA_MODEL"),
|
||||
})
|
||||
case "qwen":
|
||||
bcm, err = qwen.NewChatModel(ctx, &qwen.ChatModelConfig{
|
||||
APIKey: getEnv("BUILTIN_CM_QWEN_API_KEY"),
|
||||
BaseURL: getEnv("BUILTIN_CM_QWEN_BASE_URL"),
|
||||
Model: getEnv("BUILTIN_CM_QWEN_MODEL"),
|
||||
})
|
||||
case "gemini":
|
||||
backend, convErr := strconv.ParseInt(getEnv("BUILTIN_CM_GEMINI_BACKEND"), 10, 64)
|
||||
if convErr != nil {
|
||||
return nil, false, convErr
|
||||
}
|
||||
c, clientErr := genai.NewClient(ctx, &genai.ClientConfig{
|
||||
APIKey: getEnv("BUILTIN_CM_GEMINI_API_KEY"),
|
||||
Backend: genai.Backend(backend),
|
||||
Project: getEnv("BUILTIN_CM_GEMINI_PROJECT"),
|
||||
Location: getEnv("BUILTIN_CM_GEMINI_LOCATION"),
|
||||
HTTPOptions: genai.HTTPOptions{
|
||||
BaseURL: getEnv("BUILTIN_CM_GEMINI_BASE_URL"),
|
||||
},
|
||||
})
|
||||
if clientErr != nil {
|
||||
return nil, false, clientErr
|
||||
}
|
||||
bcm, err = gemini.NewChatModel(ctx, &gemini.Config{
|
||||
Client: c,
|
||||
Model: getEnv("BUILTIN_CM_GEMINI_MODEL"),
|
||||
})
|
||||
default:
|
||||
// accept builtin chat model not configured
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("knowledge init openai chat mode failed, %w", err)
|
||||
}
|
||||
if bcm != nil {
|
||||
configured = true
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func readJinja2PromptTemplate(jsonFilePath string) (prompt.ChatTemplate, error) {
|
||||
b, err := os.ReadFile(jsonFilePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var m2qMessages []*schema.Message
|
||||
if err = json.Unmarshal(b, &m2qMessages); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tpl := make([]schema.MessagesTemplate, len(m2qMessages))
|
||||
for i := range m2qMessages {
|
||||
tpl[i] = m2qMessages[i]
|
||||
}
|
||||
return prompt.FromMessages(schema.Jinja2, tpl...), nil
|
||||
}
|
||||
1123
backend/application/knowledge/knowledge.go
Normal file
1123
backend/application/knowledge/knowledge.go
Normal file
File diff suppressed because it is too large
Load Diff
288
backend/application/memory/convertor.go
Normal file
288
backend/application/memory/convertor.go
Normal file
@@ -0,0 +1,288 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package memory
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/base"
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/table"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/memory/database/entity"
|
||||
database "github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
)
|
||||
|
||||
func convertAddDatabase(req *table.AddDatabaseRequest) *database.CreateDatabaseRequest {
|
||||
fieldItems := make([]*model.FieldItem, 0, len(req.FieldList))
|
||||
for _, field := range req.FieldList {
|
||||
fieldItems = append(fieldItems, &model.FieldItem{
|
||||
Name: field.Name,
|
||||
Desc: field.Desc,
|
||||
Type: field.Type,
|
||||
MustRequired: field.MustRequired,
|
||||
})
|
||||
}
|
||||
|
||||
return &database.CreateDatabaseRequest{
|
||||
Database: &entity.Database{
|
||||
IconURI: req.IconURI,
|
||||
CreatorID: req.CreatorID,
|
||||
SpaceID: req.SpaceID,
|
||||
AppID: req.ProjectID,
|
||||
TableName: req.TableName,
|
||||
TableDesc: req.TableDesc,
|
||||
FieldList: fieldItems,
|
||||
RwMode: req.RwMode,
|
||||
PromptDisabled: req.PromptDisabled,
|
||||
ExtraInfo: req.ExtraInfo,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func ConvertDatabaseRes(res *entity.Database) *table.SingleDatabaseResponse {
|
||||
return &table.SingleDatabaseResponse{
|
||||
DatabaseInfo: convertDatabaseRes(res),
|
||||
|
||||
Code: 0,
|
||||
Msg: "success",
|
||||
BaseResp: &base.BaseResp{
|
||||
StatusCode: 0,
|
||||
StatusMessage: "success",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertUpdateDatabase converts the API update request to domain request
|
||||
func ConvertUpdateDatabase(req *table.UpdateDatabaseRequest) *database.UpdateDatabaseRequest {
|
||||
fieldItems := make([]*model.FieldItem, 0, len(req.FieldList))
|
||||
for _, field := range req.FieldList {
|
||||
fieldItems = append(fieldItems, &model.FieldItem{
|
||||
Name: field.Name,
|
||||
Desc: field.Desc,
|
||||
AlterID: field.AlterId,
|
||||
Type: field.Type,
|
||||
MustRequired: field.MustRequired,
|
||||
})
|
||||
}
|
||||
|
||||
return &database.UpdateDatabaseRequest{
|
||||
Database: &entity.Database{
|
||||
ID: req.ID,
|
||||
IconURI: req.IconURI,
|
||||
TableName: req.TableName,
|
||||
TableDesc: req.TableDesc,
|
||||
FieldList: fieldItems,
|
||||
RwMode: req.RwMode,
|
||||
PromptDisabled: req.PromptDisabled,
|
||||
ExtraInfo: req.ExtraInfo,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// convertUpdateDatabaseResult converts the domain update response to API response
|
||||
func convertUpdateDatabaseResult(res *database.UpdateDatabaseResponse) *table.SingleDatabaseResponse {
|
||||
return &table.SingleDatabaseResponse{
|
||||
DatabaseInfo: convertDatabaseRes(res.Database),
|
||||
|
||||
Code: 0,
|
||||
Msg: "success",
|
||||
BaseResp: &base.BaseResp{
|
||||
StatusCode: 0,
|
||||
StatusMessage: "success",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func convertDatabaseRes(db *entity.Database) *table.DatabaseInfo {
|
||||
fieldItems := make([]*table.FieldItem, 0, len(db.FieldList))
|
||||
for _, field := range db.FieldList {
|
||||
fieldItems = append(fieldItems, &table.FieldItem{
|
||||
Name: field.Name,
|
||||
Desc: field.Desc,
|
||||
Type: field.Type,
|
||||
MustRequired: field.MustRequired,
|
||||
AlterId: field.AlterID,
|
||||
IsSystemField: field.IsSystemField,
|
||||
})
|
||||
}
|
||||
|
||||
return &table.DatabaseInfo{
|
||||
ID: db.ID,
|
||||
SpaceID: db.SpaceID,
|
||||
ProjectID: db.AppID,
|
||||
IconURI: db.IconURI,
|
||||
IconURL: db.IconURL,
|
||||
TableName: db.TableName,
|
||||
TableDesc: db.TableDesc,
|
||||
Status: db.Status,
|
||||
CreatorID: db.CreatorID,
|
||||
CreateTime: db.CreatedAtMs,
|
||||
UpdateTime: db.UpdatedAtMs,
|
||||
FieldList: fieldItems,
|
||||
ActualTableName: db.ActualTableName,
|
||||
RwMode: table.BotTableRWMode(db.RwMode),
|
||||
PromptDisabled: db.PromptDisabled,
|
||||
IsVisible: db.IsVisible,
|
||||
DraftID: db.DraftID,
|
||||
ExtraInfo: db.ExtraInfo,
|
||||
IsAddedToBot: db.IsAddedToAgent,
|
||||
DatamodelTableID: getDataModelTableID(db.ActualTableName),
|
||||
}
|
||||
}
|
||||
|
||||
// convertListDatabase converts the API list request to domain request
|
||||
func convertListDatabase(req *table.ListDatabaseRequest) *database.ListDatabaseRequest {
|
||||
dRes := &database.ListDatabaseRequest{
|
||||
SpaceID: req.SpaceID,
|
||||
TableName: req.TableName,
|
||||
TableType: req.TableType,
|
||||
AppID: req.GetProjectID(),
|
||||
Limit: int(req.GetLimit()),
|
||||
Offset: int(req.GetOffset()),
|
||||
}
|
||||
|
||||
if req.CreatorID != nil && *req.CreatorID != 0 {
|
||||
dRes.CreatorID = req.CreatorID
|
||||
}
|
||||
|
||||
if len(req.OrderBy) > 0 {
|
||||
dRes.OrderBy = make([]*model.OrderBy, len(req.OrderBy))
|
||||
for i, order := range req.OrderBy {
|
||||
dRes.OrderBy[i] = &model.OrderBy{
|
||||
Field: order.Field,
|
||||
Direction: order.Direction,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return dRes
|
||||
}
|
||||
|
||||
// convertListDatabaseRes converts the domain list response to API response
|
||||
func convertListDatabaseRes(res *database.ListDatabaseResponse, bindDatabases []*entity.Database) *table.ListDatabaseResponse {
|
||||
databaseInfos := make([]*table.DatabaseInfo, 0, len(res.Databases))
|
||||
dbMap := slices.ToMap(bindDatabases, func(e *entity.Database) (int64, *entity.Database) {
|
||||
return e.ID, e
|
||||
})
|
||||
for _, db := range res.Databases {
|
||||
databaseInfo := convertDatabaseRes(db)
|
||||
if _, ok := dbMap[db.ID]; ok {
|
||||
databaseInfo.IsAddedToBot = ptr.Of(true)
|
||||
}
|
||||
databaseInfos = append(databaseInfos, databaseInfo)
|
||||
}
|
||||
|
||||
return &table.ListDatabaseResponse{
|
||||
DatabaseInfoList: databaseInfos,
|
||||
TotalCount: res.TotalCount,
|
||||
|
||||
Code: 0,
|
||||
Msg: "success",
|
||||
BaseResp: &base.BaseResp{
|
||||
StatusCode: 0,
|
||||
StatusMessage: "success",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// convertListDatabaseRecordsRes converts domain ListDatabaseRecordResponse to API ListDatabaseRecordsResponse
|
||||
func convertListDatabaseRecordsRes(res *database.ListDatabaseRecordResponse) *table.ListDatabaseRecordsResponse {
|
||||
apiRes := &table.ListDatabaseRecordsResponse{
|
||||
Data: res.Records,
|
||||
TotalNum: int32(res.TotalCount),
|
||||
HasMore: res.HasMore,
|
||||
FieldList: make([]*table.FieldItem, 0, len(res.FieldList)),
|
||||
|
||||
Code: 0,
|
||||
Msg: "success",
|
||||
BaseResp: &base.BaseResp{
|
||||
StatusCode: 0,
|
||||
StatusMessage: "success",
|
||||
},
|
||||
}
|
||||
|
||||
for _, field := range res.FieldList {
|
||||
apiRes.FieldList = append(apiRes.FieldList, &table.FieldItem{
|
||||
Name: field.Name,
|
||||
Desc: field.Desc,
|
||||
Type: field.Type,
|
||||
MustRequired: field.MustRequired,
|
||||
})
|
||||
}
|
||||
|
||||
return apiRes
|
||||
}
|
||||
|
||||
func getDataModelTableID(actualTableName string) string {
|
||||
tableID := ""
|
||||
tableIDStr := strings.Split(actualTableName, "_")
|
||||
if len(tableIDStr) < 2 {
|
||||
return tableID
|
||||
}
|
||||
|
||||
return tableIDStr[1]
|
||||
}
|
||||
|
||||
func convertToBotTableList(databases []*entity.Database, agentID int64, relationMap map[int64]*model.AgentToDatabase) []*table.BotTable {
|
||||
if len(databases) == 0 {
|
||||
return []*table.BotTable{}
|
||||
}
|
||||
|
||||
botTables := make([]*table.BotTable, 0, len(databases))
|
||||
for _, db := range databases {
|
||||
fieldItems := make([]*table.FieldItem, 0, len(db.FieldList))
|
||||
for _, field := range db.FieldList {
|
||||
fieldItems = append(fieldItems, &table.FieldItem{
|
||||
Name: field.Name,
|
||||
Desc: field.Desc,
|
||||
Type: field.Type,
|
||||
MustRequired: field.MustRequired,
|
||||
AlterId: field.AlterID,
|
||||
IsSystemField: field.IsSystemField,
|
||||
})
|
||||
}
|
||||
|
||||
botTable := &table.BotTable{
|
||||
ID: db.ID,
|
||||
BotID: agentID,
|
||||
TableID: strconv.FormatInt(db.ID, 10),
|
||||
TableName: db.TableName,
|
||||
TableDesc: db.TableDesc,
|
||||
Status: table.BotTableStatus(db.Status),
|
||||
CreatorID: db.CreatorID,
|
||||
CreateTime: db.CreatedAtMs,
|
||||
UpdateTime: db.UpdatedAtMs,
|
||||
FieldList: fieldItems,
|
||||
ActualTableName: db.ActualTableName,
|
||||
RwMode: table.BotTableRWMode(db.RwMode),
|
||||
}
|
||||
|
||||
if r, ok := relationMap[db.ID]; ok {
|
||||
botTable.ExtraInfo = map[string]string{
|
||||
"prompt_disabled": fmt.Sprintf("%t", r.PromptDisabled),
|
||||
}
|
||||
}
|
||||
|
||||
botTables = append(botTables, botTable)
|
||||
}
|
||||
|
||||
return botTables
|
||||
}
|
||||
1033
backend/application/memory/database.go
Normal file
1033
backend/application/memory/database.go
Normal file
File diff suppressed because it is too large
Load Diff
64
backend/application/memory/init.go
Normal file
64
backend/application/memory/init.go
Normal file
@@ -0,0 +1,64 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package memory
|
||||
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
database "github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/memory/variables/repository"
|
||||
variables "github.com/coze-dev/coze-studio/backend/domain/memory/variables/service"
|
||||
search "github.com/coze-dev/coze-studio/backend/domain/search/service"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/rdb"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
rdbService "github.com/coze-dev/coze-studio/backend/infra/impl/rdb"
|
||||
)
|
||||
|
||||
type MemoryApplicationServices struct {
|
||||
VariablesDomainSVC variables.Variables
|
||||
DatabaseDomainSVC database.Database
|
||||
RDBDomainSVC rdb.RDB
|
||||
}
|
||||
|
||||
type ServiceComponents struct {
|
||||
IDGen idgen.IDGenerator
|
||||
DB *gorm.DB
|
||||
EventBus search.ResourceEventBus
|
||||
TosClient storage.Storage
|
||||
ResourceDomainNotifier search.ResourceEventBus
|
||||
CacheCli *redis.Client
|
||||
}
|
||||
|
||||
func InitService(c *ServiceComponents) *MemoryApplicationServices {
|
||||
repo := repository.NewVariableRepo(c.DB, c.IDGen)
|
||||
variablesDomainSVC := variables.NewService(repo)
|
||||
rdbSVC := rdbService.NewService(c.DB, c.IDGen)
|
||||
databaseDomainSVC := database.NewService(rdbSVC, c.DB, c.IDGen, c.TosClient, c.CacheCli)
|
||||
|
||||
VariableApplicationSVC.DomainSVC = variablesDomainSVC
|
||||
DatabaseApplicationSVC.DomainSVC = databaseDomainSVC
|
||||
DatabaseApplicationSVC.eventbus = c.ResourceDomainNotifier
|
||||
|
||||
return &MemoryApplicationServices{
|
||||
VariablesDomainSVC: variablesDomainSVC,
|
||||
DatabaseDomainSVC: databaseDomainSVC,
|
||||
RDBDomainSVC: rdbSVC,
|
||||
}
|
||||
}
|
||||
385
backend/application/memory/variables.go
Normal file
385
backend/application/memory/variables.go
Normal file
@@ -0,0 +1,385 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package memory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/base"
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/variables"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/kvmemory"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/project_memory"
|
||||
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/memory/variables/entity"
|
||||
variables "github.com/coze-dev/coze-studio/backend/domain/memory/variables/service"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/i18n"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
type VariableApplicationService struct {
|
||||
DomainSVC variables.Variables
|
||||
}
|
||||
|
||||
var VariableApplicationSVC = VariableApplicationService{}
|
||||
|
||||
var i18nLocal2GroupVariableInfo = map[i18n.Locale]map[project_memory.VariableChannel]project_memory.GroupVariableInfo{
|
||||
i18n.LocaleEN: {
|
||||
project_memory.VariableChannel_APP: {
|
||||
GroupName: "App variable",
|
||||
GroupDesc: "Configures data accessed across multiple development scenarios in the app. It is initialized to a default value each time a new request is sent.",
|
||||
},
|
||||
project_memory.VariableChannel_Custom: {
|
||||
GroupName: "User variable",
|
||||
GroupDesc: "Persistently stores and reads project date for users, such as the preferred language and custom settings.",
|
||||
},
|
||||
project_memory.VariableChannel_System: {
|
||||
GroupName: "System variable",
|
||||
GroupDesc: "Displays the data that you enabled as needed, which can be used to identify users via IDs or handle channel-specific features. The data is automatically generated and is read-only.",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var channel2GroupVariableInfo = map[project_memory.VariableChannel]project_memory.GroupVariableInfo{
|
||||
project_memory.VariableChannel_APP: {
|
||||
GroupName: "应用变量",
|
||||
GroupDesc: "用于配置应用中多处开发场景需要访问的数据,每次新请求均会初始化为默认值。",
|
||||
GroupExtDesc: "",
|
||||
IsReadOnly: false,
|
||||
SubGroupList: []*project_memory.GroupVariableInfo{},
|
||||
VarInfoList: []*project_memory.Variable{},
|
||||
},
|
||||
project_memory.VariableChannel_Custom: {
|
||||
GroupName: "用户变量",
|
||||
GroupDesc: "用于存储每个用户使用项目过程中,需要持久化存储和读取的数据,如用户的语言偏好、个性化设置等。",
|
||||
GroupExtDesc: "",
|
||||
IsReadOnly: false,
|
||||
SubGroupList: []*project_memory.GroupVariableInfo{},
|
||||
VarInfoList: []*project_memory.Variable{},
|
||||
},
|
||||
project_memory.VariableChannel_System: {
|
||||
GroupName: "系统变量",
|
||||
GroupDesc: "可选择开启你需要获取的,系统在用户在请求自动产生的数据,仅可读不可修改。如用于通过ID识别用户或处理某些渠道特有的功能。",
|
||||
GroupExtDesc: "",
|
||||
IsReadOnly: true,
|
||||
SubGroupList: []*project_memory.GroupVariableInfo{},
|
||||
VarInfoList: []*project_memory.Variable{},
|
||||
},
|
||||
}
|
||||
|
||||
func (v *VariableApplicationService) GetSysVariableConf(ctx context.Context, req *kvmemory.GetSysVariableConfRequest) (*kvmemory.GetSysVariableConfResponse, error) {
|
||||
vars := v.DomainSVC.GetSysVariableConf(ctx)
|
||||
|
||||
return &kvmemory.GetSysVariableConfResponse{
|
||||
Conf: vars,
|
||||
GroupConf: vars.GroupByName(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (v *VariableApplicationService) GetProjectVariablesMeta(ctx context.Context, appOwnerID int64, req *project_memory.GetProjectVariableListReq) (*project_memory.GetProjectVariableListResp, error) {
|
||||
uid := ctxutil.GetUIDFromCtx(ctx)
|
||||
if uid == nil {
|
||||
return nil, errorx.New(errno.ErrMemoryPermissionCode, errorx.KV("msg", "session required"))
|
||||
}
|
||||
|
||||
version := ""
|
||||
if req.Version != 0 {
|
||||
version = fmt.Sprintf("%d", req.Version)
|
||||
}
|
||||
|
||||
meta, err := v.DomainSVC.GetProjectVariablesMeta(ctx, req.ProjectID, version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
groupConf, err := v.toGroupVariableInfo(ctx, meta)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &project_memory.GetProjectVariableListResp{
|
||||
VariableList: meta.ToProjectVariables(),
|
||||
GroupConf: groupConf,
|
||||
CanEdit: appOwnerID == *uid,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (v *VariableApplicationService) getGroupVariableConf(ctx context.Context, channel project_memory.VariableChannel) project_memory.GroupVariableInfo {
|
||||
groupConf, ok := channel2GroupVariableInfo[channel]
|
||||
if !ok {
|
||||
return project_memory.GroupVariableInfo{}
|
||||
}
|
||||
|
||||
local := i18n.GetLocale(ctx)
|
||||
i18nConf, ok := i18nLocal2GroupVariableInfo[local][channel]
|
||||
if ok {
|
||||
groupConf.GroupName = i18nConf.GroupName
|
||||
groupConf.GroupDesc = i18nConf.GroupDesc
|
||||
}
|
||||
|
||||
return groupConf
|
||||
}
|
||||
|
||||
func (v *VariableApplicationService) toGroupVariableInfo(ctx context.Context, meta *entity.VariablesMeta) ([]*project_memory.GroupVariableInfo, error) {
|
||||
channel2Vars := meta.GroupByChannel()
|
||||
groupConfList := make([]*project_memory.GroupVariableInfo, 0, len(channel2Vars))
|
||||
|
||||
showChannels := []project_memory.VariableChannel{
|
||||
project_memory.VariableChannel_APP,
|
||||
project_memory.VariableChannel_Custom,
|
||||
project_memory.VariableChannel_System,
|
||||
}
|
||||
|
||||
for _, channel := range showChannels {
|
||||
ch := channel
|
||||
vars := channel2Vars[ch]
|
||||
groupConf := v.getGroupVariableConf(ctx, ch)
|
||||
groupConf.DefaultChannel = &ch
|
||||
if channel != project_memory.VariableChannel_System {
|
||||
groupConf.VarInfoList = vars
|
||||
groupConfList = append(groupConfList, &groupConf)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
key2Var := make(map[string]*project_memory.Variable)
|
||||
for _, v := range vars {
|
||||
key2Var[v.Keyword] = v
|
||||
}
|
||||
|
||||
// project_memory.VariableChannel_System
|
||||
sysVars := v.DomainSVC.GetSysVariableConf(ctx).RemoveLocalChannelVariable()
|
||||
groupName2Group := sysVars.GroupByName()
|
||||
subGroupList := make([]*project_memory.GroupVariableInfo, 0, len(groupName2Group))
|
||||
|
||||
for _, group := range groupName2Group {
|
||||
var e entity.SysConfVariables = group.VarInfoList
|
||||
varList := make([]*project_memory.Variable, 0, len(group.VarInfoList))
|
||||
|
||||
for _, defaultSysMeta := range e.ToVariables().ToProjectVariables() {
|
||||
sysMetaInUserConf := key2Var[defaultSysMeta.Keyword]
|
||||
if sysMetaInUserConf == nil {
|
||||
varList = append(varList, defaultSysMeta)
|
||||
} else {
|
||||
varList = append(varList, sysMetaInUserConf)
|
||||
}
|
||||
}
|
||||
|
||||
pGroupVariableInfo := &project_memory.GroupVariableInfo{
|
||||
GroupName: group.GroupName,
|
||||
GroupDesc: group.GroupDesc,
|
||||
GroupExtDesc: group.GroupExtDesc,
|
||||
IsReadOnly: true,
|
||||
VarInfoList: varList,
|
||||
}
|
||||
|
||||
subGroupList = append(subGroupList, pGroupVariableInfo)
|
||||
}
|
||||
|
||||
groupConf.SubGroupList = subGroupList
|
||||
groupConfList = append(groupConfList, &groupConf)
|
||||
}
|
||||
|
||||
return groupConfList, nil
|
||||
}
|
||||
|
||||
func (v *VariableApplicationService) UpdateProjectVariable(ctx context.Context, req project_memory.UpdateProjectVariableReq) (*project_memory.UpdateProjectVariableResp, error) {
|
||||
uid := ctxutil.GetUIDFromCtx(ctx)
|
||||
if uid == nil {
|
||||
return nil, errorx.New(errno.ErrMemoryPermissionCode, errorx.KV("msg", "session required"))
|
||||
}
|
||||
|
||||
if req.UserID == 0 {
|
||||
req.UserID = *uid
|
||||
}
|
||||
|
||||
// TODO: project owner check
|
||||
|
||||
sysVars := v.DomainSVC.GetSysVariableConf(ctx).ToVariables()
|
||||
|
||||
sysVarsKeys2Meta := make(map[string]*entity.VariableMeta)
|
||||
for _, v := range sysVars.Variables {
|
||||
sysVarsKeys2Meta[v.Keyword] = v
|
||||
}
|
||||
|
||||
list := make([]*project_memory.Variable, 0, len(req.VariableList))
|
||||
for _, v := range req.VariableList {
|
||||
if v.Channel == project_memory.VariableChannel_System &&
|
||||
sysVarsKeys2Meta[v.Keyword] == nil {
|
||||
logs.CtxInfof(ctx, "sys variable not found, keyword: %s", v.Keyword)
|
||||
continue
|
||||
}
|
||||
|
||||
list = append(list, v)
|
||||
}
|
||||
|
||||
key2Var := make(map[string]*project_memory.Variable)
|
||||
for _, v := range req.VariableList {
|
||||
key2Var[v.Keyword] = v
|
||||
}
|
||||
|
||||
for _, v := range sysVars.Variables {
|
||||
if key2Var[v.Keyword] == nil {
|
||||
list = append(list, v.ToProjectVariable())
|
||||
} else {
|
||||
if key2Var[v.Keyword].DefaultValue != v.DefaultValue ||
|
||||
key2Var[v.Keyword].VariableType != v.VariableType {
|
||||
return nil, errorx.New(errno.ErrMemoryPermissionCode, errorx.KV("msg", "can not update system variable"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, vv := range list {
|
||||
if vv.Channel == project_memory.VariableChannel_APP {
|
||||
e := entity.NewVariableMeta(vv)
|
||||
err := e.CheckSchema(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_, err := v.DomainSVC.UpsertProjectMeta(ctx, req.ProjectID, "", req.UserID, entity.NewVariables(list))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &project_memory.UpdateProjectVariableResp{
|
||||
Code: 0,
|
||||
Msg: "success",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (v *VariableApplicationService) GetVariableMeta(ctx context.Context, req *project_memory.GetMemoryVariableMetaReq) (*project_memory.GetMemoryVariableMetaResp, error) {
|
||||
vars, err := v.DomainSVC.GetVariableMeta(ctx, req.ConnectorID, req.ConnectorType, req.GetVersion())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
vars.RemoveDisableVariable()
|
||||
|
||||
return &project_memory.GetMemoryVariableMetaResp{
|
||||
VariableMap: vars.GroupByChannel(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (v *VariableApplicationService) DeleteVariableInstance(ctx context.Context, req *kvmemory.DelProfileMemoryRequest) (*kvmemory.DelProfileMemoryResponse, error) {
|
||||
uid := ctxutil.GetUIDFromCtx(ctx)
|
||||
if uid == nil {
|
||||
return nil, errorx.New(errno.ErrMemoryPermissionCode, errorx.KV("msg", "session required"))
|
||||
}
|
||||
|
||||
bizType := ternary.IFElse(req.BotID == 0, project_memory.VariableConnector_Project, project_memory.VariableConnector_Bot)
|
||||
bizID := ternary.IFElse(req.BotID == 0, req.ProjectID, fmt.Sprintf("%d", req.BotID))
|
||||
|
||||
e := entity.NewUserVariableMeta(&model.UserVariableMeta{
|
||||
BizType: bizType,
|
||||
BizID: bizID,
|
||||
Version: "",
|
||||
ConnectorID: req.GetConnectorID(),
|
||||
ConnectorUID: fmt.Sprintf("%d", *uid),
|
||||
})
|
||||
|
||||
err := v.DomainSVC.DeleteVariableInstance(ctx, e, req.Keywords)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &kvmemory.DelProfileMemoryResponse{}, nil
|
||||
}
|
||||
|
||||
func (v *VariableApplicationService) GetPlayGroundMemory(ctx context.Context, req *kvmemory.GetProfileMemoryRequest) (*kvmemory.GetProfileMemoryResponse, error) {
|
||||
uid := ctxutil.GetUIDFromCtx(ctx)
|
||||
if uid == nil {
|
||||
return nil, errorx.New(errno.ErrMemoryPermissionCode, errorx.KV("msg", "session required"))
|
||||
}
|
||||
|
||||
isProjectKV := req.ProjectID != nil
|
||||
versionStr := strconv.FormatInt(req.GetProjectVersion(), 10)
|
||||
if req.GetProjectVersion() == 0 {
|
||||
versionStr = ""
|
||||
}
|
||||
|
||||
bizType := ternary.IFElse(isProjectKV, project_memory.VariableConnector_Project, project_memory.VariableConnector_Bot)
|
||||
bizID := ternary.IFElse(isProjectKV, req.GetProjectID(), fmt.Sprintf("%d", req.BotID))
|
||||
version := ternary.IFElse(isProjectKV, versionStr, "")
|
||||
connectId := ternary.IFElse(req.ConnectorID == nil, consts.CozeConnectorID, req.GetConnectorID())
|
||||
connectorUID := ternary.IFElse(req.UserID == 0, *uid, req.UserID)
|
||||
|
||||
e := entity.NewUserVariableMeta(&model.UserVariableMeta{
|
||||
BizType: bizType,
|
||||
BizID: bizID,
|
||||
Version: version,
|
||||
ConnectorID: connectId,
|
||||
ConnectorUID: fmt.Sprintf("%d", connectorUID),
|
||||
})
|
||||
|
||||
res, err := v.DomainSVC.GetVariableChannelInstance(ctx, e, req.Keywords, req.VariableChannel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &kvmemory.GetProfileMemoryResponse{
|
||||
Memories: res,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (v *VariableApplicationService) SetVariableInstance(ctx context.Context, req *kvmemory.SetKvMemoryReq) (*kvmemory.SetKvMemoryResp, error) {
|
||||
uid := ctxutil.GetUIDFromCtx(ctx)
|
||||
if uid == nil {
|
||||
return nil, errorx.New(errno.ErrMemoryPermissionCode, errorx.KV("msg", "session required"))
|
||||
}
|
||||
|
||||
isProjectKV := req.ProjectID != nil
|
||||
versionStr := strconv.FormatInt(req.GetProjectVersion(), 10)
|
||||
if req.GetProjectVersion() == 0 {
|
||||
versionStr = ""
|
||||
}
|
||||
|
||||
bizType := ternary.IFElse(isProjectKV, project_memory.VariableConnector_Project, project_memory.VariableConnector_Bot)
|
||||
bizID := ternary.IFElse(isProjectKV, req.GetProjectID(), fmt.Sprintf("%d", req.BotID))
|
||||
version := ternary.IFElse(isProjectKV, versionStr, "")
|
||||
connectId := ternary.IFElse(req.ConnectorID == nil, consts.CozeConnectorID, req.GetConnectorID())
|
||||
connectorUID := ternary.IFElse(req.GetUserID() == 0, *uid, req.GetUserID())
|
||||
|
||||
e := entity.NewUserVariableMeta(&model.UserVariableMeta{
|
||||
BizType: bizType,
|
||||
BizID: bizID,
|
||||
Version: version,
|
||||
ConnectorID: connectId,
|
||||
ConnectorUID: fmt.Sprintf("%d", connectorUID),
|
||||
})
|
||||
|
||||
exitKeys, err := v.DomainSVC.SetVariableInstance(ctx, e, req.Data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
exitKeysStr, _ := json.Marshal(exitKeys)
|
||||
|
||||
return &kvmemory.SetKvMemoryResp{
|
||||
BaseResp: &base.BaseResp{
|
||||
Extra: map[string]string{"existKeys": string(exitKeysStr)},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
198
backend/application/modelmgr/init.go
Normal file
198
backend/application/modelmgr/init.go
Normal file
@@ -0,0 +1,198 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package modelmgr
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
"gorm.io/gorm"
|
||||
|
||||
crossmodelmgr "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/modelmgr/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/modelmgr/service"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
)
|
||||
|
||||
func InitService(db *gorm.DB, idgen idgen.IDGenerator, oss storage.Storage) (*ModelmgrApplicationService, error) {
|
||||
svc := service.NewModelManager(db, idgen, oss)
|
||||
if err := loadStaticModelConfig(svc, oss); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ModelmgrApplicationSVC.DomainSVC = svc
|
||||
|
||||
return ModelmgrApplicationSVC, nil
|
||||
}
|
||||
|
||||
func loadStaticModelConfig(svc modelmgr.Manager, oss storage.Storage) error {
|
||||
ctx := context.Background()
|
||||
|
||||
id2Meta := make(map[int64]*entity.ModelMeta)
|
||||
var cursor *string
|
||||
for {
|
||||
req := &modelmgr.ListModelMetaRequest{
|
||||
Status: []entity.ModelMetaStatus{
|
||||
crossmodelmgr.StatusInUse,
|
||||
crossmodelmgr.StatusPending,
|
||||
crossmodelmgr.StatusDeleted,
|
||||
},
|
||||
Limit: 100,
|
||||
Cursor: cursor,
|
||||
}
|
||||
listMetaResp, err := svc.ListModelMeta(ctx, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, item := range listMetaResp.ModelMetaList {
|
||||
cpItem := item
|
||||
id2Meta[cpItem.ID] = cpItem
|
||||
}
|
||||
if !listMetaResp.HasMore {
|
||||
break
|
||||
}
|
||||
cursor = listMetaResp.NextCursor
|
||||
}
|
||||
|
||||
root, err := os.Getwd()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
envModelMeta, envModelEntity, err := initModelByEnv(root, "resources/conf/model/template")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
filePath := filepath.Join(root, "resources/conf/model/meta")
|
||||
staticModelMeta, err := readDirYaml[crossmodelmgr.ModelMeta](filePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
staticModelMeta = append(staticModelMeta, envModelMeta...)
|
||||
for _, modelMeta := range staticModelMeta {
|
||||
if _, found := id2Meta[modelMeta.ID]; !found {
|
||||
if modelMeta.IconURI == "" && modelMeta.IconURL == "" {
|
||||
return fmt.Errorf("missing icon URI or icon URL, id=%d", modelMeta.ID)
|
||||
} else if modelMeta.IconURL != "" {
|
||||
// do nothing
|
||||
} else if modelMeta.IconURI != "" {
|
||||
// try local path
|
||||
base := filepath.Base(modelMeta.IconURI)
|
||||
iconPath := filepath.Join("resources/conf/model/icon", base)
|
||||
if _, err = os.Stat(iconPath); err == nil {
|
||||
// try upload icon
|
||||
icon, err := os.ReadFile(iconPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
key := fmt.Sprintf("icon_%s_%d", base, time.Now().Second())
|
||||
if err := oss.PutObject(ctx, key, icon); err != nil {
|
||||
return err
|
||||
}
|
||||
modelMeta.IconURI = key
|
||||
} else if errors.Is(err, os.ErrNotExist) {
|
||||
// try to get object from uri
|
||||
if _, err := oss.GetObject(ctx, modelMeta.IconURI); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
newMeta, err := svc.CreateModelMeta(ctx, modelMeta)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||
logs.Infof("[loadStaticModelConfig] model meta conflict for id=%d, skip", newMeta.ID)
|
||||
}
|
||||
return err
|
||||
} else {
|
||||
logs.Infof("[loadStaticModelConfig] model meta create success, id=%d", newMeta.ID)
|
||||
}
|
||||
id2Meta[newMeta.ID] = newMeta
|
||||
} else {
|
||||
logs.Infof("[loadStaticModelConfig] model meta founded, skip create, id=%d", modelMeta.ID)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
filePath = filepath.Join(root, "resources/conf/model/entity")
|
||||
staticModel, err := readDirYaml[crossmodelmgr.Model](filePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
staticModel = append(staticModel, envModelEntity...)
|
||||
for _, modelEntity := range staticModel {
|
||||
curModelEntities, err := svc.MGetModelByID(ctx, &modelmgr.MGetModelRequest{IDs: []int64{modelEntity.ID}})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(curModelEntities) > 0 {
|
||||
logs.Infof("[loadStaticModelConfig] model entity founded, skip create, id=%d", modelEntity.ID)
|
||||
continue
|
||||
}
|
||||
meta, found := id2Meta[modelEntity.Meta.ID]
|
||||
if !found {
|
||||
return fmt.Errorf("model meta not found for id=%d, model_id=%d", modelEntity.Meta.ID, modelEntity.ID)
|
||||
}
|
||||
modelEntity.Meta = *meta
|
||||
if _, err = svc.CreateModel(ctx, &entity.Model{Model: modelEntity}); err != nil {
|
||||
if errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||
logs.Infof("[loadStaticModelConfig] model entity conflict for id=%d, skip", modelEntity.ID)
|
||||
}
|
||||
return err
|
||||
} else {
|
||||
logs.Infof("[loadStaticModelConfig] model entity create success, id=%d", modelEntity.ID)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func readDirYaml[T any](dir string) ([]*T, error) {
|
||||
des, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp := make([]*T, 0, len(des))
|
||||
for _, file := range des {
|
||||
if file.IsDir() {
|
||||
continue
|
||||
}
|
||||
if strings.HasSuffix(file.Name(), ".yaml") || strings.HasSuffix(file.Name(), ".yml") {
|
||||
filePath := filepath.Join(dir, file.Name())
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var content T
|
||||
if err := yaml.Unmarshal(data, &content); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp = append(resp, &content)
|
||||
}
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
148
backend/application/modelmgr/init_by_env.go
Normal file
148
backend/application/modelmgr/init_by_env.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package modelmgr
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
)
|
||||
|
||||
func initModelByEnv(wd, templatePath string) (metaSlice []*modelmgr.ModelMeta, entitySlice []*modelmgr.Model, err error) {
|
||||
metaRoot := filepath.Join(wd, templatePath, "meta")
|
||||
entityRoot := filepath.Join(wd, templatePath, "entity")
|
||||
|
||||
for i := -1; i < 1000; i++ {
|
||||
rawProtocol := os.Getenv(concatEnvKey(modelProtocolPrefix, i))
|
||||
if rawProtocol == "" {
|
||||
if i < 0 {
|
||||
continue
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
protocol := chatmodel.Protocol(rawProtocol)
|
||||
info, valid := getModelEnv(i)
|
||||
if !valid {
|
||||
break
|
||||
}
|
||||
|
||||
mapping, found := modelMapping[protocol]
|
||||
if !found {
|
||||
return nil, nil, fmt.Errorf("[initModelByEnv] unsupport protocol: %s", rawProtocol)
|
||||
}
|
||||
|
||||
switch protocol {
|
||||
case chatmodel.ProtocolArk:
|
||||
fileSuffix, foundTemplate := mapping[info.modelName]
|
||||
if !foundTemplate {
|
||||
logs.Warnf("[initModelByEnv] unsupport model=%s, using default config", info.modelName)
|
||||
}
|
||||
modelMeta, err := readYaml[modelmgr.ModelMeta](filepath.Join(metaRoot, concatTemplateFileName("model_meta_template_ark", fileSuffix)))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
modelEntity, err := readYaml[modelmgr.Model](filepath.Join(entityRoot, concatTemplateFileName("model_entity_template_ark", fileSuffix)))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
id, err := strconv.ParseInt(info.id, 10, 64)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// meta 和 entity 用一个 id,有概率冲突
|
||||
modelMeta.ID = id
|
||||
modelMeta.ConnConfig.Model = info.modelID
|
||||
modelMeta.ConnConfig.APIKey = info.apiKey
|
||||
if info.baseURL != "" {
|
||||
modelMeta.ConnConfig.BaseURL = info.baseURL
|
||||
}
|
||||
modelEntity.ID = id
|
||||
modelEntity.Meta.ID = id
|
||||
if !foundTemplate {
|
||||
modelEntity.Name = info.modelName
|
||||
}
|
||||
|
||||
metaSlice = append(metaSlice, modelMeta)
|
||||
entitySlice = append(entitySlice, modelEntity)
|
||||
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("[initModelByEnv] unsupport protocol: %s", rawProtocol)
|
||||
}
|
||||
}
|
||||
|
||||
return metaSlice, entitySlice, nil
|
||||
}
|
||||
|
||||
type envModelInfo struct {
|
||||
id, modelName, modelID, apiKey, baseURL string
|
||||
}
|
||||
|
||||
func getModelEnv(idx int) (info envModelInfo, valid bool) {
|
||||
info.id = os.Getenv(concatEnvKey(modelOpenCozeIDPrefix, idx))
|
||||
info.modelName = os.Getenv(concatEnvKey(modelNamePrefix, idx))
|
||||
info.modelID = os.Getenv(concatEnvKey(modelIDPrefix, idx))
|
||||
info.apiKey = os.Getenv(concatEnvKey(modelApiKeyPrefix, idx))
|
||||
info.baseURL = os.Getenv(concatEnvKey(modelBaseURLPrefix, idx))
|
||||
valid = info.modelName != "" && info.modelID != "" && info.apiKey != ""
|
||||
return
|
||||
}
|
||||
|
||||
func readYaml[T any](fPath string) (*T, error) {
|
||||
data, err := os.ReadFile(fPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var content T
|
||||
if err := yaml.Unmarshal(data, &content); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &content, nil
|
||||
}
|
||||
|
||||
func concatEnvKey(prefix string, idx int) string {
|
||||
if idx < 0 {
|
||||
return prefix
|
||||
}
|
||||
return fmt.Sprintf("%s_%d", prefix, idx)
|
||||
}
|
||||
|
||||
func concatTemplateFileName(prefix, suffix string) string {
|
||||
if suffix == "" {
|
||||
return prefix + ".yaml"
|
||||
}
|
||||
return prefix + "_" + suffix + ".yaml"
|
||||
}
|
||||
|
||||
const (
|
||||
modelProtocolPrefix = "MODEL_PROTOCOL" // model protocol
|
||||
modelOpenCozeIDPrefix = "MODEL_OPENCOZE_ID" // opencoze model id
|
||||
modelNamePrefix = "MODEL_NAME" // model name,
|
||||
modelIDPrefix = "MODEL_ID" // model in conn config
|
||||
modelApiKeyPrefix = "MODEL_API_KEY" // model api key
|
||||
modelBaseURLPrefix = "MODEL_BASE_URL" // model base url
|
||||
)
|
||||
|
||||
var modelMapping = map[chatmodel.Protocol]map[string]string{
|
||||
chatmodel.ProtocolArk: {
|
||||
"doubao-seed-1.6": "doubao-seed-1.6",
|
||||
"doubao-seed-1.6-flash": "doubao-seed-1.6-flash",
|
||||
"doubao-seed-1.6-thinking": "doubao-seed-1.6-thinking",
|
||||
"doubao-1.5-thinking-vision-pro": "doubao-1.5-thinking-vision-pro",
|
||||
"doubao-1.5-thinking-pro": "doubao-1.5-thinking-pro",
|
||||
"doubao-1.5-vision-pro": "doubao-1.5-vision-pro",
|
||||
"doubao-1.5-vision-lite": "doubao-1.5-vision-lite",
|
||||
"doubao-1.5-pro-32k": "doubao-1.5-pro-32k",
|
||||
"doubao-1.5-pro-256k": "doubao-1.5-pro-256k",
|
||||
"doubao-1.5-lite": "doubao-1.5-lite",
|
||||
"deepseek-r1": "volc_deepseek-r1",
|
||||
"deepseek-v3": "volc_deepseek-v3",
|
||||
},
|
||||
}
|
||||
31
backend/application/modelmgr/init_by_env_test.go
Normal file
31
backend/application/modelmgr/init_by_env_test.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package modelmgr
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
|
||||
)
|
||||
|
||||
func TestInitByEnv(t *testing.T) {
|
||||
i := 0
|
||||
for k := range modelMapping[chatmodel.ProtocolArk] {
|
||||
_ = os.Setenv(concatEnvKey(modelProtocolPrefix, i), "ark")
|
||||
_ = os.Setenv(concatEnvKey(modelOpenCozeIDPrefix, i), fmt.Sprintf("%d", 45678+i))
|
||||
_ = os.Setenv(concatEnvKey(modelNamePrefix, i), k)
|
||||
_ = os.Setenv(concatEnvKey(modelIDPrefix, i), k)
|
||||
_ = os.Setenv(concatEnvKey(modelApiKeyPrefix, i), "mock_api_key")
|
||||
i++
|
||||
}
|
||||
|
||||
wd, err := os.Getwd()
|
||||
assert.NoError(t, err)
|
||||
|
||||
ms, es, err := initModelByEnv(wd, "../../conf/model/template")
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, ms, len(modelMapping[chatmodel.ProtocolArk]))
|
||||
assert.Len(t, es, len(modelMapping[chatmodel.ProtocolArk]))
|
||||
}
|
||||
205
backend/application/modelmgr/modelmgr.go
Normal file
205
backend/application/modelmgr/modelmgr.go
Normal file
@@ -0,0 +1,205 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package modelmgr
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
modelmgrEntity "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/developer_api"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/modelmgr"
|
||||
modelEntity "github.com/coze-dev/coze-studio/backend/domain/modelmgr/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/i18n"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/sets"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
)
|
||||
|
||||
type ModelmgrApplicationService struct {
|
||||
DomainSVC modelmgr.Manager
|
||||
}
|
||||
|
||||
var ModelmgrApplicationSVC = &ModelmgrApplicationService{}
|
||||
|
||||
func (m *ModelmgrApplicationService) GetModelList(ctx context.Context, req *developer_api.GetTypeListRequest) (
|
||||
resp *developer_api.GetTypeListResponse, err error,
|
||||
) {
|
||||
// 一般不太可能同时配置这么多模型
|
||||
const modelMaxLimit = 300
|
||||
|
||||
modelResp, err := m.DomainSVC.ListModel(ctx, &modelmgr.ListModelRequest{
|
||||
Limit: modelMaxLimit,
|
||||
Cursor: nil,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
locale := i18n.GetLocale(ctx)
|
||||
modelList, err := slices.TransformWithErrorCheck(modelResp.ModelList, func(m *modelEntity.Model) (*developer_api.Model, error) {
|
||||
logs.CtxInfof(ctx, "ChatModel DefaultParameters: %v", m.DefaultParameters)
|
||||
return modelDo2To(m, locale)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &developer_api.GetTypeListResponse{
|
||||
Code: 0,
|
||||
Msg: "success",
|
||||
Data: &developer_api.GetTypeListData{
|
||||
ModelList: modelList,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func modelDo2To(model *modelEntity.Model, locale i18n.Locale) (*developer_api.Model, error) {
|
||||
mm := model.Meta
|
||||
|
||||
mps := slices.Transform(model.DefaultParameters,
|
||||
func(param *modelmgrEntity.Parameter) *developer_api.ModelParameter {
|
||||
return parameterDo2To(param, locale)
|
||||
},
|
||||
)
|
||||
|
||||
modalSet := sets.FromSlice(mm.Capability.InputModal)
|
||||
|
||||
return &developer_api.Model{
|
||||
Name: model.Name,
|
||||
ModelType: model.ID,
|
||||
ModelClass: mm.Protocol.TOModelClass(),
|
||||
ModelIcon: mm.IconURL,
|
||||
ModelInputPrice: 0,
|
||||
ModelOutputPrice: 0,
|
||||
ModelQuota: &developer_api.ModelQuota{
|
||||
TokenLimit: int32(mm.Capability.MaxTokens),
|
||||
TokenResp: int32(mm.Capability.OutputTokens),
|
||||
// TokenSystem: 0,
|
||||
// TokenUserIn: 0,
|
||||
// TokenToolsIn: 0,
|
||||
// TokenToolsOut: 0,
|
||||
// TokenData: 0,
|
||||
// TokenHistory: 0,
|
||||
// TokenCutSwitch: false,
|
||||
PriceIn: 0,
|
||||
PriceOut: 0,
|
||||
SystemPromptLimit: nil,
|
||||
},
|
||||
ModelName: mm.Name,
|
||||
ModelClassName: mm.Protocol.TOModelClass().String(),
|
||||
IsOffline: mm.Status != modelmgrEntity.StatusInUse,
|
||||
ModelParams: mps,
|
||||
ModelDesc: []*developer_api.ModelDescGroup{
|
||||
{
|
||||
GroupName: "Description",
|
||||
Desc: []string{model.Description},
|
||||
},
|
||||
},
|
||||
FuncConfig: nil,
|
||||
EndpointName: nil,
|
||||
ModelTagList: nil,
|
||||
IsUpRequired: nil,
|
||||
ModelBriefDesc: mm.Description.Read(locale),
|
||||
ModelSeries: &developer_api.ModelSeriesInfo{ // TODO: 替换为真实配置
|
||||
SeriesName: "热门模型",
|
||||
},
|
||||
ModelStatusDetails: nil,
|
||||
ModelAbility: &developer_api.ModelAbility{
|
||||
CotDisplay: ptr.Of(mm.Capability.Reasoning),
|
||||
FunctionCall: ptr.Of(mm.Capability.FunctionCall),
|
||||
ImageUnderstanding: ptr.Of(modalSet.Contains(modelmgrEntity.ModalImage)),
|
||||
VideoUnderstanding: ptr.Of(modalSet.Contains(modelmgrEntity.ModalVideo)),
|
||||
AudioUnderstanding: ptr.Of(modalSet.Contains(modelmgrEntity.ModalAudio)),
|
||||
SupportMultiModal: ptr.Of(len(modalSet) > 1),
|
||||
PrefillResp: ptr.Of(mm.Capability.PrefillResponse),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parameterDo2To(param *modelmgrEntity.Parameter, locale i18n.Locale) *developer_api.ModelParameter {
|
||||
if param == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
apiOptions := make([]*developer_api.Option, 0, len(param.Options))
|
||||
for _, opt := range param.Options {
|
||||
apiOptions = append(apiOptions, &developer_api.Option{
|
||||
Label: opt.Label,
|
||||
Value: opt.Value,
|
||||
})
|
||||
}
|
||||
|
||||
var custom string
|
||||
var creative, balance, precise *string
|
||||
if val, ok := param.DefaultVal[modelmgrEntity.DefaultTypeDefault]; ok {
|
||||
custom = val
|
||||
}
|
||||
|
||||
if val, ok := param.DefaultVal[modelmgrEntity.DefaultTypeCreative]; ok {
|
||||
creative = ptr.Of(val)
|
||||
}
|
||||
|
||||
if val, ok := param.DefaultVal[modelmgrEntity.DefaultTypeBalance]; ok {
|
||||
balance = ptr.Of(val)
|
||||
}
|
||||
|
||||
if val, ok := param.DefaultVal[modelmgrEntity.DefaultTypePrecise]; ok {
|
||||
precise = ptr.Of(val)
|
||||
}
|
||||
|
||||
return &developer_api.ModelParameter{
|
||||
Name: string(param.Name),
|
||||
Label: param.Label.Read(locale),
|
||||
Desc: param.Desc.Read(locale),
|
||||
Type: func() developer_api.ModelParamType {
|
||||
switch param.Type {
|
||||
case modelmgrEntity.ValueTypeBoolean:
|
||||
return developer_api.ModelParamType_Boolean
|
||||
case modelmgrEntity.ValueTypeInt:
|
||||
return developer_api.ModelParamType_Int
|
||||
case modelmgrEntity.ValueTypeFloat:
|
||||
return developer_api.ModelParamType_Float
|
||||
default:
|
||||
return developer_api.ModelParamType_String
|
||||
}
|
||||
}(),
|
||||
Min: param.Min,
|
||||
Max: param.Max,
|
||||
Precision: int32(param.Precision),
|
||||
DefaultVal: &developer_api.ModelParamDefaultValue{
|
||||
DefaultVal: custom,
|
||||
Creative: creative,
|
||||
Balance: balance,
|
||||
Precise: precise,
|
||||
},
|
||||
Options: apiOptions,
|
||||
ParamClass: &developer_api.ModelParamClass{
|
||||
ClassID: func() int32 {
|
||||
switch param.Style.Widget {
|
||||
case modelmgrEntity.WidgetSlider:
|
||||
return 1
|
||||
case modelmgrEntity.WidgetRadioButtons:
|
||||
return 2
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}(),
|
||||
Label: param.Style.Label.Read(locale),
|
||||
},
|
||||
}
|
||||
}
|
||||
39
backend/application/openauth/init.go
Normal file
39
backend/application/openauth/init.go
Normal file
@@ -0,0 +1,39 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package openauth
|
||||
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
|
||||
openapiauth2 "github.com/coze-dev/coze-studio/backend/domain/openauth/openapiauth"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
)
|
||||
|
||||
var (
|
||||
openapiAuthDomainSVC openapiauth2.APIAuth
|
||||
)
|
||||
|
||||
func InitService(db *gorm.DB, idGenSVC idgen.IDGenerator) *OpenAuthApplicationService {
|
||||
openapiAuthDomainSVC = openapiauth2.NewService(&openapiauth2.Components{
|
||||
IDGen: idGenSVC,
|
||||
DB: db,
|
||||
})
|
||||
|
||||
OpenAuthApplication.OpenAPIDomainSVC = openapiAuthDomainSVC
|
||||
|
||||
return OpenAuthApplication
|
||||
}
|
||||
203
backend/application/openauth/openapiauth.go
Normal file
203
backend/application/openauth/openapiauth.go
Normal file
@@ -0,0 +1,203 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package openauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
openapimodel "github.com/coze-dev/coze-studio/backend/api/model/permission/openapiauth"
|
||||
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
|
||||
openapi "github.com/coze-dev/coze-studio/backend/domain/openauth/openapiauth"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/openauth/openapiauth/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
)
|
||||
|
||||
type OpenAuthApplicationService struct {
|
||||
OpenAPIDomainSVC openapi.APIAuth
|
||||
}
|
||||
|
||||
var OpenAuthApplication = &OpenAuthApplicationService{}
|
||||
|
||||
func (s *OpenAuthApplicationService) GetPersonalAccessTokenAndPermission(ctx context.Context, req *openapimodel.GetPersonalAccessTokenAndPermissionRequest) (*openapimodel.GetPersonalAccessTokenAndPermissionResponse, error) {
|
||||
|
||||
resp := new(openapimodel.GetPersonalAccessTokenAndPermissionResponse)
|
||||
|
||||
userID := ctxutil.GetUIDFromCtx(ctx)
|
||||
|
||||
appReq := &entity.GetApiKey{
|
||||
ID: req.ID,
|
||||
}
|
||||
apiKeyResp, err := openapiAuthDomainSVC.Get(ctx, appReq)
|
||||
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "OpenAuthApplicationService.GetPersonalAccessTokenAndPermission failed, err=%v", err)
|
||||
return resp, errors.New("GetPersonalAccessTokenAndPermission failed")
|
||||
}
|
||||
if apiKeyResp == nil {
|
||||
return resp, errors.New("GetPersonalAccessTokenAndPermission failed")
|
||||
}
|
||||
|
||||
if apiKeyResp.UserID != *userID {
|
||||
return resp, errors.New("permission not match")
|
||||
}
|
||||
resp.Data = &openapimodel.GetPersonalAccessTokenAndPermissionResponseData{
|
||||
PersonalAccessToken: &openapimodel.PersonalAccessToken{
|
||||
ID: apiKeyResp.ID,
|
||||
Name: apiKeyResp.Name,
|
||||
ExpireAt: apiKeyResp.ExpiredAt,
|
||||
CreatedAt: apiKeyResp.CreatedAt,
|
||||
UpdatedAt: apiKeyResp.UpdatedAt,
|
||||
},
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (s *OpenAuthApplicationService) CreatePersonalAccessToken(ctx context.Context, req *openapimodel.CreatePersonalAccessTokenAndPermissionRequest) (*openapimodel.CreatePersonalAccessTokenAndPermissionResponse, error) {
|
||||
resp := new(openapimodel.CreatePersonalAccessTokenAndPermissionResponse)
|
||||
userID := ctxutil.GetUIDFromCtx(ctx)
|
||||
|
||||
appReq := &entity.CreateApiKey{
|
||||
Name: req.Name,
|
||||
Expire: req.ExpireAt,
|
||||
UserID: *userID,
|
||||
}
|
||||
|
||||
if req.DurationDay == "customize" {
|
||||
appReq.Expire = req.ExpireAt
|
||||
} else {
|
||||
expireDay, err := strconv.ParseInt(req.DurationDay, 10, 64)
|
||||
if err != nil {
|
||||
return resp, errors.New("invalid expireDay")
|
||||
}
|
||||
appReq.Expire = time.Now().Add(time.Duration(expireDay) * time.Hour * 24).Unix()
|
||||
}
|
||||
|
||||
apiKeyResp, err := openapiAuthDomainSVC.Create(ctx, appReq)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "OpenAuthApplicationService.CreatePersonalAccessToken failed, err=%v", err)
|
||||
return resp, errors.New("CreatePersonalAccessToken failed")
|
||||
}
|
||||
resp.Data = &openapimodel.CreatePersonalAccessTokenAndPermissionResponseData{
|
||||
PersonalAccessToken: &openapimodel.PersonalAccessToken{
|
||||
ID: apiKeyResp.ID,
|
||||
Name: apiKeyResp.Name,
|
||||
ExpireAt: apiKeyResp.ExpiredAt,
|
||||
|
||||
CreatedAt: apiKeyResp.CreatedAt,
|
||||
UpdatedAt: apiKeyResp.UpdatedAt,
|
||||
},
|
||||
Token: apiKeyResp.ApiKey,
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (s *OpenAuthApplicationService) ListPersonalAccessTokens(ctx context.Context, req *openapimodel.ListPersonalAccessTokensRequest) (*openapimodel.ListPersonalAccessTokensResponse, error) {
|
||||
|
||||
resp := new(openapimodel.ListPersonalAccessTokensResponse)
|
||||
|
||||
userID := ctxutil.GetUIDFromCtx(ctx)
|
||||
appReq := &entity.ListApiKey{
|
||||
UserID: *userID,
|
||||
Page: *req.Page,
|
||||
Limit: *req.Size,
|
||||
}
|
||||
|
||||
apiKeyResp, err := openapiAuthDomainSVC.List(ctx, appReq)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "OpenAuthApplicationService.ListPersonalAccessTokens failed, err=%v", err)
|
||||
return resp, errors.New("ListPersonalAccessTokens failed")
|
||||
}
|
||||
|
||||
if apiKeyResp == nil {
|
||||
return resp, nil
|
||||
}
|
||||
resp.Data = &openapimodel.ListPersonalAccessTokensResponseData{
|
||||
HasMore: apiKeyResp.HasMore,
|
||||
PersonalAccessTokens: slices.Transform(apiKeyResp.ApiKeys, func(a *entity.ApiKey) *openapimodel.PersonalAccessTokenWithCreatorInfo {
|
||||
lastUsedAt := a.LastUsedAt
|
||||
if lastUsedAt == 0 {
|
||||
lastUsedAt = -1
|
||||
}
|
||||
return &openapimodel.PersonalAccessTokenWithCreatorInfo{
|
||||
ID: a.ID,
|
||||
Name: a.Name,
|
||||
ExpireAt: a.ExpiredAt,
|
||||
CreatedAt: a.CreatedAt,
|
||||
UpdatedAt: a.UpdatedAt,
|
||||
LastUsedAt: lastUsedAt,
|
||||
}
|
||||
}),
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (s *OpenAuthApplicationService) DeletePersonalAccessTokenAndPermission(ctx context.Context, req *openapimodel.DeletePersonalAccessTokenAndPermissionRequest) (*openapimodel.DeletePersonalAccessTokenAndPermissionResponse, error) {
|
||||
resp := new(openapimodel.DeletePersonalAccessTokenAndPermissionResponse)
|
||||
|
||||
userID := ctxutil.GetUIDFromCtx(ctx)
|
||||
|
||||
appReq := &entity.DeleteApiKey{
|
||||
ID: req.ID,
|
||||
UserID: *userID,
|
||||
}
|
||||
err := openapiAuthDomainSVC.Delete(ctx, appReq)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "OpenAuthApplicationService.DeletePersonalAccessTokenAndPermission failed, err=%v", err)
|
||||
return resp, errors.New("DeletePersonalAccessTokenAndPermission failed")
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (s *OpenAuthApplicationService) UpdatePersonalAccessTokenAndPermission(ctx context.Context, req *openapimodel.UpdatePersonalAccessTokenAndPermissionRequest) (*openapimodel.UpdatePersonalAccessTokenAndPermissionResponse, error) {
|
||||
resp := new(openapimodel.UpdatePersonalAccessTokenAndPermissionResponse)
|
||||
userID := ctxutil.GetUIDFromCtx(ctx)
|
||||
|
||||
upErr := openapiAuthDomainSVC.Save(ctx, &entity.SaveMeta{
|
||||
ID: req.ID,
|
||||
Name: ptr.Of(req.Name),
|
||||
UserID: *userID,
|
||||
})
|
||||
|
||||
return resp, upErr
|
||||
}
|
||||
|
||||
func (s *OpenAuthApplicationService) UpdateLastUsedAt(ctx context.Context, apiID int64, userID int64) error {
|
||||
upErr := openapiAuthDomainSVC.Save(ctx, &entity.SaveMeta{
|
||||
ID: apiID,
|
||||
LastUsedAt: ptr.Of(time.Now().Unix()),
|
||||
UserID: userID,
|
||||
})
|
||||
return upErr
|
||||
}
|
||||
|
||||
func (s *OpenAuthApplicationService) CheckPermission(ctx context.Context, token string) (*entity.ApiKey, error) {
|
||||
appReq := &entity.CheckPermission{
|
||||
ApiKey: token,
|
||||
}
|
||||
apiKey, err := openapiAuthDomainSVC.CheckPermission(ctx, appReq)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "OpenAuthApplicationService.CheckPermission failed, err=%v", err)
|
||||
return nil, errors.New("CheckPermission failed")
|
||||
}
|
||||
return apiKey, nil
|
||||
}
|
||||
79
backend/application/plugin/init.go
Normal file
79
backend/application/plugin/init.go
Normal file
@@ -0,0 +1,79 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
pluginConf "github.com/coze-dev/coze-studio/backend/domain/plugin/conf"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/repository"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/service"
|
||||
search "github.com/coze-dev/coze-studio/backend/domain/search/service"
|
||||
user "github.com/coze-dev/coze-studio/backend/domain/user/service"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
)
|
||||
|
||||
type ServiceComponents struct {
|
||||
IDGen idgen.IDGenerator
|
||||
DB *gorm.DB
|
||||
OSS storage.Storage
|
||||
EventBus search.ResourceEventBus
|
||||
UserSVC user.User
|
||||
}
|
||||
|
||||
func InitService(ctx context.Context, components *ServiceComponents) (*PluginApplicationService, error) {
|
||||
err := pluginConf.InitConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
toolRepo := repository.NewToolRepo(&repository.ToolRepoComponents{
|
||||
IDGen: components.IDGen,
|
||||
DB: components.DB,
|
||||
})
|
||||
|
||||
pluginRepo := repository.NewPluginRepo(&repository.PluginRepoComponents{
|
||||
IDGen: components.IDGen,
|
||||
DB: components.DB,
|
||||
})
|
||||
|
||||
oauthRepo := repository.NewOAuthRepo(&repository.OAuthRepoComponents{
|
||||
IDGen: components.IDGen,
|
||||
DB: components.DB,
|
||||
})
|
||||
|
||||
pluginSVC := service.NewService(&service.Components{
|
||||
IDGen: components.IDGen,
|
||||
DB: components.DB,
|
||||
OSS: components.OSS,
|
||||
PluginRepo: pluginRepo,
|
||||
ToolRepo: toolRepo,
|
||||
OAuthRepo: oauthRepo,
|
||||
})
|
||||
|
||||
PluginApplicationSVC.DomainSVC = pluginSVC
|
||||
PluginApplicationSVC.eventbus = components.EventBus
|
||||
PluginApplicationSVC.oss = components.OSS
|
||||
PluginApplicationSVC.userSVC = components.UserSVC
|
||||
PluginApplicationSVC.pluginRepo = pluginRepo
|
||||
PluginApplicationSVC.toolRepo = toolRepo
|
||||
|
||||
return PluginApplicationSVC, nil
|
||||
}
|
||||
34
backend/application/plugin/model.go
Normal file
34
backend/application/plugin/model.go
Normal file
@@ -0,0 +1,34 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package plugin
|
||||
|
||||
import (
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
|
||||
)
|
||||
|
||||
type CopyPluginRequest struct {
|
||||
PluginID int64
|
||||
UserID int64
|
||||
CopyScene model.CopyScene
|
||||
TargetAPPID *int64
|
||||
}
|
||||
|
||||
type CopyPluginResponse struct {
|
||||
Plugin *entity.PluginInfo
|
||||
Tools map[int64]*entity.ToolInfo // old tool id -> new tool id
|
||||
}
|
||||
1759
backend/application/plugin/plugin.go
Normal file
1759
backend/application/plugin/plugin.go
Normal file
File diff suppressed because it is too large
Load Diff
34
backend/application/prompt/init.go
Normal file
34
backend/application/prompt/init.go
Normal file
@@ -0,0 +1,34 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package prompt
|
||||
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/application/search"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/prompt/repository"
|
||||
prompt "github.com/coze-dev/coze-studio/backend/domain/prompt/service"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
)
|
||||
|
||||
func InitService(db *gorm.DB, idGenSVC idgen.IDGenerator, re search.ResourceEventBus) *PromptApplicationService {
|
||||
repo := repository.NewPromptRepo(db, idGenSVC)
|
||||
PromptSVC.DomainSVC = prompt.NewService(repo)
|
||||
PromptSVC.eventbus = re
|
||||
|
||||
return PromptSVC
|
||||
}
|
||||
238
backend/application/prompt/prompt.go
Normal file
238
backend/application/prompt/prompt.go
Normal file
@@ -0,0 +1,238 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package prompt
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/playground"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/resource/common"
|
||||
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
|
||||
"github.com/coze-dev/coze-studio/backend/application/search"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/prompt/entity"
|
||||
prompt "github.com/coze-dev/coze-studio/backend/domain/prompt/service"
|
||||
searchEntity "github.com/coze-dev/coze-studio/backend/domain/search/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
type PromptApplicationService struct {
|
||||
DomainSVC prompt.Prompt
|
||||
eventbus search.ResourceEventBus
|
||||
}
|
||||
|
||||
var PromptSVC = &PromptApplicationService{}
|
||||
|
||||
func (p *PromptApplicationService) UpsertPromptResource(ctx context.Context, req *playground.UpsertPromptResourceRequest) (resp *playground.UpsertPromptResourceResponse, err error) {
|
||||
session := ctxutil.GetUserSessionFromCtx(ctx)
|
||||
if session == nil {
|
||||
return nil, errorx.New(errno.ErrPromptPermissionCode, errorx.KV("msg", "no session data provided"))
|
||||
}
|
||||
|
||||
promptID := req.Prompt.GetID()
|
||||
if promptID == 0 {
|
||||
// create a new prompt resource
|
||||
resp, err = p.createPromptResource(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pErr := p.eventbus.PublishResources(ctx, &searchEntity.ResourceDomainEvent{
|
||||
OpType: searchEntity.Created,
|
||||
Resource: &searchEntity.ResourceDocument{
|
||||
ResType: common.ResType_Prompt,
|
||||
ResID: resp.Data.ID,
|
||||
Name: req.Prompt.Name,
|
||||
SpaceID: req.Prompt.SpaceID,
|
||||
OwnerID: &session.UserID,
|
||||
PublishStatus: ptr.Of(common.PublishStatus_Published),
|
||||
},
|
||||
})
|
||||
if pErr != nil {
|
||||
logs.CtxErrorf(ctx, "publish resource event failed: %v", pErr)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// update an existing prompt resource
|
||||
resp, err = p.updatePromptResource(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pErr := p.eventbus.PublishResources(ctx, &searchEntity.ResourceDomainEvent{
|
||||
OpType: searchEntity.Updated,
|
||||
Resource: &searchEntity.ResourceDocument{
|
||||
ResType: common.ResType_Prompt,
|
||||
ResID: resp.Data.ID,
|
||||
Name: req.Prompt.Name,
|
||||
SpaceID: req.Prompt.SpaceID,
|
||||
},
|
||||
})
|
||||
if pErr != nil {
|
||||
logs.CtxErrorf(ctx, "publish resource event failed: %v", pErr)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (p *PromptApplicationService) GetPromptResourceInfo(ctx context.Context, req *playground.GetPromptResourceInfoRequest) (
|
||||
resp *playground.GetPromptResourceInfoResponse, err error,
|
||||
) {
|
||||
promptInfo, err := p.DomainSVC.GetPromptResource(ctx, req.GetPromptResourceID())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &playground.GetPromptResourceInfoResponse{
|
||||
Data: promptInfoDo2To(promptInfo),
|
||||
Code: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *PromptApplicationService) GetOfficialPromptResourceList(ctx context.Context, c *playground.GetOfficialPromptResourceListRequest) (
|
||||
*playground.GetOfficialPromptResourceListResponse, error,
|
||||
) {
|
||||
session := ctxutil.GetUserSessionFromCtx(ctx)
|
||||
if session == nil {
|
||||
return nil, errorx.New(errno.ErrPromptPermissionCode, errorx.KV("msg", "no session data provided"))
|
||||
}
|
||||
|
||||
promptList, err := p.DomainSVC.ListOfficialPromptResource(ctx, c.GetKeyword())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &playground.GetOfficialPromptResourceListResponse{
|
||||
PromptResourceList: slices.Transform(promptList, func(p *entity.PromptResource) *playground.PromptResource {
|
||||
return promptInfoDo2To(p)
|
||||
}),
|
||||
Code: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *PromptApplicationService) DeletePromptResource(ctx context.Context, req *playground.DeletePromptResourceRequest) (resp *playground.DeletePromptResourceResponse, err error) {
|
||||
uid := ctxutil.GetUIDFromCtx(ctx)
|
||||
if uid == nil {
|
||||
return nil, errorx.New(errno.ErrPromptPermissionCode, errorx.KV("msg", "no session data provided"))
|
||||
}
|
||||
|
||||
promptInfo, err := p.DomainSVC.GetPromptResource(ctx, req.GetPromptResourceID())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if promptInfo.CreatorID != *uid {
|
||||
return nil, errorx.New(errno.ErrPromptPermissionCode, errorx.KV("msg", "no permission"))
|
||||
}
|
||||
|
||||
err = p.DomainSVC.DeletePromptResource(ctx, req.GetPromptResourceID())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pErr := p.eventbus.PublishResources(ctx, &searchEntity.ResourceDomainEvent{
|
||||
OpType: searchEntity.Deleted,
|
||||
Resource: &searchEntity.ResourceDocument{
|
||||
ResType: common.ResType_Prompt,
|
||||
ResID: req.GetPromptResourceID(),
|
||||
},
|
||||
})
|
||||
if pErr != nil {
|
||||
logs.CtxErrorf(ctx, "publish resource event failed: %v", pErr)
|
||||
}
|
||||
|
||||
return &playground.DeletePromptResourceResponse{
|
||||
Code: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *PromptApplicationService) createPromptResource(ctx context.Context, req *playground.UpsertPromptResourceRequest) (resp *playground.UpsertPromptResourceResponse, err error) {
|
||||
do := p.toPromptResourceDO(req.Prompt)
|
||||
uid := ctxutil.GetUIDFromCtx(ctx)
|
||||
|
||||
do.CreatorID = *uid
|
||||
|
||||
promptID, err := p.DomainSVC.CreatePromptResource(ctx, do)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &playground.UpsertPromptResourceResponse{
|
||||
Data: &playground.ShowPromptResource{
|
||||
ID: promptID,
|
||||
},
|
||||
Code: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *PromptApplicationService) updatePromptResource(ctx context.Context, req *playground.UpsertPromptResourceRequest) (resp *playground.UpsertPromptResourceResponse, err error) {
|
||||
promptID := req.Prompt.GetID()
|
||||
|
||||
promptResource, err := p.DomainSVC.GetPromptResource(ctx, promptID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logs.CtxInfof(ctx, "promptResource.SpaceID: %v , promptResource.CreatorID : %v", promptResource.SpaceID, promptResource.CreatorID)
|
||||
uid := ctxutil.GetUIDFromCtx(ctx)
|
||||
|
||||
if promptResource.CreatorID != *uid {
|
||||
return nil, errorx.New(errno.ErrPromptPermissionCode, errorx.KV("msg", "no permission"))
|
||||
}
|
||||
|
||||
promptResource.Name = req.Prompt.GetName()
|
||||
promptResource.Description = req.Prompt.GetDescription()
|
||||
promptResource.PromptText = req.Prompt.GetPromptText()
|
||||
|
||||
err = p.DomainSVC.UpdatePromptResource(ctx, promptResource)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &playground.UpsertPromptResourceResponse{
|
||||
Data: &playground.ShowPromptResource{
|
||||
ID: promptID,
|
||||
},
|
||||
Code: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *PromptApplicationService) toPromptResourceDO(m *playground.PromptResource) *entity.PromptResource {
|
||||
e := entity.PromptResource{}
|
||||
e.ID = m.GetID()
|
||||
e.PromptText = m.GetPromptText()
|
||||
e.SpaceID = m.GetSpaceID()
|
||||
e.Name = m.GetName()
|
||||
e.Description = m.GetDescription()
|
||||
|
||||
return &e
|
||||
}
|
||||
|
||||
func promptInfoDo2To(p *entity.PromptResource) *playground.PromptResource {
|
||||
return &playground.PromptResource{
|
||||
ID: ptr.Of(p.ID),
|
||||
SpaceID: ptr.Of(p.SpaceID),
|
||||
Name: ptr.Of(p.Name),
|
||||
Description: ptr.Of(p.Description),
|
||||
PromptText: ptr.Of(p.PromptText),
|
||||
}
|
||||
}
|
||||
100
backend/application/search/init.go
Normal file
100
backend/application/search/init.go
Normal file
@@ -0,0 +1,100 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package search
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/application/singleagent"
|
||||
app "github.com/coze-dev/coze-studio/backend/domain/app/service"
|
||||
connector "github.com/coze-dev/coze-studio/backend/domain/connector/service"
|
||||
knowledge "github.com/coze-dev/coze-studio/backend/domain/knowledge/service"
|
||||
database "github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/service"
|
||||
prompt "github.com/coze-dev/coze-studio/backend/domain/prompt/service"
|
||||
search "github.com/coze-dev/coze-studio/backend/domain/search/service"
|
||||
user "github.com/coze-dev/coze-studio/backend/domain/user/service"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/es"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/cache/redis"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/eventbus"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
)
|
||||
|
||||
type ServiceComponents struct {
|
||||
DB *gorm.DB
|
||||
Cache *redis.Client
|
||||
TOS storage.Storage
|
||||
ESClient es.Client
|
||||
ProjectEventBus ProjectEventBus
|
||||
ResourceEventBus ResourceEventBus
|
||||
SingleAgentDomainSVC singleagent.SingleAgent
|
||||
APPDomainSVC app.AppService
|
||||
KnowledgeDomainSVC knowledge.Knowledge
|
||||
PluginDomainSVC service.PluginService
|
||||
WorkflowDomainSVC workflow.Service
|
||||
UserDomainSVC user.User
|
||||
ConnectorDomainSVC connector.Connector
|
||||
PromptDomainSVC prompt.Prompt
|
||||
DatabaseDomainSVC database.Database
|
||||
}
|
||||
|
||||
func InitService(ctx context.Context, s *ServiceComponents) (*SearchApplicationService, error) {
|
||||
searchDomainSVC := search.NewDomainService(ctx, s.ESClient)
|
||||
|
||||
SearchSVC.DomainSVC = searchDomainSVC
|
||||
SearchSVC.ServiceComponents = s
|
||||
|
||||
// setup consumer
|
||||
searchConsumer := search.NewProjectHandler(ctx, s.ESClient)
|
||||
|
||||
logs.Infof("start search domain consumer...")
|
||||
nameServer := os.Getenv(consts.MQServer)
|
||||
|
||||
err := eventbus.RegisterConsumer(nameServer, consts.RMQTopicApp, consts.RMQConsumeGroupApp, searchConsumer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("register search consumer failed, err=%w", err)
|
||||
}
|
||||
|
||||
searchResourceConsumer := search.NewResourceHandler(ctx, s.ESClient)
|
||||
|
||||
err = eventbus.RegisterConsumer(nameServer, consts.RMQTopicResource, consts.RMQConsumeGroupResource, searchResourceConsumer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("register search consumer failed, err=%w", err)
|
||||
}
|
||||
|
||||
return SearchSVC, nil
|
||||
}
|
||||
|
||||
type (
|
||||
ResourceEventBus = search.ResourceEventBus
|
||||
ProjectEventBus = search.ProjectEventBus
|
||||
)
|
||||
|
||||
func NewResourceEventBus(p eventbus.Producer) search.ResourceEventBus {
|
||||
return search.NewResourceEventBus(p)
|
||||
}
|
||||
|
||||
func NewProjectEventBus(p eventbus.Producer) search.ProjectEventBus {
|
||||
return search.NewProjectEventBus(p)
|
||||
}
|
||||
194
backend/application/search/project_pack.go
Normal file
194
backend/application/search/project_pack.go
Normal file
@@ -0,0 +1,194 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package search
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/intelligence"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/intelligence/common"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/app/entity"
|
||||
appService "github.com/coze-dev/coze-studio/backend/domain/app/service"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
)
|
||||
|
||||
type projectInfo struct {
|
||||
iconURI string
|
||||
desc string
|
||||
}
|
||||
|
||||
type ProjectPacker interface {
|
||||
GetProjectInfo(ctx context.Context) (*projectInfo, error)
|
||||
GetPermissionInfo() *intelligence.IntelligencePermissionInfo
|
||||
GetPublishedInfo(ctx context.Context) *intelligence.IntelligencePublishInfo
|
||||
GetUserInfo(ctx context.Context, userID int64) *common.User
|
||||
}
|
||||
|
||||
func NewPackProject(uid, projectID int64, tp common.IntelligenceType, s *SearchApplicationService) (ProjectPacker, error) {
|
||||
base := projectBase{SVC: s, projectID: projectID, iType: tp, uid: uid}
|
||||
|
||||
switch tp {
|
||||
case common.IntelligenceType_Bot:
|
||||
return &agentPacker{projectBase: base}, nil
|
||||
case common.IntelligenceType_Project:
|
||||
return &appPacker{projectBase: base}, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unsupported project_type: %d , project_id : %d", tp, projectID)
|
||||
}
|
||||
|
||||
type projectBase struct {
|
||||
projectID int64 // agent_id or application_id
|
||||
uid int64
|
||||
SVC *SearchApplicationService
|
||||
iType common.IntelligenceType
|
||||
}
|
||||
|
||||
func (p *projectBase) GetPermissionInfo() *intelligence.IntelligencePermissionInfo {
|
||||
return &intelligence.IntelligencePermissionInfo{
|
||||
InCollaboration: false,
|
||||
CanDelete: true,
|
||||
CanView: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *projectBase) GetUserInfo(ctx context.Context, userID int64) *common.User {
|
||||
u, err := p.SVC.UserDomainSVC.GetUserInfo(ctx, userID)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "[projectBase-GetUserInfo] failed to get user info, user_id: %d, err: %v", userID, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return &common.User{
|
||||
UserID: u.UserID,
|
||||
AvatarURL: u.IconURL,
|
||||
UserUniqueName: u.UniqueName,
|
||||
}
|
||||
}
|
||||
|
||||
type agentPacker struct {
|
||||
projectBase
|
||||
}
|
||||
|
||||
func (a *agentPacker) GetProjectInfo(ctx context.Context) (*projectInfo, error) {
|
||||
agent, err := a.SVC.SingleAgentDomainSVC.GetSingleAgentDraft(ctx, a.projectID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if agent == nil {
|
||||
return nil, fmt.Errorf("agent info is nil")
|
||||
}
|
||||
return &projectInfo{
|
||||
iconURI: agent.IconURI,
|
||||
desc: agent.Desc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *agentPacker) GetPublishedInfo(ctx context.Context) *intelligence.IntelligencePublishInfo {
|
||||
pubInfo, err := p.SVC.SingleAgentDomainSVC.GetPublishedInfo(ctx, p.projectID)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "[agent-GetPublishedInfo]failed to get published info, agent_id: %d, err: %v", p.projectID, err)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
connectors := make([]*common.ConnectorInfo, 0, len(pubInfo.ConnectorID2PublishTime))
|
||||
for connectorID := range pubInfo.ConnectorID2PublishTime {
|
||||
c, err := p.SVC.ConnectorDomainSVC.GetByID(ctx, connectorID)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "failed to get connector by id: %d, err: %v", connectorID, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
connectors = append(connectors, &common.ConnectorInfo{
|
||||
ID: conv.Int64ToStr(c.ID),
|
||||
Name: c.Name,
|
||||
ConnectorStatus: common.ConnectorDynamicStatus(c.ConnectorStatus),
|
||||
Icon: c.URL,
|
||||
})
|
||||
}
|
||||
|
||||
return &intelligence.IntelligencePublishInfo{
|
||||
PublishTime: conv.Int64ToStr(pubInfo.LastPublishTimeMS / 1000),
|
||||
HasPublished: pubInfo.LastPublishTimeMS > 0,
|
||||
Connectors: connectors,
|
||||
}
|
||||
}
|
||||
|
||||
type appPacker struct {
|
||||
projectBase
|
||||
}
|
||||
|
||||
func (a *appPacker) GetProjectInfo(ctx context.Context) (*projectInfo, error) {
|
||||
app, err := a.SVC.APPDomainSVC.GetDraftAPP(ctx, a.projectID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &projectInfo{
|
||||
iconURI: app.GetIconURI(),
|
||||
desc: app.GetDesc(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *appPacker) GetPublishedInfo(ctx context.Context) *intelligence.IntelligencePublishInfo {
|
||||
record, exist, err := a.SVC.APPDomainSVC.GetAPPPublishRecord(ctx, &appService.GetAPPPublishRecordRequest{
|
||||
APPID: a.projectID,
|
||||
Oldest: true,
|
||||
})
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "[app-GetPublishedInfo] failed to get published info, app_id=%d, err=%v", a.projectID, err)
|
||||
return nil
|
||||
}
|
||||
if !exist {
|
||||
return &intelligence.IntelligencePublishInfo{
|
||||
PublishTime: "",
|
||||
HasPublished: false,
|
||||
Connectors: nil,
|
||||
}
|
||||
}
|
||||
|
||||
connectorInfo := make([]*common.ConnectorInfo, 0, len(record.ConnectorPublishRecords))
|
||||
connectorIDs := slices.Transform(record.ConnectorPublishRecords, func(c *entity.ConnectorPublishRecord) int64 {
|
||||
return c.ConnectorID
|
||||
})
|
||||
|
||||
connectors, err := a.SVC.ConnectorDomainSVC.GetByIDs(ctx, connectorIDs)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "[app-GetPublishedInfo] failed to get connector info, app_id=%d, err=%v", a.projectID, err)
|
||||
} else {
|
||||
for _, c := range connectors {
|
||||
connectorInfo = append(connectorInfo, &common.ConnectorInfo{
|
||||
ID: conv.Int64ToStr(c.ID),
|
||||
Name: c.Name,
|
||||
ConnectorStatus: common.ConnectorDynamicStatus(c.ConnectorStatus),
|
||||
Icon: c.URL,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return &intelligence.IntelligencePublishInfo{
|
||||
PublishTime: strconv.FormatInt(record.APP.GetPublishedAtMS()/1000, 10),
|
||||
HasPublished: record.APP.Published(),
|
||||
Connectors: connectorInfo,
|
||||
}
|
||||
}
|
||||
452
backend/application/search/project_search.go
Normal file
452
backend/application/search/project_search.go
Normal file
@@ -0,0 +1,452 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package search
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
search2 "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/search"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/flow/marketplace/marketplace_common"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/flow/marketplace/product_common"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/flow/marketplace/product_public_api"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/intelligence"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/intelligence/common"
|
||||
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
|
||||
searchEntity "github.com/coze-dev/coze-studio/backend/domain/search/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/taskgroup"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
var projectType2iconURI = map[common.IntelligenceType]string{
|
||||
common.IntelligenceType_Bot: consts.DefaultAgentIcon,
|
||||
common.IntelligenceType_Project: consts.DefaultAppIcon,
|
||||
}
|
||||
|
||||
func (s *SearchApplicationService) GetDraftIntelligenceList(ctx context.Context, req *intelligence.GetDraftIntelligenceListRequest) (
|
||||
resp *intelligence.GetDraftIntelligenceListResponse, err error,
|
||||
) {
|
||||
userID := ctxutil.GetUIDFromCtx(ctx)
|
||||
if userID == nil {
|
||||
return nil, errorx.New(errno.ErrSearchPermissionCode, errorx.KV("msg", "session is required"))
|
||||
}
|
||||
|
||||
do := searchRequestTo2Do(*userID, req)
|
||||
|
||||
searchResp, err := s.DomainSVC.SearchProjects(ctx, do)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(searchResp.Data) == 0 {
|
||||
return &intelligence.GetDraftIntelligenceListResponse{
|
||||
Data: &intelligence.DraftIntelligenceListData{
|
||||
Intelligences: make([]*intelligence.IntelligenceData, 0),
|
||||
Total: 0,
|
||||
HasMore: false,
|
||||
NextCursorID: "",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
tasks := taskgroup.NewUninterruptibleTaskGroup(ctx, len(searchResp.Data))
|
||||
lock := sync.Mutex{}
|
||||
intelligenceDataList := make([]*intelligence.IntelligenceData, len(searchResp.Data))
|
||||
|
||||
logs.CtxDebugf(ctx, "[GetDraftIntelligenceList] searchResp.Data: %v", conv.DebugJsonToStr(searchResp.Data))
|
||||
|
||||
for idx := range searchResp.Data {
|
||||
data := searchResp.Data[idx]
|
||||
index := idx
|
||||
tasks.Go(func() error {
|
||||
info, err := s.packIntelligenceData(ctx, data)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "[packIntelligenceData] failed id %v, type %d , name %s, err: %v", data.ID, data.Type, data.GetName(), err)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
intelligenceDataList[index] = info
|
||||
return nil
|
||||
})
|
||||
|
||||
s.packIntelligenceData(ctx, data)
|
||||
}
|
||||
|
||||
_ = tasks.Wait()
|
||||
filterDataList := make([]*intelligence.IntelligenceData, 0)
|
||||
for _, data := range intelligenceDataList {
|
||||
if data != nil {
|
||||
filterDataList = append(filterDataList, data)
|
||||
}
|
||||
}
|
||||
|
||||
return &intelligence.GetDraftIntelligenceListResponse{
|
||||
Code: 0,
|
||||
Data: &intelligence.DraftIntelligenceListData{
|
||||
Intelligences: filterDataList,
|
||||
Total: int32(len(filterDataList)),
|
||||
HasMore: searchResp.HasMore,
|
||||
NextCursorID: searchResp.NextCursor,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SearchApplicationService) PublicFavoriteProduct(ctx context.Context, req *product_public_api.FavoriteProductRequest) (*product_public_api.FavoriteProductResponse, error) {
|
||||
isFav := !req.GetIsCancel()
|
||||
entityID := req.GetEntityID()
|
||||
typ := req.GetEntityType()
|
||||
|
||||
switch req.GetEntityType() {
|
||||
case product_common.ProductEntityType_Bot, product_common.ProductEntityType_Project:
|
||||
err := s.favoriteProject(ctx, entityID, typ, isFav)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
default:
|
||||
return nil, errorx.New(errno.ErrSearchInvalidParamCode, errorx.KV("msg", fmt.Sprintf("invalid entity type '%d'", req.GetEntityType())))
|
||||
}
|
||||
|
||||
return &product_public_api.FavoriteProductResponse{
|
||||
IsFirstFavorite: ptr.Of(false),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SearchApplicationService) favoriteProject(ctx context.Context, projectID int64, typ product_common.ProductEntityType, isFav bool) error {
|
||||
var entityType common.IntelligenceType
|
||||
if typ == product_common.ProductEntityType_Bot {
|
||||
entityType = common.IntelligenceType_Bot
|
||||
} else {
|
||||
entityType = common.IntelligenceType_Project
|
||||
}
|
||||
err := s.ProjectEventBus.PublishProject(ctx, &searchEntity.ProjectDomainEvent{
|
||||
OpType: searchEntity.Updated,
|
||||
Project: &searchEntity.ProjectDocument{
|
||||
ID: projectID,
|
||||
IsFav: ptr.Of(ternary.IFElse(isFav, 1, 0)),
|
||||
FavTimeMS: ptr.Of(time.Now().UnixMilli()),
|
||||
Type: entityType,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SearchApplicationService) PublicGetUserFavoriteList(ctx context.Context, req *product_public_api.GetUserFavoriteListV2Request) (resp *product_public_api.GetUserFavoriteListV2Response, err error) {
|
||||
userID := ctxutil.GetUIDFromCtx(ctx)
|
||||
if userID == nil {
|
||||
return nil, errorx.New(errno.ErrSearchPermissionCode, errorx.KV("msg", "session required"))
|
||||
}
|
||||
|
||||
var data *product_public_api.GetUserFavoriteListDataV2
|
||||
switch req.GetEntityType() {
|
||||
case product_common.ProductEntityType_Project, product_common.ProductEntityType_Bot, product_common.ProductEntityType_Common:
|
||||
data, err = s.searchFavProjects(ctx, *userID, req)
|
||||
default:
|
||||
return nil, errorx.New(errno.ErrSearchInvalidParamCode, errorx.KV("msg", fmt.Sprintf("invalid entity type '%d'", req.GetEntityType())))
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp = &product_public_api.GetUserFavoriteListV2Response{
|
||||
Data: data,
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (s *SearchApplicationService) searchFavProjects(ctx context.Context, userID int64, req *product_public_api.GetUserFavoriteListV2Request) (*product_public_api.GetUserFavoriteListDataV2, error) {
|
||||
var types []common.IntelligenceType
|
||||
if req.GetEntityType() == product_common.ProductEntityType_Common {
|
||||
types = []common.IntelligenceType{common.IntelligenceType_Bot, common.IntelligenceType_Project}
|
||||
} else if req.GetEntityType() == product_common.ProductEntityType_Bot {
|
||||
types = []common.IntelligenceType{common.IntelligenceType_Bot}
|
||||
} else {
|
||||
types = []common.IntelligenceType{common.IntelligenceType_Project}
|
||||
}
|
||||
|
||||
res, err := SearchSVC.DomainSVC.SearchProjects(ctx, &searchEntity.SearchProjectsRequest{
|
||||
OwnerID: userID,
|
||||
Types: types,
|
||||
IsFav: true,
|
||||
OrderFiledName: search2.FieldOfFavTime,
|
||||
OrderAsc: false,
|
||||
Limit: req.PageSize,
|
||||
Cursor: req.GetCursorID(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(res.Data) == 0 {
|
||||
return &product_public_api.GetUserFavoriteListDataV2{
|
||||
FavoriteEntities: []*product_common.FavoriteEntity{},
|
||||
CursorID: res.NextCursor,
|
||||
HasMore: res.HasMore,
|
||||
}, nil
|
||||
}
|
||||
|
||||
favEntities := make([]*product_common.FavoriteEntity, 0, len(res.Data))
|
||||
for _, r := range res.Data {
|
||||
favEntity, err := s.projectResourceToProductInfo(ctx, userID, r)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "[pluginResourceToProductInfo] failed to get project info, id=%v, type=%d, err=%v",
|
||||
r.ID, r.Type, err)
|
||||
continue
|
||||
}
|
||||
favEntities = append(favEntities, favEntity)
|
||||
}
|
||||
|
||||
data := &product_public_api.GetUserFavoriteListDataV2{
|
||||
FavoriteEntities: favEntities,
|
||||
CursorID: res.NextCursor,
|
||||
HasMore: res.HasMore,
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (s *SearchApplicationService) projectResourceToProductInfo(ctx context.Context, userID int64, doc *searchEntity.ProjectDocument) (favEntity *product_common.FavoriteEntity, err error) {
|
||||
typ := func() product_common.ProductEntityType {
|
||||
if doc.Type == common.IntelligenceType_Bot {
|
||||
return product_common.ProductEntityType_Bot
|
||||
}
|
||||
return product_common.ProductEntityType_Project
|
||||
}()
|
||||
|
||||
packer, err := NewPackProject(userID, doc.ID, doc.Type, s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pi, err := packer.GetProjectInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ui := packer.GetUserInfo(ctx, userID)
|
||||
|
||||
var userInfo *product_common.UserInfo
|
||||
if ui != nil {
|
||||
userInfo = &product_common.UserInfo{
|
||||
UserID: ui.UserID,
|
||||
UserName: ui.UserUniqueName,
|
||||
Name: ui.Nickname,
|
||||
AvatarURL: ui.AvatarURL,
|
||||
FollowType: ptr.Of(marketplace_common.FollowType_Unknown),
|
||||
}
|
||||
}
|
||||
|
||||
e := &product_common.FavoriteEntity{
|
||||
EntityID: doc.ID,
|
||||
EntityType: typ,
|
||||
Name: doc.GetName(),
|
||||
IconURL: pi.iconURI,
|
||||
Description: pi.desc,
|
||||
SpaceID: doc.GetSpaceID(),
|
||||
HasSpacePermission: true,
|
||||
FavoriteAt: doc.GetFavTime(),
|
||||
UserInfo: userInfo,
|
||||
}
|
||||
|
||||
return e, nil
|
||||
}
|
||||
|
||||
func (s *SearchApplicationService) GetUserRecentlyEditIntelligence(ctx context.Context, req intelligence.GetUserRecentlyEditIntelligenceRequest) (
|
||||
resp *intelligence.GetUserRecentlyEditIntelligenceResponse, err error,
|
||||
) {
|
||||
userID := ctxutil.GetUIDFromCtx(ctx)
|
||||
if userID == nil {
|
||||
return nil, errorx.New(errno.ErrSearchPermissionCode, errorx.KV("msg", "session required"))
|
||||
}
|
||||
|
||||
res, err := SearchSVC.DomainSVC.SearchProjects(ctx, &searchEntity.SearchProjectsRequest{
|
||||
OwnerID: *userID,
|
||||
Types: req.Types,
|
||||
IsRecentlyOpen: true,
|
||||
OrderFiledName: search2.FieldOfRecentlyOpenTime,
|
||||
OrderAsc: false,
|
||||
Limit: req.Size,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
intelligenceDataList := make([]*intelligence.IntelligenceData, 0, len(res.Data))
|
||||
for idx := range res.Data {
|
||||
data := res.Data[idx]
|
||||
info, err := s.packIntelligenceData(ctx, data)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "[packIntelligenceData] failed id %v, type %d, name %s, err: %v", data.ID, data.Type, data.GetName(), err)
|
||||
continue
|
||||
}
|
||||
intelligenceDataList = append(intelligenceDataList, info)
|
||||
}
|
||||
|
||||
resp = &intelligence.GetUserRecentlyEditIntelligenceResponse{
|
||||
Data: &intelligence.GetUserRecentlyEditIntelligenceData{
|
||||
IntelligenceInfoList: intelligenceDataList,
|
||||
},
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (s *SearchApplicationService) packIntelligenceData(ctx context.Context, doc *searchEntity.ProjectDocument) (*intelligence.IntelligenceData, error) {
|
||||
intelligenceData := &intelligence.IntelligenceData{
|
||||
Type: doc.Type,
|
||||
BasicInfo: &common.IntelligenceBasicInfo{
|
||||
ID: doc.ID,
|
||||
Name: doc.GetName(),
|
||||
SpaceID: doc.GetSpaceID(),
|
||||
OwnerID: doc.GetOwnerID(),
|
||||
Status: doc.Status,
|
||||
CreateTime: doc.GetCreateTime() / 1000,
|
||||
UpdateTime: doc.GetUpdateTime() / 1000,
|
||||
PublishTime: doc.GetPublishTime() / 1000,
|
||||
},
|
||||
}
|
||||
|
||||
uid := ctxutil.MustGetUIDFromCtx(ctx)
|
||||
|
||||
packer, err := NewPackProject(uid, doc.ID, doc.Type, s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
projInfo, err := packer.GetProjectInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, errorx.Wrapf(err, "GetProjectInfo failed, id: %v, type: %v", doc.ID, doc.Type)
|
||||
}
|
||||
|
||||
intelligenceData.BasicInfo.Description = projInfo.desc
|
||||
intelligenceData.BasicInfo.IconURI = projInfo.iconURI
|
||||
intelligenceData.BasicInfo.IconURL = s.getProjectIconURL(ctx, projInfo.iconURI, doc.Type)
|
||||
intelligenceData.PermissionInfo = packer.GetPermissionInfo()
|
||||
|
||||
publishedInf := packer.GetPublishedInfo(ctx)
|
||||
if publishedInf != nil {
|
||||
intelligenceData.PublishInfo = packer.GetPublishedInfo(ctx)
|
||||
} else {
|
||||
intelligenceData.PublishInfo = &intelligence.IntelligencePublishInfo{
|
||||
HasPublished: false,
|
||||
}
|
||||
}
|
||||
|
||||
intelligenceData.OwnerInfo = packer.GetUserInfo(ctx, doc.GetOwnerID())
|
||||
intelligenceData.LatestAuditInfo = &common.AuditInfo{}
|
||||
intelligenceData.FavoriteInfo = s.buildProjectFavoriteInfo(doc)
|
||||
intelligenceData.OtherInfo = s.buildProjectOtherInfo(doc)
|
||||
|
||||
return intelligenceData, nil
|
||||
}
|
||||
|
||||
func (s *SearchApplicationService) buildProjectFavoriteInfo(doc *searchEntity.ProjectDocument) *intelligence.FavoriteInfo {
|
||||
isFav := doc.GetIsFav()
|
||||
favTime := doc.GetFavTime()
|
||||
|
||||
return &intelligence.FavoriteInfo{
|
||||
IsFav: isFav,
|
||||
FavTime: conv.Int64ToStr(favTime / 1000),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SearchApplicationService) buildProjectOtherInfo(doc *searchEntity.ProjectDocument) *intelligence.OtherInfo {
|
||||
otherInfo := &intelligence.OtherInfo{
|
||||
BotMode: intelligence.BotMode_SingleMode,
|
||||
RecentlyOpenTime: conv.Int64ToStr(doc.GetRecentlyOpenTime() / 1000),
|
||||
}
|
||||
if doc.Type == common.IntelligenceType_Project {
|
||||
otherInfo.BotMode = intelligence.BotMode_WorkflowMode
|
||||
}
|
||||
|
||||
return otherInfo
|
||||
}
|
||||
|
||||
func searchRequestTo2Do(userID int64, req *intelligence.GetDraftIntelligenceListRequest) *searchEntity.SearchProjectsRequest {
|
||||
orderBy := func() string {
|
||||
switch req.GetOrderBy() {
|
||||
case intelligence.OrderBy_PublishTime:
|
||||
return search2.FieldOfPublishTime
|
||||
case intelligence.OrderBy_UpdateTime:
|
||||
return search2.FieldOfUpdateTime
|
||||
case intelligence.OrderBy_CreateTime:
|
||||
return search2.FieldOfCreateTime
|
||||
default:
|
||||
return search2.FieldOfUpdateTime
|
||||
}
|
||||
}()
|
||||
|
||||
searchReq := &searchEntity.SearchProjectsRequest{
|
||||
SpaceID: req.GetSpaceID(),
|
||||
Name: req.GetName(),
|
||||
OwnerID: 0,
|
||||
Limit: req.GetSize(),
|
||||
Cursor: req.GetCursorID(),
|
||||
OrderFiledName: orderBy,
|
||||
OrderAsc: false,
|
||||
Types: req.GetTypes(),
|
||||
Status: req.GetStatus(),
|
||||
IsFav: req.GetIsFav(),
|
||||
IsRecentlyOpen: req.GetRecentlyOpen(),
|
||||
IsPublished: req.GetHasPublished(),
|
||||
}
|
||||
|
||||
if req.GetSearchScope() == intelligence.SearchScope_CreateByMe {
|
||||
searchReq.OwnerID = userID
|
||||
}
|
||||
|
||||
return searchReq
|
||||
}
|
||||
|
||||
func (s *SearchApplicationService) getProjectDefaultIconURL(ctx context.Context, tp common.IntelligenceType) string {
|
||||
iconURL, ok := projectType2iconURI[tp]
|
||||
if !ok {
|
||||
logs.CtxWarnf(ctx, "[getProjectDefaultIconURL] don't have type: %d default icon", tp)
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
return s.getURL(ctx, iconURL)
|
||||
}
|
||||
|
||||
func (s *SearchApplicationService) getProjectIconURL(ctx context.Context, uri string, tp common.IntelligenceType) string {
|
||||
if uri == "" {
|
||||
return s.getProjectDefaultIconURL(ctx, tp)
|
||||
}
|
||||
|
||||
url := s.getURL(ctx, uri)
|
||||
if url != "" {
|
||||
return url
|
||||
}
|
||||
|
||||
return s.getProjectDefaultIconURL(ctx, tp)
|
||||
}
|
||||
325
backend/application/search/resource_pack.go
Normal file
325
backend/application/search/resource_pack.go
Normal file
@@ -0,0 +1,325 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package search
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/resource/common"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/table"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/knowledge/service"
|
||||
dbservice "github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
)
|
||||
|
||||
var defaultAction = []*common.ResourceAction{
|
||||
{
|
||||
Key: common.ActionKey_Edit,
|
||||
Enable: true,
|
||||
},
|
||||
{
|
||||
Key: common.ActionKey_Delete,
|
||||
Enable: true,
|
||||
},
|
||||
{
|
||||
Key: common.ActionKey_Copy,
|
||||
Enable: true,
|
||||
},
|
||||
}
|
||||
|
||||
type ResourcePacker interface {
|
||||
GetDataInfo(ctx context.Context) (*dataInfo, error)
|
||||
GetActions(ctx context.Context) []*common.ResourceAction
|
||||
GetProjectDefaultActions(ctx context.Context) []*common.ProjectResourceAction
|
||||
}
|
||||
|
||||
func NewResourcePacker(resID int64, t common.ResType, appContext *ServiceComponents) (ResourcePacker, error) {
|
||||
base := resourceBasePacker{appContext: appContext, resID: resID}
|
||||
|
||||
switch t {
|
||||
case common.ResType_Plugin:
|
||||
return &pluginPacker{resourceBasePacker: base}, nil
|
||||
case common.ResType_Workflow:
|
||||
return &workflowPacker{resourceBasePacker: base}, nil
|
||||
case common.ResType_Knowledge:
|
||||
return &knowledgePacker{resourceBasePacker: base}, nil
|
||||
case common.ResType_Prompt:
|
||||
return &promptPacker{resourceBasePacker: base}, nil
|
||||
case common.ResType_Database:
|
||||
return &databasePacker{resourceBasePacker: base}, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unsupported resource type: %s , resID: %d", t, resID)
|
||||
}
|
||||
|
||||
type resourceBasePacker struct {
|
||||
resID int64
|
||||
appContext *ServiceComponents
|
||||
}
|
||||
|
||||
type dataInfo struct {
|
||||
iconURI *string
|
||||
iconURL string
|
||||
desc *string
|
||||
status *int32
|
||||
}
|
||||
|
||||
func (b *resourceBasePacker) GetActions(ctx context.Context) []*common.ResourceAction {
|
||||
return defaultAction
|
||||
}
|
||||
|
||||
func (b *resourceBasePacker) GetProjectDefaultActions(ctx context.Context) []*common.ProjectResourceAction {
|
||||
return []*common.ProjectResourceAction{}
|
||||
}
|
||||
|
||||
type pluginPacker struct {
|
||||
resourceBasePacker
|
||||
}
|
||||
|
||||
func (p *pluginPacker) GetDataInfo(ctx context.Context) (*dataInfo, error) {
|
||||
plugin, err := p.appContext.PluginDomainSVC.GetDraftPlugin(ctx, p.resID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
iconURL, err := p.appContext.TOS.GetObjectUrl(ctx, plugin.GetIconURI())
|
||||
if err != nil {
|
||||
logs.CtxWarnf(ctx, "get icon url failed with '%s', err=%v", plugin.GetIconURI(), err)
|
||||
}
|
||||
|
||||
return &dataInfo{
|
||||
iconURI: ptr.Of(plugin.GetIconURI()),
|
||||
iconURL: iconURL,
|
||||
desc: ptr.Of(plugin.GetDesc()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *pluginPacker) GetProjectDefaultActions(ctx context.Context) []*common.ProjectResourceAction {
|
||||
return []*common.ProjectResourceAction{
|
||||
{
|
||||
Key: common.ProjectResourceActionKey_Rename,
|
||||
Enable: true,
|
||||
},
|
||||
{
|
||||
Key: common.ProjectResourceActionKey_Copy,
|
||||
Enable: true,
|
||||
},
|
||||
{
|
||||
Key: common.ProjectResourceActionKey_Delete,
|
||||
Enable: true,
|
||||
},
|
||||
{
|
||||
Key: common.ProjectResourceActionKey_CopyToLibrary,
|
||||
Enable: true,
|
||||
},
|
||||
{
|
||||
Key: common.ProjectResourceActionKey_MoveToLibrary,
|
||||
Enable: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type workflowPacker struct {
|
||||
resourceBasePacker
|
||||
}
|
||||
|
||||
func (w *workflowPacker) GetDataInfo(ctx context.Context) (*dataInfo, error) {
|
||||
info, err := w.appContext.WorkflowDomainSVC.Get(ctx, &vo.GetPolicy{
|
||||
ID: w.resID,
|
||||
MetaOnly: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &dataInfo{
|
||||
iconURI: &info.IconURI,
|
||||
iconURL: info.IconURL,
|
||||
desc: &info.Desc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (w *workflowPacker) GetProjectDefaultActions(ctx context.Context) []*common.ProjectResourceAction {
|
||||
return []*common.ProjectResourceAction{
|
||||
{
|
||||
Key: common.ProjectResourceActionKey_Rename,
|
||||
Enable: true,
|
||||
},
|
||||
{
|
||||
Key: common.ProjectResourceActionKey_Copy,
|
||||
Enable: true,
|
||||
},
|
||||
{
|
||||
Key: common.ProjectResourceActionKey_CopyToLibrary,
|
||||
Enable: true,
|
||||
},
|
||||
{
|
||||
Key: common.ProjectResourceActionKey_MoveToLibrary,
|
||||
Enable: true,
|
||||
},
|
||||
{
|
||||
Key: common.ProjectResourceActionKey_Delete,
|
||||
Enable: true,
|
||||
},
|
||||
{
|
||||
Key: common.ProjectResourceActionKey_UpdateDesc,
|
||||
Enable: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type knowledgePacker struct {
|
||||
resourceBasePacker
|
||||
}
|
||||
|
||||
func (k *knowledgePacker) GetDataInfo(ctx context.Context) (*dataInfo, error) {
|
||||
res, err := k.appContext.KnowledgeDomainSVC.GetKnowledgeByID(ctx, &service.GetKnowledgeByIDRequest{
|
||||
KnowledgeID: k.resID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
kn := res.Knowledge
|
||||
|
||||
return &dataInfo{
|
||||
iconURI: ptr.Of(kn.IconURI),
|
||||
iconURL: kn.IconURL,
|
||||
desc: ptr.Of(kn.Description),
|
||||
status: ptr.Of(int32(kn.Status)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (k *knowledgePacker) GetActions(ctx context.Context) []*common.ResourceAction {
|
||||
return []*common.ResourceAction{
|
||||
{
|
||||
Key: common.ActionKey_Delete,
|
||||
Enable: true,
|
||||
},
|
||||
{
|
||||
Key: common.ActionKey_EnableSwitch,
|
||||
Enable: true,
|
||||
},
|
||||
{
|
||||
Key: common.ActionKey_Edit,
|
||||
Enable: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
func (k *knowledgePacker) GetProjectDefaultActions(ctx context.Context) []*common.ProjectResourceAction {
|
||||
return []*common.ProjectResourceAction{
|
||||
{
|
||||
Key: common.ProjectResourceActionKey_Rename,
|
||||
Enable: true,
|
||||
},
|
||||
{
|
||||
Key: common.ProjectResourceActionKey_Copy,
|
||||
Enable: true,
|
||||
},
|
||||
{
|
||||
Key: common.ProjectResourceActionKey_CopyToLibrary,
|
||||
Enable: true,
|
||||
},
|
||||
{
|
||||
Key: common.ProjectResourceActionKey_MoveToLibrary,
|
||||
Enable: true,
|
||||
},
|
||||
{
|
||||
Key: common.ProjectResourceActionKey_Delete,
|
||||
Enable: true,
|
||||
},
|
||||
{
|
||||
Key: common.ProjectResourceActionKey_Disable,
|
||||
Enable: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type promptPacker struct {
|
||||
resourceBasePacker
|
||||
}
|
||||
|
||||
func (p *promptPacker) GetDataInfo(ctx context.Context) (*dataInfo, error) {
|
||||
pInfo, err := p.appContext.PromptDomainSVC.GetPromptResource(ctx, p.resID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &dataInfo{
|
||||
iconURI: nil, // prompt don't have custom icon
|
||||
iconURL: "",
|
||||
desc: &pInfo.Description,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type databasePacker struct {
|
||||
resourceBasePacker
|
||||
}
|
||||
|
||||
func (d *databasePacker) GetDataInfo(ctx context.Context) (*dataInfo, error) {
|
||||
listResp, err := d.appContext.DatabaseDomainSVC.MGetDatabase(ctx, &dbservice.MGetDatabaseRequest{Basics: []*database.DatabaseBasic{
|
||||
{
|
||||
ID: d.resID,
|
||||
TableType: table.TableType_OnlineTable,
|
||||
},
|
||||
}})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(listResp.Databases) == 0 {
|
||||
return nil, fmt.Errorf("online database not found, id: %d", d.resID)
|
||||
}
|
||||
|
||||
return &dataInfo{
|
||||
iconURI: ptr.Of(listResp.Databases[0].IconURI),
|
||||
iconURL: listResp.Databases[0].IconURL,
|
||||
desc: ptr.Of(listResp.Databases[0].TableDesc),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *databasePacker) GetActions(ctx context.Context) []*common.ResourceAction {
|
||||
return []*common.ResourceAction{
|
||||
{
|
||||
Key: common.ActionKey_Delete,
|
||||
Enable: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (d *databasePacker) GetProjectDefaultActions(ctx context.Context) []*common.ProjectResourceAction {
|
||||
return []*common.ProjectResourceAction{
|
||||
{
|
||||
Key: common.ProjectResourceActionKey_Copy,
|
||||
Enable: true,
|
||||
},
|
||||
{
|
||||
Key: common.ProjectResourceActionKey_CopyToLibrary,
|
||||
Enable: true,
|
||||
},
|
||||
{
|
||||
Key: common.ProjectResourceActionKey_MoveToLibrary,
|
||||
Enable: true,
|
||||
},
|
||||
{
|
||||
Key: common.ProjectResourceActionKey_Delete,
|
||||
Enable: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
392
backend/application/search/resource_search.go
Normal file
392
backend/application/search/resource_search.go
Normal file
@@ -0,0 +1,392 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package search
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"slices"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
knowledgeModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/resource"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/resource/common"
|
||||
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/search/entity"
|
||||
search "github.com/coze-dev/coze-studio/backend/domain/search/service"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/taskgroup"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
var SearchSVC = &SearchApplicationService{}
|
||||
|
||||
type SearchApplicationService struct {
|
||||
*ServiceComponents
|
||||
DomainSVC search.Search
|
||||
}
|
||||
|
||||
var resType2iconURI = map[common.ResType]string{
|
||||
common.ResType_Plugin: consts.DefaultPluginIcon,
|
||||
common.ResType_Workflow: consts.DefaultWorkflowIcon,
|
||||
common.ResType_Knowledge: consts.DefaultDatasetIcon,
|
||||
common.ResType_Prompt: consts.DefaultPromptIcon,
|
||||
common.ResType_Database: consts.DefaultDatabaseIcon,
|
||||
// ResType_UI: consts.DefaultWorkflowIcon,
|
||||
// ResType_Voice: consts.DefaultPluginIcon,
|
||||
// ResType_Imageflow: consts.DefaultPluginIcon,
|
||||
}
|
||||
|
||||
func (s *SearchApplicationService) LibraryResourceList(ctx context.Context, req *resource.LibraryResourceListRequest) (resp *resource.LibraryResourceListResponse, err error) {
|
||||
userID := ctxutil.GetUIDFromCtx(ctx)
|
||||
if userID == nil {
|
||||
return nil, errorx.New(errno.ErrSearchPermissionCode, errorx.KV("msg", "session required"))
|
||||
}
|
||||
|
||||
searchReq := &entity.SearchResourcesRequest{
|
||||
SpaceID: req.GetSpaceID(),
|
||||
OwnerID: 0,
|
||||
Name: req.GetName(),
|
||||
ResTypeFilter: req.GetResTypeFilter(),
|
||||
PublishStatusFilter: req.GetPublishStatusFilter(),
|
||||
SearchKeys: req.GetSearchKeys(),
|
||||
Cursor: req.GetCursor(),
|
||||
Limit: req.GetSize(),
|
||||
}
|
||||
|
||||
// 设置用户过滤
|
||||
if req.IsSetUserFilter() && req.GetUserFilter() > 0 {
|
||||
searchReq.OwnerID = ptr.From(userID)
|
||||
}
|
||||
|
||||
searchResp, err := s.DomainSVC.SearchResources(ctx, searchReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
lock := sync.Mutex{}
|
||||
tasks := taskgroup.NewUninterruptibleTaskGroup(ctx, 10)
|
||||
resources := make([]*common.ResourceInfo, len(searchResp.Data))
|
||||
for idx := range searchResp.Data {
|
||||
v := searchResp.Data[idx]
|
||||
index := idx
|
||||
tasks.Go(func() error {
|
||||
ri, err := s.packResource(ctx, v)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "[LibraryResourceList] packResource failed, will ignore resID: %d, Name : %s, resType: %d, err: %v",
|
||||
v.ResID, v.GetName(), v.ResType, err)
|
||||
return err
|
||||
}
|
||||
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
resources[index] = ri
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
_ = tasks.Wait()
|
||||
filterResource := make([]*common.ResourceInfo, 0)
|
||||
for _, res := range resources {
|
||||
if res == nil {
|
||||
continue
|
||||
}
|
||||
filterResource = append(filterResource, res)
|
||||
}
|
||||
|
||||
return &resource.LibraryResourceListResponse{
|
||||
Code: 0,
|
||||
ResourceList: filterResource,
|
||||
Cursor: ptr.Of(searchResp.NextCursor),
|
||||
HasMore: searchResp.HasMore,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SearchApplicationService) getResourceDefaultIconURL(ctx context.Context, tp common.ResType) string {
|
||||
iconURL, ok := resType2iconURI[tp]
|
||||
if !ok {
|
||||
logs.CtxWarnf(ctx, "[getDefaultIconURL] don't have type: %d default icon", tp)
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
return s.getURL(ctx, iconURL)
|
||||
}
|
||||
|
||||
func (s *SearchApplicationService) getURL(ctx context.Context, uri string) string {
|
||||
url, err := s.TOS.GetObjectUrl(ctx, uri)
|
||||
if err != nil {
|
||||
logs.CtxWarnf(ctx, "[getDefaultIconURLWitURI] GetObjectUrl failed, uri: %s, err: %v", uri, err)
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
return url
|
||||
}
|
||||
|
||||
func (s *SearchApplicationService) getResourceIconURL(ctx context.Context, uri *string, tp common.ResType) string {
|
||||
if uri == nil || *uri == "" {
|
||||
return s.getResourceDefaultIconURL(ctx, tp)
|
||||
}
|
||||
|
||||
url := s.getURL(ctx, *uri)
|
||||
if url != "" {
|
||||
return url
|
||||
}
|
||||
|
||||
return s.getResourceDefaultIconURL(ctx, tp)
|
||||
}
|
||||
|
||||
func (s *SearchApplicationService) packUserInfo(ctx context.Context, ri *common.ResourceInfo, ownerID int64) *common.ResourceInfo {
|
||||
u, err := s.UserDomainSVC.GetUserInfo(ctx, ownerID)
|
||||
if err != nil {
|
||||
logs.CtxWarnf(ctx, "[LibraryResourceList] GetUserInfo failed, uid: %d, resID: %d, Name : %s, err: %v",
|
||||
ownerID, ri.ResID, ri.GetName(), err)
|
||||
} else {
|
||||
ri.CreatorName = ptr.Of(u.Name)
|
||||
ri.CreatorAvatar = ptr.Of(u.IconURL)
|
||||
}
|
||||
|
||||
if ri.GetCreatorAvatar() == "" {
|
||||
ri.CreatorAvatar = ptr.Of(s.getURL(ctx, consts.DefaultUserIcon))
|
||||
}
|
||||
|
||||
return ri
|
||||
}
|
||||
|
||||
func (s *SearchApplicationService) packResource(ctx context.Context, doc *entity.ResourceDocument) (*common.ResourceInfo, error) {
|
||||
ri := &common.ResourceInfo{
|
||||
ResID: ptr.Of(doc.ResID),
|
||||
ResType: ptr.Of(doc.ResType),
|
||||
Name: doc.Name,
|
||||
SpaceID: doc.SpaceID,
|
||||
CreatorID: doc.OwnerID,
|
||||
ResSubType: doc.ResSubType,
|
||||
PublishStatus: doc.PublishStatus,
|
||||
EditTime: ptr.Of(doc.GetUpdateTime() / 1000),
|
||||
}
|
||||
|
||||
if doc.BizStatus != nil {
|
||||
ri.BizResStatus = ptr.Of(int32(*doc.BizStatus))
|
||||
}
|
||||
|
||||
packer, err := NewResourcePacker(doc.ResID, doc.ResType, s.ServiceComponents)
|
||||
if err != nil {
|
||||
return nil, errorx.Wrapf(err, "NewResourcePacker failed")
|
||||
}
|
||||
|
||||
ri = s.packUserInfo(ctx, ri, doc.GetOwnerID())
|
||||
ri.Actions = packer.GetActions(ctx)
|
||||
|
||||
data, err := packer.GetDataInfo(ctx)
|
||||
if err != nil {
|
||||
logs.CtxWarnf(ctx, "[packResource] GetDataInfo failed, resID: %d, Name : %s, resType: %d, err: %v",
|
||||
doc.ResID, doc.GetName(), doc.ResType, err)
|
||||
|
||||
ri.Icon = ptr.Of(s.getResourceDefaultIconURL(ctx, doc.ResType))
|
||||
|
||||
return ri, nil // Warn : weak dependency data
|
||||
}
|
||||
ri.BizResStatus = data.status
|
||||
ri.Desc = data.desc
|
||||
ri.Icon = ternary.IFElse(len(data.iconURL) > 0,
|
||||
&data.iconURL, ptr.Of(s.getResourceIconURL(ctx, data.iconURI, doc.ResType)))
|
||||
ri.BizExtend = map[string]string{
|
||||
"url": ptr.From(ri.Icon),
|
||||
}
|
||||
return ri, nil
|
||||
}
|
||||
|
||||
func (s *SearchApplicationService) ProjectResourceList(ctx context.Context, req *resource.ProjectResourceListRequest) (resp *resource.ProjectResourceListResponse, err error) {
|
||||
resources, err := s.getAPPAllResources(ctx, req.GetProjectID())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resourceGroups, err := s.packAPPResources(ctx, resources)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resourceGroups = s.sortAPPResources(resourceGroups)
|
||||
|
||||
return &resource.ProjectResourceListResponse{
|
||||
ResourceGroups: resourceGroups,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SearchApplicationService) getAPPAllResources(ctx context.Context, appID int64) ([]*entity.ResourceDocument, error) {
|
||||
cursor := ""
|
||||
resources := make([]*entity.ResourceDocument, 0, 100)
|
||||
|
||||
for {
|
||||
res, err := s.DomainSVC.SearchResources(ctx, &entity.SearchResourcesRequest{
|
||||
APPID: appID,
|
||||
Cursor: cursor,
|
||||
Limit: 100,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resources = append(resources, res.Data...)
|
||||
|
||||
hasMore := res.HasMore
|
||||
cursor = res.NextCursor
|
||||
|
||||
if !hasMore {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return resources, nil
|
||||
}
|
||||
|
||||
func (s *SearchApplicationService) packAPPResources(ctx context.Context, resources []*entity.ResourceDocument) ([]*common.ProjectResourceGroup, error) {
|
||||
workflowGroup := &common.ProjectResourceGroup{
|
||||
GroupType: common.ProjectResourceGroupType_Workflow,
|
||||
ResourceList: []*common.ProjectResourceInfo{},
|
||||
}
|
||||
dataGroup := &common.ProjectResourceGroup{
|
||||
GroupType: common.ProjectResourceGroupType_Data,
|
||||
ResourceList: []*common.ProjectResourceInfo{},
|
||||
}
|
||||
pluginGroup := &common.ProjectResourceGroup{
|
||||
GroupType: common.ProjectResourceGroupType_Plugin,
|
||||
ResourceList: []*common.ProjectResourceInfo{},
|
||||
}
|
||||
|
||||
lock := sync.Mutex{}
|
||||
tasks := taskgroup.NewUninterruptibleTaskGroup(ctx, 10)
|
||||
for idx := range resources {
|
||||
v := resources[idx]
|
||||
|
||||
tasks.Go(func() error {
|
||||
ri, err := s.packProjectResource(ctx, v)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "packAPPResources failed, will ignore resID: %d, Name : %s, resType: %d, err: %v",
|
||||
v.ResID, v.GetName(), v.ResType, err)
|
||||
return err
|
||||
}
|
||||
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
switch v.ResType {
|
||||
case common.ResType_Workflow:
|
||||
workflowGroup.ResourceList = append(workflowGroup.ResourceList, ri)
|
||||
case common.ResType_Plugin:
|
||||
pluginGroup.ResourceList = append(pluginGroup.ResourceList, ri)
|
||||
case common.ResType_Database, common.ResType_Knowledge:
|
||||
dataGroup.ResourceList = append(dataGroup.GetResourceList(), ri)
|
||||
default:
|
||||
logs.CtxWarnf(ctx, "unsupported resType: %d", v.ResType)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
_ = tasks.Wait()
|
||||
|
||||
resourceGroups := []*common.ProjectResourceGroup{
|
||||
workflowGroup,
|
||||
pluginGroup,
|
||||
dataGroup,
|
||||
}
|
||||
|
||||
return resourceGroups, nil
|
||||
}
|
||||
|
||||
func (s *SearchApplicationService) packProjectResource(ctx context.Context, resource *entity.ResourceDocument) (*common.ProjectResourceInfo, error) {
|
||||
packer, err := NewResourcePacker(resource.ResID, resource.ResType, s.ServiceComponents)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
info := &common.ProjectResourceInfo{
|
||||
ResID: resource.ResID,
|
||||
ResType: resource.ResType,
|
||||
ResSubType: resource.ResSubType,
|
||||
Name: resource.GetName(),
|
||||
Actions: packer.GetProjectDefaultActions(ctx),
|
||||
}
|
||||
|
||||
if resource.ResType == common.ResType_Knowledge {
|
||||
info.BizExtend = map[string]string{
|
||||
"format_type": strconv.FormatInt(int64(resource.GetResSubType()), 10),
|
||||
}
|
||||
di, err := packer.GetDataInfo(ctx)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "GetDataInfo failed, resID=%d, resType=%d, err=%v",
|
||||
resource.ResID, resource.ResType, err)
|
||||
} else {
|
||||
info.BizResStatus = ptr.Of(*di.status)
|
||||
if *di.status == int32(knowledgeModel.KnowledgeStatusDisable) {
|
||||
actions := slices.Clone(info.Actions)
|
||||
for _, a := range actions {
|
||||
if a.Key == common.ProjectResourceActionKey_Disable {
|
||||
a.Key = common.ProjectResourceActionKey_Enable
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if resource.ResType == common.ResType_Plugin {
|
||||
err = s.PluginDomainSVC.CheckPluginToolsDebugStatus(ctx, resource.ResID)
|
||||
if err != nil {
|
||||
var e errorx.StatusError
|
||||
if !errors.As(err, &e) {
|
||||
logs.CtxErrorf(ctx, "CheckPluginToolsDebugStatus failed, resID=%d, resType=%d, err=%v",
|
||||
resource.ResID, resource.ResType, err)
|
||||
} else {
|
||||
actions := slices.Clone(info.Actions)
|
||||
for _, a := range actions {
|
||||
if a.Key == common.ProjectResourceActionKey_MoveToLibrary ||
|
||||
a.Key == common.ProjectResourceActionKey_CopyToLibrary {
|
||||
a.Enable = false
|
||||
a.Hint = ptr.Of(e.Msg())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
func (s *ServiceComponents) sortAPPResources(resourceGroups []*common.ProjectResourceGroup) []*common.ProjectResourceGroup {
|
||||
for _, g := range resourceGroups {
|
||||
slices.SortFunc(g.ResourceList, func(a, b *common.ProjectResourceInfo) int {
|
||||
if a.Name == b.Name {
|
||||
return 0
|
||||
}
|
||||
if a.Name < b.Name {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
})
|
||||
}
|
||||
return resourceGroups
|
||||
}
|
||||
40
backend/application/shortcutcmd/init.go
Normal file
40
backend/application/shortcutcmd/init.go
Normal file
@@ -0,0 +1,40 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package shortcutcmd
|
||||
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/repository"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/service"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
)
|
||||
|
||||
var ShortcutCmdSVC *ShortcutCmdApplicationService
|
||||
|
||||
func InitService(db *gorm.DB, idGenSVC idgen.IDGenerator) *ShortcutCmdApplicationService {
|
||||
|
||||
components := &service.Components{
|
||||
ShortCutCmdRepo: repository.NewShortCutCmdRepo(db, idGenSVC),
|
||||
}
|
||||
shortcutCmdDomainSVC := service.NewShortcutCommandService(components)
|
||||
|
||||
ShortcutCmdSVC = &ShortcutCmdApplicationService{
|
||||
ShortCutDomainSVC: shortcutCmdDomainSVC,
|
||||
}
|
||||
return ShortcutCmdSVC
|
||||
}
|
||||
119
backend/application/shortcutcmd/shortcut_cmd.go
Normal file
119
backend/application/shortcutcmd/shortcut_cmd.go
Normal file
@@ -0,0 +1,119 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package shortcutcmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/playground"
|
||||
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/service"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
)
|
||||
|
||||
type ShortcutCmdApplicationService struct {
|
||||
ShortCutDomainSVC service.ShortcutCmd
|
||||
}
|
||||
|
||||
func (s *ShortcutCmdApplicationService) Handler(ctx context.Context, req *playground.CreateUpdateShortcutCommandRequest) (*playground.ShortcutCommand, error) {
|
||||
|
||||
cr, buildErr := s.buildReq(ctx, req)
|
||||
if buildErr != nil {
|
||||
return nil, buildErr
|
||||
}
|
||||
var err error
|
||||
var cmdDO *entity.ShortcutCmd
|
||||
if cr.CommandID > 0 {
|
||||
cmdDO, err = s.ShortCutDomainSVC.UpdateCMD(ctx, cr)
|
||||
} else {
|
||||
cmdDO, err = s.ShortCutDomainSVC.CreateCMD(ctx, cr)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if cmdDO == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return s.buildDo2Vo(ctx, cmdDO), nil
|
||||
}
|
||||
func (s *ShortcutCmdApplicationService) buildReq(ctx context.Context, req *playground.CreateUpdateShortcutCommandRequest) (*entity.ShortcutCmd, error) {
|
||||
|
||||
uid := ctxutil.MustGetUIDFromCtx(ctx)
|
||||
|
||||
var workflowID int64
|
||||
var pluginID int64
|
||||
var err error
|
||||
if req.GetShortcuts().GetWorkFlowID() != "" {
|
||||
workflowID, err = strconv.ParseInt(req.GetShortcuts().GetWorkFlowID(), 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if req.GetShortcuts().GetPluginID() != "" {
|
||||
pluginID, err = strconv.ParseInt(req.GetShortcuts().GetPluginID(), 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &entity.ShortcutCmd{
|
||||
ObjectID: req.GetObjectID(),
|
||||
CommandID: req.GetShortcuts().CommandID,
|
||||
CommandName: req.GetShortcuts().CommandName,
|
||||
ShortcutCommand: req.GetShortcuts().ShortcutCommand,
|
||||
Description: req.GetShortcuts().Description,
|
||||
SendType: int32(req.GetShortcuts().SendType),
|
||||
ToolType: int32(req.GetShortcuts().ToolType),
|
||||
WorkFlowID: workflowID,
|
||||
PluginID: pluginID,
|
||||
Components: req.GetShortcuts().ComponentsList,
|
||||
CardSchema: req.GetShortcuts().CardSchema,
|
||||
ToolInfo: req.GetShortcuts().ToolInfo,
|
||||
CreatorID: uid,
|
||||
PluginToolID: req.GetShortcuts().PluginAPIID,
|
||||
PluginToolName: req.GetShortcuts().PluginAPIName,
|
||||
TemplateQuery: req.GetShortcuts().TemplateQuery,
|
||||
ShortcutIcon: req.GetShortcuts().ShortcutIcon,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *ShortcutCmdApplicationService) buildDo2Vo(ctx context.Context, do *entity.ShortcutCmd) *playground.ShortcutCommand {
|
||||
|
||||
return &playground.ShortcutCommand{
|
||||
ObjectID: do.ObjectID,
|
||||
CommandID: do.CommandID,
|
||||
CommandName: do.CommandName,
|
||||
ShortcutCommand: do.ShortcutCommand,
|
||||
Description: do.Description,
|
||||
SendType: playground.SendType(do.SendType),
|
||||
ToolType: playground.ToolType(do.ToolType),
|
||||
WorkFlowID: conv.Int64ToStr(do.WorkFlowID),
|
||||
PluginID: conv.Int64ToStr(do.PluginID),
|
||||
ComponentsList: do.Components,
|
||||
CardSchema: do.CardSchema,
|
||||
ToolInfo: do.ToolInfo,
|
||||
PluginAPIID: do.PluginToolID,
|
||||
PluginAPIName: do.PluginToolName,
|
||||
TemplateQuery: do.TemplateQuery,
|
||||
ShortcutIcon: do.ShortcutIcon,
|
||||
}
|
||||
}
|
||||
204
backend/application/singleagent/create.go
Normal file
204
backend/application/singleagent/create.go
Normal file
@@ -0,0 +1,204 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package singleagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
modelmgrEntity "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
|
||||
intelligence "github.com/coze-dev/coze-studio/backend/api/model/intelligence/common"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/bot_common"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/developer_api"
|
||||
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/modelmgr"
|
||||
searchEntity "github.com/coze-dev/coze-studio/backend/domain/search/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
func (s *SingleAgentApplicationService) CreateSingleAgentDraft(ctx context.Context, req *developer_api.DraftBotCreateRequest) (*developer_api.DraftBotCreateResponse, error) {
|
||||
do, err := s.draftBotCreateRequestToSingleAgent(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userID := ctxutil.MustGetUIDFromCtx(ctx)
|
||||
agentID, err := s.DomainSVC.CreateSingleAgentDraft(ctx, userID, do)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = s.appContext.EventBus.PublishProject(ctx, &searchEntity.ProjectDomainEvent{
|
||||
OpType: searchEntity.Created,
|
||||
Project: &searchEntity.ProjectDocument{
|
||||
Status: intelligence.IntelligenceStatus_Using,
|
||||
Type: intelligence.IntelligenceType_Bot,
|
||||
ID: agentID,
|
||||
SpaceID: &req.SpaceID,
|
||||
OwnerID: &userID,
|
||||
Name: &do.Name,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &developer_api.DraftBotCreateResponse{Data: &developer_api.DraftBotCreateData{
|
||||
BotID: agentID,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
func (s *SingleAgentApplicationService) draftBotCreateRequestToSingleAgent(ctx context.Context, req *developer_api.DraftBotCreateRequest) (*entity.SingleAgent, error) {
|
||||
sa, err := s.newDefaultSingleAgent(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sa.SpaceID = req.SpaceID
|
||||
sa.Name = req.GetName()
|
||||
sa.Desc = req.GetDescription()
|
||||
sa.IconURI = req.GetIconURI()
|
||||
|
||||
return sa, nil
|
||||
}
|
||||
|
||||
func (s *SingleAgentApplicationService) newDefaultSingleAgent(ctx context.Context) (*entity.SingleAgent, error) {
|
||||
mi, err := s.defaultModelInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
now := time.Now().UnixMilli()
|
||||
return &entity.SingleAgent{
|
||||
SingleAgent: &singleagent.SingleAgent{
|
||||
OnboardingInfo: &bot_common.OnboardingInfo{},
|
||||
ModelInfo: mi,
|
||||
Prompt: &bot_common.PromptInfo{},
|
||||
Plugin: []*bot_common.PluginInfo{},
|
||||
Knowledge: &bot_common.Knowledge{
|
||||
TopK: ptr.Of(int64(1)),
|
||||
MinScore: ptr.Of(float64(0.01)),
|
||||
SearchStrategy: ptr.Of(bot_common.SearchStrategy_SemanticSearch),
|
||||
RecallStrategy: &bot_common.RecallStrategy{
|
||||
UseNl2sql: ptr.Of(true),
|
||||
UseRerank: ptr.Of(true),
|
||||
UseRewrite: ptr.Of(true),
|
||||
},
|
||||
},
|
||||
Workflow: []*bot_common.WorkflowInfo{},
|
||||
SuggestReply: &bot_common.SuggestReplyInfo{},
|
||||
JumpConfig: &bot_common.JumpConfig{},
|
||||
Database: []*bot_common.Database{},
|
||||
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SingleAgentApplicationService) defaultModelInfo(ctx context.Context) (*bot_common.ModelInfo, error) {
|
||||
modelResp, err := s.appContext.ModelMgrDomainSVC.ListModel(ctx, &modelmgr.ListModelRequest{
|
||||
Status: []modelmgrEntity.ModelEntityStatus{modelmgrEntity.ModelEntityStatusDefault, modelmgrEntity.ModelEntityStatusInUse},
|
||||
Limit: 1,
|
||||
Cursor: nil,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(modelResp.ModelList) == 0 {
|
||||
return nil, errorx.New(errno.ErrAgentResourceNotFound, errorx.KV("type", "model"), errorx.KV("id", "default"))
|
||||
}
|
||||
|
||||
dm := modelResp.ModelList[0]
|
||||
|
||||
var temperature *float64
|
||||
if tp, ok := dm.FindParameter(modelmgrEntity.Temperature); ok {
|
||||
t, err := tp.GetFloat(modelmgrEntity.DefaultTypeBalance)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
temperature = ptr.Of(t)
|
||||
}
|
||||
|
||||
var maxTokens *int32
|
||||
if tp, ok := dm.FindParameter(modelmgrEntity.MaxTokens); ok {
|
||||
t, err := tp.GetInt(modelmgrEntity.DefaultTypeBalance)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
maxTokens = ptr.Of(int32(t))
|
||||
} else if dm.Meta.ConnConfig.MaxTokens != nil {
|
||||
maxTokens = ptr.Of(int32(*dm.Meta.ConnConfig.MaxTokens))
|
||||
}
|
||||
|
||||
var topP *float64
|
||||
if tp, ok := dm.FindParameter(modelmgrEntity.TopP); ok {
|
||||
t, err := tp.GetFloat(modelmgrEntity.DefaultTypeBalance)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
topP = ptr.Of(t)
|
||||
}
|
||||
|
||||
var topK *int32
|
||||
if tp, ok := dm.FindParameter(modelmgrEntity.TopK); ok {
|
||||
t, err := tp.GetInt(modelmgrEntity.DefaultTypeBalance)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
topK = ptr.Of(int32(t))
|
||||
}
|
||||
|
||||
var frequencyPenalty *float64
|
||||
if tp, ok := dm.FindParameter(modelmgrEntity.FrequencyPenalty); ok {
|
||||
t, err := tp.GetFloat(modelmgrEntity.DefaultTypeBalance)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
frequencyPenalty = ptr.Of(t)
|
||||
}
|
||||
|
||||
var presencePenalty *float64
|
||||
if tp, ok := dm.FindParameter(modelmgrEntity.PresencePenalty); ok {
|
||||
t, err := tp.GetFloat(modelmgrEntity.DefaultTypeBalance)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
presencePenalty = ptr.Of(t)
|
||||
}
|
||||
|
||||
return &bot_common.ModelInfo{
|
||||
ModelId: ptr.Of(dm.ID),
|
||||
Temperature: temperature,
|
||||
MaxTokens: maxTokens,
|
||||
TopP: topP,
|
||||
FrequencyPenalty: frequencyPenalty,
|
||||
PresencePenalty: presencePenalty,
|
||||
TopK: topK,
|
||||
ModelStyle: bot_common.ModelStylePtr(bot_common.ModelStyle_Balance),
|
||||
ShortMemoryPolicy: &bot_common.ShortMemoryPolicy{
|
||||
ContextMode: bot_common.ContextModePtr(bot_common.ContextMode_FunctionCall_2),
|
||||
HistoryRound: ptr.Of[int32](3),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
177
backend/application/singleagent/duplicate.go
Normal file
177
backend/application/singleagent/duplicate.go
Normal file
@@ -0,0 +1,177 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package singleagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
intelligence "github.com/coze-dev/coze-studio/backend/api/model/intelligence/common"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/developer_api"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/project_memory"
|
||||
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossplugin"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
|
||||
searchEntity "github.com/coze-dev/coze-studio/backend/domain/search/entity"
|
||||
shortcutCMDEntity "github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
)
|
||||
|
||||
type duplicateAgentResourceFn func(ctx context.Context, appContext *ServiceComponents, oldAgent, newAgent *entity.SingleAgent) (*entity.SingleAgent, error)
|
||||
|
||||
func (s *SingleAgentApplicationService) DuplicateDraftBot(ctx context.Context, req *developer_api.DuplicateDraftBotRequest) (*developer_api.DuplicateDraftBotResponse, error) {
|
||||
draftAgent, err := s.ValidateAgentDraftAccess(ctx, req.BotID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newAgentID, err := s.appContext.IDGen.GenID(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userID := ctxutil.MustGetUIDFromCtx(ctx)
|
||||
duplicateInfo := &entity.DuplicateInfo{
|
||||
NewAgentID: newAgentID,
|
||||
SpaceID: req.GetSpaceID(),
|
||||
UserID: userID,
|
||||
DraftAgent: draftAgent,
|
||||
}
|
||||
|
||||
newAgent, err := s.DomainSVC.DuplicateInMemory(ctx, duplicateInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
duplicateFns := []duplicateAgentResourceFn{
|
||||
duplicateVariables,
|
||||
duplicatePlugin,
|
||||
duplicateShortCommand,
|
||||
}
|
||||
|
||||
for _, fn := range duplicateFns {
|
||||
newAgent, err = fn(ctx, s.appContext, draftAgent, newAgent)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
_, err = s.DomainSVC.CreateSingleAgentDraftWithID(ctx, userID, newAgentID, newAgent)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userInfo, err := s.appContext.UserDomainSVC.GetUserInfo(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = s.appContext.EventBus.PublishProject(ctx, &searchEntity.ProjectDomainEvent{
|
||||
OpType: searchEntity.Created,
|
||||
Project: &searchEntity.ProjectDocument{
|
||||
Status: intelligence.IntelligenceStatus_Using,
|
||||
Type: intelligence.IntelligenceType_Bot,
|
||||
ID: newAgent.AgentID,
|
||||
SpaceID: &req.SpaceID,
|
||||
OwnerID: &userID,
|
||||
Name: &newAgent.Name,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &developer_api.DuplicateDraftBotResponse{
|
||||
Data: &developer_api.DuplicateDraftBotData{
|
||||
BotID: newAgent.AgentID,
|
||||
Name: newAgent.Name,
|
||||
UserInfo: &developer_api.Creator{
|
||||
ID: userID,
|
||||
Name: userInfo.Name,
|
||||
AvatarURL: userInfo.IconURL,
|
||||
Self: userID == draftAgent.CreatorID,
|
||||
UserUniqueName: userInfo.UniqueName,
|
||||
UserLabel: nil,
|
||||
},
|
||||
},
|
||||
Code: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func duplicateVariables(ctx context.Context, appContext *ServiceComponents, oldAgent, newAgent *entity.SingleAgent) (*entity.SingleAgent, error) {
|
||||
if oldAgent.VariablesMetaID == nil || *oldAgent.VariablesMetaID <= 0 {
|
||||
return newAgent, nil
|
||||
}
|
||||
|
||||
vars, err := appContext.VariablesDomainSVC.GetVariableMetaByID(ctx, *oldAgent.VariablesMetaID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
vars.ID = 0
|
||||
vars.BizID = conv.Int64ToStr(newAgent.AgentID)
|
||||
vars.BizType = project_memory.VariableConnector_Bot
|
||||
vars.Version = ""
|
||||
vars.CreatorID = newAgent.CreatorID
|
||||
|
||||
varMetaID, err := appContext.VariablesDomainSVC.UpsertMeta(ctx, vars)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newAgent.VariablesMetaID = &varMetaID
|
||||
|
||||
return newAgent, nil
|
||||
}
|
||||
|
||||
func duplicatePlugin(ctx context.Context, _ *ServiceComponents, oldAgent, newAgent *entity.SingleAgent) (*entity.SingleAgent, error) {
|
||||
err := crossplugin.DefaultSVC().DuplicateDraftAgentTools(ctx, oldAgent.AgentID, newAgent.AgentID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newAgent, nil
|
||||
}
|
||||
|
||||
func duplicateShortCommand(ctx context.Context, appContext *ServiceComponents, oldAgent, newAgent *entity.SingleAgent) (*entity.SingleAgent, error) {
|
||||
metas, err := appContext.ShortcutCMDDomainSVC.ListCMD(ctx, &shortcutCMDEntity.ListMeta{
|
||||
SpaceID: oldAgent.SpaceID,
|
||||
ObjectID: oldAgent.AgentID,
|
||||
IsOnline: 0,
|
||||
CommandIDs: slices.Transform(oldAgent.ShortcutCommand, func(a string) int64 {
|
||||
return conv.StrToInt64D(a, 0)
|
||||
}),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
shortcutCommandIDs := make([]string, 0, len(metas))
|
||||
for _, meta := range metas {
|
||||
meta.ObjectID = newAgent.AgentID
|
||||
meta.CreatorID = newAgent.CreatorID
|
||||
do, err := appContext.ShortcutCMDDomainSVC.CreateCMD(ctx, meta)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
shortcutCommandIDs = append(shortcutCommandIDs, conv.Int64ToStr(do.CommandID))
|
||||
}
|
||||
|
||||
newAgent.ShortcutCommand = shortcutCommandIDs
|
||||
|
||||
return newAgent, nil
|
||||
}
|
||||
587
backend/application/singleagent/get.go
Normal file
587
backend/application/singleagent/get.go
Normal file
@@ -0,0 +1,587 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package singleagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/getkin/kin-openapi/openapi3"
|
||||
|
||||
knowledgeModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/bot_common"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/playground"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
|
||||
knowledge "github.com/coze-dev/coze-studio/backend/domain/knowledge/service"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/modelmgr"
|
||||
modelEntity "github.com/coze-dev/coze-studio/backend/domain/modelmgr/entity"
|
||||
pluginEntity "github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/service"
|
||||
shortcutCMDEntity "github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/entity"
|
||||
workflowEntity "github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
func (s *SingleAgentApplicationService) GetAgentBotInfo(ctx context.Context, req *playground.GetDraftBotInfoAgwRequest) (*playground.GetDraftBotInfoAgwResponse, error) {
|
||||
agentInfo, err := s.DomainSVC.GetSingleAgent(ctx, req.GetBotID(), req.GetVersion())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if agentInfo == nil {
|
||||
return nil, errorx.New(errno.ErrAgentInvalidParamCode, errorx.KVf("msg", "agent %d not found", req.GetBotID()))
|
||||
}
|
||||
|
||||
vo, err := s.singleAgentDraftDo2Vo(ctx, agentInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
klInfos, err := s.fetchKnowledgeDetails(ctx, agentInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
modelInfos, err := s.fetchModelDetails(ctx, agentInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
toolInfos, err := s.fetchToolDetails(ctx, agentInfo, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pluginInfos, err := s.fetchPluginDetails(ctx, agentInfo, toolInfos)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
workflowInfos, err := s.fetchWorkflowDetails(ctx, agentInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
shortCutCmdResp, err := s.fetchShortcutCMD(ctx, agentInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
workflowDetailMap, err := workflowDo2Vo(workflowInfos)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &playground.GetDraftBotInfoAgwResponse{
|
||||
Data: &playground.GetDraftBotInfoAgwData{
|
||||
BotInfo: vo,
|
||||
BotOptionData: &playground.BotOptionData{
|
||||
ModelDetailMap: modelInfoDo2Vo(modelInfos),
|
||||
KnowledgeDetailMap: knowledgeInfoDo2Vo(klInfos),
|
||||
PluginAPIDetailMap: toolInfoDo2Vo(toolInfos),
|
||||
PluginDetailMap: s.pluginInfoDo2Vo(ctx, pluginInfos),
|
||||
WorkflowDetailMap: workflowDetailMap,
|
||||
ShortcutCommandList: shortCutCmdResp,
|
||||
},
|
||||
SpaceID: agentInfo.SpaceID,
|
||||
Editable: ptr.Of(true),
|
||||
Deletable: ptr.Of(true),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SingleAgentApplicationService) fetchShortcutCMD(ctx context.Context, agentInfo *entity.SingleAgent) ([]*playground.ShortcutCommand, error) {
|
||||
var cmdVOs []*playground.ShortcutCommand
|
||||
if len(agentInfo.ShortcutCommand) == 0 {
|
||||
return cmdVOs, nil
|
||||
}
|
||||
|
||||
cmdDOs, err := s.appContext.ShortcutCMDDomainSVC.ListCMD(ctx, &shortcutCMDEntity.ListMeta{
|
||||
SpaceID: agentInfo.SpaceID,
|
||||
ObjectID: agentInfo.AgentID,
|
||||
CommandIDs: slices.Transform(agentInfo.ShortcutCommand, func(a string) int64 {
|
||||
return conv.StrToInt64D(a, 0)
|
||||
}),
|
||||
})
|
||||
|
||||
logs.CtxInfof(ctx, "fetchShortcutCMD cmdDOs = %v, err = %v", conv.DebugJsonToStr(cmdDOs), err)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cmdVOs = s.shortcutCMDDo2Vo(cmdDOs)
|
||||
return cmdVOs, nil
|
||||
}
|
||||
|
||||
func (s *SingleAgentApplicationService) shortcutCMDDo2Vo(cmdDOs []*shortcutCMDEntity.ShortcutCmd) []*playground.ShortcutCommand {
|
||||
return slices.Transform(cmdDOs, func(cmdDO *shortcutCMDEntity.ShortcutCmd) *playground.ShortcutCommand {
|
||||
return &playground.ShortcutCommand{
|
||||
ObjectID: cmdDO.ObjectID,
|
||||
CommandID: cmdDO.CommandID,
|
||||
CommandName: cmdDO.CommandName,
|
||||
ShortcutCommand: cmdDO.ShortcutCommand,
|
||||
Description: cmdDO.Description,
|
||||
SendType: playground.SendType(cmdDO.SendType),
|
||||
ToolType: playground.ToolType(cmdDO.ToolType),
|
||||
WorkFlowID: conv.Int64ToStr(cmdDO.WorkFlowID),
|
||||
PluginID: conv.Int64ToStr(cmdDO.PluginID),
|
||||
PluginAPIName: cmdDO.PluginToolName,
|
||||
PluginAPIID: cmdDO.PluginToolID,
|
||||
ShortcutIcon: cmdDO.ShortcutIcon,
|
||||
TemplateQuery: cmdDO.TemplateQuery,
|
||||
ComponentsList: cmdDO.Components,
|
||||
CardSchema: cmdDO.CardSchema,
|
||||
ToolInfo: cmdDO.ToolInfo,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (s *SingleAgentApplicationService) fetchModelDetails(ctx context.Context, agentInfo *entity.SingleAgent) ([]*modelEntity.Model, error) {
|
||||
if agentInfo.ModelInfo.ModelId == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
modelID := agentInfo.ModelInfo.GetModelId()
|
||||
modelInfos, err := s.appContext.ModelMgrDomainSVC.MGetModelByID(ctx, &modelmgr.MGetModelRequest{
|
||||
IDs: []int64{modelID},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch model(%d) details failed: %v", modelID, err)
|
||||
}
|
||||
|
||||
return modelInfos, nil
|
||||
}
|
||||
|
||||
func (s *SingleAgentApplicationService) fetchKnowledgeDetails(ctx context.Context, agentInfo *entity.SingleAgent) ([]*knowledgeModel.Knowledge, error) {
|
||||
knowledgeIDs := make([]int64, 0, len(agentInfo.Knowledge.KnowledgeInfo))
|
||||
for _, v := range agentInfo.Knowledge.KnowledgeInfo {
|
||||
id, err := conv.StrToInt64(v.GetId())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid knowledge id: %s", v.GetId())
|
||||
}
|
||||
knowledgeIDs = append(knowledgeIDs, id)
|
||||
}
|
||||
|
||||
if len(knowledgeIDs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
listResp, err := s.appContext.KnowledgeDomainSVC.ListKnowledge(ctx, &knowledge.ListKnowledgeRequest{
|
||||
IDs: knowledgeIDs,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch knowledge details failed: %v", err)
|
||||
}
|
||||
|
||||
return listResp.KnowledgeList, err
|
||||
}
|
||||
|
||||
func (s *SingleAgentApplicationService) fetchToolDetails(ctx context.Context, agentInfo *entity.SingleAgent, req *playground.GetDraftBotInfoAgwRequest) ([]*pluginEntity.ToolInfo, error) {
|
||||
return s.appContext.PluginDomainSVC.MGetAgentTools(ctx, &service.MGetAgentToolsRequest{
|
||||
SpaceID: agentInfo.SpaceID,
|
||||
AgentID: req.GetBotID(),
|
||||
IsDraft: true,
|
||||
VersionAgentTools: slices.Transform(agentInfo.Plugin, func(a *bot_common.PluginInfo) pluginEntity.VersionAgentTool {
|
||||
return pluginEntity.VersionAgentTool{
|
||||
ToolID: a.GetApiId(),
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *SingleAgentApplicationService) fetchPluginDetails(ctx context.Context, agentInfo *entity.SingleAgent, toolInfos []*pluginEntity.ToolInfo) ([]*pluginEntity.PluginInfo, error) {
|
||||
vPlugins := make([]pluginEntity.VersionPlugin, 0, len(agentInfo.Plugin))
|
||||
vPluginMap := make(map[string]bool, len(agentInfo.Plugin))
|
||||
for _, v := range toolInfos {
|
||||
k := fmt.Sprintf("%d:%s", v.PluginID, v.GetVersion())
|
||||
if vPluginMap[k] {
|
||||
continue
|
||||
}
|
||||
vPluginMap[k] = true
|
||||
vPlugins = append(vPlugins, pluginEntity.VersionPlugin{
|
||||
PluginID: v.PluginID,
|
||||
Version: v.GetVersion(),
|
||||
})
|
||||
}
|
||||
return s.appContext.PluginDomainSVC.MGetVersionPlugins(ctx, vPlugins)
|
||||
}
|
||||
|
||||
func (s *SingleAgentApplicationService) fetchWorkflowDetails(ctx context.Context, agentInfo *entity.SingleAgent) ([]*workflowEntity.Workflow, error) {
|
||||
if len(agentInfo.Workflow) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
policy := &vo.MGetPolicy{
|
||||
MetaQuery: vo.MetaQuery{
|
||||
IDs: slices.Transform(agentInfo.Workflow, func(a *bot_common.WorkflowInfo) int64 {
|
||||
return a.GetWorkflowId()
|
||||
}),
|
||||
},
|
||||
QType: vo.FromLatestVersion,
|
||||
}
|
||||
ret, _, err := s.appContext.WorkflowDomainSVC.MGet(ctx, policy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch workflow details failed: %v", err)
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func modelInfoDo2Vo(modelInfos []*modelEntity.Model) map[int64]*playground.ModelDetail {
|
||||
return slices.ToMap(modelInfos, func(e *modelEntity.Model) (int64, *playground.ModelDetail) {
|
||||
return e.ID, toModelDetail(e)
|
||||
})
|
||||
}
|
||||
|
||||
func toModelDetail(m *modelEntity.Model) *playground.ModelDetail {
|
||||
mm := m.Meta
|
||||
|
||||
return &playground.ModelDetail{
|
||||
Name: ptr.Of(m.Name),
|
||||
ModelName: ptr.Of(m.Meta.Name),
|
||||
ModelID: ptr.Of(m.ID),
|
||||
ModelFamily: ptr.Of(int64(mm.Protocol.TOModelClass())),
|
||||
ModelIconURL: ptr.Of(mm.IconURL),
|
||||
}
|
||||
}
|
||||
|
||||
func knowledgeInfoDo2Vo(klInfos []*knowledgeModel.Knowledge) map[string]*playground.KnowledgeDetail {
|
||||
return slices.ToMap(klInfos, func(e *knowledgeModel.Knowledge) (string, *playground.KnowledgeDetail) {
|
||||
return fmt.Sprintf("%v", e.ID), &playground.KnowledgeDetail{
|
||||
ID: ptr.Of(fmt.Sprintf("%d", e.ID)),
|
||||
Name: ptr.Of(e.Name),
|
||||
IconURL: ptr.Of(e.IconURL),
|
||||
FormatType: func() playground.DataSetType {
|
||||
switch e.Type {
|
||||
case knowledgeModel.DocumentTypeText:
|
||||
return playground.DataSetType_Text
|
||||
case knowledgeModel.DocumentTypeTable:
|
||||
return playground.DataSetType_Table
|
||||
case knowledgeModel.DocumentTypeImage:
|
||||
return playground.DataSetType_Image
|
||||
}
|
||||
return playground.DataSetType_Text
|
||||
}(),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func toolInfoDo2Vo(toolInfos []*pluginEntity.ToolInfo) map[int64]*playground.PluginAPIDetal {
|
||||
return slices.ToMap(toolInfos, func(e *pluginEntity.ToolInfo) (int64, *playground.PluginAPIDetal) {
|
||||
return e.ID, &playground.PluginAPIDetal{
|
||||
ID: ptr.Of(e.ID),
|
||||
Name: ptr.Of(e.GetName()),
|
||||
Description: ptr.Of(e.GetDesc()),
|
||||
PluginID: ptr.Of(e.PluginID),
|
||||
Parameters: parametersDo2Vo(e.Operation),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (s *SingleAgentApplicationService) pluginInfoDo2Vo(ctx context.Context, pluginInfos []*pluginEntity.PluginInfo) map[int64]*playground.PluginDetal {
|
||||
return slices.ToMap(pluginInfos, func(v *pluginEntity.PluginInfo) (int64, *playground.PluginDetal) {
|
||||
e := v.PluginInfo
|
||||
|
||||
var iconURL string
|
||||
if e.GetIconURI() != "" {
|
||||
var err error
|
||||
iconURL, err = s.appContext.TosClient.GetObjectUrl(ctx, e.GetIconURI())
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "get icon url failed, err = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return e.ID, &playground.PluginDetal{
|
||||
ID: ptr.Of(e.ID),
|
||||
Name: ptr.Of(e.GetName()),
|
||||
Description: ptr.Of(e.GetDesc()),
|
||||
PluginType: (*int64)(&e.PluginType),
|
||||
IconURL: &iconURL,
|
||||
PluginStatus: (*int64)(ptr.Of(plugin_develop_common.PluginStatus_PUBLISHED)),
|
||||
IsOfficial: func() *bool {
|
||||
if e.SpaceID == 0 {
|
||||
return ptr.Of(true)
|
||||
}
|
||||
return ptr.Of(false)
|
||||
}(),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func parametersDo2Vo(op *plugin.Openapi3Operation) []*playground.PluginParameter {
|
||||
var convertReqBody func(paramName string, isRequired bool, sc *openapi3.Schema) *playground.PluginParameter
|
||||
convertReqBody = func(paramName string, isRequired bool, sc *openapi3.Schema) *playground.PluginParameter {
|
||||
if disabledParam(sc) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var assistType *int64
|
||||
if v, ok := sc.Extensions[plugin.APISchemaExtendAssistType]; ok {
|
||||
if _v, ok := v.(string); ok {
|
||||
assistType = toParameterAssistType(_v)
|
||||
}
|
||||
}
|
||||
|
||||
paramInfo := &playground.PluginParameter{
|
||||
Name: ptr.Of(paramName),
|
||||
Type: ptr.Of(sc.Type),
|
||||
Description: ptr.Of(sc.Description),
|
||||
IsRequired: ptr.Of(isRequired),
|
||||
AssistType: assistType,
|
||||
}
|
||||
|
||||
switch sc.Type {
|
||||
case openapi3.TypeObject:
|
||||
required := slices.ToMap(sc.Required, func(e string) (string, bool) {
|
||||
return e, true
|
||||
})
|
||||
|
||||
subParams := make([]*playground.PluginParameter, 0, len(sc.Properties))
|
||||
for subParamName, prop := range sc.Properties {
|
||||
subParamInfo := convertReqBody(subParamName, required[subParamName], prop.Value)
|
||||
if subParamInfo != nil {
|
||||
subParams = append(subParams, subParamInfo)
|
||||
}
|
||||
}
|
||||
|
||||
paramInfo.SubParameters = subParams
|
||||
|
||||
return paramInfo
|
||||
case openapi3.TypeArray:
|
||||
paramInfo.SubType = ptr.Of(sc.Items.Value.Type)
|
||||
if sc.Items.Value.Type != openapi3.TypeObject {
|
||||
return paramInfo
|
||||
}
|
||||
|
||||
required := slices.ToMap(sc.Required, func(e string) (string, bool) {
|
||||
return e, true
|
||||
})
|
||||
|
||||
subParams := make([]*playground.PluginParameter, 0, len(sc.Items.Value.Properties))
|
||||
for subParamName, prop := range sc.Items.Value.Properties {
|
||||
subParamInfo := convertReqBody(subParamName, required[subParamName], prop.Value)
|
||||
if subParamInfo != nil {
|
||||
subParams = append(subParams, subParamInfo)
|
||||
}
|
||||
}
|
||||
|
||||
paramInfo.SubParameters = subParams
|
||||
|
||||
return paramInfo
|
||||
default:
|
||||
return paramInfo
|
||||
}
|
||||
}
|
||||
|
||||
var params []*playground.PluginParameter
|
||||
|
||||
for _, prop := range op.Parameters {
|
||||
paramVal := prop.Value
|
||||
schemaVal := paramVal.Schema.Value
|
||||
if schemaVal.Type == openapi3.TypeObject || schemaVal.Type == openapi3.TypeArray {
|
||||
continue
|
||||
}
|
||||
|
||||
if disabledParam(prop.Value.Schema.Value) {
|
||||
continue
|
||||
}
|
||||
|
||||
var assistType *int64
|
||||
if v, ok := schemaVal.Extensions[plugin.APISchemaExtendAssistType]; ok {
|
||||
if _v, ok := v.(string); ok {
|
||||
assistType = toParameterAssistType(_v)
|
||||
}
|
||||
}
|
||||
|
||||
params = append(params, &playground.PluginParameter{
|
||||
Name: ptr.Of(paramVal.Name),
|
||||
Description: ptr.Of(paramVal.Description),
|
||||
IsRequired: ptr.Of(paramVal.Required),
|
||||
Type: ptr.Of(schemaVal.Type),
|
||||
AssistType: assistType,
|
||||
})
|
||||
}
|
||||
|
||||
if op.RequestBody == nil || op.RequestBody.Value == nil || len(op.RequestBody.Value.Content) == 0 {
|
||||
return params
|
||||
}
|
||||
|
||||
for _, mType := range op.RequestBody.Value.Content {
|
||||
schemaVal := mType.Schema.Value
|
||||
if len(schemaVal.Properties) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
required := slices.ToMap(schemaVal.Required, func(e string) (string, bool) {
|
||||
return e, true
|
||||
})
|
||||
|
||||
for paramName, prop := range schemaVal.Properties {
|
||||
paramInfo := convertReqBody(paramName, required[paramName], prop.Value)
|
||||
if paramInfo != nil {
|
||||
params = append(params, paramInfo)
|
||||
}
|
||||
}
|
||||
|
||||
break // 只取一种 MIME
|
||||
}
|
||||
|
||||
return params
|
||||
}
|
||||
|
||||
func toParameterAssistType(assistType string) *int64 {
|
||||
if assistType == "" {
|
||||
return nil
|
||||
}
|
||||
switch plugin.APIFileAssistType(assistType) {
|
||||
case plugin.AssistTypeFile:
|
||||
return ptr.Of(int64(plugin_develop_common.AssistParameterType_CODE))
|
||||
case plugin.AssistTypeImage:
|
||||
return ptr.Of(int64(plugin_develop_common.AssistParameterType_IMAGE))
|
||||
case plugin.AssistTypeDoc:
|
||||
return ptr.Of(int64(plugin_develop_common.AssistParameterType_DOC))
|
||||
case plugin.AssistTypePPT:
|
||||
return ptr.Of(int64(plugin_develop_common.AssistParameterType_PPT))
|
||||
case plugin.AssistTypeCode:
|
||||
return ptr.Of(int64(plugin_develop_common.AssistParameterType_CODE))
|
||||
case plugin.AssistTypeExcel:
|
||||
return ptr.Of(int64(plugin_develop_common.AssistParameterType_EXCEL))
|
||||
case plugin.AssistTypeZIP:
|
||||
return ptr.Of(int64(plugin_develop_common.AssistParameterType_ZIP))
|
||||
case plugin.AssistTypeVideo:
|
||||
return ptr.Of(int64(plugin_develop_common.AssistParameterType_VIDEO))
|
||||
case plugin.AssistTypeAudio:
|
||||
return ptr.Of(int64(plugin_develop_common.AssistParameterType_AUDIO))
|
||||
case plugin.AssistTypeTXT:
|
||||
return ptr.Of(int64(plugin_develop_common.AssistParameterType_TXT))
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func workflowDo2Vo(wfInfos []*workflowEntity.Workflow) (map[int64]*playground.WorkflowDetail, error) {
|
||||
result := make(map[int64]*playground.WorkflowDetail, len(wfInfos))
|
||||
for _, e := range wfInfos {
|
||||
parameters, err := slices.TransformWithErrorCheck(e.InputParams, toPluginParameter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[e.ID] = &playground.WorkflowDetail{
|
||||
ID: ptr.Of(e.ID),
|
||||
Name: ptr.Of(e.Name),
|
||||
Description: ptr.Of(e.Desc),
|
||||
IconURL: ptr.Of(e.IconURL),
|
||||
PluginID: ptr.Of(e.ID),
|
||||
APIDetail: &playground.PluginAPIDetal{
|
||||
ID: ptr.Of(e.ID),
|
||||
Name: ptr.Of(e.Name),
|
||||
Description: ptr.Of(e.Desc),
|
||||
PluginID: ptr.Of(e.ID),
|
||||
Parameters: parameters,
|
||||
},
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func toPluginParameter(info *vo.NamedTypeInfo) (*playground.PluginParameter, error) {
|
||||
if info == nil {
|
||||
return nil, fmt.Errorf("named type info is nil")
|
||||
}
|
||||
p := &playground.PluginParameter{
|
||||
Name: ptr.Of(info.Name),
|
||||
Description: ptr.Of(info.Desc),
|
||||
IsRequired: ptr.Of(info.Required),
|
||||
}
|
||||
|
||||
switch info.Type {
|
||||
case vo.DataTypeString, vo.DataTypeFile, vo.DataTypeTime:
|
||||
p.Type = ptr.Of("string")
|
||||
if info.Type == vo.DataTypeFile {
|
||||
p.AssistType = toWorkflowParameterAssistType(string(*info.FileType))
|
||||
}
|
||||
|
||||
case vo.DataTypeInteger:
|
||||
p.Type = ptr.Of("integer")
|
||||
case vo.DataTypeNumber:
|
||||
p.Type = ptr.Of("number")
|
||||
case vo.DataTypeBoolean:
|
||||
p.Type = ptr.Of("boolean")
|
||||
case vo.DataTypeObject:
|
||||
p.Type = ptr.Of("object")
|
||||
p.SubParameters = make([]*playground.PluginParameter, 0, len(info.Properties))
|
||||
for _, sub := range info.Properties {
|
||||
subParameter, err := toPluginParameter(sub)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.SubParameters = append(p.SubParameters, subParameter)
|
||||
}
|
||||
case vo.DataTypeArray:
|
||||
p.Type = ptr.Of("array")
|
||||
eleParameter, err := toPluginParameter(info.ElemTypeInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.SubType = eleParameter.Type
|
||||
p.SubParameters = []*playground.PluginParameter{eleParameter}
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown named type info type: %s", info.Type)
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func toWorkflowParameterAssistType(assistType string) *int64 {
|
||||
if assistType == "" {
|
||||
return nil
|
||||
}
|
||||
switch vo.FileSubType(assistType) {
|
||||
case vo.FileTypeDefault:
|
||||
return ptr.Of(int64(workflow.AssistParameterType_DEFAULT))
|
||||
case vo.FileTypeImage:
|
||||
return ptr.Of(int64(workflow.AssistParameterType_IMAGE))
|
||||
case vo.FileTypeDocument:
|
||||
return ptr.Of(int64(workflow.AssistParameterType_DOC))
|
||||
case vo.FileTypePPT:
|
||||
return ptr.Of(int64(workflow.AssistParameterType_PPT))
|
||||
case vo.FileTypeCode:
|
||||
return ptr.Of(int64(workflow.AssistParameterType_CODE))
|
||||
case vo.FileTypeExcel:
|
||||
return ptr.Of(int64(workflow.AssistParameterType_EXCEL))
|
||||
case vo.FileTypeZip:
|
||||
return ptr.Of(int64(workflow.AssistParameterType_ZIP))
|
||||
case vo.FileTypeVideo:
|
||||
return ptr.Of(int64(workflow.AssistParameterType_VIDEO))
|
||||
case vo.FileTypeAudio:
|
||||
return ptr.Of(int64(workflow.AssistParameterType_AUDIO))
|
||||
case vo.FileTypeTxt:
|
||||
return ptr.Of(int64(workflow.AssistParameterType_TXT))
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
94
backend/application/singleagent/image.go
Normal file
94
backend/application/singleagent/image.go
Normal file
@@ -0,0 +1,94 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package singleagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/bot_common"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/developer_api"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/playground"
|
||||
)
|
||||
|
||||
func (s *SingleAgentApplicationService) GetUploadAuthToken(ctx context.Context, req *developer_api.GetUploadAuthTokenRequest) (*developer_api.GetUploadAuthTokenResponse, error) {
|
||||
|
||||
authToken, err := s.getAuthToken(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
prefix := s.getUploadPrefix(req.Scene, req.DataType)
|
||||
|
||||
return &developer_api.GetUploadAuthTokenResponse{
|
||||
Data: &developer_api.GetUploadAuthTokenData{
|
||||
ServiceID: authToken.ServiceID,
|
||||
UploadPathPrefix: prefix,
|
||||
UploadHost: authToken.UploadHost,
|
||||
Auth: &developer_api.UploadAuthTokenInfo{
|
||||
AccessKeyID: authToken.AccessKeyID,
|
||||
SecretAccessKey: authToken.SecretAccessKey,
|
||||
SessionToken: authToken.SessionToken,
|
||||
ExpiredTime: authToken.ExpiredTime,
|
||||
CurrentTime: authToken.CurrentTime,
|
||||
},
|
||||
Schema: authToken.HostScheme,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
func (s *SingleAgentApplicationService) getAuthToken(ctx context.Context) (*bot_common.AuthToken, error) {
|
||||
uploadAuthToken, err := s.appContext.ImageX.GetUploadAuth(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authToken := &bot_common.AuthToken{
|
||||
ServiceID: s.appContext.ImageX.GetServerID(),
|
||||
AccessKeyID: uploadAuthToken.AccessKeyID,
|
||||
SecretAccessKey: uploadAuthToken.SecretAccessKey,
|
||||
SessionToken: uploadAuthToken.SessionToken,
|
||||
ExpiredTime: uploadAuthToken.ExpiredTime,
|
||||
CurrentTime: uploadAuthToken.CurrentTime,
|
||||
UploadHost: s.appContext.ImageX.GetUploadHost(ctx),
|
||||
HostScheme: uploadAuthToken.HostScheme,
|
||||
}
|
||||
return authToken, nil
|
||||
}
|
||||
|
||||
func (s *SingleAgentApplicationService) getUploadPrefix(scene, dataType string) string {
|
||||
return strings.Replace(scene, "_", "-", -1) + "-" + dataType
|
||||
}
|
||||
|
||||
func (s *SingleAgentApplicationService) GetImagexShortUrl(ctx context.Context, req *playground.GetImagexShortUrlRequest) (*playground.GetImagexShortUrlResponse, error) {
|
||||
urlInfo := make(map[string]*playground.UrlInfo, len(req.Uris))
|
||||
for _, uri := range req.Uris {
|
||||
resURL, err := s.appContext.ImageX.GetResourceURL(ctx, uri)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
urlInfo[uri] = &playground.UrlInfo{
|
||||
URL: resURL.URL,
|
||||
ReviewStatus: true,
|
||||
}
|
||||
}
|
||||
|
||||
return &playground.GetImagexShortUrlResponse{
|
||||
Data: &playground.GetImagexShortUrlData{
|
||||
URLInfo: urlInfo,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
85
backend/application/singleagent/init.go
Normal file
85
backend/application/singleagent/init.go
Normal file
@@ -0,0 +1,85 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package singleagent
|
||||
|
||||
import (
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/repository"
|
||||
singleagent "github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/service"
|
||||
connector "github.com/coze-dev/coze-studio/backend/domain/connector/service"
|
||||
knowledge "github.com/coze-dev/coze-studio/backend/domain/knowledge/service"
|
||||
database "github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
|
||||
variables "github.com/coze-dev/coze-studio/backend/domain/memory/variables/service"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/service"
|
||||
search "github.com/coze-dev/coze-studio/backend/domain/search/service"
|
||||
shortcutCmd "github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/service"
|
||||
user "github.com/coze-dev/coze-studio/backend/domain/user/service"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/chatmodel"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/jsoncache"
|
||||
)
|
||||
|
||||
type (
|
||||
SingleAgent = singleagent.SingleAgent
|
||||
)
|
||||
|
||||
var SingleAgentSVC *SingleAgentApplicationService
|
||||
|
||||
type ServiceComponents struct {
|
||||
IDGen idgen.IDGenerator
|
||||
DB *gorm.DB
|
||||
Cache *redis.Client
|
||||
TosClient storage.Storage
|
||||
ImageX imagex.ImageX
|
||||
EventBus search.ProjectEventBus
|
||||
CounterRepo repository.CounterRepository
|
||||
|
||||
KnowledgeDomainSVC knowledge.Knowledge
|
||||
ModelMgrDomainSVC modelmgr.Manager
|
||||
PluginDomainSVC service.PluginService
|
||||
WorkflowDomainSVC workflow.Service
|
||||
UserDomainSVC user.User
|
||||
VariablesDomainSVC variables.Variables
|
||||
ConnectorDomainSVC connector.Connector
|
||||
DatabaseDomainSVC database.Database
|
||||
ShortcutCMDDomainSVC shortcutCmd.ShortcutCmd
|
||||
CPStore compose.CheckPointStore
|
||||
}
|
||||
|
||||
func InitService(c *ServiceComponents) (*SingleAgentApplicationService, error) {
|
||||
domainComponents := &singleagent.Components{
|
||||
AgentDraftRepo: repository.NewSingleAgentRepo(c.DB, c.IDGen, c.Cache),
|
||||
AgentVersionRepo: repository.NewSingleAgentVersionRepo(c.DB, c.IDGen),
|
||||
PublishInfoRepo: jsoncache.New[entity.PublishInfo]("agent:publish:last:", c.Cache),
|
||||
CounterRepo: repository.NewCounterRepo(c.Cache),
|
||||
CPStore: c.CPStore,
|
||||
ModelFactory: chatmodel.NewDefaultFactory(),
|
||||
}
|
||||
|
||||
singleAgentDomainSVC := singleagent.NewService(domainComponents)
|
||||
SingleAgentSVC = newApplicationService(c, singleAgentDomainSVC)
|
||||
|
||||
return SingleAgentSVC, nil
|
||||
}
|
||||
260
backend/application/singleagent/publish.go
Normal file
260
backend/application/singleagent/publish.go
Normal file
@@ -0,0 +1,260 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package singleagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/intelligence/common"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/developer_api"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/playground"
|
||||
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
|
||||
search "github.com/coze-dev/coze-studio/backend/domain/search/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/taskgroup"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
func (s *SingleAgentApplicationService) PublishAgent(ctx context.Context, req *developer_api.PublishDraftBotRequest) (*developer_api.PublishDraftBotResponse, error) {
|
||||
draftAgent, err := s.ValidateAgentDraftAccess(ctx, req.BotID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
version, err := s.getPublishAgentVersion(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
connectorIDs := make([]int64, 0, len(req.Connectors))
|
||||
for v := range req.Connectors {
|
||||
var id int64
|
||||
id, err = strconv.ParseInt(v, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !entity.PublishConnectorIDWhiteList[id] {
|
||||
return nil, errorx.New(errno.ErrAgentPermissionCode, errorx.KV("msg", fmt.Sprintf("connector %d not allowed", id)))
|
||||
}
|
||||
|
||||
connectorIDs = append(connectorIDs, id)
|
||||
}
|
||||
|
||||
p := &entity.SingleAgentPublish{
|
||||
ConnectorIds: connectorIDs,
|
||||
Version: version,
|
||||
PublishID: req.GetPublishID(),
|
||||
PublishInfo: req.HistoryInfo,
|
||||
}
|
||||
|
||||
publishFns := []publishFn{
|
||||
publishAgentVariables,
|
||||
publishAgentPlugins,
|
||||
publishShortcutCommand,
|
||||
publishDatabase,
|
||||
}
|
||||
|
||||
for _, pubFn := range publishFns {
|
||||
draftAgent, err = pubFn(ctx, s.appContext, p, draftAgent)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
err = s.DomainSVC.SavePublishRecord(ctx, p, draftAgent)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tasks := taskgroup.NewUninterruptibleTaskGroup(ctx, len(connectorIDs))
|
||||
publishResult := make(map[string]*developer_api.ConnectorBindResult, len(connectorIDs))
|
||||
lock := sync.Mutex{}
|
||||
|
||||
for _, connectorID := range connectorIDs {
|
||||
tasks.Go(func() error {
|
||||
_, err = s.DomainSVC.CreateSingleAgent(ctx, connectorID, version, draftAgent)
|
||||
if err != nil {
|
||||
logs.CtxWarnf(ctx, "create single agent failed: %v, agentID: %d, connectorID: %d , version : %s", err, draftAgent.AgentID, connectorID, version)
|
||||
lock.Lock()
|
||||
publishResult[conv.Int64ToStr(connectorID)] = &developer_api.ConnectorBindResult{
|
||||
PublishResultStatus: ptr.Of(developer_api.PublishResultStatus_Failed),
|
||||
}
|
||||
lock.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
// do other connector publish logic if need
|
||||
|
||||
lock.Lock()
|
||||
publishResult[conv.Int64ToStr(connectorID)] = &developer_api.ConnectorBindResult{
|
||||
PublishResultStatus: ptr.Of(developer_api.PublishResultStatus_Success),
|
||||
}
|
||||
lock.Unlock()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
_ = tasks.Wait()
|
||||
|
||||
err = s.appContext.EventBus.PublishProject(ctx, &search.ProjectDomainEvent{
|
||||
OpType: search.Updated,
|
||||
Project: &search.ProjectDocument{
|
||||
ID: draftAgent.AgentID,
|
||||
HasPublished: ptr.Of(1),
|
||||
PublishTimeMS: ptr.Of(time.Now().UnixMilli()),
|
||||
Type: common.IntelligenceType_Bot,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
logs.CtxWarnf(ctx, "publish project event failed, agentID: %d, err : %v", draftAgent.AgentID, err)
|
||||
}
|
||||
|
||||
return &developer_api.PublishDraftBotResponse{
|
||||
Data: &developer_api.PublishDraftBotData{
|
||||
CheckNotPass: false,
|
||||
PublishResult: publishResult,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SingleAgentApplicationService) getPublishAgentVersion(ctx context.Context, req *developer_api.PublishDraftBotRequest) (string, error) {
|
||||
version := req.GetCommitVersion()
|
||||
if version != "" {
|
||||
return version, nil
|
||||
}
|
||||
|
||||
v, err := s.appContext.IDGen.GenID(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
version = fmt.Sprintf("%v", v)
|
||||
|
||||
return version, nil
|
||||
}
|
||||
|
||||
func (s *SingleAgentApplicationService) GetAgentPopupInfo(ctx context.Context, req *playground.GetBotPopupInfoRequest) (*playground.GetBotPopupInfoResponse, error) {
|
||||
uid := ctxutil.MustGetUIDFromCtx(ctx)
|
||||
agentPopupCountInfo := make(map[playground.BotPopupType]int64, len(req.BotPopupTypes))
|
||||
|
||||
for _, agentPopupType := range req.BotPopupTypes {
|
||||
count, err := s.DomainSVC.GetAgentPopupCount(ctx, uid, req.GetBotID(), agentPopupType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
agentPopupCountInfo[agentPopupType] = count
|
||||
}
|
||||
|
||||
return &playground.GetBotPopupInfoResponse{
|
||||
Data: &playground.BotPopupInfoData{
|
||||
BotPopupCountInfo: agentPopupCountInfo,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SingleAgentApplicationService) UpdateAgentPopupInfo(ctx context.Context, req *playground.UpdateBotPopupInfoRequest) (*playground.UpdateBotPopupInfoResponse, error) {
|
||||
uid := ctxutil.MustGetUIDFromCtx(ctx)
|
||||
|
||||
err := s.DomainSVC.IncrAgentPopupCount(ctx, uid, req.GetBotID(), req.GetBotPopupType())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &playground.UpdateBotPopupInfoResponse{
|
||||
Code: 0,
|
||||
Msg: "success",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SingleAgentApplicationService) GetPublishConnectorList(ctx context.Context, req *developer_api.PublishConnectorListRequest) (*developer_api.PublishConnectorListResponse, error) {
|
||||
data, err := s.DomainSVC.GetPublishConnectorList(ctx, req.BotID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &developer_api.PublishConnectorListResponse{
|
||||
PublishConnectorList: data.PublishConnectorList,
|
||||
Code: 0,
|
||||
Msg: "success",
|
||||
}, nil
|
||||
}
|
||||
|
||||
type publishFn func(ctx context.Context, appContext *ServiceComponents, publishInfo *entity.SingleAgentPublish, agent *entity.SingleAgent) (*entity.SingleAgent, error)
|
||||
|
||||
func publishAgentVariables(ctx context.Context, appContext *ServiceComponents, publishInfo *entity.SingleAgentPublish, agent *entity.SingleAgent) (*entity.SingleAgent, error) {
|
||||
draftAgent := agent
|
||||
if draftAgent.VariablesMetaID != nil || *draftAgent.VariablesMetaID == 0 {
|
||||
return draftAgent, nil
|
||||
}
|
||||
|
||||
var newVariableMetaID int64
|
||||
newVariableMetaID, err := appContext.VariablesDomainSVC.PublishMeta(ctx, *draftAgent.VariablesMetaID, publishInfo.Version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
draftAgent.VariablesMetaID = ptr.Of(newVariableMetaID)
|
||||
|
||||
return draftAgent, nil
|
||||
}
|
||||
|
||||
func publishAgentPlugins(ctx context.Context, appContext *ServiceComponents, publishInfo *entity.SingleAgentPublish, agent *entity.SingleAgent) (*entity.SingleAgent, error) {
|
||||
err := appContext.PluginDomainSVC.PublishAgentTools(ctx, agent.AgentID, publishInfo.Version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return agent, nil
|
||||
}
|
||||
|
||||
func publishShortcutCommand(ctx context.Context, appContext *ServiceComponents, publishInfo *entity.SingleAgentPublish, agent *entity.SingleAgent) (*entity.SingleAgent, error) {
|
||||
logs.CtxInfof(ctx, "publishShortcutCommand agentID: %d, shortcutCommand: %v", agent.AgentID, agent.ShortcutCommand)
|
||||
if agent.ShortcutCommand == nil || len(agent.ShortcutCommand) == 0 {
|
||||
return agent, nil
|
||||
}
|
||||
cmdIDs := slices.Transform(agent.ShortcutCommand, func(a string) int64 {
|
||||
return conv.StrToInt64D(a, 0)
|
||||
})
|
||||
err := appContext.ShortcutCMDDomainSVC.PublishCMDs(ctx, agent.AgentID, cmdIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return agent, nil
|
||||
}
|
||||
|
||||
func publishDatabase(ctx context.Context, appContext *ServiceComponents, publishInfo *entity.SingleAgentPublish, agent *entity.SingleAgent) (*entity.SingleAgent, error) {
|
||||
onlineResp, err := appContext.DatabaseDomainSVC.PublishDatabase(ctx, &database.PublishDatabaseRequest{AgentID: agent.AgentID})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
agent.Database = onlineResp.OnlineDatabases
|
||||
return agent, nil
|
||||
}
|
||||
733
backend/application/singleagent/single_agent.go
Normal file
733
backend/application/singleagent/single_agent.go
Normal file
@@ -0,0 +1,733 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package singleagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
shortcutCmd "github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/service"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/getkin/kin-openapi/openapi3"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
|
||||
intelligence "github.com/coze-dev/coze-studio/backend/api/model/intelligence/common"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/bot_common"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/developer_api"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/playground"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/table"
|
||||
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossdatabase"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
|
||||
singleagent "github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/service"
|
||||
variableEntity "github.com/coze-dev/coze-studio/backend/domain/memory/variables/entity"
|
||||
shortcutEntity "github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/entity"
|
||||
|
||||
searchEntity "github.com/coze-dev/coze-studio/backend/domain/search/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
type SingleAgentApplicationService struct {
|
||||
appContext *ServiceComponents
|
||||
DomainSVC singleagent.SingleAgent
|
||||
ShortcutCMDSVC shortcutCmd.ShortcutCmd
|
||||
}
|
||||
|
||||
func newApplicationService(s *ServiceComponents, domain singleagent.SingleAgent) *SingleAgentApplicationService {
|
||||
return &SingleAgentApplicationService{
|
||||
appContext: s,
|
||||
DomainSVC: domain,
|
||||
ShortcutCMDSVC: s.ShortcutCMDDomainSVC,
|
||||
}
|
||||
}
|
||||
|
||||
const onboardingInfoMaxLength = 65535
|
||||
|
||||
func (s *SingleAgentApplicationService) generateOnboardingStr(onboardingInfo *bot_common.OnboardingInfo) (string, error) {
|
||||
onboarding := playground.OnboardingContent{}
|
||||
if onboardingInfo != nil {
|
||||
onboarding.Prologue = ptr.Of(onboardingInfo.GetPrologue())
|
||||
onboarding.SuggestedQuestions = onboardingInfo.GetSuggestedQuestions()
|
||||
onboarding.SuggestedQuestionsShowMode = onboardingInfo.SuggestedQuestionsShowMode
|
||||
}
|
||||
|
||||
onboardingInfoStr, err := sonic.MarshalString(onboarding)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return onboardingInfoStr, nil
|
||||
}
|
||||
|
||||
func (s *SingleAgentApplicationService) UpdateSingleAgentDraft(ctx context.Context, req *playground.UpdateDraftBotInfoAgwRequest) (*playground.UpdateDraftBotInfoAgwResponse, error) {
|
||||
if req.BotInfo.OnboardingInfo != nil {
|
||||
infoStr, err := s.generateOnboardingStr(req.BotInfo.OnboardingInfo)
|
||||
if err != nil {
|
||||
return nil, errorx.New(errno.ErrAgentPermissionCode, errorx.KV("msg", "onboarding_info invalidate"))
|
||||
}
|
||||
|
||||
if len(infoStr) > onboardingInfoMaxLength {
|
||||
return nil, errorx.New(errno.ErrAgentPermissionCode, errorx.KV("msg", "onboarding_info is too long"))
|
||||
}
|
||||
}
|
||||
|
||||
agentID := req.BotInfo.GetBotId()
|
||||
currentAgentInfo, err := s.ValidateAgentDraftAccess(ctx, agentID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userID := ctxutil.MustGetUIDFromCtx(ctx)
|
||||
|
||||
updateAgentInfo, err := s.applyAgentUpdates(currentAgentInfo, req.BotInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if req.BotInfo.VariableList != nil {
|
||||
var (
|
||||
varsMetaID int64
|
||||
vars = variableEntity.NewVariablesWithAgentVariables(req.BotInfo.VariableList)
|
||||
)
|
||||
|
||||
varsMetaID, err = s.appContext.VariablesDomainSVC.UpsertBotMeta(ctx, agentID, "", userID, vars)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
updateAgentInfo.VariablesMetaID = &varsMetaID
|
||||
}
|
||||
|
||||
err = s.DomainSVC.UpdateSingleAgentDraft(ctx, updateAgentInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = s.appContext.EventBus.PublishProject(ctx, &searchEntity.ProjectDomainEvent{
|
||||
OpType: searchEntity.Updated,
|
||||
Project: &searchEntity.ProjectDocument{
|
||||
ID: agentID,
|
||||
Name: &updateAgentInfo.Name,
|
||||
Type: intelligence.IntelligenceType_Bot,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &playground.UpdateDraftBotInfoAgwResponse{
|
||||
Data: &playground.UpdateDraftBotInfoAgwData{
|
||||
HasChange: ptr.Of(true),
|
||||
CheckNotPass: false,
|
||||
Branch: playground.BranchPtr(playground.Branch_PersonalDraft),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SingleAgentApplicationService) UpdatePromptDisable(ctx context.Context, req *table.UpdateDatabaseBotSwitchRequest) (*table.UpdateDatabaseBotSwitchResponse, error) {
|
||||
agentID := req.GetBotID()
|
||||
draft, err := s.ValidateAgentDraftAccess(ctx, agentID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(draft.Database) == 0 {
|
||||
return nil, fmt.Errorf("agent %d has no database", agentID) // TODO(@fanlv): 错误码
|
||||
}
|
||||
|
||||
dbInfos := draft.Database
|
||||
var found bool
|
||||
for _, db := range dbInfos {
|
||||
if db.GetTableId() == conv.Int64ToStr(req.GetDatabaseID()) {
|
||||
db.PromptDisabled = ptr.Of(req.GetPromptDisable())
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
return nil, fmt.Errorf("database %d not found in agent %d", req.GetDatabaseID(), agentID) // TODO(@fanlv): 错误码
|
||||
}
|
||||
|
||||
draft.Database = dbInfos
|
||||
err = s.DomainSVC.UpdateSingleAgentDraft(ctx, draft)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &table.UpdateDatabaseBotSwitchResponse{
|
||||
Code: 0,
|
||||
Msg: "success",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SingleAgentApplicationService) UnBindDatabase(ctx context.Context, req *table.BindDatabaseToBotRequest) (*table.BindDatabaseToBotResponse, error) {
|
||||
agentID := req.GetBotID()
|
||||
draft, err := s.ValidateAgentDraftAccess(ctx, agentID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(draft.Database) == 0 {
|
||||
return nil, fmt.Errorf("agent %d has no database", agentID)
|
||||
}
|
||||
|
||||
dbInfos := draft.Database
|
||||
var found bool
|
||||
newDBInfos := make([]*bot_common.Database, 0)
|
||||
for _, db := range dbInfos {
|
||||
if db.GetTableId() == conv.Int64ToStr(req.GetDatabaseID()) {
|
||||
found = true
|
||||
continue
|
||||
}
|
||||
newDBInfos = append(newDBInfos, db)
|
||||
}
|
||||
|
||||
if !found {
|
||||
return nil, fmt.Errorf("database %d not found in agent %d", req.GetDatabaseID(), agentID)
|
||||
}
|
||||
|
||||
draft.Database = newDBInfos
|
||||
err = s.DomainSVC.UpdateSingleAgentDraft(ctx, draft)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = crossdatabase.DefaultSVC().UnBindDatabase(ctx, &database.UnBindDatabaseToAgentRequest{
|
||||
AgentID: agentID,
|
||||
DraftDatabaseID: req.GetDatabaseID(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &table.BindDatabaseToBotResponse{
|
||||
Code: 0,
|
||||
Msg: "success",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SingleAgentApplicationService) BindDatabase(ctx context.Context, req *table.BindDatabaseToBotRequest) (*table.BindDatabaseToBotResponse, error) {
|
||||
agentID := req.GetBotID()
|
||||
draft, err := s.ValidateAgentDraftAccess(ctx, agentID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dbMap := slices.ToMap(draft.Database, func(d *bot_common.Database) (string, *bot_common.Database) {
|
||||
return d.GetTableId(), d
|
||||
})
|
||||
if _, ok := dbMap[conv.Int64ToStr(req.GetDatabaseID())]; ok {
|
||||
return nil, errorx.New(errno.ErrAgentPermissionCode, errorx.KVf("msg", "database %d already bound to agent %d", req.GetDatabaseID(), agentID))
|
||||
}
|
||||
|
||||
basics := []*database.DatabaseBasic{
|
||||
{
|
||||
ID: req.DatabaseID,
|
||||
TableType: table.TableType_DraftTable,
|
||||
},
|
||||
}
|
||||
|
||||
draftRes, err := crossdatabase.DefaultSVC().MGetDatabase(ctx, &database.MGetDatabaseRequest{
|
||||
Basics: basics,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(draftRes.Databases) == 0 {
|
||||
return nil, fmt.Errorf("database %d not found", req.DatabaseID)
|
||||
}
|
||||
|
||||
draftDatabase := draftRes.Databases[0]
|
||||
|
||||
fields := make([]*bot_common.FieldItem, 0, len(draftDatabase.FieldList))
|
||||
for _, field := range draftDatabase.FieldList {
|
||||
fields = append(fields, &bot_common.FieldItem{
|
||||
Name: ptr.Of(field.Name),
|
||||
Desc: ptr.Of(field.Desc),
|
||||
Type: ptr.Of(bot_common.FieldItemType(field.Type)),
|
||||
MustRequired: ptr.Of(field.MustRequired),
|
||||
AlterId: ptr.Of(field.AlterID),
|
||||
Id: ptr.Of(int64(0)),
|
||||
})
|
||||
}
|
||||
|
||||
bindDB := &bot_common.Database{
|
||||
TableId: ptr.Of(strconv.FormatInt(draftDatabase.ID, 10)),
|
||||
TableName: ptr.Of(draftDatabase.TableName),
|
||||
TableDesc: ptr.Of(draftDatabase.TableDesc),
|
||||
FieldList: fields,
|
||||
RWMode: ptr.Of(bot_common.BotTableRWMode(draftDatabase.RwMode)),
|
||||
}
|
||||
|
||||
if len(draft.Database) == 0 {
|
||||
draft.Database = make([]*bot_common.Database, 0, 1)
|
||||
}
|
||||
draft.Database = append(draft.Database, bindDB)
|
||||
|
||||
err = s.DomainSVC.UpdateSingleAgentDraft(ctx, draft)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = crossdatabase.DefaultSVC().BindDatabase(ctx, &database.BindDatabaseToAgentRequest{
|
||||
AgentID: agentID,
|
||||
DraftDatabaseID: req.GetDatabaseID(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &table.BindDatabaseToBotResponse{
|
||||
Code: 0,
|
||||
Msg: "success",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SingleAgentApplicationService) applyAgentUpdates(target *entity.SingleAgent, patch *bot_common.BotInfoForUpdate) (*entity.SingleAgent, error) {
|
||||
if patch.Name != nil {
|
||||
target.Name = *patch.Name
|
||||
}
|
||||
|
||||
if patch.Description != nil {
|
||||
target.Desc = *patch.Description
|
||||
}
|
||||
|
||||
if patch.IconUri != nil {
|
||||
target.IconURI = *patch.IconUri
|
||||
}
|
||||
|
||||
if patch.OnboardingInfo != nil {
|
||||
target.OnboardingInfo = patch.OnboardingInfo
|
||||
}
|
||||
|
||||
if patch.ModelInfo != nil {
|
||||
target.ModelInfo = patch.ModelInfo
|
||||
}
|
||||
|
||||
if patch.PromptInfo != nil {
|
||||
target.Prompt = patch.PromptInfo
|
||||
}
|
||||
|
||||
if patch.WorkflowInfoList != nil {
|
||||
target.Workflow = patch.WorkflowInfoList
|
||||
}
|
||||
|
||||
if patch.PluginInfoList != nil {
|
||||
target.Plugin = patch.PluginInfoList
|
||||
}
|
||||
|
||||
if patch.Knowledge != nil {
|
||||
target.Knowledge = patch.Knowledge
|
||||
}
|
||||
|
||||
if patch.SuggestReplyInfo != nil {
|
||||
target.SuggestReply = patch.SuggestReplyInfo
|
||||
}
|
||||
|
||||
if patch.BackgroundImageInfoList != nil {
|
||||
target.BackgroundImageInfoList = patch.BackgroundImageInfoList
|
||||
}
|
||||
|
||||
if patch.Agents != nil && len(patch.Agents) > 0 && patch.Agents[0].JumpConfig != nil {
|
||||
target.JumpConfig = patch.Agents[0].JumpConfig
|
||||
}
|
||||
|
||||
if patch.ShortcutSort != nil {
|
||||
target.ShortcutCommand = patch.ShortcutSort
|
||||
}
|
||||
|
||||
if patch.DatabaseList != nil {
|
||||
for _, db := range patch.DatabaseList {
|
||||
if db.PromptDisabled == nil {
|
||||
db.PromptDisabled = ptr.Of(false) // default is false
|
||||
}
|
||||
}
|
||||
target.Database = patch.DatabaseList
|
||||
}
|
||||
|
||||
return target, nil
|
||||
}
|
||||
|
||||
func (s *SingleAgentApplicationService) DeleteAgentDraft(ctx context.Context, req *developer_api.DeleteDraftBotRequest) (*developer_api.DeleteDraftBotResponse, error) {
|
||||
_, err := s.ValidateAgentDraftAccess(ctx, req.GetBotID())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = s.DomainSVC.DeleteAgentDraft(ctx, req.GetSpaceID(), req.GetBotID())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = s.appContext.EventBus.PublishProject(ctx, &searchEntity.ProjectDomainEvent{
|
||||
OpType: searchEntity.Deleted,
|
||||
Project: &searchEntity.ProjectDocument{
|
||||
ID: req.GetBotID(),
|
||||
Type: intelligence.IntelligenceType_Bot,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
logs.CtxWarnf(ctx, "publish delete project event failed id = %v , err = %v", req.GetBotID(), err)
|
||||
}
|
||||
|
||||
return &developer_api.DeleteDraftBotResponse{
|
||||
Data: &developer_api.DeleteDraftBotData{},
|
||||
Code: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SingleAgentApplicationService) singleAgentDraftDo2Vo(ctx context.Context, do *entity.SingleAgent) (*bot_common.BotInfo, error) {
|
||||
vo := &bot_common.BotInfo{
|
||||
BotId: do.AgentID,
|
||||
Name: do.Name,
|
||||
Description: do.Desc,
|
||||
IconUri: do.IconURI,
|
||||
OnboardingInfo: do.OnboardingInfo,
|
||||
ModelInfo: do.ModelInfo,
|
||||
PromptInfo: do.Prompt,
|
||||
PluginInfoList: do.Plugin,
|
||||
Knowledge: do.Knowledge,
|
||||
WorkflowInfoList: do.Workflow,
|
||||
SuggestReplyInfo: do.SuggestReply,
|
||||
CreatorId: do.CreatorID,
|
||||
TaskInfo: &bot_common.TaskInfo{},
|
||||
CreateTime: do.CreatedAt / 1000,
|
||||
UpdateTime: do.UpdatedAt / 1000,
|
||||
BotMode: bot_common.BotMode_SingleMode,
|
||||
BackgroundImageInfoList: do.BackgroundImageInfoList,
|
||||
Status: bot_common.BotStatus_Using,
|
||||
DatabaseList: do.Database,
|
||||
ShortcutSort: do.ShortcutCommand,
|
||||
}
|
||||
|
||||
if do.VariablesMetaID != nil {
|
||||
vars, err := s.appContext.VariablesDomainSVC.GetVariableMetaByID(ctx, *do.VariablesMetaID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if vars != nil {
|
||||
vo.VariableList = vars.ToAgentVariables()
|
||||
}
|
||||
}
|
||||
|
||||
if vo.IconUri != "" {
|
||||
url, err := s.appContext.TosClient.GetObjectUrl(ctx, vo.IconUri)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
vo.IconUrl = url
|
||||
}
|
||||
|
||||
if vo.ModelInfo == nil || vo.ModelInfo.ModelId == nil {
|
||||
mi, err := s.defaultModelInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
vo.ModelInfo = mi
|
||||
}
|
||||
|
||||
return vo, nil
|
||||
}
|
||||
|
||||
func disabledParam(schemaVal *openapi3.Schema) bool {
|
||||
if len(schemaVal.Extensions) == 0 {
|
||||
return false
|
||||
}
|
||||
globalDisable, localDisable := false, false
|
||||
if v, ok := schemaVal.Extensions[plugin.APISchemaExtendLocalDisable]; ok {
|
||||
localDisable = v.(bool)
|
||||
}
|
||||
if v, ok := schemaVal.Extensions[plugin.APISchemaExtendGlobalDisable]; ok {
|
||||
globalDisable = v.(bool)
|
||||
}
|
||||
return globalDisable || localDisable
|
||||
}
|
||||
|
||||
func (s *SingleAgentApplicationService) UpdateAgentDraftDisplayInfo(ctx context.Context, req *developer_api.UpdateDraftBotDisplayInfoRequest) (*developer_api.UpdateDraftBotDisplayInfoResponse, error) {
|
||||
uid := ctxutil.GetUIDFromCtx(ctx)
|
||||
if uid == nil {
|
||||
return nil, errorx.New(errno.ErrAgentPermissionCode, errorx.KV("msg", "session required"))
|
||||
}
|
||||
|
||||
_, err := s.ValidateAgentDraftAccess(ctx, req.BotID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
draftInfoDo := &entity.AgentDraftDisplayInfo{
|
||||
AgentID: req.BotID,
|
||||
DisplayInfo: req.DisplayInfo,
|
||||
SpaceID: req.SpaceID,
|
||||
}
|
||||
|
||||
err = s.DomainSVC.UpdateAgentDraftDisplayInfo(ctx, *uid, draftInfoDo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &developer_api.UpdateDraftBotDisplayInfoResponse{
|
||||
Code: 0,
|
||||
Msg: "success",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SingleAgentApplicationService) GetAgentDraftDisplayInfo(ctx context.Context, req *developer_api.GetDraftBotDisplayInfoRequest) (*developer_api.GetDraftBotDisplayInfoResponse, error) {
|
||||
uid := ctxutil.GetUIDFromCtx(ctx)
|
||||
if uid == nil {
|
||||
return nil, errorx.New(errno.ErrAgentPermissionCode, errorx.KV("msg", "session required"))
|
||||
}
|
||||
|
||||
_, err := s.ValidateAgentDraftAccess(ctx, req.BotID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
draftInfoDo, err := s.DomainSVC.GetAgentDraftDisplayInfo(ctx, *uid, req.BotID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &developer_api.GetDraftBotDisplayInfoResponse{
|
||||
Code: 0,
|
||||
Msg: "success",
|
||||
Data: draftInfoDo.DisplayInfo,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SingleAgentApplicationService) ValidateAgentDraftAccess(ctx context.Context, agentID int64) (*entity.SingleAgent, error) {
|
||||
uid := ctxutil.GetUIDFromCtx(ctx)
|
||||
if uid == nil {
|
||||
uid = ptr.Of(int64(888))
|
||||
// return nil, errorx.New(errno.ErrAgentPermissionCode, errorx.KV("msg", "session uid not found"))
|
||||
}
|
||||
|
||||
do, err := s.DomainSVC.GetSingleAgentDraft(ctx, agentID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if do == nil {
|
||||
return nil, errorx.New(errno.ErrAgentPermissionCode, errorx.KVf("msg", "No agent draft(%d) found for the given agent ID", agentID))
|
||||
}
|
||||
|
||||
if do.SpaceID == consts.TemplateSpaceID { // duplicate template, not need check uid permission
|
||||
return do, nil
|
||||
}
|
||||
|
||||
if do.CreatorID != *uid {
|
||||
logs.CtxErrorf(ctx, "user(%d) is not the creator(%d) of the agent draft", *uid, do.CreatorID)
|
||||
|
||||
return do, errorx.New(errno.ErrAgentPermissionCode, errorx.KV("detail", "you are not the agent owner"))
|
||||
}
|
||||
|
||||
return do, nil
|
||||
}
|
||||
|
||||
func (s *SingleAgentApplicationService) ListAgentPublishHistory(ctx context.Context, req *developer_api.ListDraftBotHistoryRequest) (*developer_api.ListDraftBotHistoryResponse, error) {
|
||||
resp := &developer_api.ListDraftBotHistoryResponse{}
|
||||
draftAgent, err := s.ValidateAgentDraftAccess(ctx, req.BotID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var connectorID *int64
|
||||
if req.GetConnectorID() != "" {
|
||||
var id int64
|
||||
id, err = conv.StrToInt64(req.GetConnectorID())
|
||||
if err != nil {
|
||||
return nil, errorx.New(errno.ErrAgentInvalidParamCode, errorx.KV("msg", fmt.Sprintf("ConnectorID %v invalidate", *req.ConnectorID)))
|
||||
}
|
||||
|
||||
connectorID = ptr.Of(id)
|
||||
}
|
||||
|
||||
historyList, err := s.DomainSVC.ListAgentPublishHistory(ctx, draftAgent.AgentID, req.PageIndex, req.PageSize, connectorID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
uid := ctxutil.MustGetUIDFromCtx(ctx)
|
||||
resp.Data = &developer_api.ListDraftBotHistoryData{}
|
||||
|
||||
for _, v := range historyList {
|
||||
connectorInfos := make([]*developer_api.ConnectorInfo, 0, len(v.ConnectorIds))
|
||||
infos, err := s.appContext.ConnectorDomainSVC.GetByIDs(ctx, v.ConnectorIds)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, info := range infos {
|
||||
connectorInfos = append(connectorInfos, info.ToVO())
|
||||
}
|
||||
|
||||
creator, err := s.appContext.UserDomainSVC.GetUserProfiles(ctx, v.CreatorID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
info := ""
|
||||
if v.PublishInfo != nil {
|
||||
info = *v.PublishInfo
|
||||
}
|
||||
|
||||
historyInfo := &developer_api.HistoryInfo{
|
||||
HistoryType: developer_api.HistoryType_FLAG,
|
||||
Version: v.Version,
|
||||
Info: info,
|
||||
CreateTime: conv.Int64ToStr(v.CreatedAt / 1000),
|
||||
ConnectorInfos: connectorInfos,
|
||||
Creator: &developer_api.Creator{
|
||||
ID: v.CreatorID,
|
||||
Name: creator.Name,
|
||||
AvatarURL: creator.IconURL,
|
||||
Self: uid == v.CreatorID,
|
||||
// UserUniqueName: creator.UserUniqueName, // TODO(@fanlv) : user domain 补完以后再改
|
||||
// UserLabel TODO
|
||||
},
|
||||
PublishID: &v.PublishID,
|
||||
}
|
||||
|
||||
resp.Data.HistoryInfos = append(resp.Data.HistoryInfos, historyInfo)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (s *SingleAgentApplicationService) ReportUserBehavior(ctx context.Context, req *playground.ReportUserBehaviorRequest) (resp *playground.ReportUserBehaviorResponse, err error) {
|
||||
err = s.appContext.EventBus.PublishProject(ctx, &searchEntity.ProjectDomainEvent{
|
||||
OpType: searchEntity.Updated,
|
||||
Project: &searchEntity.ProjectDocument{
|
||||
ID: req.ResourceID,
|
||||
SpaceID: req.SpaceID,
|
||||
Type: intelligence.IntelligenceType_Bot,
|
||||
IsRecentlyOpen: ptr.Of(1),
|
||||
RecentlyOpenMS: ptr.Of(time.Now().UnixMilli()),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
logs.CtxWarnf(ctx, "publish updated project event failed id=%v, err=%v", req.ResourceID, err)
|
||||
}
|
||||
|
||||
return &playground.ReportUserBehaviorResponse{}, nil
|
||||
}
|
||||
|
||||
func (s *SingleAgentApplicationService) GetAgentOnlineInfo(ctx context.Context, req *playground.GetBotOnlineInfoReq) (*bot_common.OpenAPIBotInfo, error) {
|
||||
|
||||
uid := ctxutil.MustGetUIDFromApiAuthCtx(ctx)
|
||||
|
||||
connectorID, err := conv.StrToInt64(ptr.From(req.ConnectorID))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if connectorID == 0 {
|
||||
connectorID = ctxutil.GetApiAuthFromCtx(ctx).ConnectorID
|
||||
}
|
||||
agentInfo, err := s.DomainSVC.ObtainAgentByIdentity(ctx, &entity.AgentIdentity{
|
||||
AgentID: req.BotID,
|
||||
ConnectorID: connectorID,
|
||||
Version: ptr.From(req.Version),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if agentInfo == nil {
|
||||
logs.CtxErrorf(ctx, "agent(%d) is not exist", req.BotID)
|
||||
return nil, errorx.New(errno.ErrAgentPermissionCode, errorx.KV("msg", "agent not exist"))
|
||||
}
|
||||
if agentInfo.CreatorID != uid {
|
||||
return nil, errorx.New(errno.ErrPromptPermissionCode, errorx.KV("msg", "agent not own"))
|
||||
}
|
||||
combineInfo := &bot_common.OpenAPIBotInfo{
|
||||
BotID: agentInfo.AgentID,
|
||||
Name: agentInfo.Name,
|
||||
Description: agentInfo.Desc,
|
||||
IconURL: agentInfo.IconURI,
|
||||
Version: agentInfo.Version,
|
||||
BotMode: bot_common.BotMode_SingleMode,
|
||||
PromptInfo: agentInfo.Prompt,
|
||||
OnboardingInfo: agentInfo.OnboardingInfo,
|
||||
ModelInfo: agentInfo.ModelInfo,
|
||||
WorkflowInfoList: agentInfo.Workflow,
|
||||
PluginInfoList: agentInfo.Plugin,
|
||||
}
|
||||
|
||||
if agentInfo.IconURI != "" {
|
||||
url, err := s.appContext.TosClient.GetObjectUrl(ctx, agentInfo.IconURI)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
combineInfo.IconURL = url
|
||||
}
|
||||
|
||||
if len(agentInfo.ShortcutCommand) > 0 {
|
||||
shortcutInfos, err := s.ShortcutCMDSVC.ListCMD(ctx, &shortcutEntity.ListMeta{
|
||||
ObjectID: agentInfo.AgentID,
|
||||
IsOnline: 1,
|
||||
CommandIDs: slices.Transform(agentInfo.ShortcutCommand, func(s string) int64 {
|
||||
i, _ := conv.StrToInt64(s)
|
||||
return i
|
||||
}),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
combineInfo.ShortcutCommands = make([]*bot_common.ShortcutCommandInfo, 0, len(shortcutInfos))
|
||||
combineInfo.ShortcutCommands = slices.Transform(shortcutInfos, func(si *shortcutEntity.ShortcutCmd) *bot_common.ShortcutCommandInfo {
|
||||
url := ""
|
||||
if si.ShortcutIcon != nil && si.ShortcutIcon.URI != "" {
|
||||
getUrl, e := s.appContext.TosClient.GetObjectUrl(ctx, si.ShortcutIcon.URI)
|
||||
if e == nil {
|
||||
url = getUrl
|
||||
}
|
||||
}
|
||||
|
||||
return &bot_common.ShortcutCommandInfo{
|
||||
ID: si.CommandID,
|
||||
Name: si.CommandName,
|
||||
Description: si.Description,
|
||||
IconURL: url,
|
||||
QueryTemplate: si.TemplateQuery,
|
||||
AgentID: ptr.Of(si.ObjectID),
|
||||
Command: si.ShortcutCommand,
|
||||
Components: slices.Transform(si.Components, func(i *playground.Components) *bot_common.ShortcutCommandComponent {
|
||||
return &bot_common.ShortcutCommandComponent{
|
||||
Name: i.Name,
|
||||
Description: i.Description,
|
||||
Type: i.InputType.String(),
|
||||
ToolParameter: ptr.Of(i.Parameter),
|
||||
Options: i.Options,
|
||||
DefaultValue: ptr.Of(i.DefaultValue.Value),
|
||||
IsHide: i.Hide,
|
||||
}
|
||||
}),
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
return combineInfo, nil
|
||||
}
|
||||
43
backend/application/template/init.go
Normal file
43
backend/application/template/init.go
Normal file
@@ -0,0 +1,43 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package template
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/template/repository"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
)
|
||||
|
||||
type ServiceComponents struct {
|
||||
DB *gorm.DB
|
||||
IDGen idgen.IDGenerator
|
||||
Storage storage.Storage
|
||||
}
|
||||
|
||||
func InitService(ctx context.Context, components *ServiceComponents) *ApplicationService {
|
||||
|
||||
tRepo := repository.NewTemplateDAO(components.DB, components.IDGen)
|
||||
|
||||
ApplicationSVC.templateRepo = tRepo
|
||||
ApplicationSVC.storage = components.Storage
|
||||
|
||||
return ApplicationSVC
|
||||
}
|
||||
92
backend/application/template/template.go
Normal file
92
backend/application/template/template.go
Normal file
@@ -0,0 +1,92 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package template
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
productAPI "github.com/coze-dev/coze-studio/backend/api/model/flow/marketplace/product_public_api"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/template/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/template/repository"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
)
|
||||
|
||||
type ApplicationService struct {
|
||||
templateRepo repository.TemplateRepository
|
||||
storage storage.Storage
|
||||
}
|
||||
|
||||
var ApplicationSVC = &ApplicationService{}
|
||||
|
||||
func (t *ApplicationService) PublicGetProductList(ctx context.Context, req *productAPI.GetProductListRequest) (resp *productAPI.GetProductListResponse, err error) {
|
||||
pageSize := 50
|
||||
if req.PageSize > 0 {
|
||||
pageSize = int(req.PageSize)
|
||||
}
|
||||
pagination := &entity.Pagination{
|
||||
Limit: pageSize,
|
||||
Offset: int(req.PageNum) * pageSize,
|
||||
}
|
||||
|
||||
listResp, allNum, err := t.templateRepo.List(ctx, &entity.TemplateFilter{SpaceID: ptr.Of(int64(consts.TemplateSpaceID))}, pagination, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
products := make([]*productAPI.ProductInfo, 0, len(listResp))
|
||||
for _, item := range listResp {
|
||||
meta := item.MetaInfo
|
||||
for _, cover := range meta.Covers {
|
||||
objURL, uRrr := t.storage.GetObjectUrl(ctx, cover.URI)
|
||||
if uRrr == nil {
|
||||
cover.URL = objURL
|
||||
}
|
||||
}
|
||||
|
||||
avatarURL, uRrr := t.storage.GetObjectUrl(ctx, "default_icon/connector-coze.png")
|
||||
if uRrr == nil {
|
||||
if meta.Seller != nil {
|
||||
meta.Seller.AvatarURL = avatarURL
|
||||
}
|
||||
if meta.UserInfo != nil {
|
||||
meta.UserInfo.AvatarURL = avatarURL
|
||||
}
|
||||
}
|
||||
|
||||
products = append(products, &productAPI.ProductInfo{
|
||||
MetaInfo: item.MetaInfo,
|
||||
BotExtra: item.AgentExtra,
|
||||
WorkflowExtra: item.WorkflowExtra,
|
||||
ProjectExtra: item.ProjectExtra,
|
||||
})
|
||||
}
|
||||
hasMore := false
|
||||
if int64(int(req.PageNum)*pageSize) < allNum {
|
||||
hasMore = true
|
||||
}
|
||||
resp = &productAPI.GetProductListResponse{
|
||||
Data: &productAPI.GetProductListData{
|
||||
Products: products,
|
||||
HasMore: hasMore,
|
||||
Total: int32(allNum),
|
||||
},
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
25
backend/application/upload/consts.go
Normal file
25
backend/application/upload/consts.go
Normal file
@@ -0,0 +1,25 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package upload
|
||||
|
||||
const maxFileSize = 200 * 1024 * 1024
|
||||
const (
|
||||
TextKnowledgeDefaultIcon = "default_icon/text_kn_default_icon.png"
|
||||
TableKnowledgeDefaultIcon = "default_icon/table_kn_default_icon.png"
|
||||
ImageKnowledgeDefaultIcon = "default_icon/image_kn_default_icon.png"
|
||||
DatabaseDefaultIcon = "default_icon/default_database_icon.png"
|
||||
)
|
||||
708
backend/application/upload/icon.go
Normal file
708
backend/application/upload/icon.go
Normal file
@@ -0,0 +1,708 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package upload
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/xml"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
"image"
|
||||
_ "image/gif"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"io"
|
||||
"math"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"path"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "golang.org/x/image/tiff"
|
||||
_ "golang.org/x/image/webp"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/flow/dataengine/dataset"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/developer_api"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/playground"
|
||||
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/upload/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
func InitService(oss storage.Storage, cache cache.Cmdable) {
|
||||
SVC.cache = cache
|
||||
SVC.oss = oss
|
||||
}
|
||||
|
||||
var SVC = &UploadService{}
|
||||
|
||||
type UploadService struct {
|
||||
oss storage.Storage
|
||||
cache cache.Cmdable
|
||||
}
|
||||
|
||||
const (
|
||||
uploadKey = "UploadServiceUpload:%s"
|
||||
uploadPartKey = "UploadServiceUpload:%s/parts"
|
||||
partKey = "UploadServiceUpload/%s/part-%s"
|
||||
)
|
||||
|
||||
func (u *UploadService) PartUploadFileInit(ctx context.Context, objKey string) (uploadID string, err error) {
|
||||
uploadID = uuid.NewString()
|
||||
key := fmt.Sprintf(uploadKey, uploadID)
|
||||
err = u.cache.HSet(ctx,
|
||||
key,
|
||||
"objkey", objKey,
|
||||
).Err()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
err = u.cache.Expire(ctx, key, time.Minute*10).Err()
|
||||
return
|
||||
}
|
||||
|
||||
type PartUploadFileRequest struct {
|
||||
UploadID string
|
||||
PartNumber string
|
||||
Data []byte
|
||||
}
|
||||
|
||||
type PartUploadFileResponse struct {
|
||||
Crc32 string
|
||||
}
|
||||
|
||||
type PartUploadFileCompleteRequest struct {
|
||||
UploadID string
|
||||
ObjKey string
|
||||
Crc32Map map[string]string
|
||||
}
|
||||
|
||||
func (u *UploadService) PartUploadFile(ctx context.Context, req *PartUploadFileRequest) (resp *PartUploadFileResponse, err error) {
|
||||
key := fmt.Sprintf(uploadKey, req.UploadID)
|
||||
exists, err := u.cache.Exists(ctx, key).Result()
|
||||
if err != nil || exists == 0 {
|
||||
return nil, fmt.Errorf("upload session invalid: %v", err)
|
||||
}
|
||||
crc32Val := crc32.ChecksumIEEE(req.Data)
|
||||
partTosKey := fmt.Sprintf(partKey, req.UploadID, req.PartNumber)
|
||||
err = u.oss.PutObject(ctx, partTosKey, req.Data, storage.WithExpires(time.Now().Add(10*time.Minute)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
partMeta := map[string]interface{}{
|
||||
"tos_key": partTosKey,
|
||||
}
|
||||
partMetaData, err := sonic.Marshal(partMeta)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
partKey := fmt.Sprintf(uploadPartKey, req.UploadID)
|
||||
err = u.cache.HSet(ctx, partKey, req.PartNumber, string(partMetaData)).Err()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = u.cache.Expire(ctx, partKey, time.Minute*10).Err()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &PartUploadFileResponse{
|
||||
Crc32: fmt.Sprintf("%08x", crc32Val),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type tosPart struct {
|
||||
PartNum int
|
||||
Data []byte
|
||||
}
|
||||
|
||||
func getContentType(uri string) (contentType string) {
|
||||
_ = mime.AddExtensionType(".svg", "image/svg+xml")
|
||||
_ = mime.AddExtensionType(".svgz", "image/svg+xml")
|
||||
_ = mime.AddExtensionType(".webp", "image/webp")
|
||||
_ = mime.AddExtensionType(".ico", "image/x-icon")
|
||||
fileExtension := path.Base(uri)
|
||||
ext := path.Ext(fileExtension)
|
||||
contentType = mime.TypeByExtension(ext)
|
||||
return
|
||||
}
|
||||
|
||||
func (u *UploadService) PartUploadFileComplete(ctx context.Context, req *PartUploadFileCompleteRequest) error {
|
||||
partKey := fmt.Sprintf(uploadPartKey, req.UploadID)
|
||||
parts, err := u.cache.HGetAll(ctx, partKey).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tosParts := []*tosPart{}
|
||||
for partNumStr, partData := range parts {
|
||||
var partMeta map[string]string
|
||||
if err := sonic.Unmarshal([]byte(partData), &partMeta); err != nil {
|
||||
return fmt.Errorf("failed to parse part metadata: %v", err)
|
||||
}
|
||||
partNum, err := strconv.ParseInt(partNumStr, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
objKey, exist := partMeta["tos_key"]
|
||||
if !exist {
|
||||
return errors.New("tos key not exist")
|
||||
}
|
||||
byteData, err := u.oss.GetObject(ctx, objKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tosParts = append(tosParts, &tosPart{PartNum: int(partNum), Data: byteData})
|
||||
}
|
||||
if len(tosParts) == 0 {
|
||||
return errors.New("tos part is null")
|
||||
}
|
||||
sort.Slice(tosParts, func(i, j int) bool { return tosParts[i].PartNum < tosParts[j].PartNum })
|
||||
if tosParts[len(tosParts)-1].PartNum != len(tosParts) || len(tosParts) != len(req.Crc32Map) {
|
||||
return errors.New("check parts fail")
|
||||
}
|
||||
totalData := []byte{}
|
||||
for _, val := range tosParts {
|
||||
crc32 := fmt.Sprintf("%08x", crc32.ChecksumIEEE(val.Data))
|
||||
crc32Check := req.Crc32Map[strconv.Itoa(val.PartNum)]
|
||||
if crc32 != crc32Check {
|
||||
return errors.New("crc32 check fail")
|
||||
}
|
||||
totalData = append(totalData, val.Data...)
|
||||
}
|
||||
contentType := getContentType(req.ObjKey)
|
||||
if len(contentType) != 0 {
|
||||
err = u.oss.PutObject(ctx, req.ObjKey, totalData, storage.WithContentType(contentType))
|
||||
} else {
|
||||
err = u.oss.PutObject(ctx, req.ObjKey, totalData)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (u *UploadService) GetIcon(ctx context.Context, req *developer_api.GetIconRequest) (
|
||||
resp *developer_api.GetIconResponse, err error,
|
||||
) {
|
||||
iconURI := map[developer_api.IconType]string{
|
||||
developer_api.IconType_Bot: consts.DefaultAgentIcon,
|
||||
developer_api.IconType_User: consts.DefaultUserIcon,
|
||||
developer_api.IconType_Plugin: consts.DefaultPluginIcon,
|
||||
developer_api.IconType_Dataset: consts.DefaultDatasetIcon,
|
||||
developer_api.IconType_Workflow: consts.DefaultWorkflowIcon,
|
||||
developer_api.IconType_Imageflow: consts.DefaultPluginIcon,
|
||||
developer_api.IconType_Society: consts.DefaultPluginIcon,
|
||||
developer_api.IconType_Connector: consts.DefaultPluginIcon,
|
||||
developer_api.IconType_ChatFlow: consts.DefaultPluginIcon,
|
||||
developer_api.IconType_Voice: consts.DefaultPluginIcon,
|
||||
developer_api.IconType_Enterprise: consts.DefaultTeamIcon,
|
||||
}
|
||||
|
||||
uri := iconURI[req.GetIconType()]
|
||||
if uri == "" {
|
||||
return nil, errorx.New(errno.ErrUploadInvalidType,
|
||||
errorx.KV("type", conv.Int64ToStr(int64(req.GetIconType()))))
|
||||
}
|
||||
|
||||
url, err := u.oss.GetObjectUrl(ctx, iconURI[req.GetIconType()])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &developer_api.GetIconResponse{
|
||||
Data: &developer_api.GetIconResponseData{
|
||||
IconList: []*developer_api.Icon{
|
||||
{
|
||||
URL: url,
|
||||
URI: uri,
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
func stringToMap(input string) map[string]string {
|
||||
result := make(map[string]string)
|
||||
|
||||
pairs := strings.Split(input, ",")
|
||||
|
||||
for _, pair := range pairs {
|
||||
parts := strings.Split(pair, ":")
|
||||
if len(parts) == 2 {
|
||||
key := strings.TrimSpace(parts[0])
|
||||
value := strings.TrimSpace(parts[1])
|
||||
result[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
func (u *UploadService) UploadFileCommon(ctx context.Context, req *developer_api.CommonUploadRequest, fullPath string) (*developer_api.CommonUploadResponse, error) {
|
||||
resp := developer_api.NewCommonUploadResponse()
|
||||
re := regexp.MustCompile(`/api/playground/upload/([^?]+)`)
|
||||
match := re.FindStringSubmatch(fullPath)
|
||||
if len(match) == 0 {
|
||||
return nil, errorx.New(errno.ErrUploadInvalidParamCode, errorx.KV("msg", "tos key not found"))
|
||||
}
|
||||
objKey := match[1]
|
||||
if strings.Contains(fullPath, "?uploads") {
|
||||
uploadID, err := u.PartUploadFileInit(ctx, objKey)
|
||||
if err != nil {
|
||||
return resp, errorx.New(errno.ErrUploadSystemErrorCode, errorx.KV("msg", err.Error()))
|
||||
}
|
||||
resp.Error = &developer_api.Error{Code: 200}
|
||||
resp.Payload = &developer_api.Payload{UploadID: uploadID}
|
||||
return resp, nil
|
||||
}
|
||||
if len(ptr.From(req.PartNumber)) != 0 {
|
||||
_, err := u.PartUploadFile(ctx, &PartUploadFileRequest{
|
||||
UploadID: ptr.From(req.UploadID),
|
||||
PartNumber: ptr.From(req.PartNumber),
|
||||
Data: req.ByteData,
|
||||
})
|
||||
if err != nil {
|
||||
return resp, errorx.New(errno.ErrUploadSystemErrorCode, errorx.KV("msg", err.Error()))
|
||||
}
|
||||
resp.Error = &developer_api.Error{Code: 200}
|
||||
return resp, nil
|
||||
}
|
||||
if len(ptr.From(req.UploadID)) != 0 {
|
||||
mp := stringToMap(string(req.ByteData))
|
||||
err := u.PartUploadFileComplete(ctx, &PartUploadFileCompleteRequest{
|
||||
UploadID: ptr.From(req.UploadID),
|
||||
ObjKey: objKey,
|
||||
Crc32Map: mp,
|
||||
})
|
||||
if err != nil {
|
||||
return resp, errorx.New(errno.ErrUploadSystemErrorCode, errorx.KV("msg", err.Error()))
|
||||
}
|
||||
resp.Error = &developer_api.Error{Code: 200}
|
||||
resp.Payload = &developer_api.Payload{Key: uuid.NewString()}
|
||||
return resp, nil
|
||||
}
|
||||
var err error
|
||||
contentType := getContentType(objKey)
|
||||
if len(contentType) != 0 {
|
||||
err = u.oss.PutObject(ctx, objKey, req.ByteData, storage.WithContentType(contentType))
|
||||
} else {
|
||||
err = u.oss.PutObject(ctx, objKey, req.ByteData)
|
||||
}
|
||||
if err != nil {
|
||||
return resp, errorx.New(errno.ErrUploadSystemErrorCode, errorx.KV("msg", err.Error()))
|
||||
}
|
||||
resp.Error = &developer_api.Error{Code: 200}
|
||||
resp.Payload = &developer_api.Payload{Key: uuid.NewString()}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (u *UploadService) UploadFile(ctx context.Context, data []byte, objKey string) (*developer_api.UploadFileResponse, error) {
|
||||
err := u.oss.PutObject(ctx, objKey, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
url, err := u.oss.GetObjectUrl(ctx, objKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &developer_api.UploadFileResponse{
|
||||
Data: &developer_api.UploadFileData{
|
||||
UploadURL: url,
|
||||
UploadURI: objKey,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (u *UploadService) GetShortcutIcons(ctx context.Context) ([]*playground.FileInfo, error) {
|
||||
shortcutIcons := entity.GetDefaultShortcutIconURI()
|
||||
fileList := make([]*playground.FileInfo, 0, len(shortcutIcons))
|
||||
for _, uri := range shortcutIcons {
|
||||
url, err := u.oss.GetObjectUrl(ctx, uri)
|
||||
if err == nil {
|
||||
fileList = append(fileList, &playground.FileInfo{
|
||||
URL: url,
|
||||
URI: uri,
|
||||
})
|
||||
}
|
||||
}
|
||||
return fileList, nil
|
||||
}
|
||||
|
||||
func parseMultipartFormData(ctx context.Context, req *playground.UploadFileOpenRequest) (*multipart.Form, error) {
|
||||
_, params, err := mime.ParseMediaType(req.ContentType)
|
||||
if err != nil {
|
||||
return nil, errorx.New(errno.ErrUploadInvalidContentTypeCode, errorx.KV("content-type", req.ContentType))
|
||||
}
|
||||
br := bytes.NewReader(req.Data)
|
||||
mr := multipart.NewReader(br, params["boundary"])
|
||||
|
||||
form, err := mr.ReadForm(maxFileSize)
|
||||
if errors.Is(err, multipart.ErrMessageTooLarge) {
|
||||
return nil, errorx.New(errno.ErrUploadInvalidFileSizeCode)
|
||||
} else if err != nil {
|
||||
return nil, errorx.New(errno.ErrUploadMultipartFormDataReadFailedCode)
|
||||
}
|
||||
return form, nil
|
||||
}
|
||||
|
||||
func genObjName(name string, id string) string {
|
||||
|
||||
return fmt.Sprintf("%s/%s/%s",
|
||||
"bot_files",
|
||||
id,
|
||||
name,
|
||||
)
|
||||
}
|
||||
|
||||
func (u *UploadService) UploadFileOpen(ctx context.Context, req *playground.UploadFileOpenRequest) (*playground.UploadFileOpenResponse, error) {
|
||||
resp := playground.UploadFileOpenResponse{}
|
||||
resp.File = new(playground.File)
|
||||
uid := ctxutil.MustGetUIDFromApiAuthCtx(ctx)
|
||||
if uid == 0 {
|
||||
return nil, errorx.New(errno.ErrKnowledgePermissionCode, errorx.KV("msg", "session required"))
|
||||
}
|
||||
form, err := parseMultipartFormData(ctx, req)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "parse multipart form data failed, err: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
if len(form.File["file"]) == 0 {
|
||||
return nil, errorx.New(errno.ErrUploadEmptyFileCode)
|
||||
} else if len(form.File["file"]) > 1 {
|
||||
return nil, errorx.New(errno.ErrUploadFileUploadGreaterOneCode)
|
||||
}
|
||||
fileHeader := form.File["file"][0]
|
||||
|
||||
// open file
|
||||
file, err := fileHeader.Open()
|
||||
if err != nil {
|
||||
return nil, errorx.New(errno.ErrUploadSystemErrorCode, errorx.KV("msg", "fileHeader open failed"))
|
||||
}
|
||||
defer file.Close()
|
||||
data, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
return nil, errorx.New(errno.ErrUploadSystemErrorCode, errorx.KV("msg", "file upload io read failed"))
|
||||
}
|
||||
resp.File.Bytes = int64(len(data))
|
||||
randID := uuid.NewString()
|
||||
objName := genObjName(fileHeader.Filename, randID)
|
||||
resp.File.FileName = fileHeader.Filename
|
||||
resp.File.URI = objName
|
||||
err = u.oss.PutObject(ctx, objName, data)
|
||||
if err != nil {
|
||||
return nil, errorx.New(errno.ErrUploadSystemErrorCode, errorx.KV("msg", "file upload to oss failed"))
|
||||
}
|
||||
url, err := u.oss.GetObjectUrl(ctx, objName)
|
||||
if err != nil {
|
||||
return nil, errorx.New(errno.ErrUploadSystemErrorCode, errorx.KV("msg", "get object url failed"))
|
||||
}
|
||||
resp.File.CreatedAt = time.Now().Unix()
|
||||
resp.File.URL = url
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
func (u *UploadService) GetIconForDataset(ctx context.Context, req *dataset.GetIconRequest) (*dataset.GetIconResponse, error) {
|
||||
resp := dataset.NewGetIconResponse()
|
||||
var uri string
|
||||
switch req.FormatType {
|
||||
case dataset.FormatType_Text:
|
||||
uri = TextKnowledgeDefaultIcon
|
||||
case dataset.FormatType_Table:
|
||||
uri = TableKnowledgeDefaultIcon
|
||||
case dataset.FormatType_Image:
|
||||
uri = ImageKnowledgeDefaultIcon
|
||||
case dataset.FormatType_Database:
|
||||
uri = DatabaseDefaultIcon
|
||||
default:
|
||||
uri = TextKnowledgeDefaultIcon
|
||||
}
|
||||
|
||||
iconUrl, err := u.oss.GetObjectUrl(ctx, uri)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
resp.Icon = &dataset.Icon{
|
||||
URL: iconUrl,
|
||||
URI: uri,
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (u *UploadService) UploadSessionKey(ctx context.Context, sessionKey string, tosKey string) error {
|
||||
return u.cache.Set(ctx, sessionKey, tosKey, time.Minute*30).Err()
|
||||
}
|
||||
|
||||
type GetObjInfoBySessionKey struct {
|
||||
ObjKey string
|
||||
Width int32
|
||||
Height int32
|
||||
}
|
||||
|
||||
func isImageUri(uri string) bool {
|
||||
if uri == "" {
|
||||
return false
|
||||
}
|
||||
uri = strings.ToLower(uri)
|
||||
fileExtension := path.Base(uri)
|
||||
ext := path.Ext(fileExtension)
|
||||
ext = ext[1:]
|
||||
imageExtensions := map[string]bool{
|
||||
"jpg": true,
|
||||
"jpeg": true,
|
||||
"png": true,
|
||||
"gif": true,
|
||||
"bmp": true,
|
||||
"webp": true,
|
||||
"tiff": true,
|
||||
"svg": true,
|
||||
"ico": true,
|
||||
}
|
||||
|
||||
// 检查扩展名是否在图片扩展名列表中
|
||||
return imageExtensions[ext]
|
||||
}
|
||||
func (u *UploadService) GetObjInfoBySessionKey(ctx context.Context, sessionKey string) (*GetObjInfoBySessionKey, error) {
|
||||
resp := GetObjInfoBySessionKey{}
|
||||
objKey, err := u.cache.Get(ctx, sessionKey).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp.ObjKey = objKey
|
||||
if isImageUri(objKey) {
|
||||
content, err := u.oss.GetObject(ctx, objKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if isSVG(objKey) {
|
||||
width, height, err := getSVGDimensions(content)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "get svg dimensions failed, err: %v", err)
|
||||
// default val
|
||||
resp.Width = 100
|
||||
resp.Height = 100
|
||||
return &resp, nil
|
||||
}
|
||||
resp.Width = width
|
||||
resp.Height = height
|
||||
} else {
|
||||
img, _, err := image.Decode(bytes.NewReader(content))
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "decode image failed, err: %v", err)
|
||||
// default val
|
||||
resp.Width = 100
|
||||
resp.Height = 100
|
||||
return &resp, nil
|
||||
}
|
||||
resp.Width = int32(img.Bounds().Dx())
|
||||
resp.Height = int32(img.Bounds().Dy())
|
||||
}
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
type SVG struct {
|
||||
Width string `xml:"width,attr"`
|
||||
Height string `xml:"height,attr"`
|
||||
ViewBox string `xml:"viewBox,attr"`
|
||||
}
|
||||
|
||||
// 获取 SVG 尺寸
|
||||
func getSVGDimensions(content []byte) (width, height int32, err error) {
|
||||
decoder := xml.NewDecoder(bytes.NewReader(content))
|
||||
|
||||
var svg SVG
|
||||
if err := decoder.Decode(&svg); err != nil {
|
||||
return 100, 100, nil
|
||||
}
|
||||
|
||||
// 尝试从width属性获取
|
||||
if svg.Width != "" {
|
||||
w, err := parseDimension(svg.Width)
|
||||
if err == nil {
|
||||
width = w
|
||||
}
|
||||
}
|
||||
|
||||
// 尝试从height属性获取
|
||||
if svg.Height != "" {
|
||||
h, err := parseDimension(svg.Height)
|
||||
if err == nil {
|
||||
height = h
|
||||
}
|
||||
}
|
||||
|
||||
// 如果width或height未设置,尝试从viewBox获取
|
||||
if width == 0 || height == 0 {
|
||||
if svg.ViewBox != "" {
|
||||
parts := strings.Fields(svg.ViewBox)
|
||||
if len(parts) >= 4 {
|
||||
if width == 0 {
|
||||
w, err := strconv.ParseInt(parts[2], 10, 32)
|
||||
if err == nil {
|
||||
width = int32(w)
|
||||
}
|
||||
}
|
||||
if height == 0 {
|
||||
h, err := strconv.ParseInt(parts[3], 10, 32)
|
||||
if err == nil {
|
||||
height = int32(h)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if width == 0 || height == 0 {
|
||||
return 100, 100, nil
|
||||
}
|
||||
|
||||
return width, height, nil
|
||||
}
|
||||
func parseDimension(dim string) (int32, error) {
|
||||
// 去除单位(px, pt, em, %等)和空格
|
||||
dim = strings.TrimSpace(dim)
|
||||
dim = strings.TrimRightFunc(dim, func(r rune) bool {
|
||||
return (r < '0' || r > '9') && r != '.' && r != '-' && r != '+'
|
||||
})
|
||||
|
||||
// 解析为float64
|
||||
value, err := strconv.ParseFloat(dim, 64)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// 四舍五入转换为int32
|
||||
if value > math.MaxInt32 {
|
||||
return math.MaxInt32, nil
|
||||
}
|
||||
if value < math.MinInt32 {
|
||||
return math.MinInt32, nil
|
||||
}
|
||||
return int32(math.Round(value)), nil
|
||||
}
|
||||
func isSVG(uri string) bool {
|
||||
uri = strings.ToLower(uri)
|
||||
fileExtension := path.Base(uri)
|
||||
ext := path.Ext(fileExtension)
|
||||
ext = ext[1:]
|
||||
return ext == "svg"
|
||||
}
|
||||
func (u *UploadService) ApplyImageUpload(ctx context.Context, req *developer_api.ApplyUploadActionRequest, host string) (*developer_api.ApplyUploadActionResponse, error) {
|
||||
resp := developer_api.ApplyUploadActionResponse{}
|
||||
storeUri := "tos-cn-i-v4nquku3lp/" + uuid.NewString() + ptr.From(req.FileExtension)
|
||||
sessionKey := uuid.NewString()
|
||||
auth := uuid.NewString()
|
||||
uploadID := uuid.NewString()
|
||||
uploadHost := string(host) + consts.UploadURI
|
||||
resp.ResponseMetadata = &developer_api.ResponseMetadata{
|
||||
RequestId: uuid.NewString(),
|
||||
Action: "ApplyImageUpload",
|
||||
Version: "",
|
||||
Service: "",
|
||||
Region: "",
|
||||
}
|
||||
resp.Result = &developer_api.ApplyUploadActionResult{
|
||||
UploadAddress: &developer_api.UploadAddress{
|
||||
StoreInfos: []*developer_api.StoreInfo{
|
||||
{
|
||||
StoreUri: storeUri,
|
||||
Auth: auth,
|
||||
UploadID: uploadID,
|
||||
},
|
||||
},
|
||||
UploadHosts: []string{uploadHost},
|
||||
SessionKey: sessionKey,
|
||||
},
|
||||
InnerUploadAddress: &developer_api.InnerUploadAddress{
|
||||
UploadNodes: []*developer_api.UploadNode{
|
||||
{
|
||||
StoreInfos: []*developer_api.StoreInfo{
|
||||
{
|
||||
StoreUri: storeUri,
|
||||
Auth: auth,
|
||||
UploadID: uploadID,
|
||||
},
|
||||
},
|
||||
UploadHost: uploadHost,
|
||||
SessionKey: sessionKey,
|
||||
},
|
||||
},
|
||||
},
|
||||
RequestId: ptr.Of(uuid.NewString()),
|
||||
}
|
||||
err := u.UploadSessionKey(ctx, sessionKey, storeUri)
|
||||
if err != nil {
|
||||
return &resp, errorx.New(errno.ErrUploadSystemErrorCode, errorx.KV("msg", err.Error()))
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
func (u *UploadService) CommitImageUpload(ctx context.Context, req *developer_api.ApplyUploadActionRequest, host string) (*developer_api.ApplyUploadActionResponse, error) {
|
||||
resp := developer_api.ApplyUploadActionResponse{}
|
||||
type ssKey struct {
|
||||
SessionKey string `json:"SessionKey"`
|
||||
}
|
||||
sskey := ssKey{}
|
||||
err := sonic.Unmarshal(req.ByteData, &sskey)
|
||||
if err != nil {
|
||||
return &resp, errorx.New(errno.ErrUploadSystemErrorCode, errorx.KV("msg", err.Error()))
|
||||
}
|
||||
objInfo, err := u.GetObjInfoBySessionKey(ctx, sskey.SessionKey)
|
||||
if err != nil {
|
||||
return &resp, errorx.New(errno.ErrUploadSystemErrorCode, errorx.KV("msg", err.Error()))
|
||||
}
|
||||
resp.ResponseMetadata = &developer_api.ResponseMetadata{
|
||||
RequestId: uuid.NewString(),
|
||||
Action: "ApplyImageUpload",
|
||||
Version: "",
|
||||
Service: "",
|
||||
Region: "",
|
||||
}
|
||||
resp.Result = &developer_api.ApplyUploadActionResult{
|
||||
Results: []*developer_api.UploadResult{
|
||||
{
|
||||
Uri: objInfo.ObjKey,
|
||||
UriStatus: 2000,
|
||||
},
|
||||
},
|
||||
RequestId: ptr.Of(uuid.NewString()),
|
||||
PluginResult: []*developer_api.PluginResult{
|
||||
{
|
||||
FileName: objInfo.ObjKey,
|
||||
SourceUri: objInfo.ObjKey,
|
||||
ImageUri: objInfo.ObjKey,
|
||||
ImageWidth: objInfo.Width,
|
||||
ImageHeight: objInfo.Height,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
40
backend/application/user/init.go
Normal file
40
backend/application/user/init.go
Normal file
@@ -0,0 +1,40 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/user/repository"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/user/service"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/idgen"
|
||||
)
|
||||
|
||||
func InitService(ctx context.Context, db *gorm.DB, oss storage.Storage, idgen idgen.IDGenerator) *UserApplicationService {
|
||||
UserApplicationSVC.DomainSVC = service.NewUserDomain(ctx, &service.Components{
|
||||
IconOSS: oss,
|
||||
IDGen: idgen,
|
||||
UserRepo: repository.NewUserRepo(db),
|
||||
SpaceRepo: repository.NewSpaceRepo(db),
|
||||
})
|
||||
UserApplicationSVC.oss = oss
|
||||
|
||||
return UserApplicationSVC
|
||||
}
|
||||
318
backend/application/user/user.go
Normal file
318
backend/application/user/user.go
Normal file
@@ -0,0 +1,318 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/mail"
|
||||
"strconv"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/developer_api"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/playground"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/passport"
|
||||
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/user/entity"
|
||||
user "github.com/coze-dev/coze-studio/backend/domain/user/service"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
var UserApplicationSVC = &UserApplicationService{}
|
||||
|
||||
type UserApplicationService struct {
|
||||
oss storage.Storage
|
||||
DomainSVC user.User
|
||||
}
|
||||
|
||||
// 添加一个简单的 email 验证函数
|
||||
func isValidEmail(email string) bool {
|
||||
// 如果 email 字符串格式不正确,它会返回一个 error
|
||||
_, err := mail.ParseAddress(email)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func (u *UserApplicationService) PassportWebEmailRegisterV2(ctx context.Context, locale string, req *passport.PassportWebEmailRegisterV2PostRequest) (
|
||||
resp *passport.PassportWebEmailRegisterV2PostResponse, sessionKey string, err error,
|
||||
) {
|
||||
// 验证 email 格式是否合法
|
||||
if !isValidEmail(req.GetEmail()) {
|
||||
return nil, "", errorx.New(errno.ErrUserInvalidParamCode, errorx.KV("msg", "Invalid email"))
|
||||
}
|
||||
|
||||
userInfo, err := u.DomainSVC.Create(ctx, &user.CreateUserRequest{
|
||||
Email: req.GetEmail(),
|
||||
Password: req.GetPassword(),
|
||||
|
||||
Locale: locale,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
userInfo, err = u.DomainSVC.Login(ctx, req.GetEmail(), req.GetPassword())
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
return &passport.PassportWebEmailRegisterV2PostResponse{
|
||||
Data: userDo2PassportTo(userInfo),
|
||||
Code: 0,
|
||||
}, userInfo.SessionKey, nil
|
||||
}
|
||||
|
||||
// PassportWebLogoutGet 处理用户登出请求
|
||||
func (u *UserApplicationService) PassportWebLogoutGet(ctx context.Context, req *passport.PassportWebLogoutGetRequest) (
|
||||
resp *passport.PassportWebLogoutGetResponse, err error,
|
||||
) {
|
||||
uid := ctxutil.MustGetUIDFromCtx(ctx)
|
||||
|
||||
err = u.DomainSVC.Logout(ctx, uid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &passport.PassportWebLogoutGetResponse{
|
||||
Code: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// PassportWebEmailLoginPost 处理用户邮箱登录请求
|
||||
func (u *UserApplicationService) PassportWebEmailLoginPost(ctx context.Context, req *passport.PassportWebEmailLoginPostRequest) (
|
||||
resp *passport.PassportWebEmailLoginPostResponse, sessionKey string, err error,
|
||||
) {
|
||||
userInfo, err := u.DomainSVC.Login(ctx, req.GetEmail(), req.GetPassword())
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
return &passport.PassportWebEmailLoginPostResponse{
|
||||
Data: userDo2PassportTo(userInfo),
|
||||
Code: 0,
|
||||
}, userInfo.SessionKey, nil
|
||||
}
|
||||
|
||||
func (u *UserApplicationService) PassportWebEmailPasswordResetGet(ctx context.Context, req *passport.PassportWebEmailPasswordResetGetRequest) (
|
||||
resp *passport.PassportWebEmailPasswordResetGetResponse, err error,
|
||||
) {
|
||||
err = u.DomainSVC.ResetPassword(ctx, req.GetEmail(), req.GetPassword())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &passport.PassportWebEmailPasswordResetGetResponse{
|
||||
Code: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (u *UserApplicationService) PassportAccountInfoV2(ctx context.Context, req *passport.PassportAccountInfoV2Request) (
|
||||
resp *passport.PassportAccountInfoV2Response, err error,
|
||||
) {
|
||||
userID := ctxutil.MustGetUIDFromCtx(ctx)
|
||||
|
||||
userInfo, err := u.DomainSVC.GetUserInfo(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &passport.PassportAccountInfoV2Response{
|
||||
Data: userDo2PassportTo(userInfo),
|
||||
Code: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UserUpdateAvatar 更新用户头像
|
||||
func (u *UserApplicationService) UserUpdateAvatar(ctx context.Context, mimeType string, req *passport.UserUpdateAvatarRequest) (
|
||||
resp *passport.UserUpdateAvatarResponse, err error,
|
||||
) {
|
||||
// 根据 MIME type 获取文件后缀
|
||||
var ext string
|
||||
switch mimeType {
|
||||
case "image/jpeg", "image/jpg":
|
||||
ext = "jpg"
|
||||
case "image/png":
|
||||
ext = "png"
|
||||
case "image/gif":
|
||||
ext = "gif"
|
||||
case "image/webp":
|
||||
ext = "webp"
|
||||
default:
|
||||
return nil, errorx.WrapByCode(err, errno.ErrUserInvalidParamCode,
|
||||
errorx.KV("msg", "unsupported image type"))
|
||||
}
|
||||
|
||||
uid := ctxutil.MustGetUIDFromCtx(ctx)
|
||||
|
||||
url, err := u.DomainSVC.UpdateAvatar(ctx, uid, ext, req.GetAvatar())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &passport.UserUpdateAvatarResponse{
|
||||
Data: &passport.UserUpdateAvatarResponseData{
|
||||
WebURI: url,
|
||||
},
|
||||
Code: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UserUpdateProfile 更新用户资料
|
||||
func (u *UserApplicationService) UserUpdateProfile(ctx context.Context, req *passport.UserUpdateProfileRequest) (
|
||||
resp *passport.UserUpdateProfileResponse, err error,
|
||||
) {
|
||||
userID := ctxutil.MustGetUIDFromCtx(ctx)
|
||||
|
||||
err = u.DomainSVC.UpdateProfile(ctx, &user.UpdateProfileRequest{
|
||||
UserID: userID,
|
||||
Name: req.Name,
|
||||
UniqueName: req.UserUniqueName,
|
||||
Description: req.Description,
|
||||
Locale: req.Locale,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &passport.UserUpdateProfileResponse{
|
||||
Code: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (u *UserApplicationService) GetSpaceListV2(ctx context.Context, req *playground.GetSpaceListV2Request) (
|
||||
resp *playground.GetSpaceListV2Response, err error,
|
||||
) {
|
||||
uid := ctxutil.MustGetUIDFromCtx(ctx)
|
||||
|
||||
spaces, err := u.DomainSVC.GetUserSpaceList(ctx, uid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
botSpaces := slices.Transform(spaces, func(space *entity.Space) *playground.BotSpaceV2 {
|
||||
return &playground.BotSpaceV2{
|
||||
ID: space.ID,
|
||||
Name: space.Name,
|
||||
Description: space.Description,
|
||||
SpaceType: playground.SpaceType(space.SpaceType),
|
||||
IconURL: space.IconURL,
|
||||
}
|
||||
})
|
||||
|
||||
return &playground.GetSpaceListV2Response{
|
||||
Data: &playground.SpaceInfo{
|
||||
BotSpaceList: botSpaces,
|
||||
HasPersonalSpace: true,
|
||||
TeamSpaceNum: 0,
|
||||
RecentlyUsedSpaceList: botSpaces,
|
||||
Total: ptr.Of(int32(len(botSpaces))),
|
||||
HasMore: ptr.Of(false),
|
||||
},
|
||||
Code: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (u *UserApplicationService) MGetUserBasicInfo(ctx context.Context, req *playground.MGetUserBasicInfoRequest) (
|
||||
resp *playground.MGetUserBasicInfoResponse, err error,
|
||||
) {
|
||||
userIDs, err := slices.TransformWithErrorCheck(req.GetUserIds(), func(s string) (int64, error) {
|
||||
return strconv.ParseInt(s, 10, 64)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errorx.WrapByCode(err, errno.ErrUserInvalidParamCode, errorx.KV("msg", "invalid user id"))
|
||||
}
|
||||
|
||||
userInfos, err := u.DomainSVC.MGetUserProfiles(ctx, userIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &playground.MGetUserBasicInfoResponse{
|
||||
UserBasicInfoMap: slices.ToMap(userInfos, func(userInfo *entity.User) (string, *playground.UserBasicInfo) {
|
||||
return strconv.FormatInt(userInfo.UserID, 10), userDo2PlaygroundTo(userInfo)
|
||||
}),
|
||||
Code: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (u *UserApplicationService) UpdateUserProfileCheck(ctx context.Context, req *developer_api.UpdateUserProfileCheckRequest) (resp *developer_api.UpdateUserProfileCheckResponse, err error) {
|
||||
if req.GetUserUniqueName() == "" {
|
||||
return &developer_api.UpdateUserProfileCheckResponse{
|
||||
Code: 0,
|
||||
Msg: "no content to update",
|
||||
}, nil
|
||||
}
|
||||
|
||||
validateResp, err := u.DomainSVC.ValidateProfileUpdate(ctx, &user.ValidateProfileUpdateRequest{
|
||||
UniqueName: req.UserUniqueName,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &developer_api.UpdateUserProfileCheckResponse{
|
||||
Code: int64(validateResp.Code),
|
||||
Msg: validateResp.Msg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (u *UserApplicationService) ValidateSession(ctx context.Context, sessionKey string) (*entity.Session, error) {
|
||||
session, exist, err := u.DomainSVC.ValidateSession(ctx, sessionKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !exist {
|
||||
return nil, errorx.New(errno.ErrUserAuthenticationFailed, errorx.KV("reason", "session not exist"))
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func userDo2PassportTo(userDo *entity.User) *passport.User {
|
||||
var locale *string
|
||||
if userDo.Locale != "" {
|
||||
locale = ptr.Of(userDo.Locale)
|
||||
}
|
||||
|
||||
return &passport.User{
|
||||
UserIDStr: userDo.UserID,
|
||||
Name: userDo.Name,
|
||||
ScreenName: ptr.Of(userDo.Name),
|
||||
UserUniqueName: userDo.UniqueName,
|
||||
Email: userDo.Email,
|
||||
Description: userDo.Description,
|
||||
AvatarURL: userDo.IconURL,
|
||||
AppUserInfo: &passport.AppUserInfo{
|
||||
UserUniqueName: userDo.UniqueName,
|
||||
},
|
||||
Locale: locale,
|
||||
|
||||
UserCreateTime: userDo.CreatedAt / 1000,
|
||||
}
|
||||
}
|
||||
|
||||
func userDo2PlaygroundTo(userDo *entity.User) *playground.UserBasicInfo {
|
||||
return &playground.UserBasicInfo{
|
||||
UserId: userDo.UserID,
|
||||
Username: userDo.Name,
|
||||
UserUniqueName: ptr.Of(userDo.UniqueName),
|
||||
UserAvatar: userDo.IconURL,
|
||||
CreateTime: ptr.Of(userDo.CreatedAt / 1000),
|
||||
}
|
||||
}
|
||||
87
backend/application/workflow/init.go
Normal file
87
backend/application/workflow/init.go
Normal file
@@ -0,0 +1,87 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package workflow
|
||||
|
||||
import (
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"gorm.io/gorm"
|
||||
|
||||
wfdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/database"
|
||||
wfknowledge "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/knowledge"
|
||||
wfmodel "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/model"
|
||||
wfplugin "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/plugin"
|
||||
wfsearch "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/search"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/workflow/variable"
|
||||
knowledge "github.com/coze-dev/coze-studio/backend/domain/knowledge/service"
|
||||
dbservice "github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
|
||||
variables "github.com/coze-dev/coze-studio/backend/domain/memory/variables/service"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/modelmgr"
|
||||
plugin "github.com/coze-dev/coze-studio/backend/domain/plugin/service"
|
||||
search "github.com/coze-dev/coze-studio/backend/domain/search/service"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
crosscode "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/code"
|
||||
crossdatabase "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
crossknowledge "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
|
||||
crossmodel "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
|
||||
crossplugin "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
|
||||
crosssearch "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/search"
|
||||
crossvariable "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/service"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/coderunner"
|
||||
)
|
||||
|
||||
type ServiceComponents struct {
|
||||
IDGen idgen.IDGenerator
|
||||
DB *gorm.DB
|
||||
Cache *redis.Client
|
||||
DatabaseDomainSVC dbservice.Database
|
||||
VariablesDomainSVC variables.Variables
|
||||
PluginDomainSVC plugin.PluginService
|
||||
KnowledgeDomainSVC knowledge.Knowledge
|
||||
ModelManager modelmgr.Manager
|
||||
DomainNotifier search.ResourceEventBus
|
||||
Tos storage.Storage
|
||||
ImageX imagex.ImageX
|
||||
CPStore compose.CheckPointStore
|
||||
}
|
||||
|
||||
func InitService(components *ServiceComponents) *ApplicationService {
|
||||
workflowRepo := service.NewWorkflowRepository(components.IDGen, components.DB, components.Cache,
|
||||
components.Tos, components.CPStore)
|
||||
workflow.SetRepository(workflowRepo)
|
||||
|
||||
workflowDomainSVC := service.NewWorkflowService(workflowRepo)
|
||||
crossdatabase.SetDatabaseOperator(wfdatabase.NewDatabaseRepository(components.DatabaseDomainSVC))
|
||||
crossvariable.SetVariableHandler(variable.NewVariableHandler(components.VariablesDomainSVC))
|
||||
crossvariable.SetVariablesMetaGetter(variable.NewVariablesMetaGetter(components.VariablesDomainSVC))
|
||||
crossplugin.SetPluginService(wfplugin.NewPluginService(components.PluginDomainSVC, components.Tos))
|
||||
crossknowledge.SetKnowledgeOperator(wfknowledge.NewKnowledgeRepository(components.KnowledgeDomainSVC, components.IDGen))
|
||||
crossmodel.SetManager(wfmodel.NewModelManager(components.ModelManager, nil))
|
||||
crosscode.SetCodeRunner(coderunner.NewRunner())
|
||||
crosssearch.SetNotifier(wfsearch.NewNotify(components.DomainNotifier))
|
||||
|
||||
SVC.DomainSVC = workflowDomainSVC
|
||||
SVC.ImageX = components.ImageX
|
||||
SVC.TosClient = components.Tos
|
||||
SVC.IDGenerator = components.IDGen
|
||||
|
||||
return SVC
|
||||
}
|
||||
3815
backend/application/workflow/workflow.go
Normal file
3815
backend/application/workflow/workflow.go
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user