feat: manually mirror opencoze's code from bytedance
Change-Id: I09a73aadda978ad9511264a756b2ce51f5761adf
This commit is contained in:
1390
backend/domain/workflow/internal/execute/callback.go
Normal file
1390
backend/domain/workflow/internal/execute/callback.go
Normal file
File diff suppressed because it is too large
Load Diff
143
backend/domain/workflow/internal/execute/collect_token.go
Normal file
143
backend/domain/workflow/internal/execute/collect_token.go
Normal file
@@ -0,0 +1,143 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package execute
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/cloudwego/eino/callbacks"
|
||||
"github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
callbacks2 "github.com/cloudwego/eino/utils/callbacks"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
)
|
||||
|
||||
type TokenCollector struct {
|
||||
Key string
|
||||
Usage *model.TokenUsage
|
||||
wg sync.WaitGroup
|
||||
mu sync.Mutex
|
||||
Parent *TokenCollector
|
||||
}
|
||||
|
||||
func newTokenCollector(key string, parent *TokenCollector) *TokenCollector {
|
||||
return &TokenCollector{
|
||||
Key: key,
|
||||
Usage: &model.TokenUsage{},
|
||||
Parent: parent,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TokenCollector) addTokenUsage(usage *model.TokenUsage) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
t.Usage.PromptTokens += usage.PromptTokens
|
||||
t.Usage.CompletionTokens += usage.CompletionTokens
|
||||
t.Usage.TotalTokens += usage.TotalTokens
|
||||
|
||||
if t.Parent != nil {
|
||||
t.Parent.addTokenUsage(usage)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TokenCollector) wait() *model.TokenUsage {
|
||||
t.wg.Wait()
|
||||
t.mu.Lock()
|
||||
usage := &model.TokenUsage{
|
||||
PromptTokens: t.Usage.PromptTokens,
|
||||
CompletionTokens: t.Usage.CompletionTokens,
|
||||
TotalTokens: t.Usage.TotalTokens,
|
||||
}
|
||||
t.mu.Unlock()
|
||||
return usage
|
||||
}
|
||||
|
||||
func getTokenCollector(ctx context.Context) *TokenCollector {
|
||||
c := GetExeCtx(ctx)
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
return c.TokenCollector
|
||||
}
|
||||
|
||||
func GetTokenCallbackHandler() callbacks.Handler {
|
||||
return callbacks2.NewHandlerHelper().ChatModel(&callbacks2.ModelCallbackHandler{
|
||||
OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.CallbackInput) context.Context {
|
||||
c := getTokenCollector(ctx)
|
||||
if c == nil {
|
||||
return ctx
|
||||
}
|
||||
c.wg.Add(1)
|
||||
return ctx
|
||||
},
|
||||
OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *model.CallbackOutput) context.Context {
|
||||
c := getTokenCollector(ctx)
|
||||
if c == nil {
|
||||
return ctx
|
||||
}
|
||||
if output.TokenUsage == nil {
|
||||
c.wg.Done()
|
||||
return ctx
|
||||
}
|
||||
c.addTokenUsage(output.TokenUsage)
|
||||
c.wg.Done()
|
||||
return ctx
|
||||
},
|
||||
OnEndWithStreamOutput: func(ctx context.Context, runInfo *callbacks.RunInfo, output *schema.StreamReader[*model.CallbackOutput]) context.Context {
|
||||
c := getTokenCollector(ctx)
|
||||
if c == nil {
|
||||
output.Close()
|
||||
return ctx
|
||||
}
|
||||
safego.Go(ctx, func() {
|
||||
defer func() {
|
||||
output.Close()
|
||||
c.wg.Done()
|
||||
}()
|
||||
|
||||
newC := &model.TokenUsage{}
|
||||
|
||||
for {
|
||||
chunk, err := output.Recv()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
if chunk.TokenUsage == nil {
|
||||
continue
|
||||
}
|
||||
newC.PromptTokens += chunk.TokenUsage.PromptTokens
|
||||
newC.CompletionTokens += chunk.TokenUsage.CompletionTokens
|
||||
newC.TotalTokens += chunk.TokenUsage.TotalTokens
|
||||
}
|
||||
|
||||
c.addTokenUsage(newC)
|
||||
})
|
||||
return ctx
|
||||
},
|
||||
OnError: func(ctx context.Context, runInfo *callbacks.RunInfo, runErr error) context.Context {
|
||||
c := getTokenCollector(ctx)
|
||||
if c == nil {
|
||||
return ctx
|
||||
}
|
||||
c.wg.Done()
|
||||
return ctx
|
||||
},
|
||||
}).Handler()
|
||||
}
|
||||
68
backend/domain/workflow/internal/execute/consts.go
Normal file
68
backend/domain/workflow/internal/execute/consts.go
Normal file
@@ -0,0 +1,68 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package execute
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
)
|
||||
|
||||
const (
|
||||
foregroundRunTimeout = 10 * time.Minute
|
||||
backgroundRunTimeout = 24 * time.Hour
|
||||
maxNodeCountPerWorkflow = 1000
|
||||
maxNodeCountPerExecution = 1000
|
||||
cancelCheckInterval = 200 * time.Millisecond
|
||||
)
|
||||
|
||||
type StaticConfig struct {
|
||||
ForegroundRunTimeout time.Duration
|
||||
BackgroundRunTimeout time.Duration
|
||||
MaxNodeCountPerWorkflow int
|
||||
MaxNodeCountPerExecution int
|
||||
}
|
||||
|
||||
func GetStaticConfig() *StaticConfig {
|
||||
return &StaticConfig{
|
||||
ForegroundRunTimeout: foregroundRunTimeout,
|
||||
BackgroundRunTimeout: backgroundRunTimeout,
|
||||
MaxNodeCountPerWorkflow: maxNodeCountPerWorkflow,
|
||||
MaxNodeCountPerExecution: maxNodeCountPerExecution,
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
executedNodeCountKey = "executed_node_count"
|
||||
)
|
||||
|
||||
func IncrAndCheckExecutedNodes(ctx context.Context) (int64, bool) {
|
||||
counter, ok := ctxcache.Get[atomic.Int64](ctx, executedNodeCountKey)
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
current := counter.Add(1)
|
||||
return current, current > maxNodeCountPerExecution
|
||||
}
|
||||
|
||||
func InitExecutedNodesCounter(ctx context.Context) context.Context {
|
||||
ctxcache.Store(ctx, executedNodeCountKey, atomic.Int64{})
|
||||
return ctx
|
||||
}
|
||||
365
backend/domain/workflow/internal/execute/context.go
Normal file
365
backend/domain/workflow/internal/execute/context.go
Normal file
@@ -0,0 +1,365 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package execute
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
)
|
||||
|
||||
type Context struct {
|
||||
RootCtx
|
||||
|
||||
*SubWorkflowCtx
|
||||
|
||||
*NodeCtx
|
||||
|
||||
*BatchInfo
|
||||
|
||||
TokenCollector *TokenCollector
|
||||
|
||||
StartTime int64 // UnixMilli
|
||||
|
||||
CheckPointID string
|
||||
}
|
||||
|
||||
type RootCtx struct {
|
||||
RootWorkflowBasic *entity.WorkflowBasic
|
||||
RootExecuteID int64
|
||||
ResumeEvent *entity.InterruptEvent
|
||||
ExeCfg vo.ExecuteConfig
|
||||
}
|
||||
|
||||
type SubWorkflowCtx struct {
|
||||
SubWorkflowBasic *entity.WorkflowBasic
|
||||
SubExecuteID int64
|
||||
}
|
||||
|
||||
type NodeCtx struct {
|
||||
NodeKey vo.NodeKey
|
||||
NodeExecuteID int64
|
||||
NodeName string
|
||||
NodeType entity.NodeType
|
||||
NodePath []string
|
||||
TerminatePlan *vo.TerminatePlan
|
||||
|
||||
ResumingEvent *entity.InterruptEvent
|
||||
SubWorkflowExeID int64 // if this node is subworkflow node, the execute id of the sub workflow
|
||||
|
||||
CurrentRetryCount int
|
||||
}
|
||||
|
||||
type BatchInfo struct {
|
||||
Index int
|
||||
Items map[string]any
|
||||
CompositeNodeKey vo.NodeKey
|
||||
}
|
||||
|
||||
type contextKey struct{}
|
||||
|
||||
func restoreWorkflowCtx(ctx context.Context, h *WorkflowHandler) (context.Context, error) {
|
||||
var storedCtx *Context
|
||||
err := compose.ProcessState[ExeContextStore](ctx, func(ctx context.Context, state ExeContextStore) error {
|
||||
if state == nil {
|
||||
return errors.New("state is nil")
|
||||
}
|
||||
|
||||
var e error
|
||||
storedCtx, _, e = state.GetWorkflowCtx()
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return ctx, err
|
||||
}
|
||||
|
||||
if storedCtx == nil {
|
||||
return ctx, errors.New("stored workflow context is nil")
|
||||
}
|
||||
|
||||
storedCtx.ResumeEvent = h.resumeEvent
|
||||
|
||||
// restore the parent-child relationship between token collectors
|
||||
if storedCtx.TokenCollector != nil && storedCtx.TokenCollector.Parent != nil {
|
||||
currentC := GetExeCtx(ctx)
|
||||
currentTokenCollector := currentC.TokenCollector
|
||||
storedCtx.TokenCollector.Parent = currentTokenCollector
|
||||
}
|
||||
|
||||
return context.WithValue(ctx, contextKey{}, storedCtx), nil
|
||||
}
|
||||
|
||||
func restoreNodeCtx(ctx context.Context, nodeKey vo.NodeKey, resumeEvent *entity.InterruptEvent,
|
||||
exactlyResuming bool) (context.Context, error) {
|
||||
var storedCtx *Context
|
||||
err := compose.ProcessState[ExeContextStore](ctx, func(ctx context.Context, state ExeContextStore) error {
|
||||
if state == nil {
|
||||
return errors.New("state is nil")
|
||||
}
|
||||
var e error
|
||||
storedCtx, _, e = state.GetNodeCtx(nodeKey)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return ctx, err
|
||||
}
|
||||
|
||||
if storedCtx == nil {
|
||||
return ctx, errors.New("stored node context is nil")
|
||||
}
|
||||
|
||||
if exactlyResuming {
|
||||
storedCtx.NodeCtx.ResumingEvent = resumeEvent
|
||||
} else {
|
||||
storedCtx.NodeCtx.ResumingEvent = nil
|
||||
}
|
||||
|
||||
existingC := GetExeCtx(ctx)
|
||||
if existingC != nil {
|
||||
storedCtx.RootCtx.ResumeEvent = existingC.RootCtx.ResumeEvent
|
||||
}
|
||||
|
||||
// restore the parent-child relationship between token collectors
|
||||
if storedCtx.TokenCollector != nil && storedCtx.TokenCollector.Parent != nil {
|
||||
currentC := GetExeCtx(ctx)
|
||||
currentTokenCollector := currentC.TokenCollector
|
||||
storedCtx.TokenCollector.Parent = currentTokenCollector
|
||||
}
|
||||
|
||||
storedCtx.NodeCtx.CurrentRetryCount = 0
|
||||
|
||||
return context.WithValue(ctx, contextKey{}, storedCtx), nil
|
||||
}
|
||||
|
||||
func tryRestoreNodeCtx(ctx context.Context, nodeKey vo.NodeKey) (context.Context, bool) {
|
||||
var storedCtx *Context
|
||||
err := compose.ProcessState[ExeContextStore](ctx, func(ctx context.Context, state ExeContextStore) error {
|
||||
if state == nil {
|
||||
return errors.New("state is nil")
|
||||
}
|
||||
var e error
|
||||
storedCtx, _, e = state.GetNodeCtx(nodeKey)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil || storedCtx == nil {
|
||||
return ctx, false
|
||||
}
|
||||
|
||||
storedCtx.NodeCtx.ResumingEvent = nil
|
||||
|
||||
existingC := GetExeCtx(ctx)
|
||||
if existingC != nil {
|
||||
storedCtx.RootCtx.ResumeEvent = existingC.RootCtx.ResumeEvent
|
||||
}
|
||||
|
||||
// restore the parent-child relationship between token collectors
|
||||
if storedCtx.TokenCollector != nil && storedCtx.TokenCollector.Parent != nil && existingC != nil {
|
||||
currentTokenCollector := existingC.TokenCollector
|
||||
storedCtx.TokenCollector.Parent = currentTokenCollector
|
||||
}
|
||||
|
||||
storedCtx.NodeCtx.CurrentRetryCount = 0
|
||||
|
||||
return context.WithValue(ctx, contextKey{}, storedCtx), true
|
||||
}
|
||||
|
||||
func PrepareRootExeCtx(ctx context.Context, h *WorkflowHandler) (context.Context, error) {
|
||||
var parentTokenCollector *TokenCollector
|
||||
if currentC := GetExeCtx(ctx); currentC != nil {
|
||||
parentTokenCollector = currentC.TokenCollector
|
||||
}
|
||||
|
||||
rootExeCtx := &Context{
|
||||
RootCtx: RootCtx{
|
||||
RootWorkflowBasic: h.rootWorkflowBasic,
|
||||
RootExecuteID: h.rootExecuteID,
|
||||
ResumeEvent: h.resumeEvent,
|
||||
ExeCfg: h.exeCfg,
|
||||
},
|
||||
|
||||
TokenCollector: newTokenCollector(fmt.Sprintf("wf_%d", h.rootWorkflowBasic.ID), parentTokenCollector),
|
||||
StartTime: time.Now().UnixMilli(),
|
||||
}
|
||||
|
||||
if h.requireCheckpoint {
|
||||
rootExeCtx.CheckPointID = strconv.FormatInt(h.rootExecuteID, 10)
|
||||
err := compose.ProcessState[ExeContextStore](ctx, func(ctx context.Context, state ExeContextStore) error {
|
||||
if state == nil {
|
||||
return errors.New("state is nil")
|
||||
}
|
||||
return state.SetWorkflowCtx(rootExeCtx)
|
||||
})
|
||||
if err != nil {
|
||||
return ctx, err
|
||||
}
|
||||
}
|
||||
|
||||
return context.WithValue(ctx, contextKey{}, rootExeCtx), nil
|
||||
}
|
||||
|
||||
func GetExeCtx(ctx context.Context) *Context {
|
||||
c := ctx.Value(contextKey{})
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
return c.(*Context)
|
||||
}
|
||||
|
||||
func PrepareSubExeCtx(ctx context.Context, wb *entity.WorkflowBasic, requireCheckpoint bool) (context.Context, error) {
|
||||
c := GetExeCtx(ctx)
|
||||
if c == nil {
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
subExecuteID, err := workflow.GetRepository().GenID(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var newCheckpointID string
|
||||
if len(c.CheckPointID) > 0 {
|
||||
newCheckpointID = c.CheckPointID + "_" + strconv.FormatInt(subExecuteID, 10)
|
||||
}
|
||||
|
||||
newC := &Context{
|
||||
RootCtx: c.RootCtx,
|
||||
SubWorkflowCtx: &SubWorkflowCtx{
|
||||
SubWorkflowBasic: wb,
|
||||
SubExecuteID: subExecuteID,
|
||||
},
|
||||
NodeCtx: c.NodeCtx,
|
||||
BatchInfo: c.BatchInfo,
|
||||
TokenCollector: newTokenCollector(fmt.Sprintf("sub_wf_%d", wb.ID), c.TokenCollector),
|
||||
CheckPointID: newCheckpointID,
|
||||
StartTime: time.Now().UnixMilli(),
|
||||
}
|
||||
|
||||
if requireCheckpoint {
|
||||
err := compose.ProcessState[ExeContextStore](ctx, func(ctx context.Context, state ExeContextStore) error {
|
||||
if state == nil {
|
||||
return errors.New("state is nil")
|
||||
}
|
||||
return state.SetWorkflowCtx(newC)
|
||||
})
|
||||
if err != nil {
|
||||
return ctx, err
|
||||
}
|
||||
}
|
||||
|
||||
newC.NodeCtx.SubWorkflowExeID = subExecuteID
|
||||
|
||||
return context.WithValue(ctx, contextKey{}, newC), nil
|
||||
}
|
||||
|
||||
func PrepareNodeExeCtx(ctx context.Context, nodeKey vo.NodeKey, nodeName string, nodeType entity.NodeType, plan *vo.TerminatePlan) (context.Context, error) {
|
||||
c := GetExeCtx(ctx)
|
||||
if c == nil {
|
||||
return ctx, nil
|
||||
}
|
||||
nodeExecuteID, err := workflow.GetRepository().GenID(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newC := &Context{
|
||||
RootCtx: c.RootCtx,
|
||||
SubWorkflowCtx: c.SubWorkflowCtx,
|
||||
NodeCtx: &NodeCtx{
|
||||
NodeKey: nodeKey,
|
||||
NodeExecuteID: nodeExecuteID,
|
||||
NodeName: nodeName,
|
||||
NodeType: nodeType,
|
||||
TerminatePlan: plan,
|
||||
},
|
||||
BatchInfo: c.BatchInfo,
|
||||
StartTime: time.Now().UnixMilli(),
|
||||
CheckPointID: c.CheckPointID,
|
||||
}
|
||||
|
||||
if c.NodeCtx == nil { // node within top level workflow, also not under composite node
|
||||
newC.NodeCtx.NodePath = []string{string(nodeKey)}
|
||||
} else {
|
||||
if c.BatchInfo == nil {
|
||||
newC.NodeCtx.NodePath = append(c.NodeCtx.NodePath, string(nodeKey))
|
||||
} else {
|
||||
newC.NodeCtx.NodePath = append(c.NodeCtx.NodePath, InterruptEventIndexPrefix+strconv.Itoa(c.BatchInfo.Index), string(nodeKey))
|
||||
}
|
||||
}
|
||||
|
||||
tc := c.TokenCollector
|
||||
if entity.NodeMetaByNodeType(nodeType).MayUseChatModel {
|
||||
tc = newTokenCollector(strings.Join(append([]string{string(newC.NodeType)}, newC.NodeCtx.NodePath...), "."), c.TokenCollector)
|
||||
}
|
||||
newC.TokenCollector = tc
|
||||
|
||||
return context.WithValue(ctx, contextKey{}, newC), nil
|
||||
}
|
||||
|
||||
func InheritExeCtxWithBatchInfo(ctx context.Context, index int, items map[string]any) (context.Context, string) {
|
||||
c := GetExeCtx(ctx)
|
||||
if c == nil {
|
||||
return ctx, ""
|
||||
}
|
||||
var newCheckpointID string
|
||||
if len(c.CheckPointID) > 0 {
|
||||
newCheckpointID = c.CheckPointID
|
||||
if c.SubWorkflowCtx != nil {
|
||||
newCheckpointID += "_" + strconv.Itoa(int(c.SubWorkflowCtx.SubExecuteID))
|
||||
}
|
||||
newCheckpointID += "_" + string(c.NodeCtx.NodeKey)
|
||||
newCheckpointID += "_" + strconv.Itoa(index)
|
||||
}
|
||||
return context.WithValue(ctx, contextKey{}, &Context{
|
||||
RootCtx: c.RootCtx,
|
||||
SubWorkflowCtx: c.SubWorkflowCtx,
|
||||
NodeCtx: c.NodeCtx,
|
||||
TokenCollector: c.TokenCollector,
|
||||
BatchInfo: &BatchInfo{
|
||||
Index: index,
|
||||
Items: items,
|
||||
CompositeNodeKey: c.NodeCtx.NodeKey,
|
||||
},
|
||||
CheckPointID: newCheckpointID,
|
||||
}), newCheckpointID
|
||||
}
|
||||
|
||||
type ExeContextStore interface {
|
||||
GetNodeCtx(key vo.NodeKey) (*Context, bool, error)
|
||||
SetNodeCtx(key vo.NodeKey, value *Context) error
|
||||
GetWorkflowCtx() (*Context, bool, error)
|
||||
SetWorkflowCtx(value *Context) error
|
||||
}
|
||||
120
backend/domain/workflow/internal/execute/event.go
Normal file
120
backend/domain/workflow/internal/execute/event.go
Normal file
@@ -0,0 +1,120 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package execute
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
)
|
||||
|
||||
type EventType string
|
||||
|
||||
const (
|
||||
WorkflowStart EventType = "workflow_start"
|
||||
WorkflowSuccess EventType = "workflow_success"
|
||||
WorkflowFailed EventType = "workflow_failed"
|
||||
WorkflowCancel EventType = "workflow_cancel"
|
||||
WorkflowInterrupt EventType = "workflow_interrupt"
|
||||
WorkflowResume EventType = "workflow_resume"
|
||||
NodeStart EventType = "node_start"
|
||||
NodeEnd EventType = "node_end"
|
||||
NodeEndStreaming EventType = "node_end_streaming" // absolutely end, after all streaming content are sent
|
||||
NodeError EventType = "node_error"
|
||||
NodeStreamingInput EventType = "node_streaming_input"
|
||||
NodeStreamingOutput EventType = "node_streaming_output"
|
||||
FunctionCall EventType = "function_call"
|
||||
ToolResponse EventType = "tool_response"
|
||||
ToolStreamingResponse EventType = "tool_streaming_response"
|
||||
ToolError EventType = "tool_error"
|
||||
)
|
||||
|
||||
type Event struct {
|
||||
Type EventType
|
||||
|
||||
*Context
|
||||
|
||||
Duration time.Duration
|
||||
Input map[string]any
|
||||
Output map[string]any
|
||||
|
||||
// if the node is output_emitter or exit node with answer as terminate plan, this field will be set.
|
||||
// it contains the incremental change in the output.
|
||||
Answer string
|
||||
StreamEnd bool
|
||||
|
||||
RawOutput map[string]any
|
||||
|
||||
Err error
|
||||
Token *TokenInfo
|
||||
|
||||
InterruptEvents []*entity.InterruptEvent
|
||||
|
||||
functionCall *entity.FunctionCallInfo
|
||||
toolResponse *entity.ToolResponseInfo
|
||||
|
||||
outputExtractor func(o map[string]any) string
|
||||
extra *entity.NodeExtra
|
||||
|
||||
done chan struct{}
|
||||
|
||||
nodeCount int32
|
||||
}
|
||||
|
||||
type TokenInfo struct {
|
||||
InputToken int64
|
||||
OutputToken int64
|
||||
TotalToken int64
|
||||
}
|
||||
|
||||
func (e *Event) GetInputTokens() int64 {
|
||||
if e.Token == nil {
|
||||
return 0
|
||||
}
|
||||
return e.Token.InputToken
|
||||
}
|
||||
|
||||
func (e *Event) GetOutputTokens() int64 {
|
||||
if e.Token == nil {
|
||||
return 0
|
||||
}
|
||||
return e.Token.OutputToken
|
||||
}
|
||||
|
||||
func (e *Event) GetResumedEventID() int64 {
|
||||
if e.Context == nil {
|
||||
return 0
|
||||
}
|
||||
if e.Context.RootCtx.ResumeEvent == nil {
|
||||
return 0
|
||||
}
|
||||
return e.Context.RootCtx.ResumeEvent.ID
|
||||
}
|
||||
|
||||
func (e *Event) GetFunctionCallInfo() (*entity.FunctionCallInfo, bool) {
|
||||
if e.functionCall == nil {
|
||||
return nil, false
|
||||
}
|
||||
return e.functionCall, true
|
||||
}
|
||||
|
||||
func (e *Event) GetToolResponse() (*entity.ToolResponseInfo, bool) {
|
||||
if e.toolResponse == nil {
|
||||
return nil, false
|
||||
}
|
||||
return e.toolResponse, true
|
||||
}
|
||||
910
backend/domain/workflow/internal/execute/event_handle.go
Normal file
910
backend/domain/workflow/internal/execute/event_handle.go
Normal file
@@ -0,0 +1,910 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package execute
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
func setRootWorkflowSuccess(ctx context.Context, event *Event, repo workflow.Repository,
|
||||
sw *schema.StreamWriter[*entity.Message]) (err error) {
|
||||
exeID := event.RootCtx.RootExecuteID
|
||||
wfExec := &entity.WorkflowExecution{
|
||||
ID: exeID,
|
||||
Duration: event.Duration,
|
||||
Status: entity.WorkflowSuccess,
|
||||
Output: ptr.Of(mustMarshalToString(event.Output)),
|
||||
TokenInfo: &entity.TokenUsage{
|
||||
InputTokens: event.GetInputTokens(),
|
||||
OutputTokens: event.GetOutputTokens(),
|
||||
},
|
||||
}
|
||||
|
||||
var (
|
||||
updatedRows int64
|
||||
currentStatus entity.WorkflowExecuteStatus
|
||||
)
|
||||
if updatedRows, currentStatus, err = repo.UpdateWorkflowExecution(ctx, wfExec, []entity.WorkflowExecuteStatus{entity.WorkflowRunning}); err != nil {
|
||||
return fmt.Errorf("failed to save workflow execution when successful: %v", err)
|
||||
} else if updatedRows == 0 {
|
||||
return fmt.Errorf("failed to update workflow execution to success for execution id %d, current status is %v", exeID, currentStatus)
|
||||
}
|
||||
|
||||
rootWkID := event.RootWorkflowBasic.ID
|
||||
exeCfg := event.ExeCfg
|
||||
if exeCfg.Mode == vo.ExecuteModeDebug {
|
||||
if err := repo.UpdateWorkflowDraftTestRunSuccess(ctx, rootWkID); err != nil {
|
||||
return fmt.Errorf("failed to save workflow draft test run success: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if sw != nil {
|
||||
sw.Send(&entity.Message{
|
||||
StateMessage: &entity.StateMessage{
|
||||
ExecuteID: event.RootExecuteID,
|
||||
EventID: event.GetResumedEventID(),
|
||||
Status: entity.WorkflowSuccess,
|
||||
Usage: ternary.IFElse(event.Token == nil, nil, &entity.TokenUsage{
|
||||
InputTokens: event.GetInputTokens(),
|
||||
OutputTokens: event.GetOutputTokens(),
|
||||
}),
|
||||
},
|
||||
}, nil)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type terminateSignal string
|
||||
|
||||
const (
|
||||
noTerminate terminateSignal = "no_terminate"
|
||||
workflowSuccess terminateSignal = "workflowSuccess"
|
||||
workflowAbort terminateSignal = "workflowAbort"
|
||||
lastNodeDone terminateSignal = "lastNodeDone"
|
||||
)
|
||||
|
||||
func handleEvent(ctx context.Context, event *Event, repo workflow.Repository,
|
||||
sw *schema.StreamWriter[*entity.Message], // when this workflow's caller needs to receive intermediate results
|
||||
) (signal terminateSignal, err error) {
|
||||
switch event.Type {
|
||||
case WorkflowStart:
|
||||
exeID := event.RootCtx.RootExecuteID
|
||||
var parentNodeID *string
|
||||
var parentNodeExecuteID *int64
|
||||
wb := event.RootWorkflowBasic
|
||||
if event.SubWorkflowCtx != nil {
|
||||
exeID = event.SubExecuteID
|
||||
parentNodeID = ptr.Of(string(event.NodeCtx.NodeKey))
|
||||
parentNodeExecuteID = ptr.Of(event.NodeCtx.NodeExecuteID)
|
||||
wb = event.SubWorkflowBasic
|
||||
}
|
||||
|
||||
if parentNodeID != nil { // root workflow execution has already been created
|
||||
var logID string
|
||||
logID, _ = ctx.Value("log-id").(string)
|
||||
|
||||
wfExec := &entity.WorkflowExecution{
|
||||
ID: exeID,
|
||||
WorkflowID: wb.ID,
|
||||
Version: wb.Version,
|
||||
SpaceID: wb.SpaceID,
|
||||
ExecuteConfig: event.ExeCfg,
|
||||
Status: entity.WorkflowRunning,
|
||||
Input: ptr.Of(mustMarshalToString(event.Input)),
|
||||
RootExecutionID: event.RootExecuteID,
|
||||
ParentNodeID: parentNodeID,
|
||||
ParentNodeExecuteID: parentNodeExecuteID,
|
||||
NodeCount: event.nodeCount,
|
||||
CommitID: wb.CommitID,
|
||||
LogID: logID,
|
||||
}
|
||||
|
||||
if err = repo.CreateWorkflowExecution(ctx, wfExec); err != nil {
|
||||
return noTerminate, fmt.Errorf("failed to create workflow execution: %v", err)
|
||||
}
|
||||
|
||||
nodeExec := &entity.NodeExecution{
|
||||
ID: event.NodeExecuteID,
|
||||
SubWorkflowExecution: &entity.WorkflowExecution{
|
||||
ID: exeID,
|
||||
},
|
||||
}
|
||||
if err = repo.UpdateNodeExecution(ctx, nodeExec); err != nil {
|
||||
return noTerminate, fmt.Errorf("failed to update subworkflow node execution with subExecuteID: %v", err)
|
||||
}
|
||||
} else if sw != nil {
|
||||
sw.Send(&entity.Message{
|
||||
StateMessage: &entity.StateMessage{
|
||||
ExecuteID: event.RootExecuteID,
|
||||
EventID: event.GetResumedEventID(),
|
||||
SpaceID: event.Context.RootCtx.RootWorkflowBasic.SpaceID,
|
||||
Status: entity.WorkflowRunning,
|
||||
},
|
||||
}, nil)
|
||||
}
|
||||
|
||||
if len(wb.Version) == 0 {
|
||||
if err = repo.CreateSnapshotIfNeeded(ctx, wb.ID, wb.CommitID); err != nil {
|
||||
return noTerminate, fmt.Errorf("failed to create snapshot: %v", err)
|
||||
}
|
||||
}
|
||||
case WorkflowSuccess:
|
||||
// sub workflow, no need to wait for exit node to be done
|
||||
if event.SubWorkflowCtx != nil {
|
||||
exeID := event.RootCtx.RootExecuteID
|
||||
if event.SubWorkflowCtx != nil {
|
||||
exeID = event.SubExecuteID
|
||||
}
|
||||
wfExec := &entity.WorkflowExecution{
|
||||
ID: exeID,
|
||||
Duration: event.Duration,
|
||||
Status: entity.WorkflowSuccess,
|
||||
Output: ptr.Of(mustMarshalToString(event.Output)),
|
||||
TokenInfo: &entity.TokenUsage{
|
||||
InputTokens: event.GetInputTokens(),
|
||||
OutputTokens: event.GetOutputTokens(),
|
||||
},
|
||||
}
|
||||
|
||||
var (
|
||||
updatedRows int64
|
||||
currentStatus entity.WorkflowExecuteStatus
|
||||
)
|
||||
if updatedRows, currentStatus, err = repo.UpdateWorkflowExecution(ctx, wfExec, []entity.WorkflowExecuteStatus{entity.WorkflowRunning}); err != nil {
|
||||
return noTerminate, fmt.Errorf("failed to save workflow execution when successful: %v", err)
|
||||
} else if updatedRows == 0 {
|
||||
return noTerminate, fmt.Errorf("failed to update workflow execution to success for execution id %d, current status is %v", exeID, currentStatus)
|
||||
}
|
||||
|
||||
return noTerminate, nil
|
||||
}
|
||||
|
||||
return workflowSuccess, nil
|
||||
case WorkflowFailed:
|
||||
exeID := event.RootCtx.RootExecuteID
|
||||
wfID := event.RootCtx.RootWorkflowBasic.ID
|
||||
if event.SubWorkflowCtx != nil {
|
||||
exeID = event.SubExecuteID
|
||||
wfID = event.SubWorkflowBasic.ID
|
||||
}
|
||||
|
||||
logs.CtxErrorf(ctx, "workflow execution failed: %v", event.Err)
|
||||
|
||||
wfExec := &entity.WorkflowExecution{
|
||||
ID: exeID,
|
||||
Duration: event.Duration,
|
||||
Status: entity.WorkflowFailed,
|
||||
TokenInfo: &entity.TokenUsage{
|
||||
InputTokens: event.GetInputTokens(),
|
||||
OutputTokens: event.GetOutputTokens(),
|
||||
},
|
||||
}
|
||||
|
||||
var wfe vo.WorkflowError
|
||||
if !errors.As(event.Err, &wfe) {
|
||||
if errors.Is(event.Err, context.DeadlineExceeded) {
|
||||
wfe = vo.WorkflowTimeoutErr
|
||||
} else if errors.Is(event.Err, context.Canceled) {
|
||||
wfe = vo.CancelErr
|
||||
} else {
|
||||
wfe = vo.WrapError(errno.ErrWorkflowExecuteFail, event.Err, errorx.KV("cause", vo.UnwrapRootErr(event.Err).Error()))
|
||||
}
|
||||
}
|
||||
|
||||
if cause := errors.Unwrap(event.Err); cause != nil {
|
||||
logs.CtxErrorf(ctx, "workflow %d for exeID %d returns err: %v, cause: %v",
|
||||
wfID, exeID, event.Err, cause)
|
||||
} else {
|
||||
logs.CtxErrorf(ctx, "workflow %d for exeID %d returns err: %v",
|
||||
wfID, exeID, event.Err)
|
||||
}
|
||||
|
||||
errMsg := wfe.Msg()[:min(1000, len(wfe.Msg()))]
|
||||
wfExec.ErrorCode = ptr.Of(strconv.Itoa(int(wfe.Code())))
|
||||
wfExec.FailReason = ptr.Of(errMsg)
|
||||
|
||||
var (
|
||||
updatedRows int64
|
||||
currentStatus entity.WorkflowExecuteStatus
|
||||
)
|
||||
if updatedRows, currentStatus, err = repo.UpdateWorkflowExecution(ctx, wfExec, []entity.WorkflowExecuteStatus{entity.WorkflowRunning}); err != nil {
|
||||
return noTerminate, fmt.Errorf("failed to save workflow execution when failed: %v", err)
|
||||
} else if updatedRows == 0 {
|
||||
return noTerminate, fmt.Errorf("failed to update workflow execution to failed for execution id %d, current status is %v", exeID, currentStatus)
|
||||
}
|
||||
|
||||
if event.SubWorkflowCtx == nil {
|
||||
if sw != nil {
|
||||
sw.Send(&entity.Message{
|
||||
StateMessage: &entity.StateMessage{
|
||||
ExecuteID: event.RootExecuteID,
|
||||
EventID: event.GetResumedEventID(),
|
||||
Status: entity.WorkflowFailed,
|
||||
Usage: wfExec.TokenInfo,
|
||||
LastError: wfe,
|
||||
},
|
||||
}, nil)
|
||||
}
|
||||
return workflowAbort, nil
|
||||
}
|
||||
case WorkflowInterrupt:
|
||||
if event.done != nil {
|
||||
defer close(event.done)
|
||||
}
|
||||
|
||||
exeID := event.RootCtx.RootExecuteID
|
||||
if event.SubWorkflowCtx != nil {
|
||||
exeID = event.SubExecuteID
|
||||
}
|
||||
wfExec := &entity.WorkflowExecution{
|
||||
ID: exeID,
|
||||
Status: entity.WorkflowInterrupted,
|
||||
}
|
||||
|
||||
var (
|
||||
updatedRows int64
|
||||
currentStatus entity.WorkflowExecuteStatus
|
||||
)
|
||||
if updatedRows, currentStatus, err = repo.UpdateWorkflowExecution(ctx, wfExec, []entity.WorkflowExecuteStatus{entity.WorkflowRunning}); err != nil {
|
||||
return noTerminate, fmt.Errorf("failed to save workflow execution when interrupted: %v", err)
|
||||
} else if updatedRows == 0 {
|
||||
return noTerminate, fmt.Errorf("failed to update workflow execution to interrupted for execution id %d, current status is %v", exeID, currentStatus)
|
||||
}
|
||||
|
||||
if event.RootCtx.ResumeEvent != nil {
|
||||
needPop := false
|
||||
for _, ie := range event.InterruptEvents {
|
||||
if ie.NodeKey == event.RootCtx.ResumeEvent.NodeKey {
|
||||
if reflect.DeepEqual(ie.NodePath, event.RootCtx.ResumeEvent.NodePath) {
|
||||
needPop = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if needPop {
|
||||
// the current resuming node emits an interrupt event again
|
||||
// need to remove the previous interrupt event because the node is not 'END', but 'Error',
|
||||
// so it didn't remove the previous interrupt OnEnd
|
||||
deletedEvent, deleted, err := repo.PopFirstInterruptEvent(ctx, exeID)
|
||||
if err != nil {
|
||||
return noTerminate, err
|
||||
}
|
||||
|
||||
if !deleted {
|
||||
return noTerminate, fmt.Errorf("interrupt events does not exist, wfExeID: %d", exeID)
|
||||
}
|
||||
|
||||
if deletedEvent.ID != event.RootCtx.ResumeEvent.ID {
|
||||
return noTerminate, fmt.Errorf("interrupt event id mismatch when deleting, expect: %d, actual: %d",
|
||||
event.RootCtx.ResumeEvent.ID, deletedEvent.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: there maybe time gap here
|
||||
|
||||
if err := repo.SaveInterruptEvents(ctx, event.RootExecuteID, event.InterruptEvents); err != nil {
|
||||
return noTerminate, fmt.Errorf("failed to save interrupt events: %v", err)
|
||||
}
|
||||
|
||||
if sw != nil && event.SubWorkflowCtx == nil { // only send interrupt event when is root workflow
|
||||
firstIE, found, err := repo.GetFirstInterruptEvent(ctx, event.RootExecuteID)
|
||||
if err != nil {
|
||||
return noTerminate, fmt.Errorf("failed to get first interrupt event: %v", err)
|
||||
}
|
||||
|
||||
if !found {
|
||||
return noTerminate, fmt.Errorf("interrupt event does not exist, wfExeID: %d", event.RootExecuteID)
|
||||
}
|
||||
|
||||
nodeKey := firstIE.NodeKey
|
||||
|
||||
sw.Send(&entity.Message{
|
||||
DataMessage: &entity.DataMessage{
|
||||
ExecuteID: event.RootExecuteID,
|
||||
Role: schema.Assistant,
|
||||
Type: entity.Answer,
|
||||
Content: firstIE.InterruptData, // TODO: may need to extract from InterruptData the actual info for user
|
||||
NodeID: string(nodeKey),
|
||||
NodeType: firstIE.NodeType,
|
||||
NodeTitle: firstIE.NodeTitle,
|
||||
Last: true,
|
||||
},
|
||||
}, nil)
|
||||
|
||||
sw.Send(&entity.Message{
|
||||
StateMessage: &entity.StateMessage{
|
||||
ExecuteID: event.RootExecuteID,
|
||||
EventID: event.GetResumedEventID(),
|
||||
Status: entity.WorkflowInterrupted,
|
||||
InterruptEvent: firstIE,
|
||||
},
|
||||
}, nil)
|
||||
}
|
||||
|
||||
return workflowAbort, nil
|
||||
case WorkflowCancel:
|
||||
exeID := event.RootCtx.RootExecuteID
|
||||
if event.SubWorkflowCtx != nil {
|
||||
exeID = event.SubExecuteID
|
||||
}
|
||||
wfExec := &entity.WorkflowExecution{
|
||||
ID: exeID,
|
||||
Duration: event.Duration,
|
||||
Status: entity.WorkflowCancel,
|
||||
TokenInfo: &entity.TokenUsage{
|
||||
InputTokens: event.GetInputTokens(),
|
||||
OutputTokens: event.GetOutputTokens(),
|
||||
},
|
||||
}
|
||||
|
||||
var (
|
||||
updatedRows int64
|
||||
currentStatus entity.WorkflowExecuteStatus
|
||||
)
|
||||
|
||||
if err = repo.CancelAllRunningNodes(ctx, exeID); err != nil {
|
||||
logs.CtxErrorf(ctx, err.Error())
|
||||
}
|
||||
|
||||
if updatedRows, currentStatus, err = repo.UpdateWorkflowExecution(ctx, wfExec, []entity.WorkflowExecuteStatus{entity.WorkflowRunning,
|
||||
entity.WorkflowInterrupted}); err != nil {
|
||||
return noTerminate, fmt.Errorf("failed to save workflow execution when canceled: %v", err)
|
||||
} else if updatedRows == 0 {
|
||||
return noTerminate, fmt.Errorf("failed to update workflow execution to canceled for execution id %d, current status is %v", exeID, currentStatus)
|
||||
}
|
||||
|
||||
if event.SubWorkflowCtx == nil {
|
||||
if sw != nil {
|
||||
sw.Send(&entity.Message{
|
||||
StateMessage: &entity.StateMessage{
|
||||
ExecuteID: event.RootExecuteID,
|
||||
EventID: event.GetResumedEventID(),
|
||||
Status: entity.WorkflowCancel,
|
||||
Usage: wfExec.TokenInfo,
|
||||
LastError: vo.CancelErr,
|
||||
},
|
||||
}, nil)
|
||||
}
|
||||
return workflowAbort, nil
|
||||
}
|
||||
case WorkflowResume:
|
||||
if sw == nil || event.SubWorkflowCtx != nil {
|
||||
return noTerminate, nil
|
||||
}
|
||||
|
||||
sw.Send(&entity.Message{
|
||||
StateMessage: &entity.StateMessage{
|
||||
ExecuteID: event.RootExecuteID,
|
||||
EventID: event.GetResumedEventID(),
|
||||
SpaceID: event.RootWorkflowBasic.SpaceID,
|
||||
Status: entity.WorkflowRunning,
|
||||
},
|
||||
}, nil)
|
||||
case NodeStart:
|
||||
if event.Context == nil {
|
||||
panic("nil event context")
|
||||
}
|
||||
|
||||
wfExeID := event.RootCtx.RootExecuteID
|
||||
if event.SubWorkflowCtx != nil {
|
||||
wfExeID = event.SubExecuteID
|
||||
}
|
||||
nodeExec := &entity.NodeExecution{
|
||||
ID: event.NodeExecuteID,
|
||||
ExecuteID: wfExeID,
|
||||
NodeID: string(event.NodeKey),
|
||||
NodeName: event.NodeName,
|
||||
NodeType: event.NodeType,
|
||||
Status: entity.NodeRunning,
|
||||
Input: ptr.Of(mustMarshalToString(event.Input)),
|
||||
Extra: event.extra,
|
||||
}
|
||||
if event.BatchInfo != nil {
|
||||
nodeExec.Index = event.BatchInfo.Index
|
||||
nodeExec.Items = ptr.Of(mustMarshalToString(event.BatchInfo.Items))
|
||||
nodeExec.ParentNodeID = ptr.Of(string(event.BatchInfo.CompositeNodeKey))
|
||||
}
|
||||
if err = repo.CreateNodeExecution(ctx, nodeExec); err != nil {
|
||||
return noTerminate, fmt.Errorf("failed to create node execution: %v", err)
|
||||
}
|
||||
case NodeEnd, NodeEndStreaming:
|
||||
nodeExec := &entity.NodeExecution{
|
||||
ID: event.NodeExecuteID,
|
||||
Status: entity.NodeSuccess,
|
||||
Duration: event.Duration,
|
||||
TokenInfo: &entity.TokenUsage{
|
||||
InputTokens: event.GetInputTokens(),
|
||||
OutputTokens: event.GetOutputTokens(),
|
||||
},
|
||||
Extra: event.extra,
|
||||
}
|
||||
|
||||
if event.Err != nil {
|
||||
var wfe vo.WorkflowError
|
||||
if !errors.As(event.Err, &wfe) {
|
||||
panic("node end: event.Err is not a WorkflowError")
|
||||
}
|
||||
|
||||
if cause := errors.Unwrap(event.Err); cause != nil {
|
||||
logs.CtxWarnf(ctx, "node %s for exeID %d end with warning: %v, cause: %v",
|
||||
event.NodeKey, event.NodeExecuteID, event.Err, cause)
|
||||
} else {
|
||||
logs.CtxWarnf(ctx, "node %s for exeID %d end with warning: %v",
|
||||
event.NodeKey, event.NodeExecuteID, event.Err)
|
||||
}
|
||||
nodeExec.ErrorInfo = ptr.Of(wfe.Msg())
|
||||
nodeExec.ErrorLevel = ptr.Of(string(wfe.Level()))
|
||||
}
|
||||
|
||||
if event.outputExtractor != nil {
|
||||
nodeExec.Output = ptr.Of(event.outputExtractor(event.Output))
|
||||
nodeExec.RawOutput = ptr.Of(event.outputExtractor(event.RawOutput))
|
||||
} else {
|
||||
nodeExec.Output = ptr.Of(mustMarshalToString(event.Output))
|
||||
nodeExec.RawOutput = ptr.Of(mustMarshalToString(event.RawOutput))
|
||||
}
|
||||
|
||||
fcInfos := getFCInfos(ctx, event.NodeKey)
|
||||
if len(fcInfos) > 0 {
|
||||
if nodeExec.Extra.ResponseExtra == nil {
|
||||
nodeExec.Extra.ResponseExtra = map[string]any{}
|
||||
}
|
||||
nodeExec.Extra.ResponseExtra["fc_called_detail"] = &entity.FCCalledDetail{
|
||||
FCCalledList: make([]*entity.FCCalled, 0, len(fcInfos)),
|
||||
}
|
||||
for _, fcInfo := range fcInfos {
|
||||
nodeExec.Extra.ResponseExtra["fc_called_detail"].(*entity.FCCalledDetail).FCCalledList = append(nodeExec.Extra.ResponseExtra["fc_called_detail"].(*entity.FCCalledDetail).FCCalledList, &entity.FCCalled{
|
||||
Input: fcInfo.inputString(),
|
||||
Output: fcInfo.outputString(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if event.Input != nil {
|
||||
nodeExec.Input = ptr.Of(mustMarshalToString(event.Input))
|
||||
}
|
||||
|
||||
if event.NodeCtx.ResumingEvent != nil {
|
||||
firstIE, found, err := repo.GetFirstInterruptEvent(ctx, event.RootCtx.RootExecuteID)
|
||||
if err != nil {
|
||||
return noTerminate, err
|
||||
}
|
||||
|
||||
if found && firstIE.ID == event.NodeCtx.ResumingEvent.ID {
|
||||
deletedEvent, deleted, err := repo.PopFirstInterruptEvent(ctx, event.RootCtx.RootExecuteID)
|
||||
if err != nil {
|
||||
return noTerminate, err
|
||||
}
|
||||
|
||||
if !deleted {
|
||||
return noTerminate, fmt.Errorf("node end: interrupt events does not exist, wfExeID: %d", event.RootCtx.RootExecuteID)
|
||||
}
|
||||
|
||||
if deletedEvent.ID != event.NodeCtx.ResumingEvent.ID {
|
||||
return noTerminate, fmt.Errorf("interrupt event id mismatch when deleting, expect: %d, actual: %d",
|
||||
event.RootCtx.ResumeEvent.ID, deletedEvent.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if event.NodeCtx.SubWorkflowExeID > 0 {
|
||||
nodeExec.SubWorkflowExecution = &entity.WorkflowExecution{
|
||||
ID: event.NodeCtx.SubWorkflowExeID,
|
||||
}
|
||||
}
|
||||
|
||||
if err = repo.UpdateNodeExecution(ctx, nodeExec); err != nil {
|
||||
return noTerminate, fmt.Errorf("failed to save node execution: %v", err)
|
||||
}
|
||||
|
||||
if sw != nil && event.Type == NodeEnd {
|
||||
var content string
|
||||
switch event.NodeType {
|
||||
case entity.NodeTypeOutputEmitter:
|
||||
content = event.Answer
|
||||
case entity.NodeTypeExit:
|
||||
if event.Context.SubWorkflowCtx != nil {
|
||||
// if the exit node belongs to a sub workflow, do not send data message
|
||||
return noTerminate, nil
|
||||
}
|
||||
|
||||
if *event.Context.NodeCtx.TerminatePlan == vo.ReturnVariables {
|
||||
content = mustMarshalToString(event.Output)
|
||||
} else {
|
||||
content = event.Answer
|
||||
}
|
||||
default:
|
||||
return noTerminate, nil
|
||||
}
|
||||
|
||||
sw.Send(&entity.Message{
|
||||
DataMessage: &entity.DataMessage{
|
||||
ExecuteID: event.RootExecuteID,
|
||||
Role: schema.Assistant,
|
||||
Type: entity.Answer,
|
||||
Content: content,
|
||||
NodeID: string(event.NodeKey),
|
||||
NodeType: event.NodeType,
|
||||
NodeTitle: event.NodeName,
|
||||
Last: true,
|
||||
Usage: ternary.IFElse(event.Token == nil, nil, &entity.TokenUsage{
|
||||
InputTokens: event.GetInputTokens(),
|
||||
OutputTokens: event.GetOutputTokens(),
|
||||
}),
|
||||
},
|
||||
}, nil)
|
||||
}
|
||||
|
||||
if event.NodeType == entity.NodeTypeExit && event.SubWorkflowCtx == nil {
|
||||
return lastNodeDone, nil
|
||||
}
|
||||
case NodeStreamingOutput:
|
||||
nodeExec := &entity.NodeExecution{
|
||||
ID: event.NodeExecuteID,
|
||||
Extra: event.extra,
|
||||
}
|
||||
|
||||
if event.outputExtractor != nil {
|
||||
nodeExec.Output = ptr.Of(event.outputExtractor(event.Output))
|
||||
} else {
|
||||
nodeExec.Output = ptr.Of(mustMarshalToString(event.Output))
|
||||
}
|
||||
|
||||
if err = repo.UpdateNodeExecutionStreaming(ctx, nodeExec); err != nil {
|
||||
return noTerminate, fmt.Errorf("failed to save node execution: %v", err)
|
||||
}
|
||||
|
||||
if sw == nil {
|
||||
return noTerminate, nil
|
||||
}
|
||||
|
||||
if event.NodeType == entity.NodeTypeExit {
|
||||
if event.Context.SubWorkflowCtx != nil {
|
||||
return noTerminate, nil
|
||||
}
|
||||
} else if event.NodeType == entity.NodeTypeVariableAggregator {
|
||||
return noTerminate, nil
|
||||
}
|
||||
|
||||
sw.Send(&entity.Message{
|
||||
DataMessage: &entity.DataMessage{
|
||||
ExecuteID: event.RootExecuteID,
|
||||
Role: schema.Assistant,
|
||||
Type: entity.Answer,
|
||||
Content: event.Answer,
|
||||
NodeID: string(event.NodeKey),
|
||||
NodeType: event.NodeType,
|
||||
NodeTitle: event.NodeName,
|
||||
Last: event.StreamEnd,
|
||||
},
|
||||
}, nil)
|
||||
case NodeStreamingInput:
|
||||
nodeExec := &entity.NodeExecution{
|
||||
ID: event.NodeExecuteID,
|
||||
Input: ptr.Of(mustMarshalToString(event.Input)),
|
||||
}
|
||||
if err = repo.UpdateNodeExecution(ctx, nodeExec); err != nil {
|
||||
return noTerminate, fmt.Errorf("failed to save node execution: %v", err)
|
||||
}
|
||||
|
||||
case NodeError:
|
||||
var errorInfo, errorLevel string
|
||||
var wfe vo.WorkflowError
|
||||
if !errors.As(event.Err, &wfe) {
|
||||
if errors.Is(event.Err, context.DeadlineExceeded) {
|
||||
wfe = vo.NodeTimeoutErr
|
||||
} else if errors.Is(event.Err, context.Canceled) {
|
||||
wfe = vo.CancelErr
|
||||
} else {
|
||||
wfe = vo.WrapError(errno.ErrWorkflowExecuteFail, event.Err, errorx.KV("cause", vo.UnwrapRootErr(event.Err).Error()))
|
||||
}
|
||||
}
|
||||
|
||||
if cause := errors.Unwrap(event.Err); cause != nil {
|
||||
logs.CtxErrorf(ctx, "node %s for exeID %d returns err: %v, cause: %v",
|
||||
event.NodeKey, event.NodeExecuteID, event.Err, cause)
|
||||
} else {
|
||||
logs.CtxErrorf(ctx, "node %s for exeID %d returns err: %v",
|
||||
event.NodeKey, event.NodeExecuteID, event.Err)
|
||||
}
|
||||
|
||||
errorInfo = wfe.Msg()[:min(1000, len(wfe.Msg()))]
|
||||
errorLevel = string(wfe.Level())
|
||||
|
||||
if event.Context == nil || event.Context.NodeCtx == nil {
|
||||
return noTerminate, fmt.Errorf("nil event context")
|
||||
}
|
||||
|
||||
nodeExec := &entity.NodeExecution{
|
||||
ID: event.NodeExecuteID,
|
||||
Status: entity.NodeFailed,
|
||||
ErrorInfo: ptr.Of(errorInfo),
|
||||
ErrorLevel: ptr.Of(errorLevel),
|
||||
Duration: event.Duration,
|
||||
TokenInfo: &entity.TokenUsage{
|
||||
InputTokens: event.GetInputTokens(),
|
||||
OutputTokens: event.GetOutputTokens(),
|
||||
},
|
||||
}
|
||||
if err = repo.UpdateNodeExecution(ctx, nodeExec); err != nil {
|
||||
return noTerminate, fmt.Errorf("failed to save node execution: %v", err)
|
||||
}
|
||||
case FunctionCall:
|
||||
cacheFunctionCall(ctx, event)
|
||||
if sw == nil {
|
||||
return noTerminate, nil
|
||||
}
|
||||
sw.Send(&entity.Message{
|
||||
DataMessage: &entity.DataMessage{
|
||||
ExecuteID: event.RootExecuteID,
|
||||
Role: schema.Assistant,
|
||||
Type: entity.FunctionCall,
|
||||
FunctionCall: event.functionCall,
|
||||
},
|
||||
}, nil)
|
||||
case ToolResponse:
|
||||
cacheToolResponse(ctx, event)
|
||||
if sw == nil {
|
||||
return noTerminate, nil
|
||||
}
|
||||
sw.Send(&entity.Message{
|
||||
DataMessage: &entity.DataMessage{
|
||||
ExecuteID: event.RootExecuteID,
|
||||
Role: schema.Tool,
|
||||
Type: entity.ToolResponse,
|
||||
Last: true,
|
||||
ToolResponse: event.toolResponse,
|
||||
},
|
||||
}, nil)
|
||||
case ToolStreamingResponse:
|
||||
cacheToolStreamingResponse(ctx, event)
|
||||
if sw == nil {
|
||||
return noTerminate, nil
|
||||
}
|
||||
sw.Send(&entity.Message{
|
||||
DataMessage: &entity.DataMessage{
|
||||
ExecuteID: event.RootExecuteID,
|
||||
Role: schema.Tool,
|
||||
Type: entity.ToolResponse,
|
||||
Last: event.StreamEnd,
|
||||
ToolResponse: event.toolResponse,
|
||||
},
|
||||
}, nil)
|
||||
case ToolError:
|
||||
// TODO: optimize this log
|
||||
logs.CtxErrorf(ctx, "received tool error event: %v", event)
|
||||
default:
|
||||
panic("unimplemented event type: " + event.Type)
|
||||
}
|
||||
|
||||
return noTerminate, nil
|
||||
}
|
||||
|
||||
type fcCacheKey struct{}
|
||||
type fcInfo struct {
|
||||
input *entity.FunctionCallInfo
|
||||
output *entity.ToolResponseInfo
|
||||
}
|
||||
|
||||
func HandleExecuteEvent(ctx context.Context,
|
||||
wfExeID int64,
|
||||
eventChan <-chan *Event,
|
||||
cancelFn context.CancelFunc,
|
||||
timeoutFn context.CancelFunc,
|
||||
repo workflow.Repository,
|
||||
sw *schema.StreamWriter[*entity.Message],
|
||||
exeCfg vo.ExecuteConfig,
|
||||
) (event *Event) {
|
||||
var (
|
||||
wfSuccessEvent *Event
|
||||
lastNodeIsDone bool
|
||||
cancelled bool
|
||||
)
|
||||
|
||||
ctx = context.WithValue(ctx, fcCacheKey{}, make(map[vo.NodeKey]map[string]*fcInfo))
|
||||
|
||||
handler := func(event *Event) *Event {
|
||||
var (
|
||||
nodeType entity.NodeType
|
||||
nodeKey vo.NodeKey
|
||||
)
|
||||
if event.Context.NodeCtx != nil {
|
||||
nodeType = event.Context.NodeCtx.NodeType
|
||||
nodeKey = event.Context.NodeCtx.NodeKey
|
||||
}
|
||||
|
||||
logs.CtxInfof(ctx, "receiving event type= %v, workflowID= %v, nodeType= %v, nodeKey = %s",
|
||||
event.Type, event.RootWorkflowBasic.ID, nodeType, nodeKey)
|
||||
|
||||
signal, err := handleEvent(ctx, event, repo, sw)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "failed to handle event: %v", err)
|
||||
}
|
||||
|
||||
switch signal {
|
||||
case noTerminate:
|
||||
// continue to next event
|
||||
case workflowAbort:
|
||||
return event
|
||||
case workflowSuccess: // workflow success, wait for exit node to be done
|
||||
wfSuccessEvent = event
|
||||
if lastNodeIsDone || exeCfg.Mode == vo.ExecuteModeNodeDebug {
|
||||
if err = setRootWorkflowSuccess(ctx, wfSuccessEvent, repo, sw); err != nil {
|
||||
logs.CtxErrorf(ctx, "failed to set root workflow success for workflow %d: %v",
|
||||
wfSuccessEvent.RootWorkflowBasic.ID, err)
|
||||
}
|
||||
return wfSuccessEvent
|
||||
}
|
||||
case lastNodeDone: // exit node done, wait for workflow success
|
||||
lastNodeIsDone = true
|
||||
if wfSuccessEvent != nil {
|
||||
if err = setRootWorkflowSuccess(ctx, wfSuccessEvent, repo, sw); err != nil {
|
||||
logs.CtxErrorf(ctx, "failed to set root workflow success: %v", err)
|
||||
}
|
||||
return wfSuccessEvent
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if exeCfg.Cancellable {
|
||||
// Add cancellation check timer
|
||||
cancelTicker := time.NewTicker(cancelCheckInterval)
|
||||
defer func() {
|
||||
logs.CtxInfof(ctx, "[handleExecuteEvent] finish, returned event type: %v, workflow id: %d",
|
||||
event.Type, event.Context.RootWorkflowBasic.ID)
|
||||
cancelTicker.Stop() // Clean up timer
|
||||
if timeoutFn != nil {
|
||||
timeoutFn()
|
||||
}
|
||||
cancelFn()
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-cancelTicker.C:
|
||||
if cancelled {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check cancellation status in Redis
|
||||
isCancelled, err := repo.GetWorkflowCancelFlag(ctx, wfExeID)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "failed to check cancellation status for workflow %d: %v", wfExeID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if isCancelled {
|
||||
cancelled = true
|
||||
logs.CtxInfof(ctx, "workflow %d cancellation detected", wfExeID)
|
||||
cancelFn()
|
||||
}
|
||||
case event = <-eventChan:
|
||||
if terminalE := handler(event); terminalE != nil {
|
||||
return terminalE
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
defer func() {
|
||||
logs.CtxInfof(ctx, "[handleExecuteEvent] finish, returned event type: %v, workflow id: %d",
|
||||
event.Type, event.Context.RootWorkflowBasic.ID)
|
||||
if timeoutFn != nil {
|
||||
timeoutFn()
|
||||
}
|
||||
cancelFn()
|
||||
}()
|
||||
|
||||
for e := range eventChan {
|
||||
if terminalE := handler(e); terminalE != nil {
|
||||
return terminalE
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
panic("unreachable")
|
||||
}
|
||||
|
||||
func mustMarshalToString[T any](m map[string]T) string {
|
||||
if len(m) == 0 {
|
||||
return "{}"
|
||||
}
|
||||
|
||||
b, err := sonic.ConfigStd.MarshalToString(m) // keep the order of the keys
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func cacheFunctionCall(ctx context.Context, event *Event) {
|
||||
c := ctx.Value(fcCacheKey{}).(map[vo.NodeKey]map[string]*fcInfo)
|
||||
if _, ok := c[event.NodeKey]; !ok {
|
||||
c[event.NodeKey] = make(map[string]*fcInfo)
|
||||
}
|
||||
c[event.NodeKey][event.functionCall.CallID] = &fcInfo{
|
||||
input: event.functionCall,
|
||||
}
|
||||
}
|
||||
|
||||
func cacheToolResponse(ctx context.Context, event *Event) {
|
||||
c := ctx.Value(fcCacheKey{}).(map[vo.NodeKey]map[string]*fcInfo)
|
||||
if _, ok := c[event.NodeKey]; !ok {
|
||||
c[event.NodeKey] = make(map[string]*fcInfo)
|
||||
}
|
||||
|
||||
c[event.NodeKey][event.toolResponse.CallID].output = event.toolResponse
|
||||
}
|
||||
|
||||
func cacheToolStreamingResponse(ctx context.Context, event *Event) {
|
||||
c := ctx.Value(fcCacheKey{}).(map[vo.NodeKey]map[string]*fcInfo)
|
||||
if _, ok := c[event.NodeKey]; !ok {
|
||||
c[event.NodeKey] = make(map[string]*fcInfo)
|
||||
}
|
||||
if c[event.NodeKey][event.toolResponse.CallID].output == nil {
|
||||
c[event.NodeKey][event.toolResponse.CallID].output = event.toolResponse
|
||||
}
|
||||
c[event.NodeKey][event.toolResponse.CallID].output.Response += event.toolResponse.Response
|
||||
}
|
||||
|
||||
func getFCInfos(ctx context.Context, nodeKey vo.NodeKey) map[string]*fcInfo {
|
||||
c := ctx.Value(fcCacheKey{}).(map[vo.NodeKey]map[string]*fcInfo)
|
||||
return c[nodeKey]
|
||||
}
|
||||
|
||||
func (f *fcInfo) inputString() string {
|
||||
m, err := sonic.MarshalString(f.input)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func (f *fcInfo) outputString() string {
|
||||
if f.output == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
m := map[string]any{
|
||||
"data": f.output.Response, // TODO: traceID, code, message?
|
||||
}
|
||||
b, err := sonic.MarshalString(m)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return b
|
||||
}
|
||||
77
backend/domain/workflow/internal/execute/tool_option.go
Normal file
77
backend/domain/workflow/internal/execute/tool_option.go
Normal file
@@ -0,0 +1,77 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package execute
|
||||
|
||||
import (
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
)
|
||||
|
||||
type workflowToolOption struct {
|
||||
resumeReq *entity.ResumeRequest
|
||||
sw *schema.StreamWriter[*entity.Message]
|
||||
exeCfg vo.ExecuteConfig
|
||||
allInterruptEvents map[string]*entity.ToolInterruptEvent
|
||||
parentTokenCollector *TokenCollector
|
||||
}
|
||||
|
||||
func WithResume(req *entity.ResumeRequest, all map[string]*entity.ToolInterruptEvent) tool.Option {
|
||||
return tool.WrapImplSpecificOptFn(func(opts *workflowToolOption) {
|
||||
opts.resumeReq = req
|
||||
opts.allInterruptEvents = all
|
||||
})
|
||||
}
|
||||
|
||||
func WithIntermediateStreamWriter(sw *schema.StreamWriter[*entity.Message]) tool.Option {
|
||||
return tool.WrapImplSpecificOptFn(func(opts *workflowToolOption) {
|
||||
opts.sw = sw
|
||||
})
|
||||
}
|
||||
|
||||
func WithExecuteConfig(cfg vo.ExecuteConfig) tool.Option {
|
||||
return tool.WrapImplSpecificOptFn(func(opts *workflowToolOption) {
|
||||
opts.exeCfg = cfg
|
||||
})
|
||||
}
|
||||
|
||||
func GetResumeRequest(opts ...tool.Option) (*entity.ResumeRequest, map[string]*entity.ToolInterruptEvent) {
|
||||
opt := tool.GetImplSpecificOptions(&workflowToolOption{}, opts...)
|
||||
return opt.resumeReq, opt.allInterruptEvents
|
||||
}
|
||||
|
||||
func GetIntermediateStreamWriter(opts ...tool.Option) *schema.StreamWriter[*entity.Message] {
|
||||
opt := tool.GetImplSpecificOptions(&workflowToolOption{}, opts...)
|
||||
return opt.sw
|
||||
}
|
||||
|
||||
func GetExecuteConfig(opts ...tool.Option) vo.ExecuteConfig {
|
||||
opt := tool.GetImplSpecificOptions(&workflowToolOption{}, opts...)
|
||||
return opt.exeCfg
|
||||
}
|
||||
|
||||
// WithMessagePipe returns an Option which is meant to be passed to the tool workflow, as well as a StreamReader to read the messages from the tool workflow.
|
||||
// This Option will apply to ALL workflow tools to be executed by eino's ToolsNode. The workflow tools will emit messages to this stream.
|
||||
// The caller can receive from the returned StreamReader to get the messages from the tool workflow.
|
||||
func WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message]) {
|
||||
sr, sw := schema.Pipe[*entity.Message](10)
|
||||
opt := compose.WithToolsNodeOption(compose.WithToolOption(WithIntermediateStreamWriter(sw)))
|
||||
return opt, sr
|
||||
}
|
||||
Reference in New Issue
Block a user