292 lines
		
	
	
		
			8.5 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			292 lines
		
	
	
		
			8.5 KiB
		
	
	
	
		
			Go
		
	
	
	
| /*
 | |
|  * Copyright 2025 coze-dev Authors
 | |
|  *
 | |
|  * Licensed under the Apache License, Version 2.0 (the "License");
 | |
|  * you may not use this file except in compliance with the License.
 | |
|  * You may obtain a copy of the License at
 | |
|  *
 | |
|  *     http://www.apache.org/licenses/LICENSE-2.0
 | |
|  *
 | |
|  * Unless required by applicable law or agreed to in writing, software
 | |
|  * distributed under the License is distributed on an "AS IS" BASIS,
 | |
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
|  * See the License for the specific language governing permissions and
 | |
|  * limitations under the License.
 | |
|  */
 | |
| 
 | |
| package agentflow
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"errors"
 | |
| 	"io"
 | |
| 	"slices"
 | |
| 
 | |
| 	"github.com/google/uuid"
 | |
| 
 | |
| 	"github.com/cloudwego/eino/compose"
 | |
| 	"github.com/cloudwego/eino/schema"
 | |
| 
 | |
| 	"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
 | |
| 	"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
 | |
| 	"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossworkflow"
 | |
| 	"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
 | |
| 	"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
 | |
| 	"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
 | |
| 	"github.com/coze-dev/coze-studio/backend/pkg/logs"
 | |
| 	"github.com/coze-dev/coze-studio/backend/pkg/safego"
 | |
| )
 | |
| 
 | |
| type AgentState struct {
 | |
| 	Messages                 []*schema.Message
 | |
| 	UserInput                *schema.Message
 | |
| 	ReturnDirectlyToolCallID string
 | |
| }
 | |
| 
 | |
| type AgentRequest struct {
 | |
| 	UserID  string
 | |
| 	Input   *schema.Message
 | |
| 	History []*schema.Message
 | |
| 
 | |
| 	Identity *singleagent.AgentIdentity
 | |
| 
 | |
| 	ResumeInfo   *singleagent.InterruptInfo
 | |
| 	PreCallTools []*agentrun.ToolsRetriever
 | |
| 	Variables    map[string]string
 | |
| }
 | |
| 
 | |
| type AgentRunner struct {
 | |
| 	runner            compose.Runnable[*AgentRequest, *schema.Message]
 | |
| 	requireCheckpoint bool
 | |
| 
 | |
| 	returnDirectlyTools map[string]struct{}
 | |
| 	containWfTool       bool
 | |
| 	modelInfo           *modelmgr.Model
 | |
| }
 | |
| 
 | |
| func (r *AgentRunner) StreamExecute(ctx context.Context, req *AgentRequest) (
 | |
| 	sr *schema.StreamReader[*entity.AgentEvent], err error,
 | |
| ) {
 | |
| 	executeID := uuid.New()
 | |
| 
 | |
| 	hdl, sr, sw := newReplyCallback(ctx, executeID.String(), r.returnDirectlyTools)
 | |
| 
 | |
| 	var composeOpts []compose.Option
 | |
| 	var pipeMsgOpt compose.Option
 | |
| 	var workflowMsgSr *schema.StreamReader[*crossworkflow.WorkflowMessage]
 | |
| 	if r.containWfTool {
 | |
| 		cfReq := crossworkflow.ExecuteConfig{
 | |
| 			AgentID:      &req.Identity.AgentID,
 | |
| 			ConnectorUID: req.UserID,
 | |
| 			ConnectorID:  req.Identity.ConnectorID,
 | |
| 			BizType:      crossworkflow.BizTypeAgent,
 | |
| 		}
 | |
| 		if req.Identity.IsDraft {
 | |
| 			cfReq.Mode = crossworkflow.ExecuteModeDebug
 | |
| 		} else {
 | |
| 			cfReq.Mode = crossworkflow.ExecuteModeRelease
 | |
| 		}
 | |
| 		wfConfig := crossworkflow.DefaultSVC().WithExecuteConfig(cfReq)
 | |
| 		composeOpts = append(composeOpts, wfConfig)
 | |
| 		pipeMsgOpt, workflowMsgSr = crossworkflow.DefaultSVC().WithMessagePipe()
 | |
| 		composeOpts = append(composeOpts, pipeMsgOpt)
 | |
| 	}
 | |
| 
 | |
| 	composeOpts = append(composeOpts, compose.WithCallbacks(hdl))
 | |
| 	_ = compose.RegisterSerializableType[*AgentState]("agent_state")
 | |
| 	if r.requireCheckpoint {
 | |
| 
 | |
| 		defaultCheckPointID := executeID.String()
 | |
| 		if req.ResumeInfo != nil {
 | |
| 			resumeInfo := req.ResumeInfo
 | |
| 			if resumeInfo.InterruptType != singleagent.InterruptEventType_OauthPlugin {
 | |
| 				defaultCheckPointID = resumeInfo.InterruptID
 | |
| 				opts := crossworkflow.DefaultSVC().WithResumeToolWorkflow(resumeInfo.AllWfInterruptData[resumeInfo.ToolCallID], req.Input.Content, resumeInfo.AllWfInterruptData)
 | |
| 				composeOpts = append(composeOpts, opts)
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		composeOpts = append(composeOpts, compose.WithCheckPointID(defaultCheckPointID))
 | |
| 	}
 | |
| 	if r.containWfTool && workflowMsgSr != nil {
 | |
| 		safego.Go(ctx, func() {
 | |
| 			r.processWfMidAnswerStream(ctx, sw, workflowMsgSr)
 | |
| 		})
 | |
| 	}
 | |
| 	safego.Go(ctx, func() {
 | |
| 		defer func() {
 | |
| 			if pe := recover(); pe != nil {
 | |
| 				logs.CtxErrorf(ctx, "[AgentRunner] StreamExecute recover, err: %v", pe)
 | |
| 
 | |
| 				sw.Send(nil, errors.New("internal server error"))
 | |
| 			}
 | |
| 			sw.Close()
 | |
| 		}()
 | |
| 		_, _ = r.runner.Stream(ctx, req, composeOpts...)
 | |
| 	})
 | |
| 
 | |
| 	return sr, nil
 | |
| }
 | |
| 
 | |
| func (r *AgentRunner) processWfMidAnswerStream(_ context.Context, sw *schema.StreamWriter[*entity.AgentEvent], wfStream *schema.StreamReader[*crossworkflow.WorkflowMessage]) {
 | |
| 	streamInitialized := false
 | |
| 	var srT *schema.StreamReader[*schema.Message]
 | |
| 	var swT *schema.StreamWriter[*schema.Message]
 | |
| 	defer func() {
 | |
| 		if swT != nil {
 | |
| 			swT.Close()
 | |
| 		}
 | |
| 	}()
 | |
| 	for {
 | |
| 		msg, err := wfStream.Recv()
 | |
| 
 | |
| 		if err == io.EOF {
 | |
| 			break
 | |
| 		}
 | |
| 		if msg == nil || msg.DataMessage == nil {
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		if msg.DataMessage.NodeType != crossworkflow.NodeTypeOutputEmitter {
 | |
| 			continue
 | |
| 		}
 | |
| 		if !streamInitialized {
 | |
| 			streamInitialized = true
 | |
| 			srT, swT = schema.Pipe[*schema.Message](5)
 | |
| 			sw.Send(&entity.AgentEvent{
 | |
| 				EventType:     singleagent.EventTypeOfToolMidAnswer,
 | |
| 				ToolMidAnswer: srT,
 | |
| 			}, nil)
 | |
| 		}
 | |
| 		swT.Send(&schema.Message{
 | |
| 			Role:    msg.DataMessage.Role,
 | |
| 			Content: msg.DataMessage.Content,
 | |
| 			Extra: func(msg *crossworkflow.WorkflowMessage) map[string]any {
 | |
| 
 | |
| 				extra := make(map[string]any)
 | |
| 				extra["workflow_node_name"] = msg.NodeTitle
 | |
| 				if msg.DataMessage.Last {
 | |
| 					extra["is_finish"] = true
 | |
| 				}
 | |
| 				return extra
 | |
| 			}(msg),
 | |
| 		}, nil)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (r *AgentRunner) PreHandlerReq(ctx context.Context, req *AgentRequest) *AgentRequest {
 | |
| 	req.Input = r.preHandlerInput(req.Input)
 | |
| 	req.History = r.preHandlerHistory(req.History)
 | |
| 	logs.CtxInfof(ctx, "[AgentRunner] PreHandlerReq, req: %v", conv.DebugJsonToStr(req))
 | |
| 
 | |
| 	return req
 | |
| }
 | |
| 
 | |
| func (r *AgentRunner) preHandlerInput(input *schema.Message) *schema.Message {
 | |
| 	var multiContent []schema.ChatMessagePart
 | |
| 
 | |
| 	if len(input.MultiContent) == 0 {
 | |
| 		return input
 | |
| 	}
 | |
| 
 | |
| 	unSupportMultiPart := make([]schema.ChatMessagePart, 0, len(input.MultiContent))
 | |
| 
 | |
| 	for _, v := range input.MultiContent {
 | |
| 		switch v.Type {
 | |
| 		case schema.ChatMessagePartTypeImageURL:
 | |
| 			if !r.isSupportImage() {
 | |
| 				unSupportMultiPart = append(unSupportMultiPart, v)
 | |
| 			} else {
 | |
| 				multiContent = append(multiContent, v)
 | |
| 			}
 | |
| 		case schema.ChatMessagePartTypeFileURL:
 | |
| 			if !r.isSupportFile() {
 | |
| 				unSupportMultiPart = append(unSupportMultiPart, v)
 | |
| 			} else {
 | |
| 				multiContent = append(multiContent, v)
 | |
| 			}
 | |
| 		case schema.ChatMessagePartTypeAudioURL:
 | |
| 			if !r.isSupportAudio() {
 | |
| 				unSupportMultiPart = append(unSupportMultiPart, v)
 | |
| 			} else {
 | |
| 				multiContent = append(multiContent, v)
 | |
| 			}
 | |
| 		case schema.ChatMessagePartTypeVideoURL:
 | |
| 			if !r.isSupportVideo() {
 | |
| 				unSupportMultiPart = append(unSupportMultiPart, v)
 | |
| 			} else {
 | |
| 				multiContent = append(multiContent, v)
 | |
| 			}
 | |
| 		case schema.ChatMessagePartTypeText:
 | |
| 		default:
 | |
| 			multiContent = append(multiContent, v)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	for _, v := range input.MultiContent {
 | |
| 		if v.Type != schema.ChatMessagePartTypeText {
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		if r.isSupportMultiContent() {
 | |
| 			if len(multiContent) > 0 {
 | |
| 				v.Text = concatContentString(v.Text, unSupportMultiPart)
 | |
| 				multiContent = append(multiContent, v)
 | |
| 			} else {
 | |
| 				input.Content = concatContentString(v.Text, unSupportMultiPart)
 | |
| 			}
 | |
| 		} else {
 | |
| 			input.Content = concatContentString(v.Text, unSupportMultiPart)
 | |
| 		}
 | |
| 
 | |
| 	}
 | |
| 	input.MultiContent = multiContent
 | |
| 	return input
 | |
| }
 | |
| func concatContentString(textContent string, unSupportTypeURL []schema.ChatMessagePart) string {
 | |
| 	if len(unSupportTypeURL) == 0 {
 | |
| 		return textContent
 | |
| 	}
 | |
| 	for _, v := range unSupportTypeURL {
 | |
| 		switch v.Type {
 | |
| 		case schema.ChatMessagePartTypeImageURL:
 | |
| 			textContent += "  this is a image:" + v.ImageURL.URL
 | |
| 		case schema.ChatMessagePartTypeFileURL:
 | |
| 			textContent += "  this is a file:" + v.FileURL.URL
 | |
| 		case schema.ChatMessagePartTypeAudioURL:
 | |
| 			textContent += "  this is a audio:" + v.AudioURL.URL
 | |
| 		case schema.ChatMessagePartTypeVideoURL:
 | |
| 			textContent += "  this is a video:" + v.VideoURL.URL
 | |
| 		default:
 | |
| 		}
 | |
| 	}
 | |
| 	return textContent
 | |
| }
 | |
| 
 | |
| func (r *AgentRunner) preHandlerHistory(history []*schema.Message) []*schema.Message {
 | |
| 	var hm []*schema.Message
 | |
| 	for _, msg := range history {
 | |
| 		if msg.Role == schema.User {
 | |
| 			msg = r.preHandlerInput(msg)
 | |
| 		}
 | |
| 		hm = append(hm, msg)
 | |
| 	}
 | |
| 	return hm
 | |
| }
 | |
| 
 | |
| func (r *AgentRunner) isSupportMultiContent() bool {
 | |
| 	return len(r.modelInfo.Meta.Capability.InputModal) > 1
 | |
| }
 | |
| func (r *AgentRunner) isSupportImage() bool {
 | |
| 	return slices.Contains(r.modelInfo.Meta.Capability.InputModal, modelmgr.ModalImage)
 | |
| }
 | |
| func (r *AgentRunner) isSupportFile() bool {
 | |
| 	return slices.Contains(r.modelInfo.Meta.Capability.InputModal, modelmgr.ModalFile)
 | |
| }
 | |
| func (r *AgentRunner) isSupportAudio() bool {
 | |
| 	return slices.Contains(r.modelInfo.Meta.Capability.InputModal, modelmgr.ModalAudio)
 | |
| }
 | |
| func (r *AgentRunner) isSupportVideo() bool {
 | |
| 	return slices.Contains(r.modelInfo.Meta.Capability.InputModal, modelmgr.ModalVideo)
 | |
| }
 |