556 lines
14 KiB
Go
556 lines
14 KiB
Go
/*
|
|
* Copyright 2025 coze-dev Authors
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package emitter
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"strings"
|
|
|
|
"github.com/bytedance/sonic"
|
|
"github.com/cloudwego/eino/schema"
|
|
|
|
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
|
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
|
)
|
|
|
|
type OutputEmitter struct {
|
|
cfg *Config
|
|
}
|
|
|
|
type Config struct {
|
|
Template string
|
|
FullSources map[string]*nodes.SourceInfo
|
|
}
|
|
|
|
func New(_ context.Context, cfg *Config) (*OutputEmitter, error) {
|
|
if cfg == nil {
|
|
return nil, errors.New("config is required")
|
|
}
|
|
|
|
return &OutputEmitter{
|
|
cfg: cfg,
|
|
}, nil
|
|
}
|
|
|
|
type cachedVal struct {
|
|
val any
|
|
finished bool
|
|
subCaches *cacheStore
|
|
}
|
|
|
|
type cacheStore struct {
|
|
store map[string]*cachedVal
|
|
infos map[string]*nodes.SourceInfo
|
|
}
|
|
|
|
func newCacheStore(infos map[string]*nodes.SourceInfo) *cacheStore {
|
|
return &cacheStore{
|
|
store: make(map[string]*cachedVal),
|
|
infos: infos,
|
|
}
|
|
}
|
|
|
|
func (c *cacheStore) put(k string, v any) (any, error) {
|
|
sInfo, ok := c.infos[k]
|
|
if !ok {
|
|
return nil, fmt.Errorf("no such key found from SourceInfos: %s", k)
|
|
}
|
|
|
|
if !sInfo.IsIntermediate { // this is not an intermediate object container
|
|
isStream := sInfo.FieldType == nodes.FieldIsStream
|
|
if !isStream {
|
|
_, ok := c.store[k]
|
|
if !ok {
|
|
out := &cachedVal{
|
|
val: v,
|
|
finished: true,
|
|
}
|
|
c.store[k] = out
|
|
return v, nil
|
|
} else {
|
|
return nil, fmt.Errorf("source %s not intermediate, not stream, appears multiple times", k)
|
|
}
|
|
} else { // this is an actual stream
|
|
vStr, ok := v.(string) // stream value should be string
|
|
if !ok {
|
|
return nil, fmt.Errorf("source %s is not intermediate, is stream, but value type not str: %T", k, v)
|
|
}
|
|
|
|
isFinished := strings.HasSuffix(vStr, nodes.KeyIsFinished)
|
|
if isFinished {
|
|
vStr = strings.TrimSuffix(vStr, nodes.KeyIsFinished)
|
|
}
|
|
|
|
var existingStr string
|
|
existing, ok := c.store[k]
|
|
if ok {
|
|
existingStr = existing.val.(string)
|
|
}
|
|
|
|
existingStr = existingStr + vStr
|
|
out := &cachedVal{
|
|
val: existingStr,
|
|
finished: isFinished,
|
|
}
|
|
c.store[k] = out
|
|
|
|
return vStr, nil
|
|
}
|
|
}
|
|
|
|
if len(sInfo.SubSources) == 0 {
|
|
// k is intermediate container, needs to check its sub sources
|
|
return nil, fmt.Errorf("source %s is intermediate, but does not have sub sources", k)
|
|
}
|
|
|
|
vMap, ok := v.(map[string]interface{})
|
|
if !ok {
|
|
return nil, fmt.Errorf("source %s is intermediate, but value type not map: %T", k, v)
|
|
}
|
|
|
|
currentCache, existed := c.store[k]
|
|
if !existed {
|
|
currentCache = &cachedVal{
|
|
val: v,
|
|
subCaches: newCacheStore(sInfo.SubSources),
|
|
}
|
|
c.store[k] = currentCache
|
|
} else {
|
|
// already cached k before, merge cached value with new value
|
|
currentCache.val = merge(currentCache.val, v)
|
|
}
|
|
|
|
subCacheStore := currentCache.subCaches
|
|
for subK := range subCacheStore.infos {
|
|
subV, ok := vMap[subK]
|
|
if !ok { // subK not present in this chunk
|
|
continue
|
|
}
|
|
_, err := subCacheStore.put(subK, subV)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
_ = c.finished(k)
|
|
|
|
return vMap, nil
|
|
}
|
|
|
|
func (c *cacheStore) finished(k string) bool {
|
|
cached, ok := c.store[k]
|
|
if !ok {
|
|
return c.infos[k].FieldType == nodes.FieldSkipped
|
|
}
|
|
|
|
if cached.finished {
|
|
return true
|
|
}
|
|
|
|
sInfo := c.infos[k]
|
|
if !sInfo.IsIntermediate {
|
|
return cached.finished
|
|
}
|
|
|
|
for subK := range sInfo.SubSources {
|
|
subFinished := cached.subCaches.finished(subK)
|
|
if !subFinished {
|
|
return false
|
|
}
|
|
}
|
|
|
|
cached.finished = true
|
|
return true
|
|
}
|
|
|
|
func (c *cacheStore) find(part nodes.TemplatePart) (root any, subCache *cachedVal, sourceInfo *nodes.SourceInfo,
|
|
actualPath []string,
|
|
) {
|
|
rootCached, ok := c.store[part.Root]
|
|
if !ok {
|
|
return nil, nil, nil, nil
|
|
}
|
|
|
|
// now try to find the nearest match within the cached tree
|
|
subPaths := part.SubPathsBeforeSlice
|
|
currentCache := rootCached
|
|
currentSource := c.infos[part.Root]
|
|
for i := range subPaths {
|
|
if !currentSource.IsIntermediate {
|
|
// currentSource is already the leaf, no need to look further
|
|
break
|
|
}
|
|
subPath := subPaths[i]
|
|
subInfo, ok := currentSource.SubSources[subPath]
|
|
if !ok {
|
|
// this sub path is not in the source info tree
|
|
// it's just a user defined variable field in the template
|
|
break
|
|
}
|
|
|
|
actualPath = append(actualPath, subPath)
|
|
|
|
subCache, ok = currentCache.subCaches.store[subPath]
|
|
if !ok {
|
|
// subPath corresponds to a real Field Source,
|
|
// if it's not cached, then it hasn't appeared in the stream yet
|
|
return rootCached.val, nil, subInfo, actualPath
|
|
}
|
|
if !subCache.finished {
|
|
return rootCached.val, subCache, subInfo, actualPath
|
|
}
|
|
|
|
currentCache = subCache
|
|
currentSource = subInfo
|
|
}
|
|
|
|
return rootCached.val, currentCache, currentSource, actualPath
|
|
}
|
|
|
|
func (c *cacheStore) readyForPart(part nodes.TemplatePart, sw *schema.StreamWriter[map[string]any]) (
|
|
hasErr bool, partFinished bool) {
|
|
cachedRoot, subCache, sourceInfo, _ := c.find(part)
|
|
if cachedRoot != nil && subCache != nil {
|
|
if subCache.finished || sourceInfo.FieldType == nodes.FieldIsStream {
|
|
hasErr = renderAndSend(part, part.Root, cachedRoot, sw)
|
|
if hasErr {
|
|
return true, false
|
|
}
|
|
if subCache.finished { // move on to next part in template
|
|
return false, true
|
|
}
|
|
}
|
|
}
|
|
|
|
return false, false
|
|
}
|
|
|
|
func (c *cacheStore) fillZero(nodeKey vo.NodeKey) map[string]any {
|
|
filled := make(map[string]any)
|
|
for field, sInfo := range c.infos {
|
|
if !sInfo.FromNode(nodeKey) {
|
|
continue
|
|
}
|
|
|
|
cacheV, ok := c.store[field]
|
|
if !sInfo.IsIntermediate {
|
|
if !ok {
|
|
c.store[field] = &cachedVal{
|
|
val: sInfo.TypeInfo.Zero(),
|
|
finished: true,
|
|
subCaches: nil,
|
|
}
|
|
|
|
filled[field] = true
|
|
}
|
|
|
|
continue
|
|
}
|
|
|
|
if !ok {
|
|
cacheV = &cachedVal{
|
|
val: make(map[string]any),
|
|
subCaches: newCacheStore(sInfo.SubSources),
|
|
}
|
|
c.store[field] = cacheV
|
|
}
|
|
|
|
subFilled := cacheV.subCaches.fillZero(nodeKey)
|
|
if len(subFilled) > 0 {
|
|
filled[field] = subFilled
|
|
}
|
|
}
|
|
|
|
return filled
|
|
}
|
|
|
|
func merge(a, b any) any {
|
|
aStr, ok1 := a.(string)
|
|
bStr, ok2 := b.(string)
|
|
if ok1 && ok2 {
|
|
if strings.HasSuffix(bStr, nodes.KeyIsFinished) {
|
|
bStr = strings.TrimSuffix(bStr, nodes.KeyIsFinished)
|
|
}
|
|
return aStr + bStr
|
|
}
|
|
|
|
aMap, ok1 := a.(map[string]interface{})
|
|
bMap, ok2 := b.(map[string]interface{})
|
|
if ok1 && ok2 {
|
|
merged := make(map[string]any)
|
|
for k, v := range aMap {
|
|
merged[k] = v
|
|
}
|
|
for k, v := range bMap {
|
|
if _, ok := merged[k]; !ok { // only bMap has this field, just set it
|
|
merged[k] = v
|
|
continue
|
|
}
|
|
merged[k] = merge(merged[k], v)
|
|
}
|
|
return merged
|
|
}
|
|
|
|
panic(fmt.Errorf("can only merge two maps or two strings, a type: %T, b type: %T", a, b))
|
|
}
|
|
|
|
const outputKey = "output"
|
|
|
|
func (e *OutputEmitter) EmitStream(ctx context.Context, in *schema.StreamReader[map[string]any]) (out *schema.StreamReader[map[string]any], err error) {
|
|
resolvedSources, err := nodes.ResolveStreamSources(ctx, e.cfg.FullSources)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
sr, sw := schema.Pipe[map[string]any](0)
|
|
parts := nodes.ParseTemplate(e.cfg.Template)
|
|
safego.Go(ctx, func() {
|
|
hasErr := false
|
|
defer func() {
|
|
if !hasErr {
|
|
sw.Send(map[string]any{outputKey: nodes.KeyIsFinished}, nil)
|
|
}
|
|
sw.Close()
|
|
in.Close()
|
|
}()
|
|
|
|
caches := newCacheStore(resolvedSources)
|
|
|
|
partsLoop:
|
|
for _, part := range parts {
|
|
select {
|
|
case <-ctx.Done(): // canceled by Eino workflow engine
|
|
sw.Send(nil, ctx.Err())
|
|
hasErr = true
|
|
return
|
|
default:
|
|
}
|
|
|
|
if !part.IsVariable { // literal string within template, just emit it
|
|
sw.Send(map[string]any{outputKey: part.Value}, nil)
|
|
continue
|
|
}
|
|
|
|
// now this 'part' is a variable, first check if the source(s) for it are skipped (the nodes are not selected)
|
|
// if skipped, just move on to the next 'part'
|
|
skipped, invalid := part.Skipped(resolvedSources)
|
|
if skipped {
|
|
continue
|
|
}
|
|
if invalid {
|
|
sw.Send(map[string]any{outputKey: "{{" + part.Value + "}}"}, nil)
|
|
continue
|
|
}
|
|
|
|
// now this 'part' definitely should have a match, look for a hit within cache store
|
|
// if found in cache store, emit the root only if the match is finished or the match is stream
|
|
// the rule for a hit: the nearest match within the source tree
|
|
// if hit, and the cachedVal is also finished, continue to next template part
|
|
var partFinished bool
|
|
hasErr, partFinished = caches.readyForPart(part, sw)
|
|
if hasErr {
|
|
return
|
|
}
|
|
if partFinished {
|
|
continue
|
|
}
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done(): // canceled by Eino workflow engine or timeout
|
|
sw.Send(nil, ctx.Err())
|
|
hasErr = true
|
|
return
|
|
default:
|
|
}
|
|
|
|
shouldChangePart := false
|
|
|
|
chunk, err := in.Recv()
|
|
if err != nil {
|
|
if err == io.EOF {
|
|
// current part is not fulfilled, emit the literal part content and move on to next part
|
|
sw.Send(map[string]any{outputKey: part.Value}, nil)
|
|
break
|
|
}
|
|
|
|
if sn, ok := schema.GetSourceName(err); ok {
|
|
// received end signal for a particular predecessor nodeID, do the following:
|
|
// - obtain the field sources mapped from this predecessor node
|
|
// - check which fields are still missing in the cache store
|
|
// - fill zero value for these missing fields
|
|
// - check if the current template part should be rendered and sent immediately
|
|
// - check if we should move on to next part in template
|
|
filled := caches.fillZero(vo.NodeKey(sn))
|
|
if _, okk := filled[part.Root]; okk {
|
|
// current part is influenced by the 'fill zero' operation
|
|
hasErr, shouldChangePart = caches.readyForPart(part, sw)
|
|
if hasErr {
|
|
return
|
|
}
|
|
if shouldChangePart {
|
|
continue partsLoop
|
|
}
|
|
}
|
|
|
|
continue
|
|
}
|
|
|
|
hasErr = true
|
|
sw.Send(nil, err) // real error
|
|
return
|
|
}
|
|
|
|
chunkLoop:
|
|
for k := range chunk {
|
|
v := chunk[k]
|
|
v, err = caches.put(k, v) // always update the cache
|
|
if err != nil {
|
|
hasErr = true
|
|
sw.Send(nil, err)
|
|
return
|
|
}
|
|
|
|
// needs to check if this 'k' is the current part's root
|
|
// if it is, do the case analysis:
|
|
// - the source is a leaf (not intermediate):
|
|
// - the source is stream, emit the formatted stream content immediately
|
|
// - the source is not stream, emit the formatted one-time content immediately
|
|
// - the source is intermediate:
|
|
// - the source is not finished, do not emit it
|
|
// - the source is finished, emit the full content (cached + new) immediately
|
|
if k == part.Root {
|
|
cachedRoot, subCache, sourceInfo, actualPath := caches.find(part)
|
|
if sourceInfo == nil {
|
|
panic("impossible, k is part.root, but sourceInfo is nil")
|
|
}
|
|
|
|
if subCache != nil {
|
|
if sourceInfo.IsIntermediate {
|
|
if subCache.finished {
|
|
hasErr = renderAndSend(part, part.Root, cachedRoot, sw)
|
|
if hasErr {
|
|
return
|
|
}
|
|
shouldChangePart = true
|
|
}
|
|
} else {
|
|
if sourceInfo.FieldType == nodes.FieldIsStream {
|
|
currentV := v
|
|
for i := 0; i < len(actualPath)-1; i++ {
|
|
currentM, ok := currentV.(map[string]any)
|
|
if !ok {
|
|
panic("emit item not map[string]any")
|
|
}
|
|
currentV, ok = currentM[actualPath[i]]
|
|
if !ok {
|
|
continue chunkLoop
|
|
}
|
|
}
|
|
|
|
if len(actualPath) > 0 {
|
|
finalV, ok := currentV.(map[string]any)[actualPath[len(actualPath)-1]]
|
|
if !ok {
|
|
continue chunkLoop
|
|
}
|
|
currentV = finalV
|
|
}
|
|
vStr, ok := currentV.(string)
|
|
if !ok {
|
|
panic(fmt.Errorf("source %s is not intermediate, is stream, but value type not str: %T", k, v))
|
|
}
|
|
|
|
if strings.HasSuffix(vStr, nodes.KeyIsFinished) {
|
|
vStr = strings.TrimSuffix(vStr, nodes.KeyIsFinished)
|
|
}
|
|
|
|
var delta any
|
|
delta = vStr
|
|
for j := len(actualPath) - 1; j >= 0; j-- {
|
|
delta = map[string]any{
|
|
actualPath[j]: delta,
|
|
}
|
|
}
|
|
|
|
hasErr = renderAndSend(part, part.Root, delta, sw)
|
|
} else {
|
|
hasErr = renderAndSend(part, part.Root, v, sw)
|
|
}
|
|
|
|
if hasErr {
|
|
return
|
|
}
|
|
if subCache.finished {
|
|
shouldChangePart = true
|
|
}
|
|
}
|
|
}
|
|
|
|
}
|
|
}
|
|
|
|
if shouldChangePart {
|
|
continue partsLoop
|
|
}
|
|
}
|
|
}
|
|
})
|
|
|
|
return sr, nil
|
|
}
|
|
|
|
func (e *OutputEmitter) Emit(ctx context.Context, in map[string]any) (output map[string]any, err error) {
|
|
s, err := nodes.Render(ctx, e.cfg.Template, in, e.cfg.FullSources)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
output = map[string]any{
|
|
outputKey: s,
|
|
}
|
|
|
|
return output, nil
|
|
}
|
|
|
|
func renderAndSend(tp nodes.TemplatePart, k string, v any, sw *schema.StreamWriter[map[string]any]) bool /*hasError*/ {
|
|
m, err := sonic.Marshal(map[string]any{k: v})
|
|
if err != nil {
|
|
sw.Send(nil, err)
|
|
return true
|
|
}
|
|
|
|
r, err := tp.Render(m)
|
|
if err != nil {
|
|
sw.Send(nil, err)
|
|
return true
|
|
}
|
|
|
|
if len(r) == 0 { // won't send if formatted result is empty string
|
|
return false
|
|
}
|
|
|
|
logs.Infof("send: %v", r)
|
|
|
|
sw.Send(map[string]any{outputKey: r}, nil)
|
|
return false
|
|
}
|