fix: workflow tool closes stream writer correctly (#1839)

This commit is contained in:
shentongmartin
2025-08-27 16:29:42 +08:00
committed by GitHub
parent 263a75b1c0
commit 5562800958
19 changed files with 742 additions and 620 deletions

View File

@@ -370,6 +370,10 @@ func (w *WorkflowHandler) OnError(ctx context.Context, info *callbacks.RunInfo,
interruptEvent.EventType, interruptEvent.NodeKey)
}
if c.TokenCollector != nil { // wait until all streaming chunks are collected
_ = c.TokenCollector.wait()
}
done := make(chan struct{})
w.ch <- &Event{
@@ -1309,6 +1313,7 @@ func (t *ToolHandler) OnEnd(ctx context.Context, info *callbacks.RunInfo,
FunctionInfo: t.info,
CallID: compose.GetToolCallID(ctx),
Response: output.Response,
Complete: true,
},
}
@@ -1347,6 +1352,7 @@ func (t *ToolHandler) OnEndWithStreamOutput(ctx context.Context, info *callbacks
toolResponse: &entity.ToolResponseInfo{
FunctionInfo: t.info,
CallID: callID,
Complete: true,
},
}
}

View File

@@ -76,6 +76,20 @@ func (t *TokenCollector) add(i int) {
return
}
func (t *TokenCollector) startStreamCounting() {
t.wg.Add(1)
if t.Parent != nil {
t.Parent.startStreamCounting()
}
}
func (t *TokenCollector) finishStreamCounting() {
t.wg.Done()
if t.Parent != nil {
t.Parent.finishStreamCounting()
}
}
func getTokenCollector(ctx context.Context) *TokenCollector {
c := GetExeCtx(ctx)
if c == nil {
@@ -92,7 +106,6 @@ func GetTokenCallbackHandler() callbacks.Handler {
return ctx
}
c.add(1)
//c.wg.Add(1)
return ctx
},
OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *model.CallbackOutput) context.Context {
@@ -114,6 +127,7 @@ func GetTokenCallbackHandler() callbacks.Handler {
output.Close()
return ctx
}
c.startStreamCounting()
safego.Go(ctx, func() {
defer func() {
output.Close()
@@ -141,6 +155,7 @@ func GetTokenCallbackHandler() callbacks.Handler {
if newC.TotalTokens > 0 {
c.addTokenUsage(newC)
}
c.finishStreamCounting()
})
return ctx
},

View File

@@ -789,6 +789,7 @@ func HandleExecuteEvent(ctx context.Context,
logs.CtxInfof(ctx, "[handleExecuteEvent] finish, returned event type: %v, workflow id: %d",
event.Type, event.Context.RootWorkflowBasic.ID)
cancelTicker.Stop() // Clean up timer
waitUntilToolFinish(ctx)
if timeoutFn != nil {
timeoutFn()
}
@@ -880,6 +881,7 @@ func cacheToolStreamingResponse(ctx context.Context, event *Event) {
c[event.NodeKey][event.toolResponse.CallID].output = event.toolResponse
}
c[event.NodeKey][event.toolResponse.CallID].output.Response += event.toolResponse.Response
c[event.NodeKey][event.toolResponse.CallID].output.Complete = event.toolResponse.Complete
}
func getFCInfos(ctx context.Context, nodeKey vo.NodeKey) map[string]*fcInfo {
@@ -887,6 +889,35 @@ func getFCInfos(ctx context.Context, nodeKey vo.NodeKey) map[string]*fcInfo {
return c[nodeKey]
}
func waitUntilToolFinish(ctx context.Context) {
var cnt int
outer:
for {
if cnt > 1000 {
return
}
c := ctx.Value(fcCacheKey{}).(map[vo.NodeKey]map[string]*fcInfo)
if len(c) == 0 {
return
}
for _, m := range c {
for _, info := range m {
if info.output == nil {
cnt++
continue outer
}
if !info.output.Complete {
cnt++
continue outer
}
}
}
}
}
func (f *fcInfo) inputString() string {
if f.input == nil {
return ""

View File

@@ -0,0 +1,74 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package execute
import (
"errors"
"io"
"sync"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
)
type StreamContainer struct {
sw *schema.StreamWriter[*entity.Message]
subStreams chan *schema.StreamReader[*entity.Message]
wg sync.WaitGroup
}
func NewStreamContainer(sw *schema.StreamWriter[*entity.Message]) *StreamContainer {
return &StreamContainer{
sw: sw,
subStreams: make(chan *schema.StreamReader[*entity.Message]),
}
}
func (sc *StreamContainer) AddChild(sr *schema.StreamReader[*entity.Message]) {
sc.wg.Add(1)
sc.subStreams <- sr
}
func (sc *StreamContainer) PipeAll() {
sc.wg.Add(1)
for sr := range sc.subStreams {
go func() {
defer sr.Close()
for {
msg, err := sr.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
sc.wg.Done()
return
}
}
sc.sw.Send(msg, err)
}
}()
}
}
func (sc *StreamContainer) Done() {
sc.wg.Done()
sc.wg.Wait()
close(sc.subStreams)
sc.sw.Close()
}

View File

@@ -27,7 +27,7 @@ import (
type workflowToolOption struct {
resumeReq *entity.ResumeRequest
sw *schema.StreamWriter[*entity.Message]
streamContainer *StreamContainer
exeCfg workflowModel.ExecuteConfig
allInterruptEvents map[string]*entity.ToolInterruptEvent
parentTokenCollector *TokenCollector
@@ -40,9 +40,9 @@ func WithResume(req *entity.ResumeRequest, all map[string]*entity.ToolInterruptE
})
}
func WithIntermediateStreamWriter(sw *schema.StreamWriter[*entity.Message]) tool.Option {
func WithParentStreamContainer(sc *StreamContainer) tool.Option {
return tool.WrapImplSpecificOptFn(func(opts *workflowToolOption) {
opts.sw = sw
opts.streamContainer = sc
})
}
@@ -57,9 +57,9 @@ func GetResumeRequest(opts ...tool.Option) (*entity.ResumeRequest, map[string]*e
return opt.resumeReq, opt.allInterruptEvents
}
func GetIntermediateStreamWriter(opts ...tool.Option) *schema.StreamWriter[*entity.Message] {
func GetParentStreamContainer(opts ...tool.Option) *StreamContainer {
opt := tool.GetImplSpecificOptions(&workflowToolOption{}, opts...)
return opt.sw
return opt.streamContainer
}
func GetExecuteConfig(opts ...tool.Option) workflowModel.ExecuteConfig {
@@ -67,11 +67,22 @@ func GetExecuteConfig(opts ...tool.Option) workflowModel.ExecuteConfig {
return opt.exeCfg
}
// WithMessagePipe returns an Option which is meant to be passed to the tool workflow, as well as a StreamReader to read the messages from the tool workflow.
// This Option will apply to ALL workflow tools to be executed by eino's ToolsNode. The workflow tools will emit messages to this stream.
// WithMessagePipe returns an Option which is meant to be passed to the tool workflow,
// as well as a StreamReader to read the messages from the tool workflow.
// This Option will apply to ALL workflow tools to be executed by eino's ToolsNode.
// The workflow tools will emit messages to this stream.
// The caller can receive from the returned StreamReader to get the messages from the tool workflow.
func WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message]) {
func WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message], func()) {
sr, sw := schema.Pipe[*entity.Message](10)
opt := compose.WithToolsNodeOption(compose.WithToolOption(WithIntermediateStreamWriter(sw)))
return opt, sr
container := &StreamContainer{
sw: sw,
subStreams: make(chan *schema.StreamReader[*entity.Message]),
}
go container.PipeAll()
opt := compose.WithToolsNodeOption(compose.WithToolOption(WithParentStreamContainer(container)))
return opt, sr, func() {
container.Done()
}
}