feat: manually mirror opencoze's code from bytedance
Change-Id: I09a73aadda978ad9511264a756b2ce51f5761adf
This commit is contained in:
387
backend/domain/plugin/service/agent_tool.go
Normal file
387
backend/domain/plugin/service/agent_tool.go
Normal file
@@ -0,0 +1,387 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/getkin/kin-openapi/openapi3"
|
||||
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/repository"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
func (p *pluginServiceImpl) BindAgentTools(ctx context.Context, agentID int64, toolIDs []int64) (err error) {
|
||||
return p.toolRepo.BindDraftAgentTools(ctx, agentID, toolIDs)
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) DuplicateDraftAgentTools(ctx context.Context, fromAgentID, toAgentID int64) (err error) {
|
||||
return p.toolRepo.DuplicateDraftAgentTools(ctx, fromAgentID, toAgentID)
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) GetDraftAgentToolByName(ctx context.Context, agentID int64, toolName string) (tool *entity.ToolInfo, err error) {
|
||||
draftAgentTool, exist, err := p.toolRepo.GetDraftAgentToolWithToolName(ctx, agentID, toolName)
|
||||
if err != nil {
|
||||
return nil, errorx.Wrapf(err, "GetDraftAgentToolWithToolName failed, agentID=%d, toolName=%s", agentID, toolName)
|
||||
}
|
||||
if !exist {
|
||||
return nil, errorx.New(errno.ErrPluginRecordNotFound)
|
||||
}
|
||||
|
||||
tool, exist, err = p.toolRepo.GetOnlineTool(ctx, draftAgentTool.ID)
|
||||
if err != nil {
|
||||
return nil, errorx.Wrapf(err, "GetOnlineTool failed, id=%d", draftAgentTool.ID)
|
||||
}
|
||||
if !exist {
|
||||
return nil, errorx.New(errno.ErrPluginRecordNotFound)
|
||||
}
|
||||
|
||||
draftAgentTool, err = mergeAgentToolInfo(ctx, tool, draftAgentTool)
|
||||
if err != nil {
|
||||
return nil, errorx.Wrapf(err, "mergeAgentToolInfo failed")
|
||||
}
|
||||
|
||||
return draftAgentTool, nil
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) MGetAgentTools(ctx context.Context, req *MGetAgentToolsRequest) (tools []*entity.ToolInfo, err error) {
|
||||
toolIDs := make([]int64, 0, len(req.VersionAgentTools))
|
||||
for _, v := range req.VersionAgentTools {
|
||||
toolIDs = append(toolIDs, v.ToolID)
|
||||
}
|
||||
|
||||
existTools, err := p.toolRepo.MGetOnlineTools(ctx, toolIDs, repository.WithToolID())
|
||||
if err != nil {
|
||||
return nil, errorx.Wrapf(err, "MGetOnlineTools failed, toolIDs=%v", toolIDs)
|
||||
}
|
||||
|
||||
if len(existTools) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
existMap := make(map[int64]bool, len(existTools))
|
||||
for _, tool := range existTools {
|
||||
existMap[tool.ID] = true
|
||||
}
|
||||
|
||||
if req.IsDraft {
|
||||
existToolIDs := make([]int64, 0, len(existMap))
|
||||
for _, v := range req.VersionAgentTools {
|
||||
if existMap[v.ToolID] {
|
||||
existToolIDs = append(existToolIDs, v.ToolID)
|
||||
}
|
||||
}
|
||||
|
||||
tools, err = p.toolRepo.MGetDraftAgentTools(ctx, req.AgentID, existToolIDs)
|
||||
if err != nil {
|
||||
return nil, errorx.Wrapf(err, "MGetDraftAgentTools failed, agentID=%d, toolIDs=%v", req.AgentID, existToolIDs)
|
||||
}
|
||||
|
||||
return tools, nil
|
||||
}
|
||||
|
||||
vTools := make([]entity.VersionAgentTool, 0, len(existMap))
|
||||
for _, v := range req.VersionAgentTools {
|
||||
if existMap[v.ToolID] {
|
||||
vTools = append(vTools, v)
|
||||
}
|
||||
}
|
||||
|
||||
tools, err = p.toolRepo.MGetVersionAgentTool(ctx, req.AgentID, vTools)
|
||||
if err != nil {
|
||||
return nil, errorx.Wrapf(err, "MGetVersionAgentTool failed, agentID=%d, vTools=%v", req.AgentID, vTools)
|
||||
}
|
||||
|
||||
return tools, nil
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) PublishAgentTools(ctx context.Context, agentID int64, agentVersion string) (err error) {
|
||||
tools, err := p.toolRepo.GetSpaceAllDraftAgentTools(ctx, agentID)
|
||||
if err != nil {
|
||||
return errorx.Wrapf(err, "GetSpaceAllDraftAgentTools failed, agentID=%d", agentID)
|
||||
}
|
||||
|
||||
err = p.toolRepo.BatchCreateVersionAgentTools(ctx, agentID, agentVersion, tools)
|
||||
if err != nil {
|
||||
return errorx.Wrapf(err, "BatchCreateVersionAgentTools failed, agentID=%d, agentVersion=%s", agentID, agentVersion)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) UpdateBotDefaultParams(ctx context.Context, req *UpdateBotDefaultParamsRequest) (err error) {
|
||||
_, exist, err := p.pluginRepo.GetOnlinePlugin(ctx, req.PluginID, repository.WithPluginID())
|
||||
if err != nil {
|
||||
return errorx.Wrapf(err, "GetOnlinePlugin failed, pluginID=%d", req.PluginID)
|
||||
}
|
||||
if !exist {
|
||||
return errorx.New(errno.ErrPluginRecordNotFound)
|
||||
}
|
||||
|
||||
draftAgentTool, exist, err := p.toolRepo.GetDraftAgentToolWithToolName(ctx, req.AgentID, req.ToolName)
|
||||
if err != nil {
|
||||
return errorx.Wrapf(err, "GetDraftAgentToolWithToolName failed, agentID=%d, toolName=%s", req.AgentID, req.ToolName)
|
||||
}
|
||||
if !exist {
|
||||
return errorx.New(errno.ErrPluginRecordNotFound)
|
||||
}
|
||||
|
||||
onlineTool, exist, err := p.toolRepo.GetOnlineTool(ctx, draftAgentTool.ID)
|
||||
if err != nil {
|
||||
return errorx.Wrapf(err, "GetOnlineTool failed, id=%d", draftAgentTool.ID)
|
||||
}
|
||||
if !exist {
|
||||
return errorx.New(errno.ErrPluginRecordNotFound)
|
||||
}
|
||||
|
||||
op := onlineTool.Operation
|
||||
|
||||
if req.Parameters != nil {
|
||||
op.Parameters = req.Parameters
|
||||
}
|
||||
|
||||
if req.RequestBody != nil {
|
||||
mType, ok := req.RequestBody.Value.Content[model.MediaTypeJson]
|
||||
if !ok {
|
||||
return fmt.Errorf("the '%s' media type is not defined in request body", model.MediaTypeJson)
|
||||
}
|
||||
if op.RequestBody.Value.Content == nil {
|
||||
op.RequestBody.Value.Content = map[string]*openapi3.MediaType{}
|
||||
}
|
||||
op.RequestBody.Value.Content[model.MediaTypeJson] = mType
|
||||
}
|
||||
|
||||
if req.Responses != nil {
|
||||
newRespRef, ok := req.Responses[strconv.Itoa(http.StatusOK)]
|
||||
if !ok {
|
||||
return fmt.Errorf("the '%d' status code is not defined in responses", http.StatusOK)
|
||||
}
|
||||
newMIMEType, ok := newRespRef.Value.Content[model.MediaTypeJson]
|
||||
if !ok {
|
||||
return fmt.Errorf("the '%s' media type is not defined in responses", model.MediaTypeJson)
|
||||
}
|
||||
|
||||
if op.Responses == nil {
|
||||
op.Responses = map[string]*openapi3.ResponseRef{}
|
||||
}
|
||||
|
||||
oldRespRef, ok := op.Responses[strconv.Itoa(http.StatusOK)]
|
||||
if !ok {
|
||||
oldRespRef = &openapi3.ResponseRef{
|
||||
Value: &openapi3.Response{
|
||||
Content: map[string]*openapi3.MediaType{},
|
||||
},
|
||||
}
|
||||
op.Responses[strconv.Itoa(http.StatusOK)] = oldRespRef
|
||||
}
|
||||
|
||||
if oldRespRef.Value.Content == nil {
|
||||
oldRespRef.Value.Content = map[string]*openapi3.MediaType{}
|
||||
}
|
||||
|
||||
oldRespRef.Value.Content[model.MediaTypeJson] = newMIMEType
|
||||
}
|
||||
|
||||
updatedTool := &entity.ToolInfo{
|
||||
Version: onlineTool.Version,
|
||||
Method: onlineTool.Method,
|
||||
SubURL: onlineTool.SubURL,
|
||||
Operation: op,
|
||||
}
|
||||
err = p.toolRepo.UpdateDraftAgentTool(ctx, &repository.UpdateDraftAgentToolRequest{
|
||||
AgentID: req.AgentID,
|
||||
ToolName: req.ToolName,
|
||||
Tool: updatedTool,
|
||||
})
|
||||
if err != nil {
|
||||
return errorx.Wrapf(err, "UpdateDraftAgentTool failed, agentID=%d, toolName=%s", req.AgentID, req.ToolName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func mergeAgentToolInfo(ctx context.Context, dest, src *entity.ToolInfo) (*entity.ToolInfo, error) {
|
||||
dest.Version = src.Version
|
||||
dest.Method = src.Method
|
||||
dest.SubURL = src.SubURL
|
||||
|
||||
newParameters, err := mergeParameters(ctx, dest.Operation.Parameters, src.Operation.Parameters)
|
||||
if err != nil {
|
||||
return nil, errorx.Wrapf(err, "mergeParameters failed")
|
||||
}
|
||||
|
||||
dest.Operation.Parameters = newParameters
|
||||
|
||||
newReqBody, err := mergeRequestBody(ctx, dest.Operation.RequestBody, src.Operation.RequestBody)
|
||||
if err != nil {
|
||||
return nil, errorx.Wrapf(err, "mergeRequestBody failed")
|
||||
}
|
||||
|
||||
dest.Operation.RequestBody = newReqBody
|
||||
|
||||
newRespBody, err := mergeResponseBody(ctx, dest.Operation.Responses, src.Operation.Responses)
|
||||
if err != nil {
|
||||
return nil, errorx.Wrapf(err, "mergeResponseBody failed")
|
||||
}
|
||||
|
||||
dest.Operation.Responses = newRespBody
|
||||
|
||||
return dest, nil
|
||||
}
|
||||
|
||||
func mergeParameters(ctx context.Context, dest, src openapi3.Parameters) (openapi3.Parameters, error) {
|
||||
if len(dest) == 0 || len(src) == 0 {
|
||||
return src, nil
|
||||
}
|
||||
|
||||
srcMap := make(map[string]*openapi3.ParameterRef, len(src))
|
||||
for _, p := range src {
|
||||
srcMap[p.Value.Name] = p
|
||||
}
|
||||
|
||||
for _, dp := range dest {
|
||||
sp, ok := srcMap[dp.Value.Name]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
dv := dp.Value.Schema.Value
|
||||
sv := sp.Value.Schema.Value
|
||||
|
||||
if dv.Extensions == nil {
|
||||
dv.Extensions = make(map[string]any)
|
||||
}
|
||||
|
||||
if v, ok := sv.Extensions[model.APISchemaExtendLocalDisable]; ok {
|
||||
dv.Extensions[model.APISchemaExtendLocalDisable] = v
|
||||
}
|
||||
|
||||
if v, ok := sv.Extensions[model.APISchemaExtendVariableRef]; ok {
|
||||
dv.Extensions[model.APISchemaExtendVariableRef] = v
|
||||
}
|
||||
|
||||
dv.Default = sv.Default
|
||||
}
|
||||
|
||||
return dest, nil
|
||||
}
|
||||
|
||||
func mergeRequestBody(ctx context.Context, dest, src *openapi3.RequestBodyRef) (*openapi3.RequestBodyRef, error) {
|
||||
if dest == nil || src == nil {
|
||||
return src, nil
|
||||
}
|
||||
|
||||
for ct, dm := range dest.Value.Content {
|
||||
sm, ok := src.Value.Content[ct]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
nv, err := mergeMediaSchema(ctx, dm.Schema.Value, sm.Schema.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dm.Schema.Value = nv
|
||||
}
|
||||
|
||||
return dest, nil
|
||||
}
|
||||
|
||||
func mergeMediaSchema(ctx context.Context, dest, src *openapi3.Schema) (*openapi3.Schema, error) {
|
||||
if dest.Extensions == nil {
|
||||
dest.Extensions = map[string]any{}
|
||||
}
|
||||
if v, ok := src.Extensions[model.APISchemaExtendLocalDisable]; ok {
|
||||
dest.Extensions[model.APISchemaExtendLocalDisable] = v
|
||||
}
|
||||
if v, ok := src.Extensions[model.APISchemaExtendVariableRef]; ok {
|
||||
dest.Extensions[model.APISchemaExtendVariableRef] = v
|
||||
}
|
||||
|
||||
dest.Default = src.Default
|
||||
|
||||
switch dest.Type {
|
||||
case openapi3.TypeObject:
|
||||
for k, dv := range dest.Properties {
|
||||
sv, ok := src.Properties[k]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
nv, err := mergeMediaSchema(ctx, dv.Value, sv.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dv.Value = nv
|
||||
}
|
||||
|
||||
return dest, nil
|
||||
|
||||
case openapi3.TypeArray:
|
||||
nv, err := mergeMediaSchema(ctx, dest.Items.Value, src.Items.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dest.Items.Value = nv
|
||||
|
||||
return dest, nil
|
||||
|
||||
default:
|
||||
return dest, nil
|
||||
}
|
||||
}
|
||||
|
||||
func mergeResponseBody(ctx context.Context, dest, src openapi3.Responses) (openapi3.Responses, error) {
|
||||
if len(dest) == 0 || len(src) == 0 {
|
||||
return src, nil
|
||||
}
|
||||
|
||||
for code, dr := range dest {
|
||||
sr := src[code]
|
||||
if dr == nil || sr == nil {
|
||||
continue
|
||||
}
|
||||
if len(dr.Value.Content) == 0 || len(sr.Value.Content) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
for ct, dm := range dr.Value.Content {
|
||||
sm, ok := sr.Value.Content[ct]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
nv, err := mergeMediaSchema(ctx, dm.Schema.Value, sm.Schema.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dm.Schema.Value = nv
|
||||
}
|
||||
}
|
||||
|
||||
return dest, nil
|
||||
}
|
||||
1352
backend/domain/plugin/service/exec_tool.go
Normal file
1352
backend/domain/plugin/service/exec_tool.go
Normal file
File diff suppressed because it is too large
Load Diff
903
backend/domain/plugin/service/plugin_draft.go
Normal file
903
backend/domain/plugin/service/plugin_draft.go
Normal file
@@ -0,0 +1,903 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/getkin/kin-openapi/openapi3"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
|
||||
searchModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/search"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common"
|
||||
common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common"
|
||||
resCommon "github.com/coze-dev/coze-studio/backend/api/model/resource/common"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crosssearch"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/internal/openapi"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/repository"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
func (p *pluginServiceImpl) CreateDraftPlugin(ctx context.Context, req *CreateDraftPluginRequest) (pluginID int64, err error) {
|
||||
mf := entity.NewDefaultPluginManifest()
|
||||
mf.NameForHuman = req.Name
|
||||
mf.NameForModel = req.Name
|
||||
mf.DescriptionForHuman = req.Desc
|
||||
mf.DescriptionForModel = req.Desc
|
||||
mf.API.Type, _ = model.ToPluginType(req.PluginType)
|
||||
mf.LogoURL = req.IconURI
|
||||
|
||||
authV2, err := req.AuthInfo.toAuthV2()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
mf.Auth = authV2
|
||||
|
||||
for loc, params := range req.CommonParams {
|
||||
location, ok := model.ToHTTPParamLocation(loc)
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("invalid location '%s'", loc.String())
|
||||
}
|
||||
for _, param := range params {
|
||||
mParams := mf.CommonParams[location]
|
||||
mParams = append(mParams, &plugin_develop_common.CommonParamSchema{
|
||||
Name: param.Name,
|
||||
Value: param.Value,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
doc := entity.NewDefaultOpenapiDoc()
|
||||
doc.Servers = append(doc.Servers, &openapi3.Server{
|
||||
URL: req.ServerURL,
|
||||
})
|
||||
doc.Info.Title = req.Name
|
||||
doc.Info.Description = req.Desc
|
||||
|
||||
err = doc.Validate(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
err = mf.Validate(false)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
pl := entity.NewPluginInfo(&model.PluginInfo{
|
||||
IconURI: ptr.Of(req.IconURI),
|
||||
SpaceID: req.SpaceID,
|
||||
ServerURL: ptr.Of(req.ServerURL),
|
||||
DeveloperID: req.DeveloperID,
|
||||
APPID: req.ProjectID,
|
||||
PluginType: req.PluginType,
|
||||
Manifest: mf,
|
||||
OpenapiDoc: doc,
|
||||
})
|
||||
|
||||
pluginID, err = p.pluginRepo.CreateDraftPlugin(ctx, pl)
|
||||
if err != nil {
|
||||
return 0, errorx.Wrapf(err, "CreateDraftPlugin failed")
|
||||
}
|
||||
|
||||
return pluginID, nil
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) GetDraftPlugin(ctx context.Context, pluginID int64) (plugin *entity.PluginInfo, err error) {
|
||||
pl, exist, err := p.pluginRepo.GetDraftPlugin(ctx, pluginID)
|
||||
if err != nil {
|
||||
return nil, errorx.Wrapf(err, "GetDraftPlugin failed, pluginID=%d", pluginID)
|
||||
}
|
||||
if !exist {
|
||||
return nil, errorx.New(errno.ErrPluginRecordNotFound)
|
||||
}
|
||||
|
||||
return pl, nil
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) MGetDraftPlugins(ctx context.Context, pluginIDs []int64) (plugins []*entity.PluginInfo, err error) {
|
||||
plugins, err = p.pluginRepo.MGetDraftPlugins(ctx, pluginIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return plugins, nil
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) ListDraftPlugins(ctx context.Context, req *ListDraftPluginsRequest) (resp *ListDraftPluginsResponse, err error) {
|
||||
if req.PageInfo.Name == nil || *req.PageInfo.Name == "" {
|
||||
res, err := p.pluginRepo.ListDraftPlugins(ctx, &repository.ListDraftPluginsRequest{
|
||||
SpaceID: req.SpaceID,
|
||||
APPID: req.APPID,
|
||||
PageInfo: req.PageInfo,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errorx.Wrapf(err, "ListDraftPlugins failed, spaceID=%d, appID=%d", req.SpaceID, req.APPID)
|
||||
}
|
||||
|
||||
return &ListDraftPluginsResponse{
|
||||
Plugins: res.Plugins,
|
||||
Total: res.Total,
|
||||
}, nil
|
||||
}
|
||||
|
||||
res, err := crosssearch.DefaultSVC().SearchResources(ctx, &searchModel.SearchResourcesRequest{
|
||||
SpaceID: req.SpaceID,
|
||||
APPID: req.APPID,
|
||||
Name: *req.PageInfo.Name,
|
||||
OrderAsc: false,
|
||||
ResTypeFilter: []resCommon.ResType{
|
||||
resCommon.ResType_Plugin,
|
||||
},
|
||||
OrderFiledName: func() string {
|
||||
if req.PageInfo.SortBy == nil || *req.PageInfo.SortBy != entity.SortByCreatedAt {
|
||||
return searchModel.FieldOfUpdateTime
|
||||
}
|
||||
return searchModel.FieldOfCreateTime
|
||||
}(),
|
||||
Page: ptr.Of(int32(req.PageInfo.Page)),
|
||||
Limit: int32(req.PageInfo.Size),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errorx.Wrapf(err, "SearchResources failed, spaceID=%d, appID=%d", req.SpaceID, req.APPID)
|
||||
}
|
||||
|
||||
plugins := make([]*entity.PluginInfo, 0, len(res.Data))
|
||||
for _, pl := range res.Data {
|
||||
draftPlugin, exist, err := p.pluginRepo.GetDraftPlugin(ctx, pl.ResID)
|
||||
if err != nil {
|
||||
return nil, errorx.Wrapf(err, "GetDraftPlugin failed, pluginID=%d", pl.ResID)
|
||||
}
|
||||
if !exist {
|
||||
logs.CtxWarnf(ctx, "draft plugin not exist, pluginID=%d", pl.ResID)
|
||||
continue
|
||||
}
|
||||
plugins = append(plugins, draftPlugin)
|
||||
}
|
||||
|
||||
total := int64(0)
|
||||
if res.TotalHits != nil {
|
||||
total = *res.TotalHits
|
||||
}
|
||||
|
||||
return &ListDraftPluginsResponse{
|
||||
Plugins: plugins,
|
||||
Total: total,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) CreateDraftPluginWithCode(ctx context.Context, req *CreateDraftPluginWithCodeRequest) (resp *CreateDraftPluginWithCodeResponse, err error) {
|
||||
err = req.OpenapiDoc.Validate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = req.Manifest.Validate(false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err := p.pluginRepo.CreateDraftPluginWithCode(ctx, &repository.CreateDraftPluginWithCodeRequest{
|
||||
SpaceID: req.SpaceID,
|
||||
DeveloperID: req.DeveloperID,
|
||||
ProjectID: req.ProjectID,
|
||||
Manifest: req.Manifest,
|
||||
OpenapiDoc: req.OpenapiDoc,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errorx.Wrapf(err, "CreateDraftPluginWithCode failed")
|
||||
}
|
||||
|
||||
resp = &CreateDraftPluginWithCodeResponse{
|
||||
Plugin: res.Plugin,
|
||||
Tools: res.Tools,
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) UpdateDraftPluginWithCode(ctx context.Context, req *UpdateDraftPluginWithCodeRequest) (err error) {
|
||||
doc := req.OpenapiDoc
|
||||
mf := req.Manifest
|
||||
|
||||
err = doc.Validate(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = mf.Validate(false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
apiSchemas := make(map[entity.UniqueToolAPI]*model.Openapi3Operation, len(doc.Paths))
|
||||
apis := make([]entity.UniqueToolAPI, 0, len(doc.Paths))
|
||||
|
||||
for subURL, pathItem := range doc.Paths {
|
||||
for method, op := range pathItem.Operations() {
|
||||
api := entity.UniqueToolAPI{
|
||||
SubURL: subURL,
|
||||
Method: method,
|
||||
}
|
||||
apiSchemas[api] = model.NewOpenapi3Operation(op)
|
||||
apis = append(apis, api)
|
||||
}
|
||||
}
|
||||
|
||||
oldDraftTools, err := p.toolRepo.GetPluginAllDraftTools(ctx, req.PluginID)
|
||||
if err != nil {
|
||||
return errorx.Wrapf(err, "GetPluginAllDraftTools failed, pluginID=%d", req.PluginID)
|
||||
}
|
||||
|
||||
draftPlugin, exist, err := p.pluginRepo.GetDraftPlugin(ctx, req.PluginID)
|
||||
if err != nil {
|
||||
return errorx.Wrapf(err, "GetDraftPlugin failed, pluginID=%d", req.PluginID)
|
||||
}
|
||||
if !exist {
|
||||
return errorx.New(errno.ErrPluginRecordNotFound)
|
||||
}
|
||||
|
||||
if draftPlugin.GetServerURL() != doc.Servers[0].URL {
|
||||
for _, draftTool := range oldDraftTools {
|
||||
draftTool.DebugStatus = ptr.Of(common.APIDebugStatus_DebugWaiting)
|
||||
}
|
||||
}
|
||||
|
||||
oldDraftToolsMap := slices.ToMap(oldDraftTools, func(e *entity.ToolInfo) (entity.UniqueToolAPI, *entity.ToolInfo) {
|
||||
return entity.UniqueToolAPI{
|
||||
SubURL: e.GetSubURL(),
|
||||
Method: e.GetMethod(),
|
||||
}, e
|
||||
})
|
||||
|
||||
// 1. 删除 tool -> 关闭启用
|
||||
for api, oldTool := range oldDraftToolsMap {
|
||||
_, ok := apiSchemas[api]
|
||||
if !ok {
|
||||
oldTool.DebugStatus = ptr.Of(common.APIDebugStatus_DebugWaiting)
|
||||
oldTool.ActivatedStatus = ptr.Of(model.DeactivateTool)
|
||||
}
|
||||
}
|
||||
|
||||
newDraftTools := make([]*entity.ToolInfo, 0, len(apis))
|
||||
for api, newOp := range apiSchemas {
|
||||
oldTool, ok := oldDraftToolsMap[api]
|
||||
if ok { // 2. 更新 tool -> 覆盖
|
||||
oldTool.ActivatedStatus = ptr.Of(model.ActivateTool)
|
||||
oldTool.Operation = newOp
|
||||
if needResetDebugStatusTool(ctx, newOp, oldTool.Operation) {
|
||||
oldTool.DebugStatus = ptr.Of(common.APIDebugStatus_DebugWaiting)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 3. 新增 tool
|
||||
newDraftTools = append(newDraftTools, &entity.ToolInfo{
|
||||
PluginID: req.PluginID,
|
||||
ActivatedStatus: ptr.Of(model.ActivateTool),
|
||||
DebugStatus: ptr.Of(common.APIDebugStatus_DebugWaiting),
|
||||
SubURL: ptr.Of(api.SubURL),
|
||||
Method: ptr.Of(api.Method),
|
||||
Operation: newOp,
|
||||
})
|
||||
}
|
||||
|
||||
err = p.pluginRepo.UpdateDraftPluginWithCode(ctx, &repository.UpdatePluginDraftWithCode{
|
||||
PluginID: req.PluginID,
|
||||
OpenapiDoc: doc,
|
||||
Manifest: mf,
|
||||
UpdatedTools: oldDraftTools,
|
||||
NewDraftTools: newDraftTools,
|
||||
})
|
||||
if err != nil {
|
||||
return errorx.Wrapf(err, "UpdateDraftPluginWithCode failed, pluginID=%d", req.PluginID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func needResetDebugStatusTool(_ context.Context, nt, ot *model.Openapi3Operation) bool {
|
||||
if len(ot.Parameters) != len(ot.Parameters) {
|
||||
return true
|
||||
}
|
||||
|
||||
otParams := make(map[string]*openapi3.Parameter, len(ot.Parameters))
|
||||
cnt := make(map[string]int, len(nt.Parameters))
|
||||
|
||||
for _, p := range nt.Parameters {
|
||||
cnt[p.Value.Name]++
|
||||
}
|
||||
for _, p := range ot.Parameters {
|
||||
cnt[p.Value.Name]--
|
||||
otParams[p.Value.Name] = p.Value
|
||||
}
|
||||
for _, v := range cnt {
|
||||
if v != 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
for _, p := range nt.Parameters {
|
||||
np, op := p.Value, otParams[p.Value.Name]
|
||||
if np.In != op.In {
|
||||
return true
|
||||
}
|
||||
if np.Required != op.Required {
|
||||
return true
|
||||
}
|
||||
|
||||
if !isJsonSchemaEqual(op.Schema.Value, np.Schema.Value) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
nReqBody, oReqBody := nt.RequestBody.Value, ot.RequestBody.Value
|
||||
if len(nReqBody.Content) != len(oReqBody.Content) {
|
||||
return true
|
||||
}
|
||||
cnt = make(map[string]int, len(nReqBody.Content))
|
||||
for ct := range nReqBody.Content {
|
||||
cnt[ct]++
|
||||
}
|
||||
for ct := range oReqBody.Content {
|
||||
cnt[ct]--
|
||||
}
|
||||
for _, v := range cnt {
|
||||
if v != 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
for ct, nct := range nReqBody.Content {
|
||||
oct := oReqBody.Content[ct]
|
||||
if !isJsonSchemaEqual(nct.Schema.Value, oct.Schema.Value) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func isJsonSchemaEqual(nsc, osc *openapi3.Schema) bool {
|
||||
if nsc.Type != osc.Type {
|
||||
return false
|
||||
}
|
||||
if nsc.Format != osc.Format {
|
||||
return false
|
||||
}
|
||||
if nsc.Default != osc.Default {
|
||||
return false
|
||||
}
|
||||
if nsc.Extensions[model.APISchemaExtendAssistType] != osc.Extensions[model.APISchemaExtendAssistType] {
|
||||
return false
|
||||
}
|
||||
if nsc.Extensions[model.APISchemaExtendGlobalDisable] != osc.Extensions[model.APISchemaExtendGlobalDisable] {
|
||||
return false
|
||||
}
|
||||
|
||||
switch nsc.Type {
|
||||
case openapi3.TypeObject:
|
||||
if len(nsc.Required) != len(osc.Required) {
|
||||
return false
|
||||
}
|
||||
if len(nsc.Required) > 0 {
|
||||
cnt := make(map[string]int, len(nsc.Required))
|
||||
for _, x := range nsc.Required {
|
||||
cnt[x]++
|
||||
}
|
||||
for _, x := range osc.Required {
|
||||
cnt[x]--
|
||||
}
|
||||
for _, v := range cnt {
|
||||
if v != 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(nsc.Properties) != len(osc.Properties) {
|
||||
return false
|
||||
}
|
||||
if len(nsc.Properties) > 0 {
|
||||
for paramName, np := range nsc.Properties {
|
||||
op, ok := osc.Properties[paramName]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if !isJsonSchemaEqual(np.Value, op.Value) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
case openapi3.TypeArray:
|
||||
if !isJsonSchemaEqual(nsc.Items.Value, osc.Items.Value) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) UpdateDraftPlugin(ctx context.Context, req *UpdateDraftPluginRequest) (err error) {
|
||||
oldPlugin, exist, err := p.pluginRepo.GetDraftPlugin(ctx, req.PluginID)
|
||||
if err != nil {
|
||||
return errorx.Wrapf(err, "GetDraftPlugin failed, pluginID=%d", req.PluginID)
|
||||
}
|
||||
if !exist {
|
||||
return errorx.New(errno.ErrPluginRecordNotFound)
|
||||
}
|
||||
|
||||
doc, err := updatePluginOpenapiDoc(ctx, oldPlugin.OpenapiDoc, req)
|
||||
if err != nil {
|
||||
return errorx.Wrapf(err, "updatePluginOpenapiDoc failed")
|
||||
}
|
||||
mf, err := updatePluginManifest(ctx, oldPlugin.Manifest, req)
|
||||
if err != nil {
|
||||
return errorx.Wrapf(err, "updatePluginManifest failed")
|
||||
}
|
||||
|
||||
newPlugin := entity.NewPluginInfo(&model.PluginInfo{
|
||||
ID: req.PluginID,
|
||||
IconURI: ptr.Of(mf.LogoURL),
|
||||
ServerURL: req.URL,
|
||||
Manifest: mf,
|
||||
OpenapiDoc: doc,
|
||||
})
|
||||
|
||||
if newPlugin.GetServerURL() == "" ||
|
||||
oldPlugin.GetServerURL() == newPlugin.GetServerURL() {
|
||||
err = p.pluginRepo.UpdateDraftPluginWithoutURLChanged(ctx, newPlugin)
|
||||
if err != nil {
|
||||
return errorx.Wrapf(err, "UpdateDraftPluginWithoutURLChanged failed, pluginID=%d", req.PluginID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
err = p.pluginRepo.UpdateDraftPlugin(ctx, newPlugin)
|
||||
if err != nil {
|
||||
return errorx.Wrapf(err, "UpdateDraftPlugin failed, pluginID=%d", req.PluginID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func updatePluginOpenapiDoc(_ context.Context, doc *model.Openapi3T, req *UpdateDraftPluginRequest) (*model.Openapi3T, error) {
|
||||
if req.Name != nil {
|
||||
doc.Info.Title = *req.Name
|
||||
}
|
||||
|
||||
if req.Desc != nil {
|
||||
doc.Info.Description = *req.Desc
|
||||
}
|
||||
|
||||
if req.URL != nil {
|
||||
hasServer := false
|
||||
for _, svr := range doc.Servers {
|
||||
if svr.URL == *req.URL {
|
||||
hasServer = true
|
||||
}
|
||||
}
|
||||
if !hasServer {
|
||||
doc.Servers = openapi3.Servers{{URL: *req.URL}}
|
||||
}
|
||||
}
|
||||
|
||||
return doc, nil
|
||||
}
|
||||
|
||||
func updatePluginManifest(_ context.Context, mf *entity.PluginManifest, req *UpdateDraftPluginRequest) (*entity.PluginManifest, error) {
|
||||
if req.Name != nil {
|
||||
mf.NameForHuman = *req.Name
|
||||
mf.NameForModel = *req.Name
|
||||
}
|
||||
|
||||
if req.Desc != nil {
|
||||
mf.DescriptionForHuman = *req.Desc
|
||||
mf.DescriptionForModel = *req.Desc
|
||||
}
|
||||
|
||||
if req.Icon != nil {
|
||||
mf.LogoURL = req.Icon.URI
|
||||
}
|
||||
|
||||
if len(req.CommonParams) > 0 {
|
||||
if mf.CommonParams == nil {
|
||||
mf.CommonParams = make(map[model.HTTPParamLocation][]*plugin_develop_common.CommonParamSchema, len(req.CommonParams))
|
||||
}
|
||||
for loc, params := range req.CommonParams {
|
||||
location, ok := model.ToHTTPParamLocation(loc)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid location '%s'", loc.String())
|
||||
}
|
||||
commonParams := make([]*plugin_develop_common.CommonParamSchema, 0, len(params))
|
||||
for _, param := range params {
|
||||
commonParams = append(commonParams, &plugin_develop_common.CommonParamSchema{
|
||||
Name: param.Name,
|
||||
Value: param.Value,
|
||||
})
|
||||
}
|
||||
mf.CommonParams[location] = commonParams
|
||||
}
|
||||
}
|
||||
|
||||
if req.AuthInfo != nil {
|
||||
authV2, err := req.AuthInfo.toAuthV2()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mf.Auth = authV2
|
||||
}
|
||||
|
||||
return mf, nil
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) DeleteDraftPlugin(ctx context.Context, pluginID int64) (err error) {
|
||||
return p.pluginRepo.DeleteDraftPlugin(ctx, pluginID)
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) MGetDraftTools(ctx context.Context, toolIDs []int64) (tools []*entity.ToolInfo, err error) {
|
||||
tools, err = p.toolRepo.MGetDraftTools(ctx, toolIDs)
|
||||
if err != nil {
|
||||
return nil, errorx.Wrapf(err, "MGetDraftTools failed, toolIDs=%v", toolIDs)
|
||||
}
|
||||
|
||||
return tools, nil
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) UpdateDraftTool(ctx context.Context, req *UpdateToolDraftRequest) (err error) {
|
||||
draftPlugin, exist, err := p.pluginRepo.GetDraftPlugin(ctx, req.PluginID)
|
||||
if err != nil {
|
||||
return errorx.Wrapf(err, "GetDraftPlugin failed, pluginID=%d", req.PluginID)
|
||||
}
|
||||
if !exist {
|
||||
return errorx.New(errno.ErrPluginRecordNotFound)
|
||||
}
|
||||
|
||||
draftTool, exist, err := p.toolRepo.GetDraftTool(ctx, req.ToolID)
|
||||
if err != nil {
|
||||
return errorx.Wrapf(err, "GetDraftTool failed, toolID=%d", req.ToolID)
|
||||
}
|
||||
if !exist {
|
||||
return errorx.New(errno.ErrPluginRecordNotFound)
|
||||
}
|
||||
|
||||
if req.Method != nil && req.SubURL != nil {
|
||||
api := entity.UniqueToolAPI{
|
||||
SubURL: ptr.FromOrDefault(req.SubURL, ""),
|
||||
Method: ptr.FromOrDefault(req.Method, ""),
|
||||
}
|
||||
existTool, exist, err := p.toolRepo.GetDraftToolWithAPI(ctx, draftTool.PluginID, api)
|
||||
if err != nil {
|
||||
return errorx.Wrapf(err, "GetDraftToolWithAPI failed, pluginID=%d, api=%v", draftTool.PluginID, api)
|
||||
}
|
||||
if exist && draftTool.ID != existTool.ID {
|
||||
return errorx.New(errno.ErrPluginDuplicatedTool, errorx.KVf(errno.PluginMsgKey, "[%s]:%s", api.Method, api.SubURL))
|
||||
}
|
||||
}
|
||||
|
||||
var activatedStatus *model.ActivatedStatus
|
||||
if req.Disabled != nil {
|
||||
if *req.Disabled {
|
||||
activatedStatus = ptr.Of(model.DeactivateTool)
|
||||
} else {
|
||||
activatedStatus = ptr.Of(model.ActivateTool)
|
||||
}
|
||||
}
|
||||
|
||||
debugStatus := draftTool.DebugStatus
|
||||
if req.Method != nil ||
|
||||
req.SubURL != nil ||
|
||||
req.Parameters != nil ||
|
||||
req.RequestBody != nil ||
|
||||
req.Responses != nil {
|
||||
debugStatus = ptr.Of(common.APIDebugStatus_DebugWaiting)
|
||||
}
|
||||
|
||||
op := draftTool.Operation
|
||||
if req.Name != nil {
|
||||
op.OperationID = *req.Name
|
||||
}
|
||||
if req.Desc != nil {
|
||||
op.Summary = *req.Desc
|
||||
}
|
||||
if req.Parameters != nil {
|
||||
op.Parameters = req.Parameters
|
||||
}
|
||||
if req.APIExtend != nil {
|
||||
if op.Extensions == nil {
|
||||
op.Extensions = map[string]any{}
|
||||
}
|
||||
authMode, ok := model.ToAPIAuthMode(req.APIExtend.AuthMode)
|
||||
if ok {
|
||||
op.Extensions[model.APISchemaExtendAuthMode] = authMode
|
||||
}
|
||||
}
|
||||
|
||||
if req.RequestBody == nil {
|
||||
op.RequestBody = draftTool.Operation.RequestBody
|
||||
} else {
|
||||
mType, ok := req.RequestBody.Value.Content[model.MediaTypeJson]
|
||||
if !ok {
|
||||
return fmt.Errorf("the '%s' media type is not defined in request body", model.MediaTypeJson)
|
||||
}
|
||||
if op.RequestBody == nil || op.RequestBody.Value == nil || op.RequestBody.Value.Content == nil {
|
||||
op.RequestBody = &openapi3.RequestBodyRef{
|
||||
Value: &openapi3.RequestBody{
|
||||
Content: map[string]*openapi3.MediaType{},
|
||||
},
|
||||
}
|
||||
}
|
||||
op.RequestBody.Value.Content[model.MediaTypeJson] = mType
|
||||
}
|
||||
|
||||
if req.Responses == nil {
|
||||
op.Responses = draftTool.Operation.Responses
|
||||
} else {
|
||||
newRespRef, ok := req.Responses[strconv.Itoa(http.StatusOK)]
|
||||
if !ok {
|
||||
return fmt.Errorf("the '%d' status code is not defined in responses", http.StatusOK)
|
||||
}
|
||||
newMIMEType, ok := newRespRef.Value.Content[model.MediaTypeJson]
|
||||
if !ok {
|
||||
return fmt.Errorf("the '%s' media type is not defined in responses", model.MediaTypeJson)
|
||||
}
|
||||
|
||||
if op.Responses == nil {
|
||||
op.Responses = map[string]*openapi3.ResponseRef{}
|
||||
}
|
||||
|
||||
oldRespRef, ok := op.Responses[strconv.Itoa(http.StatusOK)]
|
||||
if !ok {
|
||||
oldRespRef = &openapi3.ResponseRef{
|
||||
Value: &openapi3.Response{
|
||||
Content: map[string]*openapi3.MediaType{},
|
||||
},
|
||||
}
|
||||
op.Responses[strconv.Itoa(http.StatusOK)] = oldRespRef
|
||||
}
|
||||
|
||||
if oldRespRef.Value.Content == nil {
|
||||
oldRespRef.Value.Content = map[string]*openapi3.MediaType{}
|
||||
}
|
||||
|
||||
oldRespRef.Value.Content[model.MediaTypeJson] = newMIMEType
|
||||
}
|
||||
|
||||
updatedTool := &entity.ToolInfo{
|
||||
ID: req.ToolID,
|
||||
PluginID: req.PluginID,
|
||||
ActivatedStatus: activatedStatus,
|
||||
DebugStatus: debugStatus,
|
||||
Method: req.Method,
|
||||
SubURL: req.SubURL,
|
||||
Operation: op,
|
||||
}
|
||||
|
||||
components := draftPlugin.OpenapiDoc.Components
|
||||
if req.SaveExample != nil && !*req.SaveExample &&
|
||||
components != nil && components.Examples != nil {
|
||||
delete(components.Examples, draftTool.Operation.OperationID)
|
||||
} else if req.DebugExample != nil {
|
||||
if components == nil {
|
||||
components = &openapi3.Components{}
|
||||
}
|
||||
if components.Examples == nil {
|
||||
components.Examples = make(map[string]*openapi3.ExampleRef)
|
||||
}
|
||||
|
||||
draftPlugin.OpenapiDoc.Components = components
|
||||
|
||||
reqExample, respExample := map[string]any{}, map[string]any{}
|
||||
if req.DebugExample.ReqExample != "" {
|
||||
err = sonic.UnmarshalString(req.DebugExample.ReqExample, &reqExample)
|
||||
if err != nil {
|
||||
return errorx.WrapByCode(err, errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey, "invalid request example"))
|
||||
}
|
||||
}
|
||||
if req.DebugExample.RespExample != "" {
|
||||
err = sonic.UnmarshalString(req.DebugExample.RespExample, &respExample)
|
||||
if err != nil {
|
||||
return errorx.WrapByCode(err, errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey, "invalid response example"))
|
||||
}
|
||||
}
|
||||
|
||||
components.Examples[draftTool.Operation.OperationID] = &openapi3.ExampleRef{
|
||||
Value: &openapi3.Example{
|
||||
Value: map[string]any{
|
||||
"ReqExample": reqExample,
|
||||
"RespExample": respExample,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
err = p.toolRepo.UpdateDraftToolAndDebugExample(ctx, draftPlugin.ID, draftPlugin.OpenapiDoc, updatedTool)
|
||||
if err != nil {
|
||||
return errorx.Wrapf(err, "UpdateDraftToolAndDebugExample failed, pluginID=%d, toolID=%d", draftPlugin.ID, req.ToolID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) ConvertToOpenapi3Doc(ctx context.Context, req *ConvertToOpenapi3DocRequest) (resp *ConvertToOpenapi3DocResponse) {
|
||||
var err error
|
||||
defer func() {
|
||||
if err != nil {
|
||||
logs.Errorf("ConvertToOpenapi3Doc failed, err=%s", err)
|
||||
|
||||
resp.ErrMsg = "internal server error"
|
||||
|
||||
var e errorx.StatusError
|
||||
if errors.As(err, &e) {
|
||||
resp.ErrMsg = e.Msg()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
resp = &ConvertToOpenapi3DocResponse{}
|
||||
|
||||
cvt, format, err := getConvertFunc(ctx, req.RawInput)
|
||||
if err != nil {
|
||||
resp.Format = format
|
||||
return resp
|
||||
}
|
||||
|
||||
doc, mf, err := cvt(ctx, req.RawInput)
|
||||
if err != nil {
|
||||
resp.Format = format
|
||||
return resp
|
||||
}
|
||||
|
||||
err = validateConvertResult(ctx, req, doc, mf)
|
||||
if err != nil {
|
||||
resp.Format = format
|
||||
return resp
|
||||
}
|
||||
|
||||
return &ConvertToOpenapi3DocResponse{
|
||||
OpenapiDoc: doc,
|
||||
Manifest: mf,
|
||||
Format: format,
|
||||
ErrMsg: "",
|
||||
}
|
||||
}
|
||||
|
||||
type convertFunc func(ctx context.Context, rawInput string) (*model.Openapi3T, *entity.PluginManifest, error)
|
||||
|
||||
func getConvertFunc(ctx context.Context, rawInput string) (convertFunc, common.PluginDataFormat, error) {
|
||||
if strings.HasPrefix(rawInput, "curl") {
|
||||
return openapi.CurlToOpenapi3Doc, common.PluginDataFormat_Curl, nil
|
||||
}
|
||||
|
||||
if strings.Contains(rawInput, "_postman_id") { // postman collection
|
||||
return openapi.PostmanToOpenapi3Doc, common.PluginDataFormat_Postman, nil
|
||||
}
|
||||
|
||||
var vd struct {
|
||||
OpenAPI string `json:"openapi" yaml:"openapi"`
|
||||
Swagger string `json:"swagger" yaml:"swagger"`
|
||||
}
|
||||
|
||||
err := sonic.UnmarshalString(rawInput, &vd)
|
||||
if err != nil {
|
||||
err = yaml.Unmarshal([]byte(rawInput), &vd)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("invalid schema")
|
||||
}
|
||||
}
|
||||
|
||||
if vd.OpenAPI == "3" || strings.HasPrefix(vd.OpenAPI, "3.") {
|
||||
return openapi.ToOpenapi3Doc, common.PluginDataFormat_OpenAPI, nil
|
||||
}
|
||||
|
||||
if vd.Swagger == "2" || strings.HasPrefix(vd.Swagger, "2.") {
|
||||
return openapi.SwaggerToOpenapi3Doc, common.PluginDataFormat_Swagger, nil
|
||||
}
|
||||
|
||||
return nil, 0, fmt.Errorf("invalid schema")
|
||||
}
|
||||
|
||||
func validateConvertResult(ctx context.Context, req *ConvertToOpenapi3DocRequest, doc *model.Openapi3T, mf *entity.PluginManifest) error {
|
||||
if req.PluginServerURL != nil {
|
||||
if doc.Servers[0].URL != *req.PluginServerURL {
|
||||
return errorx.New(errno.ErrPluginConvertProtocolFailed, errorx.KV(errno.PluginMsgKey, "inconsistent API URL prefix"))
|
||||
}
|
||||
}
|
||||
|
||||
err := doc.Validate(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = mf.Validate(false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) CreateDraftToolsWithCode(ctx context.Context, req *CreateDraftToolsWithCodeRequest) (resp *CreateDraftToolsWithCodeResponse, err error) {
|
||||
err = req.OpenapiDoc.Validate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
toolAPIs := make([]entity.UniqueToolAPI, 0, len(req.OpenapiDoc.Paths))
|
||||
for path, item := range req.OpenapiDoc.Paths {
|
||||
for method := range item.Operations() {
|
||||
toolAPIs = append(toolAPIs, entity.UniqueToolAPI{
|
||||
SubURL: path,
|
||||
Method: method,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
existTools, err := p.toolRepo.MGetDraftToolWithAPI(ctx, req.PluginID, toolAPIs,
|
||||
repository.WithToolID(),
|
||||
repository.WithToolMethod(),
|
||||
repository.WithToolSubURL())
|
||||
if err != nil {
|
||||
return nil, errorx.Wrapf(err, "MGetDraftToolWithAPI failed, pluginID=%d, apis=%v", req.PluginID, toolAPIs)
|
||||
}
|
||||
|
||||
duplicatedTools := make([]entity.UniqueToolAPI, 0, len(existTools))
|
||||
for _, api := range toolAPIs {
|
||||
if _, exist := existTools[api]; exist {
|
||||
duplicatedTools = append(duplicatedTools, api)
|
||||
}
|
||||
}
|
||||
|
||||
if !req.ConflictAndUpdate && len(duplicatedTools) > 0 {
|
||||
return &CreateDraftToolsWithCodeResponse{
|
||||
DuplicatedTools: duplicatedTools,
|
||||
}, nil
|
||||
}
|
||||
|
||||
tools := make([]*entity.ToolInfo, 0, len(toolAPIs))
|
||||
for path, item := range req.OpenapiDoc.Paths {
|
||||
for method, op := range item.Operations() {
|
||||
tools = append(tools, &entity.ToolInfo{
|
||||
PluginID: req.PluginID,
|
||||
Method: ptr.Of(method),
|
||||
SubURL: ptr.Of(path),
|
||||
ActivatedStatus: ptr.Of(model.ActivateTool),
|
||||
DebugStatus: ptr.Of(common.APIDebugStatus_DebugWaiting),
|
||||
Operation: model.NewOpenapi3Operation(op),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
err = p.toolRepo.UpsertDraftTools(ctx, req.PluginID, tools)
|
||||
if err != nil {
|
||||
return nil, errorx.Wrapf(err, "UpsertDraftTools failed, pluginID=%d", req.PluginID)
|
||||
}
|
||||
|
||||
resp = &CreateDraftToolsWithCodeResponse{}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
512
backend/domain/plugin/service/plugin_oauth.go
Normal file
512
backend/domain/plugin/service/plugin_oauth.go
Normal file
@@ -0,0 +1,512 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
|
||||
common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/utils"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/taskgroup"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
var (
|
||||
initOnce = sync.Once{}
|
||||
lastActiveInterval = 15 * 24 * time.Hour
|
||||
)
|
||||
|
||||
func (p *pluginServiceImpl) processOAuthAccessToken(ctx context.Context) {
|
||||
const (
|
||||
deleteLimit = 100
|
||||
refreshLimit = 50
|
||||
)
|
||||
|
||||
for {
|
||||
now := time.Now()
|
||||
|
||||
lastActiveAt := now.Add(-lastActiveInterval)
|
||||
err := p.oauthRepo.DeleteInactiveAuthorizationCodeTokens(ctx, lastActiveAt.UnixMilli(), deleteLimit)
|
||||
if err != nil {
|
||||
logs.CtxWarnf(ctx, "DeleteInactiveAuthorizationCodeTokens failed, err=%v", err)
|
||||
}
|
||||
|
||||
err = p.oauthRepo.DeleteExpiredAuthorizationCodeTokens(ctx, now.UnixMilli(), deleteLimit)
|
||||
if err != nil {
|
||||
logs.CtxWarnf(ctx, "DeleteExpiredAuthorizationCodeTokens failed, err=%v", err)
|
||||
}
|
||||
|
||||
refreshTokenList, err := p.oauthRepo.GetAuthorizationCodeRefreshTokens(ctx, now.UnixMilli(), refreshLimit)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "GetAuthorizationCodeRefreshTokens failed, err=%v", err)
|
||||
<-time.After(time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
taskGroups := taskgroup.NewTaskGroup(ctx, 3)
|
||||
expired := make([]int64, 0, len(refreshTokenList))
|
||||
|
||||
for _, info := range refreshTokenList {
|
||||
if info.GetNextTokenRefreshAtMS() == 0 || info.TokenExpiredAtMS == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if info.GetNextTokenRefreshAtMS() > now.UnixMilli() ||
|
||||
info.LastActiveAtMS <= lastActiveAt.UnixMilli() {
|
||||
expired = append(expired, info.RecordID)
|
||||
continue
|
||||
}
|
||||
|
||||
taskGroups.Go(func() error {
|
||||
p.refreshToken(ctx, info)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
_ = taskGroups.Wait()
|
||||
|
||||
if len(expired) > 0 {
|
||||
err = p.oauthRepo.BatchDeleteAuthorizationCodeByIDs(ctx, expired)
|
||||
if err != nil {
|
||||
logs.CtxWarnf(ctx, "BatchDeleteAuthorizationCodeByIDs failed, err=%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
<-time.After(5 * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) refreshToken(ctx context.Context, info *entity.AuthorizationCodeInfo) {
|
||||
config := oauth2.Config{
|
||||
ClientID: info.Config.ClientID,
|
||||
ClientSecret: info.Config.ClientSecret,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
TokenURL: info.Config.AuthorizationURL,
|
||||
},
|
||||
Scopes: strings.Split(info.Config.Scope, " "),
|
||||
}
|
||||
|
||||
token := &oauth2.Token{
|
||||
AccessToken: info.AccessToken,
|
||||
RefreshToken: info.RefreshToken,
|
||||
Expiry: time.UnixMilli(info.TokenExpiredAtMS),
|
||||
}
|
||||
|
||||
source := config.TokenSource(ctx, token)
|
||||
|
||||
var (
|
||||
err error
|
||||
newToken *oauth2.Token
|
||||
)
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
newToken, err = source.Token()
|
||||
if err == nil {
|
||||
token = newToken
|
||||
break
|
||||
}
|
||||
<-time.After(time.Second)
|
||||
}
|
||||
if err != nil {
|
||||
logs.CtxInfof(ctx, "refreshToken failed, recordID=%d, err=%v", info.RecordID, err)
|
||||
err = p.oauthRepo.BatchDeleteAuthorizationCodeByIDs(ctx, []int64{info.RecordID})
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "BatchDeleteAuthorizationCodeByIDs failed, recordID=%d, err=%v", info.RecordID, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
var expiredAtMS int64
|
||||
if !token.Expiry.IsZero() && token.Expiry.After(time.Now()) {
|
||||
expiredAtMS = token.Expiry.UnixMilli()
|
||||
}
|
||||
|
||||
err = p.oauthRepo.UpsertAuthorizationCode(ctx, &entity.AuthorizationCodeInfo{
|
||||
Meta: &entity.AuthorizationCodeMeta{
|
||||
UserID: info.Meta.UserID,
|
||||
PluginID: info.Meta.PluginID,
|
||||
IsDraft: info.Meta.IsDraft,
|
||||
},
|
||||
Config: info.Config,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
TokenExpiredAtMS: expiredAtMS,
|
||||
NextTokenRefreshAtMS: ptr.Of(getNextTokenRefreshAtMS(expiredAtMS)),
|
||||
})
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
<-time.After(time.Second)
|
||||
}
|
||||
if err != nil {
|
||||
logs.CtxInfof(ctx, "UpsertAuthorizationCode failed, recordID=%d, err=%v", info.RecordID, err)
|
||||
err = p.oauthRepo.BatchDeleteAuthorizationCodeByIDs(ctx, []int64{info.RecordID})
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "BatchDeleteAuthorizationCodeByIDs failed, recordID=%d, err=%v", info.RecordID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) GetAccessToken(ctx context.Context, oa *entity.OAuthInfo) (accessToken string, err error) {
|
||||
switch oa.OAuthMode {
|
||||
case model.AuthzSubTypeOfOAuthAuthorizationCode:
|
||||
accessToken, err = p.getAccessTokenByAuthorizationCode(ctx, oa.AuthorizationCode)
|
||||
default:
|
||||
return "", fmt.Errorf("invalid oauth mode '%s'", oa.OAuthMode)
|
||||
}
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) getAccessTokenByAuthorizationCode(ctx context.Context, ci *entity.AuthorizationCodeInfo) (accessToken string, err error) {
|
||||
meta := ci.Meta
|
||||
info, exist, err := p.oauthRepo.GetAuthorizationCode(ctx, ci.Meta)
|
||||
if err != nil {
|
||||
return "", errorx.Wrapf(err, "GetAuthorizationCode failed, userID=%s, pluginID=%d, isDraft=%p",
|
||||
meta.UserID, meta.PluginID, meta.IsDraft)
|
||||
}
|
||||
if !exist {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
if !isValidAuthCodeConfig(info.Config, ci.Config, info.TokenExpiredAtMS, info.LastActiveAtMS) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
now := time.Now().UnixMilli()
|
||||
if now-info.LastActiveAtMS > time.Minute.Milliseconds() { // don't update too frequently
|
||||
err = p.oauthRepo.UpdateAuthorizationCodeLastActiveAt(ctx, meta, now)
|
||||
if err != nil {
|
||||
logs.CtxWarnf(ctx, "UpdateAuthorizationCodeLastActiveAt failed, userID=%s, pluginID=%d, isDraft=%t, err=%v",
|
||||
meta.UserID, meta.PluginID, meta.IsDraft, err)
|
||||
}
|
||||
}
|
||||
|
||||
return info.AccessToken, nil
|
||||
}
|
||||
|
||||
func isValidAuthCodeConfig(o, n *model.OAuthAuthorizationCodeConfig, expireAt, lastActiveAt int64) bool {
|
||||
now := time.Now()
|
||||
|
||||
if expireAt > 0 && expireAt <= now.UnixMilli() {
|
||||
return false
|
||||
}
|
||||
if lastActiveAt > 0 && lastActiveAt <= now.Add(-lastActiveInterval).UnixMilli() {
|
||||
return false
|
||||
}
|
||||
|
||||
if o.ClientID != n.ClientID {
|
||||
return false
|
||||
}
|
||||
if o.ClientSecret != n.ClientSecret {
|
||||
return false
|
||||
}
|
||||
if o.ClientURL != n.ClientURL {
|
||||
return false
|
||||
}
|
||||
if o.AuthorizationURL != n.AuthorizationURL {
|
||||
return false
|
||||
}
|
||||
if o.AuthorizationContentType != n.AuthorizationContentType {
|
||||
return false
|
||||
}
|
||||
|
||||
oldScope := strings.Split(o.Scope, " ")
|
||||
newScope := strings.Split(n.Scope, " ")
|
||||
|
||||
if len(oldScope) != len(newScope) {
|
||||
return false
|
||||
}
|
||||
|
||||
m := make(map[string]bool, len(oldScope))
|
||||
for _, v := range oldScope {
|
||||
m[v] = false
|
||||
}
|
||||
for _, v := range newScope {
|
||||
if _, ok := m[v]; !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) OAuthCode(ctx context.Context, code string, state *entity.OAuthState) (err error) {
|
||||
var plugin *entity.PluginInfo
|
||||
if state.IsDraft {
|
||||
plugin, err = p.GetDraftPlugin(ctx, state.PluginID)
|
||||
} else {
|
||||
plugin, err = p.GetOnlinePlugin(ctx, state.PluginID)
|
||||
}
|
||||
if err != nil {
|
||||
return errorx.Wrapf(err, "GetPlugin failed, pluginID=%d", state.PluginID)
|
||||
}
|
||||
|
||||
authInfo := plugin.GetAuthInfo()
|
||||
if authInfo.SubType != model.AuthzSubTypeOfOAuthAuthorizationCode {
|
||||
return errorx.New(errno.ErrPluginOAuthFailed, errorx.KV(errno.PluginMsgKey, "plugin auth type is not oauth authorization code"))
|
||||
}
|
||||
if authInfo.AuthOfOAuthAuthorizationCode == nil {
|
||||
return errorx.New(errno.ErrPluginOAuthFailed, errorx.KV(errno.PluginMsgKey, "plugin auth info is nil"))
|
||||
}
|
||||
|
||||
config := getStanderOAuthConfig(authInfo.AuthOfOAuthAuthorizationCode)
|
||||
|
||||
token, err := config.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
return errorx.WrapByCode(err, errno.ErrPluginOAuthFailed, errorx.KV(errno.PluginMsgKey, "exchange token failed"))
|
||||
}
|
||||
|
||||
meta := &entity.AuthorizationCodeMeta{
|
||||
UserID: state.UserID,
|
||||
PluginID: state.PluginID,
|
||||
IsDraft: state.IsDraft,
|
||||
}
|
||||
|
||||
var expiredAtMS int64
|
||||
if !token.Expiry.IsZero() && token.Expiry.After(time.Now()) {
|
||||
expiredAtMS = token.Expiry.UnixMilli()
|
||||
}
|
||||
|
||||
err = p.saveAccessToken(ctx, &entity.OAuthInfo{
|
||||
OAuthMode: model.AuthzSubTypeOfOAuthAuthorizationCode,
|
||||
AuthorizationCode: &entity.AuthorizationCodeInfo{
|
||||
Meta: meta,
|
||||
Config: authInfo.AuthOfOAuthAuthorizationCode,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
TokenExpiredAtMS: expiredAtMS,
|
||||
NextTokenRefreshAtMS: ptr.Of(getNextTokenRefreshAtMS(expiredAtMS)),
|
||||
LastActiveAtMS: time.Now().UnixMilli(),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return errorx.Wrapf(err, "SaveAccessToken failed, pluginID=%d", state.PluginID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) saveAccessToken(ctx context.Context, oa *entity.OAuthInfo) (err error) {
|
||||
switch oa.OAuthMode {
|
||||
case model.AuthzSubTypeOfOAuthAuthorizationCode:
|
||||
err = p.saveAuthCodeAccessToken(ctx, oa.AuthorizationCode)
|
||||
default:
|
||||
return fmt.Errorf("[standardOAuth] invalid oauth mode '%s'", oa.OAuthMode)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) saveAuthCodeAccessToken(ctx context.Context, info *entity.AuthorizationCodeInfo) (err error) {
|
||||
meta := info.Meta
|
||||
err = p.oauthRepo.UpsertAuthorizationCode(ctx, info)
|
||||
if err != nil {
|
||||
return errorx.Wrapf(err, "SaveAuthorizationCodeInfo failed, userID=%s, pluginID=%d, isDraft=%t",
|
||||
meta.UserID, meta.PluginID, meta.IsDraft)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getNextTokenRefreshAtMS(expiredAtMS int64) int64 {
|
||||
if expiredAtMS == 0 {
|
||||
return 0
|
||||
}
|
||||
return time.Now().Add(time.Duration((expiredAtMS-time.Now().UnixMilli())/2) * time.Millisecond).UnixMilli()
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) RevokeAccessToken(ctx context.Context, meta *entity.AuthorizationCodeMeta) (err error) {
|
||||
return p.oauthRepo.DeleteAuthorizationCode(ctx, meta)
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) GetOAuthStatus(ctx context.Context, userID, pluginID int64) (resp *GetOAuthStatusResponse, err error) {
|
||||
pl, exist, err := p.pluginRepo.GetDraftPlugin(ctx, pluginID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !exist {
|
||||
return nil, fmt.Errorf("draft plugin '%d' not found", pluginID)
|
||||
}
|
||||
|
||||
authInfo := pl.GetAuthInfo()
|
||||
if authInfo.Type == model.AuthzTypeOfNone || authInfo.Type == model.AuthzTypeOfService {
|
||||
return &GetOAuthStatusResponse{
|
||||
IsOauth: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
needAuth, authURL, err := p.getPluginOAuthStatus(ctx, userID, pl, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
status := common.OAuthStatus_Authorized
|
||||
if needAuth {
|
||||
status = common.OAuthStatus_Unauthorized
|
||||
}
|
||||
|
||||
resp = &GetOAuthStatusResponse{
|
||||
IsOauth: true,
|
||||
Status: status,
|
||||
OAuthURL: authURL,
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) getPluginOAuthStatus(ctx context.Context, userID int64, plugin *entity.PluginInfo, isDraft bool) (needAuth bool, authURL string, err error) {
|
||||
authInfo := plugin.GetAuthInfo()
|
||||
|
||||
if authInfo.Type != model.AuthzTypeOfOAuth {
|
||||
return false, "", fmt.Errorf("invalid auth type '%v'", authInfo.Type)
|
||||
}
|
||||
if authInfo.SubType != model.AuthzSubTypeOfOAuthAuthorizationCode {
|
||||
return false, "", fmt.Errorf("invalid auth sub type '%v'", authInfo.SubType)
|
||||
}
|
||||
|
||||
authCode := &entity.AuthorizationCodeInfo{
|
||||
Meta: &entity.AuthorizationCodeMeta{
|
||||
UserID: conv.Int64ToStr(userID),
|
||||
PluginID: plugin.ID,
|
||||
IsDraft: isDraft,
|
||||
},
|
||||
Config: plugin.Manifest.Auth.AuthOfOAuthAuthorizationCode,
|
||||
}
|
||||
|
||||
accessToken, err := p.GetAccessToken(ctx, &entity.OAuthInfo{
|
||||
OAuthMode: model.AuthzSubTypeOfOAuthAuthorizationCode,
|
||||
AuthorizationCode: authCode,
|
||||
})
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
|
||||
needAuth = accessToken == ""
|
||||
|
||||
authURL, err = genAuthURL(authCode)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
|
||||
return needAuth, authURL, nil
|
||||
}
|
||||
|
||||
func genAuthURL(info *entity.AuthorizationCodeInfo) (string, error) {
|
||||
config := getStanderOAuthConfig(info.Config)
|
||||
|
||||
state := &entity.OAuthState{
|
||||
ClientName: "",
|
||||
UserID: info.Meta.UserID,
|
||||
PluginID: info.Meta.PluginID,
|
||||
IsDraft: info.Meta.IsDraft,
|
||||
}
|
||||
stateStr, err := json.Marshal(state)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshal state failed, err=%v", err)
|
||||
}
|
||||
encryptState, err := utils.EncryptByAES(stateStr, utils.StateSecretKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("encrypt state failed, err=%v", err)
|
||||
}
|
||||
|
||||
authURL := config.AuthCodeURL(encryptState)
|
||||
|
||||
return authURL, nil
|
||||
}
|
||||
|
||||
func getStanderOAuthConfig(config *model.OAuthAuthorizationCodeConfig) *oauth2.Config {
|
||||
if config == nil {
|
||||
return nil
|
||||
}
|
||||
return &oauth2.Config{
|
||||
ClientID: config.ClientID,
|
||||
ClientSecret: config.ClientSecret,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
TokenURL: config.AuthorizationURL,
|
||||
AuthURL: config.ClientURL,
|
||||
},
|
||||
RedirectURL: fmt.Sprintf("https://%s/api/oauth/authorization_code", os.Getenv("SERVER_HOST")),
|
||||
Scopes: strings.Split(config.Scope, " "),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) GetAgentPluginsOAuthStatus(ctx context.Context, userID, agentID int64) (status []*AgentPluginOAuthStatus, err error) {
|
||||
pluginIDs, err := p.toolRepo.GetAgentPluginIDs(ctx, agentID)
|
||||
if err != nil {
|
||||
return nil, errorx.Wrapf(err, "GetAgentPluginIDs failed, agentID=%d", agentID)
|
||||
}
|
||||
|
||||
if len(pluginIDs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
plugins, err := p.pluginRepo.MGetOnlinePlugins(ctx, pluginIDs)
|
||||
if err != nil {
|
||||
return nil, errorx.Wrapf(err, "MGetOnlinePlugins failed, pluginIDs=%v", pluginIDs)
|
||||
}
|
||||
|
||||
for _, plugin := range plugins {
|
||||
authInfo := plugin.GetAuthInfo()
|
||||
if authInfo.Type == model.AuthzTypeOfNone || authInfo.Type == model.AuthzTypeOfService {
|
||||
continue
|
||||
}
|
||||
|
||||
needAuth, _, err := p.getPluginOAuthStatus(ctx, userID, plugin, false)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "getPluginOAuthStatus failed, pluginID=%d, err=%v", plugin.ID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
iconURL := ""
|
||||
if plugin.GetIconURI() != "" {
|
||||
iconURL, _ = p.oss.GetObjectUrl(ctx, plugin.GetIconURI())
|
||||
}
|
||||
|
||||
authStatus := common.OAuthStatus_Authorized
|
||||
if needAuth {
|
||||
authStatus = common.OAuthStatus_Unauthorized
|
||||
}
|
||||
|
||||
status = append(status, &AgentPluginOAuthStatus{
|
||||
PluginID: plugin.ID,
|
||||
PluginName: plugin.GetName(),
|
||||
PluginIconURL: iconURL,
|
||||
Status: authStatus,
|
||||
})
|
||||
}
|
||||
|
||||
return status, nil
|
||||
}
|
||||
382
backend/domain/plugin/service/plugin_online.go
Normal file
382
backend/domain/plugin/service/plugin_online.go
Normal file
@@ -0,0 +1,382 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
|
||||
searchModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/search"
|
||||
pluginCommon "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common"
|
||||
resCommon "github.com/coze-dev/coze-studio/backend/api/model/resource/common"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crosssearch"
|
||||
pluginConf "github.com/coze-dev/coze-studio/backend/domain/plugin/conf"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/repository"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
func (p *pluginServiceImpl) GetOnlinePlugin(ctx context.Context, pluginID int64) (plugin *entity.PluginInfo, err error) {
|
||||
pl, exist, err := p.pluginRepo.GetOnlinePlugin(ctx, pluginID)
|
||||
if err != nil {
|
||||
return nil, errorx.Wrapf(err, "GetOnlinePlugin failed, pluginID=%d", pluginID)
|
||||
}
|
||||
if !exist {
|
||||
return nil, errorx.New(errno.ErrPluginRecordNotFound)
|
||||
}
|
||||
|
||||
return pl, nil
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) MGetOnlinePlugins(ctx context.Context, pluginIDs []int64) (plugins []*entity.PluginInfo, err error) {
|
||||
plugins, err = p.pluginRepo.MGetOnlinePlugins(ctx, pluginIDs)
|
||||
if err != nil {
|
||||
return nil, errorx.Wrapf(err, "MGetOnlinePlugins failed, pluginIDs=%v", pluginIDs)
|
||||
}
|
||||
|
||||
res := make([]*model.PluginInfo, 0, len(plugins))
|
||||
for _, pl := range plugins {
|
||||
res = append(res, pl.PluginInfo)
|
||||
}
|
||||
|
||||
return plugins, nil
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) GetOnlineTool(ctx context.Context, toolID int64) (tool *entity.ToolInfo, err error) {
|
||||
tool, exist, err := p.toolRepo.GetOnlineTool(ctx, toolID)
|
||||
if err != nil {
|
||||
return nil, errorx.Wrapf(err, "GetOnlineTool failed, toolID=%d", toolID)
|
||||
}
|
||||
if !exist {
|
||||
return nil, errorx.New(errno.ErrPluginRecordNotFound)
|
||||
}
|
||||
|
||||
return tool, nil
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) MGetOnlineTools(ctx context.Context, toolIDs []int64) (tools []*entity.ToolInfo, err error) {
|
||||
tools, err = p.toolRepo.MGetOnlineTools(ctx, toolIDs)
|
||||
if err != nil {
|
||||
return nil, errorx.Wrapf(err, "MGetOnlineTools failed, toolIDs=%v", toolIDs)
|
||||
}
|
||||
|
||||
return tools, nil
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) MGetVersionTools(ctx context.Context, versionTools []entity.VersionTool) (tools []*entity.ToolInfo, err error) {
|
||||
tools, err = p.toolRepo.MGetVersionTools(ctx, versionTools)
|
||||
if err != nil {
|
||||
return nil, errorx.Wrapf(err, "MGetVersionTools failed, versionTools=%v", versionTools)
|
||||
}
|
||||
|
||||
return tools, nil
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) ListPluginProducts(ctx context.Context, req *ListPluginProductsRequest) (resp *ListPluginProductsResponse, err error) {
|
||||
plugins := slices.Transform(pluginConf.GetAllPluginProducts(), func(p *pluginConf.PluginInfo) *entity.PluginInfo {
|
||||
return entity.NewPluginInfo(p.Info)
|
||||
})
|
||||
sort.Slice(plugins, func(i, j int) bool {
|
||||
return plugins[i].GetRefProductID() < plugins[j].GetRefProductID()
|
||||
})
|
||||
|
||||
return &ListPluginProductsResponse{
|
||||
Plugins: plugins,
|
||||
Total: int64(len(plugins)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) GetPluginProductAllTools(ctx context.Context, pluginID int64) (tools []*entity.ToolInfo, err error) {
|
||||
res, err := p.toolRepo.GetPluginAllOnlineTools(ctx, pluginID)
|
||||
if err != nil {
|
||||
return nil, errorx.Wrapf(err, "GetPluginAllOnlineTools failed, pluginID=%d", pluginID)
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) DeleteAPPAllPlugins(ctx context.Context, appID int64) (pluginIDs []int64, err error) {
|
||||
return p.pluginRepo.DeleteAPPAllPlugins(ctx, appID)
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) GetAPPAllPlugins(ctx context.Context, appID int64) (plugins []*entity.PluginInfo, err error) {
|
||||
plugins, err = p.pluginRepo.GetAPPAllDraftPlugins(ctx, appID)
|
||||
if err != nil {
|
||||
return nil, errorx.Wrapf(err, "GetAPPAllDraftPlugins failed, appID=%d", appID)
|
||||
}
|
||||
|
||||
return plugins, nil
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) MGetVersionPlugins(ctx context.Context, versionPlugins []entity.VersionPlugin) (plugins []*entity.PluginInfo, err error) {
|
||||
plugins, err = p.pluginRepo.MGetVersionPlugins(ctx, versionPlugins)
|
||||
if err != nil {
|
||||
return nil, errorx.Wrapf(err, "MGetVersionPlugins failed, versionPlugins=%v", versionPlugins)
|
||||
}
|
||||
|
||||
return plugins, nil
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) ListCustomOnlinePlugins(ctx context.Context, spaceID int64, pageInfo entity.PageInfo) (plugins []*entity.PluginInfo, total int64, err error) {
|
||||
if pageInfo.Name == nil || *pageInfo.Name == "" {
|
||||
plugins, total, err = p.pluginRepo.ListCustomOnlinePlugins(ctx, spaceID, pageInfo)
|
||||
if err != nil {
|
||||
return nil, 0, errorx.Wrapf(err, "ListCustomOnlinePlugins failed, spaceID=%d", spaceID)
|
||||
}
|
||||
return plugins, total, nil
|
||||
}
|
||||
|
||||
res, err := crosssearch.DefaultSVC().SearchResources(ctx, &searchModel.SearchResourcesRequest{
|
||||
SpaceID: spaceID,
|
||||
Name: *pageInfo.Name,
|
||||
OrderAsc: false,
|
||||
ResTypeFilter: []resCommon.ResType{
|
||||
resCommon.ResType_Plugin,
|
||||
},
|
||||
OrderFiledName: func() string {
|
||||
if pageInfo.SortBy == nil || *pageInfo.SortBy != entity.SortByCreatedAt {
|
||||
return searchModel.FieldOfUpdateTime
|
||||
}
|
||||
return searchModel.FieldOfCreateTime
|
||||
}(),
|
||||
Page: ptr.Of(int32(pageInfo.Page)),
|
||||
Limit: int32(pageInfo.Size),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, 0, errorx.Wrapf(err, "SearchResources failed, spaceID=%d", spaceID)
|
||||
}
|
||||
|
||||
plugins = make([]*entity.PluginInfo, 0, len(res.Data))
|
||||
for _, pl := range res.Data {
|
||||
draftPlugin, exist, err := p.pluginRepo.GetOnlinePlugin(ctx, pl.ResID)
|
||||
if err != nil {
|
||||
return nil, 0, errorx.Wrapf(err, "GetOnlinePlugin failed, pluginID=%d", pl.ResID)
|
||||
}
|
||||
if !exist {
|
||||
logs.CtxWarnf(ctx, "online plugin not exist, pluginID=%d", pl.ResID)
|
||||
continue
|
||||
}
|
||||
plugins = append(plugins, draftPlugin)
|
||||
}
|
||||
|
||||
if res.TotalHits != nil {
|
||||
total = *res.TotalHits
|
||||
}
|
||||
|
||||
return plugins, total, nil
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) MGetPluginLatestVersion(ctx context.Context, pluginIDs []int64) (resp *MGetPluginLatestVersionResponse, err error) {
|
||||
plugins, err := p.pluginRepo.MGetOnlinePlugins(ctx, pluginIDs,
|
||||
repository.WithPluginID(),
|
||||
repository.WithPluginVersion())
|
||||
if err != nil {
|
||||
return nil, errorx.Wrapf(err, "MGetOnlinePlugins failed, pluginIDs=%v", pluginIDs)
|
||||
}
|
||||
|
||||
versions := make(map[int64]string, len(plugins))
|
||||
for _, pl := range plugins {
|
||||
versions[pl.ID] = pl.GetVersion()
|
||||
}
|
||||
|
||||
resp = &MGetPluginLatestVersionResponse{
|
||||
Versions: versions,
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) CopyPlugin(ctx context.Context, req *CopyPluginRequest) (resp *CopyPluginResponse, err error) {
|
||||
err = p.checkCanCopyPlugin(ctx, req.PluginID, req.CopyScene)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
plugin, tools, err := p.getCopySourcePluginAndTools(ctx, req.PluginID, req.CopyScene)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p.changePluginAndToolsInfoForCopy(req, plugin, tools)
|
||||
|
||||
toolMap := make(map[int64]*entity.ToolInfo, len(tools))
|
||||
for _, tool := range tools {
|
||||
toolMap[tool.ID] = tool
|
||||
}
|
||||
|
||||
plugin, tools, err = p.pluginRepo.CopyPlugin(ctx, &repository.CopyPluginRequest{
|
||||
Plugin: plugin,
|
||||
Tools: tools,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errorx.Wrapf(err, "CopyPlugin failed, pluginID=%d", req.PluginID)
|
||||
}
|
||||
|
||||
resp = &CopyPluginResponse{
|
||||
Plugin: plugin,
|
||||
Tools: toolMap,
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) changePluginAndToolsInfoForCopy(req *CopyPluginRequest, plugin *entity.PluginInfo, tools []*entity.ToolInfo) {
|
||||
plugin.Version = nil
|
||||
plugin.VersionDesc = nil
|
||||
|
||||
plugin.DeveloperID = req.UserID
|
||||
|
||||
if req.CopyScene != model.CopySceneOfAPPDuplicate {
|
||||
plugin.SetName(fmt.Sprintf("%s_copy", plugin.GetName()))
|
||||
}
|
||||
|
||||
if req.CopyScene == model.CopySceneOfToLibrary {
|
||||
const (
|
||||
defaultVersion = "v0.0.1"
|
||||
defaultVersionDesc = "copy to library"
|
||||
)
|
||||
|
||||
plugin.APPID = nil
|
||||
plugin.Version = ptr.Of(defaultVersion)
|
||||
plugin.VersionDesc = ptr.Of(defaultVersionDesc)
|
||||
|
||||
for _, tool := range tools {
|
||||
tool.Version = ptr.Of(defaultVersion)
|
||||
}
|
||||
}
|
||||
|
||||
if req.CopyScene == model.CopySceneOfToAPP {
|
||||
plugin.APPID = req.TargetAPPID
|
||||
|
||||
for _, tool := range tools {
|
||||
tool.DebugStatus = ptr.Of(pluginCommon.APIDebugStatus_DebugPassed)
|
||||
}
|
||||
}
|
||||
|
||||
if req.CopyScene == model.CopySceneOfAPPDuplicate {
|
||||
plugin.APPID = req.TargetAPPID
|
||||
}
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) checkCanCopyPlugin(ctx context.Context, pluginID int64, scene model.CopyScene) (err error) {
|
||||
switch scene {
|
||||
case model.CopySceneOfToAPP, model.CopySceneOfDuplicate, model.CopySceneOfAPPDuplicate:
|
||||
return nil
|
||||
case model.CopySceneOfToLibrary:
|
||||
return p.checkToolsDebugStatus(ctx, pluginID)
|
||||
default:
|
||||
return fmt.Errorf("unsupported copy scene '%s'", scene)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) getCopySourcePluginAndTools(ctx context.Context, pluginID int64, scene model.CopyScene) (plugin *entity.PluginInfo, tools []*entity.ToolInfo, err error) {
|
||||
switch scene {
|
||||
case model.CopySceneOfToAPP:
|
||||
return p.getOnlinePluginAndTools(ctx, pluginID)
|
||||
case model.CopySceneOfToLibrary, model.CopySceneOfDuplicate, model.CopySceneOfAPPDuplicate:
|
||||
return p.getDraftPluginAndTools(ctx, pluginID)
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("unsupported copy scene '%s'", scene)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) getOnlinePluginAndTools(ctx context.Context, pluginID int64) (plugin *entity.PluginInfo, tools []*entity.ToolInfo, err error) {
|
||||
onlinePlugin, exist, err := p.pluginRepo.GetOnlinePlugin(ctx, pluginID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if !exist {
|
||||
return nil, nil, errorx.New(errno.ErrPluginRecordNotFound)
|
||||
}
|
||||
|
||||
onlineTools, err := p.toolRepo.GetPluginAllOnlineTools(ctx, pluginID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return onlinePlugin, onlineTools, nil
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) getDraftPluginAndTools(ctx context.Context, pluginID int64) (plugin *entity.PluginInfo, tools []*entity.ToolInfo, err error) {
|
||||
draftPlugin, exist, err := p.pluginRepo.GetDraftPlugin(ctx, pluginID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if !exist {
|
||||
return nil, nil, errorx.New(errno.ErrPluginRecordNotFound)
|
||||
}
|
||||
|
||||
draftTools, err := p.toolRepo.GetPluginAllDraftTools(ctx, pluginID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return draftPlugin, draftTools, nil
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) MoveAPPPluginToLibrary(ctx context.Context, pluginID int64) (draftPlugin *entity.PluginInfo, err error) {
|
||||
draftPlugin, exist, err := p.pluginRepo.GetDraftPlugin(ctx, pluginID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !exist {
|
||||
return nil, errorx.New(errno.ErrPluginRecordNotFound)
|
||||
}
|
||||
|
||||
err = p.checkToolsDebugStatus(ctx, pluginID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
draftTools, err := p.toolRepo.GetPluginAllDraftTools(ctx, pluginID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p.changePluginAndToolsInfoForMove(draftPlugin, draftTools)
|
||||
|
||||
err = p.pluginRepo.MoveAPPPluginToLibrary(ctx, draftPlugin, draftTools)
|
||||
if err != nil {
|
||||
return nil, errorx.Wrapf(err, "MoveAPPPluginToLibrary failed, pluginID=%d", pluginID)
|
||||
}
|
||||
|
||||
return draftPlugin, nil
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) changePluginAndToolsInfoForMove(plugin *entity.PluginInfo,
|
||||
tools []*entity.ToolInfo) {
|
||||
|
||||
const (
|
||||
defaultVersion = "v0.0.1"
|
||||
defaultVersionDesc = "move to library"
|
||||
)
|
||||
|
||||
plugin.Version = ptr.Of(defaultVersion)
|
||||
plugin.VersionDesc = ptr.Of(defaultVersionDesc)
|
||||
|
||||
for _, tool := range tools {
|
||||
tool.Version = ptr.Of(defaultVersion)
|
||||
}
|
||||
|
||||
plugin.APPID = nil
|
||||
}
|
||||
234
backend/domain/plugin/service/plugin_release.go
Normal file
234
backend/domain/plugin/service/plugin_release.go
Normal file
@@ -0,0 +1,234 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/mod/semver"
|
||||
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
|
||||
common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/repository"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
func (p *pluginServiceImpl) GetPluginNextVersion(ctx context.Context, pluginID int64) (version string, err error) {
|
||||
const defaultVersion = "v1.0.0"
|
||||
|
||||
pl, exist, err := p.pluginRepo.GetOnlinePlugin(ctx, pluginID)
|
||||
if err != nil {
|
||||
return "", errorx.Wrapf(err, "GetOnlinePlugin failed, pluginID=%d", pluginID)
|
||||
}
|
||||
if !exist {
|
||||
return defaultVersion, nil
|
||||
}
|
||||
|
||||
parts := strings.Split(pl.GetVersion(), ".") // Remove the 'v' and split
|
||||
if len(parts) < 3 {
|
||||
logs.CtxWarnf(ctx, "invalid version format '%s'", pl.GetVersion())
|
||||
return defaultVersion, nil
|
||||
}
|
||||
|
||||
patch, err := strconv.ParseInt(parts[2], 10, 64)
|
||||
if err != nil {
|
||||
logs.CtxWarnf(ctx, "invalid version format '%s'", pl.GetVersion())
|
||||
return defaultVersion, nil
|
||||
}
|
||||
|
||||
parts[2] = strconv.FormatInt(patch+1, 10)
|
||||
nextVersion := strings.Join(parts, ".")
|
||||
|
||||
return nextVersion, nil
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) PublishPlugin(ctx context.Context, req *PublishPluginRequest) (err error) {
|
||||
draftPlugin, exist, err := p.pluginRepo.GetDraftPlugin(ctx, req.PluginID)
|
||||
if err != nil {
|
||||
return errorx.Wrapf(err, "GetDraftPlugin failed, pluginID=%d", req.PluginID)
|
||||
}
|
||||
if !exist {
|
||||
return errorx.New(errno.ErrPluginRecordNotFound)
|
||||
}
|
||||
|
||||
err = p.checkToolsDebugStatus(ctx, req.PluginID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
onlinePlugin, exist, err := p.pluginRepo.GetOnlinePlugin(ctx, req.PluginID)
|
||||
if err != nil {
|
||||
return errorx.Wrapf(err, "GetOnlinePlugin failed, pluginID=%d", req.PluginID)
|
||||
}
|
||||
if exist && onlinePlugin.Version != nil {
|
||||
if semver.Compare(req.Version, *onlinePlugin.Version) != 1 {
|
||||
return errorx.New(errno.ErrPluginInvalidParamCode,
|
||||
errorx.KVf(errno.PluginMsgKey, "version must be greater than the online version '%s' and format like 'v1.0.0'",
|
||||
*onlinePlugin.Version))
|
||||
}
|
||||
}
|
||||
|
||||
draftPlugin.Version = &req.Version
|
||||
draftPlugin.VersionDesc = &req.VersionDesc
|
||||
|
||||
err = p.pluginRepo.PublishPlugin(ctx, draftPlugin)
|
||||
if err != nil {
|
||||
return errorx.Wrapf(err, "PublishPlugin failed, pluginID=%d", req.PluginID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) PublishAPPPlugins(ctx context.Context, req *PublishAPPPluginsRequest) (resp *PublishAPPPluginsResponse, err error) {
|
||||
resp = &PublishAPPPluginsResponse{}
|
||||
|
||||
draftPlugins, err := p.pluginRepo.GetAPPAllDraftPlugins(ctx, req.APPID)
|
||||
if err != nil {
|
||||
return nil, errorx.Wrapf(err, "GetAPPAllDraftPlugins failed, appID=%d", req.APPID)
|
||||
}
|
||||
|
||||
failedPluginIDs, err := p.checkCanPublishAPPPlugins(ctx, req.Version, draftPlugins)
|
||||
if err != nil {
|
||||
return nil, errorx.Wrapf(err, "checkCanPublishAPPPlugins failed, appID=%d, appVerion=%s", req.APPID, req.Version)
|
||||
}
|
||||
|
||||
for _, draftPlugin := range draftPlugins {
|
||||
draftPlugin.Version = &req.Version
|
||||
draftPlugin.VersionDesc = ptr.Of(fmt.Sprintf("publish %s", req.Version))
|
||||
resp.AllDraftPlugins = append(resp.AllDraftPlugins, draftPlugin.PluginInfo)
|
||||
}
|
||||
|
||||
if len(failedPluginIDs) > 0 {
|
||||
draftPluginMap := slices.ToMap(draftPlugins, func(plugin *entity.PluginInfo) (int64, *entity.PluginInfo) {
|
||||
return plugin.ID, plugin
|
||||
})
|
||||
|
||||
failedPlugins := make([]*entity.PluginInfo, 0, len(failedPluginIDs))
|
||||
for _, failedPluginID := range failedPluginIDs {
|
||||
failedPlugins = append(failedPlugins, draftPluginMap[failedPluginID])
|
||||
}
|
||||
for _, failedPlugin := range failedPlugins {
|
||||
resp.FailedPlugins = append(resp.FailedPlugins, failedPlugin.PluginInfo)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
err = p.pluginRepo.PublishPlugins(ctx, draftPlugins)
|
||||
if err != nil {
|
||||
return nil, errorx.Wrapf(err, "PublishPlugins failed, appID=%d", req.APPID)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) checkCanPublishAPPPlugins(ctx context.Context, version string, draftPlugins []*entity.PluginInfo) (failedPluginIDs []int64, err error) {
|
||||
failedPluginIDs = make([]int64, 0, len(draftPlugins))
|
||||
|
||||
draftPluginIDs := slices.Transform(draftPlugins, func(plugin *entity.PluginInfo) int64 {
|
||||
return plugin.ID
|
||||
})
|
||||
|
||||
// 1. check version
|
||||
onlinePlugins, err := p.pluginRepo.MGetOnlinePlugins(ctx, draftPluginIDs,
|
||||
repository.WithPluginID(),
|
||||
repository.WithPluginVersion())
|
||||
if err != nil {
|
||||
return nil, errorx.Wrapf(err, "MGetOnlinePlugins failed, pluginIDs=%v", draftPluginIDs)
|
||||
}
|
||||
|
||||
if len(onlinePlugins) > 0 {
|
||||
for _, onlinePlugin := range onlinePlugins {
|
||||
if onlinePlugin.Version == nil {
|
||||
continue
|
||||
}
|
||||
if semver.Compare(version, *onlinePlugin.Version) != 1 {
|
||||
failedPluginIDs = append(failedPluginIDs, onlinePlugin.ID)
|
||||
}
|
||||
}
|
||||
if len(failedPluginIDs) > 0 {
|
||||
logs.CtxErrorf(ctx, "invalid version of plugins '%v'", failedPluginIDs)
|
||||
return failedPluginIDs, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 2. check debug status
|
||||
for _, draftPlugin := range draftPlugins {
|
||||
err = p.checkToolsDebugStatus(ctx, draftPlugin.ID)
|
||||
if err != nil {
|
||||
failedPluginIDs = append(failedPluginIDs, draftPlugin.ID)
|
||||
logs.CtxErrorf(ctx, "checkToolsDebugStatus failed, pluginID=%d, err=%s", draftPlugin.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(failedPluginIDs) > 0 {
|
||||
return failedPluginIDs, nil
|
||||
}
|
||||
|
||||
return failedPluginIDs, nil
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) checkToolsDebugStatus(ctx context.Context, pluginID int64) (err error) {
|
||||
res, err := p.toolRepo.GetPluginAllDraftTools(ctx, pluginID,
|
||||
repository.WithToolID(),
|
||||
repository.WithToolDebugStatus(),
|
||||
repository.WithToolActivatedStatus(),
|
||||
)
|
||||
if err != nil {
|
||||
return errorx.Wrapf(err, "GetPluginAllDraftTools failed, pluginID=%d", pluginID)
|
||||
}
|
||||
|
||||
if len(res) == 0 {
|
||||
return errorx.New(errno.ErrPluginToolsCheckFailed, errorx.KVf(errno.PluginMsgKey,
|
||||
"at least one activated tool is required in plugin"))
|
||||
}
|
||||
|
||||
activatedTools := make([]*entity.ToolInfo, 0, len(res))
|
||||
for _, tool := range res {
|
||||
if tool.GetActivatedStatus() == model.DeactivateTool {
|
||||
continue
|
||||
}
|
||||
activatedTools = append(activatedTools, tool)
|
||||
}
|
||||
|
||||
if len(activatedTools) == 0 {
|
||||
return errorx.New(errno.ErrPluginToolsCheckFailed, errorx.KVf(errno.PluginMsgKey,
|
||||
"at least one activated tool is required in plugin"))
|
||||
}
|
||||
|
||||
for _, tool := range activatedTools {
|
||||
if tool.GetDebugStatus() != common.APIDebugStatus_DebugWaiting {
|
||||
continue
|
||||
}
|
||||
return errorx.New(errno.ErrPluginToolsCheckFailed, errorx.KVf(errno.PluginMsgKey,
|
||||
"tools in plugin have not debugged yet"))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *pluginServiceImpl) CheckPluginToolsDebugStatus(ctx context.Context, pluginID int64) (err error) {
|
||||
return p.checkToolsDebugStatus(ctx, pluginID)
|
||||
}
|
||||
400
backend/domain/plugin/service/service.go
Normal file
400
backend/domain/plugin/service/service.go
Normal file
@@ -0,0 +1,400 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/getkin/kin-openapi/openapi3"
|
||||
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
|
||||
common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
//go:generate mockgen -destination ../../../internal/mock/domain/plugin/interface.go --package mockPlugin -source service.go
|
||||
type PluginService interface {
|
||||
// Draft Plugin
|
||||
CreateDraftPlugin(ctx context.Context, req *CreateDraftPluginRequest) (pluginID int64, err error)
|
||||
CreateDraftPluginWithCode(ctx context.Context, req *CreateDraftPluginWithCodeRequest) (resp *CreateDraftPluginWithCodeResponse, err error)
|
||||
GetDraftPlugin(ctx context.Context, pluginID int64) (plugin *entity.PluginInfo, err error)
|
||||
MGetDraftPlugins(ctx context.Context, pluginIDs []int64) (plugins []*entity.PluginInfo, err error)
|
||||
ListDraftPlugins(ctx context.Context, req *ListDraftPluginsRequest) (resp *ListDraftPluginsResponse, err error)
|
||||
UpdateDraftPlugin(ctx context.Context, plugin *UpdateDraftPluginRequest) (err error)
|
||||
UpdateDraftPluginWithCode(ctx context.Context, req *UpdateDraftPluginWithCodeRequest) (err error)
|
||||
DeleteDraftPlugin(ctx context.Context, pluginID int64) (err error)
|
||||
DeleteAPPAllPlugins(ctx context.Context, appID int64) (pluginIDs []int64, err error)
|
||||
GetAPPAllPlugins(ctx context.Context, appID int64) (plugins []*entity.PluginInfo, err error)
|
||||
|
||||
// Online Plugin
|
||||
PublishPlugin(ctx context.Context, req *PublishPluginRequest) (err error)
|
||||
PublishAPPPlugins(ctx context.Context, req *PublishAPPPluginsRequest) (resp *PublishAPPPluginsResponse, err error)
|
||||
GetOnlinePlugin(ctx context.Context, pluginID int64) (plugin *entity.PluginInfo, err error)
|
||||
MGetOnlinePlugins(ctx context.Context, pluginIDs []int64) (plugins []*entity.PluginInfo, err error)
|
||||
MGetPluginLatestVersion(ctx context.Context, pluginIDs []int64) (resp *MGetPluginLatestVersionResponse, err error)
|
||||
GetPluginNextVersion(ctx context.Context, pluginID int64) (version string, err error)
|
||||
MGetVersionPlugins(ctx context.Context, versionPlugins []entity.VersionPlugin) (plugins []*entity.PluginInfo, err error)
|
||||
ListCustomOnlinePlugins(ctx context.Context, spaceID int64, pageInfo entity.PageInfo) (plugins []*entity.PluginInfo, total int64, err error)
|
||||
|
||||
// Draft Tool
|
||||
MGetDraftTools(ctx context.Context, toolIDs []int64) (tools []*entity.ToolInfo, err error)
|
||||
UpdateDraftTool(ctx context.Context, req *UpdateToolDraftRequest) (err error)
|
||||
ConvertToOpenapi3Doc(ctx context.Context, req *ConvertToOpenapi3DocRequest) (resp *ConvertToOpenapi3DocResponse)
|
||||
CreateDraftToolsWithCode(ctx context.Context, req *CreateDraftToolsWithCodeRequest) (resp *CreateDraftToolsWithCodeResponse, err error)
|
||||
CheckPluginToolsDebugStatus(ctx context.Context, pluginID int64) (err error)
|
||||
|
||||
// Online Tool
|
||||
GetOnlineTool(ctx context.Context, toolID int64) (tool *entity.ToolInfo, err error)
|
||||
MGetOnlineTools(ctx context.Context, toolIDs []int64) (tools []*entity.ToolInfo, err error)
|
||||
MGetVersionTools(ctx context.Context, versionTools []entity.VersionTool) (tools []*entity.ToolInfo, err error)
|
||||
CopyPlugin(ctx context.Context, req *CopyPluginRequest) (resp *CopyPluginResponse, err error)
|
||||
MoveAPPPluginToLibrary(ctx context.Context, pluginID int64) (plugin *entity.PluginInfo, err error)
|
||||
|
||||
// Agent Tool
|
||||
BindAgentTools(ctx context.Context, agentID int64, toolIDs []int64) (err error)
|
||||
DuplicateDraftAgentTools(ctx context.Context, fromAgentID, toAgentID int64) (err error)
|
||||
GetDraftAgentToolByName(ctx context.Context, agentID int64, toolName string) (tool *entity.ToolInfo, err error)
|
||||
MGetAgentTools(ctx context.Context, req *MGetAgentToolsRequest) (tools []*entity.ToolInfo, err error)
|
||||
UpdateBotDefaultParams(ctx context.Context, req *UpdateBotDefaultParamsRequest) (err error)
|
||||
|
||||
PublishAgentTools(ctx context.Context, agentID int64, agentVersion string) (err error)
|
||||
|
||||
ExecuteTool(ctx context.Context, req *ExecuteToolRequest, opts ...entity.ExecuteToolOpt) (resp *ExecuteToolResponse, err error)
|
||||
|
||||
// Product
|
||||
ListPluginProducts(ctx context.Context, req *ListPluginProductsRequest) (resp *ListPluginProductsResponse, err error)
|
||||
GetPluginProductAllTools(ctx context.Context, pluginID int64) (tools []*entity.ToolInfo, err error)
|
||||
|
||||
GetOAuthStatus(ctx context.Context, userID, pluginID int64) (resp *GetOAuthStatusResponse, err error)
|
||||
GetAgentPluginsOAuthStatus(ctx context.Context, userID, agentID int64) (status []*AgentPluginOAuthStatus, err error)
|
||||
OAuthCode(ctx context.Context, code string, state *entity.OAuthState) (err error)
|
||||
GetAccessToken(ctx context.Context, oa *entity.OAuthInfo) (accessToken string, err error)
|
||||
RevokeAccessToken(ctx context.Context, meta *entity.AuthorizationCodeMeta) (err error)
|
||||
}
|
||||
|
||||
type CreateDraftPluginRequest struct {
|
||||
PluginType common.PluginType
|
||||
IconURI string
|
||||
SpaceID int64
|
||||
DeveloperID int64
|
||||
ProjectID *int64
|
||||
Name string
|
||||
Desc string
|
||||
ServerURL string
|
||||
CommonParams map[common.ParameterLocation][]*common.CommonParamSchema
|
||||
AuthInfo *PluginAuthInfo
|
||||
}
|
||||
|
||||
type UpdateDraftPluginWithCodeRequest struct {
|
||||
UserID int64
|
||||
PluginID int64
|
||||
OpenapiDoc *model.Openapi3T
|
||||
Manifest *entity.PluginManifest
|
||||
}
|
||||
|
||||
type UpdateDraftPluginRequest struct {
|
||||
PluginID int64
|
||||
Name *string
|
||||
Desc *string
|
||||
URL *string
|
||||
Icon *common.PluginIcon
|
||||
CommonParams map[common.ParameterLocation][]*common.CommonParamSchema
|
||||
AuthInfo *PluginAuthInfo
|
||||
}
|
||||
|
||||
type ListDraftPluginsRequest struct {
|
||||
SpaceID int64
|
||||
APPID int64
|
||||
PageInfo entity.PageInfo
|
||||
}
|
||||
|
||||
type ListDraftPluginsResponse struct {
|
||||
Plugins []*entity.PluginInfo
|
||||
Total int64
|
||||
}
|
||||
|
||||
type CreateDraftPluginWithCodeRequest struct {
|
||||
SpaceID int64
|
||||
DeveloperID int64
|
||||
ProjectID *int64
|
||||
Manifest *entity.PluginManifest
|
||||
OpenapiDoc *model.Openapi3T
|
||||
}
|
||||
|
||||
type CreateDraftPluginWithCodeResponse struct {
|
||||
Plugin *entity.PluginInfo
|
||||
Tools []*entity.ToolInfo
|
||||
}
|
||||
|
||||
type CreateDraftToolsWithCodeRequest struct {
|
||||
PluginID int64
|
||||
OpenapiDoc *model.Openapi3T
|
||||
|
||||
ConflictAndUpdate bool
|
||||
}
|
||||
|
||||
type CreateDraftToolsWithCodeResponse struct {
|
||||
DuplicatedTools []entity.UniqueToolAPI
|
||||
}
|
||||
|
||||
type PluginAuthInfo struct {
|
||||
AuthzType *model.AuthzType
|
||||
Location *model.HTTPParamLocation
|
||||
Key *string
|
||||
ServiceToken *string
|
||||
OAuthInfo *string
|
||||
AuthzSubType *model.AuthzSubType
|
||||
AuthzPayload *string
|
||||
}
|
||||
|
||||
func (p PluginAuthInfo) toAuthV2() (*model.AuthV2, error) {
|
||||
if p.AuthzType == nil {
|
||||
return nil, errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey, "auth type is required"))
|
||||
}
|
||||
|
||||
switch *p.AuthzType {
|
||||
case model.AuthzTypeOfNone:
|
||||
return &model.AuthV2{
|
||||
Type: model.AuthzTypeOfNone,
|
||||
}, nil
|
||||
|
||||
case model.AuthzTypeOfOAuth:
|
||||
m, err := p.authOfOAuthToAuthV2()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case model.AuthzTypeOfService:
|
||||
m, err := p.authOfServiceToAuthV2()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
|
||||
default:
|
||||
return nil, errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey,
|
||||
"the type '%s' of auth is invalid", *p.AuthzType))
|
||||
}
|
||||
}
|
||||
|
||||
func (p PluginAuthInfo) authOfOAuthToAuthV2() (*model.AuthV2, error) {
|
||||
if p.AuthzSubType == nil {
|
||||
return nil, errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey, "sub-auth type is required"))
|
||||
}
|
||||
|
||||
if p.OAuthInfo == nil || *p.OAuthInfo == "" {
|
||||
return nil, errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey, "oauth info is required"))
|
||||
}
|
||||
|
||||
oauthInfo := make(map[string]string)
|
||||
err := sonic.Unmarshal([]byte(*p.OAuthInfo), &oauthInfo)
|
||||
if err != nil {
|
||||
return nil, errorx.WrapByCode(err, errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey, "invalid oauth info"))
|
||||
}
|
||||
|
||||
if *p.AuthzSubType == model.AuthzSubTypeOfOAuthClientCredentials {
|
||||
_oauthInfo := &model.OAuthClientCredentialsConfig{
|
||||
ClientID: oauthInfo["client_id"],
|
||||
ClientSecret: oauthInfo["client_secret"],
|
||||
TokenURL: oauthInfo["token_url"],
|
||||
}
|
||||
|
||||
str, err := sonic.MarshalString(_oauthInfo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal oauth info failed, err=%v", err)
|
||||
}
|
||||
|
||||
return &model.AuthV2{
|
||||
Type: model.AuthzTypeOfOAuth,
|
||||
SubType: model.AuthzSubTypeOfOAuthClientCredentials,
|
||||
Payload: str,
|
||||
AuthOfOAuthClientCredentials: _oauthInfo,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if *p.AuthzSubType == model.AuthzSubTypeOfOAuthAuthorizationCode {
|
||||
contentType := oauthInfo["authorization_content_type"]
|
||||
if contentType != model.MediaTypeJson { // only support application/json
|
||||
return nil, errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey,
|
||||
"the type '%s' of authorization content is invalid", contentType))
|
||||
}
|
||||
|
||||
_oauthInfo := &model.OAuthAuthorizationCodeConfig{
|
||||
ClientID: oauthInfo["client_id"],
|
||||
ClientSecret: oauthInfo["client_secret"],
|
||||
ClientURL: oauthInfo["client_url"],
|
||||
Scope: oauthInfo["scope"],
|
||||
AuthorizationURL: oauthInfo["authorization_url"],
|
||||
AuthorizationContentType: contentType,
|
||||
}
|
||||
|
||||
str, err := sonic.MarshalString(_oauthInfo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal oauth info failed, err=%v", err)
|
||||
}
|
||||
|
||||
return &model.AuthV2{
|
||||
Type: model.AuthzTypeOfOAuth,
|
||||
SubType: model.AuthzSubTypeOfOAuthAuthorizationCode,
|
||||
Payload: str,
|
||||
AuthOfOAuthAuthorizationCode: _oauthInfo,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey,
|
||||
"the type '%s' of sub-auth is invalid", *p.AuthzSubType))
|
||||
}
|
||||
|
||||
func (p PluginAuthInfo) authOfServiceToAuthV2() (*model.AuthV2, error) {
|
||||
if p.AuthzSubType == nil {
|
||||
return nil, fmt.Errorf("sub-auth type is required")
|
||||
}
|
||||
|
||||
if *p.AuthzSubType == model.AuthzSubTypeOfServiceAPIToken {
|
||||
if p.Location == nil {
|
||||
return nil, fmt.Errorf("'Location' of sub-auth is required")
|
||||
}
|
||||
if p.ServiceToken == nil {
|
||||
return nil, fmt.Errorf("'ServiceToken' of sub-auth is required")
|
||||
}
|
||||
if p.Key == nil {
|
||||
return nil, fmt.Errorf("'Key' of sub-auth is required")
|
||||
}
|
||||
|
||||
tokenAuth := &model.AuthOfAPIToken{
|
||||
ServiceToken: *p.ServiceToken,
|
||||
Location: model.HTTPParamLocation(strings.ToLower(string(*p.Location))),
|
||||
Key: *p.Key,
|
||||
}
|
||||
|
||||
str, err := sonic.MarshalString(tokenAuth)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal token auth failed, err=%v", err)
|
||||
}
|
||||
|
||||
return &model.AuthV2{
|
||||
Type: model.AuthzTypeOfService,
|
||||
SubType: model.AuthzSubTypeOfServiceAPIToken,
|
||||
Payload: str,
|
||||
AuthOfAPIToken: tokenAuth,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey,
|
||||
"the type '%s' of sub-auth is invalid", *p.AuthzSubType))
|
||||
}
|
||||
|
||||
type PublishPluginRequest = model.PublishPluginRequest
|
||||
|
||||
type PublishAPPPluginsRequest = model.PublishAPPPluginsRequest
|
||||
|
||||
type PublishAPPPluginsResponse = model.PublishAPPPluginsResponse
|
||||
|
||||
type MGetPluginLatestVersionResponse = model.MGetPluginLatestVersionResponse
|
||||
|
||||
type UpdateToolDraftRequest struct {
|
||||
PluginID int64
|
||||
ToolID int64
|
||||
Name *string
|
||||
Desc *string
|
||||
SubURL *string
|
||||
Method *string
|
||||
Parameters openapi3.Parameters
|
||||
RequestBody *openapi3.RequestBodyRef
|
||||
Responses openapi3.Responses
|
||||
Disabled *bool
|
||||
SaveExample *bool
|
||||
DebugExample *common.DebugExample
|
||||
APIExtend *common.APIExtend
|
||||
}
|
||||
|
||||
type MGetAgentToolsRequest = model.MGetAgentToolsRequest
|
||||
|
||||
type UpdateBotDefaultParamsRequest struct {
|
||||
PluginID int64
|
||||
AgentID int64
|
||||
ToolName string
|
||||
Parameters openapi3.Parameters
|
||||
RequestBody *openapi3.RequestBodyRef
|
||||
Responses openapi3.Responses
|
||||
}
|
||||
|
||||
type ExecuteToolRequest = model.ExecuteToolRequest
|
||||
|
||||
type ExecuteToolResponse = model.ExecuteToolResponse
|
||||
|
||||
type ListPluginProductsRequest struct{}
|
||||
|
||||
type ListPluginProductsResponse struct {
|
||||
Plugins []*entity.PluginInfo
|
||||
Total int64
|
||||
}
|
||||
|
||||
type ConvertToOpenapi3DocRequest struct {
|
||||
RawInput string
|
||||
PluginServerURL *string
|
||||
}
|
||||
|
||||
type ConvertToOpenapi3DocResponse struct {
|
||||
OpenapiDoc *model.Openapi3T
|
||||
Manifest *entity.PluginManifest
|
||||
Format common.PluginDataFormat
|
||||
ErrMsg string
|
||||
}
|
||||
|
||||
type GetOAuthStatusResponse struct {
|
||||
IsOauth bool
|
||||
Status common.OAuthStatus
|
||||
OAuthURL string
|
||||
}
|
||||
|
||||
type AgentPluginOAuthStatus struct {
|
||||
PluginID int64
|
||||
PluginName string
|
||||
PluginIconURL string
|
||||
Status common.OAuthStatus
|
||||
}
|
||||
|
||||
type CopyPluginRequest struct {
|
||||
UserID int64
|
||||
PluginID int64
|
||||
CopyScene model.CopyScene
|
||||
|
||||
TargetAPPID *int64
|
||||
}
|
||||
|
||||
type CopyPluginResponse struct {
|
||||
Plugin *entity.PluginInfo
|
||||
Tools map[int64]*entity.ToolInfo // old tool id -> new tool
|
||||
}
|
||||
|
||||
type MoveAPPPluginToLibRequest struct {
|
||||
PluginID int64
|
||||
}
|
||||
|
||||
type GetAccessTokenRequest struct {
|
||||
UserID string
|
||||
PluginID *int64
|
||||
Mode model.AuthzSubType
|
||||
OAuthInfo *entity.OAuthInfo
|
||||
}
|
||||
67
backend/domain/plugin/service/service_impl.go
Normal file
67
backend/domain/plugin/service/service_impl.go
Normal file
@@ -0,0 +1,67 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/repository"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
)
|
||||
|
||||
type Components struct {
|
||||
IDGen idgen.IDGenerator
|
||||
DB *gorm.DB
|
||||
OSS storage.Storage
|
||||
PluginRepo repository.PluginRepository
|
||||
ToolRepo repository.ToolRepository
|
||||
OAuthRepo repository.OAuthRepository
|
||||
}
|
||||
|
||||
func NewService(components *Components) PluginService {
|
||||
impl := &pluginServiceImpl{
|
||||
db: components.DB,
|
||||
oss: components.OSS,
|
||||
pluginRepo: components.PluginRepo,
|
||||
toolRepo: components.ToolRepo,
|
||||
oauthRepo: components.OAuthRepo,
|
||||
httpCli: resty.New(),
|
||||
}
|
||||
|
||||
initOnce.Do(func() {
|
||||
ctx := context.Background()
|
||||
safego.Go(ctx, func() {
|
||||
impl.processOAuthAccessToken(ctx)
|
||||
})
|
||||
})
|
||||
|
||||
return impl
|
||||
}
|
||||
|
||||
type pluginServiceImpl struct {
|
||||
db *gorm.DB
|
||||
oss storage.Storage
|
||||
pluginRepo repository.PluginRepository
|
||||
toolRepo repository.ToolRepository
|
||||
oauthRepo repository.OAuthRepository
|
||||
httpCli *resty.Client
|
||||
}
|
||||
Reference in New Issue
Block a user