coze-studio/backend/domain/workflow/internal/compose/field_fill.go

318 lines
8.0 KiB
Go

/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package compose
import (
"context"
"fmt"
"maps"
"slices"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
)
// outputValueFiller will fill the output value with nil if the key is not present in the output map.
// if a node emits stream as output, the node needs to handle these absent keys in stream themselves.
func (s *NodeSchema) outputValueFiller() func(ctx context.Context, output map[string]any) (map[string]any, error) {
if len(s.OutputTypes) == 0 {
return func(ctx context.Context, output map[string]any) (map[string]any, error) {
return output, nil
}
}
return func(ctx context.Context, output map[string]any) (map[string]any, error) {
newOutput := make(map[string]any)
for k := range output {
newOutput[k] = output[k]
}
for k, tInfo := range s.OutputTypes {
if err := FillIfNotRequired(tInfo, newOutput, k, FillNil, false); err != nil {
return nil, err
}
}
return newOutput, nil
}
}
// inputValueFiller will fill the input value with default value(zero or nil) if the input value is not present in map.
// if a node accepts stream as input, the node needs to handle these absent keys in stream themselves.
func (s *NodeSchema) inputValueFiller() func(ctx context.Context, input map[string]any) (map[string]any, error) {
if len(s.InputTypes) == 0 {
return func(ctx context.Context, input map[string]any) (map[string]any, error) {
return input, nil
}
}
return func(ctx context.Context, input map[string]any) (map[string]any, error) {
newInput := make(map[string]any)
for k := range input {
newInput[k] = input[k]
}
for k, tInfo := range s.InputTypes {
if err := FillIfNotRequired(tInfo, newInput, k, FillZero, false); err != nil {
return nil, err
}
}
return newInput, nil
}
}
func (s *NodeSchema) streamInputValueFiller() func(ctx context.Context,
input *schema.StreamReader[map[string]any]) *schema.StreamReader[map[string]any] {
fn := func(ctx context.Context, i map[string]any) (map[string]any, error) {
newI := make(map[string]any)
for k := range i {
newI[k] = i[k]
}
for k, tInfo := range s.InputTypes {
if err := replaceNilWithZero(tInfo, newI, k); err != nil {
return nil, err
}
}
return newI, nil
}
return func(ctx context.Context, input *schema.StreamReader[map[string]any]) *schema.StreamReader[map[string]any] {
return schema.StreamReaderWithConvert(input, func(in map[string]any) (map[string]any, error) {
return fn(ctx, in)
})
}
}
type FillStrategy string
const (
FillZero FillStrategy = "zero"
FillNil FillStrategy = "nil"
)
func FillIfNotRequired(tInfo *vo.TypeInfo, container map[string]any, k string, strategy FillStrategy, isInsideObject bool) error {
v, ok := container[k]
if ok {
if len(tInfo.Properties) == 0 {
if v == nil && strategy == FillZero {
if isInsideObject {
return nil
}
v = tInfo.Zero()
container[k] = v
return nil
}
if v != nil && tInfo.Type == vo.DataTypeArray {
val, ok := v.([]any)
if !ok {
valStr, ok := v.(string)
if ok {
err := sonic.UnmarshalString(valStr, &val)
if err != nil {
return err
}
container[k] = val
} else {
return fmt.Errorf("layer field %s is not a []any or string", k)
}
}
elemTInfo := tInfo.ElemTypeInfo
copiedVal := slices.Clone(val)
container[k] = copiedVal
for i := range copiedVal {
if copiedVal[i] == nil {
if strategy == FillZero {
copiedVal[i] = elemTInfo.Zero()
continue
}
}
if len(elemTInfo.Properties) > 0 {
subContainer, ok := copiedVal[i].(map[string]any)
if !ok {
return fmt.Errorf("map item under array %s is not map[string]any", k)
}
newSubContainer := maps.Clone(subContainer)
if newSubContainer == nil {
newSubContainer = make(map[string]any)
}
for subK, subL := range elemTInfo.Properties {
if err := FillIfNotRequired(subL, newSubContainer, subK, strategy, true); err != nil {
return err
}
}
copiedVal[i] = newSubContainer
}
}
}
} else {
if v == nil {
return nil
}
// recursively handle the layered object.
subContainer, ok := v.(map[string]any)
if !ok {
subContainerStr, ok := v.(string)
if ok {
subContainer = make(map[string]any)
err := sonic.UnmarshalString(subContainerStr, &subContainer)
if err != nil {
return err
}
container[k] = subContainer
} else {
return fmt.Errorf("layer field %s is not a map[string]any or string", k)
}
}
newSubContainer := maps.Clone(subContainer)
if newSubContainer == nil {
newSubContainer = make(map[string]any)
}
for subK, subT := range tInfo.Properties {
if err := FillIfNotRequired(subT, newSubContainer, subK, strategy, true); err != nil {
return err
}
}
container[k] = newSubContainer
}
} else {
if tInfo.Required {
return fmt.Errorf("output field %s is required but not present", k)
} else {
var z any
if strategy == FillZero {
if !isInsideObject {
z = tInfo.Zero()
}
}
container[k] = z
// if it's an object, recursively handle the layeredFieldInfo.
if len(tInfo.Properties) > 0 {
z = make(map[string]any)
container[k] = z
subContainer := z.(map[string]any)
for subK, subL := range tInfo.Properties {
if err := FillIfNotRequired(subL, subContainer, subK, strategy, true); err != nil {
return err
}
}
}
}
}
return nil
}
func replaceNilWithZero(tInfo *vo.TypeInfo, container map[string]any, k string) error {
v, ok := container[k]
if !ok {
return nil
}
if len(tInfo.Properties) == 0 {
if v == nil {
v = tInfo.Zero()
container[k] = v
return nil
}
if tInfo.Type == vo.DataTypeArray {
val, ok := v.([]any)
if !ok {
valStr, ok := v.(string)
if ok {
err := sonic.UnmarshalString(valStr, &val)
if err != nil {
return err
}
container[k] = val
} else {
return fmt.Errorf("layer field %s is not a []any or string", k)
}
}
elemTInfo := tInfo.ElemTypeInfo
copiedVal := slices.Clone(val)
container[k] = copiedVal
for i := range copiedVal {
if copiedVal[i] == nil {
copiedVal[i] = elemTInfo.Zero()
continue
}
if len(elemTInfo.Properties) > 0 {
subContainer, ok := copiedVal[i].(map[string]any)
if !ok {
return fmt.Errorf("map item under array %s is not map[string]any", k)
}
newSubContainer := maps.Clone(subContainer)
for subK, subL := range elemTInfo.Properties {
if err := replaceNilWithZero(subL, newSubContainer, subK); err != nil {
return err
}
}
copiedVal[i] = newSubContainer
}
}
}
} else {
if v == nil {
return nil
}
// recursively handle the layered object.
subContainer, ok := v.(map[string]any)
if !ok {
subContainerStr, ok := v.(string)
if ok {
subContainer = make(map[string]any)
err := sonic.UnmarshalString(subContainerStr, &subContainer)
if err != nil {
return err
}
container[k] = subContainer
} else {
return fmt.Errorf("layer field %s is not a map[string]any or string", k)
}
}
newSubContainer := maps.Clone(subContainer)
for subK, subT := range tInfo.Properties {
if err := replaceNilWithZero(subT, newSubContainer, subK); err != nil {
return err
}
}
container[k] = newSubContainer
}
return nil
}