292 lines
		
	
	
		
			8.5 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			292 lines
		
	
	
		
			8.5 KiB
		
	
	
	
		
			Go
		
	
	
	
/*
 | 
						|
 * Copyright 2025 coze-dev Authors
 | 
						|
 *
 | 
						|
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
 * you may not use this file except in compliance with the License.
 | 
						|
 * You may obtain a copy of the License at
 | 
						|
 *
 | 
						|
 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
 *
 | 
						|
 * Unless required by applicable law or agreed to in writing, software
 | 
						|
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
 * See the License for the specific language governing permissions and
 | 
						|
 * limitations under the License.
 | 
						|
 */
 | 
						|
 | 
						|
package agentflow
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"errors"
 | 
						|
	"io"
 | 
						|
	"slices"
 | 
						|
 | 
						|
	"github.com/google/uuid"
 | 
						|
 | 
						|
	"github.com/cloudwego/eino/compose"
 | 
						|
	"github.com/cloudwego/eino/schema"
 | 
						|
 | 
						|
	"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
 | 
						|
	crossworkflow "github.com/coze-dev/coze-studio/backend/crossdomain/contract/workflow"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/pkg/logs"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/pkg/safego"
 | 
						|
)
 | 
						|
 | 
						|
type AgentState struct {
 | 
						|
	Messages                 []*schema.Message
 | 
						|
	UserInput                *schema.Message
 | 
						|
	ReturnDirectlyToolCallID string
 | 
						|
}
 | 
						|
 | 
						|
type AgentRequest struct {
 | 
						|
	UserID  string
 | 
						|
	Input   *schema.Message
 | 
						|
	History []*schema.Message
 | 
						|
 | 
						|
	Identity *singleagent.AgentIdentity
 | 
						|
 | 
						|
	ResumeInfo   *singleagent.InterruptInfo
 | 
						|
	PreCallTools []*agentrun.ToolsRetriever
 | 
						|
	Variables    map[string]string
 | 
						|
}
 | 
						|
 | 
						|
type AgentRunner struct {
 | 
						|
	runner            compose.Runnable[*AgentRequest, *schema.Message]
 | 
						|
	requireCheckpoint bool
 | 
						|
 | 
						|
	returnDirectlyTools map[string]struct{}
 | 
						|
	containWfTool       bool
 | 
						|
	modelInfo           *modelmgr.Model
 | 
						|
}
 | 
						|
 | 
						|
func (r *AgentRunner) StreamExecute(ctx context.Context, req *AgentRequest) (
 | 
						|
	sr *schema.StreamReader[*entity.AgentEvent], err error,
 | 
						|
) {
 | 
						|
	executeID := uuid.New()
 | 
						|
 | 
						|
	hdl, sr, sw := newReplyCallback(ctx, executeID.String(), r.returnDirectlyTools)
 | 
						|
 | 
						|
	var composeOpts []compose.Option
 | 
						|
	var pipeMsgOpt compose.Option
 | 
						|
	var workflowMsgSr *schema.StreamReader[*crossworkflow.WorkflowMessage]
 | 
						|
	if r.containWfTool {
 | 
						|
		cfReq := crossworkflow.ExecuteConfig{
 | 
						|
			AgentID:      &req.Identity.AgentID,
 | 
						|
			ConnectorUID: req.UserID,
 | 
						|
			ConnectorID:  req.Identity.ConnectorID,
 | 
						|
			BizType:      crossworkflow.BizTypeAgent,
 | 
						|
		}
 | 
						|
		if req.Identity.IsDraft {
 | 
						|
			cfReq.Mode = crossworkflow.ExecuteModeDebug
 | 
						|
		} else {
 | 
						|
			cfReq.Mode = crossworkflow.ExecuteModeRelease
 | 
						|
		}
 | 
						|
		wfConfig := crossworkflow.DefaultSVC().WithExecuteConfig(cfReq)
 | 
						|
		composeOpts = append(composeOpts, wfConfig)
 | 
						|
		pipeMsgOpt, workflowMsgSr = crossworkflow.DefaultSVC().WithMessagePipe()
 | 
						|
		composeOpts = append(composeOpts, pipeMsgOpt)
 | 
						|
	}
 | 
						|
 | 
						|
	composeOpts = append(composeOpts, compose.WithCallbacks(hdl))
 | 
						|
	_ = compose.RegisterSerializableType[*AgentState]("agent_state")
 | 
						|
	if r.requireCheckpoint {
 | 
						|
 | 
						|
		defaultCheckPointID := executeID.String()
 | 
						|
		if req.ResumeInfo != nil {
 | 
						|
			resumeInfo := req.ResumeInfo
 | 
						|
			if resumeInfo.InterruptType != singleagent.InterruptEventType_OauthPlugin {
 | 
						|
				defaultCheckPointID = resumeInfo.InterruptID
 | 
						|
				opts := crossworkflow.DefaultSVC().WithResumeToolWorkflow(resumeInfo.AllWfInterruptData[resumeInfo.ToolCallID], req.Input.Content, resumeInfo.AllWfInterruptData)
 | 
						|
				composeOpts = append(composeOpts, opts)
 | 
						|
			}
 | 
						|
		}
 | 
						|
 | 
						|
		composeOpts = append(composeOpts, compose.WithCheckPointID(defaultCheckPointID))
 | 
						|
	}
 | 
						|
	if r.containWfTool && workflowMsgSr != nil {
 | 
						|
		safego.Go(ctx, func() {
 | 
						|
			r.processWfMidAnswerStream(ctx, sw, workflowMsgSr)
 | 
						|
		})
 | 
						|
	}
 | 
						|
	safego.Go(ctx, func() {
 | 
						|
		defer func() {
 | 
						|
			if pe := recover(); pe != nil {
 | 
						|
				logs.CtxErrorf(ctx, "[AgentRunner] StreamExecute recover, err: %v", pe)
 | 
						|
 | 
						|
				sw.Send(nil, errors.New("internal server error"))
 | 
						|
			}
 | 
						|
			sw.Close()
 | 
						|
		}()
 | 
						|
		_, _ = r.runner.Stream(ctx, req, composeOpts...)
 | 
						|
	})
 | 
						|
 | 
						|
	return sr, nil
 | 
						|
}
 | 
						|
 | 
						|
func (r *AgentRunner) processWfMidAnswerStream(_ context.Context, sw *schema.StreamWriter[*entity.AgentEvent], wfStream *schema.StreamReader[*crossworkflow.WorkflowMessage]) {
 | 
						|
	streamInitialized := false
 | 
						|
	var srT *schema.StreamReader[*schema.Message]
 | 
						|
	var swT *schema.StreamWriter[*schema.Message]
 | 
						|
	defer func() {
 | 
						|
		if swT != nil {
 | 
						|
			swT.Close()
 | 
						|
		}
 | 
						|
	}()
 | 
						|
	for {
 | 
						|
		msg, err := wfStream.Recv()
 | 
						|
 | 
						|
		if err == io.EOF {
 | 
						|
			break
 | 
						|
		}
 | 
						|
		if msg == nil || msg.DataMessage == nil {
 | 
						|
			continue
 | 
						|
		}
 | 
						|
 | 
						|
		if msg.DataMessage.NodeType != crossworkflow.NodeTypeOutputEmitter {
 | 
						|
			continue
 | 
						|
		}
 | 
						|
		if !streamInitialized {
 | 
						|
			streamInitialized = true
 | 
						|
			srT, swT = schema.Pipe[*schema.Message](5)
 | 
						|
			sw.Send(&entity.AgentEvent{
 | 
						|
				EventType:     singleagent.EventTypeOfToolMidAnswer,
 | 
						|
				ToolMidAnswer: srT,
 | 
						|
			}, nil)
 | 
						|
		}
 | 
						|
		swT.Send(&schema.Message{
 | 
						|
			Role:    msg.DataMessage.Role,
 | 
						|
			Content: msg.DataMessage.Content,
 | 
						|
			Extra: func(msg *crossworkflow.WorkflowMessage) map[string]any {
 | 
						|
 | 
						|
				extra := make(map[string]any)
 | 
						|
				extra["workflow_node_name"] = msg.NodeTitle
 | 
						|
				if msg.DataMessage.Last {
 | 
						|
					extra["is_finish"] = true
 | 
						|
				}
 | 
						|
				return extra
 | 
						|
			}(msg),
 | 
						|
		}, nil)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (r *AgentRunner) PreHandlerReq(ctx context.Context, req *AgentRequest) *AgentRequest {
 | 
						|
	req.Input = r.preHandlerInput(req.Input)
 | 
						|
	req.History = r.preHandlerHistory(req.History)
 | 
						|
	logs.CtxInfof(ctx, "[AgentRunner] PreHandlerReq, req: %v", conv.DebugJsonToStr(req))
 | 
						|
 | 
						|
	return req
 | 
						|
}
 | 
						|
 | 
						|
func (r *AgentRunner) preHandlerInput(input *schema.Message) *schema.Message {
 | 
						|
	var multiContent []schema.ChatMessagePart
 | 
						|
 | 
						|
	if len(input.MultiContent) == 0 {
 | 
						|
		return input
 | 
						|
	}
 | 
						|
 | 
						|
	unSupportMultiPart := make([]schema.ChatMessagePart, 0, len(input.MultiContent))
 | 
						|
 | 
						|
	for _, v := range input.MultiContent {
 | 
						|
		switch v.Type {
 | 
						|
		case schema.ChatMessagePartTypeImageURL:
 | 
						|
			if !r.isSupportImage() {
 | 
						|
				unSupportMultiPart = append(unSupportMultiPart, v)
 | 
						|
			} else {
 | 
						|
				multiContent = append(multiContent, v)
 | 
						|
			}
 | 
						|
		case schema.ChatMessagePartTypeFileURL:
 | 
						|
			if !r.isSupportFile() {
 | 
						|
				unSupportMultiPart = append(unSupportMultiPart, v)
 | 
						|
			} else {
 | 
						|
				multiContent = append(multiContent, v)
 | 
						|
			}
 | 
						|
		case schema.ChatMessagePartTypeAudioURL:
 | 
						|
			if !r.isSupportAudio() {
 | 
						|
				unSupportMultiPart = append(unSupportMultiPart, v)
 | 
						|
			} else {
 | 
						|
				multiContent = append(multiContent, v)
 | 
						|
			}
 | 
						|
		case schema.ChatMessagePartTypeVideoURL:
 | 
						|
			if !r.isSupportVideo() {
 | 
						|
				unSupportMultiPart = append(unSupportMultiPart, v)
 | 
						|
			} else {
 | 
						|
				multiContent = append(multiContent, v)
 | 
						|
			}
 | 
						|
		case schema.ChatMessagePartTypeText:
 | 
						|
		default:
 | 
						|
			multiContent = append(multiContent, v)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	for _, v := range input.MultiContent {
 | 
						|
		if v.Type != schema.ChatMessagePartTypeText {
 | 
						|
			continue
 | 
						|
		}
 | 
						|
 | 
						|
		if r.isSupportMultiContent() {
 | 
						|
			if len(multiContent) > 0 {
 | 
						|
				v.Text = concatContentString(v.Text, unSupportMultiPart)
 | 
						|
				multiContent = append(multiContent, v)
 | 
						|
			} else {
 | 
						|
				input.Content = concatContentString(v.Text, unSupportMultiPart)
 | 
						|
			}
 | 
						|
		} else {
 | 
						|
			input.Content = concatContentString(v.Text, unSupportMultiPart)
 | 
						|
		}
 | 
						|
 | 
						|
	}
 | 
						|
	input.MultiContent = multiContent
 | 
						|
	return input
 | 
						|
}
 | 
						|
func concatContentString(textContent string, unSupportTypeURL []schema.ChatMessagePart) string {
 | 
						|
	if len(unSupportTypeURL) == 0 {
 | 
						|
		return textContent
 | 
						|
	}
 | 
						|
	for _, v := range unSupportTypeURL {
 | 
						|
		switch v.Type {
 | 
						|
		case schema.ChatMessagePartTypeImageURL:
 | 
						|
			textContent += "  this is a image:" + v.ImageURL.URL
 | 
						|
		case schema.ChatMessagePartTypeFileURL:
 | 
						|
			textContent += "  this is a file:" + v.FileURL.URL
 | 
						|
		case schema.ChatMessagePartTypeAudioURL:
 | 
						|
			textContent += "  this is a audio:" + v.AudioURL.URL
 | 
						|
		case schema.ChatMessagePartTypeVideoURL:
 | 
						|
			textContent += "  this is a video:" + v.VideoURL.URL
 | 
						|
		default:
 | 
						|
		}
 | 
						|
	}
 | 
						|
	return textContent
 | 
						|
}
 | 
						|
 | 
						|
func (r *AgentRunner) preHandlerHistory(history []*schema.Message) []*schema.Message {
 | 
						|
	var hm []*schema.Message
 | 
						|
	for _, msg := range history {
 | 
						|
		if msg.Role == schema.User {
 | 
						|
			msg = r.preHandlerInput(msg)
 | 
						|
		}
 | 
						|
		hm = append(hm, msg)
 | 
						|
	}
 | 
						|
	return hm
 | 
						|
}
 | 
						|
 | 
						|
func (r *AgentRunner) isSupportMultiContent() bool {
 | 
						|
	return len(r.modelInfo.Meta.Capability.InputModal) > 1
 | 
						|
}
 | 
						|
func (r *AgentRunner) isSupportImage() bool {
 | 
						|
	return slices.Contains(r.modelInfo.Meta.Capability.InputModal, modelmgr.ModalImage)
 | 
						|
}
 | 
						|
func (r *AgentRunner) isSupportFile() bool {
 | 
						|
	return slices.Contains(r.modelInfo.Meta.Capability.InputModal, modelmgr.ModalFile)
 | 
						|
}
 | 
						|
func (r *AgentRunner) isSupportAudio() bool {
 | 
						|
	return slices.Contains(r.modelInfo.Meta.Capability.InputModal, modelmgr.ModalAudio)
 | 
						|
}
 | 
						|
func (r *AgentRunner) isSupportVideo() bool {
 | 
						|
	return slices.Contains(r.modelInfo.Meta.Capability.InputModal, modelmgr.ModalVideo)
 | 
						|
}
 |