feat: manually mirror opencoze's code from bytedance

Change-Id: I09a73aadda978ad9511264a756b2ce51f5761adf
This commit is contained in:
fanlv
2025-07-20 17:36:12 +08:00
commit 890153324f
14811 changed files with 1923430 additions and 0 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,56 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package app
import (
resourceCommon "github.com/coze-dev/coze-studio/backend/api/model/resource/common"
"github.com/coze-dev/coze-studio/backend/domain/app/entity"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
func toResourceType(resType resourceCommon.ResType) (entity.ResourceType, error) {
switch resType {
case resourceCommon.ResType_Plugin:
return entity.ResourceTypeOfPlugin, nil
case resourceCommon.ResType_Workflow:
return entity.ResourceTypeOfWorkflow, nil
case resourceCommon.ResType_Knowledge:
return entity.ResourceTypeOfKnowledge, nil
case resourceCommon.ResType_Database:
return entity.ResourceTypeOfDatabase, nil
default:
return "", errorx.New(errno.ErrAppInvalidParamCode,
errorx.KVf(errno.APPMsgKey, "unsupported resource type '%s'", resType))
}
}
func toThriftResourceType(resType entity.ResourceType) (resourceCommon.ResType, error) {
switch resType {
case entity.ResourceTypeOfPlugin:
return resourceCommon.ResType_Plugin, nil
case entity.ResourceTypeOfWorkflow:
return resourceCommon.ResType_Workflow, nil
case entity.ResourceTypeOfKnowledge:
return resourceCommon.ResType_Knowledge, nil
case entity.ResourceTypeOfDatabase:
return resourceCommon.ResType_Database, nil
default:
return 0, errorx.New(errno.ErrAppInvalidParamCode,
errorx.KVf(errno.APPMsgKey, "unsupported resource type '%s'", resType))
}
}

View File

@@ -0,0 +1,71 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package app
import (
redisV9 "github.com/redis/go-redis/v9"
"gorm.io/gorm"
"github.com/coze-dev/coze-studio/backend/domain/app/repository"
"github.com/coze-dev/coze-studio/backend/domain/app/service"
connector "github.com/coze-dev/coze-studio/backend/domain/connector/service"
variables "github.com/coze-dev/coze-studio/backend/domain/memory/variables/service"
search "github.com/coze-dev/coze-studio/backend/domain/search/service"
user "github.com/coze-dev/coze-studio/backend/domain/user/service"
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
)
type ServiceComponents struct {
IDGen idgen.IDGenerator
DB *gorm.DB
OSS storage.Storage
CacheCli *redisV9.Client
ProjectEventBus search.ProjectEventBus
UserSVC user.User
ConnectorSVC connector.Connector
VariablesSVC variables.Variables
}
func InitService(components *ServiceComponents) (*APPApplicationService, error) {
appRepo := repository.NewAPPRepo(&repository.APPRepoComponents{
IDGen: components.IDGen,
DB: components.DB,
CacheCli: components.CacheCli,
})
domainComponents := &service.Components{
IDGen: components.IDGen,
DB: components.DB,
APPRepo: appRepo,
}
domainSVC := service.NewService(domainComponents)
APPApplicationSVC.DomainSVC = domainSVC
APPApplicationSVC.appRepo = appRepo
APPApplicationSVC.oss = components.OSS
APPApplicationSVC.projectEventBus = components.ProjectEventBus
APPApplicationSVC.userSVC = components.UserSVC
APPApplicationSVC.connectorSVC = components.ConnectorSVC
APPApplicationSVC.variablesSVC = components.VariablesSVC
return APPApplicationSVC, nil
}

View File

@@ -0,0 +1,32 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package app
import (
"github.com/coze-dev/coze-studio/backend/api/model/resource/common"
)
type copyMetaInfo struct {
scene common.ResourceCopyScene
userID int64
appSpaceID int64
copyTaskID string
fromAppID int64
toAppID *int64
}

View File

@@ -0,0 +1,362 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package application
import (
"context"
"fmt"
"github.com/coze-dev/coze-studio/backend/application/openauth"
"github.com/coze-dev/coze-studio/backend/application/template"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crosssearch"
"github.com/coze-dev/coze-studio/backend/application/app"
"github.com/coze-dev/coze-studio/backend/application/base/appinfra"
"github.com/coze-dev/coze-studio/backend/application/connector"
"github.com/coze-dev/coze-studio/backend/application/conversation"
"github.com/coze-dev/coze-studio/backend/application/knowledge"
"github.com/coze-dev/coze-studio/backend/application/memory"
"github.com/coze-dev/coze-studio/backend/application/modelmgr"
"github.com/coze-dev/coze-studio/backend/application/plugin"
"github.com/coze-dev/coze-studio/backend/application/prompt"
"github.com/coze-dev/coze-studio/backend/application/search"
"github.com/coze-dev/coze-studio/backend/application/shortcutcmd"
"github.com/coze-dev/coze-studio/backend/application/singleagent"
"github.com/coze-dev/coze-studio/backend/application/upload"
"github.com/coze-dev/coze-studio/backend/application/user"
"github.com/coze-dev/coze-studio/backend/application/workflow"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossagent"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossagentrun"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossconnector"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossconversation"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossdatabase"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossdatacopy"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossknowledge"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossmessage"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossmodelmgr"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossplugin"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossuser"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossvariables"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossworkflow"
agentrunImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/agentrun"
connectorImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/connector"
conversationImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/conversation"
crossuserImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/crossuser"
databaseImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/database"
dataCopyImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/datacopy"
knowledgeImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/knowledge"
messageImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/message"
modelmgrImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/modelmgr"
pluginImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/plugin"
searchImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/search"
singleagentImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/singleagent"
variablesImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/variables"
workflowImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/workflow"
"github.com/coze-dev/coze-studio/backend/infra/impl/checkpoint"
)
type eventbusImpl struct {
resourceEventBus search.ResourceEventBus
projectEventBus search.ProjectEventBus
}
type basicServices struct {
infra *appinfra.AppDependencies
eventbus *eventbusImpl
modelMgrSVC *modelmgr.ModelmgrApplicationService
connectorSVC *connector.ConnectorApplicationService
userSVC *user.UserApplicationService
promptSVC *prompt.PromptApplicationService
templateSVC *template.ApplicationService
openAuthSVC *openauth.OpenAuthApplicationService
}
type primaryServices struct {
basicServices *basicServices
infra *appinfra.AppDependencies
pluginSVC *plugin.PluginApplicationService
memorySVC *memory.MemoryApplicationServices
knowledgeSVC *knowledge.KnowledgeApplicationService
workflowSVC *workflow.ApplicationService
shortcutSVC *shortcutcmd.ShortcutCmdApplicationService
}
type complexServices struct {
primaryServices *primaryServices
singleAgentSVC *singleagent.SingleAgentApplicationService
appSVC *app.APPApplicationService
searchSVC *search.SearchApplicationService
conversationSVC *conversation.ConversationApplicationService
}
func Init(ctx context.Context) (err error) {
infra, err := appinfra.Init(ctx)
if err != nil {
return err
}
eventbus := initEventBus(infra)
basicServices, err := initBasicServices(ctx, infra, eventbus)
if err != nil {
return fmt.Errorf("Init - initBasicServices failed, err: %v", err)
}
primaryServices, err := initPrimaryServices(ctx, basicServices)
if err != nil {
return fmt.Errorf("Init - initPrimaryServices failed, err: %v", err)
}
complexServices, err := initComplexServices(ctx, primaryServices)
if err != nil {
return fmt.Errorf("Init - initVitalServices failed, err: %v", err)
}
crossconnector.SetDefaultSVC(connectorImpl.InitDomainService(basicServices.connectorSVC.DomainSVC))
crossdatabase.SetDefaultSVC(databaseImpl.InitDomainService(primaryServices.memorySVC.DatabaseDomainSVC))
crossknowledge.SetDefaultSVC(knowledgeImpl.InitDomainService(primaryServices.knowledgeSVC.DomainSVC))
crossmodelmgr.SetDefaultSVC(modelmgrImpl.InitDomainService(basicServices.modelMgrSVC.DomainSVC))
crossplugin.SetDefaultSVC(pluginImpl.InitDomainService(primaryServices.pluginSVC.DomainSVC))
crossvariables.SetDefaultSVC(variablesImpl.InitDomainService(primaryServices.memorySVC.VariablesDomainSVC))
crossworkflow.SetDefaultSVC(workflowImpl.InitDomainService(primaryServices.workflowSVC.DomainSVC))
crossconversation.SetDefaultSVC(conversationImpl.InitDomainService(complexServices.conversationSVC.ConversationDomainSVC))
crossmessage.SetDefaultSVC(messageImpl.InitDomainService(complexServices.conversationSVC.MessageDomainSVC))
crossagentrun.SetDefaultSVC(agentrunImpl.InitDomainService(complexServices.conversationSVC.AgentRunDomainSVC))
crossagent.SetDefaultSVC(singleagentImpl.InitDomainService(complexServices.singleAgentSVC.DomainSVC))
crossuser.SetDefaultSVC(crossuserImpl.InitDomainService(basicServices.userSVC.DomainSVC))
crossdatacopy.SetDefaultSVC(dataCopyImpl.InitDomainService(basicServices.infra))
crosssearch.SetDefaultSVC(searchImpl.InitDomainService(complexServices.searchSVC.DomainSVC))
return nil
}
func initEventBus(infra *appinfra.AppDependencies) *eventbusImpl {
e := &eventbusImpl{}
e.resourceEventBus = search.NewResourceEventBus(infra.ResourceEventProducer)
e.projectEventBus = search.NewProjectEventBus(infra.AppEventProducer)
return e
}
// initBasicServices init basic services that only depends on infra.
func initBasicServices(ctx context.Context, infra *appinfra.AppDependencies, e *eventbusImpl) (*basicServices, error) {
upload.InitService(infra.TOSClient, infra.CacheCli)
openAuthSVC := openauth.InitService(infra.DB, infra.IDGenSVC)
promptSVC := prompt.InitService(infra.DB, infra.IDGenSVC, e.resourceEventBus)
modelMgrSVC, err := modelmgr.InitService(infra.DB, infra.IDGenSVC, infra.TOSClient)
if err != nil {
return nil, err
}
connectorSVC := connector.InitService(infra.TOSClient)
userSVC := user.InitService(ctx, infra.DB, infra.TOSClient, infra.IDGenSVC)
templateSVC := template.InitService(ctx, &template.ServiceComponents{
DB: infra.DB,
IDGen: infra.IDGenSVC,
Storage: infra.TOSClient,
})
return &basicServices{
infra: infra,
eventbus: e,
modelMgrSVC: modelMgrSVC,
connectorSVC: connectorSVC,
userSVC: userSVC,
promptSVC: promptSVC,
templateSVC: templateSVC,
openAuthSVC: openAuthSVC,
}, nil
}
// initPrimaryServices init primary services that depends on basic services.
func initPrimaryServices(ctx context.Context, basicServices *basicServices) (*primaryServices, error) {
pluginSVC, err := plugin.InitService(ctx, basicServices.toPluginServiceComponents())
if err != nil {
return nil, err
}
memorySVC := memory.InitService(basicServices.toMemoryServiceComponents())
knowledgeSVC, err := knowledge.InitService(basicServices.toKnowledgeServiceComponents(memorySVC))
if err != nil {
return nil, err
}
workflowDomainSVC := workflow.InitService(
basicServices.toWorkflowServiceComponents(pluginSVC, memorySVC, knowledgeSVC))
shortcutSVC := shortcutcmd.InitService(basicServices.infra.DB, basicServices.infra.IDGenSVC)
return &primaryServices{
basicServices: basicServices,
pluginSVC: pluginSVC,
memorySVC: memorySVC,
knowledgeSVC: knowledgeSVC,
workflowSVC: workflowDomainSVC,
shortcutSVC: shortcutSVC,
infra: basicServices.infra,
}, nil
}
// initComplexServices init complex services that depends on primary services.
func initComplexServices(ctx context.Context, p *primaryServices) (*complexServices, error) {
singleAgentSVC, err := singleagent.InitService(p.toSingleAgentServiceComponents())
if err != nil {
return nil, err
}
appSVC, err := app.InitService(p.toAPPServiceComponents())
if err != nil {
return nil, err
}
searchSVC, err := search.InitService(ctx, p.toSearchServiceComponents(singleAgentSVC, appSVC))
if err != nil {
return nil, err
}
conversationSVC := conversation.InitService(p.toConversationComponents(singleAgentSVC))
return &complexServices{
primaryServices: p,
singleAgentSVC: singleAgentSVC,
appSVC: appSVC,
searchSVC: searchSVC,
conversationSVC: conversationSVC,
}, nil
}
func (b *basicServices) toPluginServiceComponents() *plugin.ServiceComponents {
return &plugin.ServiceComponents{
IDGen: b.infra.IDGenSVC,
DB: b.infra.DB,
EventBus: b.eventbus.resourceEventBus,
OSS: b.infra.TOSClient,
UserSVC: b.userSVC.DomainSVC,
}
}
func (b *basicServices) toKnowledgeServiceComponents(memoryService *memory.MemoryApplicationServices) *knowledge.ServiceComponents {
return &knowledge.ServiceComponents{
DB: b.infra.DB,
IDGenSVC: b.infra.IDGenSVC,
Storage: b.infra.TOSClient,
RDB: memoryService.RDBDomainSVC,
ImageX: b.infra.ImageXClient,
ES: b.infra.ESClient,
EventBus: b.eventbus.resourceEventBus,
CacheCli: b.infra.CacheCli,
}
}
func (b *basicServices) toMemoryServiceComponents() *memory.ServiceComponents {
return &memory.ServiceComponents{
IDGen: b.infra.IDGenSVC,
DB: b.infra.DB,
EventBus: b.eventbus.resourceEventBus,
TosClient: b.infra.TOSClient,
ResourceDomainNotifier: b.eventbus.resourceEventBus,
CacheCli: b.infra.CacheCli,
}
}
func (b *basicServices) toWorkflowServiceComponents(pluginSVC *plugin.PluginApplicationService, memorySVC *memory.MemoryApplicationServices, knowledgeSVC *knowledge.KnowledgeApplicationService) *workflow.ServiceComponents {
return &workflow.ServiceComponents{
IDGen: b.infra.IDGenSVC,
DB: b.infra.DB,
Cache: b.infra.CacheCli,
Tos: b.infra.TOSClient,
ImageX: b.infra.ImageXClient,
DatabaseDomainSVC: memorySVC.DatabaseDomainSVC,
VariablesDomainSVC: memorySVC.VariablesDomainSVC,
PluginDomainSVC: pluginSVC.DomainSVC,
KnowledgeDomainSVC: knowledgeSVC.DomainSVC,
ModelManager: b.modelMgrSVC.DomainSVC,
DomainNotifier: b.eventbus.resourceEventBus,
CPStore: checkpoint.NewRedisStore(b.infra.CacheCli),
}
}
func (p *primaryServices) toSingleAgentServiceComponents() *singleagent.ServiceComponents {
return &singleagent.ServiceComponents{
IDGen: p.basicServices.infra.IDGenSVC,
DB: p.basicServices.infra.DB,
Cache: p.basicServices.infra.CacheCli,
TosClient: p.basicServices.infra.TOSClient,
ImageX: p.basicServices.infra.ImageXClient,
ModelMgrDomainSVC: p.basicServices.modelMgrSVC.DomainSVC,
UserDomainSVC: p.basicServices.userSVC.DomainSVC,
EventBus: p.basicServices.eventbus.projectEventBus,
DatabaseDomainSVC: p.memorySVC.DatabaseDomainSVC,
ConnectorDomainSVC: p.basicServices.connectorSVC.DomainSVC,
KnowledgeDomainSVC: p.knowledgeSVC.DomainSVC,
PluginDomainSVC: p.pluginSVC.DomainSVC,
WorkflowDomainSVC: p.workflowSVC.DomainSVC,
VariablesDomainSVC: p.memorySVC.VariablesDomainSVC,
ShortcutCMDDomainSVC: p.shortcutSVC.ShortCutDomainSVC,
CPStore: checkpoint.NewRedisStore(p.infra.CacheCli),
}
}
func (p *primaryServices) toSearchServiceComponents(singleAgentSVC *singleagent.SingleAgentApplicationService, appSVC *app.APPApplicationService) *search.ServiceComponents {
infra := p.basicServices.infra
return &search.ServiceComponents{
DB: infra.DB,
Cache: infra.CacheCli,
TOS: infra.TOSClient,
ESClient: infra.ESClient,
ProjectEventBus: p.basicServices.eventbus.projectEventBus,
SingleAgentDomainSVC: singleAgentSVC.DomainSVC,
APPDomainSVC: appSVC.DomainSVC,
KnowledgeDomainSVC: p.knowledgeSVC.DomainSVC,
PluginDomainSVC: p.pluginSVC.DomainSVC,
WorkflowDomainSVC: p.workflowSVC.DomainSVC,
UserDomainSVC: p.basicServices.userSVC.DomainSVC,
ConnectorDomainSVC: p.basicServices.connectorSVC.DomainSVC,
PromptDomainSVC: p.basicServices.promptSVC.DomainSVC,
DatabaseDomainSVC: p.memorySVC.DatabaseDomainSVC,
}
}
func (p *primaryServices) toAPPServiceComponents() *app.ServiceComponents {
infra := p.basicServices.infra
basic := p.basicServices
return &app.ServiceComponents{
IDGen: infra.IDGenSVC,
DB: infra.DB,
OSS: infra.TOSClient,
CacheCli: infra.CacheCli,
ProjectEventBus: basic.eventbus.projectEventBus,
UserSVC: basic.userSVC.DomainSVC,
ConnectorSVC: basic.connectorSVC.DomainSVC,
VariablesSVC: p.memorySVC.VariablesDomainSVC,
}
}
func (p *primaryServices) toConversationComponents(singleAgentSVC *singleagent.SingleAgentApplicationService) *conversation.ServiceComponents {
infra := p.basicServices.infra
return &conversation.ServiceComponents{
DB: infra.DB,
IDGen: infra.IDGenSVC,
TosClient: infra.TOSClient,
ImageX: infra.ImageXClient,
SingleAgentDomainSVC: singleAgentSVC.DomainSVC,
}
}

View File

@@ -0,0 +1,132 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package appinfra
import (
"context"
"fmt"
"os"
"gorm.io/gorm"
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
"github.com/coze-dev/coze-studio/backend/infra/impl/cache/redis"
"github.com/coze-dev/coze-studio/backend/infra/impl/es"
"github.com/coze-dev/coze-studio/backend/infra/impl/eventbus"
"github.com/coze-dev/coze-studio/backend/infra/impl/idgen"
"github.com/coze-dev/coze-studio/backend/infra/impl/imagex/veimagex"
"github.com/coze-dev/coze-studio/backend/infra/impl/mysql"
"github.com/coze-dev/coze-studio/backend/infra/impl/storage"
"github.com/coze-dev/coze-studio/backend/types/consts"
)
type AppDependencies struct {
DB *gorm.DB
CacheCli *redis.Client
IDGenSVC idgen.IDGenerator
ESClient es.Client
ImageXClient imagex.ImageX
TOSClient storage.Storage
ResourceEventProducer eventbus.Producer
AppEventProducer eventbus.Producer
}
func Init(ctx context.Context) (*AppDependencies, error) {
deps := &AppDependencies{}
var err error
deps.DB, err = mysql.New()
if err != nil {
return nil, err
}
deps.CacheCli = redis.New()
deps.IDGenSVC, err = idgen.New(deps.CacheCli)
if err != nil {
return nil, err
}
deps.ESClient, err = es.New()
if err != nil {
return nil, err
}
deps.ImageXClient, err = initImageX(ctx)
if err != nil {
return nil, err
}
deps.TOSClient, err = initTOS(ctx)
if err != nil {
return nil, err
}
deps.ResourceEventProducer, err = initResourceEventBusProducer()
if err != nil {
return nil, err
}
deps.AppEventProducer, err = initAppEventProducer()
if err != nil {
return nil, err
}
return deps, nil
}
func initImageX(ctx context.Context) (imagex.ImageX, error) {
uploadComponentType := os.Getenv(consts.FileUploadComponentType)
if uploadComponentType != consts.FileUploadComponentTypeImagex {
return storage.NewImagex(ctx)
}
return veimagex.New(
os.Getenv(consts.VeImageXAK),
os.Getenv(consts.VeImageXSK),
os.Getenv(consts.VeImageXDomain),
os.Getenv(consts.VeImageXUploadHost),
os.Getenv(consts.VeImageXTemplate),
[]string{os.Getenv(consts.VeImageXServerID)},
)
}
func initTOS(ctx context.Context) (storage.Storage, error) {
return storage.New(ctx)
}
func initResourceEventBusProducer() (eventbus.Producer, error) {
nameServer := os.Getenv(consts.MQServer)
resourceEventBusProducer, err := eventbus.NewProducer(nameServer,
consts.RMQTopicResource, consts.RMQConsumeGroupResource, 1)
if err != nil {
return nil, fmt.Errorf("init resource producer failed, err=%w", err)
}
return resourceEventBusProducer, nil
}
func initAppEventProducer() (eventbus.Producer, error) {
nameServer := os.Getenv(consts.MQServer)
appEventProducer, err := eventbus.NewProducer(nameServer, consts.RMQTopicApp, consts.RMQConsumeGroupApp, 1)
if err != nil {
return nil, fmt.Errorf("init app producer failed, err=%w", err)
}
return appEventProducer, nil
}

View File

@@ -0,0 +1,42 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ctxutil
import (
"context"
"github.com/coze-dev/coze-studio/backend/domain/openauth/openapiauth/entity"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/types/consts"
)
func GetApiAuthFromCtx(ctx context.Context) *entity.ApiKey {
data, ok := ctxcache.Get[*entity.ApiKey](ctx, consts.OpenapiAuthKeyInCtx)
if !ok {
return nil
}
return data
}
func MustGetUIDFromApiAuthCtx(ctx context.Context) int64 {
apiKeyInfo := GetApiAuthFromCtx(ctx)
if apiKeyInfo == nil {
panic("mustGetUIDFromApiAuthCtx: apiKeyInfo is nil")
}
return apiKeyInfo.UserID
}

View File

@@ -0,0 +1,35 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ctxutil
import (
"context"
)
func GetRequestFullPathFromCtx(ctx context.Context) string {
contextValue := ctx.Value("request.full_path")
if contextValue == nil {
return ""
}
fullPath, ok := contextValue.(string)
if !ok {
return ""
}
return fullPath
}

View File

@@ -0,0 +1,52 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ctxutil
import (
"context"
"github.com/coze-dev/coze-studio/backend/domain/user/entity"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/types/consts"
)
func GetUserSessionFromCtx(ctx context.Context) *entity.Session {
data, ok := ctxcache.Get[*entity.Session](ctx, consts.SessionDataKeyInCtx)
if !ok {
return nil
}
return data
}
func MustGetUIDFromCtx(ctx context.Context) int64 {
sessionData := GetUserSessionFromCtx(ctx)
if sessionData == nil {
panic("mustGetUIDFromCtx: sessionData is nil")
}
return sessionData.UserID
}
func GetUIDFromCtx(ctx context.Context) *int64 {
sessionData := GetUserSessionFromCtx(ctx)
if sessionData == nil {
return nil
}
return &sessionData.UserID
}

View File

@@ -0,0 +1,327 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package pluginutil
import (
"net/http"
"strconv"
"github.com/getkin/kin-openapi/openapi3"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
func APIParamsToOpenapiOperation(reqParams, respParams []*common.APIParameter) (*openapi3.Operation, error) {
op := &openapi3.Operation{}
hasSetReqBody := false
hasSetParams := false
for _, apiParam := range reqParams {
if apiParam.Location != common.ParameterLocation_Body {
if !hasSetParams {
hasSetParams = true
op.Parameters = []*openapi3.ParameterRef{}
}
_apiParam, err := toOpenapiParameter(apiParam)
if err != nil {
return nil, err
}
op.Parameters = append(op.Parameters, &openapi3.ParameterRef{
Value: _apiParam,
})
continue
}
var mType *openapi3.MediaType
if hasSetReqBody {
mType = op.RequestBody.Value.Content[plugin.MediaTypeJson]
} else {
hasSetReqBody = true
mType = &openapi3.MediaType{
Schema: &openapi3.SchemaRef{
Value: &openapi3.Schema{
Type: openapi3.TypeObject,
Properties: map[string]*openapi3.SchemaRef{},
},
},
}
op.RequestBody = &openapi3.RequestBodyRef{
Value: &openapi3.RequestBody{
Content: map[string]*openapi3.MediaType{
plugin.MediaTypeJson: mType,
},
},
}
}
_apiParam, err := toOpenapi3Schema(apiParam)
if err != nil {
return nil, err
}
mType.Schema.Value.Properties[apiParam.Name] = &openapi3.SchemaRef{
Value: _apiParam,
}
if apiParam.IsRequired {
mType.Schema.Value.Required = append(mType.Schema.Value.Required, apiParam.Name)
}
}
hasSetRespBody := false
for _, apiParam := range respParams {
if !hasSetRespBody {
hasSetRespBody = true
op.Responses = map[string]*openapi3.ResponseRef{
strconv.Itoa(http.StatusOK): {
Value: &openapi3.Response{
Content: map[string]*openapi3.MediaType{
plugin.MediaTypeJson: {
Schema: &openapi3.SchemaRef{
Value: &openapi3.Schema{
Type: openapi3.TypeObject,
Properties: map[string]*openapi3.SchemaRef{},
},
},
},
},
},
},
}
}
_apiParam, err := toOpenapi3Schema(apiParam)
if err != nil {
return nil, err
}
resp, _ := op.Responses[strconv.Itoa(http.StatusOK)]
mType, _ := resp.Value.Content[plugin.MediaTypeJson] // only support application/json
mType.Schema.Value.Properties[apiParam.Name] = &openapi3.SchemaRef{
Value: _apiParam,
}
if apiParam.IsRequired {
mType.Schema.Value.Required = append(mType.Schema.Value.Required, apiParam.Name)
}
}
return op, nil
}
func toOpenapiParameter(apiParam *common.APIParameter) (*openapi3.Parameter, error) {
paramType, ok := plugin.ToOpenapiParamType(apiParam.Type)
if !ok {
return nil, errorx.New(errno.ErrPluginInvalidParamCode,
errorx.KVf(errno.PluginMsgKey, "the type '%s' of field '%s' is invalid", apiParam.Type, apiParam.Name))
}
if paramType == openapi3.TypeObject {
return nil, errorx.New(errno.ErrPluginInvalidParamCode,
errorx.KVf(errno.PluginMsgKey, "the type of field '%s' cannot be 'object'", apiParam.Name))
}
paramSchema := &openapi3.Schema{
Type: paramType,
Default: apiParam.GlobalDefault,
Extensions: map[string]interface{}{
plugin.APISchemaExtendGlobalDisable: apiParam.GlobalDisable,
},
}
if paramType == openapi3.TypeArray {
if apiParam.Location == common.ParameterLocation_Path {
return nil, errorx.New(errno.ErrPluginInvalidParamCode,
errorx.KVf(errno.PluginMsgKey, "the type of field '%s' cannot be 'array'", apiParam.Name))
}
if len(apiParam.SubParameters) == 0 {
return nil, errorx.New(errno.ErrPluginInvalidParamCode,
errorx.KVf(errno.PluginMsgKey, "the sub parameters of field '%s' is required", apiParam.Name))
}
arrayItem := apiParam.SubParameters[0]
arrayItemType, ok := plugin.ToOpenapiParamType(arrayItem.Type)
if !ok {
return nil, errorx.New(errno.ErrPluginInvalidParamCode,
errorx.KVf(errno.PluginMsgKey, "the item type '%s' of field '%s' is invalid", arrayItemType, apiParam.Name))
}
if arrayItemType == openapi3.TypeObject || arrayItemType == openapi3.TypeArray {
return nil, errorx.New(errno.ErrPluginInvalidParamCode,
errorx.KVf(errno.PluginMsgKey, "the item type of field '%s' cannot be 'array' or 'object'", apiParam.Name))
}
itemSchema := &openapi3.Schema{
Type: arrayItemType,
Description: arrayItem.Desc,
Extensions: map[string]any{},
}
if arrayItem.GetAssistType() > 0 {
aType, ok := plugin.ToAPIAssistType(arrayItem.GetAssistType())
if !ok {
return nil, errorx.New(errno.ErrPluginInvalidParamCode,
errorx.KVf(errno.PluginMsgKey, "the assist type '%s' of field '%s' is invalid", arrayItem.GetAssistType(), apiParam.Name))
}
itemSchema.Extensions[plugin.APISchemaExtendAssistType] = aType
format, ok := plugin.AssistTypeToFormat(aType)
if !ok {
return nil, errorx.New(errno.ErrPluginInvalidParamCode,
errorx.KVf(errno.PluginMsgKey, "the assist type '%s' of field '%s' is invalid", aType, apiParam.Name))
}
itemSchema.Format = format
}
paramSchema.Items = &openapi3.SchemaRef{
Value: itemSchema,
}
}
if apiParam.LocalDefault != nil && *apiParam.LocalDefault != "" {
paramSchema.Default = apiParam.LocalDefault
}
if apiParam.LocalDisable {
paramSchema.Extensions[plugin.APISchemaExtendLocalDisable] = true
}
if apiParam.VariableRef != nil && *apiParam.VariableRef != "" {
paramSchema.Extensions[plugin.APISchemaExtendVariableRef] = apiParam.VariableRef
}
if apiParam.GetAssistType() > 0 {
aType, ok := plugin.ToAPIAssistType(apiParam.GetAssistType())
if !ok {
return nil, errorx.New(errno.ErrPluginInvalidParamCode,
errorx.KVf(errno.PluginMsgKey, "the assist type '%s' of field '%s' is invalid", apiParam.GetAssistType(), apiParam.Name))
}
paramSchema.Extensions[plugin.APISchemaExtendAssistType] = aType
format, ok := plugin.AssistTypeToFormat(aType)
if !ok {
return nil, errorx.New(errno.ErrPluginInvalidParamCode,
errorx.KVf(errno.PluginMsgKey, "the assist type '%s' of field '%s' is invalid", aType, apiParam.Name))
}
paramSchema.Format = format
}
loc, ok := plugin.ToHTTPParamLocation(apiParam.Location)
if !ok {
return nil, errorx.New(errno.ErrPluginInvalidParamCode,
errorx.KVf(errno.PluginMsgKey, "the location '%s' of field '%s' is invalid ", apiParam.Location, apiParam.Name))
}
param := &openapi3.Parameter{
Description: apiParam.Desc,
Name: apiParam.Name,
In: string(loc),
Required: apiParam.IsRequired,
Schema: &openapi3.SchemaRef{
Value: paramSchema,
},
}
return param, nil
}
func toOpenapi3Schema(apiParam *common.APIParameter) (*openapi3.Schema, error) {
paramType, ok := plugin.ToOpenapiParamType(apiParam.Type)
if !ok {
return nil, errorx.New(errno.ErrPluginInvalidParamCode,
errorx.KVf(errno.PluginMsgKey, "the type '%s' of field '%s' is invalid", apiParam.Type, apiParam.Name))
}
sc := &openapi3.Schema{
Description: apiParam.Desc,
Type: paramType,
Default: apiParam.GlobalDefault,
Extensions: map[string]interface{}{
plugin.APISchemaExtendGlobalDisable: apiParam.GlobalDisable,
},
}
if apiParam.LocalDefault != nil && *apiParam.LocalDefault != "" {
sc.Default = apiParam.LocalDefault
}
if apiParam.LocalDisable {
sc.Extensions[plugin.APISchemaExtendLocalDisable] = true
}
if apiParam.VariableRef != nil && *apiParam.VariableRef != "" {
sc.Extensions[plugin.APISchemaExtendVariableRef] = apiParam.VariableRef
}
if apiParam.GetAssistType() > 0 {
aType, ok := plugin.ToAPIAssistType(apiParam.GetAssistType())
if !ok {
return nil, errorx.New(errno.ErrPluginInvalidParamCode,
errorx.KVf(errno.PluginMsgKey, "the assist type '%s' of field '%s' is invalid", apiParam.GetAssistType(), apiParam.Name))
}
sc.Extensions[plugin.APISchemaExtendAssistType] = aType
format, ok := plugin.AssistTypeToFormat(aType)
if !ok {
return nil, errorx.New(errno.ErrPluginInvalidParamCode,
errorx.KVf(errno.PluginMsgKey, "the assist type '%s' of field '%s' is invalid", aType, apiParam.Name))
}
sc.Format = format
}
switch paramType {
case openapi3.TypeObject:
sc.Properties = map[string]*openapi3.SchemaRef{}
for _, subParam := range apiParam.SubParameters {
_subParam, err := toOpenapi3Schema(subParam)
if err != nil {
return nil, err
}
sc.Properties[subParam.Name] = &openapi3.SchemaRef{
Value: _subParam,
}
if subParam.IsRequired {
sc.Required = append(sc.Required, subParam.Name)
}
}
return sc, nil
case openapi3.TypeArray:
if len(apiParam.SubParameters) == 0 {
return nil, errorx.New(errno.ErrPluginInvalidParamCode,
errorx.KVf(errno.PluginMsgKey, "the sub-parameters of field '%s' are required", apiParam.Name))
}
arrayItem := apiParam.SubParameters[0]
itemType, ok := plugin.ToOpenapiParamType(arrayItem.Type)
if !ok {
return nil, errorx.New(errno.ErrPluginInvalidParamCode,
errorx.KVf(errno.PluginMsgKey, "the item type '%s' of field '%s' is invalid", itemType, apiParam.Name))
}
subParam, err := toOpenapi3Schema(arrayItem)
if err != nil {
return nil, err
}
sc.Items = &openapi3.SchemaRef{
Value: subParam,
}
return sc, nil
}
return sc, nil
}

View File

@@ -0,0 +1,41 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package connector
import (
"context"
"github.com/coze-dev/coze-studio/backend/domain/connector/entity"
connector "github.com/coze-dev/coze-studio/backend/domain/connector/service"
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
)
type ConnectorApplicationService struct {
DomainSVC connector.Connector
}
var ConnectorApplicationSVC *ConnectorApplicationService
func New(domainSVC connector.Connector, tosClient storage.Storage) *ConnectorApplicationService {
return &ConnectorApplicationService{
DomainSVC: domainSVC,
}
}
func (c *ConnectorApplicationService) List(ctx context.Context) ([]*entity.Connector, error) {
return c.DomainSVC.List(ctx)
}

View File

@@ -0,0 +1,29 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package connector
import (
connector "github.com/coze-dev/coze-studio/backend/domain/connector/service"
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
)
func InitService(tos storage.Storage) *ConnectorApplicationService {
connectorDomainSVC := connector.NewService(tos)
ConnectorApplicationSVC = New(connectorDomainSVC, tos)
return ConnectorApplicationSVC
}

View File

@@ -0,0 +1,471 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package conversation
import (
"context"
"encoding/json"
"errors"
"io"
"strconv"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/api/model/conversation/message"
"github.com/coze-dev/coze-studio/backend/api/model/conversation/run"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
crossDomainMessage "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
saEntity "github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
convEntity "github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/entity"
msgEntity "github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
cmdEntity "github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/entity"
sseImpl "github.com/coze-dev/coze-studio/backend/infra/impl/sse"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/types/consts"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
func (c *ConversationApplicationService) Run(ctx context.Context, sseSender *sseImpl.SSenderImpl, ar *run.AgentRunRequest) error {
agentInfo, caErr := c.checkAgent(ctx, ar)
if caErr != nil {
logs.CtxErrorf(ctx, "checkAgent err:%v", caErr)
return caErr
}
userID := ctxutil.MustGetUIDFromCtx(ctx)
conversationData, ccErr := c.checkConversation(ctx, ar, userID)
if ccErr != nil {
logs.CtxErrorf(ctx, "checkConversation err:%v", ccErr)
return ccErr
}
if ar.RegenMessageID != nil && ptr.From(ar.RegenMessageID) > 0 {
msgMeta, err := c.MessageDomainSVC.GetByID(ctx, ptr.From(ar.RegenMessageID))
if err != nil {
return err
}
if msgMeta != nil {
if msgMeta.UserID != conv.Int64ToStr(userID) {
return errorx.New(errno.ErrConversationPermissionCode, errorx.KV("msg", "message not match"))
}
delErr := c.MessageDomainSVC.Delete(ctx, &msgEntity.DeleteMeta{
RunIDs: []int64{msgMeta.RunID},
})
if delErr != nil {
return delErr
}
}
}
var shortcutCmd *cmdEntity.ShortcutCmd
if ar.GetShortcutCmdID() > 0 {
cmdID := ar.GetShortcutCmdID()
cmdMeta, err := c.ShortcutDomainSVC.GetByCmdID(ctx, cmdID, 0)
if err != nil {
return err
}
shortcutCmd = cmdMeta
}
arr, err := c.buildAgentRunRequest(ctx, ar, userID, agentInfo.SpaceID, conversationData, shortcutCmd)
if err != nil {
logs.CtxErrorf(ctx, "buildAgentRunRequest err:%v", err)
return err
}
streamer, err := c.AgentRunDomainSVC.AgentRun(ctx, arr)
if err != nil {
return err
}
c.pullStream(ctx, sseSender, streamer, ar)
return nil
}
func (c *ConversationApplicationService) pullStream(ctx context.Context, sseSender *sseImpl.SSenderImpl, arStream *schema.StreamReader[*entity.AgentRunResponse], req *run.AgentRunRequest) {
var ackMessageInfo *entity.ChunkMessageItem
for {
chunk, recvErr := arStream.Recv()
if recvErr != nil {
if errors.Is(recvErr, io.EOF) {
return
}
sseSender.Send(ctx, buildErrorEvent(errno.ErrConversationAgentRunError, recvErr.Error()))
return
}
switch chunk.Event {
case entity.RunEventCreated, entity.RunEventInProgress, entity.RunEventCompleted:
case entity.RunEventError:
id, err := c.GenID(ctx)
if err != nil {
sseSender.Send(ctx, buildErrorEvent(errno.ErrConversationAgentRunError, err.Error()))
} else {
sseSender.Send(ctx, buildMessageChunkEvent(run.RunEventMessage, buildErrMsg(ackMessageInfo, chunk.Error, id)))
}
case entity.RunEventStreamDone:
sseSender.Send(ctx, buildDoneEvent(run.RunEventDone))
case entity.RunEventAck:
ackMessageInfo = chunk.ChunkMessageItem
sseSender.Send(ctx, buildMessageChunkEvent(run.RunEventMessage, buildARSM2Message(chunk, req)))
case entity.RunEventMessageDelta, entity.RunEventMessageCompleted:
sseSender.Send(ctx, buildMessageChunkEvent(run.RunEventMessage, buildARSM2Message(chunk, req)))
default:
logs.CtxErrorf(ctx, "unknown handler event:%v", chunk.Event)
}
}
}
func buildARSM2Message(chunk *entity.AgentRunResponse, req *run.AgentRunRequest) []byte {
chunkMessageItem := chunk.ChunkMessageItem
chunkMessage := &run.RunStreamResponse{
ConversationID: strconv.FormatInt(chunkMessageItem.ConversationID, 10),
IsFinish: ptr.Of(chunk.ChunkMessageItem.IsFinish),
Message: &message.ChatMessage{
Role: string(chunkMessageItem.Role),
ContentType: string(chunkMessageItem.ContentType),
MessageID: strconv.FormatInt(chunkMessageItem.ID, 10),
SectionID: strconv.FormatInt(chunkMessageItem.SectionID, 10),
ContentTime: chunkMessageItem.CreatedAt,
ExtraInfo: buildExt(chunkMessageItem.Ext),
ReplyID: strconv.FormatInt(chunkMessageItem.ReplyID, 10),
Status: "",
Type: string(chunkMessageItem.MessageType),
Content: chunkMessageItem.Content,
ReasoningContent: chunkMessageItem.ReasoningContent,
RequiredAction: chunkMessageItem.RequiredAction,
},
Index: int32(chunkMessageItem.Index),
SeqID: int32(chunkMessageItem.SeqID),
}
if chunkMessageItem.MessageType == crossDomainMessage.MessageTypeAck {
chunkMessage.Message.Content = req.GetQuery()
chunkMessage.Message.ContentType = req.GetContentType()
chunkMessage.Message.ExtraInfo = &message.ExtraInfo{
LocalMessageID: req.GetLocalMessageID(),
}
} else {
chunkMessage.Message.ExtraInfo = buildExt(chunkMessageItem.Ext)
chunkMessage.Message.SenderID = ptr.Of(strconv.FormatInt(chunkMessageItem.AgentID, 10))
chunkMessage.Message.Content = chunkMessageItem.Content
if chunkMessageItem.MessageType == crossDomainMessage.MessageTypeKnowledge {
chunkMessage.Message.Type = string(crossDomainMessage.MessageTypeVerbose)
}
}
if chunk.ChunkMessageItem.IsFinish && chunkMessageItem.MessageType == crossDomainMessage.MessageTypeAnswer {
chunkMessage.Message.Content = ""
chunkMessage.Message.ReasoningContent = ptr.Of("")
}
mCM, _ := json.Marshal(chunkMessage)
return mCM
}
func buildExt(extra map[string]string) *message.ExtraInfo {
if extra == nil {
return nil
}
return &message.ExtraInfo{
InputTokens: extra["input_tokens"],
OutputTokens: extra["output_tokens"],
Token: extra["token"],
PluginStatus: extra["plugin_status"],
TimeCost: extra["time_cost"],
WorkflowTokens: extra["workflow_tokens"],
BotState: extra["bot_state"],
PluginRequest: extra["plugin_request"],
ToolName: extra["tool_name"],
Plugin: extra["plugin"],
MockHitInfo: extra["mock_hit_info"],
MessageTitle: extra["message_title"],
StreamPluginRunning: extra["stream_plugin_running"],
ExecuteDisplayName: extra["execute_display_name"],
TaskType: extra["task_type"],
ReferFormat: extra["refer_format"],
}
}
func buildErrMsg(ackChunk *entity.ChunkMessageItem, err *entity.RunError, id int64) []byte {
chunkMessage := &run.RunStreamResponse{
IsFinish: ptr.Of(true),
ConversationID: strconv.FormatInt(ackChunk.ConversationID, 10),
Message: &message.ChatMessage{
Role: string(schema.Assistant),
ContentType: string(crossDomainMessage.ContentTypeText),
Type: string(crossDomainMessage.MessageTypeAnswer),
MessageID: strconv.FormatInt(id, 10),
SectionID: strconv.FormatInt(ackChunk.SectionID, 10),
ReplyID: strconv.FormatInt(ackChunk.ReplyID, 10),
Content: "Something error:" + err.Msg,
ExtraInfo: &message.ExtraInfo{},
},
}
mCM, _ := json.Marshal(chunkMessage)
return mCM
}
func (c *ConversationApplicationService) GenID(ctx context.Context) (int64, error) {
id, err := c.appContext.IDGen.GenID(ctx)
return id, err
}
func (c *ConversationApplicationService) checkConversation(ctx context.Context, ar *run.AgentRunRequest, userID int64) (*convEntity.Conversation, error) {
var conversationData *convEntity.Conversation
if ar.ConversationID > 0 {
realCurrCon, err := c.ConversationDomainSVC.GetCurrentConversation(ctx, &convEntity.GetCurrent{
UserID: userID,
AgentID: ar.BotID,
Scene: ptr.From(ar.Scene),
ConnectorID: consts.CozeConnectorID,
})
logs.CtxInfof(ctx, "conversatioin data:%v", conv.DebugJsonToStr(realCurrCon))
if err != nil {
return nil, err
}
if realCurrCon != nil {
conversationData = realCurrCon
}
}
if ar.ConversationID == 0 || conversationData == nil {
conData, err := c.ConversationDomainSVC.Create(ctx, &convEntity.CreateMeta{
AgentID: ar.BotID,
UserID: userID,
Scene: ptr.From(ar.Scene),
ConnectorID: consts.CozeConnectorID,
})
if err != nil {
return nil, err
}
logs.CtxInfof(ctx, "conversatioin create data:%v", conv.DebugJsonToStr(conData))
conversationData = conData
ar.ConversationID = conversationData.ID
}
if conversationData.CreatorID != userID {
return nil, errorx.New(errno.ErrConversationPermissionCode, errorx.KV("msg", "conversation not match"))
}
return conversationData, nil
}
func (c *ConversationApplicationService) checkAgent(ctx context.Context, ar *run.AgentRunRequest) (*saEntity.SingleAgent, error) {
agentInfo, err := c.appContext.SingleAgentDomainSVC.GetSingleAgent(ctx, ar.BotID, "")
if err != nil {
return nil, err
}
if agentInfo == nil {
return nil, errorx.New(errno.ErrAgentNotExists)
}
return agentInfo, nil
}
func (c *ConversationApplicationService) buildAgentRunRequest(ctx context.Context, ar *run.AgentRunRequest, userID int64, spaceID int64, conversationData *convEntity.Conversation, shortcutCMD *cmdEntity.ShortcutCmd) (*entity.AgentRunMeta, error) {
var contentType crossDomainMessage.ContentType
contentType = crossDomainMessage.ContentTypeText
if ptr.From(ar.ContentType) != string(crossDomainMessage.ContentTypeText) {
contentType = crossDomainMessage.ContentTypeMix
}
shortcutCMDData, err := c.buildTools(ctx, ar.ToolList, shortcutCMD)
if err != nil {
return nil, err
}
arm := &entity.AgentRunMeta{
ConversationID: conversationData.ID,
AgentID: ar.BotID,
Content: c.buildMultiContent(ctx, ar),
DisplayContent: c.buildDisplayContent(ctx, ar),
SpaceID: spaceID,
UserID: conv.Int64ToStr(userID),
SectionID: conversationData.SectionID,
PreRetrieveTools: shortcutCMDData,
IsDraft: ptr.From(ar.DraftMode),
ConnectorID: consts.CozeConnectorID,
ContentType: contentType,
Ext: ar.Extra,
}
return arm, nil
}
func (c *ConversationApplicationService) buildDisplayContent(ctx context.Context, ar *run.AgentRunRequest) string {
if *ar.ContentType == run.ContentTypeText {
return ""
}
return ar.Query
}
func (c *ConversationApplicationService) buildTools(ctx context.Context, tools []*run.Tool, shortcutCMD *cmdEntity.ShortcutCmd) ([]*entity.Tool, error) {
var ts []*entity.Tool
for _, tool := range tools {
if shortcutCMD != nil {
arguments := make(map[string]string)
for key, parametersStruct := range tool.Parameters {
if parametersStruct == nil {
continue
}
arguments[key] = parametersStruct.Value
// uri需要转换成url
if parametersStruct.ResourceType == consts.ShortcutCommandResourceType {
resourceInfo, err := c.appContext.ImageX.GetResourceURL(ctx, parametersStruct.Value)
if err != nil {
return nil, err
}
arguments[key] = resourceInfo.URL
}
}
argBytes, err := json.Marshal(arguments)
if err == nil {
ts = append(ts, &entity.Tool{
PluginID: shortcutCMD.PluginID,
Arguments: string(argBytes),
ToolName: shortcutCMD.PluginToolName,
ToolID: shortcutCMD.PluginToolID,
Type: agentrun.ToolType(shortcutCMD.ToolType),
})
}
}
}
return ts, nil
}
func (c *ConversationApplicationService) buildMultiContent(ctx context.Context, ar *run.AgentRunRequest) []*crossDomainMessage.InputMetaData {
var multiContents []*crossDomainMessage.InputMetaData
switch *ar.ContentType {
case run.ContentTypeText:
multiContents = append(multiContents, &crossDomainMessage.InputMetaData{
Type: crossDomainMessage.InputTypeText,
Text: ar.Query,
})
case run.ContentTypeImage, run.ContentTypeFile, run.ContentTypeMix, run.ContentTypeVideo, run.ContentTypeAudio:
var mc *run.MixContentModel
err := json.Unmarshal([]byte(ar.Query), &mc)
if err != nil {
multiContents = append(multiContents, &crossDomainMessage.InputMetaData{
Type: crossDomainMessage.InputTypeText,
Text: ar.Query,
})
return multiContents
}
mcContent, newItemList := c.parseMultiContent(ctx, mc.ItemList)
multiContents = append(multiContents, mcContent...)
mc.ItemList = newItemList
mcByte, err := json.Marshal(mc)
if err == nil {
ar.Query = string(mcByte)
}
}
return multiContents
}
func (c *ConversationApplicationService) parseMultiContent(ctx context.Context, mc []*run.Item) (multiContents []*crossDomainMessage.InputMetaData, mcNew []*run.Item) {
for index, item := range mc {
switch item.Type {
case run.ContentTypeText:
multiContents = append(multiContents, &crossDomainMessage.InputMetaData{
Type: crossDomainMessage.InputTypeText,
Text: item.Text,
})
case run.ContentTypeImage:
resourceUrl, err := c.getUrlByUri(ctx, item.Image.Key)
if err != nil {
continue
}
if err != nil {
logs.CtxErrorf(ctx, "failed to unescape resource url, err is %v", err)
continue
}
mc[index].Image.ImageThumb.URL = resourceUrl
mc[index].Image.ImageOri.URL = resourceUrl
multiContents = append(multiContents, &crossDomainMessage.InputMetaData{
Type: crossDomainMessage.InputTypeImage,
FileData: []*crossDomainMessage.FileData{
{
Url: resourceUrl,
},
},
})
case run.ContentTypeFile, run.ContentTypeAudio, run.ContentTypeVideo:
resourceUrl, err := c.getUrlByUri(ctx, item.File.FileKey)
if err != nil {
continue
}
mc[index].File.FileURL = resourceUrl
multiContents = append(multiContents, &crossDomainMessage.InputMetaData{
Type: crossDomainMessage.InputType(item.Type),
FileData: []*crossDomainMessage.FileData{
{
Url: resourceUrl,
},
},
})
}
}
return multiContents, mc
}
func (s *ConversationApplicationService) getUrlByUri(ctx context.Context, uri string) (string, error) {
url, err := s.appContext.ImageX.GetResourceURL(ctx, uri)
if err != nil {
return "", err
}
return url.URL, nil
}

View File

@@ -0,0 +1,51 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package conversation
import (
"encoding/json"
"github.com/hertz-contrib/sse"
"github.com/coze-dev/coze-studio/backend/api/model/conversation/run"
)
func buildDoneEvent(event string) *sse.Event {
return &sse.Event{
Event: event,
}
}
func buildErrorEvent(errCode int64, errMsg string) *sse.Event {
errData := run.ErrorData{
Code: errCode,
Msg: errMsg,
}
ed, _ := json.Marshal(errData)
return &sse.Event{
Event: run.RunEventError,
Data: ed,
}
}
func buildMessageChunkEvent(event string, chunkMsg []byte) *sse.Event {
return &sse.Event{
Event: event,
Data: chunkMsg,
}
}

View File

@@ -0,0 +1,188 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package conversation
import (
"context"
"github.com/coze-dev/coze-studio/backend/api/model/conversation/common"
"github.com/coze-dev/coze-studio/backend/api/model/conversation/conversation"
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
agentrun "github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/service"
"github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/entity"
conversationService "github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/service"
message "github.com/coze-dev/coze-studio/backend/domain/conversation/message/service"
"github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/service"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/types/consts"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
type ConversationApplicationService struct {
appContext *ServiceComponents
AgentRunDomainSVC agentrun.Run
ConversationDomainSVC conversationService.Conversation
MessageDomainSVC message.Message
ShortcutDomainSVC service.ShortcutCmd
}
var ConversationSVC = new(ConversationApplicationService)
type OpenapiAgentRunApplication struct {
ShortcutDomainSVC service.ShortcutCmd
}
var ConversationOpenAPISVC = new(OpenapiAgentRunApplication)
func (c *ConversationApplicationService) ClearHistory(ctx context.Context, req *conversation.ClearConversationHistoryRequest) (*conversation.ClearConversationHistoryResponse, error) {
resp := new(conversation.ClearConversationHistoryResponse)
conversationID := req.ConversationID
// get conversation
currentRes, err := c.ConversationDomainSVC.GetByID(ctx, conversationID)
if err != nil {
return resp, err
}
if currentRes == nil {
return resp, errorx.New(errno.ErrConversationNotFound)
}
// check user
userID := ctxutil.GetUIDFromCtx(ctx)
if userID == nil || *userID != currentRes.CreatorID {
return resp, errorx.New(errno.ErrConversationNotFound, errorx.KV("msg", "user not match"))
}
// delete conversation
err = c.ConversationDomainSVC.Delete(ctx, conversationID)
if err != nil {
return resp, err
}
// create new conversation
convRes, err := c.ConversationDomainSVC.Create(ctx, &entity.CreateMeta{
AgentID: currentRes.AgentID,
UserID: currentRes.CreatorID,
Scene: currentRes.Scene,
ConnectorID: consts.CozeConnectorID,
})
if err != nil {
return resp, err
}
resp.NewSectionID = convRes.SectionID
return resp, nil
}
func (c *ConversationApplicationService) CreateSection(ctx context.Context, conversationID int64) (int64, error) {
currentRes, err := c.ConversationDomainSVC.GetByID(ctx, conversationID)
if err != nil {
return 0, err
}
if currentRes == nil {
return 0, errorx.New(errno.ErrConversationNotFound, errorx.KV("msg", "conversation not found"))
}
var userID int64
if currentRes.ConnectorID == consts.CozeConnectorID {
userID = ctxutil.MustGetUIDFromCtx(ctx)
} else {
userID = ctxutil.MustGetUIDFromApiAuthCtx(ctx)
}
if userID != currentRes.CreatorID {
return 0, errorx.New(errno.ErrConversationNotFound, errorx.KV("msg", "user not match"))
}
convRes, err := c.ConversationDomainSVC.NewConversationCtx(ctx, &entity.NewConversationCtxRequest{
ID: conversationID,
})
if err != nil {
return 0, err
}
return convRes.SectionID, nil
}
func (c *ConversationApplicationService) CreateConversation(ctx context.Context, agentID int64, connectorID int64) (*conversation.CreateConversationResponse, error) {
resp := new(conversation.CreateConversationResponse)
apiKeyInfo := ctxutil.GetApiAuthFromCtx(ctx)
userID := apiKeyInfo.UserID
if connectorID != consts.WebSDKConnectorID {
connectorID = apiKeyInfo.ConnectorID
}
conversationData, err := c.ConversationDomainSVC.Create(ctx, &entity.CreateMeta{
AgentID: agentID,
UserID: userID,
ConnectorID: connectorID,
Scene: common.Scene_SceneOpenApi,
})
if err != nil {
return nil, err
}
resp.ConversationData = &conversation.ConversationData{
Id: conversationData.ID,
LastSectionID: &conversationData.SectionID,
ConnectorID: &conversationData.ConnectorID,
CreatedAt: conversationData.CreatedAt,
}
return resp, nil
}
func (c *ConversationApplicationService) ListConversation(ctx context.Context, req *conversation.ListConversationsApiRequest) (*conversation.ListConversationsApiResponse, error) {
resp := new(conversation.ListConversationsApiResponse)
apiKeyInfo := ctxutil.GetApiAuthFromCtx(ctx)
userID := apiKeyInfo.UserID
connectorID := apiKeyInfo.ConnectorID
if userID == 0 {
return resp, errorx.New(errno.ErrConversationNotFound)
}
if ptr.From(req.ConnectorID) == consts.WebSDKConnectorID {
connectorID = ptr.From(req.ConnectorID)
}
conversationDOList, hasMore, err := c.ConversationDomainSVC.List(ctx, &entity.ListMeta{
UserID: userID,
AgentID: req.GetBotID(),
ConnectorID: connectorID,
Scene: common.Scene_SceneOpenApi,
Page: int(req.GetPageNum()),
Limit: int(req.GetPageSize()),
})
if err != nil {
return resp, err
}
conversationData := slices.Transform(conversationDOList, func(conv *entity.Conversation) *conversation.ConversationData {
return &conversation.ConversationData{
Id: conv.ID,
LastSectionID: &conv.SectionID,
ConnectorID: &conv.ConnectorID,
CreatedAt: conv.CreatedAt,
}
})
resp.Data = &conversation.ListConversationData{
Conversations: conversationData,
HasMore: hasMore,
}
return resp, nil
}

View File

@@ -0,0 +1,76 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package conversation
import (
"gorm.io/gorm"
"github.com/coze-dev/coze-studio/backend/application/singleagent"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/repository"
agentrun "github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/service"
convRepo "github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/repository"
conversation "github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/service"
msgRepo "github.com/coze-dev/coze-studio/backend/domain/conversation/message/repository"
message "github.com/coze-dev/coze-studio/backend/domain/conversation/message/service"
shortcutRepo "github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/repository"
"github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/service"
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
)
type ServiceComponents struct {
IDGen idgen.IDGenerator
DB *gorm.DB
TosClient storage.Storage
ImageX imagex.ImageX
SingleAgentDomainSVC singleagent.SingleAgent
}
func InitService(s *ServiceComponents) *ConversationApplicationService {
mDomainComponents := &message.Components{
MessageRepo: msgRepo.NewMessageRepo(s.DB, s.IDGen),
}
messageDomainSVC := message.NewService(mDomainComponents)
cDomainComponents := &conversation.Components{
ConversationRepo: convRepo.NewConversationRepo(s.DB, s.IDGen),
}
conversationDomainSVC := conversation.NewService(cDomainComponents)
arDomainComponents := &agentrun.Components{
RunRecordRepo: repository.NewRunRecordRepo(s.DB, s.IDGen),
}
agentRunDomainSVC := agentrun.NewService(arDomainComponents)
components := &service.Components{
ShortCutCmdRepo: shortcutRepo.NewShortCutCmdRepo(s.DB, s.IDGen),
}
shortcutCmdDomainSVC := service.NewShortcutCommandService(components)
ConversationSVC.AgentRunDomainSVC = agentRunDomainSVC
ConversationSVC.MessageDomainSVC = messageDomainSVC
ConversationSVC.ConversationDomainSVC = conversationDomainSVC
ConversationSVC.appContext = s
ConversationSVC.ShortcutDomainSVC = shortcutCmdDomainSVC
ConversationOpenAPISVC.ShortcutDomainSVC = shortcutCmdDomainSVC
return ConversationSVC
}

View File

@@ -0,0 +1,293 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package conversation
import (
"context"
"strconv"
"github.com/coze-dev/coze-studio/backend/api/model/conversation/common"
"github.com/coze-dev/coze-studio/backend/api/model/conversation/message"
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
singleAgentEntity "github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
convEntity "github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/entity"
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/types/consts"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
func (c *ConversationApplicationService) GetMessageList(ctx context.Context, mr *message.GetMessageListRequest) (*message.GetMessageListResponse, error) {
// Get Conversation ID by agent id & userID & scene
userID := ctxutil.GetUIDFromCtx(ctx)
agentID, err := strconv.ParseInt(mr.BotID, 10, 64)
if err != nil {
return nil, err
}
currentConversation, isNewCreate, err := c.getCurrentConversation(ctx, *userID, agentID, *mr.Scene, nil)
if err != nil {
return nil, err
}
if isNewCreate {
return &message.GetMessageListResponse{
MessageList: []*message.ChatMessage{},
Cursor: mr.Cursor,
NextCursor: "0",
NextHasMore: false,
ConversationID: strconv.FormatInt(currentConversation.ID, 10),
LastSectionID: ptr.Of(strconv.FormatInt(currentConversation.SectionID, 10)),
}, nil
}
cursor, err := strconv.ParseInt(mr.Cursor, 10, 64)
if err != nil {
return nil, err
}
mListMessages, err := c.MessageDomainSVC.List(ctx, &entity.ListMeta{
ConversationID: currentConversation.ID,
AgentID: agentID,
Limit: int(mr.Count),
Cursor: cursor,
Direction: loadDirectionToScrollDirection(mr.LoadDirection),
})
if err != nil {
return nil, err
}
// get agent id
var agentIDs []int64
for _, mOne := range mListMessages.Messages {
agentIDs = append(agentIDs, mOne.AgentID)
}
agentInfo, err := c.buildAgentInfo(ctx, agentIDs)
if err != nil {
return nil, err
}
resp := c.buildMessageListResponse(ctx, mListMessages, currentConversation)
resp.ParticipantInfoMap = map[string]*message.MsgParticipantInfo{}
for _, aOne := range agentInfo {
resp.ParticipantInfoMap[aOne.ID] = aOne
}
return resp, err
}
func (c *ConversationApplicationService) buildAgentInfo(ctx context.Context, agentIDs []int64) ([]*message.MsgParticipantInfo, error) {
var result []*message.MsgParticipantInfo
if len(agentIDs) > 0 {
agentInfos, err := c.appContext.SingleAgentDomainSVC.MGetSingleAgentDraft(ctx, agentIDs)
if err != nil {
return nil, err
}
result = slices.Transform(agentInfos, func(a *singleAgentEntity.SingleAgent) *message.MsgParticipantInfo {
return &message.MsgParticipantInfo{
ID: strconv.FormatInt(a.AgentID, 10),
Name: a.Name,
UserID: strconv.FormatInt(a.CreatorID, 10),
Desc: a.Desc,
AvatarURL: a.IconURI,
}
})
}
return result, nil
}
func (c *ConversationApplicationService) getCurrentConversation(ctx context.Context, userID int64, agentID int64, scene common.Scene, connectorID *int64) (*convEntity.Conversation, bool, error) {
var currentConversation *convEntity.Conversation
var isNewCreate bool
if connectorID == nil && scene == common.Scene_Playground {
connectorID = ptr.Of(consts.CozeConnectorID)
}
currentConversation, err := c.ConversationDomainSVC.GetCurrentConversation(ctx, &convEntity.GetCurrent{
UserID: userID,
Scene: scene,
AgentID: agentID,
ConnectorID: ptr.From(connectorID),
})
if err != nil {
return nil, isNewCreate, err
}
if currentConversation == nil { // new conversation
// create conversation
ccNew, err := c.ConversationDomainSVC.Create(ctx, &convEntity.CreateMeta{
AgentID: agentID,
UserID: userID,
Scene: scene,
ConnectorID: ptr.From(connectorID),
})
if err != nil {
return nil, isNewCreate, err
}
if ccNew == nil {
return nil, isNewCreate,
errorx.New(errno.ErrConversationNotFound)
}
isNewCreate = true
currentConversation = ccNew
}
return currentConversation, isNewCreate, nil
}
func loadDirectionToScrollDirection(direction *message.LoadDirection) entity.ScrollPageDirection {
if direction != nil && *direction == message.LoadDirection_Next {
return entity.ScrollPageDirectionNext
}
return entity.ScrollPageDirectionPrev
}
func (c *ConversationApplicationService) buildMessageListResponse(ctx context.Context, mListMessages *entity.ListResult, currentConversation *convEntity.Conversation) *message.GetMessageListResponse {
var messages []*message.ChatMessage
runToQuestionIDMap := make(map[int64]int64)
for _, mMessage := range mListMessages.Messages {
if mMessage.MessageType == model.MessageTypeQuestion {
runToQuestionIDMap[mMessage.RunID] = mMessage.ID
}
}
for _, mMessage := range mListMessages.Messages {
messages = append(messages, c.buildDomainMsg2VOMessage(ctx, mMessage, runToQuestionIDMap))
}
resp := &message.GetMessageListResponse{
MessageList: messages,
Cursor: strconv.FormatInt(mListMessages.PrevCursor, 10),
NextCursor: strconv.FormatInt(mListMessages.NextCursor, 10),
ConversationID: strconv.FormatInt(currentConversation.ID, 10),
LastSectionID: ptr.Of(strconv.FormatInt(currentConversation.SectionID, 10)),
ConnectorConversationID: strconv.FormatInt(currentConversation.ID, 10),
}
if mListMessages.Direction == entity.ScrollPageDirectionPrev {
resp.Hasmore = mListMessages.HasMore
} else {
resp.NextHasMore = mListMessages.HasMore
}
return resp
}
func (c *ConversationApplicationService) buildDomainMsg2VOMessage(ctx context.Context, dm *entity.Message, runToQuestionIDMap map[int64]int64) *message.ChatMessage {
cm := &message.ChatMessage{
MessageID: strconv.FormatInt(dm.ID, 10),
Role: string(dm.Role),
Type: string(dm.MessageType),
Content: dm.Content,
ContentType: string(dm.ContentType),
ReplyID: "0",
SectionID: strconv.FormatInt(dm.SectionID, 10),
ExtraInfo: buildDExt2ApiExt(dm.Ext),
ContentTime: dm.CreatedAt,
Status: "available",
Source: 0,
ReasoningContent: ptr.Of(dm.ReasoningContent),
}
if dm.Status == model.MessageStatusBroken {
cm.BrokenPos = ptr.Of(dm.Position)
}
if dm.ContentType == model.ContentTypeMix && dm.DisplayContent != "" {
cm.Content = dm.DisplayContent
}
if dm.MessageType != model.MessageTypeQuestion {
cm.ReplyID = strconv.FormatInt(runToQuestionIDMap[dm.RunID], 10)
cm.SenderID = ptr.Of(strconv.FormatInt(dm.AgentID, 10))
}
return cm
}
func buildDExt2ApiExt(extra map[string]string) *message.ExtraInfo {
return &message.ExtraInfo{
InputTokens: extra["input_tokens"],
OutputTokens: extra["output_tokens"],
Token: extra["token"],
PluginStatus: extra["plugin_status"],
TimeCost: extra["time_cost"],
WorkflowTokens: extra["workflow_tokens"],
BotState: extra["bot_state"],
PluginRequest: extra["plugin_request"],
ToolName: extra["tool_name"],
Plugin: extra["plugin"],
MockHitInfo: extra["mock_hit_info"],
MessageTitle: extra["message_title"],
StreamPluginRunning: extra["stream_plugin_running"],
ExecuteDisplayName: extra["execute_display_name"],
TaskType: extra["task_type"],
ReferFormat: extra["refer_format"],
}
}
func (c *ConversationApplicationService) DeleteMessage(ctx context.Context, mr *message.DeleteMessageRequest) (*message.DeleteMessageResponse, error) {
resp := new(message.DeleteMessageResponse)
messageInfo, err := c.MessageDomainSVC.GetByID(ctx, mr.MessageID)
if err != nil {
return resp, err
}
if messageInfo == nil {
return resp, errorx.New(errno.ErrConversationMessageNotFound)
}
userID := ctxutil.GetUIDFromCtx(ctx)
if messageInfo.UserID != conv.Int64ToStr(*userID) {
return resp, errorx.New(errno.ErrConversationPermissionCode, errorx.KV("msg", "permission denied"))
}
err = c.AgentRunDomainSVC.Delete(ctx, []int64{messageInfo.RunID})
if err != nil {
return resp, err
}
err = c.MessageDomainSVC.Delete(ctx, &entity.DeleteMeta{
RunIDs: []int64{messageInfo.RunID},
})
if err != nil {
return resp, nil
}
return resp, nil
}
func (c *ConversationApplicationService) BreakMessage(ctx context.Context, mr *message.BreakMessageRequest) (*message.BreakMessageResponse, error) {
resp := new(message.BreakMessageResponse)
err := c.MessageDomainSVC.Broken(ctx, &entity.BrokenMeta{
ID: *mr.AnswerMessageID,
Position: mr.BrokenPos,
})
if err != nil {
return resp, err
}
return resp, nil
}

View File

@@ -0,0 +1,329 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package conversation
import (
"context"
"encoding/json"
"errors"
"io"
"strconv"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/api/model/conversation/common"
"github.com/coze-dev/coze-studio/backend/api/model/conversation/run"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
saEntity "github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
convEntity "github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/entity"
cmdEntity "github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/entity"
sseImpl "github.com/coze-dev/coze-studio/backend/infra/impl/sse"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/types/consts"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
func (a *OpenapiAgentRunApplication) OpenapiAgentRun(ctx context.Context, sseSender *sseImpl.SSenderImpl, ar *run.ChatV3Request) error {
apiKeyInfo := ctxutil.GetApiAuthFromCtx(ctx)
creatorID := apiKeyInfo.UserID
connectorID := apiKeyInfo.ConnectorID
if ptr.From(ar.ConnectorID) == consts.WebSDKConnectorID {
connectorID = ptr.From(ar.ConnectorID)
}
agentInfo, caErr := a.checkAgent(ctx, ar, connectorID)
if caErr != nil {
logs.CtxErrorf(ctx, "checkAgent err:%v", caErr)
return caErr
}
conversationData, ccErr := a.checkConversation(ctx, ar, creatorID, connectorID)
if ccErr != nil {
logs.CtxErrorf(ctx, "checkConversation err:%v", ccErr)
return ccErr
}
spaceID := agentInfo.SpaceID
arr, err := a.buildAgentRunRequest(ctx, ar, connectorID, spaceID, conversationData)
if err != nil {
logs.CtxErrorf(ctx, "buildAgentRunRequest err:%v", err)
return err
}
streamer, err := ConversationSVC.AgentRunDomainSVC.AgentRun(ctx, arr)
if err != nil {
return err
}
a.pullStream(ctx, sseSender, streamer)
return nil
}
func (a *OpenapiAgentRunApplication) checkConversation(ctx context.Context, ar *run.ChatV3Request, userID int64, connectorID int64) (*convEntity.Conversation, error) {
var conversationData *convEntity.Conversation
if ptr.From(ar.ConversationID) > 0 {
conData, err := ConversationSVC.ConversationDomainSVC.GetByID(ctx, ptr.From(ar.ConversationID))
if err != nil {
return nil, err
}
conversationData = conData
}
if ptr.From(ar.ConversationID) == 0 || conversationData == nil {
conData, err := ConversationSVC.ConversationDomainSVC.Create(ctx, &convEntity.CreateMeta{
AgentID: ar.BotID,
UserID: userID,
ConnectorID: connectorID,
Scene: common.Scene_SceneOpenApi,
})
if err != nil {
return nil, err
}
if conData == nil {
return nil, errors.New("conversation data is nil")
}
conversationData = conData
ar.ConversationID = ptr.Of(conversationData.ID)
}
if conversationData.CreatorID != userID {
return nil, errors.New("conversation data not match")
}
return conversationData, nil
}
func (a *OpenapiAgentRunApplication) checkAgent(ctx context.Context, ar *run.ChatV3Request, connectorID int64) (*saEntity.SingleAgent, error) {
agentInfo, err := ConversationSVC.appContext.SingleAgentDomainSVC.ObtainAgentByIdentity(ctx, &singleagent.AgentIdentity{
AgentID: ar.BotID,
IsDraft: false,
ConnectorID: connectorID,
})
if err != nil {
return nil, err
}
if agentInfo == nil {
return nil, errors.New("agent info is nil")
}
return agentInfo, nil
}
func (a *OpenapiAgentRunApplication) buildAgentRunRequest(ctx context.Context, ar *run.ChatV3Request, connectorID int64, spaceID int64, conversationData *convEntity.Conversation) (*entity.AgentRunMeta, error) {
shortcutCMDData, err := a.buildTools(ctx, ar.ShortcutCommand)
if err != nil {
return nil, err
}
multiContent, contentType, err := a.buildMultiContent(ctx, ar)
if err != nil {
return nil, err
}
displayContent := a.buildDisplayContent(ctx, ar)
arm := &entity.AgentRunMeta{
ConversationID: ptr.From(ar.ConversationID),
AgentID: ar.BotID,
Content: multiContent,
DisplayContent: displayContent,
SpaceID: spaceID,
UserID: ar.User,
SectionID: conversationData.SectionID,
PreRetrieveTools: shortcutCMDData,
IsDraft: false,
ConnectorID: connectorID,
ContentType: contentType,
Ext: ar.ExtraParams,
}
return arm, nil
}
func (a *OpenapiAgentRunApplication) buildTools(ctx context.Context, shortcmd *run.ShortcutCommandDetail) ([]*entity.Tool, error) {
var ts []*entity.Tool
if shortcmd == nil {
return ts, nil
}
var shortcutCMD *cmdEntity.ShortcutCmd
cmdMeta, err := a.ShortcutDomainSVC.GetByCmdID(ctx, shortcmd.CommandID, 0)
if err != nil {
return nil, err
}
shortcutCMD = cmdMeta
if shortcutCMD != nil {
argBytes, err := json.Marshal(shortcmd.Parameters)
if err == nil {
ts = append(ts, &entity.Tool{
PluginID: shortcutCMD.PluginID,
Arguments: string(argBytes),
ToolName: shortcutCMD.PluginToolName,
ToolID: shortcutCMD.PluginToolID,
Type: agentrun.ToolType(shortcutCMD.ToolType),
})
}
}
return ts, nil
}
func (a *OpenapiAgentRunApplication) buildDisplayContent(_ context.Context, ar *run.ChatV3Request) string {
for _, item := range ar.AdditionalMessages {
if item.ContentType == run.ContentTypeMixApi {
return item.Content
}
}
return ""
}
func (a *OpenapiAgentRunApplication) buildMultiContent(ctx context.Context, ar *run.ChatV3Request) ([]*message.InputMetaData, message.ContentType, error) {
var multiContents []*message.InputMetaData
contentType := message.ContentTypeText
for _, item := range ar.AdditionalMessages {
if item == nil {
continue
}
if item.Role != string(schema.User) {
return nil, contentType, errors.New("role not match")
}
if item.ContentType == run.ContentTypeText {
if item.Content == "" {
continue
}
multiContents = append(multiContents, &message.InputMetaData{
Type: message.InputTypeText,
Text: item.Content,
})
}
if item.ContentType == run.ContentTypeMixApi {
contentType = message.ContentTypeMix
var inputs []*run.AdditionalContent
err := json.Unmarshal([]byte(item.Content), &inputs)
logs.CtxInfof(ctx, "inputs:%v, err:%v", conv.DebugJsonToStr(inputs), err)
if err != nil {
continue
}
for _, one := range inputs {
if one == nil {
continue
}
switch message.InputType(one.Type) {
case message.InputTypeText:
multiContents = append(multiContents, &message.InputMetaData{
Type: message.InputTypeText,
Text: ptr.From(one.Text),
})
case message.InputTypeImage, message.InputTypeFile:
multiContents = append(multiContents, &message.InputMetaData{
Type: message.InputType(one.Type),
FileData: []*message.FileData{
{
Url: one.GetFileURL(),
},
},
})
default:
continue
}
}
}
}
return multiContents, contentType, nil
}
func (a *OpenapiAgentRunApplication) pullStream(ctx context.Context, sseSender *sseImpl.SSenderImpl, streamer *schema.StreamReader[*entity.AgentRunResponse]) {
for {
chunk, recvErr := streamer.Recv()
logs.CtxInfof(ctx, "chunk :%v, err:%v", conv.DebugJsonToStr(chunk), recvErr)
if recvErr != nil {
if errors.Is(recvErr, io.EOF) {
return
}
sseSender.Send(ctx, buildErrorEvent(errno.ErrConversationAgentRunError, recvErr.Error()))
return
}
switch chunk.Event {
case entity.RunEventError:
sseSender.Send(ctx, buildErrorEvent(chunk.Error.Code, chunk.Error.Msg))
case entity.RunEventStreamDone:
sseSender.Send(ctx, buildDoneEvent(string(entity.RunEventStreamDone)))
case entity.RunEventAck:
case entity.RunEventCreated, entity.RunEventCancelled, entity.RunEventInProgress, entity.RunEventFailed, entity.RunEventCompleted:
sseSender.Send(ctx, buildMessageChunkEvent(string(chunk.Event), buildARSM2ApiChatMessage(chunk)))
case entity.RunEventMessageDelta, entity.RunEventMessageCompleted:
sseSender.Send(ctx, buildMessageChunkEvent(string(chunk.Event), buildARSM2ApiMessage(chunk)))
default:
logs.CtxErrorf(ctx, "unknow handler event:%v", chunk.Event)
}
}
}
func buildARSM2ApiMessage(chunk *entity.AgentRunResponse) []byte {
chunkMessageItem := chunk.ChunkMessageItem
chunkMessage := &run.ChatV3MessageDetail{
ID: strconv.FormatInt(chunkMessageItem.ID, 10),
ConversationID: strconv.FormatInt(chunkMessageItem.ConversationID, 10),
BotID: strconv.FormatInt(chunkMessageItem.AgentID, 10),
Role: string(chunkMessageItem.Role),
Type: string(chunkMessageItem.MessageType),
Content: chunkMessageItem.Content,
ContentType: string(chunkMessageItem.ContentType),
MetaData: chunkMessageItem.Ext,
ChatID: strconv.FormatInt(chunkMessageItem.RunID, 10),
ReasoningContent: chunkMessageItem.ReasoningContent,
}
mCM, _ := json.Marshal(chunkMessage)
return mCM
}
func buildARSM2ApiChatMessage(chunk *entity.AgentRunResponse) []byte {
chunkRunItem := chunk.ChunkRunItem
chunkMessage := &run.ChatV3ChatDetail{
ID: chunkRunItem.ID,
ConversationID: chunkRunItem.ConversationID,
BotID: chunkRunItem.AgentID,
Status: string(chunkRunItem.Status),
SectionID: ptr.Of(chunkRunItem.SectionID),
CreatedAt: ptr.Of(int32(chunkRunItem.CreatedAt / 1000)),
CompletedAt: ptr.Of(int32(chunkRunItem.CompletedAt / 1000)),
FailedAt: ptr.Of(int32(chunkRunItem.FailedAt / 1000)),
}
if chunkRunItem.Usage != nil {
chunkMessage.Usage = &run.Usage{
TokenCount: ptr.Of(int32(chunkRunItem.Usage.LlmTotalTokens)),
InputTokens: ptr.Of(int32(chunkRunItem.Usage.LlmPromptTokens)),
OutputTokens: ptr.Of(int32(chunkRunItem.Usage.LlmCompletionTokens)),
}
}
mCM, _ := json.Marshal(chunkMessage)
return mCM
}

View File

@@ -0,0 +1,133 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package conversation
import (
"context"
"strconv"
"github.com/coze-dev/coze-studio/backend/api/model/conversation/message"
"github.com/coze-dev/coze-studio/backend/api/model/conversation/run"
message3 "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
convEntity "github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/entity"
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
type OpenapiMessageApplication struct{}
var OpenapiMessageApplicationService = new(OpenapiMessageApplication)
func (m *OpenapiMessageApplication) GetApiMessageList(ctx context.Context, mr *message.ListMessageApiRequest) (*message.ListMessageApiResponse, error) {
// Get Conversation ID by agent id & userID & scene
userID := ctxutil.MustGetUIDFromApiAuthCtx(ctx)
currentConversation, err := getConversation(ctx, mr.ConversationID)
if err != nil {
return nil, err
}
if currentConversation == nil {
return nil, errorx.New(errno.ErrConversationNotFound)
}
if currentConversation.CreatorID != userID {
return nil, errorx.New(errno.ErrConversationPermissionCode, errorx.KV("msg", "permission denied"))
}
msgListMeta := &entity.ListMeta{
ConversationID: currentConversation.ID,
AgentID: currentConversation.AgentID,
Limit: int(ptr.From(mr.Limit)),
}
if mr.BeforeID != nil {
msgListMeta.Direction = entity.ScrollPageDirectionPrev
msgListMeta.Cursor = *mr.BeforeID
} else {
msgListMeta.Direction = entity.ScrollPageDirectionNext
msgListMeta.Cursor = ptr.From(mr.AfterID)
}
if mr.Order == nil {
msgListMeta.OrderBy = ptr.Of(message.OrderByDesc)
} else {
msgListMeta.OrderBy = mr.Order
}
mListMessages, err := ConversationSVC.MessageDomainSVC.List(ctx, msgListMeta)
if err != nil {
return nil, err
}
// get agent id
var agentIDs []int64
for _, mOne := range mListMessages.Messages {
agentIDs = append(agentIDs, mOne.AgentID)
}
resp := m.buildMessageListResponse(ctx, mListMessages, currentConversation)
return resp, err
}
func getConversation(ctx context.Context, conversationID int64) (*convEntity.Conversation, error) {
conversationInfo, err := ConversationSVC.ConversationDomainSVC.GetByID(ctx, conversationID)
if err != nil {
return nil, err
}
return conversationInfo, nil
}
func (m *OpenapiMessageApplication) buildMessageListResponse(ctx context.Context, mListMessages *entity.ListResult, currentConversation *convEntity.Conversation) *message.ListMessageApiResponse {
messagesVO := slices.Transform(mListMessages.Messages, func(dm *entity.Message) *message.OpenMessageApi {
content := dm.Content
msg := &message.OpenMessageApi{
ID: dm.ID,
ConversationID: dm.ConversationID,
BotID: dm.AgentID,
Role: string(dm.Role),
Type: string(dm.MessageType),
Content: content,
ContentType: string(dm.ContentType),
SectionID: strconv.FormatInt(dm.SectionID, 10),
CreatedAt: dm.CreatedAt,
UpdatedAt: dm.UpdatedAt,
ChatID: dm.RunID,
MetaData: dm.Ext,
}
if dm.ContentType == message3.ContentTypeMix && dm.DisplayContent != "" {
msg.Content = dm.DisplayContent
msg.ContentType = run.ContentTypeMixApi
}
return msg
})
resp := &message.ListMessageApiResponse{
Messages: messagesVO,
HasMore: ptr.Of(mListMessages.HasMore),
FirstID: ptr.Of(mListMessages.PrevCursor),
LastID: ptr.Of(mListMessages.NextCursor),
}
return resp
}

View File

@@ -0,0 +1,780 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package knowledge
import (
"context"
"encoding/json"
"fmt"
"path"
"strconv"
"strings"
"time"
modelCommon "github.com/coze-dev/coze-studio/backend/api/model/common"
knowledgeModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
"github.com/coze-dev/coze-studio/backend/api/model/flow/dataengine/dataset"
"github.com/coze-dev/coze-studio/backend/application/upload"
"github.com/coze-dev/coze-studio/backend/domain/knowledge/entity"
"github.com/coze-dev/coze-studio/backend/domain/knowledge/service"
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
)
func assertValAs(typ document.TableColumnType, val string) (*document.ColumnData, error) {
cd := &document.ColumnData{
Type: typ,
}
if val == "" {
return cd, nil
}
switch typ {
case document.TableColumnTypeString:
return &document.ColumnData{
Type: document.TableColumnTypeString,
ValString: &val,
}, nil
case document.TableColumnTypeInteger:
i, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return nil, err
}
return &document.ColumnData{
Type: document.TableColumnTypeInteger,
ValInteger: &i,
}, nil
case document.TableColumnTypeTime:
// 支持时间戳和时间字符串
i, err := strconv.ParseInt(val, 10, 64)
if err == nil {
t := time.Unix(i, 0)
return &document.ColumnData{
Type: document.TableColumnTypeTime,
ValTime: &t,
}, nil
}
t, err := time.Parse(time.DateTime, val)
if err != nil {
return nil, err
}
return &document.ColumnData{
Type: document.TableColumnTypeTime,
ValTime: &t,
}, nil
case document.TableColumnTypeNumber:
f, err := strconv.ParseFloat(val, 64)
if err != nil {
return nil, err
}
return &document.ColumnData{
Type: document.TableColumnTypeNumber,
ValNumber: &f,
}, nil
case document.TableColumnTypeBoolean:
t, err := strconv.ParseBool(val)
if err != nil {
return nil, err
}
return &document.ColumnData{
Type: document.TableColumnTypeBoolean,
ValBoolean: &t,
}, nil
case document.TableColumnTypeImage:
return &document.ColumnData{
Type: document.TableColumnTypeImage,
ValImage: &val,
}, nil
default:
return nil, fmt.Errorf("[assertValAs] type not support, type=%d, val=%s", typ, val)
}
}
func convertTableDataType2Entity(t dataset.TableDataType) service.TableDataType {
switch t {
case dataset.TableDataType_AllData:
return service.AllData
case dataset.TableDataType_OnlySchema:
return service.OnlySchema
case dataset.TableDataType_OnlyPreview:
return service.OnlyPreview
default:
return service.AllData
}
}
func convertTableSheet2Entity(sheet *dataset.TableSheet) *entity.TableSheet {
if sheet == nil {
return nil
}
return &entity.TableSheet{
SheetId: sheet.GetSheetID(),
StartLineIdx: sheet.GetStartLineIdx(),
HeaderLineIdx: sheet.GetHeaderLineIdx(),
}
}
func convertDocTableSheet2Model(sheet entity.TableSheet) *dataset.DocTableSheet {
return &dataset.DocTableSheet{
ID: sheet.SheetId,
SheetName: sheet.SheetName,
TotalRow: sheet.TotalRows,
}
}
func convertTableMeta(t []*entity.TableColumn) []*modelCommon.DocTableColumn {
if len(t) == 0 {
return nil
}
resp := make([]*modelCommon.DocTableColumn, 0)
for i := range t {
if t[i] == nil {
continue
}
resp = append(resp, &modelCommon.DocTableColumn{
ID: t[i].ID,
ColumnName: t[i].Name,
IsSemantic: t[i].Indexing,
Desc: &t[i].Description,
Sequence: t[i].Sequence,
ColumnType: convertColumnType(t[i].Type),
})
}
return resp
}
func convertColumnType(t document.TableColumnType) *modelCommon.ColumnType {
switch t {
case document.TableColumnTypeString:
return modelCommon.ColumnTypePtr(modelCommon.ColumnType_Text)
case document.TableColumnTypeBoolean:
return modelCommon.ColumnTypePtr(modelCommon.ColumnType_Boolean)
case document.TableColumnTypeNumber:
return modelCommon.ColumnTypePtr(modelCommon.ColumnType_Float)
case document.TableColumnTypeTime:
return modelCommon.ColumnTypePtr(modelCommon.ColumnType_Date)
case document.TableColumnTypeInteger:
return modelCommon.ColumnTypePtr(modelCommon.ColumnType_Number)
case document.TableColumnTypeImage:
return modelCommon.ColumnTypePtr(modelCommon.ColumnType_Image)
default:
return modelCommon.ColumnTypePtr(modelCommon.ColumnType_Text)
}
}
func convertDocTableSheet(t *entity.TableSheet) *modelCommon.DocTableSheet {
if t == nil {
return nil
}
return &modelCommon.DocTableSheet{
ID: t.SheetId,
SheetName: t.SheetName,
TotalRow: t.TotalRows,
}
}
func convertSlice2Model(sliceEntity *entity.Slice) *dataset.SliceInfo {
if sliceEntity == nil {
return nil
}
return &dataset.SliceInfo{
SliceID: sliceEntity.ID,
Content: convertSliceContent(sliceEntity),
Status: convertSliceStatus2Model(sliceEntity.SliceStatus),
HitCount: sliceEntity.Hit,
CharCount: sliceEntity.CharCount,
Sequence: sliceEntity.Sequence,
DocumentID: sliceEntity.DocumentID,
ChunkInfo: "",
}
}
func convertSliceContent(s *entity.Slice) string {
if len(s.RawContent) == 0 {
return ""
}
if s.RawContent[0].Type == knowledgeModel.SliceContentTypeTable {
tableData := make([]sliceContentData, 0, len(s.RawContent[0].Table.Columns))
for _, col := range s.RawContent[0].Table.Columns {
tableData = append(tableData, sliceContentData{
ColumnID: strconv.FormatInt(col.ColumnID, 10),
ColumnName: col.ColumnName,
Value: col.GetNullableStringValue(),
Desc: "",
})
}
b, _ := json.Marshal(tableData)
return string(b)
}
return s.GetSliceContent()
}
type sliceContentData struct {
ColumnID string `json:"column_id"`
ColumnName string `json:"column_name"`
Value string `json:"value"`
Desc string `json:"desc"`
}
func convertSliceStatus2Model(status knowledgeModel.SliceStatus) dataset.SliceStatus {
switch status {
case knowledgeModel.SliceStatusInit:
return dataset.SliceStatus_PendingVectoring
case knowledgeModel.SliceStatusFinishStore:
return dataset.SliceStatus_FinishVectoring
case knowledgeModel.SliceStatusFailed:
return dataset.SliceStatus_Deactive
default:
return dataset.SliceStatus_PendingVectoring
}
}
func convertFilterStrategy2Model(strategy *entity.ParsingStrategy) *dataset.FilterStrategy {
if strategy == nil {
return nil
}
if len(strategy.FilterPages) != 0 {
return &dataset.FilterStrategy{
FilterPage: slices.Transform(strategy.FilterPages, func(page int) int32 {
return int32(page)
}),
}
}
return nil
}
func convertDocument2Model(documentEntity *entity.Document) *dataset.DocumentInfo {
if documentEntity == nil {
return nil
}
chunkStrategy := convertChunkingStrategy2Model(documentEntity.ChunkingStrategy)
filterStrategy := convertFilterStrategy2Model(documentEntity.ParsingStrategy)
parseStrategy, _ := convertParsingStrategy2Model(documentEntity.ParsingStrategy)
docInfo := &dataset.DocumentInfo{
Name: documentEntity.Name,
DocumentID: documentEntity.ID,
TosURI: &documentEntity.URI,
CreateTime: int32(documentEntity.CreatedAtMs / 1000),
UpdateTime: int32(documentEntity.UpdatedAtMs / 1000),
CreatorID: ptr.Of(documentEntity.CreatorID),
SliceCount: int32(documentEntity.SliceCount),
Type: string(documentEntity.FileExtension),
Size: int32(documentEntity.Size),
CharCount: int32(documentEntity.CharCount),
Status: convertDocumentStatus2Model(documentEntity.Status),
HitCount: int32(documentEntity.Hits),
SourceType: convertDocumentSource2Model(documentEntity.Source),
FormatType: convertDocumentTypeEntity2Dataset(documentEntity.Type),
WebURL: &documentEntity.URL,
TableMeta: convertTableColumns2Model(documentEntity.TableInfo.Columns),
StatusDescript: &documentEntity.StatusMsg,
SpaceID: ptr.Of(documentEntity.SpaceID),
EditableAppendContent: nil,
FilterStrategy: filterStrategy,
PreviewTosURL: &documentEntity.URL,
ChunkStrategy: chunkStrategy,
ParsingStrategy: parseStrategy,
}
return docInfo
}
func convertDocumentSource2Entity(sourceType dataset.DocumentSource) entity.DocumentSource {
switch sourceType {
case dataset.DocumentSource_Custom:
return entity.DocumentSourceCustom
case dataset.DocumentSource_Document:
return entity.DocumentSourceLocal
default:
return entity.DocumentSourceLocal
}
}
func convertDocumentSource2Model(sourceType entity.DocumentSource) dataset.DocumentSource {
switch sourceType {
case entity.DocumentSourceCustom:
return dataset.DocumentSource_Custom
case entity.DocumentSourceLocal:
return dataset.DocumentSource_Document
default:
return dataset.DocumentSource_Document
}
}
func convertDocumentStatus2Model(status entity.DocumentStatus) dataset.DocumentStatus {
switch status {
case entity.DocumentStatusDeleted:
return dataset.DocumentStatus_Deleted
case entity.DocumentStatusEnable, entity.DocumentStatusInit:
return dataset.DocumentStatus_Enable
case entity.DocumentStatusFailed:
return dataset.DocumentStatus_Failed
default:
return dataset.DocumentStatus_Processing
}
}
func convertTableColumns2Entity(columns []*dataset.TableColumn) []*entity.TableColumn {
if len(columns) == 0 {
return nil
}
columnEntities := make([]*entity.TableColumn, 0, len(columns))
for i := range columns {
columnEntities = append(columnEntities, &entity.TableColumn{
ID: columns[i].GetID(),
Name: columns[i].GetColumnName(),
Type: convertColumnType2Entity(columns[i].GetColumnType()),
Description: columns[i].GetDesc(),
Indexing: columns[i].GetIsSemantic(),
Sequence: columns[i].GetSequence(),
})
}
return columnEntities
}
func convertTableColumns2Model(columns []*entity.TableColumn) []*dataset.TableColumn {
if len(columns) == 0 {
return nil
}
columnModels := make([]*dataset.TableColumn, 0, len(columns))
for i := range columns {
columnType := convertColumnType2Model(columns[i].Type)
columnModels = append(columnModels, &dataset.TableColumn{
ID: columns[i].ID,
ColumnName: columns[i].Name,
ColumnType: &columnType,
Desc: &columns[i].Description,
IsSemantic: columns[i].Indexing,
Sequence: columns[i].Sequence,
})
}
return columnModels
}
func convertTableColumnDataSlice(cols []*entity.TableColumn, data []*document.ColumnData) (map[string]string, error) {
if len(cols) != len(data) {
return nil, fmt.Errorf("[convertTableColumnDataSlice] invalid cols and vals, len(cols)=%d, len(vals)=%d", len(cols), len(data))
}
resp := make(map[string]string, len(data))
for i := range data {
col := cols[i]
val := data[i]
content := ""
if val != nil {
content = val.GetStringValue()
}
resp[strconv.FormatInt(col.Sequence, 10)] = content
}
return resp, nil
}
func convertColumnType2Model(columnType document.TableColumnType) dataset.ColumnType {
switch columnType {
case document.TableColumnTypeString:
return dataset.ColumnType_Text
case document.TableColumnTypeInteger:
return dataset.ColumnType_Number
case document.TableColumnTypeImage:
return dataset.ColumnType_Image
case document.TableColumnTypeBoolean:
return dataset.ColumnType_Boolean
case document.TableColumnTypeTime:
return dataset.ColumnType_Date
case document.TableColumnTypeNumber:
return dataset.ColumnType_Float
default:
return dataset.ColumnType_Text
}
}
func convertColumnType2Entity(columnType dataset.ColumnType) document.TableColumnType {
switch columnType {
case dataset.ColumnType_Text:
return document.TableColumnTypeString
case dataset.ColumnType_Number:
return document.TableColumnTypeInteger
case dataset.ColumnType_Image:
return document.TableColumnTypeImage
case dataset.ColumnType_Boolean:
return document.TableColumnTypeBoolean
case dataset.ColumnType_Date:
return document.TableColumnTypeTime
case dataset.ColumnType_Float:
return document.TableColumnTypeNumber
default:
return document.TableColumnTypeString
}
}
func convertParsingStrategy2Entity(strategy *dataset.ParsingStrategy, sheet *dataset.TableSheet, captionType *dataset.CaptionType, filterStrategy *dataset.FilterStrategy) *entity.ParsingStrategy {
if strategy == nil && sheet == nil && captionType == nil {
return nil
}
res := &entity.ParsingStrategy{}
if strategy != nil {
res.ExtractImage = strategy.GetImageExtraction()
res.ExtractTable = strategy.GetTableExtraction()
res.ImageOCR = strategy.GetImageOcr()
res.ParsingType = convertParsingType2Entity(strategy.GetParsingType())
if strategy.GetParsingType() == dataset.ParsingType_FastParsing {
res.ExtractImage = false
res.ExtractTable = false
res.ImageOCR = false
}
}
if sheet != nil {
res.SheetID = sheet.GetSheetID()
res.HeaderLine = int(sheet.GetHeaderLineIdx())
res.DataStartLine = int(sheet.GetStartLineIdx())
}
if filterStrategy != nil {
res.FilterPages = slices.Transform(filterStrategy.GetFilterPage(), func(page int32) int { return int(page) })
}
res.CaptionType = convertCaptionType2Entity(captionType)
return res
}
func convertParsingType2Entity(pt dataset.ParsingType) entity.ParsingType {
switch pt {
case dataset.ParsingType_AccurateParsing:
return entity.ParsingType_AccurateParsing
case dataset.ParsingType_FastParsing:
return entity.ParsingType_FastParsing
default:
return entity.ParsingType_FastParsing
}
}
func convertParsingStrategy2Model(strategy *entity.ParsingStrategy) (s *dataset.ParsingStrategy, sheet *dataset.TableSheet) {
if strategy == nil {
return nil, nil
}
sheet = &dataset.TableSheet{
SheetID: strategy.SheetID,
HeaderLineIdx: int64(strategy.HeaderLine),
StartLineIdx: int64(strategy.DataStartLine),
}
return &dataset.ParsingStrategy{
ParsingType: ptr.Of(convertParsingType2Model(strategy.ParsingType)),
ImageExtraction: &strategy.ExtractImage,
TableExtraction: &strategy.ExtractTable,
ImageOcr: &strategy.ImageOCR,
}, sheet
}
func convertParsingType2Model(pt entity.ParsingType) dataset.ParsingType {
switch pt {
case entity.ParsingType_AccurateParsing:
return dataset.ParsingType_AccurateParsing
case entity.ParsingType_FastParsing:
return dataset.ParsingType_FastParsing
default:
return dataset.ParsingType_FastParsing
}
}
func convertChunkingStrategy2Entity(strategy *dataset.ChunkStrategy) *entity.ChunkingStrategy {
if strategy == nil {
return nil
}
if strategy.ChunkType == dataset.ChunkType_DefaultChunk {
return &entity.ChunkingStrategy{
ChunkType: convertChunkType2Entity(dataset.ChunkType_DefaultChunk),
}
}
return &entity.ChunkingStrategy{
ChunkType: convertChunkType2Entity(strategy.ChunkType),
ChunkSize: strategy.GetMaxTokens(),
Separator: strategy.GetSeparator(),
Overlap: strategy.GetOverlap(),
TrimSpace: strategy.GetRemoveExtraSpaces(),
TrimURLAndEmail: strategy.GetRemoveUrlsEmails(),
MaxDepth: strategy.GetMaxLevel(),
SaveTitle: strategy.GetSaveTitle(),
}
}
func GetExtension(uri string) string {
if uri == "" {
return ""
}
fileExtension := path.Base(uri)
ext := path.Ext(fileExtension)
if ext != "" {
return strings.TrimPrefix(ext, ".")
}
return ""
}
func convertCaptionType2Entity(ct *dataset.CaptionType) *parser.ImageAnnotationType {
if ct == nil {
return nil
}
switch ptr.From(ct) {
case dataset.CaptionType_Auto:
return ptr.Of(parser.ImageAnnotationTypeModel)
case dataset.CaptionType_Manual:
return ptr.Of(parser.ImageAnnotationTypeManual)
default:
return ptr.Of(parser.ImageAnnotationTypeModel)
}
}
func convertDatasetStatus2Entity(status dataset.DatasetStatus) model.KnowledgeStatus {
switch status {
case dataset.DatasetStatus_DatasetReady:
return model.KnowledgeStatusEnable
case dataset.DatasetStatus_DatasetForbid, dataset.DatasetStatus_DatasetDeleted:
return model.KnowledgeStatusDisable
default:
return model.KnowledgeStatusEnable
}
}
func convertChunkType2model(chunkType parser.ChunkType) dataset.ChunkType {
switch chunkType {
case parser.ChunkTypeCustom:
return dataset.ChunkType_CustomChunk
case parser.ChunkTypeDefault:
return dataset.ChunkType_DefaultChunk
case parser.ChunkTypeLeveled:
return dataset.ChunkType_LevelChunk
default:
return dataset.ChunkType_CustomChunk
}
}
func convertChunkType2Entity(chunkType dataset.ChunkType) parser.ChunkType {
switch chunkType {
case dataset.ChunkType_CustomChunk:
return parser.ChunkTypeCustom
case dataset.ChunkType_DefaultChunk:
return parser.ChunkTypeDefault
case dataset.ChunkType_LevelChunk:
return parser.ChunkTypeLeveled
default:
return parser.ChunkTypeDefault
}
}
func convertChunkingStrategy2Model(chunkingStrategy *entity.ChunkingStrategy) *dataset.ChunkStrategy {
if chunkingStrategy == nil {
return nil
}
return &dataset.ChunkStrategy{
Separator: chunkingStrategy.Separator,
MaxTokens: chunkingStrategy.ChunkSize,
RemoveExtraSpaces: chunkingStrategy.TrimSpace,
RemoveUrlsEmails: chunkingStrategy.TrimURLAndEmail,
ChunkType: convertChunkType2model(chunkingStrategy.ChunkType),
Overlap: &chunkingStrategy.Overlap,
MaxLevel: &chunkingStrategy.MaxDepth,
SaveTitle: &chunkingStrategy.SaveTitle,
}
}
func convertDocumentTypeEntity2Dataset(formatType model.DocumentType) dataset.FormatType {
switch formatType {
case model.DocumentTypeText:
return dataset.FormatType_Text
case model.DocumentTypeTable:
return dataset.FormatType_Table
case model.DocumentTypeImage:
return dataset.FormatType_Image
default:
return dataset.FormatType_Text
}
}
func convertDocumentTypeDataset2Entity(formatType dataset.FormatType) model.DocumentType {
switch formatType {
case dataset.FormatType_Text:
return model.DocumentTypeText
case dataset.FormatType_Table:
return model.DocumentTypeTable
case dataset.FormatType_Image:
return model.DocumentTypeImage
default:
return model.DocumentTypeUnknown
}
}
func batchConvertKnowledgeEntity2Model(ctx context.Context, knowledgeEntity []*model.Knowledge) (map[int64]*dataset.Dataset, error) {
knowledgeMap := map[int64]*dataset.Dataset{}
for _, k := range knowledgeEntity {
documentEntity, err := KnowledgeSVC.DomainSVC.ListDocument(ctx, &service.ListDocumentRequest{
KnowledgeID: k.ID,
SelectAll: true,
})
if err != nil {
logs.CtxErrorf(ctx, "list document failed, err: %v", err)
return nil, err
}
datasetStatus := dataset.DatasetStatus_DatasetReady
if k.Status == model.KnowledgeStatusDisable {
datasetStatus = dataset.DatasetStatus_DatasetForbid
}
var (
rule *entity.ChunkingStrategy
totalSize int64
sliceCount int32
processingFileList []string
processingFileIDList []string
fileList []string
)
for i := range documentEntity.Documents {
doc := documentEntity.Documents[i]
totalSize += doc.Size
sliceCount += int32(doc.SliceCount)
if doc.Status == entity.DocumentStatusChunking || doc.Status == entity.DocumentStatusUploading {
processingFileList = append(processingFileList, doc.Name)
processingFileIDList = append(processingFileIDList, strconv.FormatInt(doc.ID, 10))
}
if i == 0 {
rule = doc.ChunkingStrategy
}
fileList = append(fileList, doc.Name)
}
knowledgeMap[k.ID] = &dataset.Dataset{
DatasetID: k.ID,
Name: k.Name,
FileList: fileList,
AllFileSize: totalSize,
BotUsedCount: 0,
Status: datasetStatus,
ProcessingFileList: processingFileList,
UpdateTime: int32(k.UpdatedAtMs / 1000),
IconURI: k.IconURI,
IconURL: k.IconURL,
Description: k.Description,
CanEdit: true,
CreateTime: int32(k.CreatedAtMs / 1000),
CreatorID: k.CreatorID,
SpaceID: k.SpaceID,
FailedFileList: nil,
FormatType: convertDocumentTypeEntity2Dataset(k.Type),
SliceCount: sliceCount,
DocCount: int32(len(documentEntity.Documents)),
HitCount: int32(k.SliceHit),
ChunkStrategy: convertChunkingStrategy2Model(rule),
ProcessingFileIDList: processingFileIDList,
ProjectID: strconv.FormatInt(k.AppID, 10),
}
}
return knowledgeMap, nil
}
func convertSourceInfo(sourceInfo *dataset.SourceInfo) (*service.TableSourceInfo, error) {
if sourceInfo == nil {
return nil, nil
}
fType := sourceInfo.FileType
if fType == nil && sourceInfo.TosURI != nil {
split := strings.Split(sourceInfo.GetTosURI(), ".")
fType = &split[len(split)-1]
}
var customContent []map[string]string
if sourceInfo.CustomContent != nil {
if err := json.Unmarshal([]byte(sourceInfo.GetCustomContent()), &customContent); err != nil {
return nil, err
}
}
return &service.TableSourceInfo{
FileType: fType,
Uri: sourceInfo.TosURI,
FileBase64: sourceInfo.FileBase64,
CustomContent: customContent,
}, nil
}
func convertCreateDocReviewReq(req *dataset.CreateDocumentReviewRequest) *service.CreateDocumentReviewRequest {
if req == nil {
return nil
}
var captionType *dataset.CaptionType
if req.GetChunkStrategy() != nil {
captionType = req.GetChunkStrategy().CaptionType
}
resp := &service.CreateDocumentReviewRequest{
ChunkStrategy: convertChunkingStrategy2Entity(req.ChunkStrategy),
ParsingStrategy: convertParsingStrategy2Entity(req.ParsingStrategy, nil, captionType, nil),
}
resp.KnowledgeID = req.GetDatasetID()
resp.Reviews = slices.Transform(req.GetReviews(), func(r *dataset.ReviewInput) *service.ReviewInput {
return &service.ReviewInput{
DocumentName: r.GetDocumentName(),
DocumentType: r.GetDocumentType(),
TosUri: r.GetTosURI(),
DocumentID: ptr.Of(r.GetDocumentID()),
}
})
return resp
}
func convertReviewStatus2Model(status *entity.ReviewStatus) *dataset.ReviewStatus {
if status == nil {
return nil
}
switch *status {
case entity.ReviewStatus_Enable:
return dataset.ReviewStatusPtr(dataset.ReviewStatus_Enable)
case entity.ReviewStatus_Processing:
return dataset.ReviewStatusPtr(dataset.ReviewStatus_Processing)
case entity.ReviewStatus_Failed:
return dataset.ReviewStatusPtr(dataset.ReviewStatus_Failed)
case entity.ReviewStatus_ForceStop:
return dataset.ReviewStatusPtr(dataset.ReviewStatus_ForceStop)
default:
return dataset.ReviewStatusPtr(dataset.ReviewStatus_Processing)
}
}
func getIconURI(tp dataset.FormatType) string {
switch tp {
case dataset.FormatType_Text:
return upload.TextKnowledgeDefaultIcon
case dataset.FormatType_Table:
return upload.TableKnowledgeDefaultIcon
case dataset.FormatType_Image:
return upload.ImageKnowledgeDefaultIcon
default:
return upload.TextKnowledgeDefaultIcon
}
}
func convertFormatType2Entity(tp dataset.FormatType) model.DocumentType {
switch tp {
case dataset.FormatType_Text:
return model.DocumentTypeText
case dataset.FormatType_Table:
return model.DocumentTypeTable
case dataset.FormatType_Image:
return model.DocumentTypeImage
default:
return model.DocumentTypeUnknown
}
}

View File

@@ -0,0 +1,445 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package knowledge
import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strconv"
"time"
"github.com/cloudwego/eino-ext/components/embedding/ark"
"github.com/cloudwego/eino-ext/components/embedding/openai"
ao "github.com/cloudwego/eino-ext/components/model/ark"
"github.com/cloudwego/eino-ext/components/model/deepseek"
"github.com/cloudwego/eino-ext/components/model/gemini"
"github.com/cloudwego/eino-ext/components/model/ollama"
mo "github.com/cloudwego/eino-ext/components/model/openai"
"github.com/cloudwego/eino-ext/components/model/qwen"
"github.com/cloudwego/eino/components/prompt"
"github.com/cloudwego/eino/schema"
"github.com/milvus-io/milvus/client/v2/milvusclient"
"github.com/volcengine/volc-sdk-golang/service/vikingdb"
"github.com/volcengine/volc-sdk-golang/service/visual"
"google.golang.org/genai"
"gorm.io/gorm"
"github.com/coze-dev/coze-studio/backend/application/search"
knowledgeImpl "github.com/coze-dev/coze-studio/backend/domain/knowledge/service"
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/nl2sql"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/ocr"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
"github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
"github.com/coze-dev/coze-studio/backend/infra/contract/es"
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
"github.com/coze-dev/coze-studio/backend/infra/contract/messages2query"
"github.com/coze-dev/coze-studio/backend/infra/contract/rdb"
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
chatmodelImpl "github.com/coze-dev/coze-studio/backend/infra/impl/chatmodel"
builtinNL2SQL "github.com/coze-dev/coze-studio/backend/infra/impl/document/nl2sql/builtin"
"github.com/coze-dev/coze-studio/backend/infra/impl/document/ocr/veocr"
builtinParser "github.com/coze-dev/coze-studio/backend/infra/impl/document/parser/builtin"
"github.com/coze-dev/coze-studio/backend/infra/impl/document/rerank/rrf"
sses "github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/elasticsearch"
ssmilvus "github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/milvus"
ssvikingdb "github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/vikingdb"
arkemb "github.com/coze-dev/coze-studio/backend/infra/impl/embedding/ark"
"github.com/coze-dev/coze-studio/backend/infra/impl/embedding/wrap"
"github.com/coze-dev/coze-studio/backend/infra/impl/eventbus"
builtinM2Q "github.com/coze-dev/coze-studio/backend/infra/impl/messages2query/builtin"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/types/consts"
)
type ServiceComponents struct {
DB *gorm.DB
IDGenSVC idgen.IDGenerator
Storage storage.Storage
RDB rdb.RDB
ImageX imagex.ImageX
ES es.Client
EventBus search.ResourceEventBus
CacheCli cache.Cmdable
}
func InitService(c *ServiceComponents) (*KnowledgeApplicationService, error) {
ctx := context.Background()
nameServer := os.Getenv(consts.MQServer)
knowledgeProducer, err := eventbus.NewProducer(nameServer, consts.RMQTopicKnowledge, consts.RMQConsumeGroupKnowledge, 2)
if err != nil {
return nil, fmt.Errorf("init knowledge producer failed, err=%w", err)
}
var sManagers []searchstore.Manager
// es full text search
sManagers = append(sManagers, sses.NewManager(&sses.ManagerConfig{Client: c.ES}))
// vector search
mgr, err := getVectorStore(ctx)
if err != nil {
return nil, fmt.Errorf("init vector store failed, err=%w", err)
}
sManagers = append(sManagers, mgr)
var ocrImpl ocr.OCR
switch os.Getenv("OCR_TYPE") {
case "ve":
ocrAK := os.Getenv("VE_OCR_AK")
ocrSK := os.Getenv("VE_OCR_SK")
inst := visual.NewInstance()
inst.Client.SetAccessKey(ocrAK)
inst.Client.SetSecretKey(ocrSK)
ocrImpl = veocr.NewOCR(&veocr.Config{Client: inst})
default:
// accept ocr not configured
}
root, err := os.Getwd()
if err != nil {
logs.Warnf("[InitConfig] Failed to get current working directory: %v", err)
root = os.Getenv("PWD")
}
var rewriter messages2query.MessagesToQuery
if rewriterChatModel, _, err := getBuiltinChatModel(ctx, "M2Q_"); err != nil {
return nil, err
} else {
filePath := filepath.Join(root, "resources/conf/prompt/messages_to_query_template_jinja2.json")
rewriterTemplate, err := readJinja2PromptTemplate(filePath)
if err != nil {
return nil, err
}
rewriter, err = builtinM2Q.NewMessagesToQuery(ctx, rewriterChatModel, rewriterTemplate)
if err != nil {
return nil, err
}
}
var n2s nl2sql.NL2SQL
if n2sChatModel, _, err := getBuiltinChatModel(ctx, "NL2SQL_"); err != nil {
return nil, err
} else {
filePath := filepath.Join(root, "resources/conf/prompt/nl2sql_template_jinja2.json")
n2sTemplate, err := readJinja2PromptTemplate(filePath)
if err != nil {
return nil, err
}
n2s, err = builtinNL2SQL.NewNL2SQL(ctx, n2sChatModel, n2sTemplate)
if err != nil {
return nil, err
}
}
imageAnnoChatModel, configured, err := getBuiltinChatModel(ctx, "IA_")
if err != nil {
return nil, err
}
knowledgeDomainSVC, knowledgeEventHandler := knowledgeImpl.NewKnowledgeSVC(&knowledgeImpl.KnowledgeSVCConfig{
DB: c.DB,
IDGen: c.IDGenSVC,
RDB: c.RDB,
Producer: knowledgeProducer,
SearchStoreManagers: sManagers,
ParseManager: builtinParser.NewManager(c.Storage, ocrImpl, imageAnnoChatModel), // default builtin
Storage: c.Storage,
Rewriter: rewriter,
Reranker: rrf.NewRRFReranker(0), // default rrf
NL2Sql: n2s,
OCR: ocrImpl,
CacheCli: c.CacheCli,
IsAutoAnnotationSupported: configured,
ModelFactory: chatmodelImpl.NewDefaultFactory(),
})
if err = eventbus.RegisterConsumer(nameServer, consts.RMQTopicKnowledge, consts.RMQConsumeGroupKnowledge, knowledgeEventHandler); err != nil {
return nil, fmt.Errorf("register knowledge consumer failed, err=%w", err)
}
KnowledgeSVC.DomainSVC = knowledgeDomainSVC
KnowledgeSVC.eventBus = c.EventBus
KnowledgeSVC.storage = c.Storage
return KnowledgeSVC, nil
}
func getVectorStore(ctx context.Context) (searchstore.Manager, error) {
vsType := os.Getenv("VECTOR_STORE_TYPE")
switch vsType {
case "milvus":
cctx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()
milvusAddr := os.Getenv("MILVUS_ADDR")
mc, err := milvusclient.New(cctx, &milvusclient.ClientConfig{Address: milvusAddr})
if err != nil {
return nil, fmt.Errorf("init milvus client failed, err=%w", err)
}
emb, err := getEmbedding(ctx)
if err != nil {
return nil, fmt.Errorf("init milvus embedding failed, err=%w", err)
}
mgr, err := ssmilvus.NewManager(&ssmilvus.ManagerConfig{
Client: mc,
Embedding: emb,
EnableHybrid: ptr.Of(true),
})
if err != nil {
return nil, fmt.Errorf("init milvus vector store failed, err=%w", err)
}
return mgr, nil
case "vikingdb":
var (
host = os.Getenv("VIKING_DB_HOST")
region = os.Getenv("VIKING_DB_REGION")
ak = os.Getenv("VIKING_DB_AK")
sk = os.Getenv("VIKING_DB_SK")
scheme = os.Getenv("VIKING_DB_SCHEME")
modelName = os.Getenv("VIKING_DB_MODEL_NAME")
)
if ak == "" || sk == "" {
return nil, fmt.Errorf("invalid vikingdb ak / sk")
}
if host == "" {
host = "api-vikingdb.volces.com"
}
if region == "" {
region = "cn-beijing"
}
if scheme == "" {
scheme = "https"
}
var embConfig *ssvikingdb.VikingEmbeddingConfig
if modelName != "" {
embName := ssvikingdb.VikingEmbeddingModelName(modelName)
if embName.Dimensions() == 0 {
return nil, fmt.Errorf("embedding model not support, model_name=%s", modelName)
}
embConfig = &ssvikingdb.VikingEmbeddingConfig{
UseVikingEmbedding: true,
EnableHybrid: embName.SupportStatus() == embedding.SupportDenseAndSparse,
ModelName: embName,
ModelVersion: embName.ModelVersion(),
DenseWeight: ptr.Of(0.2),
BuiltinEmbedding: nil,
}
} else {
builtinEmbedding, err := getEmbedding(ctx)
if err != nil {
return nil, fmt.Errorf("builtint embedding init failed, err=%w", err)
}
embConfig = &ssvikingdb.VikingEmbeddingConfig{
UseVikingEmbedding: false,
EnableHybrid: false,
BuiltinEmbedding: builtinEmbedding,
}
}
svc := vikingdb.NewVikingDBService(host, region, ak, sk, scheme)
mgr, err := ssvikingdb.NewManager(&ssvikingdb.ManagerConfig{
Service: svc,
IndexingConfig: nil, // use default config
EmbeddingConfig: embConfig,
})
if err != nil {
return nil, fmt.Errorf("init vikingdb manager failed, err=%w", err)
}
return mgr, nil
default:
return nil, fmt.Errorf("unexpected vector store type, type=%s", vsType)
}
}
func getEmbedding(ctx context.Context) (embedding.Embedder, error) {
var emb embedding.Embedder
switch os.Getenv("EMBEDDING_TYPE") {
case "openai":
var (
openAIEmbeddingBaseURL = os.Getenv("OPENAI_EMBEDDING_BASE_URL")
openAIEmbeddingModel = os.Getenv("OPENAI_EMBEDDING_MODEL")
openAIEmbeddingApiKey = os.Getenv("OPENAI_EMBEDDING_API_KEY")
openAIEmbeddingByAzure = os.Getenv("OPENAI_EMBEDDING_BY_AZURE")
openAIEmbeddingApiVersion = os.Getenv("OPENAI_EMBEDDING_API_VERSION")
openAIEmbeddingDims = os.Getenv("OPENAI_EMBEDDING_DIMS")
openAIRequestEmbeddingDims = os.Getenv("OPENAI_EMBEDDING_REQUEST_DIMS")
)
byAzure, err := strconv.ParseBool(openAIEmbeddingByAzure)
if err != nil {
return nil, fmt.Errorf("init openai embedding by_azure failed, err=%w", err)
}
dims, err := strconv.ParseInt(openAIEmbeddingDims, 10, 64)
if err != nil {
return nil, fmt.Errorf("init openai embedding dims failed, err=%w", err)
}
openAICfg := &openai.EmbeddingConfig{
APIKey: openAIEmbeddingApiKey,
ByAzure: byAzure,
BaseURL: openAIEmbeddingBaseURL,
APIVersion: openAIEmbeddingApiVersion,
Model: openAIEmbeddingModel,
// Dimensions: ptr.Of(int(dims)),
}
reqDims := conv.StrToInt64D(openAIRequestEmbeddingDims, 0)
if reqDims > 0 {
// some openai model not support request dims
openAICfg.Dimensions = ptr.Of(int(reqDims))
}
emb, err = wrap.NewOpenAIEmbedder(ctx, openAICfg, dims)
if err != nil {
return nil, fmt.Errorf("init openai embedding failed, err=%w", err)
}
case "ark":
var (
arkEmbeddingBaseURL = os.Getenv("ARK_EMBEDDING_BASE_URL")
arkEmbeddingModel = os.Getenv("ARK_EMBEDDING_MODEL")
arkEmbeddingAK = os.Getenv("ARK_EMBEDDING_AK")
arkEmbeddingDims = os.Getenv("ARK_EMBEDDING_DIMS")
)
dims, err := strconv.ParseInt(arkEmbeddingDims, 10, 64)
if err != nil {
return nil, fmt.Errorf("init ark embedding dims failed, err=%w", err)
}
emb, err = arkemb.NewArkEmbedder(ctx, &ark.EmbeddingConfig{
APIKey: arkEmbeddingAK,
Model: arkEmbeddingModel,
BaseURL: arkEmbeddingBaseURL,
}, dims)
if err != nil {
return nil, fmt.Errorf("init ark embedding client failed, err=%w", err)
}
default:
return nil, fmt.Errorf("init knowledge embedding failed, type not configured")
}
return emb, nil
}
func getBuiltinChatModel(ctx context.Context, envPrefix string) (bcm chatmodel.BaseChatModel, configured bool, err error) {
getEnv := func(key string) string {
if val := os.Getenv(envPrefix + key); val != "" {
return val
}
return os.Getenv(key)
}
switch getEnv("BUILTIN_CM_TYPE") {
case "openai":
byAzure, _ := strconv.ParseBool(getEnv("BUILTIN_CM_OPENAI_BY_AZURE"))
bcm, err = mo.NewChatModel(ctx, &mo.ChatModelConfig{
APIKey: getEnv("BUILTIN_CM_OPENAI_API_KEY"),
ByAzure: byAzure,
BaseURL: getEnv("BUILTIN_CM_OPENAI_BASE_URL"),
Model: getEnv("BUILTIN_CM_OPENAI_MODEL"),
})
case "ark":
bcm, err = ao.NewChatModel(ctx, &ao.ChatModelConfig{
APIKey: getEnv("BUILTIN_CM_ARK_API_KEY"),
Model: getEnv("BUILTIN_CM_ARK_MODEL"),
BaseURL: getEnv("BUILTIN_CM_ARK_BASE_URL"),
})
case "deepseek":
bcm, err = deepseek.NewChatModel(ctx, &deepseek.ChatModelConfig{
APIKey: getEnv("BUILTIN_CM_DEEPSEEK_API_KEY"),
BaseURL: getEnv("BUILTIN_CM_DEEPSEEK_BASE_URL"),
Model: getEnv("BUILTIN_CM_DEEPSEEK_MODEL"),
})
case "ollama":
bcm, err = ollama.NewChatModel(ctx, &ollama.ChatModelConfig{
BaseURL: getEnv("BUILTIN_CM_OLLAMA_BASE_URL"),
Model: getEnv("BUILTIN_CM_OLLAMA_MODEL"),
})
case "qwen":
bcm, err = qwen.NewChatModel(ctx, &qwen.ChatModelConfig{
APIKey: getEnv("BUILTIN_CM_QWEN_API_KEY"),
BaseURL: getEnv("BUILTIN_CM_QWEN_BASE_URL"),
Model: getEnv("BUILTIN_CM_QWEN_MODEL"),
})
case "gemini":
backend, convErr := strconv.ParseInt(getEnv("BUILTIN_CM_GEMINI_BACKEND"), 10, 64)
if convErr != nil {
return nil, false, convErr
}
c, clientErr := genai.NewClient(ctx, &genai.ClientConfig{
APIKey: getEnv("BUILTIN_CM_GEMINI_API_KEY"),
Backend: genai.Backend(backend),
Project: getEnv("BUILTIN_CM_GEMINI_PROJECT"),
Location: getEnv("BUILTIN_CM_GEMINI_LOCATION"),
HTTPOptions: genai.HTTPOptions{
BaseURL: getEnv("BUILTIN_CM_GEMINI_BASE_URL"),
},
})
if clientErr != nil {
return nil, false, clientErr
}
bcm, err = gemini.NewChatModel(ctx, &gemini.Config{
Client: c,
Model: getEnv("BUILTIN_CM_GEMINI_MODEL"),
})
default:
// accept builtin chat model not configured
}
if err != nil {
return nil, false, fmt.Errorf("knowledge init openai chat mode failed, %w", err)
}
if bcm != nil {
configured = true
}
return
}
func readJinja2PromptTemplate(jsonFilePath string) (prompt.ChatTemplate, error) {
b, err := os.ReadFile(jsonFilePath)
if err != nil {
return nil, err
}
var m2qMessages []*schema.Message
if err = json.Unmarshal(b, &m2qMessages); err != nil {
return nil, err
}
tpl := make([]schema.MessagesTemplate, len(m2qMessages))
for i := range m2qMessages {
tpl[i] = m2qMessages[i]
}
return prompt.FromMessages(schema.Jinja2, tpl...), nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,288 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package memory
import (
"fmt"
"strconv"
"strings"
"github.com/coze-dev/coze-studio/backend/api/model/base"
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/api/model/table"
"github.com/coze-dev/coze-studio/backend/domain/memory/database/entity"
database "github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
)
func convertAddDatabase(req *table.AddDatabaseRequest) *database.CreateDatabaseRequest {
fieldItems := make([]*model.FieldItem, 0, len(req.FieldList))
for _, field := range req.FieldList {
fieldItems = append(fieldItems, &model.FieldItem{
Name: field.Name,
Desc: field.Desc,
Type: field.Type,
MustRequired: field.MustRequired,
})
}
return &database.CreateDatabaseRequest{
Database: &entity.Database{
IconURI: req.IconURI,
CreatorID: req.CreatorID,
SpaceID: req.SpaceID,
AppID: req.ProjectID,
TableName: req.TableName,
TableDesc: req.TableDesc,
FieldList: fieldItems,
RwMode: req.RwMode,
PromptDisabled: req.PromptDisabled,
ExtraInfo: req.ExtraInfo,
},
}
}
func ConvertDatabaseRes(res *entity.Database) *table.SingleDatabaseResponse {
return &table.SingleDatabaseResponse{
DatabaseInfo: convertDatabaseRes(res),
Code: 0,
Msg: "success",
BaseResp: &base.BaseResp{
StatusCode: 0,
StatusMessage: "success",
},
}
}
// ConvertUpdateDatabase converts the API update request to domain request
func ConvertUpdateDatabase(req *table.UpdateDatabaseRequest) *database.UpdateDatabaseRequest {
fieldItems := make([]*model.FieldItem, 0, len(req.FieldList))
for _, field := range req.FieldList {
fieldItems = append(fieldItems, &model.FieldItem{
Name: field.Name,
Desc: field.Desc,
AlterID: field.AlterId,
Type: field.Type,
MustRequired: field.MustRequired,
})
}
return &database.UpdateDatabaseRequest{
Database: &entity.Database{
ID: req.ID,
IconURI: req.IconURI,
TableName: req.TableName,
TableDesc: req.TableDesc,
FieldList: fieldItems,
RwMode: req.RwMode,
PromptDisabled: req.PromptDisabled,
ExtraInfo: req.ExtraInfo,
},
}
}
// convertUpdateDatabaseResult converts the domain update response to API response
func convertUpdateDatabaseResult(res *database.UpdateDatabaseResponse) *table.SingleDatabaseResponse {
return &table.SingleDatabaseResponse{
DatabaseInfo: convertDatabaseRes(res.Database),
Code: 0,
Msg: "success",
BaseResp: &base.BaseResp{
StatusCode: 0,
StatusMessage: "success",
},
}
}
func convertDatabaseRes(db *entity.Database) *table.DatabaseInfo {
fieldItems := make([]*table.FieldItem, 0, len(db.FieldList))
for _, field := range db.FieldList {
fieldItems = append(fieldItems, &table.FieldItem{
Name: field.Name,
Desc: field.Desc,
Type: field.Type,
MustRequired: field.MustRequired,
AlterId: field.AlterID,
IsSystemField: field.IsSystemField,
})
}
return &table.DatabaseInfo{
ID: db.ID,
SpaceID: db.SpaceID,
ProjectID: db.AppID,
IconURI: db.IconURI,
IconURL: db.IconURL,
TableName: db.TableName,
TableDesc: db.TableDesc,
Status: db.Status,
CreatorID: db.CreatorID,
CreateTime: db.CreatedAtMs,
UpdateTime: db.UpdatedAtMs,
FieldList: fieldItems,
ActualTableName: db.ActualTableName,
RwMode: table.BotTableRWMode(db.RwMode),
PromptDisabled: db.PromptDisabled,
IsVisible: db.IsVisible,
DraftID: db.DraftID,
ExtraInfo: db.ExtraInfo,
IsAddedToBot: db.IsAddedToAgent,
DatamodelTableID: getDataModelTableID(db.ActualTableName),
}
}
// convertListDatabase converts the API list request to domain request
func convertListDatabase(req *table.ListDatabaseRequest) *database.ListDatabaseRequest {
dRes := &database.ListDatabaseRequest{
SpaceID: req.SpaceID,
TableName: req.TableName,
TableType: req.TableType,
AppID: req.GetProjectID(),
Limit: int(req.GetLimit()),
Offset: int(req.GetOffset()),
}
if req.CreatorID != nil && *req.CreatorID != 0 {
dRes.CreatorID = req.CreatorID
}
if len(req.OrderBy) > 0 {
dRes.OrderBy = make([]*model.OrderBy, len(req.OrderBy))
for i, order := range req.OrderBy {
dRes.OrderBy[i] = &model.OrderBy{
Field: order.Field,
Direction: order.Direction,
}
}
}
return dRes
}
// convertListDatabaseRes converts the domain list response to API response
func convertListDatabaseRes(res *database.ListDatabaseResponse, bindDatabases []*entity.Database) *table.ListDatabaseResponse {
databaseInfos := make([]*table.DatabaseInfo, 0, len(res.Databases))
dbMap := slices.ToMap(bindDatabases, func(e *entity.Database) (int64, *entity.Database) {
return e.ID, e
})
for _, db := range res.Databases {
databaseInfo := convertDatabaseRes(db)
if _, ok := dbMap[db.ID]; ok {
databaseInfo.IsAddedToBot = ptr.Of(true)
}
databaseInfos = append(databaseInfos, databaseInfo)
}
return &table.ListDatabaseResponse{
DatabaseInfoList: databaseInfos,
TotalCount: res.TotalCount,
Code: 0,
Msg: "success",
BaseResp: &base.BaseResp{
StatusCode: 0,
StatusMessage: "success",
},
}
}
// convertListDatabaseRecordsRes converts domain ListDatabaseRecordResponse to API ListDatabaseRecordsResponse
func convertListDatabaseRecordsRes(res *database.ListDatabaseRecordResponse) *table.ListDatabaseRecordsResponse {
apiRes := &table.ListDatabaseRecordsResponse{
Data: res.Records,
TotalNum: int32(res.TotalCount),
HasMore: res.HasMore,
FieldList: make([]*table.FieldItem, 0, len(res.FieldList)),
Code: 0,
Msg: "success",
BaseResp: &base.BaseResp{
StatusCode: 0,
StatusMessage: "success",
},
}
for _, field := range res.FieldList {
apiRes.FieldList = append(apiRes.FieldList, &table.FieldItem{
Name: field.Name,
Desc: field.Desc,
Type: field.Type,
MustRequired: field.MustRequired,
})
}
return apiRes
}
func getDataModelTableID(actualTableName string) string {
tableID := ""
tableIDStr := strings.Split(actualTableName, "_")
if len(tableIDStr) < 2 {
return tableID
}
return tableIDStr[1]
}
func convertToBotTableList(databases []*entity.Database, agentID int64, relationMap map[int64]*model.AgentToDatabase) []*table.BotTable {
if len(databases) == 0 {
return []*table.BotTable{}
}
botTables := make([]*table.BotTable, 0, len(databases))
for _, db := range databases {
fieldItems := make([]*table.FieldItem, 0, len(db.FieldList))
for _, field := range db.FieldList {
fieldItems = append(fieldItems, &table.FieldItem{
Name: field.Name,
Desc: field.Desc,
Type: field.Type,
MustRequired: field.MustRequired,
AlterId: field.AlterID,
IsSystemField: field.IsSystemField,
})
}
botTable := &table.BotTable{
ID: db.ID,
BotID: agentID,
TableID: strconv.FormatInt(db.ID, 10),
TableName: db.TableName,
TableDesc: db.TableDesc,
Status: table.BotTableStatus(db.Status),
CreatorID: db.CreatorID,
CreateTime: db.CreatedAtMs,
UpdateTime: db.UpdatedAtMs,
FieldList: fieldItems,
ActualTableName: db.ActualTableName,
RwMode: table.BotTableRWMode(db.RwMode),
}
if r, ok := relationMap[db.ID]; ok {
botTable.ExtraInfo = map[string]string{
"prompt_disabled": fmt.Sprintf("%t", r.PromptDisabled),
}
}
botTables = append(botTables, botTable)
}
return botTables
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,64 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package memory
import (
"gorm.io/gorm"
"github.com/redis/go-redis/v9"
database "github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
"github.com/coze-dev/coze-studio/backend/domain/memory/variables/repository"
variables "github.com/coze-dev/coze-studio/backend/domain/memory/variables/service"
search "github.com/coze-dev/coze-studio/backend/domain/search/service"
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
"github.com/coze-dev/coze-studio/backend/infra/contract/rdb"
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
rdbService "github.com/coze-dev/coze-studio/backend/infra/impl/rdb"
)
type MemoryApplicationServices struct {
VariablesDomainSVC variables.Variables
DatabaseDomainSVC database.Database
RDBDomainSVC rdb.RDB
}
type ServiceComponents struct {
IDGen idgen.IDGenerator
DB *gorm.DB
EventBus search.ResourceEventBus
TosClient storage.Storage
ResourceDomainNotifier search.ResourceEventBus
CacheCli *redis.Client
}
func InitService(c *ServiceComponents) *MemoryApplicationServices {
repo := repository.NewVariableRepo(c.DB, c.IDGen)
variablesDomainSVC := variables.NewService(repo)
rdbSVC := rdbService.NewService(c.DB, c.IDGen)
databaseDomainSVC := database.NewService(rdbSVC, c.DB, c.IDGen, c.TosClient, c.CacheCli)
VariableApplicationSVC.DomainSVC = variablesDomainSVC
DatabaseApplicationSVC.DomainSVC = databaseDomainSVC
DatabaseApplicationSVC.eventbus = c.ResourceDomainNotifier
return &MemoryApplicationServices{
VariablesDomainSVC: variablesDomainSVC,
DatabaseDomainSVC: databaseDomainSVC,
RDBDomainSVC: rdbSVC,
}
}

View File

@@ -0,0 +1,385 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package memory
import (
"context"
"encoding/json"
"fmt"
"strconv"
"github.com/coze-dev/coze-studio/backend/api/model/base"
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/variables"
"github.com/coze-dev/coze-studio/backend/api/model/kvmemory"
"github.com/coze-dev/coze-studio/backend/api/model/project_memory"
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
"github.com/coze-dev/coze-studio/backend/domain/memory/variables/entity"
variables "github.com/coze-dev/coze-studio/backend/domain/memory/variables/service"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/i18n"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/types/consts"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
type VariableApplicationService struct {
DomainSVC variables.Variables
}
var VariableApplicationSVC = VariableApplicationService{}
var i18nLocal2GroupVariableInfo = map[i18n.Locale]map[project_memory.VariableChannel]project_memory.GroupVariableInfo{
i18n.LocaleEN: {
project_memory.VariableChannel_APP: {
GroupName: "App variable",
GroupDesc: "Configures data accessed across multiple development scenarios in the app. It is initialized to a default value each time a new request is sent.",
},
project_memory.VariableChannel_Custom: {
GroupName: "User variable",
GroupDesc: "Persistently stores and reads project date for users, such as the preferred language and custom settings.",
},
project_memory.VariableChannel_System: {
GroupName: "System variable",
GroupDesc: "Displays the data that you enabled as needed, which can be used to identify users via IDs or handle channel-specific features. The data is automatically generated and is read-only.",
},
},
}
var channel2GroupVariableInfo = map[project_memory.VariableChannel]project_memory.GroupVariableInfo{
project_memory.VariableChannel_APP: {
GroupName: "应用变量",
GroupDesc: "用于配置应用中多处开发场景需要访问的数据,每次新请求均会初始化为默认值。",
GroupExtDesc: "",
IsReadOnly: false,
SubGroupList: []*project_memory.GroupVariableInfo{},
VarInfoList: []*project_memory.Variable{},
},
project_memory.VariableChannel_Custom: {
GroupName: "用户变量",
GroupDesc: "用于存储每个用户使用项目过程中,需要持久化存储和读取的数据,如用户的语言偏好、个性化设置等。",
GroupExtDesc: "",
IsReadOnly: false,
SubGroupList: []*project_memory.GroupVariableInfo{},
VarInfoList: []*project_memory.Variable{},
},
project_memory.VariableChannel_System: {
GroupName: "系统变量",
GroupDesc: "可选择开启你需要获取的系统在用户在请求自动产生的数据仅可读不可修改。如用于通过ID识别用户或处理某些渠道特有的功能。",
GroupExtDesc: "",
IsReadOnly: true,
SubGroupList: []*project_memory.GroupVariableInfo{},
VarInfoList: []*project_memory.Variable{},
},
}
func (v *VariableApplicationService) GetSysVariableConf(ctx context.Context, req *kvmemory.GetSysVariableConfRequest) (*kvmemory.GetSysVariableConfResponse, error) {
vars := v.DomainSVC.GetSysVariableConf(ctx)
return &kvmemory.GetSysVariableConfResponse{
Conf: vars,
GroupConf: vars.GroupByName(),
}, nil
}
func (v *VariableApplicationService) GetProjectVariablesMeta(ctx context.Context, appOwnerID int64, req *project_memory.GetProjectVariableListReq) (*project_memory.GetProjectVariableListResp, error) {
uid := ctxutil.GetUIDFromCtx(ctx)
if uid == nil {
return nil, errorx.New(errno.ErrMemoryPermissionCode, errorx.KV("msg", "session required"))
}
version := ""
if req.Version != 0 {
version = fmt.Sprintf("%d", req.Version)
}
meta, err := v.DomainSVC.GetProjectVariablesMeta(ctx, req.ProjectID, version)
if err != nil {
return nil, err
}
groupConf, err := v.toGroupVariableInfo(ctx, meta)
if err != nil {
return nil, err
}
return &project_memory.GetProjectVariableListResp{
VariableList: meta.ToProjectVariables(),
GroupConf: groupConf,
CanEdit: appOwnerID == *uid,
}, nil
}
func (v *VariableApplicationService) getGroupVariableConf(ctx context.Context, channel project_memory.VariableChannel) project_memory.GroupVariableInfo {
groupConf, ok := channel2GroupVariableInfo[channel]
if !ok {
return project_memory.GroupVariableInfo{}
}
local := i18n.GetLocale(ctx)
i18nConf, ok := i18nLocal2GroupVariableInfo[local][channel]
if ok {
groupConf.GroupName = i18nConf.GroupName
groupConf.GroupDesc = i18nConf.GroupDesc
}
return groupConf
}
func (v *VariableApplicationService) toGroupVariableInfo(ctx context.Context, meta *entity.VariablesMeta) ([]*project_memory.GroupVariableInfo, error) {
channel2Vars := meta.GroupByChannel()
groupConfList := make([]*project_memory.GroupVariableInfo, 0, len(channel2Vars))
showChannels := []project_memory.VariableChannel{
project_memory.VariableChannel_APP,
project_memory.VariableChannel_Custom,
project_memory.VariableChannel_System,
}
for _, channel := range showChannels {
ch := channel
vars := channel2Vars[ch]
groupConf := v.getGroupVariableConf(ctx, ch)
groupConf.DefaultChannel = &ch
if channel != project_memory.VariableChannel_System {
groupConf.VarInfoList = vars
groupConfList = append(groupConfList, &groupConf)
continue
}
key2Var := make(map[string]*project_memory.Variable)
for _, v := range vars {
key2Var[v.Keyword] = v
}
// project_memory.VariableChannel_System
sysVars := v.DomainSVC.GetSysVariableConf(ctx).RemoveLocalChannelVariable()
groupName2Group := sysVars.GroupByName()
subGroupList := make([]*project_memory.GroupVariableInfo, 0, len(groupName2Group))
for _, group := range groupName2Group {
var e entity.SysConfVariables = group.VarInfoList
varList := make([]*project_memory.Variable, 0, len(group.VarInfoList))
for _, defaultSysMeta := range e.ToVariables().ToProjectVariables() {
sysMetaInUserConf := key2Var[defaultSysMeta.Keyword]
if sysMetaInUserConf == nil {
varList = append(varList, defaultSysMeta)
} else {
varList = append(varList, sysMetaInUserConf)
}
}
pGroupVariableInfo := &project_memory.GroupVariableInfo{
GroupName: group.GroupName,
GroupDesc: group.GroupDesc,
GroupExtDesc: group.GroupExtDesc,
IsReadOnly: true,
VarInfoList: varList,
}
subGroupList = append(subGroupList, pGroupVariableInfo)
}
groupConf.SubGroupList = subGroupList
groupConfList = append(groupConfList, &groupConf)
}
return groupConfList, nil
}
func (v *VariableApplicationService) UpdateProjectVariable(ctx context.Context, req project_memory.UpdateProjectVariableReq) (*project_memory.UpdateProjectVariableResp, error) {
uid := ctxutil.GetUIDFromCtx(ctx)
if uid == nil {
return nil, errorx.New(errno.ErrMemoryPermissionCode, errorx.KV("msg", "session required"))
}
if req.UserID == 0 {
req.UserID = *uid
}
// TODO: project owner check
sysVars := v.DomainSVC.GetSysVariableConf(ctx).ToVariables()
sysVarsKeys2Meta := make(map[string]*entity.VariableMeta)
for _, v := range sysVars.Variables {
sysVarsKeys2Meta[v.Keyword] = v
}
list := make([]*project_memory.Variable, 0, len(req.VariableList))
for _, v := range req.VariableList {
if v.Channel == project_memory.VariableChannel_System &&
sysVarsKeys2Meta[v.Keyword] == nil {
logs.CtxInfof(ctx, "sys variable not found, keyword: %s", v.Keyword)
continue
}
list = append(list, v)
}
key2Var := make(map[string]*project_memory.Variable)
for _, v := range req.VariableList {
key2Var[v.Keyword] = v
}
for _, v := range sysVars.Variables {
if key2Var[v.Keyword] == nil {
list = append(list, v.ToProjectVariable())
} else {
if key2Var[v.Keyword].DefaultValue != v.DefaultValue ||
key2Var[v.Keyword].VariableType != v.VariableType {
return nil, errorx.New(errno.ErrMemoryPermissionCode, errorx.KV("msg", "can not update system variable"))
}
}
}
for _, vv := range list {
if vv.Channel == project_memory.VariableChannel_APP {
e := entity.NewVariableMeta(vv)
err := e.CheckSchema(ctx)
if err != nil {
return nil, err
}
}
}
_, err := v.DomainSVC.UpsertProjectMeta(ctx, req.ProjectID, "", req.UserID, entity.NewVariables(list))
if err != nil {
return nil, err
}
return &project_memory.UpdateProjectVariableResp{
Code: 0,
Msg: "success",
}, nil
}
func (v *VariableApplicationService) GetVariableMeta(ctx context.Context, req *project_memory.GetMemoryVariableMetaReq) (*project_memory.GetMemoryVariableMetaResp, error) {
vars, err := v.DomainSVC.GetVariableMeta(ctx, req.ConnectorID, req.ConnectorType, req.GetVersion())
if err != nil {
return nil, err
}
vars.RemoveDisableVariable()
return &project_memory.GetMemoryVariableMetaResp{
VariableMap: vars.GroupByChannel(),
}, nil
}
func (v *VariableApplicationService) DeleteVariableInstance(ctx context.Context, req *kvmemory.DelProfileMemoryRequest) (*kvmemory.DelProfileMemoryResponse, error) {
uid := ctxutil.GetUIDFromCtx(ctx)
if uid == nil {
return nil, errorx.New(errno.ErrMemoryPermissionCode, errorx.KV("msg", "session required"))
}
bizType := ternary.IFElse(req.BotID == 0, project_memory.VariableConnector_Project, project_memory.VariableConnector_Bot)
bizID := ternary.IFElse(req.BotID == 0, req.ProjectID, fmt.Sprintf("%d", req.BotID))
e := entity.NewUserVariableMeta(&model.UserVariableMeta{
BizType: bizType,
BizID: bizID,
Version: "",
ConnectorID: req.GetConnectorID(),
ConnectorUID: fmt.Sprintf("%d", *uid),
})
err := v.DomainSVC.DeleteVariableInstance(ctx, e, req.Keywords)
if err != nil {
return nil, err
}
return &kvmemory.DelProfileMemoryResponse{}, nil
}
func (v *VariableApplicationService) GetPlayGroundMemory(ctx context.Context, req *kvmemory.GetProfileMemoryRequest) (*kvmemory.GetProfileMemoryResponse, error) {
uid := ctxutil.GetUIDFromCtx(ctx)
if uid == nil {
return nil, errorx.New(errno.ErrMemoryPermissionCode, errorx.KV("msg", "session required"))
}
isProjectKV := req.ProjectID != nil
versionStr := strconv.FormatInt(req.GetProjectVersion(), 10)
if req.GetProjectVersion() == 0 {
versionStr = ""
}
bizType := ternary.IFElse(isProjectKV, project_memory.VariableConnector_Project, project_memory.VariableConnector_Bot)
bizID := ternary.IFElse(isProjectKV, req.GetProjectID(), fmt.Sprintf("%d", req.BotID))
version := ternary.IFElse(isProjectKV, versionStr, "")
connectId := ternary.IFElse(req.ConnectorID == nil, consts.CozeConnectorID, req.GetConnectorID())
connectorUID := ternary.IFElse(req.UserID == 0, *uid, req.UserID)
e := entity.NewUserVariableMeta(&model.UserVariableMeta{
BizType: bizType,
BizID: bizID,
Version: version,
ConnectorID: connectId,
ConnectorUID: fmt.Sprintf("%d", connectorUID),
})
res, err := v.DomainSVC.GetVariableChannelInstance(ctx, e, req.Keywords, req.VariableChannel)
if err != nil {
return nil, err
}
return &kvmemory.GetProfileMemoryResponse{
Memories: res,
}, nil
}
func (v *VariableApplicationService) SetVariableInstance(ctx context.Context, req *kvmemory.SetKvMemoryReq) (*kvmemory.SetKvMemoryResp, error) {
uid := ctxutil.GetUIDFromCtx(ctx)
if uid == nil {
return nil, errorx.New(errno.ErrMemoryPermissionCode, errorx.KV("msg", "session required"))
}
isProjectKV := req.ProjectID != nil
versionStr := strconv.FormatInt(req.GetProjectVersion(), 10)
if req.GetProjectVersion() == 0 {
versionStr = ""
}
bizType := ternary.IFElse(isProjectKV, project_memory.VariableConnector_Project, project_memory.VariableConnector_Bot)
bizID := ternary.IFElse(isProjectKV, req.GetProjectID(), fmt.Sprintf("%d", req.BotID))
version := ternary.IFElse(isProjectKV, versionStr, "")
connectId := ternary.IFElse(req.ConnectorID == nil, consts.CozeConnectorID, req.GetConnectorID())
connectorUID := ternary.IFElse(req.GetUserID() == 0, *uid, req.GetUserID())
e := entity.NewUserVariableMeta(&model.UserVariableMeta{
BizType: bizType,
BizID: bizID,
Version: version,
ConnectorID: connectId,
ConnectorUID: fmt.Sprintf("%d", connectorUID),
})
exitKeys, err := v.DomainSVC.SetVariableInstance(ctx, e, req.Data)
if err != nil {
return nil, err
}
exitKeysStr, _ := json.Marshal(exitKeys)
return &kvmemory.SetKvMemoryResp{
BaseResp: &base.BaseResp{
Extra: map[string]string{"existKeys": string(exitKeysStr)},
},
}, nil
}

View File

@@ -0,0 +1,198 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package modelmgr
import (
"context"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"gopkg.in/yaml.v3"
"gorm.io/gorm"
crossmodelmgr "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
"github.com/coze-dev/coze-studio/backend/domain/modelmgr"
"github.com/coze-dev/coze-studio/backend/domain/modelmgr/entity"
"github.com/coze-dev/coze-studio/backend/domain/modelmgr/service"
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
"github.com/coze-dev/coze-studio/backend/infra/impl/idgen"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
)
func InitService(db *gorm.DB, idgen idgen.IDGenerator, oss storage.Storage) (*ModelmgrApplicationService, error) {
svc := service.NewModelManager(db, idgen, oss)
if err := loadStaticModelConfig(svc, oss); err != nil {
return nil, err
}
ModelmgrApplicationSVC.DomainSVC = svc
return ModelmgrApplicationSVC, nil
}
func loadStaticModelConfig(svc modelmgr.Manager, oss storage.Storage) error {
ctx := context.Background()
id2Meta := make(map[int64]*entity.ModelMeta)
var cursor *string
for {
req := &modelmgr.ListModelMetaRequest{
Status: []entity.ModelMetaStatus{
crossmodelmgr.StatusInUse,
crossmodelmgr.StatusPending,
crossmodelmgr.StatusDeleted,
},
Limit: 100,
Cursor: cursor,
}
listMetaResp, err := svc.ListModelMeta(ctx, req)
if err != nil {
return err
}
for _, item := range listMetaResp.ModelMetaList {
cpItem := item
id2Meta[cpItem.ID] = cpItem
}
if !listMetaResp.HasMore {
break
}
cursor = listMetaResp.NextCursor
}
root, err := os.Getwd()
if err != nil {
return err
}
envModelMeta, envModelEntity, err := initModelByEnv(root, "resources/conf/model/template")
if err != nil {
return err
}
filePath := filepath.Join(root, "resources/conf/model/meta")
staticModelMeta, err := readDirYaml[crossmodelmgr.ModelMeta](filePath)
if err != nil {
return err
}
staticModelMeta = append(staticModelMeta, envModelMeta...)
for _, modelMeta := range staticModelMeta {
if _, found := id2Meta[modelMeta.ID]; !found {
if modelMeta.IconURI == "" && modelMeta.IconURL == "" {
return fmt.Errorf("missing icon URI or icon URL, id=%d", modelMeta.ID)
} else if modelMeta.IconURL != "" {
// do nothing
} else if modelMeta.IconURI != "" {
// try local path
base := filepath.Base(modelMeta.IconURI)
iconPath := filepath.Join("resources/conf/model/icon", base)
if _, err = os.Stat(iconPath); err == nil {
// try upload icon
icon, err := os.ReadFile(iconPath)
if err != nil {
return err
}
key := fmt.Sprintf("icon_%s_%d", base, time.Now().Second())
if err := oss.PutObject(ctx, key, icon); err != nil {
return err
}
modelMeta.IconURI = key
} else if errors.Is(err, os.ErrNotExist) {
// try to get object from uri
if _, err := oss.GetObject(ctx, modelMeta.IconURI); err != nil {
return err
}
} else {
return err
}
}
newMeta, err := svc.CreateModelMeta(ctx, modelMeta)
if err != nil {
if errors.Is(err, gorm.ErrDuplicatedKey) {
logs.Infof("[loadStaticModelConfig] model meta conflict for id=%d, skip", newMeta.ID)
}
return err
} else {
logs.Infof("[loadStaticModelConfig] model meta create success, id=%d", newMeta.ID)
}
id2Meta[newMeta.ID] = newMeta
} else {
logs.Infof("[loadStaticModelConfig] model meta founded, skip create, id=%d", modelMeta.ID)
}
}
filePath = filepath.Join(root, "resources/conf/model/entity")
staticModel, err := readDirYaml[crossmodelmgr.Model](filePath)
if err != nil {
return err
}
staticModel = append(staticModel, envModelEntity...)
for _, modelEntity := range staticModel {
curModelEntities, err := svc.MGetModelByID(ctx, &modelmgr.MGetModelRequest{IDs: []int64{modelEntity.ID}})
if err != nil {
return err
}
if len(curModelEntities) > 0 {
logs.Infof("[loadStaticModelConfig] model entity founded, skip create, id=%d", modelEntity.ID)
continue
}
meta, found := id2Meta[modelEntity.Meta.ID]
if !found {
return fmt.Errorf("model meta not found for id=%d, model_id=%d", modelEntity.Meta.ID, modelEntity.ID)
}
modelEntity.Meta = *meta
if _, err = svc.CreateModel(ctx, &entity.Model{Model: modelEntity}); err != nil {
if errors.Is(err, gorm.ErrDuplicatedKey) {
logs.Infof("[loadStaticModelConfig] model entity conflict for id=%d, skip", modelEntity.ID)
}
return err
} else {
logs.Infof("[loadStaticModelConfig] model entity create success, id=%d", modelEntity.ID)
}
}
return nil
}
func readDirYaml[T any](dir string) ([]*T, error) {
des, err := os.ReadDir(dir)
if err != nil {
return nil, err
}
resp := make([]*T, 0, len(des))
for _, file := range des {
if file.IsDir() {
continue
}
if strings.HasSuffix(file.Name(), ".yaml") || strings.HasSuffix(file.Name(), ".yml") {
filePath := filepath.Join(dir, file.Name())
data, err := os.ReadFile(filePath)
if err != nil {
return nil, err
}
var content T
if err := yaml.Unmarshal(data, &content); err != nil {
return nil, err
}
resp = append(resp, &content)
}
}
return resp, nil
}

View File

@@ -0,0 +1,148 @@
package modelmgr
import (
"fmt"
"os"
"path/filepath"
"strconv"
"gopkg.in/yaml.v3"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
)
func initModelByEnv(wd, templatePath string) (metaSlice []*modelmgr.ModelMeta, entitySlice []*modelmgr.Model, err error) {
metaRoot := filepath.Join(wd, templatePath, "meta")
entityRoot := filepath.Join(wd, templatePath, "entity")
for i := -1; i < 1000; i++ {
rawProtocol := os.Getenv(concatEnvKey(modelProtocolPrefix, i))
if rawProtocol == "" {
if i < 0 {
continue
} else {
break
}
}
protocol := chatmodel.Protocol(rawProtocol)
info, valid := getModelEnv(i)
if !valid {
break
}
mapping, found := modelMapping[protocol]
if !found {
return nil, nil, fmt.Errorf("[initModelByEnv] unsupport protocol: %s", rawProtocol)
}
switch protocol {
case chatmodel.ProtocolArk:
fileSuffix, foundTemplate := mapping[info.modelName]
if !foundTemplate {
logs.Warnf("[initModelByEnv] unsupport model=%s, using default config", info.modelName)
}
modelMeta, err := readYaml[modelmgr.ModelMeta](filepath.Join(metaRoot, concatTemplateFileName("model_meta_template_ark", fileSuffix)))
if err != nil {
return nil, nil, err
}
modelEntity, err := readYaml[modelmgr.Model](filepath.Join(entityRoot, concatTemplateFileName("model_entity_template_ark", fileSuffix)))
if err != nil {
return nil, nil, err
}
id, err := strconv.ParseInt(info.id, 10, 64)
if err != nil {
return nil, nil, err
}
// meta 和 entity 用一个 id有概率冲突
modelMeta.ID = id
modelMeta.ConnConfig.Model = info.modelID
modelMeta.ConnConfig.APIKey = info.apiKey
if info.baseURL != "" {
modelMeta.ConnConfig.BaseURL = info.baseURL
}
modelEntity.ID = id
modelEntity.Meta.ID = id
if !foundTemplate {
modelEntity.Name = info.modelName
}
metaSlice = append(metaSlice, modelMeta)
entitySlice = append(entitySlice, modelEntity)
default:
return nil, nil, fmt.Errorf("[initModelByEnv] unsupport protocol: %s", rawProtocol)
}
}
return metaSlice, entitySlice, nil
}
type envModelInfo struct {
id, modelName, modelID, apiKey, baseURL string
}
func getModelEnv(idx int) (info envModelInfo, valid bool) {
info.id = os.Getenv(concatEnvKey(modelOpenCozeIDPrefix, idx))
info.modelName = os.Getenv(concatEnvKey(modelNamePrefix, idx))
info.modelID = os.Getenv(concatEnvKey(modelIDPrefix, idx))
info.apiKey = os.Getenv(concatEnvKey(modelApiKeyPrefix, idx))
info.baseURL = os.Getenv(concatEnvKey(modelBaseURLPrefix, idx))
valid = info.modelName != "" && info.modelID != "" && info.apiKey != ""
return
}
func readYaml[T any](fPath string) (*T, error) {
data, err := os.ReadFile(fPath)
if err != nil {
return nil, err
}
var content T
if err := yaml.Unmarshal(data, &content); err != nil {
return nil, err
}
return &content, nil
}
func concatEnvKey(prefix string, idx int) string {
if idx < 0 {
return prefix
}
return fmt.Sprintf("%s_%d", prefix, idx)
}
func concatTemplateFileName(prefix, suffix string) string {
if suffix == "" {
return prefix + ".yaml"
}
return prefix + "_" + suffix + ".yaml"
}
const (
modelProtocolPrefix = "MODEL_PROTOCOL" // model protocol
modelOpenCozeIDPrefix = "MODEL_OPENCOZE_ID" // opencoze model id
modelNamePrefix = "MODEL_NAME" // model name,
modelIDPrefix = "MODEL_ID" // model in conn config
modelApiKeyPrefix = "MODEL_API_KEY" // model api key
modelBaseURLPrefix = "MODEL_BASE_URL" // model base url
)
var modelMapping = map[chatmodel.Protocol]map[string]string{
chatmodel.ProtocolArk: {
"doubao-seed-1.6": "doubao-seed-1.6",
"doubao-seed-1.6-flash": "doubao-seed-1.6-flash",
"doubao-seed-1.6-thinking": "doubao-seed-1.6-thinking",
"doubao-1.5-thinking-vision-pro": "doubao-1.5-thinking-vision-pro",
"doubao-1.5-thinking-pro": "doubao-1.5-thinking-pro",
"doubao-1.5-vision-pro": "doubao-1.5-vision-pro",
"doubao-1.5-vision-lite": "doubao-1.5-vision-lite",
"doubao-1.5-pro-32k": "doubao-1.5-pro-32k",
"doubao-1.5-pro-256k": "doubao-1.5-pro-256k",
"doubao-1.5-lite": "doubao-1.5-lite",
"deepseek-r1": "volc_deepseek-r1",
"deepseek-v3": "volc_deepseek-v3",
},
}

View File

@@ -0,0 +1,31 @@
package modelmgr
import (
"fmt"
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
)
func TestInitByEnv(t *testing.T) {
i := 0
for k := range modelMapping[chatmodel.ProtocolArk] {
_ = os.Setenv(concatEnvKey(modelProtocolPrefix, i), "ark")
_ = os.Setenv(concatEnvKey(modelOpenCozeIDPrefix, i), fmt.Sprintf("%d", 45678+i))
_ = os.Setenv(concatEnvKey(modelNamePrefix, i), k)
_ = os.Setenv(concatEnvKey(modelIDPrefix, i), k)
_ = os.Setenv(concatEnvKey(modelApiKeyPrefix, i), "mock_api_key")
i++
}
wd, err := os.Getwd()
assert.NoError(t, err)
ms, es, err := initModelByEnv(wd, "../../conf/model/template")
assert.NoError(t, err)
assert.Len(t, ms, len(modelMapping[chatmodel.ProtocolArk]))
assert.Len(t, es, len(modelMapping[chatmodel.ProtocolArk]))
}

View File

@@ -0,0 +1,205 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package modelmgr
import (
"context"
modelmgrEntity "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/developer_api"
"github.com/coze-dev/coze-studio/backend/domain/modelmgr"
modelEntity "github.com/coze-dev/coze-studio/backend/domain/modelmgr/entity"
"github.com/coze-dev/coze-studio/backend/pkg/i18n"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/sets"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
)
type ModelmgrApplicationService struct {
DomainSVC modelmgr.Manager
}
var ModelmgrApplicationSVC = &ModelmgrApplicationService{}
func (m *ModelmgrApplicationService) GetModelList(ctx context.Context, req *developer_api.GetTypeListRequest) (
resp *developer_api.GetTypeListResponse, err error,
) {
// 一般不太可能同时配置这么多模型
const modelMaxLimit = 300
modelResp, err := m.DomainSVC.ListModel(ctx, &modelmgr.ListModelRequest{
Limit: modelMaxLimit,
Cursor: nil,
})
if err != nil {
return nil, err
}
locale := i18n.GetLocale(ctx)
modelList, err := slices.TransformWithErrorCheck(modelResp.ModelList, func(m *modelEntity.Model) (*developer_api.Model, error) {
logs.CtxInfof(ctx, "ChatModel DefaultParameters: %v", m.DefaultParameters)
return modelDo2To(m, locale)
})
if err != nil {
return nil, err
}
return &developer_api.GetTypeListResponse{
Code: 0,
Msg: "success",
Data: &developer_api.GetTypeListData{
ModelList: modelList,
},
}, nil
}
func modelDo2To(model *modelEntity.Model, locale i18n.Locale) (*developer_api.Model, error) {
mm := model.Meta
mps := slices.Transform(model.DefaultParameters,
func(param *modelmgrEntity.Parameter) *developer_api.ModelParameter {
return parameterDo2To(param, locale)
},
)
modalSet := sets.FromSlice(mm.Capability.InputModal)
return &developer_api.Model{
Name: model.Name,
ModelType: model.ID,
ModelClass: mm.Protocol.TOModelClass(),
ModelIcon: mm.IconURL,
ModelInputPrice: 0,
ModelOutputPrice: 0,
ModelQuota: &developer_api.ModelQuota{
TokenLimit: int32(mm.Capability.MaxTokens),
TokenResp: int32(mm.Capability.OutputTokens),
// TokenSystem: 0,
// TokenUserIn: 0,
// TokenToolsIn: 0,
// TokenToolsOut: 0,
// TokenData: 0,
// TokenHistory: 0,
// TokenCutSwitch: false,
PriceIn: 0,
PriceOut: 0,
SystemPromptLimit: nil,
},
ModelName: mm.Name,
ModelClassName: mm.Protocol.TOModelClass().String(),
IsOffline: mm.Status != modelmgrEntity.StatusInUse,
ModelParams: mps,
ModelDesc: []*developer_api.ModelDescGroup{
{
GroupName: "Description",
Desc: []string{model.Description},
},
},
FuncConfig: nil,
EndpointName: nil,
ModelTagList: nil,
IsUpRequired: nil,
ModelBriefDesc: mm.Description.Read(locale),
ModelSeries: &developer_api.ModelSeriesInfo{ // TODO: 替换为真实配置
SeriesName: "热门模型",
},
ModelStatusDetails: nil,
ModelAbility: &developer_api.ModelAbility{
CotDisplay: ptr.Of(mm.Capability.Reasoning),
FunctionCall: ptr.Of(mm.Capability.FunctionCall),
ImageUnderstanding: ptr.Of(modalSet.Contains(modelmgrEntity.ModalImage)),
VideoUnderstanding: ptr.Of(modalSet.Contains(modelmgrEntity.ModalVideo)),
AudioUnderstanding: ptr.Of(modalSet.Contains(modelmgrEntity.ModalAudio)),
SupportMultiModal: ptr.Of(len(modalSet) > 1),
PrefillResp: ptr.Of(mm.Capability.PrefillResponse),
},
}, nil
}
func parameterDo2To(param *modelmgrEntity.Parameter, locale i18n.Locale) *developer_api.ModelParameter {
if param == nil {
return nil
}
apiOptions := make([]*developer_api.Option, 0, len(param.Options))
for _, opt := range param.Options {
apiOptions = append(apiOptions, &developer_api.Option{
Label: opt.Label,
Value: opt.Value,
})
}
var custom string
var creative, balance, precise *string
if val, ok := param.DefaultVal[modelmgrEntity.DefaultTypeDefault]; ok {
custom = val
}
if val, ok := param.DefaultVal[modelmgrEntity.DefaultTypeCreative]; ok {
creative = ptr.Of(val)
}
if val, ok := param.DefaultVal[modelmgrEntity.DefaultTypeBalance]; ok {
balance = ptr.Of(val)
}
if val, ok := param.DefaultVal[modelmgrEntity.DefaultTypePrecise]; ok {
precise = ptr.Of(val)
}
return &developer_api.ModelParameter{
Name: string(param.Name),
Label: param.Label.Read(locale),
Desc: param.Desc.Read(locale),
Type: func() developer_api.ModelParamType {
switch param.Type {
case modelmgrEntity.ValueTypeBoolean:
return developer_api.ModelParamType_Boolean
case modelmgrEntity.ValueTypeInt:
return developer_api.ModelParamType_Int
case modelmgrEntity.ValueTypeFloat:
return developer_api.ModelParamType_Float
default:
return developer_api.ModelParamType_String
}
}(),
Min: param.Min,
Max: param.Max,
Precision: int32(param.Precision),
DefaultVal: &developer_api.ModelParamDefaultValue{
DefaultVal: custom,
Creative: creative,
Balance: balance,
Precise: precise,
},
Options: apiOptions,
ParamClass: &developer_api.ModelParamClass{
ClassID: func() int32 {
switch param.Style.Widget {
case modelmgrEntity.WidgetSlider:
return 1
case modelmgrEntity.WidgetRadioButtons:
return 2
default:
return 0
}
}(),
Label: param.Style.Label.Read(locale),
},
}
}

View File

@@ -0,0 +1,39 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package openauth
import (
"gorm.io/gorm"
openapiauth2 "github.com/coze-dev/coze-studio/backend/domain/openauth/openapiauth"
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
)
var (
openapiAuthDomainSVC openapiauth2.APIAuth
)
func InitService(db *gorm.DB, idGenSVC idgen.IDGenerator) *OpenAuthApplicationService {
openapiAuthDomainSVC = openapiauth2.NewService(&openapiauth2.Components{
IDGen: idGenSVC,
DB: db,
})
OpenAuthApplication.OpenAPIDomainSVC = openapiAuthDomainSVC
return OpenAuthApplication
}

View File

@@ -0,0 +1,203 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package openauth
import (
"context"
"strconv"
"time"
"github.com/pkg/errors"
openapimodel "github.com/coze-dev/coze-studio/backend/api/model/permission/openapiauth"
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
openapi "github.com/coze-dev/coze-studio/backend/domain/openauth/openapiauth"
"github.com/coze-dev/coze-studio/backend/domain/openauth/openapiauth/entity"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
)
type OpenAuthApplicationService struct {
OpenAPIDomainSVC openapi.APIAuth
}
var OpenAuthApplication = &OpenAuthApplicationService{}
func (s *OpenAuthApplicationService) GetPersonalAccessTokenAndPermission(ctx context.Context, req *openapimodel.GetPersonalAccessTokenAndPermissionRequest) (*openapimodel.GetPersonalAccessTokenAndPermissionResponse, error) {
resp := new(openapimodel.GetPersonalAccessTokenAndPermissionResponse)
userID := ctxutil.GetUIDFromCtx(ctx)
appReq := &entity.GetApiKey{
ID: req.ID,
}
apiKeyResp, err := openapiAuthDomainSVC.Get(ctx, appReq)
if err != nil {
logs.CtxErrorf(ctx, "OpenAuthApplicationService.GetPersonalAccessTokenAndPermission failed, err=%v", err)
return resp, errors.New("GetPersonalAccessTokenAndPermission failed")
}
if apiKeyResp == nil {
return resp, errors.New("GetPersonalAccessTokenAndPermission failed")
}
if apiKeyResp.UserID != *userID {
return resp, errors.New("permission not match")
}
resp.Data = &openapimodel.GetPersonalAccessTokenAndPermissionResponseData{
PersonalAccessToken: &openapimodel.PersonalAccessToken{
ID: apiKeyResp.ID,
Name: apiKeyResp.Name,
ExpireAt: apiKeyResp.ExpiredAt,
CreatedAt: apiKeyResp.CreatedAt,
UpdatedAt: apiKeyResp.UpdatedAt,
},
}
return resp, nil
}
func (s *OpenAuthApplicationService) CreatePersonalAccessToken(ctx context.Context, req *openapimodel.CreatePersonalAccessTokenAndPermissionRequest) (*openapimodel.CreatePersonalAccessTokenAndPermissionResponse, error) {
resp := new(openapimodel.CreatePersonalAccessTokenAndPermissionResponse)
userID := ctxutil.GetUIDFromCtx(ctx)
appReq := &entity.CreateApiKey{
Name: req.Name,
Expire: req.ExpireAt,
UserID: *userID,
}
if req.DurationDay == "customize" {
appReq.Expire = req.ExpireAt
} else {
expireDay, err := strconv.ParseInt(req.DurationDay, 10, 64)
if err != nil {
return resp, errors.New("invalid expireDay")
}
appReq.Expire = time.Now().Add(time.Duration(expireDay) * time.Hour * 24).Unix()
}
apiKeyResp, err := openapiAuthDomainSVC.Create(ctx, appReq)
if err != nil {
logs.CtxErrorf(ctx, "OpenAuthApplicationService.CreatePersonalAccessToken failed, err=%v", err)
return resp, errors.New("CreatePersonalAccessToken failed")
}
resp.Data = &openapimodel.CreatePersonalAccessTokenAndPermissionResponseData{
PersonalAccessToken: &openapimodel.PersonalAccessToken{
ID: apiKeyResp.ID,
Name: apiKeyResp.Name,
ExpireAt: apiKeyResp.ExpiredAt,
CreatedAt: apiKeyResp.CreatedAt,
UpdatedAt: apiKeyResp.UpdatedAt,
},
Token: apiKeyResp.ApiKey,
}
return resp, nil
}
func (s *OpenAuthApplicationService) ListPersonalAccessTokens(ctx context.Context, req *openapimodel.ListPersonalAccessTokensRequest) (*openapimodel.ListPersonalAccessTokensResponse, error) {
resp := new(openapimodel.ListPersonalAccessTokensResponse)
userID := ctxutil.GetUIDFromCtx(ctx)
appReq := &entity.ListApiKey{
UserID: *userID,
Page: *req.Page,
Limit: *req.Size,
}
apiKeyResp, err := openapiAuthDomainSVC.List(ctx, appReq)
if err != nil {
logs.CtxErrorf(ctx, "OpenAuthApplicationService.ListPersonalAccessTokens failed, err=%v", err)
return resp, errors.New("ListPersonalAccessTokens failed")
}
if apiKeyResp == nil {
return resp, nil
}
resp.Data = &openapimodel.ListPersonalAccessTokensResponseData{
HasMore: apiKeyResp.HasMore,
PersonalAccessTokens: slices.Transform(apiKeyResp.ApiKeys, func(a *entity.ApiKey) *openapimodel.PersonalAccessTokenWithCreatorInfo {
lastUsedAt := a.LastUsedAt
if lastUsedAt == 0 {
lastUsedAt = -1
}
return &openapimodel.PersonalAccessTokenWithCreatorInfo{
ID: a.ID,
Name: a.Name,
ExpireAt: a.ExpiredAt,
CreatedAt: a.CreatedAt,
UpdatedAt: a.UpdatedAt,
LastUsedAt: lastUsedAt,
}
}),
}
return resp, nil
}
func (s *OpenAuthApplicationService) DeletePersonalAccessTokenAndPermission(ctx context.Context, req *openapimodel.DeletePersonalAccessTokenAndPermissionRequest) (*openapimodel.DeletePersonalAccessTokenAndPermissionResponse, error) {
resp := new(openapimodel.DeletePersonalAccessTokenAndPermissionResponse)
userID := ctxutil.GetUIDFromCtx(ctx)
appReq := &entity.DeleteApiKey{
ID: req.ID,
UserID: *userID,
}
err := openapiAuthDomainSVC.Delete(ctx, appReq)
if err != nil {
logs.CtxErrorf(ctx, "OpenAuthApplicationService.DeletePersonalAccessTokenAndPermission failed, err=%v", err)
return resp, errors.New("DeletePersonalAccessTokenAndPermission failed")
}
return resp, nil
}
func (s *OpenAuthApplicationService) UpdatePersonalAccessTokenAndPermission(ctx context.Context, req *openapimodel.UpdatePersonalAccessTokenAndPermissionRequest) (*openapimodel.UpdatePersonalAccessTokenAndPermissionResponse, error) {
resp := new(openapimodel.UpdatePersonalAccessTokenAndPermissionResponse)
userID := ctxutil.GetUIDFromCtx(ctx)
upErr := openapiAuthDomainSVC.Save(ctx, &entity.SaveMeta{
ID: req.ID,
Name: ptr.Of(req.Name),
UserID: *userID,
})
return resp, upErr
}
func (s *OpenAuthApplicationService) UpdateLastUsedAt(ctx context.Context, apiID int64, userID int64) error {
upErr := openapiAuthDomainSVC.Save(ctx, &entity.SaveMeta{
ID: apiID,
LastUsedAt: ptr.Of(time.Now().Unix()),
UserID: userID,
})
return upErr
}
func (s *OpenAuthApplicationService) CheckPermission(ctx context.Context, token string) (*entity.ApiKey, error) {
appReq := &entity.CheckPermission{
ApiKey: token,
}
apiKey, err := openapiAuthDomainSVC.CheckPermission(ctx, appReq)
if err != nil {
logs.CtxErrorf(ctx, "OpenAuthApplicationService.CheckPermission failed, err=%v", err)
return nil, errors.New("CheckPermission failed")
}
return apiKey, nil
}

View File

@@ -0,0 +1,79 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package plugin
import (
"context"
"gorm.io/gorm"
pluginConf "github.com/coze-dev/coze-studio/backend/domain/plugin/conf"
"github.com/coze-dev/coze-studio/backend/domain/plugin/repository"
"github.com/coze-dev/coze-studio/backend/domain/plugin/service"
search "github.com/coze-dev/coze-studio/backend/domain/search/service"
user "github.com/coze-dev/coze-studio/backend/domain/user/service"
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
)
type ServiceComponents struct {
IDGen idgen.IDGenerator
DB *gorm.DB
OSS storage.Storage
EventBus search.ResourceEventBus
UserSVC user.User
}
func InitService(ctx context.Context, components *ServiceComponents) (*PluginApplicationService, error) {
err := pluginConf.InitConfig(ctx)
if err != nil {
return nil, err
}
toolRepo := repository.NewToolRepo(&repository.ToolRepoComponents{
IDGen: components.IDGen,
DB: components.DB,
})
pluginRepo := repository.NewPluginRepo(&repository.PluginRepoComponents{
IDGen: components.IDGen,
DB: components.DB,
})
oauthRepo := repository.NewOAuthRepo(&repository.OAuthRepoComponents{
IDGen: components.IDGen,
DB: components.DB,
})
pluginSVC := service.NewService(&service.Components{
IDGen: components.IDGen,
DB: components.DB,
OSS: components.OSS,
PluginRepo: pluginRepo,
ToolRepo: toolRepo,
OAuthRepo: oauthRepo,
})
PluginApplicationSVC.DomainSVC = pluginSVC
PluginApplicationSVC.eventbus = components.EventBus
PluginApplicationSVC.oss = components.OSS
PluginApplicationSVC.userSVC = components.UserSVC
PluginApplicationSVC.pluginRepo = pluginRepo
PluginApplicationSVC.toolRepo = toolRepo
return PluginApplicationSVC, nil
}

View File

@@ -0,0 +1,34 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package plugin
import (
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
"github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
)
type CopyPluginRequest struct {
PluginID int64
UserID int64
CopyScene model.CopyScene
TargetAPPID *int64
}
type CopyPluginResponse struct {
Plugin *entity.PluginInfo
Tools map[int64]*entity.ToolInfo // old tool id -> new tool id
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,34 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package prompt
import (
"gorm.io/gorm"
"github.com/coze-dev/coze-studio/backend/application/search"
"github.com/coze-dev/coze-studio/backend/domain/prompt/repository"
prompt "github.com/coze-dev/coze-studio/backend/domain/prompt/service"
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
)
func InitService(db *gorm.DB, idGenSVC idgen.IDGenerator, re search.ResourceEventBus) *PromptApplicationService {
repo := repository.NewPromptRepo(db, idGenSVC)
PromptSVC.DomainSVC = prompt.NewService(repo)
PromptSVC.eventbus = re
return PromptSVC
}

View File

@@ -0,0 +1,238 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package prompt
import (
"context"
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/playground"
"github.com/coze-dev/coze-studio/backend/api/model/resource/common"
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
"github.com/coze-dev/coze-studio/backend/application/search"
"github.com/coze-dev/coze-studio/backend/domain/prompt/entity"
prompt "github.com/coze-dev/coze-studio/backend/domain/prompt/service"
searchEntity "github.com/coze-dev/coze-studio/backend/domain/search/entity"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
type PromptApplicationService struct {
DomainSVC prompt.Prompt
eventbus search.ResourceEventBus
}
var PromptSVC = &PromptApplicationService{}
func (p *PromptApplicationService) UpsertPromptResource(ctx context.Context, req *playground.UpsertPromptResourceRequest) (resp *playground.UpsertPromptResourceResponse, err error) {
session := ctxutil.GetUserSessionFromCtx(ctx)
if session == nil {
return nil, errorx.New(errno.ErrPromptPermissionCode, errorx.KV("msg", "no session data provided"))
}
promptID := req.Prompt.GetID()
if promptID == 0 {
// create a new prompt resource
resp, err = p.createPromptResource(ctx, req)
if err != nil {
return nil, err
}
pErr := p.eventbus.PublishResources(ctx, &searchEntity.ResourceDomainEvent{
OpType: searchEntity.Created,
Resource: &searchEntity.ResourceDocument{
ResType: common.ResType_Prompt,
ResID: resp.Data.ID,
Name: req.Prompt.Name,
SpaceID: req.Prompt.SpaceID,
OwnerID: &session.UserID,
PublishStatus: ptr.Of(common.PublishStatus_Published),
},
})
if pErr != nil {
logs.CtxErrorf(ctx, "publish resource event failed: %v", pErr)
}
return resp, nil
}
// update an existing prompt resource
resp, err = p.updatePromptResource(ctx, req)
if err != nil {
return nil, err
}
pErr := p.eventbus.PublishResources(ctx, &searchEntity.ResourceDomainEvent{
OpType: searchEntity.Updated,
Resource: &searchEntity.ResourceDocument{
ResType: common.ResType_Prompt,
ResID: resp.Data.ID,
Name: req.Prompt.Name,
SpaceID: req.Prompt.SpaceID,
},
})
if pErr != nil {
logs.CtxErrorf(ctx, "publish resource event failed: %v", pErr)
}
return resp, nil
}
func (p *PromptApplicationService) GetPromptResourceInfo(ctx context.Context, req *playground.GetPromptResourceInfoRequest) (
resp *playground.GetPromptResourceInfoResponse, err error,
) {
promptInfo, err := p.DomainSVC.GetPromptResource(ctx, req.GetPromptResourceID())
if err != nil {
return nil, err
}
return &playground.GetPromptResourceInfoResponse{
Data: promptInfoDo2To(promptInfo),
Code: 0,
}, nil
}
func (p *PromptApplicationService) GetOfficialPromptResourceList(ctx context.Context, c *playground.GetOfficialPromptResourceListRequest) (
*playground.GetOfficialPromptResourceListResponse, error,
) {
session := ctxutil.GetUserSessionFromCtx(ctx)
if session == nil {
return nil, errorx.New(errno.ErrPromptPermissionCode, errorx.KV("msg", "no session data provided"))
}
promptList, err := p.DomainSVC.ListOfficialPromptResource(ctx, c.GetKeyword())
if err != nil {
return nil, err
}
return &playground.GetOfficialPromptResourceListResponse{
PromptResourceList: slices.Transform(promptList, func(p *entity.PromptResource) *playground.PromptResource {
return promptInfoDo2To(p)
}),
Code: 0,
}, nil
}
func (p *PromptApplicationService) DeletePromptResource(ctx context.Context, req *playground.DeletePromptResourceRequest) (resp *playground.DeletePromptResourceResponse, err error) {
uid := ctxutil.GetUIDFromCtx(ctx)
if uid == nil {
return nil, errorx.New(errno.ErrPromptPermissionCode, errorx.KV("msg", "no session data provided"))
}
promptInfo, err := p.DomainSVC.GetPromptResource(ctx, req.GetPromptResourceID())
if err != nil {
return nil, err
}
if promptInfo.CreatorID != *uid {
return nil, errorx.New(errno.ErrPromptPermissionCode, errorx.KV("msg", "no permission"))
}
err = p.DomainSVC.DeletePromptResource(ctx, req.GetPromptResourceID())
if err != nil {
return nil, err
}
pErr := p.eventbus.PublishResources(ctx, &searchEntity.ResourceDomainEvent{
OpType: searchEntity.Deleted,
Resource: &searchEntity.ResourceDocument{
ResType: common.ResType_Prompt,
ResID: req.GetPromptResourceID(),
},
})
if pErr != nil {
logs.CtxErrorf(ctx, "publish resource event failed: %v", pErr)
}
return &playground.DeletePromptResourceResponse{
Code: 0,
}, nil
}
func (p *PromptApplicationService) createPromptResource(ctx context.Context, req *playground.UpsertPromptResourceRequest) (resp *playground.UpsertPromptResourceResponse, err error) {
do := p.toPromptResourceDO(req.Prompt)
uid := ctxutil.GetUIDFromCtx(ctx)
do.CreatorID = *uid
promptID, err := p.DomainSVC.CreatePromptResource(ctx, do)
if err != nil {
return nil, err
}
return &playground.UpsertPromptResourceResponse{
Data: &playground.ShowPromptResource{
ID: promptID,
},
Code: 0,
}, nil
}
func (p *PromptApplicationService) updatePromptResource(ctx context.Context, req *playground.UpsertPromptResourceRequest) (resp *playground.UpsertPromptResourceResponse, err error) {
promptID := req.Prompt.GetID()
promptResource, err := p.DomainSVC.GetPromptResource(ctx, promptID)
if err != nil {
return nil, err
}
logs.CtxInfof(ctx, "promptResource.SpaceID: %v , promptResource.CreatorID : %v", promptResource.SpaceID, promptResource.CreatorID)
uid := ctxutil.GetUIDFromCtx(ctx)
if promptResource.CreatorID != *uid {
return nil, errorx.New(errno.ErrPromptPermissionCode, errorx.KV("msg", "no permission"))
}
promptResource.Name = req.Prompt.GetName()
promptResource.Description = req.Prompt.GetDescription()
promptResource.PromptText = req.Prompt.GetPromptText()
err = p.DomainSVC.UpdatePromptResource(ctx, promptResource)
if err != nil {
return nil, err
}
return &playground.UpsertPromptResourceResponse{
Data: &playground.ShowPromptResource{
ID: promptID,
},
Code: 0,
}, nil
}
func (p *PromptApplicationService) toPromptResourceDO(m *playground.PromptResource) *entity.PromptResource {
e := entity.PromptResource{}
e.ID = m.GetID()
e.PromptText = m.GetPromptText()
e.SpaceID = m.GetSpaceID()
e.Name = m.GetName()
e.Description = m.GetDescription()
return &e
}
func promptInfoDo2To(p *entity.PromptResource) *playground.PromptResource {
return &playground.PromptResource{
ID: ptr.Of(p.ID),
SpaceID: ptr.Of(p.SpaceID),
Name: ptr.Of(p.Name),
Description: ptr.Of(p.Description),
PromptText: ptr.Of(p.PromptText),
}
}

View File

@@ -0,0 +1,100 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package search
import (
"context"
"fmt"
"os"
"gorm.io/gorm"
"github.com/coze-dev/coze-studio/backend/application/singleagent"
app "github.com/coze-dev/coze-studio/backend/domain/app/service"
connector "github.com/coze-dev/coze-studio/backend/domain/connector/service"
knowledge "github.com/coze-dev/coze-studio/backend/domain/knowledge/service"
database "github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
"github.com/coze-dev/coze-studio/backend/domain/plugin/service"
prompt "github.com/coze-dev/coze-studio/backend/domain/prompt/service"
search "github.com/coze-dev/coze-studio/backend/domain/search/service"
user "github.com/coze-dev/coze-studio/backend/domain/user/service"
"github.com/coze-dev/coze-studio/backend/domain/workflow"
"github.com/coze-dev/coze-studio/backend/infra/contract/es"
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
"github.com/coze-dev/coze-studio/backend/infra/impl/cache/redis"
"github.com/coze-dev/coze-studio/backend/infra/impl/eventbus"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/types/consts"
)
type ServiceComponents struct {
DB *gorm.DB
Cache *redis.Client
TOS storage.Storage
ESClient es.Client
ProjectEventBus ProjectEventBus
ResourceEventBus ResourceEventBus
SingleAgentDomainSVC singleagent.SingleAgent
APPDomainSVC app.AppService
KnowledgeDomainSVC knowledge.Knowledge
PluginDomainSVC service.PluginService
WorkflowDomainSVC workflow.Service
UserDomainSVC user.User
ConnectorDomainSVC connector.Connector
PromptDomainSVC prompt.Prompt
DatabaseDomainSVC database.Database
}
func InitService(ctx context.Context, s *ServiceComponents) (*SearchApplicationService, error) {
searchDomainSVC := search.NewDomainService(ctx, s.ESClient)
SearchSVC.DomainSVC = searchDomainSVC
SearchSVC.ServiceComponents = s
// setup consumer
searchConsumer := search.NewProjectHandler(ctx, s.ESClient)
logs.Infof("start search domain consumer...")
nameServer := os.Getenv(consts.MQServer)
err := eventbus.RegisterConsumer(nameServer, consts.RMQTopicApp, consts.RMQConsumeGroupApp, searchConsumer)
if err != nil {
return nil, fmt.Errorf("register search consumer failed, err=%w", err)
}
searchResourceConsumer := search.NewResourceHandler(ctx, s.ESClient)
err = eventbus.RegisterConsumer(nameServer, consts.RMQTopicResource, consts.RMQConsumeGroupResource, searchResourceConsumer)
if err != nil {
return nil, fmt.Errorf("register search consumer failed, err=%w", err)
}
return SearchSVC, nil
}
type (
ResourceEventBus = search.ResourceEventBus
ProjectEventBus = search.ProjectEventBus
)
func NewResourceEventBus(p eventbus.Producer) search.ResourceEventBus {
return search.NewResourceEventBus(p)
}
func NewProjectEventBus(p eventbus.Producer) search.ProjectEventBus {
return search.NewProjectEventBus(p)
}

View File

@@ -0,0 +1,194 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package search
import (
"context"
"fmt"
"strconv"
"github.com/coze-dev/coze-studio/backend/api/model/intelligence"
"github.com/coze-dev/coze-studio/backend/api/model/intelligence/common"
"github.com/coze-dev/coze-studio/backend/domain/app/entity"
appService "github.com/coze-dev/coze-studio/backend/domain/app/service"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
)
type projectInfo struct {
iconURI string
desc string
}
type ProjectPacker interface {
GetProjectInfo(ctx context.Context) (*projectInfo, error)
GetPermissionInfo() *intelligence.IntelligencePermissionInfo
GetPublishedInfo(ctx context.Context) *intelligence.IntelligencePublishInfo
GetUserInfo(ctx context.Context, userID int64) *common.User
}
func NewPackProject(uid, projectID int64, tp common.IntelligenceType, s *SearchApplicationService) (ProjectPacker, error) {
base := projectBase{SVC: s, projectID: projectID, iType: tp, uid: uid}
switch tp {
case common.IntelligenceType_Bot:
return &agentPacker{projectBase: base}, nil
case common.IntelligenceType_Project:
return &appPacker{projectBase: base}, nil
}
return nil, fmt.Errorf("unsupported project_type: %d , project_id : %d", tp, projectID)
}
type projectBase struct {
projectID int64 // agent_id or application_id
uid int64
SVC *SearchApplicationService
iType common.IntelligenceType
}
func (p *projectBase) GetPermissionInfo() *intelligence.IntelligencePermissionInfo {
return &intelligence.IntelligencePermissionInfo{
InCollaboration: false,
CanDelete: true,
CanView: true,
}
}
func (p *projectBase) GetUserInfo(ctx context.Context, userID int64) *common.User {
u, err := p.SVC.UserDomainSVC.GetUserInfo(ctx, userID)
if err != nil {
logs.CtxErrorf(ctx, "[projectBase-GetUserInfo] failed to get user info, user_id: %d, err: %v", userID, err)
return nil
}
return &common.User{
UserID: u.UserID,
AvatarURL: u.IconURL,
UserUniqueName: u.UniqueName,
}
}
type agentPacker struct {
projectBase
}
func (a *agentPacker) GetProjectInfo(ctx context.Context) (*projectInfo, error) {
agent, err := a.SVC.SingleAgentDomainSVC.GetSingleAgentDraft(ctx, a.projectID)
if err != nil {
return nil, err
}
if agent == nil {
return nil, fmt.Errorf("agent info is nil")
}
return &projectInfo{
iconURI: agent.IconURI,
desc: agent.Desc,
}, nil
}
func (p *agentPacker) GetPublishedInfo(ctx context.Context) *intelligence.IntelligencePublishInfo {
pubInfo, err := p.SVC.SingleAgentDomainSVC.GetPublishedInfo(ctx, p.projectID)
if err != nil {
logs.CtxErrorf(ctx, "[agent-GetPublishedInfo]failed to get published info, agent_id: %d, err: %v", p.projectID, err)
return nil
}
connectors := make([]*common.ConnectorInfo, 0, len(pubInfo.ConnectorID2PublishTime))
for connectorID := range pubInfo.ConnectorID2PublishTime {
c, err := p.SVC.ConnectorDomainSVC.GetByID(ctx, connectorID)
if err != nil {
logs.CtxErrorf(ctx, "failed to get connector by id: %d, err: %v", connectorID, err)
continue
}
connectors = append(connectors, &common.ConnectorInfo{
ID: conv.Int64ToStr(c.ID),
Name: c.Name,
ConnectorStatus: common.ConnectorDynamicStatus(c.ConnectorStatus),
Icon: c.URL,
})
}
return &intelligence.IntelligencePublishInfo{
PublishTime: conv.Int64ToStr(pubInfo.LastPublishTimeMS / 1000),
HasPublished: pubInfo.LastPublishTimeMS > 0,
Connectors: connectors,
}
}
type appPacker struct {
projectBase
}
func (a *appPacker) GetProjectInfo(ctx context.Context) (*projectInfo, error) {
app, err := a.SVC.APPDomainSVC.GetDraftAPP(ctx, a.projectID)
if err != nil {
return nil, err
}
return &projectInfo{
iconURI: app.GetIconURI(),
desc: app.GetDesc(),
}, nil
}
func (a *appPacker) GetPublishedInfo(ctx context.Context) *intelligence.IntelligencePublishInfo {
record, exist, err := a.SVC.APPDomainSVC.GetAPPPublishRecord(ctx, &appService.GetAPPPublishRecordRequest{
APPID: a.projectID,
Oldest: true,
})
if err != nil {
logs.CtxErrorf(ctx, "[app-GetPublishedInfo] failed to get published info, app_id=%d, err=%v", a.projectID, err)
return nil
}
if !exist {
return &intelligence.IntelligencePublishInfo{
PublishTime: "",
HasPublished: false,
Connectors: nil,
}
}
connectorInfo := make([]*common.ConnectorInfo, 0, len(record.ConnectorPublishRecords))
connectorIDs := slices.Transform(record.ConnectorPublishRecords, func(c *entity.ConnectorPublishRecord) int64 {
return c.ConnectorID
})
connectors, err := a.SVC.ConnectorDomainSVC.GetByIDs(ctx, connectorIDs)
if err != nil {
logs.CtxErrorf(ctx, "[app-GetPublishedInfo] failed to get connector info, app_id=%d, err=%v", a.projectID, err)
} else {
for _, c := range connectors {
connectorInfo = append(connectorInfo, &common.ConnectorInfo{
ID: conv.Int64ToStr(c.ID),
Name: c.Name,
ConnectorStatus: common.ConnectorDynamicStatus(c.ConnectorStatus),
Icon: c.URL,
})
}
}
return &intelligence.IntelligencePublishInfo{
PublishTime: strconv.FormatInt(record.APP.GetPublishedAtMS()/1000, 10),
HasPublished: record.APP.Published(),
Connectors: connectorInfo,
}
}

View File

@@ -0,0 +1,452 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package search
import (
"context"
"fmt"
"sync"
"time"
search2 "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/search"
"github.com/coze-dev/coze-studio/backend/api/model/flow/marketplace/marketplace_common"
"github.com/coze-dev/coze-studio/backend/api/model/flow/marketplace/product_common"
"github.com/coze-dev/coze-studio/backend/api/model/flow/marketplace/product_public_api"
"github.com/coze-dev/coze-studio/backend/api/model/intelligence"
"github.com/coze-dev/coze-studio/backend/api/model/intelligence/common"
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
searchEntity "github.com/coze-dev/coze-studio/backend/domain/search/entity"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/taskgroup"
"github.com/coze-dev/coze-studio/backend/types/consts"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
var projectType2iconURI = map[common.IntelligenceType]string{
common.IntelligenceType_Bot: consts.DefaultAgentIcon,
common.IntelligenceType_Project: consts.DefaultAppIcon,
}
func (s *SearchApplicationService) GetDraftIntelligenceList(ctx context.Context, req *intelligence.GetDraftIntelligenceListRequest) (
resp *intelligence.GetDraftIntelligenceListResponse, err error,
) {
userID := ctxutil.GetUIDFromCtx(ctx)
if userID == nil {
return nil, errorx.New(errno.ErrSearchPermissionCode, errorx.KV("msg", "session is required"))
}
do := searchRequestTo2Do(*userID, req)
searchResp, err := s.DomainSVC.SearchProjects(ctx, do)
if err != nil {
return nil, err
}
if len(searchResp.Data) == 0 {
return &intelligence.GetDraftIntelligenceListResponse{
Data: &intelligence.DraftIntelligenceListData{
Intelligences: make([]*intelligence.IntelligenceData, 0),
Total: 0,
HasMore: false,
NextCursorID: "",
},
}, nil
}
tasks := taskgroup.NewUninterruptibleTaskGroup(ctx, len(searchResp.Data))
lock := sync.Mutex{}
intelligenceDataList := make([]*intelligence.IntelligenceData, len(searchResp.Data))
logs.CtxDebugf(ctx, "[GetDraftIntelligenceList] searchResp.Data: %v", conv.DebugJsonToStr(searchResp.Data))
for idx := range searchResp.Data {
data := searchResp.Data[idx]
index := idx
tasks.Go(func() error {
info, err := s.packIntelligenceData(ctx, data)
if err != nil {
logs.CtxErrorf(ctx, "[packIntelligenceData] failed id %v, type %d , name %s, err: %v", data.ID, data.Type, data.GetName(), err)
return err
}
lock.Lock()
defer lock.Unlock()
intelligenceDataList[index] = info
return nil
})
s.packIntelligenceData(ctx, data)
}
_ = tasks.Wait()
filterDataList := make([]*intelligence.IntelligenceData, 0)
for _, data := range intelligenceDataList {
if data != nil {
filterDataList = append(filterDataList, data)
}
}
return &intelligence.GetDraftIntelligenceListResponse{
Code: 0,
Data: &intelligence.DraftIntelligenceListData{
Intelligences: filterDataList,
Total: int32(len(filterDataList)),
HasMore: searchResp.HasMore,
NextCursorID: searchResp.NextCursor,
},
}, nil
}
func (s *SearchApplicationService) PublicFavoriteProduct(ctx context.Context, req *product_public_api.FavoriteProductRequest) (*product_public_api.FavoriteProductResponse, error) {
isFav := !req.GetIsCancel()
entityID := req.GetEntityID()
typ := req.GetEntityType()
switch req.GetEntityType() {
case product_common.ProductEntityType_Bot, product_common.ProductEntityType_Project:
err := s.favoriteProject(ctx, entityID, typ, isFav)
if err != nil {
return nil, err
}
default:
return nil, errorx.New(errno.ErrSearchInvalidParamCode, errorx.KV("msg", fmt.Sprintf("invalid entity type '%d'", req.GetEntityType())))
}
return &product_public_api.FavoriteProductResponse{
IsFirstFavorite: ptr.Of(false),
}, nil
}
func (s *SearchApplicationService) favoriteProject(ctx context.Context, projectID int64, typ product_common.ProductEntityType, isFav bool) error {
var entityType common.IntelligenceType
if typ == product_common.ProductEntityType_Bot {
entityType = common.IntelligenceType_Bot
} else {
entityType = common.IntelligenceType_Project
}
err := s.ProjectEventBus.PublishProject(ctx, &searchEntity.ProjectDomainEvent{
OpType: searchEntity.Updated,
Project: &searchEntity.ProjectDocument{
ID: projectID,
IsFav: ptr.Of(ternary.IFElse(isFav, 1, 0)),
FavTimeMS: ptr.Of(time.Now().UnixMilli()),
Type: entityType,
},
})
if err != nil {
return err
}
return nil
}
func (s *SearchApplicationService) PublicGetUserFavoriteList(ctx context.Context, req *product_public_api.GetUserFavoriteListV2Request) (resp *product_public_api.GetUserFavoriteListV2Response, err error) {
userID := ctxutil.GetUIDFromCtx(ctx)
if userID == nil {
return nil, errorx.New(errno.ErrSearchPermissionCode, errorx.KV("msg", "session required"))
}
var data *product_public_api.GetUserFavoriteListDataV2
switch req.GetEntityType() {
case product_common.ProductEntityType_Project, product_common.ProductEntityType_Bot, product_common.ProductEntityType_Common:
data, err = s.searchFavProjects(ctx, *userID, req)
default:
return nil, errorx.New(errno.ErrSearchInvalidParamCode, errorx.KV("msg", fmt.Sprintf("invalid entity type '%d'", req.GetEntityType())))
}
if err != nil {
return nil, err
}
resp = &product_public_api.GetUserFavoriteListV2Response{
Data: data,
}
return resp, nil
}
func (s *SearchApplicationService) searchFavProjects(ctx context.Context, userID int64, req *product_public_api.GetUserFavoriteListV2Request) (*product_public_api.GetUserFavoriteListDataV2, error) {
var types []common.IntelligenceType
if req.GetEntityType() == product_common.ProductEntityType_Common {
types = []common.IntelligenceType{common.IntelligenceType_Bot, common.IntelligenceType_Project}
} else if req.GetEntityType() == product_common.ProductEntityType_Bot {
types = []common.IntelligenceType{common.IntelligenceType_Bot}
} else {
types = []common.IntelligenceType{common.IntelligenceType_Project}
}
res, err := SearchSVC.DomainSVC.SearchProjects(ctx, &searchEntity.SearchProjectsRequest{
OwnerID: userID,
Types: types,
IsFav: true,
OrderFiledName: search2.FieldOfFavTime,
OrderAsc: false,
Limit: req.PageSize,
Cursor: req.GetCursorID(),
})
if err != nil {
return nil, err
}
if len(res.Data) == 0 {
return &product_public_api.GetUserFavoriteListDataV2{
FavoriteEntities: []*product_common.FavoriteEntity{},
CursorID: res.NextCursor,
HasMore: res.HasMore,
}, nil
}
favEntities := make([]*product_common.FavoriteEntity, 0, len(res.Data))
for _, r := range res.Data {
favEntity, err := s.projectResourceToProductInfo(ctx, userID, r)
if err != nil {
logs.CtxErrorf(ctx, "[pluginResourceToProductInfo] failed to get project info, id=%v, type=%d, err=%v",
r.ID, r.Type, err)
continue
}
favEntities = append(favEntities, favEntity)
}
data := &product_public_api.GetUserFavoriteListDataV2{
FavoriteEntities: favEntities,
CursorID: res.NextCursor,
HasMore: res.HasMore,
}
return data, nil
}
func (s *SearchApplicationService) projectResourceToProductInfo(ctx context.Context, userID int64, doc *searchEntity.ProjectDocument) (favEntity *product_common.FavoriteEntity, err error) {
typ := func() product_common.ProductEntityType {
if doc.Type == common.IntelligenceType_Bot {
return product_common.ProductEntityType_Bot
}
return product_common.ProductEntityType_Project
}()
packer, err := NewPackProject(userID, doc.ID, doc.Type, s)
if err != nil {
return nil, err
}
pi, err := packer.GetProjectInfo(ctx)
if err != nil {
return nil, err
}
ui := packer.GetUserInfo(ctx, userID)
var userInfo *product_common.UserInfo
if ui != nil {
userInfo = &product_common.UserInfo{
UserID: ui.UserID,
UserName: ui.UserUniqueName,
Name: ui.Nickname,
AvatarURL: ui.AvatarURL,
FollowType: ptr.Of(marketplace_common.FollowType_Unknown),
}
}
e := &product_common.FavoriteEntity{
EntityID: doc.ID,
EntityType: typ,
Name: doc.GetName(),
IconURL: pi.iconURI,
Description: pi.desc,
SpaceID: doc.GetSpaceID(),
HasSpacePermission: true,
FavoriteAt: doc.GetFavTime(),
UserInfo: userInfo,
}
return e, nil
}
func (s *SearchApplicationService) GetUserRecentlyEditIntelligence(ctx context.Context, req intelligence.GetUserRecentlyEditIntelligenceRequest) (
resp *intelligence.GetUserRecentlyEditIntelligenceResponse, err error,
) {
userID := ctxutil.GetUIDFromCtx(ctx)
if userID == nil {
return nil, errorx.New(errno.ErrSearchPermissionCode, errorx.KV("msg", "session required"))
}
res, err := SearchSVC.DomainSVC.SearchProjects(ctx, &searchEntity.SearchProjectsRequest{
OwnerID: *userID,
Types: req.Types,
IsRecentlyOpen: true,
OrderFiledName: search2.FieldOfRecentlyOpenTime,
OrderAsc: false,
Limit: req.Size,
})
if err != nil {
return nil, err
}
intelligenceDataList := make([]*intelligence.IntelligenceData, 0, len(res.Data))
for idx := range res.Data {
data := res.Data[idx]
info, err := s.packIntelligenceData(ctx, data)
if err != nil {
logs.CtxErrorf(ctx, "[packIntelligenceData] failed id %v, type %d, name %s, err: %v", data.ID, data.Type, data.GetName(), err)
continue
}
intelligenceDataList = append(intelligenceDataList, info)
}
resp = &intelligence.GetUserRecentlyEditIntelligenceResponse{
Data: &intelligence.GetUserRecentlyEditIntelligenceData{
IntelligenceInfoList: intelligenceDataList,
},
}
return resp, nil
}
func (s *SearchApplicationService) packIntelligenceData(ctx context.Context, doc *searchEntity.ProjectDocument) (*intelligence.IntelligenceData, error) {
intelligenceData := &intelligence.IntelligenceData{
Type: doc.Type,
BasicInfo: &common.IntelligenceBasicInfo{
ID: doc.ID,
Name: doc.GetName(),
SpaceID: doc.GetSpaceID(),
OwnerID: doc.GetOwnerID(),
Status: doc.Status,
CreateTime: doc.GetCreateTime() / 1000,
UpdateTime: doc.GetUpdateTime() / 1000,
PublishTime: doc.GetPublishTime() / 1000,
},
}
uid := ctxutil.MustGetUIDFromCtx(ctx)
packer, err := NewPackProject(uid, doc.ID, doc.Type, s)
if err != nil {
return nil, err
}
projInfo, err := packer.GetProjectInfo(ctx)
if err != nil {
return nil, errorx.Wrapf(err, "GetProjectInfo failed, id: %v, type: %v", doc.ID, doc.Type)
}
intelligenceData.BasicInfo.Description = projInfo.desc
intelligenceData.BasicInfo.IconURI = projInfo.iconURI
intelligenceData.BasicInfo.IconURL = s.getProjectIconURL(ctx, projInfo.iconURI, doc.Type)
intelligenceData.PermissionInfo = packer.GetPermissionInfo()
publishedInf := packer.GetPublishedInfo(ctx)
if publishedInf != nil {
intelligenceData.PublishInfo = packer.GetPublishedInfo(ctx)
} else {
intelligenceData.PublishInfo = &intelligence.IntelligencePublishInfo{
HasPublished: false,
}
}
intelligenceData.OwnerInfo = packer.GetUserInfo(ctx, doc.GetOwnerID())
intelligenceData.LatestAuditInfo = &common.AuditInfo{}
intelligenceData.FavoriteInfo = s.buildProjectFavoriteInfo(doc)
intelligenceData.OtherInfo = s.buildProjectOtherInfo(doc)
return intelligenceData, nil
}
func (s *SearchApplicationService) buildProjectFavoriteInfo(doc *searchEntity.ProjectDocument) *intelligence.FavoriteInfo {
isFav := doc.GetIsFav()
favTime := doc.GetFavTime()
return &intelligence.FavoriteInfo{
IsFav: isFav,
FavTime: conv.Int64ToStr(favTime / 1000),
}
}
func (s *SearchApplicationService) buildProjectOtherInfo(doc *searchEntity.ProjectDocument) *intelligence.OtherInfo {
otherInfo := &intelligence.OtherInfo{
BotMode: intelligence.BotMode_SingleMode,
RecentlyOpenTime: conv.Int64ToStr(doc.GetRecentlyOpenTime() / 1000),
}
if doc.Type == common.IntelligenceType_Project {
otherInfo.BotMode = intelligence.BotMode_WorkflowMode
}
return otherInfo
}
func searchRequestTo2Do(userID int64, req *intelligence.GetDraftIntelligenceListRequest) *searchEntity.SearchProjectsRequest {
orderBy := func() string {
switch req.GetOrderBy() {
case intelligence.OrderBy_PublishTime:
return search2.FieldOfPublishTime
case intelligence.OrderBy_UpdateTime:
return search2.FieldOfUpdateTime
case intelligence.OrderBy_CreateTime:
return search2.FieldOfCreateTime
default:
return search2.FieldOfUpdateTime
}
}()
searchReq := &searchEntity.SearchProjectsRequest{
SpaceID: req.GetSpaceID(),
Name: req.GetName(),
OwnerID: 0,
Limit: req.GetSize(),
Cursor: req.GetCursorID(),
OrderFiledName: orderBy,
OrderAsc: false,
Types: req.GetTypes(),
Status: req.GetStatus(),
IsFav: req.GetIsFav(),
IsRecentlyOpen: req.GetRecentlyOpen(),
IsPublished: req.GetHasPublished(),
}
if req.GetSearchScope() == intelligence.SearchScope_CreateByMe {
searchReq.OwnerID = userID
}
return searchReq
}
func (s *SearchApplicationService) getProjectDefaultIconURL(ctx context.Context, tp common.IntelligenceType) string {
iconURL, ok := projectType2iconURI[tp]
if !ok {
logs.CtxWarnf(ctx, "[getProjectDefaultIconURL] don't have type: %d default icon", tp)
return ""
}
return s.getURL(ctx, iconURL)
}
func (s *SearchApplicationService) getProjectIconURL(ctx context.Context, uri string, tp common.IntelligenceType) string {
if uri == "" {
return s.getProjectDefaultIconURL(ctx, tp)
}
url := s.getURL(ctx, uri)
if url != "" {
return url
}
return s.getProjectDefaultIconURL(ctx, tp)
}

View File

@@ -0,0 +1,325 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package search
import (
"context"
"fmt"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/api/model/resource/common"
"github.com/coze-dev/coze-studio/backend/api/model/table"
"github.com/coze-dev/coze-studio/backend/domain/knowledge/service"
dbservice "github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
)
var defaultAction = []*common.ResourceAction{
{
Key: common.ActionKey_Edit,
Enable: true,
},
{
Key: common.ActionKey_Delete,
Enable: true,
},
{
Key: common.ActionKey_Copy,
Enable: true,
},
}
type ResourcePacker interface {
GetDataInfo(ctx context.Context) (*dataInfo, error)
GetActions(ctx context.Context) []*common.ResourceAction
GetProjectDefaultActions(ctx context.Context) []*common.ProjectResourceAction
}
func NewResourcePacker(resID int64, t common.ResType, appContext *ServiceComponents) (ResourcePacker, error) {
base := resourceBasePacker{appContext: appContext, resID: resID}
switch t {
case common.ResType_Plugin:
return &pluginPacker{resourceBasePacker: base}, nil
case common.ResType_Workflow:
return &workflowPacker{resourceBasePacker: base}, nil
case common.ResType_Knowledge:
return &knowledgePacker{resourceBasePacker: base}, nil
case common.ResType_Prompt:
return &promptPacker{resourceBasePacker: base}, nil
case common.ResType_Database:
return &databasePacker{resourceBasePacker: base}, nil
}
return nil, fmt.Errorf("unsupported resource type: %s , resID: %d", t, resID)
}
type resourceBasePacker struct {
resID int64
appContext *ServiceComponents
}
type dataInfo struct {
iconURI *string
iconURL string
desc *string
status *int32
}
func (b *resourceBasePacker) GetActions(ctx context.Context) []*common.ResourceAction {
return defaultAction
}
func (b *resourceBasePacker) GetProjectDefaultActions(ctx context.Context) []*common.ProjectResourceAction {
return []*common.ProjectResourceAction{}
}
type pluginPacker struct {
resourceBasePacker
}
func (p *pluginPacker) GetDataInfo(ctx context.Context) (*dataInfo, error) {
plugin, err := p.appContext.PluginDomainSVC.GetDraftPlugin(ctx, p.resID)
if err != nil {
return nil, err
}
iconURL, err := p.appContext.TOS.GetObjectUrl(ctx, plugin.GetIconURI())
if err != nil {
logs.CtxWarnf(ctx, "get icon url failed with '%s', err=%v", plugin.GetIconURI(), err)
}
return &dataInfo{
iconURI: ptr.Of(plugin.GetIconURI()),
iconURL: iconURL,
desc: ptr.Of(plugin.GetDesc()),
}, nil
}
func (p *pluginPacker) GetProjectDefaultActions(ctx context.Context) []*common.ProjectResourceAction {
return []*common.ProjectResourceAction{
{
Key: common.ProjectResourceActionKey_Rename,
Enable: true,
},
{
Key: common.ProjectResourceActionKey_Copy,
Enable: true,
},
{
Key: common.ProjectResourceActionKey_Delete,
Enable: true,
},
{
Key: common.ProjectResourceActionKey_CopyToLibrary,
Enable: true,
},
{
Key: common.ProjectResourceActionKey_MoveToLibrary,
Enable: true,
},
}
}
type workflowPacker struct {
resourceBasePacker
}
func (w *workflowPacker) GetDataInfo(ctx context.Context) (*dataInfo, error) {
info, err := w.appContext.WorkflowDomainSVC.Get(ctx, &vo.GetPolicy{
ID: w.resID,
MetaOnly: true,
})
if err != nil {
return nil, err
}
return &dataInfo{
iconURI: &info.IconURI,
iconURL: info.IconURL,
desc: &info.Desc,
}, nil
}
func (w *workflowPacker) GetProjectDefaultActions(ctx context.Context) []*common.ProjectResourceAction {
return []*common.ProjectResourceAction{
{
Key: common.ProjectResourceActionKey_Rename,
Enable: true,
},
{
Key: common.ProjectResourceActionKey_Copy,
Enable: true,
},
{
Key: common.ProjectResourceActionKey_CopyToLibrary,
Enable: true,
},
{
Key: common.ProjectResourceActionKey_MoveToLibrary,
Enable: true,
},
{
Key: common.ProjectResourceActionKey_Delete,
Enable: true,
},
{
Key: common.ProjectResourceActionKey_UpdateDesc,
Enable: true,
},
}
}
type knowledgePacker struct {
resourceBasePacker
}
func (k *knowledgePacker) GetDataInfo(ctx context.Context) (*dataInfo, error) {
res, err := k.appContext.KnowledgeDomainSVC.GetKnowledgeByID(ctx, &service.GetKnowledgeByIDRequest{
KnowledgeID: k.resID,
})
if err != nil {
return nil, err
}
kn := res.Knowledge
return &dataInfo{
iconURI: ptr.Of(kn.IconURI),
iconURL: kn.IconURL,
desc: ptr.Of(kn.Description),
status: ptr.Of(int32(kn.Status)),
}, nil
}
func (k *knowledgePacker) GetActions(ctx context.Context) []*common.ResourceAction {
return []*common.ResourceAction{
{
Key: common.ActionKey_Delete,
Enable: true,
},
{
Key: common.ActionKey_EnableSwitch,
Enable: true,
},
{
Key: common.ActionKey_Edit,
Enable: true,
},
}
}
func (k *knowledgePacker) GetProjectDefaultActions(ctx context.Context) []*common.ProjectResourceAction {
return []*common.ProjectResourceAction{
{
Key: common.ProjectResourceActionKey_Rename,
Enable: true,
},
{
Key: common.ProjectResourceActionKey_Copy,
Enable: true,
},
{
Key: common.ProjectResourceActionKey_CopyToLibrary,
Enable: true,
},
{
Key: common.ProjectResourceActionKey_MoveToLibrary,
Enable: true,
},
{
Key: common.ProjectResourceActionKey_Delete,
Enable: true,
},
{
Key: common.ProjectResourceActionKey_Disable,
Enable: true,
},
}
}
type promptPacker struct {
resourceBasePacker
}
func (p *promptPacker) GetDataInfo(ctx context.Context) (*dataInfo, error) {
pInfo, err := p.appContext.PromptDomainSVC.GetPromptResource(ctx, p.resID)
if err != nil {
return nil, err
}
return &dataInfo{
iconURI: nil, // prompt don't have custom icon
iconURL: "",
desc: &pInfo.Description,
}, nil
}
type databasePacker struct {
resourceBasePacker
}
func (d *databasePacker) GetDataInfo(ctx context.Context) (*dataInfo, error) {
listResp, err := d.appContext.DatabaseDomainSVC.MGetDatabase(ctx, &dbservice.MGetDatabaseRequest{Basics: []*database.DatabaseBasic{
{
ID: d.resID,
TableType: table.TableType_OnlineTable,
},
}})
if err != nil {
return nil, err
}
if len(listResp.Databases) == 0 {
return nil, fmt.Errorf("online database not found, id: %d", d.resID)
}
return &dataInfo{
iconURI: ptr.Of(listResp.Databases[0].IconURI),
iconURL: listResp.Databases[0].IconURL,
desc: ptr.Of(listResp.Databases[0].TableDesc),
}, nil
}
func (d *databasePacker) GetActions(ctx context.Context) []*common.ResourceAction {
return []*common.ResourceAction{
{
Key: common.ActionKey_Delete,
Enable: true,
},
}
}
func (d *databasePacker) GetProjectDefaultActions(ctx context.Context) []*common.ProjectResourceAction {
return []*common.ProjectResourceAction{
{
Key: common.ProjectResourceActionKey_Copy,
Enable: true,
},
{
Key: common.ProjectResourceActionKey_CopyToLibrary,
Enable: true,
},
{
Key: common.ProjectResourceActionKey_MoveToLibrary,
Enable: true,
},
{
Key: common.ProjectResourceActionKey_Delete,
Enable: true,
},
}
}

View File

@@ -0,0 +1,392 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package search
import (
"context"
"errors"
"slices"
"strconv"
"sync"
knowledgeModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
"github.com/coze-dev/coze-studio/backend/api/model/resource"
"github.com/coze-dev/coze-studio/backend/api/model/resource/common"
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
"github.com/coze-dev/coze-studio/backend/domain/search/entity"
search "github.com/coze-dev/coze-studio/backend/domain/search/service"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/taskgroup"
"github.com/coze-dev/coze-studio/backend/types/consts"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
var SearchSVC = &SearchApplicationService{}
type SearchApplicationService struct {
*ServiceComponents
DomainSVC search.Search
}
var resType2iconURI = map[common.ResType]string{
common.ResType_Plugin: consts.DefaultPluginIcon,
common.ResType_Workflow: consts.DefaultWorkflowIcon,
common.ResType_Knowledge: consts.DefaultDatasetIcon,
common.ResType_Prompt: consts.DefaultPromptIcon,
common.ResType_Database: consts.DefaultDatabaseIcon,
// ResType_UI: consts.DefaultWorkflowIcon,
// ResType_Voice: consts.DefaultPluginIcon,
// ResType_Imageflow: consts.DefaultPluginIcon,
}
func (s *SearchApplicationService) LibraryResourceList(ctx context.Context, req *resource.LibraryResourceListRequest) (resp *resource.LibraryResourceListResponse, err error) {
userID := ctxutil.GetUIDFromCtx(ctx)
if userID == nil {
return nil, errorx.New(errno.ErrSearchPermissionCode, errorx.KV("msg", "session required"))
}
searchReq := &entity.SearchResourcesRequest{
SpaceID: req.GetSpaceID(),
OwnerID: 0,
Name: req.GetName(),
ResTypeFilter: req.GetResTypeFilter(),
PublishStatusFilter: req.GetPublishStatusFilter(),
SearchKeys: req.GetSearchKeys(),
Cursor: req.GetCursor(),
Limit: req.GetSize(),
}
// 设置用户过滤
if req.IsSetUserFilter() && req.GetUserFilter() > 0 {
searchReq.OwnerID = ptr.From(userID)
}
searchResp, err := s.DomainSVC.SearchResources(ctx, searchReq)
if err != nil {
return nil, err
}
lock := sync.Mutex{}
tasks := taskgroup.NewUninterruptibleTaskGroup(ctx, 10)
resources := make([]*common.ResourceInfo, len(searchResp.Data))
for idx := range searchResp.Data {
v := searchResp.Data[idx]
index := idx
tasks.Go(func() error {
ri, err := s.packResource(ctx, v)
if err != nil {
logs.CtxErrorf(ctx, "[LibraryResourceList] packResource failed, will ignore resID: %d, Name : %s, resType: %d, err: %v",
v.ResID, v.GetName(), v.ResType, err)
return err
}
lock.Lock()
defer lock.Unlock()
resources[index] = ri
return nil
})
}
_ = tasks.Wait()
filterResource := make([]*common.ResourceInfo, 0)
for _, res := range resources {
if res == nil {
continue
}
filterResource = append(filterResource, res)
}
return &resource.LibraryResourceListResponse{
Code: 0,
ResourceList: filterResource,
Cursor: ptr.Of(searchResp.NextCursor),
HasMore: searchResp.HasMore,
}, nil
}
func (s *SearchApplicationService) getResourceDefaultIconURL(ctx context.Context, tp common.ResType) string {
iconURL, ok := resType2iconURI[tp]
if !ok {
logs.CtxWarnf(ctx, "[getDefaultIconURL] don't have type: %d default icon", tp)
return ""
}
return s.getURL(ctx, iconURL)
}
func (s *SearchApplicationService) getURL(ctx context.Context, uri string) string {
url, err := s.TOS.GetObjectUrl(ctx, uri)
if err != nil {
logs.CtxWarnf(ctx, "[getDefaultIconURLWitURI] GetObjectUrl failed, uri: %s, err: %v", uri, err)
return ""
}
return url
}
func (s *SearchApplicationService) getResourceIconURL(ctx context.Context, uri *string, tp common.ResType) string {
if uri == nil || *uri == "" {
return s.getResourceDefaultIconURL(ctx, tp)
}
url := s.getURL(ctx, *uri)
if url != "" {
return url
}
return s.getResourceDefaultIconURL(ctx, tp)
}
func (s *SearchApplicationService) packUserInfo(ctx context.Context, ri *common.ResourceInfo, ownerID int64) *common.ResourceInfo {
u, err := s.UserDomainSVC.GetUserInfo(ctx, ownerID)
if err != nil {
logs.CtxWarnf(ctx, "[LibraryResourceList] GetUserInfo failed, uid: %d, resID: %d, Name : %s, err: %v",
ownerID, ri.ResID, ri.GetName(), err)
} else {
ri.CreatorName = ptr.Of(u.Name)
ri.CreatorAvatar = ptr.Of(u.IconURL)
}
if ri.GetCreatorAvatar() == "" {
ri.CreatorAvatar = ptr.Of(s.getURL(ctx, consts.DefaultUserIcon))
}
return ri
}
func (s *SearchApplicationService) packResource(ctx context.Context, doc *entity.ResourceDocument) (*common.ResourceInfo, error) {
ri := &common.ResourceInfo{
ResID: ptr.Of(doc.ResID),
ResType: ptr.Of(doc.ResType),
Name: doc.Name,
SpaceID: doc.SpaceID,
CreatorID: doc.OwnerID,
ResSubType: doc.ResSubType,
PublishStatus: doc.PublishStatus,
EditTime: ptr.Of(doc.GetUpdateTime() / 1000),
}
if doc.BizStatus != nil {
ri.BizResStatus = ptr.Of(int32(*doc.BizStatus))
}
packer, err := NewResourcePacker(doc.ResID, doc.ResType, s.ServiceComponents)
if err != nil {
return nil, errorx.Wrapf(err, "NewResourcePacker failed")
}
ri = s.packUserInfo(ctx, ri, doc.GetOwnerID())
ri.Actions = packer.GetActions(ctx)
data, err := packer.GetDataInfo(ctx)
if err != nil {
logs.CtxWarnf(ctx, "[packResource] GetDataInfo failed, resID: %d, Name : %s, resType: %d, err: %v",
doc.ResID, doc.GetName(), doc.ResType, err)
ri.Icon = ptr.Of(s.getResourceDefaultIconURL(ctx, doc.ResType))
return ri, nil // Warn : weak dependency data
}
ri.BizResStatus = data.status
ri.Desc = data.desc
ri.Icon = ternary.IFElse(len(data.iconURL) > 0,
&data.iconURL, ptr.Of(s.getResourceIconURL(ctx, data.iconURI, doc.ResType)))
ri.BizExtend = map[string]string{
"url": ptr.From(ri.Icon),
}
return ri, nil
}
func (s *SearchApplicationService) ProjectResourceList(ctx context.Context, req *resource.ProjectResourceListRequest) (resp *resource.ProjectResourceListResponse, err error) {
resources, err := s.getAPPAllResources(ctx, req.GetProjectID())
if err != nil {
return nil, err
}
resourceGroups, err := s.packAPPResources(ctx, resources)
if err != nil {
return nil, err
}
resourceGroups = s.sortAPPResources(resourceGroups)
return &resource.ProjectResourceListResponse{
ResourceGroups: resourceGroups,
}, nil
}
func (s *SearchApplicationService) getAPPAllResources(ctx context.Context, appID int64) ([]*entity.ResourceDocument, error) {
cursor := ""
resources := make([]*entity.ResourceDocument, 0, 100)
for {
res, err := s.DomainSVC.SearchResources(ctx, &entity.SearchResourcesRequest{
APPID: appID,
Cursor: cursor,
Limit: 100,
})
if err != nil {
return nil, err
}
resources = append(resources, res.Data...)
hasMore := res.HasMore
cursor = res.NextCursor
if !hasMore {
break
}
}
return resources, nil
}
func (s *SearchApplicationService) packAPPResources(ctx context.Context, resources []*entity.ResourceDocument) ([]*common.ProjectResourceGroup, error) {
workflowGroup := &common.ProjectResourceGroup{
GroupType: common.ProjectResourceGroupType_Workflow,
ResourceList: []*common.ProjectResourceInfo{},
}
dataGroup := &common.ProjectResourceGroup{
GroupType: common.ProjectResourceGroupType_Data,
ResourceList: []*common.ProjectResourceInfo{},
}
pluginGroup := &common.ProjectResourceGroup{
GroupType: common.ProjectResourceGroupType_Plugin,
ResourceList: []*common.ProjectResourceInfo{},
}
lock := sync.Mutex{}
tasks := taskgroup.NewUninterruptibleTaskGroup(ctx, 10)
for idx := range resources {
v := resources[idx]
tasks.Go(func() error {
ri, err := s.packProjectResource(ctx, v)
if err != nil {
logs.CtxErrorf(ctx, "packAPPResources failed, will ignore resID: %d, Name : %s, resType: %d, err: %v",
v.ResID, v.GetName(), v.ResType, err)
return err
}
lock.Lock()
defer lock.Unlock()
switch v.ResType {
case common.ResType_Workflow:
workflowGroup.ResourceList = append(workflowGroup.ResourceList, ri)
case common.ResType_Plugin:
pluginGroup.ResourceList = append(pluginGroup.ResourceList, ri)
case common.ResType_Database, common.ResType_Knowledge:
dataGroup.ResourceList = append(dataGroup.GetResourceList(), ri)
default:
logs.CtxWarnf(ctx, "unsupported resType: %d", v.ResType)
}
return nil
})
}
_ = tasks.Wait()
resourceGroups := []*common.ProjectResourceGroup{
workflowGroup,
pluginGroup,
dataGroup,
}
return resourceGroups, nil
}
func (s *SearchApplicationService) packProjectResource(ctx context.Context, resource *entity.ResourceDocument) (*common.ProjectResourceInfo, error) {
packer, err := NewResourcePacker(resource.ResID, resource.ResType, s.ServiceComponents)
if err != nil {
return nil, err
}
info := &common.ProjectResourceInfo{
ResID: resource.ResID,
ResType: resource.ResType,
ResSubType: resource.ResSubType,
Name: resource.GetName(),
Actions: packer.GetProjectDefaultActions(ctx),
}
if resource.ResType == common.ResType_Knowledge {
info.BizExtend = map[string]string{
"format_type": strconv.FormatInt(int64(resource.GetResSubType()), 10),
}
di, err := packer.GetDataInfo(ctx)
if err != nil {
logs.CtxErrorf(ctx, "GetDataInfo failed, resID=%d, resType=%d, err=%v",
resource.ResID, resource.ResType, err)
} else {
info.BizResStatus = ptr.Of(*di.status)
if *di.status == int32(knowledgeModel.KnowledgeStatusDisable) {
actions := slices.Clone(info.Actions)
for _, a := range actions {
if a.Key == common.ProjectResourceActionKey_Disable {
a.Key = common.ProjectResourceActionKey_Enable
break
}
}
}
}
}
if resource.ResType == common.ResType_Plugin {
err = s.PluginDomainSVC.CheckPluginToolsDebugStatus(ctx, resource.ResID)
if err != nil {
var e errorx.StatusError
if !errors.As(err, &e) {
logs.CtxErrorf(ctx, "CheckPluginToolsDebugStatus failed, resID=%d, resType=%d, err=%v",
resource.ResID, resource.ResType, err)
} else {
actions := slices.Clone(info.Actions)
for _, a := range actions {
if a.Key == common.ProjectResourceActionKey_MoveToLibrary ||
a.Key == common.ProjectResourceActionKey_CopyToLibrary {
a.Enable = false
a.Hint = ptr.Of(e.Msg())
}
}
}
}
}
return info, nil
}
func (s *ServiceComponents) sortAPPResources(resourceGroups []*common.ProjectResourceGroup) []*common.ProjectResourceGroup {
for _, g := range resourceGroups {
slices.SortFunc(g.ResourceList, func(a, b *common.ProjectResourceInfo) int {
if a.Name == b.Name {
return 0
}
if a.Name < b.Name {
return -1
}
return 1
})
}
return resourceGroups
}

View File

@@ -0,0 +1,40 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package shortcutcmd
import (
"gorm.io/gorm"
"github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/repository"
"github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/service"
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
)
var ShortcutCmdSVC *ShortcutCmdApplicationService
func InitService(db *gorm.DB, idGenSVC idgen.IDGenerator) *ShortcutCmdApplicationService {
components := &service.Components{
ShortCutCmdRepo: repository.NewShortCutCmdRepo(db, idGenSVC),
}
shortcutCmdDomainSVC := service.NewShortcutCommandService(components)
ShortcutCmdSVC = &ShortcutCmdApplicationService{
ShortCutDomainSVC: shortcutCmdDomainSVC,
}
return ShortcutCmdSVC
}

View File

@@ -0,0 +1,119 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package shortcutcmd
import (
"context"
"strconv"
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/playground"
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
"github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/entity"
"github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/service"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
)
type ShortcutCmdApplicationService struct {
ShortCutDomainSVC service.ShortcutCmd
}
func (s *ShortcutCmdApplicationService) Handler(ctx context.Context, req *playground.CreateUpdateShortcutCommandRequest) (*playground.ShortcutCommand, error) {
cr, buildErr := s.buildReq(ctx, req)
if buildErr != nil {
return nil, buildErr
}
var err error
var cmdDO *entity.ShortcutCmd
if cr.CommandID > 0 {
cmdDO, err = s.ShortCutDomainSVC.UpdateCMD(ctx, cr)
} else {
cmdDO, err = s.ShortCutDomainSVC.CreateCMD(ctx, cr)
}
if err != nil {
return nil, err
}
if cmdDO == nil {
return nil, nil
}
return s.buildDo2Vo(ctx, cmdDO), nil
}
func (s *ShortcutCmdApplicationService) buildReq(ctx context.Context, req *playground.CreateUpdateShortcutCommandRequest) (*entity.ShortcutCmd, error) {
uid := ctxutil.MustGetUIDFromCtx(ctx)
var workflowID int64
var pluginID int64
var err error
if req.GetShortcuts().GetWorkFlowID() != "" {
workflowID, err = strconv.ParseInt(req.GetShortcuts().GetWorkFlowID(), 10, 64)
if err != nil {
return nil, err
}
}
if req.GetShortcuts().GetPluginID() != "" {
pluginID, err = strconv.ParseInt(req.GetShortcuts().GetPluginID(), 10, 64)
if err != nil {
return nil, err
}
}
return &entity.ShortcutCmd{
ObjectID: req.GetObjectID(),
CommandID: req.GetShortcuts().CommandID,
CommandName: req.GetShortcuts().CommandName,
ShortcutCommand: req.GetShortcuts().ShortcutCommand,
Description: req.GetShortcuts().Description,
SendType: int32(req.GetShortcuts().SendType),
ToolType: int32(req.GetShortcuts().ToolType),
WorkFlowID: workflowID,
PluginID: pluginID,
Components: req.GetShortcuts().ComponentsList,
CardSchema: req.GetShortcuts().CardSchema,
ToolInfo: req.GetShortcuts().ToolInfo,
CreatorID: uid,
PluginToolID: req.GetShortcuts().PluginAPIID,
PluginToolName: req.GetShortcuts().PluginAPIName,
TemplateQuery: req.GetShortcuts().TemplateQuery,
ShortcutIcon: req.GetShortcuts().ShortcutIcon,
}, nil
}
func (s *ShortcutCmdApplicationService) buildDo2Vo(ctx context.Context, do *entity.ShortcutCmd) *playground.ShortcutCommand {
return &playground.ShortcutCommand{
ObjectID: do.ObjectID,
CommandID: do.CommandID,
CommandName: do.CommandName,
ShortcutCommand: do.ShortcutCommand,
Description: do.Description,
SendType: playground.SendType(do.SendType),
ToolType: playground.ToolType(do.ToolType),
WorkFlowID: conv.Int64ToStr(do.WorkFlowID),
PluginID: conv.Int64ToStr(do.PluginID),
ComponentsList: do.Components,
CardSchema: do.CardSchema,
ToolInfo: do.ToolInfo,
PluginAPIID: do.PluginToolID,
PluginAPIName: do.PluginToolName,
TemplateQuery: do.TemplateQuery,
ShortcutIcon: do.ShortcutIcon,
}
}

View File

@@ -0,0 +1,204 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package singleagent
import (
"context"
"time"
modelmgrEntity "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
intelligence "github.com/coze-dev/coze-studio/backend/api/model/intelligence/common"
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/bot_common"
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/developer_api"
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
"github.com/coze-dev/coze-studio/backend/domain/modelmgr"
searchEntity "github.com/coze-dev/coze-studio/backend/domain/search/entity"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
func (s *SingleAgentApplicationService) CreateSingleAgentDraft(ctx context.Context, req *developer_api.DraftBotCreateRequest) (*developer_api.DraftBotCreateResponse, error) {
do, err := s.draftBotCreateRequestToSingleAgent(ctx, req)
if err != nil {
return nil, err
}
userID := ctxutil.MustGetUIDFromCtx(ctx)
agentID, err := s.DomainSVC.CreateSingleAgentDraft(ctx, userID, do)
if err != nil {
return nil, err
}
err = s.appContext.EventBus.PublishProject(ctx, &searchEntity.ProjectDomainEvent{
OpType: searchEntity.Created,
Project: &searchEntity.ProjectDocument{
Status: intelligence.IntelligenceStatus_Using,
Type: intelligence.IntelligenceType_Bot,
ID: agentID,
SpaceID: &req.SpaceID,
OwnerID: &userID,
Name: &do.Name,
},
})
if err != nil {
return nil, err
}
return &developer_api.DraftBotCreateResponse{Data: &developer_api.DraftBotCreateData{
BotID: agentID,
}}, nil
}
func (s *SingleAgentApplicationService) draftBotCreateRequestToSingleAgent(ctx context.Context, req *developer_api.DraftBotCreateRequest) (*entity.SingleAgent, error) {
sa, err := s.newDefaultSingleAgent(ctx)
if err != nil {
return nil, err
}
sa.SpaceID = req.SpaceID
sa.Name = req.GetName()
sa.Desc = req.GetDescription()
sa.IconURI = req.GetIconURI()
return sa, nil
}
func (s *SingleAgentApplicationService) newDefaultSingleAgent(ctx context.Context) (*entity.SingleAgent, error) {
mi, err := s.defaultModelInfo(ctx)
if err != nil {
return nil, err
}
now := time.Now().UnixMilli()
return &entity.SingleAgent{
SingleAgent: &singleagent.SingleAgent{
OnboardingInfo: &bot_common.OnboardingInfo{},
ModelInfo: mi,
Prompt: &bot_common.PromptInfo{},
Plugin: []*bot_common.PluginInfo{},
Knowledge: &bot_common.Knowledge{
TopK: ptr.Of(int64(1)),
MinScore: ptr.Of(float64(0.01)),
SearchStrategy: ptr.Of(bot_common.SearchStrategy_SemanticSearch),
RecallStrategy: &bot_common.RecallStrategy{
UseNl2sql: ptr.Of(true),
UseRerank: ptr.Of(true),
UseRewrite: ptr.Of(true),
},
},
Workflow: []*bot_common.WorkflowInfo{},
SuggestReply: &bot_common.SuggestReplyInfo{},
JumpConfig: &bot_common.JumpConfig{},
Database: []*bot_common.Database{},
CreatedAt: now,
UpdatedAt: now,
},
}, nil
}
func (s *SingleAgentApplicationService) defaultModelInfo(ctx context.Context) (*bot_common.ModelInfo, error) {
modelResp, err := s.appContext.ModelMgrDomainSVC.ListModel(ctx, &modelmgr.ListModelRequest{
Status: []modelmgrEntity.ModelEntityStatus{modelmgrEntity.ModelEntityStatusDefault, modelmgrEntity.ModelEntityStatusInUse},
Limit: 1,
Cursor: nil,
})
if err != nil {
return nil, err
}
if len(modelResp.ModelList) == 0 {
return nil, errorx.New(errno.ErrAgentResourceNotFound, errorx.KV("type", "model"), errorx.KV("id", "default"))
}
dm := modelResp.ModelList[0]
var temperature *float64
if tp, ok := dm.FindParameter(modelmgrEntity.Temperature); ok {
t, err := tp.GetFloat(modelmgrEntity.DefaultTypeBalance)
if err != nil {
return nil, err
}
temperature = ptr.Of(t)
}
var maxTokens *int32
if tp, ok := dm.FindParameter(modelmgrEntity.MaxTokens); ok {
t, err := tp.GetInt(modelmgrEntity.DefaultTypeBalance)
if err != nil {
return nil, err
}
maxTokens = ptr.Of(int32(t))
} else if dm.Meta.ConnConfig.MaxTokens != nil {
maxTokens = ptr.Of(int32(*dm.Meta.ConnConfig.MaxTokens))
}
var topP *float64
if tp, ok := dm.FindParameter(modelmgrEntity.TopP); ok {
t, err := tp.GetFloat(modelmgrEntity.DefaultTypeBalance)
if err != nil {
return nil, err
}
topP = ptr.Of(t)
}
var topK *int32
if tp, ok := dm.FindParameter(modelmgrEntity.TopK); ok {
t, err := tp.GetInt(modelmgrEntity.DefaultTypeBalance)
if err != nil {
return nil, err
}
topK = ptr.Of(int32(t))
}
var frequencyPenalty *float64
if tp, ok := dm.FindParameter(modelmgrEntity.FrequencyPenalty); ok {
t, err := tp.GetFloat(modelmgrEntity.DefaultTypeBalance)
if err != nil {
return nil, err
}
frequencyPenalty = ptr.Of(t)
}
var presencePenalty *float64
if tp, ok := dm.FindParameter(modelmgrEntity.PresencePenalty); ok {
t, err := tp.GetFloat(modelmgrEntity.DefaultTypeBalance)
if err != nil {
return nil, err
}
presencePenalty = ptr.Of(t)
}
return &bot_common.ModelInfo{
ModelId: ptr.Of(dm.ID),
Temperature: temperature,
MaxTokens: maxTokens,
TopP: topP,
FrequencyPenalty: frequencyPenalty,
PresencePenalty: presencePenalty,
TopK: topK,
ModelStyle: bot_common.ModelStylePtr(bot_common.ModelStyle_Balance),
ShortMemoryPolicy: &bot_common.ShortMemoryPolicy{
ContextMode: bot_common.ContextModePtr(bot_common.ContextMode_FunctionCall_2),
HistoryRound: ptr.Of[int32](3),
},
}, nil
}

View File

@@ -0,0 +1,177 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package singleagent
import (
"context"
intelligence "github.com/coze-dev/coze-studio/backend/api/model/intelligence/common"
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/developer_api"
"github.com/coze-dev/coze-studio/backend/api/model/project_memory"
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossplugin"
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
searchEntity "github.com/coze-dev/coze-studio/backend/domain/search/entity"
shortcutCMDEntity "github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/entity"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
)
type duplicateAgentResourceFn func(ctx context.Context, appContext *ServiceComponents, oldAgent, newAgent *entity.SingleAgent) (*entity.SingleAgent, error)
func (s *SingleAgentApplicationService) DuplicateDraftBot(ctx context.Context, req *developer_api.DuplicateDraftBotRequest) (*developer_api.DuplicateDraftBotResponse, error) {
draftAgent, err := s.ValidateAgentDraftAccess(ctx, req.BotID)
if err != nil {
return nil, err
}
newAgentID, err := s.appContext.IDGen.GenID(ctx)
if err != nil {
return nil, err
}
userID := ctxutil.MustGetUIDFromCtx(ctx)
duplicateInfo := &entity.DuplicateInfo{
NewAgentID: newAgentID,
SpaceID: req.GetSpaceID(),
UserID: userID,
DraftAgent: draftAgent,
}
newAgent, err := s.DomainSVC.DuplicateInMemory(ctx, duplicateInfo)
if err != nil {
return nil, err
}
duplicateFns := []duplicateAgentResourceFn{
duplicateVariables,
duplicatePlugin,
duplicateShortCommand,
}
for _, fn := range duplicateFns {
newAgent, err = fn(ctx, s.appContext, draftAgent, newAgent)
if err != nil {
return nil, err
}
}
_, err = s.DomainSVC.CreateSingleAgentDraftWithID(ctx, userID, newAgentID, newAgent)
if err != nil {
return nil, err
}
userInfo, err := s.appContext.UserDomainSVC.GetUserInfo(ctx, userID)
if err != nil {
return nil, err
}
err = s.appContext.EventBus.PublishProject(ctx, &searchEntity.ProjectDomainEvent{
OpType: searchEntity.Created,
Project: &searchEntity.ProjectDocument{
Status: intelligence.IntelligenceStatus_Using,
Type: intelligence.IntelligenceType_Bot,
ID: newAgent.AgentID,
SpaceID: &req.SpaceID,
OwnerID: &userID,
Name: &newAgent.Name,
},
})
if err != nil {
return nil, err
}
return &developer_api.DuplicateDraftBotResponse{
Data: &developer_api.DuplicateDraftBotData{
BotID: newAgent.AgentID,
Name: newAgent.Name,
UserInfo: &developer_api.Creator{
ID: userID,
Name: userInfo.Name,
AvatarURL: userInfo.IconURL,
Self: userID == draftAgent.CreatorID,
UserUniqueName: userInfo.UniqueName,
UserLabel: nil,
},
},
Code: 0,
}, nil
}
func duplicateVariables(ctx context.Context, appContext *ServiceComponents, oldAgent, newAgent *entity.SingleAgent) (*entity.SingleAgent, error) {
if oldAgent.VariablesMetaID == nil || *oldAgent.VariablesMetaID <= 0 {
return newAgent, nil
}
vars, err := appContext.VariablesDomainSVC.GetVariableMetaByID(ctx, *oldAgent.VariablesMetaID)
if err != nil {
return nil, err
}
vars.ID = 0
vars.BizID = conv.Int64ToStr(newAgent.AgentID)
vars.BizType = project_memory.VariableConnector_Bot
vars.Version = ""
vars.CreatorID = newAgent.CreatorID
varMetaID, err := appContext.VariablesDomainSVC.UpsertMeta(ctx, vars)
if err != nil {
return nil, err
}
newAgent.VariablesMetaID = &varMetaID
return newAgent, nil
}
func duplicatePlugin(ctx context.Context, _ *ServiceComponents, oldAgent, newAgent *entity.SingleAgent) (*entity.SingleAgent, error) {
err := crossplugin.DefaultSVC().DuplicateDraftAgentTools(ctx, oldAgent.AgentID, newAgent.AgentID)
if err != nil {
return nil, err
}
return newAgent, nil
}
func duplicateShortCommand(ctx context.Context, appContext *ServiceComponents, oldAgent, newAgent *entity.SingleAgent) (*entity.SingleAgent, error) {
metas, err := appContext.ShortcutCMDDomainSVC.ListCMD(ctx, &shortcutCMDEntity.ListMeta{
SpaceID: oldAgent.SpaceID,
ObjectID: oldAgent.AgentID,
IsOnline: 0,
CommandIDs: slices.Transform(oldAgent.ShortcutCommand, func(a string) int64 {
return conv.StrToInt64D(a, 0)
}),
})
if err != nil {
return nil, err
}
shortcutCommandIDs := make([]string, 0, len(metas))
for _, meta := range metas {
meta.ObjectID = newAgent.AgentID
meta.CreatorID = newAgent.CreatorID
do, err := appContext.ShortcutCMDDomainSVC.CreateCMD(ctx, meta)
if err != nil {
return nil, err
}
shortcutCommandIDs = append(shortcutCommandIDs, conv.Int64ToStr(do.CommandID))
}
newAgent.ShortcutCommand = shortcutCommandIDs
return newAgent, nil
}

View File

@@ -0,0 +1,587 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package singleagent
import (
"context"
"fmt"
"github.com/getkin/kin-openapi/openapi3"
knowledgeModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/bot_common"
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/playground"
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/workflow"
"github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common"
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
knowledge "github.com/coze-dev/coze-studio/backend/domain/knowledge/service"
"github.com/coze-dev/coze-studio/backend/domain/modelmgr"
modelEntity "github.com/coze-dev/coze-studio/backend/domain/modelmgr/entity"
pluginEntity "github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
"github.com/coze-dev/coze-studio/backend/domain/plugin/service"
shortcutCMDEntity "github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/entity"
workflowEntity "github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
func (s *SingleAgentApplicationService) GetAgentBotInfo(ctx context.Context, req *playground.GetDraftBotInfoAgwRequest) (*playground.GetDraftBotInfoAgwResponse, error) {
agentInfo, err := s.DomainSVC.GetSingleAgent(ctx, req.GetBotID(), req.GetVersion())
if err != nil {
return nil, err
}
if agentInfo == nil {
return nil, errorx.New(errno.ErrAgentInvalidParamCode, errorx.KVf("msg", "agent %d not found", req.GetBotID()))
}
vo, err := s.singleAgentDraftDo2Vo(ctx, agentInfo)
if err != nil {
return nil, err
}
klInfos, err := s.fetchKnowledgeDetails(ctx, agentInfo)
if err != nil {
return nil, err
}
modelInfos, err := s.fetchModelDetails(ctx, agentInfo)
if err != nil {
return nil, err
}
toolInfos, err := s.fetchToolDetails(ctx, agentInfo, req)
if err != nil {
return nil, err
}
pluginInfos, err := s.fetchPluginDetails(ctx, agentInfo, toolInfos)
if err != nil {
return nil, err
}
workflowInfos, err := s.fetchWorkflowDetails(ctx, agentInfo)
if err != nil {
return nil, err
}
shortCutCmdResp, err := s.fetchShortcutCMD(ctx, agentInfo)
if err != nil {
return nil, err
}
workflowDetailMap, err := workflowDo2Vo(workflowInfos)
if err != nil {
return nil, err
}
return &playground.GetDraftBotInfoAgwResponse{
Data: &playground.GetDraftBotInfoAgwData{
BotInfo: vo,
BotOptionData: &playground.BotOptionData{
ModelDetailMap: modelInfoDo2Vo(modelInfos),
KnowledgeDetailMap: knowledgeInfoDo2Vo(klInfos),
PluginAPIDetailMap: toolInfoDo2Vo(toolInfos),
PluginDetailMap: s.pluginInfoDo2Vo(ctx, pluginInfos),
WorkflowDetailMap: workflowDetailMap,
ShortcutCommandList: shortCutCmdResp,
},
SpaceID: agentInfo.SpaceID,
Editable: ptr.Of(true),
Deletable: ptr.Of(true),
},
}, nil
}
func (s *SingleAgentApplicationService) fetchShortcutCMD(ctx context.Context, agentInfo *entity.SingleAgent) ([]*playground.ShortcutCommand, error) {
var cmdVOs []*playground.ShortcutCommand
if len(agentInfo.ShortcutCommand) == 0 {
return cmdVOs, nil
}
cmdDOs, err := s.appContext.ShortcutCMDDomainSVC.ListCMD(ctx, &shortcutCMDEntity.ListMeta{
SpaceID: agentInfo.SpaceID,
ObjectID: agentInfo.AgentID,
CommandIDs: slices.Transform(agentInfo.ShortcutCommand, func(a string) int64 {
return conv.StrToInt64D(a, 0)
}),
})
logs.CtxInfof(ctx, "fetchShortcutCMD cmdDOs = %v, err = %v", conv.DebugJsonToStr(cmdDOs), err)
if err != nil {
return nil, err
}
cmdVOs = s.shortcutCMDDo2Vo(cmdDOs)
return cmdVOs, nil
}
func (s *SingleAgentApplicationService) shortcutCMDDo2Vo(cmdDOs []*shortcutCMDEntity.ShortcutCmd) []*playground.ShortcutCommand {
return slices.Transform(cmdDOs, func(cmdDO *shortcutCMDEntity.ShortcutCmd) *playground.ShortcutCommand {
return &playground.ShortcutCommand{
ObjectID: cmdDO.ObjectID,
CommandID: cmdDO.CommandID,
CommandName: cmdDO.CommandName,
ShortcutCommand: cmdDO.ShortcutCommand,
Description: cmdDO.Description,
SendType: playground.SendType(cmdDO.SendType),
ToolType: playground.ToolType(cmdDO.ToolType),
WorkFlowID: conv.Int64ToStr(cmdDO.WorkFlowID),
PluginID: conv.Int64ToStr(cmdDO.PluginID),
PluginAPIName: cmdDO.PluginToolName,
PluginAPIID: cmdDO.PluginToolID,
ShortcutIcon: cmdDO.ShortcutIcon,
TemplateQuery: cmdDO.TemplateQuery,
ComponentsList: cmdDO.Components,
CardSchema: cmdDO.CardSchema,
ToolInfo: cmdDO.ToolInfo,
}
})
}
func (s *SingleAgentApplicationService) fetchModelDetails(ctx context.Context, agentInfo *entity.SingleAgent) ([]*modelEntity.Model, error) {
if agentInfo.ModelInfo.ModelId == nil {
return nil, nil
}
modelID := agentInfo.ModelInfo.GetModelId()
modelInfos, err := s.appContext.ModelMgrDomainSVC.MGetModelByID(ctx, &modelmgr.MGetModelRequest{
IDs: []int64{modelID},
})
if err != nil {
return nil, fmt.Errorf("fetch model(%d) details failed: %v", modelID, err)
}
return modelInfos, nil
}
func (s *SingleAgentApplicationService) fetchKnowledgeDetails(ctx context.Context, agentInfo *entity.SingleAgent) ([]*knowledgeModel.Knowledge, error) {
knowledgeIDs := make([]int64, 0, len(agentInfo.Knowledge.KnowledgeInfo))
for _, v := range agentInfo.Knowledge.KnowledgeInfo {
id, err := conv.StrToInt64(v.GetId())
if err != nil {
return nil, fmt.Errorf("invalid knowledge id: %s", v.GetId())
}
knowledgeIDs = append(knowledgeIDs, id)
}
if len(knowledgeIDs) == 0 {
return nil, nil
}
listResp, err := s.appContext.KnowledgeDomainSVC.ListKnowledge(ctx, &knowledge.ListKnowledgeRequest{
IDs: knowledgeIDs,
})
if err != nil {
return nil, fmt.Errorf("fetch knowledge details failed: %v", err)
}
return listResp.KnowledgeList, err
}
func (s *SingleAgentApplicationService) fetchToolDetails(ctx context.Context, agentInfo *entity.SingleAgent, req *playground.GetDraftBotInfoAgwRequest) ([]*pluginEntity.ToolInfo, error) {
return s.appContext.PluginDomainSVC.MGetAgentTools(ctx, &service.MGetAgentToolsRequest{
SpaceID: agentInfo.SpaceID,
AgentID: req.GetBotID(),
IsDraft: true,
VersionAgentTools: slices.Transform(agentInfo.Plugin, func(a *bot_common.PluginInfo) pluginEntity.VersionAgentTool {
return pluginEntity.VersionAgentTool{
ToolID: a.GetApiId(),
}
}),
})
}
func (s *SingleAgentApplicationService) fetchPluginDetails(ctx context.Context, agentInfo *entity.SingleAgent, toolInfos []*pluginEntity.ToolInfo) ([]*pluginEntity.PluginInfo, error) {
vPlugins := make([]pluginEntity.VersionPlugin, 0, len(agentInfo.Plugin))
vPluginMap := make(map[string]bool, len(agentInfo.Plugin))
for _, v := range toolInfos {
k := fmt.Sprintf("%d:%s", v.PluginID, v.GetVersion())
if vPluginMap[k] {
continue
}
vPluginMap[k] = true
vPlugins = append(vPlugins, pluginEntity.VersionPlugin{
PluginID: v.PluginID,
Version: v.GetVersion(),
})
}
return s.appContext.PluginDomainSVC.MGetVersionPlugins(ctx, vPlugins)
}
func (s *SingleAgentApplicationService) fetchWorkflowDetails(ctx context.Context, agentInfo *entity.SingleAgent) ([]*workflowEntity.Workflow, error) {
if len(agentInfo.Workflow) == 0 {
return nil, nil
}
policy := &vo.MGetPolicy{
MetaQuery: vo.MetaQuery{
IDs: slices.Transform(agentInfo.Workflow, func(a *bot_common.WorkflowInfo) int64 {
return a.GetWorkflowId()
}),
},
QType: vo.FromLatestVersion,
}
ret, _, err := s.appContext.WorkflowDomainSVC.MGet(ctx, policy)
if err != nil {
return nil, fmt.Errorf("fetch workflow details failed: %v", err)
}
return ret, nil
}
func modelInfoDo2Vo(modelInfos []*modelEntity.Model) map[int64]*playground.ModelDetail {
return slices.ToMap(modelInfos, func(e *modelEntity.Model) (int64, *playground.ModelDetail) {
return e.ID, toModelDetail(e)
})
}
func toModelDetail(m *modelEntity.Model) *playground.ModelDetail {
mm := m.Meta
return &playground.ModelDetail{
Name: ptr.Of(m.Name),
ModelName: ptr.Of(m.Meta.Name),
ModelID: ptr.Of(m.ID),
ModelFamily: ptr.Of(int64(mm.Protocol.TOModelClass())),
ModelIconURL: ptr.Of(mm.IconURL),
}
}
func knowledgeInfoDo2Vo(klInfos []*knowledgeModel.Knowledge) map[string]*playground.KnowledgeDetail {
return slices.ToMap(klInfos, func(e *knowledgeModel.Knowledge) (string, *playground.KnowledgeDetail) {
return fmt.Sprintf("%v", e.ID), &playground.KnowledgeDetail{
ID: ptr.Of(fmt.Sprintf("%d", e.ID)),
Name: ptr.Of(e.Name),
IconURL: ptr.Of(e.IconURL),
FormatType: func() playground.DataSetType {
switch e.Type {
case knowledgeModel.DocumentTypeText:
return playground.DataSetType_Text
case knowledgeModel.DocumentTypeTable:
return playground.DataSetType_Table
case knowledgeModel.DocumentTypeImage:
return playground.DataSetType_Image
}
return playground.DataSetType_Text
}(),
}
})
}
func toolInfoDo2Vo(toolInfos []*pluginEntity.ToolInfo) map[int64]*playground.PluginAPIDetal {
return slices.ToMap(toolInfos, func(e *pluginEntity.ToolInfo) (int64, *playground.PluginAPIDetal) {
return e.ID, &playground.PluginAPIDetal{
ID: ptr.Of(e.ID),
Name: ptr.Of(e.GetName()),
Description: ptr.Of(e.GetDesc()),
PluginID: ptr.Of(e.PluginID),
Parameters: parametersDo2Vo(e.Operation),
}
})
}
func (s *SingleAgentApplicationService) pluginInfoDo2Vo(ctx context.Context, pluginInfos []*pluginEntity.PluginInfo) map[int64]*playground.PluginDetal {
return slices.ToMap(pluginInfos, func(v *pluginEntity.PluginInfo) (int64, *playground.PluginDetal) {
e := v.PluginInfo
var iconURL string
if e.GetIconURI() != "" {
var err error
iconURL, err = s.appContext.TosClient.GetObjectUrl(ctx, e.GetIconURI())
if err != nil {
logs.CtxErrorf(ctx, "get icon url failed, err = %v", err)
}
}
return e.ID, &playground.PluginDetal{
ID: ptr.Of(e.ID),
Name: ptr.Of(e.GetName()),
Description: ptr.Of(e.GetDesc()),
PluginType: (*int64)(&e.PluginType),
IconURL: &iconURL,
PluginStatus: (*int64)(ptr.Of(plugin_develop_common.PluginStatus_PUBLISHED)),
IsOfficial: func() *bool {
if e.SpaceID == 0 {
return ptr.Of(true)
}
return ptr.Of(false)
}(),
}
})
}
func parametersDo2Vo(op *plugin.Openapi3Operation) []*playground.PluginParameter {
var convertReqBody func(paramName string, isRequired bool, sc *openapi3.Schema) *playground.PluginParameter
convertReqBody = func(paramName string, isRequired bool, sc *openapi3.Schema) *playground.PluginParameter {
if disabledParam(sc) {
return nil
}
var assistType *int64
if v, ok := sc.Extensions[plugin.APISchemaExtendAssistType]; ok {
if _v, ok := v.(string); ok {
assistType = toParameterAssistType(_v)
}
}
paramInfo := &playground.PluginParameter{
Name: ptr.Of(paramName),
Type: ptr.Of(sc.Type),
Description: ptr.Of(sc.Description),
IsRequired: ptr.Of(isRequired),
AssistType: assistType,
}
switch sc.Type {
case openapi3.TypeObject:
required := slices.ToMap(sc.Required, func(e string) (string, bool) {
return e, true
})
subParams := make([]*playground.PluginParameter, 0, len(sc.Properties))
for subParamName, prop := range sc.Properties {
subParamInfo := convertReqBody(subParamName, required[subParamName], prop.Value)
if subParamInfo != nil {
subParams = append(subParams, subParamInfo)
}
}
paramInfo.SubParameters = subParams
return paramInfo
case openapi3.TypeArray:
paramInfo.SubType = ptr.Of(sc.Items.Value.Type)
if sc.Items.Value.Type != openapi3.TypeObject {
return paramInfo
}
required := slices.ToMap(sc.Required, func(e string) (string, bool) {
return e, true
})
subParams := make([]*playground.PluginParameter, 0, len(sc.Items.Value.Properties))
for subParamName, prop := range sc.Items.Value.Properties {
subParamInfo := convertReqBody(subParamName, required[subParamName], prop.Value)
if subParamInfo != nil {
subParams = append(subParams, subParamInfo)
}
}
paramInfo.SubParameters = subParams
return paramInfo
default:
return paramInfo
}
}
var params []*playground.PluginParameter
for _, prop := range op.Parameters {
paramVal := prop.Value
schemaVal := paramVal.Schema.Value
if schemaVal.Type == openapi3.TypeObject || schemaVal.Type == openapi3.TypeArray {
continue
}
if disabledParam(prop.Value.Schema.Value) {
continue
}
var assistType *int64
if v, ok := schemaVal.Extensions[plugin.APISchemaExtendAssistType]; ok {
if _v, ok := v.(string); ok {
assistType = toParameterAssistType(_v)
}
}
params = append(params, &playground.PluginParameter{
Name: ptr.Of(paramVal.Name),
Description: ptr.Of(paramVal.Description),
IsRequired: ptr.Of(paramVal.Required),
Type: ptr.Of(schemaVal.Type),
AssistType: assistType,
})
}
if op.RequestBody == nil || op.RequestBody.Value == nil || len(op.RequestBody.Value.Content) == 0 {
return params
}
for _, mType := range op.RequestBody.Value.Content {
schemaVal := mType.Schema.Value
if len(schemaVal.Properties) == 0 {
continue
}
required := slices.ToMap(schemaVal.Required, func(e string) (string, bool) {
return e, true
})
for paramName, prop := range schemaVal.Properties {
paramInfo := convertReqBody(paramName, required[paramName], prop.Value)
if paramInfo != nil {
params = append(params, paramInfo)
}
}
break // 只取一种 MIME
}
return params
}
func toParameterAssistType(assistType string) *int64 {
if assistType == "" {
return nil
}
switch plugin.APIFileAssistType(assistType) {
case plugin.AssistTypeFile:
return ptr.Of(int64(plugin_develop_common.AssistParameterType_CODE))
case plugin.AssistTypeImage:
return ptr.Of(int64(plugin_develop_common.AssistParameterType_IMAGE))
case plugin.AssistTypeDoc:
return ptr.Of(int64(plugin_develop_common.AssistParameterType_DOC))
case plugin.AssistTypePPT:
return ptr.Of(int64(plugin_develop_common.AssistParameterType_PPT))
case plugin.AssistTypeCode:
return ptr.Of(int64(plugin_develop_common.AssistParameterType_CODE))
case plugin.AssistTypeExcel:
return ptr.Of(int64(plugin_develop_common.AssistParameterType_EXCEL))
case plugin.AssistTypeZIP:
return ptr.Of(int64(plugin_develop_common.AssistParameterType_ZIP))
case plugin.AssistTypeVideo:
return ptr.Of(int64(plugin_develop_common.AssistParameterType_VIDEO))
case plugin.AssistTypeAudio:
return ptr.Of(int64(plugin_develop_common.AssistParameterType_AUDIO))
case plugin.AssistTypeTXT:
return ptr.Of(int64(plugin_develop_common.AssistParameterType_TXT))
default:
return nil
}
}
func workflowDo2Vo(wfInfos []*workflowEntity.Workflow) (map[int64]*playground.WorkflowDetail, error) {
result := make(map[int64]*playground.WorkflowDetail, len(wfInfos))
for _, e := range wfInfos {
parameters, err := slices.TransformWithErrorCheck(e.InputParams, toPluginParameter)
if err != nil {
return nil, err
}
result[e.ID] = &playground.WorkflowDetail{
ID: ptr.Of(e.ID),
Name: ptr.Of(e.Name),
Description: ptr.Of(e.Desc),
IconURL: ptr.Of(e.IconURL),
PluginID: ptr.Of(e.ID),
APIDetail: &playground.PluginAPIDetal{
ID: ptr.Of(e.ID),
Name: ptr.Of(e.Name),
Description: ptr.Of(e.Desc),
PluginID: ptr.Of(e.ID),
Parameters: parameters,
},
}
}
return result, nil
}
func toPluginParameter(info *vo.NamedTypeInfo) (*playground.PluginParameter, error) {
if info == nil {
return nil, fmt.Errorf("named type info is nil")
}
p := &playground.PluginParameter{
Name: ptr.Of(info.Name),
Description: ptr.Of(info.Desc),
IsRequired: ptr.Of(info.Required),
}
switch info.Type {
case vo.DataTypeString, vo.DataTypeFile, vo.DataTypeTime:
p.Type = ptr.Of("string")
if info.Type == vo.DataTypeFile {
p.AssistType = toWorkflowParameterAssistType(string(*info.FileType))
}
case vo.DataTypeInteger:
p.Type = ptr.Of("integer")
case vo.DataTypeNumber:
p.Type = ptr.Of("number")
case vo.DataTypeBoolean:
p.Type = ptr.Of("boolean")
case vo.DataTypeObject:
p.Type = ptr.Of("object")
p.SubParameters = make([]*playground.PluginParameter, 0, len(info.Properties))
for _, sub := range info.Properties {
subParameter, err := toPluginParameter(sub)
if err != nil {
return nil, err
}
p.SubParameters = append(p.SubParameters, subParameter)
}
case vo.DataTypeArray:
p.Type = ptr.Of("array")
eleParameter, err := toPluginParameter(info.ElemTypeInfo)
if err != nil {
return nil, err
}
p.SubType = eleParameter.Type
p.SubParameters = []*playground.PluginParameter{eleParameter}
default:
return nil, fmt.Errorf("unknown named type info type: %s", info.Type)
}
return p, nil
}
func toWorkflowParameterAssistType(assistType string) *int64 {
if assistType == "" {
return nil
}
switch vo.FileSubType(assistType) {
case vo.FileTypeDefault:
return ptr.Of(int64(workflow.AssistParameterType_DEFAULT))
case vo.FileTypeImage:
return ptr.Of(int64(workflow.AssistParameterType_IMAGE))
case vo.FileTypeDocument:
return ptr.Of(int64(workflow.AssistParameterType_DOC))
case vo.FileTypePPT:
return ptr.Of(int64(workflow.AssistParameterType_PPT))
case vo.FileTypeCode:
return ptr.Of(int64(workflow.AssistParameterType_CODE))
case vo.FileTypeExcel:
return ptr.Of(int64(workflow.AssistParameterType_EXCEL))
case vo.FileTypeZip:
return ptr.Of(int64(workflow.AssistParameterType_ZIP))
case vo.FileTypeVideo:
return ptr.Of(int64(workflow.AssistParameterType_VIDEO))
case vo.FileTypeAudio:
return ptr.Of(int64(workflow.AssistParameterType_AUDIO))
case vo.FileTypeTxt:
return ptr.Of(int64(workflow.AssistParameterType_TXT))
default:
return nil
}
}

View File

@@ -0,0 +1,94 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package singleagent
import (
"context"
"strings"
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/bot_common"
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/developer_api"
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/playground"
)
func (s *SingleAgentApplicationService) GetUploadAuthToken(ctx context.Context, req *developer_api.GetUploadAuthTokenRequest) (*developer_api.GetUploadAuthTokenResponse, error) {
authToken, err := s.getAuthToken(ctx)
if err != nil {
return nil, err
}
prefix := s.getUploadPrefix(req.Scene, req.DataType)
return &developer_api.GetUploadAuthTokenResponse{
Data: &developer_api.GetUploadAuthTokenData{
ServiceID: authToken.ServiceID,
UploadPathPrefix: prefix,
UploadHost: authToken.UploadHost,
Auth: &developer_api.UploadAuthTokenInfo{
AccessKeyID: authToken.AccessKeyID,
SecretAccessKey: authToken.SecretAccessKey,
SessionToken: authToken.SessionToken,
ExpiredTime: authToken.ExpiredTime,
CurrentTime: authToken.CurrentTime,
},
Schema: authToken.HostScheme,
},
}, nil
}
func (s *SingleAgentApplicationService) getAuthToken(ctx context.Context) (*bot_common.AuthToken, error) {
uploadAuthToken, err := s.appContext.ImageX.GetUploadAuth(ctx)
if err != nil {
return nil, err
}
authToken := &bot_common.AuthToken{
ServiceID: s.appContext.ImageX.GetServerID(),
AccessKeyID: uploadAuthToken.AccessKeyID,
SecretAccessKey: uploadAuthToken.SecretAccessKey,
SessionToken: uploadAuthToken.SessionToken,
ExpiredTime: uploadAuthToken.ExpiredTime,
CurrentTime: uploadAuthToken.CurrentTime,
UploadHost: s.appContext.ImageX.GetUploadHost(ctx),
HostScheme: uploadAuthToken.HostScheme,
}
return authToken, nil
}
func (s *SingleAgentApplicationService) getUploadPrefix(scene, dataType string) string {
return strings.Replace(scene, "_", "-", -1) + "-" + dataType
}
func (s *SingleAgentApplicationService) GetImagexShortUrl(ctx context.Context, req *playground.GetImagexShortUrlRequest) (*playground.GetImagexShortUrlResponse, error) {
urlInfo := make(map[string]*playground.UrlInfo, len(req.Uris))
for _, uri := range req.Uris {
resURL, err := s.appContext.ImageX.GetResourceURL(ctx, uri)
if err != nil {
return nil, err
}
urlInfo[uri] = &playground.UrlInfo{
URL: resURL.URL,
ReviewStatus: true,
}
}
return &playground.GetImagexShortUrlResponse{
Data: &playground.GetImagexShortUrlData{
URLInfo: urlInfo,
},
}, nil
}

View File

@@ -0,0 +1,85 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package singleagent
import (
"github.com/cloudwego/eino/compose"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/repository"
singleagent "github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/service"
connector "github.com/coze-dev/coze-studio/backend/domain/connector/service"
knowledge "github.com/coze-dev/coze-studio/backend/domain/knowledge/service"
database "github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
variables "github.com/coze-dev/coze-studio/backend/domain/memory/variables/service"
"github.com/coze-dev/coze-studio/backend/domain/modelmgr"
"github.com/coze-dev/coze-studio/backend/domain/plugin/service"
search "github.com/coze-dev/coze-studio/backend/domain/search/service"
shortcutCmd "github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/service"
user "github.com/coze-dev/coze-studio/backend/domain/user/service"
"github.com/coze-dev/coze-studio/backend/domain/workflow"
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
"github.com/coze-dev/coze-studio/backend/infra/impl/chatmodel"
"github.com/coze-dev/coze-studio/backend/pkg/jsoncache"
)
type (
SingleAgent = singleagent.SingleAgent
)
var SingleAgentSVC *SingleAgentApplicationService
type ServiceComponents struct {
IDGen idgen.IDGenerator
DB *gorm.DB
Cache *redis.Client
TosClient storage.Storage
ImageX imagex.ImageX
EventBus search.ProjectEventBus
CounterRepo repository.CounterRepository
KnowledgeDomainSVC knowledge.Knowledge
ModelMgrDomainSVC modelmgr.Manager
PluginDomainSVC service.PluginService
WorkflowDomainSVC workflow.Service
UserDomainSVC user.User
VariablesDomainSVC variables.Variables
ConnectorDomainSVC connector.Connector
DatabaseDomainSVC database.Database
ShortcutCMDDomainSVC shortcutCmd.ShortcutCmd
CPStore compose.CheckPointStore
}
func InitService(c *ServiceComponents) (*SingleAgentApplicationService, error) {
domainComponents := &singleagent.Components{
AgentDraftRepo: repository.NewSingleAgentRepo(c.DB, c.IDGen, c.Cache),
AgentVersionRepo: repository.NewSingleAgentVersionRepo(c.DB, c.IDGen),
PublishInfoRepo: jsoncache.New[entity.PublishInfo]("agent:publish:last:", c.Cache),
CounterRepo: repository.NewCounterRepo(c.Cache),
CPStore: c.CPStore,
ModelFactory: chatmodel.NewDefaultFactory(),
}
singleAgentDomainSVC := singleagent.NewService(domainComponents)
SingleAgentSVC = newApplicationService(c, singleAgentDomainSVC)
return SingleAgentSVC, nil
}

View File

@@ -0,0 +1,260 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package singleagent
import (
"context"
"fmt"
"strconv"
"sync"
"time"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/api/model/intelligence/common"
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/developer_api"
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/playground"
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
search "github.com/coze-dev/coze-studio/backend/domain/search/entity"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/taskgroup"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
func (s *SingleAgentApplicationService) PublishAgent(ctx context.Context, req *developer_api.PublishDraftBotRequest) (*developer_api.PublishDraftBotResponse, error) {
draftAgent, err := s.ValidateAgentDraftAccess(ctx, req.BotID)
if err != nil {
return nil, err
}
version, err := s.getPublishAgentVersion(ctx, req)
if err != nil {
return nil, err
}
connectorIDs := make([]int64, 0, len(req.Connectors))
for v := range req.Connectors {
var id int64
id, err = strconv.ParseInt(v, 10, 64)
if err != nil {
return nil, err
}
if !entity.PublishConnectorIDWhiteList[id] {
return nil, errorx.New(errno.ErrAgentPermissionCode, errorx.KV("msg", fmt.Sprintf("connector %d not allowed", id)))
}
connectorIDs = append(connectorIDs, id)
}
p := &entity.SingleAgentPublish{
ConnectorIds: connectorIDs,
Version: version,
PublishID: req.GetPublishID(),
PublishInfo: req.HistoryInfo,
}
publishFns := []publishFn{
publishAgentVariables,
publishAgentPlugins,
publishShortcutCommand,
publishDatabase,
}
for _, pubFn := range publishFns {
draftAgent, err = pubFn(ctx, s.appContext, p, draftAgent)
if err != nil {
return nil, err
}
}
err = s.DomainSVC.SavePublishRecord(ctx, p, draftAgent)
if err != nil {
return nil, err
}
tasks := taskgroup.NewUninterruptibleTaskGroup(ctx, len(connectorIDs))
publishResult := make(map[string]*developer_api.ConnectorBindResult, len(connectorIDs))
lock := sync.Mutex{}
for _, connectorID := range connectorIDs {
tasks.Go(func() error {
_, err = s.DomainSVC.CreateSingleAgent(ctx, connectorID, version, draftAgent)
if err != nil {
logs.CtxWarnf(ctx, "create single agent failed: %v, agentID: %d, connectorID: %d , version : %s", err, draftAgent.AgentID, connectorID, version)
lock.Lock()
publishResult[conv.Int64ToStr(connectorID)] = &developer_api.ConnectorBindResult{
PublishResultStatus: ptr.Of(developer_api.PublishResultStatus_Failed),
}
lock.Unlock()
return err
}
// do other connector publish logic if need
lock.Lock()
publishResult[conv.Int64ToStr(connectorID)] = &developer_api.ConnectorBindResult{
PublishResultStatus: ptr.Of(developer_api.PublishResultStatus_Success),
}
lock.Unlock()
return nil
})
}
_ = tasks.Wait()
err = s.appContext.EventBus.PublishProject(ctx, &search.ProjectDomainEvent{
OpType: search.Updated,
Project: &search.ProjectDocument{
ID: draftAgent.AgentID,
HasPublished: ptr.Of(1),
PublishTimeMS: ptr.Of(time.Now().UnixMilli()),
Type: common.IntelligenceType_Bot,
},
})
if err != nil {
logs.CtxWarnf(ctx, "publish project event failed, agentID: %d, err : %v", draftAgent.AgentID, err)
}
return &developer_api.PublishDraftBotResponse{
Data: &developer_api.PublishDraftBotData{
CheckNotPass: false,
PublishResult: publishResult,
},
}, nil
}
func (s *SingleAgentApplicationService) getPublishAgentVersion(ctx context.Context, req *developer_api.PublishDraftBotRequest) (string, error) {
version := req.GetCommitVersion()
if version != "" {
return version, nil
}
v, err := s.appContext.IDGen.GenID(ctx)
if err != nil {
return "", err
}
version = fmt.Sprintf("%v", v)
return version, nil
}
func (s *SingleAgentApplicationService) GetAgentPopupInfo(ctx context.Context, req *playground.GetBotPopupInfoRequest) (*playground.GetBotPopupInfoResponse, error) {
uid := ctxutil.MustGetUIDFromCtx(ctx)
agentPopupCountInfo := make(map[playground.BotPopupType]int64, len(req.BotPopupTypes))
for _, agentPopupType := range req.BotPopupTypes {
count, err := s.DomainSVC.GetAgentPopupCount(ctx, uid, req.GetBotID(), agentPopupType)
if err != nil {
return nil, err
}
agentPopupCountInfo[agentPopupType] = count
}
return &playground.GetBotPopupInfoResponse{
Data: &playground.BotPopupInfoData{
BotPopupCountInfo: agentPopupCountInfo,
},
}, nil
}
func (s *SingleAgentApplicationService) UpdateAgentPopupInfo(ctx context.Context, req *playground.UpdateBotPopupInfoRequest) (*playground.UpdateBotPopupInfoResponse, error) {
uid := ctxutil.MustGetUIDFromCtx(ctx)
err := s.DomainSVC.IncrAgentPopupCount(ctx, uid, req.GetBotID(), req.GetBotPopupType())
if err != nil {
return nil, err
}
return &playground.UpdateBotPopupInfoResponse{
Code: 0,
Msg: "success",
}, nil
}
func (s *SingleAgentApplicationService) GetPublishConnectorList(ctx context.Context, req *developer_api.PublishConnectorListRequest) (*developer_api.PublishConnectorListResponse, error) {
data, err := s.DomainSVC.GetPublishConnectorList(ctx, req.BotID)
if err != nil {
return nil, err
}
return &developer_api.PublishConnectorListResponse{
PublishConnectorList: data.PublishConnectorList,
Code: 0,
Msg: "success",
}, nil
}
type publishFn func(ctx context.Context, appContext *ServiceComponents, publishInfo *entity.SingleAgentPublish, agent *entity.SingleAgent) (*entity.SingleAgent, error)
func publishAgentVariables(ctx context.Context, appContext *ServiceComponents, publishInfo *entity.SingleAgentPublish, agent *entity.SingleAgent) (*entity.SingleAgent, error) {
draftAgent := agent
if draftAgent.VariablesMetaID != nil || *draftAgent.VariablesMetaID == 0 {
return draftAgent, nil
}
var newVariableMetaID int64
newVariableMetaID, err := appContext.VariablesDomainSVC.PublishMeta(ctx, *draftAgent.VariablesMetaID, publishInfo.Version)
if err != nil {
return nil, err
}
draftAgent.VariablesMetaID = ptr.Of(newVariableMetaID)
return draftAgent, nil
}
func publishAgentPlugins(ctx context.Context, appContext *ServiceComponents, publishInfo *entity.SingleAgentPublish, agent *entity.SingleAgent) (*entity.SingleAgent, error) {
err := appContext.PluginDomainSVC.PublishAgentTools(ctx, agent.AgentID, publishInfo.Version)
if err != nil {
return nil, err
}
return agent, nil
}
func publishShortcutCommand(ctx context.Context, appContext *ServiceComponents, publishInfo *entity.SingleAgentPublish, agent *entity.SingleAgent) (*entity.SingleAgent, error) {
logs.CtxInfof(ctx, "publishShortcutCommand agentID: %d, shortcutCommand: %v", agent.AgentID, agent.ShortcutCommand)
if agent.ShortcutCommand == nil || len(agent.ShortcutCommand) == 0 {
return agent, nil
}
cmdIDs := slices.Transform(agent.ShortcutCommand, func(a string) int64 {
return conv.StrToInt64D(a, 0)
})
err := appContext.ShortcutCMDDomainSVC.PublishCMDs(ctx, agent.AgentID, cmdIDs)
if err != nil {
return nil, err
}
return agent, nil
}
func publishDatabase(ctx context.Context, appContext *ServiceComponents, publishInfo *entity.SingleAgentPublish, agent *entity.SingleAgent) (*entity.SingleAgent, error) {
onlineResp, err := appContext.DatabaseDomainSVC.PublishDatabase(ctx, &database.PublishDatabaseRequest{AgentID: agent.AgentID})
if err != nil {
return nil, err
}
agent.Database = onlineResp.OnlineDatabases
return agent, nil
}

View File

@@ -0,0 +1,733 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package singleagent
import (
"context"
"fmt"
"strconv"
"time"
shortcutCmd "github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/service"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/types/consts"
"github.com/bytedance/sonic"
"github.com/getkin/kin-openapi/openapi3"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
intelligence "github.com/coze-dev/coze-studio/backend/api/model/intelligence/common"
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/bot_common"
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/developer_api"
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/playground"
"github.com/coze-dev/coze-studio/backend/api/model/table"
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossdatabase"
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
singleagent "github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/service"
variableEntity "github.com/coze-dev/coze-studio/backend/domain/memory/variables/entity"
shortcutEntity "github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/entity"
searchEntity "github.com/coze-dev/coze-studio/backend/domain/search/entity"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
type SingleAgentApplicationService struct {
appContext *ServiceComponents
DomainSVC singleagent.SingleAgent
ShortcutCMDSVC shortcutCmd.ShortcutCmd
}
func newApplicationService(s *ServiceComponents, domain singleagent.SingleAgent) *SingleAgentApplicationService {
return &SingleAgentApplicationService{
appContext: s,
DomainSVC: domain,
ShortcutCMDSVC: s.ShortcutCMDDomainSVC,
}
}
const onboardingInfoMaxLength = 65535
func (s *SingleAgentApplicationService) generateOnboardingStr(onboardingInfo *bot_common.OnboardingInfo) (string, error) {
onboarding := playground.OnboardingContent{}
if onboardingInfo != nil {
onboarding.Prologue = ptr.Of(onboardingInfo.GetPrologue())
onboarding.SuggestedQuestions = onboardingInfo.GetSuggestedQuestions()
onboarding.SuggestedQuestionsShowMode = onboardingInfo.SuggestedQuestionsShowMode
}
onboardingInfoStr, err := sonic.MarshalString(onboarding)
if err != nil {
return "", err
}
return onboardingInfoStr, nil
}
func (s *SingleAgentApplicationService) UpdateSingleAgentDraft(ctx context.Context, req *playground.UpdateDraftBotInfoAgwRequest) (*playground.UpdateDraftBotInfoAgwResponse, error) {
if req.BotInfo.OnboardingInfo != nil {
infoStr, err := s.generateOnboardingStr(req.BotInfo.OnboardingInfo)
if err != nil {
return nil, errorx.New(errno.ErrAgentPermissionCode, errorx.KV("msg", "onboarding_info invalidate"))
}
if len(infoStr) > onboardingInfoMaxLength {
return nil, errorx.New(errno.ErrAgentPermissionCode, errorx.KV("msg", "onboarding_info is too long"))
}
}
agentID := req.BotInfo.GetBotId()
currentAgentInfo, err := s.ValidateAgentDraftAccess(ctx, agentID)
if err != nil {
return nil, err
}
userID := ctxutil.MustGetUIDFromCtx(ctx)
updateAgentInfo, err := s.applyAgentUpdates(currentAgentInfo, req.BotInfo)
if err != nil {
return nil, err
}
if req.BotInfo.VariableList != nil {
var (
varsMetaID int64
vars = variableEntity.NewVariablesWithAgentVariables(req.BotInfo.VariableList)
)
varsMetaID, err = s.appContext.VariablesDomainSVC.UpsertBotMeta(ctx, agentID, "", userID, vars)
if err != nil {
return nil, err
}
updateAgentInfo.VariablesMetaID = &varsMetaID
}
err = s.DomainSVC.UpdateSingleAgentDraft(ctx, updateAgentInfo)
if err != nil {
return nil, err
}
err = s.appContext.EventBus.PublishProject(ctx, &searchEntity.ProjectDomainEvent{
OpType: searchEntity.Updated,
Project: &searchEntity.ProjectDocument{
ID: agentID,
Name: &updateAgentInfo.Name,
Type: intelligence.IntelligenceType_Bot,
},
})
if err != nil {
return nil, err
}
return &playground.UpdateDraftBotInfoAgwResponse{
Data: &playground.UpdateDraftBotInfoAgwData{
HasChange: ptr.Of(true),
CheckNotPass: false,
Branch: playground.BranchPtr(playground.Branch_PersonalDraft),
},
}, nil
}
func (s *SingleAgentApplicationService) UpdatePromptDisable(ctx context.Context, req *table.UpdateDatabaseBotSwitchRequest) (*table.UpdateDatabaseBotSwitchResponse, error) {
agentID := req.GetBotID()
draft, err := s.ValidateAgentDraftAccess(ctx, agentID)
if err != nil {
return nil, err
}
if len(draft.Database) == 0 {
return nil, fmt.Errorf("agent %d has no database", agentID) // TODO@fanlv: 错误码
}
dbInfos := draft.Database
var found bool
for _, db := range dbInfos {
if db.GetTableId() == conv.Int64ToStr(req.GetDatabaseID()) {
db.PromptDisabled = ptr.Of(req.GetPromptDisable())
found = true
break
}
}
if !found {
return nil, fmt.Errorf("database %d not found in agent %d", req.GetDatabaseID(), agentID) // TODO@fanlv: 错误码
}
draft.Database = dbInfos
err = s.DomainSVC.UpdateSingleAgentDraft(ctx, draft)
if err != nil {
return nil, err
}
return &table.UpdateDatabaseBotSwitchResponse{
Code: 0,
Msg: "success",
}, nil
}
func (s *SingleAgentApplicationService) UnBindDatabase(ctx context.Context, req *table.BindDatabaseToBotRequest) (*table.BindDatabaseToBotResponse, error) {
agentID := req.GetBotID()
draft, err := s.ValidateAgentDraftAccess(ctx, agentID)
if err != nil {
return nil, err
}
if len(draft.Database) == 0 {
return nil, fmt.Errorf("agent %d has no database", agentID)
}
dbInfos := draft.Database
var found bool
newDBInfos := make([]*bot_common.Database, 0)
for _, db := range dbInfos {
if db.GetTableId() == conv.Int64ToStr(req.GetDatabaseID()) {
found = true
continue
}
newDBInfos = append(newDBInfos, db)
}
if !found {
return nil, fmt.Errorf("database %d not found in agent %d", req.GetDatabaseID(), agentID)
}
draft.Database = newDBInfos
err = s.DomainSVC.UpdateSingleAgentDraft(ctx, draft)
if err != nil {
return nil, err
}
err = crossdatabase.DefaultSVC().UnBindDatabase(ctx, &database.UnBindDatabaseToAgentRequest{
AgentID: agentID,
DraftDatabaseID: req.GetDatabaseID(),
})
if err != nil {
return nil, err
}
return &table.BindDatabaseToBotResponse{
Code: 0,
Msg: "success",
}, nil
}
func (s *SingleAgentApplicationService) BindDatabase(ctx context.Context, req *table.BindDatabaseToBotRequest) (*table.BindDatabaseToBotResponse, error) {
agentID := req.GetBotID()
draft, err := s.ValidateAgentDraftAccess(ctx, agentID)
if err != nil {
return nil, err
}
dbMap := slices.ToMap(draft.Database, func(d *bot_common.Database) (string, *bot_common.Database) {
return d.GetTableId(), d
})
if _, ok := dbMap[conv.Int64ToStr(req.GetDatabaseID())]; ok {
return nil, errorx.New(errno.ErrAgentPermissionCode, errorx.KVf("msg", "database %d already bound to agent %d", req.GetDatabaseID(), agentID))
}
basics := []*database.DatabaseBasic{
{
ID: req.DatabaseID,
TableType: table.TableType_DraftTable,
},
}
draftRes, err := crossdatabase.DefaultSVC().MGetDatabase(ctx, &database.MGetDatabaseRequest{
Basics: basics,
})
if err != nil {
return nil, err
}
if len(draftRes.Databases) == 0 {
return nil, fmt.Errorf("database %d not found", req.DatabaseID)
}
draftDatabase := draftRes.Databases[0]
fields := make([]*bot_common.FieldItem, 0, len(draftDatabase.FieldList))
for _, field := range draftDatabase.FieldList {
fields = append(fields, &bot_common.FieldItem{
Name: ptr.Of(field.Name),
Desc: ptr.Of(field.Desc),
Type: ptr.Of(bot_common.FieldItemType(field.Type)),
MustRequired: ptr.Of(field.MustRequired),
AlterId: ptr.Of(field.AlterID),
Id: ptr.Of(int64(0)),
})
}
bindDB := &bot_common.Database{
TableId: ptr.Of(strconv.FormatInt(draftDatabase.ID, 10)),
TableName: ptr.Of(draftDatabase.TableName),
TableDesc: ptr.Of(draftDatabase.TableDesc),
FieldList: fields,
RWMode: ptr.Of(bot_common.BotTableRWMode(draftDatabase.RwMode)),
}
if len(draft.Database) == 0 {
draft.Database = make([]*bot_common.Database, 0, 1)
}
draft.Database = append(draft.Database, bindDB)
err = s.DomainSVC.UpdateSingleAgentDraft(ctx, draft)
if err != nil {
return nil, err
}
err = crossdatabase.DefaultSVC().BindDatabase(ctx, &database.BindDatabaseToAgentRequest{
AgentID: agentID,
DraftDatabaseID: req.GetDatabaseID(),
})
if err != nil {
return nil, err
}
return &table.BindDatabaseToBotResponse{
Code: 0,
Msg: "success",
}, nil
}
func (s *SingleAgentApplicationService) applyAgentUpdates(target *entity.SingleAgent, patch *bot_common.BotInfoForUpdate) (*entity.SingleAgent, error) {
if patch.Name != nil {
target.Name = *patch.Name
}
if patch.Description != nil {
target.Desc = *patch.Description
}
if patch.IconUri != nil {
target.IconURI = *patch.IconUri
}
if patch.OnboardingInfo != nil {
target.OnboardingInfo = patch.OnboardingInfo
}
if patch.ModelInfo != nil {
target.ModelInfo = patch.ModelInfo
}
if patch.PromptInfo != nil {
target.Prompt = patch.PromptInfo
}
if patch.WorkflowInfoList != nil {
target.Workflow = patch.WorkflowInfoList
}
if patch.PluginInfoList != nil {
target.Plugin = patch.PluginInfoList
}
if patch.Knowledge != nil {
target.Knowledge = patch.Knowledge
}
if patch.SuggestReplyInfo != nil {
target.SuggestReply = patch.SuggestReplyInfo
}
if patch.BackgroundImageInfoList != nil {
target.BackgroundImageInfoList = patch.BackgroundImageInfoList
}
if patch.Agents != nil && len(patch.Agents) > 0 && patch.Agents[0].JumpConfig != nil {
target.JumpConfig = patch.Agents[0].JumpConfig
}
if patch.ShortcutSort != nil {
target.ShortcutCommand = patch.ShortcutSort
}
if patch.DatabaseList != nil {
for _, db := range patch.DatabaseList {
if db.PromptDisabled == nil {
db.PromptDisabled = ptr.Of(false) // default is false
}
}
target.Database = patch.DatabaseList
}
return target, nil
}
func (s *SingleAgentApplicationService) DeleteAgentDraft(ctx context.Context, req *developer_api.DeleteDraftBotRequest) (*developer_api.DeleteDraftBotResponse, error) {
_, err := s.ValidateAgentDraftAccess(ctx, req.GetBotID())
if err != nil {
return nil, err
}
err = s.DomainSVC.DeleteAgentDraft(ctx, req.GetSpaceID(), req.GetBotID())
if err != nil {
return nil, err
}
err = s.appContext.EventBus.PublishProject(ctx, &searchEntity.ProjectDomainEvent{
OpType: searchEntity.Deleted,
Project: &searchEntity.ProjectDocument{
ID: req.GetBotID(),
Type: intelligence.IntelligenceType_Bot,
},
})
if err != nil {
logs.CtxWarnf(ctx, "publish delete project event failed id = %v , err = %v", req.GetBotID(), err)
}
return &developer_api.DeleteDraftBotResponse{
Data: &developer_api.DeleteDraftBotData{},
Code: 0,
}, nil
}
func (s *SingleAgentApplicationService) singleAgentDraftDo2Vo(ctx context.Context, do *entity.SingleAgent) (*bot_common.BotInfo, error) {
vo := &bot_common.BotInfo{
BotId: do.AgentID,
Name: do.Name,
Description: do.Desc,
IconUri: do.IconURI,
OnboardingInfo: do.OnboardingInfo,
ModelInfo: do.ModelInfo,
PromptInfo: do.Prompt,
PluginInfoList: do.Plugin,
Knowledge: do.Knowledge,
WorkflowInfoList: do.Workflow,
SuggestReplyInfo: do.SuggestReply,
CreatorId: do.CreatorID,
TaskInfo: &bot_common.TaskInfo{},
CreateTime: do.CreatedAt / 1000,
UpdateTime: do.UpdatedAt / 1000,
BotMode: bot_common.BotMode_SingleMode,
BackgroundImageInfoList: do.BackgroundImageInfoList,
Status: bot_common.BotStatus_Using,
DatabaseList: do.Database,
ShortcutSort: do.ShortcutCommand,
}
if do.VariablesMetaID != nil {
vars, err := s.appContext.VariablesDomainSVC.GetVariableMetaByID(ctx, *do.VariablesMetaID)
if err != nil {
return nil, err
}
if vars != nil {
vo.VariableList = vars.ToAgentVariables()
}
}
if vo.IconUri != "" {
url, err := s.appContext.TosClient.GetObjectUrl(ctx, vo.IconUri)
if err != nil {
return nil, err
}
vo.IconUrl = url
}
if vo.ModelInfo == nil || vo.ModelInfo.ModelId == nil {
mi, err := s.defaultModelInfo(ctx)
if err != nil {
return nil, err
}
vo.ModelInfo = mi
}
return vo, nil
}
func disabledParam(schemaVal *openapi3.Schema) bool {
if len(schemaVal.Extensions) == 0 {
return false
}
globalDisable, localDisable := false, false
if v, ok := schemaVal.Extensions[plugin.APISchemaExtendLocalDisable]; ok {
localDisable = v.(bool)
}
if v, ok := schemaVal.Extensions[plugin.APISchemaExtendGlobalDisable]; ok {
globalDisable = v.(bool)
}
return globalDisable || localDisable
}
func (s *SingleAgentApplicationService) UpdateAgentDraftDisplayInfo(ctx context.Context, req *developer_api.UpdateDraftBotDisplayInfoRequest) (*developer_api.UpdateDraftBotDisplayInfoResponse, error) {
uid := ctxutil.GetUIDFromCtx(ctx)
if uid == nil {
return nil, errorx.New(errno.ErrAgentPermissionCode, errorx.KV("msg", "session required"))
}
_, err := s.ValidateAgentDraftAccess(ctx, req.BotID)
if err != nil {
return nil, err
}
draftInfoDo := &entity.AgentDraftDisplayInfo{
AgentID: req.BotID,
DisplayInfo: req.DisplayInfo,
SpaceID: req.SpaceID,
}
err = s.DomainSVC.UpdateAgentDraftDisplayInfo(ctx, *uid, draftInfoDo)
if err != nil {
return nil, err
}
return &developer_api.UpdateDraftBotDisplayInfoResponse{
Code: 0,
Msg: "success",
}, nil
}
func (s *SingleAgentApplicationService) GetAgentDraftDisplayInfo(ctx context.Context, req *developer_api.GetDraftBotDisplayInfoRequest) (*developer_api.GetDraftBotDisplayInfoResponse, error) {
uid := ctxutil.GetUIDFromCtx(ctx)
if uid == nil {
return nil, errorx.New(errno.ErrAgentPermissionCode, errorx.KV("msg", "session required"))
}
_, err := s.ValidateAgentDraftAccess(ctx, req.BotID)
if err != nil {
return nil, err
}
draftInfoDo, err := s.DomainSVC.GetAgentDraftDisplayInfo(ctx, *uid, req.BotID)
if err != nil {
return nil, err
}
return &developer_api.GetDraftBotDisplayInfoResponse{
Code: 0,
Msg: "success",
Data: draftInfoDo.DisplayInfo,
}, nil
}
func (s *SingleAgentApplicationService) ValidateAgentDraftAccess(ctx context.Context, agentID int64) (*entity.SingleAgent, error) {
uid := ctxutil.GetUIDFromCtx(ctx)
if uid == nil {
uid = ptr.Of(int64(888))
// return nil, errorx.New(errno.ErrAgentPermissionCode, errorx.KV("msg", "session uid not found"))
}
do, err := s.DomainSVC.GetSingleAgentDraft(ctx, agentID)
if err != nil {
return nil, err
}
if do == nil {
return nil, errorx.New(errno.ErrAgentPermissionCode, errorx.KVf("msg", "No agent draft(%d) found for the given agent ID", agentID))
}
if do.SpaceID == consts.TemplateSpaceID { // duplicate template, not need check uid permission
return do, nil
}
if do.CreatorID != *uid {
logs.CtxErrorf(ctx, "user(%d) is not the creator(%d) of the agent draft", *uid, do.CreatorID)
return do, errorx.New(errno.ErrAgentPermissionCode, errorx.KV("detail", "you are not the agent owner"))
}
return do, nil
}
func (s *SingleAgentApplicationService) ListAgentPublishHistory(ctx context.Context, req *developer_api.ListDraftBotHistoryRequest) (*developer_api.ListDraftBotHistoryResponse, error) {
resp := &developer_api.ListDraftBotHistoryResponse{}
draftAgent, err := s.ValidateAgentDraftAccess(ctx, req.BotID)
if err != nil {
return nil, err
}
var connectorID *int64
if req.GetConnectorID() != "" {
var id int64
id, err = conv.StrToInt64(req.GetConnectorID())
if err != nil {
return nil, errorx.New(errno.ErrAgentInvalidParamCode, errorx.KV("msg", fmt.Sprintf("ConnectorID %v invalidate", *req.ConnectorID)))
}
connectorID = ptr.Of(id)
}
historyList, err := s.DomainSVC.ListAgentPublishHistory(ctx, draftAgent.AgentID, req.PageIndex, req.PageSize, connectorID)
if err != nil {
return nil, err
}
uid := ctxutil.MustGetUIDFromCtx(ctx)
resp.Data = &developer_api.ListDraftBotHistoryData{}
for _, v := range historyList {
connectorInfos := make([]*developer_api.ConnectorInfo, 0, len(v.ConnectorIds))
infos, err := s.appContext.ConnectorDomainSVC.GetByIDs(ctx, v.ConnectorIds)
if err != nil {
return nil, err
}
for _, info := range infos {
connectorInfos = append(connectorInfos, info.ToVO())
}
creator, err := s.appContext.UserDomainSVC.GetUserProfiles(ctx, v.CreatorID)
if err != nil {
return nil, err
}
info := ""
if v.PublishInfo != nil {
info = *v.PublishInfo
}
historyInfo := &developer_api.HistoryInfo{
HistoryType: developer_api.HistoryType_FLAG,
Version: v.Version,
Info: info,
CreateTime: conv.Int64ToStr(v.CreatedAt / 1000),
ConnectorInfos: connectorInfos,
Creator: &developer_api.Creator{
ID: v.CreatorID,
Name: creator.Name,
AvatarURL: creator.IconURL,
Self: uid == v.CreatorID,
// UserUniqueName: creator.UserUniqueName, // TODO(@fanlv) : user domain 补完以后再改
// UserLabel TODO
},
PublishID: &v.PublishID,
}
resp.Data.HistoryInfos = append(resp.Data.HistoryInfos, historyInfo)
}
return resp, nil
}
func (s *SingleAgentApplicationService) ReportUserBehavior(ctx context.Context, req *playground.ReportUserBehaviorRequest) (resp *playground.ReportUserBehaviorResponse, err error) {
err = s.appContext.EventBus.PublishProject(ctx, &searchEntity.ProjectDomainEvent{
OpType: searchEntity.Updated,
Project: &searchEntity.ProjectDocument{
ID: req.ResourceID,
SpaceID: req.SpaceID,
Type: intelligence.IntelligenceType_Bot,
IsRecentlyOpen: ptr.Of(1),
RecentlyOpenMS: ptr.Of(time.Now().UnixMilli()),
},
})
if err != nil {
logs.CtxWarnf(ctx, "publish updated project event failed id=%v, err=%v", req.ResourceID, err)
}
return &playground.ReportUserBehaviorResponse{}, nil
}
func (s *SingleAgentApplicationService) GetAgentOnlineInfo(ctx context.Context, req *playground.GetBotOnlineInfoReq) (*bot_common.OpenAPIBotInfo, error) {
uid := ctxutil.MustGetUIDFromApiAuthCtx(ctx)
connectorID, err := conv.StrToInt64(ptr.From(req.ConnectorID))
if err != nil {
return nil, err
}
if connectorID == 0 {
connectorID = ctxutil.GetApiAuthFromCtx(ctx).ConnectorID
}
agentInfo, err := s.DomainSVC.ObtainAgentByIdentity(ctx, &entity.AgentIdentity{
AgentID: req.BotID,
ConnectorID: connectorID,
Version: ptr.From(req.Version),
})
if err != nil {
return nil, err
}
if agentInfo == nil {
logs.CtxErrorf(ctx, "agent(%d) is not exist", req.BotID)
return nil, errorx.New(errno.ErrAgentPermissionCode, errorx.KV("msg", "agent not exist"))
}
if agentInfo.CreatorID != uid {
return nil, errorx.New(errno.ErrPromptPermissionCode, errorx.KV("msg", "agent not own"))
}
combineInfo := &bot_common.OpenAPIBotInfo{
BotID: agentInfo.AgentID,
Name: agentInfo.Name,
Description: agentInfo.Desc,
IconURL: agentInfo.IconURI,
Version: agentInfo.Version,
BotMode: bot_common.BotMode_SingleMode,
PromptInfo: agentInfo.Prompt,
OnboardingInfo: agentInfo.OnboardingInfo,
ModelInfo: agentInfo.ModelInfo,
WorkflowInfoList: agentInfo.Workflow,
PluginInfoList: agentInfo.Plugin,
}
if agentInfo.IconURI != "" {
url, err := s.appContext.TosClient.GetObjectUrl(ctx, agentInfo.IconURI)
if err != nil {
return nil, err
}
combineInfo.IconURL = url
}
if len(agentInfo.ShortcutCommand) > 0 {
shortcutInfos, err := s.ShortcutCMDSVC.ListCMD(ctx, &shortcutEntity.ListMeta{
ObjectID: agentInfo.AgentID,
IsOnline: 1,
CommandIDs: slices.Transform(agentInfo.ShortcutCommand, func(s string) int64 {
i, _ := conv.StrToInt64(s)
return i
}),
})
if err != nil {
return nil, err
}
combineInfo.ShortcutCommands = make([]*bot_common.ShortcutCommandInfo, 0, len(shortcutInfos))
combineInfo.ShortcutCommands = slices.Transform(shortcutInfos, func(si *shortcutEntity.ShortcutCmd) *bot_common.ShortcutCommandInfo {
url := ""
if si.ShortcutIcon != nil && si.ShortcutIcon.URI != "" {
getUrl, e := s.appContext.TosClient.GetObjectUrl(ctx, si.ShortcutIcon.URI)
if e == nil {
url = getUrl
}
}
return &bot_common.ShortcutCommandInfo{
ID: si.CommandID,
Name: si.CommandName,
Description: si.Description,
IconURL: url,
QueryTemplate: si.TemplateQuery,
AgentID: ptr.Of(si.ObjectID),
Command: si.ShortcutCommand,
Components: slices.Transform(si.Components, func(i *playground.Components) *bot_common.ShortcutCommandComponent {
return &bot_common.ShortcutCommandComponent{
Name: i.Name,
Description: i.Description,
Type: i.InputType.String(),
ToolParameter: ptr.Of(i.Parameter),
Options: i.Options,
DefaultValue: ptr.Of(i.DefaultValue.Value),
IsHide: i.Hide,
}
}),
}
})
}
return combineInfo, nil
}

View File

@@ -0,0 +1,43 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package template
import (
"context"
"gorm.io/gorm"
"github.com/coze-dev/coze-studio/backend/domain/template/repository"
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
)
type ServiceComponents struct {
DB *gorm.DB
IDGen idgen.IDGenerator
Storage storage.Storage
}
func InitService(ctx context.Context, components *ServiceComponents) *ApplicationService {
tRepo := repository.NewTemplateDAO(components.DB, components.IDGen)
ApplicationSVC.templateRepo = tRepo
ApplicationSVC.storage = components.Storage
return ApplicationSVC
}

View File

@@ -0,0 +1,92 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package template
import (
"context"
productAPI "github.com/coze-dev/coze-studio/backend/api/model/flow/marketplace/product_public_api"
"github.com/coze-dev/coze-studio/backend/domain/template/entity"
"github.com/coze-dev/coze-studio/backend/domain/template/repository"
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/types/consts"
)
type ApplicationService struct {
templateRepo repository.TemplateRepository
storage storage.Storage
}
var ApplicationSVC = &ApplicationService{}
func (t *ApplicationService) PublicGetProductList(ctx context.Context, req *productAPI.GetProductListRequest) (resp *productAPI.GetProductListResponse, err error) {
pageSize := 50
if req.PageSize > 0 {
pageSize = int(req.PageSize)
}
pagination := &entity.Pagination{
Limit: pageSize,
Offset: int(req.PageNum) * pageSize,
}
listResp, allNum, err := t.templateRepo.List(ctx, &entity.TemplateFilter{SpaceID: ptr.Of(int64(consts.TemplateSpaceID))}, pagination, "")
if err != nil {
return nil, err
}
products := make([]*productAPI.ProductInfo, 0, len(listResp))
for _, item := range listResp {
meta := item.MetaInfo
for _, cover := range meta.Covers {
objURL, uRrr := t.storage.GetObjectUrl(ctx, cover.URI)
if uRrr == nil {
cover.URL = objURL
}
}
avatarURL, uRrr := t.storage.GetObjectUrl(ctx, "default_icon/connector-coze.png")
if uRrr == nil {
if meta.Seller != nil {
meta.Seller.AvatarURL = avatarURL
}
if meta.UserInfo != nil {
meta.UserInfo.AvatarURL = avatarURL
}
}
products = append(products, &productAPI.ProductInfo{
MetaInfo: item.MetaInfo,
BotExtra: item.AgentExtra,
WorkflowExtra: item.WorkflowExtra,
ProjectExtra: item.ProjectExtra,
})
}
hasMore := false
if int64(int(req.PageNum)*pageSize) < allNum {
hasMore = true
}
resp = &productAPI.GetProductListResponse{
Data: &productAPI.GetProductListData{
Products: products,
HasMore: hasMore,
Total: int32(allNum),
},
}
return resp, nil
}

View File

@@ -0,0 +1,25 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package upload
const maxFileSize = 200 * 1024 * 1024
const (
TextKnowledgeDefaultIcon = "default_icon/text_kn_default_icon.png"
TableKnowledgeDefaultIcon = "default_icon/table_kn_default_icon.png"
ImageKnowledgeDefaultIcon = "default_icon/image_kn_default_icon.png"
DatabaseDefaultIcon = "default_icon/default_database_icon.png"
)

View File

@@ -0,0 +1,708 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package upload
import (
"bytes"
"context"
"encoding/xml"
"errors"
"fmt"
"hash/crc32"
"image"
_ "image/gif"
_ "image/jpeg"
_ "image/png"
"io"
"math"
"mime"
"mime/multipart"
"path"
"regexp"
"sort"
"strconv"
"strings"
"time"
_ "golang.org/x/image/tiff"
_ "golang.org/x/image/webp"
"github.com/google/uuid"
"github.com/coze-dev/coze-studio/backend/api/model/flow/dataengine/dataset"
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/developer_api"
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/playground"
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
"github.com/coze-dev/coze-studio/backend/domain/upload/entity"
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
"github.com/coze-dev/coze-studio/backend/types/consts"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
func InitService(oss storage.Storage, cache cache.Cmdable) {
SVC.cache = cache
SVC.oss = oss
}
var SVC = &UploadService{}
type UploadService struct {
oss storage.Storage
cache cache.Cmdable
}
const (
uploadKey = "UploadServiceUpload:%s"
uploadPartKey = "UploadServiceUpload:%s/parts"
partKey = "UploadServiceUpload/%s/part-%s"
)
func (u *UploadService) PartUploadFileInit(ctx context.Context, objKey string) (uploadID string, err error) {
uploadID = uuid.NewString()
key := fmt.Sprintf(uploadKey, uploadID)
err = u.cache.HSet(ctx,
key,
"objkey", objKey,
).Err()
if err != nil {
return "", err
}
err = u.cache.Expire(ctx, key, time.Minute*10).Err()
return
}
type PartUploadFileRequest struct {
UploadID string
PartNumber string
Data []byte
}
type PartUploadFileResponse struct {
Crc32 string
}
type PartUploadFileCompleteRequest struct {
UploadID string
ObjKey string
Crc32Map map[string]string
}
func (u *UploadService) PartUploadFile(ctx context.Context, req *PartUploadFileRequest) (resp *PartUploadFileResponse, err error) {
key := fmt.Sprintf(uploadKey, req.UploadID)
exists, err := u.cache.Exists(ctx, key).Result()
if err != nil || exists == 0 {
return nil, fmt.Errorf("upload session invalid: %v", err)
}
crc32Val := crc32.ChecksumIEEE(req.Data)
partTosKey := fmt.Sprintf(partKey, req.UploadID, req.PartNumber)
err = u.oss.PutObject(ctx, partTosKey, req.Data, storage.WithExpires(time.Now().Add(10*time.Minute)))
if err != nil {
return nil, err
}
partMeta := map[string]interface{}{
"tos_key": partTosKey,
}
partMetaData, err := sonic.Marshal(partMeta)
if err != nil {
return nil, err
}
partKey := fmt.Sprintf(uploadPartKey, req.UploadID)
err = u.cache.HSet(ctx, partKey, req.PartNumber, string(partMetaData)).Err()
if err != nil {
return nil, err
}
err = u.cache.Expire(ctx, partKey, time.Minute*10).Err()
if err != nil {
return nil, err
}
return &PartUploadFileResponse{
Crc32: fmt.Sprintf("%08x", crc32Val),
}, nil
}
type tosPart struct {
PartNum int
Data []byte
}
func getContentType(uri string) (contentType string) {
_ = mime.AddExtensionType(".svg", "image/svg+xml")
_ = mime.AddExtensionType(".svgz", "image/svg+xml")
_ = mime.AddExtensionType(".webp", "image/webp")
_ = mime.AddExtensionType(".ico", "image/x-icon")
fileExtension := path.Base(uri)
ext := path.Ext(fileExtension)
contentType = mime.TypeByExtension(ext)
return
}
func (u *UploadService) PartUploadFileComplete(ctx context.Context, req *PartUploadFileCompleteRequest) error {
partKey := fmt.Sprintf(uploadPartKey, req.UploadID)
parts, err := u.cache.HGetAll(ctx, partKey).Result()
if err != nil {
return err
}
tosParts := []*tosPart{}
for partNumStr, partData := range parts {
var partMeta map[string]string
if err := sonic.Unmarshal([]byte(partData), &partMeta); err != nil {
return fmt.Errorf("failed to parse part metadata: %v", err)
}
partNum, err := strconv.ParseInt(partNumStr, 10, 64)
if err != nil {
return err
}
objKey, exist := partMeta["tos_key"]
if !exist {
return errors.New("tos key not exist")
}
byteData, err := u.oss.GetObject(ctx, objKey)
if err != nil {
return err
}
tosParts = append(tosParts, &tosPart{PartNum: int(partNum), Data: byteData})
}
if len(tosParts) == 0 {
return errors.New("tos part is null")
}
sort.Slice(tosParts, func(i, j int) bool { return tosParts[i].PartNum < tosParts[j].PartNum })
if tosParts[len(tosParts)-1].PartNum != len(tosParts) || len(tosParts) != len(req.Crc32Map) {
return errors.New("check parts fail")
}
totalData := []byte{}
for _, val := range tosParts {
crc32 := fmt.Sprintf("%08x", crc32.ChecksumIEEE(val.Data))
crc32Check := req.Crc32Map[strconv.Itoa(val.PartNum)]
if crc32 != crc32Check {
return errors.New("crc32 check fail")
}
totalData = append(totalData, val.Data...)
}
contentType := getContentType(req.ObjKey)
if len(contentType) != 0 {
err = u.oss.PutObject(ctx, req.ObjKey, totalData, storage.WithContentType(contentType))
} else {
err = u.oss.PutObject(ctx, req.ObjKey, totalData)
}
return err
}
func (u *UploadService) GetIcon(ctx context.Context, req *developer_api.GetIconRequest) (
resp *developer_api.GetIconResponse, err error,
) {
iconURI := map[developer_api.IconType]string{
developer_api.IconType_Bot: consts.DefaultAgentIcon,
developer_api.IconType_User: consts.DefaultUserIcon,
developer_api.IconType_Plugin: consts.DefaultPluginIcon,
developer_api.IconType_Dataset: consts.DefaultDatasetIcon,
developer_api.IconType_Workflow: consts.DefaultWorkflowIcon,
developer_api.IconType_Imageflow: consts.DefaultPluginIcon,
developer_api.IconType_Society: consts.DefaultPluginIcon,
developer_api.IconType_Connector: consts.DefaultPluginIcon,
developer_api.IconType_ChatFlow: consts.DefaultPluginIcon,
developer_api.IconType_Voice: consts.DefaultPluginIcon,
developer_api.IconType_Enterprise: consts.DefaultTeamIcon,
}
uri := iconURI[req.GetIconType()]
if uri == "" {
return nil, errorx.New(errno.ErrUploadInvalidType,
errorx.KV("type", conv.Int64ToStr(int64(req.GetIconType()))))
}
url, err := u.oss.GetObjectUrl(ctx, iconURI[req.GetIconType()])
if err != nil {
return nil, err
}
return &developer_api.GetIconResponse{
Data: &developer_api.GetIconResponseData{
IconList: []*developer_api.Icon{
{
URL: url,
URI: uri,
},
},
},
}, nil
}
func stringToMap(input string) map[string]string {
result := make(map[string]string)
pairs := strings.Split(input, ",")
for _, pair := range pairs {
parts := strings.Split(pair, ":")
if len(parts) == 2 {
key := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])
result[key] = value
}
}
return result
}
func (u *UploadService) UploadFileCommon(ctx context.Context, req *developer_api.CommonUploadRequest, fullPath string) (*developer_api.CommonUploadResponse, error) {
resp := developer_api.NewCommonUploadResponse()
re := regexp.MustCompile(`/api/playground/upload/([^?]+)`)
match := re.FindStringSubmatch(fullPath)
if len(match) == 0 {
return nil, errorx.New(errno.ErrUploadInvalidParamCode, errorx.KV("msg", "tos key not found"))
}
objKey := match[1]
if strings.Contains(fullPath, "?uploads") {
uploadID, err := u.PartUploadFileInit(ctx, objKey)
if err != nil {
return resp, errorx.New(errno.ErrUploadSystemErrorCode, errorx.KV("msg", err.Error()))
}
resp.Error = &developer_api.Error{Code: 200}
resp.Payload = &developer_api.Payload{UploadID: uploadID}
return resp, nil
}
if len(ptr.From(req.PartNumber)) != 0 {
_, err := u.PartUploadFile(ctx, &PartUploadFileRequest{
UploadID: ptr.From(req.UploadID),
PartNumber: ptr.From(req.PartNumber),
Data: req.ByteData,
})
if err != nil {
return resp, errorx.New(errno.ErrUploadSystemErrorCode, errorx.KV("msg", err.Error()))
}
resp.Error = &developer_api.Error{Code: 200}
return resp, nil
}
if len(ptr.From(req.UploadID)) != 0 {
mp := stringToMap(string(req.ByteData))
err := u.PartUploadFileComplete(ctx, &PartUploadFileCompleteRequest{
UploadID: ptr.From(req.UploadID),
ObjKey: objKey,
Crc32Map: mp,
})
if err != nil {
return resp, errorx.New(errno.ErrUploadSystemErrorCode, errorx.KV("msg", err.Error()))
}
resp.Error = &developer_api.Error{Code: 200}
resp.Payload = &developer_api.Payload{Key: uuid.NewString()}
return resp, nil
}
var err error
contentType := getContentType(objKey)
if len(contentType) != 0 {
err = u.oss.PutObject(ctx, objKey, req.ByteData, storage.WithContentType(contentType))
} else {
err = u.oss.PutObject(ctx, objKey, req.ByteData)
}
if err != nil {
return resp, errorx.New(errno.ErrUploadSystemErrorCode, errorx.KV("msg", err.Error()))
}
resp.Error = &developer_api.Error{Code: 200}
resp.Payload = &developer_api.Payload{Key: uuid.NewString()}
return resp, err
}
func (u *UploadService) UploadFile(ctx context.Context, data []byte, objKey string) (*developer_api.UploadFileResponse, error) {
err := u.oss.PutObject(ctx, objKey, data)
if err != nil {
return nil, err
}
url, err := u.oss.GetObjectUrl(ctx, objKey)
if err != nil {
return nil, err
}
return &developer_api.UploadFileResponse{
Data: &developer_api.UploadFileData{
UploadURL: url,
UploadURI: objKey,
},
}, nil
}
func (u *UploadService) GetShortcutIcons(ctx context.Context) ([]*playground.FileInfo, error) {
shortcutIcons := entity.GetDefaultShortcutIconURI()
fileList := make([]*playground.FileInfo, 0, len(shortcutIcons))
for _, uri := range shortcutIcons {
url, err := u.oss.GetObjectUrl(ctx, uri)
if err == nil {
fileList = append(fileList, &playground.FileInfo{
URL: url,
URI: uri,
})
}
}
return fileList, nil
}
func parseMultipartFormData(ctx context.Context, req *playground.UploadFileOpenRequest) (*multipart.Form, error) {
_, params, err := mime.ParseMediaType(req.ContentType)
if err != nil {
return nil, errorx.New(errno.ErrUploadInvalidContentTypeCode, errorx.KV("content-type", req.ContentType))
}
br := bytes.NewReader(req.Data)
mr := multipart.NewReader(br, params["boundary"])
form, err := mr.ReadForm(maxFileSize)
if errors.Is(err, multipart.ErrMessageTooLarge) {
return nil, errorx.New(errno.ErrUploadInvalidFileSizeCode)
} else if err != nil {
return nil, errorx.New(errno.ErrUploadMultipartFormDataReadFailedCode)
}
return form, nil
}
func genObjName(name string, id string) string {
return fmt.Sprintf("%s/%s/%s",
"bot_files",
id,
name,
)
}
func (u *UploadService) UploadFileOpen(ctx context.Context, req *playground.UploadFileOpenRequest) (*playground.UploadFileOpenResponse, error) {
resp := playground.UploadFileOpenResponse{}
resp.File = new(playground.File)
uid := ctxutil.MustGetUIDFromApiAuthCtx(ctx)
if uid == 0 {
return nil, errorx.New(errno.ErrKnowledgePermissionCode, errorx.KV("msg", "session required"))
}
form, err := parseMultipartFormData(ctx, req)
if err != nil {
logs.CtxErrorf(ctx, "parse multipart form data failed, err: %v", err)
return nil, err
}
if len(form.File["file"]) == 0 {
return nil, errorx.New(errno.ErrUploadEmptyFileCode)
} else if len(form.File["file"]) > 1 {
return nil, errorx.New(errno.ErrUploadFileUploadGreaterOneCode)
}
fileHeader := form.File["file"][0]
// open file
file, err := fileHeader.Open()
if err != nil {
return nil, errorx.New(errno.ErrUploadSystemErrorCode, errorx.KV("msg", "fileHeader open failed"))
}
defer file.Close()
data, err := io.ReadAll(file)
if err != nil {
return nil, errorx.New(errno.ErrUploadSystemErrorCode, errorx.KV("msg", "file upload io read failed"))
}
resp.File.Bytes = int64(len(data))
randID := uuid.NewString()
objName := genObjName(fileHeader.Filename, randID)
resp.File.FileName = fileHeader.Filename
resp.File.URI = objName
err = u.oss.PutObject(ctx, objName, data)
if err != nil {
return nil, errorx.New(errno.ErrUploadSystemErrorCode, errorx.KV("msg", "file upload to oss failed"))
}
url, err := u.oss.GetObjectUrl(ctx, objName)
if err != nil {
return nil, errorx.New(errno.ErrUploadSystemErrorCode, errorx.KV("msg", "get object url failed"))
}
resp.File.CreatedAt = time.Now().Unix()
resp.File.URL = url
return &resp, nil
}
func (u *UploadService) GetIconForDataset(ctx context.Context, req *dataset.GetIconRequest) (*dataset.GetIconResponse, error) {
resp := dataset.NewGetIconResponse()
var uri string
switch req.FormatType {
case dataset.FormatType_Text:
uri = TextKnowledgeDefaultIcon
case dataset.FormatType_Table:
uri = TableKnowledgeDefaultIcon
case dataset.FormatType_Image:
uri = ImageKnowledgeDefaultIcon
case dataset.FormatType_Database:
uri = DatabaseDefaultIcon
default:
uri = TextKnowledgeDefaultIcon
}
iconUrl, err := u.oss.GetObjectUrl(ctx, uri)
if err != nil {
return resp, err
}
resp.Icon = &dataset.Icon{
URL: iconUrl,
URI: uri,
}
return resp, nil
}
func (u *UploadService) UploadSessionKey(ctx context.Context, sessionKey string, tosKey string) error {
return u.cache.Set(ctx, sessionKey, tosKey, time.Minute*30).Err()
}
type GetObjInfoBySessionKey struct {
ObjKey string
Width int32
Height int32
}
func isImageUri(uri string) bool {
if uri == "" {
return false
}
uri = strings.ToLower(uri)
fileExtension := path.Base(uri)
ext := path.Ext(fileExtension)
ext = ext[1:]
imageExtensions := map[string]bool{
"jpg": true,
"jpeg": true,
"png": true,
"gif": true,
"bmp": true,
"webp": true,
"tiff": true,
"svg": true,
"ico": true,
}
// 检查扩展名是否在图片扩展名列表中
return imageExtensions[ext]
}
func (u *UploadService) GetObjInfoBySessionKey(ctx context.Context, sessionKey string) (*GetObjInfoBySessionKey, error) {
resp := GetObjInfoBySessionKey{}
objKey, err := u.cache.Get(ctx, sessionKey).Result()
if err != nil {
return nil, err
}
resp.ObjKey = objKey
if isImageUri(objKey) {
content, err := u.oss.GetObject(ctx, objKey)
if err != nil {
return nil, err
}
if isSVG(objKey) {
width, height, err := getSVGDimensions(content)
if err != nil {
logs.CtxErrorf(ctx, "get svg dimensions failed, err: %v", err)
// default val
resp.Width = 100
resp.Height = 100
return &resp, nil
}
resp.Width = width
resp.Height = height
} else {
img, _, err := image.Decode(bytes.NewReader(content))
if err != nil {
logs.CtxErrorf(ctx, "decode image failed, err: %v", err)
// default val
resp.Width = 100
resp.Height = 100
return &resp, nil
}
resp.Width = int32(img.Bounds().Dx())
resp.Height = int32(img.Bounds().Dy())
}
}
return &resp, nil
}
type SVG struct {
Width string `xml:"width,attr"`
Height string `xml:"height,attr"`
ViewBox string `xml:"viewBox,attr"`
}
// 获取 SVG 尺寸
func getSVGDimensions(content []byte) (width, height int32, err error) {
decoder := xml.NewDecoder(bytes.NewReader(content))
var svg SVG
if err := decoder.Decode(&svg); err != nil {
return 100, 100, nil
}
// 尝试从width属性获取
if svg.Width != "" {
w, err := parseDimension(svg.Width)
if err == nil {
width = w
}
}
// 尝试从height属性获取
if svg.Height != "" {
h, err := parseDimension(svg.Height)
if err == nil {
height = h
}
}
// 如果width或height未设置尝试从viewBox获取
if width == 0 || height == 0 {
if svg.ViewBox != "" {
parts := strings.Fields(svg.ViewBox)
if len(parts) >= 4 {
if width == 0 {
w, err := strconv.ParseInt(parts[2], 10, 32)
if err == nil {
width = int32(w)
}
}
if height == 0 {
h, err := strconv.ParseInt(parts[3], 10, 32)
if err == nil {
height = int32(h)
}
}
}
}
}
if width == 0 || height == 0 {
return 100, 100, nil
}
return width, height, nil
}
func parseDimension(dim string) (int32, error) {
// 去除单位(px, pt, em, %等)和空格
dim = strings.TrimSpace(dim)
dim = strings.TrimRightFunc(dim, func(r rune) bool {
return (r < '0' || r > '9') && r != '.' && r != '-' && r != '+'
})
// 解析为float64
value, err := strconv.ParseFloat(dim, 64)
if err != nil {
return 0, err
}
// 四舍五入转换为int32
if value > math.MaxInt32 {
return math.MaxInt32, nil
}
if value < math.MinInt32 {
return math.MinInt32, nil
}
return int32(math.Round(value)), nil
}
func isSVG(uri string) bool {
uri = strings.ToLower(uri)
fileExtension := path.Base(uri)
ext := path.Ext(fileExtension)
ext = ext[1:]
return ext == "svg"
}
func (u *UploadService) ApplyImageUpload(ctx context.Context, req *developer_api.ApplyUploadActionRequest, host string) (*developer_api.ApplyUploadActionResponse, error) {
resp := developer_api.ApplyUploadActionResponse{}
storeUri := "tos-cn-i-v4nquku3lp/" + uuid.NewString() + ptr.From(req.FileExtension)
sessionKey := uuid.NewString()
auth := uuid.NewString()
uploadID := uuid.NewString()
uploadHost := string(host) + consts.UploadURI
resp.ResponseMetadata = &developer_api.ResponseMetadata{
RequestId: uuid.NewString(),
Action: "ApplyImageUpload",
Version: "",
Service: "",
Region: "",
}
resp.Result = &developer_api.ApplyUploadActionResult{
UploadAddress: &developer_api.UploadAddress{
StoreInfos: []*developer_api.StoreInfo{
{
StoreUri: storeUri,
Auth: auth,
UploadID: uploadID,
},
},
UploadHosts: []string{uploadHost},
SessionKey: sessionKey,
},
InnerUploadAddress: &developer_api.InnerUploadAddress{
UploadNodes: []*developer_api.UploadNode{
{
StoreInfos: []*developer_api.StoreInfo{
{
StoreUri: storeUri,
Auth: auth,
UploadID: uploadID,
},
},
UploadHost: uploadHost,
SessionKey: sessionKey,
},
},
},
RequestId: ptr.Of(uuid.NewString()),
}
err := u.UploadSessionKey(ctx, sessionKey, storeUri)
if err != nil {
return &resp, errorx.New(errno.ErrUploadSystemErrorCode, errorx.KV("msg", err.Error()))
}
return &resp, nil
}
func (u *UploadService) CommitImageUpload(ctx context.Context, req *developer_api.ApplyUploadActionRequest, host string) (*developer_api.ApplyUploadActionResponse, error) {
resp := developer_api.ApplyUploadActionResponse{}
type ssKey struct {
SessionKey string `json:"SessionKey"`
}
sskey := ssKey{}
err := sonic.Unmarshal(req.ByteData, &sskey)
if err != nil {
return &resp, errorx.New(errno.ErrUploadSystemErrorCode, errorx.KV("msg", err.Error()))
}
objInfo, err := u.GetObjInfoBySessionKey(ctx, sskey.SessionKey)
if err != nil {
return &resp, errorx.New(errno.ErrUploadSystemErrorCode, errorx.KV("msg", err.Error()))
}
resp.ResponseMetadata = &developer_api.ResponseMetadata{
RequestId: uuid.NewString(),
Action: "ApplyImageUpload",
Version: "",
Service: "",
Region: "",
}
resp.Result = &developer_api.ApplyUploadActionResult{
Results: []*developer_api.UploadResult{
{
Uri: objInfo.ObjKey,
UriStatus: 2000,
},
},
RequestId: ptr.Of(uuid.NewString()),
PluginResult: []*developer_api.PluginResult{
{
FileName: objInfo.ObjKey,
SourceUri: objInfo.ObjKey,
ImageUri: objInfo.ObjKey,
ImageWidth: objInfo.Width,
ImageHeight: objInfo.Height,
},
},
}
return &resp, nil
}

View File

@@ -0,0 +1,40 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package user
import (
"context"
"gorm.io/gorm"
"github.com/coze-dev/coze-studio/backend/domain/user/repository"
"github.com/coze-dev/coze-studio/backend/domain/user/service"
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
"github.com/coze-dev/coze-studio/backend/infra/impl/idgen"
)
func InitService(ctx context.Context, db *gorm.DB, oss storage.Storage, idgen idgen.IDGenerator) *UserApplicationService {
UserApplicationSVC.DomainSVC = service.NewUserDomain(ctx, &service.Components{
IconOSS: oss,
IDGen: idgen,
UserRepo: repository.NewUserRepo(db),
SpaceRepo: repository.NewSpaceRepo(db),
})
UserApplicationSVC.oss = oss
return UserApplicationSVC
}

View File

@@ -0,0 +1,318 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package user
import (
"context"
"net/mail"
"strconv"
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/developer_api"
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/playground"
"github.com/coze-dev/coze-studio/backend/api/model/passport"
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
"github.com/coze-dev/coze-studio/backend/domain/user/entity"
user "github.com/coze-dev/coze-studio/backend/domain/user/service"
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
var UserApplicationSVC = &UserApplicationService{}
type UserApplicationService struct {
oss storage.Storage
DomainSVC user.User
}
// 添加一个简单的 email 验证函数
func isValidEmail(email string) bool {
// 如果 email 字符串格式不正确,它会返回一个 error
_, err := mail.ParseAddress(email)
return err == nil
}
func (u *UserApplicationService) PassportWebEmailRegisterV2(ctx context.Context, locale string, req *passport.PassportWebEmailRegisterV2PostRequest) (
resp *passport.PassportWebEmailRegisterV2PostResponse, sessionKey string, err error,
) {
// 验证 email 格式是否合法
if !isValidEmail(req.GetEmail()) {
return nil, "", errorx.New(errno.ErrUserInvalidParamCode, errorx.KV("msg", "Invalid email"))
}
userInfo, err := u.DomainSVC.Create(ctx, &user.CreateUserRequest{
Email: req.GetEmail(),
Password: req.GetPassword(),
Locale: locale,
})
if err != nil {
return nil, "", err
}
userInfo, err = u.DomainSVC.Login(ctx, req.GetEmail(), req.GetPassword())
if err != nil {
return nil, "", err
}
return &passport.PassportWebEmailRegisterV2PostResponse{
Data: userDo2PassportTo(userInfo),
Code: 0,
}, userInfo.SessionKey, nil
}
// PassportWebLogoutGet 处理用户登出请求
func (u *UserApplicationService) PassportWebLogoutGet(ctx context.Context, req *passport.PassportWebLogoutGetRequest) (
resp *passport.PassportWebLogoutGetResponse, err error,
) {
uid := ctxutil.MustGetUIDFromCtx(ctx)
err = u.DomainSVC.Logout(ctx, uid)
if err != nil {
return nil, err
}
return &passport.PassportWebLogoutGetResponse{
Code: 0,
}, nil
}
// PassportWebEmailLoginPost 处理用户邮箱登录请求
func (u *UserApplicationService) PassportWebEmailLoginPost(ctx context.Context, req *passport.PassportWebEmailLoginPostRequest) (
resp *passport.PassportWebEmailLoginPostResponse, sessionKey string, err error,
) {
userInfo, err := u.DomainSVC.Login(ctx, req.GetEmail(), req.GetPassword())
if err != nil {
return nil, "", err
}
return &passport.PassportWebEmailLoginPostResponse{
Data: userDo2PassportTo(userInfo),
Code: 0,
}, userInfo.SessionKey, nil
}
func (u *UserApplicationService) PassportWebEmailPasswordResetGet(ctx context.Context, req *passport.PassportWebEmailPasswordResetGetRequest) (
resp *passport.PassportWebEmailPasswordResetGetResponse, err error,
) {
err = u.DomainSVC.ResetPassword(ctx, req.GetEmail(), req.GetPassword())
if err != nil {
return nil, err
}
return &passport.PassportWebEmailPasswordResetGetResponse{
Code: 0,
}, nil
}
func (u *UserApplicationService) PassportAccountInfoV2(ctx context.Context, req *passport.PassportAccountInfoV2Request) (
resp *passport.PassportAccountInfoV2Response, err error,
) {
userID := ctxutil.MustGetUIDFromCtx(ctx)
userInfo, err := u.DomainSVC.GetUserInfo(ctx, userID)
if err != nil {
return nil, err
}
return &passport.PassportAccountInfoV2Response{
Data: userDo2PassportTo(userInfo),
Code: 0,
}, nil
}
// UserUpdateAvatar 更新用户头像
func (u *UserApplicationService) UserUpdateAvatar(ctx context.Context, mimeType string, req *passport.UserUpdateAvatarRequest) (
resp *passport.UserUpdateAvatarResponse, err error,
) {
// 根据 MIME type 获取文件后缀
var ext string
switch mimeType {
case "image/jpeg", "image/jpg":
ext = "jpg"
case "image/png":
ext = "png"
case "image/gif":
ext = "gif"
case "image/webp":
ext = "webp"
default:
return nil, errorx.WrapByCode(err, errno.ErrUserInvalidParamCode,
errorx.KV("msg", "unsupported image type"))
}
uid := ctxutil.MustGetUIDFromCtx(ctx)
url, err := u.DomainSVC.UpdateAvatar(ctx, uid, ext, req.GetAvatar())
if err != nil {
return nil, err
}
return &passport.UserUpdateAvatarResponse{
Data: &passport.UserUpdateAvatarResponseData{
WebURI: url,
},
Code: 0,
}, nil
}
// UserUpdateProfile 更新用户资料
func (u *UserApplicationService) UserUpdateProfile(ctx context.Context, req *passport.UserUpdateProfileRequest) (
resp *passport.UserUpdateProfileResponse, err error,
) {
userID := ctxutil.MustGetUIDFromCtx(ctx)
err = u.DomainSVC.UpdateProfile(ctx, &user.UpdateProfileRequest{
UserID: userID,
Name: req.Name,
UniqueName: req.UserUniqueName,
Description: req.Description,
Locale: req.Locale,
})
if err != nil {
return nil, err
}
return &passport.UserUpdateProfileResponse{
Code: 0,
}, nil
}
func (u *UserApplicationService) GetSpaceListV2(ctx context.Context, req *playground.GetSpaceListV2Request) (
resp *playground.GetSpaceListV2Response, err error,
) {
uid := ctxutil.MustGetUIDFromCtx(ctx)
spaces, err := u.DomainSVC.GetUserSpaceList(ctx, uid)
if err != nil {
return nil, err
}
botSpaces := slices.Transform(spaces, func(space *entity.Space) *playground.BotSpaceV2 {
return &playground.BotSpaceV2{
ID: space.ID,
Name: space.Name,
Description: space.Description,
SpaceType: playground.SpaceType(space.SpaceType),
IconURL: space.IconURL,
}
})
return &playground.GetSpaceListV2Response{
Data: &playground.SpaceInfo{
BotSpaceList: botSpaces,
HasPersonalSpace: true,
TeamSpaceNum: 0,
RecentlyUsedSpaceList: botSpaces,
Total: ptr.Of(int32(len(botSpaces))),
HasMore: ptr.Of(false),
},
Code: 0,
}, nil
}
func (u *UserApplicationService) MGetUserBasicInfo(ctx context.Context, req *playground.MGetUserBasicInfoRequest) (
resp *playground.MGetUserBasicInfoResponse, err error,
) {
userIDs, err := slices.TransformWithErrorCheck(req.GetUserIds(), func(s string) (int64, error) {
return strconv.ParseInt(s, 10, 64)
})
if err != nil {
return nil, errorx.WrapByCode(err, errno.ErrUserInvalidParamCode, errorx.KV("msg", "invalid user id"))
}
userInfos, err := u.DomainSVC.MGetUserProfiles(ctx, userIDs)
if err != nil {
return nil, err
}
return &playground.MGetUserBasicInfoResponse{
UserBasicInfoMap: slices.ToMap(userInfos, func(userInfo *entity.User) (string, *playground.UserBasicInfo) {
return strconv.FormatInt(userInfo.UserID, 10), userDo2PlaygroundTo(userInfo)
}),
Code: 0,
}, nil
}
func (u *UserApplicationService) UpdateUserProfileCheck(ctx context.Context, req *developer_api.UpdateUserProfileCheckRequest) (resp *developer_api.UpdateUserProfileCheckResponse, err error) {
if req.GetUserUniqueName() == "" {
return &developer_api.UpdateUserProfileCheckResponse{
Code: 0,
Msg: "no content to update",
}, nil
}
validateResp, err := u.DomainSVC.ValidateProfileUpdate(ctx, &user.ValidateProfileUpdateRequest{
UniqueName: req.UserUniqueName,
})
if err != nil {
return nil, err
}
return &developer_api.UpdateUserProfileCheckResponse{
Code: int64(validateResp.Code),
Msg: validateResp.Msg,
}, nil
}
func (u *UserApplicationService) ValidateSession(ctx context.Context, sessionKey string) (*entity.Session, error) {
session, exist, err := u.DomainSVC.ValidateSession(ctx, sessionKey)
if err != nil {
return nil, err
}
if !exist {
return nil, errorx.New(errno.ErrUserAuthenticationFailed, errorx.KV("reason", "session not exist"))
}
return session, nil
}
func userDo2PassportTo(userDo *entity.User) *passport.User {
var locale *string
if userDo.Locale != "" {
locale = ptr.Of(userDo.Locale)
}
return &passport.User{
UserIDStr: userDo.UserID,
Name: userDo.Name,
ScreenName: ptr.Of(userDo.Name),
UserUniqueName: userDo.UniqueName,
Email: userDo.Email,
Description: userDo.Description,
AvatarURL: userDo.IconURL,
AppUserInfo: &passport.AppUserInfo{
UserUniqueName: userDo.UniqueName,
},
Locale: locale,
UserCreateTime: userDo.CreatedAt / 1000,
}
}
func userDo2PlaygroundTo(userDo *entity.User) *playground.UserBasicInfo {
return &playground.UserBasicInfo{
UserId: userDo.UserID,
Username: userDo.Name,
UserUniqueName: ptr.Of(userDo.UniqueName),
UserAvatar: userDo.IconURL,
CreateTime: ptr.Of(userDo.CreatedAt / 1000),
}
}

View File

@@ -0,0 +1,87 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package workflow
import (
"github.com/cloudwego/eino/compose"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
wfdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/database"
wfknowledge "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/knowledge"
wfmodel "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/model"
wfplugin "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/plugin"
wfsearch "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/search"
"github.com/coze-dev/coze-studio/backend/crossdomain/workflow/variable"
knowledge "github.com/coze-dev/coze-studio/backend/domain/knowledge/service"
dbservice "github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
variables "github.com/coze-dev/coze-studio/backend/domain/memory/variables/service"
"github.com/coze-dev/coze-studio/backend/domain/modelmgr"
plugin "github.com/coze-dev/coze-studio/backend/domain/plugin/service"
search "github.com/coze-dev/coze-studio/backend/domain/search/service"
"github.com/coze-dev/coze-studio/backend/domain/workflow"
crosscode "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/code"
crossdatabase "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
crossknowledge "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
crossmodel "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
crossplugin "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
crosssearch "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/search"
crossvariable "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
"github.com/coze-dev/coze-studio/backend/domain/workflow/service"
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
"github.com/coze-dev/coze-studio/backend/infra/impl/coderunner"
)
type ServiceComponents struct {
IDGen idgen.IDGenerator
DB *gorm.DB
Cache *redis.Client
DatabaseDomainSVC dbservice.Database
VariablesDomainSVC variables.Variables
PluginDomainSVC plugin.PluginService
KnowledgeDomainSVC knowledge.Knowledge
ModelManager modelmgr.Manager
DomainNotifier search.ResourceEventBus
Tos storage.Storage
ImageX imagex.ImageX
CPStore compose.CheckPointStore
}
func InitService(components *ServiceComponents) *ApplicationService {
workflowRepo := service.NewWorkflowRepository(components.IDGen, components.DB, components.Cache,
components.Tos, components.CPStore)
workflow.SetRepository(workflowRepo)
workflowDomainSVC := service.NewWorkflowService(workflowRepo)
crossdatabase.SetDatabaseOperator(wfdatabase.NewDatabaseRepository(components.DatabaseDomainSVC))
crossvariable.SetVariableHandler(variable.NewVariableHandler(components.VariablesDomainSVC))
crossvariable.SetVariablesMetaGetter(variable.NewVariablesMetaGetter(components.VariablesDomainSVC))
crossplugin.SetPluginService(wfplugin.NewPluginService(components.PluginDomainSVC, components.Tos))
crossknowledge.SetKnowledgeOperator(wfknowledge.NewKnowledgeRepository(components.KnowledgeDomainSVC, components.IDGen))
crossmodel.SetManager(wfmodel.NewModelManager(components.ModelManager, nil))
crosscode.SetCodeRunner(coderunner.NewRunner())
crosssearch.SetNotifier(wfsearch.NewNotify(components.DomainNotifier))
SVC.DomainSVC = workflowDomainSVC
SVC.ImageX = components.ImageX
SVC.TosClient = components.Tos
SVC.IDGenerator = components.IDGen
return SVC
}

File diff suppressed because it is too large Load Diff