fix: workflow tool closes stream writer correctly (#1839)
This commit is contained in:
@@ -370,6 +370,10 @@ func (w *WorkflowHandler) OnError(ctx context.Context, info *callbacks.RunInfo,
|
||||
interruptEvent.EventType, interruptEvent.NodeKey)
|
||||
}
|
||||
|
||||
if c.TokenCollector != nil { // wait until all streaming chunks are collected
|
||||
_ = c.TokenCollector.wait()
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
w.ch <- &Event{
|
||||
@@ -1309,6 +1313,7 @@ func (t *ToolHandler) OnEnd(ctx context.Context, info *callbacks.RunInfo,
|
||||
FunctionInfo: t.info,
|
||||
CallID: compose.GetToolCallID(ctx),
|
||||
Response: output.Response,
|
||||
Complete: true,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1347,6 +1352,7 @@ func (t *ToolHandler) OnEndWithStreamOutput(ctx context.Context, info *callbacks
|
||||
toolResponse: &entity.ToolResponseInfo{
|
||||
FunctionInfo: t.info,
|
||||
CallID: callID,
|
||||
Complete: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -76,6 +76,20 @@ func (t *TokenCollector) add(i int) {
|
||||
return
|
||||
}
|
||||
|
||||
func (t *TokenCollector) startStreamCounting() {
|
||||
t.wg.Add(1)
|
||||
if t.Parent != nil {
|
||||
t.Parent.startStreamCounting()
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TokenCollector) finishStreamCounting() {
|
||||
t.wg.Done()
|
||||
if t.Parent != nil {
|
||||
t.Parent.finishStreamCounting()
|
||||
}
|
||||
}
|
||||
|
||||
func getTokenCollector(ctx context.Context) *TokenCollector {
|
||||
c := GetExeCtx(ctx)
|
||||
if c == nil {
|
||||
@@ -92,7 +106,6 @@ func GetTokenCallbackHandler() callbacks.Handler {
|
||||
return ctx
|
||||
}
|
||||
c.add(1)
|
||||
//c.wg.Add(1)
|
||||
return ctx
|
||||
},
|
||||
OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *model.CallbackOutput) context.Context {
|
||||
@@ -114,6 +127,7 @@ func GetTokenCallbackHandler() callbacks.Handler {
|
||||
output.Close()
|
||||
return ctx
|
||||
}
|
||||
c.startStreamCounting()
|
||||
safego.Go(ctx, func() {
|
||||
defer func() {
|
||||
output.Close()
|
||||
@@ -141,6 +155,7 @@ func GetTokenCallbackHandler() callbacks.Handler {
|
||||
if newC.TotalTokens > 0 {
|
||||
c.addTokenUsage(newC)
|
||||
}
|
||||
c.finishStreamCounting()
|
||||
})
|
||||
return ctx
|
||||
},
|
||||
|
||||
@@ -789,6 +789,7 @@ func HandleExecuteEvent(ctx context.Context,
|
||||
logs.CtxInfof(ctx, "[handleExecuteEvent] finish, returned event type: %v, workflow id: %d",
|
||||
event.Type, event.Context.RootWorkflowBasic.ID)
|
||||
cancelTicker.Stop() // Clean up timer
|
||||
waitUntilToolFinish(ctx)
|
||||
if timeoutFn != nil {
|
||||
timeoutFn()
|
||||
}
|
||||
@@ -880,6 +881,7 @@ func cacheToolStreamingResponse(ctx context.Context, event *Event) {
|
||||
c[event.NodeKey][event.toolResponse.CallID].output = event.toolResponse
|
||||
}
|
||||
c[event.NodeKey][event.toolResponse.CallID].output.Response += event.toolResponse.Response
|
||||
c[event.NodeKey][event.toolResponse.CallID].output.Complete = event.toolResponse.Complete
|
||||
}
|
||||
|
||||
func getFCInfos(ctx context.Context, nodeKey vo.NodeKey) map[string]*fcInfo {
|
||||
@@ -887,6 +889,35 @@ func getFCInfos(ctx context.Context, nodeKey vo.NodeKey) map[string]*fcInfo {
|
||||
return c[nodeKey]
|
||||
}
|
||||
|
||||
func waitUntilToolFinish(ctx context.Context) {
|
||||
var cnt int
|
||||
outer:
|
||||
for {
|
||||
if cnt > 1000 {
|
||||
return
|
||||
}
|
||||
|
||||
c := ctx.Value(fcCacheKey{}).(map[vo.NodeKey]map[string]*fcInfo)
|
||||
if len(c) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for _, m := range c {
|
||||
for _, info := range m {
|
||||
if info.output == nil {
|
||||
cnt++
|
||||
continue outer
|
||||
}
|
||||
|
||||
if !info.output.Complete {
|
||||
cnt++
|
||||
continue outer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fcInfo) inputString() string {
|
||||
if f.input == nil {
|
||||
return ""
|
||||
|
||||
74
backend/domain/workflow/internal/execute/stream_container.go
Normal file
74
backend/domain/workflow/internal/execute/stream_container.go
Normal file
@@ -0,0 +1,74 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package execute
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
)
|
||||
|
||||
type StreamContainer struct {
|
||||
sw *schema.StreamWriter[*entity.Message]
|
||||
subStreams chan *schema.StreamReader[*entity.Message]
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewStreamContainer(sw *schema.StreamWriter[*entity.Message]) *StreamContainer {
|
||||
return &StreamContainer{
|
||||
sw: sw,
|
||||
subStreams: make(chan *schema.StreamReader[*entity.Message]),
|
||||
}
|
||||
}
|
||||
|
||||
func (sc *StreamContainer) AddChild(sr *schema.StreamReader[*entity.Message]) {
|
||||
sc.wg.Add(1)
|
||||
sc.subStreams <- sr
|
||||
}
|
||||
|
||||
func (sc *StreamContainer) PipeAll() {
|
||||
sc.wg.Add(1)
|
||||
|
||||
for sr := range sc.subStreams {
|
||||
go func() {
|
||||
defer sr.Close()
|
||||
|
||||
for {
|
||||
msg, err := sr.Recv()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
sc.wg.Done()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
sc.sw.Send(msg, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (sc *StreamContainer) Done() {
|
||||
sc.wg.Done()
|
||||
sc.wg.Wait()
|
||||
close(sc.subStreams)
|
||||
sc.sw.Close()
|
||||
}
|
||||
@@ -27,7 +27,7 @@ import (
|
||||
|
||||
type workflowToolOption struct {
|
||||
resumeReq *entity.ResumeRequest
|
||||
sw *schema.StreamWriter[*entity.Message]
|
||||
streamContainer *StreamContainer
|
||||
exeCfg workflowModel.ExecuteConfig
|
||||
allInterruptEvents map[string]*entity.ToolInterruptEvent
|
||||
parentTokenCollector *TokenCollector
|
||||
@@ -40,9 +40,9 @@ func WithResume(req *entity.ResumeRequest, all map[string]*entity.ToolInterruptE
|
||||
})
|
||||
}
|
||||
|
||||
func WithIntermediateStreamWriter(sw *schema.StreamWriter[*entity.Message]) tool.Option {
|
||||
func WithParentStreamContainer(sc *StreamContainer) tool.Option {
|
||||
return tool.WrapImplSpecificOptFn(func(opts *workflowToolOption) {
|
||||
opts.sw = sw
|
||||
opts.streamContainer = sc
|
||||
})
|
||||
}
|
||||
|
||||
@@ -57,9 +57,9 @@ func GetResumeRequest(opts ...tool.Option) (*entity.ResumeRequest, map[string]*e
|
||||
return opt.resumeReq, opt.allInterruptEvents
|
||||
}
|
||||
|
||||
func GetIntermediateStreamWriter(opts ...tool.Option) *schema.StreamWriter[*entity.Message] {
|
||||
func GetParentStreamContainer(opts ...tool.Option) *StreamContainer {
|
||||
opt := tool.GetImplSpecificOptions(&workflowToolOption{}, opts...)
|
||||
return opt.sw
|
||||
return opt.streamContainer
|
||||
}
|
||||
|
||||
func GetExecuteConfig(opts ...tool.Option) workflowModel.ExecuteConfig {
|
||||
@@ -67,11 +67,22 @@ func GetExecuteConfig(opts ...tool.Option) workflowModel.ExecuteConfig {
|
||||
return opt.exeCfg
|
||||
}
|
||||
|
||||
// WithMessagePipe returns an Option which is meant to be passed to the tool workflow, as well as a StreamReader to read the messages from the tool workflow.
|
||||
// This Option will apply to ALL workflow tools to be executed by eino's ToolsNode. The workflow tools will emit messages to this stream.
|
||||
// WithMessagePipe returns an Option which is meant to be passed to the tool workflow,
|
||||
// as well as a StreamReader to read the messages from the tool workflow.
|
||||
// This Option will apply to ALL workflow tools to be executed by eino's ToolsNode.
|
||||
// The workflow tools will emit messages to this stream.
|
||||
// The caller can receive from the returned StreamReader to get the messages from the tool workflow.
|
||||
func WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message]) {
|
||||
func WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message], func()) {
|
||||
sr, sw := schema.Pipe[*entity.Message](10)
|
||||
opt := compose.WithToolsNodeOption(compose.WithToolOption(WithIntermediateStreamWriter(sw)))
|
||||
return opt, sr
|
||||
container := &StreamContainer{
|
||||
sw: sw,
|
||||
subStreams: make(chan *schema.StreamReader[*entity.Message]),
|
||||
}
|
||||
|
||||
go container.PipeAll()
|
||||
|
||||
opt := compose.WithToolsNodeOption(compose.WithToolOption(WithParentStreamContainer(container)))
|
||||
return opt, sr, func() {
|
||||
container.Done()
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user