coze-studio/backend/domain/workflow/internal/nodes/template.go

433 lines
11 KiB
Go

/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nodes
import (
"context"
"encoding/json"
"errors"
"fmt"
"reflect"
"regexp"
"strconv"
"strings"
"github.com/bytedance/sonic"
"github.com/bytedance/sonic/ast"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
type TemplatePart struct {
IsVariable bool
Value string
Root string
SubPathsBeforeSlice []string
JsonPath []any
literal string
}
var re = regexp.MustCompile(`{{\s*([^}]+)\s*}}`)
func ParseTemplate(template string) []TemplatePart {
matches := re.FindAllStringSubmatchIndex(template, -1)
parts := make([]TemplatePart, 0)
lastEnd := 0
loop:
for _, match := range matches {
start, end := match[0], match[1]
placeholderStart, placeholderEnd := match[2], match[3]
// Add the literal part before the current variable placeholder
if start > lastEnd {
parts = append(parts, TemplatePart{
IsVariable: false,
Value: template[lastEnd:start],
})
}
// Add the variable placeholder
val := template[placeholderStart:placeholderEnd]
segments := strings.Split(val, ".")
var subPaths []string
if !strings.Contains(segments[0], "[") {
for i := 1; i < len(segments); i++ {
if strings.Contains(segments[i], "[") {
break
}
subPaths = append(subPaths, segments[i])
}
}
var jsonPath []any
for _, segment := range segments {
// find the first '[' to separate the initial key from array accessors
firstBracket := strings.Index(segment, "[")
if firstBracket == -1 {
// No brackets, the whole segment is a key
jsonPath = append(jsonPath, segment)
continue
}
// Add the initial key part
key := segment[:firstBracket]
if key != "" {
jsonPath = append(jsonPath, key)
}
// Now, parse the array accessors like [1][2]
rest := segment[firstBracket:]
for strings.HasPrefix(rest, "[") {
closeBracket := strings.Index(rest, "]")
if closeBracket == -1 {
// Malformed, treat as literal
parts = append(parts, TemplatePart{IsVariable: false, Value: val})
continue loop
}
idxStr := rest[1:closeBracket]
idx, err := strconv.Atoi(idxStr)
if err != nil {
// Malformed, treat as literal
parts = append(parts, TemplatePart{IsVariable: false, Value: val})
continue loop
}
jsonPath = append(jsonPath, idx)
rest = rest[closeBracket+1:]
}
if rest != "" {
// Malformed, treat as literal
parts = append(parts, TemplatePart{IsVariable: false, Value: val})
continue loop
}
}
parts = append(parts, TemplatePart{
IsVariable: true,
Value: val,
Root: removeSlice(segments[0]),
SubPathsBeforeSlice: subPaths,
JsonPath: jsonPath,
literal: "{{" + val + "}}",
})
lastEnd = end
}
// Add the remaining literal part if there is any
if lastEnd < len(template) {
parts = append(parts, TemplatePart{
IsVariable: false,
Value: template[lastEnd:],
})
}
return parts
}
func removeSlice(s string) string {
i := strings.Index(s, "[")
if i != -1 {
return s[:i]
}
return s
}
type renderOptions struct {
type2CustomRenderer map[reflect.Type]func(any) (string, error)
reservedKey map[string]struct{} // a reservedKey will always render, won't check for node skipping
nilRenderer func() (string, error)
}
func WithNilRender(fn func() (string, error)) RenderOption {
return func(opts *renderOptions) {
opts.nilRenderer = fn
}
}
type RenderOption func(options *renderOptions)
func WithCustomRender(rType reflect.Type, fn func(any) (string, error)) RenderOption {
return func(opts *renderOptions) {
if opts.type2CustomRenderer == nil {
opts.type2CustomRenderer = make(map[reflect.Type]func(any) (string, error))
}
opts.type2CustomRenderer[rType] = fn
}
}
func WithReservedKey(keys ...string) RenderOption {
return func(opts *renderOptions) {
if opts.reservedKey == nil {
opts.reservedKey = make(map[string]struct{})
}
for _, key := range keys {
opts.reservedKey[key] = struct{}{}
}
}
}
var renderConfig = sonic.Config{
SortMapKeys: true,
}.Froze()
func joinJsonPath(p []any) string {
var sb strings.Builder
for i := range p {
field, ok := p[i].(string)
if ok {
if i > 0 {
_, ok := p[i-1].(string)
if ok {
sb.WriteString(".")
}
}
sb.WriteString(field)
} else {
sb.WriteString(fmt.Sprintf("[%d]", p[i]))
}
}
return sb.String()
}
func (tp TemplatePart) Render(m []byte, opts ...RenderOption) (string, error) {
options := &renderOptions{
type2CustomRenderer: make(map[reflect.Type]func(any) (string, error)),
}
for _, opt := range opts {
opt(options)
}
n, err := sonic.Get(m, tp.JsonPath...)
if err != nil {
notExist := errors.Is(err, ast.ErrNotExist)
var syntaxErr ast.SyntaxError
if notExist || errors.As(err, &syntaxErr) {
// get each path segments one by one until the first not found error
var segParent, current ast.Node
for i := range tp.JsonPath {
current, err = sonic.Get(m, tp.JsonPath[:i+1]...)
if err != nil {
if errors.Is(err, ast.ErrNotExist) { // first not found segment
segmentI, ok := tp.JsonPath[i].(int)
if ok {
if !segParent.Exists() {
panic("impossible")
} else {
segArr, err := segParent.Array()
if err != nil { // not taking elements from array
return tp.literal, nil
}
return "", vo.NewError(errno.ErrArrIndexOutOfRange,
errorx.KV("arr_name", joinJsonPath(tp.JsonPath[:i])),
errorx.KV("req_index", strconv.Itoa(segmentI)),
errorx.KV("arr_len", strconv.Itoa(len(segArr))))
}
}
return tp.literal, nil // not array element not found, but object field, just print
} else if errors.As(err, &syntaxErr) {
segmentI, ok := tp.JsonPath[i].(int)
if ok {
return "", vo.NewError(errno.ErrIndexingNilArray,
errorx.KV("arr_name", joinJsonPath(tp.JsonPath[:i])),
errorx.KV("req_index", strconv.Itoa(segmentI)))
}
return tp.literal, nil // not array element not found, but object field, just print
}
return tp.literal, nil // not ErrNotExist, just print
} else {
segParent = current
}
}
}
return tp.literal, nil
}
i, err := n.InterfaceUseNumber()
if err != nil {
return tp.literal, nil
}
if i == nil {
if options.nilRenderer != nil {
return options.nilRenderer()
}
return "", nil
}
if len(options.type2CustomRenderer) > 0 {
rType := reflect.TypeOf(i)
if fn, ok := options.type2CustomRenderer[rType]; ok {
return fn(i)
}
}
switch i.(type) {
case string:
return i.(string), nil
case json.Number:
return i.(json.Number).String(), nil
case bool:
return strconv.FormatBool(i.(bool)), nil
default:
ms, err := renderConfig.MarshalToString(i) // keep order of the map keys
if err != nil {
return "", err
}
return ms, nil
}
}
func (tp TemplatePart) Skipped(resolvedSources map[string]*schema.SourceInfo) (skipped bool, invalid bool) {
if len(resolvedSources) == 0 { // no information available, maybe outside the scope of a workflow
return false, false
}
// examine along the TemplatePart's root and sub paths,
// trying to find a matching SourceInfo as far as possible.
// the result would be one of two cases:
// - a REAL field source is matched, just check if that field source is skipped
// - otherwise an INTERMEDIATE field source is matched, it can only be skipped if ALL its sub sources are skipped
matchingSource, ok := resolvedSources[tp.Root]
if !ok { // the user specified a non-existing source, it can never have any value, just skip it
return false, true
}
if !matchingSource.IsIntermediate {
return matchingSource.FieldType == schema.FieldSkipped, false
}
for _, subPath := range tp.SubPathsBeforeSlice {
subSource, ok := matchingSource.SubSources[subPath]
if !ok { // has gone deeper than the field source
if matchingSource.IsIntermediate { // the user specified a non-existing source, just skip it
return false, true
}
return matchingSource.FieldType == schema.FieldSkipped, false
}
matchingSource = subSource
}
if !matchingSource.IsIntermediate {
return matchingSource.FieldType == schema.FieldSkipped, false
}
var checkSourceSkipped func(sInfo *schema.SourceInfo) bool
checkSourceSkipped = func(sInfo *schema.SourceInfo) bool {
if !sInfo.IsIntermediate {
return sInfo.FieldType == schema.FieldSkipped
}
for _, subSource := range sInfo.SubSources {
if !checkSourceSkipped(subSource) {
return false
}
}
return true
}
return checkSourceSkipped(matchingSource), false
}
func (tp TemplatePart) TypeInfo(types map[string]*vo.TypeInfo) *vo.TypeInfo {
if len(tp.SubPathsBeforeSlice) == 0 {
return types[tp.Root]
}
rootType, ok := types[tp.Root]
if !ok {
return nil
}
currentType := rootType
for _, subPath := range tp.SubPathsBeforeSlice {
if len(currentType.Properties) == 0 {
return nil
}
subType, ok := currentType.Properties[subPath]
if !ok {
return nil
}
currentType = subType
}
return currentType
}
func Render(ctx context.Context, tpl string, input map[string]any, sources map[string]*schema.SourceInfo, opts ...RenderOption) (string, error) {
mi, err := sonic.Marshal(input)
if err != nil {
return "", err
}
resolvedSources, err := ResolveStreamSources(ctx, sources)
if err != nil {
return "", err
}
options := &renderOptions{}
for _, opt := range opts {
opt(options)
}
var sb strings.Builder
parts := ParseTemplate(tpl)
for _, part := range parts {
if !part.IsVariable {
sb.WriteString(part.Value)
continue
}
if options.reservedKey != nil {
if _, ok := options.reservedKey[part.Root]; ok {
i, err := part.Render(mi, opts...)
if err != nil {
return "", err
}
sb.WriteString(i)
continue
}
}
skipped, invalid := part.Skipped(resolvedSources)
if skipped {
continue
}
if invalid {
sb.WriteString(part.literal)
continue
}
i, err := part.Render(mi, opts...)
if err != nil {
return "", err
}
sb.WriteString(i)
}
return sb.String(), nil
}