fix: workflow tool closes stream writer correctly (#1839)

This commit is contained in:
shentongmartin
2025-08-27 16:29:42 +08:00
committed by GitHub
parent 263a75b1c0
commit 5562800958
19 changed files with 742 additions and 620 deletions

View File

@@ -143,6 +143,7 @@ const (
knowledgeUserPromptTemplateKey = "knowledge_user_prompt_prefix"
templateNodeKey = "template"
llmNodeKey = "llm"
reactGraphName = "workflow_llm_react_agent"
outputConvertNodeKey = "output_convert"
)
@@ -620,6 +621,7 @@ func (c *Config) Build(ctx context.Context, ns *schema2.NodeSchema, _ ...schema2
ToolCallingModel: m,
ToolsConfig: compose.ToolsNodeConfig{Tools: tools},
ModelNodeName: agentModelName,
GraphName: reactGraphName,
}
if len(toolsReturnDirectly) > 0 {
@@ -635,7 +637,7 @@ func (c *Config) Build(ctx context.Context, ns *schema2.NodeSchema, _ ...schema2
}
agentNode, opts := reactAgent.ExportGraph()
opts = append(opts, compose.WithNodeName("workflow_llm_react_agent"))
opts = append(opts, compose.WithNodeName(reactGraphName))
_ = g.AddGraphNode(llmNodeKey, agentNode, opts...)
} else {
_ = g.AddChatModelNode(llmNodeKey, modelWithInfo)
@@ -867,12 +869,12 @@ func jsonParse(ctx context.Context, data string, schema_ map[string]*vo.TypeInfo
}
type llmOptions struct {
toolWorkflowSW *schema.StreamWriter[*entity.Message]
toolWorkflowContainer *execute.StreamContainer
}
func WithToolWorkflowMessageWriter(sw *schema.StreamWriter[*entity.Message]) nodes.NodeOption {
func WithToolWorkflowStreamContainer(container *execute.StreamContainer) nodes.NodeOption {
return nodes.WrapImplSpecificOptFn(func(o *llmOptions) {
o.toolWorkflowSW = sw
o.toolWorkflowContainer = container
})
}
@@ -880,7 +882,8 @@ type llmState = map[string]any
const agentModelName = "agent_model"
func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...nodes.NodeOption) (composeOpts []compose.Option, resumingEvent *entity.InterruptEvent, err error) {
func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...nodes.NodeOption) (
composeOpts []compose.Option, resumingEvent *entity.InterruptEvent, err error) {
c := execute.GetExeCtx(ctx)
if c != nil {
resumingEvent = c.NodeCtx.ResumingEvent
@@ -890,7 +893,7 @@ func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...nodes.NodeO
if c != nil && c.RootCtx.ResumeEvent != nil {
// check if we are not resuming, but previously interrupted. Interrupt immediately.
if resumingEvent == nil {
err := compose.ProcessState(ctx, func(ctx context.Context, state ToolInterruptEventStore) error {
err = compose.ProcessState(ctx, func(ctx context.Context, state ToolInterruptEventStore) error {
var e error
previousToolES, e = state.GetToolInterruptEvents(c.NodeKey)
if e != nil {
@@ -899,11 +902,12 @@ func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...nodes.NodeO
return nil
})
if err != nil {
return nil, nil, err
return
}
if len(previousToolES) > 0 {
return nil, nil, compose.InterruptAndRerun
err = compose.InterruptAndRerun
return
}
}
}
@@ -936,7 +940,7 @@ func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...nodes.NodeO
return e
})
if err != nil {
return nil, nil, err
return
}
composeOpts = append(composeOpts, compose.WithToolsNodeOption(
compose.WithToolOption(
@@ -986,27 +990,9 @@ func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...nodes.NodeO
}
llmOpts := nodes.GetImplSpecificOptions(&llmOptions{}, opts...)
if llmOpts.toolWorkflowSW != nil {
toolMsgOpt, toolMsgSR := execute.WithMessagePipe()
composeOpts = append(composeOpts, toolMsgOpt)
safego.Go(ctx, func() {
defer toolMsgSR.Close()
for {
msg, err := toolMsgSR.Recv()
if err != nil {
if err == io.EOF {
return
}
logs.CtxErrorf(ctx, "failed to receive message from tool workflow: %v", err)
return
}
logs.Infof("received message from tool workflow: %+v", msg)
llmOpts.toolWorkflowSW.Send(msg, nil)
}
})
if container := llmOpts.toolWorkflowContainer; container != nil {
composeOpts = append(composeOpts, compose.WithToolsNodeOption(compose.WithToolOption(
execute.WithParentStreamContainer(container))))
}
resolvedSources, err := nodes.ResolveStreamSources(ctx, l.fullSources)