refactor: how to add a node type in workflow (#558)
This commit is contained in:
@@ -18,7 +18,6 @@ package emitter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
@@ -26,28 +25,77 @@ import (
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
)
|
||||
|
||||
type OutputEmitter struct {
|
||||
cfg *Config
|
||||
Template string
|
||||
FullSources map[string]*schema2.SourceInfo
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Template string
|
||||
FullSources map[string]*nodes.SourceInfo
|
||||
Template string
|
||||
}
|
||||
|
||||
func New(_ context.Context, cfg *Config) (*OutputEmitter, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config is required")
|
||||
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema2.NodeSchema, error) {
|
||||
ns := &schema2.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeOutputEmitter,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: c,
|
||||
}
|
||||
|
||||
content := n.Data.Inputs.Content
|
||||
streamingOutput := n.Data.Inputs.StreamingOutput
|
||||
|
||||
if streamingOutput {
|
||||
ns.StreamConfigs = &schema2.StreamConfig{
|
||||
RequireStreamingInput: true,
|
||||
}
|
||||
} else {
|
||||
ns.StreamConfigs = &schema2.StreamConfig{
|
||||
RequireStreamingInput: false,
|
||||
}
|
||||
}
|
||||
|
||||
if content != nil {
|
||||
if content.Type != vo.VariableTypeString {
|
||||
return nil, fmt.Errorf("output emitter node's content type must be %s, got %s", vo.VariableTypeString, content.Type)
|
||||
}
|
||||
|
||||
if content.Value.Type != vo.BlockInputValueTypeLiteral {
|
||||
return nil, fmt.Errorf("output emitter node's content value type must be %s, got %s", vo.BlockInputValueTypeLiteral, content.Value.Type)
|
||||
}
|
||||
|
||||
if content.Value.Content == nil {
|
||||
c.Template = ""
|
||||
} else {
|
||||
template, ok := content.Value.Content.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("output emitter node's content value must be string, got %v", content.Value.Content)
|
||||
}
|
||||
|
||||
c.Template = template
|
||||
}
|
||||
}
|
||||
|
||||
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func (c *Config) Build(_ context.Context, ns *schema2.NodeSchema, _ ...schema2.BuildOption) (any, error) {
|
||||
return &OutputEmitter{
|
||||
cfg: cfg,
|
||||
Template: c.Template,
|
||||
FullSources: ns.FullSources,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -59,10 +107,10 @@ type cachedVal struct {
|
||||
|
||||
type cacheStore struct {
|
||||
store map[string]*cachedVal
|
||||
infos map[string]*nodes.SourceInfo
|
||||
infos map[string]*schema2.SourceInfo
|
||||
}
|
||||
|
||||
func newCacheStore(infos map[string]*nodes.SourceInfo) *cacheStore {
|
||||
func newCacheStore(infos map[string]*schema2.SourceInfo) *cacheStore {
|
||||
return &cacheStore{
|
||||
store: make(map[string]*cachedVal),
|
||||
infos: infos,
|
||||
@@ -76,7 +124,7 @@ func (c *cacheStore) put(k string, v any) (any, error) {
|
||||
}
|
||||
|
||||
if !sInfo.IsIntermediate { // this is not an intermediate object container
|
||||
isStream := sInfo.FieldType == nodes.FieldIsStream
|
||||
isStream := sInfo.FieldType == schema2.FieldIsStream
|
||||
if !isStream {
|
||||
_, ok := c.store[k]
|
||||
if !ok {
|
||||
@@ -159,7 +207,7 @@ func (c *cacheStore) put(k string, v any) (any, error) {
|
||||
func (c *cacheStore) finished(k string) bool {
|
||||
cached, ok := c.store[k]
|
||||
if !ok {
|
||||
return c.infos[k].FieldType == nodes.FieldSkipped
|
||||
return c.infos[k].FieldType == schema2.FieldSkipped
|
||||
}
|
||||
|
||||
if cached.finished {
|
||||
@@ -182,7 +230,7 @@ func (c *cacheStore) finished(k string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *cacheStore) find(part nodes.TemplatePart) (root any, subCache *cachedVal, sourceInfo *nodes.SourceInfo,
|
||||
func (c *cacheStore) find(part nodes.TemplatePart) (root any, subCache *cachedVal, sourceInfo *schema2.SourceInfo,
|
||||
actualPath []string,
|
||||
) {
|
||||
rootCached, ok := c.store[part.Root]
|
||||
@@ -230,7 +278,7 @@ func (c *cacheStore) readyForPart(part nodes.TemplatePart, sw *schema.StreamWrit
|
||||
hasErr bool, partFinished bool) {
|
||||
cachedRoot, subCache, sourceInfo, _ := c.find(part)
|
||||
if cachedRoot != nil && subCache != nil {
|
||||
if subCache.finished || sourceInfo.FieldType == nodes.FieldIsStream {
|
||||
if subCache.finished || sourceInfo.FieldType == schema2.FieldIsStream {
|
||||
hasErr = renderAndSend(part, part.Root, cachedRoot, sw)
|
||||
if hasErr {
|
||||
return true, false
|
||||
@@ -315,14 +363,14 @@ func merge(a, b any) any {
|
||||
|
||||
const outputKey = "output"
|
||||
|
||||
func (e *OutputEmitter) EmitStream(ctx context.Context, in *schema.StreamReader[map[string]any]) (out *schema.StreamReader[map[string]any], err error) {
|
||||
resolvedSources, err := nodes.ResolveStreamSources(ctx, e.cfg.FullSources)
|
||||
func (e *OutputEmitter) Transform(ctx context.Context, in *schema.StreamReader[map[string]any]) (out *schema.StreamReader[map[string]any], err error) {
|
||||
resolvedSources, err := nodes.ResolveStreamSources(ctx, e.FullSources)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sr, sw := schema.Pipe[map[string]any](0)
|
||||
parts := nodes.ParseTemplate(e.cfg.Template)
|
||||
parts := nodes.ParseTemplate(e.Template)
|
||||
safego.Go(ctx, func() {
|
||||
hasErr := false
|
||||
defer func() {
|
||||
@@ -454,7 +502,7 @@ func (e *OutputEmitter) EmitStream(ctx context.Context, in *schema.StreamReader[
|
||||
shouldChangePart = true
|
||||
}
|
||||
} else {
|
||||
if sourceInfo.FieldType == nodes.FieldIsStream {
|
||||
if sourceInfo.FieldType == schema2.FieldIsStream {
|
||||
currentV := v
|
||||
for i := 0; i < len(actualPath)-1; i++ {
|
||||
currentM, ok := currentV.(map[string]any)
|
||||
@@ -518,8 +566,8 @@ func (e *OutputEmitter) EmitStream(ctx context.Context, in *schema.StreamReader[
|
||||
return sr, nil
|
||||
}
|
||||
|
||||
func (e *OutputEmitter) Emit(ctx context.Context, in map[string]any) (output map[string]any, err error) {
|
||||
s, err := nodes.Render(ctx, e.cfg.Template, in, e.cfg.FullSources)
|
||||
func (e *OutputEmitter) Invoke(ctx context.Context, in map[string]any) (output map[string]any, err error) {
|
||||
s, err := nodes.Render(ctx, e.Template, in, e.FullSources)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user