diff --git a/.gitignore b/.gitignore index 0d311d15..e433ddc5 100644 --- a/.gitignore +++ b/.gitignore @@ -34,6 +34,7 @@ output/* # Vscode files .vscode/settings.json +.vscode/launch.json /patches /oldimpl diff --git a/backend/api/handler/coze/bot_open_api_service.go b/backend/api/handler/coze/bot_open_api_service.go index 5ec84a4e..9faba088 100644 --- a/backend/api/handler/coze/bot_open_api_service.go +++ b/backend/api/handler/coze/bot_open_api_service.go @@ -29,6 +29,7 @@ import ( "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/protocol/consts" + "github.com/coze-dev/coze-studio/backend/api/model/app/bot_open_api" ) diff --git a/backend/api/handler/coze/database_service.go b/backend/api/handler/coze/database_service.go index 38900fee..4dc31929 100644 --- a/backend/api/handler/coze/database_service.go +++ b/backend/api/handler/coze/database_service.go @@ -23,6 +23,7 @@ import ( "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/protocol/consts" + "github.com/coze-dev/coze-studio/backend/api/model/data/database/table" "github.com/coze-dev/coze-studio/backend/api/model/data/knowledge" "github.com/coze-dev/coze-studio/backend/application/memory" diff --git a/backend/api/handler/coze/intelligence_service.go b/backend/api/handler/coze/intelligence_service.go index a3dc8f2a..76938d70 100644 --- a/backend/api/handler/coze/intelligence_service.go +++ b/backend/api/handler/coze/intelligence_service.go @@ -24,6 +24,7 @@ import ( "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/protocol/consts" + "github.com/coze-dev/coze-studio/backend/api/model/app/intelligence" "github.com/coze-dev/coze-studio/backend/api/model/app/intelligence/common" project "github.com/coze-dev/coze-studio/backend/api/model/app/intelligence/project" diff --git a/backend/api/handler/coze/playground_service.go b/backend/api/handler/coze/playground_service.go index 15731c8b..1ce2ce49 100644 --- a/backend/api/handler/coze/playground_service.go +++ b/backend/api/handler/coze/playground_service.go @@ -23,6 +23,7 @@ import ( "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/protocol/consts" + "github.com/coze-dev/coze-studio/backend/api/model/playground" appApplication "github.com/coze-dev/coze-studio/backend/application/app" "github.com/coze-dev/coze-studio/backend/application/prompt" diff --git a/backend/api/model/crossdomain/message/message.go b/backend/api/model/crossdomain/message/message.go index 3e27bd38..e516a67e 100644 --- a/backend/api/model/crossdomain/message/message.go +++ b/backend/api/model/crossdomain/message/message.go @@ -101,4 +101,7 @@ const ( MessageTypeFlowUp MessageType = "follow_up" MessageTypeInterrupt MessageType = "interrupt" MessageTypeVerbose MessageType = "verbose" + + MessageTypeToolAsAnswer MessageType = "tool_as_answer" + MessageTypeToolMidAnswer MessageType = "tool_mid_answer" ) diff --git a/backend/api/model/crossdomain/singleagent/single_agent.go b/backend/api/model/crossdomain/singleagent/single_agent.go index 9b3088e3..5ca2c973 100644 --- a/backend/api/model/crossdomain/singleagent/single_agent.go +++ b/backend/api/model/crossdomain/singleagent/single_agent.go @@ -39,6 +39,7 @@ type EventType string const ( EventTypeOfChatModelAnswer EventType = "chatmodel_answer" EventTypeOfToolsAsChatModelStream EventType = "tools_as_chatmodel_answer" + EventTypeOfToolMidAnswer EventType = "tool_mid_answer" EventTypeOfToolsMessage EventType = "tools_message" EventTypeOfFuncCall EventType = "func_call" EventTypeOfSuggest EventType = "suggest" @@ -49,6 +50,9 @@ const ( type AgentEvent struct { EventType EventType + ToolMidAnswer *schema.StreamReader[*schema.Message] + ToolAsChatModelAnswer *schema.StreamReader[*schema.Message] + ChatModelAnswer *schema.StreamReader[*schema.Message] ToolsMessage []*schema.Message FuncCall *schema.Message diff --git a/backend/crossdomain/contract/crossworkflow/crossworkflow.go b/backend/crossdomain/contract/crossworkflow/crossworkflow.go index 50156392..bbc06b01 100644 --- a/backend/crossdomain/contract/crossworkflow/crossworkflow.go +++ b/backend/crossdomain/contract/crossworkflow/crossworkflow.go @@ -19,9 +19,12 @@ package crossworkflow import ( "context" + "github.com/cloudwego/eino/compose" einoCompose "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" "github.com/coze-dev/coze-studio/backend/domain/workflow" + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" workflowEntity "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" ) @@ -37,10 +40,18 @@ type Workflow interface { GetWorkflowIDsByAppID(ctx context.Context, appID int64) ([]int64, error) SyncExecuteWorkflow(ctx context.Context, config vo.ExecuteConfig, input map[string]any) (*workflowEntity.WorkflowExecution, vo.TerminatePlan, error) WithExecuteConfig(cfg vo.ExecuteConfig) einoCompose.Option + WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message]) } type ExecuteConfig = vo.ExecuteConfig type ExecuteMode = vo.ExecuteMode +type NodeType = entity.NodeType + +type WorkflowMessage = entity.Message + +const ( + NodeTypeOutputEmitter NodeType = "OutputEmitter" +) const ( ExecuteModeDebug ExecuteMode = "debug" diff --git a/backend/crossdomain/impl/singleagent/single_agent.go b/backend/crossdomain/impl/singleagent/single_agent.go index f233f2e6..1475dc22 100644 --- a/backend/crossdomain/impl/singleagent/single_agent.go +++ b/backend/crossdomain/impl/singleagent/single_agent.go @@ -165,6 +165,10 @@ func (c *impl) buildSchemaMessage(ctx context.Context, msgs []*message.Message) if err != nil { continue } + if len(sm.ReasoningContent) > 0 { + sm.ReasoningContent = "" + } + schemaMessage = append(schemaMessage, c.parseMessageURI(ctx, sm)) } diff --git a/backend/crossdomain/impl/workflow/workflow.go b/backend/crossdomain/impl/workflow/workflow.go index 2b8b0549..651e6221 100644 --- a/backend/crossdomain/impl/workflow/workflow.go +++ b/backend/crossdomain/impl/workflow/workflow.go @@ -19,10 +19,13 @@ package workflow import ( "context" + "github.com/cloudwego/eino/compose" einoCompose "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" "github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossworkflow" "github.com/coze-dev/coze-studio/backend/domain/workflow" + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" workflowEntity "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" @@ -72,6 +75,10 @@ func (i *impl) WithExecuteConfig(cfg vo.ExecuteConfig) einoCompose.Option { return i.DomainSVC.WithExecuteConfig(cfg) } +func (i *impl) WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message]) { + return i.DomainSVC.WithMessagePipe() +} + func (i *impl) GetWorkflowIDsByAppID(ctx context.Context, appID int64) ([]int64, error) { metas, _, err := i.DomainSVC.MGet(ctx, &vo.MGetPolicy{ MetaQuery: vo.MetaQuery{ diff --git a/backend/crossdomain/workflow/plugin/plugin_test.go b/backend/crossdomain/workflow/plugin/plugin_test.go index e6a0a460..bb0f55a3 100644 --- a/backend/crossdomain/workflow/plugin/plugin_test.go +++ b/backend/crossdomain/workflow/plugin/plugin_test.go @@ -19,10 +19,11 @@ package plugin import ( "testing" + "github.com/stretchr/testify/assert" + common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common" workflow3 "github.com/coze-dev/coze-studio/backend/api/model/workflow" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" - "github.com/stretchr/testify/assert" ) func TestToWorkflowAPIParameter(t *testing.T) { diff --git a/backend/domain/agent/singleagent/internal/agentflow/agent_flow_runner.go b/backend/domain/agent/singleagent/internal/agentflow/agent_flow_runner.go index 9cb0bb7b..17c9d0ec 100644 --- a/backend/domain/agent/singleagent/internal/agentflow/agent_flow_runner.go +++ b/backend/domain/agent/singleagent/internal/agentflow/agent_flow_runner.go @@ -19,6 +19,7 @@ package agentflow import ( "context" "errors" + "io" "slices" "github.com/google/uuid" @@ -33,6 +34,7 @@ import ( "github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr" "github.com/coze-dev/coze-studio/backend/pkg/lang/conv" "github.com/coze-dev/coze-studio/backend/pkg/logs" + "github.com/coze-dev/coze-studio/backend/pkg/safego" ) type AgentState struct { @@ -69,7 +71,49 @@ func (r *AgentRunner) StreamExecute(ctx context.Context, req *AgentRequest) ( hdl, sr, sw := newReplyCallback(ctx, executeID.String(), r.returnDirectlyTools) - go func() { + var composeOpts []compose.Option + var pipeMsgOpt compose.Option + var workflowMsgSr *schema.StreamReader[*crossworkflow.WorkflowMessage] + if r.containWfTool { + cfReq := crossworkflow.ExecuteConfig{ + AgentID: &req.Identity.AgentID, + ConnectorUID: req.UserID, + ConnectorID: req.Identity.ConnectorID, + BizType: crossworkflow.BizTypeAgent, + } + if req.Identity.IsDraft { + cfReq.Mode = crossworkflow.ExecuteModeDebug + } else { + cfReq.Mode = crossworkflow.ExecuteModeRelease + } + wfConfig := crossworkflow.DefaultSVC().WithExecuteConfig(cfReq) + composeOpts = append(composeOpts, wfConfig) + pipeMsgOpt, workflowMsgSr = crossworkflow.DefaultSVC().WithMessagePipe() + composeOpts = append(composeOpts, pipeMsgOpt) + } + + composeOpts = append(composeOpts, compose.WithCallbacks(hdl)) + _ = compose.RegisterSerializableType[*AgentState]("agent_state") + if r.requireCheckpoint { + + defaultCheckPointID := executeID.String() + if req.ResumeInfo != nil { + resumeInfo := req.ResumeInfo + if resumeInfo.InterruptType != singleagent.InterruptEventType_OauthPlugin { + defaultCheckPointID = resumeInfo.InterruptID + opts := crossworkflow.DefaultSVC().WithResumeToolWorkflow(resumeInfo.AllWfInterruptData[resumeInfo.ToolCallID], req.Input.Content, resumeInfo.AllWfInterruptData) + composeOpts = append(composeOpts, opts) + } + } + + composeOpts = append(composeOpts, compose.WithCheckPointID(defaultCheckPointID)) + } + if r.containWfTool && workflowMsgSr != nil { + safego.Go(ctx, func() { + r.processWfMidAnswerStream(ctx, sw, workflowMsgSr) + }) + } + safego.Go(ctx, func() { defer func() { if pe := recover(); pe != nil { logs.CtxErrorf(ctx, "[AgentRunner] StreamExecute recover, err: %v", pe) @@ -78,45 +122,58 @@ func (r *AgentRunner) StreamExecute(ctx context.Context, req *AgentRequest) ( } sw.Close() }() - - var composeOpts []compose.Option - composeOpts = append(composeOpts, compose.WithCallbacks(hdl)) - _ = compose.RegisterSerializableType[*AgentState]("agent_state") - if r.requireCheckpoint { - - defaultCheckPointID := executeID.String() - if req.ResumeInfo != nil { - resumeInfo := req.ResumeInfo - if resumeInfo.InterruptType != singleagent.InterruptEventType_OauthPlugin { - defaultCheckPointID = resumeInfo.InterruptID - opts := crossworkflow.DefaultSVC().WithResumeToolWorkflow(resumeInfo.AllWfInterruptData[resumeInfo.ToolCallID], req.Input.Content, resumeInfo.AllWfInterruptData) - composeOpts = append(composeOpts, opts) - } - } - - composeOpts = append(composeOpts, compose.WithCheckPointID(defaultCheckPointID)) - } - if r.containWfTool { - cfReq := crossworkflow.ExecuteConfig{ - AgentID: &req.Identity.AgentID, - ConnectorUID: req.UserID, - ConnectorID: req.Identity.ConnectorID, - BizType: crossworkflow.BizTypeAgent, - } - if req.Identity.IsDraft { - cfReq.Mode = crossworkflow.ExecuteModeDebug - } else { - cfReq.Mode = crossworkflow.ExecuteModeRelease - } - wfConfig := crossworkflow.DefaultSVC().WithExecuteConfig(cfReq) - composeOpts = append(composeOpts, wfConfig) - } _, _ = r.runner.Stream(ctx, req, composeOpts...) - }() + }) return sr, nil } +func (r *AgentRunner) processWfMidAnswerStream(_ context.Context, sw *schema.StreamWriter[*entity.AgentEvent], wfStream *schema.StreamReader[*crossworkflow.WorkflowMessage]) { + streamInitialized := false + var srT *schema.StreamReader[*schema.Message] + var swT *schema.StreamWriter[*schema.Message] + defer func() { + if swT != nil { + swT.Close() + } + }() + for { + msg, err := wfStream.Recv() + + if err == io.EOF { + break + } + if msg == nil || msg.DataMessage == nil { + continue + } + + if msg.DataMessage.NodeType != crossworkflow.NodeTypeOutputEmitter { + continue + } + if !streamInitialized { + streamInitialized = true + srT, swT = schema.Pipe[*schema.Message](5) + sw.Send(&entity.AgentEvent{ + EventType: singleagent.EventTypeOfToolMidAnswer, + ToolMidAnswer: srT, + }, nil) + } + swT.Send(&schema.Message{ + Role: msg.DataMessage.Role, + Content: msg.DataMessage.Content, + Extra: func(msg *crossworkflow.WorkflowMessage) map[string]any { + + extra := make(map[string]any) + extra["workflow_node_name"] = msg.NodeTitle + if msg.DataMessage.Last { + extra["is_finish"] = true + } + return extra + }(msg), + }, nil) + } +} + func (r *AgentRunner) PreHandlerReq(ctx context.Context, req *AgentRequest) *AgentRequest { req.Input = r.preHandlerInput(req.Input) req.History = r.preHandlerHistory(req.History) diff --git a/backend/domain/agent/singleagent/internal/agentflow/callback_reply_chunk.go b/backend/domain/agent/singleagent/internal/agentflow/callback_reply_chunk.go index ca0ebb81..a2163731 100644 --- a/backend/domain/agent/singleagent/internal/agentflow/callback_reply_chunk.go +++ b/backend/domain/agent/singleagent/internal/agentflow/callback_reply_chunk.go @@ -267,7 +267,6 @@ func convInterruptEventType(interruptEvent any) singleagent.InterruptEventType { } func (r *replyChunkCallback) concatToolsNodeOutput(ctx context.Context, output *schema.StreamReader[callbacks.CallbackOutput]) ([]*schema.Message, error) { - defer output.Close() var toolsMsgChunks [][]*schema.Message var sr *schema.StreamReader[*schema.Message] var sw *schema.StreamWriter[*schema.Message] @@ -280,7 +279,6 @@ func (r *replyChunkCallback) concatToolsNodeOutput(ctx context.Context, output * returnDirectToolsMap := make(map[int]bool) isReturnDirectToolsFirstCheck := true isToolsMsgChunksInit := false - for { cbOut, err := output.Recv() if errors.Is(err, io.EOF) { @@ -318,8 +316,8 @@ func (r *replyChunkCallback) concatToolsNodeOutput(ctx context.Context, output * if !streamInitialized { sr, sw = schema.Pipe[*schema.Message](5) r.sw.Send(&entity.AgentEvent{ - EventType: singleagent.EventTypeOfToolsAsChatModelStream, - ChatModelAnswer: sr, + EventType: singleagent.EventTypeOfToolsAsChatModelStream, + ToolAsChatModelAnswer: sr, }, nil) streamInitialized = true } diff --git a/backend/domain/conversation/agentrun/entity/run_record.go b/backend/domain/conversation/agentrun/entity/run_record.go index f03fe8a0..e9e6f769 100644 --- a/backend/domain/conversation/agentrun/entity/run_record.go +++ b/backend/domain/conversation/agentrun/entity/run_record.go @@ -141,15 +141,17 @@ type AgentRunResponse struct { } type AgentRespEvent struct { - EventType message.MessageType + EventType message.MessageType `json:"event_type"` - ModelAnswer *schema.StreamReader[*schema.Message] - ToolsMessage []*schema.Message - FuncCall *schema.Message - Suggest *schema.Message - Knowledge []*schema.Document - Interrupt *singleagent.InterruptInfo - Err error + ToolMidAnswer *schema.StreamReader[*schema.Message] + ToolAsAnswer *schema.StreamReader[*schema.Message] + ModelAnswer *schema.StreamReader[*schema.Message] + ToolsMessage []*schema.Message + FuncCall *schema.Message + Suggest *schema.Message + Knowledge []*schema.Document + Interrupt *singleagent.InterruptInfo + Err error } type ModelAnswerEvent struct { diff --git a/backend/domain/conversation/agentrun/service/agent_run_impl.go b/backend/domain/conversation/agentrun/service/agent_run_impl.go index 0a974ce7..f137b83f 100644 --- a/backend/domain/conversation/agentrun/service/agent_run_impl.go +++ b/backend/domain/conversation/agentrun/service/agent_run_impl.go @@ -203,8 +203,12 @@ func transformEventMap(eventType singleagent.EventType) (message.MessageType, er return message.MessageTypeKnowledge, nil case singleagent.EventTypeOfToolsMessage: return message.MessageTypeToolResponse, nil - case singleagent.EventTypeOfChatModelAnswer, singleagent.EventTypeOfToolsAsChatModelStream: + case singleagent.EventTypeOfChatModelAnswer: return message.MessageTypeAnswer, nil + case singleagent.EventTypeOfToolsAsChatModelStream: + return message.MessageTypeToolAsAnswer, nil + case singleagent.EventTypeOfToolMidAnswer: + return message.MessageTypeToolMidAnswer, nil case singleagent.EventTypeOfSuggest: return message.MessageTypeFlowUp, nil case singleagent.EventTypeOfInterrupt: @@ -241,12 +245,12 @@ func (c *runImpl) buildAgentMessage2Create(ctx context.Context, chunk *entity.Ag buildExt = arm.Ext msg.DisplayContent = arm.DisplayContent - case message.MessageTypeAnswer: + case message.MessageTypeAnswer, message.MessageTypeToolAsAnswer: msg.Role = schema.Assistant msg.ContentType = message.ContentTypeText case message.MessageTypeToolResponse: - msg.Role = schema.Tool + msg.Role = schema.Assistant msg.ContentType = message.ContentTypeText msg.Content = chunk.ToolsMessage[0].Content @@ -261,7 +265,7 @@ func (c *runImpl) buildAgentMessage2Create(ctx context.Context, chunk *entity.Ag msg.Role = schema.Assistant msg.ContentType = message.ContentTypeText - knowledgeContent := c.buildKnowledge(ctx, arm, chunk) + knowledgeContent := c.buildKnowledge(ctx, chunk) if knowledgeContent != nil { knInfo, err := json.Marshal(knowledgeContent) if err == nil { @@ -461,6 +465,9 @@ func (c *runImpl) pull(_ context.Context, mainChan chan *entity.AgentRespEvent, Knowledge: rm.Knowledge, Suggest: rm.Suggest, Interrupt: rm.Interrupt, + + ToolMidAnswer: rm.ToolMidAnswer, + ToolAsAnswer: rm.ToolAsChatModelAnswer, } mainChan <- respChunk @@ -478,9 +485,12 @@ func (c *runImpl) push(ctx context.Context, mainChan chan *entity.AgentRespEvent }() reasoningContent := bytes.NewBuffer([]byte{}) - var createPreMsg = true - var preFinalAnswerMsg *msgEntity.Message + var firstAnswerMsg *msgEntity.Message + var reasoningMsg *msgEntity.Message + isSendFinishAnswer := false + var preToolResponseMsg *msgEntity.Message + toolResponseMsgContent := bytes.NewBuffer([]byte{}) for { chunk, ok := <-mainChan if !ok || chunk == nil { @@ -489,6 +499,19 @@ func (c *runImpl) push(ctx context.Context, mainChan chan *entity.AgentRespEvent logs.CtxInfof(ctx, "hanlder event:%v,err:%v", conv.DebugJsonToStr(chunk), chunk.Err) if chunk.Err != nil { if errors.Is(chunk.Err, io.EOF) { + if !isSendFinishAnswer { + isSendFinishAnswer = true + if firstAnswerMsg != nil && len(reasoningContent.String()) > 0 { + c.saveReasoningContent(ctx, firstAnswerMsg, reasoningContent.String()) + reasoningContent.Reset() + } + + finishErr := c.handlerFinalAnswerFinish(ctx, sw, rtDependence) + if finishErr != nil { + err = finishErr + return + } + } return } c.handlerErr(ctx, chunk.Err, sw) @@ -501,45 +524,156 @@ func (c *runImpl) push(ctx context.Context, mainChan chan *entity.AgentRespEvent if err != nil { return } + + if preToolResponseMsg == nil { + var cErr error + preToolResponseMsg, cErr = c.PreCreateAnswer(ctx, rtDependence) + if cErr != nil { + err = cErr + return + } + } case message.MessageTypeToolResponse: - err = c.handlerTooResponse(ctx, chunk, sw, rtDependence) + err = c.handlerTooResponse(ctx, chunk, sw, rtDependence, preToolResponseMsg, toolResponseMsgContent.String()) if err != nil { return } + preToolResponseMsg = nil // reset case message.MessageTypeKnowledge: err = c.handlerKnowledge(ctx, chunk, sw, rtDependence) if err != nil { return } + case message.MessageTypeToolMidAnswer: + fullMidAnswerContent := bytes.NewBuffer([]byte{}) + var usage *msgEntity.UsageExt + toolMidAnswerMsg, cErr := c.PreCreateAnswer(ctx, rtDependence) + + if cErr != nil { + err = cErr + return + } + + var preMsgIsFinish = false + for { + streamMsg, receErr := chunk.ToolMidAnswer.Recv() + if receErr != nil { + if errors.Is(receErr, io.EOF) { + break + } + err = receErr + return + } + if preMsgIsFinish { + toolMidAnswerMsg, cErr = c.PreCreateAnswer(ctx, rtDependence) + if cErr != nil { + err = cErr + return + } + preMsgIsFinish = false + } + if streamMsg == nil { + continue + } + if firstAnswerMsg == nil && len(streamMsg.Content) > 0 { + if reasoningMsg != nil { + toolMidAnswerMsg = deepcopy.Copy(reasoningMsg).(*msgEntity.Message) + } + firstAnswerMsg = deepcopy.Copy(toolMidAnswerMsg).(*msgEntity.Message) + } + + if streamMsg.Extra != nil { + if val, ok := streamMsg.Extra["workflow_node_name"]; ok && val != nil { + toolMidAnswerMsg.Ext["message_title"] = val.(string) + } + } + + sendMidAnswerMsg := c.buildSendMsg(ctx, toolMidAnswerMsg, false, rtDependence) + sendMidAnswerMsg.Content = streamMsg.Content + toolResponseMsgContent.WriteString(streamMsg.Content) + fullMidAnswerContent.WriteString(streamMsg.Content) + + c.runEvent.SendMsgEvent(entity.RunEventMessageDelta, sendMidAnswerMsg, sw) + + if streamMsg != nil && streamMsg.ResponseMeta != nil { + usage = c.handlerUsage(streamMsg.ResponseMeta) + } + + if streamMsg.Extra["is_finish"] == true { + preMsgIsFinish = true + sendMidAnswerMsg := c.buildSendMsg(ctx, toolMidAnswerMsg, false, rtDependence) + sendMidAnswerMsg.Content = fullMidAnswerContent.String() + fullMidAnswerContent.Reset() + hfErr := c.handlerAnswer(ctx, sendMidAnswerMsg, sw, usage, rtDependence, toolMidAnswerMsg) + if hfErr != nil { + err = hfErr + return + } + } + } + + case message.MessageTypeToolAsAnswer: + var usage *msgEntity.UsageExt + fullContent := bytes.NewBuffer([]byte{}) + toolAsAnswerMsg, cErr := c.PreCreateAnswer(ctx, rtDependence) + if cErr != nil { + err = cErr + return + } + if firstAnswerMsg == nil { + firstAnswerMsg = toolAsAnswerMsg + } + + for { + streamMsg, receErr := chunk.ToolAsAnswer.Recv() + if receErr != nil { + if errors.Is(receErr, io.EOF) { + + answer := c.buildSendMsg(ctx, toolAsAnswerMsg, false, rtDependence) + answer.Content = fullContent.String() + hfErr := c.handlerAnswer(ctx, answer, sw, usage, rtDependence, toolAsAnswerMsg) + if hfErr != nil { + err = hfErr + return + } + break + } + err = receErr + return + } + + if streamMsg != nil && streamMsg.ResponseMeta != nil { + usage = c.handlerUsage(streamMsg.ResponseMeta) + } + sendMsg := c.buildSendMsg(ctx, toolAsAnswerMsg, false, rtDependence) + fullContent.WriteString(streamMsg.Content) + sendMsg.Content = streamMsg.Content + c.runEvent.SendMsgEvent(entity.RunEventMessageDelta, sendMsg, sw) + } + case message.MessageTypeAnswer: fullContent := bytes.NewBuffer([]byte{}) var usage *msgEntity.UsageExt var isToolCalls = false - + var modelAnswerMsg *msgEntity.Message for { streamMsg, receErr := chunk.ModelAnswer.Recv() - if receErr != nil { if errors.Is(receErr, io.EOF) { if isToolCalls { break } - - finalAnswer := c.buildSendMsg(ctx, preFinalAnswerMsg, false, rtDependence) - - finalAnswer.Content = fullContent.String() - finalAnswer.ReasoningContent = ptr.Of(reasoningContent.String()) - hfErr := c.handlerFinalAnswer(ctx, finalAnswer, sw, usage, rtDependence, preFinalAnswerMsg) + if modelAnswerMsg == nil { + break + } + answer := c.buildSendMsg(ctx, modelAnswerMsg, false, rtDependence) + answer.Content = fullContent.String() + hfErr := c.handlerAnswer(ctx, answer, sw, usage, rtDependence, modelAnswerMsg) if hfErr != nil { err = hfErr return } - finishErr := c.handlerFinalAnswerFinish(ctx, sw, rtDependence) - if finishErr != nil { - err = finishErr - return - } break } err = receErr @@ -557,32 +691,64 @@ func (c *runImpl) push(ctx context.Context, mainChan chan *entity.AgentRespEvent if streamMsg != nil && len(streamMsg.ReasoningContent) == 0 && len(streamMsg.Content) == 0 { continue } - if createPreMsg && (len(streamMsg.ReasoningContent) > 0 || len(streamMsg.Content) > 0) { - preFinalAnswerMsg, err = c.PreCreateFinalAnswer(ctx, rtDependence) - if err != nil { - return + + if len(streamMsg.ReasoningContent) > 0 { + if reasoningMsg == nil { + reasoningMsg, err = c.PreCreateAnswer(ctx, rtDependence) + if err != nil { + return + } } - createPreMsg = false + + sendReasoningMsg := c.buildSendMsg(ctx, reasoningMsg, false, rtDependence) + reasoningContent.WriteString(streamMsg.ReasoningContent) + sendReasoningMsg.ReasoningContent = ptr.Of(streamMsg.ReasoningContent) + c.runEvent.SendMsgEvent(entity.RunEventMessageDelta, sendReasoningMsg, sw) } + if len(streamMsg.Content) > 0 { - sendMsg := c.buildSendMsg(ctx, preFinalAnswerMsg, false, rtDependence) - reasoningContent.WriteString(streamMsg.ReasoningContent) - sendMsg.ReasoningContent = ptr.Of(streamMsg.ReasoningContent) + if modelAnswerMsg == nil { + modelAnswerMsg, err = c.PreCreateAnswer(ctx, rtDependence) + if err != nil { + return + } + if firstAnswerMsg == nil { + if reasoningMsg != nil { + modelAnswerMsg.ID = reasoningMsg.ID + } + firstAnswerMsg = modelAnswerMsg + } + } - fullContent.WriteString(streamMsg.Content) - sendMsg.Content = streamMsg.Content - - c.runEvent.SendMsgEvent(entity.RunEventMessageDelta, sendMsg, sw) + sendAnswerMsg := c.buildSendMsg(ctx, modelAnswerMsg, false, rtDependence) + fullContent.WriteString(streamMsg.Content) + sendAnswerMsg.Content = streamMsg.Content + c.runEvent.SendMsgEvent(entity.RunEventMessageDelta, sendAnswerMsg, sw) + } } case message.MessageTypeFlowUp: + if isSendFinishAnswer { + + if firstAnswerMsg != nil && len(reasoningContent.String()) > 0 { + c.saveReasoningContent(ctx, firstAnswerMsg, reasoningContent.String()) + } + + isSendFinishAnswer = true + finishErr := c.handlerFinalAnswerFinish(ctx, sw, rtDependence) + if finishErr != nil { + err = finishErr + return + } + } + err = c.handlerSuggest(ctx, chunk, sw, rtDependence) if err != nil { return } case message.MessageTypeInterrupt: - err = c.handlerInterrupt(ctx, chunk, sw, rtDependence) + err = c.handlerInterrupt(ctx, chunk, sw, rtDependence, firstAnswerMsg, reasoningContent.String()) if err != nil { return } @@ -590,12 +756,22 @@ func (c *runImpl) push(ctx context.Context, mainChan chan *entity.AgentRespEvent } } -func (c *runImpl) handlerInterrupt(ctx context.Context, chunk *entity.AgentRespEvent, sw *schema.StreamWriter[*entity.AgentRunResponse], rtDependence *runtimeDependence) error { +func (c *runImpl) saveReasoningContent(ctx context.Context, firstAnswerMsg *msgEntity.Message, reasoningContent string) { + _, err := crossmessage.DefaultSVC().Edit(ctx, &message.Message{ + ID: firstAnswerMsg.ID, + ReasoningContent: reasoningContent, + }) + if err != nil { + logs.CtxInfof(ctx, "save reasoning content failed, err: %v", err) + } +} + +func (c *runImpl) handlerInterrupt(ctx context.Context, chunk *entity.AgentRespEvent, sw *schema.StreamWriter[*entity.AgentRunResponse], rtDependence *runtimeDependence, firstAnswerMsg *msgEntity.Message, reasoningCOntent string) error { interruptData, cType, err := c.parseInterruptData(ctx, chunk.Interrupt) if err != nil { return err } - preMsg, err := c.PreCreateFinalAnswer(ctx, rtDependence) + preMsg, err := c.PreCreateAnswer(ctx, rtDependence) if err != nil { return err } @@ -616,8 +792,10 @@ func (c *runImpl) handlerInterrupt(ctx context.Context, chunk *entity.AgentRespE c.runEvent.SendMsgEvent(entity.RunEventMessageDelta, deltaAnswer, sw) finalAnswer := deepcopy.Copy(deltaAnswer).(*entity.ChunkMessageItem) - - err = c.handlerFinalAnswer(ctx, finalAnswer, sw, nil, rtDependence, preMsg) + if len(reasoningCOntent) > 0 && firstAnswerMsg == nil { + finalAnswer.ReasoningContent = ptr.Of(reasoningCOntent) + } + err = c.handlerAnswer(ctx, finalAnswer, sw, nil, rtDependence, preMsg) if err != nil { return err } @@ -626,11 +804,6 @@ func (c *runImpl) handlerInterrupt(ctx context.Context, chunk *entity.AgentRespE if err != nil { return err } - - err = c.handlerFinalAnswerFinish(ctx, sw, rtDependence) - if err != nil { - return err - } return nil } @@ -733,7 +906,7 @@ func (c *runImpl) handlerErr(_ context.Context, err error, sw *schema.StreamWrit }) } -func (c *runImpl) PreCreateFinalAnswer(ctx context.Context, rtDependence *runtimeDependence) (*msgEntity.Message, error) { +func (c *runImpl) PreCreateAnswer(ctx context.Context, rtDependence *runtimeDependence) (*msgEntity.Message, error) { arm := rtDependence.runMeta msgMeta := &msgEntity.Message{ ConversationID: arm.ConversationID, @@ -765,7 +938,7 @@ func (c *runImpl) PreCreateFinalAnswer(ctx context.Context, rtDependence *runtim return crossmessage.DefaultSVC().PreCreate(ctx, msgMeta) } -func (c *runImpl) handlerFinalAnswer(ctx context.Context, msg *entity.ChunkMessageItem, sw *schema.StreamWriter[*entity.AgentRunResponse], usage *msgEntity.UsageExt, rtDependence *runtimeDependence, preFinalAnswerMsg *msgEntity.Message) error { +func (c *runImpl) handlerAnswer(ctx context.Context, msg *entity.ChunkMessageItem, sw *schema.StreamWriter[*entity.AgentRunResponse], usage *msgEntity.UsageExt, rtDependence *runtimeDependence, preAnswerMsg *msgEntity.Message) error { if len(msg.Content) == 0 && len(ptr.From(msg.ReasoningContent)) == 0 { return nil @@ -801,12 +974,15 @@ func (c *runImpl) handlerFinalAnswer(ctx context.Context, msg *entity.ChunkMessa if err != nil { return err } - preFinalAnswerMsg.Content = msg.Content - preFinalAnswerMsg.ReasoningContent = ptr.From(msg.ReasoningContent) - preFinalAnswerMsg.Ext = msg.Ext - preFinalAnswerMsg.ContentType = msg.ContentType - preFinalAnswerMsg.ModelContent = string(mc) - _, err = crossmessage.DefaultSVC().Create(ctx, preFinalAnswerMsg) + preAnswerMsg.Content = msg.Content + preAnswerMsg.ReasoningContent = ptr.From(msg.ReasoningContent) + preAnswerMsg.Ext = msg.Ext + preAnswerMsg.ContentType = msg.ContentType + preAnswerMsg.ModelContent = string(mc) + preAnswerMsg.CreatedAt = 0 + preAnswerMsg.UpdatedAt = 0 + + _, err = crossmessage.DefaultSVC().Create(ctx, preAnswerMsg) if err != nil { return err } @@ -860,9 +1036,23 @@ func (c *runImpl) handlerAckMessage(_ context.Context, input *msgEntity.Message, return nil } -func (c *runImpl) handlerTooResponse(ctx context.Context, chunk *entity.AgentRespEvent, sw *schema.StreamWriter[*entity.AgentRunResponse], rtDependence *runtimeDependence) error { +func (c *runImpl) handlerTooResponse(ctx context.Context, chunk *entity.AgentRespEvent, sw *schema.StreamWriter[*entity.AgentRunResponse], rtDependence *runtimeDependence, preToolResponseMsg *msgEntity.Message, toolResponseMsgContent string) error { + cm := c.buildAgentMessage2Create(ctx, chunk, message.MessageTypeToolResponse, rtDependence) - cmData, err := crossmessage.DefaultSVC().Create(ctx, cm) + + var cmData *message.Message + var err error + + if preToolResponseMsg != nil { + cm.ID = preToolResponseMsg.ID + cm.CreatedAt = preToolResponseMsg.CreatedAt + cm.UpdatedAt = preToolResponseMsg.UpdatedAt + if len(toolResponseMsgContent) > 0 { + cm.Content = toolResponseMsgContent + "\n" + cm.Content + } + } + + cmData, err = crossmessage.DefaultSVC().Create(ctx, cm) if err != nil { return err } @@ -902,7 +1092,7 @@ func (c *runImpl) handlerKnowledge(ctx context.Context, chunk *entity.AgentRespE return nil } -func (c *runImpl) buildKnowledge(_ context.Context, arm *entity.AgentRunMeta, chunk *entity.AgentRespEvent) *msgEntity.VerboseInfo { +func (c *runImpl) buildKnowledge(_ context.Context, chunk *entity.AgentRespEvent) *msgEntity.VerboseInfo { var recallDatas []msgEntity.RecallDataInfo for _, kOne := range chunk.Knowledge { recallDatas = append(recallDatas, msgEntity.RecallDataInfo{ diff --git a/backend/domain/conversation/message/internal/dal/message.go b/backend/domain/conversation/message/internal/dal/message.go index f1446e92..d1842ee9 100644 --- a/backend/domain/conversation/message/internal/dal/message.go +++ b/backend/domain/conversation/message/internal/dal/message.go @@ -241,6 +241,12 @@ func (dao *MessageDAO) messageDO2PO(ctx context.Context, msgDo *entity.Message) UpdatedAt: time.Now().UnixMilli(), ReasoningContent: msgDo.ReasoningContent, } + if msgDo.CreatedAt > 0 { + msgPO.CreatedAt = msgDo.CreatedAt + } + if msgDo.UpdatedAt > 0 { + msgPO.UpdatedAt = msgDo.UpdatedAt + } if msgDo.ModelContent != "" { msgPO.ModelContent = msgDo.ModelContent diff --git a/backend/domain/workflow/interface.go b/backend/domain/workflow/interface.go index b7a91097..c76ccb39 100644 --- a/backend/domain/workflow/interface.go +++ b/backend/domain/workflow/interface.go @@ -21,6 +21,7 @@ import ( "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/compose" + "github.com/coze-dev/coze-studio/backend/api/model/workflow" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" diff --git a/backend/infra/impl/embedding/http/http.go b/backend/infra/impl/embedding/http/http.go index 36d7c70a..f9645b22 100644 --- a/backend/infra/impl/embedding/http/http.go +++ b/backend/infra/impl/embedding/http/http.go @@ -28,6 +28,7 @@ import ( "time" opt "github.com/cloudwego/eino/components/embedding" + "github.com/coze-dev/coze-studio/backend/pkg/logs" "github.com/coze-dev/coze-studio/backend/pkg/lang/slices"