From 87df1f80ebfbb8e577f2bdec07c61684008e67b8 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Thu, 16 Oct 2025 16:36:12 +0800 Subject: [PATCH 01/65] feat: define AgenticModel component interface --- components/agency/callback_extra.go | 101 ++++++++ components/agency/interface.go | 29 +++ components/agency/option.go | 77 ++++++ components/types.go | 1 + go.mod | 3 +- go.sum | 7 +- schema/agentic_message.go | 374 ++++++++++++++++++++++++++++ schema/anthropic/citation.go | 49 ++++ schema/anthropic/types.go | 10 + schema/google/candidate_meta.go | 66 +++++ schema/message.go | 4 +- schema/openai/annotation.go | 55 ++++ schema/openai/types.go | 10 + 13 files changed, 780 insertions(+), 6 deletions(-) create mode 100644 components/agency/callback_extra.go create mode 100644 components/agency/interface.go create mode 100644 components/agency/option.go create mode 100644 schema/agentic_message.go create mode 100644 schema/anthropic/citation.go create mode 100644 schema/anthropic/types.go create mode 100644 schema/google/candidate_meta.go create mode 100644 schema/openai/annotation.go create mode 100644 schema/openai/types.go diff --git a/components/agency/callback_extra.go b/components/agency/callback_extra.go new file mode 100644 index 000000000..984756d1a --- /dev/null +++ b/components/agency/callback_extra.go @@ -0,0 +1,101 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package agency + +import ( + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/schema" +) + +// TokenUsageMeta is the token usage for the model. +type TokenUsageMeta struct { + InputTokens int64 `json:"input_tokens"` + InputTokensDetails InputTokensUsageDetails `json:"input_tokens_details"` + OutputTokens int64 `json:"output_tokens"` + OutputTokensDetails OutputTokensUsageDetails `json:"output_tokens_details"` + TotalTokens int64 `json:"total_tokens"` +} + +type InputTokensUsageDetails struct { + CachedTokens int64 `json:"cached_tokens"` +} + +type OutputTokensUsageDetails struct { + ReasoningTokens int64 `json:"reasoning_tokens"` +} + +// Config is the config for the model. +type Config struct { + // Model is the model name. + Model string + // Temperature is the temperature, which controls the randomness of the model. + Temperature float32 + // TopP is the top p, which controls the diversity of the model. + TopP float32 +} + +// CallbackInput is the input for the model callback. +type CallbackInput struct { + // Responses is the responses to be sent to the model. + Responses []*schema.AgenticMessage + // Tools is the tools to be used in the model. + Tools []*schema.ToolInfo + // Config is the config for the model. + Config *Config + // Extra is the extra information for the callback. + Extra map[string]any +} + +// CallbackOutput is the output for the model callback. +type CallbackOutput struct { + // Response is the response generated by the model. + Response *schema.AgenticMessage + // Config is the config for the model. + Config *Config + // Usage is the token usage of this request. + Usage *TokenUsageMeta + // Extra is the extra information for the callback. + Extra map[string]any +} + +// ConvCallbackInput converts the callback input to the model callback input. +func ConvCallbackInput(src callbacks.CallbackInput) *CallbackInput { + switch t := src.(type) { + case *CallbackInput: // when callback is triggered within component implementation, the input is usually already a typed *model.CallbackInput + return t + case []*schema.AgenticMessage: // when callback is injected by graph node, not the component implementation itself, the input is the input of Chat Model interface, which is []*schema.AgenticMessage + return &CallbackInput{ + Responses: t, + } + default: + return nil + } +} + +// ConvCallbackOutput converts the callback output to the model callback output. +func ConvCallbackOutput(src callbacks.CallbackOutput) *CallbackOutput { + switch t := src.(type) { + case *CallbackOutput: // when callback is triggered within component implementation, the output is usually already a typed *model.CallbackOutput + return t + case *schema.AgenticMessage: // when callback is injected by graph node, not the component implementation itself, the output is the output of Chat Model interface, which is *schema.AgenticMessage + return &CallbackOutput{ + Response: t, + } + default: + return nil + } +} diff --git a/components/agency/interface.go b/components/agency/interface.go new file mode 100644 index 000000000..e33d6a933 --- /dev/null +++ b/components/agency/interface.go @@ -0,0 +1,29 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package agency + +import ( + "context" + + "github.com/cloudwego/eino/schema" +) + +type AgenticModel interface { + Generate(ctx context.Context, input []*schema.AgenticMessage, opts ...Option) (*schema.AgenticMessage, error) + Stream(ctx context.Context, input []*schema.AgenticMessage, opts ...Option) (*schema.StreamReader[*schema.AgenticMessage], error) + WithTools(tools []*schema.ToolInfo) (AgenticModel, error) +} diff --git a/components/agency/option.go b/components/agency/option.go new file mode 100644 index 000000000..17028f6e9 --- /dev/null +++ b/components/agency/option.go @@ -0,0 +1,77 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package agency + +// Options is the common options for the model. +type Options struct { +} + +// Option is the call option for ChatModel component. +type Option struct { + apply func(opts *Options) + + implSpecificOptFn any +} + +// WrapImplSpecificOptFn is the option to wrap the implementation specific option function. +func WrapImplSpecificOptFn[T any](optFn func(*T)) Option { + return Option{ + implSpecificOptFn: optFn, + } +} + +// GetCommonOptions extract model Options from Option list, optionally providing a base Options with default values. +func GetCommonOptions(base *Options, opts ...Option) *Options { + if base == nil { + base = &Options{} + } + + for i := range opts { + opt := opts[i] + if opt.apply != nil { + opt.apply(base) + } + } + + return base +} + +// GetImplSpecificOptions extract the implementation specific options from Option list, optionally providing a base options with default values. +// e.g. +// +// myOption := &MyOption{ +// Field1: "default_value", +// } +// +// myOption := model.GetImplSpecificOptions(myOption, opts...) +func GetImplSpecificOptions[T any](base *T, opts ...Option) *T { + if base == nil { + base = new(T) + } + + for i := range opts { + opt := opts[i] + if opt.implSpecificOptFn != nil { + optFn, ok := opt.implSpecificOptFn.(func(*T)) + if ok { + optFn(base) + } + } + } + + return base +} diff --git a/components/types.go b/components/types.go index a546ae59f..a23d82a68 100644 --- a/components/types.go +++ b/components/types.go @@ -68,6 +68,7 @@ const ( ComponentOfPrompt Component = "ChatTemplate" // ComponentOfChatModel identifies chat model components. ComponentOfChatModel Component = "ChatModel" + ComponentOfAgenticModel Component = "AgenticModel" // ComponentOfEmbedding identifies embedding components. ComponentOfEmbedding Component = "Embedding" // ComponentOfIndexer identifies indexer components. diff --git a/go.mod b/go.mod index cfa6957cc..0b87a6cab 100644 --- a/go.mod +++ b/go.mod @@ -41,6 +41,7 @@ require ( github.com/yargevad/filepathx v1.0.0 // indirect golang.org/x/arch v0.11.0 // indirect golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 // indirect - golang.org/x/sys v0.26.0 // indirect + golang.org/x/sys v0.29.0 // indirect + golang.org/x/term v0.28.0 // indirect gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect ) diff --git a/go.sum b/go.sum index a80d6399b..5813766b2 100644 --- a/go.sum +++ b/go.sum @@ -117,9 +117,10 @@ golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= -golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.10.0 h1:3R7pNqamzBraeqj/Tj8qt1aQ2HpmlC+Cx/qL/7hn4/c= +golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= +golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg= +golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= diff --git a/schema/agentic_message.go b/schema/agentic_message.go new file mode 100644 index 000000000..e386ed044 --- /dev/null +++ b/schema/agentic_message.go @@ -0,0 +1,374 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package schema + +import ( + "github.com/cloudwego/eino/schema/anthropic" + "github.com/cloudwego/eino/schema/google" + "github.com/cloudwego/eino/schema/openai" + "github.com/eino-contrib/jsonschema" +) + +type ContentBlockType string + +const ( + ContentBlockTypeReasoning ContentBlockType = "reasoning" + ContentBlockTypeUserInputText ContentBlockType = "user_input_text" + ContentBlockTypeUserInputImage ContentBlockType = "user_input_image" + ContentBlockTypeUserInputAudio ContentBlockType = "user_input_audio" + ContentBlockTypeUserInputVideo ContentBlockType = "user_input_video" + ContentBlockTypeUserInputFile ContentBlockType = "user_input_file" + ContentBlockTypeAssistantGenText ContentBlockType = "assistant_gen_text" + ContentBlockTypeAssistantGenImage ContentBlockType = "assistant_gen_image" + ContentBlockTypeAssistantGenAudio ContentBlockType = "assistant_gen_audio" + ContentBlockTypeAssistantGenVideo ContentBlockType = "assistant_gen_video" + ContentBlockTypeFunctionToolCall ContentBlockType = "function_tool_call" + ContentBlockTypeFunctionToolResult ContentBlockType = "function_tool_result" + ContentBlockTypeServerToolCall ContentBlockType = "server_tool_call" + ContentBlockTypeServerToolResult ContentBlockType = "server_tool_result" + ContentBlockTypeMCPToolCall ContentBlockType = "mcp_tool_call" + ContentBlockTypeMCPToolResult ContentBlockType = "mcp_tool_result" + ContentBlockTypeMCPListTools ContentBlockType = "mcp_list_tools" + ContentBlockTypeMCPToolApprovalRequest ContentBlockType = "mcp_tool_approval_request" + ContentBlockTypeMCPToolApprovalResponse ContentBlockType = "mcp_tool_approval_response" +) + +type AgenticRoleType string + +const ( + AgenticRoleTypeDeveloper AgenticRoleType = "developer" + AgenticRoleTypeSystem AgenticRoleType = "system" + AgenticRoleTypeUser AgenticRoleType = "user" + AgenticRoleTypeAssistant AgenticRoleType = "assistant" +) + +type AgenticMessage struct { + ResponseMeta *AgenticResponseMeta + + Role AgenticRoleType + ContentBlocks []*ContentBlock + + Extra map[string]any +} + +type AgenticResponseMeta struct { + Status *string + FinishReason string + + TokenUsage *TokenUsage + + GoogleAdditionalMeta *google.CandidateMeta +} + +type StreamMeta struct { + // Index is used for streaming to identify the chunk of the block for concatenation. + Index *int + // Streaming phase of the content block. + Phase StreamPhase +} + +type ContentBlock struct { + Type ContentBlockType + + Reasoning *Reasoning + + UserInputText *UserInputText + UserInputImage *UserInputImage + UserInputAudio *UserInputAudio + UserInputVideo *UserInputVideo + UserInputFile *UserInputFile + + AssistantGenText *AssistantGenText + AssistantGenImage *AssistantGenImage + AssistantGenAudio *AssistantGenAudio + AssistantGenVideo *AssistantGenVideo + + // FunctionToolCall holds invocation details for a user-defined tool. + FunctionToolCall *FunctionToolCall + // FunctionToolResult is the result from a user-defined tool call. + FunctionToolResult *FunctionToolResult + // ServerToolCall holds invocation details for a provider built-in tool run on the model server. + ServerToolCall *ServerToolCall + // ServerToolResult is the result from a provider built-in tool run on the model server. + ServerToolResult *ServerToolResult + + // MCPToolCall holds invocation details for an MCP tool managed by the model server. + MCPToolCall *MCPToolCall + // MCPToolResult is the result from an MCP tool managed by the model server. + MCPToolResult *MCPToolResult + // MCPListToolsResult lists available MCP tools reported by the model server. + MCPListToolsResult *MCPListToolsResult + // MCPToolApprovalRequest requests user approval for an MCP tool call when required. + MCPToolApprovalRequest *MCPToolApprovalRequest + // MCPToolApprovalResponse records the user's approval decision for an MCP tool call. + MCPToolApprovalResponse *MCPToolApprovalResponse + + StreamMeta *StreamMeta +} + +type StreamPhase string + +const ( + StreamPhaseStart StreamPhase = "start" + StreamPhaseDelta StreamPhase = "delta" + StreamPhaseStop StreamPhase = "stop" +) + +type UserInputText struct { + Text string + + // Extra stores additional information. + Extra map[string]any +} + +type UserInputImage struct { + URL *string + Base64Data *string + MIMEType string + Detail ImageURLDetail + + // Extra stores additional information. + Extra map[string]any +} + +type UserInputAudio struct { + URL *string + Base64Data *string + MIMEType string + + // Extra stores additional information. + Extra map[string]any +} + +type UserInputVideo struct { + URL *string + Base64Data *string + MIMEType string + + // Extra stores additional information. + Extra map[string]any +} + +type UserInputFile struct { + URL *string + Name *string + Base64Data *string + MIMEType string + + // Extra stores additional information. + Extra map[string]any +} + +type AssistantGenText struct { + Text string + + OpenAIAnnotations []*openai.TextAnnotation + AnthropicCitations []*anthropic.TextCitation + + // Extra stores additional information. + Extra map[string]any +} + +type AssistantGenImage struct { + URL *string + Base64Data *string + MIMEType string + + // Extra stores additional information. + Extra map[string]any +} + +type AssistantGenAudio struct { + URL *string + Base64Data *string + MIMEType string + + // Extra stores additional information. + Extra map[string]any +} + +type AssistantGenVideo struct { + URL *string + Base64Data *string + MIMEType string + + // Extra stores additional information. + Extra map[string]any +} + +type Reasoning struct { + // Summary is the reasoning content summary. + Summary []*ReasoningSummary + + // EncryptedContent is the encrypted reasoning content. + EncryptedContent string + + // Extra stores additional information. + Extra map[string]any +} + +type ReasoningSummary struct { + // Index specifies the ReasoningSummary chunk to be concatenated during streaming. + Index *int + + Text string +} + +type FunctionToolCall struct { + // CallID is the unique identifier for the tool call. + CallID string + + // Name specifies the function tool invoked. + Name string + + // Arguments is the JSON string arguments for the function tool call. + Arguments string + + // Extra stores additional information + Extra map[string]any +} + +type FunctionToolResult struct { + // CallID is the unique identifier for the tool call. + CallID string + + // Name specifies the function tool invoked. + Name string + + // Result is the function tool result returned by the user + Result string + + // Extra stores additional information. + Extra map[string]any +} + +type ServerToolCall struct { + // Name specifies the server-side tool invoked. + // Supplied by the model server (e.g., `web_search` for OpenAI, `googleSearch` for Gemini). + Name string + + // CallID is the unique identifier for the tool call. + // Empty if not provided by the model server. + CallID string + + // Arguments are the raw inputs to the server-side tool, + // supplied by the component implementer. + Arguments any + + // Extra stores additional information. + Extra map[string]any +} + +type ServerToolResult struct { + // Name specifies the server-side tool invoked. + // Supplied by the model server (e.g., `web_search` for OpenAI, `googleSearch` for Gemini). + Name string + + // CallID is the unique identifier for the tool call. + // Empty if not provided by the model server. + CallID string + + // Result refers to the raw output generated by the server-side tool, + // supplied by the component implementer. + Result any + + // Extra stores additional information. + Extra map[string]any +} + +type MCPToolCall struct { + // ServerLabel is the MCP server label used to identify it in tool calls + ServerLabel string + // ApprovalRequestID is the unique ID of the approval request. + ApprovalRequestID string + // CallID is the unique ID of the tool call. + CallID string + // Name is the name of the tool to run. + Name string + // Arguments is the JSON string arguments for the tool call. + Arguments string + + // Extra stores additional information. + Extra map[string]any +} + +type MCPToolResult struct { + // CallID is the unique ID of the tool call. + CallID string + // Name is the name of the tool to run. + Name string + // Result is the JSON string with the tool result. + Result string + // Error returned when the server fails to run the tool. + Error *MCPToolCallError + + // Extra stores additional information. + Extra map[string]any +} + +type MCPToolCallError struct { + Code int64 + Error string +} + +type MCPListToolsResult struct { + // ServerLabel is the MCP server label used to identify it in tool calls. + ServerLabel string + // Tools is the list of tools available on the server. + Tools []MCPListToolsItem + // Error returned when the server fails to list tools. + Error string + + // Extra stores additional information. + Extra map[string]any +} + +type MCPListToolsItem struct { + // Name is the name of the tool. + Name string + // Description is the description of the tool. + Description string + // InputSchema is the JSON schema that describes the tool input. + InputSchema *jsonschema.Schema +} + +type MCPToolApprovalRequest struct { + // CallID is the unique ID of the tool call. + CallID string + // Name is the name of the tool to run. + Name string + // Arguments is the JSON string arguments for the tool call. + Arguments string + // ServerLabel is the MCP server label used to identify it in tool calls. + ServerLabel string + + // Extra stores additional information. + Extra map[string]any +} + +type MCPToolApprovalResponse struct { + // ApprovalRequestID is the approval request ID being responded to. + ApprovalRequestID string + // Approve indicates whether the request is approved. + Approve bool + // Reason is the rationale for the decision. + // Optional. + Reason string + + // Extra stores additional information. + Extra map[string]any +} diff --git a/schema/anthropic/citation.go b/schema/anthropic/citation.go new file mode 100644 index 000000000..24a4c5aa6 --- /dev/null +++ b/schema/anthropic/citation.go @@ -0,0 +1,49 @@ +package anthropic + +type TextCitation struct { + Type TextCitationType `json:"type"` + + CharLocation *CitationCharLocation `json:"char_location,omitempty"` + PageLocation *CitationPageLocation `json:"page_location,omitempty"` + ContentBlockLocation *CitationContentBlockLocation `json:"content_block_location,omitempty"` + WebSearchResultLocation *CitationWebSearchResultLocation `json:"web_search_result_location,omitempty"` +} + +type CitationCharLocation struct { + CitedText string `json:"cited_text"` + + DocumentTitle string `json:"document_title"` + DocumentIndex int64 `json:"document_index"` + + StartCharIndex int64 `json:"start_char_index"` + EndCharIndex int64 `json:"end_char_index"` +} + +type CitationPageLocation struct { + CitedText string `json:"cited_text"` + + DocumentTitle string `json:"document_title"` + DocumentIndex int64 `json:"document_index"` + + StartPageNumber int64 `json:"start_page_number"` + EndPageNumber int64 `json:"end_page_number"` +} + +type CitationContentBlockLocation struct { + CitedText string `json:"cited_text"` + + DocumentTitle string `json:"document_title"` + DocumentIndex int64 `json:"document_index"` + + StartBlockIndex int64 `json:"start_block_index"` + EndBlockIndex int64 `json:"end_block_index"` +} + +type CitationWebSearchResultLocation struct { + CitedText string `json:"cited_text"` + + Title string `json:"title"` + URL string `json:"url"` + + EncryptedIndex string `json:"encrypted_index"` +} diff --git a/schema/anthropic/types.go b/schema/anthropic/types.go new file mode 100644 index 000000000..fbc85475d --- /dev/null +++ b/schema/anthropic/types.go @@ -0,0 +1,10 @@ +package anthropic + +type TextCitationType string + +const ( + TextCitationTypeCharLocation TextCitationType = "char_location" + TextCitationTypePageLocation TextCitationType = "page_location" + TextCitationTypeContentBlockLocation TextCitationType = "content_block_location" + TextCitationTypeWebSearchResultLocation TextCitationType = "web_search_result_location" +) diff --git a/schema/google/candidate_meta.go b/schema/google/candidate_meta.go new file mode 100644 index 000000000..aead31c2e --- /dev/null +++ b/schema/google/candidate_meta.go @@ -0,0 +1,66 @@ +package google + +type CandidateMeta struct { + GroundingMeta *GroundingMetadata `json:"grounding_meta,omitempty"` +} + +type GroundingMetadata struct { + // List of supporting references retrieved from specified grounding source. + GroundingChunks []*GroundingChunk `json:"grounding_chunks,omitempty"` + // Optional. List of grounding support. + GroundingSupports []*GroundingSupport `json:"grounding_supports,omitempty"` + // Optional. Google search entry for the following-up web searches. + SearchEntryPoint *SearchEntryPoint `json:"search_entry_point,omitempty"` + // Optional. Web search queries for the following-up web search. + WebSearchQueries []string `json:"web_search_queries,omitempty"` +} + +type GroundingChunk struct { + // Grounding chunk from the web. + Web *GroundingChunkWeb `json:"web,omitempty"` +} + +// Chunk from the web. +type GroundingChunkWeb struct { + // Domain of the (original) URI. This field is not supported in Gemini API. + Domain string `json:"domain,omitempty"` + // Title of the chunk. + Title string `json:"title,omitempty"` + // URI reference of the chunk. + URI string `json:"uri,omitempty"` +} + +type GroundingSupport struct { + // Confidence score of the support references. Ranges from 0 to 1. 1 is the most confident. + // For Gemini 2.0 and before, this list must have the same size as the grounding_chunk_indices. + // For Gemini 2.5 and after, this list will be empty and should be ignored. + ConfidenceScores []float32 `json:"confidence_scores,omitempty"` + // A list of indices (into 'grounding_chunk') specifying the citations associated with + // the claim. For instance [1,3,4] means that grounding_chunk[1], grounding_chunk[3], + // grounding_chunk[4] are the retrieved content attributed to the claim. + GroundingChunkIndices []int32 `json:"grounding_chunk_indices,omitempty"` + // Segment of the content this support belongs to. + Segment *Segment `json:"segment,omitempty"` +} + +// Segment of the content. +type Segment struct { + // Output only. End index in the given Part, measured in bytes. Offset from the start + // of the Part, exclusive, starting at zero. + EndIndex int32 `json:"end_index,omitempty"` + // Output only. The index of a Part object within its parent Content object. + PartIndex int32 `json:"part_index,omitempty"` + // Output only. Start index in the given Part, measured in bytes. Offset from the start + // of the Part, inclusive, starting at zero. + StartIndex int32 `json:"start_index,omitempty"` + // Output only. The text corresponding to the segment from the response. + Text string `json:"text,omitempty"` +} + +// Google search entry point. +type SearchEntryPoint struct { + // Optional. Web content snippet that can be embedded in a web page or an app webview. + RenderedContent string `json:"rendered_content,omitempty"` + // Optional. Base64 encoded JSON representing array of tuple. + SDKBlob []byte `json:"sdk_blob,omitempty"` +} diff --git a/schema/message.go b/schema/message.go index 3746244bb..fefb2079e 100644 --- a/schema/message.go +++ b/schema/message.go @@ -694,10 +694,10 @@ type TokenUsage struct { PromptTokenDetails PromptTokenDetails `json:"prompt_token_details"` // CompletionTokens is the number of completion tokens. CompletionTokens int `json:"completion_tokens"` + // CompletionTokenDetails is a breakdown of the completion tokens. + CompletionTokenDetails CompletionTokensDetails `json:"completion_token_details"` // TotalTokens is the total number of tokens. TotalTokens int `json:"total_tokens"` - // CompletionTokensDetails is breakdown of completion tokens. - CompletionTokensDetails CompletionTokensDetails `json:"completion_token_details"` } type CompletionTokensDetails struct { diff --git a/schema/openai/annotation.go b/schema/openai/annotation.go new file mode 100644 index 000000000..a834e8072 --- /dev/null +++ b/schema/openai/annotation.go @@ -0,0 +1,55 @@ +package openai + +type TextAnnotation struct { + Type TextAnnotationType `json:"type"` + + FileCitation *TextAnnotationFileCitation `json:"file_citation,omitempty"` + URLCitation *TextAnnotationURLCitation `json:"url_citation,omitempty"` + ContainerFileCitation *TextAnnotationContainerFileCitation `json:"container_file_citation,omitempty"` + FilePath *TextAnnotationFilePath `json:"file_path,omitempty"` +} + +type TextAnnotationFileCitation struct { + // The ID of the file. + FileID string `json:"file_id"` + // The filename of the file cited. + Filename string `json:"filename"` + + // The index of the file in the list of files. + Index int64 `json:"index"` +} + +type TextAnnotationURLCitation struct { + // The title of the web resource. + Title string `json:"title"` + // The URL of the web resource. + URL string `json:"url"` + + // The index of the first character of the URL citation in the message. + StartIndex int64 `json:"start_index"` + // The index of the last character of the URL citation in the message. + EndIndex int64 `json:"end_index"` +} + +type TextAnnotationContainerFileCitation struct { + // The ID of the container file. + ContainerID string `json:"container_id"` + + // The ID of the file. + FileID string `json:"file_id"` + // The filename of the container file cited. + Filename string `json:"filename"` + + // The index of the first character of the container file citation in the message. + StartIndex int64 `json:"start_index"` + // The index of the last character of the container file citation in the message. + EndIndex int64 `json:"end_index"` +} + +type TextAnnotationFilePath struct { + // The ID of the file. + FileID string `json:"file_id"` + + // The index of the file in the list of files. + Index int64 `json:"index"` +} diff --git a/schema/openai/types.go b/schema/openai/types.go new file mode 100644 index 000000000..60cee4361 --- /dev/null +++ b/schema/openai/types.go @@ -0,0 +1,10 @@ +package openai + +type TextAnnotationType string + +const ( + TextAnnotationTypeFileCitation TextAnnotationType = "file_citation" + TextAnnotationTypeURLCitation TextAnnotationType = "url_citation" + TextAnnotationTypeContainerFileCitation TextAnnotationType = "container_file_citation" + TextAnnotationTypeFilePath TextAnnotationType = "file_path" +) From 0728ce93f5c869070607b3866d903cb472871868 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Tue, 25 Nov 2025 11:35:01 +0800 Subject: [PATCH 02/65] feat: change the index in StreamMeta to a non-pointer (#573) --- schema/agentic_message.go | 14 ++------------ schema/anthropic/citation.go | 16 ++++++++++++++++ schema/anthropic/types.go | 16 ++++++++++++++++ schema/google/candidate_meta.go | 16 ++++++++++++++++ schema/openai/annotation.go | 16 ++++++++++++++++ schema/openai/types.go | 16 ++++++++++++++++ 6 files changed, 82 insertions(+), 12 deletions(-) diff --git a/schema/agentic_message.go b/schema/agentic_message.go index e386ed044..84a933c9e 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -75,10 +75,8 @@ type AgenticResponseMeta struct { } type StreamMeta struct { - // Index is used for streaming to identify the chunk of the block for concatenation. - Index *int - // Streaming phase of the content block. - Phase StreamPhase + // Index is the index position of this block in the final response. + Index int } type ContentBlock struct { @@ -120,14 +118,6 @@ type ContentBlock struct { StreamMeta *StreamMeta } -type StreamPhase string - -const ( - StreamPhaseStart StreamPhase = "start" - StreamPhaseDelta StreamPhase = "delta" - StreamPhaseStop StreamPhase = "stop" -) - type UserInputText struct { Text string diff --git a/schema/anthropic/citation.go b/schema/anthropic/citation.go index 24a4c5aa6..064477688 100644 --- a/schema/anthropic/citation.go +++ b/schema/anthropic/citation.go @@ -1,3 +1,19 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package anthropic type TextCitation struct { diff --git a/schema/anthropic/types.go b/schema/anthropic/types.go index fbc85475d..cc8b1f877 100644 --- a/schema/anthropic/types.go +++ b/schema/anthropic/types.go @@ -1,3 +1,19 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package anthropic type TextCitationType string diff --git a/schema/google/candidate_meta.go b/schema/google/candidate_meta.go index aead31c2e..8cd324254 100644 --- a/schema/google/candidate_meta.go +++ b/schema/google/candidate_meta.go @@ -1,3 +1,19 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package google type CandidateMeta struct { diff --git a/schema/openai/annotation.go b/schema/openai/annotation.go index a834e8072..ad4b6b91f 100644 --- a/schema/openai/annotation.go +++ b/schema/openai/annotation.go @@ -1,3 +1,19 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package openai type TextAnnotation struct { diff --git a/schema/openai/types.go b/schema/openai/types.go index 60cee4361..321ee2a9e 100644 --- a/schema/openai/types.go +++ b/schema/openai/types.go @@ -1,3 +1,19 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package openai type TextAnnotationType string From 918c9d670137ac49a4986f57d8035cb5c5845c28 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Tue, 25 Nov 2025 21:34:26 +0800 Subject: [PATCH 03/65] feat: improve AgenticResponseMeta definition (#575) --- schema/agentic_message.go | 30 +++++++---------- schema/{anthropic => claude}/citation.go | 2 +- schema/claude/messages_meta.go | 22 +++++++++++++ schema/{anthropic => claude}/types.go | 2 +- .../response_meta.go} | 6 ++-- schema/openai/response_meta.go | 33 +++++++++++++++++++ 6 files changed, 72 insertions(+), 23 deletions(-) rename schema/{anthropic => claude}/citation.go (99%) create mode 100644 schema/claude/messages_meta.go rename schema/{anthropic => claude}/types.go (98%) rename schema/{google/candidate_meta.go => gemini/response_meta.go} (95%) create mode 100644 schema/openai/response_meta.go diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 84a933c9e..0d2b01047 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -17,8 +17,8 @@ package schema import ( - "github.com/cloudwego/eino/schema/anthropic" - "github.com/cloudwego/eino/schema/google" + "github.com/cloudwego/eino/schema/claude" + "github.com/cloudwego/eino/schema/gemini" "github.com/cloudwego/eino/schema/openai" "github.com/eino-contrib/jsonschema" ) @@ -66,16 +66,16 @@ type AgenticMessage struct { } type AgenticResponseMeta struct { - Status *string - FinishReason string - TokenUsage *TokenUsage - GoogleAdditionalMeta *google.CandidateMeta + OpenAIExtensions *openai.ResponseMeta + GeminiExtensions *gemini.ResponseMeta + ClaudeExtensions *claude.MessageMeta + Extensions any } type StreamMeta struct { - // Index is the index position of this block in the final response. + // Index specifies the index position of this block in the final response. Index int } @@ -166,8 +166,8 @@ type UserInputFile struct { type AssistantGenText struct { Text string - OpenAIAnnotations []*openai.TextAnnotation - AnthropicCitations []*anthropic.TextCitation + OpenAIAnnotations []*openai.TextAnnotation + ClaudeCitations []*claude.TextCitation // Extra stores additional information. Extra map[string]any @@ -203,7 +203,6 @@ type AssistantGenVideo struct { type Reasoning struct { // Summary is the reasoning content summary. Summary []*ReasoningSummary - // EncryptedContent is the encrypted reasoning content. EncryptedContent string @@ -212,8 +211,8 @@ type Reasoning struct { } type ReasoningSummary struct { - // Index specifies the ReasoningSummary chunk to be concatenated during streaming. - Index *int + // Index specifies the index position of this summary in the final Reasoning. + Index int Text string } @@ -221,10 +220,8 @@ type ReasoningSummary struct { type FunctionToolCall struct { // CallID is the unique identifier for the tool call. CallID string - // Name specifies the function tool invoked. Name string - // Arguments is the JSON string arguments for the function tool call. Arguments string @@ -235,10 +232,8 @@ type FunctionToolCall struct { type FunctionToolResult struct { // CallID is the unique identifier for the tool call. CallID string - // Name specifies the function tool invoked. Name string - // Result is the function tool result returned by the user Result string @@ -250,15 +245,12 @@ type ServerToolCall struct { // Name specifies the server-side tool invoked. // Supplied by the model server (e.g., `web_search` for OpenAI, `googleSearch` for Gemini). Name string - // CallID is the unique identifier for the tool call. // Empty if not provided by the model server. CallID string - // Arguments are the raw inputs to the server-side tool, // supplied by the component implementer. Arguments any - // Extra stores additional information. Extra map[string]any } diff --git a/schema/anthropic/citation.go b/schema/claude/citation.go similarity index 99% rename from schema/anthropic/citation.go rename to schema/claude/citation.go index 064477688..b5092d3bc 100644 --- a/schema/anthropic/citation.go +++ b/schema/claude/citation.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package anthropic +package claude type TextCitation struct { Type TextCitationType `json:"type"` diff --git a/schema/claude/messages_meta.go b/schema/claude/messages_meta.go new file mode 100644 index 000000000..a72dded2a --- /dev/null +++ b/schema/claude/messages_meta.go @@ -0,0 +1,22 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package claude + +type MessageMeta struct { + ID string `json:"id"` + StopReason string `json:"stop_reason"` +} diff --git a/schema/anthropic/types.go b/schema/claude/types.go similarity index 98% rename from schema/anthropic/types.go rename to schema/claude/types.go index cc8b1f877..cbf8784f6 100644 --- a/schema/anthropic/types.go +++ b/schema/claude/types.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package anthropic +package claude type TextCitationType string diff --git a/schema/google/candidate_meta.go b/schema/gemini/response_meta.go similarity index 95% rename from schema/google/candidate_meta.go rename to schema/gemini/response_meta.go index 8cd324254..3bc590a72 100644 --- a/schema/google/candidate_meta.go +++ b/schema/gemini/response_meta.go @@ -14,9 +14,11 @@ * limitations under the License. */ -package google +package gemini -type CandidateMeta struct { +type ResponseMeta struct { + ID string `json:"id"` + FinishReason string `json:"finish_reason"` GroundingMeta *GroundingMetadata `json:"grounding_meta,omitempty"` } diff --git a/schema/openai/response_meta.go b/schema/openai/response_meta.go new file mode 100644 index 000000000..1b184073e --- /dev/null +++ b/schema/openai/response_meta.go @@ -0,0 +1,33 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package openai + +type ResponseMeta struct { + ID string `json:"id"` + Status string `json:"status"` + Error *ResponseError `json:"error,omitempty"` + IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"` +} + +type ResponseError struct { + Code string `json:"code"` + Message string `json:"message"` +} + +type IncompleteDetails struct { + Reason string `json:"reason"` +} From 3fc7ff0e11dd6db7258bbfc01c3a43c6ce779103 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Wed, 26 Nov 2025 16:08:32 +0800 Subject: [PATCH 04/65] feat: improve AssistantGenText definition (#577) --- schema/agentic_message.go | 5 +++-- schema/claude/{types.go => consts.go} | 0 schema/claude/{citation.go => content_block.go} | 4 ++++ schema/claude/{messages_meta.go => message_meta.go} | 0 schema/openai/{types.go => consts.go} | 0 schema/openai/{annotation.go => content_block.go} | 5 +++++ 6 files changed, 12 insertions(+), 2 deletions(-) rename schema/claude/{types.go => consts.go} (100%) rename schema/claude/{citation.go => content_block.go} (96%) rename schema/claude/{messages_meta.go => message_meta.go} (100%) rename schema/openai/{types.go => consts.go} (100%) rename schema/openai/{annotation.go => content_block.go} (94%) diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 0d2b01047..09ab28e40 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -166,8 +166,9 @@ type UserInputFile struct { type AssistantGenText struct { Text string - OpenAIAnnotations []*openai.TextAnnotation - ClaudeCitations []*claude.TextCitation + OpenAIExtensions *openai.OutputText + ClaudeExtensions *claude.TextBlock + Extensions any // Extra stores additional information. Extra map[string]any diff --git a/schema/claude/types.go b/schema/claude/consts.go similarity index 100% rename from schema/claude/types.go rename to schema/claude/consts.go diff --git a/schema/claude/citation.go b/schema/claude/content_block.go similarity index 96% rename from schema/claude/citation.go rename to schema/claude/content_block.go index b5092d3bc..ba297126e 100644 --- a/schema/claude/citation.go +++ b/schema/claude/content_block.go @@ -16,6 +16,10 @@ package claude +type TextBlock struct { + Citations []*TextCitation `json:"citations"` +} + type TextCitation struct { Type TextCitationType `json:"type"` diff --git a/schema/claude/messages_meta.go b/schema/claude/message_meta.go similarity index 100% rename from schema/claude/messages_meta.go rename to schema/claude/message_meta.go diff --git a/schema/openai/types.go b/schema/openai/consts.go similarity index 100% rename from schema/openai/types.go rename to schema/openai/consts.go diff --git a/schema/openai/annotation.go b/schema/openai/content_block.go similarity index 94% rename from schema/openai/annotation.go rename to schema/openai/content_block.go index ad4b6b91f..5135964b7 100644 --- a/schema/openai/annotation.go +++ b/schema/openai/content_block.go @@ -16,6 +16,11 @@ package openai +type OutputText struct { + ItemID string `json:"item_id"` + Annotations []*TextAnnotation `json:"annotations"` +} + type TextAnnotation struct { Type TextAnnotationType `json:"type"` From 368d25d4dbdf147713016cf1c667138b6e0185f0 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Wed, 26 Nov 2025 17:52:22 +0800 Subject: [PATCH 05/65] feat: improve extension type name (#578) --- schema/agentic_message.go | 14 +++++++------- schema/claude/content_block.go | 2 +- .../claude/{message_meta.go => response_meta.go} | 2 +- schema/gemini/response_meta.go | 2 +- schema/openai/content_block.go | 2 +- schema/openai/response_meta.go | 2 +- 6 files changed, 12 insertions(+), 12 deletions(-) rename schema/claude/{message_meta.go => response_meta.go} (95%) diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 09ab28e40..3530038e2 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -68,10 +68,10 @@ type AgenticMessage struct { type AgenticResponseMeta struct { TokenUsage *TokenUsage - OpenAIExtensions *openai.ResponseMeta - GeminiExtensions *gemini.ResponseMeta - ClaudeExtensions *claude.MessageMeta - Extensions any + OpenAIExtension *openai.ResponseMetaExtension + GeminiExtension *gemini.ResponseMetaExtension + ClaudeExtension *claude.ResponseMetaExtension + Extension any } type StreamMeta struct { @@ -166,9 +166,9 @@ type UserInputFile struct { type AssistantGenText struct { Text string - OpenAIExtensions *openai.OutputText - ClaudeExtensions *claude.TextBlock - Extensions any + OpenAIExtension *openai.AssistantGenTextExtension + ClaudeExtension *claude.AssistantGenTextExtension + Extension any // Extra stores additional information. Extra map[string]any diff --git a/schema/claude/content_block.go b/schema/claude/content_block.go index ba297126e..4421db807 100644 --- a/schema/claude/content_block.go +++ b/schema/claude/content_block.go @@ -16,7 +16,7 @@ package claude -type TextBlock struct { +type AssistantGenTextExtension struct { Citations []*TextCitation `json:"citations"` } diff --git a/schema/claude/message_meta.go b/schema/claude/response_meta.go similarity index 95% rename from schema/claude/message_meta.go rename to schema/claude/response_meta.go index a72dded2a..7d9dbe740 100644 --- a/schema/claude/message_meta.go +++ b/schema/claude/response_meta.go @@ -16,7 +16,7 @@ package claude -type MessageMeta struct { +type ResponseMetaExtension struct { ID string `json:"id"` StopReason string `json:"stop_reason"` } diff --git a/schema/gemini/response_meta.go b/schema/gemini/response_meta.go index 3bc590a72..bb4af92c9 100644 --- a/schema/gemini/response_meta.go +++ b/schema/gemini/response_meta.go @@ -16,7 +16,7 @@ package gemini -type ResponseMeta struct { +type ResponseMetaExtension struct { ID string `json:"id"` FinishReason string `json:"finish_reason"` GroundingMeta *GroundingMetadata `json:"grounding_meta,omitempty"` diff --git a/schema/openai/content_block.go b/schema/openai/content_block.go index 5135964b7..dfa83109a 100644 --- a/schema/openai/content_block.go +++ b/schema/openai/content_block.go @@ -16,7 +16,7 @@ package openai -type OutputText struct { +type AssistantGenTextExtension struct { ItemID string `json:"item_id"` Annotations []*TextAnnotation `json:"annotations"` } diff --git a/schema/openai/response_meta.go b/schema/openai/response_meta.go index 1b184073e..809fbb1a5 100644 --- a/schema/openai/response_meta.go +++ b/schema/openai/response_meta.go @@ -16,7 +16,7 @@ package openai -type ResponseMeta struct { +type ResponseMetaExtension struct { ID string `json:"id"` Status string `json:"status"` Error *ResponseError `json:"error,omitempty"` From d49e463dcee87e4815b52887ed70c96d0e301c59 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Wed, 26 Nov 2025 20:23:14 +0800 Subject: [PATCH 06/65] feat: modify package name (#579) --- components/{agency => agentic}/callback_extra.go | 2 +- components/{agency => agentic}/interface.go | 6 +++--- components/{agency => agentic}/option.go | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) rename components/{agency => agentic}/callback_extra.go (99%) rename components/{agency => agentic}/interface.go (89%) rename components/{agency => agentic}/option.go (99%) diff --git a/components/agency/callback_extra.go b/components/agentic/callback_extra.go similarity index 99% rename from components/agency/callback_extra.go rename to components/agentic/callback_extra.go index 984756d1a..f824750f9 100644 --- a/components/agency/callback_extra.go +++ b/components/agentic/callback_extra.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package agency +package agentic import ( "github.com/cloudwego/eino/callbacks" diff --git a/components/agency/interface.go b/components/agentic/interface.go similarity index 89% rename from components/agency/interface.go rename to components/agentic/interface.go index e33d6a933..e9960d332 100644 --- a/components/agency/interface.go +++ b/components/agentic/interface.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package agency +package agentic import ( "context" @@ -22,8 +22,8 @@ import ( "github.com/cloudwego/eino/schema" ) -type AgenticModel interface { +type Model interface { Generate(ctx context.Context, input []*schema.AgenticMessage, opts ...Option) (*schema.AgenticMessage, error) Stream(ctx context.Context, input []*schema.AgenticMessage, opts ...Option) (*schema.StreamReader[*schema.AgenticMessage], error) - WithTools(tools []*schema.ToolInfo) (AgenticModel, error) + WithTools(tools []*schema.ToolInfo) (Model, error) } diff --git a/components/agency/option.go b/components/agentic/option.go similarity index 99% rename from components/agency/option.go rename to components/agentic/option.go index 17028f6e9..b000d6893 100644 --- a/components/agency/option.go +++ b/components/agentic/option.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package agency +package agentic // Options is the common options for the model. type Options struct { From baee8c59768d7f5728ffc53ff0a9e1c54c71ee84 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Wed, 26 Nov 2025 21:06:29 +0800 Subject: [PATCH 07/65] feat: remove TokenUsage definition in CallbackOutput (#580) --- components/agentic/callback_extra.go | 31 ++++++---------------------- 1 file changed, 6 insertions(+), 25 deletions(-) diff --git a/components/agentic/callback_extra.go b/components/agentic/callback_extra.go index f824750f9..f35b37779 100644 --- a/components/agentic/callback_extra.go +++ b/components/agentic/callback_extra.go @@ -21,23 +21,6 @@ import ( "github.com/cloudwego/eino/schema" ) -// TokenUsageMeta is the token usage for the model. -type TokenUsageMeta struct { - InputTokens int64 `json:"input_tokens"` - InputTokensDetails InputTokensUsageDetails `json:"input_tokens_details"` - OutputTokens int64 `json:"output_tokens"` - OutputTokensDetails OutputTokensUsageDetails `json:"output_tokens_details"` - TotalTokens int64 `json:"total_tokens"` -} - -type InputTokensUsageDetails struct { - CachedTokens int64 `json:"cached_tokens"` -} - -type OutputTokensUsageDetails struct { - ReasoningTokens int64 `json:"reasoning_tokens"` -} - // Config is the config for the model. type Config struct { // Model is the model name. @@ -50,8 +33,8 @@ type Config struct { // CallbackInput is the input for the model callback. type CallbackInput struct { - // Responses is the responses to be sent to the model. - Responses []*schema.AgenticMessage + // Messages is the messages to be sent to the model. + Messages []*schema.AgenticMessage // Tools is the tools to be used in the model. Tools []*schema.ToolInfo // Config is the config for the model. @@ -62,12 +45,10 @@ type CallbackInput struct { // CallbackOutput is the output for the model callback. type CallbackOutput struct { - // Response is the response generated by the model. - Response *schema.AgenticMessage + // Message is the message generated by the model. + Message *schema.AgenticMessage // Config is the config for the model. Config *Config - // Usage is the token usage of this request. - Usage *TokenUsageMeta // Extra is the extra information for the callback. Extra map[string]any } @@ -79,7 +60,7 @@ func ConvCallbackInput(src callbacks.CallbackInput) *CallbackInput { return t case []*schema.AgenticMessage: // when callback is injected by graph node, not the component implementation itself, the input is the input of Chat Model interface, which is []*schema.AgenticMessage return &CallbackInput{ - Responses: t, + Messages: t, } default: return nil @@ -93,7 +74,7 @@ func ConvCallbackOutput(src callbacks.CallbackOutput) *CallbackOutput { return t case *schema.AgenticMessage: // when callback is injected by graph node, not the component implementation itself, the output is the output of Chat Model interface, which is *schema.AgenticMessage return &CallbackOutput{ - Response: t, + Message: t, } default: return nil From 7dda9377a9543189a2125b759dedfa57f15db8fe Mon Sep 17 00:00:00 2001 From: mrh997 Date: Mon, 1 Dec 2025 11:51:09 +0800 Subject: [PATCH 08/65] feat: add helper functions for AgenticMessage (#582) --- components/agentic/option.go | 62 +++++++++++++++++ schema/agentic_message.go | 119 ++++++++++++++++++++++++++++----- schema/openai/content_block.go | 1 - 3 files changed, 164 insertions(+), 18 deletions(-) diff --git a/components/agentic/option.go b/components/agentic/option.go index b000d6893..dae6139be 100644 --- a/components/agentic/option.go +++ b/components/agentic/option.go @@ -16,8 +16,22 @@ package agentic +import ( + "github.com/cloudwego/eino/schema" +) + // Options is the common options for the model. type Options struct { + // Temperature is the temperature for the model, which controls the randomness of the model. + Temperature *float32 + // Model is the model name. + Model *string + // TopP is the top p for the model, which controls the diversity of the model. + TopP *float32 + // Tools is a list of tools the model may call. + Tools []*schema.ToolInfo + // ToolChoice controls which tool is called by the model. + ToolChoice *schema.ToolChoice } // Option is the call option for ChatModel component. @@ -27,6 +41,54 @@ type Option struct { implSpecificOptFn any } +// WithTemperature is the option to set the temperature for the model. +func WithTemperature(temperature float32) Option { + return Option{ + apply: func(opts *Options) { + opts.Temperature = &temperature + }, + } +} + +// WithModel is the option to set the model name. +func WithModel(name string) Option { + return Option{ + apply: func(opts *Options) { + opts.Model = &name + }, + } +} + +// WithTopP is the option to set the top p for the model. +func WithTopP(topP float32) Option { + return Option{ + apply: func(opts *Options) { + opts.TopP = &topP + }, + } +} + +// WithTools is the option to set tools for the model. +func WithTools(tools []*schema.ToolInfo) Option { + if tools == nil { + tools = []*schema.ToolInfo{} + } + return Option{ + apply: func(opts *Options) { + opts.Tools = tools + }, + } +} + +// WithToolChoice is the option to set tool choice for the model. +func WithToolChoice(toolChoice schema.ToolChoice) Option { + return Option{ + apply: func(opts *Options) { + opts.ToolChoice = &toolChoice + }, + } +} + // WrapImplSpecificOptFn is the option to wrap the implementation specific option function. func WrapImplSpecificOptFn[T any](optFn func(*T)) Option { return Option{ diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 3530038e2..3953fac7c 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -126,8 +126,8 @@ type UserInputText struct { } type UserInputImage struct { - URL *string - Base64Data *string + URL string + Base64Data string MIMEType string Detail ImageURLDetail @@ -136,8 +136,8 @@ type UserInputImage struct { } type UserInputAudio struct { - URL *string - Base64Data *string + URL string + Base64Data string MIMEType string // Extra stores additional information. @@ -145,8 +145,8 @@ type UserInputAudio struct { } type UserInputVideo struct { - URL *string - Base64Data *string + URL string + Base64Data string MIMEType string // Extra stores additional information. @@ -154,9 +154,9 @@ type UserInputVideo struct { } type UserInputFile struct { - URL *string - Name *string - Base64Data *string + URL string + Name string + Base64Data string MIMEType string // Extra stores additional information. @@ -175,8 +175,8 @@ type AssistantGenText struct { } type AssistantGenImage struct { - URL *string - Base64Data *string + URL string + Base64Data string MIMEType string // Extra stores additional information. @@ -184,8 +184,8 @@ type AssistantGenImage struct { } type AssistantGenAudio struct { - URL *string - Base64Data *string + URL string + Base64Data string MIMEType string // Extra stores additional information. @@ -193,8 +193,8 @@ type AssistantGenAudio struct { } type AssistantGenVideo struct { - URL *string - Base64Data *string + URL string + Base64Data string MIMEType string // Extra stores additional information. @@ -290,6 +290,8 @@ type MCPToolCall struct { } type MCPToolResult struct { + // ServerLabel is the MCP server label used to identify it in tool calls + ServerLabel string // CallID is the unique ID of the tool call. CallID string // Name is the name of the tool to run. @@ -304,7 +306,7 @@ type MCPToolResult struct { } type MCPToolCallError struct { - Code int64 + Code *int64 Error string } @@ -312,7 +314,7 @@ type MCPListToolsResult struct { // ServerLabel is the MCP server label used to identify it in tool calls. ServerLabel string // Tools is the list of tools available on the server. - Tools []MCPListToolsItem + Tools []*MCPListToolsItem // Error returned when the server fails to list tools. Error string @@ -355,3 +357,86 @@ type MCPToolApprovalResponse struct { // Extra stores additional information. Extra map[string]any } + +// DeveloperAgenticMessage represents a message with AgenticRoleType "developer". +func DeveloperAgenticMessage(text string) *AgenticMessage { + return &AgenticMessage{ + Role: AgenticRoleTypeDeveloper, + ContentBlocks: []*ContentBlock{NewContentBlock(&UserInputText{Text: text})}, + } +} + +// SystemAgenticMessage represents a message with AgenticRoleType "system". +func SystemAgenticMessage(text string) *AgenticMessage { + return &AgenticMessage{ + Role: AgenticRoleTypeSystem, + ContentBlocks: []*ContentBlock{NewContentBlock(&UserInputText{Text: text})}, + } +} + +// UserAgenticMessage represents a message with AgenticRoleType "user". +func UserAgenticMessage(text string) *AgenticMessage { + return &AgenticMessage{ + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{NewContentBlock(&UserInputText{Text: text})}, + } +} + +// FunctionToolResultAgenticMessage represents a function tool result message with AgenticRoleType "user". +func FunctionToolResultAgenticMessage(callID, name, result string) *AgenticMessage { + return &AgenticMessage{ + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + NewContentBlock(&FunctionToolResult{ + CallID: callID, + Name: name, + Result: result, + }), + }, + } +} + +func NewContentBlock(block any) *ContentBlock { + switch b := block.(type) { + case *Reasoning: + return &ContentBlock{Type: ContentBlockTypeReasoning, Reasoning: b} + case *UserInputText: + return &ContentBlock{Type: ContentBlockTypeUserInputText, UserInputText: b} + case *UserInputImage: + return &ContentBlock{Type: ContentBlockTypeUserInputImage, UserInputImage: b} + case *UserInputAudio: + return &ContentBlock{Type: ContentBlockTypeUserInputAudio, UserInputAudio: b} + case *UserInputVideo: + return &ContentBlock{Type: ContentBlockTypeUserInputVideo, UserInputVideo: b} + case *UserInputFile: + return &ContentBlock{Type: ContentBlockTypeUserInputFile, UserInputFile: b} + case *AssistantGenText: + return &ContentBlock{Type: ContentBlockTypeAssistantGenText, AssistantGenText: b} + case *AssistantGenImage: + return &ContentBlock{Type: ContentBlockTypeAssistantGenImage, AssistantGenImage: b} + case *AssistantGenAudio: + return &ContentBlock{Type: ContentBlockTypeAssistantGenAudio, AssistantGenAudio: b} + case *AssistantGenVideo: + return &ContentBlock{Type: ContentBlockTypeAssistantGenVideo, AssistantGenVideo: b} + case *FunctionToolCall: + return &ContentBlock{Type: ContentBlockTypeFunctionToolCall, FunctionToolCall: b} + case *FunctionToolResult: + return &ContentBlock{Type: ContentBlockTypeFunctionToolResult, FunctionToolResult: b} + case *ServerToolCall: + return &ContentBlock{Type: ContentBlockTypeServerToolCall, ServerToolCall: b} + case *ServerToolResult: + return &ContentBlock{Type: ContentBlockTypeServerToolResult, ServerToolResult: b} + case *MCPToolCall: + return &ContentBlock{Type: ContentBlockTypeMCPToolCall, MCPToolCall: b} + case *MCPToolResult: + return &ContentBlock{Type: ContentBlockTypeMCPToolResult, MCPToolResult: b} + case *MCPListToolsResult: + return &ContentBlock{Type: ContentBlockTypeMCPListTools, MCPListToolsResult: b} + case *MCPToolApprovalRequest: + return &ContentBlock{Type: ContentBlockTypeMCPToolApprovalRequest, MCPToolApprovalRequest: b} + case *MCPToolApprovalResponse: + return &ContentBlock{Type: ContentBlockTypeMCPToolApprovalResponse, MCPToolApprovalResponse: b} + default: + return nil + } +} diff --git a/schema/openai/content_block.go b/schema/openai/content_block.go index dfa83109a..b0408e310 100644 --- a/schema/openai/content_block.go +++ b/schema/openai/content_block.go @@ -17,7 +17,6 @@ package openai type AssistantGenTextExtension struct { - ItemID string `json:"item_id"` Annotations []*TextAnnotation `json:"annotations"` } From 2203090138048832bce3524bd56535d0de92725d Mon Sep 17 00:00:00 2001 From: mrh997 Date: Mon, 1 Dec 2025 12:08:12 +0800 Subject: [PATCH 09/65] feat: improve MCPToolCallError definition (#592) --- schema/agentic_message.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 3953fac7c..44fc37bb3 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -306,8 +306,8 @@ type MCPToolResult struct { } type MCPToolCallError struct { - Code *int64 - Error string + Code *int64 + Message string } type MCPListToolsResult struct { From f0366a866da8316e2d166fed278ad89ef3d22690 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Mon, 1 Dec 2025 14:15:54 +0800 Subject: [PATCH 10/65] feat: improve Options definition (#593) --- components/agentic/option.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/components/agentic/option.go b/components/agentic/option.go index dae6139be..ac117ddb4 100644 --- a/components/agentic/option.go +++ b/components/agentic/option.go @@ -23,11 +23,11 @@ import ( // Options is the common options for the model. type Options struct { // Temperature is the temperature for the model, which controls the randomness of the model. - Temperature *float32 + Temperature *float64 // Model is the model name. Model *string // TopP is the top p for the model, which controls the diversity of the model. - TopP *float32 + TopP *float64 // Tools is a list of tools the model may call. Tools []*schema.ToolInfo // ToolChoice controls which tool is called by the model. @@ -42,7 +42,7 @@ type Option struct { } // WithTemperature is the option to set the temperature for the model. -func WithTemperature(temperature float32) Option { +func WithTemperature(temperature float64) Option { return Option{ apply: func(opts *Options) { opts.Temperature = &temperature @@ -60,7 +60,7 @@ func WithModel(name string) Option { } // WithTopP is the option to set the top p for the model. -func WithTopP(topP float32) Option { +func WithTopP(topP float64) Option { return Option{ apply: func(opts *Options) { opts.TopP = &topP From 4cf22d5b9bb49dbc5273f394c82ed2f8ea928f5d Mon Sep 17 00:00:00 2001 From: mrh997 Date: Mon, 1 Dec 2025 14:41:51 +0800 Subject: [PATCH 11/65] feat: add CallbackInput definition for CallbackInput (#594) --- components/agentic/callback_extra.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/components/agentic/callback_extra.go b/components/agentic/callback_extra.go index f35b37779..389408d33 100644 --- a/components/agentic/callback_extra.go +++ b/components/agentic/callback_extra.go @@ -37,6 +37,8 @@ type CallbackInput struct { Messages []*schema.AgenticMessage // Tools is the tools to be used in the model. Tools []*schema.ToolInfo + // ToolChoice controls which tool is called by the model. + ToolChoice *schema.ToolChoice // Config is the config for the model. Config *Config // Extra is the extra information for the callback. From eafa0d57266f766f75329abd66fbc03c1253750e Mon Sep 17 00:00:00 2001 From: mrh997 Date: Mon, 1 Dec 2025 17:01:51 +0800 Subject: [PATCH 12/65] feat: define 'omitempty' flag in json tag (#595) --- schema/agentic_message.go | 4 ++-- schema/claude/content_block.go | 42 +++++++++++++++++----------------- schema/claude/response_meta.go | 4 ++-- schema/gemini/response_meta.go | 4 ++-- schema/openai/content_block.go | 32 +++++++++++++------------- schema/openai/response_meta.go | 10 ++++---- 6 files changed, 48 insertions(+), 48 deletions(-) diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 44fc37bb3..367debd97 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -76,7 +76,7 @@ type AgenticResponseMeta struct { type StreamMeta struct { // Index specifies the index position of this block in the final response. - Index int + Index int64 } type ContentBlock struct { @@ -213,7 +213,7 @@ type Reasoning struct { type ReasoningSummary struct { // Index specifies the index position of this summary in the final Reasoning. - Index int + Index int64 Text string } diff --git a/schema/claude/content_block.go b/schema/claude/content_block.go index 4421db807..0c43d1045 100644 --- a/schema/claude/content_block.go +++ b/schema/claude/content_block.go @@ -17,11 +17,11 @@ package claude type AssistantGenTextExtension struct { - Citations []*TextCitation `json:"citations"` + Citations []*TextCitation `json:"citations,omitempty"` } type TextCitation struct { - Type TextCitationType `json:"type"` + Type TextCitationType `json:"type,omitempty"` CharLocation *CitationCharLocation `json:"char_location,omitempty"` PageLocation *CitationPageLocation `json:"page_location,omitempty"` @@ -30,40 +30,40 @@ type TextCitation struct { } type CitationCharLocation struct { - CitedText string `json:"cited_text"` + CitedText string `json:"cited_text,omitempty"` - DocumentTitle string `json:"document_title"` - DocumentIndex int64 `json:"document_index"` + DocumentTitle string `json:"document_title,omitempty"` + DocumentIndex int64 `json:"document_index,omitempty"` - StartCharIndex int64 `json:"start_char_index"` - EndCharIndex int64 `json:"end_char_index"` + StartCharIndex int64 `json:"start_char_index,omitempty"` + EndCharIndex int64 `json:"end_char_index,omitempty"` } type CitationPageLocation struct { - CitedText string `json:"cited_text"` + CitedText string `json:"cited_text,omitempty"` - DocumentTitle string `json:"document_title"` - DocumentIndex int64 `json:"document_index"` + DocumentTitle string `json:"document_title,omitempty"` + DocumentIndex int64 `json:"document_index,omitempty"` - StartPageNumber int64 `json:"start_page_number"` - EndPageNumber int64 `json:"end_page_number"` + StartPageNumber int64 `json:"start_page_number,omitempty"` + EndPageNumber int64 `json:"end_page_number,omitempty"` } type CitationContentBlockLocation struct { - CitedText string `json:"cited_text"` + CitedText string `json:"cited_text,omitempty"` - DocumentTitle string `json:"document_title"` - DocumentIndex int64 `json:"document_index"` + DocumentTitle string `json:"document_title,omitempty"` + DocumentIndex int64 `json:"document_index,omitempty"` - StartBlockIndex int64 `json:"start_block_index"` - EndBlockIndex int64 `json:"end_block_index"` + StartBlockIndex int64 `json:"start_block_index,omitempty"` + EndBlockIndex int64 `json:"end_block_index,omitempty"` } type CitationWebSearchResultLocation struct { - CitedText string `json:"cited_text"` + CitedText string `json:"cited_text,omitempty"` - Title string `json:"title"` - URL string `json:"url"` + Title string `json:"title,omitempty"` + URL string `json:"url,omitempty"` - EncryptedIndex string `json:"encrypted_index"` + EncryptedIndex string `json:"encrypted_index,omitempty"` } diff --git a/schema/claude/response_meta.go b/schema/claude/response_meta.go index 7d9dbe740..9f60dd713 100644 --- a/schema/claude/response_meta.go +++ b/schema/claude/response_meta.go @@ -17,6 +17,6 @@ package claude type ResponseMetaExtension struct { - ID string `json:"id"` - StopReason string `json:"stop_reason"` + ID string `json:"id,omitempty"` + StopReason string `json:"stop_reason,omitempty"` } diff --git a/schema/gemini/response_meta.go b/schema/gemini/response_meta.go index bb4af92c9..a5b3f626c 100644 --- a/schema/gemini/response_meta.go +++ b/schema/gemini/response_meta.go @@ -17,8 +17,8 @@ package gemini type ResponseMetaExtension struct { - ID string `json:"id"` - FinishReason string `json:"finish_reason"` + ID string `json:"id,omitempty"` + FinishReason string `json:"finish_reason,omitempty"` GroundingMeta *GroundingMetadata `json:"grounding_meta,omitempty"` } diff --git a/schema/openai/content_block.go b/schema/openai/content_block.go index b0408e310..5d92be8f7 100644 --- a/schema/openai/content_block.go +++ b/schema/openai/content_block.go @@ -17,11 +17,11 @@ package openai type AssistantGenTextExtension struct { - Annotations []*TextAnnotation `json:"annotations"` + Annotations []*TextAnnotation `json:"annotations,omitempty"` } type TextAnnotation struct { - Type TextAnnotationType `json:"type"` + Type TextAnnotationType `json:"type,omitempty"` FileCitation *TextAnnotationFileCitation `json:"file_citation,omitempty"` URLCitation *TextAnnotationURLCitation `json:"url_citation,omitempty"` @@ -31,45 +31,45 @@ type TextAnnotation struct { type TextAnnotationFileCitation struct { // The ID of the file. - FileID string `json:"file_id"` + FileID string `json:"file_id,omitempty"` // The filename of the file cited. - Filename string `json:"filename"` + Filename string `json:"filename,omitempty"` // The index of the file in the list of files. - Index int64 `json:"index"` + Index int64 `json:"index,omitempty"` } type TextAnnotationURLCitation struct { // The title of the web resource. - Title string `json:"title"` + Title string `json:"title,omitempty"` // The URL of the web resource. - URL string `json:"url"` + URL string `json:"url,omitempty"` // The index of the first character of the URL citation in the message. - StartIndex int64 `json:"start_index"` + StartIndex int64 `json:"start_index,omitempty"` // The index of the last character of the URL citation in the message. - EndIndex int64 `json:"end_index"` + EndIndex int64 `json:"end_index,omitempty"` } type TextAnnotationContainerFileCitation struct { // The ID of the container file. - ContainerID string `json:"container_id"` + ContainerID string `json:"container_id,omitempty"` // The ID of the file. - FileID string `json:"file_id"` + FileID string `json:"file_id,omitempty"` // The filename of the container file cited. - Filename string `json:"filename"` + Filename string `json:"filename,omitempty"` // The index of the first character of the container file citation in the message. - StartIndex int64 `json:"start_index"` + StartIndex int64 `json:"start_index,omitempty"` // The index of the last character of the container file citation in the message. - EndIndex int64 `json:"end_index"` + EndIndex int64 `json:"end_index,omitempty"` } type TextAnnotationFilePath struct { // The ID of the file. - FileID string `json:"file_id"` + FileID string `json:"file_id,omitempty"` // The index of the file in the list of files. - Index int64 `json:"index"` + Index int64 `json:"index,omitempty"` } diff --git a/schema/openai/response_meta.go b/schema/openai/response_meta.go index 809fbb1a5..90e884173 100644 --- a/schema/openai/response_meta.go +++ b/schema/openai/response_meta.go @@ -17,17 +17,17 @@ package openai type ResponseMetaExtension struct { - ID string `json:"id"` - Status string `json:"status"` + ID string `json:"id,omitempty"` + Status string `json:"status,omitempty"` Error *ResponseError `json:"error,omitempty"` IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"` } type ResponseError struct { - Code string `json:"code"` - Message string `json:"message"` + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` } type IncompleteDetails struct { - Reason string `json:"reason"` + Reason string `json:"reason,omitempty"` } From 082d950362f6f9f9a30d74c842011650c208a59c Mon Sep 17 00:00:00 2001 From: mrh997 Date: Tue, 2 Dec 2025 19:09:01 +0800 Subject: [PATCH 13/65] fix: MCPToolApprovalRequest definition (#600) --- schema/agentic_message.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 367debd97..93dd817ca 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -42,7 +42,7 @@ const ( ContentBlockTypeServerToolResult ContentBlockType = "server_tool_result" ContentBlockTypeMCPToolCall ContentBlockType = "mcp_tool_call" ContentBlockTypeMCPToolResult ContentBlockType = "mcp_tool_result" - ContentBlockTypeMCPListTools ContentBlockType = "mcp_list_tools" + ContentBlockTypeMCPListToolsResult ContentBlockType = "mcp_list_tools_result" ContentBlockTypeMCPToolApprovalRequest ContentBlockType = "mcp_tool_approval_request" ContentBlockTypeMCPToolApprovalResponse ContentBlockType = "mcp_tool_approval_response" ) @@ -332,8 +332,8 @@ type MCPListToolsItem struct { } type MCPToolApprovalRequest struct { - // CallID is the unique ID of the tool call. - CallID string + // ID is the approval request ID. + ID string // Name is the name of the tool to run. Name string // Arguments is the JSON string arguments for the tool call. @@ -431,7 +431,7 @@ func NewContentBlock(block any) *ContentBlock { case *MCPToolResult: return &ContentBlock{Type: ContentBlockTypeMCPToolResult, MCPToolResult: b} case *MCPListToolsResult: - return &ContentBlock{Type: ContentBlockTypeMCPListTools, MCPListToolsResult: b} + return &ContentBlock{Type: ContentBlockTypeMCPListToolsResult, MCPListToolsResult: b} case *MCPToolApprovalRequest: return &ContentBlock{Type: ContentBlockTypeMCPToolApprovalRequest, MCPToolApprovalRequest: b} case *MCPToolApprovalResponse: From a21e1f58ff9c5fd9de4b2176d4809dd3fd1ac95a Mon Sep 17 00:00:00 2001 From: mrh997 Date: Wed, 3 Dec 2025 15:24:00 +0800 Subject: [PATCH 14/65] feat: define StreamResponseError for openai (#601) --- schema/openai/response_meta.go | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/schema/openai/response_meta.go b/schema/openai/response_meta.go index 90e884173..e1933065b 100644 --- a/schema/openai/response_meta.go +++ b/schema/openai/response_meta.go @@ -17,10 +17,11 @@ package openai type ResponseMetaExtension struct { - ID string `json:"id,omitempty"` - Status string `json:"status,omitempty"` - Error *ResponseError `json:"error,omitempty"` - IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"` + ID string `json:"id,omitempty"` + Status string `json:"status,omitempty"` + Error *ResponseError `json:"error,omitempty"` + StreamError *StreamResponseError `json:"stream_error,omitempty"` + IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"` } type ResponseError struct { @@ -28,6 +29,12 @@ type ResponseError struct { Message string `json:"message,omitempty"` } +type StreamResponseError struct { + Code string + Message string + Param string +} + type IncompleteDetails struct { Reason string `json:"reason,omitempty"` } From 7d68eeb925a3aa309ec0c8f7dda0c1eb0bad2a29 Mon Sep 17 00:00:00 2001 From: Megumin Date: Wed, 3 Dec 2025 17:22:51 +0800 Subject: [PATCH 15/65] feat: support agentic message concat (#576) feat(agentic_model): - format print - support agentic chat template - support to compose agentic odel&agentic tools node - support agentic tool node - support agentic message concat --- components/agentic/callback_extra_test.go | 35 + components/agentic/option_test.go | 79 + components/prompt/callback_extra.go | 38 + components/prompt/chat_template_agentic.go | 84 + .../prompt/chat_template_agentic_test.go | 111 ++ components/prompt/interface.go | 5 + components/types.go | 1 + compose/chain.go | 43 +- compose/chain_branch.go | 49 +- compose/chain_parallel.go | 43 + compose/component_to_graph_node.go | 33 + compose/graph.go | 40 +- compose/tools_node_agentic.go | 125 ++ compose/tools_node_agentic_test.go | 244 +++ compose/types.go | 15 +- internal/concat.go | 6 +- schema/agentic_message.go | 1352 ++++++++++++++++ schema/agentic_message_test.go | 1381 +++++++++++++++++ schema/message.go | 69 +- 19 files changed, 3707 insertions(+), 46 deletions(-) create mode 100644 components/agentic/callback_extra_test.go create mode 100644 components/agentic/option_test.go create mode 100644 components/prompt/chat_template_agentic.go create mode 100644 components/prompt/chat_template_agentic_test.go create mode 100644 compose/tools_node_agentic.go create mode 100644 compose/tools_node_agentic_test.go create mode 100644 schema/agentic_message_test.go diff --git a/components/agentic/callback_extra_test.go b/components/agentic/callback_extra_test.go new file mode 100644 index 000000000..a77da6cd2 --- /dev/null +++ b/components/agentic/callback_extra_test.go @@ -0,0 +1,35 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package agentic + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" +) + +func TestConvModel(t *testing.T) { + assert.NotNil(t, ConvCallbackInput(&CallbackInput{})) + assert.NotNil(t, ConvCallbackInput([]*schema.AgenticMessage{})) + assert.Nil(t, ConvCallbackInput("asd")) + + assert.NotNil(t, ConvCallbackOutput(&CallbackOutput{})) + assert.NotNil(t, ConvCallbackOutput(&schema.AgenticMessage{})) + assert.Nil(t, ConvCallbackOutput("asd")) +} diff --git a/components/agentic/option_test.go b/components/agentic/option_test.go new file mode 100644 index 000000000..d349f35ac --- /dev/null +++ b/components/agentic/option_test.go @@ -0,0 +1,79 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package agentic + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" +) + +func TestCommon(t *testing.T) { + o := GetCommonOptions(nil, + WithTools([]*schema.ToolInfo{{Name: "test"}}), + WithModel("test"), + WithTemperature(0.1), + WithToolChoice(schema.ToolChoiceAllowed), + WithTopP(0.1), + ) + assert.Len(t, o.Tools, 1) + assert.Equal(t, "test", o.Tools[0].Name) + assert.Equal(t, "test", *o.Model) + assert.Equal(t, float64(0.1), *o.Temperature) + assert.Equal(t, schema.ToolChoiceAllowed, *o.ToolChoice) + assert.Equal(t, float64(0.1), *o.TopP) +} + +func TestImplSpecificOpts(t *testing.T) { + type implSpecificOptions struct { + conf string + index int + } + + withConf := func(conf string) func(o *implSpecificOptions) { + return func(o *implSpecificOptions) { + o.conf = conf + } + } + + withIndex := func(index int) func(o *implSpecificOptions) { + return func(o *implSpecificOptions) { + o.index = index + } + } + + documentOption1 := WrapImplSpecificOptFn(withConf("test_conf")) + documentOption2 := WrapImplSpecificOptFn(withIndex(1)) + + implSpecificOpts := GetImplSpecificOptions(&implSpecificOptions{}, documentOption1, documentOption2) + + assert.Equal(t, &implSpecificOptions{ + conf: "test_conf", + index: 1, + }, implSpecificOpts) + documentOption1 = WrapImplSpecificOptFn(withConf("test_conf")) + documentOption2 = WrapImplSpecificOptFn(withIndex(1)) + + implSpecificOpts = GetImplSpecificOptions(&implSpecificOptions{}, documentOption1, documentOption2) + + assert.Equal(t, &implSpecificOptions{ + conf: "test_conf", + index: 1, + }, implSpecificOpts) +} diff --git a/components/prompt/callback_extra.go b/components/prompt/callback_extra.go index 324a418f3..ff5c3a8ff 100644 --- a/components/prompt/callback_extra.go +++ b/components/prompt/callback_extra.go @@ -21,6 +21,44 @@ import ( "github.com/cloudwego/eino/schema" ) +type AgenticCallbackInput struct { + Variables map[string]any + Templates []schema.AgenticMessagesTemplate + Extra map[string]any +} + +type AgenticCallbackOutput struct { + Result []*schema.AgenticMessage + Templates []schema.AgenticMessagesTemplate + Extra map[string]any +} + +func ConvAgenticCallbackInput(src callbacks.CallbackInput) *AgenticCallbackInput { + switch t := src.(type) { + case *AgenticCallbackInput: + return t + case map[string]any: + return &AgenticCallbackInput{ + Variables: t, + } + default: + return nil + } +} + +func ConvAgenticCallbackOutput(src callbacks.CallbackOutput) *AgenticCallbackOutput { + switch t := src.(type) { + case *AgenticCallbackOutput: + return t + case []*schema.AgenticMessage: + return &AgenticCallbackOutput{ + Result: t, + } + default: + return nil + } +} + // CallbackInput is the input for the callback. type CallbackInput struct { // Variables is the variables for the callback. diff --git a/components/prompt/chat_template_agentic.go b/components/prompt/chat_template_agentic.go new file mode 100644 index 000000000..937d46f26 --- /dev/null +++ b/components/prompt/chat_template_agentic.go @@ -0,0 +1,84 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package prompt + +import ( + "context" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/components" + "github.com/cloudwego/eino/schema" +) + +// FromAgenticMessages creates a new DefaultAgenticChatTemplate from the given templates and format type. +// eg. +// +// template := prompt.FromAgenticMessages(schema.FString, &schema.AgenticMessage{}) +// // in chain, or graph +// chain := compose.NewChain[map[string]any, []*schema.AgenticMessage]() +// chain.AppendAgenticChatTemplate(template) +func FromAgenticMessages(formatType schema.FormatType, templates ...schema.AgenticMessagesTemplate) *DefaultAgenticChatTemplate { + return &DefaultAgenticChatTemplate{ + templates: templates, + formatType: formatType, + } +} + +type DefaultAgenticChatTemplate struct { + templates []schema.AgenticMessagesTemplate + formatType schema.FormatType +} + +func (t *DefaultAgenticChatTemplate) Format(ctx context.Context, vs map[string]any, opts ...Option) (result []*schema.AgenticMessage, err error) { + ctx = callbacks.EnsureRunInfo(ctx, t.GetType(), components.ComponentOfAgenticPrompt) + ctx = callbacks.OnStart(ctx, &AgenticCallbackInput{ + Variables: vs, + Templates: t.templates, + }) + defer func() { + if err != nil { + _ = callbacks.OnError(ctx, err) + } + }() + + result = make([]*schema.AgenticMessage, 0, len(t.templates)) + for _, template := range t.templates { + msgs, err := template.Format(ctx, vs, t.formatType) + if err != nil { + return nil, err + } + + result = append(result, msgs...) + } + + _ = callbacks.OnEnd(ctx, &AgenticCallbackOutput{ + Result: result, + Templates: t.templates, + }) + + return result, nil +} + +// GetType returns the type of the chat template (Default). +func (t *DefaultAgenticChatTemplate) GetType() string { + return "Default" +} + +// IsCallbacksEnabled checks if the callbacks are enabled for the chat template. +func (t *DefaultAgenticChatTemplate) IsCallbacksEnabled() bool { + return true +} diff --git a/components/prompt/chat_template_agentic_test.go b/components/prompt/chat_template_agentic_test.go new file mode 100644 index 000000000..aaa7d6405 --- /dev/null +++ b/components/prompt/chat_template_agentic_test.go @@ -0,0 +1,111 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package prompt + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" +) + +func TestAgenticFormat(t *testing.T) { + pyFmtTestTemplate := []schema.AgenticMessagesTemplate{ + &schema.AgenticMessage{ + Role: schema.AgenticRoleTypeUser, + ContentBlocks: []*schema.ContentBlock{ + {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "{context}"}}, + }, + }, + schema.AgenticMessagesPlaceholder("chat_history", true), + } + jinja2TestTemplate := []schema.AgenticMessagesTemplate{ + &schema.AgenticMessage{ + Role: schema.AgenticRoleTypeUser, + ContentBlocks: []*schema.ContentBlock{ + {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "{{context}}"}}, + }, + }, + schema.AgenticMessagesPlaceholder("chat_history", true), + } + goFmtTestTemplate := []schema.AgenticMessagesTemplate{ + &schema.AgenticMessage{ + Role: schema.AgenticRoleTypeUser, + ContentBlocks: []*schema.ContentBlock{ + {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "{{.context}}"}}, + }, + }, + schema.AgenticMessagesPlaceholder("chat_history", true), + } + testValues := map[string]any{ + "context": "it's beautiful day", + "chat_history": []*schema.AgenticMessage{ + { + Role: schema.AgenticRoleTypeUser, + ContentBlocks: []*schema.ContentBlock{ + {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "1"}}, + }, + }, + { + Role: schema.AgenticRoleTypeUser, + ContentBlocks: []*schema.ContentBlock{ + {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "2"}}, + }, + }, + }, + } + expected := []*schema.AgenticMessage{ + { + Role: schema.AgenticRoleTypeUser, + ContentBlocks: []*schema.ContentBlock{ + {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "it's beautiful day"}}, + }, + }, + { + Role: schema.AgenticRoleTypeUser, + ContentBlocks: []*schema.ContentBlock{ + {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "1"}}, + }, + }, + { + Role: schema.AgenticRoleTypeUser, + ContentBlocks: []*schema.ContentBlock{ + {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "2"}}, + }, + }, + } + + // FString + chatTemplate := FromAgenticMessages(schema.FString, pyFmtTestTemplate...) + msgs, err := chatTemplate.Format(context.Background(), testValues) + assert.Nil(t, err) + assert.Equal(t, expected, msgs) + + // Jinja2 + chatTemplate = FromAgenticMessages(schema.Jinja2, jinja2TestTemplate...) + msgs, err = chatTemplate.Format(context.Background(), testValues) + assert.Nil(t, err) + assert.Equal(t, expected, msgs) + + // GoTemplate + chatTemplate = FromAgenticMessages(schema.GoTemplate, goFmtTestTemplate...) + msgs, err = chatTemplate.Format(context.Background(), testValues) + assert.Nil(t, err) + assert.Equal(t, expected, msgs) +} diff --git a/components/prompt/interface.go b/components/prompt/interface.go index eac695eda..7ffe7216a 100644 --- a/components/prompt/interface.go +++ b/components/prompt/interface.go @@ -23,6 +23,7 @@ import ( ) var _ ChatTemplate = &DefaultChatTemplate{} +var _ AgenticChatTemplate = &DefaultAgenticChatTemplate{} // ChatTemplate formats a variables map into a list of messages for a ChatModel. // @@ -42,3 +43,7 @@ var _ ChatTemplate = &DefaultChatTemplate{} type ChatTemplate interface { Format(ctx context.Context, vs map[string]any, opts ...Option) ([]*schema.Message, error) } + +type AgenticChatTemplate interface { + Format(ctx context.Context, vs map[string]any, opts ...Option) ([]*schema.AgenticMessage, error) +} diff --git a/components/types.go b/components/types.go index a23d82a68..2ba088e93 100644 --- a/components/types.go +++ b/components/types.go @@ -66,6 +66,7 @@ type Component string const ( // ComponentOfPrompt identifies chat template components. ComponentOfPrompt Component = "ChatTemplate" + ComponentOfAgenticPrompt Component = "AgenticChatTemplate" // ComponentOfChatModel identifies chat model components. ComponentOfChatModel Component = "ChatModel" ComponentOfAgenticModel Component = "AgenticModel" diff --git a/compose/chain.go b/compose/chain.go index 5e4a8e1c0..8484e8767 100644 --- a/compose/chain.go +++ b/compose/chain.go @@ -22,6 +22,7 @@ import ( "fmt" "reflect" + "github.com/cloudwego/eino/components/agentic" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" @@ -174,6 +175,18 @@ func (c *Chain[I, O]) AppendChatModel(node model.BaseChatModel, opts ...GraphAdd return c } +// AppendAgenticModel add a agentic.Model node to the chain. +// e.g. +// +// model, err := openai.NewAgenticModel(ctx, config) +// if err != nil {...} +// chain.AppendAgenticModel(model) +func (c *Chain[I, O]) AppendAgenticModel(node agentic.Model, opts ...GraphAddNodeOpt) *Chain[I, O] { + gNode, options := toAgenticModelNode(node, opts...) + c.addNode(gNode, options) + return c +} + // AppendChatTemplate add a ChatTemplate node to the chain. // eg. // @@ -189,11 +202,23 @@ func (c *Chain[I, O]) AppendChatTemplate(node prompt.ChatTemplate, opts ...Graph return c } +// AppendAgenticChatTemplate add a prompt.AgenticChatTemplate node to the chain. +// eg. +// +// chatTemplate, err := prompt.FromAgenticMessages(schema.FString, &schema.AgenticMessage{}) +// +// chain.AppendAgenticChatTemplate(chatTemplate) +func (c *Chain[I, O]) AppendAgenticChatTemplate(node prompt.AgenticChatTemplate, opts ...GraphAddNodeOpt) *Chain[I, O] { + gNode, options := toAgenticChatTemplateNode(node, opts...) + c.addNode(gNode, options) + return c +} + // AppendToolsNode add a ToolsNode node to the chain. // e.g. // -// toolsNode, err := tools.NewToolNode(ctx, &tools.ToolsNodeConfig{ -// Tools: []tools.Tool{...}, +// toolsNode, err := compose.NewToolNode(ctx, &compose.ToolsNodeConfig{ +// Tools: []tools.BaseTool{...}, // }) // // chain.AppendToolsNode(toolsNode) @@ -203,6 +228,20 @@ func (c *Chain[I, O]) AppendToolsNode(node *ToolsNode, opts ...GraphAddNodeOpt) return c } +// AppendAgenticToolsNode add a AgenticToolsNode node to the chain. +// e.g. +// +// toolsNode, err := compose.NewAgenticToolsNode(ctx, &compose.ToolsNodeConfig{ +// Tools: []tools.BaseTool{...}, +// }) +// +// chain.AppendAgenticToolsNode(toolsNode) +func (c *Chain[I, O]) AppendAgenticToolsNode(node *AgenticToolsNode, opts ...GraphAddNodeOpt) *Chain[I, O] { + gNode, options := toAgenticToolsNode(node, opts...) + c.addNode(gNode, options) + return c +} + // AppendDocumentTransformer add a DocumentTransformer node to the chain. // e.g. // diff --git a/compose/chain_branch.go b/compose/chain_branch.go index ec3a433af..004dbfac3 100644 --- a/compose/chain_branch.go +++ b/compose/chain_branch.go @@ -20,6 +20,7 @@ import ( "context" "fmt" + "github.com/cloudwego/eino/components/agentic" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" @@ -146,6 +147,22 @@ func (cb *ChainBranch) AddChatModel(key string, node model.BaseChatModel, opts . return cb.addNode(key, gNode, options) } +// AddAgenticModel adds a agentic.Model node to the branch. +// eg. +// +// model1, err := openai.NewAgenticModel(ctx, &openai.AgenticModelConfig{ +// Model: "gpt-4o", +// }) +// model2, err := openai.NewAgenticModel(ctx, &openai.AgenticModelConfig{ +// Model: "gpt-4o-mini", +// }) +// cb.AddAgenticModel("agentic_model_key_1", model1) +// cb.AddAgenticModel("agentic_model_key_2", model2) +func (cb *ChainBranch) AddAgenticModel(key string, node agentic.Model, opts ...GraphAddNodeOpt) *ChainBranch { + gNode, options := toAgenticModelNode(node, opts...) + return cb.addNode(key, gNode, options) +} + // AddChatTemplate adds a ChatTemplate node to the branch. // eg. // @@ -167,11 +184,26 @@ func (cb *ChainBranch) AddChatTemplate(key string, node prompt.ChatTemplate, opt return cb.addNode(key, gNode, options) } +// AddAgenticChatTemplate adds a prompt.AgenticChatTemplate node to the branch. +// eg. +// +// chatTemplate, err := prompt.FromAgenticMessages(schema.FString, &schema.AgenticMessage{}) +// +// cb.AddAgenticChatTemplate("chat_template_key_01", chatTemplate) +// +// chatTemplate2, err := prompt.FromAgenticMessages(schema.FString, &schema.AgenticMessage{}) +// +// cb.AddAgenticChatTemplate("chat_template_key_02", chatTemplate2) +func (cb *ChainBranch) AddAgenticChatTemplate(key string, node prompt.AgenticChatTemplate, opts ...GraphAddNodeOpt) *ChainBranch { + gNode, options := toAgenticChatTemplateNode(node, opts...) + return cb.addNode(key, gNode, options) +} + // AddToolsNode adds a ToolsNode to the branch. // eg. // -// toolsNode, err := tools.NewToolNode(ctx, &tools.ToolsNodeConfig{ -// Tools: []tools.Tool{...}, +// toolsNode, err := compose.NewToolNode(ctx, &compose.ToolsNodeConfig{ +// Tools: []tools.BaseTool{...}, // }) // // cb.AddToolsNode("tools_node_key", toolsNode) @@ -180,6 +212,19 @@ func (cb *ChainBranch) AddToolsNode(key string, node *ToolsNode, opts ...GraphAd return cb.addNode(key, gNode, options) } +// AddAgenticToolsNode adds a AgenticToolsNode to the branch. +// eg. +// +// toolsNode, err := compose.NewAgenticToolsNode(ctx, &compose.ToolsNodeConfig{ +// Tools: []tools.BaseTool{...}, +// }) +// +// cb.AddAgenticToolsNode("tools_node_key", toolsNode) +func (cb *ChainBranch) AddAgenticToolsNode(key string, node *AgenticToolsNode, opts ...GraphAddNodeOpt) *ChainBranch { + gNode, options := toAgenticToolsNode(node, opts...) + return cb.addNode(key, gNode, options) +} + // AddLambda adds a Lambda node to the branch. // eg. // diff --git a/compose/chain_parallel.go b/compose/chain_parallel.go index 64cdf2db1..128ed4a26 100644 --- a/compose/chain_parallel.go +++ b/compose/chain_parallel.go @@ -19,6 +19,7 @@ package compose import ( "fmt" + "github.com/cloudwego/eino/components/agentic" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" @@ -70,6 +71,24 @@ func (p *Parallel) AddChatModel(outputKey string, node model.BaseChatModel, opts return p.addNode(outputKey, gNode, options) } +// AddAgenticModel adds a agentic.Model to the parallel. +// eg. +// +// model1, err := openai.NewAgenticModel(ctx, &openai.AgenticModelConfig{ +// Model: "gpt-4o", +// }) +// +// model2, err := openai.NewAgenticModel(ctx, &openai.AgenticModelConfig{ +// Model: "gpt-4o", +// }) +// +// p.AddAgenticModel("output_key1", model1) +// p.AddAgenticModel("output_key2", model2) +func (p *Parallel) AddAgenticModel(outputKey string, node agentic.Model, opts ...GraphAddNodeOpt) *Parallel { + gNode, options := toAgenticModelNode(node, append(opts, WithOutputKey(outputKey))...) + return p.addNode(outputKey, gNode, options) +} + // AddChatTemplate adds a chat template to the parallel. // eg. // @@ -84,6 +103,17 @@ func (p *Parallel) AddChatTemplate(outputKey string, node prompt.ChatTemplate, o return p.addNode(outputKey, gNode, options) } +// AddAgenticChatTemplate adds a prompt.AgenticChatTemplate to the parallel. +// eg. +// +// chatTemplate01, err := prompt.FromAgenticMessages(schema.FString, &schema.AgenticMessage{}) +// +// p.AddAgenticChatTemplate("output_key01", chatTemplate01) +func (p *Parallel) AddAgenticChatTemplate(outputKey string, node prompt.AgenticChatTemplate, opts ...GraphAddNodeOpt) *Parallel { + gNode, options := toAgenticChatTemplateNode(node, append(opts, WithOutputKey(outputKey))...) + return p.addNode(outputKey, gNode, options) +} + // AddToolsNode adds a tools node to the parallel. // eg. // @@ -97,6 +127,19 @@ func (p *Parallel) AddToolsNode(outputKey string, node *ToolsNode, opts ...Graph return p.addNode(outputKey, gNode, options) } +// AddAgenticToolsNode adds a tools node to the parallel. +// eg. +// +// toolsNode, err := compose.NewAgenticToolsNode(ctx, &compose.ToolsNodeConfig{ +// Tools: []tool.BaseTool{...}, +// }) +// +// p.AddAgenticToolsNode("output_key01", toolsNode) +func (p *Parallel) AddAgenticToolsNode(outputKey string, node *AgenticToolsNode, opts ...GraphAddNodeOpt) *Parallel { + gNode, options := toAgenticToolsNode(node, append(opts, WithOutputKey(outputKey))...) + return p.addNode(outputKey, gNode, options) +} + // AddLambda adds a lambda node to the parallel. // eg. // diff --git a/compose/component_to_graph_node.go b/compose/component_to_graph_node.go index ab4694f1a..e64ce4f19 100644 --- a/compose/component_to_graph_node.go +++ b/compose/component_to_graph_node.go @@ -18,6 +18,7 @@ package compose import ( "github.com/cloudwego/eino/components" + "github.com/cloudwego/eino/components/agentic" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" @@ -101,6 +102,17 @@ func toChatModelNode(node model.BaseChatModel, opts ...GraphAddNodeOpt) (*graphN opts...) } +func toAgenticModelNode(node agentic.Model, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { + return toComponentNode( + node, + components.ComponentOfAgenticModel, + node.Generate, + node.Stream, + nil, nil, + opts..., + ) +} + func toChatTemplateNode(node prompt.ChatTemplate, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { return toComponentNode( node, @@ -112,6 +124,16 @@ func toChatTemplateNode(node prompt.ChatTemplate, opts ...GraphAddNodeOpt) (*gra opts...) } +func toAgenticChatTemplateNode(node prompt.AgenticChatTemplate, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { + return toComponentNode( + node, + components.ComponentOfAgenticPrompt, + node.Format, + nil, nil, nil, + opts..., + ) +} + func toDocumentTransformerNode(node document.Transformer, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { return toComponentNode( node, @@ -134,6 +156,17 @@ func toToolsNode(node *ToolsNode, opts ...GraphAddNodeOpt) (*graphNode, *graphAd opts...) } +func toAgenticToolsNode(node *AgenticToolsNode, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { + return toComponentNode( + node, + ComponentOfAgenticToolsNode, + node.Invoke, + node.Stream, + nil, nil, + opts..., + ) +} + func toLambdaNode(node *Lambda, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { info, options := getNodeInfo(opts...) diff --git a/compose/graph.go b/compose/graph.go index 9370665f0..877b8fb42 100644 --- a/compose/graph.go +++ b/compose/graph.go @@ -23,6 +23,7 @@ import ( "reflect" "strings" + "github.com/cloudwego/eino/components/agentic" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" @@ -352,6 +353,19 @@ func (g *graph) AddChatModelNode(key string, node model.BaseChatModel, opts ...G return g.addNode(key, gNode, options) } +// AddAgenticModelNode add node that implements agentic.Model. +// e.g. +// +// model, err := openai.NewAgenticModel(ctx, &openai.AgenticModelConfig{ +// Model: "gpt-4o", +// }) +// +// graph.AddAgenticModelNode("agentic_model_node_key", model) +func (g *graph) AddAgenticModelNode(key string, node agentic.Model, opts ...GraphAddNodeOpt) error { + gNode, options := toAgenticModelNode(node, opts...) + return g.addNode(key, gNode, options) +} + // AddChatTemplateNode add node that implements prompt.ChatTemplate. // e.g. // @@ -366,10 +380,21 @@ func (g *graph) AddChatTemplateNode(key string, node prompt.ChatTemplate, opts . return g.addNode(key, gNode, options) } -// AddToolsNode adds a node that implements tools.ToolsNode. +// AddAgenticChatTemplateNode add node that implements prompt.AgenticChatTemplate. +// e.g. +// +// chatTemplate, err := prompt.FromAgenticMessages(schema.FString, &schema.AgenticMessage{}) +// +// graph.AddAgenticChatTemplateNode("chat_template_node_key", chatTemplate) +func (g *graph) AddAgenticChatTemplateNode(key string, node prompt.AgenticChatTemplate, opts ...GraphAddNodeOpt) error { + gNode, options := toAgenticChatTemplateNode(node, opts...) + return g.addNode(key, gNode, options) +} + +// AddToolsNode adds a node that implements ToolsNode. // e.g. // -// toolsNode, err := tools.NewToolNode(ctx, &tools.ToolsNodeConfig{}) +// toolsNode, err := compose.NewToolNode(ctx, &compose.ToolsNodeConfig{}) // // graph.AddToolsNode("tools_node_key", toolsNode) func (g *graph) AddToolsNode(key string, node *ToolsNode, opts ...GraphAddNodeOpt) error { @@ -377,6 +402,17 @@ func (g *graph) AddToolsNode(key string, node *ToolsNode, opts ...GraphAddNodeOp return g.addNode(key, gNode, options) } +// AddAgenticToolsNode adds a node that implements AgenticToolsNode. +// e.g. +// +// toolsNode, err := compose.NewAgenticToolsNode(ctx, &compose.ToolsNodeConfig{}) +// +// graph.AddAgenticToolsNode("tools_node_key", toolsNode) +func (g *graph) AddAgenticToolsNode(key string, node *AgenticToolsNode, opts ...GraphAddNodeOpt) error { + gNode, options := toAgenticToolsNode(node, opts...) + return g.addNode(key, gNode, options) +} + // AddDocumentTransformerNode adds a node that implements document.Transformer. // e.g. // diff --git a/compose/tools_node_agentic.go b/compose/tools_node_agentic.go new file mode 100644 index 000000000..38c5c89de --- /dev/null +++ b/compose/tools_node_agentic.go @@ -0,0 +1,125 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "context" + + "github.com/cloudwego/eino/schema" +) + +// NewAgenticToolsNode creates a new AgenticToolsNode. +// e.g. +// +// conf := &ToolsNodeConfig{ +// Tools: []tool.BaseTool{invokableTool1, streamableTool2}, +// } +// toolsNode, err := NewAgenticToolsNode(ctx, conf) +func NewAgenticToolsNode(ctx context.Context, conf *ToolsNodeConfig) (*AgenticToolsNode, error) { + tn, err := NewToolNode(ctx, conf) + if err != nil { + return nil, err + } + return &AgenticToolsNode{inner: tn}, nil +} + +type AgenticToolsNode struct { + inner *ToolsNode +} + +func (a *AgenticToolsNode) Invoke(ctx context.Context, input *schema.AgenticMessage, opts ...ToolsNodeOption) ([]*schema.AgenticMessage, error) { + result, err := a.inner.Invoke(ctx, agenticMessageToToolCallMessage(input), opts...) + if err != nil { + return nil, err + } + return toolMessageToAgenticMessage(result), nil +} + +func (a *AgenticToolsNode) Stream(ctx context.Context, input *schema.AgenticMessage, + opts ...ToolsNodeOption) (*schema.StreamReader[[]*schema.AgenticMessage], error) { + result, err := a.inner.Stream(ctx, agenticMessageToToolCallMessage(input), opts...) + if err != nil { + return nil, err + } + return streamToolMessageToAgenticMessage(result), nil +} + +func agenticMessageToToolCallMessage(input *schema.AgenticMessage) *schema.Message { + var tc []schema.ToolCall + for _, block := range input.ContentBlocks { + if block.Type != schema.ContentBlockTypeFunctionToolCall || block.FunctionToolCall == nil { + continue + } + tc = append(tc, schema.ToolCall{ + ID: block.FunctionToolCall.CallID, + Function: schema.FunctionCall{ + Name: block.FunctionToolCall.Name, + Arguments: block.FunctionToolCall.Arguments, + }, + }) + } + return &schema.Message{ + Role: schema.Assistant, + ToolCalls: tc, + } +} + +func toolMessageToAgenticMessage(input []*schema.Message) []*schema.AgenticMessage { + var results []*schema.ContentBlock + for _, m := range input { + results = append(results, &schema.ContentBlock{ + Type: schema.ContentBlockTypeFunctionToolResult, + FunctionToolResult: &schema.FunctionToolResult{ + CallID: m.ToolCallID, + Name: m.ToolName, + Result: m.Content, + Extra: m.Extra, + }, + }) + } + return []*schema.AgenticMessage{{ + Role: schema.AgenticRoleTypeUser, + ContentBlocks: results, + }} +} + +func streamToolMessageToAgenticMessage(input *schema.StreamReader[[]*schema.Message]) *schema.StreamReader[[]*schema.AgenticMessage] { + return schema.StreamReaderWithConvert(input, func(t []*schema.Message) ([]*schema.AgenticMessage, error) { + var results []*schema.ContentBlock + for i, m := range t { + if m == nil { + continue + } + results = append(results, &schema.ContentBlock{ + Type: schema.ContentBlockTypeFunctionToolResult, + FunctionToolResult: &schema.FunctionToolResult{ + CallID: m.ToolCallID, + Name: m.ToolName, + Result: m.Content, + Extra: m.Extra, + }, + StreamMeta: &schema.StreamMeta{Index: int64(i)}, + }) + } + return []*schema.AgenticMessage{{ + Role: schema.AgenticRoleTypeUser, + ContentBlocks: results, + }}, nil + }) +} + +func (a *AgenticToolsNode) GetType() string { return "" } diff --git a/compose/tools_node_agentic_test.go b/compose/tools_node_agentic_test.go new file mode 100644 index 000000000..dcd3177a9 --- /dev/null +++ b/compose/tools_node_agentic_test.go @@ -0,0 +1,244 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "io" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" +) + +func TestAgenticMessageToToolCallMessage(t *testing.T) { + input := &schema.AgenticMessage{ + ContentBlocks: []*schema.ContentBlock{ + { + Type: schema.ContentBlockTypeFunctionToolCall, + FunctionToolCall: &schema.FunctionToolCall{ + CallID: "1", + Name: "name1", + Arguments: "arg1", + }, + }, + { + Type: schema.ContentBlockTypeFunctionToolCall, + FunctionToolCall: &schema.FunctionToolCall{ + CallID: "2", + Name: "name2", + Arguments: "arg2", + }, + }, + { + Type: schema.ContentBlockTypeFunctionToolCall, + FunctionToolCall: &schema.FunctionToolCall{ + CallID: "3", + Name: "name3", + Arguments: "arg3", + }, + }, + }, + } + ret := agenticMessageToToolCallMessage(input) + assert.Equal(t, schema.Assistant, ret.Role) + assert.Equal(t, []schema.ToolCall{ + { + ID: "1", + Function: schema.FunctionCall{ + Name: "name1", + Arguments: "arg1", + }, + }, + { + ID: "2", + Function: schema.FunctionCall{ + Name: "name2", + Arguments: "arg2", + }, + }, + { + ID: "3", + Function: schema.FunctionCall{ + Name: "name3", + Arguments: "arg3", + }, + }, + }, ret.ToolCalls) +} + +func TestToolMessageToAgenticMessage(t *testing.T) { + input := []*schema.Message{ + { + Role: schema.Tool, + Content: "content1", + ToolCallID: "1", + ToolName: "name1", + }, + { + Role: schema.Tool, + Content: "content2", + ToolCallID: "2", + ToolName: "name2", + }, + { + Role: schema.Tool, + Content: "content3", + ToolCallID: "3", + ToolName: "name3", + }, + } + ret := toolMessageToAgenticMessage(input) + assert.Equal(t, 1, len(ret)) + assert.Equal(t, schema.AgenticRoleTypeUser, ret[0].Role) + assert.Equal(t, []*schema.ContentBlock{ + { + Type: schema.ContentBlockTypeFunctionToolResult, + FunctionToolResult: &schema.FunctionToolResult{ + CallID: "1", + Name: "name1", + Result: "content1", + }, + }, + { + Type: schema.ContentBlockTypeFunctionToolResult, + FunctionToolResult: &schema.FunctionToolResult{ + CallID: "2", + Name: "name2", + Result: "content2", + }, + }, + { + Type: schema.ContentBlockTypeFunctionToolResult, + FunctionToolResult: &schema.FunctionToolResult{ + CallID: "3", + Name: "name3", + Result: "content3", + }, + }, + }, ret[0].ContentBlocks) +} + +func TestStreamToolMessageToAgenticMessage(t *testing.T) { + input := schema.StreamReaderFromArray([][]*schema.Message{ + { + { + Role: schema.Tool, + Content: "content1-1", + ToolName: "name1", + ToolCallID: "1", + }, + nil, nil, + }, + { + nil, + { + Role: schema.Tool, + Content: "content2-1", + ToolName: "name2", + ToolCallID: "2", + }, + nil, + }, + { + { + Role: schema.Tool, + Content: "content1-2", + ToolName: "name2", + ToolCallID: "2", + }, + nil, nil, + }, + { + nil, nil, + { + Role: schema.Tool, + Content: "content3-1", + ToolName: "name3", + ToolCallID: "3", + }, + }, + { + nil, + { + Role: schema.Tool, + Content: "content2-2", + ToolName: "name2", + ToolCallID: "2", + }, + nil, + }, + { + nil, nil, + { + Role: schema.Tool, + Content: "content3-2", + ToolName: "name3", + ToolCallID: "3", + }, + }, + }) + ret := streamToolMessageToAgenticMessage(input) + var chunks [][]*schema.AgenticMessage + for { + chunk, err := ret.Recv() + if err == io.EOF { + break + } + assert.NoError(t, err) + chunks = append(chunks, chunk) + } + result, err := schema.ConcatAgenticMessagesArray(chunks) + assert.NoError(t, err) + assert.Equal(t, []*schema.AgenticMessage{ + { + Role: schema.AgenticRoleTypeUser, + ContentBlocks: []*schema.ContentBlock{ + { + Type: schema.ContentBlockTypeFunctionToolResult, + FunctionToolResult: &schema.FunctionToolResult{ + CallID: "1", + Name: "name1", + Result: "content1-1content1-2", + Extra: map[string]interface{}{}, + }, + StreamMeta: &schema.StreamMeta{Index: 0}, + }, + { + Type: schema.ContentBlockTypeFunctionToolResult, + FunctionToolResult: &schema.FunctionToolResult{ + CallID: "2", + Name: "name2", + Result: "content2-1content2-2", + Extra: map[string]interface{}{}, + }, + StreamMeta: &schema.StreamMeta{Index: 1}, + }, + { + Type: schema.ContentBlockTypeFunctionToolResult, + FunctionToolResult: &schema.FunctionToolResult{ + CallID: "3", + Name: "name3", + Result: "content3-1content3-2", + Extra: map[string]interface{}{}, + }, + StreamMeta: &schema.StreamMeta{Index: 2}, + }, + }, + }, + }, result) +} diff --git a/compose/types.go b/compose/types.go index 13d925df2..54f8e2be3 100644 --- a/compose/types.go +++ b/compose/types.go @@ -25,13 +25,14 @@ type component = components.Component // built-in component types in graph node. // it represents the type of the most primitive executable object provided by the user. const ( - ComponentOfUnknown component = "Unknown" - ComponentOfGraph component = "Graph" - ComponentOfWorkflow component = "Workflow" - ComponentOfChain component = "Chain" - ComponentOfPassthrough component = "Passthrough" - ComponentOfToolsNode component = "ToolsNode" - ComponentOfLambda component = "Lambda" + ComponentOfUnknown component = "Unknown" + ComponentOfGraph component = "Graph" + ComponentOfWorkflow component = "Workflow" + ComponentOfChain component = "Chain" + ComponentOfPassthrough component = "Passthrough" + ComponentOfToolsNode component = "ToolsNode" + ComponentOfAgenticToolsNode component = "AgenticToolsNode" + ComponentOfLambda component = "Lambda" ) // NodeTriggerMode controls the triggering mode of graph nodes. diff --git a/internal/concat.go b/internal/concat.go index 2681322ab..fd9b8abc5 100644 --- a/internal/concat.go +++ b/internal/concat.go @@ -99,7 +99,7 @@ func ConcatItems[T any](items []T) (T, error) { if typ.Kind() == reflect.Map { cv, err = concatMaps(v) } else { - cv, err = concatSliceValue(v) + cv, err = ConcatSliceValue(v) } if err != nil { @@ -158,7 +158,7 @@ func concatMaps(ms reflect.Value) (reflect.Value, error) { if v.Type().Elem().Kind() == reflect.Map { cv, err = concatMaps(v) } else { - cv, err = concatSliceValue(v) + cv, err = ConcatSliceValue(v) } if err != nil { @@ -171,7 +171,7 @@ func concatMaps(ms reflect.Value) (reflect.Value, error) { return ret, nil } -func concatSliceValue(val reflect.Value) (reflect.Value, error) { +func ConcatSliceValue(val reflect.Value) (reflect.Value, error) { elmType := val.Type().Elem() if val.Len() == 1 { diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 93dd817ca..2139201ec 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -17,9 +17,17 @@ package schema import ( + "context" + "fmt" + "reflect" + "strings" + "github.com/cloudwego/eino/schema/claude" "github.com/cloudwego/eino/schema/gemini" + + "github.com/cloudwego/eino/internal" "github.com/cloudwego/eino/schema/openai" + "github.com/eino-contrib/jsonschema" ) @@ -440,3 +448,1347 @@ func NewContentBlock(block any) *ContentBlock { return nil } } + +// AgenticMessagesTemplate is the interface for messages template. +// It's used to render a template to a list of agentic messages. +// e.g. +// +// chatTemplate := prompt.FromAgenticMessages( +// &schema.AgenticMessage{ +// Role: schema.AgenticRoleTypeSystem, +// ContentBlocks: []*schema.ContentBlock{ +// {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "you are an eino helper"}}, +// }, +// }, +// schema.AgenticMessagesPlaceholder("history", false), // <= this will use the value of "history" in params +// ) +// msgs, err := chatTemplate.Format(ctx, params) +type AgenticMessagesTemplate interface { + Format(ctx context.Context, vs map[string]any, formatType FormatType) ([]*AgenticMessage, error) +} + +var _ AgenticMessagesTemplate = &AgenticMessage{} +var _ AgenticMessagesTemplate = AgenticMessagesPlaceholder("", false) + +type agenticMessagesPlaceholder struct { + key string + optional bool +} + +// AgenticMessagesPlaceholder can render a placeholder to a list of agentic messages in params. +// e.g. +// +// placeholder := AgenticMessagesPlaceholder("history", false) +// params := map[string]any{ +// "history": []*schema.AgenticMessage{ +// &schema.AgenticMessage{ +// Role: schema.AgenticRoleTypeSystem, +// ContentBlocks: []*schema.ContentBlock{ +// {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "you are an eino helper"}}, +// }, +// }, +// }, +// } +// chatTemplate := chatTpl := prompt.FromMessages( +// schema.AgenticMessagesPlaceholder("history", false), // <= this will use the value of "history" in params +// ) +// msgs, err := chatTemplate.Format(ctx, params) +func AgenticMessagesPlaceholder(key string, optional bool) AgenticMessagesTemplate { + return &agenticMessagesPlaceholder{ + key: key, + optional: optional, + } +} + +func (p *agenticMessagesPlaceholder) Format(_ context.Context, vs map[string]any, _ FormatType) ([]*AgenticMessage, error) { + v, ok := vs[p.key] + if !ok { + if p.optional { + return []*AgenticMessage{}, nil + } + + return nil, fmt.Errorf("message placeholder format: %s not found", p.key) + } + + msgs, ok := v.([]*AgenticMessage) + if !ok { + return nil, fmt.Errorf("only agentic messages can be used to format message placeholder, key: %v, actual type: %v", p.key, reflect.TypeOf(v)) + } + + return msgs, nil +} + +// Format returns the agentic messages after rendering by the given formatType. +// It formats only the user input fields (UserInputText, UserInputImage, UserInputAudio, UserInputVideo, UserInputFile). +// e.g. +// +// msg := &schema.AgenticMessage{ +// Role: schema.AgenticRoleTypeUser, +// ContentBlocks: []*schema.ContentBlock{ +// {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "hello {name}"}}, +// }, +// } +// msgs, err := msg.Format(ctx, map[string]any{"name": "eino"}, schema.FString) +// // msgs[0].ContentBlocks[0].UserInputText.Text will be "hello eino" +func (m *AgenticMessage) Format(_ context.Context, vs map[string]any, formatType FormatType) ([]*AgenticMessage, error) { + copied := *m + + if len(m.ContentBlocks) > 0 { + copiedBlocks := make([]*ContentBlock, len(m.ContentBlocks)) + for i, block := range m.ContentBlocks { + if block == nil { + copiedBlocks[i] = nil + continue + } + + copiedBlock := *block + var err error + + switch block.Type { + case ContentBlockTypeUserInputText: + if block.UserInputText != nil { + copiedBlock.UserInputText, err = formatUserInputText(block.UserInputText, vs, formatType) + if err != nil { + return nil, err + } + } + case ContentBlockTypeUserInputImage: + if block.UserInputImage != nil { + copiedBlock.UserInputImage, err = formatUserInputImage(block.UserInputImage, vs, formatType) + if err != nil { + return nil, err + } + } + case ContentBlockTypeUserInputAudio: + if block.UserInputAudio != nil { + copiedBlock.UserInputAudio, err = formatUserInputAudio(block.UserInputAudio, vs, formatType) + if err != nil { + return nil, err + } + } + case ContentBlockTypeUserInputVideo: + if block.UserInputVideo != nil { + copiedBlock.UserInputVideo, err = formatUserInputVideo(block.UserInputVideo, vs, formatType) + if err != nil { + return nil, err + } + } + case ContentBlockTypeUserInputFile: + if block.UserInputFile != nil { + copiedBlock.UserInputFile, err = formatUserInputFile(block.UserInputFile, vs, formatType) + if err != nil { + return nil, err + } + } + } + + copiedBlocks[i] = &copiedBlock + } + copied.ContentBlocks = copiedBlocks + } + + return []*AgenticMessage{&copied}, nil +} + +func formatUserInputText(uit *UserInputText, vs map[string]any, formatType FormatType) (*UserInputText, error) { + text, err := formatContent(uit.Text, vs, formatType) + if err != nil { + return nil, err + } + copied := *uit + copied.Text = text + return &copied, nil +} + +func formatUserInputImage(uii *UserInputImage, vs map[string]any, formatType FormatType) (*UserInputImage, error) { + copied := *uii + if uii.URL != "" { + url, err := formatContent(uii.URL, vs, formatType) + if err != nil { + return nil, err + } + copied.URL = url + } + if uii.Base64Data != "" { + base64data, err := formatContent(uii.Base64Data, vs, formatType) + if err != nil { + return nil, err + } + copied.Base64Data = base64data + } + return &copied, nil +} + +func formatUserInputAudio(uia *UserInputAudio, vs map[string]any, formatType FormatType) (*UserInputAudio, error) { + copied := *uia + if uia.URL != "" { + url, err := formatContent(uia.URL, vs, formatType) + if err != nil { + return nil, err + } + copied.URL = url + } + if uia.Base64Data != "" { + base64data, err := formatContent(uia.Base64Data, vs, formatType) + if err != nil { + return nil, err + } + copied.Base64Data = base64data + } + return &copied, nil +} + +func formatUserInputVideo(uiv *UserInputVideo, vs map[string]any, formatType FormatType) (*UserInputVideo, error) { + copied := *uiv + if uiv.URL != "" { + url, err := formatContent(uiv.URL, vs, formatType) + if err != nil { + return nil, err + } + copied.URL = url + } + if uiv.Base64Data != "" { + base64data, err := formatContent(uiv.Base64Data, vs, formatType) + if err != nil { + return nil, err + } + copied.Base64Data = base64data + } + return &copied, nil +} + +func formatUserInputFile(uif *UserInputFile, vs map[string]any, formatType FormatType) (*UserInputFile, error) { + copied := *uif + if uif.URL != "" { + url, err := formatContent(uif.URL, vs, formatType) + if err != nil { + return nil, err + } + copied.URL = url + } + if uif.Name != "" { + name, err := formatContent(uif.Name, vs, formatType) + if err != nil { + return nil, err + } + copied.Name = name + } + if uif.Base64Data != "" { + base64data, err := formatContent(uif.Base64Data, vs, formatType) + if err != nil { + return nil, err + } + copied.Base64Data = base64data + } + return &copied, nil +} + +func ConcatAgenticMessagesArray(mas [][]*AgenticMessage) ([]*AgenticMessage, error) { + return buildConcatGenericArray[AgenticMessage](ConcatAgenticMessages)(mas) +} + +func ConcatAgenticMessages(msgs []*AgenticMessage) (*AgenticMessage, error) { + var ( + role AgenticRoleType + blocksList [][]*ContentBlock + blocks []*ContentBlock + metas []*AgenticResponseMeta + ) + + if len(msgs) == 1 { + return msgs[0], nil + } + + for idx, msg := range msgs { + if msg == nil { + return nil, fmt.Errorf("message at index %d is nil", idx) + } + + if msg.Role != "" { + if role == "" { + role = msg.Role + } else if role != msg.Role { + return nil, fmt.Errorf("cannot concat messages with different roles: got '%s' and '%s'", role, msg.Role) + } + } + + for _, block := range msg.ContentBlocks { + if block.StreamMeta == nil { + // Non-streaming block + if len(blocksList) > 0 { + // Cannot mix streaming and non-streaming blocks + return nil, fmt.Errorf("found non-streaming block after streaming blocks") + } + // Collect non-streaming block + blocks = append(blocks, block) + } else { + // Streaming block + if len(blocks) > 0 { + // Cannot mix non-streaming and streaming blocks + return nil, fmt.Errorf("found streaming block after non-streaming blocks") + } + // Collect streaming block by index + blocksList = expandSlice(int(block.StreamMeta.Index), blocksList) + blocksList[block.StreamMeta.Index] = append(blocksList[block.StreamMeta.Index], block) + } + } + + if msg.ResponseMeta != nil { + metas = append(metas, msg.ResponseMeta) + } + } + + meta, err := concatAgenticResponseMeta(metas) + if err != nil { + return nil, fmt.Errorf("failed to concat agentic response meta: %w", err) + } + + if len(blocksList) > 0 { + // All blocks are streaming, concat each group by index + blocks = make([]*ContentBlock, len(blocksList)) + for i, bs := range blocksList { + if len(bs) == 0 { + continue + } + b, err := concatAgenticContentBlocks(bs) + if err != nil { + return nil, fmt.Errorf("failed to concat content blocks at index %d: %w", i, err) + } + blocks[i] = b + } + } + + for i := 0; i < len(blocks); i++ { + if blocks[i] == nil { + blocks = append(blocks[:i], blocks[i+1:]...) + } + } + + return &AgenticMessage{ + ResponseMeta: meta, + Role: role, + ContentBlocks: blocks, + }, nil +} + +func concatAgenticResponseMeta(metas []*AgenticResponseMeta) (*AgenticResponseMeta, error) { + if len(metas) == 0 { + return nil, nil + } + ret := &AgenticResponseMeta{ + TokenUsage: &TokenUsage{}, + OpenAIExtension: nil, + ClaudeExtension: nil, + GeminiExtension: nil, + Extension: nil, + } + for _, meta := range metas { + ret.Extension = meta.Extension + ret.OpenAIExtension = meta.OpenAIExtension + ret.ClaudeExtension = meta.ClaudeExtension + ret.GeminiExtension = meta.GeminiExtension + if meta.TokenUsage != nil { + ret.TokenUsage.CompletionTokens += meta.TokenUsage.CompletionTokens + ret.TokenUsage.CompletionTokenDetails.ReasoningTokens += meta.TokenUsage.CompletionTokenDetails.ReasoningTokens + ret.TokenUsage.PromptTokens += meta.TokenUsage.PromptTokens + ret.TokenUsage.PromptTokenDetails.CachedTokens += meta.TokenUsage.PromptTokenDetails.CachedTokens + ret.TokenUsage.TotalTokens += meta.TokenUsage.TotalTokens + } + } + return ret, nil +} + +func concatAgenticContentBlocks(blocks []*ContentBlock) (*ContentBlock, error) { + if len(blocks) == 0 { + return nil, fmt.Errorf("no content blocks to concat") + } + blockType := blocks[0].Type + index := blocks[0].StreamMeta.Index + switch blockType { + case ContentBlockTypeReasoning: + return concatContentBlockHelper(blocks, blockType, "reasoning", + func(b *ContentBlock) *Reasoning { return b.Reasoning }, + concatReasoning, + func(r *Reasoning) *ContentBlock { + return &ContentBlock{Type: blockType, Reasoning: r, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeUserInputText: + return concatContentBlockHelper(blocks, blockType, "user input text", + func(b *ContentBlock) *UserInputText { return b.UserInputText }, + concatUserInputText, + func(t *UserInputText) *ContentBlock { + return &ContentBlock{Type: blockType, UserInputText: t, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeUserInputImage: + return concatContentBlockHelper(blocks, blockType, "user input image", + func(b *ContentBlock) *UserInputImage { return b.UserInputImage }, + concatUserInputImage, + func(i *UserInputImage) *ContentBlock { + return &ContentBlock{Type: blockType, UserInputImage: i, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeUserInputAudio: + return concatContentBlockHelper(blocks, blockType, "user input audio", + func(b *ContentBlock) *UserInputAudio { return b.UserInputAudio }, + concatUserInputAudio, + func(a *UserInputAudio) *ContentBlock { + return &ContentBlock{Type: blockType, UserInputAudio: a, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeUserInputVideo: + return concatContentBlockHelper(blocks, blockType, "user input video", + func(b *ContentBlock) *UserInputVideo { return b.UserInputVideo }, + concatUserInputVideo, + func(v *UserInputVideo) *ContentBlock { + return &ContentBlock{Type: blockType, UserInputVideo: v, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeUserInputFile: + return concatContentBlockHelper(blocks, blockType, "user input file", + func(b *ContentBlock) *UserInputFile { return b.UserInputFile }, + concatUserInputFile, + func(f *UserInputFile) *ContentBlock { + return &ContentBlock{Type: blockType, UserInputFile: f, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeAssistantGenText: + return concatContentBlockHelper(blocks, blockType, "assistant gen text", + func(b *ContentBlock) *AssistantGenText { return b.AssistantGenText }, + concatAssistantGenText, + func(t *AssistantGenText) *ContentBlock { + return &ContentBlock{Type: blockType, AssistantGenText: t, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeAssistantGenImage: + return concatContentBlockHelper(blocks, blockType, "assistant gen image", + func(b *ContentBlock) *AssistantGenImage { return b.AssistantGenImage }, + concatAssistantGenImage, + func(i *AssistantGenImage) *ContentBlock { + return &ContentBlock{Type: blockType, AssistantGenImage: i, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeAssistantGenAudio: + return concatContentBlockHelper(blocks, blockType, "assistant gen audio", + func(b *ContentBlock) *AssistantGenAudio { return b.AssistantGenAudio }, + concatAssistantGenAudio, + func(a *AssistantGenAudio) *ContentBlock { + return &ContentBlock{Type: blockType, AssistantGenAudio: a, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeAssistantGenVideo: + return concatContentBlockHelper(blocks, blockType, "assistant gen video", + func(b *ContentBlock) *AssistantGenVideo { return b.AssistantGenVideo }, + concatAssistantGenVideo, + func(v *AssistantGenVideo) *ContentBlock { + return &ContentBlock{Type: blockType, AssistantGenVideo: v, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeFunctionToolCall: + return concatContentBlockHelper(blocks, blockType, "function tool call", + func(b *ContentBlock) *FunctionToolCall { return b.FunctionToolCall }, + concatFunctionToolCall, + func(c *FunctionToolCall) *ContentBlock { + return &ContentBlock{Type: blockType, FunctionToolCall: c, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeFunctionToolResult: + return concatContentBlockHelper(blocks, blockType, "function tool result", + func(b *ContentBlock) *FunctionToolResult { return b.FunctionToolResult }, + concatFunctionToolResult, + func(r *FunctionToolResult) *ContentBlock { + return &ContentBlock{Type: blockType, FunctionToolResult: r, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeServerToolCall: + return concatContentBlockHelper(blocks, blockType, "server tool call", + func(b *ContentBlock) *ServerToolCall { return b.ServerToolCall }, + concatServerToolCall, + func(c *ServerToolCall) *ContentBlock { + return &ContentBlock{Type: blockType, ServerToolCall: c, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeServerToolResult: + return concatContentBlockHelper(blocks, blockType, "server tool result", + func(b *ContentBlock) *ServerToolResult { return b.ServerToolResult }, + concatServerToolResult, + func(r *ServerToolResult) *ContentBlock { + return &ContentBlock{Type: blockType, ServerToolResult: r, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeMCPToolCall: + return concatContentBlockHelper(blocks, blockType, "MCP tool call", + func(b *ContentBlock) *MCPToolCall { return b.MCPToolCall }, + concatMCPToolCall, + func(c *MCPToolCall) *ContentBlock { + return &ContentBlock{Type: blockType, MCPToolCall: c, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeMCPToolResult: + return concatContentBlockHelper(blocks, blockType, "MCP tool result", + func(b *ContentBlock) *MCPToolResult { return b.MCPToolResult }, + concatMCPToolResult, + func(r *MCPToolResult) *ContentBlock { + return &ContentBlock{Type: blockType, MCPToolResult: r, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeMCPListToolsResult: + return concatContentBlockHelper(blocks, blockType, "MCP list tools", + func(b *ContentBlock) *MCPListToolsResult { return b.MCPListToolsResult }, + concatMCPListToolsResult, + func(r *MCPListToolsResult) *ContentBlock { + return &ContentBlock{Type: blockType, MCPListToolsResult: r, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeMCPToolApprovalRequest: + return concatContentBlockHelper(blocks, blockType, "MCP tool approval request", + func(b *ContentBlock) *MCPToolApprovalRequest { return b.MCPToolApprovalRequest }, + concatMCPToolApprovalRequest, + func(r *MCPToolApprovalRequest) *ContentBlock { + return &ContentBlock{Type: blockType, MCPToolApprovalRequest: r, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeMCPToolApprovalResponse: + return concatContentBlockHelper(blocks, blockType, "MCP tool approval response", + func(b *ContentBlock) *MCPToolApprovalResponse { return b.MCPToolApprovalResponse }, + concatMCPToolApprovalResponse, + func(r *MCPToolApprovalResponse) *ContentBlock { + return &ContentBlock{Type: blockType, MCPToolApprovalResponse: r, StreamMeta: &StreamMeta{Index: index}} + }) + + default: + return nil, fmt.Errorf("unknown content block type: %s", blockType) + } +} + +// concatContentBlockHelper is a generic helper function that reduces code duplication +// for concatenating content blocks of a specific type. +func concatContentBlockHelper[T any]( + blocks []*ContentBlock, + expectedType ContentBlockType, + typeName string, + getter func(*ContentBlock) *T, + concatFunc func([]*T) (*T, error), + constructor func(*T) *ContentBlock, +) (*ContentBlock, error) { + items, err := genericGetTFromContentBlocks(blocks, func(block *ContentBlock) (*T, error) { + if block.Type != expectedType { + return nil, fmt.Errorf("expected %s block, got %s", typeName, block.Type) + } + item := getter(block) + if item == nil { + return nil, fmt.Errorf("%s content is nil", typeName) + } + return item, nil + }) + if err != nil { + return nil, err + } + + concatenated, err := concatFunc(items) + if err != nil { + return nil, err + } + + return constructor(concatenated), nil +} + +func genericGetTFromContentBlocks[T any](blocks []*ContentBlock, checkAndGetter func(block *ContentBlock) (T, error)) ([]T, error) { + ret := make([]T, 0, len(blocks)) + for _, block := range blocks { + t, err := checkAndGetter(block) + if err != nil { + return nil, err + } + ret = append(ret, t) + } + return ret, nil +} + +// Concatenation strategies for different content block types: +// +// String concatenation (incremental streaming): +// - Reasoning: Summary texts are concatenated, grouped by Index if present +// - UserInputText: Text fields are concatenated +// - AssistantGenText: Text fields are concatenated, annotations/citations are merged +// - FunctionToolCall: Arguments (JSON strings) are concatenated incrementally +// - FunctionToolResult: Result strings are concatenated +// - ServerToolCall: Arguments are merged (last non-nil value for any type) +// - ServerToolResult: Results are merged using internal.ConcatItems +// - MCPToolCall: Arguments (JSON strings) are concatenated incrementally +// - MCPToolResult: Result strings are concatenated +// - MCPListToolsResult: Tools arrays are merged +// - MCPToolApprovalRequest: Arguments are concatenated +// +// Take last block (non-streaming content): +// - UserInputImage, UserInputAudio, UserInputVideo, UserInputFile: Return last block +// - AssistantGenImage, AssistantGenAudio, AssistantGenVideo: Return last block +// - MCPToolApprovalResponse: Return last block +// + +func concatReasoning(reasons []*Reasoning) (*Reasoning, error) { + if len(reasons) == 0 { + return nil, fmt.Errorf("no reasoning found") + } + if len(reasons) == 1 { + return reasons[0], nil + } + + ret := &Reasoning{ + Summary: make([]*ReasoningSummary, 0), + EncryptedContent: "", + Extra: make(map[string]any), + } + + // Collect all summaries from all reasons + allSummaries := make([]*ReasoningSummary, 0) + for _, r := range reasons { + if r == nil { + continue + } + allSummaries = append(allSummaries, r.Summary...) + if r.EncryptedContent != "" { + ret.EncryptedContent += r.EncryptedContent + } + for k, v := range r.Extra { + ret.Extra[k] = v + } + } + + // Group by Index and concatenate Text for same Index + // Use dynamic array that expands as needed + var summaryArray []*ReasoningSummary + for _, s := range allSummaries { + idx := s.Index + // Expand array if needed + summaryArray = expandSlice(int(idx), summaryArray) + if summaryArray[idx] == nil { + // Create new entry with a copy of Index + summaryArray[idx] = &ReasoningSummary{ + Index: idx, + Text: s.Text, + } + } else { + // Concatenate text for same index + summaryArray[idx].Text += s.Text + } + } + + // Convert array to slice, filtering out nil entries + ret.Summary = make([]*ReasoningSummary, 0, len(summaryArray)) + for _, summary := range summaryArray { + if summary != nil { + ret.Summary = append(ret.Summary, summary) + } + } + + return ret, nil +} + +func concatUserInputText(texts []*UserInputText) (*UserInputText, error) { + if len(texts) == 0 { + return nil, fmt.Errorf("no user input text found") + } + if len(texts) == 1 { + return texts[0], nil + } + + ret := &UserInputText{ + Text: "", + Extra: make(map[string]any), + } + + for _, t := range texts { + if t == nil { + continue + } + ret.Text += t.Text + for k, v := range t.Extra { + ret.Extra[k] = v + } + } + + return ret, nil +} + +func concatUserInputImage(images []*UserInputImage) (*UserInputImage, error) { + if len(images) == 0 { + return nil, fmt.Errorf("no user input image found") + } + return images[len(images)-1], nil +} + +func concatUserInputAudio(audios []*UserInputAudio) (*UserInputAudio, error) { + if len(audios) == 0 { + return nil, fmt.Errorf("no user input audio found") + } + return audios[len(audios)-1], nil +} + +func concatUserInputVideo(videos []*UserInputVideo) (*UserInputVideo, error) { + if len(videos) == 0 { + return nil, fmt.Errorf("no user input video found") + } + return videos[len(videos)-1], nil +} + +func concatUserInputFile(files []*UserInputFile) (*UserInputFile, error) { + if len(files) == 0 { + return nil, fmt.Errorf("no user input file found") + } + return files[len(files)-1], nil +} + +func concatAssistantGenText(texts []*AssistantGenText) (*AssistantGenText, error) { + if len(texts) == 0 { + return nil, fmt.Errorf("no assistant gen text found") + } + if len(texts) == 1 { + return texts[0], nil + } + + ret := &AssistantGenText{ + Text: "", + OpenAIExtension: nil, + ClaudeExtension: nil, + Extra: make(map[string]any), + } + + for _, t := range texts { + if t == nil { + continue + } + ret.Text += t.Text + if t.OpenAIExtension != nil { + if ret.OpenAIExtension == nil { + ret.OpenAIExtension = &openai.AssistantGenTextExtension{} + } + ret.OpenAIExtension.Annotations = append(ret.OpenAIExtension.Annotations, t.OpenAIExtension.Annotations...) + } + if t.ClaudeExtension != nil { + if ret.ClaudeExtension == nil { + ret.ClaudeExtension = &claude.AssistantGenTextExtension{} + } + ret.ClaudeExtension.Citations = append(ret.ClaudeExtension.Citations, t.ClaudeExtension.Citations...) + } + for k, v := range t.Extra { + ret.Extra[k] = v + } + } + + return ret, nil +} + +func concatAssistantGenImage(images []*AssistantGenImage) (*AssistantGenImage, error) { + if len(images) == 0 { + return nil, fmt.Errorf("no assistant gen image found") + } + return images[len(images)-1], nil +} + +func concatAssistantGenAudio(audios []*AssistantGenAudio) (*AssistantGenAudio, error) { + if len(audios) == 0 { + return nil, fmt.Errorf("no assistant gen audio found") + } + return audios[len(audios)-1], nil +} + +func concatAssistantGenVideo(videos []*AssistantGenVideo) (*AssistantGenVideo, error) { + if len(videos) == 0 { + return nil, fmt.Errorf("no assistant gen video found") + } + return videos[len(videos)-1], nil +} + +func concatFunctionToolCall(calls []*FunctionToolCall) (*FunctionToolCall, error) { + if len(calls) == 0 { + return nil, fmt.Errorf("no function tool call found") + } + if len(calls) == 1 { + return calls[0], nil + } + + // For tool calls, arguments are typically built incrementally during streaming + ret := &FunctionToolCall{ + Extra: make(map[string]any), + } + + for _, c := range calls { + if c == nil { + continue + } + if ret.CallID == "" { + ret.CallID = c.CallID + } + if ret.Name == "" { + ret.Name = c.Name + } + ret.Arguments += c.Arguments + for k, v := range c.Extra { + ret.Extra[k] = v + } + } + + return ret, nil +} + +func concatFunctionToolResult(results []*FunctionToolResult) (*FunctionToolResult, error) { + if len(results) == 0 { + return nil, fmt.Errorf("no function tool result found") + } + if len(results) == 1 { + return results[0], nil + } + + ret := &FunctionToolResult{ + Extra: make(map[string]any), + } + + for _, r := range results { + if r == nil { + continue + } + if ret.CallID == "" { + ret.CallID = r.CallID + } + if ret.Name == "" { + ret.Name = r.Name + } + ret.Result += r.Result + for k, v := range r.Extra { + ret.Extra[k] = v + } + } + + return ret, nil +} + +func concatServerToolCall(calls []*ServerToolCall) (*ServerToolCall, error) { + if len(calls) == 0 { + return nil, fmt.Errorf("no server tool call found") + } + if len(calls) == 1 { + return calls[0], nil + } + + // ServerToolCall Arguments is of type any; merge strategy uses the last non-nil value + ret := &ServerToolCall{ + Extra: make(map[string]any), + } + + for _, c := range calls { + if c == nil { + continue + } + if ret.Name == "" { + ret.Name = c.Name + } + if ret.CallID == "" { + ret.CallID = c.CallID + } + if c.Arguments != nil { + ret.Arguments = c.Arguments + } + for k, v := range c.Extra { + ret.Extra[k] = v + } + } + + return ret, nil +} + +func concatServerToolResult(results []*ServerToolResult) (*ServerToolResult, error) { + if len(results) == 0 { + return nil, fmt.Errorf("no server tool result found") + } + if len(results) == 1 { + return results[0], nil + } + + // ServerToolResult Result is of type any; merge strategy uses the last non-nil value + ret := &ServerToolResult{ + Extra: make(map[string]any), + } + + tZeroResult := reflect.TypeOf(results[0].Result) + data := reflect.MakeSlice(reflect.SliceOf(tZeroResult), 0, 0) + for _, r := range results { + if r == nil { + continue + } + if ret.Name == "" { + ret.Name = r.Name + } + if ret.CallID == "" { + ret.CallID = r.CallID + } + if r.Result != nil { + vResult := reflect.ValueOf(r.Result) + if tZeroResult != vResult.Type() { + return nil, fmt.Errorf("tool result types are different: %v %v", tZeroResult, vResult.Type()) + } + data = reflect.Append(data, vResult) + } + for k, v := range r.Extra { + ret.Extra[k] = v + } + } + + d, err := internal.ConcatSliceValue(data) + if err != nil { + return nil, fmt.Errorf("failed to concat server tool result: %v", err) + } + ret.Result = d + + return ret, nil +} + +func concatMCPToolCall(calls []*MCPToolCall) (*MCPToolCall, error) { + if len(calls) == 0 { + return nil, fmt.Errorf("no mcp tool call found") + } + if len(calls) == 1 { + return calls[0], nil + } + + ret := &MCPToolCall{ + Extra: make(map[string]any), + } + + for _, c := range calls { + if c == nil { + continue + } + if ret.ServerLabel == "" { + ret.ServerLabel = c.ServerLabel + } + if ret.ApprovalRequestID == "" { + ret.ApprovalRequestID = c.ApprovalRequestID + } + if ret.CallID == "" { + ret.CallID = c.CallID + } + if ret.Name == "" { + ret.Name = c.Name + } + ret.Arguments += c.Arguments + for k, v := range c.Extra { + ret.Extra[k] = v + } + } + + return ret, nil +} + +func concatMCPToolResult(results []*MCPToolResult) (*MCPToolResult, error) { + if len(results) == 0 { + return nil, fmt.Errorf("no mcp tool result found") + } + if len(results) == 1 { + return results[0], nil + } + + ret := &MCPToolResult{ + Extra: make(map[string]any), + } + + for _, r := range results { + if r == nil { + continue + } + if ret.CallID == "" { + ret.CallID = r.CallID + } + if ret.Name == "" { + ret.Name = r.Name + } + ret.Result += r.Result + if r.Error != nil { + ret.Error = r.Error // Use the last error + } + for k, v := range r.Extra { + ret.Extra[k] = v + } + } + + return ret, nil +} + +func concatMCPListToolsResult(results []*MCPListToolsResult) (*MCPListToolsResult, error) { + if len(results) == 0 { + return nil, fmt.Errorf("no mcp list tools result found") + } + if len(results) == 1 { + return results[0], nil + } + + ret := &MCPListToolsResult{ + Tools: make([]*MCPListToolsItem, 0), + Extra: make(map[string]any), + } + + for _, r := range results { + if r == nil { + continue + } + if ret.ServerLabel == "" { + ret.ServerLabel = r.ServerLabel + } + ret.Tools = append(ret.Tools, r.Tools...) + if r.Error != "" { + ret.Error = r.Error // Use the last error + } + for k, v := range r.Extra { + ret.Extra[k] = v + } + } + + return ret, nil +} + +func concatMCPToolApprovalRequest(requests []*MCPToolApprovalRequest) (*MCPToolApprovalRequest, error) { + if len(requests) == 0 { + return nil, fmt.Errorf("no mcp tool approval request found") + } + if len(requests) == 1 { + return requests[0], nil + } + + ret := &MCPToolApprovalRequest{ + Extra: make(map[string]any), + } + + for _, r := range requests { + if r == nil { + continue + } + if ret.ID == "" { + ret.ID = r.ID + } + if ret.Name == "" { + ret.Name = r.Name + } + ret.Arguments += r.Arguments + if ret.ServerLabel == "" { + ret.ServerLabel = r.ServerLabel + } + for k, v := range r.Extra { + ret.Extra[k] = v + } + } + + return ret, nil +} + +func concatMCPToolApprovalResponse(responses []*MCPToolApprovalResponse) (*MCPToolApprovalResponse, error) { + if len(responses) == 0 { + return nil, fmt.Errorf("no mcp tool approval response found") + } + if len(responses) == 1 { + return responses[0], nil + } + + return responses[len(responses)-1], nil +} + +func expandSlice[T any](idx int, s []T) []T { + if len(s) > idx { + return s + } + return append(s, make([]T, idx-len(s)+1)...) +} + +func (m *AgenticMessage) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf("role: %s\n", m.Role)) + + if len(m.ContentBlocks) > 0 { + sb.WriteString("content_blocks:\n") + for i, block := range m.ContentBlocks { + if block == nil { + continue + } + sb.WriteString(fmt.Sprintf(" [%d] %s", i, block.String())) + } + } + + if m.ResponseMeta != nil { + sb.WriteString(m.ResponseMeta.String()) + } + + return sb.String() +} + +func (b *ContentBlock) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf("type: %s\n", b.Type)) + + switch b.Type { + case ContentBlockTypeReasoning: + if b.Reasoning != nil { + sb.WriteString(b.Reasoning.String()) + } + case ContentBlockTypeUserInputText: + if b.UserInputText != nil { + sb.WriteString(b.UserInputText.String()) + } + case ContentBlockTypeUserInputImage: + if b.UserInputImage != nil { + sb.WriteString(b.UserInputImage.String()) + } + case ContentBlockTypeUserInputAudio: + if b.UserInputAudio != nil { + sb.WriteString(b.UserInputAudio.String()) + } + case ContentBlockTypeUserInputVideo: + if b.UserInputVideo != nil { + sb.WriteString(b.UserInputVideo.String()) + } + case ContentBlockTypeUserInputFile: + if b.UserInputFile != nil { + sb.WriteString(b.UserInputFile.String()) + } + case ContentBlockTypeAssistantGenText: + if b.AssistantGenText != nil { + sb.WriteString(b.AssistantGenText.String()) + } + case ContentBlockTypeAssistantGenImage: + if b.AssistantGenImage != nil { + sb.WriteString(b.AssistantGenImage.String()) + } + case ContentBlockTypeAssistantGenAudio: + if b.AssistantGenAudio != nil { + sb.WriteString(b.AssistantGenAudio.String()) + } + case ContentBlockTypeAssistantGenVideo: + if b.AssistantGenVideo != nil { + sb.WriteString(b.AssistantGenVideo.String()) + } + case ContentBlockTypeFunctionToolCall: + if b.FunctionToolCall != nil { + sb.WriteString(b.FunctionToolCall.String()) + } + case ContentBlockTypeFunctionToolResult: + if b.FunctionToolResult != nil { + sb.WriteString(b.FunctionToolResult.String()) + } + case ContentBlockTypeServerToolCall: + if b.ServerToolCall != nil { + sb.WriteString(b.ServerToolCall.String()) + } + case ContentBlockTypeServerToolResult: + if b.ServerToolResult != nil { + sb.WriteString(b.ServerToolResult.String()) + } + case ContentBlockTypeMCPToolCall: + if b.MCPToolCall != nil { + sb.WriteString(b.MCPToolCall.String()) + } + case ContentBlockTypeMCPToolResult: + if b.MCPToolResult != nil { + sb.WriteString(b.MCPToolResult.String()) + } + case ContentBlockTypeMCPListToolsResult: + if b.MCPListToolsResult != nil { + sb.WriteString(b.MCPListToolsResult.String()) + } + case ContentBlockTypeMCPToolApprovalRequest: + if b.MCPToolApprovalRequest != nil { + sb.WriteString(b.MCPToolApprovalRequest.String()) + } + case ContentBlockTypeMCPToolApprovalResponse: + if b.MCPToolApprovalResponse != nil { + sb.WriteString(b.MCPToolApprovalResponse.String()) + } + } + + if b.StreamMeta != nil { + sb.WriteString(fmt.Sprintf(" stream_index: %d\n", b.StreamMeta.Index)) + } + + return sb.String() +} + +func (r *Reasoning) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" summary: %d items\n", len(r.Summary))) + for _, s := range r.Summary { + sb.WriteString(fmt.Sprintf(" [%d] %s\n", s.Index, s.Text)) + } + if r.EncryptedContent != "" { + sb.WriteString(fmt.Sprintf(" encrypted_content: %s\n", truncateString(r.EncryptedContent, 50))) + } + return sb.String() +} + +func (u *UserInputText) String() string { + return fmt.Sprintf(" text: %s\n", u.Text) +} + +func (u *UserInputImage) String() string { + return formatMediaString(u.URL, u.Base64Data, u.MIMEType, u.Detail) +} + +func (u *UserInputAudio) String() string { + return formatMediaString(u.URL, u.Base64Data, u.MIMEType, "") +} + +func (u *UserInputVideo) String() string { + return formatMediaString(u.URL, u.Base64Data, u.MIMEType, "") +} + +func (u *UserInputFile) String() string { + sb := &strings.Builder{} + if u.Name != "" { + sb.WriteString(fmt.Sprintf(" name: %s\n", u.Name)) + } + sb.WriteString(formatMediaString(u.URL, u.Base64Data, u.MIMEType, "")) + return sb.String() +} + +func (a *AssistantGenText) String() string { + return fmt.Sprintf(" text: %s\n", a.Text) +} + +func (a *AssistantGenImage) String() string { + return formatMediaString(a.URL, a.Base64Data, a.MIMEType, "") +} + +func (a *AssistantGenAudio) String() string { + return formatMediaString(a.URL, a.Base64Data, a.MIMEType, "") +} + +func (a *AssistantGenVideo) String() string { + return formatMediaString(a.URL, a.Base64Data, a.MIMEType, "") +} + +func (f *FunctionToolCall) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" call_id: %s\n", f.CallID)) + sb.WriteString(fmt.Sprintf(" name: %s\n", f.Name)) + sb.WriteString(fmt.Sprintf(" arguments: %s\n", f.Arguments)) + return sb.String() +} + +func (f *FunctionToolResult) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" call_id: %s\n", f.CallID)) + sb.WriteString(fmt.Sprintf(" name: %s\n", f.Name)) + sb.WriteString(fmt.Sprintf(" result: %s\n", f.Result)) + return sb.String() +} + +func (s *ServerToolCall) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" name: %s\n", s.Name)) + if s.CallID != "" { + sb.WriteString(fmt.Sprintf(" call_id: %s\n", s.CallID)) + } + sb.WriteString(fmt.Sprintf(" arguments: %v\n", s.Arguments)) + return sb.String() +} + +func (s *ServerToolResult) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" name: %s\n", s.Name)) + if s.CallID != "" { + sb.WriteString(fmt.Sprintf(" call_id: %s\n", s.CallID)) + } + sb.WriteString(fmt.Sprintf(" result: %v\n", s.Result)) + return sb.String() +} + +func (m *MCPToolCall) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" server_label: %s\n", m.ServerLabel)) + sb.WriteString(fmt.Sprintf(" call_id: %s\n", m.CallID)) + sb.WriteString(fmt.Sprintf(" name: %s\n", m.Name)) + sb.WriteString(fmt.Sprintf(" arguments: %s\n", m.Arguments)) + if m.ApprovalRequestID != "" { + sb.WriteString(fmt.Sprintf(" approval_request_id: %s\n", m.ApprovalRequestID)) + } + return sb.String() +} + +func (m *MCPToolResult) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" call_id: %s\n", m.CallID)) + sb.WriteString(fmt.Sprintf(" name: %s\n", m.Name)) + sb.WriteString(fmt.Sprintf(" result: %s\n", m.Result)) + if m.Error != nil { + sb.WriteString(fmt.Sprintf(" error: [%d] %s\n", *m.Error.Code, m.Error.Message)) + } + return sb.String() +} + +func (m *MCPListToolsResult) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" server_label: %s\n", m.ServerLabel)) + sb.WriteString(fmt.Sprintf(" tools: %d items\n", len(m.Tools))) + for _, tool := range m.Tools { + sb.WriteString(fmt.Sprintf(" - %s: %s\n", tool.Name, tool.Description)) + } + if m.Error != "" { + sb.WriteString(fmt.Sprintf(" error: %s\n", m.Error)) + } + return sb.String() +} + +func (m *MCPToolApprovalRequest) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" server_label: %s\n", m.ServerLabel)) + sb.WriteString(fmt.Sprintf(" id: %s\n", m.ID)) + sb.WriteString(fmt.Sprintf(" name: %s\n", m.Name)) + sb.WriteString(fmt.Sprintf(" arguments: %s\n", m.Arguments)) + return sb.String() +} + +func (m *MCPToolApprovalResponse) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" approval_request_id: %s\n", m.ApprovalRequestID)) + sb.WriteString(fmt.Sprintf(" approve: %v\n", m.Approve)) + if m.Reason != "" { + sb.WriteString(fmt.Sprintf(" reason: %s\n", m.Reason)) + } + return sb.String() +} + +func (a *AgenticResponseMeta) String() string { + sb := &strings.Builder{} + sb.WriteString("response_meta:\n") + if a.TokenUsage != nil { + sb.WriteString(fmt.Sprintf(" token_usage: prompt=%d, completion=%d, total=%d\n", + a.TokenUsage.PromptTokens, + a.TokenUsage.CompletionTokens, + a.TokenUsage.TotalTokens)) + } + return sb.String() +} + +// truncateString truncates a string to maxLen characters, adding "..." if truncated +func truncateString(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + return s[:maxLen] + "..." +} + +// formatMediaString formats URL, Base64Data, MIMEType and Detail for media content +func formatMediaString(url, base64Data string, mimeType string, detail any) string { + sb := &strings.Builder{} + if url != "" { + sb.WriteString(fmt.Sprintf(" url: %s\n", truncateString(url, 100))) + } + if base64Data != "" { + // Only show first few characters of base64 data + sb.WriteString(fmt.Sprintf(" base64_data: %s... (%d bytes)\n", truncateString(base64Data, 20), len(base64Data))) + } + if mimeType != "" { + sb.WriteString(fmt.Sprintf(" mime_type: %s\n", mimeType)) + } + if detail != nil && detail != "" { + sb.WriteString(fmt.Sprintf(" detail: %v\n", detail)) + } + return sb.String() +} diff --git a/schema/agentic_message_test.go b/schema/agentic_message_test.go new file mode 100644 index 000000000..0cafcd9ff --- /dev/null +++ b/schema/agentic_message_test.go @@ -0,0 +1,1381 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package schema + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConcatAgenticMessages(t *testing.T) { + t.Run("single message", func(t *testing.T) { + msg := &AgenticMessage{ + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Hello", + }, + }, + }, + } + + result, err := ConcatAgenticMessages([]*AgenticMessage{msg}) + assert.NoError(t, err) + assert.Equal(t, msg, result) + }) + + t.Run("nil message in stream", func(t *testing.T) { + msgs := []*AgenticMessage{ + {Role: AgenticRoleTypeAssistant}, + nil, + {Role: AgenticRoleTypeAssistant}, + } + + _, err := ConcatAgenticMessages(msgs) + assert.Error(t, err) + assert.ErrorContains(t, err, "message at index 1 is nil") + }) + + t.Run("different roles", func(t *testing.T) { + msgs := []*AgenticMessage{ + {Role: AgenticRoleTypeUser}, + {Role: AgenticRoleTypeAssistant}, + } + + _, err := ConcatAgenticMessages(msgs) + assert.Error(t, err) + assert.ErrorContains(t, err, "cannot concat messages with different roles") + }) + + t.Run("concat text blocks", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Hello ", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "World!", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Equal(t, AgenticRoleTypeAssistant, result.Role) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "Hello World!", result.ContentBlocks[0].AssistantGenText.Text) + }) + + t.Run("concat reasoning with nil index", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeReasoning, + Reasoning: &Reasoning{ + Summary: []*ReasoningSummary{ + {Index: 0, Text: "First "}, + }, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeReasoning, + Reasoning: &Reasoning{ + Summary: []*ReasoningSummary{ + {Index: 0, Text: "Second"}, + }, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Len(t, result.ContentBlocks[0].Reasoning.Summary, 1) + assert.Equal(t, "First Second", result.ContentBlocks[0].Reasoning.Summary[0].Text) + assert.Equal(t, int64(0), result.ContentBlocks[0].Reasoning.Summary[0].Index) + }) + + t.Run("concat reasoning with index", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeReasoning, + Reasoning: &Reasoning{ + Summary: []*ReasoningSummary{ + {Index: 0, Text: "Part1-"}, + {Index: 1, Text: "Part2-"}, + }, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeReasoning, + Reasoning: &Reasoning{ + Summary: []*ReasoningSummary{ + {Index: 0, Text: "Part3"}, + {Index: 1, Text: "Part4"}, + }, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Len(t, result.ContentBlocks[0].Reasoning.Summary, 2) + assert.Equal(t, "Part1-Part3", result.ContentBlocks[0].Reasoning.Summary[0].Text) + assert.Equal(t, "Part2-Part4", result.ContentBlocks[0].Reasoning.Summary[1].Text) + }) + + t.Run("concat user input text", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputText, + UserInputText: &UserInputText{ + Text: "Hello ", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputText, + UserInputText: &UserInputText{ + Text: "World!", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "Hello World!", result.ContentBlocks[0].UserInputText.Text) + }) + + t.Run("concat user input image", func(t *testing.T) { + url1 := "https://example.com/image1.jpg" + url2 := "https://example.com/image2.jpg" + + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputImage, + UserInputImage: &UserInputImage{ + URL: url1, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputImage, + UserInputImage: &UserInputImage{ + URL: url2, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + // Should take the last image + assert.Equal(t, url2, result.ContentBlocks[0].UserInputImage.URL) + }) + + t.Run("concat user input audio", func(t *testing.T) { + url1 := "https://example.com/audio1.mp3" + url2 := "https://example.com/audio2.mp3" + + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputAudio, + UserInputAudio: &UserInputAudio{ + URL: url1, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputAudio, + UserInputAudio: &UserInputAudio{ + URL: url2, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + // Should take the last audio + assert.Equal(t, url2, result.ContentBlocks[0].UserInputAudio.URL) + }) + + t.Run("concat user input video", func(t *testing.T) { + url1 := "https://example.com/video1.mp4" + url2 := "https://example.com/video2.mp4" + + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputVideo, + UserInputVideo: &UserInputVideo{ + URL: url1, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputVideo, + UserInputVideo: &UserInputVideo{ + URL: url2, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + // Should take the last video + assert.Equal(t, url2, result.ContentBlocks[0].UserInputVideo.URL) + }) + + t.Run("concat assistant gen text", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Generated ", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Text", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "Generated Text", result.ContentBlocks[0].AssistantGenText.Text) + }) + + t.Run("concat assistant gen image", func(t *testing.T) { + url1 := "https://example.com/gen_image1.jpg" + url2 := "https://example.com/gen_image2.jpg" + + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenImage, + AssistantGenImage: &AssistantGenImage{ + URL: url1, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenImage, + AssistantGenImage: &AssistantGenImage{ + URL: url2, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + // Should take the last image + assert.Equal(t, url2, result.ContentBlocks[0].AssistantGenImage.URL) + }) + + t.Run("concat assistant gen audio", func(t *testing.T) { + url1 := "https://example.com/gen_audio1.mp3" + url2 := "https://example.com/gen_audio2.mp3" + + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenAudio, + AssistantGenAudio: &AssistantGenAudio{ + URL: url1, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenAudio, + AssistantGenAudio: &AssistantGenAudio{ + URL: url2, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + // Should take the last audio + assert.Equal(t, url2, result.ContentBlocks[0].AssistantGenAudio.URL) + }) + + t.Run("concat assistant gen video", func(t *testing.T) { + url1 := "https://example.com/gen_video1.mp4" + url2 := "https://example.com/gen_video2.mp4" + + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenVideo, + AssistantGenVideo: &AssistantGenVideo{ + URL: url1, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenVideo, + AssistantGenVideo: &AssistantGenVideo{ + URL: url2, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + // Should take the last video + assert.Equal(t, url2, result.ContentBlocks[0].AssistantGenVideo.URL) + }) + + t.Run("concat function tool call", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeFunctionToolCall, + FunctionToolCall: &FunctionToolCall{ + CallID: "call_123", + Name: "get_weather", + Arguments: `{"location`, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeFunctionToolCall, + FunctionToolCall: &FunctionToolCall{ + Arguments: `":"NYC"}`, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "call_123", result.ContentBlocks[0].FunctionToolCall.CallID) + assert.Equal(t, "get_weather", result.ContentBlocks[0].FunctionToolCall.Name) + assert.Equal(t, `{"location":"NYC"}`, result.ContentBlocks[0].FunctionToolCall.Arguments) + }) + + t.Run("concat function tool result", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeFunctionToolResult, + FunctionToolResult: &FunctionToolResult{ + CallID: "call_123", + Name: "get_weather", + Result: `{"temp`, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeFunctionToolResult, + FunctionToolResult: &FunctionToolResult{ + Result: `":72}`, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "call_123", result.ContentBlocks[0].FunctionToolResult.CallID) + assert.Equal(t, "get_weather", result.ContentBlocks[0].FunctionToolResult.Name) + assert.Equal(t, `{"temp":72}`, result.ContentBlocks[0].FunctionToolResult.Result) + }) + + t.Run("concat server tool call", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeServerToolCall, + ServerToolCall: &ServerToolCall{ + CallID: "server_call_1", + Name: "server_func", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeServerToolCall, + ServerToolCall: &ServerToolCall{ + Arguments: map[string]any{"key": "value"}, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "server_call_1", result.ContentBlocks[0].ServerToolCall.CallID) + assert.Equal(t, "server_func", result.ContentBlocks[0].ServerToolCall.Name) + assert.NotNil(t, result.ContentBlocks[0].ServerToolCall.Arguments) + }) + + t.Run("concat server tool result", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeServerToolResult, + ServerToolResult: &ServerToolResult{ + CallID: "server_call_1", + Name: "server_func", + Result: "result1", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeServerToolResult, + ServerToolResult: &ServerToolResult{ + Result: "result2", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "server_call_1", result.ContentBlocks[0].ServerToolResult.CallID) + assert.Equal(t, "server_func", result.ContentBlocks[0].ServerToolResult.Name) + }) + + t.Run("concat mcp tool call", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolCall, + MCPToolCall: &MCPToolCall{ + ServerLabel: "mcp-server", + CallID: "mcp_call_1", + Name: "mcp_func", + Arguments: `{"arg`, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolCall, + MCPToolCall: &MCPToolCall{ + Arguments: `":123}`, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "mcp-server", result.ContentBlocks[0].MCPToolCall.ServerLabel) + assert.Equal(t, "mcp_call_1", result.ContentBlocks[0].MCPToolCall.CallID) + assert.Equal(t, "mcp_func", result.ContentBlocks[0].MCPToolCall.Name) + assert.Equal(t, `{"arg":123}`, result.ContentBlocks[0].MCPToolCall.Arguments) + }) + + t.Run("concat mcp tool result", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolResult, + MCPToolResult: &MCPToolResult{ + CallID: "mcp_call_1", + Name: "mcp_func", + Result: `{"res`, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolResult, + MCPToolResult: &MCPToolResult{ + Result: `ult":true}`, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "mcp_call_1", result.ContentBlocks[0].MCPToolResult.CallID) + assert.Equal(t, "mcp_func", result.ContentBlocks[0].MCPToolResult.Name) + assert.Equal(t, `{"result":true}`, result.ContentBlocks[0].MCPToolResult.Result) + }) + + t.Run("concat mcp list tools", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPListToolsResult, + MCPListToolsResult: &MCPListToolsResult{ + ServerLabel: "mcp-server", + Tools: []*MCPListToolsItem{ + {Name: "tool1"}, + }, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPListToolsResult, + MCPListToolsResult: &MCPListToolsResult{ + Tools: []*MCPListToolsItem{ + {Name: "tool2"}, + }, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "mcp-server", result.ContentBlocks[0].MCPListToolsResult.ServerLabel) + assert.Len(t, result.ContentBlocks[0].MCPListToolsResult.Tools, 2) + }) + + t.Run("concat mcp tool approval request", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolApprovalRequest, + MCPToolApprovalRequest: &MCPToolApprovalRequest{ + ID: "approval_1", + Name: "approval_func", + ServerLabel: "mcp-server", + Arguments: `{"request`, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolApprovalRequest, + MCPToolApprovalRequest: &MCPToolApprovalRequest{ + Arguments: `":1}`, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "approval_1", result.ContentBlocks[0].MCPToolApprovalRequest.ID) + assert.Equal(t, "approval_func", result.ContentBlocks[0].MCPToolApprovalRequest.Name) + assert.Equal(t, "mcp-server", result.ContentBlocks[0].MCPToolApprovalRequest.ServerLabel) + assert.Equal(t, `{"request":1}`, result.ContentBlocks[0].MCPToolApprovalRequest.Arguments) + }) + + t.Run("concat mcp tool approval response", func(t *testing.T) { + response1 := &MCPToolApprovalResponse{ + ApprovalRequestID: "approval_1", + Approve: false, + } + response2 := &MCPToolApprovalResponse{ + ApprovalRequestID: "approval_1", + Approve: true, + } + + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolApprovalResponse, + MCPToolApprovalResponse: response1, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolApprovalResponse, + MCPToolApprovalResponse: response2, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + // Should take the last response + assert.Equal(t, response2, result.ContentBlocks[0].MCPToolApprovalResponse) + }) + + t.Run("concat response meta", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ResponseMeta: &AgenticResponseMeta{ + TokenUsage: &TokenUsage{ + PromptTokens: 10, + CompletionTokens: 5, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ResponseMeta: &AgenticResponseMeta{ + TokenUsage: &TokenUsage{ + PromptTokens: 10, + CompletionTokens: 15, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.NotNil(t, result.ResponseMeta) + assert.Equal(t, 20, result.ResponseMeta.TokenUsage.CompletionTokens) + assert.Equal(t, 20, result.ResponseMeta.TokenUsage.PromptTokens) + }) + + t.Run("mixed streaming and non-streaming blocks error", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Hello", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "World", + }, + // No StreamMeta - non-streaming + }, + }, + }, + } + + _, err := ConcatAgenticMessages(msgs) + assert.Error(t, err) + assert.ErrorContains(t, err, "found non-streaming block after streaming blocks") + }) + + t.Run("concat MCP tool call", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolCall, + MCPToolCall: &MCPToolCall{ + ServerLabel: "mcp-server", + CallID: "call_456", + Name: "list_files", + Arguments: `{"path`, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolCall, + MCPToolCall: &MCPToolCall{ + Arguments: `":"/tmp"}`, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "mcp-server", result.ContentBlocks[0].MCPToolCall.ServerLabel) + assert.Equal(t, "call_456", result.ContentBlocks[0].MCPToolCall.CallID) + assert.Equal(t, `{"path":"/tmp"}`, result.ContentBlocks[0].MCPToolCall.Arguments) + }) + + t.Run("concat user input text", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputText, + UserInputText: &UserInputText{ + Text: "What is ", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputText, + UserInputText: &UserInputText{ + Text: "the weather?", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "What is the weather?", result.ContentBlocks[0].UserInputText.Text) + }) + + t.Run("multiple stream indexes - sparse indexes", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Index0-", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Index2-", + }, + StreamMeta: &StreamMeta{Index: 2}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Part2", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Part2", + }, + StreamMeta: &StreamMeta{Index: 2}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 2) + assert.Equal(t, "Index0-Part2", result.ContentBlocks[0].AssistantGenText.Text) + assert.Equal(t, "Index2-Part2", result.ContentBlocks[1].AssistantGenText.Text) + }) + + t.Run("multiple stream indexes - mixed content types", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Text ", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + { + Type: ContentBlockTypeFunctionToolCall, + FunctionToolCall: &FunctionToolCall{ + CallID: "call_1", + Name: "func1", + Arguments: `{"a`, + }, + StreamMeta: &StreamMeta{Index: 1}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Content", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + { + Type: ContentBlockTypeFunctionToolCall, + FunctionToolCall: &FunctionToolCall{ + Arguments: `":1}`, + }, + StreamMeta: &StreamMeta{Index: 1}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 2) + assert.Equal(t, "Text Content", result.ContentBlocks[0].AssistantGenText.Text) + assert.Equal(t, "call_1", result.ContentBlocks[1].FunctionToolCall.CallID) + assert.Equal(t, "func1", result.ContentBlocks[1].FunctionToolCall.Name) + assert.Equal(t, `{"a":1}`, result.ContentBlocks[1].FunctionToolCall.Arguments) + }) + + t.Run("multiple stream indexes - three indexes", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "A", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "B", + }, + StreamMeta: &StreamMeta{Index: 1}, + }, + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "C", + }, + StreamMeta: &StreamMeta{Index: 2}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "1", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "2", + }, + StreamMeta: &StreamMeta{Index: 1}, + }, + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "3", + }, + StreamMeta: &StreamMeta{Index: 2}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 3) + assert.Equal(t, "A1", result.ContentBlocks[0].AssistantGenText.Text) + assert.Equal(t, "B2", result.ContentBlocks[1].AssistantGenText.Text) + assert.Equal(t, "C3", result.ContentBlocks[2].AssistantGenText.Text) + }) +} + +func TestAgenticMessageFormat(t *testing.T) { + m := &AgenticMessage{ + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputText, + UserInputText: &UserInputText{Text: "{a}"}, + }, + { + Type: ContentBlockTypeUserInputImage, + UserInputImage: &UserInputImage{ + URL: "{b}", + Base64Data: "{c}", + }, + }, + { + Type: ContentBlockTypeUserInputAudio, + UserInputAudio: &UserInputAudio{ + URL: "{d}", + Base64Data: "{e}", + }, + }, + { + Type: ContentBlockTypeUserInputVideo, + UserInputVideo: &UserInputVideo{ + URL: "{f}", + Base64Data: "{g}", + }, + }, + { + Type: ContentBlockTypeUserInputFile, + UserInputFile: &UserInputFile{ + URL: "{h}", + Base64Data: "{i}", + }, + }, + }, + } + + result, err := m.Format(context.Background(), map[string]any{ + "a": "1", "b": "2", "c": "3", "d": "4", "e": "5", "f": "6", "g": "7", "h": "8", "i": "9", + }, FString) + assert.NoError(t, err) + assert.Equal(t, []*AgenticMessage{{ + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputText, + UserInputText: &UserInputText{Text: "1"}, + }, + { + Type: ContentBlockTypeUserInputImage, + UserInputImage: &UserInputImage{ + URL: "2", + Base64Data: "3", + }, + }, + { + Type: ContentBlockTypeUserInputAudio, + UserInputAudio: &UserInputAudio{ + URL: "4", + Base64Data: "5", + }, + }, + { + Type: ContentBlockTypeUserInputVideo, + UserInputVideo: &UserInputVideo{ + URL: "6", + Base64Data: "7", + }, + }, + { + Type: ContentBlockTypeUserInputFile, + UserInputFile: &UserInputFile{ + URL: "8", + Base64Data: "9", + }, + }, + }, + }}, result) +} + +func TestAgenticPlaceholderFormat(t *testing.T) { + ctx := context.Background() + ph := AgenticMessagesPlaceholder("a", false) + + result, err := ph.Format(ctx, map[string]any{ + "a": []*AgenticMessage{{Role: AgenticRoleTypeUser}, {Role: AgenticRoleTypeUser}}, + }, FString) + assert.NoError(t, err) + assert.Equal(t, 2, len(result)) + + ph = AgenticMessagesPlaceholder("a", true) + + result, err = ph.Format(ctx, map[string]any{}, FString) + assert.NoError(t, err) + assert.Equal(t, 0, len(result)) +} + +func ptrOf[T any](v T) *T { + return &v +} + +func TestAgenticMessageString(t *testing.T) { + longBase64 := "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" + + msg := &AgenticMessage{ + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputText, + UserInputText: &UserInputText{ + Text: "What's the weather like in New York City today?", + }, + }, + { + Type: ContentBlockTypeUserInputImage, + UserInputImage: &UserInputImage{ + URL: "https://example.com/weather-map.jpg", + Base64Data: longBase64, + MIMEType: "image/jpeg", + Detail: ImageURLDetailHigh, + }, + }, + { + Type: ContentBlockTypeReasoning, + Reasoning: &Reasoning{ + Summary: []*ReasoningSummary{ + {Index: 0, Text: "First, I need to identify the location (New York City) from the user's query."}, + {Index: 1, Text: "Then, I should call the weather API to get current conditions."}, + {Index: 2, Text: "Finally, I'll format the response in a user-friendly way with temperature and conditions."}, + }, + EncryptedContent: "encrypted_reasoning_content_that_is_very_long_and_will_be_truncated_for_display", + }, + }, + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "I'll check the current weather in New York City for you.", + }, + }, + { + Type: ContentBlockTypeFunctionToolCall, + FunctionToolCall: &FunctionToolCall{ + CallID: "call_weather_123", + Name: "get_current_weather", + Arguments: `{"location":"New York City","unit":"fahrenheit"}`, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + { + Type: ContentBlockTypeFunctionToolResult, + FunctionToolResult: &FunctionToolResult{ + CallID: "call_weather_123", + Name: "get_current_weather", + Result: `{"temperature":72,"condition":"sunny","humidity":45,"wind_speed":8}`, + }, + }, + { + Type: ContentBlockTypeMCPToolCall, + MCPToolCall: &MCPToolCall{ + ServerLabel: "weather-mcp-server", + CallID: "mcp_forecast_456", + Name: "get_7day_forecast", + Arguments: `{"city":"New York","days":7}`, + ApprovalRequestID: "approval_req_789", + }, + }, + { + Type: ContentBlockTypeMCPToolResult, + MCPToolResult: &MCPToolResult{ + CallID: "mcp_forecast_456", + Name: "get_7day_forecast", + Result: `{"status":"partial","days_available":3}`, + Error: &MCPToolCallError{ + Code: ptrOf[int64](503), + Message: "Service temporarily unavailable for full 7-day forecast", + }, + }, + }, + { + Type: ContentBlockTypeMCPListToolsResult, + MCPListToolsResult: &MCPListToolsResult{ + ServerLabel: "weather-mcp-server", + Tools: []*MCPListToolsItem{ + {Name: "get_current_weather", Description: "Get current weather conditions for a location"}, + {Name: "get_7day_forecast", Description: "Get 7-day weather forecast"}, + {Name: "get_weather_alerts", Description: "Get active weather alerts and warnings"}, + }, + }, + }, + }, + ResponseMeta: &AgenticResponseMeta{ + TokenUsage: &TokenUsage{ + PromptTokens: 250, + CompletionTokens: 180, + TotalTokens: 430, + }, + }, + } + + // Print the formatted output + output := msg.String() + + assert.Equal(t, `role: assistant +content_blocks: + [0] type: user_input_text + text: What's the weather like in New York City today? + [1] type: user_input_image + url: https://example.com/weather-map.jpg + base64_data: iVBORw0KGgoAAAANSUhE...... (96 bytes) + mime_type: image/jpeg + detail: high + [2] type: reasoning + summary: 3 items + [0] First, I need to identify the location (New York City) from the user's query. + [1] Then, I should call the weather API to get current conditions. + [2] Finally, I'll format the response in a user-friendly way with temperature and conditions. + encrypted_content: encrypted_reasoning_content_that_is_very_long_and_... + [3] type: assistant_gen_text + text: I'll check the current weather in New York City for you. + [4] type: function_tool_call + call_id: call_weather_123 + name: get_current_weather + arguments: {"location":"New York City","unit":"fahrenheit"} + stream_index: 0 + [5] type: function_tool_result + call_id: call_weather_123 + name: get_current_weather + result: {"temperature":72,"condition":"sunny","humidity":45,"wind_speed":8} + [6] type: mcp_tool_call + server_label: weather-mcp-server + call_id: mcp_forecast_456 + name: get_7day_forecast + arguments: {"city":"New York","days":7} + approval_request_id: approval_req_789 + [7] type: mcp_tool_result + call_id: mcp_forecast_456 + name: get_7day_forecast + result: {"status":"partial","days_available":3} + error: [503] Service temporarily unavailable for full 7-day forecast + [8] type: mcp_list_tools_result + server_label: weather-mcp-server + tools: 3 items + - get_current_weather: Get current weather conditions for a location + - get_7day_forecast: Get 7-day weather forecast + - get_weather_alerts: Get active weather alerts and warnings +response_meta: + token_usage: prompt=250, completion=180, total=430 +`, output) +} diff --git a/schema/message.go b/schema/message.go index fefb2079e..02cdefaad 100644 --- a/schema/message.go +++ b/schema/message.go @@ -40,47 +40,56 @@ func init() { internal.RegisterStreamChunkConcatFunc(ConcatMessages) internal.RegisterStreamChunkConcatFunc(ConcatMessageArray) + internal.RegisterStreamChunkConcatFunc(ConcatAgenticMessages) + internal.RegisterStreamChunkConcatFunc(ConcatAgenticMessagesArray) + internal.RegisterStreamChunkConcatFunc(ConcatToolResults) } -// ConcatMessageArray merges aligned slices of messages into a single slice, -// concatenating messages at the same index across the input arrays. -func ConcatMessageArray(mas [][]*Message) ([]*Message, error) { - arrayLen := len(mas[0]) +func buildConcatGenericArray[T any](f func([]*T) (*T, error)) func([][]*T) ([]*T, error) { + return func(mas [][]*T) ([]*T, error) { + arrayLen := len(mas[0]) - ret := make([]*Message, arrayLen) - slicesToConcat := make([][]*Message, arrayLen) + ret := make([]*T, arrayLen) + slicesToConcat := make([][]*T, arrayLen) - for _, ma := range mas { - if len(ma) != arrayLen { - return nil, fmt.Errorf("unexpected array length. "+ - "Got %d, expected %d", len(ma), arrayLen) - } + for _, ma := range mas { + if len(ma) != arrayLen { + return nil, fmt.Errorf("unexpected array length. "+ + "Got %d, expected %d", len(ma), arrayLen) + } - for i := 0; i < arrayLen; i++ { - m := ma[i] - if m != nil { - slicesToConcat[i] = append(slicesToConcat[i], m) + for i := 0; i < arrayLen; i++ { + m := ma[i] + if m != nil { + slicesToConcat[i] = append(slicesToConcat[i], m) + } } } - } - for i, slice := range slicesToConcat { - if len(slice) == 0 { - ret[i] = nil - } else if len(slice) == 1 { - ret[i] = slice[0] - } else { - cm, err := ConcatMessages(slice) - if err != nil { - return nil, err - } + for i, slice := range slicesToConcat { + if len(slice) == 0 { + ret[i] = nil + } else if len(slice) == 1 { + ret[i] = slice[0] + } else { + cm, err := f(slice) + if err != nil { + return nil, err + } - ret[i] = cm + ret[i] = cm + } } + + return ret, nil } +} - return ret, nil +// ConcatMessageArray merges aligned slices of messages into a single slice, +// concatenating messages at the same index across the input arrays. +func ConcatMessageArray(mas [][]*Message) ([]*Message, error) { + return buildConcatGenericArray[Message](ConcatMessages)(mas) } // FormatType used by MessageTemplate.Format @@ -721,7 +730,7 @@ var _ MessagesTemplate = MessagesPlaceholder("", false) // e.g. // // chatTemplate := prompt.FromMessages( -// schema.SystemMessage("you are eino helper"), +// schema.SystemMessage("you are an eino helper"), // schema.MessagesPlaceholder("history", false), // <= this will use the value of "history" in params // ) // msgs, err := chatTemplate.Format(ctx, params) @@ -739,7 +748,7 @@ type messagesPlaceholder struct { // // placeholder := MessagesPlaceholder("history", false) // params := map[string]any{ -// "history": []*schema.Message{{Role: "user", Content: "what is eino?"}, {Role: "assistant", Content: "eino is a great freamwork to build llm apps"}}, +// "history": []*schema.Message{{Role: "user", Content: "what is eino?"}, {Role: "assistant", Content: "eino is a great framework to build llm apps"}}, // "query": "how to use eino?", // } // chatTemplate := chatTpl := prompt.FromMessages( From 5af647532eb9df9e1cfcdb3cb3b423e4191d1284 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Tue, 6 Jan 2026 16:48:56 +0800 Subject: [PATCH 16/65] fix: concat agentic messages (#604) --- components/agentic/callback_extra.go | 5 +- components/agentic/option.go | 7 +- components/agentic/option_test.go | 2 +- components/model/callback_extra.go | 6 +- components/prompt/callback_extra.go | 2 + components/types.go | 2 + compose/tools_node_agentic.go | 7 +- compose/tools_node_agentic_test.go | 37 +- schema/agentic_message.go | 1052 ++++++++++------- schema/agentic_message_test.go | 552 ++++++--- schema/claude/consts.go | 1 + .../claude/{content_block.go => extension.go} | 70 +- schema/claude/extension_test.go | 190 +++ schema/claude/response_meta.go | 22 - .../gemini/{response_meta.go => extension.go} | 43 +- schema/gemini/extension_test.go | 79 ++ schema/message.go | 4 +- schema/openai/consts.go | 69 ++ schema/openai/content_block.go | 75 -- schema/openai/extension.go | 206 ++++ schema/openai/extension_test.go | 193 +++ schema/openai/response_meta.go | 40 - schema/tool.go | 21 + 23 files changed, 1942 insertions(+), 743 deletions(-) rename schema/claude/{content_block.go => extension.go} (51%) create mode 100644 schema/claude/extension_test.go delete mode 100644 schema/claude/response_meta.go rename schema/gemini/{response_meta.go => extension.go} (76%) create mode 100644 schema/gemini/extension_test.go delete mode 100644 schema/openai/content_block.go create mode 100644 schema/openai/extension.go create mode 100644 schema/openai/extension_test.go delete mode 100644 schema/openai/response_meta.go diff --git a/components/agentic/callback_extra.go b/components/agentic/callback_extra.go index 389408d33..2c5a656fa 100644 --- a/components/agentic/callback_extra.go +++ b/components/agentic/callback_extra.go @@ -14,6 +14,7 @@ * limitations under the License. */ +// Package agentic defines callback payloads and configuration types for agentic models. package agentic import ( @@ -26,9 +27,9 @@ type Config struct { // Model is the model name. Model string // Temperature is the temperature, which controls the randomness of the model. - Temperature float32 + Temperature float64 // TopP is the top p, which controls the diversity of the model. - TopP float32 + TopP float64 } // CallbackInput is the input for the model callback. diff --git a/components/agentic/option.go b/components/agentic/option.go index ac117ddb4..d8873442a 100644 --- a/components/agentic/option.go +++ b/components/agentic/option.go @@ -30,8 +30,10 @@ type Options struct { TopP *float64 // Tools is a list of tools the model may call. Tools []*schema.ToolInfo - // ToolChoice controls which tool is called by the model. + // ToolChoice controls how the model call the tools. ToolChoice *schema.ToolChoice + // AllowedTools is a list of allowed tools the model may call. + AllowedTools []*schema.AllowedTool } // Option is the call option for ChatModel component. @@ -81,10 +83,11 @@ func WithTools(tools []*schema.ToolInfo) Option { } // WithToolChoice is the option to set tool choice for the model. -func WithToolChoice(toolChoice schema.ToolChoice) Option { +func WithToolChoice(toolChoice schema.ToolChoice, allowedTools ...*schema.AllowedTool) Option { return Option{ apply: func(opts *Options) { opts.ToolChoice = &toolChoice + opts.AllowedTools = allowedTools }, } } diff --git a/components/agentic/option_test.go b/components/agentic/option_test.go index d349f35ac..2c5bac652 100644 --- a/components/agentic/option_test.go +++ b/components/agentic/option_test.go @@ -29,7 +29,7 @@ func TestCommon(t *testing.T) { WithTools([]*schema.ToolInfo{{Name: "test"}}), WithModel("test"), WithTemperature(0.1), - WithToolChoice(schema.ToolChoiceAllowed), + WithToolChoice(schema.ToolChoiceAllowed, []*schema.AllowedTool{{FunctionToolName: "test"}}...), WithTopP(0.1), ) assert.Len(t, o.Tools, 1) diff --git a/components/model/callback_extra.go b/components/model/callback_extra.go index 8591c4373..2767e2e5e 100644 --- a/components/model/callback_extra.go +++ b/components/model/callback_extra.go @@ -29,17 +29,17 @@ type TokenUsage struct { PromptTokenDetails PromptTokenDetails // CompletionTokens is the number of completion tokens. CompletionTokens int + // CompletionTokensDetails is a breakdown of the completion tokens. + CompletionTokensDetails CompletionTokensDetails // TotalTokens is the total number of tokens. TotalTokens int - // CompletionTokensDetails is breakdown of completion tokens. - CompletionTokensDetails CompletionTokensDetails `json:"completion_token_details"` } type CompletionTokensDetails struct { // ReasoningTokens tokens generated by the model for reasoning. // This is currently supported by OpenAI, Gemini, ARK and Qwen chat models. // For other models, this field will be 0. - ReasoningTokens int `json:"reasoning_tokens,omitempty"` + ReasoningTokens int } // PromptTokenDetails provides a breakdown of prompt token usage. diff --git a/components/prompt/callback_extra.go b/components/prompt/callback_extra.go index ff5c3a8ff..3be780543 100644 --- a/components/prompt/callback_extra.go +++ b/components/prompt/callback_extra.go @@ -33,6 +33,7 @@ type AgenticCallbackOutput struct { Extra map[string]any } +// ConvAgenticCallbackInput converts the callback input to the agentic callback input. func ConvAgenticCallbackInput(src callbacks.CallbackInput) *AgenticCallbackInput { switch t := src.(type) { case *AgenticCallbackInput: @@ -46,6 +47,7 @@ func ConvAgenticCallbackInput(src callbacks.CallbackInput) *AgenticCallbackInput } } +// ConvAgenticCallbackOutput converts the callback output to the agentic callback output. func ConvAgenticCallbackOutput(src callbacks.CallbackOutput) *AgenticCallbackOutput { switch t := src.(type) { case *AgenticCallbackOutput: diff --git a/components/types.go b/components/types.go index 2ba088e93..2b0ad8f0e 100644 --- a/components/types.go +++ b/components/types.go @@ -66,9 +66,11 @@ type Component string const ( // ComponentOfPrompt identifies chat template components. ComponentOfPrompt Component = "ChatTemplate" + // ComponentOfAgenticPrompt identifies agentic template components. ComponentOfAgenticPrompt Component = "AgenticChatTemplate" // ComponentOfChatModel identifies chat model components. ComponentOfChatModel Component = "ChatModel" + // ComponentOfAgenticModel identifies agentic model components. ComponentOfAgenticModel Component = "AgenticModel" // ComponentOfEmbedding identifies embedding components. ComponentOfEmbedding Component = "Embedding" diff --git a/compose/tools_node_agentic.go b/compose/tools_node_agentic.go index 38c5c89de..96aef7b72 100644 --- a/compose/tools_node_agentic.go +++ b/compose/tools_node_agentic.go @@ -70,6 +70,7 @@ func agenticMessageToToolCallMessage(input *schema.AgenticMessage) *schema.Messa Name: block.FunctionToolCall.Name, Arguments: block.FunctionToolCall.Arguments, }, + Extra: block.Extra, }) } return &schema.Message{ @@ -87,8 +88,8 @@ func toolMessageToAgenticMessage(input []*schema.Message) []*schema.AgenticMessa CallID: m.ToolCallID, Name: m.ToolName, Result: m.Content, - Extra: m.Extra, }, + Extra: m.Extra, }) } return []*schema.AgenticMessage{{ @@ -110,9 +111,9 @@ func streamToolMessageToAgenticMessage(input *schema.StreamReader[[]*schema.Mess CallID: m.ToolCallID, Name: m.ToolName, Result: m.Content, - Extra: m.Extra, }, - StreamMeta: &schema.StreamMeta{Index: int64(i)}, + StreamingMeta: &schema.StreamingMeta{Index: i}, + Extra: m.Extra, }) } return []*schema.AgenticMessage{{ diff --git a/compose/tools_node_agentic_test.go b/compose/tools_node_agentic_test.go index dcd3177a9..4641dd8ae 100644 --- a/compose/tools_node_agentic_test.go +++ b/compose/tools_node_agentic_test.go @@ -20,6 +20,7 @@ import ( "io" "testing" + "github.com/bytedance/sonic" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/schema" @@ -155,13 +156,14 @@ func TestStreamToolMessageToAgenticMessage(t *testing.T) { nil, }, { + nil, { Role: schema.Tool, - Content: "content1-2", + Content: "content2-2", ToolName: "name2", ToolCallID: "2", }, - nil, nil, + nil, }, { nil, nil, @@ -172,16 +174,6 @@ func TestStreamToolMessageToAgenticMessage(t *testing.T) { ToolCallID: "3", }, }, - { - nil, - { - Role: schema.Tool, - Content: "content2-2", - ToolName: "name2", - ToolCallID: "2", - }, - nil, - }, { nil, nil, { @@ -204,7 +196,11 @@ func TestStreamToolMessageToAgenticMessage(t *testing.T) { } result, err := schema.ConcatAgenticMessagesArray(chunks) assert.NoError(t, err) - assert.Equal(t, []*schema.AgenticMessage{ + + actualStr, err := sonic.MarshalString(result) + assert.NoError(t, err) + + expected := []*schema.AgenticMessage{ { Role: schema.AgenticRoleTypeUser, ContentBlocks: []*schema.ContentBlock{ @@ -213,10 +209,8 @@ func TestStreamToolMessageToAgenticMessage(t *testing.T) { FunctionToolResult: &schema.FunctionToolResult{ CallID: "1", Name: "name1", - Result: "content1-1content1-2", - Extra: map[string]interface{}{}, + Result: "content1-1", }, - StreamMeta: &schema.StreamMeta{Index: 0}, }, { Type: schema.ContentBlockTypeFunctionToolResult, @@ -224,9 +218,7 @@ func TestStreamToolMessageToAgenticMessage(t *testing.T) { CallID: "2", Name: "name2", Result: "content2-1content2-2", - Extra: map[string]interface{}{}, }, - StreamMeta: &schema.StreamMeta{Index: 1}, }, { Type: schema.ContentBlockTypeFunctionToolResult, @@ -234,11 +226,14 @@ func TestStreamToolMessageToAgenticMessage(t *testing.T) { CallID: "3", Name: "name3", Result: "content3-1content3-2", - Extra: map[string]interface{}{}, }, - StreamMeta: &schema.StreamMeta{Index: 2}, }, }, }, - }, result) + } + + expectedStr, err := sonic.MarshalString(expected) + assert.NoError(t, err) + + assert.Equal(t, expectedStr, actualStr) } diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 2139201ec..b2225b2c7 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -20,15 +20,15 @@ import ( "context" "fmt" "reflect" + "sort" "strings" - "github.com/cloudwego/eino/schema/claude" - "github.com/cloudwego/eino/schema/gemini" + "github.com/eino-contrib/jsonschema" "github.com/cloudwego/eino/internal" + "github.com/cloudwego/eino/schema/claude" + "github.com/cloudwego/eino/schema/gemini" "github.com/cloudwego/eino/schema/openai" - - "github.com/eino-contrib/jsonschema" ) type ContentBlockType string @@ -82,9 +82,9 @@ type AgenticResponseMeta struct { Extension any } -type StreamMeta struct { +type StreamingMeta struct { // Index specifies the index position of this block in the final response. - Index int64 + Index int } type ContentBlock struct { @@ -123,14 +123,12 @@ type ContentBlock struct { // MCPToolApprovalResponse records the user's approval decision for an MCP tool call. MCPToolApprovalResponse *MCPToolApprovalResponse - StreamMeta *StreamMeta + StreamingMeta *StreamingMeta + Extra map[string]any } type UserInputText struct { Text string - - // Extra stores additional information. - Extra map[string]any } type UserInputImage struct { @@ -138,27 +136,18 @@ type UserInputImage struct { Base64Data string MIMEType string Detail ImageURLDetail - - // Extra stores additional information. - Extra map[string]any } type UserInputAudio struct { URL string Base64Data string MIMEType string - - // Extra stores additional information. - Extra map[string]any } type UserInputVideo struct { URL string Base64Data string MIMEType string - - // Extra stores additional information. - Extra map[string]any } type UserInputFile struct { @@ -166,9 +155,6 @@ type UserInputFile struct { Name string Base64Data string MIMEType string - - // Extra stores additional information. - Extra map[string]any } type AssistantGenText struct { @@ -177,51 +163,37 @@ type AssistantGenText struct { OpenAIExtension *openai.AssistantGenTextExtension ClaudeExtension *claude.AssistantGenTextExtension Extension any - - // Extra stores additional information. - Extra map[string]any } type AssistantGenImage struct { URL string Base64Data string MIMEType string - - // Extra stores additional information. - Extra map[string]any } type AssistantGenAudio struct { URL string Base64Data string MIMEType string - - // Extra stores additional information. - Extra map[string]any } type AssistantGenVideo struct { URL string Base64Data string MIMEType string - - // Extra stores additional information. - Extra map[string]any } type Reasoning struct { // Summary is the reasoning content summary. Summary []*ReasoningSummary + // EncryptedContent is the encrypted reasoning content. EncryptedContent string - - // Extra stores additional information. - Extra map[string]any } type ReasoningSummary struct { // Index specifies the index position of this summary in the final Reasoning. - Index int64 + Index int Text string } @@ -229,39 +201,37 @@ type ReasoningSummary struct { type FunctionToolCall struct { // CallID is the unique identifier for the tool call. CallID string + // Name specifies the function tool invoked. Name string + // Arguments is the JSON string arguments for the function tool call. Arguments string - - // Extra stores additional information - Extra map[string]any } type FunctionToolResult struct { // CallID is the unique identifier for the tool call. CallID string + // Name specifies the function tool invoked. Name string + // Result is the function tool result returned by the user Result string - - // Extra stores additional information. - Extra map[string]any } type ServerToolCall struct { // Name specifies the server-side tool invoked. // Supplied by the model server (e.g., `web_search` for OpenAI, `googleSearch` for Gemini). Name string + // CallID is the unique identifier for the tool call. // Empty if not provided by the model server. CallID string + // Arguments are the raw inputs to the server-side tool, // supplied by the component implementer. Arguments any - // Extra stores additional information. - Extra map[string]any } type ServerToolResult struct { @@ -276,41 +246,40 @@ type ServerToolResult struct { // Result refers to the raw output generated by the server-side tool, // supplied by the component implementer. Result any - - // Extra stores additional information. - Extra map[string]any } type MCPToolCall struct { // ServerLabel is the MCP server label used to identify it in tool calls ServerLabel string - // ApprovalRequestID is the unique ID of the approval request. + + // ApprovalRequestID is the approval request ID. ApprovalRequestID string + // CallID is the unique ID of the tool call. CallID string + // Name is the name of the tool to run. Name string + // Arguments is the JSON string arguments for the tool call. Arguments string - - // Extra stores additional information. - Extra map[string]any } type MCPToolResult struct { // ServerLabel is the MCP server label used to identify it in tool calls ServerLabel string + // CallID is the unique ID of the tool call. CallID string + // Name is the name of the tool to run. Name string + // Result is the JSON string with the tool result. Result string + // Error returned when the server fails to run the tool. Error *MCPToolCallError - - // Extra stores additional information. - Extra map[string]any } type MCPToolCallError struct { @@ -321,49 +290,49 @@ type MCPToolCallError struct { type MCPListToolsResult struct { // ServerLabel is the MCP server label used to identify it in tool calls. ServerLabel string + // Tools is the list of tools available on the server. Tools []*MCPListToolsItem + // Error returned when the server fails to list tools. Error string - - // Extra stores additional information. - Extra map[string]any } type MCPListToolsItem struct { // Name is the name of the tool. Name string + // Description is the description of the tool. Description string - // InputSchema is the JSON schema that describes the tool input. + + // InputSchema is the JSON schema that describes the tool input parameters. InputSchema *jsonschema.Schema } type MCPToolApprovalRequest struct { // ID is the approval request ID. ID string + // Name is the name of the tool to run. Name string + // Arguments is the JSON string arguments for the tool call. Arguments string + // ServerLabel is the MCP server label used to identify it in tool calls. ServerLabel string - - // Extra stores additional information. - Extra map[string]any } type MCPToolApprovalResponse struct { // ApprovalRequestID is the approval request ID being responded to. ApprovalRequestID string + // Approve indicates whether the request is approved. Approve bool + // Reason is the rationale for the decision. // Optional. Reason string - - // Extra stores additional information. - Extra map[string]any } // DeveloperAgenticMessage represents a message with AgenticRoleType "developer". @@ -404,8 +373,33 @@ func FunctionToolResultAgenticMessage(callID, name, result string) *AgenticMessa } } -func NewContentBlock(block any) *ContentBlock { - switch b := block.(type) { +type contentBlockVariant interface { + Reasoning | userInputVariant | assistantGenVariant | functionToolCallVariant | serverToolCallVariant | mcpToolCallVariant +} + +type userInputVariant interface { + UserInputText | UserInputImage | UserInputAudio | UserInputVideo | UserInputFile +} + +type assistantGenVariant interface { + AssistantGenText | AssistantGenImage | AssistantGenAudio | AssistantGenVideo +} + +type functionToolCallVariant interface { + FunctionToolCall | FunctionToolResult +} + +type serverToolCallVariant interface { + ServerToolCall | ServerToolResult +} + +type mcpToolCallVariant interface { + MCPToolCall | MCPToolResult | MCPListToolsResult | MCPToolApprovalRequest | MCPToolApprovalResponse +} + +// NewContentBlock creates a new ContentBlock with the given content. +func NewContentBlock[T contentBlockVariant](content *T) *ContentBlock { + switch b := any(content).(type) { case *Reasoning: return &ContentBlock{Type: ContentBlockTypeReasoning, Reasoning: b} case *UserInputText: @@ -449,6 +443,13 @@ func NewContentBlock(block any) *ContentBlock { } } +// NewContentBlockChunk creates a new ContentBlock with the given content and streaming metadata. +func NewContentBlockChunk[T contentBlockVariant](content *T, meta *StreamingMeta) *ContentBlock { + block := NewContentBlock(content) + block.StreamingMeta = meta + return block +} + // AgenticMessagesTemplate is the interface for messages template. // It's used to render a template to a list of agentic messages. // e.g. @@ -683,16 +684,19 @@ func formatUserInputFile(uif *UserInputFile, vs map[string]any, formatType Forma return &copied, nil } +// ConcatAgenticMessagesArray concatenates multiple streams of AgenticMessage into a single slice of AgenticMessage. func ConcatAgenticMessagesArray(mas [][]*AgenticMessage) ([]*AgenticMessage, error) { return buildConcatGenericArray[AgenticMessage](ConcatAgenticMessages)(mas) } +// ConcatAgenticMessages concatenates a list of AgenticMessage chunks into a single AgenticMessage. func ConcatAgenticMessages(msgs []*AgenticMessage) (*AgenticMessage, error) { var ( - role AgenticRoleType - blocksList [][]*ContentBlock - blocks []*ContentBlock - metas []*AgenticResponseMeta + role AgenticRoleType + blocks []*ContentBlock + metas []*AgenticResponseMeta + blockIndices []int + indexToBlocks = map[int][]*ContentBlock{} ) if len(msgs) == 1 { @@ -713,9 +717,12 @@ func ConcatAgenticMessages(msgs []*AgenticMessage) (*AgenticMessage, error) { } for _, block := range msg.ContentBlocks { - if block.StreamMeta == nil { + if block == nil { + continue + } + if block.StreamingMeta == nil { // Non-streaming block - if len(blocksList) > 0 { + if len(blockIndices) > 0 { // Cannot mix streaming and non-streaming blocks return nil, fmt.Errorf("found non-streaming block after streaming blocks") } @@ -728,8 +735,12 @@ func ConcatAgenticMessages(msgs []*AgenticMessage) (*AgenticMessage, error) { return nil, fmt.Errorf("found streaming block after non-streaming blocks") } // Collect streaming block by index - blocksList = expandSlice(int(block.StreamMeta.Index), blocksList) - blocksList[block.StreamMeta.Index] = append(blocksList[block.StreamMeta.Index], block) + if blocks_, ok := indexToBlocks[block.StreamingMeta.Index]; ok { + indexToBlocks[block.StreamingMeta.Index] = append(blocks_, block) + } else { + blockIndices = append(blockIndices, block.StreamingMeta.Index) + indexToBlocks[block.StreamingMeta.Index] = []*ContentBlock{block} + } } } @@ -743,219 +754,254 @@ func ConcatAgenticMessages(msgs []*AgenticMessage) (*AgenticMessage, error) { return nil, fmt.Errorf("failed to concat agentic response meta: %w", err) } - if len(blocksList) > 0 { + if len(blockIndices) > 0 { // All blocks are streaming, concat each group by index - blocks = make([]*ContentBlock, len(blocksList)) - for i, bs := range blocksList { - if len(bs) == 0 { - continue - } - b, err := concatAgenticContentBlocks(bs) + indexToBlock := map[int]*ContentBlock{} + for idx, bs := range indexToBlocks { + b, err := concatChunksOfSameContentBlock(bs) if err != nil { - return nil, fmt.Errorf("failed to concat content blocks at index %d: %w", i, err) + return nil, err } - blocks[i] = b + indexToBlock[idx] = b } - } - - for i := 0; i < len(blocks); i++ { - if blocks[i] == nil { - blocks = append(blocks[:i], blocks[i+1:]...) + blocks = make([]*ContentBlock, 0, len(blockIndices)) + sort.Slice(blockIndices, func(i, j int) bool { + return blockIndices[i] < blockIndices[j] + }) + for _, idx := range blockIndices { + blocks = append(blocks, indexToBlock[idx]) } } return &AgenticMessage{ - ResponseMeta: meta, Role: role, + ResponseMeta: meta, ContentBlocks: blocks, }, nil } -func concatAgenticResponseMeta(metas []*AgenticResponseMeta) (*AgenticResponseMeta, error) { +func concatAgenticResponseMeta(metas []*AgenticResponseMeta) (ret *AgenticResponseMeta, err error) { if len(metas) == 0 { return nil, nil } - ret := &AgenticResponseMeta{ - TokenUsage: &TokenUsage{}, - OpenAIExtension: nil, - ClaudeExtension: nil, - GeminiExtension: nil, - Extension: nil, - } + + openaiExtensions := make([]*openai.ResponseMetaExtension, 0, len(metas)) + claudeExtensions := make([]*claude.ResponseMetaExtension, 0, len(metas)) + geminiExtensions := make([]*gemini.ResponseMetaExtension, 0, len(metas)) + tokenUsages := make([]*TokenUsage, 0, len(metas)) + + var ( + extType reflect.Type + extensions reflect.Value + ) + for _, meta := range metas { - ret.Extension = meta.Extension - ret.OpenAIExtension = meta.OpenAIExtension - ret.ClaudeExtension = meta.ClaudeExtension - ret.GeminiExtension = meta.GeminiExtension if meta.TokenUsage != nil { - ret.TokenUsage.CompletionTokens += meta.TokenUsage.CompletionTokens - ret.TokenUsage.CompletionTokenDetails.ReasoningTokens += meta.TokenUsage.CompletionTokenDetails.ReasoningTokens - ret.TokenUsage.PromptTokens += meta.TokenUsage.PromptTokens - ret.TokenUsage.PromptTokenDetails.CachedTokens += meta.TokenUsage.PromptTokenDetails.CachedTokens - ret.TokenUsage.TotalTokens += meta.TokenUsage.TotalTokens + tokenUsages = append(tokenUsages, meta.TokenUsage) + } + + var isConsistent bool + + if meta.Extension != nil { + extType, isConsistent = validateExtensionType(extType, meta.Extension) + if !isConsistent { + return nil, fmt.Errorf("inconsistent extension types in response meta chunks: '%s' vs '%s'", + extType, reflect.TypeOf(meta.Extension)) + } + if !extensions.IsValid() { + extensions = reflect.MakeSlice(reflect.SliceOf(extType), 0, len(metas)) + } + extensions = reflect.Append(extensions, reflect.ValueOf(meta.Extension)) + } + + if meta.OpenAIExtension != nil { + extType, isConsistent = validateExtensionType(extType, meta.OpenAIExtension) + if !isConsistent { + return nil, fmt.Errorf("inconsistent extension types in response meta chunks: '%s' vs '%s'", + extType, reflect.TypeOf(meta.OpenAIExtension)) + } + openaiExtensions = append(openaiExtensions, meta.OpenAIExtension) + } + + if meta.ClaudeExtension != nil { + extType, isConsistent = validateExtensionType(extType, meta.ClaudeExtension) + if !isConsistent { + return nil, fmt.Errorf("inconsistent extension types in response meta chunks: '%s' vs '%s'", + extType, reflect.TypeOf(meta.ClaudeExtension)) + } + claudeExtensions = append(claudeExtensions, meta.ClaudeExtension) + } + + if meta.GeminiExtension != nil { + extType, isConsistent = validateExtensionType(extType, meta.GeminiExtension) + if !isConsistent { + return nil, fmt.Errorf("inconsistent extension types in response meta chunks: '%s' vs '%s'", + extType, reflect.TypeOf(meta.GeminiExtension)) + } + geminiExtensions = append(geminiExtensions, meta.GeminiExtension) + } + } + + ret = &AgenticResponseMeta{ + TokenUsage: concatTokenUsage(tokenUsages), + } + + if extensions.IsValid() && !extensions.IsZero() { + var extension reflect.Value + extension, err = internal.ConcatSliceValue(extensions) + if err != nil { + return nil, fmt.Errorf("failed to concat extensions: %w", err) + } + ret.Extension = extension.Interface() + } + + if len(openaiExtensions) > 0 { + ret.OpenAIExtension, err = openai.ConcatResponseMetaExtensions(openaiExtensions) + if err != nil { + return nil, fmt.Errorf("failed to concat openai extensions: %w", err) + } + } + + if len(claudeExtensions) > 0 { + ret.ClaudeExtension, err = claude.ConcatResponseMetaExtensions(claudeExtensions) + if err != nil { + return nil, fmt.Errorf("failed to concat claude extensions: %w", err) + } + } + + if len(geminiExtensions) > 0 { + ret.GeminiExtension, err = gemini.ConcatResponseMetaExtensions(geminiExtensions) + if err != nil { + return nil, fmt.Errorf("failed to concat gemini extensions: %w", err) } } + return ret, nil } -func concatAgenticContentBlocks(blocks []*ContentBlock) (*ContentBlock, error) { +func concatTokenUsage(usages []*TokenUsage) *TokenUsage { + if len(usages) == 0 { + return nil + } + + ret := &TokenUsage{} + + for _, usage := range usages { + if usage == nil { + continue + } + ret.CompletionTokens += usage.CompletionTokens + ret.CompletionTokensDetails.ReasoningTokens += usage.CompletionTokensDetails.ReasoningTokens + ret.PromptTokens += usage.PromptTokens + ret.PromptTokenDetails.CachedTokens += usage.PromptTokenDetails.CachedTokens + ret.TotalTokens += usage.TotalTokens + } + + return ret +} + +func concatChunksOfSameContentBlock(blocks []*ContentBlock) (*ContentBlock, error) { if len(blocks) == 0 { return nil, fmt.Errorf("no content blocks to concat") } + blockType := blocks[0].Type - index := blocks[0].StreamMeta.Index + switch blockType { case ContentBlockTypeReasoning: - return concatContentBlockHelper(blocks, blockType, "reasoning", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *Reasoning { return b.Reasoning }, - concatReasoning, - func(r *Reasoning) *ContentBlock { - return &ContentBlock{Type: blockType, Reasoning: r, StreamMeta: &StreamMeta{Index: index}} - }) + concatReasoning) case ContentBlockTypeUserInputText: - return concatContentBlockHelper(blocks, blockType, "user input text", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *UserInputText { return b.UserInputText }, - concatUserInputText, - func(t *UserInputText) *ContentBlock { - return &ContentBlock{Type: blockType, UserInputText: t, StreamMeta: &StreamMeta{Index: index}} - }) + concatUserInputTexts) case ContentBlockTypeUserInputImage: - return concatContentBlockHelper(blocks, blockType, "user input image", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *UserInputImage { return b.UserInputImage }, - concatUserInputImage, - func(i *UserInputImage) *ContentBlock { - return &ContentBlock{Type: blockType, UserInputImage: i, StreamMeta: &StreamMeta{Index: index}} - }) + concatUserInputImages) case ContentBlockTypeUserInputAudio: - return concatContentBlockHelper(blocks, blockType, "user input audio", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *UserInputAudio { return b.UserInputAudio }, - concatUserInputAudio, - func(a *UserInputAudio) *ContentBlock { - return &ContentBlock{Type: blockType, UserInputAudio: a, StreamMeta: &StreamMeta{Index: index}} - }) + concatUserInputAudios) case ContentBlockTypeUserInputVideo: - return concatContentBlockHelper(blocks, blockType, "user input video", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *UserInputVideo { return b.UserInputVideo }, - concatUserInputVideo, - func(v *UserInputVideo) *ContentBlock { - return &ContentBlock{Type: blockType, UserInputVideo: v, StreamMeta: &StreamMeta{Index: index}} - }) + concatUserInputVideos) case ContentBlockTypeUserInputFile: - return concatContentBlockHelper(blocks, blockType, "user input file", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *UserInputFile { return b.UserInputFile }, - concatUserInputFile, - func(f *UserInputFile) *ContentBlock { - return &ContentBlock{Type: blockType, UserInputFile: f, StreamMeta: &StreamMeta{Index: index}} - }) + concatUserInputFiles) case ContentBlockTypeAssistantGenText: - return concatContentBlockHelper(blocks, blockType, "assistant gen text", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *AssistantGenText { return b.AssistantGenText }, - concatAssistantGenText, - func(t *AssistantGenText) *ContentBlock { - return &ContentBlock{Type: blockType, AssistantGenText: t, StreamMeta: &StreamMeta{Index: index}} - }) + concatAssistantGenTexts) case ContentBlockTypeAssistantGenImage: - return concatContentBlockHelper(blocks, blockType, "assistant gen image", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *AssistantGenImage { return b.AssistantGenImage }, - concatAssistantGenImage, - func(i *AssistantGenImage) *ContentBlock { - return &ContentBlock{Type: blockType, AssistantGenImage: i, StreamMeta: &StreamMeta{Index: index}} - }) + concatAssistantGenImages) case ContentBlockTypeAssistantGenAudio: - return concatContentBlockHelper(blocks, blockType, "assistant gen audio", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *AssistantGenAudio { return b.AssistantGenAudio }, - concatAssistantGenAudio, - func(a *AssistantGenAudio) *ContentBlock { - return &ContentBlock{Type: blockType, AssistantGenAudio: a, StreamMeta: &StreamMeta{Index: index}} - }) + concatAssistantGenAudios) case ContentBlockTypeAssistantGenVideo: - return concatContentBlockHelper(blocks, blockType, "assistant gen video", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *AssistantGenVideo { return b.AssistantGenVideo }, - concatAssistantGenVideo, - func(v *AssistantGenVideo) *ContentBlock { - return &ContentBlock{Type: blockType, AssistantGenVideo: v, StreamMeta: &StreamMeta{Index: index}} - }) + concatAssistantGenVideos) case ContentBlockTypeFunctionToolCall: - return concatContentBlockHelper(blocks, blockType, "function tool call", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *FunctionToolCall { return b.FunctionToolCall }, - concatFunctionToolCall, - func(c *FunctionToolCall) *ContentBlock { - return &ContentBlock{Type: blockType, FunctionToolCall: c, StreamMeta: &StreamMeta{Index: index}} - }) + concatFunctionToolCalls) case ContentBlockTypeFunctionToolResult: - return concatContentBlockHelper(blocks, blockType, "function tool result", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *FunctionToolResult { return b.FunctionToolResult }, - concatFunctionToolResult, - func(r *FunctionToolResult) *ContentBlock { - return &ContentBlock{Type: blockType, FunctionToolResult: r, StreamMeta: &StreamMeta{Index: index}} - }) + concatFunctionToolResults) case ContentBlockTypeServerToolCall: - return concatContentBlockHelper(blocks, blockType, "server tool call", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *ServerToolCall { return b.ServerToolCall }, - concatServerToolCall, - func(c *ServerToolCall) *ContentBlock { - return &ContentBlock{Type: blockType, ServerToolCall: c, StreamMeta: &StreamMeta{Index: index}} - }) + concatServerToolCalls) case ContentBlockTypeServerToolResult: - return concatContentBlockHelper(blocks, blockType, "server tool result", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *ServerToolResult { return b.ServerToolResult }, - concatServerToolResult, - func(r *ServerToolResult) *ContentBlock { - return &ContentBlock{Type: blockType, ServerToolResult: r, StreamMeta: &StreamMeta{Index: index}} - }) + concatServerToolResults) case ContentBlockTypeMCPToolCall: - return concatContentBlockHelper(blocks, blockType, "MCP tool call", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *MCPToolCall { return b.MCPToolCall }, - concatMCPToolCall, - func(c *MCPToolCall) *ContentBlock { - return &ContentBlock{Type: blockType, MCPToolCall: c, StreamMeta: &StreamMeta{Index: index}} - }) + concatMCPToolCalls) case ContentBlockTypeMCPToolResult: - return concatContentBlockHelper(blocks, blockType, "MCP tool result", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *MCPToolResult { return b.MCPToolResult }, - concatMCPToolResult, - func(r *MCPToolResult) *ContentBlock { - return &ContentBlock{Type: blockType, MCPToolResult: r, StreamMeta: &StreamMeta{Index: index}} - }) + concatMCPToolResults) case ContentBlockTypeMCPListToolsResult: - return concatContentBlockHelper(blocks, blockType, "MCP list tools", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *MCPListToolsResult { return b.MCPListToolsResult }, - concatMCPListToolsResult, - func(r *MCPListToolsResult) *ContentBlock { - return &ContentBlock{Type: blockType, MCPListToolsResult: r, StreamMeta: &StreamMeta{Index: index}} - }) + concatMCPListToolsResults) case ContentBlockTypeMCPToolApprovalRequest: - return concatContentBlockHelper(blocks, blockType, "MCP tool approval request", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *MCPToolApprovalRequest { return b.MCPToolApprovalRequest }, - concatMCPToolApprovalRequest, - func(r *MCPToolApprovalRequest) *ContentBlock { - return &ContentBlock{Type: blockType, MCPToolApprovalRequest: r, StreamMeta: &StreamMeta{Index: index}} - }) + concatMCPToolApprovalRequests) case ContentBlockTypeMCPToolApprovalResponse: - return concatContentBlockHelper(blocks, blockType, "MCP tool approval response", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *MCPToolApprovalResponse { return b.MCPToolApprovalResponse }, - concatMCPToolApprovalResponse, - func(r *MCPToolApprovalResponse) *ContentBlock { - return &ContentBlock{Type: blockType, MCPToolApprovalResponse: r, StreamMeta: &StreamMeta{Index: index}} - }) + concatMCPToolApprovalResponses) default: return nil, fmt.Errorf("unknown content block type: %s", blockType) @@ -964,21 +1010,19 @@ func concatAgenticContentBlocks(blocks []*ContentBlock) (*ContentBlock, error) { // concatContentBlockHelper is a generic helper function that reduces code duplication // for concatenating content blocks of a specific type. -func concatContentBlockHelper[T any]( +func concatContentBlockHelper[T contentBlockVariant]( blocks []*ContentBlock, expectedType ContentBlockType, - typeName string, getter func(*ContentBlock) *T, concatFunc func([]*T) (*T, error), - constructor func(*T) *ContentBlock, ) (*ContentBlock, error) { items, err := genericGetTFromContentBlocks(blocks, func(block *ContentBlock) (*T, error) { if block.Type != expectedType { - return nil, fmt.Errorf("expected %s block, got %s", typeName, block.Type) + return nil, fmt.Errorf("content block type mismatch: expected '%s', but got '%s'", expectedType, block.Type) } item := getter(block) if item == nil { - return nil, fmt.Errorf("%s content is nil", typeName) + return nil, fmt.Errorf("'%s' content is nil", expectedType) } return item, nil }) @@ -988,10 +1032,28 @@ func concatContentBlockHelper[T any]( concatenated, err := concatFunc(items) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to concat '%s' content blocks: %w", expectedType, err) + } + + extras := make([]map[string]any, 0, len(blocks)) + for _, block := range blocks { + if len(block.Extra) > 0 { + extras = append(extras, block.Extra) + } + } + + var extra map[string]any + if len(extras) > 0 { + extra, err = internal.ConcatItems(extras) + if err != nil { + return nil, fmt.Errorf("failed to concat content block extras: %w", err) + } } - return constructor(concatenated), nil + block := NewContentBlock(concatenated) + block.Extra = extra + + return block, nil } func genericGetTFromContentBlocks[T any](blocks []*ContentBlock, checkAndGetter func(block *ContentBlock) (T, error)) ([]T, error) { @@ -1006,43 +1068,14 @@ func genericGetTFromContentBlocks[T any](blocks []*ContentBlock, checkAndGetter return ret, nil } -// Concatenation strategies for different content block types: -// -// String concatenation (incremental streaming): -// - Reasoning: Summary texts are concatenated, grouped by Index if present -// - UserInputText: Text fields are concatenated -// - AssistantGenText: Text fields are concatenated, annotations/citations are merged -// - FunctionToolCall: Arguments (JSON strings) are concatenated incrementally -// - FunctionToolResult: Result strings are concatenated -// - ServerToolCall: Arguments are merged (last non-nil value for any type) -// - ServerToolResult: Results are merged using internal.ConcatItems -// - MCPToolCall: Arguments (JSON strings) are concatenated incrementally -// - MCPToolResult: Result strings are concatenated -// - MCPListToolsResult: Tools arrays are merged -// - MCPToolApprovalRequest: Arguments are concatenated -// -// Take last block (non-streaming content): -// - UserInputImage, UserInputAudio, UserInputVideo, UserInputFile: Return last block -// - AssistantGenImage, AssistantGenAudio, AssistantGenVideo: Return last block -// - MCPToolApprovalResponse: Return last block -// - func concatReasoning(reasons []*Reasoning) (*Reasoning, error) { if len(reasons) == 0 { return nil, fmt.Errorf("no reasoning found") } - if len(reasons) == 1 { - return reasons[0], nil - } - ret := &Reasoning{ - Summary: make([]*ReasoningSummary, 0), - EncryptedContent: "", - Extra: make(map[string]any), - } + ret := &Reasoning{} - // Collect all summaries from all reasons - allSummaries := make([]*ReasoningSummary, 0) + var allSummaries []*ReasoningSummary for _, r := range reasons { if r == nil { continue @@ -1051,157 +1084,269 @@ func concatReasoning(reasons []*Reasoning) (*Reasoning, error) { if r.EncryptedContent != "" { ret.EncryptedContent += r.EncryptedContent } - for k, v := range r.Extra { - ret.Extra[k] = v - } } - // Group by Index and concatenate Text for same Index - // Use dynamic array that expands as needed - var summaryArray []*ReasoningSummary + var ( + indices []int + indexToSummary = map[int]*ReasoningSummary{} + ) + for _, s := range allSummaries { - idx := s.Index - // Expand array if needed - summaryArray = expandSlice(int(idx), summaryArray) - if summaryArray[idx] == nil { - // Create new entry with a copy of Index - summaryArray[idx] = &ReasoningSummary{ - Index: idx, - Text: s.Text, - } - } else { - // Concatenate text for same index - summaryArray[idx].Text += s.Text + if s == nil { + continue + } + if indexToSummary[s.Index] == nil { + indexToSummary[s.Index] = &ReasoningSummary{} + indices = append(indices, s.Index) } + indexToSummary[s.Index].Text += s.Text } - // Convert array to slice, filtering out nil entries - ret.Summary = make([]*ReasoningSummary, 0, len(summaryArray)) - for _, summary := range summaryArray { - if summary != nil { - ret.Summary = append(ret.Summary, summary) - } + sort.Slice(indices, func(i, j int) bool { + return indices[i] < indices[j] + }) + + ret.Summary = make([]*ReasoningSummary, 0, len(indices)) + for _, idx := range indices { + ret.Summary = append(ret.Summary, indexToSummary[idx]) } return ret, nil } -func concatUserInputText(texts []*UserInputText) (*UserInputText, error) { +func concatUserInputTexts(texts []*UserInputText) (*UserInputText, error) { if len(texts) == 0 { return nil, fmt.Errorf("no user input text found") } if len(texts) == 1 { return texts[0], nil } - - ret := &UserInputText{ - Text: "", - Extra: make(map[string]any), - } - - for _, t := range texts { - if t == nil { - continue - } - ret.Text += t.Text - for k, v := range t.Extra { - ret.Extra[k] = v - } - } - - return ret, nil + return nil, fmt.Errorf("cannot concat multiple user input texts") } -func concatUserInputImage(images []*UserInputImage) (*UserInputImage, error) { +func concatUserInputImages(images []*UserInputImage) (*UserInputImage, error) { if len(images) == 0 { return nil, fmt.Errorf("no user input image found") } - return images[len(images)-1], nil + if len(images) == 1 { + return images[0], nil + } + return nil, fmt.Errorf("cannot concat multiple user input images") } -func concatUserInputAudio(audios []*UserInputAudio) (*UserInputAudio, error) { +func concatUserInputAudios(audios []*UserInputAudio) (*UserInputAudio, error) { if len(audios) == 0 { return nil, fmt.Errorf("no user input audio found") } - return audios[len(audios)-1], nil + if len(audios) == 1 { + return audios[0], nil + } + return nil, fmt.Errorf("cannot concat multiple user input audios") } -func concatUserInputVideo(videos []*UserInputVideo) (*UserInputVideo, error) { +func concatUserInputVideos(videos []*UserInputVideo) (*UserInputVideo, error) { if len(videos) == 0 { return nil, fmt.Errorf("no user input video found") } - return videos[len(videos)-1], nil + if len(videos) == 1 { + return videos[0], nil + } + return nil, fmt.Errorf("cannot concat multiple user input videos") } -func concatUserInputFile(files []*UserInputFile) (*UserInputFile, error) { +func concatUserInputFiles(files []*UserInputFile) (*UserInputFile, error) { if len(files) == 0 { return nil, fmt.Errorf("no user input file found") } - return files[len(files)-1], nil + if len(files) == 1 { + return files[0], nil + } + return nil, fmt.Errorf("cannot concat multiple user input files") } -func concatAssistantGenText(texts []*AssistantGenText) (*AssistantGenText, error) { +func concatAssistantGenTexts(texts []*AssistantGenText) (ret *AssistantGenText, err error) { if len(texts) == 0 { - return nil, fmt.Errorf("no assistant gen text found") + return nil, fmt.Errorf("no assistant generated text found") } if len(texts) == 1 { return texts[0], nil } - ret := &AssistantGenText{ - Text: "", - OpenAIExtension: nil, - ClaudeExtension: nil, - Extra: make(map[string]any), - } + ret = &AssistantGenText{} + + openaiExtensions := make([]*openai.AssistantGenTextExtension, 0, len(texts)) + claudeExtensions := make([]*claude.AssistantGenTextExtension, 0, len(texts)) + + var ( + extType reflect.Type + extensions reflect.Value + ) for _, t := range texts { if t == nil { continue } + ret.Text += t.Text + + var isConsistent bool + + if t.Extension != nil { + extType, isConsistent = validateExtensionType(extType, t.Extension) + if !isConsistent { + return nil, fmt.Errorf("inconsistent extension types in assistant generated text chunks: '%s' vs '%s'", + extType, reflect.TypeOf(t.Extension)) + } + if !extensions.IsValid() { + extensions = reflect.MakeSlice(reflect.SliceOf(extType), 0, len(texts)) + } + extensions = reflect.Append(extensions, reflect.ValueOf(t.Extension)) + } + if t.OpenAIExtension != nil { - if ret.OpenAIExtension == nil { - ret.OpenAIExtension = &openai.AssistantGenTextExtension{} + extType, isConsistent = validateExtensionType(extType, t.OpenAIExtension) + if !isConsistent { + return nil, fmt.Errorf("inconsistent extension types in assistant generated text chunks: '%s' vs '%s'", + extType, reflect.TypeOf(t.OpenAIExtension)) } - ret.OpenAIExtension.Annotations = append(ret.OpenAIExtension.Annotations, t.OpenAIExtension.Annotations...) + openaiExtensions = append(openaiExtensions, t.OpenAIExtension) } + if t.ClaudeExtension != nil { - if ret.ClaudeExtension == nil { - ret.ClaudeExtension = &claude.AssistantGenTextExtension{} + extType, isConsistent = validateExtensionType(extType, t.ClaudeExtension) + if !isConsistent { + return nil, fmt.Errorf("inconsistent extension types in assistant generated text chunks: '%s' vs '%s'", + extType, reflect.TypeOf(t.ClaudeExtension)) } - ret.ClaudeExtension.Citations = append(ret.ClaudeExtension.Citations, t.ClaudeExtension.Citations...) + claudeExtensions = append(claudeExtensions, t.ClaudeExtension) } - for k, v := range t.Extra { - ret.Extra[k] = v + } + + if extensions.IsValid() && !extensions.IsZero() { + ret.Extension, err = internal.ConcatSliceValue(extensions) + if err != nil { + return nil, err + } + ret.Extension = extensions.Interface() + } + + if len(openaiExtensions) > 0 { + ret.OpenAIExtension, err = openai.ConcatAssistantGenTextExtensions(openaiExtensions) + if err != nil { + return nil, err + } + } + + if len(claudeExtensions) > 0 { + ret.ClaudeExtension, err = claude.ConcatAssistantGenTextExtensions(claudeExtensions) + if err != nil { + return nil, err } } return ret, nil } -func concatAssistantGenImage(images []*AssistantGenImage) (*AssistantGenImage, error) { +func concatAssistantGenImages(images []*AssistantGenImage) (*AssistantGenImage, error) { if len(images) == 0 { return nil, fmt.Errorf("no assistant gen image found") } - return images[len(images)-1], nil + if len(images) == 1 { + return images[0], nil + } + + ret := &AssistantGenImage{} + + for _, img := range images { + if img == nil { + continue + } + + ret.Base64Data += img.Base64Data + + if ret.URL == "" { + ret.URL = img.URL + } else if img.URL != "" && ret.URL != img.URL { + return nil, fmt.Errorf("inconsistent URLs in assistant generated image chunks: '%s' vs '%s'", ret.URL, img.URL) + } + + if ret.MIMEType == "" { + ret.MIMEType = img.MIMEType + } else if img.MIMEType != "" && ret.MIMEType != img.MIMEType { + return nil, fmt.Errorf("inconsistent MIME types in assistant generated image chunks: '%s' vs '%s'", ret.MIMEType, img.MIMEType) + } + } + + return ret, nil } -func concatAssistantGenAudio(audios []*AssistantGenAudio) (*AssistantGenAudio, error) { +func concatAssistantGenAudios(audios []*AssistantGenAudio) (*AssistantGenAudio, error) { if len(audios) == 0 { return nil, fmt.Errorf("no assistant gen audio found") } - return audios[len(audios)-1], nil + if len(audios) == 1 { + return audios[0], nil + } + + ret := &AssistantGenAudio{} + + for _, audio := range audios { + if audio == nil { + continue + } + + ret.Base64Data += audio.Base64Data + + if ret.URL == "" { + ret.URL = audio.URL + } else if audio.URL != "" && ret.URL != audio.URL { + return nil, fmt.Errorf("inconsistent URLs in assistant generated audio chunks: '%s' vs '%s'", ret.URL, audio.URL) + } + + if ret.MIMEType == "" { + ret.MIMEType = audio.MIMEType + } else if audio.MIMEType != "" && ret.MIMEType != audio.MIMEType { + return nil, fmt.Errorf("inconsistent MIME types in assistant generated audio chunks: '%s' vs '%s'", ret.MIMEType, audio.MIMEType) + } + } + + return ret, nil } -func concatAssistantGenVideo(videos []*AssistantGenVideo) (*AssistantGenVideo, error) { +func concatAssistantGenVideos(videos []*AssistantGenVideo) (*AssistantGenVideo, error) { if len(videos) == 0 { return nil, fmt.Errorf("no assistant gen video found") } - return videos[len(videos)-1], nil + if len(videos) == 1 { + return videos[0], nil + } + + ret := &AssistantGenVideo{} + + for _, video := range videos { + if video == nil { + continue + } + + ret.Base64Data += video.Base64Data + + if ret.URL == "" { + ret.URL = video.URL + } else if video.URL != "" && ret.URL != video.URL { + return nil, fmt.Errorf("inconsistent URLs in assistant generated video chunks: '%s' vs '%s'", ret.URL, video.URL) + } + + if ret.MIMEType == "" { + ret.MIMEType = video.MIMEType + } else if video.MIMEType != "" && ret.MIMEType != video.MIMEType { + return nil, fmt.Errorf("inconsistent MIME types in assistant generated video chunks: '%s' vs '%s'", ret.MIMEType, video.MIMEType) + } + } + + return ret, nil } -func concatFunctionToolCall(calls []*FunctionToolCall) (*FunctionToolCall, error) { +func concatFunctionToolCalls(calls []*FunctionToolCall) (*FunctionToolCall, error) { if len(calls) == 0 { return nil, fmt.Errorf("no function tool call found") } @@ -1209,31 +1354,32 @@ func concatFunctionToolCall(calls []*FunctionToolCall) (*FunctionToolCall, error return calls[0], nil } - // For tool calls, arguments are typically built incrementally during streaming - ret := &FunctionToolCall{ - Extra: make(map[string]any), - } + ret := &FunctionToolCall{} for _, c := range calls { if c == nil { continue } + if ret.CallID == "" { ret.CallID = c.CallID + } else if c.CallID != "" && c.CallID != ret.CallID { + return nil, fmt.Errorf("expected call ID '%s' for function tool call, but got '%s'", ret.CallID, c.CallID) } + if ret.Name == "" { ret.Name = c.Name + } else if c.Name != "" && c.Name != ret.Name { + return nil, fmt.Errorf("expected tool name '%s' for function tool call, but got '%s'", ret.Name, c.Name) } + ret.Arguments += c.Arguments - for k, v := range c.Extra { - ret.Extra[k] = v - } } return ret, nil } -func concatFunctionToolResult(results []*FunctionToolResult) (*FunctionToolResult, error) { +func concatFunctionToolResults(results []*FunctionToolResult) (*FunctionToolResult, error) { if len(results) == 0 { return nil, fmt.Errorf("no function tool result found") } @@ -1241,30 +1387,32 @@ func concatFunctionToolResult(results []*FunctionToolResult) (*FunctionToolResul return results[0], nil } - ret := &FunctionToolResult{ - Extra: make(map[string]any), - } + ret := &FunctionToolResult{} for _, r := range results { if r == nil { continue } + if ret.CallID == "" { ret.CallID = r.CallID + } else if r.CallID != "" && r.CallID != ret.CallID { + return nil, fmt.Errorf("expected call ID '%s' for function tool result, but got '%s'", ret.CallID, r.CallID) } + if ret.Name == "" { ret.Name = r.Name + } else if r.Name != "" && r.Name != ret.Name { + return nil, fmt.Errorf("expected tool name '%s' for function tool result, but got '%s'", ret.Name, r.Name) } + ret.Result += r.Result - for k, v := range r.Extra { - ret.Extra[k] = v - } } return ret, nil } -func concatServerToolCall(calls []*ServerToolCall) (*ServerToolCall, error) { +func concatServerToolCalls(calls []*ServerToolCall) (ret *ServerToolCall, err error) { if len(calls) == 0 { return nil, fmt.Errorf("no server tool call found") } @@ -1272,33 +1420,54 @@ func concatServerToolCall(calls []*ServerToolCall) (*ServerToolCall, error) { return calls[0], nil } - // ServerToolCall Arguments is of type any; merge strategy uses the last non-nil value - ret := &ServerToolCall{ - Extra: make(map[string]any), - } + ret = &ServerToolCall{} + + var ( + argsType reflect.Type + argsChunks reflect.Value + ) for _, c := range calls { if c == nil { continue } - if ret.Name == "" { - ret.Name = c.Name - } + if ret.CallID == "" { ret.CallID = c.CallID + } else if c.CallID != "" && c.CallID != ret.CallID { + return nil, fmt.Errorf("expected call ID '%s' for server tool call, but got '%s'", ret.CallID, c.CallID) + } + + if ret.Name == "" { + ret.Name = c.Name + } else if c.Name != "" && c.Name != ret.Name { + return nil, fmt.Errorf("expected tool name '%s' for server tool call, but got '%s'", ret.Name, c.Name) } + if c.Arguments != nil { - ret.Arguments = c.Arguments + argsType_ := reflect.TypeOf(c.Arguments) + if argsType == nil { + argsType = argsType_ + argsChunks = reflect.MakeSlice(reflect.SliceOf(argsType), 0, len(calls)) + } else if argsType != argsType_ { + return nil, fmt.Errorf("expected type '%s' for server tool call arguments, but got '%s'", argsType, argsType_) + } + argsChunks = reflect.Append(argsChunks, reflect.ValueOf(c.Arguments)) } - for k, v := range c.Extra { - ret.Extra[k] = v + } + + if argsChunks.IsValid() && !argsChunks.IsZero() { + arguments, err := internal.ConcatSliceValue(argsChunks) + if err != nil { + return nil, err } + ret.Arguments = arguments.Interface() } return ret, nil } -func concatServerToolResult(results []*ServerToolResult) (*ServerToolResult, error) { +func concatServerToolResults(results []*ServerToolResult) (ret *ServerToolResult, err error) { if len(results) == 0 { return nil, fmt.Errorf("no server tool result found") } @@ -1306,45 +1475,54 @@ func concatServerToolResult(results []*ServerToolResult) (*ServerToolResult, err return results[0], nil } - // ServerToolResult Result is of type any; merge strategy uses the last non-nil value - ret := &ServerToolResult{ - Extra: make(map[string]any), - } + ret = &ServerToolResult{} + + var ( + resType reflect.Type + resChunks reflect.Value + ) - tZeroResult := reflect.TypeOf(results[0].Result) - data := reflect.MakeSlice(reflect.SliceOf(tZeroResult), 0, 0) for _, r := range results { if r == nil { continue } - if ret.Name == "" { - ret.Name = r.Name - } + if ret.CallID == "" { ret.CallID = r.CallID + } else if r.CallID != "" && r.CallID != ret.CallID { + return nil, fmt.Errorf("expected call ID '%s' for server tool result, but got '%s'", ret.CallID, r.CallID) } + + if ret.Name == "" { + ret.Name = r.Name + } else if r.Name != "" && r.Name != ret.Name { + return nil, fmt.Errorf("expected tool name '%s' for server tool result, but got '%s'", ret.Name, r.Name) + } + if r.Result != nil { - vResult := reflect.ValueOf(r.Result) - if tZeroResult != vResult.Type() { - return nil, fmt.Errorf("tool result types are different: %v %v", tZeroResult, vResult.Type()) + resType_ := reflect.TypeOf(r.Result) + if resType == nil { + resType = resType_ + resChunks = reflect.MakeSlice(reflect.SliceOf(resType), 0, len(results)) + } else if resType != resType_ { + return nil, fmt.Errorf("expected type '%s' for server tool result, but got '%s'", resType, resType_) } - data = reflect.Append(data, vResult) - } - for k, v := range r.Extra { - ret.Extra[k] = v + resChunks = reflect.Append(resChunks, reflect.ValueOf(r.Result)) } } - d, err := internal.ConcatSliceValue(data) - if err != nil { - return nil, fmt.Errorf("failed to concat server tool result: %v", err) + if resChunks.IsValid() && !resChunks.IsZero() { + result, err := internal.ConcatSliceValue(resChunks) + if err != nil { + return nil, fmt.Errorf("failed to concat server tool result: %v", err) + } + ret.Result = result.Interface() } - ret.Result = d return ret, nil } -func concatMCPToolCall(calls []*MCPToolCall) (*MCPToolCall, error) { +func concatMCPToolCalls(calls []*MCPToolCall) (*MCPToolCall, error) { if len(calls) == 0 { return nil, fmt.Errorf("no mcp tool call found") } @@ -1352,36 +1530,38 @@ func concatMCPToolCall(calls []*MCPToolCall) (*MCPToolCall, error) { return calls[0], nil } - ret := &MCPToolCall{ - Extra: make(map[string]any), - } + ret := &MCPToolCall{} for _, c := range calls { if c == nil { continue } + + ret.Arguments += c.Arguments + if ret.ServerLabel == "" { ret.ServerLabel = c.ServerLabel + } else if c.ServerLabel != "" && c.ServerLabel != ret.ServerLabel { + return nil, fmt.Errorf("expected server label '%s' for mcp tool call, but got '%s'", ret.ServerLabel, c.ServerLabel) } - if ret.ApprovalRequestID == "" { - ret.ApprovalRequestID = c.ApprovalRequestID - } + if ret.CallID == "" { ret.CallID = c.CallID + } else if c.CallID != "" && c.CallID != ret.CallID { + return nil, fmt.Errorf("expected call ID '%s' for mcp tool call, but got '%s'", ret.CallID, c.CallID) } + if ret.Name == "" { ret.Name = c.Name - } - ret.Arguments += c.Arguments - for k, v := range c.Extra { - ret.Extra[k] = v + } else if c.Name != "" && c.Name != ret.Name { + return nil, fmt.Errorf("expected tool name '%s' for mcp tool call, but got '%s'", ret.Name, c.Name) } } return ret, nil } -func concatMCPToolResult(results []*MCPToolResult) (*MCPToolResult, error) { +func concatMCPToolResults(results []*MCPToolResult) (*MCPToolResult, error) { if len(results) == 0 { return nil, fmt.Errorf("no mcp tool result found") } @@ -1389,33 +1569,44 @@ func concatMCPToolResult(results []*MCPToolResult) (*MCPToolResult, error) { return results[0], nil } - ret := &MCPToolResult{ - Extra: make(map[string]any), - } + ret := &MCPToolResult{} for _, r := range results { if r == nil { continue } + + if r.Result != "" { + ret.Result = r.Result + } + + if ret.ServerLabel == "" { + ret.ServerLabel = r.ServerLabel + } else if r.ServerLabel != "" && r.ServerLabel != ret.ServerLabel { + return nil, fmt.Errorf("expected server label '%s' for mcp tool result, but got '%s'", ret.ServerLabel, r.ServerLabel) + } + if ret.CallID == "" { ret.CallID = r.CallID + } else if r.CallID != "" && r.CallID != ret.CallID { + return nil, fmt.Errorf("expected call ID '%s' for mcp tool result, but got '%s'", ret.CallID, r.CallID) } + if ret.Name == "" { ret.Name = r.Name + } else if r.Name != "" && r.Name != ret.Name { + return nil, fmt.Errorf("expected tool name '%s' for mcp tool result, but got '%s'", ret.Name, r.Name) } - ret.Result += r.Result + if r.Error != nil { - ret.Error = r.Error // Use the last error - } - for k, v := range r.Extra { - ret.Extra[k] = v + ret.Error = r.Error } } return ret, nil } -func concatMCPListToolsResult(results []*MCPListToolsResult) (*MCPListToolsResult, error) { +func concatMCPListToolsResults(results []*MCPListToolsResult) (*MCPListToolsResult, error) { if len(results) == 0 { return nil, fmt.Errorf("no mcp list tools result found") } @@ -1423,31 +1614,30 @@ func concatMCPListToolsResult(results []*MCPListToolsResult) (*MCPListToolsResul return results[0], nil } - ret := &MCPListToolsResult{ - Tools: make([]*MCPListToolsItem, 0), - Extra: make(map[string]any), - } + ret := &MCPListToolsResult{} for _, r := range results { if r == nil { continue } - if ret.ServerLabel == "" { - ret.ServerLabel = r.ServerLabel - } + ret.Tools = append(ret.Tools, r.Tools...) + if r.Error != "" { - ret.Error = r.Error // Use the last error + ret.Error = r.Error } - for k, v := range r.Extra { - ret.Extra[k] = v + + if ret.ServerLabel == "" { + ret.ServerLabel = r.ServerLabel + } else if r.ServerLabel != "" && r.ServerLabel != ret.ServerLabel { + return nil, fmt.Errorf("expected server label '%s' for mcp list tools result, but got '%s'", ret.ServerLabel, r.ServerLabel) } } return ret, nil } -func concatMCPToolApprovalRequest(requests []*MCPToolApprovalRequest) (*MCPToolApprovalRequest, error) { +func concatMCPToolApprovalRequests(requests []*MCPToolApprovalRequest) (*MCPToolApprovalRequest, error) { if len(requests) == 0 { return nil, fmt.Errorf("no mcp tool approval request found") } @@ -1455,50 +1645,48 @@ func concatMCPToolApprovalRequest(requests []*MCPToolApprovalRequest) (*MCPToolA return requests[0], nil } - ret := &MCPToolApprovalRequest{ - Extra: make(map[string]any), - } + ret := &MCPToolApprovalRequest{} for _, r := range requests { if r == nil { continue } + + ret.Arguments += r.Arguments + if ret.ID == "" { ret.ID = r.ID + } else if r.ID != "" && r.ID != ret.ID { + return nil, fmt.Errorf("expected request ID '%s' for mcp tool approval request, but got '%s'", ret.ID, r.ID) } + if ret.Name == "" { ret.Name = r.Name + } else if r.Name != "" && r.Name != ret.Name { + return nil, fmt.Errorf("expected tool name '%s' for mcp tool approval request, but got '%s'", ret.Name, r.Name) } - ret.Arguments += r.Arguments + if ret.ServerLabel == "" { ret.ServerLabel = r.ServerLabel - } - for k, v := range r.Extra { - ret.Extra[k] = v + } else if r.ServerLabel != "" && r.ServerLabel != ret.ServerLabel { + return nil, fmt.Errorf("expected server label '%s' for mcp tool approval request, but got '%s'", ret.ServerLabel, r.ServerLabel) } } return ret, nil } -func concatMCPToolApprovalResponse(responses []*MCPToolApprovalResponse) (*MCPToolApprovalResponse, error) { +func concatMCPToolApprovalResponses(responses []*MCPToolApprovalResponse) (*MCPToolApprovalResponse, error) { if len(responses) == 0 { return nil, fmt.Errorf("no mcp tool approval response found") } if len(responses) == 1 { return responses[0], nil } - - return responses[len(responses)-1], nil -} - -func expandSlice[T any](idx int, s []T) []T { - if len(s) > idx { - return s - } - return append(s, make([]T, idx-len(s)+1)...) + return nil, fmt.Errorf("cannot concat multiple mcp tool approval responses") } +// String returns the string representation of AgenticMessage. func (m *AgenticMessage) String() string { sb := &strings.Builder{} sb.WriteString(fmt.Sprintf("role: %s\n", m.Role)) @@ -1520,6 +1708,7 @@ func (m *AgenticMessage) String() string { return sb.String() } +// String returns the string representation of ContentBlock. func (b *ContentBlock) String() string { sb := &strings.Builder{} sb.WriteString(fmt.Sprintf("type: %s\n", b.Type)) @@ -1603,13 +1792,14 @@ func (b *ContentBlock) String() string { } } - if b.StreamMeta != nil { - sb.WriteString(fmt.Sprintf(" stream_index: %d\n", b.StreamMeta.Index)) + if b.StreamingMeta != nil { + sb.WriteString(fmt.Sprintf(" stream_index: %d\n", b.StreamingMeta.Index)) } return sb.String() } +// String returns the string representation of Reasoning. func (r *Reasoning) String() string { sb := &strings.Builder{} sb.WriteString(fmt.Sprintf(" summary: %d items\n", len(r.Summary))) @@ -1622,22 +1812,27 @@ func (r *Reasoning) String() string { return sb.String() } +// String returns the string representation of UserInputText. func (u *UserInputText) String() string { return fmt.Sprintf(" text: %s\n", u.Text) } +// String returns the string representation of UserInputImage. func (u *UserInputImage) String() string { return formatMediaString(u.URL, u.Base64Data, u.MIMEType, u.Detail) } +// String returns the string representation of UserInputAudio. func (u *UserInputAudio) String() string { return formatMediaString(u.URL, u.Base64Data, u.MIMEType, "") } +// String returns the string representation of UserInputVideo. func (u *UserInputVideo) String() string { return formatMediaString(u.URL, u.Base64Data, u.MIMEType, "") } +// String returns the string representation of UserInputFile. func (u *UserInputFile) String() string { sb := &strings.Builder{} if u.Name != "" { @@ -1647,22 +1842,27 @@ func (u *UserInputFile) String() string { return sb.String() } +// String returns the string representation of AssistantGenText. func (a *AssistantGenText) String() string { return fmt.Sprintf(" text: %s\n", a.Text) } +// String returns the string representation of AssistantGenImage. func (a *AssistantGenImage) String() string { return formatMediaString(a.URL, a.Base64Data, a.MIMEType, "") } +// String returns the string representation of AssistantGenAudio. func (a *AssistantGenAudio) String() string { return formatMediaString(a.URL, a.Base64Data, a.MIMEType, "") } +// String returns the string representation of AssistantGenVideo. func (a *AssistantGenVideo) String() string { return formatMediaString(a.URL, a.Base64Data, a.MIMEType, "") } +// String returns the string representation of FunctionToolCall. func (f *FunctionToolCall) String() string { sb := &strings.Builder{} sb.WriteString(fmt.Sprintf(" call_id: %s\n", f.CallID)) @@ -1671,6 +1871,7 @@ func (f *FunctionToolCall) String() string { return sb.String() } +// String returns the string representation of FunctionToolResult. func (f *FunctionToolResult) String() string { sb := &strings.Builder{} sb.WriteString(fmt.Sprintf(" call_id: %s\n", f.CallID)) @@ -1679,6 +1880,7 @@ func (f *FunctionToolResult) String() string { return sb.String() } +// String returns the string representation of ServerToolCall. func (s *ServerToolCall) String() string { sb := &strings.Builder{} sb.WriteString(fmt.Sprintf(" name: %s\n", s.Name)) @@ -1689,6 +1891,7 @@ func (s *ServerToolCall) String() string { return sb.String() } +// String returns the string representation of ServerToolResult. func (s *ServerToolResult) String() string { sb := &strings.Builder{} sb.WriteString(fmt.Sprintf(" name: %s\n", s.Name)) @@ -1699,18 +1902,17 @@ func (s *ServerToolResult) String() string { return sb.String() } +// String returns the string representation of MCPToolCall. func (m *MCPToolCall) String() string { sb := &strings.Builder{} sb.WriteString(fmt.Sprintf(" server_label: %s\n", m.ServerLabel)) sb.WriteString(fmt.Sprintf(" call_id: %s\n", m.CallID)) sb.WriteString(fmt.Sprintf(" name: %s\n", m.Name)) sb.WriteString(fmt.Sprintf(" arguments: %s\n", m.Arguments)) - if m.ApprovalRequestID != "" { - sb.WriteString(fmt.Sprintf(" approval_request_id: %s\n", m.ApprovalRequestID)) - } return sb.String() } +// String returns the string representation of MCPToolResult. func (m *MCPToolResult) String() string { sb := &strings.Builder{} sb.WriteString(fmt.Sprintf(" call_id: %s\n", m.CallID)) @@ -1722,6 +1924,7 @@ func (m *MCPToolResult) String() string { return sb.String() } +// String returns the string representation of MCPListToolsResult. func (m *MCPListToolsResult) String() string { sb := &strings.Builder{} sb.WriteString(fmt.Sprintf(" server_label: %s\n", m.ServerLabel)) @@ -1735,6 +1938,7 @@ func (m *MCPListToolsResult) String() string { return sb.String() } +// String returns the string representation of MCPToolApprovalRequest. func (m *MCPToolApprovalRequest) String() string { sb := &strings.Builder{} sb.WriteString(fmt.Sprintf(" server_label: %s\n", m.ServerLabel)) @@ -1744,6 +1948,7 @@ func (m *MCPToolApprovalRequest) String() string { return sb.String() } +// String returns the string representation of MCPToolApprovalResponse. func (m *MCPToolApprovalResponse) String() string { sb := &strings.Builder{} sb.WriteString(fmt.Sprintf(" approval_request_id: %s\n", m.ApprovalRequestID)) @@ -1754,6 +1959,7 @@ func (m *MCPToolApprovalResponse) String() string { return sb.String() } +// String returns the string representation of AgenticResponseMeta. func (a *AgenticResponseMeta) String() string { sb := &strings.Builder{} sb.WriteString("response_meta:\n") @@ -1792,3 +1998,17 @@ func formatMediaString(url, base64Data string, mimeType string, detail any) stri } return sb.String() } + +func validateExtensionType(expected reflect.Type, actual any) (reflect.Type, bool) { + if actual == nil { + return expected, true + } + actualType := reflect.TypeOf(actual) + if expected == nil { + return actualType, true + } + if expected != actualType { + return expected, false + } + return expected, true +} diff --git a/schema/agentic_message_test.go b/schema/agentic_message_test.go index 0cafcd9ff..016aa5c4e 100644 --- a/schema/agentic_message_test.go +++ b/schema/agentic_message_test.go @@ -18,6 +18,7 @@ package schema import ( "context" + "reflect" "testing" "github.com/stretchr/testify/assert" @@ -75,7 +76,7 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "Hello ", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -87,7 +88,7 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "World!", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -112,7 +113,7 @@ func TestConcatAgenticMessages(t *testing.T) { {Index: 0, Text: "First "}, }, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -126,7 +127,7 @@ func TestConcatAgenticMessages(t *testing.T) { {Index: 0, Text: "Second"}, }, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -137,7 +138,7 @@ func TestConcatAgenticMessages(t *testing.T) { assert.Len(t, result.ContentBlocks, 1) assert.Len(t, result.ContentBlocks[0].Reasoning.Summary, 1) assert.Equal(t, "First Second", result.ContentBlocks[0].Reasoning.Summary[0].Text) - assert.Equal(t, int64(0), result.ContentBlocks[0].Reasoning.Summary[0].Index) + assert.Equal(t, 0, result.ContentBlocks[0].Reasoning.Summary[0].Index) }) t.Run("concat reasoning with index", func(t *testing.T) { @@ -153,7 +154,7 @@ func TestConcatAgenticMessages(t *testing.T) { {Index: 1, Text: "Part2-"}, }, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -168,7 +169,7 @@ func TestConcatAgenticMessages(t *testing.T) { {Index: 1, Text: "Part4"}, }, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -185,26 +186,26 @@ func TestConcatAgenticMessages(t *testing.T) { t.Run("concat user input text", func(t *testing.T) { msgs := []*AgenticMessage{ { - Role: AgenticRoleTypeUser, + Role: AgenticRoleTypeAssistant, ContentBlocks: []*ContentBlock{ { - Type: ContentBlockTypeUserInputText, - UserInputText: &UserInputText{ + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ Text: "Hello ", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, { - Role: AgenticRoleTypeUser, + Role: AgenticRoleTypeAssistant, ContentBlocks: []*ContentBlock{ { - Type: ContentBlockTypeUserInputText, - UserInputText: &UserInputText{ + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ Text: "World!", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -213,35 +214,35 @@ func TestConcatAgenticMessages(t *testing.T) { result, err := ConcatAgenticMessages(msgs) assert.NoError(t, err) assert.Len(t, result.ContentBlocks, 1) - assert.Equal(t, "Hello World!", result.ContentBlocks[0].UserInputText.Text) + assert.Equal(t, "Hello World!", result.ContentBlocks[0].AssistantGenText.Text) }) - t.Run("concat user input image", func(t *testing.T) { - url1 := "https://example.com/image1.jpg" - url2 := "https://example.com/image2.jpg" + t.Run("concat assistant gen image", func(t *testing.T) { + base1 := "1" + base2 := "2" msgs := []*AgenticMessage{ { - Role: AgenticRoleTypeUser, + Role: AgenticRoleTypeAssistant, ContentBlocks: []*ContentBlock{ { - Type: ContentBlockTypeUserInputImage, - UserInputImage: &UserInputImage{ - URL: url1, + Type: ContentBlockTypeAssistantGenImage, + AssistantGenImage: &AssistantGenImage{ + Base64Data: base1, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, { - Role: AgenticRoleTypeUser, + Role: AgenticRoleTypeAssistant, ContentBlocks: []*ContentBlock{ { - Type: ContentBlockTypeUserInputImage, - UserInputImage: &UserInputImage{ - URL: url2, + Type: ContentBlockTypeAssistantGenImage, + AssistantGenImage: &AssistantGenImage{ + Base64Data: base2, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -250,11 +251,10 @@ func TestConcatAgenticMessages(t *testing.T) { result, err := ConcatAgenticMessages(msgs) assert.NoError(t, err) assert.Len(t, result.ContentBlocks, 1) - // Should take the last image - assert.Equal(t, url2, result.ContentBlocks[0].UserInputImage.URL) + assert.Equal(t, "12", result.ContentBlocks[0].AssistantGenImage.Base64Data) }) - t.Run("concat user input audio", func(t *testing.T) { + t.Run("concat user input audio - should error", func(t *testing.T) { url1 := "https://example.com/audio1.mp3" url2 := "https://example.com/audio2.mp3" @@ -267,7 +267,7 @@ func TestConcatAgenticMessages(t *testing.T) { UserInputAudio: &UserInputAudio{ URL: url1, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -279,20 +279,18 @@ func TestConcatAgenticMessages(t *testing.T) { UserInputAudio: &UserInputAudio{ URL: url2, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, } - result, err := ConcatAgenticMessages(msgs) - assert.NoError(t, err) - assert.Len(t, result.ContentBlocks, 1) - // Should take the last audio - assert.Equal(t, url2, result.ContentBlocks[0].UserInputAudio.URL) + _, err := ConcatAgenticMessages(msgs) + assert.Error(t, err) + assert.ErrorContains(t, err, "cannot concat multiple user input audios") }) - t.Run("concat user input video", func(t *testing.T) { + t.Run("concat user input video - should error", func(t *testing.T) { url1 := "https://example.com/video1.mp4" url2 := "https://example.com/video2.mp4" @@ -305,7 +303,7 @@ func TestConcatAgenticMessages(t *testing.T) { UserInputVideo: &UserInputVideo{ URL: url1, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -317,17 +315,15 @@ func TestConcatAgenticMessages(t *testing.T) { UserInputVideo: &UserInputVideo{ URL: url2, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, } - result, err := ConcatAgenticMessages(msgs) - assert.NoError(t, err) - assert.Len(t, result.ContentBlocks, 1) - // Should take the last video - assert.Equal(t, url2, result.ContentBlocks[0].UserInputVideo.URL) + _, err := ConcatAgenticMessages(msgs) + assert.Error(t, err) + assert.ErrorContains(t, err, "cannot concat multiple user input videos") }) t.Run("concat assistant gen text", func(t *testing.T) { @@ -340,7 +336,7 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "Generated ", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -352,7 +348,7 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "Text", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -365,9 +361,6 @@ func TestConcatAgenticMessages(t *testing.T) { }) t.Run("concat assistant gen image", func(t *testing.T) { - url1 := "https://example.com/gen_image1.jpg" - url2 := "https://example.com/gen_image2.jpg" - msgs := []*AgenticMessage{ { Role: AgenticRoleTypeAssistant, @@ -375,9 +368,9 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeAssistantGenImage, AssistantGenImage: &AssistantGenImage{ - URL: url1, + Base64Data: "part1", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -387,9 +380,9 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeAssistantGenImage, AssistantGenImage: &AssistantGenImage{ - URL: url2, + Base64Data: "part2", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -398,14 +391,10 @@ func TestConcatAgenticMessages(t *testing.T) { result, err := ConcatAgenticMessages(msgs) assert.NoError(t, err) assert.Len(t, result.ContentBlocks, 1) - // Should take the last image - assert.Equal(t, url2, result.ContentBlocks[0].AssistantGenImage.URL) + assert.Equal(t, "part1part2", result.ContentBlocks[0].AssistantGenImage.Base64Data) }) t.Run("concat assistant gen audio", func(t *testing.T) { - url1 := "https://example.com/gen_audio1.mp3" - url2 := "https://example.com/gen_audio2.mp3" - msgs := []*AgenticMessage{ { Role: AgenticRoleTypeAssistant, @@ -413,9 +402,9 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeAssistantGenAudio, AssistantGenAudio: &AssistantGenAudio{ - URL: url1, + Base64Data: "audio1", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -425,9 +414,9 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeAssistantGenAudio, AssistantGenAudio: &AssistantGenAudio{ - URL: url2, + Base64Data: "audio2", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -436,14 +425,10 @@ func TestConcatAgenticMessages(t *testing.T) { result, err := ConcatAgenticMessages(msgs) assert.NoError(t, err) assert.Len(t, result.ContentBlocks, 1) - // Should take the last audio - assert.Equal(t, url2, result.ContentBlocks[0].AssistantGenAudio.URL) + assert.Equal(t, "audio1audio2", result.ContentBlocks[0].AssistantGenAudio.Base64Data) }) t.Run("concat assistant gen video", func(t *testing.T) { - url1 := "https://example.com/gen_video1.mp4" - url2 := "https://example.com/gen_video2.mp4" - msgs := []*AgenticMessage{ { Role: AgenticRoleTypeAssistant, @@ -451,9 +436,9 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeAssistantGenVideo, AssistantGenVideo: &AssistantGenVideo{ - URL: url1, + Base64Data: "video1", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -463,9 +448,9 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeAssistantGenVideo, AssistantGenVideo: &AssistantGenVideo{ - URL: url2, + Base64Data: "video2", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -474,8 +459,7 @@ func TestConcatAgenticMessages(t *testing.T) { result, err := ConcatAgenticMessages(msgs) assert.NoError(t, err) assert.Len(t, result.ContentBlocks, 1) - // Should take the last video - assert.Equal(t, url2, result.ContentBlocks[0].AssistantGenVideo.URL) + assert.Equal(t, "video1video2", result.ContentBlocks[0].AssistantGenVideo.Base64Data) }) t.Run("concat function tool call", func(t *testing.T) { @@ -490,7 +474,7 @@ func TestConcatAgenticMessages(t *testing.T) { Name: "get_weather", Arguments: `{"location`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -502,7 +486,7 @@ func TestConcatAgenticMessages(t *testing.T) { FunctionToolCall: &FunctionToolCall{ Arguments: `":"NYC"}`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -528,7 +512,7 @@ func TestConcatAgenticMessages(t *testing.T) { Name: "get_weather", Result: `{"temp`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -540,7 +524,7 @@ func TestConcatAgenticMessages(t *testing.T) { FunctionToolResult: &FunctionToolResult{ Result: `":72}`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -565,7 +549,7 @@ func TestConcatAgenticMessages(t *testing.T) { CallID: "server_call_1", Name: "server_func", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -577,7 +561,7 @@ func TestConcatAgenticMessages(t *testing.T) { ServerToolCall: &ServerToolCall{ Arguments: map[string]any{"key": "value"}, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -603,7 +587,7 @@ func TestConcatAgenticMessages(t *testing.T) { Name: "server_func", Result: "result1", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -611,11 +595,9 @@ func TestConcatAgenticMessages(t *testing.T) { Role: AgenticRoleTypeAssistant, ContentBlocks: []*ContentBlock{ { - Type: ContentBlockTypeServerToolResult, - ServerToolResult: &ServerToolResult{ - Result: "result2", - }, - StreamMeta: &StreamMeta{Index: 0}, + Type: ContentBlockTypeServerToolResult, + ServerToolResult: &ServerToolResult{}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -626,6 +608,7 @@ func TestConcatAgenticMessages(t *testing.T) { assert.Len(t, result.ContentBlocks, 1) assert.Equal(t, "server_call_1", result.ContentBlocks[0].ServerToolResult.CallID) assert.Equal(t, "server_func", result.ContentBlocks[0].ServerToolResult.Name) + assert.Equal(t, "result1", result.ContentBlocks[0].ServerToolResult.Result) }) t.Run("concat mcp tool call", func(t *testing.T) { @@ -641,7 +624,7 @@ func TestConcatAgenticMessages(t *testing.T) { Name: "mcp_func", Arguments: `{"arg`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -653,7 +636,7 @@ func TestConcatAgenticMessages(t *testing.T) { MCPToolCall: &MCPToolCall{ Arguments: `":123}`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -676,11 +659,12 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeMCPToolResult, MCPToolResult: &MCPToolResult{ - CallID: "mcp_call_1", - Name: "mcp_func", - Result: `{"res`, + ServerLabel: "mcp-server", + CallID: "mcp_call_1", + Name: "mcp_func", + Result: `First`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -690,9 +674,9 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeMCPToolResult, MCPToolResult: &MCPToolResult{ - Result: `ult":true}`, + Result: `Second`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -701,9 +685,10 @@ func TestConcatAgenticMessages(t *testing.T) { result, err := ConcatAgenticMessages(msgs) assert.NoError(t, err) assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "mcp-server", result.ContentBlocks[0].MCPToolResult.ServerLabel) assert.Equal(t, "mcp_call_1", result.ContentBlocks[0].MCPToolResult.CallID) assert.Equal(t, "mcp_func", result.ContentBlocks[0].MCPToolResult.Name) - assert.Equal(t, `{"result":true}`, result.ContentBlocks[0].MCPToolResult.Result) + assert.Equal(t, `Second`, result.ContentBlocks[0].MCPToolResult.Result) }) t.Run("concat mcp list tools", func(t *testing.T) { @@ -719,7 +704,7 @@ func TestConcatAgenticMessages(t *testing.T) { {Name: "tool1"}, }, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -733,7 +718,7 @@ func TestConcatAgenticMessages(t *testing.T) { {Name: "tool2"}, }, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -759,7 +744,7 @@ func TestConcatAgenticMessages(t *testing.T) { ServerLabel: "mcp-server", Arguments: `{"request`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -771,7 +756,7 @@ func TestConcatAgenticMessages(t *testing.T) { MCPToolApprovalRequest: &MCPToolApprovalRequest{ Arguments: `":1}`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -786,7 +771,7 @@ func TestConcatAgenticMessages(t *testing.T) { assert.Equal(t, `{"request":1}`, result.ContentBlocks[0].MCPToolApprovalRequest.Arguments) }) - t.Run("concat mcp tool approval response", func(t *testing.T) { + t.Run("concat mcp tool approval response - should error", func(t *testing.T) { response1 := &MCPToolApprovalResponse{ ApprovalRequestID: "approval_1", Approve: false, @@ -803,7 +788,7 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeMCPToolApprovalResponse, MCPToolApprovalResponse: response1, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -813,17 +798,15 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeMCPToolApprovalResponse, MCPToolApprovalResponse: response2, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, } - result, err := ConcatAgenticMessages(msgs) - assert.NoError(t, err) - assert.Len(t, result.ContentBlocks, 1) - // Should take the last response - assert.Equal(t, response2, result.ContentBlocks[0].MCPToolApprovalResponse) + _, err := ConcatAgenticMessages(msgs) + assert.Error(t, err) + assert.ErrorContains(t, err, "cannot concat multiple mcp tool approval responses") }) t.Run("concat response meta", func(t *testing.T) { @@ -865,7 +848,7 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "Hello", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -877,7 +860,7 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "World", }, - // No StreamMeta - non-streaming + // No StreamingMeta - non-streaming }, }, }, @@ -901,7 +884,7 @@ func TestConcatAgenticMessages(t *testing.T) { Name: "list_files", Arguments: `{"path`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -913,7 +896,7 @@ func TestConcatAgenticMessages(t *testing.T) { MCPToolCall: &MCPToolCall{ Arguments: `":"/tmp"}`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -927,7 +910,7 @@ func TestConcatAgenticMessages(t *testing.T) { assert.Equal(t, `{"path":"/tmp"}`, result.ContentBlocks[0].MCPToolCall.Arguments) }) - t.Run("concat user input text", func(t *testing.T) { + t.Run("concat user input text - should error", func(t *testing.T) { msgs := []*AgenticMessage{ { Role: AgenticRoleTypeUser, @@ -937,7 +920,7 @@ func TestConcatAgenticMessages(t *testing.T) { UserInputText: &UserInputText{ Text: "What is ", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -949,16 +932,15 @@ func TestConcatAgenticMessages(t *testing.T) { UserInputText: &UserInputText{ Text: "the weather?", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, } - result, err := ConcatAgenticMessages(msgs) - assert.NoError(t, err) - assert.Len(t, result.ContentBlocks, 1) - assert.Equal(t, "What is the weather?", result.ContentBlocks[0].UserInputText.Text) + _, err := ConcatAgenticMessages(msgs) + assert.Error(t, err) + assert.ErrorContains(t, err, "cannot concat multiple user input texts") }) t.Run("multiple stream indexes - sparse indexes", func(t *testing.T) { @@ -971,14 +953,14 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "Index0-", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, { Type: ContentBlockTypeAssistantGenText, AssistantGenText: &AssistantGenText{ Text: "Index2-", }, - StreamMeta: &StreamMeta{Index: 2}, + StreamingMeta: &StreamingMeta{Index: 2}, }, }, }, @@ -990,14 +972,14 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "Part2", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, { Type: ContentBlockTypeAssistantGenText, AssistantGenText: &AssistantGenText{ Text: "Part2", }, - StreamMeta: &StreamMeta{Index: 2}, + StreamingMeta: &StreamingMeta{Index: 2}, }, }, }, @@ -1020,7 +1002,7 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "Text ", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, { Type: ContentBlockTypeFunctionToolCall, @@ -1029,7 +1011,7 @@ func TestConcatAgenticMessages(t *testing.T) { Name: "func1", Arguments: `{"a`, }, - StreamMeta: &StreamMeta{Index: 1}, + StreamingMeta: &StreamingMeta{Index: 1}, }, }, }, @@ -1041,14 +1023,14 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "Content", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, { Type: ContentBlockTypeFunctionToolCall, FunctionToolCall: &FunctionToolCall{ Arguments: `":1}`, }, - StreamMeta: &StreamMeta{Index: 1}, + StreamingMeta: &StreamingMeta{Index: 1}, }, }, }, @@ -1073,21 +1055,21 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "A", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, { Type: ContentBlockTypeAssistantGenText, AssistantGenText: &AssistantGenText{ Text: "B", }, - StreamMeta: &StreamMeta{Index: 1}, + StreamingMeta: &StreamingMeta{Index: 1}, }, { Type: ContentBlockTypeAssistantGenText, AssistantGenText: &AssistantGenText{ Text: "C", }, - StreamMeta: &StreamMeta{Index: 2}, + StreamingMeta: &StreamingMeta{Index: 2}, }, }, }, @@ -1099,21 +1081,21 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "1", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, { Type: ContentBlockTypeAssistantGenText, AssistantGenText: &AssistantGenText{ Text: "2", }, - StreamMeta: &StreamMeta{Index: 1}, + StreamingMeta: &StreamingMeta{Index: 1}, }, { Type: ContentBlockTypeAssistantGenText, AssistantGenText: &AssistantGenText{ Text: "3", }, - StreamMeta: &StreamMeta{Index: 2}, + StreamingMeta: &StreamingMeta{Index: 2}, }, }, }, @@ -1276,7 +1258,7 @@ func TestAgenticMessageString(t *testing.T) { Name: "get_current_weather", Arguments: `{"location":"New York City","unit":"fahrenheit"}`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, { Type: ContentBlockTypeFunctionToolResult, @@ -1289,11 +1271,10 @@ func TestAgenticMessageString(t *testing.T) { { Type: ContentBlockTypeMCPToolCall, MCPToolCall: &MCPToolCall{ - ServerLabel: "weather-mcp-server", - CallID: "mcp_forecast_456", - Name: "get_7day_forecast", - Arguments: `{"city":"New York","days":7}`, - ApprovalRequestID: "approval_req_789", + ServerLabel: "weather-mcp-server", + CallID: "mcp_forecast_456", + Name: "get_7day_forecast", + Arguments: `{"city":"New York","days":7}`, }, }, { @@ -1363,7 +1344,6 @@ content_blocks: call_id: mcp_forecast_456 name: get_7day_forecast arguments: {"city":"New York","days":7} - approval_request_id: approval_req_789 [7] type: mcp_tool_result call_id: mcp_forecast_456 name: get_7day_forecast @@ -1378,4 +1358,294 @@ content_blocks: response_meta: token_usage: prompt=250, completion=180, total=430 `, output) + + t.Run("full fields", func(t *testing.T) { + msg := &AgenticMessage{ + Role: AgenticRoleTypeSystem, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputAudio, + UserInputAudio: &UserInputAudio{ + URL: "http://audio.com", + Base64Data: "audio_data", + MIMEType: "audio/mp3", + }, + }, + { + Type: ContentBlockTypeUserInputVideo, + UserInputVideo: &UserInputVideo{ + URL: "http://video.com", + Base64Data: "video_data", + MIMEType: "video/mp4", + }, + }, + { + Type: ContentBlockTypeUserInputFile, + UserInputFile: &UserInputFile{ + URL: "http://file.com", + Name: "file.txt", + Base64Data: "file_data", + MIMEType: "text/plain", + }, + }, + { + Type: ContentBlockTypeAssistantGenImage, + AssistantGenImage: &AssistantGenImage{ + URL: "http://gen_image.com", + Base64Data: "gen_image_data", + MIMEType: "image/png", + }, + }, + { + Type: ContentBlockTypeAssistantGenAudio, + AssistantGenAudio: &AssistantGenAudio{ + URL: "http://gen_audio.com", + Base64Data: "gen_audio_data", + MIMEType: "audio/wav", + }, + }, + { + Type: ContentBlockTypeAssistantGenVideo, + AssistantGenVideo: &AssistantGenVideo{ + URL: "http://gen_video.com", + Base64Data: "gen_video_data", + MIMEType: "video/mp4", + }, + }, + { + Type: ContentBlockTypeServerToolCall, + ServerToolCall: &ServerToolCall{ + Name: "server_tool", + CallID: "call_1", + Arguments: map[string]any{"a": 1}, + }, + }, + { + Type: ContentBlockTypeServerToolResult, + ServerToolResult: &ServerToolResult{ + Name: "server_tool", + CallID: "call_1", + Result: map[string]any{"success": true}, + }, + }, + { + Type: ContentBlockTypeMCPToolApprovalRequest, + MCPToolApprovalRequest: &MCPToolApprovalRequest{ + ID: "req_1", + Name: "mcp_tool", + ServerLabel: "mcp_server", + Arguments: "{}", + }, + }, + { + Type: ContentBlockTypeMCPToolApprovalResponse, + MCPToolApprovalResponse: &MCPToolApprovalResponse{ + ApprovalRequestID: "req_1", + Approve: true, + Reason: "looks good", + }, + }, + }, + } + + s := msg.String() + assert.Contains(t, s, "role: system") + assert.Contains(t, s, "type: user_input_audio") + assert.Contains(t, s, "http://audio.com") + assert.Contains(t, s, "type: user_input_video") + assert.Contains(t, s, "http://video.com") + assert.Contains(t, s, "type: user_input_file") + assert.Contains(t, s, "file.txt") + assert.Contains(t, s, "type: assistant_gen_image") + assert.Contains(t, s, "http://gen_image.com") + assert.Contains(t, s, "type: assistant_gen_audio") + assert.Contains(t, s, "http://gen_audio.com") + assert.Contains(t, s, "type: assistant_gen_video") + assert.Contains(t, s, "http://gen_video.com") + assert.Contains(t, s, "type: server_tool_call") + assert.Contains(t, s, "server_tool") + assert.Contains(t, s, "map[a:1]") + assert.Contains(t, s, "type: server_tool_result") + assert.Contains(t, s, "map[success:true]") + assert.Contains(t, s, "type: mcp_tool_approval_request") + assert.Contains(t, s, "req_1") + assert.Contains(t, s, "type: mcp_tool_approval_response") + assert.Contains(t, s, "looks good") + }) + + t.Run("nil/empty fields", func(t *testing.T) { + msg := &AgenticMessage{ + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + {Type: ContentBlockTypeUserInputAudio, UserInputAudio: &UserInputAudio{}}, // empty + {Type: ContentBlockTypeUserInputVideo, UserInputVideo: &UserInputVideo{}}, + {Type: ContentBlockTypeUserInputFile, UserInputFile: &UserInputFile{}}, + {Type: ContentBlockTypeAssistantGenImage, AssistantGenImage: &AssistantGenImage{}}, + {Type: ContentBlockTypeAssistantGenAudio, AssistantGenAudio: &AssistantGenAudio{}}, + {Type: ContentBlockTypeAssistantGenVideo, AssistantGenVideo: &AssistantGenVideo{}}, + {Type: ContentBlockTypeServerToolCall, ServerToolCall: &ServerToolCall{Name: "t"}}, // No CallID + {Type: ContentBlockTypeServerToolResult, ServerToolResult: &ServerToolResult{Name: "t"}}, // No CallID + {Type: ContentBlockTypeMCPToolResult, MCPToolResult: &MCPToolResult{Name: "t"}}, // No Error + {Type: ContentBlockTypeMCPListToolsResult, MCPListToolsResult: &MCPListToolsResult{}}, // No Error + {Type: ContentBlockTypeMCPToolApprovalResponse, MCPToolApprovalResponse: &MCPToolApprovalResponse{Approve: false}}, // No Reason + nil, // Nil block in slice + }, + } + + s := msg.String() + assert.Contains(t, s, "type: user_input_audio") + assert.NotContains(t, s, "mime_type:") + assert.Contains(t, s, "type: server_tool_call") + }) + + t.Run("nil content struct in block", func(t *testing.T) { + // Test cases where the specific content struct is nil but type is set + // This shouldn't crash and should just print type + msg := &AgenticMessage{ + ContentBlocks: []*ContentBlock{ + {Type: ContentBlockTypeReasoning, Reasoning: nil}, + {Type: ContentBlockTypeUserInputText, UserInputText: nil}, + {Type: ContentBlockTypeUserInputImage, UserInputImage: nil}, + {Type: ContentBlockTypeUserInputAudio, UserInputAudio: nil}, + {Type: ContentBlockTypeUserInputVideo, UserInputVideo: nil}, + {Type: ContentBlockTypeUserInputFile, UserInputFile: nil}, + {Type: ContentBlockTypeAssistantGenText, AssistantGenText: nil}, + {Type: ContentBlockTypeAssistantGenImage, AssistantGenImage: nil}, + {Type: ContentBlockTypeAssistantGenAudio, AssistantGenAudio: nil}, + {Type: ContentBlockTypeAssistantGenVideo, AssistantGenVideo: nil}, + {Type: ContentBlockTypeFunctionToolCall, FunctionToolCall: nil}, + {Type: ContentBlockTypeFunctionToolResult, FunctionToolResult: nil}, + {Type: ContentBlockTypeServerToolCall, ServerToolCall: nil}, + {Type: ContentBlockTypeServerToolResult, ServerToolResult: nil}, + {Type: ContentBlockTypeMCPToolCall, MCPToolCall: nil}, + {Type: ContentBlockTypeMCPToolResult, MCPToolResult: nil}, + {Type: ContentBlockTypeMCPListToolsResult, MCPListToolsResult: nil}, + {Type: ContentBlockTypeMCPToolApprovalRequest, MCPToolApprovalRequest: nil}, + {Type: ContentBlockTypeMCPToolApprovalResponse, MCPToolApprovalResponse: nil}, + }, + } + s := msg.String() + assert.Contains(t, s, "type: reasoning") + // ensure no panic and basic output present + }) +} + +func TestDeveloperAgenticMessage(t *testing.T) { + t.Run("basic", func(t *testing.T) { + msg := DeveloperAgenticMessage("developer") + assert.Equal(t, AgenticRoleTypeDeveloper, msg.Role) + assert.Len(t, msg.ContentBlocks, 1) + assert.Equal(t, "developer", msg.ContentBlocks[0].UserInputText.Text) + }) +} + +func TestSystemAgenticMessage(t *testing.T) { + t.Run("basic", func(t *testing.T) { + msg := SystemAgenticMessage("system") + assert.Equal(t, AgenticRoleTypeSystem, msg.Role) + assert.Len(t, msg.ContentBlocks, 1) + assert.Equal(t, "system", msg.ContentBlocks[0].UserInputText.Text) + }) +} + +func TestUserAgenticMessage(t *testing.T) { + t.Run("basic", func(t *testing.T) { + msg := UserAgenticMessage("user") + assert.Equal(t, AgenticRoleTypeUser, msg.Role) + assert.Len(t, msg.ContentBlocks, 1) + assert.Equal(t, "user", msg.ContentBlocks[0].UserInputText.Text) + }) +} + +func TestFunctionToolResultAgenticMessage(t *testing.T) { + t.Run("basic", func(t *testing.T) { + msg := FunctionToolResultAgenticMessage("call_1", "tool_name", "result_str") + assert.Equal(t, AgenticRoleTypeUser, msg.Role) + assert.Len(t, msg.ContentBlocks, 1) + assert.Equal(t, ContentBlockTypeFunctionToolResult, msg.ContentBlocks[0].Type) + assert.Equal(t, "call_1", msg.ContentBlocks[0].FunctionToolResult.CallID) + assert.Equal(t, "tool_name", msg.ContentBlocks[0].FunctionToolResult.Name) + assert.Equal(t, "result_str", msg.ContentBlocks[0].FunctionToolResult.Result) + }) +} + +func TestNewContentBlock(t *testing.T) { + cbType := reflect.TypeOf(ContentBlock{}) + for i := 0; i < cbType.NumField(); i++ { + field := cbType.Field(i) + + // Skip non-content fields + if field.Name == "Type" || field.Name == "Extra" || field.Name == "StreamingMeta" { + continue + } + + t.Run(field.Name, func(t *testing.T) { + // Ensure field is a pointer + assert.Equal(t, reflect.Ptr, field.Type.Kind(), "Field %s should be a pointer", field.Name) + + // Create a new instance of the field's type + // field.Type is *T, so Elem() is T. reflect.New(T) returns *T. + elemType := field.Type.Elem() + inputVal := reflect.New(elemType) + input := inputVal.Interface() + + // Call NewContentBlock (generic) via type switch + var block *ContentBlock + switch v := input.(type) { + case *Reasoning: + block = NewContentBlock(v) + case *UserInputText: + block = NewContentBlock(v) + case *UserInputImage: + block = NewContentBlock(v) + case *UserInputAudio: + block = NewContentBlock(v) + case *UserInputVideo: + block = NewContentBlock(v) + case *UserInputFile: + block = NewContentBlock(v) + case *AssistantGenText: + block = NewContentBlock(v) + case *AssistantGenImage: + block = NewContentBlock(v) + case *AssistantGenAudio: + block = NewContentBlock(v) + case *AssistantGenVideo: + block = NewContentBlock(v) + case *FunctionToolCall: + block = NewContentBlock(v) + case *FunctionToolResult: + block = NewContentBlock(v) + case *ServerToolCall: + block = NewContentBlock(v) + case *ServerToolResult: + block = NewContentBlock(v) + case *MCPToolCall: + block = NewContentBlock(v) + case *MCPToolResult: + block = NewContentBlock(v) + case *MCPListToolsResult: + block = NewContentBlock(v) + case *MCPToolApprovalRequest: + block = NewContentBlock(v) + case *MCPToolApprovalResponse: + block = NewContentBlock(v) + default: + t.Fatalf("unsupported ContentBlock field type: %T", input) + } + + // Assertions + assert.NotNil(t, block, "NewContentBlock should return non-nil for type %T", input) + + // Check if the corresponding field in block is set equals to input + blockVal := reflect.ValueOf(block).Elem() + fieldVal := blockVal.FieldByName(field.Name) + assert.True(t, fieldVal.IsValid(), "Field %s not found in result", field.Name) + assert.Equal(t, input, fieldVal.Interface(), "Field %s should match input", field.Name) + + // Check Type is set + typeVal := blockVal.FieldByName("Type") + assert.NotEmpty(t, typeVal.String(), "Type should be set for %s", field.Name) + }) + } } diff --git a/schema/claude/consts.go b/schema/claude/consts.go index cbf8784f6..714b0362e 100644 --- a/schema/claude/consts.go +++ b/schema/claude/consts.go @@ -14,6 +14,7 @@ * limitations under the License. */ +// Package claude defines constants for claude. package claude type TextCitationType string diff --git a/schema/claude/content_block.go b/schema/claude/extension.go similarity index 51% rename from schema/claude/content_block.go rename to schema/claude/extension.go index 0c43d1045..5df8d8907 100644 --- a/schema/claude/content_block.go +++ b/schema/claude/extension.go @@ -16,6 +16,15 @@ package claude +import ( + "fmt" +) + +type ResponseMetaExtension struct { + ID string `json:"id,omitempty"` + StopReason string `json:"stop_reason,omitempty"` +} + type AssistantGenTextExtension struct { Citations []*TextCitation `json:"citations,omitempty"` } @@ -33,30 +42,30 @@ type CitationCharLocation struct { CitedText string `json:"cited_text,omitempty"` DocumentTitle string `json:"document_title,omitempty"` - DocumentIndex int64 `json:"document_index,omitempty"` + DocumentIndex int `json:"document_index,omitempty"` - StartCharIndex int64 `json:"start_char_index,omitempty"` - EndCharIndex int64 `json:"end_char_index,omitempty"` + StartCharIndex int `json:"start_char_index,omitempty"` + EndCharIndex int `json:"end_char_index,omitempty"` } type CitationPageLocation struct { CitedText string `json:"cited_text,omitempty"` DocumentTitle string `json:"document_title,omitempty"` - DocumentIndex int64 `json:"document_index,omitempty"` + DocumentIndex int `json:"document_index,omitempty"` - StartPageNumber int64 `json:"start_page_number,omitempty"` - EndPageNumber int64 `json:"end_page_number,omitempty"` + StartPageNumber int `json:"start_page_number,omitempty"` + EndPageNumber int `json:"end_page_number,omitempty"` } type CitationContentBlockLocation struct { CitedText string `json:"cited_text,omitempty"` DocumentTitle string `json:"document_title,omitempty"` - DocumentIndex int64 `json:"document_index,omitempty"` + DocumentIndex int `json:"document_index,omitempty"` - StartBlockIndex int64 `json:"start_block_index,omitempty"` - EndBlockIndex int64 `json:"end_block_index,omitempty"` + StartBlockIndex int `json:"start_block_index,omitempty"` + EndBlockIndex int `json:"end_block_index,omitempty"` } type CitationWebSearchResultLocation struct { @@ -67,3 +76,46 @@ type CitationWebSearchResultLocation struct { EncryptedIndex string `json:"encrypted_index,omitempty"` } + +// ConcatAssistantGenTextExtensions concatenates multiple AssistantGenTextExtension chunks into a single one. +func ConcatAssistantGenTextExtensions(chunks []*AssistantGenTextExtension) (*AssistantGenTextExtension, error) { + if len(chunks) == 0 { + return nil, fmt.Errorf("no assistant generated text extension found") + } + if len(chunks) == 1 { + return chunks[0], nil + } + + ret := &AssistantGenTextExtension{ + Citations: make([]*TextCitation, 0, len(chunks)), + } + + for _, ext := range chunks { + ret.Citations = append(ret.Citations, ext.Citations...) + } + + return ret, nil +} + +// ConcatResponseMetaExtensions concatenates multiple ResponseMetaExtension chunks into a single one. +func ConcatResponseMetaExtensions(chunks []*ResponseMetaExtension) (*ResponseMetaExtension, error) { + if len(chunks) == 0 { + return nil, fmt.Errorf("no response meta extension found") + } + if len(chunks) == 1 { + return chunks[0], nil + } + + ret := &ResponseMetaExtension{} + + for _, ext := range chunks { + if ext.ID != "" { + ret.ID = ext.ID + } + if ext.StopReason != "" { + ret.StopReason = ext.StopReason + } + } + + return ret, nil +} diff --git a/schema/claude/extension_test.go b/schema/claude/extension_test.go new file mode 100644 index 000000000..474fe740b --- /dev/null +++ b/schema/claude/extension_test.go @@ -0,0 +1,190 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package claude + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConcatAssistantGenTextExtensions(t *testing.T) { + t.Run("multiple extensions - concatenates all citations", func(t *testing.T) { + exts := []*AssistantGenTextExtension{ + { + Citations: []*TextCitation{ + { + Type: "char_location", + CharLocation: &CitationCharLocation{ + CitedText: "citation 1", + DocumentIndex: 0, + }, + }, + }, + }, + { + Citations: []*TextCitation{ + { + Type: "page_location", + PageLocation: &CitationPageLocation{ + CitedText: "citation 2", + StartPageNumber: 1, + EndPageNumber: 2, + }, + }, + { + Type: "web_search_result_location", + WebSearchResultLocation: &CitationWebSearchResultLocation{ + CitedText: "citation 3", + URL: "https://example.com", + }, + }, + }, + }, + { + Citations: []*TextCitation{ + { + Type: "content_block_location", + ContentBlockLocation: &CitationContentBlockLocation{ + CitedText: "citation 4", + StartBlockIndex: 0, + EndBlockIndex: 5, + }, + }, + }, + }, + } + + result, err := ConcatAssistantGenTextExtensions(exts) + assert.NoError(t, err) + assert.Len(t, result.Citations, 4) + assert.Equal(t, "citation 1", result.Citations[0].CharLocation.CitedText) + assert.Equal(t, "citation 2", result.Citations[1].PageLocation.CitedText) + assert.Equal(t, "citation 3", result.Citations[2].WebSearchResultLocation.CitedText) + assert.Equal(t, "citation 4", result.Citations[3].ContentBlockLocation.CitedText) + }) + + t.Run("mixed empty and non-empty citations", func(t *testing.T) { + exts := []*AssistantGenTextExtension{ + {Citations: nil}, + { + Citations: []*TextCitation{ + { + Type: "char_location", + CharLocation: &CitationCharLocation{ + CitedText: "text1", + }, + }, + }, + }, + {Citations: []*TextCitation{}}, + { + Citations: []*TextCitation{ + { + Type: "page_location", + PageLocation: &CitationPageLocation{ + CitedText: "text2", + }, + }, + }, + }, + } + + result, err := ConcatAssistantGenTextExtensions(exts) + assert.NoError(t, err) + assert.Len(t, result.Citations, 2) + assert.Equal(t, "text1", result.Citations[0].CharLocation.CitedText) + assert.Equal(t, "text2", result.Citations[1].PageLocation.CitedText) + }) + + t.Run("streaming scenario - citations arrive in chunks", func(t *testing.T) { + // Simulates streaming where citations arrive progressively + exts := []*AssistantGenTextExtension{ + { + Citations: []*TextCitation{ + {Type: "char_location", CharLocation: &CitationCharLocation{CitedText: "chunk1"}}, + }, + }, + { + Citations: []*TextCitation{ + {Type: "char_location", CharLocation: &CitationCharLocation{CitedText: "chunk2"}}, + }, + }, + { + Citations: []*TextCitation{ + {Type: "char_location", CharLocation: &CitationCharLocation{CitedText: "chunk3"}}, + }, + }, + } + + result, err := ConcatAssistantGenTextExtensions(exts) + assert.NoError(t, err) + assert.Len(t, result.Citations, 3) + assert.Equal(t, "chunk1", result.Citations[0].CharLocation.CitedText) + assert.Equal(t, "chunk2", result.Citations[1].CharLocation.CitedText) + assert.Equal(t, "chunk3", result.Citations[2].CharLocation.CitedText) + }) +} + +func TestConcatResponseMetaExtensions(t *testing.T) { + t.Run("multiple extensions - takes last non-empty values", func(t *testing.T) { + exts := []*ResponseMetaExtension{ + { + ID: "msg_1", + StopReason: "stop_1", + }, + { + ID: "msg_2", + StopReason: "", + }, + { + ID: "", + StopReason: "stop_3", + }, + } + + result, err := ConcatResponseMetaExtensions(exts) + assert.NoError(t, err) + assert.Equal(t, "msg_2", result.ID) // Last non-empty ID + assert.Equal(t, "stop_3", result.StopReason) // Last non-empty StopReason + }) + + t.Run("all empty fields", func(t *testing.T) { + exts := []*ResponseMetaExtension{ + {ID: "", StopReason: ""}, + {ID: "", StopReason: ""}, + } + + result, err := ConcatResponseMetaExtensions(exts) + assert.NoError(t, err) + assert.Equal(t, "", result.ID) + assert.Equal(t, "", result.StopReason) + }) + + t.Run("streaming scenario - ID in first chunk, StopReason in last", func(t *testing.T) { + exts := []*ResponseMetaExtension{ + {ID: "msg_stream_123", StopReason: ""}, + {ID: "", StopReason: ""}, + {ID: "", StopReason: "end_turn"}, + } + + result, err := ConcatResponseMetaExtensions(exts) + assert.NoError(t, err) + assert.Equal(t, "msg_stream_123", result.ID) + assert.Equal(t, "end_turn", result.StopReason) + }) +} diff --git a/schema/claude/response_meta.go b/schema/claude/response_meta.go deleted file mode 100644 index 9f60dd713..000000000 --- a/schema/claude/response_meta.go +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Copyright 2025 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package claude - -type ResponseMetaExtension struct { - ID string `json:"id,omitempty"` - StopReason string `json:"stop_reason,omitempty"` -} diff --git a/schema/gemini/response_meta.go b/schema/gemini/extension.go similarity index 76% rename from schema/gemini/response_meta.go rename to schema/gemini/extension.go index a5b3f626c..efbc4f4bd 100644 --- a/schema/gemini/response_meta.go +++ b/schema/gemini/extension.go @@ -14,8 +14,13 @@ * limitations under the License. */ +// Package gemini defines the extension for gemini. package gemini +import ( + "fmt" +) + type ResponseMetaExtension struct { ID string `json:"id,omitempty"` FinishReason string `json:"finish_reason,omitempty"` @@ -38,7 +43,7 @@ type GroundingChunk struct { Web *GroundingChunkWeb `json:"web,omitempty"` } -// Chunk from the web. +// GroundingChunkWeb is the chunk from the web. type GroundingChunkWeb struct { // Domain of the (original) URI. This field is not supported in Gemini API. Domain string `json:"domain,omitempty"` @@ -56,7 +61,7 @@ type GroundingSupport struct { // A list of indices (into 'grounding_chunk') specifying the citations associated with // the claim. For instance [1,3,4] means that grounding_chunk[1], grounding_chunk[3], // grounding_chunk[4] are the retrieved content attributed to the claim. - GroundingChunkIndices []int32 `json:"grounding_chunk_indices,omitempty"` + GroundingChunkIndices []int `json:"grounding_chunk_indices,omitempty"` // Segment of the content this support belongs to. Segment *Segment `json:"segment,omitempty"` } @@ -65,20 +70,46 @@ type GroundingSupport struct { type Segment struct { // Output only. End index in the given Part, measured in bytes. Offset from the start // of the Part, exclusive, starting at zero. - EndIndex int32 `json:"end_index,omitempty"` + EndIndex int `json:"end_index,omitempty"` // Output only. The index of a Part object within its parent Content object. - PartIndex int32 `json:"part_index,omitempty"` + PartIndex int `json:"part_index,omitempty"` // Output only. Start index in the given Part, measured in bytes. Offset from the start // of the Part, inclusive, starting at zero. - StartIndex int32 `json:"start_index,omitempty"` + StartIndex int `json:"start_index,omitempty"` // Output only. The text corresponding to the segment from the response. Text string `json:"text,omitempty"` } -// Google search entry point. +// SearchEntryPoint is the Google search entry point. type SearchEntryPoint struct { // Optional. Web content snippet that can be embedded in a web page or an app webview. RenderedContent string `json:"rendered_content,omitempty"` // Optional. Base64 encoded JSON representing array of tuple. SDKBlob []byte `json:"sdk_blob,omitempty"` } + +// ConcatResponseMetaExtensions concatenates multiple ResponseMetaExtension chunks into a single one. +func ConcatResponseMetaExtensions(chunks []*ResponseMetaExtension) (*ResponseMetaExtension, error) { + if len(chunks) == 0 { + return nil, fmt.Errorf("no response meta extension found") + } + if len(chunks) == 1 { + return chunks[0], nil + } + + ret := &ResponseMetaExtension{} + + for _, ext := range chunks { + if ext.ID != "" { + ret.ID = ext.ID + } + if ext.FinishReason != "" { + ret.FinishReason = ext.FinishReason + } + if ext.GroundingMeta != nil { + ret.GroundingMeta = ext.GroundingMeta + } + } + + return ret, nil +} diff --git a/schema/gemini/extension_test.go b/schema/gemini/extension_test.go new file mode 100644 index 000000000..56f390aa8 --- /dev/null +++ b/schema/gemini/extension_test.go @@ -0,0 +1,79 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package gemini + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConcatResponseMetaExtensions(t *testing.T) { + t.Run("multiple extensions - takes last non-empty values", func(t *testing.T) { + meta1 := &GroundingMetadata{WebSearchQueries: []string{"query1"}} + meta2 := &GroundingMetadata{WebSearchQueries: []string{"query2"}} + + exts := []*ResponseMetaExtension{ + { + ID: "resp_1", + FinishReason: "STOP", + GroundingMeta: meta1, + }, + { + ID: "resp_2", + FinishReason: "", + GroundingMeta: nil, + }, + { + ID: "", + FinishReason: "MAX_TOKENS", + GroundingMeta: meta2, + }, + } + + result, err := ConcatResponseMetaExtensions(exts) + assert.NoError(t, err) + assert.Equal(t, "resp_2", result.ID) + assert.Equal(t, "MAX_TOKENS", result.FinishReason) + assert.Equal(t, meta2, result.GroundingMeta) + }) + + t.Run("streaming scenario", func(t *testing.T) { + meta := &GroundingMetadata{ + GroundingChunks: []*GroundingChunk{ + { + Web: &GroundingChunkWeb{ + Title: "Example", + URI: "https://example.com", + }, + }, + }, + } + + exts := []*ResponseMetaExtension{ + {ID: "stream_123", FinishReason: "", GroundingMeta: nil}, + {ID: "", FinishReason: "", GroundingMeta: nil}, + {ID: "", FinishReason: "STOP", GroundingMeta: meta}, + } + + result, err := ConcatResponseMetaExtensions(exts) + assert.NoError(t, err) + assert.Equal(t, "stream_123", result.ID) + assert.Equal(t, "STOP", result.FinishReason) + assert.Equal(t, meta, result.GroundingMeta) + }) +} diff --git a/schema/message.go b/schema/message.go index 02cdefaad..611bcedca 100644 --- a/schema/message.go +++ b/schema/message.go @@ -703,10 +703,10 @@ type TokenUsage struct { PromptTokenDetails PromptTokenDetails `json:"prompt_token_details"` // CompletionTokens is the number of completion tokens. CompletionTokens int `json:"completion_tokens"` - // CompletionTokenDetails is a breakdown of the completion tokens. - CompletionTokenDetails CompletionTokensDetails `json:"completion_token_details"` // TotalTokens is the total number of tokens. TotalTokens int `json:"total_tokens"` + // CompletionTokensDetails is breakdown of completion tokens. + CompletionTokensDetails CompletionTokensDetails `json:"completion_token_details"` } type CompletionTokensDetails struct { diff --git a/schema/openai/consts.go b/schema/openai/consts.go index 321ee2a9e..5958cef40 100644 --- a/schema/openai/consts.go +++ b/schema/openai/consts.go @@ -14,6 +14,7 @@ * limitations under the License. */ +// Package openai defines constants for openai. package openai type TextAnnotationType string @@ -24,3 +25,71 @@ const ( TextAnnotationTypeContainerFileCitation TextAnnotationType = "container_file_citation" TextAnnotationTypeFilePath TextAnnotationType = "file_path" ) + +type ReasoningEffort string + +const ( + ReasoningEffortMinimal ReasoningEffort = "minimal" + ReasoningEffortLow ReasoningEffort = "low" + ReasoningEffortMedium ReasoningEffort = "medium" + ReasoningEffortHigh ReasoningEffort = "high" +) + +type ReasoningSummary string + +const ( + ReasoningSummaryAuto ReasoningSummary = "auto" + ReasoningSummaryConcise ReasoningSummary = "concise" + ReasoningSummaryDetailed ReasoningSummary = "detailed" +) + +type ServiceTier string + +const ( + ServiceTierAuto ServiceTier = "auto" + ServiceTierDefault ServiceTier = "default" + ServiceTierFlex ServiceTier = "flex" + ServiceTierScale ServiceTier = "scale" + ServiceTierPriority ServiceTier = "priority" +) + +type PromptCacheRetention string + +const ( + PromptCacheRetentionInMemory PromptCacheRetention = "in-memory" + PromptCacheRetention24h PromptCacheRetention = "24h" +) + +type ResponseStatus string + +const ( + ResponseStatusCompleted ResponseStatus = "completed" + ResponseStatusFailed ResponseStatus = "failed" + ResponseStatusInProgress ResponseStatus = "in_progress" + ResponseStatusCancelled ResponseStatus = "cancelled" + ResponseStatusQueued ResponseStatus = "queued" + ResponseStatusIncomplete ResponseStatus = "incomplete" +) + +type ResponseErrorCode string + +const ( + ResponseErrorCodeServerError ResponseErrorCode = "server_error" + ResponseErrorCodeRateLimitExceeded ResponseErrorCode = "rate_limit_exceeded" + ResponseErrorCodeInvalidPrompt ResponseErrorCode = "invalid_prompt" + ResponseErrorCodeVectorStoreTimeout ResponseErrorCode = "vector_store_timeout" + ResponseErrorCodeInvalidImage ResponseErrorCode = "invalid_image" + ResponseErrorCodeInvalidImageFormat ResponseErrorCode = "invalid_image_format" + ResponseErrorCodeInvalidBase64Image ResponseErrorCode = "invalid_base64_image" + ResponseErrorCodeInvalidImageURL ResponseErrorCode = "invalid_image_url" + ResponseErrorCodeImageTooLarge ResponseErrorCode = "image_too_large" + ResponseErrorCodeImageTooSmall ResponseErrorCode = "image_too_small" + ResponseErrorCodeImageParseError ResponseErrorCode = "image_parse_error" + ResponseErrorCodeImageContentPolicyViolation ResponseErrorCode = "image_content_policy_violation" + ResponseErrorCodeInvalidImageMode ResponseErrorCode = "invalid_image_mode" + ResponseErrorCodeImageFileTooLarge ResponseErrorCode = "image_file_too_large" + ResponseErrorCodeUnsupportedImageMediaType ResponseErrorCode = "unsupported_image_media_type" + ResponseErrorCodeEmptyImageFile ResponseErrorCode = "empty_image_file" + ResponseErrorCodeFailedToDownloadImage ResponseErrorCode = "failed_to_download_image" + ResponseErrorCodeImageFileNotFound ResponseErrorCode = "image_file_not_found" +) diff --git a/schema/openai/content_block.go b/schema/openai/content_block.go deleted file mode 100644 index 5d92be8f7..000000000 --- a/schema/openai/content_block.go +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Copyright 2025 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package openai - -type AssistantGenTextExtension struct { - Annotations []*TextAnnotation `json:"annotations,omitempty"` -} - -type TextAnnotation struct { - Type TextAnnotationType `json:"type,omitempty"` - - FileCitation *TextAnnotationFileCitation `json:"file_citation,omitempty"` - URLCitation *TextAnnotationURLCitation `json:"url_citation,omitempty"` - ContainerFileCitation *TextAnnotationContainerFileCitation `json:"container_file_citation,omitempty"` - FilePath *TextAnnotationFilePath `json:"file_path,omitempty"` -} - -type TextAnnotationFileCitation struct { - // The ID of the file. - FileID string `json:"file_id,omitempty"` - // The filename of the file cited. - Filename string `json:"filename,omitempty"` - - // The index of the file in the list of files. - Index int64 `json:"index,omitempty"` -} - -type TextAnnotationURLCitation struct { - // The title of the web resource. - Title string `json:"title,omitempty"` - // The URL of the web resource. - URL string `json:"url,omitempty"` - - // The index of the first character of the URL citation in the message. - StartIndex int64 `json:"start_index,omitempty"` - // The index of the last character of the URL citation in the message. - EndIndex int64 `json:"end_index,omitempty"` -} - -type TextAnnotationContainerFileCitation struct { - // The ID of the container file. - ContainerID string `json:"container_id,omitempty"` - - // The ID of the file. - FileID string `json:"file_id,omitempty"` - // The filename of the container file cited. - Filename string `json:"filename,omitempty"` - - // The index of the first character of the container file citation in the message. - StartIndex int64 `json:"start_index,omitempty"` - // The index of the last character of the container file citation in the message. - EndIndex int64 `json:"end_index,omitempty"` -} - -type TextAnnotationFilePath struct { - // The ID of the file. - FileID string `json:"file_id,omitempty"` - - // The index of the file in the list of files. - Index int64 `json:"index,omitempty"` -} diff --git a/schema/openai/extension.go b/schema/openai/extension.go new file mode 100644 index 000000000..c30d2d8ec --- /dev/null +++ b/schema/openai/extension.go @@ -0,0 +1,206 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package openai + +import ( + "fmt" + "sort" +) + +type ResponseMetaExtension struct { + ID string `json:"id,omitempty"` + Status ResponseStatus `json:"status,omitempty"` + Error *ResponseError `json:"error,omitempty"` + IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"` + PreviousResponseID string `json:"previous_response_id,omitempty"` + Reasoning *Reasoning `json:"reasoning,omitempty"` + ServiceTier ServiceTier `json:"service_tier,omitempty"` + CreatedAt int64 `json:"created_at,omitempty"` + PromptCacheRetention PromptCacheRetention `json:"prompt_cache_retention,omitempty"` +} + +type AssistantGenTextExtension struct { + Refusal *OutputRefusal `json:"refusal,omitempty"` + Annotations []*TextAnnotation `json:"annotations,omitempty"` +} + +type ResponseError struct { + Code ResponseErrorCode `json:"code,omitempty"` + Message string `json:"message,omitempty"` +} + +type IncompleteDetails struct { + Reason string `json:"reason,omitempty"` +} + +type Reasoning struct { + Effort ReasoningEffort `json:"effort,omitempty"` + Summary ReasoningSummary `json:"summary,omitempty"` +} + +type OutputRefusal struct { + Reason string `json:"reason,omitempty"` +} + +type TextAnnotation struct { + Index int `json:"index,omitempty"` + + Type TextAnnotationType `json:"type,omitempty"` + + FileCitation *TextAnnotationFileCitation `json:"file_citation,omitempty"` + URLCitation *TextAnnotationURLCitation `json:"url_citation,omitempty"` + ContainerFileCitation *TextAnnotationContainerFileCitation `json:"container_file_citation,omitempty"` + FilePath *TextAnnotationFilePath `json:"file_path,omitempty"` +} + +type TextAnnotationFileCitation struct { + // The ID of the file. + FileID string `json:"file_id,omitempty"` + // The filename of the file cited. + Filename string `json:"filename,omitempty"` + + // The index of the file in the list of files. + Index int `json:"index,omitempty"` +} + +type TextAnnotationURLCitation struct { + // The title of the web resource. + Title string `json:"title,omitempty"` + // The URL of the web resource. + URL string `json:"url,omitempty"` + + // The index of the first character of the URL citation in the message. + StartIndex int `json:"start_index,omitempty"` + // The index of the last character of the URL citation in the message. + EndIndex int `json:"end_index,omitempty"` +} + +type TextAnnotationContainerFileCitation struct { + // The ID of the container file. + ContainerID string `json:"container_id,omitempty"` + + // The ID of the file. + FileID string `json:"file_id,omitempty"` + // The filename of the container file cited. + Filename string `json:"filename,omitempty"` + + // The index of the first character of the container file citation in the message. + StartIndex int `json:"start_index,omitempty"` + // The index of the last character of the container file citation in the message. + EndIndex int `json:"end_index,omitempty"` +} + +type TextAnnotationFilePath struct { + // The ID of the file. + FileID string `json:"file_id,omitempty"` + + // The index of the file in the list of files. + Index int `json:"index,omitempty"` +} + +// ConcatAssistantGenTextExtensions concatenates multiple AssistantGenTextExtension chunks into a single one. +func ConcatAssistantGenTextExtensions(chunks []*AssistantGenTextExtension) (*AssistantGenTextExtension, error) { + if len(chunks) == 0 { + return nil, fmt.Errorf("no assistant generated text extension found") + } + + ret := &AssistantGenTextExtension{} + + var allAnnotations []*TextAnnotation + for _, ext := range chunks { + allAnnotations = append(allAnnotations, ext.Annotations...) + } + + var ( + indices []int + indexToAnnotation = map[int]*TextAnnotation{} + ) + + for _, an := range allAnnotations { + if an == nil { + continue + } + if indexToAnnotation[an.Index] == nil { + indexToAnnotation[an.Index] = an + indices = append(indices, an.Index) + } else { + return nil, fmt.Errorf("duplicate annotation index %d", an.Index) + } + } + + sort.Slice(indices, func(i, j int) bool { + return indices[i] < indices[j] + }) + + ret.Annotations = make([]*TextAnnotation, 0, len(indices)) + for _, idx := range indices { + an := *indexToAnnotation[idx] + an.Index = 0 // clear index + ret.Annotations = append(ret.Annotations, &an) + } + + for _, ext := range chunks { + if ext.Refusal == nil { + continue + } + if ret.Refusal == nil { + ret.Refusal = ext.Refusal + } else { + ret.Refusal.Reason += ext.Refusal.Reason + } + } + + return ret, nil +} + +// ConcatResponseMetaExtensions concatenates multiple ResponseMetaExtension chunks into a single one. +func ConcatResponseMetaExtensions(chunks []*ResponseMetaExtension) (*ResponseMetaExtension, error) { + if len(chunks) == 0 { + return nil, fmt.Errorf("no response meta extension found") + } + if len(chunks) == 1 { + return chunks[0], nil + } + + ret := &ResponseMetaExtension{} + + for _, ext := range chunks { + if ext.ID != "" { + ret.ID = ext.ID + } + if ext.Status != "" { + ret.Status = ext.Status + } + if ext.Error != nil { + ret.Error = ext.Error + } + if ext.IncompleteDetails != nil { + ret.IncompleteDetails = ext.IncompleteDetails + } + if ext.PreviousResponseID != "" { + ret.PreviousResponseID = ext.PreviousResponseID + } + if ext.Reasoning != nil { + ret.Reasoning = ext.Reasoning + } + if ext.ServiceTier != "" { + ret.ServiceTier = ext.ServiceTier + } + } + + return ret, nil +} diff --git a/schema/openai/extension_test.go b/schema/openai/extension_test.go new file mode 100644 index 000000000..640982fdf --- /dev/null +++ b/schema/openai/extension_test.go @@ -0,0 +1,193 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package openai + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConcatResponseMetaExtensions(t *testing.T) { + t.Run("multiple extensions - takes last non-empty values", func(t *testing.T) { + err1 := &ResponseError{Code: "err1", Message: "msg1"} + incomplete := &IncompleteDetails{Reason: "max_tokens"} + + exts := []*ResponseMetaExtension{ + { + ID: "id_1", + Status: "in_progress", + Error: err1, + IncompleteDetails: nil, + }, + { + ID: "id_2", + Status: "", + Error: nil, + IncompleteDetails: nil, + }, + { + ID: "", + Status: "completed", + Error: nil, + IncompleteDetails: incomplete, + }, + } + + result, err := ConcatResponseMetaExtensions(exts) + assert.NoError(t, err) + assert.Equal(t, "id_2", result.ID) + assert.Equal(t, ResponseStatus("completed"), result.Status) + assert.Equal(t, err1, result.Error) + assert.Equal(t, incomplete, result.IncompleteDetails) + }) + + t.Run("streaming scenario", func(t *testing.T) { + exts := []*ResponseMetaExtension{ + {ID: "chatcmpl_stream", Status: "", Error: nil, IncompleteDetails: nil}, + {ID: "", Status: ResponseStatus("in_progress"), Error: nil, IncompleteDetails: nil}, + {ID: "", Status: ResponseStatus("completed"), Error: nil, IncompleteDetails: nil}, + } + + result, err := ConcatResponseMetaExtensions(exts) + assert.NoError(t, err) + assert.Equal(t, "chatcmpl_stream", result.ID) + assert.Equal(t, ResponseStatus("completed"), result.Status) + }) +} + +func TestConcatAssistantGenTextExtensions(t *testing.T) { + t.Run("single extension with annotations", func(t *testing.T) { + ext := &AssistantGenTextExtension{ + Annotations: []*TextAnnotation{ + { + Index: 0, + Type: "file_citation", + FileCitation: &TextAnnotationFileCitation{ + FileID: "file_123", + Filename: "doc.pdf", + }, + }, + }, + } + + result, err := ConcatAssistantGenTextExtensions([]*AssistantGenTextExtension{ext}) + assert.NoError(t, err) + assert.Len(t, result.Annotations, 1) + assert.Equal(t, "file_123", result.Annotations[0].FileCitation.FileID) + }) + + t.Run("multiple extensions - merges annotations by index", func(t *testing.T) { + exts := []*AssistantGenTextExtension{ + { + Annotations: []*TextAnnotation{ + { + Index: 0, + Type: "file_citation", + FileCitation: &TextAnnotationFileCitation{ + FileID: "file_1", + }, + }, + }, + }, + { + Annotations: []*TextAnnotation{ + { + Index: 2, + Type: "url_citation", + URLCitation: &TextAnnotationURLCitation{ + URL: "https://example.com", + }, + }, + }, + }, + { + Annotations: []*TextAnnotation{ + { + Index: 1, + Type: "file_path", + FilePath: &TextAnnotationFilePath{ + FileID: "file_2", + }, + }, + }, + }, + } + + result, err := ConcatAssistantGenTextExtensions(exts) + assert.NoError(t, err) + assert.Len(t, result.Annotations, 3) + assert.Equal(t, "file_1", result.Annotations[0].FileCitation.FileID) + assert.Equal(t, "file_2", result.Annotations[1].FilePath.FileID) + assert.Equal(t, "https://example.com", result.Annotations[2].URLCitation.URL) + }) + + t.Run("streaming scenario - annotations arrive in chunks", func(t *testing.T) { + exts := []*AssistantGenTextExtension{ + { + Annotations: []*TextAnnotation{ + {Index: 0, Type: "file_citation", FileCitation: &TextAnnotationFileCitation{FileID: "f1"}}, + }, + }, + { + Annotations: []*TextAnnotation{ + {Index: 1, Type: "url_citation", URLCitation: &TextAnnotationURLCitation{URL: "url1"}}, + }, + }, + { + Annotations: []*TextAnnotation{ + {Index: 2, Type: "file_path", FilePath: &TextAnnotationFilePath{FileID: "f2"}}, + }, + }, + } + + result, err := ConcatAssistantGenTextExtensions(exts) + assert.NoError(t, err) + assert.Len(t, result.Annotations, 3) + assert.Equal(t, "f1", result.Annotations[0].FileCitation.FileID) + assert.Equal(t, "url1", result.Annotations[1].URLCitation.URL) + assert.Equal(t, "f2", result.Annotations[2].FilePath.FileID) + }) + + t.Run("multiple extensions - concatenates refusal reason", func(t *testing.T) { + ext1 := &AssistantGenTextExtension{Refusal: &OutputRefusal{Reason: "A"}} + ext2 := &AssistantGenTextExtension{Refusal: &OutputRefusal{Reason: "B"}} + + result, err := ConcatAssistantGenTextExtensions([]*AssistantGenTextExtension{ext1, ext2}) + assert.NoError(t, err) + assert.NotNil(t, result.Refusal) + assert.Equal(t, "AB", result.Refusal.Reason) + }) + + t.Run("duplicate index - error occurrence", func(t *testing.T) { + exts := []*AssistantGenTextExtension{ + { + Annotations: []*TextAnnotation{ + {Index: 0, Type: "file_citation", FileCitation: &TextAnnotationFileCitation{FileID: "first"}}, + }, + }, + { + Annotations: []*TextAnnotation{ + {Index: 0, Type: "url_citation", URLCitation: &TextAnnotationURLCitation{URL: "second"}}, + }, + }, + } + + _, err := ConcatAssistantGenTextExtensions(exts) + assert.Error(t, err) + }) +} diff --git a/schema/openai/response_meta.go b/schema/openai/response_meta.go deleted file mode 100644 index e1933065b..000000000 --- a/schema/openai/response_meta.go +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright 2025 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package openai - -type ResponseMetaExtension struct { - ID string `json:"id,omitempty"` - Status string `json:"status,omitempty"` - Error *ResponseError `json:"error,omitempty"` - StreamError *StreamResponseError `json:"stream_error,omitempty"` - IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"` -} - -type ResponseError struct { - Code string `json:"code,omitempty"` - Message string `json:"message,omitempty"` -} - -type StreamResponseError struct { - Code string - Message string - Param string -} - -type IncompleteDetails struct { - Reason string `json:"reason,omitempty"` -} diff --git a/schema/tool.go b/schema/tool.go index ccc93b6a3..2d6bf90db 100644 --- a/schema/tool.go +++ b/schema/tool.go @@ -59,6 +59,27 @@ const ( ToolChoiceForced ToolChoice = "forced" ) +type AllowedTool struct { + // FunctionToolName is the name of the function tool. + FunctionToolName string + + MCPTool *AllowedMCPTool + + ServerTool *AllowedServerTool +} +type AllowedMCPTool struct { + // ServerLabel is the label of the MCP server. + ServerLabel string + // The name of the MCP tool. + Name string +} + +type AllowedServerTool struct { + // The name of the server tool. + Name string +} + +// ToolInfo is the information of a tool. // ToolInfo describes a tool that can be passed to a ChatModel via // [ToolCallingChatModel.WithTools] or [ChatModel.BindTools]. // From bf2c536fc81a01f9bb419fe90a56f19ee952b9db Mon Sep 17 00:00:00 2001 From: mrh997 Date: Tue, 6 Jan 2026 16:48:56 +0800 Subject: [PATCH 17/65] fix: concat agentic messages (#604) --- components/model/callback_extra.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/components/model/callback_extra.go b/components/model/callback_extra.go index 2767e2e5e..afff3f0a7 100644 --- a/components/model/callback_extra.go +++ b/components/model/callback_extra.go @@ -29,10 +29,10 @@ type TokenUsage struct { PromptTokenDetails PromptTokenDetails // CompletionTokens is the number of completion tokens. CompletionTokens int - // CompletionTokensDetails is a breakdown of the completion tokens. - CompletionTokensDetails CompletionTokensDetails // TotalTokens is the total number of tokens. TotalTokens int + // CompletionTokensDetails is a breakdown of the completion tokens. + CompletionTokensDetails CompletionTokensDetails } type CompletionTokensDetails struct { From 9d84fdd3df093f90e36552dc37d0d0aed7d53e8f Mon Sep 17 00:00:00 2001 From: Megumin Date: Thu, 8 Jan 2026 15:34:53 +0800 Subject: [PATCH 18/65] fix(schema): agentic concat support extra (#670) --- schema/agentic_message.go | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/schema/agentic_message.go b/schema/agentic_message.go index b2225b2c7..2ba02b689 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -695,8 +695,10 @@ func ConcatAgenticMessages(msgs []*AgenticMessage) (*AgenticMessage, error) { role AgenticRoleType blocks []*ContentBlock metas []*AgenticResponseMeta + extra map[string]any blockIndices []int indexToBlocks = map[int][]*ContentBlock{} + extraList = make([]map[string]any, 0, len(msgs)) ) if len(msgs) == 1 { @@ -747,6 +749,10 @@ func ConcatAgenticMessages(msgs []*AgenticMessage) (*AgenticMessage, error) { if msg.ResponseMeta != nil { metas = append(metas, msg.ResponseMeta) } + + if msg.Extra != nil { + extraList = append(extraList, msg.Extra) + } } meta, err := concatAgenticResponseMeta(metas) @@ -758,7 +764,8 @@ func ConcatAgenticMessages(msgs []*AgenticMessage) (*AgenticMessage, error) { // All blocks are streaming, concat each group by index indexToBlock := map[int]*ContentBlock{} for idx, bs := range indexToBlocks { - b, err := concatChunksOfSameContentBlock(bs) + var b *ContentBlock + b, err = concatChunksOfSameContentBlock(bs) if err != nil { return nil, err } @@ -773,10 +780,18 @@ func ConcatAgenticMessages(msgs []*AgenticMessage) (*AgenticMessage, error) { } } + if len(extraList) > 0 { + extra, err = concatExtra(extraList) + if err != nil { + return nil, err + } + } + return &AgenticMessage{ Role: role, ResponseMeta: meta, ContentBlocks: blocks, + Extra: extra, }, nil } From 1fd8814a53f21076e831f5c16f8922d7f7250288 Mon Sep 17 00:00:00 2001 From: Megumin Date: Thu, 8 Jan 2026 19:36:17 +0800 Subject: [PATCH 19/65] feat(schema): optimize agent message format (#671) --- schema/agentic_message.go | 28 +++- schema/agentic_message_test.go | 270 +++++++++++++++++---------------- 2 files changed, 164 insertions(+), 134 deletions(-) diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 2ba02b689..5008f0c75 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -18,6 +18,7 @@ package schema import ( "context" + "encoding/json" "fmt" "reflect" "sort" @@ -1834,7 +1835,7 @@ func (u *UserInputText) String() string { // String returns the string representation of UserInputImage. func (u *UserInputImage) String() string { - return formatMediaString(u.URL, u.Base64Data, u.MIMEType, u.Detail) + return formatMediaString(u.URL, u.Base64Data, u.MIMEType, string(u.Detail)) } // String returns the string representation of UserInputAudio. @@ -1902,7 +1903,7 @@ func (s *ServerToolCall) String() string { if s.CallID != "" { sb.WriteString(fmt.Sprintf(" call_id: %s\n", s.CallID)) } - sb.WriteString(fmt.Sprintf(" arguments: %v\n", s.Arguments)) + sb.WriteString(fmt.Sprintf(" arguments: %s\n", printAny(s.Arguments))) return sb.String() } @@ -1913,7 +1914,7 @@ func (s *ServerToolResult) String() string { if s.CallID != "" { sb.WriteString(fmt.Sprintf(" call_id: %s\n", s.CallID)) } - sb.WriteString(fmt.Sprintf(" result: %v\n", s.Result)) + sb.WriteString(fmt.Sprintf(" result: %s\n", printAny(s.Result))) return sb.String() } @@ -1996,7 +1997,7 @@ func truncateString(s string, maxLen int) string { } // formatMediaString formats URL, Base64Data, MIMEType and Detail for media content -func formatMediaString(url, base64Data string, mimeType string, detail any) string { +func formatMediaString(url, base64Data string, mimeType string, detail string) string { sb := &strings.Builder{} if url != "" { sb.WriteString(fmt.Sprintf(" url: %s\n", truncateString(url, 100))) @@ -2008,8 +2009,8 @@ func formatMediaString(url, base64Data string, mimeType string, detail any) stri if mimeType != "" { sb.WriteString(fmt.Sprintf(" mime_type: %s\n", mimeType)) } - if detail != nil && detail != "" { - sb.WriteString(fmt.Sprintf(" detail: %v\n", detail)) + if detail != "" { + sb.WriteString(fmt.Sprintf(" detail: %s\n", detail)) } return sb.String() } @@ -2027,3 +2028,18 @@ func validateExtensionType(expected reflect.Type, actual any) (reflect.Type, boo } return expected, true } + +func printAny(a any) string { + switch v := a.(type) { + case string: + return v + case fmt.Stringer: + return v.String() + default: + b, err := json.MarshalIndent(a, "", " ") + if err != nil { + return fmt.Sprintf("%v", a) + } + return string(b) + } +} diff --git a/schema/agentic_message_test.go b/schema/agentic_message_test.go index 016aa5c4e..4beb74930 100644 --- a/schema/agentic_message_test.go +++ b/schema/agentic_message_test.go @@ -1234,6 +1234,61 @@ func TestAgenticMessageString(t *testing.T) { Detail: ImageURLDetailHigh, }, }, + { + Type: ContentBlockTypeUserInputAudio, + UserInputAudio: &UserInputAudio{ + URL: "http://audio.com", + Base64Data: "audio_data", + MIMEType: "audio/mp3", + }, + }, + { + Type: ContentBlockTypeUserInputVideo, + UserInputVideo: &UserInputVideo{ + URL: "http://video.com", + Base64Data: "video_data", + MIMEType: "video/mp4", + }, + }, + { + Type: ContentBlockTypeUserInputFile, + UserInputFile: &UserInputFile{ + URL: "http://file.com", + Name: "file.txt", + Base64Data: "file_data", + MIMEType: "text/plain", + }, + }, + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "I'll check the current weather in New York City for you.", + }, + }, + { + Type: ContentBlockTypeAssistantGenImage, + AssistantGenImage: &AssistantGenImage{ + URL: "http://gen_image.com", + Base64Data: "gen_image_data", + MIMEType: "image/png", + }, + }, + { + Type: ContentBlockTypeAssistantGenAudio, + AssistantGenAudio: &AssistantGenAudio{ + URL: "http://gen_audio.com", + Base64Data: "gen_audio_data", + MIMEType: "audio/wav", + }, + }, + { + Type: ContentBlockTypeAssistantGenVideo, + AssistantGenVideo: &AssistantGenVideo{ + URL: "http://gen_video.com", + Base64Data: "gen_video_data", + MIMEType: "video/mp4", + }, + }, { Type: ContentBlockTypeReasoning, Reasoning: &Reasoning{ @@ -1245,12 +1300,6 @@ func TestAgenticMessageString(t *testing.T) { EncryptedContent: "encrypted_reasoning_content_that_is_very_long_and_will_be_truncated_for_display", }, }, - { - Type: ContentBlockTypeAssistantGenText, - AssistantGenText: &AssistantGenText{ - Text: "I'll check the current weather in New York City for you.", - }, - }, { Type: ContentBlockTypeFunctionToolCall, FunctionToolCall: &FunctionToolCall{ @@ -1268,6 +1317,39 @@ func TestAgenticMessageString(t *testing.T) { Result: `{"temperature":72,"condition":"sunny","humidity":45,"wind_speed":8}`, }, }, + { + Type: ContentBlockTypeServerToolCall, + ServerToolCall: &ServerToolCall{ + Name: "server_tool", + CallID: "call_1", + Arguments: map[string]any{"a": 1}, + }, + }, + { + Type: ContentBlockTypeServerToolResult, + ServerToolResult: &ServerToolResult{ + Name: "server_tool", + CallID: "call_1", + Result: map[string]any{"success": true}, + }, + }, + { + Type: ContentBlockTypeMCPToolApprovalRequest, + MCPToolApprovalRequest: &MCPToolApprovalRequest{ + ID: "req_1", + Name: "mcp_tool", + ServerLabel: "mcp_server", + Arguments: "{}", + }, + }, + { + Type: ContentBlockTypeMCPToolApprovalResponse, + MCPToolApprovalResponse: &MCPToolApprovalResponse{ + ApprovalRequestID: "req_1", + Approve: true, + Reason: "looks good", + }, + }, { Type: ContentBlockTypeMCPToolCall, MCPToolCall: &MCPToolCall{ @@ -1322,34 +1404,80 @@ content_blocks: base64_data: iVBORw0KGgoAAAANSUhE...... (96 bytes) mime_type: image/jpeg detail: high - [2] type: reasoning + [2] type: user_input_audio + url: http://audio.com + base64_data: audio_data... (10 bytes) + mime_type: audio/mp3 + [3] type: user_input_video + url: http://video.com + base64_data: video_data... (10 bytes) + mime_type: video/mp4 + [4] type: user_input_file + name: file.txt + url: http://file.com + base64_data: file_data... (9 bytes) + mime_type: text/plain + [5] type: assistant_gen_text + text: I'll check the current weather in New York City for you. + [6] type: assistant_gen_image + url: http://gen_image.com + base64_data: gen_image_data... (14 bytes) + mime_type: image/png + [7] type: assistant_gen_audio + url: http://gen_audio.com + base64_data: gen_audio_data... (14 bytes) + mime_type: audio/wav + [8] type: assistant_gen_video + url: http://gen_video.com + base64_data: gen_video_data... (14 bytes) + mime_type: video/mp4 + [9] type: reasoning summary: 3 items [0] First, I need to identify the location (New York City) from the user's query. [1] Then, I should call the weather API to get current conditions. [2] Finally, I'll format the response in a user-friendly way with temperature and conditions. encrypted_content: encrypted_reasoning_content_that_is_very_long_and_... - [3] type: assistant_gen_text - text: I'll check the current weather in New York City for you. - [4] type: function_tool_call + [10] type: function_tool_call call_id: call_weather_123 name: get_current_weather arguments: {"location":"New York City","unit":"fahrenheit"} stream_index: 0 - [5] type: function_tool_result + [11] type: function_tool_result call_id: call_weather_123 name: get_current_weather result: {"temperature":72,"condition":"sunny","humidity":45,"wind_speed":8} - [6] type: mcp_tool_call + [12] type: server_tool_call + name: server_tool + call_id: call_1 + arguments: { + "a": 1 +} + [13] type: server_tool_result + name: server_tool + call_id: call_1 + result: { + "success": true +} + [14] type: mcp_tool_approval_request + server_label: mcp_server + id: req_1 + name: mcp_tool + arguments: {} + [15] type: mcp_tool_approval_response + approval_request_id: req_1 + approve: true + reason: looks good + [16] type: mcp_tool_call server_label: weather-mcp-server call_id: mcp_forecast_456 name: get_7day_forecast arguments: {"city":"New York","days":7} - [7] type: mcp_tool_result + [17] type: mcp_tool_result call_id: mcp_forecast_456 name: get_7day_forecast result: {"status":"partial","days_available":3} error: [503] Service temporarily unavailable for full 7-day forecast - [8] type: mcp_list_tools_result + [18] type: mcp_list_tools_result server_label: weather-mcp-server tools: 3 items - get_current_weather: Get current weather conditions for a location @@ -1359,120 +1487,6 @@ response_meta: token_usage: prompt=250, completion=180, total=430 `, output) - t.Run("full fields", func(t *testing.T) { - msg := &AgenticMessage{ - Role: AgenticRoleTypeSystem, - ContentBlocks: []*ContentBlock{ - { - Type: ContentBlockTypeUserInputAudio, - UserInputAudio: &UserInputAudio{ - URL: "http://audio.com", - Base64Data: "audio_data", - MIMEType: "audio/mp3", - }, - }, - { - Type: ContentBlockTypeUserInputVideo, - UserInputVideo: &UserInputVideo{ - URL: "http://video.com", - Base64Data: "video_data", - MIMEType: "video/mp4", - }, - }, - { - Type: ContentBlockTypeUserInputFile, - UserInputFile: &UserInputFile{ - URL: "http://file.com", - Name: "file.txt", - Base64Data: "file_data", - MIMEType: "text/plain", - }, - }, - { - Type: ContentBlockTypeAssistantGenImage, - AssistantGenImage: &AssistantGenImage{ - URL: "http://gen_image.com", - Base64Data: "gen_image_data", - MIMEType: "image/png", - }, - }, - { - Type: ContentBlockTypeAssistantGenAudio, - AssistantGenAudio: &AssistantGenAudio{ - URL: "http://gen_audio.com", - Base64Data: "gen_audio_data", - MIMEType: "audio/wav", - }, - }, - { - Type: ContentBlockTypeAssistantGenVideo, - AssistantGenVideo: &AssistantGenVideo{ - URL: "http://gen_video.com", - Base64Data: "gen_video_data", - MIMEType: "video/mp4", - }, - }, - { - Type: ContentBlockTypeServerToolCall, - ServerToolCall: &ServerToolCall{ - Name: "server_tool", - CallID: "call_1", - Arguments: map[string]any{"a": 1}, - }, - }, - { - Type: ContentBlockTypeServerToolResult, - ServerToolResult: &ServerToolResult{ - Name: "server_tool", - CallID: "call_1", - Result: map[string]any{"success": true}, - }, - }, - { - Type: ContentBlockTypeMCPToolApprovalRequest, - MCPToolApprovalRequest: &MCPToolApprovalRequest{ - ID: "req_1", - Name: "mcp_tool", - ServerLabel: "mcp_server", - Arguments: "{}", - }, - }, - { - Type: ContentBlockTypeMCPToolApprovalResponse, - MCPToolApprovalResponse: &MCPToolApprovalResponse{ - ApprovalRequestID: "req_1", - Approve: true, - Reason: "looks good", - }, - }, - }, - } - - s := msg.String() - assert.Contains(t, s, "role: system") - assert.Contains(t, s, "type: user_input_audio") - assert.Contains(t, s, "http://audio.com") - assert.Contains(t, s, "type: user_input_video") - assert.Contains(t, s, "http://video.com") - assert.Contains(t, s, "type: user_input_file") - assert.Contains(t, s, "file.txt") - assert.Contains(t, s, "type: assistant_gen_image") - assert.Contains(t, s, "http://gen_image.com") - assert.Contains(t, s, "type: assistant_gen_audio") - assert.Contains(t, s, "http://gen_audio.com") - assert.Contains(t, s, "type: assistant_gen_video") - assert.Contains(t, s, "http://gen_video.com") - assert.Contains(t, s, "type: server_tool_call") - assert.Contains(t, s, "server_tool") - assert.Contains(t, s, "map[a:1]") - assert.Contains(t, s, "type: server_tool_result") - assert.Contains(t, s, "map[success:true]") - assert.Contains(t, s, "type: mcp_tool_approval_request") - assert.Contains(t, s, "req_1") - assert.Contains(t, s, "type: mcp_tool_approval_response") - assert.Contains(t, s, "looks good") - }) - t.Run("nil/empty fields", func(t *testing.T) { msg := &AgenticMessage{ Role: AgenticRoleTypeUser, From 86a7b6e3638b4bb1344cdccf8627573f3d31f822 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Mon, 12 Jan 2026 12:08:31 +0800 Subject: [PATCH 20/65] fix: openai ConcatResponseMetaExtensions (#678) --- schema/openai/extension.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/schema/openai/extension.go b/schema/openai/extension.go index c30d2d8ec..1e10c411e 100644 --- a/schema/openai/extension.go +++ b/schema/openai/extension.go @@ -200,6 +200,12 @@ func ConcatResponseMetaExtensions(chunks []*ResponseMetaExtension) (*ResponseMet if ext.ServiceTier != "" { ret.ServiceTier = ext.ServiceTier } + if ext.CreatedAt != 0 { + ret.CreatedAt = ext.CreatedAt + } + if ext.PromptCacheRetention != "" { + ret.PromptCacheRetention = ext.PromptCacheRetention + } } return ret, nil From 28ed2b9ca49cf8382140486914b7650a5699259a Mon Sep 17 00:00:00 2001 From: mrh997 Date: Mon, 12 Jan 2026 18:04:30 +0800 Subject: [PATCH 21/65] feat: improve comment (#679) --- components/agentic/interface.go | 3 + schema/agentic_message.go | 164 +++++++++++++++++++++++++------- schema/tool.go | 1 + 3 files changed, 131 insertions(+), 37 deletions(-) diff --git a/components/agentic/interface.go b/components/agentic/interface.go index e9960d332..e62a8eeab 100644 --- a/components/agentic/interface.go +++ b/components/agentic/interface.go @@ -25,5 +25,8 @@ import ( type Model interface { Generate(ctx context.Context, input []*schema.AgenticMessage, opts ...Option) (*schema.AgenticMessage, error) Stream(ctx context.Context, input []*schema.AgenticMessage, opts ...Option) (*schema.StreamReader[*schema.AgenticMessage], error) + + // WithTools returns a new Model instance with the specified tools bound. + // This method does not modify the current instance, making it safer for concurrent use. WithTools(tools []*schema.ToolInfo) (Model, error) } diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 5008f0c75..63c78c3eb 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -66,122 +66,208 @@ const ( ) type AgenticMessage struct { + // ResponseMeta is the response metadata. ResponseMeta *AgenticResponseMeta - Role AgenticRoleType + // Role is the message role. + Role AgenticRoleType + + // ContentBlocks is the list of content blocks. ContentBlocks []*ContentBlock + // Extra is the additional information. Extra map[string]any } type AgenticResponseMeta struct { + // TokenUsage is the token usage. TokenUsage *TokenUsage + // OpenAIExtension is the extension for OpenAI. OpenAIExtension *openai.ResponseMetaExtension + + // GeminiExtension is the extension for Gemini. GeminiExtension *gemini.ResponseMetaExtension + + // ClaudeExtension is the extension for Claude. ClaudeExtension *claude.ResponseMetaExtension - Extension any -} -type StreamingMeta struct { - // Index specifies the index position of this block in the final response. - Index int + // Extension is the extension for other models, supplied by the component implementer. + Extension any } type ContentBlock struct { Type ContentBlockType + // Reasoning contains the reasoning content generated by the model. Reasoning *Reasoning - UserInputText *UserInputText + // UserInputText contains the text content provided by the user. + UserInputText *UserInputText + + // UserInputImage contains the image content provided by the user. UserInputImage *UserInputImage + + // UserInputAudio contains the audio content provided by the user. UserInputAudio *UserInputAudio + + // UserInputVideo contains the video content provided by the user. UserInputVideo *UserInputVideo - UserInputFile *UserInputFile - AssistantGenText *AssistantGenText + // UserInputFile contains the file content provided by the user. + UserInputFile *UserInputFile + + // AssistantGenText contains the text content generated by the model. + AssistantGenText *AssistantGenText + + // AssistantGenImage contains the image content generated by the model. AssistantGenImage *AssistantGenImage + + // AssistantGenAudio contains the audio content generated by the model. AssistantGenAudio *AssistantGenAudio + + // AssistantGenVideo contains the video content generated by the model. AssistantGenVideo *AssistantGenVideo - // FunctionToolCall holds invocation details for a user-defined tool. + // FunctionToolCall contains the invocation details for a user-defined tool. FunctionToolCall *FunctionToolCall - // FunctionToolResult is the result from a user-defined tool call. + + // FunctionToolResult contains the result returned from a user-defined tool call. FunctionToolResult *FunctionToolResult - // ServerToolCall holds invocation details for a provider built-in tool run on the model server. + + // ServerToolCall contains the invocation details for a provider built-in tool executed on the model server. ServerToolCall *ServerToolCall - // ServerToolResult is the result from a provider built-in tool run on the model server. + + // ServerToolResult contains the result returned from a provider built-in tool executed on the model server. ServerToolResult *ServerToolResult - // MCPToolCall holds invocation details for an MCP tool managed by the model server. + // MCPToolCall contains the invocation details for an MCP tool managed by the model server. MCPToolCall *MCPToolCall - // MCPToolResult is the result from an MCP tool managed by the model server. + + // MCPToolResult contains the result returned from an MCP tool managed by the model server. MCPToolResult *MCPToolResult - // MCPListToolsResult lists available MCP tools reported by the model server. + + // MCPListToolsResult contains the list of available MCP tools reported by the model server. MCPListToolsResult *MCPListToolsResult - // MCPToolApprovalRequest requests user approval for an MCP tool call when required. + + // MCPToolApprovalRequest contains the user approval request for an MCP tool call when required. MCPToolApprovalRequest *MCPToolApprovalRequest - // MCPToolApprovalResponse records the user's approval decision for an MCP tool call. + + // MCPToolApprovalResponse contains the user's approval decision for an MCP tool call. MCPToolApprovalResponse *MCPToolApprovalResponse + // StreamingMeta contains metadata for streaming responses. StreamingMeta *StreamingMeta - Extra map[string]any + + // Extra contains additional information for the content block. + Extra map[string]any +} + +type StreamingMeta struct { + // Index specifies the index position of this block in the final response. + Index int } type UserInputText struct { + // Text is the text content. Text string } type UserInputImage struct { - URL string + // URL is the HTTP/HTTPS link. + URL string + + // Base64Data is the binary data in Base64 encoded string format. Base64Data string - MIMEType string - Detail ImageURLDetail + + // MIMEType is the mime type, e.g. "image/png". + MIMEType string + + // Detail is the quality of the image url. + Detail ImageURLDetail } type UserInputAudio struct { - URL string + // URL is the HTTP/HTTPS link. + URL string + + // Base64Data is the binary data in Base64 encoded string format. Base64Data string - MIMEType string + + // MIMEType is the mime type, e.g. "audio/wav". + MIMEType string } type UserInputVideo struct { - URL string + // URL is the HTTP/HTTPS link. + URL string + + // Base64Data is the binary data in Base64 encoded string format. Base64Data string - MIMEType string + + // MIMEType is the mime type, e.g. "video/mp4". + MIMEType string } type UserInputFile struct { - URL string - Name string + // URL is the HTTP/HTTPS link. + URL string + + // Name is the filename. + Name string + + // Base64Data is the binary data in Base64 encoded string format. Base64Data string - MIMEType string + + // MIMEType is the mime type, e.g. "application/pdf". + MIMEType string } type AssistantGenText struct { + // Text is the generated text. Text string + // OpenAIExtension is the extension for OpenAI. OpenAIExtension *openai.AssistantGenTextExtension + + // ClaudeExtension is the extension for Claude. ClaudeExtension *claude.AssistantGenTextExtension - Extension any + + // Extension is the extension for other models, supplied by the component implementer. + Extension any } type AssistantGenImage struct { - URL string + // URL is the HTTP/HTTPS link. + URL string + + // Base64Data is the binary data in Base64 encoded string format. Base64Data string - MIMEType string + + // MIMEType is the mime type, e.g. "image/png". + MIMEType string } type AssistantGenAudio struct { - URL string + // URL is the HTTP/HTTPS link. + URL string + + // Base64Data is the binary data in Base64 encoded string format. Base64Data string - MIMEType string + + // MIMEType is the mime type, e.g. "audio/wav". + MIMEType string } type AssistantGenVideo struct { - URL string + // URL is the HTTP/HTTPS link. + URL string + + // Base64Data is the binary data in Base64 encoded string format. Base64Data string - MIMEType string + + // MIMEType is the mime type, e.g. "video/mp4". + MIMEType string } type Reasoning struct { @@ -196,6 +282,7 @@ type ReasoningSummary struct { // Index specifies the index position of this summary in the final Reasoning. Index int + // Text is the reasoning content summary. Text string } @@ -284,7 +371,10 @@ type MCPToolResult struct { } type MCPToolCallError struct { - Code *int64 + // Code is the error code. + Code *int64 + + // Message is the error message. Message string } diff --git a/schema/tool.go b/schema/tool.go index 2d6bf90db..efed6d34b 100644 --- a/schema/tool.go +++ b/schema/tool.go @@ -67,6 +67,7 @@ type AllowedTool struct { ServerTool *AllowedServerTool } + type AllowedMCPTool struct { // ServerLabel is the label of the MCP server. ServerLabel string From 906136eb23609b21ef7aadf94a52ff21d61773c2 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Tue, 13 Jan 2026 21:41:07 +0800 Subject: [PATCH 22/65] feat: add agentic callbacks template (#681) --- components/agentic/callback_extra.go | 85 ----- components/agentic/callback_extra_test.go | 35 -- components/agentic/interface.go | 32 -- components/agentic/option.go | 142 -------- components/agentic/option_test.go | 79 ----- components/model/callback_extra.go | 18 +- components/model/interface.go | 12 + components/model/option.go | 31 +- components/model/option_test.go | 16 + ...te_agentic.go => agentic_chat_template.go} | 16 +- .../prompt/agentic_chat_template_test.go | 124 +++++++ components/prompt/callback_extra.go | 50 +-- components/prompt/callback_extra_test.go | 21 +- .../prompt/chat_template_agentic_test.go | 111 ------- components/prompt/interface.go | 1 + ..._node_agentic.go => agentic_tools_node.go} | 0 ...tic_test.go => agentic_tools_node_test.go} | 0 compose/chain.go | 3 +- compose/chain_branch.go | 3 +- compose/chain_parallel.go | 3 +- compose/component_to_graph_node.go | 3 +- compose/graph.go | 3 +- schema/agentic_message.go | 8 +- utils/callbacks/template.go | 176 +++++++++- utils/callbacks/template_test.go | 304 +++++++++++++++++- 25 files changed, 694 insertions(+), 582 deletions(-) delete mode 100644 components/agentic/callback_extra.go delete mode 100644 components/agentic/callback_extra_test.go delete mode 100644 components/agentic/interface.go delete mode 100644 components/agentic/option.go delete mode 100644 components/agentic/option_test.go rename components/prompt/{chat_template_agentic.go => agentic_chat_template.go} (87%) create mode 100644 components/prompt/agentic_chat_template_test.go delete mode 100644 components/prompt/chat_template_agentic_test.go rename compose/{tools_node_agentic.go => agentic_tools_node.go} (100%) rename compose/{tools_node_agentic_test.go => agentic_tools_node_test.go} (100%) diff --git a/components/agentic/callback_extra.go b/components/agentic/callback_extra.go deleted file mode 100644 index 2c5a656fa..000000000 --- a/components/agentic/callback_extra.go +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Copyright 2025 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// Package agentic defines callback payloads and configuration types for agentic models. -package agentic - -import ( - "github.com/cloudwego/eino/callbacks" - "github.com/cloudwego/eino/schema" -) - -// Config is the config for the model. -type Config struct { - // Model is the model name. - Model string - // Temperature is the temperature, which controls the randomness of the model. - Temperature float64 - // TopP is the top p, which controls the diversity of the model. - TopP float64 -} - -// CallbackInput is the input for the model callback. -type CallbackInput struct { - // Messages is the messages to be sent to the model. - Messages []*schema.AgenticMessage - // Tools is the tools to be used in the model. - Tools []*schema.ToolInfo - // ToolChoice controls which tool is called by the model. - ToolChoice *schema.ToolChoice - // Config is the config for the model. - Config *Config - // Extra is the extra information for the callback. - Extra map[string]any -} - -// CallbackOutput is the output for the model callback. -type CallbackOutput struct { - // Message is the message generated by the model. - Message *schema.AgenticMessage - // Config is the config for the model. - Config *Config - // Extra is the extra information for the callback. - Extra map[string]any -} - -// ConvCallbackInput converts the callback input to the model callback input. -func ConvCallbackInput(src callbacks.CallbackInput) *CallbackInput { - switch t := src.(type) { - case *CallbackInput: // when callback is triggered within component implementation, the input is usually already a typed *model.CallbackInput - return t - case []*schema.AgenticMessage: // when callback is injected by graph node, not the component implementation itself, the input is the input of Chat Model interface, which is []*schema.AgenticMessage - return &CallbackInput{ - Messages: t, - } - default: - return nil - } -} - -// ConvCallbackOutput converts the callback output to the model callback output. -func ConvCallbackOutput(src callbacks.CallbackOutput) *CallbackOutput { - switch t := src.(type) { - case *CallbackOutput: // when callback is triggered within component implementation, the output is usually already a typed *model.CallbackOutput - return t - case *schema.AgenticMessage: // when callback is injected by graph node, not the component implementation itself, the output is the output of Chat Model interface, which is *schema.AgenticMessage - return &CallbackOutput{ - Message: t, - } - default: - return nil - } -} diff --git a/components/agentic/callback_extra_test.go b/components/agentic/callback_extra_test.go deleted file mode 100644 index a77da6cd2..000000000 --- a/components/agentic/callback_extra_test.go +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Copyright 2025 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package agentic - -import ( - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/cloudwego/eino/schema" -) - -func TestConvModel(t *testing.T) { - assert.NotNil(t, ConvCallbackInput(&CallbackInput{})) - assert.NotNil(t, ConvCallbackInput([]*schema.AgenticMessage{})) - assert.Nil(t, ConvCallbackInput("asd")) - - assert.NotNil(t, ConvCallbackOutput(&CallbackOutput{})) - assert.NotNil(t, ConvCallbackOutput(&schema.AgenticMessage{})) - assert.Nil(t, ConvCallbackOutput("asd")) -} diff --git a/components/agentic/interface.go b/components/agentic/interface.go deleted file mode 100644 index e62a8eeab..000000000 --- a/components/agentic/interface.go +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright 2025 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package agentic - -import ( - "context" - - "github.com/cloudwego/eino/schema" -) - -type Model interface { - Generate(ctx context.Context, input []*schema.AgenticMessage, opts ...Option) (*schema.AgenticMessage, error) - Stream(ctx context.Context, input []*schema.AgenticMessage, opts ...Option) (*schema.StreamReader[*schema.AgenticMessage], error) - - // WithTools returns a new Model instance with the specified tools bound. - // This method does not modify the current instance, making it safer for concurrent use. - WithTools(tools []*schema.ToolInfo) (Model, error) -} diff --git a/components/agentic/option.go b/components/agentic/option.go deleted file mode 100644 index d8873442a..000000000 --- a/components/agentic/option.go +++ /dev/null @@ -1,142 +0,0 @@ -/* - * Copyright 2025 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package agentic - -import ( - "github.com/cloudwego/eino/schema" -) - -// Options is the common options for the model. -type Options struct { - // Temperature is the temperature for the model, which controls the randomness of the model. - Temperature *float64 - // Model is the model name. - Model *string - // TopP is the top p for the model, which controls the diversity of the model. - TopP *float64 - // Tools is a list of tools the model may call. - Tools []*schema.ToolInfo - // ToolChoice controls how the model call the tools. - ToolChoice *schema.ToolChoice - // AllowedTools is a list of allowed tools the model may call. - AllowedTools []*schema.AllowedTool -} - -// Option is the call option for ChatModel component. -type Option struct { - apply func(opts *Options) - - implSpecificOptFn any -} - -// WithTemperature is the option to set the temperature for the model. -func WithTemperature(temperature float64) Option { - return Option{ - apply: func(opts *Options) { - opts.Temperature = &temperature - }, - } -} - -// WithModel is the option to set the model name. -func WithModel(name string) Option { - return Option{ - apply: func(opts *Options) { - opts.Model = &name - }, - } -} - -// WithTopP is the option to set the top p for the model. -func WithTopP(topP float64) Option { - return Option{ - apply: func(opts *Options) { - opts.TopP = &topP - }, - } -} - -// WithTools is the option to set tools for the model. -func WithTools(tools []*schema.ToolInfo) Option { - if tools == nil { - tools = []*schema.ToolInfo{} - } - return Option{ - apply: func(opts *Options) { - opts.Tools = tools - }, - } -} - -// WithToolChoice is the option to set tool choice for the model. -func WithToolChoice(toolChoice schema.ToolChoice, allowedTools ...*schema.AllowedTool) Option { - return Option{ - apply: func(opts *Options) { - opts.ToolChoice = &toolChoice - opts.AllowedTools = allowedTools - }, - } -} - -// WrapImplSpecificOptFn is the option to wrap the implementation specific option function. -func WrapImplSpecificOptFn[T any](optFn func(*T)) Option { - return Option{ - implSpecificOptFn: optFn, - } -} - -// GetCommonOptions extract model Options from Option list, optionally providing a base Options with default values. -func GetCommonOptions(base *Options, opts ...Option) *Options { - if base == nil { - base = &Options{} - } - - for i := range opts { - opt := opts[i] - if opt.apply != nil { - opt.apply(base) - } - } - - return base -} - -// GetImplSpecificOptions extract the implementation specific options from Option list, optionally providing a base options with default values. -// e.g. -// -// myOption := &MyOption{ -// Field1: "default_value", -// } -// -// myOption := model.GetImplSpecificOptions(myOption, opts...) -func GetImplSpecificOptions[T any](base *T, opts ...Option) *T { - if base == nil { - base = new(T) - } - - for i := range opts { - opt := opts[i] - if opt.implSpecificOptFn != nil { - optFn, ok := opt.implSpecificOptFn.(func(*T)) - if ok { - optFn(base) - } - } - } - - return base -} diff --git a/components/agentic/option_test.go b/components/agentic/option_test.go deleted file mode 100644 index 2c5bac652..000000000 --- a/components/agentic/option_test.go +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Copyright 2025 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package agentic - -import ( - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/cloudwego/eino/schema" -) - -func TestCommon(t *testing.T) { - o := GetCommonOptions(nil, - WithTools([]*schema.ToolInfo{{Name: "test"}}), - WithModel("test"), - WithTemperature(0.1), - WithToolChoice(schema.ToolChoiceAllowed, []*schema.AllowedTool{{FunctionToolName: "test"}}...), - WithTopP(0.1), - ) - assert.Len(t, o.Tools, 1) - assert.Equal(t, "test", o.Tools[0].Name) - assert.Equal(t, "test", *o.Model) - assert.Equal(t, float64(0.1), *o.Temperature) - assert.Equal(t, schema.ToolChoiceAllowed, *o.ToolChoice) - assert.Equal(t, float64(0.1), *o.TopP) -} - -func TestImplSpecificOpts(t *testing.T) { - type implSpecificOptions struct { - conf string - index int - } - - withConf := func(conf string) func(o *implSpecificOptions) { - return func(o *implSpecificOptions) { - o.conf = conf - } - } - - withIndex := func(index int) func(o *implSpecificOptions) { - return func(o *implSpecificOptions) { - o.index = index - } - } - - documentOption1 := WrapImplSpecificOptFn(withConf("test_conf")) - documentOption2 := WrapImplSpecificOptFn(withIndex(1)) - - implSpecificOpts := GetImplSpecificOptions(&implSpecificOptions{}, documentOption1, documentOption2) - - assert.Equal(t, &implSpecificOptions{ - conf: "test_conf", - index: 1, - }, implSpecificOpts) - documentOption1 = WrapImplSpecificOptFn(withConf("test_conf")) - documentOption2 = WrapImplSpecificOptFn(withIndex(1)) - - implSpecificOpts = GetImplSpecificOptions(&implSpecificOptions{}, documentOption1, documentOption2) - - assert.Equal(t, &implSpecificOptions{ - conf: "test_conf", - index: 1, - }, implSpecificOpts) -} diff --git a/components/model/callback_extra.go b/components/model/callback_extra.go index afff3f0a7..ed9096d5c 100644 --- a/components/model/callback_extra.go +++ b/components/model/callback_extra.go @@ -31,15 +31,15 @@ type TokenUsage struct { CompletionTokens int // TotalTokens is the total number of tokens. TotalTokens int - // CompletionTokensDetails is a breakdown of the completion tokens. - CompletionTokensDetails CompletionTokensDetails + // CompletionTokensDetails is breakdown of completion tokens. + CompletionTokensDetails CompletionTokensDetails `json:"completion_token_details"` } type CompletionTokensDetails struct { // ReasoningTokens tokens generated by the model for reasoning. // This is currently supported by OpenAI, Gemini, ARK and Qwen chat models. // For other models, this field will be 0. - ReasoningTokens int + ReasoningTokens int `json:"reasoning_tokens,omitempty"` } // PromptTokenDetails provides a breakdown of prompt token usage. @@ -66,6 +66,8 @@ type Config struct { type CallbackInput struct { // Messages is the messages to be sent to the model. Messages []*schema.Message + // AgenticMessages is the agentic messages to be sent to the agentic model. + AgenticMessages []*schema.AgenticMessage // Tools is the tools to be used in the model. Tools []*schema.ToolInfo // ToolChoice is the tool choice, which controls the tool to be used in the model. @@ -80,6 +82,8 @@ type CallbackInput struct { type CallbackOutput struct { // Message is the message generated by the model. Message *schema.Message + // AgenticMessage is the agentic message generated by the agentic model. + AgenticMessage *schema.AgenticMessage // Config is the config for the model. Config *Config // TokenUsage is the token usage of this request. @@ -97,6 +101,10 @@ func ConvCallbackInput(src callbacks.CallbackInput) *CallbackInput { return &CallbackInput{ Messages: t, } + case []*schema.AgenticMessage: // when callback is injected by graph node, not the component implementation itself, the input is the input of Agentic Model interface, which is []*schema.AgenticMessage + return &CallbackInput{ + AgenticMessages: t, + } default: return nil } @@ -111,6 +119,10 @@ func ConvCallbackOutput(src callbacks.CallbackOutput) *CallbackOutput { return &CallbackOutput{ Message: t, } + case *schema.AgenticMessage: // when callback is injected by graph node, not the component implementation itself, the output is the output of Agentic Model interface, which is *schema.AgenticMessage + return &CallbackOutput{ + AgenticMessage: t, + } default: return nil } diff --git a/components/model/interface.go b/components/model/interface.go index deb7b56dd..cf79785bc 100644 --- a/components/model/interface.go +++ b/components/model/interface.go @@ -89,3 +89,15 @@ type ToolCallingChatModel interface { // This method does not modify the current instance, making it safer for concurrent use. WithTools(tools []*schema.ToolInfo) (ToolCallingChatModel, error) } + +// AgenticModel defines the interface for agentic models that support AgenticMessage. +// It provides methods for generating complete and streaming outputs, and supports +// tool calling via the WithTools method. +type AgenticModel interface { + Generate(ctx context.Context, input []*schema.AgenticMessage, opts ...Option) (*schema.AgenticMessage, error) + Stream(ctx context.Context, input []*schema.AgenticMessage, opts ...Option) (*schema.StreamReader[*schema.AgenticMessage], error) + + // WithTools returns a new Model instance with the specified tools bound. + // This method does not modify the current instance, making it safer for concurrent use. + WithTools(tools []*schema.ToolInfo) (AgenticModel, error) +} diff --git a/components/model/option.go b/components/model/option.go index 9fd96116c..0173d22aa 100644 --- a/components/model/option.go +++ b/components/model/option.go @@ -22,21 +22,29 @@ import "github.com/cloudwego/eino/schema" type Options struct { // Temperature is the temperature for the model, which controls the randomness of the model. Temperature *float32 - // MaxTokens is the max number of tokens, if reached the max tokens, the model will stop generating, and mostly return an finish reason of "length". - MaxTokens *int // Model is the model name. Model *string // TopP is the top p for the model, which controls the diversity of the model. TopP *float32 - // Stop is the stop words for the model, which controls the stopping condition of the model. - Stop []string // Tools is a list of tools the model may call. Tools []*schema.ToolInfo // ToolChoice controls which tool is called by the model. ToolChoice *schema.ToolChoice + + // Options only for chat model. + + // MaxTokens is the max number of tokens, if reached the max tokens, the model will stop generating, and mostly return an finish reason of "length". + MaxTokens *int // AllowedToolNames specifies a list of tool names that the model is allowed to call. // This allows for constraining the model to a specific subset of the available tools. AllowedToolNames []string + // Stop is the stop words for the model, which controls the stopping condition of the model. + Stop []string + + // Options only for agentic model. + + // AllowedTools is a list of allowed tools the model may call. + AllowedTools []*schema.AllowedTool } // Option is a call-time option for a ChatModel. Options are immutable and @@ -59,6 +67,7 @@ func WithTemperature(temperature float32) Option { } // WithMaxTokens is the option to set the max tokens for the model. +// Only available for ChatModel. func WithMaxTokens(maxTokens int) Option { return Option{ apply: func(opts *Options) { @@ -86,6 +95,7 @@ func WithTopP(topP float32) Option { } // WithStop is the option to set the stop words for the model. +// Only available for ChatModel. func WithStop(stop []string) Option { return Option{ apply: func(opts *Options) { @@ -108,6 +118,7 @@ func WithTools(tools []*schema.ToolInfo) Option { // WithToolChoice sets the tool choice for the model. It also allows for providing a list of // tool names to constrain the model to a specific subset of the available tools. +// Only available for ChatModel. func WithToolChoice(toolChoice schema.ToolChoice, allowedToolNames ...string) Option { return Option{ apply: func(opts *Options) { @@ -117,6 +128,18 @@ func WithToolChoice(toolChoice schema.ToolChoice, allowedToolNames ...string) Op } } +// WithAgenticToolChoice is the option to set tool choice for the agentic model. +// Only available for AgenticModel. +func WithAgenticToolChoice(toolChoice schema.ToolChoice, allowedTools ...*schema.AllowedTool) Option { + return Option{ + apply: func(opts *Options) { + opts.ToolChoice = &toolChoice + opts.AllowedTools = allowedTools + }, + } +} + +// WrapImplSpecificOptFn is the option to wrap the implementation specific option function. // WrapImplSpecificOptFn wraps an implementation-specific option function into // an [Option] so it can be passed alongside standard options. // diff --git a/components/model/option_test.go b/components/model/option_test.go index 36872c30e..bfacdd17c 100644 --- a/components/model/option_test.go +++ b/components/model/option_test.go @@ -82,6 +82,22 @@ func TestOptions(t *testing.T) { convey.So(opts.Tools, convey.ShouldNotBeNil) convey.So(len(opts.Tools), convey.ShouldEqual, 0) }) + + convey.Convey("test agentic tool choice option", t, func() { + var ( + toolChoice = schema.ToolChoiceForced + allowedTools = []*schema.AllowedTool{ + {FunctionToolName: "agentic_tool"}, + } + ) + opts := GetCommonOptions( + nil, + WithAgenticToolChoice(toolChoice, allowedTools...), + ) + + convey.So(opts.ToolChoice, convey.ShouldResemble, &toolChoice) + convey.So(opts.AllowedTools, convey.ShouldResemble, allowedTools) + }) } type implOption struct { diff --git a/components/prompt/chat_template_agentic.go b/components/prompt/agentic_chat_template.go similarity index 87% rename from components/prompt/chat_template_agentic.go rename to components/prompt/agentic_chat_template.go index 937d46f26..512a60ecd 100644 --- a/components/prompt/chat_template_agentic.go +++ b/components/prompt/agentic_chat_template.go @@ -1,5 +1,5 @@ /* - * Copyright 2025 CloudWeGo Authors + * Copyright 2026 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -45,9 +45,9 @@ type DefaultAgenticChatTemplate struct { func (t *DefaultAgenticChatTemplate) Format(ctx context.Context, vs map[string]any, opts ...Option) (result []*schema.AgenticMessage, err error) { ctx = callbacks.EnsureRunInfo(ctx, t.GetType(), components.ComponentOfAgenticPrompt) - ctx = callbacks.OnStart(ctx, &AgenticCallbackInput{ - Variables: vs, - Templates: t.templates, + ctx = callbacks.OnStart(ctx, &CallbackInput{ + Variables: vs, + AgenticTemplates: t.templates, }) defer func() { if err != nil { @@ -65,15 +65,15 @@ func (t *DefaultAgenticChatTemplate) Format(ctx context.Context, vs map[string]a result = append(result, msgs...) } - _ = callbacks.OnEnd(ctx, &AgenticCallbackOutput{ - Result: result, - Templates: t.templates, + _ = callbacks.OnEnd(ctx, &CallbackOutput{ + AgenticResult: result, + AgenticTemplates: t.templates, }) return result, nil } -// GetType returns the type of the chat template (Default). +// GetType returns the type of the agentic template (DefaultAgentic). func (t *DefaultAgenticChatTemplate) GetType() string { return "Default" } diff --git a/components/prompt/agentic_chat_template_test.go b/components/prompt/agentic_chat_template_test.go new file mode 100644 index 000000000..42d7a8630 --- /dev/null +++ b/components/prompt/agentic_chat_template_test.go @@ -0,0 +1,124 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package prompt + +import ( + "context" + "errors" + "testing" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/schema" + "github.com/stretchr/testify/assert" +) + +type mockAgenticTemplate struct { + err error +} + +func (m *mockAgenticTemplate) Format(ctx context.Context, vs map[string]any, formatType schema.FormatType) ([]*schema.AgenticMessage, error) { + if m.err != nil { + return nil, m.err + } + return []*schema.AgenticMessage{schema.UserAgenticMessage("mocked")}, nil +} + +func TestFromAgenticMessages(t *testing.T) { + t.Run("create template", func(t *testing.T) { + tpl := schema.UserAgenticMessage("hello") + ft := schema.FString + at := FromAgenticMessages(ft, tpl) + + assert.NotNil(t, at) + assert.Equal(t, ft, at.formatType) + assert.Len(t, at.templates, 1) + assert.Same(t, tpl, at.templates[0]) + }) +} + +func TestDefaultAgenticTemplate_GetType(t *testing.T) { + t.Run("get type", func(t *testing.T) { + at := &DefaultAgenticChatTemplate{} + assert.Equal(t, "Default", at.GetType()) + }) +} + +func TestDefaultAgenticTemplate_IsCallbacksEnabled(t *testing.T) { + t.Run("callbacks enabled", func(t *testing.T) { + at := &DefaultAgenticChatTemplate{} + assert.True(t, at.IsCallbacksEnabled()) + }) +} + +func TestDefaultAgenticTemplate_Format(t *testing.T) { + t.Run("success", func(t *testing.T) { + // Mock callback handler + cb := callbacks.NewHandlerBuilder(). + OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { + assert.Equal(t, "Default", info.Type) + return ctx + }). + OnEndFn(func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context { + assert.Equal(t, "Default", info.Type) + return ctx + }). + OnErrorFn(func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { + assert.Fail(t, "unexpected error callback") + return ctx + }). + Build() + + tpl := schema.UserAgenticMessage("hello {val}") + at := FromAgenticMessages(schema.FString, tpl) + + ctx := context.Background() + ctx = callbacks.InitCallbacks(ctx, &callbacks.RunInfo{ + Type: "Default", + Component: "agentic_prompt", + }, cb) + + res, err := at.Format(ctx, map[string]any{"val": "world"}) + assert.NoError(t, err) + assert.Len(t, res, 1) + assert.Equal(t, "hello world", res[0].ContentBlocks[0].UserInputText.Text) + }) + + t.Run("template format error", func(t *testing.T) { + mockErr := errors.New("mock error") + mockTpl := &mockAgenticTemplate{err: mockErr} + at := FromAgenticMessages(schema.FString, mockTpl) + + // Mock callback handler to verify OnError + cb := callbacks.NewHandlerBuilder(). + OnErrorFn(func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { + assert.Equal(t, mockErr, err) + return ctx + }). + Build() + + ctx := context.Background() + ctx = callbacks.InitCallbacks(ctx, &callbacks.RunInfo{ + Type: "Default", + Component: "agentic_prompt", + }, cb) + + res, err := at.Format(ctx, map[string]any{}) + assert.Error(t, err) + assert.Nil(t, res) + assert.Equal(t, mockErr, err) + }) +} diff --git a/components/prompt/callback_extra.go b/components/prompt/callback_extra.go index 3be780543..4c27f37c6 100644 --- a/components/prompt/callback_extra.go +++ b/components/prompt/callback_extra.go @@ -21,52 +21,14 @@ import ( "github.com/cloudwego/eino/schema" ) -type AgenticCallbackInput struct { - Variables map[string]any - Templates []schema.AgenticMessagesTemplate - Extra map[string]any -} - -type AgenticCallbackOutput struct { - Result []*schema.AgenticMessage - Templates []schema.AgenticMessagesTemplate - Extra map[string]any -} - -// ConvAgenticCallbackInput converts the callback input to the agentic callback input. -func ConvAgenticCallbackInput(src callbacks.CallbackInput) *AgenticCallbackInput { - switch t := src.(type) { - case *AgenticCallbackInput: - return t - case map[string]any: - return &AgenticCallbackInput{ - Variables: t, - } - default: - return nil - } -} - -// ConvAgenticCallbackOutput converts the callback output to the agentic callback output. -func ConvAgenticCallbackOutput(src callbacks.CallbackOutput) *AgenticCallbackOutput { - switch t := src.(type) { - case *AgenticCallbackOutput: - return t - case []*schema.AgenticMessage: - return &AgenticCallbackOutput{ - Result: t, - } - default: - return nil - } -} - // CallbackInput is the input for the callback. type CallbackInput struct { // Variables is the variables for the callback. Variables map[string]any // Templates is the templates for the callback. Templates []schema.MessagesTemplate + // AgenticTemplates is the agentic templates for the callback. + AgenticTemplates []schema.AgenticMessagesTemplate // Extra is the extra information for the callback. Extra map[string]any } @@ -75,8 +37,12 @@ type CallbackInput struct { type CallbackOutput struct { // Result is the result for the callback. Result []*schema.Message + // AgenticResult is the agentic result for the callback. + AgenticResult []*schema.AgenticMessage // Templates is the templates for the callback. Templates []schema.MessagesTemplate + // AgenticTemplates is the agentic templates for the callback. + AgenticTemplates []schema.AgenticMessagesTemplate // Extra is the extra information for the callback. Extra map[string]any } @@ -104,6 +70,10 @@ func ConvCallbackOutput(src callbacks.CallbackOutput) *CallbackOutput { return &CallbackOutput{ Result: t, } + case []*schema.AgenticMessage: + return &CallbackOutput{ + AgenticResult: t, + } default: return nil } diff --git a/components/prompt/callback_extra_test.go b/components/prompt/callback_extra_test.go index 456297e29..4b48ec114 100644 --- a/components/prompt/callback_extra_test.go +++ b/components/prompt/callback_extra_test.go @@ -25,11 +25,28 @@ import ( ) func TestConvPrompt(t *testing.T) { - assert.NotNil(t, ConvCallbackInput(&CallbackInput{})) + assert.NotNil(t, ConvCallbackInput(&CallbackInput{ + AgenticTemplates: []schema.AgenticMessagesTemplate{ + &schema.AgenticMessage{}, + }, + })) assert.NotNil(t, ConvCallbackInput(map[string]any{})) assert.Nil(t, ConvCallbackInput("asd")) - assert.NotNil(t, ConvCallbackOutput(&CallbackOutput{})) + assert.NotNil(t, ConvCallbackOutput(&CallbackOutput{ + AgenticResult: []*schema.AgenticMessage{ + {}, + }, + AgenticTemplates: []schema.AgenticMessagesTemplate{ + &schema.AgenticMessage{}, + }, + })) assert.NotNil(t, ConvCallbackOutput([]*schema.Message{})) + + agenticResult := []*schema.AgenticMessage{{}} + out := ConvCallbackOutput(agenticResult) + assert.NotNil(t, out) + assert.Equal(t, agenticResult, out.AgenticResult) + assert.Nil(t, ConvCallbackOutput("asd")) } diff --git a/components/prompt/chat_template_agentic_test.go b/components/prompt/chat_template_agentic_test.go deleted file mode 100644 index aaa7d6405..000000000 --- a/components/prompt/chat_template_agentic_test.go +++ /dev/null @@ -1,111 +0,0 @@ -/* - * Copyright 2025 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package prompt - -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/cloudwego/eino/schema" -) - -func TestAgenticFormat(t *testing.T) { - pyFmtTestTemplate := []schema.AgenticMessagesTemplate{ - &schema.AgenticMessage{ - Role: schema.AgenticRoleTypeUser, - ContentBlocks: []*schema.ContentBlock{ - {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "{context}"}}, - }, - }, - schema.AgenticMessagesPlaceholder("chat_history", true), - } - jinja2TestTemplate := []schema.AgenticMessagesTemplate{ - &schema.AgenticMessage{ - Role: schema.AgenticRoleTypeUser, - ContentBlocks: []*schema.ContentBlock{ - {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "{{context}}"}}, - }, - }, - schema.AgenticMessagesPlaceholder("chat_history", true), - } - goFmtTestTemplate := []schema.AgenticMessagesTemplate{ - &schema.AgenticMessage{ - Role: schema.AgenticRoleTypeUser, - ContentBlocks: []*schema.ContentBlock{ - {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "{{.context}}"}}, - }, - }, - schema.AgenticMessagesPlaceholder("chat_history", true), - } - testValues := map[string]any{ - "context": "it's beautiful day", - "chat_history": []*schema.AgenticMessage{ - { - Role: schema.AgenticRoleTypeUser, - ContentBlocks: []*schema.ContentBlock{ - {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "1"}}, - }, - }, - { - Role: schema.AgenticRoleTypeUser, - ContentBlocks: []*schema.ContentBlock{ - {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "2"}}, - }, - }, - }, - } - expected := []*schema.AgenticMessage{ - { - Role: schema.AgenticRoleTypeUser, - ContentBlocks: []*schema.ContentBlock{ - {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "it's beautiful day"}}, - }, - }, - { - Role: schema.AgenticRoleTypeUser, - ContentBlocks: []*schema.ContentBlock{ - {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "1"}}, - }, - }, - { - Role: schema.AgenticRoleTypeUser, - ContentBlocks: []*schema.ContentBlock{ - {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "2"}}, - }, - }, - } - - // FString - chatTemplate := FromAgenticMessages(schema.FString, pyFmtTestTemplate...) - msgs, err := chatTemplate.Format(context.Background(), testValues) - assert.Nil(t, err) - assert.Equal(t, expected, msgs) - - // Jinja2 - chatTemplate = FromAgenticMessages(schema.Jinja2, jinja2TestTemplate...) - msgs, err = chatTemplate.Format(context.Background(), testValues) - assert.Nil(t, err) - assert.Equal(t, expected, msgs) - - // GoTemplate - chatTemplate = FromAgenticMessages(schema.GoTemplate, goFmtTestTemplate...) - msgs, err = chatTemplate.Format(context.Background(), testValues) - assert.Nil(t, err) - assert.Equal(t, expected, msgs) -} diff --git a/components/prompt/interface.go b/components/prompt/interface.go index 7ffe7216a..2d5a2cbed 100644 --- a/components/prompt/interface.go +++ b/components/prompt/interface.go @@ -44,6 +44,7 @@ type ChatTemplate interface { Format(ctx context.Context, vs map[string]any, opts ...Option) ([]*schema.Message, error) } +// AgenticChatTemplate formats variables into a list of agentic messages according to a prompt schema. type AgenticChatTemplate interface { Format(ctx context.Context, vs map[string]any, opts ...Option) ([]*schema.AgenticMessage, error) } diff --git a/compose/tools_node_agentic.go b/compose/agentic_tools_node.go similarity index 100% rename from compose/tools_node_agentic.go rename to compose/agentic_tools_node.go diff --git a/compose/tools_node_agentic_test.go b/compose/agentic_tools_node_test.go similarity index 100% rename from compose/tools_node_agentic_test.go rename to compose/agentic_tools_node_test.go diff --git a/compose/chain.go b/compose/chain.go index 8484e8767..abfa6bf1d 100644 --- a/compose/chain.go +++ b/compose/chain.go @@ -22,7 +22,6 @@ import ( "fmt" "reflect" - "github.com/cloudwego/eino/components/agentic" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" @@ -181,7 +180,7 @@ func (c *Chain[I, O]) AppendChatModel(node model.BaseChatModel, opts ...GraphAdd // model, err := openai.NewAgenticModel(ctx, config) // if err != nil {...} // chain.AppendAgenticModel(model) -func (c *Chain[I, O]) AppendAgenticModel(node agentic.Model, opts ...GraphAddNodeOpt) *Chain[I, O] { +func (c *Chain[I, O]) AppendAgenticModel(node model.AgenticModel, opts ...GraphAddNodeOpt) *Chain[I, O] { gNode, options := toAgenticModelNode(node, opts...) c.addNode(gNode, options) return c diff --git a/compose/chain_branch.go b/compose/chain_branch.go index 004dbfac3..84fb11048 100644 --- a/compose/chain_branch.go +++ b/compose/chain_branch.go @@ -20,7 +20,6 @@ import ( "context" "fmt" - "github.com/cloudwego/eino/components/agentic" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" @@ -158,7 +157,7 @@ func (cb *ChainBranch) AddChatModel(key string, node model.BaseChatModel, opts . // }) // cb.AddAgenticModel("agentic_model_key_1", model1) // cb.AddAgenticModel("agentic_model_key_2", model2) -func (cb *ChainBranch) AddAgenticModel(key string, node agentic.Model, opts ...GraphAddNodeOpt) *ChainBranch { +func (cb *ChainBranch) AddAgenticModel(key string, node model.AgenticModel, opts ...GraphAddNodeOpt) *ChainBranch { gNode, options := toAgenticModelNode(node, opts...) return cb.addNode(key, gNode, options) } diff --git a/compose/chain_parallel.go b/compose/chain_parallel.go index 128ed4a26..463140be2 100644 --- a/compose/chain_parallel.go +++ b/compose/chain_parallel.go @@ -19,7 +19,6 @@ package compose import ( "fmt" - "github.com/cloudwego/eino/components/agentic" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" @@ -84,7 +83,7 @@ func (p *Parallel) AddChatModel(outputKey string, node model.BaseChatModel, opts // // p.AddAgenticModel("output_key1", model1) // p.AddAgenticModel("output_key2", model2) -func (p *Parallel) AddAgenticModel(outputKey string, node agentic.Model, opts ...GraphAddNodeOpt) *Parallel { +func (p *Parallel) AddAgenticModel(outputKey string, node model.AgenticModel, opts ...GraphAddNodeOpt) *Parallel { gNode, options := toAgenticModelNode(node, append(opts, WithOutputKey(outputKey))...) return p.addNode(outputKey, gNode, options) } diff --git a/compose/component_to_graph_node.go b/compose/component_to_graph_node.go index e64ce4f19..4bd27fe34 100644 --- a/compose/component_to_graph_node.go +++ b/compose/component_to_graph_node.go @@ -18,7 +18,6 @@ package compose import ( "github.com/cloudwego/eino/components" - "github.com/cloudwego/eino/components/agentic" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" @@ -102,7 +101,7 @@ func toChatModelNode(node model.BaseChatModel, opts ...GraphAddNodeOpt) (*graphN opts...) } -func toAgenticModelNode(node agentic.Model, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { +func toAgenticModelNode(node model.AgenticModel, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { return toComponentNode( node, components.ComponentOfAgenticModel, diff --git a/compose/graph.go b/compose/graph.go index 877b8fb42..bcf5ae423 100644 --- a/compose/graph.go +++ b/compose/graph.go @@ -23,7 +23,6 @@ import ( "reflect" "strings" - "github.com/cloudwego/eino/components/agentic" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" @@ -361,7 +360,7 @@ func (g *graph) AddChatModelNode(key string, node model.BaseChatModel, opts ...G // }) // // graph.AddAgenticModelNode("agentic_model_node_key", model) -func (g *graph) AddAgenticModelNode(key string, node agentic.Model, opts ...GraphAddNodeOpt) error { +func (g *graph) AddAgenticModelNode(key string, node model.AgenticModel, opts ...GraphAddNodeOpt) error { gNode, options := toAgenticModelNode(node, opts...) return g.addNode(key, gNode, options) } diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 63c78c3eb..a4554d38e 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -66,15 +66,15 @@ const ( ) type AgenticMessage struct { - // ResponseMeta is the response metadata. - ResponseMeta *AgenticResponseMeta - // Role is the message role. Role AgenticRoleType // ContentBlocks is the list of content blocks. ContentBlocks []*ContentBlock + // ResponseMeta is the response metadata. + ResponseMeta *AgenticResponseMeta + // Extra is the additional information. Extra map[string]any } @@ -541,7 +541,7 @@ func NewContentBlockChunk[T contentBlockVariant](content *T, meta *StreamingMeta return block } -// AgenticMessagesTemplate is the interface for messages template. +// AgenticMessagesTemplate is the interface for agentic messages template. // It's used to render a template to a list of agentic messages. // e.g. // diff --git a/utils/callbacks/template.go b/utils/callbacks/template.go index e04bddd63..4c73e6bbc 100644 --- a/utils/callbacks/template.go +++ b/utils/callbacks/template.go @@ -55,17 +55,20 @@ func NewHandlerHelper() *HandlerHelper { // // then use the handler with runnable.Invoke(ctx, input, compose.WithCallbacks(handler)) type HandlerHelper struct { - promptHandler *PromptCallbackHandler - chatModelHandler *ModelCallbackHandler - embeddingHandler *EmbeddingCallbackHandler - indexerHandler *IndexerCallbackHandler - retrieverHandler *RetrieverCallbackHandler - loaderHandler *LoaderCallbackHandler - transformerHandler *TransformerCallbackHandler - toolHandler *ToolCallbackHandler - toolsNodeHandler *ToolsNodeCallbackHandlers - agentHandler *AgentCallbackHandler - composeTemplates map[components.Component]callbacks.Handler + promptHandler *PromptCallbackHandler + chatModelHandler *ModelCallbackHandler + embeddingHandler *EmbeddingCallbackHandler + indexerHandler *IndexerCallbackHandler + retrieverHandler *RetrieverCallbackHandler + loaderHandler *LoaderCallbackHandler + transformerHandler *TransformerCallbackHandler + toolHandler *ToolCallbackHandler + toolsNodeHandler *ToolsNodeCallbackHandlers + agenticPromptHandler *AgenticPromptCallbackHandler + agenticModelHandler *AgenticModelCallbackHandler + agenticToolsNodeHandler *AgenticToolsNodeCallbackHandlers + agentHandler *AgentCallbackHandler + composeTemplates map[components.Component]callbacks.Handler } // Handler returns the callbacks.Handler created by HandlerHelper. @@ -127,6 +130,24 @@ func (c *HandlerHelper) ToolsNode(handler *ToolsNodeCallbackHandlers) *HandlerHe return c } +// AgenticPrompt sets the agentic prompt handler for the handler helper, which will be called when the agentic prompt component is executed. +func (c *HandlerHelper) AgenticPrompt(handler *AgenticPromptCallbackHandler) *HandlerHelper { + c.agenticPromptHandler = handler + return c +} + +// AgenticModel sets the agentic chat model handler for the handler helper, which will be called when the agentic chat model component is executed. +func (c *HandlerHelper) AgenticModel(handler *AgenticModelCallbackHandler) *HandlerHelper { + c.agenticModelHandler = handler + return c +} + +// AgenticToolsNode sets the agentic tools node handler for the handler helper, which will be called when the agentic tools node is executed. +func (c *HandlerHelper) AgenticToolsNode(handler *AgenticToolsNodeCallbackHandlers) *HandlerHelper { + c.agenticToolsNodeHandler = handler + return c +} + // Agent sets the agent handler for the handler helper, which will be called when the agent is executed. func (c *HandlerHelper) Agent(handler *AgentCallbackHandler) *HandlerHelper { c.agentHandler = handler @@ -161,8 +182,12 @@ func (c *handlerTemplate) OnStart(ctx context.Context, info *callbacks.RunInfo, switch info.Component { case components.ComponentOfPrompt: return c.promptHandler.OnStart(ctx, info, prompt.ConvCallbackInput(input)) + case components.ComponentOfAgenticPrompt: + return c.agenticPromptHandler.OnStart(ctx, info, prompt.ConvCallbackInput(input)) case components.ComponentOfChatModel: return c.chatModelHandler.OnStart(ctx, info, model.ConvCallbackInput(input)) + case components.ComponentOfAgenticModel: + return c.agenticModelHandler.OnStart(ctx, info, model.ConvCallbackInput(input)) case components.ComponentOfEmbedding: return c.embeddingHandler.OnStart(ctx, info, embedding.ConvCallbackInput(input)) case components.ComponentOfIndexer: @@ -177,6 +202,8 @@ func (c *handlerTemplate) OnStart(ctx context.Context, info *callbacks.RunInfo, return c.toolHandler.OnStart(ctx, info, tool.ConvCallbackInput(input)) case compose.ComponentOfToolsNode: return c.toolsNodeHandler.OnStart(ctx, info, convToolsNodeCallbackInput(input)) + case compose.ComponentOfAgenticToolsNode: + return c.agenticToolsNodeHandler.OnStart(ctx, info, convAgenticToolsNodeCallbackInput(input)) case adk.ComponentOfAgent: return c.agentHandler.OnStart(ctx, info, adk.ConvAgentCallbackInput(input)) case compose.ComponentOfGraph, @@ -194,8 +221,12 @@ func (c *handlerTemplate) OnEnd(ctx context.Context, info *callbacks.RunInfo, ou switch info.Component { case components.ComponentOfPrompt: return c.promptHandler.OnEnd(ctx, info, prompt.ConvCallbackOutput(output)) + case components.ComponentOfAgenticPrompt: + return c.agenticPromptHandler.OnEnd(ctx, info, prompt.ConvCallbackOutput(output)) case components.ComponentOfChatModel: return c.chatModelHandler.OnEnd(ctx, info, model.ConvCallbackOutput(output)) + case components.ComponentOfAgenticModel: + return c.agenticModelHandler.OnEnd(ctx, info, model.ConvCallbackOutput(output)) case components.ComponentOfEmbedding: return c.embeddingHandler.OnEnd(ctx, info, embedding.ConvCallbackOutput(output)) case components.ComponentOfIndexer: @@ -210,6 +241,8 @@ func (c *handlerTemplate) OnEnd(ctx context.Context, info *callbacks.RunInfo, ou return c.toolHandler.OnEnd(ctx, info, tool.ConvCallbackOutput(output)) case compose.ComponentOfToolsNode: return c.toolsNodeHandler.OnEnd(ctx, info, convToolsNodeCallbackOutput(output)) + case compose.ComponentOfAgenticToolsNode: + return c.agenticToolsNodeHandler.OnEnd(ctx, info, convAgenticToolsNodeCallbackOutput(output)) case adk.ComponentOfAgent: return c.agentHandler.OnEnd(ctx, info, adk.ConvAgentCallbackOutput(output)) case compose.ComponentOfGraph, @@ -227,8 +260,12 @@ func (c *handlerTemplate) OnError(ctx context.Context, info *callbacks.RunInfo, switch info.Component { case components.ComponentOfPrompt: return c.promptHandler.OnError(ctx, info, err) + case components.ComponentOfAgenticPrompt: + return c.agenticPromptHandler.OnError(ctx, info, err) case components.ComponentOfChatModel: return c.chatModelHandler.OnError(ctx, info, err) + case components.ComponentOfAgenticModel: + return c.agenticModelHandler.OnError(ctx, info, err) case components.ComponentOfEmbedding: return c.embeddingHandler.OnError(ctx, info, err) case components.ComponentOfIndexer: @@ -243,6 +280,8 @@ func (c *handlerTemplate) OnError(ctx context.Context, info *callbacks.RunInfo, return c.toolHandler.OnError(ctx, info, err) case compose.ComponentOfToolsNode: return c.toolsNodeHandler.OnError(ctx, info, err) + case compose.ComponentOfAgenticToolsNode: + return c.agenticToolsNodeHandler.OnError(ctx, info, err) case compose.ComponentOfGraph, compose.ComponentOfChain, compose.ComponentOfLambda: @@ -275,6 +314,11 @@ func (c *handlerTemplate) OnEndWithStreamOutput(ctx context.Context, info *callb schema.StreamReaderWithConvert(output, func(item callbacks.CallbackOutput) (*model.CallbackOutput, error) { return model.ConvCallbackOutput(item), nil })) + case components.ComponentOfAgenticModel: + return c.agenticModelHandler.OnEndWithStreamOutput(ctx, info, + schema.StreamReaderWithConvert(output, func(item callbacks.CallbackOutput) (*model.CallbackOutput, error) { + return model.ConvCallbackOutput(item), nil + })) case components.ComponentOfTool: return c.toolHandler.OnEndWithStreamOutput(ctx, info, schema.StreamReaderWithConvert(output, func(item callbacks.CallbackOutput) (*tool.CallbackOutput, error) { @@ -285,6 +329,11 @@ func (c *handlerTemplate) OnEndWithStreamOutput(ctx context.Context, info *callb schema.StreamReaderWithConvert(output, func(item callbacks.CallbackOutput) ([]*schema.Message, error) { return convToolsNodeCallbackOutput(item), nil })) + case compose.ComponentOfAgenticToolsNode: + return c.agenticToolsNodeHandler.OnEndWithStreamOutput(ctx, info, + schema.StreamReaderWithConvert(output, func(item callbacks.CallbackOutput) ([]*schema.AgenticMessage, error) { + return convAgenticToolsNodeCallbackOutput(item), nil + })) case compose.ComponentOfGraph, compose.ComponentOfChain, compose.ComponentOfLambda: @@ -295,6 +344,8 @@ func (c *handlerTemplate) OnEndWithStreamOutput(ctx context.Context, info *callb } // Needed checks if the callback handler is needed for the given timing. +// +//nolint:cyclop func (c *handlerTemplate) Needed(ctx context.Context, info *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { if info == nil { return false @@ -305,6 +356,10 @@ func (c *handlerTemplate) Needed(ctx context.Context, info *callbacks.RunInfo, t if c.chatModelHandler != nil && c.chatModelHandler.Needed(ctx, info, timing) { return true } + case components.ComponentOfAgenticModel: + if c.agenticModelHandler != nil && c.agenticModelHandler.Needed(ctx, info, timing) { + return true + } case components.ComponentOfEmbedding: if c.embeddingHandler != nil && c.embeddingHandler.Needed(ctx, info, timing) { return true @@ -321,6 +376,10 @@ func (c *handlerTemplate) Needed(ctx context.Context, info *callbacks.RunInfo, t if c.promptHandler != nil && c.promptHandler.Needed(ctx, info, timing) { return true } + case components.ComponentOfAgenticPrompt: + if c.agenticPromptHandler != nil && c.agenticPromptHandler.Needed(ctx, info, timing) { + return true + } case components.ComponentOfRetriever: if c.retrieverHandler != nil && c.retrieverHandler.Needed(ctx, info, timing) { return true @@ -337,6 +396,10 @@ func (c *handlerTemplate) Needed(ctx context.Context, info *callbacks.RunInfo, t if c.toolsNodeHandler != nil && c.toolsNodeHandler.Needed(ctx, info, timing) { return true } + case compose.ComponentOfAgenticToolsNode: + if c.agenticToolsNodeHandler != nil && c.agenticToolsNodeHandler.Needed(ctx, info, timing) { + return true + } case adk.ComponentOfAgent: if c.agentHandler != nil && c.agentHandler.Needed(ctx, info, timing) { return true @@ -596,3 +659,94 @@ func (ch *AgentCallbackHandler) Needed(ctx context.Context, info *callbacks.RunI return false } } + +// AgenticPromptCallbackHandler is the handler for the agentic prompt callback. +type AgenticPromptCallbackHandler struct { + // OnStart is the callback function for the start of the agentic prompt. + OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *prompt.CallbackInput) context.Context + // OnEnd is the callback function for the end of the agentic prompt. + OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *prompt.CallbackOutput) context.Context + // OnError is the callback function for the error of the agentic prompt. + OnError func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context +} + +// Needed checks if the callback handler is needed for the given timing. +func (ch *AgenticPromptCallbackHandler) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { + switch timing { + case callbacks.TimingOnStart: + return ch.OnStart != nil + case callbacks.TimingOnEnd: + return ch.OnEnd != nil + case callbacks.TimingOnError: + return ch.OnError != nil + default: + return false + } +} + +// AgenticModelCallbackHandler is the handler for the agentic chat model callback. +type AgenticModelCallbackHandler struct { + OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.CallbackInput) context.Context + OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *model.CallbackOutput) context.Context + OnEndWithStreamOutput func(ctx context.Context, runInfo *callbacks.RunInfo, output *schema.StreamReader[*model.CallbackOutput]) context.Context + OnError func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context +} + +// Needed checks if the callback handler is needed for the given timing. +func (ch *AgenticModelCallbackHandler) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { + switch timing { + case callbacks.TimingOnStart: + return ch.OnStart != nil + case callbacks.TimingOnEnd: + return ch.OnEnd != nil + case callbacks.TimingOnError: + return ch.OnError != nil + case callbacks.TimingOnEndWithStreamOutput: + return ch.OnEndWithStreamOutput != nil + default: + return false + } +} + +// AgenticToolsNodeCallbackHandlers defines optional callbacks for the Agentic Tools node +// lifecycle events. +type AgenticToolsNodeCallbackHandlers struct { + OnStart func(ctx context.Context, info *callbacks.RunInfo, input *schema.AgenticMessage) context.Context + OnEnd func(ctx context.Context, info *callbacks.RunInfo, input []*schema.AgenticMessage) context.Context + OnEndWithStreamOutput func(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[[]*schema.AgenticMessage]) context.Context + OnError func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context +} + +// Needed reports whether a handler is registered for the given timing. +func (ch *AgenticToolsNodeCallbackHandlers) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { + switch timing { + case callbacks.TimingOnStart: + return ch.OnStart != nil + case callbacks.TimingOnEnd: + return ch.OnEnd != nil + case callbacks.TimingOnEndWithStreamOutput: + return ch.OnEndWithStreamOutput != nil + case callbacks.TimingOnError: + return ch.OnError != nil + default: + return false + } +} + +func convAgenticToolsNodeCallbackInput(src callbacks.CallbackInput) *schema.AgenticMessage { + switch t := src.(type) { + case *schema.AgenticMessage: + return t + default: + return nil + } +} + +func convAgenticToolsNodeCallbackOutput(src callbacks.CallbackInput) []*schema.AgenticMessage { + switch t := src.(type) { + case []*schema.AgenticMessage: + return t + default: + return nil + } +} diff --git a/utils/callbacks/template_test.go b/utils/callbacks/template_test.go index 84ed6dfc6..f599e5300 100644 --- a/utils/callbacks/template_test.go +++ b/utils/callbacks/template_test.go @@ -142,6 +142,58 @@ func TestNewComponentTemplate(t *testing.T) { cnt++ return ctx }).Build()). + AgenticModel(&AgenticModelCallbackHandler{ + OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.CallbackInput) context.Context { + cnt++ + return ctx + }, + OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *model.CallbackOutput) context.Context { + cnt++ + return ctx + }, + OnEndWithStreamOutput: func(ctx context.Context, runInfo *callbacks.RunInfo, output *schema.StreamReader[*model.CallbackOutput]) context.Context { + output.Close() + cnt++ + return ctx + }, + OnError: func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context { + cnt++ + return ctx + }, + }). + AgenticPrompt(&AgenticPromptCallbackHandler{ + OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *prompt.CallbackInput) context.Context { + cnt++ + return ctx + }, + OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *prompt.CallbackOutput) context.Context { + cnt++ + return ctx + }, + OnError: func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context { + cnt++ + return ctx + }, + }). + AgenticToolsNode(&AgenticToolsNodeCallbackHandlers{ + OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *schema.AgenticMessage) context.Context { + cnt++ + return ctx + }, + OnEnd: func(ctx context.Context, info *callbacks.RunInfo, input []*schema.AgenticMessage) context.Context { + cnt++ + return ctx + }, + OnEndWithStreamOutput: func(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[[]*schema.AgenticMessage]) context.Context { + output.Close() + cnt++ + return ctx + }, + OnError: func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { + cnt++ + return ctx + }, + }). Handler() types := []components.Component{ @@ -151,6 +203,9 @@ func TestNewComponentTemplate(t *testing.T) { components.ComponentOfRetriever, components.ComponentOfTool, compose.ComponentOfLambda, + components.ComponentOfAgenticModel, + components.ComponentOfAgenticPrompt, + compose.ComponentOfAgenticToolsNode, } handler := tpl.Handler() @@ -169,28 +224,28 @@ func TestNewComponentTemplate(t *testing.T) { handler.OnEndWithStreamOutput(ctx, &callbacks.RunInfo{Component: typ}, sor) } - assert.Equal(t, 22, cnt) + assert.Equal(t, 33, cnt) ctx = context.Background() ctx = callbacks.InitCallbacks(ctx, &callbacks.RunInfo{Component: components.ComponentOfTransformer}, handler) callbacks.OnStart[any](ctx, nil) - assert.Equal(t, 22, cnt) + assert.Equal(t, 33, cnt) ctx = callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{Component: components.ComponentOfPrompt}) ctx = callbacks.OnStart[any](ctx, nil) - assert.Equal(t, 23, cnt) + assert.Equal(t, 34, cnt) ctx = callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{Component: components.ComponentOfIndexer}) callbacks.OnEnd[any](ctx, nil) - assert.Equal(t, 23, cnt) + assert.Equal(t, 34, cnt) ctx = callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{Component: components.ComponentOfEmbedding}) callbacks.OnError(ctx, nil) - assert.Equal(t, 24, cnt) + assert.Equal(t, 35, cnt) ctx = callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{Component: components.ComponentOfLoader}) callbacks.OnStart[any](ctx, nil) - assert.Equal(t, 24, cnt) + assert.Equal(t, 35, cnt) tpl.Transformer(&TransformerCallbackHandler{ OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *document.TransformerCallbackInput) context.Context { @@ -250,6 +305,37 @@ func TestNewComponentTemplate(t *testing.T) { } } }, + }).AgenticPrompt(&AgenticPromptCallbackHandler{ + OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *prompt.CallbackInput) context.Context { + cnt++ + return ctx + }, + OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *prompt.CallbackOutput) context.Context { + cnt++ + return ctx + }, + OnError: func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context { + cnt++ + return ctx + }, + }).AgenticToolsNode(&AgenticToolsNodeCallbackHandlers{ + OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *schema.AgenticMessage) context.Context { + cnt++ + return ctx + }, + OnEnd: func(ctx context.Context, info *callbacks.RunInfo, input []*schema.AgenticMessage) context.Context { + cnt++ + return ctx + }, + OnEndWithStreamOutput: func(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[[]*schema.AgenticMessage]) context.Context { + output.Close() + cnt++ + return ctx + }, + OnError: func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { + cnt++ + return ctx + }, }) handler = tpl.Handler() @@ -257,36 +343,222 @@ func TestNewComponentTemplate(t *testing.T) { ctx = callbacks.InitCallbacks(ctx, &callbacks.RunInfo{Component: components.ComponentOfTransformer}, handler) ctx = callbacks.OnStart[any](ctx, nil) - assert.Equal(t, 25, cnt) + assert.Equal(t, 36, cnt) callbacks.OnEnd[any](ctx, nil) - assert.Equal(t, 26, cnt) + assert.Equal(t, 37, cnt) ctx = callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{Component: components.ComponentOfLoader}) callbacks.OnEnd[any](ctx, nil) - assert.Equal(t, 27, cnt) + assert.Equal(t, 38, cnt) ctx = callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{Component: compose.ComponentOfToolsNode}) callbacks.OnStart[any](ctx, nil) - assert.Equal(t, 28, cnt) + assert.Equal(t, 39, cnt) sr, sw := schema.Pipe[any](0) sw.Close() callbacks.OnEndWithStreamOutput[any](ctx, sr) - assert.Equal(t, 29, cnt) + assert.Equal(t, 40, cnt) sr1, sw1 := schema.Pipe[[]*schema.Message](1) sw1.Send([]*schema.Message{{}}, nil) sw1.Close() callbacks.OnEndWithStreamOutput[[]*schema.Message](ctx, sr1) - assert.Equal(t, 30, cnt) - - callbacks.OnError(ctx, nil) - assert.Equal(t, 30, cnt) + // Check AgenticModel stream + sir2, siw2 := schema.Pipe[callbacks.CallbackOutput](1) + siw2.Close() + handler.OnEndWithStreamOutput(ctx, &callbacks.RunInfo{Component: components.ComponentOfAgenticModel}, sir2) + assert.Equal(t, 42, cnt) + + // Check AgenticToolsNode stream + sir3, siw3 := schema.Pipe[callbacks.CallbackOutput](1) + siw3.Close() + handler.OnEndWithStreamOutput(ctx, &callbacks.RunInfo{Component: compose.ComponentOfAgenticToolsNode}, sir3) + assert.Equal(t, 43, cnt) ctx = callbacks.ReuseHandlers(ctx, nil) callbacks.OnStart[any](ctx, nil) - assert.Equal(t, 30, cnt) + assert.Equal(t, 43, cnt) + }) + + t.Run("EdgeCases", func(t *testing.T) { + ctx := context.Background() + cnt := 0 + + // 1. Test Graph and Chain Setters and Execution + tpl := NewHandlerHelper(). + Graph(callbacks.NewHandlerBuilder(). + OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { + cnt++ + return ctx + }).Build()). + Chain(callbacks.NewHandlerBuilder(). + OnEndFn(func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context { + cnt++ + return ctx + }).Build()) + + h := tpl.Handler() + + // Trigger Graph OnStart + h.OnStart(ctx, &callbacks.RunInfo{Component: compose.ComponentOfGraph}, nil) + assert.Equal(t, 1, cnt) + + // Trigger Chain OnEnd + h.OnEnd(ctx, &callbacks.RunInfo{Component: compose.ComponentOfChain}, nil) + assert.Equal(t, 2, cnt) + + // 2. Test Needed logic for Graph/Chain when handler is present/absent + // Graph is present (OnStart) + needed := h.(callbacks.TimingChecker).Needed(ctx, &callbacks.RunInfo{Component: compose.ComponentOfGraph}, callbacks.TimingOnStart) + assert.True(t, needed) + + // Chain is present (OnEnd) - but we check OnStart which is not defined in the builder above? + // NewHandlerBuilder returns a handler that usually returns true for Needed if the specific func is not nil. + // Let's verify Chain OnStart is NOT needed because we only set OnEndFn. + needed = h.(callbacks.TimingChecker).Needed(ctx, &callbacks.RunInfo{Component: compose.ComponentOfChain}, callbacks.TimingOnStart) + assert.False(t, needed) // Should be false because OnStartFn wasn't set for Chain + + // Lambda is NOT present + needed = h.(callbacks.TimingChecker).Needed(ctx, &callbacks.RunInfo{Component: compose.ComponentOfLambda}, callbacks.TimingOnStart) + assert.False(t, needed) + + // 3. Test Conversion Fallbacks (Default cases) + // We need a handler with ToolsNode and AgenticToolsNode to test their conversion fallbacks + tpl2 := NewHandlerHelper(). + ToolsNode(&ToolsNodeCallbackHandlers{ + OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *schema.Message) context.Context { + if input == nil { + cnt++ + } + return ctx + }, + OnEnd: func(ctx context.Context, info *callbacks.RunInfo, input []*schema.Message) context.Context { + if input == nil { + cnt++ + } + return ctx + }, + }). + AgenticToolsNode(&AgenticToolsNodeCallbackHandlers{ + OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *schema.AgenticMessage) context.Context { + if input == nil { + cnt++ + } + return ctx + }, + OnEnd: func(ctx context.Context, info *callbacks.RunInfo, input []*schema.AgenticMessage) context.Context { + if input == nil { + cnt++ + } + return ctx + }, + }) + + h2 := tpl2.Handler() + + // Pass wrong type (string) to trigger default case in convToolsNodeCallbackInput -> returns nil + h2.OnStart(ctx, &callbacks.RunInfo{Component: compose.ComponentOfToolsNode}, "wrong-input-type") + assert.Equal(t, 3, cnt) // +1 + + // Pass wrong type to trigger default case in convToolsNodeCallbackOutput -> returns nil + h2.OnEnd(ctx, &callbacks.RunInfo{Component: compose.ComponentOfToolsNode}, "wrong-output-type") + assert.Equal(t, 4, cnt) // +1 + + // Pass wrong type to trigger default case in convAgenticToolsNodeCallbackInput -> returns nil + h2.OnStart(ctx, &callbacks.RunInfo{Component: compose.ComponentOfAgenticToolsNode}, "wrong-input-type") + assert.Equal(t, 5, cnt) // +1 + + // Pass wrong type to trigger default case in convAgenticToolsNodeCallbackOutput -> returns nil + h2.OnEnd(ctx, &callbacks.RunInfo{Component: compose.ComponentOfAgenticToolsNode}, "wrong-output-type") + assert.Equal(t, 6, cnt) // +1 + + // 4. Test Needed for Agentic components when handlers are Set vs Unset + // tpl2 has AgenticToolsNode set + needed = h2.(callbacks.TimingChecker).Needed(ctx, &callbacks.RunInfo{Component: compose.ComponentOfAgenticToolsNode}, callbacks.TimingOnStart) + assert.True(t, needed) + + // tpl2 does NOT have AgenticModel set + needed = h2.(callbacks.TimingChecker).Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfAgenticModel}, callbacks.TimingOnStart) + assert.False(t, needed) + + // Set it now + tpl2.AgenticModel(&AgenticModelCallbackHandler{ + OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.CallbackInput) context.Context { + return ctx + }, + }) + + needed = h2.(callbacks.TimingChecker).Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfAgenticModel}, callbacks.TimingOnStart) + assert.True(t, needed) + + // Check invalid component + needed = h2.(callbacks.TimingChecker).Needed(ctx, &callbacks.RunInfo{Component: "UnknownComponent"}, callbacks.TimingOnStart) + assert.False(t, needed) + + // Check RunInfo nil + needed = h2.(callbacks.TimingChecker).Needed(ctx, nil, callbacks.TimingOnStart) + assert.False(t, needed) + + // 5. Test Needed for Transformer, Loader, Indexer, etc to ensure switch coverage + tpl3 := NewHandlerHelper(). + Transformer(&TransformerCallbackHandler{OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *document.TransformerCallbackInput) context.Context { + return ctx + }}). + Loader(&LoaderCallbackHandler{OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *document.LoaderCallbackInput) context.Context { + return ctx + }}). + Indexer(&IndexerCallbackHandler{OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *indexer.CallbackInput) context.Context { + return ctx + }}). + Retriever(&RetrieverCallbackHandler{OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *retriever.CallbackInput) context.Context { + return ctx + }}). + Embedding(&EmbeddingCallbackHandler{OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *embedding.CallbackInput) context.Context { + return ctx + }}). + Tool(&ToolCallbackHandler{OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *tool.CallbackInput) context.Context { + return ctx + }}) + + h3 := tpl3.Handler() + checker := h3.(callbacks.TimingChecker) + + assert.True(t, checker.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfTransformer}, callbacks.TimingOnStart)) + assert.True(t, checker.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfLoader}, callbacks.TimingOnStart)) + assert.True(t, checker.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfIndexer}, callbacks.TimingOnStart)) + assert.True(t, checker.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfRetriever}, callbacks.TimingOnStart)) + assert.True(t, checker.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfEmbedding}, callbacks.TimingOnStart)) + assert.True(t, checker.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfTool}, callbacks.TimingOnStart)) + + // Verify False paths (by using a helper without them) + emptyH := NewHandlerHelper().Handler().(callbacks.TimingChecker) + assert.False(t, emptyH.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfTransformer}, callbacks.TimingOnStart)) + assert.False(t, emptyH.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfLoader}, callbacks.TimingOnStart)) + assert.False(t, emptyH.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfIndexer}, callbacks.TimingOnStart)) + assert.False(t, emptyH.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfRetriever}, callbacks.TimingOnStart)) + assert.False(t, emptyH.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfEmbedding}, callbacks.TimingOnStart)) + assert.False(t, emptyH.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfTool}, callbacks.TimingOnStart)) + + // 6. Test Needed for remaining components (ChatModel, Prompt, AgenticPrompt) + tpl4 := NewHandlerHelper(). + ChatModel(&ModelCallbackHandler{OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.CallbackInput) context.Context { + return ctx + }}). + Prompt(&PromptCallbackHandler{OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *prompt.CallbackInput) context.Context { + return ctx + }}). + AgenticPrompt(&AgenticPromptCallbackHandler{OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *prompt.CallbackInput) context.Context { + return ctx + }}) + + h4 := tpl4.Handler() + checker4 := h4.(callbacks.TimingChecker) + + assert.True(t, checker4.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfChatModel}, callbacks.TimingOnStart)) + assert.True(t, checker4.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfPrompt}, callbacks.TimingOnStart)) + assert.True(t, checker4.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfAgenticPrompt}, callbacks.TimingOnStart)) }) } From 75a2655a440e7e6fa3dce19fb7274290409dacbb Mon Sep 17 00:00:00 2001 From: mrh997 Date: Thu, 15 Jan 2026 16:19:47 +0800 Subject: [PATCH 23/65] feat: improve AgenticToolChoice (#684) --- components/model/option.go | 17 ++++++++--------- components/model/option_test.go | 13 ++++++++++--- compose/workflow.go | 18 ++++++++++++++++++ schema/tool.go | 31 +++++++++++++++++++++++++++++-- 4 files changed, 65 insertions(+), 14 deletions(-) diff --git a/components/model/option.go b/components/model/option.go index 0173d22aa..a337b7af2 100644 --- a/components/model/option.go +++ b/components/model/option.go @@ -28,11 +28,11 @@ type Options struct { TopP *float32 // Tools is a list of tools the model may call. Tools []*schema.ToolInfo - // ToolChoice controls which tool is called by the model. - ToolChoice *schema.ToolChoice - // Options only for chat model. + // Options only available for chat model. + // ToolChoice controls which tool is called by the model. + ToolChoice *schema.ToolChoice // MaxTokens is the max number of tokens, if reached the max tokens, the model will stop generating, and mostly return an finish reason of "length". MaxTokens *int // AllowedToolNames specifies a list of tool names that the model is allowed to call. @@ -41,10 +41,10 @@ type Options struct { // Stop is the stop words for the model, which controls the stopping condition of the model. Stop []string - // Options only for agentic model. + // Options only available for agentic model. - // AllowedTools is a list of allowed tools the model may call. - AllowedTools []*schema.AllowedTool + // AgenticToolChoice controls how the agentic model calls tools. + AgenticToolChoice *schema.AgenticToolChoice } // Option is a call-time option for a ChatModel. Options are immutable and @@ -130,11 +130,10 @@ func WithToolChoice(toolChoice schema.ToolChoice, allowedToolNames ...string) Op // WithAgenticToolChoice is the option to set tool choice for the agentic model. // Only available for AgenticModel. -func WithAgenticToolChoice(toolChoice schema.ToolChoice, allowedTools ...*schema.AllowedTool) Option { +func WithAgenticToolChoice(toolChoice *schema.AgenticToolChoice) Option { return Option{ apply: func(opts *Options) { - opts.ToolChoice = &toolChoice - opts.AllowedTools = allowedTools + opts.AgenticToolChoice = toolChoice }, } } diff --git a/components/model/option_test.go b/components/model/option_test.go index bfacdd17c..aa43e6e01 100644 --- a/components/model/option_test.go +++ b/components/model/option_test.go @@ -92,11 +92,18 @@ func TestOptions(t *testing.T) { ) opts := GetCommonOptions( nil, - WithAgenticToolChoice(toolChoice, allowedTools...), + WithAgenticToolChoice(&schema.AgenticToolChoice{ + Type: toolChoice, + Forced: &schema.AgenticForcedToolChoice{ + Tools: allowedTools, + }, + }), ) - convey.So(opts.ToolChoice, convey.ShouldResemble, &toolChoice) - convey.So(opts.AllowedTools, convey.ShouldResemble, allowedTools) + convey.So(opts.AgenticToolChoice, convey.ShouldNotBeNil) + convey.So(opts.AgenticToolChoice.Type, convey.ShouldEqual, toolChoice) + convey.So(opts.AgenticToolChoice.Forced, convey.ShouldNotBeNil) + convey.So(opts.AgenticToolChoice.Forced.Tools, convey.ShouldResemble, allowedTools) }) } diff --git a/compose/workflow.go b/compose/workflow.go index c3e4331a3..6b50962bb 100644 --- a/compose/workflow.go +++ b/compose/workflow.go @@ -89,18 +89,36 @@ func (wf *Workflow[I, O]) AddChatModelNode(key string, chatModel model.BaseChatM return wf.initNode(key) } +// AddAgenticModelNode adds an agentic model node and returns it. +func (wf *Workflow[I, O]) AddAgenticModelNode(key string, agenticModel model.AgenticModel, opts ...GraphAddNodeOpt) *WorkflowNode { + _ = wf.g.AddAgenticModelNode(key, agenticModel, opts...) + return wf.initNode(key) +} + // AddChatTemplateNode adds a chat template node and returns it. func (wf *Workflow[I, O]) AddChatTemplateNode(key string, chatTemplate prompt.ChatTemplate, opts ...GraphAddNodeOpt) *WorkflowNode { _ = wf.g.AddChatTemplateNode(key, chatTemplate, opts...) return wf.initNode(key) } +// AddAgenticChatTemplateNode adds an agentic chat template node and returns it. +func (wf *Workflow[I, O]) AddAgenticChatTemplateNode(key string, chatTemplate prompt.AgenticChatTemplate, opts ...GraphAddNodeOpt) *WorkflowNode { + _ = wf.g.AddAgenticChatTemplateNode(key, chatTemplate, opts...) + return wf.initNode(key) +} + // AddToolsNode adds a tools node and returns it. func (wf *Workflow[I, O]) AddToolsNode(key string, tools *ToolsNode, opts ...GraphAddNodeOpt) *WorkflowNode { _ = wf.g.AddToolsNode(key, tools, opts...) return wf.initNode(key) } +// AddAgenticToolsNode adds an agentic tools node and returns it. +func (wf *Workflow[I, O]) AddAgenticToolsNode(key string, tools *AgenticToolsNode, opts ...GraphAddNodeOpt) *WorkflowNode { + _ = wf.g.AddAgenticToolsNode(key, tools, opts...) + return wf.initNode(key) +} + // AddRetrieverNode adds a retriever node and returns it. func (wf *Workflow[I, O]) AddRetrieverNode(key string, retriever retriever.Retriever, opts ...GraphAddNodeOpt) *WorkflowNode { _ = wf.g.AddRetrieverNode(key, retriever, opts...) diff --git a/schema/tool.go b/schema/tool.go index efed6d34b..a067d87db 100644 --- a/schema/tool.go +++ b/schema/tool.go @@ -59,6 +59,31 @@ const ( ToolChoiceForced ToolChoice = "forced" ) +type AgenticToolChoice struct { + // Type is the tool choice mode. + Type ToolChoice + // Allowed optionally specifies the list of tools that the model is permitted to call. + Allowed *AgenticAllowedToolChoice + // Forced optionally specifies the list of tools that the model is required to call. + Forced *AgenticForcedToolChoice +} + +// AgenticAllowedToolChoice specifies a list of allowed tools for the model. +type AgenticAllowedToolChoice struct { + // Tools is the list of allowed tools for the model to call. + // Optional. + Tools []*AllowedTool +} + +// AgenticForcedToolChoice specifies a list of tools that the model must call. +type AgenticForcedToolChoice struct { + // Tools is the list of tools that the model must call. + // Optional. + Tools []*AllowedTool +} + +// AllowedTool represents a tool that the model is allowed or forced to call. +// Exactly one of FunctionToolName, MCPTool, or ServerTool must be specified. type AllowedTool struct { // FunctionToolName is the name of the function tool. FunctionToolName string @@ -68,15 +93,17 @@ type AllowedTool struct { ServerTool *AllowedServerTool } +// AllowedMCPTool contains the information for identifying an MCP tool. type AllowedMCPTool struct { // ServerLabel is the label of the MCP server. ServerLabel string - // The name of the MCP tool. + // Name is the name of the MCP tool. Name string } +// AllowedServerTool contains the information for identifying a server tool. type AllowedServerTool struct { - // The name of the server tool. + // Name is the name of the server tool. Name string } From 3d88d3296d1a65414bd33295bceafe4c049df628 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Thu, 15 Jan 2026 18:04:59 +0800 Subject: [PATCH 24/65] feat: define AgenticCallbackInput/Output (#689) --- components/model/agentic_callback_extra.go | 92 +++++++++++++++++++ .../model/agentic_callback_extra_test.go | 35 +++++++ components/model/callback_extra.go | 12 --- components/model/option_test.go | 2 +- components/prompt/agentic_callback_extra.go | 70 ++++++++++++++ .../prompt/agentic_callback_extra_test.go | 46 ++++++++++ components/prompt/agentic_chat_template.go | 4 +- components/prompt/callback_extra.go | 10 -- components/prompt/callback_extra_test.go | 17 +--- schema/agentic_message.go | 9 -- schema/agentic_message_test.go | 9 -- schema/tool.go | 8 +- 12 files changed, 256 insertions(+), 58 deletions(-) create mode 100644 components/model/agentic_callback_extra.go create mode 100644 components/model/agentic_callback_extra_test.go create mode 100644 components/prompt/agentic_callback_extra.go create mode 100644 components/prompt/agentic_callback_extra_test.go diff --git a/components/model/agentic_callback_extra.go b/components/model/agentic_callback_extra.go new file mode 100644 index 000000000..28dd366e6 --- /dev/null +++ b/components/model/agentic_callback_extra.go @@ -0,0 +1,92 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package model + +import ( + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/schema" +) + +// AgenticConfig is the config for the agentic model. +type AgenticConfig struct { + // Model is the model name. + Model string + // Temperature is the temperature, which controls the randomness of the agentic model. + Temperature float32 + // TopP is the top p, which controls the diversity of the agentic model. + TopP float32 +} + +// AgenticCallbackInput is the input for the agentic model callback. +type AgenticCallbackInput struct { + // AgenticMessages is the agentic messages to be sent to the agentic model. + AgenticMessages []*schema.AgenticMessage + // Tools is the tools to be used in the agentic model. + Tools []*schema.ToolInfo + // Config is the config for the agentic model. + Config *AgenticConfig + // Extra is the extra information for the callback. + Extra map[string]any +} + +// AgenticCallbackOutput is the output for the agentic model callback. +type AgenticCallbackOutput struct { + // AgenticMessage is the agentic message generated by the agentic model. + AgenticMessage *schema.AgenticMessage + // Config is the config for the agentic model. + Config *AgenticConfig + // TokenUsage is the token usage of this request. + TokenUsage *TokenUsage + // Extra is the extra information for the callback. + Extra map[string]any +} + +// ConvAgenticCallbackInput converts the callback input to the agentic model callback input. +func ConvAgenticCallbackInput(src callbacks.CallbackInput) *AgenticCallbackInput { + switch t := src.(type) { + case *AgenticCallbackInput: + // when callback is triggered within component implementation, + // the input is usually already a typed *model.AgenticCallbackInput + return t + case []*schema.AgenticMessage: + // when callback is injected by graph node, not the component implementation itself, + // the input is the input of Agentic Model interface, which is []*schema.AgenticMessage + return &AgenticCallbackInput{ + AgenticMessages: t, + } + default: + return nil + } +} + +// ConvAgenticCallbackOutput converts the callback output to the agentic model callback output. +func ConvAgenticCallbackOutput(src callbacks.CallbackOutput) *AgenticCallbackOutput { + switch t := src.(type) { + case *AgenticCallbackOutput: + // when callback is triggered within component implementation, + // the output is usually already a typed *model.AgenticCallbackOutput + return t + case *schema.AgenticMessage: + // when callback is injected by graph node, not the component implementation itself, + // the output is the output of Agentic Model interface, which is *schema.AgenticMessage + return &AgenticCallbackOutput{ + AgenticMessage: t, + } + default: + return nil + } +} diff --git a/components/model/agentic_callback_extra_test.go b/components/model/agentic_callback_extra_test.go new file mode 100644 index 000000000..937367477 --- /dev/null +++ b/components/model/agentic_callback_extra_test.go @@ -0,0 +1,35 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package model + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" +) + +func TestConvAgenticModel(t *testing.T) { + assert.NotNil(t, ConvAgenticCallbackInput(&AgenticCallbackInput{})) + assert.NotNil(t, ConvAgenticCallbackInput([]*schema.AgenticMessage{})) + assert.Nil(t, ConvAgenticCallbackInput("asd")) + + assert.NotNil(t, ConvAgenticCallbackOutput(&AgenticCallbackOutput{})) + assert.NotNil(t, ConvAgenticCallbackOutput(&schema.AgenticMessage{})) + assert.Nil(t, ConvAgenticCallbackOutput("asd")) +} diff --git a/components/model/callback_extra.go b/components/model/callback_extra.go index ed9096d5c..8591c4373 100644 --- a/components/model/callback_extra.go +++ b/components/model/callback_extra.go @@ -66,8 +66,6 @@ type Config struct { type CallbackInput struct { // Messages is the messages to be sent to the model. Messages []*schema.Message - // AgenticMessages is the agentic messages to be sent to the agentic model. - AgenticMessages []*schema.AgenticMessage // Tools is the tools to be used in the model. Tools []*schema.ToolInfo // ToolChoice is the tool choice, which controls the tool to be used in the model. @@ -82,8 +80,6 @@ type CallbackInput struct { type CallbackOutput struct { // Message is the message generated by the model. Message *schema.Message - // AgenticMessage is the agentic message generated by the agentic model. - AgenticMessage *schema.AgenticMessage // Config is the config for the model. Config *Config // TokenUsage is the token usage of this request. @@ -101,10 +97,6 @@ func ConvCallbackInput(src callbacks.CallbackInput) *CallbackInput { return &CallbackInput{ Messages: t, } - case []*schema.AgenticMessage: // when callback is injected by graph node, not the component implementation itself, the input is the input of Agentic Model interface, which is []*schema.AgenticMessage - return &CallbackInput{ - AgenticMessages: t, - } default: return nil } @@ -119,10 +111,6 @@ func ConvCallbackOutput(src callbacks.CallbackOutput) *CallbackOutput { return &CallbackOutput{ Message: t, } - case *schema.AgenticMessage: // when callback is injected by graph node, not the component implementation itself, the output is the output of Agentic Model interface, which is *schema.AgenticMessage - return &CallbackOutput{ - AgenticMessage: t, - } default: return nil } diff --git a/components/model/option_test.go b/components/model/option_test.go index aa43e6e01..c836933b7 100644 --- a/components/model/option_test.go +++ b/components/model/option_test.go @@ -87,7 +87,7 @@ func TestOptions(t *testing.T) { var ( toolChoice = schema.ToolChoiceForced allowedTools = []*schema.AllowedTool{ - {FunctionToolName: "agentic_tool"}, + {FunctionName: "agentic_tool"}, } ) opts := GetCommonOptions( diff --git a/components/prompt/agentic_callback_extra.go b/components/prompt/agentic_callback_extra.go new file mode 100644 index 000000000..1170854a1 --- /dev/null +++ b/components/prompt/agentic_callback_extra.go @@ -0,0 +1,70 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package prompt + +import ( + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/schema" +) + +// AgenticCallbackInput is the input for the callback. +type AgenticCallbackInput struct { + // Variables is the variables for the callback. + Variables map[string]any + // AgenticTemplates is the agentic templates for the callback. + AgenticTemplates []schema.AgenticMessagesTemplate + // Extra is the extra information for the callback. + Extra map[string]any +} + +// AgenticCallbackOutput is the output for the callback. +type AgenticCallbackOutput struct { + // AgenticResult is the agentic result for the callback. + AgenticResult []*schema.AgenticMessage + // AgenticTemplates is the agentic templates for the callback. + AgenticTemplates []schema.AgenticMessagesTemplate + // Extra is the extra information for the callback. + Extra map[string]any +} + +// ConvAgenticCallbackInput converts the callback input to the agentic prompt callback input. +func ConvAgenticCallbackInput(src callbacks.CallbackInput) *AgenticCallbackInput { + switch t := src.(type) { + case *AgenticCallbackInput: + return t + case map[string]any: + return &AgenticCallbackInput{ + Variables: t, + } + default: + return nil + } +} + +// ConvAgenticCallbackOutput converts the callback output to the agentic prompt callback output. +func ConvAgenticCallbackOutput(src callbacks.CallbackOutput) *AgenticCallbackOutput { + switch t := src.(type) { + case *AgenticCallbackOutput: + return t + case []*schema.AgenticMessage: + return &AgenticCallbackOutput{ + AgenticResult: t, + } + default: + return nil + } +} diff --git a/components/prompt/agentic_callback_extra_test.go b/components/prompt/agentic_callback_extra_test.go new file mode 100644 index 000000000..6dda1a349 --- /dev/null +++ b/components/prompt/agentic_callback_extra_test.go @@ -0,0 +1,46 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package prompt + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" +) + +func TestConvAgenticPrompt(t *testing.T) { + assert.NotNil(t, ConvAgenticCallbackInput(&AgenticCallbackInput{ + Variables: map[string]any{}, + AgenticTemplates: []schema.AgenticMessagesTemplate{ + &schema.AgenticMessage{}, + }, + })) + assert.NotNil(t, ConvAgenticCallbackInput(map[string]any{})) + assert.Nil(t, ConvAgenticCallbackInput("asd")) + + assert.NotNil(t, ConvAgenticCallbackOutput(&AgenticCallbackOutput{ + AgenticResult: []*schema.AgenticMessage{ + {}, + }, + AgenticTemplates: []schema.AgenticMessagesTemplate{ + &schema.AgenticMessage{}, + }, + })) + assert.NotNil(t, ConvAgenticCallbackOutput([]*schema.AgenticMessage{})) +} diff --git a/components/prompt/agentic_chat_template.go b/components/prompt/agentic_chat_template.go index 512a60ecd..c6c300d5a 100644 --- a/components/prompt/agentic_chat_template.go +++ b/components/prompt/agentic_chat_template.go @@ -45,7 +45,7 @@ type DefaultAgenticChatTemplate struct { func (t *DefaultAgenticChatTemplate) Format(ctx context.Context, vs map[string]any, opts ...Option) (result []*schema.AgenticMessage, err error) { ctx = callbacks.EnsureRunInfo(ctx, t.GetType(), components.ComponentOfAgenticPrompt) - ctx = callbacks.OnStart(ctx, &CallbackInput{ + ctx = callbacks.OnStart(ctx, &AgenticCallbackInput{ Variables: vs, AgenticTemplates: t.templates, }) @@ -65,7 +65,7 @@ func (t *DefaultAgenticChatTemplate) Format(ctx context.Context, vs map[string]a result = append(result, msgs...) } - _ = callbacks.OnEnd(ctx, &CallbackOutput{ + _ = callbacks.OnEnd(ctx, &AgenticCallbackOutput{ AgenticResult: result, AgenticTemplates: t.templates, }) diff --git a/components/prompt/callback_extra.go b/components/prompt/callback_extra.go index 4c27f37c6..324a418f3 100644 --- a/components/prompt/callback_extra.go +++ b/components/prompt/callback_extra.go @@ -27,8 +27,6 @@ type CallbackInput struct { Variables map[string]any // Templates is the templates for the callback. Templates []schema.MessagesTemplate - // AgenticTemplates is the agentic templates for the callback. - AgenticTemplates []schema.AgenticMessagesTemplate // Extra is the extra information for the callback. Extra map[string]any } @@ -37,12 +35,8 @@ type CallbackInput struct { type CallbackOutput struct { // Result is the result for the callback. Result []*schema.Message - // AgenticResult is the agentic result for the callback. - AgenticResult []*schema.AgenticMessage // Templates is the templates for the callback. Templates []schema.MessagesTemplate - // AgenticTemplates is the agentic templates for the callback. - AgenticTemplates []schema.AgenticMessagesTemplate // Extra is the extra information for the callback. Extra map[string]any } @@ -70,10 +64,6 @@ func ConvCallbackOutput(src callbacks.CallbackOutput) *CallbackOutput { return &CallbackOutput{ Result: t, } - case []*schema.AgenticMessage: - return &CallbackOutput{ - AgenticResult: t, - } default: return nil } diff --git a/components/prompt/callback_extra_test.go b/components/prompt/callback_extra_test.go index 4b48ec114..ad8a3c0c2 100644 --- a/components/prompt/callback_extra_test.go +++ b/components/prompt/callback_extra_test.go @@ -26,27 +26,20 @@ import ( func TestConvPrompt(t *testing.T) { assert.NotNil(t, ConvCallbackInput(&CallbackInput{ - AgenticTemplates: []schema.AgenticMessagesTemplate{ - &schema.AgenticMessage{}, + Templates: []schema.MessagesTemplate{ + &schema.Message{}, }, })) assert.NotNil(t, ConvCallbackInput(map[string]any{})) assert.Nil(t, ConvCallbackInput("asd")) assert.NotNil(t, ConvCallbackOutput(&CallbackOutput{ - AgenticResult: []*schema.AgenticMessage{ + Result: []*schema.Message{ {}, }, - AgenticTemplates: []schema.AgenticMessagesTemplate{ - &schema.AgenticMessage{}, + Templates: []schema.MessagesTemplate{ + &schema.Message{}, }, })) assert.NotNil(t, ConvCallbackOutput([]*schema.Message{})) - - agenticResult := []*schema.AgenticMessage{{}} - out := ConvCallbackOutput(agenticResult) - assert.NotNil(t, out) - assert.Equal(t, agenticResult, out.AgenticResult) - - assert.Nil(t, ConvCallbackOutput("asd")) } diff --git a/schema/agentic_message.go b/schema/agentic_message.go index a4554d38e..743f67855 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -59,7 +59,6 @@ const ( type AgenticRoleType string const ( - AgenticRoleTypeDeveloper AgenticRoleType = "developer" AgenticRoleTypeSystem AgenticRoleType = "system" AgenticRoleTypeUser AgenticRoleType = "user" AgenticRoleTypeAssistant AgenticRoleType = "assistant" @@ -426,14 +425,6 @@ type MCPToolApprovalResponse struct { Reason string } -// DeveloperAgenticMessage represents a message with AgenticRoleType "developer". -func DeveloperAgenticMessage(text string) *AgenticMessage { - return &AgenticMessage{ - Role: AgenticRoleTypeDeveloper, - ContentBlocks: []*ContentBlock{NewContentBlock(&UserInputText{Text: text})}, - } -} - // SystemAgenticMessage represents a message with AgenticRoleType "system". func SystemAgenticMessage(text string) *AgenticMessage { return &AgenticMessage{ diff --git a/schema/agentic_message_test.go b/schema/agentic_message_test.go index 4beb74930..144c0077e 100644 --- a/schema/agentic_message_test.go +++ b/schema/agentic_message_test.go @@ -1544,15 +1544,6 @@ response_meta: }) } -func TestDeveloperAgenticMessage(t *testing.T) { - t.Run("basic", func(t *testing.T) { - msg := DeveloperAgenticMessage("developer") - assert.Equal(t, AgenticRoleTypeDeveloper, msg.Role) - assert.Len(t, msg.ContentBlocks, 1) - assert.Equal(t, "developer", msg.ContentBlocks[0].UserInputText.Text) - }) -} - func TestSystemAgenticMessage(t *testing.T) { t.Run("basic", func(t *testing.T) { msg := SystemAgenticMessage("system") diff --git a/schema/tool.go b/schema/tool.go index a067d87db..c195d1f9e 100644 --- a/schema/tool.go +++ b/schema/tool.go @@ -83,13 +83,15 @@ type AgenticForcedToolChoice struct { } // AllowedTool represents a tool that the model is allowed or forced to call. -// Exactly one of FunctionToolName, MCPTool, or ServerTool must be specified. +// Exactly one of FunctionName, MCPTool, or ServerTool must be specified. type AllowedTool struct { - // FunctionToolName is the name of the function tool. - FunctionToolName string + // FunctionName specifies a function tool by name. + FunctionName string + // MCPTool specifies an MCP tool. MCPTool *AllowedMCPTool + // ServerTool specifies a server tool. ServerTool *AllowedServerTool } From d5aa8a0577ef328c3011df316683b3fc9f4e0ebb Mon Sep 17 00:00:00 2001 From: mrh997 Date: Thu, 15 Jan 2026 21:49:22 +0800 Subject: [PATCH 25/65] feat: improve callback definition (#692) --- components/model/agentic_callback_extra.go | 12 ++++++------ components/prompt/agentic_callback_extra.go | 14 +++++++------- components/prompt/agentic_callback_extra_test.go | 6 +++--- components/prompt/agentic_chat_template.go | 8 ++++---- 4 files changed, 20 insertions(+), 20 deletions(-) diff --git a/components/model/agentic_callback_extra.go b/components/model/agentic_callback_extra.go index 28dd366e6..54d49ff72 100644 --- a/components/model/agentic_callback_extra.go +++ b/components/model/agentic_callback_extra.go @@ -33,8 +33,8 @@ type AgenticConfig struct { // AgenticCallbackInput is the input for the agentic model callback. type AgenticCallbackInput struct { - // AgenticMessages is the agentic messages to be sent to the agentic model. - AgenticMessages []*schema.AgenticMessage + // Messages is the agentic messages to be sent to the agentic model. + Messages []*schema.AgenticMessage // Tools is the tools to be used in the agentic model. Tools []*schema.ToolInfo // Config is the config for the agentic model. @@ -45,8 +45,8 @@ type AgenticCallbackInput struct { // AgenticCallbackOutput is the output for the agentic model callback. type AgenticCallbackOutput struct { - // AgenticMessage is the agentic message generated by the agentic model. - AgenticMessage *schema.AgenticMessage + // Message is the agentic message generated by the agentic model. + Message *schema.AgenticMessage // Config is the config for the agentic model. Config *AgenticConfig // TokenUsage is the token usage of this request. @@ -66,7 +66,7 @@ func ConvAgenticCallbackInput(src callbacks.CallbackInput) *AgenticCallbackInput // when callback is injected by graph node, not the component implementation itself, // the input is the input of Agentic Model interface, which is []*schema.AgenticMessage return &AgenticCallbackInput{ - AgenticMessages: t, + Messages: t, } default: return nil @@ -84,7 +84,7 @@ func ConvAgenticCallbackOutput(src callbacks.CallbackOutput) *AgenticCallbackOut // when callback is injected by graph node, not the component implementation itself, // the output is the output of Agentic Model interface, which is *schema.AgenticMessage return &AgenticCallbackOutput{ - AgenticMessage: t, + Message: t, } default: return nil diff --git a/components/prompt/agentic_callback_extra.go b/components/prompt/agentic_callback_extra.go index 1170854a1..315d5a4da 100644 --- a/components/prompt/agentic_callback_extra.go +++ b/components/prompt/agentic_callback_extra.go @@ -25,18 +25,18 @@ import ( type AgenticCallbackInput struct { // Variables is the variables for the callback. Variables map[string]any - // AgenticTemplates is the agentic templates for the callback. - AgenticTemplates []schema.AgenticMessagesTemplate + // Templates is the agentic templates for the callback. + Templates []schema.AgenticMessagesTemplate // Extra is the extra information for the callback. Extra map[string]any } // AgenticCallbackOutput is the output for the callback. type AgenticCallbackOutput struct { - // AgenticResult is the agentic result for the callback. - AgenticResult []*schema.AgenticMessage - // AgenticTemplates is the agentic templates for the callback. - AgenticTemplates []schema.AgenticMessagesTemplate + // Result is the agentic result for the callback. + Result []*schema.AgenticMessage + // Templates is the agentic templates for the callback. + Templates []schema.AgenticMessagesTemplate // Extra is the extra information for the callback. Extra map[string]any } @@ -62,7 +62,7 @@ func ConvAgenticCallbackOutput(src callbacks.CallbackOutput) *AgenticCallbackOut return t case []*schema.AgenticMessage: return &AgenticCallbackOutput{ - AgenticResult: t, + Result: t, } default: return nil diff --git a/components/prompt/agentic_callback_extra_test.go b/components/prompt/agentic_callback_extra_test.go index 6dda1a349..67982be80 100644 --- a/components/prompt/agentic_callback_extra_test.go +++ b/components/prompt/agentic_callback_extra_test.go @@ -27,7 +27,7 @@ import ( func TestConvAgenticPrompt(t *testing.T) { assert.NotNil(t, ConvAgenticCallbackInput(&AgenticCallbackInput{ Variables: map[string]any{}, - AgenticTemplates: []schema.AgenticMessagesTemplate{ + Templates: []schema.AgenticMessagesTemplate{ &schema.AgenticMessage{}, }, })) @@ -35,10 +35,10 @@ func TestConvAgenticPrompt(t *testing.T) { assert.Nil(t, ConvAgenticCallbackInput("asd")) assert.NotNil(t, ConvAgenticCallbackOutput(&AgenticCallbackOutput{ - AgenticResult: []*schema.AgenticMessage{ + Result: []*schema.AgenticMessage{ {}, }, - AgenticTemplates: []schema.AgenticMessagesTemplate{ + Templates: []schema.AgenticMessagesTemplate{ &schema.AgenticMessage{}, }, })) diff --git a/components/prompt/agentic_chat_template.go b/components/prompt/agentic_chat_template.go index c6c300d5a..41d291065 100644 --- a/components/prompt/agentic_chat_template.go +++ b/components/prompt/agentic_chat_template.go @@ -46,8 +46,8 @@ type DefaultAgenticChatTemplate struct { func (t *DefaultAgenticChatTemplate) Format(ctx context.Context, vs map[string]any, opts ...Option) (result []*schema.AgenticMessage, err error) { ctx = callbacks.EnsureRunInfo(ctx, t.GetType(), components.ComponentOfAgenticPrompt) ctx = callbacks.OnStart(ctx, &AgenticCallbackInput{ - Variables: vs, - AgenticTemplates: t.templates, + Variables: vs, + Templates: t.templates, }) defer func() { if err != nil { @@ -66,8 +66,8 @@ func (t *DefaultAgenticChatTemplate) Format(ctx context.Context, vs map[string]a } _ = callbacks.OnEnd(ctx, &AgenticCallbackOutput{ - AgenticResult: result, - AgenticTemplates: t.templates, + Result: result, + Templates: t.templates, }) return result, nil From a7c6486f46bbf9f8194d7a33524d92991726f578 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Mon, 19 Jan 2026 20:58:32 +0800 Subject: [PATCH 26/65] feat: improve callback definition (#702) --- schema/agentic_message.go | 61 +++++++------------------------- schema/agentic_message_test.go | 45 ++++++++--------------- schema/tool.go | 4 +++ utils/callbacks/template.go | 14 ++++---- utils/callbacks/template_test.go | 8 ++--- 5 files changed, 41 insertions(+), 91 deletions(-) diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 743f67855..ead2d866d 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -270,19 +270,12 @@ type AssistantGenVideo struct { } type Reasoning struct { - // Summary is the reasoning content summary. - Summary []*ReasoningSummary - - // EncryptedContent is the encrypted reasoning content. - EncryptedContent string -} - -type ReasoningSummary struct { - // Index specifies the index position of this summary in the final Reasoning. - Index int - - // Text is the reasoning content summary. + // Text is either the thought summary or the raw reasoning text itself. Text string + + // Signature contains encrypted reasoning tokens. + // Required by some models when passing reasoning text back. + Signature string } type FunctionToolCall struct { @@ -1172,42 +1165,15 @@ func concatReasoning(reasons []*Reasoning) (*Reasoning, error) { ret := &Reasoning{} - var allSummaries []*ReasoningSummary for _, r := range reasons { - if r == nil { - continue + if r.Text != "" { + ret.Text += r.Text } - allSummaries = append(allSummaries, r.Summary...) - if r.EncryptedContent != "" { - ret.EncryptedContent += r.EncryptedContent + if r.Signature != "" { + ret.Signature += r.Signature } } - var ( - indices []int - indexToSummary = map[int]*ReasoningSummary{} - ) - - for _, s := range allSummaries { - if s == nil { - continue - } - if indexToSummary[s.Index] == nil { - indexToSummary[s.Index] = &ReasoningSummary{} - indices = append(indices, s.Index) - } - indexToSummary[s.Index].Text += s.Text - } - - sort.Slice(indices, func(i, j int) bool { - return indices[i] < indices[j] - }) - - ret.Summary = make([]*ReasoningSummary, 0, len(indices)) - for _, idx := range indices { - ret.Summary = append(ret.Summary, indexToSummary[idx]) - } - return ret, nil } @@ -1899,12 +1865,9 @@ func (b *ContentBlock) String() string { // String returns the string representation of Reasoning. func (r *Reasoning) String() string { sb := &strings.Builder{} - sb.WriteString(fmt.Sprintf(" summary: %d items\n", len(r.Summary))) - for _, s := range r.Summary { - sb.WriteString(fmt.Sprintf(" [%d] %s\n", s.Index, s.Text)) - } - if r.EncryptedContent != "" { - sb.WriteString(fmt.Sprintf(" encrypted_content: %s\n", truncateString(r.EncryptedContent, 50))) + sb.WriteString(fmt.Sprintf(" text: %s\n", r.Text)) + if r.Signature != "" { + sb.WriteString(fmt.Sprintf(" signature: %s\n", truncateString(r.Signature, 50))) } return sb.String() } diff --git a/schema/agentic_message_test.go b/schema/agentic_message_test.go index 144c0077e..e8a1003f5 100644 --- a/schema/agentic_message_test.go +++ b/schema/agentic_message_test.go @@ -109,9 +109,7 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeReasoning, Reasoning: &Reasoning{ - Summary: []*ReasoningSummary{ - {Index: 0, Text: "First "}, - }, + Text: "First ", }, StreamingMeta: &StreamingMeta{Index: 0}, }, @@ -123,9 +121,7 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeReasoning, Reasoning: &Reasoning{ - Summary: []*ReasoningSummary{ - {Index: 0, Text: "Second"}, - }, + Text: "Second", }, StreamingMeta: &StreamingMeta{Index: 0}, }, @@ -136,9 +132,7 @@ func TestConcatAgenticMessages(t *testing.T) { result, err := ConcatAgenticMessages(msgs) assert.NoError(t, err) assert.Len(t, result.ContentBlocks, 1) - assert.Len(t, result.ContentBlocks[0].Reasoning.Summary, 1) - assert.Equal(t, "First Second", result.ContentBlocks[0].Reasoning.Summary[0].Text) - assert.Equal(t, 0, result.ContentBlocks[0].Reasoning.Summary[0].Index) + assert.Equal(t, "First Second", result.ContentBlocks[0].Reasoning.Text) }) t.Run("concat reasoning with index", func(t *testing.T) { @@ -149,10 +143,7 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeReasoning, Reasoning: &Reasoning{ - Summary: []*ReasoningSummary{ - {Index: 0, Text: "Part1-"}, - {Index: 1, Text: "Part2-"}, - }, + Text: "Part1-", }, StreamingMeta: &StreamingMeta{Index: 0}, }, @@ -164,10 +155,7 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeReasoning, Reasoning: &Reasoning{ - Summary: []*ReasoningSummary{ - {Index: 0, Text: "Part3"}, - {Index: 1, Text: "Part4"}, - }, + Text: "Part3", }, StreamingMeta: &StreamingMeta{Index: 0}, }, @@ -178,9 +166,7 @@ func TestConcatAgenticMessages(t *testing.T) { result, err := ConcatAgenticMessages(msgs) assert.NoError(t, err) assert.Len(t, result.ContentBlocks, 1) - assert.Len(t, result.ContentBlocks[0].Reasoning.Summary, 2) - assert.Equal(t, "Part1-Part3", result.ContentBlocks[0].Reasoning.Summary[0].Text) - assert.Equal(t, "Part2-Part4", result.ContentBlocks[0].Reasoning.Summary[1].Text) + assert.Equal(t, "Part1-Part3", result.ContentBlocks[0].Reasoning.Text) }) t.Run("concat user input text", func(t *testing.T) { @@ -1292,12 +1278,10 @@ func TestAgenticMessageString(t *testing.T) { { Type: ContentBlockTypeReasoning, Reasoning: &Reasoning{ - Summary: []*ReasoningSummary{ - {Index: 0, Text: "First, I need to identify the location (New York City) from the user's query."}, - {Index: 1, Text: "Then, I should call the weather API to get current conditions."}, - {Index: 2, Text: "Finally, I'll format the response in a user-friendly way with temperature and conditions."}, - }, - EncryptedContent: "encrypted_reasoning_content_that_is_very_long_and_will_be_truncated_for_display", + Text: "First, I need to identify the location (New York City) from the user's query.\n" + + "Then, I should call the weather API to get current conditions.\n" + + "Finally, I'll format the response in a user-friendly way with temperature and conditions.", + Signature: "encrypted_reasoning_content_that_is_very_long_and_will_be_truncated_for_display", }, }, { @@ -1432,11 +1416,10 @@ content_blocks: base64_data: gen_video_data... (14 bytes) mime_type: video/mp4 [9] type: reasoning - summary: 3 items - [0] First, I need to identify the location (New York City) from the user's query. - [1] Then, I should call the weather API to get current conditions. - [2] Finally, I'll format the response in a user-friendly way with temperature and conditions. - encrypted_content: encrypted_reasoning_content_that_is_very_long_and_... + text: First, I need to identify the location (New York City) from the user's query. +Then, I should call the weather API to get current conditions. +Finally, I'll format the response in a user-friendly way with temperature and conditions. + signature: encrypted_reasoning_content_that_is_very_long_and_... [10] type: function_tool_call call_id: call_weather_123 name: get_current_weather diff --git a/schema/tool.go b/schema/tool.go index c195d1f9e..a49306047 100644 --- a/schema/tool.go +++ b/schema/tool.go @@ -62,9 +62,13 @@ const ( type AgenticToolChoice struct { // Type is the tool choice mode. Type ToolChoice + // Allowed optionally specifies the list of tools that the model is permitted to call. + // Optional. Allowed *AgenticAllowedToolChoice + // Forced optionally specifies the list of tools that the model is required to call. + // Optional. Forced *AgenticForcedToolChoice } diff --git a/utils/callbacks/template.go b/utils/callbacks/template.go index 4c73e6bbc..4c2c709da 100644 --- a/utils/callbacks/template.go +++ b/utils/callbacks/template.go @@ -187,7 +187,7 @@ func (c *handlerTemplate) OnStart(ctx context.Context, info *callbacks.RunInfo, case components.ComponentOfChatModel: return c.chatModelHandler.OnStart(ctx, info, model.ConvCallbackInput(input)) case components.ComponentOfAgenticModel: - return c.agenticModelHandler.OnStart(ctx, info, model.ConvCallbackInput(input)) + return c.agenticModelHandler.OnStart(ctx, info, model.ConvAgenticCallbackInput(input)) case components.ComponentOfEmbedding: return c.embeddingHandler.OnStart(ctx, info, embedding.ConvCallbackInput(input)) case components.ComponentOfIndexer: @@ -226,7 +226,7 @@ func (c *handlerTemplate) OnEnd(ctx context.Context, info *callbacks.RunInfo, ou case components.ComponentOfChatModel: return c.chatModelHandler.OnEnd(ctx, info, model.ConvCallbackOutput(output)) case components.ComponentOfAgenticModel: - return c.agenticModelHandler.OnEnd(ctx, info, model.ConvCallbackOutput(output)) + return c.agenticModelHandler.OnEnd(ctx, info, model.ConvAgenticCallbackOutput(output)) case components.ComponentOfEmbedding: return c.embeddingHandler.OnEnd(ctx, info, embedding.ConvCallbackOutput(output)) case components.ComponentOfIndexer: @@ -316,8 +316,8 @@ func (c *handlerTemplate) OnEndWithStreamOutput(ctx context.Context, info *callb })) case components.ComponentOfAgenticModel: return c.agenticModelHandler.OnEndWithStreamOutput(ctx, info, - schema.StreamReaderWithConvert(output, func(item callbacks.CallbackOutput) (*model.CallbackOutput, error) { - return model.ConvCallbackOutput(item), nil + schema.StreamReaderWithConvert(output, func(item callbacks.CallbackOutput) (*model.AgenticCallbackOutput, error) { + return model.ConvAgenticCallbackOutput(item), nil })) case components.ComponentOfTool: return c.toolHandler.OnEndWithStreamOutput(ctx, info, @@ -686,9 +686,9 @@ func (ch *AgenticPromptCallbackHandler) Needed(ctx context.Context, runInfo *cal // AgenticModelCallbackHandler is the handler for the agentic chat model callback. type AgenticModelCallbackHandler struct { - OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.CallbackInput) context.Context - OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *model.CallbackOutput) context.Context - OnEndWithStreamOutput func(ctx context.Context, runInfo *callbacks.RunInfo, output *schema.StreamReader[*model.CallbackOutput]) context.Context + OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.AgenticCallbackInput) context.Context + OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *model.AgenticCallbackOutput) context.Context + OnEndWithStreamOutput func(ctx context.Context, runInfo *callbacks.RunInfo, output *schema.StreamReader[*model.AgenticCallbackOutput]) context.Context OnError func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context } diff --git a/utils/callbacks/template_test.go b/utils/callbacks/template_test.go index f599e5300..dcc0e5c7f 100644 --- a/utils/callbacks/template_test.go +++ b/utils/callbacks/template_test.go @@ -143,15 +143,15 @@ func TestNewComponentTemplate(t *testing.T) { return ctx }).Build()). AgenticModel(&AgenticModelCallbackHandler{ - OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.CallbackInput) context.Context { + OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.AgenticCallbackInput) context.Context { cnt++ return ctx }, - OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *model.CallbackOutput) context.Context { + OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *model.AgenticCallbackOutput) context.Context { cnt++ return ctx }, - OnEndWithStreamOutput: func(ctx context.Context, runInfo *callbacks.RunInfo, output *schema.StreamReader[*model.CallbackOutput]) context.Context { + OnEndWithStreamOutput: func(ctx context.Context, runInfo *callbacks.RunInfo, output *schema.StreamReader[*model.AgenticCallbackOutput]) context.Context { output.Close() cnt++ return ctx @@ -485,7 +485,7 @@ func TestNewComponentTemplate(t *testing.T) { // Set it now tpl2.AgenticModel(&AgenticModelCallbackHandler{ - OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.CallbackInput) context.Context { + OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.AgenticCallbackInput) context.Context { return ctx }, }) From 3b481db3a5a556f007a11b814dc2ad0e6314edeb Mon Sep 17 00:00:00 2001 From: mrh997 Date: Mon, 19 Jan 2026 22:12:10 +0800 Subject: [PATCH 27/65] feat: agentic model support MaxTokens (#703) --- components/model/agentic_callback_extra.go | 2 ++ components/model/option.go | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/components/model/agentic_callback_extra.go b/components/model/agentic_callback_extra.go index 54d49ff72..9a769cf7e 100644 --- a/components/model/agentic_callback_extra.go +++ b/components/model/agentic_callback_extra.go @@ -25,6 +25,8 @@ import ( type AgenticConfig struct { // Model is the model name. Model string + // MaxTokens is the max number of output tokens, if reached the max tokens, the model will stop generating. + MaxTokens int // Temperature is the temperature, which controls the randomness of the agentic model. Temperature float32 // TopP is the top p, which controls the diversity of the agentic model. diff --git a/components/model/option.go b/components/model/option.go index a337b7af2..a46b71b19 100644 --- a/components/model/option.go +++ b/components/model/option.go @@ -28,13 +28,13 @@ type Options struct { TopP *float32 // Tools is a list of tools the model may call. Tools []*schema.ToolInfo + // MaxTokens is the max number of tokens, if reached the max tokens, the model will stop generating, and mostly return a finish reason of "length". + MaxTokens *int // Options only available for chat model. // ToolChoice controls which tool is called by the model. ToolChoice *schema.ToolChoice - // MaxTokens is the max number of tokens, if reached the max tokens, the model will stop generating, and mostly return an finish reason of "length". - MaxTokens *int // AllowedToolNames specifies a list of tool names that the model is allowed to call. // This allows for constraining the model to a specific subset of the available tools. AllowedToolNames []string From 5e29f6fa7f1e05e094da4044d0e46e10bcbe1250 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Tue, 20 Jan 2026 13:21:40 +0800 Subject: [PATCH 28/65] feat: agentic model support stop option --- components/model/option.go | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/components/model/option.go b/components/model/option.go index a46b71b19..936b0fbda 100644 --- a/components/model/option.go +++ b/components/model/option.go @@ -30,6 +30,8 @@ type Options struct { Tools []*schema.ToolInfo // MaxTokens is the max number of tokens, if reached the max tokens, the model will stop generating, and mostly return a finish reason of "length". MaxTokens *int + // Stop is the stop words for the model, which controls the stopping condition of the model. + Stop []string // Options only available for chat model. @@ -38,8 +40,6 @@ type Options struct { // AllowedToolNames specifies a list of tool names that the model is allowed to call. // This allows for constraining the model to a specific subset of the available tools. AllowedToolNames []string - // Stop is the stop words for the model, which controls the stopping condition of the model. - Stop []string // Options only available for agentic model. @@ -67,7 +67,6 @@ func WithTemperature(temperature float32) Option { } // WithMaxTokens is the option to set the max tokens for the model. -// Only available for ChatModel. func WithMaxTokens(maxTokens int) Option { return Option{ apply: func(opts *Options) { @@ -95,7 +94,6 @@ func WithTopP(topP float32) Option { } // WithStop is the option to set the stop words for the model. -// Only available for ChatModel. func WithStop(stop []string) Option { return Option{ apply: func(opts *Options) { From 1840123edb2667e724372d0f970528c96d8ea3a7 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Fri, 13 Mar 2026 10:51:24 +0800 Subject: [PATCH 29/65] feat: add json tag for agentic message (#880) --- schema/agentic_message.go | 198 ++++++++++++++++++------------------ utils/callbacks/template.go | 2 +- 2 files changed, 100 insertions(+), 100 deletions(-) diff --git a/schema/agentic_message.go b/schema/agentic_message.go index ead2d866d..95e14c0df 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -66,356 +66,356 @@ const ( type AgenticMessage struct { // Role is the message role. - Role AgenticRoleType + Role AgenticRoleType `json:"role"` // ContentBlocks is the list of content blocks. - ContentBlocks []*ContentBlock + ContentBlocks []*ContentBlock `json:"content_blocks,omitempty"` // ResponseMeta is the response metadata. - ResponseMeta *AgenticResponseMeta + ResponseMeta *AgenticResponseMeta `json:"response_meta,omitempty"` // Extra is the additional information. - Extra map[string]any + Extra map[string]any `json:"extra,omitempty"` } type AgenticResponseMeta struct { // TokenUsage is the token usage. - TokenUsage *TokenUsage + TokenUsage *TokenUsage `json:"token_usage,omitempty"` // OpenAIExtension is the extension for OpenAI. - OpenAIExtension *openai.ResponseMetaExtension + OpenAIExtension *openai.ResponseMetaExtension `json:"openai_extension,omitempty"` // GeminiExtension is the extension for Gemini. - GeminiExtension *gemini.ResponseMetaExtension + GeminiExtension *gemini.ResponseMetaExtension `json:"gemini_extension,omitempty"` // ClaudeExtension is the extension for Claude. - ClaudeExtension *claude.ResponseMetaExtension + ClaudeExtension *claude.ResponseMetaExtension `json:"claude_extension,omitempty"` // Extension is the extension for other models, supplied by the component implementer. - Extension any + Extension any `json:"extension,omitempty"` } type ContentBlock struct { - Type ContentBlockType + Type ContentBlockType `json:"type"` // Reasoning contains the reasoning content generated by the model. - Reasoning *Reasoning + Reasoning *Reasoning `json:"reasoning,omitempty"` // UserInputText contains the text content provided by the user. - UserInputText *UserInputText + UserInputText *UserInputText `json:"user_input_text,omitempty"` // UserInputImage contains the image content provided by the user. - UserInputImage *UserInputImage + UserInputImage *UserInputImage `json:"user_input_image,omitempty"` // UserInputAudio contains the audio content provided by the user. - UserInputAudio *UserInputAudio + UserInputAudio *UserInputAudio `json:"user_input_audio,omitempty"` // UserInputVideo contains the video content provided by the user. - UserInputVideo *UserInputVideo + UserInputVideo *UserInputVideo `json:"user_input_video,omitempty"` // UserInputFile contains the file content provided by the user. - UserInputFile *UserInputFile + UserInputFile *UserInputFile `json:"user_input_file,omitempty"` // AssistantGenText contains the text content generated by the model. - AssistantGenText *AssistantGenText + AssistantGenText *AssistantGenText `json:"assistant_gen_text,omitempty"` // AssistantGenImage contains the image content generated by the model. - AssistantGenImage *AssistantGenImage + AssistantGenImage *AssistantGenImage `json:"assistant_gen_image,omitempty"` // AssistantGenAudio contains the audio content generated by the model. - AssistantGenAudio *AssistantGenAudio + AssistantGenAudio *AssistantGenAudio `json:"assistant_gen_audio,omitempty"` // AssistantGenVideo contains the video content generated by the model. - AssistantGenVideo *AssistantGenVideo + AssistantGenVideo *AssistantGenVideo `json:"assistant_gen_video,omitempty"` // FunctionToolCall contains the invocation details for a user-defined tool. - FunctionToolCall *FunctionToolCall + FunctionToolCall *FunctionToolCall `json:"function_tool_call,omitempty"` // FunctionToolResult contains the result returned from a user-defined tool call. - FunctionToolResult *FunctionToolResult + FunctionToolResult *FunctionToolResult `json:"function_tool_result,omitempty"` // ServerToolCall contains the invocation details for a provider built-in tool executed on the model server. - ServerToolCall *ServerToolCall + ServerToolCall *ServerToolCall `json:"server_tool_call,omitempty"` // ServerToolResult contains the result returned from a provider built-in tool executed on the model server. - ServerToolResult *ServerToolResult + ServerToolResult *ServerToolResult `json:"server_tool_result,omitempty"` // MCPToolCall contains the invocation details for an MCP tool managed by the model server. - MCPToolCall *MCPToolCall + MCPToolCall *MCPToolCall `json:"mcp_tool_call,omitempty"` // MCPToolResult contains the result returned from an MCP tool managed by the model server. - MCPToolResult *MCPToolResult + MCPToolResult *MCPToolResult `json:"mcp_tool_result,omitempty"` // MCPListToolsResult contains the list of available MCP tools reported by the model server. - MCPListToolsResult *MCPListToolsResult + MCPListToolsResult *MCPListToolsResult `json:"mcp_list_tools_result,omitempty"` // MCPToolApprovalRequest contains the user approval request for an MCP tool call when required. - MCPToolApprovalRequest *MCPToolApprovalRequest + MCPToolApprovalRequest *MCPToolApprovalRequest `json:"mcp_tool_approval_request,omitempty"` // MCPToolApprovalResponse contains the user's approval decision for an MCP tool call. - MCPToolApprovalResponse *MCPToolApprovalResponse + MCPToolApprovalResponse *MCPToolApprovalResponse `json:"mcp_tool_approval_response,omitempty"` // StreamingMeta contains metadata for streaming responses. - StreamingMeta *StreamingMeta + StreamingMeta *StreamingMeta `json:"streaming_meta,omitempty"` // Extra contains additional information for the content block. - Extra map[string]any + Extra map[string]any `json:"extra,omitempty"` } type StreamingMeta struct { // Index specifies the index position of this block in the final response. - Index int + Index int `json:"index"` } type UserInputText struct { // Text is the text content. - Text string + Text string `json:"text,omitempty"` } type UserInputImage struct { // URL is the HTTP/HTTPS link. - URL string + URL string `json:"url,omitempty"` // Base64Data is the binary data in Base64 encoded string format. - Base64Data string + Base64Data string `json:"base64_data,omitempty"` // MIMEType is the mime type, e.g. "image/png". - MIMEType string + MIMEType string `json:"mime_type,omitempty"` // Detail is the quality of the image url. - Detail ImageURLDetail + Detail ImageURLDetail `json:"detail,omitempty"` } type UserInputAudio struct { // URL is the HTTP/HTTPS link. - URL string + URL string `json:"url,omitempty"` // Base64Data is the binary data in Base64 encoded string format. - Base64Data string + Base64Data string `json:"base64_data,omitempty"` // MIMEType is the mime type, e.g. "audio/wav". - MIMEType string + MIMEType string `json:"mime_type,omitempty"` } type UserInputVideo struct { // URL is the HTTP/HTTPS link. - URL string + URL string `json:"url,omitempty"` // Base64Data is the binary data in Base64 encoded string format. - Base64Data string + Base64Data string `json:"base64_data,omitempty"` // MIMEType is the mime type, e.g. "video/mp4". - MIMEType string + MIMEType string `json:"mime_type,omitempty"` } type UserInputFile struct { // URL is the HTTP/HTTPS link. - URL string + URL string `json:"url,omitempty"` // Name is the filename. - Name string + Name string `json:"name,omitempty"` // Base64Data is the binary data in Base64 encoded string format. - Base64Data string + Base64Data string `json:"base64_data,omitempty"` // MIMEType is the mime type, e.g. "application/pdf". - MIMEType string + MIMEType string `json:"mime_type,omitempty"` } type AssistantGenText struct { // Text is the generated text. - Text string + Text string `json:"text,omitempty"` // OpenAIExtension is the extension for OpenAI. - OpenAIExtension *openai.AssistantGenTextExtension + OpenAIExtension *openai.AssistantGenTextExtension `json:"openai_extension,omitempty"` // ClaudeExtension is the extension for Claude. - ClaudeExtension *claude.AssistantGenTextExtension + ClaudeExtension *claude.AssistantGenTextExtension `json:"claude_extension,omitempty"` // Extension is the extension for other models, supplied by the component implementer. - Extension any + Extension any `json:"extension,omitempty"` } type AssistantGenImage struct { // URL is the HTTP/HTTPS link. - URL string + URL string `json:"url,omitempty"` // Base64Data is the binary data in Base64 encoded string format. - Base64Data string + Base64Data string `json:"base64_data,omitempty"` // MIMEType is the mime type, e.g. "image/png". - MIMEType string + MIMEType string `json:"mime_type,omitempty"` } type AssistantGenAudio struct { // URL is the HTTP/HTTPS link. - URL string + URL string `json:"url,omitempty"` // Base64Data is the binary data in Base64 encoded string format. - Base64Data string + Base64Data string `json:"base64_data,omitempty"` // MIMEType is the mime type, e.g. "audio/wav". - MIMEType string + MIMEType string `json:"mime_type,omitempty"` } type AssistantGenVideo struct { // URL is the HTTP/HTTPS link. - URL string + URL string `json:"url,omitempty"` // Base64Data is the binary data in Base64 encoded string format. - Base64Data string + Base64Data string `json:"base64_data,omitempty"` // MIMEType is the mime type, e.g. "video/mp4". - MIMEType string + MIMEType string `json:"mime_type,omitempty"` } type Reasoning struct { // Text is either the thought summary or the raw reasoning text itself. - Text string + Text string `json:"text,omitempty"` // Signature contains encrypted reasoning tokens. // Required by some models when passing reasoning text back. - Signature string + Signature string `json:"signature,omitempty"` } type FunctionToolCall struct { // CallID is the unique identifier for the tool call. - CallID string + CallID string `json:"call_id,omitempty"` // Name specifies the function tool invoked. - Name string + Name string `json:"name"` // Arguments is the JSON string arguments for the function tool call. - Arguments string + Arguments string `json:"arguments,omitempty"` } type FunctionToolResult struct { // CallID is the unique identifier for the tool call. - CallID string + CallID string `json:"call_id,omitempty"` // Name specifies the function tool invoked. - Name string + Name string `json:"name"` // Result is the function tool result returned by the user - Result string + Result string `json:"result,omitempty"` } type ServerToolCall struct { // Name specifies the server-side tool invoked. // Supplied by the model server (e.g., `web_search` for OpenAI, `googleSearch` for Gemini). - Name string + Name string `json:"name"` // CallID is the unique identifier for the tool call. // Empty if not provided by the model server. - CallID string + CallID string `json:"call_id,omitempty"` // Arguments are the raw inputs to the server-side tool, // supplied by the component implementer. - Arguments any + Arguments any `json:"arguments,omitempty"` } type ServerToolResult struct { // Name specifies the server-side tool invoked. // Supplied by the model server (e.g., `web_search` for OpenAI, `googleSearch` for Gemini). - Name string + Name string `json:"name"` // CallID is the unique identifier for the tool call. // Empty if not provided by the model server. - CallID string + CallID string `json:"call_id,omitempty"` // Result refers to the raw output generated by the server-side tool, // supplied by the component implementer. - Result any + Result any `json:"result,omitempty"` } type MCPToolCall struct { // ServerLabel is the MCP server label used to identify it in tool calls - ServerLabel string + ServerLabel string `json:"server_label,omitempty"` // ApprovalRequestID is the approval request ID. - ApprovalRequestID string + ApprovalRequestID string `json:"approval_request_id,omitempty"` // CallID is the unique ID of the tool call. - CallID string + CallID string `json:"call_id,omitempty"` // Name is the name of the tool to run. - Name string + Name string `json:"name"` // Arguments is the JSON string arguments for the tool call. - Arguments string + Arguments string `json:"arguments,omitempty"` } type MCPToolResult struct { // ServerLabel is the MCP server label used to identify it in tool calls - ServerLabel string + ServerLabel string `json:"server_label,omitempty"` // CallID is the unique ID of the tool call. - CallID string + CallID string `json:"call_id,omitempty"` // Name is the name of the tool to run. - Name string + Name string `json:"name"` // Result is the JSON string with the tool result. - Result string + Result string `json:"result,omitempty"` // Error returned when the server fails to run the tool. - Error *MCPToolCallError + Error *MCPToolCallError `json:"error,omitempty"` } type MCPToolCallError struct { // Code is the error code. - Code *int64 + Code *int64 `json:"code,omitempty"` // Message is the error message. - Message string + Message string `json:"message,omitempty"` } type MCPListToolsResult struct { // ServerLabel is the MCP server label used to identify it in tool calls. - ServerLabel string + ServerLabel string `json:"server_label,omitempty"` // Tools is the list of tools available on the server. - Tools []*MCPListToolsItem + Tools []*MCPListToolsItem `json:"tools,omitempty"` // Error returned when the server fails to list tools. - Error string + Error string `json:"error,omitempty"` } type MCPListToolsItem struct { // Name is the name of the tool. - Name string + Name string `json:"name"` // Description is the description of the tool. - Description string + Description string `json:"description"` // InputSchema is the JSON schema that describes the tool input parameters. - InputSchema *jsonschema.Schema + InputSchema *jsonschema.Schema `json:"input_schema,omitempty"` } type MCPToolApprovalRequest struct { // ID is the approval request ID. - ID string + ID string `json:"id,omitempty"` // Name is the name of the tool to run. - Name string + Name string `json:"name"` // Arguments is the JSON string arguments for the tool call. - Arguments string + Arguments string `json:"arguments,omitempty"` // ServerLabel is the MCP server label used to identify it in tool calls. - ServerLabel string + ServerLabel string `json:"server_label,omitempty"` } type MCPToolApprovalResponse struct { // ApprovalRequestID is the approval request ID being responded to. - ApprovalRequestID string + ApprovalRequestID string `json:"approval_request_id,omitempty"` // Approve indicates whether the request is approved. - Approve bool + Approve bool `json:"approve"` // Reason is the rationale for the decision. // Optional. - Reason string + Reason string `json:"reason,omitempty"` } // SystemAgenticMessage represents a message with AgenticRoleType "system". diff --git a/utils/callbacks/template.go b/utils/callbacks/template.go index 4c2c709da..f01a849b6 100644 --- a/utils/callbacks/template.go +++ b/utils/callbacks/template.go @@ -64,10 +64,10 @@ type HandlerHelper struct { transformerHandler *TransformerCallbackHandler toolHandler *ToolCallbackHandler toolsNodeHandler *ToolsNodeCallbackHandlers + agentHandler *AgentCallbackHandler agenticPromptHandler *AgenticPromptCallbackHandler agenticModelHandler *AgenticModelCallbackHandler agenticToolsNodeHandler *AgenticToolsNodeCallbackHandlers - agentHandler *AgentCallbackHandler composeTemplates map[components.Component]callbacks.Handler } From 927da7a6e0fe97f4405117310569c15a0c202da4 Mon Sep 17 00:00:00 2001 From: Ryo Date: Fri, 13 Mar 2026 14:09:00 +0800 Subject: [PATCH 30/65] =?UTF-8?q?feat(adk):=20add=20agentmd=20middleware?= =?UTF-8?q?=20for=20auto-injecting=20Agents.md=20into=20m=E2=80=A6=20(#882?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit feat(adk): add agentmd middleware for auto-injecting Agents.md into model input Change-Id: I34add4f925a23c6d6821925c482a21f6cddfddd4 --- adk/middlewares/agentsmd/agentsmd.go | 183 +++ adk/middlewares/agentsmd/agentsmd_test.go | 1420 +++++++++++++++++++++ adk/middlewares/agentsmd/loader.go | 299 +++++ 3 files changed, 1902 insertions(+) create mode 100644 adk/middlewares/agentsmd/agentsmd.go create mode 100644 adk/middlewares/agentsmd/agentsmd_test.go create mode 100644 adk/middlewares/agentsmd/loader.go diff --git a/adk/middlewares/agentsmd/agentsmd.go b/adk/middlewares/agentsmd/agentsmd.go new file mode 100644 index 000000000..7d29896a7 --- /dev/null +++ b/adk/middlewares/agentsmd/agentsmd.go @@ -0,0 +1,183 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package agentsmd provides a middleware that automatically injects Agents.md +// file contents into model input messages. The injection is transient — content +// is prepended at model call time and never persisted to conversation state, +// so it is naturally excluded from summarization / compression. +package agentsmd + +import ( + "context" + "fmt" + + "github.com/cloudwego/eino/adk" + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/schema" +) + +// Config defines the configuration for the agentsmd middleware. +type Config struct { + // Backend provides file access for loading Agents.md files. + // Implementations can use local filesystem, remote storage, or any other backend. + // Required. + Backend Backend + + // AgentsMDFiles specifies the ordered list of Agents.md file paths to load. + // Files are loaded and injected in the given order. + // Supports @import syntax inside files for recursive inclusion (max depth 5). + AgentsMDFiles []string + + // AllAgentsMDMaxBytes limits the total byte size of all loaded Agents.md content. + // Files are loaded in order; once the cumulative size exceeds this limit, + // remaining files are skipped. Each individual file is always loaded in full. + // 0 means no limit. + AllAgentsMDMaxBytes int + + // OnLoadWarning is an optional callback invoked when a non-fatal error occurs + // during Agents.md file loading (e.g. file not found, circular @import, depth + // exceeded). If nil, warnings are logged via log.Printf. + // + // Note: Backend.Read errors other than os.ErrNotExist (e.g. permission denied, + // I/O errors) are NOT treated as warnings and will abort the loading process. + OnLoadWarning func(filePath string, err error) +} + +// New creates an agentsmd middleware that injects Agents.md content into every +// model call. The content is loaded from the configured file paths via Backend +// on each model invocation. +// +// Recommended: place this middleware AFTER the summarization middleware, so that +// Agents.md content is excluded from summarization/compression. +func New(_ context.Context, cfg *Config) (adk.ChatModelAgentMiddleware, error) { + if err := cfg.validate(); err != nil { + return nil, err + } + + return &middleware{ + BaseChatModelAgentMiddleware: &adk.BaseChatModelAgentMiddleware{}, + loader: newLoaderConfig(cfg.Backend, cfg.AgentsMDFiles, cfg.AllAgentsMDMaxBytes, cfg.OnLoadWarning), + }, nil +} + +type middleware struct { + *adk.BaseChatModelAgentMiddleware + loader *loaderConfig +} + +// WrapModel returns a proxy model that prepends Agents.md content to the input +// messages on every Generate/Stream call. The injected message is never written +// back to ChatModelAgentState, so summarization and reduction middlewares are +// unaffected. +func (m *middleware) WrapModel(_ context.Context, cm model.BaseChatModel, _ *adk.ModelContext) (model.BaseChatModel, error) { + return &agentMDModel{ + inner: cm, + loader: m.loader, + }, nil +} + +// agentMDModel wraps a BaseChatModel to prepend Agents.md content to input. +type agentMDModel struct { + inner model.BaseChatModel + loader *loaderConfig +} + +func (m *agentMDModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + messages, err := m.prependAgentMD(ctx, input) + if err != nil { + return nil, err + } + return m.inner.Generate(ctx, messages, opts...) +} + +func (m *agentMDModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + messages, err := m.prependAgentMD(ctx, input) + if err != nil { + return nil, err + } + return m.inner.Stream(ctx, messages, opts...) +} + +const agentsMDCacheKey = "__agentsmd_content_cache__" + +// prependAgentMD loads the current Agents.md content and inserts it before the +// first User role message. If all configured agent files are empty (or skipped), +// the original input is returned unchanged. +// The loaded content is cached in RunLocalValue for the duration of the agent Run(). +func (m *agentMDModel) prependAgentMD(ctx context.Context, input []*schema.Message) ([]*schema.Message, error) { + var content string + + // Try to get cached content from RunLocalValue. + if cached, found, err := adk.GetRunLocalValue(ctx, agentsMDCacheKey); err == nil && found { + if s, ok := cached.(string); ok { + content = s + } + } + + if content == "" { + var err error + content, err = m.loader.load(ctx) + if err != nil { + return nil, fmt.Errorf("[agentsmd]: failed to load agent files: %w", err) + } + // Cache the loaded content for subsequent model calls in this Run(). + if content != "" { + _ = adk.SetRunLocalValue(ctx, agentsMDCacheKey, content) + } + } + if content == "" { + return input, nil + } + + agentMDMsg := &schema.Message{ + Role: schema.User, + Content: content, + } + + // Insert agentMDMsg before the first User role message. + messages := make([]*schema.Message, 0, len(input)+1) + inserted := false + for i, msg := range input { + if !inserted && msg.Role == schema.User { + messages = append(messages, agentMDMsg) + messages = append(messages, input[i:]...) + inserted = true + break + } + messages = append(messages, msg) + } + if !inserted { + // No User message found; append at the end as fallback. + messages = append(messages, agentMDMsg) + } + return messages, nil +} + +func (c *Config) validate() error { + if c == nil { + return fmt.Errorf("[agentsmd]: config is required") + } + if c.Backend == nil { + return fmt.Errorf("[agentsmd]: backend is required") + } + if len(c.AgentsMDFiles) == 0 { + return fmt.Errorf("[agentsmd]: at least one agent file path is required") + } + if c.AllAgentsMDMaxBytes < 0 { + return fmt.Errorf("[agentsmd]: AllAgentMDDocsMaxBytes must be non-negative") + } + return nil +} diff --git a/adk/middlewares/agentsmd/agentsmd_test.go b/adk/middlewares/agentsmd/agentsmd_test.go new file mode 100644 index 000000000..e3d7e00e9 --- /dev/null +++ b/adk/middlewares/agentsmd/agentsmd_test.go @@ -0,0 +1,1420 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package agentsmd + +import ( + "context" + "fmt" + "os" + "strings" + "testing" + + "github.com/cloudwego/eino/adk/filesystem" + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/schema" +) + +// --- test helpers --- + +type mockModel struct { + lastInput []*schema.Message +} + +func (m *mockModel) Generate(_ context.Context, input []*schema.Message, _ ...model.Option) (*schema.Message, error) { + m.lastInput = input + return &schema.Message{Role: schema.Assistant, Content: "ok"}, nil +} + +func (m *mockModel) Stream(_ context.Context, input []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + m.lastInput = input + return nil, nil +} + +type memBackend struct { + files map[string]string +} + +func newMemBackend() *memBackend { + return &memBackend{files: make(map[string]string)} +} + +func (b *memBackend) set(path string, content string) { + b.files[path] = content +} + +func (b *memBackend) Read(_ context.Context, req *ReadRequest) (*filesystem.FileContent, error) { + content, ok := b.files[req.FilePath] + if !ok { + return nil, fmt.Errorf("file not found: %s: %w", req.FilePath, os.ErrNotExist) + } + return &filesystem.FileContent{Content: content}, nil +} + +// errBackend always returns a non-ErrNotExist error on Read, simulating I/O failures. +type errBackend struct{} + +func (b *errBackend) Read(_ context.Context, req *ReadRequest) (*filesystem.FileContent, error) { + return nil, fmt.Errorf("permission denied: %s", req.FilePath) +} + +// partialErrBackend returns content for known files and I/O error for others. +type partialErrBackend struct { + files map[string]string +} + +func (b *partialErrBackend) Read(_ context.Context, req *ReadRequest) (*filesystem.FileContent, error) { + content, ok := b.files[req.FilePath] + if !ok { + return nil, fmt.Errorf("I/O error reading %s", req.FilePath) + } + return &filesystem.FileContent{Content: content}, nil +} + +// --- tests --- + +func TestNew_Validation(t *testing.T) { + ctx := context.Background() + b := newMemBackend() + + _, err := New(ctx, nil) + if err == nil { + t.Fatal("expected error for nil config") + } + + _, err = New(ctx, &Config{}) + if err == nil { + t.Fatal("expected error for empty config") + } + + _, err = New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/test.md"}, AllAgentsMDMaxBytes: -1}) + if err == nil { + t.Fatal("expected error for negative max bytes") + } + + _, err = New(ctx, &Config{AgentsMDFiles: []string{"/test.md"}}) + if err == nil { + t.Fatal("expected error for nil backend") + } +} + +func TestMiddleware_BasicInjection(t *testing.T) { + b := newMemBackend() + b.set("/agent.md", "You are a helpful assistant.") + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/agent.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + userMsg := &schema.Message{Role: schema.User, Content: "hello"} + if _, err = wrapped.Generate(ctx, []*schema.Message{userMsg}); err != nil { + t.Fatal(err) + } + + if len(mock.lastInput) != 2 { + t.Fatalf("expected 2 messages, got %d", len(mock.lastInput)) + } + if mock.lastInput[0].Role != schema.User { + t.Fatalf("expected first message role User, got %s", mock.lastInput[0].Role) + } + if !strings.Contains(mock.lastInput[0].Content, "You are a helpful assistant.") { + t.Fatalf("expected agent.md content in first message, got %q", mock.lastInput[0].Content) + } + if !strings.Contains(mock.lastInput[0].Content, "") { + t.Fatalf("expected system-reminder tag, got %q", mock.lastInput[0].Content) + } + if mock.lastInput[1].Content != "hello" { + t.Fatalf("expected original message preserved, got %q", mock.lastInput[1].Content) + } +} + +func TestMiddleware_MultipleFiles(t *testing.T) { + b := newMemBackend() + b.set("/a.md", "instruction A") + b.set("/b.md", "instruction B") + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/a.md", "/b.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + if _, err = wrapped.Generate(ctx, []*schema.Message{{Role: schema.User, Content: "hi"}}); err != nil { + t.Fatal(err) + } + + content := mock.lastInput[0].Content + idxA := strings.Index(content, "instruction A") + idxB := strings.Index(content, "instruction B") + if idxA < 0 || idxB < 0 { + t.Fatalf("both files should be included, content: %q", content) + } + if idxA >= idxB { + t.Fatal("file A should appear before file B") + } +} + +func TestMiddleware_ImportResolution(t *testing.T) { + b := newMemBackend() + b.set("/project/agent.md", "main instructions\n@sub/rules.md\nend") + b.set("/project/sub/rules.md", "imported rule") + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/project/agent.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + if _, err = wrapped.Generate(ctx, []*schema.Message{{Role: schema.User, Content: "hi"}}); err != nil { + t.Fatal(err) + } + + content := mock.lastInput[0].Content + // Original text should be preserved with @path intact. + if !strings.Contains(content, "main instructions") { + t.Fatalf("should contain original text, got %q", content) + } + if !strings.Contains(content, "@sub/rules.md") { + t.Fatalf("@import reference should be preserved in original text, got %q", content) + } + if !strings.Contains(content, "end") { + t.Fatalf("should contain original trailing text, got %q", content) + } + // Imported file should appear as a separate section. + if !strings.Contains(content, "Contents of /project/sub/rules.md") { + t.Fatalf("imported file should have its own section, got %q", content) + } + if !strings.Contains(content, "imported rule") { + t.Fatalf("imported file content should be present, got %q", content) + } +} + +func TestMiddleware_RecursiveImport(t *testing.T) { + b := newMemBackend() + b.set("/a.md", "top\n@/b.md") + b.set("/b.md", "middle\n@/c.md") + b.set("/c.md", "leaf content") + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/a.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + if _, err = wrapped.Generate(ctx, []*schema.Message{{Role: schema.User, Content: "hi"}}); err != nil { + t.Fatal(err) + } + + content := mock.lastInput[0].Content + // All three files should appear as separate sections. + for _, section := range []string{"Contents of /a.md", "Contents of /b.md", "Contents of /c.md"} { + if !strings.Contains(content, section) { + t.Fatalf("expected section %q in content, got %q", section, content) + } + } + for _, text := range []string{"top", "middle", "leaf content"} { + if !strings.Contains(content, text) { + t.Fatalf("expected %q in content, got %q", text, content) + } + } + // Sections should appear in order: a, b, c. + idxA := strings.Index(content, "Contents of /a.md") + idxB := strings.Index(content, "Contents of /b.md") + idxC := strings.Index(content, "Contents of /c.md") + if !(idxA < idxB && idxB < idxC) { + t.Fatalf("sections should appear in order a < b < c, got a=%d b=%d c=%d", idxA, idxB, idxC) + } +} + +func TestMiddleware_MaxImportDepth(t *testing.T) { + b := newMemBackend() + for i := 0; i < 7; i++ { + var content string + if i < 6 { + content = fmt.Sprintf("level %d\n@/level%d.md", i, i+1) + } else { + content = fmt.Sprintf("level %d", i) + } + b.set(fmt.Sprintf("/level%d.md", i), content) + } + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/level0.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + // Import failure at depth > 5 is logged, not returned as error. + _, err = wrapped.Generate(ctx, []*schema.Message{{Role: schema.User, Content: "hi"}}) + if err != nil { + t.Fatalf("expected no error (depth exceeded is logged), got %v", err) + } + // Levels 0-5 should be present as sections; level 6 fails silently. + content := mock.lastInput[0].Content + for i := 0; i <= 5; i++ { + want := fmt.Sprintf("Contents of /level%d.md", i) + if !strings.Contains(content, want) { + t.Fatalf("expected %q in content, got %q", want, content) + } + } + if strings.Contains(content, "Contents of /level6.md") { + t.Fatalf("level6 should not be present (depth exceeded), got %q", content) + } +} + +func TestMiddleware_CircularImport(t *testing.T) { + b := newMemBackend() + b.set("/a.md", "@/b.md") + b.set("/b.md", "@/a.md") + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/a.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + // Circular import failure is logged, not returned as error. + _, err = wrapped.Generate(ctx, []*schema.Message{{Role: schema.User, Content: "hi"}}) + if err != nil { + t.Fatalf("expected no error (circular import is logged), got %v", err) + } + // /a.md and /b.md should both be present; the circular ref from b->a is skipped. + content := mock.lastInput[0].Content + if !strings.Contains(content, "Contents of /a.md") { + t.Fatalf("expected /a.md section, got %q", content) + } + if !strings.Contains(content, "Contents of /b.md") { + t.Fatalf("expected /b.md section, got %q", content) + } +} + +func TestMiddleware_MaxBytesLimit(t *testing.T) { + b := newMemBackend() + b.set("/a.md", "AAAA") // 4 bytes + b.set("/b.md", "BBBB") // 4 bytes + + ctx := context.Background() + mw, err := New(ctx, &Config{ + Backend: b, + AgentsMDFiles: []string{"/a.md", "/b.md"}, + AllAgentsMDMaxBytes: 5, // file a (4) fits, file b (4) would exceed + }) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + if _, err = wrapped.Generate(ctx, []*schema.Message{{Role: schema.User, Content: "hi"}}); err != nil { + t.Fatal(err) + } + + content := mock.lastInput[0].Content + if !strings.Contains(content, "AAAA") { + t.Fatal("first file should be included") + } + if strings.Contains(content, "BBBB") { + t.Fatal("second file should be excluded due to max bytes") + } +} + +func TestMiddleware_NotPersistedInState(t *testing.T) { + b := newMemBackend() + b.set("/agent.md", "agent instructions") + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/agent.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + originalMsgs := []*schema.Message{{Role: schema.User, Content: "hello"}} + if _, err = wrapped.Generate(ctx, originalMsgs); err != nil { + t.Fatal(err) + } + + if len(originalMsgs) != 1 { + t.Fatalf("original messages should not be modified, got %d messages", len(originalMsgs)) + } + if originalMsgs[0].Content != "hello" { + t.Fatalf("original message should be unchanged, got %q", originalMsgs[0].Content) + } + if len(mock.lastInput) != 2 { + t.Fatalf("model should receive 2 messages, got %d", len(mock.lastInput)) + } +} + +func TestMiddleware_AbsoluteImportPath(t *testing.T) { + b := newMemBackend() + b.set("/project/main.md", "start\n@/shared/imported.md\nend") + b.set("/shared/imported.md", "absolute import content") + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/project/main.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + if _, err = wrapped.Generate(ctx, []*schema.Message{{Role: schema.User, Content: "hi"}}); err != nil { + t.Fatal(err) + } + + content := mock.lastInput[0].Content + // @path preserved in original text. + if !strings.Contains(content, "@/shared/imported.md") { + t.Fatalf("@import reference should be preserved, got %q", content) + } + // Imported content in separate section. + if !strings.Contains(content, "Contents of /shared/imported.md") { + t.Fatalf("expected separate section for imported file, got %q", content) + } + if !strings.Contains(content, "absolute import content") { + t.Fatalf("expected absolute import content, got %q", content) + } +} + +func TestMiddleware_ImportAsSeparateSection(t *testing.T) { + b := newMemBackend() + b.set("/project/agent.md", "Please read @sub/rules.md and also @sub/style.md for guidance.") + b.set("/project/sub/rules.md", "RULE_CONTENT") + b.set("/project/sub/style.md", "STYLE_CONTENT") + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/project/agent.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + if _, err = wrapped.Generate(ctx, []*schema.Message{{Role: schema.User, Content: "hi"}}); err != nil { + t.Fatal(err) + } + + content := mock.lastInput[0].Content + // Original text preserved with @paths intact. + if !strings.Contains(content, "Please read @sub/rules.md and also @sub/style.md for guidance.") { + t.Fatalf("original text with @paths should be preserved, got %q", content) + } + // Imported files appear as separate sections. + if !strings.Contains(content, "Contents of /project/sub/rules.md") { + t.Fatalf("expected rules.md section, got %q", content) + } + if !strings.Contains(content, "RULE_CONTENT") { + t.Fatalf("expected imported rule content, got %q", content) + } + if !strings.Contains(content, "Contents of /project/sub/style.md") { + t.Fatalf("expected style.md section, got %q", content) + } + if !strings.Contains(content, "STYLE_CONTENT") { + t.Fatalf("expected imported style content, got %q", content) + } + + // Sections should be ordered: agent.md, rules.md, style.md. + idxAgent := strings.Index(content, "Contents of /project/agent.md") + idxRules := strings.Index(content, "Contents of /project/sub/rules.md") + idxStyle := strings.Index(content, "Contents of /project/sub/style.md") + if !(idxAgent < idxRules && idxRules < idxStyle) { + t.Fatalf("sections should appear in order agent < rules < style, got agent=%d rules=%d style=%d", idxAgent, idxRules, idxStyle) + } +} + +// --- loader-specific tests --- + +func TestLoader_NoImportsPassthrough(t *testing.T) { + // Content without any @path should be returned as-is in its section. + b := newMemBackend() + b.set("/agent.md", "plain text without imports\nline two") + + l := newLoaderConfig(b, []string{"/agent.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(content, "plain text without imports") { + t.Fatalf("expected plain content, got %q", content) + } + if !strings.Contains(content, "line two") { + t.Fatalf("expected second line, got %q", content) + } +} + +func TestLoader_ImportAsSeparateSection(t *testing.T) { + // @path in the middle of a sentence should be preserved; imported file is a separate section. + b := newMemBackend() + b.set("/doc.md", "before @/snippet.md after") + b.set("/snippet.md", "INJECTED") + + l := newLoaderConfig(b, []string{"/doc.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + // Original text preserved. + if !strings.Contains(content, "before @/snippet.md after") { + t.Fatalf("original text should be preserved with @path, got %q", content) + } + // Imported file in separate section. + if !strings.Contains(content, "Contents of /snippet.md") { + t.Fatalf("expected separate section for snippet.md, got %q", content) + } + if !strings.Contains(content, "INJECTED") { + t.Fatalf("expected imported content, got %q", content) + } +} + +func TestLoader_MultipleImportsSameLine(t *testing.T) { + // Multiple @path on one line should each get a separate section. + b := newMemBackend() + b.set("/doc.md", "see @/a.txt and @/b.txt here") + b.set("/a.txt", "AAA") + b.set("/b.txt", "BBB") + + l := newLoaderConfig(b, []string{"/doc.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + // Original text preserved. + if !strings.Contains(content, "see @/a.txt and @/b.txt here") { + t.Fatalf("original text should be preserved, got %q", content) + } + // Each imported file has its own section. + if !strings.Contains(content, "Contents of /a.txt") { + t.Fatalf("expected section for a.txt, got %q", content) + } + if !strings.Contains(content, "AAA") { + t.Fatalf("expected a.txt content, got %q", content) + } + if !strings.Contains(content, "Contents of /b.txt") { + t.Fatalf("expected section for b.txt, got %q", content) + } + if !strings.Contains(content, "BBB") { + t.Fatalf("expected b.txt content, got %q", content) + } +} + +func TestLoader_SameFileTwiceOnSameLine(t *testing.T) { + // The same file referenced twice should appear only once as a section (deduped). + b := newMemBackend() + b.set("/doc.md", "@/shared.md and @/shared.md again") + b.set("/shared.md", "SHARED") + + l := newLoaderConfig(b, []string{"/doc.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + // Original text preserved. + if !strings.Contains(content, "@/shared.md and @/shared.md again") { + t.Fatalf("original text should be preserved, got %q", content) + } + // shared.md content should appear only once (deduped). + count := strings.Count(content, "Contents of /shared.md") + if count != 1 { + t.Fatalf("expected shared.md section to appear once (deduped), got %d in %q", count, content) + } +} + +func TestLoader_ImportFileNotFound(t *testing.T) { + b := newMemBackend() + b.set("/doc.md", "load @/missing.md please") + + l := newLoaderConfig(b, []string{"/doc.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatalf("expected no error (missing import is logged), got %v", err) + } + // Original text preserved; missing file simply has no section. + if !strings.Contains(content, "load @/missing.md please") { + t.Fatalf("expected original text preserved, got %q", content) + } + if strings.Contains(content, "Contents of /missing.md") { + t.Fatalf("missing file should not have a section, got %q", content) + } +} + +func TestLoader_RelativePathResolution(t *testing.T) { + // Relative path should resolve relative to the host file's directory. + b := newMemBackend() + b.set("/a/b/host.md", "ref @../c/target.md done") + b.set("/a/c/target.md", "TARGET") + + l := newLoaderConfig(b, []string{"/a/b/host.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + // Original text preserved. + if !strings.Contains(content, "ref @../c/target.md done") { + t.Fatalf("original text should be preserved, got %q", content) + } + // Imported file as separate section. + if !strings.Contains(content, "Contents of /a/c/target.md") { + t.Fatalf("expected section for target.md, got %q", content) + } + if !strings.Contains(content, "TARGET") { + t.Fatalf("expected imported content, got %q", content) + } +} + +func TestLoader_RelativeTopLevelPath(t *testing.T) { + // Top-level file uses relative path; imports with ./ resolve correctly. + b := newMemBackend() + b.set("sub/agents.md", "start @./other.md end") + b.set("sub/other.md", "OTHER CONTENT") + + l := newLoaderConfig(b, []string{"sub/agents.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(content, "start @./other.md end") { + t.Fatalf("expected original text preserved, got %q", content) + } + if !strings.Contains(content, "OTHER CONTENT") { + t.Fatalf("expected imported content, got %q", content) + } +} + +func TestLoader_RelativeTopLevelWithDotDotImport(t *testing.T) { + // Top-level file uses relative path; import with ../ resolves correctly. + b := newMemBackend() + b.set("sub/agents.md", "see @../shared/x.md here") + b.set("shared/x.md", "SHARED X") + + l := newLoaderConfig(b, []string{"sub/agents.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(content, "SHARED X") { + t.Fatalf("expected imported content, got %q", content) + } + // filepath.Clean should normalize "sub/../shared/x.md" to "shared/x.md" + if !strings.Contains(content, "Contents of shared/x.md") { + t.Fatalf("expected normalized path in section header, got %q", content) + } +} + +func TestLoader_RelativeTopLevelDedup(t *testing.T) { + // Two top-level relative paths that resolve to the same file via filepath.Clean + // should be deduped (loaded only once). + b := newMemBackend() + b.set("sub/a.md", "CONTENT A") + + l := newLoaderConfig(b, []string{"sub/a.md", "./sub/a.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + count := strings.Count(content, "CONTENT A") + if count != 1 { + t.Fatalf("expected file loaded once (deduped), got %d occurrences in %q", count, content) + } +} + +func TestLoader_AbsoluteTopLevelWithRelativeImport(t *testing.T) { + // Absolute top-level path with relative @import resolves correctly. + b := newMemBackend() + b.set("/project/agents.md", "ref @./lib/helper.md done") + b.set("/project/lib/helper.md", "HELPER") + + l := newLoaderConfig(b, []string{"/project/agents.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(content, "HELPER") { + t.Fatalf("expected imported content, got %q", content) + } + if !strings.Contains(content, "Contents of /project/lib/helper.md") { + t.Fatalf("expected section header, got %q", content) + } +} + +func TestLoader_AbsoluteTopLevelWithDotDotImport(t *testing.T) { + // Absolute top-level path; @import with ../ resolves and normalizes. + b := newMemBackend() + b.set("/project/sub/agents.md", "load @../shared/x.md here") + b.set("/project/shared/x.md", "SHARED") + + l := newLoaderConfig(b, []string{"/project/sub/agents.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(content, "SHARED") { + t.Fatalf("expected imported content, got %q", content) + } + // filepath.Clean normalizes "/project/sub/../shared/x.md" to "/project/shared/x.md" + if !strings.Contains(content, "Contents of /project/shared/x.md") { + t.Fatalf("expected normalized path in section header, got %q", content) + } +} + +func TestLoader_RelativeImportDedup(t *testing.T) { + // Two different relative @import paths that resolve to the same file + // should be deduped via filepath.Clean. + b := newMemBackend() + b.set("/a/main.md", "first @/a/b/shared.md second @../a/b/shared.md end") + b.set("/a/b/shared.md", "SHARED ONCE") + + l := newLoaderConfig(b, []string{"/a/main.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + count := strings.Count(content, "SHARED ONCE") + if count != 1 { + t.Fatalf("expected shared file loaded once (deduped), got %d in %q", count, content) + } +} + +func TestLoader_NestedRelativeImport(t *testing.T) { + // File A imports B via relative path, B imports C via relative path. + // All three should appear as separate sections. + b := newMemBackend() + b.set("/root/main.md", "start @sub/mid.md end") + b.set("/root/sub/mid.md", "mid @deep/leaf.md mid_end") + b.set("/root/sub/deep/leaf.md", "LEAF") + + l := newLoaderConfig(b, []string{"/root/main.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + for _, section := range []string{"Contents of /root/main.md", "Contents of /root/sub/mid.md", "Contents of /root/sub/deep/leaf.md"} { + if !strings.Contains(content, section) { + t.Fatalf("expected section %q, got %q", section, content) + } + } + if !strings.Contains(content, "LEAF") { + t.Fatalf("expected leaf content, got %q", content) + } +} + +func TestLoader_TransitiveImport(t *testing.T) { + // Imported file itself contains @imports; all should appear as separate sections. + b := newMemBackend() + b.set("/main.md", "header @/mid.md footer") + b.set("/mid.md", "mid-start @/leaf.md mid-end") + b.set("/leaf.md", "LEAF_VALUE") + + l := newLoaderConfig(b, []string{"/main.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + for _, section := range []string{"Contents of /main.md", "Contents of /mid.md", "Contents of /leaf.md"} { + if !strings.Contains(content, section) { + t.Fatalf("expected section %q, got %q", section, content) + } + } + if !strings.Contains(content, "LEAF_VALUE") { + t.Fatalf("expected leaf value, got %q", content) + } +} + +func TestLoader_EmptyFile(t *testing.T) { + b := newMemBackend() + b.set("/empty.md", "") + + l := newLoaderConfig(b, []string{"/empty.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + // Empty file is treated as non-existent, so output should be empty. + if content != "" { + t.Fatalf("expected empty output for empty file, got %q", content) + } +} + +func TestLoader_MaxBytesFirstFileFull(t *testing.T) { + // Even if the first file alone exceeds maxBytes, it should still be loaded in full. + b := newMemBackend() + b.set("/big.md", "ABCDEFGHIJ") // 10 bytes + + l := newLoaderConfig(b, []string{"/big.md"}, 3, nil) + content, err := l.load(context.Background()) // maxBytes=3, but first file always loads + if err != nil { + t.Fatal(err) + } + if !strings.Contains(content, "ABCDEFGHIJ") { + t.Fatalf("first file should always load in full, got %q", content) + } +} + +func TestLoader_CircularImportInline(t *testing.T) { + // Circular reference via @import should be detected, logged, and skipped. + b := newMemBackend() + b.set("/a.md", "text @/b.md more") + b.set("/b.md", "ref @/a.md back") + + l := newLoaderConfig(b, []string{"/a.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatalf("expected no error (circular import is logged), got %v", err) + } + // Both a and b should have sections; circular back-reference a from b is skipped. + if !strings.Contains(content, "Contents of /a.md") { + t.Fatalf("expected /a.md section, got %q", content) + } + if !strings.Contains(content, "Contents of /b.md") { + t.Fatalf("expected /b.md section, got %q", content) + } +} + +func TestLoader_MaxDepthInline(t *testing.T) { + // Deep chain via @import should be logged at depth > 5, not returned as error. + b := newMemBackend() + for i := 0; i < 7; i++ { + var content string + if i < 6 { + content = fmt.Sprintf("level%d @/level%d.md tail", i, i+1) + } else { + content = fmt.Sprintf("level%d", i) + } + b.set(fmt.Sprintf("/level%d.md", i), content) + } + + l := newLoaderConfig(b, []string{"/level0.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatalf("expected no error (depth exceeded is logged), got %v", err) + } + // Levels 0-5 should have sections. + for i := 0; i <= 5; i++ { + want := fmt.Sprintf("Contents of /level%d.md", i) + if !strings.Contains(content, want) { + t.Fatalf("expected %q in content, got %q", want, content) + } + } + // Level 6 should not be present. + if strings.Contains(content, "Contents of /level6.md") { + t.Fatalf("level6 should not be present (depth exceeded), got %q", content) + } +} + +func TestLoader_DiamondDependency(t *testing.T) { + // A imports B and D; B imports C; D also imports C. + // C should appear only once (deduped across the whole load). + b := newMemBackend() + b.set("/a.md", "start @/b.md middle @/d.md end") + b.set("/b.md", "B(@/c.md)") + b.set("/d.md", "D(@/c.md)") + b.set("/c.md", "SHARED") + + l := newLoaderConfig(b, []string{"/a.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatalf("diamond dependency should not be circular, got error: %v", err) + } + + // C should appear only once as a section (deduped). + count := strings.Count(content, "Contents of /c.md") + if count != 1 { + t.Fatalf("expected /c.md section once (deduped), got %d in %q", count, content) + } + // All files should have sections. + for _, section := range []string{"Contents of /a.md", "Contents of /b.md", "Contents of /c.md", "Contents of /d.md"} { + if !strings.Contains(content, section) { + t.Fatalf("expected section %q, got %q", section, content) + } + } +} + +func TestLoader_AtSignInNormalText(t *testing.T) { + // Bare @word without "/" or file extension should not trigger import. + // Email-like patterns (@example.com) with non-allowed extensions should also be ignored. + b := newMemBackend() + b.set("/agent.md", "contact me @ anytime or @ spaces and @someone mentioned and user@example.com and @company.org") + + l := newLoaderConfig(b, []string{"/agent.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(content, "contact me @ anytime") { + t.Fatalf("bare @ should not trigger import, got %q", content) + } + if !strings.Contains(content, "@someone mentioned") { + t.Fatalf("@someone without / or extension should not trigger import, got %q", content) + } + if !strings.Contains(content, "@example.com") { + t.Fatalf("email-like @example.com should not trigger import, got %q", content) + } + if !strings.Contains(content, "@company.org") { + t.Fatalf("email-like @company.org should not trigger import, got %q", content) + } +} + +func TestMiddleware_Stream(t *testing.T) { + b := newMemBackend() + b.set("/agent.md", "stream test") + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/agent.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + _, _ = wrapped.Stream(ctx, []*schema.Message{{Role: schema.User, Content: "hi"}}) + + if len(mock.lastInput) != 2 { + t.Fatalf("expected 2 messages for stream, got %d", len(mock.lastInput)) + } + if !strings.Contains(mock.lastInput[0].Content, "stream test") { + t.Fatalf("expected agent.md content in stream input, got %q", mock.lastInput[0].Content) + } +} + +func TestLoader_MaxBytesWithImports(t *testing.T) { + // Two top-level files that both import the same shared file. + // Budget should account for imported file bytes. + b := newMemBackend() + b.set("/a.md", "A(@/shared.md)") + b.set("/b.md", "B(@/shared.md)") + b.set("/shared.md", strings.Repeat("X", 100)) // 100 bytes + + l := newLoaderConfig(b, []string{"/a.md", "/b.md"}, 120, nil) + // /a.md = 14 bytes + /shared.md = 100 bytes => 114 total after /a.md. + // Budget = 120: /b.md (14 bytes) would push to 128, exceeding budget. + content, err := l.load(context.Background()) + if err != nil { + t.Fatalf("load failed: %v", err) + } + + // /a.md and its import should be included. + if !strings.Contains(content, strings.Repeat("X", 100)) { + t.Fatal("expected /a.md with shared content to be included") + } + + // /b.md should be excluded because totalBytes exceeded budget after loading /a.md. + if strings.Contains(content, "B(") { + t.Fatalf("expected /b.md to be excluded due to budget, got %q", content) + } +} + +func TestNew_Validation_EmptyAgentFiles(t *testing.T) { + ctx := context.Background() + b := newMemBackend() + + _, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{}}) + if err == nil { + t.Fatal("expected error for empty agent files") + } + if !strings.Contains(err.Error(), "at least one agent file path is required") { + t.Fatalf("unexpected error message: %v", err) + } +} + +func TestMiddleware_GenerateError(t *testing.T) { + // Non-ErrNotExist errors (e.g. permission denied) should propagate. + b := &errBackend{} + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/file.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + _, err = wrapped.Generate(ctx, []*schema.Message{{Role: schema.User, Content: "hi"}}) + if err == nil { + t.Fatal("expected error when backend read fails with non-ErrNotExist") + } + if !strings.Contains(err.Error(), "failed to load agent files") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestMiddleware_StreamError(t *testing.T) { + // Non-ErrNotExist errors (e.g. permission denied) should propagate. + b := &errBackend{} + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/file.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + _, err = wrapped.Stream(ctx, []*schema.Message{{Role: schema.User, Content: "hi"}}) + if err == nil { + t.Fatal("expected error when backend read fails with non-ErrNotExist for stream") + } + if !strings.Contains(err.Error(), "failed to load agent files") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestLoader_DuplicateTopLevelFiles(t *testing.T) { + // Same file listed twice in AgentFiles; second should be deduped via seen map. + b := newMemBackend() + b.set("/agent.md", "unique content") + + l := newLoaderConfig(b, []string{"/agent.md", "/agent.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + + count := strings.Count(content, "Contents of /agent.md") + if count != 1 { + t.Fatalf("expected /agent.md section once (deduped), got %d", count) + } +} + +func TestLoader_LoadFileError(t *testing.T) { + // Missing file (ErrNotExist) is silently skipped. + b := newMemBackend() + l := newLoaderConfig(b, []string{"/missing.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatalf("expected missing file to be skipped, got error: %v", err) + } + if content != "" { + t.Fatalf("expected empty output, got %q", content) + } +} + +func TestLoader_MaxBytesStopsImports(t *testing.T) { + // When budget is exhausted, further imports in collectImports should be skipped. + b := newMemBackend() + b.set("/main.md", "@/big.md @/small.md") + b.set("/big.md", strings.Repeat("B", 200)) + b.set("/small.md", "SMALL") + + l := newLoaderConfig(b, []string{"/main.md"}, 50, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + + // main.md itself is loaded (always), big.md pushes over budget, + // small.md should be skipped. + if !strings.Contains(content, "Contents of /main.md") { + t.Fatal("main.md should be present") + } + if strings.Contains(content, "SMALL") { + t.Fatal("small.md should be skipped after budget exhausted") + } +} + +func TestFormatContent_Empty(t *testing.T) { + // formatContent with nil/empty slice should return empty string. + if got := formatContent(nil); got != "" { + t.Fatalf("expected empty string for nil, got %q", got) + } + if got := formatContent([]loadedFile{}); got != "" { + t.Fatalf("expected empty string for empty slice, got %q", got) + } +} + +func TestMiddleware_AllFilesEmpty(t *testing.T) { + // When all agent files have empty content, loader returns "" and + // prependAgentMD returns the original input unchanged. + b := newMemBackend() + b.set("/agent.md", "") + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/agent.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + userMsg := []*schema.Message{{Role: schema.User, Content: "hello"}} + if _, err = wrapped.Generate(ctx, userMsg); err != nil { + t.Fatal(err) + } + // Empty file produces no agentmd content, so original messages pass through unchanged. + if len(mock.lastInput) != 1 { + t.Fatalf("expected 1 message (no agentmd prepended), got %d", len(mock.lastInput)) + } + if mock.lastInput[0].Content != "hello" { + t.Fatalf("expected original message unchanged, got %q", mock.lastInput[0].Content) + } +} + +func TestLoader_ExactOutput(t *testing.T) { + // Verify the exact output format matches the expected structure: + // each file (top-level and imported) gets its own "Contents of ..." section, + // @path references are preserved in the original text. + b := newMemBackend() + b.set("/project/CLAUDE.md", "this is project claude.md\n\n- git workflow @git/git-instructions.md") + b.set("/project/git/git-instructions.md", "this is git-instructions.md") + + l := newLoaderConfig(b, []string{"/project/CLAUDE.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + + expected := ` +As you answer the user's questions, you can use the following context: +Codebase and user instructions are shown below. Be sure to adhere to these instructions. IMPORTANT: These instructions OVERRIDE any default behavior and you MUST follow them exactly as written. + +Contents of /project/CLAUDE.md (instructions): + +this is project claude.md + +- git workflow @git/git-instructions.md + +Contents of /project/git/git-instructions.md (instructions): + +this is git-instructions.md +IMPORTANT: this context may or may not be relevant to your tasks. You should not respond to this context unless it is highly relevant to your task. +` + + if content != expected { + t.Fatalf("output mismatch.\n\ngot:\n%s\n\nexpected:\n%s", content, expected) + } +} + +func TestLoader_MissingFileSkipped(t *testing.T) { + b := newMemBackend() + b.set("/good.md", "GOOD CONTENT") + // /missing.md is not set + + l := newLoaderConfig(b, []string{"/missing.md", "/good.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatalf("expected no error for missing file, got %v", err) + } + if !strings.Contains(content, "GOOD CONTENT") { + t.Fatal("expected good.md content in output") + } +} + +func TestLoader_AllMissingFilesSkipped(t *testing.T) { + b := newMemBackend() + + l := newLoaderConfig(b, []string{"/missing1.md", "/missing2.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatalf("expected no error for missing files, got %v", err) + } + if content != "" { + t.Fatalf("expected empty output when all files missing, got %q", content) + } +} + +func TestLoader_CircularImportSkipped(t *testing.T) { + b := newMemBackend() + b.set("/a.md", "A content @/b.md") + b.set("/b.md", "B content @/a.md") + + // Circular import in collectImports is logged via onWarning and skipped. + l := newLoaderConfig(b, []string{"/a.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !strings.Contains(content, "A content") { + t.Fatal("expected a.md content") + } + if !strings.Contains(content, "B content") { + t.Fatal("expected b.md content") + } +} + +func TestLoader_DepthExceededSkipped(t *testing.T) { + b := newMemBackend() + // Create a chain that exceeds maxImportDepth (5) + b.set("/l0.md", "@/l1.md") + b.set("/l1.md", "@/l2.md") + b.set("/l2.md", "@/l3.md") + b.set("/l3.md", "@/l4.md") + b.set("/l4.md", "@/l5.md") + b.set("/l5.md", "@/l6.md") + b.set("/l6.md", "DEEP") + + l := newLoaderConfig(b, []string{"/l0.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatalf("expected no error for depth exceeded, got %v", err) + } + // Should have content up to the depth limit, deep file skipped. + if !strings.Contains(content, "/l0.md") { + t.Fatal("expected l0.md in output") + } +} + +func TestLoader_OnLoadWarningCallback(t *testing.T) { + b := newMemBackend() + b.set("/good.md", "GOOD CONTENT") + + var warnings []error + onWarning := func(filePath string, err error) { + warnings = append(warnings, fmt.Errorf("%s: %w", filePath, err)) + } + + l := newLoaderConfig(b, []string{"/missing.md", "/good.md"}, 0, onWarning) + content, err := l.load(context.Background()) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !strings.Contains(content, "GOOD CONTENT") { + t.Fatal("expected good.md content in output") + } + if len(warnings) == 0 { + t.Fatal("expected at least one warning for missing file") + } + if !strings.Contains(warnings[0].Error(), "file not found") { + t.Fatalf("expected file not found warning, got %v", warnings[0]) + } +} + +func TestMiddleware_MissingFile_Generate(t *testing.T) { + b := newMemBackend() + // /missing.md not set — will fail to read + + ctx := context.Background() + mw, err := New(ctx, &Config{ + Backend: b, + AgentsMDFiles: []string{"/missing.md"}, + }) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + userMsg := []*schema.Message{{Role: schema.User, Content: "hello"}} + _, err = wrapped.Generate(ctx, userMsg) + if err != nil { + t.Fatalf("expected no error for missing file, got %v", err) + } + // No agent.md content, so original messages should be passed through unchanged. + if len(mock.lastInput) != 1 { + t.Fatalf("expected 1 message (no agentmd prepended), got %d", len(mock.lastInput)) + } +} + +func TestMiddleware_MissingFile_Stream(t *testing.T) { + b := newMemBackend() + + ctx := context.Background() + mw, err := New(ctx, &Config{ + Backend: b, + AgentsMDFiles: []string{"/missing.md"}, + }) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + userMsg := []*schema.Message{{Role: schema.User, Content: "hello"}} + _, err = wrapped.Stream(ctx, userMsg) + if err != nil { + t.Fatalf("expected no error for missing file, got %v", err) + } + if len(mock.lastInput) != 1 { + t.Fatalf("expected 1 message (no agentmd prepended), got %d", len(mock.lastInput)) + } +} + +func TestMiddleware_InsertBeforeFirstUserMessage(t *testing.T) { + b := newMemBackend() + b.set("/agent.md", "agent instructions") + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/agent.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + // Input has a System message before the User message. + input := []*schema.Message{ + {Role: schema.System, Content: "system prompt"}, + {Role: schema.User, Content: "hello"}, + } + if _, err = wrapped.Generate(ctx, input); err != nil { + t.Fatal(err) + } + + if len(mock.lastInput) != 3 { + t.Fatalf("expected 3 messages, got %d", len(mock.lastInput)) + } + if mock.lastInput[0].Role != schema.System { + t.Fatalf("expected first message role System, got %s", mock.lastInput[0].Role) + } + if mock.lastInput[0].Content != "system prompt" { + t.Fatalf("expected system prompt preserved, got %q", mock.lastInput[0].Content) + } + if mock.lastInput[1].Role != schema.User || !strings.Contains(mock.lastInput[1].Content, "agent instructions") { + t.Fatalf("expected agentmd message before user message, got role=%s content=%q", mock.lastInput[1].Role, mock.lastInput[1].Content) + } + if mock.lastInput[2].Role != schema.User || mock.lastInput[2].Content != "hello" { + t.Fatalf("expected original user message at index 2, got role=%s content=%q", mock.lastInput[2].Role, mock.lastInput[2].Content) + } +} + +func TestMiddleware_InsertWithNoUserMessage(t *testing.T) { + b := newMemBackend() + b.set("/agent.md", "agent instructions") + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/agent.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + // Input has no User message at all. + input := []*schema.Message{ + {Role: schema.System, Content: "system prompt"}, + {Role: schema.Assistant, Content: "assistant reply"}, + } + if _, err = wrapped.Generate(ctx, input); err != nil { + t.Fatal(err) + } + + if len(mock.lastInput) != 3 { + t.Fatalf("expected 3 messages, got %d", len(mock.lastInput)) + } + if mock.lastInput[0].Role != schema.System { + t.Fatalf("expected System at index 0, got %s", mock.lastInput[0].Role) + } + if mock.lastInput[1].Role != schema.Assistant { + t.Fatalf("expected Assistant at index 1, got %s", mock.lastInput[1].Role) + } + if mock.lastInput[2].Role != schema.User || !strings.Contains(mock.lastInput[2].Content, "agent instructions") { + t.Fatalf("expected agentmd appended at end, got role=%s content=%q", mock.lastInput[2].Role, mock.lastInput[2].Content) + } +} + +func TestLoader_ImportIOError(t *testing.T) { + // When an imported file returns a non-ErrNotExist error (e.g. I/O error), + // the load should propagate the error (covers collectImports and loadFile error paths). + b := &partialErrBackend{ + files: map[string]string{ + "/main.md": "content @/broken.md", + }, + // /broken.md is NOT in the map, so Read returns I/O error (not ErrNotExist) + } + + l := newLoaderConfig(b, []string{"/main.md"}, 0, nil) + _, err := l.load(context.Background()) + if err == nil { + t.Fatal("expected error from I/O failure on imported file") + } + if !strings.Contains(err.Error(), "I/O error") { + t.Fatalf("expected I/O error, got: %v", err) + } +} diff --git a/adk/middlewares/agentsmd/loader.go b/adk/middlewares/agentsmd/loader.go new file mode 100644 index 000000000..db733383b --- /dev/null +++ b/adk/middlewares/agentsmd/loader.go @@ -0,0 +1,299 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package agentsmd + +import ( + "context" + "errors" + "fmt" + "log" + "os" + "path/filepath" + "regexp" + "strings" + + "github.com/cloudwego/eino/adk/filesystem" + "github.com/cloudwego/eino/adk/internal" +) + +// importRegex matches @path/to/file anywhere in text. +// The path must start with a letter, digit, dot, underscore, slash, or tilde, followed by +// path characters (letters, digits, dots, slashes, hyphens, underscores). +// A post-match filter further requires the path to contain "/" or end with +// an allowed extension (see allowedImportExts), so bare words like @someone +// and email-like patterns like @example.com are ignored. +var importRegex = regexp.MustCompile(`@([a-zA-Z0-9_.~/][a-zA-Z0-9_.~/\-]*)`) + +// allowedImportExts is the set of file extensions recognised as @import targets. +// Paths without "/" must end with one of these extensions to be treated as imports; +// this avoids false positives on email addresses (@example.com) and mentions (@foo.bar). +var allowedImportExts = map[string]bool{ + ".md": true, + ".txt": true, + ".mdx": true, + ".yaml": true, + ".yml": true, + ".json": true, + ".toml": true, +} + +const maxImportDepth = 5 + +// ReadRequest is an alias for filesystem.ReadRequest. +type ReadRequest = filesystem.ReadRequest +type FileContent = filesystem.FileContent + +// Backend defines the file access interface for loading Agents.md files. +// Implementations can use local filesystem, remote storage, or any other backend. +type Backend interface { + // Read reads the content of a file. + // If the file does not exist, implementations should return an error wrapping + // os.ErrNotExist (so that errors.Is(err, os.ErrNotExist) returns true). This allows the loader + // to silently skip missing files and notify via OnLoadWarning callback. + // Other errors (e.g. permission denied, I/O errors) will abort the loading process. + Read(ctx context.Context, req *ReadRequest) (*FileContent, error) +} + +// loaderConfig holds the immutable configuration for creating loaders. +// It is safe for concurrent use by multiple goroutines. +type loaderConfig struct { + backend Backend + files []string // ordered file paths from config + maxBytes int // cumulative read budget; 0 means unlimited + onWarning func(filePath string, err error) // callback for non-fatal loading warnings +} + +func newLoaderConfig(backend Backend, files []string, maxBytes int, onWarning func(filePath string, err error)) *loaderConfig { + if onWarning == nil { + onWarning = func(filePath string, err error) { + log.Printf("[agentsmd] warning: %s: %v", filePath, err) + } + } + return &loaderConfig{ + backend: backend, + files: files, + maxBytes: maxBytes, + onWarning: onWarning, + } +} + +// loader handles loading and @import resolution for agents.md files. +// A new loader is created for each load() call to avoid sharing mutable state +// (totalBytes) across concurrent invocations. +type loader struct { + *loaderConfig + totalBytes int // accumulated bytes during this load call +} + +func (cfg *loaderConfig) newLoader() *loader { + return &loader{loaderConfig: cfg} +} + +// load reads all agents.md files and returns the formatted content. +// Each top-level file and its @imported files appear as separate sections. +func (cfg *loaderConfig) load(ctx context.Context) (string, error) { + l := cfg.newLoader() + + var parts []loadedFile + seen := make(map[string]bool) // dedup across all files and imports + + for i, filePath := range l.files { + files, err := l.loadFile(ctx, filePath, 0, make(map[string]bool), seen) + if err != nil { + return "", fmt.Errorf("failed to load %q: %w", filePath, err) + } + + // If loading this file caused the budget to be exceeded, skip it + // (but always include the first file). + if i > 0 && l.maxBytes > 0 && l.totalBytes > l.maxBytes { + l.onWarning(filePath, fmt.Errorf("skipped: cumulative size %d exceeds max bytes %d", l.totalBytes, l.maxBytes)) + break + } + + parts = append(parts, files...) + } + + return formatContent(parts), nil +} + +// loadFile reads a file via Backend and collects @imported files as separate entries. +// Returns a slice where the first element is this file itself, followed by all +// transitively imported files (in encounter order, preserving @path in original text). +// visited tracks the current ancestor chain to detect circular imports. +// seen tracks globally loaded files to avoid duplicate reads and byte counting. +func (l *loader) loadFile(ctx context.Context, filePath string, depth int, visited map[string]bool, seen map[string]bool) ([]loadedFile, error) { + filePath = filepath.Clean(filePath) + + if depth > maxImportDepth { + l.onWarning(filePath, fmt.Errorf("@import depth exceeds maximum of %d", maxImportDepth)) + return nil, nil + } + + if visited[filePath] { + l.onWarning(filePath, fmt.Errorf("circular @import detected")) + return nil, nil + } + + if seen[filePath] { + return nil, nil + } + + visited[filePath] = true + defer delete(visited, filePath) + + fileContent, err := l.backend.Read(ctx, &ReadRequest{FilePath: filePath, Offset: 1}) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + l.onWarning(filePath, fmt.Errorf("file not found, skipping")) + return nil, nil + } + return nil, err + } + content := "" + if fileContent != nil { + content = fileContent.Content + } + + l.totalBytes += len(content) + seen[filePath] = true + + if content == "" { + return nil, nil + } + + // Collect imported files as separate sections (content stays untouched). + imports, err := l.collectImports(ctx, filePath, content, depth, visited, seen) + if err != nil { + return nil, err + } + + // This file first, then its imports. + result := make([]loadedFile, 0, 1+len(imports)) + result = append(result, loadedFile{path: filePath, content: content}) + result = append(result, imports...) + return result, nil +} + +// collectImports scans content for @path/to/file references and loads each +// imported file (plus its transitive imports). The original content is NOT modified. +// Returns the list of imported loadedFile entries in encounter order. +// seen is shared across the entire load call to avoid duplicate reads. +// Non-fatal errors (file not found, depth exceeded, circular import) are reported +// via onWarning and skipped. Fatal errors (e.g. I/O) are returned. +func (l *loader) collectImports(ctx context.Context, hostPath, content string, depth int, visited map[string]bool, seen map[string]bool) ([]loadedFile, error) { + dir := filepath.Dir(hostPath) + var imports []loadedFile + + matches := importRegex.FindAllStringSubmatch(content, -1) + for _, match := range matches { + rawPath := match[1] + + // Only treat as import if path contains "/" or ends with an allowed extension. + // This avoids false positives on email addresses and social mentions. + if !strings.Contains(rawPath, "/") && !allowedImportExts[filepath.Ext(rawPath)] { + continue + } + + // If budget is exhausted, skip further imports. + if l.maxBytes > 0 && l.totalBytes > l.maxBytes { + break + } + + importPath := rawPath + if !filepath.IsAbs(importPath) { + importPath = filepath.Join(dir, importPath) + } + + if seen[importPath] { + continue + } + + files, err := l.loadFile(ctx, importPath, depth+1, visited, seen) + if err != nil { + return nil, fmt.Errorf("failed to import %q from %q: %w", rawPath, hostPath, err) + } + + imports = append(imports, files...) + } + + return imports, nil +} + +type loadedFile struct { + path string + content string +} + +const formatHeaderEn = ` +As you answer the user's questions, you can use the following context: +Codebase and user instructions are shown below. Be sure to adhere to these instructions. IMPORTANT: These instructions OVERRIDE any default behavior and you MUST follow them exactly as written. +` + +const formatHeaderCn = ` +在回答用户问题时,你可以使用以下上下文: +代码库和用户指令如下。请务必遵守这些指令。重要提示:这些指令会覆盖任何默认行为,你必须严格按照要求执行。 +` + +const formatFileHeaderEn = "\nContents of " + +const formatFileHeaderCn = "\n文件内容:" + +const formatFileLabelEn = " (instructions):\n\n" + +const formatFileLabelCn = "(指令):\n\n" + +const formatFooterEn = `IMPORTANT: this context may or may not be relevant to your tasks. You should not respond to this context unless it is highly relevant to your task. +` + +const formatFooterCn = `重要提示:此上下文可能与你的任务相关,也可能不相关。除非此上下文与你的任务高度相关,否则不要响应此上下文。 +` + +func formatContent(files []loadedFile) string { + if len(files) == 0 { + return "" + } + + header := internal.SelectPrompt(internal.I18nPrompts{ + English: formatHeaderEn, + Chinese: formatHeaderCn, + }) + fileHeader := internal.SelectPrompt(internal.I18nPrompts{ + English: formatFileHeaderEn, + Chinese: formatFileHeaderCn, + }) + fileLabel := internal.SelectPrompt(internal.I18nPrompts{ + English: formatFileLabelEn, + Chinese: formatFileLabelCn, + }) + footer := internal.SelectPrompt(internal.I18nPrompts{ + English: formatFooterEn, + Chinese: formatFooterCn, + }) + + var sb strings.Builder + sb.WriteString(header) + + for _, f := range files { + sb.WriteString(fileHeader) + sb.WriteString(f.path) + sb.WriteString(fileLabel) + sb.WriteString(f.content) + sb.WriteString("\n") + } + sb.WriteString(footer) + return sb.String() +} From 7d3f27e1be846d28cf5356f6a45a4d601f757c3f Mon Sep 17 00:00:00 2001 From: "luohuaqing.2018" Date: Thu, 12 Feb 2026 17:42:07 +0800 Subject: [PATCH 31/65] feat(adk): add TurnLoop and Cancellable interfaces Change-Id: Ifb3be9fadeabdd8f3b6985bb47d2fb38a7004beb --- adk/interface.go | 39 +++ adk/turn_loop.go | 251 +++++++++++++++++++ adk/turn_loop_test.go | 564 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 854 insertions(+) create mode 100644 adk/turn_loop.go create mode 100644 adk/turn_loop_test.go diff --git a/adk/interface.go b/adk/interface.go index 5c06843ae..fef2f98fe 100644 --- a/adk/interface.go +++ b/adk/interface.go @@ -20,6 +20,7 @@ import ( "bytes" "context" "encoding/gob" + "errors" "fmt" "io" @@ -269,3 +270,41 @@ type ResumableAgent interface { Resume(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] } + +// CancelMode specifies when an agent should be canceled. +// Modes can be combined with bitwise OR to cancel at multiple execution points. +// For example, CancelAfterChatModel | CancelAfterToolCall cancels the agent +// after whichever execution point is reached first. +type CancelMode int + +const ( + // CancelImmediate cancels the agent immediately without waiting + // for any execution point. + CancelImmediate CancelMode = 0 + // CancelAfterChatModel cancels the agent after a chat model call completes. + CancelAfterChatModel CancelMode = 1 << iota + // CancelAfterToolCall cancels the agent after a tool call completes. + CancelAfterToolCall +) + +// ErrAgentFinished is returned by Cancel when the agent has already finished execution. +var ErrAgentFinished = errors.New("agent has already finished execution") + +// CancelOption holds options for cancelling an agent. +type CancelOption struct { + Mode CancelMode +} + +// Cancellable is an optional interface that an Agent can implement to support +// cancellation during execution. +type Cancellable interface { + // Cancel signals the agent to stop, either immediately or after reaching the + // specified execution point(s) defined by opt.Mode. + // + // The opt parameter must be non-nil. Use opt.Mode to control when the + // cancellation takes effect (e.g., CancelImmediate, CancelAfterChatModel, + // CancelAfterToolCall, or a combination via bitwise OR). + // + // If the agent has already finished execution, Cancel returns ErrAgentFinished. + Cancel(ctx context.Context, opt *CancelOption) error +} diff --git a/adk/turn_loop.go b/adk/turn_loop.go new file mode 100644 index 000000000..d6ac24071 --- /dev/null +++ b/adk/turn_loop.go @@ -0,0 +1,251 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "fmt" + "time" +) + +// ConsumeMode specifies how a received message should be consumed +// relative to the currently running agent. +type ConsumeMode int + +const ( + // ConsumeNonPreemptive processes the message after the current agent + // finishes. This is the default queued behavior. + ConsumeNonPreemptive ConsumeMode = iota + // ConsumePreemptive cancels the currently running agent (if it + // implements Cancellable) and processes the message immediately. + // If the agent does not implement Cancellable, the message is + // buffered and processed after the agent finishes. + ConsumePreemptive +) + +// ConsumeOption describes how a received message should be consumed. +// It combines ConsumeMode (preemptive vs queued) with CancelMode +// (how to cancel the running agent when preempting). +type ConsumeOption struct { + // Mode specifies whether the message should preempt the current agent + // or be queued. Default zero value is ConsumeNonPreemptive. + Mode ConsumeMode + // CancelOption specifies when and how the running agent should be canceled. + // Only meaningful when Mode is ConsumePreemptive and the agent + // implements Cancellable. Default nil value means CancelImmediate. + CancelOption *CancelOption +} + +// NonPreemptiveConsumeOption is a convenience value for the common +// non-preemptive (queued) case. +var NonPreemptiveConsumeOption = ConsumeOption{Mode: ConsumeNonPreemptive} + +// MessageSource is an interface for pulling typed messages from an external source. +// Receive blocks until a message is available or an error occurs. +// The timeout parameter specifies the maximum duration to wait for a message. +// The returned ConsumeOption indicates whether the message should preempt the +// currently running agent (and how to cancel it) or be queued for processing +// after it finishes. +type MessageSource[T any] interface { + Receive(ctx context.Context, timeout time.Duration) (T, ConsumeOption, error) +} + +// TurnLoopConfig is the configuration for creating a TurnLoop. +type TurnLoopConfig[T any] struct { + // Source provides messages to drive the loop. Required. + Source MessageSource[T] + // GenInput converts a received message into AgentInput. Required. + GenInput func(ctx context.Context, item T) (*AgentInput, error) + // GetAgent returns the Agent to run for a given message. Required. + GetAgent func(ctx context.Context, item T) (Agent, error) + // OnAgentEvent is called for each event emitted by the agent. Optional. + OnAgentEvent func(ctx context.Context, event *AgentEvent) error + // RunOptions are passed to Agent.Run on each turn. Optional. + RunOptions []AgentRunOption + // ReceiveTimeout is the timeout passed to Source.Receive on each iteration. + // Zero means no timeout. Optional. + ReceiveTimeout time.Duration +} + +// TurnLoop is a loop that pulls messages from a source, runs an Agent for +// each message, and dispatches resulting events. It supports preemptive +// cancellation when the source returns ConsumePreemptive and the current +// agent implements Cancellable. +type TurnLoop[T any] struct { + source MessageSource[T] + genInput func(ctx context.Context, item T) (*AgentInput, error) + getAgent func(ctx context.Context, item T) (Agent, error) + onAgentEvent func(ctx context.Context, event *AgentEvent) error + runOptions []AgentRunOption + receiveTimeout time.Duration +} + +// NewTurnLoop creates a new TurnLoop from the given configuration. +// Source, GenInput, and GetAgent are required fields. +func NewTurnLoop[T any](config TurnLoopConfig[T]) (*TurnLoop[T], error) { + if config.Source == nil { + return nil, fmt.Errorf("TurnLoopConfig.Source is required") + } + if config.GenInput == nil { + return nil, fmt.Errorf("TurnLoopConfig.GenInput is required") + } + if config.GetAgent == nil { + return nil, fmt.Errorf("TurnLoopConfig.GetAgent is required") + } + + return &TurnLoop[T]{ + source: config.Source, + genInput: config.GenInput, + getAgent: config.GetAgent, + onAgentEvent: config.OnAgentEvent, + runOptions: config.RunOptions, + receiveTimeout: config.ReceiveTimeout, + }, nil +} + +// recvResult holds the result of a concurrent Receive call. +type recvResult[T any] struct { + item T + option ConsumeOption + err error +} + +// iterResult holds the result of a single AsyncIterator.Next call. +type iterResult struct { + event *AgentEvent + ok bool +} + +// Run starts the blocking loop that continuously receives messages, runs +// agents, and dispatches events. While an agent is running, the next message +// is received concurrently. If that message's ConsumeOption has ConsumePreemptive +// mode and the running agent implements Cancellable, the agent is canceled +// (using the CancelMode from the option) and the new message is processed +// immediately. +func (l *TurnLoop[T]) Run(ctx context.Context) error { + // done is closed when Run returns, signaling background goroutines to exit. + done := make(chan struct{}) + defer close(done) + + // Initial blocking receive — no agent running yet, mode is irrelevant. + item, _, err := l.source.Receive(ctx, l.receiveTimeout) + if err != nil { + return err + } + + for { + input, e := l.genInput(ctx, item) + if e != nil { + return fmt.Errorf("failed to generate agent input: %w", e) + } + + agent, e := l.getAgent(ctx, item) + if e != nil { + return fmt.Errorf("failed to get agent: %w", e) + } + + // Start receiving the next message concurrently with agent execution. + // The channel is buffered so the goroutine never blocks on send. + recvCh := make(chan recvResult[T], 1) + go func() { + i, opt, e_ := l.source.Receive(ctx, l.receiveTimeout) + recvCh <- recvResult[T]{i, opt, e_} + }() + + // Run the agent and forward events through a channel so we can + // select between agent events and incoming messages. + iter := agent.Run(ctx, input, l.runOptions...) + eventCh := make(chan iterResult, 1) + go func() { + for { + event, ok := iter.Next() + select { + case eventCh <- iterResult{event, ok}: + case <-done: + return + } + if !ok { + return + } + } + }() + + var pending *recvResult[T] + var turnErr error + + eventLoop: + for { + select { + case ev := <-eventCh: + if !ev.ok { + break eventLoop + } + if ev.event.Err != nil { + turnErr = fmt.Errorf("agent run failed: %w", ev.event.Err) + break eventLoop + } + if l.onAgentEvent != nil { + if e_ := l.onAgentEvent(ctx, ev.event); e_ != nil { + turnErr = fmt.Errorf("OnAgentEvent failed: %w", e_) + break eventLoop + } + } + + case recv := <-recvCh: + recvCh = nil // nil channel never matches in select + if recv.option.Mode == ConsumePreemptive && recv.err == nil { + if ca, ok := agent.(Cancellable); ok { + if e_ := ca.Cancel(ctx, recv.option.CancelOption); e_ != nil { + return fmt.Errorf("failed to cancel agent: %w", e_) + } + // Drain remaining events after cancellation. + for { + ev := <-eventCh + if !ev.ok { + break + } + } + pending = &recvResult[T]{item: recv.item} + break eventLoop + } + } + // Non-preemptive, preemptive but not Cancellable, or + // source error: buffer and let the eventLoop finish + // processing the current agent's events first. + pending = &recvResult[T]{item: recv.item, err: recv.err} + } + } + + if turnErr != nil { + return turnErr + } + + if pending != nil { + if pending.err != nil { + return pending.err + } + item = pending.item + } else { + // Agent finished before the next message arrived; wait for it. + recv := <-recvCh + if recv.err != nil { + return recv.err + } + item = recv.item + } + } +} diff --git a/adk/turn_loop_test.go b/adk/turn_loop_test.go new file mode 100644 index 000000000..9ef2d1a90 --- /dev/null +++ b/adk/turn_loop_test.go @@ -0,0 +1,564 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/cloudwego/eino/schema" +) + +// --------------------------------------------------------------------------- +// Test mocks +// --------------------------------------------------------------------------- + +// turnLoopMockSource returns items from a slice (all NonPreemptive), then an error. +type turnLoopMockSource struct { + items []string + idx int + err error +} + +func (s *turnLoopMockSource) Receive(_ context.Context, _ time.Duration) (string, ConsumeOption, error) { + if s.idx >= len(s.items) { + return "", NonPreemptiveConsumeOption, s.err + } + item := s.items[s.idx] + s.idx++ + return item, NonPreemptiveConsumeOption, nil +} + +// turnLoopFuncSource delegates Receive to a user-supplied function. +type turnLoopFuncSource[T any] struct { + fn func(ctx context.Context, timeout time.Duration) (T, ConsumeOption, error) +} + +func (s *turnLoopFuncSource[T]) Receive(ctx context.Context, timeout time.Duration) (T, ConsumeOption, error) { + return s.fn(ctx, timeout) +} + +// turnLoopMockAgent emits a fixed list of events per Run call. +type turnLoopMockAgent struct { + name string + events []*AgentEvent +} + +func (a *turnLoopMockAgent) Name(_ context.Context) string { return a.name } +func (a *turnLoopMockAgent) Description(_ context.Context) string { return "mock agent" } +func (a *turnLoopMockAgent) Run(_ context.Context, _ *AgentInput, _ ...AgentRunOption) *AsyncIterator[*AgentEvent] { + iter, gen := NewAsyncIteratorPair[*AgentEvent]() + go func() { + defer gen.Close() + for _, e := range a.events { + gen.Send(e) + } + }() + return iter +} + +// turnLoopCancellableAgent blocks until Cancel is called, then closes its iterator. +// It implements both Agent and Cancellable. +type turnLoopCancellableAgent struct { + name string + startedCh chan struct{} // closed when Run is entered + cancelCh chan struct{} // closed by Cancel + cancelled atomic.Bool + cancelledOpt *CancelOption // records the CancelOption passed to Cancel +} + +func (a *turnLoopCancellableAgent) Name(_ context.Context) string { return a.name } +func (a *turnLoopCancellableAgent) Description(_ context.Context) string { return "cancellable mock" } +func (a *turnLoopCancellableAgent) Run(_ context.Context, _ *AgentInput, _ ...AgentRunOption) *AsyncIterator[*AgentEvent] { + iter, gen := NewAsyncIteratorPair[*AgentEvent]() + close(a.startedCh) + go func() { + defer gen.Close() + <-a.cancelCh + }() + return iter +} + +func (a *turnLoopCancellableAgent) Cancel(_ context.Context, opt *CancelOption) error { + a.cancelled.Store(true) + a.cancelledOpt = opt + close(a.cancelCh) + return nil +} + +// turnLoopBlockingAgent blocks until blockCh is closed, then emits its events. +// It does NOT implement Cancellable. +type turnLoopBlockingAgent struct { + name string + startedCh chan struct{} + blockCh chan struct{} + events []*AgentEvent +} + +func (a *turnLoopBlockingAgent) Name(_ context.Context) string { return a.name } +func (a *turnLoopBlockingAgent) Description(_ context.Context) string { return "blocking mock" } +func (a *turnLoopBlockingAgent) Run(_ context.Context, _ *AgentInput, _ ...AgentRunOption) *AsyncIterator[*AgentEvent] { + iter, gen := NewAsyncIteratorPair[*AgentEvent]() + close(a.startedCh) + go func() { + defer gen.Close() + <-a.blockCh + for _, e := range a.events { + gen.Send(e) + } + }() + return iter +} + +// --------------------------------------------------------------------------- +// Tests — validation +// --------------------------------------------------------------------------- + +func TestNewTurnLoop_Validation(t *testing.T) { + t.Run("missing source", func(t *testing.T) { + _, err := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(_ context.Context, _ string) (*AgentInput, error) { return nil, nil }, + GetAgent: func(_ context.Context, _ string) (Agent, error) { return nil, nil }, + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "Source") + }) + + t.Run("missing GenInput", func(t *testing.T) { + _, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: &turnLoopMockSource{}, + GetAgent: func(_ context.Context, _ string) (Agent, error) { return nil, nil }, + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "GenInput") + }) + + t.Run("missing GetAgent", func(t *testing.T) { + _, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: &turnLoopMockSource{}, + GenInput: func(_ context.Context, _ string) (*AgentInput, error) { return nil, nil }, + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "GetAgent") + }) + + t.Run("valid config", func(t *testing.T) { + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: &turnLoopMockSource{}, + GenInput: func(_ context.Context, _ string) (*AgentInput, error) { return nil, nil }, + GetAgent: func(_ context.Context, _ string) (Agent, error) { return nil, nil }, + }) + require.NoError(t, err) + assert.NotNil(t, loop) + }) +} + +// --------------------------------------------------------------------------- +// Tests — non-preemptive (queued) behavior +// --------------------------------------------------------------------------- + +func TestTurnLoop_NormalLoop(t *testing.T) { + agent := &turnLoopMockAgent{ + name: "test-agent", + events: []*AgentEvent{ + {Output: &AgentOutput{MessageOutput: &MessageVariant{Message: schema.AssistantMessage("hello", nil)}}}, + }, + } + + var receivedEvents []*AgentEvent + var receivedItems []string + + source := &turnLoopMockSource{ + items: []string{"msg1", "msg2", "msg3"}, + err: context.DeadlineExceeded, + } + + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: source, + GenInput: func(_ context.Context, item string) (*AgentInput, error) { + receivedItems = append(receivedItems, item) + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil + }, + GetAgent: func(_ context.Context, _ string) (Agent, error) { + return agent, nil + }, + OnAgentEvent: func(_ context.Context, event *AgentEvent) error { + receivedEvents = append(receivedEvents, event) + return nil + }, + }) + require.NoError(t, err) + + err = loop.Run(context.Background()) + assert.ErrorIs(t, err, context.DeadlineExceeded) + assert.Equal(t, []string{"msg1", "msg2", "msg3"}, receivedItems) + assert.Len(t, receivedEvents, 3) +} + +func TestTurnLoop_SourceError(t *testing.T) { + sourceErr := errors.New("source failure") + source := &turnLoopMockSource{ + items: nil, + err: sourceErr, + } + + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: source, + GenInput: func(_ context.Context, _ string) (*AgentInput, error) { + return &AgentInput{}, nil + }, + GetAgent: func(_ context.Context, _ string) (Agent, error) { + return &turnLoopMockAgent{name: "a"}, nil + }, + }) + require.NoError(t, err) + + err = loop.Run(context.Background()) + assert.ErrorIs(t, err, sourceErr) +} + +func TestTurnLoop_GenInputError(t *testing.T) { + genErr := errors.New("gen input failure") + + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: &turnLoopMockSource{items: []string{"msg1"}, err: errors.New("should not reach")}, + GenInput: func(_ context.Context, _ string) (*AgentInput, error) { + return nil, genErr + }, + GetAgent: func(_ context.Context, _ string) (Agent, error) { + return &turnLoopMockAgent{name: "a"}, nil + }, + }) + require.NoError(t, err) + + err = loop.Run(context.Background()) + assert.ErrorIs(t, err, genErr) + assert.Contains(t, err.Error(), "failed to generate agent input") +} + +func TestTurnLoop_GetAgentError(t *testing.T) { + agentErr := errors.New("get agent failure") + + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: &turnLoopMockSource{items: []string{"msg1"}, err: errors.New("should not reach")}, + GenInput: func(_ context.Context, _ string) (*AgentInput, error) { + return &AgentInput{}, nil + }, + GetAgent: func(_ context.Context, _ string) (Agent, error) { + return nil, agentErr + }, + }) + require.NoError(t, err) + + err = loop.Run(context.Background()) + assert.ErrorIs(t, err, agentErr) + assert.Contains(t, err.Error(), "failed to get agent") +} + +func TestTurnLoop_OnAgentEventError(t *testing.T) { + eventErr := errors.New("event handler failure") + agent := &turnLoopMockAgent{ + name: "test-agent", + events: []*AgentEvent{{Output: &AgentOutput{}}}, + } + + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: &turnLoopMockSource{items: []string{"msg1"}, err: errors.New("should not reach")}, + GenInput: func(_ context.Context, _ string) (*AgentInput, error) { + return &AgentInput{}, nil + }, + GetAgent: func(_ context.Context, _ string) (Agent, error) { + return agent, nil + }, + OnAgentEvent: func(_ context.Context, _ *AgentEvent) error { + return eventErr + }, + }) + require.NoError(t, err) + + err = loop.Run(context.Background()) + assert.ErrorIs(t, err, eventErr) + assert.Contains(t, err.Error(), "OnAgentEvent failed") +} + +func TestTurnLoop_ContextCancellation(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + callCount := 0 + source := &turnLoopFuncSource[string]{fn: func(ctx context.Context, _ time.Duration) (string, ConsumeOption, error) { + callCount++ + if callCount > 1 { + cancel() + return "", NonPreemptiveConsumeOption, ctx.Err() + } + return "msg1", NonPreemptiveConsumeOption, nil + }} + + agent := &turnLoopMockAgent{ + name: "test-agent", + events: []*AgentEvent{{Output: &AgentOutput{}}}, + } + + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: source, + GenInput: func(_ context.Context, item string) (*AgentInput, error) { + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil + }, + GetAgent: func(_ context.Context, _ string) (Agent, error) { + return agent, nil + }, + }) + require.NoError(t, err) + + err = loop.Run(ctx) + assert.ErrorIs(t, err, context.Canceled) +} + +func TestTurnLoop_MultipleEventsPerTurn(t *testing.T) { + agent := &turnLoopMockAgent{ + name: "multi-event-agent", + events: []*AgentEvent{ + {Output: &AgentOutput{MessageOutput: &MessageVariant{Message: schema.AssistantMessage("event1", nil)}}}, + {Output: &AgentOutput{MessageOutput: &MessageVariant{Message: schema.AssistantMessage("event2", nil)}}}, + {Output: &AgentOutput{MessageOutput: &MessageVariant{Message: schema.AssistantMessage("event3", nil)}}}, + }, + } + + var eventCount int + + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: &turnLoopMockSource{items: []string{"msg1"}, err: context.DeadlineExceeded}, + GenInput: func(_ context.Context, item string) (*AgentInput, error) { + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil + }, + GetAgent: func(_ context.Context, _ string) (Agent, error) { + return agent, nil + }, + OnAgentEvent: func(_ context.Context, _ *AgentEvent) error { + eventCount++ + return nil + }, + }) + require.NoError(t, err) + + err = loop.Run(context.Background()) + assert.ErrorIs(t, err, context.DeadlineExceeded) + assert.Equal(t, 3, eventCount) +} + +func TestTurnLoop_NoOnAgentEvent(t *testing.T) { + agent := &turnLoopMockAgent{ + name: "test-agent", + events: []*AgentEvent{{Output: &AgentOutput{}}}, + } + + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: &turnLoopMockSource{items: []string{"msg1"}, err: context.DeadlineExceeded}, + GenInput: func(_ context.Context, item string) (*AgentInput, error) { + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil + }, + GetAgent: func(_ context.Context, _ string) (Agent, error) { + return agent, nil + }, + }) + require.NoError(t, err) + + err = loop.Run(context.Background()) + assert.ErrorIs(t, err, context.DeadlineExceeded) +} + +func TestTurnLoop_AgentErrorEvent(t *testing.T) { + agentErr := errors.New("agent internal error") + agent := &turnLoopMockAgent{ + name: "error-agent", + events: []*AgentEvent{{Err: agentErr}}, + } + + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: &turnLoopMockSource{items: []string{"msg1"}, err: context.DeadlineExceeded}, + GenInput: func(_ context.Context, item string) (*AgentInput, error) { + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil + }, + GetAgent: func(_ context.Context, _ string) (Agent, error) { + return agent, nil + }, + }) + require.NoError(t, err) + + err = loop.Run(context.Background()) + assert.ErrorIs(t, err, agentErr) + assert.Contains(t, err.Error(), "agent run failed") +} + +// --------------------------------------------------------------------------- +// Tests — preemptive behavior +// --------------------------------------------------------------------------- + +func TestTurnLoop_PreemptiveCancellation(t *testing.T) { + slowAgent := &turnLoopCancellableAgent{ + name: "slow-agent", + startedCh: make(chan struct{}), + cancelCh: make(chan struct{}), + } + + fastAgent := &turnLoopMockAgent{ + name: "fast-agent", + events: []*AgentEvent{{Output: &AgentOutput{}}}, + } + + var processedItems []string + callCount := 0 + source := &turnLoopFuncSource[string]{fn: func(_ context.Context, _ time.Duration) (string, ConsumeOption, error) { + callCount++ + switch callCount { + case 1: + return "slow-msg", NonPreemptiveConsumeOption, nil + case 2: + <-slowAgent.startedCh + return "preempt-msg", ConsumeOption{Mode: ConsumePreemptive, CancelOption: &CancelOption{Mode: CancelImmediate}}, nil + default: + return "", NonPreemptiveConsumeOption, context.DeadlineExceeded + } + }} + + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: source, + GenInput: func(_ context.Context, item string) (*AgentInput, error) { + processedItems = append(processedItems, item) + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil + }, + GetAgent: func(_ context.Context, item string) (Agent, error) { + if item == "slow-msg" { + return slowAgent, nil + } + return fastAgent, nil + }, + }) + require.NoError(t, err) + + err = loop.Run(context.Background()) + assert.ErrorIs(t, err, context.DeadlineExceeded) + assert.True(t, slowAgent.cancelled.Load(), "slow agent should have been cancelled") + assert.Equal(t, CancelImmediate, slowAgent.cancelledOpt.Mode) + assert.Equal(t, []string{"slow-msg", "preempt-msg"}, processedItems) +} + +func TestTurnLoop_PreemptiveWithCancelMode(t *testing.T) { + slowAgent := &turnLoopCancellableAgent{ + name: "slow-agent", + startedCh: make(chan struct{}), + cancelCh: make(chan struct{}), + } + + fastAgent := &turnLoopMockAgent{ + name: "fast-agent", + events: []*AgentEvent{{Output: &AgentOutput{}}}, + } + + callCount := 0 + source := &turnLoopFuncSource[string]{fn: func(_ context.Context, _ time.Duration) (string, ConsumeOption, error) { + callCount++ + switch callCount { + case 1: + return "slow-msg", NonPreemptiveConsumeOption, nil + case 2: + <-slowAgent.startedCh + return "preempt-msg", ConsumeOption{ + Mode: ConsumePreemptive, + CancelOption: &CancelOption{Mode: CancelAfterChatModel | CancelAfterToolCall}, + }, nil + default: + return "", NonPreemptiveConsumeOption, context.DeadlineExceeded + } + }} + + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: source, + GenInput: func(_ context.Context, item string) (*AgentInput, error) { + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil + }, + GetAgent: func(_ context.Context, item string) (Agent, error) { + if item == "slow-msg" { + return slowAgent, nil + } + return fastAgent, nil + }, + }) + require.NoError(t, err) + + err = loop.Run(context.Background()) + assert.ErrorIs(t, err, context.DeadlineExceeded) + assert.True(t, slowAgent.cancelled.Load()) + assert.Equal(t, CancelAfterChatModel|CancelAfterToolCall, slowAgent.cancelledOpt.Mode) +} + +func TestTurnLoop_PreemptiveNonCancellableAgent(t *testing.T) { + agentStarted := make(chan struct{}) + agentContinue := make(chan struct{}) + + blockingAgent := &turnLoopBlockingAgent{ + name: "blocking-agent", + startedCh: agentStarted, + blockCh: agentContinue, + events: []*AgentEvent{{Output: &AgentOutput{}}}, + } + fastAgent := &turnLoopMockAgent{ + name: "fast-agent", + events: []*AgentEvent{{Output: &AgentOutput{}}}, + } + + var processedItems []string + callCount := 0 + source := &turnLoopFuncSource[string]{fn: func(_ context.Context, _ time.Duration) (string, ConsumeOption, error) { + callCount++ + switch callCount { + case 1: + return "blocking-msg", NonPreemptiveConsumeOption, nil + case 2: + <-agentStarted + close(agentContinue) + return "preempt-msg", ConsumeOption{Mode: ConsumePreemptive}, nil + default: + return "", NonPreemptiveConsumeOption, context.DeadlineExceeded + } + }} + + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: source, + GenInput: func(_ context.Context, item string) (*AgentInput, error) { + processedItems = append(processedItems, item) + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil + }, + GetAgent: func(_ context.Context, item string) (Agent, error) { + if item == "blocking-msg" { + return blockingAgent, nil + } + return fastAgent, nil + }, + }) + require.NoError(t, err) + + err = loop.Run(context.Background()) + assert.ErrorIs(t, err, context.DeadlineExceeded) + assert.Equal(t, []string{"blocking-msg", "preempt-msg"}, processedItems) +} From 1e7f8a477f129ccb1d76de7390cbf478973b3448 Mon Sep 17 00:00:00 2001 From: "luohuaqing.2018" Date: Thu, 12 Feb 2026 18:13:32 +0800 Subject: [PATCH 32/65] refactor(adk): improve TurnLoop API signatures - Add inputItem param to OnAgentEvent callback - Move RunOptions from static config field to GenInput return value - Change CancelOption from pointer to value type Change-Id: I5bff76223fedb2b48fc5671e90d635a9479b20f6 Co-Authored-By: Claude Opus 4.6 --- adk/interface.go | 8 ++--- adk/turn_loop.go | 26 ++++++++--------- adk/turn_loop_test.go | 68 +++++++++++++++++++++---------------------- 3 files changed, 50 insertions(+), 52 deletions(-) diff --git a/adk/interface.go b/adk/interface.go index fef2f98fe..28ba857cb 100644 --- a/adk/interface.go +++ b/adk/interface.go @@ -301,10 +301,10 @@ type Cancellable interface { // Cancel signals the agent to stop, either immediately or after reaching the // specified execution point(s) defined by opt.Mode. // - // The opt parameter must be non-nil. Use opt.Mode to control when the - // cancellation takes effect (e.g., CancelImmediate, CancelAfterChatModel, - // CancelAfterToolCall, or a combination via bitwise OR). + // Use opt.Mode to control when the cancellation takes effect + // (e.g., CancelImmediate, CancelAfterChatModel, CancelAfterToolCall, + // or a combination via bitwise OR). The zero value defaults to CancelImmediate. // // If the agent has already finished execution, Cancel returns ErrAgentFinished. - Cancel(ctx context.Context, opt *CancelOption) error + Cancel(ctx context.Context, opt CancelOption) error } diff --git a/adk/turn_loop.go b/adk/turn_loop.go index d6ac24071..a8f8d4c4f 100644 --- a/adk/turn_loop.go +++ b/adk/turn_loop.go @@ -46,8 +46,8 @@ type ConsumeOption struct { Mode ConsumeMode // CancelOption specifies when and how the running agent should be canceled. // Only meaningful when Mode is ConsumePreemptive and the agent - // implements Cancellable. Default nil value means CancelImmediate. - CancelOption *CancelOption + // implements Cancellable. Default zero value means CancelImmediate. + CancelOption CancelOption } // NonPreemptiveConsumeOption is a convenience value for the common @@ -68,14 +68,14 @@ type MessageSource[T any] interface { type TurnLoopConfig[T any] struct { // Source provides messages to drive the loop. Required. Source MessageSource[T] - // GenInput converts a received message into AgentInput. Required. - GenInput func(ctx context.Context, item T) (*AgentInput, error) + // GenInput converts a received message into AgentInput and optional + // RunOptions for the agent. Required. + GenInput func(ctx context.Context, item T) (*AgentInput, []AgentRunOption, error) // GetAgent returns the Agent to run for a given message. Required. GetAgent func(ctx context.Context, item T) (Agent, error) // OnAgentEvent is called for each event emitted by the agent. Optional. - OnAgentEvent func(ctx context.Context, event *AgentEvent) error - // RunOptions are passed to Agent.Run on each turn. Optional. - RunOptions []AgentRunOption + // The inputItem is the message that triggered the current agent turn. + OnAgentEvent func(ctx context.Context, inputItem T, event *AgentEvent) error // ReceiveTimeout is the timeout passed to Source.Receive on each iteration. // Zero means no timeout. Optional. ReceiveTimeout time.Duration @@ -87,10 +87,9 @@ type TurnLoopConfig[T any] struct { // agent implements Cancellable. type TurnLoop[T any] struct { source MessageSource[T] - genInput func(ctx context.Context, item T) (*AgentInput, error) + genInput func(ctx context.Context, item T) (*AgentInput, []AgentRunOption, error) getAgent func(ctx context.Context, item T) (Agent, error) - onAgentEvent func(ctx context.Context, event *AgentEvent) error - runOptions []AgentRunOption + onAgentEvent func(ctx context.Context, inputItem T, event *AgentEvent) error receiveTimeout time.Duration } @@ -112,7 +111,6 @@ func NewTurnLoop[T any](config TurnLoopConfig[T]) (*TurnLoop[T], error) { genInput: config.GenInput, getAgent: config.GetAgent, onAgentEvent: config.OnAgentEvent, - runOptions: config.RunOptions, receiveTimeout: config.ReceiveTimeout, }, nil } @@ -148,7 +146,7 @@ func (l *TurnLoop[T]) Run(ctx context.Context) error { } for { - input, e := l.genInput(ctx, item) + input, runOpts, e := l.genInput(ctx, item) if e != nil { return fmt.Errorf("failed to generate agent input: %w", e) } @@ -168,7 +166,7 @@ func (l *TurnLoop[T]) Run(ctx context.Context) error { // Run the agent and forward events through a channel so we can // select between agent events and incoming messages. - iter := agent.Run(ctx, input, l.runOptions...) + iter := agent.Run(ctx, input, runOpts...) eventCh := make(chan iterResult, 1) go func() { for { @@ -199,7 +197,7 @@ func (l *TurnLoop[T]) Run(ctx context.Context) error { break eventLoop } if l.onAgentEvent != nil { - if e_ := l.onAgentEvent(ctx, ev.event); e_ != nil { + if e_ := l.onAgentEvent(ctx, item, ev.event); e_ != nil { turnErr = fmt.Errorf("OnAgentEvent failed: %w", e_) break eventLoop } diff --git a/adk/turn_loop_test.go b/adk/turn_loop_test.go index 9ef2d1a90..d27ff4fb8 100644 --- a/adk/turn_loop_test.go +++ b/adk/turn_loop_test.go @@ -84,7 +84,7 @@ type turnLoopCancellableAgent struct { startedCh chan struct{} // closed when Run is entered cancelCh chan struct{} // closed by Cancel cancelled atomic.Bool - cancelledOpt *CancelOption // records the CancelOption passed to Cancel + cancelledOpt CancelOption // records the CancelOption passed to Cancel } func (a *turnLoopCancellableAgent) Name(_ context.Context) string { return a.name } @@ -99,7 +99,7 @@ func (a *turnLoopCancellableAgent) Run(_ context.Context, _ *AgentInput, _ ...Ag return iter } -func (a *turnLoopCancellableAgent) Cancel(_ context.Context, opt *CancelOption) error { +func (a *turnLoopCancellableAgent) Cancel(_ context.Context, opt CancelOption) error { a.cancelled.Store(true) a.cancelledOpt = opt close(a.cancelCh) @@ -137,7 +137,7 @@ func (a *turnLoopBlockingAgent) Run(_ context.Context, _ *AgentInput, _ ...Agent func TestNewTurnLoop_Validation(t *testing.T) { t.Run("missing source", func(t *testing.T) { _, err := NewTurnLoop(TurnLoopConfig[string]{ - GenInput: func(_ context.Context, _ string) (*AgentInput, error) { return nil, nil }, + GenInput: func(_ context.Context, _ string) (*AgentInput, []AgentRunOption, error) { return nil, nil, nil }, GetAgent: func(_ context.Context, _ string) (Agent, error) { return nil, nil }, }) require.Error(t, err) @@ -156,7 +156,7 @@ func TestNewTurnLoop_Validation(t *testing.T) { t.Run("missing GetAgent", func(t *testing.T) { _, err := NewTurnLoop(TurnLoopConfig[string]{ Source: &turnLoopMockSource{}, - GenInput: func(_ context.Context, _ string) (*AgentInput, error) { return nil, nil }, + GenInput: func(_ context.Context, _ string) (*AgentInput, []AgentRunOption, error) { return nil, nil, nil }, }) require.Error(t, err) assert.Contains(t, err.Error(), "GetAgent") @@ -165,7 +165,7 @@ func TestNewTurnLoop_Validation(t *testing.T) { t.Run("valid config", func(t *testing.T) { loop, err := NewTurnLoop(TurnLoopConfig[string]{ Source: &turnLoopMockSource{}, - GenInput: func(_ context.Context, _ string) (*AgentInput, error) { return nil, nil }, + GenInput: func(_ context.Context, _ string) (*AgentInput, []AgentRunOption, error) { return nil, nil, nil }, GetAgent: func(_ context.Context, _ string) (Agent, error) { return nil, nil }, }) require.NoError(t, err) @@ -195,14 +195,14 @@ func TestTurnLoop_NormalLoop(t *testing.T) { loop, err := NewTurnLoop(TurnLoopConfig[string]{ Source: source, - GenInput: func(_ context.Context, item string) (*AgentInput, error) { + GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { receivedItems = append(receivedItems, item) - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil }, GetAgent: func(_ context.Context, _ string) (Agent, error) { return agent, nil }, - OnAgentEvent: func(_ context.Context, event *AgentEvent) error { + OnAgentEvent: func(_ context.Context, _ string, event *AgentEvent) error { receivedEvents = append(receivedEvents, event) return nil }, @@ -224,8 +224,8 @@ func TestTurnLoop_SourceError(t *testing.T) { loop, err := NewTurnLoop(TurnLoopConfig[string]{ Source: source, - GenInput: func(_ context.Context, _ string) (*AgentInput, error) { - return &AgentInput{}, nil + GenInput: func(_ context.Context, _ string) (*AgentInput, []AgentRunOption, error) { + return &AgentInput{}, nil, nil }, GetAgent: func(_ context.Context, _ string) (Agent, error) { return &turnLoopMockAgent{name: "a"}, nil @@ -242,8 +242,8 @@ func TestTurnLoop_GenInputError(t *testing.T) { loop, err := NewTurnLoop(TurnLoopConfig[string]{ Source: &turnLoopMockSource{items: []string{"msg1"}, err: errors.New("should not reach")}, - GenInput: func(_ context.Context, _ string) (*AgentInput, error) { - return nil, genErr + GenInput: func(_ context.Context, _ string) (*AgentInput, []AgentRunOption, error) { + return nil, nil, genErr }, GetAgent: func(_ context.Context, _ string) (Agent, error) { return &turnLoopMockAgent{name: "a"}, nil @@ -261,8 +261,8 @@ func TestTurnLoop_GetAgentError(t *testing.T) { loop, err := NewTurnLoop(TurnLoopConfig[string]{ Source: &turnLoopMockSource{items: []string{"msg1"}, err: errors.New("should not reach")}, - GenInput: func(_ context.Context, _ string) (*AgentInput, error) { - return &AgentInput{}, nil + GenInput: func(_ context.Context, _ string) (*AgentInput, []AgentRunOption, error) { + return &AgentInput{}, nil, nil }, GetAgent: func(_ context.Context, _ string) (Agent, error) { return nil, agentErr @@ -284,13 +284,13 @@ func TestTurnLoop_OnAgentEventError(t *testing.T) { loop, err := NewTurnLoop(TurnLoopConfig[string]{ Source: &turnLoopMockSource{items: []string{"msg1"}, err: errors.New("should not reach")}, - GenInput: func(_ context.Context, _ string) (*AgentInput, error) { - return &AgentInput{}, nil + GenInput: func(_ context.Context, _ string) (*AgentInput, []AgentRunOption, error) { + return &AgentInput{}, nil, nil }, GetAgent: func(_ context.Context, _ string) (Agent, error) { return agent, nil }, - OnAgentEvent: func(_ context.Context, _ *AgentEvent) error { + OnAgentEvent: func(_ context.Context, _ string, _ *AgentEvent) error { return eventErr }, }) @@ -321,8 +321,8 @@ func TestTurnLoop_ContextCancellation(t *testing.T) { loop, err := NewTurnLoop(TurnLoopConfig[string]{ Source: source, - GenInput: func(_ context.Context, item string) (*AgentInput, error) { - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil + GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil }, GetAgent: func(_ context.Context, _ string) (Agent, error) { return agent, nil @@ -348,13 +348,13 @@ func TestTurnLoop_MultipleEventsPerTurn(t *testing.T) { loop, err := NewTurnLoop(TurnLoopConfig[string]{ Source: &turnLoopMockSource{items: []string{"msg1"}, err: context.DeadlineExceeded}, - GenInput: func(_ context.Context, item string) (*AgentInput, error) { - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil + GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil }, GetAgent: func(_ context.Context, _ string) (Agent, error) { return agent, nil }, - OnAgentEvent: func(_ context.Context, _ *AgentEvent) error { + OnAgentEvent: func(_ context.Context, _ string, _ *AgentEvent) error { eventCount++ return nil }, @@ -374,8 +374,8 @@ func TestTurnLoop_NoOnAgentEvent(t *testing.T) { loop, err := NewTurnLoop(TurnLoopConfig[string]{ Source: &turnLoopMockSource{items: []string{"msg1"}, err: context.DeadlineExceeded}, - GenInput: func(_ context.Context, item string) (*AgentInput, error) { - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil + GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil }, GetAgent: func(_ context.Context, _ string) (Agent, error) { return agent, nil @@ -396,8 +396,8 @@ func TestTurnLoop_AgentErrorEvent(t *testing.T) { loop, err := NewTurnLoop(TurnLoopConfig[string]{ Source: &turnLoopMockSource{items: []string{"msg1"}, err: context.DeadlineExceeded}, - GenInput: func(_ context.Context, item string) (*AgentInput, error) { - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil + GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil }, GetAgent: func(_ context.Context, _ string) (Agent, error) { return agent, nil @@ -435,7 +435,7 @@ func TestTurnLoop_PreemptiveCancellation(t *testing.T) { return "slow-msg", NonPreemptiveConsumeOption, nil case 2: <-slowAgent.startedCh - return "preempt-msg", ConsumeOption{Mode: ConsumePreemptive, CancelOption: &CancelOption{Mode: CancelImmediate}}, nil + return "preempt-msg", ConsumeOption{Mode: ConsumePreemptive, CancelOption: CancelOption{Mode: CancelImmediate}}, nil default: return "", NonPreemptiveConsumeOption, context.DeadlineExceeded } @@ -443,9 +443,9 @@ func TestTurnLoop_PreemptiveCancellation(t *testing.T) { loop, err := NewTurnLoop(TurnLoopConfig[string]{ Source: source, - GenInput: func(_ context.Context, item string) (*AgentInput, error) { + GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { processedItems = append(processedItems, item) - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil }, GetAgent: func(_ context.Context, item string) (Agent, error) { if item == "slow-msg" { @@ -485,7 +485,7 @@ func TestTurnLoop_PreemptiveWithCancelMode(t *testing.T) { <-slowAgent.startedCh return "preempt-msg", ConsumeOption{ Mode: ConsumePreemptive, - CancelOption: &CancelOption{Mode: CancelAfterChatModel | CancelAfterToolCall}, + CancelOption: CancelOption{Mode: CancelAfterChatModel | CancelAfterToolCall}, }, nil default: return "", NonPreemptiveConsumeOption, context.DeadlineExceeded @@ -494,8 +494,8 @@ func TestTurnLoop_PreemptiveWithCancelMode(t *testing.T) { loop, err := NewTurnLoop(TurnLoopConfig[string]{ Source: source, - GenInput: func(_ context.Context, item string) (*AgentInput, error) { - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil + GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil }, GetAgent: func(_ context.Context, item string) (Agent, error) { if item == "slow-msg" { @@ -545,9 +545,9 @@ func TestTurnLoop_PreemptiveNonCancellableAgent(t *testing.T) { loop, err := NewTurnLoop(TurnLoopConfig[string]{ Source: source, - GenInput: func(_ context.Context, item string) (*AgentInput, error) { + GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { processedItems = append(processedItems, item) - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil }, GetAgent: func(_ context.Context, item string) (Agent, error) { if item == "blocking-msg" { From 72d65cf508ac2847ae6ff5150cd3c26053f2267f Mon Sep 17 00:00:00 2001 From: "luohuaqing.2018" Date: Fri, 13 Feb 2026 12:42:59 +0800 Subject: [PATCH 33/65] fix(adk): TurnLoop.Run Change-Id: Iaa0da27dcab3fdd446828a592142a4504278daf7 --- adk/turn_loop.go | 178 +++++++++++++++++++----------------------- adk/turn_loop_test.go | 52 +++--------- 2 files changed, 95 insertions(+), 135 deletions(-) diff --git a/adk/turn_loop.go b/adk/turn_loop.go index a8f8d4c4f..46100e1ee 100644 --- a/adk/turn_loop.go +++ b/adk/turn_loop.go @@ -19,7 +19,10 @@ package adk import ( "context" "fmt" + `runtime/debug` "time" + + `github.com/cloudwego/eino/internal/safe` ) // ConsumeMode specifies how a received message should be consumed @@ -115,34 +118,20 @@ func NewTurnLoop[T any](config TurnLoopConfig[T]) (*TurnLoop[T], error) { }, nil } -// recvResult holds the result of a concurrent Receive call. -type recvResult[T any] struct { - item T - option ConsumeOption - err error -} - -// iterResult holds the result of a single AsyncIterator.Next call. -type iterResult struct { - event *AgentEvent - ok bool -} - -// Run starts the blocking loop that continuously receives messages, runs -// agents, and dispatches events. While an agent is running, the next message -// is received concurrently. If that message's ConsumeOption has ConsumePreemptive -// mode and the running agent implements Cancellable, the agent is canceled -// (using the CancelMode from the option) and the new message is processed -// immediately. +// Run starts the blocking loop that continuously receives messages from the +// source, runs the agent returned by GetAgent for each message, and dispatches +// resulting events to OnAgentEvent. It blocks until the source returns an error +// (including context cancellation) or a callback fails. +// +// If a received message has ConsumePreemptive mode and the current agent +// implements Cancellable, the agent is canceled and the new message is processed +// immediately. If the agent does not implement Cancellable, preemptive messages +// are queued and processed after the current agent finishes. func (l *TurnLoop[T]) Run(ctx context.Context) error { - // done is closed when Run returns, signaling background goroutines to exit. - done := make(chan struct{}) - defer close(done) - - // Initial blocking receive — no agent running yet, mode is irrelevant. - item, _, err := l.source.Receive(ctx, l.receiveTimeout) + // Initial blocking receive — no agent is running yet. + item, option, err := l.source.Receive(ctx, l.receiveTimeout) if err != nil { - return err + return fmt.Errorf("failed to receive message: %w", err) } for { @@ -156,94 +145,91 @@ func (l *TurnLoop[T]) Run(ctx context.Context) error { return fmt.Errorf("failed to get agent: %w", e) } - // Start receiving the next message concurrently with agent execution. - // The channel is buffered so the goroutine never blocks on send. - recvCh := make(chan recvResult[T], 1) - go func() { - i, opt, e_ := l.source.Receive(ctx, l.receiveTimeout) - recvCh <- recvResult[T]{i, opt, e_} - }() - - // Run the agent and forward events through a channel so we can - // select between agent events and incoming messages. + ca, isAgentCancellable := agent.(Cancellable) iter := agent.Run(ctx, input, runOpts...) - eventCh := make(chan iterResult, 1) - go func() { + + // handleEvents drains the agent iterator, forwarding each event to the + // OnAgentEvent callback. It is called directly in the non-cancellable + // path and from a goroutine in the cancellable path. + handleEvents := func() error { for { event, ok := iter.Next() - select { - case eventCh <- iterResult{event, ok}: - case <-done: - return - } if !ok { - return + break } - } - }() - - var pending *recvResult[T] - var turnErr error - eventLoop: - for { - select { - case ev := <-eventCh: - if !ev.ok { - break eventLoop - } - if ev.event.Err != nil { - turnErr = fmt.Errorf("agent run failed: %w", ev.event.Err) - break eventLoop + if event.Err != nil { + return fmt.Errorf("agent run failed: %w", event.Err) } + if l.onAgentEvent != nil { - if e_ := l.onAgentEvent(ctx, item, ev.event); e_ != nil { - turnErr = fmt.Errorf("OnAgentEvent failed: %w", e_) - break eventLoop + e = l.onAgentEvent(ctx, item, event) + if e != nil { + return fmt.Errorf("OnAgentEvent callback failed: %w", e) } } + } - case recv := <-recvCh: - recvCh = nil // nil channel never matches in select - if recv.option.Mode == ConsumePreemptive && recv.err == nil { - if ca, ok := agent.(Cancellable); ok { - if e_ := ca.Cancel(ctx, recv.option.CancelOption); e_ != nil { - return fmt.Errorf("failed to cancel agent: %w", e_) - } - // Drain remaining events after cancellation. - for { - ev := <-eventCh - if !ev.ok { - break - } - } - pending = &recvResult[T]{item: recv.item} - break eventLoop + return nil + } + + var handleEventErr error + if isAgentCancellable { + // Cancellable path: consume events in a goroutine so the main + // goroutine can block on Receive concurrently. + done := make(chan struct{}) + + go func() { + defer func() { + // Recover panics from the iterator or callback so they + // don't crash the process; surface them as errors instead. + panicErr := recover() + if panicErr != nil { + handleEventErr = safe.NewPanicErr(panicErr, debug.Stack()) } - } - // Non-preemptive, preemptive but not Cancellable, or - // source error: buffer and let the eventLoop finish - // processing the current agent's events first. - pending = &recvResult[T]{item: recv.item, err: recv.err} + + close(done) + }() + + handleEventErr = handleEvents() + }() + + // Block on the next message while events are being consumed above. + item, option, err = l.source.Receive(ctx, l.receiveTimeout) + if err != nil { + <-done // wait for the event goroutine before returning + return fmt.Errorf("failed to receive message: %w", err) } - } - if turnErr != nil { - return turnErr - } + // If the new message requests preemption, cancel the running agent. + // Cancel triggers the iterator to terminate, which unblocks the + // event goroutine above. + if option.Mode == ConsumePreemptive { + err = ca.Cancel(ctx, option.CancelOption) + if err != nil { + <-done // wait for the event goroutine before returning + return fmt.Errorf("failed to cancel agent: %w", err) + } + } - if pending != nil { - if pending.err != nil { - return pending.err + // Wait for event consumption to finish (normal completion or + // post-cancel drain) before starting the next turn. + <-done + if handleEventErr != nil { + return fmt.Errorf("failed to handle events: %w", handleEventErr) } - item = pending.item } else { - // Agent finished before the next message arrived; wait for it. - recv := <-recvCh - if recv.err != nil { - return recv.err + // Non-cancellable path: consume all events sequentially, then + // block on the next message. + if handleEventErr = handleEvents(); handleEventErr != nil { + return fmt.Errorf("failed to handle events: %w", handleEventErr) + } + + item, option, err = l.source.Receive(ctx, l.receiveTimeout) + if err != nil { + return fmt.Errorf("failed to receive message: %w", err) } - item = recv.item } } } + diff --git a/adk/turn_loop_test.go b/adk/turn_loop_test.go index d27ff4fb8..f66c1298b 100644 --- a/adk/turn_loop_test.go +++ b/adk/turn_loop_test.go @@ -106,30 +106,6 @@ func (a *turnLoopCancellableAgent) Cancel(_ context.Context, opt CancelOption) e return nil } -// turnLoopBlockingAgent blocks until blockCh is closed, then emits its events. -// It does NOT implement Cancellable. -type turnLoopBlockingAgent struct { - name string - startedCh chan struct{} - blockCh chan struct{} - events []*AgentEvent -} - -func (a *turnLoopBlockingAgent) Name(_ context.Context) string { return a.name } -func (a *turnLoopBlockingAgent) Description(_ context.Context) string { return "blocking mock" } -func (a *turnLoopBlockingAgent) Run(_ context.Context, _ *AgentInput, _ ...AgentRunOption) *AsyncIterator[*AgentEvent] { - iter, gen := NewAsyncIteratorPair[*AgentEvent]() - close(a.startedCh) - go func() { - defer gen.Close() - <-a.blockCh - for _, e := range a.events { - gen.Send(e) - } - }() - return iter -} - // --------------------------------------------------------------------------- // Tests — validation // --------------------------------------------------------------------------- @@ -298,7 +274,7 @@ func TestTurnLoop_OnAgentEventError(t *testing.T) { err = loop.Run(context.Background()) assert.ErrorIs(t, err, eventErr) - assert.Contains(t, err.Error(), "OnAgentEvent failed") + assert.Contains(t, err.Error(), "OnAgentEvent callback failed") } func TestTurnLoop_ContextCancellation(t *testing.T) { @@ -513,14 +489,12 @@ func TestTurnLoop_PreemptiveWithCancelMode(t *testing.T) { } func TestTurnLoop_PreemptiveNonCancellableAgent(t *testing.T) { - agentStarted := make(chan struct{}) - agentContinue := make(chan struct{}) - - blockingAgent := &turnLoopBlockingAgent{ - name: "blocking-agent", - startedCh: agentStarted, - blockCh: agentContinue, - events: []*AgentEvent{{Output: &AgentOutput{}}}, + // A non-cancellable agent cannot be preempted, so the new Run processes + // events sequentially before calling Receive. The preemptive message is + // effectively queued and processed in the next turn. + nonCancellableAgent := &turnLoopMockAgent{ + name: "non-cancellable-agent", + events: []*AgentEvent{{Output: &AgentOutput{}}}, } fastAgent := &turnLoopMockAgent{ name: "fast-agent", @@ -533,10 +507,10 @@ func TestTurnLoop_PreemptiveNonCancellableAgent(t *testing.T) { callCount++ switch callCount { case 1: - return "blocking-msg", NonPreemptiveConsumeOption, nil + return "non-cancel-msg", NonPreemptiveConsumeOption, nil case 2: - <-agentStarted - close(agentContinue) + // Even though Mode is ConsumePreemptive, the agent doesn't + // implement Cancellable, so it's treated as non-preemptive. return "preempt-msg", ConsumeOption{Mode: ConsumePreemptive}, nil default: return "", NonPreemptiveConsumeOption, context.DeadlineExceeded @@ -550,8 +524,8 @@ func TestTurnLoop_PreemptiveNonCancellableAgent(t *testing.T) { return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil }, GetAgent: func(_ context.Context, item string) (Agent, error) { - if item == "blocking-msg" { - return blockingAgent, nil + if item == "non-cancel-msg" { + return nonCancellableAgent, nil } return fastAgent, nil }, @@ -560,5 +534,5 @@ func TestTurnLoop_PreemptiveNonCancellableAgent(t *testing.T) { err = loop.Run(context.Background()) assert.ErrorIs(t, err, context.DeadlineExceeded) - assert.Equal(t, []string{"blocking-msg", "preempt-msg"}, processedItems) + assert.Equal(t, []string{"non-cancel-msg", "preempt-msg"}, processedItems) } From b982edd7eed9f13bcc50bb31e674394ed8c0b56d Mon Sep 17 00:00:00 2001 From: Megumin Date: Sat, 14 Feb 2026 11:19:24 +0800 Subject: [PATCH 34/65] chore --- adk/interface.go | 41 ++++++++++++------- adk/turn_loop.go | 104 ++++++++++++++++++++++++++++++++--------------- 2 files changed, 97 insertions(+), 48 deletions(-) diff --git a/adk/interface.go b/adk/interface.go index 28ba857cb..2e7e3d09c 100644 --- a/adk/interface.go +++ b/adk/interface.go @@ -23,6 +23,7 @@ import ( "errors" "fmt" "io" + "time" "github.com/cloudwego/eino/components" "github.com/cloudwego/eino/internal/core" @@ -290,21 +291,31 @@ const ( // ErrAgentFinished is returned by Cancel when the agent has already finished execution. var ErrAgentFinished = errors.New("agent has already finished execution") -// CancelOption holds options for cancelling an agent. -type CancelOption struct { - Mode CancelMode +type cancelConfig struct { + Mode CancelMode + Timeout *time.Duration } -// Cancellable is an optional interface that an Agent can implement to support -// cancellation during execution. -type Cancellable interface { - // Cancel signals the agent to stop, either immediately or after reaching the - // specified execution point(s) defined by opt.Mode. - // - // Use opt.Mode to control when the cancellation takes effect - // (e.g., CancelImmediate, CancelAfterChatModel, CancelAfterToolCall, - // or a combination via bitwise OR). The zero value defaults to CancelImmediate. - // - // If the agent has already finished execution, Cancel returns ErrAgentFinished. - Cancel(ctx context.Context, opt CancelOption) error +type CancelOption func(*cancelConfig) + +func WithCancelMode(mode CancelMode) CancelOption { + return func(config *cancelConfig) { + config.Mode = mode + } +} + +func WithCancelTimeout(timeout time.Duration) CancelOption { + return func(config *cancelConfig) { + config.Timeout = &timeout + } +} + +type CancelFunc func(context.Context, ...CancelOption) error + +type CancellableRun interface { + RunWithCancel(ctx context.Context, input *AgentInput, options ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) +} + +type CancellableResume interface { + ResumeWithCancel(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) } diff --git a/adk/turn_loop.go b/adk/turn_loop.go index 46100e1ee..164bb44ed 100644 --- a/adk/turn_loop.go +++ b/adk/turn_loop.go @@ -19,10 +19,10 @@ package adk import ( "context" "fmt" - `runtime/debug` + "runtime/debug" "time" - `github.com/cloudwego/eino/internal/safe` + "github.com/cloudwego/eino/internal/safe" ) // ConsumeMode specifies how a received message should be consumed @@ -38,33 +38,58 @@ const ( // If the agent does not implement Cancellable, the message is // buffered and processed after the agent finishes. ConsumePreemptive + + ConsumePreemptiveOnTimeout ) -// ConsumeOption describes how a received message should be consumed. -// It combines ConsumeMode (preemptive vs queued) with CancelMode -// (how to cancel the running agent when preempting). -type ConsumeOption struct { - // Mode specifies whether the message should preempt the current agent - // or be queued. Default zero value is ConsumeNonPreemptive. - Mode ConsumeMode - // CancelOption specifies when and how the running agent should be canceled. - // Only meaningful when Mode is ConsumePreemptive and the agent - // implements Cancellable. Default zero value means CancelImmediate. - CancelOption CancelOption +type consumeConfig struct { + Mode ConsumeMode + Timeout time.Duration + CancelOpts []CancelOption +} + +type ConsumeOption func(*consumeConfig) + +func WithPreemptive() ConsumeOption { + return func(config *consumeConfig) { + config.Mode = ConsumePreemptive + } +} + +func WithPreemptiveOnTimeout(timeout time.Duration) ConsumeOption { + return func(config *consumeConfig) { + config.Mode = ConsumePreemptive + config.Timeout = timeout + } +} + +func WithCancelOptions(opts ...CancelOption) ConsumeOption { + return func(config *consumeConfig) { + config.CancelOpts = append(config.CancelOpts, opts...) + } } -// NonPreemptiveConsumeOption is a convenience value for the common -// non-preemptive (queued) case. -var NonPreemptiveConsumeOption = ConsumeOption{Mode: ConsumeNonPreemptive} +type receiveConfig struct { + Timeout *time.Duration + NonBlocking bool +} + +type ReceiveOption func(*receiveConfig) + +func WithReceiveTimeout(timeout time.Duration) ReceiveOption { + return func(config *receiveConfig) { + config.Timeout = &timeout + } +} + +func WithReceiveNonBlocking() ReceiveOption { + return func(config *receiveConfig) { + config.NonBlocking = true + } +} -// MessageSource is an interface for pulling typed messages from an external source. -// Receive blocks until a message is available or an error occurs. -// The timeout parameter specifies the maximum duration to wait for a message. -// The returned ConsumeOption indicates whether the message should preempt the -// currently running agent (and how to cancel it) or be queued for processing -// after it finishes. type MessageSource[T any] interface { - Receive(ctx context.Context, timeout time.Duration) (T, ConsumeOption, error) + Receive(ctx context.Context, option ...ReceiveOption) (context.Context, T, []ConsumeOption, error) } // TurnLoopConfig is the configuration for creating a TurnLoop. @@ -129,24 +154,29 @@ func NewTurnLoop[T any](config TurnLoopConfig[T]) (*TurnLoop[T], error) { // are queued and processed after the current agent finishes. func (l *TurnLoop[T]) Run(ctx context.Context) error { // Initial blocking receive — no agent is running yet. - item, option, err := l.source.Receive(ctx, l.receiveTimeout) + nCtx, item, option, err := l.source.Receive(ctx, WithReceiveTimeout(l.receiveTimeout)) if err != nil { return fmt.Errorf("failed to receive message: %w", err) } for { - input, runOpts, e := l.genInput(ctx, item) + input, runOpts, e := l.genInput(nCtx, item) if e != nil { return fmt.Errorf("failed to generate agent input: %w", e) } - agent, e := l.getAgent(ctx, item) + agent, e := l.getAgent(nCtx, item) if e != nil { return fmt.Errorf("failed to get agent: %w", e) } - ca, isAgentCancellable := agent.(Cancellable) - iter := agent.Run(ctx, input, runOpts...) + var cancelFunc CancelFunc + var iter *AsyncIterator[*AgentEvent] + if ca, isAgentCancellable := agent.(CancellableRun); isAgentCancellable { + iter, cancelFunc = ca.RunWithCancel(nCtx, input, runOpts...) + } else { + iter = agent.Run(nCtx, input, runOpts...) + } // handleEvents drains the agent iterator, forwarding each event to the // OnAgentEvent callback. It is called directly in the non-cancellable @@ -174,7 +204,7 @@ func (l *TurnLoop[T]) Run(ctx context.Context) error { } var handleEventErr error - if isAgentCancellable { + if cancelFunc != nil { // Cancellable path: consume events in a goroutine so the main // goroutine can block on Receive concurrently. done := make(chan struct{}) @@ -195,7 +225,7 @@ func (l *TurnLoop[T]) Run(ctx context.Context) error { }() // Block on the next message while events are being consumed above. - item, option, err = l.source.Receive(ctx, l.receiveTimeout) + nCtx, item, option, err = l.source.Receive(ctx, WithReceiveTimeout(l.receiveTimeout)) if err != nil { <-done // wait for the event goroutine before returning return fmt.Errorf("failed to receive message: %w", err) @@ -204,8 +234,9 @@ func (l *TurnLoop[T]) Run(ctx context.Context) error { // If the new message requests preemption, cancel the running agent. // Cancel triggers the iterator to terminate, which unblocks the // event goroutine above. - if option.Mode == ConsumePreemptive { - err = ca.Cancel(ctx, option.CancelOption) + o := applyConsumeOptions(option) + if o.Mode == ConsumePreemptive { + err = cancelFunc(ctx, o.CancelOpts...) if err != nil { <-done // wait for the event goroutine before returning return fmt.Errorf("failed to cancel agent: %w", err) @@ -225,7 +256,7 @@ func (l *TurnLoop[T]) Run(ctx context.Context) error { return fmt.Errorf("failed to handle events: %w", handleEventErr) } - item, option, err = l.source.Receive(ctx, l.receiveTimeout) + nCtx, item, option, err = l.source.Receive(ctx, WithReceiveTimeout(l.receiveTimeout)) if err != nil { return fmt.Errorf("failed to receive message: %w", err) } @@ -233,3 +264,10 @@ func (l *TurnLoop[T]) Run(ctx context.Context) error { } } +func applyConsumeOptions(opts []ConsumeOption) *consumeConfig { + var config consumeConfig + for _, opt := range opts { + opt(&config) + } + return &config +} From 729759a245f9a1abac3a54def5f513a418cdde17 Mon Sep 17 00:00:00 2001 From: Megumin Date: Sat, 14 Feb 2026 11:33:38 +0800 Subject: [PATCH 35/65] chore --- adk/turn_loop.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/adk/turn_loop.go b/adk/turn_loop.go index 164bb44ed..376cd667d 100644 --- a/adk/turn_loop.go +++ b/adk/turn_loop.go @@ -225,7 +225,7 @@ func (l *TurnLoop[T]) Run(ctx context.Context) error { }() // Block on the next message while events are being consumed above. - nCtx, item, option, err = l.source.Receive(ctx, WithReceiveTimeout(l.receiveTimeout)) + nCtx, item, option, err = l.source.Receive(nCtx, WithReceiveTimeout(l.receiveTimeout)) if err != nil { <-done // wait for the event goroutine before returning return fmt.Errorf("failed to receive message: %w", err) @@ -236,7 +236,7 @@ func (l *TurnLoop[T]) Run(ctx context.Context) error { // event goroutine above. o := applyConsumeOptions(option) if o.Mode == ConsumePreemptive { - err = cancelFunc(ctx, o.CancelOpts...) + err = cancelFunc(nCtx, o.CancelOpts...) if err != nil { <-done // wait for the event goroutine before returning return fmt.Errorf("failed to cancel agent: %w", err) @@ -256,7 +256,7 @@ func (l *TurnLoop[T]) Run(ctx context.Context) error { return fmt.Errorf("failed to handle events: %w", handleEventErr) } - nCtx, item, option, err = l.source.Receive(ctx, WithReceiveTimeout(l.receiveTimeout)) + nCtx, item, option, err = l.source.Receive(nCtx, WithReceiveTimeout(l.receiveTimeout)) if err != nil { return fmt.Errorf("failed to receive message: %w", err) } From dce1a20a77fffb4a5e6c66a1204d5cd45d6a4f56 Mon Sep 17 00:00:00 2001 From: Megumin Date: Sat, 14 Feb 2026 11:42:05 +0800 Subject: [PATCH 36/65] chore --- adk/turn_loop.go | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/adk/turn_loop.go b/adk/turn_loop.go index 376cd667d..673a8fcea 100644 --- a/adk/turn_loop.go +++ b/adk/turn_loop.go @@ -58,7 +58,7 @@ func WithPreemptive() ConsumeOption { func WithPreemptiveOnTimeout(timeout time.Duration) ConsumeOption { return func(config *consumeConfig) { - config.Mode = ConsumePreemptive + config.Mode = ConsumePreemptiveOnTimeout config.Timeout = timeout } } @@ -235,12 +235,23 @@ func (l *TurnLoop[T]) Run(ctx context.Context) error { // Cancel triggers the iterator to terminate, which unblocks the // event goroutine above. o := applyConsumeOptions(option) - if o.Mode == ConsumePreemptive { + switch o.Mode { + case ConsumePreemptive: err = cancelFunc(nCtx, o.CancelOpts...) if err != nil { <-done // wait for the event goroutine before returning return fmt.Errorf("failed to cancel agent: %w", err) } + case ConsumePreemptiveOnTimeout: + select { + case <-done: + case <-time.After(o.Timeout): + err = cancelFunc(nCtx, o.CancelOpts...) + if err != nil { + <-done // wait for the event goroutine before returning + return fmt.Errorf("failed to cancel agent: %w", err) + } + } } // Wait for event consumption to finish (normal completion or From ff5148cd3c4d174e6baae53e9d9e639f922212f5 Mon Sep 17 00:00:00 2001 From: Megumin Date: Sat, 14 Feb 2026 16:13:04 +0800 Subject: [PATCH 37/65] chore --- adk/turn_loop.go | 33 ++++++++++++--------------------- 1 file changed, 12 insertions(+), 21 deletions(-) diff --git a/adk/turn_loop.go b/adk/turn_loop.go index 673a8fcea..ded163547 100644 --- a/adk/turn_loop.go +++ b/adk/turn_loop.go @@ -69,27 +69,12 @@ func WithCancelOptions(opts ...CancelOption) ConsumeOption { } } -type receiveConfig struct { - Timeout *time.Duration - NonBlocking bool -} - -type ReceiveOption func(*receiveConfig) - -func WithReceiveTimeout(timeout time.Duration) ReceiveOption { - return func(config *receiveConfig) { - config.Timeout = &timeout - } -} - -func WithReceiveNonBlocking() ReceiveOption { - return func(config *receiveConfig) { - config.NonBlocking = true - } +type ReceiveConfig struct { + Timeout time.Duration } type MessageSource[T any] interface { - Receive(ctx context.Context, option ...ReceiveOption) (context.Context, T, []ConsumeOption, error) + Receive(context.Context, ReceiveConfig) (context.Context, T, []ConsumeOption, error) } // TurnLoopConfig is the configuration for creating a TurnLoop. @@ -154,7 +139,9 @@ func NewTurnLoop[T any](config TurnLoopConfig[T]) (*TurnLoop[T], error) { // are queued and processed after the current agent finishes. func (l *TurnLoop[T]) Run(ctx context.Context) error { // Initial blocking receive — no agent is running yet. - nCtx, item, option, err := l.source.Receive(ctx, WithReceiveTimeout(l.receiveTimeout)) + nCtx, item, option, err := l.source.Receive(ctx, ReceiveConfig{ + Timeout: l.receiveTimeout, + }) if err != nil { return fmt.Errorf("failed to receive message: %w", err) } @@ -225,7 +212,9 @@ func (l *TurnLoop[T]) Run(ctx context.Context) error { }() // Block on the next message while events are being consumed above. - nCtx, item, option, err = l.source.Receive(nCtx, WithReceiveTimeout(l.receiveTimeout)) + nCtx, item, option, err = l.source.Receive(nCtx, ReceiveConfig{ + Timeout: l.receiveTimeout, + }) if err != nil { <-done // wait for the event goroutine before returning return fmt.Errorf("failed to receive message: %w", err) @@ -267,7 +256,9 @@ func (l *TurnLoop[T]) Run(ctx context.Context) error { return fmt.Errorf("failed to handle events: %w", handleEventErr) } - nCtx, item, option, err = l.source.Receive(nCtx, WithReceiveTimeout(l.receiveTimeout)) + nCtx, item, option, err = l.source.Receive(nCtx, ReceiveConfig{ + Timeout: l.receiveTimeout, + }) if err != nil { return fmt.Errorf("failed to receive message: %w", err) } From a4acca92f7c0358d3e7cda7e8b77954257b03726 Mon Sep 17 00:00:00 2001 From: Megumin Date: Thu, 19 Feb 2026 18:16:46 +0800 Subject: [PATCH 38/65] feat(adk): modify on agent events (#795) --- adk/turn_loop.go | 26 ++++++-------------------- 1 file changed, 6 insertions(+), 20 deletions(-) diff --git a/adk/turn_loop.go b/adk/turn_loop.go index ded163547..bc4d9b69c 100644 --- a/adk/turn_loop.go +++ b/adk/turn_loop.go @@ -88,7 +88,7 @@ type TurnLoopConfig[T any] struct { GetAgent func(ctx context.Context, item T) (Agent, error) // OnAgentEvent is called for each event emitted by the agent. Optional. // The inputItem is the message that triggered the current agent turn. - OnAgentEvent func(ctx context.Context, inputItem T, event *AgentEvent) error + OnAgentEvents func(ctx context.Context, inputItem T, event *AsyncIterator[*AgentEvent]) error // ReceiveTimeout is the timeout passed to Source.Receive on each iteration. // Zero means no timeout. Optional. ReceiveTimeout time.Duration @@ -102,7 +102,7 @@ type TurnLoop[T any] struct { source MessageSource[T] genInput func(ctx context.Context, item T) (*AgentInput, []AgentRunOption, error) getAgent func(ctx context.Context, item T) (Agent, error) - onAgentEvent func(ctx context.Context, inputItem T, event *AgentEvent) error + onAgentEvents func(ctx context.Context, inputItem T, event *AsyncIterator[*AgentEvent]) error receiveTimeout time.Duration } @@ -123,7 +123,7 @@ func NewTurnLoop[T any](config TurnLoopConfig[T]) (*TurnLoop[T], error) { source: config.Source, genInput: config.GenInput, getAgent: config.GetAgent, - onAgentEvent: config.OnAgentEvent, + onAgentEvents: config.OnAgentEvents, receiveTimeout: config.ReceiveTimeout, }, nil } @@ -169,24 +169,10 @@ func (l *TurnLoop[T]) Run(ctx context.Context) error { // OnAgentEvent callback. It is called directly in the non-cancellable // path and from a goroutine in the cancellable path. handleEvents := func() error { - for { - event, ok := iter.Next() - if !ok { - break - } - - if event.Err != nil { - return fmt.Errorf("agent run failed: %w", event.Err) - } - - if l.onAgentEvent != nil { - e = l.onAgentEvent(ctx, item, event) - if e != nil { - return fmt.Errorf("OnAgentEvent callback failed: %w", e) - } - } + oe := l.onAgentEvents(ctx, item, iter) + if oe != nil { + return oe } - return nil } From 680c7263ff2d3ca0d30ae966944e912806ccc357 Mon Sep 17 00:00:00 2001 From: Megumin Date: Sat, 21 Feb 2026 11:45:32 +0800 Subject: [PATCH 39/65] feat(adk): turn loop support front and exit loop (#796) --- adk/turn_loop.go | 42 +++++++++++++++++++++++++----------------- 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/adk/turn_loop.go b/adk/turn_loop.go index bc4d9b69c..e5cf1c36f 100644 --- a/adk/turn_loop.go +++ b/adk/turn_loop.go @@ -18,6 +18,7 @@ package adk import ( "context" + "errors" "fmt" "runtime/debug" "time" @@ -75,6 +76,7 @@ type ReceiveConfig struct { type MessageSource[T any] interface { Receive(context.Context, ReceiveConfig) (context.Context, T, []ConsumeOption, error) + Front(context.Context, ReceiveConfig) (context.Context, T, []ConsumeOption, error) } // TurnLoopConfig is the configuration for creating a TurnLoop. @@ -128,6 +130,8 @@ func NewTurnLoop[T any](config TurnLoopConfig[T]) (*TurnLoop[T], error) { }, nil } +var ErrLoopExit = errors.New("loop exit") + // Run starts the blocking loop that continuously receives messages from the // source, runs the agent returned by GetAgent for each message, and dispatches // resulting events to OnAgentEvent. It blocks until the source returns an error @@ -138,15 +142,17 @@ func NewTurnLoop[T any](config TurnLoopConfig[T]) (*TurnLoop[T], error) { // immediately. If the agent does not implement Cancellable, preemptive messages // are queued and processed after the current agent finishes. func (l *TurnLoop[T]) Run(ctx context.Context) error { - // Initial blocking receive — no agent is running yet. - nCtx, item, option, err := l.source.Receive(ctx, ReceiveConfig{ - Timeout: l.receiveTimeout, - }) - if err != nil { - return fmt.Errorf("failed to receive message: %w", err) - } - for { + nCtx, item, option, err := l.source.Receive(ctx, ReceiveConfig{ + Timeout: l.receiveTimeout, + }) + if errors.Is(err, ErrLoopExit) { + return nil + } + if err != nil { + return fmt.Errorf("failed to receive message: %w", err) + } + input, runOpts, e := l.genInput(nCtx, item) if e != nil { return fmt.Errorf("failed to generate agent input: %w", e) @@ -198,12 +204,15 @@ func (l *TurnLoop[T]) Run(ctx context.Context) error { }() // Block on the next message while events are being consumed above. - nCtx, item, option, err = l.source.Receive(nCtx, ReceiveConfig{ + _, _, option, err = l.source.Front(nCtx, ReceiveConfig{ Timeout: l.receiveTimeout, }) if err != nil { <-done // wait for the event goroutine before returning - return fmt.Errorf("failed to receive message: %w", err) + if errors.Is(err, ErrLoopExit) { + return nil + } + return fmt.Errorf("failed to front message: %w", err) } // If the new message requests preemption, cancel the running agent. @@ -233,21 +242,20 @@ func (l *TurnLoop[T]) Run(ctx context.Context) error { // post-cancel drain) before starting the next turn. <-done if handleEventErr != nil { + if errors.Is(handleEventErr, ErrLoopExit) { + return nil + } return fmt.Errorf("failed to handle events: %w", handleEventErr) } } else { // Non-cancellable path: consume all events sequentially, then // block on the next message. if handleEventErr = handleEvents(); handleEventErr != nil { + if errors.Is(handleEventErr, ErrLoopExit) { + return nil + } return fmt.Errorf("failed to handle events: %w", handleEventErr) } - - nCtx, item, option, err = l.source.Receive(nCtx, ReceiveConfig{ - Timeout: l.receiveTimeout, - }) - if err != nil { - return fmt.Errorf("failed to receive message: %w", err) - } } } } From 2279e23cf60fd2f9cb73c1e69d5b3740ec3bc98d Mon Sep 17 00:00:00 2001 From: IPender Date: Tue, 24 Feb 2026 14:29:37 +0800 Subject: [PATCH 40/65] feat(adk): implement cancel mechanism for ChatModelAgent (#797) --- .gitignore | 1 + adk/cancel_test.go | 1023 +++++++++++++++++++++++++++++++++++++++++ adk/cancel_wrapper.go | 280 +++++++++++ adk/chatmodel.go | 149 ++++-- adk/flow.go | 96 ++-- adk/interface.go | 11 +- adk/react.go | 100 +++- adk/react_test.go | 12 +- adk/runner.go | 96 +++- adk/turn_loop.go | 428 ++++++++++++++--- adk/turn_loop_test.go | 811 +++++++++++++++++++++++++++----- 11 files changed, 2747 insertions(+), 260 deletions(-) create mode 100644 adk/cancel_test.go create mode 100644 adk/cancel_wrapper.go diff --git a/.gitignore b/.gitignore index 8ef36de95..8ac1d568d 100644 --- a/.gitignore +++ b/.gitignore @@ -47,6 +47,7 @@ output/* # Reports (generated analysis files) reports/ +/todos .DS_Store *.log diff --git a/adk/cancel_test.go b/adk/cancel_test.go new file mode 100644 index 000000000..f45e344ff --- /dev/null +++ b/adk/cancel_test.go @@ -0,0 +1,1023 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/components/tool" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" +) + +type cancelTestChatModel struct { + delay time.Duration + response *schema.Message + startedChan chan struct{} + doneChan chan struct{} +} + +func (m *cancelTestChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + select { + case m.startedChan <- struct{}{}: + default: + } + time.Sleep(m.delay) + select { + case m.doneChan <- struct{}{}: + default: + } + return m.response, nil +} + +func (m *cancelTestChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + m.startedChan <- struct{}{} + time.Sleep(m.delay) + m.doneChan <- struct{}{} + return schema.StreamReaderFromArray([]*schema.Message{m.response}), nil +} + +func (m *cancelTestChatModel) BindTools(tools []*schema.ToolInfo) error { + return nil +} + +type slowTool struct { + name string + delay time.Duration + result string + callCount int32 + startedChan chan struct{} +} + +func newSlowTool(name string, delay time.Duration, result string) *slowTool { + return &slowTool{ + name: name, + delay: delay, + result: result, + startedChan: make(chan struct{}, 10), + } +} + +func (t *slowTool) Info(ctx context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: t.name, + Desc: "A slow tool for testing", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "input": {Type: "string", Desc: "Input parameter"}, + }), + }, nil +} + +func (t *slowTool) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { + atomic.AddInt32(&t.callCount, 1) + select { + case t.startedChan <- struct{}{}: + default: + } + time.Sleep(t.delay) + return t.result, nil +} + +type cancelTestStore struct { + m map[string][]byte + mu sync.Mutex +} + +func newCancelTestStore() *cancelTestStore { + return &cancelTestStore{m: make(map[string][]byte)} +} + +func (s *cancelTestStore) Set(_ context.Context, key string, value []byte) error { + s.mu.Lock() + defer s.mu.Unlock() + s.m[key] = value + return nil +} + +func (s *cancelTestStore) Get(_ context.Context, key string) ([]byte, bool, error) { + s.mu.Lock() + defer s.mu.Unlock() + v, ok := s.m[key] + return v, ok, nil +} + +func TestCancelSig(t *testing.T) { + t.Run("BasicCancelSignal", func(t *testing.T) { + cs := newCancelSig() + + cfg := checkCancelSig(cs) + assert.Nil(t, cfg, "Should not be cancelled initially") + + cs.cancel(&cancelConfig{Mode: CancelImmediate}) + + cfg = checkCancelSig(cs) + assert.NotNil(t, cfg, "Should be cancelled after cancel()") + assert.Equal(t, CancelImmediate, cfg.Mode) + }) +} + +func TestRunWithCancel_WithTools(t *testing.T) { + ctx := context.Background() + + t.Run("CancelImmediate_DuringModelCall", func(t *testing.T) { + modelStarted := make(chan struct{}, 1) + st := newSlowTool("slow_tool", 100*time.Millisecond, "tool result") + + slowModel := &cancelTestChatModel{ + delay: 2 * time.Second, + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: schema.FunctionCall{ + Name: "slow_tool", + Arguments: `{"input": "test"}`, + }, + }, + }, + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent with tool", + Instruction: "You are a test assistant", + Model: slowModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + EnableStreaming: false, + }) + + iter, cancelFn := runner.RunWithCancel(ctx, []Message{schema.UserMessage("Use the tool")}) + assert.NotNil(t, iter) + assert.NotNil(t, cancelFn) + + eventsCh := make(chan []*AgentEvent, 1) + go func() { + var events []*AgentEvent + for { + event, ok := iter.Next() + if !ok { + break + } + events = append(events, event) + } + eventsCh <- events + }() + + select { + case <-modelStarted: + case <-time.After(5 * time.Second): + t.Fatal("Model did not start within 5 seconds") + } + + time.Sleep(100 * time.Millisecond) + + err = cancelFn(ctx) + assert.NoError(t, err) + + start := time.Now() + events := <-eventsCh + elapsed := time.Since(start) + + assert.True(t, elapsed < 1*time.Second, "Should return quickly after cancel, elapsed: %v", elapsed) + assert.True(t, len(events) > 0) + + hasInterrupted := false + for _, e := range events { + assert.Nil(t, e.Err, "Should not have error event after cancel") + if e.Action != nil && e.Action.Interrupted != nil { + hasInterrupted = true + } + } + assert.True(t, hasInterrupted, "Should have interrupted event after cancel") + }) + + t.Run("CancelAfterChatModel_DuringToolCall", func(t *testing.T) { + toolStarted := make(chan struct{}, 1) + st := &slowToolWithSignal{ + name: "slow_tool", + delay: 2 * time.Second, + result: "tool result", + startedChan: toolStarted, + } + + modelWithToolCall := &simpleChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: schema.FunctionCall{ + Name: "slow_tool", + Arguments: `{"input": "test"}`, + }, + }, + }, + }, + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent with tool", + Instruction: "You are a test assistant", + Model: modelWithToolCall, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + + iter, cancelFn := agent.RunWithCancel(ctx, &AgentInput{ + Messages: []Message{schema.UserMessage("Use the tool")}, + }) + assert.NotNil(t, iter) + assert.NotNil(t, cancelFn) + + <-toolStarted + + time.Sleep(100 * time.Millisecond) + + err = cancelFn(ctx, WithCancelMode(CancelAfterChatModel)) + assert.NoError(t, err) + + var events []*AgentEvent + for { + event, ok := iter.Next() + if !ok { + break + } + assert.Nil(t, event.Err, "Should not have error event after cancel") + events = append(events, event) + } + + assert.True(t, len(events) > 0) + assert.True(t, atomic.LoadInt32(&st.callCount) >= 1, "Tool should have been called") + }) + + t.Run("CancelAfterToolCall_CompletesToolExecution", func(t *testing.T) { + toolStarted := make(chan struct{}, 1) + st := &slowToolWithSignal{ + name: "slow_tool", + delay: 500 * time.Millisecond, + result: "tool result", + startedChan: toolStarted, + } + + modelWithToolCall := &simpleChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: schema.FunctionCall{ + Name: "slow_tool", + Arguments: `{"input": "test"}`, + }, + }, + }, + }, + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent with tool", + Instruction: "You are a test assistant", + Model: modelWithToolCall, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + + iter, cancelFn := agent.RunWithCancel(ctx, &AgentInput{ + Messages: []Message{schema.UserMessage("Use the tool")}, + }) + assert.NotNil(t, iter) + assert.NotNil(t, cancelFn) + + <-toolStarted + + time.Sleep(100 * time.Millisecond) + + err = cancelFn(ctx, WithCancelMode(CancelAfterToolCall)) + assert.NoError(t, err) + + var events []*AgentEvent + for { + event, ok := iter.Next() + if !ok { + break + } + assert.Nil(t, event.Err, "Should not have error event after cancel") + events = append(events, event) + } + + assert.True(t, len(events) > 0) + assert.True(t, atomic.LoadInt32(&st.callCount) >= 1, "Tool should have been called") + }) +} + +type slowToolWithSignal struct { + name string + delay time.Duration + result string + callCount int32 + startedChan chan struct{} +} + +func (t *slowToolWithSignal) Info(ctx context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: t.name, + Desc: "A slow tool for testing", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "input": {Type: "string", Desc: "Input parameter"}, + }), + }, nil +} + +func (t *slowToolWithSignal) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { + atomic.AddInt32(&t.callCount, 1) + t.startedChan <- struct{}{} + time.Sleep(t.delay) + return t.result, nil +} + +type simpleChatModel struct { + response *schema.Message +} + +func (m *simpleChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + return m.response, nil +} + +func (m *simpleChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return schema.StreamReaderFromArray([]*schema.Message{m.response}), nil +} + +func (m *simpleChatModel) BindTools(tools []*schema.ToolInfo) error { + return nil +} + +func TestRunWithCancel_WithCheckpoint(t *testing.T) { + ctx := context.Background() + + t.Run("CancelWithCheckpoint", func(t *testing.T) { + modelStarted := make(chan struct{}, 1) + st := newSlowTool("slow_tool", 100*time.Millisecond, "tool result") + + slowModel := &cancelTestChatModel{ + delay: 500 * time.Millisecond, + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: schema.FunctionCall{ + Name: "slow_tool", + Arguments: `{"input": "test"}`, + }, + }, + }, + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent with tool", + Instruction: "You are a test assistant", + Model: slowModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + + store := newCancelTestStore() + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + EnableStreaming: false, + CheckPointStore: store, + }) + + iter, cancelFn := runner.RunWithCancel(ctx, []Message{schema.UserMessage("Use the tool")}, WithCheckPointID("cancel-1")) + + <-modelStarted + + err = cancelFn(ctx) + assert.NoError(t, err) + + var events []*AgentEvent + for { + event, ok := iter.Next() + if !ok { + break + } + assert.Nil(t, event.Err, "Should not have error event after cancel") + events = append(events, event) + } + + assert.True(t, len(events) > 0) + }) +} + +func TestCancelFuncMultipleCalls(t *testing.T) { + ctx := context.Background() + + t.Run("SecondCancelReturnsErrAgentFinished", func(t *testing.T) { + modelStarted := make(chan struct{}, 1) + st := newSlowTool("slow_tool", 100*time.Millisecond, "tool result") + + slowModel := &cancelTestChatModel{ + delay: 1 * time.Second, + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: schema.FunctionCall{ + Name: "slow_tool", + Arguments: `{"input": "test"}`, + }, + }, + }, + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent with tool", + Instruction: "You are a test assistant", + Model: slowModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + EnableStreaming: false, + }) + + iter, cancelFn := runner.RunWithCancel(ctx, []Message{schema.UserMessage("Use the tool")}) + + <-modelStarted + + cancelErr := cancelFn(ctx) + assert.NoError(t, cancelErr) + + cancelErr = cancelFn(ctx) + assert.ErrorIs(t, cancelErr, ErrAgentFinished) + + for { + _, ok := iter.Next() + if !ok { + break + } + } + }) +} + +func TestAgentNotCancellable(t *testing.T) { + ctx := context.Background() + + nonCancellableAgent := &nonCancellableTestAgent{ + name: "NonCancellable", + } + + runner := NewRunner(ctx, RunnerConfig{ + Agent: nonCancellableAgent, + EnableStreaming: false, + }) + + t.Run("RunWithCancelReturnsNilCancelFunc", func(t *testing.T) { + iter, cancelFn := runner.RunWithCancel(ctx, []Message{schema.UserMessage("Hello")}) + assert.NotNil(t, iter) + assert.Nil(t, cancelFn) + + for { + _, ok := iter.Next() + if !ok { + break + } + } + }) +} + +type nonCancellableTestAgent struct { + name string +} + +func (a *nonCancellableTestAgent) Name(_ context.Context) string { + return a.name +} + +func (a *nonCancellableTestAgent) Description(_ context.Context) string { + return "A non-cancellable agent" +} + +func (a *nonCancellableTestAgent) Run(_ context.Context, input *AgentInput, _ ...AgentRunOption) *AsyncIterator[*AgentEvent] { + iter, gen := NewAsyncIteratorPair[*AgentEvent]() + gen.Send(&AgentEvent{ + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + Message: schema.AssistantMessage("Response", nil), + }, + }, + }) + gen.Close() + return iter +} + +func TestRunWithCancel_Streaming(t *testing.T) { + ctx := context.Background() + + t.Run("CancelImmediate_DuringModelStream", func(t *testing.T) { + modelStarted := make(chan struct{}, 1) + st := newSlowTool("slow_tool", 100*time.Millisecond, "tool result") + + slowModel := &cancelTestChatModel{ + delay: 2 * time.Second, + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: schema.FunctionCall{ + Name: "slow_tool", + Arguments: `{"input": "test"}`, + }, + }, + }, + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent with tool", + Instruction: "You are a test assistant", + Model: slowModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + EnableStreaming: true, + }) + + iter, cancelFn := runner.RunWithCancel(ctx, []Message{schema.UserMessage("Use the tool")}) + assert.NotNil(t, iter) + assert.NotNil(t, cancelFn) + + eventsCh := make(chan []*AgentEvent, 1) + go func() { + var events []*AgentEvent + for { + event, ok := iter.Next() + if !ok { + break + } + events = append(events, event) + } + eventsCh <- events + }() + + select { + case <-modelStarted: + case <-time.After(5 * time.Second): + t.Fatal("Model did not start within 5 seconds") + } + + time.Sleep(100 * time.Millisecond) + + cancelErr := cancelFn(ctx) + assert.NoError(t, cancelErr) + + start := time.Now() + events := <-eventsCh + elapsed := time.Since(start) + + assert.True(t, elapsed < 1*time.Second, "Should return quickly after cancel, elapsed: %v", elapsed) + assert.True(t, len(events) > 0) + + hasInterrupted := false + for _, e := range events { + assert.Nil(t, e.Err, "Should not have error event after cancel") + if e.Action != nil && e.Action.Interrupted != nil { + hasInterrupted = true + } + } + assert.True(t, hasInterrupted, "Should have interrupted event after cancel") + }) + + t.Run("CancelAfterToolCall_Streaming", func(t *testing.T) { + toolStarted := make(chan struct{}, 1) + st := &slowToolWithSignal{ + name: "slow_tool", + delay: 500 * time.Millisecond, + result: "tool result", + startedChan: toolStarted, + } + + modelWithToolCall := &simpleChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: schema.FunctionCall{ + Name: "slow_tool", + Arguments: `{"input": "test"}`, + }, + }, + }, + }, + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent with tool", + Instruction: "You are a test assistant", + Model: modelWithToolCall, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + EnableStreaming: true, + }) + + iter, cancelFn := runner.RunWithCancel(ctx, []Message{schema.UserMessage("Use the tool")}) + assert.NotNil(t, iter) + assert.NotNil(t, cancelFn) + + <-toolStarted + + time.Sleep(100 * time.Millisecond) + + cancelErr := cancelFn(ctx, WithCancelMode(CancelAfterToolCall)) + assert.NoError(t, cancelErr) + + var events []*AgentEvent + for { + event, ok := iter.Next() + if !ok { + break + } + assert.Nil(t, event.Err, "Should not have error event after cancel") + events = append(events, event) + } + + assert.True(t, len(events) > 0) + assert.True(t, atomic.LoadInt32(&st.callCount) >= 1, "Tool should have been called") + }) +} + +// TestResumeWithCancel tests the workflow of Cancel followed by Resume. +// +// IMPORTANT: When Cancel is triggered, the cancelableChatModel.Generate/Stream +// method returns immediately with an Interrupt error, but the inner model's +// Generate/Stream call continues running in a background goroutine until completion. +// This means the original model instance's fields (e.g., delay, response) may still +// be read by the background goroutine after Cancel returns. +// +// To avoid data races, we create new agent and runner instances for the Resume phase +// instead of reusing and modifying the original model instance. +func TestResumeWithCancel(t *testing.T) { + ctx := context.Background() + + t.Run("RunWithCancel_ThenResumeWithCancel", func(t *testing.T) { + modelStarted := make(chan struct{}, 1) + modelCallCount := int32(0) + st := newSlowTool("slow_tool", 100*time.Millisecond, "tool result") + + slowModel := &cancelTestChatModel{ + delay: 500 * time.Millisecond, + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: schema.FunctionCall{ + Name: "slow_tool", + Arguments: `{"input": "test"}`, + }, + }, + }, + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent with tool", + Instruction: "You are a test assistant", + Model: slowModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + + store := newCancelTestStore() + checkpointID := "resume-cancel-test-1" + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + EnableStreaming: false, + CheckPointStore: store, + }) + + iter, cancelFn := runner.RunWithCancel(ctx, []Message{schema.UserMessage("Use the tool")}, WithCheckPointID(checkpointID)) + + <-modelStarted + atomic.AddInt32(&modelCallCount, 1) + + cancelErr := cancelFn(ctx) + assert.NoError(t, cancelErr) + + var events []*AgentEvent + for { + event, ok := iter.Next() + if !ok { + break + } + assert.Nil(t, event.Err, "Should not have error event after cancel") + events = append(events, event) + } + assert.True(t, len(events) > 0) + + hasInterrupted := false + for _, e := range events { + if e.Action != nil && e.Action.Interrupted != nil { + hasInterrupted = true + break + } + } + assert.True(t, hasInterrupted, "First run should have interrupted event") + + newModelStarted := make(chan struct{}, 1) + slowModel2 := &cancelTestChatModel{ + delay: 100 * time.Millisecond, + response: &schema.Message{ + Role: schema.Assistant, + Content: "Final response after resume", + }, + startedChan: newModelStarted, + doneChan: make(chan struct{}, 1), + } + + agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent with tool", + Instruction: "You are a test assistant", + Model: slowModel2, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + + runner2 := NewRunner(ctx, RunnerConfig{ + Agent: agent2, + EnableStreaming: false, + CheckPointStore: store, + }) + + resumeIter, resumeCancelFn, err := runner2.ResumeWithCancel(ctx, checkpointID) + assert.NoError(t, err) + assert.NotNil(t, resumeIter) + assert.NotNil(t, resumeCancelFn) + + var resumeEvents []*AgentEvent + for { + event, ok := resumeIter.Next() + if !ok { + break + } + assert.Nil(t, event.Err, "Should not have error event during resume") + resumeEvents = append(resumeEvents, event) + } + + assert.True(t, len(resumeEvents) > 0, "Resume should produce events") + }) + + t.Run("ResumeWithCancel_ThenCancel", func(t *testing.T) { + firstModelStarted := make(chan struct{}, 1) + resumeModelStarted := make(chan struct{}, 1) + modelCallCount := int32(0) + st := newSlowTool("slow_tool", 100*time.Millisecond, "tool result") + + slowModel := &cancelTestChatModel{ + delay: 500 * time.Millisecond, + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: schema.FunctionCall{ + Name: "slow_tool", + Arguments: `{"input": "test"}`, + }, + }, + }, + }, + startedChan: firstModelStarted, + doneChan: make(chan struct{}, 1), + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent with tool", + Instruction: "You are a test assistant", + Model: slowModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + + store := newCancelTestStore() + checkpointID := "resume-then-cancel-test-1" + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + EnableStreaming: false, + CheckPointStore: store, + }) + + iter, cancelFn := runner.RunWithCancel(ctx, []Message{schema.UserMessage("Use the tool")}, WithCheckPointID(checkpointID)) + + <-firstModelStarted + atomic.AddInt32(&modelCallCount, 1) + + cancelErr := cancelFn(ctx) + assert.NoError(t, cancelErr) + + for { + _, ok := iter.Next() + if !ok { + break + } + } + + slowModel2 := &cancelTestChatModel{ + delay: 2 * time.Second, + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: schema.FunctionCall{ + Name: "slow_tool", + Arguments: `{"input": "test"}`, + }, + }, + }, + }, + startedChan: resumeModelStarted, + doneChan: make(chan struct{}, 1), + } + + agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent with tool", + Instruction: "You are a test assistant", + Model: slowModel2, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + + runner2 := NewRunner(ctx, RunnerConfig{ + Agent: agent2, + EnableStreaming: false, + CheckPointStore: store, + }) + + resumeIter, resumeCancelFn, err := runner2.ResumeWithCancel(ctx, checkpointID) + assert.NoError(t, err) + + resumeEventsCh := make(chan []*AgentEvent, 1) + go func() { + var events []*AgentEvent + for { + event, ok := resumeIter.Next() + if !ok { + break + } + events = append(events, event) + } + resumeEventsCh <- events + }() + + <-resumeModelStarted + atomic.AddInt32(&modelCallCount, 1) + + time.Sleep(100 * time.Millisecond) + + err = resumeCancelFn(ctx) + assert.NoError(t, err) + + start := time.Now() + resumeEvents := <-resumeEventsCh + elapsed := time.Since(start) + + assert.True(t, elapsed < 1*time.Second, "Resume should return quickly after cancel, elapsed: %v", elapsed) + assert.True(t, len(resumeEvents) > 0) + + hasInterrupted := false + for _, e := range resumeEvents { + assert.Nil(t, e.Err, "Should not have error event after resume cancel") + if e.Action != nil && e.Action.Interrupted != nil { + hasInterrupted = true + } + } + assert.True(t, hasInterrupted, "Resume should have interrupted event after cancel") + }) +} diff --git a/adk/cancel_wrapper.go b/adk/cancel_wrapper.go new file mode 100644 index 000000000..9c00df77f --- /dev/null +++ b/adk/cancel_wrapper.go @@ -0,0 +1,280 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "io" + "runtime/debug" + "time" + + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/internal/safe" + "github.com/cloudwego/eino/schema" +) + +type cancelWaitResult[T any] struct { + result T + err error + cancelled bool +} + +func waitWithCancel[T any](cs *cancelSig, resultCh <-chan cancelWaitResult[T]) cancelWaitResult[T] { + var timeCh <-chan time.Time + select { + case <-cs.done: + cfg := cs.config.Load().(*cancelConfig) + if cfg.Mode == CancelImmediate { + if cfg.Timeout == nil { + return cancelWaitResult[T]{cancelled: true} + } + timeCh = time.After(*cfg.Timeout) + } + case res := <-resultCh: + return res + } + select { + case <-timeCh: + return cancelWaitResult[T]{cancelled: true} + case res := <-resultCh: + return res + } +} + +type cancelableChatModel struct { + inner model.BaseChatModel + cs *cancelSig +} + +func wrapModelForCancelable(m model.BaseChatModel, cs *cancelSig) *cancelableChatModel { + return &cancelableChatModel{inner: m, cs: cs} +} + +func (c *cancelableChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + if cfg := checkCancelSig(c.cs); cfg != nil && cfg.Mode == CancelImmediate { + return nil, compose.Interrupt(ctx, "cancelled externally") + } + + resultCh := make(chan cancelWaitResult[*schema.Message], 1) + go func() { + defer func() { + if panicErr := recover(); panicErr != nil { + resultCh <- cancelWaitResult[*schema.Message]{err: safe.NewPanicErr(panicErr, debug.Stack())} + } + }() + res, err := c.inner.Generate(ctx, input, opts...) + resultCh <- cancelWaitResult[*schema.Message]{result: res, err: err} + }() + + res := waitWithCancel(c.cs, resultCh) + if res.cancelled { + return nil, compose.Interrupt(ctx, "cancelled externally") + } + return res.result, res.err +} + +func (c *cancelableChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + if cfg := checkCancelSig(c.cs); cfg != nil && cfg.Mode == CancelImmediate { + return nil, compose.Interrupt(ctx, "cancelled externally") + } + + resultCh := make(chan cancelWaitResult[*schema.StreamReader[*schema.Message]], 1) + go func() { + defer func() { + if panicErr := recover(); panicErr != nil { + resultCh <- cancelWaitResult[*schema.StreamReader[*schema.Message]]{err: safe.NewPanicErr(panicErr, debug.Stack())} + } + }() + + stream, err := c.inner.Stream(ctx, input, opts...) + if err != nil { + resultCh <- cancelWaitResult[*schema.StreamReader[*schema.Message]]{err: err} + return + } + copies := stream.Copy(2) + _ = consumeStreamForError(copies[0]) + resultCh <- cancelWaitResult[*schema.StreamReader[*schema.Message]]{result: copies[1]} + }() + + res := waitWithCancel(c.cs, resultCh) + if res.cancelled { + return nil, compose.Interrupt(ctx, "cancelled externally") + } + return res.result, res.err +} + +func cancelableToolInvokable(cs *cancelSig, endpoint compose.InvokableToolEndpoint) compose.InvokableToolEndpoint { + return func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) { + if cfg := checkCancelSig(cs); cfg != nil && cfg.Mode == CancelImmediate { + return nil, compose.Interrupt(ctx, "cancelled externally") + } + + resultCh := make(chan cancelWaitResult[*compose.ToolOutput], 1) + go func() { + defer func() { + if panicErr := recover(); panicErr != nil { + resultCh <- cancelWaitResult[*compose.ToolOutput]{err: safe.NewPanicErr(panicErr, debug.Stack())} + } + }() + output, err := endpoint(ctx, input) + resultCh <- cancelWaitResult[*compose.ToolOutput]{result: output, err: err} + }() + + res := waitWithCancel(cs, resultCh) + if res.cancelled { + return nil, compose.Interrupt(ctx, "cancelled externally") + } + return res.result, res.err + } +} + +func cancelableToolStreamable(cs *cancelSig, endpoint compose.StreamableToolEndpoint) compose.StreamableToolEndpoint { + return func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) { + if cfg := checkCancelSig(cs); cfg != nil && cfg.Mode == CancelImmediate { + return nil, compose.Interrupt(ctx, "cancelled externally") + } + + resultCh := make(chan cancelWaitResult[*schema.StreamReader[string]], 1) + go func() { + defer func() { + if panicErr := recover(); panicErr != nil { + resultCh <- cancelWaitResult[*schema.StreamReader[string]]{err: safe.NewPanicErr(panicErr, debug.Stack())} + } + }() + output, err := endpoint(ctx, input) + if err != nil { + resultCh <- cancelWaitResult[*schema.StreamReader[string]]{err: err} + return + } + copies := output.Result.Copy(2) + _ = consumeStreamForErrorString(copies[0]) + resultCh <- cancelWaitResult[*schema.StreamReader[string]]{result: copies[1]} + }() + + res := waitWithCancel(cs, resultCh) + if res.cancelled { + return nil, compose.Interrupt(ctx, "cancelled externally") + } + if res.err != nil { + return nil, res.err + } + return &compose.StreamToolOutput{Result: res.result}, nil + } +} + +func cancelableToolEnhancedInvokable(cs *cancelSig, endpoint compose.EnhancedInvokableToolEndpoint) compose.EnhancedInvokableToolEndpoint { + return func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedInvokableToolOutput, error) { + if cfg := checkCancelSig(cs); cfg != nil && cfg.Mode == CancelImmediate { + return nil, compose.Interrupt(ctx, "cancelled externally") + } + + resultCh := make(chan cancelWaitResult[*compose.EnhancedInvokableToolOutput], 1) + go func() { + defer func() { + if panicErr := recover(); panicErr != nil { + resultCh <- cancelWaitResult[*compose.EnhancedInvokableToolOutput]{err: safe.NewPanicErr(panicErr, debug.Stack())} + } + }() + output, err := endpoint(ctx, input) + resultCh <- cancelWaitResult[*compose.EnhancedInvokableToolOutput]{result: output, err: err} + }() + + res := waitWithCancel(cs, resultCh) + if res.cancelled { + return nil, compose.Interrupt(ctx, "cancelled externally") + } + return res.result, res.err + } +} + +func cancelableToolEnhancedStreamable(cs *cancelSig, endpoint compose.EnhancedStreamableToolEndpoint) compose.EnhancedStreamableToolEndpoint { + return func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedStreamableToolOutput, error) { + if cfg := checkCancelSig(cs); cfg != nil && cfg.Mode == CancelImmediate { + return nil, compose.Interrupt(ctx, "cancelled externally") + } + + resultCh := make(chan cancelWaitResult[*schema.StreamReader[*schema.ToolResult]], 1) + go func() { + defer func() { + if panicErr := recover(); panicErr != nil { + resultCh <- cancelWaitResult[*schema.StreamReader[*schema.ToolResult]]{err: safe.NewPanicErr(panicErr, debug.Stack())} + } + }() + output, err := endpoint(ctx, input) + if err != nil { + resultCh <- cancelWaitResult[*schema.StreamReader[*schema.ToolResult]]{err: err} + return + } + copies := output.Result.Copy(2) + _ = consumeStreamForErrorToolResult(copies[0]) + resultCh <- cancelWaitResult[*schema.StreamReader[*schema.ToolResult]]{result: copies[1]} + }() + + res := waitWithCancel(cs, resultCh) + if res.cancelled { + return nil, compose.Interrupt(ctx, "cancelled externally") + } + if res.err != nil { + return nil, res.err + } + return &compose.EnhancedStreamableToolOutput{Result: res.result}, nil + } +} + +func cancelableTool(cs *cancelSig) compose.ToolMiddleware { + return compose.ToolMiddleware{ + Invokable: func(endpoint compose.InvokableToolEndpoint) compose.InvokableToolEndpoint { + return cancelableToolInvokable(cs, endpoint) + }, + Streamable: func(endpoint compose.StreamableToolEndpoint) compose.StreamableToolEndpoint { + return cancelableToolStreamable(cs, endpoint) + }, + EnhancedInvokable: func(endpoint compose.EnhancedInvokableToolEndpoint) compose.EnhancedInvokableToolEndpoint { + return cancelableToolEnhancedInvokable(cs, endpoint) + }, + EnhancedStreamable: func(endpoint compose.EnhancedStreamableToolEndpoint) compose.EnhancedStreamableToolEndpoint { + return cancelableToolEnhancedStreamable(cs, endpoint) + }, + } +} + +func consumeStreamForErrorString(stream *schema.StreamReader[string]) error { + defer stream.Close() + for { + _, err := stream.Recv() + if err == io.EOF { + return nil + } + if err != nil { + return err + } + } +} + +func consumeStreamForErrorToolResult(stream *schema.StreamReader[*schema.ToolResult]) error { + defer stream.Close() + for { + _, err := stream.Recv() + if err == io.EOF { + return nil + } + if err != nil { + return err + } + } +} diff --git a/adk/chatmodel.go b/adk/chatmodel.go index 56993a7b2..effb270bd 100644 --- a/adk/chatmodel.go +++ b/adk/chatmodel.go @@ -38,6 +38,10 @@ import ( "github.com/cloudwego/eino/schema" ) +var _ ResumableAgent = &ChatModelAgent{} +var _ CancellableAgent = &ChatModelAgent{} +var _ CancellableResumableAgent = &ChatModelAgent{} + type chatModelAgentExecCtx struct { runtimeReturnDirectly map[string]bool generator *AsyncGenerator[*AgentEvent] @@ -341,7 +345,8 @@ type ChatModelAgent struct { exeCtx *execContext } -type runFunc func(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent], store *bridgeStore, instruction string, returnDirectly map[string]bool, opts ...compose.Option) +type runFunc func(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent], + store *bridgeStore, instruction string, returnDirectly map[string]bool, cs *cancelSig, opts ...compose.Option) // NewChatModelAgent constructs a chat model-backed agent with the provided config. func NewChatModelAgent(ctx context.Context, config *ChatModelAgentConfig) (*ChatModelAgent, error) { @@ -570,7 +575,7 @@ func setOutputToSession(ctx context.Context, msg Message, msgStream MessageStrea } func errFunc(err error) runFunc { - return func(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent], store *bridgeStore, _ string, _ map[string]bool, _ ...compose.Option) { + return func(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent], store *bridgeStore, _ string, _ map[string]bool, _ *cancelSig, _ ...compose.Option) { generator.Send(&AgentEvent{Err: err}) } } @@ -692,19 +697,23 @@ func (a *ChatModelAgent) prepareExecContext(ctx context.Context) (*execContext, } func (a *ChatModelAgent) buildNoToolsRunFunc(_ context.Context) runFunc { - wrappedModel := buildModelWrappers(a.model, &modelWrapperConfig{ - handlers: a.handlers, - middlewares: a.middlewares, - retryConfig: a.modelRetryConfig, - }) - type noToolsInput struct { input *AgentInput instruction string } return func(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent], - store *bridgeStore, instruction string, _ map[string]bool, opts ...compose.Option) { + store *bridgeStore, instruction string, _ map[string]bool, cs *cancelSig, opts ...compose.Option) { + + wrappedModel := buildModelWrappers(a.model, &modelWrapperConfig{ + handlers: a.handlers, + middlewares: a.middlewares, + retryConfig: a.modelRetryConfig, + }) + + if cs != nil { + wrappedModel = wrapModelForCancelable(wrappedModel, cs) + } chain := compose.NewChain[noToolsInput, Message]( compose.WithGenLocalState(func(ctx context.Context) (state *State) { @@ -750,13 +759,37 @@ func (a *ChatModelAgent) buildNoToolsRunFunc(_ context.Context) runFunc { } else if msgStream != nil { msgStream.Close() } - } else { + return + } + + info, ok := compose.ExtractInterruptInfo(err) + if !ok { generator.Send(&AgentEvent{Err: err}) + return + } + + data, existed, sErr := store.Get(ctx, bridgeCheckpointID) + if sErr != nil { + generator.Send(&AgentEvent{AgentName: a.name, Err: fmt.Errorf("failed to get interrupt info: %w", sErr)}) + return + } + if !existed { + generator.Send(&AgentEvent{AgentName: a.name, Err: fmt.Errorf("interrupt occurred but checkpoint data is missing")}) + return + } + + is := FromInterruptContexts(info.InterruptContexts) + event := CompositeInterrupt(ctx, info, data, is) + event.Action.Interrupted.Data = &ChatModelAgentInterruptInfo{ + Info: info, + Data: data, } + event.AgentName = a.name + generator.Send(event) } } -func (a *ChatModelAgent) buildReactRunFunc(ctx context.Context, bc *execContext) (runFunc, error) { +func (a *ChatModelAgent) buildReActRunFunc(_ context.Context, bc *execContext) (runFunc, error) { conf := &reactConfig{ model: a.model, toolsConfig: &bc.toolsNodeConf, @@ -777,8 +810,8 @@ func (a *ChatModelAgent) buildReactRunFunc(ctx context.Context, bc *execContext) } return func(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent], store *bridgeStore, - instruction string, returnDirectly map[string]bool, opts ...compose.Option) { - g, err := newReact(ctx, conf) + instruction string, returnDirectly map[string]bool, cs *cancelSig, opts ...compose.Option) { + g, err := newReact(ctx, conf, cs) if err != nil { generator.Send(&AgentEvent{Err: err}) return @@ -894,7 +927,7 @@ func (a *ChatModelAgent) buildRunFunc(ctx context.Context) runFunc { return } - run, err := a.buildReactRunFunc(ctx, ec) + run, err := a.buildReActRunFunc(ctx, ec) if err != nil { a.run = errFunc(err) return @@ -938,7 +971,7 @@ func (a *ChatModelAgent) getRunFunc(ctx context.Context) (context.Context, runFu if len(runtimeBC.toolsNodeConf.Tools) == 0 { tempRun = a.buildNoToolsRunFunc(ctx) } else { - tempRun, err = a.buildReactRunFunc(ctx, runtimeBC) + tempRun, err = a.buildReActRunFunc(ctx, runtimeBC) if err != nil { return ctx, nil, nil, err } @@ -948,6 +981,15 @@ func (a *ChatModelAgent) getRunFunc(ctx context.Context) (context.Context, runFu } func (a *ChatModelAgent) Run(ctx context.Context, input *AgentInput, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { + iter, _ := a.runInternal(ctx, input, false, opts...) + return iter +} + +func (a *ChatModelAgent) RunWithCancel(ctx context.Context, input *AgentInput, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) { + return a.runInternal(ctx, input, true, opts...) +} + +func (a *ChatModelAgent) runInternal(ctx context.Context, input *AgentInput, withCancel bool, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) { iterator, generator := NewAsyncIteratorPair[*AgentEvent]() ctx, run, bc, err := a.getRunFunc(ctx) @@ -956,7 +998,7 @@ func (a *ChatModelAgent) Run(ctx context.Context, input *AgentInput, opts ...Age generator.Send(&AgentEvent{Err: err}) generator.Close() }() - return iterator + return iterator, notCancellableFuncInternal } co := getComposeOptions(opts) @@ -969,6 +1011,13 @@ func (a *ChatModelAgent) Run(ctx context.Context, input *AgentInput, opts ...Age } } + var cs *cancelSig + var cancelFn CancelFunc = notCancellableFuncInternal + if withCancel { + cs = newCancelSig() + cancelFn = buildCancelFunc(cs) + } + go func() { defer func() { panicErr := recover() @@ -990,13 +1039,47 @@ func (a *ChatModelAgent) Run(ctx context.Context, input *AgentInput, opts ...Age returnDirectly = bc.returnDirectly } - run(ctx, input, generator, newBridgeStore(), instruction, returnDirectly, co...) + run(ctx, input, generator, newBridgeStore(), instruction, returnDirectly, cs, co...) }() - return iterator + return iterator, cancelFn +} + +func buildCancelFunc(cs *cancelSig) CancelFunc { + var once sync.Once + + return func(_ context.Context, opts ...CancelOption) error { + cfg := &cancelConfig{ + Mode: CancelImmediate, + } + for _, opt := range opts { + opt(cfg) + } + + cancelled := false + once.Do(func() { + cs.cancel(cfg) + cancelled = true + }) + + if !cancelled { + return ErrAgentFinished + } + + return nil + } } func (a *ChatModelAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { + iter, _ := a.resumeInternal(ctx, info, false, opts...) + return iter +} + +func (a *ChatModelAgent) ResumeWithCancel(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) { + return a.resumeInternal(ctx, info, true, opts...) +} + +func (a *ChatModelAgent) resumeInternal(ctx context.Context, info *ResumeInfo, withCancel bool, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) { iterator, generator := NewAsyncIteratorPair[*AgentEvent]() ctx, run, bc, err := a.getRunFunc(ctx) @@ -1005,7 +1088,7 @@ func (a *ChatModelAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...A generator.Send(&AgentEvent{Err: err}) generator.Close() }() - return iterator + return iterator, notCancellableFuncInternal } co := getComposeOptions(opts) @@ -1018,14 +1101,19 @@ func (a *ChatModelAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...A } } + methodName := "Resume" + if withCancel { + methodName = "ResumeWithCancel" + } + if info.InterruptState == nil { - panic(fmt.Sprintf("ChatModelAgent.Resume: agent '%s' was asked to resume but has no state", a.Name(ctx))) + panic(fmt.Sprintf("ChatModelAgent.%s: agent '%s' was asked to resume but has no state", methodName, a.Name(ctx))) } stateByte, ok := info.InterruptState.([]byte) if !ok { - panic(fmt.Sprintf("ChatModelAgent.Resume: agent '%s' was asked to resume but has invalid interrupt state type: %T", - a.Name(ctx), info.InterruptState)) + panic(fmt.Sprintf("ChatModelAgent.%s: agent '%s' was asked to resume but has invalid interrupt state type: %T", + methodName, a.Name(ctx), info.InterruptState)) } // Migrate legacy checkpoints before resume. @@ -1040,15 +1128,15 @@ func (a *ChatModelAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...A generator.Send(&AgentEvent{Err: err}) generator.Close() }() - return iterator + return iterator, nil } var historyModifier func(ctx context.Context, history []Message) []Message if info.ResumeData != nil { resumeData, ok := info.ResumeData.(*ChatModelAgentResumeData) if !ok { - panic(fmt.Sprintf("ChatModelAgent.Resume: agent '%s' was asked to resume but has invalid resume data type: %T", - a.Name(ctx), info.ResumeData)) + panic(fmt.Sprintf("ChatModelAgent.%s: agent '%s' was asked to resume but has invalid resume data type: %T", + methodName, a.Name(ctx), info.ResumeData)) } historyModifier = resumeData.HistoryModifier } @@ -1064,6 +1152,13 @@ func (a *ChatModelAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...A })) } + var cs *cancelSig + var cancelFn CancelFunc = notCancellableFuncInternal + if withCancel { + cs = newCancelSig() + cancelFn = buildCancelFunc(cs) + } + go func() { defer func() { panicErr := recover() @@ -1086,10 +1181,10 @@ func (a *ChatModelAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...A } run(ctx, &AgentInput{EnableStreaming: info.EnableStreaming}, generator, - newResumeBridgeStore(stateByte), instruction, returnDirectly, co...) + newResumeBridgeStore(stateByte), instruction, returnDirectly, cs, co...) }() - return iterator + return iterator, cancelFn } func getComposeOptions(opts []AgentRunOption) []compose.Option { diff --git a/adk/flow.go b/adk/flow.go index ee4dec96c..93143c641 100644 --- a/adk/flow.go +++ b/adk/flow.go @@ -333,6 +333,15 @@ func buildDefaultHistoryRewriter(agentName string) HistoryRewriter { } func (a *flowAgent) Run(ctx context.Context, input *AgentInput, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { + iter, _ := a.runInternal(ctx, input, false, opts...) + return iter +} + +func (a *flowAgent) RunWithCancel(ctx context.Context, input *AgentInput, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) { + return a.runInternal(ctx, input, true, opts...) +} + +func (a *flowAgent) runInternal(ctx context.Context, input *AgentInput, withCancel bool, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) { agentName := a.Name(ctx) var runCtx *runContext @@ -345,7 +354,7 @@ func (a *flowAgent) Run(ctx context.Context, input *AgentInput, opts ...AgentRun if err != nil { cbInput := &AgentCallbackInput{Input: input} ctx = callbacks.OnStart(ctx, cbInput) - return wrapIterWithOnEnd(ctx, genErrorIter(err)) + return wrapIterWithOnEnd(ctx, genErrorIter(err)), notCancellableFuncInternal } ctxForSubAgents := ctx @@ -358,19 +367,40 @@ func (a *flowAgent) Run(ctx context.Context, input *AgentInput, opts ...AgentRun input = processedInput if wf, ok := a.Agent.(*workflowAgent); ok { - return wrapIterWithOnEnd(ctx, wf.Run(ctx, input, filterCallbackHandlersForNestedAgents(agentName, opts)...)) + return wrapIterWithOnEnd(ctx, wf.Run(ctx, input, filterCallbackHandlersForNestedAgents(agentName, opts)...)), notCancellableFuncInternal } - aIter := a.Agent.Run(ctx, input, filterOptions(agentName, opts)...) + var aIter *AsyncIterator[*AgentEvent] + var cancelFn CancelFunc = notCancellableFuncInternal + + ca, supportCancel := a.Agent.(CancellableAgent) + if withCancel && supportCancel { + aIter, cancelFn = ca.RunWithCancel(ctx, input, filterOptions(agentName, opts)...) + } else { + aIter = a.Agent.Run(ctx, input, filterOptions(agentName, opts)...) + } iterator, generator := NewAsyncIteratorPair[*AgentEvent]() go a.run(ctx, ctxForSubAgents, runCtx, aIter, generator, opts...) - return iterator + return iterator, cancelFn +} + +func notCancellableFuncInternal(_ context.Context, _ ...CancelOption) error { + return ErrAgentNotCancellable } func (a *flowAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { + iter, _ := a.resumeInternal(ctx, info, false, opts...) + return iter +} + +func (a *flowAgent) ResumeWithCancel(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) { + return a.resumeInternal(ctx, info, true, opts...) +} + +func (a *flowAgent) resumeInternal(ctx context.Context, info *ResumeInfo, withCancel bool, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) { agentName := a.Name(ctx) ctx, info = buildResumeInfo(ctx, agentName, info) @@ -383,51 +413,59 @@ func (a *flowAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...AgentR ctx = callbacks.OnStart(ctx, cbInput) if info.WasInterrupted { - ra, ok := a.Agent.(ResumableAgent) - if !ok { + var aIter *AsyncIterator[*AgentEvent] + var cancelFn CancelFunc = notCancellableFuncInternal + + ca, supportCancel := a.Agent.(CancellableResumableAgent) + if withCancel && supportCancel { + aIter, cancelFn = ca.ResumeWithCancel(ctx, info, opts...) + } else if ra, ok := a.Agent.(ResumableAgent); ok { + if _, ok := ra.(*workflowAgent); ok { + filteredOpts := filterCallbackHandlersForNestedAgents(agentName, opts) + aIter := ra.Resume(ctx, info, filteredOpts...) + return wrapIterWithOnEnd(ctx, aIter), cancelFn + } + aIter = ra.Resume(ctx, info, opts...) + } else { return wrapIterWithOnEnd(ctx, genErrorIter(fmt.Errorf("failed to resume agent: agent '%s' is an interrupt point "+ - "but is not a ResumableAgent", agentName))) + "but is not a ResumableAgent", agentName))), notCancellableFuncInternal } - iterator, generator := NewAsyncIteratorPair[*AgentEvent]() - if _, ok := ra.(*workflowAgent); ok { - filteredOpts := filterCallbackHandlersForNestedAgents(agentName, opts) - aIter := ra.Resume(ctx, info, filteredOpts...) - return wrapIterWithOnEnd(ctx, aIter) - } - aIter := ra.Resume(ctx, info, opts...) + iterator, generator := NewAsyncIteratorPair[*AgentEvent]() go a.run(ctx, ctxForSubAgents, getRunCtx(ctxForSubAgents), aIter, generator, opts...) - return iterator + return iterator, cancelFn } nextAgentName, err := getNextResumeAgent(ctx, info) if err != nil { - return wrapIterWithOnEnd(ctx, genErrorIter(err)) + return wrapIterWithOnEnd(ctx, genErrorIter(err)), notCancellableFuncInternal } subAgent := a.getAgent(ctxForSubAgents, nextAgentName) if subAgent == nil { - // the inner agent wrapped by flowAgent may be ANY agent, including flowAgent, - // AgentWithDeterministicTransferTo, or any other custom agent user defined, - // or any combinations of the above in any order, - // that ultimately wraps the flowAgent with sub-agents - // We need to go through these wrappers to reach the flowAgent with sub-agents. if len(a.subAgents) == 0 { + ca, supportCancel := a.Agent.(CancellableResumableAgent) + if withCancel && supportCancel { + iter, cancelFn := ca.ResumeWithCancel(ctx, info, opts...) + return wrapIterWithOnEnd(ctx, iter), cancelFn + } if ra, ok := a.Agent.(ResumableAgent); ok { - // Use ctx (callback-enriched) instead of ctxForSubAgents here. - // This is the inner agent that flowAgent wraps (e.g., supervisorContainer), - // not a sub-agent. The callback context from OnStart should be propagated - // to ensure unified tracing for container patterns. - return wrapIterWithOnEnd(ctx, ra.Resume(ctx, info, opts...)) + return wrapIterWithOnEnd(ctx, ra.Resume(ctx, info, opts...)), notCancellableFuncInternal } return wrapIterWithOnEnd(ctx, genErrorIter(fmt.Errorf( "failed to resume agent: agent '%s' (type %T) has no sub-agents and does not implement ResumableAgent interface. "+ - "To support resume, your custom agent wrapper must implement the ResumableAgent interface", agentName, a.Agent))) + "To support resume, your custom agent wrapper must implement the ResumableAgent interface", agentName, a.Agent))), nil } - return wrapIterWithOnEnd(ctx, genErrorIter(fmt.Errorf("failed to resume agent: sub-agent '%s' not found in agent '%s'", nextAgentName, agentName))) + return wrapIterWithOnEnd(ctx, genErrorIter(fmt.Errorf("failed to resume agent: sub-agent '%s' not found in agent '%s'", nextAgentName, agentName))), notCancellableFuncInternal + } + + ca, supportCancel := ResumableAgent(subAgent).(CancellableResumableAgent) + if withCancel && supportCancel { + iter, cancelFn := ca.ResumeWithCancel(ctxForSubAgents, info, opts...) + return wrapIterWithOnEnd(ctx, iter), cancelFn } - return wrapIterWithOnEnd(ctx, subAgent.Resume(ctxForSubAgents, info, opts...)) + return wrapIterWithOnEnd(ctx, subAgent.Resume(ctxForSubAgents, info, opts...)), notCancellableFuncInternal } type DeterministicTransferConfig struct { diff --git a/adk/interface.go b/adk/interface.go index 2e7e3d09c..1d2f5f070 100644 --- a/adk/interface.go +++ b/adk/interface.go @@ -291,6 +291,9 @@ const ( // ErrAgentFinished is returned by Cancel when the agent has already finished execution. var ErrAgentFinished = errors.New("agent has already finished execution") +// ErrAgentNotCancellable is returned by Cancel when the agent does not support cancellation. +var ErrAgentNotCancellable = errors.New("agent does not implement CancellableAgent interface") + type cancelConfig struct { Mode CancelMode Timeout *time.Duration @@ -298,12 +301,14 @@ type cancelConfig struct { type CancelOption func(*cancelConfig) +// WithCancelMode sets the cancel mode for the cancel operation. func WithCancelMode(mode CancelMode) CancelOption { return func(config *cancelConfig) { config.Mode = mode } } +// WithCancelTimeout sets a timeout duration for CancelImmediate mode. func WithCancelTimeout(timeout time.Duration) CancelOption { return func(config *cancelConfig) { config.Timeout = &timeout @@ -312,10 +317,12 @@ func WithCancelTimeout(timeout time.Duration) CancelOption { type CancelFunc func(context.Context, ...CancelOption) error -type CancellableRun interface { +type CancellableAgent interface { + Agent RunWithCancel(ctx context.Context, input *AgentInput, options ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) } -type CancellableResume interface { +type CancellableResumableAgent interface { + ResumableAgent ResumeWithCancel(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) } diff --git a/adk/react.go b/adk/react.go index 2bf6dd462..0aec5d94c 100644 --- a/adk/react.go +++ b/adk/react.go @@ -22,6 +22,7 @@ import ( "encoding/gob" "errors" "io" + "sync/atomic" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/compose" @@ -299,13 +300,31 @@ func genReactState(config *reactConfig) func(ctx context.Context) *State { } } -func newReact(ctx context.Context, config *reactConfig) (reactGraph, error) { +func newReact(ctx context.Context, config *reactConfig, cs *cancelSig) (reactGraph, error) { const ( - initNode_ = "Init" - chatModel_ = "ChatModel" - toolNode_ = "ToolNode" + initNode_ = "Init" + chatModel_ = "ChatModel" + beforeToolNode_ = "BeforeToolNode" + toolNode_ = "ToolNode" + afterToolNode_ = "AfterToolNode" ) + checkCancel := cs != nil + + nodeNameAfterModel := func() string { + if checkCancel { + return beforeToolNode_ + } + return toolNode_ + } + + nodeNameAfterTool := func() string { + if checkCancel { + return afterToolNode_ + } + return chatModel_ + } + g := compose.NewGraph[*reactInput, Message](compose.WithGenLocalState(genReactState(config))) initLambda := func(ctx context.Context, input *reactInput) ([]Message, error) { @@ -318,7 +337,18 @@ func newReact(ctx context.Context, config *reactConfig) (reactGraph, error) { wrappedModel = buildModelWrappers(config.model, config.modelWrapperConf) } - toolsNode, err := compose.NewToolNode(ctx, config.toolsConfig) + toolsConfig := config.toolsConfig + if checkCancel { + wrappedModel = wrapModelForCancelable(wrappedModel, cs) + tcMWs := make([]compose.ToolMiddleware, 0, len(toolsConfig.ToolCallMiddlewares)+1) + tcMWs = append(tcMWs, cancelableTool(cs)) + tcMWs = append(tcMWs, toolsConfig.ToolCallMiddlewares...) + toolsConfigCopy := *toolsConfig + toolsConfigCopy.ToolCallMiddlewares = tcMWs + toolsConfig = &toolsConfigCopy + } + + toolsNode, err := compose.NewToolNode(ctx, toolsConfig) if err != nil { return nil, err } @@ -369,6 +399,28 @@ func newReact(ctx context.Context, config *reactConfig) (reactGraph, error) { _ = g.AddEdge(compose.START, initNode_) _ = g.AddEdge(initNode_, chatModel_) + if checkCancel { + beforeToolNode := func(ctx context.Context, input Message) (output Message, err error) { + if sig := checkCancelSig(cs); sig != nil && sig.Mode != CancelAfterToolCall { + return nil, compose.Interrupt(ctx, "cancelled externally") + } + + return input, nil + } + _ = g.AddLambdaNode(beforeToolNode_, compose.InvokableLambda(beforeToolNode), compose.WithNodeName(beforeToolNode_)) + g.AddEdge(beforeToolNode_, toolNode_) + + afterToolNode := func(ctx context.Context, input []Message) (output []Message, err error) { + if sig := checkCancelSig(cs); sig != nil && sig.Mode != CancelAfterChatModel { + return nil, compose.Interrupt(ctx, "cancelled externally") + } + + return input, nil + } + _ = g.AddLambdaNode(afterToolNode_, compose.InvokableLambda(afterToolNode), compose.WithNodeName(afterToolNode_)) + g.AddEdge(afterToolNode_, chatModel_) + } + toolCallCheck := func(ctx context.Context, sMsg MessageStream) (string, error) { defer sMsg.Close() for { @@ -382,11 +434,11 @@ func newReact(ctx context.Context, config *reactConfig) (reactGraph, error) { } if len(chunk.ToolCalls) > 0 { - return toolNode_, nil + return nodeNameAfterModel(), nil } } } - branch := compose.NewStreamGraphBranch(toolCallCheck, map[string]bool{compose.END: true, toolNode_: true}) + branch := compose.NewStreamGraphBranch(toolCallCheck, map[string]bool{compose.END: true, nodeNameAfterModel(): true}) _ = g.AddBranch(chatModel_, branch) if len(config.toolsReturnDirectly) > 0 { @@ -423,15 +475,43 @@ func newReact(ctx context.Context, config *reactConfig) (reactGraph, error) { return toolNodeToEndConverter, nil } - return chatModel_, nil + return nodeNameAfterTool(), nil } branch = compose.NewStreamGraphBranch(checkReturnDirect, - map[string]bool{toolNodeToEndConverter: true, chatModel_: true}) + map[string]bool{toolNodeToEndConverter: true, nodeNameAfterTool(): true}) _ = g.AddBranch(toolNode_, branch) } else { - _ = g.AddEdge(toolNode_, chatModel_) + _ = g.AddEdge(toolNode_, nodeNameAfterTool()) } return g, nil } + +type cancelSig struct { + done chan struct{} + config atomic.Value +} + +func newCancelSig() *cancelSig { + return &cancelSig{ + done: make(chan struct{}), + } +} + +func (cs *cancelSig) cancel(cfg *cancelConfig) { + cs.config.Store(cfg) + close(cs.done) +} + +func checkCancelSig(cs *cancelSig) *cancelConfig { + if cs == nil { + return nil + } + select { + case <-cs.done: + return cs.config.Load().(*cancelConfig) + default: + return nil + } +} diff --git a/adk/react_test.go b/adk/react_test.go index 5364f0912..969e73e35 100644 --- a/adk/react_test.go +++ b/adk/react_test.go @@ -144,7 +144,7 @@ func TestReact(t *testing.T) { toolsReturnDirectly: map[string]bool{}, } - graph, err := newReact(ctx, config) + graph, err := newReact(ctx, config, nil) assert.NoError(t, err) assert.NotNil(t, graph) @@ -211,7 +211,7 @@ func TestReact(t *testing.T) { toolsReturnDirectly: map[string]bool{info.Name: true}, } - graph, err := newReact(ctx, config) + graph, err := newReact(ctx, config, nil) assert.NoError(t, err) assert.NotNil(t, graph) @@ -303,7 +303,7 @@ func TestReact(t *testing.T) { toolsReturnDirectly: map[string]bool{}, } - graph, err := newReact(ctx, config) + graph, err := newReact(ctx, config, nil) assert.NoError(t, err) assert.NotNil(t, graph) @@ -413,7 +413,7 @@ func TestReact(t *testing.T) { toolsReturnDirectly: map[string]bool{streamInfo.Name: true}, } - graph, err := newReact(ctx, config) + graph, err := newReact(ctx, config, nil) assert.NoError(t, err) assert.NotNil(t, graph) @@ -502,7 +502,7 @@ func TestReact(t *testing.T) { maxIterations: 6, } - graph, err := newReact(ctx, config) + graph, err := newReact(ctx, config, nil) assert.NoError(t, err) assert.NotNil(t, graph) @@ -532,7 +532,7 @@ func TestReact(t *testing.T) { maxIterations: 5, } - graph, err = newReact(ctx, config) + graph, err = newReact(ctx, config, nil) assert.NoError(t, err) assert.NotNil(t, graph) diff --git a/adk/runner.go b/adk/runner.go index 07a931ac2..a9fd9b94d 100644 --- a/adk/runner.go +++ b/adk/runner.go @@ -74,6 +74,31 @@ func NewRunner(_ context.Context, conf RunnerConfig) *Runner { // upon interruption. func (r *Runner) Run(ctx context.Context, messages []Message, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { + iter, _, _ := r.runWithCancel(ctx, messages, false, opts...) + return iter +} + +// Query is a convenience method that starts a new execution with a single user query string. +func (r *Runner) Query(ctx context.Context, + query string, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { + + return r.Run(ctx, []Message{schema.UserMessage(query)}, opts...) +} + +// RunWithCancel starts a new execution of the agent and returns both an iterator and a cancel function. +// The cancel function can be used to interrupt the running agent at specific points based on the CancelMode. +// If the Runner was configured with a CheckPointStore and WithCheckPointID option, it will automatically +// save the agent's state upon cancellation for later resumption. +// +// If the agent does not implement CancellableAgent, the returned CancelFunc will be nil. +func (r *Runner) RunWithCancel(ctx context.Context, messages []Message, + opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) { + iter, cancelFn, _ := r.runWithCancel(ctx, messages, true, opts...) + return iter, cancelFn +} + +func (r *Runner) runWithCancel(ctx context.Context, messages []Message, withCancel bool, + opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc, error) { o := getCommonOptions(nil, opts...) fa := toFlowAgent(ctx, r.a) @@ -87,22 +112,46 @@ func (r *Runner) Run(ctx context.Context, messages []Message, AddSessionValues(ctx, o.sessionValues) - iter := fa.Run(ctx, input, opts...) + var iter *AsyncIterator[*AgentEvent] + var cancelFn CancelFunc + if withCancel { + if _, ok := r.a.(CancellableAgent); ok { + iter, cancelFn = fa.RunWithCancel(ctx, input, opts...) + } else { + iter = fa.Run(ctx, input, opts...) + } + } else { + iter = fa.Run(ctx, input, opts...) + } + if r.store == nil { - return iter + return iter, cancelFn, nil } niter, gen := NewAsyncIteratorPair[*AgentEvent]() go r.handleIter(ctx, iter, gen, o.checkPointID) - return niter + return niter, cancelFn, nil } -// Query is a convenience method that starts a new execution with a single user query string. -func (r *Runner) Query(ctx context.Context, - query string, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { +// ResumeWithCancel continues an interrupted execution from a checkpoint and returns both an iterator and a cancel function. +// This method uses the "Implicit Resume All" strategy where all previously interrupted points proceed without specific data. +// The cancel function can be used to interrupt the running agent again at specific points based on the CancelMode. +// +// If the agent does not implement CancellableResumableAgent, the returned CancelFunc will be nil. +func (r *Runner) ResumeWithCancel(ctx context.Context, checkPointID string, opts ...AgentRunOption) ( + *AsyncIterator[*AgentEvent], CancelFunc, error) { + return r.resumeWithCancel(ctx, checkPointID, nil, true, opts...) +} - return r.Run(ctx, []Message{schema.UserMessage(query)}, opts...) +// ResumeWithParamsAndCancel continues an interrupted execution from a checkpoint with specific parameters +// and returns both an iterator and a cancel function. +// The params.Targets map should contain the addresses of the components to be resumed as keys. +// +// If the agent does not implement CancellableResumableAgent, the returned CancelFunc will be nil. +func (r *Runner) ResumeWithParamsAndCancel(ctx context.Context, checkPointID string, params *ResumeParams, + opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc, error) { + return r.resumeWithCancel(ctx, checkPointID, params.Targets, true, opts...) } // Resume continues an interrupted execution from a checkpoint, using an "Implicit Resume All" strategy. @@ -114,7 +163,8 @@ func (r *Runner) Query(ctx context.Context, // pattern where an agent only needs to know `wasInterrupted` is true to continue. func (r *Runner) Resume(ctx context.Context, checkPointID string, opts ...AgentRunOption) ( *AsyncIterator[*AgentEvent], error) { - return r.resume(ctx, checkPointID, nil, opts...) + iter, _, err := r.resumeWithCancel(ctx, checkPointID, nil, false, opts...) + return iter, err } // ResumeWithParams continues an interrupted execution from a checkpoint with specific parameters. @@ -136,19 +186,19 @@ func (r *Runner) Resume(ctx context.Context, checkPointID string, opts ...AgentR // naturally re-interrupt if one of their interrupted children re-interrupts, as they receive the // new `CompositeInterrupt` signal from them. func (r *Runner) ResumeWithParams(ctx context.Context, checkPointID string, params *ResumeParams, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], error) { - return r.resume(ctx, checkPointID, params.Targets, opts...) + iter, _, err := r.resumeWithCancel(ctx, checkPointID, params.Targets, false, opts...) + return iter, err } -// resume is the internal implementation for both Resume and ResumeWithParams. -func (r *Runner) resume(ctx context.Context, checkPointID string, resumeData map[string]any, - opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], error) { +func (r *Runner) resumeWithCancel(ctx context.Context, checkPointID string, resumeData map[string]any, + withCancel bool, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc, error) { if r.store == nil { - return nil, fmt.Errorf("failed to resume: store is nil") + return nil, nil, fmt.Errorf("failed to resume: store is nil") } ctx, runCtx, resumeInfo, err := r.loadCheckPoint(ctx, checkPointID) if err != nil { - return nil, fmt.Errorf("failed to load from checkpoint: %w", err) + return nil, nil, fmt.Errorf("failed to load from checkpoint: %w", err) } o := getCommonOptions(nil, opts...) @@ -175,15 +225,27 @@ func (r *Runner) resume(ctx context.Context, checkPointID string, resumeData map } fa := toFlowAgent(ctx, r.a) - aIter := fa.Resume(ctx, resumeInfo, opts...) + + var aIter *AsyncIterator[*AgentEvent] + var cancelFn CancelFunc + if withCancel { + if _, ok := r.a.(CancellableResumableAgent); ok { + aIter, cancelFn = fa.ResumeWithCancel(ctx, resumeInfo, opts...) + } else { + aIter = fa.Resume(ctx, resumeInfo, opts...) + } + } else { + aIter = fa.Resume(ctx, resumeInfo, opts...) + } + if r.store == nil { - return aIter, nil + return aIter, cancelFn, nil } niter, gen := NewAsyncIteratorPair[*AgentEvent]() go r.handleIter(ctx, aIter, gen, &checkPointID) - return niter, nil + return niter, cancelFn, nil } func (r *Runner) handleIter(ctx context.Context, aIter *AsyncIterator[*AgentEvent], diff --git a/adk/turn_loop.go b/adk/turn_loop.go index e5cf1c36f..af69ae5d3 100644 --- a/adk/turn_loop.go +++ b/adk/turn_loop.go @@ -21,6 +21,8 @@ import ( "errors" "fmt" "runtime/debug" + "sync" + "sync/atomic" "time" "github.com/cloudwego/eino/internal/safe" @@ -44,19 +46,24 @@ const ( ) type consumeConfig struct { - Mode ConsumeMode - Timeout time.Duration - CancelOpts []CancelOption + Mode ConsumeMode + Timeout time.Duration + CancelOpts []CancelOption + CheckPointID string } type ConsumeOption func(*consumeConfig) +// WithPreemptive sets the consume mode to preemptive, which cancels the +// currently running agent immediately. func WithPreemptive() ConsumeOption { return func(config *consumeConfig) { config.Mode = ConsumePreemptive } } +// WithPreemptiveOnTimeout sets the consume mode to preemptive with a timeout. +// If the current agent does not complete within the timeout, it will be canceled. func WithPreemptiveOnTimeout(timeout time.Duration) ConsumeOption { return func(config *consumeConfig) { config.Mode = ConsumePreemptiveOnTimeout @@ -64,12 +71,21 @@ func WithPreemptiveOnTimeout(timeout time.Duration) ConsumeOption { } } +// WithCancelOptions appends cancel options to be used when canceling the agent. func WithCancelOptions(opts ...CancelOption) ConsumeOption { return func(config *consumeConfig) { config.CancelOpts = append(config.CancelOpts, opts...) } } +// WithConsumeCheckPointID sets the checkpoint ID for the consumed message. +// When set, the checkpoint will be saved with this ID if an interrupt occurs. +func WithConsumeCheckPointID(id string) ConsumeOption { + return func(config *consumeConfig) { + config.CheckPointID = id + } +} + type ReceiveConfig struct { Timeout time.Duration } @@ -79,6 +95,24 @@ type MessageSource[T any] interface { Front(context.Context, ReceiveConfig) (context.Context, T, []ConsumeOption, error) } +type turnLoopRunConfig[T any] struct { + checkPointID string + item T +} + +// TurnLoopRunOption is an option for TurnLoop.Run. +type TurnLoopRunOption[T any] func(*turnLoopRunConfig[T]) + +// WithTurnLoopResume configures the TurnLoop to resume from a previously saved checkpoint. +// The checkPointID identifies the checkpoint to resume from, and item is the original input +// that triggered the interrupted execution. +func WithTurnLoopResume[T any](checkPointID string, item T) TurnLoopRunOption[T] { + return func(c *turnLoopRunConfig[T]) { + c.checkPointID = checkPointID + c.item = item + } +} + // TurnLoopConfig is the configuration for creating a TurnLoop. type TurnLoopConfig[T any] struct { // Source provides messages to drive the loop. Required. @@ -88,12 +122,16 @@ type TurnLoopConfig[T any] struct { GenInput func(ctx context.Context, item T) (*AgentInput, []AgentRunOption, error) // GetAgent returns the Agent to run for a given message. Required. GetAgent func(ctx context.Context, item T) (Agent, error) - // OnAgentEvent is called for each event emitted by the agent. Optional. + // OnAgentEvents is called for each event emitted by the agent. Optional. // The inputItem is the message that triggered the current agent turn. + // If not provided, the default implementation will consume all events and + // return any error event encountered. OnAgentEvents func(ctx context.Context, inputItem T, event *AsyncIterator[*AgentEvent]) error // ReceiveTimeout is the timeout passed to Source.Receive on each iteration. // Zero means no timeout. Optional. ReceiveTimeout time.Duration + + Store CheckPointStore } // TurnLoop is a loop that pulls messages from a source, runs an Agent for @@ -106,8 +144,63 @@ type TurnLoop[T any] struct { getAgent func(ctx context.Context, item T) (Agent, error) onAgentEvents func(ctx context.Context, inputItem T, event *AsyncIterator[*AgentEvent]) error receiveTimeout time.Duration + store CheckPointStore +} + +type turnLoopCancelSig struct { + done chan struct{} + config atomic.Value } +func newTurnLoopCancelSig() *turnLoopCancelSig { + return &turnLoopCancelSig{ + done: make(chan struct{}), + } +} + +func (cs *turnLoopCancelSig) cancel(cfg *cancelConfig) { + cs.config.Store(cfg) + close(cs.done) +} + +func (cs *turnLoopCancelSig) isCancelled() bool { + select { + case <-cs.done: + return true + default: + return false + } +} + +func (cs *turnLoopCancelSig) getConfig() *cancelConfig { + if v := cs.config.Load(); v != nil { + return v.(*cancelConfig) + } + return nil +} + +type turnLoopCancelSigKey struct{} + +func withTurnLoopCancelSig(ctx context.Context, cs *turnLoopCancelSig) context.Context { + return context.WithValue(ctx, turnLoopCancelSigKey{}, cs) +} + +func getTurnLoopCancelSig(ctx context.Context) *turnLoopCancelSig { + if v, ok := ctx.Value(turnLoopCancelSigKey{}).(*turnLoopCancelSig); ok { + return v + } + return nil +} + +// TurnLoopCancelFunc is the cancel function returned by WithCancel. +// Unlike Agent's CancelFunc, it does not require a context parameter +// since the context is already bound when WithCancel is called. +type TurnLoopCancelFunc func(opts ...CancelOption) error + +// ErrAgentNotCancellableInTurnLoop is returned when WithCancel context is used +// but the Agent does not implement CancellableAgent. +var ErrAgentNotCancellableInTurnLoop = errors.New("agent does not support cancel but WithCancel context was provided") + // NewTurnLoop creates a new TurnLoop from the given configuration. // Source, GenInput, and GetAgent are required fields. func NewTurnLoop[T any](config TurnLoopConfig[T]) (*TurnLoop[T], error) { @@ -121,17 +214,78 @@ func NewTurnLoop[T any](config TurnLoopConfig[T]) (*TurnLoop[T], error) { return nil, fmt.Errorf("TurnLoopConfig.GetAgent is required") } + onAgentEvents := config.OnAgentEvents + if onAgentEvents == nil { + onAgentEvents = func(_ context.Context, _ T, iter *AsyncIterator[*AgentEvent]) error { + for { + event, ok := iter.Next() + if !ok { + break + } + if event.Err != nil { + return event.Err + } + } + return nil + } + } + return &TurnLoop[T]{ source: config.Source, genInput: config.GenInput, getAgent: config.GetAgent, - onAgentEvents: config.OnAgentEvents, + onAgentEvents: onAgentEvents, receiveTimeout: config.ReceiveTimeout, + store: config.Store, }, nil } var ErrLoopExit = errors.New("loop exit") +// WithCancel returns a new context and a cancel function that can be used to +// cancel the TurnLoop's Run method externally. Each call to WithCancel creates +// an independent cancel signal, allowing multiple concurrent Run calls with +// separate cancel controls. +// +// The returned TurnLoopCancelFunc does not require a context parameter since +// the context is already bound when WithCancel is called. +// +// Example: +// +// ctx, cancel := turnLoop.WithCancel(context.Background()) +// go func() { +// err := turnLoop.Run(ctx) +// }() +// // Later, to cancel: +// cancel(adk.WithCancelMode(adk.CancelAfterToolCall)) +func (l *TurnLoop[T]) WithCancel(ctx context.Context) (context.Context, TurnLoopCancelFunc) { + cs := newTurnLoopCancelSig() + ctx = withTurnLoopCancelSig(ctx, cs) + + var once sync.Once + cancelFn := func(opts ...CancelOption) error { + cfg := &cancelConfig{ + Mode: CancelImmediate, + } + for _, opt := range opts { + opt(cfg) + } + + cancelled := false + once.Do(func() { + cs.cancel(cfg) + cancelled = true + }) + + if !cancelled { + return ErrAgentFinished + } + return nil + } + + return ctx, cancelFn +} + // Run starts the blocking loop that continuously receives messages from the // source, runs the agent returned by GetAgent for each message, and dispatches // resulting events to OnAgentEvent. It blocks until the source returns an error @@ -141,16 +295,60 @@ var ErrLoopExit = errors.New("loop exit") // implements Cancellable, the agent is canceled and the new message is processed // immediately. If the agent does not implement Cancellable, preemptive messages // are queued and processed after the current agent finishes. -func (l *TurnLoop[T]) Run(ctx context.Context) error { +// +// To enable external cancellation, use WithCancel to create a cancellable context: +// +// ctx, cancel := turnLoop.WithCancel(context.Background()) +// go turnLoop.Run(ctx) +// // Later: cancel() +// +// To enable checkpoint-based resumption, use WithTurnLoopResume: +// +// err := turnLoop.Run(ctx, WithTurnLoopResume("session-123")) +// +//nolint:cyclop,funlen // This is a core method, splitting would make the logic harder to follow +func (l *TurnLoop[T]) Run(ctx context.Context, opts ...TurnLoopRunOption[T]) error { + var runCfg turnLoopRunConfig[T] + for _, opt := range opts { + opt(&runCfg) + } + + cs := getTurnLoopCancelSig(ctx) + toResumeFirst := false + if len(runCfg.checkPointID) > 0 { + toResumeFirst = true + } + for { - nCtx, item, option, err := l.source.Receive(ctx, ReceiveConfig{ - Timeout: l.receiveTimeout, - }) - if errors.Is(err, ErrLoopExit) { + if cs != nil && cs.isCancelled() { return nil } - if err != nil { - return fmt.Errorf("failed to receive message: %w", err) + + var nCtx context.Context + var item T + var checkPointID string + if !toResumeFirst { + var err error + var option []ConsumeOption + nCtx, item, option, err = l.source.Receive(ctx, ReceiveConfig{ + Timeout: l.receiveTimeout, + }) + if errors.Is(err, ErrLoopExit) { + return nil + } + if err != nil { + return fmt.Errorf("failed to receive message: %w", err) + } + o := applyConsumeOptions(option) + checkPointID = o.CheckPointID + } else { + nCtx = ctx + item = runCfg.item + checkPointID = runCfg.checkPointID + } + + if len(checkPointID) > 0 && l.store == nil { + return fmt.Errorf("CheckPointStore is required") } input, runOpts, e := l.genInput(nCtx, item) @@ -165,33 +363,56 @@ func (l *TurnLoop[T]) Run(ctx context.Context) error { var cancelFunc CancelFunc var iter *AsyncIterator[*AgentEvent] - if ca, isAgentCancellable := agent.(CancellableRun); isAgentCancellable { - iter, cancelFunc = ca.RunWithCancel(nCtx, input, runOpts...) + _, isAgentCancellable := agent.(CancellableAgent) + if cs != nil && !isAgentCancellable { + return fmt.Errorf("%w: agent %s", ErrAgentNotCancellableInTurnLoop, agent.Name(nCtx)) + } + + if toResumeFirst { + var err error + iter, cancelFunc, err = NewRunner(nCtx, RunnerConfig{ + EnableStreaming: input.EnableStreaming, + Agent: agent, + CheckPointStore: l.store, + }).ResumeWithCancel(nCtx, checkPointID, runOpts...) + if err != nil { + return fmt.Errorf("failed to resume agent: %w", err) + } + toResumeFirst = false + } else if isAgentCancellable { + var cps CheckPointStore + if len(checkPointID) > 0 { + cps = l.store + runOpts = append(runOpts, WithCheckPointID(checkPointID)) + } + iter, cancelFunc = NewRunner(nCtx, RunnerConfig{ + EnableStreaming: input.EnableStreaming, + Agent: agent, + CheckPointStore: cps, + }).RunWithCancel(nCtx, input.Messages, runOpts...) } else { - iter = agent.Run(nCtx, input, runOpts...) + var cps CheckPointStore + if len(checkPointID) > 0 { + cps = l.store + runOpts = append(runOpts, WithCheckPointID(checkPointID)) + } + iter = NewRunner(nCtx, RunnerConfig{ + EnableStreaming: input.EnableStreaming, + Agent: agent, + CheckPointStore: cps, + }).Run(nCtx, input.Messages, runOpts...) } - // handleEvents drains the agent iterator, forwarding each event to the - // OnAgentEvent callback. It is called directly in the non-cancellable - // path and from a goroutine in the cancellable path. handleEvents := func() error { - oe := l.onAgentEvents(ctx, item, iter) - if oe != nil { - return oe - } - return nil + return l.handleEvents(ctx, item, iter, checkPointID) } var handleEventErr error if cancelFunc != nil { - // Cancellable path: consume events in a goroutine so the main - // goroutine can block on Receive concurrently. done := make(chan struct{}) go func() { defer func() { - // Recover panics from the iterator or callback so they - // don't crash the process; surface them as errors instead. panicErr := recover() if panicErr != nil { handleEventErr = safe.NewPanicErr(panicErr, debug.Stack()) @@ -203,63 +424,150 @@ func (l *TurnLoop[T]) Run(ctx context.Context) error { handleEventErr = handleEvents() }() - // Block on the next message while events are being consumed above. - _, _, option, err = l.source.Front(nCtx, ReceiveConfig{ - Timeout: l.receiveTimeout, - }) - if err != nil { - <-done // wait for the event goroutine before returning - if errors.Is(err, ErrLoopExit) { + frontDone := make(chan struct{}) + var frontErr error + var option []ConsumeOption + go func() { + defer func() { + panicErr := recover() + if panicErr != nil { + frontErr = safe.NewPanicErr(panicErr, debug.Stack()) + } + + close(frontDone) + }() + _, _, option, frontErr = l.source.Front(nCtx, ReceiveConfig{ + Timeout: l.receiveTimeout, + }) + }() + + var externalCancelled bool + select { + case <-frontDone: + case <-done: + case <-func() <-chan struct{} { + if cs != nil { + return cs.done + } + return nil + }(): + externalCancelled = true + cfg := cs.getConfig() + err := cancelFunc(nCtx, cancelConfigToOpts(cfg)...) + if err != nil && !errors.Is(err, ErrAgentFinished) { + <-done + return fmt.Errorf("failed to cancel agent: %w", err) + } + } + + if externalCancelled { + <-done + return l.wrapHandleEventErr(handleEventErr) + } + + if frontErr != nil { + <-done + if errors.Is(frontErr, ErrLoopExit) { return nil } - return fmt.Errorf("failed to front message: %w", err) + return fmt.Errorf("failed to front message: %w", frontErr) } - // If the new message requests preemption, cancel the running agent. - // Cancel triggers the iterator to terminate, which unblocks the - // event goroutine above. o := applyConsumeOptions(option) switch o.Mode { case ConsumePreemptive: - err = cancelFunc(nCtx, o.CancelOpts...) + err := cancelFunc(nCtx, o.CancelOpts...) if err != nil { - <-done // wait for the event goroutine before returning + <-done return fmt.Errorf("failed to cancel agent: %w", err) } case ConsumePreemptiveOnTimeout: select { case <-done: case <-time.After(o.Timeout): - err = cancelFunc(nCtx, o.CancelOpts...) + err := cancelFunc(nCtx, o.CancelOpts...) if err != nil { - <-done // wait for the event goroutine before returning + <-done + return fmt.Errorf("failed to cancel agent: %w", err) + } + case <-func() <-chan struct{} { + if cs != nil { + return cs.done + } + return nil + }(): + cfg := cs.getConfig() + err := cancelFunc(nCtx, cancelConfigToOpts(cfg)...) + if err != nil && !errors.Is(err, ErrAgentFinished) { + <-done return fmt.Errorf("failed to cancel agent: %w", err) } + <-done + return l.wrapHandleEventErr(handleEventErr) } } - // Wait for event consumption to finish (normal completion or - // post-cancel drain) before starting the next turn. <-done - if handleEventErr != nil { - if errors.Is(handleEventErr, ErrLoopExit) { - return nil - } - return fmt.Errorf("failed to handle events: %w", handleEventErr) + if err := l.wrapHandleEventErr(handleEventErr); err != nil { + return err } } else { - // Non-cancellable path: consume all events sequentially, then - // block on the next message. if handleEventErr = handleEvents(); handleEventErr != nil { - if errors.Is(handleEventErr, ErrLoopExit) { - return nil + if err := l.wrapHandleEventErr(handleEventErr); err != nil { + return err } - return fmt.Errorf("failed to handle events: %w", handleEventErr) } } } } +func (l *TurnLoop[T]) wrapHandleEventErr(handleEventErr error) error { + if handleEventErr == nil { + return nil + } + if errors.Is(handleEventErr, ErrLoopExit) { + return nil + } + var interruptErr *TurnLoopInterruptError[T] + if errors.As(handleEventErr, &interruptErr) { + return interruptErr + } + return fmt.Errorf("failed to handle events: %w", handleEventErr) +} + +func (l *TurnLoop[T]) handleEvents(ctx context.Context, item T, iter *AsyncIterator[*AgentEvent], checkPointID string) error { + copies := copyEventIterator(iter, 2) + oe := l.onAgentEvents(ctx, item, copies[0]) + if oe != nil { + return oe + } + for { + e, ok := copies[1].Next() + if !ok { + break + } + if e.Action != nil && e.Action.Interrupted != nil { + return &TurnLoopInterruptError[T]{ + Item: item, + CheckpointID: checkPointID, + InterruptContexts: e.Action.Interrupted.InterruptContexts, + } + } + } + return nil +} + +func cancelConfigToOpts(cfg *cancelConfig) []CancelOption { + if cfg == nil { + return nil + } + opts := []CancelOption{WithCancelMode(cfg.Mode)} + if cfg.Timeout != nil { + opts = append(opts, WithCancelTimeout(*cfg.Timeout)) + } + return opts +} + func applyConsumeOptions(opts []ConsumeOption) *consumeConfig { var config consumeConfig for _, opt := range opts { @@ -267,3 +575,15 @@ func applyConsumeOptions(opts []ConsumeOption) *consumeConfig { } return &config } + +type TurnLoopInterruptError[T any] struct { + Item T + CheckpointID string + // InterruptContexts provides a structured, user-facing view of the interrupt chain. + // Each context represents a step in the agent hierarchy that was interrupted. + InterruptContexts []*InterruptCtx +} + +func (t *TurnLoopInterruptError[T]) Error() string { + return fmt.Sprintf("TurnLoopInterruptError[%s]: %v", t.CheckpointID, t.InterruptContexts) +} diff --git a/adk/turn_loop_test.go b/adk/turn_loop_test.go index f66c1298b..7c22f3699 100644 --- a/adk/turn_loop_test.go +++ b/adk/turn_loop_test.go @@ -19,53 +19,66 @@ package adk import ( "context" "errors" + "fmt" "sync/atomic" "testing" "time" + "unsafe" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/components/tool" + "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/schema" ) -// --------------------------------------------------------------------------- -// Test mocks -// --------------------------------------------------------------------------- - -// turnLoopMockSource returns items from a slice (all NonPreemptive), then an error. type turnLoopMockSource struct { items []string idx int err error } -func (s *turnLoopMockSource) Receive(_ context.Context, _ time.Duration) (string, ConsumeOption, error) { +func (s *turnLoopMockSource) Receive(ctx context.Context, cfg ReceiveConfig) (context.Context, string, []ConsumeOption, error) { if s.idx >= len(s.items) { - return "", NonPreemptiveConsumeOption, s.err + return ctx, "", nil, s.err } item := s.items[s.idx] s.idx++ - return item, NonPreemptiveConsumeOption, nil + return ctx, item, nil, nil +} + +func (s *turnLoopMockSource) Front(ctx context.Context, cfg ReceiveConfig) (context.Context, string, []ConsumeOption, error) { + if s.idx >= len(s.items) { + return ctx, "", nil, s.err + } + return ctx, s.items[s.idx], nil, nil } -// turnLoopFuncSource delegates Receive to a user-supplied function. type turnLoopFuncSource[T any] struct { - fn func(ctx context.Context, timeout time.Duration) (T, ConsumeOption, error) + receive func(ctx context.Context, cfg ReceiveConfig) (context.Context, T, []ConsumeOption, error) + front func(ctx context.Context, cfg ReceiveConfig) (context.Context, T, []ConsumeOption, error) } -func (s *turnLoopFuncSource[T]) Receive(ctx context.Context, timeout time.Duration) (T, ConsumeOption, error) { - return s.fn(ctx, timeout) +func (s *turnLoopFuncSource[T]) Receive(ctx context.Context, cfg ReceiveConfig) (context.Context, T, []ConsumeOption, error) { + return s.receive(ctx, cfg) +} + +func (s *turnLoopFuncSource[T]) Front(ctx context.Context, cfg ReceiveConfig) (context.Context, T, []ConsumeOption, error) { + if s.front != nil { + return s.front(ctx, cfg) + } + return s.receive(ctx, cfg) } -// turnLoopMockAgent emits a fixed list of events per Run call. type turnLoopMockAgent struct { name string events []*AgentEvent } func (a *turnLoopMockAgent) Name(_ context.Context) string { return a.name } -func (a *turnLoopMockAgent) Description(_ context.Context) string { return "mock agent" } +func (a *turnLoopMockAgent) Description(_ context.Context) string { return "mock agent" } func (a *turnLoopMockAgent) Run(_ context.Context, _ *AgentInput, _ ...AgentRunOption) *AsyncIterator[*AgentEvent] { iter, gen := NewAsyncIteratorPair[*AgentEvent]() go func() { @@ -77,39 +90,37 @@ func (a *turnLoopMockAgent) Run(_ context.Context, _ *AgentInput, _ ...AgentRunO return iter } -// turnLoopCancellableAgent blocks until Cancel is called, then closes its iterator. -// It implements both Agent and Cancellable. type turnLoopCancellableAgent struct { - name string - startedCh chan struct{} // closed when Run is entered - cancelCh chan struct{} // closed by Cancel - cancelled atomic.Bool - cancelledOpt CancelOption // records the CancelOption passed to Cancel + name string + startedCh chan struct{} + cancelCh chan struct{} + cancelled int32 + cancelledOpt []CancelOption } func (a *turnLoopCancellableAgent) Name(_ context.Context) string { return a.name } -func (a *turnLoopCancellableAgent) Description(_ context.Context) string { return "cancellable mock" } +func (a *turnLoopCancellableAgent) Description(_ context.Context) string { return "cancellable mock" } func (a *turnLoopCancellableAgent) Run(_ context.Context, _ *AgentInput, _ ...AgentRunOption) *AsyncIterator[*AgentEvent] { + iter, _ := a.RunWithCancel(context.Background(), nil) + return iter +} + +func (a *turnLoopCancellableAgent) RunWithCancel(_ context.Context, _ *AgentInput, _ ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) { iter, gen := NewAsyncIteratorPair[*AgentEvent]() close(a.startedCh) go func() { defer gen.Close() <-a.cancelCh }() - return iter -} - -func (a *turnLoopCancellableAgent) Cancel(_ context.Context, opt CancelOption) error { - a.cancelled.Store(true) - a.cancelledOpt = opt - close(a.cancelCh) - return nil + cancelFunc := func(_ context.Context, opts ...CancelOption) error { + atomic.StoreInt32(&a.cancelled, 1) + a.cancelledOpt = opts + close(a.cancelCh) + return nil + } + return iter, cancelFunc } -// --------------------------------------------------------------------------- -// Tests — validation -// --------------------------------------------------------------------------- - func TestNewTurnLoop_Validation(t *testing.T) { t.Run("missing source", func(t *testing.T) { _, err := NewTurnLoop(TurnLoopConfig[string]{ @@ -138,7 +149,7 @@ func TestNewTurnLoop_Validation(t *testing.T) { assert.Contains(t, err.Error(), "GetAgent") }) - t.Run("valid config", func(t *testing.T) { + t.Run("valid config without OnAgentEvents", func(t *testing.T) { loop, err := NewTurnLoop(TurnLoopConfig[string]{ Source: &turnLoopMockSource{}, GenInput: func(_ context.Context, _ string) (*AgentInput, []AgentRunOption, error) { return nil, nil, nil }, @@ -147,11 +158,18 @@ func TestNewTurnLoop_Validation(t *testing.T) { require.NoError(t, err) assert.NotNil(t, loop) }) -} -// --------------------------------------------------------------------------- -// Tests — non-preemptive (queued) behavior -// --------------------------------------------------------------------------- + t.Run("valid config with OnAgentEvents", func(t *testing.T) { + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: &turnLoopMockSource{}, + GenInput: func(_ context.Context, _ string) (*AgentInput, []AgentRunOption, error) { return nil, nil, nil }, + GetAgent: func(_ context.Context, _ string) (Agent, error) { return nil, nil }, + OnAgentEvents: func(_ context.Context, _ string, _ *AsyncIterator[*AgentEvent]) error { return nil }, + }) + require.NoError(t, err) + assert.NotNil(t, loop) + }) +} func TestTurnLoop_NormalLoop(t *testing.T) { agent := &turnLoopMockAgent{ @@ -161,8 +179,8 @@ func TestTurnLoop_NormalLoop(t *testing.T) { }, } - var receivedEvents []*AgentEvent var receivedItems []string + var eventCount int source := &turnLoopMockSource{ items: []string{"msg1", "msg2", "msg3"}, @@ -178,8 +196,14 @@ func TestTurnLoop_NormalLoop(t *testing.T) { GetAgent: func(_ context.Context, _ string) (Agent, error) { return agent, nil }, - OnAgentEvent: func(_ context.Context, _ string, event *AgentEvent) error { - receivedEvents = append(receivedEvents, event) + OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { + for { + _, ok := iter.Next() + if !ok { + break + } + eventCount++ + } return nil }, }) @@ -188,7 +212,7 @@ func TestTurnLoop_NormalLoop(t *testing.T) { err = loop.Run(context.Background()) assert.ErrorIs(t, err, context.DeadlineExceeded) assert.Equal(t, []string{"msg1", "msg2", "msg3"}, receivedItems) - assert.Len(t, receivedEvents, 3) + assert.Equal(t, 3, eventCount) } func TestTurnLoop_SourceError(t *testing.T) { @@ -206,6 +230,7 @@ func TestTurnLoop_SourceError(t *testing.T) { GetAgent: func(_ context.Context, _ string) (Agent, error) { return &turnLoopMockAgent{name: "a"}, nil }, + OnAgentEvents: func(_ context.Context, _ string, _ *AsyncIterator[*AgentEvent]) error { return nil }, }) require.NoError(t, err) @@ -224,6 +249,7 @@ func TestTurnLoop_GenInputError(t *testing.T) { GetAgent: func(_ context.Context, _ string) (Agent, error) { return &turnLoopMockAgent{name: "a"}, nil }, + OnAgentEvents: func(_ context.Context, _ string, _ *AsyncIterator[*AgentEvent]) error { return nil }, }) require.NoError(t, err) @@ -243,6 +269,7 @@ func TestTurnLoop_GetAgentError(t *testing.T) { GetAgent: func(_ context.Context, _ string) (Agent, error) { return nil, agentErr }, + OnAgentEvents: func(_ context.Context, _ string, _ *AsyncIterator[*AgentEvent]) error { return nil }, }) require.NoError(t, err) @@ -251,7 +278,7 @@ func TestTurnLoop_GetAgentError(t *testing.T) { assert.Contains(t, err.Error(), "failed to get agent") } -func TestTurnLoop_OnAgentEventError(t *testing.T) { +func TestTurnLoop_OnAgentEventsError(t *testing.T) { eventErr := errors.New("event handler failure") agent := &turnLoopMockAgent{ name: "test-agent", @@ -266,7 +293,7 @@ func TestTurnLoop_OnAgentEventError(t *testing.T) { GetAgent: func(_ context.Context, _ string) (Agent, error) { return agent, nil }, - OnAgentEvent: func(_ context.Context, _ string, _ *AgentEvent) error { + OnAgentEvents: func(_ context.Context, _ string, _ *AsyncIterator[*AgentEvent]) error { return eventErr }, }) @@ -274,20 +301,20 @@ func TestTurnLoop_OnAgentEventError(t *testing.T) { err = loop.Run(context.Background()) assert.ErrorIs(t, err, eventErr) - assert.Contains(t, err.Error(), "OnAgentEvent callback failed") + assert.Contains(t, err.Error(), "failed to handle events") } func TestTurnLoop_ContextCancellation(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) callCount := 0 - source := &turnLoopFuncSource[string]{fn: func(ctx context.Context, _ time.Duration) (string, ConsumeOption, error) { + source := &turnLoopFuncSource[string]{receive: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { callCount++ if callCount > 1 { cancel() - return "", NonPreemptiveConsumeOption, ctx.Err() + return ctx, "", nil, ctx.Err() } - return "msg1", NonPreemptiveConsumeOption, nil + return ctx, "msg1", nil, nil }} agent := &turnLoopMockAgent{ @@ -303,6 +330,14 @@ func TestTurnLoop_ContextCancellation(t *testing.T) { GetAgent: func(_ context.Context, _ string) (Agent, error) { return agent, nil }, + OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := iter.Next(); !ok { + break + } + } + return nil + }, }) require.NoError(t, err) @@ -330,8 +365,14 @@ func TestTurnLoop_MultipleEventsPerTurn(t *testing.T) { GetAgent: func(_ context.Context, _ string) (Agent, error) { return agent, nil }, - OnAgentEvent: func(_ context.Context, _ string, _ *AgentEvent) error { - eventCount++ + OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { + for { + _, ok := iter.Next() + if !ok { + break + } + eventCount++ + } return nil }, }) @@ -342,7 +383,7 @@ func TestTurnLoop_MultipleEventsPerTurn(t *testing.T) { assert.Equal(t, 3, eventCount) } -func TestTurnLoop_NoOnAgentEvent(t *testing.T) { +func TestTurnLoop_DefaultOnAgentEvents(t *testing.T) { agent := &turnLoopMockAgent{ name: "test-agent", events: []*AgentEvent{{Output: &AgentOutput{}}}, @@ -363,6 +404,28 @@ func TestTurnLoop_NoOnAgentEvent(t *testing.T) { assert.ErrorIs(t, err, context.DeadlineExceeded) } +func TestTurnLoop_DefaultOnAgentEventsWithError(t *testing.T) { + agentErr := errors.New("agent internal error") + agent := &turnLoopMockAgent{ + name: "error-agent", + events: []*AgentEvent{{Err: agentErr}}, + } + + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: &turnLoopMockSource{items: []string{"msg1"}, err: context.DeadlineExceeded}, + GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil + }, + GetAgent: func(_ context.Context, _ string) (Agent, error) { + return agent, nil + }, + }) + require.NoError(t, err) + + err = loop.Run(context.Background()) + assert.ErrorIs(t, err, agentErr) +} + func TestTurnLoop_AgentErrorEvent(t *testing.T) { agentErr := errors.New("agent internal error") agent := &turnLoopMockAgent{ @@ -378,6 +441,18 @@ func TestTurnLoop_AgentErrorEvent(t *testing.T) { GetAgent: func(_ context.Context, _ string) (Agent, error) { return agent, nil }, + OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { + for { + event, ok := iter.Next() + if !ok { + break + } + if event.Err != nil { + return fmt.Errorf("agent run failed: %w", event.Err) + } + } + return nil + }, }) require.NoError(t, err) @@ -386,10 +461,6 @@ func TestTurnLoop_AgentErrorEvent(t *testing.T) { assert.Contains(t, err.Error(), "agent run failed") } -// --------------------------------------------------------------------------- -// Tests — preemptive behavior -// --------------------------------------------------------------------------- - func TestTurnLoop_PreemptiveCancellation(t *testing.T) { slowAgent := &turnLoopCancellableAgent{ name: "slow-agent", @@ -403,19 +474,36 @@ func TestTurnLoop_PreemptiveCancellation(t *testing.T) { } var processedItems []string - callCount := 0 - source := &turnLoopFuncSource[string]{fn: func(_ context.Context, _ time.Duration) (string, ConsumeOption, error) { - callCount++ - switch callCount { - case 1: - return "slow-msg", NonPreemptiveConsumeOption, nil - case 2: + receiveCount := 0 + msgs := []struct { + item string + opts []ConsumeOption + err error + }{ + {"slow-msg", nil, nil}, + {"preempt-msg", []ConsumeOption{WithPreemptive(), WithCancelOptions(WithCancelMode(CancelImmediate))}, nil}, + {"", nil, context.DeadlineExceeded}, + } + frontIdx := 0 + source := &turnLoopFuncSource[string]{ + receive: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { + if receiveCount >= len(msgs) { + return ctx, "", nil, context.DeadlineExceeded + } + m := msgs[receiveCount] + receiveCount++ + frontIdx = receiveCount + return ctx, m.item, m.opts, m.err + }, + front: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { <-slowAgent.startedCh - return "preempt-msg", ConsumeOption{Mode: ConsumePreemptive, CancelOption: CancelOption{Mode: CancelImmediate}}, nil - default: - return "", NonPreemptiveConsumeOption, context.DeadlineExceeded - } - }} + if frontIdx >= len(msgs) { + return ctx, "", nil, context.DeadlineExceeded + } + m := msgs[frontIdx] + return ctx, m.item, m.opts, m.err + }, + } loop, err := NewTurnLoop(TurnLoopConfig[string]{ Source: source, @@ -429,110 +517,603 @@ func TestTurnLoop_PreemptiveCancellation(t *testing.T) { } return fastAgent, nil }, + OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := iter.Next(); !ok { + break + } + } + return nil + }, }) require.NoError(t, err) err = loop.Run(context.Background()) assert.ErrorIs(t, err, context.DeadlineExceeded) - assert.True(t, slowAgent.cancelled.Load(), "slow agent should have been cancelled") - assert.Equal(t, CancelImmediate, slowAgent.cancelledOpt.Mode) + assert.True(t, atomic.LoadInt32(&slowAgent.cancelled) == 1, "slow agent should have been cancelled") assert.Equal(t, []string{"slow-msg", "preempt-msg"}, processedItems) } -func TestTurnLoop_PreemptiveWithCancelMode(t *testing.T) { - slowAgent := &turnLoopCancellableAgent{ - name: "slow-agent", - startedCh: make(chan struct{}), - cancelCh: make(chan struct{}), +func TestTurnLoop_PreemptiveNonCancellableAgent(t *testing.T) { + nonCancellableAgent := &turnLoopMockAgent{ + name: "non-cancellable-agent", + events: []*AgentEvent{{Output: &AgentOutput{}}}, } - fastAgent := &turnLoopMockAgent{ name: "fast-agent", events: []*AgentEvent{{Output: &AgentOutput{}}}, } + var processedItems []string callCount := 0 - source := &turnLoopFuncSource[string]{fn: func(_ context.Context, _ time.Duration) (string, ConsumeOption, error) { + source := &turnLoopFuncSource[string]{receive: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { callCount++ switch callCount { case 1: - return "slow-msg", NonPreemptiveConsumeOption, nil + return ctx, "non-cancel-msg", nil, nil case 2: - <-slowAgent.startedCh - return "preempt-msg", ConsumeOption{ - Mode: ConsumePreemptive, - CancelOption: CancelOption{Mode: CancelAfterChatModel | CancelAfterToolCall}, - }, nil + return ctx, "preempt-msg", []ConsumeOption{WithPreemptive()}, nil default: - return "", NonPreemptiveConsumeOption, context.DeadlineExceeded + return ctx, "", nil, context.DeadlineExceeded } }} loop, err := NewTurnLoop(TurnLoopConfig[string]{ Source: source, GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { + processedItems = append(processedItems, item) return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil }, GetAgent: func(_ context.Context, item string) (Agent, error) { - if item == "slow-msg" { - return slowAgent, nil + if item == "non-cancel-msg" { + return nonCancellableAgent, nil } return fastAgent, nil }, + OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := iter.Next(); !ok { + break + } + } + return nil + }, }) require.NoError(t, err) err = loop.Run(context.Background()) assert.ErrorIs(t, err, context.DeadlineExceeded) - assert.True(t, slowAgent.cancelled.Load()) - assert.Equal(t, CancelAfterChatModel|CancelAfterToolCall, slowAgent.cancelledOpt.Mode) + assert.Equal(t, []string{"non-cancel-msg", "preempt-msg"}, processedItems) } -func TestTurnLoop_PreemptiveNonCancellableAgent(t *testing.T) { - // A non-cancellable agent cannot be preempted, so the new Run processes - // events sequentially before calling Receive. The preemptive message is - // effectively queued and processed in the next turn. - nonCancellableAgent := &turnLoopMockAgent{ +func TestTurnLoop_WithCancel_Basic(t *testing.T) { + agent := &turnLoopCancellableAgent{ + name: "test-agent", + startedCh: make(chan struct{}), + cancelCh: make(chan struct{}), + } + + receiveCount := int32(0) + frontBlocked := make(chan struct{}) + source := &turnLoopFuncSource[string]{ + receive: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { + cnt := atomic.AddInt32(&receiveCount, 1) + if cnt == 1 { + return ctx, "msg1", nil, nil + } + return ctx, "", nil, context.DeadlineExceeded + }, + front: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { + <-frontBlocked + return ctx, "", nil, context.DeadlineExceeded + }, + } + + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: source, + GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil + }, + GetAgent: func(_ context.Context, _ string) (Agent, error) { + return agent, nil + }, + OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := iter.Next(); !ok { + break + } + } + return nil + }, + }) + require.NoError(t, err) + + ctx, cancel := loop.WithCancel(context.Background()) + done := make(chan error) + go func() { + done <- loop.Run(ctx) + }() + + <-agent.startedCh + e := cancel() + assert.NoError(t, e) + + err = <-done + assert.NoError(t, err) +} + +func TestTurnLoop_WithCancel_DuringAgentRun(t *testing.T) { + slowAgent := &turnLoopCancellableAgent{ + name: "slow-agent", + startedCh: make(chan struct{}), + cancelCh: make(chan struct{}), + } + + receiveCount := int32(0) + frontBlocked := make(chan struct{}) + source := &turnLoopFuncSource[string]{ + receive: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { + cnt := atomic.AddInt32(&receiveCount, 1) + if cnt == 1 { + return ctx, "msg1", nil, nil + } + return ctx, "", nil, context.DeadlineExceeded + }, + front: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { + <-frontBlocked + return ctx, "", nil, context.DeadlineExceeded + }, + } + + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: source, + GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil + }, + GetAgent: func(_ context.Context, _ string) (Agent, error) { + return slowAgent, nil + }, + OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := iter.Next(); !ok { + break + } + } + return nil + }, + }) + require.NoError(t, err) + + ctx, cancel := loop.WithCancel(context.Background()) + done := make(chan error) + go func() { + done <- loop.Run(ctx) + }() + + <-slowAgent.startedCh + cancel(WithCancelMode(CancelImmediate)) + + err = <-done + assert.NoError(t, err) + assert.True(t, atomic.LoadInt32(&slowAgent.cancelled) == 1, "agent should have been cancelled") +} + +func TestTurnLoop_WithCancel_NonCancellableAgent_ReturnsError(t *testing.T) { + agent := &turnLoopMockAgent{ name: "non-cancellable-agent", events: []*AgentEvent{{Output: &AgentOutput{}}}, } - fastAgent := &turnLoopMockAgent{ - name: "fast-agent", + + source := &turnLoopFuncSource[string]{ + receive: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { + return ctx, "msg1", nil, nil + }, + } + + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: source, + GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil + }, + GetAgent: func(_ context.Context, _ string) (Agent, error) { + return agent, nil + }, + }) + require.NoError(t, err) + + ctx, _ := loop.WithCancel(context.Background()) + err = loop.Run(ctx) + + assert.ErrorIs(t, err, ErrAgentNotCancellableInTurnLoop) + assert.Contains(t, err.Error(), "non-cancellable-agent") +} + +func TestTurnLoop_WithCancel_MultipleCalls(t *testing.T) { + agent := &turnLoopMockAgent{ + name: "test-agent", events: []*AgentEvent{{Output: &AgentOutput{}}}, } - var processedItems []string - callCount := 0 - source := &turnLoopFuncSource[string]{fn: func(_ context.Context, _ time.Duration) (string, ConsumeOption, error) { - callCount++ - switch callCount { - case 1: - return "non-cancel-msg", NonPreemptiveConsumeOption, nil - case 2: - // Even though Mode is ConsumePreemptive, the agent doesn't - // implement Cancellable, so it's treated as non-preemptive. - return "preempt-msg", ConsumeOption{Mode: ConsumePreemptive}, nil + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: &turnLoopMockSource{items: []string{"msg1"}, err: context.DeadlineExceeded}, + GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil + }, + GetAgent: func(_ context.Context, _ string) (Agent, error) { + return agent, nil + }, + }) + require.NoError(t, err) + + _, cancel := loop.WithCancel(context.Background()) + + err1 := cancel() + err2 := cancel() + + assert.NoError(t, err1) + assert.ErrorIs(t, err2, ErrAgentFinished) +} + +func TestTurnLoop_WithCancel_IndependentCancels(t *testing.T) { + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: &turnLoopMockSource{items: []string{}, err: context.DeadlineExceeded}, + GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil + }, + GetAgent: func(_ context.Context, _ string) (Agent, error) { + return &turnLoopMockAgent{name: "a"}, nil + }, + }) + require.NoError(t, err) + + ctx1, cancel1 := loop.WithCancel(context.Background()) + ctx2, cancel2 := loop.WithCancel(context.Background()) + + cs1 := getTurnLoopCancelSig(ctx1) + cs2 := getTurnLoopCancelSig(ctx2) + + assert.NotNil(t, cs1) + assert.NotNil(t, cs2) + assert.NotEqual(t, cs1, cs2) + + cancel1() + assert.True(t, cs1.isCancelled()) + assert.False(t, cs2.isCancelled()) + + cancel2() + assert.True(t, cs2.isCancelled()) +} + +func TestTurnLoop_WithCancel_WithCancelOptions(t *testing.T) { + slowAgent := &turnLoopCancellableAgent{ + name: "slow-agent", + startedCh: make(chan struct{}), + cancelCh: make(chan struct{}), + } + + receiveCount := int32(0) + frontBlocked := make(chan struct{}) + source := &turnLoopFuncSource[string]{ + receive: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { + cnt := atomic.AddInt32(&receiveCount, 1) + if cnt == 1 { + return ctx, "msg1", nil, nil + } + return ctx, "", nil, context.DeadlineExceeded + }, + front: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { + <-frontBlocked + return ctx, "", nil, context.DeadlineExceeded + }, + } + + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: source, + GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil + }, + GetAgent: func(_ context.Context, _ string) (Agent, error) { + return slowAgent, nil + }, + OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := iter.Next(); !ok { + break + } + } + return nil + }, + }) + require.NoError(t, err) + + ctx, cancel := loop.WithCancel(context.Background()) + done := make(chan error) + go func() { + done <- loop.Run(ctx) + }() + + <-slowAgent.startedCh + cancel(WithCancelMode(CancelAfterToolCall)) + + <-done + assert.Len(t, slowAgent.cancelledOpt, 1) +} + +type turnLoopInMemoryStore struct { + data map[string][]byte +} + +func newTurnLoopInMemoryStore() *turnLoopInMemoryStore { + return &turnLoopInMemoryStore{data: make(map[string][]byte)} +} + +func (s *turnLoopInMemoryStore) Get(_ context.Context, key string) ([]byte, bool, error) { + v, ok := s.data[key] + return v, ok, nil +} + +func (s *turnLoopInMemoryStore) Set(_ context.Context, key string, value []byte) error { + s.data[key] = value + return nil +} + +type turnLoopTestModel struct { + messages []*schema.Message + idx int +} + +func (m *turnLoopTestModel) Generate(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + if m.idx >= len(m.messages) { + return nil, fmt.Errorf("no more messages") + } + msg := m.messages[m.idx] + m.idx++ + return msg, nil +} + +func (m *turnLoopTestModel) Stream(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + panic("not implemented") +} + +func (m *turnLoopTestModel) WithTools(_ []*schema.ToolInfo) (model.ToolCallingChatModel, error) { + return m, nil +} + +type turnLoopSlowModel struct { + delay int64 + startedCh unsafe.Pointer + doneCh chan struct{} + message *schema.Message +} + +func (m *turnLoopSlowModel) Generate(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + if ch := (*chan struct{})(atomic.LoadPointer(&m.startedCh)); ch != nil { + select { + case *ch <- struct{}{}: default: - return "", NonPreemptiveConsumeOption, context.DeadlineExceeded } - }} + } + if delay := atomic.LoadInt64(&m.delay); delay > 0 { + time.Sleep(time.Duration(delay)) + } + if m.doneCh != nil { + select { + case m.doneCh <- struct{}{}: + default: + } + } + return m.message, nil +} + +func (m *turnLoopSlowModel) Stream(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + panic("not implemented") +} + +type turnLoopInterruptTool struct { + name string +} + +func (t *turnLoopInterruptTool) Info(_ context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: t.name, + Desc: "A tool that interrupts", + }, nil +} + +func (t *turnLoopInterruptTool) InvokableRun(ctx context.Context, _ string, _ ...tool.Option) (string, error) { + wasInterrupted, _, _ := tool.GetInterruptState[any](ctx) + if !wasInterrupted { + return "", tool.Interrupt(ctx, "need approval") + } + isResumeTarget, hasData, data := tool.GetResumeContext[string](ctx) + if isResumeTarget && hasData { + return data, nil + } + return "approved", nil +} + +func TestTurnLoop_ExternalCancel_WithStore(t *testing.T) { + store := newTurnLoopInMemoryStore() + const checkPointID = "external-cancel-test" + + modelStarted := make(chan struct{}, 1) + testModel := &turnLoopSlowModel{ + doneCh: make(chan struct{}, 1), + message: schema.AssistantMessage("task completed", nil), + } + atomic.StoreInt64(&testModel.delay, int64(5*time.Second)) + atomic.StorePointer(&testModel.startedCh, unsafe.Pointer(&modelStarted)) + + agent, err := NewChatModelAgent(context.Background(), &ChatModelAgentConfig{ + Name: "test-agent", + Description: "test agent for external cancel", + Model: testModel, + }) + require.NoError(t, err) + + receiveCount := int32(0) + turnCount := int32(0) + frontBlocked := make(chan struct{}) + source := &turnLoopFuncSource[string]{ + receive: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { + cnt := atomic.AddInt32(&receiveCount, 1) + if cnt == 1 { + return ctx, "msg1", []ConsumeOption{WithConsumeCheckPointID(checkPointID)}, nil + } + return ctx, "", nil, ErrLoopExit + }, + front: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { + <-frontBlocked + return ctx, "", nil, context.DeadlineExceeded + }, + } loop, err := NewTurnLoop(TurnLoopConfig[string]{ Source: source, GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { - processedItems = append(processedItems, item) return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil }, - GetAgent: func(_ context.Context, item string) (Agent, error) { - if item == "non-cancel-msg" { - return nonCancellableAgent, nil + GetAgent: func(_ context.Context, _ string) (Agent, error) { + return agent, nil + }, + OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { + atomic.AddInt32(&turnCount, 1) + for { + if _, ok := iter.Next(); !ok { + break + } } - return fastAgent, nil + return nil }, + Store: store, + }) + require.NoError(t, err) + + ctx, cancel := loop.WithCancel(context.Background()) + done := make(chan error) + go func() { + done <- loop.Run(ctx) + }() + + select { + case <-modelStarted: + case err := <-done: + t.Fatalf("loop.Run returned early with error: %v", err) + case <-time.After(3 * time.Second): + t.Fatalf("timeout waiting for model to start") + } + + err = cancel(WithCancelMode(CancelImmediate)) + + var runErr error + select { + case runErr = <-done: + case <-time.After(3 * time.Second): + t.Fatalf("timeout waiting for loop.Run to return after cancel") + } + + var interruptErr *TurnLoopInterruptError[string] + require.True(t, errors.As(runErr, &interruptErr), "expected TurnLoopInterruptError, got: %v", runErr) + assert.Equal(t, checkPointID, interruptErr.CheckpointID) + assert.Equal(t, "msg1", interruptErr.Item) + assert.True(t, len(store.data) > 0, "checkpoint should be stored") + + atomic.StorePointer(&testModel.startedCh, nil) + atomic.StoreInt64(&testModel.delay, 0) + + done = make(chan error) + go func() { + done <- loop.Run(context.Background(), WithTurnLoopResume(checkPointID, "msg1")) + }() + + select { + case runErr = <-done: + case <-time.After(5 * time.Second): + t.Fatalf("timeout waiting for loop.Run (resume) to return") + } + assert.NoError(t, runErr) + assert.Equal(t, int32(2), atomic.LoadInt32(&turnCount), "should have 2 turns total (1 from first run + 1 from resume)") +} + +func TestTurnLoop_InternalToolInterrupt_WithCheckpoint_ThenResume(t *testing.T) { + store := newTurnLoopInMemoryStore() + const checkPointID = "tool-interrupt-test" + + interruptTool := &turnLoopInterruptTool{name: "approval_tool"} + + testModel := &turnLoopTestModel{ + messages: []*schema.Message{ + schema.AssistantMessage("", []schema.ToolCall{ + {ID: "call1", Function: schema.FunctionCall{Name: "approval_tool", Arguments: "{}"}}, + }), + schema.AssistantMessage("task completed", nil), + }, + } + + agent, err := NewChatModelAgent(context.Background(), &ChatModelAgentConfig{ + Name: "test-agent", + Description: "test agent", + Model: testModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{interruptTool}, + }, + }, + }) + require.NoError(t, err) + + receiveCount := int32(0) + frontBlocked := make(chan struct{}) + source := &turnLoopFuncSource[string]{ + receive: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { + cnt := atomic.AddInt32(&receiveCount, 1) + if cnt == 1 { + return ctx, "msg1", []ConsumeOption{WithConsumeCheckPointID(checkPointID)}, nil + } + return ctx, "", nil, ErrLoopExit + }, + front: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { + <-frontBlocked + return ctx, "", nil, context.DeadlineExceeded + }, + } + + var interruptErr *TurnLoopInterruptError[string] + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: source, + GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil + }, + GetAgent: func(_ context.Context, _ string) (Agent, error) { + return agent, nil + }, + OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := iter.Next(); !ok { + break + } + } + return nil + }, + Store: store, }) require.NoError(t, err) err = loop.Run(context.Background()) - assert.ErrorIs(t, err, context.DeadlineExceeded) - assert.Equal(t, []string{"non-cancel-msg", "preempt-msg"}, processedItems) + require.True(t, errors.As(err, &interruptErr), "expected TurnLoopInterruptError, got: %v", err) + assert.Equal(t, checkPointID, interruptErr.CheckpointID) + assert.Equal(t, "msg1", interruptErr.Item) + assert.Len(t, interruptErr.InterruptContexts, 1) + assert.Equal(t, "need approval", interruptErr.InterruptContexts[0].Info) + + interruptID := interruptErr.InterruptContexts[0].ID + assert.NotEmpty(t, interruptID) + assert.True(t, len(store.data) > 0, "checkpoint should be stored") + + testModel.idx = 1 + receiveCount = 0 + + resumeCtx := compose.ResumeWithData(context.Background(), interruptID, "user approved") + err = loop.Run(resumeCtx, WithTurnLoopResume(checkPointID, "msg1")) + assert.NoError(t, err) } From aaa035f57224ff9c6fe09396de31b7a7895f5731 Mon Sep 17 00:00:00 2001 From: Megumin Date: Tue, 24 Feb 2026 16:35:57 +0800 Subject: [PATCH 41/65] feat(adk): modify cancel func (#800) --- adk/cancel_test.go | 23 ++++++++++------------- adk/chatmodel.go | 11 +---------- adk/flow.go | 2 +- adk/interface.go | 5 +---- adk/turn_loop.go | 26 +++++++------------------- adk/turn_loop_test.go | 9 +++------ 6 files changed, 23 insertions(+), 53 deletions(-) diff --git a/adk/cancel_test.go b/adk/cancel_test.go index f45e344ff..702efd3c3 100644 --- a/adk/cancel_test.go +++ b/adk/cancel_test.go @@ -207,7 +207,7 @@ func TestRunWithCancel_WithTools(t *testing.T) { time.Sleep(100 * time.Millisecond) - err = cancelFn(ctx) + err = cancelFn() assert.NoError(t, err) start := time.Now() @@ -276,7 +276,7 @@ func TestRunWithCancel_WithTools(t *testing.T) { time.Sleep(100 * time.Millisecond) - err = cancelFn(ctx, WithCancelMode(CancelAfterChatModel)) + err = cancelFn(WithCancelMode(CancelAfterChatModel)) assert.NoError(t, err) var events []*AgentEvent @@ -342,7 +342,7 @@ func TestRunWithCancel_WithTools(t *testing.T) { time.Sleep(100 * time.Millisecond) - err = cancelFn(ctx, WithCancelMode(CancelAfterToolCall)) + err = cancelFn(WithCancelMode(CancelAfterToolCall)) assert.NoError(t, err) var events []*AgentEvent @@ -452,7 +452,7 @@ func TestRunWithCancel_WithCheckpoint(t *testing.T) { <-modelStarted - err = cancelFn(ctx) + err = cancelFn() assert.NoError(t, err) var events []*AgentEvent @@ -518,12 +518,9 @@ func TestCancelFuncMultipleCalls(t *testing.T) { <-modelStarted - cancelErr := cancelFn(ctx) + cancelErr := cancelFn() assert.NoError(t, cancelErr) - cancelErr = cancelFn(ctx) - assert.ErrorIs(t, cancelErr, ErrAgentFinished) - for { _, ok := iter.Next() if !ok { @@ -654,7 +651,7 @@ func TestRunWithCancel_Streaming(t *testing.T) { time.Sleep(100 * time.Millisecond) - cancelErr := cancelFn(ctx) + cancelErr := cancelFn() assert.NoError(t, cancelErr) start := time.Now() @@ -726,7 +723,7 @@ func TestRunWithCancel_Streaming(t *testing.T) { time.Sleep(100 * time.Millisecond) - cancelErr := cancelFn(ctx, WithCancelMode(CancelAfterToolCall)) + cancelErr := cancelFn(WithCancelMode(CancelAfterToolCall)) assert.NoError(t, cancelErr) var events []*AgentEvent @@ -808,7 +805,7 @@ func TestResumeWithCancel(t *testing.T) { <-modelStarted atomic.AddInt32(&modelCallCount, 1) - cancelErr := cancelFn(ctx) + cancelErr := cancelFn() assert.NoError(t, cancelErr) var events []*AgentEvent @@ -931,7 +928,7 @@ func TestResumeWithCancel(t *testing.T) { <-firstModelStarted atomic.AddInt32(&modelCallCount, 1) - cancelErr := cancelFn(ctx) + cancelErr := cancelFn() assert.NoError(t, cancelErr) for { @@ -1001,7 +998,7 @@ func TestResumeWithCancel(t *testing.T) { time.Sleep(100 * time.Millisecond) - err = resumeCancelFn(ctx) + err = resumeCancelFn() assert.NoError(t, err) start := time.Now() diff --git a/adk/chatmodel.go b/adk/chatmodel.go index effb270bd..10faa7434 100644 --- a/adk/chatmodel.go +++ b/adk/chatmodel.go @@ -1047,25 +1047,16 @@ func (a *ChatModelAgent) runInternal(ctx context.Context, input *AgentInput, wit func buildCancelFunc(cs *cancelSig) CancelFunc { var once sync.Once - - return func(_ context.Context, opts ...CancelOption) error { + return func(opts ...CancelOption) error { cfg := &cancelConfig{ Mode: CancelImmediate, } for _, opt := range opts { opt(cfg) } - - cancelled := false once.Do(func() { cs.cancel(cfg) - cancelled = true }) - - if !cancelled { - return ErrAgentFinished - } - return nil } } diff --git a/adk/flow.go b/adk/flow.go index 93143c641..7acfd2137 100644 --- a/adk/flow.go +++ b/adk/flow.go @@ -387,7 +387,7 @@ func (a *flowAgent) runInternal(ctx context.Context, input *AgentInput, withCanc return iterator, cancelFn } -func notCancellableFuncInternal(_ context.Context, _ ...CancelOption) error { +func notCancellableFuncInternal(_ ...CancelOption) error { return ErrAgentNotCancellable } diff --git a/adk/interface.go b/adk/interface.go index 1d2f5f070..c73705ae3 100644 --- a/adk/interface.go +++ b/adk/interface.go @@ -288,9 +288,6 @@ const ( CancelAfterToolCall ) -// ErrAgentFinished is returned by Cancel when the agent has already finished execution. -var ErrAgentFinished = errors.New("agent has already finished execution") - // ErrAgentNotCancellable is returned by Cancel when the agent does not support cancellation. var ErrAgentNotCancellable = errors.New("agent does not implement CancellableAgent interface") @@ -315,7 +312,7 @@ func WithCancelTimeout(timeout time.Duration) CancelOption { } } -type CancelFunc func(context.Context, ...CancelOption) error +type CancelFunc func(...CancelOption) error type CancellableAgent interface { Agent diff --git a/adk/turn_loop.go b/adk/turn_loop.go index af69ae5d3..94b39fe6c 100644 --- a/adk/turn_loop.go +++ b/adk/turn_loop.go @@ -192,11 +192,6 @@ func getTurnLoopCancelSig(ctx context.Context) *turnLoopCancelSig { return nil } -// TurnLoopCancelFunc is the cancel function returned by WithCancel. -// Unlike Agent's CancelFunc, it does not require a context parameter -// since the context is already bound when WithCancel is called. -type TurnLoopCancelFunc func(opts ...CancelOption) error - // ErrAgentNotCancellableInTurnLoop is returned when WithCancel context is used // but the Agent does not implement CancellableAgent. var ErrAgentNotCancellableInTurnLoop = errors.New("agent does not support cancel but WithCancel context was provided") @@ -258,7 +253,7 @@ var ErrLoopExit = errors.New("loop exit") // }() // // Later, to cancel: // cancel(adk.WithCancelMode(adk.CancelAfterToolCall)) -func (l *TurnLoop[T]) WithCancel(ctx context.Context) (context.Context, TurnLoopCancelFunc) { +func (l *TurnLoop[T]) WithCancel(ctx context.Context) (context.Context, CancelFunc) { cs := newTurnLoopCancelSig() ctx = withTurnLoopCancelSig(ctx, cs) @@ -270,16 +265,9 @@ func (l *TurnLoop[T]) WithCancel(ctx context.Context) (context.Context, TurnLoop for _, opt := range opts { opt(cfg) } - - cancelled := false once.Do(func() { cs.cancel(cfg) - cancelled = true }) - - if !cancelled { - return ErrAgentFinished - } return nil } @@ -453,8 +441,8 @@ func (l *TurnLoop[T]) Run(ctx context.Context, opts ...TurnLoopRunOption[T]) err }(): externalCancelled = true cfg := cs.getConfig() - err := cancelFunc(nCtx, cancelConfigToOpts(cfg)...) - if err != nil && !errors.Is(err, ErrAgentFinished) { + err := cancelFunc(cancelConfigToOpts(cfg)...) + if err != nil { <-done return fmt.Errorf("failed to cancel agent: %w", err) } @@ -476,7 +464,7 @@ func (l *TurnLoop[T]) Run(ctx context.Context, opts ...TurnLoopRunOption[T]) err o := applyConsumeOptions(option) switch o.Mode { case ConsumePreemptive: - err := cancelFunc(nCtx, o.CancelOpts...) + err := cancelFunc(o.CancelOpts...) if err != nil { <-done return fmt.Errorf("failed to cancel agent: %w", err) @@ -485,7 +473,7 @@ func (l *TurnLoop[T]) Run(ctx context.Context, opts ...TurnLoopRunOption[T]) err select { case <-done: case <-time.After(o.Timeout): - err := cancelFunc(nCtx, o.CancelOpts...) + err := cancelFunc(o.CancelOpts...) if err != nil { <-done return fmt.Errorf("failed to cancel agent: %w", err) @@ -497,8 +485,8 @@ func (l *TurnLoop[T]) Run(ctx context.Context, opts ...TurnLoopRunOption[T]) err return nil }(): cfg := cs.getConfig() - err := cancelFunc(nCtx, cancelConfigToOpts(cfg)...) - if err != nil && !errors.Is(err, ErrAgentFinished) { + err := cancelFunc(cancelConfigToOpts(cfg)...) + if err != nil { <-done return fmt.Errorf("failed to cancel agent: %w", err) } diff --git a/adk/turn_loop_test.go b/adk/turn_loop_test.go index 7c22f3699..88e3679b3 100644 --- a/adk/turn_loop_test.go +++ b/adk/turn_loop_test.go @@ -112,7 +112,7 @@ func (a *turnLoopCancellableAgent) RunWithCancel(_ context.Context, _ *AgentInpu defer gen.Close() <-a.cancelCh }() - cancelFunc := func(_ context.Context, opts ...CancelOption) error { + cancelFunc := func(opts ...CancelOption) error { atomic.StoreInt32(&a.cancelled, 1) a.cancelledOpt = opts close(a.cancelCh) @@ -747,11 +747,8 @@ func TestTurnLoop_WithCancel_MultipleCalls(t *testing.T) { _, cancel := loop.WithCancel(context.Background()) - err1 := cancel() - err2 := cancel() - - assert.NoError(t, err1) - assert.ErrorIs(t, err2, ErrAgentFinished) + err = cancel() + assert.NoError(t, err) } func TestTurnLoop_WithCancel_IndependentCancels(t *testing.T) { From bb3f4cbc777358c658e611fddf2bbf2df93e30dd Mon Sep 17 00:00:00 2001 From: Megumin Date: Tue, 24 Feb 2026 21:07:31 +0800 Subject: [PATCH 42/65] fix(adk): select cancel after front without preemptive (#802) --- adk/turn_loop.go | 65 ++++++++++++++++++++++++++---------------------- 1 file changed, 35 insertions(+), 30 deletions(-) diff --git a/adk/turn_loop.go b/adk/turn_loop.go index 94b39fe6c..f7ec423c5 100644 --- a/adk/turn_loop.go +++ b/adk/turn_loop.go @@ -179,6 +179,13 @@ func (cs *turnLoopCancelSig) getConfig() *cancelConfig { return nil } +func (cs *turnLoopCancelSig) getDoneChan() <-chan struct{} { + if cs != nil { + return cs.done + } + return nil +} + type turnLoopCancelSigKey struct{} func withTurnLoopCancelSig(ctx context.Context, cs *turnLoopCancelSig) context.Context { @@ -395,8 +402,8 @@ func (l *TurnLoop[T]) Run(ctx context.Context, opts ...TurnLoopRunOption[T]) err return l.handleEvents(ctx, item, iter, checkPointID) } - var handleEventErr error if cancelFunc != nil { + var handleEventErr error done := make(chan struct{}) go func() { @@ -429,27 +436,14 @@ func (l *TurnLoop[T]) Run(ctx context.Context, opts ...TurnLoopRunOption[T]) err }) }() - var externalCancelled bool select { case <-frontDone: case <-done: - case <-func() <-chan struct{} { - if cs != nil { - return cs.done - } - return nil - }(): - externalCancelled = true - cfg := cs.getConfig() - err := cancelFunc(cancelConfigToOpts(cfg)...) + case <-cs.getDoneChan(): + err := cancelAndWait(cancelFunc, cs, done) if err != nil { - <-done - return fmt.Errorf("failed to cancel agent: %w", err) + return err } - } - - if externalCancelled { - <-done return l.wrapHandleEventErr(handleEventErr) } @@ -478,29 +472,29 @@ func (l *TurnLoop[T]) Run(ctx context.Context, opts ...TurnLoopRunOption[T]) err <-done return fmt.Errorf("failed to cancel agent: %w", err) } - case <-func() <-chan struct{} { - if cs != nil { - return cs.done - } - return nil - }(): - cfg := cs.getConfig() - err := cancelFunc(cancelConfigToOpts(cfg)...) + case <-cs.getDoneChan(): + err := cancelAndWait(cancelFunc, cs, done) if err != nil { - <-done - return fmt.Errorf("failed to cancel agent: %w", err) + return err } - <-done return l.wrapHandleEventErr(handleEventErr) } } - <-done + select { + case <-done: + case <-cs.getDoneChan(): + err := cancelAndWait(cancelFunc, cs, done) + if err != nil { + return err + } + return l.wrapHandleEventErr(handleEventErr) + } if err := l.wrapHandleEventErr(handleEventErr); err != nil { return err } } else { - if handleEventErr = handleEvents(); handleEventErr != nil { + if handleEventErr := handleEvents(); handleEventErr != nil { if err := l.wrapHandleEventErr(handleEventErr); err != nil { return err } @@ -545,6 +539,17 @@ func (l *TurnLoop[T]) handleEvents(ctx context.Context, item T, iter *AsyncItera return nil } +func cancelAndWait(cf CancelFunc, cs *turnLoopCancelSig, done chan struct{}) error { + cfg := cs.getConfig() + err := cf(cancelConfigToOpts(cfg)...) + if err != nil { + <-done + return fmt.Errorf("failed to cancel agent: %w", err) + } + <-done + return nil +} + func cancelConfigToOpts(cfg *cancelConfig) []CancelOption { if cfg == nil { return nil From 2cd1c25787db93d4ddcca961ff045ec9a082cacc Mon Sep 17 00:00:00 2001 From: IPender Date: Mon, 2 Mar 2026 21:43:51 +0800 Subject: [PATCH 43/65] fix: implement IsCallbacksEnabled and GetType for cancelableChatModel (#826) --- adk/cancel_wrapper.go | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/adk/cancel_wrapper.go b/adk/cancel_wrapper.go index 9c00df77f..396d6a458 100644 --- a/adk/cancel_wrapper.go +++ b/adk/cancel_wrapper.go @@ -19,11 +19,14 @@ package adk import ( "context" "io" + "reflect" "runtime/debug" "time" + "github.com/cloudwego/eino/components" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/internal/generic" "github.com/cloudwego/eino/internal/safe" "github.com/cloudwego/eino/schema" ) @@ -118,6 +121,18 @@ func (c *cancelableChatModel) Stream(ctx context.Context, input []*schema.Messag return res.result, res.err } +func (c *cancelableChatModel) IsCallbacksEnabled() bool { + return components.IsCallbacksEnabled(c.inner) +} + +func (c *cancelableChatModel) GetType() string { + if name, ok := components.GetType(c.inner); ok { + return name + } + + return generic.ParseTypeName(reflect.ValueOf(c.inner)) +} + func cancelableToolInvokable(cs *cancelSig, endpoint compose.InvokableToolEndpoint) compose.InvokableToolEndpoint { return func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) { if cfg := checkCancelSig(cs); cfg != nil && cfg.Mode == CancelImmediate { From 14044c22367c3e30e0674ef59d002124e7a3be54 Mon Sep 17 00:00:00 2001 From: "shentong.martin" Date: Mon, 23 Mar 2026 10:39:45 +0800 Subject: [PATCH 44/65] fix: rebase error Change-Id: If20fa78dba82a1c177c8ec47090050ea8c1354ed --- adk/chatmodel.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/adk/chatmodel.go b/adk/chatmodel.go index 10faa7434..0848a859a 100644 --- a/adk/chatmodel.go +++ b/adk/chatmodel.go @@ -1119,7 +1119,7 @@ func (a *ChatModelAgent) resumeInternal(ctx context.Context, info *ResumeInfo, w generator.Send(&AgentEvent{Err: err}) generator.Close() }() - return iterator, nil + return iterator, notCancellableFuncInternal } var historyModifier func(ctx context.Context, history []Message) []Message From 1230732d043eefd6ef62e9b796b202768523733c Mon Sep 17 00:00:00 2001 From: shentongmartin Date: Thu, 26 Mar 2026 19:44:26 +0800 Subject: [PATCH 45/65] refactor(adk): replace TurnLoop with push-based API (#835) --- .gitignore | 5 +- adk/agent_tool.go | 21 +- adk/call_option.go | 28 + adk/cancel.go | 885 ++++ adk/cancel_edge_test.go | 1268 +++++ adk/cancel_multicall_test.go | 125 + adk/cancel_test.go | 2885 ++++++++++- adk/cancel_wrapper.go | 295 -- adk/chatmodel.go | 373 +- adk/chatmodel_retry_test.go | 145 + adk/flow.go | 114 +- adk/handler.go | 20 + adk/handler_test.go | 318 ++ adk/interface.go | 54 - adk/interrupt.go | 39 +- .../patchtoolcalls/patchtoolcalls.go | 2 +- adk/prebuilt/planexecute/plan_execute_test.go | 232 + adk/react.go | 226 +- adk/react_test.go | 37 +- adk/retry_chatmodel.go | 9 + adk/runctx_test.go | 209 + adk/runner.go | 127 +- adk/turn_loop.go | 1657 +++++-- adk/turn_loop_test.go | 4202 +++++++++++++---- adk/utils.go | 5 +- adk/workflow.go | 88 +- adk/wrappers.go | 38 +- compose/checkpoint_test.go | 28 + compose/graph_manager.go | 9 +- compose/graph_run.go | 82 +- compose/interrupt.go | 4 + examples | 1 + ext | 1 + internal/channel.go | 45 + internal/channel_test.go | 241 + internal/core/address.go | 25 +- internal/core/interrupt.go | 7 + schema/serialization.go | 2 +- 38 files changed, 11572 insertions(+), 2280 deletions(-) create mode 100644 adk/cancel.go create mode 100644 adk/cancel_edge_test.go create mode 100644 adk/cancel_multicall_test.go delete mode 100644 adk/cancel_wrapper.go create mode 160000 examples create mode 160000 ext diff --git a/.gitignore b/.gitignore index 8ac1d568d..04542d49a 100644 --- a/.gitignore +++ b/.gitignore @@ -50,8 +50,11 @@ reports/ /todos .DS_Store -*.log +*.log* +.claude CLAUDE.md +*.jsonl +*.txt # Specs directories */specs diff --git a/adk/agent_tool.go b/adk/agent_tool.go index 9472dab1f..fde319cb4 100644 --- a/adk/agent_tool.go +++ b/adk/agent_tool.go @@ -167,16 +167,19 @@ func (at *agentTool) InvokableRun(ctx context.Context, argumentsInJSON string, o } iter = newInvokableAgentToolRunner(at.agent, ms, enableStreaming).Run(ctx, input, - append(getOptionsByAgentName(at.agent.Name(ctx), opts), WithCheckPointID(bridgeCheckpointID), withSharedParentSession())...) + append(extractAndDeriveCancelCtx(ctx, at.agent.Name(ctx), opts), WithCheckPointID(bridgeCheckpointID), withSharedParentSession())...) } else { if !hasState { return "", fmt.Errorf("agent tool '%s' interrupt has happened, but cannot find interrupt state", at.agent.Name(ctx)) } - ms = newResumeBridgeStore(state) + ms = newResumeBridgeStore(bridgeCheckpointID, state) + + agentOpts := extractAndDeriveCancelCtx(ctx, at.agent.Name(ctx), opts) + agentOpts = append(agentOpts, withSharedParentSession()) iter, err = newInvokableAgentToolRunner(at.agent, ms, enableStreaming). - Resume(ctx, bridgeCheckpointID, append(getOptionsByAgentName(at.agent.Name(ctx), opts), withSharedParentSession())...) + Resume(ctx, bridgeCheckpointID, agentOpts...) if err != nil { return "", err } @@ -281,6 +284,18 @@ func getOptionsByAgentName(agentName string, opts []tool.Option) []AgentRunOptio return ret } +func extractAndDeriveCancelCtx(ctx context.Context, agentName string, opts []tool.Option) []AgentRunOption { + agentOpts := getOptionsByAgentName(agentName, opts) + baseOpts := getCommonOptions(nil, agentOpts...) + if baseOpts.cancelCtx != nil { + childCtx := baseOpts.cancelCtx.deriveChild(ctx) + agentOpts = append(agentOpts, WrapImplSpecificOptFn(func(o *options) { + o.cancelCtx = childCtx + })) + } + return agentOpts +} + func getEmitGeneratorAndEnableStreaming(opts []tool.Option) (*AsyncGenerator[*AgentEvent], bool) { o := tool.GetImplSpecificOptions[agentToolOptions](nil, opts...) if o == nil { diff --git a/adk/call_option.go b/adk/call_option.go index 55e57fd32..ead6ae636 100644 --- a/adk/call_option.go +++ b/adk/call_option.go @@ -24,6 +24,7 @@ type options struct { checkPointID *string skipTransferMessages bool handlers []callbacks.Handler + cancelCtx *cancelContext } // AgentRunOption is the call option for adk Agent. @@ -157,6 +158,33 @@ func filterCallbackHandlersForNestedAgents(currentAgentName string, opts []Agent return filteredOpts } +// filterCancelOption removes any AgentRunOption that sets a cancelCtx on *options. +// This prevents inner (nested) agents from receiving the cancel option when the +// outer flowAgent owns the cancel lifecycle. Inner agents access the cancelContext +// via the Go context (getCancelContext) instead. +func filterCancelOption(opts []AgentRunOption) []AgentRunOption { + if len(opts) == 0 { + return nil + } + var filteredOpts []AgentRunOption + for i := range opts { + opt := opts[i] + if opt.implSpecificOptFn == nil { + filteredOpts = append(filteredOpts, opt) + continue + } + if _, isCommonOpt := opt.implSpecificOptFn.(func(*options)); isCommonOpt { + testOpt := &options{} + opt.implSpecificOptFn.(func(*options))(testOpt) + if testOpt.cancelCtx != nil { + continue + } + } + filteredOpts = append(filteredOpts, opt) + } + return filteredOpts +} + func filterOptions(agentName string, opts []AgentRunOption) []AgentRunOption { if len(opts) == 0 { return nil diff --git a/adk/cancel.go b/adk/cancel.go new file mode 100644 index 000000000..20d72bb20 --- /dev/null +++ b/adk/cancel.go @@ -0,0 +1,885 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "errors" + "fmt" + "io" + "sync" + "sync/atomic" + "time" + + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" +) + +func init() { + schema.RegisterName[*CancelError]("_eino_adk_cancel_error") + schema.RegisterName[*AgentCancelInfo]("_eino_adk_agent_cancel_info") + schema.RegisterName[*StreamCanceledError]("_eino_adk_stream_cancelled_error") +} + +// CancelMode specifies when an agent should be canceled. +// Modes can be combined with bitwise OR to cancel at multiple safe-points. +// For example, CancelAfterChatModel | CancelAfterToolCalls cancels the agent +// after whichever safe-point is reached first. +type CancelMode int + +const ( + // CancelImmediate cancels the agent as soon as the signal is received, + // without waiting for a ChatModel or ToolCalls safe-point. Propagates + // to all descendant agents via the cancel context hierarchy, including + // agents nested inside AgentTools and workflow sub-agents. + CancelImmediate CancelMode = 0 + // CancelAfterChatModel cancels after the first chat model call that completes + // anywhere in the agent hierarchy, including nested sub-agents, agent tools, + // and workflow branches. The cancel mode propagates to all descendant agents; + // whichever ChatModel finishes first triggers the cancel. The interrupting + // agent emits an interrupt that bubbles up through the agent tree — parent + // agents do not need to reach their own ChatModel safe-point. + CancelAfterChatModel CancelMode = 1 << iota + // CancelAfterToolCalls cancels after the first set of concurrent tool calls + // that completes anywhere in the agent hierarchy. Like CancelAfterChatModel, + // this mode propagates to all descendants and fires at whichever level + // reaches the safe-point first. + CancelAfterToolCalls +) + +// CancelHandle represents a cancel operation that can be waited on. +type CancelHandle struct { + wait func() error +} + +func (h *CancelHandle) Wait() error { + return h.wait() +} + +// AgentCancelFunc is called to request cancellation of a running agent. +// It returns after the cancel request is committed; use the returned handle's +// Wait to block for completion and outcome. +// +// The returned bool reports whether this call contributed to the CancelError +// for the current execution. "Contributed" means this call's cancel options +// were included before cancellation was finalized. It is false when cancellation +// was already finalized (handled or execution completed). +type AgentCancelFunc func(...AgentCancelOption) (*CancelHandle, bool) + +type agentCancelConfig struct { + Mode CancelMode + Timeout *time.Duration +} + +// AgentCancelOption configures cancel behavior. +type AgentCancelOption func(*agentCancelConfig) + +// WithAgentCancelMode sets the cancel mode for the agent cancel operation. +func WithAgentCancelMode(mode CancelMode) AgentCancelOption { + return func(config *agentCancelConfig) { + config.Mode = mode + } +} + +// WithAgentCancelTimeout sets a timeout for the cancel operation. +// This only applies to safe-point modes (CancelAfterChatModel, CancelAfterToolCalls): +// if the safe-point hasn't fired within this duration, the cancel escalates to +// an immediate graph interrupt. +// For CancelImmediate this timeout is ignored — the graph interrupt fires +// immediately with timeout=0. +func WithAgentCancelTimeout(timeout time.Duration) AgentCancelOption { + return func(config *agentCancelConfig) { + config.Timeout = &timeout + } +} + +// AgentCancelInfo contains information about a cancel operation. +type AgentCancelInfo struct { + Mode CancelMode + Escalated bool + Timeout bool +} + +// CancelError is sent via AgentEvent.Err when an agent is canceled. +// Use errors.As to match and extract *CancelError from event errors. +// +// Interrupt absorption: when a cancel is active (shouldCancel() == true), ANY +// interrupt — whether from a cancel safe-point node or from business logic +// (e.g. compose.Interrupt in a tool) — is converted to a CancelError. The +// cancel "absorbs" the business interrupt. This is intentional: +// +// - In concurrent execution (parallel workflows, concurrent tool calls), +// cancel-induced and business interrupts can arrive as a single composite +// signal that cannot be split apart. +// - Even in sequential execution, treating business interrupts as CancelError +// during active cancel gives consistent semantics. +// - The business interrupt is NOT lost — the checkpoint preserves the full +// interrupt hierarchy. On resume (Runner.Resume), the agent re-executes +// the interrupting code path and the business interrupt re-fires naturally. +type CancelError struct { + Info *AgentCancelInfo + + // CheckPointID is the checkpoint ID associated with this cancel operation. + // When non-empty, the cancelled agent's state has been persisted under this ID + // and can be resumed via Runner.Resume or GenInputResult.ResumeFromCheckpointID. + CheckPointID string + + // InterruptContexts provides the interrupt contexts needed for targeted + // resumption via Runner.ResumeWithParams. Each context represents a step + // in the agent hierarchy that was interrupted. This is a slice because + // composite agents (e.g. parallel workflows) may interrupt at multiple + // points simultaneously, matching the shape of AgentAction.Interrupted.InterruptContexts. + // Use each InterruptCtx.ID as a key in ResumeParams.Targets. + InterruptContexts []*InterruptCtx + + interruptSignal *InterruptSignal // unexported — only Runner needs it for checkpoint +} + +func (e *CancelError) Error() string { + return fmt.Sprintf("agent canceled: mode=%v, escalated=%v", e.Info.Mode, e.Info.Escalated) +} + +// Sentinel errors for cancel outcomes. +var ( + // ErrCancelTimeout is returned by CancelHandle.Wait when the cancel operation timed out. + ErrCancelTimeout = errors.New("cancel timed out") + + // ErrExecutionCompleted is returned by CancelHandle.Wait when the agent finished + // before the cancel took effect. "Finished" means the event stream was fully + // drained without any interrupt — normal completion or a fatal error. + // + // Note: business interrupts that occur while cancel is active are absorbed + // into CancelError (see CancelError doc), so they result in nil (cancel + // succeeded), NOT ErrExecutionCompleted. Only execution that completes with + // no interrupt at all produces this error. + ErrExecutionCompleted = errors.New("execution already completed") + + // ErrStreamCanceled is the error sent through the stream when CancelImmediate aborts it. + // It is a *StreamCanceledError so it can be gob-serialized during checkpoint save + // (when stored as agentEventWrapper.StreamErr). + ErrStreamCanceled error = &StreamCanceledError{} +) + +// StreamCanceledError is the concrete error type for ErrStreamCanceled. +// It is exported so that gob can serialize it during checkpoint save when the error +// is stored in agentEventWrapper.StreamErr. +type StreamCanceledError struct{} + +func (e *StreamCanceledError) Error() string { + return "stream canceled" +} + +// WithCancel creates an AgentRunOption that enables cancellation for an agent run. +// It returns the option to pass to Run/Resume and a cancel function. +// Cancel options (mode, timeout) are passed to the returned AgentCancelFunc at call time. +func WithCancel() (AgentRunOption, AgentCancelFunc) { + cc := newCancelContext() + opt := WrapImplSpecificOptFn(func(o *options) { + o.cancelCtx = cc + }) + cancelFn := cc.buildCancelFunc() + return opt, cancelFn +} + +// cancelContext state constants (for int32 CAS). +// +// State transition rules: +// +// stateRunning -> stateCancelling (cancel requested by AgentCancelFunc) +// stateRunning -> stateDone (execution finished without interrupt) +// stateCancelling -> stateCancelHandled (ANY interrupt absorbed as CancelError) +// stateCancelling -> stateDone (execution finished without interrupt while cancel pending) +// +// Terminal states: stateDone, stateCancelHandled. +// +// Note: We intentionally do NOT distinguish between "completed" and "errored" +// terminal states. End-users get the actual outcome from AgentEvent. +// This simplification keeps the state machine minimal — only the cancel/non-cancel +// distinction matters for the AgentCancelFunc return value. +// +// Business interrupt handling: when cancel is active (stateCancelling) and any +// interrupt arrives — cancel-induced OR business — wrapIterWithCancelCtx absorbs +// it as a CancelError and transitions to stateCancelHandled. The business interrupt +// data is preserved in the checkpoint for re-emission on resume. +const ( + // stateRunning is the initial state: agent is executing normally. + stateRunning int32 = 0 + // stateCancelling means AgentCancelFunc has been called and cancelChan is + // closed, but the cancel has not yet been handled by the runFunc. + stateCancelling int32 = 1 + // stateDone means execution has finished through any non-cancel path: + // normal completion, business interrupt, or error. The specific outcome + // is conveyed through AgentEvent, not through the cancel state machine. + stateDone int32 = 2 + // stateCancelHandled means the cancel was processed by the runFunc and a + // CancelError was emitted through the event stream. This is the success + // terminal state for cancellation. + stateCancelHandled int32 = 5 +) + +// interruptSent constants (for int32 CAS). +// +// Transition rules: +// +// interruptNotSent -> interruptImmediate (CancelImmediate or escalation) +const ( + // interruptNotSent means no compose graph interrupt has been sent. + interruptNotSent int32 = 0 + // interruptImmediate means an immediate graph interrupt was sent with + // timeout=0, forcing the graph to stop as soon as possible. + interruptImmediate int32 = 1 +) + +// defaultCancelImmediateGracePeriod is the time a parent's graph interrupt +// waits when the cancelContext has active children (via deriveChild). This +// gives child agents time to propagate their interrupt signal back through +// the agentTool as a CompositeInterrupt. If this proves insufficient for +// deeply nested structures or too slow for latency-sensitive use cases, +// consider making it configurable via an AgentCancelOption. +const defaultCancelImmediateGracePeriod = 1 * time.Second + +type cancelContextKey struct{} + +// withCancelContext stores a cancelContext in the Go context. +func withCancelContext(ctx context.Context, cc *cancelContext) context.Context { + if cc == nil { + return ctx + } + return context.WithValue(ctx, cancelContextKey{}, cc) +} + +// getCancelContext retrieves the cancelContext from the Go context, or nil. +func getCancelContext(ctx context.Context) *cancelContext { + if v := ctx.Value(cancelContextKey{}); v != nil { + return v.(*cancelContext) + } + return nil +} + +type cancelContext struct { + mode int32 // atomic, CancelMode + + cancelChan chan struct{} // closed when cancel is requested (all modes, not just safe-point) + immediateChan chan struct{} // closed when an immediate graph interrupt fires + doneChan chan struct{} // closed when execution completes (by any mark* method) + doneOnce sync.Once // ensures doneChan is closed exactly once + + state int32 // stateRunning, stateCancelling, stateDone, stateCancelHandled + interruptSent int32 // interruptNotSent, interruptImmediate + escalated int32 // 1 if escalated from safe-point to immediate + timeoutEscalated int32 // 1 if escalation was triggered by timeout + startedMode int32 // atomic, mode when state transitioned to cancelling + deadlineUnixNano int64 // atomic, 0 means no deadline + + root bool // true for the original cancelContext created by WithCancel(); false for derived children + parent *cancelContext // non-nil for derived children; used to decrement parent's activeChildren on markDone + + activeChildren int32 // atomic; number of derived children that haven't called markDone() yet + decrementedParent int32 // atomic CAS guard; ensures parent.activeChildren is decremented at most once + + cancelMu sync.Mutex + timeoutOnce sync.Once + timeoutNotify chan struct{} + + mu sync.Mutex + graphInterruptFuncs []func(...compose.GraphInterruptOption) +} + +func newCancelContext() *cancelContext { + return &cancelContext{ + cancelChan: make(chan struct{}), + immediateChan: make(chan struct{}), + doneChan: make(chan struct{}), + timeoutNotify: make(chan struct{}, 1), + root: true, + } +} + +func (cc *cancelContext) isRoot() bool { + return cc != nil && cc.root +} + +// deriveChild creates a child cancelContext that receives cancel propagation +// from the parent. The caller MUST ensure the child's markDone() is eventually +// called (e.g., via wrapIterWithCancelCtx's defer) or that ctx is canceled; +// otherwise the two propagation goroutines will leak. +func (cc *cancelContext) deriveChild(ctx context.Context) *cancelContext { + if cc == nil { + return nil + } + child := newCancelContext() + child.root = false + child.parent = cc + atomic.AddInt32(&cc.activeChildren, 1) + + go func() { + select { + case <-cc.cancelChan: + child.triggerCancel(cc.getMode()) + case <-child.doneChan: + case <-ctx.Done(): + } + }() + + go func() { + select { + case <-cc.immediateChan: + child.triggerImmediateCancel() + case <-child.doneChan: + case <-ctx.Done(): + } + }() + + return child +} + +func (cc *cancelContext) triggerCancel(mode CancelMode) { + cc.setMode(mode) + if atomic.CompareAndSwapInt32(&cc.state, stateRunning, stateCancelling) { + close(cc.cancelChan) + } +} + +func (cc *cancelContext) triggerImmediateCancel() { + atomic.StoreInt32(&cc.escalated, 1) + cc.setMode(CancelImmediate) + if atomic.CompareAndSwapInt32(&cc.state, stateRunning, stateCancelling) { + close(cc.cancelChan) + } + cc.sendImmediateInterrupt() +} + +func (cc *cancelContext) getMode() CancelMode { + if cc == nil { + return CancelImmediate + } + return CancelMode(atomic.LoadInt32(&cc.mode)) +} + +func (cc *cancelContext) setMode(mode CancelMode) { + atomic.StoreInt32(&cc.mode, int32(mode)) +} + +func (cc *cancelContext) getDeadlineUnixNano() int64 { + return atomic.LoadInt64(&cc.deadlineUnixNano) +} + +func (cc *cancelContext) setDeadlineUnixNano(v int64) { + atomic.StoreInt64(&cc.deadlineUnixNano, v) +} + +func (cc *cancelContext) wakeTimeoutController() { + select { + case cc.timeoutNotify <- struct{}{}: + default: + } +} + +// shouldCancel returns true if a cancel has been requested (cancelChan is closed). +func (cc *cancelContext) shouldCancel() bool { + if cc == nil { + return false + } + select { + case <-cc.cancelChan: + return true + default: + return false + } +} + +// sendImmediateInterrupt sends the compose graph interrupt signal via graphInterruptFuncs. +// Also closes immediateChan (used by cancelMonitoredModel to abort an in-progress stream). +// Returns false if an interrupt was already sent or if no graphInterruptFuncs have been +// registered yet (the deferred fire in setGraphInterruptFunc will handle that case). +func (cc *cancelContext) sendImmediateInterrupt() bool { + cc.mu.Lock() + + if !atomic.CompareAndSwapInt32(&cc.interruptSent, interruptNotSent, interruptImmediate) { + cc.mu.Unlock() + return false + } + + close(cc.immediateChan) + + fns := make([]func(...compose.GraphInterruptOption), len(cc.graphInterruptFuncs)) + copy(fns, cc.graphInterruptFuncs) + + if len(fns) == 0 { + cc.mu.Unlock() + return false + } + + for _, fn := range fns { + fn(compose.WithGraphInterruptTimeout(0)) + } + cc.mu.Unlock() + return true +} + +// setGraphInterruptFunc appends a graph interrupt function to the list. +// If an immediate cancel was already requested, fires it retroactively. +// Multiple functions can be registered (e.g. one per parallel sub-agent). +// +// Both this method and sendImmediateInterrupt hold cc.mu across the entire +// check-and-fire sequence, ensuring each interrupt function is called exactly +// once (compose.WithGraphInterrupt returns a non-idempotent closure that panics +// on double-call). +func (cc *cancelContext) setGraphInterruptFunc(interrupt func(...compose.GraphInterruptOption)) { + cc.mu.Lock() + cc.graphInterruptFuncs = append(cc.graphInterruptFuncs, interrupt) + + shouldFire := atomic.LoadInt32(&cc.interruptSent) == interruptImmediate + if shouldFire { + interrupt(compose.WithGraphInterruptTimeout(0)) + } + cc.mu.Unlock() +} + +// markDone marks the execution as finished through any non-cancel path +// (normal completion, business interrupt, or error). +// This is safe to call even if a cancel is in progress — it allows the +// cancel func to detect that execution finished before cancel took effect. +func (cc *cancelContext) markDone() { + if cc == nil { + return + } + if atomic.CompareAndSwapInt32(&cc.state, stateRunning, stateDone) || + atomic.CompareAndSwapInt32(&cc.state, stateCancelling, stateDone) { + cc.doneOnce.Do(func() { close(cc.doneChan) }) + cc.detachFromParent() + } +} + +func (cc *cancelContext) detachFromParent() { + if cc.parent != nil && atomic.CompareAndSwapInt32(&cc.decrementedParent, 0, 1) { + atomic.AddInt32(&cc.parent.activeChildren, -1) + } +} + +func (cc *cancelContext) hasActiveChildren() bool { + return cc != nil && atomic.LoadInt32(&cc.activeChildren) > 0 +} + +func (cc *cancelContext) wrapGraphInterruptWithGracePeriod(interrupt func(...compose.GraphInterruptOption)) func(...compose.GraphInterruptOption) { + return func(opts ...compose.GraphInterruptOption) { + if cc.hasActiveChildren() { + newOpts := make([]compose.GraphInterruptOption, len(opts)+1) + copy(newOpts, opts) + newOpts[len(opts)] = compose.WithGraphInterruptTimeout(defaultCancelImmediateGracePeriod) + opts = newOpts + } + interrupt(opts...) + } +} + +// markCancelHandled signals that the cancel path in the runFunc has created +// and sent a CancelError. Transitions state to stateCancelHandled so that: +// 1. The deferred markDone() becomes a no-op (CAS from cancelling fails). +// 2. buildCancelFunc sees stateCancelHandled and returns nil (cancel succeeded). +// Returns true if the transition succeeded, false if cancel was already handled +// (e.g., by a sub-agent). This prevents duplicate CancelError emission. +func (cc *cancelContext) markCancelHandled() bool { + if cc == nil { + return false + } + if atomic.CompareAndSwapInt32(&cc.state, stateCancelling, stateCancelHandled) { + cc.doneOnce.Do(func() { close(cc.doneChan) }) + cc.detachFromParent() + return true + } + return false +} + +// createCancelError creates a CancelError based on the current cancel state. +func (cc *cancelContext) createCancelError() *CancelError { + info := &AgentCancelInfo{} + info.Mode = cc.getMode() + if atomic.LoadInt32(&cc.escalated) == 1 { + info.Escalated = true + info.Timeout = atomic.LoadInt32(&cc.timeoutEscalated) == 1 + } + return &CancelError{ + Info: info, + } +} + +func (cc *cancelContext) createAndMarkCancelHandled() (*CancelError, bool) { + cc.cancelMu.Lock() + defer cc.cancelMu.Unlock() + cancelErr := cc.createCancelError() + ok := cc.markCancelHandled() + return cancelErr, ok +} + +// buildCancelFunc builds the AgentCancelFunc for external use. +func (cc *cancelContext) buildCancelFunc() AgentCancelFunc { + joinMode := func(a, b CancelMode) CancelMode { + if a == CancelImmediate || b == CancelImmediate { + return CancelImmediate + } + return a | b + } + + parseReq := func(callOpts ...AgentCancelOption) *agentCancelConfig { + cfg := &agentCancelConfig{Mode: CancelImmediate} + for _, opt := range callOpts { + opt(cfg) + } + return cfg + } + + startTimeoutController := func() { + cc.timeoutOnce.Do(func() { + go func() { + for { + select { + case <-cc.doneChan: + return + default: + } + + mode := cc.getMode() + if mode == CancelImmediate { + return + } + + deadline := cc.getDeadlineUnixNano() + if deadline == 0 { + select { + case <-cc.timeoutNotify: + continue + case <-cc.doneChan: + return + } + } + + now := time.Now().UnixNano() + wait := time.Duration(deadline - now) + if wait <= 0 { + atomic.StoreInt32(&cc.escalated, 1) + atomic.StoreInt32(&cc.timeoutEscalated, 1) + cc.sendImmediateInterrupt() + return + } + + timer := time.NewTimer(wait) + select { + case <-timer.C: + timer.Stop() + atomic.StoreInt32(&cc.escalated, 1) + atomic.StoreInt32(&cc.timeoutEscalated, 1) + cc.sendImmediateInterrupt() + return + case <-cc.timeoutNotify: + timer.Stop() + continue + case <-cc.doneChan: + timer.Stop() + return + } + } + }() + }) + } + + newHandle := func(wait func() error) *CancelHandle { + return &CancelHandle{wait: wait} + } + + waitForCompletion := func() error { + <-cc.doneChan + + st := atomic.LoadInt32(&cc.state) + switch st { + case stateDone: + return ErrExecutionCompleted + default: + if atomic.LoadInt32(&cc.timeoutEscalated) == 1 { + return ErrCancelTimeout + } + return nil + } + } + + return func(callOpts ...AgentCancelOption) (*CancelHandle, bool) { + req := parseReq(callOpts...) + + st := atomic.LoadInt32(&cc.state) + switch st { + case stateCancelHandled: + return newHandle(func() error { return nil }), false + case stateDone: + return newHandle(func() error { return ErrExecutionCompleted }), false + } + + var needImmediate, needTimeoutCtl bool + + cc.cancelMu.Lock() + + st = atomic.LoadInt32(&cc.state) + switch st { + case stateCancelHandled: + cc.cancelMu.Unlock() + return newHandle(func() error { return nil }), false + case stateDone: + cc.cancelMu.Unlock() + return newHandle(func() error { return ErrExecutionCompleted }), false + } + + curMode := cc.getMode() + if st == stateRunning { + if !atomic.CompareAndSwapInt32(&cc.state, stateRunning, stateCancelling) { + st = atomic.LoadInt32(&cc.state) + cc.cancelMu.Unlock() + if st == stateDone { + return newHandle(func() error { return ErrExecutionCompleted }), false + } + return newHandle(waitForCompletion), true + } + + curMode = req.Mode + cc.setMode(curMode) + atomic.StoreInt32(&cc.startedMode, int32(curMode)) + close(cc.cancelChan) + } else { + curMode = joinMode(curMode, req.Mode) + cc.setMode(curMode) + } + + if curMode == CancelImmediate { + cc.setDeadlineUnixNano(0) + needImmediate = true + } else if req.Timeout != nil && *req.Timeout > 0 { + proposed := time.Now().Add(*req.Timeout).UnixNano() + existing := cc.getDeadlineUnixNano() + if existing == 0 || proposed < existing { + cc.setDeadlineUnixNano(proposed) + cc.wakeTimeoutController() + } + needTimeoutCtl = cc.getDeadlineUnixNano() != 0 + } + + cc.cancelMu.Unlock() + + if needImmediate { + if atomic.LoadInt32(&cc.startedMode) != int32(CancelImmediate) { + atomic.StoreInt32(&cc.escalated, 1) + } + cc.sendImmediateInterrupt() + } + if needTimeoutCtl { + startTimeoutController() + } + + return newHandle(waitForCompletion), true + } +} + +// wrapIterWithCancelCtx wraps an iterator with cancel lifecycle management. +// It calls markDone when the inner iterator is fully drained, ensuring the +// cancelContext's doneChan is closed and propagation goroutines can exit. +// +// For root cancelContexts (created by WithCancel, not deriveChild), it also +// converts interrupt ACTION events to CancelError when cancel is active. +// This is the single point of interrupt-to-CancelError conversion in the +// system — Runner.handleIter only enriches the resulting CancelError with +// checkpoint metadata. +// +// Interrupt absorption: ALL interrupts are converted when cancel is active, +// including business interrupts (compose.Interrupt from user code). Cancel and +// business interrupts cannot be reliably distinguished in concurrent execution +// (parallel workflows, concurrent tool calls) where they merge into a single +// composite signal. The business interrupt data is preserved in the checkpoint +// and re-fires naturally on resume. +// +// This conversion MUST happen in this wrapper (not deferred to Runner.handleIter) +// because markDone runs as a defer in this goroutine — if the interrupt event +// were passed through unconverted, markDone would transition stateCancelling→stateDone +// before the Runner goroutine could call createAndMarkCancelHandled, causing it +// to fail the CAS. +func wrapIterWithCancelCtx(iter *AsyncIterator[*AgentEvent], cancelCtx *cancelContext) *AsyncIterator[*AgentEvent] { + if cancelCtx == nil { + return iter + } + it, gen := NewAsyncIteratorPair[*AgentEvent]() + go func() { + defer cancelCtx.markDone() + defer gen.Close() + for { + event, ok := iter.Next() + if !ok { + break + } + + if cancelCtx.isRoot() && event.Action != nil && event.Action.internalInterrupted != nil { + if cancelCtx.shouldCancel() { + cancelErr, ok := cancelCtx.createAndMarkCancelHandled() + if ok { + cancelErr.interruptSignal = event.Action.internalInterrupted + gen.Send(&AgentEvent{Err: cancelErr}) + } + return + } + } + + gen.Send(event) + } + }() + return it +} + +// cancelMonitoredModel wraps a model with cancel monitoring. +// Generate: pure delegate to the inner model (CancelAfterChatModel is handled +// by a dedicated node after the ChatModel in the compose graph). +// Stream: pipes chunks through a goroutine that selects on immediateChan for +// CancelImmediate abort. +type cancelMonitoredModel struct { + inner model.BaseChatModel + cancelContext *cancelContext +} + +type recvResult[T any] struct { + data T + err error +} + +func (m *cancelMonitoredModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + return m.inner.Generate(ctx, input, opts...) +} + +func (m *cancelMonitoredModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + stream, err := m.inner.Stream(ctx, input, opts...) + if err != nil { + return nil, err + } + wrapped := wrapStreamWithCancelMonitoring(stream, m.cancelContext) + return wrapped, nil +} + +// wrapStreamWithCancelMonitoring wraps a stream with cancel monitoring. +// When immediateChan fires (CancelImmediate or timeout escalation), the output +// stream is terminated with ErrStreamCanceled. +func wrapStreamWithCancelMonitoring[T any](stream *schema.StreamReader[T], cc *cancelContext) *schema.StreamReader[T] { + if cc == nil { + return stream + } + + // Already canceled — terminate immediately + select { + case <-cc.immediateChan: + stream.Close() + r, w := schema.Pipe[T](1) + var zero T + w.Send(zero, ErrStreamCanceled) + w.Close() + return r + default: + } + + reader, writer := schema.Pipe[T](1) + + go func() { + done := make(chan struct{}) + defer close(done) + defer writer.Close() + defer stream.Close() + + ch := make(chan recvResult[T]) + go func() { + defer close(ch) + for { + chunk, recvErr := stream.Recv() + select { + case ch <- recvResult[T]{chunk, recvErr}: + case <-done: + return + } + if recvErr != nil { + return + } + } + }() + + for { + select { + case <-cc.immediateChan: + var zero T + writer.Send(zero, ErrStreamCanceled) + return + + case r, ok := <-ch: + if !ok { + return + } + if r.err != nil { + if r.err == io.EOF { + return + } + var zero T + writer.Send(zero, r.err) + return + } + if closed := writer.Send(r.data, nil); closed { + return + } + } + } + }() + + return reader +} + +// cancelMonitoredToolHandler wraps streamable tool calls with cancel monitoring. +// When CancelImmediate fires, the tool output stream is terminated with ErrStreamCanceled. +// This handler reads the cancelContext from the Go context via getCancelContext. +type cancelMonitoredToolHandler struct{} + +func (h *cancelMonitoredToolHandler) WrapStreamableToolCall(next compose.StreamableToolEndpoint) compose.StreamableToolEndpoint { + return func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) { + output, err := next(ctx, input) + if err != nil { + return nil, err + } + + cc := getCancelContext(ctx) + if cc == nil { + return output, nil + } + + wrapped := wrapStreamWithCancelMonitoring(output.Result, cc) + return &compose.StreamToolOutput{Result: wrapped}, nil + } +} + +func (h *cancelMonitoredToolHandler) WrapEnhancedStreamableToolCall(next compose.EnhancedStreamableToolEndpoint) compose.EnhancedStreamableToolEndpoint { + return func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedStreamableToolOutput, error) { + output, err := next(ctx, input) + if err != nil { + return nil, err + } + + cc := getCancelContext(ctx) + if cc == nil { + return output, nil + } + + wrapped := wrapStreamWithCancelMonitoring(output.Result, cc) + return &compose.EnhancedStreamableToolOutput{Result: wrapped}, nil + } +} diff --git a/adk/cancel_edge_test.go b/adk/cancel_edge_test.go new file mode 100644 index 000000000..d3fb02a1a --- /dev/null +++ b/adk/cancel_edge_test.go @@ -0,0 +1,1268 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/components/tool" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" +) + +// --- helpers shared across edge-case tests --- + +// blockingChatModel blocks until unblockCh is closed, then returns a fixed response. +type blockingChatModel struct { + unblockCh chan struct{} + response *schema.Message + started chan struct{} + callCount int32 +} + +func newBlockingChatModel(response *schema.Message) *blockingChatModel { + return &blockingChatModel{ + unblockCh: make(chan struct{}), + response: response, + started: make(chan struct{}, 1), + } +} + +func (m *blockingChatModel) Generate(ctx context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m.callCount, 1) + select { + case m.started <- struct{}{}: + default: + } + <-m.unblockCh + return m.response, nil +} + +func (m *blockingChatModel) Stream(ctx context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m.callCount, 1) + select { + case m.started <- struct{}{}: + default: + } + <-m.unblockCh + return schema.StreamReaderFromArray([]*schema.Message{m.response}), nil +} + +func (m *blockingChatModel) BindTools(_ []*schema.ToolInfo) error { return nil } + +// errorChatModel returns an error from Generate/Stream. +type errorChatModel struct { + err error + started chan struct{} +} + +func (m *errorChatModel) Generate(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + if m.started != nil { + select { + case m.started <- struct{}{}: + default: + } + } + return nil, m.err +} + +func (m *errorChatModel) Stream(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, m.err +} + +func (m *errorChatModel) BindTools(_ []*schema.ToolInfo) error { return nil } + +// plainResponseModel returns immediately with a fixed text response (no tool calls). +type plainResponseModel struct { + text string +} + +func (m *plainResponseModel) Generate(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return schema.AssistantMessage(m.text, nil), nil +} + +func (m *plainResponseModel) Stream(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage(m.text, nil)}), nil +} + +func (m *plainResponseModel) BindTools(_ []*schema.ToolInfo) error { return nil } + +// blockingTool blocks until unblockCh is closed. +type blockingTool struct { + name string + unblockCh chan struct{} + started chan struct{} + callCount int32 +} + +func newBlockingTool(name string) *blockingTool { + return &blockingTool{ + name: name, + unblockCh: make(chan struct{}), + started: make(chan struct{}, 4), + } +} + +func (t *blockingTool) Info(_ context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: t.name, + Desc: "blocking tool", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "input": {Type: "string"}, + }), + }, nil +} + +func (t *blockingTool) InvokableRun(_ context.Context, _ string, _ ...tool.Option) (string, error) { + atomic.AddInt32(&t.callCount, 1) + select { + case t.started <- struct{}{}: + default: + } + <-t.unblockCh + return "result", nil +} + +func toolCallMsg(calls ...schema.ToolCall) *schema.Message { + return &schema.Message{Role: schema.Assistant, ToolCalls: calls} +} + +func toolCall(id, name, args string) schema.ToolCall { + return schema.ToolCall{ID: id, Type: "function", Function: schema.FunctionCall{Name: name, Arguments: args}} +} + +func drainEvents(iter *AsyncIterator[*AgentEvent]) ([]*AgentEvent, bool) { + var events []*AgentEvent + hasCancelError := false + for { + e, ok := iter.Next() + if !ok { + break + } + events = append(events, e) + var ce *CancelError + if e.Err != nil && errors.As(e.Err, &ce) { + hasCancelError = true + } + } + return events, hasCancelError +} + +// --- tests --- + +// TestWithCancel_BeforeExecutionStarts verifies that a cancel issued before +// the graph begins executing still produces a CancelError without invoking +// the model or tools. +func TestWithCancel_BeforeExecutionStarts(t *testing.T) { + ctx := context.Background() + + blk := newBlockingChatModel(toolCallMsg(toolCall("c1", "bt", `{"input":"x"}`))) + bt := newBlockingTool("bt") + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: blk, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{bt}}, + }, + }) + assert.NoError(t, err) + + cancelOpt, cancelFn := WithCancel() + + // Extract the cancelContext so we can wait for cancelChan to close, + // ensuring the cancel is fully registered before Run starts. + cc := getCommonOptions(nil, cancelOpt).cancelCtx + + // Call cancel BEFORE calling agent.Run. + // The cancelFunc must succeed (not hang) even though execution hasn't started. + cancelDone := make(chan error, 1) + go func() { + handle, _ := cancelFn() + cancelDone <- handle.Wait() + }() + + // Wait for cancelChan to close so the pre-execution check in runFunc + // deterministically sees shouldCancel()=true (eliminates goroutine scheduling race). + <-cc.cancelChan + + // Now start the run — it should see shouldCancel()=true and emit CancelError immediately. + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("hi")}}, cancelOpt) + + _, hasCancelError := drainEvents(iter) + assert.True(t, hasCancelError, "expected CancelError when cancel precedes execution") + + // cancelFn must have already returned (or return quickly now that doneChan is closed). + select { + case cancelErr := <-cancelDone: + // Either nil (cancel handled) or ErrExecutionCompleted is acceptable + // depending on exact timing; what matters is it didn't hang. + _ = cancelErr + case <-time.After(3 * time.Second): + t.Fatal("cancelFn blocked indefinitely after pre-start cancel") + } + + // Model and tool must not have been invoked. + assert.Equal(t, int32(0), atomic.LoadInt32(&bt.callCount), "tool must not be called") +} + +// TestWithCancel_AfterCompletion verifies cancelFn returns ErrExecutionCompleted +// when called after a normal run finishes. +func TestWithCancel_AfterCompletion(t *testing.T) { + ctx := context.Background() + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: &plainResponseModel{text: "done"}, + }) + require.NoError(t, err) + + cancelOpt, cancelFn := WithCancel() + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("hi")}}, cancelOpt) + + // Drain all events so the run completes. + for { + _, ok := iter.Next() + if !ok { + break + } + } + + handle, _ := cancelFn() + cancelErr := handle.Wait() + assert.ErrorIs(t, cancelErr, ErrExecutionCompleted) +} + +// TestWithCancel_AfterBusinessInterrupt verifies cancelFn returns ErrExecutionCompleted +// when called after the agent has been interrupted by business logic. +func TestWithCancel_AfterBusinessInterrupt(t *testing.T) { + ctx := context.Background() + + // Use a model that triggers a compose.Interrupt so the agent stops with an interrupt. + interruptModel := &interruptingChatModel{} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: interruptModel, + }) + require.NoError(t, err) + + store := newCancelTestStore() + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + CheckPointStore: store, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("hi")}, cancelOpt, WithCheckPointID("biz-interrupt-1")) + + // Drain — expect an interrupt action event, no cancel error. + var gotInterrupt bool + for { + e, ok := iter.Next() + if !ok { + break + } + if e.Action != nil && e.Action.Interrupted != nil { + gotInterrupt = true + } + } + assert.True(t, gotInterrupt, "expected business interrupt event") + + handle, _ := cancelFn() + cancelErr := handle.Wait() + assert.ErrorIs(t, cancelErr, ErrExecutionCompleted) +} + +// TestWithCancel_AfterError verifies cancelFn returns ErrExecutionCompleted +// when called after the agent errors out. +func TestWithCancel_AfterError(t *testing.T) { + ctx := context.Background() + + modelErr := errors.New("model exploded") + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: &errorChatModel{err: modelErr}, + }) + require.NoError(t, err) + + cancelOpt, cancelFn := WithCancel() + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("hi")}}, cancelOpt) + + for { + _, ok := iter.Next() + if !ok { + break + } + } + + handle, _ := cancelFn() + cancelErr := handle.Wait() + assert.ErrorIs(t, cancelErr, ErrExecutionCompleted) +} + +// TestWithCancel_TimeoutEscalation tests that WithAgentCancelTimeout causes the +// cancel to escalate to immediate when the safe-point hasn't fired yet, and +// that the resulting CancelError has Escalated=true. +// +// Strategy: use CancelAfterChatModel mode. The model blocks (never completes), +// so the safe-point can't fire naturally. After the timeout, escalateToImmediate +// closes immediateChan which aborts the model stream via cancelMonitoredModel +// and causes a CancelError — no compose graph-interrupt races involved. +func TestWithCancel_TimeoutEscalation(t *testing.T) { + ctx := context.Background() + + blk := newBlockingChatModel(schema.AssistantMessage("hello", nil)) + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: blk, + }) + require.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + EnableStreaming: true, // use streaming so cancelMonitoredModel.Stream is exercised + }) + + timeout := 300 * time.Millisecond + // CancelAfterChatModel + timeout: safe-point can't fire (model never finishes), + // so after 300ms the timeout goroutine escalates to immediate. + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("go")}, cancelOpt) + + select { + case <-blk.started: + case <-time.After(5 * time.Second): + t.Fatal("model did not start") + } + + // Fire cancelFn; it will wait for escalation to complete. + start := time.Now() + handle, _ := cancelFn(WithAgentCancelMode(CancelAfterChatModel), WithAgentCancelTimeout(timeout)) + cancelErr := handle.Wait() + elapsed := time.Since(start) + + assert.ErrorIs(t, cancelErr, ErrCancelTimeout, "cancel should return ErrCancelTimeout after timeout escalation") + assert.True(t, elapsed >= timeout, "should wait at least the timeout duration, elapsed=%v", elapsed) + assert.True(t, elapsed < 3*time.Second, "should complete shortly after timeout, elapsed=%v", elapsed) + + var cancelError *CancelError + for { + e, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if e.Err != nil && errors.As(e.Err, &ce) { + cancelError = ce + } + } + if assert.NotNil(t, cancelError, "expected CancelError after timeout escalation") { + assert.True(t, cancelError.Info.Escalated, "CancelError should report Escalated=true") + assert.True(t, cancelError.Info.Timeout, "CancelError should report Timeout=true") + } +} + +// TestWithCancel_AfterChatModel_WithTools verifies CancelAfterChatModel fires +// when the model returns tool calls (the safe-point is on the tool-calls path). +func TestWithCancel_AfterChatModel_WithTools(t *testing.T) { + ctx := context.Background() + + blk := newBlockingChatModel(toolCallMsg(toolCall("c1", "bt", `{"input":"x"}`))) + bt := newBlockingTool("bt") + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: blk, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{bt}}, + }, + }) + require.NoError(t, err) + + cancelOpt, cancelFn := WithCancel() + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("hi")}}, cancelOpt) + + select { + case <-blk.started: + case <-time.After(5 * time.Second): + t.Fatal("model did not start") + } + + cancelDone := make(chan error, 1) + go func() { + handle, _ := cancelFn(WithAgentCancelMode(CancelAfterChatModel)) + cancelDone <- handle.Wait() + }() + + time.Sleep(20 * time.Millisecond) + + close(blk.unblockCh) + + cancelErr := <-cancelDone + assert.NoError(t, cancelErr) + + _, hasCancelError := drainEvents(iter) + assert.True(t, hasCancelError, "CancelError expected after model returns tool calls") +} + +// TestWithCancel_CancelImmediate_StreamAborted verifies that CancelImmediate +// during model streaming surfaces ErrStreamCanceled and completes quickly. +func TestWithCancel_CancelImmediate_StreamAborted(t *testing.T) { + ctx := context.Background() + + blk := newBlockingChatModel(schema.AssistantMessage("hello", nil)) + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: blk, + }) + require.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + EnableStreaming: true, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("hi")}, cancelOpt) + + select { + case <-blk.started: + case <-time.After(5 * time.Second): + t.Fatal("model did not start") + } + time.Sleep(50 * time.Millisecond) + + start := time.Now() + handle, _ := cancelFn() + cancelErr := handle.Wait() + assert.NoError(t, cancelErr) + elapsed := time.Since(start) + assert.True(t, elapsed < 2*time.Second, "cancel should complete quickly, elapsed=%v", elapsed) + + var foundStreamCanceled bool + for { + e, ok := iter.Next() + if !ok { + break + } + if e.Err != nil && errors.Is(e.Err, ErrStreamCanceled) { + foundStreamCanceled = true + } + var ce *CancelError + if e.Err != nil && errors.As(e.Err, &ce) { + foundStreamCanceled = true // CancelError wraps stream abort + } + } + assert.True(t, foundStreamCanceled, "expected stream-abort error during immediate cancel") +} + +// TestWithCancel_MultipleToolsConcurrent verifies that CancelAfterToolCalls +// waits for ALL concurrent tool calls to complete before cancelling. +func TestWithCancel_MultipleToolsConcurrent(t *testing.T) { + ctx := context.Background() + + bt1 := newBlockingTool("tool1") + bt2 := newBlockingTool("tool2") + + // Model calls both tools in one response. + modelResp := toolCallMsg( + toolCall("c1", "tool1", `{"input":"a"}`), + toolCall("c2", "tool2", `{"input":"b"}`), + ) + modelWithTools := &simpleChatModel{response: modelResp} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: modelWithTools, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{bt1, bt2}}, + }, + }) + assert.NoError(t, err) + + cancelOpt, cancelFn := WithCancel() + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("go")}}, cancelOpt) + + // Wait for both tools to start. + for i := 0; i < 2; i++ { + select { + case <-bt1.started: + case <-bt2.started: + case <-time.After(5 * time.Second): + t.Fatal("tools did not start") + } + } + + // Request cancel after tool calls while both are still blocking. + cancelDone := make(chan error, 1) + go func() { + handle, _ := cancelFn(WithAgentCancelMode(CancelAfterToolCalls)) + cancelDone <- handle.Wait() + }() + + // Unblock both tools — cancel should fire only after both complete. + time.Sleep(50 * time.Millisecond) + close(bt1.unblockCh) + time.Sleep(50 * time.Millisecond) + close(bt2.unblockCh) + + cancelErr := <-cancelDone + assert.NoError(t, cancelErr) + + assert.Equal(t, int32(1), atomic.LoadInt32(&bt1.callCount), "tool1 should complete") + assert.Equal(t, int32(1), atomic.LoadInt32(&bt2.callCount), "tool2 should complete") + + _, hasCancelError := drainEvents(iter) + assert.True(t, hasCancelError, "expected CancelError after concurrent tools completed") +} + +// TestWithCancel_GraphInterruptRaceBeforeSet verifies that a CancelImmediate +// issued before setGraphInterruptFunc is called still results in cancellation. +// This exercises the retroactive-fire path in setGraphInterruptFunc. +func TestWithCancel_GraphInterruptRaceBeforeSet(t *testing.T) { + ctx := context.Background() + + blk := newBlockingChatModel(schema.AssistantMessage("hi", nil)) + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: blk, + }) + require.NoError(t, err) + + cancelOpt, cancelFn := WithCancel() + + // Cancel immediately before run starts. + go func() { + handle, _ := cancelFn() + _ = handle.Wait() + }() + + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("hi")}}, cancelOpt) + + done := make(chan struct{}) + go func() { + defer close(done) + drainEvents(iter) + }() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("iteration did not complete after pre-start CancelImmediate") + } +} + +// TestWithCancel_NoCheckpointStore verifies cancel completes and does not panic +// when no checkpoint store is configured. +func TestWithCancel_NoCheckpointStore(t *testing.T) { + ctx := context.Background() + + blk := newBlockingChatModel(schema.AssistantMessage("hi", nil)) + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: blk, + }) + require.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + // No CheckPointStore set. + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("hi")}, cancelOpt) + + select { + case <-blk.started: + case <-time.After(5 * time.Second): + t.Fatal("model did not start") + } + time.Sleep(30 * time.Millisecond) + + handle, _ := cancelFn() + cancelErr := handle.Wait() + assert.NoError(t, cancelErr) + + var ce *CancelError + for { + e, ok := iter.Next() + if !ok { + break + } + if e.Err != nil && errors.As(e.Err, &ce) { + break + } + } + if assert.NotNil(t, ce, "expected CancelError even without checkpoint store") { + assert.Empty(t, ce.CheckPointID, "CheckPointID should be empty without checkpoint store") + } +} + +// TestWithCancel_ModelError verifies that a model error marks the cancelCtx as +// done so that a subsequent cancelFn call returns ErrExecutionCompleted. +func TestWithCancel_ModelError(t *testing.T) { + ctx := context.Background() + + modelErr := errors.New("model failed") + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: &errorChatModel{err: modelErr}, + }) + require.NoError(t, err) + + cancelOpt, cancelFn := WithCancel() + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("hi")}}, cancelOpt) + + var gotModelErr bool + for { + e, ok := iter.Next() + if !ok { + break + } + if e.Err != nil && !errors.As(e.Err, new(*CancelError)) { + gotModelErr = true + } + } + assert.True(t, gotModelErr, "expected non-cancel error event from model failure") + + handle, _ := cancelFn() + cancelErr := handle.Wait() + assert.ErrorIs(t, cancelErr, ErrExecutionCompleted, "cancelFn should return ErrExecutionCompleted after model error") +} + +// TestWithCancel_Resume_SafePoint covers CancelAfterChatModel and +// CancelAfterToolCalls on a Resume path. +func TestWithCancel_Resume_SafePoint(t *testing.T) { + ctx := context.Background() + + // --- phase 1: run to get a checkpoint via CancelImmediate --- + blk := newBlockingChatModel(toolCallMsg(toolCall("c1", "bt", `{"input":"x"}`))) + bt := newSlowTool("bt", 50*time.Millisecond, "result") + + agent1, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: blk, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{bt}}, + }, + }) + assert.NoError(t, err) + + store := newCancelTestStore() + runner1 := NewRunner(ctx, RunnerConfig{ + Agent: agent1, + CheckPointStore: store, + }) + + cancelOpt1, cancelFn1 := WithCancel() + iter1 := runner1.Run(ctx, []Message{schema.UserMessage("hi")}, cancelOpt1, WithCheckPointID("resume-sp-1")) + + select { + case <-blk.started: + case <-time.After(5 * time.Second): + t.Fatal("model did not start in phase 1") + } + _, _ = cancelFn1() + drainEvents(iter1) + + // --- phase 2: resume, cancel after chat model --- + resumeModel := newBlockingChatModel(toolCallMsg(toolCall("c1", "bt", `{"input":"x"}`))) + + bt2 := newSlowTool("bt", 50*time.Millisecond, "result") + agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: resumeModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{bt2}}, + }, + }) + assert.NoError(t, err) + + runner2 := NewRunner(ctx, RunnerConfig{ + Agent: agent2, + CheckPointStore: store, + }) + + cancelOpt2, cancelFn2 := WithCancel() + resumeIter, err := runner2.Resume(ctx, "resume-sp-1", cancelOpt2) + require.NoError(t, err) + + select { + case <-resumeModel.started: + case <-time.After(5 * time.Second): + t.Fatal("model did not start in phase 2") + } + + cancelDone := make(chan error, 1) + go func() { + handle, _ := cancelFn2(WithAgentCancelMode(CancelAfterChatModel)) + cancelDone <- handle.Wait() + }() + + time.Sleep(50 * time.Millisecond) + + close(resumeModel.unblockCh) + + cancelErr := <-cancelDone + assert.NoError(t, cancelErr) + + _, hasCancelError := drainEvents(resumeIter) + assert.True(t, hasCancelError, "CancelError expected after resumed model returns tool calls") +} + +// callbackTool is a tool that calls onCall when invoked. +type callbackTool struct { + name string + onCall func() +} + +func (t *callbackTool) Info(_ context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: t.name, + Desc: "callback tool", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "input": {Type: "string"}, + }), + }, nil +} + +func (t *callbackTool) InvokableRun(_ context.Context, _ string, _ ...tool.Option) (string, error) { + if t.onCall != nil { + t.onCall() + } + return "ok", nil +} + +// interruptingChatModel returns a compose.Interrupt error to simulate a +// business interrupt during execution. +type interruptingChatModel struct{} + +func (m *interruptingChatModel) Generate(ctx context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, compose.Interrupt(ctx, "test interrupt") +} + +func (m *interruptingChatModel) Stream(ctx context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, compose.Interrupt(ctx, "test interrupt") +} + +func (m *interruptingChatModel) BindTools(_ []*schema.ToolInfo) error { return nil } + +// TestWithCancel_TargetedResume_CancelImmediate cancels an agent via CancelImmediate, +// extracts InterruptContexts from the resulting CancelError, and uses them +// for targeted resumption via Runner.ResumeWithParams. +func TestWithCancel_TargetedResume_CancelImmediate(t *testing.T) { + ctx := context.Background() + + blk := newBlockingChatModel(toolCallMsg(toolCall("c1", "st", `{"input":"x"}`))) + st := newSlowTool("st", 50*time.Millisecond, "result") + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: blk, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{st}}, + }, + }) + require.NoError(t, err) + + store := newCancelTestStore() + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + CheckPointStore: store, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("go")}, cancelOpt, WithCheckPointID("targeted-imm-1")) + + select { + case <-blk.started: + case <-time.After(5 * time.Second): + t.Fatal("model did not start") + } + + handle, _ := cancelFn() // CancelImmediate (default) + cancelErr := handle.Wait() + assert.NoError(t, cancelErr) + + var cancelError *CancelError + for { + e, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if e.Err != nil && errors.As(e.Err, &ce) { + cancelError = ce + } + } + + require.NotNil(t, cancelError, "expected CancelError") + require.NotEmpty(t, cancelError.InterruptContexts, "CancelError should have InterruptContexts for targeted resume") + + // --- resume with targeted params --- + targets := make(map[string]any) + for _, ic := range cancelError.InterruptContexts { + targets[ic.ID] = nil + } + + resumeModel := &plainResponseModel{text: "resumed"} + agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: resumeModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{st}}, + }, + }) + require.NoError(t, err) + + runner2 := NewRunner(ctx, RunnerConfig{ + Agent: agent2, + CheckPointStore: store, + }) + + resumeIter, err := runner2.ResumeWithParams(ctx, "targeted-imm-1", &ResumeParams{Targets: targets}) + require.NoError(t, err) + + var gotOutput bool + for { + e, ok := resumeIter.Next() + if !ok { + break + } + if e.Err != nil { + t.Fatalf("unexpected error during targeted resume: %v", e.Err) + } + if e.Output != nil && e.Output.MessageOutput != nil { + gotOutput = true + } + } + assert.True(t, gotOutput, "targeted resume should produce output") +} + +// TestWithCancel_TargetedResume_SafePoint cancels an agent via CancelAfterChatModel +// (safe-point) and verifies that InterruptContexts are populated on the CancelError +// and that targeted resume via ResumeWithParams succeeds. +// Since safe-point cancels now use compose.Interrupt, compose saves checkpoint data, +// making the cancel fully resumable. +func TestWithCancel_TargetedResume_SafePoint(t *testing.T) { + ctx := context.Background() + + // The model returns a tool call so the react graph routes to toolPreHandle, + // which detects CancelAfterChatModel and fires compose.Interrupt. + blk := newBlockingChatModel(toolCallMsg(toolCall("c1", "st", `{"input":"x"}`))) + st := newSlowTool("st", 0, "result") + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: blk, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{st}}, + }, + }) + require.NoError(t, err) + + store := newCancelTestStore() + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + CheckPointStore: store, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("go")}, cancelOpt, WithCheckPointID("targeted-sp-1")) + + select { + case <-blk.started: + case <-time.After(5 * time.Second): + t.Fatal("model did not start") + } + + // Start cancelFn in background so the CAS happens before the model unblocks. + cancelDone := make(chan error, 1) + go func() { + handle, _ := cancelFn(WithAgentCancelMode(CancelAfterChatModel)) + cancelDone <- handle.Wait() + }() + time.Sleep(50 * time.Millisecond) + close(blk.unblockCh) + + cancelErr := <-cancelDone + assert.NoError(t, cancelErr) + + var cancelError *CancelError + for { + e, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if e.Err != nil && errors.As(e.Err, &ce) { + cancelError = ce + } + } + + require.NotNil(t, cancelError, "expected CancelError") + require.NotEmpty(t, cancelError.InterruptContexts, "CancelError should have InterruptContexts for targeted resume") + + // --- resume with targeted params --- + targets := make(map[string]any) + for _, ic := range cancelError.InterruptContexts { + targets[ic.ID] = nil + } + + resumeModel := &plainResponseModel{text: "resumed"} + agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: resumeModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{st}}, + }, + }) + require.NoError(t, err) + + runner2 := NewRunner(ctx, RunnerConfig{ + Agent: agent2, + CheckPointStore: store, + }) + + resumeIter, err := runner2.ResumeWithParams(ctx, "targeted-sp-1", &ResumeParams{Targets: targets}) + require.NoError(t, err) + + var gotOutput bool + for { + e, ok := resumeIter.Next() + if !ok { + break + } + if e.Err != nil { + t.Fatalf("unexpected error during targeted resume: %v", e.Err) + } + if e.Output != nil && e.Output.MessageOutput != nil { + gotOutput = true + } + } + assert.True(t, gotOutput, "targeted resume should produce output") +} + +// TestWithCancel_Resume_CancelAfterChatModel_MessagePreserved tests both the +// ReAct (with-tools) and noTools paths to ensure that when a +// CancelAfterChatModel safe-point fires and the run is later resumed, the +// original Message returned by the chat model is preserved through the +// StatefulInterrupt checkpoint. +// +// For the ReAct path: the model returns a tool-call message. On resume the +// cancelCheck node must return that same message so the branch routes to the +// ToolNode and the tool actually executes. +// +// For the noTools path: the model returns a plain text message. On resume the +// cancel-check lambda must return that same message as the chain output. +func TestWithCancel_Resume_CancelAfterChatModel_MessagePreserved(t *testing.T) { + t.Run("react_path_tool_call_preserved", func(t *testing.T) { + ctx := context.Background() + + // Phase-2 model returns no tool calls so the graph ends. + // We track whether the tool actually executes on resume. + toolExecuted := make(chan struct{}, 1) + st := &callbackTool{ + name: "my_tool", + onCall: func() { + select { + case toolExecuted <- struct{}{}: + default: + } + }, + } + + // Phase-1 model returns a tool call. + blk := newBlockingChatModel(toolCallMsg(toolCall("c1", "my_tool", `{"input":"x"}`))) + + agent1, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: blk, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{st}}, + }, + }) + require.NoError(t, err) + + store := newCancelTestStore() + runner1 := NewRunner(ctx, RunnerConfig{ + Agent: agent1, + CheckPointStore: store, + }) + + cancelOpt1, cancelFn1 := WithCancel() + iter1 := runner1.Run(ctx, []Message{schema.UserMessage("hi")}, + cancelOpt1, WithCheckPointID("react-msg-preserved-1")) + + select { + case <-blk.started: + case <-time.After(5 * time.Second): + t.Fatal("model did not start in phase 1") + } + + cancelDone := make(chan error, 1) + go func() { + handle, _ := cancelFn1(WithAgentCancelMode(CancelAfterChatModel)) + cancelDone <- handle.Wait() + }() + time.Sleep(50 * time.Millisecond) + close(blk.unblockCh) + + cancelErr := <-cancelDone + assert.NoError(t, cancelErr) + + _, hasCancelError := drainEvents(iter1) + assert.True(t, hasCancelError, "expected CancelError from phase 1") + + // Phase 2: resume. The model for phase-2 returns plain text (no tool + // calls) so the react graph ends after one iteration. But first the + // tool from the checkpoint must execute. + resumeModel := &plainResponseModel{text: "done"} + agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: resumeModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{st}}, + }, + }) + require.NoError(t, err) + + runner2 := NewRunner(ctx, RunnerConfig{ + Agent: agent2, + CheckPointStore: store, + }) + + resumeIter, err := runner2.Resume(ctx, "react-msg-preserved-1") + require.NoError(t, err) + + for { + e, ok := resumeIter.Next() + if !ok { + break + } + if e.Err != nil { + t.Fatalf("unexpected error during resume: %v", e.Err) + } + } + + // The key assertion: the tool must have been called during resume, + // which can only happen if the tool-call message was preserved. + select { + case <-toolExecuted: + // success + default: + t.Fatal("tool was not executed on resume — the tool-call message was lost") + } + }) + +} + +// TestHandleRunFuncError_AlreadyHandled_NoDuplicate verifies that when +// markCancelHandled() was already claimed by a sub-agent's handleRunFuncError, +// the sequential workflow's checkCancel does not emit a second CancelError. +// +// Setup: sequential[cma1, cma2] with CancelAfterToolCalls. cma1 has tools, +// cancel fires while tool is running. After tool completes, the safe-point +// fires in cma1's handleRunFuncError (claiming markCancelHandled). The +// sequential workflow's checkCancel at the transition point should find +// markCancelHandled returns false and skip — producing exactly 1 CancelError. +func TestHandleRunFuncError_AlreadyHandled_NoDuplicate(t *testing.T) { + ctx := context.Background() + + bt := newBlockingTool("bt") + + // cma1: model returns a tool call immediately, tool blocks until unblocked + cma1Model := newBlockingChatModel(toolCallMsg(toolCall("c1", "bt", `{"input":"x"}`))) + close(cma1Model.unblockCh) // model returns immediately + + agent1, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent1", Description: "first", Instruction: "test", + Model: cma1Model, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{bt}}, + }, + }) + require.NoError(t, err) + + agent2Model := &plainResponseModel{text: "agent2-response"} + agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent2", Description: "second", Instruction: "test", + Model: agent2Model, + }) + require.NoError(t, err) + + seqAgent, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "seq", Description: "sequential", SubAgents: []Agent{agent1, agent2}, + }) + require.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: seqAgent, EnableStreaming: false, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt) + + // Wait for tool to start + select { + case <-bt.started: + case <-time.After(5 * time.Second): + t.Fatal("Tool did not start") + } + + // Cancel while tool is still running (in goroutine because cancelFn blocks + // until execution finishes), then unblock tool so safe-point fires + go func() { + handle, _ := cancelFn(WithAgentCancelMode(CancelAfterToolCalls)) + _ = handle.Wait() + }() + + // Give cancel time to register, then unblock tool + time.Sleep(50 * time.Millisecond) + close(bt.unblockCh) + + cancelCount := 0 + for { + event, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + cancelCount++ + } + } + + assert.Equal(t, 1, cancelCount, "Should have exactly one CancelError, no duplicate from handleRunFuncError + checkCancel") +} + +func TestWithCancel_CancelAfterChatModel_NestedAgentTool(t *testing.T) { + ctx := context.Background() + + subAgentModel := newBlockingChatModel(toolCallMsg(toolCall("c1", "sub_tool", `{"input":"x"}`))) + subAgentModelStarted := subAgentModel.started + subTool := newBlockingTool("sub_tool") + + subAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "sub_agent", + Description: "test sub agent", + Instruction: "you are a sub agent", + Model: subAgentModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{subTool}}, + }, + }) + require.NoError(t, err) + + supervisorModel := &simpleChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + ToolCalls: []schema.ToolCall{{ + ID: "call_1", Type: "function", + Function: schema.FunctionCall{ + Name: TransferToAgentToolName, + Arguments: `{"agent_name": "sub_agent"}`, + }, + }}, + }, + } + + supervisorAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "supervisor", + Description: "supervisor agent (equivalent to DeepAgent)", + Instruction: "you are a supervisor", + Model: supervisorModel, + }) + require.NoError(t, err) + + agentWithSubAgents, err := SetSubAgents(ctx, supervisorAgent, []Agent{subAgent}) + require.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: agentWithSubAgents, + EnableStreaming: false, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt) + + select { + case <-subAgentModelStarted: + case <-time.After(10 * time.Second): + t.Fatal("Sub-agent model did not start") + } + + time.Sleep(50 * time.Millisecond) + + cancelDone := make(chan error, 1) + go func() { + handle, _ := cancelFn(WithAgentCancelMode(CancelAfterChatModel)) + cancelDone <- handle.Wait() + }() + + time.Sleep(100 * time.Millisecond) + close(subAgentModel.unblockCh) + + cancelErr := <-cancelDone + assert.NoError(t, cancelErr) + + hasCancelError := false + for { + event, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + hasCancelError = true + } + } + + assert.True(t, hasCancelError, "CancelError expected from nested agent tool with tools") +} diff --git a/adk/cancel_multicall_test.go b/adk/cancel_multicall_test.go new file mode 100644 index 000000000..790d14fb3 --- /dev/null +++ b/adk/cancel_multicall_test.go @@ -0,0 +1,125 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/compose" +) + +func TestAgentCancelFunc_MultiCall_EscalateToImmediate(t *testing.T) { + cc := newCancelContext() + var interruptCalls int32 + cc.setGraphInterruptFunc(func(opts ...compose.GraphInterruptOption) { + atomic.AddInt32(&interruptCalls, 1) + }) + cancelFn := cc.buildCancelFunc() + + handle1, _ := cancelFn(WithAgentCancelMode(CancelAfterChatModel)) + handle2, _ := cancelFn(WithAgentCancelMode(CancelImmediate)) + assert.Equal(t, int32(1), atomic.LoadInt32(&interruptCalls)) + + cancelErr := cc.createCancelError() + assert.Equal(t, CancelImmediate, cancelErr.Info.Mode) + assert.True(t, cancelErr.Info.Escalated) + assert.False(t, cancelErr.Info.Timeout) + + assert.True(t, cc.markCancelHandled()) + assert.NoError(t, handle1.Wait()) + assert.NoError(t, handle2.Wait()) +} + +func TestAgentCancelFunc_MultiCall_JoinSafePointModes(t *testing.T) { + cc := newCancelContext() + cancelFn := cc.buildCancelFunc() + + handle1, _ := cancelFn(WithAgentCancelMode(CancelAfterChatModel)) + handle2, _ := cancelFn(WithAgentCancelMode(CancelAfterToolCalls)) + + want := CancelAfterChatModel | CancelAfterToolCalls + assert.Equal(t, want, cc.getMode()) + + assert.True(t, cc.markCancelHandled()) + assert.NoError(t, handle1.Wait()) + assert.NoError(t, handle2.Wait()) +} + +func TestAgentCancelFunc_MultiCall_TimeoutDeadlineJoinUsesAbsoluteTime(t *testing.T) { + cc := newCancelContext() + cancelFn := cc.buildCancelFunc() + + handle1, _ := cancelFn( + WithAgentCancelMode(CancelAfterChatModel), + WithAgentCancelTimeout(200*time.Millisecond), + ) + + firstDeadline := cc.getDeadlineUnixNano() + assert.NotZero(t, firstDeadline) + + time.Sleep(50 * time.Millisecond) + + handle2, _ := cancelFn( + WithAgentCancelMode(CancelAfterToolCalls), + WithAgentCancelTimeout(60*time.Millisecond), + ) + + secondDeadline := cc.getDeadlineUnixNano() + assert.NotZero(t, secondDeadline) + assert.Less(t, secondDeadline, firstDeadline) + + assert.True(t, cc.markCancelHandled()) + assert.NoError(t, handle1.Wait()) + assert.NoError(t, handle2.Wait()) +} + +func TestAgentCancelFunc_MultiCall_TimeoutEscalationReturnsErrCancelTimeout(t *testing.T) { + cc := newCancelContext() + var interruptCalls int32 + interruptCh := make(chan struct{}, 1) + cc.setGraphInterruptFunc(func(opts ...compose.GraphInterruptOption) { + atomic.AddInt32(&interruptCalls, 1) + select { + case interruptCh <- struct{}{}: + default: + } + }) + cancelFn := cc.buildCancelFunc() + handle, _ := cancelFn( + WithAgentCancelMode(CancelAfterChatModel), + WithAgentCancelTimeout(30*time.Millisecond), + ) + + select { + case <-interruptCh: + case <-time.After(1 * time.Second): + t.Fatal("timeout escalation did not interrupt") + } + assert.Equal(t, int32(1), atomic.LoadInt32(&interruptCalls)) + + cancelErr := cc.createCancelError() + assert.Equal(t, CancelAfterChatModel, cancelErr.Info.Mode) + assert.True(t, cancelErr.Info.Escalated) + assert.True(t, cancelErr.Info.Timeout) + + assert.True(t, cc.markCancelHandled()) + assert.Equal(t, ErrCancelTimeout, handle.Wait()) +} diff --git a/adk/cancel_test.go b/adk/cancel_test.go index 702efd3c3..0d88db8cc 100644 --- a/adk/cancel_test.go +++ b/adk/cancel_test.go @@ -18,6 +18,10 @@ package adk import ( "context" + "errors" + "fmt" + "io" + "runtime" "sync" "sync/atomic" "testing" @@ -32,18 +36,30 @@ import ( ) type cancelTestChatModel struct { - delay time.Duration + delayNs int64 response *schema.Message startedChan chan struct{} doneChan chan struct{} } +func (m *cancelTestChatModel) getDelay() time.Duration { + return time.Duration(atomic.LoadInt64(&m.delayNs)) +} + +func (m *cancelTestChatModel) setDelay(d time.Duration) { + atomic.StoreInt64(&m.delayNs, int64(d)) +} + func (m *cancelTestChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { select { case m.startedChan <- struct{}{}: default: } - time.Sleep(m.delay) + select { + case <-time.After(m.getDelay()): + case <-ctx.Done(): + return nil, ctx.Err() + } select { case m.doneChan <- struct{}{}: default: @@ -53,7 +69,7 @@ func (m *cancelTestChatModel) Generate(ctx context.Context, input []*schema.Mess func (m *cancelTestChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { m.startedChan <- struct{}{} - time.Sleep(m.delay) + time.Sleep(m.getDelay()) m.doneChan <- struct{}{} return schema.StreamReaderFromArray([]*schema.Message{m.response}), nil } @@ -95,7 +111,11 @@ func (t *slowTool) InvokableRun(ctx context.Context, argumentsInJSON string, opt case t.startedChan <- struct{}{}: default: } - time.Sleep(t.delay) + select { + case <-time.After(t.delay): + case <-ctx.Done(): + return "", ctx.Err() + } return t.result, nil } @@ -122,22 +142,20 @@ func (s *cancelTestStore) Get(_ context.Context, key string) ([]byte, bool, erro return v, ok, nil } -func TestCancelSig(t *testing.T) { - t.Run("BasicCancelSignal", func(t *testing.T) { - cs := newCancelSig() +func TestCancelContext(t *testing.T) { + t.Run("BasicCancelContext", func(t *testing.T) { + cc := newCancelContext() + assert.False(t, cc.shouldCancel(), "Should not be cancelled initially") - cfg := checkCancelSig(cs) - assert.Nil(t, cfg, "Should not be cancelled initially") + cc.setMode(CancelImmediate) + close(cc.cancelChan) - cs.cancel(&cancelConfig{Mode: CancelImmediate}) - - cfg = checkCancelSig(cs) - assert.NotNil(t, cfg, "Should be cancelled after cancel()") - assert.Equal(t, CancelImmediate, cfg.Mode) + assert.True(t, cc.shouldCancel(), "Should be cancelled after close(cancelChan)") + assert.Equal(t, CancelImmediate, cc.getMode()) }) } -func TestRunWithCancel_WithTools(t *testing.T) { +func TestWithCancel_WithTools(t *testing.T) { ctx := context.Background() t.Run("CancelImmediate_DuringModelCall", func(t *testing.T) { @@ -145,7 +163,7 @@ func TestRunWithCancel_WithTools(t *testing.T) { st := newSlowTool("slow_tool", 100*time.Millisecond, "tool result") slowModel := &cancelTestChatModel{ - delay: 2 * time.Second, + delayNs: int64(2 * time.Second), response: &schema.Message{ Role: schema.Assistant, Content: "", @@ -182,7 +200,8 @@ func TestRunWithCancel_WithTools(t *testing.T) { EnableStreaming: false, }) - iter, cancelFn := runner.RunWithCancel(ctx, []Message{schema.UserMessage("Use the tool")}) + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("Use the tool")}, cancelOpt) assert.NotNil(t, iter) assert.NotNil(t, cancelFn) @@ -207,24 +226,27 @@ func TestRunWithCancel_WithTools(t *testing.T) { time.Sleep(100 * time.Millisecond) - err = cancelFn() + handle, _ := cancelFn() + err = handle.Wait() assert.NoError(t, err) - start := time.Now() - events := <-eventsCh - elapsed := time.Since(start) + var events []*AgentEvent + select { + case events = <-eventsCh: + case <-time.After(5 * time.Second): + t.Fatal("Timed out waiting for events") + } - assert.True(t, elapsed < 1*time.Second, "Should return quickly after cancel, elapsed: %v", elapsed) assert.True(t, len(events) > 0) - hasInterrupted := false + hasCancelError := false for _, e := range events { - assert.Nil(t, e.Err, "Should not have error event after cancel") - if e.Action != nil && e.Action.Interrupted != nil { - hasInterrupted = true + var cancelErr *CancelError + if e.Err != nil && errors.As(e.Err, &cancelErr) { + hasCancelError = true } } - assert.True(t, hasInterrupted, "Should have interrupted event after cancel") + assert.True(t, hasCancelError, "Should have CancelError event after cancel") }) t.Run("CancelAfterChatModel_DuringToolCall", func(t *testing.T) { @@ -237,6 +259,7 @@ func TestRunWithCancel_WithTools(t *testing.T) { } modelWithToolCall := &simpleChatModel{ + delay: 1 * time.Second, response: &schema.Message{ Role: schema.Assistant, Content: "", @@ -266,9 +289,10 @@ func TestRunWithCancel_WithTools(t *testing.T) { }) assert.NoError(t, err) - iter, cancelFn := agent.RunWithCancel(ctx, &AgentInput{ + cancelOpt, cancelFn := WithCancel() + iter := agent.Run(ctx, &AgentInput{ Messages: []Message{schema.UserMessage("Use the tool")}, - }) + }, cancelOpt) assert.NotNil(t, iter) assert.NotNil(t, cancelFn) @@ -276,7 +300,8 @@ func TestRunWithCancel_WithTools(t *testing.T) { time.Sleep(100 * time.Millisecond) - err = cancelFn(WithCancelMode(CancelAfterChatModel)) + handle, _ := cancelFn(WithAgentCancelMode(CancelAfterChatModel)) + err = handle.Wait() assert.NoError(t, err) var events []*AgentEvent @@ -285,7 +310,10 @@ func TestRunWithCancel_WithTools(t *testing.T) { if !ok { break } - assert.Nil(t, event.Err, "Should not have error event after cancel") + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + continue + } events = append(events, event) } @@ -293,7 +321,7 @@ func TestRunWithCancel_WithTools(t *testing.T) { assert.True(t, atomic.LoadInt32(&st.callCount) >= 1, "Tool should have been called") }) - t.Run("CancelAfterToolCall_CompletesToolExecution", func(t *testing.T) { + t.Run("CancelAfterToolCalls_CompletesToolExecution", func(t *testing.T) { toolStarted := make(chan struct{}, 1) st := &slowToolWithSignal{ name: "slow_tool", @@ -332,9 +360,10 @@ func TestRunWithCancel_WithTools(t *testing.T) { }) assert.NoError(t, err) - iter, cancelFn := agent.RunWithCancel(ctx, &AgentInput{ + cancelOpt, cancelFn := WithCancel() + iter := agent.Run(ctx, &AgentInput{ Messages: []Message{schema.UserMessage("Use the tool")}, - }) + }, cancelOpt) assert.NotNil(t, iter) assert.NotNil(t, cancelFn) @@ -342,7 +371,8 @@ func TestRunWithCancel_WithTools(t *testing.T) { time.Sleep(100 * time.Millisecond) - err = cancelFn(WithCancelMode(CancelAfterToolCall)) + handle, _ := cancelFn(WithAgentCancelMode(CancelAfterToolCalls)) + err = handle.Wait() assert.NoError(t, err) var events []*AgentEvent @@ -351,13 +381,120 @@ func TestRunWithCancel_WithTools(t *testing.T) { if !ok { break } - assert.Nil(t, event.Err, "Should not have error event after cancel") + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + continue + } events = append(events, event) } assert.True(t, len(events) > 0) assert.True(t, atomic.LoadInt32(&st.callCount) >= 1, "Tool should have been called") }) + + t.Run("NestedCancelPropagation", func(t *testing.T) { + cc := newCancelContext() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + child := cc.deriveChild(ctx) + assert.NotNil(t, child) + + cc.setMode(CancelImmediate) + + if atomic.CompareAndSwapInt32(&cc.state, stateRunning, stateCancelling) { + close(cc.cancelChan) + } + + select { + case <-child.cancelChan: + case <-time.After(1 * time.Second): + t.Fatal("Child did not receive cancel signal") + } + + assert.True(t, child.shouldCancel()) + assert.Equal(t, CancelImmediate, child.getMode()) + }) + + t.Run("DeepAgentIntegrationCancel", func(t *testing.T) { + ctx := context.Background() + modelStarted := make(chan struct{}, 1) + + leafModel := &cancelTestChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + Content: "Leaf result", + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + leafModel.setDelay(500 * time.Millisecond) + leafAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "LeafAgent", + Description: "desc", + Model: leafModel, + }) + assert.NoError(t, err) + + rootModel := &cancelTestChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: schema.FunctionCall{ + Name: "LeafAgent", + Arguments: `{}`, + }, + }, + }, + }, + startedChan: make(chan struct{}, 1), + doneChan: make(chan struct{}, 1), + } + rootAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "RootAgent", + Description: "desc", + Model: rootModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{NewAgentTool(ctx, leafAgent)}, + }, + }, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: rootAgent, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("Run leaf")}, cancelOpt) + + <-modelStarted + + handle, _ := cancelFn(WithAgentCancelMode(CancelAfterChatModel)) + err = handle.Wait() + assert.NoError(t, err) + + hasCancelError := false + for { + event, ok := iter.Next() + if !ok { + break + } + if event.Err != nil { + var ce *CancelError + if errors.As(event.Err, &ce) { + hasCancelError = true + assert.NotNil(t, ce.interruptSignal, "CancelError should carry interrupt signal") + } + } + } + assert.True(t, hasCancelError, "Should have received CancelError") + }) } type slowToolWithSignal struct { @@ -386,14 +523,29 @@ func (t *slowToolWithSignal) InvokableRun(ctx context.Context, argumentsInJSON s } type simpleChatModel struct { + delay time.Duration response *schema.Message } func (m *simpleChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + if m.delay > 0 { + select { + case <-time.After(m.delay): + case <-ctx.Done(): + return nil, ctx.Err() + } + } return m.response, nil } func (m *simpleChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + if m.delay > 0 { + select { + case <-time.After(m.delay): + case <-ctx.Done(): + return nil, ctx.Err() + } + } return schema.StreamReaderFromArray([]*schema.Message{m.response}), nil } @@ -401,7 +553,7 @@ func (m *simpleChatModel) BindTools(tools []*schema.ToolInfo) error { return nil } -func TestRunWithCancel_WithCheckpoint(t *testing.T) { +func TestWithCancel_WithCheckpoint(t *testing.T) { ctx := context.Background() t.Run("CancelWithCheckpoint", func(t *testing.T) { @@ -409,7 +561,7 @@ func TestRunWithCancel_WithCheckpoint(t *testing.T) { st := newSlowTool("slow_tool", 100*time.Millisecond, "tool result") slowModel := &cancelTestChatModel{ - delay: 500 * time.Millisecond, + delayNs: int64(1 * time.Second), response: &schema.Message{ Role: schema.Assistant, Content: "", @@ -448,28 +600,38 @@ func TestRunWithCancel_WithCheckpoint(t *testing.T) { CheckPointStore: store, }) - iter, cancelFn := runner.RunWithCancel(ctx, []Message{schema.UserMessage("Use the tool")}, WithCheckPointID("cancel-1")) + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("Use the tool")}, cancelOpt, WithCheckPointID("cancel-1")) <-modelStarted - err = cancelFn() + handle, _ := cancelFn() + err = handle.Wait() assert.NoError(t, err) var events []*AgentEvent + hasCancelError := false + var cancelErrorCheckPointID string for { event, ok := iter.Next() if !ok { break } - assert.Nil(t, event.Err, "Should not have error event after cancel") + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + hasCancelError = true + cancelErrorCheckPointID = ce.CheckPointID + continue + } events = append(events, event) } - assert.True(t, len(events) > 0) + assert.True(t, hasCancelError, "Should have CancelError event after cancel") + assert.Equal(t, "cancel-1", cancelErrorCheckPointID, "CancelError should contain the checkpoint ID") }) } -func TestCancelFuncMultipleCalls(t *testing.T) { +func TestAgentCancelFuncMultipleCalls(t *testing.T) { ctx := context.Background() t.Run("SecondCancelReturnsErrAgentFinished", func(t *testing.T) { @@ -477,7 +639,7 @@ func TestCancelFuncMultipleCalls(t *testing.T) { st := newSlowTool("slow_tool", 100*time.Millisecond, "tool result") slowModel := &cancelTestChatModel{ - delay: 1 * time.Second, + delayNs: int64(1 * time.Second), response: &schema.Message{ Role: schema.Assistant, Content: "", @@ -514,11 +676,13 @@ func TestCancelFuncMultipleCalls(t *testing.T) { EnableStreaming: false, }) - iter, cancelFn := runner.RunWithCancel(ctx, []Message{schema.UserMessage("Use the tool")}) + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("Use the tool")}, cancelOpt) <-modelStarted - cancelErr := cancelFn() + handle, _ := cancelFn() + cancelErr := handle.Wait() assert.NoError(t, cancelErr) for { @@ -530,58 +694,7 @@ func TestCancelFuncMultipleCalls(t *testing.T) { }) } -func TestAgentNotCancellable(t *testing.T) { - ctx := context.Background() - - nonCancellableAgent := &nonCancellableTestAgent{ - name: "NonCancellable", - } - - runner := NewRunner(ctx, RunnerConfig{ - Agent: nonCancellableAgent, - EnableStreaming: false, - }) - - t.Run("RunWithCancelReturnsNilCancelFunc", func(t *testing.T) { - iter, cancelFn := runner.RunWithCancel(ctx, []Message{schema.UserMessage("Hello")}) - assert.NotNil(t, iter) - assert.Nil(t, cancelFn) - - for { - _, ok := iter.Next() - if !ok { - break - } - } - }) -} - -type nonCancellableTestAgent struct { - name string -} - -func (a *nonCancellableTestAgent) Name(_ context.Context) string { - return a.name -} - -func (a *nonCancellableTestAgent) Description(_ context.Context) string { - return "A non-cancellable agent" -} - -func (a *nonCancellableTestAgent) Run(_ context.Context, input *AgentInput, _ ...AgentRunOption) *AsyncIterator[*AgentEvent] { - iter, gen := NewAsyncIteratorPair[*AgentEvent]() - gen.Send(&AgentEvent{ - Output: &AgentOutput{ - MessageOutput: &MessageVariant{ - Message: schema.AssistantMessage("Response", nil), - }, - }, - }) - gen.Close() - return iter -} - -func TestRunWithCancel_Streaming(t *testing.T) { +func TestWithCancel_Streaming(t *testing.T) { ctx := context.Background() t.Run("CancelImmediate_DuringModelStream", func(t *testing.T) { @@ -589,7 +702,7 @@ func TestRunWithCancel_Streaming(t *testing.T) { st := newSlowTool("slow_tool", 100*time.Millisecond, "tool result") slowModel := &cancelTestChatModel{ - delay: 2 * time.Second, + delayNs: int64(2 * time.Second), response: &schema.Message{ Role: schema.Assistant, Content: "", @@ -626,7 +739,8 @@ func TestRunWithCancel_Streaming(t *testing.T) { EnableStreaming: true, }) - iter, cancelFn := runner.RunWithCancel(ctx, []Message{schema.UserMessage("Use the tool")}) + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("Use the tool")}, cancelOpt) assert.NotNil(t, iter) assert.NotNil(t, cancelFn) @@ -651,27 +765,30 @@ func TestRunWithCancel_Streaming(t *testing.T) { time.Sleep(100 * time.Millisecond) - cancelErr := cancelFn() + handle, _ := cancelFn() + cancelErr := handle.Wait() assert.NoError(t, cancelErr) - start := time.Now() - events := <-eventsCh - elapsed := time.Since(start) + var events []*AgentEvent + select { + case events = <-eventsCh: + case <-time.After(5 * time.Second): + t.Fatal("Timed out waiting for events") + } - assert.True(t, elapsed < 1*time.Second, "Should return quickly after cancel, elapsed: %v", elapsed) assert.True(t, len(events) > 0) - hasInterrupted := false + hasCancelError := false for _, e := range events { - assert.Nil(t, e.Err, "Should not have error event after cancel") - if e.Action != nil && e.Action.Interrupted != nil { - hasInterrupted = true + var ce *CancelError + if e.Err != nil && errors.As(e.Err, &ce) { + hasCancelError = true } } - assert.True(t, hasInterrupted, "Should have interrupted event after cancel") + assert.True(t, hasCancelError, "Should have CancelError event after cancel") }) - t.Run("CancelAfterToolCall_Streaming", func(t *testing.T) { + t.Run("CancelAfterToolCalls_Streaming", func(t *testing.T) { toolStarted := make(chan struct{}, 1) st := &slowToolWithSignal{ name: "slow_tool", @@ -715,7 +832,8 @@ func TestRunWithCancel_Streaming(t *testing.T) { EnableStreaming: true, }) - iter, cancelFn := runner.RunWithCancel(ctx, []Message{schema.UserMessage("Use the tool")}) + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("Use the tool")}, cancelOpt) assert.NotNil(t, iter) assert.NotNil(t, cancelFn) @@ -723,7 +841,8 @@ func TestRunWithCancel_Streaming(t *testing.T) { time.Sleep(100 * time.Millisecond) - cancelErr := cancelFn(WithCancelMode(CancelAfterToolCall)) + handle, _ := cancelFn(WithAgentCancelMode(CancelAfterToolCalls)) + cancelErr := handle.Wait() assert.NoError(t, cancelErr) var events []*AgentEvent @@ -732,7 +851,10 @@ func TestRunWithCancel_Streaming(t *testing.T) { if !ok { break } - assert.Nil(t, event.Err, "Should not have error event after cancel") + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + continue + } events = append(events, event) } @@ -741,26 +863,20 @@ func TestRunWithCancel_Streaming(t *testing.T) { }) } -// TestResumeWithCancel tests the workflow of Cancel followed by Resume. -// -// IMPORTANT: When Cancel is triggered, the cancelableChatModel.Generate/Stream -// method returns immediately with an Interrupt error, but the inner model's -// Generate/Stream call continues running in a background goroutine until completion. -// This means the original model instance's fields (e.g., delay, response) may still -// be read by the background goroutine after Cancel returns. +// TestWithCancel_Resume tests the workflow of Cancel followed by Resume. // // To avoid data races, we create new agent and runner instances for the Resume phase // instead of reusing and modifying the original model instance. -func TestResumeWithCancel(t *testing.T) { +func TestWithCancel_Resume(t *testing.T) { ctx := context.Background() - t.Run("RunWithCancel_ThenResumeWithCancel", func(t *testing.T) { + t.Run("Cancel_ThenResume", func(t *testing.T) { modelStarted := make(chan struct{}, 1) modelCallCount := int32(0) st := newSlowTool("slow_tool", 100*time.Millisecond, "tool result") slowModel := &cancelTestChatModel{ - delay: 500 * time.Millisecond, + delayNs: int64(500 * time.Millisecond), response: &schema.Message{ Role: schema.Assistant, Content: "", @@ -800,37 +916,38 @@ func TestResumeWithCancel(t *testing.T) { CheckPointStore: store, }) - iter, cancelFn := runner.RunWithCancel(ctx, []Message{schema.UserMessage("Use the tool")}, WithCheckPointID(checkpointID)) + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("Use the tool")}, cancelOpt, WithCheckPointID(checkpointID)) <-modelStarted atomic.AddInt32(&modelCallCount, 1) - cancelErr := cancelFn() + handle, _ := cancelFn() + cancelErr := handle.Wait() assert.NoError(t, cancelErr) var events []*AgentEvent + hasCancelErr := false for { event, ok := iter.Next() if !ok { break } - assert.Nil(t, event.Err, "Should not have error event after cancel") - events = append(events, event) - } - assert.True(t, len(events) > 0) - - hasInterrupted := false - for _, e := range events { - if e.Action != nil && e.Action.Interrupted != nil { - hasInterrupted = true - break + if event.Err != nil { + var ce *CancelError + if errors.As(event.Err, &ce) { + hasCancelErr = true + continue + } + t.Fatalf("unexpected error: %v", event.Err) } + events = append(events, event) } - assert.True(t, hasInterrupted, "First run should have interrupted event") + assert.True(t, hasCancelErr, "Should have CancelError event after cancel") newModelStarted := make(chan struct{}, 1) slowModel2 := &cancelTestChatModel{ - delay: 100 * time.Millisecond, + delayNs: int64(100 * time.Millisecond), response: &schema.Message{ Role: schema.Assistant, Content: "Final response after resume", @@ -858,10 +975,10 @@ func TestResumeWithCancel(t *testing.T) { CheckPointStore: store, }) - resumeIter, resumeCancelFn, err := runner2.ResumeWithCancel(ctx, checkpointID) + resumeCancelOpt, _ := WithCancel() + resumeIter, err := runner2.Resume(ctx, checkpointID, resumeCancelOpt) assert.NoError(t, err) assert.NotNil(t, resumeIter) - assert.NotNil(t, resumeCancelFn) var resumeEvents []*AgentEvent for { @@ -876,14 +993,13 @@ func TestResumeWithCancel(t *testing.T) { assert.True(t, len(resumeEvents) > 0, "Resume should produce events") }) - t.Run("ResumeWithCancel_ThenCancel", func(t *testing.T) { + t.Run("Resume_ThenCancel", func(t *testing.T) { firstModelStarted := make(chan struct{}, 1) - resumeModelStarted := make(chan struct{}, 1) modelCallCount := int32(0) st := newSlowTool("slow_tool", 100*time.Millisecond, "tool result") slowModel := &cancelTestChatModel{ - delay: 500 * time.Millisecond, + delayNs: int64(500 * time.Millisecond), response: &schema.Message{ Role: schema.Assistant, Content: "", @@ -923,12 +1039,14 @@ func TestResumeWithCancel(t *testing.T) { CheckPointStore: store, }) - iter, cancelFn := runner.RunWithCancel(ctx, []Message{schema.UserMessage("Use the tool")}, WithCheckPointID(checkpointID)) + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("Use the tool")}, cancelOpt, WithCheckPointID(checkpointID)) <-firstModelStarted atomic.AddInt32(&modelCallCount, 1) - cancelErr := cancelFn() + handle, _ := cancelFn() + cancelErr := handle.Wait() assert.NoError(t, cancelErr) for { @@ -938,25 +1056,7 @@ func TestResumeWithCancel(t *testing.T) { } } - slowModel2 := &cancelTestChatModel{ - delay: 2 * time.Second, - response: &schema.Message{ - Role: schema.Assistant, - Content: "", - ToolCalls: []schema.ToolCall{ - { - ID: "call_1", - Type: "function", - Function: schema.FunctionCall{ - Name: "slow_tool", - Arguments: `{"input": "test"}`, - }, - }, - }, - }, - startedChan: resumeModelStarted, - doneChan: make(chan struct{}, 1), - } + slowModel2 := newBlockingChatModel(toolCallMsg(toolCall("call_1", "slow_tool", `{"input": "test"}`))) agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", @@ -977,7 +1077,8 @@ func TestResumeWithCancel(t *testing.T) { CheckPointStore: store, }) - resumeIter, resumeCancelFn, err := runner2.ResumeWithCancel(ctx, checkpointID) + resumeCancelOpt, resumeCancelFn := WithCancel() + resumeIter, err := runner2.Resume(ctx, checkpointID, resumeCancelOpt) assert.NoError(t, err) resumeEventsCh := make(chan []*AgentEvent, 1) @@ -993,13 +1094,13 @@ func TestResumeWithCancel(t *testing.T) { resumeEventsCh <- events }() - <-resumeModelStarted + <-slowModel2.started atomic.AddInt32(&modelCallCount, 1) - time.Sleep(100 * time.Millisecond) - - err = resumeCancelFn() - assert.NoError(t, err) + cancelHandle, _ := resumeCancelFn() + close(slowModel2.unblockCh) + err = cancelHandle.Wait() + assert.True(t, err == nil || errors.Is(err, ErrExecutionCompleted), "unexpected cancel wait error: %v", err) start := time.Now() resumeEvents := <-resumeEventsCh @@ -1008,13 +1109,2447 @@ func TestResumeWithCancel(t *testing.T) { assert.True(t, elapsed < 1*time.Second, "Resume should return quickly after cancel, elapsed: %v", elapsed) assert.True(t, len(resumeEvents) > 0) - hasInterrupted := false + hasCancelError := false for _, e := range resumeEvents { - assert.Nil(t, e.Err, "Should not have error event after resume cancel") - if e.Action != nil && e.Action.Interrupted != nil { - hasInterrupted = true + var ce *CancelError + if e.Err != nil && errors.As(e.Err, &ce) { + hasCancelError = true } } - assert.True(t, hasInterrupted, "Resume should have interrupted event after cancel") + executionCompletedBeforeCancel := errors.Is(err, ErrExecutionCompleted) + assert.True(t, hasCancelError || executionCompletedBeforeCancel, "Resume should have CancelError event after cancel, or execution completed before cancel") + }) +} + +func TestCancelMonitoredToolHandler_StreamableToolCall(t *testing.T) { + t.Run("NoCancelContext_PassesThrough", func(t *testing.T) { + handler := &cancelMonitoredToolHandler{} + + // Create a stream with some data + r, w := schema.Pipe[string](1) + go func() { + w.Send("chunk1", nil) + w.Send("chunk2", nil) + w.Close() + }() + + next := func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) { + return &compose.StreamToolOutput{Result: r}, nil + } + + wrapped := handler.WrapStreamableToolCall(next) + // No cancelContext in the Go context + output, err := wrapped(context.Background(), &compose.ToolInput{Name: "test"}) + assert.NoError(t, err) + + // Should get the original stream unchanged + chunk1, err := output.Result.Recv() + assert.NoError(t, err) + assert.Equal(t, "chunk1", chunk1) + + chunk2, err := output.Result.Recv() + assert.NoError(t, err) + assert.Equal(t, "chunk2", chunk2) + + _, err = output.Result.Recv() + assert.ErrorIs(t, err, io.EOF) + }) + + t.Run("WithCancelContext_NoCancel_StreamsNormally", func(t *testing.T) { + handler := &cancelMonitoredToolHandler{} + cc := newCancelContext() + + r, w := schema.Pipe[string](1) + go func() { + w.Send("data1", nil) + w.Send("data2", nil) + w.Close() + }() + + next := func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) { + return &compose.StreamToolOutput{Result: r}, nil + } + + wrapped := handler.WrapStreamableToolCall(next) + ctx := withCancelContext(context.Background(), cc) + output, err := wrapped(ctx, &compose.ToolInput{Name: "test"}) + assert.NoError(t, err) + + chunk1, err := output.Result.Recv() + assert.NoError(t, err) + assert.Equal(t, "data1", chunk1) + + chunk2, err := output.Result.Recv() + assert.NoError(t, err) + assert.Equal(t, "data2", chunk2) + + _, err = output.Result.Recv() + assert.ErrorIs(t, err, io.EOF) + }) + + t.Run("WithCancelContext_ImmediateCancel_TerminatesStream", func(t *testing.T) { + handler := &cancelMonitoredToolHandler{} + cc := newCancelContext() + + // Create a slow stream that we'll cancel mid-way + r, w := schema.Pipe[string](1) + go func() { + defer w.Close() + w.Send("chunk1", nil) + time.Sleep(200 * time.Millisecond) + w.Send("chunk2", nil) + }() + + next := func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) { + return &compose.StreamToolOutput{Result: r}, nil + } + + wrapped := handler.WrapStreamableToolCall(next) + ctx := withCancelContext(context.Background(), cc) + output, err := wrapped(ctx, &compose.ToolInput{Name: "test"}) + assert.NoError(t, err) + + // Read first chunk + chunk1, err := output.Result.Recv() + assert.NoError(t, err) + assert.Equal(t, "chunk1", chunk1) + + // Fire immediate cancel + close(cc.immediateChan) + + // Next recv should get ErrStreamCanceled + _, err = output.Result.Recv() + assert.ErrorIs(t, err, ErrStreamCanceled) + }) + + t.Run("WithCancelContext_AlreadyCancelled_TerminatesImmediately", func(t *testing.T) { + handler := &cancelMonitoredToolHandler{} + cc := newCancelContext() + close(cc.immediateChan) // Already canceled + + r, w := schema.Pipe[string](1) + go func() { + w.Send("should-not-see", nil) + w.Close() + }() + + next := func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) { + return &compose.StreamToolOutput{Result: r}, nil + } + + wrapped := handler.WrapStreamableToolCall(next) + ctx := withCancelContext(context.Background(), cc) + output, err := wrapped(ctx, &compose.ToolInput{Name: "test"}) + assert.NoError(t, err) + + _, err = output.Result.Recv() + assert.ErrorIs(t, err, ErrStreamCanceled) + }) + + t.Run("NextReturnsError_PropagatesError", func(t *testing.T) { + handler := &cancelMonitoredToolHandler{} + cc := newCancelContext() + + nextErr := errors.New("tool execution failed") + next := func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) { + return nil, nextErr + } + + wrapped := handler.WrapStreamableToolCall(next) + ctx := withCancelContext(context.Background(), cc) + _, err := wrapped(ctx, &compose.ToolInput{Name: "test"}) + assert.ErrorIs(t, err, nextErr) + }) +} + +func TestCancelMonitoredToolHandler_EnhancedStreamableToolCall(t *testing.T) { + t.Run("NoCancelContext_PassesThrough", func(t *testing.T) { + handler := &cancelMonitoredToolHandler{} + + tr1 := &schema.ToolResult{Parts: []schema.ToolOutputPart{{Type: schema.ToolPartTypeText, Text: "chunk1"}}} + r, w := schema.Pipe[*schema.ToolResult](1) + go func() { + w.Send(tr1, nil) + w.Close() + }() + + next := func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedStreamableToolOutput, error) { + return &compose.EnhancedStreamableToolOutput{Result: r}, nil + } + + wrapped := handler.WrapEnhancedStreamableToolCall(next) + output, err := wrapped(context.Background(), &compose.ToolInput{Name: "test"}) + assert.NoError(t, err) + + result, err := output.Result.Recv() + assert.NoError(t, err) + assert.Equal(t, tr1, result) + + _, err = output.Result.Recv() + assert.ErrorIs(t, err, io.EOF) + }) + + t.Run("WithCancelContext_ImmediateCancel_TerminatesStream", func(t *testing.T) { + handler := &cancelMonitoredToolHandler{} + cc := newCancelContext() + + tr1 := &schema.ToolResult{Parts: []schema.ToolOutputPart{{Type: schema.ToolPartTypeText, Text: "chunk1"}}} + tr2 := &schema.ToolResult{Parts: []schema.ToolOutputPart{{Type: schema.ToolPartTypeText, Text: "chunk2"}}} + r, w := schema.Pipe[*schema.ToolResult](1) + go func() { + defer w.Close() + w.Send(tr1, nil) + time.Sleep(200 * time.Millisecond) + w.Send(tr2, nil) + }() + + next := func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedStreamableToolOutput, error) { + return &compose.EnhancedStreamableToolOutput{Result: r}, nil + } + + wrapped := handler.WrapEnhancedStreamableToolCall(next) + ctx := withCancelContext(context.Background(), cc) + output, err := wrapped(ctx, &compose.ToolInput{Name: "test"}) + assert.NoError(t, err) + + result, err := output.Result.Recv() + assert.NoError(t, err) + assert.Equal(t, tr1, result) + + close(cc.immediateChan) + + _, err = output.Result.Recv() + assert.ErrorIs(t, err, ErrStreamCanceled) + }) + + t.Run("NextReturnsError_PropagatesError", func(t *testing.T) { + handler := &cancelMonitoredToolHandler{} + cc := newCancelContext() + + nextErr := errors.New("enhanced tool failed") + next := func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedStreamableToolOutput, error) { + return nil, nextErr + } + + wrapped := handler.WrapEnhancedStreamableToolCall(next) + ctx := withCancelContext(context.Background(), cc) + _, err := wrapped(ctx, &compose.ToolInput{Name: "test"}) + assert.ErrorIs(t, err, nextErr) + }) +} + +func TestCancelContextKey(t *testing.T) { + t.Run("WithAndGet_RoundTrips", func(t *testing.T) { + cc := newCancelContext() + ctx := withCancelContext(context.Background(), cc) + got := getCancelContext(ctx) + assert.Equal(t, cc, got) + }) + + t.Run("Get_NoValue_ReturnsNil", func(t *testing.T) { + got := getCancelContext(context.Background()) + assert.Nil(t, got) }) + + t.Run("With_NilCancelContext_ReturnsOriginalCtx", func(t *testing.T) { + ctx := context.Background() + result := withCancelContext(ctx, nil) + assert.Equal(t, ctx, result) + }) +} + +// -- Tests for cancel support across all agent types -- + +// cancelTestAgent is a ChatModelAgent-based agent where the model blocks until +// signalled, allowing tests to control exactly when to issue a cancel. +func newCancelTestAgent(t *testing.T, name string, modelDelay time.Duration, modelStarted chan struct{}) *ChatModelAgent { + t.Helper() + slowModel := &cancelTestChatModel{ + delayNs: int64(modelDelay), + response: &schema.Message{ + Role: schema.Assistant, + Content: "response from " + name, + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + + agent, err := NewChatModelAgent(context.Background(), &ChatModelAgentConfig{ + Name: name, + Description: "Test agent " + name, + Instruction: "You are a test assistant", + Model: slowModel, + }) + assert.NoError(t, err) + return agent +} + +func newCancelTestAgentWithTools(t *testing.T, name string, modelDelay time.Duration, modelStarted chan struct{}) *ChatModelAgent { + t.Helper() + toolName := name + "_tool" + slowModel := &cancelTestChatModel{ + delayNs: int64(modelDelay), + response: &schema.Message{ + Role: schema.Assistant, + ToolCalls: []schema.ToolCall{{ + ID: "call_1", Type: "function", + Function: schema.FunctionCall{ + Name: toolName, + Arguments: `{"input": "test"}`, + }, + }}, + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + + st := newSlowTool(toolName, 10*time.Millisecond, "tool result") + + agent, err := NewChatModelAgent(context.Background(), &ChatModelAgentConfig{ + Name: name, + Description: "Test agent " + name, + Instruction: "You are a test assistant", + Model: slowModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + return agent +} + +func newCancelTestAgentWithToolsFinalAnswer(t *testing.T, name string) *ChatModelAgent { + t.Helper() + toolName := name + "_tool" + finalModel := &cancelTestChatModel{ + delayNs: int64(10 * time.Millisecond), + response: &schema.Message{ + Role: schema.Assistant, + Content: "final response from " + name, + }, + startedChan: make(chan struct{}, 1), + doneChan: make(chan struct{}, 1), + } + + st := newSlowTool(toolName, 10*time.Millisecond, "tool result") + + agent, err := NewChatModelAgent(context.Background(), &ChatModelAgentConfig{ + Name: name, + Description: "Test agent " + name, + Instruction: "You are a test assistant", + Model: finalModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + return agent +} + +func TestWithCancel_SequentialAgent(t *testing.T) { + ctx := context.Background() + + t.Run("CancelImmediate_DuringSecondAgent", func(t *testing.T) { + // The first agent completes quickly. The second agent takes a long time. + // Cancel during the second agent's model call. + agent1Started := make(chan struct{}, 1) + agent2Started := make(chan struct{}, 1) + + agent1 := newCancelTestAgent(t, "fast_agent", 50*time.Millisecond, agent1Started) + agent2 := newCancelTestAgent(t, "slow_agent", 5*time.Second, agent2Started) + + seqAgent, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "seq_agent", + Description: "Sequential test", + SubAgents: []Agent{agent1, agent2}, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: seqAgent, + EnableStreaming: false, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt) + + // Wait for second agent to start + select { + case <-agent2Started: + case <-time.After(10 * time.Second): + t.Fatal("Second agent did not start") + } + + time.Sleep(50 * time.Millisecond) + + // Cancel should NOT return ErrExecutionCompleted (the bug before the fix) + handle, _ := cancelFn() + err = handle.Wait() + assert.NoError(t, err, "Cancel during second agent should succeed, not return ErrExecutionCompleted") + + var events []*AgentEvent + hasCancelError := false + for { + event, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + hasCancelError = true + } + events = append(events, event) + } + + assert.True(t, hasCancelError, "Should have CancelError event") + }) +} + +func TestWithCancel_LoopAgent(t *testing.T) { + ctx := context.Background() + + t.Run("CancelImmediate_DuringIteration", func(t *testing.T) { + // Agent in a loop. Cancel during second iteration's model call. + modelStarted := make(chan struct{}, 10) + + slowModel := &cancelTestChatModel{ + delayNs: int64(3 * time.Second), + response: &schema.Message{ + Role: schema.Assistant, + Content: "loop response", + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 10), + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "loop_inner", + Description: "Inner loop agent", + Instruction: "You are a test assistant", + Model: slowModel, + }) + assert.NoError(t, err) + + loopAgent, err := NewLoopAgent(ctx, &LoopAgentConfig{ + Name: "loop_agent", + Description: "Loop test", + SubAgents: []Agent{agent}, + MaxIterations: 10, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: loopAgent, + EnableStreaming: false, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt) + + // Wait for first iteration's model call to start + select { + case <-modelStarted: + case <-time.After(10 * time.Second): + t.Fatal("Model did not start") + } + + time.Sleep(50 * time.Millisecond) + + // Cancel should succeed + handle, _ := cancelFn() + err = handle.Wait() + assert.NoError(t, err, "Cancel during loop iteration should succeed") + + hasCancelError := false + for { + event, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + hasCancelError = true + } + } + + assert.True(t, hasCancelError, "Should have CancelError event") + }) +} + +func TestWithCancel_ParallelAgent(t *testing.T) { + ctx := context.Background() + + t.Run("CancelImmediate_InterruptsAllBranches", func(t *testing.T) { + agent1Started := make(chan struct{}, 1) + agent2Started := make(chan struct{}, 1) + + // Both agents have long delays, so cancel should interrupt both. + agent1 := newCancelTestAgent(t, "par_agent1", 5*time.Second, agent1Started) + agent2 := newCancelTestAgent(t, "par_agent2", 5*time.Second, agent2Started) + + parAgent, err := NewParallelAgent(ctx, &ParallelAgentConfig{ + Name: "par_agent", + Description: "Parallel test", + SubAgents: []Agent{agent1, agent2}, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: parAgent, + EnableStreaming: false, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt) + + // Wait for both agents to start + for i := 0; i < 2; i++ { + select { + case <-agent1Started: + case <-agent2Started: + case <-time.After(10 * time.Second): + t.Fatal("Parallel agents did not start") + } + } + + time.Sleep(50 * time.Millisecond) + + start := time.Now() + handle, _ := cancelFn() + err = handle.Wait() + assert.NoError(t, err, "Cancel during parallel agents should succeed") + + var events []*AgentEvent + hasCancelError := false + for { + event, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + hasCancelError = true + } + events = append(events, event) + } + elapsed := time.Since(start) + + assert.True(t, hasCancelError, "Should have CancelError event") + assert.True(t, elapsed < 3*time.Second, "Should complete quickly after cancel, elapsed: %v", elapsed) + }) +} + +func TestWithCancel_SupervisorAgent(t *testing.T) { + ctx := context.Background() + + t.Run("CancelImmediate_DuringSubAgent", func(t *testing.T) { + // Supervisor delegates to a slow sub-agent via transfer. + // Cancel during the sub-agent's model call. + supervisorModelStarted := make(chan struct{}, 1) + subAgentModelStarted := make(chan struct{}, 1) + + // The supervisor model returns a transfer_to_agent tool call + supervisorModel := &simpleChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: schema.FunctionCall{ + Name: TransferToAgentToolName, + Arguments: `{"agent_name": "slow_sub"}`, + }, + }, + }, + }, + } + + supervisorAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "supervisor", + Description: "Supervisor agent", + Instruction: "You are a supervisor", + Model: supervisorModel, + }) + assert.NoError(t, err) + + subAgent := newCancelTestAgent(t, "slow_sub", 5*time.Second, subAgentModelStarted) + + agentWithSubAgents, err := SetSubAgents(ctx, supervisorAgent, []Agent{subAgent}) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: agentWithSubAgents, + EnableStreaming: false, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt) + + // Ignore the supervisor model start, wait for the sub-agent model + // The supervisor model is fast (simpleChatModel), so it will start and finish quickly + _ = supervisorModelStarted + select { + case <-subAgentModelStarted: + case <-time.After(10 * time.Second): + t.Fatal("Sub-agent model did not start") + } + + time.Sleep(50 * time.Millisecond) + + start := time.Now() + handle, _ := cancelFn() + err = handle.Wait() + assert.NoError(t, err, "Cancel during sub-agent should succeed") + + hasCancelError := false + for { + event, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + hasCancelError = true + } + } + elapsed := time.Since(start) + + assert.True(t, hasCancelError, "Should have CancelError event") + assert.True(t, elapsed < 3*time.Second, "Should complete quickly after cancel, elapsed: %v", elapsed) + }) +} + +func TestFilterCancelOption(t *testing.T) { + t.Run("RemovesCancelOption", func(t *testing.T) { + cancelOpt, _ := WithCancel() + sessionOpt := WithSessionValues(map[string]any{"key": "value"}) + opts := []AgentRunOption{cancelOpt, sessionOpt} + + filtered := filterCancelOption(opts) + assert.Len(t, filtered, 1, "Should have removed the cancel option") + + // Verify the remaining option is the session option + testOpt := &options{} + filtered[0].implSpecificOptFn.(func(*options))(testOpt) + assert.NotNil(t, testOpt.sessionValues) + assert.Nil(t, testOpt.cancelCtx) + }) + + t.Run("KeepsNonCancelOptions", func(t *testing.T) { + sessionOpt := WithSessionValues(map[string]any{"key": "value"}) + callbackOpt := WithCallbacks() + opts := []AgentRunOption{sessionOpt, callbackOpt} + + filtered := filterCancelOption(opts) + assert.Len(t, filtered, 2, "Should keep all non-cancel options") + }) + + t.Run("EmptyInput", func(t *testing.T) { + filtered := filterCancelOption(nil) + assert.Nil(t, filtered) + }) +} + +func wrapIterWithMarkDone(iter *AsyncIterator[*AgentEvent], cc *cancelContext) *AsyncIterator[*AgentEvent] { + if cc == nil { + return iter + } + outIter, outGen := NewAsyncIteratorPair[*AgentEvent]() + go func() { + defer cc.markDone() + defer outGen.Close() + for { + event, ok := iter.Next() + if !ok { + return + } + outGen.Send(event) + } + }() + return outIter +} + +func TestWrapIterWithMarkDone(t *testing.T) { + t.Run("MarksDoneAfterDrain", func(t *testing.T) { + cc := newCancelContext() + iter, gen := NewAsyncIteratorPair[*AgentEvent]() + + go func() { + gen.Send(&AgentEvent{AgentName: "test"}) + gen.Close() + }() + + wrapped := wrapIterWithMarkDone(iter, cc) + + event, ok := wrapped.Next() + assert.True(t, ok) + assert.Equal(t, "test", event.AgentName) + + _, ok = wrapped.Next() + assert.False(t, ok) + + // markDone should have been called, so doneChan should be closed + select { + case <-cc.doneChan: + // good + case <-time.After(time.Second): + t.Fatal("doneChan was not closed after drain") + } + }) + + t.Run("NilCancelContext_PassesThrough", func(t *testing.T) { + iter, gen := NewAsyncIteratorPair[*AgentEvent]() + go func() { + gen.Send(&AgentEvent{AgentName: "test"}) + gen.Close() + }() + + wrapped := wrapIterWithMarkDone(iter, nil) + assert.Equal(t, iter, wrapped, "Should return same iter when cc is nil") + }) +} + +func TestGraphInterruptFuncs_Parallel(t *testing.T) { + t.Run("MultipleGraphInterruptFuncsAllCalled", func(t *testing.T) { + cc := newCancelContext() + + var called1, called2 int32 + cc.setGraphInterruptFunc(func(opts ...compose.GraphInterruptOption) { + atomic.AddInt32(&called1, 1) + }) + cc.setGraphInterruptFunc(func(opts ...compose.GraphInterruptOption) { + atomic.AddInt32(&called2, 1) + }) + + // Simulate immediate cancel + cc.setMode(CancelImmediate) + atomic.CompareAndSwapInt32(&cc.state, stateRunning, stateCancelling) + close(cc.cancelChan) + cc.sendImmediateInterrupt() + + assert.Equal(t, int32(1), atomic.LoadInt32(&called1), "First graph interrupt func should be called") + assert.Equal(t, int32(1), atomic.LoadInt32(&called2), "Second graph interrupt func should be called") + }) + + t.Run("RetroactiveFire_OnSetAfterCancel", func(t *testing.T) { + cc := newCancelContext() + + // First set up cancel state with immediate interrupt + cc.setMode(CancelImmediate) + atomic.CompareAndSwapInt32(&cc.state, stateRunning, stateCancelling) + close(cc.cancelChan) + close(cc.immediateChan) + atomic.StoreInt32(&cc.interruptSent, interruptImmediate) + + // Now register a new function - it should be retroactively fired + var called int32 + cc.setGraphInterruptFunc(func(opts ...compose.GraphInterruptOption) { + atomic.AddInt32(&called, 1) + }) + + assert.Equal(t, int32(1), atomic.LoadInt32(&called), "setGraphInterruptFunc should retroactively fire new func") + }) +} + +// -- Tests for transition-point cancel (cancel between sub-agents) -- + +// gatedChatModel is a model that: +// - Signals doneChan when Generate completes +// - Optionally blocks on gateChan before returning (nil gateChan = no blocking) +// - Tracks call count via callCount +type gatedChatModel struct { + response *schema.Message + gateChan chan struct{} // if non-nil, blocks until closed before returning + doneChan chan struct{} // signalled after Generate completes + callCount int32 +} + +func (m *gatedChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m.callCount, 1) + if m.gateChan != nil { + select { + case <-m.gateChan: + case <-ctx.Done(): + return nil, ctx.Err() + } + } + select { + case m.doneChan <- struct{}{}: + default: + } + return m.response, nil +} + +func (m *gatedChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + msg, err := m.Generate(ctx, input, opts...) + if err != nil { + return nil, err + } + return schema.StreamReaderFromArray([]*schema.Message{msg}), nil +} + +func (m *gatedChatModel) BindTools(tools []*schema.ToolInfo) error { + return nil +} + +func TestCheckCancel_Sequential_BetweenSubAgents(t *testing.T) { + ctx := context.Background() + + // CancelAfterToolCalls fires at transition boundaries between sub-agents. + // At a transition boundary, the completed sub-agent's entire execution + // (including any tool calls) is done, satisfying the CancelAfterToolCalls + // contract — even if this particular sub-agent had no tools. + model1 := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "agent1 done"}, + gateChan: make(chan struct{}), + doneChan: make(chan struct{}, 1), + } + model2 := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "agent2 done"}, + doneChan: make(chan struct{}, 1), + } + + agent1, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent1", Description: "first", Instruction: "test", Model: model1, + }) + assert.NoError(t, err) + + agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent2", Description: "second", Instruction: "test", Model: model2, + }) + assert.NoError(t, err) + + seqAgent, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "seq", Description: "sequential test", SubAgents: []Agent{agent1, agent2}, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: seqAgent, EnableStreaming: false, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("hello")}, cancelOpt) + + for atomic.LoadInt32(&model1.callCount) == 0 { + runtime.Gosched() + } + + cancelCalled, result := cancelAsync(cancelFn, WithAgentCancelMode(CancelAfterToolCalls)) + waitForChan(t, cancelCalled, "cancelFn was not called") + close(model1.gateChan) + + assert.NoError(t, result.waitDone(t), "CancelAfterToolCalls should succeed at transition boundary") + + for { + _, ok := iter.Next() + if !ok { + break + } + } + + assert.Equal(t, int32(1), atomic.LoadInt32(&model1.callCount), "Agent1 model should be invoked") + assert.Equal(t, int32(0), atomic.LoadInt32(&model2.callCount), + "Agent2 model should NOT be invoked (CancelAfterToolCalls caught at transition)") +} + +func TestCheckCancel_Loop_BetweenIterations(t *testing.T) { + ctx := context.Background() + + // CancelAfterToolCalls fires at loop iteration boundaries. + // After the first iteration completes, any tool calls it made are done, + // satisfying the CancelAfterToolCalls contract. + mdl := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "loop iter"}, + gateChan: make(chan struct{}), + doneChan: make(chan struct{}, 10), + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "loop_inner", Description: "inner", Instruction: "test", Model: mdl, + }) + assert.NoError(t, err) + + loopAgent, err := NewLoopAgent(ctx, &LoopAgentConfig{ + Name: "loop", Description: "loop test", SubAgents: []Agent{agent}, MaxIterations: 3, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: loopAgent, EnableStreaming: false, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("hello")}, cancelOpt) + + for atomic.LoadInt32(&mdl.callCount) == 0 { + runtime.Gosched() + } + + cancelCalled, result := cancelAsync(cancelFn, WithAgentCancelMode(CancelAfterToolCalls)) + waitForChan(t, cancelCalled, "cancelFn was not called") + close(mdl.gateChan) + + assert.NoError(t, result.waitDone(t), "CancelAfterToolCalls should succeed at loop transition boundary") + + for { + _, ok := iter.Next() + if !ok { + break + } + } + + assert.Equal(t, int32(1), atomic.LoadInt32(&mdl.callCount), + "Model should be called once; second iteration caught at transition") +} + +func TestCheckCancel_Parallel_PreSpawn(t *testing.T) { + ctx := context.Background() + + // Cancel fires before Run is called. Neither model should be invoked. + model1 := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "par1"}, + doneChan: make(chan struct{}, 1), + } + model2 := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "par2"}, + doneChan: make(chan struct{}, 1), + } + + agent1, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "par1", Description: "first", Instruction: "test", Model: model1, + }) + assert.NoError(t, err) + + agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "par2", Description: "second", Instruction: "test", Model: model2, + }) + assert.NoError(t, err) + + parAgent, err := NewParallelAgent(ctx, &ParallelAgentConfig{ + Name: "par", Description: "parallel test", SubAgents: []Agent{agent1, agent2}, + }) + assert.NoError(t, err) + + // Fire cancel in goroutine (cancelFn blocks until handled) + cancelOpt, cancelFn := WithCancel() + cancelDone := make(chan error, 1) + go func() { + handle, _ := cancelFn() + cancelDone <- handle.Wait() + }() + // Wait for cancelChan to be closed (happens synchronously before the blocking doneChan wait) + time.Sleep(20 * time.Millisecond) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: parAgent, EnableStreaming: false, + }) + + iter := runner.Run(ctx, []Message{schema.UserMessage("hello")}, cancelOpt) + + var cancelErr *CancelError + for { + event, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + cancelErr = ce + } + } + + // cancelFn should have completed + select { + case err = <-cancelDone: + assert.NoError(t, err) + case <-time.After(5 * time.Second): + t.Fatal("cancelFn did not return") + } + + assert.NotNil(t, cancelErr, "Should have CancelError") + assert.Equal(t, int32(0), atomic.LoadInt32(&model1.callCount), "First model should never be invoked") + assert.Equal(t, int32(0), atomic.LoadInt32(&model2.callCount), "Second model should never be invoked") +} + +func TestCheckCancel_Transfer_BeforeTarget(t *testing.T) { + ctx := context.Background() + + // Supervisor CMA returns a transfer action (instantly). + // Cancel fires after transfer action but before target runs. + // Target model should never be invoked. + supervisorModel := &simpleChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + ToolCalls: []schema.ToolCall{{ + ID: "call_1", Type: "function", + Function: schema.FunctionCall{ + Name: TransferToAgentToolName, + Arguments: `{"agent_name": "target"}`, + }, + }}, + }, + } + targetModel := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "target done"}, + doneChan: make(chan struct{}, 1), + } + + supervisorAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "supervisor", Description: "supervisor", Instruction: "test", Model: supervisorModel, + }) + assert.NoError(t, err) + + targetAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "target", Description: "target", Instruction: "test", Model: targetModel, + }) + assert.NoError(t, err) + + agentWithSub, err := SetSubAgents(ctx, supervisorAgent, []Agent{targetAgent}) + assert.NoError(t, err) + + // Fire cancel in goroutine (cancelFn blocks until handled) + cancelOpt, cancelFn := WithCancel() + cancelDone := make(chan error, 1) + go func() { + handle, _ := cancelFn() + cancelDone <- handle.Wait() + }() + time.Sleep(20 * time.Millisecond) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: agentWithSub, EnableStreaming: false, + }) + + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt) + + var cancelErr *CancelError + for { + event, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + cancelErr = ce + } + } + + select { + case err = <-cancelDone: + assert.NoError(t, err) + case <-time.After(5 * time.Second): + t.Fatal("cancelFn did not return") + } + + assert.NotNil(t, cancelErr, "Should have CancelError") + assert.Equal(t, int32(0), atomic.LoadInt32(&targetModel.callCount), "Target model should never be invoked") +} + +func TestCheckCancel_AlreadyHandled_NoDuplicate(t *testing.T) { + ctx := context.Background() + + // In a sequential agent, if the first CMA handles the cancel (graph interrupt), + // the workflow's transition check should NOT emit a duplicate CancelError. + // Use a slow model so cancel fires during its execution (handled by CMA). + modelStarted := make(chan struct{}, 1) + model1 := &cancelTestChatModel{ + delayNs: int64(2 * time.Second), + response: &schema.Message{Role: schema.Assistant, Content: "agent1"}, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + model2 := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "agent2"}, + doneChan: make(chan struct{}, 1), + } + + agent1, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent1", Description: "first", Instruction: "test", Model: model1, + }) + assert.NoError(t, err) + + agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent2", Description: "second", Instruction: "test", Model: model2, + }) + assert.NoError(t, err) + + seqAgent, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "seq", Description: "sequential", SubAgents: []Agent{agent1, agent2}, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: seqAgent, EnableStreaming: false, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt) + + // Wait for model to start, then cancel during model execution + select { + case <-modelStarted: + case <-time.After(5 * time.Second): + t.Fatal("Model did not start") + } + time.Sleep(50 * time.Millisecond) + handle, _ := cancelFn() + err = handle.Wait() + assert.NoError(t, err) + + cancelCount := 0 + for { + event, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + cancelCount++ + } + } + + assert.Equal(t, 1, cancelCount, "Should have exactly one CancelError, no duplicate from workflow transition") + assert.Equal(t, int32(0), atomic.LoadInt32(&model2.callCount), "Second agent should not run") +} + +// Tests for CancelAfterChatModel/CancelAfterToolCalls in nested workflow structures. +// These verify that safe-point cancel modes propagate through the entire agent hierarchy +// and fire at whichever nested level reaches the safe-point first. + +func TestCancel_SequentialWorkflow_CancelAfterChatModel(t *testing.T) { + ctx := context.Background() + agent1Started := make(chan struct{}, 1) + + agent1 := newCancelTestAgentWithTools(t, "seq_slow", 500*time.Millisecond, agent1Started) + agent2 := newCancelTestAgentWithTools(t, "seq_fast", 50*time.Millisecond, make(chan struct{}, 1)) + + seqAgent, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "seq_agent", + Description: "Sequential workflow", + SubAgents: []Agent{agent1, agent2}, + }) + assert.NoError(t, err) + + store := newCancelTestStore() + runner := NewRunner(ctx, RunnerConfig{ + Agent: seqAgent, + CheckPointStore: store, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt, WithCheckPointID("seq-cancel-1")) + + select { + case <-agent1Started: + case <-time.After(10 * time.Second): + t.Fatal("First agent did not start") + } + + handle, contributed := cancelFn(WithAgentCancelMode(CancelAfterChatModel)) + assert.True(t, contributed, "Cancel should contribute") + err = handle.Wait() + assert.NoError(t, err) + + hasCancelError := false + var cancelErr *CancelError + for { + event, ok := iter.Next() + if !ok { + break + } + if event.Err != nil && errors.As(event.Err, &cancelErr) { + hasCancelError = true + } + } + + assert.True(t, hasCancelError, "Should have CancelError") + assert.Equal(t, CancelAfterChatModel, cancelErr.Info.Mode) + assert.NotEmpty(t, cancelErr.CheckPointID, "CancelError should have checkpoint ID") + assert.NotNil(t, cancelErr.interruptSignal, "CancelError should have interrupt signal for checkpoint") + + resumeAgent1 := newCancelTestAgentWithToolsFinalAnswer(t, "seq_slow") + resumeAgent2 := newCancelTestAgentWithToolsFinalAnswer(t, "seq_fast") + + resumeSeq, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "seq_agent", + Description: "Sequential workflow", + SubAgents: []Agent{resumeAgent1, resumeAgent2}, + }) + assert.NoError(t, err) + + runner2 := NewRunner(ctx, RunnerConfig{ + Agent: resumeSeq, + CheckPointStore: store, + }) + + resumeIter, err := runner2.Resume(ctx, "seq-cancel-1") + assert.NoError(t, err) + assert.NotNil(t, resumeIter) + + var resumeEvents []*AgentEvent + for { + event, ok := resumeIter.Next() + if !ok { + break + } + assert.Nil(t, event.Err, "Should not have error during resume") + resumeEvents = append(resumeEvents, event) + } + assert.True(t, len(resumeEvents) > 0, "Resume should produce events") +} + +// -- Tests for CancelImmediate in nested agent structures -- + +func newTestChatModel(response *schema.Message, delay time.Duration) *cancelTestChatModel { + m := &cancelTestChatModel{ + response: response, + startedChan: make(chan struct{}, 1), + doneChan: make(chan struct{}, 1), + } + if delay > 0 { + m.setDelay(delay) + } + return m +} + +func newToolCallResponse(toolName string) *schema.Message { + return &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + {ID: "call_1", Type: "function", Function: schema.FunctionCall{Name: toolName, Arguments: `{}`}}, + }, + } +} + +func newAgentWithTool(t *testing.T, ctx context.Context, name string, mdl model.BaseChatModel, subAgent Agent) (Agent, error) { + t.Helper() + return NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: name, + Description: name, + Model: mdl, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{NewAgentTool(ctx, subAgent)}, + }, + }, + }) +} + +func waitForChan(t *testing.T, ch <-chan struct{}, msg string) { + t.Helper() + select { + case <-ch: + case <-time.After(10 * time.Second): + t.Fatal(msg) + } +} + +func drainCancelError(t *testing.T, iter *AsyncIterator[*AgentEvent]) *CancelError { + t.Helper() + var cancelErr *CancelError + for { + event, ok := iter.Next() + if !ok { + break + } + if event.Err != nil { + errors.As(event.Err, &cancelErr) + } + } + return cancelErr +} + +func drainResumeErrors(t *testing.T, iter *AsyncIterator[*AgentEvent]) []error { + t.Helper() + var errs []error + for { + event, ok := iter.Next() + if !ok { + break + } + if event.Err != nil { + errs = append(errs, event.Err) + } + } + return errs +} + +type cancelResult struct { + err error + contributed bool + done chan struct{} +} + +func cancelAsync(cancelFn AgentCancelFunc, opts ...AgentCancelOption) (cancelCalled chan struct{}, result *cancelResult) { + cancelCalled = make(chan struct{}) + result = &cancelResult{done: make(chan struct{})} + go func() { + handle, contributed := cancelFn(opts...) + result.contributed = contributed + close(cancelCalled) + result.err = handle.Wait() + close(result.done) + }() + return +} + +func (r *cancelResult) waitDone(t *testing.T) error { + t.Helper() + select { + case <-r.done: + return r.err + case <-time.After(10 * time.Second): + t.Fatal("cancel did not complete") + return nil + } +} + +func TestCancelImmediate_AgentTool_PreservesChildCheckpoint(t *testing.T) { + ctx := context.Background() + + leafModel := newTestChatModel( + &schema.Message{Role: schema.Assistant, Content: "leaf response"}, 2*time.Second) + leafAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "leaf_agent", Description: "Leaf agent in agentTool", Model: leafModel, + }) + assert.NoError(t, err) + + seqAgent, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "inner_seq", Description: "Inner sequential workflow", SubAgents: []Agent{leafAgent}, + }) + assert.NoError(t, err) + + rootModel := newTestChatModel(newToolCallResponse("inner_seq"), 0) + rootAgent, err := newAgentWithTool(t, ctx, "root_agent", rootModel, seqAgent) + assert.NoError(t, err) + + store := newCancelTestStore() + runner := NewRunner(ctx, RunnerConfig{Agent: rootAgent, CheckPointStore: store}) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt, WithCheckPointID("immediate-agent-tool-1")) + + waitForChan(t, leafModel.startedChan, "Leaf agent model did not start") + + handle, contributed := cancelFn() + assert.True(t, contributed) + assert.NoError(t, handle.Wait()) + + cancelErr := drainCancelError(t, iter) + assert.NotNil(t, cancelErr, "Should have CancelError from CancelImmediate through agentTool") + assert.NotEmpty(t, cancelErr.CheckPointID) + assert.NotNil(t, cancelErr.interruptSignal) + + resumeLeaf, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "leaf_agent", Description: "Leaf agent in agentTool", + Model: newTestChatModel(&schema.Message{Role: schema.Assistant, Content: "resumed leaf"}, 0), + }) + assert.NoError(t, err) + resumeSeq, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "inner_seq", Description: "Inner sequential workflow", SubAgents: []Agent{resumeLeaf}, + }) + assert.NoError(t, err) + resumeRoot, err := newAgentWithTool(t, ctx, "root_agent", + newTestChatModel(&schema.Message{Role: schema.Assistant, Content: "resumed root"}, 0), resumeSeq) + assert.NoError(t, err) + + runner2 := NewRunner(ctx, RunnerConfig{Agent: resumeRoot, CheckPointStore: store}) + resumeIter, err := runner2.Resume(ctx, "immediate-agent-tool-1") + assert.NoError(t, err) + assert.Empty(t, drainResumeErrors(t, resumeIter), "Resume should complete without errors") +} + +func TestCancelImmediate_ParallelWorkflow_WithAgentTool(t *testing.T) { + ctx := context.Background() + + leafModel := newTestChatModel( + &schema.Message{Role: schema.Assistant, Content: "leaf response"}, 2*time.Second) + leafAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "leaf_agent", Description: "Leaf agent in agentTool", Model: leafModel, + }) + assert.NoError(t, err) + + agentWithTool, err := newAgentWithTool(t, ctx, "agent_with_tool", + newTestChatModel(newToolCallResponse("leaf_agent"), 0), leafAgent) + assert.NoError(t, err) + + simpleStarted := make(chan struct{}, 1) + simpleAgent := newCancelTestAgent(t, "simple_agent", 2*time.Second, simpleStarted) + + parAgent, err := NewParallelAgent(ctx, &ParallelAgentConfig{ + Name: "par_agent", Description: "Parallel with agentTool and simple agent", + SubAgents: []Agent{agentWithTool, simpleAgent}, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{Agent: parAgent, EnableStreaming: false}) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt) + + waitForChan(t, leafModel.startedChan, "Leaf agent did not start") + waitForChan(t, simpleStarted, "Simple agent did not start") + + start := time.Now() + handle, _ := cancelFn() + assert.NoError(t, handle.Wait()) + + cancelErr := drainCancelError(t, iter) + elapsed := time.Since(start) + + assert.NotNil(t, cancelErr, "Should have CancelError from parallel with agentTool") + assert.True(t, elapsed < 5*time.Second, "Should complete quickly after cancel, elapsed: %v", elapsed) +} + +type cancelUnawareAgent struct { + name string + desc string + delay time.Duration + response string +} + +type multiResponseGatedModel struct { + responses []*schema.Message + gateChan chan struct{} + gateOnce bool + gated int32 + doneChan chan struct{} + callCount int32 +} + +func (m *multiResponseGatedModel) Generate(ctx context.Context, msgs []*schema.Message, opts ...model.Option) (*schema.Message, error) { + idx := atomic.AddInt32(&m.callCount, 1) + if m.gateChan != nil && (!m.gateOnce || atomic.CompareAndSwapInt32(&m.gated, 0, 1)) { + select { + case <-m.gateChan: + case <-ctx.Done(): + return nil, ctx.Err() + } + } + if len(m.responses) == 0 { + return nil, fmt.Errorf("multiResponseGatedModel: no responses configured") + } + resp := m.responses[(int(idx)-1)%len(m.responses)] + if m.doneChan != nil { + select { + case m.doneChan <- struct{}{}: + default: + } + } + return resp, nil +} + +func (m *multiResponseGatedModel) Stream(ctx context.Context, msgs []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + resp, err := m.Generate(ctx, msgs, opts...) + if err != nil { + return nil, err + } + return schema.StreamReaderFromArray([]*schema.Message{resp}), nil +} + +func (m *multiResponseGatedModel) BindTools(tools []*schema.ToolInfo) error { return nil } + +func (a *cancelUnawareAgent) Name(_ context.Context) string { return a.name } +func (a *cancelUnawareAgent) Description(_ context.Context) string { return a.desc } + +func (a *cancelUnawareAgent) Run(_ context.Context, _ *AgentInput, _ ...AgentRunOption) *AsyncIterator[*AgentEvent] { + iter, gen := NewAsyncIteratorPair[*AgentEvent]() + go func() { + defer gen.Close() + // Intentionally ignores ctx.Done() — simulates a custom agent that + // does not participate in the cancel protocol at all. + // Delay is kept short (relative to grace period) to avoid goroutine + // leak lasting long after the test completes. + time.Sleep(a.delay) + }() + return iter +} + +func TestCancelImmediate_CustomAgent_GracePeriodFallback(t *testing.T) { + ctx := context.Background() + + customAgent := &cancelUnawareAgent{ + name: "custom_slow", desc: "A custom agent that ignores cancel", + delay: 5 * time.Second, response: "custom response", + } + + rootModel := newTestChatModel(newToolCallResponse("custom_slow"), 0) + rootAgent, err := newAgentWithTool(t, ctx, "root_agent", rootModel, customAgent) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{Agent: rootAgent, EnableStreaming: false}) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt) + + waitForChan(t, rootModel.startedChan, "Root model did not start") + waitForChan(t, rootModel.doneChan, "Root model did not finish") + + start := time.Now() + handle, _ := cancelFn() + assert.NoError(t, handle.Wait()) + + cancelErr := drainCancelError(t, iter) + elapsed := time.Since(start) + + assert.NotNil(t, cancelErr, "Should have CancelError (from grace period fallback)") + assert.True(t, elapsed < 5*time.Second, + "Should complete within grace period + overhead, elapsed: %v", elapsed) +} + +func TestCancelImmediate_MultiLevelNesting(t *testing.T) { + ctx := context.Background() + + innerLeafModel := newTestChatModel( + &schema.Message{Role: schema.Assistant, Content: "inner leaf response"}, 2*time.Second) + innerLeafAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "inner_leaf", Description: "Innermost leaf agent", Model: innerLeafModel, + }) + assert.NoError(t, err) + + middleAgent, err := newAgentWithTool(t, ctx, "middle_agent", + newTestChatModel(newToolCallResponse("inner_leaf"), 0), innerLeafAgent) + assert.NoError(t, err) + + rootAgent, err := newAgentWithTool(t, ctx, "root_agent", + newTestChatModel(newToolCallResponse("middle_agent"), 0), middleAgent) + assert.NoError(t, err) + + store := newCancelTestStore() + runner := NewRunner(ctx, RunnerConfig{Agent: rootAgent, CheckPointStore: store}) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt, WithCheckPointID("multi-level-1")) + + waitForChan(t, innerLeafModel.startedChan, "Inner leaf model did not start") + + start := time.Now() + handle, contributed := cancelFn() + assert.True(t, contributed) + assert.NoError(t, handle.Wait()) + + cancelErr := drainCancelError(t, iter) + elapsed := time.Since(start) + + assert.NotNil(t, cancelErr, "Should have CancelError from multi-level nesting") + assert.NotEmpty(t, cancelErr.CheckPointID) + assert.NotNil(t, cancelErr.interruptSignal) + assert.True(t, elapsed < 5*time.Second, "Should complete quickly, elapsed: %v", elapsed) + + resumeInnerLeaf, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "inner_leaf", Description: "Innermost leaf agent", + Model: newTestChatModel(&schema.Message{Role: schema.Assistant, Content: "resumed inner leaf"}, 0), + }) + assert.NoError(t, err) + resumeMiddle, err := newAgentWithTool(t, ctx, "middle_agent", + newTestChatModel(&schema.Message{Role: schema.Assistant, Content: "resumed middle"}, 0), resumeInnerLeaf) + assert.NoError(t, err) + resumeRoot, err := newAgentWithTool(t, ctx, "root_agent", + newTestChatModel(&schema.Message{Role: schema.Assistant, Content: "resumed root"}, 0), resumeMiddle) + assert.NoError(t, err) + + runner2 := NewRunner(ctx, RunnerConfig{Agent: resumeRoot, CheckPointStore: store}) + resumeIter, err := runner2.Resume(ctx, "multi-level-1") + assert.NoError(t, err) + assert.Empty(t, drainResumeErrors(t, resumeIter), "Resume should complete without errors") +} + +func TestCancelImmediate_SequentialTransitionBoundary(t *testing.T) { + ctx := context.Background() + + model1 := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "agent1 done"}, + gateChan: make(chan struct{}), + doneChan: make(chan struct{}, 1), + } + model2 := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "agent2 done"}, + doneChan: make(chan struct{}, 1), + } + + agent1, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent1", Description: "first", Instruction: "test", Model: model1, + }) + assert.NoError(t, err) + + agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent2", Description: "second", Instruction: "test", Model: model2, + }) + assert.NoError(t, err) + + seqAgent, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "seq", Description: "sequential test", SubAgents: []Agent{agent1, agent2}, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: seqAgent, EnableStreaming: false, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("hello")}, cancelOpt) + + for atomic.LoadInt32(&model1.callCount) == 0 { + runtime.Gosched() + } + + cancelCalled, result := cancelAsync(cancelFn) + waitForChan(t, cancelCalled, "cancelFn was not called") + close(model1.gateChan) + + assert.NoError(t, result.waitDone(t), "CancelImmediate should succeed at transition") + + cancelErr := drainCancelError(t, iter) + + assert.NotNil(t, cancelErr, "Should have CancelError at transition boundary") + assert.Equal(t, int32(1), atomic.LoadInt32(&model1.callCount), "Agent1 model should be invoked") + assert.Equal(t, int32(0), atomic.LoadInt32(&model2.callCount), "Agent2 model should NOT be invoked (caught at transition)") +} + +func TestCancelImmediate_LoopTransitionBoundary(t *testing.T) { + ctx := context.Background() + + mdl := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "loop iter"}, + gateChan: make(chan struct{}), + doneChan: make(chan struct{}, 10), + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "loop_inner", Description: "inner", Instruction: "test", Model: mdl, + }) + assert.NoError(t, err) + + loopAgent, err := NewLoopAgent(ctx, &LoopAgentConfig{ + Name: "loop", Description: "loop test", SubAgents: []Agent{agent}, MaxIterations: 5, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: loopAgent, EnableStreaming: false, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("hello")}, cancelOpt) + + for atomic.LoadInt32(&mdl.callCount) == 0 { + runtime.Gosched() + } + + cancelCalled, result := cancelAsync(cancelFn) + waitForChan(t, cancelCalled, "cancelFn was not called") + close(mdl.gateChan) + + assert.NoError(t, result.waitDone(t), "CancelImmediate should succeed at loop transition") + + for { + _, ok := iter.Next() + if !ok { + break + } + } + + assert.Equal(t, int32(1), atomic.LoadInt32(&mdl.callCount), + "Model should be called once; second iteration caught at transition") +} + +func TestCancelAfterChatModel_SequentialTransitionBoundary(t *testing.T) { + ctx := context.Background() + + model1 := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "agent1 done"}, + gateChan: make(chan struct{}), + doneChan: make(chan struct{}, 1), + } + model2 := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "agent2 done"}, + doneChan: make(chan struct{}, 1), + } + + agent1, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent1", Description: "first", Instruction: "test", Model: model1, + }) + assert.NoError(t, err) + + agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent2", Description: "second", Instruction: "test", Model: model2, + }) + assert.NoError(t, err) + + seqAgent, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "seq", Description: "sequential test", SubAgents: []Agent{agent1, agent2}, + }) + assert.NoError(t, err) + + store := newCancelTestStore() + runner := NewRunner(ctx, RunnerConfig{ + Agent: seqAgent, + EnableStreaming: false, + CheckPointStore: store, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("hello")}, cancelOpt, WithCheckPointID("chatmodel-transition-1")) + + for atomic.LoadInt32(&model1.callCount) == 0 { + runtime.Gosched() + } + + cancelCalled, result := cancelAsync(cancelFn, WithAgentCancelMode(CancelAfterChatModel)) + waitForChan(t, cancelCalled, "cancelFn was not called") + close(model1.gateChan) + + assert.NoError(t, result.waitDone(t), "CancelAfterChatModel should succeed at transition boundary") + + var cancelErr *CancelError + for { + event, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + cancelErr = ce + } + } + + assert.NotNil(t, cancelErr, "Should have CancelError at transition boundary") + assert.Equal(t, CancelAfterChatModel, cancelErr.Info.Mode) + assert.Equal(t, int32(1), atomic.LoadInt32(&model1.callCount), "Agent1 model should be invoked") + assert.Equal(t, int32(0), atomic.LoadInt32(&model2.callCount), + "Agent2 model should NOT be invoked (CancelAfterChatModel caught at transition)") +} + +func TestCancelAfterChatModel_Sequential_Agent1CompletesCancelBeforeAgent2Resume(t *testing.T) { + ctx := context.Background() + + model1 := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "agent1 done"}, + gateChan: make(chan struct{}), + doneChan: make(chan struct{}, 1), + } + model2 := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "agent2 done"}, + doneChan: make(chan struct{}, 1), + } + model3 := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "agent3 done"}, + doneChan: make(chan struct{}, 1), + } + + agent1, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent1", Description: "first", Instruction: "test", Model: model1, + }) + assert.NoError(t, err) + agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent2", Description: "second", Instruction: "test", Model: model2, + }) + assert.NoError(t, err) + agent3, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent3", Description: "third", Instruction: "test", Model: model3, + }) + assert.NoError(t, err) + + seqAgent, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "seq", Description: "3-agent sequential", SubAgents: []Agent{agent1, agent2, agent3}, + }) + assert.NoError(t, err) + + store := newCancelTestStore() + runner := NewRunner(ctx, RunnerConfig{ + Agent: seqAgent, CheckPointStore: store, EnableStreaming: false, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("hello")}, cancelOpt, + WithCheckPointID("seq-transition-resume-1")) + + for atomic.LoadInt32(&model1.callCount) == 0 { + runtime.Gosched() + } + + cancelCalled, result := cancelAsync(cancelFn, WithAgentCancelMode(CancelAfterChatModel)) + waitForChan(t, cancelCalled, "cancelFn was not called") + close(model1.gateChan) + + assert.NoError(t, result.waitDone(t)) + + cancelErr := drainCancelError(t, iter) + assert.NotNil(t, cancelErr, "Should have CancelError") + assert.Equal(t, CancelAfterChatModel, cancelErr.Info.Mode) + assert.Equal(t, int32(1), atomic.LoadInt32(&model1.callCount)) + assert.Equal(t, int32(0), atomic.LoadInt32(&model2.callCount), + "Agent2 should NOT run (cancel caught at transition after agent1)") + assert.Equal(t, int32(0), atomic.LoadInt32(&model3.callCount)) + assert.NotEmpty(t, cancelErr.CheckPointID) + + resumeModel2 := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "resumed agent2"}, + doneChan: make(chan struct{}, 1), + } + resumeModel3 := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "resumed agent3"}, + doneChan: make(chan struct{}, 1), + } + + resumeAgent1, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent1", Description: "first", Instruction: "test", + Model: &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "should not run"}, + doneChan: make(chan struct{}, 1), + }, + }) + assert.NoError(t, err) + resumeAgent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent2", Description: "second", Instruction: "test", Model: resumeModel2, + }) + assert.NoError(t, err) + resumeAgent3, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent3", Description: "third", Instruction: "test", Model: resumeModel3, + }) + assert.NoError(t, err) + + resumeSeq, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "seq", Description: "3-agent sequential", + SubAgents: []Agent{resumeAgent1, resumeAgent2, resumeAgent3}, + }) + assert.NoError(t, err) + + runner2 := NewRunner(ctx, RunnerConfig{ + Agent: resumeSeq, CheckPointStore: store, EnableStreaming: false, + }) + resumeIter, err := runner2.Resume(ctx, "seq-transition-resume-1") + assert.NoError(t, err) + assert.Empty(t, drainResumeErrors(t, resumeIter), "Resume should complete without errors") + + assert.Equal(t, int32(1), atomic.LoadInt32(&resumeModel2.callCount), + "Agent2 should run on resume") + assert.Equal(t, int32(1), atomic.LoadInt32(&resumeModel3.callCount), + "Agent3 should run on resume") +} + +func TestCancelAfterToolCalls_LoopTransitionBoundary(t *testing.T) { + ctx := context.Background() + + // Model that returns tool calls on odd calls and no tools on even calls. + // This completes one ReAct cycle per pair of calls: + // call 1 (gated): returns tool call → tool runs → call 2: returns no tools → END + // The gate only blocks the very first call. After that, all calls proceed instantly. + mdl := &multiResponseGatedModel{ + responses: []*schema.Message{ + {Role: schema.Assistant, ToolCalls: []schema.ToolCall{{ + ID: "call_1", Type: "function", + Function: schema.FunctionCall{Name: "loop_tool", Arguments: `{"input": "test"}`}, + }}}, + {Role: schema.Assistant, Content: "iteration done"}, + }, + gateChan: make(chan struct{}), + gateOnce: true, + doneChan: make(chan struct{}, 10), + } + + st := &slowTool{ + name: "loop_tool", + delay: 10 * time.Millisecond, + result: "tool done", + startedChan: make(chan struct{}, 10), + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "loop_inner", Description: "inner", Instruction: "test", Model: mdl, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + + loopAgent, err := NewLoopAgent(ctx, &LoopAgentConfig{ + Name: "loop", Description: "loop test", SubAgents: []Agent{agent}, MaxIterations: 10, + }) + assert.NoError(t, err) + + store := newCancelTestStore() + runner := NewRunner(ctx, RunnerConfig{Agent: loopAgent, CheckPointStore: store}) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt, WithCheckPointID("toolcalls-loop-1")) + + // Wait for the model to be entered (blocked on gate) + for atomic.LoadInt32(&mdl.callCount) == 0 { + runtime.Gosched() + } + + // Fire cancel, wait for it to be registered, then release the gate + cancelCalled, result := cancelAsync(cancelFn, WithAgentCancelMode(CancelAfterToolCalls)) + waitForChan(t, cancelCalled, "cancelFn was not called") + close(mdl.gateChan) + + // Iteration 1 completes fully (model→tool→model-no-tools→END). + // The CancelAfterToolCalls safe-point inside ReAct fires after tool calls, + // OR the transition boundary catches it before iteration 2. + // Note: this test doesn't deterministically distinguish which path fires — + // both are semantically correct for CancelAfterToolCalls. The transition- + // boundary code path for CancelAfterToolCalls in loops is not definitively + // covered here because the ReAct safe-point may handle it first. + assert.NoError(t, result.waitDone(t)) + + cancelErr := drainCancelError(t, iter) + assert.NotNil(t, cancelErr, "Should have CancelError from CancelAfterToolCalls in loop") + assert.Equal(t, CancelAfterToolCalls, cancelErr.Info.Mode) + assert.NotEmpty(t, cancelErr.CheckPointID) +} + +func TestCancelContext_ActiveChildren_Tracking(t *testing.T) { + t.Run("DeriveChild_IncrementsActiveChildren", func(t *testing.T) { + parent := newCancelContext() + assert.False(t, parent.hasActiveChildren()) + + ctx := context.Background() + child := parent.deriveChild(ctx) + assert.True(t, parent.hasActiveChildren()) + assert.Equal(t, int32(1), atomic.LoadInt32(&parent.activeChildren)) + + child.markDone() + time.Sleep(10 * time.Millisecond) + assert.False(t, parent.hasActiveChildren()) + assert.Equal(t, int32(0), atomic.LoadInt32(&parent.activeChildren)) + }) + + t.Run("MultipleChildren_AllTracked", func(t *testing.T) { + parent := newCancelContext() + ctx := context.Background() + + child1 := parent.deriveChild(ctx) + child2 := parent.deriveChild(ctx) + assert.Equal(t, int32(2), atomic.LoadInt32(&parent.activeChildren)) + + child1.markDone() + time.Sleep(10 * time.Millisecond) + assert.Equal(t, int32(1), atomic.LoadInt32(&parent.activeChildren)) + assert.True(t, parent.hasActiveChildren()) + + child2.markDone() + time.Sleep(10 * time.Millisecond) + assert.False(t, parent.hasActiveChildren()) + }) + + t.Run("MarkCancelHandled_AlsoDecrementsParent", func(t *testing.T) { + parent := newCancelContext() + ctx := context.Background() + + child := parent.deriveChild(ctx) + assert.True(t, parent.hasActiveChildren()) + + child.triggerCancel(CancelImmediate) + child.markCancelHandled() + time.Sleep(10 * time.Millisecond) + assert.False(t, parent.hasActiveChildren()) + }) + + t.Run("GracePeriodWrapper_AppliesWhenChildrenActive", func(t *testing.T) { + parent := newCancelContext() + ctx := context.Background() + + var receivedOpts []compose.GraphInterruptOption + mockInterrupt := func(opts ...compose.GraphInterruptOption) { + receivedOpts = opts + } + + wrapped := parent.wrapGraphInterruptWithGracePeriod(mockInterrupt) + + // No children: no options appended + receivedOpts = nil + wrapped() + assert.Empty(t, receivedOpts, "Should pass no extra options when no children") + + // With active child: one timeout option appended + _ = parent.deriveChild(ctx) + receivedOpts = nil + wrapped() + assert.Len(t, receivedOpts, 1, "Should add exactly one timeout option when children are active") + + // Caller-provided options are preserved, grace period option appended after + receivedOpts = nil + callerOpt := compose.WithGraphInterruptTimeout(0) + wrapped(callerOpt) + assert.Len(t, receivedOpts, 2, + "Should append timeout option after caller-provided options when children are active") + // Note: verifying the exact timeout value (defaultCancelImmediateGracePeriod) + // requires access to unexported compose.graphInterruptOptions. The integration + // tests (TestCancelImmediate_AgentTool_PreservesChildCheckpoint) verify the + // actual behavioral effect — child interrupts propagate within the grace period. + }) +} + +func TestCancel_ParallelWorkflow_CancelAfterChatModel(t *testing.T) { + ctx := context.Background() + slowStarted := make(chan struct{}, 1) + + slowAgent := newCancelTestAgentWithTools(t, "par_slow", 1*time.Second, slowStarted) + fastAgent := newCancelTestAgentWithTools(t, "par_fast", 50*time.Millisecond, make(chan struct{}, 1)) + + parAgent, err := NewParallelAgent(ctx, &ParallelAgentConfig{ + Name: "par_agent", + Description: "Parallel workflow", + SubAgents: []Agent{slowAgent, fastAgent}, + }) + assert.NoError(t, err) + + store := newCancelTestStore() + runner := NewRunner(ctx, RunnerConfig{ + Agent: parAgent, + CheckPointStore: store, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt, WithCheckPointID("par-cancel-1")) + + select { + case <-slowStarted: + case <-time.After(10 * time.Second): + t.Fatal("Slow agent did not start") + } + + handle, contributed := cancelFn(WithAgentCancelMode(CancelAfterChatModel)) + assert.True(t, contributed, "Cancel should contribute") + err = handle.Wait() + assert.NoError(t, err) + + hasCancelError := false + var cancelErr *CancelError + for { + event, ok := iter.Next() + if !ok { + break + } + if event.Err != nil && errors.As(event.Err, &cancelErr) { + hasCancelError = true + } + } + + assert.True(t, hasCancelError, "Should have CancelError from parallel workflow") + assert.Equal(t, CancelAfterChatModel, cancelErr.Info.Mode) + assert.NotEmpty(t, cancelErr.CheckPointID, "CancelError should have checkpoint ID") + + resumeSlow := newCancelTestAgentWithToolsFinalAnswer(t, "par_slow") + resumeFast := newCancelTestAgentWithToolsFinalAnswer(t, "par_fast") + + resumePar, err := NewParallelAgent(ctx, &ParallelAgentConfig{ + Name: "par_agent", + Description: "Parallel workflow", + SubAgents: []Agent{resumeSlow, resumeFast}, + }) + assert.NoError(t, err) + + runner2 := NewRunner(ctx, RunnerConfig{ + Agent: resumePar, + CheckPointStore: store, + }) + + resumeIter, err := runner2.Resume(ctx, "par-cancel-1") + assert.NoError(t, err) + assert.NotNil(t, resumeIter) + + var resumeErrors []error + for { + event, ok := resumeIter.Next() + if !ok { + break + } + if event.Err != nil { + resumeErrors = append(resumeErrors, event.Err) + } + } + assert.Empty(t, resumeErrors, "Resume should complete without errors") +} + +func TestCancel_LoopWorkflow_CancelAfterChatModel(t *testing.T) { + ctx := context.Background() + modelStarted := make(chan struct{}, 10) + + agent := newCancelTestAgentWithTools(t, "loop_inner", 500*time.Millisecond, modelStarted) + + loopAgent, err := NewLoopAgent(ctx, &LoopAgentConfig{ + Name: "loop_agent", + Description: "Loop workflow", + SubAgents: []Agent{agent}, + MaxIterations: 10, + }) + assert.NoError(t, err) + + store := newCancelTestStore() + runner := NewRunner(ctx, RunnerConfig{ + Agent: loopAgent, + CheckPointStore: store, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt, WithCheckPointID("loop-cancel-1")) + + select { + case <-modelStarted: + case <-time.After(10 * time.Second): + t.Fatal("Model did not start") + } + + handle, contributed := cancelFn(WithAgentCancelMode(CancelAfterChatModel)) + assert.True(t, contributed, "Cancel should contribute") + err = handle.Wait() + assert.NoError(t, err) + + hasCancelError := false + var cancelErr *CancelError + for { + event, ok := iter.Next() + if !ok { + break + } + if event.Err != nil && errors.As(event.Err, &cancelErr) { + hasCancelError = true + } + } + + assert.True(t, hasCancelError, "Should have CancelError from loop workflow") + assert.Equal(t, CancelAfterChatModel, cancelErr.Info.Mode) + assert.NotEmpty(t, cancelErr.CheckPointID, "CancelError should have checkpoint ID") + + resumeAgent := newCancelTestAgentWithToolsFinalAnswer(t, "loop_inner") + + resumeLoop, err := NewLoopAgent(ctx, &LoopAgentConfig{ + Name: "loop_agent", + Description: "Loop workflow", + SubAgents: []Agent{resumeAgent}, + MaxIterations: 10, + }) + assert.NoError(t, err) + + runner2 := NewRunner(ctx, RunnerConfig{ + Agent: resumeLoop, + CheckPointStore: store, + }) + + resumeIter, err := runner2.Resume(ctx, "loop-cancel-1") + assert.NoError(t, err) + assert.NotNil(t, resumeIter) + + var resumeEvents []*AgentEvent + for { + event, ok := resumeIter.Next() + if !ok { + break + } + assert.Nil(t, event.Err, "Should not have error during resume") + resumeEvents = append(resumeEvents, event) + } + assert.True(t, len(resumeEvents) > 0, "Resume should produce events") +} + +func TestCancel_NestedWorkflow_AgentTool_CancelAfterChatModel(t *testing.T) { + // Structure: Runner -> RootCMA (with tools) -> agentTool -> flowAgent -> seqWorkflow -> LeafCMA + ctx := context.Background() + leafStarted := make(chan struct{}, 1) + + leafModel := &cancelTestChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + Content: "leaf response", + }, + startedChan: leafStarted, + doneChan: make(chan struct{}, 1), + } + leafModel.setDelay(500 * time.Millisecond) + + leafAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "leaf_agent", + Description: "Leaf agent in workflow", + Model: leafModel, + }) + assert.NoError(t, err) + + seqAgent, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "inner_seq", + Description: "Inner sequential workflow", + SubAgents: []Agent{leafAgent}, + }) + assert.NoError(t, err) + + rootModel := &cancelTestChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: schema.FunctionCall{ + Name: "inner_seq", + Arguments: `{}`, + }, + }, + }, + }, + startedChan: make(chan struct{}, 1), + doneChan: make(chan struct{}, 1), + } + rootAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "root_agent", + Description: "Root agent", + Model: rootModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{NewAgentTool(ctx, seqAgent)}, + }, + }, + }) + assert.NoError(t, err) + + store := newCancelTestStore() + runner := NewRunner(ctx, RunnerConfig{ + Agent: rootAgent, + CheckPointStore: store, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt, WithCheckPointID("nested-cancel-1")) + + select { + case <-leafStarted: + case <-time.After(10 * time.Second): + t.Fatal("Leaf agent model did not start") + } + + handle, contributed := cancelFn(WithAgentCancelMode(CancelAfterChatModel)) + assert.True(t, contributed, "Cancel should contribute") + err = handle.Wait() + assert.NoError(t, err) + + hasCancelError := false + var cancelErr *CancelError + for { + event, ok := iter.Next() + if !ok { + break + } + if event.Err != nil && errors.As(event.Err, &cancelErr) { + hasCancelError = true + } + } + + assert.True(t, hasCancelError, "Should have CancelError from deeply nested workflow") + assert.Equal(t, CancelAfterChatModel, cancelErr.Info.Mode) + assert.NotEmpty(t, cancelErr.CheckPointID, "CancelError should have checkpoint ID") + assert.NotNil(t, cancelErr.interruptSignal, "CancelError should carry interrupt signal through agent tree") + + // Phase 2: Resume from checkpoint — new instances to avoid data races + resumeLeafModel := &cancelTestChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + Content: "resumed leaf response", + }, + startedChan: make(chan struct{}, 1), + doneChan: make(chan struct{}, 1), + } + resumeLeaf, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "leaf_agent", + Description: "Leaf agent in workflow", + Model: resumeLeafModel, + }) + assert.NoError(t, err) + + resumeSeq, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "inner_seq", + Description: "Inner sequential workflow", + SubAgents: []Agent{resumeLeaf}, + }) + assert.NoError(t, err) + + resumeRootModel := &cancelTestChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + Content: "resumed root response", + }, + startedChan: make(chan struct{}, 1), + doneChan: make(chan struct{}, 1), + } + resumeRoot, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "root_agent", + Description: "Root agent", + Model: resumeRootModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{NewAgentTool(ctx, resumeSeq)}, + }, + }, + }) + assert.NoError(t, err) + + runner2 := NewRunner(ctx, RunnerConfig{ + Agent: resumeRoot, + CheckPointStore: store, + }) + + resumeIter, err := runner2.Resume(ctx, "nested-cancel-1") + assert.NoError(t, err) + assert.NotNil(t, resumeIter) + + var resumeErrors []error + for { + event, ok := resumeIter.Next() + if !ok { + break + } + if event.Err != nil { + resumeErrors = append(resumeErrors, event.Err) + } + } + assert.Empty(t, resumeErrors, "Resume should complete without errors") +} + +func TestCancel_CancelAfterToolCalls_InSequentialWorkflow(t *testing.T) { + ctx := context.Background() + toolStarted := make(chan struct{}, 1) + + st := &slowTool{ + name: "slow_tool", + delay: 200 * time.Millisecond, + result: "tool done", + startedChan: toolStarted, + } + + modelWithToolCall := &simpleChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: schema.FunctionCall{ + Name: "slow_tool", + Arguments: `{"input": "test"}`, + }, + }, + }, + }, + } + + agentWithTools, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent_with_tools", + Description: "Agent with slow tool", + Model: modelWithToolCall, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + + seqAgent, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "seq_agent", + Description: "Sequential workflow with tool agent", + SubAgents: []Agent{agentWithTools}, + }) + assert.NoError(t, err) + + store := newCancelTestStore() + runner := NewRunner(ctx, RunnerConfig{ + Agent: seqAgent, + CheckPointStore: store, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("Use the tool")}, cancelOpt, WithCheckPointID("tool-cancel-1")) + + select { + case <-toolStarted: + case <-time.After(10 * time.Second): + t.Fatal("Tool did not start") + } + + // Cancel after tool calls — should wait for the tool to finish, then cancel + handle, contributed := cancelFn(WithAgentCancelMode(CancelAfterToolCalls)) + assert.True(t, contributed, "Cancel should contribute") + err = handle.Wait() + assert.NoError(t, err) + + hasCancelError := false + var cancelErr *CancelError + for { + event, ok := iter.Next() + if !ok { + break + } + if event.Err != nil && errors.As(event.Err, &cancelErr) { + hasCancelError = true + } + } + + assert.True(t, hasCancelError, "Should have CancelError after tool calls complete") + assert.Equal(t, CancelAfterToolCalls, cancelErr.Info.Mode) + assert.NotEmpty(t, cancelErr.CheckPointID, "CancelError should have checkpoint ID") + + // Phase 2: Resume from checkpoint — new instances + resumeTool := &slowTool{ + name: "slow_tool", + delay: 50 * time.Millisecond, + result: "resumed tool done", + startedChan: make(chan struct{}, 1), + } + + resumeModel := &simpleChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + Content: "resumed response after tool", + }, + } + + resumeAgentWithTools, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent_with_tools", + Description: "Agent with slow tool", + Model: resumeModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{resumeTool}, + }, + }, + }) + assert.NoError(t, err) + + resumeSeq, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "seq_agent", + Description: "Sequential workflow with tool agent", + SubAgents: []Agent{resumeAgentWithTools}, + }) + assert.NoError(t, err) + + runner2 := NewRunner(ctx, RunnerConfig{ + Agent: resumeSeq, + CheckPointStore: store, + }) + + resumeIter, err := runner2.Resume(ctx, "tool-cancel-1") + assert.NoError(t, err) + assert.NotNil(t, resumeIter) + + var resumeEvents []*AgentEvent + for { + event, ok := resumeIter.Next() + if !ok { + break + } + assert.Nil(t, event.Err, "Should not have error during resume") + resumeEvents = append(resumeEvents, event) + } + assert.True(t, len(resumeEvents) > 0, "Resume should produce events") } diff --git a/adk/cancel_wrapper.go b/adk/cancel_wrapper.go deleted file mode 100644 index 396d6a458..000000000 --- a/adk/cancel_wrapper.go +++ /dev/null @@ -1,295 +0,0 @@ -/* - * Copyright 2026 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package adk - -import ( - "context" - "io" - "reflect" - "runtime/debug" - "time" - - "github.com/cloudwego/eino/components" - "github.com/cloudwego/eino/components/model" - "github.com/cloudwego/eino/compose" - "github.com/cloudwego/eino/internal/generic" - "github.com/cloudwego/eino/internal/safe" - "github.com/cloudwego/eino/schema" -) - -type cancelWaitResult[T any] struct { - result T - err error - cancelled bool -} - -func waitWithCancel[T any](cs *cancelSig, resultCh <-chan cancelWaitResult[T]) cancelWaitResult[T] { - var timeCh <-chan time.Time - select { - case <-cs.done: - cfg := cs.config.Load().(*cancelConfig) - if cfg.Mode == CancelImmediate { - if cfg.Timeout == nil { - return cancelWaitResult[T]{cancelled: true} - } - timeCh = time.After(*cfg.Timeout) - } - case res := <-resultCh: - return res - } - select { - case <-timeCh: - return cancelWaitResult[T]{cancelled: true} - case res := <-resultCh: - return res - } -} - -type cancelableChatModel struct { - inner model.BaseChatModel - cs *cancelSig -} - -func wrapModelForCancelable(m model.BaseChatModel, cs *cancelSig) *cancelableChatModel { - return &cancelableChatModel{inner: m, cs: cs} -} - -func (c *cancelableChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { - if cfg := checkCancelSig(c.cs); cfg != nil && cfg.Mode == CancelImmediate { - return nil, compose.Interrupt(ctx, "cancelled externally") - } - - resultCh := make(chan cancelWaitResult[*schema.Message], 1) - go func() { - defer func() { - if panicErr := recover(); panicErr != nil { - resultCh <- cancelWaitResult[*schema.Message]{err: safe.NewPanicErr(panicErr, debug.Stack())} - } - }() - res, err := c.inner.Generate(ctx, input, opts...) - resultCh <- cancelWaitResult[*schema.Message]{result: res, err: err} - }() - - res := waitWithCancel(c.cs, resultCh) - if res.cancelled { - return nil, compose.Interrupt(ctx, "cancelled externally") - } - return res.result, res.err -} - -func (c *cancelableChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { - if cfg := checkCancelSig(c.cs); cfg != nil && cfg.Mode == CancelImmediate { - return nil, compose.Interrupt(ctx, "cancelled externally") - } - - resultCh := make(chan cancelWaitResult[*schema.StreamReader[*schema.Message]], 1) - go func() { - defer func() { - if panicErr := recover(); panicErr != nil { - resultCh <- cancelWaitResult[*schema.StreamReader[*schema.Message]]{err: safe.NewPanicErr(panicErr, debug.Stack())} - } - }() - - stream, err := c.inner.Stream(ctx, input, opts...) - if err != nil { - resultCh <- cancelWaitResult[*schema.StreamReader[*schema.Message]]{err: err} - return - } - copies := stream.Copy(2) - _ = consumeStreamForError(copies[0]) - resultCh <- cancelWaitResult[*schema.StreamReader[*schema.Message]]{result: copies[1]} - }() - - res := waitWithCancel(c.cs, resultCh) - if res.cancelled { - return nil, compose.Interrupt(ctx, "cancelled externally") - } - return res.result, res.err -} - -func (c *cancelableChatModel) IsCallbacksEnabled() bool { - return components.IsCallbacksEnabled(c.inner) -} - -func (c *cancelableChatModel) GetType() string { - if name, ok := components.GetType(c.inner); ok { - return name - } - - return generic.ParseTypeName(reflect.ValueOf(c.inner)) -} - -func cancelableToolInvokable(cs *cancelSig, endpoint compose.InvokableToolEndpoint) compose.InvokableToolEndpoint { - return func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) { - if cfg := checkCancelSig(cs); cfg != nil && cfg.Mode == CancelImmediate { - return nil, compose.Interrupt(ctx, "cancelled externally") - } - - resultCh := make(chan cancelWaitResult[*compose.ToolOutput], 1) - go func() { - defer func() { - if panicErr := recover(); panicErr != nil { - resultCh <- cancelWaitResult[*compose.ToolOutput]{err: safe.NewPanicErr(panicErr, debug.Stack())} - } - }() - output, err := endpoint(ctx, input) - resultCh <- cancelWaitResult[*compose.ToolOutput]{result: output, err: err} - }() - - res := waitWithCancel(cs, resultCh) - if res.cancelled { - return nil, compose.Interrupt(ctx, "cancelled externally") - } - return res.result, res.err - } -} - -func cancelableToolStreamable(cs *cancelSig, endpoint compose.StreamableToolEndpoint) compose.StreamableToolEndpoint { - return func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) { - if cfg := checkCancelSig(cs); cfg != nil && cfg.Mode == CancelImmediate { - return nil, compose.Interrupt(ctx, "cancelled externally") - } - - resultCh := make(chan cancelWaitResult[*schema.StreamReader[string]], 1) - go func() { - defer func() { - if panicErr := recover(); panicErr != nil { - resultCh <- cancelWaitResult[*schema.StreamReader[string]]{err: safe.NewPanicErr(panicErr, debug.Stack())} - } - }() - output, err := endpoint(ctx, input) - if err != nil { - resultCh <- cancelWaitResult[*schema.StreamReader[string]]{err: err} - return - } - copies := output.Result.Copy(2) - _ = consumeStreamForErrorString(copies[0]) - resultCh <- cancelWaitResult[*schema.StreamReader[string]]{result: copies[1]} - }() - - res := waitWithCancel(cs, resultCh) - if res.cancelled { - return nil, compose.Interrupt(ctx, "cancelled externally") - } - if res.err != nil { - return nil, res.err - } - return &compose.StreamToolOutput{Result: res.result}, nil - } -} - -func cancelableToolEnhancedInvokable(cs *cancelSig, endpoint compose.EnhancedInvokableToolEndpoint) compose.EnhancedInvokableToolEndpoint { - return func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedInvokableToolOutput, error) { - if cfg := checkCancelSig(cs); cfg != nil && cfg.Mode == CancelImmediate { - return nil, compose.Interrupt(ctx, "cancelled externally") - } - - resultCh := make(chan cancelWaitResult[*compose.EnhancedInvokableToolOutput], 1) - go func() { - defer func() { - if panicErr := recover(); panicErr != nil { - resultCh <- cancelWaitResult[*compose.EnhancedInvokableToolOutput]{err: safe.NewPanicErr(panicErr, debug.Stack())} - } - }() - output, err := endpoint(ctx, input) - resultCh <- cancelWaitResult[*compose.EnhancedInvokableToolOutput]{result: output, err: err} - }() - - res := waitWithCancel(cs, resultCh) - if res.cancelled { - return nil, compose.Interrupt(ctx, "cancelled externally") - } - return res.result, res.err - } -} - -func cancelableToolEnhancedStreamable(cs *cancelSig, endpoint compose.EnhancedStreamableToolEndpoint) compose.EnhancedStreamableToolEndpoint { - return func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedStreamableToolOutput, error) { - if cfg := checkCancelSig(cs); cfg != nil && cfg.Mode == CancelImmediate { - return nil, compose.Interrupt(ctx, "cancelled externally") - } - - resultCh := make(chan cancelWaitResult[*schema.StreamReader[*schema.ToolResult]], 1) - go func() { - defer func() { - if panicErr := recover(); panicErr != nil { - resultCh <- cancelWaitResult[*schema.StreamReader[*schema.ToolResult]]{err: safe.NewPanicErr(panicErr, debug.Stack())} - } - }() - output, err := endpoint(ctx, input) - if err != nil { - resultCh <- cancelWaitResult[*schema.StreamReader[*schema.ToolResult]]{err: err} - return - } - copies := output.Result.Copy(2) - _ = consumeStreamForErrorToolResult(copies[0]) - resultCh <- cancelWaitResult[*schema.StreamReader[*schema.ToolResult]]{result: copies[1]} - }() - - res := waitWithCancel(cs, resultCh) - if res.cancelled { - return nil, compose.Interrupt(ctx, "cancelled externally") - } - if res.err != nil { - return nil, res.err - } - return &compose.EnhancedStreamableToolOutput{Result: res.result}, nil - } -} - -func cancelableTool(cs *cancelSig) compose.ToolMiddleware { - return compose.ToolMiddleware{ - Invokable: func(endpoint compose.InvokableToolEndpoint) compose.InvokableToolEndpoint { - return cancelableToolInvokable(cs, endpoint) - }, - Streamable: func(endpoint compose.StreamableToolEndpoint) compose.StreamableToolEndpoint { - return cancelableToolStreamable(cs, endpoint) - }, - EnhancedInvokable: func(endpoint compose.EnhancedInvokableToolEndpoint) compose.EnhancedInvokableToolEndpoint { - return cancelableToolEnhancedInvokable(cs, endpoint) - }, - EnhancedStreamable: func(endpoint compose.EnhancedStreamableToolEndpoint) compose.EnhancedStreamableToolEndpoint { - return cancelableToolEnhancedStreamable(cs, endpoint) - }, - } -} - -func consumeStreamForErrorString(stream *schema.StreamReader[string]) error { - defer stream.Close() - for { - _, err := stream.Recv() - if err == io.EOF { - return nil - } - if err != nil { - return err - } - } -} - -func consumeStreamForErrorToolResult(stream *schema.StreamReader[*schema.ToolResult]) error { - defer stream.Close() - for { - _, err := stream.Recv() - if err == io.EOF { - return nil - } - if err != nil { - return err - } - } -} diff --git a/adk/chatmodel.go b/adk/chatmodel.go index 0848a859a..a0126455a 100644 --- a/adk/chatmodel.go +++ b/adk/chatmodel.go @@ -39,8 +39,6 @@ import ( ) var _ ResumableAgent = &ChatModelAgent{} -var _ CancellableAgent = &ChatModelAgent{} -var _ CancellableResumableAgent = &ChatModelAgent{} type chatModelAgentExecCtx struct { runtimeReturnDirectly map[string]bool @@ -345,8 +343,19 @@ type ChatModelAgent struct { exeCtx *execContext } -type runFunc func(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent], - store *bridgeStore, instruction string, returnDirectly map[string]bool, cs *cancelSig, opts ...compose.Option) +// runParams holds the parameters for a runFunc invocation. +type runParams struct { + input *AgentInput + generator *AsyncGenerator[*AgentEvent] + store *bridgeStore + instruction string + returnDirectly map[string]bool + cancelCtx *cancelContext + cancelCtxOwned bool + composeOpts []compose.Option +} + +type runFunc func(ctx context.Context, p *runParams) // NewChatModelAgent constructs a chat model-backed agent with the provided config. func NewChatModelAgent(ctx context.Context, config *ChatModelAgentConfig) (*ChatModelAgent, error) { @@ -378,6 +387,16 @@ func NewChatModelAgent(ctx context.Context, config *ChatModelAgentConfig) (*Chat ) tc.ToolCallMiddlewares = append(tc.ToolCallMiddlewares, collectToolMiddlewaresFromMiddlewares(config.Middlewares)...) + // Cancel monitoring middleware (innermost — close to the tool endpoint). + // This allows early abort of the raw tool result stream when immediateChan fires + // (CancelImmediate or timeout escalation), while requiring outer wrappers to + // propagate stream errors such as ErrStreamCanceled without swallowing them. + cancelToolHandler := &cancelMonitoredToolHandler{} + tc.ToolCallMiddlewares = append(tc.ToolCallMiddlewares, compose.ToolMiddleware{ + Streamable: cancelToolHandler.WrapStreamableToolCall, + EnhancedStreamable: cancelToolHandler.WrapEnhancedStreamableToolCall, + }) + return &ChatModelAgent{ name: config.Name, description: config.Description, @@ -575,8 +594,8 @@ func setOutputToSession(ctx context.Context, msg Message, msgStream MessageStrea } func errFunc(err error) runFunc { - return func(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent], store *bridgeStore, _ string, _ map[string]bool, _ *cancelSig, _ ...compose.Option) { - generator.Send(&AgentEvent{Err: err}) + return func(ctx context.Context, p *runParams) { + p.generator.Send(&AgentEvent{Err: err}) } } @@ -696,25 +715,74 @@ func (a *ChatModelAgent) prepareExecContext(ctx context.Context) (*execContext, }, nil } +// handleRunFuncError is the common error handler for buildNoToolsRunFunc and buildReActRunFunc. +// It handles compose interrupts (both cancel-triggered and business) +// and generic errors, sending the appropriate event to the generator. +func (a *ChatModelAgent) handleRunFuncError( + ctx context.Context, + err error, + cancelCtx *cancelContext, + cancelCtxOwned bool, + store *bridgeStore, + generator *AsyncGenerator[*AgentEvent], +) { + info, ok := compose.ExtractInterruptInfo(err) + if ok { + if cancelCtx != nil { + // Note: there is a benign TOCTOU window here. Between shouldCancel() + // returning false and markDone() executing, a concurrent cancel could + // transition stateRunning→stateCancelling. markDone() then does + // stateCancelling→stateDone, and the cancel func receives + // ErrExecutionCompleted (execution finished before cancel took effect). + if !cancelCtx.shouldCancel() { + cancelCtx.markDone() + } + } + + data, existed, sErr := store.Get(ctx, bridgeCheckpointID) + if sErr != nil { + generator.Send(&AgentEvent{AgentName: a.name, Err: fmt.Errorf("failed to get interrupt info: %w", sErr)}) + return + } + if !existed { + generator.Send(&AgentEvent{AgentName: a.name, Err: fmt.Errorf("interrupt occurred but checkpoint data is missing")}) + return + } + + is := FromInterruptContexts(info.InterruptContexts) + event := CompositeInterrupt(ctx, info, data, is) + event.Action.Interrupted.Data = &ChatModelAgentInterruptInfo{ + Info: info, + Data: data, + } + event.AgentName = a.name + generator.Send(event) + return + } + + if cancelCtxOwned && cancelCtx != nil { + cancelCtx.markDone() + } + generator.Send(&AgentEvent{Err: err}) +} + func (a *ChatModelAgent) buildNoToolsRunFunc(_ context.Context) runFunc { type noToolsInput struct { input *AgentInput instruction string } - return func(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent], - store *bridgeStore, instruction string, _ map[string]bool, cs *cancelSig, opts ...compose.Option) { + return func(ctx context.Context, p *runParams) { + cancelCtx := p.cancelCtx + ctx = withCancelContext(ctx, cancelCtx) wrappedModel := buildModelWrappers(a.model, &modelWrapperConfig{ - handlers: a.handlers, - middlewares: a.middlewares, - retryConfig: a.modelRetryConfig, + handlers: a.handlers, + middlewares: a.middlewares, + retryConfig: a.modelRetryConfig, + cancelContext: cancelCtx, }) - if cs != nil { - wrappedModel = wrapModelForCancelable(wrappedModel, cs) - } - chain := compose.NewChain[noToolsInput, Message]( compose.WithGenLocalState(func(ctx context.Context) (state *State) { return &State{} @@ -724,37 +792,63 @@ func (a *ChatModelAgent) buildNoToolsRunFunc(_ context.Context) runFunc { if err != nil { return nil, err } + _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { + st.Messages = append(st.Messages, messages...) + return nil + }) return messages, nil })). AppendChatModel(wrappedModel) - r, err := chain.Compile(ctx, compose.WithGraphName(a.name), - compose.WithCheckPointStore(store), + var compileOptions []compose.GraphCompileOption + compileOptions = append(compileOptions, + compose.WithGraphName(a.name), + compose.WithCheckPointStore(p.store), compose.WithSerializer(&gobSerializer{})) + + if cancelCtx != nil { + var interrupt func(...compose.GraphInterruptOption) + ctx, interrupt = compose.WithGraphInterrupt(ctx) + cancelCtx.setGraphInterruptFunc(cancelCtx.wrapGraphInterruptWithGracePeriod(interrupt)) + } + + r, err := chain.Compile(ctx, compileOptions...) if err != nil { - generator.Send(&AgentEvent{Err: err}) + p.generator.Send(&AgentEvent{Err: err}) return } ctx = withChatModelAgentExecCtx(ctx, &chatModelAgentExecCtx{ - generator: generator, + generator: p.generator, }) - in := noToolsInput{input: input, instruction: instruction} + // Pre-execution cancel check + if cancelCtx != nil && cancelCtx.shouldCancel() { + if cancelCtx.getMode() == CancelImmediate || atomic.LoadInt32(&cancelCtx.escalated) == 1 { + cancelErr, ok := cancelCtx.createAndMarkCancelHandled() + if !ok { + return + } + p.generator.Send(&AgentEvent{Err: cancelErr}) + return + } + } + + in := noToolsInput{input: p.input, instruction: p.instruction} var msg Message var msgStream MessageStream - if input.EnableStreaming { - msgStream, err = r.Stream(ctx, in, opts...) + if p.input.EnableStreaming { + msgStream, err = r.Stream(ctx, in, p.composeOpts...) } else { - msg, err = r.Invoke(ctx, in, opts...) + msg, err = r.Invoke(ctx, in, p.composeOpts...) } if err == nil { if a.outputKey != "" { err = setOutputToSession(ctx, msg, msgStream, a.outputKey) if err != nil { - generator.Send(&AgentEvent{Err: err}) + p.generator.Send(&AgentEvent{Err: err}) } } else if msgStream != nil { msgStream.Close() @@ -762,30 +856,7 @@ func (a *ChatModelAgent) buildNoToolsRunFunc(_ context.Context) runFunc { return } - info, ok := compose.ExtractInterruptInfo(err) - if !ok { - generator.Send(&AgentEvent{Err: err}) - return - } - - data, existed, sErr := store.Get(ctx, bridgeCheckpointID) - if sErr != nil { - generator.Send(&AgentEvent{AgentName: a.name, Err: fmt.Errorf("failed to get interrupt info: %w", sErr)}) - return - } - if !existed { - generator.Send(&AgentEvent{AgentName: a.name, Err: fmt.Errorf("interrupt occurred but checkpoint data is missing")}) - return - } - - is := FromInterruptContexts(info.InterruptContexts) - event := CompositeInterrupt(ctx, info, data, is) - event.Action.Interrupted.Data = &ChatModelAgentInterruptInfo{ - Info: info, - Data: data, - } - event.AgentName = a.name - generator.Send(event) + a.handleRunFuncError(ctx, err, cancelCtx, p.cancelCtxOwned, p.store, p.generator) } } @@ -809,11 +880,17 @@ func (a *ChatModelAgent) buildReActRunFunc(_ context.Context, bc *execContext) ( instruction string } - return func(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent], store *bridgeStore, - instruction string, returnDirectly map[string]bool, cs *cancelSig, opts ...compose.Option) { - g, err := newReact(ctx, conf, cs) + return func(ctx context.Context, p *runParams) { + cancelCtx := p.cancelCtx + conf.cancelCtx = cancelCtx + if conf.modelWrapperConf != nil { + conf.modelWrapperConf.cancelContext = cancelCtx + } + ctx = withCancelContext(ctx, cancelCtx) + + g, err := newReact(ctx, conf) if err != nil { - generator.Send(&AgentEvent{Err: err}) + p.generator.Send(&AgentEvent{Err: err}) return } @@ -825,7 +902,7 @@ func (a *ChatModelAgent) buildReActRunFunc(_ context.Context, bc *execContext) ( return nil, genErr } return &reactInput{ - messages: messages, + Messages: messages, }, nil }), ). @@ -834,38 +911,56 @@ func (a *ChatModelAgent) buildReActRunFunc(_ context.Context, bc *execContext) ( var compileOptions []compose.GraphCompileOption compileOptions = append(compileOptions, compose.WithGraphName(a.name), - compose.WithCheckPointStore(store), + compose.WithCheckPointStore(p.store), compose.WithSerializer(&gobSerializer{}), compose.WithMaxRunSteps(math.MaxInt)) + if cancelCtx != nil { + var interrupt func(...compose.GraphInterruptOption) + ctx, interrupt = compose.WithGraphInterrupt(ctx) + cancelCtx.setGraphInterruptFunc(cancelCtx.wrapGraphInterruptWithGracePeriod(interrupt)) + } + runnable, err_ := chain.Compile(ctx, compileOptions...) if err_ != nil { - generator.Send(&AgentEvent{Err: err_}) + p.generator.Send(&AgentEvent{Err: err_}) return } ctx = withChatModelAgentExecCtx(ctx, &chatModelAgentExecCtx{ - runtimeReturnDirectly: returnDirectly, - generator: generator, + runtimeReturnDirectly: p.returnDirectly, + generator: p.generator, }) + // Pre-execution cancel check + if cancelCtx != nil && cancelCtx.shouldCancel() { + if cancelCtx.getMode() == CancelImmediate || atomic.LoadInt32(&cancelCtx.escalated) == 1 { + cancelErr, ok := cancelCtx.createAndMarkCancelHandled() + if !ok { + return + } + p.generator.Send(&AgentEvent{Err: cancelErr}) + return + } + } + in := reactRunInput{ - input: input, - instruction: instruction, + input: p.input, + instruction: p.instruction, } var runOpts []compose.Option - runOpts = append(runOpts, opts...) + runOpts = append(runOpts, p.composeOpts...) if a.toolsConfig.EmitInternalEvents { - runOpts = append(runOpts, compose.WithToolsNodeOption(compose.WithToolOption(withAgentToolEventGenerator(generator)))) + runOpts = append(runOpts, compose.WithToolsNodeOption(compose.WithToolOption(withAgentToolEventGenerator(p.generator)))) } - if input.EnableStreaming { + if p.input.EnableStreaming { runOpts = append(runOpts, compose.WithToolsNodeOption(compose.WithToolOption(withAgentToolEnableStreaming(true)))) } var msg Message var msgStream MessageStream - if input.EnableStreaming { + if p.input.EnableStreaming { msgStream, err_ = runnable.Stream(ctx, in, runOpts...) } else { msg, err_ = runnable.Invoke(ctx, in, runOpts...) @@ -875,7 +970,7 @@ func (a *ChatModelAgent) buildReActRunFunc(_ context.Context, bc *execContext) ( if a.outputKey != "" { err_ = setOutputToSession(ctx, msg, msgStream, a.outputKey) if err_ != nil { - generator.Send(&AgentEvent{Err: err_}) + p.generator.Send(&AgentEvent{Err: err_}) } } else if msgStream != nil { msgStream.Close() @@ -884,31 +979,7 @@ func (a *ChatModelAgent) buildReActRunFunc(_ context.Context, bc *execContext) ( return } - info, ok := compose.ExtractInterruptInfo(err_) - if !ok { - generator.Send(&AgentEvent{Err: err_}) - return - } - - data, existed, err := store.Get(ctx, bridgeCheckpointID) - if err != nil { - generator.Send(&AgentEvent{AgentName: a.name, Err: fmt.Errorf("failed to get interrupt info: %w", err)}) - return - } - if !existed { - generator.Send(&AgentEvent{AgentName: a.name, Err: fmt.Errorf("interrupt occurred but checkpoint data is missing")}) - return - } - - is := FromInterruptContexts(info.InterruptContexts) - - event := CompositeInterrupt(ctx, info, data, is) - event.Action.Interrupted.Data = &ChatModelAgentInterruptInfo{ - Info: info, - Data: data, - } - event.AgentName = a.name - generator.Send(event) + a.handleRunFuncError(ctx, err_, cancelCtx, p.cancelCtxOwned, p.store, p.generator) }, nil } @@ -981,24 +1052,25 @@ func (a *ChatModelAgent) getRunFunc(ctx context.Context) (context.Context, runFu } func (a *ChatModelAgent) Run(ctx context.Context, input *AgentInput, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { - iter, _ := a.runInternal(ctx, input, false, opts...) - return iter -} - -func (a *ChatModelAgent) RunWithCancel(ctx context.Context, input *AgentInput, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) { - return a.runInternal(ctx, input, true, opts...) -} - -func (a *ChatModelAgent) runInternal(ctx context.Context, input *AgentInput, withCancel bool, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) { iterator, generator := NewAsyncIteratorPair[*AgentEvent]() + o := getCommonOptions(nil, opts...) + cancelCtx := o.cancelCtx + cancelCtxOwned := cancelCtx != nil && getCancelContext(ctx) == nil + if cancelCtx == nil { + cancelCtx = getCancelContext(ctx) + } + ctx, run, bc, err := a.getRunFunc(ctx) if err != nil { go func() { - generator.Send(&AgentEvent{Err: err}) + if cancelCtxOwned && cancelCtx != nil { + defer cancelCtx.markDone() + } + generator.Send(&AgentEvent{Err: fmt.Errorf("ChatModelAgent getRunFunc error: %w", err)}) generator.Close() }() - return iterator, notCancellableFuncInternal + return iterator } co := getComposeOptions(opts) @@ -1011,13 +1083,6 @@ func (a *ChatModelAgent) runInternal(ctx context.Context, input *AgentInput, wit } } - var cs *cancelSig - var cancelFn CancelFunc = notCancellableFuncInternal - if withCancel { - cs = newCancelSig() - cancelFn = buildCancelFunc(cs) - } - go func() { defer func() { panicErr := recover() @@ -1039,47 +1104,44 @@ func (a *ChatModelAgent) runInternal(ctx context.Context, input *AgentInput, wit returnDirectly = bc.returnDirectly } - run(ctx, input, generator, newBridgeStore(), instruction, returnDirectly, cs, co...) + run(ctx, &runParams{ + input: input, + generator: generator, + store: newBridgeStore(), + instruction: instruction, + returnDirectly: returnDirectly, + cancelCtx: cancelCtx, + cancelCtxOwned: cancelCtxOwned, + composeOpts: co, + }) }() - return iterator, cancelFn -} - -func buildCancelFunc(cs *cancelSig) CancelFunc { - var once sync.Once - return func(opts ...CancelOption) error { - cfg := &cancelConfig{ - Mode: CancelImmediate, - } - for _, opt := range opts { - opt(cfg) - } - once.Do(func() { - cs.cancel(cfg) - }) - return nil + if cancelCtxOwned { + return wrapIterWithCancelCtx(iterator, cancelCtx) } + return iterator } func (a *ChatModelAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { - iter, _ := a.resumeInternal(ctx, info, false, opts...) - return iter -} - -func (a *ChatModelAgent) ResumeWithCancel(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) { - return a.resumeInternal(ctx, info, true, opts...) -} - -func (a *ChatModelAgent) resumeInternal(ctx context.Context, info *ResumeInfo, withCancel bool, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) { iterator, generator := NewAsyncIteratorPair[*AgentEvent]() + o := getCommonOptions(nil, opts...) + cancelCtx := o.cancelCtx + cancelCtxOwned := cancelCtx != nil && getCancelContext(ctx) == nil + if cancelCtx == nil { + cancelCtx = getCancelContext(ctx) + } + ctx, run, bc, err := a.getRunFunc(ctx) if err != nil { go func() { - generator.Send(&AgentEvent{Err: err}) + if cancelCtxOwned && cancelCtx != nil { + defer cancelCtx.markDone() + } + generator.Send(&AgentEvent{Err: fmt.Errorf("ChatModelAgent getRunFunc error: %w", err)}) generator.Close() }() - return iterator, notCancellableFuncInternal + return iterator } co := getComposeOptions(opts) @@ -1092,19 +1154,18 @@ func (a *ChatModelAgent) resumeInternal(ctx context.Context, info *ResumeInfo, w } } - methodName := "Resume" - if withCancel { - methodName = "ResumeWithCancel" + if info == nil { + panic(fmt.Sprintf("ChatModelAgent.Resume: agent '%s' was asked to resume but info is nil", a.Name(ctx))) } if info.InterruptState == nil { - panic(fmt.Sprintf("ChatModelAgent.%s: agent '%s' was asked to resume but has no state", methodName, a.Name(ctx))) + panic(fmt.Sprintf("ChatModelAgent.Resume: agent '%s' was asked to resume but has no state", a.Name(ctx))) } stateByte, ok := info.InterruptState.([]byte) if !ok { - panic(fmt.Sprintf("ChatModelAgent.%s: agent '%s' was asked to resume but has invalid interrupt state type: %T", - methodName, a.Name(ctx), info.InterruptState)) + panic(fmt.Sprintf("ChatModelAgent.Resume: agent '%s' was asked to resume but has invalid interrupt state type: %T", + a.Name(ctx), info.InterruptState)) } // Migrate legacy checkpoints before resume. @@ -1119,15 +1180,15 @@ func (a *ChatModelAgent) resumeInternal(ctx context.Context, info *ResumeInfo, w generator.Send(&AgentEvent{Err: err}) generator.Close() }() - return iterator, notCancellableFuncInternal + return iterator } var historyModifier func(ctx context.Context, history []Message) []Message if info.ResumeData != nil { resumeData, ok := info.ResumeData.(*ChatModelAgentResumeData) if !ok { - panic(fmt.Sprintf("ChatModelAgent.%s: agent '%s' was asked to resume but has invalid resume data type: %T", - methodName, a.Name(ctx), info.ResumeData)) + panic(fmt.Sprintf("ChatModelAgent.Resume: agent '%s' was asked to resume but has invalid resume data type: %T", + a.Name(ctx), info.ResumeData)) } historyModifier = resumeData.HistoryModifier } @@ -1143,13 +1204,6 @@ func (a *ChatModelAgent) resumeInternal(ctx context.Context, info *ResumeInfo, w })) } - var cs *cancelSig - var cancelFn CancelFunc = notCancellableFuncInternal - if withCancel { - cs = newCancelSig() - cancelFn = buildCancelFunc(cs) - } - go func() { defer func() { panicErr := recover() @@ -1171,11 +1225,22 @@ func (a *ChatModelAgent) resumeInternal(ctx context.Context, info *ResumeInfo, w returnDirectly = bc.returnDirectly } - run(ctx, &AgentInput{EnableStreaming: info.EnableStreaming}, generator, - newResumeBridgeStore(stateByte), instruction, returnDirectly, cs, co...) + run(ctx, &runParams{ + input: &AgentInput{EnableStreaming: info.EnableStreaming}, + generator: generator, + store: newResumeBridgeStore(bridgeCheckpointID, stateByte), + instruction: instruction, + returnDirectly: returnDirectly, + cancelCtx: cancelCtx, + cancelCtxOwned: cancelCtxOwned, + composeOpts: co, + }) }() - return iterator, cancelFn + if cancelCtxOwned { + return wrapIterWithCancelCtx(iterator, cancelCtx) + } + return iterator } func getComposeOptions(opts []AgentRunOption) []compose.Option { diff --git a/adk/chatmodel_retry_test.go b/adk/chatmodel_retry_test.go index 00c89b352..0cb2a87bd 100644 --- a/adk/chatmodel_retry_test.go +++ b/adk/chatmodel_retry_test.go @@ -1046,3 +1046,148 @@ func TestSequentialWorkflow_NoRetryConfig_StreamError_StopsFlow(t *testing.T) { assert.Equal(t, 0, len(capturingModel.capturedInputs), "Agent B should NOT be called due to error") assert.Equal(t, int32(1), atomic.LoadInt32(&noRetryModel.callCount), "Model should only be called once (no retry)") } + +// failThenToolCallStreamModel is a ChatModel that: +// - First Stream() call: yields a partial chunk then fails with a retryable error mid-stream. +// - Second Stream() call (retry): yields a tool-call message (success). +// - Third Generate() call (after tool result): yields a final assistant message. +// +// This exercises the path where the eventSenderModel copies the first stream, +// wraps its error as WillRetryError, and sends it as an event to the session. +// The retryModelWrapper then retries, gets a clean stream with a tool call, +// the tool interrupts, and checkpoint save needs to gob-encode the session +// (which still contains the unconsumed WillRetryError event stream). +type failThenToolCallStreamModel struct { + streamCallCount int32 + genCallCount int32 +} + +func (m *failThenToolCallStreamModel) Generate(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m.genCallCount, 1) + return schema.AssistantMessage("final answer", nil), nil +} + +func (m *failThenToolCallStreamModel) Stream(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + count := atomic.AddInt32(&m.streamCallCount, 1) + + sr, sw := schema.Pipe[*schema.Message](10) + go func() { + defer sw.Close() + if count == 1 { + // First call: yield a partial chunk then fail. + sw.Send(schema.AssistantMessage("partial", nil), nil) + sw.Send(nil, errRetryAble) + return + } + // Second call (retry): yield a tool-call message. + sw.Send(schema.AssistantMessage("", []schema.ToolCall{{ + ID: "call-1", + Function: schema.FunctionCall{ + Name: "interrupt_tool", + Arguments: `{}`, + }, + }}), nil) + }() + return sr, nil +} + +func (m *failThenToolCallStreamModel) WithTools(_ []*schema.ToolInfo) (model.ToolCallingChatModel, error) { + return m, nil +} + +// interruptToolForRetryTest is a tool that always interrupts. +type interruptToolForRetryTest struct{} + +func (t *interruptToolForRetryTest) Info(_ context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: "interrupt_tool", + Desc: "tool that interrupts", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "input": {Type: "string"}, + }), + }, nil +} + +func (t *interruptToolForRetryTest) InvokableRun(ctx context.Context, _ string, _ ...tool.Option) (string, error) { + return "", tool.Interrupt(ctx, "interrupted by tool") +} + +// TestCheckpointSave_WillRetryError_StreamNotConsumed verifies that checkpoint +// saving succeeds when the session contains an event with an unconsumed stream +// that ends with WillRetryError. +// +// Scenario: +// 1. ChatModelAgent with retry (MaxRetries=1) and a tool that always interrupts +// 2. Model.Stream() #1 yields "partial" then errRetryAble mid-stream +// → eventSenderModel copies the stream, wraps the error as WillRetryError, +// sends the event to the session (stream NOT consumed by anyone yet) +// → retryModelWrapper detects error on its copy, retries +// 3. Model.Stream() #2 succeeds with a tool-call message +// 4. Tool executes → interrupts +// 5. Runner.handleIter sees the interrupt → saveCheckPoint → gob encodes runSession +// 6. The session has the WillRetryError event with an unconsumed stream +// → agentEventWrapper.GobEncode proactively consumes the stream via +// getMessageFromWrappedEvent, so MessageVariant.GobEncode sees an error-free +// array and succeeds +func TestCheckpointSave_WillRetryError_StreamNotConsumed(t *testing.T) { + ctx := context.Background() + + mdl := &failThenToolCallStreamModel{} + itool := &interruptToolForRetryTest{} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Agent for checkpoint stream error test", + Instruction: "You are a test agent.", + Model: mdl, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{itool}, + }, + }, + ModelRetryConfig: &ModelRetryConfig{ + MaxRetries: 1, + IsRetryAble: func(_ context.Context, err error) bool { + return errors.Is(err, errRetryAble) + }, + BackoffFunc: func(_ context.Context, _ int) time.Duration { + return time.Millisecond // fast retry for test + }, + }, + }) + assert.NoError(t, err) + + store := newMyStore() + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + EnableStreaming: true, + CheckPointStore: store, + }) + + iter := runner.Run(ctx, + []Message{schema.UserMessage("hello")}, + WithCheckPointID("ckpt-1"), + ) + + var events []*AgentEvent + for { + event, ok := iter.Next() + if !ok { + break + } + events = append(events, event) + + if event.Err != nil { + t.Logf("event error: %v", event.Err) + } + } + + // Verify the checkpoint was saved successfully. + _, exists, _ := store.Get(ctx, "ckpt-1") + assert.True(t, exists, "checkpoint should be saved successfully; "+ + "if this fails, the WillRetryError stream in the session caused gob encoding to fail") + + // Sanity: the model should have been called twice for Stream (fail + retry). + assert.Equal(t, int32(2), atomic.LoadInt32(&mdl.streamCallCount), + "model should be called twice: first fail, then retry success") +} diff --git a/adk/flow.go b/adk/flow.go index 7acfd2137..52a346c74 100644 --- a/adk/flow.go +++ b/adk/flow.go @@ -333,15 +333,6 @@ func buildDefaultHistoryRewriter(agentName string) HistoryRewriter { } func (a *flowAgent) Run(ctx context.Context, input *AgentInput, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { - iter, _ := a.runInternal(ctx, input, false, opts...) - return iter -} - -func (a *flowAgent) RunWithCancel(ctx context.Context, input *AgentInput, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) { - return a.runInternal(ctx, input, true, opts...) -} - -func (a *flowAgent) runInternal(ctx context.Context, input *AgentInput, withCancel bool, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) { agentName := a.Name(ctx) var runCtx *runContext @@ -349,12 +340,16 @@ func (a *flowAgent) runInternal(ctx context.Context, input *AgentInput, withCanc ctx = AppendAddressSegment(ctx, AddressSegmentAgent, agentName) o := getCommonOptions(nil, opts...) + cancelCtx := o.cancelCtx processedInput, err := a.genAgentInput(ctx, runCtx, o.skipTransferMessages) if err != nil { + if cancelCtx != nil { + cancelCtx.markDone() + } cbInput := &AgentCallbackInput{Input: input} ctx = callbacks.OnStart(ctx, cbInput) - return wrapIterWithOnEnd(ctx, genErrorIter(err)), notCancellableFuncInternal + return wrapIterWithOnEnd(ctx, genErrorIter(err)) } ctxForSubAgents := ctx @@ -367,105 +362,90 @@ func (a *flowAgent) runInternal(ctx context.Context, input *AgentInput, withCanc input = processedInput if wf, ok := a.Agent.(*workflowAgent); ok { - return wrapIterWithOnEnd(ctx, wf.Run(ctx, input, filterCallbackHandlersForNestedAgents(agentName, opts)...)), notCancellableFuncInternal + ctx = withCancelContext(ctx, cancelCtx) + filteredOpts := filterCancelOption(filterCallbackHandlersForNestedAgents(agentName, opts)) + iter := wf.Run(ctx, input, filteredOpts...) + iter = wrapIterWithCancelCtx(iter, cancelCtx) + return wrapIterWithOnEnd(ctx, iter) } - var aIter *AsyncIterator[*AgentEvent] - var cancelFn CancelFunc = notCancellableFuncInternal - - ca, supportCancel := a.Agent.(CancellableAgent) - if withCancel && supportCancel { - aIter, cancelFn = ca.RunWithCancel(ctx, input, filterOptions(agentName, opts)...) - } else { - aIter = a.Agent.Run(ctx, input, filterOptions(agentName, opts)...) - } + aIter := a.Agent.Run(withCancelContext(ctx, cancelCtx), input, filterOptions(agentName, opts)...) iterator, generator := NewAsyncIteratorPair[*AgentEvent]() - go a.run(ctx, ctxForSubAgents, runCtx, aIter, generator, opts...) - - return iterator, cancelFn -} + go a.run(withCancelContext(ctx, cancelCtx), withCancelContext(ctxForSubAgents, cancelCtx), runCtx, aIter, generator, filterCancelOption(opts)...) -func notCancellableFuncInternal(_ ...CancelOption) error { - return ErrAgentNotCancellable + return wrapIterWithCancelCtx(iterator, cancelCtx) } func (a *flowAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { - iter, _ := a.resumeInternal(ctx, info, false, opts...) - return iter -} - -func (a *flowAgent) ResumeWithCancel(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) { - return a.resumeInternal(ctx, info, true, opts...) -} - -func (a *flowAgent) resumeInternal(ctx context.Context, info *ResumeInfo, withCancel bool, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) { agentName := a.Name(ctx) ctx, info = buildResumeInfo(ctx, agentName, info) ctxForSubAgents := ctx + o := getCommonOptions(nil, opts...) + cancelCtx := o.cancelCtx + agentType := getAgentType(a.Agent) ctx = initAgentCallbacks(ctx, agentName, agentType, filterOptions(agentName, opts)...) cbInput := &AgentCallbackInput{ResumeInfo: info} ctx = callbacks.OnStart(ctx, cbInput) if info.WasInterrupted { - var aIter *AsyncIterator[*AgentEvent] - var cancelFn CancelFunc = notCancellableFuncInternal - - ca, supportCancel := a.Agent.(CancellableResumableAgent) - if withCancel && supportCancel { - aIter, cancelFn = ca.ResumeWithCancel(ctx, info, opts...) - } else if ra, ok := a.Agent.(ResumableAgent); ok { + if ra, ok := a.Agent.(ResumableAgent); ok { if _, ok := ra.(*workflowAgent); ok { - filteredOpts := filterCallbackHandlersForNestedAgents(agentName, opts) + ctx = withCancelContext(ctx, cancelCtx) + filteredOpts := filterCancelOption(filterCallbackHandlersForNestedAgents(agentName, opts)) aIter := ra.Resume(ctx, info, filteredOpts...) - return wrapIterWithOnEnd(ctx, aIter), cancelFn + aIter = wrapIterWithCancelCtx(aIter, cancelCtx) + return wrapIterWithOnEnd(ctx, aIter) } - aIter = ra.Resume(ctx, info, opts...) - } else { - return wrapIterWithOnEnd(ctx, genErrorIter(fmt.Errorf("failed to resume agent: agent '%s' is an interrupt point "+ - "but is not a ResumableAgent", agentName))), notCancellableFuncInternal + + aIter := ra.Resume(withCancelContext(ctx, cancelCtx), info, opts...) + + iterator, generator := NewAsyncIteratorPair[*AgentEvent]() + go a.run(withCancelContext(ctx, cancelCtx), withCancelContext(ctxForSubAgents, cancelCtx), getRunCtx(ctxForSubAgents), aIter, generator, filterCancelOption(opts)...) + return wrapIterWithCancelCtx(iterator, cancelCtx) } - iterator, generator := NewAsyncIteratorPair[*AgentEvent]() - go a.run(ctx, ctxForSubAgents, getRunCtx(ctxForSubAgents), aIter, generator, opts...) - return iterator, cancelFn + if cancelCtx != nil { + cancelCtx.markDone() + } + return wrapIterWithOnEnd(ctx, genErrorIter(fmt.Errorf("failed to resume agent: agent '%s' is an interrupt point "+ + "but is not a ResumableAgent", agentName))) } nextAgentName, err := getNextResumeAgent(ctx, info) if err != nil { - return wrapIterWithOnEnd(ctx, genErrorIter(err)), notCancellableFuncInternal + if cancelCtx != nil { + cancelCtx.markDone() + } + return wrapIterWithOnEnd(ctx, genErrorIter(err)) } subAgent := a.getAgent(ctxForSubAgents, nextAgentName) if subAgent == nil { if len(a.subAgents) == 0 { - ca, supportCancel := a.Agent.(CancellableResumableAgent) - if withCancel && supportCancel { - iter, cancelFn := ca.ResumeWithCancel(ctx, info, opts...) - return wrapIterWithOnEnd(ctx, iter), cancelFn - } if ra, ok := a.Agent.(ResumableAgent); ok { - return wrapIterWithOnEnd(ctx, ra.Resume(ctx, info, opts...)), notCancellableFuncInternal + ctx = withCancelContext(ctx, cancelCtx) + innerIter := ra.Resume(ctx, info, filterCancelOption(opts)...) + return wrapIterWithCancelCtx(wrapIterWithOnEnd(ctx, innerIter), cancelCtx) } return wrapIterWithOnEnd(ctx, genErrorIter(fmt.Errorf( "failed to resume agent: agent '%s' (type %T) has no sub-agents and does not implement ResumableAgent interface. "+ - "To support resume, your custom agent wrapper must implement the ResumableAgent interface", agentName, a.Agent))), nil + "To support resume, your custom agent wrapper must implement the ResumableAgent interface", agentName, a.Agent))) } - return wrapIterWithOnEnd(ctx, genErrorIter(fmt.Errorf("failed to resume agent: sub-agent '%s' not found in agent '%s'", nextAgentName, agentName))), notCancellableFuncInternal - } - - ca, supportCancel := ResumableAgent(subAgent).(CancellableResumableAgent) - if withCancel && supportCancel { - iter, cancelFn := ca.ResumeWithCancel(ctxForSubAgents, info, opts...) - return wrapIterWithOnEnd(ctx, iter), cancelFn + if cancelCtx != nil { + cancelCtx.markDone() + } + return wrapIterWithOnEnd(ctx, genErrorIter(fmt.Errorf("failed to resume agent: sub-agent '%s' not found in agent '%s'", nextAgentName, agentName))) } - return wrapIterWithOnEnd(ctx, subAgent.Resume(ctxForSubAgents, info, opts...)), notCancellableFuncInternal + ctxForSubAgents = withCancelContext(ctxForSubAgents, cancelCtx) + innerIter := subAgent.Resume(ctxForSubAgents, info, filterCancelOption(opts)...) + return wrapIterWithCancelCtx(wrapIterWithOnEnd(ctx, innerIter), cancelCtx) } type DeterministicTransferConfig struct { diff --git a/adk/handler.go b/adk/handler.go index 7c7ebba71..423282a7a 100644 --- a/adk/handler.go +++ b/adk/handler.go @@ -47,6 +47,12 @@ type ToolContext struct { CallID string } +// ToolCallsContext contains metadata about the tool calls that just completed. +type ToolCallsContext struct { + // ToolCalls contains the tool call metadata from the model's response. + ToolCalls []ToolContext +} + // ModelContext contains context information passed to WrapModel. type ModelContext struct { // Tools contains the current tool list configured for the agent. @@ -57,6 +63,8 @@ type ModelContext struct { // This is populated at request time from the agent's ModelRetryConfig. // Used by EventSenderModelWrapper to wrap stream errors appropriately. ModelRetryConfig *ModelRetryConfig + + cancelContext *cancelContext } // ChatModelAgentContext contains runtime information passed to handlers before each ChatModelAgent run. @@ -138,6 +146,14 @@ type ChatModelAgentMiddleware interface { // - Tools: the current tool list that was sent to the model AfterModelRewriteState(ctx context.Context, state *ChatModelAgentState, mc *ModelContext) (context.Context, *ChatModelAgentState, error) + // AfterToolCallsRewriteState is called after all concurrent tool calls in an iteration complete. + // The input state includes all messages up to and including the tool call results. + // The returned state is persisted to the agent's internal state. + // + // The ToolCallsContext provides metadata about the tool calls that just completed, + // derived from the assistant message's ToolCalls field. + AfterToolCallsRewriteState(ctx context.Context, state *ChatModelAgentState, tc *ToolCallsContext) (context.Context, *ChatModelAgentState, error) + // WrapInvokableToolCall wraps a tool's synchronous execution with custom behavior. // Return the input endpoint unchanged and nil error if no wrapping is needed. // @@ -247,6 +263,10 @@ func (b *BaseChatModelAgentMiddleware) AfterModelRewriteState(ctx context.Contex return ctx, state, nil } +func (b *BaseChatModelAgentMiddleware) AfterToolCallsRewriteState(ctx context.Context, state *ChatModelAgentState, tc *ToolCallsContext) (context.Context, *ChatModelAgentState, error) { + return ctx, state, nil +} + // SetRunLocalValue sets a key-value pair that persists for the duration of the current agent Run() invocation. // The value is scoped to this specific execution and is not shared across different Run() calls or agent instances. // diff --git a/adk/handler_test.go b/adk/handler_test.go index e56da3842..abdb0ecab 100644 --- a/adk/handler_test.go +++ b/adk/handler_test.go @@ -111,6 +111,15 @@ func (h *testAfterModelRewriteStateHandler) AfterModelRewriteState(ctx context.C return h.fn(ctx, state, mc) } +type testAfterToolCallsHandler struct { + *BaseChatModelAgentMiddleware + fn func(ctx context.Context, state *ChatModelAgentState, tc *ToolCallsContext) (context.Context, *ChatModelAgentState, error) +} + +func (h *testAfterToolCallsHandler) AfterToolCallsRewriteState(ctx context.Context, state *ChatModelAgentState, tc *ToolCallsContext) (context.Context, *ChatModelAgentState, error) { + return h.fn(ctx, state, tc) +} + type testToolWrapperHandler struct { *BaseChatModelAgentMiddleware wrapInvokableFn func(context.Context, InvokableToolCallEndpoint, *ToolContext) InvokableToolCallEndpoint @@ -1820,3 +1829,312 @@ func TestToolContextInWrappers(t *testing.T) { assert.Equal(t, "test_call_id_123", capturedCallID, "ToolContext should have correct call ID") }) } + +func TestAfterToolCallsRewriteState(t *testing.T) { + t.Run("ReceivesCorrectToolCallsContext", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + tool1 := &namedTool{name: "tool_alpha"} + tool2 := &namedTool{name: "tool_beta"} + + cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() + + // First call: model returns two tool calls + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + Return(schema.AssistantMessage("calling tools", []schema.ToolCall{ + {ID: "call_1", Function: schema.FunctionCall{Name: "tool_alpha", Arguments: "{}"}}, + {ID: "call_2", Function: schema.FunctionCall{Name: "tool_beta", Arguments: "{}"}}, + }), nil).Times(1) + + // Second call: model returns final response + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + Return(schema.AssistantMessage("done", nil), nil).Times(1) + + var mu sync.Mutex + var capturedTC *ToolCallsContext + var capturedState *ChatModelAgentState + callCount := 0 + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: cm, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{tool1, tool2}, + }, + }, + Handlers: []ChatModelAgentMiddleware{ + &testAfterToolCallsHandler{fn: func(ctx context.Context, state *ChatModelAgentState, tc *ToolCallsContext) (context.Context, *ChatModelAgentState, error) { + mu.Lock() + callCount++ + capturedTC = tc + capturedState = &ChatModelAgentState{Messages: append([]Message{}, state.Messages...)} + mu.Unlock() + return ctx, state, nil + }}, + }, + }) + assert.NoError(t, err) + + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("test")}}) + for { + _, ok := iter.Next() + if !ok { + break + } + } + + mu.Lock() + defer mu.Unlock() + + // Should be called exactly once (one iteration with tool calls) + assert.Equal(t, 1, callCount) + + // ToolCallsContext should have the two tool calls + assert.NotNil(t, capturedTC) + assert.Len(t, capturedTC.ToolCalls, 2) + assert.Equal(t, "tool_alpha", capturedTC.ToolCalls[0].Name) + assert.Equal(t, "call_1", capturedTC.ToolCalls[0].CallID) + assert.Equal(t, "tool_beta", capturedTC.ToolCalls[1].Name) + assert.Equal(t, "call_2", capturedTC.ToolCalls[1].CallID) + + // State should contain: system msg + user msg + assistant msg + 2 tool results + assert.NotNil(t, capturedState) + assert.True(t, len(capturedState.Messages) >= 4, "expected at least 4 messages, got %d", len(capturedState.Messages)) + + // Check tool results are in state + toolResultCount := 0 + for _, msg := range capturedState.Messages { + if msg.Role == schema.Tool { + toolResultCount++ + } + } + assert.Equal(t, 2, toolResultCount) + }) + + t.Run("NotCalledWithoutToolCalls", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + // Model returns a direct response with no tool calls + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + Return(schema.AssistantMessage("direct response", nil), nil).Times(1) + + callCount := 0 + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: cm, + Handlers: []ChatModelAgentMiddleware{ + &testAfterToolCallsHandler{fn: func(ctx context.Context, state *ChatModelAgentState, tc *ToolCallsContext) (context.Context, *ChatModelAgentState, error) { + callCount++ + return ctx, state, nil + }}, + }, + }) + assert.NoError(t, err) + + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("test")}}) + for { + _, ok := iter.Next() + if !ok { + break + } + } + + assert.Equal(t, 0, callCount, "AfterToolCallsRewriteState should not be called when no tool calls happen") + }) + + t.Run("CanModifyStatePersistsToNextIteration", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + tool1 := &namedTool{name: "mytool"} + cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() + + // First call: model returns a tool call + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + Return(schema.AssistantMessage("calling", []schema.ToolCall{ + {ID: "c1", Function: schema.FunctionCall{Name: "mytool", Arguments: "{}"}}, + }), nil).Times(1) + + // Second call: capture messages to verify the injected message is present + var capturedMsgs []*schema.Message + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, msgs []*schema.Message, opts ...interface{}) (*schema.Message, error) { + capturedMsgs = append([]*schema.Message{}, msgs...) + return schema.AssistantMessage("final", nil), nil + }).Times(1) + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: cm, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{tool1}, + }, + }, + Handlers: []ChatModelAgentMiddleware{ + &testAfterToolCallsHandler{fn: func(ctx context.Context, state *ChatModelAgentState, tc *ToolCallsContext) (context.Context, *ChatModelAgentState, error) { + // Inject a user message into state + state.Messages = append(state.Messages, schema.UserMessage("injected_by_middleware")) + return ctx, state, nil + }}, + }, + }) + assert.NoError(t, err) + + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("original")}}) + for { + _, ok := iter.Next() + if !ok { + break + } + } + + // The injected message should be visible in the second model call + assert.NotNil(t, capturedMsgs) + found := false + for _, msg := range capturedMsgs { + if msg.Content == "injected_by_middleware" { + found = true + break + } + } + assert.True(t, found, "Injected message should persist to the next model call") + }) +} + +func TestToolResultNotDuplicated(t *testing.T) { + t.Run("SecondModelCallHasNoToolResultDuplication", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + tool1 := &namedTool{name: "mytool"} + cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() + + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + Return(schema.AssistantMessage("calling", []schema.ToolCall{ + {ID: "c1", Function: schema.FunctionCall{Name: "mytool", Arguments: "{}"}}, + }), nil).Times(1) + + var capturedMsgs []*schema.Message + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, msgs []*schema.Message, opts ...interface{}) (*schema.Message, error) { + capturedMsgs = append([]*schema.Message{}, msgs...) + return schema.AssistantMessage("final", nil), nil + }).Times(1) + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Instruction: "You are helpful.", + Model: cm, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{tool1}, + }, + }, + }) + assert.NoError(t, err) + + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("hello")}}) + for { + _, ok := iter.Next() + if !ok { + break + } + } + + assert.NotNil(t, capturedMsgs) + assert.Equal(t, 4, len(capturedMsgs), + "expected [system, user, assistant, tool_result], got %d messages", len(capturedMsgs)) + assert.Equal(t, schema.System, capturedMsgs[0].Role) + assert.Equal(t, schema.User, capturedMsgs[1].Role) + assert.Equal(t, schema.Assistant, capturedMsgs[2].Role) + assert.Equal(t, schema.Tool, capturedMsgs[3].Role) + + toolResultCount := 0 + for _, msg := range capturedMsgs { + if msg.Role == schema.Tool { + toolResultCount++ + } + } + assert.Equal(t, 1, toolResultCount, + "tool result should appear exactly once, got %d", toolResultCount) + }) + + t.Run("HandlerInjectedMessagePresentWithoutDuplication", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + tool1 := &namedTool{name: "mytool"} + cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() + + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + Return(schema.AssistantMessage("calling", []schema.ToolCall{ + {ID: "c1", Function: schema.FunctionCall{Name: "mytool", Arguments: "{}"}}, + }), nil).Times(1) + + var capturedMsgs []*schema.Message + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, msgs []*schema.Message, opts ...interface{}) (*schema.Message, error) { + capturedMsgs = append([]*schema.Message{}, msgs...) + return schema.AssistantMessage("final", nil), nil + }).Times(1) + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Instruction: "You are helpful.", + Model: cm, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{tool1}, + }, + }, + Handlers: []ChatModelAgentMiddleware{ + &testAfterToolCallsHandler{fn: func(ctx context.Context, state *ChatModelAgentState, tc *ToolCallsContext) (context.Context, *ChatModelAgentState, error) { + state.Messages = append(state.Messages, schema.UserMessage("injected")) + return ctx, state, nil + }}, + }, + }) + assert.NoError(t, err) + + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("hello")}}) + for { + _, ok := iter.Next() + if !ok { + break + } + } + + assert.NotNil(t, capturedMsgs) + assert.Equal(t, 5, len(capturedMsgs), + "expected [system, user, assistant, tool_result, injected], got %d messages", len(capturedMsgs)) + assert.Equal(t, schema.System, capturedMsgs[0].Role) + assert.Equal(t, schema.User, capturedMsgs[1].Role) + assert.Equal(t, schema.Assistant, capturedMsgs[2].Role) + assert.Equal(t, schema.Tool, capturedMsgs[3].Role) + assert.Equal(t, "injected", capturedMsgs[4].Content) + + toolResultCount := 0 + for _, msg := range capturedMsgs { + if msg.Role == schema.Tool { + toolResultCount++ + } + } + assert.Equal(t, 1, toolResultCount, + "tool result should appear exactly once, got %d", toolResultCount) + }) +} diff --git a/adk/interface.go b/adk/interface.go index c73705ae3..5c06843ae 100644 --- a/adk/interface.go +++ b/adk/interface.go @@ -20,10 +20,8 @@ import ( "bytes" "context" "encoding/gob" - "errors" "fmt" "io" - "time" "github.com/cloudwego/eino/components" "github.com/cloudwego/eino/internal/core" @@ -271,55 +269,3 @@ type ResumableAgent interface { Resume(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] } - -// CancelMode specifies when an agent should be canceled. -// Modes can be combined with bitwise OR to cancel at multiple execution points. -// For example, CancelAfterChatModel | CancelAfterToolCall cancels the agent -// after whichever execution point is reached first. -type CancelMode int - -const ( - // CancelImmediate cancels the agent immediately without waiting - // for any execution point. - CancelImmediate CancelMode = 0 - // CancelAfterChatModel cancels the agent after a chat model call completes. - CancelAfterChatModel CancelMode = 1 << iota - // CancelAfterToolCall cancels the agent after a tool call completes. - CancelAfterToolCall -) - -// ErrAgentNotCancellable is returned by Cancel when the agent does not support cancellation. -var ErrAgentNotCancellable = errors.New("agent does not implement CancellableAgent interface") - -type cancelConfig struct { - Mode CancelMode - Timeout *time.Duration -} - -type CancelOption func(*cancelConfig) - -// WithCancelMode sets the cancel mode for the cancel operation. -func WithCancelMode(mode CancelMode) CancelOption { - return func(config *cancelConfig) { - config.Mode = mode - } -} - -// WithCancelTimeout sets a timeout duration for CancelImmediate mode. -func WithCancelTimeout(timeout time.Duration) CancelOption { - return func(config *cancelConfig) { - config.Timeout = &timeout - } -} - -type CancelFunc func(...CancelOption) error - -type CancellableAgent interface { - Agent - RunWithCancel(ctx context.Context, input *AgentInput, options ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) -} - -type CancellableResumableAgent interface { - ResumableAgent - ResumeWithCancel(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) -} diff --git a/adk/interrupt.go b/adk/interrupt.go index 5941d0724..fce09d4cf 100644 --- a/adk/interrupt.go +++ b/adk/interrupt.go @@ -22,6 +22,7 @@ import ( "encoding/gob" "errors" "fmt" + "sync" "github.com/cloudwego/eino/internal/core" "github.com/cloudwego/eino/schema" @@ -183,6 +184,11 @@ func WithCheckPointID(id string) AgentRunOption { func init() { schema.RegisterName[*serialization]("_eino_adk_serialization") schema.RegisterName[*WorkflowInterruptInfo]("_eino_adk_workflow_interrupt_info") + // Register []byte for gob: the cancel refactor routes bridge store checkpoint + // bytes ([]byte) through InterruptState.State (type any) inside the outer + // serialization struct. Gob requires concrete types behind interface fields + // to be registered. + gob.Register([]byte{}) } // serialization CheckpointSchema: root checkpoint payload (gob). @@ -266,6 +272,10 @@ func (r *Runner) saveCheckPoint( info *InterruptInfo, is *core.InterruptSignal, ) error { + if r.store == nil { + return nil + } + runCtx := getRunCtx(ctx) id2Addr, id2State := core.SignalToPersistenceMaps(is) @@ -287,31 +297,36 @@ func (r *Runner) saveCheckPoint( const bridgeCheckpointID = "adk_react_mock_key" func newBridgeStore() *bridgeStore { - return &bridgeStore{} + return &bridgeStore{data: make(map[string][]byte)} } -func newResumeBridgeStore(data []byte) *bridgeStore { +func newResumeBridgeStore(checkPointID string, data []byte) *bridgeStore { return &bridgeStore{ - Data: data, - Valid: true, + data: map[string][]byte{checkPointID: data}, } } type bridgeStore struct { - Data []byte - Valid bool + mu sync.Mutex + data map[string][]byte } -func (m *bridgeStore) Get(_ context.Context, _ string) ([]byte, bool, error) { - if m.Valid { - return m.Data, true, nil +func (m *bridgeStore) Get(_ context.Context, key string) ([]byte, bool, error) { + m.mu.Lock() + defer m.mu.Unlock() + if v, ok := m.data[key]; ok { + return v, true, nil } return nil, false, nil } -func (m *bridgeStore) Set(_ context.Context, _ string, checkPoint []byte) error { - m.Data = checkPoint - m.Valid = true +func (m *bridgeStore) Set(_ context.Context, key string, checkPoint []byte) error { + m.mu.Lock() + defer m.mu.Unlock() + if m.data == nil { + m.data = make(map[string][]byte) + } + m.data[key] = checkPoint return nil } diff --git a/adk/middlewares/patchtoolcalls/patchtoolcalls.go b/adk/middlewares/patchtoolcalls/patchtoolcalls.go index 75fb5fcbf..833ca3794 100644 --- a/adk/middlewares/patchtoolcalls/patchtoolcalls.go +++ b/adk/middlewares/patchtoolcalls/patchtoolcalls.go @@ -121,6 +121,6 @@ func (m *middleware) createPatchedToolMessage(ctx context.Context, tc schema.Too } const ( - defaultPatchedToolMessageTemplate = "Tool call %s with id %s was cancelled - another message came in before it could be completed." + defaultPatchedToolMessageTemplate = "Tool call %s with id %s was canceled - another message came in before it could be completed." defaultPatchedToolMessageTemplateChinese = "工具调用 %s(ID 为 %s)已被取消——在其完成之前收到了另一条消息。" ) diff --git a/adk/prebuilt/planexecute/plan_execute_test.go b/adk/prebuilt/planexecute/plan_execute_test.go index fb7360357..6734a16b8 100644 --- a/adk/prebuilt/planexecute/plan_execute_test.go +++ b/adk/prebuilt/planexecute/plan_execute_test.go @@ -18,9 +18,12 @@ package planexecute import ( "context" + "errors" "fmt" "strings" + "sync" "testing" + "time" "github.com/bytedance/sonic" "github.com/stretchr/testify/assert" @@ -1002,3 +1005,232 @@ func TestPlanExecuteAgentInterruptResume(t *testing.T) { assert.True(t, hasAssistantCompletion, "Should have assistant completion message") assert.True(t, hasBreakLoop, "Should have break loop action indicating completion") } + +// slowChatModel is a ChatModel that blocks for a configurable duration. +type slowChatModel struct { + delay time.Duration + response *schema.Message + startedChan chan struct{} + startedOnce sync.Once +} + +func (m *slowChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + m.startedOnce.Do(func() { + close(m.startedChan) + }) + + select { + case <-time.After(m.delay): + return m.response, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (m *slowChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + msg, err := m.Generate(ctx, input, opts...) + if err != nil { + return nil, err + } + sr, sw := schema.Pipe[*schema.Message](1) + sw.Send(msg, nil) + sw.Close() + return sr, nil +} + +func (m *slowChatModel) BindTools(tools []*schema.ToolInfo) error { return nil } +func (m *slowChatModel) WithTools(tools []*schema.ToolInfo) (model.ToolCallingChatModel, error) { + return m, nil +} + +// TestWithCancel_PlanExecute_DuringExecution verifies that cancel works +// during the executor (ChatModelAgent) phase of the PlanExecute agent. +func TestWithCancel_PlanExecute_DuringExecution(t *testing.T) { + ctx := context.Background() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Planner: returns a plan quickly + mockPlanner := mockAdk.NewMockAgent(ctrl) + mockPlanner.EXPECT().Name(gomock.Any()).Return("planner").AnyTimes() + mockPlanner.EXPECT().Description(gomock.Any()).Return("a planner agent").AnyTimes() + + plan := &defaultPlan{Steps: []string{"Step 1", "Step 2"}} + userInput := []adk.Message{schema.UserMessage("test task")} + + mockPlanner.EXPECT().Run(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, input *adk.AgentInput, opts ...adk.AgentRunOption) *adk.AsyncIterator[*adk.AgentEvent] { + iterator, generator := adk.NewAsyncIteratorPair[*adk.AgentEvent]() + adk.AddSessionValue(ctx, PlanSessionKey, plan) + adk.AddSessionValue(ctx, UserInputSessionKey, userInput) + planJSON, _ := sonic.MarshalString(plan) + msg := schema.AssistantMessage(planJSON, nil) + generator.Send(adk.EventFromMessage(msg, nil, schema.Assistant, "")) + generator.Close() + return iterator + }, + ).Times(1) + + // Executor: uses a slow model that we can cancel + executorStarted := make(chan struct{}) + slowModel := &slowChatModel{ + delay: 5 * time.Second, + response: schema.AssistantMessage("step result", nil), + startedChan: executorStarted, + } + + executor, err := NewExecutor(ctx, &ExecutorConfig{ + Model: slowModel, + MaxIterations: 5, + }) + assert.NoError(t, err) + + // Replanner: should not be reached since we cancel during executor + mockReplanner := mockAdk.NewMockAgent(ctrl) + mockReplanner.EXPECT().Name(gomock.Any()).Return("replanner").AnyTimes() + mockReplanner.EXPECT().Description(gomock.Any()).Return("a replanner agent").AnyTimes() + + agent, err := New(ctx, &Config{ + Planner: mockPlanner, + Executor: executor, + Replanner: mockReplanner, + MaxIterations: 5, + }) + assert.NoError(t, err) + + runner := adk.NewRunner(ctx, adk.RunnerConfig{Agent: agent}) + + cancelOpt, cancelFn := adk.WithCancel() + iter := runner.Run(ctx, userInput, cancelOpt) + + // Wait for the executor's model to start + select { + case <-executorStarted: + case <-time.After(10 * time.Second): + t.Fatal("Executor model did not start") + } + + time.Sleep(50 * time.Millisecond) + + // Cancel should NOT return ErrExecutionCompleted + handle, _ := cancelFn() + err = handle.Wait() + assert.NoError(t, err, "Cancel during executor should succeed") + + hasCancelError := false + for { + event, ok := iter.Next() + if !ok { + break + } + var ce *adk.CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + hasCancelError = true + } + } + + assert.True(t, hasCancelError, "Should have CancelError event") +} + +// TestWithCancel_PlanExecute_BetweenTransitions verifies that cancel works +// when fired between agent transitions (e.g., after planner, before executor starts). +func TestWithCancel_PlanExecute_BetweenTransitions(t *testing.T) { + ctx := context.Background() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + plannerDone := make(chan struct{}) + + // Planner: signals when done + mockPlanner := mockAdk.NewMockAgent(ctrl) + mockPlanner.EXPECT().Name(gomock.Any()).Return("planner").AnyTimes() + mockPlanner.EXPECT().Description(gomock.Any()).Return("a planner agent").AnyTimes() + + plan := &defaultPlan{Steps: []string{"Step 1"}} + userInput := []adk.Message{schema.UserMessage("test task")} + + mockPlanner.EXPECT().Run(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, input *adk.AgentInput, opts ...adk.AgentRunOption) *adk.AsyncIterator[*adk.AgentEvent] { + iterator, generator := adk.NewAsyncIteratorPair[*adk.AgentEvent]() + go func() { + defer generator.Close() + adk.AddSessionValue(ctx, PlanSessionKey, plan) + adk.AddSessionValue(ctx, UserInputSessionKey, userInput) + planJSON, _ := sonic.MarshalString(plan) + msg := schema.AssistantMessage(planJSON, nil) + generator.Send(adk.EventFromMessage(msg, nil, schema.Assistant, "")) + close(plannerDone) + }() + return iterator + }, + ).Times(1) + + // Executor: slow model to give time to observe cancel + executorModelStarted := make(chan struct{}) + slowExecModel := &slowChatModel{ + delay: 5 * time.Second, + response: schema.AssistantMessage("step result", nil), + startedChan: executorModelStarted, + } + + executor, err := NewExecutor(ctx, &ExecutorConfig{ + Model: slowExecModel, + MaxIterations: 5, + }) + assert.NoError(t, err) + + mockReplanner := mockAdk.NewMockAgent(ctrl) + mockReplanner.EXPECT().Name(gomock.Any()).Return("replanner").AnyTimes() + mockReplanner.EXPECT().Description(gomock.Any()).Return("a replanner agent").AnyTimes() + + agent, err := New(ctx, &Config{ + Planner: mockPlanner, + Executor: executor, + Replanner: mockReplanner, + MaxIterations: 5, + }) + assert.NoError(t, err) + + runner := adk.NewRunner(ctx, adk.RunnerConfig{Agent: agent}) + + cancelOpt, cancelFn := adk.WithCancel() + iter := runner.Run(ctx, userInput, cancelOpt) + + // Wait for planner to finish, then cancel before executor has a chance to produce output + select { + case <-plannerDone: + case <-time.After(10 * time.Second): + t.Fatal("Planner did not finish") + } + + // Cancel after planner, during executor phase + // The executor is a ChatModelAgent which will handle the cancel + select { + case <-executorModelStarted: + case <-time.After(10 * time.Second): + t.Fatal("Executor model did not start") + } + + start := time.Now() + handle, _ := cancelFn() + err = handle.Wait() + assert.NoError(t, err, "Cancel between transitions should succeed") + + hasCancelError := false + for { + event, ok := iter.Next() + if !ok { + break + } + var ce *adk.CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + hasCancelError = true + } + } + elapsed := time.Since(start) + + assert.True(t, hasCancelError, "Should have CancelError event") + assert.True(t, elapsed < 3*time.Second, "Should complete quickly after cancel, elapsed: %v", elapsed) +} diff --git a/adk/react.go b/adk/react.go index 0aec5d94c..07fdbde9a 100644 --- a/adk/react.go +++ b/adk/react.go @@ -22,7 +22,6 @@ import ( "encoding/gob" "errors" "io" - "sync/atomic" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/compose" @@ -82,6 +81,7 @@ func init() { // when decoding checkpoints created by v0.8.0 - v0.8.3 gob.Register(&AgentEvent{}) gob.Register(int(0)) + schema.RegisterName[*reactInput]("_eino_adk_react_input") } func (s *State) getReturnDirectlyEvent() *AgentEvent { @@ -238,7 +238,7 @@ func SendToolGenAction(ctx context.Context, toolName string, action *AgentAction } type reactInput struct { - messages []Message + Messages []Message } type reactConfig struct { @@ -254,6 +254,8 @@ type reactConfig struct { agentName string maxIterations int + + cancelCtx *cancelContext } func genToolInfos(ctx context.Context, config *compose.ToolsNodeConfig) ([]*schema.ToolInfo, error) { @@ -271,8 +273,6 @@ func genToolInfos(ctx context.Context, config *compose.ToolsNodeConfig) ([]*sche } type reactGraph = *compose.Graph[*reactInput, Message] -type sToolNodeOutput = *schema.StreamReader[[]Message] -type sGraphOutput = MessageStream func getReturnDirectlyToolCallID(ctx context.Context) (string, bool) { var toolCallID string @@ -300,37 +300,25 @@ func genReactState(config *reactConfig) func(ctx context.Context) *State { } } -func newReact(ctx context.Context, config *reactConfig, cs *cancelSig) (reactGraph, error) { +func newReact(ctx context.Context, config *reactConfig) (reactGraph, error) { const ( - initNode_ = "Init" - chatModel_ = "ChatModel" - beforeToolNode_ = "BeforeToolNode" - toolNode_ = "ToolNode" - afterToolNode_ = "AfterToolNode" + initNode_ = "Init" + chatModel_ = "ChatModel" + cancelCheckNode_ = "CancelCheck" + toolNode_ = "ToolNode" + afterToolCallsNode_ = "AfterToolCalls" + afterToolCallsCancelCheckNode_ = "AfterToolCallsCancelCheck" ) - checkCancel := cs != nil - - nodeNameAfterModel := func() string { - if checkCancel { - return beforeToolNode_ - } - return toolNode_ - } - - nodeNameAfterTool := func() string { - if checkCancel { - return afterToolNode_ - } - return chatModel_ - } - + cancelCtx := config.cancelCtx g := compose.NewGraph[*reactInput, Message](compose.WithGenLocalState(genReactState(config))) - - initLambda := func(ctx context.Context, input *reactInput) ([]Message, error) { - return input.messages, nil - } - _ = g.AddLambdaNode(initNode_, compose.InvokableLambda(initLambda), compose.WithNodeName(initNode_)) + _ = g.AddLambdaNode(initNode_, compose.InvokableLambda(func(ctx context.Context, input *reactInput) ([]Message, error) { + _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { + st.Messages = append(st.Messages, input.Messages...) + return nil + }) + return input.Messages, nil + }), compose.WithNodeName(initNode_)) var wrappedModel model.BaseChatModel = config.model if config.modelWrapperConf != nil { @@ -338,39 +326,43 @@ func newReact(ctx context.Context, config *reactConfig, cs *cancelSig) (reactGra } toolsConfig := config.toolsConfig - if checkCancel { - wrappedModel = wrapModelForCancelable(wrappedModel, cs) - tcMWs := make([]compose.ToolMiddleware, 0, len(toolsConfig.ToolCallMiddlewares)+1) - tcMWs = append(tcMWs, cancelableTool(cs)) - tcMWs = append(tcMWs, toolsConfig.ToolCallMiddlewares...) - toolsConfigCopy := *toolsConfig - toolsConfigCopy.ToolCallMiddlewares = tcMWs - toolsConfig = &toolsConfigCopy - } toolsNode, err := compose.NewToolNode(ctx, toolsConfig) if err != nil { return nil, err } - modelPreHandle := func(ctx context.Context, input []Message, st *State) ([]Message, error) { - if st.getRemainingIterations() <= 0 { - return nil, ErrExceedMaxIterations + _ = g.AddChatModelNode(chatModel_, wrappedModel, compose.WithStatePreHandler( + func(ctx context.Context, input []Message, st *State) ([]Message, error) { + if st.getRemainingIterations() <= 0 { + return nil, ErrExceedMaxIterations + } + st.decrementRemainingIterations() + return input, nil + }), compose.WithNodeName(chatModel_)) + + // CancelAfterChatModel safe-point: on the tool-calls path, after the branch + // has confirmed that the model response contains tool calls (i.e. not a final + // answer). Skipped entirely when the model produces a final answer. + _ = g.AddLambdaNode(cancelCheckNode_, compose.InvokableLambda(func(ctx context.Context, msg Message) (Message, error) { + if cancelCtx != nil && cancelCtx.shouldCancel() { + if cancelCtx.getMode()&CancelAfterChatModel != 0 { + return nil, compose.StatefulInterrupt(ctx, "CancelAfterChatModel", msg) + } } - st.decrementRemainingIterations() - return input, nil - } - _ = g.AddChatModelNode(chatModel_, wrappedModel, - compose.WithStatePreHandler(modelPreHandle), compose.WithNodeName(chatModel_)) + wasInterrupted, hasState, state := compose.GetInterruptState[Message](ctx) + if wasInterrupted && hasState { + msg = state + } + return msg, nil + }), compose.WithNodeName(cancelCheckNode_)) toolPreHandle := func(ctx context.Context, _ Message, st *State) (Message, error) { input := st.Messages[len(st.Messages)-1] - returnDirectly := config.toolsReturnDirectly if execCtx := getChatModelAgentExecCtx(ctx); execCtx != nil && len(execCtx.runtimeReturnDirectly) > 0 { returnDirectly = execCtx.runtimeReturnDirectly } - if len(returnDirectly) > 0 { for i := range input.ToolCalls { toolName := input.ToolCalls[i].Function.Name @@ -379,10 +371,8 @@ func newReact(ctx context.Context, config *reactConfig, cs *cancelSig) (reactGra } } } - return input, nil } - toolPostHandle := func(ctx context.Context, out *schema.StreamReader[[]*schema.Message], st *State) (*schema.StreamReader[[]*schema.Message], error) { if event := st.getReturnDirectlyEvent(); event != nil { getChatModelAgentExecCtx(ctx).send(event) @@ -390,36 +380,62 @@ func newReact(ctx context.Context, config *reactConfig, cs *cancelSig) (reactGra } return out, nil } - _ = g.AddToolsNode(toolNode_, toolsNode, compose.WithStatePreHandler(toolPreHandle), compose.WithStreamStatePostHandler(toolPostHandle), compose.WithNodeName(toolNode_)) - _ = g.AddEdge(compose.START, initNode_) - _ = g.AddEdge(initNode_, chatModel_) - - if checkCancel { - beforeToolNode := func(ctx context.Context, input Message) (output Message, err error) { - if sig := checkCancelSig(cs); sig != nil && sig.Mode != CancelAfterToolCall { - return nil, compose.Interrupt(ctx, "cancelled externally") + // AfterToolCalls node: calls AfterToolCallsRewriteState handlers after all tool calls complete. + // The graph auto-materializes the ToolsNode stream into []Message before this node. + afterToolCalls := func(ctx context.Context, toolResults []Message) ([]Message, error) { + var stateMessages []Message + _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { + stateMessages = st.Messages + return nil + }) + + state := &ChatModelAgentState{Messages: append(stateMessages, toolResults...)} + + if config.modelWrapperConf != nil { + assistantMsg := stateMessages[len(stateMessages)-1] + tc := &ToolCallsContext{} + for _, toolCall := range assistantMsg.ToolCalls { + tc.ToolCalls = append(tc.ToolCalls, ToolContext{Name: toolCall.Function.Name, CallID: toolCall.ID}) } - return input, nil + for _, handler := range config.modelWrapperConf.handlers { + var err error + ctx, state, err = handler.AfterToolCallsRewriteState(ctx, state, tc) + if err != nil { + return nil, err + } + } } - _ = g.AddLambdaNode(beforeToolNode_, compose.InvokableLambda(beforeToolNode), compose.WithNodeName(beforeToolNode_)) - g.AddEdge(beforeToolNode_, toolNode_) - afterToolNode := func(ctx context.Context, input []Message) (output []Message, err error) { - if sig := checkCancelSig(cs); sig != nil && sig.Mode != CancelAfterChatModel { - return nil, compose.Interrupt(ctx, "cancelled externally") - } + _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { + st.Messages = state.Messages + return nil + }) - return input, nil + return toolResults, nil + } + _ = g.AddLambdaNode(afterToolCallsNode_, compose.InvokableLambda(afterToolCalls), + compose.WithNodeName(afterToolCallsNode_)) + + // AfterToolCallsCancelCheck: CancelAfterToolCalls safe-point, separated from toolPostHandle. + afterToolCallsCancelCheck := func(ctx context.Context, toolResults []Message) ([]Message, error) { + if cancelCtx != nil && cancelCtx.shouldCancel() { + if cancelCtx.getMode()&CancelAfterToolCalls != 0 { + return nil, compose.Interrupt(ctx, "CancelAfterToolCalls") + } } - _ = g.AddLambdaNode(afterToolNode_, compose.InvokableLambda(afterToolNode), compose.WithNodeName(afterToolNode_)) - g.AddEdge(afterToolNode_, chatModel_) + return toolResults, nil } + _ = g.AddLambdaNode(afterToolCallsCancelCheckNode_, compose.InvokableLambda(afterToolCallsCancelCheck), + compose.WithNodeName(afterToolCallsCancelCheckNode_)) + + _ = g.AddEdge(compose.START, initNode_) + _ = g.AddEdge(initNode_, chatModel_) toolCallCheck := func(ctx context.Context, sMsg MessageStream) (string, error) { defer sMsg.Close() @@ -434,84 +450,54 @@ func newReact(ctx context.Context, config *reactConfig, cs *cancelSig) (reactGra } if len(chunk.ToolCalls) > 0 { - return nodeNameAfterModel(), nil + return cancelCheckNode_, nil } } } - branch := compose.NewStreamGraphBranch(toolCallCheck, map[string]bool{compose.END: true, nodeNameAfterModel(): true}) + branch := compose.NewStreamGraphBranch(toolCallCheck, map[string]bool{compose.END: true, cancelCheckNode_: true}) _ = g.AddBranch(chatModel_, branch) + _ = g.AddEdge(cancelCheckNode_, toolNode_) + _ = g.AddEdge(toolNode_, afterToolCallsNode_) + _ = g.AddEdge(afterToolCallsNode_, afterToolCallsCancelCheckNode_) + if len(config.toolsReturnDirectly) > 0 { const ( toolNodeToEndConverter = "ToolNodeToEndConverter" ) - cvt := func(ctx context.Context, sToolCallMessages sToolNodeOutput) (sGraphOutput, error) { + cvt := func(ctx context.Context, toolResults []Message) (Message, error) { id, _ := getReturnDirectlyToolCallID(ctx) - return schema.StreamReaderWithConvert(sToolCallMessages, - func(in []Message) (Message, error) { - - for _, chunk := range in { - if chunk != nil && chunk.ToolCallID == id { - return chunk, nil - } - } + for _, msg := range toolResults { + if msg != nil && msg.ToolCallID == id { + return msg, nil + } + } - return nil, schema.ErrNoValue - }), nil + return nil, errors.New("return directly tool call result not found") } - _ = g.AddLambdaNode(toolNodeToEndConverter, compose.TransformableLambda(cvt), + _ = g.AddLambdaNode(toolNodeToEndConverter, compose.InvokableLambda(cvt), compose.WithNodeName(toolNodeToEndConverter)) _ = g.AddEdge(toolNodeToEndConverter, compose.END) - checkReturnDirect := func(ctx context.Context, - sToolCallMessages sToolNodeOutput) (string, error) { - + checkReturnDirect := func(ctx context.Context, toolResults []Message) (string, error) { _, ok := getReturnDirectlyToolCallID(ctx) if ok { return toolNodeToEndConverter, nil } - return nodeNameAfterTool(), nil + return chatModel_, nil } - branch = compose.NewStreamGraphBranch(checkReturnDirect, - map[string]bool{toolNodeToEndConverter: true, nodeNameAfterTool(): true}) - _ = g.AddBranch(toolNode_, branch) + returnDirectBranch := compose.NewGraphBranch(checkReturnDirect, + map[string]bool{toolNodeToEndConverter: true, chatModel_: true}) + _ = g.AddBranch(afterToolCallsCancelCheckNode_, returnDirectBranch) } else { - _ = g.AddEdge(toolNode_, nodeNameAfterTool()) + _ = g.AddEdge(afterToolCallsCancelCheckNode_, chatModel_) } return g, nil } - -type cancelSig struct { - done chan struct{} - config atomic.Value -} - -func newCancelSig() *cancelSig { - return &cancelSig{ - done: make(chan struct{}), - } -} - -func (cs *cancelSig) cancel(cfg *cancelConfig) { - cs.config.Store(cfg) - close(cs.done) -} - -func checkCancelSig(cs *cancelSig) *cancelConfig { - if cs == nil { - return nil - } - select { - case <-cs.done: - return cs.config.Load().(*cancelConfig) - default: - return nil - } -} diff --git a/adk/react_test.go b/adk/react_test.go index 969e73e35..b0a6c3985 100644 --- a/adk/react_test.go +++ b/adk/react_test.go @@ -23,6 +23,7 @@ import ( "errors" "fmt" "io" + "math" "math/rand" "testing" @@ -144,16 +145,16 @@ func TestReact(t *testing.T) { toolsReturnDirectly: map[string]bool{}, } - graph, err := newReact(ctx, config, nil) + graph, err := newReact(ctx, config) assert.NoError(t, err) assert.NotNil(t, graph) - compiled, err := graph.Compile(ctx) + compiled, err := graph.Compile(ctx, compose.WithMaxRunSteps(math.MaxInt)) assert.NoError(t, err) assert.NotNil(t, compiled) // Test with a user message - result, err := compiled.Invoke(ctx, &reactInput{messages: []Message{ + result, err := compiled.Invoke(ctx, &reactInput{Messages: []Message{ { Role: schema.User, Content: "Use the test tool to say hello", @@ -211,16 +212,16 @@ func TestReact(t *testing.T) { toolsReturnDirectly: map[string]bool{info.Name: true}, } - graph, err := newReact(ctx, config, nil) + graph, err := newReact(ctx, config) assert.NoError(t, err) assert.NotNil(t, graph) - compiled, err := graph.Compile(ctx) + compiled, err := graph.Compile(ctx, compose.WithMaxRunSteps(math.MaxInt)) assert.NoError(t, err) assert.NotNil(t, compiled) // Test with a user message when tool returns directly - result, err := compiled.Invoke(ctx, &reactInput{messages: []Message{ + result, err := compiled.Invoke(ctx, &reactInput{Messages: []Message{ { Role: schema.User, Content: "Use the test tool to say hello", @@ -303,16 +304,16 @@ func TestReact(t *testing.T) { toolsReturnDirectly: map[string]bool{}, } - graph, err := newReact(ctx, config, nil) + graph, err := newReact(ctx, config) assert.NoError(t, err) assert.NotNil(t, graph) - compiled, err := graph.Compile(ctx) + compiled, err := graph.Compile(ctx, compose.WithMaxRunSteps(math.MaxInt)) assert.NoError(t, err) assert.NotNil(t, compiled) // Test streaming with a user message - outStream, err := compiled.Stream(ctx, &reactInput{messages: []Message{ + outStream, err := compiled.Stream(ctx, &reactInput{Messages: []Message{ { Role: schema.User, Content: "Use the test tool to say hello", @@ -413,11 +414,11 @@ func TestReact(t *testing.T) { toolsReturnDirectly: map[string]bool{streamInfo.Name: true}, } - graph, err := newReact(ctx, config, nil) + graph, err := newReact(ctx, config) assert.NoError(t, err) assert.NotNil(t, graph) - compiled, err := graph.Compile(ctx) + compiled, err := graph.Compile(ctx, compose.WithMaxRunSteps(math.MaxInt)) assert.NoError(t, err) assert.NotNil(t, compiled) @@ -425,7 +426,7 @@ func TestReact(t *testing.T) { times = 0 // Test streaming with a user message when tool returns directly - outStream, err := compiled.Stream(ctx, &reactInput{messages: []Message{ + outStream, err := compiled.Stream(ctx, &reactInput{Messages: []Message{ { Role: schema.User, Content: "Use the test tool to say hello", @@ -502,16 +503,16 @@ func TestReact(t *testing.T) { maxIterations: 6, } - graph, err := newReact(ctx, config, nil) + graph, err := newReact(ctx, config) assert.NoError(t, err) assert.NotNil(t, graph) - compiled, err := graph.Compile(ctx) + compiled, err := graph.Compile(ctx, compose.WithMaxRunSteps(math.MaxInt)) assert.NoError(t, err) assert.NotNil(t, compiled) // Test with a user message - result, err := compiled.Invoke(ctx, &reactInput{messages: []Message{ + result, err := compiled.Invoke(ctx, &reactInput{Messages: []Message{ { Role: schema.User, Content: "Use the test tool to say hello", @@ -532,16 +533,16 @@ func TestReact(t *testing.T) { maxIterations: 5, } - graph, err = newReact(ctx, config, nil) + graph, err = newReact(ctx, config) assert.NoError(t, err) assert.NotNil(t, graph) - compiled, err = graph.Compile(ctx) + compiled, err = graph.Compile(ctx, compose.WithMaxRunSteps(math.MaxInt)) assert.NoError(t, err) assert.NotNil(t, compiled) // Test with a user message - result, err = compiled.Invoke(ctx, &reactInput{messages: []Message{ + result, err = compiled.Invoke(ctx, &reactInput{Messages: []Message{ { Role: schema.User, Content: "Use the test tool to say hello", diff --git a/adk/retry_chatmodel.go b/adk/retry_chatmodel.go index 8ae4e2aac..bac955033 100644 --- a/adk/retry_chatmodel.go +++ b/adk/retry_chatmodel.go @@ -196,6 +196,11 @@ func (r *retryModelWrapper) Generate(ctx context.Context, input []*schema.Messag return out, nil } + // Never retry interrupt errors (e.g. cancel safe-point interrupts). + if _, ok := compose.ExtractInterruptInfo(err); ok { + return nil, err + } + if !isRetryAble(ctx, err) { return nil, err } @@ -238,6 +243,10 @@ func (r *retryModelWrapper) Stream(ctx context.Context, input []*schema.Message, stream, err := r.inner.Stream(ctx, input, opts...) if err != nil { + // Never retry interrupt errors (e.g. cancel safe-point interrupts). + if _, ok := compose.ExtractInterruptInfo(err); ok { + return nil, err + } if !isRetryAble(ctx, err) { return nil, err } diff --git a/adk/runctx_test.go b/adk/runctx_test.go index 7f164b3e2..bef1f44eb 100644 --- a/adk/runctx_test.go +++ b/adk/runctx_test.go @@ -17,7 +17,10 @@ package adk import ( + "bytes" "context" + "encoding/gob" + "errors" "testing" "time" @@ -423,3 +426,209 @@ func TestForkJoinRunCtx(t *testing.T) { mainRunCtx.Session.addEvent(eventF) assert.Equal(t, []string{"A", "B", "C1", "D", "E", "F"}, getEventNames(mainRunCtx.Session.getEvents()), "After F") } + +// makeStreamingEventWrapper creates an agentEventWrapper with a streaming MessageOutput +// whose stream yields the given message then terminates with streamErr (or io.EOF if nil). +func makeStreamingEventWrapper(msg Message, streamErr error) *agentEventWrapper { + r, w := schema.Pipe[Message](2) + w.Send(msg, nil) + if streamErr != nil { + w.Send(nil, streamErr) + } + w.Close() + + return &agentEventWrapper{ + AgentEvent: &AgentEvent{ + AgentName: "test-agent", + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + IsStreaming: true, + MessageStream: r, + Role: schema.Assistant, + }, + }, + }, + } +} + +func TestGobEncodeStreamErrors(t *testing.T) { + t.Run("WillRetryError_unconsumed_stream_fails_GobEncode", func(t *testing.T) { + // An agentEventWrapper whose stream yields a message then WillRetryError. + // Without pre-consuming (no getMessageFromWrappedEvent call), GobEncode + // reaches MessageVariant.GobEncode which treats non-EOF errors as fatal. + wrapper := makeStreamingEventWrapper( + schema.AssistantMessage("partial", nil), + &WillRetryError{ErrStr: "model error", RetryAttempt: 1}, + ) + + _, err := wrapper.GobEncode() + assert.NoError(t, err, "GobEncode should handle WillRetryError streams gracefully") + }) + + t.Run("ErrStreamCanceled_unconsumed_stream_fails_GobEncode", func(t *testing.T) { + // Same scenario but with ErrStreamCanceled (*errors.errorString). + wrapper := makeStreamingEventWrapper( + schema.AssistantMessage("partial", nil), + ErrStreamCanceled, + ) + + _, err := wrapper.GobEncode() + assert.NoError(t, err, "GobEncode should handle ErrStreamCanceled streams gracefully") + }) + + t.Run("successful_stream_GobEncode_succeeds", func(t *testing.T) { + // Control: a clean stream (no error) should encode fine. + wrapper := makeStreamingEventWrapper( + schema.AssistantMessage("hello", nil), + nil, // no stream error + ) + + data, err := wrapper.GobEncode() + assert.NoError(t, err) + assert.NotEmpty(t, data) + + // Verify round-trip decode works. + decoded := &agentEventWrapper{AgentEvent: &AgentEvent{}} + err = decoded.GobDecode(data) + assert.NoError(t, err) + assert.Equal(t, "test-agent", decoded.AgentName) + }) + + t.Run("preconsumed_WillRetryError_GobEncode_succeeds", func(t *testing.T) { + // When getMessageFromWrappedEvent is called first, WillRetryError is + // cached in StreamErr and the stream is replaced with an error-free array. + wrapper := makeStreamingEventWrapper( + schema.AssistantMessage("partial", nil), + &WillRetryError{ErrStr: "model error", RetryAttempt: 1}, + ) + + _, consumeErr := getMessageFromWrappedEvent(wrapper) + assert.Error(t, consumeErr) + + data, err := wrapper.GobEncode() + assert.NoError(t, err, "GobEncode should succeed after pre-consuming WillRetryError stream") + assert.NotEmpty(t, data) + }) + + t.Run("preconsumed_ErrStreamCanceled_GobEncode_succeeds", func(t *testing.T) { + // ErrStreamCanceled is a *StreamCanceledError which IS gob-registered. + // After getMessageFromWrappedEvent, StreamErr = ErrStreamCanceled. + // Since it's registered, gob encoding succeeds. + wrapper := makeStreamingEventWrapper( + schema.AssistantMessage("partial", nil), + ErrStreamCanceled, + ) + + _, consumeErr := getMessageFromWrappedEvent(wrapper) + assert.Error(t, consumeErr) + + data, err := wrapper.GobEncode() + assert.NoError(t, err, "GobEncode should succeed; ErrStreamCanceled is gob-registered") + assert.NotEmpty(t, data) + }) + + t.Run("GobEncode_roundtrip_preserves_content", func(t *testing.T) { + // Verify that after GobEncode with a WillRetryError stream, + // the decoded wrapper has the partial message content and StreamErr intact. + wrapper := makeStreamingEventWrapper( + schema.AssistantMessage("partial response", nil), + &WillRetryError{ErrStr: "err", RetryAttempt: 1}, + ) + + data, err := wrapper.GobEncode() + assert.NoError(t, err) + + decoded := &agentEventWrapper{AgentEvent: &AgentEvent{}} + err = decoded.GobDecode(data) + assert.NoError(t, err) + assert.Equal(t, "test-agent", decoded.AgentName) + assert.True(t, decoded.Output.MessageOutput.IsStreaming) + // The stream should be consumable and yield the partial message. + msg, recvErr := decoded.Output.MessageOutput.MessageStream.Recv() + assert.NoError(t, recvErr) + assert.Contains(t, msg.Content, "partial response") + // StreamErr should be preserved for end-user visibility. + var willRetryErr *WillRetryError + assert.True(t, errors.As(decoded.StreamErr, &willRetryErr)) + assert.Equal(t, "err", willRetryErr.ErrStr) + }) + + t.Run("GobEncode_roundtrip_preserves_ErrStreamCanceled", func(t *testing.T) { + // ErrStreamCanceled (*StreamCanceledError) is gob-registered, so + // StreamErr should survive encoding/decoding. + wrapper := makeStreamingEventWrapper( + schema.AssistantMessage("partial", nil), + ErrStreamCanceled, + ) + + data, err := wrapper.GobEncode() + assert.NoError(t, err) + + decoded := &agentEventWrapper{AgentEvent: &AgentEvent{}} + err = decoded.GobDecode(data) + assert.NoError(t, err) + var streamCanceledErr *StreamCanceledError + assert.ErrorAs(t, decoded.StreamErr, &streamCanceledErr) + }) + + t.Run("GobEncode_idempotent", func(t *testing.T) { + // Calling GobEncode twice should succeed both times (stream replaced on first call). + wrapper := makeStreamingEventWrapper( + schema.AssistantMessage("hello", nil), + &WillRetryError{ErrStr: "err", RetryAttempt: 1}, + ) + + data1, err := wrapper.GobEncode() + assert.NoError(t, err) + + data2, err := wrapper.GobEncode() + assert.NoError(t, err) + + // Both should decode to equivalent content. + d1, d2 := &agentEventWrapper{AgentEvent: &AgentEvent{}}, &agentEventWrapper{AgentEvent: &AgentEvent{}} + assert.NoError(t, d1.GobDecode(data1)) + assert.NoError(t, d2.GobDecode(data2)) + assert.Equal(t, d1.AgentName, d2.AgentName) + }) + + t.Run("GobEncode_non_streaming_unaffected", func(t *testing.T) { + // Non-streaming events should encode/decode as before. + wrapper := &agentEventWrapper{ + AgentEvent: &AgentEvent{ + AgentName: "non-stream-agent", + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + IsStreaming: false, + Message: schema.AssistantMessage("direct", nil), + Role: schema.Assistant, + }, + }, + }, + } + + data, err := wrapper.GobEncode() + assert.NoError(t, err) + + decoded := &agentEventWrapper{AgentEvent: &AgentEvent{}} + assert.NoError(t, decoded.GobDecode(data)) + assert.Equal(t, "non-stream-agent", decoded.AgentName) + assert.False(t, decoded.Output.MessageOutput.IsStreaming) + }) + + t.Run("GobEncode_within_runSession", func(t *testing.T) { + // Simulate the real scenario: a runSession with a streaming event containing + // WillRetryError is gob-encoded (as happens during checkpoint save). + wrapper := makeStreamingEventWrapper( + schema.AssistantMessage("checkpoint content", nil), + &WillRetryError{ErrStr: "retry", RetryAttempt: 1}, + ) + + session := newRunSession() + session.Events = []*agentEventWrapper{wrapper} + + // Encode the entire session (the checkpoint path). + var buf bytes.Buffer + err := gob.NewEncoder(&buf).Encode(session) + assert.NoError(t, err, "encoding runSession with WillRetryError stream should succeed") + }) +} diff --git a/adk/runner.go b/adk/runner.go index a9fd9b94d..4881122a6 100644 --- a/adk/runner.go +++ b/adk/runner.go @@ -18,6 +18,7 @@ package adk import ( "context" + "errors" "fmt" "runtime/debug" "sync" @@ -41,6 +42,8 @@ type Runner struct { type CheckPointStore = core.CheckPointStore +type CheckPointDeleter = core.CheckPointDeleter + type RunnerConfig struct { Agent Agent EnableStreaming bool @@ -74,31 +77,6 @@ func NewRunner(_ context.Context, conf RunnerConfig) *Runner { // upon interruption. func (r *Runner) Run(ctx context.Context, messages []Message, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { - iter, _, _ := r.runWithCancel(ctx, messages, false, opts...) - return iter -} - -// Query is a convenience method that starts a new execution with a single user query string. -func (r *Runner) Query(ctx context.Context, - query string, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { - - return r.Run(ctx, []Message{schema.UserMessage(query)}, opts...) -} - -// RunWithCancel starts a new execution of the agent and returns both an iterator and a cancel function. -// The cancel function can be used to interrupt the running agent at specific points based on the CancelMode. -// If the Runner was configured with a CheckPointStore and WithCheckPointID option, it will automatically -// save the agent's state upon cancellation for later resumption. -// -// If the agent does not implement CancellableAgent, the returned CancelFunc will be nil. -func (r *Runner) RunWithCancel(ctx context.Context, messages []Message, - opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) { - iter, cancelFn, _ := r.runWithCancel(ctx, messages, true, opts...) - return iter, cancelFn -} - -func (r *Runner) runWithCancel(ctx context.Context, messages []Message, withCancel bool, - opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc, error) { o := getCommonOptions(nil, opts...) fa := toFlowAgent(ctx, r.a) @@ -112,46 +90,23 @@ func (r *Runner) runWithCancel(ctx context.Context, messages []Message, withCanc AddSessionValues(ctx, o.sessionValues) - var iter *AsyncIterator[*AgentEvent] - var cancelFn CancelFunc - if withCancel { - if _, ok := r.a.(CancellableAgent); ok { - iter, cancelFn = fa.RunWithCancel(ctx, input, opts...) - } else { - iter = fa.Run(ctx, input, opts...) - } - } else { - iter = fa.Run(ctx, input, opts...) - } + iter := fa.Run(ctx, input, opts...) - if r.store == nil { - return iter, cancelFn, nil + if r.store == nil && o.cancelCtx == nil { + return iter } niter, gen := NewAsyncIteratorPair[*AgentEvent]() - go r.handleIter(ctx, iter, gen, o.checkPointID) - return niter, cancelFn, nil + go r.handleIter(ctx, iter, gen, o.checkPointID, o.cancelCtx) + return niter } -// ResumeWithCancel continues an interrupted execution from a checkpoint and returns both an iterator and a cancel function. -// This method uses the "Implicit Resume All" strategy where all previously interrupted points proceed without specific data. -// The cancel function can be used to interrupt the running agent again at specific points based on the CancelMode. -// -// If the agent does not implement CancellableResumableAgent, the returned CancelFunc will be nil. -func (r *Runner) ResumeWithCancel(ctx context.Context, checkPointID string, opts ...AgentRunOption) ( - *AsyncIterator[*AgentEvent], CancelFunc, error) { - return r.resumeWithCancel(ctx, checkPointID, nil, true, opts...) -} +// Query is a convenience method that starts a new execution with a single user query string. +func (r *Runner) Query(ctx context.Context, + query string, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { -// ResumeWithParamsAndCancel continues an interrupted execution from a checkpoint with specific parameters -// and returns both an iterator and a cancel function. -// The params.Targets map should contain the addresses of the components to be resumed as keys. -// -// If the agent does not implement CancellableResumableAgent, the returned CancelFunc will be nil. -func (r *Runner) ResumeWithParamsAndCancel(ctx context.Context, checkPointID string, params *ResumeParams, - opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc, error) { - return r.resumeWithCancel(ctx, checkPointID, params.Targets, true, opts...) + return r.Run(ctx, []Message{schema.UserMessage(query)}, opts...) } // Resume continues an interrupted execution from a checkpoint, using an "Implicit Resume All" strategy. @@ -163,8 +118,7 @@ func (r *Runner) ResumeWithParamsAndCancel(ctx context.Context, checkPointID str // pattern where an agent only needs to know `wasInterrupted` is true to continue. func (r *Runner) Resume(ctx context.Context, checkPointID string, opts ...AgentRunOption) ( *AsyncIterator[*AgentEvent], error) { - iter, _, err := r.resumeWithCancel(ctx, checkPointID, nil, false, opts...) - return iter, err + return r.resumeInternal(ctx, checkPointID, nil, opts...) } // ResumeWithParams continues an interrupted execution from a checkpoint with specific parameters. @@ -186,19 +140,18 @@ func (r *Runner) Resume(ctx context.Context, checkPointID string, opts ...AgentR // naturally re-interrupt if one of their interrupted children re-interrupts, as they receive the // new `CompositeInterrupt` signal from them. func (r *Runner) ResumeWithParams(ctx context.Context, checkPointID string, params *ResumeParams, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], error) { - iter, _, err := r.resumeWithCancel(ctx, checkPointID, params.Targets, false, opts...) - return iter, err + return r.resumeInternal(ctx, checkPointID, params.Targets, opts...) } -func (r *Runner) resumeWithCancel(ctx context.Context, checkPointID string, resumeData map[string]any, - withCancel bool, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc, error) { +func (r *Runner) resumeInternal(ctx context.Context, checkPointID string, resumeData map[string]any, + opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], error) { if r.store == nil { - return nil, nil, fmt.Errorf("failed to resume: store is nil") + return nil, fmt.Errorf("failed to resume: store is nil") } ctx, runCtx, resumeInfo, err := r.loadCheckPoint(ctx, checkPointID) if err != nil { - return nil, nil, fmt.Errorf("failed to load from checkpoint: %w", err) + return nil, fmt.Errorf("failed to load from checkpoint: %w", err) } o := getCommonOptions(nil, opts...) @@ -226,30 +179,20 @@ func (r *Runner) resumeWithCancel(ctx context.Context, checkPointID string, resu fa := toFlowAgent(ctx, r.a) - var aIter *AsyncIterator[*AgentEvent] - var cancelFn CancelFunc - if withCancel { - if _, ok := r.a.(CancellableResumableAgent); ok { - aIter, cancelFn = fa.ResumeWithCancel(ctx, resumeInfo, opts...) - } else { - aIter = fa.Resume(ctx, resumeInfo, opts...) - } - } else { - aIter = fa.Resume(ctx, resumeInfo, opts...) - } + aIter := fa.Resume(ctx, resumeInfo, opts...) - if r.store == nil { - return aIter, cancelFn, nil + if r.store == nil && o.cancelCtx == nil { + return aIter, nil } niter, gen := NewAsyncIteratorPair[*AgentEvent]() - go r.handleIter(ctx, aIter, gen, &checkPointID) - return niter, cancelFn, nil + go r.handleIter(ctx, aIter, gen, &checkPointID, o.cancelCtx) + return niter, nil } func (r *Runner) handleIter(ctx context.Context, aIter *AsyncIterator[*AgentEvent], - gen *AsyncGenerator[*AgentEvent], checkPointID *string) { + gen *AsyncGenerator[*AgentEvent], checkPointID *string, cancelCtx *cancelContext) { defer func() { panicErr := recover() if panicErr != nil { @@ -269,6 +212,25 @@ func (r *Runner) handleIter(ctx context.Context, aIter *AsyncIterator[*AgentEven break } + if event.Err != nil { + var cancelErr *CancelError + if errors.As(event.Err, &cancelErr) { + if cancelCtx != nil && cancelCtx.isRoot() && cancelCtx.shouldCancel() { + cancelCtx.markCancelHandled() + } + if cancelErr.interruptSignal != nil && checkPointID != nil { + cancelErr.CheckPointID = *checkPointID + cancelErr.InterruptContexts = core.ToInterruptContexts(cancelErr.interruptSignal, allowedAddressSegmentTypes) + err := r.saveCheckPoint(ctx, *checkPointID, &InterruptInfo{}, cancelErr.interruptSignal) + if err != nil { + gen.Send(&AgentEvent{Err: fmt.Errorf("failed to save checkpoint on cancel: %w", err)}) + } + } + gen.Send(event) + break + } + } + if event.Action != nil && event.Action.internalInterrupted != nil { if interruptSignal != nil { // even if multiple interrupt happens, they should be merged into one @@ -293,8 +255,7 @@ func (r *Runner) handleIter(ctx context.Context, aIter *AsyncIterator[*AgentEven legacyData = event.Action.Interrupted.Data if checkPointID != nil { - // save checkpoint first before sending interrupt event, - // so when end-user receives interrupt event, they can resume from this checkpoint + // save checkpoint first before sending interrupt event, so when end-user receives interrupt event, they can resume from this checkpoint err := r.saveCheckPoint(ctx, *checkPointID, &InterruptInfo{ Data: legacyData, }, interruptSignal) diff --git a/adk/turn_loop.go b/adk/turn_loop.go index f7ec423c5..88979876e 100644 --- a/adk/turn_loop.go +++ b/adk/turn_loop.go @@ -17,7 +17,9 @@ package adk import ( + "bytes" "context" + "encoding/gob" "errors" "fmt" "runtime/debug" @@ -25,558 +27,1387 @@ import ( "sync/atomic" "time" + "github.com/cloudwego/eino/internal" "github.com/cloudwego/eino/internal/safe" ) -// ConsumeMode specifies how a received message should be consumed -// relative to the currently running agent. -type ConsumeMode int - -const ( - // ConsumeNonPreemptive processes the message after the current agent - // finishes. This is the default queued behavior. - ConsumeNonPreemptive ConsumeMode = iota - // ConsumePreemptive cancels the currently running agent (if it - // implements Cancellable) and processes the message immediately. - // If the agent does not implement Cancellable, the message is - // buffered and processed after the agent finishes. - ConsumePreemptive - - ConsumePreemptiveOnTimeout -) - -type consumeConfig struct { - Mode ConsumeMode - Timeout time.Duration - CancelOpts []CancelOption - CheckPointID string +// stopSignal coordinates the Stop() call with per-turn watcher goroutines. +// +// Lifecycle overview: +// +// 1. SIGNAL — Stop() calls signal() which bumps the generation counter, +// stores the AgentCancelOptions, and deposits a one-shot notification +// in the buffered notify channel. +// +// 2. DONE — Stop() calls closeDone() which permanently closes the done +// channel. This acts as a durable "stopped" flag: any current or future +// select on done fires immediately, ensuring that every watcher — +// including watchers in turns that start after Stop() but before the +// run loop observes isStopped() — can reliably detect the stop. +// +// 3. RECEIVE — The per-turn watchStopSignal goroutine selects on the done +// channel (the durable flag) and the notify channel (to detect mode +// escalation from a second Stop call). On either signal, it calls +// agentCancelFunc to cancel the running agent. +// +// The generation counter (gen) de-duplicates wakes so that the watcher only +// acts when a new Stop() call has been made, supporting mode escalation +// (e.g. CancelAfterToolCalls followed by CancelImmediate). +type stopSignal struct { + // done is closed exactly once by closeDone(). A closed channel is + // readable forever, so it serves as a durable stop flag for all watchers. + done chan struct{} + + mu sync.Mutex + gen uint64 + agentCancelOpts []AgentCancelOption + // notify is a buffered(1) channel that wakes the current turn's watcher + // when Stop() is called. Unlike done, it supports repeated Stop() calls + // for cancel-mode escalation. + notify chan struct{} } -type ConsumeOption func(*consumeConfig) - -// WithPreemptive sets the consume mode to preemptive, which cancels the -// currently running agent immediately. -func WithPreemptive() ConsumeOption { - return func(config *consumeConfig) { - config.Mode = ConsumePreemptive +func newStopSignal() *stopSignal { + return &stopSignal{ + done: make(chan struct{}), + notify: make(chan struct{}, 1), } } -// WithPreemptiveOnTimeout sets the consume mode to preemptive with a timeout. -// If the current agent does not complete within the timeout, it will be canceled. -func WithPreemptiveOnTimeout(timeout time.Duration) ConsumeOption { - return func(config *consumeConfig) { - config.Mode = ConsumePreemptiveOnTimeout - config.Timeout = timeout +// signal records a stop request and wakes the current turn's watcher (if any). +// The non-blocking send means the notification is silently coalesced when the +// buffer is already full — this is safe because gen de-duplicates in the watcher. +func (s *stopSignal) signal(cfg *stopConfig) { + s.mu.Lock() + s.gen++ + s.agentCancelOpts = cfg.agentCancelOpts + s.mu.Unlock() + select { + case s.notify <- struct{}{}: + default: } } -// WithCancelOptions appends cancel options to be used when canceling the agent. -func WithCancelOptions(opts ...CancelOption) ConsumeOption { - return func(config *consumeConfig) { - config.CancelOpts = append(config.CancelOpts, opts...) +// isStopped returns true if closeDone() has been called. +func (s *stopSignal) isStopped() bool { + select { + case <-s.done: + return true + default: + return false } } -// WithConsumeCheckPointID sets the checkpoint ID for the consumed message. -// When set, the checkpoint will be saved with this ID if an interrupt occurs. -func WithConsumeCheckPointID(id string) ConsumeOption { - return func(config *consumeConfig) { - config.CheckPointID = id +// closeDone permanently marks the stop as committed. All current and future +// selects on s.done will fire immediately after this call. +func (s *stopSignal) closeDone() { + close(s.done) +} + +// check returns the current generation and a snapshot of the cancel options. +func (s *stopSignal) check() (uint64, []AgentCancelOption) { + s.mu.Lock() + defer s.mu.Unlock() + return s.gen, append([]AgentCancelOption{}, s.agentCancelOpts...) +} + +// preemptSignal coordinates preemption between Push callers and the run loop. +// +// Lifecycle overview: +// +// 1. HOLD — A Push caller (or the run loop itself) calls holdRunLoop() to +// increment holdCount. While holdCount > 0 the run loop blocks at +// waitForPreemptOrUnhold(), preventing it from starting a new turn. +// +// 2. REQUEST — The Push caller calls requestPreempt() which sets +// preemptRequested=true, bumps preemptGen, stores cancelOpts/acks, and +// wakes both the run-loop (via cond) and the in-turn watcher goroutine +// (via notify channel). +// +// 3. RECEIVE — The per-turn watchPreemptSignal goroutine calls +// receivePreempt(), obtains the cancel opts and ack channels, invokes +// agentCancelFunc to cancel the running agent, and closes the ack +// channels to notify Push callers. +// +// 4. UNHOLD — After the turn finishes (or if the Push caller decides not +// to preempt), unholdRunLoop() / endTurnAndUnhold() decrements +// holdCount. When holdCount reaches 0, all signal state is reset. +// +// The run loop brackets every turn with holdRunLoop() / endTurnAndUnhold() +// so that a concurrent Push caller's hold keeps holdCount > 0 even after +// the turn ends, preventing the loop from racing into a new turn before +// the Push caller's preempt request is delivered. +// +// Fields currentTC and currentRunCtx are stored here (rather than on +// TurnLoop) so that holdAndGetTurn() can atomically snapshot the turn +// state and increment holdCount under the same mu lock, eliminating the +// TOCTOU race between reading the turn and holding the loop. +type preemptSignal struct { + mu sync.Mutex + cond *sync.Cond + holdCount int + preemptRequested bool + preemptGen uint64 + agentCancelOpts []AgentCancelOption + pendingAckList []chan struct{} + notify chan struct{} + + currentTC any + currentRunCtx context.Context +} + +func newPreemptSignal() *preemptSignal { + s := &preemptSignal{notify: make(chan struct{}, 1)} + s.cond = sync.NewCond(&s.mu) + return s +} + +func (s *preemptSignal) holdRunLoop() { + s.mu.Lock() + s.holdCount++ + s.mu.Unlock() +} + +func (s *preemptSignal) setTurn(ctx context.Context, tc any) { + s.mu.Lock() + s.currentRunCtx = ctx + s.currentTC = tc + s.mu.Unlock() +} + +func (s *preemptSignal) holdAndGetTurn() (context.Context, any) { + s.mu.Lock() + defer s.mu.Unlock() + s.holdCount++ + return s.currentRunCtx, s.currentTC +} + +// requestPreempt records a preempt request and wakes both waiters. +// If holdCount is 0, no one is listening — close the ack immediately as a no-op. +func (s *preemptSignal) requestPreempt(ack chan struct{}, opts ...AgentCancelOption) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.holdCount <= 0 { + if ack != nil { + close(ack) + } + return + } + + s.preemptRequested = true + s.preemptGen++ + s.agentCancelOpts = opts + if ack != nil { + s.pendingAckList = append(s.pendingAckList, ack) } + select { + case s.notify <- struct{}{}: + default: + } + + s.cond.Broadcast() } -type ReceiveConfig struct { - Timeout time.Duration +// receivePreempt is called by the per-turn watcher goroutine to consume a +// pending preempt. It drains pendingAckList (so the watcher can close them +// after invoking agentCancelFunc) but intentionally preserves preemptRequested +// and preemptGen — these are needed by waitForPreemptOrUnhold on the run loop. +func (s *preemptSignal) receivePreempt() (bool, uint64, []AgentCancelOption, []chan struct{}) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.preemptRequested { + ackList := s.pendingAckList + s.pendingAckList = nil + return true, s.preemptGen, s.agentCancelOpts, ackList + } + return false, 0, nil, nil } -type MessageSource[T any] interface { - Receive(context.Context, ReceiveConfig) (context.Context, T, []ConsumeOption, error) - Front(context.Context, ReceiveConfig) (context.Context, T, []ConsumeOption, error) +// waitForPreemptOrUnhold blocks the run loop between turns. It returns early +// (preempted=false) when holdCount is 0 (no Push caller is holding). Otherwise +// it blocks until either a preempt is requested or all holders release. +func (s *preemptSignal) waitForPreemptOrUnhold() (preempted bool, opts []AgentCancelOption, ackList []chan struct{}) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.holdCount <= 0 { + return false, nil, nil + } + + for s.holdCount > 0 && !s.preemptRequested { + s.cond.Wait() + } + + if s.preemptRequested { + ackList = s.pendingAckList + s.pendingAckList = nil + return true, s.agentCancelOpts, ackList + } + return false, nil, nil } -type turnLoopRunConfig[T any] struct { - checkPointID string - item T +// resetLocked clears all signal state and closes pending ack channels so the +// next cycle starts clean and blocked Push callers are unblocked. Must be +// called with s.mu held. Does NOT touch holdCount, currentTC, or currentRunCtx +// — callers are responsible for those. +func (s *preemptSignal) resetLocked() { + s.preemptRequested = false + s.preemptGen = 0 + s.agentCancelOpts = nil + for _, ack := range s.pendingAckList { + close(ack) + } + s.pendingAckList = nil + select { + case <-s.notify: + default: + } } -// TurnLoopRunOption is an option for TurnLoop.Run. -type TurnLoopRunOption[T any] func(*turnLoopRunConfig[T]) +// unholdRunLoop drops one hold. When holdCount reaches 0, all signal state is +// reset so the next cycle starts clean. +func (s *preemptSignal) unholdRunLoop() { + s.mu.Lock() + defer s.mu.Unlock() + + s.holdCount-- + if s.holdCount < 0 { + s.holdCount = 0 + } + if s.holdCount == 0 { + s.resetLocked() + } + s.cond.Broadcast() +} -// WithTurnLoopResume configures the TurnLoop to resume from a previously saved checkpoint. -// The checkPointID identifies the checkpoint to resume from, and item is the original input -// that triggered the interrupted execution. -func WithTurnLoopResume[T any](checkPointID string, item T) TurnLoopRunOption[T] { - return func(c *turnLoopRunConfig[T]) { - c.checkPointID = checkPointID - c.item = item +// endTurnAndUnhold is called by the run loop after runAgentAndHandleEvents +// returns. It clears the current turn context and drops the run loop's hold. +func (s *preemptSignal) endTurnAndUnhold() { + s.mu.Lock() + defer s.mu.Unlock() + + s.currentTC = nil + s.currentRunCtx = nil + s.holdCount-- + if s.holdCount < 0 { + s.holdCount = 0 + } + if s.holdCount == 0 { + s.resetLocked() } + s.cond.Broadcast() +} + +// drainAll forcefully resets all preemptSignal state and closes any pending +// ack channels. Called during TurnLoop cleanup to prevent ack channels from +// leaking when the run loop exits (e.g. due to Stop) while a Push caller +// still holds a reference. +func (s *preemptSignal) drainAll() { + s.mu.Lock() + defer s.mu.Unlock() + + s.holdCount = 0 + s.currentTC = nil + s.currentRunCtx = nil + s.resetLocked() + s.cond.Broadcast() } // TurnLoopConfig is the configuration for creating a TurnLoop. type TurnLoopConfig[T any] struct { - // Source provides messages to drive the loop. Required. - Source MessageSource[T] - // GenInput converts a received message into AgentInput and optional - // RunOptions for the agent. Required. - GenInput func(ctx context.Context, item T) (*AgentInput, []AgentRunOption, error) - // GetAgent returns the Agent to run for a given message. Required. - GetAgent func(ctx context.Context, item T) (Agent, error) - // OnAgentEvents is called for each event emitted by the agent. Optional. - // The inputItem is the message that triggered the current agent turn. - // If not provided, the default implementation will consume all events and - // return any error event encountered. - OnAgentEvents func(ctx context.Context, inputItem T, event *AsyncIterator[*AgentEvent]) error - // ReceiveTimeout is the timeout passed to Source.Receive on each iteration. - // Zero means no timeout. Optional. - ReceiveTimeout time.Duration - + // GenInput receives the TurnLoop instance and all buffered items, and decides what to process. + // It returns which items to consume now vs keep for later turns. + // The loop parameter allows calling Push() or Stop() directly from within the callback. + // Required. + GenInput func(ctx context.Context, loop *TurnLoop[T], items []T) (*GenInputResult[T], error) + + // GenResume is called exactly once when the TurnLoop detects a mid-turn + // checkpoint on startup (i.e. CheckpointID is configured and the stored + // checkpoint has runner state from an interrupted agent execution). + // It receives: + // - canceledItems: the items being processed when the prior run was canceled + // - unhandledItems: items buffered but not processed when the prior run exited + // - newItems: items that were Push()-ed before Run() was called + // + // It returns a GenResumeResult describing how to resume the interrupted agent + // turn (optional ResumeParams) and how to manipulate the buffer + // (Consumed/Remaining) before continuing. + GenResume func(ctx context.Context, loop *TurnLoop[T], canceledItems, unhandledItems, newItems []T) (*GenResumeResult[T], error) + + // PrepareAgent returns an Agent configured to handle the consumed items. + // This callback should set up the agent with appropriate system prompt, + // tools, and middlewares based on what items are being processed. + // Called once per turn with the items that GenInput decided to consume. + // The loop parameter allows calling Push() or Stop() directly from within the callback. + // Required. + PrepareAgent func(ctx context.Context, loop *TurnLoop[T], consumed []T) (Agent, error) + + // OnAgentEvents is called to handle events emitted by the agent. + // The TurnContext provides per-turn info and control: + // - tc.Consumed: items that triggered this agent execution + // - tc.Loop: allows calling Push() or Stop() directly from within the callback + // - tc.Preempted / tc.Stopped: signals while processing events + // Optional. If not provided, events are drained and errors (except CancelError + // from Stop-triggered cancellation) are returned as ExitReason. + OnAgentEvents func(ctx context.Context, tc *TurnContext[T], events *AsyncIterator[*AgentEvent]) error + + // Store is the checkpoint store for persistence and resume. Optional. + // When set together with CheckpointID, enables automatic checkpoint-based resume. + // The TurnLoop always persists both runner checkpoint bytes and item bookkeeping + // (CanceledItems, UnhandledItems) via gob encoding, so T must be gob-encodable + // when Store is used. Store CheckPointStore + + // CheckpointID, when set together with Store, enables automatic + // checkpoint-based resume. On Run(), the TurnLoop queries Store for this ID: + // - If a checkpoint exists with runner state (mid-turn interrupt), + // GenResume is called to plan the resume turn. + // - If a checkpoint exists without runner state (between-turns), + // the stored unhandled items are buffered and the loop proceeds + // normally via GenInput. + // - If no checkpoint exists, the loop starts fresh. + // + // On exit, if the TurnLoop saved a new checkpoint, it is saved under this + // same CheckpointID. On clean exit (no checkpoint saved), the existing + // checkpoint under CheckpointID is deleted to prevent stale resumption. + CheckpointID string } -// TurnLoop is a loop that pulls messages from a source, runs an Agent for -// each message, and dispatches resulting events. It supports preemptive -// cancellation when the source returns ConsumePreemptive and the current -// agent implements Cancellable. -type TurnLoop[T any] struct { - source MessageSource[T] - genInput func(ctx context.Context, item T) (*AgentInput, []AgentRunOption, error) - getAgent func(ctx context.Context, item T) (Agent, error) - onAgentEvents func(ctx context.Context, inputItem T, event *AsyncIterator[*AgentEvent]) error - receiveTimeout time.Duration - store CheckPointStore +// GenInputResult contains the result of GenInput processing. +type GenInputResult[T any] struct { + // RunCtx, if non-nil, overrides the context for this turn's execution + // (PrepareAgent, agent run, OnAgentEvents). + // + // Must be derived from the ctx passed to GenInput to preserve the + // TurnLoop's cancellation semantics and inherited values. For example: + // + // runCtx := context.WithValue(ctx, traceKey{}, extractTraceID(items)) + // return &GenInputResult[T]{RunCtx: runCtx, ...}, nil + // + // If nil, the TurnLoop's context is used unchanged. + RunCtx context.Context + + // Input is the agent input to execute + Input *AgentInput + + // RunOpts are the options for this agent run + RunOpts []AgentRunOption + + // Consumed are the items selected for this turn. + // They are removed from the buffer and passed to PrepareAgent. + Consumed []T + + // Remaining are the items to keep in the buffer for a future turn. + // TurnLoop pushes Remaining back into the buffer before running the agent. + // + // Items from the GenInput input slice that are in neither Consumed nor Remaining + // are dropped by the loop. + Remaining []T } -type turnLoopCancelSig struct { - done chan struct{} - config atomic.Value +// GenResumeResult contains the result of GenResume processing. +type GenResumeResult[T any] struct { + // RunCtx, if non-nil, overrides the context for this resumed turn's execution + // (PrepareAgent, agent resume, OnAgentEvents). + RunCtx context.Context + + // RunOpts are the options for this agent resume run. + RunOpts []AgentRunOption + + // ResumeParams are optional parameters for resuming an interrupted agent. + ResumeParams *ResumeParams + + // Consumed are the items selected for this resumed turn. + // They are removed from the buffer and passed to PrepareAgent. + Consumed []T + + // Remaining are the items to keep in the buffer for a future turn. + // TurnLoop pushes Remaining back into the buffer before resuming the agent. + // + // Items from (canceledItems, unhandledItems, newItems) that are in neither Consumed + // nor Remaining are dropped by the loop. + Remaining []T } -func newTurnLoopCancelSig() *turnLoopCancelSig { - return &turnLoopCancelSig{ - done: make(chan struct{}), - } +type turnRunSpec[T any] struct { + runCtx context.Context + input *AgentInput + runOpts []AgentRunOption + resumeParams *ResumeParams + isResume bool + consumed []T + resumeBytes []byte } -func (cs *turnLoopCancelSig) cancel(cfg *cancelConfig) { - cs.config.Store(cfg) - close(cs.done) +type turnPlan[T any] struct { + turnCtx context.Context + remaining []T + spec *turnRunSpec[T] } -func (cs *turnLoopCancelSig) isCancelled() bool { - select { - case <-cs.done: - return true - default: - return false +func (l *TurnLoop[T]) planTurn( + ctx context.Context, + isResume bool, + items []T, + pr *turnLoopPendingResume[T], +) (*turnPlan[T], error) { + if !isResume { + result, err := l.config.GenInput(ctx, l, items) + if err != nil { + return nil, err + } + if result == nil { + return nil, errors.New("GenInputResult is nil") + } + if result.Input == nil { + return nil, errors.New("agent input is nil") + } + turnCtx := ctx + if result.RunCtx != nil { + turnCtx = result.RunCtx + } + return &turnPlan[T]{ + turnCtx: turnCtx, + remaining: result.Remaining, + spec: &turnRunSpec[T]{ + runCtx: result.RunCtx, + input: result.Input, + runOpts: result.RunOpts, + consumed: result.Consumed, + }, + }, nil + } + if pr == nil { + return nil, errors.New("resume payload is nil") + } + if l.config.GenResume == nil { + return nil, errors.New("GenResume is required for resume") + } + resumeResult, err := l.config.GenResume(ctx, l, pr.canceled, pr.unhandled, pr.newItems) + if err != nil { + return nil, err + } + if resumeResult == nil { + return nil, errors.New("GenResumeResult is nil") + } + turnCtx := ctx + if resumeResult.RunCtx != nil { + turnCtx = resumeResult.RunCtx } + return &turnPlan[T]{ + turnCtx: turnCtx, + remaining: resumeResult.Remaining, + spec: &turnRunSpec[T]{ + runCtx: resumeResult.RunCtx, + runOpts: resumeResult.RunOpts, + resumeParams: resumeResult.ResumeParams, + isResume: true, + consumed: resumeResult.Consumed, + resumeBytes: pr.resumeBytes, + }, + }, nil } -func (cs *turnLoopCancelSig) getConfig() *cancelConfig { - if v := cs.config.Load(); v != nil { - return v.(*cancelConfig) - } - return nil +// TurnLoopExitState is returned when TurnLoop exits, containing the exit reason +// and any items that were not processed. +type TurnLoopExitState[T any] struct { + // ExitReason indicates why the loop exited. + // nil means clean exit (Stop() was called and completed normally). + // Non-nil values include context errors, callback errors, *CancelError, etc. + // When Stop() cancels a running agent, ExitReason will be a *CancelError. + ExitReason error + + // UnhandledItems contains items that were buffered but not processed. + // This is always valid regardless of ExitReason. + UnhandledItems []T + + // CanceledItems contains the items whose turn was canceled by Stop(). + // This is set when Stop() is called during a running turn, even if it + // did not contribute to the final CancelError. + // It can be used to reconstruct GenInput/PrepareAgent inputs when resuming. + CanceledItems []T } -func (cs *turnLoopCancelSig) getDoneChan() <-chan struct{} { - if cs != nil { - return cs.done +// TurnContext provides per-turn context to the OnAgentEvents callback. +type TurnContext[T any] struct { + // Loop is the TurnLoop instance, allowing Push() or Stop() calls. + Loop *TurnLoop[T] + + // Consumed contains items that triggered this agent execution. + Consumed []T + + // Preempted is closed when a preempt signal fires for the current turn + // (via Push with WithPreempt) and at least one preemptive Push contributed + // to the CancelError for the current turn. + // "Contributed" means the preempt's cancel options were included in the + // CancelError before it was finalized. Remains open if no preempt contributed. + // Use in a select to detect preemption while processing events. + Preempted <-chan struct{} + + // Stopped is closed when a Stop() call contributed to the CancelError for the + // current turn. + // "Contributed" means Stop's cancel options were included in the CancelError + // before it was finalized. Remains open if Stop did not contribute. + // Use in a select to detect stop while processing events. + Stopped <-chan struct{} +} + +// TurnLoop is a push-based event loop for agent execution. +// Users push items via Push() and the loop processes them through the agent. +// +// Create with NewTurnLoop, then start with Run: +// +// loop := NewTurnLoop(cfg) +// // pass loop to other components, push initial items, etc. +// loop.Run(ctx) +// +// # Permissive API +// +// All methods are valid on a not-yet-running loop: +// - Push: items are buffered and will be processed once Run is called. +// - Stop: sets the stopped flag; a subsequent Run will exit immediately. +// - Wait: blocks until Run is called AND the loop exits. If Run is never +// called, Wait blocks forever (this is a programming error, analogous +// to reading from a channel that nobody writes to). +type TurnLoop[T any] struct { + config TurnLoopConfig[T] + + buffer *internal.UnboundedChan[T] + + stopped int32 + started int32 + + done chan struct{} + + result *TurnLoopExitState[T] + + stopOnce sync.Once + + runOnce sync.Once + + stopSig *stopSignal + + preemptSig *preemptSignal + + runErr error + + canceledItems []T + + checkPointRunnerBytes []byte + + pendingResume *turnLoopPendingResume[T] + + loadCheckpointID string + + onAgentEvents func(ctx context.Context, tc *TurnContext[T], events *AsyncIterator[*AgentEvent]) error +} + +type turnLoopCheckpoint[T any] struct { + RunnerCheckpoint []byte + // HasRunnerState reports whether RunnerCheckpoint contains resumable runner state. + // It is false for "between turns" checkpoints where no agent execution was + // interrupted (e.g. Stop() before the first turn or between turns). + HasRunnerState bool + UnhandledItems []T + CanceledItems []T +} + +// ErrCheckpointStoreNil is returned when a checkpoint operation requires a Store +// but none was configured in TurnLoopConfig. +var ErrCheckpointStoreNil = errors.New("checkpoint store is nil") + +func marshalTurnLoopCheckpoint[T any](c *turnLoopCheckpoint[T]) ([]byte, error) { + buf := new(bytes.Buffer) + if err := gob.NewEncoder(buf).Encode(c); err != nil { + return nil, err } - return nil + return buf.Bytes(), nil } -type turnLoopCancelSigKey struct{} +func unmarshalTurnLoopCheckpoint[T any](data []byte) (*turnLoopCheckpoint[T], error) { + var c turnLoopCheckpoint[T] + if err := gob.NewDecoder(bytes.NewReader(data)).Decode(&c); err != nil { + return nil, err + } + return &c, nil +} -func withTurnLoopCancelSig(ctx context.Context, cs *turnLoopCancelSig) context.Context { - return context.WithValue(ctx, turnLoopCancelSigKey{}, cs) +func (l *TurnLoop[T]) saveTurnLoopCheckpoint(ctx context.Context, checkPointID string, c *turnLoopCheckpoint[T]) error { + if l.config.Store == nil { + return ErrCheckpointStoreNil + } + data, err := marshalTurnLoopCheckpoint(c) + if err != nil { + return err + } + return l.config.Store.Set(ctx, checkPointID, data) } -func getTurnLoopCancelSig(ctx context.Context) *turnLoopCancelSig { - if v, ok := ctx.Value(turnLoopCancelSigKey{}).(*turnLoopCancelSig); ok { - return v +func (l *TurnLoop[T]) deleteTurnLoopCheckpoint(ctx context.Context, checkPointID string) error { + if l.config.Store == nil { + return nil } - return nil + if deleter, ok := l.config.Store.(CheckPointDeleter); ok { + return deleter.Delete(ctx, checkPointID) + } + return l.config.Store.Set(ctx, checkPointID, nil) } -// ErrAgentNotCancellableInTurnLoop is returned when WithCancel context is used -// but the Agent does not implement CancellableAgent. -var ErrAgentNotCancellableInTurnLoop = errors.New("agent does not support cancel but WithCancel context was provided") +func (l *TurnLoop[T]) tryLoadCheckpoint(ctx context.Context) error { + checkPointID := l.config.CheckpointID + if checkPointID == "" || l.config.Store == nil { + return nil + } + + l.loadCheckpointID = checkPointID -// NewTurnLoop creates a new TurnLoop from the given configuration. -// Source, GenInput, and GetAgent are required fields. -func NewTurnLoop[T any](config TurnLoopConfig[T]) (*TurnLoop[T], error) { - if config.Source == nil { - return nil, fmt.Errorf("TurnLoopConfig.Source is required") + data, existed, err := l.config.Store.Get(ctx, checkPointID) + if err != nil { + return fmt.Errorf("failed to load checkpoint[%s]: %w", checkPointID, err) + } + if !existed { + return nil } - if config.GenInput == nil { - return nil, fmt.Errorf("TurnLoopConfig.GenInput is required") + + var cp *turnLoopCheckpoint[T] + if len(data) == 0 { + return nil } - if config.GetAgent == nil { - return nil, fmt.Errorf("TurnLoopConfig.GetAgent is required") + cp, err = unmarshalTurnLoopCheckpoint[T](data) + if err != nil { + return fmt.Errorf("failed to unmarshal checkpoint[%s]: %w", checkPointID, err) } - onAgentEvents := config.OnAgentEvents - if onAgentEvents == nil { - onAgentEvents = func(_ context.Context, _ T, iter *AsyncIterator[*AgentEvent]) error { - for { - event, ok := iter.Next() - if !ok { - break - } - if event.Err != nil { - return event.Err - } - } - return nil + newItems := l.buffer.TakeAll() + + if cp.HasRunnerState { + if len(cp.RunnerCheckpoint) == 0 { + l.buffer.PushFront(newItems) + return fmt.Errorf("checkpoint[%s] has runner state but bytes are empty", checkPointID) + } + l.pendingResume = &turnLoopPendingResume[T]{ + canceled: append([]T{}, cp.CanceledItems...), + unhandled: append([]T{}, cp.UnhandledItems...), + newItems: append([]T{}, newItems...), + resumeBytes: append([]byte{}, cp.RunnerCheckpoint...), } + } else { + items := make([]T, 0, len(cp.UnhandledItems)+len(newItems)) + items = append(items, cp.UnhandledItems...) + items = append(items, newItems...) + l.buffer.PushFront(items) } - return &TurnLoop[T]{ - source: config.Source, - genInput: config.GenInput, - getAgent: config.GetAgent, - onAgentEvents: onAgentEvents, - receiveTimeout: config.ReceiveTimeout, - store: config.Store, - }, nil + return nil +} + +type turnLoopPendingResume[T any] struct { + canceled []T + unhandled []T + newItems []T + resumeBytes []byte +} + +type stopConfig struct { + agentCancelOpts []AgentCancelOption +} + +// StopOption is an option for Stop(). +type StopOption func(*stopConfig) + +// WithAgentCancel sets the agent cancel options to use when stopping the loop. +// These options control how the currently running agent is cancelled. +func WithAgentCancel(opts ...AgentCancelOption) StopOption { + return func(cfg *stopConfig) { + cfg.agentCancelOpts = opts + } +} + +type pushConfig[T any] struct { + preempt bool + preemptDelay time.Duration + agentCancelOpts []AgentCancelOption + pushStrategy func(context.Context, *TurnContext[T]) []PushOption[T] +} + +// PushOption is an option for Push(). +type PushOption[T any] func(*pushConfig[T]) + +// WithPreempt signals that the current agent should be canceled after pushing. +// This enables atomic "push + preempt" to avoid race conditions between +// pushing an urgent item and triggering preemption. +// The loop will cancel the current agent turn and continue with the next turn, +// where GenInput will see all buffered items including the newly pushed one. +func WithPreempt[T any](agentCancelOpts ...AgentCancelOption) PushOption[T] { + return func(cfg *pushConfig[T]) { + cfg.preempt = true + cfg.agentCancelOpts = agentCancelOpts + } } -var ErrLoopExit = errors.New("loop exit") +// WithPreemptDelay sets a delay duration before preemption takes effect. +// When used with WithPreempt, the push will succeed immediately, but the +// preemption signal will be delayed by the specified duration. +// This allows the current agent to continue processing for a grace period +// before being preempted. +func WithPreemptDelay[T any](delay time.Duration) PushOption[T] { + return func(cfg *pushConfig[T]) { + cfg.preemptDelay = delay + } +} -// WithCancel returns a new context and a cancel function that can be used to -// cancel the TurnLoop's Run method externally. Each call to WithCancel creates -// an independent cancel signal, allowing multiple concurrent Run calls with -// separate cancel controls. +// WithPushStrategy provides dynamic push option resolution based on the current turn state. +// The callback receives the current turn's context and TurnContext (nil if no turn is active) +// and returns the actual PushOptions to apply. When WithPushStrategy is used, all other +// PushOptions passed to the same Push call are ignored. // -// The returned TurnLoopCancelFunc does not require a context parameter since -// the context is already bound when WithCancel is called. +// The returned options must not contain another WithPushStrategy; any nested +// strategy is silently stripped. // -// Example: +// Example: preempt only if the current turn is processing low-priority items: // -// ctx, cancel := turnLoop.WithCancel(context.Background()) -// go func() { -// err := turnLoop.Run(ctx) -// }() -// // Later, to cancel: -// cancel(adk.WithCancelMode(adk.CancelAfterToolCall)) -func (l *TurnLoop[T]) WithCancel(ctx context.Context) (context.Context, CancelFunc) { - cs := newTurnLoopCancelSig() - ctx = withTurnLoopCancelSig(ctx, cs) +// loop.Push(urgentItem, WithPushStrategy(func(ctx context.Context, tc *TurnContext[MyItem]) []PushOption[MyItem] { +// if tc == nil { +// return nil // between turns, plain push +// } +// if isLowPriority(tc.Consumed) { +// return []PushOption[MyItem]{WithPreempt[MyItem]()} +// } +// return nil // don't preempt high-priority work +// })) +func WithPushStrategy[T any](fn func(ctx context.Context, tc *TurnContext[T]) []PushOption[T]) PushOption[T] { + return func(cfg *pushConfig[T]) { + cfg.pushStrategy = fn + } +} - var once sync.Once - cancelFn := func(opts ...CancelOption) error { - cfg := &cancelConfig{ - Mode: CancelImmediate, +func defaultTurnLoopOnAgentEvents[T any](_ context.Context, _ *TurnContext[T], events *AsyncIterator[*AgentEvent]) error { + for { + event, ok := events.Next() + if !ok { + break } - for _, opt := range opts { - opt(cfg) + if event.Err != nil { + return event.Err } - once.Do(func() { - cs.cancel(cfg) - }) - return nil } - - return ctx, cancelFn + return nil } -// Run starts the blocking loop that continuously receives messages from the -// source, runs the agent returned by GetAgent for each message, and dispatches -// resulting events to OnAgentEvent. It blocks until the source returns an error -// (including context cancellation) or a callback fails. +// NewTurnLoop creates a new TurnLoop without starting it. +// The returned loop accepts Push and Stop calls immediately; pushed items +// are buffered until Run is called. +// Call Run to start the processing goroutine. // -// If a received message has ConsumePreemptive mode and the current agent -// implements Cancellable, the agent is canceled and the new message is processed -// immediately. If the agent does not implement Cancellable, preemptive messages -// are queued and processed after the current agent finishes. -// -// To enable external cancellation, use WithCancel to create a cancellable context: +// NewTurnLoop panics if GenInput or PrepareAgent is nil. +func NewTurnLoop[T any](cfg TurnLoopConfig[T]) *TurnLoop[T] { + if cfg.GenInput == nil { + panic("adk: NewTurnLoop: GenInput is required") + } + if cfg.PrepareAgent == nil { + panic("adk: NewTurnLoop: PrepareAgent is required") + } + + l := &TurnLoop[T]{ + config: cfg, + buffer: internal.NewUnboundedChan[T](), + done: make(chan struct{}), + stopSig: newStopSignal(), + preemptSig: newPreemptSignal(), + } + if cfg.OnAgentEvents != nil { + l.onAgentEvents = cfg.OnAgentEvents + } else { + l.onAgentEvents = defaultTurnLoopOnAgentEvents[T] + } + return l +} + +func (l *TurnLoop[T]) start(ctx context.Context) { + l.runOnce.Do(func() { + atomic.StoreInt32(&l.started, 1) + go l.run(ctx) + }) +} + +// Run starts the loop's processing goroutine. It is non-blocking: the loop +// runs in the background and results are obtained via Wait. // -// ctx, cancel := turnLoop.WithCancel(context.Background()) -// go turnLoop.Run(ctx) -// // Later: cancel() +// If CheckpointID is configured in TurnLoopConfig and a matching checkpoint +// exists in Store, the loop automatically resumes from that checkpoint. +// Otherwise it starts fresh with whatever items were Push()-ed. // -// To enable checkpoint-based resumption, use WithTurnLoopResume: +// Calling Run more than once is a no-op: only the first call starts the loop. +func (l *TurnLoop[T]) Run(ctx context.Context) { + l.start(ctx) +} + +// Push adds an item to the loop's buffer for processing. +// This method is non-blocking and thread-safe. +// Returns false if the loop has stopped, true otherwise. If a preemptive push +// succeeds, the second return value is a channel that is closed when the loop +// has acknowledged the preempt signal (by either initiating cancellation of the +// current agent run or reaching a point where no cancellation is needed). +// If the loop has not been started yet (Run not called), items are buffered +// and will be processed once Run is called. // -// err := turnLoop.Run(ctx, WithTurnLoopResume("session-123")) +// Use WithPreempt() to atomically push an item and signal preemption of the current agent. +// This is useful for urgent items that should interrupt the current processing. +// The returned channel may be waited on if the caller needs to ensure the preempt +// signal has been observed. // -//nolint:cyclop,funlen // This is a core method, splitting would make the logic harder to follow -func (l *TurnLoop[T]) Run(ctx context.Context, opts ...TurnLoopRunOption[T]) error { - var runCfg turnLoopRunConfig[T] +// Use WithPreemptDelay() together with WithPreempt() to delay the preemption signal. +// Push returns immediately after the item is buffered, and a goroutine is spawned +// to signal preemption after the delay. +func (l *TurnLoop[T]) Push(item T, opts ...PushOption[T]) (bool, <-chan struct{}) { + cfg := &pushConfig[T]{} for _, opt := range opts { - opt(&runCfg) + opt(cfg) } - cs := getTurnLoopCancelSig(ctx) - toResumeFirst := false - if len(runCfg.checkPointID) > 0 { - toResumeFirst = true + if cfg.pushStrategy != nil { + return l.pushWithStrategy(item, cfg) } - for { - if cs != nil && cs.isCancelled() { - return nil - } + return l.pushWithConfig(item, cfg) +} - var nCtx context.Context - var item T - var checkPointID string - if !toResumeFirst { - var err error - var option []ConsumeOption - nCtx, item, option, err = l.source.Receive(ctx, ReceiveConfig{ - Timeout: l.receiveTimeout, - }) - if errors.Is(err, ErrLoopExit) { - return nil - } - if err != nil { - return fmt.Errorf("failed to receive message: %w", err) +// pushWithStrategy atomically holds the run loop and snapshots the current turn, +// then calls the strategy callback with a guaranteed-stable TurnContext. If the +// strategy returns preempt options, the hold is kept and a preempt is requested; +// otherwise the hold is released and the item is buffered as a plain push. +func (l *TurnLoop[T]) pushWithStrategy(item T, cfg *pushConfig[T]) (bool, <-chan struct{}) { + strategy := cfg.pushStrategy + + runCtx, tcAny := l.preemptSig.holdAndGetTurn() + if runCtx == nil { + runCtx = context.Background() + } + var tc *TurnContext[T] + if tcAny != nil { + tc = tcAny.(*TurnContext[T]) + } + realOpts := strategy(runCtx, tc) + cfg = &pushConfig[T]{} + for _, opt := range realOpts { + opt(cfg) + } + cfg.pushStrategy = nil + + if !cfg.preempt { + l.preemptSig.unholdRunLoop() + return l.buffer.TrySend(item), nil + } + + if atomic.LoadInt32(&l.stopped) != 0 { + l.preemptSig.unholdRunLoop() + return false, nil + } + + if !l.buffer.TrySend(item) { + l.preemptSig.unholdRunLoop() + return false, nil + } + + ack := make(chan struct{}) + if atomic.LoadInt32(&l.started) == 0 { + l.preemptSig.unholdRunLoop() + close(ack) + return true, ack + } + + if cfg.preemptDelay > 0 { + go func() { + select { + case <-time.After(cfg.preemptDelay): + l.preemptSig.requestPreempt(ack, cfg.agentCancelOpts...) + case <-l.done: + l.preemptSig.unholdRunLoop() + close(ack) } - o := applyConsumeOptions(option) - checkPointID = o.CheckPointID - } else { - nCtx = ctx - item = runCfg.item - checkPointID = runCfg.checkPointID + }() + } else { + l.preemptSig.requestPreempt(ack, cfg.agentCancelOpts...) + } + return true, ack +} + +func (l *TurnLoop[T]) pushWithConfig(item T, cfg *pushConfig[T]) (bool, <-chan struct{}) { + if atomic.LoadInt32(&l.stopped) != 0 { + return false, nil + } + + if cfg.preempt { + l.preemptSig.holdRunLoop() + + if !l.buffer.TrySend(item) { + l.preemptSig.unholdRunLoop() + return false, nil } - if len(checkPointID) > 0 && l.store == nil { - return fmt.Errorf("CheckPointStore is required") + ack := make(chan struct{}) + if atomic.LoadInt32(&l.started) == 0 { + l.preemptSig.unholdRunLoop() + close(ack) + return true, ack } - input, runOpts, e := l.genInput(nCtx, item) - if e != nil { - return fmt.Errorf("failed to generate agent input: %w", e) + if cfg.preemptDelay > 0 { + go func() { + select { + case <-time.After(cfg.preemptDelay): + l.preemptSig.requestPreempt(ack, cfg.agentCancelOpts...) + case <-l.done: + l.preemptSig.unholdRunLoop() + close(ack) + } + }() + } else { + l.preemptSig.requestPreempt(ack, cfg.agentCancelOpts...) } + return true, ack + } + + return l.buffer.TrySend(item), nil +} + +// Stop signals the loop to stop and returns immediately (non-blocking). +// The loop will finish the current turn (or cancel it via WithAgentCancel options), +// then exit without starting a new turn. +// Use WithAgentCancel to control how the currently running agent is cancelled. +// This method is idempotent - multiple calls update cancel options. +// Call Wait() to block until the loop has fully exited and get the result. +// +// Stop may be called before Run. In that case, the stopped flag is set and +// a subsequent Run will exit the loop immediately. +// +// If the running agent does not support the WithCancel AgentRunOption, +// Stop degrades to "exit the loop on entering the next iteration" — the +// current agent turn runs to completion before the loop exits. +func (l *TurnLoop[T]) Stop(opts ...StopOption) { + cfg := &stopConfig{} + for _, opt := range opts { + opt(cfg) + } + + l.stopSig.signal(cfg) + + l.stopOnce.Do(func() { + l.stopSig.closeDone() + atomic.StoreInt32(&l.stopped, 1) + l.buffer.Close() + }) +} + +// Wait blocks until the loop exits and returns the result. +// This method is safe to call from multiple goroutines. +// All callers will receive the same result. +// +// Wait blocks until Run is called AND the loop exits. If Run is +// ever called, Wait blocks forever. +func (l *TurnLoop[T]) Wait() *TurnLoopExitState[T] { + <-l.done + return l.result +} + +func (l *TurnLoop[T]) run(ctx context.Context) { + defer l.cleanup(ctx) + + if err := l.tryLoadCheckpoint(ctx); err != nil { + l.runErr = err + return + } - agent, e := l.getAgent(nCtx, item) - if e != nil { - return fmt.Errorf("failed to get agent: %w", e) + // Monitor context cancellation: close the buffer so that a blocking + // Receive() unblocks. The loop will then check ctx.Err() and exit. + go func() { + select { + case <-ctx.Done(): + l.buffer.Close() + case <-l.done: } + }() - var cancelFunc CancelFunc - var iter *AsyncIterator[*AgentEvent] - _, isAgentCancellable := agent.(CancellableAgent) - if cs != nil && !isAgentCancellable { - return fmt.Errorf("%w: agent %s", ErrAgentNotCancellableInTurnLoop, agent.Name(nCtx)) + for { + if l.stopSig.isStopped() { + return } - if toResumeFirst { - var err error - iter, cancelFunc, err = NewRunner(nCtx, RunnerConfig{ - EnableStreaming: input.EnableStreaming, - Agent: agent, - CheckPointStore: l.store, - }).ResumeWithCancel(nCtx, checkPointID, runOpts...) - if err != nil { - return fmt.Errorf("failed to resume agent: %w", err) + isResume := false + var pr *turnLoopPendingResume[T] + var items []T + var pushBack []T + + if l.pendingResume != nil { + isResume = true + pr = l.pendingResume + l.pendingResume = nil + + pushBack = make([]T, 0, len(pr.canceled)+len(pr.unhandled)+len(pr.newItems)) + pushBack = append(pushBack, pr.canceled...) + pushBack = append(pushBack, pr.unhandled...) + pushBack = append(pushBack, pr.newItems...) + } else { + first, ok := l.buffer.Receive() + if !ok { + if err := ctx.Err(); err != nil { + l.runErr = err + } + return } - toResumeFirst = false - } else if isAgentCancellable { - var cps CheckPointStore - if len(checkPointID) > 0 { - cps = l.store - runOpts = append(runOpts, WithCheckPointID(checkPointID)) + + if err := ctx.Err(); err != nil { + l.buffer.PushFront([]T{first}) + l.runErr = err + return } - iter, cancelFunc = NewRunner(nCtx, RunnerConfig{ - EnableStreaming: input.EnableStreaming, - Agent: agent, - CheckPointStore: cps, - }).RunWithCancel(nCtx, input.Messages, runOpts...) - } else { - var cps CheckPointStore - if len(checkPointID) > 0 { - cps = l.store - runOpts = append(runOpts, WithCheckPointID(checkPointID)) + + if l.stopSig.isStopped() { + l.buffer.PushFront([]T{first}) + return } - iter = NewRunner(nCtx, RunnerConfig{ - EnableStreaming: input.EnableStreaming, - Agent: agent, - CheckPointStore: cps, - }).Run(nCtx, input.Messages, runOpts...) + + rest := l.buffer.TakeAll() + items = append([]T{first}, rest...) + pushBack = items } - handleEvents := func() error { - return l.handleEvents(ctx, item, iter, checkPointID) + // Drain any pending preempt that arrived between turns. A Push caller + // may have called holdRunLoop + requestPreempt while the loop was + // between iterations; acknowledge and release before planning the + // next turn. Use drainAll to release all pusher holds at once — + // multiple concurrent Push(WithPreempt) callers each hold a ref. + if preempted, _, ackList := l.preemptSig.waitForPreemptOrUnhold(); preempted { + for _, ack := range ackList { + close(ack) + } + l.preemptSig.drainAll() } - if cancelFunc != nil { - var handleEventErr error - done := make(chan struct{}) + plan, err := l.planTurn(ctx, isResume, items, pr) + if err != nil { + if len(pushBack) > 0 { + l.buffer.PushFront(pushBack) + } + l.runErr = err + return + } - go func() { - defer func() { - panicErr := recover() - if panicErr != nil { - handleEventErr = safe.NewPanicErr(panicErr, debug.Stack()) - } + if l.stopSig.isStopped() { + if len(pushBack) > 0 { + l.buffer.PushFront(pushBack) + } + return + } - close(done) - }() + agent, err := l.config.PrepareAgent(plan.turnCtx, l, plan.spec.consumed) + if err != nil { + if len(pushBack) > 0 { + l.buffer.PushFront(pushBack) + } + l.runErr = err + return + } - handleEventErr = handleEvents() - }() + if l.stopSig.isStopped() { + if len(pushBack) > 0 { + l.buffer.PushFront(pushBack) + } + return + } - frontDone := make(chan struct{}) - var frontErr error - var option []ConsumeOption - go func() { - defer func() { - panicErr := recover() - if panicErr != nil { - frontErr = safe.NewPanicErr(panicErr, debug.Stack()) - } + l.buffer.PushFront(plan.remaining) - close(frontDone) - }() - _, _, option, frontErr = l.source.Front(nCtx, ReceiveConfig{ - Timeout: l.receiveTimeout, - }) - }() + // Bracket the turn with holdRunLoop / endTurnAndUnhold. The run loop's + // own hold ensures that if a Push caller also holds mid-turn, the total + // holdCount stays > 0 after endTurnAndUnhold, blocking the loop at + // waitForPreemptOrUnhold until the Push caller's preempt is resolved. + l.preemptSig.holdRunLoop() + runErr := l.runAgentAndHandleEvents(plan.turnCtx, agent, plan.spec) - select { - case <-frontDone: - case <-done: - case <-cs.getDoneChan(): - err := cancelAndWait(cancelFunc, cs, done) - if err != nil { - return err - } - return l.wrapHandleEventErr(handleEventErr) - } + l.preemptSig.endTurnAndUnhold() - if frontErr != nil { - <-done - if errors.Is(frontErr, ErrLoopExit) { - return nil - } - return fmt.Errorf("failed to front message: %w", frontErr) - } + if runErr != nil { + l.runErr = runErr + return + } + } +} - o := applyConsumeOptions(option) - switch o.Mode { - case ConsumePreemptive: - err := cancelFunc(o.CancelOpts...) - if err != nil { - <-done - return fmt.Errorf("failed to cancel agent: %w", err) - } - case ConsumePreemptiveOnTimeout: - select { - case <-done: - case <-time.After(o.Timeout): - err := cancelFunc(o.CancelOpts...) - if err != nil { - <-done - return fmt.Errorf("failed to cancel agent: %w", err) +func (l *TurnLoop[T]) setupBridgeStore(spec *turnRunSpec[T], runOpts []AgentRunOption) ([]AgentRunOption, *bridgeStore, error) { + store := l.config.Store + if store == nil && spec.isResume { + return nil, nil, fmt.Errorf("failed to resume agent: %w", ErrCheckpointStoreNil) + } + if store == nil { + return runOpts, nil, nil + } + runOpts = append(runOpts, WithCheckPointID(bridgeCheckpointID)) + if spec.isResume { + if len(spec.resumeBytes) == 0 { + return nil, nil, fmt.Errorf("resume checkpoint is empty") + } + return runOpts, newResumeBridgeStore(bridgeCheckpointID, spec.resumeBytes), nil + } + return runOpts, newBridgeStore(), nil +} + +// watchPreemptSignal runs for the lifetime of a single turn. It listens on the +// notify channel for preempt requests and relays them to agentCancelFunc. +// +// preemptGen de-duplicates notifications: multiple notify wakes can fire for the +// same logical preempt (e.g. cond.Broadcast + channel send), so the watcher +// only acts when the generation advances. +// +// On the first preempt whose cancel actually contributed (i.e. the cancel options +// were accepted before the CancelError was finalized), preemptDone is closed to +// wake runAgentAndHandleEvents's select. +func (l *TurnLoop[T]) watchPreemptSignal(done <-chan struct{}, agentCancelFunc AgentCancelFunc, preemptDone chan struct{}) { + var lastGen uint64 + for { + select { + case <-done: + return + case <-l.preemptSig.notify: + if preempted, gen, opts, ackList := l.preemptSig.receivePreempt(); preempted { + if gen != lastGen { + firstPreempt := lastGen == 0 + lastGen = gen + // CancelHandle is intentionally not awaited here: agentCancelFunc commits the cancel signal synchronously, + // while waiting would block until the turn finishes and can deadlock this watcher against the done signal. + _, contributed := agentCancelFunc(opts...) + if firstPreempt && contributed { + close(preemptDone) } - case <-cs.getDoneChan(): - err := cancelAndWait(cancelFunc, cs, done) - if err != nil { - return err + for _, ack := range ackList { + close(ack) } - return l.wrapHandleEventErr(handleEventErr) } } + } + } +} - select { - case <-done: - case <-cs.getDoneChan(): - err := cancelAndWait(cancelFunc, cs, done) - if err != nil { - return err +// watchStopSignal runs for the lifetime of a single turn. It selects on two +// channels from stopSignal: +// +// - done (permanently closed after Stop): the durable stop flag. Fires +// immediately for any watcher, even those in turns started after +// Stop() but before the run loop observed isStopped(). This eliminates +// the race where a previous turn's watcher consumed the one-shot notify, +// leaving the current turn unable to detect the stop. +// +// - notify (one-shot, buffered 1): fires when a new Stop() call is made, +// enabling cancel-mode escalation (e.g. CancelAfterToolCalls → CancelImmediate). +// The generation counter de-duplicates wakes, analogous to preemptGen in +// watchPreemptSignal. +// +// On the first cancel that actually contributed (i.e. the cancel was accepted +// before the CancelError was finalized), stoppedDone is closed to wake +// runAgentAndHandleEvents's select. +func (l *TurnLoop[T]) watchStopSignal(done <-chan struct{}, agentCancelFunc AgentCancelFunc, stoppedDone chan struct{}) { + var lastGen uint64 + stoppedClosed := false + for { + select { + case <-done: + return + case <-l.stopSig.notify: + gen, opts := l.stopSig.check() + if gen != lastGen { + lastGen = gen + // CancelHandle is intentionally not awaited here: agentCancelFunc + // commits the cancel signal synchronously, while waiting would block + // until the turn finishes and can deadlock this watcher against done. + _, contributed := agentCancelFunc(opts...) + if contributed && !stoppedClosed { + close(stoppedDone) + stoppedClosed = true } - return l.wrapHandleEventErr(handleEventErr) - } - if err := l.wrapHandleEventErr(handleEventErr); err != nil { - return err } - } else { - if handleEventErr := handleEvents(); handleEventErr != nil { - if err := l.wrapHandleEventErr(handleEventErr); err != nil { - return err - } + case <-l.stopSig.done: + _, opts := l.stopSig.check() + _, contributed := agentCancelFunc(opts...) + if contributed && !stoppedClosed { + close(stoppedDone) + stoppedClosed = true } + <-done + return } } } -func (l *TurnLoop[T]) wrapHandleEventErr(handleEventErr error) error { - if handleEventErr == nil { - return nil +func (l *TurnLoop[T]) runAgentAndHandleEvents( + ctx context.Context, + agent Agent, + spec *turnRunSpec[T], +) error { + var iter *AsyncIterator[*AgentEvent] + defer func() { + if l.stopSig.isStopped() && len(l.canceledItems) == 0 { + l.canceledItems = append([]T{}, spec.consumed...) + } + }() + + runOpts, ms, err := l.setupBridgeStore(spec, spec.runOpts) + if err != nil { + return err } - if errors.Is(handleEventErr, ErrLoopExit) { - return nil + store := l.config.Store + cancelOpt, agentCancelFunc := WithCancel() + runOpts = append(runOpts, cancelOpt) + + enableStreaming := false + if spec.input != nil { + enableStreaming = spec.input.EnableStreaming } - var interruptErr *TurnLoopInterruptError[T] - if errors.As(handleEventErr, &interruptErr) { - return interruptErr + runner := NewRunner(ctx, RunnerConfig{ + EnableStreaming: enableStreaming, + Agent: agent, + CheckPointStore: ms, + }) + + preemptDone := make(chan struct{}) + stoppedDone := make(chan struct{}) + + tc := &TurnContext[T]{ + Loop: l, + Consumed: spec.consumed, + Preempted: preemptDone, + Stopped: stoppedDone, } - return fmt.Errorf("failed to handle events: %w", handleEventErr) -} + l.preemptSig.setTurn(ctx, tc) -func (l *TurnLoop[T]) handleEvents(ctx context.Context, item T, iter *AsyncIterator[*AgentEvent], checkPointID string) error { - copies := copyEventIterator(iter, 2) - oe := l.onAgentEvents(ctx, item, copies[0]) - if oe != nil { - return oe - } - for { - e, ok := copies[1].Next() - if !ok { - break + if spec.isResume { + var err error + if spec.resumeParams != nil { + iter, err = runner.ResumeWithParams(ctx, bridgeCheckpointID, spec.resumeParams, runOpts...) + } else { + iter, err = runner.Resume(ctx, bridgeCheckpointID, runOpts...) } - if e.Action != nil && e.Action.Interrupted != nil { - return &TurnLoopInterruptError[T]{ - Item: item, - CheckpointID: checkPointID, - InterruptContexts: e.Action.Interrupted.InterruptContexts, - } + if err != nil { + return fmt.Errorf("failed to resume agent: %w", err) } + } else { + iter = runner.Run(ctx, spec.input.Messages, runOpts...) } - return nil -} -func cancelAndWait(cf CancelFunc, cs *turnLoopCancelSig, done chan struct{}) error { - cfg := cs.getConfig() - err := cf(cancelConfigToOpts(cfg)...) - if err != nil { - <-done - return fmt.Errorf("failed to cancel agent: %w", err) + handleEvents := func() error { + return l.onAgentEvents(ctx, tc, iter) } - <-done - return nil -} -func cancelConfigToOpts(cfg *cancelConfig) []CancelOption { - if cfg == nil { + done := make(chan struct{}) + var handleErr error + + go func() { + defer func() { + panicErr := recover() + if panicErr != nil { + handleErr = safe.NewPanicErr(panicErr, debug.Stack()) + } + close(done) + }() + handleErr = handleEvents() + }() + go l.watchPreemptSignal(done, agentCancelFunc, preemptDone) + go l.watchStopSignal(done, agentCancelFunc, stoppedDone) + + finalizeCheckpoint := func() error { + if store != nil && ms != nil { + data, ok, err := ms.Get(ctx, bridgeCheckpointID) + if err != nil { + return fmt.Errorf("failed to read runner checkpoint: %w", err) + } + if ok { + l.checkPointRunnerBytes = append([]byte{}, data...) + } + } return nil } - opts := []CancelOption{WithCancelMode(cfg.Mode)} - if cfg.Timeout != nil { - opts = append(opts, WithCancelTimeout(*cfg.Timeout)) + + // Wait for the turn to end. Three outcomes: + // + // done: Events fully handled (normal or error). If Stop() was + // called, save checkpoint so the caller can resume later. + // Also handle the select race: if preemptDone is closed + // too, treat as a preempt (return nil) instead of leaking + // the CancelError. + // + // preemptDone: A preemptive Push successfully cancelled the agent. + // Wait for the handleEvents goroutine to drain, then + // return nil — the run loop will start a new turn. + // + // stoppedDone: Stop() cancelled the agent. Save checkpoint so the + // caller can resume later. + select { + case <-done: + select { + case <-preemptDone: + return nil + default: + } + if l.stopSig.isStopped() { + if err := finalizeCheckpoint(); err != nil { + if handleErr != nil { + handleErr = fmt.Errorf("%w; checkpoint error: %v", handleErr, err) + } else { + handleErr = err + } + } + } + return handleErr + case <-preemptDone: + <-done + return nil + case <-stoppedDone: + <-done + if err := finalizeCheckpoint(); err != nil { + if handleErr != nil { + handleErr = fmt.Errorf("%w; checkpoint error: %v", handleErr, err) + } else { + handleErr = err + } + } + return handleErr } - return opts } -func applyConsumeOptions(opts []ConsumeOption) *consumeConfig { - var config consumeConfig - for _, opt := range opts { - opt(&config) +func (l *TurnLoop[T]) cleanup(ctx context.Context) { + atomic.StoreInt32(&l.stopped, 1) + + unhandled := l.buffer.TakeAll() + checkpointID := l.config.CheckpointID + shouldSaveCheckpoint := l.config.Store != nil && checkpointID != "" && l.stopSig.isStopped() + if shouldSaveCheckpoint { + cp := &turnLoopCheckpoint[T]{ + RunnerCheckpoint: l.checkPointRunnerBytes, + HasRunnerState: len(l.checkPointRunnerBytes) > 0, + UnhandledItems: unhandled, + CanceledItems: l.canceledItems, + } + err := l.saveTurnLoopCheckpoint(ctx, checkpointID, cp) + if err != nil { + saveErr := fmt.Errorf("failed to save turn loop checkpoint: %w", err) + if l.runErr != nil { + l.runErr = fmt.Errorf("%w; %v", l.runErr, saveErr) + } else { + l.runErr = saveErr + } + } + } else if l.loadCheckpointID != "" { + _ = l.deleteTurnLoopCheckpoint(ctx, l.loadCheckpointID) } - return &config -} -type TurnLoopInterruptError[T any] struct { - Item T - CheckpointID string - // InterruptContexts provides a structured, user-facing view of the interrupt chain. - // Each context represents a step in the agent hierarchy that was interrupted. - InterruptContexts []*InterruptCtx -} + l.result = &TurnLoopExitState[T]{ + ExitReason: l.runErr, + UnhandledItems: unhandled, + CanceledItems: l.canceledItems, + } -func (t *TurnLoopInterruptError[T]) Error() string { - return fmt.Sprintf("TurnLoopInterruptError[%s]: %v", t.CheckpointID, t.InterruptContexts) + l.preemptSig.drainAll() + l.buffer.Close() + close(l.done) } diff --git a/adk/turn_loop_test.go b/adk/turn_loop_test.go index 88e3679b3..6e3159cfc 100644 --- a/adk/turn_loop_test.go +++ b/adk/turn_loop_test.go @@ -20,1097 +20,3703 @@ import ( "context" "errors" "fmt" + "sync" "sync/atomic" "testing" "time" - "unsafe" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/cloudwego/eino/components/model" - "github.com/cloudwego/eino/components/tool" - "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/schema" ) -type turnLoopMockSource struct { - items []string - idx int - err error +type turnLoopMockAgent struct { + name string + events []*AgentEvent + runFunc func(ctx context.Context, input *AgentInput) (*AgentOutput, error) + cancelFunc func(opts ...AgentCancelOption) error } -func (s *turnLoopMockSource) Receive(ctx context.Context, cfg ReceiveConfig) (context.Context, string, []ConsumeOption, error) { - if s.idx >= len(s.items) { - return ctx, "", nil, s.err - } - item := s.items[s.idx] - s.idx++ - return ctx, item, nil, nil -} +func (a *turnLoopMockAgent) Name(_ context.Context) string { return a.name } +func (a *turnLoopMockAgent) Description(_ context.Context) string { return "mock agent" } +func (a *turnLoopMockAgent) Run(ctx context.Context, input *AgentInput, _ ...AgentRunOption) *AsyncIterator[*AgentEvent] { + iter, gen := NewAsyncIteratorPair[*AgentEvent]() -func (s *turnLoopMockSource) Front(ctx context.Context, cfg ReceiveConfig) (context.Context, string, []ConsumeOption, error) { - if s.idx >= len(s.items) { - return ctx, "", nil, s.err + if a.runFunc != nil { + go func() { + defer gen.Close() + output, err := a.runFunc(ctx, input) + if err != nil { + gen.Send(&AgentEvent{Err: err}) + return + } + gen.Send(&AgentEvent{Output: output}) + }() + return iter } - return ctx, s.items[s.idx], nil, nil + + go func() { + defer gen.Close() + for _, e := range a.events { + gen.Send(e) + } + }() + return iter } -type turnLoopFuncSource[T any] struct { - receive func(ctx context.Context, cfg ReceiveConfig) (context.Context, T, []ConsumeOption, error) - front func(ctx context.Context, cfg ReceiveConfig) (context.Context, T, []ConsumeOption, error) +type turnLoopCheckpointStore struct { + m map[string][]byte + mu sync.Mutex } -func (s *turnLoopFuncSource[T]) Receive(ctx context.Context, cfg ReceiveConfig) (context.Context, T, []ConsumeOption, error) { - return s.receive(ctx, cfg) +func (s *turnLoopCheckpointStore) Set(_ context.Context, key string, value []byte) error { + s.mu.Lock() + defer s.mu.Unlock() + s.m[key] = value + return nil } -func (s *turnLoopFuncSource[T]) Front(ctx context.Context, cfg ReceiveConfig) (context.Context, T, []ConsumeOption, error) { - if s.front != nil { - return s.front(ctx, cfg) - } - return s.receive(ctx, cfg) +func (s *turnLoopCheckpointStore) Get(_ context.Context, key string) ([]byte, bool, error) { + s.mu.Lock() + defer s.mu.Unlock() + v, ok := s.m[key] + return v, ok, nil } -type turnLoopMockAgent struct { - name string - events []*AgentEvent +type turnLoopCancellableMockAgent struct { + name string + runFunc func(ctx context.Context, input *AgentInput) (*AgentOutput, error) + onCancel func(cc *cancelContext) + cancel context.CancelFunc + mu sync.Mutex } -func (a *turnLoopMockAgent) Name(_ context.Context) string { return a.name } -func (a *turnLoopMockAgent) Description(_ context.Context) string { return "mock agent" } -func (a *turnLoopMockAgent) Run(_ context.Context, _ *AgentInput, _ ...AgentRunOption) *AsyncIterator[*AgentEvent] { +func (a *turnLoopCancellableMockAgent) Name(_ context.Context) string { return a.name } +func (a *turnLoopCancellableMockAgent) Description(_ context.Context) string { return "mock agent" } + +func (a *turnLoopCancellableMockAgent) Run(ctx context.Context, input *AgentInput, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { iter, gen := NewAsyncIteratorPair[*AgentEvent]() + + o := getCommonOptions(nil, opts...) + cc := o.cancelCtx + + a.mu.Lock() + var cancelCtx context.Context + cancelCtx, a.cancel = context.WithCancel(ctx) + a.mu.Unlock() + go func() { defer gen.Close() - for _, e := range a.events { - gen.Send(e) + if cc != nil { + go func() { + <-cc.cancelChan + // CRITICAL: call onCancel BEFORE cancel() to avoid race condition. + // If cancel() fires first, the runFunc returns immediately, + // flowAgent's defer calls markDone(), and doneChan closes + // before onCancel can read cc.config. + if a.onCancel != nil { + a.onCancel(cc) + } + a.mu.Lock() + if a.cancel != nil { + a.cancel() + } + a.mu.Unlock() + }() + } + + output, err := a.runFunc(cancelCtx, input) + if err != nil { + gen.Send(&AgentEvent{Err: err}) + return } + gen.Send(&AgentEvent{Output: output}) }() return iter } -type turnLoopCancellableAgent struct { - name string - startedCh chan struct{} - cancelCh chan struct{} - cancelled int32 - cancelledOpt []CancelOption -} - -func (a *turnLoopCancellableAgent) Name(_ context.Context) string { return a.name } -func (a *turnLoopCancellableAgent) Description(_ context.Context) string { return "cancellable mock" } -func (a *turnLoopCancellableAgent) Run(_ context.Context, _ *AgentInput, _ ...AgentRunOption) *AsyncIterator[*AgentEvent] { - iter, _ := a.RunWithCancel(context.Background(), nil) - return iter +type turnLoopStopModeProbeAgent struct { + ccCh chan *cancelContext } -func (a *turnLoopCancellableAgent) RunWithCancel(_ context.Context, _ *AgentInput, _ ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) { +func (a *turnLoopStopModeProbeAgent) Name(_ context.Context) string { return "probe" } +func (a *turnLoopStopModeProbeAgent) Description(_ context.Context) string { return "probe" } +func (a *turnLoopStopModeProbeAgent) Run(ctx context.Context, input *AgentInput, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { iter, gen := NewAsyncIteratorPair[*AgentEvent]() - close(a.startedCh) + o := getCommonOptions(nil, opts...) + cc := o.cancelCtx + a.ccCh <- cc go func() { defer gen.Close() - <-a.cancelCh + <-cc.cancelChan + for { + if cc.getMode() == CancelImmediate { + gen.Send(&AgentEvent{Err: cc.createCancelError()}) + return + } + time.Sleep(1 * time.Millisecond) + } }() - cancelFunc := func(opts ...CancelOption) error { - atomic.StoreInt32(&a.cancelled, 1) - a.cancelledOpt = opts - close(a.cancelCh) - return nil - } - return iter, cancelFunc + return iter } -func TestNewTurnLoop_Validation(t *testing.T) { - t.Run("missing source", func(t *testing.T) { - _, err := NewTurnLoop(TurnLoopConfig[string]{ - GenInput: func(_ context.Context, _ string) (*AgentInput, []AgentRunOption, error) { return nil, nil, nil }, - GetAgent: func(_ context.Context, _ string) (Agent, error) { return nil, nil }, - }) - require.Error(t, err) - assert.Contains(t, err.Error(), "Source") - }) +func newAndRunTurnLoop[T any](ctx context.Context, cfg TurnLoopConfig[T]) *TurnLoop[T] { + l := NewTurnLoop(cfg) + l.Run(ctx) + return l +} - t.Run("missing GenInput", func(t *testing.T) { - _, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: &turnLoopMockSource{}, - GetAgent: func(_ context.Context, _ string) (Agent, error) { return nil, nil }, - }) - require.Error(t, err) - assert.Contains(t, err.Error(), "GenInput") - }) +func TestTurnLoop_RunAndPush(t *testing.T) { + processedItems := make([]string, 0) + var mu sync.Mutex - t.Run("missing GetAgent", func(t *testing.T) { - _, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: &turnLoopMockSource{}, - GenInput: func(_ context.Context, _ string) (*AgentInput, []AgentRunOption, error) { return nil, nil, nil }, - }) - require.Error(t, err) - assert.Contains(t, err.Error(), "GetAgent") + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + mu.Lock() + processedItems = append(processedItems, items...) + mu.Unlock() + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, }) - t.Run("valid config without OnAgentEvents", func(t *testing.T) { - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: &turnLoopMockSource{}, - GenInput: func(_ context.Context, _ string) (*AgentInput, []AgentRunOption, error) { return nil, nil, nil }, - GetAgent: func(_ context.Context, _ string) (Agent, error) { return nil, nil }, - }) - require.NoError(t, err) - assert.NotNil(t, loop) - }) + loop.Push("msg1") + loop.Push("msg2") - t.Run("valid config with OnAgentEvents", func(t *testing.T) { - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: &turnLoopMockSource{}, - GenInput: func(_ context.Context, _ string) (*AgentInput, []AgentRunOption, error) { return nil, nil, nil }, - GetAgent: func(_ context.Context, _ string) (Agent, error) { return nil, nil }, - OnAgentEvents: func(_ context.Context, _ string, _ *AsyncIterator[*AgentEvent]) error { return nil }, - }) - require.NoError(t, err) - assert.NotNil(t, loop) - }) + time.Sleep(100 * time.Millisecond) + + loop.Stop() + result := loop.Wait() + + mu.Lock() + defer mu.Unlock() + + assert.NoError(t, result.ExitReason) + assert.True(t, len(processedItems) > 0, "should have processed at least one item") } -func TestTurnLoop_NormalLoop(t *testing.T) { - agent := &turnLoopMockAgent{ - name: "test-agent", - events: []*AgentEvent{ - {Output: &AgentOutput{MessageOutput: &MessageVariant{Message: schema.AssistantMessage("hello", nil)}}}, +func TestTurnLoop_PushReturnsErrorAfterStop(t *testing.T) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{}, + Consumed: items, + }, nil }, - } + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) - var receivedItems []string - var eventCount int + loop.Stop() - source := &turnLoopMockSource{ - items: []string{"msg1", "msg2", "msg3"}, - err: context.DeadlineExceeded, - } + ok, _ := loop.Push("msg1") + assert.False(t, ok) +} - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: source, - GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { - receivedItems = append(receivedItems, item) - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil +func TestTurnLoop_StopIsIdempotent(t *testing.T) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil }, - GetAgent: func(_ context.Context, _ string) (Agent, error) { - return agent, nil - }, - OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { - for { - _, ok := iter.Next() - if !ok { - break - } - eventCount++ - } - return nil + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil }, }) - require.NoError(t, err) - err = loop.Run(context.Background()) - assert.ErrorIs(t, err, context.DeadlineExceeded) - assert.Equal(t, []string{"msg1", "msg2", "msg3"}, receivedItems) - assert.Equal(t, 3, eventCount) + loop.Stop() + loop.Stop() + loop.Stop() + + result := loop.Wait() + assert.NoError(t, result.ExitReason) } -func TestTurnLoop_SourceError(t *testing.T) { - sourceErr := errors.New("source failure") - source := &turnLoopMockSource{ - items: nil, - err: sourceErr, +func TestTurnLoop_WaitMultipleGoroutines(t *testing.T) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Stop() + + var wg sync.WaitGroup + results := make([]*TurnLoopExitState[string], 3) + + for i := 0; i < 3; i++ { + i := i + wg.Add(1) + go func() { + defer wg.Done() + results[i] = loop.Wait() + }() } - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: source, - GenInput: func(_ context.Context, _ string) (*AgentInput, []AgentRunOption, error) { - return &AgentInput{}, nil, nil + wg.Wait() + + assert.Equal(t, results[0], results[1]) + assert.Equal(t, results[1], results[2]) +} + +func TestTurnLoop_UnhandledItemsOnStop(t *testing.T) { + started := make(chan struct{}) + blocked := make(chan struct{}) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + close(started) + <-blocked + return &GenInputResult[string]{ + Input: &AgentInput{}, + Consumed: items[:1], + Remaining: items[1:], + }, nil }, - GetAgent: func(_ context.Context, _ string) (Agent, error) { - return &turnLoopMockAgent{name: "a"}, nil + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil }, - OnAgentEvents: func(_ context.Context, _ string, _ *AsyncIterator[*AgentEvent]) error { return nil }, }) - require.NoError(t, err) - err = loop.Run(context.Background()) - assert.ErrorIs(t, err, sourceErr) + loop.Push("msg1") + loop.Push("msg2") + loop.Push("msg3") + + <-started + + loop.Stop() + close(blocked) + + result := loop.Wait() + assert.True(t, len(result.UnhandledItems) >= 0, "should return unhandled items") } func TestTurnLoop_GenInputError(t *testing.T) { - genErr := errors.New("gen input failure") + genErr := errors.New("gen input error") - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: &turnLoopMockSource{items: []string{"msg1"}, err: errors.New("should not reach")}, - GenInput: func(_ context.Context, _ string) (*AgentInput, []AgentRunOption, error) { - return nil, nil, genErr + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return nil, genErr }, - GetAgent: func(_ context.Context, _ string) (Agent, error) { - return &turnLoopMockAgent{name: "a"}, nil + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil }, - OnAgentEvents: func(_ context.Context, _ string, _ *AsyncIterator[*AgentEvent]) error { return nil }, }) - require.NoError(t, err) - err = loop.Run(context.Background()) - assert.ErrorIs(t, err, genErr) - assert.Contains(t, err.Error(), "failed to generate agent input") + loop.Push("msg1") + + result := loop.Wait() + assert.ErrorIs(t, result.ExitReason, genErr) } func TestTurnLoop_GetAgentError(t *testing.T) { - agentErr := errors.New("get agent failure") + agentErr := errors.New("get agent error") - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: &turnLoopMockSource{items: []string{"msg1"}, err: errors.New("should not reach")}, - GenInput: func(_ context.Context, _ string) (*AgentInput, []AgentRunOption, error) { - return &AgentInput{}, nil, nil + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil }, - GetAgent: func(_ context.Context, _ string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { return nil, agentErr }, - OnAgentEvents: func(_ context.Context, _ string, _ *AsyncIterator[*AgentEvent]) error { return nil }, }) - require.NoError(t, err) - err = loop.Run(context.Background()) - assert.ErrorIs(t, err, agentErr) - assert.Contains(t, err.Error(), "failed to get agent") + loop.Push("msg1") + + result := loop.Wait() + assert.ErrorIs(t, result.ExitReason, agentErr) } -func TestTurnLoop_OnAgentEventsError(t *testing.T) { - eventErr := errors.New("event handler failure") - agent := &turnLoopMockAgent{ - name: "test-agent", - events: []*AgentEvent{{Output: &AgentOutput{}}}, - } +func TestTurnLoop_BatchProcessing(t *testing.T) { + var batches [][]string + var mu sync.Mutex - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: &turnLoopMockSource{items: []string{"msg1"}, err: errors.New("should not reach")}, - GenInput: func(_ context.Context, _ string) (*AgentInput, []AgentRunOption, error) { - return &AgentInput{}, nil, nil - }, - GetAgent: func(_ context.Context, _ string) (Agent, error) { - return agent, nil + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + mu.Lock() + batches = append(batches, items) + mu.Unlock() + + return &GenInputResult[string]{ + Input: &AgentInput{}, + Consumed: items[:1], + Remaining: items[1:], + }, nil }, - OnAgentEvents: func(_ context.Context, _ string, _ *AsyncIterator[*AgentEvent]) error { - return eventErr + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil }, }) - require.NoError(t, err) - err = loop.Run(context.Background()) - assert.ErrorIs(t, err, eventErr) - assert.Contains(t, err.Error(), "failed to handle events") -} + loop.Push("msg1") + loop.Push("msg2") + loop.Push("msg3") -func TestTurnLoop_ContextCancellation(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + time.Sleep(200 * time.Millisecond) - callCount := 0 - source := &turnLoopFuncSource[string]{receive: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { - callCount++ - if callCount > 1 { - cancel() - return ctx, "", nil, ctx.Err() - } - return ctx, "msg1", nil, nil - }} + loop.Stop() + loop.Wait() - agent := &turnLoopMockAgent{ - name: "test-agent", - events: []*AgentEvent{{Output: &AgentOutput{}}}, - } + mu.Lock() + defer mu.Unlock() - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: source, - GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil - }, - GetAgent: func(_ context.Context, _ string) (Agent, error) { - return agent, nil + assert.True(t, len(batches) > 0, "should have processed at least one batch") +} + +func TestTurnLoop_StopWithMode(t *testing.T) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil }, - OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { - for { - if _, ok := iter.Next(); !ok { - break - } - } - return nil + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil }, }) - require.NoError(t, err) - err = loop.Run(ctx) - assert.ErrorIs(t, err, context.Canceled) + loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelAfterToolCalls))) + + result := loop.Wait() + assert.NoError(t, result.ExitReason) } -func TestTurnLoop_MultipleEventsPerTurn(t *testing.T) { - agent := &turnLoopMockAgent{ - name: "multi-event-agent", - events: []*AgentEvent{ - {Output: &AgentOutput{MessageOutput: &MessageVariant{Message: schema.AssistantMessage("event1", nil)}}}, - {Output: &AgentOutput{MessageOutput: &MessageVariant{Message: schema.AssistantMessage("event2", nil)}}}, - {Output: &AgentOutput{MessageOutput: &MessageVariant{Message: schema.AssistantMessage("event3", nil)}}}, +func TestTurnLoop_Preempt_CancelsCurrentAgent(t *testing.T) { + agentStarted := make(chan struct{}) + agentCancelled := make(chan struct{}) + agentStartedOnce := sync.Once{} + agentCancelledOnce := sync.Once{} + + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + agentStartedOnce.Do(func() { + close(agentStarted) + }) + <-ctx.Done() + agentCancelledOnce.Do(func() { + close(agentCancelled) + }) + return &AgentOutput{}, nil }, } - var eventCount int + genInputCalls := int32(0) + secondGenInputCalled := make(chan struct{}) + secondGenInputOnce := sync.Once{} - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: &turnLoopMockSource{items: []string{"msg1"}, err: context.DeadlineExceeded}, - GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil - }, - GetAgent: func(_ context.Context, _ string) (Agent, error) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { return agent, nil }, - OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { - for { - _, ok := iter.Next() - if !ok { - break - } - eventCount++ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + count := atomic.AddInt32(&genInputCalls, 1) + if count >= 2 { + secondGenInputOnce.Do(func() { + close(secondGenInputCalled) + }) } - return nil + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: []string{items[0]}, + Remaining: items[1:], + }, nil }, }) - require.NoError(t, err) - err = loop.Run(context.Background()) - assert.ErrorIs(t, err, context.DeadlineExceeded) - assert.Equal(t, 3, eventCount) -} + loop.Push("first") -func TestTurnLoop_DefaultOnAgentEvents(t *testing.T) { - agent := &turnLoopMockAgent{ - name: "test-agent", - events: []*AgentEvent{{Output: &AgentOutput{}}}, + select { + case <-agentStarted: + case <-time.After(1 * time.Second): + t.Fatal("agent did not start") } - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: &turnLoopMockSource{items: []string{"msg1"}, err: context.DeadlineExceeded}, - GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil - }, - GetAgent: func(_ context.Context, _ string) (Agent, error) { - return agent, nil - }, - }) - require.NoError(t, err) - - err = loop.Run(context.Background()) - assert.ErrorIs(t, err, context.DeadlineExceeded) -} + loop.Push("urgent", WithPreempt[string]()) -func TestTurnLoop_DefaultOnAgentEventsWithError(t *testing.T) { - agentErr := errors.New("agent internal error") - agent := &turnLoopMockAgent{ - name: "error-agent", - events: []*AgentEvent{{Err: agentErr}}, + select { + case <-agentCancelled: + case <-time.After(1 * time.Second): + t.Fatal("agent was not cancelled by preempt") } - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: &turnLoopMockSource{items: []string{"msg1"}, err: context.DeadlineExceeded}, - GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil - }, - GetAgent: func(_ context.Context, _ string) (Agent, error) { - return agent, nil - }, - }) - require.NoError(t, err) + select { + case <-secondGenInputCalled: + case <-time.After(1 * time.Second): + t.Fatal("second GenInput was not called after preempt") + } - err = loop.Run(context.Background()) - assert.ErrorIs(t, err, agentErr) + loop.Stop() + result := loop.Wait() + assert.NoError(t, result.ExitReason) + assert.GreaterOrEqual(t, atomic.LoadInt32(&genInputCalls), int32(2)) } -func TestTurnLoop_AgentErrorEvent(t *testing.T) { - agentErr := errors.New("agent internal error") - agent := &turnLoopMockAgent{ - name: "error-agent", - events: []*AgentEvent{{Err: agentErr}}, +func TestTurnLoop_Preempt_DiscardsConsumedItems(t *testing.T) { + agentStarted := make(chan struct{}) + agentDone := make(chan struct{}) + agentStartedOnce := sync.Once{} + agentDoneOnce := sync.Once{} + firstAgentRun := true + var firstRunMu sync.Mutex + + genInputResults := make([][]string, 0) + var mu sync.Mutex + + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + firstRunMu.Lock() + isFirst := firstAgentRun + firstAgentRun = false + firstRunMu.Unlock() + + if isFirst { + agentStartedOnce.Do(func() { + close(agentStarted) + }) + <-ctx.Done() + } else { + agentDoneOnce.Do(func() { + close(agentDone) + }) + } + return &AgentOutput{}, nil + }, } - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: &turnLoopMockSource{items: []string{"msg1"}, err: context.DeadlineExceeded}, - GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil - }, - GetAgent: func(_ context.Context, _ string) (Agent, error) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { return agent, nil }, - OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { - for { - event, ok := iter.Next() - if !ok { - break - } - if event.Err != nil { - return fmt.Errorf("agent run failed: %w", event.Err) - } - } - return nil + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + mu.Lock() + genInputResults = append(genInputResults, items) + mu.Unlock() + + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: []string{items[0]}, + Remaining: items[1:], + }, nil }, }) - require.NoError(t, err) - err = loop.Run(context.Background()) - assert.ErrorIs(t, err, agentErr) - assert.Contains(t, err.Error(), "agent run failed") -} + loop.Push("first") + + select { + case <-agentStarted: + case <-time.After(1 * time.Second): + t.Fatal("agent did not start") + } + + loop.Push("urgent", WithPreempt[string]()) -func TestTurnLoop_PreemptiveCancellation(t *testing.T) { - slowAgent := &turnLoopCancellableAgent{ - name: "slow-agent", - startedCh: make(chan struct{}), - cancelCh: make(chan struct{}), + select { + case <-agentDone: + case <-time.After(1 * time.Second): + t.Fatal("second agent run did not complete") } - fastAgent := &turnLoopMockAgent{ - name: "fast-agent", - events: []*AgentEvent{{Output: &AgentOutput{}}}, + loop.Stop() + result := loop.Wait() + assert.NoError(t, result.ExitReason) + + mu.Lock() + defer mu.Unlock() + assert.GreaterOrEqual(t, len(genInputResults), 2) + if len(genInputResults) >= 2 { + assert.NotContains(t, genInputResults[1], "first") + assert.Contains(t, genInputResults[1], "urgent") } +} - var processedItems []string - receiveCount := 0 - msgs := []struct { - item string - opts []ConsumeOption - err error - }{ - {"slow-msg", nil, nil}, - {"preempt-msg", []ConsumeOption{WithPreemptive(), WithCancelOptions(WithCancelMode(CancelImmediate))}, nil}, - {"", nil, context.DeadlineExceeded}, - } - frontIdx := 0 - source := &turnLoopFuncSource[string]{ - receive: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { - if receiveCount >= len(msgs) { - return ctx, "", nil, context.DeadlineExceeded - } - m := msgs[receiveCount] - receiveCount++ - frontIdx = receiveCount - return ctx, m.item, m.opts, m.err - }, - front: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { - <-slowAgent.startedCh - if frontIdx >= len(msgs) { - return ctx, "", nil, context.DeadlineExceeded - } - m := msgs[frontIdx] - return ctx, m.item, m.opts, m.err +func TestTurnLoop_Preempt_WithAgentCancelMode(t *testing.T) { + agentStarted := make(chan struct{}) + cancelFuncCalled := make(chan struct{}) + agentStartedOnce := sync.Once{} + cancelFuncCalledOnce := sync.Once{} + firstCancelModeUsed := CancelImmediate + var cancelModeMu sync.Mutex + + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + agentStartedOnce.Do(func() { + close(agentStarted) + }) + <-ctx.Done() + return &AgentOutput{}, nil + }, + onCancel: func(cc *cancelContext) { + cancelModeMu.Lock() + cancelFuncCalledOnce.Do(func() { + firstCancelModeUsed = cc.getMode() + close(cancelFuncCalled) + }) + cancelModeMu.Unlock() }, } - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: source, - GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { - processedItems = append(processedItems, item) - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil - }, - GetAgent: func(_ context.Context, item string) (Agent, error) { - if item == "slow-msg" { - return slowAgent, nil - } - return fastAgent, nil + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil }, - OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { - for { - if _, ok := iter.Next(); !ok { - break - } - } - return nil + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: []string{items[0]}, + Remaining: items[1:], + }, nil }, }) - require.NoError(t, err) - err = loop.Run(context.Background()) - assert.ErrorIs(t, err, context.DeadlineExceeded) - assert.True(t, atomic.LoadInt32(&slowAgent.cancelled) == 1, "slow agent should have been cancelled") - assert.Equal(t, []string{"slow-msg", "preempt-msg"}, processedItems) -} + loop.Push("first") -func TestTurnLoop_PreemptiveNonCancellableAgent(t *testing.T) { - nonCancellableAgent := &turnLoopMockAgent{ - name: "non-cancellable-agent", - events: []*AgentEvent{{Output: &AgentOutput{}}}, - } - fastAgent := &turnLoopMockAgent{ - name: "fast-agent", - events: []*AgentEvent{{Output: &AgentOutput{}}}, + select { + case <-agentStarted: + case <-time.After(1 * time.Second): + t.Fatal("agent did not start") } - var processedItems []string - callCount := 0 - source := &turnLoopFuncSource[string]{receive: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { - callCount++ - switch callCount { - case 1: - return ctx, "non-cancel-msg", nil, nil - case 2: - return ctx, "preempt-msg", []ConsumeOption{WithPreemptive()}, nil - default: - return ctx, "", nil, context.DeadlineExceeded - } - }} + loop.Push("urgent", WithPreempt[string](WithAgentCancelMode(CancelAfterToolCalls))) - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: source, - GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { - processedItems = append(processedItems, item) - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil - }, - GetAgent: func(_ context.Context, item string) (Agent, error) { - if item == "non-cancel-msg" { - return nonCancellableAgent, nil - } - return fastAgent, nil - }, - OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { - for { - if _, ok := iter.Next(); !ok { - break - } - } - return nil - }, - }) - require.NoError(t, err) + select { + case <-cancelFuncCalled: + case <-time.After(1 * time.Second): + t.Fatal("cancelFunc was not called by preempt") + } - err = loop.Run(context.Background()) - assert.ErrorIs(t, err, context.DeadlineExceeded) - assert.Equal(t, []string{"non-cancel-msg", "preempt-msg"}, processedItems) + loop.Stop() + result := loop.Wait() + assert.NoError(t, result.ExitReason) + cancelModeMu.Lock() + actualMode := firstCancelModeUsed + cancelModeMu.Unlock() + assert.Equal(t, CancelAfterToolCalls, actualMode) } -func TestTurnLoop_WithCancel_Basic(t *testing.T) { - agent := &turnLoopCancellableAgent{ - name: "test-agent", - startedCh: make(chan struct{}), - cancelCh: make(chan struct{}), - } +func TestTurnLoop_PreemptAck_ClosesAfterCancelIsInitiated(t *testing.T) { + agentStarted := make(chan struct{}) + cancelObserved := make(chan struct{}) + agentFinishGate := make(chan struct{}) + agentStartedOnce := sync.Once{} + cancelObservedOnce := sync.Once{} - receiveCount := int32(0) - frontBlocked := make(chan struct{}) - source := &turnLoopFuncSource[string]{ - receive: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { - cnt := atomic.AddInt32(&receiveCount, 1) - if cnt == 1 { - return ctx, "msg1", nil, nil - } - return ctx, "", nil, context.DeadlineExceeded + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + agentStartedOnce.Do(func() { close(agentStarted) }) + <-ctx.Done() + <-agentFinishGate + return &AgentOutput{}, nil }, - front: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { - <-frontBlocked - return ctx, "", nil, context.DeadlineExceeded + onCancel: func(cc *cancelContext) { + cancelObservedOnce.Do(func() { close(cancelObserved) }) }, } - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: source, - GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil - }, - GetAgent: func(_ context.Context, _ string) (Agent, error) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { return agent, nil }, - OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { - for { - if _, ok := iter.Next(); !ok { - break - } - } - return nil + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: []string{items[0]}, + Remaining: items[1:], + }, nil }, }) - require.NoError(t, err) - ctx, cancel := loop.WithCancel(context.Background()) - done := make(chan error) - go func() { - done <- loop.Run(ctx) - }() + _, _ = loop.Push("first") - <-agent.startedCh - e := cancel() - assert.NoError(t, e) + select { + case <-agentStarted: + case <-time.After(1 * time.Second): + t.Fatal("agent did not start") + } - err = <-done - assert.NoError(t, err) -} + ok, ack := loop.Push("urgent", WithPreempt[string](WithAgentCancelMode(CancelAfterToolCalls))) + assert.True(t, ok) + assert.NotNil(t, ack) -func TestTurnLoop_WithCancel_DuringAgentRun(t *testing.T) { - slowAgent := &turnLoopCancellableAgent{ - name: "slow-agent", - startedCh: make(chan struct{}), - cancelCh: make(chan struct{}), + select { + case <-ack: + case <-time.After(1 * time.Second): + t.Fatal("preempt ack was not closed") } - receiveCount := int32(0) - frontBlocked := make(chan struct{}) - source := &turnLoopFuncSource[string]{ - receive: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { - cnt := atomic.AddInt32(&receiveCount, 1) - if cnt == 1 { - return ctx, "msg1", nil, nil - } - return ctx, "", nil, context.DeadlineExceeded - }, - front: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { - <-frontBlocked - return ctx, "", nil, context.DeadlineExceeded - }, + select { + case <-cancelObserved: + case <-time.After(1 * time.Second): + t.Fatal("cancel was not initiated") } - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: source, - GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil - }, - GetAgent: func(_ context.Context, _ string) (Agent, error) { - return slowAgent, nil + close(agentFinishGate) + + loop.Stop() + result := loop.Wait() + assert.NoError(t, result.ExitReason) +} + +func TestTurnLoop_PreemptAck_ClosesImmediatelyIfLoopNotStarted(t *testing.T) { + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil }, - OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { - for { - if _, ok := iter.Next(); !ok { - break - } - } - return nil + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil }, }) - require.NoError(t, err) - - ctx, cancel := loop.WithCancel(context.Background()) - done := make(chan error) - go func() { - done <- loop.Run(ctx) - }() - <-slowAgent.startedCh - cancel(WithCancelMode(CancelImmediate)) + ok, ack := loop.Push("urgent", WithPreempt[string]()) + assert.True(t, ok) + assert.NotNil(t, ack) - err = <-done - assert.NoError(t, err) - assert.True(t, atomic.LoadInt32(&slowAgent.cancelled) == 1, "agent should have been cancelled") + select { + case <-ack: + case <-time.After(1 * time.Second): + t.Fatal("preempt ack was not closed") + } } -func TestTurnLoop_WithCancel_NonCancellableAgent_ReturnsError(t *testing.T) { - agent := &turnLoopMockAgent{ - name: "non-cancellable-agent", - events: []*AgentEvent{{Output: &AgentOutput{}}}, - } +func TestTurnLoop_Preempt_EscalatesOnSecondPreempt(t *testing.T) { + agentStarted := make(chan struct{}) + firstCancelSeen := make(chan struct{}) + agentFinishGate := make(chan struct{}) + agentStartedOnce := sync.Once{} + firstCancelOnce := sync.Once{} - source := &turnLoopFuncSource[string]{ - receive: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { - return ctx, "msg1", nil, nil + var ccPtr atomic.Value + + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + agentStartedOnce.Do(func() { close(agentStarted) }) + <-ctx.Done() + <-agentFinishGate + return &AgentOutput{}, nil + }, + onCancel: func(cc *cancelContext) { + ccPtr.Store(cc) + firstCancelOnce.Do(func() { close(firstCancelSeen) }) }, } - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: source, - GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil - }, - GetAgent: func(_ context.Context, _ string) (Agent, error) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { return agent, nil }, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: []string{items[0]}, + Remaining: items[1:], + }, nil + }, }) - require.NoError(t, err) - ctx, _ := loop.WithCancel(context.Background()) - err = loop.Run(ctx) + loop.Push("first") + select { + case <-agentStarted: + case <-time.After(1 * time.Second): + t.Fatal("agent did not start") + } - assert.ErrorIs(t, err, ErrAgentNotCancellableInTurnLoop) - assert.Contains(t, err.Error(), "non-cancellable-agent") -} + loop.Push("urgent1", WithPreempt[string](WithAgentCancelMode(CancelAfterChatModel))) + select { + case <-firstCancelSeen: + case <-time.After(1 * time.Second): + t.Fatal("first preempt did not trigger cancel") + } -func TestTurnLoop_WithCancel_MultipleCalls(t *testing.T) { - agent := &turnLoopMockAgent{ - name: "test-agent", - events: []*AgentEvent{{Output: &AgentOutput{}}}, + loop.Push("urgent2", WithPreempt[string](WithAgentCancelMode(CancelImmediate))) + + deadline := time.Now().Add(1 * time.Second) + for time.Now().Before(deadline) { + v := ccPtr.Load() + if v == nil { + time.Sleep(5 * time.Millisecond) + continue + } + cc := v.(*cancelContext) + if cc.getMode() == CancelImmediate && atomic.LoadInt32(&cc.escalated) == 1 { + break + } + time.Sleep(5 * time.Millisecond) } - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: &turnLoopMockSource{items: []string{"msg1"}, err: context.DeadlineExceeded}, - GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil - }, - GetAgent: func(_ context.Context, _ string) (Agent, error) { - return agent, nil - }, - }) - require.NoError(t, err) + v := ccPtr.Load() + if v == nil { + t.Fatal("cancel context was not captured") + } + cc := v.(*cancelContext) + assert.Equal(t, CancelImmediate, cc.getMode()) + assert.Equal(t, int32(1), atomic.LoadInt32(&cc.escalated)) - _, cancel := loop.WithCancel(context.Background()) + close(agentFinishGate) - err = cancel() - assert.NoError(t, err) + loop.Stop() + result := loop.Wait() + assert.NoError(t, result.ExitReason) } -func TestTurnLoop_WithCancel_IndependentCancels(t *testing.T) { - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: &turnLoopMockSource{items: []string{}, err: context.DeadlineExceeded}, - GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil +func TestTurnLoop_Preempt_JoinsSafePointModesOnSecondPreempt(t *testing.T) { + agentStarted := make(chan struct{}) + firstCancelSeen := make(chan struct{}) + agentFinishGate := make(chan struct{}) + agentStartedOnce := sync.Once{} + firstCancelOnce := sync.Once{} + + var ccPtr atomic.Value + + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + agentStartedOnce.Do(func() { close(agentStarted) }) + <-ctx.Done() + <-agentFinishGate + return &AgentOutput{}, nil + }, + onCancel: func(cc *cancelContext) { + ccPtr.Store(cc) + firstCancelOnce.Do(func() { close(firstCancelSeen) }) + }, + } + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil }, - GetAgent: func(_ context.Context, _ string) (Agent, error) { - return &turnLoopMockAgent{name: "a"}, nil + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: []string{items[0]}, + Remaining: items[1:], + }, nil }, }) - require.NoError(t, err) - ctx1, cancel1 := loop.WithCancel(context.Background()) - ctx2, cancel2 := loop.WithCancel(context.Background()) + loop.Push("first") + select { + case <-agentStarted: + case <-time.After(1 * time.Second): + t.Fatal("agent did not start") + } + + loop.Push("urgent1", WithPreempt[string](WithAgentCancelMode(CancelAfterChatModel))) + select { + case <-firstCancelSeen: + case <-time.After(1 * time.Second): + t.Fatal("first preempt did not trigger cancel") + } + + loop.Push("urgent2", WithPreempt[string](WithAgentCancelMode(CancelAfterToolCalls))) - cs1 := getTurnLoopCancelSig(ctx1) - cs2 := getTurnLoopCancelSig(ctx2) + want := CancelAfterChatModel | CancelAfterToolCalls + deadline := time.Now().Add(1 * time.Second) + for time.Now().Before(deadline) { + v := ccPtr.Load() + if v == nil { + time.Sleep(5 * time.Millisecond) + continue + } + cc := v.(*cancelContext) + if cc.getMode() == want { + break + } + time.Sleep(5 * time.Millisecond) + } - assert.NotNil(t, cs1) - assert.NotNil(t, cs2) - assert.NotEqual(t, cs1, cs2) + v := ccPtr.Load() + if v == nil { + t.Fatal("cancel context was not captured") + } + cc := v.(*cancelContext) + assert.Equal(t, want, cc.getMode()) - cancel1() - assert.True(t, cs1.isCancelled()) - assert.False(t, cs2.isCancelled()) + close(agentFinishGate) - cancel2() - assert.True(t, cs2.isCancelled()) + loop.Stop() + result := loop.Wait() + assert.NoError(t, result.ExitReason) } -func TestTurnLoop_WithCancel_WithCancelOptions(t *testing.T) { - slowAgent := &turnLoopCancellableAgent{ - name: "slow-agent", - startedCh: make(chan struct{}), - cancelCh: make(chan struct{}), - } +func TestTurnLoop_Push_WithoutPreempt_DoesNotCancel(t *testing.T) { + agentRunCount := 0 + agentDone := make(chan struct{}) - receiveCount := int32(0) - frontBlocked := make(chan struct{}) - source := &turnLoopFuncSource[string]{ - receive: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { - cnt := atomic.AddInt32(&receiveCount, 1) - if cnt == 1 { - return ctx, "msg1", nil, nil + agent := &turnLoopMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + agentRunCount++ + if agentRunCount == 1 { + time.Sleep(100 * time.Millisecond) } - return ctx, "", nil, context.DeadlineExceeded - }, - front: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { - <-frontBlocked - return ctx, "", nil, context.DeadlineExceeded + if agentRunCount == 2 { + close(agentDone) + } + return &AgentOutput{}, nil }, } - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: source, - GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil }, - GetAgent: func(_ context.Context, _ string) (Agent, error) { - return slowAgent, nil + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: []string{items[0]}, + Remaining: items[1:], + }, nil }, - OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { - for { - if _, ok := iter.Next(); !ok { + }) + + loop.Push("first") + time.Sleep(20 * time.Millisecond) + loop.Push("second") + + select { + case <-agentDone: + case <-time.After(1 * time.Second): + t.Fatal("second agent run did not complete") + } + + loop.Stop() + result := loop.Wait() + assert.NoError(t, result.ExitReason) + assert.Equal(t, 2, agentRunCount) +} + +func TestTurnLoop_PreemptDelay_NoMispreemptOnNaturalCompletion(t *testing.T) { + agent1Started := make(chan struct{}) + agent1Done := make(chan struct{}) + agent2Started := make(chan struct{}) + agent2Done := make(chan struct{}) + agent1StartedOnce := sync.Once{} + agent1DoneOnce := sync.Once{} + agent2StartedOnce := sync.Once{} + agent2DoneOnce := sync.Once{} + + var agentRunCount int32 + + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + count := atomic.AddInt32(&agentRunCount, 1) + if count == 1 { + agent1StartedOnce.Do(func() { close(agent1Started) }) + time.Sleep(50 * time.Millisecond) + agent1DoneOnce.Do(func() { close(agent1Done) }) + } else if count == 2 { + agent2StartedOnce.Do(func() { close(agent2Started) }) + time.Sleep(100 * time.Millisecond) + select { + case <-ctx.Done(): + t.Error("Agent2 was unexpectedly cancelled") + return nil, ctx.Err() + default: + } + agent2DoneOnce.Do(func() { close(agent2Done) }) + } + return &AgentOutput{}, nil + }, + } + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: []string{items[0]}, + Remaining: items[1:], + }, nil + }, + }) + + loop.Push("first") + + select { + case <-agent1Started: + case <-time.After(1 * time.Second): + t.Fatal("agent1 did not start") + } + + loop.Push("second", WithPreempt[string](), WithPreemptDelay[string](500*time.Millisecond)) + + select { + case <-agent1Done: + case <-time.After(1 * time.Second): + t.Fatal("agent1 did not complete naturally") + } + + select { + case <-agent2Started: + case <-time.After(1 * time.Second): + t.Fatal("agent2 did not start") + } + + select { + case <-agent2Done: + case <-time.After(1 * time.Second): + t.Fatal("agent2 did not complete - may have been incorrectly preempted") + } + + loop.Stop() + result := loop.Wait() + assert.NoError(t, result.ExitReason) + assert.Equal(t, int32(2), atomic.LoadInt32(&agentRunCount)) +} + +func TestTurnLoop_ConcurrentPush(t *testing.T) { + var count int32 + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + atomic.AddInt32(&count, int32(len(items))) + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + for j := 0; j < 10; j++ { + _, _ = loop.Push(fmt.Sprintf("msg-%d-%d", i, j)) + } + }(i) + } + + wg.Wait() + time.Sleep(200 * time.Millisecond) + + loop.Stop() + result := loop.Wait() + + processed := atomic.LoadInt32(&count) + unhandled := len(result.UnhandledItems) + + assert.True(t, processed > 0, "should have processed some items") + assert.True(t, int(processed)+unhandled <= 100, "total should not exceed pushed amount") +} + +func TestTurnLoop_StopAfterReceive_RecoverItem(t *testing.T) { + receiveStarted := make(chan struct{}) + cancelDone := make(chan struct{}) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + close(receiveStarted) + <-cancelDone + time.Sleep(50 * time.Millisecond) + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Push("msg1") + <-receiveStarted + + loop.Stop() + close(cancelDone) + + result := loop.Wait() + assert.NoError(t, result.ExitReason) +} + +func TestTurnLoop_StopAfterGenInput_RecoverConsumed(t *testing.T) { + genInputDone := make(chan struct{}) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + close(genInputDone) + time.Sleep(50 * time.Millisecond) + return &GenInputResult[string]{ + Input: &AgentInput{}, + Consumed: items[:1], + Remaining: items[1:], + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + time.Sleep(100 * time.Millisecond) + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Push("msg1") + loop.Push("msg2") + + <-genInputDone + + time.Sleep(60 * time.Millisecond) + loop.Stop() + + result := loop.Wait() + assert.NoError(t, result.ExitReason) +} + +func TestTurnLoop_GetAgentError_RecoverConsumed(t *testing.T) { + agentErr := errors.New("get agent error") + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{}, + Consumed: items[:1], + Remaining: items[1:], + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], c []string) (Agent, error) { + return nil, agentErr + }, + }) + + loop.Push("msg1") + loop.Push("msg2") + + result := loop.Wait() + assert.ErrorIs(t, result.ExitReason, agentErr) + assert.True(t, len(result.UnhandledItems) >= 1, "should recover at least the consumed item and remaining") +} + +func TestTurnLoop_GenInputError_RecoverItems(t *testing.T) { + genErr := errors.New("gen input error") + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return nil, genErr + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Push("msg1") + loop.Push("msg2") + + result := loop.Wait() + assert.ErrorIs(t, result.ExitReason, genErr) + assert.Len(t, result.UnhandledItems, 2, "should recover all items when GenInput fails") + assert.Contains(t, result.UnhandledItems, "msg1") + assert.Contains(t, result.UnhandledItems, "msg2") +} + +func TestTurnLoop_PrepareAgentError_RecoverItemsInOrder(t *testing.T) { + agentErr := errors.New("prepare agent error") + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + var urgent string + remaining := make([]string, 0, len(items)) + for _, item := range items { + if item == "urgent" { + urgent = item + } else { + remaining = append(remaining, item) + } + } + if urgent != "" { + return &GenInputResult[string]{ + Input: &AgentInput{}, + Consumed: []string{urgent}, + Remaining: remaining, + }, nil + } + return &GenInputResult[string]{ + Input: &AgentInput{}, + Consumed: items[:1], + Remaining: items[1:], + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return nil, agentErr + }, + }) + + loop.Push("msg1") + loop.Push("urgent") + loop.Push("msg2") + + result := loop.Wait() + assert.ErrorIs(t, result.ExitReason, agentErr) + assert.Len(t, result.UnhandledItems, 3, "should recover all items") + assert.Equal(t, []string{"msg1", "urgent", "msg2"}, result.UnhandledItems, + "should preserve original push order even when GenInput selects non-prefix items") +} + +// Context cancel tests: the TurnLoop monitors context cancellation by closing +// the internal buffer when ctx.Done() fires, which unblocks the blocking +// Receive() call. The loop then checks ctx.Err() and exits with the context error. + +func TestTurnLoop_ContextCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + genInputStarted := make(chan struct{}) + genInputDone := make(chan struct{}) + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + close(genInputStarted) + <-genInputDone + if err := ctx.Err(); err != nil { + return nil, err + } + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], c []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Push("msg1") + + <-genInputStarted + cancel() + close(genInputDone) + + result := loop.Wait() + assert.ErrorIs(t, result.ExitReason, context.Canceled) +} + +func TestTurnLoop_ContextDeadlineExceeded(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + select { + case <-time.After(100 * time.Millisecond): + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + case <-ctx.Done(): + return nil, ctx.Err() + } + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], c []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Push("msg1") + + result := loop.Wait() + assert.ErrorIs(t, result.ExitReason, context.DeadlineExceeded) +} + +func TestTurnLoop_ContextCancelBeforeReceive(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], c []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + // Push before Run to guarantee the item is buffered before the + // context-monitoring goroutine can close the buffer. + _, _ = loop.Push("msg1") + loop.Run(ctx) + + result := loop.Wait() + assert.ErrorIs(t, result.ExitReason, context.Canceled) + assert.Len(t, result.UnhandledItems, 1) +} + +func TestTurnLoop_ContextCancelDuringBlockingReceive(t *testing.T) { + // When context is cancelled while Receive() is blocking (no items in buffer), + // the context monitoring goroutine closes the buffer, which unblocks Receive(). + ctx, cancel := context.WithCancel(context.Background()) + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], c []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + // Don't push any items — let Receive() block + time.Sleep(50 * time.Millisecond) + cancel() + + result := loop.Wait() + assert.ErrorIs(t, result.ExitReason, context.Canceled) +} + +func TestTurnLoop_ContextCancelAfterGenInput_RecoverItems(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + genInputCount := 0 + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + genInputCount++ + if genInputCount == 1 { + cancel() + } + return &GenInputResult[string]{ + Input: &AgentInput{}, + Consumed: items[:1], + Remaining: items[1:], + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], c []string) (Agent, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Push("msg1") + loop.Push("msg2") + + result := loop.Wait() + assert.ErrorIs(t, result.ExitReason, context.Canceled) + assert.True(t, len(result.UnhandledItems) >= 1, "should recover consumed and remaining items") +} + +func TestTurnLoop_OnAgentEventsReceivesEvents(t *testing.T) { + var receivedEvents []*AgentEvent + var receivedConsumed []string + var mu sync.Mutex + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + mu.Lock() + receivedConsumed = append(receivedConsumed, tc.Consumed...) + mu.Unlock() + + for { + event, ok := events.Next() + if !ok { + break + } + mu.Lock() + receivedEvents = append(receivedEvents, event) + mu.Unlock() + } + return nil + }, + }) + + loop.Push("msg1") + + time.Sleep(100 * time.Millisecond) + + loop.Stop() + result := loop.Wait() + + assert.NoError(t, result.ExitReason) + + mu.Lock() + defer mu.Unlock() + assert.True(t, len(receivedConsumed) > 0, "should have received consumed items") +} + +func TestTurnLoop_StopDuringAgentExecution(t *testing.T) { + agentStarted := make(chan struct{}) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + close(agentStarted) + time.Sleep(200 * time.Millisecond) + for { + _, ok := events.Next() + if !ok { + break + } + } + return nil + }, + }) + + loop.Push("msg1") + + <-agentStarted + loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + + result := loop.Wait() + assert.NoError(t, result.ExitReason) + assert.Equal(t, []string{"msg1"}, result.CanceledItems) +} + +func TestTurnLoop_StopCheckPointIDInCancelError(t *testing.T) { + ctx := context.Background() + modelStarted := make(chan struct{}, 1) + checkpointID := "turn-loop-cancel-ckpt-1" + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + + slowModel := &cancelTestChatModel{ + delayNs: int64(500 * time.Millisecond), + response: &schema.Message{ + Role: schema.Assistant, + Content: "Hello", + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Instruction: "You are a test assistant", + Model: slowModel, + }) + assert.NoError(t, err) + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + Store: store, + CheckpointID: checkpointID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + }) + + loop.Push("msg1") + + <-modelStarted + loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + + result := loop.Wait() + + var cancelErr *CancelError + assert.True(t, errors.As(result.ExitReason, &cancelErr), "ExitReason should be a *CancelError") + + store.mu.Lock() + defer store.mu.Unlock() + _, ok := store.m[checkpointID] + assert.True(t, ok, "checkpoint should be saved under the configured CheckpointID") +} + +func TestTurnLoop_StopWithoutCheckpointIDDoesNotPersist(t *testing.T) { + ctx := context.Background() + modelStarted := make(chan struct{}, 1) + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + + slowModel := &cancelTestChatModel{ + delayNs: int64(500 * time.Millisecond), + response: &schema.Message{ + Role: schema.Assistant, + Content: "Hello", + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Instruction: "You are a test assistant", + Model: slowModel, + }) + assert.NoError(t, err) + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + Store: store, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + }) + + loop.Push("msg1") + + <-modelStarted + loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + + result := loop.Wait() + + var cancelErr *CancelError + assert.True(t, errors.As(result.ExitReason, &cancelErr), "ExitReason should be a *CancelError") + + store.mu.Lock() + defer store.mu.Unlock() + assert.Empty(t, store.m, "no checkpoint should be saved when CheckpointID is not configured") +} + +func TestTurnLoop_StopBetweenTurnsAndResume(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + cpID := "between-turns-session" + + loop := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Push("a") + loop.Push("b") + loop.Stop() + loop.Run(ctx) + + exit := loop.Wait() + assert.NoError(t, exit.ExitReason) + + var seen []string + var mu sync.Mutex + loop2 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + mu.Lock() + seen = append([]string{}, items...) + mu.Unlock() + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + _, ok := events.Next() + if !ok { + break + } + } + tc.Loop.Stop() + return nil + }, + }) + + loop2.Push("c") + loop2.Run(ctx) + exit2 := loop2.Wait() + assert.NoError(t, exit2.ExitReason) + + mu.Lock() + defer mu.Unlock() + assert.Equal(t, []string{"a", "b", "c"}, seen) +} + +func TestTurnLoop_StopDuringAgentExecution_PersistAndResume(t *testing.T) { + ctx := context.Background() + modelStarted := make(chan struct{}, 1) + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + cpID := "mid-turn-session" + + slowModel := &cancelTestChatModel{ + delayNs: int64(500 * time.Millisecond), + response: &schema.Message{ + Role: schema.Assistant, + Content: "Hello", + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Instruction: "You are a test assistant", + Model: slowModel, + }) + assert.NoError(t, err) + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + }) + + loop.Push("msg1") + <-modelStarted + loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + exit := loop.Wait() + + store.mu.Lock() + _, ok := store.m[cpID] + store.mu.Unlock() + assert.True(t, ok) + _ = exit + + slowModel.setDelay(10 * time.Millisecond) + + var consumed2 []string + var genResumeCalled bool + var genInputCalled bool + loop2 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenResume: func(ctx context.Context, _ *TurnLoop[string], canceledItems []string, unhandledItems []string, newItems []string) (*GenResumeResult[string], error) { + genResumeCalled = true + return &GenResumeResult[string]{ + Consumed: canceledItems, + Remaining: append(append([]string{}, unhandledItems...), newItems...), + }, nil + }, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + genInputCalled = true + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + consumed2 = append([]string{}, consumed...) + return agent, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + _, ok := events.Next() + if !ok { + break + } + } + tc.Loop.Stop() + return nil + }, + }) + + loop2.Run(ctx) + exit2 := loop2.Wait() + assert.NoError(t, exit2.ExitReason) + assert.Equal(t, []string{"msg1"}, consumed2) + assert.True(t, genResumeCalled) + assert.False(t, genInputCalled) +} + +func TestTurnLoop_CheckpointIDWithoutStore_FreshStart(t *testing.T) { + ctx := context.Background() + var genInputCalled bool + loop := NewTurnLoop(TurnLoopConfig[string]{ + CheckpointID: "some-id", + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + genInputCalled = true + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := events.Next(); !ok { + break + } + } + tc.Loop.Stop() + return nil + }, + }) + loop.Push("a") + loop.Run(ctx) + exit := loop.Wait() + assert.NoError(t, exit.ExitReason) + assert.True(t, genInputCalled) +} + +func TestTurnLoop_CheckpointNotFound_FreshStart(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + var genInputCalled bool + loop := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: "nonexistent-id", + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + genInputCalled = true + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := events.Next(); !ok { + break + } + } + tc.Loop.Stop() + return nil + }, + }) + loop.Push("a") + loop.Run(ctx) + exit := loop.Wait() + assert.NoError(t, exit.ExitReason) + assert.True(t, genInputCalled) +} + +func TestTurnLoop_CheckpointEmptyData_TreatedAsNoCheckpoint(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + store.m["cp-empty"] = nil + + var genInputCalled bool + loop := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: "cp-empty", + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + genInputCalled = true + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := events.Next(); !ok { + break + } + } + tc.Loop.Stop() + return nil + }, + }) + loop.Push("a") + loop.Run(ctx) + exit := loop.Wait() + assert.NoError(t, exit.ExitReason) + assert.True(t, genInputCalled) +} + +type errorCheckpointStore struct { + getErr error + setErr error +} + +func (s *errorCheckpointStore) Get(_ context.Context, _ string) ([]byte, bool, error) { + return nil, false, s.getErr +} + +func (s *errorCheckpointStore) Set(_ context.Context, _ string, _ []byte) error { + return s.setErr +} + +func TestTurnLoop_CheckpointLoadError_ReturnsError(t *testing.T) { + ctx := context.Background() + store := &errorCheckpointStore{getErr: fmt.Errorf("store unavailable")} + loop := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: "cp-1", + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop.Push("a") + loop.Run(ctx) + exit := loop.Wait() + assert.Error(t, exit.ExitReason) + assert.Contains(t, exit.ExitReason.Error(), "store unavailable") +} + +func TestTurnLoop_CheckpointCorruptData_ReturnsError(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + store.m["cp-corrupt"] = []byte("not-valid-gob-data") + loop := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: "cp-corrupt", + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop.Push("a") + loop.Run(ctx) + exit := loop.Wait() + assert.Error(t, exit.ExitReason) + assert.Contains(t, exit.ExitReason.Error(), "failed to unmarshal checkpoint") +} + +func TestTurnLoop_CheckpointSaveError_ReturnsError(t *testing.T) { + ctx := context.Background() + modelStarted := make(chan struct{}, 1) + saveStore := &errorCheckpointStore{setErr: fmt.Errorf("write failed")} + slowModel := &cancelTestChatModel{ + delayNs: int64(500 * time.Millisecond), + response: &schema.Message{ + Role: schema.Assistant, + Content: "Hello", + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Instruction: "You are a test assistant", + Model: slowModel, + }) + assert.NoError(t, err) + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + Store: saveStore, + CheckpointID: "cp-1", + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + }) + loop.Push("msg1") + <-modelStarted + loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + exit := loop.Wait() + assert.Error(t, exit.ExitReason) + assert.Contains(t, exit.ExitReason.Error(), "write failed") +} + +func TestTurnLoop_StaleCheckpointDeletion_OnCleanResume(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + cpID := "stale-session" + + loop1 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop1.Push("a") + loop1.Stop() + loop1.Run(ctx) + loop1.Wait() + + store.mu.Lock() + _, exists := store.m[cpID] + store.mu.Unlock() + assert.True(t, exists, "checkpoint should exist after first loop saves it") + + loop2 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := events.Next(); !ok { + break + } + } + tc.Loop.Stop() + return nil + }, + }) + loop2.Push("b") + loop2.Run(ctx) + exit2 := loop2.Wait() + assert.NoError(t, exit2.ExitReason) + + store.mu.Lock() + _, exists = store.m[cpID] + store.mu.Unlock() + assert.True(t, exists, "checkpoint should still exist because loop2 was stopped and saved a new one") +} + +func TestTurnLoop_StaleCheckpointDeletion_ContextCancel(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + cpID := "delete-on-cancel" + + loop1 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop1.Push("a") + loop1.Stop() + loop1.Run(ctx) + loop1.Wait() + + store.mu.Lock() + _, exists := store.m[cpID] + store.mu.Unlock() + assert.True(t, exists, "checkpoint saved after loop1") + + ctx2, cancel2 := context.WithCancel(ctx) + loop2 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := events.Next(); !ok { + break + } + } + cancel2() + return nil + }, + }) + loop2.Push("b") + loop2.Run(ctx2) + exit2 := loop2.Wait() + assert.ErrorIs(t, exit2.ExitReason, context.Canceled) + + store.mu.Lock() + v, exists := store.m[cpID] + store.mu.Unlock() + deletedViaNil := exists && v == nil + deletedViaAbsence := !exists + assert.True(t, deletedViaNil || deletedViaAbsence, "stale checkpoint should be deleted when loop exits via context cancellation") +} + +type deletableCheckpointStore struct { + turnLoopCheckpointStore + deleteCalled bool + deletedKey string +} + +func (s *deletableCheckpointStore) Delete(_ context.Context, key string) error { + s.mu.Lock() + defer s.mu.Unlock() + s.deleteCalled = true + s.deletedKey = key + delete(s.m, key) + return nil +} + +func TestTurnLoop_CheckpointDeleter_CalledOnContextCancel(t *testing.T) { + ctx := context.Background() + store := &deletableCheckpointStore{ + turnLoopCheckpointStore: turnLoopCheckpointStore{m: make(map[string][]byte)}, + } + cpID := "deleter-session" + + loop1 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop1.Push("a") + loop1.Stop() + loop1.Run(ctx) + loop1.Wait() + + store.mu.Lock() + _, exists := store.m[cpID] + store.mu.Unlock() + assert.True(t, exists, "checkpoint saved after loop1") + + ctx2, cancel2 := context.WithCancel(ctx) + loop2 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := events.Next(); !ok { + break + } + } + cancel2() + return nil + }, + }) + loop2.Push("b") + loop2.Run(ctx2) + exit2 := loop2.Wait() + assert.ErrorIs(t, exit2.ExitReason, context.Canceled) + + store.mu.Lock() + defer store.mu.Unlock() + assert.True(t, store.deleteCalled, "CheckPointDeleter.Delete should be called") + assert.Equal(t, cpID, store.deletedKey) + _, exists = store.m[cpID] + assert.False(t, exists, "checkpoint should be removed from store") +} + +func TestTurnLoop_GenResumeNil_Error(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + cpID := "resume-nil-session" + modelStarted := make(chan struct{}, 1) + + slowModel := &cancelTestChatModel{ + delayNs: int64(500 * time.Millisecond), + response: &schema.Message{ + Role: schema.Assistant, + Content: "Hello", + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Instruction: "You are a test assistant", + Model: slowModel, + }) + assert.NoError(t, err) + + loop1 := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + }) + loop1.Push("msg1") + <-modelStarted + loop1.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + loop1.Wait() + + loop2 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop2.Run(ctx) + exit2 := loop2.Wait() + assert.Error(t, exit2.ExitReason) + assert.Contains(t, exit2.ExitReason.Error(), "GenResume is required") +} + +func TestTurnLoop_SameCheckpointID_OverwritePattern(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + cpID := "overwrite-session" + + loop1 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop1.Push("a") + loop1.Push("b") + loop1.Stop() + loop1.Run(ctx) + loop1.Wait() + + store.mu.Lock() + data1 := append([]byte{}, store.m[cpID]...) + store.mu.Unlock() + assert.NotEmpty(t, data1) + + loop2 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop2.Push("c") + loop2.Stop() + loop2.Run(ctx) + loop2.Wait() + + store.mu.Lock() + data2 := append([]byte{}, store.m[cpID]...) + store.mu.Unlock() + assert.NotEmpty(t, data2) + assert.NotEqual(t, data1, data2, "checkpoint data should change because items are different") + + var seen []string + var mu sync.Mutex + loop3 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + mu.Lock() + seen = append([]string{}, items...) + mu.Unlock() + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := events.Next(); !ok { + break + } + } + tc.Loop.Stop() + return nil + }, + }) + loop3.Push("d") + loop3.Run(ctx) + exit3 := loop3.Wait() + assert.NoError(t, exit3.ExitReason) + + mu.Lock() + defer mu.Unlock() + assert.Equal(t, []string{"a", "b", "c", "d"}, seen, "should see loop2's unhandled items (a,b,c from loop2's checkpoint) plus new d") +} + +func TestTurnLoop_CheckpointHasRunnerStateButEmptyBytes(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + cpID := "empty-runner-bytes" + + cp := &turnLoopCheckpoint[string]{ + HasRunnerState: true, + RunnerCheckpoint: nil, + UnhandledItems: []string{"x"}, + } + data, err := marshalTurnLoopCheckpoint(cp) + assert.NoError(t, err) + store.m[cpID] = data + + loop := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop.Push("a") + loop.Run(ctx) + exit := loop.Wait() + assert.Error(t, exit.ExitReason) + assert.Contains(t, exit.ExitReason.Error(), "has runner state but bytes are empty") +} + +func TestTurnLoop_GenResumeReturnsError(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + cpID := "resume-err-session" + modelStarted := make(chan struct{}, 1) + + slowModel := &cancelTestChatModel{ + delayNs: int64(500 * time.Millisecond), + response: &schema.Message{ + Role: schema.Assistant, + Content: "Hello", + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Instruction: "You are a test assistant", + Model: slowModel, + }) + assert.NoError(t, err) + + loop1 := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + }) + loop1.Push("msg1") + <-modelStarted + loop1.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + loop1.Wait() + + genResumeErr := fmt.Errorf("resume callback failed") + loop2 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + GenResume: func(ctx context.Context, _ *TurnLoop[string], canceled, unhandled, newItems []string) (*GenResumeResult[string], error) { + return nil, genResumeErr + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop2.Run(ctx) + exit2 := loop2.Wait() + assert.Error(t, exit2.ExitReason) + assert.ErrorIs(t, exit2.ExitReason, genResumeErr) +} + +func TestTurnLoop_CheckpointSaveError_MergesWithExistingError(t *testing.T) { + ctx := context.Background() + modelStarted := make(chan struct{}, 1) + saveStore := &errorCheckpointStore{setErr: fmt.Errorf("disk full")} + slowModel := &cancelTestChatModel{ + delayNs: int64(500 * time.Millisecond), + response: &schema.Message{ + Role: schema.Assistant, + Content: "Hello", + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Instruction: "You are a test assistant", + Model: slowModel, + }) + assert.NoError(t, err) + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + Store: saveStore, + CheckpointID: "cp-merge-err", + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + }) + loop.Push("msg1") + <-modelStarted + loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + exit := loop.Wait() + assert.Error(t, exit.ExitReason) + errStr := exit.ExitReason.Error() + assert.Contains(t, errStr, "disk full") + var ce *CancelError + assert.True(t, errors.As(exit.ExitReason, &ce), "should wrap original CancelError") +} + +func TestTurnLoop_ResumeWithParams(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + cpID := "resume-params-session" + modelStarted := make(chan struct{}, 1) + + slowModel := &cancelTestChatModel{ + delayNs: int64(500 * time.Millisecond), + response: &schema.Message{ + Role: schema.Assistant, + Content: "Hello", + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Instruction: "You are a test assistant", + Model: slowModel, + }) + assert.NoError(t, err) + + loop1 := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + }) + loop1.Push("msg1") + <-modelStarted + loop1.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + exit1 := loop1.Wait() + var ce *CancelError + assert.True(t, errors.As(exit1.ExitReason, &ce)) + + var resumeParamsUsed *ResumeParams + loop2 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + GenResume: func(ctx context.Context, _ *TurnLoop[string], canceled, unhandled, newItems []string) (*GenResumeResult[string], error) { + params := &ResumeParams{ + Targets: map[string]any{"some-address": "user-data"}, + } + resumeParamsUsed = params + return &GenResumeResult[string]{ + ResumeParams: params, + Consumed: append(append(canceled, unhandled...), newItems...), + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := events.Next(); !ok { + break + } + } + tc.Loop.Stop() + return nil + }, + }) + loop2.Run(ctx) + exit2 := loop2.Wait() + assert.NotNil(t, resumeParamsUsed, "GenResume should have been called with ResumeParams") + assert.Contains(t, resumeParamsUsed.Targets, "some-address") + _ = exit2 +} + +func TestTurnLoop_StopOptionsArePassed(t *testing.T) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelAfterToolCalls))) + + result := loop.Wait() + assert.NoError(t, result.ExitReason) +} + +func TestTurnLoop_Stop_EscalatesCancelMode(t *testing.T) { + ctx := context.Background() + agentStarted := make(chan *cancelContext, 1) + probe := &turnLoopStopModeProbeAgent{ccCh: agentStarted} + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return probe, nil + }, + }) + + loop.Push("msg1") + cc := <-agentStarted + + loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelAfterToolCalls), WithAgentCancelTimeout(10*time.Second))) + loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + + deadline := time.After(1 * time.Second) + for { + if cc.getMode() == CancelImmediate { + break + } + select { + case <-deadline: + t.Fatal("cancel mode did not escalate to CancelImmediate") + default: + } + time.Sleep(1 * time.Millisecond) + } + + exit := loop.Wait() + var ce *CancelError + assert.True(t, errors.As(exit.ExitReason, &ce)) + assert.Equal(t, CancelImmediate, ce.Info.Mode) +} + +func TestTurnLoop_DefaultOnAgentEvents_ErrorPropagation(t *testing.T) { + agentErr := errors.New("agent execution error") + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + return nil, agentErr + }, + }, nil + }, + // No OnAgentEvents — use default handler + }) + + loop.Push("msg1") + + result := loop.Wait() + // The default handler should propagate the agent error as ExitReason + assert.Error(t, result.ExitReason) +} + +func TestTurnLoop_OnAgentEventsError(t *testing.T) { + handlerErr := errors.New("event handler error") + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + // Drain events then return error + for { + _, ok := events.Next() + if !ok { + break + } + } + return handlerErr + }, + }) + + loop.Push("msg1") + + result := loop.Wait() + assert.ErrorIs(t, result.ExitReason, handlerErr) +} + +func TestTurnLoop_StopCallFromGenInput(t *testing.T) { + // Test that calling Stop() from within GenInput works correctly + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, loop *TurnLoop[string], items []string) (*GenInputResult[string], error) { + loop.Stop() + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Push("msg1") + + result := loop.Wait() + assert.NoError(t, result.ExitReason) +} + +func TestTurnLoop_PushFromOnAgentEvents(t *testing.T) { + // Test that calling Push() from within OnAgentEvents works + pushCount := int32(0) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: []string{items[0]}, + Remaining: items[1:], + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + _, ok := events.Next() + if !ok { + break + } + } + count := atomic.AddInt32(&pushCount, 1) + if count == 1 { + // Push a follow-up item from the callback + _, _ = tc.Loop.Push("follow-up") + } else { + tc.Loop.Stop() + } + return nil + }, + }) + + loop.Push("initial") + + result := loop.Wait() + assert.NoError(t, result.ExitReason) + assert.GreaterOrEqual(t, atomic.LoadInt32(&pushCount), int32(2)) +} + +// Tests for NewTurnLoop: the permissive API where Push, Stop, and Wait are +// all valid on a not-yet-running loop. + +func TestNewTurnLoop_PushBeforeRun(t *testing.T) { + // Items pushed before Run are buffered and processed after Run starts. + var processedItems []string + var mu sync.Mutex + + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + mu.Lock() + processedItems = append(processedItems, items...) + mu.Unlock() + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + // Push before Run — items should be buffered. + ok, _ := loop.Push("msg1") + assert.True(t, ok) + ok, _ = loop.Push("msg2") + assert.True(t, ok) + + loop.Run(context.Background()) + + time.Sleep(100 * time.Millisecond) + + loop.Stop() + result := loop.Wait() + + mu.Lock() + defer mu.Unlock() + + assert.NoError(t, result.ExitReason) + assert.Contains(t, processedItems, "msg1") + assert.Contains(t, processedItems, "msg2") +} + +func TestNewTurnLoop_StopBeforeRun(t *testing.T) { + // Stop before Run sets the stopped flag. When Run is called, the loop + // exits immediately and buffered items appear as UnhandledItems. + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + t.Fatal("GenInput should not be called") + return nil, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + t.Fatal("PrepareAgent should not be called") + return nil, nil + }, + }) + + loop.Push("msg1") + loop.Push("msg2") + loop.Stop() + + // Push after Stop returns false. + ok, _ := loop.Push("msg3") + assert.False(t, ok) + + loop.Run(context.Background()) + result := loop.Wait() + + assert.NoError(t, result.ExitReason) + assert.Equal(t, []string{"msg1", "msg2"}, result.UnhandledItems) +} + +func TestNewTurnLoop_WaitBeforeRun(t *testing.T) { + // Wait blocks until Run is called AND the loop exits. + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + waitDone := make(chan *TurnLoopExitState[string], 1) + go func() { + waitDone <- loop.Wait() + }() + + // Wait should not return yet since Run hasn't been called. + select { + case <-waitDone: + t.Fatal("Wait returned before Run was called") + case <-time.After(50 * time.Millisecond): + // expected + } + + loop.Push("msg1") + loop.Stop() + loop.Run(context.Background()) + + select { + case result := <-waitDone: + assert.NoError(t, result.ExitReason) + assert.Equal(t, []string{"msg1"}, result.UnhandledItems) + case <-time.After(1 * time.Second): + t.Fatal("Wait did not return after Run + Stop") + } +} + +func TestNewTurnLoop_RunIsIdempotent(t *testing.T) { + var genInputCalls int32 + + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + atomic.AddInt32(&genInputCalls, 1) + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Push("msg1") + loop.Run(context.Background()) + loop.Run(context.Background()) + loop.Run(context.Background()) + + time.Sleep(100 * time.Millisecond) + + loop.Stop() + result := loop.Wait() + + assert.NoError(t, result.ExitReason) + assert.True(t, atomic.LoadInt32(&genInputCalls) >= 1) +} + +func TestNewTurnLoop_StopBeforeRun_ThenWait(t *testing.T) { + // Demonstrates the full sequence: create, push, stop, run, wait. + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + t.Fatal("GenInput should not be called after Stop") + return nil, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + t.Fatal("PrepareAgent should not be called after Stop") + return nil, nil + }, + }) + + loop.Push("a") + loop.Push("b") + loop.Push("c") + loop.Stop() + + // Run after Stop: the loop goroutine starts but exits immediately. + loop.Run(context.Background()) + + result := loop.Wait() + assert.NoError(t, result.ExitReason) + assert.Equal(t, []string{"a", "b", "c"}, result.UnhandledItems) +} + +func TestNewTurnLoop_ConcurrentPushAndRun(t *testing.T) { + // Concurrent Push and Run should not race. + for i := 0; i < 100; i++ { + var count int32 + + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + atomic.AddInt32(&count, int32(len(items))) + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + _, _ = loop.Push("item") + }() + + go func() { + defer wg.Done() + loop.Run(context.Background()) + }() + + wg.Wait() + + time.Sleep(50 * time.Millisecond) + + loop.Stop() + result := loop.Wait() + assert.NoError(t, result.ExitReason) + + processed := atomic.LoadInt32(&count) + unhandled := len(result.UnhandledItems) + assert.True(t, int(processed)+unhandled <= 1, + "total should not exceed pushed amount") + } +} + +type turnCtxKey struct{} + +func TestTurnLoop_RunCtx_Propagation(t *testing.T) { + // Verify that GenInputResult.RunCtx is propagated to PrepareAgent, + // the agent run, and OnAgentEvents. + + const traceVal = "trace-123" + var prepareCtxVal, agentCtxVal, eventsCtxVal string + + cfg := TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, loop *TurnLoop[string], items []string) (*GenInputResult[string], error) { + // Derive a new context with per-item trace data + runCtx := context.WithValue(ctx, turnCtxKey{}, traceVal) + return &GenInputResult[string]{ + RunCtx: runCtx, + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, loop *TurnLoop[string], consumed []string) (Agent, error) { + if v, ok := ctx.Value(turnCtxKey{}).(string); ok { + prepareCtxVal = v + } + return &turnLoopMockAgent{ + name: "trace-agent", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + if v, ok := ctx.Value(turnCtxKey{}).(string); ok { + agentCtxVal = v + } + return &AgentOutput{}, nil + }, + }, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + if v, ok := ctx.Value(turnCtxKey{}).(string); ok { + eventsCtxVal = v + } + for { + if _, ok := events.Next(); !ok { + break + } + } + tc.Loop.Stop() + return nil + }, + } + + loop := NewTurnLoop(cfg) + loop.Push("hello") + loop.Run(context.Background()) + result := loop.Wait() + + assert.Nil(t, result.ExitReason) + assert.Equal(t, traceVal, prepareCtxVal, "PrepareAgent should receive RunCtx") + assert.Equal(t, traceVal, agentCtxVal, "Agent run should receive RunCtx") + assert.Equal(t, traceVal, eventsCtxVal, "OnAgentEvents should receive RunCtx") +} + +func TestTurnLoop_TurnContext_PreemptedChannel(t *testing.T) { + preemptedSeen := make(chan struct{}) + agentStarted := make(chan struct{}) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopCancellableMockAgent{ + name: "slow", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + <-ctx.Done() + return nil, ctx.Err() + }, + }, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + close(agentStarted) + select { + case <-tc.Preempted: + close(preemptedSeen) + case <-time.After(5 * time.Second): + t.Error("timed out waiting for Preempted channel") + } + // Drain events + for { + if _, ok := events.Next(); !ok { break } } return nil }, }) - require.NoError(t, err) - ctx, cancel := loop.WithCancel(context.Background()) - done := make(chan error) + loop.Push("msg1") + <-agentStarted + loop.Push("msg2", WithPreempt[string](WithAgentCancelMode(CancelImmediate))) + + select { + case <-preemptedSeen: + // success + case <-time.After(5 * time.Second): + t.Fatal("preempted channel was never observed in OnAgentEvents") + } + + loop.Stop() + loop.Wait() +} + +// ============================================================================= +// preemptSignal unit tests (direct testing of the hold/preempt/unhold mechanism) +// ============================================================================= + +func TestPreemptSignal_HoldCountLifecycle(t *testing.T) { + s := newPreemptSignal() + + s.holdRunLoop() + s.holdRunLoop() + + done := make(chan bool) go func() { - done <- loop.Run(ctx) + preempted, _, _ := s.waitForPreemptOrUnhold() + done <- preempted }() - <-slowAgent.startedCh - cancel(WithCancelMode(CancelAfterToolCall)) + select { + case <-done: + t.Fatal("waitForPreemptOrUnhold should block while holdCount > 0") + case <-time.After(50 * time.Millisecond): + } - <-done - assert.Len(t, slowAgent.cancelledOpt, 1) -} + s.unholdRunLoop() -type turnLoopInMemoryStore struct { - data map[string][]byte -} + select { + case <-done: + t.Fatal("waitForPreemptOrUnhold should still block (holdCount=1)") + case <-time.After(50 * time.Millisecond): + } -func newTurnLoopInMemoryStore() *turnLoopInMemoryStore { - return &turnLoopInMemoryStore{data: make(map[string][]byte)} -} + s.unholdRunLoop() -func (s *turnLoopInMemoryStore) Get(_ context.Context, key string) ([]byte, bool, error) { - v, ok := s.data[key] - return v, ok, nil + select { + case preempted := <-done: + assert.False(t, preempted, "should return not-preempted when all holds released") + case <-time.After(1 * time.Second): + t.Fatal("waitForPreemptOrUnhold should unblock when holdCount reaches 0") + } } -func (s *turnLoopInMemoryStore) Set(_ context.Context, key string, value []byte) error { - s.data[key] = value - return nil -} +func TestPreemptSignal_RequestPreemptWithNoHold(t *testing.T) { + s := newPreemptSignal() -type turnLoopTestModel struct { - messages []*schema.Message - idx int -} + ack := make(chan struct{}) + s.requestPreempt(ack) -func (m *turnLoopTestModel) Generate(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { - if m.idx >= len(m.messages) { - return nil, fmt.Errorf("no more messages") + select { + case <-ack: + case <-time.After(100 * time.Millisecond): + t.Fatal("ack should be closed immediately when holdCount is 0") } - msg := m.messages[m.idx] - m.idx++ - return msg, nil } -func (m *turnLoopTestModel) Stream(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { - panic("not implemented") -} +func TestPreemptSignal_RequestPreemptWakesWaiter(t *testing.T) { + s := newPreemptSignal() + s.holdRunLoop() + + done := make(chan struct { + preempted bool + ackList []chan struct{} + }) + go func() { + preempted, _, ackList := s.waitForPreemptOrUnhold() + done <- struct { + preempted bool + ackList []chan struct{} + }{preempted, ackList} + }() -func (m *turnLoopTestModel) WithTools(_ []*schema.ToolInfo) (model.ToolCallingChatModel, error) { - return m, nil + ack := make(chan struct{}) + s.requestPreempt(ack) + + select { + case result := <-done: + assert.True(t, result.preempted) + assert.Len(t, result.ackList, 1) + close(result.ackList[0]) + case <-time.After(1 * time.Second): + t.Fatal("waitForPreemptOrUnhold should wake on requestPreempt") + } } -type turnLoopSlowModel struct { - delay int64 - startedCh unsafe.Pointer - doneCh chan struct{} - message *schema.Message +func TestPreemptSignal_HoldAndGetTurn(t *testing.T) { + s := newPreemptSignal() + s.setTurn(context.Background(), "turn-A") + + ctx, tc := s.holdAndGetTurn() + assert.NotNil(t, ctx) + assert.Equal(t, "turn-A", tc) + + s.endTurnAndUnhold() + + _, tc2 := s.holdAndGetTurn() + assert.Nil(t, tc2, "TC should be nil after endTurnAndUnhold") + s.unholdRunLoop() } -func (m *turnLoopSlowModel) Generate(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { - if ch := (*chan struct{})(atomic.LoadPointer(&m.startedCh)); ch != nil { - select { - case *ch <- struct{}{}: - default: +func TestPreemptSignal_EndTurnPreservesSignalWhenHoldRemains(t *testing.T) { + s := newPreemptSignal() + + s.holdRunLoop() + s.holdRunLoop() + + ack := make(chan struct{}) + s.requestPreempt(ack) + + s.endTurnAndUnhold() + + done := make(chan bool) + go func() { + preempted, _, ackList := s.waitForPreemptOrUnhold() + for _, a := range ackList { + close(a) } + done <- preempted + }() + + select { + case preempted := <-done: + assert.True(t, preempted, "signal state should be preserved when holdCount > 0 after endTurnAndUnhold") + case <-time.After(1 * time.Second): + t.Fatal("waiter should see the preserved preempt signal") } - if delay := atomic.LoadInt64(&m.delay); delay > 0 { - time.Sleep(time.Duration(delay)) - } - if m.doneCh != nil { - select { - case m.doneCh <- struct{}{}: - default: - } + + select { + case <-ack: + case <-time.After(100 * time.Millisecond): + t.Fatal("ack should have been closed") } - return m.message, nil } -func (m *turnLoopSlowModel) Stream(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { - panic("not implemented") +func TestPreemptSignal_ConcurrentHoldRequestUnhold(t *testing.T) { + s := newPreemptSignal() + + var wg sync.WaitGroup + for i := 0; i < 50; i++ { + wg.Add(1) + go func() { + defer wg.Done() + s.holdRunLoop() + ack := make(chan struct{}) + s.requestPreempt(ack) + s.unholdRunLoop() + <-ack + }() + } + wg.Wait() } -type turnLoopInterruptTool struct { - name string -} +// ============================================================================= +// Integration tests for race-prone preempt scenarios +// ============================================================================= -func (t *turnLoopInterruptTool) Info(_ context.Context) (*schema.ToolInfo, error) { - return &schema.ToolInfo{ - Name: t.name, - Desc: "A tool that interrupts", - }, nil -} +func TestTurnLoop_ConcurrentPreemptsDuringTurn(t *testing.T) { + agentStarted := make(chan struct{}) + agentStartedOnce := sync.Once{} -func (t *turnLoopInterruptTool) InvokableRun(ctx context.Context, _ string, _ ...tool.Option) (string, error) { - wasInterrupted, _, _ := tool.GetInterruptState[any](ctx) - if !wasInterrupted { - return "", tool.Interrupt(ctx, "need approval") - } - isResumeTarget, hasData, data := tool.GetResumeContext[string](ctx) - if isResumeTarget && hasData { - return data, nil + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + agentStartedOnce.Do(func() { + close(agentStarted) + }) + <-ctx.Done() + return &AgentOutput{}, nil + }, } - return "approved", nil -} -func TestTurnLoop_ExternalCancel_WithStore(t *testing.T) { - store := newTurnLoopInMemoryStore() - const checkPointID = "external-cancel-test" + var genInputCount int32 - modelStarted := make(chan struct{}, 1) - testModel := &turnLoopSlowModel{ - doneCh: make(chan struct{}, 1), - message: schema.AssistantMessage("task completed", nil), + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + atomic.AddInt32(&genInputCount, 1) + return &GenInputResult[string]{ + Input: &AgentInput{}, + Consumed: items, + }, nil + }, + }) + + loop.Push("seed") + + select { + case <-agentStarted: + case <-time.After(1 * time.Second): + t.Fatal("agent did not start") } - atomic.StoreInt64(&testModel.delay, int64(5*time.Second)) - atomic.StorePointer(&testModel.startedCh, unsafe.Pointer(&modelStarted)) - agent, err := NewChatModelAgent(context.Background(), &ChatModelAgentConfig{ - Name: "test-agent", - Description: "test agent for external cancel", - Model: testModel, - }) - require.NoError(t, err) + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + ok, ack := loop.Push(fmt.Sprintf("urgent-%d", i), WithPreempt[string]()) + if ok && ack != nil { + select { + case <-ack: + case <-time.After(5 * time.Second): + t.Error("ack channel not closed within timeout") + } + } + }(i) + } + + wg.Wait() + time.Sleep(200 * time.Millisecond) + + loop.Stop() + result := loop.Wait() + assert.NoError(t, result.ExitReason) + assert.True(t, atomic.LoadInt32(&genInputCount) >= 2, "should have had at least the initial turn + one preempted turn") +} - receiveCount := int32(0) +func TestTurnLoop_PreemptDuringTurnTransition(t *testing.T) { turnCount := int32(0) - frontBlocked := make(chan struct{}) - source := &turnLoopFuncSource[string]{ - receive: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { - cnt := atomic.AddInt32(&receiveCount, 1) - if cnt == 1 { - return ctx, "msg1", []ConsumeOption{WithConsumeCheckPointID(checkPointID)}, nil + firstTurnDone := make(chan struct{}) + firstTurnOnce := sync.Once{} + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "fast"}, nil + }, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + count := atomic.AddInt32(&turnCount, 1) + if count == 1 { + firstTurnOnce.Do(func() { + close(firstTurnDone) + }) } - return ctx, "", nil, ErrLoopExit - }, - front: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { - <-frontBlocked - return ctx, "", nil, context.DeadlineExceeded + return &GenInputResult[string]{ + Input: &AgentInput{}, + Consumed: items, + }, nil }, + }) + + loop.Push("first") + + select { + case <-firstTurnDone: + case <-time.After(1 * time.Second): + t.Fatal("first turn did not start") + } + + time.Sleep(50 * time.Millisecond) + + ok, ack := loop.Push("transitional", WithPreempt[string]()) + assert.True(t, ok, "push should succeed") + if ack != nil { + select { + case <-ack: + case <-time.After(2 * time.Second): + t.Fatal("ack should be closed even if preempt arrived during/after turn transition") + } } - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: source, - GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil + time.Sleep(100 * time.Millisecond) + + loop.Stop() + result := loop.Wait() + assert.NoError(t, result.ExitReason) + assert.True(t, atomic.LoadInt32(&turnCount) >= 2, "transitional item should have been processed") +} + +func TestTurnLoop_PushStrategy_DuringTurnTransition(t *testing.T) { + agentStarted := make(chan struct{}) + agentStartedOnce := sync.Once{} + allowFinish := make(chan struct{}) + + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + agentStartedOnce.Do(func() { + close(agentStarted) + }) + select { + case <-allowFinish: + return &AgentOutput{}, nil + case <-ctx.Done(): + return &AgentOutput{}, nil + } }, - GetAgent: func(_ context.Context, _ string) (Agent, error) { + } + + var genInputCount int32 + secondTurnDone := make(chan struct{}) + secondTurnOnce := sync.Once{} + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { return agent, nil }, - OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { - atomic.AddInt32(&turnCount, 1) - for { - if _, ok := iter.Next(); !ok { - break - } + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + count := atomic.AddInt32(&genInputCount, 1) + if count >= 2 { + secondTurnOnce.Do(func() { + close(secondTurnDone) + }) } - return nil + return &GenInputResult[string]{ + Input: &AgentInput{}, + Consumed: items, + }, nil }, - Store: store, }) - require.NoError(t, err) - ctx, cancel := loop.WithCancel(context.Background()) - done := make(chan error) + loop.Push("first") + + select { + case <-agentStarted: + case <-time.After(1 * time.Second): + t.Fatal("agent did not start") + } + + strategyBlocker := make(chan struct{}) + var strategyTCNotNil int32 + go func() { - done <- loop.Run(ctx) + loop.Push("strategic-item", WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { + if tc != nil { + atomic.StoreInt32(&strategyTCNotNil, 1) + } + <-strategyBlocker + return []PushOption[string]{WithPreempt[string]()} + })) }() + time.Sleep(50 * time.Millisecond) + close(allowFinish) + time.Sleep(50 * time.Millisecond) + close(strategyBlocker) + select { - case <-modelStarted: - case err := <-done: - t.Fatalf("loop.Run returned early with error: %v", err) + case <-secondTurnDone: case <-time.After(3 * time.Second): - t.Fatalf("timeout waiting for model to start") + t.Fatal("second turn should eventually run after strategy resolves") } - err = cancel(WithCancelMode(CancelImmediate)) + loop.Stop() + result := loop.Wait() + assert.NoError(t, result.ExitReason) + assert.True(t, atomic.LoadInt32(&genInputCount) >= 2) +} - var runErr error - select { - case runErr = <-done: - case <-time.After(3 * time.Second): - t.Fatalf("timeout waiting for loop.Run to return after cancel") +func TestTurnLoop_ConcurrentPreemptAndStop(t *testing.T) { + for iter := 0; iter < 20; iter++ { + t.Run(fmt.Sprintf("iter_%d", iter), func(t *testing.T) { + ctx := context.Background() + + agentStarted := make(chan struct{}) + agentStartedOnce := sync.Once{} + + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + agentStartedOnce.Do(func() { + close(agentStarted) + }) + <-ctx.Done() + return &AgentOutput{}, nil + }, + } + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{}, + Consumed: items, + }, nil + }, + }) + + loop.Push("seed") + + select { + case <-agentStarted: + case <-time.After(1 * time.Second): + t.Fatal("agent did not start") + } + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + _, ack := loop.Push("preempt-item", WithPreempt[string]()) + if ack != nil { + <-ack + } + }() + + go func() { + defer wg.Done() + loop.Stop() + }() + + wg.Wait() + loop.Wait() + }) } +} + +func TestTurnLoop_ConcurrentPushStrategyAndStop(t *testing.T) { + for iter := 0; iter < 20; iter++ { + t.Run(fmt.Sprintf("iter_%d", iter), func(t *testing.T) { + ctx := context.Background() + + agentStarted := make(chan struct{}) + agentStartedOnce := sync.Once{} + + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + agentStartedOnce.Do(func() { + close(agentStarted) + }) + <-ctx.Done() + return &AgentOutput{}, nil + }, + } - var interruptErr *TurnLoopInterruptError[string] - require.True(t, errors.As(runErr, &interruptErr), "expected TurnLoopInterruptError, got: %v", runErr) - assert.Equal(t, checkPointID, interruptErr.CheckpointID) - assert.Equal(t, "msg1", interruptErr.Item) - assert.True(t, len(store.data) > 0, "checkpoint should be stored") + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{}, + Consumed: items, + }, nil + }, + }) + + loop.Push("seed") + + select { + case <-agentStarted: + case <-time.After(1 * time.Second): + t.Fatal("agent did not start") + } - atomic.StorePointer(&testModel.startedCh, nil) - atomic.StoreInt64(&testModel.delay, 0) + var wg sync.WaitGroup + wg.Add(2) - done = make(chan error) - go func() { - done <- loop.Run(context.Background(), WithTurnLoopResume(checkPointID, "msg1")) - }() + go func() { + defer wg.Done() + _, ack := loop.Push("strategic-item", WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { + return []PushOption[string]{WithPreempt[string]()} + })) + if ack != nil { + <-ack + } + }() + + go func() { + defer wg.Done() + loop.Stop() + }() + + wg.Wait() + loop.Wait() + }) + } +} +func TestTurnLoop_TurnContext_StoppedChannel(t *testing.T) { + stoppedSeen := make(chan struct{}) + agentStarted := make(chan struct{}) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopCancellableMockAgent{ + name: "slow", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + <-ctx.Done() + return nil, ctx.Err() + }, + }, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + close(agentStarted) + select { + case <-tc.Stopped: + close(stoppedSeen) + case <-time.After(5 * time.Second): + t.Error("timed out waiting for Stopped channel") + } + // Drain events + for { + if _, ok := events.Next(); !ok { + break + } + } + return nil + }, + }) + + loop.Push("msg1") + <-agentStarted + loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) select { - case runErr = <-done: + case <-stoppedSeen: + // success case <-time.After(5 * time.Second): - t.Fatalf("timeout waiting for loop.Run (resume) to return") + t.Fatal("stopped channel was never observed in OnAgentEvents") } - assert.NoError(t, runErr) - assert.Equal(t, int32(2), atomic.LoadInt32(&turnCount), "should have 2 turns total (1 from first run + 1 from resume)") + + loop.Wait() } -func TestTurnLoop_InternalToolInterrupt_WithCheckpoint_ThenResume(t *testing.T) { - store := newTurnLoopInMemoryStore() - const checkPointID = "tool-interrupt-test" +func TestTurnLoop_PushStrategy_DuringTurn(t *testing.T) { + agentStarted := make(chan struct{}) + agentStartedOnce := sync.Once{} + agentCancelled := make(chan struct{}) + agentCancelledOnce := sync.Once{} + + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + agentStartedOnce.Do(func() { + close(agentStarted) + }) + <-ctx.Done() + agentCancelledOnce.Do(func() { + close(agentCancelled) + }) + return &AgentOutput{}, nil + }, + } - interruptTool := &turnLoopInterruptTool{name: "approval_tool"} + genInputCalls := int32(0) + secondGenInputCalled := make(chan struct{}) + secondGenInputOnce := sync.Once{} - testModel := &turnLoopTestModel{ - messages: []*schema.Message{ - schema.AssistantMessage("", []schema.ToolCall{ - {ID: "call1", Function: schema.FunctionCall{Name: "approval_tool", Arguments: "{}"}}, - }), - schema.AssistantMessage("task completed", nil), + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil }, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + count := atomic.AddInt32(&genInputCalls, 1) + if count >= 2 { + secondGenInputOnce.Do(func() { + close(secondGenInputCalled) + }) + } + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: []string{items[0]}, + Remaining: items[1:], + }, nil + }, + }) + + loop.Push("first") + + select { + case <-agentStarted: + case <-time.After(1 * time.Second): + t.Fatal("agent did not start") } - agent, err := NewChatModelAgent(context.Background(), &ChatModelAgentConfig{ - Name: "test-agent", - Description: "test agent", - Model: testModel, - ToolsConfig: ToolsConfig{ - ToolsNodeConfig: compose.ToolsNodeConfig{ - Tools: []tool.BaseTool{interruptTool}, - }, + // Strategy inspects TurnContext during a running turn and decides to preempt. + var strategyCalled int32 + var strategyTC *TurnContext[string] + loop.Push("urgent", WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { + atomic.AddInt32(&strategyCalled, 1) + strategyTC = tc + return []PushOption[string]{WithPreempt[string]()} + })) + + select { + case <-agentCancelled: + case <-time.After(1 * time.Second): + t.Fatal("agent was not cancelled by strategy-returned preempt") + } + + select { + case <-secondGenInputCalled: + case <-time.After(1 * time.Second): + t.Fatal("second GenInput was not called after preempt") + } + + loop.Stop() + loop.Wait() + + assert.Equal(t, int32(1), atomic.LoadInt32(&strategyCalled)) + assert.NotNil(t, strategyTC, "strategy should receive non-nil TurnContext during a turn") + assert.Equal(t, []string{"first"}, strategyTC.Consumed) +} + +func TestTurnLoop_PushStrategy_BetweenTurns(t *testing.T) { + // Push with strategy before Run() — TurnContext should be nil. + var strategyCalled int32 + var strategyTCWasNil bool + + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + return &AgentOutput{}, nil }, - }) - require.NoError(t, err) + } - receiveCount := int32(0) - frontBlocked := make(chan struct{}) - source := &turnLoopFuncSource[string]{ - receive: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { - cnt := atomic.AddInt32(&receiveCount, 1) - if cnt == 1 { - return ctx, "msg1", []ConsumeOption{WithConsumeCheckPointID(checkPointID)}, nil + agentDone := make(chan struct{}) + agentDoneOnce := sync.Once{} + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + Remaining: nil, + }, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + _, ok := events.Next() + if !ok { + break + } } - return ctx, "", nil, ErrLoopExit + agentDoneOnce.Do(func() { + close(agentDone) + }) + return nil + }, + }) + + // Push with strategy — no turn is active yet, so tc should be nil. + loop.Push("item", WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { + atomic.AddInt32(&strategyCalled, 1) + strategyTCWasNil = (tc == nil) + return nil // plain push, no preempt + })) + + select { + case <-agentDone: + case <-time.After(2 * time.Second): + t.Fatal("agent did not complete") + } + + loop.Stop() + loop.Wait() + + assert.Equal(t, int32(1), atomic.LoadInt32(&strategyCalled)) + assert.True(t, strategyTCWasNil, "strategy should receive nil TurnContext between turns") +} + +func TestTurnLoop_PushStrategy_OverridesOtherOptions(t *testing.T) { + // Push with both WithPreempt and WithPushStrategy — only strategy's result applies. + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + return &AgentOutput{}, nil + }, + } + + agentDone := make(chan struct{}) + agentDoneOnce := sync.Once{} + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil }, - front: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { - <-frontBlocked - return ctx, "", nil, context.DeadlineExceeded + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + Remaining: nil, + }, nil }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + _, ok := events.Next() + if !ok { + break + } + } + agentDoneOnce.Do(func() { + close(agentDone) + }) + return nil + }, + }) + + // Strategy returns nil (no preempt), even though WithPreempt is also passed. + // The strategy should override — so the agent should NOT be preempted. + ok, ack := loop.Push("item", WithPreempt[string](), WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { + return nil // no preempt + })) + assert.True(t, ok) + assert.Nil(t, ack, "ack should be nil since strategy returned no preempt") + + select { + case <-agentDone: + case <-time.After(2 * time.Second): + t.Fatal("agent did not complete normally") } - var interruptErr *TurnLoopInterruptError[string] - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: source, - GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil + loop.Stop() + loop.Wait() +} + +func TestTurnLoop_PushStrategy_NestedStrategyStripped(t *testing.T) { + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + return &AgentOutput{}, nil }, - GetAgent: func(_ context.Context, _ string) (Agent, error) { + } + + agentDone := make(chan struct{}) + agentDoneOnce := sync.Once{} + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { return agent, nil }, - OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + Remaining: nil, + }, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { for { - if _, ok := iter.Next(); !ok { + _, ok := events.Next() + if !ok { break } } + agentDoneOnce.Do(func() { + close(agentDone) + }) return nil }, - Store: store, }) - require.NoError(t, err) - err = loop.Run(context.Background()) - require.True(t, errors.As(err, &interruptErr), "expected TurnLoopInterruptError, got: %v", err) - assert.Equal(t, checkPointID, interruptErr.CheckpointID) - assert.Equal(t, "msg1", interruptErr.Item) - assert.Len(t, interruptErr.InterruptContexts, 1) - assert.Equal(t, "need approval", interruptErr.InterruptContexts[0].Info) + // Strategy returns another WithPushStrategy — the nested one should be stripped. + innerCalled := int32(0) + ok, ack := loop.Push("item", WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { + return []PushOption[string]{ + WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { + atomic.AddInt32(&innerCalled, 1) + return []PushOption[string]{WithPreempt[string]()} + }), + } + })) + assert.True(t, ok) + assert.Nil(t, ack, "ack should be nil since nested strategy was stripped (no preempt)") - interruptID := interruptErr.InterruptContexts[0].ID - assert.NotEmpty(t, interruptID) - assert.True(t, len(store.data) > 0, "checkpoint should be stored") + select { + case <-agentDone: + case <-time.After(2 * time.Second): + t.Fatal("agent did not complete normally") + } - testModel.idx = 1 - receiveCount = 0 + loop.Stop() + loop.Wait() - resumeCtx := compose.ResumeWithData(context.Background(), interruptID, "user approved") - err = loop.Run(resumeCtx, WithTurnLoopResume(checkPointID, "msg1")) - assert.NoError(t, err) + assert.Equal(t, int32(0), atomic.LoadInt32(&innerCalled), "nested strategy should not be called") +} + +func TestTurnLoop_PushStrategy_ConsumedInspection(t *testing.T) { + // Strategy preempts only when current turn is processing "low-priority" items. + agentStarted := make(chan struct{}) + agentStartedOnce := sync.Once{} + + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + agentStartedOnce.Do(func() { + close(agentStarted) + }) + <-ctx.Done() + return &AgentOutput{}, nil + }, + } + + genInputCalls := int32(0) + secondGenInputItems := make(chan []string, 1) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + count := atomic.AddInt32(&genInputCalls, 1) + if count >= 2 { + select { + case secondGenInputItems <- append([]string{}, items...): + default: + } + } + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: []string{items[0]}, + Remaining: items[1:], + }, nil + }, + }) + + loop.Push("low-priority-task") + + select { + case <-agentStarted: + case <-time.After(1 * time.Second): + t.Fatal("agent did not start") + } + + // Strategy checks Consumed and preempts because current turn has "low-priority" items. + loop.Push("urgent-task", WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { + if tc != nil && len(tc.Consumed) > 0 && tc.Consumed[0] == "low-priority-task" { + return []PushOption[string]{WithPreempt[string]()} + } + return nil + })) + + select { + case items := <-secondGenInputItems: + assert.Contains(t, items, "urgent-task") + case <-time.After(2 * time.Second): + t.Fatal("second GenInput was not called after strategy-driven preempt") + } + + loop.Stop() + loop.Wait() } diff --git a/adk/utils.go b/adk/utils.go index 62ca8d2c6..11d3d7eeb 100644 --- a/adk/utils.go +++ b/adk/utils.go @@ -135,6 +135,7 @@ func getMessageFromWrappedEvent(e *agentEventWrapper) (Message, error) { // consumeStream drains the message stream, setting concatenatedMessage on // success or StreamErr on failure. The stream is always replaced with an // error-free, materialized version safe for gob encoding. +// Must be called at most once (guarded by callers checking concatenatedMessage/StreamErr). func (e *agentEventWrapper) consumeStream() { e.mu.Lock() defer e.mu.Unlock() @@ -154,10 +155,6 @@ func (e *agentEventWrapper) consumeStream() { break } e.StreamErr = err - // Replace the stream with successfully received messages only (no error at the end). - // The error is preserved in StreamErr for users to check. - // We intentionally exclude the error from the new stream to ensure gob encoding - // compatibility, as the stream may be consumed during serialization. e.AgentEvent.Output.MessageOutput.MessageStream = schema.StreamReaderFromArray(msgs) return } diff --git a/adk/workflow.go b/adk/workflow.go index 9d63d7347..00411e33b 100644 --- a/adk/workflow.go +++ b/adk/workflow.go @@ -175,7 +175,6 @@ func (a *workflowAgent) runSequential(ctx context.Context, startIdx := 0 - // seqCtx tracks the accumulated RunPath across the sequence. seqCtx := ctx // If we are resuming, find which sub-agent to start from and prepare its context. @@ -193,12 +192,28 @@ func (a *workflowAgent) runSequential(ctx context.Context, for i := startIdx; i < len(a.subAgents); i++ { subAgent := a.subAgents[i] + // Cancel check at transition boundary between sub-agents. + // Transition boundaries are always safe to cancel at — no sub-agent + // work is in progress, so any cancel mode is honoured. + if cancelCtx := getCancelContext(ctx); cancelCtx != nil && cancelCtx.shouldCancel() { + state := &sequentialWorkflowState{InterruptIndex: i} + event := cancelAtTransition(ctx, "Sequential workflow cancel at transition", state) + generator.Send(event) + return nil + } + var subIterator *AsyncIterator[*AgentEvent] if seqState != nil { - subIterator = subAgent.Resume(seqCtx, &ResumeInfo{ - EnableStreaming: info.EnableStreaming, - InterruptInfo: info.Data.(*WorkflowInterruptInfo).SequentialInterruptInfo, - }, opts...) + wfInfo, _ := info.Data.(*WorkflowInterruptInfo) + if wfInfo != nil && wfInfo.SequentialInterruptInfo != nil { + // Sub-agent was interrupted — resume it. + subIterator = subAgent.Resume(seqCtx, &ResumeInfo{ + EnableStreaming: info.EnableStreaming, + InterruptInfo: wfInfo.SequentialInterruptInfo, + }, opts...) + } else { + subIterator = subAgent.Run(seqCtx, nil, opts...) + } seqState = nil } else { subIterator = subAgent.Run(seqCtx, nil, opts...) @@ -304,7 +319,6 @@ func (a *workflowAgent) runLoop(ctx context.Context, generator *AsyncGenerator[* startIter := 0 startIdx := 0 - // loopCtx tracks the accumulated RunPath across the full sequence within a single iteration. loopCtx := ctx if loopState != nil { @@ -329,13 +343,25 @@ func (a *workflowAgent) runLoop(ctx context.Context, generator *AsyncGenerator[* for j := startIdx; j < len(a.subAgents); j++ { subAgent := a.subAgents[j] + if cancelCtx := getCancelContext(ctx); cancelCtx != nil && cancelCtx.shouldCancel() { + state := &loopWorkflowState{LoopIterations: i, SubAgentIndex: j} + event := cancelAtTransition(ctx, "Loop workflow cancel at transition", state) + generator.Send(event) + return nil + } + var subIterator *AsyncIterator[*AgentEvent] if loopState != nil { - // This is the agent we need to resume. - subIterator = subAgent.Resume(loopCtx, &ResumeInfo{ - EnableStreaming: resumeInfo.EnableStreaming, - InterruptInfo: resumeInfo.Data.(*WorkflowInterruptInfo).SequentialInterruptInfo, - }, opts...) + wfInfo, _ := resumeInfo.Data.(*WorkflowInterruptInfo) + if wfInfo != nil && wfInfo.SequentialInterruptInfo != nil { + // Sub-agent was interrupted — resume it. + subIterator = subAgent.Resume(loopCtx, &ResumeInfo{ + EnableStreaming: resumeInfo.EnableStreaming, + InterruptInfo: wfInfo.SequentialInterruptInfo, + }, opts...) + } else { + subIterator = subAgent.Run(loopCtx, nil, opts...) + } loopState = nil // Only resume the first time. } else { subIterator = subAgent.Run(loopCtx, nil, opts...) @@ -468,6 +494,15 @@ func (a *workflowAgent) runParallel(ctx context.Context, generator *AsyncGenerat } } + // Cancel check before spawning parallel goroutines. No sub-agent work + // is in progress, so any cancel mode is honoured at this boundary. + if cancelCtx := getCancelContext(ctx); cancelCtx != nil && cancelCtx.shouldCancel() { + state := ¶llelWorkflowState{} + event := cancelAtTransition(ctx, "Parallel workflow cancel before spawn", state) + generator.Send(event) + return nil + } + for i := range a.subAgents { wg.Add(1) go func(idx int, agent *flowAgent) { @@ -483,11 +518,13 @@ func (a *workflowAgent) runParallel(ctx context.Context, generator *AsyncGenerat var iterator *AsyncIterator[*AgentEvent] if _, ok := agentNames[agent.Name(ctx)]; ok { - // This branch was interrupted and needs to be resumed. - iterator = agent.Resume(childContexts[idx], &ResumeInfo{ + childResumeInfo := &ResumeInfo{ EnableStreaming: resumeInfo.EnableStreaming, - InterruptInfo: resumeInfo.Data.(*WorkflowInterruptInfo).ParallelInterruptInfo[idx], - }, opts...) + } + if wfInfo, ok := resumeInfo.Data.(*WorkflowInterruptInfo); ok && wfInfo != nil { + childResumeInfo.InterruptInfo = wfInfo.ParallelInterruptInfo[idx] + } + iterator = agent.Resume(childContexts[idx], childResumeInfo, opts...) } else if parState != nil { // We are resuming, but this child is not in the next points map. // This means it finished successfully, so we don't run it. @@ -550,6 +587,27 @@ func (a *workflowAgent) runParallel(ctx context.Context, generator *AsyncGenerat return nil } +func cancelAtTransition(ctx context.Context, info string, state any) *AgentEvent { + // state is the workflow checkpoint state (e.g. sequentialWorkflowState); + // nil for subContexts because this is a leaf interrupt with no child signals. + is, err := core.Interrupt(ctx, info, state, nil, + core.WithLayerPayload(getRunCtx(ctx).RunPath)) + if err != nil { + return &AgentEvent{Err: err} + } + + contexts := core.ToInterruptContexts(is, allowedAddressSegmentTypes) + + return &AgentEvent{ + Action: &AgentAction{ + Interrupted: &InterruptInfo{ + InterruptContexts: contexts, + }, + internalInterrupted: is, + }, + } +} + type SequentialAgentConfig struct { Name string Description string diff --git a/adk/wrappers.go b/adk/wrappers.go index 5061f5be8..eb23549fd 100644 --- a/adk/wrappers.go +++ b/adk/wrappers.go @@ -34,10 +34,11 @@ type generateEndpoint func(ctx context.Context, input []*schema.Message, opts .. type streamEndpoint func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) type modelWrapperConfig struct { - handlers []ChatModelAgentMiddleware - middlewares []AgentMiddleware - retryConfig *ModelRetryConfig - toolInfos []*schema.ToolInfo + handlers []ChatModelAgentMiddleware + middlewares []AgentMiddleware + retryConfig *ModelRetryConfig + toolInfos []*schema.ToolInfo + cancelContext *cancelContext } func buildModelWrappers(m model.BaseChatModel, config *modelWrapperConfig) model.BaseChatModel { @@ -54,6 +55,7 @@ func buildModelWrappers(m model.BaseChatModel, config *modelWrapperConfig) model middlewares: config.middlewares, toolInfos: config.toolInfos, modelRetryConfig: config.retryConfig, + cancelContext: config.cancelContext, } return wrapped @@ -252,11 +254,18 @@ func NewEventSenderModelWrapper() ChatModelAgentMiddleware { } func (w *eventSenderModelWrapper) WrapModel(_ context.Context, m model.BaseChatModel, mc *ModelContext) (model.BaseChatModel, error) { + inner := m + if mc != nil && mc.cancelContext != nil { + inner = &cancelMonitoredModel{ + inner: inner, + cancelContext: mc.cancelContext, + } + } var retryConfig *ModelRetryConfig if mc != nil { retryConfig = mc.ModelRetryConfig } - return &eventSenderModel{inner: m, modelRetryConfig: retryConfig}, nil + return &eventSenderModel{inner: inner, modelRetryConfig: retryConfig}, nil } type eventSenderModel struct { @@ -490,6 +499,7 @@ type stateModelWrapper struct { middlewares []AgentMiddleware toolInfos []*schema.ToolInfo modelRetryConfig *ModelRetryConfig + cancelContext *cancelContext } func (w *stateModelWrapper) IsCallbacksEnabled() bool { @@ -515,6 +525,7 @@ func (w *stateModelWrapper) hasUserEventSender() bool { func (w *stateModelWrapper) wrapGenerateEndpoint(endpoint generateEndpoint) generateEndpoint { hasUserEventSender := w.hasUserEventSender() retryConfig := w.modelRetryConfig + cc := w.cancelContext for i := len(w.handlers) - 1; i >= 0; i-- { handler := w.handlers[i] @@ -523,7 +534,7 @@ func (w *stateModelWrapper) wrapGenerateEndpoint(endpoint generateEndpoint) gene endpoint = func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { baseOpts := &model.Options{Tools: baseToolInfos} commonOpts := model.GetCommonOptions(baseOpts, opts...) - mc := &ModelContext{Tools: commonOpts.Tools, ModelRetryConfig: retryConfig} + mc := &ModelContext{Tools: commonOpts.Tools, ModelRetryConfig: retryConfig, cancelContext: cc} wrappedModel, err := handler.WrapModel(ctx, &endpointModel{generate: innerEndpoint}, mc) if err != nil { return nil, err @@ -540,7 +551,7 @@ func (w *stateModelWrapper) wrapGenerateEndpoint(endpoint generateEndpoint) gene if execCtx == nil || execCtx.generator == nil { return innerEndpoint(ctx, input, opts...) } - mc := &ModelContext{ModelRetryConfig: retryConfig} + mc := &ModelContext{ModelRetryConfig: retryConfig, cancelContext: cc} wrappedModel, err := eventSender.WrapModel(ctx, &endpointModel{generate: innerEndpoint}, mc) if err != nil { return nil, err @@ -563,6 +574,7 @@ func (w *stateModelWrapper) wrapGenerateEndpoint(endpoint generateEndpoint) gene func (w *stateModelWrapper) wrapStreamEndpoint(endpoint streamEndpoint) streamEndpoint { hasUserEventSender := w.hasUserEventSender() retryConfig := w.modelRetryConfig + cc := w.cancelContext for i := len(w.handlers) - 1; i >= 0; i-- { handler := w.handlers[i] @@ -571,7 +583,7 @@ func (w *stateModelWrapper) wrapStreamEndpoint(endpoint streamEndpoint) streamEn endpoint = func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { baseOpts := &model.Options{Tools: baseToolInfos} commonOpts := model.GetCommonOptions(baseOpts, opts...) - mc := &ModelContext{Tools: commonOpts.Tools, ModelRetryConfig: retryConfig} + mc := &ModelContext{Tools: commonOpts.Tools, ModelRetryConfig: retryConfig, cancelContext: cc} wrappedModel, err := handler.WrapModel(ctx, &endpointModel{stream: innerEndpoint}, mc) if err != nil { return nil, err @@ -588,7 +600,7 @@ func (w *stateModelWrapper) wrapStreamEndpoint(endpoint streamEndpoint) streamEn if execCtx == nil || execCtx.generator == nil { return innerEndpoint(ctx, input, opts...) } - mc := &ModelContext{ModelRetryConfig: retryConfig} + mc := &ModelContext{ModelRetryConfig: retryConfig, cancelContext: cc} wrappedModel, err := eventSender.WrapModel(ctx, &endpointModel{stream: innerEndpoint}, mc) if err != nil { return nil, err @@ -615,7 +627,7 @@ func (w *stateModelWrapper) Generate(ctx context.Context, input []*schema.Messag return nil }) - state := &ChatModelAgentState{Messages: append(stateMessages, input...)} + state := &ChatModelAgentState{Messages: stateMessages} for _, m := range w.middlewares { if m.BeforeChatModel != nil { @@ -627,7 +639,7 @@ func (w *stateModelWrapper) Generate(ctx context.Context, input []*schema.Messag baseOpts := &model.Options{Tools: w.toolInfos} commonOpts := model.GetCommonOptions(baseOpts, opts...) - mc := &ModelContext{Tools: commonOpts.Tools, ModelRetryConfig: w.modelRetryConfig} + mc := &ModelContext{Tools: commonOpts.Tools, ModelRetryConfig: w.modelRetryConfig, cancelContext: w.cancelContext} for _, handler := range w.handlers { var err error ctx, state, err = handler.BeforeModelRewriteState(ctx, state, mc) @@ -681,7 +693,7 @@ func (w *stateModelWrapper) Stream(ctx context.Context, input []*schema.Message, return nil }) - state := &ChatModelAgentState{Messages: append(stateMessages, input...)} + state := &ChatModelAgentState{Messages: stateMessages} for _, m := range w.middlewares { if m.BeforeChatModel != nil { @@ -693,7 +705,7 @@ func (w *stateModelWrapper) Stream(ctx context.Context, input []*schema.Message, baseOpts := &model.Options{Tools: w.toolInfos} commonOpts := model.GetCommonOptions(baseOpts, opts...) - mc := &ModelContext{Tools: commonOpts.Tools, ModelRetryConfig: w.modelRetryConfig} + mc := &ModelContext{Tools: commonOpts.Tools, ModelRetryConfig: w.modelRetryConfig, cancelContext: w.cancelContext} for _, handler := range w.handlers { var err error ctx, state, err = handler.BeforeModelRewriteState(ctx, state, mc) diff --git a/compose/checkpoint_test.go b/compose/checkpoint_test.go index c24b6ce6f..a86c02fb3 100644 --- a/compose/checkpoint_test.go +++ b/compose/checkpoint_test.go @@ -1383,6 +1383,7 @@ func TestCancelInterrupt(t *testing.T) { info, success := ExtractInterruptInfo(err) assert.True(t, success) assert.Equal(t, []string{"1"}, info.AfterNodes) + assert.True(t, info.FromGraphInterrupt) result, err := r.Invoke(ctx, "input", WithCheckPointID("1")) assert.NoError(t, err) assert.Equal(t, "input12", result) @@ -1397,6 +1398,7 @@ func TestCancelInterrupt(t *testing.T) { info, success = ExtractInterruptInfo(err) assert.True(t, success) assert.Equal(t, []string{"1"}, info.AfterNodes) + assert.True(t, info.FromGraphInterrupt) result, err = r.Invoke(ctx, "input", WithCheckPointID("2")) assert.NoError(t, err) assert.Equal(t, "input12", result) @@ -1412,6 +1414,7 @@ func TestCancelInterrupt(t *testing.T) { info, success = ExtractInterruptInfo(err) assert.True(t, success) assert.Equal(t, []string{"1"}, info.RerunNodes) + assert.True(t, info.FromGraphInterrupt) result, err = r.Invoke(ctx, "input", WithCheckPointID("3")) assert.NoError(t, err) assert.Equal(t, "input12", result) @@ -1441,6 +1444,7 @@ func TestCancelInterrupt(t *testing.T) { info, success = ExtractInterruptInfo(err) assert.True(t, success) assert.Equal(t, []string{"1"}, info.AfterNodes) + assert.True(t, info.FromGraphInterrupt) result, err = r.Invoke(ctx, "input", WithCheckPointID("1")) assert.NoError(t, err) assert.Equal(t, "input12", result) @@ -1455,6 +1459,7 @@ func TestCancelInterrupt(t *testing.T) { info, success = ExtractInterruptInfo(err) assert.True(t, success) assert.Equal(t, []string{"1"}, info.AfterNodes) + assert.True(t, info.FromGraphInterrupt) result, err = r.Invoke(ctx, "input", WithCheckPointID("2")) assert.NoError(t, err) assert.Equal(t, "input12", result) @@ -1470,6 +1475,7 @@ func TestCancelInterrupt(t *testing.T) { info, success = ExtractInterruptInfo(err) assert.True(t, success) assert.Equal(t, []string{"1"}, info.RerunNodes) + assert.True(t, info.FromGraphInterrupt) result, err = r.Invoke(ctx, "input", WithCheckPointID("3")) assert.NoError(t, err) assert.Equal(t, "input12", result) @@ -1510,6 +1516,7 @@ func TestCancelInterrupt(t *testing.T) { info, success = ExtractInterruptInfo(err) assert.True(t, success) assert.Equal(t, 2, len(info.AfterNodes)) + assert.True(t, info.FromGraphInterrupt) result2, err := rr.Invoke(ctx, "input", WithCheckPointID("1")) assert.NoError(t, err) assert.Equal(t, map[string]any{ @@ -1528,6 +1535,7 @@ func TestCancelInterrupt(t *testing.T) { info, success = ExtractInterruptInfo(err) assert.True(t, success) assert.Equal(t, 2, len(info.RerunNodes)) + assert.True(t, info.FromGraphInterrupt) result2, err = rr.Invoke(ctx, "input", WithCheckPointID("2")) assert.NoError(t, err) assert.Equal(t, map[string]any{ @@ -1536,6 +1544,26 @@ func TestCancelInterrupt(t *testing.T) { }, result2) } +func TestBusinessInterruptFromGraphInterruptFalse(t *testing.T) { + g := NewGraph[string, string]() + _ = g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { + return "", Interrupt(ctx, "biz") + })) + _ = g.AddEdge(START, "1") + _ = g.AddEdge("1", END) + + ctx := context.Background() + r, err := g.Compile(ctx, WithCheckPointStore(newInMemoryStore())) + assert.NoError(t, err) + + _, err = r.Invoke(ctx, "input", WithCheckPointID("biz")) + assert.Error(t, err) + info, existed := ExtractInterruptInfo(err) + assert.True(t, existed) + assert.False(t, info.FromGraphInterrupt) + assert.Equal(t, []string{"1"}, info.RerunNodes) +} + func TestPersistRerunInputNonStream(t *testing.T) { store := newInMemoryStore() diff --git a/compose/graph_manager.go b/compose/graph_manager.go index 944a0cf0a..46df3488e 100644 --- a/compose/graph_manager.go +++ b/compose/graph_manager.go @@ -496,12 +496,15 @@ func receiveWithListening(recv func() (*task, bool), cancel chan *time.Duration) return p.ta, p.closed, false, false, nil case timeout, ok := <-cancel: if !ok { - // unreachable - break + // The cancel channel has been closed — this means a previous call to + // receiveWithListening already consumed the cancel signal (task completed + // at the same time as cancel, and select picked the task result). Since + // cancel was already issued, treat this as an immediate cancel rather than + // blocking forever on resultCh. + return nil, false, true, true, nil } canceled = true if timeout == nil { - // canceled without timeout break } timeoutCh = time.After(*timeout) diff --git a/compose/graph_run.go b/compose/graph_run.go index a3e81ecf1..770cf16de 100644 --- a/compose/graph_run.go +++ b/compose/graph_run.go @@ -434,6 +434,7 @@ type interruptTempInfo struct { interruptBeforeNodes []string interruptAfterNodes []string interruptRerunExtra map[string]any + fromGraphInterrupt bool signals []*core.InterruptSignal } @@ -442,6 +443,7 @@ func (ti *interruptTempInfo) collectCanceledInfo(canceled bool, canceledTasks, c if !canceled { return } + ti.fromGraphInterrupt = true if len(canceledTasks) > 0 { for _, t := range canceledTasks { ti.interruptRerunNodes = append(ti.interruptRerunNodes, t.nodeKey) @@ -459,6 +461,13 @@ func (r *runner) resolveInterruptCompletedTasks(tempInfo *interruptTempInfo, com if info := isSubGraphInterrupt(completedTask.err); info != nil { tempInfo.subGraphInterrupts[completedTask.nodeKey] = info tempInfo.signals = append(tempInfo.signals, info.signal) + // Propagate FromGraphInterrupt from the sub-graph to the parent. + // The sub-graph's task manager may have consumed the cancel + // channel value before the parent's, so only the sub-graph + // knows the interrupt was triggered by a graph-level cancel. + if info.Info != nil && info.Info.FromGraphInterrupt { + tempInfo.fromGraphInterrupt = true + } continue } @@ -515,27 +524,27 @@ func (r *runner) handleInterrupt( if r.runCtx != nil { // current graph has enable state if state, ok := ctx.Value(stateKey{}).(*internalState); ok { - cp.State = state.state + state.mu.Lock() + copiedState, err := deepCopyState(state.state) + state.mu.Unlock() + if err != nil { + return fmt.Errorf("failed to copy state: %w", err) + } + cp.State = copiedState } } intInfo := &InterruptInfo{ - State: cp.State, - AfterNodes: tempInfo.interruptAfterNodes, - BeforeNodes: tempInfo.interruptBeforeNodes, - RerunNodes: tempInfo.interruptRerunNodes, - RerunNodesExtra: tempInfo.interruptRerunExtra, - SubGraphs: make(map[string]*InterruptInfo), + State: cp.State, + AfterNodes: tempInfo.interruptAfterNodes, + BeforeNodes: tempInfo.interruptBeforeNodes, + RerunNodes: tempInfo.interruptRerunNodes, + RerunNodesExtra: tempInfo.interruptRerunExtra, + SubGraphs: make(map[string]*InterruptInfo), + FromGraphInterrupt: tempInfo.fromGraphInterrupt, } - var info any - if cp.State != nil { - copiedState, err := deepCopyState(cp.State) - if err != nil { - return fmt.Errorf("failed to copy state: %w", err) - } - info = copiedState - } + info := cp.State is, err := core.Interrupt(ctx, info, nil, tempInfo.signals) if err != nil { @@ -581,15 +590,18 @@ func deepCopyState(state any) (any, error) { // Create new instance of the same type stateType := reflect.TypeOf(state) - if stateType.Kind() == reflect.Ptr { + isPtr := stateType.Kind() == reflect.Ptr + if isPtr { stateType = stateType.Elem() } - newState := reflect.New(stateType).Interface() - - if err := serializer.Unmarshal(data, newState); err != nil { + newStatePtr := reflect.New(stateType).Interface() + if err := serializer.Unmarshal(data, newStatePtr); err != nil { return nil, fmt.Errorf("failed to unmarshal state: %w", err) } - return newState, nil + if isPtr { + return newStatePtr, nil + } + return reflect.ValueOf(newStatePtr).Elem().Interface(), nil } func (r *runner) handleInterruptWithSubGraphAndRerunNodes( @@ -645,27 +657,27 @@ func (r *runner) handleInterruptWithSubGraphAndRerunNodes( if r.runCtx != nil { // current graph has enable state if state, ok := ctx.Value(stateKey{}).(*internalState); ok { - cp.State = state.state + state.mu.Lock() + copiedState, err_ := deepCopyState(state.state) + state.mu.Unlock() + if err_ != nil { + return fmt.Errorf("failed to copy state: %w", err_) + } + cp.State = copiedState } } intInfo := &InterruptInfo{ - State: cp.State, - BeforeNodes: tempInfo.interruptBeforeNodes, - AfterNodes: tempInfo.interruptAfterNodes, - RerunNodes: tempInfo.interruptRerunNodes, - RerunNodesExtra: tempInfo.interruptRerunExtra, - SubGraphs: make(map[string]*InterruptInfo), + State: cp.State, + BeforeNodes: tempInfo.interruptBeforeNodes, + AfterNodes: tempInfo.interruptAfterNodes, + RerunNodes: tempInfo.interruptRerunNodes, + RerunNodesExtra: tempInfo.interruptRerunExtra, + SubGraphs: make(map[string]*InterruptInfo), + FromGraphInterrupt: tempInfo.fromGraphInterrupt, } - var info any - if cp.State != nil { - copiedState, err_ := deepCopyState(cp.State) - if err_ != nil { - return fmt.Errorf("failed to copy state: %w", err_) - } - info = copiedState - } + info := cp.State is, err := core.Interrupt(ctx, info, nil, tempInfo.signals) if err != nil { diff --git a/compose/interrupt.go b/compose/interrupt.go index 98a5eeecc..cd423a1d6 100644 --- a/compose/interrupt.go +++ b/compose/interrupt.go @@ -263,6 +263,10 @@ type InterruptInfo struct { RerunNodesExtra map[string]any SubGraphs map[string]*InterruptInfo InterruptContexts []*InterruptCtx + // FromGraphInterrupt indicates whether the interrupt was triggered by a graph-level + // cancel operation (e.g., via WithGraphInterrupt) rather than business logic. + // When true, the interrupt originated from an external cancellation request. + FromGraphInterrupt bool } func init() { diff --git a/examples b/examples new file mode 160000 index 000000000..4afd5a3f2 --- /dev/null +++ b/examples @@ -0,0 +1 @@ +Subproject commit 4afd5a3f26a4db4833088505b9f7a0f631e9f231 diff --git a/ext b/ext new file mode 160000 index 000000000..f061db7e8 --- /dev/null +++ b/ext @@ -0,0 +1 @@ +Subproject commit f061db7e84191705db6c48f0085938de84f90742 diff --git a/internal/channel.go b/internal/channel.go index 2351c87e9..8e36d8939 100644 --- a/internal/channel.go +++ b/internal/channel.go @@ -46,6 +46,21 @@ func (ch *UnboundedChan[T]) Send(value T) { ch.notEmpty.Signal() // Wake up one goroutine waiting to receive } +// TrySend attempts to put an item into the channel. +// Returns false if the channel is closed, true otherwise. +func (ch *UnboundedChan[T]) TrySend(value T) bool { + ch.mutex.Lock() + defer ch.mutex.Unlock() + + if ch.closed { + return false + } + + ch.buffer = append(ch.buffer, value) + ch.notEmpty.Signal() + return true +} + // Receive gets an item from the channel (blocks if empty) func (ch *UnboundedChan[T]) Receive() (T, bool) { ch.mutex.Lock() @@ -76,3 +91,33 @@ func (ch *UnboundedChan[T]) Close() { ch.notEmpty.Broadcast() // Wake up all waiting goroutines } } + +// TakeAll removes and returns all values from the channel atomically. +// Returns nil if the channel is empty. +func (ch *UnboundedChan[T]) TakeAll() []T { + ch.mutex.Lock() + defer ch.mutex.Unlock() + + if len(ch.buffer) == 0 { + return nil + } + + values := ch.buffer + ch.buffer = nil + return values +} + +// PushFront adds values to the front of the channel. +// This is useful for recovering values that need to be reprocessed. +// Does nothing if values is empty. +func (ch *UnboundedChan[T]) PushFront(values []T) { + if len(values) == 0 { + return + } + + ch.mutex.Lock() + defer ch.mutex.Unlock() + + ch.buffer = append(append([]T{}, values...), ch.buffer...) + ch.notEmpty.Signal() +} diff --git a/internal/channel_test.go b/internal/channel_test.go index 736a27413..bed2383f1 100644 --- a/internal/channel_test.go +++ b/internal/channel_test.go @@ -219,3 +219,244 @@ func TestUnboundedChan_BlockingReceive(t *testing.T) { t.Error("Receive should have unblocked") } } + +func TestUnboundedChan_TakeAll(t *testing.T) { + ch := NewUnboundedChan[int]() + + // Test TakeAll on empty channel + items := ch.TakeAll() + if items != nil { + t.Errorf("TakeAll on empty channel should return nil, got %v", items) + } + + // Send some values + ch.Send(1) + ch.Send(2) + ch.Send(3) + + // Test TakeAll returns all values + items = ch.TakeAll() + if len(items) != 3 { + t.Errorf("expected 3 values, got %d", len(items)) + } + if items[0] != 1 || items[1] != 2 || items[2] != 3 { + t.Errorf("unexpected values: %v", items) + } + + // Channel should be empty now + if len(ch.buffer) != 0 { + t.Errorf("channel should be empty after TakeAll, got %d values", len(ch.buffer)) + } + + // TakeAll again should return nil + items = ch.TakeAll() + if items != nil { + t.Errorf("TakeAll on empty channel should return nil, got %v", items) + } +} + +func TestUnboundedChan_TakeAll_Partial(t *testing.T) { + ch := NewUnboundedChan[int]() + + // Send values + ch.Send(1) + ch.Send(2) + ch.Send(3) + + // Receive one + val, ok := ch.Receive() + if !ok || val != 1 { + t.Errorf("expected (1, true), got (%d, %v)", val, ok) + } + + // TakeAll should return remaining values + items := ch.TakeAll() + if len(items) != 2 { + t.Errorf("expected 2 values, got %d", len(items)) + } + if items[0] != 2 || items[1] != 3 { + t.Errorf("unexpected values: %v", items) + } +} + +func TestUnboundedChan_PushFront(t *testing.T) { + ch := NewUnboundedChan[int]() + + // Test PushFront with empty values (should do nothing) + ch.PushFront(nil) + ch.PushFront([]int{}) + if len(ch.buffer) != 0 { + t.Errorf("PushFront with empty values should not add anything, got %d values", len(ch.buffer)) + } + + // Send some values + ch.Send(3) + ch.Send(4) + + // PushFront should prepend values + ch.PushFront([]int{1, 2}) + + if len(ch.buffer) != 4 { + t.Errorf("expected 4 values, got %d", len(ch.buffer)) + } + if ch.buffer[0] != 1 || ch.buffer[1] != 2 || ch.buffer[2] != 3 || ch.buffer[3] != 4 { + t.Errorf("unexpected buffer: %v", ch.buffer) + } + + // Receive should return in correct order + val, _ := ch.Receive() + if val != 1 { + t.Errorf("expected 1, got %d", val) + } + val, _ = ch.Receive() + if val != 2 { + t.Errorf("expected 2, got %d", val) + } +} + +func TestUnboundedChan_PushFront_EmptyChannel(t *testing.T) { + ch := NewUnboundedChan[int]() + + // PushFront to empty channel + ch.PushFront([]int{1, 2, 3}) + + if len(ch.buffer) != 3 { + t.Errorf("expected 3 values, got %d", len(ch.buffer)) + } + + // Receive should work + val, ok := ch.Receive() + if !ok || val != 1 { + t.Errorf("expected (1, true), got (%d, %v)", val, ok) + } +} + +func TestUnboundedChan_PushFront_UnblocksReceive(t *testing.T) { + ch := NewUnboundedChan[int]() + + // Start a blocking receive + receiveDone := make(chan int) + go func() { + val, _ := ch.Receive() + receiveDone <- val + }() + + // Ensure receive is blocked + select { + case <-receiveDone: + t.Error("Receive should block on empty channel") + case <-time.After(50 * time.Millisecond): + // This is expected + } + + // PushFront should unblock the receive + ch.PushFront([]int{42}) + + select { + case val := <-receiveDone: + if val != 42 { + t.Errorf("expected 42, got %d", val) + } + case <-time.After(50 * time.Millisecond): + t.Error("Receive should have unblocked after PushFront") + } +} + +func TestUnboundedChan_PushFront_SpareCapacity(t *testing.T) { + ch := NewUnboundedChan[int]() + + // Pre-fill the channel so PushFront has something to append + ch.Send(10) + ch.Send(20) + + // Create a slice with spare capacity: len=2, cap=10. + // Elements beyond len (index 2-9) must not be corrupted by PushFront. + src := make([]int, 3, 10) + src[0] = 1 + src[1] = 2 + src[2] = 3 // sentinel — must survive PushFront(src[:2]) + + ch.PushFront(src[:2]) + + // Verify the sentinel was NOT overwritten by the channel's existing buffer + if src[2] != 3 { + t.Errorf("PushFront corrupted caller's backing array: src[2] = %d, want 3", src[2]) + } + + // Verify channel drains correctly: [1, 2, 10, 20] + expected := []int{1, 2, 10, 20} + for i, want := range expected { + got, ok := ch.Receive() + if !ok { + t.Fatalf("Receive returned ok=false at index %d", i) + } + if got != want { + t.Errorf("index %d: got %d, want %d", i, got, want) + } + } +} + +func TestUnboundedChan_TakeAll_PushFront_Concurrent(t *testing.T) { + ch := NewUnboundedChan[int]() + const numOps = 100 + + var wg sync.WaitGroup + wg.Add(3) + + // Goroutine 1: Send values + go func() { + defer wg.Done() + for i := 0; i < numOps; i++ { + ch.Send(i) + time.Sleep(time.Microsecond) + } + }() + + // Goroutine 2: TakeAll periodically + takeAllResults := make([][]int, 0) + var mu sync.Mutex + go func() { + defer wg.Done() + for i := 0; i < numOps/10; i++ { + items := ch.TakeAll() + if items != nil { + mu.Lock() + takeAllResults = append(takeAllResults, items) + mu.Unlock() + } + time.Sleep(10 * time.Microsecond) + } + }() + + // Goroutine 3: PushFront periodically + go func() { + defer wg.Done() + for i := 0; i < numOps/10; i++ { + ch.PushFront([]int{-i}) + time.Sleep(10 * time.Microsecond) + } + }() + + wg.Wait() + ch.Close() + + // Drain remaining values + remaining := ch.TakeAll() + if remaining != nil { + mu.Lock() + takeAllResults = append(takeAllResults, remaining) + mu.Unlock() + } + + // Count total values collected + total := 0 + for _, batch := range takeAllResults { + total += len(batch) + } + + // We should have exactly numOps (from Send) + numOps/10 (from PushFront) values + expected := numOps + numOps/10 + if total != expected { + t.Errorf("expected %d values, got %d", expected, total) + } +} diff --git a/internal/core/address.go b/internal/core/address.go index 8efabf943..c1ecb17fb 100644 --- a/internal/core/address.go +++ b/internal/core/address.go @@ -88,7 +88,7 @@ type addrCtx struct { type globalResumeInfoKey struct{} type globalResumeInfo struct { - mu sync.Mutex + mu sync.RWMutex id2ResumeData map[string]any id2ResumeDataUsed map[string]bool id2State map[string]InterruptState @@ -147,24 +147,21 @@ func AppendAddressSegment(ctx context.Context, segType AddressSegmentType, segID return context.WithValue(ctx, addrCtxKey{}, runCtx) } + rInfo.mu.Lock() + defer rInfo.mu.Unlock() + var id string for id_, addr := range rInfo.id2Addr { if addr.Equals(currentAddress) { - rInfo.mu.Lock() if used, ok := rInfo.id2StateUsed[id_]; !ok || !used { runCtx.interruptState = generic.PtrOf(rInfo.id2State[id_]) rInfo.id2StateUsed[id_] = true id = id_ - rInfo.mu.Unlock() break } - rInfo.mu.Unlock() } } - // take from globalResumeInfo the data for the new address if there is any - rInfo.mu.Lock() - defer rInfo.mu.Unlock() used := rInfo.id2ResumeDataUsed[id] if !used { rData, existed := rInfo.id2ResumeData[id] @@ -175,10 +172,6 @@ func AppendAddressSegment(ctx context.Context, segType AddressSegmentType, segID } } - // Also mark as resume target if any descendant address is a resume target. - // This allows composite components (e.g., a tool containing a nested graph) to know - // they should execute their children to reach the actual resume target. - // We only consider descendants whose resume data has not yet been consumed. if !runCtx.isResumeTarget { for id_, addr := range rInfo.id2Addr { if len(addr) > len(currentAddress) && addr[:len(currentAddress)].Equals(currentAddress) { @@ -202,6 +195,9 @@ func GetNextResumptionPoints(ctx context.Context) (map[string]bool, error) { return nil, fmt.Errorf("GetNextResumptionPoints: failed to get resume info from context") } + rInfo.mu.RLock() + defer rInfo.mu.RUnlock() + nextPoints := make(map[string]bool) parentAddrLen := len(parentAddr) @@ -276,6 +272,9 @@ func PopulateInterruptState(ctx context.Context, id2Addr map[string]Address, id2State map[string]InterruptState) context.Context { rInfo, ok := ctx.Value(globalResumeInfoKey{}).(*globalResumeInfo) if ok { + rInfo.mu.Lock() + defer rInfo.mu.Unlock() + if rInfo.id2Addr == nil { rInfo.id2Addr = make(map[string]Address) } @@ -299,17 +298,13 @@ func PopulateInterruptState(ctx context.Context, id2Addr map[string]Address, if addr.Equals(runCtx.addr) { if used, ok := rInfo.id2StateUsed[id_]; !ok || !used { runCtx.interruptState = generic.PtrOf(rInfo.id2State[id_]) - rInfo.mu.Lock() rInfo.id2StateUsed[id_] = true - rInfo.mu.Unlock() } if used, ok := rInfo.id2ResumeDataUsed[id_]; !ok || !used { runCtx.isResumeTarget = true runCtx.resumeData = rInfo.id2ResumeData[id_] - rInfo.mu.Lock() rInfo.id2ResumeDataUsed[id_] = true - rInfo.mu.Unlock() } break diff --git a/internal/core/interrupt.go b/internal/core/interrupt.go index d7a934a3d..174e0c47c 100644 --- a/internal/core/interrupt.go +++ b/internal/core/interrupt.go @@ -29,6 +29,13 @@ type CheckPointStore interface { Set(ctx context.Context, checkPointID string, checkPoint []byte) error } +// CheckPointDeleter is an optional interface that CheckPointStore implementations +// can implement to support explicit checkpoint deletion. If the Store does not +// implement this interface, deletion is performed by writing an empty value via Set. +type CheckPointDeleter interface { + Delete(ctx context.Context, checkPointID string) error +} + type InterruptSignal struct { ID string Address diff --git a/schema/serialization.go b/schema/serialization.go index 7a719b0a8..22fa16ade 100644 --- a/schema/serialization.go +++ b/schema/serialization.go @@ -25,7 +25,7 @@ import ( ) func init() { - RegisterName[Message]("_eino_message") + RegisterName[*Message]("_eino_message") RegisterName[[]*Message]("_eino_message_slice") RegisterName[Document]("_eino_document") RegisterName[RoleType]("_eino_role_type") From 8fd61237b0e7dd861360174f29a3e414d8916ad6 Mon Sep 17 00:00:00 2001 From: shentongmartin Date: Fri, 27 Mar 2026 11:46:36 +0800 Subject: [PATCH 46/65] fix(adk): skip saving checkpoint when TurnLoop is idle (#916) * fix(adk): skip saving checkpoint when TurnLoop is idle When Stop() is called on an idle TurnLoop (no active agent run, no unhandled items, no canceled items), the resulting checkpoint contains no meaningful state. Skip saving such checkpoints to avoid unnecessary store writes. - Add isIdle check in cleanup() before checkpoint save decision - Add TestTurnLoop_StopWhileIdle_SkipsCheckpoint test Change-Id: I6aeaff5ed5833a971cb95298193fdb96d904baf8 * fix(internal): merge id2State in PopulateInterruptState instead of replacing PopulateInterruptState merged id2Addr entries one by one but replaced id2State wholesale. In a parallel workflow resume, two goroutines share the same globalResumeInfo. If one goroutine's compose graph called PopulateInterruptState (replacing id2State with compose-only entries) before the other goroutine looked up its outer-level entry, the lookup returned a zero-value InterruptState with State=nil, triggering the 'has no state' panic in ChatModelAgent.Resume. Change id2State handling to merge entry by entry, consistent with id2Addr. Change-Id: Ia21f65289bff7beb2bc383fb033926ad9c92d7e7 * fix(adk): keep watching for cancel escalation after stopSig.done When watchStopSignal entered the stopSig.done branch, it processed the initial cancel and then blocked on <-done (turn completion), never looping back to check notify. This meant a subsequent Stop() call with a higher cancel mode (e.g. CancelImmediate) was never forwarded to the agent, causing TestTurnLoop_Stop_EscalatesCancelMode to time out. Replace the blocking <-done with an inner loop that selects on both done and notify, so escalation signals are always delivered. Also apply the generation-based dedup check consistent with the notify branch. Change-Id: Ia6a04d00a2b44625ffbcb625ff0e559c12ed145f --- adk/turn_loop.go | 34 ++++++++++++++++++++++++++-------- adk/turn_loop_test.go | 28 ++++++++++++++++++++++++++++ internal/core/address.go | 7 ++++++- 3 files changed, 60 insertions(+), 9 deletions(-) diff --git a/adk/turn_loop.go b/adk/turn_loop.go index 88979876e..000842589 100644 --- a/adk/turn_loop.go +++ b/adk/turn_loop.go @@ -1228,14 +1228,31 @@ func (l *TurnLoop[T]) watchStopSignal(done <-chan struct{}, agentCancelFunc Agen } } case <-l.stopSig.done: - _, opts := l.stopSig.check() - _, contributed := agentCancelFunc(opts...) - if contributed && !stoppedClosed { - close(stoppedDone) - stoppedClosed = true + gen, opts := l.stopSig.check() + if gen != lastGen { + lastGen = gen + _, contributed := agentCancelFunc(opts...) + if contributed && !stoppedClosed { + close(stoppedDone) + stoppedClosed = true + } + } + for { + select { + case <-done: + return + case <-l.stopSig.notify: + gen, opts := l.stopSig.check() + if gen != lastGen { + lastGen = gen + _, contributed := agentCancelFunc(opts...) + if contributed && !stoppedClosed { + close(stoppedDone) + stoppedClosed = true + } + } + } } - <-done - return } } } @@ -1380,7 +1397,8 @@ func (l *TurnLoop[T]) cleanup(ctx context.Context) { unhandled := l.buffer.TakeAll() checkpointID := l.config.CheckpointID - shouldSaveCheckpoint := l.config.Store != nil && checkpointID != "" && l.stopSig.isStopped() + isIdle := len(l.checkPointRunnerBytes) == 0 && len(unhandled) == 0 && len(l.canceledItems) == 0 + shouldSaveCheckpoint := l.config.Store != nil && checkpointID != "" && l.stopSig.isStopped() && !isIdle if shouldSaveCheckpoint { cp := &turnLoopCheckpoint[T]{ RunnerCheckpoint: l.checkPointRunnerBytes, diff --git a/adk/turn_loop_test.go b/adk/turn_loop_test.go index 6e3159cfc..72753fc5e 100644 --- a/adk/turn_loop_test.go +++ b/adk/turn_loop_test.go @@ -1484,6 +1484,34 @@ func TestTurnLoop_StopWithoutCheckpointIDDoesNotPersist(t *testing.T) { assert.Empty(t, store.m, "no checkpoint should be saved when CheckpointID is not configured") } +func TestTurnLoop_StopWhileIdle_SkipsCheckpoint(t *testing.T) { + ctx := context.Background() + store := &deletableCheckpointStore{ + turnLoopCheckpointStore: turnLoopCheckpointStore{m: make(map[string][]byte)}, + } + cpID := "idle-session" + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Stop() + exit := loop.Wait() + assert.NoError(t, exit.ExitReason) + + store.mu.Lock() + defer store.mu.Unlock() + _, exists := store.m[cpID] + assert.False(t, exists, "no checkpoint should be saved when TurnLoop is idle") +} + func TestTurnLoop_StopBetweenTurnsAndResume(t *testing.T) { ctx := context.Background() store := &turnLoopCheckpointStore{m: make(map[string][]byte)} diff --git a/internal/core/address.go b/internal/core/address.go index c1ecb17fb..bb2400a92 100644 --- a/internal/core/address.go +++ b/internal/core/address.go @@ -281,7 +281,12 @@ func PopulateInterruptState(ctx context.Context, id2Addr map[string]Address, for id, addr := range id2Addr { rInfo.id2Addr[id] = addr } - rInfo.id2State = id2State + if rInfo.id2State == nil { + rInfo.id2State = make(map[string]InterruptState) + } + for id, state := range id2State { + rInfo.id2State[id] = state + } } else { rInfo = &globalResumeInfo{ id2Addr: id2Addr, From d832c4f7da8e86f7a221c466d8c171a22785b0cc Mon Sep 17 00:00:00 2001 From: shentongmartin Date: Wed, 1 Apr 2026 11:41:15 +0800 Subject: [PATCH 47/65] feat(adk): export NewEventSenderToolWrapper for customizable tool event position (#926) --- adk/chatmodel.go | 39 ++- adk/wrappers.go | 88 ++++--- adk/wrappers_test.go | 597 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 680 insertions(+), 44 deletions(-) diff --git a/adk/chatmodel.go b/adk/chatmodel.go index a0126455a..d7382b309 100644 --- a/adk/chatmodel.go +++ b/adk/chatmodel.go @@ -283,13 +283,35 @@ type ChatModelAgentConfig struct { // the default event sender to avoid duplicate events. // // Tool call lifecycle (outermost to innermost): - // 1. eventSenderToolHandler (internal ToolMiddleware - sends tool result events after all processing) + // 1. eventSenderToolWrapper (internal ToolMiddleware - sends tool result events after all processing) // 2. ToolsConfig.ToolCallMiddlewares (ToolMiddleware) // 3. AgentMiddleware.WrapToolCall (ToolMiddleware) // 4. ChatModelAgentMiddleware.WrapToolCall (wrapper, first registered is outermost) // 5. callbackInjectedToolCall (internal - injects callbacks if tool doesn't handle them) // 6. Tool.InvokableRun/StreamableRun // + // Custom Tool Event Sender Position: + // By default, tool result events are emitted by an internal event sender placed before + // all user middlewares (outermost), so events reflect the fully processed tool output. + // To control exactly where in the handler chain tool events are emitted, pass + // NewEventSenderToolWrapper() as one of the Handlers. Its position determines which + // middlewares' effects are visible in the emitted event: + // + // agent, _ := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ + // Handlers: []adk.ChatModelAgentMiddleware{ + // loggingHandler, // Outermost: sees event-sender output + // adk.NewEventSenderToolWrapper(), // Events reflect output from handlers below + // sanitizationHandler, // Innermost: runs first, modifies tool output + // }, + // }) + // + // Handler order: first registered is outermost. So [A, B, C] wraps as A(B(C(tool))). + // The event sender captures tool output in post-processing, so its position controls + // which handlers' modifications are included in the emitted events. + // + // When NewEventSenderToolWrapper is detected in Handlers, the framework skips + // the default event sender to avoid duplicate events. + // // Tool List Modification: // // There are two ways to modify the tool list: @@ -371,20 +393,15 @@ func NewChatModelAgent(ctx context.Context, config *ChatModelAgentConfig) (*Chat tc := config.ToolsConfig // Tool call middleware execution order (outermost to innermost): - // 1. eventSenderToolHandler (internal - sends tool result events after all modifications) + // 1. eventSenderToolWrapper (internal - sends tool result events after all modifications) // 2. User-provided ToolsConfig.ToolCallMiddlewares (original order preserved) // 3. Middlewares' WrapToolCall (in registration order) // 4. ChatModelAgentMiddleware.WrapToolCall (in registration order) // 5. callbackInjectedToolCall (internal - injects callbacks if tool doesn't handle them) - eventSender := &eventSenderToolHandler{} - tc.ToolCallMiddlewares = append( - []compose.ToolMiddleware{{Invokable: eventSender.WrapInvokableToolCall, - Streamable: eventSender.WrapStreamableToolCall, - EnhancedInvokable: eventSender.WrapEnhancedInvokableToolCall, - EnhancedStreamable: eventSender.WrapEnhancedStreamableToolCall, - }}, - tc.ToolCallMiddlewares..., - ) + if !hasUserEventSenderToolWrapper(config.Handlers) { + defaultToolEventSender := handlersToToolMiddlewares([]ChatModelAgentMiddleware{NewEventSenderToolWrapper()}) + tc.ToolCallMiddlewares = append(defaultToolEventSender, tc.ToolCallMiddlewares...) + } tc.ToolCallMiddlewares = append(tc.ToolCallMiddlewares, collectToolMiddlewaresFromMiddlewares(config.Middlewares)...) // Cancel monitoring middleware (innermost — close to the tool endpoint). diff --git a/adk/wrappers.go b/adk/wrappers.go index eb23549fd..e6ac617ea 100644 --- a/adk/wrappers.go +++ b/adk/wrappers.go @@ -350,20 +350,33 @@ func popToolGenAction(ctx context.Context, toolName string) *AgentAction { return action } -type eventSenderToolHandler struct{} +type eventSenderToolWrapper struct { + *BaseChatModelAgentMiddleware +} + +// NewEventSenderToolWrapper returns a ChatModelAgentMiddleware that sends tool result events. +// By default, the framework places this before all user middlewares (outermost), so events +// reflect the fully processed tool output. To control exactly where events are emitted, +// include this in ChatModelAgentConfig.Handlers at the desired position. +// When detected in Handlers, the framework skips the default event sender to avoid duplicates. +func NewEventSenderToolWrapper() ChatModelAgentMiddleware { + return &eventSenderToolWrapper{ + BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, + } +} -func (h *eventSenderToolHandler) WrapInvokableToolCall(next compose.InvokableToolEndpoint) compose.InvokableToolEndpoint { - return func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) { - output, err := next(ctx, input) +func (w *eventSenderToolWrapper) WrapInvokableToolCall(_ context.Context, endpoint InvokableToolCallEndpoint, tCtx *ToolContext) (InvokableToolCallEndpoint, error) { + return func(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { + result, err := endpoint(ctx, argumentsInJSON, opts...) if err != nil { - return nil, err + return "", err } - toolName := input.Name - callID := input.CallID + toolName := tCtx.Name + callID := tCtx.CallID prePopAction := popToolGenAction(ctx, toolName) - msg := schema.ToolMessage(output.Result, callID, schema.WithToolName(toolName)) + msg := schema.ToolMessage(result, callID, schema.WithToolName(toolName)) event := EventFromMessage(msg, nil, schema.Tool, toolName) if prePopAction != nil { event.Action = prePopAction @@ -379,22 +392,22 @@ func (h *eventSenderToolHandler) WrapInvokableToolCall(next compose.InvokableToo return nil }) - return output, nil - } + return result, nil + }, nil } -func (h *eventSenderToolHandler) WrapStreamableToolCall(next compose.StreamableToolEndpoint) compose.StreamableToolEndpoint { - return func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) { - output, err := next(ctx, input) +func (w *eventSenderToolWrapper) WrapStreamableToolCall(_ context.Context, endpoint StreamableToolCallEndpoint, tCtx *ToolContext) (StreamableToolCallEndpoint, error) { + return func(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (*schema.StreamReader[string], error) { + result, err := endpoint(ctx, argumentsInJSON, opts...) if err != nil { return nil, err } - toolName := input.Name - callID := input.CallID + toolName := tCtx.Name + callID := tCtx.CallID prePopAction := popToolGenAction(ctx, toolName) - streams := output.Result.Copy(2) + streams := result.Copy(2) cvt := func(in string) (Message, error) { return schema.ToolMessage(in, callID, schema.WithToolName(toolName)), nil @@ -413,23 +426,23 @@ func (h *eventSenderToolHandler) WrapStreamableToolCall(next compose.StreamableT return nil }) - return &compose.StreamToolOutput{Result: streams[1]}, nil - } + return streams[1], nil + }, nil } -func (h *eventSenderToolHandler) WrapEnhancedInvokableToolCall(next compose.EnhancedInvokableToolEndpoint) compose.EnhancedInvokableToolEndpoint { - return func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedInvokableToolOutput, error) { - output, err := next(ctx, input) +func (w *eventSenderToolWrapper) WrapEnhancedInvokableToolCall(_ context.Context, endpoint EnhancedInvokableToolCallEndpoint, tCtx *ToolContext) (EnhancedInvokableToolCallEndpoint, error) { + return func(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (*schema.ToolResult, error) { + result, err := endpoint(ctx, toolArgument, opts...) if err != nil { return nil, err } - toolName := input.Name - callID := input.CallID + toolName := tCtx.Name + callID := tCtx.CallID prePopAction := popToolGenAction(ctx, toolName) msg := schema.ToolMessage("", callID, schema.WithToolName(toolName)) - msg.UserInputMultiContent, err = output.Result.ToMessageInputParts() + msg.UserInputMultiContent, err = result.ToMessageInputParts() if err != nil { return nil, err } @@ -448,22 +461,22 @@ func (h *eventSenderToolHandler) WrapEnhancedInvokableToolCall(next compose.Enha return nil }) - return output, nil - } + return result, nil + }, nil } -func (h *eventSenderToolHandler) WrapEnhancedStreamableToolCall(next compose.EnhancedStreamableToolEndpoint) compose.EnhancedStreamableToolEndpoint { - return func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedStreamableToolOutput, error) { - output, err := next(ctx, input) +func (w *eventSenderToolWrapper) WrapEnhancedStreamableToolCall(_ context.Context, endpoint EnhancedStreamableToolCallEndpoint, tCtx *ToolContext) (EnhancedStreamableToolCallEndpoint, error) { + return func(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (*schema.StreamReader[*schema.ToolResult], error) { + result, err := endpoint(ctx, toolArgument, opts...) if err != nil { return nil, err } - toolName := input.Name - callID := input.CallID + toolName := tCtx.Name + callID := tCtx.CallID prePopAction := popToolGenAction(ctx, toolName) - streams := output.Result.Copy(2) + streams := result.Copy(2) cvt := func(in *schema.ToolResult) (Message, error) { msg := schema.ToolMessage("", callID, schema.WithToolName(toolName)) @@ -488,8 +501,17 @@ func (h *eventSenderToolHandler) WrapEnhancedStreamableToolCall(next compose.Enh return nil }) - return &compose.EnhancedStreamableToolOutput{Result: streams[1]}, nil + return streams[1], nil + }, nil +} + +func hasUserEventSenderToolWrapper(handlers []ChatModelAgentMiddleware) bool { + for _, handler := range handlers { + if _, ok := handler.(*eventSenderToolWrapper); ok { + return true + } } + return false } type stateModelWrapper struct { diff --git a/adk/wrappers_test.go b/adk/wrappers_test.go index 5fd8acef5..acb6588be 100644 --- a/adk/wrappers_test.go +++ b/adk/wrappers_test.go @@ -1085,3 +1085,600 @@ func (m *contentModifyingModelWrapper) Stream(ctx context.Context, input []*sche result.Content = m.newContent return schema.StreamReaderFromArray([]*schema.Message{result}), nil } + +type mockToolCallingModel struct { + mu sync.Mutex + generateCalls int + toolCallName string +} + +func (m *mockToolCallingModel) Generate(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + m.mu.Lock() + m.generateCalls++ + calls := m.generateCalls + m.mu.Unlock() + if calls == 1 { + return schema.AssistantMessage("calling tool", []schema.ToolCall{ + {ID: "tc-1", Function: schema.FunctionCall{Name: m.toolCallName, Arguments: `{"input":"test"}`}}, + }), nil + } + return schema.AssistantMessage("done", nil), nil +} + +func (m *mockToolCallingModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + msg, err := m.Generate(ctx, input, opts...) + if err != nil { + return nil, err + } + return schema.StreamReaderFromArray([]*schema.Message{msg}), nil +} + +func (m *mockToolCallingModel) WithTools(_ []*schema.ToolInfo) (model.ToolCallingChatModel, error) { + return m, nil +} + +type invokableTestTool struct { + name string + result string +} + +func (t *invokableTestTool) Info(_ context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: t.name, + Desc: "test tool", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "input": {Desc: "input", Required: true, Type: schema.String}, + }), + }, nil +} + +func (t *invokableTestTool) InvokableRun(_ context.Context, _ string, _ ...tool.Option) (string, error) { + return t.result, nil +} + +type streamableTestTool struct { + name string + result string +} + +func (t *streamableTestTool) Info(_ context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: t.name, + Desc: "test tool", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "input": {Desc: "input", Required: true, Type: schema.String}, + }), + }, nil +} + +func (t *streamableTestTool) StreamableRun(_ context.Context, _ string, _ ...tool.Option) (*schema.StreamReader[string], error) { + return schema.StreamReaderFromArray([]string{t.result}), nil +} + +type enhancedInvokableTestTool struct { + name string + result string +} + +func (t *enhancedInvokableTestTool) Info(_ context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: t.name, + Desc: "test tool", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "input": {Desc: "input", Required: true, Type: schema.String}, + }), + }, nil +} + +func (t *enhancedInvokableTestTool) InvokableRun(_ context.Context, _ *schema.ToolArgument, _ ...tool.Option) (*schema.ToolResult, error) { + return &schema.ToolResult{ + Parts: []schema.ToolOutputPart{{Type: schema.ToolPartTypeText, Text: t.result}}, + }, nil +} + +type enhancedStreamableTestTool struct { + name string + result string +} + +func (t *enhancedStreamableTestTool) Info(_ context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: t.name, + Desc: "test tool", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "input": {Desc: "input", Required: true, Type: schema.String}, + }), + }, nil +} + +func (t *enhancedStreamableTestTool) StreamableRun(_ context.Context, _ *schema.ToolArgument, _ ...tool.Option) (*schema.StreamReader[*schema.ToolResult], error) { + return schema.StreamReaderFromArray([]*schema.ToolResult{ + {Parts: []schema.ToolOutputPart{{Type: schema.ToolPartTypeText, Text: t.result}}}, + }), nil +} + +type invokableResultModifier struct { + *BaseChatModelAgentMiddleware + modifiedResult string +} + +func (h *invokableResultModifier) WrapInvokableToolCall(_ context.Context, endpoint InvokableToolCallEndpoint, _ *ToolContext) (InvokableToolCallEndpoint, error) { + return func(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { + _, err := endpoint(ctx, argumentsInJSON, opts...) + if err != nil { + return "", err + } + return h.modifiedResult, nil + }, nil +} + +type streamableResultModifier struct { + *BaseChatModelAgentMiddleware + modifiedResult string +} + +func (h *streamableResultModifier) WrapStreamableToolCall(_ context.Context, endpoint StreamableToolCallEndpoint, _ *ToolContext) (StreamableToolCallEndpoint, error) { + return func(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (*schema.StreamReader[string], error) { + sr, err := endpoint(ctx, argumentsInJSON, opts...) + if err != nil { + return nil, err + } + sr.Close() + return schema.StreamReaderFromArray([]string{h.modifiedResult}), nil + }, nil +} + +type enhancedInvokableResultModifier struct { + *BaseChatModelAgentMiddleware + modifiedResult string +} + +func (h *enhancedInvokableResultModifier) WrapEnhancedInvokableToolCall(_ context.Context, endpoint EnhancedInvokableToolCallEndpoint, _ *ToolContext) (EnhancedInvokableToolCallEndpoint, error) { + return func(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (*schema.ToolResult, error) { + _, err := endpoint(ctx, toolArgument, opts...) + if err != nil { + return nil, err + } + return &schema.ToolResult{ + Parts: []schema.ToolOutputPart{{Type: schema.ToolPartTypeText, Text: h.modifiedResult}}, + }, nil + }, nil +} + +type enhancedStreamableResultModifier struct { + *BaseChatModelAgentMiddleware + modifiedResult string +} + +func (h *enhancedStreamableResultModifier) WrapEnhancedStreamableToolCall(_ context.Context, endpoint EnhancedStreamableToolCallEndpoint, _ *ToolContext) (EnhancedStreamableToolCallEndpoint, error) { + return func(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (*schema.StreamReader[*schema.ToolResult], error) { + sr, err := endpoint(ctx, toolArgument, opts...) + if err != nil { + return nil, err + } + sr.Close() + return schema.StreamReaderFromArray([]*schema.ToolResult{ + {Parts: []schema.ToolOutputPart{{Type: schema.ToolPartTypeText, Text: h.modifiedResult}}}, + }), nil + }, nil +} + +func collectToolEvents(it *AsyncIterator[*AgentEvent]) []*AgentEvent { + var toolEvents []*AgentEvent + for { + ev, ok := it.Next() + if !ok { + break + } + if ev.Output == nil || ev.Output.MessageOutput == nil { + continue + } + mo := ev.Output.MessageOutput + if mo.Message != nil && mo.Message.Role == schema.Tool { + toolEvents = append(toolEvents, ev) + continue + } + if mo.IsStreaming && mo.Role == schema.Tool && mo.MessageStream != nil { + toolEvents = append(toolEvents, ev) + } + } + return toolEvents +} + +func collectToolContent(events []*AgentEvent) []string { + var contents []string + for _, ev := range events { + mo := ev.Output.MessageOutput + if !mo.IsStreaming && mo.Message != nil { + if mo.Message.Content != "" { + contents = append(contents, mo.Message.Content) + } else if len(mo.Message.UserInputMultiContent) > 0 { + for _, part := range mo.Message.UserInputMultiContent { + if part.Text != "" { + contents = append(contents, part.Text) + } + } + } + continue + } + if mo.IsStreaming && mo.MessageStream != nil { + var msgs []*schema.Message + for { + msg, err := mo.MessageStream.Recv() + if err != nil { + break + } + msgs = append(msgs, msg) + } + if len(msgs) > 0 { + concated, err := schema.ConcatMessages(msgs) + if err == nil { + if concated.Content != "" { + contents = append(contents, concated.Content) + } else if len(concated.UserInputMultiContent) > 0 { + for _, part := range concated.UserInputMultiContent { + if part.Text != "" { + contents = append(contents, part.Text) + } + } + } + } + } + } + } + return contents +} + +func TestEventSenderToolHandler(t *testing.T) { + t.Run("Invokable", func(t *testing.T) { + t.Run("DefaultSendsEvent", func(t *testing.T) { + ctx := context.Background() + testTool := &invokableTestTool{name: "test_tool", result: "invokable_output"} + mockModel := &mockToolCallingModel{toolCallName: "test_tool"} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: mockModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{testTool}, + }, + }, + }) + assert.NoError(t, err) + + r := NewRunner(ctx, RunnerConfig{Agent: agent, EnableStreaming: false}) + it := r.Run(ctx, []Message{schema.UserMessage("test")}) + + toolEvents := collectToolEvents(it) + assert.Equal(t, 1, len(toolEvents)) + contents := collectToolContent(toolEvents) + assert.Contains(t, contents, "invokable_output") + }) + + t.Run("UserConfiguredSkipsDefault", func(t *testing.T) { + ctx := context.Background() + testTool := &invokableTestTool{name: "test_tool", result: "invokable_output"} + mockModel := &mockToolCallingModel{toolCallName: "test_tool"} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: mockModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{testTool}, + }, + }, + Handlers: []ChatModelAgentMiddleware{NewEventSenderToolWrapper()}, + }) + assert.NoError(t, err) + + r := NewRunner(ctx, RunnerConfig{Agent: agent, EnableStreaming: false}) + it := r.Run(ctx, []Message{schema.UserMessage("test")}) + + toolEvents := collectToolEvents(it) + assert.Equal(t, 1, len(toolEvents)) + }) + + t.Run("InnermostGetsOriginalOutput", func(t *testing.T) { + ctx := context.Background() + originalResult := "original_invokable_output" + modifiedResult := "modified_invokable_output" + testTool := &invokableTestTool{name: "test_tool", result: originalResult} + mockModel := &mockToolCallingModel{toolCallName: "test_tool"} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: mockModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{testTool}, + }, + }, + Handlers: []ChatModelAgentMiddleware{ + NewEventSenderToolWrapper(), + &invokableResultModifier{ + BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, + modifiedResult: modifiedResult, + }, + }, + }) + assert.NoError(t, err) + + r := NewRunner(ctx, RunnerConfig{Agent: agent, EnableStreaming: false}) + it := r.Run(ctx, []Message{schema.UserMessage("test")}) + + toolEvents := collectToolEvents(it) + assert.GreaterOrEqual(t, len(toolEvents), 1) + contents := collectToolContent(toolEvents) + assert.Contains(t, contents, originalResult) + }) + }) + + t.Run("Streamable", func(t *testing.T) { + t.Run("DefaultSendsEvent", func(t *testing.T) { + ctx := context.Background() + testTool := &streamableTestTool{name: "test_tool", result: "streamable_output"} + mockModel := &mockToolCallingModel{toolCallName: "test_tool"} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: mockModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{testTool}, + }, + }, + }) + assert.NoError(t, err) + + r := NewRunner(ctx, RunnerConfig{Agent: agent, EnableStreaming: true}) + it := r.Run(ctx, []Message{schema.UserMessage("test")}) + + toolEvents := collectToolEvents(it) + assert.Equal(t, 1, len(toolEvents)) + contents := collectToolContent(toolEvents) + assert.Contains(t, contents, "streamable_output") + }) + + t.Run("UserConfiguredSkipsDefault", func(t *testing.T) { + ctx := context.Background() + testTool := &streamableTestTool{name: "test_tool", result: "streamable_output"} + mockModel := &mockToolCallingModel{toolCallName: "test_tool"} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: mockModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{testTool}, + }, + }, + Handlers: []ChatModelAgentMiddleware{NewEventSenderToolWrapper()}, + }) + assert.NoError(t, err) + + r := NewRunner(ctx, RunnerConfig{Agent: agent, EnableStreaming: true}) + it := r.Run(ctx, []Message{schema.UserMessage("test")}) + + toolEvents := collectToolEvents(it) + assert.Equal(t, 1, len(toolEvents)) + }) + + t.Run("InnermostGetsOriginalOutput", func(t *testing.T) { + ctx := context.Background() + originalResult := "original_streamable_output" + modifiedResult := "modified_streamable_output" + testTool := &streamableTestTool{name: "test_tool", result: originalResult} + mockModel := &mockToolCallingModel{toolCallName: "test_tool"} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: mockModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{testTool}, + }, + }, + Handlers: []ChatModelAgentMiddleware{ + NewEventSenderToolWrapper(), + &streamableResultModifier{ + BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, + modifiedResult: modifiedResult, + }, + }, + }) + assert.NoError(t, err) + + r := NewRunner(ctx, RunnerConfig{Agent: agent, EnableStreaming: true}) + it := r.Run(ctx, []Message{schema.UserMessage("test")}) + + toolEvents := collectToolEvents(it) + assert.GreaterOrEqual(t, len(toolEvents), 1) + contents := collectToolContent(toolEvents) + assert.Contains(t, contents, originalResult) + }) + }) + + t.Run("EnhancedInvokable", func(t *testing.T) { + t.Run("DefaultSendsEvent", func(t *testing.T) { + ctx := context.Background() + testTool := &enhancedInvokableTestTool{name: "test_tool", result: "enhanced_invokable_output"} + mockModel := &mockToolCallingModel{toolCallName: "test_tool"} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: mockModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{testTool}, + }, + }, + }) + assert.NoError(t, err) + + r := NewRunner(ctx, RunnerConfig{Agent: agent, EnableStreaming: false}) + it := r.Run(ctx, []Message{schema.UserMessage("test")}) + + toolEvents := collectToolEvents(it) + assert.Equal(t, 1, len(toolEvents)) + contents := collectToolContent(toolEvents) + assert.Contains(t, contents, "enhanced_invokable_output") + }) + + t.Run("UserConfiguredSkipsDefault", func(t *testing.T) { + ctx := context.Background() + testTool := &enhancedInvokableTestTool{name: "test_tool", result: "enhanced_invokable_output"} + mockModel := &mockToolCallingModel{toolCallName: "test_tool"} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: mockModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{testTool}, + }, + }, + Handlers: []ChatModelAgentMiddleware{NewEventSenderToolWrapper()}, + }) + assert.NoError(t, err) + + r := NewRunner(ctx, RunnerConfig{Agent: agent, EnableStreaming: false}) + it := r.Run(ctx, []Message{schema.UserMessage("test")}) + + toolEvents := collectToolEvents(it) + assert.Equal(t, 1, len(toolEvents)) + }) + + t.Run("InnermostGetsOriginalOutput", func(t *testing.T) { + ctx := context.Background() + originalResult := "original_enhanced_invokable_output" + modifiedResult := "modified_enhanced_invokable_output" + testTool := &enhancedInvokableTestTool{name: "test_tool", result: originalResult} + mockModel := &mockToolCallingModel{toolCallName: "test_tool"} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: mockModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{testTool}, + }, + }, + Handlers: []ChatModelAgentMiddleware{ + NewEventSenderToolWrapper(), + &enhancedInvokableResultModifier{ + BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, + modifiedResult: modifiedResult, + }, + }, + }) + assert.NoError(t, err) + + r := NewRunner(ctx, RunnerConfig{Agent: agent, EnableStreaming: false}) + it := r.Run(ctx, []Message{schema.UserMessage("test")}) + + toolEvents := collectToolEvents(it) + assert.GreaterOrEqual(t, len(toolEvents), 1) + contents := collectToolContent(toolEvents) + assert.Contains(t, contents, originalResult) + }) + }) + + t.Run("EnhancedStreamable", func(t *testing.T) { + t.Run("DefaultSendsEvent", func(t *testing.T) { + ctx := context.Background() + testTool := &enhancedStreamableTestTool{name: "test_tool", result: "enhanced_streamable_output"} + mockModel := &mockToolCallingModel{toolCallName: "test_tool"} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: mockModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{testTool}, + }, + }, + }) + assert.NoError(t, err) + + r := NewRunner(ctx, RunnerConfig{Agent: agent, EnableStreaming: true}) + it := r.Run(ctx, []Message{schema.UserMessage("test")}) + + toolEvents := collectToolEvents(it) + assert.Equal(t, 1, len(toolEvents)) + contents := collectToolContent(toolEvents) + assert.Contains(t, contents, "enhanced_streamable_output") + }) + + t.Run("UserConfiguredSkipsDefault", func(t *testing.T) { + ctx := context.Background() + testTool := &enhancedStreamableTestTool{name: "test_tool", result: "enhanced_streamable_output"} + mockModel := &mockToolCallingModel{toolCallName: "test_tool"} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: mockModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{testTool}, + }, + }, + Handlers: []ChatModelAgentMiddleware{NewEventSenderToolWrapper()}, + }) + assert.NoError(t, err) + + r := NewRunner(ctx, RunnerConfig{Agent: agent, EnableStreaming: true}) + it := r.Run(ctx, []Message{schema.UserMessage("test")}) + + toolEvents := collectToolEvents(it) + assert.Equal(t, 1, len(toolEvents)) + }) + + t.Run("InnermostGetsOriginalOutput", func(t *testing.T) { + ctx := context.Background() + originalResult := "original_enhanced_streamable_output" + modifiedResult := "modified_enhanced_streamable_output" + testTool := &enhancedStreamableTestTool{name: "test_tool", result: originalResult} + mockModel := &mockToolCallingModel{toolCallName: "test_tool"} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: mockModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{testTool}, + }, + }, + Handlers: []ChatModelAgentMiddleware{ + NewEventSenderToolWrapper(), + &enhancedStreamableResultModifier{ + BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, + modifiedResult: modifiedResult, + }, + }, + }) + assert.NoError(t, err) + + r := NewRunner(ctx, RunnerConfig{Agent: agent, EnableStreaming: true}) + it := r.Run(ctx, []Message{schema.UserMessage("test")}) + + toolEvents := collectToolEvents(it) + assert.GreaterOrEqual(t, len(toolEvents), 1) + contents := collectToolContent(toolEvents) + assert.Contains(t, contents, originalResult) + }) + }) +} From a64d35e9fefdfa94f0b74aab5a69588914a1ecd3 Mon Sep 17 00:00:00 2001 From: shentongmartin Date: Thu, 2 Apr 2026 20:17:29 +0800 Subject: [PATCH 48/65] fix(adk): prevent panic when orphaned tool goroutine sends event after agent cancellation (#929) * fix(adk): prevent panic when orphaned tool goroutine sends event after agent cancellation When CancelAfterChatModel times out and escalates to CancelImmediate, GraphInterrupt fires with timeout=0. The compose graph returns immediately, orphaning parallel tool goroutines. When an orphaned tool completes, eventSenderToolWrapper tries to send an event via the AsyncGenerator which is already closed, causing 'send on closed channel' panic. - Add isImmediateCancelled() to cancelContext for checking immediateChan - Make chatModelAgentExecCtx.send cancel-aware: skip send when immediate cancel is active - Use trySend as safety net for the TOCTOU race window - Route SendEvent() through execCtx.send() instead of direct generator.Send() Change-Id: Ic7e0194c860e2692a3cddc559911ab379024f650 * test(adk): add test for orphaned tool goroutine panic after CancelImmediate - unit_send_after_close: directly reproduces the panic by sending to a closed generator with isImmediateCancelled=true - unit_send_after_close_without_cancel_ctx: verifies trySend safety net prevents panic even without cancelCtx - integration_cancel_escalation_orphans_tool: end-to-end test with slow tool, CancelAfterChatModel timeout escalation, and orphaned goroutine Change-Id: Ia82fa957b102ccc2ac42094d18d4b15db2a1701c * test(adk): improve coverage for orphaned tool goroutine fix Add test cases for: - nil execCtx and nil generator defensive guards - nil cancelContext in isImmediateCancelled - TOCTOU race window (isImmediateCancelled=false but generator closed) - SendEvent public API with closed generator - SendEvent without exec context Change-Id: I197c36f34675f5376cbe5f830b15db6ca873cd1f --- adk/cancel.go | 16 ++++ adk/cancel_test.go | 186 +++++++++++++++++++++++++++++++++++++++++++++ adk/chatmodel.go | 11 ++- adk/handler.go | 2 +- adk/utils.go | 4 + 5 files changed, 216 insertions(+), 3 deletions(-) diff --git a/adk/cancel.go b/adk/cancel.go index 20d72bb20..f119d79f5 100644 --- a/adk/cancel.go +++ b/adk/cancel.go @@ -403,6 +403,22 @@ func (cc *cancelContext) shouldCancel() bool { } } +// isImmediateCancelled returns true if an immediate graph interrupt has been +// fired (CancelImmediate or timeout escalation). This is stronger than +// shouldCancel: it means the compose graph is being torn down right now and +// orphaned goroutines should not attempt to send events. +func (cc *cancelContext) isImmediateCancelled() bool { + if cc == nil { + return false + } + select { + case <-cc.immediateChan: + return true + default: + return false + } +} + // sendImmediateInterrupt sends the compose graph interrupt signal via graphInterruptFuncs. // Also closes immediateChan (used by cancelMonitoredModel to abort an in-progress stream). // Returns false if an interrupt was already sent or if no graphInterruptFuncs have been diff --git a/adk/cancel_test.go b/adk/cancel_test.go index 0d88db8cc..105c9ea13 100644 --- a/adk/cancel_test.go +++ b/adk/cancel_test.go @@ -2305,6 +2305,192 @@ func TestCancel_SequentialWorkflow_CancelAfterChatModel(t *testing.T) { assert.True(t, len(resumeEvents) > 0, "Resume should produce events") } +func TestCancelImmediate_OrphanedToolGoroutine_NoPanic(t *testing.T) { + t.Run("unit_send_after_close", func(t *testing.T) { + _, gen := NewAsyncIteratorPair[*AgentEvent]() + + cc := newCancelContext() + cc.setMode(CancelImmediate) + close(cc.cancelChan) + close(cc.immediateChan) + + gen.Close() + + execCtx := &chatModelAgentExecCtx{ + generator: gen, + cancelCtx: cc, + } + + assert.NotPanics(t, func() { + execCtx.send(&AgentEvent{AgentName: "test"}) + }, "send after generator.Close must not panic") + }) + + t.Run("unit_send_after_close_without_cancel_ctx", func(t *testing.T) { + _, gen := NewAsyncIteratorPair[*AgentEvent]() + gen.Close() + + execCtx := &chatModelAgentExecCtx{ + generator: gen, + } + + assert.NotPanics(t, func() { + execCtx.send(&AgentEvent{AgentName: "test"}) + }, "send after generator.Close must not panic even without cancelCtx (trySend safety net)") + }) + + t.Run("unit_send_nil_execCtx", func(t *testing.T) { + var execCtx *chatModelAgentExecCtx + assert.NotPanics(t, func() { + execCtx.send(&AgentEvent{AgentName: "test"}) + }, "send on nil execCtx must not panic") + }) + + t.Run("unit_send_nil_generator", func(t *testing.T) { + execCtx := &chatModelAgentExecCtx{} + assert.NotPanics(t, func() { + execCtx.send(&AgentEvent{AgentName: "test"}) + }, "send with nil generator must not panic") + }) + + t.Run("unit_isImmediateCancelled_nil_cancelContext", func(t *testing.T) { + var cc *cancelContext + assert.False(t, cc.isImmediateCancelled(), "nil cancelContext should return false") + }) + + t.Run("unit_trySend_race_window", func(t *testing.T) { + _, gen := NewAsyncIteratorPair[*AgentEvent]() + cc := newCancelContext() + + gen.Close() + + execCtx := &chatModelAgentExecCtx{ + generator: gen, + cancelCtx: cc, + } + + assert.NotPanics(t, func() { + execCtx.send(&AgentEvent{AgentName: "test"}) + }, "trySend must handle the case where isImmediateCancelled is false but generator is closed") + }) + + t.Run("unit_SendEvent_after_close", func(t *testing.T) { + _, gen := NewAsyncIteratorPair[*AgentEvent]() + + cc := newCancelContext() + cc.setMode(CancelImmediate) + close(cc.cancelChan) + close(cc.immediateChan) + + gen.Close() + + execCtx := &chatModelAgentExecCtx{ + generator: gen, + cancelCtx: cc, + } + + ctx := withChatModelAgentExecCtx(context.Background(), execCtx) + + assert.NotPanics(t, func() { + err := SendEvent(ctx, &AgentEvent{AgentName: "test"}) + assert.NoError(t, err) + }, "SendEvent after generator.Close must not panic") + }) + + t.Run("unit_SendEvent_no_execCtx", func(t *testing.T) { + err := SendEvent(context.Background(), &AgentEvent{AgentName: "test"}) + assert.Error(t, err, "SendEvent without execCtx should return error") + }) + + t.Run("integration_cancel_escalation_orphans_tool", func(t *testing.T) { + ctx := context.Background() + + toolStarted := make(chan struct{}, 1) + toolDone := make(chan struct{}, 1) + st := &slowToolWithSignal{ + name: "orphan_tool", + delay: 2 * time.Second, + result: "tool result", + startedChan: toolStarted, + } + + mdl := &simpleChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_orphan_1", + Type: "function", + Function: schema.FunctionCall{ + Name: "orphan_tool", + Arguments: `{"input": "test"}`, + }, + }, + }, + }, + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "OrphanTestAgent", + Description: "Test agent for orphaned tool goroutine panic", + Model: mdl, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + + cancelOpt, cancelFn := WithCancel() + iter := agent.Run(ctx, &AgentInput{ + Messages: []Message{schema.UserMessage("Use the tool")}, + }, cancelOpt) + assert.NotNil(t, iter) + + select { + case <-toolStarted: + case <-time.After(10 * time.Second): + t.Fatal("Tool did not start") + } + + timeout := 50 * time.Millisecond + handle, contributed := cancelFn( + WithAgentCancelMode(CancelAfterChatModel), + WithAgentCancelTimeout(timeout), + ) + assert.True(t, contributed, "Cancel should contribute") + + err = handle.Wait() + assert.True(t, err == nil || errors.Is(err, ErrCancelTimeout), + "handle.Wait should return nil or ErrCancelTimeout, got: %v", err) + + for { + _, ok := iter.Next() + if !ok { + break + } + } + + go func() { + time.Sleep(3 * time.Second) + select { + case toolDone <- struct{}{}: + default: + } + }() + + runtime.Gosched() + time.Sleep(3 * time.Second) + + select { + case <-toolDone: + default: + } + }) +} + // -- Tests for CancelImmediate in nested agent structures -- func newTestChatModel(response *schema.Message, delay time.Duration) *cancelTestChatModel { diff --git a/adk/chatmodel.go b/adk/chatmodel.go index d7382b309..4b736d51d 100644 --- a/adk/chatmodel.go +++ b/adk/chatmodel.go @@ -43,12 +43,17 @@ var _ ResumableAgent = &ChatModelAgent{} type chatModelAgentExecCtx struct { runtimeReturnDirectly map[string]bool generator *AsyncGenerator[*AgentEvent] + cancelCtx *cancelContext } func (e *chatModelAgentExecCtx) send(event *AgentEvent) { - if e != nil && e.generator != nil { - e.generator.Send(event) + if e == nil || e.generator == nil { + return + } + if e.cancelCtx != nil && e.cancelCtx.isImmediateCancelled() { + return } + e.generator.trySend(event) } type chatModelAgentExecCtxKey struct{} @@ -837,6 +842,7 @@ func (a *ChatModelAgent) buildNoToolsRunFunc(_ context.Context) runFunc { ctx = withChatModelAgentExecCtx(ctx, &chatModelAgentExecCtx{ generator: p.generator, + cancelCtx: cancelCtx, }) // Pre-execution cancel check @@ -947,6 +953,7 @@ func (a *ChatModelAgent) buildReActRunFunc(_ context.Context, bc *execContext) ( ctx = withChatModelAgentExecCtx(ctx, &chatModelAgentExecCtx{ runtimeReturnDirectly: p.returnDirectly, generator: p.generator, + cancelCtx: cancelCtx, }) // Pre-execution cancel check diff --git a/adk/handler.go b/adk/handler.go index 423282a7a..854063f16 100644 --- a/adk/handler.go +++ b/adk/handler.go @@ -347,7 +347,7 @@ func SendEvent(ctx context.Context, event *AgentEvent) error { if execCtx == nil || execCtx.generator == nil { return fmt.Errorf("SendEvent failed: must be called within a ChatModelAgent Run() or Resume() execution context") } - execCtx.generator.Send(event) + execCtx.send(event) return nil } diff --git a/adk/utils.go b/adk/utils.go index 11d3d7eeb..5dd890be8 100644 --- a/adk/utils.go +++ b/adk/utils.go @@ -44,6 +44,10 @@ func (ag *AsyncGenerator[T]) Send(v T) { ag.ch.Send(v) } +func (ag *AsyncGenerator[T]) trySend(v T) bool { + return ag.ch.TrySend(v) +} + func (ag *AsyncGenerator[T]) Close() { ag.ch.Close() } From d29b95e998ae857e052b00c4d00c3599b35d9395 Mon Sep 17 00:00:00 2001 From: shentongmartin Date: Wed, 8 Apr 2026 14:06:52 +0800 Subject: [PATCH 49/65] feat(adk): improve TurnLoop stop cleanup and add StopOption controls (#925) * fix(adk): keep late turn loop items Change-Id: Iabee0c25a83d5a25585d3592a41ca6a5fba35c2b * docs(adk): clarify cancel wait semantics Change-Id: Ia0a396b9cc2e43f15e85056d966f20b010dcd2b6 * feat(adk): add WithSkipCheckpoint and WithStopCause StopOptions Add two new StopOption variants for TurnLoop.Stop(): - WithSkipCheckpoint: prevents checkpoint persistence on stop, for cases where the caller does not intend to resume in the future. The flag is sticky across escalation calls. - WithStopCause: attaches a business-supplied reason string. Surfaced in TurnLoopExitState.StopCause and, after the Stopped channel closes, via TurnContext.StopCause(). Uses first-non-empty-wins semantics across multiple Stop() calls. Thread both fields through stopSignal with proper mutex protection. Update cleanup() to skip checkpoint save when skipCheckpoint is set. Change-Id: Ifeat-stop-options-skip-checkpoint-stop-cause --- adk/cancel.go | 10 + adk/turn_loop.go | 149 ++++++++- adk/turn_loop_test.go | 605 ++++++++++++++++++++++++++++++++++++- internal/core/interrupt.go | 8 +- 4 files changed, 748 insertions(+), 24 deletions(-) diff --git a/adk/cancel.go b/adk/cancel.go index f119d79f5..a15699f53 100644 --- a/adk/cancel.go +++ b/adk/cancel.go @@ -67,6 +67,16 @@ type CancelHandle struct { wait func() error } +// Wait blocks until the cancel request reaches a terminal outcome. +// +// It reports the result of the cancel operation itself, not the agent's final +// business error: +// - nil: cancellation succeeded, including the case where a business interrupt +// was absorbed into CancelError while cancellation was active +// - ErrCancelTimeout: the requested safe-point cancellation timed out and was +// escalated to immediate cancellation +// - ErrExecutionCompleted: the execution finished before cancellation took effect, +// meaning the stream drained to completion without any interrupt func (h *CancelHandle) Wait() error { return h.wait() } diff --git a/adk/turn_loop.go b/adk/turn_loop.go index 000842589..7d0f61a3b 100644 --- a/adk/turn_loop.go +++ b/adk/turn_loop.go @@ -61,6 +61,8 @@ type stopSignal struct { mu sync.Mutex gen uint64 agentCancelOpts []AgentCancelOption + skipCheckpoint bool + stopCause string // notify is a buffered(1) channel that wakes the current turn's watcher // when Stop() is called. Unlike done, it supports repeated Stop() calls // for cancel-mode escalation. @@ -81,6 +83,12 @@ func (s *stopSignal) signal(cfg *stopConfig) { s.mu.Lock() s.gen++ s.agentCancelOpts = cfg.agentCancelOpts + if cfg.skipCheckpoint { + s.skipCheckpoint = true + } + if cfg.stopCause != "" && s.stopCause == "" { + s.stopCause = cfg.stopCause + } s.mu.Unlock() select { case s.notify <- struct{}{}: @@ -111,6 +119,18 @@ func (s *stopSignal) check() (uint64, []AgentCancelOption) { return s.gen, append([]AgentCancelOption{}, s.agentCancelOpts...) } +func (s *stopSignal) isSkipCheckpoint() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.skipCheckpoint +} + +func (s *stopSignal) getStopCause() string { + s.mu.Lock() + defer s.mu.Unlock() + return s.stopCause +} + // preemptSignal coordinates preemption between Push callers and the run loop. // // Lifecycle overview: @@ -517,9 +537,11 @@ type TurnLoopExitState[T any] struct { // nil means clean exit (Stop() was called and completed normally). // Non-nil values include context errors, callback errors, *CancelError, etc. // When Stop() cancels a running agent, ExitReason will be a *CancelError. + // This never contains checkpoint errors — see CheckpointErr for those. ExitReason error // UnhandledItems contains items that were buffered but not processed. + // These are items for which Push returned true but were never consumed by a turn. // This is always valid regardless of ExitReason. UnhandledItems []T @@ -528,6 +550,33 @@ type TurnLoopExitState[T any] struct { // did not contribute to the final CancelError. // It can be used to reconstruct GenInput/PrepareAgent inputs when resuming. CanceledItems []T + + // StopCause is the business-supplied reason passed via WithStopCause. + // Empty if Stop was not called or no cause was provided. + StopCause string + + // Checkpointed indicates whether a checkpoint save was attempted during cleanup. + // True only when Store is configured, CheckpointID is set, Stop() was called, + // and the loop was not idle at exit time. + Checkpointed bool + + // CheckpointErr is the error from checkpoint save, if any. + // nil when Checkpointed is false (no attempt was made) or when the save succeeded. + CheckpointErr error + + // TakeLateItems returns items that were pushed after the loop stopped + // (i.e., Push returned false for these items). These items are NOT included + // in the checkpoint. + // + // This function is idempotent: the first call computes and caches the result; + // subsequent calls return the same slice. + // + // After TakeLateItems is called, any subsequent Push() will panic. This + // seals the late buffer and prevents items from being silently lost. + // + // It is safe to call TakeLateItems from any goroutine after Wait() returns. + // If TakeLateItems is never called, late items are simply garbage collected. + TakeLateItems func() []T } // TurnContext provides per-turn context to the OnAgentEvents callback. @@ -552,6 +601,11 @@ type TurnContext[T any] struct { // before it was finalized. Remains open if Stop did not contribute. // Use in a select to detect stop while processing events. Stopped <-chan struct{} + + // StopCause returns the business-supplied reason from WithStopCause. + // This value is only meaningful after the Stopped channel is closed. + // Before that, it returns an empty string. + StopCause func() string } // TurnLoop is a push-based event loop for agent execution. @@ -602,6 +656,19 @@ type TurnLoop[T any] struct { loadCheckpointID string onAgentEvents func(ctx context.Context, tc *TurnContext[T], events *AsyncIterator[*AgentEvent]) error + + lateMu sync.Mutex + lateItems []T + lateSealed bool +} + +func (l *TurnLoop[T]) appendLate(item T) { + l.lateMu.Lock() + defer l.lateMu.Unlock() + if l.lateSealed { + panic("TurnLoop: Push called after TakeLateItems") + } + l.lateItems = append(l.lateItems, item) } type turnLoopCheckpoint[T any] struct { @@ -652,7 +719,7 @@ func (l *TurnLoop[T]) deleteTurnLoopCheckpoint(ctx context.Context, checkPointID if deleter, ok := l.config.Store.(CheckPointDeleter); ok { return deleter.Delete(ctx, checkPointID) } - return l.config.Store.Set(ctx, checkPointID, nil) + return nil } func (l *TurnLoop[T]) tryLoadCheckpoint(ctx context.Context) error { @@ -712,6 +779,8 @@ type turnLoopPendingResume[T any] struct { type stopConfig struct { agentCancelOpts []AgentCancelOption + skipCheckpoint bool + stopCause string } // StopOption is an option for Stop(). @@ -725,6 +794,25 @@ func WithAgentCancel(opts ...AgentCancelOption) StopOption { } } +// WithSkipCheckpoint tells the TurnLoop not to persist a checkpoint for this +// Stop call. Use this when the caller does not intend to resume in the future. +// The flag is sticky: once any Stop() call sets it, subsequent calls cannot undo it. +func WithSkipCheckpoint() StopOption { + return func(cfg *stopConfig) { + cfg.skipCheckpoint = true + } +} + +// WithStopCause attaches a business-supplied reason string to this Stop call. +// The cause is surfaced in TurnLoopExitState.StopCause and, after the Stopped +// channel closes, via TurnContext.StopCause(). +// If multiple Stop() calls provide a cause, the first non-empty value wins. +func WithStopCause(cause string) StopOption { + return func(cfg *stopConfig) { + cfg.stopCause = cause + } +} + type pushConfig[T any] struct { preempt bool preemptDelay time.Duration @@ -852,6 +940,9 @@ func (l *TurnLoop[T]) Run(ctx context.Context) { // current agent run or reaching a point where no cancellation is needed). // If the loop has not been started yet (Run not called), items are buffered // and will be processed once Run is called. +// After Wait() returns, failed pushes can be recovered via TurnLoopExitState.TakeLateItems(). +// Once TakeLateItems() has been called, any subsequent push that would become a +// late item will panic instead of being silently dropped. // // Use WithPreempt() to atomically push an item and signal preemption of the current agent. // This is useful for urgent items that should interrupt the current processing. @@ -898,16 +989,22 @@ func (l *TurnLoop[T]) pushWithStrategy(item T, cfg *pushConfig[T]) (bool, <-chan if !cfg.preempt { l.preemptSig.unholdRunLoop() - return l.buffer.TrySend(item), nil + if !l.buffer.TrySend(item) { + l.appendLate(item) + return false, nil + } + return true, nil } if atomic.LoadInt32(&l.stopped) != 0 { l.preemptSig.unholdRunLoop() + l.appendLate(item) return false, nil } if !l.buffer.TrySend(item) { l.preemptSig.unholdRunLoop() + l.appendLate(item) return false, nil } @@ -936,6 +1033,7 @@ func (l *TurnLoop[T]) pushWithStrategy(item T, cfg *pushConfig[T]) (bool, <-chan func (l *TurnLoop[T]) pushWithConfig(item T, cfg *pushConfig[T]) (bool, <-chan struct{}) { if atomic.LoadInt32(&l.stopped) != 0 { + l.appendLate(item) return false, nil } @@ -944,6 +1042,7 @@ func (l *TurnLoop[T]) pushWithConfig(item T, cfg *pushConfig[T]) (bool, <-chan s if !l.buffer.TrySend(item) { l.preemptSig.unholdRunLoop() + l.appendLate(item) return false, nil } @@ -970,7 +1069,11 @@ func (l *TurnLoop[T]) pushWithConfig(item T, cfg *pushConfig[T]) (bool, <-chan s return true, ack } - return l.buffer.TrySend(item), nil + if !l.buffer.TrySend(item) { + l.appendLate(item) + return false, nil + } + return true, nil } // Stop signals the loop to stop and returns immediately (non-blocking). @@ -1295,6 +1398,7 @@ func (l *TurnLoop[T]) runAgentAndHandleEvents( Consumed: spec.consumed, Preempted: preemptDone, Stopped: stoppedDone, + StopCause: l.stopSig.getStopCause, } l.preemptSig.setTurn(ctx, tc) @@ -1398,7 +1502,18 @@ func (l *TurnLoop[T]) cleanup(ctx context.Context) { unhandled := l.buffer.TakeAll() checkpointID := l.config.CheckpointID isIdle := len(l.checkPointRunnerBytes) == 0 && len(unhandled) == 0 && len(l.canceledItems) == 0 - shouldSaveCheckpoint := l.config.Store != nil && checkpointID != "" && l.stopSig.isStopped() && !isIdle + + // Only save checkpoint when the loop exited due to an explicit Stop(). + // If Stop() was called but a callback error happened concurrently, + // the state may be inconsistent — don't checkpoint in that case. + // We consider the exit Stop-caused if runErr is nil (clean stop between + // turns) or a *CancelError (Stop canceled a running agent). + exitCausedByStop := l.runErr == nil || errors.As(l.runErr, new(*CancelError)) + shouldSaveCheckpoint := l.config.Store != nil && checkpointID != "" && l.stopSig.isStopped() && exitCausedByStop && !isIdle && !l.stopSig.isSkipCheckpoint() + + var checkpointed bool + var checkpointErr error + if shouldSaveCheckpoint { cp := &turnLoopCheckpoint[T]{ RunnerCheckpoint: l.checkPointRunnerBytes, @@ -1406,23 +1521,31 @@ func (l *TurnLoop[T]) cleanup(ctx context.Context) { UnhandledItems: unhandled, CanceledItems: l.canceledItems, } - err := l.saveTurnLoopCheckpoint(ctx, checkpointID, cp) - if err != nil { - saveErr := fmt.Errorf("failed to save turn loop checkpoint: %w", err) - if l.runErr != nil { - l.runErr = fmt.Errorf("%w; %v", l.runErr, saveErr) - } else { - l.runErr = saveErr - } - } + checkpointed = true + checkpointErr = l.saveTurnLoopCheckpoint(ctx, checkpointID, cp) } else if l.loadCheckpointID != "" { _ = l.deleteTurnLoopCheckpoint(ctx, l.loadCheckpointID) } + var takeLateOnce sync.Once + var takeLateResult []T + l.result = &TurnLoopExitState[T]{ ExitReason: l.runErr, UnhandledItems: unhandled, CanceledItems: l.canceledItems, + StopCause: l.stopSig.getStopCause(), + Checkpointed: checkpointed, + CheckpointErr: checkpointErr, + TakeLateItems: func() []T { + takeLateOnce.Do(func() { + l.lateMu.Lock() + takeLateResult = append([]T{}, l.lateItems...) + l.lateSealed = true + l.lateMu.Unlock() + }) + return takeLateResult + }, } l.preemptSig.drainAll() diff --git a/adk/turn_loop_test.go b/adk/turn_loop_test.go index 72753fc5e..1b8b2c86d 100644 --- a/adk/turn_loop_test.go +++ b/adk/turn_loop_test.go @@ -1853,7 +1853,9 @@ func TestTurnLoop_CheckpointSaveError_ReturnsError(t *testing.T) { loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) exit := loop.Wait() assert.Error(t, exit.ExitReason) - assert.Contains(t, exit.ExitReason.Error(), "write failed") + assert.True(t, exit.Checkpointed) + assert.Error(t, exit.CheckpointErr) + assert.Contains(t, exit.CheckpointErr.Error(), "write failed") } func TestTurnLoop_StaleCheckpointDeletion_OnCleanResume(t *testing.T) { @@ -1916,7 +1918,7 @@ func TestTurnLoop_StaleCheckpointDeletion_OnCleanResume(t *testing.T) { func TestTurnLoop_StaleCheckpointDeletion_ContextCancel(t *testing.T) { ctx := context.Background() - store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + store := &deletableCheckpointStore{turnLoopCheckpointStore: turnLoopCheckpointStore{m: make(map[string][]byte)}} cpID := "delete-on-cancel" loop1 := NewTurnLoop(TurnLoopConfig[string]{ @@ -1968,11 +1970,10 @@ func TestTurnLoop_StaleCheckpointDeletion_ContextCancel(t *testing.T) { assert.ErrorIs(t, exit2.ExitReason, context.Canceled) store.mu.Lock() - v, exists := store.m[cpID] + _, exists = store.m[cpID] + deleteCalled := store.deleteCalled store.mu.Unlock() - deletedViaNil := exists && v == nil - deletedViaAbsence := !exists - assert.True(t, deletedViaNil || deletedViaAbsence, "stale checkpoint should be deleted when loop exits via context cancellation") + assert.True(t, deleteCalled && !exists, "stale checkpoint should be deleted when loop exits via context cancellation") } type deletableCheckpointStore struct { @@ -2325,10 +2326,11 @@ func TestTurnLoop_CheckpointSaveError_MergesWithExistingError(t *testing.T) { loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) exit := loop.Wait() assert.Error(t, exit.ExitReason) - errStr := exit.ExitReason.Error() - assert.Contains(t, errStr, "disk full") var ce *CancelError - assert.True(t, errors.As(exit.ExitReason, &ce), "should wrap original CancelError") + assert.True(t, errors.As(exit.ExitReason, &ce), "ExitReason should be CancelError, not merged with checkpoint error") + assert.True(t, exit.Checkpointed) + assert.Error(t, exit.CheckpointErr) + assert.Contains(t, exit.CheckpointErr.Error(), "disk full") } func TestTurnLoop_ResumeWithParams(t *testing.T) { @@ -3748,3 +3750,588 @@ func TestTurnLoop_PushStrategy_ConsumedInspection(t *testing.T) { loop.Stop() loop.Wait() } + +func TestTurnLoop_PushAfterStop_BufferedAsLateItems(t *testing.T) { + ctx := context.Background() + processed := make(chan string, 10) + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := events.Next(); !ok { + break + } + } + processed <- tc.Consumed[0] + return nil + }, + }) + + loop.Push("msg1") + <-processed + loop.Stop() + result := loop.Wait() + + // Push after stop — should be buffered as late items + ok1, _ := loop.Push("late1") + ok2, _ := loop.Push("late2") + ok3, _ := loop.Push("late3") + assert.False(t, ok1) + assert.False(t, ok2) + assert.False(t, ok3) + + late := result.TakeLateItems() + assert.Equal(t, []string{"late1", "late2", "late3"}, late) +} + +func TestTurnLoop_TakeLateItems_Idempotent(t *testing.T) { + ctx := context.Background() + + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop.Push("a") + loop.Stop() + loop.Run(ctx) + result := loop.Wait() + + loop.Push("late1") + + first := result.TakeLateItems() + second := result.TakeLateItems() + third := result.TakeLateItems() + + assert.Equal(t, []string{"late1"}, first) + assert.Equal(t, first, second, "subsequent calls should return the same slice") + assert.Equal(t, first, third, "subsequent calls should return the same slice") +} + +func TestTurnLoop_PushAfterTakeLateItems_Panics(t *testing.T) { + ctx := context.Background() + + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop.Push("a") + loop.Stop() + loop.Run(ctx) + result := loop.Wait() + + result.TakeLateItems() + + assert.PanicsWithValue(t, "TurnLoop: Push called after TakeLateItems", func() { + loop.Push("too-late") + }) +} + +func TestTurnLoop_TakeLateItems_NeverCalled_NoImpact(t *testing.T) { + ctx := context.Background() + + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop.Push("a") + loop.Push("b") + loop.Stop() + loop.Run(ctx) + result := loop.Wait() + + // Don't call TakeLateItems — verify UnhandledItems works normally + assert.Contains(t, result.UnhandledItems, "b") + assert.Nil(t, result.ExitReason) +} + +func TestTurnLoop_CheckpointErr_SeparateFromExitReason(t *testing.T) { + ctx := context.Background() + saveStore := &errorCheckpointStore{setErr: fmt.Errorf("storage unavailable")} + + loop := NewTurnLoop(TurnLoopConfig[string]{ + Store: saveStore, + CheckpointID: "cp-separate-err", + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop.Push("a") + loop.Stop() + loop.Run(ctx) + result := loop.Wait() + + // ExitReason should be nil (clean stop), checkpoint error should be separate + assert.Nil(t, result.ExitReason) + assert.True(t, result.Checkpointed) + assert.Error(t, result.CheckpointErr) + assert.Contains(t, result.CheckpointErr.Error(), "storage unavailable") +} + +func TestTurnLoop_Checkpointed_FalseWhenNoStore(t *testing.T) { + ctx := context.Background() + + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop.Push("a") + loop.Stop() + loop.Run(ctx) + result := loop.Wait() + + assert.False(t, result.Checkpointed) + assert.Nil(t, result.CheckpointErr) +} + +func TestTurnLoop_Checkpointed_FalseOnErrorExit(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + genInputErr := errors.New("gen input failed") + + firstTurnDone := make(chan struct{}) + var callCount int32 + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + Store: store, + CheckpointID: "cp-err-exit", + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + n := atomic.AddInt32(&callCount, 1) + if n > 1 { + return nil, genInputErr + } + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := events.Next(); !ok { + break + } + } + close(firstTurnDone) + return nil + }, + }) + loop.Push("msg1") + <-firstTurnDone + loop.Push("msg2") + result := loop.Wait() + + // Loop exited from error, not Stop() — checkpoint should not be saved + assert.ErrorIs(t, result.ExitReason, genInputErr) + assert.False(t, result.Checkpointed) + assert.Nil(t, result.CheckpointErr) +} + +func TestTurnLoop_StopConcurrentWithCallbackError_NoCheckpoint(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + cpID := "stop-concurrent-err" + + prepareErr := errors.New("prepare agent failed") + firstTurnDone := make(chan struct{}) + stopCalled := make(chan struct{}) + var prepareCount int32 + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + n := atomic.AddInt32(&prepareCount, 1) + if n > 1 { + // Wait until Stop() has been called so stopSig.isStopped() is true + <-stopCalled + return nil, prepareErr + } + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := events.Next(); !ok { + break + } + } + close(firstTurnDone) + return nil + }, + }) + + loop.Push("msg1") + <-firstTurnDone + loop.Push("msg2") + + // Call Stop() and signal PrepareAgent to proceed with error + go func() { + loop.Stop() + close(stopCalled) + }() + + result := loop.Wait() + + // The loop may exit via Stop (clean) or via PrepareAgent error. + // If it exited via PrepareAgent error with Stop also called: + // checkpoint should NOT be saved. + if result.ExitReason != nil && !errors.As(result.ExitReason, new(*CancelError)) { + assert.ErrorIs(t, result.ExitReason, prepareErr) + assert.False(t, result.Checkpointed, "should not checkpoint when exit is caused by callback error") + } + // If Stop won the race, that's fine — checkpoint may or may not be saved + // depending on idle state. The test is about the error path. +} + +func TestTurnLoop_DeleteWithoutCheckPointDeleter_NoOp(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + cpID := "no-deleter" + + // First loop: save a checkpoint + loop1 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop1.Push("a") + loop1.Stop() + loop1.Run(ctx) + loop1.Wait() + + store.mu.Lock() + _, exists := store.m[cpID] + store.mu.Unlock() + assert.True(t, exists, "checkpoint should be saved") + + // Second loop: exit via context cancel — should try to delete but store + // doesn't implement CheckPointDeleter, so checkpoint persists (no-op) + ctx2, cancel2 := context.WithCancel(ctx) + loop2 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := events.Next(); !ok { + break + } + } + cancel2() + return nil + }, + }) + loop2.Push("b") + loop2.Run(ctx2) + loop2.Wait() + + // Without CheckPointDeleter, the stale checkpoint should NOT be deleted + store.mu.Lock() + v, exists := store.m[cpID] + store.mu.Unlock() + assert.True(t, exists, "checkpoint should still exist without CheckPointDeleter") + assert.NotNil(t, v, "checkpoint should not be set to nil") +} + +func TestTurnLoop_StopWithSkipCheckpoint(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + cpID := "skip-cp-session" + + loop := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Push("a") + loop.Push("b") + loop.Stop(WithSkipCheckpoint()) + loop.Run(ctx) + + exit := loop.Wait() + assert.NoError(t, exit.ExitReason) + assert.False(t, exit.Checkpointed, "checkpoint should be skipped when WithSkipCheckpoint is used") + + store.mu.Lock() + _, exists := store.m[cpID] + store.mu.Unlock() + assert.False(t, exists, "no checkpoint should be saved when WithSkipCheckpoint is used") +} + +func TestTurnLoop_StopWithSkipCheckpoint_DeletesStaleCheckpoint(t *testing.T) { + ctx := context.Background() + store := &deletableCheckpointStore{ + turnLoopCheckpointStore: turnLoopCheckpointStore{m: make(map[string][]byte)}, + } + cpID := "skip-stale-session" + + loop1 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop1.Push("a") + loop1.Stop() + loop1.Run(ctx) + exit1 := loop1.Wait() + assert.True(t, exit1.Checkpointed) + + store.mu.Lock() + _, exists := store.m[cpID] + store.mu.Unlock() + assert.True(t, exists, "first loop should save checkpoint") + + loop2 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop2.Push("b") + loop2.Stop(WithSkipCheckpoint()) + loop2.Run(ctx) + exit2 := loop2.Wait() + assert.False(t, exit2.Checkpointed, "second loop should skip checkpoint") + + store.mu.Lock() + deleteCalled := store.deleteCalled + store.mu.Unlock() + assert.True(t, deleteCalled, "stale checkpoint should be deleted when SkipCheckpoint is used") +} + +func TestTurnLoop_StopWithStopCause(t *testing.T) { + ctx := context.Background() + cause := "user session timeout" + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Push("a") + loop.Stop(WithStopCause(cause)) + + exit := loop.Wait() + assert.Equal(t, cause, exit.StopCause) +} + +func TestTurnLoop_StopCause_EmptyWhenNoStop(t *testing.T) { + ctx := context.Background() + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Stop() + exit := loop.Wait() + assert.Empty(t, exit.StopCause) +} + +func TestTurnLoop_StopCause_InTurnContext(t *testing.T) { + cause := "business shutdown" + gotCause := make(chan string, 1) + agentStarted := make(chan struct{}) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopCancellableMockAgent{ + name: "slow", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + <-ctx.Done() + return nil, ctx.Err() + }, + }, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + close(agentStarted) + select { + case <-tc.Stopped: + gotCause <- tc.StopCause() + case <-time.After(5 * time.Second): + t.Error("timed out waiting for Stopped channel") + } + for { + if _, ok := events.Next(); !ok { + break + } + } + return nil + }, + }) + + loop.Push("msg1") + <-agentStarted + loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate)), WithStopCause(cause)) + + select { + case c := <-gotCause: + assert.Equal(t, cause, c) + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for StopCause in TurnContext") + } + + exit := loop.Wait() + assert.Equal(t, cause, exit.StopCause) +} + +func TestTurnLoop_StopCause_FirstNonEmptyWins(t *testing.T) { + agentStarted := make(chan struct{}) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopCancellableMockAgent{ + name: "slow", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + <-ctx.Done() + return nil, ctx.Err() + }, + }, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + close(agentStarted) + for { + if _, ok := events.Next(); !ok { + break + } + } + return nil + }, + }) + + loop.Push("msg1") + <-agentStarted + loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelAfterToolCalls)), WithStopCause("first cause")) + loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate)), WithStopCause("second cause")) + + exit := loop.Wait() + assert.Equal(t, "first cause", exit.StopCause, "first non-empty StopCause should win") +} + +func TestTurnLoop_SkipCheckpoint_Sticky(t *testing.T) { + agentStarted := make(chan struct{}) + + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + cpID := "sticky-skip-session" + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopCancellableMockAgent{ + name: "slow", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + <-ctx.Done() + return nil, ctx.Err() + }, + }, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + close(agentStarted) + for { + if _, ok := events.Next(); !ok { + break + } + } + return nil + }, + }) + + loop.Push("msg1") + <-agentStarted + loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelAfterToolCalls)), WithSkipCheckpoint()) + loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + + exit := loop.Wait() + assert.False(t, exit.Checkpointed, "SkipCheckpoint should be sticky across multiple Stop calls") + + store.mu.Lock() + _, exists := store.m[cpID] + store.mu.Unlock() + assert.False(t, exists, "no checkpoint should be saved when SkipCheckpoint was set in any Stop call") +} diff --git a/internal/core/interrupt.go b/internal/core/interrupt.go index 174e0c47c..38ddbdae0 100644 --- a/internal/core/interrupt.go +++ b/internal/core/interrupt.go @@ -30,8 +30,12 @@ type CheckPointStore interface { } // CheckPointDeleter is an optional interface that CheckPointStore implementations -// can implement to support explicit checkpoint deletion. If the Store does not -// implement this interface, deletion is performed by writing an empty value via Set. +// can implement to support explicit checkpoint deletion. +// +// If the Store does not implement this interface, stale checkpoints will NOT be +// automatically cleaned up. The store owner is responsible for managing checkpoint +// lifecycle in that case (e.g., via TTL, external cleanup, or implementing this +// interface). type CheckPointDeleter interface { Delete(ctx context.Context, checkPointID string) error } From f25c591cb316aeeecd041542f59b211ddff8114e Mon Sep 17 00:00:00 2001 From: Born Date: Wed, 8 Apr 2026 14:13:14 +0800 Subject: [PATCH 50/65] feat(compose): support tool name and argument aliases in ToolsNode (#931) --- compose/tool_alias_test.go | 1178 ++++++++++++++++++++++++++++++++++++ compose/tool_node.go | 303 +++++++++- 2 files changed, 1459 insertions(+), 22 deletions(-) create mode 100644 compose/tool_alias_test.go diff --git a/compose/tool_alias_test.go b/compose/tool_alias_test.go new file mode 100644 index 000000000..487132cbe --- /dev/null +++ b/compose/tool_alias_test.go @@ -0,0 +1,1178 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "context" + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/cloudwego/eino/components/tool" + "github.com/cloudwego/eino/schema" +) + +type searchArgs struct { + Query string `json:"query"` +} + +func TestToolNameAliases(t *testing.T) { + ctx := context.Background() + + // Create test tool + searchTool := newTool(&schema.ToolInfo{ + Name: "search", + Desc: "Search for information", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "query": {Type: "string", Desc: "Search query"}, + }), + }, func(ctx context.Context, args *searchArgs) (string, error) { + return "search result", nil + }) + + // Configure aliases + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"search_v1", "query", "find"}, + }, + }, + } + + node, err := NewToolNode(ctx, config) + require.NoError(t, err) + + // Test calling tool with alias + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "search_v1", // Using alias + Arguments: `{"query": "test"}`, + }, + }, + }) + + output, err := node.Invoke(ctx, input) + require.NoError(t, err) + require.Len(t, output, 1) + assert.Equal(t, "call_1", output[0].ToolCallID) + assert.Contains(t, output[0].Content, "search result") +} + +type searchArgsWithLimit struct { + Query string `json:"query"` + Limit int `json:"limit"` +} + +func TestArgumentsAliases(t *testing.T) { + ctx := context.Background() + + receivedArgs := "" + searchTool := newTool(&schema.ToolInfo{ + Name: "search", + Desc: "Search for information", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "query": {Type: "string"}, + "limit": {Type: "integer"}, + }), + }, func(ctx context.Context, args *searchArgsWithLimit) (string, error) { + b, _ := json.Marshal(args) + receivedArgs = string(b) + return "result", nil + }) + + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + ArgumentsAliases: map[string][]string{ + "query": {"q", "search_term"}, + "limit": {"max_results", "count"}, + }, + }, + }, + } + + node, err := NewToolNode(ctx, config) + require.NoError(t, err) + + // Use alias parameters + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "search", + Arguments: `{"q": "test", "max_results": 10}`, // Using aliases + }, + }, + }) + + _, err = node.Invoke(ctx, input) + require.NoError(t, err) + + // Verify tool received canonical parameter names + var args map[string]any + err = json.Unmarshal([]byte(receivedArgs), &args) + require.NoError(t, err) + assert.Equal(t, "test", args["query"]) + assert.Equal(t, float64(10), args["limit"]) + assert.NotContains(t, args, "q") + assert.NotContains(t, args, "max_results") +} + +type emptyArgs struct{} + +func TestAliasConflict(t *testing.T) { + ctx := context.Background() + + tool1 := newTool(&schema.ToolInfo{Name: "search", Desc: "Search"}, func(ctx context.Context, args *emptyArgs) (string, error) { + return "result", nil + }) + tool2 := newTool(&schema.ToolInfo{Name: "query", Desc: "Query"}, func(ctx context.Context, args *emptyArgs) (string, error) { + return "result", nil + }) + + t.Run("tool name alias conflict", func(t *testing.T) { + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{tool1, tool2}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"find"}, + }, + "query": { + NameAliases: []string{"find"}, // Conflict: find already used by search + }, + }, + } + + _, err := NewToolNode(ctx, config) + require.Error(t, err) + assert.Contains(t, err.Error(), "conflicts with an alias already registered for") + }) + + t.Run("tool name alias conflicts with canonical name", func(t *testing.T) { + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{tool1, tool2}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"query"}, // Conflict: "query" is tool2's canonical name + }, + }, + } + + _, err := NewToolNode(ctx, config) + require.Error(t, err) + assert.Contains(t, err.Error(), "conflicts with existing tool's canonical name") + }) + + t.Run("argument alias conflict", func(t *testing.T) { + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{tool1}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + ArgumentsAliases: map[string][]string{ + "query": {"q"}, + "limit": {"q"}, // Conflict: q maps to multiple parameters + }, + }, + }, + } + + _, err := NewToolNode(ctx, config) + require.Error(t, err) + assert.Contains(t, err.Error(), "conflicting arg alias") + }) + + t.Run("arg alias conflicts with existing schema property", func(t *testing.T) { + searchWithParams := newTool(&schema.ToolInfo{ + Name: "search", + Desc: "Search", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "query": {Type: "string"}, + "limit": {Type: "integer"}, + }), + }, func(ctx context.Context, args *emptyArgs) (string, error) { + return "result", nil + }) + + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchWithParams}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + ArgumentsAliases: map[string][]string{ + "limit": {"query"}, // "query" is already a schema property + }, + }, + }, + } + + _, err := NewToolNode(ctx, config) + require.Error(t, err) + assert.Contains(t, err.Error(), "conflicts with existing schema property") + }) +} + +func TestArgumentsAliasesWithHandler(t *testing.T) { + ctx := context.Background() + + executionOrder := []string{} + + searchTool := newTool(&schema.ToolInfo{ + Name: "search", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "query": {Type: "string"}, + }), + }, func(ctx context.Context, args *searchArgs) (string, error) { + executionOrder = append(executionOrder, "tool_invoke") + return "result", nil + }) + + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"find"}, + ArgumentsAliases: map[string][]string{ + "query": {"q"}, + }, + }, + }, + ToolArgumentsHandler: func(ctx context.Context, name, args string) (string, error) { + executionOrder = append(executionOrder, "args_handler") + // Handler receives the original model-returned name (alias) + assert.Equal(t, "search", name) + // Verify alias remapping has already been done + var m map[string]any + err := json.Unmarshal([]byte(args), &m) + require.NoError(t, err) + assert.Contains(t, m, "query") + assert.NotContains(t, m, "q") + return args, nil + }, + } + + node, err := NewToolNode(ctx, config) + require.NoError(t, err) + + // Call with alias name "find" and alias arg "q" + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "find", + Arguments: `{"q": "test"}`, + }, + }, + }) + + _, err = node.Invoke(ctx, input) + require.NoError(t, err) + + // Verify execution order: alias remapping → ToolArgumentsHandler → tool execution + assert.Equal(t, []string{"args_handler", "tool_invoke"}, executionOrder) +} + +func TestNonExistentToolInAliasConfig(t *testing.T) { + ctx := context.Background() + + tool1 := newTool(&schema.ToolInfo{Name: "search", Desc: "Search"}, func(ctx context.Context, args *emptyArgs) (string, error) { + return "result", nil + }) + + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{tool1}, + ToolAliases: map[string]ToolAliasConfig{ + "non_existent_tool": { // Non-existent tool + NameAliases: []string{"alias1"}, + }, + }, + } + + // Should not error — non-existent tool alias configs are silently skipped + node, err := NewToolNode(ctx, config) + require.NoError(t, err) + + // The existing tool should still work normally + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "search", + Arguments: `{}`, + }, + }, + }) + output, err := node.Invoke(ctx, input) + require.NoError(t, err) + require.Len(t, output, 1) + assert.Contains(t, output[0].Content, "result") +} + +type weatherArgs struct { + Location string `json:"location"` +} + +func TestToolAliasesE2E(t *testing.T) { + ctx := context.Background() + + // Create multiple tools + searchTool := newTool(&schema.ToolInfo{ + Name: "search", + Desc: "Search for information", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "query": {Type: "string"}, + "limit": {Type: "integer"}, + }), + }, func(ctx context.Context, args *searchArgsWithLimit) (string, error) { + return "search result", nil + }) + + weatherTool := newTool(&schema.ToolInfo{ + Name: "weather", + Desc: "Get weather information", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "location": {Type: "string"}, + }), + }, func(ctx context.Context, args *weatherArgs) (string, error) { + return "weather result", nil + }) + + // Configure aliases for multiple tools + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool, weatherTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"search_v1", "query"}, + ArgumentsAliases: map[string][]string{ + "query": {"q", "search_term"}, + "limit": {"max_results"}, + }, + }, + "weather": { + NameAliases: []string{"get_weather"}, + ArgumentsAliases: map[string][]string{ + "location": {"loc", "city"}, + }, + }, + }, + } + + node, err := NewToolNode(ctx, config) + require.NoError(t, err) + + // Construct message with multiple tool calls using different aliases + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "search_v1", // Tool name alias + Arguments: `{"q": "test", "max_results": 5}`, // Parameter aliases + }, + }, + { + ID: "call_2", + Function: schema.FunctionCall{ + Name: "get_weather", // Tool name alias + Arguments: `{"city": "Beijing"}`, // Parameter alias + }, + }, + }) + + output, err := node.Invoke(ctx, input) + require.NoError(t, err) + require.Len(t, output, 2) + + // Verify both tools executed successfully + assert.Equal(t, "call_1", output[0].ToolCallID) + assert.Equal(t, "call_2", output[1].ToolCallID) + assert.Contains(t, output[0].Content, "search result") + assert.Contains(t, output[1].Content, "weather result") +} + +func TestRemapArgsEdgeCases(t *testing.T) { + aliasMap := map[string]string{"q": "query"} + + t.Run("empty string", func(t *testing.T) { + result, err := remapArgs("", aliasMap) + assert.NoError(t, err) + assert.Equal(t, "", result) + }) + + t.Run("whitespace only", func(t *testing.T) { + result, err := remapArgs(" ", aliasMap) + assert.NoError(t, err) + assert.Equal(t, " ", result) + }) + + t.Run("non-object JSON", func(t *testing.T) { + result, err := remapArgs(`"hello"`, aliasMap) + assert.NoError(t, err) + assert.Equal(t, `"hello"`, result) + }) + + t.Run("JSON array", func(t *testing.T) { + result, err := remapArgs(`[1,2,3]`, aliasMap) + assert.NoError(t, err) + assert.Equal(t, `[1,2,3]`, result) + }) + + t.Run("invalid JSON", func(t *testing.T) { + result, err := remapArgs(`{invalid`, aliasMap) + assert.NoError(t, err) + assert.Equal(t, `{invalid`, result) + }) + + t.Run("alias and canonical both present", func(t *testing.T) { + // When both alias "q" and canonical "query" exist, alias is kept as-is (not deleted, not overwritten) + result, err := remapArgs(`{"q": "alias_val", "query": "canonical_val"}`, aliasMap) + assert.NoError(t, err) + var m map[string]any + require.NoError(t, json.Unmarshal([]byte(result), &m)) + assert.Equal(t, "canonical_val", m["query"]) + assert.Equal(t, "alias_val", m["q"]) + }) + + t.Run("unknown fields preserved", func(t *testing.T) { + result, err := remapArgs(`{"q": "test", "unknown_field": 42}`, aliasMap) + assert.NoError(t, err) + var m map[string]any + require.NoError(t, json.Unmarshal([]byte(result), &m)) + assert.Equal(t, "test", m["query"]) + assert.NotContains(t, m, "q") + assert.Equal(t, float64(42), m["unknown_field"]) + }) +} + +func TestCanonicalNameCallWithAliasConfigured(t *testing.T) { + ctx := context.Background() + + searchTool := newTool(&schema.ToolInfo{ + Name: "search", + Desc: "Search", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "query": {Type: "string"}, + }), + }, func(ctx context.Context, args *searchArgs) (string, error) { + return "result: " + args.Query, nil + }) + + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"find"}, + ArgumentsAliases: map[string][]string{ + "query": {"q"}, + }, + }, + }, + } + + node, err := NewToolNode(ctx, config) + require.NoError(t, err) + + // Call with canonical name and canonical arg — should work normally + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "search", + Arguments: `{"query": "hello"}`, + }, + }, + }) + + output, err := node.Invoke(ctx, input) + require.NoError(t, err) + require.Len(t, output, 1) + assert.Contains(t, output[0].Content, "result: hello") +} + +func TestEmptyAliasValidation(t *testing.T) { + ctx := context.Background() + + searchTool := newTool(&schema.ToolInfo{Name: "search", Desc: "Search"}, func(ctx context.Context, args *emptyArgs) (string, error) { + return "result", nil + }) + + t.Run("empty name alias", func(t *testing.T) { + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{""}, + }, + }, + } + _, err := NewToolNode(ctx, config) + require.Error(t, err) + assert.Contains(t, err.Error(), "empty name alias") + }) + + t.Run("empty arg alias", func(t *testing.T) { + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + ArgumentsAliases: map[string][]string{ + "query": {""}, + }, + }, + }, + } + _, err := NewToolNode(ctx, config) + require.Error(t, err) + assert.Contains(t, err.Error(), "empty argument alias") + }) + + t.Run("empty canonical arg key", func(t *testing.T) { + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + ArgumentsAliases: map[string][]string{ + "": {"q"}, + }, + }, + }, + } + _, err := NewToolNode(ctx, config) + require.Error(t, err) + assert.Contains(t, err.Error(), "empty canonical argument key") + }) +} + +func TestNameAliasSameAsCanonical(t *testing.T) { + ctx := context.Background() + + searchTool := newTool(&schema.ToolInfo{Name: "search", Desc: "Search"}, func(ctx context.Context, args *emptyArgs) (string, error) { + return "result", nil + }) + + // Alias same as canonical name — should be tolerated (skip, no error) + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"search", "find"}, + }, + }, + } + + node, err := NewToolNode(ctx, config) + require.NoError(t, err) + + // Both canonical and alias should work + for _, name := range []string{"search", "find"} { + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: name, + Arguments: `{}`, + }, + }, + }) + output, err := node.Invoke(ctx, input) + require.NoError(t, err) + require.Len(t, output, 1) + assert.Contains(t, output[0].Content, "result") + } +} + +func TestToolAliasesWithDynamicToolList(t *testing.T) { + ctx := context.Background() + + searchTool := newTool(&schema.ToolInfo{ + Name: "search", + Desc: "Search", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "query": {Type: "string"}, + }), + }, func(ctx context.Context, args *searchArgs) (string, error) { + return "search result: " + args.Query, nil + }) + + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"find"}, + ArgumentsAliases: map[string][]string{ + "query": {"q"}, + }, + }, + }, + } + + node, err := NewToolNode(ctx, config) + require.NoError(t, err) + + // Use dynamic ToolList via option — alias should still work + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "find", + Arguments: `{"q": "dynamic"}`, + }, + }, + }) + + output, err := node.Invoke(ctx, input, WithToolList(searchTool)) + require.NoError(t, err) + require.Len(t, output, 1) + assert.Contains(t, output[0].Content, "search result: dynamic") +} + +func TestToolNameAliasesStream(t *testing.T) { + ctx := context.Background() + + searchTool := newTool(&schema.ToolInfo{ + Name: "search", + Desc: "Search for information", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "query": {Type: "string"}, + }), + }, func(ctx context.Context, args *searchArgs) (string, error) { + return "stream result: " + args.Query, nil + }) + + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"find"}, + ArgumentsAliases: map[string][]string{ + "query": {"q"}, + }, + }, + }, + } + + node, err := NewToolNode(ctx, config) + require.NoError(t, err) + + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "find", + Arguments: `{"q": "hello"}`, + }, + }, + }) + + reader, err := node.Stream(ctx, input) + require.NoError(t, err) + + var chunks [][]*schema.Message + for { + chunk, err := reader.Recv() + if err != nil { + break + } + chunks = append(chunks, chunk) + } + + msgs, err := schema.ConcatMessageArray(chunks) + require.NoError(t, err) + require.Len(t, msgs, 1) + assert.Equal(t, "call_1", msgs[0].ToolCallID) + assert.Contains(t, msgs[0].Content, "stream result: hello") +} + +func TestEnhancedToolWithAliases(t *testing.T) { + ctx := context.Background() + + enhancedTool := &enhancedInvokableTool{ + info: &schema.ToolInfo{ + Name: "search", + Desc: "Enhanced search", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "query": {Type: "string"}, + }), + }, + fn: func(ctx context.Context, input *schema.ToolArgument) (*schema.ToolResult, error) { + return &schema.ToolResult{ + Parts: []schema.ToolOutputPart{ + {Type: schema.ToolPartTypeText, Text: "enhanced: " + input.Text}, + }, + }, nil + }, + } + + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{enhancedTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"find"}, + ArgumentsAliases: map[string][]string{ + "query": {"q"}, + }, + }, + }, + } + + node, err := NewToolNode(ctx, config) + require.NoError(t, err) + + // Call with alias name and alias arg + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "find", + Arguments: `{"q": "test"}`, + }, + }, + }) + + output, err := node.Invoke(ctx, input) + require.NoError(t, err) + require.Len(t, output, 1) + assert.Equal(t, "call_1", output[0].ToolCallID) + // Verify arg alias was remapped: "q" → "query" in the JSON passed to enhanced tool + assert.Contains(t, output[0].UserInputMultiContent[0].Text, "enhanced:") +} + +func TestDynamicToolListAliasRemoved(t *testing.T) { + ctx := context.Background() + + searchTool := newTool(&schema.ToolInfo{ + Name: "search", + Desc: "Search", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "query": {Type: "string"}, + }), + }, func(ctx context.Context, args *searchArgs) (string, error) { + return "search result", nil + }) + + weatherTool := newTool(&schema.ToolInfo{ + Name: "weather", + Desc: "Weather", + }, func(ctx context.Context, args *emptyArgs) (string, error) { + return "weather result", nil + }) + + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool, weatherTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"find"}, + }, + }, + } + + node, err := NewToolNode(ctx, config) + require.NoError(t, err) + + // Dynamic tool list only contains weatherTool — "search" and its alias "find" should not be available + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "find", + Arguments: `{}`, + }, + }, + }) + + _, err = node.Invoke(ctx, input, WithToolList(weatherTool)) + require.Error(t, err) + assert.Contains(t, err.Error(), "not found") +} + +func TestToolAliasesOptionOverridesGlobal(t *testing.T) { + ctx := context.Background() + + searchTool := newTool(&schema.ToolInfo{ + Name: "search", + Desc: "Search", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "query": {Type: "string"}, + }), + }, func(ctx context.Context, args *searchArgs) (string, error) { + return "search result: " + args.Query, nil + }) + + weatherTool := newTool(&schema.ToolInfo{ + Name: "weather", + Desc: "Weather", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "location": {Type: "string"}, + }), + }, func(ctx context.Context, args *weatherArgs) (string, error) { + return "weather result: " + args.Location, nil + }) + + // Global aliases: search has alias "find" + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool, weatherTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"find"}, + ArgumentsAliases: map[string][]string{ + "query": {"q"}, + }, + }, + }, + } + + node, err := NewToolNode(ctx, config) + require.NoError(t, err) + + t.Run("opt ToolAliases overrides global in Invoke", func(t *testing.T) { + // opt.ToolAliases defines "lookup" as alias for search (not "find") + optAliases := map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"lookup"}, + ArgumentsAliases: map[string][]string{ + "query": {"keyword"}, + }, + }, + } + + // "lookup" should work with opt aliases + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "lookup", + Arguments: `{"keyword": "test"}`, + }, + }, + }) + + output, err := node.Invoke(ctx, input, WithToolList(searchTool), WithToolAliases(optAliases)) + require.NoError(t, err) + require.Len(t, output, 1) + assert.Contains(t, output[0].Content, "search result: test") + + // "find" (global alias) should NOT work when opt.ToolAliases is set + input2 := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_2", + Function: schema.FunctionCall{ + Name: "find", + Arguments: `{"q": "test"}`, + }, + }, + }) + + _, err = node.Invoke(ctx, input2, WithToolList(searchTool), WithToolAliases(optAliases)) + require.Error(t, err) + assert.Contains(t, err.Error(), "not found") + }) + + t.Run("opt ToolAliases overrides global in Stream", func(t *testing.T) { + optAliases := map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"lookup"}, + ArgumentsAliases: map[string][]string{ + "query": {"keyword"}, + }, + }, + } + + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "lookup", + Arguments: `{"keyword": "stream_test"}`, + }, + }, + }) + + reader, err := node.Stream(ctx, input, WithToolList(searchTool), WithToolAliases(optAliases)) + require.NoError(t, err) + + var chunks [][]*schema.Message + for { + chunk, err := reader.Recv() + if err != nil { + break + } + chunks = append(chunks, chunk) + } + + msgs, err := schema.ConcatMessageArray(chunks) + require.NoError(t, err) + require.Len(t, msgs, 1) + assert.Contains(t, msgs[0].Content, "search result: stream_test") + }) + + t.Run("nil opt ToolAliases falls back to global filtered", func(t *testing.T) { + // No WithToolAliases — should use global "find" alias, filtered by ToolList + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "find", + Arguments: `{"q": "fallback"}`, + }, + }, + }) + + output, err := node.Invoke(ctx, input, WithToolList(searchTool)) + require.NoError(t, err) + require.Len(t, output, 1) + assert.Contains(t, output[0].Content, "search result: fallback") + }) + + t.Run("opt ToolAliases only without ToolList replaces global", func(t *testing.T) { + // Only WithToolAliases, no WithToolList — should use global tools with opt aliases + optAliases := map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"lookup"}, + ArgumentsAliases: map[string][]string{ + "query": {"keyword"}, + }, + }, + } + + // "lookup" (opt alias) should work + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "lookup", + Arguments: `{"keyword": "only_alias"}`, + }, + }, + }) + + output, err := node.Invoke(ctx, input, WithToolAliases(optAliases)) + require.NoError(t, err) + require.Len(t, output, 1) + assert.Contains(t, output[0].Content, "search result: only_alias") + + // "find" (global alias) should NOT work when opt.ToolAliases replaces global + input2 := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_2", + Function: schema.FunctionCall{ + Name: "find", + Arguments: `{"q": "test"}`, + }, + }, + }) + + _, err = node.Invoke(ctx, input2, WithToolAliases(optAliases)) + require.Error(t, err) + assert.Contains(t, err.Error(), "not found") + }) + + t.Run("opt ToolAliases only without ToolList in Stream", func(t *testing.T) { + optAliases := map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"lookup"}, + }, + } + + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "lookup", + Arguments: `{"query": "stream_only_alias"}`, + }, + }, + }) + + reader, err := node.Stream(ctx, input, WithToolAliases(optAliases)) + require.NoError(t, err) + + var chunks [][]*schema.Message + for { + chunk, err := reader.Recv() + if err != nil { + break + } + chunks = append(chunks, chunk) + } + + msgs, err := schema.ConcatMessageArray(chunks) + require.NoError(t, err) + require.Len(t, msgs, 1) + assert.Contains(t, msgs[0].Content, "search result: stream_only_alias") + }) +} + +func TestAliasConfigForToolAddedViaOption(t *testing.T) { + ctx := context.Background() + + searchTool := newTool(&schema.ToolInfo{ + Name: "search", + Desc: "Search", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "query": {Type: "string"}, + }), + }, func(ctx context.Context, args *searchArgs) (string, error) { + return "search result: " + args.Query, nil + }) + + weatherTool := newTool(&schema.ToolInfo{ + Name: "weather", + Desc: "Weather", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "location": {Type: "string"}, + }), + }, func(ctx context.Context, args *weatherArgs) (string, error) { + return "weather result: " + args.Location, nil + }) + + // New with only searchTool, but alias config includes weather tool + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"find"}, + ArgumentsAliases: map[string][]string{ + "query": {"q"}, + }, + }, + "weather": { + NameAliases: []string{"forecast"}, + ArgumentsAliases: map[string][]string{ + "location": {"loc"}, + }, + }, + }, + } + + node, err := NewToolNode(ctx, config) + require.NoError(t, err) + + t.Run("weather alias works when tool passed via option", func(t *testing.T) { + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "forecast", + Arguments: `{"loc": "Beijing"}`, + }, + }, + }) + + output, err := node.Invoke(ctx, input, WithToolList(searchTool, weatherTool)) + require.NoError(t, err) + require.Len(t, output, 1) + assert.Contains(t, output[0].Content, "weather result: Beijing") + }) + + t.Run("search alias still works with option tool list", func(t *testing.T) { + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "find", + Arguments: `{"q": "test"}`, + }, + }, + }) + + output, err := node.Invoke(ctx, input, WithToolList(searchTool, weatherTool)) + require.NoError(t, err) + require.Len(t, output, 1) + assert.Contains(t, output[0].Content, "search result: test") + }) +} + +func TestOptionWithToolListAndToolAliases(t *testing.T) { + ctx := context.Background() + + searchTool := newTool(&schema.ToolInfo{ + Name: "search", + Desc: "Search", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "query": {Type: "string"}, + }), + }, func(ctx context.Context, args *searchArgs) (string, error) { + return "search result: " + args.Query, nil + }) + + weatherTool := newTool(&schema.ToolInfo{ + Name: "weather", + Desc: "Weather", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "location": {Type: "string"}, + }), + }, func(ctx context.Context, args *weatherArgs) (string, error) { + return "weather result: " + args.Location, nil + }) + + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"find"}, + }, + }, + } + + node, err := NewToolNode(ctx, config) + require.NoError(t, err) + + t.Run("opt aliases override global when both tool list and aliases provided", func(t *testing.T) { + optAliases := map[string]ToolAliasConfig{ + "weather": { + NameAliases: []string{"forecast"}, + ArgumentsAliases: map[string][]string{ + "location": {"loc"}, + }, + }, + } + + // "forecast" should work via opt aliases + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "forecast", + Arguments: `{"loc": "Shanghai"}`, + }, + }, + }) + + output, err := node.Invoke(ctx, input, WithToolList(searchTool, weatherTool), WithToolAliases(optAliases)) + require.NoError(t, err) + require.Len(t, output, 1) + assert.Contains(t, output[0].Content, "weather result: Shanghai") + + // "find" (global alias) should NOT work when opt aliases override + input2 := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_2", + Function: schema.FunctionCall{ + Name: "find", + Arguments: `{"query": "test"}`, + }, + }, + }) + + _, err = node.Invoke(ctx, input2, WithToolList(searchTool, weatherTool), WithToolAliases(optAliases)) + require.Error(t, err) + assert.Contains(t, err.Error(), "not found") + }) +} diff --git a/compose/tool_node.go b/compose/tool_node.go index a8f98a866..f65037e90 100644 --- a/compose/tool_node.go +++ b/compose/tool_node.go @@ -18,11 +18,16 @@ package compose import ( "context" + "encoding/json" "errors" "fmt" "runtime/debug" + "sort" + "strings" "sync" + "github.com/bytedance/sonic" + "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/components" "github.com/cloudwego/eino/components/tool" @@ -33,6 +38,8 @@ import ( type toolsNodeOptions struct { ToolOptions []tool.Option ToolList []tool.BaseTool + + ToolAliases map[string]ToolAliasConfig } // ToolsNodeOption is the option func type for ToolsNode. @@ -52,6 +59,15 @@ func WithToolList(tool ...tool.BaseTool) ToolsNodeOption { } } +// WithToolAliases sets the tool aliases for the ToolsNode call option. +// When used with WithToolList, it overrides the global alias configuration for the dynamic tool list. +// When used alone (without WithToolList), it replaces the global alias configuration while keeping the original tool list. +func WithToolAliases(toolAliases map[string]ToolAliasConfig) ToolsNodeOption { + return func(o *toolsNodeOptions) { + o.ToolAliases = toolAliases + } +} + // ToolsNode represents a node capable of executing tools within a graph. // The Graph Node interface is defined as follows: // @@ -62,6 +78,7 @@ func WithToolList(tool ...tool.BaseTool) ToolsNodeOption { // Output: An array of ToolMessage where the order of elements corresponds to the order of ToolCalls in the input type ToolsNode struct { tuple *toolsTuple + tools []tool.BaseTool unknownToolHandler func(ctx context.Context, name, input string) (string, error) executeSequentially bool toolArgumentsHandler func(ctx context.Context, name, input string) (string, error) @@ -69,6 +86,7 @@ type ToolsNode struct { streamToolCallMiddlewares []StreamableToolMiddleware enhancedToolCallMiddlewares []EnhancedInvokableToolMiddleware enhancedStreamToolCallMiddlewares []EnhancedStreamableToolMiddleware + toolAliasConfigs map[string]ToolAliasConfig } // ToolInput represents the input parameters for a tool call execution. @@ -150,11 +168,30 @@ type ToolMiddleware struct { EnhancedStreamable EnhancedStreamableToolMiddleware } +// ToolAliasConfig configures name and argument aliases for a single tool. +type ToolAliasConfig struct { + // NameAliases are alternative names for this tool. + // If the model returns any of these names, it will be resolved to the canonical tool name. + NameAliases []string + + // ArgumentsAliases maps canonical argument keys to their alias lists. + // key=canonical, value=[]alias. Applied to top-level JSON keys before tool execution. + // Example: {"query": ["q", "search_term"], "limit": ["max_results", "count"]} + ArgumentsAliases map[string][]string +} + // ToolsNodeConfig is the config for ToolsNode. type ToolsNodeConfig struct { // Tools specify the list of tools can be called which are BaseTool but must implement InvokableTool or StreamableTool. Tools []tool.BaseTool + // ToolAliases configures name and argument aliases for tools. + // Key is the canonical tool name, value defines its aliases. + // This field is optional. When provided, tool name aliases will be resolved during tool dispatch, + // and argument aliases will be remapped before ToolArgumentsHandler (if configured) and tool execution. + // Execution order: ArgumentsAliases remapping → ToolArgumentsHandler → tool execution + ToolAliases map[string]ToolAliasConfig + // UnknownToolsHandler handles tool calls for non-existent tools when LLM hallucinates. // This field is optional. When not set, calling a non-existent tool will result in an error. // When provided, if the LLM attempts to call a tool that doesn't exist in the Tools list, @@ -219,13 +256,22 @@ func NewToolNode(ctx context.Context, conf *ToolsNodeConfig) (*ToolsNode, error) } } - tuple, err := convTools(ctx, conf.Tools, middlewares, streamMiddlewares, enhancedInvokableMiddlewares, enhancedStreamableMiddlewares) + params := convToolsParams{ + tools: conf.Tools, + aliasConfigs: conf.ToolAliases, + } + params.middlewares.invokable = middlewares + params.middlewares.streamable = streamMiddlewares + params.middlewares.enhancedInvokable = enhancedInvokableMiddlewares + params.middlewares.enhancedStreamable = enhancedStreamableMiddlewares + tuple, err := convTools(ctx, params) if err != nil { return nil, err } return &ToolsNode{ tuple: tuple, + tools: conf.Tools, unknownToolHandler: conf.UnknownToolsHandler, executeSequentially: conf.ExecuteSequentially, toolArgumentsHandler: conf.ToolArgumentsHandler, @@ -233,6 +279,7 @@ func NewToolNode(ctx context.Context, conf *ToolsNodeConfig) (*ToolsNode, error) streamToolCallMiddlewares: streamMiddlewares, enhancedToolCallMiddlewares: enhancedInvokableMiddlewares, enhancedStreamToolCallMiddlewares: enhancedStreamableMiddlewares, + toolAliasConfigs: conf.ToolAliases, }, nil } @@ -273,19 +320,184 @@ type toolsTuple struct { streamEndpoints []StreamableToolEndpoint enhancedInvokableEndpoints []EnhancedInvokableToolEndpoint enhancedStreamableEndpoints []EnhancedStreamableToolEndpoint + // argsAliasMap stores reverse argument alias mappings for each tool. + // key: canonical tool name, value: map[aliasKey]canonicalKey (alias → canonical direction) + argsAliasMap map[string]map[string]string + // canonicalNames stores the canonical name for each tool index + canonicalNames []string + // toolInfos stores the ToolInfo for each tool index, used for alias validation + toolInfos []*schema.ToolInfo +} + +// remapArgs replaces alias keys in the JSON arguments string with canonical keys. +// aliasMap: alias → canonical mapping +func remapArgs(args string, aliasMap map[string]string) (string, error) { + if len(aliasMap) == 0 { + return args, nil + } + + trimmed := strings.TrimSpace(args) + if trimmed == "" || trimmed[0] != '{' { + return args, nil + } + + var m map[string]json.RawMessage + if err := sonic.Unmarshal([]byte(args), &m); err != nil { + return args, nil + } + + changed := false + for alias, canonical := range aliasMap { + if v, ok := m[alias]; ok { + // Only replace if canonical key doesn't exist. + // If both alias and canonical are present (e.g. {"q":"a","query":"b"}), + // the alias key is kept as-is and passed through as an unknown field. + if _, exists := m[canonical]; !exists { + m[canonical] = v + delete(m, alias) + changed = true + } + } + } + + if !changed { + return args, nil + } + + b, err := sonic.Marshal(m) + return string(b), err +} + +type convToolsParams struct { + tools []tool.BaseTool + middlewares struct { + invokable []InvokableToolMiddleware + streamable []StreamableToolMiddleware + enhancedInvokable []EnhancedInvokableToolMiddleware + enhancedStreamable []EnhancedStreamableToolMiddleware + } + aliasConfigs map[string]ToolAliasConfig +} + +func (t *toolsTuple) applyAliasConfigs(aliasConfigs map[string]ToolAliasConfig) error { + t.argsAliasMap = make(map[string]map[string]string) + + sortedToolNames := make([]string, 0, len(aliasConfigs)) + for toolName := range aliasConfigs { + sortedToolNames = append(sortedToolNames, toolName) + } + sort.Strings(sortedToolNames) + + for _, toolName := range sortedToolNames { + aliasConfig := aliasConfigs[toolName] + var ( + toolIdx int + exists bool + ) + if toolIdx, exists = t.indexes[toolName]; !exists { + continue + } + + if err := t.applyNameAliases(toolName, toolIdx, aliasConfig.NameAliases); err != nil { + return err + } + + if err := t.applyArgsAliases(toolName, toolIdx, aliasConfig.ArgumentsAliases); err != nil { + return err + } + } + + return nil +} + +// applyNameAliases validates and registers name aliases for a single tool into the indexes map. +func (t *toolsTuple) applyNameAliases(toolName string, toolIdx int, nameAliases []string) error { + for _, alias := range nameAliases { + if strings.TrimSpace(alias) == "" { + return fmt.Errorf("tool '%s' has empty name alias", toolName) + } + if existingIdx, conflict := t.indexes[alias]; conflict { + if existingIdx != toolIdx { + conflictToolName := t.canonicalNames[existingIdx] + if alias == conflictToolName { + return fmt.Errorf("tool '%s': name alias '%s' conflicts with existing tool's canonical name", toolName, alias) + } + return fmt.Errorf("tool '%s': name alias '%s' conflicts with an alias already registered for tool '%s'", toolName, alias, conflictToolName) + } + continue + } + t.indexes[alias] = toolIdx + } + return nil } -func convTools(ctx context.Context, tools []tool.BaseTool, ms []InvokableToolMiddleware, sms []StreamableToolMiddleware, - ems []EnhancedInvokableToolMiddleware, esms []EnhancedStreamableToolMiddleware) (*toolsTuple, error) { +// applyArgsAliases validates argument aliases against the tool schema and builds a reverse alias map for a single tool. +func (t *toolsTuple) applyArgsAliases(toolName string, toolIdx int, argumentsAliases map[string][]string) error { + if len(argumentsAliases) == 0 { + return nil + } + + schemaKeys := make(map[string]bool) + if info := t.toolInfos[toolIdx]; info != nil && info.ParamsOneOf != nil { + js, err := info.ParamsOneOf.ToJSONSchema() + if err != nil { + return fmt.Errorf("tool '%s': failed to parse JSON schema for alias validation: %w", toolName, err) + } + if js != nil && js.Properties != nil { + for pair := js.Properties.Oldest(); pair != nil; pair = pair.Next() { + schemaKeys[pair.Key] = true + } + } + } + + reverseMap := make(map[string]string) + sortedCanonicals := make([]string, 0, len(argumentsAliases)) + for canonical := range argumentsAliases { + sortedCanonicals = append(sortedCanonicals, canonical) + } + sort.Strings(sortedCanonicals) + + for _, canonical := range sortedCanonicals { + aliases := argumentsAliases[canonical] + if strings.TrimSpace(canonical) == "" { + return fmt.Errorf("tool '%s' has empty canonical argument key", toolName) + } + if strings.Contains(canonical, ".") { + return fmt.Errorf("tool '%s' has unsupported '.' in canonical argument key '%s': nested field matching is not yet supported", + toolName, canonical) + } + for _, alias := range aliases { + if strings.TrimSpace(alias) == "" { + return fmt.Errorf("tool '%s' has empty argument alias for canonical key '%s'", toolName, canonical) + } + if schemaKeys[alias] { + return fmt.Errorf("tool '%s' has arg alias '%s' that conflicts with existing schema property '%s'", + toolName, alias, alias) + } + if existingCanonical, conflict := reverseMap[alias]; conflict { + return fmt.Errorf("tool '%s' has conflicting arg alias '%s' mapped to both '%s' and '%s'", + toolName, alias, existingCanonical, canonical) + } + reverseMap[alias] = canonical + } + } + t.argsAliasMap[toolName] = reverseMap + + return nil +} + +func convTools(ctx context.Context, params convToolsParams) (*toolsTuple, error) { ret := &toolsTuple{ indexes: make(map[string]int), - meta: make([]*executorMeta, len(tools)), - endpoints: make([]InvokableToolEndpoint, len(tools)), - streamEndpoints: make([]StreamableToolEndpoint, len(tools)), - enhancedInvokableEndpoints: make([]EnhancedInvokableToolEndpoint, len(tools)), - enhancedStreamableEndpoints: make([]EnhancedStreamableToolEndpoint, len(tools)), + meta: make([]*executorMeta, len(params.tools)), + endpoints: make([]InvokableToolEndpoint, len(params.tools)), + streamEndpoints: make([]StreamableToolEndpoint, len(params.tools)), + enhancedInvokableEndpoints: make([]EnhancedInvokableToolEndpoint, len(params.tools)), + enhancedStreamableEndpoints: make([]EnhancedStreamableToolEndpoint, len(params.tools)), + canonicalNames: make([]string, len(params.tools)), + toolInfos: make([]*schema.ToolInfo, len(params.tools)), } - for idx, bt := range tools { + for idx, bt := range params.tools { tl, err := bt.Info(ctx) if err != nil { return nil, fmt.Errorf("(NewToolNode) failed to get tool info at idx= %d: %w", idx, err) @@ -310,19 +522,19 @@ func convTools(ctx context.Context, tools []tool.BaseTool, ms []InvokableToolMid meta = parseExecutorInfoFromComponent(components.ComponentOfTool, bt) if st, ok = bt.(tool.StreamableTool); ok { - streamable = wrapStreamToolCall(st, sms, !meta.isComponentCallbackEnabled) + streamable = wrapStreamToolCall(st, params.middlewares.streamable, !meta.isComponentCallbackEnabled) } if it, ok = bt.(tool.InvokableTool); ok { - invokable = wrapToolCall(it, ms, !meta.isComponentCallbackEnabled) + invokable = wrapToolCall(it, params.middlewares.invokable, !meta.isComponentCallbackEnabled) } if eiTool, ok = bt.(tool.EnhancedInvokableTool); ok { - enhancedInvokable = wrapEnhancedInvokableToolCall(eiTool, ems, !meta.isComponentCallbackEnabled) + enhancedInvokable = wrapEnhancedInvokableToolCall(eiTool, params.middlewares.enhancedInvokable, !meta.isComponentCallbackEnabled) } if esTool, ok = bt.(tool.EnhancedStreamableTool); ok { - enhancedStreamable = wrapEnhancedStreamableToolCall(esTool, esms, !meta.isComponentCallbackEnabled) + enhancedStreamable = wrapEnhancedStreamableToolCall(esTool, params.middlewares.enhancedStreamable, !meta.isComponentCallbackEnabled) } if st == nil && it == nil && eiTool == nil && esTool == nil { @@ -348,7 +560,16 @@ func convTools(ctx context.Context, tools []tool.BaseTool, ms []InvokableToolMid ret.streamEndpoints[idx] = streamable ret.enhancedInvokableEndpoints[idx] = enhancedInvokable ret.enhancedStreamableEndpoints[idx] = enhancedStreamable + ret.canonicalNames[idx] = toolName + ret.toolInfos[idx] = tl } + + if len(params.aliasConfigs) > 0 { + if err := ret.applyAliasConfigs(params.aliasConfigs); err != nil { + return nil, err + } + } + return ret, nil } @@ -616,14 +837,27 @@ func (tn *ToolsNode) genToolCallTasks(ctx context.Context, tuple *toolsTuple, toolCallTasks[i].useEnhanced = false } + // Get canonical tool name for looking up argument aliases + canonicalToolName := tuple.canonicalNames[index] + + // Process argument aliases remapping + args := toolCall.Function.Arguments + if aliasMap, hasAliases := tuple.argsAliasMap[canonicalToolName]; hasAliases { + remappedArgs, err := remapArgs(args, aliasMap) + if err != nil { + return nil, fmt.Errorf("failed to remap args for tool[name:%s]: %w", canonicalToolName, err) + } + args = remappedArgs + } + if tn.toolArgumentsHandler != nil { - arg, err := tn.toolArgumentsHandler(ctx, toolCall.Function.Name, toolCall.Function.Arguments) + arg, err := tn.toolArgumentsHandler(ctx, canonicalToolName, args) if err != nil { - return nil, fmt.Errorf("failed to executed tool[name:%s arguments:%s] arguments handler: %w", toolCall.Function.Name, toolCall.Function.Arguments, err) + return nil, fmt.Errorf("failed to executed tool[name:%s arguments:%s] arguments handler: %w", toolCall.Function.Name, args, err) } toolCallTasks[i].arg = arg } else { - toolCallTasks[i].arg = toolCall.Function.Arguments + toolCallTasks[i].arg = args } } } @@ -782,6 +1016,31 @@ func parallelRunToolCall(ctx context.Context, wg.Wait() } +// buildTupleFromOpts rebuilds a toolsTuple when call options override tools or aliases. +func (tn *ToolsNode) buildTupleFromOpts(ctx context.Context, opt *toolsNodeOptions) (*toolsTuple, error) { + tools := opt.ToolList + if tools == nil { + tools = tn.tools + } + aliasConfigs := opt.ToolAliases + if aliasConfigs == nil { + aliasConfigs = tn.toolAliasConfigs + } + p := convToolsParams{ + tools: tools, + aliasConfigs: aliasConfigs, + } + p.middlewares.invokable = tn.toolCallMiddlewares + p.middlewares.streamable = tn.streamToolCallMiddlewares + p.middlewares.enhancedInvokable = tn.enhancedToolCallMiddlewares + p.middlewares.enhancedStreamable = tn.enhancedStreamToolCallMiddlewares + tuple, err := convTools(ctx, p) + if err != nil { + return nil, fmt.Errorf("failed to convert tool list from call option: %w", err) + } + return tuple, nil +} + // Invoke calls the tools and collects the results of invokable tools. // it's parallel if there are multiple tool calls in the input message. func (tn *ToolsNode) Invoke(ctx context.Context, input *schema.Message, @@ -789,11 +1048,11 @@ func (tn *ToolsNode) Invoke(ctx context.Context, input *schema.Message, opt := getToolsNodeOptions(opts...) tuple := tn.tuple - if opt.ToolList != nil { + if opt.ToolList != nil || opt.ToolAliases != nil { var err error - tuple, err = convTools(ctx, opt.ToolList, tn.toolCallMiddlewares, tn.streamToolCallMiddlewares, tn.enhancedToolCallMiddlewares, tn.enhancedStreamToolCallMiddlewares) + tuple, err = tn.buildTupleFromOpts(ctx, opt) if err != nil { - return nil, fmt.Errorf("failed to convert tool list from call option: %w", err) + return nil, err } } @@ -891,11 +1150,11 @@ func (tn *ToolsNode) Stream(ctx context.Context, input *schema.Message, opt := getToolsNodeOptions(opts...) tuple := tn.tuple - if opt.ToolList != nil { + if opt.ToolList != nil || opt.ToolAliases != nil { var err error - tuple, err = convTools(ctx, opt.ToolList, tn.toolCallMiddlewares, tn.streamToolCallMiddlewares, tn.enhancedToolCallMiddlewares, tn.enhancedStreamToolCallMiddlewares) + tuple, err = tn.buildTupleFromOpts(ctx, opt) if err != nil { - return nil, fmt.Errorf("failed to convert tool list from call option: %w", err) + return nil, err } } From 87cea07b55ddb1c6075fc3abdbf2193713cf2cfa Mon Sep 17 00:00:00 2001 From: Ryo Date: Thu, 9 Apr 2026 11:41:14 +0800 Subject: [PATCH 51/65] feat(adk): add failover support for ChatModel (#885) * fix: rebase error Change-Id: If20fa78dba82a1c177c8ec47090050ea8c1354ed * feat(adk): add failover support for ChatModel Change-Id: Ice1b513b4b509e7b540316da9119ff3d529c9bae * feat(adk): add failover support for ChatModel Change-Id: Ice1b513b4b509e7b540316da9119ff3d529c9bae * feat(adk): add failover support for ChatModel Change-Id: Id5483447b74322f6dd495bdd3b994c001094569d * feat(adk): make Name and Description optional in ChatModelAgentConfig * feat(adk): add callback lifecycle management to failoverProxyModel - Extract prepareCallbacks method to reuse callback setup logic between Generate and Stream methods - Add callbacks.ReuseHandlers with proper RunInfo (model type + component) before each failover model invocation so handlers receive correct identity - Add explicit OnStart/OnEnd/OnError callback invocations in Generate and Stream since failoverProxyModel declares IsCallbacksEnabled() = true and the outer layer skips automatic callback injection Change-Id: I0150529024125251828cf6f77c8247aa464b1f84 * fix(adk): preserve partial result in failoverProxyModel.Generate on error Return result instead of nil when target.Generate fails, so that the outer failoverModelWrapper can pass the partial output message to ShouldFailover for inspection. Change-Id: I32d86151a6e133f1a58d5e988bccf42d831a646c * refactor(adk): use EnsureRunInfo in failoverProxyModel and separate ctx for callbacks - Replace manual RunInfo construction + ReuseHandlers with callbacks.EnsureRunInfo for cleaner RunInfo setup - Use nCtx (from EnsureRunInfo) for target model invocation and original ctx for OnStart/OnEnd/OnError callback lifecycle Change-Id: I1d5982d0e1ceeaf8f6648b9c40c229b6a2b07ab8 --------- Co-authored-by: shentong.martin --- adk/chatmodel.go | 94 ++-- adk/chatmodel_test.go | 41 ++ adk/failover_chatmodel.go | 466 +++++++++++++++++++ adk/failover_chatmodel_test.go | 697 ++++++++++++++++++++++++++++ adk/handler.go | 6 + adk/prebuilt/deep/deep.go | 12 +- adk/prebuilt/deep/task_tool.go | 23 +- adk/prebuilt/deep/task_tool_test.go | 1 + adk/wrappers.go | 141 ++++-- adk/wrappers_failover_test.go | 181 ++++++++ adk/wrappers_retry_failover_test.go | 411 ++++++++++++++++ 11 files changed, 1993 insertions(+), 80 deletions(-) create mode 100644 adk/failover_chatmodel.go create mode 100644 adk/failover_chatmodel_test.go create mode 100644 adk/wrappers_failover_test.go create mode 100644 adk/wrappers_retry_failover_test.go diff --git a/adk/chatmodel.go b/adk/chatmodel.go index 4b736d51d..83d8ffd4a 100644 --- a/adk/chatmodel.go +++ b/adk/chatmodel.go @@ -44,6 +44,9 @@ type chatModelAgentExecCtx struct { runtimeReturnDirectly map[string]bool generator *AsyncGenerator[*AgentEvent] cancelCtx *cancelContext + + // failoverLastSuccessModel is the last success model only used in failover middleware. + failoverLastSuccessModel model.BaseChatModel } func (e *chatModelAgentExecCtx) send(event *AgentEvent) { @@ -260,13 +263,14 @@ type ChatModelAgentConfig struct { // Model call lifecycle (outermost to innermost wrapper chain): // 1. AgentMiddleware.BeforeChatModel (hook, runs before model call) // 2. ChatModelAgentMiddleware.BeforeModelRewriteState (hook, can modify state before model call) - // 3. retryModelWrapper (internal - retries on failure, if configured) - // 4. eventSenderModelWrapper (internal - sends model response events) - // 5. ChatModelAgentMiddleware.WrapModel (wrapper, first registered is outermost) - // 6. callbackInjectionModelWrapper (internal - injects callbacks if not enabled) - // 7. Model.Generate/Stream - // 8. ChatModelAgentMiddleware.AfterModelRewriteState (hook, can modify state after model call) - // 9. AgentMiddleware.AfterChatModel (hook, runs after model call) + // 3. failoverModelWrapper (internal - failover between models, if configured) + // 4. retryModelWrapper (internal - retries on failure, if configured) + // 5. eventSenderModelWrapper (internal - sends model response events) + // 6. ChatModelAgentMiddleware.WrapModel (wrapper, first registered is outermost) + // 7. callbackInjectionModelWrapper (internal - injects callbacks if not enabled; when failover is enabled, this is handled per-model inside failoverProxyModel instead) + // 8. failoverProxyModel (internal - dispatches to selected failover model, if configured) / Model.Generate/Stream + // 9. ChatModelAgentMiddleware.AfterModelRewriteState (hook, can modify state after model call) + // 10. AgentMiddleware.AfterChatModel (hook, runs after model call) // // Custom Event Sender Position: // By default, events are sent after all user middlewares (WrapModel) have processed the output, @@ -337,6 +341,13 @@ type ChatModelAgentConfig struct { // based on the configured policy. // Optional. If nil, no retry will be performed. ModelRetryConfig *ModelRetryConfig + + // ModelFailoverConfig configures failover behavior for the ChatModel. + // When set, the agent will first try the last successful model (initially the configured Model), + // and on failure, call GetFailoverModel to select alternate models. + // Model field is still required as it serves as the initial model. + // Optional. If nil, no failover will be performed. + ModelFailoverConfig *ModelFailoverConfig } type ChatModelAgent struct { @@ -362,7 +373,8 @@ type ChatModelAgent struct { handlers []ChatModelAgentMiddleware middlewares []AgentMiddleware - modelRetryConfig *ModelRetryConfig + modelRetryConfig *ModelRetryConfig + modelFailoverConfig *ModelFailoverConfig once sync.Once run runFunc @@ -386,6 +398,17 @@ type runFunc func(ctx context.Context, p *runParams) // NewChatModelAgent constructs a chat model-backed agent with the provided config. func NewChatModelAgent(ctx context.Context, config *ChatModelAgentConfig) (*ChatModelAgent, error) { + if config.ModelFailoverConfig != nil { + if config.ModelFailoverConfig.GetFailoverModel == nil { + return nil, errors.New("ModelFailoverConfig.GetFailoverModel is required when ModelFailoverConfig is set") + } + + // ShouldFailover is required when ModelFailoverConfig is set + if config.ModelFailoverConfig.ShouldFailover == nil { + return nil, errors.New("ModelFailoverConfig.ShouldFailover is required when ModelFailoverConfig is set") + } + } + if config.Model == nil { return nil, errors.New("agent 'Model' is required") } @@ -420,18 +443,19 @@ func NewChatModelAgent(ctx context.Context, config *ChatModelAgentConfig) (*Chat }) return &ChatModelAgent{ - name: config.Name, - description: config.Description, - instruction: config.Instruction, - model: config.Model, - toolsConfig: tc, - genModelInput: genInput, - exit: config.Exit, - outputKey: config.OutputKey, - maxIterations: config.MaxIterations, - handlers: config.Handlers, - middlewares: config.Middlewares, - modelRetryConfig: config.ModelRetryConfig, + name: config.Name, + description: config.Description, + instruction: config.Instruction, + model: config.Model, + toolsConfig: tc, + genModelInput: genInput, + exit: config.Exit, + outputKey: config.OutputKey, + maxIterations: config.MaxIterations, + handlers: config.Handlers, + middlewares: config.Middlewares, + modelRetryConfig: config.ModelRetryConfig, + modelFailoverConfig: config.ModelFailoverConfig, }, nil } @@ -799,10 +823,11 @@ func (a *ChatModelAgent) buildNoToolsRunFunc(_ context.Context) runFunc { ctx = withCancelContext(ctx, cancelCtx) wrappedModel := buildModelWrappers(a.model, &modelWrapperConfig{ - handlers: a.handlers, - middlewares: a.middlewares, - retryConfig: a.modelRetryConfig, - cancelContext: cancelCtx, + handlers: a.handlers, + middlewares: a.middlewares, + retryConfig: a.modelRetryConfig, + failoverConfig: a.modelFailoverConfig, + cancelContext: cancelCtx, }) chain := compose.NewChain[noToolsInput, Message]( @@ -841,8 +866,9 @@ func (a *ChatModelAgent) buildNoToolsRunFunc(_ context.Context) runFunc { } ctx = withChatModelAgentExecCtx(ctx, &chatModelAgentExecCtx{ - generator: p.generator, - cancelCtx: cancelCtx, + generator: p.generator, + cancelCtx: cancelCtx, + failoverLastSuccessModel: a.model, }) // Pre-execution cancel check @@ -888,10 +914,11 @@ func (a *ChatModelAgent) buildReActRunFunc(_ context.Context, bc *execContext) ( model: a.model, toolsConfig: &bc.toolsNodeConf, modelWrapperConf: &modelWrapperConfig{ - handlers: a.handlers, - middlewares: a.middlewares, - retryConfig: a.modelRetryConfig, - toolInfos: bc.toolInfos, + handlers: a.handlers, + middlewares: a.middlewares, + retryConfig: a.modelRetryConfig, + failoverConfig: a.modelFailoverConfig, + toolInfos: bc.toolInfos, }, toolsReturnDirectly: bc.returnDirectly, agentName: a.name, @@ -951,9 +978,10 @@ func (a *ChatModelAgent) buildReActRunFunc(_ context.Context, bc *execContext) ( } ctx = withChatModelAgentExecCtx(ctx, &chatModelAgentExecCtx{ - runtimeReturnDirectly: p.returnDirectly, - generator: p.generator, - cancelCtx: cancelCtx, + runtimeReturnDirectly: p.returnDirectly, + generator: p.generator, + cancelCtx: cancelCtx, + failoverLastSuccessModel: a.model, }) // Pre-execution cancel check diff --git a/adk/chatmodel_test.go b/adk/chatmodel_test.go index 3a2f920dd..f3ff6ea05 100644 --- a/adk/chatmodel_test.go +++ b/adk/chatmodel_test.go @@ -2057,3 +2057,44 @@ func TestPreprocessComposeCheckpoint_MigrateErrorIsReturned(t *testing.T) { _, err := preprocessComposeCheckpoint(in) assert.Error(t, err) } + +func TestNewChatModelAgent_FailoverConfigValidation(t *testing.T) { + ctx := context.Background() + cm := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return schema.AssistantMessage("ok", nil), nil + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("ok", nil)}), nil + }, + } + + t.Run("missing GetFailoverModel", func(t *testing.T) { + _, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: cm, + ModelFailoverConfig: &ModelFailoverConfig{ + ShouldFailover: func(context.Context, *schema.Message, error) bool { return true }, + }, + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "ModelFailoverConfig.GetFailoverModel") + }) + + t.Run("missing ShouldFailover", func(t *testing.T) { + _, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: cm, + ModelFailoverConfig: &ModelFailoverConfig{ + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return cm, nil, nil + }, + }, + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "ModelFailoverConfig.ShouldFailover") + }) +} diff --git a/adk/failover_chatmodel.go b/adk/failover_chatmodel.go new file mode 100644 index 000000000..2a467ed76 --- /dev/null +++ b/adk/failover_chatmodel.go @@ -0,0 +1,466 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "errors" + "fmt" + "io" + "log" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/components" + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/schema" +) + +type failoverCurrentModelKey struct{} + +type failoverCurrentModel struct { + model model.BaseChatModel +} + +func setFailoverCurrentModel(ctx context.Context, currentModel model.BaseChatModel) context.Context { + return context.WithValue(ctx, failoverCurrentModelKey{}, &failoverCurrentModel{ + model: currentModel, + }) +} + +func getFailoverCurrentModel(ctx context.Context) *failoverCurrentModel { + if fm, ok := ctx.Value(failoverCurrentModelKey{}).(*failoverCurrentModel); ok { + return fm + } + return nil +} + +type failoverHasMoreAttemptsKey struct{} + +// withFailoverHasMoreAttempts sets a flag in context indicating whether additional failover +// attempts remain after the current one. This is read by buildErrWrapper to decide whether +// stream errors should be wrapped as WillRetryError. +func withFailoverHasMoreAttempts(ctx context.Context, hasMore bool) context.Context { + return context.WithValue(ctx, failoverHasMoreAttemptsKey{}, hasMore) +} + +// getFailoverHasMoreAttempts returns true if the current failover attempt has more attempts +// after it, false otherwise (including when no failover context is present). +func getFailoverHasMoreAttempts(ctx context.Context) bool { + v, _ := ctx.Value(failoverHasMoreAttemptsKey{}).(bool) + return v +} + +type failoverProxyModel struct { +} + +func (m *failoverProxyModel) prepareCallbacks(ctx context.Context) (context.Context, model.BaseChatModel, error) { + current := getFailoverCurrentModel(ctx) + if current == nil || current.model == nil { + return nil, nil, errors.New("failover current model not found in context") + } + + typ, _ := components.GetType(current.model) + ctx = callbacks.EnsureRunInfo(ctx, typ, components.ComponentOfChatModel) + + target := current.model + if !components.IsCallbacksEnabled(target) { + target = (&callbackInjectionModelWrapper{}).WrapModel(target) + } + + return ctx, target, nil +} + +func (m *failoverProxyModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + nCtx, target, err := m.prepareCallbacks(ctx) + if err != nil { + return nil, err + } + + ctx = callbacks.OnStart(ctx, input) + + result, err := target.Generate(nCtx, input, opts...) + if err != nil { + callbacks.OnError(ctx, err) + return result, err + } + + callbacks.OnEnd(ctx, result) + + return result, nil +} + +func (m *failoverProxyModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + nCtx, target, err := m.prepareCallbacks(ctx) + if err != nil { + return nil, err + } + + ctx = callbacks.OnStart(ctx, input) + + result, err := target.Stream(nCtx, input, opts...) + if err != nil { + callbacks.OnError(ctx, err) + return nil, err + } + + _, wrappedStream := callbacks.OnEndWithStreamOutput(ctx, result) + return wrappedStream, nil +} + +func (m *failoverProxyModel) IsCallbacksEnabled() bool { + return true +} + +func (m *failoverProxyModel) GetType() string { + return "FailoverProxyModel" +} + +// FailoverContext contains context information during failover process. +type FailoverContext struct { + // FailoverAttempt is the current failover attempt number, starting from 1. + FailoverAttempt uint + + // InputMessages is the original input messages before any transformation. + InputMessages []*schema.Message + + // LastOutputMessage is the output message from the last failed attempt. + // May be nil if no output was produced. For streaming, this may be a partial message + // already received before the stream error. + LastOutputMessage *schema.Message + + // LastErr is the error from the last failed attempt that triggered this failover. + // + // Note: When ModelRetryConfig is also configured, LastErr will be a *RetryExhaustedError + // (if retries were exhausted) rather than the original model error. The original error + // can be retrieved via RetryExhaustedError.LastErr. + LastErr error +} + +// ModelFailoverConfig configures failover behavior for ChatModel. +// When configured, each ChatModel call first tries the last successful model (initially the configured Model), +// and if that fails, calls GetFailoverModel to select alternate models. +type ModelFailoverConfig struct { + // MaxRetries specifies the maximum number of failover attempts. + // + // When failover is triggered, GetFailoverModel will be called up to MaxRetries times + // (FailoverAttempt starts from 1). If GetFailoverModel returns an error, failover + // stops immediately and that error is returned. + // + // A value of 0 means no failover (GetFailoverModel will not be called). + // A value of 1 means GetFailoverModel may be called once. + // + // Note: if lastSuccessModel is set (from a previous successful call), it will be tried + // first before calling GetFailoverModel. + MaxRetries uint + + // ShouldFailover determines whether to fail over to the next model when an error occurs. + // It receives the output message (may be nil if no output is available) and the error (non-nil on failure). + // For streaming errors, outputMessage can carry a partial message accumulated before the error. + // + // Note: When ModelRetryConfig is also configured, outputErr will be a *RetryExhaustedError + // (if retries were exhausted) rather than the original model error. Use errors.As to extract + // the RetryExhaustedError and access RetryExhaustedError.LastErr for the original error: + // + // var retryErr *adk.RetryExhaustedError + // if errors.As(outputErr, &retryErr) { + // // retryErr.LastErr contains the original model error + // } + // + // Note: When the context itself is cancelled (ctx.Err() != nil), failover will stop immediately + // regardless of this function. However, if the model returns context.Canceled or context.DeadlineExceeded + // as an error while the context is still active, this function will still be called. + // Should not be nil when ModelFailoverConfig is set. + // Return true to fail over to the next model, false to stop and return the current result/error. + ShouldFailover func(ctx context.Context, outputMessage *schema.Message, outputErr error) bool + + // GetFailoverModel is called when a model call fails and ShouldFailover returns true. + // It selects the next model to use for the failover attempt and optionally transforms input messages. + // It receives the failover context containing attempt number (starting from 1), original input, and last result. + // Return values: + // - failoverModel: The model to use for this failover attempt. + // - failoverModelInputMessages: The transformed input messages for the failover model. If nil, will use original input. + // - failoverErr: If non-nil, failover stops and this error is returned. + // Should not be nil when ModelFailoverConfig is set via ChatModelAgentConfig. + GetFailoverModel func(ctx context.Context, failoverCtx *FailoverContext) ( + failoverModel model.BaseChatModel, failoverModelInputMessages []*schema.Message, failoverErr error) +} + +func getLastSuccessModel(ctx context.Context) model.BaseChatModel { + if execCtx := getChatModelAgentExecCtx(ctx); execCtx != nil { + return execCtx.failoverLastSuccessModel + } + return nil +} + +func setLastSuccessModel(ctx context.Context, m model.BaseChatModel) { + if execCtx := getChatModelAgentExecCtx(ctx); execCtx != nil { + execCtx.failoverLastSuccessModel = m + } +} + +type failoverModelWrapper struct { + config *ModelFailoverConfig + inner model.BaseChatModel +} + +func newFailoverModelWrapper(inner model.BaseChatModel, config *ModelFailoverConfig) *failoverModelWrapper { + return &failoverModelWrapper{ + config: config, + inner: inner, + } +} + +func (f *failoverModelWrapper) needFailover(ctx context.Context, outputMessage *schema.Message, outputErr error) bool { + if ctx.Err() != nil { + return false + } + + // ShouldFailover is validated at agent construction; nil here indicates a programmer error. + return f.config.ShouldFailover(ctx, outputMessage, outputErr) +} + +func (f *failoverModelWrapper) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + // Defensive: GetFailoverModel is validated non-nil at agent construction. + if f.config.GetFailoverModel == nil { + return f.inner.Generate(ctx, input, opts...) + } + + var lastOutputMessage *schema.Message + var lastErr error + + // Try lastSuccessModel first if available. + if lastSuccess := getLastSuccessModel(ctx); lastSuccess != nil { + if err := ctx.Err(); err != nil { + return nil, err + } + + modelCtx := setFailoverCurrentModel(ctx, lastSuccess) + modelCtx = withFailoverHasMoreAttempts(modelCtx, f.config.MaxRetries > 0) + result, err := f.inner.Generate(modelCtx, input, opts...) + if err == nil { + return result, nil + } + + lastOutputMessage = result + lastErr = err + + if !f.needFailover(ctx, result, err) { + return result, err + } + + log.Printf("failover ChatModel.Generate lastSuccessModel failed: %v", err) + } + + for attempt := uint(1); attempt <= f.config.MaxRetries; attempt++ { + if err := ctx.Err(); err != nil { + return nil, err + } + + failoverCtx := &FailoverContext{ + FailoverAttempt: attempt, + InputMessages: input, + LastOutputMessage: lastOutputMessage, + LastErr: lastErr, + } + + currentModel, currentInput, err := f.config.GetFailoverModel(ctx, failoverCtx) + if err != nil { + return nil, err + } + if currentModel == nil { + return nil, fmt.Errorf("failover GetFailoverModel returned nil model at attempt %d", attempt) + } + + if currentInput == nil { + currentInput = input + } + + modelCtx := setFailoverCurrentModel(ctx, currentModel) + modelCtx = withFailoverHasMoreAttempts(modelCtx, attempt < f.config.MaxRetries) + result, err := f.inner.Generate(modelCtx, currentInput, opts...) + lastOutputMessage = result + lastErr = err + + if err == nil { + setLastSuccessModel(ctx, currentModel) + return result, nil + } + + if !f.needFailover(ctx, result, err) { + return result, err + } + + if attempt < f.config.MaxRetries { + log.Printf("failover ChatModel.Generate attempt %d failed: %v", attempt, err) + } + } + + return lastOutputMessage, lastErr +} + +func (f *failoverModelWrapper) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) ( + *schema.StreamReader[*schema.Message], error) { + // Defensive: GetFailoverModel is validated non-nil at agent construction. + if f.config.GetFailoverModel == nil { + return f.inner.Stream(ctx, input, opts...) + } + + var lastOutputMessage *schema.Message + var lastErr error + + // Try lastSuccessModel first if available. + if lastSuccess := getLastSuccessModel(ctx); lastSuccess != nil { + if err := ctx.Err(); err != nil { + return nil, err + } + + modelCtx := setFailoverCurrentModel(ctx, lastSuccess) + modelCtx = withFailoverHasMoreAttempts(modelCtx, f.config.MaxRetries > 0) + stream, err := f.inner.Stream(modelCtx, input, opts...) + if err != nil { + lastErr = err + if !f.needFailover(ctx, nil, err) { + return nil, err + } + log.Printf("failover ChatModel.Stream lastSuccessModel failed: %v", err) + } else { + copies := stream.Copy(2) + checkCopy := copies[0] + returnCopy := copies[1] + + outMsg, streamErr := consumeStream(checkCopy) + if streamErr != nil { + lastOutputMessage = outMsg + lastErr = streamErr + returnCopy.Close() + + if !f.needFailover(ctx, outMsg, streamErr) { + return nil, streamErr + } + log.Printf("failover ChatModel.Stream lastSuccessModel failed: %v", streamErr) + } else { + return returnCopy, nil + } + } + } + + for attempt := uint(1); attempt <= f.config.MaxRetries; attempt++ { + if err := ctx.Err(); err != nil { + return nil, err + } + + failoverCtx := &FailoverContext{ + FailoverAttempt: attempt, + InputMessages: input, + LastOutputMessage: lastOutputMessage, + LastErr: lastErr, + } + + currentModel, currentInput, err := f.config.GetFailoverModel(ctx, failoverCtx) + if err != nil { + return nil, err + } + if currentModel == nil { + return nil, fmt.Errorf("failover GetFailoverModel returned nil model at attempt %d", attempt) + } + + if currentInput == nil { + currentInput = input + } + + modelCtx := setFailoverCurrentModel(ctx, currentModel) + modelCtx = withFailoverHasMoreAttempts(modelCtx, attempt < f.config.MaxRetries) + stream, err := f.inner.Stream(modelCtx, currentInput, opts...) + if err != nil { + lastErr = err + lastOutputMessage = nil + + if !f.needFailover(ctx, nil, err) { + return nil, err + } + + if attempt < f.config.MaxRetries { + log.Printf("failover ChatModel.Stream attempt %d failed: %v", attempt, err) + } + continue + } + + // The stream returned by f.inner.Stream is already Copy'd by the inner eventSender layer: one + // copy is forwarded to the client in real time via events. Therefore consuming a copy here does + // NOT block client-side streaming. + // + // We Copy the stream into two readers: + // - checkCopy: consumed synchronously to surface mid-stream errors and decide whether to fail over. + // - returnCopy: returned to the caller (stateModelWrapper), which also consumes synchronously to + // build state (AfterModelRewriteState), so waiting here adds no extra latency. + // + // If checkCopy errors and failover is allowed, we close returnCopy and retry with the next model. + // Otherwise we return returnCopy. + // + // NOTE on duplicate events during failover: when a retry happens, events from the failed attempt + // may already have been emitted to the client, and the retry will emit a new stream. Client-side + // handlers are expected to handle multiple rounds (e.g., reset on retry or deduplicate by attempt + // metadata). + copies := stream.Copy(2) + checkCopy := copies[0] + returnCopy := copies[1] + + outMsg, streamErr := consumeStream(checkCopy) + if streamErr != nil { + lastOutputMessage = outMsg + lastErr = streamErr + returnCopy.Close() + + if !f.needFailover(ctx, outMsg, streamErr) { + return nil, streamErr + } + + if attempt < f.config.MaxRetries { + log.Printf("failover ChatModel.Stream attempt %d failed: %v", attempt, streamErr) + } + continue + } + + setLastSuccessModel(ctx, currentModel) + return returnCopy, nil + } + + return nil, lastErr +} + +func consumeStream(stream *schema.StreamReader[*schema.Message]) (*schema.Message, error) { + defer stream.Close() + chunks := make([]*schema.Message, 0) + for { + chunk, err := stream.Recv() + if err == io.EOF { + break + } + if err != nil { + // ignore concat error + msg, _ := schema.ConcatMessages(chunks) + return msg, err + } + + chunks = append(chunks, chunk) + } + + // Stream completed successfully (EOF). ConcatMessages error is not a stream error, + // so ignore it to avoid incorrectly triggering failover. + msg, _ := schema.ConcatMessages(chunks) + return msg, nil +} diff --git a/adk/failover_chatmodel_test.go b/adk/failover_chatmodel_test.go new file mode 100644 index 000000000..82866e994 --- /dev/null +++ b/adk/failover_chatmodel_test.go @@ -0,0 +1,697 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "errors" + "io" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/schema" +) + +type fakeChatModel struct { + generate func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) + stream func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) + callbacksEnabled bool +} + +func (m *fakeChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + return m.generate(ctx, input, opts...) +} + +func (m *fakeChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return m.stream(ctx, input, opts...) +} + +func (m *fakeChatModel) IsCallbacksEnabled() bool { + return m.callbacksEnabled +} + +func drainMessageStream(sr *schema.StreamReader[*schema.Message]) ([]*schema.Message, error) { + defer sr.Close() + var out []*schema.Message + for { + chunk, err := sr.Recv() + if err == io.EOF { + return out, nil + } + if err != nil { + return out, err + } + out = append(out, chunk) + } +} + +func streamWithMidError(chunks []*schema.Message, err error) *schema.StreamReader[*schema.Message] { + sr, sw := schema.Pipe[*schema.Message](2) + go func() { + defer sw.Close() + for _, c := range chunks { + sw.Send(c, nil) + } + sw.Send(nil, err) + }() + return sr +} + +func streamWithMidErrorControlled(chunks []*schema.Message, err error, firstSent chan struct{}, release chan struct{}) *schema.StreamReader[*schema.Message] { + sr, sw := schema.Pipe[*schema.Message](2) + go func() { + defer sw.Close() + for i, c := range chunks { + sw.Send(c, nil) + if i == 0 && firstSent != nil { + close(firstSent) + if release != nil { + <-release + } + } + } + sw.Send(nil, err) + }() + return sr +} + +func TestFailoverCurrentModelContext(t *testing.T) { + t.Run("set and get", func(t *testing.T) { + ctx := context.Background() + m := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return schema.AssistantMessage("ok", nil), nil + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("ok", nil)}), nil + }, + } + ctx = setFailoverCurrentModel(ctx, m) + got := getFailoverCurrentModel(ctx) + require.NotNil(t, got) + require.Same(t, m, got.model) + }) + + t.Run("wrong type", func(t *testing.T) { + ctx := context.WithValue(context.Background(), failoverCurrentModelKey{}, "bad") + require.Nil(t, getFailoverCurrentModel(ctx)) + }) + + t.Run("missing", func(t *testing.T) { + require.Nil(t, getFailoverCurrentModel(context.Background())) + }) +} + +func TestFailoverProxyModel(t *testing.T) { + t.Run("generate missing context", func(t *testing.T) { + p := &failoverProxyModel{} + _, err := p.Generate(context.Background(), []*schema.Message{schema.UserMessage("hi")}) + require.Error(t, err) + }) + + t.Run("stream missing context", func(t *testing.T) { + p := &failoverProxyModel{} + _, err := p.Stream(context.Background(), []*schema.Message{schema.UserMessage("hi")}) + require.Error(t, err) + }) + + t.Run("generate routes to current model", func(t *testing.T) { + var called int32 + target := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&called, 1) + return schema.AssistantMessage("routed", nil), nil + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("routed", nil)}), nil + }, + } + ctx := setFailoverCurrentModel(context.Background(), target) + p := &failoverProxyModel{} + msg, err := p.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.NoError(t, err) + require.Equal(t, "routed", msg.Content) + require.Equal(t, int32(1), atomic.LoadInt32(&called)) + }) +} + +func TestFailoverModelWrapper_Generate(t *testing.T) { + t.Run("delegates when GetFailoverModel nil", func(t *testing.T) { + var called int32 + inner := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&called, 1) + return schema.AssistantMessage("inner", nil), nil + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("inner", nil)}), nil + }, + } + w := newFailoverModelWrapper(inner, &ModelFailoverConfig{ + MaxRetries: 2, + ShouldFailover: func(context.Context, *schema.Message, error) bool { return true }, + GetFailoverModel: nil, + }) + msg, err := w.Generate(context.Background(), []*schema.Message{schema.UserMessage("hi")}) + require.NoError(t, err) + require.Equal(t, "inner", msg.Content) + require.Equal(t, int32(1), atomic.LoadInt32(&called)) + }) + + t.Run("failover to second model", func(t *testing.T) { + wantErr := errors.New("first failed") + var shouldCalls int32 + var m1Calls int32 + var m2Calls int32 + + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m1Calls, 1) + return nil, wantErr + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, errors.New("unused") + }, + } + m2 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m2Calls, 1) + return schema.AssistantMessage("ok", nil), nil + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, errors.New("unused") + }, + } + + cfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + atomic.AddInt32(&shouldCalls, 1) + return errors.Is(err, wantErr) + }, + GetFailoverModel: func(_ context.Context, failoverCtx *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + require.Equal(t, uint(1), failoverCtx.FailoverAttempt) + return m2, nil, nil + }, + } + + w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + msg, err := w.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.NoError(t, err) + require.Equal(t, "ok", msg.Content) + require.Equal(t, int32(1), atomic.LoadInt32(&m1Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&m2Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&shouldCalls)) + }) + + t.Run("canceled error delegates to ShouldFailover", func(t *testing.T) { + var shouldCalls int32 + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, context.Canceled + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, errors.New("unused") + }, + } + + cfg := &ModelFailoverConfig{ + MaxRetries: 5, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + atomic.AddInt32(&shouldCalls, 1) + // User decides to stop on canceled error + return !errors.Is(err, context.Canceled) + }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return m1, nil, nil + }, + } + + w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + _, err := w.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.ErrorIs(t, err, context.Canceled) + // ShouldFailover is called once and returns false, stopping failover + require.Equal(t, int32(1), atomic.LoadInt32(&shouldCalls)) + }) + + t.Run("stops when GetFailoverModel returns error", func(t *testing.T) { + wantErr := errors.New("get model failed") + var called int32 + inner := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&called, 1) + return schema.AssistantMessage("unused", nil), nil + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, errors.New("unused") + }, + } + + cfg := &ModelFailoverConfig{ + MaxRetries: 3, + ShouldFailover: func(context.Context, *schema.Message, error) bool { return true }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return nil, nil, wantErr + }, + } + + w := newFailoverModelWrapper(inner, cfg) + _, err := w.Generate(context.Background(), []*schema.Message{schema.UserMessage("hi")}) + require.ErrorIs(t, err, wantErr) + require.Equal(t, int32(0), atomic.LoadInt32(&called)) + }) + + t.Run("stops when GetFailoverModel returns nil model", func(t *testing.T) { + cfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(context.Context, *schema.Message, error) bool { return true }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return nil, nil, nil + }, + } + + w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) + msg, err := w.Generate(context.Background(), []*schema.Message{schema.UserMessage("hi")}) + require.Nil(t, msg) + require.Error(t, err) + require.ErrorContains(t, err, "GetFailoverModel returned nil model") + }) +} + +func TestFailoverModelWrapper_Stream(t *testing.T) { + t.Run("returns stream when first attempt succeeds", func(t *testing.T) { + var shouldCalls int32 + in := schema.UserMessage("hi") + + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, input []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + require.Len(t, input, 1) + require.Same(t, in, input[0]) + return schema.StreamReaderFromArray([]*schema.Message{ + schema.AssistantMessage("a", nil), + schema.AssistantMessage("b", nil), + }), nil + }, + } + + cfg := &ModelFailoverConfig{ + MaxRetries: 0, + ShouldFailover: func(context.Context, *schema.Message, error) bool { + atomic.AddInt32(&shouldCalls, 1) + return false + }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return m1, nil, nil + }, + } + + w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + sr, err := w.Stream(ctx, []*schema.Message{in}) + require.NoError(t, err) + msgs, err := drainMessageStream(sr) + require.NoError(t, err) + require.Len(t, msgs, 2) + require.Equal(t, "a", msgs[0].Content) + require.Equal(t, "b", msgs[1].Content) + require.Equal(t, int32(0), atomic.LoadInt32(&shouldCalls)) + }) + + t.Run("failover when Stream returns error immediately", func(t *testing.T) { + wantErr := errors.New("stream init failed") + var shouldCalls int32 + var m1Calls int32 + var m2Calls int32 + + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m1Calls, 1) + return nil, wantErr + }, + } + m2 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m2Calls, 1) + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("ok", nil)}), nil + }, + } + + cfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + atomic.AddInt32(&shouldCalls, 1) + return errors.Is(err, wantErr) + }, + GetFailoverModel: func(_ context.Context, failoverCtx *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + require.Equal(t, uint(1), failoverCtx.FailoverAttempt) + return m2, nil, nil + }, + } + + w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + sr, err := w.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.NoError(t, err) + msgs, err := drainMessageStream(sr) + require.NoError(t, err) + require.Len(t, msgs, 1) + require.Equal(t, "ok", msgs[0].Content) + require.Equal(t, int32(1), atomic.LoadInt32(&m1Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&m2Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&shouldCalls)) + }) + + t.Run("failover when stream errors mid-way", func(t *testing.T) { + streamErr := errors.New("mid error") + var shouldCalls int32 + var seenOutput atomic.Value + var m1Calls int32 + var m2Calls int32 + + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m1Calls, 1) + return streamWithMidError([]*schema.Message{ + schema.AssistantMessage("p1", nil), + schema.AssistantMessage("p2", nil), + }, streamErr), nil + }, + } + m2 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m2Calls, 1) + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("final", nil)}), nil + }, + } + + cfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(_ context.Context, out *schema.Message, err error) bool { + atomic.AddInt32(&shouldCalls, 1) + if errors.Is(err, streamErr) && out != nil { + seenOutput.Store(out.Content) + } + return errors.Is(err, streamErr) + }, + GetFailoverModel: func(_ context.Context, failoverCtx *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + require.Equal(t, uint(1), failoverCtx.FailoverAttempt) + return m2, nil, nil + }, + } + + w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + sr, err := w.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.NoError(t, err) + msgs, err := drainMessageStream(sr) + require.NoError(t, err) + require.Len(t, msgs, 1) + require.Equal(t, "final", msgs[0].Content) + require.Equal(t, "p1p2", seenOutput.Load()) + require.Equal(t, int32(1), atomic.LoadInt32(&m1Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&m2Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&shouldCalls)) + }) + + t.Run("stop when ShouldFailover returns false for mid-way error", func(t *testing.T) { + streamErr := errors.New("mid error") + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return streamWithMidError([]*schema.Message{schema.AssistantMessage("p", nil)}, streamErr), nil + }, + } + + cfg := &ModelFailoverConfig{ + MaxRetries: 3, + ShouldFailover: func(context.Context, *schema.Message, error) bool { + return false + }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return m1, nil, nil + }, + } + + w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + sr, err := w.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.Nil(t, sr) + require.ErrorIs(t, err, streamErr) + }) + + t.Run("canceled mid-way error delegates to ShouldFailover", func(t *testing.T) { + var shouldCalls int32 + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return streamWithMidError([]*schema.Message{schema.AssistantMessage("p", nil)}, context.Canceled), nil + }, + } + + cfg := &ModelFailoverConfig{ + MaxRetries: 3, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + atomic.AddInt32(&shouldCalls, 1) + // User decides to stop on canceled error + return !errors.Is(err, context.Canceled) + }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return m1, nil, nil + }, + } + + w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + sr, err := w.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.Nil(t, sr) + require.ErrorIs(t, err, context.Canceled) + // ShouldFailover is called once and returns false, stopping failover + require.Equal(t, int32(1), atomic.LoadInt32(&shouldCalls)) + }) + + t.Run("stop when Stream returns error immediately and ShouldFailover returns false", func(t *testing.T) { + wantErr := errors.New("stream init failed") + var shouldCalls int32 + var m1Calls int32 + + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m1Calls, 1) + return nil, wantErr + }, + } + + cfg := &ModelFailoverConfig{ + MaxRetries: 3, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + atomic.AddInt32(&shouldCalls, 1) + require.ErrorIs(t, err, wantErr) + return false + }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return m1, nil, nil + }, + } + + w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + sr, err := w.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.Nil(t, sr) + require.ErrorIs(t, err, wantErr) + require.Equal(t, int32(1), atomic.LoadInt32(&m1Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&shouldCalls)) + }) + + t.Run("stops when GetFailoverModel returns nil model", func(t *testing.T) { + cfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(context.Context, *schema.Message, error) bool { return true }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return nil, nil, nil + }, + } + + w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) + sr, err := w.Stream(context.Background(), []*schema.Message{schema.UserMessage("hi")}) + require.Nil(t, sr) + require.Error(t, err) + require.ErrorContains(t, err, "GetFailoverModel returned nil model") + }) + + t.Run("stops when GetFailoverModel returns error", func(t *testing.T) { + wantErr := errors.New("get model failed") + var called int32 + inner := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&called, 1) + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("unused", nil)}), nil + }, + } + + cfg := &ModelFailoverConfig{ + MaxRetries: 3, + ShouldFailover: func(context.Context, *schema.Message, error) bool { return true }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return nil, nil, wantErr + }, + } + + w := newFailoverModelWrapper(inner, cfg) + sr, err := w.Stream(context.Background(), []*schema.Message{schema.UserMessage("hi")}) + require.Nil(t, sr) + require.ErrorIs(t, err, wantErr) + require.Equal(t, int32(0), atomic.LoadInt32(&called)) + }) + + t.Run("stops when ctx canceled during mid-way error handling", func(t *testing.T) { + midErr := errors.New("mid error") + var shouldCalls int32 + var m1Calls int32 + var m2Calls int32 + firstSent := make(chan struct{}) + release := make(chan struct{}) + + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m1Calls, 1) + return streamWithMidErrorControlled( + []*schema.Message{schema.AssistantMessage("p", nil)}, + midErr, + firstSent, + release, + ), nil + }, + } + m2 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m2Calls, 1) + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("unused", nil)}), nil + }, + } + + cfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(context.Context, *schema.Message, error) bool { + atomic.AddInt32(&shouldCalls, 1) + return true + }, + GetFailoverModel: func(_ context.Context, failoverCtx *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + require.Equal(t, uint(1), failoverCtx.FailoverAttempt) + return m2, nil, nil + }, + } + + w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) + baseCtx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + ctx, cancel := context.WithCancel(baseCtx) + type result struct { + sr *schema.StreamReader[*schema.Message] + err error + } + ch := make(chan result, 1) + go func() { + sr, err := w.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) + ch <- result{sr: sr, err: err} + }() + + <-firstSent + cancel() + close(release) + + res := <-ch + if res.sr != nil { + res.sr.Close() + } + require.Nil(t, res.sr) + require.ErrorIs(t, res.err, midErr) + require.Equal(t, int32(1), atomic.LoadInt32(&m1Calls)) + require.Equal(t, int32(0), atomic.LoadInt32(&m2Calls)) + require.Equal(t, int32(0), atomic.LoadInt32(&shouldCalls)) + }) +} diff --git a/adk/handler.go b/adk/handler.go index 854063f16..d18abc965 100644 --- a/adk/handler.go +++ b/adk/handler.go @@ -64,6 +64,12 @@ type ModelContext struct { // Used by EventSenderModelWrapper to wrap stream errors appropriately. ModelRetryConfig *ModelRetryConfig + // ModelFailoverConfig contains the failover configuration for the model. + // This is populated at request time from the agent's ModelFailoverConfig. + // Used by EventSenderModelWrapper to wrap stream errors so that failed failover + // attempts are skipped (not treated as fatal) by the flow event processor. + ModelFailoverConfig *ModelFailoverConfig + cancelContext *cancelContext } diff --git a/adk/prebuilt/deep/deep.go b/adk/prebuilt/deep/deep.go index 48b5349a6..3918d47e4 100644 --- a/adk/prebuilt/deep/deep.go +++ b/adk/prebuilt/deep/deep.go @@ -93,6 +93,10 @@ type Config struct { Handlers []adk.ChatModelAgentMiddleware ModelRetryConfig *adk.ModelRetryConfig + // ModelFailoverConfig configures failover behavior for the ChatModel. + // When set, the agent will automatically fail over to alternative models on errors. + // This config is also propagated to the general sub-agent. + ModelFailoverConfig *adk.ModelFailoverConfig // OutputKey stores the agent's response in the session. // Optional. When set, stores output via AddSessionValue(ctx, outputKey, msg.Content). @@ -129,6 +133,7 @@ func New(ctx context.Context, cfg *Config) (adk.ResumableAgent, error) { cfg.MaxIteration, cfg.Middlewares, append(handlers, cfg.Handlers...), + cfg.ModelFailoverConfig, ) if err != nil { return nil, fmt.Errorf("failed to new task tool: %w", err) @@ -146,9 +151,10 @@ func New(ctx context.Context, cfg *Config) (adk.ResumableAgent, error) { Middlewares: cfg.Middlewares, Handlers: append(handlers, cfg.Handlers...), - GenModelInput: genModelInput, - ModelRetryConfig: cfg.ModelRetryConfig, - OutputKey: cfg.OutputKey, + GenModelInput: genModelInput, + ModelRetryConfig: cfg.ModelRetryConfig, + ModelFailoverConfig: cfg.ModelFailoverConfig, + OutputKey: cfg.OutputKey, }) } diff --git a/adk/prebuilt/deep/task_tool.go b/adk/prebuilt/deep/task_tool.go index 6235021bd..e6fcedeb3 100644 --- a/adk/prebuilt/deep/task_tool.go +++ b/adk/prebuilt/deep/task_tool.go @@ -45,8 +45,9 @@ func newTaskToolMiddleware( maxIteration int, middlewares []adk.AgentMiddleware, handlers []adk.ChatModelAgentMiddleware, + modelFailoverConfig *adk.ModelFailoverConfig, ) (adk.ChatModelAgentMiddleware, error) { - t, err := newTaskTool(ctx, taskToolDescriptionGenerator, subAgents, withoutGeneralSubAgent, cm, instruction, toolsConfig, maxIteration, middlewares, handlers) + t, err := newTaskTool(ctx, taskToolDescriptionGenerator, subAgents, withoutGeneralSubAgent, cm, instruction, toolsConfig, maxIteration, middlewares, handlers, modelFailoverConfig) if err != nil { return nil, err } @@ -71,6 +72,7 @@ func newTaskTool( MaxIteration int, middlewares []adk.AgentMiddleware, handlers []adk.ChatModelAgentMiddleware, + modelFailoverConfig *adk.ModelFailoverConfig, ) (tool.InvokableTool, error) { t := &taskTool{ subAgents: map[string]tool.InvokableTool{}, @@ -88,15 +90,16 @@ func newTaskTool( Chinese: generalAgentDescriptionChinese, }) generalAgent, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ - Name: generalAgentName, - Description: agentDesc, - Instruction: Instruction, - Model: Model, - ToolsConfig: ToolsConfig, - MaxIterations: MaxIteration, - Middlewares: middlewares, - Handlers: handlers, - GenModelInput: genModelInput, + Name: generalAgentName, + Description: agentDesc, + Instruction: Instruction, + Model: Model, + ToolsConfig: ToolsConfig, + MaxIterations: MaxIteration, + Middlewares: middlewares, + Handlers: handlers, + GenModelInput: genModelInput, + ModelFailoverConfig: modelFailoverConfig, }) if err != nil { return nil, err diff --git a/adk/prebuilt/deep/task_tool_test.go b/adk/prebuilt/deep/task_tool_test.go index 91c3a7784..8d60eb452 100644 --- a/adk/prebuilt/deep/task_tool_test.go +++ b/adk/prebuilt/deep/task_tool_test.go @@ -41,6 +41,7 @@ func TestTaskTool(t *testing.T) { 10, nil, nil, + nil, ) assert.NoError(t, err) diff --git a/adk/wrappers.go b/adk/wrappers.go index e6ac617ea..b4e16d298 100644 --- a/adk/wrappers.go +++ b/adk/wrappers.go @@ -34,28 +34,35 @@ type generateEndpoint func(ctx context.Context, input []*schema.Message, opts .. type streamEndpoint func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) type modelWrapperConfig struct { - handlers []ChatModelAgentMiddleware - middlewares []AgentMiddleware - retryConfig *ModelRetryConfig - toolInfos []*schema.ToolInfo - cancelContext *cancelContext + handlers []ChatModelAgentMiddleware + middlewares []AgentMiddleware + retryConfig *ModelRetryConfig + failoverConfig *ModelFailoverConfig + toolInfos []*schema.ToolInfo + cancelContext *cancelContext } func buildModelWrappers(m model.BaseChatModel, config *modelWrapperConfig) model.BaseChatModel { var wrapped model.BaseChatModel = m - if !components.IsCallbacksEnabled(m) { + // failoverProxyModel must be the innermost wrapper to read the selected failover model from context. + if config.failoverConfig != nil { + wrapped = &failoverProxyModel{} + } + + if !components.IsCallbacksEnabled(wrapped) { wrapped = (&callbackInjectionModelWrapper{}).WrapModel(wrapped) } wrapped = &stateModelWrapper{ - inner: wrapped, - original: m, - handlers: config.handlers, - middlewares: config.middlewares, - toolInfos: config.toolInfos, - modelRetryConfig: config.retryConfig, - cancelContext: config.cancelContext, + inner: wrapped, + original: m, + handlers: config.handlers, + middlewares: config.middlewares, + toolInfos: config.toolInfos, + modelRetryConfig: config.retryConfig, + modelFailoverConfig: config.failoverConfig, + cancelContext: config.cancelContext, } return wrapped @@ -265,12 +272,17 @@ func (w *eventSenderModelWrapper) WrapModel(_ context.Context, m model.BaseChatM if mc != nil { retryConfig = mc.ModelRetryConfig } - return &eventSenderModel{inner: inner, modelRetryConfig: retryConfig}, nil + var failoverConfig *ModelFailoverConfig + if mc != nil { + failoverConfig = mc.ModelFailoverConfig + } + return &eventSenderModel{inner: inner, modelRetryConfig: retryConfig, modelFailoverConfig: failoverConfig}, nil } type eventSenderModel struct { - inner model.BaseChatModel - modelRetryConfig *ModelRetryConfig + inner model.BaseChatModel + modelRetryConfig *ModelRetryConfig + modelFailoverConfig *ModelFailoverConfig } func (m *eventSenderModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { @@ -303,19 +315,12 @@ func (m *eventSenderModel) Stream(ctx context.Context, input []*schema.Message, return nil, errors.New("generator is nil when sending event in Stream: ensure agent state is properly initialized") } - var retryAttempt int - _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { - retryAttempt = st.getRetryAttempt() - return nil - }) - streams := result.Copy(2) eventStream := streams[0] - if m.modelRetryConfig != nil { + if errWrapper := m.buildErrWrapper(ctx); errWrapper != nil { convertOpts := []schema.ConvertOption{ - schema.WithErrWrapper(genErrWrapper(ctx, m.modelRetryConfig.MaxRetries, - retryAttempt, m.modelRetryConfig.IsRetryAble)), + schema.WithErrWrapper(errWrapper), } eventStream = schema.StreamReaderWithConvert(streams[0], func(msg *schema.Message) (*schema.Message, error) { return msg, nil }, @@ -328,6 +333,51 @@ func (m *eventSenderModel) Stream(ctx context.Context, input []*schema.Message, return streams[1], nil } +// buildErrWrapper constructs an error wrapper function for event streams. +// It wraps stream errors as WillRetryError when retry or failover is configured, +// so that flow.go:genAgentInput() can skip events from failed attempts instead of +// treating them as fatal errors. +func (m *eventSenderModel) buildErrWrapper(ctx context.Context) func(error) error { + var retryAttempt int + _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { + retryAttempt = st.getRetryAttempt() + return nil + }) + + var retryWrapper func(error) error + if m.modelRetryConfig != nil { + retryWrapper = genErrWrapper(ctx, m.modelRetryConfig.MaxRetries, retryAttempt, m.modelRetryConfig.IsRetryAble) + } + + hasFailover := m.modelFailoverConfig != nil + // failoverHasMoreAttempts is set by failoverModelWrapper before each inner call. + // It is true when additional failover attempts remain after the current one, + // meaning stream errors should be wrapped as WillRetryError so the flow layer + // skips them. On the final attempt it is false, so the error propagates normally. + failoverHasMore := getFailoverHasMoreAttempts(ctx) + + if retryWrapper == nil && !(hasFailover && failoverHasMore) { + return nil + } + + return func(err error) error { + // If retry is configured and will retry this error, use the retry wrapper's WillRetryError. + if retryWrapper != nil { + wrapped := retryWrapper(err) + if _, ok := wrapped.(*WillRetryError); ok { + return wrapped + } + } + // Retry won't handle this error (either exhausted or not configured), but + // failover still has more attempts remaining. Wrap it as WillRetryError so + // the flow layer skips this event from the failed attempt. + if hasFailover && failoverHasMore { + return &WillRetryError{ErrStr: err.Error(), err: err} + } + return err + } +} + func popToolGenAction(ctx context.Context, toolName string) *AgentAction { toolCallID := compose.GetToolCallID(ctx) @@ -515,13 +565,14 @@ func hasUserEventSenderToolWrapper(handlers []ChatModelAgentMiddleware) bool { } type stateModelWrapper struct { - inner model.BaseChatModel - original model.BaseChatModel - handlers []ChatModelAgentMiddleware - middlewares []AgentMiddleware - toolInfos []*schema.ToolInfo - modelRetryConfig *ModelRetryConfig - cancelContext *cancelContext + inner model.BaseChatModel + original model.BaseChatModel + handlers []ChatModelAgentMiddleware + middlewares []AgentMiddleware + toolInfos []*schema.ToolInfo + modelRetryConfig *ModelRetryConfig + modelFailoverConfig *ModelFailoverConfig + cancelContext *cancelContext } func (w *stateModelWrapper) IsCallbacksEnabled() bool { @@ -547,6 +598,7 @@ func (w *stateModelWrapper) hasUserEventSender() bool { func (w *stateModelWrapper) wrapGenerateEndpoint(endpoint generateEndpoint) generateEndpoint { hasUserEventSender := w.hasUserEventSender() retryConfig := w.modelRetryConfig + failoverConfig := w.modelFailoverConfig cc := w.cancelContext for i := len(w.handlers) - 1; i >= 0; i-- { @@ -573,7 +625,7 @@ func (w *stateModelWrapper) wrapGenerateEndpoint(endpoint generateEndpoint) gene if execCtx == nil || execCtx.generator == nil { return innerEndpoint(ctx, input, opts...) } - mc := &ModelContext{ModelRetryConfig: retryConfig, cancelContext: cc} + mc := &ModelContext{ModelRetryConfig: retryConfig, ModelFailoverConfig: failoverConfig, cancelContext: cc} wrappedModel, err := eventSender.WrapModel(ctx, &endpointModel{generate: innerEndpoint}, mc) if err != nil { return nil, err @@ -590,12 +642,23 @@ func (w *stateModelWrapper) wrapGenerateEndpoint(endpoint generateEndpoint) gene } } + // Needs to handle failoverWrapper after retryWrapper + if w.modelFailoverConfig != nil { + config := w.modelFailoverConfig + innerEndpoint := endpoint + endpoint = func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + failoverWrapper := newFailoverModelWrapper(&endpointModel{generate: innerEndpoint}, config) + return failoverWrapper.Generate(ctx, input, opts...) + } + } + return endpoint } func (w *stateModelWrapper) wrapStreamEndpoint(endpoint streamEndpoint) streamEndpoint { hasUserEventSender := w.hasUserEventSender() retryConfig := w.modelRetryConfig + failoverConfig := w.modelFailoverConfig cc := w.cancelContext for i := len(w.handlers) - 1; i >= 0; i-- { @@ -622,7 +685,7 @@ func (w *stateModelWrapper) wrapStreamEndpoint(endpoint streamEndpoint) streamEn if execCtx == nil || execCtx.generator == nil { return innerEndpoint(ctx, input, opts...) } - mc := &ModelContext{ModelRetryConfig: retryConfig, cancelContext: cc} + mc := &ModelContext{ModelRetryConfig: retryConfig, ModelFailoverConfig: failoverConfig, cancelContext: cc} wrappedModel, err := eventSender.WrapModel(ctx, &endpointModel{stream: innerEndpoint}, mc) if err != nil { return nil, err @@ -639,6 +702,16 @@ func (w *stateModelWrapper) wrapStreamEndpoint(endpoint streamEndpoint) streamEn } } + // Needs to handle failoverWrapper after retryWrapper + if w.modelFailoverConfig != nil { + config := w.modelFailoverConfig + innerEndpoint := endpoint + endpoint = func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + failoverWrapper := newFailoverModelWrapper(&endpointModel{stream: innerEndpoint}, config) + return failoverWrapper.Stream(ctx, input, opts...) + } + } + return endpoint } diff --git a/adk/wrappers_failover_test.go b/adk/wrappers_failover_test.go new file mode 100644 index 000000000..8b14463e1 --- /dev/null +++ b/adk/wrappers_failover_test.go @@ -0,0 +1,181 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "errors" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/schema" +) + +func TestBuildModelWrappers_FailoverProxyInner(t *testing.T) { + base := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return schema.AssistantMessage("ok", nil), nil + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("ok", nil)}), nil + }, + } + + failoverCfg := &ModelFailoverConfig{ + MaxRetries: 0, + ShouldFailover: func(context.Context, *schema.Message, error) bool { return false }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return base, nil, nil + }, + } + + wrapped := buildModelWrappers(base, &modelWrapperConfig{ + failoverConfig: failoverCfg, + }) + + smw, ok := wrapped.(*stateModelWrapper) + require.True(t, ok) + _, ok = smw.inner.(*failoverProxyModel) + require.True(t, ok) + require.Same(t, base, smw.original) + require.Same(t, failoverCfg, smw.modelFailoverConfig) +} + +func TestStateModelWrapper_Generate_WithFailover(t *testing.T) { + wantErr := errors.New("first failed") + var shouldCalls int32 + var m1Calls int32 + var m2Calls int32 + + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m1Calls, 1) + return schema.AssistantMessage("partial", nil), wantErr + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, errors.New("unused") + }, + } + m2 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m2Calls, 1) + return schema.AssistantMessage("ok", nil), nil + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, errors.New("unused") + }, + } + + failoverCfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(_ context.Context, out *schema.Message, err error) bool { + atomic.AddInt32(&shouldCalls, 1) + require.ErrorIs(t, err, wantErr) + require.NotNil(t, out) + require.Equal(t, "partial", out.Content) + return true + }, + GetFailoverModel: func(_ context.Context, failoverCtx *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + require.Equal(t, uint(1), failoverCtx.FailoverAttempt) + return m2, nil, nil + }, + } + + wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + failoverConfig: failoverCfg, + }) + + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + got, err := wrapped.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.NoError(t, err) + require.NotNil(t, got) + require.Equal(t, "ok", got.Content) + require.Equal(t, int32(1), atomic.LoadInt32(&m1Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&m2Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&shouldCalls)) +} + +func TestStateModelWrapper_Stream_WithFailover(t *testing.T) { + streamErr := errors.New("mid error") + var shouldCalls int32 + var m1Calls int32 + var m2Calls int32 + + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m1Calls, 1) + return streamWithMidError([]*schema.Message{ + schema.AssistantMessage("p1", nil), + schema.AssistantMessage("p2", nil), + }, streamErr), nil + }, + } + m2 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m2Calls, 1) + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("final", nil)}), nil + }, + } + + failoverCfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(_ context.Context, out *schema.Message, err error) bool { + atomic.AddInt32(&shouldCalls, 1) + require.ErrorIs(t, err, streamErr) + require.NotNil(t, out) + require.Equal(t, "p1p2", out.Content) + return true + }, + GetFailoverModel: func(_ context.Context, failoverCtx *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + require.Equal(t, uint(1), failoverCtx.FailoverAttempt) + return m2, nil, nil + }, + } + + wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + failoverConfig: failoverCfg, + }) + + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + sr, err := wrapped.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.NoError(t, err) + msgs, err := drainMessageStream(sr) + require.NoError(t, err) + require.Len(t, msgs, 1) + require.Equal(t, "final", msgs[0].Content) + require.Equal(t, int32(1), atomic.LoadInt32(&m1Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&m2Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&shouldCalls)) +} diff --git a/adk/wrappers_retry_failover_test.go b/adk/wrappers_retry_failover_test.go new file mode 100644 index 000000000..98db172e9 --- /dev/null +++ b/adk/wrappers_retry_failover_test.go @@ -0,0 +1,411 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/schema" +) + +// TestRetryThenFailover_Generate_RetryExhaustedTriggersFailover tests the combined +// retry + failover path for Generate: m1 always fails, retry exhausted, failover to m2 which succeeds. +func TestRetryThenFailover_Generate_RetryExhaustedTriggersFailover(t *testing.T) { + modelErr := errors.New("model error") + var m1Calls int32 + var m2Calls int32 + + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m1Calls, 1) + return nil, modelErr + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, errors.New("unused") + }, + } + m2 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m2Calls, 1) + return schema.AssistantMessage("ok from m2", nil), nil + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, errors.New("unused") + }, + } + + retryCfg := &ModelRetryConfig{ + MaxRetries: 2, + IsRetryAble: func(_ context.Context, err error) bool { return true }, + BackoffFunc: func(_ context.Context, _ int) time.Duration { return 0 }, + } + + failoverCfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + return err != nil + }, + GetFailoverModel: func(_ context.Context, fc *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + require.NotNil(t, fc.LastErr) + return m2, nil, nil + }, + } + + wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + retryConfig: retryCfg, + failoverConfig: failoverCfg, + }) + + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + msg, err := wrapped.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.NoError(t, err) + require.Equal(t, "ok from m2", msg.Content) + + // m1: 1 (lastSuccess) + 2 retries = 3 calls on lastSuccess attempt, + // then failover to m2 which also goes through retry wrapper: 1 call succeeds. + require.Equal(t, int32(3), atomic.LoadInt32(&m1Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&m2Calls)) +} + +// TestRetryThenFailover_Generate_AllExhausted tests: m1 retry exhausted → failover to m2 → m2 retry exhausted → final error. +func TestRetryThenFailover_Generate_AllExhausted(t *testing.T) { + modelErr := errors.New("always fails") + var m1Calls int32 + var m2Calls int32 + + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m1Calls, 1) + return nil, modelErr + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, errors.New("unused") + }, + } + m2 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m2Calls, 1) + return nil, modelErr + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, errors.New("unused") + }, + } + + retryCfg := &ModelRetryConfig{ + MaxRetries: 1, + IsRetryAble: func(_ context.Context, err error) bool { return true }, + BackoffFunc: func(_ context.Context, _ int) time.Duration { return 0 }, + } + + failoverCfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + return err != nil + }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return m2, nil, nil + }, + } + + wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + retryConfig: retryCfg, + failoverConfig: failoverCfg, + }) + + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + _, err := wrapped.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.Error(t, err) + + // Should be RetryExhaustedError from m2's retry wrapper + var retryErr *RetryExhaustedError + require.True(t, errors.As(err, &retryErr)) + + // m1: 1 initial + 1 retry = 2 calls + require.Equal(t, int32(2), atomic.LoadInt32(&m1Calls)) + // m2: 1 initial + 1 retry = 2 calls + require.Equal(t, int32(2), atomic.LoadInt32(&m2Calls)) +} + +// TestRetryThenFailover_Stream_RetryExhaustedTriggersFailover tests stream path: +// m1 stream always errors mid-way, retry exhausted, failover to m2 which succeeds. +func TestRetryThenFailover_Stream_RetryExhaustedTriggersFailover(t *testing.T) { + streamErr := errors.New("stream mid error") + var m1Calls int32 + var m2Calls int32 + + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m1Calls, 1) + return streamWithMidError([]*schema.Message{ + schema.AssistantMessage("partial", nil), + }, streamErr), nil + }, + } + m2 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m2Calls, 1) + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("ok from m2", nil)}), nil + }, + } + + retryCfg := &ModelRetryConfig{ + MaxRetries: 1, + IsRetryAble: func(_ context.Context, err error) bool { return true }, + BackoffFunc: func(_ context.Context, _ int) time.Duration { return 0 }, + } + + failoverCfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + return err != nil + }, + GetFailoverModel: func(_ context.Context, fc *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + require.NotNil(t, fc.LastErr) + return m2, nil, nil + }, + } + + wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + retryConfig: retryCfg, + failoverConfig: failoverCfg, + }) + + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + sr, err := wrapped.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.NoError(t, err) + msgs, err := drainMessageStream(sr) + require.NoError(t, err) + require.Len(t, msgs, 1) + require.Equal(t, "ok from m2", msgs[0].Content) + + // m1: 1 initial + 1 retry = 2 calls on lastSuccess attempt + require.Equal(t, int32(2), atomic.LoadInt32(&m1Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&m2Calls)) +} + +// TestRetryThenFailover_Generate_RetrySucceedsNoFailover tests that when retry +// succeeds on the first model, failover is never triggered. +func TestRetryThenFailover_Generate_RetrySucceedsNoFailover(t *testing.T) { + var m1Calls int32 + var failoverCalled int32 + + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + n := atomic.AddInt32(&m1Calls, 1) + if n == 1 { + return nil, errors.New("transient error") + } + return schema.AssistantMessage("ok on retry", nil), nil + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, errors.New("unused") + }, + } + + retryCfg := &ModelRetryConfig{ + MaxRetries: 2, + IsRetryAble: func(_ context.Context, err error) bool { return true }, + BackoffFunc: func(_ context.Context, _ int) time.Duration { return 0 }, + } + + failoverCfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + atomic.AddInt32(&failoverCalled, 1) + return true + }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + t.Fatal("GetFailoverModel should not be called when retry succeeds") + return nil, nil, nil + }, + } + + wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + retryConfig: retryCfg, + failoverConfig: failoverCfg, + }) + + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + msg, err := wrapped.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.NoError(t, err) + require.Equal(t, "ok on retry", msg.Content) + + // 2 calls: first fails, second succeeds via retry + require.Equal(t, int32(2), atomic.LoadInt32(&m1Calls)) + // ShouldFailover should never be called + require.Equal(t, int32(0), atomic.LoadInt32(&failoverCalled)) +} + +// TestRetryThenFailover_Generate_NonRetryableErrorTriggersFailover tests that a non-retryable +// error skips retry and directly triggers failover. +func TestRetryThenFailover_Generate_NonRetryableErrorTriggersFailover(t *testing.T) { + nonRetryableErr := errors.New("non-retryable") + var m1Calls int32 + var m2Calls int32 + + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m1Calls, 1) + return nil, nonRetryableErr + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, errors.New("unused") + }, + } + m2 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m2Calls, 1) + return schema.AssistantMessage("ok from m2", nil), nil + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, errors.New("unused") + }, + } + + retryCfg := &ModelRetryConfig{ + MaxRetries: 3, + IsRetryAble: func(_ context.Context, err error) bool { + // Only non-retryable errors + return !errors.Is(err, nonRetryableErr) + }, + BackoffFunc: func(_ context.Context, _ int) time.Duration { return 0 }, + } + + failoverCfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + return err != nil + }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return m2, nil, nil + }, + } + + wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + retryConfig: retryCfg, + failoverConfig: failoverCfg, + }) + + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + msg, err := wrapped.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.NoError(t, err) + require.Equal(t, "ok from m2", msg.Content) + + // m1 called only once — non-retryable error skips retry + require.Equal(t, int32(1), atomic.LoadInt32(&m1Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&m2Calls)) +} + +// TestRetryThenFailover_Stream_AllExhausted tests stream path when both retry and failover are exhausted. +func TestRetryThenFailover_Stream_AllExhausted(t *testing.T) { + streamErr := errors.New("always fails mid-stream") + var m1Calls int32 + var m2Calls int32 + + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m1Calls, 1) + return streamWithMidError([]*schema.Message{ + schema.AssistantMessage("p", nil), + }, streamErr), nil + }, + } + m2 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m2Calls, 1) + return streamWithMidError([]*schema.Message{ + schema.AssistantMessage("p", nil), + }, streamErr), nil + }, + } + + retryCfg := &ModelRetryConfig{ + MaxRetries: 1, + IsRetryAble: func(_ context.Context, err error) bool { return true }, + BackoffFunc: func(_ context.Context, _ int) time.Duration { return 0 }, + } + + failoverCfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + return err != nil + }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return m2, nil, nil + }, + } + + wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + retryConfig: retryCfg, + failoverConfig: failoverCfg, + }) + + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + _, err := wrapped.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.Error(t, err) + + var retryErr *RetryExhaustedError + require.True(t, errors.As(err, &retryErr)) + + // m1: 1 initial + 1 retry = 2 calls + require.Equal(t, int32(2), atomic.LoadInt32(&m1Calls)) + // m2: 1 initial + 1 retry = 2 calls + require.Equal(t, int32(2), atomic.LoadInt32(&m2Calls)) +} From 8c5a10e4e0612ca5997346c244ea820604d0bb77 Mon Sep 17 00:00:00 2001 From: Megumin Date: Thu, 9 Apr 2026 14:57:26 +0800 Subject: [PATCH 52/65] feat: tool search (#884) feat: tool search definition --- .../dynamictool/toolsearch/prompt.go | 162 ++++ .../dynamictool/toolsearch/toolsearch.go | 457 ++++++++-- .../dynamictool/toolsearch/toolsearch_test.go | 849 ++++++++++-------- components/model/option.go | 37 + schema/agentic_message.go | 52 +- schema/agentic_message_test.go | 2 + schema/message.go | 42 +- schema/tool.go | 101 +++ schema/tool_test.go | 48 + 9 files changed, 1312 insertions(+), 438 deletions(-) create mode 100644 adk/middlewares/dynamictool/toolsearch/prompt.go diff --git a/adk/middlewares/dynamictool/toolsearch/prompt.go b/adk/middlewares/dynamictool/toolsearch/prompt.go new file mode 100644 index 000000000..5aaa56ad1 --- /dev/null +++ b/adk/middlewares/dynamictool/toolsearch/prompt.go @@ -0,0 +1,162 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package toolsearch + +const ( + toolDescription = `Search for or select deferred tools to make them available for use. + +MANDATORY PREREQUISITE - THIS IS A HARD REQUIREMENT + +You MUST use this tool to load deferred tools BEFORE calling them directly. + +This is a BLOCKING REQUIREMENT - deferred tools are NOT available until you load them using this tool. Look for messages in the conversation for the list of tools you can discover. Both query modes (keyword search and direct selection) load the returned tools — once a tool appears in the results, it is immediately available to call. + +Why this is non-negotiable: +- Deferred tools are not loaded until discovered via this tool +- Calling a deferred tool without first loading it will fail +Query modes: + +1. Keyword search - Use keywords when you're unsure which tool to use or need to discover multiple tools at once: + - "list directory" - find tools for listing directories + - "notebook jupyter" - find notebook editing tools + - "slack message" - find slack messaging tools + - Returns up to 5 matching tools ranked by relevance + - All returned tools are immediately available to call — no further selection step needed +2. Direct selection - Use select: when you know the exact tool name: + - "select:mcp__slack__read_channel" + - "select:NotebookEdit" + - "select:Read,Edit,Grep" - load multiple tools at once with comma separation + - Returns the named tool(s) if they exist +IMPORTANT: Both modes load tools equally. Do NOT follow up a keyword search with select: calls for tools already returned — they are already loaded. + +3. Required keyword - Prefix with + to require a match: + - "+linear create issue" - only tools from "linear", ranked by "create"/"issue" + - "+slack send" - only "slack" tools, ranked by "send" + - Useful when you know the service name but not the exact tool +CORRECT Usage Patterns: + + +User: I need to work with slack somehow +Assistant: Let me search for slack tools. +[Calls tool_search with query: "slack"] +Assistant: Found several options including mcp__slack__read_channel. +[Calls mcp__slack__read_channel directly — it was loaded by the keyword search] + + + +User: Edit the Jupyter notebook +Assistant: Let me load the notebook editing tool. +[Calls tool_search with query: "select:NotebookEdit"] +[Calls NotebookEdit] + + + +User: List files in the src directory +Assistant: I can see mcp__filesystem__list_directory in the available tools. Let me select it. +[Calls tool_search with query: "select:mcp__filesystem__list_directory"] +[Calls the tool] + + +INCORRECT Usage Patterns - NEVER DO THESE: + + +User: Read my slack messages +Assistant: [Directly calls mcp__slack__read_channel without loading it first] +WRONG - You must load the tool FIRST using this tool + + + +Assistant: [Calls tool_search with query: "slack", gets back mcp__slack__read_channel] +Assistant: [Calls tool_search with query: "select:mcp__slack__read_channel"] +WRONG - The keyword search already loaded the tool. The select call is redundant. +` + + toolDescriptionChinese = `搜索或选择延迟加载(deferred)的工具,使其可供调用。 + +强制前提条件(MANDATORY PREREQUISITE)— 硬性要求 + +在直接调用任何 延迟加载工具(deferred tools) 之前,你 必须先使用此工具将其加载。 + +这是一个 阻塞性要求(BLOCKING REQUIREMENT) — 延迟加载工具在被加载之前是 不可用的。你需要在对话中查找 消息,以获取可以发现的工具列表。无论使用哪种查询方式(关键字搜索 或 直接选择),只要工具出现在返回结果中,它们就会自动被加载并立即可调用。 + +为什么这是不可协商的规则: +- 延迟加载工具在被发现之前不会被加载 +- 如果你在加载之前直接调用延迟工具,调用将会失败 +查询模式: + +1. 关键字搜索(Keyword search)- 当你不确定具体需要哪个工具,或希望一次发现多个工具时使用关键字搜索: +- "list directory" — 查找用于列出目录的工具 +- "notebook jupyter" — 查找 Jupyter Notebook 编辑工具 +- "slack message" — 查找 Slack 消息相关工具 +- 返回最多 5 个最相关的工具 +- 所有返回的工具都会立即加载并可直接调用 — 不需要额外执行 select 步骤 + +2. 直接选择(Direct selection)— 当你已经知道工具的确切名称时使用 select:: +- "select:mcp__slack__read_channel" +- "select:NotebookEdit" +- "select:Read,Edit,Grep" — 一次加载多个工具 +- 如果工具存在,将被加载并返回 +重要说明:两种模式的加载效果完全相同。不要在关键词搜索之后,对返回的工具再次进行 select: 选择 — 它们已经加载好了。 + +3. 必须匹配关键字(Required keyword)— 在关键字前添加 + 可以 强制匹配特定服务或来源。 +- "+linear create issue" — 仅返回名字中包含 "linear" 的工具,按 "create" / "issue" 排序 +- "+slack send" — 仅返回名字中包含 "slack" 的工具,按 "send" 排序 +- 适用于你知道服务名称但不知道具体工具名称 + +正确使用示例: + + +User: 我需要处理 Slack 相关的事情 +Assistant: 让我搜索 Slack 工具。 +[调用 tool_search,query: "slack"] +Assistant: 找到多个选项,包括 mcp__slack__read_channel。 +[直接调用 mcp__slack__read_channel — 关键字搜索已经加载了该工具] + + + +User: 编辑这个 Jupyter Notebook +Assistant: 让我加载 Notebook 编辑工具。 +[调用 tool_search,query: "select:NotebookEdit"] +[调用 NotebookEdit] + + + +User: 列出 src 目录中的文件 +Assistant: 我看到可用工具中有 mcp__filesystem__list_directory,让我加载它。 +[调用 tool_search,query: "select:mcp__filesystem__list_directory"] +[调用该工具] + + +错误用法(严禁) + + +User: 读取我的 Slack 消息 +Assistant: [不调用 tool_search 工具加载,直接调用 mcp__slack__read_channel] +错误 — 在调用工具之前没有先使用 tool_search 加载该工具。 + + + +Assistant:[调用 tool_search,query: "slack",返回 mcp__slack__read_channel] +Assistant:[再次调用 tool_search,query: "select:mcp__slack__read_channel"] +错误 — 关键字搜索 已经加载了该工具,再次 select 是冗余操作。` + + systemReminderTpl = ` +{{- range .Tools }} +{{ . }} +{{- end }} +` +) diff --git a/adk/middlewares/dynamictool/toolsearch/toolsearch.go b/adk/middlewares/dynamictool/toolsearch/toolsearch.go index 4ee4c216b..55883e914 100644 --- a/adk/middlewares/dynamictool/toolsearch/toolsearch.go +++ b/adk/middlewares/dynamictool/toolsearch/toolsearch.go @@ -18,12 +18,17 @@ package toolsearch import ( + "bytes" "context" "encoding/json" "fmt" - "regexp" + "sort" + "strings" + "text/template" + "unicode" "github.com/cloudwego/eino/adk" + "github.com/cloudwego/eino/adk/internal" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/schema" @@ -33,6 +38,16 @@ import ( type Config struct { // DynamicTools is a list of tools that can be dynamically searched and loaded by the agent. DynamicTools []tool.BaseTool + + // UseModelToolSearch indicates whether the ChatModel natively supports tool search. + // + // When true, the middleware delegates tool search to the model's native capability. + // + // When false (default), the middleware manages tool visibility by filtering the tool list + // based on tool_search results before each model call. Note that this approach may + // invalidate the model's KV-cache (as the tool list changes between calls), and effectiveness + // depends on the model's ability to work with a dynamically changing tool set. + UseModelToolSearch bool } // New constructs and returns the tool search middleware. @@ -41,7 +56,7 @@ type Config struct { // Instead of passing all tools to the model at once (which can overwhelm context limits), // this middleware: // -// 1. Adds a "tool_search" meta-tool that accepts a regex pattern to search tool names +// 1. Adds a "tool_search" meta-tool that accepts keyword queries to search tools // 2. Initially hides all dynamic tools from the model's tool list // 3. When the model calls tool_search, matching tools become available for subsequent calls // @@ -62,14 +77,55 @@ func New(ctx context.Context, config *Config) (adk.ChatModelAgentMiddleware, err return nil, fmt.Errorf("tools is required") } + tpl, err := template.New("").Parse(systemReminderTpl) + if err != nil { + return nil, err + } + + dynamicToolInfos := make([]*schema.ToolInfo, 0, len(config.DynamicTools)) + mapOfDynamicTools := make(map[string]*schema.ToolInfo, len(config.DynamicTools)) + toolNames := make([]string, 0, len(config.DynamicTools)) + for _, t := range config.DynamicTools { + info, infoErr := t.Info(ctx) + if infoErr != nil { + return nil, fmt.Errorf("failed to get dynamic tool info: %w", infoErr) + } + + if _, ok := mapOfDynamicTools[info.Name]; ok { + return nil, fmt.Errorf("duplicate dynamic tool name: %s", info.Name) + } + + toolNames = append(toolNames, info.Name) + mapOfDynamicTools[info.Name] = info + dynamicToolInfos = append(dynamicToolInfos, info) + } + + buf := &bytes.Buffer{} + err = tpl.Execute(buf, systemReminder{Tools: toolNames}) + if err != nil { + return nil, fmt.Errorf("failed to format system reminder template: %w", err) + } + return &middleware{ - dynamicTools: config.DynamicTools, + dynamicTools: config.DynamicTools, + mapOfDynamicTools: mapOfDynamicTools, + dynamicToolInfos: dynamicToolInfos, + useModelToolSearch: config.UseModelToolSearch, + sr: buf.String(), }, nil } +type systemReminder struct { + Tools []string +} + type middleware struct { adk.BaseChatModelAgentMiddleware - dynamicTools []tool.BaseTool + dynamicTools []tool.BaseTool + mapOfDynamicTools map[string]*schema.ToolInfo + dynamicToolInfos []*schema.ToolInfo + useModelToolSearch bool + sr string } func (m *middleware) BeforeAgent(ctx context.Context, runCtx *adk.ChatModelAgentContext) (context.Context, *adk.ChatModelAgentContext, error) { @@ -78,123 +134,384 @@ func (m *middleware) BeforeAgent(ctx context.Context, runCtx *adk.ChatModelAgent } nRunCtx := *runCtx - toolNames, err := getToolNames(ctx, m.dynamicTools) - if err != nil { - return ctx, nil, fmt.Errorf("failed to get tool names: %w", err) - } - nRunCtx.Tools = append(nRunCtx.Tools, newToolSearchTool(toolNames)) + nRunCtx.Tools = make([]tool.BaseTool, len(runCtx.Tools), len(runCtx.Tools)+1+len(m.dynamicTools)) + copy(nRunCtx.Tools, runCtx.Tools) + nRunCtx.Tools = append(nRunCtx.Tools, newToolSearchTool(m.mapOfDynamicTools, m.useModelToolSearch)) nRunCtx.Tools = append(nRunCtx.Tools, m.dynamicTools...) return ctx, &nRunCtx, nil } func (m *middleware) WrapModel(_ context.Context, cm model.BaseChatModel, mc *adk.ModelContext) (model.BaseChatModel, error) { - return &wrapper{allTools: mc.Tools, cm: cm, dynamicTools: m.dynamicTools}, nil + return &wrapper{ + allTools: mc.Tools, + cm: cm, + dynamicToolInfos: m.dynamicToolInfos, + reminder: m.sr, + useModelToolSearch: m.useModelToolSearch, + }, nil } type wrapper struct { - allTools []*schema.ToolInfo - dynamicTools []tool.BaseTool + allTools []*schema.ToolInfo + dynamicToolInfos []*schema.ToolInfo + reminder string + useModelToolSearch bool cm model.BaseChatModel } func (w *wrapper) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { - tools, err := removeTools(ctx, w.allTools, w.dynamicTools, input) + toolsOpts, err := w.resolveTools(ctx, input) if err != nil { return nil, fmt.Errorf("failed to load dynamic tools: %w", err) } - return w.cm.Generate(ctx, input, append(opts, model.WithTools(tools))...) + return w.cm.Generate(ctx, w.insertReminder(input), append(opts, toolsOpts...)...) } func (w *wrapper) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { - tools, err := removeTools(ctx, w.allTools, w.dynamicTools, input) + toolsOpts, err := w.resolveTools(ctx, input) if err != nil { return nil, fmt.Errorf("failed to load dynamic tools: %w", err) } - return w.cm.Stream(ctx, input, append(opts, model.WithTools(tools))...) + return w.cm.Stream(ctx, w.insertReminder(input), append(opts, toolsOpts...)...) +} + +func (w *wrapper) resolveTools(ctx context.Context, input []*schema.Message) ([]model.Option, error) { + if w.useModelToolSearch { + // Model handles tool search natively; remove all dynamic tools from the list. + return calculateTools(ctx, w.allTools, w.dynamicToolInfos, nil, w.useModelToolSearch) + } + return calculateTools(ctx, w.allTools, w.dynamicToolInfos, input, w.useModelToolSearch) +} + +func (w *wrapper) insertReminder(input []*schema.Message) []*schema.Message { + inserted := false + ret := make([]*schema.Message, 0, len(input)+1) + for _, m := range input { + if m.Role != schema.System && !inserted { + inserted = true + ret = append(ret, schema.UserMessage(w.reminder)) + } + ret = append(ret, m) + } + if !inserted { + ret = append(ret, schema.UserMessage(w.reminder)) + } + return ret +} + +func newToolSearchTool(tools map[string]*schema.ToolInfo, useModelToolSearch bool) tool.BaseTool { + if useModelToolSearch { + return &modelToolSearchTool{tools: tools} + } + return &toolSearchTool{tools: tools} } -func newToolSearchTool(toolNames []string) *toolSearchTool { - return &toolSearchTool{toolNames: toolNames} +type toolSearchArgs struct { + Query string `json:"query"` + MaxResults *int `json:"max_results,omitempty"` +} + +type toolSearchResult struct { + Matches []string `json:"matches"` } type toolSearchTool struct { - toolNames []string + tools map[string]*schema.ToolInfo +} + +func (t *toolSearchTool) Info(ctx context.Context) (*schema.ToolInfo, error) { + return getToolSearchToolInfo(), nil +} + +func (t *toolSearchTool) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { + matches, err := search(argumentsInJSON, t.tools) + if err != nil { + return "", err + } + result := &toolSearchResult{} + for _, m := range matches { + result.Matches = append(result.Matches, m.Name) + } + b, err := json.Marshal(result) + if err != nil { + return "", fmt.Errorf("failed to marshal tool search result: %w", err) + } + return string(b), nil +} + +type modelToolSearchTool struct { + tools map[string]*schema.ToolInfo +} + +func (t *modelToolSearchTool) Info(_ context.Context) (*schema.ToolInfo, error) { + return getToolSearchToolInfo(), nil +} + +func (t *modelToolSearchTool) InvokableRun(_ context.Context, argumentsInJSON *schema.ToolArgument, _ ...tool.Option) (*schema.ToolResult, error) { + ret, err := search(argumentsInJSON.Text, t.tools) + if err != nil { + return nil, err + } + + return &schema.ToolResult{Parts: []schema.ToolOutputPart{ + { + Type: schema.ToolPartTypeToolSearchResult, + ToolSearchResult: &schema.ToolSearchResult{ + Tools: ret, + }, + }, + }}, nil } const ( toolSearchToolName = "tool_search" + defaultMaxResults = 5 ) -func (t *toolSearchTool) Info(ctx context.Context) (*schema.ToolInfo, error) { +func getToolSearchToolInfo() *schema.ToolInfo { return &schema.ToolInfo{ - Name: "tool_search", - Desc: "Search for tools using a regex pattern that matches tool names. Returns a list of matching tool names. Use this when you need a tool but don't have it available yet.", + Name: toolSearchToolName, + Desc: internal.SelectPrompt(internal.I18nPrompts{ + English: toolDescription, + Chinese: toolDescriptionChinese, + }), ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ - "regex_pattern": { + "query": { Type: schema.String, - Desc: "A regex pattern to match tool names against.", + Desc: "Query to find deferred tools. Use \"select:\" for direct selection, or keywords to search.", Required: true, }, + "max_results": { + Type: schema.Integer, + Desc: "Maximum number of results to return (default: 5)", + Required: false, + }, }), - }, nil + } } -type toolSearchArgs struct { - RegexPattern string `json:"regex_pattern"` +func search(argumentsInJSON string, tools map[string]*schema.ToolInfo) ([]*schema.ToolInfo, error) { + var args toolSearchArgs + if err := json.Unmarshal([]byte(argumentsInJSON), &args); err != nil { + return nil, fmt.Errorf("failed to unmarshal tool search arguments: %w", err) + } + + query := strings.TrimSpace(args.Query) + if query == "" { + return nil, fmt.Errorf("query is required") + } + + maxResults := defaultMaxResults + if args.MaxResults != nil && *args.MaxResults > 0 { + maxResults = *args.MaxResults + } + + var matches []string + + // Direct selection mode: select:tool1,tool2 + // max_results is intentionally not applied here because the model has + // already specified the exact tools it wants by name. + if strings.HasPrefix(query, "select:") { + names := strings.Split(strings.TrimPrefix(query, "select:"), ",") + toolSet := make(map[string]bool, len(tools)) + for name := range tools { + toolSet[name] = true + } + for _, name := range names { + name = strings.TrimSpace(name) + if name != "" && toolSet[name] { + matches = append(matches, name) + } + } + } else { + matches = keywordSearch(query, maxResults, tools) + } + + ret := make([]*schema.ToolInfo, 0, len(matches)) + for _, name := range matches { + ti, ok := tools[name] + if !ok { + continue + } + ret = append(ret, ti) + } + return ret, nil } -type toolSearchResult struct { - SelectedTools []string `json:"selectedTools"` +func intMax(a, b int) int { + if a > b { + return a + } + return b } -func (t *toolSearchTool) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { - var args toolSearchArgs - if err := json.Unmarshal([]byte(argumentsInJSON), &args); err != nil { - return "", fmt.Errorf("failed to unmarshal tool search arguments: %w", err) +func intMin(a, b int) int { + if a < b { + return a } + return b +} + +// scoredTool pairs a tool name with its search score. +type scoredTool struct { + name string + score int +} - if args.RegexPattern == "" { - return "", fmt.Errorf("regex_pattern is required") +// keywordSearch scores all tools against the query keywords and returns the top N. +func keywordSearch(query string, maxResults int, tools map[string]*schema.ToolInfo) []string { + keywords := parseKeywords(query) + if len(keywords) == 0 { + return nil } - re, err := regexp.Compile(args.RegexPattern) - if err != nil { - return "", fmt.Errorf("invalid regex pattern: %w", err) + var scored []scoredTool + + for name, tm := range tools { + nameParts := splitToolName(name) + nameLower := strings.ToLower(name) + descLower := strings.ToLower(tm.Desc) + + totalScore := 0 + allRequiredFound := true + + for _, kw := range keywords { + kwLower := strings.ToLower(kw.word) + kwScore := 0 + + // Score against name parts + for _, part := range nameParts { + partLower := strings.ToLower(part) + if partLower == kwLower { + kwScore = intMax(kwScore, 10) + } else if strings.Contains(partLower, kwLower) { + kwScore = intMax(kwScore, 5) + } + } + + // Score against full name + if strings.Contains(nameLower, kwLower) { + kwScore = intMax(kwScore, 3) + } + + // Score against description (substring match) + if descLower != "" && strings.Contains(descLower, kwLower) { + kwScore = intMax(kwScore, 2) + } + + if kw.required && kwScore == 0 { + allRequiredFound = false + break + } + + totalScore += kwScore + } + + if !allRequiredFound { + continue + } + + if totalScore > 0 { + scored = append(scored, scoredTool{name: name, score: totalScore}) + } } - var matchedTools []string - for _, name := range t.toolNames { - if re.MatchString(name) { - matchedTools = append(matchedTools, name) + // Sort by score descending, then by name for stability + sort.Slice(scored, func(i, j int) bool { + if scored[i].score != scored[j].score { + return scored[i].score > scored[j].score } + return scored[i].name < scored[j].name + }) + + results := make([]string, 0, intMin(maxResults, len(scored))) + for i := 0; i < len(scored) && i < maxResults; i++ { + results = append(results, scored[i].name) } + return results +} + +// keyword represents a parsed search keyword. +type keyword struct { + word string + required bool +} - result := toolSearchResult{ - SelectedTools: matchedTools, +// parseKeywords splits a query string into keywords, handling the '+' required prefix. +func parseKeywords(query string) (keywords []keyword) { + parts := strings.Fields(query) + for _, p := range parts { + if strings.HasPrefix(p, "+") { + word := strings.TrimPrefix(p, "+") + if word != "" { + keywords = append(keywords, keyword{word: word, required: true}) + } + } else if p != "" { + keywords = append(keywords, keyword{word: p, required: false}) + } } + return +} - output, err := json.Marshal(result) - if err != nil { - return "", fmt.Errorf("failed to marshal result: %w", err) +// splitToolName splits a tool name into parts by underscores, double underscores (MCP separator), +// and camelCase boundaries. +func splitToolName(name string) []string { + // First split by double underscore (MCP server__tool separator) + segments := strings.Split(name, "__") + + var parts []string + for _, seg := range segments { + // Split each segment by single underscore + underscoreParts := strings.Split(seg, "_") + for _, up := range underscoreParts { + if up == "" { + continue + } + // Further split by camelCase + camelParts := splitCamelCase(up) + parts = append(parts, camelParts...) + } + } + return parts +} + +// splitCamelCase splits a camelCase or PascalCase string into its constituent words. +func splitCamelCase(s string) []string { + if s == "" { + return nil } - return string(output), nil + var parts []string + runes := []rune(s) + start := 0 + + for i := 1; i < len(runes); i++ { + if unicode.IsUpper(runes[i]) { + if unicode.IsLower(runes[i-1]) { + parts = append(parts, string(runes[start:i])) + start = i + } else if i+1 < len(runes) && unicode.IsLower(runes[i+1]) { + parts = append(parts, string(runes[start:i])) + start = i + } + } + } + parts = append(parts, string(runes[start:])) + + return parts } -func getToolNames(ctx context.Context, tools []tool.BaseTool) ([]string, error) { +// getToolNames extracts just tool names from a slice of BaseTools (used by calculateTools). +func getToolNames(tools []*schema.ToolInfo) []string { ret := make([]string, 0, len(tools)) for _, t := range tools { - info, err := t.Info(ctx) - if err != nil { - return nil, err - } - ret = append(ret, info.Name) + ret = append(ret, t.Name) } - return ret, nil + return ret } -func extractSelectedTools(ctx context.Context, messages []*schema.Message) ([]string, error) { +func extractSelectedTools(_ context.Context, messages []*schema.Message) ([]string, error) { var selectedTools []string for _, message := range messages { if message.Role != schema.Tool || message.ToolName != toolSearchToolName { @@ -206,7 +523,7 @@ func extractSelectedTools(ctx context.Context, messages []*schema.Message) ([]st if err != nil { return nil, fmt.Errorf("failed to unmarshal tool search tool result: %w", err) } - selectedTools = append(selectedTools, result.SelectedTools...) + selectedTools = append(selectedTools, result.Matches...) } return selectedTools, nil } @@ -226,22 +543,30 @@ func invertSelect[T comparable](all []T, selected []T) map[T]struct{} { return result } -func removeTools(ctx context.Context, all []*schema.ToolInfo, dynamicTools []tool.BaseTool, messages []*schema.Message) ([]*schema.ToolInfo, error) { - selectedToolNames, err := extractSelectedTools(ctx, messages) - if err != nil { - return nil, err +func calculateTools(ctx context.Context, all []*schema.ToolInfo, dynamicTools []*schema.ToolInfo, messages []*schema.Message, useModelToolSearch bool) ([]model.Option, error) { + var err error + var ret []model.Option + var selectedToolNames []string + if !useModelToolSearch { + selectedToolNames, err = extractSelectedTools(ctx, messages) + if err != nil { + return nil, err + } } - dynamicToolNames, err := getToolNames(ctx, dynamicTools) - if err != nil { - return nil, err + dynamicToolNames := getToolNames(dynamicTools) + if useModelToolSearch { + // if useModelToolSearch, register tool search tool by WithToolSearchTool + dynamicToolNames = append(dynamicToolNames, toolSearchToolName) + ret = append(ret, model.WithToolSearchTool(getToolSearchToolInfo())) } removeMap := invertSelect(dynamicToolNames, selectedToolNames) - ret := make([]*schema.ToolInfo, 0, len(all)-len(dynamicTools)) + tools := make([]*schema.ToolInfo, 0, len(all)-len(dynamicTools)) for _, info := range all { if _, ok := removeMap[info.Name]; ok { continue } - ret = append(ret, info) + tools = append(tools, info) } + ret = append(ret, model.WithTools(tools)) return ret, nil } diff --git a/adk/middlewares/dynamictool/toolsearch/toolsearch_test.go b/adk/middlewares/dynamictool/toolsearch/toolsearch_test.go index 4b249b9be..20cee35da 100644 --- a/adk/middlewares/dynamictool/toolsearch/toolsearch_test.go +++ b/adk/middlewares/dynamictool/toolsearch/toolsearch_test.go @@ -19,6 +19,10 @@ package toolsearch import ( "context" "encoding/json" + "fmt" + "sort" + "strings" + "sync" "testing" "github.com/stretchr/testify/assert" @@ -27,464 +31,569 @@ import ( "github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/tool" + "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/schema" ) -type mockTool struct { - name string - desc string -} +// --------------------------------------------------------------------------- +// helpers +// --------------------------------------------------------------------------- -func (m *mockTool) Info(ctx context.Context) (*schema.ToolInfo, error) { - return &schema.ToolInfo{ - Name: m.name, - Desc: m.desc, - }, nil +func makeToolMap(tools ...*schema.ToolInfo) map[string]*schema.ToolInfo { + m := make(map[string]*schema.ToolInfo, len(tools)) + for _, t := range tools { + m[t.Name] = t + } + return m } -func newMockTool(name, desc string) *mockTool { - return &mockTool{name: name, desc: desc} +func ti(name, desc string) *schema.ToolInfo { + return &schema.ToolInfo{Name: name, Desc: desc} } -func TestNew(t *testing.T) { - ctx := context.Background() +func toolNames(infos []*schema.ToolInfo) []string { + names := make([]string, len(infos)) + for i, info := range infos { + names[i] = info.Name + } + sort.Strings(names) + return names +} - t.Run("nil config returns error", func(t *testing.T) { - m, err := New(ctx, nil) - assert.Nil(t, m) - assert.Error(t, err) - assert.Contains(t, err.Error(), "config is required") - }) +func searchJSON(query string, maxResults *int) string { + args := toolSearchArgs{Query: query, MaxResults: maxResults} + b, _ := json.Marshal(args) + return string(b) +} - t.Run("empty tools returns error", func(t *testing.T) { - m, err := New(ctx, &Config{DynamicTools: []tool.BaseTool{}}) - assert.Nil(t, m) - assert.Error(t, err) - assert.Contains(t, err.Error(), "tools is required") - }) +func intPtr(v int) *int { return &v } + +// --------------------------------------------------------------------------- +// TestSearch — unit tests for the search() function +// --------------------------------------------------------------------------- + +func TestSearch(t *testing.T) { + tools := makeToolMap( + ti("get_weather", "Get current weather for a city"), + ti("search_flights", "Search available flights"), + ti("mcp__slack__send_message", "Send a message to Slack channel"), + ti("mcp__slack__read_channel", "Read messages from Slack channel"), + ti("create_calendar_event", "Create a new calendar event"), + ti("NotebookEdit", "Edit Jupyter notebook cells"), + ) + + tests := []struct { + name string + json string + wantNames []string // sorted; nil means expect empty + wantErr bool + }{ + { + name: "keyword exact name part match", + json: searchJSON("weather", nil), + wantNames: []string{"get_weather"}, + }, + { + name: "keyword matches multiple tools", + json: searchJSON("slack", nil), + wantNames: []string{"mcp__slack__read_channel", "mcp__slack__send_message"}, + }, + { + name: "multi-word ranking - send_message ranked first", + json: searchJSON("send message", nil), + wantNames: []string{"mcp__slack__send_message"}, // check first element only + }, + { + name: "required keyword filters to slack only", + json: searchJSON("+slack send", nil), + wantNames: []string{"mcp__slack__read_channel", "mcp__slack__send_message"}, + }, + { + name: "required keyword no match", + json: searchJSON("+github send", nil), + wantNames: nil, + }, + { + name: "direct select single", + json: searchJSON("select:get_weather", nil), + wantNames: []string{"get_weather"}, + }, + { + name: "direct select multiple", + json: searchJSON("select:get_weather,NotebookEdit", nil), + wantNames: []string{"NotebookEdit", "get_weather"}, + }, + { + name: "direct select nonexistent", + json: searchJSON("select:nonexistent", nil), + wantNames: nil, + }, + { + name: "max_results limits output", + json: searchJSON("slack", intPtr(1)), + wantNames: []string{"mcp__slack__read_channel"}, // just check length below + }, + { + name: "camelCase split matches notebook", + json: searchJSON("notebook", nil), + wantNames: []string{"NotebookEdit"}, + }, + { + name: "empty query returns error", + json: searchJSON("", nil), + wantErr: true, + }, + { + name: "description match - jupyter", + json: searchJSON("jupyter", nil), + wantNames: []string{"NotebookEdit"}, + }, + } - t.Run("valid config returns middleware", func(t *testing.T) { - tools := []tool.BaseTool{ - newMockTool("tool1", "desc1"), - newMockTool("tool2", "desc2"), - } - m, err := New(ctx, &Config{DynamicTools: tools}) - assert.NoError(t, err) - assert.NotNil(t, m) - }) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := search(tt.json, tools) + if tt.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + + // special case: max_results limit + if tt.name == "max_results limits output" { + assert.Len(t, got, 1) + return + } + + // special case: ranking — just check first element + if tt.name == "multi-word ranking - send_message ranked first" { + require.NotEmpty(t, got) + assert.Equal(t, "mcp__slack__send_message", got[0].Name) + return + } + + gotNames := toolNames(got) + if tt.wantNames == nil { + assert.Empty(t, gotNames) + } else { + assert.Equal(t, tt.wantNames, gotNames) + } + }) + } } -func TestMiddleware_BeforeAgent(t *testing.T) { - ctx := context.Background() +// --------------------------------------------------------------------------- +// TestMiddlewareFlow — integration test for UseModelToolSearch=false +// --------------------------------------------------------------------------- - t.Run("nil runCtx returns nil", func(t *testing.T) { - tools := []tool.BaseTool{newMockTool("tool1", "desc1")} - m, err := New(ctx, &Config{DynamicTools: tools}) - require.NoError(t, err) +// simpleTool is a minimal InvokableTool for testing. +type simpleTool struct { + name string + desc string + called bool + mu sync.Mutex +} - newCtx, newRunCtx, err := m.BeforeAgent(ctx, nil) - assert.NoError(t, err) - assert.Equal(t, ctx, newCtx) - assert.Nil(t, newRunCtx) - }) +func (s *simpleTool) Info(_ context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: s.name, + Desc: s.desc, + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "input": {Type: schema.String, Desc: "input", Required: true}, + }), + }, nil +} - t.Run("adds tool_search and dynamic tools", func(t *testing.T) { - tools := []tool.BaseTool{ - newMockTool("tool1", "desc1"), - newMockTool("tool2", "desc2"), - } - m, err := New(ctx, &Config{DynamicTools: tools}) - require.NoError(t, err) +func (s *simpleTool) InvokableRun(_ context.Context, _ string, _ ...tool.Option) (string, error) { + s.mu.Lock() + s.called = true + s.mu.Unlock() + return `{"result":"ok"}`, nil +} - middleware := m.(*middleware) - runCtx := &adk.ChatModelAgentContext{ - Tools: []tool.BaseTool{}, - } +func (s *simpleTool) wasCalled() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.called +} - _, newRunCtx, err := middleware.BeforeAgent(ctx, runCtx) - assert.NoError(t, err) - assert.NotNil(t, newRunCtx) - assert.Len(t, newRunCtx.Tools, 3) - }) +// mockChatModel implements model.ToolCallingChatModel. +// It drives a 3-turn conversation: +// +// Turn 1: call tool_search with select:dynamic_tool_a +// Turn 2: call dynamic_tool_a +// Turn 3: return final text +type mockChatModel struct { + mu sync.Mutex + generateCall int + // toolsPerCall records the tool names passed via model.WithTools for each Generate call. + toolsPerCall [][]string } -func TestToolSearchTool_Info(t *testing.T) { - ctx := context.Background() - toolNames := []string{"tool1", "tool2", "tool3"} - tst := newToolSearchTool(toolNames) - - info, err := tst.Info(ctx) - assert.NoError(t, err) - assert.Equal(t, "tool_search", info.Name) - assert.Contains(t, info.Desc, "regex pattern") - assert.NotNil(t, info.ParamsOneOf) +func (m *mockChatModel) Generate(_ context.Context, _ []*schema.Message, opts ...model.Option) (*schema.Message, error) { + options := model.GetCommonOptions(nil, opts...) + var names []string + for _, t := range options.Tools { + names = append(names, t.Name) + } + sort.Strings(names) + + m.mu.Lock() + m.generateCall++ + call := m.generateCall + m.toolsPerCall = append(m.toolsPerCall, names) + m.mu.Unlock() + + switch call { + case 1: + // Ask tool_search to select dynamic_tool_a + return schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "tc1", + Function: schema.FunctionCall{ + Name: toolSearchToolName, + Arguments: `{"query":"select:dynamic_tool_a","max_results":5}`, + }, + }, + }), nil + case 2: + // Call dynamic_tool_a + return schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "tc2", + Function: schema.FunctionCall{ + Name: "dynamic_tool_a", + Arguments: `{"input":"hello"}`, + }, + }, + }), nil + default: + // Final response + return schema.AssistantMessage("done", nil), nil + } } -func TestToolSearchTool_InvokableRun(t *testing.T) { - ctx := context.Background() - toolNames := []string{"get_weather", "get_time", "search_web", "calculate_sum"} - tst := newToolSearchTool(toolNames) +func (m *mockChatModel) Stream(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, fmt.Errorf("not implemented") +} - t.Run("empty regex pattern returns error", func(t *testing.T) { - args := `{"regex_pattern": ""}` - result, err := tst.InvokableRun(ctx, args) - assert.Error(t, err) - assert.Contains(t, err.Error(), "regex_pattern is required") - assert.Empty(t, result) - }) +func (m *mockChatModel) WithTools(_ []*schema.ToolInfo) (model.ToolCallingChatModel, error) { + return m, nil +} - t.Run("invalid json returns error", func(t *testing.T) { - args := `{invalid json}` - result, err := tst.InvokableRun(ctx, args) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to unmarshal") - assert.Empty(t, result) - }) +func (m *mockChatModel) getToolsPerCall() [][]string { + m.mu.Lock() + defer m.mu.Unlock() + ret := make([][]string, len(m.toolsPerCall)) + copy(ret, m.toolsPerCall) + return ret +} - t.Run("invalid regex returns error", func(t *testing.T) { - args := `{"regex_pattern": "[invalid"}` - result, err := tst.InvokableRun(ctx, args) - assert.Error(t, err) - assert.Contains(t, err.Error(), "invalid regex pattern") - assert.Empty(t, result) - }) +func TestMiddlewareFlow(t *testing.T) { + ctx := context.Background() - t.Run("matches tools with prefix pattern", func(t *testing.T) { - args := `{"regex_pattern": "^get_"}` - result, err := tst.InvokableRun(ctx, args) - assert.NoError(t, err) + dynamicA := &simpleTool{name: "dynamic_tool_a", desc: "Dynamic tool A"} + dynamicB := &simpleTool{name: "dynamic_tool_b", desc: "Dynamic tool B"} + staticTool := &simpleTool{name: "static_tool", desc: "Static tool"} - var res toolSearchResult - err = json.Unmarshal([]byte(result), &res) - assert.NoError(t, err) - assert.ElementsMatch(t, []string{"get_weather", "get_time"}, res.SelectedTools) + mw, err := New(ctx, &Config{ + DynamicTools: []tool.BaseTool{dynamicA, dynamicB}, + UseModelToolSearch: false, }) - - t.Run("matches tools with suffix pattern", func(t *testing.T) { - args := `{"regex_pattern": "_sum$"}` - result, err := tst.InvokableRun(ctx, args) - assert.NoError(t, err) - - var res toolSearchResult - err = json.Unmarshal([]byte(result), &res) - assert.NoError(t, err) - assert.ElementsMatch(t, []string{"calculate_sum"}, res.SelectedTools) + require.NoError(t, err) + + cm := &mockChatModel{} + + agent, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ + Name: "test_agent", + Description: "test", + Instruction: "you are a test agent", + Model: cm, + ToolsConfig: adk.ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{staticTool}, + }, + }, + Handlers: []adk.ChatModelAgentMiddleware{mw}, }) + require.NoError(t, err) - t.Run("matches all tools with wildcard", func(t *testing.T) { - args := `{"regex_pattern": ".*"}` - result, err := tst.InvokableRun(ctx, args) - assert.NoError(t, err) + input := &adk.AgentInput{ + Messages: []adk.Message{schema.UserMessage("test")}, + } + iter := agent.Run(ctx, input) - var res toolSearchResult - err = json.Unmarshal([]byte(result), &res) - assert.NoError(t, err) - assert.ElementsMatch(t, toolNames, res.SelectedTools) - }) + var events []*adk.AgentEvent + for { + ev, ok := iter.Next() + if !ok { + break + } + events = append(events, ev) + } - t.Run("no matches returns empty list", func(t *testing.T) { - args := `{"regex_pattern": "^nonexistent_"}` - result, err := tst.InvokableRun(ctx, args) - assert.NoError(t, err) + // Verify no error event. + for _, ev := range events { + if ev.Err != nil { + t.Fatalf("unexpected error event: %v", ev.Err) + } + } - var res toolSearchResult - err = json.Unmarshal([]byte(result), &res) - assert.NoError(t, err) - assert.Empty(t, res.SelectedTools) - }) + // Verify final output is "done". + lastEvent := events[len(events)-1] + require.NotNil(t, lastEvent.Output) + require.NotNil(t, lastEvent.Output.MessageOutput) + assert.Equal(t, "done", lastEvent.Output.MessageOutput.Message.Content) + + // Verify dynamic_tool_a was actually called. + assert.True(t, dynamicA.wasCalled(), "dynamic_tool_a should have been called") + assert.False(t, dynamicB.wasCalled(), "dynamic_tool_b should not have been called") + + // Verify tool lists per Generate call. + toolsPerCall := cm.getToolsPerCall() + require.Len(t, toolsPerCall, 3, "expected 3 Generate calls") + + // Call 1: tool_search + static_tool; dynamic tools are hidden. + assert.Contains(t, toolsPerCall[0], "tool_search") + assert.Contains(t, toolsPerCall[0], "static_tool") + assert.NotContains(t, toolsPerCall[0], "dynamic_tool_a") + assert.NotContains(t, toolsPerCall[0], "dynamic_tool_b") + + // Call 2: after selecting dynamic_tool_a, it becomes visible. + assert.Contains(t, toolsPerCall[1], "tool_search") + assert.Contains(t, toolsPerCall[1], "static_tool") + assert.Contains(t, toolsPerCall[1], "dynamic_tool_a") + assert.NotContains(t, toolsPerCall[1], "dynamic_tool_b") + + // Call 3: same as call 2. + assert.Contains(t, toolsPerCall[2], "tool_search") + assert.Contains(t, toolsPerCall[2], "static_tool") + assert.Contains(t, toolsPerCall[2], "dynamic_tool_a") + assert.NotContains(t, toolsPerCall[2], "dynamic_tool_b") + + // Verify reminder is present in messages (checked via tool list — the wrapper inserts it). + // The model received messages, and the reminder contains "". + // We indirectly verify this by checking that the middleware ran without error and the + // 3-turn flow completed successfully, which requires the tool_search tool to work. + + // Additional: verify that the reminder contains the dynamic tool names. + mwImpl := mw.(*middleware) + assert.True(t, strings.Contains(mwImpl.sr, "dynamic_tool_a")) + assert.True(t, strings.Contains(mwImpl.sr, "dynamic_tool_b")) + assert.True(t, strings.Contains(mwImpl.sr, "")) } -func TestGetToolNames(t *testing.T) { +// --------------------------------------------------------------------------- +// TestNew — error paths for New() +// --------------------------------------------------------------------------- + +func TestNew(t *testing.T) { ctx := context.Background() - t.Run("returns tool names", func(t *testing.T) { - tools := []tool.BaseTool{ - newMockTool("tool1", "desc1"), - newMockTool("tool2", "desc2"), - newMockTool("tool3", "desc3"), - } - names, err := getToolNames(ctx, tools) - assert.NoError(t, err) - assert.Equal(t, []string{"tool1", "tool2", "tool3"}, names) + t.Run("nil config", func(t *testing.T) { + _, err := New(ctx, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "config is required") }) - t.Run("empty tools returns empty slice", func(t *testing.T) { - names, err := getToolNames(ctx, []tool.BaseTool{}) - assert.NoError(t, err) - assert.Empty(t, names) + t.Run("empty DynamicTools", func(t *testing.T) { + _, err := New(ctx, &Config{}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "tools is required") }) -} - -func TestExtractSelectedTools(t *testing.T) { - ctx := context.Background() - - t.Run("extracts selected tools from messages", func(t *testing.T) { - result := toolSearchResult{SelectedTools: []string{"tool1", "tool2"}} - resultJSON, _ := json.Marshal(result) - messages := []*schema.Message{ - schema.UserMessage("hello"), - {Role: schema.Tool, ToolName: toolSearchToolName, Content: string(resultJSON)}, - } - - selected, err := extractSelectedTools(ctx, messages) - assert.NoError(t, err) - assert.ElementsMatch(t, []string{"tool1", "tool2"}, selected) + t.Run("success", func(t *testing.T) { + st := &simpleTool{name: "t1", desc: "tool 1"} + mw, err := New(ctx, &Config{DynamicTools: []tool.BaseTool{st}}) + require.NoError(t, err) + assert.NotNil(t, mw) }) +} - t.Run("handles multiple tool_search results", func(t *testing.T) { - result1 := toolSearchResult{SelectedTools: []string{"tool1"}} - result1JSON, _ := json.Marshal(result1) - result2 := toolSearchResult{SelectedTools: []string{"tool2", "tool3"}} - result2JSON, _ := json.Marshal(result2) +// --------------------------------------------------------------------------- +// TestSplitCamelCase +// --------------------------------------------------------------------------- + +func TestSplitCamelCase(t *testing.T) { + tests := []struct { + input string + want []string + }{ + {"", nil}, + {"hello", []string{"hello"}}, + {"NotebookEdit", []string{"Notebook", "Edit"}}, + {"camelCase", []string{"camel", "Case"}}, + {"HTMLParser", []string{"HTML", "Parser"}}, + {"getURL", []string{"get", "URL"}}, + {"A", []string{"A"}}, + {"AB", []string{"AB"}}, + {"HTTP", []string{"HTTP"}}, + } + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := splitCamelCase(tt.input) + assert.Equal(t, tt.want, got) + }) + } +} - messages := []*schema.Message{ - {Role: schema.Tool, ToolName: toolSearchToolName, Content: string(result1JSON)}, - schema.UserMessage("continue"), - {Role: schema.Tool, ToolName: toolSearchToolName, Content: string(result2JSON)}, - } +// --------------------------------------------------------------------------- +// TestInsertReminder +// --------------------------------------------------------------------------- - selected, err := extractSelectedTools(ctx, messages) - assert.NoError(t, err) - assert.ElementsMatch(t, []string{"tool1", "tool2", "tool3"}, selected) - }) +func TestInsertReminder(t *testing.T) { + w := &wrapper{reminder: ""} - t.Run("ignores non-tool_search messages", func(t *testing.T) { - messages := []*schema.Message{ - schema.UserMessage("hello"), - {Role: schema.Tool, ToolName: "other_tool", Content: "some content"}, - {Role: schema.Assistant, Content: "response"}, + t.Run("normal: system then user", func(t *testing.T) { + input := []*schema.Message{ + {Role: schema.System, Content: "sys"}, + {Role: schema.User, Content: "hi"}, } - - selected, err := extractSelectedTools(ctx, messages) - assert.NoError(t, err) - assert.Empty(t, selected) + got := w.insertReminder(input) + require.Len(t, got, 3) + assert.Equal(t, schema.System, got[0].Role) + assert.Equal(t, schema.User, got[1].Role) + assert.Equal(t, "", got[1].Content) + assert.Equal(t, schema.User, got[2].Role) + assert.Equal(t, "hi", got[2].Content) }) - t.Run("returns error for invalid json", func(t *testing.T) { - messages := []*schema.Message{ - {Role: schema.Tool, ToolName: toolSearchToolName, Content: "invalid json"}, + t.Run("all system messages", func(t *testing.T) { + input := []*schema.Message{ + {Role: schema.System, Content: "sys1"}, + {Role: schema.System, Content: "sys2"}, } - - selected, err := extractSelectedTools(ctx, messages) - assert.Error(t, err) - assert.Nil(t, selected) + got := w.insertReminder(input) + require.Len(t, got, 3) + // Reminder appended at the end since no non-system message found during iteration. + assert.Equal(t, schema.System, got[0].Role) + assert.Equal(t, schema.System, got[1].Role) + assert.Equal(t, "", got[2].Content) }) -} -func TestInvertSelect(t *testing.T) { - t.Run("returns items not in selected", func(t *testing.T) { - all := []string{"a", "b", "c", "d"} - selected := []string{"b", "d"} - - result := invertSelect(all, selected) - assert.Len(t, result, 2) - _, hasA := result["a"] - _, hasC := result["c"] - assert.True(t, hasA) - assert.True(t, hasC) + t.Run("empty input", func(t *testing.T) { + got := w.insertReminder(nil) + require.Len(t, got, 1) + assert.Equal(t, "", got[0].Content) }) - t.Run("empty selected returns all", func(t *testing.T) { - all := []string{"a", "b", "c"} - selected := []string{} - - result := invertSelect(all, selected) - assert.Len(t, result, 3) - }) - - t.Run("all selected returns empty", func(t *testing.T) { - all := []string{"a", "b"} - selected := []string{"a", "b"} - - result := invertSelect(all, selected) - assert.Empty(t, result) - }) - - t.Run("works with integers", func(t *testing.T) { - all := []int{1, 2, 3, 4, 5} - selected := []int{2, 4} - - result := invertSelect(all, selected) - assert.Len(t, result, 3) - _, has1 := result[1] - _, has3 := result[3] - _, has5 := result[5] - assert.True(t, has1) - assert.True(t, has3) - assert.True(t, has5) + t.Run("no system messages", func(t *testing.T) { + input := []*schema.Message{ + {Role: schema.User, Content: "hi"}, + {Role: schema.Assistant, Content: "hello"}, + } + got := w.insertReminder(input) + require.Len(t, got, 3) + // Reminder inserted before the first non-system message. + assert.Equal(t, "", got[0].Content) + assert.Equal(t, "hi", got[1].Content) + assert.Equal(t, "hello", got[2].Content) }) } -func TestRemoveTools(t *testing.T) { - ctx := context.Background() - - t.Run("removes unselected dynamic tools", func(t *testing.T) { - allTools := []*schema.ToolInfo{ - {Name: "static_tool"}, - {Name: "dynamic_tool1"}, - {Name: "dynamic_tool2"}, - {Name: "dynamic_tool3"}, - } +// --------------------------------------------------------------------------- +// TestExtractSelectedTools +// --------------------------------------------------------------------------- - dynamicTools := []tool.BaseTool{ - newMockTool("dynamic_tool1", ""), - newMockTool("dynamic_tool2", ""), - newMockTool("dynamic_tool3", ""), - } +func TestExtractSelectedTools(t *testing.T) { + ctx := context.Background() - result := toolSearchResult{SelectedTools: []string{"dynamic_tool1"}} - resultJSON, _ := json.Marshal(result) + t.Run("accumulates from multiple tool_search results", func(t *testing.T) { messages := []*schema.Message{ - {Role: schema.Tool, ToolName: toolSearchToolName, Content: string(resultJSON)}, - } - - tools, err := removeTools(ctx, allTools, dynamicTools, messages) - assert.NoError(t, err) - assert.Len(t, tools, 2) - - toolNames := make([]string, len(tools)) - for i, t := range tools { - toolNames[i] = t.Name + {Role: schema.Tool, ToolName: toolSearchToolName, Content: `{"matches":["tool_a"]}`}, + {Role: schema.Tool, ToolName: toolSearchToolName, Content: `{"matches":["tool_b","tool_c"]}`}, } - assert.ElementsMatch(t, []string{"static_tool", "dynamic_tool1"}, toolNames) + got, err := extractSelectedTools(ctx, messages) + require.NoError(t, err) + assert.Equal(t, []string{"tool_a", "tool_b", "tool_c"}, got) }) - t.Run("remove all dynamic tools when no tool_search result", func(t *testing.T) { - allTools := []*schema.ToolInfo{ - {Name: "static_tool"}, - {Name: "dynamic_tool1"}, - } - - dynamicTools := []tool.BaseTool{ - newMockTool("dynamic_tool1", ""), - } - + t.Run("ignores non tool_search messages", func(t *testing.T) { messages := []*schema.Message{ - schema.UserMessage("hello"), + {Role: schema.User, Content: "hello"}, + {Role: schema.Tool, ToolName: "other_tool", Content: `{"matches":["should_ignore"]}`}, + {Role: schema.Assistant, Content: "world"}, + {Role: schema.Tool, ToolName: toolSearchToolName, Content: `{"matches":["tool_a"]}`}, } - - tools, err := removeTools(ctx, allTools, dynamicTools, messages) - assert.NoError(t, err) - assert.Len(t, tools, 1) - assert.Equal(t, "static_tool", tools[0].Name) + got, err := extractSelectedTools(ctx, messages) + require.NoError(t, err) + assert.Equal(t, []string{"tool_a"}, got) }) - t.Run("handles empty dynamic tools", func(t *testing.T) { - allTools := []*schema.ToolInfo{ - {Name: "static_tool1"}, - {Name: "static_tool2"}, + t.Run("malformed JSON returns error", func(t *testing.T) { + messages := []*schema.Message{ + {Role: schema.Tool, ToolName: toolSearchToolName, Content: `not json`}, } - - dynamicTools := []tool.BaseTool{} - messages := []*schema.Message{} - - tools, err := removeTools(ctx, allTools, dynamicTools, messages) - assert.NoError(t, err) - assert.Len(t, tools, 2) + _, err := extractSelectedTools(ctx, messages) + assert.Error(t, err) }) -} - -type mockChatModel struct { - generateFunc func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) - streamFunc func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) -} -func (m *mockChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { - if m.generateFunc != nil { - return m.generateFunc(ctx, input, opts...) - } - return &schema.Message{Role: schema.Assistant, Content: "response"}, nil + t.Run("nil messages returns nil", func(t *testing.T) { + got, err := extractSelectedTools(ctx, nil) + require.NoError(t, err) + assert.Nil(t, got) + }) } -func (m *mockChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { - if m.streamFunc != nil { - return m.streamFunc(ctx, input, opts...) - } - return nil, nil -} +// --------------------------------------------------------------------------- +// TestCalculateTools +// --------------------------------------------------------------------------- -func TestWrapper_Generate(t *testing.T) { +func TestCalculateTools(t *testing.T) { ctx := context.Background() - t.Run("filters tools based on tool_search result", func(t *testing.T) { - allTools := []*schema.ToolInfo{ - {Name: "static_tool"}, - {Name: "dynamic_tool1"}, - {Name: "dynamic_tool2"}, - } - - dynamicTools := []tool.BaseTool{ - newMockTool("dynamic_tool1", ""), - newMockTool("dynamic_tool2", ""), - } + staticTool := ti("static_tool", "static") + toolSearchInfo := getToolSearchToolInfo() + dynA := ti("dynamic_a", "A") + dynB := ti("dynamic_b", "B") - result := toolSearchResult{SelectedTools: []string{"dynamic_tool1"}} - resultJSON, _ := json.Marshal(result) + allTools := []*schema.ToolInfo{staticTool, toolSearchInfo, dynA, dynB} + dynamicTools := []*schema.ToolInfo{dynA, dynB} + t.Run("no selection: dynamic tools hidden", func(t *testing.T) { messages := []*schema.Message{ - schema.UserMessage("hello"), - {Role: schema.Tool, ToolName: toolSearchToolName, Content: string(resultJSON)}, + {Role: schema.User, Content: "hello"}, } + opts, err := calculateTools(ctx, allTools, dynamicTools, messages, false) + require.NoError(t, err) - w := &wrapper{ - allTools: allTools, - dynamicTools: dynamicTools, - cm: &mockChatModel{ - generateFunc: func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { - options := model.GetCommonOptions(nil, opts...) - assert.Len(t, options.Tools, 2) - assert.Equal(t, "static_tool", options.Tools[0].Name) - assert.Equal(t, "dynamic_tool1", options.Tools[1].Name) - return nil, nil - }, - }, + options := model.GetCommonOptions(nil, opts...) + var names []string + for _, t := range options.Tools { + names = append(names, t.Name) } - - _, err := w.Generate(ctx, messages) - assert.NoError(t, err) + sort.Strings(names) + assert.Equal(t, []string{"static_tool", "tool_search"}, names) }) -} - -func TestWrapper_Stream(t *testing.T) { - ctx := context.Background() - t.Run("filters tools based on tool_search result", func(t *testing.T) { - allTools := []*schema.ToolInfo{ - {Name: "static_tool"}, - {Name: "dynamic_tool1"}, - {Name: "dynamic_tool2"}, + t.Run("partial selection: selected tool visible", func(t *testing.T) { + messages := []*schema.Message{ + {Role: schema.Tool, ToolName: toolSearchToolName, Content: `{"matches":["dynamic_a"]}`}, } + opts, err := calculateTools(ctx, allTools, dynamicTools, messages, false) + require.NoError(t, err) - dynamicTools := []tool.BaseTool{ - newMockTool("dynamic_tool1", ""), - newMockTool("dynamic_tool2", ""), + options := model.GetCommonOptions(nil, opts...) + var names []string + for _, t := range options.Tools { + names = append(names, t.Name) } + sort.Strings(names) + assert.Equal(t, []string{"dynamic_a", "static_tool", "tool_search"}, names) + }) - result := toolSearchResult{SelectedTools: []string{"dynamic_tool1"}} - resultJSON, _ := json.Marshal(result) - - messages := []*schema.Message{ - schema.UserMessage("hello"), - {Role: schema.Tool, ToolName: toolSearchToolName, Content: string(resultJSON)}, - } + t.Run("useModelToolSearch: dynamic tools and tool_search removed from WithTools", func(t *testing.T) { + opts, err := calculateTools(ctx, allTools, dynamicTools, nil, true) + require.NoError(t, err) - w := &wrapper{ - allTools: allTools, - dynamicTools: dynamicTools, - cm: &mockChatModel{ - streamFunc: func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { - options := model.GetCommonOptions(nil, opts...) - assert.Len(t, options.Tools, 2) - assert.Equal(t, "static_tool", options.Tools[0].Name) - assert.Equal(t, "dynamic_tool1", options.Tools[1].Name) - return nil, nil - }, - }, + options := model.GetCommonOptions(nil, opts...) + var names []string + for _, t := range options.Tools { + names = append(names, t.Name) } - - stream, err := w.Stream(ctx, messages) - assert.NoError(t, err) - assert.Nil(t, stream) + assert.Equal(t, []string{"static_tool"}, names) + // ToolSearchTool should be set. + assert.NotNil(t, options.ToolSearchTool) + assert.Equal(t, toolSearchToolName, options.ToolSearchTool.Name) }) } diff --git a/components/model/option.go b/components/model/option.go index 936b0fbda..2222e14a1 100644 --- a/components/model/option.go +++ b/components/model/option.go @@ -28,6 +28,16 @@ type Options struct { TopP *float32 // Tools is a list of tools the model may call. Tools []*schema.ToolInfo + // DeferredTools is a list of tools to be registered with defer_loading=true + // for the model's built-in (server-side) tool search capability. + // These tools are sent to the model API but not loaded into context upfront — + // only their names and descriptions are visible to the model. The model's + // built-in tool search tool searches through them and loads matching ones + // on demand. + DeferredTools []*schema.ToolInfo + + ToolSearchTool *schema.ToolInfo + // MaxTokens is the max number of tokens, if reached the max tokens, the model will stop generating, and mostly return a finish reason of "length". MaxTokens *int // Stop is the stop words for the model, which controls the stopping condition of the model. @@ -114,6 +124,33 @@ func WithTools(tools []*schema.ToolInfo) Option { } } +// WithToolSearchTool is the option to register a tool search tool with the model. +// When set, the model uses this tool to discover and load deferred tools on demand. +// Note: The tool search tool should NOT be included in WithTools. +func WithToolSearchTool(tool *schema.ToolInfo) Option { + return Option{ + apply: func(opts *Options) { + opts.ToolSearchTool = tool + }, + } +} + +// WithDeferredTools is the option to set deferred tools for the model's +// built-in (server-side) tool search. These tools are registered with +// defer_loading=true so the model can discover and load them on demand +// via its native tool search capability. +// Note: Deferred tools should NOT be included in WithTools. +func WithDeferredTools(tools []*schema.ToolInfo) Option { + if tools == nil { + tools = []*schema.ToolInfo{} + } + return Option{ + apply: func(opts *Options) { + opts.DeferredTools = tools + }, + } +} + // WithToolChoice sets the tool choice for the model. It also allows for providing a list of // tool names to constrain the model to a specific subset of the available tools. // Only available for ChatModel. diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 95e14c0df..43376c146 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -41,6 +41,7 @@ const ( ContentBlockTypeUserInputAudio ContentBlockType = "user_input_audio" ContentBlockTypeUserInputVideo ContentBlockType = "user_input_video" ContentBlockTypeUserInputFile ContentBlockType = "user_input_file" + ContentBlockTypeToolSearchResult ContentBlockType = "tool_search_result" ContentBlockTypeAssistantGenText ContentBlockType = "assistant_gen_text" ContentBlockTypeAssistantGenImage ContentBlockType = "assistant_gen_image" ContentBlockTypeAssistantGenAudio ContentBlockType = "assistant_gen_audio" @@ -134,6 +135,11 @@ type ContentBlock struct { // FunctionToolResult contains the result returned from a user-defined tool call. FunctionToolResult *FunctionToolResult `json:"function_tool_result,omitempty"` + // ToolSearchFunctionToolResult contains the result of a client-side custom tool search tool call. + // It carries the full definitions of newly discovered tools so that the model can + // recognize which tools have been added and are now available for invocation. + ToolSearchFunctionToolResult *ToolSearchFunctionToolResult `json:"tool_search_function_tool_result,omitempty"` + // ServerToolCall contains the invocation details for a provider built-in tool executed on the model server. ServerToolCall *ServerToolCall `json:"server_tool_call,omitempty"` @@ -300,6 +306,28 @@ type FunctionToolResult struct { Result string `json:"result,omitempty"` } +// ToolSearchFunctionToolResult represents the result of a client-side custom tool search +// function tool call. Unlike a regular FunctionToolResult, this carries a ToolSearchResult +// containing the full definitions of newly discovered tools, so the model can recognize +// which tools have been added and are now available for invocation. +type ToolSearchFunctionToolResult struct { + // CallID is the unique identifier for the tool call. + CallID string `json:"call_id,omitempty"` + + // Name specifies the function tool invoked. + Name string `json:"name"` + + // Result is the function tool result returned by the user + Result *ToolSearchResult `json:"result,omitempty"` +} + +func (t *ToolSearchFunctionToolResult) String() string { + if t.Result != nil { + return t.Result.String() + } + return "" +} + type ServerToolCall struct { // Name specifies the server-side tool invoked. // Supplied by the model server (e.g., `web_search` for OpenAI, `googleSearch` for Gemini). @@ -461,7 +489,7 @@ type assistantGenVariant interface { } type functionToolCallVariant interface { - FunctionToolCall | FunctionToolResult + FunctionToolCall | FunctionToolResult | ToolSearchFunctionToolResult } type serverToolCallVariant interface { @@ -487,6 +515,8 @@ func NewContentBlock[T contentBlockVariant](content *T) *ContentBlock { return &ContentBlock{Type: ContentBlockTypeUserInputVideo, UserInputVideo: b} case *UserInputFile: return &ContentBlock{Type: ContentBlockTypeUserInputFile, UserInputFile: b} + case *ToolSearchFunctionToolResult: + return &ContentBlock{Type: ContentBlockTypeToolSearchResult, ToolSearchFunctionToolResult: b} case *AssistantGenText: return &ContentBlock{Type: ContentBlockTypeAssistantGenText, AssistantGenText: b} case *AssistantGenImage: @@ -1028,6 +1058,11 @@ func concatChunksOfSameContentBlock(blocks []*ContentBlock) (*ContentBlock, erro func(b *ContentBlock) *UserInputFile { return b.UserInputFile }, concatUserInputFiles) + case ContentBlockTypeToolSearchResult: + return concatContentBlockHelper(blocks, blockType, + func(b *ContentBlock) *ToolSearchFunctionToolResult { return b.ToolSearchFunctionToolResult }, + concatToolSearchFunctionToolResult) + case ContentBlockTypeAssistantGenText: return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *AssistantGenText { return b.AssistantGenText }, @@ -1227,6 +1262,16 @@ func concatUserInputFiles(files []*UserInputFile) (*UserInputFile, error) { return nil, fmt.Errorf("cannot concat multiple user input files") } +func concatToolSearchFunctionToolResult(results []*ToolSearchFunctionToolResult) (*ToolSearchFunctionToolResult, error) { + if len(results) == 0 { + return nil, fmt.Errorf("no tool search results found") + } + if len(results) == 1 { + return results[0], nil + } + return nil, fmt.Errorf("cannot concat multiple tool search results") +} + func concatAssistantGenTexts(texts []*AssistantGenText) (ret *AssistantGenText, err error) { if len(texts) == 0 { return nil, fmt.Errorf("no assistant generated text found") @@ -1772,6 +1817,7 @@ func (m *AgenticMessage) String() string { } // String returns the string representation of ContentBlock. +// nolint func (b *ContentBlock) String() string { sb := &strings.Builder{} sb.WriteString(fmt.Sprintf("type: %s\n", b.Type)) @@ -1801,6 +1847,10 @@ func (b *ContentBlock) String() string { if b.UserInputFile != nil { sb.WriteString(b.UserInputFile.String()) } + case ContentBlockTypeToolSearchResult: + if b.ToolSearchFunctionToolResult != nil { + sb.WriteString(b.ToolSearchFunctionToolResult.String()) + } case ContentBlockTypeAssistantGenText: if b.AssistantGenText != nil { sb.WriteString(b.AssistantGenText.String()) diff --git a/schema/agentic_message_test.go b/schema/agentic_message_test.go index e8a1003f5..10639f738 100644 --- a/schema/agentic_message_test.go +++ b/schema/agentic_message_test.go @@ -1592,6 +1592,8 @@ func TestNewContentBlock(t *testing.T) { block = NewContentBlock(v) case *UserInputFile: block = NewContentBlock(v) + case *ToolSearchFunctionToolResult: + block = NewContentBlock(v) case *AssistantGenText: block = NewContentBlock(v) case *AssistantGenImage: diff --git a/schema/message.go b/schema/message.go index 611bcedca..890af48ab 100644 --- a/schema/message.go +++ b/schema/message.go @@ -139,7 +139,6 @@ type ToolCall struct { Type string `json:"type"` // Function is the function call to be made. Function FunctionCall `json:"function"` - // Extra is used to store extra information for the tool call. Extra map[string]any `json:"extra,omitempty"` } @@ -222,6 +221,9 @@ type MessageInputPart struct { // File is the file input of the part, it's used when Type is "file_url". File *MessageInputFile `json:"file,omitempty"` + // ToolSearchResult holds the result of a tool search request, containing the matched tool names and their definitions. + ToolSearchResult *ToolSearchResult `json:"tool_search_result,omitempty"` + // Extra is used to store extra information. Extra map[string]any `json:"extra,omitempty"` } @@ -310,6 +312,9 @@ const ( // ToolPartTypeFile means the part is a file url. ToolPartTypeFile ToolPartType = "file" + + // ToolPartTypeToolSearchResult means the part contains tool search results. + ToolPartTypeToolSearchResult ToolPartType = "tool_search_result" ) // ToolOutputImage represents an image in tool output. @@ -336,6 +341,27 @@ type ToolOutputFile struct { MessagePartCommon } +// ToolSearchResult represents the result of a tool search operation. +// When a model issues a tool search call, the framework searches for matching tools +// and returns the results via this struct. +type ToolSearchResult struct { + // Tools contains the full definitions of matched tools that were not previously + // registered. Their complete definitions are required so that the model can + // understand their parameters and usage. + Tools []*ToolInfo +} + +func (t *ToolSearchResult) String() string { + sb := new(strings.Builder) + sb.WriteString("ToolSearchResult[") + for _, tool := range t.Tools { + sb.WriteString(tool.Name) + sb.WriteString(",") + } + sb.WriteString("]") + return sb.String() +} + // ToolOutputPart represents a part of tool execution output. // It supports streaming scenarios through the Index field for chunk merging. type ToolOutputPart struct { @@ -358,6 +384,9 @@ type ToolOutputPart struct { // File is the file content, used when Type is ToolPartTypeFile. File *ToolOutputFile `json:"file,omitempty"` + // ToolSearchResult holds the tool search results, used when Type is ToolPartTypeToolSearchResult. + ToolSearchResult *ToolSearchResult `json:"tool_search_result,omitempty"` + // Extra is used to store extra information. Extra map[string]any `json:"extra,omitempty"` } @@ -422,6 +451,14 @@ func convToolOutputPartToMessageInputPart(toolPart ToolOutputPart) (MessageInput File: &MessageInputFile{MessagePartCommon: toolPart.File.MessagePartCommon}, Extra: toolPart.Extra, }, nil + case ToolPartTypeToolSearchResult: + if toolPart.ToolSearchResult == nil { + return MessageInputPart{}, fmt.Errorf("tool search result is nil for tool part type %v", toolPart.Type) + } + return MessageInputPart{ + Type: ChatMessagePartTypeToolSearchResult, + ToolSearchResult: toolPart.ToolSearchResult, + }, nil default: return MessageInputPart{}, fmt.Errorf("unknown tool part type: %v", toolPart.Type) } @@ -498,6 +535,9 @@ const ( ChatMessagePartTypeFileURL ChatMessagePartType = "file_url" // ChatMessagePartTypeReasoning means the part is a reasoning block. ChatMessagePartTypeReasoning ChatMessagePartType = "reasoning" + + // ChatMessagePartTypeToolSearchResult means the part contains tool search results. + ChatMessagePartTypeToolSearchResult ChatMessagePartType = "tool_search_result" ) // Deprecated: This struct is deprecated as the MultiContent field is deprecated. diff --git a/schema/tool.go b/schema/tool.go index a49306047..f8a0a743e 100644 --- a/schema/tool.go +++ b/schema/tool.go @@ -17,6 +17,9 @@ package schema import ( + "bytes" + "encoding/gob" + "encoding/json" "sort" "github.com/eino-contrib/jsonschema" @@ -137,6 +140,104 @@ type ToolInfo struct { *ParamsOneOf } +type toolInfoForJSON struct { + Name string `json:"name,omitempty"` + Desc string `json:"desc,omitempty"` + Extra map[string]any `json:"extra,omitempty"` + HasParamsOneOf bool `json:"has_params_one_of,omitempty"` + Params map[string]*ParameterInfo `json:"params,omitempty"` + JSONSchema *jsonschema.Schema `json:"json_schema,omitempty"` +} + +type toolInfoForGob struct { + Name string + Desc string + Extra map[string]any + HasParamsOneOf bool + Params map[string]*ParameterInfo + JSONSchema *string +} + +func (t *ToolInfo) MarshalJSON() ([]byte, error) { + tmp := &toolInfoForJSON{ + Name: t.Name, + Desc: t.Desc, + Extra: t.Extra, + } + if t.ParamsOneOf != nil { + tmp.HasParamsOneOf = true + tmp.Params = t.ParamsOneOf.params + tmp.JSONSchema = t.ParamsOneOf.jsonschema + } + return json.Marshal(tmp) +} + +func (t *ToolInfo) UnmarshalJSON(data []byte) error { + tmp := &toolInfoForJSON{} + if err := json.Unmarshal(data, tmp); err != nil { + return err + } + t.Name = tmp.Name + t.Desc = tmp.Desc + t.Extra = tmp.Extra + if tmp.HasParamsOneOf { + t.ParamsOneOf = &ParamsOneOf{ + params: tmp.Params, + jsonschema: tmp.JSONSchema, + } + } + return nil +} + +func (t *ToolInfo) GobEncode() ([]byte, error) { + tmp := &toolInfoForGob{ + Name: t.Name, + Desc: t.Desc, + Extra: t.Extra, + } + if t.ParamsOneOf != nil { + tmp.HasParamsOneOf = true + tmp.Params = t.ParamsOneOf.params + if t.ParamsOneOf.jsonschema != nil { + b, err := json.Marshal(t.ParamsOneOf.jsonschema) + if err != nil { + return nil, err + } + str := string(b) + tmp.JSONSchema = &str + } + } + buf := new(bytes.Buffer) + if err := gob.NewEncoder(buf).Encode(tmp); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +func (t *ToolInfo) GobDecode(b []byte) error { + tmp := &toolInfoForGob{} + if err := gob.NewDecoder(bytes.NewBuffer(b)).Decode(tmp); err != nil { + return err + } + t.Name = tmp.Name + t.Desc = tmp.Desc + t.Extra = tmp.Extra + if !tmp.HasParamsOneOf { + return nil + } + t.ParamsOneOf = &ParamsOneOf{ + params: tmp.Params, + } + if tmp.JSONSchema != nil { + s := &jsonschema.Schema{} + if err := json.Unmarshal([]byte(*tmp.JSONSchema), s); err != nil { + return err + } + t.ParamsOneOf.jsonschema = s + } + return nil +} + // ParameterInfo is the information of a parameter. // It is used to describe the parameters of a tool. type ParameterInfo struct { diff --git a/schema/tool_test.go b/schema/tool_test.go index 97af29be2..e8f95c364 100644 --- a/schema/tool_test.go +++ b/schema/tool_test.go @@ -17,6 +17,8 @@ package schema import ( + "bytes" + "encoding/gob" "encoding/json" "testing" @@ -133,3 +135,49 @@ func TestParamsOneOfToJSONSchema(t *testing.T) { }) } + +func TestToolInfoSerialization(t *testing.T) { + ti1 := &ToolInfo{ + ParamsOneOf: NewParamsOneOfByParams(map[string]*ParameterInfo{ + "a": { + Type: String, + Desc: "desc", + }, + }), + } + ti2 := &ToolInfo{ + ParamsOneOf: NewParamsOneOfByJSONSchema(&jsonschema.Schema{ + Type: "string", + }), + } + + // json + b, err := json.Marshal(ti1) + assert.NoError(t, err) + result := &ToolInfo{} + err = json.Unmarshal(b, result) + assert.NoError(t, err) + assert.Equal(t, ti1, result) + b, err = json.Marshal(ti2) + assert.NoError(t, err) + result = &ToolInfo{} + err = json.Unmarshal(b, result) + assert.NoError(t, err) + assert.Equal(t, ti2, result) + + // gob + buf := new(bytes.Buffer) + err = gob.NewEncoder(buf).Encode(ti1) + assert.NoError(t, err) + result = &ToolInfo{} + err = gob.NewDecoder(buf).Decode(result) + assert.NoError(t, err) + assert.Equal(t, ti1, result) + buf = new(bytes.Buffer) + err = gob.NewEncoder(buf).Encode(ti2) + assert.NoError(t, err) + result = &ToolInfo{} + err = gob.NewDecoder(buf).Decode(result) + assert.NoError(t, err) + assert.Equal(t, ti2, result) +} From 99572ace4ac04144c3c5dcfdb1b243f39d4b6661 Mon Sep 17 00:00:00 2001 From: Born Date: Fri, 10 Apr 2026 16:23:12 +0800 Subject: [PATCH 53/65] fix(adk): propagate missing ToolsNodeConfig fields in ChatModelAgent (#945) - Add ToolAliases to prepareExecContext when building ToolsNodeConfig - Add UnknownToolsHandler, ExecuteSequentially, ToolArgumentsHandler, and ToolAliases to applyBeforeAgent when rebuilding after BeforeAgent handlers modify tools - Add tests covering argument alias remapping, name alias dispatch, alias preservation after handler rebuild, and handler-only tool registration with pre-configured aliases --- adk/chatmodel.go | 9 +- adk/chatmodel_test.go | 317 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 324 insertions(+), 2 deletions(-) diff --git a/adk/chatmodel.go b/adk/chatmodel.go index 83d8ffd4a..abfc55fa0 100644 --- a/adk/chatmodel.go +++ b/adk/chatmodel.go @@ -683,8 +683,12 @@ func (a *ChatModelAgent) applyBeforeAgent(ctx context.Context, ec *execContext) runtimeEC := &execContext{ instruction: runCtx.Instruction, toolsNodeConf: compose.ToolsNodeConfig{ - Tools: runCtx.Tools, - ToolCallMiddlewares: cloneSlice(ec.toolsNodeConf.ToolCallMiddlewares), + Tools: runCtx.Tools, + ToolCallMiddlewares: cloneSlice(ec.toolsNodeConf.ToolCallMiddlewares), + UnknownToolsHandler: ec.toolsNodeConf.UnknownToolsHandler, + ExecuteSequentially: ec.toolsNodeConf.ExecuteSequentially, + ToolArgumentsHandler: ec.toolsNodeConf.ToolArgumentsHandler, + ToolAliases: ec.toolsNodeConf.ToolAliases, }, returnDirectly: runCtx.ReturnDirectly, toolUpdated: true, @@ -710,6 +714,7 @@ func (a *ChatModelAgent) prepareExecContext(ctx context.Context) (*execContext, UnknownToolsHandler: a.toolsConfig.UnknownToolsHandler, ExecuteSequentially: a.toolsConfig.ExecuteSequentially, ToolArgumentsHandler: a.toolsConfig.ToolArgumentsHandler, + ToolAliases: a.toolsConfig.ToolAliases, } returnDirectly := copyMap(a.toolsConfig.ReturnDirectly) diff --git a/adk/chatmodel_test.go b/adk/chatmodel_test.go index f3ff6ea05..0edab5a2d 100644 --- a/adk/chatmodel_test.go +++ b/adk/chatmodel_test.go @@ -18,11 +18,13 @@ package adk import ( "context" + "encoding/json" "errors" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" "github.com/cloudwego/eino/components/model" @@ -2098,3 +2100,318 @@ func TestNewChatModelAgent_FailoverConfigValidation(t *testing.T) { assert.Contains(t, err.Error(), "ModelFailoverConfig.ShouldFailover") }) } + +// aliasCaptureTool captures the raw arguments JSON received by the tool. +type aliasCaptureTool struct { + name string + params map[string]*schema.ParameterInfo + receivedArgs string +} + +func (t *aliasCaptureTool) Info(_ context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: t.name, + Desc: t.name + " tool", + ParamsOneOf: schema.NewParamsOneOfByParams(t.params), + }, nil +} + +func (t *aliasCaptureTool) InvokableRun(_ context.Context, argumentsInJSON string, _ ...tool.Option) (string, error) { + t.receivedArgs = argumentsInJSON + return "ok", nil +} + +func TestToolAliasesPropagation(t *testing.T) { + t.Run("prepareExecContext_propagates_ToolAliases", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + + captureTool := &aliasCaptureTool{ + name: "grep", + params: map[string]*schema.ParameterInfo{ + "pattern": {Type: schema.String, Desc: "regex pattern"}, + "path": {Type: schema.String, Desc: "search path"}, + }, + } + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + generateCount := 0 + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, msgs []*schema.Message, opts ...model.Option) (*schema.Message, error) { + generateCount++ + if generateCount == 1 { + return &schema.Message{ + Role: schema.Assistant, + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "grep", + Arguments: `{"grep_content": "TODO", "path": "/src"}`, + }, + }, + }, + }, nil + } + return schema.AssistantMessage("done", nil), nil + }).AnyTimes() + cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "test", + Instruction: "test", + Model: cm, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{captureTool}, + ToolAliases: map[string]compose.ToolAliasConfig{ + "grep": { + ArgumentsAliases: map[string][]string{ + "pattern": {"grep_content"}, + }, + }, + }, + }, + }, + }) + require.NoError(t, err) + + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("search for TODOs")}}) + for { + _, ok := iter.Next() + if !ok { + break + } + } + + require.NotEmpty(t, captureTool.receivedArgs, "tool should have been called") + var args map[string]any + err = json.Unmarshal([]byte(captureTool.receivedArgs), &args) + require.NoError(t, err) + assert.Equal(t, "TODO", args["pattern"], "alias 'grep_content' should be remapped to 'pattern'") + assert.NotContains(t, args, "grep_content", "alias key should not be present after remapping") + assert.Equal(t, "/src", args["path"]) + }) + + t.Run("applyBeforeAgent_preserves_ToolAliases_when_handler_modifies_tools", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + + captureTool := &aliasCaptureTool{ + name: "grep", + params: map[string]*schema.ParameterInfo{ + "pattern": {Type: schema.String, Desc: "regex pattern"}, + }, + } + + extraTool := &aliasCaptureTool{ + name: "extra_tool", + params: map[string]*schema.ParameterInfo{ + "input": {Type: schema.String}, + }, + } + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + generateCount := 0 + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, msgs []*schema.Message, opts ...model.Option) (*schema.Message, error) { + generateCount++ + if generateCount == 1 { + return &schema.Message{ + Role: schema.Assistant, + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "grep", + Arguments: `{"grep_content": "FIXME"}`, + }, + }, + }, + }, nil + } + return schema.AssistantMessage("done", nil), nil + }).AnyTimes() + cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() + + handler := &testToolsHandler{ + BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, + tools: []tool.BaseTool{extraTool}, + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "test", + Instruction: "test", + Model: cm, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{captureTool}, + ToolAliases: map[string]compose.ToolAliasConfig{ + "grep": { + ArgumentsAliases: map[string][]string{ + "pattern": {"grep_content"}, + }, + }, + }, + }, + }, + Handlers: []ChatModelAgentMiddleware{handler}, + }) + require.NoError(t, err) + + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("search for FIXMEs")}}) + for { + _, ok := iter.Next() + if !ok { + break + } + } + + require.NotEmpty(t, captureTool.receivedArgs, "tool should have been called") + var args map[string]any + err = json.Unmarshal([]byte(captureTool.receivedArgs), &args) + require.NoError(t, err) + assert.Equal(t, "FIXME", args["pattern"], "alias 'grep_content' should be remapped to 'pattern' even after handler rebuild") + assert.NotContains(t, args, "grep_content", "alias key should not be present after remapping") + }) + + t.Run("name_alias_propagated_through_prepareExecContext", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + + captureTool := &aliasCaptureTool{ + name: "grep", + params: map[string]*schema.ParameterInfo{ + "pattern": {Type: schema.String}, + }, + } + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + generateCount := 0 + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, msgs []*schema.Message, opts ...model.Option) (*schema.Message, error) { + generateCount++ + if generateCount == 1 { + return &schema.Message{ + Role: schema.Assistant, + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "search_content", + Arguments: `{"pattern": "TODO"}`, + }, + }, + }, + }, nil + } + return schema.AssistantMessage("done", nil), nil + }).AnyTimes() + cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "test", + Instruction: "test", + Model: cm, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{captureTool}, + ToolAliases: map[string]compose.ToolAliasConfig{ + "grep": { + NameAliases: []string{"search_content"}, + }, + }, + }, + }, + }) + require.NoError(t, err) + + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("search")}}) + for { + _, ok := iter.Next() + if !ok { + break + } + } + + require.NotEmpty(t, captureTool.receivedArgs, "tool should have been called via name alias 'search_content'") + var args map[string]any + err = json.Unmarshal([]byte(captureTool.receivedArgs), &args) + require.NoError(t, err) + assert.Equal(t, "TODO", args["pattern"]) + }) + + t.Run("handler_adds_tool_matching_preexisting_ToolAliases_with_no_initial_tools", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + + captureTool := &aliasCaptureTool{ + name: "grep", + params: map[string]*schema.ParameterInfo{ + "pattern": {Type: schema.String, Desc: "regex pattern"}, + }, + } + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + generateCount := 0 + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, msgs []*schema.Message, opts ...model.Option) (*schema.Message, error) { + generateCount++ + if generateCount == 1 { + return &schema.Message{ + Role: schema.Assistant, + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "grep", + Arguments: `{"grep_content": "BUG"}`, + }, + }, + }, + }, nil + } + return schema.AssistantMessage("done", nil), nil + }).AnyTimes() + cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() + + handler := &testToolsHandler{ + BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, + tools: []tool.BaseTool{captureTool}, + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "test", + Instruction: "test", + Model: cm, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + ToolAliases: map[string]compose.ToolAliasConfig{ + "grep": { + ArgumentsAliases: map[string][]string{ + "pattern": {"grep_content"}, + }, + }, + }, + }, + }, + Handlers: []ChatModelAgentMiddleware{handler}, + }) + require.NoError(t, err) + + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("find bugs")}}) + for { + _, ok := iter.Next() + if !ok { + break + } + } + + require.NotEmpty(t, captureTool.receivedArgs, "tool added by handler should have been called") + var args map[string]any + err = json.Unmarshal([]byte(captureTool.receivedArgs), &args) + require.NoError(t, err) + assert.Equal(t, "BUG", args["pattern"], "alias 'grep_content' should be remapped to 'pattern' for handler-added tool") + assert.NotContains(t, args, "grep_content") + }) +} From b710dd866ba6f0729e236a6b89a68e8813e314c0 Mon Sep 17 00:00:00 2001 From: shentongmartin Date: Tue, 14 Apr 2026 14:14:41 +0800 Subject: [PATCH 54/65] refactor(adk): improve cancel propagation, encapsulate TurnLoop stop options, add UntilIdleFor (#942) --- adk/attack_test.go | 449 +++++++++++++ adk/cancel.go | 108 ++- adk/cancel_edge_test.go | 2 +- adk/cancel_recursive_test.go | 409 ++++++++++++ adk/cancel_test.go | 163 +++-- adk/turn_buffer.go | 128 ++++ adk/turn_loop.go | 372 ++++++++--- adk/turn_loop_test.go | 632 +++++++++++++----- adk/wrappers_test.go | 8 +- .../prompt/agentic_chat_template_test.go | 3 +- internal/channel.go | 39 +- internal/channel_test.go | 241 ------- 12 files changed, 1921 insertions(+), 633 deletions(-) create mode 100644 adk/attack_test.go create mode 100644 adk/cancel_recursive_test.go create mode 100644 adk/turn_buffer.go diff --git a/adk/attack_test.go b/adk/attack_test.go new file mode 100644 index 000000000..bfb4462ef --- /dev/null +++ b/adk/attack_test.go @@ -0,0 +1,449 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/cloudwego/eino/schema" +) + +func TestAttack_UntilIdleFor_ConcurrentPushDuringIdleTimer(t *testing.T) { + turnCount := int32(0) + turnDone := make(chan struct{}, 10) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + atomic.AddInt32(&turnCount, 1) + turnDone <- struct{}{} + return &AgentOutput{}, nil + }, + }, nil + }, + }) + + loop.Push("msg1") + <-turnDone + + loop.Stop(UntilIdleFor(200 * time.Millisecond)) + + for i := 0; i < 5; i++ { + time.Sleep(50 * time.Millisecond) + loop.Push("concurrent-" + string(rune('a'+i))) + <-turnDone + } + + done := make(chan struct{}) + go func() { + loop.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(3 * time.Second): + t.Fatal("loop did not exit after idle timeout — Push did not reset timer correctly") + } + + finalCount := atomic.LoadInt32(&turnCount) + assert.Equal(t, int32(6), finalCount, "all 6 pushes should have been processed") +} + +func TestAttack_UntilIdleFor_MultipleStopCallsFirstWins(t *testing.T) { + turnDone := make(chan struct{}) + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + close(turnDone) + return &AgentOutput{}, nil + }, + }, nil + }, + }) + + loop.Push("msg1") + <-turnDone + + loop.Stop(UntilIdleFor(100 * time.Millisecond)) + loop.Stop(UntilIdleFor(10 * time.Minute)) + + done := make(chan struct{}) + go func() { + loop.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("second UntilIdleFor should have been ignored; loop should have exited with 100ms timer") + } +} + +func TestAttack_BareStopOverridesUntilIdleFor(t *testing.T) { + agentStarted := make(chan struct{}) + agentDone := make(chan struct{}) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + close(agentStarted) + <-agentDone + return &AgentOutput{}, nil + }, + }, nil + }, + }) + + loop.Push("msg1") + <-agentStarted + + loop.Stop(UntilIdleFor(10 * time.Minute)) + + loop.Stop() + close(agentDone) + + done := make(chan struct{}) + go func() { + loop.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("bare Stop should override UntilIdleFor and cause immediate shutdown") + } + + exit := loop.Wait() + assert.NoError(t, exit.ExitReason, "bare Stop should exit cleanly") +} + +func TestAttack_StopSignal_NilCancelOptsDoNotDeescalate(t *testing.T) { + agentStarted := make(chan *cancelContext, 1) + probe := &turnLoopStopModeProbeAgent{ccCh: agentStarted} + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return probe, nil + }, + }) + + loop.Push("msg1") + cc := <-agentStarted + + loop.Stop(WithImmediate()) + + time.Sleep(20 * time.Millisecond) + + loop.Stop() + + time.Sleep(20 * time.Millisecond) + mode := cc.getMode() + assert.Equal(t, CancelImmediate, mode, "bare Stop after WithImmediate must not de-escalate cancel mode") + + exit := loop.Wait() + var ce *CancelError + require.True(t, errors.As(exit.ExitReason, &ce)) + assert.Equal(t, CancelImmediate, ce.Info.Mode) +} + +func TestAttack_CanceledItems_EmptyWhenAgentFinishesNormally(t *testing.T) { + agentStarted := make(chan struct{}) + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + close(agentStarted) + return &AgentOutput{}, nil + }, + }, nil + }, + }) + + loop.Push("msg1") + <-agentStarted + time.Sleep(50 * time.Millisecond) + loop.Stop() + + exit := loop.Wait() + assert.NoError(t, exit.ExitReason) + assert.Empty(t, exit.CanceledItems, "CanceledItems must be empty when agent finished normally") +} + +func TestAttack_TurnBuffer_WakeupDoesNotLoseItems(t *testing.T) { + tb := newTurnBuffer[string]() + + tb.Send("a") + tb.Send("b") + tb.Wakeup() + tb.Send("c") + + var got []string + for i := 0; i < 3; i++ { + val, ok := tb.Receive() + require.True(t, ok) + got = append(got, val) + } + + assert.Equal(t, []string{"a", "b", "c"}, got, "Wakeup must not cause items to be lost") +} + +func TestAttack_TurnBuffer_ClearWakeupPreventsSpuriousReturn(t *testing.T) { + tb := newTurnBuffer[string]() + + tb.Wakeup() + tb.ClearWakeup() + + received := make(chan string, 1) + go func() { + val, ok := tb.Receive() + if ok { + received <- val + } + }() + + time.Sleep(50 * time.Millisecond) + tb.Send("real") + + select { + case val := <-received: + assert.Equal(t, "real", val, "ClearWakeup should prevent spurious empty return") + case <-time.After(2 * time.Second): + t.Fatal("Receive blocked forever despite Send") + } +} + +func TestAttack_StopBeforeRun_UntilIdleFor_ExitsImmediately(t *testing.T) { + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Stop(UntilIdleFor(10 * time.Minute)) + loop.Stop() + + loop.Run(context.Background()) + + done := make(chan struct{}) + go func() { + loop.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("loop should exit immediately when Stop() called before Run()") + } +} + +func TestAttack_PushAfterStop_UntilIdleFor_RoutedToLateItems(t *testing.T) { + turnDone := make(chan struct{}) + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + close(turnDone) + return &AgentOutput{}, nil + }, + }, nil + }, + }) + + loop.Push("msg1") + <-turnDone + + loop.Stop(UntilIdleFor(50 * time.Millisecond)) + exit := loop.Wait() + assert.NoError(t, exit.ExitReason) + + ok, _ := loop.Push("after-stop") + assert.False(t, ok, "Push after loop exited should return false") + + late := exit.TakeLateItems() + assert.Equal(t, []string{"after-stop"}, late) +} + +func TestAttack_ConcurrentStopEscalation_RaceDetector(t *testing.T) { + agentStarted := make(chan struct{}) + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + close(agentStarted) + <-ctx.Done() + return nil, ctx.Err() + }, + }, nil + }, + }) + + loop.Push("msg1") + <-agentStarted + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + switch i % 4 { + case 0: + loop.Stop() + case 1: + loop.Stop(WithImmediate()) + case 2: + loop.Stop(WithGracefulTimeout(100 * time.Millisecond)) + case 3: + loop.Stop(UntilIdleFor(50 * time.Millisecond)) + } + }(i) + } + + wg.Wait() + exit := loop.Wait() + t.Log("ExitReason:", exit.ExitReason) +} + +func TestAttack_StopCause_FirstNonEmptyWins_ConcurrentCallers(t *testing.T) { + turnDone := make(chan struct{}) + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + close(turnDone) + return &AgentOutput{}, nil + }, + }, nil + }, + }) + + loop.Push("msg1") + <-turnDone + + loop.Stop(WithStopCause("first-cause")) + loop.Stop(WithStopCause("second-cause")) + + exit := loop.Wait() + assert.Equal(t, "first-cause", exit.StopCause, "first non-empty StopCause should win") +} + +func TestAttack_SkipCheckpoint_Sticky(t *testing.T) { + agentStarted := make(chan struct{}) + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + close(agentStarted) + <-ctx.Done() + return nil, ctx.Err() + }, + }, nil + }, + Store: &turnLoopCheckpointStore{m: make(map[string][]byte)}, + CheckpointID: "test-sticky", + }) + + loop.Push("msg1") + <-agentStarted + + loop.Stop(WithSkipCheckpoint()) + loop.Stop(WithImmediate()) + + exit := loop.Wait() + assert.False(t, exit.Checkpointed, "SkipCheckpoint is sticky; checkpoint should be skipped") +} diff --git a/adk/cancel.go b/adk/cancel.go index a15699f53..6d4aa9ad9 100644 --- a/adk/cancel.go +++ b/adk/cancel.go @@ -44,21 +44,21 @@ type CancelMode int const ( // CancelImmediate cancels the agent as soon as the signal is received, - // without waiting for a ChatModel or ToolCalls safe-point. Propagates - // to all descendant agents via the cancel context hierarchy, including - // agents nested inside AgentTools and workflow sub-agents. + // without waiting for a ChatModel or ToolCalls safe-point. + // By default, only the root agent is interrupted; descendant agents inside + // AgentTools are torn down via context cancellation as a side effect. + // Use WithRecursive to propagate explicit immediate-cancel signals to + // descendants for clean teardown with grace period. CancelImmediate CancelMode = 0 - // CancelAfterChatModel cancels after the first chat model call that completes - // anywhere in the agent hierarchy, including nested sub-agents, agent tools, - // and workflow branches. The cancel mode propagates to all descendant agents; - // whichever ChatModel finishes first triggers the cancel. The interrupting - // agent emits an interrupt that bubbles up through the agent tree — parent - // agents do not need to reach their own ChatModel safe-point. + // CancelAfterChatModel cancels after the root agent's next chat model call + // completes. By default, only the root agent checks this safe-point; + // nested sub-agents inside AgentTools are unaware of the cancel. + // Use WithRecursive to propagate the cancel to all descendants — whichever + // ChatModel finishes first triggers the cancel. CancelAfterChatModel CancelMode = 1 << iota - // CancelAfterToolCalls cancels after the first set of concurrent tool calls - // that completes anywhere in the agent hierarchy. Like CancelAfterChatModel, - // this mode propagates to all descendants and fires at whichever level - // reaches the safe-point first. + // CancelAfterToolCalls cancels after the root agent's next set of concurrent + // tool calls completes. By default, only the root agent checks this safe-point. + // Use WithRecursive to propagate to all descendants. CancelAfterToolCalls ) @@ -92,8 +92,9 @@ func (h *CancelHandle) Wait() error { type AgentCancelFunc func(...AgentCancelOption) (*CancelHandle, bool) type agentCancelConfig struct { - Mode CancelMode - Timeout *time.Duration + Mode CancelMode + Recursive bool + Timeout *time.Duration } // AgentCancelOption configures cancel behavior. @@ -118,6 +119,21 @@ func WithAgentCancelTimeout(timeout time.Duration) AgentCancelOption { } } +// WithRecursive opts into recursive cancel propagation. By default, cancel +// modes only affect the root agent; descendant agents inside AgentTools are +// not notified. WithRecursive makes the cancel propagate to all descendants: +// - CancelAfterChatModel / CancelAfterToolCalls: descendants check their own safe-points. +// - CancelImmediate: descendants receive explicit immediate-cancel signals for +// clean teardown; the root uses a grace period to collect child interrupts. +// +// Once any cancel call includes WithRecursive, the flag stays set for the +// entire cancel lifecycle (monotonic escalation). +func WithRecursive() AgentCancelOption { + return func(config *agentCancelConfig) { + config.Recursive = true + } +} + // AgentCancelInfo contains information about a cancel operation. type AgentCancelInfo struct { Mode CancelMode @@ -296,6 +312,9 @@ type cancelContext struct { startedMode int32 // atomic, mode when state transitioned to cancelling deadlineUnixNano int64 // atomic, 0 means no deadline + recursive int32 // atomic; 1 if cancel should propagate to descendant agents via deriveChild + recursiveChan chan struct{} // closed when recursive transitions from 0 to 1 + root bool // true for the original cancelContext created by WithCancel(); false for derived children parent *cancelContext // non-nil for derived children; used to decrement parent's activeChildren on markDone @@ -316,6 +335,7 @@ func newCancelContext() *cancelContext { immediateChan: make(chan struct{}), doneChan: make(chan struct{}), timeoutNotify: make(chan struct{}, 1), + recursiveChan: make(chan struct{}), root: true, } } @@ -324,6 +344,18 @@ func (cc *cancelContext) isRoot() bool { return cc != nil && cc.root } +func (cc *cancelContext) isRecursive() bool { + return cc != nil && atomic.LoadInt32(&cc.recursive) == 1 +} + +// setRecursive(false) is a no-op; recursive is monotonically escalating: +// once set to true, it cannot be reverted. +func (cc *cancelContext) setRecursive(v bool) { + if v && atomic.CompareAndSwapInt32(&cc.recursive, 0, 1) { + close(cc.recursiveChan) + } +} + // deriveChild creates a child cancelContext that receives cancel propagation // from the parent. The caller MUST ensure the child's markDone() is eventually // called (e.g., via wrapIterWithCancelCtx's defer) or that ctx is canceled; @@ -337,10 +369,29 @@ func (cc *cancelContext) deriveChild(ctx context.Context) *cancelContext { child.parent = cc atomic.AddInt32(&cc.activeChildren, 1) + // Each goroutine below propagates one signal class (cancel / immediate) to + // the child. The pattern is a two-phase select: + // Phase 1: wait for the parent signal (or child/ctx completion). + // Phase 2: if the signal fired but recursive mode is not active yet, + // enter a second select waiting for either recursive escalation + // (recursiveChan) or child/ctx completion. This ensures + // non-recursive cancels leave children unaware, while a late + // escalation to recursive still propagates. go func() { select { case <-cc.cancelChan: - child.triggerCancel(cc.getMode()) + if cc.isRecursive() { + child.setRecursive(true) + child.triggerCancel(cc.getMode()) + return + } + select { + case <-cc.recursiveChan: + child.setRecursive(true) + child.triggerCancel(cc.getMode()) + case <-child.doneChan: + case <-ctx.Done(): + } case <-child.doneChan: case <-ctx.Done(): } @@ -349,7 +400,18 @@ func (cc *cancelContext) deriveChild(ctx context.Context) *cancelContext { go func() { select { case <-cc.immediateChan: - child.triggerImmediateCancel() + if cc.isRecursive() { + child.setRecursive(true) + child.triggerImmediateCancel() + return + } + select { + case <-cc.recursiveChan: + child.setRecursive(true) + child.triggerImmediateCancel() + case <-child.doneChan: + case <-ctx.Done(): + } case <-child.doneChan: case <-ctx.Done(): } @@ -504,7 +566,10 @@ func (cc *cancelContext) hasActiveChildren() bool { func (cc *cancelContext) wrapGraphInterruptWithGracePeriod(interrupt func(...compose.GraphInterruptOption)) func(...compose.GraphInterruptOption) { return func(opts ...compose.GraphInterruptOption) { - if cc.hasActiveChildren() { + // Grace period only applies in recursive mode: in shallow mode, + // children are unaware of the cancel and don't need time to propagate + // their interrupt signals back. + if cc.isRecursive() && cc.hasActiveChildren() { newOpts := make([]compose.GraphInterruptOption, len(opts)+1) copy(newOpts, opts) newOpts[len(opts)] = compose.WithGraphInterruptTimeout(defaultCancelImmediateGracePeriod) @@ -682,10 +747,17 @@ func (cc *cancelContext) buildCancelFunc() AgentCancelFunc { curMode = req.Mode cc.setMode(curMode) atomic.StoreInt32(&cc.startedMode, int32(curMode)) + cc.setRecursive(req.Recursive) close(cc.cancelChan) } else { + // Recursive is monotonic: once set, cannot be unset. The first + // cancel call uses the bool directly; subsequent calls only + // escalate (false → true) — setRecursive(false) is a no-op. curMode = joinMode(curMode, req.Mode) cc.setMode(curMode) + if req.Recursive { + cc.setRecursive(true) + } } if curMode == CancelImmediate { diff --git a/adk/cancel_edge_test.go b/adk/cancel_edge_test.go index d3fb02a1a..b0afbe674 100644 --- a/adk/cancel_edge_test.go +++ b/adk/cancel_edge_test.go @@ -1242,7 +1242,7 @@ func TestWithCancel_CancelAfterChatModel_NestedAgentTool(t *testing.T) { cancelDone := make(chan error, 1) go func() { - handle, _ := cancelFn(WithAgentCancelMode(CancelAfterChatModel)) + handle, _ := cancelFn(WithAgentCancelMode(CancelAfterChatModel), WithRecursive()) cancelDone <- handle.Wait() }() diff --git a/adk/cancel_recursive_test.go b/adk/cancel_recursive_test.go new file mode 100644 index 000000000..9f13f55d2 --- /dev/null +++ b/adk/cancel_recursive_test.go @@ -0,0 +1,409 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "runtime" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/compose" +) + +func assertNotClosedWithin(t *testing.T, ch <-chan struct{}, d time.Duration) { + t.Helper() + select { + case <-ch: + t.Fatal("channel was closed but should not have been") + case <-time.After(d): + } +} + +func setupParentChild(t *testing.T) (parent, child *cancelContext, cleanup func()) { + parent = newCancelContext() + ctx, cancel := context.WithCancel(context.Background()) + child = parent.deriveChild(ctx) + cleanup = func() { + child.markDone() + cancel() + } + t.Cleanup(cleanup) + return parent, child, cleanup +} + +func TestDeriveChild(t *testing.T) { + t.Run("Shallow", func(t *testing.T) { + t.Run("DoesNotPropagateSafePoint", func(t *testing.T) { + parent, child, _ := setupParentChild(t) + + parent.triggerCancel(CancelAfterChatModel) + + assertNotClosedWithin(t, child.cancelChan, 50*time.Millisecond) + }) + + t.Run("ImmediateDoesNotPropagate", func(t *testing.T) { + parent, child, _ := setupParentChild(t) + + parent.triggerImmediateCancel() + + assertNotClosedWithin(t, child.immediateChan, 50*time.Millisecond) + }) + + t.Run("GrandchildNoPropagation", func(t *testing.T) { + a := newCancelContext() + ctx, cancel := context.WithCancel(context.Background()) + + b := a.deriveChild(ctx) + c := b.deriveChild(ctx) + t.Cleanup(func() { + c.markDone() + b.markDone() + cancel() + }) + + a.triggerCancel(CancelAfterChatModel) + + assertNotClosedWithin(t, b.cancelChan, 50*time.Millisecond) + assertNotClosedWithin(t, c.cancelChan, 50*time.Millisecond) + }) + + t.Run("NeverRecursive_GoroutineCleanup", func(t *testing.T) { + runtime.GC() + time.Sleep(50 * time.Millisecond) + before := runtime.NumGoroutine() + + parent := newCancelContext() + ctx, cancel := context.WithCancel(context.Background()) + + child := parent.deriveChild(ctx) + + parent.triggerCancel(CancelAfterChatModel) + time.Sleep(100 * time.Millisecond) + + child.markDone() + cancel() + + time.Sleep(200 * time.Millisecond) + runtime.GC() + time.Sleep(50 * time.Millisecond) + after := runtime.NumGoroutine() + + assert.InDelta(t, before, after, 5, "goroutine leak detected: before=%d after=%d", before, after) + }) + }) + + t.Run("Recursive", func(t *testing.T) { + t.Run("PropagatesSafePoint", func(t *testing.T) { + parent, child, _ := setupParentChild(t) + + parent.setRecursive(true) + parent.triggerCancel(CancelAfterChatModel) + + select { + case <-child.cancelChan: + case <-time.After(1 * time.Second): + t.Fatal("child did not receive cancel within 1s") + } + assert.True(t, child.shouldCancel()) + }) + + t.Run("ImmediatePropagates", func(t *testing.T) { + parent, child, _ := setupParentChild(t) + + parent.setRecursive(true) + parent.triggerImmediateCancel() + + select { + case <-child.immediateChan: + case <-time.After(1 * time.Second): + t.Fatal("child did not receive immediate cancel within 1s") + } + assert.True(t, child.isImmediateCancelled()) + }) + + t.Run("GrandchildPropagation", func(t *testing.T) { + a := newCancelContext() + ctx, cancel := context.WithCancel(context.Background()) + + b := a.deriveChild(ctx) + c := b.deriveChild(ctx) + t.Cleanup(func() { + c.markDone() + b.markDone() + cancel() + }) + + a.setRecursive(true) + a.triggerCancel(CancelAfterChatModel) + + select { + case <-b.cancelChan: + case <-time.After(1 * time.Second): + t.Fatal("B did not receive cancel within 1s") + } + + select { + case <-c.cancelChan: + case <-time.After(1 * time.Second): + t.Fatal("C did not receive cancel within 1s") + } + + assert.True(t, b.shouldCancel()) + assert.True(t, c.shouldCancel()) + }) + + t.Run("SetBeforeCancel", func(t *testing.T) { + parent, child, _ := setupParentChild(t) + + parent.setRecursive(true) + + parent.triggerCancel(CancelAfterChatModel) + + select { + case <-child.cancelChan: + case <-time.After(1 * time.Second): + t.Fatal("child did not receive cancel within 1s") + } + assert.True(t, child.shouldCancel()) + }) + + t.Run("AfterRecursiveAndCancelAlreadySet", func(t *testing.T) { + parent := newCancelContext() + ctx, cancel := context.WithCancel(context.Background()) + + parent.setRecursive(true) + parent.triggerCancel(CancelAfterChatModel) + + child := parent.deriveChild(ctx) + t.Cleanup(func() { + child.markDone() + cancel() + }) + + select { + case <-child.cancelChan: + case <-time.After(1 * time.Second): + t.Fatal("child did not immediately receive cancel") + } + assert.True(t, child.shouldCancel()) + }) + }) + + t.Run("Escalation", func(t *testing.T) { + t.Run("EscalateFromNonRecursive", func(t *testing.T) { + parent, child, _ := setupParentChild(t) + + parent.triggerCancel(CancelAfterChatModel) + + assertNotClosedWithin(t, child.cancelChan, 50*time.Millisecond) + + parent.setRecursive(true) + + select { + case <-child.cancelChan: + case <-time.After(1 * time.Second): + t.Fatal("child did not receive cancel after escalation within 1s") + } + assert.True(t, child.shouldCancel()) + }) + + t.Run("EscalateImmediate", func(t *testing.T) { + parent, child, _ := setupParentChild(t) + + parent.triggerImmediateCancel() + + assertNotClosedWithin(t, child.immediateChan, 50*time.Millisecond) + + parent.setRecursive(true) + + select { + case <-child.immediateChan: + case <-time.After(1 * time.Second): + t.Fatal("child did not receive immediate cancel after escalation within 1s") + } + assert.True(t, child.isImmediateCancelled()) + }) + }) +} + +func TestDeriveChild_Race(t *testing.T) { + t.Run("SetRecursiveConcurrentWithCancelChan", func(t *testing.T) { + for i := 0; i < 100; i++ { + parent := newCancelContext() + ctx, cancel := context.WithCancel(context.Background()) + + child := parent.deriveChild(ctx) + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + parent.setRecursive(true) + }() + + go func() { + defer wg.Done() + parent.triggerCancel(CancelAfterChatModel) + }() + + wg.Wait() + + select { + case <-child.cancelChan: + case <-time.After(1 * time.Second): + t.Fatalf("iteration %d: child did not receive cancel within 1s", i) + } + + assert.True(t, child.shouldCancel()) + child.markDone() + cancel() + } + }) + + t.Run("ChildCompletesBeforeEscalation", func(t *testing.T) { + parent := newCancelContext() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + child := parent.deriveChild(ctx) + + parent.triggerCancel(CancelAfterChatModel) + time.Sleep(50 * time.Millisecond) + + child.markDone() + time.Sleep(50 * time.Millisecond) + + parent.setRecursive(true) + + assertNotClosedWithin(t, child.cancelChan, 50*time.Millisecond) + }) + + t.Run("MultipleChildren_PartialCompletion", func(t *testing.T) { + parent := newCancelContext() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + child1 := parent.deriveChild(ctx) + child2 := parent.deriveChild(ctx) + + parent.triggerCancel(CancelAfterChatModel) + time.Sleep(50 * time.Millisecond) + + child1.markDone() + time.Sleep(50 * time.Millisecond) + + parent.setRecursive(true) + + select { + case <-child2.cancelChan: + case <-time.After(1 * time.Second): + t.Fatal("running child did not receive cancel within 1s") + } + + assert.True(t, child2.shouldCancel()) + assert.False(t, child1.shouldCancel()) + child2.markDone() + }) + + t.Run("ContextCancelConcurrentWithRecursive", func(t *testing.T) { + done := make(chan struct{}) + go func() { + defer close(done) + + parent := newCancelContext() + ctx, cancel := context.WithCancel(context.Background()) + + child := parent.deriveChild(ctx) + + parent.triggerCancel(CancelAfterChatModel) + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + cancel() + }() + + go func() { + defer wg.Done() + parent.setRecursive(true) + }() + + wg.Wait() + child.markDone() + }() + + select { + case <-done: + case <-time.After(1 * time.Second): + t.Fatal("deadlock detected") + } + }) + + t.Run("ConcurrentSetRecursive", func(t *testing.T) { + parent := newCancelContext() + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + parent.setRecursive(true) + }() + } + + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(1 * time.Second): + t.Fatal("deadlock or panic in concurrent setRecursive") + } + + assert.True(t, parent.isRecursive()) + }) +} + +func TestGracePeriod_OnlyWhenRecursive(t *testing.T) { + parent, _, _ := setupParentChild(t) + + var nonRecursiveOptCount int + wrappedNonRecursive := parent.wrapGraphInterruptWithGracePeriod(func(opts ...compose.GraphInterruptOption) { + nonRecursiveOptCount = len(opts) + }) + wrappedNonRecursive() + assert.Equal(t, 0, nonRecursiveOptCount) + + parent.setRecursive(true) + + var recursiveOptCount int + wrappedRecursive := parent.wrapGraphInterruptWithGracePeriod(func(opts ...compose.GraphInterruptOption) { + recursiveOptCount = len(opts) + }) + wrappedRecursive() + assert.Equal(t, 1, recursiveOptCount) +} diff --git a/adk/cancel_test.go b/adk/cancel_test.go index 105c9ea13..2096a9ac3 100644 --- a/adk/cancel_test.go +++ b/adk/cancel_test.go @@ -142,6 +142,51 @@ func (s *cancelTestStore) Get(_ context.Context, key string) ([]byte, bool, erro return v, ok, nil } +func assertHasCancelError(t *testing.T, events []*AgentEvent) { + t.Helper() + for _, e := range events { + var ce *CancelError + if e.Err != nil && errors.As(e.Err, &ce) { + return + } + } + t.Fatal("expected CancelError in events") +} + +func drainAndAssertCancelError(t *testing.T, iter *AsyncIterator[*AgentEvent]) { + t.Helper() + for { + event, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + return + } + } + t.Fatal("expected CancelError in event stream") +} + +func drainEventsAndAssertCancelError(t *testing.T, iter *AsyncIterator[*AgentEvent]) []*AgentEvent { + t.Helper() + var events []*AgentEvent + hasCancelError := false + for { + event, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + hasCancelError = true + } + events = append(events, event) + } + assert.True(t, hasCancelError, "expected CancelError in event stream") + return events +} + func TestCancelContext(t *testing.T) { t.Run("BasicCancelContext", func(t *testing.T) { cc := newCancelContext() @@ -237,16 +282,9 @@ func TestWithCancel_WithTools(t *testing.T) { t.Fatal("Timed out waiting for events") } - assert.True(t, len(events) > 0) + assert.NotEmpty(t, events) - hasCancelError := false - for _, e := range events { - var cancelErr *CancelError - if e.Err != nil && errors.As(e.Err, &cancelErr) { - hasCancelError = true - } - } - assert.True(t, hasCancelError, "Should have CancelError event after cancel") + assertHasCancelError(t, events) }) t.Run("CancelAfterChatModel_DuringToolCall", func(t *testing.T) { @@ -317,7 +355,7 @@ func TestWithCancel_WithTools(t *testing.T) { events = append(events, event) } - assert.True(t, len(events) > 0) + assert.NotEmpty(t, events) assert.True(t, atomic.LoadInt32(&st.callCount) >= 1, "Tool should have been called") }) @@ -388,7 +426,7 @@ func TestWithCancel_WithTools(t *testing.T) { events = append(events, event) } - assert.True(t, len(events) > 0) + assert.NotEmpty(t, events) assert.True(t, atomic.LoadInt32(&st.callCount) >= 1, "Tool should have been called") }) @@ -400,6 +438,7 @@ func TestWithCancel_WithTools(t *testing.T) { child := cc.deriveChild(ctx) assert.NotNil(t, child) + cc.setRecursive(true) cc.setMode(CancelImmediate) if atomic.CompareAndSwapInt32(&cc.state, stateRunning, stateCancelling) { @@ -475,7 +514,7 @@ func TestWithCancel_WithTools(t *testing.T) { <-modelStarted - handle, _ := cancelFn(WithAgentCancelMode(CancelAfterChatModel)) + handle, _ := cancelFn(WithAgentCancelMode(CancelAfterChatModel), WithRecursive()) err = handle.Wait() assert.NoError(t, err) @@ -776,16 +815,9 @@ func TestWithCancel_Streaming(t *testing.T) { t.Fatal("Timed out waiting for events") } - assert.True(t, len(events) > 0) + assert.NotEmpty(t, events) - hasCancelError := false - for _, e := range events { - var ce *CancelError - if e.Err != nil && errors.As(e.Err, &ce) { - hasCancelError = true - } - } - assert.True(t, hasCancelError, "Should have CancelError event after cancel") + assertHasCancelError(t, events) }) t.Run("CancelAfterToolCalls_Streaming", func(t *testing.T) { @@ -858,7 +890,7 @@ func TestWithCancel_Streaming(t *testing.T) { events = append(events, event) } - assert.True(t, len(events) > 0) + assert.NotEmpty(t, events) assert.True(t, atomic.LoadInt32(&st.callCount) >= 1, "Tool should have been called") }) } @@ -990,7 +1022,7 @@ func TestWithCancel_Resume(t *testing.T) { resumeEvents = append(resumeEvents, event) } - assert.True(t, len(resumeEvents) > 0, "Resume should produce events") + assert.NotEmpty(t, resumeEvents, "Resume should produce events") }) t.Run("Resume_ThenCancel", func(t *testing.T) { @@ -1107,7 +1139,7 @@ func TestWithCancel_Resume(t *testing.T) { elapsed := time.Since(start) assert.True(t, elapsed < 1*time.Second, "Resume should return quickly after cancel, elapsed: %v", elapsed) - assert.True(t, len(resumeEvents) > 0) + assert.NotEmpty(t, resumeEvents) hasCancelError := false for _, e := range resumeEvents { @@ -1491,21 +1523,7 @@ func TestWithCancel_SequentialAgent(t *testing.T) { err = handle.Wait() assert.NoError(t, err, "Cancel during second agent should succeed, not return ErrExecutionCompleted") - var events []*AgentEvent - hasCancelError := false - for { - event, ok := iter.Next() - if !ok { - break - } - var ce *CancelError - if event.Err != nil && errors.As(event.Err, &ce) { - hasCancelError = true - } - events = append(events, event) - } - - assert.True(t, hasCancelError, "Should have CancelError event") + drainEventsAndAssertCancelError(t, iter) }) } @@ -1564,19 +1582,7 @@ func TestWithCancel_LoopAgent(t *testing.T) { err = handle.Wait() assert.NoError(t, err, "Cancel during loop iteration should succeed") - hasCancelError := false - for { - event, ok := iter.Next() - if !ok { - break - } - var ce *CancelError - if event.Err != nil && errors.As(event.Err, &ce) { - hasCancelError = true - } - } - - assert.True(t, hasCancelError, "Should have CancelError event") + drainAndAssertCancelError(t, iter) }) } @@ -1623,22 +1629,10 @@ func TestWithCancel_ParallelAgent(t *testing.T) { err = handle.Wait() assert.NoError(t, err, "Cancel during parallel agents should succeed") - var events []*AgentEvent - hasCancelError := false - for { - event, ok := iter.Next() - if !ok { - break - } - var ce *CancelError - if event.Err != nil && errors.As(event.Err, &ce) { - hasCancelError = true - } - events = append(events, event) - } + events := drainEventsAndAssertCancelError(t, iter) elapsed := time.Since(start) - assert.True(t, hasCancelError, "Should have CancelError event") + _ = events assert.True(t, elapsed < 3*time.Second, "Should complete quickly after cancel, elapsed: %v", elapsed) }) } @@ -1707,20 +1701,9 @@ func TestWithCancel_SupervisorAgent(t *testing.T) { err = handle.Wait() assert.NoError(t, err, "Cancel during sub-agent should succeed") - hasCancelError := false - for { - event, ok := iter.Next() - if !ok { - break - } - var ce *CancelError - if event.Err != nil && errors.As(event.Err, &ce) { - hasCancelError = true - } - } + drainAndAssertCancelError(t, iter) elapsed := time.Since(start) - assert.True(t, hasCancelError, "Should have CancelError event") assert.True(t, elapsed < 3*time.Second, "Should complete quickly after cancel, elapsed: %v", elapsed) }) } @@ -2302,7 +2285,7 @@ func TestCancel_SequentialWorkflow_CancelAfterChatModel(t *testing.T) { assert.Nil(t, event.Err, "Should not have error during resume") resumeEvents = append(resumeEvents, event) } - assert.True(t, len(resumeEvents) > 0, "Resume should produce events") + assert.NotEmpty(t, resumeEvents, "Resume should produce events") } func TestCancelImmediate_OrphanedToolGoroutine_NoPanic(t *testing.T) { @@ -2625,7 +2608,7 @@ func TestCancelImmediate_AgentTool_PreservesChildCheckpoint(t *testing.T) { waitForChan(t, leafModel.startedChan, "Leaf agent model did not start") - handle, contributed := cancelFn() + handle, contributed := cancelFn(WithRecursive()) assert.True(t, contributed) assert.NoError(t, handle.Wait()) @@ -3253,23 +3236,27 @@ func TestCancelContext_ActiveChildren_Tracking(t *testing.T) { wrapped := parent.wrapGraphInterruptWithGracePeriod(mockInterrupt) - // No children: no options appended receivedOpts = nil wrapped() assert.Empty(t, receivedOpts, "Should pass no extra options when no children") - // With active child: one timeout option appended _ = parent.deriveChild(ctx) + + receivedOpts = nil + wrapped() + assert.Empty(t, receivedOpts, "Should pass no extra options when children are active but not recursive") + + parent.setRecursive(true) + receivedOpts = nil wrapped() - assert.Len(t, receivedOpts, 1, "Should add exactly one timeout option when children are active") + assert.Len(t, receivedOpts, 1, "Should add exactly one timeout option when children are active and recursive") - // Caller-provided options are preserved, grace period option appended after receivedOpts = nil callerOpt := compose.WithGraphInterruptTimeout(0) wrapped(callerOpt) assert.Len(t, receivedOpts, 2, - "Should append timeout option after caller-provided options when children are active") + "Should append timeout option after caller-provided options when children are active and recursive") // Note: verifying the exact timeout value (defaultCancelImmediateGracePeriod) // requires access to unexported compose.graphInterruptOptions. The integration // tests (TestCancelImmediate_AgentTool_PreservesChildCheckpoint) verify the @@ -3437,7 +3424,7 @@ func TestCancel_LoopWorkflow_CancelAfterChatModel(t *testing.T) { assert.Nil(t, event.Err, "Should not have error during resume") resumeEvents = append(resumeEvents, event) } - assert.True(t, len(resumeEvents) > 0, "Resume should produce events") + assert.NotEmpty(t, resumeEvents, "Resume should produce events") } func TestCancel_NestedWorkflow_AgentTool_CancelAfterChatModel(t *testing.T) { @@ -3514,7 +3501,7 @@ func TestCancel_NestedWorkflow_AgentTool_CancelAfterChatModel(t *testing.T) { t.Fatal("Leaf agent model did not start") } - handle, contributed := cancelFn(WithAgentCancelMode(CancelAfterChatModel)) + handle, contributed := cancelFn(WithAgentCancelMode(CancelAfterChatModel), WithRecursive()) assert.True(t, contributed, "Cancel should contribute") err = handle.Wait() assert.NoError(t, err) @@ -3737,5 +3724,5 @@ func TestCancel_CancelAfterToolCalls_InSequentialWorkflow(t *testing.T) { assert.Nil(t, event.Err, "Should not have error during resume") resumeEvents = append(resumeEvents, event) } - assert.True(t, len(resumeEvents) > 0, "Resume should produce events") + assert.NotEmpty(t, resumeEvents, "Resume should produce events") } diff --git a/adk/turn_buffer.go b/adk/turn_buffer.go new file mode 100644 index 000000000..b154587c9 --- /dev/null +++ b/adk/turn_buffer.go @@ -0,0 +1,128 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import "sync" + +type turnBuffer[T any] struct { + buf []T + mu sync.Mutex + notEmpty *sync.Cond + closed bool + woken bool +} + +func newTurnBuffer[T any]() *turnBuffer[T] { + tb := &turnBuffer[T]{} + tb.notEmpty = sync.NewCond(&tb.mu) + return tb +} + +func (tb *turnBuffer[T]) Send(value T) { + tb.mu.Lock() + defer tb.mu.Unlock() + + if tb.closed { + panic("turnBuffer: send on closed buffer") + } + + tb.buf = append(tb.buf, value) + tb.notEmpty.Signal() +} + +func (tb *turnBuffer[T]) TrySend(value T) bool { + tb.mu.Lock() + defer tb.mu.Unlock() + + if tb.closed { + return false + } + + tb.buf = append(tb.buf, value) + tb.notEmpty.Signal() + return true +} + +func (tb *turnBuffer[T]) Receive() (T, bool) { + tb.mu.Lock() + defer tb.mu.Unlock() + + for len(tb.buf) == 0 && !tb.closed && !tb.woken { + tb.notEmpty.Wait() + } + + tb.woken = false + + if len(tb.buf) == 0 { + var zero T + return zero, false + } + + val := tb.buf[0] + tb.buf = tb.buf[1:] + return val, true +} + +func (tb *turnBuffer[T]) Close() { + tb.mu.Lock() + defer tb.mu.Unlock() + + if !tb.closed { + tb.closed = true + tb.notEmpty.Broadcast() + } +} + +func (tb *turnBuffer[T]) TakeAll() []T { + tb.mu.Lock() + defer tb.mu.Unlock() + + if len(tb.buf) == 0 { + return nil + } + + values := tb.buf + tb.buf = nil + return values +} + +func (tb *turnBuffer[T]) PushFront(values []T) { + if len(values) == 0 { + return + } + + tb.mu.Lock() + defer tb.mu.Unlock() + + tb.buf = append(append([]T{}, values...), tb.buf...) + tb.notEmpty.Signal() +} + +func (tb *turnBuffer[T]) Wakeup() { + tb.mu.Lock() + defer tb.mu.Unlock() + + tb.woken = true + tb.notEmpty.Broadcast() +} + +func (tb *turnBuffer[T]) ClearWakeup() { + tb.mu.Lock() + defer tb.mu.Unlock() + + tb.woken = false +} diff --git a/adk/turn_loop.go b/adk/turn_loop.go index 7d0f61a3b..124f65459 100644 --- a/adk/turn_loop.go +++ b/adk/turn_loop.go @@ -27,7 +27,6 @@ import ( "sync/atomic" "time" - "github.com/cloudwego/eino/internal" "github.com/cloudwego/eino/internal/safe" ) @@ -54,19 +53,19 @@ import ( // acts when a new Stop() call has been made, supporting mode escalation // (e.g. CancelAfterToolCalls followed by CancelImmediate). type stopSignal struct { - // done is closed exactly once by closeDone(). A closed channel is - // readable forever, so it serves as a durable stop flag for all watchers. done chan struct{} - mu sync.Mutex - gen uint64 + mu sync.Mutex + gen uint64 + // agentCancelOpts controls how the stop interacts with the running agent: + // nil → no cancel; the turn runs to completion (bare Stop) + // empty → CancelImmediate (WithImmediate) + // non-empty → cancel with specific modes (WithGraceful, WithGracefulTimeout) agentCancelOpts []AgentCancelOption skipCheckpoint bool stopCause string - // notify is a buffered(1) channel that wakes the current turn's watcher - // when Stop() is called. Unlike done, it supports repeated Stop() calls - // for cancel-mode escalation. - notify chan struct{} + idleFor time.Duration + notify chan struct{} } func newStopSignal() *stopSignal { @@ -82,13 +81,21 @@ func newStopSignal() *stopSignal { func (s *stopSignal) signal(cfg *stopConfig) { s.mu.Lock() s.gen++ - s.agentCancelOpts = cfg.agentCancelOpts + // Only overwrite when the caller explicitly provides cancel options. + // A bare Stop() leaves cfg.agentCancelOpts nil (no cancel intent), which + // must not de-escalate a previously set cancel policy. + if cfg.agentCancelOpts != nil { + s.agentCancelOpts = cfg.agentCancelOpts + } if cfg.skipCheckpoint { s.skipCheckpoint = true } if cfg.stopCause != "" && s.stopCause == "" { s.stopCause = cfg.stopCause } + if cfg.idleFor > 0 && s.idleFor == 0 { + s.idleFor = cfg.idleFor + } s.mu.Unlock() select { case s.notify <- struct{}{}: @@ -131,6 +138,12 @@ func (s *stopSignal) getStopCause() string { return s.stopCause } +func (s *stopSignal) getIdleFor() time.Duration { + s.mu.Lock() + defer s.mu.Unlock() + return s.idleFor +} + // preemptSignal coordinates preemption between Push callers and the run loop. // // Lifecycle overview: @@ -369,6 +382,16 @@ type TurnLoopConfig[T any] struct { // - tc.Consumed: items that triggered this agent execution // - tc.Loop: allows calling Push() or Stop() directly from within the callback // - tc.Preempted / tc.Stopped: signals while processing events + // + // Error handling: the returned error is only used when the callback itself + // wants to abort the TurnLoop. The TurnLoop already captures CancelError + // from the event stream when the turn is stopped or preempted, so the + // callback should NOT propagate CancelError. In practice, return a non-nil + // error only for callback-internal failures that should terminate the loop; + // return nil when the current agent is canceled by an external Stop or + // Preempt (Preempt cancels the current agent but the loop continues with + // the next turn). + // // Optional. If not provided, events are drained and errors (except CancelError // from Stop-triggered cancellation) are returned as ExitReason. OnAgentEvents func(ctx context.Context, tc *TurnContext[T], events *AsyncIterator[*AgentEvent]) error @@ -534,9 +557,11 @@ func (l *TurnLoop[T]) planTurn( // and any items that were not processed. type TurnLoopExitState[T any] struct { // ExitReason indicates why the loop exited. - // nil means clean exit (Stop() was called and completed normally). + // nil means clean exit (Stop() was called without cancel options, or the + // agent completed normally before Stop took effect). // Non-nil values include context errors, callback errors, *CancelError, etc. - // When Stop() cancels a running agent, ExitReason will be a *CancelError. + // When Stop(WithImmediate()) or Stop(WithGraceful()) cancels a running + // agent, ExitReason will be a *CancelError. // This never contains checkpoint errors — see CheckpointErr for those. ExitReason error @@ -545,9 +570,10 @@ type TurnLoopExitState[T any] struct { // This is always valid regardless of ExitReason. UnhandledItems []T - // CanceledItems contains the items whose turn was canceled by Stop(). - // This is set when Stop() is called during a running turn, even if it - // did not contribute to the final CancelError. + // CanceledItems contains the items whose turn was actually interrupted + // by a cancel (Stop with WithImmediate, WithGraceful, or WithGracefulTimeout). + // Only populated when ExitReason is a *CancelError — if the agent finishes + // normally before the cancel takes effect, CanceledItems is empty. // It can be used to reconstruct GenInput/PrepareAgent inputs when resuming. CanceledItems []T @@ -588,11 +614,15 @@ type TurnContext[T any] struct { Consumed []T // Preempted is closed when a preempt signal fires for the current turn - // (via Push with WithPreempt) and at least one preemptive Push contributed - // to the CancelError for the current turn. + // (via Push with WithPreempt/WithPreemptTimeout) and at least one + // preemptive Push contributed to the CancelError for the current turn. // "Contributed" means the preempt's cancel options were included in the // CancelError before it was finalized. Remains open if no preempt contributed. // Use in a select to detect preemption while processing events. + // + // Both Preempted and Stopped may be closed within the same turn if both + // signals arrive while the agent is still being cancelled. Whichever + // arrives after the cancel is fully handled will not contribute. Preempted <-chan struct{} // Stopped is closed when a Stop() call contributed to the CancelError for the @@ -600,6 +630,8 @@ type TurnContext[T any] struct { // "Contributed" means Stop's cancel options were included in the CancelError // before it was finalized. Remains open if Stop did not contribute. // Use in a select to detect stop while processing events. + // + // See Preempted for the relationship between the two channels. Stopped <-chan struct{} // StopCause returns the business-supplied reason from WithStopCause. @@ -628,7 +660,7 @@ type TurnContext[T any] struct { type TurnLoop[T any] struct { config TurnLoopConfig[T] - buffer *internal.UnboundedChan[T] + buffer *turnBuffer[T] stopped int32 started int32 @@ -777,20 +809,98 @@ type turnLoopPendingResume[T any] struct { resumeBytes []byte } +// SafePoint describes at which boundary the agent may be cancelled. +// It is a bitmask: values can be combined with bitwise OR to accept multiple +// safe points (e.g. AfterToolCalls | AfterChatModel). Internally, SafePoint +// is translated to CancelMode via toCancelMode(). +// +// SafePoint is used only in the preemption API (WithPreempt/WithPreemptTimeout). +// A key design constraint: preemption always targets a safe point — the user's +// intent is to cancel at a well-defined boundary, never to abort immediately. +// Immediate cancellation is only reachable as an automatic timeout escalation +// (via WithPreemptTimeout), not as a direct user choice. This is why SafePoint +// has no "immediate" value and why WithPreempt requires a non-zero SafePoint +// (panics otherwise). +type SafePoint int + +const ( + // AfterToolCalls allows the agent to finish the current tool-call round + // before being cancelled. + AfterToolCalls SafePoint = 1 << iota + // AfterChatModel allows the agent to finish the current chat-model + // call before being cancelled. + AfterChatModel + // AnySafePoint is shorthand for AfterToolCalls | AfterChatModel. + AnySafePoint = AfterToolCalls | AfterChatModel +) + +func (sp SafePoint) toCancelMode() CancelMode { + var mode CancelMode + if sp&AfterToolCalls != 0 { + mode |= CancelAfterToolCalls + } + if sp&AfterChatModel != 0 { + mode |= CancelAfterChatModel + } + return mode +} + type stopConfig struct { agentCancelOpts []AgentCancelOption skipCheckpoint bool stopCause string + idleFor time.Duration } // StopOption is an option for Stop(). type StopOption func(*stopConfig) -// WithAgentCancel sets the agent cancel options to use when stopping the loop. -// These options control how the currently running agent is cancelled. -func WithAgentCancel(opts ...AgentCancelOption) StopOption { +// WithGraceful requests a graceful stop that waits at the nearest safe point +// (after tool calls or after a chat-model call) and propagates recursively to +// nested agents. It does not impose a time limit; use WithGracefulTimeout to +// add a grace period after which the stop escalates to immediate cancellation. +// +// WithGraceful and WithGracefulTimeout are mutually exclusive; if both are +// passed to the same Stop call, the last one wins. +func WithGraceful() StopOption { + return func(cfg *stopConfig) { + cfg.agentCancelOpts = []AgentCancelOption{ + WithAgentCancelMode(CancelAfterChatModel | CancelAfterToolCalls), + WithRecursive(), + } + } +} + +// WithImmediate aborts the running agent turn as soon as possible. +// The agent's context is cancelled immediately without waiting for any +// safe point. Nested agents inside AgentTools are torn down as a side effect. +// +// This is the most aggressive stop mode — typically used when the caller +// wants to shut down the TurnLoop with no intention of resuming. +func WithImmediate() StopOption { return func(cfg *stopConfig) { - cfg.agentCancelOpts = opts + cfg.agentCancelOpts = []AgentCancelOption{} + } +} + +// WithGracefulTimeout is like WithGraceful but adds a grace period. +// If the agent has not reached a safe point within gracePeriod, the stop +// escalates to immediate cancellation. +// +// gracePeriod must be positive; passing a zero or negative duration panics. +// +// WithGraceful and WithGracefulTimeout are mutually exclusive; if both are +// passed to the same Stop call, the last one wins. +func WithGracefulTimeout(gracePeriod time.Duration) StopOption { + if gracePeriod <= 0 { + panic("adk: WithGracefulTimeout: gracePeriod must be positive") + } + return func(cfg *stopConfig) { + cfg.agentCancelOpts = []AgentCancelOption{ + WithAgentCancelMode(CancelAfterChatModel | CancelAfterToolCalls), + WithRecursive(), + WithAgentCancelTimeout(gracePeriod), + } } } @@ -813,6 +923,34 @@ func WithStopCause(cause string) StopOption { } } +// UntilIdleFor defers the stop until the TurnLoop has been continuously idle +// (blocked between turns with no pending items) for at least the given +// duration. Each time a new item arrives the timer resets from zero. +// +// This is useful when business code monitors agent activity externally and +// wants to shut down the loop once there has been no work for a while, without +// racing with concurrent Push calls. +// +// UntilIdleFor is combinable with other StopOptions in the same call. +// For example, Stop(UntilIdleFor(30*time.Second), WithGraceful()) means +// "after 30 s of idle, stop gracefully". If another Stop call is made +// without UntilIdleFor (e.g. Stop(WithImmediate())), the loop shuts down +// immediately, bypassing the idle wait. +// +// Only the first UntilIdleFor duration takes effect; subsequent calls with +// a different duration are ignored. A Stop() call without UntilIdleFor always +// shuts down the loop immediately regardless of any pending idle timer. +// +// duration must be positive; passing a zero or negative value panics. +func UntilIdleFor(duration time.Duration) StopOption { + if duration <= 0 { + panic("adk: UntilIdleFor: duration must be positive") + } + return func(cfg *stopConfig) { + cfg.idleFor = duration + } +} + type pushConfig[T any] struct { preempt bool preemptDelay time.Duration @@ -823,23 +961,58 @@ type pushConfig[T any] struct { // PushOption is an option for Push(). type PushOption[T any] func(*pushConfig[T]) -// WithPreempt signals that the current agent should be canceled after pushing. -// This enables atomic "push + preempt" to avoid race conditions between -// pushing an urgent item and triggering preemption. -// The loop will cancel the current agent turn and continue with the next turn, -// where GenInput will see all buffered items including the newly pushed one. -func WithPreempt[T any](agentCancelOpts ...AgentCancelOption) PushOption[T] { +// WithPreempt signals that the current agent turn should be cancelled at the +// specified safePoint after pushing the new item. The loop cancels the current +// turn and starts a new one, where GenInput will see all buffered items +// including the newly pushed one. +// Use WithPreemptTimeout to add a timeout that escalates to immediate abort. +// +// Because safe points fire at turn-level boundaries (after the chat model +// returns or after all tool calls complete), no nested agent is running at +// the moment of cancellation — nested agents within AgentTools have either +// not started yet (AfterChatModel) or already finished (AfterToolCalls). +// If the preemption escalates to immediate via WithPreemptTimeout, any +// in-flight nested agent is torn down through Go context cancellation. +// +// WithPreempt and WithPreemptTimeout are mutually exclusive; if both are +// passed to the same Push call, the last one wins. +// +// safePoint must not be zero; passing SafePoint(0) panics. +func WithPreempt[T any](safePoint SafePoint) PushOption[T] { + if safePoint == 0 { + panic("adk: SafePoint must not be zero; use AfterToolCalls, AfterChatModel, or AnySafePoint") + } return func(cfg *pushConfig[T]) { cfg.preempt = true - cfg.agentCancelOpts = agentCancelOpts + cfg.agentCancelOpts = []AgentCancelOption{ + WithAgentCancelMode(safePoint.toCancelMode()), + } + } +} + +// WithPreemptTimeout is like WithPreempt but adds a timeout. If the agent has +// not reached the safe point within timeout, the preemption escalates to +// immediate cancellation. +// +// safePoint must not be zero; passing SafePoint(0) panics. +func WithPreemptTimeout[T any](safePoint SafePoint, timeout time.Duration) PushOption[T] { + if safePoint == 0 { + panic("adk: SafePoint must not be zero; use AfterToolCalls, AfterChatModel, or AnySafePoint") + } + return func(cfg *pushConfig[T]) { + cfg.preempt = true + cfg.agentCancelOpts = []AgentCancelOption{ + WithAgentCancelMode(safePoint.toCancelMode()), + WithAgentCancelTimeout(timeout), + } } } // WithPreemptDelay sets a delay duration before preemption takes effect. -// When used with WithPreempt, the push will succeed immediately, but the -// preemption signal will be delayed by the specified duration. -// This allows the current agent to continue processing for a grace period -// before being preempted. +// When used with WithPreempt or WithPreemptTimeout, the push will succeed +// immediately, but the preemption signal will be delayed by the specified +// duration. This allows the current agent to continue processing for a grace +// period before being preempted. func WithPreemptDelay[T any](delay time.Duration) PushOption[T] { return func(cfg *pushConfig[T]) { cfg.preemptDelay = delay @@ -861,7 +1034,7 @@ func WithPreemptDelay[T any](delay time.Duration) PushOption[T] { // return nil // between turns, plain push // } // if isLowPriority(tc.Consumed) { -// return []PushOption[MyItem]{WithPreempt[MyItem]()} +// return []PushOption[MyItem]{WithPreempt[MyItem](AnySafePoint)} // } // return nil // don't preempt high-priority work // })) @@ -900,7 +1073,7 @@ func NewTurnLoop[T any](cfg TurnLoopConfig[T]) *TurnLoop[T] { l := &TurnLoop[T]{ config: cfg, - buffer: internal.NewUnboundedChan[T](), + buffer: newTurnBuffer[T](), done: make(chan struct{}), stopSig: newStopSignal(), preemptSig: newPreemptSignal(), @@ -944,12 +1117,14 @@ func (l *TurnLoop[T]) Run(ctx context.Context) { // Once TakeLateItems() has been called, any subsequent push that would become a // late item will panic instead of being silently dropped. // -// Use WithPreempt() to atomically push an item and signal preemption of the current agent. -// This is useful for urgent items that should interrupt the current processing. +// Use WithPreempt() or WithPreemptTimeout() to atomically push an item and signal +// preemption of the current agent. This is useful for urgent items that should +// interrupt the current processing. // The returned channel may be waited on if the caller needs to ensure the preempt // signal has been observed. // -// Use WithPreemptDelay() together with WithPreempt() to delay the preemption signal. +// Use WithPreemptDelay() together with WithPreempt()/WithPreemptTimeout() to delay +// the preemption signal. // Push returns immediately after the item is buffered, and a goroutine is spawned // to signal preemption after the delay. func (l *TurnLoop[T]) Push(item T, opts ...PushOption[T]) (bool, <-chan struct{}) { @@ -1077,18 +1252,29 @@ func (l *TurnLoop[T]) pushWithConfig(item T, cfg *pushConfig[T]) (bool, <-chan s } // Stop signals the loop to stop and returns immediately (non-blocking). -// The loop will finish the current turn (or cancel it via WithAgentCancel options), -// then exit without starting a new turn. -// Use WithAgentCancel to control how the currently running agent is cancelled. -// This method is idempotent - multiple calls update cancel options. +// Without options, the current agent turn runs to completion and the loop +// exits at the turn boundary without starting a new turn. ExitReason is nil. +// +// Use WithImmediate() to abort the running agent turn immediately. +// Use WithGraceful() to cancel at the nearest safe point with recursive +// propagation to nested agents. +// Use WithGracefulTimeout() for safe-point cancel with an escalation deadline. +// Use UntilIdleFor() to defer the stop until the loop has been continuously +// idle for a given duration; the loop shuts down automatically once the idle +// timer fires. +// +// This method may be called multiple times; subsequent calls update cancel options. +// A Stop() call without UntilIdleFor shuts down the loop immediately, even if +// a prior UntilIdleFor is still waiting. // Call Wait() to block until the loop has fully exited and get the result. // // Stop may be called before Run. In that case, the stopped flag is set and // a subsequent Run will exit the loop immediately. // // If the running agent does not support the WithCancel AgentRunOption, -// Stop degrades to "exit the loop on entering the next iteration" — the -// current agent turn runs to completion before the loop exits. +// all cancel-related options (WithImmediate, WithGraceful, WithGracefulTimeout) +// degrade to "exit the loop on entering the next iteration" — the current +// agent turn runs to completion before the loop exits. func (l *TurnLoop[T]) Stop(opts ...StopOption) { cfg := &stopConfig{} for _, opt := range opts { @@ -1097,6 +1283,14 @@ func (l *TurnLoop[T]) Stop(opts ...StopOption) { l.stopSig.signal(cfg) + if cfg.idleFor > 0 { + l.buffer.Wakeup() + return + } + l.commitStop() +} + +func (l *TurnLoop[T]) commitStop() { l.stopOnce.Do(func() { l.stopSig.closeDone() atomic.StoreInt32(&l.stopped, 1) @@ -1109,7 +1303,7 @@ func (l *TurnLoop[T]) Stop(opts ...StopOption) { // All callers will receive the same result. // // Wait blocks until Run is called AND the loop exits. If Run is -// ever called, Wait blocks forever. +// never called, Wait blocks forever. func (l *TurnLoop[T]) Wait() *TurnLoopExitState[T] { <-l.done return l.result @@ -1153,7 +1347,36 @@ func (l *TurnLoop[T]) run(ctx context.Context) { pushBack = append(pushBack, pr.unhandled...) pushBack = append(pushBack, pr.newItems...) } else { - first, ok := l.buffer.Receive() + var first T + var ok bool + + if idleFor := l.stopSig.getIdleFor(); idleFor > 0 { + l.buffer.ClearWakeup() + idleTimer := time.NewTimer(idleFor) + cancelIdle := make(chan struct{}) + // When the idle timer fires, commitStop closes the buffer via + // buffer.Close(), which broadcasts to unblock the pending + // Receive() call below. + go func() { + select { + case <-idleTimer.C: + l.commitStop() + case <-cancelIdle: + } + }() + + first, ok = l.buffer.Receive() + + idleTimer.Stop() + close(cancelIdle) + } else { + first, ok = l.buffer.Receive() + // Woken up by Stop(UntilIdleFor); re-enter loop to start the idle timer. + if !ok && l.stopSig.getIdleFor() > 0 { + continue + } + } + if !ok { if err := ctx.Err(); err != nil { l.runErr = err @@ -1233,6 +1456,9 @@ func (l *TurnLoop[T]) run(ctx context.Context) { l.preemptSig.endTurnAndUnhold() if runErr != nil { + if errors.As(runErr, new(*CancelError)) && len(l.canceledItems) == 0 { + l.canceledItems = append([]T{}, plan.spec.consumed...) + } l.runErr = runErr return } @@ -1313,47 +1539,36 @@ func (l *TurnLoop[T]) watchPreemptSignal(done <-chan struct{}, agentCancelFunc A func (l *TurnLoop[T]) watchStopSignal(done <-chan struct{}, agentCancelFunc AgentCancelFunc, stoppedDone chan struct{}) { var lastGen uint64 stoppedClosed := false + + tryCancel := func(gen uint64, opts []AgentCancelOption) { + if gen == lastGen { + return + } + lastGen = gen + if opts == nil { + return + } + _, contributed := agentCancelFunc(opts...) + if contributed && !stoppedClosed { + close(stoppedDone) + stoppedClosed = true + } + } + for { select { case <-done: return case <-l.stopSig.notify: - gen, opts := l.stopSig.check() - if gen != lastGen { - lastGen = gen - // CancelHandle is intentionally not awaited here: agentCancelFunc - // commits the cancel signal synchronously, while waiting would block - // until the turn finishes and can deadlock this watcher against done. - _, contributed := agentCancelFunc(opts...) - if contributed && !stoppedClosed { - close(stoppedDone) - stoppedClosed = true - } - } + tryCancel(l.stopSig.check()) case <-l.stopSig.done: - gen, opts := l.stopSig.check() - if gen != lastGen { - lastGen = gen - _, contributed := agentCancelFunc(opts...) - if contributed && !stoppedClosed { - close(stoppedDone) - stoppedClosed = true - } - } + tryCancel(l.stopSig.check()) for { select { case <-done: return case <-l.stopSig.notify: - gen, opts := l.stopSig.check() - if gen != lastGen { - lastGen = gen - _, contributed := agentCancelFunc(opts...) - if contributed && !stoppedClosed { - close(stoppedDone) - stoppedClosed = true - } - } + tryCancel(l.stopSig.check()) } } } @@ -1366,11 +1581,6 @@ func (l *TurnLoop[T]) runAgentAndHandleEvents( spec *turnRunSpec[T], ) error { var iter *AsyncIterator[*AgentEvent] - defer func() { - if l.stopSig.isStopped() && len(l.canceledItems) == 0 { - l.canceledItems = append([]T{}, spec.consumed...) - } - }() runOpts, ms, err := l.setupBridgeStore(spec, spec.runOpts) if err != nil { diff --git a/adk/turn_loop_test.go b/adk/turn_loop_test.go index 1b8b2c86d..4f22ca1a7 100644 --- a/adk/turn_loop_test.go +++ b/adk/turn_loop_test.go @@ -26,6 +26,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/cloudwego/eino/schema" ) @@ -166,6 +167,42 @@ func newAndRunTurnLoop[T any](ctx context.Context, cfg TurnLoopConfig[T]) *TurnL return l } +func newPreemptTestLoop(t *testing.T, agent *turnLoopCancellableMockAgent) *TurnLoop[string] { + t.Helper() + + agentStarted := make(chan struct{}) + agentStartedOnce := sync.Once{} + + originalRunFunc := agent.runFunc + agent.runFunc = func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + agentStartedOnce.Do(func() { close(agentStarted) }) + return originalRunFunc(ctx, input) + } + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: []string{items[0]}, + Remaining: items[1:], + }, nil + }, + }) + + loop.Push("first") + + select { + case <-agentStarted: + case <-time.After(1 * time.Second): + t.Fatal("agent did not start") + } + + return loop +} + func TestTurnLoop_RunAndPush(t *testing.T) { processedItems := make([]string, 0) var mu sync.Mutex @@ -197,7 +234,7 @@ func TestTurnLoop_RunAndPush(t *testing.T) { defer mu.Unlock() assert.NoError(t, result.ExitReason) - assert.True(t, len(processedItems) > 0, "should have processed at least one item") + assert.NotEmpty(t, processedItems, "should have processed at least one item") } func TestTurnLoop_PushReturnsErrorAfterStop(t *testing.T) { @@ -296,7 +333,7 @@ func TestTurnLoop_UnhandledItemsOnStop(t *testing.T) { close(blocked) result := loop.Wait() - assert.True(t, len(result.UnhandledItems) >= 0, "should return unhandled items") + assert.NotEmpty(t, result.UnhandledItems, "should return unhandled items") } func TestTurnLoop_GenInputError(t *testing.T) { @@ -368,7 +405,7 @@ func TestTurnLoop_BatchProcessing(t *testing.T) { mu.Lock() defer mu.Unlock() - assert.True(t, len(batches) > 0, "should have processed at least one batch") + assert.NotEmpty(t, batches, "should have processed at least one batch") } func TestTurnLoop_StopWithMode(t *testing.T) { @@ -381,7 +418,7 @@ func TestTurnLoop_StopWithMode(t *testing.T) { }, }) - loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelAfterToolCalls))) + loop.Stop(WithGraceful()) result := loop.Wait() assert.NoError(t, result.ExitReason) @@ -438,7 +475,7 @@ func TestTurnLoop_Preempt_CancelsCurrentAgent(t *testing.T) { t.Fatal("agent did not start") } - loop.Push("urgent", WithPreempt[string]()) + loop.Push("urgent", WithPreempt[string](AnySafePoint)) select { case <-agentCancelled: @@ -516,7 +553,7 @@ func TestTurnLoop_Preempt_DiscardsConsumedItems(t *testing.T) { t.Fatal("agent did not start") } - loop.Push("urgent", WithPreempt[string]()) + loop.Push("urgent", WithPreempt[string](AnySafePoint)) select { case <-agentDone: @@ -530,17 +567,13 @@ func TestTurnLoop_Preempt_DiscardsConsumedItems(t *testing.T) { mu.Lock() defer mu.Unlock() - assert.GreaterOrEqual(t, len(genInputResults), 2) - if len(genInputResults) >= 2 { - assert.NotContains(t, genInputResults[1], "first") - assert.Contains(t, genInputResults[1], "urgent") - } + require.GreaterOrEqual(t, len(genInputResults), 2) + assert.NotContains(t, genInputResults[1], "first") + assert.Contains(t, genInputResults[1], "urgent") } func TestTurnLoop_Preempt_WithAgentCancelMode(t *testing.T) { - agentStarted := make(chan struct{}) cancelFuncCalled := make(chan struct{}) - agentStartedOnce := sync.Once{} cancelFuncCalledOnce := sync.Once{} firstCancelModeUsed := CancelImmediate var cancelModeMu sync.Mutex @@ -548,9 +581,6 @@ func TestTurnLoop_Preempt_WithAgentCancelMode(t *testing.T) { agent := &turnLoopCancellableMockAgent{ name: "test", runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { - agentStartedOnce.Do(func() { - close(agentStarted) - }) <-ctx.Done() return &AgentOutput{}, nil }, @@ -564,28 +594,9 @@ func TestTurnLoop_Preempt_WithAgentCancelMode(t *testing.T) { }, } - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { - return agent, nil - }, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ - Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, - Consumed: []string{items[0]}, - Remaining: items[1:], - }, nil - }, - }) - - loop.Push("first") + loop := newPreemptTestLoop(t, agent) - select { - case <-agentStarted: - case <-time.After(1 * time.Second): - t.Fatal("agent did not start") - } - - loop.Push("urgent", WithPreempt[string](WithAgentCancelMode(CancelAfterToolCalls))) + loop.Push("urgent", WithPreempt[string](AfterToolCalls)) select { case <-cancelFuncCalled: @@ -603,16 +614,13 @@ func TestTurnLoop_Preempt_WithAgentCancelMode(t *testing.T) { } func TestTurnLoop_PreemptAck_ClosesAfterCancelIsInitiated(t *testing.T) { - agentStarted := make(chan struct{}) cancelObserved := make(chan struct{}) agentFinishGate := make(chan struct{}) - agentStartedOnce := sync.Once{} cancelObservedOnce := sync.Once{} agent := &turnLoopCancellableMockAgent{ name: "test", runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { - agentStartedOnce.Do(func() { close(agentStarted) }) <-ctx.Done() <-agentFinishGate return &AgentOutput{}, nil @@ -622,28 +630,9 @@ func TestTurnLoop_PreemptAck_ClosesAfterCancelIsInitiated(t *testing.T) { }, } - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { - return agent, nil - }, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ - Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, - Consumed: []string{items[0]}, - Remaining: items[1:], - }, nil - }, - }) - - _, _ = loop.Push("first") + loop := newPreemptTestLoop(t, agent) - select { - case <-agentStarted: - case <-time.After(1 * time.Second): - t.Fatal("agent did not start") - } - - ok, ack := loop.Push("urgent", WithPreempt[string](WithAgentCancelMode(CancelAfterToolCalls))) + ok, ack := loop.Push("urgent", WithPreempt[string](AfterToolCalls)) assert.True(t, ok) assert.NotNil(t, ack) @@ -676,7 +665,7 @@ func TestTurnLoop_PreemptAck_ClosesImmediatelyIfLoopNotStarted(t *testing.T) { }, }) - ok, ack := loop.Push("urgent", WithPreempt[string]()) + ok, ack := loop.Push("urgent", WithPreempt[string](AnySafePoint)) assert.True(t, ok) assert.NotNil(t, ack) @@ -688,10 +677,8 @@ func TestTurnLoop_PreemptAck_ClosesImmediatelyIfLoopNotStarted(t *testing.T) { } func TestTurnLoop_Preempt_EscalatesOnSecondPreempt(t *testing.T) { - agentStarted := make(chan struct{}) firstCancelSeen := make(chan struct{}) agentFinishGate := make(chan struct{}) - agentStartedOnce := sync.Once{} firstCancelOnce := sync.Once{} var ccPtr atomic.Value @@ -699,7 +686,6 @@ func TestTurnLoop_Preempt_EscalatesOnSecondPreempt(t *testing.T) { agent := &turnLoopCancellableMockAgent{ name: "test", runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { - agentStartedOnce.Do(func() { close(agentStarted) }) <-ctx.Done() <-agentFinishGate return &AgentOutput{}, nil @@ -710,35 +696,18 @@ func TestTurnLoop_Preempt_EscalatesOnSecondPreempt(t *testing.T) { }, } - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { - return agent, nil - }, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ - Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, - Consumed: []string{items[0]}, - Remaining: items[1:], - }, nil - }, - }) + loop := newPreemptTestLoop(t, agent) - loop.Push("first") - select { - case <-agentStarted: - case <-time.After(1 * time.Second): - t.Fatal("agent did not start") - } - - loop.Push("urgent1", WithPreempt[string](WithAgentCancelMode(CancelAfterChatModel))) + loop.Push("urgent1", WithPreempt[string](AfterChatModel)) select { case <-firstCancelSeen: case <-time.After(1 * time.Second): t.Fatal("first preempt did not trigger cancel") } - loop.Push("urgent2", WithPreempt[string](WithAgentCancelMode(CancelImmediate))) + loop.Push("urgent2", WithPreemptTimeout[string](AnySafePoint, time.Millisecond)) + wantMode := CancelAfterChatModel | CancelAfterToolCalls deadline := time.Now().Add(1 * time.Second) for time.Now().Before(deadline) { v := ccPtr.Load() @@ -747,7 +716,7 @@ func TestTurnLoop_Preempt_EscalatesOnSecondPreempt(t *testing.T) { continue } cc := v.(*cancelContext) - if cc.getMode() == CancelImmediate && atomic.LoadInt32(&cc.escalated) == 1 { + if cc.getMode() == wantMode && atomic.LoadInt32(&cc.escalated) == 1 { break } time.Sleep(5 * time.Millisecond) @@ -758,7 +727,7 @@ func TestTurnLoop_Preempt_EscalatesOnSecondPreempt(t *testing.T) { t.Fatal("cancel context was not captured") } cc := v.(*cancelContext) - assert.Equal(t, CancelImmediate, cc.getMode()) + assert.Equal(t, wantMode, cc.getMode()) assert.Equal(t, int32(1), atomic.LoadInt32(&cc.escalated)) close(agentFinishGate) @@ -769,10 +738,8 @@ func TestTurnLoop_Preempt_EscalatesOnSecondPreempt(t *testing.T) { } func TestTurnLoop_Preempt_JoinsSafePointModesOnSecondPreempt(t *testing.T) { - agentStarted := make(chan struct{}) firstCancelSeen := make(chan struct{}) agentFinishGate := make(chan struct{}) - agentStartedOnce := sync.Once{} firstCancelOnce := sync.Once{} var ccPtr atomic.Value @@ -780,7 +747,6 @@ func TestTurnLoop_Preempt_JoinsSafePointModesOnSecondPreempt(t *testing.T) { agent := &turnLoopCancellableMockAgent{ name: "test", runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { - agentStartedOnce.Do(func() { close(agentStarted) }) <-ctx.Done() <-agentFinishGate return &AgentOutput{}, nil @@ -791,34 +757,16 @@ func TestTurnLoop_Preempt_JoinsSafePointModesOnSecondPreempt(t *testing.T) { }, } - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { - return agent, nil - }, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ - Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, - Consumed: []string{items[0]}, - Remaining: items[1:], - }, nil - }, - }) - - loop.Push("first") - select { - case <-agentStarted: - case <-time.After(1 * time.Second): - t.Fatal("agent did not start") - } + loop := newPreemptTestLoop(t, agent) - loop.Push("urgent1", WithPreempt[string](WithAgentCancelMode(CancelAfterChatModel))) + loop.Push("urgent1", WithPreempt[string](AfterChatModel)) select { case <-firstCancelSeen: case <-time.After(1 * time.Second): t.Fatal("first preempt did not trigger cancel") } - loop.Push("urgent2", WithPreempt[string](WithAgentCancelMode(CancelAfterToolCalls))) + loop.Push("urgent2", WithPreempt[string](AfterToolCalls)) want := CancelAfterChatModel | CancelAfterToolCalls deadline := time.Now().Add(1 * time.Second) @@ -952,7 +900,7 @@ func TestTurnLoop_PreemptDelay_NoMispreemptOnNaturalCompletion(t *testing.T) { t.Fatal("agent1 did not start") } - loop.Push("second", WithPreempt[string](), WithPreemptDelay[string](500*time.Millisecond)) + loop.Push("second", WithPreempt[string](AnySafePoint), WithPreemptDelay[string](500*time.Millisecond)) select { case <-agent1Done: @@ -1093,7 +1041,7 @@ func TestTurnLoop_GetAgentError_RecoverConsumed(t *testing.T) { result := loop.Wait() assert.ErrorIs(t, result.ExitReason, agentErr) - assert.True(t, len(result.UnhandledItems) >= 1, "should recover at least the consumed item and remaining") + assert.NotEmpty(t, result.UnhandledItems, "should recover at least the consumed item and remaining") } func TestTurnLoop_GenInputError_RecoverItems(t *testing.T) { @@ -1293,7 +1241,7 @@ func TestTurnLoop_ContextCancelAfterGenInput_RecoverItems(t *testing.T) { result := loop.Wait() assert.ErrorIs(t, result.ExitReason, context.Canceled) - assert.True(t, len(result.UnhandledItems) >= 1, "should recover consumed and remaining items") + assert.NotEmpty(t, result.UnhandledItems, "should recover consumed and remaining items") } func TestTurnLoop_OnAgentEventsReceivesEvents(t *testing.T) { @@ -1340,7 +1288,7 @@ func TestTurnLoop_OnAgentEventsReceivesEvents(t *testing.T) { mu.Lock() defer mu.Unlock() - assert.True(t, len(receivedConsumed) > 0, "should have received consumed items") + assert.NotEmpty(t, receivedConsumed, "should have received consumed items") } func TestTurnLoop_StopDuringAgentExecution(t *testing.T) { @@ -1372,11 +1320,11 @@ func TestTurnLoop_StopDuringAgentExecution(t *testing.T) { loop.Push("msg1") <-agentStarted - loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + loop.Stop() result := loop.Wait() assert.NoError(t, result.ExitReason) - assert.Equal(t, []string{"msg1"}, result.CanceledItems) + assert.Empty(t, result.CanceledItems) } func TestTurnLoop_StopCheckPointIDInCancelError(t *testing.T) { @@ -1420,7 +1368,7 @@ func TestTurnLoop_StopCheckPointIDInCancelError(t *testing.T) { loop.Push("msg1") <-modelStarted - loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + loop.Stop() result := loop.Wait() @@ -1472,7 +1420,7 @@ func TestTurnLoop_StopWithoutCheckpointIDDoesNotPersist(t *testing.T) { loop.Push("msg1") <-modelStarted - loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + loop.Stop() result := loop.Wait() @@ -1615,7 +1563,7 @@ func TestTurnLoop_StopDuringAgentExecution_PersistAndResume(t *testing.T) { loop.Push("msg1") <-modelStarted - loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + loop.Stop() exit := loop.Wait() store.mu.Lock() @@ -1850,7 +1798,7 @@ func TestTurnLoop_CheckpointSaveError_ReturnsError(t *testing.T) { }) loop.Push("msg1") <-modelStarted - loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + loop.Stop() exit := loop.Wait() assert.Error(t, exit.ExitReason) assert.True(t, exit.Checkpointed) @@ -2092,7 +2040,7 @@ func TestTurnLoop_GenResumeNil_Error(t *testing.T) { }) loop1.Push("msg1") <-modelStarted - loop1.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + loop1.Stop() loop1.Wait() loop2 := NewTurnLoop(TurnLoopConfig[string]{ @@ -2264,7 +2212,7 @@ func TestTurnLoop_GenResumeReturnsError(t *testing.T) { }) loop1.Push("msg1") <-modelStarted - loop1.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + loop1.Stop() loop1.Wait() genResumeErr := fmt.Errorf("resume callback failed") @@ -2323,7 +2271,7 @@ func TestTurnLoop_CheckpointSaveError_MergesWithExistingError(t *testing.T) { }) loop.Push("msg1") <-modelStarted - loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + loop.Stop() exit := loop.Wait() assert.Error(t, exit.ExitReason) var ce *CancelError @@ -2371,7 +2319,7 @@ func TestTurnLoop_ResumeWithParams(t *testing.T) { }) loop1.Push("msg1") <-modelStarted - loop1.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + loop1.Stop() exit1 := loop1.Wait() var ce *CancelError assert.True(t, errors.As(exit1.ExitReason, &ce)) @@ -2413,25 +2361,6 @@ func TestTurnLoop_ResumeWithParams(t *testing.T) { _ = exit2 } -func TestTurnLoop_StopOptionsArePassed(t *testing.T) { - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ - Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, - Consumed: items, - }, nil - }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { - return &turnLoopMockAgent{name: "test"}, nil - }, - }) - - loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelAfterToolCalls))) - - result := loop.Wait() - assert.NoError(t, result.ExitReason) -} - func TestTurnLoop_Stop_EscalatesCancelMode(t *testing.T) { ctx := context.Background() agentStarted := make(chan *cancelContext, 1) @@ -2451,8 +2380,8 @@ func TestTurnLoop_Stop_EscalatesCancelMode(t *testing.T) { loop.Push("msg1") cc := <-agentStarted - loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelAfterToolCalls), WithAgentCancelTimeout(10*time.Second))) - loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + loop.Stop(WithGracefulTimeout(10 * time.Second)) + loop.Stop(WithImmediate()) deadline := time.After(1 * time.Second) for { @@ -2469,7 +2398,7 @@ func TestTurnLoop_Stop_EscalatesCancelMode(t *testing.T) { exit := loop.Wait() var ce *CancelError - assert.True(t, errors.As(exit.ExitReason, &ce)) + require.True(t, errors.As(exit.ExitReason, &ce)) assert.Equal(t, CancelImmediate, ce.Info.Mode) } @@ -2587,7 +2516,7 @@ func TestTurnLoop_PushFromOnAgentEvents(t *testing.T) { result := loop.Wait() assert.NoError(t, result.ExitReason) - assert.GreaterOrEqual(t, atomic.LoadInt32(&pushCount), int32(2)) + assert.Equal(t, int32(2), atomic.LoadInt32(&pushCount)) } // Tests for NewTurnLoop: the permissive API where Push, Stop, and Wait are @@ -2894,7 +2823,7 @@ func TestTurnLoop_TurnContext_PreemptedChannel(t *testing.T) { loop.Push("msg1") <-agentStarted - loop.Push("msg2", WithPreempt[string](WithAgentCancelMode(CancelImmediate))) + loop.Push("msg2", WithPreemptTimeout[string](AnySafePoint, time.Millisecond)) select { case <-preemptedSeen: @@ -3103,7 +3032,7 @@ func TestTurnLoop_ConcurrentPreemptsDuringTurn(t *testing.T) { wg.Add(1) go func(i int) { defer wg.Done() - ok, ack := loop.Push(fmt.Sprintf("urgent-%d", i), WithPreempt[string]()) + ok, ack := loop.Push(fmt.Sprintf("urgent-%d", i), WithPreemptTimeout[string](AnySafePoint, time.Millisecond)) if ok && ack != nil { select { case <-ack: @@ -3156,7 +3085,7 @@ func TestTurnLoop_PreemptDuringTurnTransition(t *testing.T) { time.Sleep(50 * time.Millisecond) - ok, ack := loop.Push("transitional", WithPreempt[string]()) + ok, ack := loop.Push("transitional", WithPreempt[string](AnySafePoint)) assert.True(t, ok, "push should succeed") if ack != nil { select { @@ -3233,7 +3162,7 @@ func TestTurnLoop_PushStrategy_DuringTurnTransition(t *testing.T) { atomic.StoreInt32(&strategyTCNotNil, 1) } <-strategyBlocker - return []PushOption[string]{WithPreempt[string]()} + return []PushOption[string]{WithPreempt[string](AnySafePoint)} })) }() @@ -3298,7 +3227,7 @@ func TestTurnLoop_ConcurrentPreemptAndStop(t *testing.T) { go func() { defer wg.Done() - _, ack := loop.Push("preempt-item", WithPreempt[string]()) + _, ack := loop.Push("preempt-item", WithPreempt[string](AnySafePoint)) if ack != nil { <-ack } @@ -3360,7 +3289,7 @@ func TestTurnLoop_ConcurrentPushStrategyAndStop(t *testing.T) { go func() { defer wg.Done() _, ack := loop.Push("strategic-item", WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { - return []PushOption[string]{WithPreempt[string]()} + return []PushOption[string]{WithPreempt[string](AnySafePoint)} })) if ack != nil { <-ack @@ -3417,7 +3346,7 @@ func TestTurnLoop_TurnContext_StoppedChannel(t *testing.T) { loop.Push("msg1") <-agentStarted - loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + loop.Stop() select { case <-stoppedSeen: @@ -3429,6 +3358,110 @@ func TestTurnLoop_TurnContext_StoppedChannel(t *testing.T) { loop.Wait() } +func TestTurnLoop_TurnContext_BothPreemptedAndStopped(t *testing.T) { + t.Run("PreemptThenStop_OnlyPreemptContributes", func(t *testing.T) { + preemptedSeen := make(chan struct{}) + agentStarted := make(chan struct{}) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopCancellableMockAgent{ + name: "slow", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + <-ctx.Done() + return nil, ctx.Err() + }, + }, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + close(agentStarted) + select { + case <-tc.Preempted: + close(preemptedSeen) + case <-time.After(5 * time.Second): + t.Error("timed out waiting for Preempted") + } + for { + if _, ok := events.Next(); !ok { + break + } + } + return nil + }, + }) + + loop.Push("msg1") + <-agentStarted + loop.Push("msg2", WithPreemptTimeout[string](AnySafePoint, time.Millisecond)) + + select { + case <-preemptedSeen: + case <-time.After(5 * time.Second): + t.Fatal("Preempted channel was never closed") + } + + loop.Stop(WithImmediate()) + loop.Wait() + }) + + t.Run("StopThenPreempt_OnlyStopContributes", func(t *testing.T) { + stoppedSeen := make(chan struct{}) + agentStarted := make(chan struct{}) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopCancellableMockAgent{ + name: "slow", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + <-ctx.Done() + return nil, ctx.Err() + }, + }, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + close(agentStarted) + select { + case <-tc.Stopped: + close(stoppedSeen) + case <-time.After(5 * time.Second): + t.Error("timed out waiting for Stopped") + } + for { + if _, ok := events.Next(); !ok { + break + } + } + return nil + }, + }) + + loop.Push("msg1") + <-agentStarted + loop.Stop(WithImmediate()) + + select { + case <-stoppedSeen: + case <-time.After(5 * time.Second): + t.Fatal("Stopped channel was never closed") + } + + loop.Push("msg2", WithPreemptTimeout[string](AnySafePoint, time.Millisecond)) + loop.Wait() + }) +} + func TestTurnLoop_PushStrategy_DuringTurn(t *testing.T) { agentStarted := make(chan struct{}) agentStartedOnce := sync.Once{} @@ -3486,7 +3519,7 @@ func TestTurnLoop_PushStrategy_DuringTurn(t *testing.T) { loop.Push("urgent", WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { atomic.AddInt32(&strategyCalled, 1) strategyTC = tc - return []PushOption[string]{WithPreempt[string]()} + return []PushOption[string]{WithPreempt[string](AnySafePoint)} })) select { @@ -3608,7 +3641,7 @@ func TestTurnLoop_PushStrategy_OverridesOtherOptions(t *testing.T) { // Strategy returns nil (no preempt), even though WithPreempt is also passed. // The strategy should override — so the agent should NOT be preempted. - ok, ack := loop.Push("item", WithPreempt[string](), WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { + ok, ack := loop.Push("item", WithPreempt[string](AnySafePoint), WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { return nil // no preempt })) assert.True(t, ok) @@ -3666,7 +3699,7 @@ func TestTurnLoop_PushStrategy_NestedStrategyStripped(t *testing.T) { return []PushOption[string]{ WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { atomic.AddInt32(&innerCalled, 1) - return []PushOption[string]{WithPreempt[string]()} + return []PushOption[string]{WithPreempt[string](AnySafePoint)} }), } })) @@ -3735,7 +3768,7 @@ func TestTurnLoop_PushStrategy_ConsumedInspection(t *testing.T) { // Strategy checks Consumed and preempts because current turn has "low-priority" items. loop.Push("urgent-task", WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { if tc != nil && len(tc.Consumed) > 0 && tc.Consumed[0] == "low-priority-task" { - return []PushOption[string]{WithPreempt[string]()} + return []PushOption[string]{WithPreempt[string](AnySafePoint)} } return nil })) @@ -4235,7 +4268,7 @@ func TestTurnLoop_StopCause_InTurnContext(t *testing.T) { loop.Push("msg1") <-agentStarted - loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate)), WithStopCause(cause)) + loop.Stop(WithStopCause(cause)) select { case c := <-gotCause: @@ -4280,13 +4313,68 @@ func TestTurnLoop_StopCause_FirstNonEmptyWins(t *testing.T) { loop.Push("msg1") <-agentStarted - loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelAfterToolCalls)), WithStopCause("first cause")) - loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate)), WithStopCause("second cause")) + loop.Stop(WithGraceful(), WithStopCause("first cause")) + loop.Stop(WithStopCause("second cause")) exit := loop.Wait() assert.Equal(t, "first cause", exit.StopCause, "first non-empty StopCause should win") } +func TestTurnLoop_StopBeforeRun_PushThenStop(t *testing.T) { + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + t.Fatal("GenInput should not be called when Stop is called before Run") + return nil, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + t.Fatal("PrepareAgent should not be called when Stop is called before Run") + return nil, nil + }, + }) + + ok, _ := loop.Push("item1") + assert.True(t, ok) + ok, _ = loop.Push("item2") + assert.True(t, ok) + + loop.Stop() + loop.Run(context.Background()) + result := loop.Wait() + + assert.NoError(t, result.ExitReason) + assert.Equal(t, []string{"item1", "item2"}, result.UnhandledItems) + assert.Empty(t, result.CanceledItems) + assert.Empty(t, result.TakeLateItems()) +} + +func TestTurnLoop_StopBeforeRun_StopThenPush(t *testing.T) { + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + t.Fatal("GenInput should not be called when Stop is called before Run") + return nil, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + t.Fatal("PrepareAgent should not be called when Stop is called before Run") + return nil, nil + }, + }) + + loop.Stop() + + ok, _ := loop.Push("item1") + assert.False(t, ok) + ok, _ = loop.Push("item2") + assert.False(t, ok) + + loop.Run(context.Background()) + result := loop.Wait() + + assert.NoError(t, result.ExitReason) + assert.Empty(t, result.UnhandledItems) + assert.Empty(t, result.CanceledItems) + assert.Equal(t, []string{"item1", "item2"}, result.TakeLateItems()) +} + func TestTurnLoop_SkipCheckpoint_Sticky(t *testing.T) { agentStarted := make(chan struct{}) @@ -4324,8 +4412,8 @@ func TestTurnLoop_SkipCheckpoint_Sticky(t *testing.T) { loop.Push("msg1") <-agentStarted - loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelAfterToolCalls)), WithSkipCheckpoint()) - loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + loop.Stop(WithGraceful(), WithSkipCheckpoint()) + loop.Stop() exit := loop.Wait() assert.False(t, exit.Checkpointed, "SkipCheckpoint should be sticky across multiple Stop calls") @@ -4335,3 +4423,217 @@ func TestTurnLoop_SkipCheckpoint_Sticky(t *testing.T) { store.mu.Unlock() assert.False(t, exists, "no checkpoint should be saved when SkipCheckpoint was set in any Stop call") } + +func TestWithGracefulTimeout_NonPositive_Panics(t *testing.T) { + assert.PanicsWithValue(t, "adk: WithGracefulTimeout: gracePeriod must be positive", + func() { WithGracefulTimeout(0) }) + assert.PanicsWithValue(t, "adk: WithGracefulTimeout: gracePeriod must be positive", + func() { WithGracefulTimeout(-1 * time.Second) }) +} + +func TestWithPreempt_ZeroSafePoint_Panics(t *testing.T) { + assert.PanicsWithValue(t, "adk: SafePoint must not be zero; use AfterToolCalls, AfterChatModel, or AnySafePoint", + func() { WithPreempt[string](SafePoint(0)) }) +} + +func TestWithPreemptTimeout_ZeroSafePoint_Panics(t *testing.T) { + assert.PanicsWithValue(t, "adk: SafePoint must not be zero; use AfterToolCalls, AfterChatModel, or AnySafePoint", + func() { WithPreemptTimeout[string](SafePoint(0), time.Second) }) +} + +func TestSafePoint_ToCancelMode(t *testing.T) { + assert.Equal(t, CancelAfterToolCalls, AfterToolCalls.toCancelMode()) + assert.Equal(t, CancelAfterChatModel, AfterChatModel.toCancelMode()) + assert.Equal(t, CancelAfterToolCalls|CancelAfterChatModel, AnySafePoint.toCancelMode()) +} + +func TestNewTurnLoop_NilGenInput_Panics(t *testing.T) { + assert.PanicsWithValue(t, "adk: NewTurnLoop: GenInput is required", func() { + NewTurnLoop(TurnLoopConfig[string]{PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { return nil, nil }}) + }) +} + +func TestNewTurnLoop_NilPrepareAgent_Panics(t *testing.T) { + assert.PanicsWithValue(t, "adk: NewTurnLoop: PrepareAgent is required", func() { + NewTurnLoop(TurnLoopConfig[string]{GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return nil, nil + }}) + }) +} + +func TestDeriveChild_NilParent_ReturnsNil(t *testing.T) { + var cc *cancelContext + assert.Nil(t, cc.deriveChild(context.Background())) +} + +func TestUntilIdleFor(t *testing.T) { + t.Run("FiresAfterIdleDuration", func(t *testing.T) { + turnDone := make(chan struct{}) + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + close(turnDone) + return &AgentOutput{}, nil + }, + }, nil + }, + }) + + loop.Push("msg1") + <-turnDone + + loop.Stop(UntilIdleFor(50 * time.Millisecond)) + + done := make(chan struct{}) + go func() { + loop.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("loop did not exit after idle timeout") + } + }) + + t.Run("ResetsOnPush", func(t *testing.T) { + turnCount := int32(0) + turnDone := make(chan struct{}, 10) + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + atomic.AddInt32(&turnCount, 1) + turnDone <- struct{}{} + return &AgentOutput{}, nil + }, + }, nil + }, + }) + + loop.Push("msg1") + <-turnDone + + loop.Stop(UntilIdleFor(200 * time.Millisecond)) + + time.Sleep(100 * time.Millisecond) + loop.Push("msg2") + <-turnDone + + done := make(chan struct{}) + go func() { + loop.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("loop did not exit after idle timeout") + } + + assert.Equal(t, int32(2), atomic.LoadInt32(&turnCount)) + }) + + t.Run("EscalatedByStopWithImmediate", func(t *testing.T) { + agentStarted := make(chan *cancelContext, 1) + probe := &turnLoopStopModeProbeAgent{ccCh: agentStarted} + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return probe, nil + }, + }) + + loop.Push("msg1") + cc := <-agentStarted + + loop.Stop(UntilIdleFor(10 * time.Minute)) + loop.Stop(WithImmediate()) + + deadline := time.After(2 * time.Second) + for { + if cc.getMode() == CancelImmediate { + break + } + select { + case <-deadline: + t.Fatal("cancel mode did not escalate to CancelImmediate") + default: + } + time.Sleep(1 * time.Millisecond) + } + + exit := loop.Wait() + var ce *CancelError + require.True(t, errors.As(exit.ExitReason, &ce)) + assert.Equal(t, CancelImmediate, ce.Info.Mode) + }) + + t.Run("EscalatedByStopWithGraceful", func(t *testing.T) { + agentStarted := make(chan struct{}) + agentDone := make(chan struct{}) + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + close(agentStarted) + <-ctx.Done() + close(agentDone) + return nil, ctx.Err() + }, + }, nil + }, + }) + + loop.Push("msg1") + <-agentStarted + + loop.Stop(UntilIdleFor(10 * time.Minute)) + loop.Stop(WithGracefulTimeout(50 * time.Millisecond)) + + select { + case <-agentDone: + case <-time.After(2 * time.Second): + t.Fatal("agent was not cancelled") + } + + exit := loop.Wait() + assert.Error(t, exit.ExitReason) + }) +} + +func TestUntilIdleFor_NonPositive_Panics(t *testing.T) { + assert.PanicsWithValue(t, "adk: UntilIdleFor: duration must be positive", + func() { UntilIdleFor(0) }) + assert.PanicsWithValue(t, "adk: UntilIdleFor: duration must be positive", + func() { UntilIdleFor(-1 * time.Second) }) +} diff --git a/adk/wrappers_test.go b/adk/wrappers_test.go index acb6588be..f231e3c07 100644 --- a/adk/wrappers_test.go +++ b/adk/wrappers_test.go @@ -1402,7 +1402,7 @@ func TestEventSenderToolHandler(t *testing.T) { NewEventSenderToolWrapper(), &invokableResultModifier{ BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, - modifiedResult: modifiedResult, + modifiedResult: modifiedResult, }, }, }) @@ -1490,7 +1490,7 @@ func TestEventSenderToolHandler(t *testing.T) { NewEventSenderToolWrapper(), &streamableResultModifier{ BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, - modifiedResult: modifiedResult, + modifiedResult: modifiedResult, }, }, }) @@ -1578,7 +1578,7 @@ func TestEventSenderToolHandler(t *testing.T) { NewEventSenderToolWrapper(), &enhancedInvokableResultModifier{ BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, - modifiedResult: modifiedResult, + modifiedResult: modifiedResult, }, }, }) @@ -1666,7 +1666,7 @@ func TestEventSenderToolHandler(t *testing.T) { NewEventSenderToolWrapper(), &enhancedStreamableResultModifier{ BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, - modifiedResult: modifiedResult, + modifiedResult: modifiedResult, }, }, }) diff --git a/components/prompt/agentic_chat_template_test.go b/components/prompt/agentic_chat_template_test.go index 42d7a8630..f47020a2c 100644 --- a/components/prompt/agentic_chat_template_test.go +++ b/components/prompt/agentic_chat_template_test.go @@ -21,9 +21,10 @@ import ( "errors" "testing" + "github.com/stretchr/testify/assert" + "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/schema" - "github.com/stretchr/testify/assert" ) type mockAgenticTemplate struct { diff --git a/internal/channel.go b/internal/channel.go index 8e36d8939..fa4215359 100644 --- a/internal/channel.go +++ b/internal/channel.go @@ -61,17 +61,18 @@ func (ch *UnboundedChan[T]) TrySend(value T) bool { return true } -// Receive gets an item from the channel (blocks if empty) +// Receive gets an item from the channel (blocks if empty). +// Returns (value, true) if an item was received. +// Returns (zero, false) if the channel was closed with no data remaining. func (ch *UnboundedChan[T]) Receive() (T, bool) { ch.mutex.Lock() defer ch.mutex.Unlock() for len(ch.buffer) == 0 && !ch.closed { - ch.notEmpty.Wait() // Wait until data is available + ch.notEmpty.Wait() } if len(ch.buffer) == 0 { - // Channel is closed and empty var zero T return zero, false } @@ -88,36 +89,6 @@ func (ch *UnboundedChan[T]) Close() { if !ch.closed { ch.closed = true - ch.notEmpty.Broadcast() // Wake up all waiting goroutines + ch.notEmpty.Broadcast() } } - -// TakeAll removes and returns all values from the channel atomically. -// Returns nil if the channel is empty. -func (ch *UnboundedChan[T]) TakeAll() []T { - ch.mutex.Lock() - defer ch.mutex.Unlock() - - if len(ch.buffer) == 0 { - return nil - } - - values := ch.buffer - ch.buffer = nil - return values -} - -// PushFront adds values to the front of the channel. -// This is useful for recovering values that need to be reprocessed. -// Does nothing if values is empty. -func (ch *UnboundedChan[T]) PushFront(values []T) { - if len(values) == 0 { - return - } - - ch.mutex.Lock() - defer ch.mutex.Unlock() - - ch.buffer = append(append([]T{}, values...), ch.buffer...) - ch.notEmpty.Signal() -} diff --git a/internal/channel_test.go b/internal/channel_test.go index bed2383f1..736a27413 100644 --- a/internal/channel_test.go +++ b/internal/channel_test.go @@ -219,244 +219,3 @@ func TestUnboundedChan_BlockingReceive(t *testing.T) { t.Error("Receive should have unblocked") } } - -func TestUnboundedChan_TakeAll(t *testing.T) { - ch := NewUnboundedChan[int]() - - // Test TakeAll on empty channel - items := ch.TakeAll() - if items != nil { - t.Errorf("TakeAll on empty channel should return nil, got %v", items) - } - - // Send some values - ch.Send(1) - ch.Send(2) - ch.Send(3) - - // Test TakeAll returns all values - items = ch.TakeAll() - if len(items) != 3 { - t.Errorf("expected 3 values, got %d", len(items)) - } - if items[0] != 1 || items[1] != 2 || items[2] != 3 { - t.Errorf("unexpected values: %v", items) - } - - // Channel should be empty now - if len(ch.buffer) != 0 { - t.Errorf("channel should be empty after TakeAll, got %d values", len(ch.buffer)) - } - - // TakeAll again should return nil - items = ch.TakeAll() - if items != nil { - t.Errorf("TakeAll on empty channel should return nil, got %v", items) - } -} - -func TestUnboundedChan_TakeAll_Partial(t *testing.T) { - ch := NewUnboundedChan[int]() - - // Send values - ch.Send(1) - ch.Send(2) - ch.Send(3) - - // Receive one - val, ok := ch.Receive() - if !ok || val != 1 { - t.Errorf("expected (1, true), got (%d, %v)", val, ok) - } - - // TakeAll should return remaining values - items := ch.TakeAll() - if len(items) != 2 { - t.Errorf("expected 2 values, got %d", len(items)) - } - if items[0] != 2 || items[1] != 3 { - t.Errorf("unexpected values: %v", items) - } -} - -func TestUnboundedChan_PushFront(t *testing.T) { - ch := NewUnboundedChan[int]() - - // Test PushFront with empty values (should do nothing) - ch.PushFront(nil) - ch.PushFront([]int{}) - if len(ch.buffer) != 0 { - t.Errorf("PushFront with empty values should not add anything, got %d values", len(ch.buffer)) - } - - // Send some values - ch.Send(3) - ch.Send(4) - - // PushFront should prepend values - ch.PushFront([]int{1, 2}) - - if len(ch.buffer) != 4 { - t.Errorf("expected 4 values, got %d", len(ch.buffer)) - } - if ch.buffer[0] != 1 || ch.buffer[1] != 2 || ch.buffer[2] != 3 || ch.buffer[3] != 4 { - t.Errorf("unexpected buffer: %v", ch.buffer) - } - - // Receive should return in correct order - val, _ := ch.Receive() - if val != 1 { - t.Errorf("expected 1, got %d", val) - } - val, _ = ch.Receive() - if val != 2 { - t.Errorf("expected 2, got %d", val) - } -} - -func TestUnboundedChan_PushFront_EmptyChannel(t *testing.T) { - ch := NewUnboundedChan[int]() - - // PushFront to empty channel - ch.PushFront([]int{1, 2, 3}) - - if len(ch.buffer) != 3 { - t.Errorf("expected 3 values, got %d", len(ch.buffer)) - } - - // Receive should work - val, ok := ch.Receive() - if !ok || val != 1 { - t.Errorf("expected (1, true), got (%d, %v)", val, ok) - } -} - -func TestUnboundedChan_PushFront_UnblocksReceive(t *testing.T) { - ch := NewUnboundedChan[int]() - - // Start a blocking receive - receiveDone := make(chan int) - go func() { - val, _ := ch.Receive() - receiveDone <- val - }() - - // Ensure receive is blocked - select { - case <-receiveDone: - t.Error("Receive should block on empty channel") - case <-time.After(50 * time.Millisecond): - // This is expected - } - - // PushFront should unblock the receive - ch.PushFront([]int{42}) - - select { - case val := <-receiveDone: - if val != 42 { - t.Errorf("expected 42, got %d", val) - } - case <-time.After(50 * time.Millisecond): - t.Error("Receive should have unblocked after PushFront") - } -} - -func TestUnboundedChan_PushFront_SpareCapacity(t *testing.T) { - ch := NewUnboundedChan[int]() - - // Pre-fill the channel so PushFront has something to append - ch.Send(10) - ch.Send(20) - - // Create a slice with spare capacity: len=2, cap=10. - // Elements beyond len (index 2-9) must not be corrupted by PushFront. - src := make([]int, 3, 10) - src[0] = 1 - src[1] = 2 - src[2] = 3 // sentinel — must survive PushFront(src[:2]) - - ch.PushFront(src[:2]) - - // Verify the sentinel was NOT overwritten by the channel's existing buffer - if src[2] != 3 { - t.Errorf("PushFront corrupted caller's backing array: src[2] = %d, want 3", src[2]) - } - - // Verify channel drains correctly: [1, 2, 10, 20] - expected := []int{1, 2, 10, 20} - for i, want := range expected { - got, ok := ch.Receive() - if !ok { - t.Fatalf("Receive returned ok=false at index %d", i) - } - if got != want { - t.Errorf("index %d: got %d, want %d", i, got, want) - } - } -} - -func TestUnboundedChan_TakeAll_PushFront_Concurrent(t *testing.T) { - ch := NewUnboundedChan[int]() - const numOps = 100 - - var wg sync.WaitGroup - wg.Add(3) - - // Goroutine 1: Send values - go func() { - defer wg.Done() - for i := 0; i < numOps; i++ { - ch.Send(i) - time.Sleep(time.Microsecond) - } - }() - - // Goroutine 2: TakeAll periodically - takeAllResults := make([][]int, 0) - var mu sync.Mutex - go func() { - defer wg.Done() - for i := 0; i < numOps/10; i++ { - items := ch.TakeAll() - if items != nil { - mu.Lock() - takeAllResults = append(takeAllResults, items) - mu.Unlock() - } - time.Sleep(10 * time.Microsecond) - } - }() - - // Goroutine 3: PushFront periodically - go func() { - defer wg.Done() - for i := 0; i < numOps/10; i++ { - ch.PushFront([]int{-i}) - time.Sleep(10 * time.Microsecond) - } - }() - - wg.Wait() - ch.Close() - - // Drain remaining values - remaining := ch.TakeAll() - if remaining != nil { - mu.Lock() - takeAllResults = append(takeAllResults, remaining) - mu.Unlock() - } - - // Count total values collected - total := 0 - for _, batch := range takeAllResults { - total += len(batch) - } - - // We should have exactly numOps (from Send) + numOps/10 (from PushFront) values - expected := numOps + numOps/10 - if total != expected { - t.Errorf("expected %d values, got %d", expected, total) - } -} From 592c470558493bac8f12a4c885d7ebc87b11f719 Mon Sep 17 00:00:00 2001 From: shentongmartin Date: Wed, 15 Apr 2026 19:41:38 +0800 Subject: [PATCH 55/65] feat(adk): add ShouldRetry callback with EOF-gated verdict signal for retry event signaling (#944) --- adk/chatmodel.go | 7 +- adk/chatmodel_retry_test.go | 1938 ++++++++++++++++++++++++++- adk/failover_chatmodel.go | 6 + adk/retry_chatmodel.go | 498 ++++++- adk/wrappers.go | 126 +- adk/wrappers_retry_failover_test.go | 858 +++++++----- schema/stream.go | 39 +- schema/stream_oneof_test.go | 324 +++++ 8 files changed, 3418 insertions(+), 378 deletions(-) create mode 100644 schema/stream_oneof_test.go diff --git a/adk/chatmodel.go b/adk/chatmodel.go index abfc55fa0..0f1f9f0a8 100644 --- a/adk/chatmodel.go +++ b/adk/chatmodel.go @@ -45,8 +45,13 @@ type chatModelAgentExecCtx struct { generator *AsyncGenerator[*AgentEvent] cancelCtx *cancelContext - // failoverLastSuccessModel is the last success model only used in failover middleware. failoverLastSuccessModel model.BaseChatModel + + // suppressEventSend prevents eventSenderModel from emitting AgentEvents for the current + // Generate call. Set to true before each rejected retry attempt and reset to false after. + // Invariant: any code path that emits model output events MUST check this flag. + suppressEventSend bool + retryVerdictSignal *retryVerdictSignal } func (e *chatModelAgentExecCtx) send(event *AgentEvent) { diff --git a/adk/chatmodel_retry_test.go b/adk/chatmodel_retry_test.go index 0cb2a87bd..225fb354d 100644 --- a/adk/chatmodel_retry_test.go +++ b/adk/chatmodel_retry_test.go @@ -26,6 +26,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" "github.com/cloudwego/eino/components/model" @@ -38,6 +39,57 @@ import ( var errRetryAble = errors.New("retry-able error") var errNonRetryAble = errors.New("non-retry-able error") +var instantBackoff = func(_ context.Context, _ int) time.Duration { return time.Millisecond } + +type agentEvent struct { + Err error + Output *AgentOutput + StreamContent string +} + +func drainAgentEvents(t *testing.T, iterator *AsyncIterator[*AgentEvent]) []agentEvent { + t.Helper() + var events []agentEvent + for { + event, ok := iterator.Next() + if !ok { + break + } + events = append(events, agentEvent{Err: event.Err, Output: event.Output}) + } + return events +} + +func drainStreamingAgentEvents(t *testing.T, iterator *AsyncIterator[*AgentEvent]) (events []agentEvent, streamTermErrs []error) { + t.Helper() + for { + event, ok := iterator.Next() + if !ok { + break + } + ae := agentEvent{Err: event.Err, Output: event.Output} + if event.Output != nil && event.Output.MessageOutput != nil { + mo := event.Output.MessageOutput + if mo.IsStreaming && mo.MessageStream != nil { + var chunks []string + for { + msg, recvErr := mo.MessageStream.Recv() + if recvErr != nil { + streamTermErrs = append(streamTermErrs, recvErr) + break + } + if msg != nil { + chunks = append(chunks, msg.Content) + } + } + ae.StreamContent = strings.Join(chunks, "") + } + } + events = append(events, ae) + } + return events, streamTermErrs +} + func TestChatModelAgentRetry_NoTools_DirectError_Generate(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) @@ -706,26 +758,6 @@ func TestDefaultBackoff(t *testing.T) { "Delay should still be capped at 10s + jitter for very high attempts, got %v", d100) } -func TestRetryExhaustedError_ErrorString(t *testing.T) { - errWithLast := &RetryExhaustedError{ - LastErr: errors.New("connection timeout"), - TotalRetries: 3, - } - assert.Contains(t, errWithLast.Error(), "exceeds max retries") - assert.Contains(t, errWithLast.Error(), "connection timeout") - - errWithoutLast := &RetryExhaustedError{ - LastErr: nil, - TotalRetries: 3, - } - assert.Equal(t, "exceeds max retries", errWithoutLast.Error()) -} - -func TestWillRetryError_ErrorString(t *testing.T) { - willRetry := &WillRetryError{ErrStr: "transient error", RetryAttempt: 1} - assert.Equal(t, "transient error", willRetry.Error()) -} - type customError struct { code int msg string @@ -1150,9 +1182,7 @@ func TestCheckpointSave_WillRetryError_StreamNotConsumed(t *testing.T) { IsRetryAble: func(_ context.Context, err error) bool { return errors.Is(err, errRetryAble) }, - BackoffFunc: func(_ context.Context, _ int) time.Duration { - return time.Millisecond // fast retry for test - }, + BackoffFunc: instantBackoff, }, }) assert.NoError(t, err) @@ -1191,3 +1221,1865 @@ func TestCheckpointSave_WillRetryError_StreamNotConsumed(t *testing.T) { assert.Equal(t, int32(2), atomic.LoadInt32(&mdl.streamCallCount), "model should be called twice: first fail, then retry success") } + +func TestChatModelAgentRetry_ShouldRetry_RejectMessage_Stream(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + var callCount int32 + cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + count := atomic.AddInt32(&callCount, 1) + r, w := schema.Pipe[*schema.Message](1) + go func() { + if count < 2 { + _ = w.Send(schema.AssistantMessage("bad stream content", nil), nil) + } else { + _ = w.Send(schema.AssistantMessage("good stream content", nil), nil) + } + w.Close() + }() + return r, nil + }).Times(2) + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "ShouldRetryStreamTestAgent", + Description: "Test ShouldRetry message rejection in stream mode", + Instruction: "You are a helpful assistant.", + Model: cm, + ModelRetryConfig: &ModelRetryConfig{ + MaxRetries: 3, + ShouldRetry: func(ctx context.Context, retryCtx *RetryContext) *RetryDecision { + if retryCtx.Err != nil { + return &RetryDecision{Retry: true} + } + if retryCtx.OutputMessage != nil && strings.Contains(retryCtx.OutputMessage.Content, "bad") { + return &RetryDecision{Retry: true} + } + return &RetryDecision{Retry: false} + }, + BackoffFunc: instantBackoff, + }, + }) + assert.NoError(t, err) + + input := &AgentInput{ + Messages: []Message{schema.UserMessage("Hello")}, + EnableStreaming: true, + } + iterator := agent.Run(ctx, input) + + events, _ := drainStreamingAgentEvents(t, iterator) + var foundGoodContent bool + for _, e := range events { + if e.StreamContent == "good stream content" { + foundGoodContent = true + } + } + require.True(t, foundGoodContent, "should have received good stream content") + assert.Equal(t, int32(2), atomic.LoadInt32(&callCount)) +} + +func TestShouldRetry_Generate(t *testing.T) { + t.Run("RetryContext_Fields", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + var callCount int32 + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + count := atomic.AddInt32(&callCount, 1) + if count < 2 { + return schema.AssistantMessage("bad", nil), nil + } + return schema.AssistantMessage("good", nil), nil + }).Times(2) + + var capturedContexts []*RetryContext + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "RetryContextFieldsAgent", + Description: "Test that RetryContext fields are correctly populated", + Instruction: "You are a helpful assistant.", + Model: cm, + ModelRetryConfig: &ModelRetryConfig{ + MaxRetries: 3, + ShouldRetry: func(ctx context.Context, retryCtx *RetryContext) *RetryDecision { + capturedContexts = append(capturedContexts, retryCtx) + if retryCtx.OutputMessage != nil && retryCtx.OutputMessage.Content == "bad" { + return &RetryDecision{Retry: true} + } + return &RetryDecision{Retry: false} + }, + BackoffFunc: instantBackoff, + }, + }) + assert.NoError(t, err) + + input := &AgentInput{ + Messages: []Message{schema.UserMessage("Hello")}, + } + iterator := agent.Run(ctx, input) + + for { + event, ok := iterator.Next() + if !ok { + break + } + _ = event + } + + assert.Len(t, capturedContexts, 2, "ShouldRetry should be called twice") + + assert.Equal(t, 1, capturedContexts[0].RetryAttempt) + assert.Len(t, capturedContexts[0].InputMessages, 2) + assert.True(t, len(capturedContexts[0].Options) > 0, "should have default options") + assert.Equal(t, "bad", capturedContexts[0].OutputMessage.Content) + assert.Nil(t, capturedContexts[0].Err) + + assert.Equal(t, 2, capturedContexts[1].RetryAttempt) + assert.Equal(t, "good", capturedContexts[1].OutputMessage.Content) + assert.Nil(t, capturedContexts[1].Err) + }) + + t.Run("RewriteError_OnMessage", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + Return(schema.AssistantMessage("unrecoverable bad message", nil), nil).Times(1) + + fatalErr := errors.New("fatal: unrecoverable model output") + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "RewriteErrorTestAgent", + Description: "Test ShouldRetry RewriteError on message", + Instruction: "You are a helpful assistant.", + Model: cm, + ModelRetryConfig: &ModelRetryConfig{ + MaxRetries: 2, + ShouldRetry: func(ctx context.Context, retryCtx *RetryContext) *RetryDecision { + if retryCtx.OutputMessage != nil && strings.Contains(retryCtx.OutputMessage.Content, "unrecoverable") { + return &RetryDecision{ + Retry: false, + RewriteError: fatalErr, + } + } + return &RetryDecision{Retry: false} + }, + BackoffFunc: instantBackoff, + }, + }) + assert.NoError(t, err) + + input := &AgentInput{ + Messages: []Message{schema.UserMessage("Hello")}, + } + iterator := agent.Run(ctx, input) + + events := drainAgentEvents(t, iterator) + require.NotEmpty(t, events) + foundErr := false + for _, e := range events { + if e.Err != nil && errors.Is(e.Err, fatalErr) { + foundErr = true + } + } + require.True(t, foundErr, "should have received the fatal rewrite error") + }) + + t.Run("RewriteError_OnError", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + origErr := errors.New("original transient error") + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil, origErr).Times(1) + + wrappedErr := errors.New("wrapped: original transient error with more context") + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "RewriteErrorOnErrorTestAgent", + Description: "Test ShouldRetry RewriteError replacing original error", + Instruction: "You are a helpful assistant.", + Model: cm, + ModelRetryConfig: &ModelRetryConfig{ + MaxRetries: 2, + ShouldRetry: func(ctx context.Context, retryCtx *RetryContext) *RetryDecision { + if retryCtx.Err != nil { + return &RetryDecision{ + Retry: false, + RewriteError: wrappedErr, + } + } + return &RetryDecision{Retry: false} + }, + BackoffFunc: instantBackoff, + }, + }) + assert.NoError(t, err) + + input := &AgentInput{ + Messages: []Message{schema.UserMessage("Hello")}, + } + iterator := agent.Run(ctx, input) + + events := drainAgentEvents(t, iterator) + require.NotEmpty(t, events) + foundErr := false + for _, e := range events { + if e.Err != nil && errors.Is(e.Err, wrappedErr) { + foundErr = true + } + } + require.True(t, foundErr, "should have received the wrapped rewrite error") + }) + + t.Run("AdditionalOptions", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + var callCount int32 + var capturedOpts [][]model.Option + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + count := atomic.AddInt32(&callCount, 1) + capturedOpts = append(capturedOpts, opts) + if count < 2 { + return nil, errRetryAble + } + return schema.AssistantMessage("success", nil), nil + }).Times(2) + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "AdditionalOptionsTestAgent", + Description: "Test ShouldRetry AdditionalOptions", + Instruction: "You are a helpful assistant.", + Model: cm, + ModelRetryConfig: &ModelRetryConfig{ + MaxRetries: 3, + ShouldRetry: func(ctx context.Context, retryCtx *RetryContext) *RetryDecision { + if retryCtx.Err != nil { + return &RetryDecision{ + Retry: true, + AdditionalOptions: []model.Option{model.WithMaxTokens(8192)}, + } + } + return &RetryDecision{Retry: false} + }, + BackoffFunc: instantBackoff, + }, + }) + assert.NoError(t, err) + + input := &AgentInput{ + Messages: []Message{schema.UserMessage("Hello")}, + } + iterator := agent.Run(ctx, input) + + event, ok := iterator.Next() + assert.True(t, ok) + assert.NotNil(t, event) + assert.Nil(t, event.Err) + assert.Equal(t, int32(2), atomic.LoadInt32(&callCount)) + assert.Equal(t, 2, len(capturedOpts)) + assert.Equal(t, len(capturedOpts[0])+1, len(capturedOpts[1])) + }) + + t.Run("ModifiedInputMessages_NoPersist", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + var callCount int32 + var capturedInputs [][]*schema.Message + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + count := atomic.AddInt32(&callCount, 1) + inputCopy := make([]*schema.Message, len(input)) + copy(inputCopy, input) + capturedInputs = append(capturedInputs, inputCopy) + if count < 2 { + return nil, errRetryAble + } + return schema.AssistantMessage("success", nil), nil + }).Times(2) + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "ModifiedInputNoPersistAgent", + Description: "Test ShouldRetry ModifiedInputMessages without persistence", + Instruction: "You are a helpful assistant.", + Model: cm, + ModelRetryConfig: &ModelRetryConfig{ + MaxRetries: 3, + ShouldRetry: func(ctx context.Context, retryCtx *RetryContext) *RetryDecision { + if retryCtx.Err != nil { + return &RetryDecision{ + Retry: true, + ModifiedInputMessages: []*schema.Message{ + schema.SystemMessage("compressed instruction"), + schema.UserMessage("Hello"), + }, + PersistModifiedInputMessages: false, + } + } + return &RetryDecision{Retry: false} + }, + BackoffFunc: instantBackoff, + }, + }) + assert.NoError(t, err) + + input := &AgentInput{ + Messages: []Message{schema.UserMessage("Hello")}, + } + iterator := agent.Run(ctx, input) + + event, ok := iterator.Next() + assert.True(t, ok) + assert.NotNil(t, event) + assert.Nil(t, event.Err) + assert.Equal(t, int32(2), atomic.LoadInt32(&callCount)) + assert.Equal(t, 2, len(capturedInputs)) + assert.Equal(t, "compressed instruction", capturedInputs[1][0].Content, "second call should use modified input") + }) + + t.Run("Backoff", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + var callCount int32 + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + count := atomic.AddInt32(&callCount, 1) + if count < 2 { + return nil, errRetryAble + } + return schema.AssistantMessage("success", nil), nil + }).Times(2) + + customBackoff := 50 * time.Millisecond + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "BackoffTestAgent", + Description: "Test ShouldRetry custom Backoff in decision", + Instruction: "You are a helpful assistant.", + Model: cm, + ModelRetryConfig: &ModelRetryConfig{ + MaxRetries: 3, + ShouldRetry: func(ctx context.Context, retryCtx *RetryContext) *RetryDecision { + if retryCtx.Err != nil { + return &RetryDecision{ + Retry: true, + Backoff: customBackoff, + } + } + return &RetryDecision{Retry: false} + }, + }, + }) + assert.NoError(t, err) + + start := time.Now() + input := &AgentInput{ + Messages: []Message{schema.UserMessage("Hello")}, + } + iterator := agent.Run(ctx, input) + + event, ok := iterator.Next() + assert.True(t, ok) + assert.NotNil(t, event) + assert.Nil(t, event.Err) + elapsed := time.Since(start) + assert.True(t, elapsed >= customBackoff && elapsed < customBackoff+200*time.Millisecond, "expected backoff ~%v, got %v", customBackoff, elapsed) + assert.Equal(t, int32(2), atomic.LoadInt32(&callCount)) + }) + + t.Run("SuppressFlag_Rejected_NoEvent", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + var callCount int32 + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + count := atomic.AddInt32(&callCount, 1) + if count == 1 { + return schema.AssistantMessage("bad", nil), nil + } + return schema.AssistantMessage("good", nil), nil + }).Times(2) + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "SuppressRejected", + Description: "Test suppress flag rejects first then accepts", + Instruction: "You are a helpful assistant.", + Model: cm, + ModelRetryConfig: &ModelRetryConfig{ + MaxRetries: 1, + ShouldRetry: func(ctx context.Context, retryCtx *RetryContext) *RetryDecision { + if retryCtx.OutputMessage != nil && retryCtx.OutputMessage.Content == "bad" { + return &RetryDecision{Retry: true} + } + return &RetryDecision{Retry: false} + }, + BackoffFunc: instantBackoff, + }, + }) + assert.NoError(t, err) + + input := &AgentInput{ + Messages: []Message{schema.UserMessage("Hello")}, + } + iterator := agent.Run(ctx, input) + + var msgEvents []*AgentEvent + for { + event, ok := iterator.Next() + if !ok { + break + } + if event.Output != nil && event.Output.MessageOutput != nil { + msgEvents = append(msgEvents, event) + } + } + assert.Equal(t, 1, len(msgEvents), "should have exactly 1 message event (suppressed rejected)") + assert.Equal(t, "good", msgEvents[0].Output.MessageOutput.Message.Content) + }) + + t.Run("SuppressFlag_AllRejected_NoEvents", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + Return(schema.AssistantMessage("always bad", nil), nil).Times(3) + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "SuppressAllRejected", + Description: "Test suppress flag all rejected no events", + Instruction: "You are a helpful assistant.", + Model: cm, + ModelRetryConfig: &ModelRetryConfig{ + MaxRetries: 2, + ShouldRetry: func(ctx context.Context, retryCtx *RetryContext) *RetryDecision { + return &RetryDecision{Retry: true} + }, + BackoffFunc: instantBackoff, + }, + }) + assert.NoError(t, err) + + input := &AgentInput{ + Messages: []Message{schema.UserMessage("Hello")}, + } + iterator := agent.Run(ctx, input) + + events := drainAgentEvents(t, iterator) + var msgEventCount int + var foundExhaustedErr bool + for _, e := range events { + if e.Output != nil && e.Output.MessageOutput != nil { + msgEventCount++ + } + if e.Err != nil && errors.Is(e.Err, ErrExceedMaxRetries) { + foundExhaustedErr = true + } + } + assert.Equal(t, 0, msgEventCount, "no message events should be emitted when all are rejected") + require.True(t, foundExhaustedErr, "final event should have RetryExhaustedError") + }) + + t.Run("SuppressFlag_Accepted_FirstAttempt", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + Return(schema.AssistantMessage("perfect", nil), nil).Times(1) + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "SuppressAcceptedFirst", + Description: "Test suppress flag accepted first attempt", + Instruction: "You are a helpful assistant.", + Model: cm, + ModelRetryConfig: &ModelRetryConfig{ + MaxRetries: 2, + ShouldRetry: func(ctx context.Context, retryCtx *RetryContext) *RetryDecision { + return &RetryDecision{Retry: false} + }, + BackoffFunc: instantBackoff, + }, + }) + assert.NoError(t, err) + + input := &AgentInput{ + Messages: []Message{schema.UserMessage("Hello")}, + } + iterator := agent.Run(ctx, input) + + var msgEvents []*AgentEvent + for { + event, ok := iterator.Next() + if !ok { + break + } + if event.Output != nil && event.Output.MessageOutput != nil { + msgEvents = append(msgEvents, event) + } + } + assert.Equal(t, 1, len(msgEvents), "should have exactly 1 event") + assert.Equal(t, "perfect", msgEvents[0].Output.MessageOutput.Message.Content) + }) + + t.Run("ContextCanceled_DuringSleep", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + var callCount int32 + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&callCount, 1) + return nil, errors.New("transient") + }).Times(1) + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "ContextCancelDuringSleep", + Description: "Test context cancellation during backoff sleep", + Instruction: "You are a helpful assistant.", + Model: cm, + ModelRetryConfig: &ModelRetryConfig{ + MaxRetries: 5, + ShouldRetry: func(ctx context.Context, retryCtx *RetryContext) *RetryDecision { + return &RetryDecision{Retry: true} + }, + BackoffFunc: func(_ context.Context, _ int) time.Duration { return 10 * time.Second }, + }, + }) + require.NoError(t, err) + + go func() { + time.Sleep(50 * time.Millisecond) + cancel() + }() + + start := time.Now() + iterator := agent.Run(ctx, &AgentInput{ + Messages: []Message{schema.UserMessage("Hello")}, + }) + events := drainAgentEvents(t, iterator) + elapsed := time.Since(start) + + require.True(t, elapsed < 2*time.Second, "should not block for full backoff; elapsed: %v", elapsed) + assert.Equal(t, int32(1), atomic.LoadInt32(&callCount)) + + var foundCtxErr bool + for _, e := range events { + if e.Err != nil && errors.Is(e.Err, context.Canceled) { + foundCtxErr = true + } + } + require.True(t, foundCtxErr, "should have received context.Canceled error") + }) +} + +func TestShouldRetry_Stream(t *testing.T) { + t.Run("ErrorRetry", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + streamErr := errors.New("stream unavailable") + var callCount int32 + cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + count := atomic.AddInt32(&callCount, 1) + if count < 2 { + return nil, streamErr + } + r, w := schema.Pipe[*schema.Message](1) + go func() { + _ = w.Send(schema.AssistantMessage("recovered stream", nil), nil) + w.Close() + }() + return r, nil + }).Times(2) + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "StreamErrorRetryAgent", + Description: "Test ShouldRetry when Stream returns error (nil stream)", + Instruction: "You are a helpful assistant.", + Model: cm, + ModelRetryConfig: &ModelRetryConfig{ + MaxRetries: 3, + ShouldRetry: func(ctx context.Context, retryCtx *RetryContext) *RetryDecision { + if retryCtx.Err != nil { + return &RetryDecision{Retry: true} + } + return &RetryDecision{Retry: false} + }, + BackoffFunc: instantBackoff, + }, + }) + assert.NoError(t, err) + + input := &AgentInput{ + Messages: []Message{schema.UserMessage("Hello")}, + EnableStreaming: true, + } + iterator := agent.Run(ctx, input) + + events, _ := drainStreamingAgentEvents(t, iterator) + var foundContent bool + for _, e := range events { + if e.StreamContent == "recovered stream" { + foundContent = true + } + } + require.True(t, foundContent, "should have received recovered stream content after error retry") + assert.Equal(t, int32(2), atomic.LoadInt32(&callCount)) + }) + + t.Run("ErrorRewrite", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + streamErr := errors.New("model overloaded") + cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil, streamErr).Times(1) + + fatalErr := errors.New("fatal: model overloaded, aborting") + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "StreamErrorRewriteAgent", + Description: "Test ShouldRetry RewriteError when Stream returns error", + Instruction: "You are a helpful assistant.", + Model: cm, + ModelRetryConfig: &ModelRetryConfig{ + MaxRetries: 2, + ShouldRetry: func(ctx context.Context, retryCtx *RetryContext) *RetryDecision { + if retryCtx.Err != nil && strings.Contains(retryCtx.Err.Error(), "overloaded") { + return &RetryDecision{ + Retry: false, + RewriteError: fatalErr, + } + } + return &RetryDecision{Retry: false} + }, + BackoffFunc: instantBackoff, + }, + }) + assert.NoError(t, err) + + input := &AgentInput{ + Messages: []Message{schema.UserMessage("Hello")}, + EnableStreaming: true, + } + iterator := agent.Run(ctx, input) + + events := drainAgentEvents(t, iterator) + require.NotEmpty(t, events) + foundErr := false + for _, e := range events { + if e.Err != nil && errors.Is(e.Err, fatalErr) { + foundErr = true + } + } + require.True(t, foundErr, "should have received the fatal rewrite error from stream") + }) + + t.Run("RewriteError_OnMessage", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + r, w := schema.Pipe[*schema.Message](1) + go func() { + _ = w.Send(schema.AssistantMessage("hallucinated garbage output", nil), nil) + w.Close() + }() + return r, nil + }).Times(1) + + fatalErr := errors.New("fatal: hallucinated output detected") + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "StreamRewriteOnMessageAgent", + Description: "Test ShouldRetry RewriteError on successful stream with bad content", + Instruction: "You are a helpful assistant.", + Model: cm, + ModelRetryConfig: &ModelRetryConfig{ + MaxRetries: 2, + ShouldRetry: func(ctx context.Context, retryCtx *RetryContext) *RetryDecision { + if retryCtx.OutputMessage != nil && strings.Contains(retryCtx.OutputMessage.Content, "hallucinated") { + return &RetryDecision{ + Retry: false, + RewriteError: fatalErr, + } + } + return &RetryDecision{Retry: false} + }, + BackoffFunc: instantBackoff, + }, + }) + assert.NoError(t, err) + + input := &AgentInput{ + Messages: []Message{schema.UserMessage("Hello")}, + EnableStreaming: true, + } + iterator := agent.Run(ctx, input) + + events := drainAgentEvents(t, iterator) + require.NotEmpty(t, events) + foundErr := false + for _, e := range events { + if e.Err != nil && errors.Is(e.Err, fatalErr) { + foundErr = true + } + } + require.True(t, foundErr, "should have received fatal rewrite error from stream message inspection") + }) + + t.Run("PartialStreamError", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + partialErr := errors.New("connection reset mid-stream") + var callCount int32 + cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + count := atomic.AddInt32(&callCount, 1) + r, w := schema.Pipe[*schema.Message](1) + go func() { + _ = w.Send(schema.AssistantMessage("partial chunk", nil), nil) + if count < 2 { + w.Send(nil, partialErr) + } else { + _ = w.Send(schema.AssistantMessage(" complete", nil), nil) + w.Close() + } + }() + return r, nil + }).Times(2) + + var capturedContexts []*RetryContext + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "StreamPartialErrorAgent", + Description: "Test ShouldRetry when stream has partial content then error", + Instruction: "You are a helpful assistant.", + Model: cm, + ModelRetryConfig: &ModelRetryConfig{ + MaxRetries: 3, + ShouldRetry: func(ctx context.Context, retryCtx *RetryContext) *RetryDecision { + capturedContexts = append(capturedContexts, retryCtx) + if retryCtx.Err != nil { + return &RetryDecision{Retry: true} + } + return &RetryDecision{Retry: false} + }, + BackoffFunc: instantBackoff, + }, + }) + assert.NoError(t, err) + + input := &AgentInput{ + Messages: []Message{schema.UserMessage("Hello")}, + EnableStreaming: true, + } + iterator := agent.Run(ctx, input) + + for { + event, ok := iterator.Next() + if !ok { + break + } + if event.Output != nil && event.Output.MessageOutput != nil { + mo := event.Output.MessageOutput + if mo.IsStreaming && mo.MessageStream != nil { + for { + _, err := mo.MessageStream.Recv() + if err != nil { + break + } + } + } + } + } + + assert.Equal(t, int32(2), atomic.LoadInt32(&callCount)) + assert.Equal(t, 2, len(capturedContexts)) + assert.NotNil(t, capturedContexts[0].Err, "first attempt should have stream error") + assert.NotNil(t, capturedContexts[0].OutputMessage, "first attempt should have partial message despite error") + assert.Equal(t, "partial chunk", capturedContexts[0].OutputMessage.Content) + }) + + t.Run("ModifiedInputsAndOptions_WithPersist", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + var callCount int32 + var capturedInputs [][]*schema.Message + var capturedOptLens []int + cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + count := atomic.AddInt32(&callCount, 1) + inputCopy := make([]*schema.Message, len(input)) + copy(inputCopy, input) + capturedInputs = append(capturedInputs, inputCopy) + capturedOptLens = append(capturedOptLens, len(opts)) + + r, w := schema.Pipe[*schema.Message](1) + go func() { + if count < 2 { + _ = w.Send(schema.AssistantMessage("too long response exceeds limit", nil), nil) + } else { + _ = w.Send(schema.AssistantMessage("good response", nil), nil) + } + w.Close() + }() + return r, nil + }).Times(2) + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "StreamModifiedInputsPersistAgent", + Description: "Test ShouldRetry with ModifiedInputMessages (persist) and AdditionalOptions in stream mode", + Instruction: "You are a helpful assistant.", + Model: cm, + ModelRetryConfig: &ModelRetryConfig{ + MaxRetries: 3, + ShouldRetry: func(ctx context.Context, retryCtx *RetryContext) *RetryDecision { + if retryCtx.OutputMessage != nil && strings.Contains(retryCtx.OutputMessage.Content, "too long") { + return &RetryDecision{ + Retry: true, + ModifiedInputMessages: []*schema.Message{ + schema.SystemMessage("compressed instruction"), + schema.UserMessage("summarized history"), + }, + PersistModifiedInputMessages: true, + AdditionalOptions: []model.Option{model.WithMaxTokens(16384)}, + } + } + return &RetryDecision{Retry: false} + }, + BackoffFunc: instantBackoff, + }, + }) + assert.NoError(t, err) + + input := &AgentInput{ + Messages: []Message{schema.UserMessage("Hello")}, + EnableStreaming: true, + } + iterator := agent.Run(ctx, input) + + events, _ := drainStreamingAgentEvents(t, iterator) + var foundGood bool + for _, e := range events { + if e.StreamContent == "good response" { + foundGood = true + } + } + + require.True(t, foundGood, "should have received good response after retry with modified inputs") + assert.Equal(t, int32(2), atomic.LoadInt32(&callCount)) + assert.Equal(t, 2, len(capturedInputs)) + assert.Equal(t, "compressed instruction", capturedInputs[1][0].Content, "second call should use modified input") + assert.Equal(t, "summarized history", capturedInputs[1][1].Content) + assert.Equal(t, capturedOptLens[0]+1, capturedOptLens[1]) + }) + + t.Run("VerdictSignal_CleanStream_Rejected", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + var callCount int32 + cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + count := atomic.AddInt32(&callCount, 1) + if count == 1 { + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("bad", nil)}), nil + } + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("good", nil)}), nil + }).Times(2) + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "VerdictCleanRejected", + Description: "Test verdict signal on clean stream rejected", + Instruction: "You are a helpful assistant.", + Model: cm, + ModelRetryConfig: &ModelRetryConfig{ + MaxRetries: 1, + ShouldRetry: func(ctx context.Context, retryCtx *RetryContext) *RetryDecision { + if retryCtx.OutputMessage != nil && retryCtx.OutputMessage.Content == "bad" { + return &RetryDecision{Retry: true} + } + return &RetryDecision{Retry: false} + }, + BackoffFunc: instantBackoff, + }, + }) + assert.NoError(t, err) + + input := &AgentInput{ + Messages: []Message{schema.UserMessage("Hello")}, + EnableStreaming: true, + } + iterator := agent.Run(ctx, input) + + var streamEvents []int + for { + event, ok := iterator.Next() + if !ok { + break + } + if event.Output != nil && event.Output.MessageOutput != nil { + mo := event.Output.MessageOutput + if mo.IsStreaming && mo.MessageStream != nil { + idx := len(streamEvents) + streamEvents = append(streamEvents, idx) + var lastErr error + for { + _, recvErr := mo.MessageStream.Recv() + if recvErr != nil { + lastErr = recvErr + break + } + } + if idx == 0 { + var willRetryErr *WillRetryError + assert.True(t, errors.As(lastErr, &willRetryErr), "first stream should end with WillRetryError") + } else { + assert.ErrorIs(t, lastErr, io.EOF, "second stream should end with io.EOF") + } + } + } + } + assert.Equal(t, 2, len(streamEvents), "should have exactly 2 stream events") + assert.Equal(t, int32(2), atomic.LoadInt32(&callCount)) + }) + + t.Run("VerdictSignal_StreamError_Rejected", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + streamErr := errors.New("mid-stream error") + var callCount int32 + cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + count := atomic.AddInt32(&callCount, 1) + if count == 1 { + r, w := schema.Pipe[*schema.Message](1) + go func() { + _ = w.Send(schema.AssistantMessage("partial", nil), nil) + w.Send(nil, streamErr) + }() + return r, nil + } + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("good", nil)}), nil + }).Times(2) + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "VerdictStreamErrorRejected", + Description: "Test verdict signal on stream error rejected", + Instruction: "You are a helpful assistant.", + Model: cm, + ModelRetryConfig: &ModelRetryConfig{ + MaxRetries: 1, + ShouldRetry: func(ctx context.Context, retryCtx *RetryContext) *RetryDecision { + if retryCtx.Err != nil { + return &RetryDecision{Retry: true} + } + return &RetryDecision{Retry: false} + }, + BackoffFunc: instantBackoff, + }, + }) + assert.NoError(t, err) + + input := &AgentInput{ + Messages: []Message{schema.UserMessage("Hello")}, + EnableStreaming: true, + } + iterator := agent.Run(ctx, input) + + var firstEventHasWillRetry bool + var eventCount int + for { + event, ok := iterator.Next() + if !ok { + break + } + if event.Output != nil && event.Output.MessageOutput != nil { + mo := event.Output.MessageOutput + if mo.IsStreaming && mo.MessageStream != nil { + eventCount++ + for { + _, recvErr := mo.MessageStream.Recv() + if recvErr != nil { + if eventCount == 1 { + var willRetryErr *WillRetryError + if errors.As(recvErr, &willRetryErr) { + firstEventHasWillRetry = true + } + } + break + } + } + } + } + } + assert.True(t, firstEventHasWillRetry, "first event stream should end with WillRetryError via errWrapper path") + assert.Equal(t, 2, eventCount, "should have 2 stream events") + }) + + t.Run("VerdictSignal_Accepted_FirstAttempt", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("perfect", nil)}), nil + }).Times(1) + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "VerdictAcceptedFirst", + Description: "Test verdict signal accepted first attempt", + Instruction: "You are a helpful assistant.", + Model: cm, + ModelRetryConfig: &ModelRetryConfig{ + MaxRetries: 2, + ShouldRetry: func(ctx context.Context, retryCtx *RetryContext) *RetryDecision { + return &RetryDecision{Retry: false} + }, + BackoffFunc: instantBackoff, + }, + }) + assert.NoError(t, err) + + input := &AgentInput{ + Messages: []Message{schema.UserMessage("Hello")}, + EnableStreaming: true, + } + iterator := agent.Run(ctx, input) + + var eventCount int + for { + event, ok := iterator.Next() + if !ok { + break + } + if event.Output != nil && event.Output.MessageOutput != nil { + mo := event.Output.MessageOutput + if mo.IsStreaming && mo.MessageStream != nil { + eventCount++ + var lastErr error + for { + _, recvErr := mo.MessageStream.Recv() + if recvErr != nil { + lastErr = recvErr + break + } + } + assert.ErrorIs(t, lastErr, io.EOF, "accepted stream should end with io.EOF") + } + } + } + assert.Equal(t, 1, eventCount, "should have exactly 1 event") + }) + + t.Run("VerdictSignal_AllRejected_Exhausted", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("always bad", nil)}), nil + }).Times(3) + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "VerdictAllRejected", + Description: "Test verdict signal all rejected exhausted", + Instruction: "You are a helpful assistant.", + Model: cm, + ModelRetryConfig: &ModelRetryConfig{ + MaxRetries: 2, + ShouldRetry: func(ctx context.Context, retryCtx *RetryContext) *RetryDecision { + return &RetryDecision{Retry: true} + }, + BackoffFunc: instantBackoff, + }, + }) + assert.NoError(t, err) + + input := &AgentInput{ + Messages: []Message{schema.UserMessage("Hello")}, + EnableStreaming: true, + } + iterator := agent.Run(ctx, input) + + events, streamTermErrs := drainStreamingAgentEvents(t, iterator) + var willRetryCount int + var foundExhaustedErr bool + for _, e := range events { + if e.Err != nil && errors.Is(e.Err, ErrExceedMaxRetries) { + foundExhaustedErr = true + } + } + for _, termErr := range streamTermErrs { + var willRetryErr *WillRetryError + if errors.As(termErr, &willRetryErr) { + willRetryCount++ + } + } + assert.Equal(t, 3, willRetryCount, "all 3 stream events should end with WillRetryError") + require.True(t, foundExhaustedErr, "final error should be RetryExhaustedError") + }) + + t.Run("ShouldRetry_Panics_VerdictStillSent", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("trigger panic", nil)}), nil + }).Times(1) + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "ShouldRetryPanicsAgent", + Description: "Test that ShouldRetry panic sends verdict signal and does not deadlock", + Instruction: "You are a helpful assistant.", + Model: cm, + ModelRetryConfig: &ModelRetryConfig{ + MaxRetries: 1, + ShouldRetry: func(ctx context.Context, retryCtx *RetryContext) *RetryDecision { + panic("deliberate panic in ShouldRetry") + }, + BackoffFunc: instantBackoff, + }, + }) + require.NoError(t, err) + + input := &AgentInput{ + Messages: []Message{schema.UserMessage("Hello")}, + EnableStreaming: true, + } + + done := make(chan struct{}) + var events []agentEvent + go func() { + defer close(done) + iterator := agent.Run(ctx, input) + events = drainAgentEvents(t, iterator) + }() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("test deadlocked — verdict signal was not sent after ShouldRetry panic") + } + require.NotEmpty(t, events) + var foundPanicErr bool + for _, e := range events { + if e.Err != nil && strings.Contains(e.Err.Error(), "panic") { + foundPanicErr = true + } + } + assert.True(t, foundPanicErr, "should have received a panic error event") + }) +} + +func TestErrStreamCanceled(t *testing.T) { + t.Run("Stream_ShouldRetry_NeverRetried", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + r, w := schema.Pipe[*schema.Message](1) + go func() { + _ = w.Send(schema.AssistantMessage("partial", nil), nil) + w.Send(nil, ErrStreamCanceled) + }() + return r, nil + }).Times(1) + + var shouldRetryCalled int32 + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "StreamCanceledShouldRetry", + Description: "Test ErrStreamCanceled never retried with ShouldRetry", + Instruction: "You are a helpful assistant.", + Model: cm, + ModelRetryConfig: &ModelRetryConfig{ + MaxRetries: 3, + ShouldRetry: func(ctx context.Context, retryCtx *RetryContext) *RetryDecision { + atomic.AddInt32(&shouldRetryCalled, 1) + return &RetryDecision{Retry: true} + }, + BackoffFunc: instantBackoff, + }, + }) + assert.NoError(t, err) + + input := &AgentInput{ + Messages: []Message{schema.UserMessage("Hello")}, + EnableStreaming: true, + } + iterator := agent.Run(ctx, input) + + for { + event, ok := iterator.Next() + if !ok { + break + } + if event.Output != nil && event.Output.MessageOutput != nil { + mo := event.Output.MessageOutput + if mo.IsStreaming && mo.MessageStream != nil { + for { + _, recvErr := mo.MessageStream.Recv() + if recvErr != nil { + break + } + } + } + } + } + assert.Equal(t, int32(0), atomic.LoadInt32(&shouldRetryCalled), "ShouldRetry should never be called for ErrStreamCanceled") + }) + + t.Run("Stream_LegacyIsRetryAble_NeverRetried", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + r, w := schema.Pipe[*schema.Message](1) + go func() { + _ = w.Send(schema.AssistantMessage("partial", nil), nil) + w.Send(nil, ErrStreamCanceled) + }() + return r, nil + }).Times(1) + + var isRetryAbleCalled int32 + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "StreamCanceledLegacy", + Description: "Test ErrStreamCanceled never retried with legacy IsRetryAble", + Instruction: "You are a helpful assistant.", + Model: cm, + ModelRetryConfig: &ModelRetryConfig{ + MaxRetries: 3, + IsRetryAble: func(_ context.Context, err error) bool { + atomic.AddInt32(&isRetryAbleCalled, 1) + return true + }, + BackoffFunc: instantBackoff, + }, + }) + assert.NoError(t, err) + + input := &AgentInput{ + Messages: []Message{schema.UserMessage("Hello")}, + EnableStreaming: true, + } + iterator := agent.Run(ctx, input) + + for { + event, ok := iterator.Next() + if !ok { + break + } + if event.Output != nil && event.Output.MessageOutput != nil { + mo := event.Output.MessageOutput + if mo.IsStreaming && mo.MessageStream != nil { + for { + _, recvErr := mo.MessageStream.Recv() + if recvErr != nil { + break + } + } + } + } + } + assert.Equal(t, int32(0), atomic.LoadInt32(&isRetryAbleCalled), "IsRetryAble should never be called for ErrStreamCanceled") + }) + + t.Run("Generate_ShouldRetry_NeverRetried", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil, ErrStreamCanceled).Times(1) + + var shouldRetryCalled int32 + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "GenCanceledShouldRetry", + Description: "Test ErrStreamCanceled in Generate never retried", + Instruction: "You are a helpful assistant.", + Model: cm, + ModelRetryConfig: &ModelRetryConfig{ + MaxRetries: 3, + ShouldRetry: func(ctx context.Context, retryCtx *RetryContext) *RetryDecision { + atomic.AddInt32(&shouldRetryCalled, 1) + return &RetryDecision{Retry: true} + }, + BackoffFunc: instantBackoff, + }, + }) + assert.NoError(t, err) + + input := &AgentInput{ + Messages: []Message{schema.UserMessage("Hello")}, + } + iterator := agent.Run(ctx, input) + + for { + _, ok := iterator.Next() + if !ok { + break + } + } + assert.Equal(t, int32(0), atomic.LoadInt32(&shouldRetryCalled), "ShouldRetry should never be called for ErrStreamCanceled") + }) +} + +func TestAttack_ShouldRetry_NilDecisionOnEveryCall(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + Return(schema.AssistantMessage("ok", nil), nil).Times(1) + + var shouldRetryCalls int32 + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "NilDecisionAgent", + Description: "ShouldRetry always returns nil — should accept on first call", + Instruction: "You are a helpful assistant.", + Model: cm, + ModelRetryConfig: &ModelRetryConfig{ + MaxRetries: 3, + ShouldRetry: func(ctx context.Context, retryCtx *RetryContext) *RetryDecision { + atomic.AddInt32(&shouldRetryCalls, 1) + return nil + }, + BackoffFunc: instantBackoff, + }, + }) + require.NoError(t, err) + + iterator := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("Hello")}}) + events := drainAgentEvents(t, iterator) + + require.NotEmpty(t, events) + assert.Equal(t, int32(1), atomic.LoadInt32(&shouldRetryCalls)) + var foundOK bool + for _, e := range events { + if e.Output != nil && e.Output.MessageOutput != nil && e.Output.MessageOutput.Message != nil { + if e.Output.MessageOutput.Message.Content == "ok" { + foundOK = true + } + } + } + assert.True(t, foundOK, "nil decision should accept the message as-is") +} + +func TestAttack_ShouldRetry_MaxRetriesZero_RejectFirstAttempt(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + Return(schema.AssistantMessage("bad", nil), nil).Times(1) + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "MaxZeroRejectAgent", + Description: "MaxRetries=0 with ShouldRetry rejecting — should exhaust immediately", + Instruction: "You are a helpful assistant.", + Model: cm, + ModelRetryConfig: &ModelRetryConfig{ + MaxRetries: 0, + ShouldRetry: func(ctx context.Context, retryCtx *RetryContext) *RetryDecision { + return &RetryDecision{Retry: true} + }, + BackoffFunc: instantBackoff, + }, + }) + require.NoError(t, err) + + iterator := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("Hello")}}) + events := drainAgentEvents(t, iterator) + + var foundExhausted bool + for _, e := range events { + if e.Err != nil { + var exhaustedErr *RetryExhaustedError + if errors.As(e.Err, &exhaustedErr) { + foundExhausted = true + } + } + } + assert.True(t, foundExhausted, "MaxRetries=0 with Retry:true should produce RetryExhaustedError") +} + +func TestAttack_ShouldRetry_RetryTrueWithRewriteError_IgnoresRewrite(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + var callCount int32 + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + count := atomic.AddInt32(&callCount, 1) + if count == 1 { + return nil, errors.New("transient") + } + return schema.AssistantMessage("success", nil), nil + }).Times(2) + + rewriteErr := errors.New("this should be ignored") + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "RetryTrueRewriteAgent", + Description: "Retry=true with RewriteError should ignore the rewrite", + Instruction: "You are a helpful assistant.", + Model: cm, + ModelRetryConfig: &ModelRetryConfig{ + MaxRetries: 3, + ShouldRetry: func(ctx context.Context, retryCtx *RetryContext) *RetryDecision { + if retryCtx.Err != nil { + return &RetryDecision{Retry: true, RewriteError: rewriteErr} + } + return &RetryDecision{Retry: false} + }, + BackoffFunc: instantBackoff, + }, + }) + require.NoError(t, err) + + iterator := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("Hello")}}) + events := drainAgentEvents(t, iterator) + + var foundSuccess bool + for _, e := range events { + if e.Err != nil && errors.Is(e.Err, rewriteErr) { + t.Fatal("RewriteError should be ignored when Retry=true") + } + if e.Output != nil && e.Output.MessageOutput != nil && e.Output.MessageOutput.Message != nil { + if e.Output.MessageOutput.Message.Content == "success" { + foundSuccess = true + } + } + } + assert.True(t, foundSuccess, "should eventually succeed after retry, ignoring RewriteError") +} + +func TestAttack_ShouldRetry_OptionsAccumulateAcrossRetries(t *testing.T) { + ctx := context.Background() + + var capturedOpts [][]model.Option + var callCount int32 + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + count := atomic.AddInt32(&callCount, 1) + capturedOpts = append(capturedOpts, opts) + if count <= 2 { + return nil, errors.New("needs retry") + } + return schema.AssistantMessage("done", nil), nil + }).Times(3) + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "OptsAccumulateAgent", + Description: "Verify options accumulate across retries", + Instruction: "You are a helpful assistant.", + Model: cm, + ModelRetryConfig: &ModelRetryConfig{ + MaxRetries: 5, + ShouldRetry: func(ctx context.Context, retryCtx *RetryContext) *RetryDecision { + if retryCtx.Err != nil { + return &RetryDecision{ + Retry: true, + AdditionalOptions: []model.Option{model.WithMaxTokens(100 * retryCtx.RetryAttempt)}, + } + } + return &RetryDecision{Retry: false} + }, + BackoffFunc: instantBackoff, + }, + }) + require.NoError(t, err) + + iterator := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("Hello")}}) + drainAgentEvents(t, iterator) + + require.Len(t, capturedOpts, 3) + assert.True(t, len(capturedOpts[1]) > len(capturedOpts[0]), + "second call should have more options than first (accumulated AdditionalOptions)") + assert.True(t, len(capturedOpts[2]) > len(capturedOpts[1]), + "third call should have more options than second (accumulated AdditionalOptions)") +} + +func TestAttack_ShouldRetry_Stream_NilDecisionAccepts(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("stream ok", nil)}), nil + }).Times(1) + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "StreamNilDecisionAgent", + Description: "ShouldRetry returns nil in stream mode — should accept", + Instruction: "You are a helpful assistant.", + Model: cm, + ModelRetryConfig: &ModelRetryConfig{ + MaxRetries: 2, + ShouldRetry: func(ctx context.Context, retryCtx *RetryContext) *RetryDecision { + return nil + }, + BackoffFunc: instantBackoff, + }, + }) + require.NoError(t, err) + + input := &AgentInput{ + Messages: []Message{schema.UserMessage("Hello")}, + EnableStreaming: true, + } + iterator := agent.Run(ctx, input) + + events, streamTermErrs := drainStreamingAgentEvents(t, iterator) + var foundStreamContent bool + for _, e := range events { + if e.StreamContent == "stream ok" { + foundStreamContent = true + } + } + assert.True(t, foundStreamContent, "nil decision should accept the stream") + for _, termErr := range streamTermErrs { + assert.Equal(t, io.EOF, termErr, "stream should terminate with clean EOF, not error") + } +} + +func TestAttack_ShouldRetry_Stream_MaxRetriesZero_Exhausted(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("rejected", nil)}), nil + }).Times(1) + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "StreamMaxZeroAgent", + Description: "Stream mode with MaxRetries=0 rejecting — should exhaust immediately", + Instruction: "You are a helpful assistant.", + Model: cm, + ModelRetryConfig: &ModelRetryConfig{ + MaxRetries: 0, + ShouldRetry: func(ctx context.Context, retryCtx *RetryContext) *RetryDecision { + return &RetryDecision{Retry: true} + }, + BackoffFunc: instantBackoff, + }, + }) + require.NoError(t, err) + + input := &AgentInput{ + Messages: []Message{schema.UserMessage("Hello")}, + EnableStreaming: true, + } + + done := make(chan struct{}) + var events []agentEvent + go func() { + defer close(done) + iterator := agent.Run(ctx, input) + events, _ = drainStreamingAgentEvents(t, iterator) + }() + + select { + case <-done: + case <-time.After(10 * time.Second): + t.Fatal("test deadlocked — Stream MaxRetries=0 with reject should not hang") + } + + var foundExhausted bool + for _, e := range events { + if e.Err != nil { + var exhaustedErr *RetryExhaustedError + if errors.As(e.Err, &exhaustedErr) { + foundExhausted = true + } + } + } + assert.True(t, foundExhausted, "MaxRetries=0 stream reject should produce RetryExhaustedError") +} + +func TestAttack_ShouldRetry_Stream_RewriteErrorOnCleanStream(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("looks good but bad", nil)}), nil + }).Times(1) + + fatalErr := errors.New("fatal: content policy violation") + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "StreamRewriteCleanAgent", + Description: "Stream returns cleanly but ShouldRetry rewrites to error", + Instruction: "You are a helpful assistant.", + Model: cm, + ModelRetryConfig: &ModelRetryConfig{ + MaxRetries: 2, + ShouldRetry: func(ctx context.Context, retryCtx *RetryContext) *RetryDecision { + return &RetryDecision{Retry: false, RewriteError: fatalErr} + }, + BackoffFunc: instantBackoff, + }, + }) + require.NoError(t, err) + + input := &AgentInput{ + Messages: []Message{schema.UserMessage("Hello")}, + EnableStreaming: true, + } + + done := make(chan struct{}) + var events []agentEvent + go func() { + defer close(done) + iterator := agent.Run(ctx, input) + events, _ = drainStreamingAgentEvents(t, iterator) + }() + + select { + case <-done: + case <-time.After(10 * time.Second): + t.Fatal("test deadlocked") + } + + var foundFatal bool + for _, e := range events { + if e.Err != nil && errors.Is(e.Err, fatalErr) { + foundFatal = true + } + } + assert.True(t, foundFatal, "RewriteError on clean stream should propagate the fatal error") +} + +func TestAttack_ShouldRetry_ConcatMessagesFails_EmptyStream(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + r, w := schema.Pipe[*schema.Message](1) + w.Close() + return r, nil + }).Times(1) + + var capturedCtx *RetryContext + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "EmptyStreamAgent", + Description: "Stream returns zero chunks — both OutputMessage and Err should be nil", + Instruction: "You are a helpful assistant.", + Model: cm, + ModelRetryConfig: &ModelRetryConfig{ + MaxRetries: 1, + ShouldRetry: func(ctx context.Context, retryCtx *RetryContext) *RetryDecision { + capturedCtx = retryCtx + return &RetryDecision{Retry: false} + }, + BackoffFunc: instantBackoff, + }, + }) + require.NoError(t, err) + + input := &AgentInput{ + Messages: []Message{schema.UserMessage("Hello")}, + EnableStreaming: true, + } + + done := make(chan struct{}) + go func() { + defer close(done) + iterator := agent.Run(ctx, input) + drainStreamingAgentEvents(t, iterator) + }() + + select { + case <-done: + case <-time.After(10 * time.Second): + t.Fatal("test deadlocked on empty stream") + } + + require.NotNil(t, capturedCtx) + assert.Nil(t, capturedCtx.OutputMessage, "empty stream should have nil OutputMessage") + assert.Nil(t, capturedCtx.Err, "empty stream should have nil Err") +} + +func TestAttack_ShouldRetry_Stream_MidStreamError_VerdictDoubleRead(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + midStreamErr := errors.New("mid-stream transient error") + + cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + r, w := schema.Pipe[*schema.Message](1) + go func() { + defer w.Close() + _ = w.Send(schema.AssistantMessage("chunk1", nil), nil) + _ = w.Send(nil, midStreamErr) + }() + return r, nil + }).Times(2) + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "DoubleReadBugAgent", + Description: "Reproduces signal.ch double-read when event stream hits mid-stream error then EOF", + Instruction: "You are a helpful assistant.", + Model: cm, + ModelRetryConfig: &ModelRetryConfig{ + MaxRetries: 1, + ShouldRetry: func(ctx context.Context, retryCtx *RetryContext) *RetryDecision { + if retryCtx.Err != nil { + return &RetryDecision{Retry: true} + } + return &RetryDecision{Retry: false} + }, + BackoffFunc: instantBackoff, + }, + }) + require.NoError(t, err) + + input := &AgentInput{ + Messages: []Message{schema.UserMessage("Hello")}, + EnableStreaming: true, + } + + done := make(chan struct{}) + go func() { + defer close(done) + iterator := agent.Run(ctx, input) + for { + event, ok := iterator.Next() + if !ok { + break + } + if event.Output != nil && event.Output.MessageOutput != nil { + mo := event.Output.MessageOutput + if mo.IsStreaming && mo.MessageStream != nil { + for { + _, recvErr := mo.MessageStream.Recv() + if recvErr == io.EOF { + break + } + } + } + } + } + }() + + select { + case <-done: + case <-time.After(10 * time.Second): + t.Fatal("goroutine leak: onEOF blocked on signal.ch after errWrapper already drained the verdict") + } +} diff --git a/adk/failover_chatmodel.go b/adk/failover_chatmodel.go index 2a467ed76..898aedd7c 100644 --- a/adk/failover_chatmodel.go +++ b/adk/failover_chatmodel.go @@ -229,6 +229,12 @@ func (f *failoverModelWrapper) needFailover(ctx context.Context, outputMessage * return false } + // ErrStreamCanceled means the caller voluntarily abandoned the stream; + // never retry or fail over in this case. + if errors.Is(outputErr, ErrStreamCanceled) { + return false + } + // ShouldFailover is validated at agent construction; nil here indicates a programmer error. return f.config.ShouldFailover(ctx, outputMessage, outputErr) } diff --git a/adk/retry_chatmodel.go b/adk/retry_chatmodel.go index bac955033..304e8b9b3 100644 --- a/adk/retry_chatmodel.go +++ b/adk/retry_chatmodel.go @@ -21,7 +21,6 @@ import ( "errors" "fmt" "io" - "log" "math/rand" "time" @@ -79,7 +78,10 @@ func (e *RetryExhaustedError) Unwrap() error { type WillRetryError struct { ErrStr string RetryAttempt int - err error + // OutputMessage is the model's output message from the attempt that triggered the retry, if any. + // May be nil if the model returned an error without producing a message. + OutputMessage *schema.Message + err error } func (e *WillRetryError) Error() string { @@ -94,6 +96,102 @@ func init() { schema.RegisterName[*WillRetryError]("eino_adk_chatmodel_will_retry_error") } +// RetryContext contains context information passed to ModelRetryConfig.ShouldRetry +// during a retry decision. +// +// State combinations for OutputMessage and Err: +// +// OutputMessage != nil, Err == nil → successful call; inspect message quality +// OutputMessage == nil, Err != nil → failed call (Generate error or Stream() error) +// OutputMessage != nil, Err != nil → partial stream (chunks received before mid-stream error) +// OutputMessage == nil, Err == nil → empty stream (zero chunks before EOF) +type RetryContext struct { + // RetryAttempt is the current retry attempt number (1-based). + // For the first retry decision (after the initial call), this is 1. + RetryAttempt int + + // InputMessages is the input messages that were sent to the model for the current attempt. + InputMessages []*schema.Message + + // Options is the model options that were used for the current attempt. + Options []model.Option + + // OutputMessage is the output message from the model, if any. + // This is non-nil when the model returned a message successfully. + // For streaming, this is the fully concatenated message (the entire stream is consumed + // before ShouldRetry is called). + // For streaming with mid-stream errors, this is the partial concatenation of chunks + // received before the error occurred. + // May be nil if the model returned an error without producing a message, or if the + // stream was empty (zero chunks before EOF). + OutputMessage *schema.Message + + // Err is the error from the model call, if any. + // May be nil if the model produced a message without error. + // Note: both OutputMessage and Err can be nil simultaneously for empty streams. + Err error +} + +// RetryDecision represents the decision made by ModelRetryConfig.ShouldRetry. +type RetryDecision struct { + // Retry indicates whether the model call should be retried. + // If false, the model output (or error) is accepted as-is, unless RewriteError is set. + Retry bool + + // RewriteError, when non-nil, overrides the return value of the model call with this error. + // The agent run will fail with this error. + // + // This is useful for two scenarios: + // - When the model returns a "seemingly correct" message (no error) that actually + // contains unrecoverable issues. RewriteError converts the successful output + // into a fatal error. + // - When the model returns an error, but you want to replace it with a different, + // more descriptive error (e.g., adding context or wrapping). + // + // When Retry is true, RewriteError is ignored. + // When Retry is false and RewriteError is non-nil, the model call returns + // RewriteError regardless of whether the original call had an error or a message. + RewriteError error + + // ModifiedInputMessages, when non-nil, replaces the input messages for the next retry. + // + // This enables advanced recovery strategies like context compression or message trimming. + // Only used when Retry is true. Ignored when Retry is false. + ModifiedInputMessages []*schema.Message + + // PersistModifiedInputMessages controls whether ModifiedInputMessages are written + // back to the agent's conversation history, affecting subsequent model calls in + // the agent loop (not just the next retry attempt). + // + // When true, the modified messages replace the current conversation history. + // When false (default), the modified messages are only used for the next retry attempt + // within this retry cycle. + // + // Only used when Retry is true and ModifiedInputMessages is non-nil. + PersistModifiedInputMessages bool + + // AdditionalOptions, when non-nil, provides additional model options for the next retry. + // These options are appended to the existing options, taking precedence via last-wins semantics. + // + // This enables adjustments like increasing MaxTokens for the retry attempt. + // Note: options accumulate across retries within a single retry cycle. If ShouldRetry + // returns AdditionalOptions on every attempt, each set is appended to the previous ones. + // Only the last value for each option key takes effect, but earlier values remain in the slice. + // AdditionalOptions are scoped to the current retry cycle and do not persist to subsequent + // agent iterations — each new model call in the agent loop starts with the original options. + // Only used when Retry is true. Ignored when Retry is false. + AdditionalOptions []model.Option + + // Backoff specifies the duration to wait before the next retry attempt. + // If zero, the default backoff function (from ModelRetryConfig.BackoffFunc or the + // built-in exponential backoff) is used. + // + // This allows the ShouldRetry callback to dynamically control retry timing based on + // the specific error or problematic message encountered. + // Only used when Retry is true. Ignored when Retry is false. + Backoff time.Duration +} + // ModelRetryConfig configures retry behavior for the ChatModel node. // It defines how the agent should handle transient failures when calling the ChatModel. type ModelRetryConfig struct { @@ -102,14 +200,27 @@ type ModelRetryConfig struct { // A value of 3 means up to 3 retry attempts (4 total calls including the initial attempt). MaxRetries int - // IsRetryAble is a function that determines whether an error should trigger a retry. - // If nil, all errors are considered retry-able. - // Return true if the error is transient and the operation should be retried. - // Return false if the error is permanent and should be propagated immediately. + // ShouldRetry determines how to handle a model call result. + // It receives context information about the current attempt including the output message + // and/or error, and returns a decision on whether to retry, what to modify, etc. + // Returning nil is treated as &RetryDecision{Retry: false} (accept as-is). + // + // If nil, defaults to retrying on any non-nil error (backward compatible with IsRetryAble). + // + // Note: When ShouldRetry is set, IsRetryAble is ignored. + // Note: In streaming mode, the entire stream is consumed before ShouldRetry is called. + // The event stream is sent to the client in real time regardless; only the retry + // decision is deferred until the full response is available. + ShouldRetry func(ctx context.Context, retryCtx *RetryContext) *RetryDecision + + // Deprecated: Use ShouldRetry instead for richer retry control including message + // inspection, input modification, and option adjustment. When ShouldRetry is set, + // IsRetryAble is ignored. IsRetryAble func(ctx context.Context, err error) bool // BackoffFunc calculates the delay before the next retry attempt. // The attempt parameter starts at 1 for the first retry. + // Used as the default when RetryDecision.Backoff is zero. // If nil, a default exponential backoff with jitter is used: // base delay 100ms, exponentially increasing up to 10s max, // with random jitter (0-50% of delay) to prevent thundering herd. @@ -166,6 +277,17 @@ func consumeStreamForError(stream *schema.StreamReader[*schema.Message]) error { } } +type retryVerdictSignal struct { + ch chan retryVerdict +} + +type retryVerdict struct { + WillRetry bool + RetryAttempt int + Err error + OutputMessage *schema.Message +} + // retryModelWrapper wraps a BaseChatModel with retry logic. // This is used inside the model wrapper chain, positioned between eventSenderModelWrapper // and stateModelWrapper, so that retry only affects the inner chain (event sending, user wrappers, @@ -180,6 +302,13 @@ func newRetryModelWrapper(inner model.BaseChatModel, config *ModelRetryConfig) * } func (r *retryModelWrapper) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + if r.config.ShouldRetry != nil { + return r.generateWithShouldRetry(ctx, input, opts...) + } + return r.generateLegacy(ctx, input, opts...) +} + +func (r *retryModelWrapper) generateLegacy(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { isRetryAble := r.config.IsRetryAble if isRetryAble == nil { isRetryAble = defaultIsRetryAble @@ -201,23 +330,359 @@ func (r *retryModelWrapper) Generate(ctx context.Context, input []*schema.Messag return nil, err } + if errors.Is(err, ErrStreamCanceled) { + return nil, err + } + if !isRetryAble(ctx, err) { return nil, err } lastErr = err if attempt < r.config.MaxRetries { - log.Printf("retrying ChatModel.Generate (attempt %d/%d): %v", attempt+1, r.config.MaxRetries, err) - time.Sleep(backoffFunc(ctx, attempt+1)) + if err := r.contextAwareSleep(ctx, backoffFunc(ctx, attempt+1)); err != nil { + return nil, err + } + } + } + + return nil, &RetryExhaustedError{LastErr: lastErr, TotalRetries: r.config.MaxRetries} +} + +func (r *retryModelWrapper) generateWithShouldRetry(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + backoffFunc := r.config.BackoffFunc + if backoffFunc == nil { + backoffFunc = defaultBackoff + } + + execCtx := getChatModelAgentExecCtx(ctx) + + currentInput := input + currentOpts := opts + var lastErr error + + defer func() { + _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { + st.setRetryAttempt(0) + return nil + }) + }() + + for attempt := 0; attempt <= r.config.MaxRetries; attempt++ { + _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { + st.setRetryAttempt(attempt) + return nil + }) + + // Suppress event sending during Generate: the ShouldRetry callback must decide whether + // to accept or reject the result before any event is emitted. If accepted, the event + // is sent explicitly below (lines after decision check). If rejected, no event leaks. + if execCtx != nil { + execCtx.suppressEventSend = true + } + out, err := r.inner.Generate(ctx, currentInput, currentOpts...) + if execCtx != nil { + execCtx.suppressEventSend = false + } + + if err != nil { + if _, ok := compose.ExtractInterruptInfo(err); ok { + return nil, err + } + + if errors.Is(err, ErrStreamCanceled) { + return nil, err + } + } + + retryCtx := &RetryContext{ + RetryAttempt: attempt + 1, + InputMessages: currentInput, + Options: currentOpts, + OutputMessage: out, + Err: err, + } + decision := r.config.ShouldRetry(ctx, retryCtx) + if decision == nil { + decision = &RetryDecision{} + } + + if !decision.Retry { + if decision.RewriteError != nil { + return nil, decision.RewriteError + } + if err != nil { + return nil, err + } + if execCtx != nil && execCtx.generator != nil && out != nil { + msgCopy := *out + event := EventFromMessage(&msgCopy, nil, schema.Assistant, "") + execCtx.send(event) + } + return out, nil + } + + lastErr = err + if lastErr == nil { + lastErr = fmt.Errorf("model output rejected by ShouldRetry at attempt %d", attempt+1) + } + + if attempt >= r.config.MaxRetries { + break + } + + r.applyDecisionForRetry(¤tInput, ¤tOpts, ctx, decision) + + delay := decision.Backoff + if delay == 0 { + delay = backoffFunc(ctx, attempt+1) + } + + if err := r.contextAwareSleep(ctx, delay); err != nil { + return nil, err } } return nil, &RetryExhaustedError{LastErr: lastErr, TotalRetries: r.config.MaxRetries} } +func (r *retryModelWrapper) contextAwareSleep(ctx context.Context, delay time.Duration) error { + if delay <= 0 { + return nil + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(delay): + return nil + } +} + +func consumeStreamForMessage(stream *schema.StreamReader[*schema.Message]) (*schema.Message, error) { + defer stream.Close() + var chunks []*schema.Message + for { + chunk, err := stream.Recv() + if err == io.EOF { + if len(chunks) == 0 { + return nil, nil + } + msg, concatErr := schema.ConcatMessages(chunks) + return msg, concatErr + } + if err != nil { + if len(chunks) == 0 { + return nil, err + } + msg, _ := schema.ConcatMessages(chunks) + return msg, err + } + chunks = append(chunks, chunk) + } +} + +func (r *retryModelWrapper) streamWithShouldRetry(ctx context.Context, input []*schema.Message, opts ...model.Option) ( + *schema.StreamReader[*schema.Message], error) { + + backoffFunc := r.config.BackoffFunc + if backoffFunc == nil { + backoffFunc = defaultBackoff + } + + defer func() { + _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { + st.setRetryAttempt(0) + return nil + }) + }() + + execCtx := getChatModelAgentExecCtx(ctx) + + currentInput := input + currentOpts := opts + var lastErr error + var curSignal *retryVerdictSignal + + // Panic recovery for verdict signal: if ShouldRetry panics, the onEOF/errWrapper closures in + // buildStreamConvertOptions will block forever on signal.ch, causing a goroutine leak. This + // defer ensures a verdict is always sent, even on panic, before re-panicking. + defer func() { + if p := recover(); p != nil { + if curSignal != nil { + select { + case curSignal.ch <- retryVerdict{WillRetry: false, Err: fmt.Errorf("panic: %v", p)}: + default: + } + } + panic(p) + } + }() + + for attempt := 0; attempt <= r.config.MaxRetries; attempt++ { + _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { + st.setRetryAttempt(attempt) + return nil + }) + + signal := &retryVerdictSignal{ch: make(chan retryVerdict, 1)} + curSignal = signal + if execCtx != nil { + execCtx.retryVerdictSignal = signal + } + + stream, err := r.inner.Stream(ctx, currentInput, currentOpts...) + if err != nil { + // Defensive no-op: when Stream() returns an error, no stream exists, so + // eventSenderModel never creates the StreamReaderWithConvert hooks that would + // read from signal.ch. This send has no consumer — it merely fills the + // buffered(1) slot so the panic-recovery defer (select/default) won't block + // if a later panic tries to send a second verdict. The signal is discarded + // when the next iteration creates a new one. + signal.ch <- retryVerdict{WillRetry: false} + + if _, ok := compose.ExtractInterruptInfo(err); ok { + return nil, err + } + + if errors.Is(err, ErrStreamCanceled) { + return nil, err + } + + retryCtx := &RetryContext{ + RetryAttempt: attempt + 1, + InputMessages: currentInput, + Options: currentOpts, + Err: err, + } + decision := r.config.ShouldRetry(ctx, retryCtx) + if decision == nil { + decision = &RetryDecision{} + } + + if !decision.Retry { + if decision.RewriteError != nil { + return nil, decision.RewriteError + } + return nil, err + } + + lastErr = err + if attempt < r.config.MaxRetries { + r.applyDecisionForRetry(¤tInput, ¤tOpts, ctx, decision) + delay := decision.Backoff + if delay == 0 { + delay = backoffFunc(ctx, attempt+1) + } + if err := r.contextAwareSleep(ctx, delay); err != nil { + return nil, err + } + } + continue + } + + // Split the stream: checkCopy is consumed synchronously here to build the complete + // message for ShouldRetry inspection; returnCopy is returned to the caller and may + // already be consumed downstream in parallel. The verdict signal bridges the two: + // once ShouldRetry decides, the signal tells returnCopy's errWrapper/onEOF whether + // to pass through normally or inject a WillRetryError. + copies := stream.Copy(2) + checkCopy := copies[0] + returnCopy := copies[1] + + msg, streamErr := consumeStreamForMessage(checkCopy) + + if errors.Is(streamErr, ErrStreamCanceled) { + signal.ch <- retryVerdict{WillRetry: false} + returnCopy.Close() + return nil, streamErr + } + + retryCtx := &RetryContext{ + RetryAttempt: attempt + 1, + InputMessages: currentInput, + Options: currentOpts, + OutputMessage: msg, + Err: streamErr, + } + decision := r.config.ShouldRetry(ctx, retryCtx) + if decision == nil { + decision = &RetryDecision{} + } + + if !decision.Retry { + signal.ch <- retryVerdict{WillRetry: false} + + if decision.RewriteError != nil { + returnCopy.Close() + return nil, decision.RewriteError + } + if streamErr != nil { + returnCopy.Close() + return nil, streamErr + } + return returnCopy, nil + } + + verdictErr := streamErr + if verdictErr == nil { + verdictErr = fmt.Errorf("model output rejected by ShouldRetry at attempt %d", attempt+1) + } + signal.ch <- retryVerdict{ + WillRetry: true, + RetryAttempt: attempt, + Err: verdictErr, + OutputMessage: msg, + } + returnCopy.Close() + + lastErr = verdictErr + + if attempt < r.config.MaxRetries { + r.applyDecisionForRetry(¤tInput, ¤tOpts, ctx, decision) + delay := decision.Backoff + if delay == 0 { + delay = backoffFunc(ctx, attempt+1) + } + if err := r.contextAwareSleep(ctx, delay); err != nil { + return nil, err + } + } + } + + return nil, &RetryExhaustedError{LastErr: lastErr, TotalRetries: r.config.MaxRetries} +} + +func (r *retryModelWrapper) applyDecisionForRetry(currentInput *[]*schema.Message, currentOpts *[]model.Option, ctx context.Context, decision *RetryDecision) { + if decision.ModifiedInputMessages != nil { + *currentInput = decision.ModifiedInputMessages + if decision.PersistModifiedInputMessages { + modifiedInput := *currentInput + _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { + st.Messages = modifiedInput + return nil + }) + } + } + + if decision.AdditionalOptions != nil { + cloned := make([]model.Option, len(*currentOpts), len(*currentOpts)+len(decision.AdditionalOptions)) + copy(cloned, *currentOpts) + *currentOpts = append(cloned, decision.AdditionalOptions...) + } +} + func (r *retryModelWrapper) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) ( *schema.StreamReader[*schema.Message], error) { + if r.config.ShouldRetry != nil { + return r.streamWithShouldRetry(ctx, input, opts...) + } + return r.streamLegacy(ctx, input, opts...) +} + +func (r *retryModelWrapper) streamLegacy(ctx context.Context, input []*schema.Message, opts ...model.Option) ( + *schema.StreamReader[*schema.Message], error) { + isRetryAble := r.config.IsRetryAble if isRetryAble == nil { isRetryAble = defaultIsRetryAble @@ -243,17 +708,20 @@ func (r *retryModelWrapper) Stream(ctx context.Context, input []*schema.Message, stream, err := r.inner.Stream(ctx, input, opts...) if err != nil { - // Never retry interrupt errors (e.g. cancel safe-point interrupts). if _, ok := compose.ExtractInterruptInfo(err); ok { return nil, err } + if errors.Is(err, ErrStreamCanceled) { + return nil, err + } if !isRetryAble(ctx, err) { return nil, err } lastErr = err if attempt < r.config.MaxRetries { - log.Printf("retrying ChatModel.Stream (attempt %d/%d): %v", attempt+1, r.config.MaxRetries, err) - time.Sleep(backoffFunc(ctx, attempt+1)) + if err := r.contextAwareSleep(ctx, backoffFunc(ctx, attempt+1)); err != nil { + return nil, err + } } continue } @@ -268,14 +736,18 @@ func (r *retryModelWrapper) Stream(ctx context.Context, input []*schema.Message, } returnCopy.Close() + if errors.Is(streamErr, ErrStreamCanceled) { + return nil, streamErr + } if !isRetryAble(ctx, streamErr) { return nil, streamErr } lastErr = streamErr if attempt < r.config.MaxRetries { - log.Printf("retrying ChatModel.Stream (attempt %d/%d): %v", attempt+1, r.config.MaxRetries, streamErr) - time.Sleep(backoffFunc(ctx, attempt+1)) + if err := r.contextAwareSleep(ctx, backoffFunc(ctx, attempt+1)); err != nil { + return nil, err + } } } diff --git a/adk/wrappers.go b/adk/wrappers.go index b4e16d298..ce50d7baa 100644 --- a/adk/wrappers.go +++ b/adk/wrappers.go @@ -19,7 +19,9 @@ package adk import ( "context" "errors" + "io" "reflect" + "sync" "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/components" @@ -292,6 +294,9 @@ func (m *eventSenderModel) Generate(ctx context.Context, input []*schema.Message } execCtx := getChatModelAgentExecCtx(ctx) + if execCtx != nil && execCtx.suppressEventSend { + return result, nil + } if execCtx == nil || execCtx.generator == nil { return nil, errors.New("generator is nil when sending event in Generate: ensure agent state is properly initialized") } @@ -318,10 +323,7 @@ func (m *eventSenderModel) Stream(ctx context.Context, input []*schema.Message, streams := result.Copy(2) eventStream := streams[0] - if errWrapper := m.buildErrWrapper(ctx); errWrapper != nil { - convertOpts := []schema.ConvertOption{ - schema.WithErrWrapper(errWrapper), - } + if convertOpts := m.buildStreamConvertOptions(ctx); len(convertOpts) > 0 { eventStream = schema.StreamReaderWithConvert(streams[0], func(msg *schema.Message) (*schema.Message, error) { return msg, nil }, convertOpts...) @@ -333,20 +335,94 @@ func (m *eventSenderModel) Stream(ctx context.Context, input []*schema.Message, return streams[1], nil } -// buildErrWrapper constructs an error wrapper function for event streams. -// It wraps stream errors as WillRetryError when retry or failover is configured, -// so that flow.go:genAgentInput() can skip events from failed attempts instead of -// treating them as fatal errors. -func (m *eventSenderModel) buildErrWrapper(ctx context.Context) func(error) error { +// buildStreamConvertOptions constructs ConvertOption hooks that gate stream termination behind +// the retry verdict signal protocol. +// +// Verdict signal lifecycle: +// - streamWithShouldRetry creates a new retryVerdictSignal per retry attempt, stores it in +// execCtx.retryVerdictSignal, and sends exactly one retryVerdict after ShouldRetry decides. +// - The closures below capture a *retryVerdictSignal that is nil at closure-creation time; they +// read the live value from execCtx.retryVerdictSignal, which is set before each model call. +// +// Two hooks cooperate to cover all stream termination paths: +// - WithErrWrapper intercepts mid-stream errors. It blocks on the verdict to decide +// whether to wrap the error as WillRetryError (rejected attempt) or pass it through (accepted). +// - WithOnEOF intercepts clean EOF (successful stream). It blocks on the verdict to +// either inject a WillRetryError (rejected) or pass through io.EOF (accepted). +// +// Both hooks share a sync.Once-guarded reader so the verdict channel is read at most once. +// This prevents a goroutine leak when a mid-stream error is followed by EOF: errWrapper fires +// first (caching the verdict), and onEOF reuses the cached value instead of blocking on a +// drained channel. +func (m *eventSenderModel) buildStreamConvertOptions(ctx context.Context) []schema.ConvertOption { var retryAttempt int _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { retryAttempt = st.getRetryAttempt() return nil }) + wrapWithCancelGuard := func(inner func(error) error) func(error) error { + return func(err error) error { + if errors.Is(err, ErrStreamCanceled) { + return err + } + return inner(err) + } + } + + var opts []schema.ConvertOption + var retryWrapper func(error) error if m.modelRetryConfig != nil { - retryWrapper = genErrWrapper(ctx, m.modelRetryConfig.MaxRetries, retryAttempt, m.modelRetryConfig.IsRetryAble) + if m.modelRetryConfig.ShouldRetry != nil { + execCtx := getChatModelAgentExecCtx(ctx) + signal := (*retryVerdictSignal)(nil) + if execCtx != nil { + signal = execCtx.retryVerdictSignal + } + if signal != nil { + var ( + verdictOnce sync.Once + cachedVerdict retryVerdict + ) + readVerdict := func() retryVerdict { + verdictOnce.Do(func() { + cachedVerdict = <-signal.ch + }) + return cachedVerdict + } + + retryWrapper = wrapWithCancelGuard(func(err error) error { + verdict := readVerdict() + if verdict.WillRetry { + return &WillRetryError{ + ErrStr: err.Error(), + RetryAttempt: verdict.RetryAttempt, + OutputMessage: verdict.OutputMessage, + err: err, + } + } + return err + }) + + opts = append(opts, schema.WithOnEOF(func() (any, error) { + verdict := readVerdict() + if verdict.WillRetry { + return nil, &WillRetryError{ + ErrStr: verdict.Err.Error(), + RetryAttempt: verdict.RetryAttempt, + OutputMessage: verdict.OutputMessage, + err: verdict.Err, + } + } + return nil, io.EOF + })) + } + } else { + retryWrapper = wrapWithCancelGuard( + genErrWrapper(ctx, m.modelRetryConfig.MaxRetries, retryAttempt, m.modelRetryConfig.IsRetryAble), + ) + } } hasFailover := m.modelFailoverConfig != nil @@ -357,10 +433,10 @@ func (m *eventSenderModel) buildErrWrapper(ctx context.Context) func(error) erro failoverHasMore := getFailoverHasMoreAttempts(ctx) if retryWrapper == nil && !(hasFailover && failoverHasMore) { - return nil + return opts } - return func(err error) error { + combinedErrWrapper := func(err error) error { // If retry is configured and will retry this error, use the retry wrapper's WillRetryError. if retryWrapper != nil { wrapped := retryWrapper(err) @@ -372,10 +448,16 @@ func (m *eventSenderModel) buildErrWrapper(ctx context.Context) func(error) erro // failover still has more attempts remaining. Wrap it as WillRetryError so // the flow layer skips this event from the failed attempt. if hasFailover && failoverHasMore { + if errors.Is(err, ErrStreamCanceled) { + return err + } return &WillRetryError{ErrStr: err.Error(), err: err} } return err } + opts = append(opts, schema.WithErrWrapper(combinedErrWrapper)) + + return opts } func popToolGenAction(ctx context.Context, toolName string) *AgentAction { @@ -753,6 +835,17 @@ func (w *stateModelWrapper) Generate(ctx context.Context, input []*schema.Messag if err != nil { return nil, err } + + // Re-read State.Messages after Generate completes: when ShouldRetry uses + // PersistModifiedInputMessages, applyDecisionForRetry writes modified messages to State. + // We must pick up those changes before appending the model result. + if w.modelRetryConfig != nil && w.modelRetryConfig.ShouldRetry != nil { + _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { + state.Messages = st.Messages + return nil + }) + } + state.Messages = append(state.Messages, result) for _, handler := range w.handlers { @@ -823,6 +916,15 @@ func (w *stateModelWrapper) Stream(ctx context.Context, input []*schema.Message, if err != nil { return nil, err } + + // Re-read State.Messages after Stream completes: same rationale as in Generate above. + if w.modelRetryConfig != nil && w.modelRetryConfig.ShouldRetry != nil { + _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { + state.Messages = st.Messages + return nil + }) + } + state.Messages = append(state.Messages, result) for _, handler := range w.handlers { diff --git a/adk/wrappers_retry_failover_test.go b/adk/wrappers_retry_failover_test.go index 98db172e9..29c4b495a 100644 --- a/adk/wrappers_retry_failover_test.go +++ b/adk/wrappers_retry_failover_test.go @@ -29,383 +29,585 @@ import ( "github.com/cloudwego/eino/schema" ) -// TestRetryThenFailover_Generate_RetryExhaustedTriggersFailover tests the combined -// retry + failover path for Generate: m1 always fails, retry exhausted, failover to m2 which succeeds. -func TestRetryThenFailover_Generate_RetryExhaustedTriggersFailover(t *testing.T) { - modelErr := errors.New("model error") - var m1Calls int32 - var m2Calls int32 - - m1 := &fakeChatModel{ - callbacksEnabled: true, - generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { - atomic.AddInt32(&m1Calls, 1) - return nil, modelErr - }, - stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { +func newFakeChatModel( + gen func(context.Context, []*schema.Message, ...model.Option) (*schema.Message, error), + stream func(context.Context, []*schema.Message, ...model.Option) (*schema.StreamReader[*schema.Message], error), +) *fakeChatModel { + if gen == nil { + gen = func(context.Context, []*schema.Message, ...model.Option) (*schema.Message, error) { return nil, errors.New("unused") - }, + } } - m2 := &fakeChatModel{ - callbacksEnabled: true, - generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { - atomic.AddInt32(&m2Calls, 1) - return schema.AssistantMessage("ok from m2", nil), nil - }, - stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + if stream == nil { + stream = func(context.Context, []*schema.Message, ...model.Option) (*schema.StreamReader[*schema.Message], error) { return nil, errors.New("unused") - }, + } } - - retryCfg := &ModelRetryConfig{ - MaxRetries: 2, - IsRetryAble: func(_ context.Context, err error) bool { return true }, - BackoffFunc: func(_ context.Context, _ int) time.Duration { return 0 }, - } - - failoverCfg := &ModelFailoverConfig{ - MaxRetries: 1, - ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { - return err != nil - }, - GetFailoverModel: func(_ context.Context, fc *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { - require.NotNil(t, fc.LastErr) - return m2, nil, nil - }, - } - - wrapped := buildModelWrappers(m1, &modelWrapperConfig{ - retryConfig: retryCfg, - failoverConfig: failoverCfg, - }) - - ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ - failoverLastSuccessModel: m1, - }) - msg, err := wrapped.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) - require.NoError(t, err) - require.Equal(t, "ok from m2", msg.Content) - - // m1: 1 (lastSuccess) + 2 retries = 3 calls on lastSuccess attempt, - // then failover to m2 which also goes through retry wrapper: 1 call succeeds. - require.Equal(t, int32(3), atomic.LoadInt32(&m1Calls)) - require.Equal(t, int32(1), atomic.LoadInt32(&m2Calls)) + return &fakeChatModel{callbacksEnabled: true, generate: gen, stream: stream} } -// TestRetryThenFailover_Generate_AllExhausted tests: m1 retry exhausted → failover to m2 → m2 retry exhausted → final error. -func TestRetryThenFailover_Generate_AllExhausted(t *testing.T) { - modelErr := errors.New("always fails") - var m1Calls int32 - var m2Calls int32 +func TestRetryThenFailover(t *testing.T) { + t.Run("Generate_RetryExhaustedTriggersFailover", func(t *testing.T) { + modelErr := errors.New("model error") + var m1Calls int32 + var m2Calls int32 - m1 := &fakeChatModel{ - callbacksEnabled: true, - generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + m1 := newFakeChatModel(func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { atomic.AddInt32(&m1Calls, 1) return nil, modelErr - }, - stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { - return nil, errors.New("unused") - }, - } - m2 := &fakeChatModel{ - callbacksEnabled: true, - generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + }, nil) + m2 := newFakeChatModel(func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { atomic.AddInt32(&m2Calls, 1) - return nil, modelErr - }, - stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { - return nil, errors.New("unused") - }, - } - - retryCfg := &ModelRetryConfig{ - MaxRetries: 1, - IsRetryAble: func(_ context.Context, err error) bool { return true }, - BackoffFunc: func(_ context.Context, _ int) time.Duration { return 0 }, - } - - failoverCfg := &ModelFailoverConfig{ - MaxRetries: 1, - ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { - return err != nil - }, - GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { - return m2, nil, nil - }, - } - - wrapped := buildModelWrappers(m1, &modelWrapperConfig{ - retryConfig: retryCfg, - failoverConfig: failoverCfg, - }) - - ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ - failoverLastSuccessModel: m1, + return schema.AssistantMessage("ok from m2", nil), nil + }, nil) + + retryCfg := &ModelRetryConfig{ + MaxRetries: 2, + IsRetryAble: func(_ context.Context, err error) bool { return true }, + BackoffFunc: func(_ context.Context, _ int) time.Duration { return 0 }, + } + + failoverCfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + return err != nil + }, + GetFailoverModel: func(_ context.Context, fc *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + require.NotNil(t, fc.LastErr) + return m2, nil, nil + }, + } + + wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + retryConfig: retryCfg, + failoverConfig: failoverCfg, + }) + + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + msg, err := wrapped.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.NoError(t, err) + require.Equal(t, "ok from m2", msg.Content) + + // m1: 1 (lastSuccess) + 2 retries = 3 calls on lastSuccess attempt, + // then failover to m2 which also goes through retry wrapper: 1 call succeeds. + require.Equal(t, int32(3), atomic.LoadInt32(&m1Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&m2Calls)) }) - _, err := wrapped.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) - require.Error(t, err) - // Should be RetryExhaustedError from m2's retry wrapper - var retryErr *RetryExhaustedError - require.True(t, errors.As(err, &retryErr)) + t.Run("Generate_AllExhausted", func(t *testing.T) { + modelErr := errors.New("always fails") + var m1Calls int32 + var m2Calls int32 - // m1: 1 initial + 1 retry = 2 calls - require.Equal(t, int32(2), atomic.LoadInt32(&m1Calls)) - // m2: 1 initial + 1 retry = 2 calls - require.Equal(t, int32(2), atomic.LoadInt32(&m2Calls)) -} - -// TestRetryThenFailover_Stream_RetryExhaustedTriggersFailover tests stream path: -// m1 stream always errors mid-way, retry exhausted, failover to m2 which succeeds. -func TestRetryThenFailover_Stream_RetryExhaustedTriggersFailover(t *testing.T) { - streamErr := errors.New("stream mid error") - var m1Calls int32 - var m2Calls int32 - - m1 := &fakeChatModel{ - callbacksEnabled: true, - generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { - return nil, errors.New("unused") - }, - stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + m1 := newFakeChatModel(func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { atomic.AddInt32(&m1Calls, 1) - return streamWithMidError([]*schema.Message{ - schema.AssistantMessage("partial", nil), - }, streamErr), nil - }, - } - m2 := &fakeChatModel{ - callbacksEnabled: true, - generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { - return nil, errors.New("unused") - }, - stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, modelErr + }, nil) + m2 := newFakeChatModel(func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { atomic.AddInt32(&m2Calls, 1) - return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("ok from m2", nil)}), nil - }, - } - - retryCfg := &ModelRetryConfig{ - MaxRetries: 1, - IsRetryAble: func(_ context.Context, err error) bool { return true }, - BackoffFunc: func(_ context.Context, _ int) time.Duration { return 0 }, - } - - failoverCfg := &ModelFailoverConfig{ - MaxRetries: 1, - ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { - return err != nil - }, - GetFailoverModel: func(_ context.Context, fc *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { - require.NotNil(t, fc.LastErr) - return m2, nil, nil - }, - } - - wrapped := buildModelWrappers(m1, &modelWrapperConfig{ - retryConfig: retryCfg, - failoverConfig: failoverCfg, - }) - - ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ - failoverLastSuccessModel: m1, + return nil, modelErr + }, nil) + + retryCfg := &ModelRetryConfig{ + MaxRetries: 1, + IsRetryAble: func(_ context.Context, err error) bool { return true }, + BackoffFunc: func(_ context.Context, _ int) time.Duration { return 0 }, + } + + failoverCfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + return err != nil + }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return m2, nil, nil + }, + } + + wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + retryConfig: retryCfg, + failoverConfig: failoverCfg, + }) + + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + _, err := wrapped.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.Error(t, err) + + // Should be RetryExhaustedError from m2's retry wrapper + var retryErr *RetryExhaustedError + require.True(t, errors.As(err, &retryErr)) + + // m1: 1 initial + 1 retry = 2 calls + require.Equal(t, int32(2), atomic.LoadInt32(&m1Calls)) + // m2: 1 initial + 1 retry = 2 calls + require.Equal(t, int32(2), atomic.LoadInt32(&m2Calls)) }) - sr, err := wrapped.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) - require.NoError(t, err) - msgs, err := drainMessageStream(sr) - require.NoError(t, err) - require.Len(t, msgs, 1) - require.Equal(t, "ok from m2", msgs[0].Content) - - // m1: 1 initial + 1 retry = 2 calls on lastSuccess attempt - require.Equal(t, int32(2), atomic.LoadInt32(&m1Calls)) - require.Equal(t, int32(1), atomic.LoadInt32(&m2Calls)) -} -// TestRetryThenFailover_Generate_RetrySucceedsNoFailover tests that when retry -// succeeds on the first model, failover is never triggered. -func TestRetryThenFailover_Generate_RetrySucceedsNoFailover(t *testing.T) { - var m1Calls int32 - var failoverCalled int32 + t.Run("Generate_RetrySucceedsNoFailover", func(t *testing.T) { + var m1Calls int32 + var failoverCalled int32 - m1 := &fakeChatModel{ - callbacksEnabled: true, - generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + m1 := newFakeChatModel(func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { n := atomic.AddInt32(&m1Calls, 1) if n == 1 { return nil, errors.New("transient error") } return schema.AssistantMessage("ok on retry", nil), nil - }, - stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { - return nil, errors.New("unused") - }, - } - - retryCfg := &ModelRetryConfig{ - MaxRetries: 2, - IsRetryAble: func(_ context.Context, err error) bool { return true }, - BackoffFunc: func(_ context.Context, _ int) time.Duration { return 0 }, - } - - failoverCfg := &ModelFailoverConfig{ - MaxRetries: 1, - ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { - atomic.AddInt32(&failoverCalled, 1) - return true - }, - GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { - t.Fatal("GetFailoverModel should not be called when retry succeeds") - return nil, nil, nil - }, - } - - wrapped := buildModelWrappers(m1, &modelWrapperConfig{ - retryConfig: retryCfg, - failoverConfig: failoverCfg, + }, nil) + + retryCfg := &ModelRetryConfig{ + MaxRetries: 2, + IsRetryAble: func(_ context.Context, err error) bool { return true }, + BackoffFunc: func(_ context.Context, _ int) time.Duration { return 0 }, + } + + failoverCfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + atomic.AddInt32(&failoverCalled, 1) + return true + }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + t.Fatal("GetFailoverModel should not be called when retry succeeds") + return nil, nil, nil + }, + } + + wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + retryConfig: retryCfg, + failoverConfig: failoverCfg, + }) + + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + msg, err := wrapped.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.NoError(t, err) + require.Equal(t, "ok on retry", msg.Content) + + // 2 calls: first fails, second succeeds via retry + require.Equal(t, int32(2), atomic.LoadInt32(&m1Calls)) + // ShouldFailover should never be called + require.Equal(t, int32(0), atomic.LoadInt32(&failoverCalled)) }) - ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ - failoverLastSuccessModel: m1, - }) - msg, err := wrapped.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) - require.NoError(t, err) - require.Equal(t, "ok on retry", msg.Content) - - // 2 calls: first fails, second succeeds via retry - require.Equal(t, int32(2), atomic.LoadInt32(&m1Calls)) - // ShouldFailover should never be called - require.Equal(t, int32(0), atomic.LoadInt32(&failoverCalled)) -} + t.Run("Generate_NonRetryableErrorTriggersFailover", func(t *testing.T) { + nonRetryableErr := errors.New("non-retryable") + var m1Calls int32 + var m2Calls int32 -// TestRetryThenFailover_Generate_NonRetryableErrorTriggersFailover tests that a non-retryable -// error skips retry and directly triggers failover. -func TestRetryThenFailover_Generate_NonRetryableErrorTriggersFailover(t *testing.T) { - nonRetryableErr := errors.New("non-retryable") - var m1Calls int32 - var m2Calls int32 - - m1 := &fakeChatModel{ - callbacksEnabled: true, - generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + m1 := newFakeChatModel(func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { atomic.AddInt32(&m1Calls, 1) return nil, nonRetryableErr - }, - stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { - return nil, errors.New("unused") - }, - } - m2 := &fakeChatModel{ - callbacksEnabled: true, - generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + }, nil) + m2 := newFakeChatModel(func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { atomic.AddInt32(&m2Calls, 1) return schema.AssistantMessage("ok from m2", nil), nil - }, - stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { - return nil, errors.New("unused") - }, - } - - retryCfg := &ModelRetryConfig{ - MaxRetries: 3, - IsRetryAble: func(_ context.Context, err error) bool { - // Only non-retryable errors - return !errors.Is(err, nonRetryableErr) - }, - BackoffFunc: func(_ context.Context, _ int) time.Duration { return 0 }, - } - - failoverCfg := &ModelFailoverConfig{ - MaxRetries: 1, - ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { - return err != nil - }, - GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { - return m2, nil, nil - }, - } - - wrapped := buildModelWrappers(m1, &modelWrapperConfig{ - retryConfig: retryCfg, - failoverConfig: failoverCfg, + }, nil) + + retryCfg := &ModelRetryConfig{ + MaxRetries: 3, + IsRetryAble: func(_ context.Context, err error) bool { + // Only non-retryable errors + return !errors.Is(err, nonRetryableErr) + }, + BackoffFunc: func(_ context.Context, _ int) time.Duration { return 0 }, + } + + failoverCfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + return err != nil + }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return m2, nil, nil + }, + } + + wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + retryConfig: retryCfg, + failoverConfig: failoverCfg, + }) + + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + msg, err := wrapped.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.NoError(t, err) + require.Equal(t, "ok from m2", msg.Content) + + // m1 called only once — non-retryable error skips retry + require.Equal(t, int32(1), atomic.LoadInt32(&m1Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&m2Calls)) }) - ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ - failoverLastSuccessModel: m1, - }) - msg, err := wrapped.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) - require.NoError(t, err) - require.Equal(t, "ok from m2", msg.Content) + t.Run("Stream_RetryExhaustedTriggersFailover", func(t *testing.T) { + streamErr := errors.New("stream mid error") + var m1Calls int32 + var m2Calls int32 - // m1 called only once — non-retryable error skips retry - require.Equal(t, int32(1), atomic.LoadInt32(&m1Calls)) - require.Equal(t, int32(1), atomic.LoadInt32(&m2Calls)) -} + m1 := newFakeChatModel(nil, func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m1Calls, 1) + return streamWithMidError([]*schema.Message{ + schema.AssistantMessage("partial", nil), + }, streamErr), nil + }) + m2 := newFakeChatModel(nil, func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m2Calls, 1) + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("ok from m2", nil)}), nil + }) + + retryCfg := &ModelRetryConfig{ + MaxRetries: 1, + IsRetryAble: func(_ context.Context, err error) bool { return true }, + BackoffFunc: func(_ context.Context, _ int) time.Duration { return 0 }, + } + + failoverCfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + return err != nil + }, + GetFailoverModel: func(_ context.Context, fc *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + require.NotNil(t, fc.LastErr) + return m2, nil, nil + }, + } + + wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + retryConfig: retryCfg, + failoverConfig: failoverCfg, + }) + + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + sr, err := wrapped.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.NoError(t, err) + msgs, err := drainMessageStream(sr) + require.NoError(t, err) + require.Len(t, msgs, 1) + require.Equal(t, "ok from m2", msgs[0].Content) + + // m1: 1 initial + 1 retry = 2 calls on lastSuccess attempt + require.Equal(t, int32(2), atomic.LoadInt32(&m1Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&m2Calls)) + }) -// TestRetryThenFailover_Stream_AllExhausted tests stream path when both retry and failover are exhausted. -func TestRetryThenFailover_Stream_AllExhausted(t *testing.T) { - streamErr := errors.New("always fails mid-stream") - var m1Calls int32 - var m2Calls int32 + t.Run("Stream_AllExhausted", func(t *testing.T) { + streamErr := errors.New("always fails mid-stream") + var m1Calls int32 + var m2Calls int32 - m1 := &fakeChatModel{ - callbacksEnabled: true, - generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { - return nil, errors.New("unused") - }, - stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + m1 := newFakeChatModel(nil, func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { atomic.AddInt32(&m1Calls, 1) return streamWithMidError([]*schema.Message{ schema.AssistantMessage("p", nil), }, streamErr), nil - }, - } - m2 := &fakeChatModel{ - callbacksEnabled: true, - generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { - return nil, errors.New("unused") - }, - stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + }) + m2 := newFakeChatModel(nil, func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { atomic.AddInt32(&m2Calls, 1) return streamWithMidError([]*schema.Message{ schema.AssistantMessage("p", nil), }, streamErr), nil - }, - } + }) + + retryCfg := &ModelRetryConfig{ + MaxRetries: 1, + IsRetryAble: func(_ context.Context, err error) bool { return true }, + BackoffFunc: func(_ context.Context, _ int) time.Duration { return 0 }, + } + + failoverCfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + return err != nil + }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return m2, nil, nil + }, + } + + wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + retryConfig: retryCfg, + failoverConfig: failoverCfg, + }) + + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + _, err := wrapped.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.Error(t, err) + + var retryErr *RetryExhaustedError + require.True(t, errors.As(err, &retryErr)) + + // m1: 1 initial + 1 retry = 2 calls + require.Equal(t, int32(2), atomic.LoadInt32(&m1Calls)) + // m2: 1 initial + 1 retry = 2 calls + require.Equal(t, int32(2), atomic.LoadInt32(&m2Calls)) + }) - retryCfg := &ModelRetryConfig{ - MaxRetries: 1, - IsRetryAble: func(_ context.Context, err error) bool { return true }, - BackoffFunc: func(_ context.Context, _ int) time.Duration { return 0 }, - } + t.Run("ShouldRetry_Stream_TriggersFailover", func(t *testing.T) { + var m1Calls int32 + var m2Calls int32 - failoverCfg := &ModelFailoverConfig{ - MaxRetries: 1, - ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { - return err != nil - }, - GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { - return m2, nil, nil - }, - } + m1 := newFakeChatModel(nil, func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m1Calls, 1) + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("bad from m1", nil)}), nil + }) + m2 := newFakeChatModel(nil, func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m2Calls, 1) + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("good from m2", nil)}), nil + }) + + retryCfg := &ModelRetryConfig{ + MaxRetries: 1, + ShouldRetry: func(_ context.Context, retryCtx *RetryContext) *RetryDecision { + if retryCtx.OutputMessage != nil && retryCtx.OutputMessage.Content == "bad from m1" { + return &RetryDecision{Retry: true} + } + return &RetryDecision{Retry: false} + }, + BackoffFunc: func(_ context.Context, _ int) time.Duration { return 0 }, + } + + failoverCfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + return err != nil + }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return m2, nil, nil + }, + } + + wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + retryConfig: retryCfg, + failoverConfig: failoverCfg, + }) + + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + sr, err := wrapped.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.NoError(t, err) + msgs, err := drainMessageStream(sr) + require.NoError(t, err) + require.Len(t, msgs, 1) + require.Equal(t, "good from m2", msgs[0].Content) + require.Equal(t, int32(2), atomic.LoadInt32(&m1Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&m2Calls)) + }) - wrapped := buildModelWrappers(m1, &modelWrapperConfig{ - retryConfig: retryCfg, - failoverConfig: failoverCfg, + t.Run("ShouldRetry_Generate_TriggersFailover", func(t *testing.T) { + var m1Calls int32 + var m2Calls int32 + + m1 := newFakeChatModel(func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m1Calls, 1) + return schema.AssistantMessage("bad from m1", nil), nil + }, nil) + m2 := newFakeChatModel(func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m2Calls, 1) + return schema.AssistantMessage("good from m2", nil), nil + }, nil) + + retryCfg := &ModelRetryConfig{ + MaxRetries: 1, + ShouldRetry: func(_ context.Context, retryCtx *RetryContext) *RetryDecision { + if retryCtx.OutputMessage != nil && retryCtx.OutputMessage.Content == "bad from m1" { + return &RetryDecision{Retry: true} + } + return &RetryDecision{Retry: false} + }, + BackoffFunc: func(_ context.Context, _ int) time.Duration { return 0 }, + } + + failoverCfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + return err != nil + }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return m2, nil, nil + }, + } + + wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + retryConfig: retryCfg, + failoverConfig: failoverCfg, + }) + + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + msg, err := wrapped.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.NoError(t, err) + require.Equal(t, "good from m2", msg.Content) + require.Equal(t, int32(2), atomic.LoadInt32(&m1Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&m2Calls)) }) - ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ - failoverLastSuccessModel: m1, + t.Run("Stream_GetFailoverModelReturnsNilModel", func(t *testing.T) { + streamErr := errors.New("m1 always fails") + var m1Calls int32 + + m1 := newFakeChatModel(nil, func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m1Calls, 1) + return nil, streamErr + }) + + retryCfg := &ModelRetryConfig{ + MaxRetries: 0, + IsRetryAble: func(_ context.Context, err error) bool { return false }, + BackoffFunc: func(_ context.Context, _ int) time.Duration { return 0 }, + } + + failoverCfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + return err != nil + }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return nil, nil, nil + }, + } + + wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + retryConfig: retryCfg, + failoverConfig: failoverCfg, + }) + + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + _, err := wrapped.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.Error(t, err) + require.Contains(t, err.Error(), "returned nil model at attempt") + require.Equal(t, int32(1), atomic.LoadInt32(&m1Calls)) }) - _, err := wrapped.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) - require.Error(t, err) - var retryErr *RetryExhaustedError - require.True(t, errors.As(err, &retryErr)) + t.Run("Stream_ContextCanceledDuringFailover", func(t *testing.T) { + streamErr := errors.New("m1 fails") + var m1Calls int32 + var failoverModelCalled int32 - // m1: 1 initial + 1 retry = 2 calls - require.Equal(t, int32(2), atomic.LoadInt32(&m1Calls)) - // m2: 1 initial + 1 retry = 2 calls - require.Equal(t, int32(2), atomic.LoadInt32(&m2Calls)) + m1 := newFakeChatModel(nil, func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m1Calls, 1) + return nil, streamErr + }) + + ctx, cancel := context.WithCancel(context.Background()) + + retryCfg := &ModelRetryConfig{ + MaxRetries: 0, + IsRetryAble: func(_ context.Context, err error) bool { return false }, + BackoffFunc: func(_ context.Context, _ int) time.Duration { return 0 }, + } + + failoverCfg := &ModelFailoverConfig{ + MaxRetries: 3, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + cancel() + return err != nil + }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + atomic.AddInt32(&failoverModelCalled, 1) + return nil, nil, nil + }, + } + + wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + retryConfig: retryCfg, + failoverConfig: failoverCfg, + }) + + ctx = withChatModelAgentExecCtx(ctx, &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + _, err := wrapped.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) + require.Equal(t, int32(1), atomic.LoadInt32(&m1Calls)) + require.Equal(t, int32(0), atomic.LoadInt32(&failoverModelCalled)) + }) +} + +func TestErrStreamCanceled_Failover(t *testing.T) { + t.Run("Stream_NeverFailedOver", func(t *testing.T) { + var m1Calls int32 + var failoverCalled int32 + + m1 := newFakeChatModel(nil, func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m1Calls, 1) + return streamWithMidError([]*schema.Message{ + schema.AssistantMessage("partial", nil), + }, ErrStreamCanceled), nil + }) + + failoverCfg := &ModelFailoverConfig{ + MaxRetries: 2, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + atomic.AddInt32(&failoverCalled, 1) + return true + }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + t.Fatal("GetFailoverModel should not be called for ErrStreamCanceled") + return nil, nil, nil + }, + } + + wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + failoverConfig: failoverCfg, + }) + + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + _, err := wrapped.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.Error(t, err) + require.True(t, errors.Is(err, ErrStreamCanceled)) + require.Equal(t, int32(1), atomic.LoadInt32(&m1Calls)) + require.Equal(t, int32(0), atomic.LoadInt32(&failoverCalled)) + }) + + t.Run("Generate_NeverFailedOver", func(t *testing.T) { + var m1Calls int32 + var failoverCalled int32 + + m1 := newFakeChatModel(func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m1Calls, 1) + return nil, ErrStreamCanceled + }, nil) + + failoverCfg := &ModelFailoverConfig{ + MaxRetries: 2, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + atomic.AddInt32(&failoverCalled, 1) + return true + }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + t.Fatal("GetFailoverModel should not be called for ErrStreamCanceled") + return nil, nil, nil + }, + } + + wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + failoverConfig: failoverCfg, + }) + + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + _, err := wrapped.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.Error(t, err) + require.True(t, errors.Is(err, ErrStreamCanceled)) + require.Equal(t, int32(1), atomic.LoadInt32(&m1Calls)) + require.Equal(t, int32(0), atomic.LoadInt32(&failoverCalled)) + }) } diff --git a/schema/stream.go b/schema/stream.go index 67b855b27..5625efe56 100644 --- a/schema/stream.go +++ b/schema/stream.go @@ -599,6 +599,8 @@ type streamReaderWithConvert[T any] struct { convert func(any) (T, error) errWrapper func(error) error + onEOF func() (T, error) + eofDone bool } func newStreamReaderWithConvert[T any](origin iStreamReader, convert func(any) (T, error), opts ...ConvertOption) *StreamReader[T] { @@ -613,6 +615,22 @@ func newStreamReaderWithConvert[T any](origin iStreamReader, convert func(any) ( errWrapper: opt.ErrWrapper, } + if opt.OnEOF != nil { + typedOnEOF := opt.OnEOF + srw.onEOF = func() (T, error) { + v, err := typedOnEOF() + if err != nil { + var t T + return t, err + } + if v == nil { + var t T + return t, nil + } + return v.(T), nil + } + } + return &StreamReader[T]{ typ: readerTypeWithConvert, srw: srw, @@ -621,6 +639,7 @@ func newStreamReaderWithConvert[T any](origin iStreamReader, convert func(any) ( type convertOptions struct { ErrWrapper func(error) error + OnEOF func() (any, error) } type ConvertOption func(*convertOptions) @@ -637,6 +656,17 @@ func WithErrWrapper(wrapper func(error) error) ConvertOption { } } +// WithOnEOF registers a callback that fires once when the stream reaches EOF. +// The callback can inject an error or a value before the final io.EOF is returned. +// If the callback returns (nil, io.EOF), the stream ends normally. +// If it returns a non-EOF error, that error is delivered first, then subsequent Recv returns io.EOF. +// If it returns a non-nil value with nil error, that value is delivered first, then io.EOF. +func WithOnEOF(fn func() (any, error)) ConvertOption { + return func(o *convertOptions) { + o.OnEOF = fn + } +} + // StreamReaderWithConvert returns a new StreamReader[D] that wraps sr and // applies convert to every element. The original reader sr must not be used // after calling this function. @@ -673,7 +703,14 @@ func (srw *streamReaderWithConvert[T]) recv() (T, error) { if err != nil { var t T if err == io.EOF { - return t, err + if srw.onEOF != nil && !srw.eofDone { + srw.eofDone = true + val, onEOFErr := srw.onEOF() + if onEOFErr != io.EOF { + return val, onEOFErr + } + } + return t, io.EOF } if srw.errWrapper != nil { err = srw.errWrapper(err) diff --git a/schema/stream_oneof_test.go b/schema/stream_oneof_test.go new file mode 100644 index 000000000..740836de1 --- /dev/null +++ b/schema/stream_oneof_test.go @@ -0,0 +1,324 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package schema_test + +import ( + "errors" + "io" + "testing" + "time" + + "github.com/cloudwego/eino/schema" +) + +func recvAll(t *testing.T, sr *schema.StreamReader[string]) ([]string, []error) { + t.Helper() + var vals []string + var errs []error + for { + v, err := sr.Recv() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + errs = append(errs, err) + continue + } + vals = append(vals, v) + } + return vals, errs +} + +func makeStream(items []string, opts ...schema.ConvertOption) *schema.StreamReader[string] { + return schema.StreamReaderWithConvert( + schema.StreamReaderFromArray(items), + func(s string) (string, error) { return s, nil }, + opts..., + ) +} + +func TestWithOnEOF_PassThroughEOF(t *testing.T) { + items := []string{"a", "b", "c", "d"} + sr := makeStream(items, schema.WithOnEOF(func() (any, error) { + return nil, io.EOF + })) + defer sr.Close() + + vals, errs := recvAll(t, sr) + if len(errs) != 0 { + t.Fatalf("expected no errors, got %v", errs) + } + if len(vals) != 4 { + t.Fatalf("expected 4 values, got %d: %v", len(vals), vals) + } + for i, want := range items { + if vals[i] != want { + t.Errorf("vals[%d] = %q, want %q", i, vals[i], want) + } + } +} + +func TestWithOnEOF_InjectError(t *testing.T) { + items := []string{"a", "b", "c", "d"} + customErr := errors.New("validation failed") + sr := makeStream(items, schema.WithOnEOF(func() (any, error) { + return nil, customErr + })) + defer sr.Close() + + var vals []string + var gotCustomErr bool + for { + v, err := sr.Recv() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + if errors.Is(err, customErr) { + gotCustomErr = true + continue + } + t.Fatalf("unexpected error: %v", err) + } + vals = append(vals, v) + } + + if len(vals) != 4 { + t.Fatalf("expected 4 values, got %d: %v", len(vals), vals) + } + if !gotCustomErr { + t.Fatalf("expected custom error from onEOF, did not receive it") + } +} + +func TestWithOnEOF_InjectValue(t *testing.T) { + items := []string{"a", "b", "c", "d"} + sr := makeStream(items, schema.WithOnEOF(func() (any, error) { + return "extra", nil + })) + defer sr.Close() + + var vals []string + for { + v, err := sr.Recv() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + vals = append(vals, v) + } + + if len(vals) != 5 { + t.Fatalf("expected 5 values, got %d: %v", len(vals), vals) + } + if vals[4] != "extra" { + t.Errorf("vals[4] = %q, want %q", vals[4], "extra") + } +} + +func TestWithOnEOF_BlockingCallback(t *testing.T) { + sr, sw := schema.Pipe[string](0) + + unblock := make(chan struct{}) + converted := schema.StreamReaderWithConvert(sr, + func(s string) (string, error) { return s, nil }, + schema.WithOnEOF(func() (any, error) { + <-unblock + return "after-block", nil + }), + ) + defer converted.Close() + + go func() { + sw.Send("x", nil) + sw.Close() + }() + + v, err := converted.Recv() + if err != nil { + t.Fatalf("first Recv error: %v", err) + } + if v != "x" { + t.Fatalf("first Recv = %q, want %q", v, "x") + } + + done := make(chan struct{}) + var recvVal string + var recvErr error + go func() { + recvVal, recvErr = converted.Recv() + close(done) + }() + + select { + case <-done: + t.Fatal("Recv returned before unblock signal") + case <-time.After(50 * time.Millisecond): + } + + close(unblock) + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("Recv did not return after unblock signal") + } + + if recvErr != nil { + t.Fatalf("second Recv error: %v", recvErr) + } + if recvVal != "after-block" { + t.Errorf("second Recv = %q, want %q", recvVal, "after-block") + } + + v3, err3 := converted.Recv() + if !errors.Is(err3, io.EOF) { + t.Fatalf("third Recv: got (%q, %v), want EOF", v3, err3) + } +} + +func TestWithOnEOF_EmptyStream(t *testing.T) { + customErr := errors.New("empty stream error") + sr := makeStream(nil, schema.WithOnEOF(func() (any, error) { + return nil, customErr + })) + defer sr.Close() + + v, err := sr.Recv() + if !errors.Is(err, customErr) { + t.Fatalf("first Recv: got (%q, %v), want customErr", v, err) + } + + v2, err2 := sr.Recv() + if !errors.Is(err2, io.EOF) { + t.Fatalf("second Recv: got (%q, %v), want EOF", v2, err2) + } +} + +func TestWithOnEOF_WithErrWrapper_ErrorPath(t *testing.T) { + sr, sw := schema.Pipe[string](0) + + streamErr := errors.New("stream error") + onEOFCalled := false + + converted := schema.StreamReaderWithConvert(sr, + func(s string) (string, error) { return s, nil }, + schema.WithErrWrapper(func(err error) error { + return err + }), + schema.WithOnEOF(func() (any, error) { + onEOFCalled = true + return nil, errors.New("should not happen") + }), + ) + defer converted.Close() + + go func() { + sw.Send("a", nil) + sw.Send("", streamErr) + sw.Close() + }() + + v, err := converted.Recv() + if err != nil { + t.Fatalf("first Recv error: %v", err) + } + if v != "a" { + t.Fatalf("first Recv = %q, want %q", v, "a") + } + + _, err = converted.Recv() + if !errors.Is(err, streamErr) { + t.Fatalf("second Recv: got %v, want streamErr", err) + } + + if onEOFCalled { + t.Fatal("onEOF should not have been called when stream errored") + } +} + +func TestWithOnEOF_WithErrWrapper_EOFPath(t *testing.T) { + items := []string{"a", "b", "c"} + errWrapperCalled := false + + sr := schema.StreamReaderWithConvert( + schema.StreamReaderFromArray(items), + func(s string) (string, error) { return s, nil }, + schema.WithErrWrapper(func(err error) error { + errWrapperCalled = true + return err + }), + schema.WithOnEOF(func() (any, error) { + return "oneof-val", nil + }), + ) + defer sr.Close() + + var vals []string + for { + v, err := sr.Recv() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + vals = append(vals, v) + } + + if len(vals) != 4 { + t.Fatalf("expected 4 values, got %d: %v", len(vals), vals) + } + if vals[3] != "oneof-val" { + t.Errorf("vals[3] = %q, want %q", vals[3], "oneof-val") + } + if errWrapperCalled { + t.Fatal("errWrapper should not have been called for clean stream") + } +} + +func TestWithOnEOF_MultipleRecvAfterEOF(t *testing.T) { + items := []string{"a"} + customErr := errors.New("oneof error") + + sr := makeStream(items, schema.WithOnEOF(func() (any, error) { + return nil, customErr + })) + defer sr.Close() + + v, err := sr.Recv() + if err != nil { + t.Fatalf("first Recv error: %v", err) + } + if v != "a" { + t.Fatalf("first Recv = %q, want %q", v, "a") + } + + _, err = sr.Recv() + if !errors.Is(err, customErr) { + t.Fatalf("second Recv: got %v, want customErr", err) + } + + for i := 0; i < 5; i++ { + _, err = sr.Recv() + if !errors.Is(err, io.EOF) { + t.Fatalf("Recv #%d after onEOF: got %v, want io.EOF", i+3, err) + } + } +} From bef5008cd25c08c019c62a10abe3b033b91146d9 Mon Sep 17 00:00:00 2001 From: shentongmartin Date: Thu, 16 Apr 2026 14:46:38 +0800 Subject: [PATCH 56/65] docs(adk): add NOT RECOMMENDED advisory to agent transfer and workflow APIs (#970) --- adk/call_option.go | 4 +++ adk/chatmodel.go | 23 ++++++++++++++++ adk/deterministic_transfer.go | 4 +++ adk/flow.go | 24 ++++++++++++++++- adk/interface.go | 34 ++++++++++++++++++++---- adk/prebuilt/supervisor/supervisor.go | 9 +++++++ adk/utils.go | 4 +++ adk/workflow.go | 38 ++++++++++++++++++++++++++- compose/checkpoint_test.go | 28 -------------------- compose/graph_run.go | 36 +++++++++---------------- compose/interrupt.go | 4 --- 11 files changed, 146 insertions(+), 62 deletions(-) diff --git a/adk/call_option.go b/adk/call_option.go index ead6ae636..7a1cc1b65 100644 --- a/adk/call_option.go +++ b/adk/call_option.go @@ -56,6 +56,10 @@ func WithSessionValues(v map[string]any) AgentRunOption { } // WithSkipTransferMessages disables forwarding transfer messages during execution. +// +// NOT RECOMMENDED: Agent transfer with full context sharing between agents has not proven +// to be more effective empirically. Consider using ChatModelAgent with AgentTool +// or DeepAgent instead for most multi-agent scenarios. func WithSkipTransferMessages() AgentRunOption { return WrapImplSpecificOptFn(func(t *options) { t.skipTransferMessages = true diff --git a/adk/chatmodel.go b/adk/chatmodel.go index 0f1f9f0a8..f1155a1cf 100644 --- a/adk/chatmodel.go +++ b/adk/chatmodel.go @@ -238,10 +238,18 @@ type ChatModelAgentConfig struct { // Exit defines the tool used to terminate the agent process. // Optional. If nil, no Exit Action will be generated. // You can use the provided 'ExitTool' implementation directly. + // + // NOT RECOMMENDED: Agent transfer with full context sharing between agents has not proven + // to be more effective empirically. Consider using ChatModelAgent with AgentTool + // or DeepAgent instead for most multi-agent scenarios. Exit tool.BaseTool // OutputKey stores the agent's response in the session. // Optional. When set, stores output via AddSessionValue(ctx, outputKey, msg.Content). + // + // NOT RECOMMENDED: Agent transfer with full context sharing between agents has not proven + // to be more effective empirically. Consider using ChatModelAgent with AgentTool + // or DeepAgent instead for most multi-agent scenarios. OutputKey string // MaxIterations defines the upper limit of ChatModel generation cycles. @@ -584,6 +592,11 @@ func (a *ChatModelAgent) GetType() string { return "ChatModel" } +// OnSetSubAgents implements OnSubAgents. +// +// NOT RECOMMENDED: Agent transfer with full context sharing between agents has not proven +// to be more effective empirically. Consider using ChatModelAgent with AgentTool +// or DeepAgent instead for most multi-agent scenarios. func (a *ChatModelAgent) OnSetSubAgents(_ context.Context, subAgents []Agent) error { if atomic.LoadUint32(&a.frozen) == 1 { return errors.New("agent has been frozen after run") @@ -597,6 +610,11 @@ func (a *ChatModelAgent) OnSetSubAgents(_ context.Context, subAgents []Agent) er return nil } +// OnSetAsSubAgent implements OnSubAgents. +// +// NOT RECOMMENDED: Agent transfer with full context sharing between agents has not proven +// to be more effective empirically. Consider using ChatModelAgent with AgentTool +// or DeepAgent instead for most multi-agent scenarios. func (a *ChatModelAgent) OnSetAsSubAgent(_ context.Context, parent Agent) error { if atomic.LoadUint32(&a.frozen) == 1 { return errors.New("agent has been frozen after run") @@ -610,6 +628,11 @@ func (a *ChatModelAgent) OnSetAsSubAgent(_ context.Context, parent Agent) error return nil } +// OnDisallowTransferToParent implements OnSubAgents. +// +// NOT RECOMMENDED: Agent transfer with full context sharing between agents has not proven +// to be more effective empirically. Consider using ChatModelAgent with AgentTool +// or DeepAgent instead for most multi-agent scenarios. func (a *ChatModelAgent) OnDisallowTransferToParent(_ context.Context) error { if atomic.LoadUint32(&a.frozen) == 1 { return errors.New("agent has been frozen after run") diff --git a/adk/deterministic_transfer.go b/adk/deterministic_transfer.go index e9c9f4ef8..dc677a007 100644 --- a/adk/deterministic_transfer.go +++ b/adk/deterministic_transfer.go @@ -36,6 +36,10 @@ type deterministicTransferState struct { } // AgentWithDeterministicTransferTo wraps an agent to transfer to given agents deterministically. +// +// NOT RECOMMENDED: Agent transfer with full context sharing between agents has not proven +// to be more effective empirically. Consider using ChatModelAgent with AgentTool +// or DeepAgent instead for most multi-agent scenarios. func AgentWithDeterministicTransferTo(_ context.Context, config *DeterministicTransferConfig) Agent { if ra, ok := config.Agent.(ResumableAgent); ok { return &resumableAgentWithDeterministicTransferTo{ diff --git a/adk/flow.go b/adk/flow.go index 52a346c74..8edc002a0 100644 --- a/adk/flow.go +++ b/adk/flow.go @@ -68,6 +68,10 @@ func (a *flowAgent) deepCopy() *flowAgent { } // SetSubAgents sets sub-agents for the given agent and returns the updated agent. +// +// NOT RECOMMENDED: Agent transfer with full context sharing between agents has not proven +// to be more effective empirically. Consider using ChatModelAgent with AgentTool +// or DeepAgent instead for most multi-agent scenarios. func SetSubAgents(ctx context.Context, agent Agent, subAgents []Agent) (ResumableAgent, error) { return setSubAgents(ctx, agent, subAgents) } @@ -75,13 +79,22 @@ func SetSubAgents(ctx context.Context, agent Agent, subAgents []Agent) (Resumabl type AgentOption func(options *flowAgent) // WithDisallowTransferToParent prevents a sub-agent from transferring to its parent. +// +// NOT RECOMMENDED: Agent transfer with full context sharing between agents has not proven +// to be more effective empirically. Consider using ChatModelAgent with AgentTool +// or DeepAgent instead for most multi-agent scenarios. func WithDisallowTransferToParent() AgentOption { return func(fa *flowAgent) { fa.disallowTransferToParent = true } } -// WithHistoryRewriter sets a rewriter to transform conversation history. +// WithHistoryRewriter sets a rewriter to transform conversation history +// during agent transfers. +// +// NOT RECOMMENDED: Agent transfer with full context sharing between agents has not proven +// to be more effective empirically. Consider using ChatModelAgent with AgentTool +// or DeepAgent instead for most multi-agent scenarios. func WithHistoryRewriter(h HistoryRewriter) AgentOption { return func(fa *flowAgent) { fa.historyRewriter = h @@ -108,6 +121,10 @@ func toFlowAgent(ctx context.Context, agent Agent, opts ...AgentOption) *flowAge } // AgentWithOptions wraps an agent with flow-specific options and returns it. +// +// NOT RECOMMENDED: Agent transfer with full context sharing between agents has not proven +// to be more effective empirically. Consider using ChatModelAgent with AgentTool +// or DeepAgent instead for most multi-agent scenarios. func AgentWithOptions(ctx context.Context, agent Agent, opts ...AgentOption) Agent { return toFlowAgent(ctx, agent, opts...) } @@ -448,6 +465,11 @@ func (a *flowAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...AgentR return wrapIterWithCancelCtx(wrapIterWithOnEnd(ctx, innerIter), cancelCtx) } +// DeterministicTransferConfig is the configuration for AgentWithDeterministicTransferTo. +// +// NOT RECOMMENDED: Agent transfer with full context sharing between agents has not proven +// to be more effective empirically. Consider using ChatModelAgent with AgentTool +// or DeepAgent instead for most multi-agent scenarios. type DeterministicTransferConfig struct { Agent Agent ToAgentNames []string diff --git a/adk/interface.go b/adk/interface.go index 5c06843ae..e1f17eca7 100644 --- a/adk/interface.go +++ b/adk/interface.go @@ -134,6 +134,11 @@ func (mv *MessageVariant) GetMessage() (Message, error) { return message, nil } +// TransferToAgentAction represents a transfer-to-agent action. +// +// NOT RECOMMENDED: Agent transfer with full context sharing between agents has not proven +// to be more effective empirically. Consider using ChatModelAgent with AgentTool +// or DeepAgent instead for most multi-agent scenarios. type TransferToAgentAction struct { DestAgentName string } @@ -145,11 +150,19 @@ type AgentOutput struct { } // NewTransferToAgentAction creates an action to transfer to the specified agent. +// +// NOT RECOMMENDED: Agent transfer with full context sharing between agents has not proven +// to be more effective empirically. Consider using ChatModelAgent with AgentTool +// or DeepAgent instead for most multi-agent scenarios. func NewTransferToAgentAction(destAgentName string) *AgentAction { return &AgentAction{TransferToAgent: &TransferToAgentAction{DestAgentName: destAgentName}} } // NewExitAction creates an action that signals the agent to exit. +// +// NOT RECOMMENDED: Agent transfer with full context sharing between agents has not proven +// to be more effective empirically. Consider using ChatModelAgent with AgentTool +// or DeepAgent instead for most multi-agent scenarios. func NewExitAction() *AgentAction { return &AgentAction{Exit: true} } @@ -179,7 +192,12 @@ type AgentAction struct { internalInterrupted *core.InterruptSignal } -// RunStep CheckpointSchema: persisted via serialization.RunCtx (gob). +// RunStep represents a step in the agent execution path. +// CheckpointSchema: persisted via serialization.RunCtx (gob). +// +// NOT RECOMMENDED: RunStep is mainly relevant for agent transfer and workflow agents, +// which have not proven to be more effective empirically. Consider using ChatModelAgent +// with AgentTool or DeepAgent instead for most multi-agent scenarios. type RunStep struct { agentName string } @@ -225,10 +243,11 @@ type AgentEvent struct { AgentName string // RunPath represents the execution path from root agent to the current event source. - // This field is managed entirely by the eino framework and cannot be set by end-users - // because RunStep's fields are unexported. The framework sets RunPath exactly once: - // - flowAgent sets it when the event has no RunPath (len == 0) - // - agentTool prepends parent RunPath when forwarding events from nested agents + // This field is managed entirely by the framework and cannot be set by end-users. + // + // NOT RECOMMENDED: RunPath is mainly relevant for agent transfer and workflow agents, + // which have not proven to be more effective empirically. For ChatModelAgent with + // AgentTool or DeepAgent, RunPath is trivial. Consider those patterns instead. RunPath []RunStep Output *AgentOutput @@ -257,6 +276,11 @@ type Agent interface { Run(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] } +// OnSubAgents is the interface for agents that support sub-agent registration and transfer. +// +// NOT RECOMMENDED: Agent transfer with full context sharing between agents has not proven +// to be more effective empirically. Consider using ChatModelAgent with AgentTool +// or DeepAgent instead for most multi-agent scenarios. type OnSubAgents interface { OnSetSubAgents(ctx context.Context, subAgents []Agent) error OnSetAsSubAgent(ctx context.Context, parent Agent) error diff --git a/adk/prebuilt/supervisor/supervisor.go b/adk/prebuilt/supervisor/supervisor.go index e461ff190..62e6d1ddc 100644 --- a/adk/prebuilt/supervisor/supervisor.go +++ b/adk/prebuilt/supervisor/supervisor.go @@ -37,6 +37,11 @@ import ( "github.com/cloudwego/eino/adk" ) +// Config is the configuration for creating a supervisor-based multi-agent system. +// +// NOT RECOMMENDED: Supervisor is built on agent transfer with full context sharing, +// which has not proven to be more effective empirically. Consider using +// ChatModelAgent with AgentTool or DeepAgent instead for most multi-agent scenarios. type Config struct { // Supervisor specifies the agent that will act as the supervisor, coordinating and managing the sub-agents. Supervisor adk.Agent @@ -89,6 +94,10 @@ func (s *supervisorContainer) Resume(ctx context.Context, info *adk.ResumeInfo, // When used with Runner and callbacks, all agents within the supervisor structure will // share the same trace root, making it easy to observe the entire multi-agent execution // as a single logical unit. +// +// NOT RECOMMENDED: Supervisor is built on agent transfer with full context sharing, +// which has not proven to be more effective empirically. Consider using +// ChatModelAgent with AgentTool or DeepAgent instead for most multi-agent scenarios. func New(ctx context.Context, conf *Config) (adk.ResumableAgent, error) { subAgents := make([]adk.Agent, 0, len(conf.SubAgents)) supervisorName := conf.Supervisor.Name(ctx) diff --git a/adk/utils.go b/adk/utils.go index 5dd890be8..89b991324 100644 --- a/adk/utils.go +++ b/adk/utils.go @@ -89,6 +89,10 @@ func concatInstructions(instructions ...string) string { // GenTransferMessages generates assistant and tool messages to instruct a // transfer-to-agent tool call targeting the destination agent. +// +// NOT RECOMMENDED: Agent transfer with full context sharing between agents has not proven +// to be more effective empirically. Consider using ChatModelAgent with AgentTool +// or DeepAgent instead for most multi-agent scenarios. func GenTransferMessages(_ context.Context, destAgentName string) (Message, Message) { toolCallID := uuid.NewString() tooCall := schema.ToolCall{ID: toolCallID, Function: schema.FunctionCall{Name: TransferToAgentToolName, Arguments: destAgentName}} diff --git a/adk/workflow.go b/adk/workflow.go index 00411e33b..161c43497 100644 --- a/adk/workflow.go +++ b/adk/workflow.go @@ -157,7 +157,12 @@ func (a *workflowAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...Ag return iterator } -// WorkflowInterruptInfo CheckpointSchema: persisted via InterruptInfo.Data (gob). +// WorkflowInterruptInfo stores interrupt information for workflow agents. +// CheckpointSchema: persisted via InterruptInfo.Data (gob). +// +// NOT RECOMMENDED: Workflow agents are built on agent transfer with full context sharing, +// which has not proven to be more effective empirically. Consider using +// ChatModelAgent with AgentTool or DeepAgent instead for most multi-agent scenarios. type WorkflowInterruptInfo struct { OrigInput *AgentInput @@ -303,6 +308,10 @@ type BreakLoopAction struct { // NewBreakLoopAction creates a new BreakLoopAction, signaling a request // to terminate the current loop. +// +// NOT RECOMMENDED: Workflow agents are built on agent transfer with full context sharing, +// which has not proven to be more effective empirically. Consider using +// ChatModelAgent with AgentTool or DeepAgent instead for most multi-agent scenarios. func NewBreakLoopAction(agentName string) *AgentAction { return &AgentAction{BreakLoop: &BreakLoopAction{ From: agentName, @@ -608,18 +617,33 @@ func cancelAtTransition(ctx context.Context, info string, state any) *AgentEvent } } +// SequentialAgentConfig is the configuration for NewSequentialAgent. +// +// NOT RECOMMENDED: Workflow agents are built on agent transfer with full context sharing, +// which has not proven to be more effective empirically. Consider using +// ChatModelAgent with AgentTool or DeepAgent instead for most multi-agent scenarios. type SequentialAgentConfig struct { Name string Description string SubAgents []Agent } +// ParallelAgentConfig is the configuration for NewParallelAgent. +// +// NOT RECOMMENDED: Workflow agents are built on agent transfer with full context sharing, +// which has not proven to be more effective empirically. Consider using +// ChatModelAgent with AgentTool or DeepAgent instead for most multi-agent scenarios. type ParallelAgentConfig struct { Name string Description string SubAgents []Agent } +// LoopAgentConfig is the configuration for NewLoopAgent. +// +// NOT RECOMMENDED: Workflow agents are built on agent transfer with full context sharing, +// which has not proven to be more effective empirically. Consider using +// ChatModelAgent with AgentTool or DeepAgent instead for most multi-agent scenarios. type LoopAgentConfig struct { Name string Description string @@ -655,16 +679,28 @@ func newWorkflowAgent(ctx context.Context, name, desc string, } // NewSequentialAgent creates an agent that runs sub-agents sequentially. +// +// NOT RECOMMENDED: Workflow agents are built on agent transfer with full context sharing, +// which has not proven to be more effective empirically. Consider using +// ChatModelAgent with AgentTool or DeepAgent instead for most multi-agent scenarios. func NewSequentialAgent(ctx context.Context, config *SequentialAgentConfig) (ResumableAgent, error) { return newWorkflowAgent(ctx, config.Name, config.Description, config.SubAgents, workflowAgentModeSequential, 0) } // NewParallelAgent creates an agent that runs sub-agents in parallel. +// +// NOT RECOMMENDED: Workflow agents are built on agent transfer with full context sharing, +// which has not proven to be more effective empirically. Consider using +// ChatModelAgent with AgentTool or DeepAgent instead for most multi-agent scenarios. func NewParallelAgent(ctx context.Context, config *ParallelAgentConfig) (ResumableAgent, error) { return newWorkflowAgent(ctx, config.Name, config.Description, config.SubAgents, workflowAgentModeParallel, 0) } // NewLoopAgent creates an agent that loops over sub-agents with a max iteration limit. +// +// NOT RECOMMENDED: Workflow agents are built on agent transfer with full context sharing, +// which has not proven to be more effective empirically. Consider using +// ChatModelAgent with AgentTool or DeepAgent instead for most multi-agent scenarios. func NewLoopAgent(ctx context.Context, config *LoopAgentConfig) (ResumableAgent, error) { return newWorkflowAgent(ctx, config.Name, config.Description, config.SubAgents, workflowAgentModeLoop, config.MaxIterations) } diff --git a/compose/checkpoint_test.go b/compose/checkpoint_test.go index a86c02fb3..c24b6ce6f 100644 --- a/compose/checkpoint_test.go +++ b/compose/checkpoint_test.go @@ -1383,7 +1383,6 @@ func TestCancelInterrupt(t *testing.T) { info, success := ExtractInterruptInfo(err) assert.True(t, success) assert.Equal(t, []string{"1"}, info.AfterNodes) - assert.True(t, info.FromGraphInterrupt) result, err := r.Invoke(ctx, "input", WithCheckPointID("1")) assert.NoError(t, err) assert.Equal(t, "input12", result) @@ -1398,7 +1397,6 @@ func TestCancelInterrupt(t *testing.T) { info, success = ExtractInterruptInfo(err) assert.True(t, success) assert.Equal(t, []string{"1"}, info.AfterNodes) - assert.True(t, info.FromGraphInterrupt) result, err = r.Invoke(ctx, "input", WithCheckPointID("2")) assert.NoError(t, err) assert.Equal(t, "input12", result) @@ -1414,7 +1412,6 @@ func TestCancelInterrupt(t *testing.T) { info, success = ExtractInterruptInfo(err) assert.True(t, success) assert.Equal(t, []string{"1"}, info.RerunNodes) - assert.True(t, info.FromGraphInterrupt) result, err = r.Invoke(ctx, "input", WithCheckPointID("3")) assert.NoError(t, err) assert.Equal(t, "input12", result) @@ -1444,7 +1441,6 @@ func TestCancelInterrupt(t *testing.T) { info, success = ExtractInterruptInfo(err) assert.True(t, success) assert.Equal(t, []string{"1"}, info.AfterNodes) - assert.True(t, info.FromGraphInterrupt) result, err = r.Invoke(ctx, "input", WithCheckPointID("1")) assert.NoError(t, err) assert.Equal(t, "input12", result) @@ -1459,7 +1455,6 @@ func TestCancelInterrupt(t *testing.T) { info, success = ExtractInterruptInfo(err) assert.True(t, success) assert.Equal(t, []string{"1"}, info.AfterNodes) - assert.True(t, info.FromGraphInterrupt) result, err = r.Invoke(ctx, "input", WithCheckPointID("2")) assert.NoError(t, err) assert.Equal(t, "input12", result) @@ -1475,7 +1470,6 @@ func TestCancelInterrupt(t *testing.T) { info, success = ExtractInterruptInfo(err) assert.True(t, success) assert.Equal(t, []string{"1"}, info.RerunNodes) - assert.True(t, info.FromGraphInterrupt) result, err = r.Invoke(ctx, "input", WithCheckPointID("3")) assert.NoError(t, err) assert.Equal(t, "input12", result) @@ -1516,7 +1510,6 @@ func TestCancelInterrupt(t *testing.T) { info, success = ExtractInterruptInfo(err) assert.True(t, success) assert.Equal(t, 2, len(info.AfterNodes)) - assert.True(t, info.FromGraphInterrupt) result2, err := rr.Invoke(ctx, "input", WithCheckPointID("1")) assert.NoError(t, err) assert.Equal(t, map[string]any{ @@ -1535,7 +1528,6 @@ func TestCancelInterrupt(t *testing.T) { info, success = ExtractInterruptInfo(err) assert.True(t, success) assert.Equal(t, 2, len(info.RerunNodes)) - assert.True(t, info.FromGraphInterrupt) result2, err = rr.Invoke(ctx, "input", WithCheckPointID("2")) assert.NoError(t, err) assert.Equal(t, map[string]any{ @@ -1544,26 +1536,6 @@ func TestCancelInterrupt(t *testing.T) { }, result2) } -func TestBusinessInterruptFromGraphInterruptFalse(t *testing.T) { - g := NewGraph[string, string]() - _ = g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { - return "", Interrupt(ctx, "biz") - })) - _ = g.AddEdge(START, "1") - _ = g.AddEdge("1", END) - - ctx := context.Background() - r, err := g.Compile(ctx, WithCheckPointStore(newInMemoryStore())) - assert.NoError(t, err) - - _, err = r.Invoke(ctx, "input", WithCheckPointID("biz")) - assert.Error(t, err) - info, existed := ExtractInterruptInfo(err) - assert.True(t, existed) - assert.False(t, info.FromGraphInterrupt) - assert.Equal(t, []string{"1"}, info.RerunNodes) -} - func TestPersistRerunInputNonStream(t *testing.T) { store := newInMemoryStore() diff --git a/compose/graph_run.go b/compose/graph_run.go index 770cf16de..02b4fca7d 100644 --- a/compose/graph_run.go +++ b/compose/graph_run.go @@ -434,7 +434,6 @@ type interruptTempInfo struct { interruptBeforeNodes []string interruptAfterNodes []string interruptRerunExtra map[string]any - fromGraphInterrupt bool signals []*core.InterruptSignal } @@ -443,7 +442,7 @@ func (ti *interruptTempInfo) collectCanceledInfo(canceled bool, canceledTasks, c if !canceled { return } - ti.fromGraphInterrupt = true + if len(canceledTasks) > 0 { for _, t := range canceledTasks { ti.interruptRerunNodes = append(ti.interruptRerunNodes, t.nodeKey) @@ -461,13 +460,6 @@ func (r *runner) resolveInterruptCompletedTasks(tempInfo *interruptTempInfo, com if info := isSubGraphInterrupt(completedTask.err); info != nil { tempInfo.subGraphInterrupts[completedTask.nodeKey] = info tempInfo.signals = append(tempInfo.signals, info.signal) - // Propagate FromGraphInterrupt from the sub-graph to the parent. - // The sub-graph's task manager may have consumed the cancel - // channel value before the parent's, so only the sub-graph - // knows the interrupt was triggered by a graph-level cancel. - if info.Info != nil && info.Info.FromGraphInterrupt { - tempInfo.fromGraphInterrupt = true - } continue } @@ -535,13 +527,12 @@ func (r *runner) handleInterrupt( } intInfo := &InterruptInfo{ - State: cp.State, - AfterNodes: tempInfo.interruptAfterNodes, - BeforeNodes: tempInfo.interruptBeforeNodes, - RerunNodes: tempInfo.interruptRerunNodes, - RerunNodesExtra: tempInfo.interruptRerunExtra, - SubGraphs: make(map[string]*InterruptInfo), - FromGraphInterrupt: tempInfo.fromGraphInterrupt, + State: cp.State, + AfterNodes: tempInfo.interruptAfterNodes, + BeforeNodes: tempInfo.interruptBeforeNodes, + RerunNodes: tempInfo.interruptRerunNodes, + RerunNodesExtra: tempInfo.interruptRerunExtra, + SubGraphs: make(map[string]*InterruptInfo), } info := cp.State @@ -668,13 +659,12 @@ func (r *runner) handleInterruptWithSubGraphAndRerunNodes( } intInfo := &InterruptInfo{ - State: cp.State, - BeforeNodes: tempInfo.interruptBeforeNodes, - AfterNodes: tempInfo.interruptAfterNodes, - RerunNodes: tempInfo.interruptRerunNodes, - RerunNodesExtra: tempInfo.interruptRerunExtra, - SubGraphs: make(map[string]*InterruptInfo), - FromGraphInterrupt: tempInfo.fromGraphInterrupt, + State: cp.State, + BeforeNodes: tempInfo.interruptBeforeNodes, + AfterNodes: tempInfo.interruptAfterNodes, + RerunNodes: tempInfo.interruptRerunNodes, + RerunNodesExtra: tempInfo.interruptRerunExtra, + SubGraphs: make(map[string]*InterruptInfo), } info := cp.State diff --git a/compose/interrupt.go b/compose/interrupt.go index cd423a1d6..98a5eeecc 100644 --- a/compose/interrupt.go +++ b/compose/interrupt.go @@ -263,10 +263,6 @@ type InterruptInfo struct { RerunNodesExtra map[string]any SubGraphs map[string]*InterruptInfo InterruptContexts []*InterruptCtx - // FromGraphInterrupt indicates whether the interrupt was triggered by a graph-level - // cancel operation (e.g., via WithGraphInterrupt) rather than business logic. - // When true, the interrupt originated from an external cancellation request. - FromGraphInterrupt bool } func init() { From 67aadb14de355ae2de72bbe590a13ab841053cb7 Mon Sep 17 00:00:00 2001 From: shentongmartin Date: Fri, 17 Apr 2026 16:43:36 +0800 Subject: [PATCH 57/65] fix(adk): preserve nil agentCancelOpts in stopSignal.check to prevent UntilIdleFor from canceling agent (#972) --- adk/attack_test.go | 449 ------------------------ adk/turn_loop.go | 37 +- adk/turn_loop_test.go | 794 ++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 802 insertions(+), 478 deletions(-) delete mode 100644 adk/attack_test.go diff --git a/adk/attack_test.go b/adk/attack_test.go deleted file mode 100644 index bfb4462ef..000000000 --- a/adk/attack_test.go +++ /dev/null @@ -1,449 +0,0 @@ -/* - * Copyright 2026 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package adk - -import ( - "context" - "errors" - "sync" - "sync/atomic" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/cloudwego/eino/schema" -) - -func TestAttack_UntilIdleFor_ConcurrentPushDuringIdleTimer(t *testing.T) { - turnCount := int32(0) - turnDone := make(chan struct{}, 10) - - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ - Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, - Consumed: items, - }, nil - }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { - return &turnLoopMockAgent{ - name: "test", - runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { - atomic.AddInt32(&turnCount, 1) - turnDone <- struct{}{} - return &AgentOutput{}, nil - }, - }, nil - }, - }) - - loop.Push("msg1") - <-turnDone - - loop.Stop(UntilIdleFor(200 * time.Millisecond)) - - for i := 0; i < 5; i++ { - time.Sleep(50 * time.Millisecond) - loop.Push("concurrent-" + string(rune('a'+i))) - <-turnDone - } - - done := make(chan struct{}) - go func() { - loop.Wait() - close(done) - }() - - select { - case <-done: - case <-time.After(3 * time.Second): - t.Fatal("loop did not exit after idle timeout — Push did not reset timer correctly") - } - - finalCount := atomic.LoadInt32(&turnCount) - assert.Equal(t, int32(6), finalCount, "all 6 pushes should have been processed") -} - -func TestAttack_UntilIdleFor_MultipleStopCallsFirstWins(t *testing.T) { - turnDone := make(chan struct{}) - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ - Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, - Consumed: items, - }, nil - }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { - return &turnLoopMockAgent{ - name: "test", - runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { - close(turnDone) - return &AgentOutput{}, nil - }, - }, nil - }, - }) - - loop.Push("msg1") - <-turnDone - - loop.Stop(UntilIdleFor(100 * time.Millisecond)) - loop.Stop(UntilIdleFor(10 * time.Minute)) - - done := make(chan struct{}) - go func() { - loop.Wait() - close(done) - }() - - select { - case <-done: - case <-time.After(2 * time.Second): - t.Fatal("second UntilIdleFor should have been ignored; loop should have exited with 100ms timer") - } -} - -func TestAttack_BareStopOverridesUntilIdleFor(t *testing.T) { - agentStarted := make(chan struct{}) - agentDone := make(chan struct{}) - - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ - Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, - Consumed: items, - }, nil - }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { - return &turnLoopMockAgent{ - name: "test", - runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { - close(agentStarted) - <-agentDone - return &AgentOutput{}, nil - }, - }, nil - }, - }) - - loop.Push("msg1") - <-agentStarted - - loop.Stop(UntilIdleFor(10 * time.Minute)) - - loop.Stop() - close(agentDone) - - done := make(chan struct{}) - go func() { - loop.Wait() - close(done) - }() - - select { - case <-done: - case <-time.After(2 * time.Second): - t.Fatal("bare Stop should override UntilIdleFor and cause immediate shutdown") - } - - exit := loop.Wait() - assert.NoError(t, exit.ExitReason, "bare Stop should exit cleanly") -} - -func TestAttack_StopSignal_NilCancelOptsDoNotDeescalate(t *testing.T) { - agentStarted := make(chan *cancelContext, 1) - probe := &turnLoopStopModeProbeAgent{ccCh: agentStarted} - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ - Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, - Consumed: items, - }, nil - }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { - return probe, nil - }, - }) - - loop.Push("msg1") - cc := <-agentStarted - - loop.Stop(WithImmediate()) - - time.Sleep(20 * time.Millisecond) - - loop.Stop() - - time.Sleep(20 * time.Millisecond) - mode := cc.getMode() - assert.Equal(t, CancelImmediate, mode, "bare Stop after WithImmediate must not de-escalate cancel mode") - - exit := loop.Wait() - var ce *CancelError - require.True(t, errors.As(exit.ExitReason, &ce)) - assert.Equal(t, CancelImmediate, ce.Info.Mode) -} - -func TestAttack_CanceledItems_EmptyWhenAgentFinishesNormally(t *testing.T) { - agentStarted := make(chan struct{}) - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ - Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, - Consumed: items, - }, nil - }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { - return &turnLoopMockAgent{ - name: "test", - runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { - close(agentStarted) - return &AgentOutput{}, nil - }, - }, nil - }, - }) - - loop.Push("msg1") - <-agentStarted - time.Sleep(50 * time.Millisecond) - loop.Stop() - - exit := loop.Wait() - assert.NoError(t, exit.ExitReason) - assert.Empty(t, exit.CanceledItems, "CanceledItems must be empty when agent finished normally") -} - -func TestAttack_TurnBuffer_WakeupDoesNotLoseItems(t *testing.T) { - tb := newTurnBuffer[string]() - - tb.Send("a") - tb.Send("b") - tb.Wakeup() - tb.Send("c") - - var got []string - for i := 0; i < 3; i++ { - val, ok := tb.Receive() - require.True(t, ok) - got = append(got, val) - } - - assert.Equal(t, []string{"a", "b", "c"}, got, "Wakeup must not cause items to be lost") -} - -func TestAttack_TurnBuffer_ClearWakeupPreventsSpuriousReturn(t *testing.T) { - tb := newTurnBuffer[string]() - - tb.Wakeup() - tb.ClearWakeup() - - received := make(chan string, 1) - go func() { - val, ok := tb.Receive() - if ok { - received <- val - } - }() - - time.Sleep(50 * time.Millisecond) - tb.Send("real") - - select { - case val := <-received: - assert.Equal(t, "real", val, "ClearWakeup should prevent spurious empty return") - case <-time.After(2 * time.Second): - t.Fatal("Receive blocked forever despite Send") - } -} - -func TestAttack_StopBeforeRun_UntilIdleFor_ExitsImmediately(t *testing.T) { - loop := NewTurnLoop(TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ - Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, - Consumed: items, - }, nil - }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { - return &turnLoopMockAgent{name: "test"}, nil - }, - }) - - loop.Stop(UntilIdleFor(10 * time.Minute)) - loop.Stop() - - loop.Run(context.Background()) - - done := make(chan struct{}) - go func() { - loop.Wait() - close(done) - }() - - select { - case <-done: - case <-time.After(2 * time.Second): - t.Fatal("loop should exit immediately when Stop() called before Run()") - } -} - -func TestAttack_PushAfterStop_UntilIdleFor_RoutedToLateItems(t *testing.T) { - turnDone := make(chan struct{}) - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ - Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, - Consumed: items, - }, nil - }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { - return &turnLoopMockAgent{ - name: "test", - runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { - close(turnDone) - return &AgentOutput{}, nil - }, - }, nil - }, - }) - - loop.Push("msg1") - <-turnDone - - loop.Stop(UntilIdleFor(50 * time.Millisecond)) - exit := loop.Wait() - assert.NoError(t, exit.ExitReason) - - ok, _ := loop.Push("after-stop") - assert.False(t, ok, "Push after loop exited should return false") - - late := exit.TakeLateItems() - assert.Equal(t, []string{"after-stop"}, late) -} - -func TestAttack_ConcurrentStopEscalation_RaceDetector(t *testing.T) { - agentStarted := make(chan struct{}) - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ - Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, - Consumed: items, - }, nil - }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { - return &turnLoopCancellableMockAgent{ - name: "test", - runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { - close(agentStarted) - <-ctx.Done() - return nil, ctx.Err() - }, - }, nil - }, - }) - - loop.Push("msg1") - <-agentStarted - - var wg sync.WaitGroup - for i := 0; i < 10; i++ { - wg.Add(1) - go func(i int) { - defer wg.Done() - switch i % 4 { - case 0: - loop.Stop() - case 1: - loop.Stop(WithImmediate()) - case 2: - loop.Stop(WithGracefulTimeout(100 * time.Millisecond)) - case 3: - loop.Stop(UntilIdleFor(50 * time.Millisecond)) - } - }(i) - } - - wg.Wait() - exit := loop.Wait() - t.Log("ExitReason:", exit.ExitReason) -} - -func TestAttack_StopCause_FirstNonEmptyWins_ConcurrentCallers(t *testing.T) { - turnDone := make(chan struct{}) - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ - Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, - Consumed: items, - }, nil - }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { - return &turnLoopMockAgent{ - name: "test", - runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { - close(turnDone) - return &AgentOutput{}, nil - }, - }, nil - }, - }) - - loop.Push("msg1") - <-turnDone - - loop.Stop(WithStopCause("first-cause")) - loop.Stop(WithStopCause("second-cause")) - - exit := loop.Wait() - assert.Equal(t, "first-cause", exit.StopCause, "first non-empty StopCause should win") -} - -func TestAttack_SkipCheckpoint_Sticky(t *testing.T) { - agentStarted := make(chan struct{}) - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ - Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, - Consumed: items, - }, nil - }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { - return &turnLoopCancellableMockAgent{ - name: "test", - runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { - close(agentStarted) - <-ctx.Done() - return nil, ctx.Err() - }, - }, nil - }, - Store: &turnLoopCheckpointStore{m: make(map[string][]byte)}, - CheckpointID: "test-sticky", - }) - - loop.Push("msg1") - <-agentStarted - - loop.Stop(WithSkipCheckpoint()) - loop.Stop(WithImmediate()) - - exit := loop.Wait() - assert.False(t, exit.Checkpointed, "SkipCheckpoint is sticky; checkpoint should be skipped") -} diff --git a/adk/turn_loop.go b/adk/turn_loop.go index 124f65459..67db57c47 100644 --- a/adk/turn_loop.go +++ b/adk/turn_loop.go @@ -58,7 +58,8 @@ type stopSignal struct { mu sync.Mutex gen uint64 // agentCancelOpts controls how the stop interacts with the running agent: - // nil → no cancel; the turn runs to completion (bare Stop) + // nil → no cancel intent; the turn runs to completion + // (bare Stop, or UntilIdleFor without cancel opts) // empty → CancelImmediate (WithImmediate) // non-empty → cancel with specific modes (WithGraceful, WithGracefulTimeout) agentCancelOpts []AgentCancelOption @@ -120,9 +121,15 @@ func (s *stopSignal) closeDone() { } // check returns the current generation and a snapshot of the cancel options. +// Returns nil opts when no cancel intent has been set (e.g. UntilIdleFor without +// WithGraceful/WithImmediate), preserving the nil vs empty-slice distinction +// that tryCancel relies on. func (s *stopSignal) check() (uint64, []AgentCancelOption) { s.mu.Lock() defer s.mu.Unlock() + if s.agentCancelOpts == nil { + return s.gen, nil + } return s.gen, append([]AgentCancelOption{}, s.agentCancelOpts...) } @@ -931,16 +938,24 @@ func WithStopCause(cause string) StopOption { // wants to shut down the loop once there has been no work for a while, without // racing with concurrent Push calls. // -// UntilIdleFor is combinable with other StopOptions in the same call. -// For example, Stop(UntilIdleFor(30*time.Second), WithGraceful()) means -// "after 30 s of idle, stop gracefully". If another Stop call is made -// without UntilIdleFor (e.g. Stop(WithImmediate())), the loop shuts down -// immediately, bypassing the idle wait. +// UntilIdleFor does not impact a running agent. It only takes effect when the +// loop is idle between turns. Cancel options (WithImmediate, WithGraceful, +// WithGracefulTimeout) in the same Stop call are silently ignored — they are +// meaningless alongside UntilIdleFor. +// +// To escalate after a prior UntilIdleFor, issue a separate Stop call: +// +// loop.Stop(UntilIdleFor(30 * time.Second)) // wait for idle +// // ... later, if you need to abort immediately: +// loop.Stop(WithImmediate()) // overrides the idle wait // // Only the first UntilIdleFor duration takes effect; subsequent calls with // a different duration are ignored. A Stop() call without UntilIdleFor always // shuts down the loop immediately regardless of any pending idle timer. // +// UntilIdleFor is combinable with non-cancel StopOptions (WithSkipCheckpoint, +// WithStopCause) in the same call. +// // duration must be positive; passing a zero or negative value panics. func UntilIdleFor(duration time.Duration) StopOption { if duration <= 0 { @@ -1281,6 +1296,14 @@ func (l *TurnLoop[T]) Stop(opts ...StopOption) { opt(cfg) } + // UntilIdleFor is incompatible with cancel options (WithImmediate, + // WithGraceful, WithGracefulTimeout) in the same call. Cancel opts only + // make sense for an immediate or escalated stop; UntilIdleFor defers the + // stop until idle, and must not impact a running agent. Drop them silently. + if cfg.idleFor > 0 { + cfg.agentCancelOpts = nil + } + l.stopSig.signal(cfg) if cfg.idleFor > 0 { @@ -1545,7 +1568,7 @@ func (l *TurnLoop[T]) watchStopSignal(done <-chan struct{}, agentCancelFunc Agen return } lastGen = gen - if opts == nil { + if opts == nil { // no cancel intent; see stopSignal.agentCancelOpts return } _, contributed := agentCancelFunc(opts...) diff --git a/adk/turn_loop_test.go b/adk/turn_loop_test.go index 4f22ca1a7..ea7c0aa93 100644 --- a/adk/turn_loop_test.go +++ b/adk/turn_loop_test.go @@ -489,7 +489,7 @@ func TestTurnLoop_Preempt_CancelsCurrentAgent(t *testing.T) { t.Fatal("second GenInput was not called after preempt") } - loop.Stop() + loop.Stop(WithImmediate()) result := loop.Wait() assert.NoError(t, result.ExitReason) assert.GreaterOrEqual(t, atomic.LoadInt32(&genInputCalls), int32(2)) @@ -604,7 +604,7 @@ func TestTurnLoop_Preempt_WithAgentCancelMode(t *testing.T) { t.Fatal("cancelFunc was not called by preempt") } - loop.Stop() + loop.Stop(WithImmediate()) result := loop.Wait() assert.NoError(t, result.ExitReason) cancelModeMu.Lock() @@ -650,7 +650,7 @@ func TestTurnLoop_PreemptAck_ClosesAfterCancelIsInitiated(t *testing.T) { close(agentFinishGate) - loop.Stop() + loop.Stop(WithImmediate()) result := loop.Wait() assert.NoError(t, result.ExitReason) } @@ -732,7 +732,7 @@ func TestTurnLoop_Preempt_EscalatesOnSecondPreempt(t *testing.T) { close(agentFinishGate) - loop.Stop() + loop.Stop(WithImmediate()) result := loop.Wait() assert.NoError(t, result.ExitReason) } @@ -792,7 +792,7 @@ func TestTurnLoop_Preempt_JoinsSafePointModesOnSecondPreempt(t *testing.T) { close(agentFinishGate) - loop.Stop() + loop.Stop(WithImmediate()) result := loop.Wait() assert.NoError(t, result.ExitReason) } @@ -1327,6 +1327,92 @@ func TestTurnLoop_StopDuringAgentExecution(t *testing.T) { assert.Empty(t, result.CanceledItems) } +// TestTurnLoop_BareStop_AgentRunsToCompletion verifies the core contract of +// bare Stop(): the running agent finishes naturally with an uncanceled context, +// the loop exits cleanly (ExitReason == nil), and no new turn starts even when +// additional items are buffered. +func TestTurnLoop_BareStop_AgentRunsToCompletion(t *testing.T) { + const agentWorkDuration = 200 * time.Millisecond + + agentStarted := make(chan struct{}) + agentCtxErr := make(chan error, 1) + agentOutput := make(chan string, 1) + + turnsExecuted := int32(0) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: []string{items[0]}, + Remaining: items[1:], + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{ + name: "worker", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + atomic.AddInt32(&turnsExecuted, 1) + close(agentStarted) + + // Simulate real work (NOT blocking on <-ctx.Done()) + time.Sleep(agentWorkDuration) + + // Record context state AFTER work completes + agentCtxErr <- ctx.Err() + agentOutput <- "work-done" + + return &AgentOutput{}, nil + }, + }, nil + }, + }) + + // Push two items so the loop has a reason to start a second turn. + loop.Push("task1") + loop.Push("task2") + + // Wait for the agent to start processing task1. + select { + case <-agentStarted: + case <-time.After(2 * time.Second): + t.Fatal("agent did not start") + } + + // Call bare Stop() while the agent is doing work. + loop.Stop() + + result := loop.Wait() + + // 1. Agent's context was NOT canceled. + select { + case err := <-agentCtxErr: + assert.NoError(t, err, "bare Stop must not cancel the agent's context") + default: + t.Fatal("agent never reported context state") + } + + // 2. Agent completed its work. + select { + case out := <-agentOutput: + assert.Equal(t, "work-done", out) + default: + t.Fatal("agent never produced output") + } + + // 3. ExitReason is nil (clean exit, not a CancelError). + assert.NoError(t, result.ExitReason) + + // 4. CanceledItems is empty (agent was not canceled). + assert.Empty(t, result.CanceledItems) + + // 5. Only one turn executed; the second item is unhandled. + assert.Equal(t, int32(1), atomic.LoadInt32(&turnsExecuted), + "bare Stop must prevent new turns from starting after the current one completes") + assert.Equal(t, []string{"task2"}, result.UnhandledItems, + "the second item should appear in UnhandledItems") +} + func TestTurnLoop_StopCheckPointIDInCancelError(t *testing.T) { ctx := context.Background() modelStarted := make(chan struct{}, 1) @@ -1368,7 +1454,7 @@ func TestTurnLoop_StopCheckPointIDInCancelError(t *testing.T) { loop.Push("msg1") <-modelStarted - loop.Stop() + loop.Stop(WithImmediate()) result := loop.Wait() @@ -1420,7 +1506,7 @@ func TestTurnLoop_StopWithoutCheckpointIDDoesNotPersist(t *testing.T) { loop.Push("msg1") <-modelStarted - loop.Stop() + loop.Stop(WithImmediate()) result := loop.Wait() @@ -1563,7 +1649,7 @@ func TestTurnLoop_StopDuringAgentExecution_PersistAndResume(t *testing.T) { loop.Push("msg1") <-modelStarted - loop.Stop() + loop.Stop(WithImmediate()) exit := loop.Wait() store.mu.Lock() @@ -1798,7 +1884,7 @@ func TestTurnLoop_CheckpointSaveError_ReturnsError(t *testing.T) { }) loop.Push("msg1") <-modelStarted - loop.Stop() + loop.Stop(WithImmediate()) exit := loop.Wait() assert.Error(t, exit.ExitReason) assert.True(t, exit.Checkpointed) @@ -2040,7 +2126,7 @@ func TestTurnLoop_GenResumeNil_Error(t *testing.T) { }) loop1.Push("msg1") <-modelStarted - loop1.Stop() + loop1.Stop(WithImmediate()) loop1.Wait() loop2 := NewTurnLoop(TurnLoopConfig[string]{ @@ -2212,7 +2298,7 @@ func TestTurnLoop_GenResumeReturnsError(t *testing.T) { }) loop1.Push("msg1") <-modelStarted - loop1.Stop() + loop1.Stop(WithImmediate()) loop1.Wait() genResumeErr := fmt.Errorf("resume callback failed") @@ -2271,7 +2357,7 @@ func TestTurnLoop_CheckpointSaveError_MergesWithExistingError(t *testing.T) { }) loop.Push("msg1") <-modelStarted - loop.Stop() + loop.Stop(WithImmediate()) exit := loop.Wait() assert.Error(t, exit.ExitReason) var ce *CancelError @@ -2319,7 +2405,7 @@ func TestTurnLoop_ResumeWithParams(t *testing.T) { }) loop1.Push("msg1") <-modelStarted - loop1.Stop() + loop1.Stop(WithImmediate()) exit1 := loop1.Wait() var ce *CancelError assert.True(t, errors.As(exit1.ExitReason, &ce)) @@ -3032,11 +3118,11 @@ func TestTurnLoop_ConcurrentPreemptsDuringTurn(t *testing.T) { wg.Add(1) go func(i int) { defer wg.Done() - ok, ack := loop.Push(fmt.Sprintf("urgent-%d", i), WithPreemptTimeout[string](AnySafePoint, time.Millisecond)) + ok, ack := loop.Push(fmt.Sprintf("urgent-%d", i), WithPreemptTimeout[string](AnySafePoint, 10*time.Millisecond)) if ok && ack != nil { select { case <-ack: - case <-time.After(5 * time.Second): + case <-time.After(30 * time.Second): t.Error("ack channel not closed within timeout") } } @@ -3046,7 +3132,7 @@ func TestTurnLoop_ConcurrentPreemptsDuringTurn(t *testing.T) { wg.Wait() time.Sleep(200 * time.Millisecond) - loop.Stop() + loop.Stop(WithImmediate()) result := loop.Wait() assert.NoError(t, result.ExitReason) assert.True(t, atomic.LoadInt32(&genInputCount) >= 2, "should have had at least the initial turn + one preempted turn") @@ -3235,7 +3321,7 @@ func TestTurnLoop_ConcurrentPreemptAndStop(t *testing.T) { go func() { defer wg.Done() - loop.Stop() + loop.Stop(WithImmediate()) }() wg.Wait() @@ -3298,7 +3384,7 @@ func TestTurnLoop_ConcurrentPushStrategyAndStop(t *testing.T) { go func() { defer wg.Done() - loop.Stop() + loop.Stop(WithImmediate()) }() wg.Wait() @@ -3306,6 +3392,7 @@ func TestTurnLoop_ConcurrentPushStrategyAndStop(t *testing.T) { }) } } + func TestTurnLoop_TurnContext_StoppedChannel(t *testing.T) { stoppedSeen := make(chan struct{}) agentStarted := make(chan struct{}) @@ -3346,7 +3433,7 @@ func TestTurnLoop_TurnContext_StoppedChannel(t *testing.T) { loop.Push("msg1") <-agentStarted - loop.Stop() + loop.Stop(WithImmediate()) select { case <-stoppedSeen: @@ -3534,7 +3621,7 @@ func TestTurnLoop_PushStrategy_DuringTurn(t *testing.T) { t.Fatal("second GenInput was not called after preempt") } - loop.Stop() + loop.Stop(WithImmediate()) loop.Wait() assert.Equal(t, int32(1), atomic.LoadInt32(&strategyCalled)) @@ -3780,7 +3867,7 @@ func TestTurnLoop_PushStrategy_ConsumedInspection(t *testing.T) { t.Fatal("second GenInput was not called after strategy-driven preempt") } - loop.Stop() + loop.Stop(WithImmediate()) loop.Wait() } @@ -4268,7 +4355,7 @@ func TestTurnLoop_StopCause_InTurnContext(t *testing.T) { loop.Push("msg1") <-agentStarted - loop.Stop(WithStopCause(cause)) + loop.Stop(WithImmediate(), WithStopCause(cause)) select { case c := <-gotCause: @@ -4631,6 +4718,669 @@ func TestUntilIdleFor(t *testing.T) { }) } +// TestUntilIdleFor_DoesNotCancelRunningAgent verifies that Stop(UntilIdleFor) +// does NOT cancel a running agent. The notify signal from UntilIdleFor must not +// be misinterpreted as a cancel request by watchStopSignal. This is a regression +// test for a bug where stopSignal.check() converted nil agentCancelOpts to a +// non-nil empty slice, which tryCancel treated as CancelImmediate. +func TestUntilIdleFor_DoesNotCancelRunningAgent(t *testing.T) { + t.Run("BeforeRun", func(t *testing.T) { + agentStarted := make(chan struct{}) + agentCtxCanceled := int32(0) + agentDone := make(chan struct{}) + + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + close(agentStarted) + // Block until context is canceled or a short timeout. + select { + case <-ctx.Done(): + atomic.StoreInt32(&agentCtxCanceled, 1) + case <-time.After(200 * time.Millisecond): + } + close(agentDone) + return &AgentOutput{}, nil + }, + }, nil + }, + }) + + loop.Push("msg1") + // Call Stop(UntilIdleFor) BEFORE Run. + loop.Stop(UntilIdleFor(50 * time.Millisecond)) + loop.Run(context.Background()) + + <-agentStarted + <-agentDone + + exit := loop.Wait() + assert.Nil(t, exit.ExitReason, "UntilIdleFor should not produce a CancelError") + assert.Equal(t, int32(0), atomic.LoadInt32(&agentCtxCanceled), + "agent context should not have been canceled by UntilIdleFor") + }) + + t.Run("DuringRun", func(t *testing.T) { + agentStarted := make(chan struct{}) + agentCtxCanceled := int32(0) + agentDone := make(chan struct{}) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + close(agentStarted) + select { + case <-ctx.Done(): + atomic.StoreInt32(&agentCtxCanceled, 1) + case <-time.After(200 * time.Millisecond): + } + close(agentDone) + return &AgentOutput{}, nil + }, + }, nil + }, + }) + + loop.Push("msg1") + <-agentStarted + + // Call Stop(UntilIdleFor) while the agent is running. + loop.Stop(UntilIdleFor(50 * time.Millisecond)) + <-agentDone + + exit := loop.Wait() + assert.Nil(t, exit.ExitReason, "UntilIdleFor should not produce a CancelError") + assert.Equal(t, int32(0), atomic.LoadInt32(&agentCtxCanceled), + "agent context should not have been canceled by UntilIdleFor") + }) + + // Cancel opts paired with UntilIdleFor in the same call are silently + // dropped. The agent must run to completion even when WithImmediate is + // combined with UntilIdleFor. + t.Run("CancelOptsDroppedInSameCall", func(t *testing.T) { + agentStarted := make(chan struct{}) + agentCtxCanceled := int32(0) + agentDone := make(chan struct{}) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + close(agentStarted) + select { + case <-ctx.Done(): + atomic.StoreInt32(&agentCtxCanceled, 1) + case <-time.After(200 * time.Millisecond): + } + close(agentDone) + return &AgentOutput{}, nil + }, + }, nil + }, + }) + + loop.Push("msg1") + <-agentStarted + + // WithImmediate in the same call as UntilIdleFor must be ignored. + loop.Stop(UntilIdleFor(50*time.Millisecond), WithImmediate()) + <-agentDone + + exit := loop.Wait() + assert.Nil(t, exit.ExitReason, "cancel opts should be dropped when combined with UntilIdleFor") + assert.Equal(t, int32(0), atomic.LoadInt32(&agentCtxCanceled), + "agent context should not have been canceled") + }) +} + +func TestUntilIdleFor_ContextCancelDuringIdleWait(t *testing.T) { + turnDone := make(chan struct{}) + ctx, cancel := context.WithCancel(context.Background()) + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + close(turnDone) + return &AgentOutput{}, nil + }, + }, nil + }, + }) + + loop.Push("msg1") + <-turnDone + + // Start idle timer, then cancel the parent context while idle. + loop.Stop(UntilIdleFor(10 * time.Minute)) + time.Sleep(20 * time.Millisecond) + cancel() + + done := make(chan struct{}) + go func() { + loop.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("loop should exit when context is canceled during idle wait") + } + + exit := loop.Wait() + assert.ErrorIs(t, exit.ExitReason, context.Canceled) +} + +// TestStopSignalCheck_NilPreservedUnderConcurrentSignals hammers +// stopSignal.check() and signal() concurrently to verify that the nil guard +// in check() does not race with signal(). The race detector should catch any +// unsynchronised access. +func TestStopSignalCheck_NilPreservedUnderConcurrentSignals(t *testing.T) { + sig := newStopSignal() + + const goroutines = 20 + const iterations = 200 + + var wg sync.WaitGroup + + // Half the goroutines call signal() with UntilIdleFor-style config (nil agentCancelOpts). + for i := 0; i < goroutines/2; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < iterations; j++ { + // UntilIdleFor produces nil agentCancelOpts after Stop() forces it. + sig.signal(&stopConfig{idleFor: 100 * time.Millisecond}) + } + }() + } + + // The other half call signal() with WithImmediate-style config (non-nil empty opts). + for i := 0; i < goroutines/2; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < iterations; j++ { + sig.signal(&stopConfig{agentCancelOpts: []AgentCancelOption{}}) + } + }() + } + + // Concurrently read check() — the nil guard must be race-free. + sawNil := int32(0) + sawNonNil := int32(0) + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < iterations; j++ { + _, opts := sig.check() + if opts == nil { + atomic.AddInt32(&sawNil, 1) + } else { + atomic.AddInt32(&sawNonNil, 1) + } + } + }() + } + + wg.Wait() + + // We expect both nil and non-nil snapshots to have been observed, since + // signal() alternates between the two modes concurrently. + t.Logf("sawNil=%d sawNonNil=%d", atomic.LoadInt32(&sawNil), atomic.LoadInt32(&sawNonNil)) + // Main point: no race detector failure. The counts are non-deterministic. +} + +func TestAttack_UntilIdleFor_ConcurrentPushDuringIdleTimer(t *testing.T) { + turnCount := int32(0) + turnDone := make(chan struct{}, 10) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + atomic.AddInt32(&turnCount, 1) + turnDone <- struct{}{} + return &AgentOutput{}, nil + }, + }, nil + }, + }) + + loop.Push("msg1") + <-turnDone + + loop.Stop(UntilIdleFor(200 * time.Millisecond)) + + for i := 0; i < 5; i++ { + time.Sleep(50 * time.Millisecond) + loop.Push("concurrent-" + string(rune('a'+i))) + <-turnDone + } + + done := make(chan struct{}) + go func() { + loop.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(3 * time.Second): + t.Fatal("loop did not exit after idle timeout — Push did not reset timer correctly") + } + + finalCount := atomic.LoadInt32(&turnCount) + assert.Equal(t, int32(6), finalCount, "all 6 pushes should have been processed") +} + +func TestAttack_UntilIdleFor_MultipleStopCallsFirstWins(t *testing.T) { + turnDone := make(chan struct{}) + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + close(turnDone) + return &AgentOutput{}, nil + }, + }, nil + }, + }) + + loop.Push("msg1") + <-turnDone + + loop.Stop(UntilIdleFor(100 * time.Millisecond)) + loop.Stop(UntilIdleFor(10 * time.Minute)) + + done := make(chan struct{}) + go func() { + loop.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("second UntilIdleFor should have been ignored; loop should have exited with 100ms timer") + } +} + +func TestAttack_BareStopOverridesUntilIdleFor(t *testing.T) { + agentStarted := make(chan struct{}) + agentDone := make(chan struct{}) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + close(agentStarted) + <-agentDone + return &AgentOutput{}, nil + }, + }, nil + }, + }) + + loop.Push("msg1") + <-agentStarted + + loop.Stop(UntilIdleFor(10 * time.Minute)) + + loop.Stop() + close(agentDone) + + done := make(chan struct{}) + go func() { + loop.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("bare Stop should override UntilIdleFor and cause immediate shutdown") + } + + exit := loop.Wait() + assert.NoError(t, exit.ExitReason, "bare Stop should exit cleanly") +} + +func TestAttack_StopSignal_NilCancelOptsDoNotDeescalate(t *testing.T) { + agentStarted := make(chan *cancelContext, 1) + probe := &turnLoopStopModeProbeAgent{ccCh: agentStarted} + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return probe, nil + }, + }) + + loop.Push("msg1") + cc := <-agentStarted + + loop.Stop(WithImmediate()) + + time.Sleep(20 * time.Millisecond) + + loop.Stop() + + time.Sleep(20 * time.Millisecond) + mode := cc.getMode() + assert.Equal(t, CancelImmediate, mode, "bare Stop after WithImmediate must not de-escalate cancel mode") + + exit := loop.Wait() + var ce *CancelError + require.True(t, errors.As(exit.ExitReason, &ce)) + assert.Equal(t, CancelImmediate, ce.Info.Mode) +} + +func TestAttack_CanceledItems_EmptyWhenAgentFinishesNormally(t *testing.T) { + agentStarted := make(chan struct{}) + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + close(agentStarted) + return &AgentOutput{}, nil + }, + }, nil + }, + }) + + loop.Push("msg1") + <-agentStarted + time.Sleep(50 * time.Millisecond) + loop.Stop() + + exit := loop.Wait() + assert.NoError(t, exit.ExitReason) + assert.Empty(t, exit.CanceledItems, "CanceledItems must be empty when agent finished normally") +} + +func TestAttack_TurnBuffer_WakeupDoesNotLoseItems(t *testing.T) { + tb := newTurnBuffer[string]() + + tb.Send("a") + tb.Send("b") + tb.Wakeup() + tb.Send("c") + + var got []string + for i := 0; i < 3; i++ { + val, ok := tb.Receive() + require.True(t, ok) + got = append(got, val) + } + + assert.Equal(t, []string{"a", "b", "c"}, got, "Wakeup must not cause items to be lost") +} + +func TestAttack_TurnBuffer_ClearWakeupPreventsSpuriousReturn(t *testing.T) { + tb := newTurnBuffer[string]() + + tb.Wakeup() + tb.ClearWakeup() + + received := make(chan string, 1) + go func() { + val, ok := tb.Receive() + if ok { + received <- val + } + }() + + time.Sleep(50 * time.Millisecond) + tb.Send("real") + + select { + case val := <-received: + assert.Equal(t, "real", val, "ClearWakeup should prevent spurious empty return") + case <-time.After(2 * time.Second): + t.Fatal("Receive blocked forever despite Send") + } +} + +func TestAttack_StopBeforeRun_UntilIdleFor_ExitsImmediately(t *testing.T) { + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Stop(UntilIdleFor(10 * time.Minute)) + loop.Stop() + + loop.Run(context.Background()) + + done := make(chan struct{}) + go func() { + loop.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("loop should exit immediately when Stop() called before Run()") + } +} + +func TestAttack_PushAfterStop_UntilIdleFor_RoutedToLateItems(t *testing.T) { + turnDone := make(chan struct{}) + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + close(turnDone) + return &AgentOutput{}, nil + }, + }, nil + }, + }) + + loop.Push("msg1") + <-turnDone + + loop.Stop(UntilIdleFor(50 * time.Millisecond)) + exit := loop.Wait() + assert.NoError(t, exit.ExitReason) + + ok, _ := loop.Push("after-stop") + assert.False(t, ok, "Push after loop exited should return false") + + late := exit.TakeLateItems() + assert.Equal(t, []string{"after-stop"}, late) +} + +func TestAttack_ConcurrentStopEscalation_RaceDetector(t *testing.T) { + agentStarted := make(chan struct{}) + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + close(agentStarted) + <-ctx.Done() + return nil, ctx.Err() + }, + }, nil + }, + }) + + loop.Push("msg1") + <-agentStarted + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + switch i % 4 { + case 0: + loop.Stop() + case 1: + loop.Stop(WithImmediate()) + case 2: + loop.Stop(WithGracefulTimeout(100 * time.Millisecond)) + case 3: + loop.Stop(UntilIdleFor(50 * time.Millisecond)) + } + }(i) + } + + wg.Wait() + exit := loop.Wait() + t.Log("ExitReason:", exit.ExitReason) +} + +func TestAttack_StopCause_FirstNonEmptyWins_ConcurrentCallers(t *testing.T) { + turnDone := make(chan struct{}) + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + close(turnDone) + return &AgentOutput{}, nil + }, + }, nil + }, + }) + + loop.Push("msg1") + <-turnDone + + loop.Stop(WithStopCause("first-cause")) + loop.Stop(WithStopCause("second-cause")) + + exit := loop.Wait() + assert.Equal(t, "first-cause", exit.StopCause, "first non-empty StopCause should win") +} + +func TestAttack_SkipCheckpoint_Sticky(t *testing.T) { + agentStarted := make(chan struct{}) + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + close(agentStarted) + <-ctx.Done() + return nil, ctx.Err() + }, + }, nil + }, + Store: &turnLoopCheckpointStore{m: make(map[string][]byte)}, + CheckpointID: "test-sticky", + }) + + loop.Push("msg1") + <-agentStarted + + loop.Stop(WithSkipCheckpoint()) + loop.Stop(WithImmediate()) + + exit := loop.Wait() + assert.False(t, exit.Checkpointed, "SkipCheckpoint is sticky; checkpoint should be skipped") +} + func TestUntilIdleFor_NonPositive_Panics(t *testing.T) { assert.PanicsWithValue(t, "adk: UntilIdleFor: duration must be positive", func() { UntilIdleFor(0) }) From a4460f37ce1e35870ce7537fcc5de96116fd7fb0 Mon Sep 17 00:00:00 2001 From: shentongmartin Date: Mon, 20 Apr 2026 19:56:28 +0800 Subject: [PATCH 58/65] refactor(adk): improve cancel API naming, enforce recursive teardown, and clean up internal terminology (#979) --- adk/cancel.go | 45 ++--- adk/cancel_edge_test.go | 176 ++++++++++++++++-- adk/cancel_test.go | 166 +++++++++++++++-- adk/chatmodel.go | 2 +- adk/prebuilt/planexecute/plan_execute_test.go | 2 +- adk/runner.go | 1 - adk/turn_loop.go | 108 ++++++----- adk/turn_loop_test.go | 167 +++++++++++++++-- 8 files changed, 548 insertions(+), 119 deletions(-) diff --git a/adk/cancel.go b/adk/cancel.go index 6d4aa9ad9..513b0cf43 100644 --- a/adk/cancel.go +++ b/adk/cancel.go @@ -56,8 +56,8 @@ const ( // Use WithRecursive to propagate the cancel to all descendants — whichever // ChatModel finishes first triggers the cancel. CancelAfterChatModel CancelMode = 1 << iota - // CancelAfterToolCalls cancels after the root agent's next set of concurrent - // tool calls completes. By default, only the root agent checks this safe-point. + // CancelAfterToolCalls cancels after the root agent's next set of tool calls + // completes. By default, only the root agent checks this safe-point. // Use WithRecursive to propagate to all descendants. CancelAfterToolCalls ) @@ -75,7 +75,7 @@ type CancelHandle struct { // was absorbed into CancelError while cancellation was active // - ErrCancelTimeout: the requested safe-point cancellation timed out and was // escalated to immediate cancellation -// - ErrExecutionCompleted: the execution finished before cancellation took effect, +// - ErrExecutionEnded: the execution ended before cancellation took effect, // meaning the stream drained to completion without any interrupt func (h *CancelHandle) Wait() error { return h.wait() @@ -110,9 +110,9 @@ func WithAgentCancelMode(mode CancelMode) AgentCancelOption { // WithAgentCancelTimeout sets a timeout for the cancel operation. // This only applies to safe-point modes (CancelAfterChatModel, CancelAfterToolCalls): // if the safe-point hasn't fired within this duration, the cancel escalates to -// an immediate graph interrupt. -// For CancelImmediate this timeout is ignored — the graph interrupt fires -// immediately with timeout=0. +// CancelImmediate. The escalated cancel still saves a checkpoint, so the execution +// can be resumed via Runner.Resume or Runner.ResumeWithParams. +// For CancelImmediate this timeout is ignored — the cancel fires immediately. func WithAgentCancelTimeout(timeout time.Duration) AgentCancelOption { return func(config *agentCancelConfig) { config.Timeout = &timeout @@ -126,6 +126,11 @@ func WithAgentCancelTimeout(timeout time.Duration) AgentCancelOption { // - CancelImmediate: descendants receive explicit immediate-cancel signals for // clean teardown; the root uses a grace period to collect child interrupts. // +// With recursive cancellation, each descendant agent also triggers cancellation +// and cascades its interrupt information upward. The root agent ultimately +// produces a complete checkpoint that includes descendant checkpoints, enabling +// resumption from the exact point where each descendant was interrupted. +// // Once any cancel call includes WithRecursive, the flag stays set for the // entire cancel lifecycle (monotonic escalation). func WithRecursive() AgentCancelOption { @@ -146,7 +151,7 @@ type AgentCancelInfo struct { // // Interrupt absorption: when a cancel is active (shouldCancel() == true), ANY // interrupt — whether from a cancel safe-point node or from business logic -// (e.g. compose.Interrupt in a tool) — is converted to a CancelError. The +// (e.g. tool.Interrupt in a tool) — is converted to a CancelError. The // cancel "absorbs" the business interrupt. This is intentional: // // - In concurrent execution (parallel workflows, concurrent tool calls), @@ -155,16 +160,12 @@ type AgentCancelInfo struct { // - Even in sequential execution, treating business interrupts as CancelError // during active cancel gives consistent semantics. // - The business interrupt is NOT lost — the checkpoint preserves the full -// interrupt hierarchy. On resume (Runner.Resume), the agent re-executes -// the interrupting code path and the business interrupt re-fires naturally. +// interrupt hierarchy. On resume (Runner.Resume or Runner.ResumeWithParams), +// the agent re-executes the interrupting code path and the business +// interrupt re-fires naturally. type CancelError struct { Info *AgentCancelInfo - // CheckPointID is the checkpoint ID associated with this cancel operation. - // When non-empty, the cancelled agent's state has been persisted under this ID - // and can be resumed via Runner.Resume or GenInputResult.ResumeFromCheckpointID. - CheckPointID string - // InterruptContexts provides the interrupt contexts needed for targeted // resumption via Runner.ResumeWithParams. Each context represents a step // in the agent hierarchy that was interrupted. This is a slice because @@ -185,15 +186,15 @@ var ( // ErrCancelTimeout is returned by CancelHandle.Wait when the cancel operation timed out. ErrCancelTimeout = errors.New("cancel timed out") - // ErrExecutionCompleted is returned by CancelHandle.Wait when the agent finished - // before the cancel took effect. "Finished" means the event stream was fully + // ErrExecutionEnded is returned by CancelHandle.Wait when the agent ended + // before the cancel took effect. "Ended" means the event stream was fully // drained without any interrupt — normal completion or a fatal error. // // Note: business interrupts that occur while cancel is active are absorbed // into CancelError (see CancelError doc), so they result in nil (cancel - // succeeded), NOT ErrExecutionCompleted. Only execution that completes with + // succeeded), NOT ErrExecutionEnded. Only execution that completes with // no interrupt at all produces this error. - ErrExecutionCompleted = errors.New("execution already completed") + ErrExecutionEnded = errors.New("execution already ended") // ErrStreamCanceled is the error sent through the stream when CancelImmediate aborts it. // It is a *StreamCanceledError so it can be gob-serialized during checkpoint save @@ -699,7 +700,7 @@ func (cc *cancelContext) buildCancelFunc() AgentCancelFunc { st := atomic.LoadInt32(&cc.state) switch st { case stateDone: - return ErrExecutionCompleted + return ErrExecutionEnded default: if atomic.LoadInt32(&cc.timeoutEscalated) == 1 { return ErrCancelTimeout @@ -716,7 +717,7 @@ func (cc *cancelContext) buildCancelFunc() AgentCancelFunc { case stateCancelHandled: return newHandle(func() error { return nil }), false case stateDone: - return newHandle(func() error { return ErrExecutionCompleted }), false + return newHandle(func() error { return ErrExecutionEnded }), false } var needImmediate, needTimeoutCtl bool @@ -730,7 +731,7 @@ func (cc *cancelContext) buildCancelFunc() AgentCancelFunc { return newHandle(func() error { return nil }), false case stateDone: cc.cancelMu.Unlock() - return newHandle(func() error { return ErrExecutionCompleted }), false + return newHandle(func() error { return ErrExecutionEnded }), false } curMode := cc.getMode() @@ -739,7 +740,7 @@ func (cc *cancelContext) buildCancelFunc() AgentCancelFunc { st = atomic.LoadInt32(&cc.state) cc.cancelMu.Unlock() if st == stateDone { - return newHandle(func() error { return ErrExecutionCompleted }), false + return newHandle(func() error { return ErrExecutionEnded }), false } return newHandle(waitForCompletion), true } diff --git a/adk/cancel_edge_test.go b/adk/cancel_edge_test.go index b0afbe674..248a84ee0 100644 --- a/adk/cancel_edge_test.go +++ b/adk/cancel_edge_test.go @@ -218,7 +218,7 @@ func TestWithCancel_BeforeExecutionStarts(t *testing.T) { // cancelFn must have already returned (or return quickly now that doneChan is closed). select { case cancelErr := <-cancelDone: - // Either nil (cancel handled) or ErrExecutionCompleted is acceptable + // Either nil (cancel handled) or ErrExecutionEnded is acceptable // depending on exact timing; what matters is it didn't hang. _ = cancelErr case <-time.After(3 * time.Second): @@ -229,7 +229,7 @@ func TestWithCancel_BeforeExecutionStarts(t *testing.T) { assert.Equal(t, int32(0), atomic.LoadInt32(&bt.callCount), "tool must not be called") } -// TestWithCancel_AfterCompletion verifies cancelFn returns ErrExecutionCompleted +// TestWithCancel_AfterCompletion verifies cancelFn returns ErrExecutionEnded // when called after a normal run finishes. func TestWithCancel_AfterCompletion(t *testing.T) { ctx := context.Background() @@ -254,10 +254,10 @@ func TestWithCancel_AfterCompletion(t *testing.T) { handle, _ := cancelFn() cancelErr := handle.Wait() - assert.ErrorIs(t, cancelErr, ErrExecutionCompleted) + assert.ErrorIs(t, cancelErr, ErrExecutionEnded) } -// TestWithCancel_AfterBusinessInterrupt verifies cancelFn returns ErrExecutionCompleted +// TestWithCancel_AfterBusinessInterrupt verifies cancelFn returns ErrExecutionEnded // when called after the agent has been interrupted by business logic. func TestWithCancel_AfterBusinessInterrupt(t *testing.T) { ctx := context.Background() @@ -296,10 +296,10 @@ func TestWithCancel_AfterBusinessInterrupt(t *testing.T) { handle, _ := cancelFn() cancelErr := handle.Wait() - assert.ErrorIs(t, cancelErr, ErrExecutionCompleted) + assert.ErrorIs(t, cancelErr, ErrExecutionEnded) } -// TestWithCancel_AfterError verifies cancelFn returns ErrExecutionCompleted +// TestWithCancel_AfterError verifies cancelFn returns ErrExecutionEnded // when called after the agent errors out. func TestWithCancel_AfterError(t *testing.T) { ctx := context.Background() @@ -324,7 +324,7 @@ func TestWithCancel_AfterError(t *testing.T) { handle, _ := cancelFn() cancelErr := handle.Wait() - assert.ErrorIs(t, cancelErr, ErrExecutionCompleted) + assert.ErrorIs(t, cancelErr, ErrExecutionEnded) } // TestWithCancel_TimeoutEscalation tests that WithAgentCancelTimeout causes the @@ -436,7 +436,9 @@ func TestWithCancel_AfterChatModel_WithTools(t *testing.T) { } // TestWithCancel_CancelImmediate_StreamAborted verifies that CancelImmediate -// during model streaming surfaces ErrStreamCanceled and completes quickly. +// during model execution surfaces CancelError and completes quickly. +// Uses blockingChatModel which blocks in Stream(), keeping the agent's run +// function alive so the cancel context stays in stateRunning. func TestWithCancel_CancelImmediate_StreamAborted(t *testing.T) { ctx := context.Background() @@ -471,21 +473,21 @@ func TestWithCancel_CancelImmediate_StreamAborted(t *testing.T) { elapsed := time.Since(start) assert.True(t, elapsed < 2*time.Second, "cancel should complete quickly, elapsed=%v", elapsed) - var foundStreamCanceled bool + var foundCancelError bool for { e, ok := iter.Next() if !ok { break } - if e.Err != nil && errors.Is(e.Err, ErrStreamCanceled) { - foundStreamCanceled = true + if e.Action != nil && e.Action.Interrupted != nil { + foundCancelError = true } var ce *CancelError if e.Err != nil && errors.As(e.Err, &ce) { - foundStreamCanceled = true // CancelError wraps stream abort + foundCancelError = true } } - assert.True(t, foundStreamCanceled, "expected stream-abort error during immediate cancel") + assert.True(t, foundCancelError, "expected CancelError in event stream") } // TestWithCancel_MultipleToolsConcurrent verifies that CancelAfterToolCalls @@ -630,13 +632,11 @@ func TestWithCancel_NoCheckpointStore(t *testing.T) { break } } - if assert.NotNil(t, ce, "expected CancelError even without checkpoint store") { - assert.Empty(t, ce.CheckPointID, "CheckPointID should be empty without checkpoint store") - } + assert.NotNil(t, ce, "expected CancelError even without checkpoint store") } // TestWithCancel_ModelError verifies that a model error marks the cancelCtx as -// done so that a subsequent cancelFn call returns ErrExecutionCompleted. +// done so that a subsequent cancelFn call returns ErrExecutionEnded. func TestWithCancel_ModelError(t *testing.T) { ctx := context.Background() @@ -665,7 +665,7 @@ func TestWithCancel_ModelError(t *testing.T) { handle, _ := cancelFn() cancelErr := handle.Wait() - assert.ErrorIs(t, cancelErr, ErrExecutionCompleted, "cancelFn should return ErrExecutionCompleted after model error") + assert.ErrorIs(t, cancelErr, ErrExecutionEnded, "cancelFn should return ErrExecutionEnded after model error") } // TestWithCancel_Resume_SafePoint covers CancelAfterChatModel and @@ -1266,3 +1266,143 @@ func TestWithCancel_CancelAfterChatModel_NestedAgentTool(t *testing.T) { assert.True(t, hasCancelError, "CancelError expected from nested agent tool with tools") } + +// slowStreamingTool implements StreamableTool (but NOT InvokableTool), streaming +// chunks slowly so CancelImmediate can fire mid-stream. +type slowStreamingTool struct { + name string + chunkInterval time.Duration + chunks []string + started chan struct{} +} + +func (t *slowStreamingTool) Info(_ context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: t.name, + Desc: "slow streaming tool", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "input": {Type: "string"}, + }), + }, nil +} + +func (t *slowStreamingTool) StreamableRun(_ context.Context, _ string, _ ...tool.Option) (*schema.StreamReader[string], error) { + r, w := schema.Pipe[string](1) + go func() { + defer w.Close() + select { + case t.started <- struct{}{}: + default: + } + for _, chunk := range t.chunks { + time.Sleep(t.chunkInterval) + if closed := w.Send(chunk, nil); closed { + return + } + } + }() + return r, nil +} + +// toolCallStreamModel returns a tool-call message on the first Stream call, +// then a plain text response on subsequent calls. +type toolCallStreamModel struct { + callCount int32 +} + +func (m *toolCallStreamModel) Generate(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + if atomic.AddInt32(&m.callCount, 1) == 1 { + return toolCallMsg(toolCall("c1", "slow_tool", `{"input":"x"}`)), nil + } + return schema.AssistantMessage("done", nil), nil +} + +func (m *toolCallStreamModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + msg, err := m.Generate(ctx, input, opts...) + if err != nil { + return nil, err + } + return schema.StreamReaderFromArray([]*schema.Message{msg}), nil +} + +func (m *toolCallStreamModel) BindTools(_ []*schema.ToolInfo) error { return nil } + +// TestWithCancel_CancelImmediate_StreamableToolAborted verifies that CancelImmediate +// during StreamableTool streaming surfaces ErrStreamCanceled on the tool's +// MessageStream.Recv(), just like it does for ChatModel streaming. +func TestWithCancel_CancelImmediate_StreamableToolAborted(t *testing.T) { + ctx := context.Background() + + tcm := &toolCallStreamModel{} + st := &slowStreamingTool{ + name: "slow_tool", + chunkInterval: 200 * time.Millisecond, + chunks: []string{"a", "b", "c", "d", "e", "f", "g", "h", "i", "j"}, + started: make(chan struct{}, 1), + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: tcm, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{st}}, + }, + }) + require.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + EnableStreaming: true, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("hi")}, cancelOpt) + + // Wait for the tool to start streaming + select { + case <-st.started: + case <-time.After(5 * time.Second): + t.Fatal("tool did not start streaming") + } + // Let a few chunks through, then cancel mid-stream + time.Sleep(300 * time.Millisecond) + + handle, _ := cancelFn() + cancelErr := handle.Wait() + assert.NoError(t, cancelErr) + + var foundStreamCanceled bool + var foundCancelError bool + for { + e, ok := iter.Next() + if !ok { + break + } + + // ErrStreamCanceled appears on the tool's MessageStream.Recv() + if e.Output != nil && e.Output.MessageOutput != nil && e.Output.MessageOutput.IsStreaming && + e.Output.MessageOutput.Role == schema.Tool { + stream := e.Output.MessageOutput.MessageStream + for { + _, recvErr := stream.Recv() + if recvErr != nil { + if errors.Is(recvErr, ErrStreamCanceled) { + foundStreamCanceled = true + } + break + } + } + } + + if e.Action != nil && e.Action.Interrupted != nil { + foundCancelError = true + } + var ce *CancelError + if e.Err != nil && errors.As(e.Err, &ce) { + foundCancelError = true + } + } + assert.True(t, foundStreamCanceled, "expected ErrStreamCanceled on tool's MessageStream.Recv()") + assert.True(t, foundCancelError, "expected CancelError in event stream") +} diff --git a/adk/cancel_test.go b/adk/cancel_test.go index 2096a9ac3..97779827b 100644 --- a/adk/cancel_test.go +++ b/adk/cancel_test.go @@ -650,7 +650,6 @@ func TestWithCancel_WithCheckpoint(t *testing.T) { var events []*AgentEvent hasCancelError := false - var cancelErrorCheckPointID string for { event, ok := iter.Next() if !ok { @@ -659,14 +658,12 @@ func TestWithCancel_WithCheckpoint(t *testing.T) { var ce *CancelError if event.Err != nil && errors.As(event.Err, &ce) { hasCancelError = true - cancelErrorCheckPointID = ce.CheckPointID continue } events = append(events, event) } assert.True(t, hasCancelError, "Should have CancelError event after cancel") - assert.Equal(t, "cancel-1", cancelErrorCheckPointID, "CancelError should contain the checkpoint ID") }) } @@ -1132,7 +1129,7 @@ func TestWithCancel_Resume(t *testing.T) { cancelHandle, _ := resumeCancelFn() close(slowModel2.unblockCh) err = cancelHandle.Wait() - assert.True(t, err == nil || errors.Is(err, ErrExecutionCompleted), "unexpected cancel wait error: %v", err) + assert.True(t, err == nil || errors.Is(err, ErrExecutionEnded), "unexpected cancel wait error: %v", err) start := time.Now() resumeEvents := <-resumeEventsCh @@ -1148,7 +1145,7 @@ func TestWithCancel_Resume(t *testing.T) { hasCancelError = true } } - executionCompletedBeforeCancel := errors.Is(err, ErrExecutionCompleted) + executionCompletedBeforeCancel := errors.Is(err, ErrExecutionEnded) assert.True(t, hasCancelError || executionCompletedBeforeCancel, "Resume should have CancelError event after cancel, or execution completed before cancel") }) } @@ -1518,10 +1515,10 @@ func TestWithCancel_SequentialAgent(t *testing.T) { time.Sleep(50 * time.Millisecond) - // Cancel should NOT return ErrExecutionCompleted (the bug before the fix) + // Cancel should NOT return ErrExecutionEnded (the bug before the fix) handle, _ := cancelFn() err = handle.Wait() - assert.NoError(t, err, "Cancel during second agent should succeed, not return ErrExecutionCompleted") + assert.NoError(t, err, "Cancel during second agent should succeed, not return ErrExecutionEnded") drainEventsAndAssertCancelError(t, iter) }) @@ -2254,7 +2251,6 @@ func TestCancel_SequentialWorkflow_CancelAfterChatModel(t *testing.T) { assert.True(t, hasCancelError, "Should have CancelError") assert.Equal(t, CancelAfterChatModel, cancelErr.Info.Mode) - assert.NotEmpty(t, cancelErr.CheckPointID, "CancelError should have checkpoint ID") assert.NotNil(t, cancelErr.interruptSignal, "CancelError should have interrupt signal for checkpoint") resumeAgent1 := newCancelTestAgentWithToolsFinalAnswer(t, "seq_slow") @@ -2614,7 +2610,6 @@ func TestCancelImmediate_AgentTool_PreservesChildCheckpoint(t *testing.T) { cancelErr := drainCancelError(t, iter) assert.NotNil(t, cancelErr, "Should have CancelError from CancelImmediate through agentTool") - assert.NotEmpty(t, cancelErr.CheckPointID) assert.NotNil(t, cancelErr.interruptSignal) resumeLeaf, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ @@ -2809,7 +2804,6 @@ func TestCancelImmediate_MultiLevelNesting(t *testing.T) { elapsed := time.Since(start) assert.NotNil(t, cancelErr, "Should have CancelError from multi-level nesting") - assert.NotEmpty(t, cancelErr.CheckPointID) assert.NotNil(t, cancelErr.interruptSignal) assert.True(t, elapsed < 5*time.Second, "Should complete quickly, elapsed: %v", elapsed) @@ -3058,7 +3052,6 @@ func TestCancelAfterChatModel_Sequential_Agent1CompletesCancelBeforeAgent2Resume assert.Equal(t, int32(0), atomic.LoadInt32(&model2.callCount), "Agent2 should NOT run (cancel caught at transition after agent1)") assert.Equal(t, int32(0), atomic.LoadInt32(&model3.callCount)) - assert.NotEmpty(t, cancelErr.CheckPointID) resumeModel2 := &gatedChatModel{ response: &schema.Message{Role: schema.Assistant, Content: "resumed agent2"}, @@ -3175,7 +3168,6 @@ func TestCancelAfterToolCalls_LoopTransitionBoundary(t *testing.T) { cancelErr := drainCancelError(t, iter) assert.NotNil(t, cancelErr, "Should have CancelError from CancelAfterToolCalls in loop") assert.Equal(t, CancelAfterToolCalls, cancelErr.Info.Mode) - assert.NotEmpty(t, cancelErr.CheckPointID) } func TestCancelContext_ActiveChildren_Tracking(t *testing.T) { @@ -3312,7 +3304,6 @@ func TestCancel_ParallelWorkflow_CancelAfterChatModel(t *testing.T) { assert.True(t, hasCancelError, "Should have CancelError from parallel workflow") assert.Equal(t, CancelAfterChatModel, cancelErr.Info.Mode) - assert.NotEmpty(t, cancelErr.CheckPointID, "CancelError should have checkpoint ID") resumeSlow := newCancelTestAgentWithToolsFinalAnswer(t, "par_slow") resumeFast := newCancelTestAgentWithToolsFinalAnswer(t, "par_fast") @@ -3394,7 +3385,6 @@ func TestCancel_LoopWorkflow_CancelAfterChatModel(t *testing.T) { assert.True(t, hasCancelError, "Should have CancelError from loop workflow") assert.Equal(t, CancelAfterChatModel, cancelErr.Info.Mode) - assert.NotEmpty(t, cancelErr.CheckPointID, "CancelError should have checkpoint ID") resumeAgent := newCancelTestAgentWithToolsFinalAnswer(t, "loop_inner") @@ -3520,7 +3510,6 @@ func TestCancel_NestedWorkflow_AgentTool_CancelAfterChatModel(t *testing.T) { assert.True(t, hasCancelError, "Should have CancelError from deeply nested workflow") assert.Equal(t, CancelAfterChatModel, cancelErr.Info.Mode) - assert.NotEmpty(t, cancelErr.CheckPointID, "CancelError should have checkpoint ID") assert.NotNil(t, cancelErr.interruptSignal, "CancelError should carry interrupt signal through agent tree") // Phase 2: Resume from checkpoint — new instances to avoid data races @@ -3670,7 +3659,6 @@ func TestCancel_CancelAfterToolCalls_InSequentialWorkflow(t *testing.T) { assert.True(t, hasCancelError, "Should have CancelError after tool calls complete") assert.Equal(t, CancelAfterToolCalls, cancelErr.Info.Mode) - assert.NotEmpty(t, cancelErr.CheckPointID, "CancelError should have checkpoint ID") // Phase 2: Resume from checkpoint — new instances resumeTool := &slowTool{ @@ -3726,3 +3714,149 @@ func TestCancel_CancelAfterToolCalls_InSequentialWorkflow(t *testing.T) { } assert.NotEmpty(t, resumeEvents, "Resume should produce events") } + +// TestCancel_SafePointNeverFires_ErrExecutionEnded verifies the waitForCompletion +// path where a safe-point cancel is submitted while the agent is running, but +// the agent finishes without hitting the requested safe-point (e.g. +// CancelAfterToolCalls on an agent with no tool calls). The cancel CAS succeeds +// (stateRunning → stateCancelling), but the agent completes normally (markDone → +// stateDone), so waitForCompletion returns ErrExecutionEnded. +func TestCancel_SafePointNeverFires_ErrExecutionEnded(t *testing.T) { + ctx := context.Background() + + gate := make(chan struct{}) + done := make(chan struct{}, 1) + + m := &gatedChatModel{ + gateChan: gate, + doneChan: done, + response: &schema.Message{ + Role: schema.Assistant, + Content: "Final answer, no tool calls", + }, + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "NoToolAgent", + Description: "Agent with no tools", + Instruction: "You are a test assistant", + Model: m, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("hello")}, cancelOpt) + + // Wait a moment for the agent to enter Generate and block on gateChan. + runtime.Gosched() + time.Sleep(50 * time.Millisecond) + + // Submit a safe-point cancel for tool calls. The agent has no tools, + // so this safe-point will never fire. + handle, submitted := cancelFn(WithAgentCancelMode(CancelAfterToolCalls)) + assert.True(t, submitted) + + // Let the model complete. The agent finishes without hitting the tool + // calls safe-point → markDone → stateDone → waitForCompletion returns + // ErrExecutionEnded. + close(gate) + + waitErr := handle.Wait() + assert.ErrorIs(t, waitErr, ErrExecutionEnded) + + for { + _, ok := iter.Next() + if !ok { + break + } + } +} + +// TestBuildCancelFunc_StateDoneUnderLock exercises the race-condition path +// in buildCancelFunc where the state transitions to stateDone between the +// lockless check and the locked check (cancel.go L732-734). +func TestBuildCancelFunc_StateDoneUnderLock(t *testing.T) { + cc := newCancelContext() + cancelFn := cc.buildCancelFunc() + + // Hold cancelMu so the cancel func blocks when it tries to acquire the lock. + cc.cancelMu.Lock() + + type result struct { + handle *CancelHandle + ok bool + } + ch := make(chan result, 1) + + go func() { + h, ok := cancelFn(WithAgentCancelMode(CancelAfterToolCalls)) + ch <- result{h, ok} + }() + + // Give the goroutine time to reach the Lock() call. + runtime.Gosched() + time.Sleep(20 * time.Millisecond) + + // Transition to stateDone while the cancel goroutine is blocked on the lock. + cc.markDone() + + // Release the lock. The cancel func resumes and finds stateDone. + cc.cancelMu.Unlock() + + r := <-ch + assert.False(t, r.ok, "cancel should not be accepted when execution already done") + assert.ErrorIs(t, r.handle.Wait(), ErrExecutionEnded) +} + +// TestBuildCancelFunc_CASFailStateDone exercises the race-condition path +// in buildCancelFunc where the CAS on stateRunning→stateCancelling fails +// because markDone transitioned stateRunning→stateDone concurrently +// (cancel.go L742-743). +func TestBuildCancelFunc_CASFailStateDone(t *testing.T) { + // Exercises cancel.go L742-743: CAS(stateRunning→stateCancelling) fails + // because markDone transitions stateRunning→stateDone concurrently. + // + // The window between the state check (L738) and CAS (L739) is extremely + // tight. We maximize the chance by having the cancel goroutine block on + // cancelMu, then racing markDone with the lock release. + hit := false + for i := 0; i < 100000 && !hit; i++ { + cc := newCancelContext() + cancelFn := cc.buildCancelFunc() + + // Hold cancelMu so the cancel goroutine blocks at L725. + cc.cancelMu.Lock() + + cancelDone := make(chan struct{}) + var h *CancelHandle + var ok bool + + go func() { + defer close(cancelDone) + h, ok = cancelFn(WithAgentCancelMode(CancelAfterToolCalls)) + }() + + // Let the cancel goroutine reach the Lock() call. + runtime.Gosched() + + // Release lock and fire markDone concurrently. The cancel goroutine + // will acquire the lock and race with markDone on the CAS. + go cc.markDone() + cc.cancelMu.Unlock() + + <-cancelDone + + if !ok && errors.Is(h.Wait(), ErrExecutionEnded) { + hit = true + } + } + if hit { + t.Log("Successfully hit CAS-fail → stateDone path") + } else { + t.Log("CAS race path not triggered (L743 remains a theoretical race edge)") + } +} diff --git a/adk/chatmodel.go b/adk/chatmodel.go index f1155a1cf..fe32cdfac 100644 --- a/adk/chatmodel.go +++ b/adk/chatmodel.go @@ -812,7 +812,7 @@ func (a *ChatModelAgent) handleRunFuncError( // returning false and markDone() executing, a concurrent cancel could // transition stateRunning→stateCancelling. markDone() then does // stateCancelling→stateDone, and the cancel func receives - // ErrExecutionCompleted (execution finished before cancel took effect). + // ErrExecutionEnded (execution finished before cancel took effect). if !cancelCtx.shouldCancel() { cancelCtx.markDone() } diff --git a/adk/prebuilt/planexecute/plan_execute_test.go b/adk/prebuilt/planexecute/plan_execute_test.go index 6734a16b8..ba5ba7ac2 100644 --- a/adk/prebuilt/planexecute/plan_execute_test.go +++ b/adk/prebuilt/planexecute/plan_execute_test.go @@ -1113,7 +1113,7 @@ func TestWithCancel_PlanExecute_DuringExecution(t *testing.T) { time.Sleep(50 * time.Millisecond) - // Cancel should NOT return ErrExecutionCompleted + // Cancel should NOT return ErrExecutionEnded handle, _ := cancelFn() err = handle.Wait() assert.NoError(t, err, "Cancel during executor should succeed") diff --git a/adk/runner.go b/adk/runner.go index 4881122a6..405b69e76 100644 --- a/adk/runner.go +++ b/adk/runner.go @@ -219,7 +219,6 @@ func (r *Runner) handleIter(ctx context.Context, aIter *AsyncIterator[*AgentEven cancelCtx.markCancelHandled() } if cancelErr.interruptSignal != nil && checkPointID != nil { - cancelErr.CheckPointID = *checkPointID cancelErr.InterruptContexts = core.ToInterruptContexts(cancelErr.interruptSignal, allowedAddressSegmentTypes) err := r.saveCheckPoint(ctx, *checkPointID, &InterruptInfo{}, cancelErr.interruptSignal) if err != nil { diff --git a/adk/turn_loop.go b/adk/turn_loop.go index 67db57c47..df12ba40f 100644 --- a/adk/turn_loop.go +++ b/adk/turn_loop.go @@ -363,9 +363,13 @@ type TurnLoopConfig[T any] struct { // Required. GenInput func(ctx context.Context, loop *TurnLoop[T], items []T) (*GenInputResult[T], error) - // GenResume is called exactly once when the TurnLoop detects a mid-turn - // checkpoint on startup (i.e. CheckpointID is configured and the stored - // checkpoint has runner state from an interrupted agent execution). + // GenResume is called at most once during Run(). When CheckpointID is + // configured, Run() queries Store for the checkpoint: + // - If the checkpoint contains runner state (i.e. an agent was interrupted + // mid-turn), Run() calls GenResume to plan a resume turn. + // - Otherwise (no checkpoint, or between-turns checkpoint), GenResume is + // never called and the loop proceeds via GenInput. + // // It receives: // - canceledItems: the items being processed when the prior run was canceled // - unhandledItems: items buffered but not processed when the prior run exited @@ -391,16 +395,16 @@ type TurnLoopConfig[T any] struct { // - tc.Preempted / tc.Stopped: signals while processing events // // Error handling: the returned error is only used when the callback itself - // wants to abort the TurnLoop. The TurnLoop already captures CancelError - // from the event stream when the turn is stopped or preempted, so the - // callback should NOT propagate CancelError. In practice, return a non-nil - // error only for callback-internal failures that should terminate the loop; - // return nil when the current agent is canceled by an external Stop or - // Preempt (Preempt cancels the current agent but the loop continues with - // the next turn). + // wants to abort the TurnLoop. The callback should NEVER propagate + // CancelError — the framework handles it automatically: + // - Stop: the framework propagates CancelError as ExitReason, loop exits. + // - Preempt: the framework does not propagate CancelError; if the callback + // also returns nil, the loop continues with the next turn. + // In practice, return a non-nil error only for callback-internal failures + // that should terminate the loop. // - // Optional. If not provided, events are drained and errors (except CancelError - // from Stop-triggered cancellation) are returned as ExitReason. + // Optional. If not provided, events are drained and the first error + // (including CancelError from Stop) is returned as ExitReason. OnAgentEvents func(ctx context.Context, tc *TurnContext[T], events *AsyncIterator[*AgentEvent]) error // Store is the checkpoint store for persistence and resume. Optional. @@ -442,7 +446,9 @@ type GenInputResult[T any] struct { // Input is the agent input to execute Input *AgentInput - // RunOpts are the options for this agent run + // RunOpts are the options for this agent run. + // Note: do not pass WithCheckPointID here; the TurnLoop automatically + // injects the checkpointID into the Runner. RunOpts []AgentRunOption // Consumed are the items selected for this turn. @@ -464,6 +470,8 @@ type GenResumeResult[T any] struct { RunCtx context.Context // RunOpts are the options for this agent resume run. + // Note: do not pass WithCheckPointID here; the TurnLoop automatically + // injects the checkpointID into the Runner. RunOpts []AgentRunOption // ResumeParams are optional parameters for resuming an interrupted agent. @@ -581,20 +589,20 @@ type TurnLoopExitState[T any] struct { // by a cancel (Stop with WithImmediate, WithGraceful, or WithGracefulTimeout). // Only populated when ExitReason is a *CancelError — if the agent finishes // normally before the cancel takes effect, CanceledItems is empty. - // It can be used to reconstruct GenInput/PrepareAgent inputs when resuming. + // On resume, these are passed to GenResume's CanceledItems parameter. CanceledItems []T // StopCause is the business-supplied reason passed via WithStopCause. // Empty if Stop was not called or no cause was provided. StopCause string - // Checkpointed indicates whether a checkpoint save was attempted during cleanup. + // CheckpointAttempted indicates whether a checkpoint save was attempted when the loop exited. // True only when Store is configured, CheckpointID is set, Stop() was called, - // and the loop was not idle at exit time. - Checkpointed bool + // the loop was not idle at exit time, and WithSkipCheckpoint was not used. + CheckpointAttempted bool // CheckpointErr is the error from checkpoint save, if any. - // nil when Checkpointed is false (no attempt was made) or when the save succeeded. + // nil when CheckpointAttempted is false (no attempt was made) or when the save succeeded. CheckpointErr error // TakeLateItems returns items that were pushed after the loop stopped @@ -604,8 +612,8 @@ type TurnLoopExitState[T any] struct { // This function is idempotent: the first call computes and caches the result; // subsequent calls return the same slice. // - // After TakeLateItems is called, any subsequent Push() will panic. This - // seals the late buffer and prevents items from being silently lost. + // After TakeLateItems is called, any subsequent Push() will panic to + // prevent items from being silently lost. // // It is safe to call TakeLateItems from any goroutine after Wait() returns. // If TakeLateItems is never called, late items are simply garbage collected. @@ -720,10 +728,6 @@ type turnLoopCheckpoint[T any] struct { CanceledItems []T } -// ErrCheckpointStoreNil is returned when a checkpoint operation requires a Store -// but none was configured in TurnLoopConfig. -var ErrCheckpointStoreNil = errors.New("checkpoint store is nil") - func marshalTurnLoopCheckpoint[T any](c *turnLoopCheckpoint[T]) ([]byte, error) { buf := new(bytes.Buffer) if err := gob.NewEncoder(buf).Encode(c); err != nil { @@ -742,7 +746,7 @@ func unmarshalTurnLoopCheckpoint[T any](data []byte) (*turnLoopCheckpoint[T], er func (l *TurnLoop[T]) saveTurnLoopCheckpoint(ctx context.Context, checkPointID string, c *turnLoopCheckpoint[T]) error { if l.config.Store == nil { - return ErrCheckpointStoreNil + return errors.New("checkpoint store is nil") } data, err := marshalTurnLoopCheckpoint(c) if err != nil { @@ -831,14 +835,14 @@ type turnLoopPendingResume[T any] struct { type SafePoint int const ( - // AfterToolCalls allows the agent to finish the current tool-call round - // before being cancelled. - AfterToolCalls SafePoint = 1 << iota // AfterChatModel allows the agent to finish the current chat-model // call before being cancelled. - AfterChatModel - // AnySafePoint is shorthand for AfterToolCalls | AfterChatModel. - AnySafePoint = AfterToolCalls | AfterChatModel + AfterChatModel SafePoint = 1 << iota + // AfterToolCalls allows the agent to finish the current tool-call round + // before being cancelled. + AfterToolCalls + // AnySafePoint is shorthand for AfterChatModel | AfterToolCalls. + AnySafePoint = AfterChatModel | AfterToolCalls ) func (sp SafePoint) toCancelMode() CancelMode { @@ -879,14 +883,17 @@ func WithGraceful() StopOption { } // WithImmediate aborts the running agent turn as soon as possible. -// The agent's context is cancelled immediately without waiting for any -// safe point. Nested agents inside AgentTools are torn down as a side effect. +// The agent is cancelled immediately without waiting for any safe point. +// Nested agents inside AgentTools will also receive the cancel signal +// and be torn down. // // This is the most aggressive stop mode — typically used when the caller // wants to shut down the TurnLoop with no intention of resuming. func WithImmediate() StopOption { return func(cfg *stopConfig) { - cfg.agentCancelOpts = []AgentCancelOption{} + cfg.agentCancelOpts = []AgentCancelOption{ + WithRecursive(), + } } } @@ -986,8 +993,9 @@ type PushOption[T any] func(*pushConfig[T]) // returns or after all tool calls complete), no nested agent is running at // the moment of cancellation — nested agents within AgentTools have either // not started yet (AfterChatModel) or already finished (AfterToolCalls). -// If the preemption escalates to immediate via WithPreemptTimeout, any -// in-flight nested agent is torn down through Go context cancellation. +// Note: WithPreempt does NOT include WithRecursive (no escalation path exists). +// WithPreemptTimeout DOES include WithRecursive so that on timeout escalation, +// nested agents are properly torn down. // // WithPreempt and WithPreemptTimeout are mutually exclusive; if both are // passed to the same Push call, the last one wins. @@ -1007,7 +1015,8 @@ func WithPreempt[T any](safePoint SafePoint) PushOption[T] { // WithPreemptTimeout is like WithPreempt but adds a timeout. If the agent has // not reached the safe point within timeout, the preemption escalates to -// immediate cancellation. +// immediate cancellation. On escalation, nested agents inside AgentTools will +// also receive the cancel signal and be torn down. // // safePoint must not be zero; passing SafePoint(0) panics. func WithPreemptTimeout[T any](safePoint SafePoint, timeout time.Duration) PushOption[T] { @@ -1019,6 +1028,7 @@ func WithPreemptTimeout[T any](safePoint SafePoint, timeout time.Duration) PushO cfg.agentCancelOpts = []AgentCancelOption{ WithAgentCancelMode(safePoint.toCancelMode()), WithAgentCancelTimeout(timeout), + WithRecursive(), } } } @@ -1123,9 +1133,13 @@ func (l *TurnLoop[T]) Run(ctx context.Context) { // Push adds an item to the loop's buffer for processing. // This method is non-blocking and thread-safe. // Returns false if the loop has stopped, true otherwise. If a preemptive push -// succeeds, the second return value is a channel that is closed when the loop -// has acknowledged the preempt signal (by either initiating cancellation of the -// current agent run or reaching a point where no cancellation is needed). +// succeeds, the second return value is a channel that callers can wait on to +// confirm the preempt signal has been received and the cancel request submitted +// — i.e., the current turn is guaranteed to be preempted. Specifically: +// - If an agent is running: the channel closes after TurnLoop submits cancel. +// - If no agent is running (loop idle or not yet started): the channel closes +// immediately (nothing to cancel). +// // If the loop has not been started yet (Run not called), items are buffered // and will be processed once Run is called. // After Wait() returns, failed pushes can be recovered via TurnLoopExitState.TakeLateItems(). @@ -1491,7 +1505,7 @@ func (l *TurnLoop[T]) run(ctx context.Context) { func (l *TurnLoop[T]) setupBridgeStore(spec *turnRunSpec[T], runOpts []AgentRunOption) ([]AgentRunOption, *bridgeStore, error) { store := l.config.Store if store == nil && spec.isResume { - return nil, nil, fmt.Errorf("failed to resume agent: %w", ErrCheckpointStoreNil) + return nil, nil, fmt.Errorf("failed to resume agent: checkpoint store is nil") } if store == nil { return runOpts, nil, nil @@ -1764,12 +1778,12 @@ func (l *TurnLoop[T]) cleanup(ctx context.Context) { var takeLateResult []T l.result = &TurnLoopExitState[T]{ - ExitReason: l.runErr, - UnhandledItems: unhandled, - CanceledItems: l.canceledItems, - StopCause: l.stopSig.getStopCause(), - Checkpointed: checkpointed, - CheckpointErr: checkpointErr, + ExitReason: l.runErr, + UnhandledItems: unhandled, + CanceledItems: l.canceledItems, + StopCause: l.stopSig.getStopCause(), + CheckpointAttempted: checkpointed, + CheckpointErr: checkpointErr, TakeLateItems: func() []T { takeLateOnce.Do(func() { l.lateMu.Lock() diff --git a/adk/turn_loop_test.go b/adk/turn_loop_test.go index ea7c0aa93..309c84f0e 100644 --- a/adk/turn_loop_test.go +++ b/adk/turn_loop_test.go @@ -1887,7 +1887,7 @@ func TestTurnLoop_CheckpointSaveError_ReturnsError(t *testing.T) { loop.Stop(WithImmediate()) exit := loop.Wait() assert.Error(t, exit.ExitReason) - assert.True(t, exit.Checkpointed) + assert.True(t, exit.CheckpointAttempted) assert.Error(t, exit.CheckpointErr) assert.Contains(t, exit.CheckpointErr.Error(), "write failed") } @@ -2362,7 +2362,7 @@ func TestTurnLoop_CheckpointSaveError_MergesWithExistingError(t *testing.T) { assert.Error(t, exit.ExitReason) var ce *CancelError assert.True(t, errors.As(exit.ExitReason, &ce), "ExitReason should be CancelError, not merged with checkpoint error") - assert.True(t, exit.Checkpointed) + assert.True(t, exit.CheckpointAttempted) assert.Error(t, exit.CheckpointErr) assert.Contains(t, exit.CheckpointErr.Error(), "disk full") } @@ -4006,12 +4006,12 @@ func TestTurnLoop_CheckpointErr_SeparateFromExitReason(t *testing.T) { // ExitReason should be nil (clean stop), checkpoint error should be separate assert.Nil(t, result.ExitReason) - assert.True(t, result.Checkpointed) + assert.True(t, result.CheckpointAttempted) assert.Error(t, result.CheckpointErr) assert.Contains(t, result.CheckpointErr.Error(), "storage unavailable") } -func TestTurnLoop_Checkpointed_FalseWhenNoStore(t *testing.T) { +func TestTurnLoop_CheckpointAttempted_FalseWhenNoStore(t *testing.T) { ctx := context.Background() loop := NewTurnLoop(TurnLoopConfig[string]{ @@ -4027,11 +4027,11 @@ func TestTurnLoop_Checkpointed_FalseWhenNoStore(t *testing.T) { loop.Run(ctx) result := loop.Wait() - assert.False(t, result.Checkpointed) + assert.False(t, result.CheckpointAttempted) assert.Nil(t, result.CheckpointErr) } -func TestTurnLoop_Checkpointed_FalseOnErrorExit(t *testing.T) { +func TestTurnLoop_CheckpointAttempted_FalseOnErrorExit(t *testing.T) { ctx := context.Background() store := &turnLoopCheckpointStore{m: make(map[string][]byte)} genInputErr := errors.New("gen input failed") @@ -4068,7 +4068,7 @@ func TestTurnLoop_Checkpointed_FalseOnErrorExit(t *testing.T) { // Loop exited from error, not Stop() — checkpoint should not be saved assert.ErrorIs(t, result.ExitReason, genInputErr) - assert.False(t, result.Checkpointed) + assert.False(t, result.CheckpointAttempted) assert.Nil(t, result.CheckpointErr) } @@ -4128,7 +4128,7 @@ func TestTurnLoop_StopConcurrentWithCallbackError_NoCheckpoint(t *testing.T) { // checkpoint should NOT be saved. if result.ExitReason != nil && !errors.As(result.ExitReason, new(*CancelError)) { assert.ErrorIs(t, result.ExitReason, prepareErr) - assert.False(t, result.Checkpointed, "should not checkpoint when exit is caused by callback error") + assert.False(t, result.CheckpointAttempted, "should not checkpoint when exit is caused by callback error") } // If Stop won the race, that's fine — checkpoint may or may not be saved // depending on idle state. The test is about the error path. @@ -4220,7 +4220,7 @@ func TestTurnLoop_StopWithSkipCheckpoint(t *testing.T) { exit := loop.Wait() assert.NoError(t, exit.ExitReason) - assert.False(t, exit.Checkpointed, "checkpoint should be skipped when WithSkipCheckpoint is used") + assert.False(t, exit.CheckpointAttempted, "checkpoint should be skipped when WithSkipCheckpoint is used") store.mu.Lock() _, exists := store.m[cpID] @@ -4249,7 +4249,7 @@ func TestTurnLoop_StopWithSkipCheckpoint_DeletesStaleCheckpoint(t *testing.T) { loop1.Stop() loop1.Run(ctx) exit1 := loop1.Wait() - assert.True(t, exit1.Checkpointed) + assert.True(t, exit1.CheckpointAttempted) store.mu.Lock() _, exists := store.m[cpID] @@ -4270,7 +4270,7 @@ func TestTurnLoop_StopWithSkipCheckpoint_DeletesStaleCheckpoint(t *testing.T) { loop2.Stop(WithSkipCheckpoint()) loop2.Run(ctx) exit2 := loop2.Wait() - assert.False(t, exit2.Checkpointed, "second loop should skip checkpoint") + assert.False(t, exit2.CheckpointAttempted, "second loop should skip checkpoint") store.mu.Lock() deleteCalled := store.deleteCalled @@ -4503,7 +4503,7 @@ func TestTurnLoop_SkipCheckpoint_Sticky(t *testing.T) { loop.Stop() exit := loop.Wait() - assert.False(t, exit.Checkpointed, "SkipCheckpoint should be sticky across multiple Stop calls") + assert.False(t, exit.CheckpointAttempted, "SkipCheckpoint should be sticky across multiple Stop calls") store.mu.Lock() _, exists := store.m[cpID] @@ -5378,7 +5378,133 @@ func TestAttack_SkipCheckpoint_Sticky(t *testing.T) { loop.Stop(WithImmediate()) exit := loop.Wait() - assert.False(t, exit.Checkpointed, "SkipCheckpoint is sticky; checkpoint should be skipped") + assert.False(t, exit.CheckpointAttempted, "SkipCheckpoint is sticky; checkpoint should be skipped") +} + +// turnLoopNestedProbeAgent simulates an agent with a nested sub-agent +// by deriving a child cancelContext. This allows tests to verify that +// TurnLoop's Stop/Push options correctly propagate recursive cancellation. +// +// IMPORTANT: child.markDone() is NOT called by the probe. The test MUST +// call it (e.g. via t.Cleanup) after verifying propagation to avoid a +// race between markDone closing child.doneChan and the deriveChild +// goroutines propagating the cancel signal. +type turnLoopNestedProbeAgent struct { + parentCCCh chan *cancelContext + childCCCh chan *cancelContext +} + +func (a *turnLoopNestedProbeAgent) Name(_ context.Context) string { return "nested-probe" } +func (a *turnLoopNestedProbeAgent) Description(_ context.Context) string { return "nested-probe" } +func (a *turnLoopNestedProbeAgent) Run(ctx context.Context, _ *AgentInput, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { + iter, gen := NewAsyncIteratorPair[*AgentEvent]() + o := getCommonOptions(nil, opts...) + cc := o.cancelCtx + + child := cc.deriveChild(ctx) + a.parentCCCh <- cc + a.childCCCh <- child + + go func() { + defer gen.Close() + <-cc.cancelChan + for { + if cc.getMode() == CancelImmediate { + gen.Send(&AgentEvent{Err: cc.createCancelError()}) + return + } + time.Sleep(1 * time.Millisecond) + } + }() + return iter +} + +func TestTurnLoop_Stop_WithImmediate_RecursivePropagation(t *testing.T) { + parentCCCh := make(chan *cancelContext, 1) + childCCCh := make(chan *cancelContext, 1) + probe := &turnLoopNestedProbeAgent{parentCCCh: parentCCCh, childCCCh: childCCCh} + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return probe, nil + }, + }) + + loop.Push("msg1") + cc := <-parentCCCh + child := <-childCCCh + t.Cleanup(func() { child.markDone() }) + + loop.Stop(WithImmediate()) + + // Child should receive the cancel signal via recursive propagation. + select { + case <-child.cancelChan: + case <-time.After(2 * time.Second): + t.Fatal("child did not receive cancel via recursive propagation") + } + + // Child should also receive the immediate cancel signal. + select { + case <-child.immediateChan: + case <-time.After(2 * time.Second): + t.Fatal("child did not receive immediate cancel via recursive propagation") + } + + assert.True(t, cc.isRecursive(), "WithImmediate should set recursive on parent") + assert.True(t, child.shouldCancel(), "child should be cancelled") + assert.True(t, child.isImmediateCancelled(), "child should have received immediate cancel") + + exit := loop.Wait() + var ce *CancelError + require.True(t, errors.As(exit.ExitReason, &ce)) + assert.Equal(t, CancelImmediate, ce.Info.Mode) +} + +func TestTurnLoop_Push_WithPreemptTimeout_RecursivePropagation(t *testing.T) { + parentCCCh := make(chan *cancelContext, 2) + childCCCh := make(chan *cancelContext, 2) + probe := &turnLoopNestedProbeAgent{parentCCCh: parentCCCh, childCCCh: childCCCh} + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return probe, nil + }, + }) + + loop.Push("first") + cc := <-parentCCCh + child := <-childCCCh + t.Cleanup(func() { child.markDone() }) + + // Preempt with a very short timeout so it escalates to CancelImmediate quickly. + loop.Push("urgent", WithPreemptTimeout[string](AfterChatModel, 10*time.Millisecond)) + + // After timeout escalation, child should receive the immediate cancel + // via recursive propagation. + select { + case <-child.immediateChan: + case <-time.After(2 * time.Second): + t.Fatal("child did not receive immediate cancel after preempt timeout escalation") + } + + assert.True(t, cc.isRecursive(), "WithPreemptTimeout should set recursive on parent") + assert.True(t, child.isImmediateCancelled(), "child should have received immediate cancel") + + loop.Stop(WithImmediate()) + loop.Wait() } func TestUntilIdleFor_NonPositive_Panics(t *testing.T) { @@ -5387,3 +5513,18 @@ func TestUntilIdleFor_NonPositive_Panics(t *testing.T) { assert.PanicsWithValue(t, "adk: UntilIdleFor: duration must be positive", func() { UntilIdleFor(-1 * time.Second) }) } + +func TestSaveTurnLoopCheckpoint_NilStore(t *testing.T) { + l := &TurnLoop[string]{config: TurnLoopConfig[string]{Store: nil}} + err := l.saveTurnLoopCheckpoint(context.Background(), "cp-1", &turnLoopCheckpoint[string]{}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "checkpoint store is nil") +} + +func TestSetupBridgeStore_NilStore_Resume(t *testing.T) { + l := &TurnLoop[string]{config: TurnLoopConfig[string]{Store: nil}} + spec := &turnRunSpec[string]{isResume: true} + _, _, err := l.setupBridgeStore(spec, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "checkpoint store is nil") +} From d720849f2d78d2edc360e37041f496aef5b4368f Mon Sep 17 00:00:00 2001 From: shentongmartin Date: Tue, 21 Apr 2026 14:15:26 +0800 Subject: [PATCH 59/65] feat(adk): integrate AgenticMessage into ADK (#920) --- adk/agent_tool.go | 98 +- adk/agent_tool_test.go | 93 ++ adk/agentic_callback_integration_test.go | 268 ++++ adk/agentic_integration_test.go | 665 ++++++++ adk/agentic_react_test.go | 1143 ++++++++++++++ adk/agentic_test.go | 1355 +++++++++++++++++ adk/callback.go | 81 +- adk/callback_test.go | 162 +- adk/cancel.go | 16 +- adk/cancel_edge_test.go | 4 +- adk/cancel_test.go | 2 +- adk/chatmodel.go | 516 +++++-- adk/deterministic_transfer.go | 2 +- adk/failover_chatmodel.go | 225 ++- adk/failover_chatmodel_test.go | 62 +- adk/flow.go | 206 ++- adk/handler.go | 114 +- adk/instruction.go | 2 +- adk/interface.go | 304 +++- adk/interrupt.go | 76 +- .../planexecute/{utils.go => utils_test.go} | 0 adk/react.go | 284 +++- adk/react_test.go | 28 + adk/retry_chatmodel.go | 72 +- adk/runctx.go | 214 ++- adk/runner.go | 213 ++- adk/runner_test.go | 48 + adk/turn_loop.go | 140 +- adk/turn_loop_test.go | 1148 +++++++------- adk/utils.go | 76 +- adk/workflow_test.go | 2 +- adk/wrappers.go | 313 ++-- adk/wrappers_failover_test.go | 44 +- adk/wrappers_retry_failover_test.go | 48 +- components/model/interface.go | 40 +- schema/agentic_message.go | 51 +- schema/agentic_message_test.go | 39 + schema/serialization.go | 2 + schema/tool_test.go | 38 + utils/callbacks/template.go | 43 +- utils/callbacks/template_test.go | 122 ++ 41 files changed, 6948 insertions(+), 1411 deletions(-) create mode 100644 adk/agentic_callback_integration_test.go create mode 100644 adk/agentic_integration_test.go create mode 100644 adk/agentic_react_test.go create mode 100644 adk/agentic_test.go rename adk/prebuilt/planexecute/{utils.go => utils_test.go} (100%) diff --git a/adk/agent_tool.go b/adk/agent_tool.go index fde319cb4..2c78a584c 100644 --- a/adk/agent_tool.go +++ b/adk/agent_tool.go @@ -103,14 +103,34 @@ func NewAgentTool(_ context.Context, agent Agent, options ...AgentToolOption) to } } -type agentTool struct { - agent Agent +// NewTypedAgentTool creates a new agent tool that wraps a TypedAgent as a tool.BaseTool. +func NewTypedAgentTool[M messageType](_ context.Context, agent TypedAgent[M], options ...AgentToolOption) tool.BaseTool { + opts := &AgentToolOptions{} + for _, opt := range options { + opt(opts) + } + + return &typedAgentTool[M]{ + agent: agent, + fullChatHistoryAsInput: opts.fullChatHistoryAsInput, + inputSchema: opts.agentInputSchema, + } +} + +type typedAgentTool[M messageType] struct { + agent TypedAgent[M] fullChatHistoryAsInput bool inputSchema *schema.ParamsOneOf } -func (at *agentTool) Info(ctx context.Context) (*schema.ToolInfo, error) { +type agentTool = typedAgentTool[*schema.Message] + +type agentToolRequest struct { + Request string `json:"request"` +} + +func (at *typedAgentTool[M]) Info(ctx context.Context) (*schema.ToolInfo, error) { name := at.agent.Name(ctx) if name == "" { return nil, errors.New("agent tool requires a non-empty Name") @@ -119,7 +139,6 @@ func (at *agentTool) Info(ctx context.Context) (*schema.ToolInfo, error) { if desc == "" { return nil, errors.New("agent tool requires a non-empty Description") } - param := at.inputSchema if param == nil { param = defaultAgentToolParam @@ -132,41 +151,41 @@ func (at *agentTool) Info(ctx context.Context) (*schema.ToolInfo, error) { }, nil } -func (at *agentTool) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { +func (at *typedAgentTool[M]) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { gen, enableStreaming := getEmitGeneratorAndEnableStreaming(opts) var ms *bridgeStore - var iter *AsyncIterator[*AgentEvent] + var iter *AsyncIterator[*TypedAgentEvent[M]] var err error wasInterrupted, hasState, state := tool.GetInterruptState[[]byte](ctx) if !wasInterrupted { ms = newBridgeStore() - var input []Message + + var input []M if at.fullChatHistoryAsInput { - input, err = getReactChatHistory(ctx, at.agent.Name(ctx)) - if err != nil { - return "", err + var zero M + if _, ok := any(zero).(*schema.Message); !ok { + return "", fmt.Errorf("fullChatHistoryAsInput is only supported for *schema.Message agents") } + msgInput, histErr := getReactChatHistory(ctx, at.agent.Name(ctx)) + if histErr != nil { + return "", histErr + } + input = any(msgInput).([]M) } else { if at.inputSchema == nil { - // default input schema - type request struct { - Request string `json:"request"` - } - - req := &request{} + req := &agentToolRequest{} err = sonic.UnmarshalString(argumentsInJSON, req) if err != nil { return "", err } argumentsInJSON = req.Request } - input = []Message{ - schema.UserMessage(argumentsInJSON), - } + input = newTypedUserMessages[M](argumentsInJSON) } - iter = newInvokableAgentToolRunner(at.agent, ms, enableStreaming).Run(ctx, input, + runner := newTypedInvokableAgentToolRunner[M](at.agent, ms, enableStreaming) + iter = runner.Run(ctx, input, append(extractAndDeriveCancelCtx(ctx, at.agent.Name(ctx), opts), WithCheckPointID(bridgeCheckpointID), withSharedParentSession())...) } else { if !hasState { @@ -178,14 +197,14 @@ func (at *agentTool) InvokableRun(ctx context.Context, argumentsInJSON string, o agentOpts := extractAndDeriveCancelCtx(ctx, at.agent.Name(ctx), opts) agentOpts = append(agentOpts, withSharedParentSession()) - iter, err = newInvokableAgentToolRunner(at.agent, ms, enableStreaming). - Resume(ctx, bridgeCheckpointID, agentOpts...) + runner := newTypedInvokableAgentToolRunner[M](at.agent, ms, enableStreaming) + iter, err = runner.Resume(ctx, bridgeCheckpointID, agentOpts...) if err != nil { return "", err } } - var lastEvent *AgentEvent + var lastEvent *TypedAgentEvent[M] for { event, ok := iter.Next() if !ok { @@ -211,9 +230,13 @@ func (at *agentTool) InvokableRun(ctx context.Context, argumentsInJSON string, o rp = append(rp, event.RunPath...) event.RunPath = rp } - tmp := copyAgentEvent(event) - gen.Send(event) - event = tmp + if msgEvent, ok := any(event).(*AgentEvent); ok { + tmp := copyTypedAgentEvent(msgEvent) + gen.Send(msgEvent) + event = any(tmp).(*TypedAgentEvent[M]) + } else { + return "", fmt.Errorf("cross-message-type agent tools are not supported: cannot use an AgenticMessage agent as a tool of a Message agent") + } } } @@ -244,7 +267,7 @@ func (at *agentTool) InvokableRun(ctx context.Context, argumentsInJSON string, o if err != nil { return "", err } - ret = msg.Content + ret = extractTextContent(msg) } } @@ -308,8 +331,11 @@ func getEmitGeneratorAndEnableStreaming(opts []tool.Option) (*AsyncGenerator[*Ag func getReactChatHistory(ctx context.Context, destAgentName string) ([]Message, error) { var messages []Message err := compose.ProcessState(ctx, func(ctx context.Context, st *State) error { + if len(st.Messages) == 0 { + return nil + } messages = make([]Message, len(st.Messages)-1) - copy(messages, st.Messages[:len(st.Messages)-1]) // remove the last assistant message, which is the tool call message + copy(messages, st.Messages[:len(st.Messages)-1]) return nil }) if err != nil { @@ -339,8 +365,20 @@ func getReactChatHistory(ctx context.Context, destAgentName string) ([]Message, return history, nil } -func newInvokableAgentToolRunner(agent Agent, store compose.CheckPointStore, enableStreaming bool) *Runner { - return &Runner{ +func newTypedUserMessages[M messageType](text string) []M { + var zero M + switch any(zero).(type) { + case *schema.Message: + return any([]Message{schema.UserMessage(text)}).([]M) + case *schema.AgenticMessage: + return any([]*schema.AgenticMessage{schema.UserAgenticMessage(text)}).([]M) + default: + return nil + } +} + +func newTypedInvokableAgentToolRunner[M messageType](agent TypedAgent[M], store compose.CheckPointStore, enableStreaming bool) *TypedRunner[M] { + return &TypedRunner[M]{ a: agent, enableStreaming: enableStreaming, store: store, diff --git a/adk/agent_tool_test.go b/adk/agent_tool_test.go index cfedb24c6..54c02ea9c 100644 --- a/adk/agent_tool_test.go +++ b/adk/agent_tool_test.go @@ -21,9 +21,11 @@ import ( "fmt" "strings" "sync" + "sync/atomic" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/tool" @@ -31,6 +33,24 @@ import ( "github.com/cloudwego/eino/schema" ) +type mockChatModelForAttack struct { + generateFn func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) +} + +func (m *mockChatModelForAttack) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + return m.generateFn(ctx, input, opts...) +} + +func (m *mockChatModelForAttack) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + result, err := m.generateFn(ctx, input, opts...) + if err != nil { + return nil, err + } + r, w := schema.Pipe[*schema.Message](1) + go func() { defer w.Close(); w.Send(result, nil) }() + return r, nil +} + // mockAgent implements the Agent interface for testing type mockAgentForTool struct { name string @@ -1146,3 +1166,76 @@ func TestInvokableAgentTool_ErrorCases(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "", out2) } + +func TestCrossTypeAgentToolGracefulError(t *testing.T) { + ctx := context.Background() + + innerModel := &mockAgenticModel{ + generateFn: func(ctx context.Context, input []*schema.AgenticMessage, opts ...model.Option) (*schema.AgenticMessage, error) { + return agenticMsg("inner result"), nil + }, + } + + innerAgent, err := NewTypedChatModelAgent[*schema.AgenticMessage](ctx, &TypedChatModelAgentConfig[*schema.AgenticMessage]{ + Name: "AgenticInner", + Description: "An agentic agent used as a tool", + Model: innerModel, + }) + require.NoError(t, err) + + agenticAgentTool := NewTypedAgentTool(ctx, TypedAgent[*schema.AgenticMessage](innerAgent)) + + var outerCallCount int32 + outerModel := &mockChatModelForAttack{ + generateFn: func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + count := atomic.AddInt32(&outerCallCount, 1) + if count == 1 { + return &schema.Message{ + Role: schema.Assistant, + ToolCalls: []schema.ToolCall{ + {ID: "c1", Function: schema.FunctionCall{Name: "AgenticInner", Arguments: `{"request":"test"}`}}, + }, + }, nil + } + return schema.AssistantMessage("done", nil), nil + }, + } + + outerAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "OuterMessageAgent", + Description: "A Message agent using an AgenticMessage sub-agent tool", + Model: outerModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{agenticAgentTool}, + }, + }, + }) + require.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{Agent: outerAgent, EnableStreaming: true}) + iter := runner.Query(ctx, "test cross-type") + + var capturedErr error + for { + event, ok := iter.Next() + if !ok { + break + } + if event.Err != nil { + capturedErr = event.Err + t.Logf("Cross-type error message: %v", event.Err) + } + } + + if capturedErr == nil { + t.Log("DESIGN CONCERN: Cross-type agent tool (AgenticMessage sub-agent in Message agent) " + + "only errors at event forwarding time when streaming is enabled. " + + "The error check happens in the gen.Send path, which is only exercised " + + "when the outer agent actually calls the tool AND streaming is enabled. " + + "Without streaming, the tool result is returned as a string, so no type mismatch occurs.") + } else { + assert.Contains(t, capturedErr.Error(), "cross-message-type", + "Error should mention cross-message-type incompatibility") + } +} diff --git a/adk/agentic_callback_integration_test.go b/adk/agentic_callback_integration_test.go new file mode 100644 index 000000000..689188fc6 --- /dev/null +++ b/adk/agentic_callback_integration_test.go @@ -0,0 +1,268 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/schema" +) + +type agenticCallbackRecorder struct { + mu sync.Mutex + onStartCalled bool + onEndCalled bool + runInfo *callbacks.RunInfo + inputReceived *TypedAgentCallbackInput[*schema.AgenticMessage] + eventsReceived []*TypedAgentEvent[*schema.AgenticMessage] + eventsDone chan struct{} + closeOnce sync.Once +} + +func (r *agenticCallbackRecorder) getOnStartCalled() bool { + r.mu.Lock() + defer r.mu.Unlock() + return r.onStartCalled +} + +func (r *agenticCallbackRecorder) getOnEndCalled() bool { + r.mu.Lock() + defer r.mu.Unlock() + return r.onEndCalled +} + +func (r *agenticCallbackRecorder) getEventsReceived() []*TypedAgentEvent[*schema.AgenticMessage] { + r.mu.Lock() + defer r.mu.Unlock() + result := make([]*TypedAgentEvent[*schema.AgenticMessage], len(r.eventsReceived)) + copy(result, r.eventsReceived) + return result +} + +func newAgenticRecordingHandler(recorder *agenticCallbackRecorder) callbacks.Handler { + recorder.eventsDone = make(chan struct{}) + return callbacks.NewHandlerBuilder(). + OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { + if info.Component != ComponentOfAgenticAgent { + return ctx + } + recorder.mu.Lock() + defer recorder.mu.Unlock() + recorder.onStartCalled = true + recorder.runInfo = info + if agentInput := ConvTypedCallbackInput[*schema.AgenticMessage](input); agentInput != nil { + recorder.inputReceived = agentInput + } + return ctx + }). + OnEndFn(func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context { + if info.Component != ComponentOfAgenticAgent { + return ctx + } + recorder.mu.Lock() + recorder.onEndCalled = true + recorder.runInfo = info + recorder.mu.Unlock() + + if agentOutput := ConvTypedCallbackOutput[*schema.AgenticMessage](output); agentOutput != nil { + if agentOutput.Events != nil { + go func() { + defer recorder.closeOnce.Do(func() { close(recorder.eventsDone) }) + for { + event, ok := agentOutput.Events.Next() + if !ok { + break + } + recorder.mu.Lock() + recorder.eventsReceived = append(recorder.eventsReceived, event) + recorder.mu.Unlock() + } + }() + return ctx + } + } + recorder.closeOnce.Do(func() { close(recorder.eventsDone) }) + return ctx + }). + Build() +} + +func TestAgenticCallback(t *testing.T) { + ctx := context.Background() + + expectedContent := "This is the test response content" + m := &mockAgenticModel{ + generateFn: func(ctx context.Context, input []*schema.AgenticMessage, opts ...model.Option) (*schema.AgenticMessage, error) { + return agenticMsg(expectedContent), nil + }, + } + + agent, err := NewTypedChatModelAgent[*schema.AgenticMessage](ctx, &TypedChatModelAgentConfig[*schema.AgenticMessage]{ + Name: "TestChatAgent", + Description: "Test chat agent", + Instruction: "You are a test agent", + Model: m, + }) + require.NoError(t, err) + + recorder := &agenticCallbackRecorder{} + handler := newAgenticRecordingHandler(recorder) + + var agentEvents []*TypedAgentEvent[*schema.AgenticMessage] + runner := NewTypedRunner[*schema.AgenticMessage](TypedRunnerConfig[*schema.AgenticMessage]{Agent: agent}) + iter := runner.Query(ctx, "hello", WithCallbacks(handler)) + for { + event, ok := iter.Next() + if !ok { + break + } + agentEvents = append(agentEvents, event) + } + + <-recorder.eventsDone + assertAgenticEventRoleFields(t, agentEvents) + + t.Run("OnStart_Invocation", func(t *testing.T) { + assert.True(t, recorder.getOnStartCalled(), "OnStart should be called") + require.NotNil(t, recorder.inputReceived, "Input should be received") + require.NotNil(t, recorder.inputReceived.Input, "AgentInput should be set") + assert.Len(t, recorder.inputReceived.Input.Messages, 1) + }) + + t.Run("OnEnd_Invocation", func(t *testing.T) { + assert.True(t, recorder.getOnEndCalled(), "OnEnd should be called") + assert.Len(t, recorder.getEventsReceived(), 1) + }) + + t.Run("RunInfo_Fields", func(t *testing.T) { + require.NotNil(t, recorder.runInfo) + assert.Equal(t, "TestChatAgent", recorder.runInfo.Name) + assert.Equal(t, ComponentOfAgenticAgent, recorder.runInfo.Component) + }) + + t.Run("Events_MatchAgentOutput", func(t *testing.T) { + require.NotEmpty(t, agentEvents, "Agent should emit events") + received := recorder.getEventsReceived() + require.NotEmpty(t, received, "Callback should receive events") + + require.Len(t, received, 1, "Callback should receive exactly 1 event") + require.NotNil(t, received[0].Output) + require.NotNil(t, received[0].Output.MessageOutput) + require.NotNil(t, received[0].Output.MessageOutput.Message) + assert.Equal(t, expectedContent, agenticTextContent(received[0].Output.MessageOutput.Message)) + }) +} + +func TestAgenticCallbackMultipleHandlers(t *testing.T) { + ctx := context.Background() + + m := &mockAgenticModel{ + generateFn: func(ctx context.Context, input []*schema.AgenticMessage, opts ...model.Option) (*schema.AgenticMessage, error) { + return agenticMsg("test response"), nil + }, + } + + agent, err := NewTypedChatModelAgent[*schema.AgenticMessage](ctx, &TypedChatModelAgentConfig[*schema.AgenticMessage]{ + Name: "TestAgent", + Description: "Test agent", + Instruction: "You are a test agent", + Model: m, + }) + require.NoError(t, err) + + recorder1 := &agenticCallbackRecorder{} + recorder2 := &agenticCallbackRecorder{} + handler1 := newAgenticRecordingHandler(recorder1) + handler2 := newAgenticRecordingHandler(recorder2) + + runner := NewTypedRunner[*schema.AgenticMessage](TypedRunnerConfig[*schema.AgenticMessage]{Agent: agent}) + iter := runner.Query(ctx, "hello", WithCallbacks(handler1, handler2)) + for { + _, ok := iter.Next() + if !ok { + break + } + } + + <-recorder1.eventsDone + <-recorder2.eventsDone + + assert.True(t, recorder1.getOnStartCalled(), "Handler1 OnStart should be called") + assert.True(t, recorder2.getOnStartCalled(), "Handler2 OnStart should be called") + assert.True(t, recorder1.getOnEndCalled(), "Handler1 OnEnd should be called") + assert.True(t, recorder2.getOnEndCalled(), "Handler2 OnEnd should be called") + + assert.NotEmpty(t, recorder1.getEventsReceived(), "Handler1 should receive events") + assert.NotEmpty(t, recorder2.getEventsReceived(), "Handler2 should receive events") +} + +func TestCoverage_WrapAgenticIterWithOnEnd(t *testing.T) { + ctx := context.Background() + + var onEndCalled bool + handler := callbacks.NewHandlerBuilder(). + OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { + return ctx + }). + OnEndFn(func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context { + if info.Component == ComponentOfAgenticAgent { + onEndCalled = true + } + return ctx + }). + Build() + + ctx = initAgenticCallbacks(ctx, "test-agent", "ChatModel", + WithCallbacks(handler)) + + cbInput := &TypedAgentCallbackInput[*schema.AgenticMessage]{ + Input: &TypedAgentInput[*schema.AgenticMessage]{ + Messages: []*schema.AgenticMessage{schema.UserAgenticMessage("Hi")}, + }, + } + ctx = callbacks.OnStart(ctx, cbInput) + + origIter, origGen := NewAsyncIteratorPair[*TypedAgentEvent[*schema.AgenticMessage]]() + go func() { + defer origGen.Close() + origGen.Send(&TypedAgentEvent[*schema.AgenticMessage]{ + Output: &TypedAgentOutput[*schema.AgenticMessage]{ + MessageOutput: &TypedMessageVariant[*schema.AgenticMessage]{ + Message: agenticMsg("done"), + }, + }, + }) + }() + + wrappedIter := wrapAgenticIterWithOnEnd(ctx, origIter) + + for { + _, ok := wrappedIter.Next() + if !ok { + break + } + } + + assert.True(t, onEndCalled, "OnEnd callback should have been called") +} diff --git a/adk/agentic_integration_test.go b/adk/agentic_integration_test.go new file mode 100644 index 000000000..eb6657991 --- /dev/null +++ b/adk/agentic_integration_test.go @@ -0,0 +1,665 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "encoding/json" + "sync/atomic" + "testing" + "time" + + "github.com/eino-contrib/jsonschema" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/components/tool" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" +) + +func agenticMsg(text string) *schema.AgenticMessage { + return &schema.AgenticMessage{ + Role: schema.AgenticRoleTypeAssistant, + ContentBlocks: []*schema.ContentBlock{ + schema.NewContentBlock(&schema.AssistantGenText{Text: text}), + }, + } +} + +func agenticTextContent(msg *schema.AgenticMessage) string { + for _, b := range msg.ContentBlocks { + if b.AssistantGenText != nil { + return b.AssistantGenText.Text + } + } + return "" +} + +func TestAgenticIntegration_ChatModelSingleShot(t *testing.T) { + ctx := context.Background() + + m := &mockAgenticModel{ + generateFn: func(ctx context.Context, input []*schema.AgenticMessage, opts ...model.Option) (*schema.AgenticMessage, error) { + return agenticMsg("Handled internally with tool result: 42"), nil + }, + } + + dummyTool := newSlowTool("calculator", 0, "42") + + agent, err := NewTypedChatModelAgent[*schema.AgenticMessage](ctx, &TypedChatModelAgentConfig[*schema.AgenticMessage]{ + Name: "ToolCallAgent", + Description: "Agent with tools for agentic model", + Instruction: "You are a calculator.", + Model: m, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{dummyTool}, + }, + }, + }) + require.NoError(t, err) + + runner := NewTypedRunner[*schema.AgenticMessage](TypedRunnerConfig[*schema.AgenticMessage]{ + Agent: agent, + }) + + iter := runner.Query(ctx, "What is 6*7?") + + var events []*TypedAgentEvent[*schema.AgenticMessage] + for { + event, ok := iter.Next() + if !ok { + break + } + events = append(events, event) + } + + require.Len(t, events, 1) + assertAgenticEventRoleFields(t, events) + lastEvent := events[len(events)-1] + require.Nil(t, lastEvent.Err) + require.NotNil(t, lastEvent.Output) + require.NotNil(t, lastEvent.Output.MessageOutput) + assert.Equal(t, "Handled internally with tool result: 42", + agenticTextContent(lastEvent.Output.MessageOutput.Message)) +} + +func TestAgenticIntegration_ChatModelToolsPassedViaOptions(t *testing.T) { + ctx := context.Background() + + var receivedTools []*schema.ToolInfo + m := &mockAgenticModel{ + generateFn: func(ctx context.Context, input []*schema.AgenticMessage, opts ...model.Option) (*schema.AgenticMessage, error) { + o := model.GetCommonOptions(&model.Options{}, opts...) + receivedTools = o.Tools + return agenticMsg("done"), nil + }, + } + + dummyTool := newSlowTool("my_tool", 0, "result") + + agent, err := NewTypedChatModelAgent[*schema.AgenticMessage](ctx, &TypedChatModelAgentConfig[*schema.AgenticMessage]{ + Name: "ToolOptAgent", + Description: "Agent verifying tools are passed via options", + Model: m, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{dummyTool}, + }, + }, + }) + require.NoError(t, err) + + runner := NewTypedRunner[*schema.AgenticMessage](TypedRunnerConfig[*schema.AgenticMessage]{ + Agent: agent, + }) + iter := runner.Query(ctx, "test tools") + for { + _, ok := iter.Next() + if !ok { + break + } + } + + require.NotNil(t, receivedTools, "tools should be passed via model.Options") + require.Len(t, receivedTools, 1) + assert.Equal(t, "my_tool", receivedTools[0].Name) +} + +func TestAgenticIntegration_StreamingWithRunner(t *testing.T) { + ctx := context.Background() + + chunk1 := &schema.AgenticMessage{ + Role: schema.AgenticRoleTypeAssistant, + ContentBlocks: []*schema.ContentBlock{ + schema.NewContentBlock(&schema.AssistantGenText{Text: "Hello "}), + }, + } + chunk2 := &schema.AgenticMessage{ + Role: schema.AgenticRoleTypeAssistant, + ContentBlocks: []*schema.ContentBlock{ + schema.NewContentBlock(&schema.AssistantGenText{Text: "world"}), + }, + } + + m := &mockAgenticModel{ + streamFn: func(ctx context.Context, input []*schema.AgenticMessage, opts ...model.Option) (*schema.StreamReader[*schema.AgenticMessage], error) { + r, w := schema.Pipe[*schema.AgenticMessage](2) + go func() { + defer w.Close() + w.Send(chunk1, nil) + w.Send(chunk2, nil) + }() + return r, nil + }, + } + + agent, err := NewTypedChatModelAgent[*schema.AgenticMessage](ctx, &TypedChatModelAgentConfig[*schema.AgenticMessage]{ + Name: "StreamRunner", + Description: "Streaming runner agent", + Model: m, + }) + require.NoError(t, err) + + runner := NewTypedRunner[*schema.AgenticMessage](TypedRunnerConfig[*schema.AgenticMessage]{ + Agent: agent, + EnableStreaming: true, + }) + + iter := runner.Query(ctx, "stream me") + + event, ok := iter.Next() + require.True(t, ok) + assert.Nil(t, event.Err) + require.NotNil(t, event.Output) + require.NotNil(t, event.Output.MessageOutput) + + if event.Output.MessageOutput.IsStreaming { + require.NotNil(t, event.Output.MessageOutput.MessageStream) + var chunks []*schema.AgenticMessage + for { + chunk, err := event.Output.MessageOutput.MessageStream.Recv() + if err != nil { + break + } + chunks = append(chunks, chunk) + } + assert.Equal(t, 2, len(chunks)) + } else { + assert.NotNil(t, event.Output.MessageOutput.Message) + } + + _, ok = iter.Next() + assert.False(t, ok) +} + +func TestAgenticIntegration_CancelDuringExecution(t *testing.T) { + ctx := context.Background() + + modelStarted := make(chan struct{}, 1) + modelBlocked := make(chan struct{}) + + m := &mockAgenticModel{ + generateFn: func(ctx context.Context, input []*schema.AgenticMessage, opts ...model.Option) (*schema.AgenticMessage, error) { + select { + case modelStarted <- struct{}{}: + default: + } + select { + case <-modelBlocked: + return agenticMsg("should not reach"), nil + case <-ctx.Done(): + return nil, ctx.Err() + } + }, + } + + agent, err := NewTypedChatModelAgent[*schema.AgenticMessage](ctx, &TypedChatModelAgentConfig[*schema.AgenticMessage]{ + Name: "CancelAgent", + Description: "cancel test", + Model: m, + }) + require.NoError(t, err) + + cancelCtx, cancel := context.WithCancel(ctx) + defer cancel() + + runner := NewTypedRunner[*schema.AgenticMessage](TypedRunnerConfig[*schema.AgenticMessage]{ + Agent: agent, + }) + iter := runner.Run(cancelCtx, []*schema.AgenticMessage{ + schema.UserAgenticMessage("Hi"), + }) + + <-modelStarted + cancel() + + var capturedErr error + for { + event, ok := iter.Next() + if !ok { + break + } + if event.Err != nil { + capturedErr = event.Err + } + } + require.Error(t, capturedErr, "should propagate cancel error") + assert.ErrorIs(t, capturedErr, context.Canceled) +} + +func TestAgenticIntegration_CancelWithTimeout(t *testing.T) { + ctx := context.Background() + + sa := &myAgenticAgent{ + name: "slow-agent", + runFn: func(ctx context.Context, input *TypedAgentInput[*schema.AgenticMessage], options ...AgentRunOption) *AsyncIterator[*TypedAgentEvent[*schema.AgenticMessage]] { + iter, generator := NewAsyncIteratorPair[*TypedAgentEvent[*schema.AgenticMessage]]() + go func() { + defer generator.Close() + select { + case <-time.After(10 * time.Second): + generator.Send(&TypedAgentEvent[*schema.AgenticMessage]{ + Output: &TypedAgentOutput[*schema.AgenticMessage]{ + MessageOutput: &TypedMessageVariant[*schema.AgenticMessage]{ + Message: agenticMsg("slow response"), + }, + }, + }) + case <-ctx.Done(): + generator.Send(&TypedAgentEvent[*schema.AgenticMessage]{ + Err: ctx.Err(), + }) + } + }() + return iter + }, + } + + timeoutCtx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) + defer cancel() + + runner := NewTypedRunner[*schema.AgenticMessage](TypedRunnerConfig[*schema.AgenticMessage]{ + Agent: sa, + }) + iter := runner.Run(timeoutCtx, []*schema.AgenticMessage{ + schema.UserAgenticMessage("slow request"), + }) + + var capturedErr error + for { + event, ok := iter.Next() + if !ok { + break + } + if event.Err != nil { + capturedErr = event.Err + } + } + + require.Error(t, capturedErr, "should get timeout/cancel error") + assert.ErrorIs(t, capturedErr, context.DeadlineExceeded) +} +func TestAgenticIntegration_AgentTool(t *testing.T) { + ctx := context.Background() + + innerModel := &mockAgenticModel{ + generateFn: func(ctx context.Context, input []*schema.AgenticMessage, opts ...model.Option) (*schema.AgenticMessage, error) { + return agenticMsg("inner tool result"), nil + }, + } + + innerAgent, err := NewTypedChatModelAgent[*schema.AgenticMessage](ctx, &TypedChatModelAgentConfig[*schema.AgenticMessage]{ + Name: "InnerAgent", + Description: "An agent used as a tool", + Model: innerModel, + }) + require.NoError(t, err) + + agentTool := NewTypedAgentTool(ctx, TypedAgent[*schema.AgenticMessage](innerAgent)) + require.NotNil(t, agentTool) + + info, err := agentTool.Info(ctx) + require.NoError(t, err) + assert.Equal(t, "InnerAgent", info.Name) + assert.Equal(t, "An agent used as a tool", info.Desc) + + outerModel := &mockAgenticModel{ + generateFn: func(ctx context.Context, input []*schema.AgenticMessage, opts ...model.Option) (*schema.AgenticMessage, error) { + return agenticMsg("outer response after inner tool"), nil + }, + } + + outerAgent, err := NewTypedChatModelAgent[*schema.AgenticMessage](ctx, &TypedChatModelAgentConfig[*schema.AgenticMessage]{ + Name: "OuterAgent", + Description: "Outer agent with agent tool", + Model: outerModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{agentTool}, + }, + }, + }) + require.NoError(t, err) + + runner := NewTypedRunner[*schema.AgenticMessage](TypedRunnerConfig[*schema.AgenticMessage]{ + Agent: outerAgent, + }) + iter := runner.Query(ctx, "delegate to inner") + + var events []*TypedAgentEvent[*schema.AgenticMessage] + for { + event, ok := iter.Next() + if !ok { + break + } + events = append(events, event) + } + + require.NotEmpty(t, events) + assertAgenticEventRoleFields(t, events) + lastEvent := events[len(events)-1] + assert.Nil(t, lastEvent.Err) + assert.NotNil(t, lastEvent.Output) +} +func TestAgenticIntegration_InterruptEventFormation(t *testing.T) { + ctx := context.Background() + + t.Run("simple interrupt", func(t *testing.T) { + agent := &myAgenticAgent{ + name: "int-agent", + runFn: func(ctx context.Context, input *TypedAgentInput[*schema.AgenticMessage], options ...AgentRunOption) *AsyncIterator[*TypedAgentEvent[*schema.AgenticMessage]] { + iter, generator := NewAsyncIteratorPair[*TypedAgentEvent[*schema.AgenticMessage]]() + go func() { + defer generator.Close() + intEvent := TypedInterrupt[*schema.AgenticMessage](ctx, "approval needed") + intEvent.Action.Interrupted.Data = "approval data" + generator.Send(intEvent) + }() + return iter + }, + } + + runner := NewTypedRunner[*schema.AgenticMessage](TypedRunnerConfig[*schema.AgenticMessage]{ + Agent: agent, + }) + iter := runner.Query(ctx, "interrupt test") + + var interruptEvent *TypedAgentEvent[*schema.AgenticMessage] + for { + event, ok := iter.Next() + if !ok { + break + } + if event.Action != nil && event.Action.Interrupted != nil { + interruptEvent = event + } + } + + require.NotNil(t, interruptEvent) + assert.Equal(t, "approval data", interruptEvent.Action.Interrupted.Data) + require.NotEmpty(t, interruptEvent.Action.Interrupted.InterruptContexts) + assert.NotEmpty(t, interruptEvent.Action.Interrupted.InterruptContexts[0].ID) + assert.Equal(t, "approval needed", interruptEvent.Action.Interrupted.InterruptContexts[0].Info) + assert.True(t, interruptEvent.Action.Interrupted.InterruptContexts[0].IsRootCause) + }) + + t.Run("stateful interrupt", func(t *testing.T) { + agent := &myAgenticAgent{ + name: "st-agent", + runFn: func(ctx context.Context, input *TypedAgentInput[*schema.AgenticMessage], options ...AgentRunOption) *AsyncIterator[*TypedAgentEvent[*schema.AgenticMessage]] { + iter, generator := NewAsyncIteratorPair[*TypedAgentEvent[*schema.AgenticMessage]]() + go func() { + defer generator.Close() + intEvent := TypedStatefulInterrupt[*schema.AgenticMessage](ctx, "state interrupt", "my-state") + intEvent.Action.Interrupted.Data = "stateful data" + generator.Send(intEvent) + }() + return iter + }, + } + + runner := NewTypedRunner[*schema.AgenticMessage](TypedRunnerConfig[*schema.AgenticMessage]{ + Agent: agent, + }) + iter := runner.Query(ctx, "stateful test") + + var interruptEvent *TypedAgentEvent[*schema.AgenticMessage] + for { + event, ok := iter.Next() + if !ok { + break + } + if event.Action != nil && event.Action.Interrupted != nil { + interruptEvent = event + } + } + + require.NotNil(t, interruptEvent) + assert.Equal(t, "stateful data", interruptEvent.Action.Interrupted.Data) + require.NotEmpty(t, interruptEvent.Action.Interrupted.InterruptContexts) + assert.Equal(t, "state interrupt", interruptEvent.Action.Interrupted.InterruptContexts[0].Info) + }) +} +func TestAgenticIntegration_CheckpointInterruptResume(t *testing.T) { + ctx := context.Background() + + var resumeCalled int32 + agent := &myAgenticAgent{ + name: "ckpt-agent", + runFn: func(ctx context.Context, input *TypedAgentInput[*schema.AgenticMessage], options ...AgentRunOption) *AsyncIterator[*TypedAgentEvent[*schema.AgenticMessage]] { + iter, generator := NewAsyncIteratorPair[*TypedAgentEvent[*schema.AgenticMessage]]() + go func() { + defer generator.Close() + generator.Send(&TypedAgentEvent[*schema.AgenticMessage]{ + AgentName: "ckpt-agent", + Output: &TypedAgentOutput[*schema.AgenticMessage]{ + MessageOutput: &TypedMessageVariant[*schema.AgenticMessage]{ + Message: agenticMsg("before interrupt"), + }, + }, + }) + intEvent := TypedInterrupt[*schema.AgenticMessage](ctx, "need approval") + intEvent.Action.Interrupted.Data = "approval data" + generator.Send(intEvent) + }() + return iter + }, + resumeFn: func(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*TypedAgentEvent[*schema.AgenticMessage]] { + atomic.StoreInt32(&resumeCalled, 1) + iter, generator := NewAsyncIteratorPair[*TypedAgentEvent[*schema.AgenticMessage]]() + go func() { + defer generator.Close() + generator.Send(&TypedAgentEvent[*schema.AgenticMessage]{ + AgentName: "ckpt-agent", + Output: &TypedAgentOutput[*schema.AgenticMessage]{ + MessageOutput: &TypedMessageVariant[*schema.AgenticMessage]{ + Message: agenticMsg("after resume"), + }, + }, + }) + }() + return iter + }, + } + + store := newMyStore() + runner := NewTypedRunner[*schema.AgenticMessage](TypedRunnerConfig[*schema.AgenticMessage]{ + Agent: agent, + CheckPointStore: store, + }) + + iter := runner.Query(ctx, "checkpoint test", WithCheckPointID("ckpt-1")) + + var interruptEvent *TypedAgentEvent[*schema.AgenticMessage] + var preInterruptOutputs []string + for { + event, ok := iter.Next() + if !ok { + break + } + require.Nil(t, event.Err) + if event.Action != nil && event.Action.Interrupted != nil { + interruptEvent = event + } + if event.Output != nil && event.Output.MessageOutput != nil && event.Output.MessageOutput.Message != nil { + preInterruptOutputs = append(preInterruptOutputs, agenticTextContent(event.Output.MessageOutput.Message)) + } + } + + require.NotNil(t, interruptEvent, "should receive interrupt event") + assert.Contains(t, preInterruptOutputs, "before interrupt") + require.NotEmpty(t, interruptEvent.Action.Interrupted.InterruptContexts) + + interruptID := interruptEvent.Action.Interrupted.InterruptContexts[0].ID + require.NotEmpty(t, interruptID) + + resumeIter, err := runner.ResumeWithParams(ctx, "ckpt-1", &ResumeParams{ + Targets: map[string]any{ + interruptID: nil, + }, + }) + require.NoError(t, err) + + var postResumeOutputs []string + for { + event, ok := resumeIter.Next() + if !ok { + break + } + if event.Err != nil { + t.Fatalf("unexpected error during resume: %v", event.Err) + } + if event.Output != nil && event.Output.MessageOutput != nil && event.Output.MessageOutput.Message != nil { + postResumeOutputs = append(postResumeOutputs, agenticTextContent(event.Output.MessageOutput.Message)) + } + } + + assert.Equal(t, int32(1), atomic.LoadInt32(&resumeCalled), "resume function should have been called") + assert.Contains(t, postResumeOutputs, "after resume") +} + +func TestAgenticIntegration_CheckpointWithMCPListToolsResult(t *testing.T) { + ctx := context.Background() + + inputSchemaJSON := `{ + "type": "object", + "properties": { + "query": {"type": "string", "description": "search query"}, + "limit": {"type": "integer", "description": "max results"} + }, + "required": ["query"] + }` + var inputSchema jsonschema.Schema + require.NoError(t, json.Unmarshal([]byte(inputSchemaJSON), &inputSchema)) + + mcpMsg := &schema.AgenticMessage{ + Role: schema.AgenticRoleTypeAssistant, + ContentBlocks: []*schema.ContentBlock{ + { + Type: schema.ContentBlockTypeMCPListToolsResult, + MCPListToolsResult: &schema.MCPListToolsResult{ + ServerLabel: "test-server", + Tools: []*schema.MCPListToolsItem{ + { + Name: "search", + Description: "search the web", + InputSchema: &inputSchema, + }, + }, + }, + }, + schema.NewContentBlock(&schema.AssistantGenText{Text: "here are tools"}), + }, + } + + var resumeCalled int32 + agent := &myAgenticAgent{ + name: "mcp-agent", + runFn: func(ctx context.Context, input *TypedAgentInput[*schema.AgenticMessage], options ...AgentRunOption) *AsyncIterator[*TypedAgentEvent[*schema.AgenticMessage]] { + iter, gen := NewAsyncIteratorPair[*TypedAgentEvent[*schema.AgenticMessage]]() + go func() { + defer gen.Close() + gen.Send(&TypedAgentEvent[*schema.AgenticMessage]{ + AgentName: "mcp-agent", + Output: &TypedAgentOutput[*schema.AgenticMessage]{ + MessageOutput: &TypedMessageVariant[*schema.AgenticMessage]{Message: mcpMsg}, + }, + }) + gen.Send(TypedInterrupt[*schema.AgenticMessage](ctx, "approve tools")) + }() + return iter + }, + resumeFn: func(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*TypedAgentEvent[*schema.AgenticMessage]] { + atomic.StoreInt32(&resumeCalled, 1) + iter, gen := NewAsyncIteratorPair[*TypedAgentEvent[*schema.AgenticMessage]]() + go func() { + defer gen.Close() + gen.Send(&TypedAgentEvent[*schema.AgenticMessage]{ + AgentName: "mcp-agent", + Output: &TypedAgentOutput[*schema.AgenticMessage]{ + MessageOutput: &TypedMessageVariant[*schema.AgenticMessage]{Message: agenticMsg("tools approved")}, + }, + }) + }() + return iter + }, + } + + store := newMyStore() + runner := NewTypedRunner[*schema.AgenticMessage](TypedRunnerConfig[*schema.AgenticMessage]{ + Agent: agent, + CheckPointStore: store, + }) + + iter := runner.Query(ctx, "list tools", WithCheckPointID("mcp-1")) + var interruptEvent *TypedAgentEvent[*schema.AgenticMessage] + for { + ev, ok := iter.Next() + if !ok { + break + } + require.Nil(t, ev.Err) + if ev.Action != nil && ev.Action.Interrupted != nil { + interruptEvent = ev + } + } + require.NotNil(t, interruptEvent) + interruptID := interruptEvent.Action.Interrupted.InterruptContexts[0].ID + + resumeIter, err := runner.ResumeWithParams(ctx, "mcp-1", &ResumeParams{ + Targets: map[string]any{interruptID: nil}, + }) + require.NoError(t, err) + + var outputs []string + for { + ev, ok := resumeIter.Next() + if !ok { + break + } + require.Nil(t, ev.Err) + if ev.Output != nil && ev.Output.MessageOutput != nil && ev.Output.MessageOutput.Message != nil { + outputs = append(outputs, agenticTextContent(ev.Output.MessageOutput.Message)) + } + } + + assert.Equal(t, int32(1), atomic.LoadInt32(&resumeCalled)) + assert.Contains(t, outputs, "tools approved") +} diff --git a/adk/agentic_react_test.go b/adk/agentic_react_test.go new file mode 100644 index 000000000..5896a65a8 --- /dev/null +++ b/adk/agentic_react_test.go @@ -0,0 +1,1143 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/components/tool" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" +) + +type agenticAgentEvent = TypedAgentEvent[*schema.AgenticMessage] + +func agenticToolCallMsg(toolName, callID, args string) *schema.AgenticMessage { + return &schema.AgenticMessage{ + Role: schema.AgenticRoleTypeAssistant, + ContentBlocks: []*schema.ContentBlock{ + { + Type: schema.ContentBlockTypeFunctionToolCall, + FunctionToolCall: &schema.FunctionToolCall{Name: toolName, CallID: callID, Arguments: args}, + }, + }, + } +} + +type sequentialAgenticModel struct { + responses []*schema.AgenticMessage + callCount int32 +} + +func (m *sequentialAgenticModel) Generate(_ context.Context, _ []*schema.AgenticMessage, _ ...model.Option) (*schema.AgenticMessage, error) { + idx := atomic.AddInt32(&m.callCount, 1) - 1 + if int(idx) >= len(m.responses) { + return nil, fmt.Errorf("sequentialAgenticModel: no more responses (call #%d)", idx) + } + return m.responses[idx], nil +} + +func (m *sequentialAgenticModel) Stream(ctx context.Context, input []*schema.AgenticMessage, opts ...model.Option) (*schema.StreamReader[*schema.AgenticMessage], error) { + result, err := m.Generate(ctx, input, opts...) + if err != nil { + return nil, err + } + r, w := schema.Pipe[*schema.AgenticMessage](1) + go func() { defer w.Close(); w.Send(result, nil) }() + return r, nil +} + +type agenticEchoTool struct { + name string +} + +func (t *agenticEchoTool) Info(_ context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{Name: t.name, Desc: "echoes input"}, nil +} + +func (t *agenticEchoTool) InvokableRun(_ context.Context, argumentsInJSON string, _ ...tool.Option) (string, error) { + return "echo:" + argumentsInJSON, nil +} + +type agenticInterruptTool struct { + name string +} + +func (t *agenticInterruptTool) Info(_ context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{Name: t.name, Desc: "interrupts on first call, returns on resume"}, nil +} + +func (t *agenticInterruptTool) InvokableRun(ctx context.Context, _ string, _ ...tool.Option) (string, error) { + wasInterrupted, _, _ := tool.GetInterruptState[any](ctx) + if !wasInterrupted { + return "", tool.Interrupt(ctx, "need_approval") + } + isResume, hasData, data := tool.GetResumeContext[string](ctx) + if isResume && hasData { + return "approved:" + data, nil + } + return "resumed_no_data", nil +} + +type agenticArgCaptureTool struct { + name string + onInvoke func(args string) string +} + +func (t *agenticArgCaptureTool) Info(_ context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{Name: t.name, Desc: "captures args"}, nil +} + +func (t *agenticArgCaptureTool) InvokableRun(_ context.Context, argumentsInJSON string, _ ...tool.Option) (string, error) { + return t.onInvoke(argumentsInJSON), nil +} + +type agenticSignalTool struct { + name string + started chan struct{} + result string + done chan struct{} + once sync.Once +} + +func (t *agenticSignalTool) Info(_ context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{Name: t.name, Desc: "blocks until finish() is called"}, nil +} + +func (t *agenticSignalTool) InvokableRun(_ context.Context, _ string, _ ...tool.Option) (string, error) { + t.once.Do(func() { t.done = make(chan struct{}) }) + select { + case t.started <- struct{}{}: + default: + } + <-t.done + return t.result, nil +} + +func (t *agenticSignalTool) finish() { + t.once.Do(func() { t.done = make(chan struct{}) }) + close(t.done) +} + +type agenticReactTestStore struct { + m map[string][]byte +} + +func (s *agenticReactTestStore) Set(_ context.Context, key string, value []byte) error { + s.m[key] = value + return nil +} + +func (s *agenticReactTestStore) Get(_ context.Context, key string) ([]byte, bool, error) { + v, ok := s.m[key] + return v, ok, nil +} + +func newAgenticAgent(t *testing.T, ctx context.Context, mdl model.BaseModel[*schema.AgenticMessage], tools []tool.BaseTool) TypedAgent[*schema.AgenticMessage] { + t.Helper() + config := &TypedChatModelAgentConfig[*schema.AgenticMessage]{ + Name: t.Name(), + Description: "test agentic agent", + Model: mdl, + } + if len(tools) > 0 { + config.ToolsConfig = ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: tools, + }, + } + } + agent, err := NewTypedChatModelAgent[*schema.AgenticMessage](ctx, config) + require.NoError(t, err) + return agent +} + +func newAgenticRunner(t *testing.T, ctx context.Context, mdl model.BaseModel[*schema.AgenticMessage], tools []tool.BaseTool) *TypedRunner[*schema.AgenticMessage] { + t.Helper() + agent := newAgenticAgent(t, ctx, mdl, tools) + return NewTypedRunner[*schema.AgenticMessage](TypedRunnerConfig[*schema.AgenticMessage]{Agent: agent}) +} + +func newAgenticRunnerWithStore(t *testing.T, ctx context.Context, mdl model.BaseModel[*schema.AgenticMessage], tools []tool.BaseTool, store CheckPointStore) *TypedRunner[*schema.AgenticMessage] { + t.Helper() + agent := newAgenticAgent(t, ctx, mdl, tools) + return NewTypedRunner[*schema.AgenticMessage](TypedRunnerConfig[*schema.AgenticMessage]{ + Agent: agent, + CheckPointStore: store, + }) +} + +func drainAgenticEvents(iter *AsyncIterator[*agenticAgentEvent]) []*agenticAgentEvent { + var events []*agenticAgentEvent + for { + ev, ok := iter.Next() + if !ok { + break + } + events = append(events, ev) + } + return events +} + +func lastAgenticEvent(events []*agenticAgentEvent) *agenticAgentEvent { + if len(events) == 0 { + return nil + } + return events[len(events)-1] +} + +func findInterruptEvent(events []*agenticAgentEvent) *agenticAgentEvent { + for _, ev := range events { + if ev.Action != nil && ev.Action.Interrupted != nil { + return ev + } + } + return nil +} + +func TestAgenticReact_BasicInvoke(t *testing.T) { + ctx := context.Background() + + mdl := &sequentialAgenticModel{ + responses: []*schema.AgenticMessage{ + agenticToolCallMsg("echo", "call-1", `"hello"`), + agenticMsg("done: echo result received"), + }, + } + + runner := newAgenticRunner(t, ctx, mdl, []tool.BaseTool{&agenticEchoTool{name: "echo"}}) + events := drainAgenticEvents(runner.Query(ctx, "test input")) + last := lastAgenticEvent(events) + + require.NotNil(t, last) + require.Nil(t, last.Err) + require.NotNil(t, last.Output) + require.NotNil(t, last.Output.MessageOutput) + assert.Equal(t, "done: echo result received", agenticTextContent(last.Output.MessageOutput.Message)) + assert.Equal(t, int32(2), atomic.LoadInt32(&mdl.callCount)) +} + +func TestAgenticReact_MultiTurnToolCalling(t *testing.T) { + ctx := context.Background() + + mdl := &sequentialAgenticModel{ + responses: []*schema.AgenticMessage{ + agenticToolCallMsg("echo", "call-1", `"step1"`), + agenticToolCallMsg("echo", "call-2", `"step2"`), + agenticToolCallMsg("echo", "call-3", `"step3"`), + agenticMsg("all done"), + }, + } + + runner := newAgenticRunner(t, ctx, mdl, []tool.BaseTool{&agenticEchoTool{name: "echo"}}) + events := drainAgenticEvents(runner.Query(ctx, "do three steps")) + last := lastAgenticEvent(events) + + require.NotNil(t, last) + require.Nil(t, last.Err) + require.NotNil(t, last.Output) + require.NotNil(t, last.Output.MessageOutput) + assert.Equal(t, "all done", agenticTextContent(last.Output.MessageOutput.Message)) + assert.Equal(t, int32(4), atomic.LoadInt32(&mdl.callCount)) +} + +func TestAgenticReact_Stream(t *testing.T) { + ctx := context.Background() + + mdl := &sequentialAgenticModel{ + responses: []*schema.AgenticMessage{ + agenticToolCallMsg("echo", "call-1", `"hello"`), + agenticMsg("stream done"), + }, + } + + agent := newAgenticAgent(t, ctx, mdl, []tool.BaseTool{&agenticEchoTool{name: "echo"}}) + runner := NewTypedRunner[*schema.AgenticMessage](TypedRunnerConfig[*schema.AgenticMessage]{ + Agent: agent, + EnableStreaming: true, + }) + + events := drainAgenticEvents(runner.Query(ctx, "stream test")) + + var finalText string + for _, ev := range events { + if ev.Output != nil && ev.Output.MessageOutput != nil { + msg, err := ev.Output.MessageOutput.GetMessage() + if err == nil && msg != nil { + txt := agenticTextContent(msg) + if txt != "" { + finalText = txt + } + } + } + } + + assert.Equal(t, "stream done", finalText) +} + +func TestAgenticReact_MaxIterations(t *testing.T) { + ctx := context.Background() + + t.Run("within_limit", func(t *testing.T) { + mdl := &sequentialAgenticModel{ + responses: []*schema.AgenticMessage{ + agenticToolCallMsg("echo", "c1", `"1"`), + agenticToolCallMsg("echo", "c2", `"2"`), + agenticMsg("done within limit"), + }, + } + + runner := newAgenticRunner(t, ctx, mdl, []tool.BaseTool{&agenticEchoTool{name: "echo"}}) + events := drainAgenticEvents(runner.Query(ctx, "go")) + last := lastAgenticEvent(events) + + require.NotNil(t, last) + require.Nil(t, last.Err) + require.NotNil(t, last.Output) + require.NotNil(t, last.Output.MessageOutput) + assert.Equal(t, "done within limit", agenticTextContent(last.Output.MessageOutput.Message)) + }) + + t.Run("exceeded", func(t *testing.T) { + responses := make([]*schema.AgenticMessage, 25) + for i := range responses { + responses[i] = agenticToolCallMsg("echo", fmt.Sprintf("c%d", i), `"x"`) + } + + mdl := &sequentialAgenticModel{responses: responses} + config := &TypedChatModelAgentConfig[*schema.AgenticMessage]{ + Name: "exceed-agent", + Description: "test max iterations exceeded", + Model: mdl, + MaxIterations: 3, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{&agenticEchoTool{name: "echo"}}, + }, + }, + } + agent, err := NewTypedChatModelAgent[*schema.AgenticMessage](ctx, config) + require.NoError(t, err) + + runner := NewTypedRunner[*schema.AgenticMessage](TypedRunnerConfig[*schema.AgenticMessage]{Agent: agent}) + events := drainAgenticEvents(runner.Query(ctx, "go")) + last := lastAgenticEvent(events) + + require.NotNil(t, last) + require.NotNil(t, last.Err) + assert.ErrorIs(t, last.Err, ErrExceedMaxIterations) + }) +} + +func TestAgenticReact_ReturnDirectly(t *testing.T) { + t.Skip("returnDirectly for agentic agents depends on typed eventSenderToolHandler; not yet supported") +} + +func TestAgenticReact_CancelAfterChatModel(t *testing.T) { + ctx := context.Background() + + toolStarted := make(chan struct{}, 1) + var modelCallCount int32 + mdl := &mockAgenticModel{ + generateFn: func(ctx context.Context, input []*schema.AgenticMessage, opts ...model.Option) (*schema.AgenticMessage, error) { + count := atomic.AddInt32(&modelCallCount, 1) + switch count { + case 1: + return agenticToolCallMsg("slow", "c1", `"hi"`), nil + case 2: + return agenticToolCallMsg("slow", "c2", `"hi2"`), nil + default: + return agenticMsg("should not reach"), nil + } + }, + } + + slowTool := &agenticSignalTool{ + name: "slow", + started: toolStarted, + result: "slow result", + } + + agent := newAgenticAgent(t, ctx, mdl, []tool.BaseTool{slowTool}) + + cancelOpt, cancelFn := WithCancel() + iter := agent.Run(ctx, &TypedAgentInput[*schema.AgenticMessage]{ + Messages: []*schema.AgenticMessage{schema.UserAgenticMessage("trigger cancel")}, + }, cancelOpt) + + <-toolStarted + + go func() { + handle, _ := cancelFn(WithAgentCancelMode(CancelAfterChatModel)) + _ = handle.Wait() + }() + + time.Sleep(10 * time.Millisecond) + slowTool.finish() + + var capturedErr error + for { + ev, ok := iter.Next() + if !ok { + break + } + if ev.Err != nil { + capturedErr = ev.Err + } + } + require.Error(t, capturedErr, "expected CancelError event") + var cancelErr *CancelError + require.ErrorAs(t, capturedErr, &cancelErr) +} + +func TestAgenticReact_CancelAfterToolCalls(t *testing.T) { + ctx := context.Background() + + toolStarted := make(chan struct{}, 1) + var modelCallCount int32 + mdl := &mockAgenticModel{ + generateFn: func(ctx context.Context, input []*schema.AgenticMessage, opts ...model.Option) (*schema.AgenticMessage, error) { + count := atomic.AddInt32(&modelCallCount, 1) + if count == 1 { + return agenticToolCallMsg("slow", "c1", `"hi"`), nil + } + return agenticMsg("should not reach on second call"), nil + }, + } + + slowTool := &agenticSignalTool{ + name: "slow", + started: toolStarted, + result: "slow result", + } + + agent := newAgenticAgent(t, ctx, mdl, []tool.BaseTool{slowTool}) + + cancelOpt, cancelFn := WithCancel() + iter := agent.Run(ctx, &TypedAgentInput[*schema.AgenticMessage]{ + Messages: []*schema.AgenticMessage{schema.UserAgenticMessage("trigger cancel")}, + }, cancelOpt) + + <-toolStarted + + go func() { + handle, _ := cancelFn(WithAgentCancelMode(CancelAfterToolCalls)) + _ = handle.Wait() + }() + + time.Sleep(10 * time.Millisecond) + slowTool.finish() + + var capturedErr error + for { + ev, ok := iter.Next() + if !ok { + break + } + if ev.Err != nil { + capturedErr = ev.Err + } + } + require.Error(t, capturedErr, "expected CancelError event") + var cancelErr *CancelError + require.ErrorAs(t, capturedErr, &cancelErr) + assert.Equal(t, int32(1), atomic.LoadInt32(&modelCallCount)) +} + +func TestAgenticReact_DoubleInterruptResume(t *testing.T) { + ctx := context.Background() + + var modelCallCount int32 + mdl := &mockAgenticModel{ + generateFn: func(ctx context.Context, input []*schema.AgenticMessage, opts ...model.Option) (*schema.AgenticMessage, error) { + count := atomic.AddInt32(&modelCallCount, 1) + switch count { + case 1: + return agenticToolCallMsg("approval_tool", "c1", `"first"`), nil + case 2: + return agenticToolCallMsg("approval_tool", "c2", `"second"`), nil + case 3: + return agenticMsg("all approved"), nil + default: + return nil, fmt.Errorf("unexpected call #%d", count) + } + }, + } + + store := &agenticReactTestStore{m: map[string][]byte{}} + runner := newAgenticRunnerWithStore(t, ctx, mdl, []tool.BaseTool{&agenticInterruptTool{name: "approval_tool"}}, store) + + events1 := drainAgenticEvents(runner.Query(ctx, "approve twice", WithCheckPointID("dbl-cp"))) + int1Event := findInterruptEvent(events1) + require.NotNil(t, int1Event, "expected first interrupt") + int1ID := int1Event.Action.Interrupted.InterruptContexts[0].ID + + iter2, err := runner.ResumeWithParams(ctx, "dbl-cp", &ResumeParams{ + Targets: map[string]any{int1ID: "approved_1"}, + }) + require.NoError(t, err) + + events2 := drainAgenticEvents(iter2) + int2Event := findInterruptEvent(events2) + require.NotNil(t, int2Event, "expected second interrupt") + int2ID := int2Event.Action.Interrupted.InterruptContexts[0].ID + + iter3, err := runner.ResumeWithParams(ctx, "dbl-cp", &ResumeParams{ + Targets: map[string]any{int2ID: "approved_2"}, + }) + require.NoError(t, err) + + events3 := drainAgenticEvents(iter3) + last := lastAgenticEvent(events3) + + require.NotNil(t, last) + require.Nil(t, last.Err) + require.NotNil(t, last.Output) + require.NotNil(t, last.Output.MessageOutput) + assert.Contains(t, agenticTextContent(last.Output.MessageOutput.Message), "all approved") +} + +func TestAgenticReact_ChatModelAgent_NoTools(t *testing.T) { + ctx := context.Background() + + mdl := &mockAgenticModel{ + generateFn: func(ctx context.Context, input []*schema.AgenticMessage, opts ...model.Option) (*schema.AgenticMessage, error) { + return agenticMsg("no tools response"), nil + }, + } + + runner := newAgenticRunner(t, ctx, mdl, nil) + events := drainAgenticEvents(runner.Query(ctx, "hello")) + last := lastAgenticEvent(events) + + require.NotNil(t, last) + require.Nil(t, last.Err) + require.NotNil(t, last.Output) + require.NotNil(t, last.Output.MessageOutput) + assert.Equal(t, "no tools response", agenticTextContent(last.Output.MessageOutput.Message)) +} + +func TestAgenticReact_ChatModelAgent_ToolsReceiveArgs(t *testing.T) { + ctx := context.Background() + + var receivedArgs string + captureTool := &agenticArgCaptureTool{ + name: "capture", + onInvoke: func(args string) string { + receivedArgs = args + return "captured" + }, + } + + mdl := &sequentialAgenticModel{ + responses: []*schema.AgenticMessage{ + agenticToolCallMsg("capture", "c1", `{"foo":"bar"}`), + agenticMsg("done"), + }, + } + + runner := newAgenticRunner(t, ctx, mdl, []tool.BaseTool{captureTool}) + drainAgenticEvents(runner.Query(ctx, "call capture")) + + assert.Equal(t, `{"foo":"bar"}`, receivedArgs) +} + +func TestCoverage_AgenticReact_Streaming(t *testing.T) { + ctx := context.Background() + + m := &mockAgenticModel{ + streamFn: func(_ context.Context, input []*schema.AgenticMessage, _ ...model.Option) (*schema.StreamReader[*schema.AgenticMessage], error) { + r, w := schema.Pipe[*schema.AgenticMessage](1) + go func() { + defer w.Close() + w.Send(&schema.AgenticMessage{ + Role: schema.AgenticRoleTypeAssistant, + ContentBlocks: []*schema.ContentBlock{ + schema.NewContentBlock(&schema.AssistantGenText{Text: "streamed response"}), + }, + }, nil) + }() + return r, nil + }, + } + + echoTool := &agenticEchoTool{name: "echo"} + + agent, err := NewTypedChatModelAgent[*schema.AgenticMessage](ctx, &TypedChatModelAgentConfig[*schema.AgenticMessage]{ + Name: "stream-react", + Description: "streaming agentic react", + Model: m, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{echoTool}, + }, + }, + }) + require.NoError(t, err) + + runner := NewTypedRunner[*schema.AgenticMessage](TypedRunnerConfig[*schema.AgenticMessage]{ + Agent: agent, + EnableStreaming: true, + }) + + iter := runner.Query(ctx, "stream me") + + var events []*TypedAgentEvent[*schema.AgenticMessage] + for { + event, ok := iter.Next() + if !ok { + break + } + if event.Output != nil && event.Output.MessageOutput != nil && event.Output.MessageOutput.IsStreaming { + stream := event.Output.MessageOutput.MessageStream + for { + _, sErr := stream.Recv() + if sErr != nil { + break + } + } + } + events = append(events, event) + } + + require.NotEmpty(t, events) + assertAgenticEventRoleFields(t, events) +} + +func TestCoverage_ConcatMessageStream_Agentic(t *testing.T) { + t.Run("Success", func(t *testing.T) { + r, w := schema.Pipe[*schema.AgenticMessage](2) + go func() { + defer w.Close() + w.Send(&schema.AgenticMessage{ + Role: schema.AgenticRoleTypeAssistant, + ContentBlocks: []*schema.ContentBlock{ + schema.NewContentBlock(&schema.AssistantGenText{Text: "Hello "}), + }, + }, nil) + w.Send(&schema.AgenticMessage{ + Role: schema.AgenticRoleTypeAssistant, + ContentBlocks: []*schema.ContentBlock{ + schema.NewContentBlock(&schema.AssistantGenText{Text: "world"}), + }, + }, nil) + }() + + result, err := concatMessageStream(r) + assert.NoError(t, err) + assert.NotNil(t, result) + }) + + t.Run("ErrorDuringRecv", func(t *testing.T) { + r, w := schema.Pipe[*schema.AgenticMessage](2) + go func() { + w.Send(nil, fmt.Errorf("recv error")) + w.Close() + }() + + _, err := concatMessageStream(r) + assert.Error(t, err) + }) +} + +func TestCoverage_AgenticReact_InterruptResume(t *testing.T) { + ctx := context.Background() + + interruptTool := &agenticInterruptTool{name: "approval"} + + var callIdx int32 + m := &mockAgenticModel{ + generateFn: func(_ context.Context, _ []*schema.AgenticMessage, _ ...model.Option) (*schema.AgenticMessage, error) { + idx := atomic.AddInt32(&callIdx, 1) + if idx == 1 { + return agenticToolCallMsg("approval", "call1", `{}`), nil + } + return agenticMsg("approved and done"), nil + }, + } + + agent, err := NewTypedChatModelAgent[*schema.AgenticMessage](ctx, &TypedChatModelAgentConfig[*schema.AgenticMessage]{ + Name: "interrupt-agent", + Description: "tests interrupt and resume", + Model: m, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{interruptTool}, + }, + }, + }) + require.NoError(t, err) + + store := newDTTestStore() + runner := NewTypedRunner[*schema.AgenticMessage](TypedRunnerConfig[*schema.AgenticMessage]{ + Agent: agent, + CheckPointStore: store, + }) + + iter := runner.Run(ctx, []*schema.AgenticMessage{ + schema.UserAgenticMessage("need approval"), + }, WithCheckPointID("cp-int")) + + var interruptEvent *TypedAgentEvent[*schema.AgenticMessage] + for { + event, ok := iter.Next() + if !ok { + break + } + if event.Action != nil && event.Action.Interrupted != nil { + interruptEvent = event + } + } + + require.NotNil(t, interruptEvent, "should have interrupt event") + + var rootCauseID string + for _, intCtx := range interruptEvent.Action.Interrupted.InterruptContexts { + if intCtx.IsRootCause { + rootCauseID = intCtx.ID + break + } + } + require.NotEmpty(t, rootCauseID) + + resumeIter, err := runner.ResumeWithParams(ctx, "cp-int", &ResumeParams{ + Targets: map[string]any{rootCauseID: "approved"}, + }) + require.NoError(t, err) + + var events []*TypedAgentEvent[*schema.AgenticMessage] + for { + event, ok := resumeIter.Next() + if !ok { + break + } + events = append(events, event) + } + require.NotEmpty(t, events) +} + +func TestCoverage_AgenticMessageHasToolCalls(t *testing.T) { + t.Run("NilMessage", func(t *testing.T) { + assert.False(t, agenticMessageHasToolCalls(nil)) + }) + + t.Run("NoToolCalls", func(t *testing.T) { + msg := agenticMsg("just text") + assert.False(t, agenticMessageHasToolCalls(msg)) + }) + + t.Run("HasToolCalls", func(t *testing.T) { + msg := agenticToolCallMsg("tool1", "id1", `{}`) + assert.True(t, agenticMessageHasToolCalls(msg)) + }) + + t.Run("NilBlock", func(t *testing.T) { + msg := &schema.AgenticMessage{ + ContentBlocks: []*schema.ContentBlock{nil}, + } + assert.False(t, agenticMessageHasToolCalls(msg)) + }) + + t.Run("ToolCallBlockNilFunctionToolCall", func(t *testing.T) { + msg := &schema.AgenticMessage{ + ContentBlocks: []*schema.ContentBlock{ + {Type: schema.ContentBlockTypeFunctionToolCall, FunctionToolCall: nil}, + }, + } + assert.False(t, agenticMessageHasToolCalls(msg)) + }) +} + +func TestCoverage_ChatModelAgent_StreamError(t *testing.T) { + ctx := context.Background() + + testErr := errors.New("stream failed") + m := &mockAgenticModel{ + streamFn: func(_ context.Context, _ []*schema.AgenticMessage, _ ...model.Option) (*schema.StreamReader[*schema.AgenticMessage], error) { + return nil, testErr + }, + } + + agent, err := NewTypedChatModelAgent[*schema.AgenticMessage](ctx, &TypedChatModelAgentConfig[*schema.AgenticMessage]{ + Name: "stream-error-agent", + Description: "tests stream error", + Model: m, + }) + require.NoError(t, err) + + runner := NewTypedRunner[*schema.AgenticMessage](TypedRunnerConfig[*schema.AgenticMessage]{ + Agent: agent, + EnableStreaming: true, + }) + + iter := runner.Query(ctx, "trigger stream error") + + var capturedErr error + for { + event, ok := iter.Next() + if !ok { + break + } + if event.Err != nil { + capturedErr = event.Err + } + } + require.Error(t, capturedErr, "should propagate stream error") +} + +func TestCoverage_AgenticReact_GobStateRoundTrip(t *testing.T) { + ctx := context.Background() + + var callIdx int32 + m := &mockAgenticModel{ + generateFn: func(_ context.Context, _ []*schema.AgenticMessage, _ ...model.Option) (*schema.AgenticMessage, error) { + idx := atomic.AddInt32(&callIdx, 1) + if idx == 1 { + return agenticToolCallMsg("interrupt_tool", "call1", `{}`), nil + } + return agenticMsg("completed"), nil + }, + } + + interruptTool := &agenticInterruptTool{name: "interrupt_tool"} + + agent, err := NewTypedChatModelAgent[*schema.AgenticMessage](ctx, &TypedChatModelAgentConfig[*schema.AgenticMessage]{ + Name: "gob-test", + Description: "tests gob state round trip", + Model: m, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{interruptTool}, + }, + }, + }) + require.NoError(t, err) + + store := newDTTestStore() + runner := NewTypedRunner[*schema.AgenticMessage](TypedRunnerConfig[*schema.AgenticMessage]{ + Agent: agent, + CheckPointStore: store, + }) + + iter := runner.Run(ctx, []*schema.AgenticMessage{ + schema.UserAgenticMessage("test gob"), + }, WithCheckPointID("gob-cp")) + + var interrupted bool + var interruptEvent *TypedAgentEvent[*schema.AgenticMessage] + for { + event, ok := iter.Next() + if !ok { + break + } + if event.Action != nil && event.Action.Interrupted != nil { + interrupted = true + interruptEvent = event + } + } + + if !interrupted || interruptEvent == nil { + t.Skip("no interrupt occurred, skipping gob round-trip test") + } + + _, exists, err := store.Get(ctx, "gob-cp") + assert.NoError(t, err) + assert.True(t, exists, "checkpoint should be saved") + + var rootCauseID string + for _, intCtx := range interruptEvent.Action.Interrupted.InterruptContexts { + if intCtx.IsRootCause { + rootCauseID = intCtx.ID + break + } + } + require.NotEmpty(t, rootCauseID) + + resumeIter, err := runner.ResumeWithParams(ctx, "gob-cp", &ResumeParams{ + Targets: map[string]any{rootCauseID: "approved"}, + }) + require.NoError(t, err) + + var resumed bool + for { + event, ok := resumeIter.Next() + if !ok { + break + } + if event.Output != nil && event.Output.MessageOutput != nil { + resumed = true + } + } + assert.True(t, resumed, "should successfully resume from gob checkpoint") +} + +func TestCoverage_GetMessageFromTypedWrappedEvent_Agentic(t *testing.T) { + t.Run("NilOutput", func(t *testing.T) { + wrapper := &typedAgentEventWrapper[*schema.AgenticMessage]{ + event: &TypedAgentEvent[*schema.AgenticMessage]{}, + } + msg, err := getMessageFromTypedWrappedEvent(wrapper) + assert.NoError(t, err) + assert.Nil(t, msg) + }) + + t.Run("NonStreaming", func(t *testing.T) { + expected := agenticMsg("hello") + wrapper := &typedAgentEventWrapper[*schema.AgenticMessage]{ + event: &TypedAgentEvent[*schema.AgenticMessage]{ + Output: &TypedAgentOutput[*schema.AgenticMessage]{ + MessageOutput: &TypedMessageVariant[*schema.AgenticMessage]{ + Message: expected, + }, + }, + }, + } + msg, err := getMessageFromTypedWrappedEvent(wrapper) + assert.NoError(t, err) + assert.Equal(t, expected, msg) + }) + + t.Run("StreamingAlreadyConcatenated", func(t *testing.T) { + expected := agenticMsg("already concatenated") + wrapper := &typedAgentEventWrapper[*schema.AgenticMessage]{ + concatenatedMessage: expected, + event: &TypedAgentEvent[*schema.AgenticMessage]{ + Output: &TypedAgentOutput[*schema.AgenticMessage]{ + MessageOutput: &TypedMessageVariant[*schema.AgenticMessage]{ + IsStreaming: true, + }, + }, + }, + } + msg, err := getMessageFromTypedWrappedEvent(wrapper) + assert.NoError(t, err) + assert.Equal(t, expected, msg) + }) + + t.Run("StreamingWithPriorError", func(t *testing.T) { + testErr := errors.New("prior stream error") + wrapper := &typedAgentEventWrapper[*schema.AgenticMessage]{ + event: &TypedAgentEvent[*schema.AgenticMessage]{ + Output: &TypedAgentOutput[*schema.AgenticMessage]{ + MessageOutput: &TypedMessageVariant[*schema.AgenticMessage]{ + IsStreaming: true, + }, + }, + }, + } + wrapper.StreamErr = testErr + msg, err := getMessageFromTypedWrappedEvent(wrapper) + assert.Equal(t, testErr, err) + assert.Nil(t, msg) + }) +} + +func TestCoverage_GetMessageFromWrappedEvent_ErrorPaths(t *testing.T) { + t.Run("NilOutput", func(t *testing.T) { + wrapper := &agentEventWrapper{ + AgentEvent: &AgentEvent{}, + } + msg, err := getMessageFromWrappedEvent(wrapper) + assert.NoError(t, err) + assert.Nil(t, msg) + }) + + t.Run("NonStreaming", func(t *testing.T) { + expected := schema.AssistantMessage("hello", nil) + wrapper := &agentEventWrapper{ + AgentEvent: &AgentEvent{ + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + Message: expected, + }, + }, + }, + } + msg, err := getMessageFromWrappedEvent(wrapper) + assert.NoError(t, err) + assert.Equal(t, expected, msg) + }) + + t.Run("AlreadyConcatenated", func(t *testing.T) { + expected := schema.AssistantMessage("concatenated", nil) + wrapper := &agentEventWrapper{ + concatenatedMessage: expected, + AgentEvent: &AgentEvent{ + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + IsStreaming: true, + }, + }, + }, + } + msg, err := getMessageFromWrappedEvent(wrapper) + assert.NoError(t, err) + assert.Equal(t, expected, msg) + }) + + t.Run("PriorStreamError", func(t *testing.T) { + testErr := errors.New("prior error") + wrapper := &agentEventWrapper{ + AgentEvent: &AgentEvent{ + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + IsStreaming: true, + }, + }, + }, + } + wrapper.StreamErr = testErr + msg, err := getMessageFromWrappedEvent(wrapper) + assert.Equal(t, testErr, err) + assert.Nil(t, msg) + }) +} + +func TestCoverage_ConsumeStream_ErrorDuringRecv(t *testing.T) { + testErr := errors.New("stream recv error") + r, w := schema.Pipe[*schema.Message](2) + go func() { + w.Send(schema.AssistantMessage("partial", nil), nil) + w.Send(nil, testErr) + w.Close() + }() + + wrapper := &agentEventWrapper{ + AgentEvent: &AgentEvent{ + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + IsStreaming: true, + MessageStream: r, + }, + }, + }, + } + + wrapper.consumeStream() + + assert.NotNil(t, wrapper.StreamErr) + assert.Nil(t, wrapper.concatenatedMessage) +} + +func TestCoverage_ConsumeStream_EmptyStream(t *testing.T) { + r, w := schema.Pipe[*schema.Message](1) + go func() { w.Close() }() + + wrapper := &agentEventWrapper{ + AgentEvent: &AgentEvent{ + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + IsStreaming: true, + MessageStream: r, + }, + }, + }, + } + + wrapper.consumeStream() + + require.NotNil(t, wrapper.StreamErr) + assert.Contains(t, wrapper.StreamErr.Error(), "no messages") +} + +func TestCoverage_ConsumeStream_MultipleMessages(t *testing.T) { + r, w := schema.Pipe[*schema.Message](3) + go func() { + defer w.Close() + w.Send(schema.AssistantMessage("chunk1", nil), nil) + w.Send(schema.AssistantMessage("chunk2", nil), nil) + w.Send(schema.AssistantMessage("chunk3", nil), nil) + }() + + wrapper := &agentEventWrapper{ + AgentEvent: &AgentEvent{ + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + IsStreaming: true, + MessageStream: r, + }, + }, + }, + } + + wrapper.consumeStream() + + assert.Nil(t, wrapper.StreamErr) + assert.NotNil(t, wrapper.concatenatedMessage) +} + +func TestCoverage_ConsumeStream_SingleMessage(t *testing.T) { + r, w := schema.Pipe[*schema.Message](1) + go func() { + defer w.Close() + w.Send(schema.AssistantMessage("single", nil), nil) + }() + + wrapper := &agentEventWrapper{ + AgentEvent: &AgentEvent{ + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + IsStreaming: true, + MessageStream: r, + }, + }, + }, + } + + wrapper.consumeStream() + + assert.Nil(t, wrapper.StreamErr) + require.NotNil(t, wrapper.concatenatedMessage) + assert.Equal(t, "single", wrapper.concatenatedMessage.Content) +} + +func TestCoverage_ConsumeStream_Idempotent(t *testing.T) { + r, w := schema.Pipe[*schema.Message](1) + go func() { + defer w.Close() + w.Send(schema.AssistantMessage("once", nil), nil) + }() + + wrapper := &agentEventWrapper{ + AgentEvent: &AgentEvent{ + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + IsStreaming: true, + MessageStream: r, + }, + }, + }, + } + + wrapper.consumeStream() + msg1 := wrapper.concatenatedMessage + + wrapper.consumeStream() + msg2 := wrapper.concatenatedMessage + + assert.Equal(t, msg1, msg2, "second call should be no-op") +} diff --git a/adk/agentic_test.go b/adk/agentic_test.go new file mode 100644 index 000000000..80f729e93 --- /dev/null +++ b/adk/agentic_test.go @@ -0,0 +1,1355 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "errors" + "io" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/components/tool" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" +) + +type mockAgenticModel struct { + generateFn func(ctx context.Context, input []*schema.AgenticMessage, opts ...model.Option) (*schema.AgenticMessage, error) + streamFn func(ctx context.Context, input []*schema.AgenticMessage, opts ...model.Option) (*schema.StreamReader[*schema.AgenticMessage], error) +} + +func (m *mockAgenticModel) Generate(ctx context.Context, input []*schema.AgenticMessage, opts ...model.Option) (*schema.AgenticMessage, error) { + return m.generateFn(ctx, input, opts...) +} + +func (m *mockAgenticModel) Stream(ctx context.Context, input []*schema.AgenticMessage, opts ...model.Option) (*schema.StreamReader[*schema.AgenticMessage], error) { + if m.streamFn != nil { + return m.streamFn(ctx, input, opts...) + } + result, err := m.generateFn(ctx, input, opts...) + if err != nil { + return nil, err + } + r, w := schema.Pipe[*schema.AgenticMessage](1) + go func() { defer w.Close(); w.Send(result, nil) }() + return r, nil +} + +type testAgenticMiddleware struct { + *TypedBaseChatModelAgentMiddleware[*schema.AgenticMessage] + beforeFn func(context.Context, *TypedChatModelAgentState[*schema.AgenticMessage], *ModelContext) (context.Context, *TypedChatModelAgentState[*schema.AgenticMessage], error) + afterFn func(context.Context, *TypedChatModelAgentState[*schema.AgenticMessage], *ModelContext) (context.Context, *TypedChatModelAgentState[*schema.AgenticMessage], error) +} + +func (m *testAgenticMiddleware) BeforeModelRewriteState(ctx context.Context, state *TypedChatModelAgentState[*schema.AgenticMessage], mc *ModelContext) (context.Context, *TypedChatModelAgentState[*schema.AgenticMessage], error) { + if m.beforeFn != nil { + return m.beforeFn(ctx, state, mc) + } + return ctx, state, nil +} + +func (m *testAgenticMiddleware) AfterModelRewriteState(ctx context.Context, state *TypedChatModelAgentState[*schema.AgenticMessage], mc *ModelContext) (context.Context, *TypedChatModelAgentState[*schema.AgenticMessage], error) { + if m.afterFn != nil { + return m.afterFn(ctx, state, mc) + } + return ctx, state, nil +} + +func TestAgenticChatModelAgentRun_NoTools(t *testing.T) { + ctx := context.Background() + + agenticResponse := &schema.AgenticMessage{ + Role: schema.AgenticRoleTypeAssistant, + ContentBlocks: []*schema.ContentBlock{ + schema.NewContentBlock(&schema.AssistantGenText{Text: "Hello from agentic model"}), + }, + } + + m := &mockAgenticModel{ + generateFn: func(ctx context.Context, input []*schema.AgenticMessage, opts ...model.Option) (*schema.AgenticMessage, error) { + return agenticResponse, nil + }, + } + + agent, err := NewTypedChatModelAgent[*schema.AgenticMessage](ctx, &TypedChatModelAgentConfig[*schema.AgenticMessage]{ + Name: "AgenticTestAgent", + Description: "Agentic test agent", + Instruction: "You are helpful.", + Model: m, + }) + assert.NoError(t, err) + assert.NotNil(t, agent) + + input := &TypedAgentInput[*schema.AgenticMessage]{ + Messages: []*schema.AgenticMessage{ + schema.UserAgenticMessage("Hi"), + }, + } + iter := agent.Run(ctx, input) + require.NotNil(t, iter) + + event, ok := iter.Next() + assert.True(t, ok) + require.NotNil(t, event) + assert.Nil(t, event.Err) + require.NotNil(t, event.Output) + require.NotNil(t, event.Output.MessageOutput) + + msg := event.Output.MessageOutput.Message + require.NotNil(t, msg) + assert.Equal(t, schema.AgenticRoleTypeAssistant, msg.Role) + assert.Len(t, msg.ContentBlocks, 1) + assert.Equal(t, "Hello from agentic model", msg.ContentBlocks[0].AssistantGenText.Text) + + _, ok = iter.Next() + assert.False(t, ok) +} + +func TestAgenticChatModelAgentRun_WithTools(t *testing.T) { + ctx := context.Background() + + agenticResponse := &schema.AgenticMessage{ + Role: schema.AgenticRoleTypeAssistant, + ContentBlocks: []*schema.ContentBlock{ + schema.NewContentBlock(&schema.AssistantGenText{Text: "Used tool and got result"}), + }, + } + + var receivedToolInfos []*schema.ToolInfo + m := &mockAgenticModel{ + generateFn: func(ctx context.Context, input []*schema.AgenticMessage, opts ...model.Option) (*schema.AgenticMessage, error) { + o := model.GetCommonOptions(&model.Options{}, opts...) + receivedToolInfos = o.Tools + return agenticResponse, nil + }, + } + + dummyTool := newSlowTool("dummy_tool", 0, "ok") + + agent, err := NewTypedChatModelAgent[*schema.AgenticMessage](ctx, &TypedChatModelAgentConfig[*schema.AgenticMessage]{ + Name: "AgenticToolAgent", + Description: "Agentic agent with tools", + Instruction: "You are helpful.", + Model: m, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{dummyTool}, + }, + }, + }) + assert.NoError(t, err) + assert.NotNil(t, agent) + + input := &TypedAgentInput[*schema.AgenticMessage]{ + Messages: []*schema.AgenticMessage{ + schema.UserAgenticMessage("Call a tool"), + }, + } + iter := agent.Run(ctx, input) + + event, ok := iter.Next() + assert.True(t, ok) + assert.Nil(t, event.Err) + assert.NotNil(t, event.Output) + + _, ok = iter.Next() + assert.False(t, ok) + + require.Len(t, receivedToolInfos, 1) + assert.Equal(t, "dummy_tool", receivedToolInfos[0].Name) +} + +func TestAgenticChatModelAgentRun_Streaming(t *testing.T) { + ctx := context.Background() + + chunk1 := &schema.AgenticMessage{ + Role: schema.AgenticRoleTypeAssistant, + ContentBlocks: []*schema.ContentBlock{ + schema.NewContentBlock(&schema.AssistantGenText{Text: "Hello "}), + }, + } + chunk2 := &schema.AgenticMessage{ + Role: schema.AgenticRoleTypeAssistant, + ContentBlocks: []*schema.ContentBlock{ + schema.NewContentBlock(&schema.AssistantGenText{Text: "world"}), + }, + } + + m := &mockAgenticModel{ + streamFn: func(ctx context.Context, input []*schema.AgenticMessage, opts ...model.Option) (*schema.StreamReader[*schema.AgenticMessage], error) { + r, w := schema.Pipe[*schema.AgenticMessage](2) + go func() { + defer w.Close() + w.Send(chunk1, nil) + w.Send(chunk2, nil) + }() + return r, nil + }, + } + + agent, err := NewTypedChatModelAgent[*schema.AgenticMessage](ctx, &TypedChatModelAgentConfig[*schema.AgenticMessage]{ + Name: "AgenticStreamAgent", + Description: "Agentic streaming agent", + Instruction: "You are helpful.", + Model: m, + }) + assert.NoError(t, err) + + input := &TypedAgentInput[*schema.AgenticMessage]{ + Messages: []*schema.AgenticMessage{ + schema.UserAgenticMessage("Hi"), + }, + EnableStreaming: true, + } + iter := agent.Run(ctx, input) + + event, ok := iter.Next() + assert.True(t, ok) + assert.Nil(t, event.Err) + require.NotNil(t, event.Output) + require.NotNil(t, event.Output.MessageOutput) + require.NotNil(t, event.Output.MessageOutput.MessageStream) + event.Output.MessageOutput.MessageStream.Close() + + _, ok = iter.Next() + assert.False(t, ok) +} + +func TestDefaultAgenticGenModelInput(t *testing.T) { + ctx := context.Background() + + t.Run("WithInstruction", func(t *testing.T) { + input := &TypedAgentInput[*schema.AgenticMessage]{ + Messages: []*schema.AgenticMessage{ + schema.UserAgenticMessage("Hello"), + }, + } + msgs, err := newDefaultGenModelInput[*schema.AgenticMessage]()(ctx, "Be helpful", input) + assert.NoError(t, err) + assert.Len(t, msgs, 2) + assert.Equal(t, schema.AgenticRoleTypeSystem, msgs[0].Role) + assert.Equal(t, schema.AgenticRoleTypeUser, msgs[1].Role) + }) + + t.Run("WithoutInstruction", func(t *testing.T) { + input := &TypedAgentInput[*schema.AgenticMessage]{ + Messages: []*schema.AgenticMessage{ + schema.UserAgenticMessage("Hello"), + }, + } + msgs, err := newDefaultGenModelInput[*schema.AgenticMessage]()(ctx, "", input) + assert.NoError(t, err) + assert.Len(t, msgs, 1) + assert.Equal(t, schema.AgenticRoleTypeUser, msgs[0].Role) + }) +} + +func TestAgenticRunnerQuery(t *testing.T) { + ctx := context.Background() + + agenticResponse := &schema.AgenticMessage{ + Role: schema.AgenticRoleTypeAssistant, + ContentBlocks: []*schema.ContentBlock{ + schema.NewContentBlock(&schema.AssistantGenText{Text: "query response"}), + }, + } + + m := &mockAgenticModel{ + generateFn: func(ctx context.Context, input []*schema.AgenticMessage, opts ...model.Option) (*schema.AgenticMessage, error) { + return agenticResponse, nil + }, + } + + agent, err := NewTypedChatModelAgent[*schema.AgenticMessage](ctx, &TypedChatModelAgentConfig[*schema.AgenticMessage]{ + Name: "QueryAgent", + Description: "Query test agent", + Instruction: "Be helpful.", + Model: m, + }) + assert.NoError(t, err) + + runner := NewTypedRunner[*schema.AgenticMessage](TypedRunnerConfig[*schema.AgenticMessage]{ + Agent: agent, + }) + + iter := runner.Query(ctx, "What's up?") + + event, ok := iter.Next() + assert.True(t, ok) + assert.Nil(t, event.Err) + + _, ok = iter.Next() + assert.False(t, ok) +} + +func agenticAssistantMessage(text string) *schema.AgenticMessage { + return &schema.AgenticMessage{ + Role: schema.AgenticRoleTypeAssistant, + ContentBlocks: []*schema.ContentBlock{ + schema.NewContentBlock(&schema.AssistantGenText{Text: text}), + }, + } +} + +type mockAgenticRunnerAgent struct { + name string + description string + responses []*TypedAgentEvent[*schema.AgenticMessage] + callCount int + lastInput *TypedAgentInput[*schema.AgenticMessage] + enableStreaming bool +} + +func (a *mockAgenticRunnerAgent) Name(_ context.Context) string { return a.name } +func (a *mockAgenticRunnerAgent) Description(_ context.Context) string { return a.description } +func (a *mockAgenticRunnerAgent) Run(_ context.Context, input *TypedAgentInput[*schema.AgenticMessage], _ ...AgentRunOption) *AsyncIterator[*TypedAgentEvent[*schema.AgenticMessage]] { + a.callCount++ + a.lastInput = input + a.enableStreaming = input.EnableStreaming + + iterator, generator := NewAsyncIteratorPair[*TypedAgentEvent[*schema.AgenticMessage]]() + go func() { + defer generator.Close() + for _, event := range a.responses { + generator.Send(event) + if event.Action != nil && event.Action.Exit { + break + } + } + }() + return iterator +} + +type mockAgenticAgent struct { + name string + description string + responses []*TypedAgentEvent[*schema.AgenticMessage] +} + +func (a *mockAgenticAgent) Name(_ context.Context) string { return a.name } +func (a *mockAgenticAgent) Description(_ context.Context) string { return a.description } +func (a *mockAgenticAgent) Run(_ context.Context, _ *TypedAgentInput[*schema.AgenticMessage], _ ...AgentRunOption) *AsyncIterator[*TypedAgentEvent[*schema.AgenticMessage]] { + iterator, generator := NewAsyncIteratorPair[*TypedAgentEvent[*schema.AgenticMessage]]() + go func() { + defer generator.Close() + for _, event := range a.responses { + generator.Send(event) + if event.Action != nil && event.Action.Exit { + break + } + } + }() + return iterator +} + +type myAgenticAgent struct { + name string + runFn func(ctx context.Context, input *TypedAgentInput[*schema.AgenticMessage], options ...AgentRunOption) *AsyncIterator[*TypedAgentEvent[*schema.AgenticMessage]] + resumeFn func(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*TypedAgentEvent[*schema.AgenticMessage]] +} + +func (m *myAgenticAgent) Name(_ context.Context) string { + if len(m.name) > 0 { + return m.name + } + return "myAgenticAgent" +} +func (m *myAgenticAgent) Description(_ context.Context) string { return "my agentic agent description" } +func (m *myAgenticAgent) Run(ctx context.Context, input *TypedAgentInput[*schema.AgenticMessage], options ...AgentRunOption) *AsyncIterator[*TypedAgentEvent[*schema.AgenticMessage]] { + return m.runFn(ctx, input, options...) +} +func (m *myAgenticAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*TypedAgentEvent[*schema.AgenticMessage]] { + return m.resumeFn(ctx, info, opts...) +} + +func TestAgenticChatModelAgentRun_WithMiddleware(t *testing.T) { + ctx := context.Background() + + agenticResponse := &schema.AgenticMessage{ + Role: schema.AgenticRoleTypeAssistant, + ContentBlocks: []*schema.ContentBlock{ + schema.NewContentBlock(&schema.AssistantGenText{Text: "Hello from agentic agent"}), + }, + } + + m := &mockAgenticModel{ + generateFn: func(ctx context.Context, input []*schema.AgenticMessage, opts ...model.Option) (*schema.AgenticMessage, error) { + return agenticResponse, nil + }, + } + + afterModelExecuted := false + + mw := &testAgenticMiddleware{ + beforeFn: func(ctx context.Context, state *TypedChatModelAgentState[*schema.AgenticMessage], mc *ModelContext) (context.Context, *TypedChatModelAgentState[*schema.AgenticMessage], error) { + state.Messages = append(state.Messages, schema.UserAgenticMessage("extra")) + return ctx, state, nil + }, + afterFn: func(ctx context.Context, state *TypedChatModelAgentState[*schema.AgenticMessage], mc *ModelContext) (context.Context, *TypedChatModelAgentState[*schema.AgenticMessage], error) { + assert.Len(t, state.Messages, 4) + afterModelExecuted = true + return ctx, state, nil + }, + } + + agent, err := NewTypedChatModelAgent[*schema.AgenticMessage](ctx, &TypedChatModelAgentConfig[*schema.AgenticMessage]{ + Name: "AgenticMiddlewareAgent", + Description: "Agentic agent with middleware", + Instruction: "You are helpful.", + Model: m, + Handlers: []TypedChatModelAgentMiddleware[*schema.AgenticMessage]{mw}, + }) + assert.NoError(t, err) + + input := &TypedAgentInput[*schema.AgenticMessage]{ + Messages: []*schema.AgenticMessage{ + schema.UserAgenticMessage("Hi"), + }, + } + iter := agent.Run(ctx, input) + event, ok := iter.Next() + assert.True(t, ok) + assert.Nil(t, event.Err) + require.NotNil(t, event.Output) + require.NotNil(t, event.Output.MessageOutput) + require.NotNil(t, event.Output.MessageOutput.Message) + assert.Equal(t, schema.AgenticRoleTypeAssistant, event.Output.MessageOutput.Message.Role) + _, ok = iter.Next() + assert.False(t, ok) + assert.True(t, afterModelExecuted) +} + +func TestAgenticAfterModel_NoTools_ModifyDoesNotAffectEvent(t *testing.T) { + ctx := context.Background() + + m := &mockAgenticModel{ + generateFn: func(ctx context.Context, input []*schema.AgenticMessage, opts ...model.Option) (*schema.AgenticMessage, error) { + return agenticAssistantMessage("original content"), nil + }, + } + + var capturedMessages []*schema.AgenticMessage + + agent, err := NewTypedChatModelAgent[*schema.AgenticMessage](ctx, &TypedChatModelAgentConfig[*schema.AgenticMessage]{ + Name: "AgenticAfterModelAgent", + Description: "Test AfterModelRewriteState", + Instruction: "You are helpful.", + Model: m, + Handlers: []TypedChatModelAgentMiddleware[*schema.AgenticMessage]{ + &testAgenticMiddleware{ + afterFn: func(ctx context.Context, state *TypedChatModelAgentState[*schema.AgenticMessage], mc *ModelContext) (context.Context, *TypedChatModelAgentState[*schema.AgenticMessage], error) { + capturedMessages = make([]*schema.AgenticMessage, len(state.Messages)) + copy(capturedMessages, state.Messages) + state.Messages = append(state.Messages, agenticAssistantMessage("appended content")) + return ctx, state, nil + }, + }, + }, + }) + assert.NoError(t, err) + + input := &TypedAgentInput[*schema.AgenticMessage]{ + Messages: []*schema.AgenticMessage{ + schema.UserAgenticMessage("Hello"), + }, + } + iterator := agent.Run(ctx, input) + + event, ok := iterator.Next() + assert.True(t, ok) + assert.Nil(t, event.Err) + require.NotNil(t, event.Output) + require.NotNil(t, event.Output.MessageOutput) + + msg := event.Output.MessageOutput.Message + require.NotNil(t, msg) + assert.Equal(t, "original content", msg.ContentBlocks[0].AssistantGenText.Text) + + _, ok = iterator.Next() + assert.False(t, ok) + + assert.Len(t, capturedMessages, 3) +} + +func TestAgenticGetComposeOptions_WithChatModelOptions(t *testing.T) { + ctx := context.Background() + + var capturedTemperature float32 + m := &mockAgenticModel{ + generateFn: func(ctx context.Context, input []*schema.AgenticMessage, opts ...model.Option) (*schema.AgenticMessage, error) { + options := model.GetCommonOptions(&model.Options{}, opts...) + if options.Temperature != nil { + capturedTemperature = *options.Temperature + } + return agenticAssistantMessage("response"), nil + }, + } + + agent, err := NewTypedChatModelAgent[*schema.AgenticMessage](ctx, &TypedChatModelAgentConfig[*schema.AgenticMessage]{ + Name: "AgenticOptionsAgent", + Description: "Test agent", + Model: m, + }) + assert.NoError(t, err) + + temp := float32(0.7) + iter := agent.Run(ctx, &TypedAgentInput[*schema.AgenticMessage]{Messages: []*schema.AgenticMessage{schema.UserAgenticMessage("test")}}, + WithChatModelOptions([]model.Option{model.WithTemperature(temp)})) + for { + _, ok := iter.Next() + if !ok { + break + } + } + + assert.Equal(t, temp, capturedTemperature) +} + +func TestAgenticChatModelAgent_PrepareExecContextError(t *testing.T) { + ctx := context.Background() + + expectedErr := errors.New("tool info error") + errTool := &errorTool{infoErr: expectedErr} + + m := &mockAgenticModel{ + generateFn: func(ctx context.Context, input []*schema.AgenticMessage, opts ...model.Option) (*schema.AgenticMessage, error) { + return agenticAssistantMessage("response"), nil + }, + } + + agent, err := NewTypedChatModelAgent[*schema.AgenticMessage](ctx, &TypedChatModelAgentConfig[*schema.AgenticMessage]{ + Name: "AgenticErrToolAgent", + Description: "Test agent", + Model: m, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{errTool}, + }, + }, + }) + assert.NoError(t, err) + + iter := agent.Run(ctx, &TypedAgentInput[*schema.AgenticMessage]{Messages: []*schema.AgenticMessage{schema.UserAgenticMessage("test")}}) + + event, ok := iter.Next() + assert.True(t, ok) + assert.NotNil(t, event.Err) + assert.Contains(t, event.Err.Error(), "tool info error") + + _, ok = iter.Next() + assert.False(t, ok) +} + +func TestAgenticChatModelAgentOutputKey(t *testing.T) { + t.Run("OutputKeyStoresInSession", func(t *testing.T) { + ctx := context.Background() + + m := &mockAgenticModel{ + generateFn: func(ctx context.Context, input []*schema.AgenticMessage, opts ...model.Option) (*schema.AgenticMessage, error) { + return agenticAssistantMessage("Hello from agentic assistant."), nil + }, + } + + agent, err := NewTypedChatModelAgent[*schema.AgenticMessage](ctx, &TypedChatModelAgentConfig[*schema.AgenticMessage]{ + Name: "AgenticOutputKeyAgent", + Description: "Test agent for output key", + Instruction: "You are helpful.", + Model: m, + OutputKey: "agent_output", + }) + assert.NoError(t, err) + + input := &TypedAgentInput[*schema.AgenticMessage]{ + Messages: []*schema.AgenticMessage{ + schema.UserAgenticMessage("Hello"), + }, + } + ctx, runCtx := initTypedRunCtx[*schema.AgenticMessage](ctx, "AgenticOutputKeyAgent", input) + require.NotNil(t, runCtx) + require.NotNil(t, runCtx.Session) + + iterator := agent.Run(ctx, input) + + event, ok := iterator.Next() + assert.True(t, ok) + assert.Nil(t, event.Err) + + msg := event.Output.MessageOutput.Message + assert.Equal(t, "Hello from agentic assistant.", msg.ContentBlocks[0].AssistantGenText.Text) + + _, ok = iterator.Next() + assert.False(t, ok) + + sessionValues := GetSessionValues(ctx) + assert.Contains(t, sessionValues, "agent_output") + assert.Equal(t, "Hello from agentic assistant.", sessionValues["agent_output"]) + }) + + t.Run("OutputKeyWithStreamingStoresInSession", func(t *testing.T) { + ctx := context.Background() + + chunk1 := agenticAssistantMessage("Hello") + chunk2 := agenticAssistantMessage(", world.") + + m := &mockAgenticModel{ + streamFn: func(ctx context.Context, input []*schema.AgenticMessage, opts ...model.Option) (*schema.StreamReader[*schema.AgenticMessage], error) { + r, w := schema.Pipe[*schema.AgenticMessage](2) + go func() { + defer w.Close() + w.Send(chunk1, nil) + w.Send(chunk2, nil) + }() + return r, nil + }, + } + + agent, err := NewTypedChatModelAgent[*schema.AgenticMessage](ctx, &TypedChatModelAgentConfig[*schema.AgenticMessage]{ + Name: "AgenticStreamOutputKeyAgent", + Description: "Test agent for streaming output key", + Instruction: "You are helpful.", + Model: m, + OutputKey: "agent_output", + }) + assert.NoError(t, err) + + input := &TypedAgentInput[*schema.AgenticMessage]{ + Messages: []*schema.AgenticMessage{ + schema.UserAgenticMessage("Hello"), + }, + EnableStreaming: true, + } + ctx, runCtx := initTypedRunCtx[*schema.AgenticMessage](ctx, "AgenticStreamOutputKeyAgent", input) + require.NotNil(t, runCtx) + require.NotNil(t, runCtx.Session) + + iterator := agent.Run(ctx, input) + + event, ok := iterator.Next() + assert.True(t, ok) + assert.Nil(t, event.Err) + assert.True(t, event.Output.MessageOutput.IsStreaming) + + _, ok = iterator.Next() + assert.False(t, ok) + }) + + t.Run("SetOutputToSessionAgenticMessage", func(t *testing.T) { + ctx := context.Background() + + input := &TypedAgentInput[*schema.AgenticMessage]{ + Messages: []*schema.AgenticMessage{schema.UserAgenticMessage("test")}, + } + ctx, runCtx := initTypedRunCtx[*schema.AgenticMessage](ctx, "TestAgent", input) + require.NotNil(t, runCtx) + require.NotNil(t, runCtx.Session) + + msg := agenticAssistantMessage("Test response") + err := setOutputToSession(ctx, msg, nil, "test_output") + assert.NoError(t, err) + + sessionValues := GetSessionValues(ctx) + assert.Contains(t, sessionValues, "test_output") + assert.Equal(t, "Test response", sessionValues["test_output"]) + }) +} + +func TestAgenticRunner_Run_WithStreaming(t *testing.T) { + ctx := context.Background() + + mockAgent_ := &mockAgenticRunnerAgent{ + name: "AgenticStreamRunnerAgent", + description: "Test agent for agentic runner streaming", + responses: []*TypedAgentEvent[*schema.AgenticMessage]{ + { + AgentName: "AgenticStreamRunnerAgent", + Output: &TypedAgentOutput[*schema.AgenticMessage]{ + MessageOutput: &TypedMessageVariant[*schema.AgenticMessage]{ + IsStreaming: true, + MessageStream: schema.StreamReaderFromArray([]*schema.AgenticMessage{ + agenticAssistantMessage("Streaming response"), + }), + }, + }, + }, + }, + } + + runner := NewTypedRunner[*schema.AgenticMessage](TypedRunnerConfig[*schema.AgenticMessage]{EnableStreaming: true, Agent: mockAgent_}) + + msgs := []*schema.AgenticMessage{ + schema.UserAgenticMessage("Hello, agent!"), + } + + iterator := runner.Run(ctx, msgs) + + assert.Equal(t, 1, mockAgent_.callCount) + assert.Equal(t, msgs, mockAgent_.lastInput.Messages) + assert.True(t, mockAgent_.enableStreaming) + + event, ok := iterator.Next() + assert.True(t, ok) + assert.Equal(t, "AgenticStreamRunnerAgent", event.AgentName) + require.NotNil(t, event.Output) + require.NotNil(t, event.Output.MessageOutput) + assert.True(t, event.Output.MessageOutput.IsStreaming) + + _, ok = iterator.Next() + assert.False(t, ok) +} + +func TestAgenticRunner_Query_WithStreaming(t *testing.T) { + ctx := context.Background() + + mockAgent_ := &mockAgenticRunnerAgent{ + name: "AgenticStreamQueryAgent", + description: "Test agent for agentic runner query streaming", + responses: []*TypedAgentEvent[*schema.AgenticMessage]{ + { + AgentName: "AgenticStreamQueryAgent", + Output: &TypedAgentOutput[*schema.AgenticMessage]{ + MessageOutput: &TypedMessageVariant[*schema.AgenticMessage]{ + IsStreaming: true, + MessageStream: schema.StreamReaderFromArray([]*schema.AgenticMessage{ + agenticAssistantMessage("Streaming query response"), + }), + }, + }, + }, + }, + } + + runner := NewTypedRunner[*schema.AgenticMessage](TypedRunnerConfig[*schema.AgenticMessage]{EnableStreaming: true, Agent: mockAgent_}) + + iterator := runner.Query(ctx, "Test query") + + assert.Equal(t, 1, mockAgent_.callCount) + assert.Len(t, mockAgent_.lastInput.Messages, 1) + assert.True(t, mockAgent_.enableStreaming) + + event, ok := iterator.Next() + assert.True(t, ok) + assert.Equal(t, "AgenticStreamQueryAgent", event.AgentName) + require.NotNil(t, event.Output) + require.NotNil(t, event.Output.MessageOutput) + assert.True(t, event.Output.MessageOutput.IsStreaming) + + _, ok = iterator.Next() + assert.False(t, ok) +} + +func TestAgenticSimpleInterrupt(t *testing.T) { + data := "hello world" + agent := &myAgenticAgent{ + runFn: func(ctx context.Context, input *TypedAgentInput[*schema.AgenticMessage], options ...AgentRunOption) *AsyncIterator[*TypedAgentEvent[*schema.AgenticMessage]] { + iter, generator := NewAsyncIteratorPair[*TypedAgentEvent[*schema.AgenticMessage]]() + generator.Send(&TypedAgentEvent[*schema.AgenticMessage]{ + Output: &TypedAgentOutput[*schema.AgenticMessage]{ + MessageOutput: &TypedMessageVariant[*schema.AgenticMessage]{ + IsStreaming: true, + MessageStream: schema.StreamReaderFromArray([]*schema.AgenticMessage{ + schema.UserAgenticMessage("hello "), + schema.UserAgenticMessage("world"), + }), + }, + }, + }) + intEvent := TypedInterrupt[*schema.AgenticMessage](ctx, data) + intEvent.Action.Interrupted.Data = data + generator.Send(intEvent) + generator.Close() + return iter + }, + resumeFn: func(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*TypedAgentEvent[*schema.AgenticMessage]] { + assert.True(t, info.WasInterrupted) + assert.Nil(t, info.InterruptState) + assert.True(t, info.EnableStreaming) + assert.Equal(t, data, info.Data) + + assert.True(t, info.IsResumeTarget) + iter, generator := NewAsyncIteratorPair[*TypedAgentEvent[*schema.AgenticMessage]]() + generator.Close() + return iter + }, + } + store := newMyStore() + ctx := context.Background() + runner := NewTypedRunner[*schema.AgenticMessage](TypedRunnerConfig[*schema.AgenticMessage]{ + Agent: agent, + EnableStreaming: true, + CheckPointStore: store, + }) + iter := runner.Query(ctx, "hello world", WithCheckPointID("1")) + + var interruptEvent *TypedAgentEvent[*schema.AgenticMessage] + for { + event, ok := iter.Next() + if !ok { + break + } + if event.Action != nil && event.Action.Interrupted != nil { + interruptEvent = event + } + } + + require.NotNil(t, interruptEvent) + assert.Equal(t, data, interruptEvent.Action.Interrupted.Data) + assert.NotEmpty(t, interruptEvent.Action.Interrupted.InterruptContexts[0].ID) + assert.True(t, interruptEvent.Action.Interrupted.InterruptContexts[0].IsRootCause) + assert.Equal(t, data, interruptEvent.Action.Interrupted.InterruptContexts[0].Info) + assert.Equal(t, Address{{Type: AddressSegmentAgent, ID: "myAgenticAgent"}}, + interruptEvent.Action.Interrupted.InterruptContexts[0].Address) +} + +func TestCascadingFrom_NewChatModelAgentFrom(t *testing.T) { + ctx := context.Background() + + agenticResponse := &schema.AgenticMessage{ + Role: schema.AgenticRoleTypeAssistant, + ContentBlocks: []*schema.ContentBlock{ + schema.NewContentBlock(&schema.AssistantGenText{Text: "from response"}), + }, + } + + m := &mockAgenticModel{ + generateFn: func(ctx context.Context, input []*schema.AgenticMessage, opts ...model.Option) (*schema.AgenticMessage, error) { + return agenticResponse, nil + }, + } + + agent, err := NewTypedChatModelAgent[*schema.AgenticMessage](ctx, &TypedChatModelAgentConfig[*schema.AgenticMessage]{ + Name: "FromAgent", + Description: "Test cascading constructor", + Instruction: "Be helpful.", + Model: m, + }) + assert.NoError(t, err) + assert.Equal(t, "FromAgent", agent.Name(ctx)) + + runner := NewTypedRunner(TypedRunnerConfig[*schema.AgenticMessage]{Agent: agent}) + + iter := runner.Run(ctx, []*schema.AgenticMessage{ + schema.UserAgenticMessage("Hello"), + }) + + event, ok := iter.Next() + assert.True(t, ok) + assert.Nil(t, event.Err) + assert.NotNil(t, event.Output) + + _, ok = iter.Next() + assert.False(t, ok) +} + +func TestCascadingTyped_TypedStatefulInterrupt(t *testing.T) { + ctx := context.Background() + ctx = AppendAddressSegment(ctx, AddressSegmentAgent, "test-agent") + + type myState struct { + Count int + } + + event := TypedStatefulInterrupt[*schema.AgenticMessage](ctx, "please confirm", &myState{Count: 42}) + require.NotNil(t, event) + require.NotNil(t, event.Action) + require.NotNil(t, event.Action.Interrupted) +} + +func TestCascadingTyped_EventFromAgenticMessage(t *testing.T) { + msg := &schema.AgenticMessage{ + Role: schema.AgenticRoleTypeAssistant, + ContentBlocks: []*schema.ContentBlock{ + schema.NewContentBlock(&schema.AssistantGenText{Text: "hello"}), + }, + } + + event := EventFromAgenticMessage(msg, nil, schema.AgenticRoleTypeAssistant) + require.NotNil(t, event) + require.NotNil(t, event.Output) + require.NotNil(t, event.Output.MessageOutput) + assert.Equal(t, msg, event.Output.MessageOutput.Message) + assert.False(t, event.Output.MessageOutput.IsStreaming) + assert.Equal(t, schema.RoleType(""), event.Output.MessageOutput.Role) + assert.Equal(t, schema.AgenticRoleTypeAssistant, event.Output.MessageOutput.AgenticRole) + assert.Empty(t, event.Output.MessageOutput.ToolName) +} + +// assertAgenticEventRoleFields asserts that all AgenticMessage events in the +// list have zero-valued Role and ToolName fields (which are *schema.Message-only), +// and that AgenticRole is populated with a non-zero value. +func assertAgenticEventRoleFields(t *testing.T, events []*TypedAgentEvent[*schema.AgenticMessage]) { + t.Helper() + for i, event := range events { + if event.Output == nil || event.Output.MessageOutput == nil { + continue + } + mo := event.Output.MessageOutput + assert.Equal(t, schema.RoleType(""), mo.Role, "event[%d]: AgenticMessage must have zero Role", i) + assert.Empty(t, mo.ToolName, "event[%d]: AgenticMessage must have empty ToolName", i) + assert.NotEmpty(t, mo.AgenticRole, "event[%d]: AgenticMessage must have non-zero AgenticRole", i) + } +} + +func TestCoverage_FlowAgent_ResumeNotResumable(t *testing.T) { + ctx := context.Background() + + agent := &mockAgenticAgent{ + name: "non-resumable", + description: "cannot resume", + responses: []*TypedAgentEvent[*schema.AgenticMessage]{ + {Output: &TypedAgentOutput[*schema.AgenticMessage]{ + MessageOutput: &TypedMessageVariant[*schema.AgenticMessage]{ + Message: agenticMsg("done"), + }, + }}, + }, + } + + fa := toTypedFlowAgent[*schema.AgenticMessage](agent) + + info := &ResumeInfo{WasInterrupted: true} + iter := fa.Resume(ctx, info) + + var capturedErr error + for { + event, ok := iter.Next() + if !ok { + break + } + if event.Err != nil { + capturedErr = event.Err + } + } + require.Error(t, capturedErr, "should get error for non-resumable agent") +} + +func TestCoverage_GenAgenticErrorIter(t *testing.T) { + testErr := errors.New("test agentic error") + iter := genAgenticErrorIter(testErr) + + event, ok := iter.Next() + require.True(t, ok) + assert.Equal(t, testErr, event.Err) + + _, ok = iter.Next() + assert.False(t, ok) +} + +func TestCoverage_ChatModelAgent_OnSetSubAgents_FrozenError(t *testing.T) { + ctx := context.Background() + + m := &mockAgenticModel{ + generateFn: func(_ context.Context, _ []*schema.AgenticMessage, _ ...model.Option) (*schema.AgenticMessage, error) { + return agenticMsg("done"), nil + }, + } + + agent, err := NewTypedChatModelAgent[*schema.AgenticMessage](ctx, &TypedChatModelAgentConfig[*schema.AgenticMessage]{ + Name: "freeze-test", + Description: "frozen test agent", + Model: m, + }) + require.NoError(t, err) + + input := &TypedAgentInput[*schema.AgenticMessage]{ + Messages: []*schema.AgenticMessage{schema.UserAgenticMessage("Hi")}, + } + iter := agent.Run(ctx, input) + for { + _, ok := iter.Next() + if !ok { + break + } + } + + err = agent.OnSetSubAgents(ctx, []TypedAgent[*schema.AgenticMessage]{ + &mockAgenticAgent{name: "late-child"}, + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "frozen") +} + +func TestCoverage_ChatModelAgent_OnSetAsSubAgent_FrozenError(t *testing.T) { + ctx := context.Background() + + m := &mockAgenticModel{ + generateFn: func(_ context.Context, _ []*schema.AgenticMessage, _ ...model.Option) (*schema.AgenticMessage, error) { + return agenticMsg("done"), nil + }, + } + + agent, err := NewTypedChatModelAgent[*schema.AgenticMessage](ctx, &TypedChatModelAgentConfig[*schema.AgenticMessage]{ + Name: "freeze-child", + Description: "frozen child agent", + Model: m, + }) + require.NoError(t, err) + + input := &TypedAgentInput[*schema.AgenticMessage]{ + Messages: []*schema.AgenticMessage{schema.UserAgenticMessage("Hi")}, + } + iter := agent.Run(ctx, input) + for { + _, ok := iter.Next() + if !ok { + break + } + } + + err = agent.OnSetAsSubAgent(ctx, &mockAgenticAgent{name: "parent"}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "frozen") +} + +func TestCoverage_ChatModelAgent_OnSetAsSubAgent_DuplicateError(t *testing.T) { + ctx := context.Background() + + m := &mockAgenticModel{ + generateFn: func(_ context.Context, _ []*schema.AgenticMessage, _ ...model.Option) (*schema.AgenticMessage, error) { + return agenticMsg("done"), nil + }, + } + + agent, err := NewTypedChatModelAgent[*schema.AgenticMessage](ctx, &TypedChatModelAgentConfig[*schema.AgenticMessage]{ + Name: "dup-child", + Description: "duplicate child agent", + Model: m, + }) + require.NoError(t, err) + + err = agent.OnSetAsSubAgent(ctx, &mockAgenticAgent{name: "parent1"}) + assert.NoError(t, err) + + err = agent.OnSetAsSubAgent(ctx, &mockAgenticAgent{name: "parent2"}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "already been set as a sub-agent") +} + +func TestCoverage_ChatModelAgent_OnDisallowTransferToParent_FrozenError(t *testing.T) { + ctx := context.Background() + + m := &mockAgenticModel{ + generateFn: func(_ context.Context, _ []*schema.AgenticMessage, _ ...model.Option) (*schema.AgenticMessage, error) { + return agenticMsg("done"), nil + }, + } + + agent, err := NewTypedChatModelAgent[*schema.AgenticMessage](ctx, &TypedChatModelAgentConfig[*schema.AgenticMessage]{ + Name: "disallow-test", + Description: "disallow transfer test", + Model: m, + }) + require.NoError(t, err) + + input := &TypedAgentInput[*schema.AgenticMessage]{ + Messages: []*schema.AgenticMessage{schema.UserAgenticMessage("Hi")}, + } + iter := agent.Run(ctx, input) + for { + _, ok := iter.Next() + if !ok { + break + } + } + + err = agent.OnDisallowTransferToParent(ctx) + assert.Error(t, err) + assert.Contains(t, err.Error(), "frozen") +} + +func TestCoverage_TypedGetMessage_AgenticNonStreaming(t *testing.T) { + msg := agenticMsg("hello") + event := &TypedAgentEvent[*schema.AgenticMessage]{ + Output: &TypedAgentOutput[*schema.AgenticMessage]{ + MessageOutput: &TypedMessageVariant[*schema.AgenticMessage]{ + Message: msg, + }, + }, + } + + result, retEvent, err := TypedGetMessage(event) + assert.NoError(t, err) + assert.Equal(t, msg, result) + assert.Equal(t, event, retEvent) +} + +func TestCoverage_TypedGetMessage_AgenticStreaming(t *testing.T) { + r, w := schema.Pipe[*schema.AgenticMessage](2) + go func() { + defer w.Close() + w.Send(&schema.AgenticMessage{ + Role: schema.AgenticRoleTypeAssistant, + ContentBlocks: []*schema.ContentBlock{ + schema.NewContentBlock(&schema.AssistantGenText{Text: "Hello "}), + }, + }, nil) + w.Send(&schema.AgenticMessage{ + Role: schema.AgenticRoleTypeAssistant, + ContentBlocks: []*schema.ContentBlock{ + schema.NewContentBlock(&schema.AssistantGenText{Text: "world"}), + }, + }, nil) + }() + + event := &TypedAgentEvent[*schema.AgenticMessage]{ + Output: &TypedAgentOutput[*schema.AgenticMessage]{ + MessageOutput: &TypedMessageVariant[*schema.AgenticMessage]{ + IsStreaming: true, + MessageStream: r, + }, + }, + } + + result, retEvent, err := TypedGetMessage(event) + assert.NoError(t, err) + assert.NotNil(t, result) + require.NotNil(t, retEvent) + assert.NotNil(t, retEvent.Output.MessageOutput.MessageStream) +} + +func TestCoverage_TypedGetMessage_NilOutput(t *testing.T) { + event := &TypedAgentEvent[*schema.AgenticMessage]{} + + result, retEvent, err := TypedGetMessage(event) + assert.NoError(t, err) + assert.Nil(t, result) + assert.Equal(t, event, retEvent) +} + +func TestCoverage_GetMessage_NonStreaming(t *testing.T) { + msg := schema.AssistantMessage("hello", nil) + event := &AgentEvent{ + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + Message: msg, + }, + }, + } + + result, retEvent, err := GetMessage(event) + assert.NoError(t, err) + assert.Equal(t, msg, result) + assert.Equal(t, event, retEvent) +} + +func TestCoverage_GetMessage_Streaming(t *testing.T) { + r, w := schema.Pipe[*schema.Message](2) + go func() { + defer w.Close() + w.Send(schema.AssistantMessage("Hello ", nil), nil) + w.Send(schema.AssistantMessage("world", nil), nil) + }() + + event := &AgentEvent{ + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + IsStreaming: true, + MessageStream: r, + }, + }, + } + + result, retEvent, err := GetMessage(event) + assert.NoError(t, err) + assert.NotNil(t, result) + assert.NotNil(t, retEvent) +} + +func TestCoverage_NewTypedAgentTool_Agentic(t *testing.T) { + ctx := context.Background() + + m := &mockAgenticModel{ + generateFn: func(_ context.Context, _ []*schema.AgenticMessage, _ ...model.Option) (*schema.AgenticMessage, error) { + return agenticMsg("tool response"), nil + }, + } + + agent, err := NewTypedChatModelAgent[*schema.AgenticMessage](ctx, &TypedChatModelAgentConfig[*schema.AgenticMessage]{ + Name: "tool-agent", + Description: "agent wrapped as tool", + Model: m, + }) + require.NoError(t, err) + + agentTool := NewTypedAgentTool[*schema.AgenticMessage](ctx, agent) + + info, err := agentTool.Info(ctx) + require.NoError(t, err) + assert.Equal(t, "tool-agent", info.Name) + + result, err := agentTool.(tool.InvokableTool).InvokableRun(ctx, `{"request":"test"}`) + require.NoError(t, err) + assert.Contains(t, result, "tool response") +} +func TestCoverage_CopyAgenticEvent(t *testing.T) { + original := &TypedAgentEvent[*schema.AgenticMessage]{ + AgentName: "agent1", + RunPath: []RunStep{{agentName: "root"}, {agentName: "agent1"}}, + Output: &TypedAgentOutput[*schema.AgenticMessage]{ + MessageOutput: &TypedMessageVariant[*schema.AgenticMessage]{ + Message: agenticMsg("hello"), + }, + }, + Action: &AgentAction{ + TransferToAgent: &TransferToAgentAction{DestAgentName: "agent2"}, + }, + } + + copied := copyTypedAgentEvent(original) + assert.Equal(t, original.AgentName, copied.AgentName) + assert.Equal(t, len(original.RunPath), len(copied.RunPath)) + assert.Equal(t, original.Action, copied.Action) + + copied.RunPath[0].agentName = "mutated" + assert.NotEqual(t, original.RunPath[0].agentName, copied.RunPath[0].agentName) +} + +func TestCoverage_ChatModelAgent_ModelGenerateError(t *testing.T) { + ctx := context.Background() + + testErr := errors.New("model generate failed") + m := &mockAgenticModel{ + generateFn: func(_ context.Context, _ []*schema.AgenticMessage, _ ...model.Option) (*schema.AgenticMessage, error) { + return nil, testErr + }, + } + + agent, err := NewTypedChatModelAgent[*schema.AgenticMessage](ctx, &TypedChatModelAgentConfig[*schema.AgenticMessage]{ + Name: "error-model-agent", + Description: "tests model generate error", + Model: m, + }) + require.NoError(t, err) + + runner := NewTypedRunner[*schema.AgenticMessage](TypedRunnerConfig[*schema.AgenticMessage]{ + Agent: agent, + }) + + iter := runner.Query(ctx, "trigger error") + + var capturedErr error + for { + event, ok := iter.Next() + if !ok { + break + } + if event.Err != nil { + capturedErr = event.Err + } + } + require.Error(t, capturedErr, "should propagate model error") +} + +func TestCoverage_NewTypedUserMessages(t *testing.T) { + t.Run("Message", func(t *testing.T) { + msgs := newTypedUserMessages[*schema.Message]("hello") + require.Len(t, msgs, 1) + assert.Equal(t, schema.User, msgs[0].Role) + assert.Equal(t, "hello", msgs[0].Content) + }) + + t.Run("AgenticMessage", func(t *testing.T) { + msgs := newTypedUserMessages[*schema.AgenticMessage]("hello") + require.Len(t, msgs, 1) + assert.Equal(t, schema.AgenticRoleTypeUser, msgs[0].Role) + }) +} + +func TestCoverage_TypedEndpointModel_NilEndpoints(t *testing.T) { + ctx := context.Background() + + m := &typedEndpointModel[*schema.AgenticMessage]{} + + _, err := m.Generate(ctx, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "generate endpoint not set") + + _, err = m.Stream(ctx, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "stream endpoint not set") +} + +func TestCoverage_TypedEndpointModel_WithEndpoints(t *testing.T) { + ctx := context.Background() + + expected := agenticMsg("generated") + m := &typedEndpointModel[*schema.AgenticMessage]{ + generate: func(_ context.Context, _ []*schema.AgenticMessage, _ ...model.Option) (*schema.AgenticMessage, error) { + return expected, nil + }, + stream: func(_ context.Context, _ []*schema.AgenticMessage, _ ...model.Option) (*schema.StreamReader[*schema.AgenticMessage], error) { + r, w := schema.Pipe[*schema.AgenticMessage](1) + go func() { + defer w.Close() + w.Send(expected, nil) + }() + return r, nil + }, + } + + result, err := m.Generate(ctx, nil) + assert.NoError(t, err) + assert.Equal(t, expected, result) + + stream, err := m.Stream(ctx, nil) + assert.NoError(t, err) + require.NotNil(t, stream) + msg, err := stream.Recv() + assert.NoError(t, err) + assert.Equal(t, expected, msg) + _, err = stream.Recv() + assert.Equal(t, io.EOF, err) +} + +func TestCoverage_SetAutomaticClose(t *testing.T) { + r, w := schema.Pipe[*schema.AgenticMessage](1) + go func() { + defer w.Close() + w.Send(agenticMsg("data"), nil) + }() + + event := &TypedAgentEvent[*schema.AgenticMessage]{ + Output: &TypedAgentOutput[*schema.AgenticMessage]{ + MessageOutput: &TypedMessageVariant[*schema.AgenticMessage]{ + IsStreaming: true, + MessageStream: r, + }, + }, + } + + typedSetAutomaticClose(event) +} + +func TestConcatMessageStream_AgenticClosesStream(t *testing.T) { + r, w := schema.Pipe[*schema.AgenticMessage](2) + go func() { + defer w.Close() + w.Send(agenticMsg("a"), nil) + w.Send(agenticMsg("b"), nil) + }() + + result, err := concatMessageStream(r) + require.NoError(t, err) + require.NotNil(t, result) + + _, recvErr := r.Recv() + assert.Error(t, recvErr, + "stream should be closed after concatMessageStream returns") +} diff --git a/adk/callback.go b/adk/callback.go index 19afbfc7e..0b5cac879 100644 --- a/adk/callback.go +++ b/adk/callback.go @@ -43,18 +43,18 @@ type AgentCallbackOutput struct { Events *AsyncIterator[*AgentEvent] } -func copyEventIterator(iter *AsyncIterator[*AgentEvent], n int) []*AsyncIterator[*AgentEvent] { +func copyTypedEventIterator[M messageType](iter *AsyncIterator[*TypedAgentEvent[M]], n int) []*AsyncIterator[*TypedAgentEvent[M]] { if n <= 0 { return nil } if n == 1 { - return []*AsyncIterator[*AgentEvent]{iter} + return []*AsyncIterator[*TypedAgentEvent[M]]{iter} } - iterators := make([]*AsyncIterator[*AgentEvent], n) - generators := make([]*AsyncGenerator[*AgentEvent], n) + iterators := make([]*AsyncIterator[*TypedAgentEvent[M]], n) + generators := make([]*AsyncGenerator[*TypedAgentEvent[M]], n) for i := 0; i < n; i++ { - iterators[i], generators[i] = NewAsyncIteratorPair[*AgentEvent]() + iterators[i], generators[i] = NewAsyncIteratorPair[*TypedAgentEvent[M]]() } go func() { @@ -70,7 +70,7 @@ func copyEventIterator(iter *AsyncIterator[*AgentEvent], n int) []*AsyncIterator break } for i := 0; i < n-1; i++ { - generators[i].Send(copyAgentEvent(event)) + generators[i].Send(copyTypedAgentEvent(event)) } generators[n-1].Send(event) } @@ -87,7 +87,7 @@ func copyAgentCallbackOutput(out *AgentCallbackOutput, n int) []*AgentCallbackOu } return result } - iters := copyEventIterator(out.Events, n) + iters := copyTypedEventIterator(out.Events, n) result := make([]*AgentCallbackOutput, n) for i, iter := range iters { result[i] = &AgentCallbackOutput{Events: iter} @@ -133,3 +133,70 @@ func getAgentType(agent Agent) string { } return "" } + +// TypedAgentCallbackInput represents the input passed to typed agent callbacks during OnStart. +// Use ConvTypedCallbackInput to safely convert from callbacks.CallbackInput. +type TypedAgentCallbackInput[M messageType] struct { + // Input contains the agent input for a new run. Nil when resuming. + Input *TypedAgentInput[M] + // ResumeInfo contains resume information when resuming from an interrupt. Nil for new runs. + ResumeInfo *ResumeInfo +} + +// TypedAgentCallbackOutput represents the output passed to typed agent callbacks during OnEnd. +// Use ConvTypedCallbackOutput to safely convert from callbacks.CallbackOutput. +// +// Important: The Events iterator should be consumed asynchronously to avoid blocking +// the agent execution. Each callback handler receives an independent copy of the iterator. +type TypedAgentCallbackOutput[M messageType] struct { + // Events provides a stream of agent events. Each handler receives its own copy. + Events *AsyncIterator[*TypedAgentEvent[M]] +} + +// ConvTypedCallbackInput converts a callbacks.CallbackInput to *TypedAgentCallbackInput[M]. +// Returns nil if the input is not of the expected type. +func ConvTypedCallbackInput[M messageType](input callbacks.CallbackInput) *TypedAgentCallbackInput[M] { + if v, ok := input.(*TypedAgentCallbackInput[M]); ok { + return v + } + return nil +} + +// ConvTypedCallbackOutput converts a callbacks.CallbackOutput to *TypedAgentCallbackOutput[M]. +// Returns nil if the output is not of the expected type. +func ConvTypedCallbackOutput[M messageType](output callbacks.CallbackOutput) *TypedAgentCallbackOutput[M] { + if v, ok := output.(*TypedAgentCallbackOutput[M]); ok { + return v + } + return nil +} + +func copyTypedCallbackOutput[M messageType](out *TypedAgentCallbackOutput[M], n int) []*TypedAgentCallbackOutput[M] { + if out == nil || out.Events == nil { + result := make([]*TypedAgentCallbackOutput[M], n) + for i := 0; i < n; i++ { + result[i] = out + } + return result + } + iters := copyTypedEventIterator(out.Events, n) + result := make([]*TypedAgentCallbackOutput[M], n) + for i, iter := range iters { + result[i] = &TypedAgentCallbackOutput[M]{Events: iter} + } + return result +} + +func initAgenticCallbacks(ctx context.Context, agentName, agentType string, opts ...AgentRunOption) context.Context { + ri := &callbacks.RunInfo{ + Name: agentName, + Type: agentType, + Component: ComponentOfAgenticAgent, + } + + o := getCommonOptions(nil, opts...) + if len(o.handlers) == 0 { + return icb.ReuseHandlers(ctx, ri) + } + return icb.AppendHandlers(ctx, ri, o.handlers...) +} diff --git a/adk/callback_test.go b/adk/callback_test.go index b54ea7ee5..efd66f562 100644 --- a/adk/callback_test.go +++ b/adk/callback_test.go @@ -22,12 +22,13 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/schema" ) -func TestCopyEventIterator(t *testing.T) { +func TestCopyTypedEventIterator(t *testing.T) { t.Run("n=0 returns nil", func(t *testing.T) { iter, gen := NewAsyncIteratorPair[*AgentEvent]() go func() { @@ -35,7 +36,7 @@ func TestCopyEventIterator(t *testing.T) { gen.Close() }() - result := copyEventIterator(iter, 0) + result := copyTypedEventIterator(iter, 0) assert.Nil(t, result) }) @@ -46,7 +47,7 @@ func TestCopyEventIterator(t *testing.T) { gen.Close() }() - result := copyEventIterator(iter, 1) + result := copyTypedEventIterator(iter, 1) assert.Len(t, result, 1) assert.Equal(t, iter, result[0]) }) @@ -66,7 +67,7 @@ func TestCopyEventIterator(t *testing.T) { }() n := 3 - copies := copyEventIterator(iter, n) + copies := copyTypedEventIterator(iter, n) assert.Len(t, copies, n) var wg sync.WaitGroup @@ -127,7 +128,7 @@ func TestCopyAgentCallbackOutput(t *testing.T) { assert.Len(t, result, 2) for i, r := range result { - assert.NotNil(t, r, "result[%d] should not be nil", i) + require.NotNil(t, r, "result[%d] should not be nil", i) assert.NotNil(t, r.Events, "result[%d].Events should not be nil", i) } }) @@ -234,3 +235,154 @@ func TestWithMultipleCallbacksOption(t *testing.T) { assert.Len(t, opts.handlers, 2) } + +func TestCopyTypedEventIteratorAgentic(t *testing.T) { + t.Run("n=0 returns nil", func(t *testing.T) { + iter, gen := NewAsyncIteratorPair[*TypedAgentEvent[*schema.AgenticMessage]]() + go func() { + gen.Send(&TypedAgentEvent[*schema.AgenticMessage]{AgentName: "test"}) + gen.Close() + }() + + result := copyTypedEventIterator(iter, 0) + assert.Nil(t, result) + }) + + t.Run("n=1 returns original iterator", func(t *testing.T) { + iter, gen := NewAsyncIteratorPair[*TypedAgentEvent[*schema.AgenticMessage]]() + go func() { + gen.Send(&TypedAgentEvent[*schema.AgenticMessage]{AgentName: "test"}) + gen.Close() + }() + + result := copyTypedEventIterator(iter, 1) + assert.Len(t, result, 1) + assert.Equal(t, iter, result[0]) + }) + + t.Run("n>1 creates n independent copies", func(t *testing.T) { + iter, gen := NewAsyncIteratorPair[*TypedAgentEvent[*schema.AgenticMessage]]() + events := []*TypedAgentEvent[*schema.AgenticMessage]{ + {AgentName: "agent1", Output: &TypedAgentOutput[*schema.AgenticMessage]{ + MessageOutput: &TypedMessageVariant[*schema.AgenticMessage]{Message: agenticMsg("msg1")}, + }}, + {AgentName: "agent2", Output: &TypedAgentOutput[*schema.AgenticMessage]{ + MessageOutput: &TypedMessageVariant[*schema.AgenticMessage]{Message: agenticMsg("msg2")}, + }}, + } + + go func() { + for _, e := range events { + gen.Send(e) + } + gen.Close() + }() + + n := 3 + copies := copyTypedEventIterator(iter, n) + assert.Len(t, copies, n) + + var wg sync.WaitGroup + receivedEvents := make([][]*TypedAgentEvent[*schema.AgenticMessage], n) + + for i := 0; i < n; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + for { + event, ok := copies[idx].Next() + if !ok { + break + } + receivedEvents[idx] = append(receivedEvents[idx], event) + } + }(i) + } + + wg.Wait() + + for i := 0; i < n; i++ { + assert.Len(t, receivedEvents[i], len(events), "iterator %d should receive all events", i) + for j, e := range receivedEvents[i] { + assert.Equal(t, events[j].AgentName, e.AgentName) + } + } + }) +} + +func TestCopyTypedCallbackOutput(t *testing.T) { + t.Run("nil output", func(t *testing.T) { + result := copyTypedCallbackOutput[*schema.AgenticMessage](nil, 3) + assert.Len(t, result, 3) + for _, r := range result { + assert.Nil(t, r) + } + }) + + t.Run("output with nil Events", func(t *testing.T) { + out := &TypedAgentCallbackOutput[*schema.AgenticMessage]{Events: nil} + result := copyTypedCallbackOutput(out, 3) + assert.Len(t, result, 3) + for _, r := range result { + assert.Equal(t, out, r) + } + }) + + t.Run("valid output with events", func(t *testing.T) { + iter, gen := NewAsyncIteratorPair[*TypedAgentEvent[*schema.AgenticMessage]]() + go func() { + gen.Send(&TypedAgentEvent[*schema.AgenticMessage]{AgentName: "test"}) + gen.Close() + }() + + out := &TypedAgentCallbackOutput[*schema.AgenticMessage]{Events: iter} + result := copyTypedCallbackOutput(out, 2) + assert.Len(t, result, 2) + + for i, r := range result { + require.NotNil(t, r, "result[%d] should not be nil", i) + assert.NotNil(t, r.Events, "result[%d].Events should not be nil", i) + } + }) +} + +func TestConvTypedCallbackInput(t *testing.T) { + t.Run("valid TypedAgentCallbackInput", func(t *testing.T) { + input := &TypedAgentCallbackInput[*schema.AgenticMessage]{ + Input: &TypedAgentInput[*schema.AgenticMessage]{ + Messages: []*schema.AgenticMessage{schema.UserAgenticMessage("test")}, + }, + } + result := ConvTypedCallbackInput[*schema.AgenticMessage](input) + assert.Equal(t, input, result) + }) + + t.Run("invalid type returns nil", func(t *testing.T) { + result := ConvTypedCallbackInput[*schema.AgenticMessage]("invalid") + assert.Nil(t, result) + }) + + t.Run("nil returns nil", func(t *testing.T) { + result := ConvTypedCallbackInput[*schema.AgenticMessage](nil) + assert.Nil(t, result) + }) +} + +func TestConvTypedCallbackOutput(t *testing.T) { + t.Run("valid TypedAgentCallbackOutput", func(t *testing.T) { + iter, _ := NewAsyncIteratorPair[*TypedAgentEvent[*schema.AgenticMessage]]() + output := &TypedAgentCallbackOutput[*schema.AgenticMessage]{Events: iter} + result := ConvTypedCallbackOutput[*schema.AgenticMessage](output) + assert.Equal(t, output, result) + }) + + t.Run("invalid type returns nil", func(t *testing.T) { + result := ConvTypedCallbackOutput[*schema.AgenticMessage]("invalid") + assert.Nil(t, result) + }) + + t.Run("nil returns nil", func(t *testing.T) { + result := ConvTypedCallbackOutput[*schema.AgenticMessage](nil) + assert.Nil(t, result) + }) +} diff --git a/adk/cancel.go b/adk/cancel.go index 513b0cf43..72f3e109f 100644 --- a/adk/cancel.go +++ b/adk/cancel.go @@ -812,11 +812,11 @@ func (cc *cancelContext) buildCancelFunc() AgentCancelFunc { // were passed through unconverted, markDone would transition stateCancelling→stateDone // before the Runner goroutine could call createAndMarkCancelHandled, causing it // to fail the CAS. -func wrapIterWithCancelCtx(iter *AsyncIterator[*AgentEvent], cancelCtx *cancelContext) *AsyncIterator[*AgentEvent] { +func wrapIterWithCancelCtx[M messageType](iter *AsyncIterator[*TypedAgentEvent[M]], cancelCtx *cancelContext) *AsyncIterator[*TypedAgentEvent[M]] { if cancelCtx == nil { return iter } - it, gen := NewAsyncIteratorPair[*AgentEvent]() + it, gen := NewAsyncIteratorPair[*TypedAgentEvent[M]]() go func() { defer cancelCtx.markDone() defer gen.Close() @@ -831,7 +831,7 @@ func wrapIterWithCancelCtx(iter *AsyncIterator[*AgentEvent], cancelCtx *cancelCo cancelErr, ok := cancelCtx.createAndMarkCancelHandled() if ok { cancelErr.interruptSignal = event.Action.internalInterrupted - gen.Send(&AgentEvent{Err: cancelErr}) + gen.Send(&TypedAgentEvent[M]{Err: cancelErr}) } return } @@ -843,13 +843,13 @@ func wrapIterWithCancelCtx(iter *AsyncIterator[*AgentEvent], cancelCtx *cancelCo return it } -// cancelMonitoredModel wraps a model with cancel monitoring. +// typedCancelMonitoredModel wraps a model with cancel monitoring. // Generate: pure delegate to the inner model (CancelAfterChatModel is handled // by a dedicated node after the ChatModel in the compose graph). // Stream: pipes chunks through a goroutine that selects on immediateChan for // CancelImmediate abort. -type cancelMonitoredModel struct { - inner model.BaseChatModel +type typedCancelMonitoredModel[M messageType] struct { + inner model.BaseModel[M] cancelContext *cancelContext } @@ -858,11 +858,11 @@ type recvResult[T any] struct { err error } -func (m *cancelMonitoredModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { +func (m *typedCancelMonitoredModel[M]) Generate(ctx context.Context, input []M, opts ...model.Option) (M, error) { return m.inner.Generate(ctx, input, opts...) } -func (m *cancelMonitoredModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { +func (m *typedCancelMonitoredModel[M]) Stream(ctx context.Context, input []M, opts ...model.Option) (*schema.StreamReader[M], error) { stream, err := m.inner.Stream(ctx, input, opts...) if err != nil { return nil, err diff --git a/adk/cancel_edge_test.go b/adk/cancel_edge_test.go index 248a84ee0..946cd6008 100644 --- a/adk/cancel_edge_test.go +++ b/adk/cancel_edge_test.go @@ -1336,7 +1336,7 @@ func TestWithCancel_CancelImmediate_StreamableToolAborted(t *testing.T) { tcm := &toolCallStreamModel{} st := &slowStreamingTool{ name: "slow_tool", - chunkInterval: 200 * time.Millisecond, + chunkInterval: 100 * time.Millisecond, chunks: []string{"a", "b", "c", "d", "e", "f", "g", "h", "i", "j"}, started: make(chan struct{}, 1), } @@ -1366,7 +1366,7 @@ func TestWithCancel_CancelImmediate_StreamableToolAborted(t *testing.T) { t.Fatal("tool did not start streaming") } // Let a few chunks through, then cancel mid-stream - time.Sleep(300 * time.Millisecond) + time.Sleep(500 * time.Millisecond) handle, _ := cancelFn() cancelErr := handle.Wait() diff --git a/adk/cancel_test.go b/adk/cancel_test.go index 97779827b..e08a0f585 100644 --- a/adk/cancel_test.go +++ b/adk/cancel_test.go @@ -2368,7 +2368,7 @@ func TestCancelImmediate_OrphanedToolGoroutine_NoPanic(t *testing.T) { cancelCtx: cc, } - ctx := withChatModelAgentExecCtx(context.Background(), execCtx) + ctx := withTypedChatModelAgentExecCtx[*schema.Message](context.Background(), execCtx) assert.NotPanics(t, func() { err := SendEvent(ctx, &AgentEvent{AgentName: "test"}) diff --git a/adk/chatmodel.go b/adk/chatmodel.go index fe32cdfac..096435dfd 100644 --- a/adk/chatmodel.go +++ b/adk/chatmodel.go @@ -24,6 +24,7 @@ import ( "fmt" "math" "runtime/debug" + "strings" "sync" "sync/atomic" @@ -38,14 +39,15 @@ import ( "github.com/cloudwego/eino/schema" ) -var _ ResumableAgent = &ChatModelAgent{} +var _ ResumableAgent = &TypedChatModelAgent[*schema.Message]{} +var _ TypedResumableAgent[*schema.AgenticMessage] = &TypedChatModelAgent[*schema.AgenticMessage]{} -type chatModelAgentExecCtx struct { +type typedChatModelAgentExecCtx[M messageType] struct { runtimeReturnDirectly map[string]bool - generator *AsyncGenerator[*AgentEvent] + generator *AsyncGenerator[*TypedAgentEvent[M]] cancelCtx *cancelContext - failoverLastSuccessModel model.BaseChatModel + failoverLastSuccessModel model.BaseModel[M] // suppressEventSend prevents eventSenderModel from emitting AgentEvents for the current // Generate call. Set to true before each rejected retry attempt and reset to false after. @@ -54,7 +56,7 @@ type chatModelAgentExecCtx struct { retryVerdictSignal *retryVerdictSignal } -func (e *chatModelAgentExecCtx) send(event *AgentEvent) { +func (e *typedChatModelAgentExecCtx[M]) send(event *TypedAgentEvent[M]) { if e == nil || e.generator == nil { return } @@ -64,15 +66,17 @@ func (e *chatModelAgentExecCtx) send(event *AgentEvent) { e.generator.trySend(event) } -type chatModelAgentExecCtxKey struct{} +type chatModelAgentExecCtx = typedChatModelAgentExecCtx[*schema.Message] -func withChatModelAgentExecCtx(ctx context.Context, execCtx *chatModelAgentExecCtx) context.Context { - return context.WithValue(ctx, chatModelAgentExecCtxKey{}, execCtx) +type typedChatModelAgentExecCtxKey[M messageType] struct{} + +func withTypedChatModelAgentExecCtx[M messageType](ctx context.Context, execCtx *typedChatModelAgentExecCtx[M]) context.Context { + return context.WithValue(ctx, typedChatModelAgentExecCtxKey[M]{}, execCtx) } -func getChatModelAgentExecCtx(ctx context.Context) *chatModelAgentExecCtx { - if v := ctx.Value(chatModelAgentExecCtxKey{}); v != nil { - return v.(*chatModelAgentExecCtx) +func getTypedChatModelAgentExecCtx[M messageType](ctx context.Context) *typedChatModelAgentExecCtx[M] { + if v := ctx.Value(typedChatModelAgentExecCtxKey[M]{}); v != nil { + return v.(*typedChatModelAgentExecCtx[M]) } return nil } @@ -137,8 +141,14 @@ type ToolsConfig struct { EmitInternalEvents bool } +// TypedGenModelInput transforms the agent's system instruction and user input into model input +// messages ([]M). This is the primary customization point for controlling what the model sees. +// The default implementation prepends a system message (if instruction is non-empty), +// followed by the user's input messages. +type TypedGenModelInput[M messageType] func(ctx context.Context, instruction string, input *TypedAgentInput[M]) ([]M, error) + // GenModelInput transforms agent instructions and input into a format suitable for the model. -type GenModelInput func(ctx context.Context, instruction string, input *AgentInput) ([]Message, error) +type GenModelInput = TypedGenModelInput[*schema.Message] func defaultGenModelInput(ctx context.Context, instruction string, input *AgentInput) ([]Message, error) { msgs := make([]Message, 0, len(input.Messages)+1) @@ -168,13 +178,35 @@ func defaultGenModelInput(ctx context.Context, instruction string, input *AgentI return msgs, nil } -// ChatModelAgentState represents the state of a chat model agent during conversation. -// This is the primary state type for both ChatModelAgentMiddleware and AgentMiddleware callbacks. -type ChatModelAgentState struct { +func newDefaultGenModelInput[M messageType]() TypedGenModelInput[M] { + var zero M + switch any(zero).(type) { + case *schema.Message: + return any(GenModelInput(defaultGenModelInput)).(TypedGenModelInput[M]) + case *schema.AgenticMessage: + return any(TypedGenModelInput[*schema.AgenticMessage](func(_ context.Context, instruction string, input *TypedAgentInput[*schema.AgenticMessage]) ([]*schema.AgenticMessage, error) { + msgs := make([]*schema.AgenticMessage, 0, len(input.Messages)+1) + if instruction != "" { + msgs = append(msgs, schema.SystemAgenticMessage(instruction)) + } + msgs = append(msgs, input.Messages...) + return msgs, nil + })).(TypedGenModelInput[M]) + default: + panic("unreachable: unknown messageType") + } +} + +// TypedChatModelAgentState represents the state of a chat model agent during conversation. +// This is the primary state type for both TypedChatModelAgentMiddleware and AgentMiddleware callbacks. +type TypedChatModelAgentState[M messageType] struct { // Messages contains all messages in the current conversation session. - Messages []Message + Messages []M } +// ChatModelAgentState is the default state type using *schema.Message. +type ChatModelAgentState = TypedChatModelAgentState[*schema.Message] + // AgentMiddleware provides hooks to customize agent behavior at various stages of execution. // // Limitations of AgentMiddleware (struct-based): @@ -207,7 +239,8 @@ type AgentMiddleware struct { WrapToolCall compose.ToolMiddleware } -type ChatModelAgentConfig struct { +// TypedChatModelAgentConfig is the generic configuration for ChatModelAgent. +type TypedChatModelAgentConfig[M messageType] struct { // Name of the agent. Better be unique across all agents. // Optional. If empty, the agent can still run standalone but cannot be used as // a sub-agent tool via NewAgentTool (which requires a non-empty Name). @@ -227,13 +260,13 @@ type ChatModelAgentConfig struct { // Model is the chat model used by the agent. // If your ChatModelAgent uses any tools, this model must support the model.WithTools // call option, as that's how ChatModelAgent configures the model with tool information. - Model model.BaseChatModel + Model model.BaseModel[M] ToolsConfig ToolsConfig // GenModelInput transforms instructions and input messages into the model's input format. // Optional. Defaults to defaultGenModelInput which combines instruction and messages. - GenModelInput GenModelInput + GenModelInput TypedGenModelInput[M] // Exit defines the tool used to terminate the agent process. // Optional. If nil, no Exit Action will be generated. @@ -347,7 +380,7 @@ type ChatModelAgentConfig struct { // passed to ChatModel, NOT the actual tools available for execution. Use this for // dynamic tool filtering/selection based on conversation context. The modification // is scoped to this model request only. - Handlers []ChatModelAgentMiddleware + Handlers []TypedChatModelAgentMiddleware[M] // ModelRetryConfig configures retry behavior for the ChatModel. // When set, the agent will automatically retry failed ChatModel calls @@ -363,42 +396,52 @@ type ChatModelAgentConfig struct { ModelFailoverConfig *ModelFailoverConfig } -type ChatModelAgent struct { +type ChatModelAgentConfig = TypedChatModelAgentConfig[*schema.Message] + +// TypedChatModelAgent is a chat model-backed agent parameterized by message type. +// +// For M = *schema.Message, the full ReAct loop (model → tool calls → model) is used. +// For M = *schema.AgenticMessage, a single-shot chain is used since agentic models +// handle tool calling internally. Cancel monitoring and retry on the model stream +// are not yet supported for agentic models. +type TypedChatModelAgent[M messageType] struct { name string description string instruction string - model model.BaseChatModel + model model.BaseModel[M] toolsConfig ToolsConfig - genModelInput GenModelInput + genModelInput TypedGenModelInput[M] outputKey string maxIterations int - subAgents []Agent - parentAgent Agent + subAgents []TypedAgent[M] + parentAgent TypedAgent[M] disallowTransferToParent bool exit tool.BaseTool - handlers []ChatModelAgentMiddleware + handlers []TypedChatModelAgentMiddleware[M] middlewares []AgentMiddleware modelRetryConfig *ModelRetryConfig modelFailoverConfig *ModelFailoverConfig once sync.Once - run runFunc + run typedRunFunc[M] frozen uint32 exeCtx *execContext } -// runParams holds the parameters for a runFunc invocation. -type runParams struct { - input *AgentInput - generator *AsyncGenerator[*AgentEvent] +type ChatModelAgent = TypedChatModelAgent[*schema.Message] + +// typedRunParams holds the parameters for a typedRunFunc invocation. +type typedRunParams[M messageType] struct { + input *TypedAgentInput[M] + generator *AsyncGenerator[*TypedAgentEvent[M]] store *bridgeStore instruction string returnDirectly map[string]bool @@ -407,10 +450,15 @@ type runParams struct { composeOpts []compose.Option } -type runFunc func(ctx context.Context, p *runParams) +type typedRunFunc[M messageType] func(ctx context.Context, p *typedRunParams[M]) -// NewChatModelAgent constructs a chat model-backed agent with the provided config. +// NewChatModelAgent creates a new ChatModelAgent with the given config. func NewChatModelAgent(ctx context.Context, config *ChatModelAgentConfig) (*ChatModelAgent, error) { + return NewTypedChatModelAgent[*schema.Message](ctx, config) +} + +// NewTypedChatModelAgent creates a new TypedChatModelAgent with the given config. +func NewTypedChatModelAgent[M messageType](ctx context.Context, config *TypedChatModelAgentConfig[M]) (*TypedChatModelAgent[M], error) { if config.ModelFailoverConfig != nil { if config.ModelFailoverConfig.GetFailoverModel == nil { return nil, errors.New("ModelFailoverConfig.GetFailoverModel is required when ModelFailoverConfig is set") @@ -426,9 +474,11 @@ func NewChatModelAgent(ctx context.Context, config *ChatModelAgentConfig) (*Chat return nil, errors.New("agent 'Model' is required") } - genInput := defaultGenModelInput + var genInput TypedGenModelInput[M] if config.GenModelInput != nil { genInput = config.GenModelInput + } else { + genInput = newDefaultGenModelInput[M]() } tc := config.ToolsConfig @@ -455,7 +505,7 @@ func NewChatModelAgent(ctx context.Context, config *ChatModelAgentConfig) (*Chat EnhancedStreamable: cancelToolHandler.WrapEnhancedStreamableToolCall, }) - return &ChatModelAgent{ + return &TypedChatModelAgent[M]{ name: config.Name, description: config.Description, instruction: config.Instruction, @@ -580,15 +630,15 @@ func (tta transferToAgent) InvokableRun(ctx context.Context, argumentsInJSON str return transferToAgentToolOutput(params.AgentName), nil } -func (a *ChatModelAgent) Name(_ context.Context) string { +func (a *TypedChatModelAgent[M]) Name(_ context.Context) string { return a.name } -func (a *ChatModelAgent) Description(_ context.Context) string { +func (a *TypedChatModelAgent[M]) Description(_ context.Context) string { return a.description } -func (a *ChatModelAgent) GetType() string { +func (a *TypedChatModelAgent[M]) GetType() string { return "ChatModel" } @@ -597,7 +647,7 @@ func (a *ChatModelAgent) GetType() string { // NOT RECOMMENDED: Agent transfer with full context sharing between agents has not proven // to be more effective empirically. Consider using ChatModelAgent with AgentTool // or DeepAgent instead for most multi-agent scenarios. -func (a *ChatModelAgent) OnSetSubAgents(_ context.Context, subAgents []Agent) error { +func (a *TypedChatModelAgent[M]) OnSetSubAgents(_ context.Context, subAgents []TypedAgent[M]) error { if atomic.LoadUint32(&a.frozen) == 1 { return errors.New("agent has been frozen after run") } @@ -615,7 +665,7 @@ func (a *ChatModelAgent) OnSetSubAgents(_ context.Context, subAgents []Agent) er // NOT RECOMMENDED: Agent transfer with full context sharing between agents has not proven // to be more effective empirically. Consider using ChatModelAgent with AgentTool // or DeepAgent instead for most multi-agent scenarios. -func (a *ChatModelAgent) OnSetAsSubAgent(_ context.Context, parent Agent) error { +func (a *TypedChatModelAgent[M]) OnSetAsSubAgent(_ context.Context, parent TypedAgent[M]) error { if atomic.LoadUint32(&a.frozen) == 1 { return errors.New("agent has been frozen after run") } @@ -633,7 +683,7 @@ func (a *ChatModelAgent) OnSetAsSubAgent(_ context.Context, parent Agent) error // NOT RECOMMENDED: Agent transfer with full context sharing between agents has not proven // to be more effective empirically. Consider using ChatModelAgent with AgentTool // or DeepAgent instead for most multi-agent scenarios. -func (a *ChatModelAgent) OnDisallowTransferToParent(_ context.Context) error { +func (a *TypedChatModelAgent[M]) OnDisallowTransferToParent(_ context.Context) error { if atomic.LoadUint32(&a.frozen) == 1 { return errors.New("agent has been frozen after run") } @@ -652,24 +702,41 @@ func init() { schema.RegisterName[*ChatModelAgentInterruptInfo]("_eino_adk_chat_model_agent_interrupt_info") } -func setOutputToSession(ctx context.Context, msg Message, msgStream MessageStream, outputKey string) error { - if msg != nil { - AddSessionValue(ctx, outputKey, msg.Content) +func extractTextContent[M messageType](msg M) string { + switch v := any(msg).(type) { + case *schema.Message: + return v.Content + case *schema.AgenticMessage: + var texts []string + for _, block := range v.ContentBlocks { + if block != nil && block.Type == schema.ContentBlockTypeAssistantGenText && block.AssistantGenText != nil { + texts = append(texts, block.AssistantGenText.Text) + } + } + return strings.Join(texts, "\n") + default: + return "" + } +} + +func setOutputToSession[M messageType](ctx context.Context, msg M, msgStream *schema.StreamReader[M], outputKey string) error { + if !isNilMessage(msg) { + AddSessionValue(ctx, outputKey, extractTextContent(msg)) return nil } - concatenated, err := schema.ConcatMessageStream(msgStream) + concatenated, err := concatMessageStream(msgStream) if err != nil { return err } - AddSessionValue(ctx, outputKey, concatenated.Content) + AddSessionValue(ctx, outputKey, extractTextContent(concatenated)) return nil } -func errFunc(err error) runFunc { - return func(ctx context.Context, p *runParams) { - p.generator.Send(&AgentEvent{Err: err}) +func typedErrFunc[M messageType](err error) typedRunFunc[M] { + return func(ctx context.Context, p *typedRunParams[M]) { + p.generator.Send(&TypedAgentEvent[M]{Err: err}) } } @@ -693,7 +760,7 @@ type execContext struct { toolUpdated bool // whether needs to pass a compose.WithToolList option to ToolsNode due to tool list change } -func (a *ChatModelAgent) applyBeforeAgent(ctx context.Context, ec *execContext) (context.Context, *execContext, error) { +func (a *TypedChatModelAgent[M]) applyBeforeAgent(ctx context.Context, ec *execContext) (context.Context, *execContext, error) { runCtx := &ChatModelAgentContext{ Instruction: ec.instruction, Tools: cloneSlice(ec.unwrappedTools), @@ -734,7 +801,7 @@ func (a *ChatModelAgent) applyBeforeAgent(ctx context.Context, ec *execContext) return ctx, runtimeEC, nil } -func (a *ChatModelAgent) prepareExecContext(ctx context.Context) (*execContext, error) { +func (a *TypedChatModelAgent[M]) prepareExecContext(ctx context.Context) (*execContext, error) { instruction := a.instruction toolsNodeConf := compose.ToolsNodeConfig{ Tools: cloneSlice(a.toolsConfig.Tools), @@ -797,39 +864,39 @@ func (a *ChatModelAgent) prepareExecContext(ctx context.Context) (*execContext, // handleRunFuncError is the common error handler for buildNoToolsRunFunc and buildReActRunFunc. // It handles compose interrupts (both cancel-triggered and business) // and generic errors, sending the appropriate event to the generator. -func (a *ChatModelAgent) handleRunFuncError( +func (a *TypedChatModelAgent[M]) handleRunFuncError( ctx context.Context, err error, cancelCtx *cancelContext, cancelCtxOwned bool, store *bridgeStore, - generator *AsyncGenerator[*AgentEvent], + generator *AsyncGenerator[*TypedAgentEvent[M]], ) { info, ok := compose.ExtractInterruptInfo(err) if ok { if cancelCtx != nil { - // Note: there is a benign TOCTOU window here. Between shouldCancel() - // returning false and markDone() executing, a concurrent cancel could - // transition stateRunning→stateCancelling. markDone() then does - // stateCancelling→stateDone, and the cancel func receives - // ErrExecutionEnded (execution finished before cancel took effect). if !cancelCtx.shouldCancel() { + // Note: there is a benign TOCTOU window here. Between shouldCancel() + // returning false and markDone() executing, a concurrent cancel could + // transition stateRunning→stateCancelling. markDone() then does + // stateCancelling→stateDone, and the cancel func receives + // ErrExecutionEnded (execution finished before cancel took effect). cancelCtx.markDone() } } data, existed, sErr := store.Get(ctx, bridgeCheckpointID) if sErr != nil { - generator.Send(&AgentEvent{AgentName: a.name, Err: fmt.Errorf("failed to get interrupt info: %w", sErr)}) + generator.Send(&TypedAgentEvent[M]{AgentName: a.name, Err: fmt.Errorf("failed to get interrupt info: %w", sErr)}) return } if !existed { - generator.Send(&AgentEvent{AgentName: a.name, Err: fmt.Errorf("interrupt occurred but checkpoint data is missing")}) + generator.Send(&TypedAgentEvent[M]{AgentName: a.name, Err: fmt.Errorf("interrupt occurred but checkpoint data is missing")}) return } is := FromInterruptContexts(info.InterruptContexts) - event := CompositeInterrupt(ctx, info, data, is) + event := TypedCompositeInterrupt[M](ctx, info, data, is) event.Action.Interrupted.Data = &ChatModelAgentInterruptInfo{ Info: info, Data: data, @@ -842,20 +909,30 @@ func (a *ChatModelAgent) handleRunFuncError( if cancelCtxOwned && cancelCtx != nil { cancelCtx.markDone() } - generator.Send(&AgentEvent{Err: err}) + generator.Send(&TypedAgentEvent[M]{Err: err}) } -func (a *ChatModelAgent) buildNoToolsRunFunc(_ context.Context) runFunc { - type noToolsInput struct { - input *AgentInput - instruction string +type typedNoToolsInput[M messageType] struct { + input *TypedAgentInput[M] + instruction string +} + +func appendModelToChain[I, O any, M messageType](chain *compose.Chain[I, O], m model.BaseModel[M]) { + var zero M + switch any(zero).(type) { + case *schema.Message: + chain.AppendChatModel(any(m).(model.BaseChatModel)) + case *schema.AgenticMessage: + chain.AppendAgenticModel(any(m).(model.AgenticModel)) } +} - return func(ctx context.Context, p *runParams) { +func (a *TypedChatModelAgent[M]) buildNoToolsRunFunc(_ context.Context) (typedRunFunc[M], error) { + return func(ctx context.Context, p *typedRunParams[M]) { cancelCtx := p.cancelCtx ctx = withCancelContext(ctx, cancelCtx) - wrappedModel := buildModelWrappers(a.model, &modelWrapperConfig{ + wrappedModel := buildModelWrappers(a.model, &typedModelWrapperConfig[M]{ handlers: a.handlers, middlewares: a.middlewares, retryConfig: a.modelRetryConfig, @@ -863,22 +940,26 @@ func (a *ChatModelAgent) buildNoToolsRunFunc(_ context.Context) runFunc { cancelContext: cancelCtx, }) - chain := compose.NewChain[noToolsInput, Message]( - compose.WithGenLocalState(func(ctx context.Context) (state *State) { - return &State{} - })). - AppendLambda(compose.InvokableLambda(func(ctx context.Context, in noToolsInput) ([]Message, error) { - messages, err := a.genModelInput(ctx, in.instruction, in.input) - if err != nil { - return nil, err - } - _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { - st.Messages = append(st.Messages, messages...) - return nil - }) - return messages, nil - })). - AppendChatModel(wrappedModel) + chain := compose.NewChain[typedNoToolsInput[M], M]( + compose.WithGenLocalState(func(ctx context.Context) (state *typedState[M]) { + return &typedState[M]{} + })) + + chain.AppendLambda(compose.InvokableLambda(func(ctx context.Context, in typedNoToolsInput[M]) ([]M, error) { + messages, err := a.genModelInput(ctx, in.instruction, in.input) + if err != nil { + return nil, err + } + if err := compose.ProcessState(ctx, func(_ context.Context, st *typedState[M]) error { + st.Messages = append(st.Messages, messages...) + return nil + }); err != nil { + return nil, err + } + return messages, nil + })) + + appendModelToChain(chain, wrappedModel) var compileOptions []compose.GraphCompileOption compileOptions = append(compileOptions, @@ -894,11 +975,11 @@ func (a *ChatModelAgent) buildNoToolsRunFunc(_ context.Context) runFunc { r, err := chain.Compile(ctx, compileOptions...) if err != nil { - p.generator.Send(&AgentEvent{Err: err}) + p.generator.Send(&TypedAgentEvent[M]{Err: err}) return } - ctx = withChatModelAgentExecCtx(ctx, &chatModelAgentExecCtx{ + ctx = withTypedChatModelAgentExecCtx(ctx, &typedChatModelAgentExecCtx[M]{ generator: p.generator, cancelCtx: cancelCtx, failoverLastSuccessModel: a.model, @@ -911,15 +992,15 @@ func (a *ChatModelAgent) buildNoToolsRunFunc(_ context.Context) runFunc { if !ok { return } - p.generator.Send(&AgentEvent{Err: cancelErr}) + p.generator.Send(&TypedAgentEvent[M]{Err: cancelErr}) return } } - in := noToolsInput{input: p.input, instruction: p.instruction} + in := typedNoToolsInput[M]{input: p.input, instruction: p.instruction} - var msg Message - var msgStream MessageStream + var msg M + var msgStream *schema.StreamReader[M] if p.input.EnableStreaming { msgStream, err = r.Stream(ctx, in, p.composeOpts...) } else { @@ -930,7 +1011,7 @@ func (a *ChatModelAgent) buildNoToolsRunFunc(_ context.Context) runFunc { if a.outputKey != "" { err = setOutputToSession(ctx, msg, msgStream, a.outputKey) if err != nil { - p.generator.Send(&AgentEvent{Err: err}) + p.generator.Send(&TypedAgentEvent[M]{Err: err}) } } else if msgStream != nil { msgStream.Close() @@ -939,15 +1020,37 @@ func (a *ChatModelAgent) buildNoToolsRunFunc(_ context.Context) runFunc { } a.handleRunFuncError(ctx, err, cancelCtx, p.cancelCtxOwned, p.store, p.generator) + }, nil +} + +func (a *TypedChatModelAgent[M]) buildReActRunFunc(ctx context.Context, bc *execContext) (typedRunFunc[M], error) { + var zero M + switch any(zero).(type) { + case *schema.Message: + return a.buildMessageReActRunFunc(ctx, bc) + case *schema.AgenticMessage: + // single-shot: agentic models handle tool calling internally + return a.buildAgenticReActRunFunc(ctx, bc) + default: + return nil, fmt.Errorf("unsupported message type %T for ReAct run mode", zero) } } -func (a *ChatModelAgent) buildReActRunFunc(_ context.Context, bc *execContext) (runFunc, error) { - conf := &reactConfig{ - model: a.model, +type reactRunInput struct { + input *AgentInput + instruction string +} + +func (a *TypedChatModelAgent[M]) buildMessageReActRunFunc(ctx context.Context, bc *execContext) (typedRunFunc[M], error) { + // safe: only called when M = *schema.Message (guarded by type switch in buildReActRunFunc) + msgModel := any(a.model).(model.BaseChatModel) + msgHandlers := any(a.handlers).([]ChatModelAgentMiddleware) + genModelInputFn := any(a.genModelInput).(GenModelInput) + msgConf := &reactConfig{ + model: msgModel, toolsConfig: &bc.toolsNodeConf, modelWrapperConf: &modelWrapperConfig{ - handlers: a.handlers, + handlers: msgHandlers, middlewares: a.middlewares, retryConfig: a.modelRetryConfig, failoverConfig: a.modelFailoverConfig, @@ -958,29 +1061,25 @@ func (a *ChatModelAgent) buildReActRunFunc(_ context.Context, bc *execContext) ( maxIterations: a.maxIterations, } - type reactRunInput struct { - input *AgentInput - instruction string - } - - return func(ctx context.Context, p *runParams) { - cancelCtx := p.cancelCtx - conf.cancelCtx = cancelCtx - if conf.modelWrapperConf != nil { - conf.modelWrapperConf.cancelContext = cancelCtx + return func(ctx context.Context, p *typedRunParams[M]) { + mp := any(p).(*typedRunParams[*schema.Message]) + cancelCtx := mp.cancelCtx + msgConf.cancelCtx = cancelCtx + if msgConf.modelWrapperConf != nil { + msgConf.modelWrapperConf.cancelContext = cancelCtx } ctx = withCancelContext(ctx, cancelCtx) - g, err := newReact(ctx, conf) + g, err := newReact(ctx, msgConf) if err != nil { - p.generator.Send(&AgentEvent{Err: err}) + mp.generator.Send(&AgentEvent{Err: err}) return } chain := compose.NewChain[reactRunInput, Message](). AppendLambda( compose.InvokableLambda(func(ctx context.Context, in reactRunInput) (*reactInput, error) { - messages, genErr := a.genModelInput(ctx, in.instruction, in.input) + messages, genErr := genModelInputFn(ctx, in.instruction, in.input) if genErr != nil { return nil, genErr } @@ -994,7 +1093,7 @@ func (a *ChatModelAgent) buildReActRunFunc(_ context.Context, bc *execContext) ( var compileOptions []compose.GraphCompileOption compileOptions = append(compileOptions, compose.WithGraphName(a.name), - compose.WithCheckPointStore(p.store), + compose.WithCheckPointStore(mp.store), compose.WithSerializer(&gobSerializer{}), compose.WithMaxRunSteps(math.MaxInt)) @@ -1006,15 +1105,15 @@ func (a *ChatModelAgent) buildReActRunFunc(_ context.Context, bc *execContext) ( runnable, err_ := chain.Compile(ctx, compileOptions...) if err_ != nil { - p.generator.Send(&AgentEvent{Err: err_}) + mp.generator.Send(&AgentEvent{Err: err_}) return } - ctx = withChatModelAgentExecCtx(ctx, &chatModelAgentExecCtx{ - runtimeReturnDirectly: p.returnDirectly, - generator: p.generator, + ctx = withTypedChatModelAgentExecCtx[*schema.Message](ctx, &chatModelAgentExecCtx{ + runtimeReturnDirectly: mp.returnDirectly, + generator: mp.generator, cancelCtx: cancelCtx, - failoverLastSuccessModel: a.model, + failoverLastSuccessModel: msgModel, }) // Pre-execution cancel check @@ -1024,28 +1123,149 @@ func (a *ChatModelAgent) buildReActRunFunc(_ context.Context, bc *execContext) ( if !ok { return } - p.generator.Send(&AgentEvent{Err: cancelErr}) + mp.generator.Send(&AgentEvent{Err: cancelErr}) return } } in := reactRunInput{ - input: p.input, - instruction: p.instruction, + input: mp.input, + instruction: mp.instruction, } var runOpts []compose.Option - runOpts = append(runOpts, p.composeOpts...) + runOpts = append(runOpts, mp.composeOpts...) if a.toolsConfig.EmitInternalEvents { - runOpts = append(runOpts, compose.WithToolsNodeOption(compose.WithToolOption(withAgentToolEventGenerator(p.generator)))) + runOpts = append(runOpts, compose.WithToolsNodeOption(compose.WithToolOption(withAgentToolEventGenerator(mp.generator)))) } - if p.input.EnableStreaming { + if mp.input.EnableStreaming { runOpts = append(runOpts, compose.WithToolsNodeOption(compose.WithToolOption(withAgentToolEnableStreaming(true)))) } var msg Message var msgStream MessageStream - if p.input.EnableStreaming { + if mp.input.EnableStreaming { + msgStream, err_ = runnable.Stream(ctx, in, runOpts...) + } else { + msg, err_ = runnable.Invoke(ctx, in, runOpts...) + } + + if err_ == nil { + if a.outputKey != "" { + err_ = setOutputToSession[*schema.Message](ctx, msg, msgStream, a.outputKey) + if err_ != nil { + mp.generator.Send(&AgentEvent{Err: err_}) + } + } else if msgStream != nil { + msgStream.Close() + } + + return + } + + a.handleRunFuncError(ctx, err_, cancelCtx, mp.cancelCtxOwned, mp.store, p.generator) + }, nil +} + +type agenticReactRunInput struct { + input *TypedAgentInput[*schema.AgenticMessage] + instruction string +} + +func (a *TypedChatModelAgent[M]) buildAgenticReActRunFunc(ctx context.Context, bc *execContext) (typedRunFunc[M], error) { + agenticModel := any(a.model).(model.AgenticModel) + agenticHandlers := any(a.handlers).([]TypedChatModelAgentMiddleware[*schema.AgenticMessage]) + genModelInputFn := any(a.genModelInput).(TypedGenModelInput[*schema.AgenticMessage]) + agenticConf := &agenticReactConfig{ + model: agenticModel, + toolsConfig: &bc.toolsNodeConf, + modelWrapperConf: &typedModelWrapperConfig[*schema.AgenticMessage]{ + handlers: agenticHandlers, + middlewares: a.middlewares, + retryConfig: a.modelRetryConfig, + toolInfos: bc.toolInfos, + }, + toolsReturnDirectly: bc.returnDirectly, + agentName: a.name, + maxIterations: a.maxIterations, + } + + return func(ctx context.Context, p *typedRunParams[M]) { + ap := any(p).(*typedRunParams[*schema.AgenticMessage]) + cancelCtx := ap.cancelCtx + agenticConf.cancelCtx = cancelCtx + if agenticConf.modelWrapperConf != nil { + agenticConf.modelWrapperConf.cancelContext = cancelCtx + } + ctx = withCancelContext(ctx, cancelCtx) + + g, err := newAgenticReact(ctx, agenticConf) + if err != nil { + ap.generator.Send(&TypedAgentEvent[*schema.AgenticMessage]{Err: err}) + return + } + + chain := compose.NewChain[agenticReactRunInput, *schema.AgenticMessage](). + AppendLambda( + compose.InvokableLambda(func(ctx context.Context, in agenticReactRunInput) (*agenticReactInput, error) { + messages, genErr := genModelInputFn(ctx, in.instruction, in.input) + if genErr != nil { + return nil, genErr + } + return &agenticReactInput{ + Messages: messages, + }, nil + }), + ). + AppendGraph(g, compose.WithNodeName("ReAct"), compose.WithGraphCompileOptions(compose.WithMaxRunSteps(math.MaxInt))) + + var compileOptions []compose.GraphCompileOption + compileOptions = append(compileOptions, + compose.WithGraphName(a.name), + compose.WithCheckPointStore(ap.store), + compose.WithSerializer(&gobSerializer{}), + compose.WithMaxRunSteps(math.MaxInt)) + + if cancelCtx != nil { + var interrupt func(...compose.GraphInterruptOption) + ctx, interrupt = compose.WithGraphInterrupt(ctx) + cancelCtx.setGraphInterruptFunc(cancelCtx.wrapGraphInterruptWithGracePeriod(interrupt)) + } + + runnable, err_ := chain.Compile(ctx, compileOptions...) + if err_ != nil { + ap.generator.Send(&TypedAgentEvent[*schema.AgenticMessage]{Err: err_}) + return + } + + ctx = withTypedChatModelAgentExecCtx(ctx, &typedChatModelAgentExecCtx[*schema.AgenticMessage]{ + runtimeReturnDirectly: ap.returnDirectly, + generator: ap.generator, + }) + + // Pre-execution cancel check + if cancelCtx != nil && cancelCtx.shouldCancel() { + if cancelCtx.getMode() == CancelImmediate || atomic.LoadInt32(&cancelCtx.escalated) == 1 { + cancelErr, ok := cancelCtx.createAndMarkCancelHandled() + if !ok { + return + } + ap.generator.Send(&TypedAgentEvent[*schema.AgenticMessage]{Err: cancelErr}) + return + } + } + + in := agenticReactRunInput{input: ap.input, instruction: ap.instruction} + + var runOpts []compose.Option + runOpts = append(runOpts, ap.composeOpts...) + if ap.input.EnableStreaming { + runOpts = append(runOpts, compose.WithToolsNodeOption(compose.WithToolOption(withAgentToolEnableStreaming(true)))) + } + + var msg *schema.AgenticMessage + var msgStream *schema.StreamReader[*schema.AgenticMessage] + if ap.input.EnableStreaming { msgStream, err_ = runnable.Stream(ctx, in, runOpts...) } else { msg, err_ = runnable.Invoke(ctx, in, runOpts...) @@ -1055,7 +1275,7 @@ func (a *ChatModelAgent) buildReActRunFunc(_ context.Context, bc *execContext) ( if a.outputKey != "" { err_ = setOutputToSession(ctx, msg, msgStream, a.outputKey) if err_ != nil { - p.generator.Send(&AgentEvent{Err: err_}) + ap.generator.Send(&TypedAgentEvent[*schema.AgenticMessage]{Err: err_}) } } else if msgStream != nil { msgStream.Close() @@ -1064,28 +1284,35 @@ func (a *ChatModelAgent) buildReActRunFunc(_ context.Context, bc *execContext) ( return } - a.handleRunFuncError(ctx, err_, cancelCtx, p.cancelCtxOwned, p.store, p.generator) + a.handleRunFuncError(ctx, err_, cancelCtx, ap.cancelCtxOwned, ap.store, p.generator) }, nil } -func (a *ChatModelAgent) buildRunFunc(ctx context.Context) runFunc { +func (a *TypedChatModelAgent[M]) buildRunFunc(ctx context.Context) typedRunFunc[M] { a.once.Do(func() { ec, err := a.prepareExecContext(ctx) if err != nil { - a.run = errFunc(err) + a.run = typedErrFunc[M](err) return } a.exeCtx = ec if len(ec.toolsNodeConf.Tools) == 0 { - a.run = a.buildNoToolsRunFunc(ctx) + var run typedRunFunc[M] + run, err = a.buildNoToolsRunFunc(ctx) + if err != nil { + a.run = typedErrFunc[M](err) + return + } + a.run = run return } - run, err := a.buildReActRunFunc(ctx, ec) + var run typedRunFunc[M] + run, err = a.buildReActRunFunc(ctx, ec) if err != nil { - a.run = errFunc(err) + a.run = typedErrFunc[M](err) return } a.run = run @@ -1096,7 +1323,7 @@ func (a *ChatModelAgent) buildRunFunc(ctx context.Context) runFunc { return a.run } -func (a *ChatModelAgent) getRunFunc(ctx context.Context) (context.Context, runFunc, *execContext, error) { +func (a *TypedChatModelAgent[M]) getRunFunc(ctx context.Context) (context.Context, typedRunFunc[M], *execContext, error) { defaultRun := a.buildRunFunc(ctx) bc := a.exeCtx @@ -1123,9 +1350,12 @@ func (a *ChatModelAgent) getRunFunc(ctx context.Context) (context.Context, runFu return ctx, defaultRun, runtimeBC, nil } - var tempRun runFunc + var tempRun typedRunFunc[M] if len(runtimeBC.toolsNodeConf.Tools) == 0 { - tempRun = a.buildNoToolsRunFunc(ctx) + tempRun, err = a.buildNoToolsRunFunc(ctx) + if err != nil { + return ctx, nil, nil, err + } } else { tempRun, err = a.buildReActRunFunc(ctx, runtimeBC) if err != nil { @@ -1136,8 +1366,8 @@ func (a *ChatModelAgent) getRunFunc(ctx context.Context) (context.Context, runFu return ctx, tempRun, runtimeBC, nil } -func (a *ChatModelAgent) Run(ctx context.Context, input *AgentInput, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { - iterator, generator := NewAsyncIteratorPair[*AgentEvent]() +func (a *TypedChatModelAgent[M]) Run(ctx context.Context, input *TypedAgentInput[M], opts ...AgentRunOption) *AsyncIterator[*TypedAgentEvent[M]] { + iterator, generator := NewAsyncIteratorPair[*TypedAgentEvent[M]]() o := getCommonOptions(nil, opts...) cancelCtx := o.cancelCtx @@ -1152,7 +1382,7 @@ func (a *ChatModelAgent) Run(ctx context.Context, input *AgentInput, opts ...Age if cancelCtxOwned && cancelCtx != nil { defer cancelCtx.markDone() } - generator.Send(&AgentEvent{Err: fmt.Errorf("ChatModelAgent getRunFunc error: %w", err)}) + generator.Send(&TypedAgentEvent[M]{Err: fmt.Errorf("ChatModelAgent getRunFunc error: %w", err)}) generator.Close() }() return iterator @@ -1173,7 +1403,7 @@ func (a *ChatModelAgent) Run(ctx context.Context, input *AgentInput, opts ...Age panicErr := recover() if panicErr != nil { e := safe.NewPanicErr(panicErr, debug.Stack()) - generator.Send(&AgentEvent{Err: e}) + generator.Send(&TypedAgentEvent[M]{Err: e}) } generator.Close() @@ -1189,7 +1419,7 @@ func (a *ChatModelAgent) Run(ctx context.Context, input *AgentInput, opts ...Age returnDirectly = bc.returnDirectly } - run(ctx, &runParams{ + run(ctx, &typedRunParams[M]{ input: input, generator: generator, store: newBridgeStore(), @@ -1207,8 +1437,8 @@ func (a *ChatModelAgent) Run(ctx context.Context, input *AgentInput, opts ...Age return iterator } -func (a *ChatModelAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { - iterator, generator := NewAsyncIteratorPair[*AgentEvent]() +func (a *TypedChatModelAgent[M]) Resume(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*TypedAgentEvent[M]] { + iterator, generator := NewAsyncIteratorPair[*TypedAgentEvent[M]]() o := getCommonOptions(nil, opts...) cancelCtx := o.cancelCtx @@ -1223,7 +1453,7 @@ func (a *ChatModelAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...A if cancelCtxOwned && cancelCtx != nil { defer cancelCtx.markDone() } - generator.Send(&AgentEvent{Err: fmt.Errorf("ChatModelAgent getRunFunc error: %w", err)}) + generator.Send(&TypedAgentEvent[M]{Err: fmt.Errorf("ChatModelAgent getRunFunc error: %w", err)}) generator.Close() }() return iterator @@ -1262,7 +1492,7 @@ func (a *ChatModelAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...A stateByte, err = preprocessComposeCheckpoint(stateByte) if err != nil { go func() { - generator.Send(&AgentEvent{Err: err}) + generator.Send(&TypedAgentEvent[M]{Err: err}) generator.Close() }() return iterator @@ -1294,7 +1524,7 @@ func (a *ChatModelAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...A panicErr := recover() if panicErr != nil { e := safe.NewPanicErr(panicErr, debug.Stack()) - generator.Send(&AgentEvent{Err: e}) + generator.Send(&TypedAgentEvent[M]{Err: e}) } generator.Close() @@ -1310,8 +1540,8 @@ func (a *ChatModelAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...A returnDirectly = bc.returnDirectly } - run(ctx, &runParams{ - input: &AgentInput{EnableStreaming: info.EnableStreaming}, + run(ctx, &typedRunParams[M]{ + input: &TypedAgentInput[M]{EnableStreaming: info.EnableStreaming}, generator: generator, store: newResumeBridgeStore(bridgeCheckpointID, stateByte), instruction: instruction, diff --git a/adk/deterministic_transfer.go b/adk/deterministic_transfer.go index dc677a007..ce5b20093 100644 --- a/adk/deterministic_transfer.go +++ b/adk/deterministic_transfer.go @@ -250,7 +250,7 @@ func handleFlowAgentEvents(ctx context.Context, iter *AsyncIterator[*AgentEvent] } if parentSession != nil && (event.Action == nil || event.Action.Interrupted == nil) { - copied := copyAgentEvent(event) + copied := copyTypedAgentEvent(event) setAutomaticClose(copied) setAutomaticClose(event) parentSession.addEvent(copied) diff --git a/adk/failover_chatmodel.go b/adk/failover_chatmodel.go index 898aedd7c..a0f60ea85 100644 --- a/adk/failover_chatmodel.go +++ b/adk/failover_chatmodel.go @@ -31,21 +31,13 @@ import ( type failoverCurrentModelKey struct{} -type failoverCurrentModel struct { - model model.BaseChatModel +func typedSetFailoverCurrentModel[M messageType](ctx context.Context, currentModel model.BaseModel[M]) context.Context { + return context.WithValue(ctx, failoverCurrentModelKey{}, currentModel) } -func setFailoverCurrentModel(ctx context.Context, currentModel model.BaseChatModel) context.Context { - return context.WithValue(ctx, failoverCurrentModelKey{}, &failoverCurrentModel{ - model: currentModel, - }) -} - -func getFailoverCurrentModel(ctx context.Context) *failoverCurrentModel { - if fm, ok := ctx.Value(failoverCurrentModelKey{}).(*failoverCurrentModel); ok { - return fm - } - return nil +func typedGetFailoverCurrentModel[M messageType](ctx context.Context) (model.BaseModel[M], bool) { + m, ok := ctx.Value(failoverCurrentModelKey{}).(model.BaseModel[M]) + return m, ok } type failoverHasMoreAttemptsKey struct{} @@ -64,30 +56,30 @@ func getFailoverHasMoreAttempts(ctx context.Context) bool { return v } -type failoverProxyModel struct { +type typedFailoverProxyModel[M messageType] struct { } -func (m *failoverProxyModel) prepareCallbacks(ctx context.Context) (context.Context, model.BaseChatModel, error) { - current := getFailoverCurrentModel(ctx) - if current == nil || current.model == nil { +func (m *typedFailoverProxyModel[M]) prepareCallbacks(ctx context.Context) (context.Context, model.BaseModel[M], error) { + target, ok := typedGetFailoverCurrentModel[M](ctx) + if !ok { return nil, nil, errors.New("failover current model not found in context") } - typ, _ := components.GetType(current.model) + typ, _ := components.GetType(target) ctx = callbacks.EnsureRunInfo(ctx, typ, components.ComponentOfChatModel) - target := current.model if !components.IsCallbacksEnabled(target) { - target = (&callbackInjectionModelWrapper{}).WrapModel(target) + target = typedCallbackInjectionModelWrapper[M]{}.wrapModel(target) } return ctx, target, nil } -func (m *failoverProxyModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { +func (m *typedFailoverProxyModel[M]) Generate(ctx context.Context, input []M, opts ...model.Option) (M, error) { nCtx, target, err := m.prepareCallbacks(ctx) if err != nil { - return nil, err + var zero M + return zero, err } ctx = callbacks.OnStart(ctx, input) @@ -103,7 +95,7 @@ func (m *failoverProxyModel) Generate(ctx context.Context, input []*schema.Messa return result, nil } -func (m *failoverProxyModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { +func (m *typedFailoverProxyModel[M]) Stream(ctx context.Context, input []M, opts ...model.Option) (*schema.StreamReader[M], error) { nCtx, target, err := m.prepareCallbacks(ctx) if err != nil { return nil, err @@ -121,14 +113,16 @@ func (m *failoverProxyModel) Stream(ctx context.Context, input []*schema.Message return wrappedStream, nil } -func (m *failoverProxyModel) IsCallbacksEnabled() bool { +func (m *typedFailoverProxyModel[M]) IsCallbacksEnabled() bool { return true } -func (m *failoverProxyModel) GetType() string { +func (m *typedFailoverProxyModel[M]) GetType() string { return "FailoverProxyModel" } +type failoverProxyModel = typedFailoverProxyModel[*schema.Message] + // FailoverContext contains context information during failover process. type FailoverContext struct { // FailoverAttempt is the current failover attempt number, starting from 1. @@ -199,32 +193,35 @@ type ModelFailoverConfig struct { failoverModel model.BaseChatModel, failoverModelInputMessages []*schema.Message, failoverErr error) } -func getLastSuccessModel(ctx context.Context) model.BaseChatModel { - if execCtx := getChatModelAgentExecCtx(ctx); execCtx != nil { - return execCtx.failoverLastSuccessModel +func typedGetFailoverLastSuccessModel[M messageType](ctx context.Context) model.BaseModel[M] { + execCtx := getTypedChatModelAgentExecCtx[M](ctx) + if execCtx == nil { + return nil } - return nil + return execCtx.failoverLastSuccessModel } -func setLastSuccessModel(ctx context.Context, m model.BaseChatModel) { - if execCtx := getChatModelAgentExecCtx(ctx); execCtx != nil { +func typedSetFailoverLastSuccessModel[M messageType](ctx context.Context, m model.BaseModel[M]) { + if execCtx := getTypedChatModelAgentExecCtx[M](ctx); execCtx != nil { execCtx.failoverLastSuccessModel = m } } -type failoverModelWrapper struct { +type typedFailoverModelWrapper[M messageType] struct { config *ModelFailoverConfig - inner model.BaseChatModel + inner model.BaseModel[M] } -func newFailoverModelWrapper(inner model.BaseChatModel, config *ModelFailoverConfig) *failoverModelWrapper { - return &failoverModelWrapper{ +type failoverModelWrapper = typedFailoverModelWrapper[*schema.Message] + +func newTypedFailoverModelWrapper[M messageType](inner model.BaseModel[M], config *ModelFailoverConfig) *typedFailoverModelWrapper[M] { + return &typedFailoverModelWrapper[M]{ config: config, inner: inner, } } -func (f *failoverModelWrapper) needFailover(ctx context.Context, outputMessage *schema.Message, outputErr error) bool { +func (f *typedFailoverModelWrapper[M]) needFailover(ctx context.Context, outputMessage M, outputErr error) bool { if ctx.Err() != nil { return false } @@ -236,25 +233,51 @@ func (f *failoverModelWrapper) needFailover(ctx context.Context, outputMessage * } // ShouldFailover is validated at agent construction; nil here indicates a programmer error. - return f.config.ShouldFailover(ctx, outputMessage, outputErr) + schemaMsg, _ := any(outputMessage).(*schema.Message) + return f.config.ShouldFailover(ctx, schemaMsg, outputErr) +} + +func (f *typedFailoverModelWrapper[M]) getFailoverModel(ctx context.Context, failoverCtx *FailoverContext) (model.BaseModel[M], []M, error) { + chatModel, msgs, err := f.config.GetFailoverModel(ctx, failoverCtx) + if err != nil { + return nil, nil, err + } + if chatModel == nil { + return nil, nil, nil + } + + typedModel, ok := any(chatModel).(model.BaseModel[M]) + if !ok { + return nil, nil, fmt.Errorf("failover GetFailoverModel returned model of type %T, expected model.BaseModel[%T]", chatModel, *new(M)) + } + + var typedMsgs []M + if msgs != nil { + if m, ok := any(msgs).([]M); ok { + typedMsgs = m + } + } + + return typedModel, typedMsgs, nil } -func (f *failoverModelWrapper) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { +func (f *typedFailoverModelWrapper[M]) Generate(ctx context.Context, input []M, opts ...model.Option) (M, error) { // Defensive: GetFailoverModel is validated non-nil at agent construction. if f.config.GetFailoverModel == nil { return f.inner.Generate(ctx, input, opts...) } - var lastOutputMessage *schema.Message + var lastOutputMessage M var lastErr error // Try lastSuccessModel first if available. - if lastSuccess := getLastSuccessModel(ctx); lastSuccess != nil { + if lastSuccess := typedGetFailoverLastSuccessModel[M](ctx); lastSuccess != nil { if err := ctx.Err(); err != nil { - return nil, err + var zero M + return zero, err } - modelCtx := setFailoverCurrentModel(ctx, lastSuccess) + modelCtx := typedSetFailoverCurrentModel(ctx, lastSuccess) modelCtx = withFailoverHasMoreAttempts(modelCtx, f.config.MaxRetries > 0) result, err := f.inner.Generate(modelCtx, input, opts...) if err == nil { @@ -273,36 +296,41 @@ func (f *failoverModelWrapper) Generate(ctx context.Context, input []*schema.Mes for attempt := uint(1); attempt <= f.config.MaxRetries; attempt++ { if err := ctx.Err(); err != nil { - return nil, err + var zero M + return zero, err } + inputMsgs, _ := any(input).([]*schema.Message) + lastOutputMsg, _ := any(lastOutputMessage).(*schema.Message) failoverCtx := &FailoverContext{ FailoverAttempt: attempt, - InputMessages: input, - LastOutputMessage: lastOutputMessage, + InputMessages: inputMsgs, + LastOutputMessage: lastOutputMsg, LastErr: lastErr, } - currentModel, currentInput, err := f.config.GetFailoverModel(ctx, failoverCtx) + currentModel, currentInput, err := f.getFailoverModel(ctx, failoverCtx) if err != nil { - return nil, err + var zero M + return zero, err } if currentModel == nil { - return nil, fmt.Errorf("failover GetFailoverModel returned nil model at attempt %d", attempt) + var zero M + return zero, fmt.Errorf("failover GetFailoverModel returned nil model at attempt %d", attempt) } if currentInput == nil { currentInput = input } - modelCtx := setFailoverCurrentModel(ctx, currentModel) + modelCtx := typedSetFailoverCurrentModel(ctx, currentModel) modelCtx = withFailoverHasMoreAttempts(modelCtx, attempt < f.config.MaxRetries) result, err := f.inner.Generate(modelCtx, currentInput, opts...) lastOutputMessage = result lastErr = err if err == nil { - setLastSuccessModel(ctx, currentModel) + typedSetFailoverLastSuccessModel[M](ctx, currentModel) return result, nil } @@ -318,28 +346,29 @@ func (f *failoverModelWrapper) Generate(ctx context.Context, input []*schema.Mes return lastOutputMessage, lastErr } -func (f *failoverModelWrapper) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) ( - *schema.StreamReader[*schema.Message], error) { +func (f *typedFailoverModelWrapper[M]) Stream(ctx context.Context, input []M, opts ...model.Option) ( + *schema.StreamReader[M], error) { // Defensive: GetFailoverModel is validated non-nil at agent construction. if f.config.GetFailoverModel == nil { return f.inner.Stream(ctx, input, opts...) } - var lastOutputMessage *schema.Message + var lastOutputMessage M var lastErr error // Try lastSuccessModel first if available. - if lastSuccess := getLastSuccessModel(ctx); lastSuccess != nil { + if lastSuccess := typedGetFailoverLastSuccessModel[M](ctx); lastSuccess != nil { if err := ctx.Err(); err != nil { return nil, err } - modelCtx := setFailoverCurrentModel(ctx, lastSuccess) + modelCtx := typedSetFailoverCurrentModel(ctx, lastSuccess) modelCtx = withFailoverHasMoreAttempts(modelCtx, f.config.MaxRetries > 0) stream, err := f.inner.Stream(modelCtx, input, opts...) if err != nil { lastErr = err - if !f.needFailover(ctx, nil, err) { + var zero M + if !f.needFailover(ctx, zero, err) { return nil, err } log.Printf("failover ChatModel.Stream lastSuccessModel failed: %v", err) @@ -348,7 +377,7 @@ func (f *failoverModelWrapper) Stream(ctx context.Context, input []*schema.Messa checkCopy := copies[0] returnCopy := copies[1] - outMsg, streamErr := consumeStream(checkCopy) + outMsg, streamErr := typedConsumeStream(checkCopy) if streamErr != nil { lastOutputMessage = outMsg lastErr = streamErr @@ -369,14 +398,16 @@ func (f *failoverModelWrapper) Stream(ctx context.Context, input []*schema.Messa return nil, err } + inputMsgs2, _ := any(input).([]*schema.Message) + lastOutputMsg2, _ := any(lastOutputMessage).(*schema.Message) failoverCtx := &FailoverContext{ FailoverAttempt: attempt, - InputMessages: input, - LastOutputMessage: lastOutputMessage, + InputMessages: inputMsgs2, + LastOutputMessage: lastOutputMsg2, LastErr: lastErr, } - currentModel, currentInput, err := f.config.GetFailoverModel(ctx, failoverCtx) + currentModel, currentInput, err := f.getFailoverModel(ctx, failoverCtx) if err != nil { return nil, err } @@ -388,14 +419,15 @@ func (f *failoverModelWrapper) Stream(ctx context.Context, input []*schema.Messa currentInput = input } - modelCtx := setFailoverCurrentModel(ctx, currentModel) + modelCtx := typedSetFailoverCurrentModel(ctx, currentModel) modelCtx = withFailoverHasMoreAttempts(modelCtx, attempt < f.config.MaxRetries) stream, err := f.inner.Stream(modelCtx, currentInput, opts...) if err != nil { lastErr = err - lastOutputMessage = nil + var zero M + lastOutputMessage = zero - if !f.needFailover(ctx, nil, err) { + if !f.needFailover(ctx, zero, err) { return nil, err } @@ -425,7 +457,7 @@ func (f *failoverModelWrapper) Stream(ctx context.Context, input []*schema.Messa checkCopy := copies[0] returnCopy := copies[1] - outMsg, streamErr := consumeStream(checkCopy) + outMsg, streamErr := typedConsumeStream(checkCopy) if streamErr != nil { lastOutputMessage = outMsg lastErr = streamErr @@ -441,32 +473,61 @@ func (f *failoverModelWrapper) Stream(ctx context.Context, input []*schema.Messa continue } - setLastSuccessModel(ctx, currentModel) + typedSetFailoverLastSuccessModel[M](ctx, currentModel) return returnCopy, nil } return nil, lastErr } -func consumeStream(stream *schema.StreamReader[*schema.Message]) (*schema.Message, error) { +func typedConsumeStream[M messageType](stream *schema.StreamReader[M]) (M, error) { + var zero M defer stream.Close() - chunks := make([]*schema.Message, 0) - for { - chunk, err := stream.Recv() - if err == io.EOF { - break + + switch s := any(stream).(type) { + case *schema.StreamReader[*schema.Message]: + chunks := make([]*schema.Message, 0) + for { + chunk, err := s.Recv() + if err == io.EOF { + break + } + if err != nil { + msg, _ := schema.ConcatMessages(chunks) + if msg != nil { + return any(msg).(M), err + } + return zero, err + } + chunks = append(chunks, chunk) } - if err != nil { - // ignore concat error - msg, _ := schema.ConcatMessages(chunks) - return msg, err + msg, _ := schema.ConcatMessages(chunks) + if msg != nil { + return any(msg).(M), nil } - - chunks = append(chunks, chunk) + return zero, nil + case *schema.StreamReader[*schema.AgenticMessage]: + chunks := make([]*schema.AgenticMessage, 0) + for { + chunk, err := s.Recv() + if err == io.EOF { + break + } + if err != nil { + msg, _ := schema.ConcatAgenticMessages(chunks) + if msg != nil { + return any(msg).(M), err + } + return zero, err + } + chunks = append(chunks, chunk) + } + msg, _ := schema.ConcatAgenticMessages(chunks) + if msg != nil { + return any(msg).(M), nil + } + return zero, nil + default: + panic("unreachable: unknown messageType") } - - // Stream completed successfully (EOF). ConcatMessages error is not a stream error, - // so ignore it to avoid incorrectly triggering failover. - msg, _ := schema.ConcatMessages(chunks) - return msg, nil } diff --git a/adk/failover_chatmodel_test.go b/adk/failover_chatmodel_test.go index 82866e994..75f87df36 100644 --- a/adk/failover_chatmodel_test.go +++ b/adk/failover_chatmodel_test.go @@ -104,19 +104,21 @@ func TestFailoverCurrentModelContext(t *testing.T) { return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("ok", nil)}), nil }, } - ctx = setFailoverCurrentModel(ctx, m) - got := getFailoverCurrentModel(ctx) - require.NotNil(t, got) - require.Same(t, m, got.model) + ctx = typedSetFailoverCurrentModel[*schema.Message](ctx, m) + got, ok := typedGetFailoverCurrentModel[*schema.Message](ctx) + require.True(t, ok) + require.Same(t, m, got) }) t.Run("wrong type", func(t *testing.T) { ctx := context.WithValue(context.Background(), failoverCurrentModelKey{}, "bad") - require.Nil(t, getFailoverCurrentModel(ctx)) + _, ok := typedGetFailoverCurrentModel[*schema.Message](ctx) + require.False(t, ok) }) t.Run("missing", func(t *testing.T) { - require.Nil(t, getFailoverCurrentModel(context.Background())) + _, ok := typedGetFailoverCurrentModel[*schema.Message](context.Background()) + require.False(t, ok) }) } @@ -145,7 +147,7 @@ func TestFailoverProxyModel(t *testing.T) { return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("routed", nil)}), nil }, } - ctx := setFailoverCurrentModel(context.Background(), target) + ctx := typedSetFailoverCurrentModel[*schema.Message](context.Background(), target) p := &failoverProxyModel{} msg, err := p.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) require.NoError(t, err) @@ -167,7 +169,7 @@ func TestFailoverModelWrapper_Generate(t *testing.T) { return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("inner", nil)}), nil }, } - w := newFailoverModelWrapper(inner, &ModelFailoverConfig{ + w := newTypedFailoverModelWrapper[*schema.Message](inner, &ModelFailoverConfig{ MaxRetries: 2, ShouldFailover: func(context.Context, *schema.Message, error) bool { return true }, GetFailoverModel: nil, @@ -217,8 +219,8 @@ func TestFailoverModelWrapper_Generate(t *testing.T) { }, } - w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) - ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + w := newTypedFailoverModelWrapper[*schema.Message](&failoverProxyModel{}, cfg) + ctx := withTypedChatModelAgentExecCtx[*schema.Message](context.Background(), &chatModelAgentExecCtx{ failoverLastSuccessModel: m1, }) msg, err := w.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) @@ -253,8 +255,8 @@ func TestFailoverModelWrapper_Generate(t *testing.T) { }, } - w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) - ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + w := newTypedFailoverModelWrapper[*schema.Message](&failoverProxyModel{}, cfg) + ctx := withTypedChatModelAgentExecCtx[*schema.Message](context.Background(), &chatModelAgentExecCtx{ failoverLastSuccessModel: m1, }) _, err := w.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) @@ -285,7 +287,7 @@ func TestFailoverModelWrapper_Generate(t *testing.T) { }, } - w := newFailoverModelWrapper(inner, cfg) + w := newTypedFailoverModelWrapper[*schema.Message](inner, cfg) _, err := w.Generate(context.Background(), []*schema.Message{schema.UserMessage("hi")}) require.ErrorIs(t, err, wantErr) require.Equal(t, int32(0), atomic.LoadInt32(&called)) @@ -300,7 +302,7 @@ func TestFailoverModelWrapper_Generate(t *testing.T) { }, } - w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) + w := newTypedFailoverModelWrapper[*schema.Message](&failoverProxyModel{}, cfg) msg, err := w.Generate(context.Background(), []*schema.Message{schema.UserMessage("hi")}) require.Nil(t, msg) require.Error(t, err) @@ -339,8 +341,8 @@ func TestFailoverModelWrapper_Stream(t *testing.T) { }, } - w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) - ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + w := newTypedFailoverModelWrapper[*schema.Message](&failoverProxyModel{}, cfg) + ctx := withTypedChatModelAgentExecCtx[*schema.Message](context.Background(), &chatModelAgentExecCtx{ failoverLastSuccessModel: m1, }) sr, err := w.Stream(ctx, []*schema.Message{in}) @@ -392,8 +394,8 @@ func TestFailoverModelWrapper_Stream(t *testing.T) { }, } - w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) - ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + w := newTypedFailoverModelWrapper[*schema.Message](&failoverProxyModel{}, cfg) + ctx := withTypedChatModelAgentExecCtx[*schema.Message](context.Background(), &chatModelAgentExecCtx{ failoverLastSuccessModel: m1, }) sr, err := w.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) @@ -453,8 +455,8 @@ func TestFailoverModelWrapper_Stream(t *testing.T) { }, } - w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) - ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + w := newTypedFailoverModelWrapper[*schema.Message](&failoverProxyModel{}, cfg) + ctx := withTypedChatModelAgentExecCtx[*schema.Message](context.Background(), &chatModelAgentExecCtx{ failoverLastSuccessModel: m1, }) sr, err := w.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) @@ -491,8 +493,8 @@ func TestFailoverModelWrapper_Stream(t *testing.T) { }, } - w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) - ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + w := newTypedFailoverModelWrapper[*schema.Message](&failoverProxyModel{}, cfg) + ctx := withTypedChatModelAgentExecCtx[*schema.Message](context.Background(), &chatModelAgentExecCtx{ failoverLastSuccessModel: m1, }) sr, err := w.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) @@ -524,8 +526,8 @@ func TestFailoverModelWrapper_Stream(t *testing.T) { }, } - w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) - ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + w := newTypedFailoverModelWrapper[*schema.Message](&failoverProxyModel{}, cfg) + ctx := withTypedChatModelAgentExecCtx[*schema.Message](context.Background(), &chatModelAgentExecCtx{ failoverLastSuccessModel: m1, }) sr, err := w.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) @@ -563,8 +565,8 @@ func TestFailoverModelWrapper_Stream(t *testing.T) { }, } - w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) - ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + w := newTypedFailoverModelWrapper[*schema.Message](&failoverProxyModel{}, cfg) + ctx := withTypedChatModelAgentExecCtx[*schema.Message](context.Background(), &chatModelAgentExecCtx{ failoverLastSuccessModel: m1, }) sr, err := w.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) @@ -583,7 +585,7 @@ func TestFailoverModelWrapper_Stream(t *testing.T) { }, } - w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) + w := newTypedFailoverModelWrapper[*schema.Message](&failoverProxyModel{}, cfg) sr, err := w.Stream(context.Background(), []*schema.Message{schema.UserMessage("hi")}) require.Nil(t, sr) require.Error(t, err) @@ -612,7 +614,7 @@ func TestFailoverModelWrapper_Stream(t *testing.T) { }, } - w := newFailoverModelWrapper(inner, cfg) + w := newTypedFailoverModelWrapper[*schema.Message](inner, cfg) sr, err := w.Stream(context.Background(), []*schema.Message{schema.UserMessage("hi")}) require.Nil(t, sr) require.ErrorIs(t, err, wantErr) @@ -665,8 +667,8 @@ func TestFailoverModelWrapper_Stream(t *testing.T) { }, } - w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) - baseCtx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + w := newTypedFailoverModelWrapper[*schema.Message](&failoverProxyModel{}, cfg) + baseCtx := withTypedChatModelAgentExecCtx[*schema.Message](context.Background(), &chatModelAgentExecCtx{ failoverLastSuccessModel: m1, }) ctx, cancel := context.WithCancel(baseCtx) diff --git a/adk/flow.go b/adk/flow.go index 8edc002a0..7011fa81a 100644 --- a/adk/flow.go +++ b/adk/flow.go @@ -261,7 +261,7 @@ func genMsg(entry *HistoryEntry, agentName string) (Message, error) { return msg, nil } -func (ai *AgentInput) deepCopy() *AgentInput { +func deepCopyAgentInput(ai *AgentInput) *AgentInput { copied := &AgentInput{ Messages: make([]Message, len(ai.Messages)), EnableStreaming: ai.EnableStreaming, @@ -273,7 +273,7 @@ func (ai *AgentInput) deepCopy() *AgentInput { } func (a *flowAgent) genAgentInput(ctx context.Context, runCtx *runContext, skipTransferMessages bool) (*AgentInput, error) { - input := runCtx.RootInput.deepCopy() + input := deepCopyAgentInput(runCtx.RootInput) events := runCtx.Session.getEvents() historyEntries := make([]*HistoryEntry, 0) @@ -521,7 +521,7 @@ func (a *flowAgent) run( // copy before adding to session because once added to session it's stream could be consumed by genAgentInput at any time // interrupt action are not added to session, because ALL information contained in it // is either presented to end-user, or made available to agents through other means - copied := copyAgentEvent(event) + copied := copyTypedAgentEvent(event) setAutomaticClose(copied) setAutomaticClose(event) runCtx.Session.addEvent(copied) @@ -532,7 +532,7 @@ func (a *flowAgent) run( if exactRunPathMatch(runCtx.RunPath, event.RunPath) { lastAction = event.Action } - copied := copyAgentEvent(event) + copied := copyTypedAgentEvent(event) setAutomaticClose(copied) setAutomaticClose(event) cbGen.Send(copied) @@ -604,10 +604,206 @@ func wrapIterWithOnEnd(ctx context.Context, iter *AsyncIterator[*AgentEvent]) *A if !ok { break } - copied := copyAgentEvent(event) + copied := copyTypedAgentEvent(event) cbGen.Send(copied) outGen.Send(event) } }() return outIter } + +// --------------------------------------------------------------------------- +// Typed wrapper for the agentic path (TypedAgent[*schema.AgenticMessage]). +// +// typedFlowAgent is a minimal wrapper used exclusively by TypedRunner and +// AgentTool to execute a TypedAgent[*schema.AgenticMessage]. It handles +// callbacks, event recording, and run-path tracking. Transfer, sub-agent +// orchestration, and history rewriting are handled solely by the concrete +// flowAgent (the *schema.Message path). +// --------------------------------------------------------------------------- + +type typedFlowAgent[M messageType] struct { + TypedAgent[M] + + checkPointStore compose.CheckPointStore +} + +func toTypedFlowAgent[M messageType](agent TypedAgent[M]) *typedFlowAgent[M] { + if fa, ok := agent.(*typedFlowAgent[M]); ok { + return fa + } + return &typedFlowAgent[M]{TypedAgent: agent} +} + +func getTypedAgentType[M messageType](agent TypedAgent[M]) string { + if msgAgent, ok := any(agent).(Agent); ok { + return getAgentType(msgAgent) + } + if typer, ok := any(agent).(interface{ GetType() string }); ok { + return typer.GetType() + } + return "" +} + +func (a *typedFlowAgent[M]) Run(ctx context.Context, input *TypedAgentInput[M], opts ...AgentRunOption) *AsyncIterator[*TypedAgentEvent[M]] { + agentName := a.Name(ctx) + + var runCtx *runContext + ctx, runCtx = initTypedRunCtx(ctx, agentName, input) + ctx = AppendAddressSegment(ctx, AddressSegmentAgent, agentName) + + o := getCommonOptions(nil, opts...) + cancelCtx := o.cancelCtx + + ctxForSubAgents := ctx + + agentType := getTypedAgentType(a.TypedAgent) + ctx = initAgenticCallbacks(ctx, agentName, agentType, filterOptions(agentName, opts)...) + cbInput := &TypedAgentCallbackInput[*schema.AgenticMessage]{Input: any(input).(*TypedAgentInput[*schema.AgenticMessage])} + ctx = callbacks.OnStart(ctx, cbInput) + + aIter := a.TypedAgent.Run(withCancelContext(ctx, cancelCtx), input, filterOptions(agentName, opts)...) + + iterator, generator := NewAsyncIteratorPair[*TypedAgentEvent[M]]() + + go a.run(withCancelContext(ctx, cancelCtx), withCancelContext(ctxForSubAgents, cancelCtx), runCtx, aIter, generator, filterCancelOption(opts)...) + + return wrapIterWithCancelCtx(iterator, cancelCtx) +} + +func (a *typedFlowAgent[M]) Resume(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*TypedAgentEvent[M]] { + agentName := a.Name(ctx) + + ctx, info = buildResumeInfo(ctx, agentName, info) + + ctxForSubAgents := ctx + + o := getCommonOptions(nil, opts...) + cancelCtx := o.cancelCtx + + agentType := getTypedAgentType(a.TypedAgent) + ctx = initAgenticCallbacks(ctx, agentName, agentType, filterOptions(agentName, opts)...) + cbInput := &TypedAgentCallbackInput[*schema.AgenticMessage]{ResumeInfo: info} + ctx = callbacks.OnStart(ctx, cbInput) + + if info.WasInterrupted { + if ra, ok := a.TypedAgent.(TypedResumableAgent[M]); ok { + aIter := ra.Resume(withCancelContext(ctx, cancelCtx), info, opts...) + + iterator, generator := NewAsyncIteratorPair[*TypedAgentEvent[M]]() + go a.run(withCancelContext(ctx, cancelCtx), withCancelContext(ctxForSubAgents, cancelCtx), getRunCtx(ctxForSubAgents), aIter, generator, filterCancelOption(opts)...) + return wrapIterWithCancelCtx(iterator, cancelCtx) + } + + if cancelCtx != nil { + cancelCtx.markDone() + } + return typedErrorIterWithOnEnd[M](ctx, fmt.Errorf("failed to resume agent: agent '%s' is an interrupt point "+ + "but is not a ResumableAgent", agentName)) + } + + _, err := getNextResumeAgent(ctx, info) + if err != nil { + if cancelCtx != nil { + cancelCtx.markDone() + } + return typedErrorIterWithOnEnd[M](ctx, err) + } + + if ra, ok := a.TypedAgent.(TypedResumableAgent[M]); ok { + ctx = withCancelContext(ctx, cancelCtx) + innerIter := ra.Resume(ctx, info, filterCancelOption(opts)...) + return wrapIterWithCancelCtx(typedWrapIterWithOnEnd[M](ctx, innerIter), cancelCtx) + } + return typedErrorIterWithOnEnd[M](ctx, fmt.Errorf( + "failed to resume agent: agent '%s' (type %T) does not implement ResumableAgent interface. "+ + "To support resume, your custom agent wrapper must implement the ResumableAgent interface", agentName, a.TypedAgent)) +} + +func (a *typedFlowAgent[M]) run( + ctx context.Context, + _ context.Context, + runCtx *runContext, + aIter *AsyncIterator[*TypedAgentEvent[M]], + generator *AsyncGenerator[*TypedAgentEvent[M]], + _ ...AgentRunOption) { + + agenticCbIter, agenticCbGen := NewAsyncIteratorPair[*TypedAgentEvent[*schema.AgenticMessage]]() + cbOutput := &TypedAgentCallbackOutput[*schema.AgenticMessage]{Events: agenticCbIter} + icb.On(ctx, cbOutput, icb.BuildOnEndHandleWithCopy(copyTypedCallbackOutput[*schema.AgenticMessage]), callbacks.TimingOnEnd, false) + + defer func() { + panicErr := recover() + if panicErr != nil { + e := safe.NewPanicErr(panicErr, debug.Stack()) + generator.Send(&TypedAgentEvent[M]{Err: e}) + } + + agenticCbGen.Close() + generator.Close() + }() + + for { + event, ok := aIter.Next() + if !ok { + break + } + + if len(event.RunPath) == 0 { + event.AgentName = a.Name(ctx) + event.RunPath = runCtx.RunPath + } + if (event.Action == nil || event.Action.Interrupted == nil) && exactRunPathMatch(runCtx.RunPath, event.RunPath) { + copied := copyTypedAgentEvent(event) + typedSetAutomaticClose(copied) + typedSetAutomaticClose(event) + addTypedEvent(runCtx.Session, copied) + } + + agenticCopied := copyTypedAgentEvent(event) + typedSetAutomaticClose(agenticCopied) + typedSetAutomaticClose(event) + agenticCbGen.Send(any(agenticCopied).(*TypedAgentEvent[*schema.AgenticMessage])) + generator.Send(event) + } +} + +func wrapAgenticIterWithOnEnd(ctx context.Context, iter *AsyncIterator[*TypedAgentEvent[*schema.AgenticMessage]]) *AsyncIterator[*TypedAgentEvent[*schema.AgenticMessage]] { + cbIter, cbGen := NewAsyncIteratorPair[*TypedAgentEvent[*schema.AgenticMessage]]() + cbOutput := &TypedAgentCallbackOutput[*schema.AgenticMessage]{Events: cbIter} + icb.On(ctx, cbOutput, icb.BuildOnEndHandleWithCopy(copyTypedCallbackOutput[*schema.AgenticMessage]), callbacks.TimingOnEnd, false) + + outIter, outGen := NewAsyncIteratorPair[*TypedAgentEvent[*schema.AgenticMessage]]() + go func() { + defer func() { + cbGen.Close() + outGen.Close() + }() + for { + event, ok := iter.Next() + if !ok { + break + } + copied := copyTypedAgentEvent(event) + cbGen.Send(copied) + outGen.Send(event) + } + }() + return outIter +} + +func genAgenticErrorIter(err error) *AsyncIterator[*TypedAgentEvent[*schema.AgenticMessage]] { + iter, gen := NewAsyncIteratorPair[*TypedAgentEvent[*schema.AgenticMessage]]() + gen.Send(&TypedAgentEvent[*schema.AgenticMessage]{Err: err}) + gen.Close() + return iter +} + +func typedWrapIterWithOnEnd[M messageType](ctx context.Context, iter *AsyncIterator[*TypedAgentEvent[M]]) *AsyncIterator[*TypedAgentEvent[M]] { + agenticIter := any(iter).(*AsyncIterator[*TypedAgentEvent[*schema.AgenticMessage]]) + return any(wrapAgenticIterWithOnEnd(ctx, agenticIter)).(*AsyncIterator[*TypedAgentEvent[M]]) +} + +func typedErrorIterWithOnEnd[M messageType](ctx context.Context, err error) *AsyncIterator[*TypedAgentEvent[M]] { + return typedWrapIterWithOnEnd[M](ctx, typedErrorIter[M](err)) +} diff --git a/adk/handler.go b/adk/handler.go index d18abc965..255294dd0 100644 --- a/adk/handler.go +++ b/adk/handler.go @@ -96,12 +96,12 @@ type ChatModelAgentContext struct { ReturnDirectly map[string]bool } -// ChatModelAgentMiddleware defines the interface for customizing ChatModelAgent behavior. +// TypedChatModelAgentMiddleware defines the interface for customizing TypedChatModelAgent behavior. // -// IMPORTANT: This interface is specifically designed for ChatModelAgent and agents built +// IMPORTANT: This interface is specifically designed for TypedChatModelAgent and agents built // on top of it (e.g., DeepAgent). // -// Why ChatModelAgentMiddleware instead of AgentMiddleware? +// Why TypedChatModelAgentMiddleware instead of AgentMiddleware? // // AgentMiddleware is a struct type, which has inherent limitations: // - Struct types are closed: users cannot add new methods to extend functionality @@ -110,22 +110,22 @@ type ChatModelAgentContext struct { // call those methods (config.Middlewares is []AgentMiddleware, not a user type) // - Callbacks in AgentMiddleware only return error, cannot return modified context // -// ChatModelAgentMiddleware is an interface type, which is open for extension: +// TypedChatModelAgentMiddleware is an interface type, which is open for extension: // - Users can implement custom handlers with arbitrary internal state and methods // - Hook methods return (context.Context, ..., error) for direct context propagation // - Wrapper methods (WrapToolCall, WrapModel) enable context propagation through the // wrapped endpoint chain: wrappers can pass modified context to the next wrapper // - Configuration is centralized in struct fields rather than scattered in closures // -// ChatModelAgentMiddleware vs AgentMiddleware: +// TypedChatModelAgentMiddleware vs AgentMiddleware: // - Use AgentMiddleware for simple, static additions (extra instruction/tools) -// - Use ChatModelAgentMiddleware for dynamic behavior, context modification, or call wrapping +// - Use TypedChatModelAgentMiddleware for dynamic behavior, context modification, or call wrapping // - AgentMiddleware is kept for backward compatibility with existing users // - Both can be used together; see AgentMiddleware documentation for execution order // -// Use *BaseChatModelAgentMiddleware as an embedded struct to provide default no-op +// Use *TypedBaseChatModelAgentMiddleware as an embedded struct to provide default no-op // implementations for all methods. -type ChatModelAgentMiddleware interface { +type TypedChatModelAgentMiddleware[M messageType] interface { // BeforeAgent is called before each agent run, allowing modification of // the agent's instruction and tools configuration. BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext) (context.Context, *ChatModelAgentContext, error) @@ -139,7 +139,7 @@ type ChatModelAgentMiddleware interface { // // The ModelContext struct provides read-only access to: // - Tools: the current tool list that will be sent to the model - BeforeModelRewriteState(ctx context.Context, state *ChatModelAgentState, mc *ModelContext) (context.Context, *ChatModelAgentState, error) + BeforeModelRewriteState(ctx context.Context, state *TypedChatModelAgentState[M], mc *ModelContext) (context.Context, *TypedChatModelAgentState[M], error) // AfterModelRewriteState is called after each model invocation. // The input state includes the model's response as the last message. @@ -150,7 +150,7 @@ type ChatModelAgentMiddleware interface { // // The ModelContext struct provides read-only access to: // - Tools: the current tool list that was sent to the model - AfterModelRewriteState(ctx context.Context, state *ChatModelAgentState, mc *ModelContext) (context.Context, *ChatModelAgentState, error) + AfterModelRewriteState(ctx context.Context, state *TypedChatModelAgentState[M], mc *ModelContext) (context.Context, *TypedChatModelAgentState[M], error) // AfterToolCallsRewriteState is called after all concurrent tool calls in an iteration complete. // The input state includes all messages up to and including the tool call results. @@ -158,7 +158,7 @@ type ChatModelAgentMiddleware interface { // // The ToolCallsContext provides metadata about the tool calls that just completed, // derived from the assistant message's ToolCalls field. - AfterToolCallsRewriteState(ctx context.Context, state *ChatModelAgentState, tc *ToolCallsContext) (context.Context, *ChatModelAgentState, error) + AfterToolCallsRewriteState(ctx context.Context, state *TypedChatModelAgentState[M], tc *ToolCallsContext) (context.Context, *TypedChatModelAgentState[M], error) // WrapInvokableToolCall wraps a tool's synchronous execution with custom behavior. // Return the input endpoint unchanged and nil error if no wrapping is needed. @@ -212,15 +212,21 @@ type ChatModelAgentMiddleware interface { // Return the input model unchanged and nil error if no wrapping is needed. // // This method is called at request time when the model is about to be invoked. - // Note: The parameter is BaseChatModel (not ToolCallingChatModel) because wrappers + // Note: The parameter is model.BaseModel[M] (not ToolCallingChatModel) because wrappers // only need to intercept Generate/Stream calls. Tool binding (WithTools) is handled // separately by the framework and does not flow through user wrappers. // // The mc parameter contains the current tool configuration: // - Tools: The tool infos that will be sent to the model - WrapModel(ctx context.Context, m model.BaseChatModel, mc *ModelContext) (model.BaseChatModel, error) + WrapModel(ctx context.Context, m model.BaseModel[M], mc *ModelContext) (model.BaseModel[M], error) } +// ChatModelAgentMiddleware is the default middleware type using *schema.Message. +// See TypedChatModelAgentMiddleware for full documentation. +type ChatModelAgentMiddleware = TypedChatModelAgentMiddleware[*schema.Message] + +type TypedBaseChatModelAgentMiddleware[M messageType] struct{} + // BaseChatModelAgentMiddleware provides default no-op implementations for ChatModelAgentMiddleware. // Embed *BaseChatModelAgentMiddleware in custom handlers to only override the methods you need. // @@ -235,44 +241,58 @@ type ChatModelAgentMiddleware interface { // // custom logic // return ctx, state, nil // } -type BaseChatModelAgentMiddleware struct{} +type BaseChatModelAgentMiddleware = TypedBaseChatModelAgentMiddleware[*schema.Message] -func (b *BaseChatModelAgentMiddleware) WrapInvokableToolCall(_ context.Context, endpoint InvokableToolCallEndpoint, _ *ToolContext) (InvokableToolCallEndpoint, error) { +func (b *TypedBaseChatModelAgentMiddleware[M]) WrapInvokableToolCall(_ context.Context, endpoint InvokableToolCallEndpoint, _ *ToolContext) (InvokableToolCallEndpoint, error) { return endpoint, nil } -func (b *BaseChatModelAgentMiddleware) WrapStreamableToolCall(_ context.Context, endpoint StreamableToolCallEndpoint, _ *ToolContext) (StreamableToolCallEndpoint, error) { +func (b *TypedBaseChatModelAgentMiddleware[M]) WrapStreamableToolCall(_ context.Context, endpoint StreamableToolCallEndpoint, _ *ToolContext) (StreamableToolCallEndpoint, error) { return endpoint, nil } -func (b *BaseChatModelAgentMiddleware) WrapEnhancedInvokableToolCall(_ context.Context, endpoint EnhancedInvokableToolCallEndpoint, _ *ToolContext) (EnhancedInvokableToolCallEndpoint, error) { +func (b *TypedBaseChatModelAgentMiddleware[M]) WrapEnhancedInvokableToolCall(_ context.Context, endpoint EnhancedInvokableToolCallEndpoint, _ *ToolContext) (EnhancedInvokableToolCallEndpoint, error) { return endpoint, nil } -func (b *BaseChatModelAgentMiddleware) WrapEnhancedStreamableToolCall(_ context.Context, endpoint EnhancedStreamableToolCallEndpoint, _ *ToolContext) (EnhancedStreamableToolCallEndpoint, error) { +func (b *TypedBaseChatModelAgentMiddleware[M]) WrapEnhancedStreamableToolCall(_ context.Context, endpoint EnhancedStreamableToolCallEndpoint, _ *ToolContext) (EnhancedStreamableToolCallEndpoint, error) { return endpoint, nil } -func (b *BaseChatModelAgentMiddleware) WrapModel(_ context.Context, m model.BaseChatModel, _ *ModelContext) (model.BaseChatModel, error) { +func (b *TypedBaseChatModelAgentMiddleware[M]) WrapModel(_ context.Context, m model.BaseModel[M], _ *ModelContext) (model.BaseModel[M], error) { return m, nil } -func (b *BaseChatModelAgentMiddleware) BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext) (context.Context, *ChatModelAgentContext, error) { +func (b *TypedBaseChatModelAgentMiddleware[M]) BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext) (context.Context, *ChatModelAgentContext, error) { return ctx, runCtx, nil } -func (b *BaseChatModelAgentMiddleware) BeforeModelRewriteState(ctx context.Context, state *ChatModelAgentState, mc *ModelContext) (context.Context, *ChatModelAgentState, error) { +func (b *TypedBaseChatModelAgentMiddleware[M]) BeforeModelRewriteState(ctx context.Context, state *TypedChatModelAgentState[M], mc *ModelContext) (context.Context, *TypedChatModelAgentState[M], error) { return ctx, state, nil } -func (b *BaseChatModelAgentMiddleware) AfterModelRewriteState(ctx context.Context, state *ChatModelAgentState, mc *ModelContext) (context.Context, *ChatModelAgentState, error) { +func (b *TypedBaseChatModelAgentMiddleware[M]) AfterModelRewriteState(ctx context.Context, state *TypedChatModelAgentState[M], mc *ModelContext) (context.Context, *TypedChatModelAgentState[M], error) { return ctx, state, nil } -func (b *BaseChatModelAgentMiddleware) AfterToolCallsRewriteState(ctx context.Context, state *ChatModelAgentState, tc *ToolCallsContext) (context.Context, *ChatModelAgentState, error) { +func (b *TypedBaseChatModelAgentMiddleware[M]) AfterToolCallsRewriteState(ctx context.Context, state *TypedChatModelAgentState[M], tc *ToolCallsContext) (context.Context, *TypedChatModelAgentState[M], error) { return ctx, state, nil } +func processTypedState(ctx context.Context, fn func(extra map[string]any) map[string]any) error { + runCtx := getRunCtx(ctx) + if runCtx != nil && runCtx.AgenticRootInput != nil { + return compose.ProcessState(ctx, func(_ context.Context, st *typedState[*schema.AgenticMessage]) error { + st.Extra = fn(st.Extra) + return nil + }) + } + return compose.ProcessState(ctx, func(_ context.Context, st *typedState[*schema.Message]) error { + st.Extra = fn(st.Extra) + return nil + }) +} + // SetRunLocalValue sets a key-value pair that persists for the duration of the current agent Run() invocation. // The value is scoped to this specific execution and is not shared across different Run() calls or agent instances. // @@ -287,12 +307,12 @@ func SetRunLocalValue(ctx context.Context, key string, value any) error { return err } - err := compose.ProcessState(ctx, func(_ context.Context, st *State) error { - if st.Extra == nil { - st.Extra = make(map[string]any) + err := processTypedState(ctx, func(extra map[string]any) map[string]any { + if extra == nil { + extra = make(map[string]any) } - st.Extra[key] = value - return nil + extra[key] = value + return extra }) if err != nil { return fmt.Errorf("SetRunLocalValue failed: must be called within a ChatModelAgent Run() or Resume() execution context: %w", err) @@ -313,11 +333,11 @@ func SetRunLocalValue(ctx context.Context, key string, value any) error { func GetRunLocalValue(ctx context.Context, key string) (any, bool, error) { var val any var found bool - err := compose.ProcessState(ctx, func(_ context.Context, st *State) error { - if st.Extra != nil { - val, found = st.Extra[key] + err := processTypedState(ctx, func(extra map[string]any) map[string]any { + if extra != nil { + val, found = extra[key] } - return nil + return extra }) if err != nil { return nil, false, fmt.Errorf("GetRunLocalValue failed: must be called within a ChatModelAgent Run() or Resume() execution context: %w", err) @@ -330,11 +350,11 @@ func GetRunLocalValue(ctx context.Context, key string) (any, bool, error) { // This function can only be called from within a ChatModelAgentMiddleware during agent execution. // Returns an error if called outside of an agent execution context. func DeleteRunLocalValue(ctx context.Context, key string) error { - err := compose.ProcessState(ctx, func(_ context.Context, st *State) error { - if st.Extra != nil { - delete(st.Extra, key) + err := processTypedState(ctx, func(extra map[string]any) map[string]any { + if extra != nil { + delete(extra, key) } - return nil + return extra }) if err != nil { return fmt.Errorf("DeleteRunLocalValue failed: must be called within a ChatModelAgent Run() or Resume() execution context: %w", err) @@ -342,21 +362,31 @@ func DeleteRunLocalValue(ctx context.Context, key string) error { return nil } -// SendEvent sends a custom AgentEvent to the event stream during agent execution. -// This allows ChatModelAgentMiddleware implementations to emit custom events that will be +// TypedSendEvent sends a custom TypedAgentEvent to the event stream during agent execution. +// This allows TypedChatModelAgentMiddleware implementations to emit custom events that will be // received by the caller iterating over the agent's event stream. // -// This function can only be called from within a ChatModelAgentMiddleware during agent execution. +// This function can only be called from within a TypedChatModelAgentMiddleware during agent execution. // Returns an error if called outside of an agent execution context. -func SendEvent(ctx context.Context, event *AgentEvent) error { - execCtx := getChatModelAgentExecCtx(ctx) +func TypedSendEvent[M messageType](ctx context.Context, event *TypedAgentEvent[M]) error { + execCtx := getTypedChatModelAgentExecCtx[M](ctx) if execCtx == nil || execCtx.generator == nil { - return fmt.Errorf("SendEvent failed: must be called within a ChatModelAgent Run() or Resume() execution context") + return fmt.Errorf("TypedSendEvent failed: must be called within a ChatModelAgent Run() or Resume() execution context") } execCtx.send(event) return nil } +// SendEvent sends a custom AgentEvent to the event stream during agent execution. +// This allows ChatModelAgentMiddleware implementations to emit custom events that will be +// received by the caller iterating over the agent's event stream. +// +// This function can only be called from within a ChatModelAgentMiddleware during agent execution. +// Returns an error if called outside of an agent execution context. +func SendEvent(ctx context.Context, event *AgentEvent) error { + return TypedSendEvent[*schema.Message](ctx, event) +} + // checkGobEncodability probes whether the value can be gob-encoded as part of // a map[string]any, which is exactly how State.Extra is serialized during // checkpoint. This catches unregistered types early at Set time, rather than diff --git a/adk/instruction.go b/adk/instruction.go index f02888ed2..8794aff59 100644 --- a/adk/instruction.go +++ b/adk/instruction.go @@ -45,7 +45,7 @@ When transferring: OUTPUT ONLY THE FUNCTION CALL` agentDescriptionTplChinese = "\n- Agent 名字: %s\n Agent 描述: %s" ) -func genTransferToAgentInstruction(ctx context.Context, agents []Agent) string { +func genTransferToAgentInstruction[M messageType](ctx context.Context, agents []TypedAgent[M]) string { tpl := internal.SelectPrompt(internal.I18nPrompts{ English: agentDescriptionTpl, Chinese: agentDescriptionTplChinese, diff --git a/adk/interface.go b/adk/interface.go index e1f17eca7..0a4c0bc5c 100644 --- a/adk/interface.go +++ b/adk/interface.go @@ -32,36 +32,80 @@ import ( // Use this to filter callback events to only agent-related events. const ComponentOfAgent components.Component = "Agent" +// ComponentOfAgenticAgent is the component type identifier for ADK agents +// that use *schema.AgenticMessage in callbacks. +const ComponentOfAgenticAgent components.Component = "AgenticAgent" + +// messageType is the sealed type constraint for message types used in ADK. +// Only *schema.Message and *schema.AgenticMessage satisfy this constraint. +// External packages cannot add new types to this union; all generic functions +// in ADK use exhaustive type switches on these two types. +type messageType interface { + *schema.Message | *schema.AgenticMessage +} + type Message = *schema.Message type MessageStream = *schema.StreamReader[Message] -type MessageVariant struct { +type AgenticMessage = *schema.AgenticMessage +type AgenticMessageStream = *schema.StreamReader[AgenticMessage] + +// isNilMessage checks whether a generic message value is nil. +// Direct `msg == nil` does not compile for generic pointer types in Go; +// the canonical workaround is to compare through the `any` interface. +func isNilMessage[M messageType](msg M) bool { + var zero M + return any(msg) == any(zero) +} + +// TypedMessageVariant represents a message output from an agent event. +// It carries either a complete message or a streaming reader, along with +// metadata describing the event's origin. +// +// Role and ToolName are only meaningful for *schema.Message events. For +// *schema.AgenticMessage events (created via EventFromAgenticMessage), these +// fields are always zero-valued because AgenticMessage carries tool results as +// ContentBlocks within the message itself and does not support agent transfer. +// +// For *schema.Message events, Role and ToolName exist independently of the inner +// Message because in streaming mode (IsStreaming=true, Message=nil), the message +// has not materialized yet and the consumer needs metadata without consuming the stream. +type TypedMessageVariant[M messageType] struct { IsStreaming bool - Message Message - MessageStream MessageStream - // message role: Assistant or Tool + Message M + MessageStream *schema.StreamReader[M] + + // Role indicates the origin of this event within the agent's ReAct loop. + // Only meaningful for *schema.Message events: + // - schema.Assistant: the event carries model output (generation or stream). + // - schema.Tool: the event carries a tool execution result. + // Always zero-valued for *schema.AgenticMessage events; use AgenticRole instead. Role schema.RoleType - // only used when Role is Tool + + // AgenticRole indicates the role of the agentic message (assistant, user, system). + // Only meaningful for *schema.AgenticMessage events. + // In streaming mode, this is available before consuming the stream. + // Always zero-valued for *schema.Message events; use Role instead. + AgenticRole schema.AgenticRoleType + + // ToolName is the name of the tool that produced this event. + // Only meaningful for *schema.Message events: non-empty when Role == schema.Tool. + // In streaming mode, this is the only way to identify the source tool before + // the stream is consumed. + // Always empty for *schema.AgenticMessage events. ToolName string } -// EventFromMessage wraps a message or stream into an AgentEvent with role metadata. -func EventFromMessage(msg Message, msgStream MessageStream, - role schema.RoleType, toolName string) *AgentEvent { - return &AgentEvent{ - Output: &AgentOutput{ - MessageOutput: &MessageVariant{ - IsStreaming: msgStream != nil, - Message: msg, - MessageStream: msgStream, - Role: role, - ToolName: toolName, - }, - }, +func (mv *TypedMessageVariant[M]) GetMessage() (M, error) { + if mv.IsStreaming { + return concatMessageStream(mv.MessageStream) } + return mv.Message, nil } +type MessageVariant = TypedMessageVariant[*schema.Message] + type messageVariantSerialization struct { IsStreaming bool Message Message @@ -70,7 +114,36 @@ type messageVariantSerialization struct { ToolName string } -func (mv *MessageVariant) GobEncode() ([]byte, error) { +type agenticMessageVariantSerialization struct { + IsStreaming bool + Message *schema.AgenticMessage + MessageStream *schema.AgenticMessage + Role schema.RoleType + AgenticRole schema.AgenticRoleType + ToolName string +} + +func (mv *TypedMessageVariant[M]) GobEncode() ([]byte, error) { + if mvMsg, ok := any(mv).(*TypedMessageVariant[*schema.Message]); ok { + return gobEncodeMessageVariant(mvMsg) + } + if mvAgentic, ok := any(mv).(*TypedMessageVariant[*schema.AgenticMessage]); ok { + return gobEncodeAgenticMessageVariant(mvAgentic) + } + return nil, fmt.Errorf("gob encoding not supported for this message type") +} + +func (mv *TypedMessageVariant[M]) GobDecode(b []byte) error { + if mvMsg, ok := any(mv).(*TypedMessageVariant[*schema.Message]); ok { + return gobDecodeMessageVariant(mvMsg, b) + } + if mvAgentic, ok := any(mv).(*TypedMessageVariant[*schema.AgenticMessage]); ok { + return gobDecodeAgenticMessageVariant(mvAgentic, b) + } + return fmt.Errorf("gob decoding not supported for this message type") +} + +func gobEncodeMessageVariant(mv *TypedMessageVariant[*schema.Message]) ([]byte, error) { s := &messageVariantSerialization{ IsStreaming: mv.IsStreaming, Message: mv.Message, @@ -103,7 +176,7 @@ func (mv *MessageVariant) GobEncode() ([]byte, error) { return buf.Bytes(), nil } -func (mv *MessageVariant) GobDecode(b []byte) error { +func gobDecodeMessageVariant(mv *TypedMessageVariant[*schema.Message], b []byte) error { s := &messageVariantSerialization{} err := gob.NewDecoder(bytes.NewReader(b)).Decode(s) if err != nil { @@ -119,19 +192,120 @@ func (mv *MessageVariant) GobDecode(b []byte) error { return nil } -func (mv *MessageVariant) GetMessage() (Message, error) { - var message Message +func gobEncodeAgenticMessageVariant(mv *TypedMessageVariant[*schema.AgenticMessage]) ([]byte, error) { + s := &agenticMessageVariantSerialization{ + IsStreaming: mv.IsStreaming, + Message: mv.Message, + Role: mv.Role, + AgenticRole: mv.AgenticRole, + ToolName: mv.ToolName, + } if mv.IsStreaming { - var err error - message, err = schema.ConcatMessageStream(mv.MessageStream) + var messages []*schema.AgenticMessage + for { + frame, err := mv.MessageStream.Recv() + if err == io.EOF { + break + } + if err != nil { + return nil, fmt.Errorf("error receiving agentic message stream: %w", err) + } + messages = append(messages, frame) + } + m, err := schema.ConcatAgenticMessages(messages) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to encode agentic message: cannot concat message stream: %w", err) } + s.MessageStream = m + } + buf := &bytes.Buffer{} + err := gob.NewEncoder(buf).Encode(s) + if err != nil { + return nil, fmt.Errorf("failed to gob encode agentic message variant: %w", err) + } + return buf.Bytes(), nil +} + +func gobDecodeAgenticMessageVariant(mv *TypedMessageVariant[*schema.AgenticMessage], b []byte) error { + s := &agenticMessageVariantSerialization{} + err := gob.NewDecoder(bytes.NewReader(b)).Decode(s) + if err != nil { + return fmt.Errorf("failed to decode agentic message variant: %w", err) + } + mv.IsStreaming = s.IsStreaming + mv.Message = s.Message + mv.Role = s.Role + mv.AgenticRole = s.AgenticRole + mv.ToolName = s.ToolName + if s.MessageStream != nil { + mv.MessageStream = schema.StreamReaderFromArray([]*schema.AgenticMessage{s.MessageStream}) + } + return nil +} + +// typedEventFromMessage creates a TypedAgentEvent containing the given message and optional stream. +func typedEventFromMessage[M messageType](msg M, msgStream *schema.StreamReader[M], + role schema.RoleType, toolName string) *TypedAgentEvent[M] { + return &TypedAgentEvent[M]{ + Output: &TypedAgentOutput[M]{ + MessageOutput: &TypedMessageVariant[M]{ + IsStreaming: msgStream != nil, + Message: msg, + MessageStream: msgStream, + Role: role, + ToolName: toolName, + }, + }, + } +} + +// typedModelOutputEvent creates a model-output event for the generic path. +// For *schema.Message, Role is set to schema.Assistant. +// For *schema.AgenticMessage, AgenticRole is set to schema.AgenticRoleTypeAssistant. +func typedModelOutputEvent[M messageType](msg M, msgStream *schema.StreamReader[M]) *TypedAgentEvent[M] { + var role schema.RoleType + var agenticRole schema.AgenticRoleType + var zero M + if _, ok := any(zero).(*schema.Message); ok { + role = schema.Assistant } else { - message = mv.Message + agenticRole = schema.AgenticRoleTypeAssistant } + event := typedEventFromMessage(msg, msgStream, role, "") + event.Output.MessageOutput.AgenticRole = agenticRole + return event +} + +// EventFromMessage creates an AgentEvent containing the given message and optional stream. +// +// role identifies the origin of this event: +// - schema.Assistant: model output (generation or stream). +// - schema.Tool: tool execution result; toolName must be non-empty. +// +// For *schema.AgenticMessage events, use EventFromAgenticMessage instead. +func EventFromMessage(msg Message, msgStream *schema.StreamReader[Message], + role schema.RoleType, toolName string) *AgentEvent { + return typedEventFromMessage(msg, msgStream, role, toolName) +} - return message, nil +// EventFromAgenticMessage creates a TypedAgentEvent for the AgenticMessage path. +// Unlike EventFromMessage, it does not require role or toolName parameters because +// AgenticMessage carries tool results as ContentBlocks within the message itself, +// and does not support agent transfer. +// +// agenticRole identifies the role of the message (e.g. schema.AgenticRoleTypeAssistant). +// In streaming mode, the role is available on the event before consuming the stream. +func EventFromAgenticMessage(msg AgenticMessage, msgStream AgenticMessageStream, agenticRole schema.AgenticRoleType) *TypedAgentEvent[AgenticMessage] { + return &TypedAgentEvent[AgenticMessage]{ + Output: &TypedAgentOutput[AgenticMessage]{ + MessageOutput: &TypedMessageVariant[AgenticMessage]{ + IsStreaming: msgStream != nil, + Message: msg, + MessageStream: msgStream, + AgenticRole: agenticRole, + }, + }, + } } // TransferToAgentAction represents a transfer-to-agent action. @@ -143,12 +317,14 @@ type TransferToAgentAction struct { DestAgentName string } -type AgentOutput struct { - MessageOutput *MessageVariant +type TypedAgentOutput[M messageType] struct { + MessageOutput *TypedMessageVariant[M] CustomizedOutput any } +type AgentOutput = TypedAgentOutput[*schema.Message] + // NewTransferToAgentAction creates an action to transfer to the specified agent. // // NOT RECOMMENDED: Agent transfer with full context sharing between agents has not proven @@ -238,8 +414,9 @@ type runStepSerialization struct { AgentName string } -// AgentEvent CheckpointSchema: persisted via serialization.RunCtx (gob). -type AgentEvent struct { +// TypedAgentEvent represents a single event emitted during agent execution. +// CheckpointSchema: persisted via serialization.RunCtx (gob). +type TypedAgentEvent[M messageType] struct { AgentName string // RunPath represents the execution path from root agent to the current event source. @@ -250,20 +427,30 @@ type AgentEvent struct { // AgentTool or DeepAgent, RunPath is trivial. Consider those patterns instead. RunPath []RunStep - Output *AgentOutput + Output *TypedAgentOutput[M] Action *AgentAction Err error } -type AgentInput struct { - Messages []Message +// AgentEvent is the default event type using *schema.Message. +type AgentEvent = TypedAgentEvent[*schema.Message] + +type TypedAgentInput[M messageType] struct { + Messages []M EnableStreaming bool } -//go:generate mockgen -destination ../internal/mock/adk/Agent_mock.go --package adk -source interface.go -type Agent interface { +type AgentInput = TypedAgentInput[*schema.Message] + +// TypedAgent is the base agent interface parameterized by message type. +// +// For M = *schema.Message, the full ADK feature set is supported (multi-agent +// orchestration, cancel monitoring, retry, flowAgent). +// For M = *schema.AgenticMessage, single-agent execution works but cancel +// monitoring on the model stream and retry are not yet wired. +type TypedAgent[M messageType] interface { Name(ctx context.Context) string Description(ctx context.Context) string @@ -273,9 +460,12 @@ type Agent interface { // the MessageStream MUST be exclusive and safe to be received directly. // NOTE: it's recommended to use SetAutomaticClose() on the MessageStream of AgentEvents emitted by AsyncIterator, // so that even the events are not processed, the MessageStream can still be closed. - Run(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] + Run(ctx context.Context, input *TypedAgentInput[M], options ...AgentRunOption) *AsyncIterator[*TypedAgentEvent[M]] } +//go:generate mockgen -destination ../internal/mock/adk/Agent_mock.go --package adk github.com/cloudwego/eino/adk Agent,ResumableAgent +type Agent = TypedAgent[*schema.Message] + // OnSubAgents is the interface for agents that support sub-agent registration and transfer. // // NOT RECOMMENDED: Agent transfer with full context sharing between agents has not proven @@ -288,8 +478,42 @@ type OnSubAgents interface { OnDisallowTransferToParent(ctx context.Context) error } -type ResumableAgent interface { - Agent +type TypedResumableAgent[M messageType] interface { + TypedAgent[M] + + Resume(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*TypedAgentEvent[M]] +} + +type ResumableAgent = TypedResumableAgent[*schema.Message] - Resume(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] +func concatMessageStream[M messageType](stream *schema.StreamReader[M]) (M, error) { + var zero M + switch s := any(stream).(type) { + case *schema.StreamReader[*schema.Message]: + result, err := schema.ConcatMessageStream(s) + if err != nil { + return zero, err + } + return any(result).(M), nil + case *schema.StreamReader[*schema.AgenticMessage]: + defer s.Close() + var msgs []*schema.AgenticMessage + for { + frame, err := s.Recv() + if err == io.EOF { + break + } + if err != nil { + return zero, err + } + msgs = append(msgs, frame) + } + result, err := schema.ConcatAgenticMessages(msgs) + if err != nil { + return zero, err + } + return any(result).(M), nil + default: + panic("unreachable: unknown messageType") + } } diff --git a/adk/interrupt.go b/adk/interrupt.go index fce09d4cf..35bfb5198 100644 --- a/adk/interrupt.go +++ b/adk/interrupt.go @@ -54,11 +54,9 @@ type InterruptInfo struct { InterruptContexts []*InterruptCtx } -// Interrupt creates a basic interrupt action. -// This is used when an agent needs to pause its execution to request external input or intervention, -// but does not need to save any internal state to be restored upon resumption. -// The `info` parameter is user-facing data that describes the reason for the interrupt. -func Interrupt(ctx context.Context, info any) *AgentEvent { +// TypedInterrupt creates a typed interrupt event that pauses execution to request external input. +// It is the generic counterpart of Interrupt; see Interrupt for full documentation. +func TypedInterrupt[M messageType](ctx context.Context, info any) *TypedAgentEvent[M] { var rp []RunStep rCtx := getRunCtx(ctx) if rCtx != nil { @@ -68,12 +66,12 @@ func Interrupt(ctx context.Context, info any) *AgentEvent { is, err := core.Interrupt(ctx, info, nil, nil, core.WithLayerPayload(rp)) if err != nil { - return &AgentEvent{Err: err} + return &TypedAgentEvent[M]{Err: err} } contexts := core.ToInterruptContexts(is, allowedAddressSegmentTypes) - return &AgentEvent{ + return &TypedAgentEvent[M]{ Action: &AgentAction{ Interrupted: &InterruptInfo{ InterruptContexts: contexts, @@ -83,11 +81,17 @@ func Interrupt(ctx context.Context, info any) *AgentEvent { } } -// StatefulInterrupt creates an interrupt action that also saves the agent's internal state. -// This is used when an agent has internal state that must be restored for it to continue correctly. -// The `info` parameter is user-facing data describing the interrupt. -// The `state` parameter is the agent's internal state object, which will be serialized and stored. -func StatefulInterrupt(ctx context.Context, info any, state any) *AgentEvent { +// Interrupt creates a basic interrupt action. +// This is used when an agent needs to pause its execution to request external input or intervention, +// but does not need to save any internal state to be restored upon resumption. +// The `info` parameter is user-facing data that describes the reason for the interrupt. +func Interrupt(ctx context.Context, info any) *AgentEvent { + return TypedInterrupt[*schema.Message](ctx, info) +} + +// TypedStatefulInterrupt creates a typed interrupt event that also saves the agent's internal state. +// It is the generic counterpart of StatefulInterrupt; see StatefulInterrupt for full documentation. +func TypedStatefulInterrupt[M messageType](ctx context.Context, info any, state any) *TypedAgentEvent[M] { var rp []RunStep rCtx := getRunCtx(ctx) if rCtx != nil { @@ -97,12 +101,12 @@ func StatefulInterrupt(ctx context.Context, info any, state any) *AgentEvent { is, err := core.Interrupt(ctx, info, state, nil, core.WithLayerPayload(rp)) if err != nil { - return &AgentEvent{Err: err} + return &TypedAgentEvent[M]{Err: err} } contexts := core.ToInterruptContexts(is, allowedAddressSegmentTypes) - return &AgentEvent{ + return &TypedAgentEvent[M]{ Action: &AgentAction{ Interrupted: &InterruptInfo{ InterruptContexts: contexts, @@ -112,14 +116,18 @@ func StatefulInterrupt(ctx context.Context, info any, state any) *AgentEvent { } } -// CompositeInterrupt creates an interrupt action for a workflow agent. -// It combines the interrupts from one or more of its sub-agents into a single, cohesive interrupt. -// This is used by workflow agents (like Sequential, Parallel, or Loop) to propagate interrupts from their children. -// The `info` parameter is user-facing data describing the workflow's own reason for interrupting. -// The `state` parameter is the workflow agent's own state (e.g., the index of the sub-agent that was interrupted). -// The `subInterruptSignals` is a variadic list of the InterruptSignal objects from the interrupted sub-agents. -func CompositeInterrupt(ctx context.Context, info any, state any, - subInterruptSignals ...*InterruptSignal) *AgentEvent { +// StatefulInterrupt creates an interrupt action that also saves the agent's internal state. +// This is used when an agent has internal state that must be restored for it to continue correctly. +// The `info` parameter is user-facing data describing the interrupt. +// The `state` parameter is the agent's internal state object, which will be serialized and stored. +func StatefulInterrupt(ctx context.Context, info any, state any) *AgentEvent { + return TypedStatefulInterrupt[*schema.Message](ctx, info, state) +} + +// TypedCompositeInterrupt creates a typed interrupt event that aggregates sub-interrupt signals. +// It is the generic counterpart of CompositeInterrupt; see CompositeInterrupt for full documentation. +func TypedCompositeInterrupt[M messageType](ctx context.Context, info any, state any, + subInterruptSignals ...*InterruptSignal) *TypedAgentEvent[M] { var rp []RunStep rCtx := getRunCtx(ctx) if rCtx != nil { @@ -129,12 +137,12 @@ func CompositeInterrupt(ctx context.Context, info any, state any, is, err := core.Interrupt(ctx, info, state, subInterruptSignals, core.WithLayerPayload(rp)) if err != nil { - return &AgentEvent{Err: err} + return &TypedAgentEvent[M]{Err: err} } contexts := core.ToInterruptContexts(is, allowedAddressSegmentTypes) - return &AgentEvent{ + return &TypedAgentEvent[M]{ Action: &AgentAction{ Interrupted: &InterruptInfo{ InterruptContexts: contexts, @@ -144,6 +152,12 @@ func CompositeInterrupt(ctx context.Context, info any, state any, } } +// CompositeInterrupt creates an interrupt event that aggregates sub-interrupt signals. +func CompositeInterrupt(ctx context.Context, info any, state any, + subInterruptSignals ...*InterruptSignal) *AgentEvent { + return TypedCompositeInterrupt[*schema.Message](ctx, info, state, subInterruptSignals...) +} + // Address represents the unique, hierarchical address of a component within an execution. // It is a slice of AddressSegments, where each segment represents one level of nesting. // This is a type alias for core.Address. See the core package for more details. @@ -202,9 +216,9 @@ type serialization struct { InterruptID2State map[string]core.InterruptState } -func (r *Runner) loadCheckPoint(ctx context.Context, checkpointID string) ( +func runnerLoadCheckPointImpl(store CheckPointStore, ctx context.Context, checkpointID string) ( context.Context, *runContext, *ResumeInfo, error) { - data, existed, err := r.store.Get(ctx, checkpointID) + data, existed, err := store.Get(ctx, checkpointID) if err != nil { return nil, nil, nil, fmt.Errorf("failed to get checkpoint from store: %w", err) } @@ -266,13 +280,15 @@ func preprocessADKCheckpoint(data []byte) []byte { []byte(lenPrefixedCompatName)) } -func (r *Runner) saveCheckPoint( +func runnerSaveCheckPointImpl( + enableStreaming bool, + store CheckPointStore, ctx context.Context, key string, info *InterruptInfo, is *core.InterruptSignal, ) error { - if r.store == nil { + if store == nil { return nil } @@ -286,12 +302,12 @@ func (r *Runner) saveCheckPoint( Info: info, InterruptID2Address: id2Addr, InterruptID2State: id2State, - EnableStreaming: r.enableStreaming, + EnableStreaming: enableStreaming, }) if err != nil { return fmt.Errorf("failed to encode checkpoint: %w", err) } - return r.store.Set(ctx, key, buf.Bytes()) + return store.Set(ctx, key, buf.Bytes()) } const bridgeCheckpointID = "adk_react_mock_key" diff --git a/adk/prebuilt/planexecute/utils.go b/adk/prebuilt/planexecute/utils_test.go similarity index 100% rename from adk/prebuilt/planexecute/utils.go rename to adk/prebuilt/planexecute/utils_test.go diff --git a/adk/react.go b/adk/react.go index 07fdbde9a..fdd224f74 100644 --- a/adk/react.go +++ b/adk/react.go @@ -31,14 +31,8 @@ import ( // ErrExceedMaxIterations indicates the agent reached the maximum iterations limit. var ErrExceedMaxIterations = errors.New("exceeds max iterations") -// State holds agent runtime state including messages and user-extensible storage. -// -// Deprecated: This type will be unexported in v1.0.0. Use ChatModelAgentState -// in HandlerMiddleware and AgentMiddleware callbacks instead. Direct use of -// compose.ProcessState[*State] is discouraged and will stop working in v1.0.0; -// use the handler APIs instead. -type State struct { - Messages []Message +type typedState[M messageType] struct { + Messages []M Extra map[string]any // Internal fields below - do not access directly. @@ -48,10 +42,18 @@ type State struct { ToolGenActions map[string]*AgentAction AgentName string RemainingIterations int - ReturnDirectlyEvent *AgentEvent + ReturnDirectlyEvent *TypedAgentEvent[M] RetryAttempt int } +// State is the internal state of the ChatModelAgent. +// +// Deprecated: State is exported only for checkpoint backward compatibility. +// Do not use it directly. +type State = typedState[*schema.Message] + +type agenticState = typedState[*schema.AgenticMessage] + const ( stateGobNameV07 = "_eino_adk_react_state" @@ -77,50 +79,57 @@ func init() { schema.RegisterName[*State](stateGobNameV07) schema.RegisterName[*stateV080](stateGobNameV080) - // the following two lines of registration mainly for backward compatibility - // when decoding checkpoints created by v0.8.0 - v0.8.3 + schema.RegisterName[*typedState[*schema.AgenticMessage]]("_eino_adk_agentic_state") + schema.RegisterName[*TypedAgentEvent[*schema.AgenticMessage]]("_eino_adk_agentic_event") + + // backward compatibility when decoding checkpoints created by v0.8.0 - v0.8.3 gob.Register(&AgentEvent{}) gob.Register(int(0)) + + schema.RegisterName[*TypedAgentInput[*schema.AgenticMessage]]("_eino_adk_agentic_agent_input") + schema.RegisterName[*typedAgentEventWrapper[*schema.AgenticMessage]]("_eino_adk_agentic_event_wrapper") + schema.RegisterName[*[]*typedAgentEventWrapper[*schema.AgenticMessage]]("_eino_adk_agentic_event_wrapper_slice") schema.RegisterName[*reactInput]("_eino_adk_react_input") + schema.RegisterName[*agenticReactInput]("_eino_adk_agentic_react_input") } -func (s *State) getReturnDirectlyEvent() *AgentEvent { +func (s *typedState[M]) getReturnDirectlyEvent() *TypedAgentEvent[M] { return s.ReturnDirectlyEvent } -func (s *State) setReturnDirectlyEvent(event *AgentEvent) { +func (s *typedState[M]) setReturnDirectlyEvent(event *TypedAgentEvent[M]) { s.ReturnDirectlyEvent = event } -func (s *State) getRetryAttempt() int { +func (s *typedState[M]) getRetryAttempt() int { return s.RetryAttempt } -func (s *State) setRetryAttempt(attempt int) { +func (s *typedState[M]) setRetryAttempt(attempt int) { s.RetryAttempt = attempt } -func (s *State) getReturnDirectlyToolCallID() string { +func (s *typedState[M]) getReturnDirectlyToolCallID() string { return s.ReturnDirectlyToolCallID } -func (s *State) setReturnDirectlyToolCallID(id string) { +func (s *typedState[M]) setReturnDirectlyToolCallID(id string) { s.ReturnDirectlyToolCallID = id s.HasReturnDirectly = id != "" } -func (s *State) getToolGenActions() map[string]*AgentAction { +func (s *typedState[M]) getToolGenActions() map[string]*AgentAction { return s.ToolGenActions } -func (s *State) setToolGenAction(key string, action *AgentAction) { +func (s *typedState[M]) setToolGenAction(key string, action *AgentAction) { if s.ToolGenActions == nil { s.ToolGenActions = make(map[string]*AgentAction) } s.ToolGenActions[key] = action } -func (s *State) popToolGenAction(key string) *AgentAction { +func (s *typedState[M]) popToolGenAction(key string) *AgentAction { if s.ToolGenActions == nil { return nil } @@ -129,15 +138,15 @@ func (s *State) popToolGenAction(key string) *AgentAction { return action } -func (s *State) getRemainingIterations() int { +func (s *typedState[M]) getRemainingIterations() int { return s.RemainingIterations } -func (s *State) setRemainingIterations(iterations int) { +func (s *typedState[M]) setRemainingIterations(iterations int) { s.RemainingIterations = iterations } -func (s *State) decrementRemainingIterations() { +func (s *typedState[M]) decrementRemainingIterations() { current := s.getRemainingIterations() s.RemainingIterations = current - 1 } @@ -241,13 +250,11 @@ type reactInput struct { Messages []Message } -type reactConfig struct { - // model is the chat model used by the react graph. - // Tools are configured via model.WithTools call option, not the WithTools method. - model model.BaseChatModel +type typedReactConfig[M messageType] struct { + model model.BaseModel[M] toolsConfig *compose.ToolsNodeConfig - modelWrapperConf *modelWrapperConfig + modelWrapperConf *typedModelWrapperConfig[M] toolsReturnDirectly map[string]bool @@ -258,6 +265,8 @@ type reactConfig struct { cancelCtx *cancelContext } +type reactConfig = typedReactConfig[*schema.Message] + func genToolInfos(ctx context.Context, config *compose.ToolsNodeConfig) ([]*schema.ToolInfo, error) { toolInfos := make([]*schema.ToolInfo, 0, len(config.Tools)) for _, t := range config.Tools { @@ -360,7 +369,7 @@ func newReact(ctx context.Context, config *reactConfig) (reactGraph, error) { toolPreHandle := func(ctx context.Context, _ Message, st *State) (Message, error) { input := st.Messages[len(st.Messages)-1] returnDirectly := config.toolsReturnDirectly - if execCtx := getChatModelAgentExecCtx(ctx); execCtx != nil && len(execCtx.runtimeReturnDirectly) > 0 { + if execCtx := getTypedChatModelAgentExecCtx[*schema.Message](ctx); execCtx != nil && len(execCtx.runtimeReturnDirectly) > 0 { returnDirectly = execCtx.runtimeReturnDirectly } if len(returnDirectly) > 0 { @@ -375,7 +384,7 @@ func newReact(ctx context.Context, config *reactConfig) (reactGraph, error) { } toolPostHandle := func(ctx context.Context, out *schema.StreamReader[[]*schema.Message], st *State) (*schema.StreamReader[[]*schema.Message], error) { if event := st.getReturnDirectlyEvent(); event != nil { - getChatModelAgentExecCtx(ctx).send(event) + getTypedChatModelAgentExecCtx[*schema.Message](ctx).send(event) st.setReturnDirectlyEvent(nil) } return out, nil @@ -501,3 +510,218 @@ func newReact(ctx context.Context, config *reactConfig) (reactGraph, error) { return g, nil } + +type agenticReactInput struct { + Messages []*schema.AgenticMessage +} + +type agenticReactConfig = typedReactConfig[*schema.AgenticMessage] + +type agenticReactGraph = *compose.Graph[*agenticReactInput, *schema.AgenticMessage] + +func getAgenticReturnDirectlyToolCallID(ctx context.Context) (string, bool) { + var toolCallID string + _ = compose.ProcessState(ctx, func(_ context.Context, st *agenticState) error { + toolCallID = st.getReturnDirectlyToolCallID() + return nil + }) + return toolCallID, toolCallID != "" +} + +func genAgenticReactState(config *agenticReactConfig) func(ctx context.Context) *agenticState { + return func(ctx context.Context) *agenticState { + st := &agenticState{ + AgentName: config.agentName, + } + maxIter := 20 + if config.maxIterations > 0 { + maxIter = config.maxIterations + } + st.setRemainingIterations(maxIter) + return st + } +} + +func agenticMessageHasToolCalls(msg *schema.AgenticMessage) bool { + if msg == nil { + return false + } + for _, block := range msg.ContentBlocks { + if block != nil && block.Type == schema.ContentBlockTypeFunctionToolCall && block.FunctionToolCall != nil { + return true + } + } + return false +} + +func newAgenticReact(ctx context.Context, config *agenticReactConfig) (agenticReactGraph, error) { + const ( + initNode_ = "Init" + chatModel_ = "ChatModel" + cancelCheckNode_ = "CancelCheck" + toolNode_ = "ToolNode" + afterToolCallsNode_ = "AfterToolCalls" + afterToolCallsCancelCheckNode_ = "AfterToolCallsCancelCheck" + ) + + cancelCtx := config.cancelCtx + g := compose.NewGraph[*agenticReactInput, *schema.AgenticMessage]( + compose.WithGenLocalState(genAgenticReactState(config))) + _ = g.AddLambdaNode(initNode_, compose.InvokableLambda(func(ctx context.Context, input *agenticReactInput) ([]*schema.AgenticMessage, error) { + _ = compose.ProcessState(ctx, func(_ context.Context, st *agenticState) error { + st.Messages = append(st.Messages, input.Messages...) + return nil + }) + return input.Messages, nil + }), compose.WithNodeName(initNode_)) + + var wrappedModel model.AgenticModel = config.model + if config.modelWrapperConf != nil { + wrappedModel = buildModelWrappers(config.model, config.modelWrapperConf) + } + + toolsNode, err := compose.NewAgenticToolsNode(ctx, config.toolsConfig) + if err != nil { + return nil, err + } + + _ = g.AddAgenticModelNode(chatModel_, wrappedModel, compose.WithStatePreHandler( + func(ctx context.Context, input []*schema.AgenticMessage, st *agenticState) ([]*schema.AgenticMessage, error) { + if st.getRemainingIterations() <= 0 { + return nil, ErrExceedMaxIterations + } + st.decrementRemainingIterations() + return input, nil + }), compose.WithNodeName(chatModel_)) + + _ = g.AddLambdaNode(cancelCheckNode_, compose.InvokableLambda(func(ctx context.Context, msg *schema.AgenticMessage) (*schema.AgenticMessage, error) { + if cancelCtx != nil && cancelCtx.shouldCancel() { + if cancelCtx.getMode()&CancelAfterChatModel != 0 { + return nil, compose.StatefulInterrupt(ctx, "CancelAfterChatModel", msg) + } + } + wasInterrupted, hasState, state := compose.GetInterruptState[*schema.AgenticMessage](ctx) + if wasInterrupted && hasState { + msg = state + } + return msg, nil + }), compose.WithNodeName(cancelCheckNode_)) + + toolPreHandle := func(ctx context.Context, _ *schema.AgenticMessage, st *agenticState) (*schema.AgenticMessage, error) { + input := st.Messages[len(st.Messages)-1] + returnDirectly := config.toolsReturnDirectly + if execCtx := getTypedChatModelAgentExecCtx[*schema.AgenticMessage](ctx); execCtx != nil && len(execCtx.runtimeReturnDirectly) > 0 { + returnDirectly = execCtx.runtimeReturnDirectly + } + if len(returnDirectly) > 0 { + for _, block := range input.ContentBlocks { + if block == nil || block.Type != schema.ContentBlockTypeFunctionToolCall || block.FunctionToolCall == nil { + continue + } + if _, ok := returnDirectly[block.FunctionToolCall.Name]; ok { + st.setReturnDirectlyToolCallID(block.FunctionToolCall.CallID) + } + } + } + return input, nil + } + toolPostHandle := func(ctx context.Context, out *schema.StreamReader[[]*schema.AgenticMessage], st *agenticState) (*schema.StreamReader[[]*schema.AgenticMessage], error) { + if event := st.getReturnDirectlyEvent(); event != nil { + getTypedChatModelAgentExecCtx[*schema.AgenticMessage](ctx).send(event) + st.setReturnDirectlyEvent(nil) + } + return out, nil + } + _ = g.AddAgenticToolsNode(toolNode_, toolsNode, + compose.WithStatePreHandler(toolPreHandle), + compose.WithStreamStatePostHandler(toolPostHandle), + compose.WithNodeName(toolNode_)) + + afterToolCalls := func(ctx context.Context, toolResults []*schema.AgenticMessage) ([]*schema.AgenticMessage, error) { + _ = compose.ProcessState(ctx, func(_ context.Context, st *agenticState) error { + st.Messages = append(st.Messages, toolResults...) + return nil + }) + return toolResults, nil + } + _ = g.AddLambdaNode(afterToolCallsNode_, compose.InvokableLambda(afterToolCalls), + compose.WithNodeName(afterToolCallsNode_)) + + afterToolCallsCancelCheck := func(ctx context.Context, toolResults []*schema.AgenticMessage) ([]*schema.AgenticMessage, error) { + if cancelCtx != nil && cancelCtx.shouldCancel() { + if cancelCtx.getMode()&CancelAfterToolCalls != 0 { + return nil, compose.Interrupt(ctx, "CancelAfterToolCalls") + } + } + return toolResults, nil + } + _ = g.AddLambdaNode(afterToolCallsCancelCheckNode_, compose.InvokableLambda(afterToolCallsCancelCheck), + compose.WithNodeName(afterToolCallsCancelCheckNode_)) + + _ = g.AddEdge(compose.START, initNode_) + _ = g.AddEdge(initNode_, chatModel_) + + toolCallCheck := func(ctx context.Context, sMsg *schema.StreamReader[*schema.AgenticMessage]) (string, error) { + defer sMsg.Close() + for { + chunk, err_ := sMsg.Recv() + if err_ != nil { + if err_ == io.EOF { + return compose.END, nil + } + return "", err_ + } + if agenticMessageHasToolCalls(chunk) { + return cancelCheckNode_, nil + } + } + } + branch := compose.NewStreamGraphBranch(toolCallCheck, map[string]bool{compose.END: true, cancelCheckNode_: true}) + _ = g.AddBranch(chatModel_, branch) + + _ = g.AddEdge(cancelCheckNode_, toolNode_) + _ = g.AddEdge(toolNode_, afterToolCallsNode_) + _ = g.AddEdge(afterToolCallsNode_, afterToolCallsCancelCheckNode_) + + if len(config.toolsReturnDirectly) > 0 { + const ( + toolNodeToEndConverter = "ToolNodeToEndConverter" + ) + + cvt := func(ctx context.Context, toolResults []*schema.AgenticMessage) (*schema.AgenticMessage, error) { + id, _ := getAgenticReturnDirectlyToolCallID(ctx) + for _, msg := range toolResults { + if msg == nil { + continue + } + for _, block := range msg.ContentBlocks { + if block != nil && block.Type == schema.ContentBlockTypeFunctionToolResult && + block.FunctionToolResult != nil && block.FunctionToolResult.CallID == id { + return msg, nil + } + } + } + return nil, errors.New("return directly tool call result not found") + } + + _ = g.AddLambdaNode(toolNodeToEndConverter, compose.InvokableLambda(cvt), + compose.WithNodeName(toolNodeToEndConverter)) + _ = g.AddEdge(toolNodeToEndConverter, compose.END) + + checkReturnDirect := func(ctx context.Context, toolResults []*schema.AgenticMessage) (string, error) { + _, ok := getAgenticReturnDirectlyToolCallID(ctx) + if ok { + return toolNodeToEndConverter, nil + } + return chatModel_, nil + } + + returnDirectBranch := compose.NewGraphBranch(checkReturnDirect, + map[string]bool{toolNodeToEndConverter: true, chatModel_: true}) + _ = g.AddBranch(afterToolCallsCancelCheckNode_, returnDirectBranch) + } else { + _ = g.AddEdge(afterToolCallsCancelCheckNode_, chatModel_) + } + + return g, nil +} diff --git a/adk/react_test.go b/adk/react_test.go index b0a6c3985..1ac0ff5ee 100644 --- a/adk/react_test.go +++ b/adk/react_test.go @@ -29,6 +29,7 @@ import ( "github.com/bytedance/sonic" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" "github.com/cloudwego/eino/components/model" @@ -642,3 +643,30 @@ func randStrForTest() string { } return string(b) } + +func TestReactHistory_EmptyMessages(t *testing.T) { + g := compose.NewGraph[string, []Message](compose.WithGenLocalState(func(ctx context.Context) (state *State) { + return &State{ + Messages: []Message{}, + } + })) + require.NoError(t, g.AddLambdaNode("1", compose.InvokableLambda(func(ctx context.Context, input string) (output []Message, err error) { + return getReactChatHistory(ctx, "DestAgent") + }))) + require.NoError(t, g.AddEdge(compose.START, "1")) + require.NoError(t, g.AddEdge("1", compose.END)) + + ctx := context.Background() + ctx, _ = initRunCtx(ctx, "MyAgent", nil) + runner, err := g.Compile(ctx) + require.NoError(t, err) + + require.NotPanics(t, func() { + result, err := runner.Invoke(ctx, "") + if err != nil { + t.Logf("Got error (acceptable): %v", err) + return + } + t.Logf("Got %d messages", len(result)) + }, "BUG: getReactChatHistory should not panic with empty Messages slice") +} diff --git a/adk/retry_chatmodel.go b/adk/retry_chatmodel.go index 304e8b9b3..df6fcad93 100644 --- a/adk/retry_chatmodel.go +++ b/adk/retry_chatmodel.go @@ -264,7 +264,7 @@ func genErrWrapper(ctx context.Context, maxRetries, attempt int, isRetryAbleFunc } } -func consumeStreamForError(stream *schema.StreamReader[*schema.Message]) error { +func consumeStreamForError[M any](stream *schema.StreamReader[M]) error { defer stream.Close() for { _, err := stream.Recv() @@ -292,23 +292,27 @@ type retryVerdict struct { // This is used inside the model wrapper chain, positioned between eventSenderModelWrapper // and stateModelWrapper, so that retry only affects the inner chain (event sending, user wrappers, // callback injection) without re-running state management (BeforeModelRewriteState/AfterModelRewriteState). -type retryModelWrapper struct { - inner model.BaseChatModel +type typedRetryModelWrapper[M messageType] struct { + inner model.BaseModel[M] config *ModelRetryConfig } -func newRetryModelWrapper(inner model.BaseChatModel, config *ModelRetryConfig) *retryModelWrapper { - return &retryModelWrapper{inner: inner, config: config} +func newTypedRetryModelWrapper[M messageType](inner model.BaseModel[M], config *ModelRetryConfig) *typedRetryModelWrapper[M] { + return &typedRetryModelWrapper[M]{inner: inner, config: config} } -func (r *retryModelWrapper) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { +func (r *typedRetryModelWrapper[M]) Generate(ctx context.Context, input []M, opts ...model.Option) (M, error) { if r.config.ShouldRetry != nil { - return r.generateWithShouldRetry(ctx, input, opts...) + // ShouldRetry is *schema.Message-specific (RetryContext.OutputMessage is *schema.Message). + msgR, _ := any(r).(*typedRetryModelWrapper[*schema.Message]) + msgInput, _ := any(input).([]Message) + out, err := generateWithShouldRetry(msgR, ctx, msgInput, opts...) + return any(out).(M), err } return r.generateLegacy(ctx, input, opts...) } -func (r *retryModelWrapper) generateLegacy(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { +func (r *typedRetryModelWrapper[M]) generateLegacy(ctx context.Context, input []M, opts ...model.Option) (zero M, _ error) { isRetryAble := r.config.IsRetryAble if isRetryAble == nil { isRetryAble = defaultIsRetryAble @@ -325,37 +329,36 @@ func (r *retryModelWrapper) generateLegacy(ctx context.Context, input []*schema. return out, nil } - // Never retry interrupt errors (e.g. cancel safe-point interrupts). if _, ok := compose.ExtractInterruptInfo(err); ok { - return nil, err + return zero, err } if errors.Is(err, ErrStreamCanceled) { - return nil, err + return zero, err } if !isRetryAble(ctx, err) { - return nil, err + return zero, err } lastErr = err if attempt < r.config.MaxRetries { if err := r.contextAwareSleep(ctx, backoffFunc(ctx, attempt+1)); err != nil { - return nil, err + return zero, err } } } - return nil, &RetryExhaustedError{LastErr: lastErr, TotalRetries: r.config.MaxRetries} + return zero, &RetryExhaustedError{LastErr: lastErr, TotalRetries: r.config.MaxRetries} } -func (r *retryModelWrapper) generateWithShouldRetry(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { +func generateWithShouldRetry(r *typedRetryModelWrapper[*schema.Message], ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { backoffFunc := r.config.BackoffFunc if backoffFunc == nil { backoffFunc = defaultBackoff } - execCtx := getChatModelAgentExecCtx(ctx) + execCtx := getTypedChatModelAgentExecCtx[*schema.Message](ctx) currentInput := input currentOpts := opts @@ -431,7 +434,7 @@ func (r *retryModelWrapper) generateWithShouldRetry(ctx context.Context, input [ break } - r.applyDecisionForRetry(¤tInput, ¤tOpts, ctx, decision) + applyDecisionForRetry(¤tInput, ¤tOpts, ctx, decision) delay := decision.Backoff if delay == 0 { @@ -446,7 +449,7 @@ func (r *retryModelWrapper) generateWithShouldRetry(ctx context.Context, input [ return nil, &RetryExhaustedError{LastErr: lastErr, TotalRetries: r.config.MaxRetries} } -func (r *retryModelWrapper) contextAwareSleep(ctx context.Context, delay time.Duration) error { +func (r *typedRetryModelWrapper[M]) contextAwareSleep(ctx context.Context, delay time.Duration) error { if delay <= 0 { return nil } @@ -481,7 +484,7 @@ func consumeStreamForMessage(stream *schema.StreamReader[*schema.Message]) (*sch } } -func (r *retryModelWrapper) streamWithShouldRetry(ctx context.Context, input []*schema.Message, opts ...model.Option) ( +func streamWithShouldRetry(r *typedRetryModelWrapper[*schema.Message], ctx context.Context, input []*schema.Message, opts ...model.Option) ( *schema.StreamReader[*schema.Message], error) { backoffFunc := r.config.BackoffFunc @@ -496,7 +499,7 @@ func (r *retryModelWrapper) streamWithShouldRetry(ctx context.Context, input []* }) }() - execCtx := getChatModelAgentExecCtx(ctx) + execCtx := getTypedChatModelAgentExecCtx[*schema.Message](ctx) currentInput := input currentOpts := opts @@ -568,7 +571,7 @@ func (r *retryModelWrapper) streamWithShouldRetry(ctx context.Context, input []* lastErr = err if attempt < r.config.MaxRetries { - r.applyDecisionForRetry(¤tInput, ¤tOpts, ctx, decision) + applyDecisionForRetry(¤tInput, ¤tOpts, ctx, decision) delay := decision.Backoff if delay == 0 { delay = backoffFunc(ctx, attempt+1) @@ -638,7 +641,7 @@ func (r *retryModelWrapper) streamWithShouldRetry(ctx context.Context, input []* lastErr = verdictErr if attempt < r.config.MaxRetries { - r.applyDecisionForRetry(¤tInput, ¤tOpts, ctx, decision) + applyDecisionForRetry(¤tInput, ¤tOpts, ctx, decision) delay := decision.Backoff if delay == 0 { delay = backoffFunc(ctx, attempt+1) @@ -652,7 +655,7 @@ func (r *retryModelWrapper) streamWithShouldRetry(ctx context.Context, input []* return nil, &RetryExhaustedError{LastErr: lastErr, TotalRetries: r.config.MaxRetries} } -func (r *retryModelWrapper) applyDecisionForRetry(currentInput *[]*schema.Message, currentOpts *[]model.Option, ctx context.Context, decision *RetryDecision) { +func applyDecisionForRetry(currentInput *[]*schema.Message, currentOpts *[]model.Option, ctx context.Context, decision *RetryDecision) { if decision.ModifiedInputMessages != nil { *currentInput = decision.ModifiedInputMessages if decision.PersistModifiedInputMessages { @@ -671,17 +674,24 @@ func (r *retryModelWrapper) applyDecisionForRetry(currentInput *[]*schema.Messag } } -func (r *retryModelWrapper) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) ( - *schema.StreamReader[*schema.Message], error) { +func (r *typedRetryModelWrapper[M]) Stream(ctx context.Context, input []M, opts ...model.Option) ( + *schema.StreamReader[M], error) { if r.config.ShouldRetry != nil { - return r.streamWithShouldRetry(ctx, input, opts...) + // ShouldRetry is *schema.Message-specific (RetryContext.OutputMessage is *schema.Message). + msgR, _ := any(r).(*typedRetryModelWrapper[*schema.Message]) + msgInput, _ := any(input).([]Message) + sr, err := streamWithShouldRetry(msgR, ctx, msgInput, opts...) + if err != nil { + return nil, err + } + return any(sr).(*schema.StreamReader[M]), nil } return r.streamLegacy(ctx, input, opts...) } -func (r *retryModelWrapper) streamLegacy(ctx context.Context, input []*schema.Message, opts ...model.Option) ( - *schema.StreamReader[*schema.Message], error) { +func (r *typedRetryModelWrapper[M]) streamLegacy(ctx context.Context, input []M, opts ...model.Option) ( + *schema.StreamReader[M], error) { isRetryAble := r.config.IsRetryAble if isRetryAble == nil { @@ -693,7 +703,7 @@ func (r *retryModelWrapper) streamLegacy(ctx context.Context, input []*schema.Me } defer func() { - _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { + _ = compose.ProcessState(ctx, func(_ context.Context, st *typedState[M]) error { st.setRetryAttempt(0) return nil }) @@ -701,7 +711,7 @@ func (r *retryModelWrapper) streamLegacy(ctx context.Context, input []*schema.Me var lastErr error for attempt := 0; attempt <= r.config.MaxRetries; attempt++ { - _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { + _ = compose.ProcessState(ctx, func(_ context.Context, st *typedState[M]) error { st.setRetryAttempt(attempt) return nil }) @@ -730,7 +740,7 @@ func (r *retryModelWrapper) streamLegacy(ctx context.Context, input []*schema.Me checkCopy := copies[0] returnCopy := copies[1] - streamErr := consumeStreamForError(checkCopy) + streamErr := consumeStreamForError[M](checkCopy) if streamErr == nil { return returnCopy, nil } diff --git a/adk/runctx.go b/adk/runctx.go index 1a32f1760..3c0316018 100644 --- a/adk/runctx.go +++ b/adk/runctx.go @@ -20,10 +20,14 @@ import ( "bytes" "context" "encoding/gob" + "errors" "fmt" + "io" "sort" "sync" "time" + + "github.com/cloudwego/eino/schema" ) // runSession CheckpointSchema: persisted via serialization.RunCtx (gob). @@ -34,6 +38,11 @@ type runSession struct { Events []*agentEventWrapper LaneEvents *laneEvents mtx sync.Mutex + + // TypedEvents stores *[]*typedAgentEventWrapper[M] for M != *schema.Message. + // For M = *schema.Message, the existing Events field is used instead. + // The any type is required because Go does not support generic fields in non-generic structs. + TypedEvents any } // laneEvents CheckpointSchema: persisted via serialization.RunCtx (gob). @@ -60,6 +69,105 @@ type agentEventWrapper struct { StreamErr error } +type typedAgentEventWrapper[M messageType] struct { + event *TypedAgentEvent[M] + mu sync.Mutex + concatenatedMessage M + TS int64 + StreamErr error +} + +// typedAgentEventWrapperForGob is a gob-serializable representation of typedAgentEventWrapper. +// We encode the event and TS separately to avoid the sync.Mutex and non-exported fields. +type typedAgentEventWrapperForGob[M messageType] struct { + Event *TypedAgentEvent[M] + TS int64 +} + +func (e *typedAgentEventWrapper[M]) GobEncode() ([]byte, error) { + if e.event != nil && e.event.Output != nil && e.event.Output.MessageOutput != nil && e.event.Output.MessageOutput.IsStreaming { + // Materialize the stream before encoding. + if isNilMessage(e.concatenatedMessage) && e.StreamErr == nil { + e.consumeStream() + } + } + + buf := &bytes.Buffer{} + err := gob.NewEncoder(buf).Encode(&typedAgentEventWrapperForGob[M]{ + Event: e.event, + TS: e.TS, + }) + if err != nil { + return nil, fmt.Errorf("failed to gob encode generic agent event wrapper: %w", err) + } + return buf.Bytes(), nil +} + +func (e *typedAgentEventWrapper[M]) GobDecode(b []byte) error { + g := &typedAgentEventWrapperForGob[M]{} + if err := gob.NewDecoder(bytes.NewReader(b)).Decode(g); err != nil { + return fmt.Errorf("failed to gob decode generic agent event wrapper: %w", err) + } + e.event = g.Event + e.TS = g.TS + return nil +} + +// consumeStream drains the typed message stream, setting concatenatedMessage on success +// or StreamErr on failure. The stream is replaced with a materialized version safe for +// gob encoding. +// +// NOTE: This method parallels agentEventWrapper.consumeStream in utils.go. The two +// implementations exist because agentEventWrapper is non-generic (uses *schema.Message +// directly) while typedAgentEventWrapper[M] is generic. They cannot be unified without +// making the non-generic wrapper generic, which would cascade through the entire +// non-generic event storage layer. +func (e *typedAgentEventWrapper[M]) consumeStream() { + e.mu.Lock() + defer e.mu.Unlock() + + if !isNilMessage(e.concatenatedMessage) { + return + } + + s := e.event.Output.MessageOutput.MessageStream + var msgs []M + + defer s.Close() + for { + msg, err := s.Recv() + if err != nil { + if err == io.EOF { + break + } + e.StreamErr = err + e.event.Output.MessageOutput.MessageStream = schema.StreamReaderFromArray(msgs) + return + } + msgs = append(msgs, msg) + } + + if len(msgs) == 0 { + e.StreamErr = errors.New("no messages in typedAgentEventWrapper.MessageStream") + e.event.Output.MessageOutput.MessageStream = schema.StreamReaderFromArray(msgs) + return + } + + if len(msgs) == 1 { + e.concatenatedMessage = msgs[0] + } else { + var err error + e.concatenatedMessage, err = concatMessageStream(schema.StreamReaderFromArray(msgs)) + if err != nil { + e.StreamErr = err + e.event.Output.MessageOutput.MessageStream = schema.StreamReaderFromArray(msgs) + return + } + } + + e.event.Output.MessageOutput.MessageStream = schema.StreamReaderFromArray([]M{e.concatenatedMessage}) +} + type otherAgentEventWrapperForEncode agentEventWrapper func (a *agentEventWrapper) GobEncode() ([]byte, error) { @@ -184,6 +292,71 @@ func (rs *runSession) getEvents() []*agentEventWrapper { return finalEvents } +func addTypedEvent[M messageType](session *runSession, event *TypedAgentEvent[M]) { + var zero M + if _, ok := any(zero).(*schema.Message); ok { + session.addEvent(any(event).(*AgentEvent)) + return + } + session.mtx.Lock() + defer session.mtx.Unlock() + wrapper := &typedAgentEventWrapper[M]{event: event, TS: time.Now().UnixNano()} + store, _ := session.TypedEvents.(*[]*typedAgentEventWrapper[M]) + if store == nil { + s := make([]*typedAgentEventWrapper[M], 0) + store = &s + session.TypedEvents = store + } + *store = append(*store, wrapper) +} + +func getTypedEvents[M messageType](session *runSession) []*typedAgentEventWrapper[M] { + var zero M + if _, ok := any(zero).(*schema.Message); ok { + events := session.getEvents() + result := make([]*typedAgentEventWrapper[M], 0, len(events)) + for _, e := range events { + w := &typedAgentEventWrapper[M]{ + event: any(e.AgentEvent).(*TypedAgentEvent[M]), + TS: e.TS, + StreamErr: e.StreamErr, + } + if e.concatenatedMessage != nil { + w.concatenatedMessage = any(e.concatenatedMessage).(M) + } + result = append(result, w) + } + return result + } + + session.mtx.Lock() + defer session.mtx.Unlock() + + store, _ := session.TypedEvents.(*[]*typedAgentEventWrapper[M]) + if store == nil { + if len(session.Events) == 0 { + return nil + } + result := make([]*typedAgentEventWrapper[M], 0, len(session.Events)) + for _, e := range session.Events { + w := &typedAgentEventWrapper[M]{ + event: any(e.AgentEvent).(*TypedAgentEvent[M]), + TS: e.TS, + StreamErr: e.StreamErr, + } + if e.concatenatedMessage != nil { + w.concatenatedMessage = any(e.concatenatedMessage).(M) + } + result = append(result, w) + } + return result + } + + result := make([]*typedAgentEventWrapper[M], len(*store)) + copy(result, *store) + return result +} + func (rs *runSession) getValues() map[string]any { rs.valuesMtx.Lock() values := make(map[string]any, len(rs.Values)) @@ -221,6 +394,8 @@ type runContext struct { RootInput *AgentInput RunPath []RunStep + AgenticRootInput any + Session *runSession } @@ -230,9 +405,10 @@ func (rc *runContext) isRoot() bool { func (rc *runContext) deepCopy() *runContext { copied := &runContext{ - RootInput: rc.RootInput, - RunPath: make([]RunStep, len(rc.RunPath)), - Session: rc.Session, + RootInput: rc.RootInput, + AgenticRootInput: rc.AgenticRootInput, + RunPath: make([]RunStep, len(rc.RunPath)), + Session: rc.Session, } copy(copied.RunPath, rc.RunPath) @@ -270,6 +446,27 @@ func initRunCtx(ctx context.Context, agentName string, input *AgentInput) (conte return setRunCtx(ctx, runCtx), runCtx } +func initTypedRunCtx[M messageType](ctx context.Context, agentName string, input *TypedAgentInput[M]) (context.Context, *runContext) { + runCtx := getRunCtx(ctx) + if runCtx != nil { + runCtx = runCtx.deepCopy() + } else { + runCtx = &runContext{Session: newRunSession()} + } + + runCtx.RunPath = append(runCtx.RunPath, RunStep{agentName: agentName}) + if runCtx.isRoot() && input != nil { + var zero M + if _, ok := any(zero).(*schema.Message); ok { + runCtx.RootInput = any(input).(*AgentInput) + } else { + runCtx.AgenticRootInput = input + } + } + + return setRunCtx(ctx, runCtx), runCtx +} + func joinRunCtxs(parentCtx context.Context, childCtxs ...context.Context) { switch len(childCtxs) { case 0: @@ -384,7 +581,7 @@ func ClearRunCtx(ctx context.Context) context.Context { return context.WithValue(ctx, runCtxKey{}, nil) } -func ctxWithNewRunCtx(ctx context.Context, input *AgentInput, sharedParentSession bool) context.Context { +func ctxWithNewTypedRunCtx[M messageType](ctx context.Context, input *TypedAgentInput[M], sharedParentSession bool) context.Context { var session *runSession if sharedParentSession { if parentSession := getSession(ctx); parentSession != nil { @@ -397,7 +594,14 @@ func ctxWithNewRunCtx(ctx context.Context, input *AgentInput, sharedParentSessio if session == nil { session = newRunSession() } - return setRunCtx(ctx, &runContext{Session: session, RootInput: input}) + var zero M + rc := &runContext{Session: session} + if _, ok := any(zero).(*schema.Message); ok { + rc.RootInput = any(input).(*AgentInput) + } else { + rc.AgenticRootInput = input + } + return setRunCtx(ctx, rc) } func getSession(ctx context.Context) *runSession { diff --git a/adk/runner.go b/adk/runner.go index 405b69e76..6caac130c 100644 --- a/adk/runner.go +++ b/adk/runner.go @@ -28,29 +28,53 @@ import ( "github.com/cloudwego/eino/schema" ) -// Runner is the primary entry point for executing an Agent. +func errorIterator[M messageType](err error) *AsyncIterator[*TypedAgentEvent[M]] { + iter, gen := NewAsyncIteratorPair[*TypedAgentEvent[M]]() + gen.Send(&TypedAgentEvent[M]{Err: err}) + gen.Close() + return iter +} + +func newUserMessage[M messageType](query string) (M, error) { + var zero M + switch any(zero).(type) { + case *schema.Message: + return any(schema.UserMessage(query)).(M), nil + case *schema.AgenticMessage: + return any(schema.UserAgenticMessage(query)).(M), nil + default: + return zero, fmt.Errorf("unsupported message type %T", zero) + } +} + +// TypedRunner is the primary entry point for executing an Agent. // It manages the agent's lifecycle, including starting, resuming, and checkpointing. -type Runner struct { - // a is the agent to be executed. - a Agent - // enableStreaming dictates whether the execution should be in streaming mode. +// +// Execution always goes through the flowAgent pipeline, which handles +// multi-agent orchestration, callbacks, agent naming, run paths, and cancellation. +type TypedRunner[M messageType] struct { + a TypedAgent[M] enableStreaming bool - // store is the checkpoint store used to persist agent state upon interruption. - // If nil, checkpointing is disabled. - store CheckPointStore + store CheckPointStore } +// Runner is the default runner type using *schema.Message. +type Runner = TypedRunner[*schema.Message] + type CheckPointStore = core.CheckPointStore type CheckPointDeleter = core.CheckPointDeleter -type RunnerConfig struct { - Agent Agent +type TypedRunnerConfig[M messageType] struct { + Agent TypedAgent[M] EnableStreaming bool CheckPointStore CheckPointStore } +// RunnerConfig is the default runner config type using *schema.Message. +type RunnerConfig = TypedRunnerConfig[*schema.Message] + // ResumeParams contains all parameters needed to resume an execution. // This struct provides an extensible way to pass resume parameters without // requiring breaking changes to method signatures. @@ -61,52 +85,33 @@ type ResumeParams struct { // Future extensible fields can be added here without breaking changes } -// NewRunner creates a Runner that executes an Agent with optional streaming -// and checkpoint persistence. +// NewRunner creates a new Runner with the given config. func NewRunner(_ context.Context, conf RunnerConfig) *Runner { - return &Runner{ + return NewTypedRunner[*schema.Message](conf) +} + +// NewTypedRunner creates a new TypedRunner with the given config. +func NewTypedRunner[M messageType](conf TypedRunnerConfig[M]) *TypedRunner[M] { + return &TypedRunner[M]{ enableStreaming: conf.EnableStreaming, a: conf.Agent, store: conf.CheckPointStore, } } -// Run starts a new execution of the agent with a given set of messages. -// It returns an iterator that yields agent events as they occur. -// If the Runner was configured with a CheckPointStore, it will automatically save the agent's state -// upon interruption. -func (r *Runner) Run(ctx context.Context, messages []Message, - opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { - o := getCommonOptions(nil, opts...) - - fa := toFlowAgent(ctx, r.a) - - input := &AgentInput{ - Messages: messages, - EnableStreaming: r.enableStreaming, - } - - ctx = ctxWithNewRunCtx(ctx, input, o.sharedParentSession) - - AddSessionValues(ctx, o.sessionValues) - - iter := fa.Run(ctx, input, opts...) - - if r.store == nil && o.cancelCtx == nil { - return iter - } - - niter, gen := NewAsyncIteratorPair[*AgentEvent]() - - go r.handleIter(ctx, iter, gen, o.checkPointID, o.cancelCtx) - return niter +func (r *TypedRunner[M]) Run(ctx context.Context, messages []M, + opts ...AgentRunOption) *AsyncIterator[*TypedAgentEvent[M]] { + return typedRunnerRunImpl(r.a, r.enableStreaming, r.store, ctx, messages, opts...) } // Query is a convenience method that starts a new execution with a single user query string. -func (r *Runner) Query(ctx context.Context, - query string, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { - - return r.Run(ctx, []Message{schema.UserMessage(query)}, opts...) +func (r *TypedRunner[M]) Query(ctx context.Context, + query string, opts ...AgentRunOption) *AsyncIterator[*TypedAgentEvent[M]] { + msgs, err := newUserMessage[M](query) + if err != nil { + return errorIterator[M](err) + } + return r.Run(ctx, []M{msgs}, opts...) } // Resume continues an interrupted execution from a checkpoint, using an "Implicit Resume All" strategy. @@ -116,8 +121,8 @@ func (r *Runner) Query(ctx context.Context, // When using this method, all interrupted agents will receive `isResumeFlow = false` when they // call `GetResumeContext`, as no specific agent was targeted. This is suitable for the "Simple Confirmation" // pattern where an agent only needs to know `wasInterrupted` is true to continue. -func (r *Runner) Resume(ctx context.Context, checkPointID string, opts ...AgentRunOption) ( - *AsyncIterator[*AgentEvent], error) { +func (r *TypedRunner[M]) Resume(ctx context.Context, checkPointID string, opts ...AgentRunOption) ( + *AsyncIterator[*TypedAgentEvent[M]], error) { return r.resumeInternal(ctx, checkPointID, nil, opts...) } @@ -139,17 +144,71 @@ func (r *Runner) Resume(ctx context.Context, checkPointID string, opts ...AgentR // execution. They act as conduits, allowing the resume signal to flow to their children. They will // naturally re-interrupt if one of their interrupted children re-interrupts, as they receive the // new `CompositeInterrupt` signal from them. -func (r *Runner) ResumeWithParams(ctx context.Context, checkPointID string, params *ResumeParams, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], error) { +func (r *TypedRunner[M]) ResumeWithParams(ctx context.Context, checkPointID string, params *ResumeParams, opts ...AgentRunOption) (*AsyncIterator[*TypedAgentEvent[M]], error) { return r.resumeInternal(ctx, checkPointID, params.Targets, opts...) } -func (r *Runner) resumeInternal(ctx context.Context, checkPointID string, resumeData map[string]any, - opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], error) { - if r.store == nil { +func (r *TypedRunner[M]) resumeInternal(ctx context.Context, checkPointID string, resumeData map[string]any, + opts ...AgentRunOption) (*AsyncIterator[*TypedAgentEvent[M]], error) { + return typedRunnerResumeInternalImpl(r.a, r.enableStreaming, r.store, ctx, checkPointID, resumeData, opts...) +} + +func typedRunnerRunImpl[M messageType](a TypedAgent[M], enableStreaming bool, store CheckPointStore, ctx context.Context, messages []M, opts ...AgentRunOption) *AsyncIterator[*TypedAgentEvent[M]] { + o := getCommonOptions(nil, opts...) + + input := &TypedAgentInput[M]{ + Messages: messages, + EnableStreaming: enableStreaming, + } + + var zero M + if _, ok := any(zero).(*schema.Message); ok { + concreteAgent, _ := any(a).(Agent) + fa := toFlowAgent(ctx, concreteAgent) + if store != nil { + fa.checkPointStore = store + } + concreteInput := any(input).(*AgentInput) + ctx = ctxWithNewTypedRunCtx(ctx, input, o.sharedParentSession) + AddSessionValues(ctx, o.sessionValues) + + iter := fa.Run(ctx, concreteInput, opts...) + + if store == nil && o.cancelCtx == nil { + return any(iter).(*AsyncIterator[*TypedAgentEvent[M]]) + } + + niter, gen := NewAsyncIteratorPair[*TypedAgentEvent[M]]() + go typedRunnerHandleIterImpl(enableStreaming, store, ctx, any(iter).(*AsyncIterator[*TypedAgentEvent[M]]), gen, o.checkPointID, o.cancelCtx) + return niter + } + + fa := toTypedFlowAgent(a) + if store != nil { + fa.checkPointStore = store + } + + ctx = ctxWithNewTypedRunCtx(ctx, input, o.sharedParentSession) + AddSessionValues(ctx, o.sessionValues) + + iter := fa.Run(ctx, input, opts...) + + if store == nil && o.cancelCtx == nil { + return iter + } + + niter, gen := NewAsyncIteratorPair[*TypedAgentEvent[M]]() + go typedRunnerHandleIterImpl(enableStreaming, store, ctx, iter, gen, o.checkPointID, o.cancelCtx) + return niter +} + +func typedRunnerResumeInternalImpl[M messageType](a TypedAgent[M], enableStreaming bool, store CheckPointStore, ctx context.Context, checkPointID string, resumeData map[string]any, //nolint:revive // argument-limit + opts ...AgentRunOption) (*AsyncIterator[*TypedAgentEvent[M]], error) { + if store == nil { return nil, fmt.Errorf("failed to resume: store is nil") } - ctx, runCtx, resumeInfo, err := r.loadCheckPoint(ctx, checkPointID) + ctx, runCtx, resumeInfo, err := runnerLoadCheckPointImpl(store, ctx, checkPointID) if err != nil { return nil, fmt.Errorf("failed to load from checkpoint: %w", err) } @@ -170,34 +229,46 @@ func (r *Runner) resumeInternal(ctx context.Context, checkPointID string, resume } ctx = setRunCtx(ctx, runCtx) - AddSessionValues(ctx, o.sessionValues) if len(resumeData) > 0 { ctx = core.BatchResumeWithData(ctx, resumeData) } - fa := toFlowAgent(ctx, r.a) - - aIter := fa.Resume(ctx, resumeInfo, opts...) + var zero M + if _, ok := any(zero).(*schema.Message); ok { + concreteAgent, _ := any(a).(Agent) + fa := toFlowAgent(ctx, concreteAgent) + ra, ok := Agent(fa).(ResumableAgent) + if !ok { + return nil, fmt.Errorf("agent %T does not support resume", a) + } + aIter := ra.Resume(ctx, resumeInfo, opts...) - if r.store == nil && o.cancelCtx == nil { - return aIter, nil + niter, gen := NewAsyncIteratorPair[*TypedAgentEvent[M]]() + go typedRunnerHandleIterImpl(enableStreaming, store, ctx, any(aIter).(*AsyncIterator[*TypedAgentEvent[M]]), gen, &checkPointID, o.cancelCtx) + return niter, nil } - niter, gen := NewAsyncIteratorPair[*AgentEvent]() + fa := toTypedFlowAgent(a) + ra, ok := TypedAgent[M](fa).(TypedResumableAgent[M]) + if !ok { + return nil, fmt.Errorf("agent %T does not support resume", a) + } + aIter := ra.Resume(ctx, resumeInfo, opts...) - go r.handleIter(ctx, aIter, gen, &checkPointID, o.cancelCtx) + niter, gen := NewAsyncIteratorPair[*TypedAgentEvent[M]]() + go typedRunnerHandleIterImpl(enableStreaming, store, ctx, aIter, gen, &checkPointID, o.cancelCtx) return niter, nil } -func (r *Runner) handleIter(ctx context.Context, aIter *AsyncIterator[*AgentEvent], - gen *AsyncGenerator[*AgentEvent], checkPointID *string, cancelCtx *cancelContext) { +func typedRunnerHandleIterImpl[M messageType](enableStreaming bool, store CheckPointStore, ctx context.Context, aIter *AsyncIterator[*TypedAgentEvent[M]], //nolint:revive // argument-limit + gen *AsyncGenerator[*TypedAgentEvent[M]], checkPointID *string, cancelCtx *cancelContext) { defer func() { panicErr := recover() if panicErr != nil { e := safe.NewPanicErr(panicErr, debug.Stack()) - gen.Send(&AgentEvent{Err: e}) + gen.Send(&TypedAgentEvent[M]{Err: e}) } gen.Close() @@ -220,9 +291,9 @@ func (r *Runner) handleIter(ctx context.Context, aIter *AsyncIterator[*AgentEven } if cancelErr.interruptSignal != nil && checkPointID != nil { cancelErr.InterruptContexts = core.ToInterruptContexts(cancelErr.interruptSignal, allowedAddressSegmentTypes) - err := r.saveCheckPoint(ctx, *checkPointID, &InterruptInfo{}, cancelErr.interruptSignal) + err := runnerSaveCheckPointImpl(enableStreaming, store, ctx, *checkPointID, &InterruptInfo{}, cancelErr.interruptSignal) if err != nil { - gen.Send(&AgentEvent{Err: fmt.Errorf("failed to save checkpoint on cancel: %w", err)}) + gen.Send(&TypedAgentEvent[M]{Err: fmt.Errorf("failed to save checkpoint on cancel: %w", err)}) } } gen.Send(event) @@ -232,14 +303,11 @@ func (r *Runner) handleIter(ctx context.Context, aIter *AsyncIterator[*AgentEven if event.Action != nil && event.Action.internalInterrupted != nil { if interruptSignal != nil { - // even if multiple interrupt happens, they should be merged into one - // action by CompositeInterrupt, so here in Runner we must assume at most - // one interrupt action happens panic("multiple interrupt actions should not happen in Runner") } interruptSignal = event.Action.internalInterrupted interruptContexts := core.ToInterruptContexts(interruptSignal, allowedAddressSegmentTypes) - event = &AgentEvent{ + event = &TypedAgentEvent[M]{ AgentName: event.AgentName, RunPath: event.RunPath, Output: event.Output, @@ -254,12 +322,11 @@ func (r *Runner) handleIter(ctx context.Context, aIter *AsyncIterator[*AgentEven legacyData = event.Action.Interrupted.Data if checkPointID != nil { - // save checkpoint first before sending interrupt event, so when end-user receives interrupt event, they can resume from this checkpoint - err := r.saveCheckPoint(ctx, *checkPointID, &InterruptInfo{ + err := runnerSaveCheckPointImpl(enableStreaming, store, ctx, *checkPointID, &InterruptInfo{ Data: legacyData, }, interruptSignal) if err != nil { - gen.Send(&AgentEvent{Err: fmt.Errorf("failed to save checkpoint: %w", err)}) + gen.Send(&TypedAgentEvent[M]{Err: fmt.Errorf("failed to save checkpoint: %w", err)}) } } } diff --git a/adk/runner_test.go b/adk/runner_test.go index 6ab3f128e..0eb797c8e 100644 --- a/adk/runner_test.go +++ b/adk/runner_test.go @@ -21,6 +21,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/cloudwego/eino/schema" ) @@ -261,3 +262,50 @@ func TestRunner_Query_WithStreaming(t *testing.T) { _, ok = iterator.Next() assert.False(t, ok) } + +func TestResumeWithMissingCheckpoint(t *testing.T) { + ctx := context.Background() + + agent := &myAgenticAgent{ + name: "resume-agent", + runFn: func(ctx context.Context, input *TypedAgentInput[*schema.AgenticMessage], options ...AgentRunOption) *AsyncIterator[*TypedAgentEvent[*schema.AgenticMessage]] { + iter, gen := NewAsyncIteratorPair[*TypedAgentEvent[*schema.AgenticMessage]]() + go func() { + defer gen.Close() + gen.Send(&TypedAgentEvent[*schema.AgenticMessage]{ + Output: &TypedAgentOutput[*schema.AgenticMessage]{ + MessageOutput: &TypedMessageVariant[*schema.AgenticMessage]{ + Message: agenticMsg("ok"), + }, + }, + }) + }() + return iter + }, + } + + store := newMyStore() + runner := NewTypedRunner[*schema.AgenticMessage](TypedRunnerConfig[*schema.AgenticMessage]{ + Agent: agent, + CheckPointStore: store, + }) + + require.NotPanics(t, func() { + iter, err := runner.ResumeWithParams(ctx, "nonexistent-checkpoint", &ResumeParams{ + Targets: map[string]any{"fake-id": nil}, + }) + if err != nil { + t.Logf("Got expected error: %v", err) + return + } + for { + event, ok := iter.Next() + if !ok { + break + } + if event.Err != nil { + t.Logf("Got error event: %v", event.Err) + } + } + }, "ResumeWithParams with nonexistent checkpoint should not panic") +} diff --git a/adk/turn_loop.go b/adk/turn_loop.go index df12ba40f..46504c64c 100644 --- a/adk/turn_loop.go +++ b/adk/turn_loop.go @@ -356,12 +356,12 @@ func (s *preemptSignal) drainAll() { } // TurnLoopConfig is the configuration for creating a TurnLoop. -type TurnLoopConfig[T any] struct { +type TurnLoopConfig[T any, M messageType] struct { // GenInput receives the TurnLoop instance and all buffered items, and decides what to process. // It returns which items to consume now vs keep for later turns. // The loop parameter allows calling Push() or Stop() directly from within the callback. // Required. - GenInput func(ctx context.Context, loop *TurnLoop[T], items []T) (*GenInputResult[T], error) + GenInput func(ctx context.Context, loop *TurnLoop[T, M], items []T) (*GenInputResult[T, M], error) // GenResume is called at most once during Run(). When CheckpointID is // configured, Run() queries Store for the checkpoint: @@ -378,7 +378,7 @@ type TurnLoopConfig[T any] struct { // It returns a GenResumeResult describing how to resume the interrupted agent // turn (optional ResumeParams) and how to manipulate the buffer // (Consumed/Remaining) before continuing. - GenResume func(ctx context.Context, loop *TurnLoop[T], canceledItems, unhandledItems, newItems []T) (*GenResumeResult[T], error) + GenResume func(ctx context.Context, loop *TurnLoop[T, M], canceledItems, unhandledItems, newItems []T) (*GenResumeResult[T, M], error) // PrepareAgent returns an Agent configured to handle the consumed items. // This callback should set up the agent with appropriate system prompt, @@ -386,7 +386,7 @@ type TurnLoopConfig[T any] struct { // Called once per turn with the items that GenInput decided to consume. // The loop parameter allows calling Push() or Stop() directly from within the callback. // Required. - PrepareAgent func(ctx context.Context, loop *TurnLoop[T], consumed []T) (Agent, error) + PrepareAgent func(ctx context.Context, loop *TurnLoop[T, M], consumed []T) (TypedAgent[M], error) // OnAgentEvents is called to handle events emitted by the agent. // The TurnContext provides per-turn info and control: @@ -405,7 +405,7 @@ type TurnLoopConfig[T any] struct { // // Optional. If not provided, events are drained and the first error // (including CancelError from Stop) is returned as ExitReason. - OnAgentEvents func(ctx context.Context, tc *TurnContext[T], events *AsyncIterator[*AgentEvent]) error + OnAgentEvents func(ctx context.Context, tc *TurnContext[T, M], events *AsyncIterator[*TypedAgentEvent[M]]) error // Store is the checkpoint store for persistence and resume. Optional. // When set together with CheckpointID, enables automatic checkpoint-based resume. @@ -430,7 +430,7 @@ type TurnLoopConfig[T any] struct { } // GenInputResult contains the result of GenInput processing. -type GenInputResult[T any] struct { +type GenInputResult[T any, M messageType] struct { // RunCtx, if non-nil, overrides the context for this turn's execution // (PrepareAgent, agent run, OnAgentEvents). // @@ -444,7 +444,7 @@ type GenInputResult[T any] struct { RunCtx context.Context // Input is the agent input to execute - Input *AgentInput + Input *TypedAgentInput[M] // RunOpts are the options for this agent run. // Note: do not pass WithCheckPointID here; the TurnLoop automatically @@ -464,7 +464,7 @@ type GenInputResult[T any] struct { } // GenResumeResult contains the result of GenResume processing. -type GenResumeResult[T any] struct { +type GenResumeResult[T any, M messageType] struct { // RunCtx, if non-nil, overrides the context for this resumed turn's execution // (PrepareAgent, agent resume, OnAgentEvents). RunCtx context.Context @@ -489,9 +489,9 @@ type GenResumeResult[T any] struct { Remaining []T } -type turnRunSpec[T any] struct { +type turnRunSpec[T any, M messageType] struct { runCtx context.Context - input *AgentInput + input *TypedAgentInput[M] runOpts []AgentRunOption resumeParams *ResumeParams isResume bool @@ -499,18 +499,18 @@ type turnRunSpec[T any] struct { resumeBytes []byte } -type turnPlan[T any] struct { +type turnPlan[T any, M messageType] struct { turnCtx context.Context remaining []T - spec *turnRunSpec[T] + spec *turnRunSpec[T, M] } -func (l *TurnLoop[T]) planTurn( +func (l *TurnLoop[T, M]) planTurn( ctx context.Context, isResume bool, items []T, pr *turnLoopPendingResume[T], -) (*turnPlan[T], error) { +) (*turnPlan[T, M], error) { if !isResume { result, err := l.config.GenInput(ctx, l, items) if err != nil { @@ -526,10 +526,10 @@ func (l *TurnLoop[T]) planTurn( if result.RunCtx != nil { turnCtx = result.RunCtx } - return &turnPlan[T]{ + return &turnPlan[T, M]{ turnCtx: turnCtx, remaining: result.Remaining, - spec: &turnRunSpec[T]{ + spec: &turnRunSpec[T, M]{ runCtx: result.RunCtx, input: result.Input, runOpts: result.RunOpts, @@ -554,10 +554,10 @@ func (l *TurnLoop[T]) planTurn( if resumeResult.RunCtx != nil { turnCtx = resumeResult.RunCtx } - return &turnPlan[T]{ + return &turnPlan[T, M]{ turnCtx: turnCtx, remaining: resumeResult.Remaining, - spec: &turnRunSpec[T]{ + spec: &turnRunSpec[T, M]{ runCtx: resumeResult.RunCtx, runOpts: resumeResult.RunOpts, resumeParams: resumeResult.ResumeParams, @@ -570,7 +570,7 @@ func (l *TurnLoop[T]) planTurn( // TurnLoopExitState is returned when TurnLoop exits, containing the exit reason // and any items that were not processed. -type TurnLoopExitState[T any] struct { +type TurnLoopExitState[T any, M messageType] struct { // ExitReason indicates why the loop exited. // nil means clean exit (Stop() was called without cancel options, or the // agent completed normally before Stop took effect). @@ -621,9 +621,9 @@ type TurnLoopExitState[T any] struct { } // TurnContext provides per-turn context to the OnAgentEvents callback. -type TurnContext[T any] struct { +type TurnContext[T any, M messageType] struct { // Loop is the TurnLoop instance, allowing Push() or Stop() calls. - Loop *TurnLoop[T] + Loop *TurnLoop[T, M] // Consumed contains items that triggered this agent execution. Consumed []T @@ -672,8 +672,8 @@ type TurnContext[T any] struct { // - Wait: blocks until Run is called AND the loop exits. If Run is never // called, Wait blocks forever (this is a programming error, analogous // to reading from a channel that nobody writes to). -type TurnLoop[T any] struct { - config TurnLoopConfig[T] +type TurnLoop[T any, M messageType] struct { + config TurnLoopConfig[T, M] buffer *turnBuffer[T] @@ -682,7 +682,7 @@ type TurnLoop[T any] struct { done chan struct{} - result *TurnLoopExitState[T] + result *TurnLoopExitState[T, M] stopOnce sync.Once @@ -702,14 +702,14 @@ type TurnLoop[T any] struct { loadCheckpointID string - onAgentEvents func(ctx context.Context, tc *TurnContext[T], events *AsyncIterator[*AgentEvent]) error + onAgentEvents func(ctx context.Context, tc *TurnContext[T, M], events *AsyncIterator[*TypedAgentEvent[M]]) error lateMu sync.Mutex lateItems []T lateSealed bool } -func (l *TurnLoop[T]) appendLate(item T) { +func (l *TurnLoop[T, M]) appendLate(item T) { l.lateMu.Lock() defer l.lateMu.Unlock() if l.lateSealed { @@ -744,7 +744,7 @@ func unmarshalTurnLoopCheckpoint[T any](data []byte) (*turnLoopCheckpoint[T], er return &c, nil } -func (l *TurnLoop[T]) saveTurnLoopCheckpoint(ctx context.Context, checkPointID string, c *turnLoopCheckpoint[T]) error { +func (l *TurnLoop[T, M]) saveTurnLoopCheckpoint(ctx context.Context, checkPointID string, c *turnLoopCheckpoint[T]) error { if l.config.Store == nil { return errors.New("checkpoint store is nil") } @@ -755,7 +755,7 @@ func (l *TurnLoop[T]) saveTurnLoopCheckpoint(ctx context.Context, checkPointID s return l.config.Store.Set(ctx, checkPointID, data) } -func (l *TurnLoop[T]) deleteTurnLoopCheckpoint(ctx context.Context, checkPointID string) error { +func (l *TurnLoop[T, M]) deleteTurnLoopCheckpoint(ctx context.Context, checkPointID string) error { if l.config.Store == nil { return nil } @@ -765,7 +765,7 @@ func (l *TurnLoop[T]) deleteTurnLoopCheckpoint(ctx context.Context, checkPointID return nil } -func (l *TurnLoop[T]) tryLoadCheckpoint(ctx context.Context) error { +func (l *TurnLoop[T, M]) tryLoadCheckpoint(ctx context.Context) error { checkPointID := l.config.CheckpointID if checkPointID == "" || l.config.Store == nil { return nil @@ -973,15 +973,15 @@ func UntilIdleFor(duration time.Duration) StopOption { } } -type pushConfig[T any] struct { +type pushConfig[T any, M messageType] struct { preempt bool preemptDelay time.Duration agentCancelOpts []AgentCancelOption - pushStrategy func(context.Context, *TurnContext[T]) []PushOption[T] + pushStrategy func(context.Context, *TurnContext[T, M]) []PushOption[T, M] } // PushOption is an option for Push(). -type PushOption[T any] func(*pushConfig[T]) +type PushOption[T any, M messageType] func(*pushConfig[T, M]) // WithPreempt signals that the current agent turn should be cancelled at the // specified safePoint after pushing the new item. The loop cancels the current @@ -1001,11 +1001,11 @@ type PushOption[T any] func(*pushConfig[T]) // passed to the same Push call, the last one wins. // // safePoint must not be zero; passing SafePoint(0) panics. -func WithPreempt[T any](safePoint SafePoint) PushOption[T] { +func WithPreempt[T any, M messageType](safePoint SafePoint) PushOption[T, M] { if safePoint == 0 { panic("adk: SafePoint must not be zero; use AfterToolCalls, AfterChatModel, or AnySafePoint") } - return func(cfg *pushConfig[T]) { + return func(cfg *pushConfig[T, M]) { cfg.preempt = true cfg.agentCancelOpts = []AgentCancelOption{ WithAgentCancelMode(safePoint.toCancelMode()), @@ -1019,11 +1019,11 @@ func WithPreempt[T any](safePoint SafePoint) PushOption[T] { // also receive the cancel signal and be torn down. // // safePoint must not be zero; passing SafePoint(0) panics. -func WithPreemptTimeout[T any](safePoint SafePoint, timeout time.Duration) PushOption[T] { +func WithPreemptTimeout[T any, M messageType](safePoint SafePoint, timeout time.Duration) PushOption[T, M] { if safePoint == 0 { panic("adk: SafePoint must not be zero; use AfterToolCalls, AfterChatModel, or AnySafePoint") } - return func(cfg *pushConfig[T]) { + return func(cfg *pushConfig[T, M]) { cfg.preempt = true cfg.agentCancelOpts = []AgentCancelOption{ WithAgentCancelMode(safePoint.toCancelMode()), @@ -1038,8 +1038,8 @@ func WithPreemptTimeout[T any](safePoint SafePoint, timeout time.Duration) PushO // immediately, but the preemption signal will be delayed by the specified // duration. This allows the current agent to continue processing for a grace // period before being preempted. -func WithPreemptDelay[T any](delay time.Duration) PushOption[T] { - return func(cfg *pushConfig[T]) { +func WithPreemptDelay[T any, M messageType](delay time.Duration) PushOption[T, M] { + return func(cfg *pushConfig[T, M]) { cfg.preemptDelay = delay } } @@ -1054,22 +1054,22 @@ func WithPreemptDelay[T any](delay time.Duration) PushOption[T] { // // Example: preempt only if the current turn is processing low-priority items: // -// loop.Push(urgentItem, WithPushStrategy(func(ctx context.Context, tc *TurnContext[MyItem]) []PushOption[MyItem] { +// loop.Push(urgentItem, WithPushStrategy(func(ctx context.Context, tc *TurnContext[MyItem, *schema.Message]) []PushOption[MyItem, *schema.Message] { // if tc == nil { // return nil // between turns, plain push // } // if isLowPriority(tc.Consumed) { -// return []PushOption[MyItem]{WithPreempt[MyItem](AnySafePoint)} +// return []PushOption[MyItem, *schema.Message]{WithPreempt[MyItem, *schema.Message](AnySafePoint)} // } // return nil // don't preempt high-priority work // })) -func WithPushStrategy[T any](fn func(ctx context.Context, tc *TurnContext[T]) []PushOption[T]) PushOption[T] { - return func(cfg *pushConfig[T]) { +func WithPushStrategy[T any, M messageType](fn func(ctx context.Context, tc *TurnContext[T, M]) []PushOption[T, M]) PushOption[T, M] { + return func(cfg *pushConfig[T, M]) { cfg.pushStrategy = fn } } -func defaultTurnLoopOnAgentEvents[T any](_ context.Context, _ *TurnContext[T], events *AsyncIterator[*AgentEvent]) error { +func defaultTurnLoopOnAgentEvents[T any, M messageType](_ context.Context, _ *TurnContext[T, M], events *AsyncIterator[*TypedAgentEvent[M]]) error { for { event, ok := events.Next() if !ok { @@ -1088,7 +1088,7 @@ func defaultTurnLoopOnAgentEvents[T any](_ context.Context, _ *TurnContext[T], e // Call Run to start the processing goroutine. // // NewTurnLoop panics if GenInput or PrepareAgent is nil. -func NewTurnLoop[T any](cfg TurnLoopConfig[T]) *TurnLoop[T] { +func NewTurnLoop[T any, M messageType](cfg TurnLoopConfig[T, M]) *TurnLoop[T, M] { if cfg.GenInput == nil { panic("adk: NewTurnLoop: GenInput is required") } @@ -1096,7 +1096,7 @@ func NewTurnLoop[T any](cfg TurnLoopConfig[T]) *TurnLoop[T] { panic("adk: NewTurnLoop: PrepareAgent is required") } - l := &TurnLoop[T]{ + l := &TurnLoop[T, M]{ config: cfg, buffer: newTurnBuffer[T](), done: make(chan struct{}), @@ -1106,12 +1106,12 @@ func NewTurnLoop[T any](cfg TurnLoopConfig[T]) *TurnLoop[T] { if cfg.OnAgentEvents != nil { l.onAgentEvents = cfg.OnAgentEvents } else { - l.onAgentEvents = defaultTurnLoopOnAgentEvents[T] + l.onAgentEvents = defaultTurnLoopOnAgentEvents[T, M] } return l } -func (l *TurnLoop[T]) start(ctx context.Context) { +func (l *TurnLoop[T, M]) start(ctx context.Context) { l.runOnce.Do(func() { atomic.StoreInt32(&l.started, 1) go l.run(ctx) @@ -1126,7 +1126,7 @@ func (l *TurnLoop[T]) start(ctx context.Context) { // Otherwise it starts fresh with whatever items were Push()-ed. // // Calling Run more than once is a no-op: only the first call starts the loop. -func (l *TurnLoop[T]) Run(ctx context.Context) { +func (l *TurnLoop[T, M]) Run(ctx context.Context) { l.start(ctx) } @@ -1156,8 +1156,8 @@ func (l *TurnLoop[T]) Run(ctx context.Context) { // the preemption signal. // Push returns immediately after the item is buffered, and a goroutine is spawned // to signal preemption after the delay. -func (l *TurnLoop[T]) Push(item T, opts ...PushOption[T]) (bool, <-chan struct{}) { - cfg := &pushConfig[T]{} +func (l *TurnLoop[T, M]) Push(item T, opts ...PushOption[T, M]) (bool, <-chan struct{}) { + cfg := &pushConfig[T, M]{} for _, opt := range opts { opt(cfg) } @@ -1173,19 +1173,19 @@ func (l *TurnLoop[T]) Push(item T, opts ...PushOption[T]) (bool, <-chan struct{} // then calls the strategy callback with a guaranteed-stable TurnContext. If the // strategy returns preempt options, the hold is kept and a preempt is requested; // otherwise the hold is released and the item is buffered as a plain push. -func (l *TurnLoop[T]) pushWithStrategy(item T, cfg *pushConfig[T]) (bool, <-chan struct{}) { +func (l *TurnLoop[T, M]) pushWithStrategy(item T, cfg *pushConfig[T, M]) (bool, <-chan struct{}) { strategy := cfg.pushStrategy runCtx, tcAny := l.preemptSig.holdAndGetTurn() if runCtx == nil { runCtx = context.Background() } - var tc *TurnContext[T] + var tc *TurnContext[T, M] if tcAny != nil { - tc = tcAny.(*TurnContext[T]) + tc = tcAny.(*TurnContext[T, M]) } realOpts := strategy(runCtx, tc) - cfg = &pushConfig[T]{} + cfg = &pushConfig[T, M]{} for _, opt := range realOpts { opt(cfg) } @@ -1235,7 +1235,7 @@ func (l *TurnLoop[T]) pushWithStrategy(item T, cfg *pushConfig[T]) (bool, <-chan return true, ack } -func (l *TurnLoop[T]) pushWithConfig(item T, cfg *pushConfig[T]) (bool, <-chan struct{}) { +func (l *TurnLoop[T, M]) pushWithConfig(item T, cfg *pushConfig[T, M]) (bool, <-chan struct{}) { if atomic.LoadInt32(&l.stopped) != 0 { l.appendLate(item) return false, nil @@ -1304,7 +1304,7 @@ func (l *TurnLoop[T]) pushWithConfig(item T, cfg *pushConfig[T]) (bool, <-chan s // all cancel-related options (WithImmediate, WithGraceful, WithGracefulTimeout) // degrade to "exit the loop on entering the next iteration" — the current // agent turn runs to completion before the loop exits. -func (l *TurnLoop[T]) Stop(opts ...StopOption) { +func (l *TurnLoop[T, M]) Stop(opts ...StopOption) { cfg := &stopConfig{} for _, opt := range opts { opt(cfg) @@ -1327,7 +1327,7 @@ func (l *TurnLoop[T]) Stop(opts ...StopOption) { l.commitStop() } -func (l *TurnLoop[T]) commitStop() { +func (l *TurnLoop[T, M]) commitStop() { l.stopOnce.Do(func() { l.stopSig.closeDone() atomic.StoreInt32(&l.stopped, 1) @@ -1341,12 +1341,12 @@ func (l *TurnLoop[T]) commitStop() { // // Wait blocks until Run is called AND the loop exits. If Run is // never called, Wait blocks forever. -func (l *TurnLoop[T]) Wait() *TurnLoopExitState[T] { +func (l *TurnLoop[T, M]) Wait() *TurnLoopExitState[T, M] { <-l.done return l.result } -func (l *TurnLoop[T]) run(ctx context.Context) { +func (l *TurnLoop[T, M]) run(ctx context.Context) { defer l.cleanup(ctx) if err := l.tryLoadCheckpoint(ctx); err != nil { @@ -1502,7 +1502,7 @@ func (l *TurnLoop[T]) run(ctx context.Context) { } } -func (l *TurnLoop[T]) setupBridgeStore(spec *turnRunSpec[T], runOpts []AgentRunOption) ([]AgentRunOption, *bridgeStore, error) { +func (l *TurnLoop[T, M]) setupBridgeStore(spec *turnRunSpec[T, M], runOpts []AgentRunOption) ([]AgentRunOption, *bridgeStore, error) { store := l.config.Store if store == nil && spec.isResume { return nil, nil, fmt.Errorf("failed to resume agent: checkpoint store is nil") @@ -1530,7 +1530,7 @@ func (l *TurnLoop[T]) setupBridgeStore(spec *turnRunSpec[T], runOpts []AgentRunO // On the first preempt whose cancel actually contributed (i.e. the cancel options // were accepted before the CancelError was finalized), preemptDone is closed to // wake runAgentAndHandleEvents's select. -func (l *TurnLoop[T]) watchPreemptSignal(done <-chan struct{}, agentCancelFunc AgentCancelFunc, preemptDone chan struct{}) { +func (l *TurnLoop[T, M]) watchPreemptSignal(done <-chan struct{}, agentCancelFunc AgentCancelFunc, preemptDone chan struct{}) { var lastGen uint64 for { select { @@ -1573,7 +1573,7 @@ func (l *TurnLoop[T]) watchPreemptSignal(done <-chan struct{}, agentCancelFunc A // On the first cancel that actually contributed (i.e. the cancel was accepted // before the CancelError was finalized), stoppedDone is closed to wake // runAgentAndHandleEvents's select. -func (l *TurnLoop[T]) watchStopSignal(done <-chan struct{}, agentCancelFunc AgentCancelFunc, stoppedDone chan struct{}) { +func (l *TurnLoop[T, M]) watchStopSignal(done <-chan struct{}, agentCancelFunc AgentCancelFunc, stoppedDone chan struct{}) { var lastGen uint64 stoppedClosed := false @@ -1612,12 +1612,12 @@ func (l *TurnLoop[T]) watchStopSignal(done <-chan struct{}, agentCancelFunc Agen } } -func (l *TurnLoop[T]) runAgentAndHandleEvents( +func (l *TurnLoop[T, M]) runAgentAndHandleEvents( ctx context.Context, - agent Agent, - spec *turnRunSpec[T], + agent TypedAgent[M], + spec *turnRunSpec[T, M], ) error { - var iter *AsyncIterator[*AgentEvent] + var iter *AsyncIterator[*TypedAgentEvent[M]] runOpts, ms, err := l.setupBridgeStore(spec, spec.runOpts) if err != nil { @@ -1631,7 +1631,7 @@ func (l *TurnLoop[T]) runAgentAndHandleEvents( if spec.input != nil { enableStreaming = spec.input.EnableStreaming } - runner := NewRunner(ctx, RunnerConfig{ + runner := NewTypedRunner[M](TypedRunnerConfig[M]{ EnableStreaming: enableStreaming, Agent: agent, CheckPointStore: ms, @@ -1640,7 +1640,7 @@ func (l *TurnLoop[T]) runAgentAndHandleEvents( preemptDone := make(chan struct{}) stoppedDone := make(chan struct{}) - tc := &TurnContext[T]{ + tc := &TurnContext[T, M]{ Loop: l, Consumed: spec.consumed, Preempted: preemptDone, @@ -1743,7 +1743,7 @@ func (l *TurnLoop[T]) runAgentAndHandleEvents( } } -func (l *TurnLoop[T]) cleanup(ctx context.Context) { +func (l *TurnLoop[T, M]) cleanup(ctx context.Context) { atomic.StoreInt32(&l.stopped, 1) unhandled := l.buffer.TakeAll() @@ -1777,7 +1777,7 @@ func (l *TurnLoop[T]) cleanup(ctx context.Context) { var takeLateOnce sync.Once var takeLateResult []T - l.result = &TurnLoopExitState[T]{ + l.result = &TurnLoopExitState[T, M]{ ExitReason: l.runErr, UnhandledItems: unhandled, CanceledItems: l.canceledItems, diff --git a/adk/turn_loop_test.go b/adk/turn_loop_test.go index 309c84f0e..8a65a6c2f 100644 --- a/adk/turn_loop_test.go +++ b/adk/turn_loop_test.go @@ -161,13 +161,13 @@ func (a *turnLoopStopModeProbeAgent) Run(ctx context.Context, input *AgentInput, return iter } -func newAndRunTurnLoop[T any](ctx context.Context, cfg TurnLoopConfig[T]) *TurnLoop[T] { - l := NewTurnLoop(cfg) +func newAndRunTurnLoop[T any, M messageType](ctx context.Context, cfg TurnLoopConfig[T, M]) *TurnLoop[T, M] { + l := NewTurnLoop[T, M](cfg) l.Run(ctx) return l } -func newPreemptTestLoop(t *testing.T, agent *turnLoopCancellableMockAgent) *TurnLoop[string] { +func newPreemptTestLoop(t *testing.T, agent *turnLoopCancellableMockAgent) *TurnLoop[string, *schema.Message] { t.Helper() agentStarted := make(chan struct{}) @@ -179,12 +179,12 @@ func newPreemptTestLoop(t *testing.T, agent *turnLoopCancellableMockAgent) *Turn return originalRunFunc(ctx, input) } - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return agent, nil }, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: []string{items[0]}, Remaining: items[1:], @@ -207,17 +207,17 @@ func TestTurnLoop_RunAndPush(t *testing.T) { processedItems := make([]string, 0) var mu sync.Mutex - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { mu.Lock() processedItems = append(processedItems, items...) mu.Unlock() - return &GenInputResult[string]{ + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -238,14 +238,14 @@ func TestTurnLoop_RunAndPush(t *testing.T) { } func TestTurnLoop_PushReturnsErrorAfterStop(t *testing.T) { - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -257,11 +257,11 @@ func TestTurnLoop_PushReturnsErrorAfterStop(t *testing.T) { } func TestTurnLoop_StopIsIdempotent(t *testing.T) { - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -275,11 +275,11 @@ func TestTurnLoop_StopIsIdempotent(t *testing.T) { } func TestTurnLoop_WaitMultipleGoroutines(t *testing.T) { - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -287,7 +287,7 @@ func TestTurnLoop_WaitMultipleGoroutines(t *testing.T) { loop.Stop() var wg sync.WaitGroup - results := make([]*TurnLoopExitState[string], 3) + results := make([]*TurnLoopExitState[string, *schema.Message], 3) for i := 0; i < 3; i++ { i := i @@ -308,17 +308,17 @@ func TestTurnLoop_UnhandledItemsOnStop(t *testing.T) { started := make(chan struct{}) blocked := make(chan struct{}) - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { close(started) <-blocked - return &GenInputResult[string]{ + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{}, Consumed: items[:1], Remaining: items[1:], }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -339,11 +339,11 @@ func TestTurnLoop_UnhandledItemsOnStop(t *testing.T) { func TestTurnLoop_GenInputError(t *testing.T) { genErr := errors.New("gen input error") - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { return nil, genErr }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -357,11 +357,11 @@ func TestTurnLoop_GenInputError(t *testing.T) { func TestTurnLoop_GetAgentError(t *testing.T) { agentErr := errors.New("get agent error") - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return nil, agentErr }, }) @@ -376,19 +376,19 @@ func TestTurnLoop_BatchProcessing(t *testing.T) { var batches [][]string var mu sync.Mutex - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { mu.Lock() batches = append(batches, items) mu.Unlock() - return &GenInputResult[string]{ + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{}, Consumed: items[:1], Remaining: items[1:], }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -409,11 +409,11 @@ func TestTurnLoop_BatchProcessing(t *testing.T) { } func TestTurnLoop_StopWithMode(t *testing.T) { - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -448,18 +448,18 @@ func TestTurnLoop_Preempt_CancelsCurrentAgent(t *testing.T) { secondGenInputCalled := make(chan struct{}) secondGenInputOnce := sync.Once{} - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return agent, nil }, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { count := atomic.AddInt32(&genInputCalls, 1) if count >= 2 { secondGenInputOnce.Do(func() { close(secondGenInputCalled) }) } - return &GenInputResult[string]{ + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: []string{items[0]}, Remaining: items[1:], @@ -475,7 +475,7 @@ func TestTurnLoop_Preempt_CancelsCurrentAgent(t *testing.T) { t.Fatal("agent did not start") } - loop.Push("urgent", WithPreempt[string](AnySafePoint)) + loop.Push("urgent", WithPreempt[string, *schema.Message](AnySafePoint)) select { case <-agentCancelled: @@ -528,16 +528,16 @@ func TestTurnLoop_Preempt_DiscardsConsumedItems(t *testing.T) { }, } - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return agent, nil }, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { mu.Lock() genInputResults = append(genInputResults, items) mu.Unlock() - return &GenInputResult[string]{ + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: []string{items[0]}, Remaining: items[1:], @@ -553,7 +553,7 @@ func TestTurnLoop_Preempt_DiscardsConsumedItems(t *testing.T) { t.Fatal("agent did not start") } - loop.Push("urgent", WithPreempt[string](AnySafePoint)) + loop.Push("urgent", WithPreempt[string, *schema.Message](AnySafePoint)) select { case <-agentDone: @@ -596,7 +596,7 @@ func TestTurnLoop_Preempt_WithAgentCancelMode(t *testing.T) { loop := newPreemptTestLoop(t, agent) - loop.Push("urgent", WithPreempt[string](AfterToolCalls)) + loop.Push("urgent", WithPreempt[string, *schema.Message](AfterToolCalls)) select { case <-cancelFuncCalled: @@ -632,7 +632,7 @@ func TestTurnLoop_PreemptAck_ClosesAfterCancelIsInitiated(t *testing.T) { loop := newPreemptTestLoop(t, agent) - ok, ack := loop.Push("urgent", WithPreempt[string](AfterToolCalls)) + ok, ack := loop.Push("urgent", WithPreempt[string, *schema.Message](AfterToolCalls)) assert.True(t, ok) assert.NotNil(t, ack) @@ -656,16 +656,16 @@ func TestTurnLoop_PreemptAck_ClosesAfterCancelIsInitiated(t *testing.T) { } func TestTurnLoop_PreemptAck_ClosesImmediatelyIfLoopNotStarted(t *testing.T) { - loop := NewTurnLoop(TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + loop := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) - ok, ack := loop.Push("urgent", WithPreempt[string](AnySafePoint)) + ok, ack := loop.Push("urgent", WithPreempt[string, *schema.Message](AnySafePoint)) assert.True(t, ok) assert.NotNil(t, ack) @@ -698,14 +698,14 @@ func TestTurnLoop_Preempt_EscalatesOnSecondPreempt(t *testing.T) { loop := newPreemptTestLoop(t, agent) - loop.Push("urgent1", WithPreempt[string](AfterChatModel)) + loop.Push("urgent1", WithPreempt[string, *schema.Message](AfterChatModel)) select { case <-firstCancelSeen: case <-time.After(1 * time.Second): t.Fatal("first preempt did not trigger cancel") } - loop.Push("urgent2", WithPreemptTimeout[string](AnySafePoint, time.Millisecond)) + loop.Push("urgent2", WithPreemptTimeout[string, *schema.Message](AnySafePoint, time.Millisecond)) wantMode := CancelAfterChatModel | CancelAfterToolCalls deadline := time.Now().Add(1 * time.Second) @@ -759,14 +759,14 @@ func TestTurnLoop_Preempt_JoinsSafePointModesOnSecondPreempt(t *testing.T) { loop := newPreemptTestLoop(t, agent) - loop.Push("urgent1", WithPreempt[string](AfterChatModel)) + loop.Push("urgent1", WithPreempt[string, *schema.Message](AfterChatModel)) select { case <-firstCancelSeen: case <-time.After(1 * time.Second): t.Fatal("first preempt did not trigger cancel") } - loop.Push("urgent2", WithPreempt[string](AfterToolCalls)) + loop.Push("urgent2", WithPreempt[string, *schema.Message](AfterToolCalls)) want := CancelAfterChatModel | CancelAfterToolCalls deadline := time.Now().Add(1 * time.Second) @@ -815,12 +815,12 @@ func TestTurnLoop_Push_WithoutPreempt_DoesNotCancel(t *testing.T) { }, } - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return agent, nil }, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: []string{items[0]}, Remaining: items[1:], @@ -879,12 +879,12 @@ func TestTurnLoop_PreemptDelay_NoMispreemptOnNaturalCompletion(t *testing.T) { }, } - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return agent, nil }, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: []string{items[0]}, Remaining: items[1:], @@ -900,7 +900,7 @@ func TestTurnLoop_PreemptDelay_NoMispreemptOnNaturalCompletion(t *testing.T) { t.Fatal("agent1 did not start") } - loop.Push("second", WithPreempt[string](AnySafePoint), WithPreemptDelay[string](500*time.Millisecond)) + loop.Push("second", WithPreempt[string, *schema.Message](AnySafePoint), WithPreemptDelay[string, *schema.Message](500*time.Millisecond)) select { case <-agent1Done: @@ -929,12 +929,12 @@ func TestTurnLoop_PreemptDelay_NoMispreemptOnNaturalCompletion(t *testing.T) { func TestTurnLoop_ConcurrentPush(t *testing.T) { var count int32 - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { atomic.AddInt32(&count, int32(len(items))) - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -967,14 +967,14 @@ func TestTurnLoop_StopAfterReceive_RecoverItem(t *testing.T) { receiveStarted := make(chan struct{}) cancelDone := make(chan struct{}) - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { close(receiveStarted) <-cancelDone time.Sleep(50 * time.Millisecond) - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -992,17 +992,17 @@ func TestTurnLoop_StopAfterReceive_RecoverItem(t *testing.T) { func TestTurnLoop_StopAfterGenInput_RecoverConsumed(t *testing.T) { genInputDone := make(chan struct{}) - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { close(genInputDone) time.Sleep(50 * time.Millisecond) - return &GenInputResult[string]{ + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{}, Consumed: items[:1], Remaining: items[1:], }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { time.Sleep(100 * time.Millisecond) return &turnLoopMockAgent{name: "test"}, nil }, @@ -1023,15 +1023,15 @@ func TestTurnLoop_StopAfterGenInput_RecoverConsumed(t *testing.T) { func TestTurnLoop_GetAgentError_RecoverConsumed(t *testing.T) { agentErr := errors.New("get agent error") - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{}, Consumed: items[:1], Remaining: items[1:], }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], c []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], c []string) (Agent, error) { return nil, agentErr }, }) @@ -1047,11 +1047,11 @@ func TestTurnLoop_GetAgentError_RecoverConsumed(t *testing.T) { func TestTurnLoop_GenInputError_RecoverItems(t *testing.T) { genErr := errors.New("gen input error") - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { return nil, genErr }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -1069,8 +1069,8 @@ func TestTurnLoop_GenInputError_RecoverItems(t *testing.T) { func TestTurnLoop_PrepareAgentError_RecoverItemsInOrder(t *testing.T) { agentErr := errors.New("prepare agent error") - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { var urgent string remaining := make([]string, 0, len(items)) for _, item := range items { @@ -1081,19 +1081,19 @@ func TestTurnLoop_PrepareAgentError_RecoverItemsInOrder(t *testing.T) { } } if urgent != "" { - return &GenInputResult[string]{ + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{}, Consumed: []string{urgent}, Remaining: remaining, }, nil } - return &GenInputResult[string]{ + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{}, Consumed: items[:1], Remaining: items[1:], }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return nil, agentErr }, }) @@ -1119,16 +1119,16 @@ func TestTurnLoop_ContextCancel(t *testing.T) { genInputStarted := make(chan struct{}) genInputDone := make(chan struct{}) - loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { close(genInputStarted) <-genInputDone if err := ctx.Err(); err != nil { return nil, err } - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], c []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], c []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -1147,16 +1147,16 @@ func TestTurnLoop_ContextDeadlineExceeded(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) defer cancel() - loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { select { case <-time.After(100 * time.Millisecond): - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil case <-ctx.Done(): return nil, ctx.Err() } }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], c []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], c []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -1171,11 +1171,11 @@ func TestTurnLoop_ContextCancelBeforeReceive(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() - loop := NewTurnLoop(TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + loop := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], c []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], c []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -1195,11 +1195,11 @@ func TestTurnLoop_ContextCancelDuringBlockingReceive(t *testing.T) { // the context monitoring goroutine closes the buffer, which unblocks Receive(). ctx, cancel := context.WithCancel(context.Background()) - loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], c []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], c []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -1216,19 +1216,19 @@ func TestTurnLoop_ContextCancelAfterGenInput_RecoverItems(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) genInputCount := 0 - loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { genInputCount++ if genInputCount == 1 { cancel() } - return &GenInputResult[string]{ + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{}, Consumed: items[:1], Remaining: items[1:], }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], c []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], c []string) (Agent, error) { if err := ctx.Err(); err != nil { return nil, err } @@ -1249,17 +1249,17 @@ func TestTurnLoop_OnAgentEventsReceivesEvents(t *testing.T) { var receivedConsumed []string var mu sync.Mutex - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, - OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string, *schema.Message], events *AsyncIterator[*AgentEvent]) error { mu.Lock() receivedConsumed = append(receivedConsumed, tc.Consumed...) mu.Unlock() @@ -1294,17 +1294,17 @@ func TestTurnLoop_OnAgentEventsReceivesEvents(t *testing.T) { func TestTurnLoop_StopDuringAgentExecution(t *testing.T) { agentStarted := make(chan struct{}) - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, - OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string, *schema.Message], events *AsyncIterator[*AgentEvent]) error { close(agentStarted) time.Sleep(200 * time.Millisecond) for { @@ -1340,15 +1340,15 @@ func TestTurnLoop_BareStop_AgentRunsToCompletion(t *testing.T) { turnsExecuted := int32(0) - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: []string{items[0]}, Remaining: items[1:], }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{ name: "worker", runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { @@ -1437,16 +1437,16 @@ func TestTurnLoop_StopCheckPointIDInCancelError(t *testing.T) { }) assert.NoError(t, err) - loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string, *schema.Message]{ Store: store, CheckpointID: checkpointID, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return agent, nil }, }) @@ -1490,15 +1490,15 @@ func TestTurnLoop_StopWithoutCheckpointIDDoesNotPersist(t *testing.T) { }) assert.NoError(t, err) - loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string, *schema.Message]{ Store: store, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return agent, nil }, }) @@ -1525,13 +1525,13 @@ func TestTurnLoop_StopWhileIdle_SkipsCheckpoint(t *testing.T) { } cpID := "idle-session" - loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string, *schema.Message]{ Store: store, CheckpointID: cpID, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -1551,13 +1551,13 @@ func TestTurnLoop_StopBetweenTurnsAndResume(t *testing.T) { store := &turnLoopCheckpointStore{m: make(map[string][]byte)} cpID := "between-turns-session" - loop := NewTurnLoop(TurnLoopConfig[string]{ + loop := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ Store: store, CheckpointID: cpID, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -1572,22 +1572,22 @@ func TestTurnLoop_StopBetweenTurnsAndResume(t *testing.T) { var seen []string var mu sync.Mutex - loop2 := NewTurnLoop(TurnLoopConfig[string]{ + loop2 := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ Store: store, CheckpointID: cpID, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { mu.Lock() seen = append([]string{}, items...) mu.Unlock() - return &GenInputResult[string]{ + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, - OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string, *schema.Message], events *AsyncIterator[*AgentEvent]) error { for { _, ok := events.Next() if !ok { @@ -1633,16 +1633,16 @@ func TestTurnLoop_StopDuringAgentExecution_PersistAndResume(t *testing.T) { }) assert.NoError(t, err) - loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string, *schema.Message]{ Store: store, CheckpointID: cpID, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return agent, nil }, }) @@ -1663,25 +1663,25 @@ func TestTurnLoop_StopDuringAgentExecution_PersistAndResume(t *testing.T) { var consumed2 []string var genResumeCalled bool var genInputCalled bool - loop2 := NewTurnLoop(TurnLoopConfig[string]{ + loop2 := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ Store: store, CheckpointID: cpID, - GenResume: func(ctx context.Context, _ *TurnLoop[string], canceledItems []string, unhandledItems []string, newItems []string) (*GenResumeResult[string], error) { + GenResume: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], canceledItems []string, unhandledItems []string, newItems []string) (*GenResumeResult[string, *schema.Message], error) { genResumeCalled = true - return &GenResumeResult[string]{ + return &GenResumeResult[string, *schema.Message]{ Consumed: canceledItems, Remaining: append(append([]string{}, unhandledItems...), newItems...), }, nil }, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { genInputCalled = true - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { consumed2 = append([]string{}, consumed...) return agent, nil }, - OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string, *schema.Message], events *AsyncIterator[*AgentEvent]) error { for { _, ok := events.Next() if !ok { @@ -1704,16 +1704,16 @@ func TestTurnLoop_StopDuringAgentExecution_PersistAndResume(t *testing.T) { func TestTurnLoop_CheckpointIDWithoutStore_FreshStart(t *testing.T) { ctx := context.Background() var genInputCalled bool - loop := NewTurnLoop(TurnLoopConfig[string]{ + loop := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ CheckpointID: "some-id", - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { genInputCalled = true - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, - OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string, *schema.Message], events *AsyncIterator[*AgentEvent]) error { for { if _, ok := events.Next(); !ok { break @@ -1734,17 +1734,17 @@ func TestTurnLoop_CheckpointNotFound_FreshStart(t *testing.T) { ctx := context.Background() store := &turnLoopCheckpointStore{m: make(map[string][]byte)} var genInputCalled bool - loop := NewTurnLoop(TurnLoopConfig[string]{ + loop := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ Store: store, CheckpointID: "nonexistent-id", - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { genInputCalled = true - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, - OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string, *schema.Message], events *AsyncIterator[*AgentEvent]) error { for { if _, ok := events.Next(); !ok { break @@ -1767,17 +1767,17 @@ func TestTurnLoop_CheckpointEmptyData_TreatedAsNoCheckpoint(t *testing.T) { store.m["cp-empty"] = nil var genInputCalled bool - loop := NewTurnLoop(TurnLoopConfig[string]{ + loop := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ Store: store, CheckpointID: "cp-empty", - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { genInputCalled = true - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, - OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string, *schema.Message], events *AsyncIterator[*AgentEvent]) error { for { if _, ok := events.Next(); !ok { break @@ -1810,13 +1810,13 @@ func (s *errorCheckpointStore) Set(_ context.Context, _ string, _ []byte) error func TestTurnLoop_CheckpointLoadError_ReturnsError(t *testing.T) { ctx := context.Background() store := &errorCheckpointStore{getErr: fmt.Errorf("store unavailable")} - loop := NewTurnLoop(TurnLoopConfig[string]{ + loop := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ Store: store, CheckpointID: "cp-1", - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -1831,13 +1831,13 @@ func TestTurnLoop_CheckpointCorruptData_ReturnsError(t *testing.T) { ctx := context.Background() store := &turnLoopCheckpointStore{m: make(map[string][]byte)} store.m["cp-corrupt"] = []byte("not-valid-gob-data") - loop := NewTurnLoop(TurnLoopConfig[string]{ + loop := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ Store: store, CheckpointID: "cp-corrupt", - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -1869,16 +1869,16 @@ func TestTurnLoop_CheckpointSaveError_ReturnsError(t *testing.T) { }) assert.NoError(t, err) - loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string, *schema.Message]{ Store: saveStore, CheckpointID: "cp-1", - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return agent, nil }, }) @@ -1897,13 +1897,13 @@ func TestTurnLoop_StaleCheckpointDeletion_OnCleanResume(t *testing.T) { store := &turnLoopCheckpointStore{m: make(map[string][]byte)} cpID := "stale-session" - loop1 := NewTurnLoop(TurnLoopConfig[string]{ + loop1 := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ Store: store, CheckpointID: cpID, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -1917,19 +1917,19 @@ func TestTurnLoop_StaleCheckpointDeletion_OnCleanResume(t *testing.T) { store.mu.Unlock() assert.True(t, exists, "checkpoint should exist after first loop saves it") - loop2 := NewTurnLoop(TurnLoopConfig[string]{ + loop2 := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ Store: store, CheckpointID: cpID, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, - OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string, *schema.Message], events *AsyncIterator[*AgentEvent]) error { for { if _, ok := events.Next(); !ok { break @@ -1955,13 +1955,13 @@ func TestTurnLoop_StaleCheckpointDeletion_ContextCancel(t *testing.T) { store := &deletableCheckpointStore{turnLoopCheckpointStore: turnLoopCheckpointStore{m: make(map[string][]byte)}} cpID := "delete-on-cancel" - loop1 := NewTurnLoop(TurnLoopConfig[string]{ + loop1 := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ Store: store, CheckpointID: cpID, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -1976,19 +1976,19 @@ func TestTurnLoop_StaleCheckpointDeletion_ContextCancel(t *testing.T) { assert.True(t, exists, "checkpoint saved after loop1") ctx2, cancel2 := context.WithCancel(ctx) - loop2 := NewTurnLoop(TurnLoopConfig[string]{ + loop2 := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ Store: store, CheckpointID: cpID, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, - OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string, *schema.Message], events *AsyncIterator[*AgentEvent]) error { for { if _, ok := events.Next(); !ok { break @@ -2032,13 +2032,13 @@ func TestTurnLoop_CheckpointDeleter_CalledOnContextCancel(t *testing.T) { } cpID := "deleter-session" - loop1 := NewTurnLoop(TurnLoopConfig[string]{ + loop1 := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ Store: store, CheckpointID: cpID, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -2053,19 +2053,19 @@ func TestTurnLoop_CheckpointDeleter_CalledOnContextCancel(t *testing.T) { assert.True(t, exists, "checkpoint saved after loop1") ctx2, cancel2 := context.WithCancel(ctx) - loop2 := NewTurnLoop(TurnLoopConfig[string]{ + loop2 := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ Store: store, CheckpointID: cpID, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, - OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string, *schema.Message], events *AsyncIterator[*AgentEvent]) error { for { if _, ok := events.Next(); !ok { break @@ -2111,16 +2111,16 @@ func TestTurnLoop_GenResumeNil_Error(t *testing.T) { }) assert.NoError(t, err) - loop1 := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + loop1 := newAndRunTurnLoop(ctx, TurnLoopConfig[string, *schema.Message]{ Store: store, CheckpointID: cpID, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return agent, nil }, }) @@ -2129,13 +2129,13 @@ func TestTurnLoop_GenResumeNil_Error(t *testing.T) { loop1.Stop(WithImmediate()) loop1.Wait() - loop2 := NewTurnLoop(TurnLoopConfig[string]{ + loop2 := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ Store: store, CheckpointID: cpID, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -2150,13 +2150,13 @@ func TestTurnLoop_SameCheckpointID_OverwritePattern(t *testing.T) { store := &turnLoopCheckpointStore{m: make(map[string][]byte)} cpID := "overwrite-session" - loop1 := NewTurnLoop(TurnLoopConfig[string]{ + loop1 := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ Store: store, CheckpointID: cpID, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -2171,13 +2171,13 @@ func TestTurnLoop_SameCheckpointID_OverwritePattern(t *testing.T) { store.mu.Unlock() assert.NotEmpty(t, data1) - loop2 := NewTurnLoop(TurnLoopConfig[string]{ + loop2 := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ Store: store, CheckpointID: cpID, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -2194,22 +2194,22 @@ func TestTurnLoop_SameCheckpointID_OverwritePattern(t *testing.T) { var seen []string var mu sync.Mutex - loop3 := NewTurnLoop(TurnLoopConfig[string]{ + loop3 := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ Store: store, CheckpointID: cpID, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { mu.Lock() seen = append([]string{}, items...) mu.Unlock() - return &GenInputResult[string]{ + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, - OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string, *schema.Message], events *AsyncIterator[*AgentEvent]) error { for { if _, ok := events.Next(); !ok { break @@ -2243,13 +2243,13 @@ func TestTurnLoop_CheckpointHasRunnerStateButEmptyBytes(t *testing.T) { assert.NoError(t, err) store.m[cpID] = data - loop := NewTurnLoop(TurnLoopConfig[string]{ + loop := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ Store: store, CheckpointID: cpID, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -2283,16 +2283,16 @@ func TestTurnLoop_GenResumeReturnsError(t *testing.T) { }) assert.NoError(t, err) - loop1 := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + loop1 := newAndRunTurnLoop(ctx, TurnLoopConfig[string, *schema.Message]{ Store: store, CheckpointID: cpID, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return agent, nil }, }) @@ -2302,16 +2302,16 @@ func TestTurnLoop_GenResumeReturnsError(t *testing.T) { loop1.Wait() genResumeErr := fmt.Errorf("resume callback failed") - loop2 := NewTurnLoop(TurnLoopConfig[string]{ + loop2 := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ Store: store, CheckpointID: cpID, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - GenResume: func(ctx context.Context, _ *TurnLoop[string], canceled, unhandled, newItems []string) (*GenResumeResult[string], error) { + GenResume: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], canceled, unhandled, newItems []string) (*GenResumeResult[string, *schema.Message], error) { return nil, genResumeErr }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -2342,16 +2342,16 @@ func TestTurnLoop_CheckpointSaveError_MergesWithExistingError(t *testing.T) { }) assert.NoError(t, err) - loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string, *schema.Message]{ Store: saveStore, CheckpointID: "cp-merge-err", - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return agent, nil }, }) @@ -2390,16 +2390,16 @@ func TestTurnLoop_ResumeWithParams(t *testing.T) { }) assert.NoError(t, err) - loop1 := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + loop1 := newAndRunTurnLoop(ctx, TurnLoopConfig[string, *schema.Message]{ Store: store, CheckpointID: cpID, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return agent, nil }, }) @@ -2411,26 +2411,26 @@ func TestTurnLoop_ResumeWithParams(t *testing.T) { assert.True(t, errors.As(exit1.ExitReason, &ce)) var resumeParamsUsed *ResumeParams - loop2 := NewTurnLoop(TurnLoopConfig[string]{ + loop2 := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ Store: store, CheckpointID: cpID, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - GenResume: func(ctx context.Context, _ *TurnLoop[string], canceled, unhandled, newItems []string) (*GenResumeResult[string], error) { + GenResume: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], canceled, unhandled, newItems []string) (*GenResumeResult[string, *schema.Message], error) { params := &ResumeParams{ Targets: map[string]any{"some-address": "user-data"}, } resumeParamsUsed = params - return &GenResumeResult[string]{ + return &GenResumeResult[string, *schema.Message]{ ResumeParams: params, Consumed: append(append(canceled, unhandled...), newItems...), }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return agent, nil }, - OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string, *schema.Message], events *AsyncIterator[*AgentEvent]) error { for { if _, ok := events.Next(); !ok { break @@ -2451,14 +2451,14 @@ func TestTurnLoop_Stop_EscalatesCancelMode(t *testing.T) { ctx := context.Background() agentStarted := make(chan *cancelContext, 1) probe := &turnLoopStopModeProbeAgent{ccCh: agentStarted} - loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return probe, nil }, }) @@ -2491,14 +2491,14 @@ func TestTurnLoop_Stop_EscalatesCancelMode(t *testing.T) { func TestTurnLoop_DefaultOnAgentEvents_ErrorPropagation(t *testing.T) { agentErr := errors.New("agent execution error") - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{ name: "test", runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { @@ -2519,17 +2519,17 @@ func TestTurnLoop_DefaultOnAgentEvents_ErrorPropagation(t *testing.T) { func TestTurnLoop_OnAgentEventsError(t *testing.T) { handlerErr := errors.New("event handler error") - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, - OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string, *schema.Message], events *AsyncIterator[*AgentEvent]) error { // Drain events then return error for { _, ok := events.Next() @@ -2549,12 +2549,12 @@ func TestTurnLoop_OnAgentEventsError(t *testing.T) { func TestTurnLoop_StopCallFromGenInput(t *testing.T) { // Test that calling Stop() from within GenInput works correctly - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, loop *TurnLoop[string], items []string) (*GenInputResult[string], error) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, loop *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { loop.Stop() - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -2569,18 +2569,18 @@ func TestTurnLoop_PushFromOnAgentEvents(t *testing.T) { // Test that calling Push() from within OnAgentEvents works pushCount := int32(0) - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: []string{items[0]}, Remaining: items[1:], }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, - OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string, *schema.Message], events *AsyncIterator[*AgentEvent]) error { for { _, ok := events.Next() if !ok { @@ -2613,17 +2613,17 @@ func TestNewTurnLoop_PushBeforeRun(t *testing.T) { var processedItems []string var mu sync.Mutex - loop := NewTurnLoop(TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + loop := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { mu.Lock() processedItems = append(processedItems, items...) mu.Unlock() - return &GenInputResult[string]{ + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -2652,12 +2652,12 @@ func TestNewTurnLoop_PushBeforeRun(t *testing.T) { func TestNewTurnLoop_StopBeforeRun(t *testing.T) { // Stop before Run sets the stopped flag. When Run is called, the loop // exits immediately and buffered items appear as UnhandledItems. - loop := NewTurnLoop(TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + loop := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { t.Fatal("GenInput should not be called") return nil, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { t.Fatal("PrepareAgent should not be called") return nil, nil }, @@ -2680,16 +2680,16 @@ func TestNewTurnLoop_StopBeforeRun(t *testing.T) { func TestNewTurnLoop_WaitBeforeRun(t *testing.T) { // Wait blocks until Run is called AND the loop exits. - loop := NewTurnLoop(TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + loop := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) - waitDone := make(chan *TurnLoopExitState[string], 1) + waitDone := make(chan *TurnLoopExitState[string, *schema.Message], 1) go func() { waitDone <- loop.Wait() }() @@ -2718,12 +2718,12 @@ func TestNewTurnLoop_WaitBeforeRun(t *testing.T) { func TestNewTurnLoop_RunIsIdempotent(t *testing.T) { var genInputCalls int32 - loop := NewTurnLoop(TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + loop := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { atomic.AddInt32(&genInputCalls, 1) - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -2744,12 +2744,12 @@ func TestNewTurnLoop_RunIsIdempotent(t *testing.T) { func TestNewTurnLoop_StopBeforeRun_ThenWait(t *testing.T) { // Demonstrates the full sequence: create, push, stop, run, wait. - loop := NewTurnLoop(TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + loop := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { t.Fatal("GenInput should not be called after Stop") return nil, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { t.Fatal("PrepareAgent should not be called after Stop") return nil, nil }, @@ -2773,12 +2773,12 @@ func TestNewTurnLoop_ConcurrentPushAndRun(t *testing.T) { for i := 0; i < 100; i++ { var count int32 - loop := NewTurnLoop(TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + loop := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { atomic.AddInt32(&count, int32(len(items))) - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -2820,17 +2820,17 @@ func TestTurnLoop_RunCtx_Propagation(t *testing.T) { const traceVal = "trace-123" var prepareCtxVal, agentCtxVal, eventsCtxVal string - cfg := TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, loop *TurnLoop[string], items []string) (*GenInputResult[string], error) { + cfg := TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, loop *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { // Derive a new context with per-item trace data runCtx := context.WithValue(ctx, turnCtxKey{}, traceVal) - return &GenInputResult[string]{ + return &GenInputResult[string, *schema.Message]{ RunCtx: runCtx, Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, loop *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, loop *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { if v, ok := ctx.Value(turnCtxKey{}).(string); ok { prepareCtxVal = v } @@ -2844,7 +2844,7 @@ func TestTurnLoop_RunCtx_Propagation(t *testing.T) { }, }, nil }, - OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string, *schema.Message], events *AsyncIterator[*AgentEvent]) error { if v, ok := ctx.Value(turnCtxKey{}).(string); ok { eventsCtxVal = v } @@ -2873,14 +2873,14 @@ func TestTurnLoop_TurnContext_PreemptedChannel(t *testing.T) { preemptedSeen := make(chan struct{}) agentStarted := make(chan struct{}) - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopCancellableMockAgent{ name: "slow", runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { @@ -2889,7 +2889,7 @@ func TestTurnLoop_TurnContext_PreemptedChannel(t *testing.T) { }, }, nil }, - OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string, *schema.Message], events *AsyncIterator[*AgentEvent]) error { close(agentStarted) select { case <-tc.Preempted: @@ -2909,7 +2909,7 @@ func TestTurnLoop_TurnContext_PreemptedChannel(t *testing.T) { loop.Push("msg1") <-agentStarted - loop.Push("msg2", WithPreemptTimeout[string](AnySafePoint, time.Millisecond)) + loop.Push("msg2", WithPreemptTimeout[string, *schema.Message](AnySafePoint, time.Millisecond)) select { case <-preemptedSeen: @@ -3092,13 +3092,13 @@ func TestTurnLoop_ConcurrentPreemptsDuringTurn(t *testing.T) { var genInputCount int32 - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return agent, nil }, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { atomic.AddInt32(&genInputCount, 1) - return &GenInputResult[string]{ + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{}, Consumed: items, }, nil @@ -3118,7 +3118,7 @@ func TestTurnLoop_ConcurrentPreemptsDuringTurn(t *testing.T) { wg.Add(1) go func(i int) { defer wg.Done() - ok, ack := loop.Push(fmt.Sprintf("urgent-%d", i), WithPreemptTimeout[string](AnySafePoint, 10*time.Millisecond)) + ok, ack := loop.Push(fmt.Sprintf("urgent-%d", i), WithPreemptTimeout[string, *schema.Message](AnySafePoint, 10*time.Millisecond)) if ok && ack != nil { select { case <-ack: @@ -3143,18 +3143,18 @@ func TestTurnLoop_PreemptDuringTurnTransition(t *testing.T) { firstTurnDone := make(chan struct{}) firstTurnOnce := sync.Once{} - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "fast"}, nil }, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { count := atomic.AddInt32(&turnCount, 1) if count == 1 { firstTurnOnce.Do(func() { close(firstTurnDone) }) } - return &GenInputResult[string]{ + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{}, Consumed: items, }, nil @@ -3171,7 +3171,7 @@ func TestTurnLoop_PreemptDuringTurnTransition(t *testing.T) { time.Sleep(50 * time.Millisecond) - ok, ack := loop.Push("transitional", WithPreempt[string](AnySafePoint)) + ok, ack := loop.Push("transitional", WithPreempt[string, *schema.Message](AnySafePoint)) assert.True(t, ok, "push should succeed") if ack != nil { select { @@ -3213,18 +3213,18 @@ func TestTurnLoop_PushStrategy_DuringTurnTransition(t *testing.T) { secondTurnDone := make(chan struct{}) secondTurnOnce := sync.Once{} - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return agent, nil }, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { count := atomic.AddInt32(&genInputCount, 1) if count >= 2 { secondTurnOnce.Do(func() { close(secondTurnDone) }) } - return &GenInputResult[string]{ + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{}, Consumed: items, }, nil @@ -3243,12 +3243,12 @@ func TestTurnLoop_PushStrategy_DuringTurnTransition(t *testing.T) { var strategyTCNotNil int32 go func() { - loop.Push("strategic-item", WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { + loop.Push("strategic-item", WithPushStrategy(func(ctx context.Context, tc *TurnContext[string, *schema.Message]) []PushOption[string, *schema.Message] { if tc != nil { atomic.StoreInt32(&strategyTCNotNil, 1) } <-strategyBlocker - return []PushOption[string]{WithPreempt[string](AnySafePoint)} + return []PushOption[string, *schema.Message]{WithPreempt[string, *schema.Message](AnySafePoint)} })) }() @@ -3288,12 +3288,12 @@ func TestTurnLoop_ConcurrentPreemptAndStop(t *testing.T) { }, } - loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string, *schema.Message]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return agent, nil }, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{}, Consumed: items, }, nil @@ -3313,7 +3313,7 @@ func TestTurnLoop_ConcurrentPreemptAndStop(t *testing.T) { go func() { defer wg.Done() - _, ack := loop.Push("preempt-item", WithPreempt[string](AnySafePoint)) + _, ack := loop.Push("preempt-item", WithPreempt[string, *schema.Message](AnySafePoint)) if ack != nil { <-ack } @@ -3349,12 +3349,12 @@ func TestTurnLoop_ConcurrentPushStrategyAndStop(t *testing.T) { }, } - loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string, *schema.Message]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return agent, nil }, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{}, Consumed: items, }, nil @@ -3374,8 +3374,8 @@ func TestTurnLoop_ConcurrentPushStrategyAndStop(t *testing.T) { go func() { defer wg.Done() - _, ack := loop.Push("strategic-item", WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { - return []PushOption[string]{WithPreempt[string](AnySafePoint)} + _, ack := loop.Push("strategic-item", WithPushStrategy(func(ctx context.Context, tc *TurnContext[string, *schema.Message]) []PushOption[string, *schema.Message] { + return []PushOption[string, *schema.Message]{WithPreempt[string, *schema.Message](AnySafePoint)} })) if ack != nil { <-ack @@ -3397,14 +3397,14 @@ func TestTurnLoop_TurnContext_StoppedChannel(t *testing.T) { stoppedSeen := make(chan struct{}) agentStarted := make(chan struct{}) - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopCancellableMockAgent{ name: "slow", runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { @@ -3413,7 +3413,7 @@ func TestTurnLoop_TurnContext_StoppedChannel(t *testing.T) { }, }, nil }, - OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string, *schema.Message], events *AsyncIterator[*AgentEvent]) error { close(agentStarted) select { case <-tc.Stopped: @@ -3450,14 +3450,14 @@ func TestTurnLoop_TurnContext_BothPreemptedAndStopped(t *testing.T) { preemptedSeen := make(chan struct{}) agentStarted := make(chan struct{}) - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopCancellableMockAgent{ name: "slow", runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { @@ -3466,7 +3466,7 @@ func TestTurnLoop_TurnContext_BothPreemptedAndStopped(t *testing.T) { }, }, nil }, - OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string, *schema.Message], events *AsyncIterator[*TypedAgentEvent[*schema.Message]]) error { close(agentStarted) select { case <-tc.Preempted: @@ -3485,7 +3485,7 @@ func TestTurnLoop_TurnContext_BothPreemptedAndStopped(t *testing.T) { loop.Push("msg1") <-agentStarted - loop.Push("msg2", WithPreemptTimeout[string](AnySafePoint, time.Millisecond)) + loop.Push("msg2", WithPreemptTimeout[string, *schema.Message](AnySafePoint, time.Millisecond)) select { case <-preemptedSeen: @@ -3501,14 +3501,14 @@ func TestTurnLoop_TurnContext_BothPreemptedAndStopped(t *testing.T) { stoppedSeen := make(chan struct{}) agentStarted := make(chan struct{}) - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopCancellableMockAgent{ name: "slow", runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { @@ -3517,7 +3517,7 @@ func TestTurnLoop_TurnContext_BothPreemptedAndStopped(t *testing.T) { }, }, nil }, - OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string, *schema.Message], events *AsyncIterator[*TypedAgentEvent[*schema.Message]]) error { close(agentStarted) select { case <-tc.Stopped: @@ -3544,7 +3544,7 @@ func TestTurnLoop_TurnContext_BothPreemptedAndStopped(t *testing.T) { t.Fatal("Stopped channel was never closed") } - loop.Push("msg2", WithPreemptTimeout[string](AnySafePoint, time.Millisecond)) + loop.Push("msg2", WithPreemptTimeout[string, *schema.Message](AnySafePoint, time.Millisecond)) loop.Wait() }) } @@ -3573,18 +3573,18 @@ func TestTurnLoop_PushStrategy_DuringTurn(t *testing.T) { secondGenInputCalled := make(chan struct{}) secondGenInputOnce := sync.Once{} - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return agent, nil }, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { count := atomic.AddInt32(&genInputCalls, 1) if count >= 2 { secondGenInputOnce.Do(func() { close(secondGenInputCalled) }) } - return &GenInputResult[string]{ + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: []string{items[0]}, Remaining: items[1:], @@ -3602,11 +3602,11 @@ func TestTurnLoop_PushStrategy_DuringTurn(t *testing.T) { // Strategy inspects TurnContext during a running turn and decides to preempt. var strategyCalled int32 - var strategyTC *TurnContext[string] - loop.Push("urgent", WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { + var strategyTC *TurnContext[string, *schema.Message] + loop.Push("urgent", WithPushStrategy(func(ctx context.Context, tc *TurnContext[string, *schema.Message]) []PushOption[string, *schema.Message] { atomic.AddInt32(&strategyCalled, 1) strategyTC = tc - return []PushOption[string]{WithPreempt[string](AnySafePoint)} + return []PushOption[string, *schema.Message]{WithPreempt[string, *schema.Message](AnySafePoint)} })) select { @@ -3644,18 +3644,18 @@ func TestTurnLoop_PushStrategy_BetweenTurns(t *testing.T) { agentDone := make(chan struct{}) agentDoneOnce := sync.Once{} - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return agent, nil }, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, Remaining: nil, }, nil }, - OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string, *schema.Message], events *AsyncIterator[*AgentEvent]) error { for { _, ok := events.Next() if !ok { @@ -3670,7 +3670,7 @@ func TestTurnLoop_PushStrategy_BetweenTurns(t *testing.T) { }) // Push with strategy — no turn is active yet, so tc should be nil. - loop.Push("item", WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { + loop.Push("item", WithPushStrategy(func(ctx context.Context, tc *TurnContext[string, *schema.Message]) []PushOption[string, *schema.Message] { atomic.AddInt32(&strategyCalled, 1) strategyTCWasNil = (tc == nil) return nil // plain push, no preempt @@ -3701,18 +3701,18 @@ func TestTurnLoop_PushStrategy_OverridesOtherOptions(t *testing.T) { agentDone := make(chan struct{}) agentDoneOnce := sync.Once{} - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return agent, nil }, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, Remaining: nil, }, nil }, - OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string, *schema.Message], events *AsyncIterator[*AgentEvent]) error { for { _, ok := events.Next() if !ok { @@ -3728,7 +3728,7 @@ func TestTurnLoop_PushStrategy_OverridesOtherOptions(t *testing.T) { // Strategy returns nil (no preempt), even though WithPreempt is also passed. // The strategy should override — so the agent should NOT be preempted. - ok, ack := loop.Push("item", WithPreempt[string](AnySafePoint), WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { + ok, ack := loop.Push("item", WithPreempt[string, *schema.Message](AnySafePoint), WithPushStrategy(func(ctx context.Context, tc *TurnContext[string, *schema.Message]) []PushOption[string, *schema.Message] { return nil // no preempt })) assert.True(t, ok) @@ -3755,18 +3755,18 @@ func TestTurnLoop_PushStrategy_NestedStrategyStripped(t *testing.T) { agentDone := make(chan struct{}) agentDoneOnce := sync.Once{} - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return agent, nil }, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, Remaining: nil, }, nil }, - OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string, *schema.Message], events *AsyncIterator[*AgentEvent]) error { for { _, ok := events.Next() if !ok { @@ -3782,11 +3782,11 @@ func TestTurnLoop_PushStrategy_NestedStrategyStripped(t *testing.T) { // Strategy returns another WithPushStrategy — the nested one should be stripped. innerCalled := int32(0) - ok, ack := loop.Push("item", WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { - return []PushOption[string]{ - WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { + ok, ack := loop.Push("item", WithPushStrategy(func(ctx context.Context, tc *TurnContext[string, *schema.Message]) []PushOption[string, *schema.Message] { + return []PushOption[string, *schema.Message]{ + WithPushStrategy(func(ctx context.Context, tc *TurnContext[string, *schema.Message]) []PushOption[string, *schema.Message] { atomic.AddInt32(&innerCalled, 1) - return []PushOption[string]{WithPreempt[string](AnySafePoint)} + return []PushOption[string, *schema.Message]{WithPreempt[string, *schema.Message](AnySafePoint)} }), } })) @@ -3824,11 +3824,11 @@ func TestTurnLoop_PushStrategy_ConsumedInspection(t *testing.T) { genInputCalls := int32(0) secondGenInputItems := make(chan []string, 1) - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return agent, nil }, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { count := atomic.AddInt32(&genInputCalls, 1) if count >= 2 { select { @@ -3836,7 +3836,7 @@ func TestTurnLoop_PushStrategy_ConsumedInspection(t *testing.T) { default: } } - return &GenInputResult[string]{ + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: []string{items[0]}, Remaining: items[1:], @@ -3853,9 +3853,9 @@ func TestTurnLoop_PushStrategy_ConsumedInspection(t *testing.T) { } // Strategy checks Consumed and preempts because current turn has "low-priority" items. - loop.Push("urgent-task", WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { + loop.Push("urgent-task", WithPushStrategy(func(ctx context.Context, tc *TurnContext[string, *schema.Message]) []PushOption[string, *schema.Message] { if tc != nil && len(tc.Consumed) > 0 && tc.Consumed[0] == "low-priority-task" { - return []PushOption[string]{WithPreempt[string](AnySafePoint)} + return []PushOption[string, *schema.Message]{WithPreempt[string, *schema.Message](AnySafePoint)} } return nil })) @@ -3875,17 +3875,17 @@ func TestTurnLoop_PushAfterStop_BufferedAsLateItems(t *testing.T) { ctx := context.Background() processed := make(chan string, 10) - loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, - OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string, *schema.Message], events *AsyncIterator[*AgentEvent]) error { for { if _, ok := events.Next(); !ok { break @@ -3916,11 +3916,11 @@ func TestTurnLoop_PushAfterStop_BufferedAsLateItems(t *testing.T) { func TestTurnLoop_TakeLateItems_Idempotent(t *testing.T) { ctx := context.Background() - loop := NewTurnLoop(TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + loop := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -3943,11 +3943,11 @@ func TestTurnLoop_TakeLateItems_Idempotent(t *testing.T) { func TestTurnLoop_PushAfterTakeLateItems_Panics(t *testing.T) { ctx := context.Background() - loop := NewTurnLoop(TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + loop := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -3966,11 +3966,11 @@ func TestTurnLoop_PushAfterTakeLateItems_Panics(t *testing.T) { func TestTurnLoop_TakeLateItems_NeverCalled_NoImpact(t *testing.T) { ctx := context.Background() - loop := NewTurnLoop(TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + loop := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -3989,13 +3989,13 @@ func TestTurnLoop_CheckpointErr_SeparateFromExitReason(t *testing.T) { ctx := context.Background() saveStore := &errorCheckpointStore{setErr: fmt.Errorf("storage unavailable")} - loop := NewTurnLoop(TurnLoopConfig[string]{ + loop := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ Store: saveStore, CheckpointID: "cp-separate-err", - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -4014,11 +4014,11 @@ func TestTurnLoop_CheckpointErr_SeparateFromExitReason(t *testing.T) { func TestTurnLoop_CheckpointAttempted_FalseWhenNoStore(t *testing.T) { ctx := context.Background() - loop := NewTurnLoop(TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + loop := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -4038,20 +4038,20 @@ func TestTurnLoop_CheckpointAttempted_FalseOnErrorExit(t *testing.T) { firstTurnDone := make(chan struct{}) var callCount int32 - loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string, *schema.Message]{ Store: store, CheckpointID: "cp-err-exit", - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { n := atomic.AddInt32(&callCount, 1) if n > 1 { return nil, genInputErr } - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, - OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string, *schema.Message], events *AsyncIterator[*AgentEvent]) error { for { if _, ok := events.Next(); !ok { break @@ -4082,16 +4082,16 @@ func TestTurnLoop_StopConcurrentWithCallbackError_NoCheckpoint(t *testing.T) { stopCalled := make(chan struct{}) var prepareCount int32 - loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string, *schema.Message]{ Store: store, CheckpointID: cpID, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { n := atomic.AddInt32(&prepareCount, 1) if n > 1 { // Wait until Stop() has been called so stopSig.isStopped() is true @@ -4100,7 +4100,7 @@ func TestTurnLoop_StopConcurrentWithCallbackError_NoCheckpoint(t *testing.T) { } return &turnLoopMockAgent{name: "test"}, nil }, - OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string, *schema.Message], events *AsyncIterator[*AgentEvent]) error { for { if _, ok := events.Next(); !ok { break @@ -4140,13 +4140,13 @@ func TestTurnLoop_DeleteWithoutCheckPointDeleter_NoOp(t *testing.T) { cpID := "no-deleter" // First loop: save a checkpoint - loop1 := NewTurnLoop(TurnLoopConfig[string]{ + loop1 := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ Store: store, CheckpointID: cpID, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -4163,19 +4163,19 @@ func TestTurnLoop_DeleteWithoutCheckPointDeleter_NoOp(t *testing.T) { // Second loop: exit via context cancel — should try to delete but store // doesn't implement CheckPointDeleter, so checkpoint persists (no-op) ctx2, cancel2 := context.WithCancel(ctx) - loop2 := NewTurnLoop(TurnLoopConfig[string]{ + loop2 := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ Store: store, CheckpointID: cpID, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, - OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string, *schema.Message], events *AsyncIterator[*AgentEvent]) error { for { if _, ok := events.Next(); !ok { break @@ -4202,13 +4202,13 @@ func TestTurnLoop_StopWithSkipCheckpoint(t *testing.T) { store := &turnLoopCheckpointStore{m: make(map[string][]byte)} cpID := "skip-cp-session" - loop := NewTurnLoop(TurnLoopConfig[string]{ + loop := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ Store: store, CheckpointID: cpID, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -4235,13 +4235,13 @@ func TestTurnLoop_StopWithSkipCheckpoint_DeletesStaleCheckpoint(t *testing.T) { } cpID := "skip-stale-session" - loop1 := NewTurnLoop(TurnLoopConfig[string]{ + loop1 := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ Store: store, CheckpointID: cpID, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -4256,13 +4256,13 @@ func TestTurnLoop_StopWithSkipCheckpoint_DeletesStaleCheckpoint(t *testing.T) { store.mu.Unlock() assert.True(t, exists, "first loop should save checkpoint") - loop2 := NewTurnLoop(TurnLoopConfig[string]{ + loop2 := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ Store: store, CheckpointID: cpID, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -4282,11 +4282,11 @@ func TestTurnLoop_StopWithStopCause(t *testing.T) { ctx := context.Background() cause := "user session timeout" - loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -4301,11 +4301,11 @@ func TestTurnLoop_StopWithStopCause(t *testing.T) { func TestTurnLoop_StopCause_EmptyWhenNoStop(t *testing.T) { ctx := context.Background() - loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{Input: &AgentInput{}, Consumed: items}, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -4320,14 +4320,14 @@ func TestTurnLoop_StopCause_InTurnContext(t *testing.T) { gotCause := make(chan string, 1) agentStarted := make(chan struct{}) - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopCancellableMockAgent{ name: "slow", runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { @@ -4336,7 +4336,7 @@ func TestTurnLoop_StopCause_InTurnContext(t *testing.T) { }, }, nil }, - OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string, *schema.Message], events *AsyncIterator[*AgentEvent]) error { close(agentStarted) select { case <-tc.Stopped: @@ -4371,14 +4371,14 @@ func TestTurnLoop_StopCause_InTurnContext(t *testing.T) { func TestTurnLoop_StopCause_FirstNonEmptyWins(t *testing.T) { agentStarted := make(chan struct{}) - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopCancellableMockAgent{ name: "slow", runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { @@ -4387,7 +4387,7 @@ func TestTurnLoop_StopCause_FirstNonEmptyWins(t *testing.T) { }, }, nil }, - OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string, *schema.Message], events *AsyncIterator[*AgentEvent]) error { close(agentStarted) for { if _, ok := events.Next(); !ok { @@ -4408,12 +4408,12 @@ func TestTurnLoop_StopCause_FirstNonEmptyWins(t *testing.T) { } func TestTurnLoop_StopBeforeRun_PushThenStop(t *testing.T) { - loop := NewTurnLoop(TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + loop := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { t.Fatal("GenInput should not be called when Stop is called before Run") return nil, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { t.Fatal("PrepareAgent should not be called when Stop is called before Run") return nil, nil }, @@ -4435,12 +4435,12 @@ func TestTurnLoop_StopBeforeRun_PushThenStop(t *testing.T) { } func TestTurnLoop_StopBeforeRun_StopThenPush(t *testing.T) { - loop := NewTurnLoop(TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + loop := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { t.Fatal("GenInput should not be called when Stop is called before Run") return nil, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { t.Fatal("PrepareAgent should not be called when Stop is called before Run") return nil, nil }, @@ -4468,16 +4468,16 @@ func TestTurnLoop_SkipCheckpoint_Sticky(t *testing.T) { store := &turnLoopCheckpointStore{m: make(map[string][]byte)} cpID := "sticky-skip-session" - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ Store: store, CheckpointID: cpID, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopCancellableMockAgent{ name: "slow", runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { @@ -4486,7 +4486,7 @@ func TestTurnLoop_SkipCheckpoint_Sticky(t *testing.T) { }, }, nil }, - OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string, *schema.Message], events *AsyncIterator[*AgentEvent]) error { close(agentStarted) for { if _, ok := events.Next(); !ok { @@ -4520,12 +4520,12 @@ func TestWithGracefulTimeout_NonPositive_Panics(t *testing.T) { func TestWithPreempt_ZeroSafePoint_Panics(t *testing.T) { assert.PanicsWithValue(t, "adk: SafePoint must not be zero; use AfterToolCalls, AfterChatModel, or AnySafePoint", - func() { WithPreempt[string](SafePoint(0)) }) + func() { WithPreempt[string, *schema.Message](SafePoint(0)) }) } func TestWithPreemptTimeout_ZeroSafePoint_Panics(t *testing.T) { assert.PanicsWithValue(t, "adk: SafePoint must not be zero; use AfterToolCalls, AfterChatModel, or AnySafePoint", - func() { WithPreemptTimeout[string](SafePoint(0), time.Second) }) + func() { WithPreemptTimeout[string, *schema.Message](SafePoint(0), time.Second) }) } func TestSafePoint_ToCancelMode(t *testing.T) { @@ -4536,13 +4536,15 @@ func TestSafePoint_ToCancelMode(t *testing.T) { func TestNewTurnLoop_NilGenInput_Panics(t *testing.T) { assert.PanicsWithValue(t, "adk: NewTurnLoop: GenInput is required", func() { - NewTurnLoop(TurnLoopConfig[string]{PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { return nil, nil }}) + NewTurnLoop(TurnLoopConfig[string, *schema.Message]{PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { + return nil, nil + }}) }) } func TestNewTurnLoop_NilPrepareAgent_Panics(t *testing.T) { assert.PanicsWithValue(t, "adk: NewTurnLoop: PrepareAgent is required", func() { - NewTurnLoop(TurnLoopConfig[string]{GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + NewTurnLoop(TurnLoopConfig[string, *schema.Message]{GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { return nil, nil }}) }) @@ -4556,14 +4558,14 @@ func TestDeriveChild_NilParent_ReturnsNil(t *testing.T) { func TestUntilIdleFor(t *testing.T) { t.Run("FiresAfterIdleDuration", func(t *testing.T) { turnDone := make(chan struct{}) - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{ name: "test", runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { @@ -4595,14 +4597,14 @@ func TestUntilIdleFor(t *testing.T) { t.Run("ResetsOnPush", func(t *testing.T) { turnCount := int32(0) turnDone := make(chan struct{}, 10) - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{ name: "test", runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { @@ -4641,14 +4643,14 @@ func TestUntilIdleFor(t *testing.T) { t.Run("EscalatedByStopWithImmediate", func(t *testing.T) { agentStarted := make(chan *cancelContext, 1) probe := &turnLoopStopModeProbeAgent{ccCh: agentStarted} - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return probe, nil }, }) @@ -4681,14 +4683,14 @@ func TestUntilIdleFor(t *testing.T) { t.Run("EscalatedByStopWithGraceful", func(t *testing.T) { agentStarted := make(chan struct{}) agentDone := make(chan struct{}) - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopCancellableMockAgent{ name: "test", runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { @@ -4729,14 +4731,14 @@ func TestUntilIdleFor_DoesNotCancelRunningAgent(t *testing.T) { agentCtxCanceled := int32(0) agentDone := make(chan struct{}) - loop := NewTurnLoop(TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + loop := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopCancellableMockAgent{ name: "test", runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { @@ -4773,14 +4775,14 @@ func TestUntilIdleFor_DoesNotCancelRunningAgent(t *testing.T) { agentCtxCanceled := int32(0) agentDone := make(chan struct{}) - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopCancellableMockAgent{ name: "test", runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { @@ -4818,14 +4820,14 @@ func TestUntilIdleFor_DoesNotCancelRunningAgent(t *testing.T) { agentCtxCanceled := int32(0) agentDone := make(chan struct{}) - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopCancellableMockAgent{ name: "test", runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { @@ -4860,14 +4862,14 @@ func TestUntilIdleFor_ContextCancelDuringIdleWait(t *testing.T) { turnDone := make(chan struct{}) ctx, cancel := context.WithCancel(context.Background()) - loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{ name: "test", runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { @@ -4967,14 +4969,14 @@ func TestAttack_UntilIdleFor_ConcurrentPushDuringIdleTimer(t *testing.T) { turnCount := int32(0) turnDone := make(chan struct{}, 10) - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{ name: "test", runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { @@ -5015,14 +5017,14 @@ func TestAttack_UntilIdleFor_ConcurrentPushDuringIdleTimer(t *testing.T) { func TestAttack_UntilIdleFor_MultipleStopCallsFirstWins(t *testing.T) { turnDone := make(chan struct{}) - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{ name: "test", runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { @@ -5056,14 +5058,14 @@ func TestAttack_BareStopOverridesUntilIdleFor(t *testing.T) { agentStarted := make(chan struct{}) agentDone := make(chan struct{}) - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{ name: "test", runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { @@ -5102,14 +5104,14 @@ func TestAttack_BareStopOverridesUntilIdleFor(t *testing.T) { func TestAttack_StopSignal_NilCancelOptsDoNotDeescalate(t *testing.T) { agentStarted := make(chan *cancelContext, 1) probe := &turnLoopStopModeProbeAgent{ccCh: agentStarted} - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return probe, nil }, }) @@ -5135,14 +5137,14 @@ func TestAttack_StopSignal_NilCancelOptsDoNotDeescalate(t *testing.T) { func TestAttack_CanceledItems_EmptyWhenAgentFinishesNormally(t *testing.T) { agentStarted := make(chan struct{}) - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{ name: "test", runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { @@ -5207,14 +5209,14 @@ func TestAttack_TurnBuffer_ClearWakeupPreventsSpuriousReturn(t *testing.T) { } func TestAttack_StopBeforeRun_UntilIdleFor_ExitsImmediately(t *testing.T) { - loop := NewTurnLoop(TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + loop := NewTurnLoop(TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{name: "test"}, nil }, }) @@ -5239,14 +5241,14 @@ func TestAttack_StopBeforeRun_UntilIdleFor_ExitsImmediately(t *testing.T) { func TestAttack_PushAfterStop_UntilIdleFor_RoutedToLateItems(t *testing.T) { turnDone := make(chan struct{}) - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{ name: "test", runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { @@ -5273,14 +5275,14 @@ func TestAttack_PushAfterStop_UntilIdleFor_RoutedToLateItems(t *testing.T) { func TestAttack_ConcurrentStopEscalation_RaceDetector(t *testing.T) { agentStarted := make(chan struct{}) - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopCancellableMockAgent{ name: "test", runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { @@ -5320,14 +5322,14 @@ func TestAttack_ConcurrentStopEscalation_RaceDetector(t *testing.T) { func TestAttack_StopCause_FirstNonEmptyWins_ConcurrentCallers(t *testing.T) { turnDone := make(chan struct{}) - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopMockAgent{ name: "test", runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { @@ -5350,14 +5352,14 @@ func TestAttack_StopCause_FirstNonEmptyWins_ConcurrentCallers(t *testing.T) { func TestAttack_SkipCheckpoint_Sticky(t *testing.T) { agentStarted := make(chan struct{}) - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return &turnLoopCancellableMockAgent{ name: "test", runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { @@ -5424,14 +5426,14 @@ func TestTurnLoop_Stop_WithImmediate_RecursivePropagation(t *testing.T) { childCCCh := make(chan *cancelContext, 1) probe := &turnLoopNestedProbeAgent{parentCCCh: parentCCCh, childCCCh: childCCCh} - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return probe, nil }, }) @@ -5472,14 +5474,14 @@ func TestTurnLoop_Push_WithPreemptTimeout_RecursivePropagation(t *testing.T) { childCCCh := make(chan *cancelContext, 2) probe := &turnLoopNestedProbeAgent{parentCCCh: parentCCCh, childCCCh: childCCCh} - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string, *schema.Message]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], items []string) (*GenInputResult[string, *schema.Message], error) { + return &GenInputResult[string, *schema.Message]{ Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, }, nil }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string, *schema.Message], consumed []string) (Agent, error) { return probe, nil }, }) @@ -5490,7 +5492,7 @@ func TestTurnLoop_Push_WithPreemptTimeout_RecursivePropagation(t *testing.T) { t.Cleanup(func() { child.markDone() }) // Preempt with a very short timeout so it escalates to CancelImmediate quickly. - loop.Push("urgent", WithPreemptTimeout[string](AfterChatModel, 10*time.Millisecond)) + loop.Push("urgent", WithPreemptTimeout[string, *schema.Message](AfterChatModel, 10*time.Millisecond)) // After timeout escalation, child should receive the immediate cancel // via recursive propagation. @@ -5515,15 +5517,15 @@ func TestUntilIdleFor_NonPositive_Panics(t *testing.T) { } func TestSaveTurnLoopCheckpoint_NilStore(t *testing.T) { - l := &TurnLoop[string]{config: TurnLoopConfig[string]{Store: nil}} + l := &TurnLoop[string, *schema.Message]{config: TurnLoopConfig[string, *schema.Message]{Store: nil}} err := l.saveTurnLoopCheckpoint(context.Background(), "cp-1", &turnLoopCheckpoint[string]{}) assert.Error(t, err) assert.Contains(t, err.Error(), "checkpoint store is nil") } func TestSetupBridgeStore_NilStore_Resume(t *testing.T) { - l := &TurnLoop[string]{config: TurnLoopConfig[string]{Store: nil}} - spec := &turnRunSpec[string]{isResume: true} + l := &TurnLoop[string, *schema.Message]{config: TurnLoopConfig[string, *schema.Message]{Store: nil}} + spec := &turnRunSpec[string, *schema.Message]{isResume: true} _, _, err := l.setupBridgeStore(spec, nil) assert.Error(t, err) assert.Contains(t, err.Error(), "checkpoint store is nil") diff --git a/adk/utils.go b/adk/utils.go index 89b991324..ec804f728 100644 --- a/adk/utils.go +++ b/adk/utils.go @@ -102,8 +102,7 @@ func GenTransferMessages(_ context.Context, destAgentName string) (Message, Mess return assistantMessage, toolMessage } -// set automatic close for event's message stream -func setAutomaticClose(e *AgentEvent) { +func typedSetAutomaticClose[M messageType](e *TypedAgentEvent[M]) { if e.Output == nil || e.Output.MessageOutput == nil || !e.Output.MessageOutput.IsStreaming { return } @@ -111,10 +110,41 @@ func setAutomaticClose(e *AgentEvent) { e.Output.MessageOutput.MessageStream.SetAutomaticClose() } +// set automatic close for event's message stream +func setAutomaticClose(e *AgentEvent) { + typedSetAutomaticClose(e) +} + // getMessageFromWrappedEvent extracts the message from an AgentEvent. // If the stream contains an error chunk, this function returns (nil, err) and // sets StreamErr to prevent re-consumption. The nil message ensures that // failed stream responses are not included in subsequent agents' context windows. +func getMessageFromTypedWrappedEvent[M messageType](e *typedAgentEventWrapper[M]) (M, error) { + var zero M + if e.event.Output == nil || e.event.Output.MessageOutput == nil { + return zero, nil + } + + if !e.event.Output.MessageOutput.IsStreaming { + return e.event.Output.MessageOutput.Message, nil + } + + if e.StreamErr != nil { + return zero, e.StreamErr + } + + if !isNilMessage(e.concatenatedMessage) { + return e.concatenatedMessage, nil + } + + e.consumeStream() + + if e.StreamErr != nil { + return zero, e.StreamErr + } + return e.concatenatedMessage, nil +} + func getMessageFromWrappedEvent(e *agentEventWrapper) (Message, error) { if e.AgentEvent.Output == nil || e.AgentEvent.Output.MessageOutput == nil { return nil, nil @@ -194,21 +224,21 @@ func (e *agentEventWrapper) consumeStream() { e.AgentEvent.Output.MessageOutput.MessageStream = schema.StreamReaderFromArray([]Message{e.concatenatedMessage}) } -// copyAgentEvent copies an AgentEvent. +// copyTypedAgentEvent copies a TypedAgentEvent. // If the MessageVariant is streaming, the MessageStream will be copied. // RunPath will be deep copied. -// The result of Copy will be a new AgentEvent that is: -// - safe to set fields of AgentEvent +// The result of Copy will be a new TypedAgentEvent that is: +// - safe to set fields of TypedAgentEvent // - safe to extend RunPath // - safe to receive from MessageStream -// NOTE: even if the AgentEvent is copied, it's still not recommended to modify +// NOTE: even if the event is copied, it's still not recommended to modify // the Message itself or Chunks of the MessageStream, as they are not copied. // NOTE: if you have CustomizedOutput or CustomizedAction, they are NOT copied. -func copyAgentEvent(ae *AgentEvent) *AgentEvent { +func copyTypedAgentEvent[M messageType](ae *TypedAgentEvent[M]) *TypedAgentEvent[M] { rp := make([]RunStep, len(ae.RunPath)) copy(rp, ae.RunPath) - copied := &AgentEvent{ + copied := &TypedAgentEvent[M]{ AgentName: ae.AgentName, RunPath: rp, Action: ae.Action, @@ -219,7 +249,7 @@ func copyAgentEvent(ae *AgentEvent) *AgentEvent { return copied } - copied.Output = &AgentOutput{ + copied.Output = &TypedAgentOutput[M]{ CustomizedOutput: ae.Output.CustomizedOutput, } @@ -228,7 +258,7 @@ func copyAgentEvent(ae *AgentEvent) *AgentEvent { return copied } - copied.Output.MessageOutput = &MessageVariant{ + copied.Output.MessageOutput = &TypedMessageVariant[M]{ IsStreaming: mv.IsStreaming, Role: mv.Role, ToolName: mv.ToolName, @@ -244,11 +274,11 @@ func copyAgentEvent(ae *AgentEvent) *AgentEvent { return copied } -// GetMessage extracts the Message from an AgentEvent. For streaming output, -// it duplicates the stream and concatenates it into a single Message. -func GetMessage(e *AgentEvent) (Message, *AgentEvent, error) { +// TypedGetMessage extracts the message from a TypedAgentEvent, concatenating a stream if present. +func TypedGetMessage[M messageType](e *TypedAgentEvent[M]) (M, *TypedAgentEvent[M], error) { + var zero M if e.Output == nil || e.Output.MessageOutput == nil { - return nil, e, nil + return zero, e, nil } msgOutput := e.Output.MessageOutput @@ -256,7 +286,7 @@ func GetMessage(e *AgentEvent) (Message, *AgentEvent, error) { ss := msgOutput.MessageStream.Copy(2) e.Output.MessageOutput.MessageStream = ss[0] - msg, err := schema.ConcatMessageStream(ss[1]) + msg, err := concatMessageStream(ss[1]) return msg, e, err } @@ -264,9 +294,19 @@ func GetMessage(e *AgentEvent) (Message, *AgentEvent, error) { return msgOutput.Message, e, nil } -func genErrorIter(err error) *AsyncIterator[*AgentEvent] { - iterator, generator := NewAsyncIteratorPair[*AgentEvent]() - generator.Send(&AgentEvent{Err: err}) +// GetMessage extracts the Message from an AgentEvent. For streaming output, +// it duplicates the stream and concatenates it into a single Message. +func GetMessage(e *AgentEvent) (Message, *AgentEvent, error) { + return TypedGetMessage(e) +} + +func typedErrorIter[M messageType](err error) *AsyncIterator[*TypedAgentEvent[M]] { + iterator, generator := NewAsyncIteratorPair[*TypedAgentEvent[M]]() + generator.Send(&TypedAgentEvent[M]{Err: err}) generator.Close() return iterator } + +func genErrorIter(err error) *AsyncIterator[*AgentEvent] { + return typedErrorIter[*schema.Message](err) +} diff --git a/adk/workflow_test.go b/adk/workflow_test.go index 298bef5c7..3392187a6 100644 --- a/adk/workflow_test.go +++ b/adk/workflow_test.go @@ -1021,7 +1021,7 @@ func TestWorkflowAgentUnsupportedMode(t *testing.T) { name: "UnsupportedModeAgent", description: "Agent with unsupported mode", subAgents: []*flowAgent{}, - mode: workflowAgentMode(999), // Invalid mode + mode: workflowAgentMode(999), } // Run the agent and expect error diff --git a/adk/wrappers.go b/adk/wrappers.go index ce50d7baa..a464d84f7 100644 --- a/adk/wrappers.go +++ b/adk/wrappers.go @@ -32,11 +32,11 @@ import ( "github.com/cloudwego/eino/schema" ) -type generateEndpoint func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) -type streamEndpoint func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) +type typedGenerateEndpoint[M messageType] func(ctx context.Context, input []M, opts ...model.Option) (M, error) +type typedStreamEndpoint[M messageType] func(ctx context.Context, input []M, opts ...model.Option) (*schema.StreamReader[M], error) -type modelWrapperConfig struct { - handlers []ChatModelAgentMiddleware +type typedModelWrapperConfig[M messageType] struct { + handlers []TypedChatModelAgentMiddleware[M] middlewares []AgentMiddleware retryConfig *ModelRetryConfig failoverConfig *ModelFailoverConfig @@ -44,19 +44,24 @@ type modelWrapperConfig struct { cancelContext *cancelContext } -func buildModelWrappers(m model.BaseChatModel, config *modelWrapperConfig) model.BaseChatModel { - var wrapped model.BaseChatModel = m +type modelWrapperConfig = typedModelWrapperConfig[*schema.Message] + +func buildModelWrappers[M messageType](m model.BaseModel[M], config *typedModelWrapperConfig[M]) model.BaseModel[M] { + return buildModelWrappersImpl(m, config) +} + +func buildModelWrappersImpl[M messageType](m model.BaseModel[M], config *typedModelWrapperConfig[M]) model.BaseModel[M] { + var wrapped model.BaseModel[M] = m - // failoverProxyModel must be the innermost wrapper to read the selected failover model from context. if config.failoverConfig != nil { - wrapped = &failoverProxyModel{} + wrapped = &typedFailoverProxyModel[M]{} } if !components.IsCallbacksEnabled(wrapped) { - wrapped = (&callbackInjectionModelWrapper{}).WrapModel(wrapped) + wrapped = typedCallbackInjectionModelWrapper[M]{}.wrapModel(wrapped) } - wrapped = &stateModelWrapper{ + wrapped = &typedStateModelWrapper[M]{ inner: wrapped, original: m, handlers: config.handlers, @@ -70,28 +75,29 @@ func buildModelWrappers(m model.BaseChatModel, config *modelWrapperConfig) model return wrapped } -type callbackInjectionModelWrapper struct{} +type typedCallbackInjectionModelWrapper[M messageType] struct{} -func (w *callbackInjectionModelWrapper) WrapModel(m model.BaseChatModel) model.BaseChatModel { - return &callbackInjectedModel{inner: m} +func (w typedCallbackInjectionModelWrapper[M]) wrapModel(m model.BaseModel[M]) model.BaseModel[M] { + return &typedCallbackInjectedModel[M]{inner: m} } -type callbackInjectedModel struct { - inner model.BaseChatModel +type typedCallbackInjectedModel[M messageType] struct { + inner model.BaseModel[M] } -func (m *callbackInjectedModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { +func (m *typedCallbackInjectedModel[M]) Generate(ctx context.Context, input []M, opts ...model.Option) (M, error) { ctx = callbacks.OnStart(ctx, input) result, err := m.inner.Generate(ctx, input, opts...) if err != nil { callbacks.OnError(ctx, err) - return nil, err + var zero M + return zero, err } callbacks.OnEnd(ctx, result) return result, nil } -func (m *callbackInjectedModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { +func (m *typedCallbackInjectedModel[M]) Stream(ctx context.Context, input []M, opts ...model.Option) (*schema.StreamReader[M], error) { ctx = callbacks.OnStart(ctx, input) result, err := m.inner.Stream(ctx, input, opts...) if err != nil { @@ -102,7 +108,7 @@ func (m *callbackInjectedModel) Stream(ctx context.Context, input []*schema.Mess return wrappedStream, nil } -func handlersToToolMiddlewares(handlers []ChatModelAgentMiddleware) []compose.ToolMiddleware { +func handlersToToolMiddlewares[M messageType](handlers []TypedChatModelAgentMiddleware[M]) []compose.ToolMiddleware { var middlewares []compose.ToolMiddleware for i := len(handlers) - 1; i >= 0; i-- { handler := handlers[i] @@ -247,25 +253,21 @@ func handlersToToolMiddlewares(handlers []ChatModelAgentMiddleware) []compose.To return middlewares } -type eventSenderModelWrapper struct { - *BaseChatModelAgentMiddleware +type typedEventSenderModelWrapper[M messageType] struct { + *TypedBaseChatModelAgentMiddleware[M] } -// NewEventSenderModelWrapper returns a ChatModelAgentMiddleware that sends model response events. -// By default, the framework applies this wrapper after all user middlewares, so events contain -// modified messages. To send events with original (unmodified) output, pass this as a Handler -// after the modifying middleware (placing it innermost in the wrapper chain). -// When detected in Handlers, the framework skips the default event sender to avoid duplicates. +// NewEventSenderModelWrapper creates a ChatModelAgentMiddleware that sends model output as agent events. func NewEventSenderModelWrapper() ChatModelAgentMiddleware { - return &eventSenderModelWrapper{ - BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, + return &typedEventSenderModelWrapper[*schema.Message]{ + TypedBaseChatModelAgentMiddleware: &TypedBaseChatModelAgentMiddleware[*schema.Message]{}, } } -func (w *eventSenderModelWrapper) WrapModel(_ context.Context, m model.BaseChatModel, mc *ModelContext) (model.BaseChatModel, error) { +func (w *typedEventSenderModelWrapper[M]) WrapModel(_ context.Context, m model.BaseModel[M], mc *ModelContext) (model.BaseModel[M], error) { inner := m if mc != nil && mc.cancelContext != nil { - inner = &cancelMonitoredModel{ + inner = &typedCancelMonitoredModel[M]{ inner: inner, cancelContext: mc.cancelContext, } @@ -278,43 +280,44 @@ func (w *eventSenderModelWrapper) WrapModel(_ context.Context, m model.BaseChatM if mc != nil { failoverConfig = mc.ModelFailoverConfig } - return &eventSenderModel{inner: inner, modelRetryConfig: retryConfig, modelFailoverConfig: failoverConfig}, nil + return &typedEventSenderModel[M]{inner: inner, modelRetryConfig: retryConfig, modelFailoverConfig: failoverConfig}, nil } -type eventSenderModel struct { - inner model.BaseChatModel +type typedEventSenderModel[M messageType] struct { + inner model.BaseModel[M] modelRetryConfig *ModelRetryConfig modelFailoverConfig *ModelFailoverConfig } -func (m *eventSenderModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { +func (m *typedEventSenderModel[M]) Generate(ctx context.Context, input []M, opts ...model.Option) (M, error) { result, err := m.inner.Generate(ctx, input, opts...) if err != nil { - return nil, err + var zero M + return zero, err } - execCtx := getChatModelAgentExecCtx(ctx) + execCtx := getTypedChatModelAgentExecCtx[M](ctx) if execCtx != nil && execCtx.suppressEventSend { return result, nil } if execCtx == nil || execCtx.generator == nil { - return nil, errors.New("generator is nil when sending event in Generate: ensure agent state is properly initialized") + var zero M + return zero, errors.New("generator is nil when sending event in Generate: ensure agent state is properly initialized") } - msgCopy := *result - event := EventFromMessage(&msgCopy, nil, schema.Assistant, "") + event := typedModelOutputEvent(copyMessage(result), nil) execCtx.send(event) return result, nil } -func (m *eventSenderModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { +func (m *typedEventSenderModel[M]) Stream(ctx context.Context, input []M, opts ...model.Option) (*schema.StreamReader[M], error) { result, err := m.inner.Stream(ctx, input, opts...) if err != nil { return nil, err } - execCtx := getChatModelAgentExecCtx(ctx) + execCtx := getTypedChatModelAgentExecCtx[M](ctx) if execCtx == nil || execCtx.generator == nil { result.Close() return nil, errors.New("generator is nil when sending event in Stream: ensure agent state is properly initialized") @@ -325,11 +328,12 @@ func (m *eventSenderModel) Stream(ctx context.Context, input []*schema.Message, eventStream := streams[0] if convertOpts := m.buildStreamConvertOptions(ctx); len(convertOpts) > 0 { eventStream = schema.StreamReaderWithConvert(streams[0], - func(msg *schema.Message) (*schema.Message, error) { return msg, nil }, + func(msg M) (M, error) { return msg, nil }, convertOpts...) } - event := EventFromMessage(nil, eventStream, schema.Assistant, "") + var zero M + event := typedModelOutputEvent[M](zero, eventStream) execCtx.send(event) return streams[1], nil @@ -354,9 +358,9 @@ func (m *eventSenderModel) Stream(ctx context.Context, input []*schema.Message, // This prevents a goroutine leak when a mid-stream error is followed by EOF: errWrapper fires // first (caching the verdict), and onEOF reuses the cached value instead of blocking on a // drained channel. -func (m *eventSenderModel) buildStreamConvertOptions(ctx context.Context) []schema.ConvertOption { +func (m *typedEventSenderModel[M]) buildStreamConvertOptions(ctx context.Context) []schema.ConvertOption { var retryAttempt int - _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { + _ = compose.ProcessState(ctx, func(_ context.Context, st *typedState[M]) error { retryAttempt = st.getRetryAttempt() return nil }) @@ -375,7 +379,7 @@ func (m *eventSenderModel) buildStreamConvertOptions(ctx context.Context) []sche var retryWrapper func(error) error if m.modelRetryConfig != nil { if m.modelRetryConfig.ShouldRetry != nil { - execCtx := getChatModelAgentExecCtx(ctx) + execCtx := getTypedChatModelAgentExecCtx[*schema.Message](ctx) signal := (*retryVerdictSignal)(nil) if execCtx != nil { signal = execCtx.retryVerdictSignal @@ -460,11 +464,24 @@ func (m *eventSenderModel) buildStreamConvertOptions(ctx context.Context) []sche return opts } -func popToolGenAction(ctx context.Context, toolName string) *AgentAction { +func copyMessage[M messageType](msg M) M { + switch v := any(msg).(type) { + case *schema.Message: + cp := *v + return any(&cp).(M) + case *schema.AgenticMessage: + cp := *v + return any(&cp).(M) + default: + return msg + } +} + +func typedPopToolGenAction[M messageType](ctx context.Context, toolName string) *AgentAction { toolCallID := compose.GetToolCallID(ctx) var action *AgentAction - _ = compose.ProcessState(ctx, func(ctx context.Context, st *State) error { + _ = compose.ProcessState(ctx, func(ctx context.Context, st *typedState[M]) error { if len(toolCallID) > 0 { if a := st.popToolGenAction(toolCallID); a != nil { action = a @@ -482,10 +499,23 @@ func popToolGenAction(ctx context.Context, toolName string) *AgentAction { return action } +func popToolGenAction(ctx context.Context, toolName string) *AgentAction { + return typedPopToolGenAction[*schema.Message](ctx, toolName) +} + type eventSenderToolWrapper struct { *BaseChatModelAgentMiddleware } +func (*eventSenderToolWrapper) isEventSenderToolWrapper() {} + +// eventSenderToolWrapperMarker enables cross-type detection of eventSenderToolWrapper +// in generic contexts. hasUserEventSenderToolWrapper[M] receives +// []TypedChatModelAgentMiddleware[M], so when M is *schema.AgenticMessage, a direct +// type assertion to *eventSenderToolWrapper (which implements the *schema.Message alias) +// would fail. The marker interface bridges this gap. +type eventSenderToolWrapperMarker interface{ isEventSenderToolWrapper() } + // NewEventSenderToolWrapper returns a ChatModelAgentMiddleware that sends tool result events. // By default, the framework places this before all user middlewares (outermost), so events // reflect the fully processed tool output. To control exactly where events are emitted, @@ -514,7 +544,7 @@ func (w *eventSenderToolWrapper) WrapInvokableToolCall(_ context.Context, endpoi event.Action = prePopAction } - execCtx := getChatModelAgentExecCtx(ctx) + execCtx := getTypedChatModelAgentExecCtx[*schema.Message](ctx) _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { if st.getReturnDirectlyToolCallID() == callID { st.setReturnDirectlyEvent(event) @@ -548,7 +578,7 @@ func (w *eventSenderToolWrapper) WrapStreamableToolCall(_ context.Context, endpo event := EventFromMessage(nil, msgStream, schema.Tool, toolName) event.Action = prePopAction - execCtx := getChatModelAgentExecCtx(ctx) + execCtx := getTypedChatModelAgentExecCtx[*schema.Message](ctx) _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { if st.getReturnDirectlyToolCallID() == callID { st.setReturnDirectlyEvent(event) @@ -583,7 +613,7 @@ func (w *eventSenderToolWrapper) WrapEnhancedInvokableToolCall(_ context.Context event.Action = prePopAction } - execCtx := getChatModelAgentExecCtx(ctx) + execCtx := getTypedChatModelAgentExecCtx[*schema.Message](ctx) _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { if st.getReturnDirectlyToolCallID() == callID { st.setReturnDirectlyEvent(event) @@ -623,7 +653,7 @@ func (w *eventSenderToolWrapper) WrapEnhancedStreamableToolCall(_ context.Contex event := EventFromMessage(nil, msgStream, schema.Tool, toolName) event.Action = prePopAction - execCtx := getChatModelAgentExecCtx(ctx) + execCtx := getTypedChatModelAgentExecCtx[*schema.Message](ctx) _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { if st.getReturnDirectlyToolCallID() == callID { st.setReturnDirectlyEvent(event) @@ -637,19 +667,19 @@ func (w *eventSenderToolWrapper) WrapEnhancedStreamableToolCall(_ context.Contex }, nil } -func hasUserEventSenderToolWrapper(handlers []ChatModelAgentMiddleware) bool { +func hasUserEventSenderToolWrapper[M messageType](handlers []TypedChatModelAgentMiddleware[M]) bool { for _, handler := range handlers { - if _, ok := handler.(*eventSenderToolWrapper); ok { + if _, ok := any(handler).(eventSenderToolWrapperMarker); ok { return true } } return false } -type stateModelWrapper struct { - inner model.BaseChatModel - original model.BaseChatModel - handlers []ChatModelAgentMiddleware +type typedStateModelWrapper[M messageType] struct { + inner model.BaseModel[M] + original model.BaseModel[M] + handlers []TypedChatModelAgentMiddleware[M] middlewares []AgentMiddleware toolInfos []*schema.ToolInfo modelRetryConfig *ModelRetryConfig @@ -657,27 +687,29 @@ type stateModelWrapper struct { cancelContext *cancelContext } -func (w *stateModelWrapper) IsCallbacksEnabled() bool { +type stateModelWrapper = typedStateModelWrapper[*schema.Message] + +func (w *typedStateModelWrapper[M]) IsCallbacksEnabled() bool { return true } -func (w *stateModelWrapper) GetType() string { - if typer, ok := w.original.(components.Typer); ok { +func (w *typedStateModelWrapper[M]) GetType() string { + if typer, ok := any(w.original).(components.Typer); ok { return typer.GetType() } return generic.ParseTypeName(reflect.ValueOf(w.original)) } -func (w *stateModelWrapper) hasUserEventSender() bool { +func (w *typedStateModelWrapper[M]) hasUserEventSender() bool { for _, handler := range w.handlers { - if _, ok := handler.(*eventSenderModelWrapper); ok { + if _, ok := any(handler).(*typedEventSenderModelWrapper[M]); ok { return true } } return false } -func (w *stateModelWrapper) wrapGenerateEndpoint(endpoint generateEndpoint) generateEndpoint { +func (w *typedStateModelWrapper[M]) wrapGenerateEndpoint(endpoint typedGenerateEndpoint[M]) typedGenerateEndpoint[M] { hasUserEventSender := w.hasUserEventSender() retryConfig := w.modelRetryConfig failoverConfig := w.modelFailoverConfig @@ -687,13 +719,14 @@ func (w *stateModelWrapper) wrapGenerateEndpoint(endpoint generateEndpoint) gene handler := w.handlers[i] innerEndpoint := endpoint baseToolInfos := w.toolInfos - endpoint = func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + endpoint = func(ctx context.Context, input []M, opts ...model.Option) (M, error) { baseOpts := &model.Options{Tools: baseToolInfos} commonOpts := model.GetCommonOptions(baseOpts, opts...) mc := &ModelContext{Tools: commonOpts.Tools, ModelRetryConfig: retryConfig, cancelContext: cc} - wrappedModel, err := handler.WrapModel(ctx, &endpointModel{generate: innerEndpoint}, mc) + wrappedModel, err := handler.WrapModel(ctx, &typedEndpointModel[M]{generate: innerEndpoint}, mc) if err != nil { - return nil, err + var zero M + return zero, err } return wrappedModel.Generate(ctx, input, opts...) } @@ -701,16 +734,19 @@ func (w *stateModelWrapper) wrapGenerateEndpoint(endpoint generateEndpoint) gene if !hasUserEventSender { innerEndpoint := endpoint - eventSender := NewEventSenderModelWrapper() - endpoint = func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { - execCtx := getChatModelAgentExecCtx(ctx) + eventSender := &typedEventSenderModelWrapper[M]{ + TypedBaseChatModelAgentMiddleware: &TypedBaseChatModelAgentMiddleware[M]{}, + } + endpoint = func(ctx context.Context, input []M, opts ...model.Option) (M, error) { + execCtx := getTypedChatModelAgentExecCtx[M](ctx) if execCtx == nil || execCtx.generator == nil { return innerEndpoint(ctx, input, opts...) } mc := &ModelContext{ModelRetryConfig: retryConfig, ModelFailoverConfig: failoverConfig, cancelContext: cc} - wrappedModel, err := eventSender.WrapModel(ctx, &endpointModel{generate: innerEndpoint}, mc) + wrappedModel, err := eventSender.WrapModel(ctx, &typedEndpointModel[M]{generate: innerEndpoint}, mc) if err != nil { - return nil, err + var zero M + return zero, err } return wrappedModel.Generate(ctx, input, opts...) } @@ -718,18 +754,17 @@ func (w *stateModelWrapper) wrapGenerateEndpoint(endpoint generateEndpoint) gene if w.modelRetryConfig != nil { innerEndpoint := endpoint - endpoint = func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { - retryWrapper := newRetryModelWrapper(&endpointModel{generate: innerEndpoint}, w.modelRetryConfig) + endpoint = func(ctx context.Context, input []M, opts ...model.Option) (M, error) { + retryWrapper := newTypedRetryModelWrapper[M](&typedEndpointModel[M]{generate: innerEndpoint}, w.modelRetryConfig) return retryWrapper.Generate(ctx, input, opts...) } } - // Needs to handle failoverWrapper after retryWrapper if w.modelFailoverConfig != nil { config := w.modelFailoverConfig innerEndpoint := endpoint - endpoint = func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { - failoverWrapper := newFailoverModelWrapper(&endpointModel{generate: innerEndpoint}, config) + endpoint = func(ctx context.Context, input []M, opts ...model.Option) (M, error) { + failoverWrapper := newTypedFailoverModelWrapper[M](&typedEndpointModel[M]{generate: innerEndpoint}, config) return failoverWrapper.Generate(ctx, input, opts...) } } @@ -737,7 +772,7 @@ func (w *stateModelWrapper) wrapGenerateEndpoint(endpoint generateEndpoint) gene return endpoint } -func (w *stateModelWrapper) wrapStreamEndpoint(endpoint streamEndpoint) streamEndpoint { +func (w *typedStateModelWrapper[M]) wrapStreamEndpoint(endpoint typedStreamEndpoint[M]) typedStreamEndpoint[M] { hasUserEventSender := w.hasUserEventSender() retryConfig := w.modelRetryConfig failoverConfig := w.modelFailoverConfig @@ -747,11 +782,11 @@ func (w *stateModelWrapper) wrapStreamEndpoint(endpoint streamEndpoint) streamEn handler := w.handlers[i] innerEndpoint := endpoint baseToolInfos := w.toolInfos - endpoint = func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + endpoint = func(ctx context.Context, input []M, opts ...model.Option) (*schema.StreamReader[M], error) { baseOpts := &model.Options{Tools: baseToolInfos} commonOpts := model.GetCommonOptions(baseOpts, opts...) mc := &ModelContext{Tools: commonOpts.Tools, ModelRetryConfig: retryConfig, cancelContext: cc} - wrappedModel, err := handler.WrapModel(ctx, &endpointModel{stream: innerEndpoint}, mc) + wrappedModel, err := handler.WrapModel(ctx, &typedEndpointModel[M]{stream: innerEndpoint}, mc) if err != nil { return nil, err } @@ -761,14 +796,16 @@ func (w *stateModelWrapper) wrapStreamEndpoint(endpoint streamEndpoint) streamEn if !hasUserEventSender { innerEndpoint := endpoint - eventSender := NewEventSenderModelWrapper() - endpoint = func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { - execCtx := getChatModelAgentExecCtx(ctx) + eventSender := &typedEventSenderModelWrapper[M]{ + TypedBaseChatModelAgentMiddleware: &TypedBaseChatModelAgentMiddleware[M]{}, + } + endpoint = func(ctx context.Context, input []M, opts ...model.Option) (*schema.StreamReader[M], error) { + execCtx := getTypedChatModelAgentExecCtx[M](ctx) if execCtx == nil || execCtx.generator == nil { return innerEndpoint(ctx, input, opts...) } mc := &ModelContext{ModelRetryConfig: retryConfig, ModelFailoverConfig: failoverConfig, cancelContext: cc} - wrappedModel, err := eventSender.WrapModel(ctx, &endpointModel{stream: innerEndpoint}, mc) + wrappedModel, err := eventSender.WrapModel(ctx, &typedEndpointModel[M]{stream: innerEndpoint}, mc) if err != nil { return nil, err } @@ -778,18 +815,17 @@ func (w *stateModelWrapper) wrapStreamEndpoint(endpoint streamEndpoint) streamEn if w.modelRetryConfig != nil { innerEndpoint := endpoint - endpoint = func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { - retryWrapper := newRetryModelWrapper(&endpointModel{stream: innerEndpoint}, w.modelRetryConfig) + endpoint = func(ctx context.Context, input []M, opts ...model.Option) (*schema.StreamReader[M], error) { + retryWrapper := newTypedRetryModelWrapper[M](&typedEndpointModel[M]{stream: innerEndpoint}, w.modelRetryConfig) return retryWrapper.Stream(ctx, input, opts...) } } - // Needs to handle failoverWrapper after retryWrapper if w.modelFailoverConfig != nil { config := w.modelFailoverConfig innerEndpoint := endpoint - endpoint = func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { - failoverWrapper := newFailoverModelWrapper(&endpointModel{stream: innerEndpoint}, config) + endpoint = func(ctx context.Context, input []M, opts ...model.Option) (*schema.StreamReader[M], error) { + failoverWrapper := newTypedFailoverModelWrapper[M](&typedEndpointModel[M]{stream: innerEndpoint}, config) return failoverWrapper.Stream(ctx, input, opts...) } } @@ -797,19 +833,22 @@ func (w *stateModelWrapper) wrapStreamEndpoint(endpoint streamEndpoint) streamEn return endpoint } -func (w *stateModelWrapper) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { - var stateMessages []Message - _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { +func (w *typedStateModelWrapper[M]) Generate(ctx context.Context, input []M, opts ...model.Option) (M, error) { + var stateMessages []M + _ = compose.ProcessState(ctx, func(_ context.Context, st *typedState[M]) error { stateMessages = st.Messages return nil }) - state := &ChatModelAgentState{Messages: stateMessages} + state := &TypedChatModelAgentState[M]{Messages: stateMessages} - for _, m := range w.middlewares { - if m.BeforeChatModel != nil { - if err := m.BeforeChatModel(ctx, state); err != nil { - return nil, err + if msgState, ok := any(state).(*ChatModelAgentState); ok { + for _, m := range w.middlewares { + if m.BeforeChatModel != nil { + if err := m.BeforeChatModel(ctx, msgState); err != nil { + var zero M + return zero, err + } } } } @@ -821,11 +860,12 @@ func (w *stateModelWrapper) Generate(ctx context.Context, input []*schema.Messag var err error ctx, state, err = handler.BeforeModelRewriteState(ctx, state, mc) if err != nil { - return nil, err + var zero M + return zero, err } } - _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { + _ = compose.ProcessState(ctx, func(_ context.Context, st *typedState[M]) error { st.Messages = state.Messages return nil }) @@ -833,14 +873,15 @@ func (w *stateModelWrapper) Generate(ctx context.Context, input []*schema.Messag wrappedEndpoint := w.wrapGenerateEndpoint(w.inner.Generate) result, err := wrappedEndpoint(ctx, state.Messages, opts...) if err != nil { - return nil, err + var zero M + return zero, err } // Re-read State.Messages after Generate completes: when ShouldRetry uses // PersistModifiedInputMessages, applyDecisionForRetry writes modified messages to State. // We must pick up those changes before appending the model result. if w.modelRetryConfig != nil && w.modelRetryConfig.ShouldRetry != nil { - _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { + _ = compose.ProcessState(ctx, func(_ context.Context, st *typedState[M]) error { state.Messages = st.Messages return nil }) @@ -851,42 +892,49 @@ func (w *stateModelWrapper) Generate(ctx context.Context, input []*schema.Messag for _, handler := range w.handlers { ctx, state, err = handler.AfterModelRewriteState(ctx, state, mc) if err != nil { - return nil, err + var zero M + return zero, err } } - for _, m := range w.middlewares { - if m.AfterChatModel != nil { - if err := m.AfterChatModel(ctx, state); err != nil { - return nil, err + if msgState, ok := any(state).(*ChatModelAgentState); ok { + for _, m := range w.middlewares { + if m.AfterChatModel != nil { + if err := m.AfterChatModel(ctx, msgState); err != nil { + var zero M + return zero, err + } } } } - _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { + _ = compose.ProcessState(ctx, func(_ context.Context, st *typedState[M]) error { st.Messages = state.Messages return nil }) if len(state.Messages) == 0 { - return nil, errors.New("no messages left in state after model call") + var zero M + return zero, errors.New("no messages left in state after model call") } return state.Messages[len(state.Messages)-1], nil } -func (w *stateModelWrapper) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { - var stateMessages []Message - _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { +func (w *typedStateModelWrapper[M]) Stream(ctx context.Context, input []M, opts ...model.Option) (*schema.StreamReader[M], error) { + var stateMessages []M + _ = compose.ProcessState(ctx, func(_ context.Context, st *typedState[M]) error { stateMessages = st.Messages return nil }) - state := &ChatModelAgentState{Messages: stateMessages} + state := &TypedChatModelAgentState[M]{Messages: stateMessages} - for _, m := range w.middlewares { - if m.BeforeChatModel != nil { - if err := m.BeforeChatModel(ctx, state); err != nil { - return nil, err + if msgState, ok := any(state).(*ChatModelAgentState); ok { + for _, m := range w.middlewares { + if m.BeforeChatModel != nil { + if err := m.BeforeChatModel(ctx, msgState); err != nil { + return nil, err + } } } } @@ -902,7 +950,7 @@ func (w *stateModelWrapper) Stream(ctx context.Context, input []*schema.Message, } } - _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { + _ = compose.ProcessState(ctx, func(_ context.Context, st *typedState[M]) error { st.Messages = state.Messages return nil }) @@ -912,14 +960,14 @@ func (w *stateModelWrapper) Stream(ctx context.Context, input []*schema.Message, if err != nil { return nil, err } - result, err := schema.ConcatMessageStream(stream) + result, err := concatMessageStream(stream) if err != nil { return nil, err } // Re-read State.Messages after Stream completes: same rationale as in Generate above. if w.modelRetryConfig != nil && w.modelRetryConfig.ShouldRetry != nil { - _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { + _ = compose.ProcessState(ctx, func(_ context.Context, st *typedState[M]) error { state.Messages = st.Messages return nil }) @@ -934,15 +982,17 @@ func (w *stateModelWrapper) Stream(ctx context.Context, input []*schema.Message, } } - for _, m := range w.middlewares { - if m.AfterChatModel != nil { - if err := m.AfterChatModel(ctx, state); err != nil { - return nil, err + if msgState, ok := any(state).(*ChatModelAgentState); ok { + for _, m := range w.middlewares { + if m.AfterChatModel != nil { + if err := m.AfterChatModel(ctx, msgState); err != nil { + return nil, err + } } } } - _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { + _ = compose.ProcessState(ctx, func(_ context.Context, st *typedState[M]) error { st.Messages = state.Messages return nil }) @@ -950,22 +1000,23 @@ func (w *stateModelWrapper) Stream(ctx context.Context, input []*schema.Message, if len(state.Messages) == 0 { return nil, errors.New("no messages left in state after model call") } - return schema.StreamReaderFromArray([]*schema.Message{state.Messages[len(state.Messages)-1]}), nil + return schema.StreamReaderFromArray([]M{state.Messages[len(state.Messages)-1]}), nil } -type endpointModel struct { - generate generateEndpoint - stream streamEndpoint +type typedEndpointModel[M messageType] struct { + generate typedGenerateEndpoint[M] + stream typedStreamEndpoint[M] } -func (m *endpointModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { +func (m *typedEndpointModel[M]) Generate(ctx context.Context, input []M, opts ...model.Option) (M, error) { if m.generate != nil { return m.generate(ctx, input, opts...) } - return nil, errors.New("generate endpoint not set") + var zero M + return zero, errors.New("generate endpoint not set") } -func (m *endpointModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { +func (m *typedEndpointModel[M]) Stream(ctx context.Context, input []M, opts ...model.Option) (*schema.StreamReader[M], error) { if m.stream != nil { return m.stream(ctx, input, opts...) } diff --git a/adk/wrappers_failover_test.go b/adk/wrappers_failover_test.go index 8b14463e1..92a68fe9b 100644 --- a/adk/wrappers_failover_test.go +++ b/adk/wrappers_failover_test.go @@ -22,6 +22,7 @@ import ( "sync/atomic" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/cloudwego/eino/components/model" @@ -47,7 +48,7 @@ func TestBuildModelWrappers_FailoverProxyInner(t *testing.T) { }, } - wrapped := buildModelWrappers(base, &modelWrapperConfig{ + wrapped := buildModelWrappers[*schema.Message](base, &modelWrapperConfig{ failoverConfig: failoverCfg, }) @@ -101,11 +102,11 @@ func TestStateModelWrapper_Generate_WithFailover(t *testing.T) { }, } - wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + wrapped := buildModelWrappers[*schema.Message](m1, &modelWrapperConfig{ failoverConfig: failoverCfg, }) - ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + ctx := withTypedChatModelAgentExecCtx[*schema.Message](context.Background(), &chatModelAgentExecCtx{ failoverLastSuccessModel: m1, }) got, err := wrapped.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) @@ -162,11 +163,11 @@ func TestStateModelWrapper_Stream_WithFailover(t *testing.T) { }, } - wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + wrapped := buildModelWrappers[*schema.Message](m1, &modelWrapperConfig{ failoverConfig: failoverCfg, }) - ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + ctx := withTypedChatModelAgentExecCtx[*schema.Message](context.Background(), &chatModelAgentExecCtx{ failoverLastSuccessModel: m1, }) sr, err := wrapped.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) @@ -179,3 +180,36 @@ func TestStateModelWrapper_Stream_WithFailover(t *testing.T) { require.Equal(t, int32(1), atomic.LoadInt32(&m2Calls)) require.Equal(t, int32(1), atomic.LoadInt32(&shouldCalls)) } + +func TestFailoverAcceptsAgenticAgent(t *testing.T) { + ctx := context.Background() + + m := &mockAgenticModel{ + generateFn: func(ctx context.Context, input []*schema.AgenticMessage, opts ...model.Option) (*schema.AgenticMessage, error) { + return agenticMsg("ok"), nil + }, + } + + fallbackModel := &mockChatModelForAttack{ + generateFn: func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + return schema.AssistantMessage("fallback", nil), nil + }, + } + + agent, err := NewTypedChatModelAgent[*schema.AgenticMessage](ctx, &TypedChatModelAgentConfig[*schema.AgenticMessage]{ + Name: "FailoverAgent", + Description: "Agent with failover config", + Model: m, + ModelFailoverConfig: &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(ctx context.Context, outputMessage *schema.Message, outputErr error) bool { + return true + }, + GetFailoverModel: func(ctx context.Context, failoverCtx *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return fallbackModel, nil, nil + }, + }, + }) + require.NoError(t, err) + assert.NotNil(t, agent) +} diff --git a/adk/wrappers_retry_failover_test.go b/adk/wrappers_retry_failover_test.go index 29c4b495a..4e8f05d76 100644 --- a/adk/wrappers_retry_failover_test.go +++ b/adk/wrappers_retry_failover_test.go @@ -78,12 +78,12 @@ func TestRetryThenFailover(t *testing.T) { }, } - wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + wrapped := buildModelWrappers[*schema.Message](m1, &modelWrapperConfig{ retryConfig: retryCfg, failoverConfig: failoverCfg, }) - ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + ctx := withTypedChatModelAgentExecCtx[*schema.Message](context.Background(), &chatModelAgentExecCtx{ failoverLastSuccessModel: m1, }) msg, err := wrapped.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) @@ -126,12 +126,12 @@ func TestRetryThenFailover(t *testing.T) { }, } - wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + wrapped := buildModelWrappers[*schema.Message](m1, &modelWrapperConfig{ retryConfig: retryCfg, failoverConfig: failoverCfg, }) - ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + ctx := withTypedChatModelAgentExecCtx[*schema.Message](context.Background(), &chatModelAgentExecCtx{ failoverLastSuccessModel: m1, }) _, err := wrapped.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) @@ -177,12 +177,12 @@ func TestRetryThenFailover(t *testing.T) { }, } - wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + wrapped := buildModelWrappers[*schema.Message](m1, &modelWrapperConfig{ retryConfig: retryCfg, failoverConfig: failoverCfg, }) - ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + ctx := withTypedChatModelAgentExecCtx[*schema.Message](context.Background(), &chatModelAgentExecCtx{ failoverLastSuccessModel: m1, }) msg, err := wrapped.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) @@ -228,12 +228,12 @@ func TestRetryThenFailover(t *testing.T) { }, } - wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + wrapped := buildModelWrappers[*schema.Message](m1, &modelWrapperConfig{ retryConfig: retryCfg, failoverConfig: failoverCfg, }) - ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + ctx := withTypedChatModelAgentExecCtx[*schema.Message](context.Background(), &chatModelAgentExecCtx{ failoverLastSuccessModel: m1, }) msg, err := wrapped.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) @@ -278,12 +278,12 @@ func TestRetryThenFailover(t *testing.T) { }, } - wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + wrapped := buildModelWrappers[*schema.Message](m1, &modelWrapperConfig{ retryConfig: retryCfg, failoverConfig: failoverCfg, }) - ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + ctx := withTypedChatModelAgentExecCtx[*schema.Message](context.Background(), &chatModelAgentExecCtx{ failoverLastSuccessModel: m1, }) sr, err := wrapped.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) @@ -332,12 +332,12 @@ func TestRetryThenFailover(t *testing.T) { }, } - wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + wrapped := buildModelWrappers[*schema.Message](m1, &modelWrapperConfig{ retryConfig: retryCfg, failoverConfig: failoverCfg, }) - ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + ctx := withTypedChatModelAgentExecCtx[*schema.Message](context.Background(), &chatModelAgentExecCtx{ failoverLastSuccessModel: m1, }) _, err := wrapped.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) @@ -386,12 +386,12 @@ func TestRetryThenFailover(t *testing.T) { }, } - wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + wrapped := buildModelWrappers[*schema.Message](m1, &modelWrapperConfig{ retryConfig: retryCfg, failoverConfig: failoverCfg, }) - ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + ctx := withTypedChatModelAgentExecCtx[*schema.Message](context.Background(), &chatModelAgentExecCtx{ failoverLastSuccessModel: m1, }) sr, err := wrapped.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) @@ -438,12 +438,12 @@ func TestRetryThenFailover(t *testing.T) { }, } - wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + wrapped := buildModelWrappers[*schema.Message](m1, &modelWrapperConfig{ retryConfig: retryCfg, failoverConfig: failoverCfg, }) - ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + ctx := withTypedChatModelAgentExecCtx[*schema.Message](context.Background(), &chatModelAgentExecCtx{ failoverLastSuccessModel: m1, }) msg, err := wrapped.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) @@ -478,12 +478,12 @@ func TestRetryThenFailover(t *testing.T) { }, } - wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + wrapped := buildModelWrappers[*schema.Message](m1, &modelWrapperConfig{ retryConfig: retryCfg, failoverConfig: failoverCfg, }) - ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + ctx := withTypedChatModelAgentExecCtx[*schema.Message](context.Background(), &chatModelAgentExecCtx{ failoverLastSuccessModel: m1, }) _, err := wrapped.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) @@ -522,12 +522,12 @@ func TestRetryThenFailover(t *testing.T) { }, } - wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + wrapped := buildModelWrappers[*schema.Message](m1, &modelWrapperConfig{ retryConfig: retryCfg, failoverConfig: failoverCfg, }) - ctx = withChatModelAgentExecCtx(ctx, &chatModelAgentExecCtx{ + ctx = withTypedChatModelAgentExecCtx[*schema.Message](ctx, &chatModelAgentExecCtx{ failoverLastSuccessModel: m1, }) _, err := wrapped.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) @@ -562,11 +562,11 @@ func TestErrStreamCanceled_Failover(t *testing.T) { }, } - wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + wrapped := buildModelWrappers[*schema.Message](m1, &modelWrapperConfig{ failoverConfig: failoverCfg, }) - ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + ctx := withTypedChatModelAgentExecCtx[*schema.Message](context.Background(), &chatModelAgentExecCtx{ failoverLastSuccessModel: m1, }) _, err := wrapped.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) @@ -597,11 +597,11 @@ func TestErrStreamCanceled_Failover(t *testing.T) { }, } - wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + wrapped := buildModelWrappers[*schema.Message](m1, &modelWrapperConfig{ failoverConfig: failoverCfg, }) - ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + ctx := withTypedChatModelAgentExecCtx[*schema.Message](context.Background(), &chatModelAgentExecCtx{ failoverLastSuccessModel: m1, }) _, err := wrapped.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) diff --git a/components/model/interface.go b/components/model/interface.go index cf79785bc..78eadaf28 100644 --- a/components/model/interface.go +++ b/components/model/interface.go @@ -22,7 +22,19 @@ import ( "github.com/cloudwego/eino/schema" ) -// BaseChatModel defines the core interface for all chat model implementations. +// BaseModel is the generic base model interface parameterized by message type M. +// It exposes two modes of interaction: +// - [BaseModel.Generate]: blocks until the model returns a complete response. +// - [BaseModel.Stream]: returns a [schema.StreamReader] that yields message +// chunks incrementally as the model generates them. +type BaseModel[M any] interface { + Generate(ctx context.Context, input []M, opts ...Option) (M, error) + Stream(ctx context.Context, input []M, opts ...Option) (*schema.StreamReader[M], error) +} + +// BaseChatModel is a backward-compatible type alias for BaseModel specialized +// with *schema.Message. All existing code using model.BaseChatModel continues +// to work without modification. // // It exposes two modes of interaction: // - [BaseChatModel.Generate]: blocks until the model returns a complete response. @@ -49,12 +61,8 @@ import ( // Note: a [schema.StreamReader] can only be read once. If multiple consumers // need the stream, it must be copied before reading. // -//go:generate mockgen -destination ../../internal/mock/components/model/ChatModel_mock.go --package model -source interface.go -type BaseChatModel interface { - Generate(ctx context.Context, input []*schema.Message, opts ...Option) (*schema.Message, error) - Stream(ctx context.Context, input []*schema.Message, opts ...Option) ( - *schema.StreamReader[*schema.Message], error) -} +//go:generate mockgen -destination ../../internal/mock/components/model/ChatModel_mock.go --package model github.com/cloudwego/eino/components/model BaseChatModel,ChatModel,ToolCallingChatModel +type BaseChatModel = BaseModel[*schema.Message] // Deprecated: Use [ToolCallingChatModel] instead. // @@ -85,19 +93,11 @@ type ChatModel interface { type ToolCallingChatModel interface { BaseChatModel - // WithTools returns a new ToolCallingChatModel instance with the specified tools bound. - // This method does not modify the current instance, making it safer for concurrent use. WithTools(tools []*schema.ToolInfo) (ToolCallingChatModel, error) } -// AgenticModel defines the interface for agentic models that support AgenticMessage. -// It provides methods for generating complete and streaming outputs, and supports -// tool calling via the WithTools method. -type AgenticModel interface { - Generate(ctx context.Context, input []*schema.AgenticMessage, opts ...Option) (*schema.AgenticMessage, error) - Stream(ctx context.Context, input []*schema.AgenticMessage, opts ...Option) (*schema.StreamReader[*schema.AgenticMessage], error) - - // WithTools returns a new Model instance with the specified tools bound. - // This method does not modify the current instance, making it safer for concurrent use. - WithTools(tools []*schema.ToolInfo) (AgenticModel, error) -} +// AgenticModel is a type alias for BaseModel specialized with +// *schema.AgenticMessage. Unlike ToolCallingChatModel, agentic models do NOT +// expose a WithTools method; tools are passed at request time via the +// model.WithTools option, consistent with how ChatModelAgent binds tools. +type AgenticModel = BaseModel[*schema.AgenticMessage] diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 43376c146..01a24c4a1 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -17,13 +17,16 @@ package schema import ( + "bytes" "context" + "encoding/gob" "encoding/json" "fmt" "reflect" "sort" "strings" + "github.com/bytedance/sonic" "github.com/eino-contrib/jsonschema" "github.com/cloudwego/eino/internal" @@ -420,6 +423,47 @@ type MCPListToolsItem struct { InputSchema *jsonschema.Schema `json:"input_schema,omitempty"` } +type mcpListToolsItemGob struct { + Name string + Description string + InputSchemaJSON []byte +} + +func (m *MCPListToolsItem) GobEncode() ([]byte, error) { + g := mcpListToolsItemGob{ + Name: m.Name, + Description: m.Description, + } + if m.InputSchema != nil { + b, err := json.Marshal(m.InputSchema) + if err != nil { + return nil, fmt.Errorf("failed to marshal MCPListToolsItem.InputSchema: %w", err) + } + g.InputSchemaJSON = b + } + var buf bytes.Buffer + if err := gob.NewEncoder(&buf).Encode(&g); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +func (m *MCPListToolsItem) GobDecode(data []byte) error { + var g mcpListToolsItemGob + if err := gob.NewDecoder(bytes.NewReader(data)).Decode(&g); err != nil { + return err + } + m.Name = g.Name + m.Description = g.Description + if len(g.InputSchemaJSON) > 0 { + m.InputSchema = &jsonschema.Schema{} + if err := sonic.Unmarshal(g.InputSchemaJSON, m.InputSchema); err != nil { + return fmt.Errorf("failed to unmarshal MCPListToolsItem.InputSchema: %w", err) + } + } + return nil +} + type MCPToolApprovalRequest struct { // ID is the approval request ID. ID string `json:"id,omitempty"` @@ -1335,7 +1379,6 @@ func concatAssistantGenTexts(texts []*AssistantGenText) (ret *AssistantGenText, if err != nil { return nil, err } - ret.Extension = extensions.Interface() } if len(openaiExtensions) > 0 { @@ -2029,7 +2072,11 @@ func (m *MCPToolResult) String() string { sb.WriteString(fmt.Sprintf(" name: %s\n", m.Name)) sb.WriteString(fmt.Sprintf(" result: %s\n", m.Result)) if m.Error != nil { - sb.WriteString(fmt.Sprintf(" error: [%d] %s\n", *m.Error.Code, m.Error.Message)) + if m.Error.Code != nil { + sb.WriteString(fmt.Sprintf(" error: [%d] %s\n", *m.Error.Code, m.Error.Message)) + } else { + sb.WriteString(fmt.Sprintf(" error: %s\n", m.Error.Message)) + } } return sb.String() } diff --git a/schema/agentic_message_test.go b/schema/agentic_message_test.go index 10639f738..aea4252d7 100644 --- a/schema/agentic_message_test.go +++ b/schema/agentic_message_test.go @@ -22,6 +22,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestConcatAgenticMessages(t *testing.T) { @@ -1639,3 +1640,41 @@ func TestNewContentBlock(t *testing.T) { }) } } + +func TestNewContentBlockChunk_NilMeta(t *testing.T) { + require.NotPanics(t, func() { + block := NewContentBlockChunk(&AssistantGenText{Text: "test"}, nil) + require.NotNil(t, block) + assert.Nil(t, block.StreamingMeta) + }, "NewContentBlockChunk should handle nil meta without panic") +} + +func TestConcatAssistantGenTexts_ExtensionOverwrite(t *testing.T) { + type testExtension struct { + Value string + } + + texts := []*AssistantGenText{ + {Text: "Hello ", Extension: &testExtension{Value: "ext1"}}, + {Text: "world", Extension: &testExtension{Value: "ext2"}}, + } + + result, err := concatAssistantGenTexts(texts) + if err != nil { + t.Logf("Concat error (may be expected if ConcatSliceValue doesn't handle this type): %v", err) + t.Skip("Skipping: ConcatSliceValue doesn't support test type") + } + require.NotNil(t, result) + + assert.Equal(t, "Hello world", result.Text) + + if result.Extension != nil { + t.Logf("Extension type: %T, value: %v", result.Extension, result.Extension) + _, isSlice := result.Extension.([]*testExtension) + if isSlice { + t.Log("WARNING: Extension is a raw slice instead of a concatenated value. " + + "Line 1381 in agentic_message.go overwrites the ConcatSliceValue result " + + "with extensions.Interface(), discarding the concatenation.") + } + } +} diff --git a/schema/serialization.go b/schema/serialization.go index 22fa16ade..169bf9ee9 100644 --- a/schema/serialization.go +++ b/schema/serialization.go @@ -27,6 +27,8 @@ import ( func init() { RegisterName[*Message]("_eino_message") RegisterName[[]*Message]("_eino_message_slice") + RegisterName[*AgenticMessage]("_eino_agentic_message") + RegisterName[[]*AgenticMessage]("_eino_agentic_message_slice") RegisterName[Document]("_eino_document") RegisterName[RoleType]("_eino_role_type") RegisterName[ToolCall]("_eino_tool_call") diff --git a/schema/tool_test.go b/schema/tool_test.go index e8f95c364..8966cde54 100644 --- a/schema/tool_test.go +++ b/schema/tool_test.go @@ -25,6 +25,7 @@ import ( "github.com/eino-contrib/jsonschema" "github.com/smartystreets/goconvey/convey" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestParamsOneOfToJSONSchema(t *testing.T) { @@ -181,3 +182,40 @@ func TestToolInfoSerialization(t *testing.T) { assert.NoError(t, err) assert.Equal(t, ti2, result) } + +func TestMCPToolResult_NilErrorCode(t *testing.T) { + result := &MCPToolResult{ + CallID: "test-call", + Name: "test-tool", + Result: "some result", + Error: &MCPToolCallError{ + Code: nil, + Message: "something went wrong", + }, + } + + require.NotPanics(t, func() { + s := result.String() + t.Logf("String output: %s", s) + assert.Contains(t, s, "something went wrong") + }, "BUG: MCPToolResult.String() should not panic when Error.Code is nil") +} + +func TestMCPToolResult_WithErrorCode(t *testing.T) { + code := int64(500) + result := &MCPToolResult{ + CallID: "test-call", + Name: "test-tool", + Result: "", + Error: &MCPToolCallError{ + Code: &code, + Message: "internal server error", + }, + } + + require.NotPanics(t, func() { + s := result.String() + assert.Contains(t, s, "500") + assert.Contains(t, s, "internal server error") + }) +} diff --git a/utils/callbacks/template.go b/utils/callbacks/template.go index f01a849b6..850e3011c 100644 --- a/utils/callbacks/template.go +++ b/utils/callbacks/template.go @@ -65,6 +65,7 @@ type HandlerHelper struct { toolHandler *ToolCallbackHandler toolsNodeHandler *ToolsNodeCallbackHandlers agentHandler *AgentCallbackHandler + agenticAgentHandler *AgenticAgentCallbackHandler agenticPromptHandler *AgenticPromptCallbackHandler agenticModelHandler *AgenticModelCallbackHandler agenticToolsNodeHandler *AgenticToolsNodeCallbackHandlers @@ -154,6 +155,12 @@ func (c *HandlerHelper) Agent(handler *AgentCallbackHandler) *HandlerHelper { return c } +// AgenticAgent sets the agentic agent callback handler for the handler helper, which will be called when an agentic agent is executed. +func (c *HandlerHelper) AgenticAgent(handler *AgenticAgentCallbackHandler) *HandlerHelper { + c.agenticAgentHandler = handler + return c +} + // Graph sets the graph handler for the handler helper, which will be called when the graph is executed. func (c *HandlerHelper) Graph(handler callbacks.Handler) *HandlerHelper { c.composeTemplates[compose.ComponentOfGraph] = handler @@ -206,6 +213,8 @@ func (c *handlerTemplate) OnStart(ctx context.Context, info *callbacks.RunInfo, return c.agenticToolsNodeHandler.OnStart(ctx, info, convAgenticToolsNodeCallbackInput(input)) case adk.ComponentOfAgent: return c.agentHandler.OnStart(ctx, info, adk.ConvAgentCallbackInput(input)) + case adk.ComponentOfAgenticAgent: + return c.agenticAgentHandler.OnStart(ctx, info, adk.ConvTypedCallbackInput[*schema.AgenticMessage](input)) case compose.ComponentOfGraph, compose.ComponentOfChain, compose.ComponentOfLambda: @@ -245,6 +254,8 @@ func (c *handlerTemplate) OnEnd(ctx context.Context, info *callbacks.RunInfo, ou return c.agenticToolsNodeHandler.OnEnd(ctx, info, convAgenticToolsNodeCallbackOutput(output)) case adk.ComponentOfAgent: return c.agentHandler.OnEnd(ctx, info, adk.ConvAgentCallbackOutput(output)) + case adk.ComponentOfAgenticAgent: + return c.agenticAgentHandler.OnEnd(ctx, info, adk.ConvTypedCallbackOutput[*schema.AgenticMessage](output)) case compose.ComponentOfGraph, compose.ComponentOfChain, compose.ComponentOfLambda: @@ -404,6 +415,10 @@ func (c *handlerTemplate) Needed(ctx context.Context, info *callbacks.RunInfo, t if c.agentHandler != nil && c.agentHandler.Needed(ctx, info, timing) { return true } + case adk.ComponentOfAgenticAgent: + if c.agenticAgentHandler != nil && c.agenticAgentHandler.Needed(ctx, info, timing) { + return true + } case compose.ComponentOfGraph, compose.ComponentOfChain, compose.ComponentOfLambda: @@ -644,9 +659,14 @@ func convToolsNodeCallbackOutput(src callbacks.CallbackInput) []*schema.Message } } +// AgentCallbackHandler handles callbacks for agents using *schema.Message. +// Use ComponentOfAgent to filter callback events to agent-related events. type AgentCallbackHandler struct { + // OnStart is called when an agent run begins. Return a modified context to propagate values. OnStart func(ctx context.Context, info *callbacks.RunInfo, input *adk.AgentCallbackInput) context.Context - OnEnd func(ctx context.Context, info *callbacks.RunInfo, output *adk.AgentCallbackOutput) context.Context + // OnEnd is called when an agent run completes. The output's Events iterator should be + // consumed asynchronously to avoid blocking. + OnEnd func(ctx context.Context, info *callbacks.RunInfo, output *adk.AgentCallbackOutput) context.Context } func (ch *AgentCallbackHandler) Needed(ctx context.Context, info *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { @@ -660,6 +680,27 @@ func (ch *AgentCallbackHandler) Needed(ctx context.Context, info *callbacks.RunI } } +// AgenticAgentCallbackHandler handles callbacks for agentic agents using *schema.AgenticMessage. +// Use ComponentOfAgenticAgent to filter callback events to agentic-agent-related events. +type AgenticAgentCallbackHandler struct { + // OnStart is called when an agentic agent run begins. Return a modified context to propagate values. + OnStart func(ctx context.Context, info *callbacks.RunInfo, input *adk.TypedAgentCallbackInput[*schema.AgenticMessage]) context.Context + // OnEnd is called when an agentic agent run completes. The output's Events iterator should be + // consumed asynchronously to avoid blocking. + OnEnd func(ctx context.Context, info *callbacks.RunInfo, output *adk.TypedAgentCallbackOutput[*schema.AgenticMessage]) context.Context +} + +func (ch *AgenticAgentCallbackHandler) Needed(ctx context.Context, info *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { + switch timing { + case callbacks.TimingOnStart: + return ch.OnStart != nil + case callbacks.TimingOnEnd: + return ch.OnEnd != nil + default: + return false + } +} + // AgenticPromptCallbackHandler is the handler for the agentic prompt callback. type AgenticPromptCallbackHandler struct { // OnStart is the callback function for the start of the agentic prompt. diff --git a/utils/callbacks/template_test.go b/utils/callbacks/template_test.go index dcc0e5c7f..79be157f3 100644 --- a/utils/callbacks/template_test.go +++ b/utils/callbacks/template_test.go @@ -683,3 +683,125 @@ func TestHandlerTemplateWithAgentComponent(t *testing.T) { assert.True(t, checker.Needed(ctx, info, callbacks.TimingOnStart)) }) } + +func TestAgenticAgentCallbackHandler(t *testing.T) { + t.Run("Needed returns correct values", func(t *testing.T) { + handler := &AgenticAgentCallbackHandler{ + OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *adk.TypedAgentCallbackInput[*schema.AgenticMessage]) context.Context { + return ctx + }, + } + + ctx := context.Background() + info := &callbacks.RunInfo{Component: adk.ComponentOfAgenticAgent} + + assert.True(t, handler.Needed(ctx, info, callbacks.TimingOnStart)) + assert.False(t, handler.Needed(ctx, info, callbacks.TimingOnEnd)) + }) + + t.Run("Needed with OnEnd set", func(t *testing.T) { + handler := &AgenticAgentCallbackHandler{ + OnEnd: func(ctx context.Context, info *callbacks.RunInfo, output *adk.TypedAgentCallbackOutput[*schema.AgenticMessage]) context.Context { + return ctx + }, + } + + ctx := context.Background() + info := &callbacks.RunInfo{Component: adk.ComponentOfAgenticAgent} + + assert.False(t, handler.Needed(ctx, info, callbacks.TimingOnStart)) + assert.True(t, handler.Needed(ctx, info, callbacks.TimingOnEnd)) + }) + + t.Run("Needed with nil handlers", func(t *testing.T) { + handler := &AgenticAgentCallbackHandler{} + + ctx := context.Background() + info := &callbacks.RunInfo{Component: adk.ComponentOfAgenticAgent} + + assert.False(t, handler.Needed(ctx, info, callbacks.TimingOnStart)) + assert.False(t, handler.Needed(ctx, info, callbacks.TimingOnEnd)) + }) +} + +func TestHandlerHelperWithAgenticAgent(t *testing.T) { + t.Run("AgenticAgent method sets handler correctly", func(t *testing.T) { + cnt := 0 + tpl := NewHandlerHelper() + tpl.AgenticAgent(&AgenticAgentCallbackHandler{ + OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *adk.TypedAgentCallbackInput[*schema.AgenticMessage]) context.Context { + cnt++ + return ctx + }, + OnEnd: func(ctx context.Context, info *callbacks.RunInfo, output *adk.TypedAgentCallbackOutput[*schema.AgenticMessage]) context.Context { + cnt++ + return ctx + }, + }) + + handler := tpl.Handler() + ctx := context.Background() + ctx = callbacks.InitCallbacks(ctx, &callbacks.RunInfo{Component: adk.ComponentOfAgenticAgent}, handler) + + ctx = callbacks.OnStart[any](ctx, nil) + assert.Equal(t, 1, cnt) + + callbacks.OnEnd[any](ctx, nil) + assert.Equal(t, 2, cnt) + }) +} + +func TestHandlerTemplateWithAgenticAgentComponent(t *testing.T) { + t.Run("OnStart routes to agentic agent handler", func(t *testing.T) { + called := false + tpl := NewHandlerHelper() + tpl.AgenticAgent(&AgenticAgentCallbackHandler{ + OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *adk.TypedAgentCallbackInput[*schema.AgenticMessage]) context.Context { + called = true + return ctx + }, + }) + + handler := tpl.Handler() + ctx := context.Background() + info := &callbacks.RunInfo{Component: adk.ComponentOfAgenticAgent, Name: "TestAgenticAgent"} + + handler.OnStart(ctx, info, &adk.TypedAgentCallbackInput[*schema.AgenticMessage]{}) + assert.True(t, called) + }) + + t.Run("OnEnd routes to agentic agent handler", func(t *testing.T) { + called := false + tpl := NewHandlerHelper() + tpl.AgenticAgent(&AgenticAgentCallbackHandler{ + OnEnd: func(ctx context.Context, info *callbacks.RunInfo, output *adk.TypedAgentCallbackOutput[*schema.AgenticMessage]) context.Context { + called = true + return ctx + }, + }) + + handler := tpl.Handler() + ctx := context.Background() + info := &callbacks.RunInfo{Component: adk.ComponentOfAgenticAgent, Name: "TestAgenticAgent"} + + handler.OnEnd(ctx, info, &adk.TypedAgentCallbackOutput[*schema.AgenticMessage]{}) + assert.True(t, called) + }) + + t.Run("Needed returns true for agentic agent component", func(t *testing.T) { + tpl := NewHandlerHelper() + tpl.AgenticAgent(&AgenticAgentCallbackHandler{ + OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *adk.TypedAgentCallbackInput[*schema.AgenticMessage]) context.Context { + return ctx + }, + }) + + handler := tpl.Handler() + ctx := context.Background() + info := &callbacks.RunInfo{Component: adk.ComponentOfAgenticAgent} + + checker, ok := handler.(callbacks.TimingChecker) + assert.True(t, ok, "handler should implement TimingChecker") + assert.True(t, checker.Needed(ctx, info, callbacks.TimingOnStart)) + }) +} From 28a9142a077402c3837b7b6e69ef27a4262dccf6 Mon Sep 17 00:00:00 2001 From: Born Date: Tue, 21 Apr 2026 20:23:14 +0800 Subject: [PATCH 60/65] feat(adk): add EnhancedRead with custom FileContentPart types (#973) feat(adk): add MultiModalRead with custom FileContentPart types - Define FileContentPartType, FileContentPart in filesystem package to replace direct schema.ToolOutputPart dependency, supporting only Image (bytes) and File (bytes) types - Add MultiModalReader interface and MultiModalReadRequest with Pages field - Add multiModalReadFileArgs extending readFileArgs with PDF pages param - Convert FileContentPart to schema.ToolOutputPart with base64 encoding in middleware layer - Guard against nil FileContent returned from Backend.Read and MultiModalRead; return human-readable fallback instead of panicking - Reuse base64 encoding buffer across multimodal parts via base64Encoder - Add tests for image, file, unsupported type, pages passthrough, schema fields, custom desc, empty data error, nil result, and routing --- adk/filesystem/backend.go | 64 ++++ adk/middlewares/filesystem/backend.go | 1 + adk/middlewares/filesystem/filesystem.go | 182 +++++++++- adk/middlewares/filesystem/filesystem_test.go | 341 +++++++++++++++++- adk/middlewares/filesystem/prompt.go | 9 + 5 files changed, 585 insertions(+), 12 deletions(-) diff --git a/adk/filesystem/backend.go b/adk/filesystem/backend.go index 44f604927..62ebee870 100644 --- a/adk/filesystem/backend.go +++ b/adk/filesystem/backend.go @@ -75,6 +75,15 @@ type ReadRequest struct { Limit int } +// MultiModalReadRequest extends ReadRequest with parameters only applicable +// to MultiModalReader implementations (e.g. PDF page ranges). +type MultiModalReadRequest struct { + ReadRequest + + // Pages specifies the page range for PDF files (e.g. "1-5", "3", "10-20"). + Pages string +} + // GrepRequest contains parameters for searching file content. type GrepRequest struct { // ===== Search Parameters ===== @@ -168,10 +177,65 @@ type EditRequest struct { ReplaceAll bool } +// FileContentPartType defines the type of a multimodal file content part. +type FileContentPartType string + +const ( + // FileContentPartTypeImage represents an image part (e.g. PNG, JPG). + FileContentPartTypeImage FileContentPartType = "image" + // FileContentPartTypePDF represents a file part (e.g. PDF). + FileContentPartTypePDF FileContentPartType = "pdf" +) + +// FileContentPart represents a multimodal part of file content. +// Data holds raw bytes; encoding (e.g. base64) is handled by the consumer. +type FileContentPart struct { + // Type is the kind of content this part represents. + // Required. + Type FileContentPartType + + // MIMEType is the MIME type of the content (e.g. "image/png", "application/pdf"). + // Required. + MIMEType string + + // Data is the raw binary content. + // Required. + Data []byte +} + +// FileContent holds the result of a Read operation. type FileContent struct { + // Content holds the plain text content of the file. Content string } +// MultiFileContent holds the result of a MultiModalRead operation. +// +// FileContent and Parts are mutually exclusive (one-of): +// - Set FileContent for plain text results (same as a normal Read). +// - Set Parts for multimodal results (images, PDFs, etc.). +// +// When Parts is non-empty, FileContent is ignored. +type MultiFileContent struct { + *FileContent + + // Parts holds multimodal output parts (e.g. image, PDF). + Parts []FileContentPart +} + +// MultiModalReader is an optional extension interface for Backend. +// Backends that implement this interface support multimodal file reading, +// returning structured parts (images, PDFs) instead of plain text. +// +// For large file handling, there are two approaches to control output size: +// - Implement size control within MultiModalRead (e.g. reject files exceeding a threshold, +// downsample images, or limit PDF page counts at the backend level). +// - Use ToolMiddleware's EnhancedInvokable to customize result transformation, +// or use the built-in reduction middleware with configurable policies. +type MultiModalReader interface { + MultiModalRead(ctx context.Context, req *MultiModalReadRequest) (*MultiFileContent, error) +} + // Backend is a pluggable, unified file backend protocol interface. // // All methods use struct-based parameters to allow future extensibility diff --git a/adk/middlewares/filesystem/backend.go b/adk/middlewares/filesystem/backend.go index c5935066e..eec62f162 100644 --- a/adk/middlewares/filesystem/backend.go +++ b/adk/middlewares/filesystem/backend.go @@ -25,6 +25,7 @@ type FileInfo = filesystem.FileInfo type GrepMatch = filesystem.GrepMatch type LsInfoRequest = filesystem.LsInfoRequest type ReadRequest = filesystem.ReadRequest +type MultiModalReadRequest = filesystem.MultiModalReadRequest type GrepRequest = filesystem.GrepRequest type GlobInfoRequest = filesystem.GlobInfoRequest type WriteRequest = filesystem.WriteRequest diff --git a/adk/middlewares/filesystem/filesystem.go b/adk/middlewares/filesystem/filesystem.go index ba43d82ad..143d1d6c5 100644 --- a/adk/middlewares/filesystem/filesystem.go +++ b/adk/middlewares/filesystem/filesystem.go @@ -18,6 +18,7 @@ package filesystem import ( "context" + "encoding/base64" "errors" "fmt" "io" @@ -92,7 +93,9 @@ type Config struct { // LsToolConfig configures the ls tool // optional LsToolConfig *ToolConfig - // ReadFileToolConfig configures the read_file tool + // ReadFileToolConfig configures the read_file tool. + // This config applies to both the standard read_file tool (InvokableTool) and + // the multimodal read_file tool (EnhancedInvokableTool) when UseMultiModalRead is true. // optional ReadFileToolConfig *ToolConfig // WriteFileToolConfig configures the write_file tool @@ -233,7 +236,9 @@ type MiddlewareConfig struct { // LsToolConfig configures the ls tool // optional LsToolConfig *ToolConfig - // ReadFileToolConfig configures the read_file tool + // ReadFileToolConfig configures the read_file tool. + // This config applies to both the standard read_file tool (InvokableTool) and + // the multimodal read_file tool (EnhancedInvokableTool) when UseMultiModalRead is true. // optional ReadFileToolConfig *ToolConfig // WriteFileToolConfig configures the write_file tool @@ -249,6 +254,24 @@ type MiddlewareConfig struct { // optional GrepToolConfig *ToolConfig + // UseMultiModalRead enables multimodal read_file tool (EnhancedInvokableTool). + // When true, read_file returns results via schema.ToolResult.Parts instead of plain text string. + // + // Requires Backend to implement filesystem.MultiModalReader interface. + // The default implementation supports reading image files (PNG, JPG, etc.) + // and PDF files with page range selection. + // + // If you provide a custom MultiModalReader, you may need to override + // ReadFileToolConfig.Desc to accurately describe your implementation's capabilities. + // The default description is composed of ReadFileToolDesc + EnhancedReadFileDescSuffix. + // + // Note: When enabled, the read_file tool becomes an EnhancedInvokableTool. + // If you use ChatModelAgentMiddleware, you must implement ChatModelAgentMiddleware.WrapEnhancedInvokableToolCall + // for the middleware to take effect on the read_file tool. + // + // Default false, preserving backward compatibility. + UseMultiModalRead bool + // CustomSystemPrompt overrides the default ToolsSystemPrompt appended to agent instruction // optional, ToolsSystemPrompt by default CustomSystemPrompt *string @@ -406,6 +429,9 @@ func getFilesystemTools(_ context.Context, middlewareConfig *MiddlewareConfig) ( legacyDesc: middlewareConfig.CustomReadFileToolDesc, createFunc: func(name, desc string) (tool.BaseTool, error) { if middlewareConfig.Backend != nil { + if middlewareConfig.UseMultiModalRead { + return newMultiModalReadFileTool(middlewareConfig.Backend, name, desc) + } return newReadFileTool(middlewareConfig.Backend, name, desc) } return nil, nil @@ -554,6 +580,14 @@ type readFileArgs struct { Limit int `json:"limit" jsonschema:"description=The number of lines to read. Only provide if the file is too large to read at once."` } +// multiModalReadFileArgs extends readFileArgs with PDF-specific parameters for MultiModalReadFileTool. +type multiModalReadFileArgs struct { + readFileArgs + + // Pages is the page range for PDF files. + Pages string `json:"pages,omitempty" jsonschema:"description=Page range for PDF files (e.g.\\, \"1-5\"\\, \"3\"\\, \"10-20\"). Only applicable to PDF files. Maximum 20 pages per request."` +} + func newReadFileTool(fs filesystem.Backend, name string, desc string) (tool.BaseTool, error) { toolName := selectToolName(name, ToolNameReadFile) d, err := selectToolDesc(desc, ReadFileToolDesc, ReadFileToolDescChinese) @@ -576,19 +610,129 @@ func newReadFileTool(fs filesystem.Backend, name string, desc string) (tool.Base if err != nil { return "", err } + if fileCt == nil { + return fmt.Sprintf("No content found at path: %s", input.FilePath), nil + } - startLine := input.Offset - lines := strings.Split(fileCt.Content, "\n") - var b strings.Builder - for i, line := range lines { - if i < len(lines)-1 { - fmt.Fprintf(&b, "%6d\t%s\n", startLine+i, line) - } else { - fmt.Fprintf(&b, "%6d\t%s", startLine+i, line) + return formatLineNumbers(fileCt.Content, input.Offset), nil + }) +} + +// formatLineNumbers prefixes each line of content with a 1-based line number +// starting at startLine (e.g. " 1\tfoo"). startLine corresponds to the +// line number of the first line in content (usually ReadRequest.Offset). +func formatLineNumbers(content string, startLine int) string { + lines := strings.Split(content, "\n") + var b strings.Builder + for i, line := range lines { + if i < len(lines)-1 { + fmt.Fprintf(&b, "%6d\t%s\n", startLine+i, line) + } else { + fmt.Fprintf(&b, "%6d\t%s", startLine+i, line) + } + } + return b.String() +} + +func newMultiModalReadFileTool(fs filesystem.Backend, name string, desc string) (tool.BaseTool, error) { + er, ok := fs.(filesystem.MultiModalReader) + if !ok { + return nil, fmt.Errorf("UseMultiModalRead is enabled, but backend (type %T) does not implement filesystem.MultiModalReader interface. "+ + "Either implement the MultiModalReader interface on your backend, or set UseMultiModalRead to false", fs) + } + + toolName := selectToolName(name, ToolNameReadFile) + d, err := selectToolDesc(desc, ReadFileToolDesc, ReadFileToolDescChinese) + if err != nil { + return nil, err + } + // Only append the multimodal suffix when falling back to the built-in desc. + // A custom desc is expected to describe its own capabilities, so appending + // would produce duplicated or contradictory descriptions. + if desc == "" { + d += internal.SelectPrompt(internal.I18nPrompts{ + English: EnhancedReadFileDescSuffix, + Chinese: EnhancedReadFileDescSuffixChinese, + }) + } + + return utils.InferEnhancedTool(toolName, d, func(ctx context.Context, input multiModalReadFileArgs) (*schema.ToolResult, error) { + if input.Offset <= 0 { + input.Offset = 1 + } + if input.Limit <= 0 { + input.Limit = 2000 + } + + fileCt, err := er.MultiModalRead(ctx, &filesystem.MultiModalReadRequest{ + ReadRequest: filesystem.ReadRequest{ + FilePath: input.FilePath, + Offset: input.Offset, + Limit: input.Limit, + }, + Pages: input.Pages, + }) + if err != nil { + return nil, err + } + + if fileCt == nil { + return &schema.ToolResult{ + Parts: []schema.ToolOutputPart{{Type: schema.ToolPartTypeText, Text: fmt.Sprintf("No content found at path: %s", input.FilePath)}}, + }, nil + } + + // Multimodal result: convert FileContentPart to ToolOutputPart + if len(fileCt.Parts) > 0 { + parts := make([]schema.ToolOutputPart, 0, len(fileCt.Parts)) + enc := base64Encoder{} + for _, p := range fileCt.Parts { + if len(p.Data) == 0 { + return nil, fmt.Errorf("FileContentPart.Data is empty for type %s", p.Type) + } + if p.MIMEType == "" { + return nil, fmt.Errorf("FileContentPart.MIMEType is empty for type %s", p.Type) + } + b64 := enc.encode(p.Data) + switch p.Type { + case filesystem.FileContentPartTypeImage: + parts = append(parts, schema.ToolOutputPart{ + Type: schema.ToolPartTypeImage, + Image: &schema.ToolOutputImage{ + MessagePartCommon: schema.MessagePartCommon{ + MIMEType: p.MIMEType, + Base64Data: &b64, + }, + }, + }) + case filesystem.FileContentPartTypePDF: + parts = append(parts, schema.ToolOutputPart{ + Type: schema.ToolPartTypeFile, + File: &schema.ToolOutputFile{ + MessagePartCommon: schema.MessagePartCommon{ + MIMEType: p.MIMEType, + Base64Data: &b64, + }, + }, + }) + default: + // FileContentPartType is defined by Backend implementations. + // Unrecognized types are unlikely but should fail explicitly rather than silently. + return nil, fmt.Errorf("unsupported FileContentPartType: %s", p.Type) + } } + return &schema.ToolResult{Parts: parts}, nil + } + if fileCt.FileContent == nil { + return &schema.ToolResult{ + Parts: []schema.ToolOutputPart{{Type: schema.ToolPartTypeText, Text: fmt.Sprintf("No content found at path: %s", input.FilePath)}}, + }, nil } - return b.String(), nil + + return &schema.ToolResult{ + Parts: []schema.ToolOutputPart{{Type: schema.ToolPartTypeText, Text: formatLineNumbers(fileCt.Content, input.Offset)}}, + }, nil }) } @@ -920,6 +1064,22 @@ func valueOrDefault[T any](ptr *T, defaultValue T) T { return defaultValue } +// base64Encoder reuses a buffer across multiple base64 encoding calls to reduce allocations. +type base64Encoder struct { + buf []byte +} + +func (e *base64Encoder) encode(data []byte) string { + n := base64.StdEncoding.EncodedLen(len(data)) + if cap(e.buf) < n { + e.buf = make([]byte, n) + } else { + e.buf = e.buf[:n] + } + base64.StdEncoding.Encode(e.buf, data) + return string(e.buf) +} + func applyPagination[T any](items []T, offset, headLimit int) []T { if offset < 0 { offset = 0 diff --git a/adk/middlewares/filesystem/filesystem_test.go b/adk/middlewares/filesystem/filesystem_test.go index 54c6d440f..1185c2ded 100644 --- a/adk/middlewares/filesystem/filesystem_test.go +++ b/adk/middlewares/filesystem/filesystem_test.go @@ -18,6 +18,7 @@ package filesystem import ( "context" + "encoding/base64" "errors" "fmt" "io" @@ -289,7 +290,7 @@ func TestWriteFileTool(t *testing.T) { t.Fatalf("Failed to read written file: %v", err) } if content.Content != "new content" { - t.Errorf("Expected written content to be 'new content', got %q", content) + t.Errorf("Expected written content to be 'new content', got %q", content.Content) } } @@ -2273,3 +2274,341 @@ type mockShellBackendWithError struct{} func (m *mockShellBackendWithError) Execute(ctx context.Context, req *filesystem.ExecuteRequest) (*filesystem.ExecuteResponse, error) { return nil, errors.New("shell execution error") } + +// multiModalBackend wraps InMemoryBackend and implements MultiModalReader for testing. +type multiModalBackend struct { + *filesystem.InMemoryBackend + multiModalReadFunc func(ctx context.Context, req *filesystem.MultiModalReadRequest) (*filesystem.MultiFileContent, error) +} + +func (b *multiModalBackend) MultiModalRead(ctx context.Context, req *filesystem.MultiModalReadRequest) (*filesystem.MultiFileContent, error) { + return b.multiModalReadFunc(ctx, req) +} + +func TestMultiModalReadFileTool_TextOnly(t *testing.T) { + base := setupTestBackend() + eb := &multiModalBackend{ + InMemoryBackend: base, + multiModalReadFunc: func(ctx context.Context, req *filesystem.MultiModalReadRequest) (*filesystem.MultiFileContent, error) { + ct, err := base.Read(ctx, &req.ReadRequest) + if err != nil { + return nil, err + } + return &filesystem.MultiFileContent{ + FileContent: ct, + }, nil + }, + } + + mmTool, err := newMultiModalReadFileTool(eb, "", "") + assert.NoError(t, err) + + result, err := mmTool.(tool.EnhancedInvokableTool).InvokableRun( + context.Background(), &schema.ToolArgument{Text: `{"file_path": "/file1.txt", "offset": 0, "limit": 100}`}) + assert.NoError(t, err) + assert.NotNil(t, result) + assert.Len(t, result.Parts, 1) + assert.Equal(t, schema.ToolPartTypeText, result.Parts[0].Type) + assert.Contains(t, result.Parts[0].Text, "line1") + assert.Contains(t, result.Parts[0].Text, "line5") +} + +func TestMultiModalReadFileTool_Multimodal(t *testing.T) { + base := setupTestBackend() + imgData := []byte("rawimagedata") + eb := &multiModalBackend{ + InMemoryBackend: base, + multiModalReadFunc: func(ctx context.Context, req *filesystem.MultiModalReadRequest) (*filesystem.MultiFileContent, error) { + return &filesystem.MultiFileContent{ + Parts: []filesystem.FileContentPart{ + { + Type: filesystem.FileContentPartTypeImage, + MIMEType: "image/png", + Data: imgData, + }, + }, + }, nil + }, + } + + mmTool, err := newMultiModalReadFileTool(eb, "", "") + assert.NoError(t, err) + + result, err := mmTool.(tool.EnhancedInvokableTool).InvokableRun( + context.Background(), &schema.ToolArgument{Text: `{"file_path": "/image.png", "offset": 0, "limit": 100}`}) + assert.NoError(t, err) + assert.NotNil(t, result) + assert.Len(t, result.Parts, 1) + assert.Equal(t, schema.ToolPartTypeImage, result.Parts[0].Type) + + // Verify base64 encoding correctness + assert.NotNil(t, result.Parts[0].Image) + assert.Equal(t, "image/png", result.Parts[0].Image.MIMEType) + assert.Equal(t, base64.StdEncoding.EncodeToString(imgData), *result.Parts[0].Image.Base64Data) +} + +func TestMultiModalReadFileTool_FileType(t *testing.T) { + base := setupTestBackend() + pdfData := []byte("fakepdfcontent") + eb := &multiModalBackend{ + InMemoryBackend: base, + multiModalReadFunc: func(ctx context.Context, req *filesystem.MultiModalReadRequest) (*filesystem.MultiFileContent, error) { + return &filesystem.MultiFileContent{ + Parts: []filesystem.FileContentPart{ + { + Type: filesystem.FileContentPartTypePDF, + MIMEType: "application/pdf", + Data: pdfData, + }, + }, + }, nil + }, + } + + mmTool, err := newMultiModalReadFileTool(eb, "", "") + assert.NoError(t, err) + + result, err := mmTool.(tool.EnhancedInvokableTool).InvokableRun( + context.Background(), &schema.ToolArgument{Text: `{"file_path": "/doc.pdf", "offset": 0, "limit": 100}`}) + assert.NoError(t, err) + assert.Len(t, result.Parts, 1) + assert.Equal(t, schema.ToolPartTypeFile, result.Parts[0].Type) + assert.NotNil(t, result.Parts[0].File) + assert.Equal(t, "application/pdf", result.Parts[0].File.MIMEType) + assert.Equal(t, base64.StdEncoding.EncodeToString(pdfData), *result.Parts[0].File.Base64Data) +} + +func TestMultiModalReadFileTool_UnsupportedPartType(t *testing.T) { + base := setupTestBackend() + eb := &multiModalBackend{ + InMemoryBackend: base, + multiModalReadFunc: func(ctx context.Context, req *filesystem.MultiModalReadRequest) (*filesystem.MultiFileContent, error) { + return &filesystem.MultiFileContent{ + Parts: []filesystem.FileContentPart{ + { + Type: filesystem.FileContentPartType("unknown"), + MIMEType: "application/octet-stream", + Data: []byte("data"), + }, + }, + }, nil + }, + } + + mmTool, err := newMultiModalReadFileTool(eb, "", "") + assert.NoError(t, err) + + _, err = mmTool.(tool.EnhancedInvokableTool).InvokableRun( + context.Background(), &schema.ToolArgument{Text: `{"file_path": "/file.bin", "offset": 0, "limit": 100}`}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unsupported FileContentPartType") +} + +func TestMultiModalReadFileTool_PagesPassThrough(t *testing.T) { + base := setupTestBackend() + var capturedPages string + eb := &multiModalBackend{ + InMemoryBackend: base, + multiModalReadFunc: func(ctx context.Context, req *filesystem.MultiModalReadRequest) (*filesystem.MultiFileContent, error) { + capturedPages = req.Pages + return &filesystem.MultiFileContent{FileContent: &filesystem.FileContent{Content: "page content"}}, nil + }, + } + + mmTool, err := newMultiModalReadFileTool(eb, "", "") + assert.NoError(t, err) + + _, err = mmTool.(tool.EnhancedInvokableTool).InvokableRun( + context.Background(), &schema.ToolArgument{Text: `{"file_path": "/doc.pdf", "pages": "1-5"}`}) + assert.NoError(t, err) + assert.Equal(t, "1-5", capturedPages) +} + +func TestMultiModalReadFileTool_BackendNotMultiModalReader(t *testing.T) { + base := setupTestBackend() + _, err := newMultiModalReadFileTool(base, "", "") + assert.Error(t, err) + assert.Contains(t, err.Error(), "MultiModalReader") +} + +func TestUseMultiModalRead_Routing(t *testing.T) { + base := setupTestBackend() + eb := &multiModalBackend{ + InMemoryBackend: base, + multiModalReadFunc: func(ctx context.Context, req *filesystem.MultiModalReadRequest) (*filesystem.MultiFileContent, error) { + ct, err := base.Read(ctx, &req.ReadRequest) + if err != nil { + return nil, err + } + return &filesystem.MultiFileContent{FileContent: ct}, nil + }, + } + + // UseMultiModalRead=false should create standard tool + tools, err := getFilesystemTools(context.Background(), &MiddlewareConfig{ + Backend: base, + UseMultiModalRead: false, + }) + assert.NoError(t, err) + for _, tl := range tools { + info, _ := tl.Info(context.Background()) + if info != nil && info.Name == ToolNameReadFile { + _, isEnhanced := tl.(tool.EnhancedInvokableTool) + assert.False(t, isEnhanced, "should be standard InvokableTool when UseMultiModalRead=false") + } + } + + // UseMultiModalRead=true with enhanced backend should create enhanced tool + tools2, err := getFilesystemTools(context.Background(), &MiddlewareConfig{ + Backend: eb, + UseMultiModalRead: true, + }) + assert.NoError(t, err) + for _, tl := range tools2 { + info, _ := tl.Info(context.Background()) + if info != nil && info.Name == ToolNameReadFile { + _, isEnhanced := tl.(tool.EnhancedInvokableTool) + assert.True(t, isEnhanced, "should be EnhancedInvokableTool when UseMultiModalRead=true") + } + } +} + +// TestMultiModalReadFileTool_SchemaContainsAllFields verifies that the JSON schema +// exposed to the LLM includes both the embedded readFileArgs fields (file_path, +// offset, limit) and the enhanced-only "pages" field. Guards against the +// jsonschema library failing to flatten an unexported anonymous embedded struct. +func TestMultiModalReadFileTool_SchemaContainsAllFields(t *testing.T) { + base := setupTestBackend() + eb := &multiModalBackend{ + InMemoryBackend: base, + multiModalReadFunc: func(ctx context.Context, req *filesystem.MultiModalReadRequest) (*filesystem.MultiFileContent, error) { + ct, err := base.Read(ctx, &req.ReadRequest) + if err != nil { + return nil, err + } + return &filesystem.MultiFileContent{FileContent: ct}, nil + }, + } + + mmTool, err := newMultiModalReadFileTool(eb, "", "") + assert.NoError(t, err) + + info, err := mmTool.Info(context.Background()) + assert.NoError(t, err) + assert.NotNil(t, info) + + js, err := info.ParamsOneOf.ToJSONSchema() + assert.NoError(t, err) + assert.NotNil(t, js) + assert.NotNil(t, js.Properties, "schema should have properties") + + for _, field := range []string{"file_path", "offset", "limit", "pages"} { + _, ok := js.Properties.Get(field) + assert.True(t, ok, "expected JSON schema to contain field %q, schema=%+v", field, js.Properties) + } +} + +// TestMultiModalReadFileTool_CustomDescNoSuffix verifies that when a custom desc is +// provided, the multimodal suffix is NOT appended (user's desc replaces default). +func TestMultiModalReadFileTool_CustomDescNoSuffix(t *testing.T) { + base := setupTestBackend() + eb := &multiModalBackend{ + InMemoryBackend: base, + multiModalReadFunc: func(ctx context.Context, req *filesystem.MultiModalReadRequest) (*filesystem.MultiFileContent, error) { + ct, err := base.Read(ctx, &req.ReadRequest) + if err != nil { + return nil, err + } + return &filesystem.MultiFileContent{FileContent: ct}, nil + }, + } + + customDesc := "my custom read tool description" + mmTool, err := newMultiModalReadFileTool(eb, "", customDesc) + assert.NoError(t, err) + + info, err := mmTool.Info(context.Background()) + assert.NoError(t, err) + assert.Equal(t, customDesc, info.Desc, "custom desc should not be augmented with multimodal suffix") + + // With empty desc (fallback to default), suffix should be appended. + defaultTool, err := newMultiModalReadFileTool(eb, "", "") + assert.NoError(t, err) + defaultInfo, err := defaultTool.Info(context.Background()) + assert.NoError(t, err) + assert.Contains(t, defaultInfo.Desc, "multimodal", "default desc should include multimodal suffix") +} + +// TestMultiModalReadFileTool_EmptyPartDataError verifies that a FileContentPart +// with empty Data fails explicitly rather than silently encoding to an empty +// base64 string. +func TestMultiModalReadFileTool_EmptyPartDataError(t *testing.T) { + base := setupTestBackend() + eb := &multiModalBackend{ + InMemoryBackend: base, + multiModalReadFunc: func(ctx context.Context, req *filesystem.MultiModalReadRequest) (*filesystem.MultiFileContent, error) { + return &filesystem.MultiFileContent{ + Parts: []filesystem.FileContentPart{ + {Type: filesystem.FileContentPartTypeImage, MIMEType: "image/png", Data: nil}, + }, + }, nil + }, + } + + mmTool, err := newMultiModalReadFileTool(eb, "", "") + assert.NoError(t, err) + + _, err = mmTool.(tool.EnhancedInvokableTool).InvokableRun( + context.Background(), &schema.ToolArgument{Text: `{"file_path": "/x"}`}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "empty") +} + +// nilReadBackend wraps InMemoryBackend but returns nil, nil from Read. +type nilReadBackend struct { + *filesystem.InMemoryBackend +} + +func (b *nilReadBackend) Read(_ context.Context, _ *filesystem.ReadRequest) (*filesystem.FileContent, error) { + return nil, nil +} + +// TestReadFileTool_NilResult verifies that newReadFileTool does not panic when +// Backend.Read returns nil, and emits a human-readable fallback message instead. +func TestReadFileTool_NilResult(t *testing.T) { + base := setupTestBackend() + backend := &nilReadBackend{InMemoryBackend: base} + + readTool, err := newReadFileTool(backend, "", "") + assert.NoError(t, err) + + out, err := invokeTool(t, readTool, `{"file_path": "/missing.txt"}`) + assert.NoError(t, err) + assert.Contains(t, out, "No content found at path") + assert.Contains(t, out, "/missing.txt") +} + +// TestMultiModalReadFileTool_NilResult verifies that newMultiModalReadFileTool +// does not panic when MultiModalRead returns nil, and returns a text part with +// a human-readable fallback message. +func TestMultiModalReadFileTool_NilResult(t *testing.T) { + base := setupTestBackend() + eb := &multiModalBackend{ + InMemoryBackend: base, + multiModalReadFunc: func(ctx context.Context, req *filesystem.MultiModalReadRequest) (*filesystem.MultiFileContent, error) { + return nil, nil + }, + } + + mmTool, err := newMultiModalReadFileTool(eb, "", "") + assert.NoError(t, err) + + result, err := mmTool.(tool.EnhancedInvokableTool).InvokableRun( + context.Background(), &schema.ToolArgument{Text: `{"file_path": "/missing.txt"}`}) + assert.NoError(t, err) + assert.NotNil(t, result) + assert.Len(t, result.Parts, 1) + assert.Equal(t, schema.ToolPartTypeText, result.Parts[0].Type) + assert.Contains(t, result.Parts[0].Text, "No content found at path") + assert.Contains(t, result.Parts[0].Text, "/missing.txt") +} diff --git a/adk/middlewares/filesystem/prompt.go b/adk/middlewares/filesystem/prompt.go index 55bba056b..a20d6d7d8 100644 --- a/adk/middlewares/filesystem/prompt.go +++ b/adk/middlewares/filesystem/prompt.go @@ -89,6 +89,15 @@ Usage: - 如果你读取的文件存在但内容为空,你将收到系统提醒警告而不是文件内容 - 在编辑文件之前,你应该始终确保已读取该文件` + // EnhancedReadFileDescSuffix is appended to ReadFileToolDesc when using MultiModalReadFileTool. + EnhancedReadFileDescSuffix = ` +- This tool supports reading image files (e.g., PNG, JPG, etc.). When reading an image file, the contents are presented visually, as the underlying model is a multimodal LLM. +- This tool can read PDF files (.pdf). For large PDFs (more than 10 pages), you MUST provide the pages parameter to read specific page ranges (e.g., pages: "1-5"). Reading a large PDF without the pages parameter will fail. Maximum 20 pages per request.` + + EnhancedReadFileDescSuffixChinese = ` +- 此工具支持读取图片文件(如 PNG、JPG 等)。读取图片文件时,内容将以视觉方式呈现,因为底层模型是多模态 LLM。 +- 此工具可以读取 PDF 文件(.pdf)。对于大型 PDF(超过 10 页),你必须提供 pages 参数来指定页面范围(例如 pages: "1-5")。不提供 pages 参数读取大型 PDF 将会失败。每次请求最多 20 页。` + EditFileToolDesc = `Performs exact string replacements in files. Usage: From 8683f147bb9cca067ad17b5c2606cbc861dd048f Mon Sep 17 00:00:00 2001 From: "shentong.martin" Date: Mon, 30 Mar 2026 21:33:27 +0800 Subject: [PATCH 61/65] feat(adk): add BeforeFinalAnswer hook in ChatModelAgentMiddleware Change-Id: I743af2360a8742c809712c5a2134d67dcafc66ac --- adk/before_final_answer_test.go | 389 ++++++++++++++++++++++++++++++++ adk/chatmodel.go | 12 +- adk/handler.go | 23 ++ adk/react.go | 156 ++++++++----- 4 files changed, 518 insertions(+), 62 deletions(-) create mode 100644 adk/before_final_answer_test.go diff --git a/adk/before_final_answer_test.go b/adk/before_final_answer_test.go new file mode 100644 index 000000000..a85548b8d --- /dev/null +++ b/adk/before_final_answer_test.go @@ -0,0 +1,389 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "errors" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/components/tool" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" +) + +type bfaDelegateModel struct { + generateFn func(input []*schema.Message) (*schema.Message, error) + streamFn func(input []*schema.Message) (*schema.StreamReader[*schema.Message], error) +} + +func (m *bfaDelegateModel) Generate(_ context.Context, input []*schema.Message, _ ...model.Option) (*schema.Message, error) { + if m.generateFn != nil { + return m.generateFn(input) + } + return schema.AssistantMessage("default", nil), nil +} + +func (m *bfaDelegateModel) Stream(_ context.Context, input []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + if m.streamFn != nil { + return m.streamFn(input) + } + return schema.StreamReaderFromArray([]*schema.Message{ + schema.AssistantMessage("default", nil), + }), nil +} + +type beforeFinalAnswerHandler struct { + BaseChatModelAgentMiddleware + fn func(ctx context.Context, state *ChatModelAgentState) (context.Context, bool, *ChatModelAgentState, error) +} + +func (h *beforeFinalAnswerHandler) BeforeFinalAnswer(ctx context.Context, state *ChatModelAgentState) (context.Context, bool, *ChatModelAgentState, error) { + return h.fn(ctx, state) +} + +func drainBFAIterator(iter *AsyncIterator[*AgentEvent]) []*AgentEvent { + var events []*AgentEvent + for { + event, ok := iter.Next() + if !ok { + break + } + events = append(events, event) + } + return events +} + +func drainBFAStreamEvents(events []*AgentEvent) { + for _, event := range events { + if event.Output != nil && event.Output.MessageOutput != nil && event.Output.MessageOutput.IsStreaming { + sr := event.Output.MessageOutput.MessageStream + for { + _, err := sr.Recv() + if err != nil { + break + } + } + } + } +} + +func TestBeforeFinalAnswer(t *testing.T) { + t.Run("no-tools invoke: accept on first call exits immediately", func(t *testing.T) { + var callCount int32 + m := &bfaDelegateModel{ + generateFn: func(_ []*schema.Message) (*schema.Message, error) { + atomic.AddInt32(&callCount, 1) + return schema.AssistantMessage("answer", nil), nil + }, + } + + var hookCalls int32 + agent, err := NewChatModelAgent(context.Background(), &ChatModelAgentConfig{ + Name: "test", Description: "d", Instruction: "i", Model: m, + MaxIterations: 5, + Handlers: []ChatModelAgentMiddleware{ + &beforeFinalAnswerHandler{fn: func(ctx context.Context, state *ChatModelAgentState) (context.Context, bool, *ChatModelAgentState, error) { + atomic.AddInt32(&hookCalls, 1) + return ctx, true, state, nil + }}, + }, + }) + assert.NoError(t, err) + + events := drainBFAIterator(agent.Run(context.Background(), &AgentInput{Messages: []Message{schema.UserMessage("hi")}})) + assert.Equal(t, 1, len(events)) + assert.Nil(t, events[0].Err) + assert.Equal(t, "answer", events[0].Output.MessageOutput.Message.Content) + assert.Equal(t, int32(1), atomic.LoadInt32(&callCount)) + assert.Equal(t, int32(1), atomic.LoadInt32(&hookCalls)) + }) + + t.Run("no-tools invoke: reject causes re-iteration with modified messages", func(t *testing.T) { + var callCount int32 + m := &bfaDelegateModel{ + generateFn: func(input []*schema.Message) (*schema.Message, error) { + count := atomic.AddInt32(&callCount, 1) + if count == 1 { + msg := schema.AssistantMessage("partial", nil) + msg.ResponseMeta = &schema.ResponseMeta{FinishReason: "length"} + return msg, nil + } + return schema.AssistantMessage("complete answer", nil), nil + }, + } + + agent, err := NewChatModelAgent(context.Background(), &ChatModelAgentConfig{ + Name: "test", Description: "d", Instruction: "i", Model: m, + MaxIterations: 5, + Handlers: []ChatModelAgentMiddleware{ + &beforeFinalAnswerHandler{fn: func(ctx context.Context, state *ChatModelAgentState) (context.Context, bool, *ChatModelAgentState, error) { + lastMsg := state.Messages[len(state.Messages)-1] + if lastMsg.ResponseMeta != nil && lastMsg.ResponseMeta.FinishReason == "length" { + state.Messages = append(state.Messages, schema.UserMessage("Please continue.")) + return ctx, false, state, nil + } + return ctx, true, state, nil + }}, + }, + }) + assert.NoError(t, err) + + events := drainBFAIterator(agent.Run(context.Background(), &AgentInput{Messages: []Message{schema.UserMessage("hi")}})) + lastEvent := events[len(events)-1] + assert.Nil(t, lastEvent.Err) + assert.Equal(t, "complete answer", lastEvent.Output.MessageOutput.Message.Content) + assert.Equal(t, int32(2), atomic.LoadInt32(&callCount)) + }) + + t.Run("no-tools stream: reject causes re-iteration", func(t *testing.T) { + var callCount int32 + m := &bfaDelegateModel{ + streamFn: func(_ []*schema.Message) (*schema.StreamReader[*schema.Message], error) { + count := atomic.AddInt32(&callCount, 1) + if count == 1 { + msg := schema.AssistantMessage("partial stream", nil) + msg.ResponseMeta = &schema.ResponseMeta{FinishReason: "length"} + return schema.StreamReaderFromArray([]*schema.Message{msg}), nil + } + return schema.StreamReaderFromArray([]*schema.Message{ + schema.AssistantMessage("complete stream", nil), + }), nil + }, + } + + agent, err := NewChatModelAgent(context.Background(), &ChatModelAgentConfig{ + Name: "test", Description: "d", Instruction: "i", Model: m, + MaxIterations: 5, + Handlers: []ChatModelAgentMiddleware{ + &beforeFinalAnswerHandler{fn: func(ctx context.Context, state *ChatModelAgentState) (context.Context, bool, *ChatModelAgentState, error) { + lastMsg := state.Messages[len(state.Messages)-1] + if lastMsg.ResponseMeta != nil && lastMsg.ResponseMeta.FinishReason == "length" { + return ctx, false, state, nil + } + return ctx, true, state, nil + }}, + }, + }) + assert.NoError(t, err) + + events := drainBFAIterator(agent.Run(context.Background(), &AgentInput{ + Messages: []Message{schema.UserMessage("hi")}, EnableStreaming: true, + })) + drainBFAStreamEvents(events) + assert.Equal(t, int32(2), atomic.LoadInt32(&callCount)) + }) + + t.Run("no-tools invoke: rejected answers count toward MaxIterations", func(t *testing.T) { + var callCount int32 + m := &bfaDelegateModel{ + generateFn: func(_ []*schema.Message) (*schema.Message, error) { + atomic.AddInt32(&callCount, 1) + return schema.AssistantMessage("bad answer", nil), nil + }, + } + + agent, err := NewChatModelAgent(context.Background(), &ChatModelAgentConfig{ + Name: "test", Description: "d", Instruction: "i", Model: m, + MaxIterations: 3, + Handlers: []ChatModelAgentMiddleware{ + &beforeFinalAnswerHandler{fn: func(ctx context.Context, state *ChatModelAgentState) (context.Context, bool, *ChatModelAgentState, error) { + return ctx, false, state, nil + }}, + }, + }) + assert.NoError(t, err) + + events := drainBFAIterator(agent.Run(context.Background(), &AgentInput{Messages: []Message{schema.UserMessage("hi")}})) + lastEvent := events[len(events)-1] + assert.True(t, errors.Is(lastEvent.Err, ErrExceedMaxIterations)) + assert.Equal(t, int32(3), atomic.LoadInt32(&callCount)) + }) + + t.Run("no-tools invoke: hook error propagates immediately", func(t *testing.T) { + hookErr := errors.New("hook failed") + m := &bfaDelegateModel{ + generateFn: func(_ []*schema.Message) (*schema.Message, error) { + return schema.AssistantMessage("answer", nil), nil + }, + } + + agent, err := NewChatModelAgent(context.Background(), &ChatModelAgentConfig{ + Name: "test", Description: "d", Instruction: "i", Model: m, + MaxIterations: 5, + Handlers: []ChatModelAgentMiddleware{ + &beforeFinalAnswerHandler{fn: func(ctx context.Context, state *ChatModelAgentState) (context.Context, bool, *ChatModelAgentState, error) { + return ctx, false, state, hookErr + }}, + }, + }) + assert.NoError(t, err) + + events := drainBFAIterator(agent.Run(context.Background(), &AgentInput{Messages: []Message{schema.UserMessage("hi")}})) + lastEvent := events[len(events)-1] + assert.True(t, errors.Is(lastEvent.Err, hookErr)) + }) + + t.Run("with-tools invoke: hook only runs on final answers, not tool calls", func(t *testing.T) { + var generateCount int32 + var hookCalls int32 + + m := &bfaDelegateModel{ + generateFn: func(_ []*schema.Message) (*schema.Message, error) { + count := atomic.AddInt32(&generateCount, 1) + if count == 1 { + return schema.AssistantMessage("", []schema.ToolCall{{ + ID: "call-1", + Function: schema.FunctionCall{Name: "test_tool", Arguments: `{"name":"test"}`}, + }}), nil + } + return schema.AssistantMessage("final answer", nil), nil + }, + } + + agent, err := NewChatModelAgent(context.Background(), &ChatModelAgentConfig{ + Name: "test", Description: "d", Instruction: "i", Model: m, + MaxIterations: 10, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{&fakeToolForTest{tarCount: 0}}}, + }, + Handlers: []ChatModelAgentMiddleware{ + &beforeFinalAnswerHandler{fn: func(ctx context.Context, state *ChatModelAgentState) (context.Context, bool, *ChatModelAgentState, error) { + atomic.AddInt32(&hookCalls, 1) + return ctx, true, state, nil + }}, + }, + }) + assert.NoError(t, err) + + events := drainBFAIterator(agent.Run(context.Background(), &AgentInput{Messages: []Message{schema.UserMessage("hi")}})) + drainBFAStreamEvents(events) + + var foundFinalAnswer bool + for _, event := range events { + if event.Err == nil && event.Output != nil && event.Output.MessageOutput != nil { + if event.Output.MessageOutput.Message != nil && event.Output.MessageOutput.Message.Content == "final answer" { + foundFinalAnswer = true + } + } + } + assert.True(t, foundFinalAnswer) + assert.Equal(t, int32(1), atomic.LoadInt32(&hookCalls)) + }) + + t.Run("with-tools stream: reject loops back through ChatModel", func(t *testing.T) { + var streamCount int32 + + m := &bfaDelegateModel{ + streamFn: func(_ []*schema.Message) (*schema.StreamReader[*schema.Message], error) { + count := atomic.AddInt32(&streamCount, 1) + if count == 1 { + msg := schema.AssistantMessage("bad answer", nil) + msg.ResponseMeta = &schema.ResponseMeta{FinishReason: "length"} + return schema.StreamReaderFromArray([]*schema.Message{msg}), nil + } + return schema.StreamReaderFromArray([]*schema.Message{ + schema.AssistantMessage("good answer", nil), + }), nil + }, + } + + agent, err := NewChatModelAgent(context.Background(), &ChatModelAgentConfig{ + Name: "test", Description: "d", Instruction: "i", Model: m, + MaxIterations: 5, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{&fakeToolForTest{tarCount: 0}}}, + }, + Handlers: []ChatModelAgentMiddleware{ + &beforeFinalAnswerHandler{fn: func(ctx context.Context, state *ChatModelAgentState) (context.Context, bool, *ChatModelAgentState, error) { + lastMsg := state.Messages[len(state.Messages)-1] + if lastMsg.ResponseMeta != nil && lastMsg.ResponseMeta.FinishReason == "length" { + state.Messages = append(state.Messages, schema.UserMessage("Continue please.")) + return ctx, false, state, nil + } + return ctx, true, state, nil + }}, + }, + }) + assert.NoError(t, err) + + events := drainBFAIterator(agent.Run(context.Background(), &AgentInput{ + Messages: []Message{schema.UserMessage("hi")}, EnableStreaming: true, + })) + drainBFAStreamEvents(events) + assert.Equal(t, int32(2), atomic.LoadInt32(&streamCount)) + }) + + t.Run("no handlers means default accept behavior", func(t *testing.T) { + var callCount int32 + m := &bfaDelegateModel{ + generateFn: func(_ []*schema.Message) (*schema.Message, error) { + atomic.AddInt32(&callCount, 1) + return schema.AssistantMessage("answer", nil), nil + }, + } + + agent, err := NewChatModelAgent(context.Background(), &ChatModelAgentConfig{ + Name: "test", Description: "d", Instruction: "i", Model: m, + MaxIterations: 5, + }) + assert.NoError(t, err) + + events := drainBFAIterator(agent.Run(context.Background(), &AgentInput{Messages: []Message{schema.UserMessage("hi")}})) + assert.Equal(t, 1, len(events)) + assert.Nil(t, events[0].Err) + assert.Equal(t, int32(1), atomic.LoadInt32(&callCount)) + }) + + t.Run("no-tools invoke: reject on empty content, accept on non-empty", func(t *testing.T) { + var callCount int32 + m := &bfaDelegateModel{ + generateFn: func(_ []*schema.Message) (*schema.Message, error) { + count := atomic.AddInt32(&callCount, 1) + if count == 1 { + return schema.AssistantMessage("", nil), nil + } + return schema.AssistantMessage("real content", nil), nil + }, + } + + agent, err := NewChatModelAgent(context.Background(), &ChatModelAgentConfig{ + Name: "test", Description: "d", Instruction: "i", Model: m, + MaxIterations: 5, + Handlers: []ChatModelAgentMiddleware{ + &beforeFinalAnswerHandler{fn: func(ctx context.Context, state *ChatModelAgentState) (context.Context, bool, *ChatModelAgentState, error) { + lastMsg := state.Messages[len(state.Messages)-1] + if lastMsg.Content == "" && len(lastMsg.ToolCalls) == 0 { + return ctx, false, state, nil + } + return ctx, true, state, nil + }}, + }, + }) + assert.NoError(t, err) + + events := drainBFAIterator(agent.Run(context.Background(), &AgentInput{Messages: []Message{schema.UserMessage("hi")}})) + lastEvent := events[len(events)-1] + assert.Nil(t, lastEvent.Err) + assert.Equal(t, "real content", lastEvent.Output.MessageOutput.Message.Content) + assert.Equal(t, int32(2), atomic.LoadInt32(&callCount)) + }) +} diff --git a/adk/chatmodel.go b/adk/chatmodel.go index 096435dfd..5370d38a8 100644 --- a/adk/chatmodel.go +++ b/adk/chatmodel.go @@ -756,8 +756,7 @@ type execContext struct { toolInfos []*schema.ToolInfo unwrappedTools []tool.BaseTool - rebuildGraph bool // whether needs to instantiate a new graph because of topology changes due to tool modifications - toolUpdated bool // whether needs to pass a compose.WithToolList option to ToolsNode due to tool list change + toolUpdated bool // whether needs to pass a compose.WithToolList option to ToolsNode due to tool list change } func (a *TypedChatModelAgent[M]) applyBeforeAgent(ctx context.Context, ec *execContext) (context.Context, *execContext, error) { @@ -787,8 +786,6 @@ func (a *TypedChatModelAgent[M]) applyBeforeAgent(ctx context.Context, ec *execC }, returnDirectly: runCtx.ReturnDirectly, toolUpdated: true, - rebuildGraph: (len(ec.toolsNodeConf.Tools) == 0 && len(runCtx.Tools) > 0) || - (len(ec.returnDirectly) == 0 && len(runCtx.ReturnDirectly) > 0), } toolInfos, err := genToolInfos(ctx, &runtimeEC.toolsNodeConf) @@ -861,7 +858,7 @@ func (a *TypedChatModelAgent[M]) prepareExecContext(ctx context.Context) (*execC }, nil } -// handleRunFuncError is the common error handler for buildNoToolsRunFunc and buildReActRunFunc. +// handleRunFuncError is the common error handler for buildReActRunFunc. // It handles compose interrupts (both cancel-triggered and business) // and generic errors, sending the appropriate event to the generator. func (a *TypedChatModelAgent[M]) handleRunFuncError( @@ -1346,9 +1343,8 @@ func (a *TypedChatModelAgent[M]) getRunFunc(ctx context.Context) (context.Contex return ctx, nil, nil, err } - if !runtimeBC.rebuildGraph { - return ctx, defaultRun, runtimeBC, nil - } + return ctx, defaultRun, runtimeBC, nil +} var tempRun typedRunFunc[M] if len(runtimeBC.toolsNodeConf.Tools) == 0 { diff --git a/adk/handler.go b/adk/handler.go index 255294dd0..e061d2469 100644 --- a/adk/handler.go +++ b/adk/handler.go @@ -219,6 +219,25 @@ type TypedChatModelAgentMiddleware[M messageType] interface { // The mc parameter contains the current tool configuration: // - Tools: The tool infos that will be sent to the model WrapModel(ctx context.Context, m model.BaseModel[M], mc *ModelContext) (model.BaseModel[M], error) + + // BeforeFinalAnswer is called when the model produces a response with no tool calls + // (a "final answer") before it is accepted and the agent exits. + // + // The state contains all messages including the model's final answer as the last message. + // The hook can inspect the response (e.g., FinishReason, content) and decide whether + // to accept or reject it. + // + // Returns: + // - ctx: the (possibly modified) context + // - accept: if true, the final answer is accepted and the agent exits normally. + // If false, the agent loops back to the ChatModel for another iteration. + // The handler may modify state.Messages before returning false (e.g., append a + // "please continue" user message after a truncated response). + // - state: the (possibly modified) agent state + // - error: if non-nil, the agent exits with this error + // + // Rejected answers count toward MaxIterations, providing a natural cap on runaway loops. + BeforeFinalAnswer(ctx context.Context, state *ChatModelAgentState) (context.Context, bool, *ChatModelAgentState, error) } // ChatModelAgentMiddleware is the default middleware type using *schema.Message. @@ -279,6 +298,10 @@ func (b *TypedBaseChatModelAgentMiddleware[M]) AfterToolCallsRewriteState(ctx co return ctx, state, nil } +func (b *BaseChatModelAgentMiddleware) BeforeFinalAnswer(ctx context.Context, state *ChatModelAgentState) (context.Context, bool, *ChatModelAgentState, error) { + return ctx, true, state, nil +} + func processTypedState(ctx context.Context, fn func(extra map[string]any) map[string]any) error { runCtx := getRunCtx(ctx) if runCtx != nil && runCtx.AgenticRootInput != nil { diff --git a/adk/react.go b/adk/react.go index fdd224f74..e37712d74 100644 --- a/adk/react.go +++ b/adk/react.go @@ -334,9 +334,7 @@ func newReact(ctx context.Context, config *reactConfig) (reactGraph, error) { wrappedModel = buildModelWrappers(config.model, config.modelWrapperConf) } - toolsConfig := config.toolsConfig - - toolsNode, err := compose.NewToolNode(ctx, toolsConfig) + toolsNode, err := compose.NewToolNode(ctx, config.toolsConfig) if err != nil { return nil, err } @@ -350,9 +348,6 @@ func newReact(ctx context.Context, config *reactConfig) (reactGraph, error) { return input, nil }), compose.WithNodeName(chatModel_)) - // CancelAfterChatModel safe-point: on the tool-calls path, after the branch - // has confirmed that the model response contains tool calls (i.e. not a final - // answer). Skipped entirely when the model produces a final answer. _ = g.AddLambdaNode(cancelCheckNode_, compose.InvokableLambda(func(ctx context.Context, msg Message) (Message, error) { if cancelCtx != nil && cancelCtx.shouldCancel() { if cancelCtx.getMode()&CancelAfterChatModel != 0 { @@ -394,8 +389,6 @@ func newReact(ctx context.Context, config *reactConfig) (reactGraph, error) { compose.WithStreamStatePostHandler(toolPostHandle), compose.WithNodeName(toolNode_)) - // AfterToolCalls node: calls AfterToolCallsRewriteState handlers after all tool calls complete. - // The graph auto-materializes the ToolsNode stream into []Message before this node. afterToolCalls := func(ctx context.Context, toolResults []Message) ([]Message, error) { var stateMessages []Message _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { @@ -431,7 +424,6 @@ func newReact(ctx context.Context, config *reactConfig) (reactGraph, error) { _ = g.AddLambdaNode(afterToolCallsNode_, compose.InvokableLambda(afterToolCalls), compose.WithNodeName(afterToolCallsNode_)) - // AfterToolCallsCancelCheck: CancelAfterToolCalls safe-point, separated from toolPostHandle. afterToolCallsCancelCheck := func(ctx context.Context, toolResults []Message) ([]Message, error) { if cancelCtx != nil && cancelCtx.shouldCancel() { if cancelCtx.getMode()&CancelAfterToolCalls != 0 { @@ -446,71 +438,127 @@ func newReact(ctx context.Context, config *reactConfig) (reactGraph, error) { _ = g.AddEdge(compose.START, initNode_) _ = g.AddEdge(initNode_, chatModel_) - toolCallCheck := func(ctx context.Context, sMsg MessageStream) (string, error) { - defer sMsg.Close() - for { - chunk, err_ := sMsg.Recv() - if err_ != nil { - if err_ == io.EOF { - return compose.END, nil - } - - return "", err_ - } - - if len(chunk.ToolCalls) > 0 { - return cancelCheckNode_, nil - } - } - } - branch := compose.NewStreamGraphBranch(toolCallCheck, map[string]bool{compose.END: true, cancelCheckNode_: true}) - _ = g.AddBranch(chatModel_, branch) + addFinalAnswerBranch(g, chatModel_, cancelCheckNode_, config.modelWrapperConf) _ = g.AddEdge(cancelCheckNode_, toolNode_) _ = g.AddEdge(toolNode_, afterToolCallsNode_) _ = g.AddEdge(afterToolCallsNode_, afterToolCallsCancelCheckNode_) - if len(config.toolsReturnDirectly) > 0 { - const ( - toolNodeToEndConverter = "ToolNodeToEndConverter" - ) + const ( + toolNodeToEndConverter = "ToolNodeToEndConverter" + ) - cvt := func(ctx context.Context, toolResults []Message) (Message, error) { - id, _ := getReturnDirectlyToolCallID(ctx) + cvt := func(ctx context.Context, toolResults []Message) (Message, error) { + id, _ := getReturnDirectlyToolCallID(ctx) - for _, msg := range toolResults { - if msg != nil && msg.ToolCallID == id { - return msg, nil - } + for _, msg := range toolResults { + if msg != nil && msg.ToolCallID == id { + return msg, nil } - - return nil, errors.New("return directly tool call result not found") } - _ = g.AddLambdaNode(toolNodeToEndConverter, compose.InvokableLambda(cvt), - compose.WithNodeName(toolNodeToEndConverter)) - _ = g.AddEdge(toolNodeToEndConverter, compose.END) + return nil, errors.New("return directly tool call result not found") + } - checkReturnDirect := func(ctx context.Context, toolResults []Message) (string, error) { - _, ok := getReturnDirectlyToolCallID(ctx) + _ = g.AddLambdaNode(toolNodeToEndConverter, compose.InvokableLambda(cvt), + compose.WithNodeName(toolNodeToEndConverter)) + _ = g.AddEdge(toolNodeToEndConverter, compose.END) - if ok { - return toolNodeToEndConverter, nil - } + checkReturnDirect := func(ctx context.Context, toolResults []Message) (string, error) { + _, ok := getReturnDirectlyToolCallID(ctx) - return chatModel_, nil + if ok { + return toolNodeToEndConverter, nil } - returnDirectBranch := compose.NewGraphBranch(checkReturnDirect, - map[string]bool{toolNodeToEndConverter: true, chatModel_: true}) - _ = g.AddBranch(afterToolCallsCancelCheckNode_, returnDirectBranch) - } else { - _ = g.AddEdge(afterToolCallsCancelCheckNode_, chatModel_) + return chatModel_, nil } + returnDirectBranch := compose.NewGraphBranch(checkReturnDirect, + map[string]bool{toolNodeToEndConverter: true, chatModel_: true}) + _ = g.AddBranch(afterToolCallsCancelCheckNode_, returnDirectBranch) + return g, nil } +func runBeforeFinalAnswer(ctx context.Context, mwConf *modelWrapperConfig) (bool, error) { + if mwConf == nil || len(mwConf.handlers) == 0 { + return true, nil + } + + var stateMessages []Message + _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { + stateMessages = st.Messages + return nil + }) + + state := &ChatModelAgentState{Messages: stateMessages} + accepted := true + + for _, handler := range mwConf.handlers { + var accept bool + var newState *ChatModelAgentState + var err error + ctx, accept, newState, err = handler.BeforeFinalAnswer(ctx, state) + if err != nil { + return false, err + } + state = newState + if !accept { + accepted = false + } + } + + if !accepted { + _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { + st.Messages = state.Messages + return nil + }) + } + + return accepted, nil +} + +func addFinalAnswerBranch(g *compose.Graph[*reactInput, Message], chatModelNode, cancelCheckNode string, mwConf *modelWrapperConfig) { + const finalAnswerRejectionNode_ = "FinalAnswerRejection" + _ = g.AddLambdaNode(finalAnswerRejectionNode_, compose.InvokableLambda(func(ctx context.Context, _ Message) ([]Message, error) { + var msgs []Message + _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { + msgs = st.Messages + return nil + }) + return msgs, nil + }), compose.WithNodeName(finalAnswerRejectionNode_)) + _ = g.AddEdge(finalAnswerRejectionNode_, chatModelNode) + + toolCallCheck := func(ctx context.Context, sMsg MessageStream) (string, error) { + defer sMsg.Close() + for { + chunk, err_ := sMsg.Recv() + if err_ != nil { + if err_ == io.EOF { + accepted, err := runBeforeFinalAnswer(ctx, mwConf) + if err != nil { + return "", err + } + if accepted { + return compose.END, nil + } + return finalAnswerRejectionNode_, nil + } + + return "", err_ + } + + if len(chunk.ToolCalls) > 0 { + return cancelCheckNode, nil + } + } + } + branch := compose.NewStreamGraphBranch(toolCallCheck, map[string]bool{compose.END: true, finalAnswerRejectionNode_: true, cancelCheckNode: true}) + _ = g.AddBranch(chatModelNode, branch) +} + type agenticReactInput struct { Messages []*schema.AgenticMessage } From e8a6e4d132f4ccbeba1ce96e8441519863aae38e Mon Sep 17 00:00:00 2001 From: "shentong.martin" Date: Thu, 16 Apr 2026 12:03:52 +0800 Subject: [PATCH 62/65] refactor(adk): replace bool with FinalAnswerDecision, fix nil state panic, improve docs - Introduce FinalAnswerDecision type (AcceptFinalAnswer/RejectFinalAnswer) to replace bare bool in BeforeFinalAnswer, making call sites self-documenting - Add nil guard for newState in runBeforeFinalAnswer to prevent nil pointer dereference when a handler returns nil state with RejectFinalAnswer - Document ctx propagation limitation with SetRunLocalValue workaround, nil state preservation, and multi-handler veto semantics in godoc - Add protocol comment to toolCallCheck explaining stream drain / state flow - Add 3 regression tests: nil state accept, nil state reject, multi-handler veto Change-Id: I16148180557d8bfe76b99af007ec13799ccd9c00 --- adk/before_final_answer_test.go | 138 +++++++++++++++++++++++++++----- adk/handler.go | 42 +++++++--- adk/react.go | 17 +++- 3 files changed, 161 insertions(+), 36 deletions(-) diff --git a/adk/before_final_answer_test.go b/adk/before_final_answer_test.go index a85548b8d..26edd8d4b 100644 --- a/adk/before_final_answer_test.go +++ b/adk/before_final_answer_test.go @@ -23,6 +23,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/tool" @@ -53,10 +54,10 @@ func (m *bfaDelegateModel) Stream(_ context.Context, input []*schema.Message, _ type beforeFinalAnswerHandler struct { BaseChatModelAgentMiddleware - fn func(ctx context.Context, state *ChatModelAgentState) (context.Context, bool, *ChatModelAgentState, error) + fn func(ctx context.Context, state *ChatModelAgentState) (context.Context, FinalAnswerDecision, *ChatModelAgentState, error) } -func (h *beforeFinalAnswerHandler) BeforeFinalAnswer(ctx context.Context, state *ChatModelAgentState) (context.Context, bool, *ChatModelAgentState, error) { +func (h *beforeFinalAnswerHandler) BeforeFinalAnswer(ctx context.Context, state *ChatModelAgentState) (context.Context, FinalAnswerDecision, *ChatModelAgentState, error) { return h.fn(ctx, state) } @@ -101,9 +102,9 @@ func TestBeforeFinalAnswer(t *testing.T) { Name: "test", Description: "d", Instruction: "i", Model: m, MaxIterations: 5, Handlers: []ChatModelAgentMiddleware{ - &beforeFinalAnswerHandler{fn: func(ctx context.Context, state *ChatModelAgentState) (context.Context, bool, *ChatModelAgentState, error) { + &beforeFinalAnswerHandler{fn: func(ctx context.Context, state *ChatModelAgentState) (context.Context, FinalAnswerDecision, *ChatModelAgentState, error) { atomic.AddInt32(&hookCalls, 1) - return ctx, true, state, nil + return ctx, AcceptFinalAnswer, state, nil }}, }, }) @@ -135,13 +136,13 @@ func TestBeforeFinalAnswer(t *testing.T) { Name: "test", Description: "d", Instruction: "i", Model: m, MaxIterations: 5, Handlers: []ChatModelAgentMiddleware{ - &beforeFinalAnswerHandler{fn: func(ctx context.Context, state *ChatModelAgentState) (context.Context, bool, *ChatModelAgentState, error) { + &beforeFinalAnswerHandler{fn: func(ctx context.Context, state *ChatModelAgentState) (context.Context, FinalAnswerDecision, *ChatModelAgentState, error) { lastMsg := state.Messages[len(state.Messages)-1] if lastMsg.ResponseMeta != nil && lastMsg.ResponseMeta.FinishReason == "length" { state.Messages = append(state.Messages, schema.UserMessage("Please continue.")) - return ctx, false, state, nil + return ctx, RejectFinalAnswer, state, nil } - return ctx, true, state, nil + return ctx, AcceptFinalAnswer, state, nil }}, }, }) @@ -174,12 +175,12 @@ func TestBeforeFinalAnswer(t *testing.T) { Name: "test", Description: "d", Instruction: "i", Model: m, MaxIterations: 5, Handlers: []ChatModelAgentMiddleware{ - &beforeFinalAnswerHandler{fn: func(ctx context.Context, state *ChatModelAgentState) (context.Context, bool, *ChatModelAgentState, error) { + &beforeFinalAnswerHandler{fn: func(ctx context.Context, state *ChatModelAgentState) (context.Context, FinalAnswerDecision, *ChatModelAgentState, error) { lastMsg := state.Messages[len(state.Messages)-1] if lastMsg.ResponseMeta != nil && lastMsg.ResponseMeta.FinishReason == "length" { - return ctx, false, state, nil + return ctx, RejectFinalAnswer, state, nil } - return ctx, true, state, nil + return ctx, AcceptFinalAnswer, state, nil }}, }, }) @@ -205,8 +206,8 @@ func TestBeforeFinalAnswer(t *testing.T) { Name: "test", Description: "d", Instruction: "i", Model: m, MaxIterations: 3, Handlers: []ChatModelAgentMiddleware{ - &beforeFinalAnswerHandler{fn: func(ctx context.Context, state *ChatModelAgentState) (context.Context, bool, *ChatModelAgentState, error) { - return ctx, false, state, nil + &beforeFinalAnswerHandler{fn: func(ctx context.Context, state *ChatModelAgentState) (context.Context, FinalAnswerDecision, *ChatModelAgentState, error) { + return ctx, RejectFinalAnswer, state, nil }}, }, }) @@ -230,8 +231,8 @@ func TestBeforeFinalAnswer(t *testing.T) { Name: "test", Description: "d", Instruction: "i", Model: m, MaxIterations: 5, Handlers: []ChatModelAgentMiddleware{ - &beforeFinalAnswerHandler{fn: func(ctx context.Context, state *ChatModelAgentState) (context.Context, bool, *ChatModelAgentState, error) { - return ctx, false, state, hookErr + &beforeFinalAnswerHandler{fn: func(ctx context.Context, state *ChatModelAgentState) (context.Context, FinalAnswerDecision, *ChatModelAgentState, error) { + return ctx, RejectFinalAnswer, state, hookErr }}, }, }) @@ -266,9 +267,9 @@ func TestBeforeFinalAnswer(t *testing.T) { ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{&fakeToolForTest{tarCount: 0}}}, }, Handlers: []ChatModelAgentMiddleware{ - &beforeFinalAnswerHandler{fn: func(ctx context.Context, state *ChatModelAgentState) (context.Context, bool, *ChatModelAgentState, error) { + &beforeFinalAnswerHandler{fn: func(ctx context.Context, state *ChatModelAgentState) (context.Context, FinalAnswerDecision, *ChatModelAgentState, error) { atomic.AddInt32(&hookCalls, 1) - return ctx, true, state, nil + return ctx, AcceptFinalAnswer, state, nil }}, }, }) @@ -313,13 +314,13 @@ func TestBeforeFinalAnswer(t *testing.T) { ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{&fakeToolForTest{tarCount: 0}}}, }, Handlers: []ChatModelAgentMiddleware{ - &beforeFinalAnswerHandler{fn: func(ctx context.Context, state *ChatModelAgentState) (context.Context, bool, *ChatModelAgentState, error) { + &beforeFinalAnswerHandler{fn: func(ctx context.Context, state *ChatModelAgentState) (context.Context, FinalAnswerDecision, *ChatModelAgentState, error) { lastMsg := state.Messages[len(state.Messages)-1] if lastMsg.ResponseMeta != nil && lastMsg.ResponseMeta.FinishReason == "length" { state.Messages = append(state.Messages, schema.UserMessage("Continue please.")) - return ctx, false, state, nil + return ctx, RejectFinalAnswer, state, nil } - return ctx, true, state, nil + return ctx, AcceptFinalAnswer, state, nil }}, }, }) @@ -369,12 +370,12 @@ func TestBeforeFinalAnswer(t *testing.T) { Name: "test", Description: "d", Instruction: "i", Model: m, MaxIterations: 5, Handlers: []ChatModelAgentMiddleware{ - &beforeFinalAnswerHandler{fn: func(ctx context.Context, state *ChatModelAgentState) (context.Context, bool, *ChatModelAgentState, error) { + &beforeFinalAnswerHandler{fn: func(ctx context.Context, state *ChatModelAgentState) (context.Context, FinalAnswerDecision, *ChatModelAgentState, error) { lastMsg := state.Messages[len(state.Messages)-1] if lastMsg.Content == "" && len(lastMsg.ToolCalls) == 0 { - return ctx, false, state, nil + return ctx, RejectFinalAnswer, state, nil } - return ctx, true, state, nil + return ctx, AcceptFinalAnswer, state, nil }}, }, }) @@ -386,4 +387,97 @@ func TestBeforeFinalAnswer(t *testing.T) { assert.Equal(t, "real content", lastEvent.Output.MessageOutput.Message.Content) assert.Equal(t, int32(2), atomic.LoadInt32(&callCount)) }) + + t.Run("nil state return with accept preserves previous state", func(t *testing.T) { + m := &bfaDelegateModel{ + generateFn: func(_ []*schema.Message) (*schema.Message, error) { + return schema.AssistantMessage("answer", nil), nil + }, + } + + agent, err := NewChatModelAgent(context.Background(), &ChatModelAgentConfig{ + Name: "test", Description: "d", Instruction: "i", Model: m, + MaxIterations: 5, + Handlers: []ChatModelAgentMiddleware{ + &beforeFinalAnswerHandler{fn: func(ctx context.Context, state *ChatModelAgentState) (context.Context, FinalAnswerDecision, *ChatModelAgentState, error) { + return ctx, AcceptFinalAnswer, nil, nil + }}, + }, + }) + require.NoError(t, err) + + events := drainBFAIterator(agent.Run(context.Background(), &AgentInput{Messages: []Message{schema.UserMessage("hi")}})) + require.NotEmpty(t, events) + lastEvent := events[len(events)-1] + assert.Nil(t, lastEvent.Err) + assert.Equal(t, "answer", lastEvent.Output.MessageOutput.Message.Content) + }) + + t.Run("nil state return with reject loops without panic", func(t *testing.T) { + m := &bfaDelegateModel{ + generateFn: func(_ []*schema.Message) (*schema.Message, error) { + return schema.AssistantMessage("answer", nil), nil + }, + } + + agent, err := NewChatModelAgent(context.Background(), &ChatModelAgentConfig{ + Name: "test", Description: "d", Instruction: "i", Model: m, + MaxIterations: 3, + Handlers: []ChatModelAgentMiddleware{ + &beforeFinalAnswerHandler{fn: func(ctx context.Context, state *ChatModelAgentState) (context.Context, FinalAnswerDecision, *ChatModelAgentState, error) { + return ctx, RejectFinalAnswer, nil, nil + }}, + }, + }) + require.NoError(t, err) + + events := drainBFAIterator(agent.Run(context.Background(), &AgentInput{Messages: []Message{schema.UserMessage("hi")}})) + drainBFAStreamEvents(events) + require.NotEmpty(t, events) + lastEvent := events[len(events)-1] + assert.ErrorIs(t, lastEvent.Err, ErrExceedMaxIterations) + }) + + t.Run("multiple handlers: single reject vetoes, all handlers execute", func(t *testing.T) { + var callCount int32 + m := &bfaDelegateModel{ + generateFn: func(_ []*schema.Message) (*schema.Message, error) { + count := atomic.AddInt32(&callCount, 1) + if count == 1 { + return schema.AssistantMessage("bad", nil), nil + } + return schema.AssistantMessage("good", nil), nil + }, + } + + var h1Calls, h2Calls int32 + agent, err := NewChatModelAgent(context.Background(), &ChatModelAgentConfig{ + Name: "test", Description: "d", Instruction: "i", Model: m, + MaxIterations: 5, + Handlers: []ChatModelAgentMiddleware{ + &beforeFinalAnswerHandler{fn: func(ctx context.Context, state *ChatModelAgentState) (context.Context, FinalAnswerDecision, *ChatModelAgentState, error) { + atomic.AddInt32(&h1Calls, 1) + if state.Messages[len(state.Messages)-1].Content == "bad" { + return ctx, RejectFinalAnswer, state, nil + } + return ctx, AcceptFinalAnswer, state, nil + }}, + &beforeFinalAnswerHandler{fn: func(ctx context.Context, state *ChatModelAgentState) (context.Context, FinalAnswerDecision, *ChatModelAgentState, error) { + atomic.AddInt32(&h2Calls, 1) + return ctx, AcceptFinalAnswer, state, nil + }}, + }, + }) + require.NoError(t, err) + + events := drainBFAIterator(agent.Run(context.Background(), &AgentInput{Messages: []Message{schema.UserMessage("hi")}})) + require.NotEmpty(t, events) + lastEvent := events[len(events)-1] + assert.Nil(t, lastEvent.Err) + assert.Equal(t, "good", lastEvent.Output.MessageOutput.Message.Content) + + assert.Equal(t, int32(2), atomic.LoadInt32(&h1Calls)) + assert.Equal(t, int32(2), atomic.LoadInt32(&h2Calls)) + assert.Equal(t, int32(2), atomic.LoadInt32(&callCount)) + }) } diff --git a/adk/handler.go b/adk/handler.go index e061d2469..5d7cd2a79 100644 --- a/adk/handler.go +++ b/adk/handler.go @@ -123,6 +123,8 @@ type ChatModelAgentContext struct { // - AgentMiddleware is kept for backward compatibility with existing users // - Both can be used together; see AgentMiddleware documentation for execution order // +// ChatModelAgentMiddleware defines the interface for customizing ChatModelAgent behavior. +// // Use *TypedBaseChatModelAgentMiddleware as an embedded struct to provide default no-op // implementations for all methods. type TypedChatModelAgentMiddleware[M messageType] interface { @@ -221,25 +223,45 @@ type TypedChatModelAgentMiddleware[M messageType] interface { WrapModel(ctx context.Context, m model.BaseModel[M], mc *ModelContext) (model.BaseModel[M], error) // BeforeFinalAnswer is called when the model produces a response with no tool calls - // (a "final answer") before it is accepted and the agent exits. + // (a "final answer"). It acts as a quality gate: inspect the response and decide + // whether to accept or reject it. // // The state contains all messages including the model's final answer as the last message. // The hook can inspect the response (e.g., FinishReason, content) and decide whether // to accept or reject it. // // Returns: - // - ctx: the (possibly modified) context - // - accept: if true, the final answer is accepted and the agent exits normally. - // If false, the agent loops back to the ChatModel for another iteration. - // The handler may modify state.Messages before returning false (e.g., append a - // "please continue" user message after a truncated response). - // - state: the (possibly modified) agent state + // - ctx: the (possibly modified) context. Note: context modifications are propagated + // to subsequent handlers in the chain and used for state write-back, but are NOT + // propagated to downstream graph nodes due to a compose framework constraint. + // To pass data across iterations, use SetRunLocalValue/GetRunLocalValue instead. + // - decision: AcceptFinalAnswer to exit the agent normally, or RejectFinalAnswer to + // loop back to the ChatModel for another iteration. The handler may modify + // state.Messages before rejecting (e.g., append a "please continue" user message + // after a truncated response). + // - state: the (possibly modified) agent state. If nil, the previous state is preserved. // - error: if non-nil, the agent exits with this error // // Rejected answers count toward MaxIterations, providing a natural cap on runaway loops. - BeforeFinalAnswer(ctx context.Context, state *ChatModelAgentState) (context.Context, bool, *ChatModelAgentState, error) + // + // When multiple handlers are registered, all handlers execute in order even if an earlier + // handler rejects. A single RejectFinalAnswer from any handler vetoes the final answer. + // Handlers see state modifications from prior handlers in the chain. + BeforeFinalAnswer(ctx context.Context, state *ChatModelAgentState) (context.Context, FinalAnswerDecision, *ChatModelAgentState, error) } +// FinalAnswerDecision represents the decision made by a BeforeFinalAnswer hook. +type FinalAnswerDecision int + +const ( + // AcceptFinalAnswer indicates the final answer should be accepted and the agent exits normally. + AcceptFinalAnswer FinalAnswerDecision = iota + + // RejectFinalAnswer indicates the final answer should be rejected and the agent loops + // back to the ChatModel for another iteration. + RejectFinalAnswer +) + // ChatModelAgentMiddleware is the default middleware type using *schema.Message. // See TypedChatModelAgentMiddleware for full documentation. type ChatModelAgentMiddleware = TypedChatModelAgentMiddleware[*schema.Message] @@ -298,8 +320,8 @@ func (b *TypedBaseChatModelAgentMiddleware[M]) AfterToolCallsRewriteState(ctx co return ctx, state, nil } -func (b *BaseChatModelAgentMiddleware) BeforeFinalAnswer(ctx context.Context, state *ChatModelAgentState) (context.Context, bool, *ChatModelAgentState, error) { - return ctx, true, state, nil +func (b *BaseChatModelAgentMiddleware) BeforeFinalAnswer(ctx context.Context, state *ChatModelAgentState) (context.Context, FinalAnswerDecision, *ChatModelAgentState, error) { + return ctx, AcceptFinalAnswer, state, nil } func processTypedState(ctx context.Context, fn func(extra map[string]any) map[string]any) error { diff --git a/adk/react.go b/adk/react.go index e37712d74..2816c3ae0 100644 --- a/adk/react.go +++ b/adk/react.go @@ -496,15 +496,17 @@ func runBeforeFinalAnswer(ctx context.Context, mwConf *modelWrapperConfig) (bool accepted := true for _, handler := range mwConf.handlers { - var accept bool + var decision FinalAnswerDecision var newState *ChatModelAgentState var err error - ctx, accept, newState, err = handler.BeforeFinalAnswer(ctx, state) + ctx, decision, newState, err = handler.BeforeFinalAnswer(ctx, state) if err != nil { return false, err } - state = newState - if !accept { + if newState != nil { + state = newState + } + if decision == RejectFinalAnswer { accepted = false } } @@ -531,6 +533,13 @@ func addFinalAnswerBranch(g *compose.Graph[*reactInput, Message], chatModelNode, }), compose.WithNodeName(finalAnswerRejectionNode_)) _ = g.AddEdge(finalAnswerRejectionNode_, chatModelNode) + // toolCallCheck drains the model's output stream to determine routing: + // - If any chunk contains tool calls → route to tool execution (cancelCheckNode) + // - If stream ends with no tool calls → this is a final answer. + // The model wrapper has already written the complete response into + // State.Messages via ProcessState, so BeforeFinalAnswer hooks see + // the full response. If rejected, modified state is written back and + // we route to FinalAnswerRejection to loop back to ChatModel. toolCallCheck := func(ctx context.Context, sMsg MessageStream) (string, error) { defer sMsg.Close() for { From c0c1a80479556ec495b7d9191770ff454e248d38 Mon Sep 17 00:00:00 2001 From: "shentong.martin" Date: Tue, 21 Apr 2026 16:22:33 +0800 Subject: [PATCH 63/65] refactor(adk): short-circuit BeforeFinalAnswer on first reject, wire ReturnDirectly through gate Change-Id: I5e03169f8411ce0a3dc17b6a3d47d901aece3607 --- adk/before_final_answer_test.go | 244 +++++++++++++++++++++++++++++++- adk/handler.go | 39 ++--- adk/react.go | 71 ++++++---- 3 files changed, 311 insertions(+), 43 deletions(-) diff --git a/adk/before_final_answer_test.go b/adk/before_final_answer_test.go index 26edd8d4b..f52395be1 100644 --- a/adk/before_final_answer_test.go +++ b/adk/before_final_answer_test.go @@ -438,7 +438,7 @@ func TestBeforeFinalAnswer(t *testing.T) { assert.ErrorIs(t, lastEvent.Err, ErrExceedMaxIterations) }) - t.Run("multiple handlers: single reject vetoes, all handlers execute", func(t *testing.T) { + t.Run("multiple handlers: first reject short-circuits, skips remaining handlers", func(t *testing.T) { var callCount int32 m := &bfaDelegateModel{ generateFn: func(_ []*schema.Message) (*schema.Message, error) { @@ -476,8 +476,248 @@ func TestBeforeFinalAnswer(t *testing.T) { assert.Nil(t, lastEvent.Err) assert.Equal(t, "good", lastEvent.Output.MessageOutput.Message.Content) + // h1 runs both times (reject on "bad", accept on "good") assert.Equal(t, int32(2), atomic.LoadInt32(&h1Calls)) - assert.Equal(t, int32(2), atomic.LoadInt32(&h2Calls)) + // h2 is skipped on the first call (h1 rejected), but runs on the second (h1 accepted) + assert.Equal(t, int32(1), atomic.LoadInt32(&h2Calls)) assert.Equal(t, int32(2), atomic.LoadInt32(&callCount)) }) + + t.Run("multiple handlers: all accept, all execute", func(t *testing.T) { + m := &bfaDelegateModel{ + generateFn: func(_ []*schema.Message) (*schema.Message, error) { + return schema.AssistantMessage("answer", nil), nil + }, + } + + var h1Calls, h2Calls, h3Calls int32 + agent, err := NewChatModelAgent(context.Background(), &ChatModelAgentConfig{ + Name: "test", Description: "d", Instruction: "i", Model: m, + MaxIterations: 5, + Handlers: []ChatModelAgentMiddleware{ + &beforeFinalAnswerHandler{fn: func(ctx context.Context, state *ChatModelAgentState) (context.Context, FinalAnswerDecision, *ChatModelAgentState, error) { + atomic.AddInt32(&h1Calls, 1) + return ctx, AcceptFinalAnswer, state, nil + }}, + &beforeFinalAnswerHandler{fn: func(ctx context.Context, state *ChatModelAgentState) (context.Context, FinalAnswerDecision, *ChatModelAgentState, error) { + atomic.AddInt32(&h2Calls, 1) + return ctx, AcceptFinalAnswer, state, nil + }}, + &beforeFinalAnswerHandler{fn: func(ctx context.Context, state *ChatModelAgentState) (context.Context, FinalAnswerDecision, *ChatModelAgentState, error) { + atomic.AddInt32(&h3Calls, 1) + return ctx, AcceptFinalAnswer, state, nil + }}, + }, + }) + require.NoError(t, err) + + events := drainBFAIterator(agent.Run(context.Background(), &AgentInput{Messages: []Message{schema.UserMessage("hi")}})) + require.NotEmpty(t, events) + lastEvent := events[len(events)-1] + assert.Nil(t, lastEvent.Err) + assert.Equal(t, "answer", lastEvent.Output.MessageOutput.Message.Content) + assert.Equal(t, int32(1), atomic.LoadInt32(&h1Calls)) + assert.Equal(t, int32(1), atomic.LoadInt32(&h2Calls)) + assert.Equal(t, int32(1), atomic.LoadInt32(&h3Calls)) + }) + + t.Run("multiple handlers: first handler rejects, all subsequent skipped", func(t *testing.T) { + m := &bfaDelegateModel{ + generateFn: func(_ []*schema.Message) (*schema.Message, error) { + return schema.AssistantMessage("bad", nil), nil + }, + } + + var h1Calls, h2Calls, h3Calls int32 + agent, err := NewChatModelAgent(context.Background(), &ChatModelAgentConfig{ + Name: "test", Description: "d", Instruction: "i", Model: m, + MaxIterations: 3, + Handlers: []ChatModelAgentMiddleware{ + &beforeFinalAnswerHandler{fn: func(ctx context.Context, state *ChatModelAgentState) (context.Context, FinalAnswerDecision, *ChatModelAgentState, error) { + atomic.AddInt32(&h1Calls, 1) + return ctx, RejectFinalAnswer, state, nil + }}, + &beforeFinalAnswerHandler{fn: func(ctx context.Context, state *ChatModelAgentState) (context.Context, FinalAnswerDecision, *ChatModelAgentState, error) { + atomic.AddInt32(&h2Calls, 1) + return ctx, AcceptFinalAnswer, state, nil + }}, + &beforeFinalAnswerHandler{fn: func(ctx context.Context, state *ChatModelAgentState) (context.Context, FinalAnswerDecision, *ChatModelAgentState, error) { + atomic.AddInt32(&h3Calls, 1) + return ctx, AcceptFinalAnswer, state, nil + }}, + }, + }) + require.NoError(t, err) + + events := drainBFAIterator(agent.Run(context.Background(), &AgentInput{Messages: []Message{schema.UserMessage("hi")}})) + require.NotEmpty(t, events) + lastEvent := events[len(events)-1] + assert.ErrorIs(t, lastEvent.Err, ErrExceedMaxIterations) + // h1 runs every iteration, h2 and h3 never run (always short-circuited) + assert.Equal(t, int32(3), atomic.LoadInt32(&h1Calls)) + assert.Equal(t, int32(0), atomic.LoadInt32(&h2Calls)) + assert.Equal(t, int32(0), atomic.LoadInt32(&h3Calls)) + }) + + t.Run("accept path: handler state modifications are written back", func(t *testing.T) { + // This verifies the bug fix: accept-path state modifications were + // previously silently dropped because state was only written back on reject. + var callCount int32 + m := &bfaDelegateModel{ + generateFn: func(input []*schema.Message) (*schema.Message, error) { + count := atomic.AddInt32(&callCount, 1) + if count == 1 { + return schema.AssistantMessage("first answer", nil), nil + } + // On the second call, verify the handler's state modification persisted. + // The handler should have appended a marker message on the first accept. + lastUserMsg := "" + for i := len(input) - 1; i >= 0; i-- { + if input[i].Role == schema.User { + lastUserMsg = input[i].Content + break + } + } + return schema.AssistantMessage("saw:"+lastUserMsg, nil), nil + }, + } + + firstCall := true + agent, err := NewChatModelAgent(context.Background(), &ChatModelAgentConfig{ + Name: "test", Description: "d", Instruction: "i", Model: m, + MaxIterations: 5, + Handlers: []ChatModelAgentMiddleware{ + // First handler: on first call, modifies state AND accepts. + // Then rejects on second call to trigger another iteration to verify state. + &beforeFinalAnswerHandler{fn: func(ctx context.Context, state *ChatModelAgentState) (context.Context, FinalAnswerDecision, *ChatModelAgentState, error) { + if firstCall { + firstCall = false + // Modify state (e.g., strip thinking tokens, add metadata) + state.Messages = append(state.Messages, schema.UserMessage("marker-from-accept-handler")) + return ctx, RejectFinalAnswer, state, nil + } + return ctx, AcceptFinalAnswer, state, nil + }}, + }, + }) + require.NoError(t, err) + + events := drainBFAIterator(agent.Run(context.Background(), &AgentInput{Messages: []Message{schema.UserMessage("hi")}})) + require.NotEmpty(t, events) + lastEvent := events[len(events)-1] + assert.Nil(t, lastEvent.Err) + // The model's second call should see the marker message appended by the handler + assert.Equal(t, "saw:marker-from-accept-handler", lastEvent.Output.MessageOutput.Message.Content) + }) + + t.Run("ReturnDirectly: accepted tool result exits normally through BeforeFinalAnswer", func(t *testing.T) { + var hookCalls int32 + + m := &bfaDelegateModel{ + generateFn: func(_ []*schema.Message) (*schema.Message, error) { + return schema.AssistantMessage("", []schema.ToolCall{{ + ID: "call-1", + Function: schema.FunctionCall{Name: "rd_tool", Arguments: `{}`}, + }}), nil + }, + } + + rdTool := &bfaReturnDirectlyTool{name: "rd_tool", result: "tool output"} + + agent, err := NewChatModelAgent(context.Background(), &ChatModelAgentConfig{ + Name: "test", Description: "d", Instruction: "i", Model: m, + MaxIterations: 5, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{rdTool}}, + ReturnDirectly: map[string]bool{"rd_tool": true}, + }, + Handlers: []ChatModelAgentMiddleware{ + &beforeFinalAnswerHandler{fn: func(ctx context.Context, state *ChatModelAgentState) (context.Context, FinalAnswerDecision, *ChatModelAgentState, error) { + atomic.AddInt32(&hookCalls, 1) + return ctx, AcceptFinalAnswer, state, nil + }}, + }, + }) + require.NoError(t, err) + + events := drainBFAIterator(agent.Run(context.Background(), &AgentInput{Messages: []Message{schema.UserMessage("hi")}})) + drainBFAStreamEvents(events) + // Hook must have been called (via the ReturnDirectly path) + assert.Equal(t, int32(1), atomic.LoadInt32(&hookCalls)) + }) + + t.Run("ReturnDirectly: rejected tool result loops back to ChatModel", func(t *testing.T) { + var generateCount int32 + var hookCalls int32 + + m := &bfaDelegateModel{ + generateFn: func(_ []*schema.Message) (*schema.Message, error) { + count := atomic.AddInt32(&generateCount, 1) + if count == 1 { + // First call: invoke the ReturnDirectly tool + return schema.AssistantMessage("", []schema.ToolCall{{ + ID: "call-1", + Function: schema.FunctionCall{Name: "rd_tool", Arguments: `{}`}, + }}), nil + } + // Second call: after rejection, model produces a normal final answer + return schema.AssistantMessage("final answer after rejection", nil), nil + }, + } + + rdTool := &bfaReturnDirectlyTool{name: "rd_tool", result: "insufficient result"} + firstCall := true + + agent, err := NewChatModelAgent(context.Background(), &ChatModelAgentConfig{ + Name: "test", Description: "d", Instruction: "i", Model: m, + MaxIterations: 5, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{rdTool}}, + ReturnDirectly: map[string]bool{"rd_tool": true}, + }, + Handlers: []ChatModelAgentMiddleware{ + &beforeFinalAnswerHandler{fn: func(ctx context.Context, state *ChatModelAgentState) (context.Context, FinalAnswerDecision, *ChatModelAgentState, error) { + atomic.AddInt32(&hookCalls, 1) + if firstCall { + firstCall = false + // Reject the ReturnDirectly result and append feedback + state.Messages = append(state.Messages, schema.UserMessage("Tool result insufficient, try a different approach.")) + return ctx, RejectFinalAnswer, state, nil + } + // Accept the model's second answer + return ctx, AcceptFinalAnswer, state, nil + }}, + }, + }) + require.NoError(t, err) + + events := drainBFAIterator(agent.Run(context.Background(), &AgentInput{Messages: []Message{schema.UserMessage("hi")}})) + drainBFAStreamEvents(events) + require.NotEmpty(t, events) + + lastEvent := events[len(events)-1] + assert.Nil(t, lastEvent.Err) + assert.Equal(t, "final answer after rejection", lastEvent.Output.MessageOutput.Message.Content) + // Hook called twice: once for ReturnDirectly (reject), once for model answer (accept) + assert.Equal(t, int32(2), atomic.LoadInt32(&hookCalls)) + // Model called twice: first produced tool call, second produced final answer + assert.Equal(t, int32(2), atomic.LoadInt32(&generateCount)) + }) +} + +// bfaReturnDirectlyTool is a simple invokable tool for BeforeFinalAnswer tests. +type bfaReturnDirectlyTool struct { + name string + result string +} + +func (t *bfaReturnDirectlyTool) Info(_ context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: t.name, + Desc: "A tool that returns directly for BFA testing", + }, nil +} + +func (t *bfaReturnDirectlyTool) InvokableRun(_ context.Context, _ string, _ ...tool.Option) (string, error) { + return t.result, nil } diff --git a/adk/handler.go b/adk/handler.go index 5d7cd2a79..e6c2c0e0e 100644 --- a/adk/handler.go +++ b/adk/handler.go @@ -230,23 +230,27 @@ type TypedChatModelAgentMiddleware[M messageType] interface { // The hook can inspect the response (e.g., FinishReason, content) and decide whether // to accept or reject it. // + // Short-circuit semantics: Handlers execute in registration order. The first + // RejectFinalAnswer immediately stops the chain — remaining handlers are skipped. + // This allows users to control execution policy via registration order: + // - Observers that must always run (e.g., memory extraction): register FIRST + // - Quality gates that may reject (e.g., CI checks): register in the MIDDLE + // - Observers that should only run on accepted answers: register LAST + // // Returns: - // - ctx: the (possibly modified) context. Note: context modifications are propagated - // to subsequent handlers in the chain and used for state write-back, but are NOT - // propagated to downstream graph nodes due to a compose framework constraint. - // To pass data across iterations, use SetRunLocalValue/GetRunLocalValue instead. - // - decision: AcceptFinalAnswer to exit the agent normally, or RejectFinalAnswer to - // loop back to the ChatModel for another iteration. The handler may modify - // state.Messages before rejecting (e.g., append a "please continue" user message - // after a truncated response). + // - ctx: the (possibly modified) context. Context modifications are propagated + // to subsequent handlers in the chain (if not short-circuited) and used for + // state write-back. Note: context is NOT propagated to downstream graph nodes + // due to a compose framework constraint. To pass data across iterations, use + // SetRunLocalValue/GetRunLocalValue instead. + // - decision: AcceptFinalAnswer to continue the chain, or RejectFinalAnswer to + // immediately stop the chain, write back state, and loop back to the ChatModel. + // The handler may modify state.Messages before rejecting (e.g., append a + // "please continue" user message after a truncated response). // - state: the (possibly modified) agent state. If nil, the previous state is preserved. - // - error: if non-nil, the agent exits with this error + // - error: if non-nil, the agent exits with this error (regardless of decision). // // Rejected answers count toward MaxIterations, providing a natural cap on runaway loops. - // - // When multiple handlers are registered, all handlers execute in order even if an earlier - // handler rejects. A single RejectFinalAnswer from any handler vetoes the final answer. - // Handlers see state modifications from prior handlers in the chain. BeforeFinalAnswer(ctx context.Context, state *ChatModelAgentState) (context.Context, FinalAnswerDecision, *ChatModelAgentState, error) } @@ -254,11 +258,14 @@ type TypedChatModelAgentMiddleware[M messageType] interface { type FinalAnswerDecision int const ( - // AcceptFinalAnswer indicates the final answer should be accepted and the agent exits normally. + // AcceptFinalAnswer indicates this handler accepts the final answer. + // The chain continues to the next handler. If all handlers accept, + // the agent exits normally. AcceptFinalAnswer FinalAnswerDecision = iota - // RejectFinalAnswer indicates the final answer should be rejected and the agent loops - // back to the ChatModel for another iteration. + // RejectFinalAnswer indicates this handler rejects the final answer. + // The chain is immediately short-circuited: remaining handlers are skipped, + // state is written back, and the agent loops back to the ChatModel. RejectFinalAnswer ) diff --git a/adk/react.go b/adk/react.go index 2816c3ae0..e05e528a3 100644 --- a/adk/react.go +++ b/adk/react.go @@ -435,10 +435,24 @@ func newReact(ctx context.Context, config *reactConfig) (reactGraph, error) { _ = g.AddLambdaNode(afterToolCallsCancelCheckNode_, compose.InvokableLambda(afterToolCallsCancelCheck), compose.WithNodeName(afterToolCallsCancelCheckNode_)) + // FinalAnswerRejection reads state.Messages (possibly modified by BeforeFinalAnswer hooks) + // and feeds them back to ChatModel for another iteration. Shared by both the model's + // final-answer path and the ReturnDirectly path. + const finalAnswerRejectionNode_ = "FinalAnswerRejection" + _ = g.AddLambdaNode(finalAnswerRejectionNode_, compose.InvokableLambda(func(ctx context.Context, _ Message) ([]Message, error) { + var msgs []Message + _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { + msgs = st.Messages + return nil + }) + return msgs, nil + }), compose.WithNodeName(finalAnswerRejectionNode_)) + _ = g.AddEdge(finalAnswerRejectionNode_, chatModel_) + _ = g.AddEdge(compose.START, initNode_) _ = g.AddEdge(initNode_, chatModel_) - addFinalAnswerBranch(g, chatModel_, cancelCheckNode_, config.modelWrapperConf) + addFinalAnswerBranch(g, chatModel_, cancelCheckNode_, finalAnswerRejectionNode_, config.modelWrapperConf) _ = g.AddEdge(cancelCheckNode_, toolNode_) _ = g.AddEdge(toolNode_, afterToolCallsNode_) @@ -462,7 +476,21 @@ func newReact(ctx context.Context, config *reactConfig) (reactGraph, error) { _ = g.AddLambdaNode(toolNodeToEndConverter, compose.InvokableLambda(cvt), compose.WithNodeName(toolNodeToEndConverter)) - _ = g.AddEdge(toolNodeToEndConverter, compose.END) + + // ReturnDirectly results also go through BeforeFinalAnswer hooks. + // If rejected, route to FinalAnswerRejection to loop back to ChatModel. + returnDirectFinalAnswerCheck := func(ctx context.Context, _ Message) (string, error) { + accepted, err := runBeforeFinalAnswer(ctx, config.modelWrapperConf) + if err != nil { + return "", err + } + if accepted { + return compose.END, nil + } + return finalAnswerRejectionNode_, nil + } + _ = g.AddBranch(toolNodeToEndConverter, compose.NewGraphBranch(returnDirectFinalAnswerCheck, + map[string]bool{compose.END: true, finalAnswerRejectionNode_: true})) checkReturnDirect := func(ctx context.Context, toolResults []Message) (string, error) { _, ok := getReturnDirectlyToolCallID(ctx) @@ -493,7 +521,6 @@ func runBeforeFinalAnswer(ctx context.Context, mwConf *modelWrapperConfig) (bool }) state := &ChatModelAgentState{Messages: stateMessages} - accepted := true for _, handler := range mwConf.handlers { var decision FinalAnswerDecision @@ -507,32 +534,26 @@ func runBeforeFinalAnswer(ctx context.Context, mwConf *modelWrapperConfig) (bool state = newState } if decision == RejectFinalAnswer { - accepted = false + // Short-circuit: write back state immediately and skip remaining handlers. + _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { + st.Messages = state.Messages + return nil + }) + return false, nil } } - if !accepted { - _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { - st.Messages = state.Messages - return nil - }) - } + // All handlers accepted — write back state in case any handler modified it + // (e.g., stripping thinking tokens, adding metadata). + _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { + st.Messages = state.Messages + return nil + }) - return accepted, nil + return true, nil } -func addFinalAnswerBranch(g *compose.Graph[*reactInput, Message], chatModelNode, cancelCheckNode string, mwConf *modelWrapperConfig) { - const finalAnswerRejectionNode_ = "FinalAnswerRejection" - _ = g.AddLambdaNode(finalAnswerRejectionNode_, compose.InvokableLambda(func(ctx context.Context, _ Message) ([]Message, error) { - var msgs []Message - _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { - msgs = st.Messages - return nil - }) - return msgs, nil - }), compose.WithNodeName(finalAnswerRejectionNode_)) - _ = g.AddEdge(finalAnswerRejectionNode_, chatModelNode) - +func addFinalAnswerBranch(g *compose.Graph[*reactInput, Message], chatModelNode, cancelCheckNode, finalAnswerRejectionNode string, mwConf *modelWrapperConfig) { // toolCallCheck drains the model's output stream to determine routing: // - If any chunk contains tool calls → route to tool execution (cancelCheckNode) // - If stream ends with no tool calls → this is a final answer. @@ -553,7 +574,7 @@ func addFinalAnswerBranch(g *compose.Graph[*reactInput, Message], chatModelNode, if accepted { return compose.END, nil } - return finalAnswerRejectionNode_, nil + return finalAnswerRejectionNode, nil } return "", err_ @@ -564,7 +585,7 @@ func addFinalAnswerBranch(g *compose.Graph[*reactInput, Message], chatModelNode, } } } - branch := compose.NewStreamGraphBranch(toolCallCheck, map[string]bool{compose.END: true, finalAnswerRejectionNode_: true, cancelCheckNode: true}) + branch := compose.NewStreamGraphBranch(toolCallCheck, map[string]bool{compose.END: true, finalAnswerRejectionNode: true, cancelCheckNode: true}) _ = g.AddBranch(chatModelNode, branch) } From 5843ac1cb9bb2fc0dcb24c6cec680f6ec91ece85 Mon Sep 17 00:00:00 2001 From: "shentong.martin" Date: Tue, 21 Apr 2026 17:08:33 +0800 Subject: [PATCH 64/65] fix(adk): resolve rebase conflicts with alpha/09 AgenticMessage integration - Remove orphaned code block in getRunFunc left from conflict resolution - Change BeforeFinalAnswer receiver to generic TypedBaseChatModelAgentMiddleware[M] to avoid Go's "cannot define methods on instantiated type" error - Guard no-tools chain optimization with len(a.handlers) == 0 so agents with BeforeFinalAnswer handlers still route through the react graph Change-Id: Ic32282313888c4433ecf9e8bbb31e50a0e386cca --- adk/chatmodel.go | 18 +----------------- adk/handler.go | 2 +- 2 files changed, 2 insertions(+), 18 deletions(-) diff --git a/adk/chatmodel.go b/adk/chatmodel.go index 5370d38a8..173d71e20 100644 --- a/adk/chatmodel.go +++ b/adk/chatmodel.go @@ -1295,7 +1295,7 @@ func (a *TypedChatModelAgent[M]) buildRunFunc(ctx context.Context) typedRunFunc[ a.exeCtx = ec - if len(ec.toolsNodeConf.Tools) == 0 { + if len(ec.toolsNodeConf.Tools) == 0 && len(a.handlers) == 0 { var run typedRunFunc[M] run, err = a.buildNoToolsRunFunc(ctx) if err != nil { @@ -1346,22 +1346,6 @@ func (a *TypedChatModelAgent[M]) getRunFunc(ctx context.Context) (context.Contex return ctx, defaultRun, runtimeBC, nil } - var tempRun typedRunFunc[M] - if len(runtimeBC.toolsNodeConf.Tools) == 0 { - tempRun, err = a.buildNoToolsRunFunc(ctx) - if err != nil { - return ctx, nil, nil, err - } - } else { - tempRun, err = a.buildReActRunFunc(ctx, runtimeBC) - if err != nil { - return ctx, nil, nil, err - } - } - - return ctx, tempRun, runtimeBC, nil -} - func (a *TypedChatModelAgent[M]) Run(ctx context.Context, input *TypedAgentInput[M], opts ...AgentRunOption) *AsyncIterator[*TypedAgentEvent[M]] { iterator, generator := NewAsyncIteratorPair[*TypedAgentEvent[M]]() diff --git a/adk/handler.go b/adk/handler.go index e6c2c0e0e..f58381823 100644 --- a/adk/handler.go +++ b/adk/handler.go @@ -327,7 +327,7 @@ func (b *TypedBaseChatModelAgentMiddleware[M]) AfterToolCallsRewriteState(ctx co return ctx, state, nil } -func (b *BaseChatModelAgentMiddleware) BeforeFinalAnswer(ctx context.Context, state *ChatModelAgentState) (context.Context, FinalAnswerDecision, *ChatModelAgentState, error) { +func (b *TypedBaseChatModelAgentMiddleware[M]) BeforeFinalAnswer(ctx context.Context, state *ChatModelAgentState) (context.Context, FinalAnswerDecision, *ChatModelAgentState, error) { return ctx, AcceptFinalAnswer, state, nil } From 25702596030687b464fab7cf436b26871fcd938b Mon Sep 17 00:00:00 2001 From: "shentong.martin" Date: Tue, 21 Apr 2026 21:19:32 +0800 Subject: [PATCH 65/65] refactor(adk): make BeforeFinalAnswer generic and wire into agentic graph - Change BeforeFinalAnswer signature to use TypedChatModelAgentState[M] instead of hardcoded ChatModelAgentState, consistent with all other state-touching hooks - Rename runBeforeFinalAnswer to runTypedBeforeFinalAnswer[M] so it works with both *schema.Message and *schema.AgenticMessage state types - Wire BeforeFinalAnswer into newAgenticReact: add FinalAnswerRejection node, gate model final-answer path, and gate ReturnDirectly path Change-Id: I273b3f9bb08c889adcbc8e1068844f95625be2be --- adk/handler.go | 4 ++-- adk/react.go | 59 ++++++++++++++++++++++++++++++++++++++++---------- 2 files changed, 49 insertions(+), 14 deletions(-) diff --git a/adk/handler.go b/adk/handler.go index f58381823..52f8c181d 100644 --- a/adk/handler.go +++ b/adk/handler.go @@ -251,7 +251,7 @@ type TypedChatModelAgentMiddleware[M messageType] interface { // - error: if non-nil, the agent exits with this error (regardless of decision). // // Rejected answers count toward MaxIterations, providing a natural cap on runaway loops. - BeforeFinalAnswer(ctx context.Context, state *ChatModelAgentState) (context.Context, FinalAnswerDecision, *ChatModelAgentState, error) + BeforeFinalAnswer(ctx context.Context, state *TypedChatModelAgentState[M]) (context.Context, FinalAnswerDecision, *TypedChatModelAgentState[M], error) } // FinalAnswerDecision represents the decision made by a BeforeFinalAnswer hook. @@ -327,7 +327,7 @@ func (b *TypedBaseChatModelAgentMiddleware[M]) AfterToolCallsRewriteState(ctx co return ctx, state, nil } -func (b *TypedBaseChatModelAgentMiddleware[M]) BeforeFinalAnswer(ctx context.Context, state *ChatModelAgentState) (context.Context, FinalAnswerDecision, *ChatModelAgentState, error) { +func (b *TypedBaseChatModelAgentMiddleware[M]) BeforeFinalAnswer(ctx context.Context, state *TypedChatModelAgentState[M]) (context.Context, FinalAnswerDecision, *TypedChatModelAgentState[M], error) { return ctx, AcceptFinalAnswer, state, nil } diff --git a/adk/react.go b/adk/react.go index e05e528a3..f4ee9071e 100644 --- a/adk/react.go +++ b/adk/react.go @@ -480,7 +480,7 @@ func newReact(ctx context.Context, config *reactConfig) (reactGraph, error) { // ReturnDirectly results also go through BeforeFinalAnswer hooks. // If rejected, route to FinalAnswerRejection to loop back to ChatModel. returnDirectFinalAnswerCheck := func(ctx context.Context, _ Message) (string, error) { - accepted, err := runBeforeFinalAnswer(ctx, config.modelWrapperConf) + accepted, err := runTypedBeforeFinalAnswer(ctx, config.modelWrapperConf) if err != nil { return "", err } @@ -509,22 +509,22 @@ func newReact(ctx context.Context, config *reactConfig) (reactGraph, error) { return g, nil } -func runBeforeFinalAnswer(ctx context.Context, mwConf *modelWrapperConfig) (bool, error) { +func runTypedBeforeFinalAnswer[M messageType](ctx context.Context, mwConf *typedModelWrapperConfig[M]) (bool, error) { if mwConf == nil || len(mwConf.handlers) == 0 { return true, nil } - var stateMessages []Message - _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { + var stateMessages []M + _ = compose.ProcessState(ctx, func(_ context.Context, st *typedState[M]) error { stateMessages = st.Messages return nil }) - state := &ChatModelAgentState{Messages: stateMessages} + state := &TypedChatModelAgentState[M]{Messages: stateMessages} for _, handler := range mwConf.handlers { var decision FinalAnswerDecision - var newState *ChatModelAgentState + var newState *TypedChatModelAgentState[M] var err error ctx, decision, newState, err = handler.BeforeFinalAnswer(ctx, state) if err != nil { @@ -535,7 +535,7 @@ func runBeforeFinalAnswer(ctx context.Context, mwConf *modelWrapperConfig) (bool } if decision == RejectFinalAnswer { // Short-circuit: write back state immediately and skip remaining handlers. - _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { + _ = compose.ProcessState(ctx, func(_ context.Context, st *typedState[M]) error { st.Messages = state.Messages return nil }) @@ -545,7 +545,7 @@ func runBeforeFinalAnswer(ctx context.Context, mwConf *modelWrapperConfig) (bool // All handlers accepted — write back state in case any handler modified it // (e.g., stripping thinking tokens, adding metadata). - _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { + _ = compose.ProcessState(ctx, func(_ context.Context, st *typedState[M]) error { st.Messages = state.Messages return nil }) @@ -567,7 +567,7 @@ func addFinalAnswerBranch(g *compose.Graph[*reactInput, Message], chatModelNode, chunk, err_ := sMsg.Recv() if err_ != nil { if err_ == io.EOF { - accepted, err := runBeforeFinalAnswer(ctx, mwConf) + accepted, err := runTypedBeforeFinalAnswer(ctx, mwConf) if err != nil { return "", err } @@ -736,6 +736,20 @@ func newAgenticReact(ctx context.Context, config *agenticReactConfig) (agenticRe _ = g.AddLambdaNode(afterToolCallsCancelCheckNode_, compose.InvokableLambda(afterToolCallsCancelCheck), compose.WithNodeName(afterToolCallsCancelCheckNode_)) + // FinalAnswerRejection reads state.Messages (possibly modified by BeforeFinalAnswer hooks) + // and feeds them back to ChatModel for another iteration. Shared by both the model's + // final-answer path and the ReturnDirectly path. + const finalAnswerRejectionNode_ = "FinalAnswerRejection" + _ = g.AddLambdaNode(finalAnswerRejectionNode_, compose.InvokableLambda(func(ctx context.Context, _ *schema.AgenticMessage) ([]*schema.AgenticMessage, error) { + var msgs []*schema.AgenticMessage + _ = compose.ProcessState(ctx, func(_ context.Context, st *agenticState) error { + msgs = st.Messages + return nil + }) + return msgs, nil + }), compose.WithNodeName(finalAnswerRejectionNode_)) + _ = g.AddEdge(finalAnswerRejectionNode_, chatModel_) + _ = g.AddEdge(compose.START, initNode_) _ = g.AddEdge(initNode_, chatModel_) @@ -745,7 +759,14 @@ func newAgenticReact(ctx context.Context, config *agenticReactConfig) (agenticRe chunk, err_ := sMsg.Recv() if err_ != nil { if err_ == io.EOF { - return compose.END, nil + accepted, err := runTypedBeforeFinalAnswer(ctx, config.modelWrapperConf) + if err != nil { + return "", err + } + if accepted { + return compose.END, nil + } + return finalAnswerRejectionNode_, nil } return "", err_ } @@ -754,7 +775,7 @@ func newAgenticReact(ctx context.Context, config *agenticReactConfig) (agenticRe } } } - branch := compose.NewStreamGraphBranch(toolCallCheck, map[string]bool{compose.END: true, cancelCheckNode_: true}) + branch := compose.NewStreamGraphBranch(toolCallCheck, map[string]bool{compose.END: true, cancelCheckNode_: true, finalAnswerRejectionNode_: true}) _ = g.AddBranch(chatModel_, branch) _ = g.AddEdge(cancelCheckNode_, toolNode_) @@ -784,7 +805,21 @@ func newAgenticReact(ctx context.Context, config *agenticReactConfig) (agenticRe _ = g.AddLambdaNode(toolNodeToEndConverter, compose.InvokableLambda(cvt), compose.WithNodeName(toolNodeToEndConverter)) - _ = g.AddEdge(toolNodeToEndConverter, compose.END) + + // ReturnDirectly results also go through BeforeFinalAnswer hooks. + // If rejected, route to FinalAnswerRejection to loop back to ChatModel. + returnDirectFinalAnswerCheck := func(ctx context.Context, _ *schema.AgenticMessage) (string, error) { + accepted, err := runTypedBeforeFinalAnswer(ctx, config.modelWrapperConf) + if err != nil { + return "", err + } + if accepted { + return compose.END, nil + } + return finalAnswerRejectionNode_, nil + } + _ = g.AddBranch(toolNodeToEndConverter, compose.NewGraphBranch(returnDirectFinalAnswerCheck, + map[string]bool{compose.END: true, finalAnswerRejectionNode_: true})) checkReturnDirect := func(ctx context.Context, toolResults []*schema.AgenticMessage) (string, error) { _, ok := getAgenticReturnDirectlyToolCallID(ctx)