coze-studio/backend/domain/workflow/internal/execute/callback.go

1409 lines
36 KiB
Go

/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package execute
import (
"context"
"errors"
"fmt"
"io"
"reflect"
"slices"
"strconv"
"strings"
"time"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
"github.com/cloudwego/eino/callbacks"
"github.com/cloudwego/eino/components/tool"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
callbacks2 "github.com/cloudwego/eino/utils/callbacks"
workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
workflow2 "github.com/coze-dev/coze-studio/backend/api/model/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
)
type NodeHandler struct {
nodeKey vo.NodeKey
nodeName string
ch chan<- *Event
resumePath []string
resumeEvent *entity.InterruptEvent
terminatePlan *vo.TerminatePlan
}
type WorkflowHandler struct {
ch chan<- *Event
rootWorkflowBasic *entity.WorkflowBasic
rootExecuteID int64
subWorkflowBasic *entity.WorkflowBasic
nodeCount int32
requireCheckpoint bool
resumeEvent *entity.InterruptEvent
exeCfg workflowModel.ExecuteConfig
rootTokenCollector *TokenCollector
}
type ToolHandler struct {
ch chan<- *Event
info entity.FunctionInfo
}
func NewRootWorkflowHandler(wb *entity.WorkflowBasic, executeID int64, requireCheckpoint bool,
ch chan<- *Event, resumedEvent *entity.InterruptEvent, exeCfg workflowModel.ExecuteConfig, nodeCount int32,
) callbacks.Handler {
return &WorkflowHandler{
ch: ch,
rootWorkflowBasic: wb,
rootExecuteID: executeID,
requireCheckpoint: requireCheckpoint,
resumeEvent: resumedEvent,
exeCfg: exeCfg,
nodeCount: nodeCount,
}
}
func NewSubWorkflowHandler(parent *WorkflowHandler, subWB *entity.WorkflowBasic,
resumedEvent *entity.InterruptEvent, nodeCount int32,
) callbacks.Handler {
return &WorkflowHandler{
ch: parent.ch,
rootWorkflowBasic: parent.rootWorkflowBasic,
rootExecuteID: parent.rootExecuteID,
requireCheckpoint: parent.requireCheckpoint,
subWorkflowBasic: subWB,
resumeEvent: resumedEvent,
nodeCount: nodeCount,
}
}
func (w *WorkflowHandler) getRootWorkflowID() int64 {
if w.rootWorkflowBasic != nil {
return w.rootWorkflowBasic.ID
}
return 0
}
func (w *WorkflowHandler) getSubWorkflowID() int64 {
if w.subWorkflowBasic != nil {
return w.subWorkflowBasic.ID
}
return 0
}
func NewNodeHandler(key string, name string, ch chan<- *Event, resumeEvent *entity.InterruptEvent, plan *vo.TerminatePlan) callbacks.Handler {
var resumePath []string
if resumeEvent != nil {
resumePath = slices.Clone(resumeEvent.NodePath)
}
return &NodeHandler{
nodeKey: vo.NodeKey(key),
nodeName: name,
ch: ch,
resumePath: resumePath,
resumeEvent: resumeEvent,
terminatePlan: plan,
}
}
func NewToolHandler(ch chan<- *Event, info entity.FunctionInfo) callbacks.Handler {
th := &ToolHandler{
ch: ch,
info: info,
}
return callbacks2.NewHandlerHelper().Tool(&callbacks2.ToolCallbackHandler{
OnStart: th.OnStart,
OnEnd: th.OnEnd,
OnEndWithStreamOutput: th.OnEndWithStreamOutput,
OnError: th.OnError,
}).Handler()
}
func (w *WorkflowHandler) initWorkflowCtx(ctx context.Context) (context.Context, bool) {
var (
err error
newCtx context.Context
resume bool
)
if w.subWorkflowBasic == nil {
if w.resumeEvent != nil {
resume = true
newCtx, err = restoreWorkflowCtx(ctx, w)
if err != nil {
logs.Errorf("failed to restore root execute context: %v", err)
return ctx, false
}
} else {
newCtx, err = PrepareRootExeCtx(ctx, w)
if err != nil {
logs.Errorf("failed to prepare root exe context: %v", err)
return ctx, false
}
}
} else {
if w.resumeEvent == nil {
resume = false
} else {
resumePath := w.resumeEvent.NodePath
c := GetExeCtx(ctx)
if c == nil {
panic("nil execute context")
}
if c.NodeCtx == nil {
panic("sub workflow exe ctx must under a parent node ctx")
}
path := c.NodeCtx.NodePath
if len(path) > len(resumePath) {
resume = false
} else {
resume = true
for i := 0; i < len(path); i++ {
if path[i] != resumePath[i] {
resume = false
break
}
}
}
}
if resume {
newCtx, err = restoreWorkflowCtx(ctx, w)
if err != nil {
logs.Errorf("failed to restore sub execute context: %v", err)
return ctx, false
}
} else {
newCtx, err = PrepareSubExeCtx(ctx, w.subWorkflowBasic, w.requireCheckpoint)
if err != nil {
logs.Errorf("failed to prepare root exe context: %v", err)
return ctx, false
}
}
}
return newCtx, resume
}
func (w *WorkflowHandler) OnStart(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
if info.Component != compose.ComponentOfWorkflow || (info.Name != strconv.FormatInt(w.getRootWorkflowID(), 10) &&
info.Name != strconv.FormatInt(w.getSubWorkflowID(), 10)) {
return ctx
}
newCtx, resumed := w.initWorkflowCtx(ctx)
if w.subWorkflowBasic == nil {
// check if already canceled
canceled, err := workflow.GetRepository().GetWorkflowCancelFlag(newCtx, w.rootExecuteID)
if err != nil {
logs.Errorf("failed to get workflow cancel flag: %v", err)
}
if canceled {
cancelCtx, cancelFn := context.WithCancel(newCtx)
cancelFn()
return cancelCtx
}
}
if resumed {
c := GetExeCtx(newCtx)
w.ch <- &Event{
Type: WorkflowResume,
Context: c,
}
return newCtx
}
c := GetExeCtx(newCtx)
w.ch <- &Event{
Type: WorkflowStart,
Context: c,
Input: input.(map[string]any),
nodeCount: w.nodeCount,
}
return newCtx
}
func (w *WorkflowHandler) OnEnd(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context {
if info.Component != compose.ComponentOfWorkflow || (info.Name != strconv.FormatInt(w.getRootWorkflowID(), 10) &&
info.Name != strconv.FormatInt(w.getSubWorkflowID(), 10)) {
return ctx
}
c := GetExeCtx(ctx)
e := &Event{
Type: WorkflowSuccess,
Context: c,
Output: output.(map[string]any),
RawOutput: output.(map[string]any),
Duration: time.Since(time.UnixMilli(c.StartTime)),
}
if c.TokenCollector != nil {
usage := c.TokenCollector.wait()
e.Token = &TokenInfo{
InputToken: int64(usage.PromptTokens),
OutputToken: int64(usage.CompletionTokens),
TotalToken: int64(usage.TotalTokens),
}
}
w.ch <- e
return ctx
}
const InterruptEventIndexPrefix = "interrupt_event_index_"
func extractInterruptEvents(interruptInfo *compose.InterruptInfo, prefixes ...string) (interruptEvents []*entity.InterruptEvent, err error) {
ieStore, ok := interruptInfo.State.(nodes.InterruptEventStore)
if !ok {
return nil, errors.New("failed to extract interrupt event store from interrupt info")
}
for _, nodeKey := range interruptInfo.RerunNodes {
interruptE, ok, err := ieStore.GetInterruptEvent(vo.NodeKey(nodeKey))
if err != nil {
logs.Errorf("failed to extract interrupt event from node key: %v", err)
continue
}
if !ok {
extra := interruptInfo.RerunNodesExtra[nodeKey]
if extra == nil {
continue
}
interruptE, ok = extra.(*entity.InterruptEvent)
if !ok {
logs.Errorf("failed to extract tool interrupt event from node key: %v", err)
continue
}
}
if len(interruptE.NestedInterruptInfo) == 0 && interruptE.SubWorkflowInterruptInfo == nil {
interruptE.NodePath = append(prefixes, string(interruptE.NodeKey))
interruptEvents = append(interruptEvents, interruptE)
} else if len(interruptE.NestedInterruptInfo) > 0 {
for index := range interruptE.NestedInterruptInfo {
indexedPrefixes := append(prefixes, string(interruptE.NodeKey), InterruptEventIndexPrefix+strconv.Itoa(index))
indexedIEvents, err := extractInterruptEvents(interruptE.NestedInterruptInfo[index], indexedPrefixes...)
if err != nil {
return nil, err
}
interruptEvents = append(interruptEvents, indexedIEvents...)
}
} else if interruptE.SubWorkflowInterruptInfo != nil {
appendedPrefix := append(prefixes, string(interruptE.NodeKey))
subWorkflowIEvents, err := extractInterruptEvents(interruptE.SubWorkflowInterruptInfo, appendedPrefix...)
if err != nil {
return nil, err
}
interruptEvents = append(interruptEvents, subWorkflowIEvents...)
}
}
for graphKey, subGraphInfo := range interruptInfo.SubGraphs {
newPrefix := append(prefixes, graphKey)
subInterruptEvents, subErr := extractInterruptEvents(subGraphInfo, newPrefix...)
if subErr != nil {
return nil, subErr
}
interruptEvents = append(interruptEvents, subInterruptEvents...)
}
return interruptEvents, nil
}
func (w *WorkflowHandler) OnError(ctx context.Context, info *callbacks.RunInfo, err error) context.Context {
if info.Component != compose.ComponentOfWorkflow || (info.Name != strconv.FormatInt(w.getRootWorkflowID(), 10) &&
info.Name != strconv.FormatInt(w.getSubWorkflowID(), 10)) {
return ctx
}
c := GetExeCtx(ctx)
interruptInfo, ok := compose.ExtractInterruptInfo(err)
if ok {
if w.subWorkflowBasic != nil {
return ctx
}
interruptEvents, err := extractInterruptEvents(interruptInfo)
if err != nil {
logs.Errorf("failed to extract interrupt events: %v", err)
return ctx
}
for _, interruptEvent := range interruptEvents {
logs.CtxInfof(ctx, "emit interrupt event id= %d, eventType= %d, nodeID= %s", interruptEvent.ID,
interruptEvent.EventType, interruptEvent.NodeKey)
}
if c.TokenCollector != nil { // wait until all streaming chunks are collected
_ = c.TokenCollector.wait()
}
done := make(chan struct{})
w.ch <- &Event{
Type: WorkflowInterrupt,
Context: c,
InterruptEvents: interruptEvents,
done: done,
}
<-done
return ctx
}
if errors.Is(err, context.Canceled) {
e := &Event{
Type: WorkflowCancel,
Context: c,
Duration: time.Since(time.UnixMilli(c.StartTime)),
}
if c.TokenCollector != nil {
usage := c.TokenCollector.wait()
e.Token = &TokenInfo{
InputToken: int64(usage.PromptTokens),
OutputToken: int64(usage.CompletionTokens),
TotalToken: int64(usage.TotalTokens),
}
}
w.ch <- e
return ctx
}
logs.CtxErrorf(ctx, "workflow failed: %v", err)
e := &Event{
Type: WorkflowFailed,
Context: c,
Duration: time.Since(time.UnixMilli(c.StartTime)),
Err: err,
}
if c.TokenCollector != nil {
usage := c.TokenCollector.wait()
e.Token = &TokenInfo{
InputToken: int64(usage.PromptTokens),
OutputToken: int64(usage.CompletionTokens),
TotalToken: int64(usage.TotalTokens),
}
}
w.ch <- e
return ctx
}
func (w *WorkflowHandler) OnStartWithStreamInput(ctx context.Context, info *callbacks.RunInfo,
input *schema.StreamReader[callbacks.CallbackInput],
) context.Context {
if info.Component != compose.ComponentOfWorkflow || (info.Name != strconv.FormatInt(w.getRootWorkflowID(), 10) &&
info.Name != strconv.FormatInt(w.getSubWorkflowID(), 10)) {
input.Close()
return ctx
}
newCtx, resumed := w.initWorkflowCtx(ctx)
if w.subWorkflowBasic == nil {
// check if already canceled
canceled, err := workflow.GetRepository().GetWorkflowCancelFlag(newCtx, w.rootExecuteID)
if err != nil {
logs.Errorf("failed to get workflow cancel flag: %v", err)
}
if canceled {
input.Close()
cancelCtx, cancelFn := context.WithCancel(newCtx)
cancelFn()
return cancelCtx
}
}
if resumed {
input.Close()
c := GetExeCtx(newCtx)
w.ch <- &Event{
Type: WorkflowResume,
Context: c,
}
return newCtx
}
// consumes the stream synchronously because a workflow can only have Invoke or Stream.
defer input.Close()
fullInput := make(map[string]any)
for {
chunk, e := input.Recv()
if e != nil {
if e == io.EOF {
break
}
logs.Errorf("failed to receive stream input: %v", e)
return newCtx
}
fullInput, e = nodes.ConcatTwoMaps(fullInput, chunk.(map[string]any))
if e != nil {
logs.Errorf("failed to concat two maps: %v", e)
return newCtx
}
}
c := GetExeCtx(newCtx)
w.ch <- &Event{
Type: WorkflowStart,
Context: c,
Input: fullInput,
nodeCount: w.nodeCount,
}
return newCtx
}
func (w *WorkflowHandler) OnEndWithStreamOutput(ctx context.Context, info *callbacks.RunInfo,
output *schema.StreamReader[callbacks.CallbackOutput],
) context.Context {
if info.Component != compose.ComponentOfWorkflow || (info.Name != strconv.FormatInt(w.getRootWorkflowID(), 10) &&
info.Name != strconv.FormatInt(w.getSubWorkflowID(), 10)) {
output.Close()
return ctx
}
safego.Go(ctx, func() {
defer output.Close()
fullOutput := make(map[string]any)
for {
chunk, e := output.Recv()
if e != nil {
if e == io.EOF {
break
}
if _, ok := schema.GetSourceName(e); ok {
continue
}
logs.Errorf("workflow OnEndWithStreamOutput failed to receive stream output: %v", e)
_ = w.OnError(ctx, info, e)
return
}
fullOutput, e = nodes.ConcatTwoMaps(fullOutput, chunk.(map[string]any))
if e != nil {
logs.Errorf("failed to concat two maps: %v", e)
return
}
}
c := GetExeCtx(ctx)
e := &Event{
Type: WorkflowSuccess,
Context: c,
Duration: time.Since(time.UnixMilli(c.StartTime)),
Output: fullOutput,
}
if c.TokenCollector != nil {
usage := c.TokenCollector.wait()
e.Token = &TokenInfo{
InputToken: int64(usage.PromptTokens),
OutputToken: int64(usage.CompletionTokens),
TotalToken: int64(usage.TotalTokens),
}
}
w.ch <- e
})
return ctx
}
func (n *NodeHandler) initNodeCtx(ctx context.Context, typ entity.NodeType) (context.Context, bool) {
var (
err error
newCtx context.Context
resume bool // whether this node is on the resume path
exactlyResuming bool // whether this node is the exact node resuming
)
if len(n.resumePath) == 0 {
resume = false
} else {
c := GetExeCtx(ctx)
if c == nil {
panic("nil execute context")
}
if c.NodeCtx == nil { // top level node
resume = n.resumePath[0] == string(n.nodeKey)
exactlyResuming = resume && len(n.resumePath) == 1
} else {
path := slices.Clone(c.NodeCtx.NodePath)
// immediate inner node under composite node
if c.BatchInfo != nil && c.BatchInfo.CompositeNodeKey == c.NodeCtx.NodeKey {
path = append(path, InterruptEventIndexPrefix+strconv.Itoa(c.BatchInfo.Index))
}
path = append(path, string(n.nodeKey))
if len(path) > len(n.resumePath) {
resume = false
} else {
resume = true
for i := 0; i < len(path); i++ {
if path[i] != n.resumePath[i] {
resume = false
break
}
}
if resume && len(path) == len(n.resumePath) {
exactlyResuming = true
}
}
}
}
if resume {
newCtx, err = restoreNodeCtx(ctx, n.nodeKey, n.resumeEvent, exactlyResuming)
if err != nil {
logs.Errorf("failed to restore node execute context: %v", err)
return ctx, resume
}
var resumeEventID int64
if c := GetExeCtx(newCtx); c != nil && c.RootCtx.ResumeEvent != nil {
resumeEventID = c.RootCtx.ResumeEvent.ID
}
logs.CtxInfof(ctx, "[restoreNodeCtx] restored nodeKey= %s, root.resumeEventID= %d", n.nodeKey, resumeEventID)
} else {
// even if this node is not on the resume path, it could still restore from checkpoint,
// for example:
// this workflow has parallel interrupts, this node is one of them(or along the path of one of them),
// but not resumed this time
restoredCtx, restored := tryRestoreNodeCtx(ctx, n.nodeKey)
if restored {
logs.CtxInfof(ctx, "[tryRestoreNodeCtx] restored, nodeKey= %s", n.nodeKey)
newCtx = restoredCtx
return newCtx, true
}
newCtx, err = PrepareNodeExeCtx(ctx, n.nodeKey, n.nodeName, typ, n.terminatePlan)
if err != nil {
logs.Errorf("failed to prepare node execute context: %v", err)
return ctx, resume
}
}
return newCtx, resume
}
func (n *NodeHandler) OnStart(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
if info.Component != compose.ComponentOfLambda || info.Name != string(n.nodeKey) {
return ctx
}
newCtx, resumed := n.initNodeCtx(ctx, entity.NodeType(info.Type))
if resumed {
return newCtx
}
c := GetExeCtx(newCtx)
if c == nil {
panic("nil node context")
}
e := &Event{
Type: NodeStart,
Context: c,
Input: input.(map[string]any),
extra: &entity.NodeExtra{},
}
if c.SubWorkflowCtx == nil {
e.extra.CurrentSubExecuteID = c.RootExecuteID
} else {
e.extra.CurrentSubExecuteID = c.SubExecuteID
}
n.ch <- e
return newCtx
}
func (n *NodeHandler) OnEnd(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context {
if info.Component != compose.ComponentOfLambda || info.Name != string(n.nodeKey) {
return ctx
}
var (
outputMap, rawOutputMap, customExtra map[string]any
errInfo vo.WorkflowError
ok bool
)
outputMap, ok = output.(map[string]any)
if ok {
rawOutputMap = outputMap
} else {
structuredOutput, ok := output.(*nodes.StructuredCallbackOutput)
if !ok {
return ctx
}
outputMap = structuredOutput.Output
rawOutputMap = structuredOutput.RawOutput
customExtra = structuredOutput.Extra
errInfo = structuredOutput.Error
}
c := GetExeCtx(ctx)
startTime := time.UnixMilli(c.StartTime)
duration := time.Since(startTime)
_ = duration
e := &Event{
Type: NodeEnd,
Context: c,
Duration: time.Since(time.UnixMilli(c.StartTime)),
Output: outputMap,
RawOutput: rawOutputMap,
Err: errInfo,
extra: &entity.NodeExtra{},
}
if c.TokenCollector != nil && entity.NodeMetaByNodeType(c.NodeType).MayUseChatModel {
usage := c.TokenCollector.wait()
e.Token = &TokenInfo{
InputToken: int64(usage.PromptTokens),
OutputToken: int64(usage.CompletionTokens),
TotalToken: int64(usage.TotalTokens),
}
}
if c.NodeType == entity.NodeTypeOutputEmitter {
e.Answer = output.(map[string]any)["output"].(string)
} else if c.NodeType == entity.NodeTypeExit && *c.TerminatePlan == vo.UseAnswerContent {
e.Answer = output.(map[string]any)["output"].(string)
}
if len(customExtra) > 0 {
if e.extra.ResponseExtra == nil {
e.extra.ResponseExtra = map[string]any{}
}
for k := range customExtra {
e.extra.ResponseExtra[k] = customExtra[k]
}
}
if c.SubWorkflowCtx == nil {
e.extra.CurrentSubExecuteID = c.RootExecuteID
} else {
e.extra.CurrentSubExecuteID = c.SubExecuteID
}
switch t := entity.NodeType(info.Type); t {
case entity.NodeTypeExit:
terminatePlan := n.terminatePlan
if terminatePlan == nil {
terminatePlan = ptr.Of(vo.ReturnVariables)
}
if *terminatePlan == vo.UseAnswerContent {
e.extra = &entity.NodeExtra{
ResponseExtra: map[string]any{
"terminal_plan": workflow2.TerminatePlanType_USESETTING,
},
}
e.outputExtractor = func(o map[string]any) string {
str, ok := o["output"].(string)
if ok {
return str
}
return fmt.Sprint(o["output"])
}
}
case entity.NodeTypeOutputEmitter:
e.outputExtractor = func(o map[string]any) string {
str, ok := o["output"].(string)
if ok {
return str
}
return fmt.Sprint(o["output"])
}
case entity.NodeTypeInputReceiver:
e.Input = outputMap
default:
}
n.ch <- e
return ctx
}
func (n *NodeHandler) OnError(ctx context.Context, info *callbacks.RunInfo, err error) context.Context {
if info.Component != compose.ComponentOfLambda || info.Name != string(n.nodeKey) {
return ctx
}
c := GetExeCtx(ctx)
if _, ok := compose.IsInterruptRerunError(err); ok { // current node interrupts
if err := compose.ProcessState[ExeContextStore](ctx, func(ctx context.Context, state ExeContextStore) error {
if state == nil {
return errors.New("state is nil")
}
logs.CtxInfof(ctx, "[SetNodeCtx] nodeKey= %s", n.nodeKey)
return state.SetNodeCtx(n.nodeKey, c)
}); err != nil {
logs.Errorf("failed to process state: %v", err)
}
return ctx
}
if errors.Is(err, context.Canceled) {
if c == nil || c.NodeCtx == nil {
return ctx
}
e := &Event{
Type: NodeError,
Context: c,
Duration: time.Since(time.UnixMilli(c.StartTime)),
Err: err,
}
if c.TokenCollector != nil {
usage := c.TokenCollector.wait()
e.Token = &TokenInfo{
InputToken: int64(usage.PromptTokens),
OutputToken: int64(usage.CompletionTokens),
TotalToken: int64(usage.TotalTokens),
}
}
n.ch <- e
return ctx
}
e := &Event{
Type: NodeError,
Context: c,
Duration: time.Since(time.UnixMilli(c.StartTime)),
Err: err,
}
if c.TokenCollector != nil {
usage := c.TokenCollector.wait()
e.Token = &TokenInfo{
InputToken: int64(usage.PromptTokens),
OutputToken: int64(usage.CompletionTokens),
TotalToken: int64(usage.TotalTokens),
}
}
n.ch <- e
return ctx
}
func (n *NodeHandler) OnStartWithStreamInput(ctx context.Context, info *callbacks.RunInfo, input *schema.StreamReader[callbacks.CallbackInput]) context.Context {
if info.Component != compose.ComponentOfLambda || info.Name != string(n.nodeKey) {
input.Close()
return ctx
}
// currently Exit, OutputEmitter can potentially trigger this.
// VariableAggregator can also potentially trigger this.
// we may receive nodes.KeyIsFinished from the stream, which should be discarded when concatenating the map.
if info.Type != string(entity.NodeTypeExit) &&
info.Type != string(entity.NodeTypeOutputEmitter) &&
info.Type != string(entity.NodeTypeVariableAggregator) {
panic(fmt.Sprintf("impossible, node type= %s", info.Type))
}
newCtx, resumed := n.initNodeCtx(ctx, entity.NodeType(info.Type))
if resumed {
input.Close()
return newCtx
}
c := GetExeCtx(newCtx)
if c == nil {
panic("nil node context")
}
e := &Event{
Type: NodeStart,
Context: c,
}
if entity.NodeType(info.Type) == entity.NodeTypeExit {
terminatePlan := n.terminatePlan
if terminatePlan == nil {
terminatePlan = ptr.Of(vo.ReturnVariables)
}
if *terminatePlan == vo.UseAnswerContent {
e.extra = &entity.NodeExtra{
ResponseExtra: map[string]any{
"terminal_plan": workflow2.TerminatePlanType_USESETTING,
},
}
}
}
n.ch <- e
safego.Go(ctx, func() {
defer input.Close()
fullInput := make(map[string]any)
var previous map[string]any
for {
chunk, e := input.Recv()
if e != nil {
if e == io.EOF {
break
}
if _, ok := schema.GetSourceName(e); ok {
continue
}
logs.Errorf("node OnStartWithStreamInput failed to receive stream output: %v", e)
_ = n.OnError(newCtx, info, e)
return
}
previous = fullInput
fullInput, e = nodes.ConcatTwoMaps(fullInput, chunk.(map[string]any))
if e != nil {
logs.Errorf("failed to concat two maps: %v", e)
return
}
if info.Type == string(entity.NodeTypeVariableAggregator) {
if !reflect.DeepEqual(fullInput, previous) {
n.ch <- &Event{
Type: NodeStreamingInput,
Context: c,
Input: fullInput,
}
}
}
}
n.ch <- &Event{
Type: NodeStreamingInput,
Context: c,
Input: fullInput,
}
})
return newCtx
}
func (n *NodeHandler) OnEndWithStreamOutput(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[callbacks.CallbackOutput]) context.Context {
if info.Component != compose.ComponentOfLambda || info.Name != string(n.nodeKey) {
output.Close()
return ctx
}
c := GetExeCtx(ctx)
switch t := entity.NodeType(info.Type); t {
case entity.NodeTypeLLM:
safego.Go(ctx, func() {
defer output.Close()
fullOutput := make(map[string]any)
fullRawOutput := make(map[string]any)
var warning error
for {
chunk, e := output.Recv()
if e != nil {
if e == io.EOF {
break
}
logs.Errorf("node OnEndWithStreamOutput failed to receive stream output: %v", e)
_ = n.OnError(ctx, info, e)
return
}
so := chunk.(*nodes.StructuredCallbackOutput)
fullOutput, e = nodes.ConcatTwoMaps(fullOutput, so.Output)
if e != nil {
logs.Errorf("failed to concat two maps: %v", e)
_ = n.OnError(ctx, info, e)
return
}
fullRawOutput, e = nodes.ConcatTwoMaps(fullRawOutput, so.RawOutput)
if e != nil {
logs.Errorf("failed to concat two maps: %v", e)
_ = n.OnError(ctx, info, e)
return
}
if so.Error != nil {
warning = so.Error
}
}
e := &Event{
Type: NodeEndStreaming,
Context: c,
Output: fullOutput,
RawOutput: fullRawOutput,
Duration: time.Since(time.UnixMilli(c.StartTime)),
Err: warning,
extra: &entity.NodeExtra{},
}
if c.TokenCollector != nil {
usage := c.TokenCollector.wait()
e.Token = &TokenInfo{
InputToken: int64(usage.PromptTokens),
OutputToken: int64(usage.CompletionTokens),
TotalToken: int64(usage.TotalTokens),
}
}
if c.SubWorkflowCtx == nil {
e.extra.CurrentSubExecuteID = c.RootExecuteID
} else {
e.extra.CurrentSubExecuteID = c.SubExecuteID
}
// TODO: hard-coded string
if _, ok := fullOutput["output"]; ok {
if len(fullOutput) == 1 {
e.outputExtractor = func(o map[string]any) string {
if o["output"] == nil {
return ""
}
return o["output"].(string)
}
} else if len(fullOutput) == 2 {
if reasoning, ok := fullOutput["reasoning_content"]; ok {
e.outputExtractor = func(o map[string]any) string {
if o["output"] == nil {
return ""
}
return o["output"].(string)
}
if reasoning != nil {
e.extra.ResponseExtra = map[string]any{
"reasoning_content": fullOutput["reasoning_content"].(string),
}
}
}
}
}
n.ch <- e
})
case entity.NodeTypeVariableAggregator:
safego.Go(ctx, func() {
defer output.Close()
extra := &entity.NodeExtra{}
if c.SubWorkflowCtx == nil {
extra.CurrentSubExecuteID = c.RootExecuteID
} else {
extra.CurrentSubExecuteID = c.SubExecuteID
}
extra.ResponseExtra = make(map[string]any)
fullOutput := &nodes.StructuredCallbackOutput{
Output: make(map[string]any),
RawOutput: make(map[string]any),
}
var (
previous *nodes.StructuredCallbackOutput
first = true
)
for {
chunk, e := output.Recv()
if e != nil {
if e == io.EOF {
break
}
logs.Errorf("node OnEndWithStreamOutput failed to receive stream output: %v", e)
_ = n.OnError(ctx, info, e)
return
}
previous = fullOutput
fullOutputMap, e := nodes.ConcatTwoMaps(fullOutput.Output, chunk.(*nodes.StructuredCallbackOutput).Output)
if e != nil {
logs.Errorf("failed to concat two maps: %v", e)
_ = n.OnError(ctx, info, e)
return
}
fullRawOutput, e := nodes.ConcatTwoMaps(fullOutput.RawOutput, chunk.(*nodes.StructuredCallbackOutput).RawOutput)
if e != nil {
logs.Errorf("failed to concat two maps: %v", e)
_ = n.OnError(ctx, info, e)
return
}
if first {
extra.ResponseExtra = chunk.(*nodes.StructuredCallbackOutput).Extra
}
fullOutput = &nodes.StructuredCallbackOutput{
Output: fullOutputMap,
RawOutput: fullRawOutput,
}
if !reflect.DeepEqual(fullOutput, previous) {
deltaEvent := &Event{
Type: NodeStreamingOutput,
Context: c,
Output: fullOutput.Output,
RawOutput: fullOutput.RawOutput,
}
if first {
deltaEvent.extra = extra
first = false
}
n.ch <- deltaEvent
}
}
e := &Event{
Type: NodeEndStreaming,
Context: c,
Output: fullOutput.Output,
RawOutput: fullOutput.RawOutput,
Duration: time.Since(time.UnixMilli(c.StartTime)),
}
n.ch <- e
})
case entity.NodeTypeExit, entity.NodeTypeOutputEmitter, entity.NodeTypeSubWorkflow:
consumer := func(ctx context.Context) context.Context {
defer output.Close()
fullOutput := make(map[string]any)
var firstEvent, previousEvent, secondPreviousEvent *Event
for {
chunk, err := output.Recv()
if err != nil {
if err == io.EOF {
if previousEvent != nil {
previousEmpty := len(previousEvent.Answer) == 0
if previousEmpty { // concat the empty previous chunk with the second previous chunk
if secondPreviousEvent != nil {
secondPreviousEvent.StreamEnd = true
n.ch <- secondPreviousEvent
} else {
previousEvent.StreamEnd = true
n.ch <- previousEvent
}
} else {
if secondPreviousEvent != nil {
n.ch <- secondPreviousEvent
}
previousEvent.StreamEnd = true
n.ch <- previousEvent
}
} else { // only sent first event, or no event at all
n.ch <- &Event{
Type: NodeStreamingOutput,
Context: c,
Output: fullOutput,
StreamEnd: true,
}
}
break
}
if _, ok := schema.GetSourceName(err); ok {
continue
}
logs.Errorf("node OnEndWithStreamOutput failed to receive stream output: %v", err)
return n.OnError(ctx, info, err)
}
if secondPreviousEvent != nil {
n.ch <- secondPreviousEvent
}
fullOutput, err = nodes.ConcatTwoMaps(fullOutput, chunk.(map[string]any))
if err != nil {
logs.Errorf("failed to concat two maps: %v", err)
return n.OnError(ctx, info, err)
}
deltaEvent := &Event{
Type: NodeStreamingOutput,
Context: c,
Output: fullOutput,
}
if delta, ok := chunk.(map[string]any)["output"]; ok {
if entity.NodeType(info.Type) == entity.NodeTypeOutputEmitter {
deltaEvent.Answer = strings.TrimSuffix(delta.(string), nodes.KeyIsFinished)
deltaEvent.outputExtractor = func(o map[string]any) string {
str, ok := o["output"].(string)
if ok {
return str
}
return fmt.Sprint(o["output"])
}
} else if n.terminatePlan != nil && *n.terminatePlan == vo.UseAnswerContent {
deltaEvent.Answer = strings.TrimSuffix(delta.(string), nodes.KeyIsFinished)
deltaEvent.outputExtractor = func(o map[string]any) string {
str, ok := o["output"].(string)
if ok {
return str
}
return fmt.Sprint(o["output"])
}
}
}
if firstEvent == nil { // prioritize sending the first event asap.
firstEvent = deltaEvent
n.ch <- firstEvent
} else {
secondPreviousEvent = previousEvent
previousEvent = deltaEvent
}
}
e := &Event{
Type: NodeEndStreaming,
Context: c,
Output: fullOutput,
RawOutput: fullOutput,
Duration: time.Since(time.UnixMilli(c.StartTime)),
extra: &entity.NodeExtra{},
}
if answer, ok := fullOutput["output"]; ok {
if entity.NodeType(info.Type) == entity.NodeTypeOutputEmitter {
e.Answer = answer.(string)
e.outputExtractor = func(o map[string]any) string {
str, ok := o["output"].(string)
if ok {
return str
}
return fmt.Sprint(o["output"])
}
} else if n.terminatePlan != nil && *n.terminatePlan == vo.UseAnswerContent {
e.Answer = answer.(string)
e.outputExtractor = func(o map[string]any) string {
str, ok := o["output"].(string)
if ok {
return str
}
return fmt.Sprint(o["output"])
}
}
}
if c.SubWorkflowCtx == nil {
e.extra.CurrentSubExecuteID = c.RootExecuteID
} else {
e.extra.CurrentSubExecuteID = c.SubExecuteID
}
if t == entity.NodeTypeExit {
terminatePlan := n.terminatePlan
if terminatePlan == nil {
terminatePlan = ptr.Of(vo.ReturnVariables)
}
if *terminatePlan == vo.UseAnswerContent {
e.extra.ResponseExtra = map[string]any{
"terminal_plan": workflow2.TerminatePlanType_USESETTING,
}
}
}
n.ch <- e
return ctx
}
if c.NodeType == entity.NodeTypeExit {
go consumer(ctx) // handles Exit node asynchronously to keep the typewriter effect for workflow tool returning directly
return ctx
} else if c.NodeType == entity.NodeTypeOutputEmitter || c.NodeType == entity.NodeTypeSubWorkflow {
return consumer(ctx)
}
default:
panic(fmt.Sprintf("impossible, node type= %s", info.Type))
}
return ctx
}
func (t *ToolHandler) OnStart(ctx context.Context, info *callbacks.RunInfo,
input *tool.CallbackInput,
) context.Context {
if info.Name != t.info.Name {
return ctx
}
var args map[string]any
if input.ArgumentsInJSON != "" {
if err := sonic.UnmarshalString(input.ArgumentsInJSON, &args); err != nil {
logs.Errorf("failed to unmarshal arguments: %v", err)
return ctx
}
}
t.ch <- &Event{
Type: FunctionCall,
Context: GetExeCtx(ctx),
functionCall: &entity.FunctionCallInfo{
FunctionInfo: t.info,
CallID: compose.GetToolCallID(ctx),
Arguments: args,
},
}
return ctx
}
func (t *ToolHandler) OnEnd(ctx context.Context, info *callbacks.RunInfo,
output *tool.CallbackOutput,
) context.Context {
if info.Name != t.info.Name {
return ctx
}
t.ch <- &Event{
Type: ToolResponse,
Context: GetExeCtx(ctx),
toolResponse: &entity.ToolResponseInfo{
FunctionInfo: t.info,
CallID: compose.GetToolCallID(ctx),
Response: output.Response,
Complete: true,
},
}
return ctx
}
func (t *ToolHandler) OnEndWithStreamOutput(ctx context.Context, info *callbacks.RunInfo,
output *schema.StreamReader[*tool.CallbackOutput],
) context.Context {
if info.Name != t.info.Name {
output.Close()
return ctx
}
safego.Go(ctx, func() {
c := GetExeCtx(ctx)
defer output.Close()
var (
firstEvent, previousEvent *Event
fullResponse string
callID = compose.GetToolCallID(ctx)
)
for {
chunk, e := output.Recv()
if e != nil {
if e == io.EOF {
if previousEvent != nil {
previousEvent.StreamEnd = true
t.ch <- previousEvent
} else {
t.ch <- &Event{
Type: ToolStreamingResponse,
Context: c,
StreamEnd: true,
toolResponse: &entity.ToolResponseInfo{
FunctionInfo: t.info,
CallID: callID,
Complete: true,
},
}
}
break
}
logs.Errorf("tool OnEndWithStreamOutput failed to receive stream output: %v", e)
_ = t.OnError(ctx, info, e)
return
}
fullResponse += chunk.Response
if previousEvent != nil {
t.ch <- previousEvent
}
deltaEvent := &Event{
Type: ToolStreamingResponse,
Context: c,
toolResponse: &entity.ToolResponseInfo{
FunctionInfo: t.info,
CallID: compose.GetToolCallID(ctx),
Response: chunk.Response,
},
}
if firstEvent == nil {
firstEvent = deltaEvent
t.ch <- firstEvent
} else {
previousEvent = deltaEvent
}
}
})
return ctx
}
func (t *ToolHandler) OnError(ctx context.Context, info *callbacks.RunInfo, err error) context.Context {
if info.Name != t.info.Name {
return ctx
}
t.ch <- &Event{
Type: ToolError,
Context: GetExeCtx(ctx),
functionCall: &entity.FunctionCallInfo{
FunctionInfo: t.info,
CallID: compose.GetToolCallID(ctx),
},
Err: err,
}
return ctx
}