feat: manually mirror opencoze's code from bytedance
Change-Id: I09a73aadda978ad9511264a756b2ce51f5761adf
This commit is contained in:
32
backend/api/model/crossdomain/agentrun/agent_run.go
Normal file
32
backend/api/model/crossdomain/agentrun/agent_run.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package agentrun
|
||||
|
||||
type Tool struct {
|
||||
PluginID int64 `json:"plugin_id"`
|
||||
ToolID int64 `json:"tool_id"`
|
||||
Arguments string `json:"arguments"`
|
||||
ToolName string `json:"tool_name"`
|
||||
Type ToolType `json:"type"`
|
||||
}
|
||||
|
||||
type ToolType int32
|
||||
|
||||
const (
|
||||
ToolTypePlugin ToolType = 2
|
||||
ToolTypeWorkflow ToolType = 1
|
||||
)
|
||||
|
||||
type ToolsRetriever struct {
|
||||
PluginID int64
|
||||
ToolName string
|
||||
ToolID int64
|
||||
Arguments string
|
||||
Type ToolType
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
LlmPromptTokens int64 `json:"llm_prompt_tokens"`
|
||||
LlmCompletionTokens int64 `json:"llm_completion_tokens"`
|
||||
LlmTotalTokens int64 `json:"llm_total_tokens"`
|
||||
WorkflowTokens *int64 `json:"workflow_tokens,omitempty"`
|
||||
WorkflowCost *int64 `json:"workflow_cost,omitempty"`
|
||||
}
|
||||
12
backend/api/model/crossdomain/connector/connector.go
Normal file
12
backend/api/model/crossdomain/connector/connector.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package connector
|
||||
|
||||
import "github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/developer_api"
|
||||
|
||||
type Connector struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
URI string `json:"uri"`
|
||||
URL string `json:"url"`
|
||||
Desc string `json:"description"`
|
||||
ConnectorStatus developer_api.ConnectorDynamicStatus `json:"connector_status"`
|
||||
}
|
||||
45
backend/api/model/crossdomain/conversation/conversation.go
Normal file
45
backend/api/model/crossdomain/conversation/conversation.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package conversation
|
||||
|
||||
import "github.com/coze-dev/coze-studio/backend/api/model/conversation/common"
|
||||
|
||||
type GetCurrent struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
Scene common.Scene `json:"scene"`
|
||||
AgentID int64 `json:"agent_id"`
|
||||
ConnectorID int64 `json:"connector_id"`
|
||||
}
|
||||
|
||||
type Scene int32
|
||||
|
||||
const (
|
||||
SceneDefault Scene = 0
|
||||
SceneExplore Scene = 1
|
||||
SceneBotStore Scene = 2
|
||||
SceneCozeHome Scene = 3
|
||||
ScenePlayground Scene = 4
|
||||
SceneEvaluation Scene = 5
|
||||
SceneAgentAPP Scene = 6
|
||||
ScenePromptOptimize Scene = 7
|
||||
SceneGenerateAgentInfo Scene = 8
|
||||
SceneOpenApi Scene = 9
|
||||
)
|
||||
|
||||
type Conversation struct {
|
||||
ID int64 `json:"id"`
|
||||
SectionID int64 `json:"section_id"`
|
||||
AgentID int64 `json:"agent_id"`
|
||||
ConnectorID int64 `json:"connector_id"`
|
||||
CreatorID int64 `json:"creator_id"`
|
||||
Scene common.Scene `json:"scene"`
|
||||
Status ConversationStatus `json:"status"`
|
||||
Ext string `json:"ext"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
}
|
||||
|
||||
type ConversationStatus int32
|
||||
|
||||
const (
|
||||
ConversationStatusNormal ConversationStatus = 1
|
||||
ConversationStatusDeleted ConversationStatus = 2
|
||||
)
|
||||
76
backend/api/model/crossdomain/database/const.go
Normal file
76
backend/api/model/crossdomain/database/const.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package database
|
||||
|
||||
type OperateType int64
|
||||
|
||||
const (
|
||||
OperateType_Custom OperateType = 0
|
||||
OperateType_Insert OperateType = 1
|
||||
OperateType_Update OperateType = 2
|
||||
OperateType_Delete OperateType = 3
|
||||
OperateType_Select OperateType = 4
|
||||
)
|
||||
|
||||
type Operation int64
|
||||
|
||||
const (
|
||||
Operation_EQUAL Operation = 1
|
||||
Operation_NOT_EQUAL Operation = 2
|
||||
Operation_GREATER_THAN Operation = 3
|
||||
Operation_LESS_THAN Operation = 4
|
||||
Operation_GREATER_EQUAL Operation = 5
|
||||
Operation_LESS_EQUAL Operation = 6
|
||||
Operation_IN Operation = 7
|
||||
Operation_NOT_IN Operation = 8
|
||||
Operation_IS_NULL Operation = 9
|
||||
Operation_IS_NOT_NULL Operation = 10
|
||||
Operation_LIKE Operation = 11
|
||||
Operation_NOT_LIKE Operation = 12
|
||||
)
|
||||
|
||||
type Logic int64
|
||||
|
||||
const (
|
||||
Logic_And Logic = 1
|
||||
Logic_Or Logic = 2
|
||||
)
|
||||
|
||||
// SQLType indicates the type of SQL, e.g., parameterized (with '?') or raw SQL.
|
||||
type SQLType int32
|
||||
|
||||
const (
|
||||
SQLType_Parameterized SQLType = 0
|
||||
SQLType_Raw SQLType = 1 // Complete/raw SQL
|
||||
)
|
||||
|
||||
type DocumentSourceType int64
|
||||
|
||||
const (
|
||||
DocumentSourceType_Document DocumentSourceType = 0
|
||||
)
|
||||
|
||||
type TableReadDataMethod int
|
||||
|
||||
var (
|
||||
TableReadDataMethodOnlyHeader TableReadDataMethod = 1
|
||||
TableReadDataMethodPreview TableReadDataMethod = 2
|
||||
TableReadDataMethodAll TableReadDataMethod = 3
|
||||
TableReadDataMethodHead TableReadDataMethod = 4
|
||||
)
|
||||
|
||||
type ColumnTypeCategory int64
|
||||
|
||||
const (
|
||||
ColumnTypeCategoryText ColumnTypeCategory = 0
|
||||
ColumnTypeCategoryNumber ColumnTypeCategory = 1
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultCreateTimeColName = "bstudio_create_time"
|
||||
DefaultCidColName = "bstudio_connector_id"
|
||||
DefaultUidColName = "bstudio_connector_uid"
|
||||
DefaultIDColName = "bstudio_id"
|
||||
|
||||
DefaultCreateTimeDisplayColName = "bstudio_create_time"
|
||||
DefaultUidDisplayColName = "uuid"
|
||||
DefaultIDDisplayColName = "id"
|
||||
)
|
||||
190
backend/api/model/crossdomain/database/database.go
Normal file
190
backend/api/model/crossdomain/database/database.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/bot_common"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/table"
|
||||
)
|
||||
|
||||
type ExecuteSQLRequest struct {
|
||||
SQL *string // set if OperateType is 0.
|
||||
SQLType SQLType // SQLType indicates the type of SQL: parameterized or raw SQL. It takes effect if OperateType is 0.
|
||||
|
||||
DatabaseID int64
|
||||
UserID string
|
||||
SpaceID int64
|
||||
ConnectorID *int64
|
||||
SQLParams []*SQLParamVal
|
||||
TableType table.TableType
|
||||
OperateType OperateType
|
||||
|
||||
// set the following values if OperateType is not 0.
|
||||
SelectFieldList *SelectFieldList
|
||||
OrderByList []OrderBy
|
||||
Limit *int64
|
||||
Offset *int64
|
||||
Condition *ComplexCondition
|
||||
UpsertRows []*UpsertRow
|
||||
}
|
||||
|
||||
type ExecuteSQLResponse struct {
|
||||
// Records contains the query result, where each map represents a row.
|
||||
// The map's key is the column name, and the value is the raw data from the database.
|
||||
// The caller is responsible for type assertion and conversion to the desired format.
|
||||
// Common types returned by database drivers include:
|
||||
// - Text: []uint8 (can be converted to string)
|
||||
// - Number: int64
|
||||
// - Float: float64
|
||||
// - Boolean: bool
|
||||
// - Date: time.Time
|
||||
Records []map[string]any
|
||||
FieldList []*FieldItem
|
||||
RowsAffected *int64
|
||||
}
|
||||
|
||||
type PublishDatabaseRequest struct {
|
||||
AgentID int64
|
||||
}
|
||||
|
||||
type PublishDatabaseResponse struct {
|
||||
OnlineDatabases []*bot_common.Database
|
||||
}
|
||||
|
||||
type SQLParamVal struct {
|
||||
ValueType table.FieldItemType
|
||||
ISNull bool
|
||||
Value *string
|
||||
Name *string
|
||||
}
|
||||
|
||||
type OrderBy struct {
|
||||
Field string
|
||||
Direction table.SortDirection
|
||||
}
|
||||
|
||||
type UpsertRow struct {
|
||||
Records []*Record
|
||||
}
|
||||
|
||||
type Record struct {
|
||||
FieldId string
|
||||
FieldValue string
|
||||
}
|
||||
|
||||
type SelectFieldList struct {
|
||||
FieldID []string
|
||||
IsDistinct bool
|
||||
}
|
||||
|
||||
type ComplexCondition struct {
|
||||
Conditions []*Condition
|
||||
// NestedConditions *ComplexCondition
|
||||
Logic Logic
|
||||
}
|
||||
|
||||
type Condition struct {
|
||||
Left string
|
||||
Operation Operation
|
||||
Right string
|
||||
}
|
||||
|
||||
type FieldItem struct {
|
||||
Name string
|
||||
Desc string
|
||||
Type table.FieldItemType
|
||||
MustRequired bool
|
||||
AlterID int64
|
||||
IsSystemField bool
|
||||
PhysicalName string
|
||||
// ID int64
|
||||
}
|
||||
|
||||
type Database struct {
|
||||
ID int64
|
||||
IconURI string
|
||||
|
||||
CreatorID int64
|
||||
SpaceID int64
|
||||
|
||||
CreatedAtMs int64
|
||||
UpdatedAtMs int64
|
||||
DeletedAtMs int64
|
||||
|
||||
AppID int64
|
||||
IconURL string
|
||||
TableName string
|
||||
TableDesc string
|
||||
Status table.BotTableStatus
|
||||
FieldList []*FieldItem
|
||||
ActualTableName string
|
||||
RwMode table.BotTableRWMode
|
||||
PromptDisabled bool
|
||||
IsVisible bool
|
||||
DraftID *int64
|
||||
OnlineID *int64
|
||||
ExtraInfo map[string]string
|
||||
IsAddedToAgent *bool
|
||||
TableType *table.TableType
|
||||
}
|
||||
|
||||
func (d *Database) GetDraftID() int64 {
|
||||
if d.DraftID == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
return *d.DraftID
|
||||
}
|
||||
|
||||
func (d *Database) GetOnlineID() int64 {
|
||||
if d.OnlineID == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
return *d.OnlineID
|
||||
}
|
||||
|
||||
type DatabaseBasic struct {
|
||||
ID int64
|
||||
TableType table.TableType
|
||||
NeedSysFields bool
|
||||
}
|
||||
|
||||
type DeleteDatabaseRequest struct {
|
||||
ID int64
|
||||
}
|
||||
|
||||
type AgentToDatabase struct {
|
||||
AgentID int64
|
||||
DatabaseID int64
|
||||
TableType table.TableType
|
||||
PromptDisabled bool
|
||||
}
|
||||
|
||||
type AgentToDatabaseBasic struct {
|
||||
AgentID int64
|
||||
DatabaseID int64
|
||||
}
|
||||
|
||||
type BindDatabaseToAgentRequest struct {
|
||||
DraftDatabaseID int64
|
||||
AgentID int64
|
||||
}
|
||||
|
||||
type UnBindDatabaseToAgentRequest struct {
|
||||
DraftDatabaseID int64
|
||||
AgentID int64
|
||||
}
|
||||
|
||||
type MGetDatabaseRequest struct {
|
||||
Basics []*DatabaseBasic
|
||||
}
|
||||
type MGetDatabaseResponse struct {
|
||||
Databases []*Database
|
||||
}
|
||||
|
||||
type GetAllDatabaseByAppIDRequest struct {
|
||||
AppID int64
|
||||
}
|
||||
|
||||
type GetAllDatabaseByAppIDResponse struct {
|
||||
Databases []*Database // online databases
|
||||
}
|
||||
269
backend/api/model/crossdomain/knowledge/knowledge.go
Normal file
269
backend/api/model/crossdomain/knowledge/knowledge.go
Normal file
@@ -0,0 +1,269 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"github.com/bytedance/sonic"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
type ListKnowledgeRequest struct {
|
||||
IDs []int64
|
||||
SpaceID *int64
|
||||
AppID *int64
|
||||
Name *string // 完全匹配
|
||||
Status []int32
|
||||
UserID *int64
|
||||
Query *string // 模糊匹配
|
||||
Page *int
|
||||
PageSize *int
|
||||
Order *Order
|
||||
OrderType *OrderType
|
||||
FormatType *DocumentType
|
||||
}
|
||||
|
||||
type Order int32
|
||||
|
||||
const (
|
||||
OrderCreatedAt Order = 1
|
||||
OrderUpdatedAt Order = 2
|
||||
)
|
||||
|
||||
type OrderType int32
|
||||
|
||||
const (
|
||||
OrderTypeAsc OrderType = 1
|
||||
OrderTypeDesc OrderType = 2
|
||||
)
|
||||
|
||||
type DocumentType int64
|
||||
|
||||
const (
|
||||
DocumentTypeText DocumentType = 0 // 文本
|
||||
DocumentTypeTable DocumentType = 1 // 表格
|
||||
DocumentTypeImage DocumentType = 2 // 图片
|
||||
DocumentTypeUnknown DocumentType = 9 // 未知
|
||||
)
|
||||
|
||||
type ListKnowledgeResponse struct {
|
||||
KnowledgeList []*Knowledge
|
||||
Total int64
|
||||
}
|
||||
|
||||
type Knowledge struct {
|
||||
Info
|
||||
SliceHit int64
|
||||
Type DocumentType
|
||||
Status KnowledgeStatus
|
||||
}
|
||||
|
||||
type Info struct {
|
||||
ID int64
|
||||
Name string
|
||||
Description string
|
||||
IconURI string
|
||||
IconURL string
|
||||
CreatorID int64
|
||||
SpaceID int64
|
||||
AppID int64
|
||||
CreatedAtMs int64
|
||||
UpdatedAtMs int64
|
||||
DeletedAtMs int64
|
||||
}
|
||||
|
||||
type KnowledgeStatus int64
|
||||
|
||||
const (
|
||||
KnowledgeStatusInit KnowledgeStatus = 0
|
||||
KnowledgeStatusEnable KnowledgeStatus = 1
|
||||
KnowledgeStatusDisable KnowledgeStatus = 3
|
||||
)
|
||||
|
||||
type RetrieveRequest struct {
|
||||
Query string
|
||||
ChatHistory []*schema.Message
|
||||
|
||||
// 从指定的知识库和文档中召回
|
||||
KnowledgeIDs []int64
|
||||
DocumentIDs []int64 // todo: 确认下这个场景
|
||||
|
||||
// 召回策略
|
||||
Strategy *RetrievalStrategy
|
||||
|
||||
// 用于 nl2sql 和 message to query 的 chat model config
|
||||
ChatModelProtocol *chatmodel.Protocol
|
||||
ChatModelConfig *chatmodel.Config
|
||||
}
|
||||
|
||||
type RetrievalStrategy struct {
|
||||
TopK *int64 // 1-10 default 3
|
||||
MinScore *float64 // 0.01-0.99 default 0.5
|
||||
MaxTokens *int64
|
||||
|
||||
SelectType SelectType // 调用方式
|
||||
SearchType SearchType // 搜索策略
|
||||
EnableQueryRewrite bool
|
||||
EnableRerank bool
|
||||
EnableNL2SQL bool
|
||||
}
|
||||
|
||||
type SelectType int64
|
||||
|
||||
const (
|
||||
SelectTypeAuto = 0 // 自动调用
|
||||
SelectTypeOnDemand = 1 // 按需调用
|
||||
)
|
||||
|
||||
type SearchType int64
|
||||
|
||||
const (
|
||||
SearchTypeSemantic SearchType = 0 // 语义
|
||||
SearchTypeFullText SearchType = 1 // 全文
|
||||
SearchTypeHybrid SearchType = 2 // 混合
|
||||
)
|
||||
|
||||
type RetrieveResponse struct {
|
||||
RetrieveSlices []*RetrieveSlice
|
||||
}
|
||||
|
||||
type RetrieveSlice struct {
|
||||
Slice *Slice
|
||||
Score float64
|
||||
}
|
||||
|
||||
type Slice struct {
|
||||
Info
|
||||
|
||||
KnowledgeID int64
|
||||
DocumentID int64
|
||||
DocumentName string
|
||||
RawContent []*SliceContent
|
||||
SliceStatus SliceStatus
|
||||
ByteCount int64 // 切片 bytes
|
||||
CharCount int64 // 切片字符数
|
||||
Sequence int64 // 切片位置序号
|
||||
Hit int64 // 命中次数
|
||||
Extra map[string]string
|
||||
}
|
||||
|
||||
func (s *Slice) GetSliceContent() string {
|
||||
if len(s.RawContent) == 0 {
|
||||
return ""
|
||||
}
|
||||
if s.RawContent[0].Type == SliceContentTypeTable {
|
||||
contentMap := map[string]string{}
|
||||
for _, column := range s.RawContent[0].Table.Columns {
|
||||
contentMap[column.ColumnName] = column.GetStringValue()
|
||||
}
|
||||
byteData, err := sonic.Marshal(contentMap)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(byteData)
|
||||
}
|
||||
data := ""
|
||||
for i := range s.RawContent {
|
||||
item := s.RawContent[i]
|
||||
if item == nil {
|
||||
continue
|
||||
}
|
||||
if item.Type == SliceContentTypeTable {
|
||||
var contentMap map[string]string
|
||||
for _, column := range s.RawContent[0].Table.Columns {
|
||||
contentMap[column.ColumnName] = column.GetStringValue()
|
||||
}
|
||||
byteData, err := sonic.Marshal(contentMap)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
data += string(byteData)
|
||||
}
|
||||
if item.Type == SliceContentTypeText {
|
||||
data += ptr.From(item.Text)
|
||||
}
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
type SliceContent struct {
|
||||
Type SliceContentType
|
||||
|
||||
Text *string
|
||||
Image *SliceImage
|
||||
Table *SliceTable
|
||||
}
|
||||
|
||||
type SliceStatus int64
|
||||
|
||||
const (
|
||||
SliceStatusInit SliceStatus = 0 // 初始化
|
||||
SliceStatusFinishStore SliceStatus = 1 // searchStore存储完成
|
||||
SliceStatusFailed SliceStatus = 9 // 失败
|
||||
)
|
||||
|
||||
type SliceContentType int64
|
||||
|
||||
const (
|
||||
SliceContentTypeText SliceContentType = 0
|
||||
//SliceContentTypeImage SliceContentType = 1
|
||||
SliceContentTypeTable SliceContentType = 2
|
||||
)
|
||||
|
||||
type SliceImage struct {
|
||||
Base64 []byte
|
||||
URI string
|
||||
OCR bool // 是否使用 ocr 提取了文本
|
||||
OCRText *string
|
||||
}
|
||||
|
||||
type SliceTable struct { // table slice 为一行数据
|
||||
Columns []*document.ColumnData
|
||||
}
|
||||
|
||||
type DeleteKnowledgeRequest struct {
|
||||
KnowledgeID int64
|
||||
}
|
||||
type GetKnowledgeByIDRequest struct {
|
||||
KnowledgeID int64
|
||||
}
|
||||
|
||||
type GetKnowledgeByIDResponse struct {
|
||||
Knowledge *Knowledge
|
||||
}
|
||||
|
||||
type MGetKnowledgeByIDRequest struct {
|
||||
KnowledgeIDs []int64
|
||||
}
|
||||
|
||||
type MGetKnowledgeByIDResponse struct {
|
||||
Knowledge []*Knowledge
|
||||
}
|
||||
|
||||
type CopyKnowledgeRequest struct {
|
||||
KnowledgeID int64
|
||||
TargetAppID int64
|
||||
TargetSpaceID int64
|
||||
TargetUserID int64
|
||||
TaskUniqKey string
|
||||
}
|
||||
type CopyStatus int64
|
||||
|
||||
const (
|
||||
CopyStatus_Successful CopyStatus = 1
|
||||
CopyStatus_Processing CopyStatus = 2
|
||||
CopyStatus_Failed CopyStatus = 3
|
||||
CopyStatus_KeepOrigin CopyStatus = 4
|
||||
)
|
||||
|
||||
type CopyKnowledgeResponse struct {
|
||||
OriginKnowledgeID int64
|
||||
TargetKnowledgeID int64
|
||||
CopyStatus CopyStatus
|
||||
ErrMsg string
|
||||
}
|
||||
type MoveKnowledgeToLibraryRequest struct {
|
||||
KnowledgeID int64
|
||||
}
|
||||
87
backend/api/model/crossdomain/message/message.go
Normal file
87
backend/api/model/crossdomain/message/message.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/conversation/message"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
ID int64 `json:"id"`
|
||||
ConversationID int64 `json:"conversation_id"`
|
||||
RunID int64 `json:"run_id"`
|
||||
AgentID int64 `json:"agent_id"`
|
||||
SectionID int64 `json:"section_id"`
|
||||
Content string `json:"content"`
|
||||
MultiContent []*InputMetaData `json:"multi_content"`
|
||||
ContentType ContentType `json:"content_type"`
|
||||
DisplayContent string `json:"display_content"`
|
||||
Role schema.RoleType `json:"role"`
|
||||
Name string `json:"name"`
|
||||
Status MessageStatus `json:"status"`
|
||||
MessageType MessageType `json:"message_type"`
|
||||
ModelContent string `json:"model_content"`
|
||||
Position int32 `json:"position"`
|
||||
UserID string `json:"user_id"`
|
||||
Ext map[string]string `json:"ext"`
|
||||
ReasoningContent string `json:"reasoning_content"`
|
||||
RequiredAction *message.RequiredAction `json:"required_action"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
}
|
||||
|
||||
type InputMetaData struct {
|
||||
Type InputType `json:"type"`
|
||||
Text string `json:"text"`
|
||||
FileData []*FileData `json:"file_data"`
|
||||
}
|
||||
|
||||
type MessageStatus int32
|
||||
|
||||
const (
|
||||
MessageStatusAvailable MessageStatus = 1
|
||||
MessageStatusDeleted MessageStatus = 2
|
||||
MessageStatusBroken MessageStatus = 4
|
||||
)
|
||||
|
||||
type InputType string
|
||||
|
||||
const (
|
||||
InputTypeText InputType = "text"
|
||||
InputTypeFile InputType = "file"
|
||||
InputTypeImage InputType = "image"
|
||||
InputTypeVideo InputType = "video"
|
||||
InputTypeAudio InputType = "audio"
|
||||
)
|
||||
|
||||
type FileData struct {
|
||||
Url string `json:"url"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type ContentType string
|
||||
|
||||
const (
|
||||
ContentTypeText ContentType = "text"
|
||||
ContentTypeImage ContentType = "image"
|
||||
ContentTypeVideo ContentType = "video"
|
||||
ContentTypeMusic ContentType = "music"
|
||||
ContentTypeCard ContentType = "card"
|
||||
ContentTypeWidget ContentType = "widget"
|
||||
ContentTypeAPP ContentType = "app"
|
||||
ContentTypeMix ContentType = "mix"
|
||||
)
|
||||
|
||||
type MessageType string
|
||||
|
||||
const (
|
||||
MessageTypeAck MessageType = "ack"
|
||||
MessageTypeQuestion MessageType = "question"
|
||||
MessageTypeFunctionCall MessageType = "function_call"
|
||||
MessageTypeToolResponse MessageType = "tool_response"
|
||||
MessageTypeKnowledge MessageType = "knowledge"
|
||||
MessageTypeAnswer MessageType = "answer"
|
||||
MessageTypeFlowUp MessageType = "follow_up"
|
||||
MessageTypeInterrupt MessageType = "interrupt"
|
||||
MessageTypeVerbose MessageType = "verbose"
|
||||
)
|
||||
68
backend/api/model/crossdomain/modelmgr/const.go
Normal file
68
backend/api/model/crossdomain/modelmgr/const.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package modelmgr
|
||||
|
||||
type ParameterName string
|
||||
|
||||
const (
|
||||
Temperature ParameterName = "temperature"
|
||||
TopP ParameterName = "top_p"
|
||||
TopK ParameterName = "top_k"
|
||||
MaxTokens ParameterName = "max_tokens"
|
||||
RespFormat ParameterName = "response_format"
|
||||
FrequencyPenalty ParameterName = "frequency_penalty"
|
||||
PresencePenalty ParameterName = "presence_penalty"
|
||||
)
|
||||
|
||||
type ValueType string
|
||||
|
||||
const (
|
||||
ValueTypeInt ValueType = "int"
|
||||
ValueTypeFloat ValueType = "float"
|
||||
ValueTypeBoolean ValueType = "boolean"
|
||||
ValueTypeString ValueType = "string"
|
||||
)
|
||||
|
||||
type DefaultType string
|
||||
|
||||
const (
|
||||
DefaultTypeDefault DefaultType = "default_val"
|
||||
DefaultTypeCreative DefaultType = "creative"
|
||||
DefaultTypeBalance DefaultType = "balance"
|
||||
DefaultTypePrecise DefaultType = "precise"
|
||||
)
|
||||
|
||||
// Deprecated
|
||||
type Scenario int64 // 模型实体使用场景
|
||||
|
||||
type Modal string
|
||||
|
||||
const (
|
||||
ModalText Modal = "text"
|
||||
ModalImage Modal = "image"
|
||||
ModalFile Modal = "file"
|
||||
ModalAudio Modal = "audio"
|
||||
ModalVideo Modal = "video"
|
||||
)
|
||||
|
||||
type ModelMetaStatus int64 // 模型实体状态
|
||||
|
||||
const (
|
||||
StatusInUse ModelMetaStatus = 1 // 应用中,可使用可新建
|
||||
StatusPending ModelMetaStatus = 5 // 待下线,可使用不可新建
|
||||
StatusDeleted ModelMetaStatus = 10 // 已下线,不可使用不可新建
|
||||
)
|
||||
|
||||
type Widget string
|
||||
|
||||
const (
|
||||
WidgetSlider Widget = "slider"
|
||||
WidgetRadioButtons Widget = "radio_buttons"
|
||||
)
|
||||
|
||||
type ModelEntityStatus int64
|
||||
|
||||
const (
|
||||
ModelEntityStatusDefault ModelEntityStatus = 0
|
||||
ModelEntityStatusInUse ModelEntityStatus = 1
|
||||
ModelEntityStatusPending ModelEntityStatus = 5
|
||||
ModelEntityStatusDeleted ModelEntityStatus = 10
|
||||
)
|
||||
171
backend/api/model/crossdomain/modelmgr/modelmgr.go
Normal file
171
backend/api/model/crossdomain/modelmgr/modelmgr.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package modelmgr
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/i18n"
|
||||
)
|
||||
|
||||
type MGetModelRequest struct {
|
||||
IDs []int64
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
ID int64 `yaml:"id"`
|
||||
Name string `yaml:"name"`
|
||||
Description string `yaml:"description"`
|
||||
DefaultParameters []*Parameter `yaml:"default_parameters"`
|
||||
|
||||
CreatedAtMs int64
|
||||
UpdatedAtMs int64
|
||||
DeletedAtMs int64
|
||||
|
||||
Meta ModelMeta `yaml:"meta"`
|
||||
}
|
||||
|
||||
type Parameter struct {
|
||||
Name ParameterName `json:"name" yaml:"name"`
|
||||
Label *MultilingualText `json:"label,omitempty" yaml:"label,omitempty"`
|
||||
Desc *MultilingualText `json:"desc" yaml:"desc"`
|
||||
Type ValueType `json:"type" yaml:"type"`
|
||||
Min string `json:"min" yaml:"min"`
|
||||
Max string `json:"max" yaml:"max"`
|
||||
DefaultVal DefaultValue `json:"default_val" yaml:"default_val"`
|
||||
Precision int `json:"precision,omitempty" yaml:"precision,omitempty"` // float precision, default 2
|
||||
Options []*ParamOption `json:"options" yaml:"options"` // enum options
|
||||
Style DisplayStyle `json:"param_class" yaml:"style"`
|
||||
}
|
||||
|
||||
func (p *Parameter) GetFloat(tp DefaultType) (float64, error) {
|
||||
if p.Type != ValueTypeFloat {
|
||||
return 0, fmt.Errorf("unexpected paramerter type, name=%v, expect=%v, given=%v",
|
||||
p.Name, ValueTypeFloat, p.Type)
|
||||
}
|
||||
|
||||
if tp != DefaultTypeDefault && p.DefaultVal[tp] == "" {
|
||||
tp = DefaultTypeDefault
|
||||
}
|
||||
|
||||
val, ok := p.DefaultVal[tp]
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("unexpected default type, name=%v, type=%v", p.Name, tp)
|
||||
}
|
||||
|
||||
return strconv.ParseFloat(val, 64)
|
||||
}
|
||||
|
||||
func (p *Parameter) GetInt(tp DefaultType) (int64, error) {
|
||||
if p.Type != ValueTypeInt {
|
||||
return 0, fmt.Errorf("unexpected paramerter type, name=%v, expect=%v, given=%v",
|
||||
p.Name, ValueTypeInt, p.Type)
|
||||
}
|
||||
|
||||
if tp != DefaultTypeDefault && p.DefaultVal[tp] == "" {
|
||||
tp = DefaultTypeDefault
|
||||
}
|
||||
val, ok := p.DefaultVal[tp]
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("unexpected default type, name=%v, type=%v", p.Name, tp)
|
||||
}
|
||||
return strconv.ParseInt(val, 10, 64)
|
||||
}
|
||||
|
||||
func (p *Parameter) GetBool(tp DefaultType) (bool, error) {
|
||||
if p.Type != ValueTypeBoolean {
|
||||
return false, fmt.Errorf("unexpected paramerter type, name=%v, expect=%v, given=%v",
|
||||
p.Name, ValueTypeBoolean, p.Type)
|
||||
}
|
||||
if tp != DefaultTypeDefault && p.DefaultVal[tp] == "" {
|
||||
tp = DefaultTypeDefault
|
||||
}
|
||||
val, ok := p.DefaultVal[tp]
|
||||
if !ok {
|
||||
return false, fmt.Errorf("unexpected default type, name=%v, type=%v", p.Name, tp)
|
||||
}
|
||||
return strconv.ParseBool(val)
|
||||
}
|
||||
|
||||
func (p *Parameter) GetString(tp DefaultType) (string, error) {
|
||||
if tp != DefaultTypeDefault && p.DefaultVal[tp] == "" {
|
||||
tp = DefaultTypeDefault
|
||||
}
|
||||
|
||||
val, ok := p.DefaultVal[tp]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("unexpected default type, name=%v, type=%v", p.Name, tp)
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
type ModelMeta struct {
|
||||
ID int64 `yaml:"id"`
|
||||
Name string `yaml:"name"`
|
||||
IconURI string `yaml:"icon_uri"`
|
||||
IconURL string `yaml:"icon_url"`
|
||||
Description *MultilingualText `yaml:"description"`
|
||||
|
||||
CreatedAtMs int64
|
||||
UpdatedAtMs int64
|
||||
DeletedAtMs int64
|
||||
|
||||
Protocol chatmodel.Protocol `yaml:"protocol"` // 模型通信协议
|
||||
Capability *Capability `yaml:"capability"` // 模型能力
|
||||
ConnConfig *chatmodel.Config `yaml:"conn_config"` // 模型连接配置
|
||||
Status ModelMetaStatus `yaml:"status"` // 模型状态
|
||||
}
|
||||
|
||||
type DefaultValue map[DefaultType]string
|
||||
|
||||
type DisplayStyle struct {
|
||||
Widget Widget `json:"class_id" yaml:"widget"`
|
||||
Label *MultilingualText `json:"label" yaml:"label"`
|
||||
}
|
||||
|
||||
type ParamOption struct {
|
||||
Label string `json:"label"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
type Capability struct {
|
||||
// Model supports function calling
|
||||
FunctionCall bool `json:"function_call" yaml:"function_call" mapstructure:"function_call"`
|
||||
// Input modals
|
||||
InputModal []Modal `json:"input_modal,omitempty" yaml:"input_modal,omitempty" mapstructure:"input_modal,omitempty"`
|
||||
// Input tokens
|
||||
InputTokens int `json:"input_tokens" yaml:"input_tokens" mapstructure:"input_tokens"`
|
||||
// Model supports json mode
|
||||
JSONMode bool `json:"json_mode" yaml:"json_mode" mapstructure:"json_mode"`
|
||||
// Max tokens
|
||||
MaxTokens int `json:"max_tokens" yaml:"max_tokens" mapstructure:"max_tokens"`
|
||||
// Output modals
|
||||
OutputModal []Modal `json:"output_modal,omitempty" yaml:"output_modal,omitempty" mapstructure:"output_modal,omitempty"`
|
||||
// Output tokens
|
||||
OutputTokens int `json:"output_tokens" yaml:"output_tokens" mapstructure:"output_tokens"`
|
||||
// Model supports prefix caching
|
||||
PrefixCaching bool `json:"prefix_caching" yaml:"prefix_caching" mapstructure:"prefix_caching"`
|
||||
// Model supports reasoning
|
||||
Reasoning bool `json:"reasoning" yaml:"reasoning" mapstructure:"reasoning"`
|
||||
// Model supports prefill response
|
||||
PrefillResponse bool `json:"prefill_response" yaml:"prefill_response" mapstructure:"prefill_response"`
|
||||
}
|
||||
|
||||
type MultilingualText struct {
|
||||
ZH string `json:"zh,omitempty" yaml:"zh,omitempty"`
|
||||
EN string `json:"en,omitempty" yaml:"en,omitempty"`
|
||||
}
|
||||
|
||||
func (m *MultilingualText) Read(locale i18n.Locale) string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
switch locale {
|
||||
case i18n.LocaleZH:
|
||||
return m.ZH
|
||||
case i18n.LocaleEN:
|
||||
return m.EN
|
||||
default:
|
||||
return m.EN
|
||||
}
|
||||
}
|
||||
110
backend/api/model/crossdomain/plugin/consts.go
Normal file
110
backend/api/model/crossdomain/plugin/consts.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package plugin
|
||||
|
||||
import "github.com/getkin/kin-openapi/openapi3"
|
||||
|
||||
type PluginType string
|
||||
|
||||
const (
|
||||
PluginTypeOfCloud PluginType = "openapi"
|
||||
)
|
||||
|
||||
type AuthzType string
|
||||
|
||||
const (
|
||||
AuthzTypeOfNone AuthzType = "none"
|
||||
AuthzTypeOfService AuthzType = "service_http"
|
||||
AuthzTypeOfOAuth AuthzType = "oauth"
|
||||
)
|
||||
|
||||
type AuthzSubType string
|
||||
|
||||
const (
|
||||
AuthzSubTypeOfServiceAPIToken AuthzSubType = "token/api_key"
|
||||
AuthzSubTypeOfOAuthAuthorizationCode AuthzSubType = "authorization_code"
|
||||
AuthzSubTypeOfOAuthClientCredentials AuthzSubType = "client_credentials"
|
||||
)
|
||||
|
||||
type HTTPParamLocation string
|
||||
|
||||
const (
|
||||
ParamInHeader HTTPParamLocation = openapi3.ParameterInHeader
|
||||
ParamInPath HTTPParamLocation = openapi3.ParameterInPath
|
||||
ParamInQuery HTTPParamLocation = openapi3.ParameterInQuery
|
||||
ParamInBody HTTPParamLocation = "body"
|
||||
)
|
||||
|
||||
type ActivatedStatus int32
|
||||
|
||||
const (
|
||||
ActivateTool ActivatedStatus = 0
|
||||
DeactivateTool ActivatedStatus = 1
|
||||
)
|
||||
|
||||
type ProjectType int8
|
||||
|
||||
const (
|
||||
ProjectTypeOfAgent ProjectType = 1
|
||||
ProjectTypeOfAPP ProjectType = 2
|
||||
)
|
||||
|
||||
type ExecuteScene string
|
||||
|
||||
const (
|
||||
ExecSceneOfOnlineAgent ExecuteScene = "online_agent"
|
||||
ExecSceneOfDraftAgent ExecuteScene = "draft_agent"
|
||||
ExecSceneOfWorkflow ExecuteScene = "workflow"
|
||||
ExecSceneOfToolDebug ExecuteScene = "tool_debug"
|
||||
)
|
||||
|
||||
type InvalidResponseProcessStrategy int8
|
||||
|
||||
const (
|
||||
InvalidResponseProcessStrategyOfReturnRaw InvalidResponseProcessStrategy = 0 // If the value of a field is invalid, the raw response value of the field is returned.
|
||||
InvalidResponseProcessStrategyOfReturnDefault InvalidResponseProcessStrategy = 1 // If the value of a field is invalid, the default value of the field is returned.
|
||||
)
|
||||
|
||||
const (
|
||||
APISchemaExtendAssistType = "x-assist-type"
|
||||
APISchemaExtendGlobalDisable = "x-global-disable"
|
||||
APISchemaExtendLocalDisable = "x-local-disable"
|
||||
APISchemaExtendVariableRef = "x-variable-ref"
|
||||
APISchemaExtendAuthMode = "x-auth-mode"
|
||||
)
|
||||
|
||||
type ToolAuthMode string
|
||||
|
||||
const (
|
||||
ToolAuthModeOfRequired ToolAuthMode = "required"
|
||||
ToolAuthModeOfSupported ToolAuthMode = "supported"
|
||||
ToolAuthModeOfDisabled ToolAuthMode = "disabled"
|
||||
)
|
||||
|
||||
type APIFileAssistType string
|
||||
|
||||
const (
|
||||
AssistTypeFile APIFileAssistType = "file"
|
||||
AssistTypeImage APIFileAssistType = "image"
|
||||
AssistTypeDoc APIFileAssistType = "doc"
|
||||
AssistTypePPT APIFileAssistType = "ppt"
|
||||
AssistTypeCode APIFileAssistType = "code"
|
||||
AssistTypeExcel APIFileAssistType = "excel"
|
||||
AssistTypeZIP APIFileAssistType = "zip"
|
||||
AssistTypeVideo APIFileAssistType = "video"
|
||||
AssistTypeAudio APIFileAssistType = "audio"
|
||||
AssistTypeTXT APIFileAssistType = "txt"
|
||||
)
|
||||
|
||||
type CopyScene string
|
||||
|
||||
const (
|
||||
CopySceneOfToAPP CopyScene = "to_app"
|
||||
CopySceneOfToLibrary CopyScene = "to_library"
|
||||
CopySceneOfDuplicate CopyScene = "duplicate"
|
||||
CopySceneOfAPPDuplicate CopyScene = "app_duplicate"
|
||||
)
|
||||
|
||||
type InterruptEventType string
|
||||
|
||||
const (
|
||||
InterruptEventTypeOfToolNeedOAuth InterruptEventType = "tool_need_oauth"
|
||||
)
|
||||
270
backend/api/model/crossdomain/plugin/convert.go
Normal file
270
backend/api/model/crossdomain/plugin/convert.go
Normal file
@@ -0,0 +1,270 @@
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/getkin/kin-openapi/openapi3"
|
||||
|
||||
common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common"
|
||||
)
|
||||
|
||||
var httpParamLocations = map[common.ParameterLocation]HTTPParamLocation{
|
||||
common.ParameterLocation_Path: ParamInPath,
|
||||
common.ParameterLocation_Query: ParamInQuery,
|
||||
common.ParameterLocation_Body: ParamInBody,
|
||||
common.ParameterLocation_Header: ParamInHeader,
|
||||
}
|
||||
|
||||
func ToHTTPParamLocation(loc common.ParameterLocation) (HTTPParamLocation, bool) {
|
||||
_loc, ok := httpParamLocations[loc]
|
||||
return _loc, ok
|
||||
}
|
||||
|
||||
var thriftHTTPParamLocations = func() map[HTTPParamLocation]common.ParameterLocation {
|
||||
locations := make(map[HTTPParamLocation]common.ParameterLocation, len(httpParamLocations))
|
||||
for k, v := range httpParamLocations {
|
||||
locations[v] = k
|
||||
}
|
||||
return locations
|
||||
}()
|
||||
|
||||
func ToThriftHTTPParamLocation(loc HTTPParamLocation) (common.ParameterLocation, bool) {
|
||||
_loc, ok := thriftHTTPParamLocations[loc]
|
||||
return _loc, ok
|
||||
}
|
||||
|
||||
var openapiTypes = map[common.ParameterType]string{
|
||||
common.ParameterType_String: openapi3.TypeString,
|
||||
common.ParameterType_Integer: openapi3.TypeInteger,
|
||||
common.ParameterType_Number: openapi3.TypeNumber,
|
||||
common.ParameterType_Object: openapi3.TypeObject,
|
||||
common.ParameterType_Array: openapi3.TypeArray,
|
||||
common.ParameterType_Bool: openapi3.TypeBoolean,
|
||||
}
|
||||
|
||||
func ToOpenapiParamType(typ common.ParameterType) (string, bool) {
|
||||
_typ, ok := openapiTypes[typ]
|
||||
return _typ, ok
|
||||
}
|
||||
|
||||
var thriftParameterTypes = func() map[string]common.ParameterType {
|
||||
types := make(map[string]common.ParameterType, len(openapiTypes))
|
||||
for k, v := range openapiTypes {
|
||||
types[v] = k
|
||||
}
|
||||
return types
|
||||
}()
|
||||
|
||||
func ToThriftParamType(typ string) (common.ParameterType, bool) {
|
||||
_typ, ok := thriftParameterTypes[typ]
|
||||
return _typ, ok
|
||||
}
|
||||
|
||||
var apiAssistTypes = map[common.AssistParameterType]APIFileAssistType{
|
||||
common.AssistParameterType_DEFAULT: AssistTypeFile,
|
||||
common.AssistParameterType_IMAGE: AssistTypeImage,
|
||||
common.AssistParameterType_DOC: AssistTypeDoc,
|
||||
common.AssistParameterType_PPT: AssistTypePPT,
|
||||
common.AssistParameterType_CODE: AssistTypeCode,
|
||||
common.AssistParameterType_EXCEL: AssistTypeExcel,
|
||||
common.AssistParameterType_ZIP: AssistTypeZIP,
|
||||
common.AssistParameterType_VIDEO: AssistTypeVideo,
|
||||
common.AssistParameterType_AUDIO: AssistTypeAudio,
|
||||
common.AssistParameterType_TXT: AssistTypeTXT,
|
||||
}
|
||||
|
||||
func ToAPIAssistType(typ common.AssistParameterType) (APIFileAssistType, bool) {
|
||||
_typ, ok := apiAssistTypes[typ]
|
||||
return _typ, ok
|
||||
}
|
||||
|
||||
var thriftAPIAssistTypes = func() map[APIFileAssistType]common.AssistParameterType {
|
||||
types := make(map[APIFileAssistType]common.AssistParameterType, len(apiAssistTypes))
|
||||
for k, v := range apiAssistTypes {
|
||||
types[v] = k
|
||||
}
|
||||
return types
|
||||
}()
|
||||
|
||||
func ToThriftAPIAssistType(typ APIFileAssistType) (common.AssistParameterType, bool) {
|
||||
_typ, ok := thriftAPIAssistTypes[typ]
|
||||
return _typ, ok
|
||||
}
|
||||
|
||||
func IsValidAPIAssistType(typ APIFileAssistType) bool {
|
||||
_, ok := thriftAPIAssistTypes[typ]
|
||||
return ok
|
||||
}
|
||||
|
||||
var httpMethods = map[common.APIMethod]string{
|
||||
common.APIMethod_GET: http.MethodGet,
|
||||
common.APIMethod_POST: http.MethodPost,
|
||||
common.APIMethod_PUT: http.MethodPut,
|
||||
common.APIMethod_DELETE: http.MethodDelete,
|
||||
common.APIMethod_PATCH: http.MethodPatch,
|
||||
}
|
||||
|
||||
var thriftAPIMethods = func() map[string]common.APIMethod {
|
||||
methods := make(map[string]common.APIMethod, len(httpMethods))
|
||||
for k, v := range httpMethods {
|
||||
methods[v] = k
|
||||
}
|
||||
return methods
|
||||
}()
|
||||
|
||||
func ToThriftAPIMethod(method string) (common.APIMethod, bool) {
|
||||
_method, ok := thriftAPIMethods[method]
|
||||
return _method, ok
|
||||
}
|
||||
|
||||
func ToHTTPMethod(method common.APIMethod) (string, bool) {
|
||||
_method, ok := httpMethods[method]
|
||||
return _method, ok
|
||||
}
|
||||
|
||||
var assistTypeToFormat = map[APIFileAssistType]string{
|
||||
AssistTypeFile: "file_url",
|
||||
AssistTypeImage: "image_url",
|
||||
AssistTypeDoc: "doc_url",
|
||||
AssistTypePPT: "ppt_url",
|
||||
AssistTypeCode: "code_url",
|
||||
AssistTypeExcel: "excel_url",
|
||||
AssistTypeZIP: "zip_url",
|
||||
AssistTypeVideo: "video_url",
|
||||
AssistTypeAudio: "audio_url",
|
||||
AssistTypeTXT: "txt_url",
|
||||
}
|
||||
|
||||
func AssistTypeToFormat(typ APIFileAssistType) (string, bool) {
|
||||
format, ok := assistTypeToFormat[typ]
|
||||
return format, ok
|
||||
}
|
||||
|
||||
var formatToAssistType = func() map[string]APIFileAssistType {
|
||||
types := make(map[string]APIFileAssistType, len(assistTypeToFormat))
|
||||
for k, v := range assistTypeToFormat {
|
||||
types[v] = k
|
||||
}
|
||||
return types
|
||||
}()
|
||||
|
||||
func FormatToAssistType(format string) (APIFileAssistType, bool) {
|
||||
typ, ok := formatToAssistType[format]
|
||||
return typ, ok
|
||||
}
|
||||
|
||||
var assistTypeToThriftFormat = map[APIFileAssistType]common.PluginParamTypeFormat{
|
||||
AssistTypeFile: common.PluginParamTypeFormat_FileUrl,
|
||||
AssistTypeImage: common.PluginParamTypeFormat_ImageUrl,
|
||||
AssistTypeDoc: common.PluginParamTypeFormat_DocUrl,
|
||||
AssistTypePPT: common.PluginParamTypeFormat_PptUrl,
|
||||
AssistTypeCode: common.PluginParamTypeFormat_CodeUrl,
|
||||
AssistTypeExcel: common.PluginParamTypeFormat_ExcelUrl,
|
||||
AssistTypeZIP: common.PluginParamTypeFormat_ZipUrl,
|
||||
AssistTypeVideo: common.PluginParamTypeFormat_VideoUrl,
|
||||
AssistTypeAudio: common.PluginParamTypeFormat_AudioUrl,
|
||||
AssistTypeTXT: common.PluginParamTypeFormat_TxtUrl,
|
||||
}
|
||||
|
||||
func AssistTypeToThriftFormat(typ APIFileAssistType) (common.PluginParamTypeFormat, bool) {
|
||||
format, ok := assistTypeToThriftFormat[typ]
|
||||
return format, ok
|
||||
}
|
||||
|
||||
var authTypes = map[common.AuthorizationType]AuthzType{
|
||||
common.AuthorizationType_None: AuthzTypeOfNone,
|
||||
common.AuthorizationType_Service: AuthzTypeOfService,
|
||||
common.AuthorizationType_OAuth: AuthzTypeOfOAuth,
|
||||
common.AuthorizationType_Standard: AuthzTypeOfOAuth, // deprecated, the same as OAuth
|
||||
}
|
||||
|
||||
func ToAuthType(typ common.AuthorizationType) (AuthzType, bool) {
|
||||
_type, ok := authTypes[typ]
|
||||
return _type, ok
|
||||
}
|
||||
|
||||
var thriftAuthTypes = func() map[AuthzType]common.AuthorizationType {
|
||||
types := make(map[AuthzType]common.AuthorizationType, len(authTypes))
|
||||
for k, v := range authTypes {
|
||||
if v == AuthzTypeOfOAuth {
|
||||
types[v] = common.AuthorizationType_OAuth
|
||||
} else {
|
||||
types[v] = k
|
||||
}
|
||||
}
|
||||
return types
|
||||
}()
|
||||
|
||||
func ToThriftAuthType(typ AuthzType) (common.AuthorizationType, bool) {
|
||||
_type, ok := thriftAuthTypes[typ]
|
||||
return _type, ok
|
||||
}
|
||||
|
||||
var subAuthTypes = map[int32]AuthzSubType{
|
||||
int32(common.ServiceAuthSubType_ApiKey): AuthzSubTypeOfServiceAPIToken,
|
||||
int32(common.ServiceAuthSubType_OAuthAuthorizationCode): AuthzSubTypeOfOAuthAuthorizationCode,
|
||||
}
|
||||
|
||||
func ToAuthSubType(typ int32) (AuthzSubType, bool) {
|
||||
_type, ok := subAuthTypes[typ]
|
||||
return _type, ok
|
||||
}
|
||||
|
||||
var thriftSubAuthTypes = func() map[AuthzSubType]int32 {
|
||||
types := make(map[AuthzSubType]int32, len(subAuthTypes))
|
||||
for k, v := range subAuthTypes {
|
||||
types[v] = int32(k)
|
||||
}
|
||||
return types
|
||||
}()
|
||||
|
||||
func ToThriftAuthSubType(typ AuthzSubType) (int32, bool) {
|
||||
_type, ok := thriftSubAuthTypes[typ]
|
||||
return _type, ok
|
||||
}
|
||||
|
||||
var pluginTypes = map[common.PluginType]PluginType{
|
||||
common.PluginType_PLUGIN: PluginTypeOfCloud,
|
||||
}
|
||||
|
||||
func ToPluginType(typ common.PluginType) (PluginType, bool) {
|
||||
_type, ok := pluginTypes[typ]
|
||||
return _type, ok
|
||||
}
|
||||
|
||||
var thriftPluginTypes = func() map[PluginType]common.PluginType {
|
||||
types := make(map[PluginType]common.PluginType, len(pluginTypes))
|
||||
for k, v := range pluginTypes {
|
||||
types[v] = k
|
||||
}
|
||||
return types
|
||||
}()
|
||||
|
||||
func ToThriftPluginType(typ PluginType) (common.PluginType, bool) {
|
||||
_type, ok := thriftPluginTypes[typ]
|
||||
return _type, ok
|
||||
}
|
||||
|
||||
var apiAuthModes = map[common.PluginToolAuthType]ToolAuthMode{
|
||||
common.PluginToolAuthType_Required: ToolAuthModeOfRequired,
|
||||
common.PluginToolAuthType_Supported: ToolAuthModeOfSupported,
|
||||
common.PluginToolAuthType_Disable: ToolAuthModeOfDisabled,
|
||||
}
|
||||
|
||||
func ToAPIAuthMode(mode common.PluginToolAuthType) (ToolAuthMode, bool) {
|
||||
_mode, ok := apiAuthModes[mode]
|
||||
return _mode, ok
|
||||
}
|
||||
|
||||
var thriftAPIAuthModes = func() map[ToolAuthMode]common.PluginToolAuthType {
|
||||
modes := make(map[ToolAuthMode]common.PluginToolAuthType, len(apiAuthModes))
|
||||
for k, v := range apiAuthModes {
|
||||
modes[v] = k
|
||||
}
|
||||
return modes
|
||||
}()
|
||||
|
||||
func ToThriftAPIAuthMode(mode ToolAuthMode) (common.PluginToolAuthType, bool) {
|
||||
_mode, ok := thriftAPIAuthModes[mode]
|
||||
return _mode, ok
|
||||
}
|
||||
429
backend/api/model/crossdomain/plugin/openai.go
Normal file
429
backend/api/model/crossdomain/plugin/openai.go
Normal file
@@ -0,0 +1,429 @@
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/getkin/kin-openapi/openapi3"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
type Openapi3T openapi3.T
|
||||
|
||||
func (ot Openapi3T) Validate(ctx context.Context) (err error) {
|
||||
err = ptr.Of(openapi3.T(ot)).Validate(ctx)
|
||||
if err != nil {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey, err.Error()))
|
||||
}
|
||||
|
||||
if ot.Info == nil {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
|
||||
"info is required"))
|
||||
}
|
||||
if ot.Info.Title == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
|
||||
"the title of info is required"))
|
||||
}
|
||||
if ot.Info.Description == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
|
||||
"the description of info is required"))
|
||||
}
|
||||
|
||||
if len(ot.Servers) != 1 {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
|
||||
"server is required and only one server is allowed"))
|
||||
}
|
||||
|
||||
serverURL := ot.Servers[0].URL
|
||||
urlSchema, err := url.Parse(serverURL)
|
||||
if err != nil {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KVf(errno.PluginMsgKey,
|
||||
"invalid server url '%s'", serverURL))
|
||||
}
|
||||
if urlSchema.Scheme != "https" {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
|
||||
"server url must start with 'https://'"))
|
||||
}
|
||||
if urlSchema.Host == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KVf(errno.PluginMsgKey,
|
||||
"invalid server url '%s'", serverURL))
|
||||
}
|
||||
if len(serverURL) > 512 {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KVf(errno.PluginMsgKey,
|
||||
"server url '%s' is too long", serverURL))
|
||||
}
|
||||
|
||||
for _, pathItem := range ot.Paths {
|
||||
for _, op := range pathItem.Operations() {
|
||||
err = NewOpenapi3Operation(op).Validate(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewOpenapi3Operation(op *openapi3.Operation) *Openapi3Operation {
|
||||
return &Openapi3Operation{
|
||||
Operation: op,
|
||||
}
|
||||
}
|
||||
|
||||
type Openapi3Operation struct {
|
||||
*openapi3.Operation
|
||||
}
|
||||
|
||||
func (op *Openapi3Operation) MarshalJSON() ([]byte, error) {
|
||||
return op.Operation.MarshalJSON()
|
||||
}
|
||||
|
||||
func (op *Openapi3Operation) UnmarshalJSON(data []byte) error {
|
||||
op.Operation = &openapi3.Operation{}
|
||||
return op.Operation.UnmarshalJSON(data)
|
||||
}
|
||||
|
||||
func (op *Openapi3Operation) Validate(ctx context.Context) (err error) {
|
||||
err = op.Operation.Validate(ctx)
|
||||
if err != nil {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KVf(errno.PluginMsgKey, "operation is invalid, err=%s", err))
|
||||
}
|
||||
|
||||
if op.OperationID == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey, "operationID is required"))
|
||||
}
|
||||
if op.Summary == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey, "summary is required"))
|
||||
}
|
||||
|
||||
err = validateOpenapi3RequestBody(op.RequestBody)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = validateOpenapi3Parameters(op.Parameters)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = validateOpenapi3Responses(op.Responses)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (op *Openapi3Operation) ToEinoSchemaParameterInfo(ctx context.Context) (map[string]*schema.ParameterInfo, error) {
|
||||
convertType := func(openapiType string) schema.DataType {
|
||||
switch openapiType {
|
||||
case openapi3.TypeString:
|
||||
return schema.String
|
||||
case openapi3.TypeInteger:
|
||||
return schema.Integer
|
||||
case openapi3.TypeObject:
|
||||
return schema.Object
|
||||
case openapi3.TypeArray:
|
||||
return schema.Array
|
||||
case openapi3.TypeBoolean:
|
||||
return schema.Boolean
|
||||
case openapi3.TypeNumber:
|
||||
return schema.Number
|
||||
default:
|
||||
return schema.Null
|
||||
}
|
||||
}
|
||||
|
||||
var convertReqBody func(sc *openapi3.Schema, isRequired bool) (*schema.ParameterInfo, error)
|
||||
convertReqBody = func(sc *openapi3.Schema, isRequired bool) (*schema.ParameterInfo, error) {
|
||||
if disabledParam(sc) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
paramInfo := &schema.ParameterInfo{
|
||||
Type: convertType(sc.Type),
|
||||
Desc: sc.Description,
|
||||
Required: isRequired,
|
||||
}
|
||||
|
||||
switch sc.Type {
|
||||
case openapi3.TypeObject:
|
||||
required := slices.ToMap(sc.Required, func(e string) (string, bool) {
|
||||
return e, true
|
||||
})
|
||||
|
||||
subParams := make(map[string]*schema.ParameterInfo, len(sc.Properties))
|
||||
for paramName, prop := range sc.Properties {
|
||||
subParam, err := convertReqBody(prop.Value, required[paramName])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
subParams[paramName] = subParam
|
||||
}
|
||||
|
||||
paramInfo.SubParams = subParams
|
||||
|
||||
case openapi3.TypeArray:
|
||||
ele, err := convertReqBody(sc.Items.Value, isRequired)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramInfo.ElemInfo = ele
|
||||
|
||||
case openapi3.TypeString, openapi3.TypeInteger, openapi3.TypeBoolean, openapi3.TypeNumber:
|
||||
return paramInfo, nil
|
||||
|
||||
default:
|
||||
return nil, errorx.New(errno.ErrSearchInvalidParamCode, errorx.KVf(errno.PluginMsgKey,
|
||||
"unsupported json type '%s'", sc.Type))
|
||||
}
|
||||
|
||||
return paramInfo, nil
|
||||
}
|
||||
|
||||
result := make(map[string]*schema.ParameterInfo)
|
||||
|
||||
for _, prop := range op.Parameters {
|
||||
paramVal := prop.Value
|
||||
schemaVal := paramVal.Schema.Value
|
||||
if schemaVal.Type == openapi3.TypeObject || schemaVal.Type == openapi3.TypeArray {
|
||||
continue
|
||||
}
|
||||
|
||||
if disabledParam(prop.Value.Schema.Value) {
|
||||
continue
|
||||
}
|
||||
|
||||
paramInfo := &schema.ParameterInfo{
|
||||
Type: convertType(schemaVal.Type),
|
||||
Desc: paramVal.Description,
|
||||
Required: paramVal.Required,
|
||||
}
|
||||
|
||||
if _, ok := result[paramVal.Name]; ok {
|
||||
logs.CtxWarnf(ctx, "duplicate parameter name '%s'", paramVal.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
result[paramVal.Name] = paramInfo
|
||||
}
|
||||
|
||||
if op.RequestBody == nil || op.RequestBody.Value == nil || len(op.RequestBody.Value.Content) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
for _, mType := range op.RequestBody.Value.Content {
|
||||
schemaVal := mType.Schema.Value
|
||||
if len(schemaVal.Properties) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
required := slices.ToMap(schemaVal.Required, func(e string) (string, bool) {
|
||||
return e, true
|
||||
})
|
||||
|
||||
for paramName, prop := range schemaVal.Properties {
|
||||
paramInfo, err := convertReqBody(prop.Value, required[paramName])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, ok := result[paramName]; ok {
|
||||
logs.CtxWarnf(ctx, "duplicate parameter name '%s'", paramName)
|
||||
continue
|
||||
}
|
||||
|
||||
result[paramName] = paramInfo
|
||||
}
|
||||
|
||||
break // 只取一种 MIME
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func validateOpenapi3RequestBody(bodyRef *openapi3.RequestBodyRef) (err error) {
|
||||
if bodyRef == nil {
|
||||
return nil
|
||||
}
|
||||
if bodyRef.Value == nil || len(bodyRef.Value.Content) == 0 {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
|
||||
"request body is required"))
|
||||
}
|
||||
|
||||
body := bodyRef.Value
|
||||
if len(body.Content) != 1 {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
|
||||
"request body only supports one media type"))
|
||||
}
|
||||
|
||||
var mType *openapi3.MediaType
|
||||
for _, ct := range mediaTypeArray {
|
||||
var ok bool
|
||||
mType, ok = body.Content[ct]
|
||||
if ok {
|
||||
break
|
||||
}
|
||||
}
|
||||
if mType == nil {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KVf(errno.PluginMsgKey,
|
||||
"invalid media type, request body only the following types: [%s]", strings.Join(mediaTypeArray, ", ")))
|
||||
}
|
||||
|
||||
if mType.Schema == nil || mType.Schema.Value == nil {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
|
||||
"request body schema is required"))
|
||||
}
|
||||
|
||||
sc := mType.Schema.Value
|
||||
if sc.Type == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
|
||||
"request body only supports 'object' type"))
|
||||
}
|
||||
if sc.Type != openapi3.TypeObject {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
|
||||
"request body only supports 'object' type"))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateOpenapi3Parameters(params openapi3.Parameters) (err error) {
|
||||
if len(params) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, param := range params {
|
||||
if param == nil || param.Value == nil || param.Value.Schema == nil || param.Value.Schema.Value == nil {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
|
||||
"parameter schema is required"))
|
||||
}
|
||||
|
||||
paramVal := param.Value
|
||||
|
||||
if paramVal.In == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
|
||||
"parameter location is required"))
|
||||
}
|
||||
if paramVal.In == string(ParamInBody) {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KVf(errno.PluginMsgKey,
|
||||
"the location of parameter '%s' cannot be 'body'", paramVal.Name))
|
||||
}
|
||||
|
||||
paramSchema := paramVal.Schema.Value
|
||||
if paramSchema.Type == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KVf(errno.PluginMsgKey,
|
||||
"the type of parameter '%s' is required", paramVal.Name))
|
||||
}
|
||||
if paramSchema.Type == openapi3.TypeObject {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KVf(errno.PluginMsgKey,
|
||||
"the type of parameter '%s' cannot be 'object'", paramVal.Name))
|
||||
}
|
||||
if paramVal.In == openapi3.ParameterInPath && paramSchema.Type == openapi3.TypeArray {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KVf(errno.PluginMsgKey,
|
||||
"the type of parameter '%s' cannot be 'array'", paramVal.Name))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MIME Type
|
||||
const (
|
||||
MediaTypeJson = "application/json"
|
||||
MediaTypeProblemJson = "application/problem+json"
|
||||
MediaTypeFormURLEncoded = "application/x-www-form-urlencoded"
|
||||
MediaTypeXYaml = "application/x-yaml"
|
||||
MediaTypeYaml = "application/yaml"
|
||||
)
|
||||
|
||||
var mediaTypeArray = []string{
|
||||
MediaTypeJson,
|
||||
MediaTypeProblemJson,
|
||||
MediaTypeFormURLEncoded,
|
||||
MediaTypeXYaml,
|
||||
MediaTypeYaml,
|
||||
}
|
||||
|
||||
func validateOpenapi3Responses(responses openapi3.Responses) (err error) {
|
||||
if len(responses) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// default status 不处理
|
||||
// 只处理 '200' status
|
||||
if len(responses) != 1 {
|
||||
if len(responses) != 2 {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
|
||||
"response only supports '200' status"))
|
||||
} else if _, ok := responses["default"]; !ok {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
|
||||
"response only supports '200' status"))
|
||||
}
|
||||
}
|
||||
|
||||
resp, ok := responses[strconv.Itoa(http.StatusOK)]
|
||||
if !ok || resp == nil {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
|
||||
"response only supports '200' status"))
|
||||
}
|
||||
if resp.Value == nil {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
|
||||
"response schema is required"))
|
||||
}
|
||||
if len(resp.Value.Content) != 1 {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
|
||||
"response only supports 'application/json' media type"))
|
||||
}
|
||||
mType, ok := resp.Value.Content[MediaTypeJson]
|
||||
if !ok || mType == nil {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
|
||||
"response only supports 'application/json' media type"))
|
||||
|
||||
}
|
||||
if mType.Schema == nil || mType.Schema.Value == nil {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
|
||||
"the media type schema of response is required"))
|
||||
}
|
||||
|
||||
sc := mType.Schema.Value
|
||||
if sc.Type == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
|
||||
"response body only supports 'object' type"))
|
||||
}
|
||||
if sc.Type != openapi3.TypeObject {
|
||||
return errorx.New(errno.ErrPluginInvalidOpenapi3Doc, errorx.KV(errno.PluginMsgKey,
|
||||
"response body only supports 'object' type"))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func disabledParam(schemaVal *openapi3.Schema) bool {
|
||||
if len(schemaVal.Extensions) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
globalDisable, localDisable := false, false
|
||||
if v, ok := schemaVal.Extensions[APISchemaExtendLocalDisable]; ok {
|
||||
localDisable = v.(bool)
|
||||
}
|
||||
|
||||
if v, ok := schemaVal.Extensions[APISchemaExtendGlobalDisable]; ok {
|
||||
globalDisable = v.(bool)
|
||||
}
|
||||
|
||||
return globalDisable || localDisable
|
||||
}
|
||||
51
backend/api/model/crossdomain/plugin/option.go
Normal file
51
backend/api/model/crossdomain/plugin/option.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package plugin
|
||||
|
||||
type ExecuteToolOption struct {
|
||||
ProjectInfo *ProjectInfo
|
||||
|
||||
AutoGenRespSchema bool
|
||||
|
||||
ToolVersion string
|
||||
Operation *Openapi3Operation
|
||||
InvalidRespProcessStrategy InvalidResponseProcessStrategy
|
||||
}
|
||||
|
||||
type ExecuteToolOpt func(o *ExecuteToolOption)
|
||||
|
||||
type ProjectInfo struct {
|
||||
ProjectID int64 // agentID or appID
|
||||
ProjectVersion *string // if version si nil, use latest version
|
||||
ProjectType ProjectType // agent or app
|
||||
|
||||
ConnectorID int64
|
||||
}
|
||||
|
||||
func WithProjectInfo(info *ProjectInfo) ExecuteToolOpt {
|
||||
return func(o *ExecuteToolOption) {
|
||||
o.ProjectInfo = info
|
||||
}
|
||||
}
|
||||
|
||||
func WithToolVersion(version string) ExecuteToolOpt {
|
||||
return func(o *ExecuteToolOption) {
|
||||
o.ToolVersion = version
|
||||
}
|
||||
}
|
||||
|
||||
func WithOpenapiOperation(op *Openapi3Operation) ExecuteToolOpt {
|
||||
return func(o *ExecuteToolOption) {
|
||||
o.Operation = op
|
||||
}
|
||||
}
|
||||
|
||||
func WithInvalidRespProcessStrategy(strategy InvalidResponseProcessStrategy) ExecuteToolOpt {
|
||||
return func(o *ExecuteToolOption) {
|
||||
o.InvalidRespProcessStrategy = strategy
|
||||
}
|
||||
}
|
||||
|
||||
func WithAutoGenRespSchema() ExecuteToolOpt {
|
||||
return func(o *ExecuteToolOption) {
|
||||
o.AutoGenRespSchema = true
|
||||
}
|
||||
}
|
||||
153
backend/api/model/crossdomain/plugin/plugin.go
Normal file
153
backend/api/model/crossdomain/plugin/plugin.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"github.com/getkin/kin-openapi/openapi3"
|
||||
|
||||
api "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common"
|
||||
)
|
||||
|
||||
type VersionPlugin struct {
|
||||
PluginID int64
|
||||
Version string
|
||||
}
|
||||
|
||||
type VersionTool struct {
|
||||
ToolID int64
|
||||
Version string
|
||||
}
|
||||
|
||||
type MGetPluginLatestVersionResponse struct {
|
||||
Versions map[int64]string // pluginID vs version
|
||||
}
|
||||
|
||||
type PluginInfo struct {
|
||||
ID int64
|
||||
PluginType api.PluginType
|
||||
SpaceID int64
|
||||
DeveloperID int64
|
||||
APPID *int64
|
||||
RefProductID *int64 // for product plugin
|
||||
IconURI *string
|
||||
ServerURL *string
|
||||
Version *string
|
||||
VersionDesc *string
|
||||
|
||||
CreatedAt int64
|
||||
UpdatedAt int64
|
||||
|
||||
Manifest *PluginManifest
|
||||
OpenapiDoc *Openapi3T
|
||||
}
|
||||
|
||||
func (p PluginInfo) SetName(name string) {
|
||||
if p.Manifest == nil || p.OpenapiDoc == nil {
|
||||
return
|
||||
}
|
||||
p.Manifest.NameForModel = name
|
||||
p.Manifest.NameForHuman = name
|
||||
p.OpenapiDoc.Info.Title = name
|
||||
}
|
||||
|
||||
func (p PluginInfo) GetName() string {
|
||||
if p.Manifest == nil {
|
||||
return ""
|
||||
}
|
||||
return p.Manifest.NameForHuman
|
||||
}
|
||||
|
||||
func (p PluginInfo) GetDesc() string {
|
||||
if p.Manifest == nil {
|
||||
return ""
|
||||
}
|
||||
return p.Manifest.DescriptionForHuman
|
||||
}
|
||||
|
||||
func (p PluginInfo) GetAuthInfo() *AuthV2 {
|
||||
if p.Manifest == nil {
|
||||
return nil
|
||||
}
|
||||
return p.Manifest.Auth
|
||||
}
|
||||
|
||||
func (p PluginInfo) IsOfficial() bool {
|
||||
return p.RefProductID != nil
|
||||
}
|
||||
|
||||
func (p PluginInfo) GetIconURI() string {
|
||||
if p.IconURI == nil {
|
||||
return ""
|
||||
}
|
||||
return *p.IconURI
|
||||
}
|
||||
|
||||
func (p PluginInfo) Published() bool {
|
||||
return p.Version != nil
|
||||
}
|
||||
|
||||
type VersionAgentTool struct {
|
||||
ToolName *string
|
||||
ToolID int64
|
||||
|
||||
AgentVersion *string
|
||||
}
|
||||
|
||||
type MGetAgentToolsRequest struct {
|
||||
AgentID int64
|
||||
SpaceID int64
|
||||
IsDraft bool
|
||||
|
||||
VersionAgentTools []VersionAgentTool
|
||||
}
|
||||
|
||||
type ExecuteToolRequest struct {
|
||||
UserID string
|
||||
PluginID int64
|
||||
ToolID int64
|
||||
ExecDraftTool bool // if true, execute draft tool
|
||||
ExecScene ExecuteScene
|
||||
|
||||
ArgumentsInJson string
|
||||
}
|
||||
|
||||
type ExecuteToolResponse struct {
|
||||
Tool *ToolInfo
|
||||
Request string
|
||||
TrimmedResp string
|
||||
RawResp string
|
||||
|
||||
RespSchema openapi3.Responses
|
||||
}
|
||||
|
||||
type PublishPluginRequest struct {
|
||||
PluginID int64
|
||||
Version string
|
||||
VersionDesc string
|
||||
}
|
||||
|
||||
type PublishAPPPluginsRequest struct {
|
||||
APPID int64
|
||||
Version string
|
||||
}
|
||||
|
||||
type PublishAPPPluginsResponse struct {
|
||||
FailedPlugins []*PluginInfo
|
||||
AllDraftPlugins []*PluginInfo
|
||||
}
|
||||
|
||||
type CheckCanPublishPluginsRequest struct {
|
||||
PluginIDs []int64
|
||||
Version string
|
||||
}
|
||||
|
||||
type CheckCanPublishPluginsResponse struct {
|
||||
InvalidPlugins []*PluginInfo
|
||||
}
|
||||
|
||||
type ToolInterruptEvent struct {
|
||||
Event InterruptEventType
|
||||
ToolNeedOAuth *ToolNeedOAuthInterruptEvent
|
||||
}
|
||||
|
||||
type ToolNeedOAuthInterruptEvent struct {
|
||||
Message string
|
||||
}
|
||||
497
backend/api/model/crossdomain/plugin/plugin_manifest.go
Normal file
497
backend/api/model/crossdomain/plugin/plugin_manifest.go
Normal file
@@ -0,0 +1,497 @@
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
api "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/utils"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
)
|
||||
|
||||
type PluginManifest struct {
|
||||
SchemaVersion string `json:"schema_version" yaml:"schema_version"`
|
||||
NameForModel string `json:"name_for_model" yaml:"name_for_model"`
|
||||
NameForHuman string `json:"name_for_human" yaml:"name_for_human"`
|
||||
DescriptionForModel string `json:"description_for_model" yaml:"description_for_model"`
|
||||
DescriptionForHuman string `json:"description_for_human" yaml:"description_for_human"`
|
||||
Auth *AuthV2 `json:"auth" yaml:"auth"`
|
||||
LogoURL string `json:"logo_url" yaml:"logo_url"`
|
||||
API APIDesc `json:"api" yaml:"api"`
|
||||
CommonParams map[HTTPParamLocation][]*api.CommonParamSchema `json:"common_params" yaml:"common_params"`
|
||||
}
|
||||
|
||||
func (mf *PluginManifest) Copy() (*PluginManifest, error) {
|
||||
if mf == nil {
|
||||
return mf, nil
|
||||
}
|
||||
|
||||
b, err := json.Marshal(mf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mf_ := &PluginManifest{}
|
||||
err = json.Unmarshal(b, mf_)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return mf_, err
|
||||
}
|
||||
|
||||
func (mf *PluginManifest) EncryptAuthPayload() (*PluginManifest, error) {
|
||||
if mf == nil || mf.Auth == nil {
|
||||
return mf, nil
|
||||
}
|
||||
|
||||
mf_, err := mf.Copy()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if mf_.Auth.Payload == "" {
|
||||
return mf_, nil
|
||||
}
|
||||
|
||||
payload_, err := utils.EncryptByAES([]byte(mf_.Auth.Payload), utils.AuthSecretKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mf_.Auth.Payload = payload_
|
||||
|
||||
return mf_, nil
|
||||
}
|
||||
|
||||
func (mf *PluginManifest) Validate(skipAuthPayload bool) (err error) {
|
||||
if mf.SchemaVersion != "v1" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey,
|
||||
"invalid schema version '%s'", mf.SchemaVersion))
|
||||
}
|
||||
if mf.NameForModel == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"name for model is required"))
|
||||
}
|
||||
if mf.NameForHuman == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"name for human is required"))
|
||||
}
|
||||
if mf.DescriptionForModel == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"description for model is required"))
|
||||
}
|
||||
if mf.DescriptionForHuman == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"description for human is required"))
|
||||
}
|
||||
if mf.API.Type != PluginTypeOfCloud {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey,
|
||||
"invalid api type '%s'", mf.API.Type))
|
||||
}
|
||||
|
||||
err = mf.validateAuthInfo(skipAuthPayload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for loc := range mf.CommonParams {
|
||||
if loc != ParamInBody &&
|
||||
loc != ParamInHeader &&
|
||||
loc != ParamInQuery &&
|
||||
loc != ParamInPath {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey,
|
||||
"invalid location '%s' in common params", loc))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mf *PluginManifest) validateAuthInfo(skipAuthPayload bool) (err error) {
|
||||
if mf.Auth == nil {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"auth is required"))
|
||||
}
|
||||
|
||||
if mf.Auth.Payload != "" {
|
||||
js := json.RawMessage{}
|
||||
err = sonic.UnmarshalString(mf.Auth.Payload, &js)
|
||||
if err != nil {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"invalid auth payload"))
|
||||
}
|
||||
}
|
||||
|
||||
if mf.Auth.Type == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"auth type is required"))
|
||||
}
|
||||
|
||||
if mf.Auth.Type != AuthzTypeOfNone &&
|
||||
mf.Auth.Type != AuthzTypeOfOAuth &&
|
||||
mf.Auth.Type != AuthzTypeOfService {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey,
|
||||
"invalid auth type '%s'", mf.Auth.Type))
|
||||
}
|
||||
|
||||
if mf.Auth.Type == AuthzTypeOfNone {
|
||||
return nil
|
||||
}
|
||||
|
||||
if mf.Auth.SubType == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"sub-auth type is required"))
|
||||
}
|
||||
|
||||
switch mf.Auth.SubType {
|
||||
case AuthzSubTypeOfServiceAPIToken:
|
||||
err = mf.validateServiceToken(skipAuthPayload)
|
||||
//case AuthzSubTypeOfOAuthClientCredentials:
|
||||
// err = mf.validateClientCredentials()
|
||||
case AuthzSubTypeOfOAuthAuthorizationCode:
|
||||
err = mf.validateAuthCode(skipAuthPayload)
|
||||
default:
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey,
|
||||
"invalid sub-auth type '%s'", mf.Auth.SubType))
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mf *PluginManifest) validateServiceToken(skipAuthPayload bool) (err error) {
|
||||
if mf.Auth.AuthOfAPIToken == nil {
|
||||
err = sonic.UnmarshalString(mf.Auth.Payload, &mf.Auth.AuthOfAPIToken)
|
||||
if err != nil {
|
||||
return errorx.WrapByCode(err, errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"invalid auth payload"))
|
||||
}
|
||||
}
|
||||
|
||||
if skipAuthPayload {
|
||||
return nil
|
||||
}
|
||||
|
||||
apiToken := mf.Auth.AuthOfAPIToken
|
||||
|
||||
if apiToken.ServiceToken == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"service token is required"))
|
||||
}
|
||||
if apiToken.Key == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"key is required"))
|
||||
}
|
||||
|
||||
loc := HTTPParamLocation(strings.ToLower(string(apiToken.Location)))
|
||||
if loc != ParamInHeader && loc != ParamInQuery {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey,
|
||||
"invalid location '%s'", apiToken.Location))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mf *PluginManifest) validateClientCredentials() (err error) {
|
||||
if mf.Auth.AuthOfOAuthClientCredentials == nil {
|
||||
err = sonic.UnmarshalString(mf.Auth.Payload, &mf.Auth.AuthOfOAuthClientCredentials)
|
||||
if err != nil {
|
||||
return errorx.WrapByCode(err, errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"invalid auth payload"))
|
||||
}
|
||||
}
|
||||
|
||||
clientCredentials := mf.Auth.AuthOfOAuthClientCredentials
|
||||
|
||||
if clientCredentials.ClientID == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"client id is required"))
|
||||
}
|
||||
if clientCredentials.ClientSecret == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"client secret is required"))
|
||||
}
|
||||
if clientCredentials.TokenURL == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"token url is required"))
|
||||
}
|
||||
|
||||
urlParse, err := url.Parse(clientCredentials.TokenURL)
|
||||
if err != nil || urlParse.Hostname() == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"invalid token url"))
|
||||
}
|
||||
if urlParse.Scheme != "https" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"token url scheme must be 'https'"))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mf *PluginManifest) validateAuthCode(skipAuthPayload bool) (err error) {
|
||||
if mf.Auth.AuthOfOAuthAuthorizationCode == nil {
|
||||
err = sonic.UnmarshalString(mf.Auth.Payload, &mf.Auth.AuthOfOAuthAuthorizationCode)
|
||||
if err != nil {
|
||||
return errorx.WrapByCode(err, errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"invalid auth payload"))
|
||||
}
|
||||
}
|
||||
|
||||
if skipAuthPayload {
|
||||
return nil
|
||||
}
|
||||
|
||||
authCode := mf.Auth.AuthOfOAuthAuthorizationCode
|
||||
|
||||
if authCode.ClientID == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"client id is required"))
|
||||
}
|
||||
if authCode.ClientSecret == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"client secret is required"))
|
||||
}
|
||||
if authCode.AuthorizationContentType != MediaTypeJson {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"authorization content type must be 'application/json'"))
|
||||
}
|
||||
if authCode.AuthorizationURL == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"token url is required"))
|
||||
}
|
||||
if authCode.ClientURL == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"client url is required"))
|
||||
}
|
||||
|
||||
urlParse, err := url.Parse(authCode.AuthorizationURL)
|
||||
if err != nil || urlParse.Hostname() == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"invalid authorization url"))
|
||||
}
|
||||
if urlParse.Scheme != "https" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"authorization url scheme must be 'https'"))
|
||||
}
|
||||
|
||||
urlParse, err = url.Parse(authCode.ClientURL)
|
||||
if err != nil || urlParse.Hostname() == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"invalid client url"))
|
||||
}
|
||||
if urlParse.Scheme != "https" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"client url scheme must be 'https'"))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type Auth struct {
|
||||
Type string `json:"type" validate:"required"`
|
||||
AuthorizationType string `json:"authorization_type,omitempty"`
|
||||
ClientURL string `json:"client_url,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
AuthorizationURL string `json:"authorization_url,omitempty"`
|
||||
AuthorizationContentType string `json:"authorization_content_type,omitempty"`
|
||||
Platform string `json:"platform,omitempty"`
|
||||
ClientID string `json:"client_id,omitempty"`
|
||||
ClientSecret string `json:"client_secret,omitempty"`
|
||||
Location string `json:"location,omitempty"`
|
||||
Key string `json:"key,omitempty"`
|
||||
ServiceToken string `json:"service_token,omitempty"`
|
||||
SubType string `json:"sub_type"`
|
||||
Payload string `json:"payload"`
|
||||
}
|
||||
|
||||
type AuthV2 struct {
|
||||
Type AuthzType `json:"type" yaml:"type"`
|
||||
SubType AuthzSubType `json:"sub_type" yaml:"sub_type"`
|
||||
Payload string `json:"payload" yaml:"payload"`
|
||||
// service
|
||||
AuthOfAPIToken *AuthOfAPIToken `json:"-"`
|
||||
|
||||
// oauth
|
||||
AuthOfOAuthAuthorizationCode *OAuthAuthorizationCodeConfig `json:"-"`
|
||||
AuthOfOAuthClientCredentials *OAuthClientCredentialsConfig `json:"-"`
|
||||
}
|
||||
|
||||
func (au *AuthV2) UnmarshalJSON(data []byte) error {
|
||||
auth := &Auth{} // 兼容老数据
|
||||
err := json.Unmarshal(data, auth)
|
||||
if err != nil {
|
||||
return errorx.WrapByCode(err, errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"invalid plugin manifest json"))
|
||||
}
|
||||
|
||||
au.Type = AuthzType(auth.Type)
|
||||
au.SubType = AuthzSubType(auth.SubType)
|
||||
|
||||
if au.Type == "" {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"plugin auth type is required"))
|
||||
}
|
||||
|
||||
if auth.Payload != "" {
|
||||
payload_, err := utils.DecryptByAES(auth.Payload, utils.AuthSecretKey)
|
||||
if err == nil {
|
||||
auth.Payload = string(payload_)
|
||||
}
|
||||
}
|
||||
|
||||
switch au.Type {
|
||||
case AuthzTypeOfNone:
|
||||
case AuthzTypeOfOAuth:
|
||||
err = au.unmarshalOAuth(auth)
|
||||
case AuthzTypeOfService:
|
||||
err = au.unmarshalService(auth)
|
||||
default:
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey,
|
||||
"invalid plugin auth type '%s'", au.Type))
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (au *AuthV2) unmarshalService(auth *Auth) (err error) {
|
||||
if au.SubType == "" && au.Payload == "" { // 兼容老数据
|
||||
au.SubType = AuthzSubTypeOfServiceAPIToken
|
||||
}
|
||||
|
||||
var payload []byte
|
||||
|
||||
if au.SubType == AuthzSubTypeOfServiceAPIToken {
|
||||
if len(auth.ServiceToken) > 0 {
|
||||
au.AuthOfAPIToken = &AuthOfAPIToken{
|
||||
Location: HTTPParamLocation(strings.ToLower(auth.Location)),
|
||||
Key: auth.Key,
|
||||
ServiceToken: auth.ServiceToken,
|
||||
}
|
||||
} else {
|
||||
token := &AuthOfAPIToken{}
|
||||
err = json.Unmarshal([]byte(auth.Payload), token)
|
||||
if err != nil {
|
||||
return errorx.WrapByCode(err, errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"invalid auth payload json"))
|
||||
}
|
||||
au.AuthOfAPIToken = token
|
||||
}
|
||||
|
||||
payload, err = json.Marshal(au.AuthOfAPIToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(payload) == 0 {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey,
|
||||
"invalid plugin sub-auth type '%s'", au.SubType))
|
||||
}
|
||||
|
||||
au.Payload = string(payload)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (au *AuthV2) unmarshalOAuth(auth *Auth) (err error) {
|
||||
if au.SubType == "" { // 兼容老数据
|
||||
au.SubType = AuthzSubTypeOfOAuthAuthorizationCode
|
||||
}
|
||||
|
||||
var payload []byte
|
||||
|
||||
if au.SubType == AuthzSubTypeOfOAuthAuthorizationCode {
|
||||
if len(auth.ClientSecret) > 0 {
|
||||
au.AuthOfOAuthAuthorizationCode = &OAuthAuthorizationCodeConfig{
|
||||
ClientID: auth.ClientID,
|
||||
ClientSecret: auth.ClientSecret,
|
||||
ClientURL: auth.ClientURL,
|
||||
Scope: auth.Scope,
|
||||
AuthorizationURL: auth.AuthorizationURL,
|
||||
AuthorizationContentType: auth.AuthorizationContentType,
|
||||
}
|
||||
} else {
|
||||
oauth := &OAuthAuthorizationCodeConfig{}
|
||||
err = json.Unmarshal([]byte(auth.Payload), oauth)
|
||||
if err != nil {
|
||||
return errorx.WrapByCode(err, errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"invalid auth payload json"))
|
||||
}
|
||||
au.AuthOfOAuthAuthorizationCode = oauth
|
||||
}
|
||||
|
||||
payload, err = json.Marshal(au.AuthOfOAuthAuthorizationCode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if au.SubType == AuthzSubTypeOfOAuthClientCredentials {
|
||||
oauth := &OAuthClientCredentialsConfig{}
|
||||
err = json.Unmarshal([]byte(auth.Payload), oauth)
|
||||
if err != nil {
|
||||
return errorx.WrapByCode(err, errno.ErrPluginInvalidManifest, errorx.KV(errno.PluginMsgKey,
|
||||
"invalid auth payload json"))
|
||||
}
|
||||
au.AuthOfOAuthClientCredentials = oauth
|
||||
|
||||
payload, err = json.Marshal(au.AuthOfOAuthClientCredentials)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(payload) == 0 {
|
||||
return errorx.New(errno.ErrPluginInvalidManifest, errorx.KVf(errno.PluginMsgKey,
|
||||
"invalid plugin sub-auth type '%s'", au.SubType))
|
||||
}
|
||||
|
||||
au.Payload = string(payload)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type AuthOfAPIToken struct {
|
||||
// Location is the location of the parameter.
|
||||
// It can be "header" or "query".
|
||||
Location HTTPParamLocation `json:"location"`
|
||||
// Key is the name of the parameter.
|
||||
Key string `json:"key"`
|
||||
// ServiceToken is the simple authorization information for the service.
|
||||
ServiceToken string `json:"service_token"`
|
||||
}
|
||||
|
||||
type OAuthAuthorizationCodeConfig struct {
|
||||
ClientID string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"`
|
||||
// ClientURL is the URL of authorization endpoint.
|
||||
ClientURL string `json:"client_url"`
|
||||
// Scope is the scope of the authorization request.
|
||||
// If multiple scopes are requested, they must be separated by a space.
|
||||
Scope string `json:"scope,omitempty"`
|
||||
// AuthorizationURL is the URL of token exchange endpoint.
|
||||
AuthorizationURL string `json:"authorization_url"`
|
||||
// AuthorizationContentType is the content type of the authorization request, and it must be "application/json".
|
||||
AuthorizationContentType string `json:"authorization_content_type"`
|
||||
}
|
||||
|
||||
type OAuthClientCredentialsConfig struct {
|
||||
ClientID string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"`
|
||||
TokenURL string `json:"token_url"`
|
||||
}
|
||||
|
||||
type APIDesc struct {
|
||||
Type PluginType `json:"type" validate:"required"`
|
||||
}
|
||||
566
backend/api/model/crossdomain/plugin/toolinfo.go
Normal file
566
backend/api/model/crossdomain/plugin/toolinfo.go
Normal file
@@ -0,0 +1,566 @@
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/getkin/kin-openapi/openapi3"
|
||||
gonanoid "github.com/matoous/go-nanoid"
|
||||
|
||||
productAPI "github.com/coze-dev/coze-studio/backend/api/model/flow/marketplace/product_public_api"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common"
|
||||
common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
)
|
||||
|
||||
type ToolInfo struct {
|
||||
ID int64
|
||||
PluginID int64
|
||||
CreatedAt int64
|
||||
UpdatedAt int64
|
||||
Version *string
|
||||
|
||||
ActivatedStatus *ActivatedStatus
|
||||
DebugStatus *plugin_develop_common.APIDebugStatus
|
||||
|
||||
Method *string
|
||||
SubURL *string
|
||||
Operation *Openapi3Operation
|
||||
}
|
||||
|
||||
func (t ToolInfo) GetName() string {
|
||||
if t.Operation == nil {
|
||||
return ""
|
||||
}
|
||||
return t.Operation.OperationID
|
||||
}
|
||||
|
||||
func (t ToolInfo) GetDesc() string {
|
||||
if t.Operation == nil {
|
||||
return ""
|
||||
}
|
||||
return t.Operation.Summary
|
||||
}
|
||||
|
||||
func (t ToolInfo) GetVersion() string {
|
||||
return ptr.FromOrDefault(t.Version, "")
|
||||
}
|
||||
|
||||
func (t ToolInfo) GetActivatedStatus() ActivatedStatus {
|
||||
return ptr.FromOrDefault(t.ActivatedStatus, ActivateTool)
|
||||
}
|
||||
|
||||
func (t ToolInfo) GetSubURL() string {
|
||||
return ptr.FromOrDefault(t.SubURL, "")
|
||||
}
|
||||
|
||||
func (t ToolInfo) GetMethod() string {
|
||||
return strings.ToUpper(ptr.FromOrDefault(t.Method, ""))
|
||||
}
|
||||
|
||||
func (t ToolInfo) GetDebugStatus() common.APIDebugStatus {
|
||||
return ptr.FromOrDefault(t.DebugStatus, common.APIDebugStatus_DebugWaiting)
|
||||
}
|
||||
|
||||
func (t ToolInfo) GetResponseOpenapiSchema() (*openapi3.Schema, error) {
|
||||
op := t.Operation
|
||||
if op == nil {
|
||||
return nil, fmt.Errorf("operation is required")
|
||||
}
|
||||
|
||||
resp, ok := op.Responses[strconv.Itoa(http.StatusOK)]
|
||||
if !ok || resp == nil || resp.Value == nil || len(resp.Value.Content) == 0 {
|
||||
return nil, fmt.Errorf("response status '200' not found")
|
||||
}
|
||||
|
||||
mType, ok := resp.Value.Content[MediaTypeJson] // only support application/json
|
||||
if !ok || mType == nil || mType.Schema == nil || mType.Schema.Value == nil {
|
||||
return nil, fmt.Errorf("media type '%s' not found in response", MediaTypeJson)
|
||||
}
|
||||
|
||||
return mType.Schema.Value, nil
|
||||
}
|
||||
|
||||
type paramMetaInfo struct {
|
||||
name string
|
||||
desc string
|
||||
required bool
|
||||
location string
|
||||
}
|
||||
|
||||
func (t ToolInfo) ToRespAPIParameter() ([]*common.APIParameter, error) {
|
||||
op := t.Operation
|
||||
if op == nil {
|
||||
return nil, fmt.Errorf("operation is required")
|
||||
}
|
||||
|
||||
respSchema, err := t.GetResponseOpenapiSchema()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params := make([]*common.APIParameter, 0, len(op.Parameters))
|
||||
if len(respSchema.Properties) == 0 {
|
||||
return params, nil
|
||||
}
|
||||
|
||||
required := slices.ToMap(respSchema.Required, func(e string) (string, bool) {
|
||||
return e, true
|
||||
})
|
||||
|
||||
for subParamName, prop := range respSchema.Properties {
|
||||
if prop == nil || prop.Value == nil {
|
||||
return nil, fmt.Errorf("the schema of property '%s' is required", subParamName)
|
||||
}
|
||||
|
||||
paramMeta := paramMetaInfo{
|
||||
name: subParamName,
|
||||
desc: prop.Value.Description,
|
||||
location: string(ParamInBody),
|
||||
required: required[subParamName],
|
||||
}
|
||||
apiParam, err := toAPIParameter(paramMeta, prop.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params = append(params, apiParam)
|
||||
}
|
||||
|
||||
return params, nil
|
||||
}
|
||||
|
||||
func (t ToolInfo) ToReqAPIParameter() ([]*common.APIParameter, error) {
|
||||
op := t.Operation
|
||||
if op == nil {
|
||||
return nil, fmt.Errorf("operation is required")
|
||||
}
|
||||
|
||||
params := make([]*common.APIParameter, 0, len(op.Parameters))
|
||||
for _, param := range op.Parameters {
|
||||
if param == nil || param.Value == nil || param.Value.Schema == nil || param.Value.Schema.Value == nil {
|
||||
return nil, fmt.Errorf("parameter schema is required")
|
||||
}
|
||||
|
||||
paramVal := param.Value
|
||||
schemaVal := paramVal.Schema.Value
|
||||
|
||||
if schemaVal.Type == openapi3.TypeObject {
|
||||
return nil, fmt.Errorf("the type of parameter '%s' cannot be 'object'", paramVal.Name)
|
||||
}
|
||||
|
||||
if schemaVal.Type == openapi3.TypeArray {
|
||||
if paramVal.In == openapi3.ParameterInPath {
|
||||
return nil, fmt.Errorf("the type of field '%s' cannot be 'array'", paramVal.Name)
|
||||
}
|
||||
if schemaVal.Items == nil || schemaVal.Items.Value == nil {
|
||||
return nil, fmt.Errorf("the item schema of field '%s' is required", paramVal.Name)
|
||||
}
|
||||
item := schemaVal.Items.Value
|
||||
if item.Type == openapi3.TypeObject || item.Type == openapi3.TypeArray {
|
||||
return nil, fmt.Errorf("the item type of parameter '%s' cannot be 'object' or 'array'", paramVal.Name)
|
||||
}
|
||||
}
|
||||
|
||||
paramMeta := paramMetaInfo{
|
||||
name: paramVal.Name,
|
||||
desc: paramVal.Description,
|
||||
location: paramVal.In,
|
||||
required: paramVal.Required,
|
||||
}
|
||||
apiParam, err := toAPIParameter(paramMeta, schemaVal)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params = append(params, apiParam)
|
||||
}
|
||||
|
||||
if op.RequestBody == nil || op.RequestBody.Value == nil || len(op.RequestBody.Value.Content) == 0 {
|
||||
return params, nil
|
||||
}
|
||||
|
||||
for _, mType := range op.RequestBody.Value.Content {
|
||||
if mType == nil || mType.Schema == nil || mType.Schema.Value == nil {
|
||||
return nil, fmt.Errorf("request body schema is required")
|
||||
}
|
||||
|
||||
schemaVal := mType.Schema.Value
|
||||
if len(schemaVal.Properties) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
required := slices.ToMap(schemaVal.Required, func(e string) (string, bool) {
|
||||
return e, true
|
||||
})
|
||||
|
||||
for subParamName, prop := range schemaVal.Properties {
|
||||
if prop == nil || prop.Value == nil {
|
||||
return nil, fmt.Errorf("the schema of property '%s' is required", subParamName)
|
||||
}
|
||||
|
||||
paramMeta := paramMetaInfo{
|
||||
name: subParamName,
|
||||
desc: prop.Value.Description,
|
||||
location: string(ParamInBody),
|
||||
required: required[subParamName],
|
||||
}
|
||||
apiParam, err := toAPIParameter(paramMeta, prop.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params = append(params, apiParam)
|
||||
}
|
||||
|
||||
break // 只取一种 MIME
|
||||
}
|
||||
|
||||
return params, nil
|
||||
}
|
||||
|
||||
func toAPIParameter(paramMeta paramMetaInfo, sc *openapi3.Schema) (*common.APIParameter, error) {
|
||||
if sc == nil {
|
||||
return nil, fmt.Errorf("schema is requred")
|
||||
}
|
||||
|
||||
apiType, ok := ToThriftParamType(strings.ToLower(sc.Type))
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("the type '%s' of filed '%s' is invalid", sc.Type, paramMeta.name)
|
||||
}
|
||||
location, ok := ToThriftHTTPParamLocation(HTTPParamLocation(paramMeta.location))
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("the location '%s' of field '%s' is invalid", paramMeta.location, paramMeta.name)
|
||||
}
|
||||
|
||||
apiParam := &common.APIParameter{
|
||||
ID: gonanoid.MustID(10),
|
||||
Name: paramMeta.name,
|
||||
Desc: paramMeta.desc,
|
||||
Type: apiType,
|
||||
Location: location, // 使用父节点的值
|
||||
IsRequired: paramMeta.required,
|
||||
SubParameters: []*common.APIParameter{},
|
||||
}
|
||||
|
||||
if sc.Default != nil {
|
||||
apiParam.LocalDefault = ptr.Of(fmt.Sprintf("%v", sc.Default))
|
||||
}
|
||||
|
||||
if sc.Format != "" {
|
||||
aType, ok := FormatToAssistType(sc.Format)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("the format '%s' of field '%s' is invalid", sc.Format, paramMeta.name)
|
||||
}
|
||||
_aType, ok := ToThriftAPIAssistType(aType)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("assist type '%s' of field '%s' is invalid", aType, paramMeta.name)
|
||||
}
|
||||
apiParam.AssistType = ptr.Of(_aType)
|
||||
}
|
||||
|
||||
if v, ok := sc.Extensions[APISchemaExtendGlobalDisable]; ok {
|
||||
if disable, ok := v.(bool); ok {
|
||||
apiParam.GlobalDisable = disable
|
||||
}
|
||||
}
|
||||
if v, ok := sc.Extensions[APISchemaExtendLocalDisable]; ok {
|
||||
if disable, ok := v.(bool); ok {
|
||||
apiParam.LocalDisable = disable
|
||||
}
|
||||
}
|
||||
if v, ok := sc.Extensions[APISchemaExtendVariableRef]; ok {
|
||||
if ref, ok := v.(string); ok {
|
||||
apiParam.VariableRef = ptr.Of(ref)
|
||||
apiParam.DefaultParamSource = ptr.Of(common.DefaultParamSource_Variable)
|
||||
}
|
||||
}
|
||||
|
||||
switch sc.Type {
|
||||
case openapi3.TypeObject:
|
||||
if len(sc.Properties) == 0 {
|
||||
return apiParam, nil
|
||||
}
|
||||
|
||||
required := slices.ToMap(sc.Required, func(e string) (string, bool) {
|
||||
return e, true
|
||||
})
|
||||
for subParamName, prop := range sc.Properties {
|
||||
if prop == nil || prop.Value == nil {
|
||||
return nil, fmt.Errorf("the schema of property '%s' is required", subParamName)
|
||||
}
|
||||
|
||||
subMeta := paramMetaInfo{
|
||||
name: subParamName,
|
||||
desc: prop.Value.Description,
|
||||
required: required[subParamName],
|
||||
location: paramMeta.location,
|
||||
}
|
||||
subParam, err := toAPIParameter(subMeta, prop.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
apiParam.SubParameters = append(apiParam.SubParameters, subParam)
|
||||
}
|
||||
|
||||
return apiParam, nil
|
||||
|
||||
case openapi3.TypeArray:
|
||||
if sc.Items == nil || sc.Items.Value == nil {
|
||||
return nil, fmt.Errorf("the item schema of field '%s' is required", paramMeta.name)
|
||||
}
|
||||
|
||||
item := sc.Items.Value
|
||||
|
||||
subMeta := paramMetaInfo{
|
||||
name: "[Array Item]",
|
||||
desc: item.Description,
|
||||
location: paramMeta.location,
|
||||
required: paramMeta.required,
|
||||
}
|
||||
subParam, err := toAPIParameter(subMeta, item)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
apiParam.SubParameters = append(apiParam.SubParameters, subParam)
|
||||
|
||||
return apiParam, nil
|
||||
}
|
||||
|
||||
return apiParam, nil
|
||||
}
|
||||
|
||||
func (t ToolInfo) ToPluginParameters() ([]*common.PluginParameter, error) {
|
||||
op := t.Operation
|
||||
if op == nil {
|
||||
return nil, fmt.Errorf("operation is required")
|
||||
}
|
||||
|
||||
var params []*common.PluginParameter
|
||||
|
||||
for _, prop := range op.Parameters {
|
||||
if prop == nil || prop.Value == nil || prop.Value.Schema == nil || prop.Value.Schema.Value == nil {
|
||||
return nil, fmt.Errorf("parameter schema is required")
|
||||
}
|
||||
|
||||
paramVal := prop.Value
|
||||
schemaVal := paramVal.Schema.Value
|
||||
|
||||
if schemaVal.Type == openapi3.TypeObject {
|
||||
return nil, fmt.Errorf("the type of parameter '%s' cannot be 'object'", paramVal.Name)
|
||||
}
|
||||
|
||||
var arrayItemType string
|
||||
if schemaVal.Type == openapi3.TypeArray {
|
||||
if paramVal.In == openapi3.ParameterInPath {
|
||||
return nil, fmt.Errorf("the type of field '%s' cannot be 'array'", paramVal.Name)
|
||||
}
|
||||
if schemaVal.Items == nil || schemaVal.Items.Value == nil {
|
||||
return nil, fmt.Errorf("the item schema of field '%s' is required", paramVal.Name)
|
||||
}
|
||||
item := schemaVal.Items.Value
|
||||
if item.Type == openapi3.TypeObject || item.Type == openapi3.TypeArray {
|
||||
return nil, fmt.Errorf("the item type of parameter '%s' cannot be 'object' or 'array'", paramVal.Name)
|
||||
}
|
||||
|
||||
arrayItemType = item.Type
|
||||
}
|
||||
|
||||
if disabledParam(schemaVal) {
|
||||
continue
|
||||
}
|
||||
|
||||
var assistType *common.PluginParamTypeFormat
|
||||
if v, ok := schemaVal.Extensions[APISchemaExtendAssistType]; ok {
|
||||
_v, ok := v.(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
f, ok := AssistTypeToThriftFormat(APIFileAssistType(_v))
|
||||
if ok {
|
||||
return nil, fmt.Errorf("the assist type '%s' of field '%s' is invalid", _v, paramVal.Name)
|
||||
}
|
||||
assistType = ptr.Of(f)
|
||||
}
|
||||
|
||||
params = append(params, &common.PluginParameter{
|
||||
Name: paramVal.Name,
|
||||
Desc: paramVal.Description,
|
||||
Required: paramVal.Required,
|
||||
Type: schemaVal.Type,
|
||||
SubType: arrayItemType,
|
||||
Format: assistType,
|
||||
SubParameters: []*common.PluginParameter{},
|
||||
})
|
||||
}
|
||||
|
||||
if op.RequestBody == nil || op.RequestBody.Value == nil || len(op.RequestBody.Value.Content) == 0 {
|
||||
return params, nil
|
||||
}
|
||||
|
||||
for _, mType := range op.RequestBody.Value.Content {
|
||||
if mType == nil || mType.Schema == nil || mType.Schema.Value == nil {
|
||||
return nil, fmt.Errorf("request body schema is required")
|
||||
}
|
||||
|
||||
schemaVal := mType.Schema.Value
|
||||
if len(schemaVal.Properties) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
required := slices.ToMap(schemaVal.Required, func(e string) (string, bool) {
|
||||
return e, true
|
||||
})
|
||||
|
||||
for subParamName, prop := range schemaVal.Properties {
|
||||
if prop == nil || prop.Value == nil {
|
||||
return nil, fmt.Errorf("the schema of property '%s' is required", subParamName)
|
||||
}
|
||||
|
||||
paramMeta := paramMetaInfo{
|
||||
name: subParamName,
|
||||
desc: prop.Value.Description,
|
||||
required: required[subParamName],
|
||||
}
|
||||
paramInfo, err := toPluginParameter(paramMeta, prop.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if paramInfo != nil {
|
||||
params = append(params, paramInfo)
|
||||
}
|
||||
}
|
||||
|
||||
break // 只取一种 MIME
|
||||
}
|
||||
|
||||
return params, nil
|
||||
}
|
||||
|
||||
func toPluginParameter(paramMeta paramMetaInfo, sc *openapi3.Schema) (*common.PluginParameter, error) {
|
||||
if sc == nil {
|
||||
return nil, fmt.Errorf("schema is required")
|
||||
}
|
||||
|
||||
if disabledParam(sc) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var assistType *common.PluginParamTypeFormat
|
||||
if v, ok := sc.Extensions[APISchemaExtendAssistType]; ok {
|
||||
if _v, ok := v.(string); ok {
|
||||
f, ok := AssistTypeToThriftFormat(APIFileAssistType(_v))
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("the assist type '%s' of field '%s' is invalid", _v, paramMeta.name)
|
||||
}
|
||||
assistType = ptr.Of(f)
|
||||
}
|
||||
}
|
||||
|
||||
pluginParam := &common.PluginParameter{
|
||||
Name: paramMeta.name,
|
||||
Type: sc.Type,
|
||||
Desc: paramMeta.desc,
|
||||
Required: paramMeta.required,
|
||||
Format: assistType,
|
||||
SubParameters: []*common.PluginParameter{},
|
||||
}
|
||||
|
||||
switch sc.Type {
|
||||
case openapi3.TypeObject:
|
||||
if len(sc.Properties) == 0 {
|
||||
return pluginParam, nil
|
||||
}
|
||||
|
||||
required := slices.ToMap(sc.Required, func(e string) (string, bool) {
|
||||
return e, true
|
||||
})
|
||||
for subParamName, prop := range sc.Properties {
|
||||
if prop == nil || prop.Value == nil {
|
||||
return nil, fmt.Errorf("the schema of property '%s' is required", subParamName)
|
||||
}
|
||||
|
||||
subMeta := paramMetaInfo{
|
||||
name: subParamName,
|
||||
desc: prop.Value.Description,
|
||||
required: required[subParamName],
|
||||
}
|
||||
subParam, err := toPluginParameter(subMeta, prop.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pluginParam.SubParameters = append(pluginParam.SubParameters, subParam)
|
||||
}
|
||||
|
||||
return pluginParam, nil
|
||||
|
||||
case openapi3.TypeArray:
|
||||
if sc.Items == nil || sc.Items.Value == nil {
|
||||
return nil, fmt.Errorf("the item schema of field '%s' is required", paramMeta.name)
|
||||
}
|
||||
|
||||
item := sc.Items.Value
|
||||
pluginParam.SubType = item.Type
|
||||
|
||||
if item.Type != openapi3.TypeObject {
|
||||
return pluginParam, nil
|
||||
}
|
||||
|
||||
subMeta := paramMetaInfo{
|
||||
desc: item.Description,
|
||||
required: paramMeta.required,
|
||||
}
|
||||
subParam, err := toPluginParameter(subMeta, item)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pluginParam.SubParameters = append(pluginParam.SubParameters, subParam.SubParameters...)
|
||||
|
||||
return pluginParam, nil
|
||||
}
|
||||
|
||||
return pluginParam, nil
|
||||
}
|
||||
|
||||
func (t ToolInfo) ToToolParameters() ([]*productAPI.ToolParameter, error) {
|
||||
apiParams, err := t.ToReqAPIParameter()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var toToolParams func(apiParams []*common.APIParameter) ([]*productAPI.ToolParameter, error)
|
||||
toToolParams = func(apiParams []*common.APIParameter) ([]*productAPI.ToolParameter, error) {
|
||||
params := make([]*productAPI.ToolParameter, 0, len(apiParams))
|
||||
for _, apiParam := range apiParams {
|
||||
typ, _ := ToOpenapiParamType(apiParam.Type)
|
||||
toolParam := &productAPI.ToolParameter{
|
||||
Name: apiParam.Name,
|
||||
Description: apiParam.Desc,
|
||||
Type: typ,
|
||||
IsRequired: apiParam.IsRequired,
|
||||
SubParameter: []*productAPI.ToolParameter{},
|
||||
}
|
||||
|
||||
if len(apiParam.SubParameters) > 0 {
|
||||
subParams, err := toToolParams(apiParam.SubParameters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
toolParam.SubParameter = append(toolParam.SubParameter, subParams...)
|
||||
}
|
||||
|
||||
params = append(params, toolParam)
|
||||
}
|
||||
|
||||
return params, nil
|
||||
}
|
||||
|
||||
return toToolParams(apiParams)
|
||||
}
|
||||
64
backend/api/model/crossdomain/search/resource_doc.go
Normal file
64
backend/api/model/crossdomain/search/resource_doc.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package search
|
||||
|
||||
import (
|
||||
resource "github.com/coze-dev/coze-studio/backend/api/model/resource/common"
|
||||
)
|
||||
|
||||
type ResourceDocument struct {
|
||||
ResID int64 `json:"res_id"`
|
||||
ResType resource.ResType `json:"res_type"`
|
||||
ResSubType *int32 `json:"res_sub_type,omitempty"`
|
||||
Name *string `json:"name,omitempty"`
|
||||
OwnerID *int64 `json:"owner_id,omitempty"`
|
||||
SpaceID *int64 `json:"space_id,omitempty"`
|
||||
APPID *int64 `json:"app_id,omitempty"`
|
||||
BizStatus *int64 `json:"biz_status,omitempty"`
|
||||
PublishStatus *resource.PublishStatus `json:"publish_status,omitempty"`
|
||||
|
||||
CreateTimeMS *int64 `json:"create_time,omitempty"`
|
||||
UpdateTimeMS *int64 `json:"update_time,omitempty"`
|
||||
PublishTimeMS *int64 `json:"publish_time,omitempty"`
|
||||
}
|
||||
|
||||
func (r *ResourceDocument) GetName() string {
|
||||
if r.Name != nil {
|
||||
return *r.Name
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (r *ResourceDocument) GetOwnerID() int64 {
|
||||
if r.OwnerID != nil {
|
||||
return *r.OwnerID
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetUpdateTime 获取更新时间
|
||||
func (r *ResourceDocument) GetUpdateTime() int64 {
|
||||
if r.UpdateTimeMS != nil {
|
||||
return *r.UpdateTimeMS
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (r *ResourceDocument) GetResSubType() int32 {
|
||||
if r.ResSubType != nil {
|
||||
return *r.ResSubType
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (r *ResourceDocument) GetCreateTime() int64 {
|
||||
if r.CreateTimeMS == nil {
|
||||
return 0
|
||||
}
|
||||
return *r.CreateTimeMS
|
||||
}
|
||||
|
||||
func (r *ResourceDocument) GetPublishTime() int64 {
|
||||
if r.PublishTimeMS == nil {
|
||||
return 0
|
||||
}
|
||||
return *r.PublishTimeMS
|
||||
}
|
||||
38
backend/api/model/crossdomain/search/search.go
Normal file
38
backend/api/model/crossdomain/search/search.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package search
|
||||
|
||||
import (
|
||||
resource "github.com/coze-dev/coze-studio/backend/api/model/resource/common"
|
||||
)
|
||||
|
||||
type SearchResourcesRequest struct {
|
||||
SpaceID int64
|
||||
OwnerID int64
|
||||
Name string
|
||||
APPID int64
|
||||
|
||||
OrderFiledName string
|
||||
OrderAsc bool
|
||||
ResTypeFilter []resource.ResType
|
||||
PublishStatusFilter resource.PublishStatus
|
||||
SearchKeys []string
|
||||
|
||||
Cursor string
|
||||
Page *int32
|
||||
Limit int32
|
||||
}
|
||||
|
||||
type SearchResourcesResponse struct {
|
||||
HasMore bool
|
||||
NextCursor string
|
||||
TotalHits *int64
|
||||
|
||||
Data []*ResourceDocument
|
||||
}
|
||||
|
||||
const (
|
||||
FieldOfCreateTime = "create_time"
|
||||
FieldOfUpdateTime = "update_time"
|
||||
FieldOfPublishTime = "publish_time"
|
||||
FieldOfFavTime = "fav_time"
|
||||
FieldOfRecentlyOpenTime = "recently_open_time"
|
||||
)
|
||||
106
backend/api/model/crossdomain/singleagent/single_agent.go
Normal file
106
backend/api/model/crossdomain/singleagent/single_agent.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package singleagent
|
||||
|
||||
import (
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/bot_common"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossworkflow"
|
||||
)
|
||||
|
||||
type AgentRuntime struct {
|
||||
AgentVersion string
|
||||
IsDraft bool
|
||||
SpaceID int64
|
||||
ConnectorID int64
|
||||
PreRetrieveTools []*agentrun.Tool
|
||||
}
|
||||
|
||||
type EventType string
|
||||
|
||||
const (
|
||||
EventTypeOfChatModelAnswer EventType = "chatmodel_answer"
|
||||
EventTypeOfToolsMessage EventType = "tools_message"
|
||||
EventTypeOfFuncCall EventType = "func_call"
|
||||
EventTypeOfSuggest EventType = "suggest"
|
||||
EventTypeOfKnowledge EventType = "knowledge"
|
||||
EventTypeOfInterrupt EventType = "interrupt"
|
||||
)
|
||||
|
||||
type AgentEvent struct {
|
||||
EventType EventType
|
||||
|
||||
ChatModelAnswer *schema.StreamReader[*schema.Message]
|
||||
ToolsMessage []*schema.Message
|
||||
FuncCall *schema.Message
|
||||
Suggest *schema.Message
|
||||
Knowledge []*schema.Document
|
||||
Interrupt *InterruptInfo
|
||||
}
|
||||
|
||||
type SingleAgent struct {
|
||||
AgentID int64
|
||||
CreatorID int64
|
||||
SpaceID int64
|
||||
Name string
|
||||
Desc string
|
||||
IconURI string
|
||||
CreatedAt int64
|
||||
UpdatedAt int64
|
||||
Version string
|
||||
DeletedAt gorm.DeletedAt
|
||||
|
||||
VariablesMetaID *int64
|
||||
OnboardingInfo *bot_common.OnboardingInfo
|
||||
ModelInfo *bot_common.ModelInfo
|
||||
Prompt *bot_common.PromptInfo
|
||||
Plugin []*bot_common.PluginInfo
|
||||
Knowledge *bot_common.Knowledge
|
||||
Workflow []*bot_common.WorkflowInfo
|
||||
SuggestReply *bot_common.SuggestReplyInfo
|
||||
JumpConfig *bot_common.JumpConfig
|
||||
BackgroundImageInfoList []*bot_common.BackgroundImageInfo
|
||||
Database []*bot_common.Database
|
||||
ShortcutCommand []string
|
||||
}
|
||||
|
||||
type InterruptEventType int64
|
||||
|
||||
const (
|
||||
InterruptEventType_LocalPlugin InterruptEventType = 1
|
||||
InterruptEventType_Question InterruptEventType = 2
|
||||
InterruptEventType_RequireInfos InterruptEventType = 3
|
||||
InterruptEventType_SceneChat InterruptEventType = 4
|
||||
InterruptEventType_InputNode InterruptEventType = 5
|
||||
InterruptEventType_WorkflowLocalPlugin InterruptEventType = 6
|
||||
InterruptEventType_OauthPlugin InterruptEventType = 7
|
||||
InterruptEventType_WorkflowLLM InterruptEventType = 100
|
||||
)
|
||||
|
||||
type InterruptInfo struct {
|
||||
AllToolInterruptData map[string]*plugin.ToolInterruptEvent
|
||||
AllWfInterruptData map[string]*crossworkflow.ToolInterruptEvent
|
||||
ToolCallID string
|
||||
InterruptType InterruptEventType
|
||||
InterruptID string
|
||||
}
|
||||
|
||||
type ExecuteRequest struct {
|
||||
Identity *AgentIdentity
|
||||
UserID string
|
||||
|
||||
Input *schema.Message
|
||||
History []*schema.Message
|
||||
ResumeInfo *InterruptInfo
|
||||
PreCallTools []*agentrun.ToolsRetriever
|
||||
}
|
||||
|
||||
type AgentIdentity struct {
|
||||
AgentID int64
|
||||
// State AgentState
|
||||
Version string
|
||||
IsDraft bool
|
||||
ConnectorID int64
|
||||
}
|
||||
13
backend/api/model/crossdomain/variables/variable_instance.go
Normal file
13
backend/api/model/crossdomain/variables/variable_instance.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package variables
|
||||
|
||||
import (
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/project_memory"
|
||||
)
|
||||
|
||||
type UserVariableMeta struct {
|
||||
BizType project_memory.VariableConnector
|
||||
BizID string
|
||||
Version string
|
||||
ConnectorUID string
|
||||
ConnectorID int64
|
||||
}
|
||||
Reference in New Issue
Block a user