fix: workflow tool closes stream writer correctly (#1839)
This commit is contained in:
@@ -48,7 +48,7 @@ type Executable interface {
|
||||
|
||||
type AsTool interface {
|
||||
WorkflowAsModelTool(ctx context.Context, policies []*vo.GetPolicy) ([]ToolFromWorkflow, error)
|
||||
WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message])
|
||||
WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message], func())
|
||||
WithExecuteConfig(cfg workflowModel.ExecuteConfig) compose.Option
|
||||
WithResumeToolWorkflow(resumingEvent *entity.ToolInterruptEvent, resumeData string,
|
||||
allInterruptEvents map[string]*entity.ToolInterruptEvent) compose.Option
|
||||
|
||||
@@ -85,6 +85,7 @@ type ToolResponseInfo struct {
|
||||
FunctionInfo
|
||||
CallID string
|
||||
Response string
|
||||
Complete bool
|
||||
}
|
||||
|
||||
type ToolType = workflow.PluginType
|
||||
|
||||
@@ -23,7 +23,6 @@ import (
|
||||
"strconv"
|
||||
|
||||
einoCompose "github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
|
||||
@@ -47,7 +46,7 @@ func (r *WorkflowRunner) designateOptions(ctx context.Context) (context.Context,
|
||||
workflowSC = r.schema
|
||||
eventChan = r.eventChan
|
||||
resumedEvent = r.interruptEvent
|
||||
sw = r.streamWriter
|
||||
sw = r.container
|
||||
)
|
||||
|
||||
if wb.AppID != nil && exeCfg.AppID == nil {
|
||||
@@ -148,7 +147,7 @@ func (r *WorkflowRunner) designateOptionsForSubWorkflow(ctx context.Context,
|
||||
var (
|
||||
resumeEvent = r.interruptEvent
|
||||
eventChan = r.eventChan
|
||||
sw = r.streamWriter
|
||||
container = r.container
|
||||
)
|
||||
subHandler := execute.NewSubWorkflowHandler(
|
||||
parentHandler,
|
||||
@@ -186,7 +185,7 @@ func (r *WorkflowRunner) designateOptionsForSubWorkflow(ctx context.Context,
|
||||
opts = append(opts, WrapOpt(subO, ns.Key))
|
||||
}
|
||||
} else if subNS.Type == entity.NodeTypeLLM {
|
||||
llmNodeOpts, err := llmToolCallbackOptions(ctx, subNS, eventChan, sw)
|
||||
llmNodeOpts, err := llmToolCallbackOptions(ctx, subNS, eventChan, container)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -209,7 +208,7 @@ func (r *WorkflowRunner) designateOptionsForSubWorkflow(ctx context.Context,
|
||||
opts = append(opts, WrapOpt(WrapOpt(subO, parent.Key), ns.Key))
|
||||
}
|
||||
} else if subNS.Type == entity.NodeTypeLLM {
|
||||
llmNodeOpts, err := llmToolCallbackOptions(ctx, subNS, eventChan, sw)
|
||||
llmNodeOpts, err := llmToolCallbackOptions(ctx, subNS, eventChan, container)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -224,7 +223,7 @@ func (r *WorkflowRunner) designateOptionsForSubWorkflow(ctx context.Context,
|
||||
}
|
||||
|
||||
func llmToolCallbackOptions(ctx context.Context, ns *schema2.NodeSchema, eventChan chan *execute.Event,
|
||||
sw *schema.StreamWriter[*entity.Message]) (
|
||||
container *execute.StreamContainer) (
|
||||
opts []einoCompose.Option, err error) {
|
||||
// this is a LLM node.
|
||||
// check if it has any tools, if no tools, then no callback options needed
|
||||
@@ -280,6 +279,12 @@ func llmToolCallbackOptions(ctx context.Context, ns *schema2.NodeSchema, eventCh
|
||||
opt = einoCompose.WithLambdaOption(nodes.WithOptsForNested(opt)).DesignateNode(string(ns.Key))
|
||||
opts = append(opts, opt)
|
||||
}
|
||||
|
||||
if container != nil {
|
||||
toolMsgOpt := llm.WithToolWorkflowStreamContainer(container)
|
||||
opt := einoCompose.WithLambdaOption(toolMsgOpt).DesignateNode(string(ns.Key))
|
||||
opts = append(opts, opt)
|
||||
}
|
||||
}
|
||||
if fcParams.PluginFCParam != nil {
|
||||
for _, p := range fcParams.PluginFCParam.PluginList {
|
||||
@@ -321,11 +326,5 @@ func llmToolCallbackOptions(ctx context.Context, ns *schema2.NodeSchema, eventCh
|
||||
}
|
||||
}
|
||||
|
||||
if sw != nil {
|
||||
toolMsgOpt := llm.WithToolWorkflowMessageWriter(sw)
|
||||
opt := einoCompose.WithLambdaOption(toolMsgOpt).DesignateNode(string(ns.Key))
|
||||
opts = append(opts, opt)
|
||||
}
|
||||
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
@@ -38,15 +38,17 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
type WorkflowRunner struct {
|
||||
basic *entity.WorkflowBasic
|
||||
input string
|
||||
resumeReq *entity.ResumeRequest
|
||||
schema *schema2.WorkflowSchema
|
||||
streamWriter *schema.StreamWriter[*entity.Message]
|
||||
config model.ExecuteConfig
|
||||
basic *entity.WorkflowBasic
|
||||
input string
|
||||
resumeReq *entity.ResumeRequest
|
||||
schema *schema2.WorkflowSchema
|
||||
sw *schema.StreamWriter[*entity.Message]
|
||||
container *execute.StreamContainer
|
||||
config model.ExecuteConfig
|
||||
|
||||
executeID int64
|
||||
eventChan chan *execute.Event
|
||||
@@ -84,13 +86,19 @@ func NewWorkflowRunner(b *entity.WorkflowBasic, sc *schema2.WorkflowSchema, conf
|
||||
opt(options)
|
||||
}
|
||||
|
||||
var container *execute.StreamContainer
|
||||
if options.streamWriter != nil {
|
||||
container = execute.NewStreamContainer(options.streamWriter)
|
||||
}
|
||||
|
||||
return &WorkflowRunner{
|
||||
basic: b,
|
||||
input: options.input,
|
||||
resumeReq: options.resumeReq,
|
||||
schema: sc,
|
||||
streamWriter: options.streamWriter,
|
||||
config: config,
|
||||
basic: b,
|
||||
input: options.input,
|
||||
resumeReq: options.resumeReq,
|
||||
schema: sc,
|
||||
sw: options.streamWriter,
|
||||
container: container,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -108,14 +116,16 @@ func (r *WorkflowRunner) Prepare(ctx context.Context) (
|
||||
resumeReq = r.resumeReq
|
||||
wb = r.basic
|
||||
sc = r.schema
|
||||
sw = r.streamWriter
|
||||
sw = r.sw
|
||||
container = r.container
|
||||
config = r.config
|
||||
)
|
||||
|
||||
if r.resumeReq == nil {
|
||||
executeID, err = repo.GenID(ctx)
|
||||
if err != nil {
|
||||
return ctx, 0, nil, nil, fmt.Errorf("failed to generate workflow execute ID: %w", err)
|
||||
return ctx, 0, nil, nil, vo.WrapError(errno.ErrIDGenError,
|
||||
fmt.Errorf("failed to generate workflow execute ID: %w", err))
|
||||
}
|
||||
} else {
|
||||
executeID = resumeReq.ExecuteID
|
||||
@@ -148,6 +158,15 @@ func (r *WorkflowRunner) Prepare(ctx context.Context) (
|
||||
r.eventChan = eventChan
|
||||
r.interruptEvent = interruptEvent
|
||||
|
||||
if container != nil {
|
||||
go container.PipeAll()
|
||||
defer func() {
|
||||
if err != nil {
|
||||
container.Done()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
ctx, composeOpts, err := r.designateOptions(ctx)
|
||||
if err != nil {
|
||||
return ctx, 0, nil, nil, err
|
||||
@@ -277,8 +296,8 @@ func (r *WorkflowRunner) Prepare(ctx context.Context) (
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
if sw != nil {
|
||||
sw.Close()
|
||||
if container != nil {
|
||||
container.Done()
|
||||
}
|
||||
}()
|
||||
|
||||
|
||||
@@ -33,17 +33,22 @@ import (
|
||||
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
const answerKey = "output"
|
||||
|
||||
type invokableWorkflow struct {
|
||||
workflowTool
|
||||
invoke func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (map[string]any, error)
|
||||
}
|
||||
|
||||
type workflowTool struct {
|
||||
info *schema.ToolInfo
|
||||
invoke func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (map[string]any, error)
|
||||
terminatePlan vo.TerminatePlan
|
||||
wfEntity *entity.Workflow
|
||||
sc *schema2.WorkflowSchema
|
||||
repo wf.Repository
|
||||
terminatePlan vo.TerminatePlan
|
||||
}
|
||||
|
||||
func NewInvokableWorkflow(info *schema.ToolInfo,
|
||||
@@ -54,12 +59,14 @@ func NewInvokableWorkflow(info *schema.ToolInfo,
|
||||
repo wf.Repository,
|
||||
) wf.ToolFromWorkflow {
|
||||
return &invokableWorkflow{
|
||||
info: info,
|
||||
invoke: invoke,
|
||||
terminatePlan: terminatePlan,
|
||||
wfEntity: wfEntity,
|
||||
sc: sc,
|
||||
repo: repo,
|
||||
workflowTool: workflowTool{
|
||||
info: info,
|
||||
wfEntity: wfEntity,
|
||||
sc: sc,
|
||||
repo: repo,
|
||||
terminatePlan: terminatePlan,
|
||||
},
|
||||
invoke: invoke,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,6 +84,52 @@ func resumeOnce(rInfo *entity.ResumeRequest, callID string, allIEs map[string]*e
|
||||
}
|
||||
}
|
||||
|
||||
func (wt *workflowTool) prepare(ctx context.Context, rInfo *entity.ResumeRequest, argumentsInJSON string, opts ...tool.Option) (
|
||||
cancelCtx context.Context, executeID int64, input map[string]any, callOpts []einoCompose.Option, err error) {
|
||||
cfg := execute.GetExecuteConfig(opts...)
|
||||
|
||||
var runOpts []WorkflowRunnerOption
|
||||
if rInfo != nil && !rInfo.Resumed {
|
||||
runOpts = append(runOpts, WithResumeReq(rInfo))
|
||||
} else {
|
||||
runOpts = append(runOpts, WithInput(argumentsInJSON))
|
||||
}
|
||||
if container := execute.GetParentStreamContainer(opts...); container != nil {
|
||||
sr, sw := schema.Pipe[*entity.Message](10)
|
||||
container.AddChild(sr)
|
||||
runOpts = append(runOpts, WithStreamWriter(sw))
|
||||
}
|
||||
|
||||
var ws *nodes.ConversionWarnings
|
||||
|
||||
if (rInfo == nil || rInfo.Resumed) && len(wt.wfEntity.InputParams) > 0 {
|
||||
if err = sonic.UnmarshalString(argumentsInJSON, &input); err != nil {
|
||||
err = vo.WrapError(errno.ErrSerializationDeserializationFail, err)
|
||||
return
|
||||
}
|
||||
|
||||
var entryNode *schema2.NodeSchema
|
||||
for _, node := range wt.sc.Nodes {
|
||||
if node.Type == entity.NodeTypeEntry {
|
||||
entryNode = node
|
||||
break
|
||||
}
|
||||
}
|
||||
if entryNode == nil {
|
||||
panic("entry node not found in tool workflow")
|
||||
}
|
||||
input, ws, err = nodes.ConvertInputs(ctx, input, entryNode.OutputTypes)
|
||||
if err != nil {
|
||||
return
|
||||
} else if ws != nil {
|
||||
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
|
||||
}
|
||||
}
|
||||
|
||||
cancelCtx, executeID, callOpts, _, err = NewWorkflowRunner(wt.wfEntity.GetBasic(), wt.sc, cfg, runOpts...).Prepare(ctx)
|
||||
return
|
||||
}
|
||||
|
||||
func (i *invokableWorkflow) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) {
|
||||
rInfo, allIEs := execute.GetResumeRequest(opts...)
|
||||
var (
|
||||
@@ -97,52 +150,9 @@ func (i *invokableWorkflow) InvokableRun(ctx context.Context, argumentsInJSON st
|
||||
return "", einoCompose.InterruptAndRerun
|
||||
}
|
||||
|
||||
cfg := execute.GetExecuteConfig(opts...)
|
||||
defer resumeOnce(rInfo, callID, allIEs)
|
||||
|
||||
var runOpts []WorkflowRunnerOption
|
||||
if rInfo != nil && !rInfo.Resumed {
|
||||
runOpts = append(runOpts, WithResumeReq(rInfo))
|
||||
} else {
|
||||
runOpts = append(runOpts, WithInput(argumentsInJSON))
|
||||
}
|
||||
if sw := execute.GetIntermediateStreamWriter(opts...); sw != nil {
|
||||
runOpts = append(runOpts, WithStreamWriter(sw))
|
||||
}
|
||||
|
||||
var (
|
||||
cancelCtx context.Context
|
||||
executeID int64
|
||||
callOpts []einoCompose.Option
|
||||
in map[string]any
|
||||
err error
|
||||
ws *nodes.ConversionWarnings
|
||||
)
|
||||
|
||||
if rInfo == nil && len(i.wfEntity.InputParams) > 0 {
|
||||
if err = sonic.UnmarshalString(argumentsInJSON, &in); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var entryNode *schema2.NodeSchema
|
||||
for _, node := range i.sc.Nodes {
|
||||
if node.Type == entity.NodeTypeEntry {
|
||||
entryNode = node
|
||||
break
|
||||
}
|
||||
}
|
||||
if entryNode == nil {
|
||||
panic("entry node not found in tool workflow")
|
||||
}
|
||||
in, ws, err = nodes.ConvertInputs(ctx, in, entryNode.OutputTypes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
} else if ws != nil {
|
||||
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
|
||||
}
|
||||
}
|
||||
|
||||
cancelCtx, executeID, callOpts, _, err = NewWorkflowRunner(i.wfEntity.GetBasic(), i.sc, cfg, runOpts...).Prepare(ctx)
|
||||
cancelCtx, executeID, in, callOpts, err := i.prepare(ctx, rInfo, argumentsInJSON, opts...)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -198,12 +208,8 @@ func (i *invokableWorkflow) GetWorkflow() *entity.Workflow {
|
||||
}
|
||||
|
||||
type streamableWorkflow struct {
|
||||
info *schema.ToolInfo
|
||||
stream func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (*schema.StreamReader[map[string]any], error)
|
||||
terminatePlan vo.TerminatePlan
|
||||
wfEntity *entity.Workflow
|
||||
sc *schema2.WorkflowSchema
|
||||
repo wf.Repository
|
||||
workflowTool
|
||||
stream func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (*schema.StreamReader[map[string]any], error)
|
||||
}
|
||||
|
||||
func NewStreamableWorkflow(info *schema.ToolInfo,
|
||||
@@ -214,12 +220,14 @@ func NewStreamableWorkflow(info *schema.ToolInfo,
|
||||
repo wf.Repository,
|
||||
) wf.ToolFromWorkflow {
|
||||
return &streamableWorkflow{
|
||||
info: info,
|
||||
stream: stream,
|
||||
terminatePlan: terminatePlan,
|
||||
wfEntity: wfEntity,
|
||||
sc: sc,
|
||||
repo: repo,
|
||||
workflowTool: workflowTool{
|
||||
info: info,
|
||||
wfEntity: wfEntity,
|
||||
sc: sc,
|
||||
repo: repo,
|
||||
terminatePlan: terminatePlan,
|
||||
},
|
||||
stream: stream,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -247,52 +255,9 @@ func (s *streamableWorkflow) StreamableRun(ctx context.Context, argumentsInJSON
|
||||
return nil, einoCompose.InterruptAndRerun
|
||||
}
|
||||
|
||||
cfg := execute.GetExecuteConfig(opts...)
|
||||
defer resumeOnce(rInfo, callID, allIEs)
|
||||
|
||||
var runOpts []WorkflowRunnerOption
|
||||
if rInfo != nil && !rInfo.Resumed {
|
||||
runOpts = append(runOpts, WithResumeReq(rInfo))
|
||||
} else {
|
||||
runOpts = append(runOpts, WithInput(argumentsInJSON))
|
||||
}
|
||||
if sw := execute.GetIntermediateStreamWriter(opts...); sw != nil {
|
||||
runOpts = append(runOpts, WithStreamWriter(sw))
|
||||
}
|
||||
|
||||
var (
|
||||
cancelCtx context.Context
|
||||
executeID int64
|
||||
callOpts []einoCompose.Option
|
||||
in map[string]any
|
||||
err error
|
||||
ws *nodes.ConversionWarnings
|
||||
)
|
||||
|
||||
if rInfo == nil && len(s.wfEntity.InputParams) > 0 {
|
||||
if err = sonic.UnmarshalString(argumentsInJSON, &in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var entryNode *schema2.NodeSchema
|
||||
for _, node := range s.sc.Nodes {
|
||||
if node.Type == entity.NodeTypeEntry {
|
||||
entryNode = node
|
||||
break
|
||||
}
|
||||
}
|
||||
if entryNode == nil {
|
||||
panic("entry node not found in tool workflow")
|
||||
}
|
||||
in, ws, err = nodes.ConvertInputs(ctx, in, entryNode.OutputTypes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if ws != nil {
|
||||
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
|
||||
}
|
||||
}
|
||||
|
||||
cancelCtx, executeID, callOpts, _, err = NewWorkflowRunner(s.wfEntity.GetBasic(), s.sc, cfg, runOpts...).Prepare(ctx)
|
||||
cancelCtx, executeID, in, callOpts, err := s.prepare(ctx, rInfo, argumentsInJSON, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -370,6 +370,10 @@ func (w *WorkflowHandler) OnError(ctx context.Context, info *callbacks.RunInfo,
|
||||
interruptEvent.EventType, interruptEvent.NodeKey)
|
||||
}
|
||||
|
||||
if c.TokenCollector != nil { // wait until all streaming chunks are collected
|
||||
_ = c.TokenCollector.wait()
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
w.ch <- &Event{
|
||||
@@ -1309,6 +1313,7 @@ func (t *ToolHandler) OnEnd(ctx context.Context, info *callbacks.RunInfo,
|
||||
FunctionInfo: t.info,
|
||||
CallID: compose.GetToolCallID(ctx),
|
||||
Response: output.Response,
|
||||
Complete: true,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1347,6 +1352,7 @@ func (t *ToolHandler) OnEndWithStreamOutput(ctx context.Context, info *callbacks
|
||||
toolResponse: &entity.ToolResponseInfo{
|
||||
FunctionInfo: t.info,
|
||||
CallID: callID,
|
||||
Complete: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -76,6 +76,20 @@ func (t *TokenCollector) add(i int) {
|
||||
return
|
||||
}
|
||||
|
||||
func (t *TokenCollector) startStreamCounting() {
|
||||
t.wg.Add(1)
|
||||
if t.Parent != nil {
|
||||
t.Parent.startStreamCounting()
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TokenCollector) finishStreamCounting() {
|
||||
t.wg.Done()
|
||||
if t.Parent != nil {
|
||||
t.Parent.finishStreamCounting()
|
||||
}
|
||||
}
|
||||
|
||||
func getTokenCollector(ctx context.Context) *TokenCollector {
|
||||
c := GetExeCtx(ctx)
|
||||
if c == nil {
|
||||
@@ -92,7 +106,6 @@ func GetTokenCallbackHandler() callbacks.Handler {
|
||||
return ctx
|
||||
}
|
||||
c.add(1)
|
||||
//c.wg.Add(1)
|
||||
return ctx
|
||||
},
|
||||
OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *model.CallbackOutput) context.Context {
|
||||
@@ -114,6 +127,7 @@ func GetTokenCallbackHandler() callbacks.Handler {
|
||||
output.Close()
|
||||
return ctx
|
||||
}
|
||||
c.startStreamCounting()
|
||||
safego.Go(ctx, func() {
|
||||
defer func() {
|
||||
output.Close()
|
||||
@@ -141,6 +155,7 @@ func GetTokenCallbackHandler() callbacks.Handler {
|
||||
if newC.TotalTokens > 0 {
|
||||
c.addTokenUsage(newC)
|
||||
}
|
||||
c.finishStreamCounting()
|
||||
})
|
||||
return ctx
|
||||
},
|
||||
|
||||
@@ -789,6 +789,7 @@ func HandleExecuteEvent(ctx context.Context,
|
||||
logs.CtxInfof(ctx, "[handleExecuteEvent] finish, returned event type: %v, workflow id: %d",
|
||||
event.Type, event.Context.RootWorkflowBasic.ID)
|
||||
cancelTicker.Stop() // Clean up timer
|
||||
waitUntilToolFinish(ctx)
|
||||
if timeoutFn != nil {
|
||||
timeoutFn()
|
||||
}
|
||||
@@ -880,6 +881,7 @@ func cacheToolStreamingResponse(ctx context.Context, event *Event) {
|
||||
c[event.NodeKey][event.toolResponse.CallID].output = event.toolResponse
|
||||
}
|
||||
c[event.NodeKey][event.toolResponse.CallID].output.Response += event.toolResponse.Response
|
||||
c[event.NodeKey][event.toolResponse.CallID].output.Complete = event.toolResponse.Complete
|
||||
}
|
||||
|
||||
func getFCInfos(ctx context.Context, nodeKey vo.NodeKey) map[string]*fcInfo {
|
||||
@@ -887,6 +889,35 @@ func getFCInfos(ctx context.Context, nodeKey vo.NodeKey) map[string]*fcInfo {
|
||||
return c[nodeKey]
|
||||
}
|
||||
|
||||
func waitUntilToolFinish(ctx context.Context) {
|
||||
var cnt int
|
||||
outer:
|
||||
for {
|
||||
if cnt > 1000 {
|
||||
return
|
||||
}
|
||||
|
||||
c := ctx.Value(fcCacheKey{}).(map[vo.NodeKey]map[string]*fcInfo)
|
||||
if len(c) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for _, m := range c {
|
||||
for _, info := range m {
|
||||
if info.output == nil {
|
||||
cnt++
|
||||
continue outer
|
||||
}
|
||||
|
||||
if !info.output.Complete {
|
||||
cnt++
|
||||
continue outer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fcInfo) inputString() string {
|
||||
if f.input == nil {
|
||||
return ""
|
||||
|
||||
74
backend/domain/workflow/internal/execute/stream_container.go
Normal file
74
backend/domain/workflow/internal/execute/stream_container.go
Normal file
@@ -0,0 +1,74 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package execute
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
)
|
||||
|
||||
type StreamContainer struct {
|
||||
sw *schema.StreamWriter[*entity.Message]
|
||||
subStreams chan *schema.StreamReader[*entity.Message]
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewStreamContainer(sw *schema.StreamWriter[*entity.Message]) *StreamContainer {
|
||||
return &StreamContainer{
|
||||
sw: sw,
|
||||
subStreams: make(chan *schema.StreamReader[*entity.Message]),
|
||||
}
|
||||
}
|
||||
|
||||
func (sc *StreamContainer) AddChild(sr *schema.StreamReader[*entity.Message]) {
|
||||
sc.wg.Add(1)
|
||||
sc.subStreams <- sr
|
||||
}
|
||||
|
||||
func (sc *StreamContainer) PipeAll() {
|
||||
sc.wg.Add(1)
|
||||
|
||||
for sr := range sc.subStreams {
|
||||
go func() {
|
||||
defer sr.Close()
|
||||
|
||||
for {
|
||||
msg, err := sr.Recv()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
sc.wg.Done()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
sc.sw.Send(msg, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (sc *StreamContainer) Done() {
|
||||
sc.wg.Done()
|
||||
sc.wg.Wait()
|
||||
close(sc.subStreams)
|
||||
sc.sw.Close()
|
||||
}
|
||||
@@ -27,7 +27,7 @@ import (
|
||||
|
||||
type workflowToolOption struct {
|
||||
resumeReq *entity.ResumeRequest
|
||||
sw *schema.StreamWriter[*entity.Message]
|
||||
streamContainer *StreamContainer
|
||||
exeCfg workflowModel.ExecuteConfig
|
||||
allInterruptEvents map[string]*entity.ToolInterruptEvent
|
||||
parentTokenCollector *TokenCollector
|
||||
@@ -40,9 +40,9 @@ func WithResume(req *entity.ResumeRequest, all map[string]*entity.ToolInterruptE
|
||||
})
|
||||
}
|
||||
|
||||
func WithIntermediateStreamWriter(sw *schema.StreamWriter[*entity.Message]) tool.Option {
|
||||
func WithParentStreamContainer(sc *StreamContainer) tool.Option {
|
||||
return tool.WrapImplSpecificOptFn(func(opts *workflowToolOption) {
|
||||
opts.sw = sw
|
||||
opts.streamContainer = sc
|
||||
})
|
||||
}
|
||||
|
||||
@@ -57,9 +57,9 @@ func GetResumeRequest(opts ...tool.Option) (*entity.ResumeRequest, map[string]*e
|
||||
return opt.resumeReq, opt.allInterruptEvents
|
||||
}
|
||||
|
||||
func GetIntermediateStreamWriter(opts ...tool.Option) *schema.StreamWriter[*entity.Message] {
|
||||
func GetParentStreamContainer(opts ...tool.Option) *StreamContainer {
|
||||
opt := tool.GetImplSpecificOptions(&workflowToolOption{}, opts...)
|
||||
return opt.sw
|
||||
return opt.streamContainer
|
||||
}
|
||||
|
||||
func GetExecuteConfig(opts ...tool.Option) workflowModel.ExecuteConfig {
|
||||
@@ -67,11 +67,22 @@ func GetExecuteConfig(opts ...tool.Option) workflowModel.ExecuteConfig {
|
||||
return opt.exeCfg
|
||||
}
|
||||
|
||||
// WithMessagePipe returns an Option which is meant to be passed to the tool workflow, as well as a StreamReader to read the messages from the tool workflow.
|
||||
// This Option will apply to ALL workflow tools to be executed by eino's ToolsNode. The workflow tools will emit messages to this stream.
|
||||
// WithMessagePipe returns an Option which is meant to be passed to the tool workflow,
|
||||
// as well as a StreamReader to read the messages from the tool workflow.
|
||||
// This Option will apply to ALL workflow tools to be executed by eino's ToolsNode.
|
||||
// The workflow tools will emit messages to this stream.
|
||||
// The caller can receive from the returned StreamReader to get the messages from the tool workflow.
|
||||
func WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message]) {
|
||||
func WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message], func()) {
|
||||
sr, sw := schema.Pipe[*entity.Message](10)
|
||||
opt := compose.WithToolsNodeOption(compose.WithToolOption(WithIntermediateStreamWriter(sw)))
|
||||
return opt, sr
|
||||
container := &StreamContainer{
|
||||
sw: sw,
|
||||
subStreams: make(chan *schema.StreamReader[*entity.Message]),
|
||||
}
|
||||
|
||||
go container.PipeAll()
|
||||
|
||||
opt := compose.WithToolsNodeOption(compose.WithToolOption(WithParentStreamContainer(container)))
|
||||
return opt, sr, func() {
|
||||
container.Done()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -446,9 +446,6 @@ func (b *Batch) Invoke(ctx context.Context, in map[string]any, opts ...nodes.Nod
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fmt.Println("save interruptEvent in state within batch: ", iEvent)
|
||||
fmt.Println("save composite info in state within batch: ", compState)
|
||||
|
||||
return nil, compose.InterruptAndRerun
|
||||
} else {
|
||||
err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error {
|
||||
|
||||
@@ -92,7 +92,6 @@ func ConvertInputs(ctx context.Context, in map[string]any, tInfo map[string]*vo.
|
||||
t, ok := tInfo[k]
|
||||
if !ok {
|
||||
// for input fields not explicitly defined, just pass the string value through
|
||||
logs.CtxWarnf(ctx, "input %s not found in type info", k)
|
||||
if !options.skipUnknownFields {
|
||||
out[k] = in[k]
|
||||
}
|
||||
@@ -323,7 +322,6 @@ func convertToObject(ctx context.Context, in any, path string, t *vo.TypeInfo, o
|
||||
propType, ok := t.Properties[k]
|
||||
if !ok {
|
||||
// for input fields not explicitly defined, just pass the value through
|
||||
logs.CtxWarnf(ctx, "input %s.%s not found in type info", path, k)
|
||||
if !options.skipUnknownFields {
|
||||
out[k] = v
|
||||
}
|
||||
|
||||
@@ -143,6 +143,7 @@ const (
|
||||
knowledgeUserPromptTemplateKey = "knowledge_user_prompt_prefix"
|
||||
templateNodeKey = "template"
|
||||
llmNodeKey = "llm"
|
||||
reactGraphName = "workflow_llm_react_agent"
|
||||
outputConvertNodeKey = "output_convert"
|
||||
)
|
||||
|
||||
@@ -620,6 +621,7 @@ func (c *Config) Build(ctx context.Context, ns *schema2.NodeSchema, _ ...schema2
|
||||
ToolCallingModel: m,
|
||||
ToolsConfig: compose.ToolsNodeConfig{Tools: tools},
|
||||
ModelNodeName: agentModelName,
|
||||
GraphName: reactGraphName,
|
||||
}
|
||||
|
||||
if len(toolsReturnDirectly) > 0 {
|
||||
@@ -635,7 +637,7 @@ func (c *Config) Build(ctx context.Context, ns *schema2.NodeSchema, _ ...schema2
|
||||
}
|
||||
|
||||
agentNode, opts := reactAgent.ExportGraph()
|
||||
opts = append(opts, compose.WithNodeName("workflow_llm_react_agent"))
|
||||
opts = append(opts, compose.WithNodeName(reactGraphName))
|
||||
_ = g.AddGraphNode(llmNodeKey, agentNode, opts...)
|
||||
} else {
|
||||
_ = g.AddChatModelNode(llmNodeKey, modelWithInfo)
|
||||
@@ -867,12 +869,12 @@ func jsonParse(ctx context.Context, data string, schema_ map[string]*vo.TypeInfo
|
||||
}
|
||||
|
||||
type llmOptions struct {
|
||||
toolWorkflowSW *schema.StreamWriter[*entity.Message]
|
||||
toolWorkflowContainer *execute.StreamContainer
|
||||
}
|
||||
|
||||
func WithToolWorkflowMessageWriter(sw *schema.StreamWriter[*entity.Message]) nodes.NodeOption {
|
||||
func WithToolWorkflowStreamContainer(container *execute.StreamContainer) nodes.NodeOption {
|
||||
return nodes.WrapImplSpecificOptFn(func(o *llmOptions) {
|
||||
o.toolWorkflowSW = sw
|
||||
o.toolWorkflowContainer = container
|
||||
})
|
||||
}
|
||||
|
||||
@@ -880,7 +882,8 @@ type llmState = map[string]any
|
||||
|
||||
const agentModelName = "agent_model"
|
||||
|
||||
func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...nodes.NodeOption) (composeOpts []compose.Option, resumingEvent *entity.InterruptEvent, err error) {
|
||||
func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...nodes.NodeOption) (
|
||||
composeOpts []compose.Option, resumingEvent *entity.InterruptEvent, err error) {
|
||||
c := execute.GetExeCtx(ctx)
|
||||
if c != nil {
|
||||
resumingEvent = c.NodeCtx.ResumingEvent
|
||||
@@ -890,7 +893,7 @@ func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...nodes.NodeO
|
||||
if c != nil && c.RootCtx.ResumeEvent != nil {
|
||||
// check if we are not resuming, but previously interrupted. Interrupt immediately.
|
||||
if resumingEvent == nil {
|
||||
err := compose.ProcessState(ctx, func(ctx context.Context, state ToolInterruptEventStore) error {
|
||||
err = compose.ProcessState(ctx, func(ctx context.Context, state ToolInterruptEventStore) error {
|
||||
var e error
|
||||
previousToolES, e = state.GetToolInterruptEvents(c.NodeKey)
|
||||
if e != nil {
|
||||
@@ -899,11 +902,12 @@ func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...nodes.NodeO
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return
|
||||
}
|
||||
|
||||
if len(previousToolES) > 0 {
|
||||
return nil, nil, compose.InterruptAndRerun
|
||||
err = compose.InterruptAndRerun
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -936,7 +940,7 @@ func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...nodes.NodeO
|
||||
return e
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return
|
||||
}
|
||||
composeOpts = append(composeOpts, compose.WithToolsNodeOption(
|
||||
compose.WithToolOption(
|
||||
@@ -986,27 +990,9 @@ func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...nodes.NodeO
|
||||
}
|
||||
|
||||
llmOpts := nodes.GetImplSpecificOptions(&llmOptions{}, opts...)
|
||||
if llmOpts.toolWorkflowSW != nil {
|
||||
toolMsgOpt, toolMsgSR := execute.WithMessagePipe()
|
||||
composeOpts = append(composeOpts, toolMsgOpt)
|
||||
|
||||
safego.Go(ctx, func() {
|
||||
defer toolMsgSR.Close()
|
||||
for {
|
||||
msg, err := toolMsgSR.Recv()
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return
|
||||
}
|
||||
logs.CtxErrorf(ctx, "failed to receive message from tool workflow: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
logs.Infof("received message from tool workflow: %+v", msg)
|
||||
|
||||
llmOpts.toolWorkflowSW.Send(msg, nil)
|
||||
}
|
||||
})
|
||||
if container := llmOpts.toolWorkflowContainer; container != nil {
|
||||
composeOpts = append(composeOpts, compose.WithToolsNodeOption(compose.WithToolOption(
|
||||
execute.WithParentStreamContainer(container))))
|
||||
}
|
||||
|
||||
resolvedSources, err := nodes.ResolveStreamSources(ctx, l.fullSources)
|
||||
|
||||
@@ -33,7 +33,7 @@ type asToolImpl struct {
|
||||
repo workflow.Repository
|
||||
}
|
||||
|
||||
func (a *asToolImpl) WithMessagePipe() (einoCompose.Option, *schema.StreamReader[*entity.Message]) {
|
||||
func (a *asToolImpl) WithMessagePipe() (einoCompose.Option, *schema.StreamReader[*entity.Message], func()) {
|
||||
return execute.WithMessagePipe()
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user