From 0f09d14bdbedef419392d8b157cac4562b73c1f3 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Thu, 16 Oct 2025 16:36:12 +0800 Subject: [PATCH 01/59] feat: define AgenticModel component interface --- components/agency/callback_extra.go | 101 ++++++++ components/agency/interface.go | 29 +++ components/agency/option.go | 77 ++++++ components/types.go | 1 + go.mod | 3 +- go.sum | 7 +- schema/agentic_message.go | 374 ++++++++++++++++++++++++++++ schema/anthropic/citation.go | 49 ++++ schema/anthropic/types.go | 10 + schema/google/candidate_meta.go | 66 +++++ schema/message.go | 4 +- schema/openai/annotation.go | 55 ++++ schema/openai/types.go | 10 + 13 files changed, 780 insertions(+), 6 deletions(-) create mode 100644 components/agency/callback_extra.go create mode 100644 components/agency/interface.go create mode 100644 components/agency/option.go create mode 100644 schema/agentic_message.go create mode 100644 schema/anthropic/citation.go create mode 100644 schema/anthropic/types.go create mode 100644 schema/google/candidate_meta.go create mode 100644 schema/openai/annotation.go create mode 100644 schema/openai/types.go diff --git a/components/agency/callback_extra.go b/components/agency/callback_extra.go new file mode 100644 index 000000000..984756d1a --- /dev/null +++ b/components/agency/callback_extra.go @@ -0,0 +1,101 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package agency + +import ( + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/schema" +) + +// TokenUsageMeta is the token usage for the model. +type TokenUsageMeta struct { + InputTokens int64 `json:"input_tokens"` + InputTokensDetails InputTokensUsageDetails `json:"input_tokens_details"` + OutputTokens int64 `json:"output_tokens"` + OutputTokensDetails OutputTokensUsageDetails `json:"output_tokens_details"` + TotalTokens int64 `json:"total_tokens"` +} + +type InputTokensUsageDetails struct { + CachedTokens int64 `json:"cached_tokens"` +} + +type OutputTokensUsageDetails struct { + ReasoningTokens int64 `json:"reasoning_tokens"` +} + +// Config is the config for the model. +type Config struct { + // Model is the model name. + Model string + // Temperature is the temperature, which controls the randomness of the model. + Temperature float32 + // TopP is the top p, which controls the diversity of the model. + TopP float32 +} + +// CallbackInput is the input for the model callback. +type CallbackInput struct { + // Responses is the responses to be sent to the model. + Responses []*schema.AgenticMessage + // Tools is the tools to be used in the model. + Tools []*schema.ToolInfo + // Config is the config for the model. + Config *Config + // Extra is the extra information for the callback. + Extra map[string]any +} + +// CallbackOutput is the output for the model callback. +type CallbackOutput struct { + // Response is the response generated by the model. + Response *schema.AgenticMessage + // Config is the config for the model. + Config *Config + // Usage is the token usage of this request. + Usage *TokenUsageMeta + // Extra is the extra information for the callback. + Extra map[string]any +} + +// ConvCallbackInput converts the callback input to the model callback input. +func ConvCallbackInput(src callbacks.CallbackInput) *CallbackInput { + switch t := src.(type) { + case *CallbackInput: // when callback is triggered within component implementation, the input is usually already a typed *model.CallbackInput + return t + case []*schema.AgenticMessage: // when callback is injected by graph node, not the component implementation itself, the input is the input of Chat Model interface, which is []*schema.AgenticMessage + return &CallbackInput{ + Responses: t, + } + default: + return nil + } +} + +// ConvCallbackOutput converts the callback output to the model callback output. +func ConvCallbackOutput(src callbacks.CallbackOutput) *CallbackOutput { + switch t := src.(type) { + case *CallbackOutput: // when callback is triggered within component implementation, the output is usually already a typed *model.CallbackOutput + return t + case *schema.AgenticMessage: // when callback is injected by graph node, not the component implementation itself, the output is the output of Chat Model interface, which is *schema.AgenticMessage + return &CallbackOutput{ + Response: t, + } + default: + return nil + } +} diff --git a/components/agency/interface.go b/components/agency/interface.go new file mode 100644 index 000000000..e33d6a933 --- /dev/null +++ b/components/agency/interface.go @@ -0,0 +1,29 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package agency + +import ( + "context" + + "github.com/cloudwego/eino/schema" +) + +type AgenticModel interface { + Generate(ctx context.Context, input []*schema.AgenticMessage, opts ...Option) (*schema.AgenticMessage, error) + Stream(ctx context.Context, input []*schema.AgenticMessage, opts ...Option) (*schema.StreamReader[*schema.AgenticMessage], error) + WithTools(tools []*schema.ToolInfo) (AgenticModel, error) +} diff --git a/components/agency/option.go b/components/agency/option.go new file mode 100644 index 000000000..17028f6e9 --- /dev/null +++ b/components/agency/option.go @@ -0,0 +1,77 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package agency + +// Options is the common options for the model. +type Options struct { +} + +// Option is the call option for ChatModel component. +type Option struct { + apply func(opts *Options) + + implSpecificOptFn any +} + +// WrapImplSpecificOptFn is the option to wrap the implementation specific option function. +func WrapImplSpecificOptFn[T any](optFn func(*T)) Option { + return Option{ + implSpecificOptFn: optFn, + } +} + +// GetCommonOptions extract model Options from Option list, optionally providing a base Options with default values. +func GetCommonOptions(base *Options, opts ...Option) *Options { + if base == nil { + base = &Options{} + } + + for i := range opts { + opt := opts[i] + if opt.apply != nil { + opt.apply(base) + } + } + + return base +} + +// GetImplSpecificOptions extract the implementation specific options from Option list, optionally providing a base options with default values. +// e.g. +// +// myOption := &MyOption{ +// Field1: "default_value", +// } +// +// myOption := model.GetImplSpecificOptions(myOption, opts...) +func GetImplSpecificOptions[T any](base *T, opts ...Option) *T { + if base == nil { + base = new(T) + } + + for i := range opts { + opt := opts[i] + if opt.implSpecificOptFn != nil { + optFn, ok := opt.implSpecificOptFn.(func(*T)) + if ok { + optFn(base) + } + } + } + + return base +} diff --git a/components/types.go b/components/types.go index a546ae59f..a23d82a68 100644 --- a/components/types.go +++ b/components/types.go @@ -68,6 +68,7 @@ const ( ComponentOfPrompt Component = "ChatTemplate" // ComponentOfChatModel identifies chat model components. ComponentOfChatModel Component = "ChatModel" + ComponentOfAgenticModel Component = "AgenticModel" // ComponentOfEmbedding identifies embedding components. ComponentOfEmbedding Component = "Embedding" // ComponentOfIndexer identifies indexer components. diff --git a/go.mod b/go.mod index cfa6957cc..0b87a6cab 100644 --- a/go.mod +++ b/go.mod @@ -41,6 +41,7 @@ require ( github.com/yargevad/filepathx v1.0.0 // indirect golang.org/x/arch v0.11.0 // indirect golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 // indirect - golang.org/x/sys v0.26.0 // indirect + golang.org/x/sys v0.29.0 // indirect + golang.org/x/term v0.28.0 // indirect gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect ) diff --git a/go.sum b/go.sum index a80d6399b..5813766b2 100644 --- a/go.sum +++ b/go.sum @@ -117,9 +117,10 @@ golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= -golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.10.0 h1:3R7pNqamzBraeqj/Tj8qt1aQ2HpmlC+Cx/qL/7hn4/c= +golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= +golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg= +golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= diff --git a/schema/agentic_message.go b/schema/agentic_message.go new file mode 100644 index 000000000..e386ed044 --- /dev/null +++ b/schema/agentic_message.go @@ -0,0 +1,374 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package schema + +import ( + "github.com/cloudwego/eino/schema/anthropic" + "github.com/cloudwego/eino/schema/google" + "github.com/cloudwego/eino/schema/openai" + "github.com/eino-contrib/jsonschema" +) + +type ContentBlockType string + +const ( + ContentBlockTypeReasoning ContentBlockType = "reasoning" + ContentBlockTypeUserInputText ContentBlockType = "user_input_text" + ContentBlockTypeUserInputImage ContentBlockType = "user_input_image" + ContentBlockTypeUserInputAudio ContentBlockType = "user_input_audio" + ContentBlockTypeUserInputVideo ContentBlockType = "user_input_video" + ContentBlockTypeUserInputFile ContentBlockType = "user_input_file" + ContentBlockTypeAssistantGenText ContentBlockType = "assistant_gen_text" + ContentBlockTypeAssistantGenImage ContentBlockType = "assistant_gen_image" + ContentBlockTypeAssistantGenAudio ContentBlockType = "assistant_gen_audio" + ContentBlockTypeAssistantGenVideo ContentBlockType = "assistant_gen_video" + ContentBlockTypeFunctionToolCall ContentBlockType = "function_tool_call" + ContentBlockTypeFunctionToolResult ContentBlockType = "function_tool_result" + ContentBlockTypeServerToolCall ContentBlockType = "server_tool_call" + ContentBlockTypeServerToolResult ContentBlockType = "server_tool_result" + ContentBlockTypeMCPToolCall ContentBlockType = "mcp_tool_call" + ContentBlockTypeMCPToolResult ContentBlockType = "mcp_tool_result" + ContentBlockTypeMCPListTools ContentBlockType = "mcp_list_tools" + ContentBlockTypeMCPToolApprovalRequest ContentBlockType = "mcp_tool_approval_request" + ContentBlockTypeMCPToolApprovalResponse ContentBlockType = "mcp_tool_approval_response" +) + +type AgenticRoleType string + +const ( + AgenticRoleTypeDeveloper AgenticRoleType = "developer" + AgenticRoleTypeSystem AgenticRoleType = "system" + AgenticRoleTypeUser AgenticRoleType = "user" + AgenticRoleTypeAssistant AgenticRoleType = "assistant" +) + +type AgenticMessage struct { + ResponseMeta *AgenticResponseMeta + + Role AgenticRoleType + ContentBlocks []*ContentBlock + + Extra map[string]any +} + +type AgenticResponseMeta struct { + Status *string + FinishReason string + + TokenUsage *TokenUsage + + GoogleAdditionalMeta *google.CandidateMeta +} + +type StreamMeta struct { + // Index is used for streaming to identify the chunk of the block for concatenation. + Index *int + // Streaming phase of the content block. + Phase StreamPhase +} + +type ContentBlock struct { + Type ContentBlockType + + Reasoning *Reasoning + + UserInputText *UserInputText + UserInputImage *UserInputImage + UserInputAudio *UserInputAudio + UserInputVideo *UserInputVideo + UserInputFile *UserInputFile + + AssistantGenText *AssistantGenText + AssistantGenImage *AssistantGenImage + AssistantGenAudio *AssistantGenAudio + AssistantGenVideo *AssistantGenVideo + + // FunctionToolCall holds invocation details for a user-defined tool. + FunctionToolCall *FunctionToolCall + // FunctionToolResult is the result from a user-defined tool call. + FunctionToolResult *FunctionToolResult + // ServerToolCall holds invocation details for a provider built-in tool run on the model server. + ServerToolCall *ServerToolCall + // ServerToolResult is the result from a provider built-in tool run on the model server. + ServerToolResult *ServerToolResult + + // MCPToolCall holds invocation details for an MCP tool managed by the model server. + MCPToolCall *MCPToolCall + // MCPToolResult is the result from an MCP tool managed by the model server. + MCPToolResult *MCPToolResult + // MCPListToolsResult lists available MCP tools reported by the model server. + MCPListToolsResult *MCPListToolsResult + // MCPToolApprovalRequest requests user approval for an MCP tool call when required. + MCPToolApprovalRequest *MCPToolApprovalRequest + // MCPToolApprovalResponse records the user's approval decision for an MCP tool call. + MCPToolApprovalResponse *MCPToolApprovalResponse + + StreamMeta *StreamMeta +} + +type StreamPhase string + +const ( + StreamPhaseStart StreamPhase = "start" + StreamPhaseDelta StreamPhase = "delta" + StreamPhaseStop StreamPhase = "stop" +) + +type UserInputText struct { + Text string + + // Extra stores additional information. + Extra map[string]any +} + +type UserInputImage struct { + URL *string + Base64Data *string + MIMEType string + Detail ImageURLDetail + + // Extra stores additional information. + Extra map[string]any +} + +type UserInputAudio struct { + URL *string + Base64Data *string + MIMEType string + + // Extra stores additional information. + Extra map[string]any +} + +type UserInputVideo struct { + URL *string + Base64Data *string + MIMEType string + + // Extra stores additional information. + Extra map[string]any +} + +type UserInputFile struct { + URL *string + Name *string + Base64Data *string + MIMEType string + + // Extra stores additional information. + Extra map[string]any +} + +type AssistantGenText struct { + Text string + + OpenAIAnnotations []*openai.TextAnnotation + AnthropicCitations []*anthropic.TextCitation + + // Extra stores additional information. + Extra map[string]any +} + +type AssistantGenImage struct { + URL *string + Base64Data *string + MIMEType string + + // Extra stores additional information. + Extra map[string]any +} + +type AssistantGenAudio struct { + URL *string + Base64Data *string + MIMEType string + + // Extra stores additional information. + Extra map[string]any +} + +type AssistantGenVideo struct { + URL *string + Base64Data *string + MIMEType string + + // Extra stores additional information. + Extra map[string]any +} + +type Reasoning struct { + // Summary is the reasoning content summary. + Summary []*ReasoningSummary + + // EncryptedContent is the encrypted reasoning content. + EncryptedContent string + + // Extra stores additional information. + Extra map[string]any +} + +type ReasoningSummary struct { + // Index specifies the ReasoningSummary chunk to be concatenated during streaming. + Index *int + + Text string +} + +type FunctionToolCall struct { + // CallID is the unique identifier for the tool call. + CallID string + + // Name specifies the function tool invoked. + Name string + + // Arguments is the JSON string arguments for the function tool call. + Arguments string + + // Extra stores additional information + Extra map[string]any +} + +type FunctionToolResult struct { + // CallID is the unique identifier for the tool call. + CallID string + + // Name specifies the function tool invoked. + Name string + + // Result is the function tool result returned by the user + Result string + + // Extra stores additional information. + Extra map[string]any +} + +type ServerToolCall struct { + // Name specifies the server-side tool invoked. + // Supplied by the model server (e.g., `web_search` for OpenAI, `googleSearch` for Gemini). + Name string + + // CallID is the unique identifier for the tool call. + // Empty if not provided by the model server. + CallID string + + // Arguments are the raw inputs to the server-side tool, + // supplied by the component implementer. + Arguments any + + // Extra stores additional information. + Extra map[string]any +} + +type ServerToolResult struct { + // Name specifies the server-side tool invoked. + // Supplied by the model server (e.g., `web_search` for OpenAI, `googleSearch` for Gemini). + Name string + + // CallID is the unique identifier for the tool call. + // Empty if not provided by the model server. + CallID string + + // Result refers to the raw output generated by the server-side tool, + // supplied by the component implementer. + Result any + + // Extra stores additional information. + Extra map[string]any +} + +type MCPToolCall struct { + // ServerLabel is the MCP server label used to identify it in tool calls + ServerLabel string + // ApprovalRequestID is the unique ID of the approval request. + ApprovalRequestID string + // CallID is the unique ID of the tool call. + CallID string + // Name is the name of the tool to run. + Name string + // Arguments is the JSON string arguments for the tool call. + Arguments string + + // Extra stores additional information. + Extra map[string]any +} + +type MCPToolResult struct { + // CallID is the unique ID of the tool call. + CallID string + // Name is the name of the tool to run. + Name string + // Result is the JSON string with the tool result. + Result string + // Error returned when the server fails to run the tool. + Error *MCPToolCallError + + // Extra stores additional information. + Extra map[string]any +} + +type MCPToolCallError struct { + Code int64 + Error string +} + +type MCPListToolsResult struct { + // ServerLabel is the MCP server label used to identify it in tool calls. + ServerLabel string + // Tools is the list of tools available on the server. + Tools []MCPListToolsItem + // Error returned when the server fails to list tools. + Error string + + // Extra stores additional information. + Extra map[string]any +} + +type MCPListToolsItem struct { + // Name is the name of the tool. + Name string + // Description is the description of the tool. + Description string + // InputSchema is the JSON schema that describes the tool input. + InputSchema *jsonschema.Schema +} + +type MCPToolApprovalRequest struct { + // CallID is the unique ID of the tool call. + CallID string + // Name is the name of the tool to run. + Name string + // Arguments is the JSON string arguments for the tool call. + Arguments string + // ServerLabel is the MCP server label used to identify it in tool calls. + ServerLabel string + + // Extra stores additional information. + Extra map[string]any +} + +type MCPToolApprovalResponse struct { + // ApprovalRequestID is the approval request ID being responded to. + ApprovalRequestID string + // Approve indicates whether the request is approved. + Approve bool + // Reason is the rationale for the decision. + // Optional. + Reason string + + // Extra stores additional information. + Extra map[string]any +} diff --git a/schema/anthropic/citation.go b/schema/anthropic/citation.go new file mode 100644 index 000000000..24a4c5aa6 --- /dev/null +++ b/schema/anthropic/citation.go @@ -0,0 +1,49 @@ +package anthropic + +type TextCitation struct { + Type TextCitationType `json:"type"` + + CharLocation *CitationCharLocation `json:"char_location,omitempty"` + PageLocation *CitationPageLocation `json:"page_location,omitempty"` + ContentBlockLocation *CitationContentBlockLocation `json:"content_block_location,omitempty"` + WebSearchResultLocation *CitationWebSearchResultLocation `json:"web_search_result_location,omitempty"` +} + +type CitationCharLocation struct { + CitedText string `json:"cited_text"` + + DocumentTitle string `json:"document_title"` + DocumentIndex int64 `json:"document_index"` + + StartCharIndex int64 `json:"start_char_index"` + EndCharIndex int64 `json:"end_char_index"` +} + +type CitationPageLocation struct { + CitedText string `json:"cited_text"` + + DocumentTitle string `json:"document_title"` + DocumentIndex int64 `json:"document_index"` + + StartPageNumber int64 `json:"start_page_number"` + EndPageNumber int64 `json:"end_page_number"` +} + +type CitationContentBlockLocation struct { + CitedText string `json:"cited_text"` + + DocumentTitle string `json:"document_title"` + DocumentIndex int64 `json:"document_index"` + + StartBlockIndex int64 `json:"start_block_index"` + EndBlockIndex int64 `json:"end_block_index"` +} + +type CitationWebSearchResultLocation struct { + CitedText string `json:"cited_text"` + + Title string `json:"title"` + URL string `json:"url"` + + EncryptedIndex string `json:"encrypted_index"` +} diff --git a/schema/anthropic/types.go b/schema/anthropic/types.go new file mode 100644 index 000000000..fbc85475d --- /dev/null +++ b/schema/anthropic/types.go @@ -0,0 +1,10 @@ +package anthropic + +type TextCitationType string + +const ( + TextCitationTypeCharLocation TextCitationType = "char_location" + TextCitationTypePageLocation TextCitationType = "page_location" + TextCitationTypeContentBlockLocation TextCitationType = "content_block_location" + TextCitationTypeWebSearchResultLocation TextCitationType = "web_search_result_location" +) diff --git a/schema/google/candidate_meta.go b/schema/google/candidate_meta.go new file mode 100644 index 000000000..aead31c2e --- /dev/null +++ b/schema/google/candidate_meta.go @@ -0,0 +1,66 @@ +package google + +type CandidateMeta struct { + GroundingMeta *GroundingMetadata `json:"grounding_meta,omitempty"` +} + +type GroundingMetadata struct { + // List of supporting references retrieved from specified grounding source. + GroundingChunks []*GroundingChunk `json:"grounding_chunks,omitempty"` + // Optional. List of grounding support. + GroundingSupports []*GroundingSupport `json:"grounding_supports,omitempty"` + // Optional. Google search entry for the following-up web searches. + SearchEntryPoint *SearchEntryPoint `json:"search_entry_point,omitempty"` + // Optional. Web search queries for the following-up web search. + WebSearchQueries []string `json:"web_search_queries,omitempty"` +} + +type GroundingChunk struct { + // Grounding chunk from the web. + Web *GroundingChunkWeb `json:"web,omitempty"` +} + +// Chunk from the web. +type GroundingChunkWeb struct { + // Domain of the (original) URI. This field is not supported in Gemini API. + Domain string `json:"domain,omitempty"` + // Title of the chunk. + Title string `json:"title,omitempty"` + // URI reference of the chunk. + URI string `json:"uri,omitempty"` +} + +type GroundingSupport struct { + // Confidence score of the support references. Ranges from 0 to 1. 1 is the most confident. + // For Gemini 2.0 and before, this list must have the same size as the grounding_chunk_indices. + // For Gemini 2.5 and after, this list will be empty and should be ignored. + ConfidenceScores []float32 `json:"confidence_scores,omitempty"` + // A list of indices (into 'grounding_chunk') specifying the citations associated with + // the claim. For instance [1,3,4] means that grounding_chunk[1], grounding_chunk[3], + // grounding_chunk[4] are the retrieved content attributed to the claim. + GroundingChunkIndices []int32 `json:"grounding_chunk_indices,omitempty"` + // Segment of the content this support belongs to. + Segment *Segment `json:"segment,omitempty"` +} + +// Segment of the content. +type Segment struct { + // Output only. End index in the given Part, measured in bytes. Offset from the start + // of the Part, exclusive, starting at zero. + EndIndex int32 `json:"end_index,omitempty"` + // Output only. The index of a Part object within its parent Content object. + PartIndex int32 `json:"part_index,omitempty"` + // Output only. Start index in the given Part, measured in bytes. Offset from the start + // of the Part, inclusive, starting at zero. + StartIndex int32 `json:"start_index,omitempty"` + // Output only. The text corresponding to the segment from the response. + Text string `json:"text,omitempty"` +} + +// Google search entry point. +type SearchEntryPoint struct { + // Optional. Web content snippet that can be embedded in a web page or an app webview. + RenderedContent string `json:"rendered_content,omitempty"` + // Optional. Base64 encoded JSON representing array of tuple. + SDKBlob []byte `json:"sdk_blob,omitempty"` +} diff --git a/schema/message.go b/schema/message.go index 3746244bb..fefb2079e 100644 --- a/schema/message.go +++ b/schema/message.go @@ -694,10 +694,10 @@ type TokenUsage struct { PromptTokenDetails PromptTokenDetails `json:"prompt_token_details"` // CompletionTokens is the number of completion tokens. CompletionTokens int `json:"completion_tokens"` + // CompletionTokenDetails is a breakdown of the completion tokens. + CompletionTokenDetails CompletionTokensDetails `json:"completion_token_details"` // TotalTokens is the total number of tokens. TotalTokens int `json:"total_tokens"` - // CompletionTokensDetails is breakdown of completion tokens. - CompletionTokensDetails CompletionTokensDetails `json:"completion_token_details"` } type CompletionTokensDetails struct { diff --git a/schema/openai/annotation.go b/schema/openai/annotation.go new file mode 100644 index 000000000..a834e8072 --- /dev/null +++ b/schema/openai/annotation.go @@ -0,0 +1,55 @@ +package openai + +type TextAnnotation struct { + Type TextAnnotationType `json:"type"` + + FileCitation *TextAnnotationFileCitation `json:"file_citation,omitempty"` + URLCitation *TextAnnotationURLCitation `json:"url_citation,omitempty"` + ContainerFileCitation *TextAnnotationContainerFileCitation `json:"container_file_citation,omitempty"` + FilePath *TextAnnotationFilePath `json:"file_path,omitempty"` +} + +type TextAnnotationFileCitation struct { + // The ID of the file. + FileID string `json:"file_id"` + // The filename of the file cited. + Filename string `json:"filename"` + + // The index of the file in the list of files. + Index int64 `json:"index"` +} + +type TextAnnotationURLCitation struct { + // The title of the web resource. + Title string `json:"title"` + // The URL of the web resource. + URL string `json:"url"` + + // The index of the first character of the URL citation in the message. + StartIndex int64 `json:"start_index"` + // The index of the last character of the URL citation in the message. + EndIndex int64 `json:"end_index"` +} + +type TextAnnotationContainerFileCitation struct { + // The ID of the container file. + ContainerID string `json:"container_id"` + + // The ID of the file. + FileID string `json:"file_id"` + // The filename of the container file cited. + Filename string `json:"filename"` + + // The index of the first character of the container file citation in the message. + StartIndex int64 `json:"start_index"` + // The index of the last character of the container file citation in the message. + EndIndex int64 `json:"end_index"` +} + +type TextAnnotationFilePath struct { + // The ID of the file. + FileID string `json:"file_id"` + + // The index of the file in the list of files. + Index int64 `json:"index"` +} diff --git a/schema/openai/types.go b/schema/openai/types.go new file mode 100644 index 000000000..60cee4361 --- /dev/null +++ b/schema/openai/types.go @@ -0,0 +1,10 @@ +package openai + +type TextAnnotationType string + +const ( + TextAnnotationTypeFileCitation TextAnnotationType = "file_citation" + TextAnnotationTypeURLCitation TextAnnotationType = "url_citation" + TextAnnotationTypeContainerFileCitation TextAnnotationType = "container_file_citation" + TextAnnotationTypeFilePath TextAnnotationType = "file_path" +) From cc213db8272c8d1cc05871307d15b7233f31a7b4 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Tue, 25 Nov 2025 11:35:01 +0800 Subject: [PATCH 02/59] feat: change the index in StreamMeta to a non-pointer (#573) --- schema/agentic_message.go | 14 ++------------ schema/anthropic/citation.go | 16 ++++++++++++++++ schema/anthropic/types.go | 16 ++++++++++++++++ schema/google/candidate_meta.go | 16 ++++++++++++++++ schema/openai/annotation.go | 16 ++++++++++++++++ schema/openai/types.go | 16 ++++++++++++++++ 6 files changed, 82 insertions(+), 12 deletions(-) diff --git a/schema/agentic_message.go b/schema/agentic_message.go index e386ed044..84a933c9e 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -75,10 +75,8 @@ type AgenticResponseMeta struct { } type StreamMeta struct { - // Index is used for streaming to identify the chunk of the block for concatenation. - Index *int - // Streaming phase of the content block. - Phase StreamPhase + // Index is the index position of this block in the final response. + Index int } type ContentBlock struct { @@ -120,14 +118,6 @@ type ContentBlock struct { StreamMeta *StreamMeta } -type StreamPhase string - -const ( - StreamPhaseStart StreamPhase = "start" - StreamPhaseDelta StreamPhase = "delta" - StreamPhaseStop StreamPhase = "stop" -) - type UserInputText struct { Text string diff --git a/schema/anthropic/citation.go b/schema/anthropic/citation.go index 24a4c5aa6..064477688 100644 --- a/schema/anthropic/citation.go +++ b/schema/anthropic/citation.go @@ -1,3 +1,19 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package anthropic type TextCitation struct { diff --git a/schema/anthropic/types.go b/schema/anthropic/types.go index fbc85475d..cc8b1f877 100644 --- a/schema/anthropic/types.go +++ b/schema/anthropic/types.go @@ -1,3 +1,19 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package anthropic type TextCitationType string diff --git a/schema/google/candidate_meta.go b/schema/google/candidate_meta.go index aead31c2e..8cd324254 100644 --- a/schema/google/candidate_meta.go +++ b/schema/google/candidate_meta.go @@ -1,3 +1,19 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package google type CandidateMeta struct { diff --git a/schema/openai/annotation.go b/schema/openai/annotation.go index a834e8072..ad4b6b91f 100644 --- a/schema/openai/annotation.go +++ b/schema/openai/annotation.go @@ -1,3 +1,19 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package openai type TextAnnotation struct { diff --git a/schema/openai/types.go b/schema/openai/types.go index 60cee4361..321ee2a9e 100644 --- a/schema/openai/types.go +++ b/schema/openai/types.go @@ -1,3 +1,19 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package openai type TextAnnotationType string From 52ad46f9e077c6ca10e5ed5ec5987bc59c1f12ad Mon Sep 17 00:00:00 2001 From: mrh997 Date: Tue, 25 Nov 2025 21:34:26 +0800 Subject: [PATCH 03/59] feat: improve AgenticResponseMeta definition (#575) --- schema/agentic_message.go | 30 +++++++---------- schema/{anthropic => claude}/citation.go | 2 +- schema/claude/messages_meta.go | 22 +++++++++++++ schema/{anthropic => claude}/types.go | 2 +- .../response_meta.go} | 6 ++-- schema/openai/response_meta.go | 33 +++++++++++++++++++ 6 files changed, 72 insertions(+), 23 deletions(-) rename schema/{anthropic => claude}/citation.go (99%) create mode 100644 schema/claude/messages_meta.go rename schema/{anthropic => claude}/types.go (98%) rename schema/{google/candidate_meta.go => gemini/response_meta.go} (95%) create mode 100644 schema/openai/response_meta.go diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 84a933c9e..0d2b01047 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -17,8 +17,8 @@ package schema import ( - "github.com/cloudwego/eino/schema/anthropic" - "github.com/cloudwego/eino/schema/google" + "github.com/cloudwego/eino/schema/claude" + "github.com/cloudwego/eino/schema/gemini" "github.com/cloudwego/eino/schema/openai" "github.com/eino-contrib/jsonschema" ) @@ -66,16 +66,16 @@ type AgenticMessage struct { } type AgenticResponseMeta struct { - Status *string - FinishReason string - TokenUsage *TokenUsage - GoogleAdditionalMeta *google.CandidateMeta + OpenAIExtensions *openai.ResponseMeta + GeminiExtensions *gemini.ResponseMeta + ClaudeExtensions *claude.MessageMeta + Extensions any } type StreamMeta struct { - // Index is the index position of this block in the final response. + // Index specifies the index position of this block in the final response. Index int } @@ -166,8 +166,8 @@ type UserInputFile struct { type AssistantGenText struct { Text string - OpenAIAnnotations []*openai.TextAnnotation - AnthropicCitations []*anthropic.TextCitation + OpenAIAnnotations []*openai.TextAnnotation + ClaudeCitations []*claude.TextCitation // Extra stores additional information. Extra map[string]any @@ -203,7 +203,6 @@ type AssistantGenVideo struct { type Reasoning struct { // Summary is the reasoning content summary. Summary []*ReasoningSummary - // EncryptedContent is the encrypted reasoning content. EncryptedContent string @@ -212,8 +211,8 @@ type Reasoning struct { } type ReasoningSummary struct { - // Index specifies the ReasoningSummary chunk to be concatenated during streaming. - Index *int + // Index specifies the index position of this summary in the final Reasoning. + Index int Text string } @@ -221,10 +220,8 @@ type ReasoningSummary struct { type FunctionToolCall struct { // CallID is the unique identifier for the tool call. CallID string - // Name specifies the function tool invoked. Name string - // Arguments is the JSON string arguments for the function tool call. Arguments string @@ -235,10 +232,8 @@ type FunctionToolCall struct { type FunctionToolResult struct { // CallID is the unique identifier for the tool call. CallID string - // Name specifies the function tool invoked. Name string - // Result is the function tool result returned by the user Result string @@ -250,15 +245,12 @@ type ServerToolCall struct { // Name specifies the server-side tool invoked. // Supplied by the model server (e.g., `web_search` for OpenAI, `googleSearch` for Gemini). Name string - // CallID is the unique identifier for the tool call. // Empty if not provided by the model server. CallID string - // Arguments are the raw inputs to the server-side tool, // supplied by the component implementer. Arguments any - // Extra stores additional information. Extra map[string]any } diff --git a/schema/anthropic/citation.go b/schema/claude/citation.go similarity index 99% rename from schema/anthropic/citation.go rename to schema/claude/citation.go index 064477688..b5092d3bc 100644 --- a/schema/anthropic/citation.go +++ b/schema/claude/citation.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package anthropic +package claude type TextCitation struct { Type TextCitationType `json:"type"` diff --git a/schema/claude/messages_meta.go b/schema/claude/messages_meta.go new file mode 100644 index 000000000..a72dded2a --- /dev/null +++ b/schema/claude/messages_meta.go @@ -0,0 +1,22 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package claude + +type MessageMeta struct { + ID string `json:"id"` + StopReason string `json:"stop_reason"` +} diff --git a/schema/anthropic/types.go b/schema/claude/types.go similarity index 98% rename from schema/anthropic/types.go rename to schema/claude/types.go index cc8b1f877..cbf8784f6 100644 --- a/schema/anthropic/types.go +++ b/schema/claude/types.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package anthropic +package claude type TextCitationType string diff --git a/schema/google/candidate_meta.go b/schema/gemini/response_meta.go similarity index 95% rename from schema/google/candidate_meta.go rename to schema/gemini/response_meta.go index 8cd324254..3bc590a72 100644 --- a/schema/google/candidate_meta.go +++ b/schema/gemini/response_meta.go @@ -14,9 +14,11 @@ * limitations under the License. */ -package google +package gemini -type CandidateMeta struct { +type ResponseMeta struct { + ID string `json:"id"` + FinishReason string `json:"finish_reason"` GroundingMeta *GroundingMetadata `json:"grounding_meta,omitempty"` } diff --git a/schema/openai/response_meta.go b/schema/openai/response_meta.go new file mode 100644 index 000000000..1b184073e --- /dev/null +++ b/schema/openai/response_meta.go @@ -0,0 +1,33 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package openai + +type ResponseMeta struct { + ID string `json:"id"` + Status string `json:"status"` + Error *ResponseError `json:"error,omitempty"` + IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"` +} + +type ResponseError struct { + Code string `json:"code"` + Message string `json:"message"` +} + +type IncompleteDetails struct { + Reason string `json:"reason"` +} From f74acacb32be6c8e90dd71ea641a1b419ee4a95d Mon Sep 17 00:00:00 2001 From: mrh997 Date: Wed, 26 Nov 2025 16:08:32 +0800 Subject: [PATCH 04/59] feat: improve AssistantGenText definition (#577) --- schema/agentic_message.go | 5 +++-- schema/claude/{types.go => consts.go} | 0 schema/claude/{citation.go => content_block.go} | 4 ++++ schema/claude/{messages_meta.go => message_meta.go} | 0 schema/openai/{types.go => consts.go} | 0 schema/openai/{annotation.go => content_block.go} | 5 +++++ 6 files changed, 12 insertions(+), 2 deletions(-) rename schema/claude/{types.go => consts.go} (100%) rename schema/claude/{citation.go => content_block.go} (96%) rename schema/claude/{messages_meta.go => message_meta.go} (100%) rename schema/openai/{types.go => consts.go} (100%) rename schema/openai/{annotation.go => content_block.go} (94%) diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 0d2b01047..09ab28e40 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -166,8 +166,9 @@ type UserInputFile struct { type AssistantGenText struct { Text string - OpenAIAnnotations []*openai.TextAnnotation - ClaudeCitations []*claude.TextCitation + OpenAIExtensions *openai.OutputText + ClaudeExtensions *claude.TextBlock + Extensions any // Extra stores additional information. Extra map[string]any diff --git a/schema/claude/types.go b/schema/claude/consts.go similarity index 100% rename from schema/claude/types.go rename to schema/claude/consts.go diff --git a/schema/claude/citation.go b/schema/claude/content_block.go similarity index 96% rename from schema/claude/citation.go rename to schema/claude/content_block.go index b5092d3bc..ba297126e 100644 --- a/schema/claude/citation.go +++ b/schema/claude/content_block.go @@ -16,6 +16,10 @@ package claude +type TextBlock struct { + Citations []*TextCitation `json:"citations"` +} + type TextCitation struct { Type TextCitationType `json:"type"` diff --git a/schema/claude/messages_meta.go b/schema/claude/message_meta.go similarity index 100% rename from schema/claude/messages_meta.go rename to schema/claude/message_meta.go diff --git a/schema/openai/types.go b/schema/openai/consts.go similarity index 100% rename from schema/openai/types.go rename to schema/openai/consts.go diff --git a/schema/openai/annotation.go b/schema/openai/content_block.go similarity index 94% rename from schema/openai/annotation.go rename to schema/openai/content_block.go index ad4b6b91f..5135964b7 100644 --- a/schema/openai/annotation.go +++ b/schema/openai/content_block.go @@ -16,6 +16,11 @@ package openai +type OutputText struct { + ItemID string `json:"item_id"` + Annotations []*TextAnnotation `json:"annotations"` +} + type TextAnnotation struct { Type TextAnnotationType `json:"type"` From 0186a33ccba2164a65185d2050e8389283e74ccb Mon Sep 17 00:00:00 2001 From: mrh997 Date: Wed, 26 Nov 2025 17:52:22 +0800 Subject: [PATCH 05/59] feat: improve extension type name (#578) --- schema/agentic_message.go | 14 +++++++------- schema/claude/content_block.go | 2 +- .../claude/{message_meta.go => response_meta.go} | 2 +- schema/gemini/response_meta.go | 2 +- schema/openai/content_block.go | 2 +- schema/openai/response_meta.go | 2 +- 6 files changed, 12 insertions(+), 12 deletions(-) rename schema/claude/{message_meta.go => response_meta.go} (95%) diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 09ab28e40..3530038e2 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -68,10 +68,10 @@ type AgenticMessage struct { type AgenticResponseMeta struct { TokenUsage *TokenUsage - OpenAIExtensions *openai.ResponseMeta - GeminiExtensions *gemini.ResponseMeta - ClaudeExtensions *claude.MessageMeta - Extensions any + OpenAIExtension *openai.ResponseMetaExtension + GeminiExtension *gemini.ResponseMetaExtension + ClaudeExtension *claude.ResponseMetaExtension + Extension any } type StreamMeta struct { @@ -166,9 +166,9 @@ type UserInputFile struct { type AssistantGenText struct { Text string - OpenAIExtensions *openai.OutputText - ClaudeExtensions *claude.TextBlock - Extensions any + OpenAIExtension *openai.AssistantGenTextExtension + ClaudeExtension *claude.AssistantGenTextExtension + Extension any // Extra stores additional information. Extra map[string]any diff --git a/schema/claude/content_block.go b/schema/claude/content_block.go index ba297126e..4421db807 100644 --- a/schema/claude/content_block.go +++ b/schema/claude/content_block.go @@ -16,7 +16,7 @@ package claude -type TextBlock struct { +type AssistantGenTextExtension struct { Citations []*TextCitation `json:"citations"` } diff --git a/schema/claude/message_meta.go b/schema/claude/response_meta.go similarity index 95% rename from schema/claude/message_meta.go rename to schema/claude/response_meta.go index a72dded2a..7d9dbe740 100644 --- a/schema/claude/message_meta.go +++ b/schema/claude/response_meta.go @@ -16,7 +16,7 @@ package claude -type MessageMeta struct { +type ResponseMetaExtension struct { ID string `json:"id"` StopReason string `json:"stop_reason"` } diff --git a/schema/gemini/response_meta.go b/schema/gemini/response_meta.go index 3bc590a72..bb4af92c9 100644 --- a/schema/gemini/response_meta.go +++ b/schema/gemini/response_meta.go @@ -16,7 +16,7 @@ package gemini -type ResponseMeta struct { +type ResponseMetaExtension struct { ID string `json:"id"` FinishReason string `json:"finish_reason"` GroundingMeta *GroundingMetadata `json:"grounding_meta,omitempty"` diff --git a/schema/openai/content_block.go b/schema/openai/content_block.go index 5135964b7..dfa83109a 100644 --- a/schema/openai/content_block.go +++ b/schema/openai/content_block.go @@ -16,7 +16,7 @@ package openai -type OutputText struct { +type AssistantGenTextExtension struct { ItemID string `json:"item_id"` Annotations []*TextAnnotation `json:"annotations"` } diff --git a/schema/openai/response_meta.go b/schema/openai/response_meta.go index 1b184073e..809fbb1a5 100644 --- a/schema/openai/response_meta.go +++ b/schema/openai/response_meta.go @@ -16,7 +16,7 @@ package openai -type ResponseMeta struct { +type ResponseMetaExtension struct { ID string `json:"id"` Status string `json:"status"` Error *ResponseError `json:"error,omitempty"` From 2bc09665d673cd15078108b7b04e8c0b2e2b5912 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Wed, 26 Nov 2025 20:23:14 +0800 Subject: [PATCH 06/59] feat: modify package name (#579) --- components/{agency => agentic}/callback_extra.go | 2 +- components/{agency => agentic}/interface.go | 6 +++--- components/{agency => agentic}/option.go | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) rename components/{agency => agentic}/callback_extra.go (99%) rename components/{agency => agentic}/interface.go (89%) rename components/{agency => agentic}/option.go (99%) diff --git a/components/agency/callback_extra.go b/components/agentic/callback_extra.go similarity index 99% rename from components/agency/callback_extra.go rename to components/agentic/callback_extra.go index 984756d1a..f824750f9 100644 --- a/components/agency/callback_extra.go +++ b/components/agentic/callback_extra.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package agency +package agentic import ( "github.com/cloudwego/eino/callbacks" diff --git a/components/agency/interface.go b/components/agentic/interface.go similarity index 89% rename from components/agency/interface.go rename to components/agentic/interface.go index e33d6a933..e9960d332 100644 --- a/components/agency/interface.go +++ b/components/agentic/interface.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package agency +package agentic import ( "context" @@ -22,8 +22,8 @@ import ( "github.com/cloudwego/eino/schema" ) -type AgenticModel interface { +type Model interface { Generate(ctx context.Context, input []*schema.AgenticMessage, opts ...Option) (*schema.AgenticMessage, error) Stream(ctx context.Context, input []*schema.AgenticMessage, opts ...Option) (*schema.StreamReader[*schema.AgenticMessage], error) - WithTools(tools []*schema.ToolInfo) (AgenticModel, error) + WithTools(tools []*schema.ToolInfo) (Model, error) } diff --git a/components/agency/option.go b/components/agentic/option.go similarity index 99% rename from components/agency/option.go rename to components/agentic/option.go index 17028f6e9..b000d6893 100644 --- a/components/agency/option.go +++ b/components/agentic/option.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package agency +package agentic // Options is the common options for the model. type Options struct { From 428bdcf0c39de17dfa5c83f4af1e788b811b8ff0 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Wed, 26 Nov 2025 21:06:29 +0800 Subject: [PATCH 07/59] feat: remove TokenUsage definition in CallbackOutput (#580) --- components/agentic/callback_extra.go | 31 ++++++---------------------- 1 file changed, 6 insertions(+), 25 deletions(-) diff --git a/components/agentic/callback_extra.go b/components/agentic/callback_extra.go index f824750f9..f35b37779 100644 --- a/components/agentic/callback_extra.go +++ b/components/agentic/callback_extra.go @@ -21,23 +21,6 @@ import ( "github.com/cloudwego/eino/schema" ) -// TokenUsageMeta is the token usage for the model. -type TokenUsageMeta struct { - InputTokens int64 `json:"input_tokens"` - InputTokensDetails InputTokensUsageDetails `json:"input_tokens_details"` - OutputTokens int64 `json:"output_tokens"` - OutputTokensDetails OutputTokensUsageDetails `json:"output_tokens_details"` - TotalTokens int64 `json:"total_tokens"` -} - -type InputTokensUsageDetails struct { - CachedTokens int64 `json:"cached_tokens"` -} - -type OutputTokensUsageDetails struct { - ReasoningTokens int64 `json:"reasoning_tokens"` -} - // Config is the config for the model. type Config struct { // Model is the model name. @@ -50,8 +33,8 @@ type Config struct { // CallbackInput is the input for the model callback. type CallbackInput struct { - // Responses is the responses to be sent to the model. - Responses []*schema.AgenticMessage + // Messages is the messages to be sent to the model. + Messages []*schema.AgenticMessage // Tools is the tools to be used in the model. Tools []*schema.ToolInfo // Config is the config for the model. @@ -62,12 +45,10 @@ type CallbackInput struct { // CallbackOutput is the output for the model callback. type CallbackOutput struct { - // Response is the response generated by the model. - Response *schema.AgenticMessage + // Message is the message generated by the model. + Message *schema.AgenticMessage // Config is the config for the model. Config *Config - // Usage is the token usage of this request. - Usage *TokenUsageMeta // Extra is the extra information for the callback. Extra map[string]any } @@ -79,7 +60,7 @@ func ConvCallbackInput(src callbacks.CallbackInput) *CallbackInput { return t case []*schema.AgenticMessage: // when callback is injected by graph node, not the component implementation itself, the input is the input of Chat Model interface, which is []*schema.AgenticMessage return &CallbackInput{ - Responses: t, + Messages: t, } default: return nil @@ -93,7 +74,7 @@ func ConvCallbackOutput(src callbacks.CallbackOutput) *CallbackOutput { return t case *schema.AgenticMessage: // when callback is injected by graph node, not the component implementation itself, the output is the output of Chat Model interface, which is *schema.AgenticMessage return &CallbackOutput{ - Response: t, + Message: t, } default: return nil From 1e7d37d94a74550bff30b4ec2b5b171d175582dd Mon Sep 17 00:00:00 2001 From: mrh997 Date: Mon, 1 Dec 2025 11:51:09 +0800 Subject: [PATCH 08/59] feat: add helper functions for AgenticMessage (#582) --- components/agentic/option.go | 62 +++++++++++++++++ schema/agentic_message.go | 119 ++++++++++++++++++++++++++++----- schema/openai/content_block.go | 1 - 3 files changed, 164 insertions(+), 18 deletions(-) diff --git a/components/agentic/option.go b/components/agentic/option.go index b000d6893..dae6139be 100644 --- a/components/agentic/option.go +++ b/components/agentic/option.go @@ -16,8 +16,22 @@ package agentic +import ( + "github.com/cloudwego/eino/schema" +) + // Options is the common options for the model. type Options struct { + // Temperature is the temperature for the model, which controls the randomness of the model. + Temperature *float32 + // Model is the model name. + Model *string + // TopP is the top p for the model, which controls the diversity of the model. + TopP *float32 + // Tools is a list of tools the model may call. + Tools []*schema.ToolInfo + // ToolChoice controls which tool is called by the model. + ToolChoice *schema.ToolChoice } // Option is the call option for ChatModel component. @@ -27,6 +41,54 @@ type Option struct { implSpecificOptFn any } +// WithTemperature is the option to set the temperature for the model. +func WithTemperature(temperature float32) Option { + return Option{ + apply: func(opts *Options) { + opts.Temperature = &temperature + }, + } +} + +// WithModel is the option to set the model name. +func WithModel(name string) Option { + return Option{ + apply: func(opts *Options) { + opts.Model = &name + }, + } +} + +// WithTopP is the option to set the top p for the model. +func WithTopP(topP float32) Option { + return Option{ + apply: func(opts *Options) { + opts.TopP = &topP + }, + } +} + +// WithTools is the option to set tools for the model. +func WithTools(tools []*schema.ToolInfo) Option { + if tools == nil { + tools = []*schema.ToolInfo{} + } + return Option{ + apply: func(opts *Options) { + opts.Tools = tools + }, + } +} + +// WithToolChoice is the option to set tool choice for the model. +func WithToolChoice(toolChoice schema.ToolChoice) Option { + return Option{ + apply: func(opts *Options) { + opts.ToolChoice = &toolChoice + }, + } +} + // WrapImplSpecificOptFn is the option to wrap the implementation specific option function. func WrapImplSpecificOptFn[T any](optFn func(*T)) Option { return Option{ diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 3530038e2..3953fac7c 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -126,8 +126,8 @@ type UserInputText struct { } type UserInputImage struct { - URL *string - Base64Data *string + URL string + Base64Data string MIMEType string Detail ImageURLDetail @@ -136,8 +136,8 @@ type UserInputImage struct { } type UserInputAudio struct { - URL *string - Base64Data *string + URL string + Base64Data string MIMEType string // Extra stores additional information. @@ -145,8 +145,8 @@ type UserInputAudio struct { } type UserInputVideo struct { - URL *string - Base64Data *string + URL string + Base64Data string MIMEType string // Extra stores additional information. @@ -154,9 +154,9 @@ type UserInputVideo struct { } type UserInputFile struct { - URL *string - Name *string - Base64Data *string + URL string + Name string + Base64Data string MIMEType string // Extra stores additional information. @@ -175,8 +175,8 @@ type AssistantGenText struct { } type AssistantGenImage struct { - URL *string - Base64Data *string + URL string + Base64Data string MIMEType string // Extra stores additional information. @@ -184,8 +184,8 @@ type AssistantGenImage struct { } type AssistantGenAudio struct { - URL *string - Base64Data *string + URL string + Base64Data string MIMEType string // Extra stores additional information. @@ -193,8 +193,8 @@ type AssistantGenAudio struct { } type AssistantGenVideo struct { - URL *string - Base64Data *string + URL string + Base64Data string MIMEType string // Extra stores additional information. @@ -290,6 +290,8 @@ type MCPToolCall struct { } type MCPToolResult struct { + // ServerLabel is the MCP server label used to identify it in tool calls + ServerLabel string // CallID is the unique ID of the tool call. CallID string // Name is the name of the tool to run. @@ -304,7 +306,7 @@ type MCPToolResult struct { } type MCPToolCallError struct { - Code int64 + Code *int64 Error string } @@ -312,7 +314,7 @@ type MCPListToolsResult struct { // ServerLabel is the MCP server label used to identify it in tool calls. ServerLabel string // Tools is the list of tools available on the server. - Tools []MCPListToolsItem + Tools []*MCPListToolsItem // Error returned when the server fails to list tools. Error string @@ -355,3 +357,86 @@ type MCPToolApprovalResponse struct { // Extra stores additional information. Extra map[string]any } + +// DeveloperAgenticMessage represents a message with AgenticRoleType "developer". +func DeveloperAgenticMessage(text string) *AgenticMessage { + return &AgenticMessage{ + Role: AgenticRoleTypeDeveloper, + ContentBlocks: []*ContentBlock{NewContentBlock(&UserInputText{Text: text})}, + } +} + +// SystemAgenticMessage represents a message with AgenticRoleType "system". +func SystemAgenticMessage(text string) *AgenticMessage { + return &AgenticMessage{ + Role: AgenticRoleTypeSystem, + ContentBlocks: []*ContentBlock{NewContentBlock(&UserInputText{Text: text})}, + } +} + +// UserAgenticMessage represents a message with AgenticRoleType "user". +func UserAgenticMessage(text string) *AgenticMessage { + return &AgenticMessage{ + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{NewContentBlock(&UserInputText{Text: text})}, + } +} + +// FunctionToolResultAgenticMessage represents a function tool result message with AgenticRoleType "user". +func FunctionToolResultAgenticMessage(callID, name, result string) *AgenticMessage { + return &AgenticMessage{ + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + NewContentBlock(&FunctionToolResult{ + CallID: callID, + Name: name, + Result: result, + }), + }, + } +} + +func NewContentBlock(block any) *ContentBlock { + switch b := block.(type) { + case *Reasoning: + return &ContentBlock{Type: ContentBlockTypeReasoning, Reasoning: b} + case *UserInputText: + return &ContentBlock{Type: ContentBlockTypeUserInputText, UserInputText: b} + case *UserInputImage: + return &ContentBlock{Type: ContentBlockTypeUserInputImage, UserInputImage: b} + case *UserInputAudio: + return &ContentBlock{Type: ContentBlockTypeUserInputAudio, UserInputAudio: b} + case *UserInputVideo: + return &ContentBlock{Type: ContentBlockTypeUserInputVideo, UserInputVideo: b} + case *UserInputFile: + return &ContentBlock{Type: ContentBlockTypeUserInputFile, UserInputFile: b} + case *AssistantGenText: + return &ContentBlock{Type: ContentBlockTypeAssistantGenText, AssistantGenText: b} + case *AssistantGenImage: + return &ContentBlock{Type: ContentBlockTypeAssistantGenImage, AssistantGenImage: b} + case *AssistantGenAudio: + return &ContentBlock{Type: ContentBlockTypeAssistantGenAudio, AssistantGenAudio: b} + case *AssistantGenVideo: + return &ContentBlock{Type: ContentBlockTypeAssistantGenVideo, AssistantGenVideo: b} + case *FunctionToolCall: + return &ContentBlock{Type: ContentBlockTypeFunctionToolCall, FunctionToolCall: b} + case *FunctionToolResult: + return &ContentBlock{Type: ContentBlockTypeFunctionToolResult, FunctionToolResult: b} + case *ServerToolCall: + return &ContentBlock{Type: ContentBlockTypeServerToolCall, ServerToolCall: b} + case *ServerToolResult: + return &ContentBlock{Type: ContentBlockTypeServerToolResult, ServerToolResult: b} + case *MCPToolCall: + return &ContentBlock{Type: ContentBlockTypeMCPToolCall, MCPToolCall: b} + case *MCPToolResult: + return &ContentBlock{Type: ContentBlockTypeMCPToolResult, MCPToolResult: b} + case *MCPListToolsResult: + return &ContentBlock{Type: ContentBlockTypeMCPListTools, MCPListToolsResult: b} + case *MCPToolApprovalRequest: + return &ContentBlock{Type: ContentBlockTypeMCPToolApprovalRequest, MCPToolApprovalRequest: b} + case *MCPToolApprovalResponse: + return &ContentBlock{Type: ContentBlockTypeMCPToolApprovalResponse, MCPToolApprovalResponse: b} + default: + return nil + } +} diff --git a/schema/openai/content_block.go b/schema/openai/content_block.go index dfa83109a..b0408e310 100644 --- a/schema/openai/content_block.go +++ b/schema/openai/content_block.go @@ -17,7 +17,6 @@ package openai type AssistantGenTextExtension struct { - ItemID string `json:"item_id"` Annotations []*TextAnnotation `json:"annotations"` } From 1f4004cf21d6d25e28ad603ee2978fd0e65a37f2 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Mon, 1 Dec 2025 12:08:12 +0800 Subject: [PATCH 09/59] feat: improve MCPToolCallError definition (#592) --- schema/agentic_message.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 3953fac7c..44fc37bb3 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -306,8 +306,8 @@ type MCPToolResult struct { } type MCPToolCallError struct { - Code *int64 - Error string + Code *int64 + Message string } type MCPListToolsResult struct { From bec326658f1168f50d4b68fc277b49e19f82defa Mon Sep 17 00:00:00 2001 From: mrh997 Date: Mon, 1 Dec 2025 14:15:54 +0800 Subject: [PATCH 10/59] feat: improve Options definition (#593) --- components/agentic/option.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/components/agentic/option.go b/components/agentic/option.go index dae6139be..ac117ddb4 100644 --- a/components/agentic/option.go +++ b/components/agentic/option.go @@ -23,11 +23,11 @@ import ( // Options is the common options for the model. type Options struct { // Temperature is the temperature for the model, which controls the randomness of the model. - Temperature *float32 + Temperature *float64 // Model is the model name. Model *string // TopP is the top p for the model, which controls the diversity of the model. - TopP *float32 + TopP *float64 // Tools is a list of tools the model may call. Tools []*schema.ToolInfo // ToolChoice controls which tool is called by the model. @@ -42,7 +42,7 @@ type Option struct { } // WithTemperature is the option to set the temperature for the model. -func WithTemperature(temperature float32) Option { +func WithTemperature(temperature float64) Option { return Option{ apply: func(opts *Options) { opts.Temperature = &temperature @@ -60,7 +60,7 @@ func WithModel(name string) Option { } // WithTopP is the option to set the top p for the model. -func WithTopP(topP float32) Option { +func WithTopP(topP float64) Option { return Option{ apply: func(opts *Options) { opts.TopP = &topP From 9c9b8bee007948a0b4de419eda8e65e0ab311b3c Mon Sep 17 00:00:00 2001 From: mrh997 Date: Mon, 1 Dec 2025 14:41:51 +0800 Subject: [PATCH 11/59] feat: add CallbackInput definition for CallbackInput (#594) --- components/agentic/callback_extra.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/components/agentic/callback_extra.go b/components/agentic/callback_extra.go index f35b37779..389408d33 100644 --- a/components/agentic/callback_extra.go +++ b/components/agentic/callback_extra.go @@ -37,6 +37,8 @@ type CallbackInput struct { Messages []*schema.AgenticMessage // Tools is the tools to be used in the model. Tools []*schema.ToolInfo + // ToolChoice controls which tool is called by the model. + ToolChoice *schema.ToolChoice // Config is the config for the model. Config *Config // Extra is the extra information for the callback. From df5eea08966f016fcc1bbefd0e81b89fd9839271 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Mon, 1 Dec 2025 17:01:51 +0800 Subject: [PATCH 12/59] feat: define 'omitempty' flag in json tag (#595) --- schema/agentic_message.go | 4 ++-- schema/claude/content_block.go | 42 +++++++++++++++++----------------- schema/claude/response_meta.go | 4 ++-- schema/gemini/response_meta.go | 4 ++-- schema/openai/content_block.go | 32 +++++++++++++------------- schema/openai/response_meta.go | 10 ++++---- 6 files changed, 48 insertions(+), 48 deletions(-) diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 44fc37bb3..367debd97 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -76,7 +76,7 @@ type AgenticResponseMeta struct { type StreamMeta struct { // Index specifies the index position of this block in the final response. - Index int + Index int64 } type ContentBlock struct { @@ -213,7 +213,7 @@ type Reasoning struct { type ReasoningSummary struct { // Index specifies the index position of this summary in the final Reasoning. - Index int + Index int64 Text string } diff --git a/schema/claude/content_block.go b/schema/claude/content_block.go index 4421db807..0c43d1045 100644 --- a/schema/claude/content_block.go +++ b/schema/claude/content_block.go @@ -17,11 +17,11 @@ package claude type AssistantGenTextExtension struct { - Citations []*TextCitation `json:"citations"` + Citations []*TextCitation `json:"citations,omitempty"` } type TextCitation struct { - Type TextCitationType `json:"type"` + Type TextCitationType `json:"type,omitempty"` CharLocation *CitationCharLocation `json:"char_location,omitempty"` PageLocation *CitationPageLocation `json:"page_location,omitempty"` @@ -30,40 +30,40 @@ type TextCitation struct { } type CitationCharLocation struct { - CitedText string `json:"cited_text"` + CitedText string `json:"cited_text,omitempty"` - DocumentTitle string `json:"document_title"` - DocumentIndex int64 `json:"document_index"` + DocumentTitle string `json:"document_title,omitempty"` + DocumentIndex int64 `json:"document_index,omitempty"` - StartCharIndex int64 `json:"start_char_index"` - EndCharIndex int64 `json:"end_char_index"` + StartCharIndex int64 `json:"start_char_index,omitempty"` + EndCharIndex int64 `json:"end_char_index,omitempty"` } type CitationPageLocation struct { - CitedText string `json:"cited_text"` + CitedText string `json:"cited_text,omitempty"` - DocumentTitle string `json:"document_title"` - DocumentIndex int64 `json:"document_index"` + DocumentTitle string `json:"document_title,omitempty"` + DocumentIndex int64 `json:"document_index,omitempty"` - StartPageNumber int64 `json:"start_page_number"` - EndPageNumber int64 `json:"end_page_number"` + StartPageNumber int64 `json:"start_page_number,omitempty"` + EndPageNumber int64 `json:"end_page_number,omitempty"` } type CitationContentBlockLocation struct { - CitedText string `json:"cited_text"` + CitedText string `json:"cited_text,omitempty"` - DocumentTitle string `json:"document_title"` - DocumentIndex int64 `json:"document_index"` + DocumentTitle string `json:"document_title,omitempty"` + DocumentIndex int64 `json:"document_index,omitempty"` - StartBlockIndex int64 `json:"start_block_index"` - EndBlockIndex int64 `json:"end_block_index"` + StartBlockIndex int64 `json:"start_block_index,omitempty"` + EndBlockIndex int64 `json:"end_block_index,omitempty"` } type CitationWebSearchResultLocation struct { - CitedText string `json:"cited_text"` + CitedText string `json:"cited_text,omitempty"` - Title string `json:"title"` - URL string `json:"url"` + Title string `json:"title,omitempty"` + URL string `json:"url,omitempty"` - EncryptedIndex string `json:"encrypted_index"` + EncryptedIndex string `json:"encrypted_index,omitempty"` } diff --git a/schema/claude/response_meta.go b/schema/claude/response_meta.go index 7d9dbe740..9f60dd713 100644 --- a/schema/claude/response_meta.go +++ b/schema/claude/response_meta.go @@ -17,6 +17,6 @@ package claude type ResponseMetaExtension struct { - ID string `json:"id"` - StopReason string `json:"stop_reason"` + ID string `json:"id,omitempty"` + StopReason string `json:"stop_reason,omitempty"` } diff --git a/schema/gemini/response_meta.go b/schema/gemini/response_meta.go index bb4af92c9..a5b3f626c 100644 --- a/schema/gemini/response_meta.go +++ b/schema/gemini/response_meta.go @@ -17,8 +17,8 @@ package gemini type ResponseMetaExtension struct { - ID string `json:"id"` - FinishReason string `json:"finish_reason"` + ID string `json:"id,omitempty"` + FinishReason string `json:"finish_reason,omitempty"` GroundingMeta *GroundingMetadata `json:"grounding_meta,omitempty"` } diff --git a/schema/openai/content_block.go b/schema/openai/content_block.go index b0408e310..5d92be8f7 100644 --- a/schema/openai/content_block.go +++ b/schema/openai/content_block.go @@ -17,11 +17,11 @@ package openai type AssistantGenTextExtension struct { - Annotations []*TextAnnotation `json:"annotations"` + Annotations []*TextAnnotation `json:"annotations,omitempty"` } type TextAnnotation struct { - Type TextAnnotationType `json:"type"` + Type TextAnnotationType `json:"type,omitempty"` FileCitation *TextAnnotationFileCitation `json:"file_citation,omitempty"` URLCitation *TextAnnotationURLCitation `json:"url_citation,omitempty"` @@ -31,45 +31,45 @@ type TextAnnotation struct { type TextAnnotationFileCitation struct { // The ID of the file. - FileID string `json:"file_id"` + FileID string `json:"file_id,omitempty"` // The filename of the file cited. - Filename string `json:"filename"` + Filename string `json:"filename,omitempty"` // The index of the file in the list of files. - Index int64 `json:"index"` + Index int64 `json:"index,omitempty"` } type TextAnnotationURLCitation struct { // The title of the web resource. - Title string `json:"title"` + Title string `json:"title,omitempty"` // The URL of the web resource. - URL string `json:"url"` + URL string `json:"url,omitempty"` // The index of the first character of the URL citation in the message. - StartIndex int64 `json:"start_index"` + StartIndex int64 `json:"start_index,omitempty"` // The index of the last character of the URL citation in the message. - EndIndex int64 `json:"end_index"` + EndIndex int64 `json:"end_index,omitempty"` } type TextAnnotationContainerFileCitation struct { // The ID of the container file. - ContainerID string `json:"container_id"` + ContainerID string `json:"container_id,omitempty"` // The ID of the file. - FileID string `json:"file_id"` + FileID string `json:"file_id,omitempty"` // The filename of the container file cited. - Filename string `json:"filename"` + Filename string `json:"filename,omitempty"` // The index of the first character of the container file citation in the message. - StartIndex int64 `json:"start_index"` + StartIndex int64 `json:"start_index,omitempty"` // The index of the last character of the container file citation in the message. - EndIndex int64 `json:"end_index"` + EndIndex int64 `json:"end_index,omitempty"` } type TextAnnotationFilePath struct { // The ID of the file. - FileID string `json:"file_id"` + FileID string `json:"file_id,omitempty"` // The index of the file in the list of files. - Index int64 `json:"index"` + Index int64 `json:"index,omitempty"` } diff --git a/schema/openai/response_meta.go b/schema/openai/response_meta.go index 809fbb1a5..90e884173 100644 --- a/schema/openai/response_meta.go +++ b/schema/openai/response_meta.go @@ -17,17 +17,17 @@ package openai type ResponseMetaExtension struct { - ID string `json:"id"` - Status string `json:"status"` + ID string `json:"id,omitempty"` + Status string `json:"status,omitempty"` Error *ResponseError `json:"error,omitempty"` IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"` } type ResponseError struct { - Code string `json:"code"` - Message string `json:"message"` + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` } type IncompleteDetails struct { - Reason string `json:"reason"` + Reason string `json:"reason,omitempty"` } From edafdbcea9cc59f58054b5af1a3f0d3a5e444476 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Tue, 2 Dec 2025 19:09:01 +0800 Subject: [PATCH 13/59] fix: MCPToolApprovalRequest definition (#600) --- schema/agentic_message.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 367debd97..93dd817ca 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -42,7 +42,7 @@ const ( ContentBlockTypeServerToolResult ContentBlockType = "server_tool_result" ContentBlockTypeMCPToolCall ContentBlockType = "mcp_tool_call" ContentBlockTypeMCPToolResult ContentBlockType = "mcp_tool_result" - ContentBlockTypeMCPListTools ContentBlockType = "mcp_list_tools" + ContentBlockTypeMCPListToolsResult ContentBlockType = "mcp_list_tools_result" ContentBlockTypeMCPToolApprovalRequest ContentBlockType = "mcp_tool_approval_request" ContentBlockTypeMCPToolApprovalResponse ContentBlockType = "mcp_tool_approval_response" ) @@ -332,8 +332,8 @@ type MCPListToolsItem struct { } type MCPToolApprovalRequest struct { - // CallID is the unique ID of the tool call. - CallID string + // ID is the approval request ID. + ID string // Name is the name of the tool to run. Name string // Arguments is the JSON string arguments for the tool call. @@ -431,7 +431,7 @@ func NewContentBlock(block any) *ContentBlock { case *MCPToolResult: return &ContentBlock{Type: ContentBlockTypeMCPToolResult, MCPToolResult: b} case *MCPListToolsResult: - return &ContentBlock{Type: ContentBlockTypeMCPListTools, MCPListToolsResult: b} + return &ContentBlock{Type: ContentBlockTypeMCPListToolsResult, MCPListToolsResult: b} case *MCPToolApprovalRequest: return &ContentBlock{Type: ContentBlockTypeMCPToolApprovalRequest, MCPToolApprovalRequest: b} case *MCPToolApprovalResponse: From ed90509c79cb392ef3e54d51105eb35de693a885 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Wed, 3 Dec 2025 15:24:00 +0800 Subject: [PATCH 14/59] feat: define StreamResponseError for openai (#601) --- schema/openai/response_meta.go | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/schema/openai/response_meta.go b/schema/openai/response_meta.go index 90e884173..e1933065b 100644 --- a/schema/openai/response_meta.go +++ b/schema/openai/response_meta.go @@ -17,10 +17,11 @@ package openai type ResponseMetaExtension struct { - ID string `json:"id,omitempty"` - Status string `json:"status,omitempty"` - Error *ResponseError `json:"error,omitempty"` - IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"` + ID string `json:"id,omitempty"` + Status string `json:"status,omitempty"` + Error *ResponseError `json:"error,omitempty"` + StreamError *StreamResponseError `json:"stream_error,omitempty"` + IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"` } type ResponseError struct { @@ -28,6 +29,12 @@ type ResponseError struct { Message string `json:"message,omitempty"` } +type StreamResponseError struct { + Code string + Message string + Param string +} + type IncompleteDetails struct { Reason string `json:"reason,omitempty"` } From 6a017079061f22dda4bd472c7d0ebc402ab2f5b9 Mon Sep 17 00:00:00 2001 From: Megumin Date: Wed, 3 Dec 2025 17:22:51 +0800 Subject: [PATCH 15/59] feat: support agentic message concat (#576) feat(agentic_model): - format print - support agentic chat template - support to compose agentic odel&agentic tools node - support agentic tool node - support agentic message concat --- components/agentic/callback_extra_test.go | 35 + components/agentic/option_test.go | 79 + components/prompt/callback_extra.go | 38 + components/prompt/chat_template_agentic.go | 84 + .../prompt/chat_template_agentic_test.go | 111 ++ components/prompt/interface.go | 5 + components/types.go | 1 + compose/chain.go | 43 +- compose/chain_branch.go | 49 +- compose/chain_parallel.go | 43 + compose/component_to_graph_node.go | 33 + compose/graph.go | 40 +- compose/tools_node_agentic.go | 125 ++ compose/tools_node_agentic_test.go | 244 +++ compose/types.go | 15 +- internal/concat.go | 6 +- schema/agentic_message.go | 1352 ++++++++++++++++ schema/agentic_message_test.go | 1381 +++++++++++++++++ schema/message.go | 69 +- 19 files changed, 3707 insertions(+), 46 deletions(-) create mode 100644 components/agentic/callback_extra_test.go create mode 100644 components/agentic/option_test.go create mode 100644 components/prompt/chat_template_agentic.go create mode 100644 components/prompt/chat_template_agentic_test.go create mode 100644 compose/tools_node_agentic.go create mode 100644 compose/tools_node_agentic_test.go create mode 100644 schema/agentic_message_test.go diff --git a/components/agentic/callback_extra_test.go b/components/agentic/callback_extra_test.go new file mode 100644 index 000000000..a77da6cd2 --- /dev/null +++ b/components/agentic/callback_extra_test.go @@ -0,0 +1,35 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package agentic + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" +) + +func TestConvModel(t *testing.T) { + assert.NotNil(t, ConvCallbackInput(&CallbackInput{})) + assert.NotNil(t, ConvCallbackInput([]*schema.AgenticMessage{})) + assert.Nil(t, ConvCallbackInput("asd")) + + assert.NotNil(t, ConvCallbackOutput(&CallbackOutput{})) + assert.NotNil(t, ConvCallbackOutput(&schema.AgenticMessage{})) + assert.Nil(t, ConvCallbackOutput("asd")) +} diff --git a/components/agentic/option_test.go b/components/agentic/option_test.go new file mode 100644 index 000000000..d349f35ac --- /dev/null +++ b/components/agentic/option_test.go @@ -0,0 +1,79 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package agentic + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" +) + +func TestCommon(t *testing.T) { + o := GetCommonOptions(nil, + WithTools([]*schema.ToolInfo{{Name: "test"}}), + WithModel("test"), + WithTemperature(0.1), + WithToolChoice(schema.ToolChoiceAllowed), + WithTopP(0.1), + ) + assert.Len(t, o.Tools, 1) + assert.Equal(t, "test", o.Tools[0].Name) + assert.Equal(t, "test", *o.Model) + assert.Equal(t, float64(0.1), *o.Temperature) + assert.Equal(t, schema.ToolChoiceAllowed, *o.ToolChoice) + assert.Equal(t, float64(0.1), *o.TopP) +} + +func TestImplSpecificOpts(t *testing.T) { + type implSpecificOptions struct { + conf string + index int + } + + withConf := func(conf string) func(o *implSpecificOptions) { + return func(o *implSpecificOptions) { + o.conf = conf + } + } + + withIndex := func(index int) func(o *implSpecificOptions) { + return func(o *implSpecificOptions) { + o.index = index + } + } + + documentOption1 := WrapImplSpecificOptFn(withConf("test_conf")) + documentOption2 := WrapImplSpecificOptFn(withIndex(1)) + + implSpecificOpts := GetImplSpecificOptions(&implSpecificOptions{}, documentOption1, documentOption2) + + assert.Equal(t, &implSpecificOptions{ + conf: "test_conf", + index: 1, + }, implSpecificOpts) + documentOption1 = WrapImplSpecificOptFn(withConf("test_conf")) + documentOption2 = WrapImplSpecificOptFn(withIndex(1)) + + implSpecificOpts = GetImplSpecificOptions(&implSpecificOptions{}, documentOption1, documentOption2) + + assert.Equal(t, &implSpecificOptions{ + conf: "test_conf", + index: 1, + }, implSpecificOpts) +} diff --git a/components/prompt/callback_extra.go b/components/prompt/callback_extra.go index 324a418f3..ff5c3a8ff 100644 --- a/components/prompt/callback_extra.go +++ b/components/prompt/callback_extra.go @@ -21,6 +21,44 @@ import ( "github.com/cloudwego/eino/schema" ) +type AgenticCallbackInput struct { + Variables map[string]any + Templates []schema.AgenticMessagesTemplate + Extra map[string]any +} + +type AgenticCallbackOutput struct { + Result []*schema.AgenticMessage + Templates []schema.AgenticMessagesTemplate + Extra map[string]any +} + +func ConvAgenticCallbackInput(src callbacks.CallbackInput) *AgenticCallbackInput { + switch t := src.(type) { + case *AgenticCallbackInput: + return t + case map[string]any: + return &AgenticCallbackInput{ + Variables: t, + } + default: + return nil + } +} + +func ConvAgenticCallbackOutput(src callbacks.CallbackOutput) *AgenticCallbackOutput { + switch t := src.(type) { + case *AgenticCallbackOutput: + return t + case []*schema.AgenticMessage: + return &AgenticCallbackOutput{ + Result: t, + } + default: + return nil + } +} + // CallbackInput is the input for the callback. type CallbackInput struct { // Variables is the variables for the callback. diff --git a/components/prompt/chat_template_agentic.go b/components/prompt/chat_template_agentic.go new file mode 100644 index 000000000..937d46f26 --- /dev/null +++ b/components/prompt/chat_template_agentic.go @@ -0,0 +1,84 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package prompt + +import ( + "context" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/components" + "github.com/cloudwego/eino/schema" +) + +// FromAgenticMessages creates a new DefaultAgenticChatTemplate from the given templates and format type. +// eg. +// +// template := prompt.FromAgenticMessages(schema.FString, &schema.AgenticMessage{}) +// // in chain, or graph +// chain := compose.NewChain[map[string]any, []*schema.AgenticMessage]() +// chain.AppendAgenticChatTemplate(template) +func FromAgenticMessages(formatType schema.FormatType, templates ...schema.AgenticMessagesTemplate) *DefaultAgenticChatTemplate { + return &DefaultAgenticChatTemplate{ + templates: templates, + formatType: formatType, + } +} + +type DefaultAgenticChatTemplate struct { + templates []schema.AgenticMessagesTemplate + formatType schema.FormatType +} + +func (t *DefaultAgenticChatTemplate) Format(ctx context.Context, vs map[string]any, opts ...Option) (result []*schema.AgenticMessage, err error) { + ctx = callbacks.EnsureRunInfo(ctx, t.GetType(), components.ComponentOfAgenticPrompt) + ctx = callbacks.OnStart(ctx, &AgenticCallbackInput{ + Variables: vs, + Templates: t.templates, + }) + defer func() { + if err != nil { + _ = callbacks.OnError(ctx, err) + } + }() + + result = make([]*schema.AgenticMessage, 0, len(t.templates)) + for _, template := range t.templates { + msgs, err := template.Format(ctx, vs, t.formatType) + if err != nil { + return nil, err + } + + result = append(result, msgs...) + } + + _ = callbacks.OnEnd(ctx, &AgenticCallbackOutput{ + Result: result, + Templates: t.templates, + }) + + return result, nil +} + +// GetType returns the type of the chat template (Default). +func (t *DefaultAgenticChatTemplate) GetType() string { + return "Default" +} + +// IsCallbacksEnabled checks if the callbacks are enabled for the chat template. +func (t *DefaultAgenticChatTemplate) IsCallbacksEnabled() bool { + return true +} diff --git a/components/prompt/chat_template_agentic_test.go b/components/prompt/chat_template_agentic_test.go new file mode 100644 index 000000000..aaa7d6405 --- /dev/null +++ b/components/prompt/chat_template_agentic_test.go @@ -0,0 +1,111 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package prompt + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" +) + +func TestAgenticFormat(t *testing.T) { + pyFmtTestTemplate := []schema.AgenticMessagesTemplate{ + &schema.AgenticMessage{ + Role: schema.AgenticRoleTypeUser, + ContentBlocks: []*schema.ContentBlock{ + {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "{context}"}}, + }, + }, + schema.AgenticMessagesPlaceholder("chat_history", true), + } + jinja2TestTemplate := []schema.AgenticMessagesTemplate{ + &schema.AgenticMessage{ + Role: schema.AgenticRoleTypeUser, + ContentBlocks: []*schema.ContentBlock{ + {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "{{context}}"}}, + }, + }, + schema.AgenticMessagesPlaceholder("chat_history", true), + } + goFmtTestTemplate := []schema.AgenticMessagesTemplate{ + &schema.AgenticMessage{ + Role: schema.AgenticRoleTypeUser, + ContentBlocks: []*schema.ContentBlock{ + {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "{{.context}}"}}, + }, + }, + schema.AgenticMessagesPlaceholder("chat_history", true), + } + testValues := map[string]any{ + "context": "it's beautiful day", + "chat_history": []*schema.AgenticMessage{ + { + Role: schema.AgenticRoleTypeUser, + ContentBlocks: []*schema.ContentBlock{ + {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "1"}}, + }, + }, + { + Role: schema.AgenticRoleTypeUser, + ContentBlocks: []*schema.ContentBlock{ + {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "2"}}, + }, + }, + }, + } + expected := []*schema.AgenticMessage{ + { + Role: schema.AgenticRoleTypeUser, + ContentBlocks: []*schema.ContentBlock{ + {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "it's beautiful day"}}, + }, + }, + { + Role: schema.AgenticRoleTypeUser, + ContentBlocks: []*schema.ContentBlock{ + {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "1"}}, + }, + }, + { + Role: schema.AgenticRoleTypeUser, + ContentBlocks: []*schema.ContentBlock{ + {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "2"}}, + }, + }, + } + + // FString + chatTemplate := FromAgenticMessages(schema.FString, pyFmtTestTemplate...) + msgs, err := chatTemplate.Format(context.Background(), testValues) + assert.Nil(t, err) + assert.Equal(t, expected, msgs) + + // Jinja2 + chatTemplate = FromAgenticMessages(schema.Jinja2, jinja2TestTemplate...) + msgs, err = chatTemplate.Format(context.Background(), testValues) + assert.Nil(t, err) + assert.Equal(t, expected, msgs) + + // GoTemplate + chatTemplate = FromAgenticMessages(schema.GoTemplate, goFmtTestTemplate...) + msgs, err = chatTemplate.Format(context.Background(), testValues) + assert.Nil(t, err) + assert.Equal(t, expected, msgs) +} diff --git a/components/prompt/interface.go b/components/prompt/interface.go index eac695eda..7ffe7216a 100644 --- a/components/prompt/interface.go +++ b/components/prompt/interface.go @@ -23,6 +23,7 @@ import ( ) var _ ChatTemplate = &DefaultChatTemplate{} +var _ AgenticChatTemplate = &DefaultAgenticChatTemplate{} // ChatTemplate formats a variables map into a list of messages for a ChatModel. // @@ -42,3 +43,7 @@ var _ ChatTemplate = &DefaultChatTemplate{} type ChatTemplate interface { Format(ctx context.Context, vs map[string]any, opts ...Option) ([]*schema.Message, error) } + +type AgenticChatTemplate interface { + Format(ctx context.Context, vs map[string]any, opts ...Option) ([]*schema.AgenticMessage, error) +} diff --git a/components/types.go b/components/types.go index a23d82a68..2ba088e93 100644 --- a/components/types.go +++ b/components/types.go @@ -66,6 +66,7 @@ type Component string const ( // ComponentOfPrompt identifies chat template components. ComponentOfPrompt Component = "ChatTemplate" + ComponentOfAgenticPrompt Component = "AgenticChatTemplate" // ComponentOfChatModel identifies chat model components. ComponentOfChatModel Component = "ChatModel" ComponentOfAgenticModel Component = "AgenticModel" diff --git a/compose/chain.go b/compose/chain.go index 5e4a8e1c0..8484e8767 100644 --- a/compose/chain.go +++ b/compose/chain.go @@ -22,6 +22,7 @@ import ( "fmt" "reflect" + "github.com/cloudwego/eino/components/agentic" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" @@ -174,6 +175,18 @@ func (c *Chain[I, O]) AppendChatModel(node model.BaseChatModel, opts ...GraphAdd return c } +// AppendAgenticModel add a agentic.Model node to the chain. +// e.g. +// +// model, err := openai.NewAgenticModel(ctx, config) +// if err != nil {...} +// chain.AppendAgenticModel(model) +func (c *Chain[I, O]) AppendAgenticModel(node agentic.Model, opts ...GraphAddNodeOpt) *Chain[I, O] { + gNode, options := toAgenticModelNode(node, opts...) + c.addNode(gNode, options) + return c +} + // AppendChatTemplate add a ChatTemplate node to the chain. // eg. // @@ -189,11 +202,23 @@ func (c *Chain[I, O]) AppendChatTemplate(node prompt.ChatTemplate, opts ...Graph return c } +// AppendAgenticChatTemplate add a prompt.AgenticChatTemplate node to the chain. +// eg. +// +// chatTemplate, err := prompt.FromAgenticMessages(schema.FString, &schema.AgenticMessage{}) +// +// chain.AppendAgenticChatTemplate(chatTemplate) +func (c *Chain[I, O]) AppendAgenticChatTemplate(node prompt.AgenticChatTemplate, opts ...GraphAddNodeOpt) *Chain[I, O] { + gNode, options := toAgenticChatTemplateNode(node, opts...) + c.addNode(gNode, options) + return c +} + // AppendToolsNode add a ToolsNode node to the chain. // e.g. // -// toolsNode, err := tools.NewToolNode(ctx, &tools.ToolsNodeConfig{ -// Tools: []tools.Tool{...}, +// toolsNode, err := compose.NewToolNode(ctx, &compose.ToolsNodeConfig{ +// Tools: []tools.BaseTool{...}, // }) // // chain.AppendToolsNode(toolsNode) @@ -203,6 +228,20 @@ func (c *Chain[I, O]) AppendToolsNode(node *ToolsNode, opts ...GraphAddNodeOpt) return c } +// AppendAgenticToolsNode add a AgenticToolsNode node to the chain. +// e.g. +// +// toolsNode, err := compose.NewAgenticToolsNode(ctx, &compose.ToolsNodeConfig{ +// Tools: []tools.BaseTool{...}, +// }) +// +// chain.AppendAgenticToolsNode(toolsNode) +func (c *Chain[I, O]) AppendAgenticToolsNode(node *AgenticToolsNode, opts ...GraphAddNodeOpt) *Chain[I, O] { + gNode, options := toAgenticToolsNode(node, opts...) + c.addNode(gNode, options) + return c +} + // AppendDocumentTransformer add a DocumentTransformer node to the chain. // e.g. // diff --git a/compose/chain_branch.go b/compose/chain_branch.go index ec3a433af..004dbfac3 100644 --- a/compose/chain_branch.go +++ b/compose/chain_branch.go @@ -20,6 +20,7 @@ import ( "context" "fmt" + "github.com/cloudwego/eino/components/agentic" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" @@ -146,6 +147,22 @@ func (cb *ChainBranch) AddChatModel(key string, node model.BaseChatModel, opts . return cb.addNode(key, gNode, options) } +// AddAgenticModel adds a agentic.Model node to the branch. +// eg. +// +// model1, err := openai.NewAgenticModel(ctx, &openai.AgenticModelConfig{ +// Model: "gpt-4o", +// }) +// model2, err := openai.NewAgenticModel(ctx, &openai.AgenticModelConfig{ +// Model: "gpt-4o-mini", +// }) +// cb.AddAgenticModel("agentic_model_key_1", model1) +// cb.AddAgenticModel("agentic_model_key_2", model2) +func (cb *ChainBranch) AddAgenticModel(key string, node agentic.Model, opts ...GraphAddNodeOpt) *ChainBranch { + gNode, options := toAgenticModelNode(node, opts...) + return cb.addNode(key, gNode, options) +} + // AddChatTemplate adds a ChatTemplate node to the branch. // eg. // @@ -167,11 +184,26 @@ func (cb *ChainBranch) AddChatTemplate(key string, node prompt.ChatTemplate, opt return cb.addNode(key, gNode, options) } +// AddAgenticChatTemplate adds a prompt.AgenticChatTemplate node to the branch. +// eg. +// +// chatTemplate, err := prompt.FromAgenticMessages(schema.FString, &schema.AgenticMessage{}) +// +// cb.AddAgenticChatTemplate("chat_template_key_01", chatTemplate) +// +// chatTemplate2, err := prompt.FromAgenticMessages(schema.FString, &schema.AgenticMessage{}) +// +// cb.AddAgenticChatTemplate("chat_template_key_02", chatTemplate2) +func (cb *ChainBranch) AddAgenticChatTemplate(key string, node prompt.AgenticChatTemplate, opts ...GraphAddNodeOpt) *ChainBranch { + gNode, options := toAgenticChatTemplateNode(node, opts...) + return cb.addNode(key, gNode, options) +} + // AddToolsNode adds a ToolsNode to the branch. // eg. // -// toolsNode, err := tools.NewToolNode(ctx, &tools.ToolsNodeConfig{ -// Tools: []tools.Tool{...}, +// toolsNode, err := compose.NewToolNode(ctx, &compose.ToolsNodeConfig{ +// Tools: []tools.BaseTool{...}, // }) // // cb.AddToolsNode("tools_node_key", toolsNode) @@ -180,6 +212,19 @@ func (cb *ChainBranch) AddToolsNode(key string, node *ToolsNode, opts ...GraphAd return cb.addNode(key, gNode, options) } +// AddAgenticToolsNode adds a AgenticToolsNode to the branch. +// eg. +// +// toolsNode, err := compose.NewAgenticToolsNode(ctx, &compose.ToolsNodeConfig{ +// Tools: []tools.BaseTool{...}, +// }) +// +// cb.AddAgenticToolsNode("tools_node_key", toolsNode) +func (cb *ChainBranch) AddAgenticToolsNode(key string, node *AgenticToolsNode, opts ...GraphAddNodeOpt) *ChainBranch { + gNode, options := toAgenticToolsNode(node, opts...) + return cb.addNode(key, gNode, options) +} + // AddLambda adds a Lambda node to the branch. // eg. // diff --git a/compose/chain_parallel.go b/compose/chain_parallel.go index 64cdf2db1..128ed4a26 100644 --- a/compose/chain_parallel.go +++ b/compose/chain_parallel.go @@ -19,6 +19,7 @@ package compose import ( "fmt" + "github.com/cloudwego/eino/components/agentic" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" @@ -70,6 +71,24 @@ func (p *Parallel) AddChatModel(outputKey string, node model.BaseChatModel, opts return p.addNode(outputKey, gNode, options) } +// AddAgenticModel adds a agentic.Model to the parallel. +// eg. +// +// model1, err := openai.NewAgenticModel(ctx, &openai.AgenticModelConfig{ +// Model: "gpt-4o", +// }) +// +// model2, err := openai.NewAgenticModel(ctx, &openai.AgenticModelConfig{ +// Model: "gpt-4o", +// }) +// +// p.AddAgenticModel("output_key1", model1) +// p.AddAgenticModel("output_key2", model2) +func (p *Parallel) AddAgenticModel(outputKey string, node agentic.Model, opts ...GraphAddNodeOpt) *Parallel { + gNode, options := toAgenticModelNode(node, append(opts, WithOutputKey(outputKey))...) + return p.addNode(outputKey, gNode, options) +} + // AddChatTemplate adds a chat template to the parallel. // eg. // @@ -84,6 +103,17 @@ func (p *Parallel) AddChatTemplate(outputKey string, node prompt.ChatTemplate, o return p.addNode(outputKey, gNode, options) } +// AddAgenticChatTemplate adds a prompt.AgenticChatTemplate to the parallel. +// eg. +// +// chatTemplate01, err := prompt.FromAgenticMessages(schema.FString, &schema.AgenticMessage{}) +// +// p.AddAgenticChatTemplate("output_key01", chatTemplate01) +func (p *Parallel) AddAgenticChatTemplate(outputKey string, node prompt.AgenticChatTemplate, opts ...GraphAddNodeOpt) *Parallel { + gNode, options := toAgenticChatTemplateNode(node, append(opts, WithOutputKey(outputKey))...) + return p.addNode(outputKey, gNode, options) +} + // AddToolsNode adds a tools node to the parallel. // eg. // @@ -97,6 +127,19 @@ func (p *Parallel) AddToolsNode(outputKey string, node *ToolsNode, opts ...Graph return p.addNode(outputKey, gNode, options) } +// AddAgenticToolsNode adds a tools node to the parallel. +// eg. +// +// toolsNode, err := compose.NewAgenticToolsNode(ctx, &compose.ToolsNodeConfig{ +// Tools: []tool.BaseTool{...}, +// }) +// +// p.AddAgenticToolsNode("output_key01", toolsNode) +func (p *Parallel) AddAgenticToolsNode(outputKey string, node *AgenticToolsNode, opts ...GraphAddNodeOpt) *Parallel { + gNode, options := toAgenticToolsNode(node, append(opts, WithOutputKey(outputKey))...) + return p.addNode(outputKey, gNode, options) +} + // AddLambda adds a lambda node to the parallel. // eg. // diff --git a/compose/component_to_graph_node.go b/compose/component_to_graph_node.go index ab4694f1a..e64ce4f19 100644 --- a/compose/component_to_graph_node.go +++ b/compose/component_to_graph_node.go @@ -18,6 +18,7 @@ package compose import ( "github.com/cloudwego/eino/components" + "github.com/cloudwego/eino/components/agentic" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" @@ -101,6 +102,17 @@ func toChatModelNode(node model.BaseChatModel, opts ...GraphAddNodeOpt) (*graphN opts...) } +func toAgenticModelNode(node agentic.Model, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { + return toComponentNode( + node, + components.ComponentOfAgenticModel, + node.Generate, + node.Stream, + nil, nil, + opts..., + ) +} + func toChatTemplateNode(node prompt.ChatTemplate, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { return toComponentNode( node, @@ -112,6 +124,16 @@ func toChatTemplateNode(node prompt.ChatTemplate, opts ...GraphAddNodeOpt) (*gra opts...) } +func toAgenticChatTemplateNode(node prompt.AgenticChatTemplate, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { + return toComponentNode( + node, + components.ComponentOfAgenticPrompt, + node.Format, + nil, nil, nil, + opts..., + ) +} + func toDocumentTransformerNode(node document.Transformer, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { return toComponentNode( node, @@ -134,6 +156,17 @@ func toToolsNode(node *ToolsNode, opts ...GraphAddNodeOpt) (*graphNode, *graphAd opts...) } +func toAgenticToolsNode(node *AgenticToolsNode, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { + return toComponentNode( + node, + ComponentOfAgenticToolsNode, + node.Invoke, + node.Stream, + nil, nil, + opts..., + ) +} + func toLambdaNode(node *Lambda, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { info, options := getNodeInfo(opts...) diff --git a/compose/graph.go b/compose/graph.go index 9370665f0..877b8fb42 100644 --- a/compose/graph.go +++ b/compose/graph.go @@ -23,6 +23,7 @@ import ( "reflect" "strings" + "github.com/cloudwego/eino/components/agentic" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" @@ -352,6 +353,19 @@ func (g *graph) AddChatModelNode(key string, node model.BaseChatModel, opts ...G return g.addNode(key, gNode, options) } +// AddAgenticModelNode add node that implements agentic.Model. +// e.g. +// +// model, err := openai.NewAgenticModel(ctx, &openai.AgenticModelConfig{ +// Model: "gpt-4o", +// }) +// +// graph.AddAgenticModelNode("agentic_model_node_key", model) +func (g *graph) AddAgenticModelNode(key string, node agentic.Model, opts ...GraphAddNodeOpt) error { + gNode, options := toAgenticModelNode(node, opts...) + return g.addNode(key, gNode, options) +} + // AddChatTemplateNode add node that implements prompt.ChatTemplate. // e.g. // @@ -366,10 +380,21 @@ func (g *graph) AddChatTemplateNode(key string, node prompt.ChatTemplate, opts . return g.addNode(key, gNode, options) } -// AddToolsNode adds a node that implements tools.ToolsNode. +// AddAgenticChatTemplateNode add node that implements prompt.AgenticChatTemplate. +// e.g. +// +// chatTemplate, err := prompt.FromAgenticMessages(schema.FString, &schema.AgenticMessage{}) +// +// graph.AddAgenticChatTemplateNode("chat_template_node_key", chatTemplate) +func (g *graph) AddAgenticChatTemplateNode(key string, node prompt.AgenticChatTemplate, opts ...GraphAddNodeOpt) error { + gNode, options := toAgenticChatTemplateNode(node, opts...) + return g.addNode(key, gNode, options) +} + +// AddToolsNode adds a node that implements ToolsNode. // e.g. // -// toolsNode, err := tools.NewToolNode(ctx, &tools.ToolsNodeConfig{}) +// toolsNode, err := compose.NewToolNode(ctx, &compose.ToolsNodeConfig{}) // // graph.AddToolsNode("tools_node_key", toolsNode) func (g *graph) AddToolsNode(key string, node *ToolsNode, opts ...GraphAddNodeOpt) error { @@ -377,6 +402,17 @@ func (g *graph) AddToolsNode(key string, node *ToolsNode, opts ...GraphAddNodeOp return g.addNode(key, gNode, options) } +// AddAgenticToolsNode adds a node that implements AgenticToolsNode. +// e.g. +// +// toolsNode, err := compose.NewAgenticToolsNode(ctx, &compose.ToolsNodeConfig{}) +// +// graph.AddAgenticToolsNode("tools_node_key", toolsNode) +func (g *graph) AddAgenticToolsNode(key string, node *AgenticToolsNode, opts ...GraphAddNodeOpt) error { + gNode, options := toAgenticToolsNode(node, opts...) + return g.addNode(key, gNode, options) +} + // AddDocumentTransformerNode adds a node that implements document.Transformer. // e.g. // diff --git a/compose/tools_node_agentic.go b/compose/tools_node_agentic.go new file mode 100644 index 000000000..38c5c89de --- /dev/null +++ b/compose/tools_node_agentic.go @@ -0,0 +1,125 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "context" + + "github.com/cloudwego/eino/schema" +) + +// NewAgenticToolsNode creates a new AgenticToolsNode. +// e.g. +// +// conf := &ToolsNodeConfig{ +// Tools: []tool.BaseTool{invokableTool1, streamableTool2}, +// } +// toolsNode, err := NewAgenticToolsNode(ctx, conf) +func NewAgenticToolsNode(ctx context.Context, conf *ToolsNodeConfig) (*AgenticToolsNode, error) { + tn, err := NewToolNode(ctx, conf) + if err != nil { + return nil, err + } + return &AgenticToolsNode{inner: tn}, nil +} + +type AgenticToolsNode struct { + inner *ToolsNode +} + +func (a *AgenticToolsNode) Invoke(ctx context.Context, input *schema.AgenticMessage, opts ...ToolsNodeOption) ([]*schema.AgenticMessage, error) { + result, err := a.inner.Invoke(ctx, agenticMessageToToolCallMessage(input), opts...) + if err != nil { + return nil, err + } + return toolMessageToAgenticMessage(result), nil +} + +func (a *AgenticToolsNode) Stream(ctx context.Context, input *schema.AgenticMessage, + opts ...ToolsNodeOption) (*schema.StreamReader[[]*schema.AgenticMessage], error) { + result, err := a.inner.Stream(ctx, agenticMessageToToolCallMessage(input), opts...) + if err != nil { + return nil, err + } + return streamToolMessageToAgenticMessage(result), nil +} + +func agenticMessageToToolCallMessage(input *schema.AgenticMessage) *schema.Message { + var tc []schema.ToolCall + for _, block := range input.ContentBlocks { + if block.Type != schema.ContentBlockTypeFunctionToolCall || block.FunctionToolCall == nil { + continue + } + tc = append(tc, schema.ToolCall{ + ID: block.FunctionToolCall.CallID, + Function: schema.FunctionCall{ + Name: block.FunctionToolCall.Name, + Arguments: block.FunctionToolCall.Arguments, + }, + }) + } + return &schema.Message{ + Role: schema.Assistant, + ToolCalls: tc, + } +} + +func toolMessageToAgenticMessage(input []*schema.Message) []*schema.AgenticMessage { + var results []*schema.ContentBlock + for _, m := range input { + results = append(results, &schema.ContentBlock{ + Type: schema.ContentBlockTypeFunctionToolResult, + FunctionToolResult: &schema.FunctionToolResult{ + CallID: m.ToolCallID, + Name: m.ToolName, + Result: m.Content, + Extra: m.Extra, + }, + }) + } + return []*schema.AgenticMessage{{ + Role: schema.AgenticRoleTypeUser, + ContentBlocks: results, + }} +} + +func streamToolMessageToAgenticMessage(input *schema.StreamReader[[]*schema.Message]) *schema.StreamReader[[]*schema.AgenticMessage] { + return schema.StreamReaderWithConvert(input, func(t []*schema.Message) ([]*schema.AgenticMessage, error) { + var results []*schema.ContentBlock + for i, m := range t { + if m == nil { + continue + } + results = append(results, &schema.ContentBlock{ + Type: schema.ContentBlockTypeFunctionToolResult, + FunctionToolResult: &schema.FunctionToolResult{ + CallID: m.ToolCallID, + Name: m.ToolName, + Result: m.Content, + Extra: m.Extra, + }, + StreamMeta: &schema.StreamMeta{Index: int64(i)}, + }) + } + return []*schema.AgenticMessage{{ + Role: schema.AgenticRoleTypeUser, + ContentBlocks: results, + }}, nil + }) +} + +func (a *AgenticToolsNode) GetType() string { return "" } diff --git a/compose/tools_node_agentic_test.go b/compose/tools_node_agentic_test.go new file mode 100644 index 000000000..dcd3177a9 --- /dev/null +++ b/compose/tools_node_agentic_test.go @@ -0,0 +1,244 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "io" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" +) + +func TestAgenticMessageToToolCallMessage(t *testing.T) { + input := &schema.AgenticMessage{ + ContentBlocks: []*schema.ContentBlock{ + { + Type: schema.ContentBlockTypeFunctionToolCall, + FunctionToolCall: &schema.FunctionToolCall{ + CallID: "1", + Name: "name1", + Arguments: "arg1", + }, + }, + { + Type: schema.ContentBlockTypeFunctionToolCall, + FunctionToolCall: &schema.FunctionToolCall{ + CallID: "2", + Name: "name2", + Arguments: "arg2", + }, + }, + { + Type: schema.ContentBlockTypeFunctionToolCall, + FunctionToolCall: &schema.FunctionToolCall{ + CallID: "3", + Name: "name3", + Arguments: "arg3", + }, + }, + }, + } + ret := agenticMessageToToolCallMessage(input) + assert.Equal(t, schema.Assistant, ret.Role) + assert.Equal(t, []schema.ToolCall{ + { + ID: "1", + Function: schema.FunctionCall{ + Name: "name1", + Arguments: "arg1", + }, + }, + { + ID: "2", + Function: schema.FunctionCall{ + Name: "name2", + Arguments: "arg2", + }, + }, + { + ID: "3", + Function: schema.FunctionCall{ + Name: "name3", + Arguments: "arg3", + }, + }, + }, ret.ToolCalls) +} + +func TestToolMessageToAgenticMessage(t *testing.T) { + input := []*schema.Message{ + { + Role: schema.Tool, + Content: "content1", + ToolCallID: "1", + ToolName: "name1", + }, + { + Role: schema.Tool, + Content: "content2", + ToolCallID: "2", + ToolName: "name2", + }, + { + Role: schema.Tool, + Content: "content3", + ToolCallID: "3", + ToolName: "name3", + }, + } + ret := toolMessageToAgenticMessage(input) + assert.Equal(t, 1, len(ret)) + assert.Equal(t, schema.AgenticRoleTypeUser, ret[0].Role) + assert.Equal(t, []*schema.ContentBlock{ + { + Type: schema.ContentBlockTypeFunctionToolResult, + FunctionToolResult: &schema.FunctionToolResult{ + CallID: "1", + Name: "name1", + Result: "content1", + }, + }, + { + Type: schema.ContentBlockTypeFunctionToolResult, + FunctionToolResult: &schema.FunctionToolResult{ + CallID: "2", + Name: "name2", + Result: "content2", + }, + }, + { + Type: schema.ContentBlockTypeFunctionToolResult, + FunctionToolResult: &schema.FunctionToolResult{ + CallID: "3", + Name: "name3", + Result: "content3", + }, + }, + }, ret[0].ContentBlocks) +} + +func TestStreamToolMessageToAgenticMessage(t *testing.T) { + input := schema.StreamReaderFromArray([][]*schema.Message{ + { + { + Role: schema.Tool, + Content: "content1-1", + ToolName: "name1", + ToolCallID: "1", + }, + nil, nil, + }, + { + nil, + { + Role: schema.Tool, + Content: "content2-1", + ToolName: "name2", + ToolCallID: "2", + }, + nil, + }, + { + { + Role: schema.Tool, + Content: "content1-2", + ToolName: "name2", + ToolCallID: "2", + }, + nil, nil, + }, + { + nil, nil, + { + Role: schema.Tool, + Content: "content3-1", + ToolName: "name3", + ToolCallID: "3", + }, + }, + { + nil, + { + Role: schema.Tool, + Content: "content2-2", + ToolName: "name2", + ToolCallID: "2", + }, + nil, + }, + { + nil, nil, + { + Role: schema.Tool, + Content: "content3-2", + ToolName: "name3", + ToolCallID: "3", + }, + }, + }) + ret := streamToolMessageToAgenticMessage(input) + var chunks [][]*schema.AgenticMessage + for { + chunk, err := ret.Recv() + if err == io.EOF { + break + } + assert.NoError(t, err) + chunks = append(chunks, chunk) + } + result, err := schema.ConcatAgenticMessagesArray(chunks) + assert.NoError(t, err) + assert.Equal(t, []*schema.AgenticMessage{ + { + Role: schema.AgenticRoleTypeUser, + ContentBlocks: []*schema.ContentBlock{ + { + Type: schema.ContentBlockTypeFunctionToolResult, + FunctionToolResult: &schema.FunctionToolResult{ + CallID: "1", + Name: "name1", + Result: "content1-1content1-2", + Extra: map[string]interface{}{}, + }, + StreamMeta: &schema.StreamMeta{Index: 0}, + }, + { + Type: schema.ContentBlockTypeFunctionToolResult, + FunctionToolResult: &schema.FunctionToolResult{ + CallID: "2", + Name: "name2", + Result: "content2-1content2-2", + Extra: map[string]interface{}{}, + }, + StreamMeta: &schema.StreamMeta{Index: 1}, + }, + { + Type: schema.ContentBlockTypeFunctionToolResult, + FunctionToolResult: &schema.FunctionToolResult{ + CallID: "3", + Name: "name3", + Result: "content3-1content3-2", + Extra: map[string]interface{}{}, + }, + StreamMeta: &schema.StreamMeta{Index: 2}, + }, + }, + }, + }, result) +} diff --git a/compose/types.go b/compose/types.go index 13d925df2..54f8e2be3 100644 --- a/compose/types.go +++ b/compose/types.go @@ -25,13 +25,14 @@ type component = components.Component // built-in component types in graph node. // it represents the type of the most primitive executable object provided by the user. const ( - ComponentOfUnknown component = "Unknown" - ComponentOfGraph component = "Graph" - ComponentOfWorkflow component = "Workflow" - ComponentOfChain component = "Chain" - ComponentOfPassthrough component = "Passthrough" - ComponentOfToolsNode component = "ToolsNode" - ComponentOfLambda component = "Lambda" + ComponentOfUnknown component = "Unknown" + ComponentOfGraph component = "Graph" + ComponentOfWorkflow component = "Workflow" + ComponentOfChain component = "Chain" + ComponentOfPassthrough component = "Passthrough" + ComponentOfToolsNode component = "ToolsNode" + ComponentOfAgenticToolsNode component = "AgenticToolsNode" + ComponentOfLambda component = "Lambda" ) // NodeTriggerMode controls the triggering mode of graph nodes. diff --git a/internal/concat.go b/internal/concat.go index 2681322ab..fd9b8abc5 100644 --- a/internal/concat.go +++ b/internal/concat.go @@ -99,7 +99,7 @@ func ConcatItems[T any](items []T) (T, error) { if typ.Kind() == reflect.Map { cv, err = concatMaps(v) } else { - cv, err = concatSliceValue(v) + cv, err = ConcatSliceValue(v) } if err != nil { @@ -158,7 +158,7 @@ func concatMaps(ms reflect.Value) (reflect.Value, error) { if v.Type().Elem().Kind() == reflect.Map { cv, err = concatMaps(v) } else { - cv, err = concatSliceValue(v) + cv, err = ConcatSliceValue(v) } if err != nil { @@ -171,7 +171,7 @@ func concatMaps(ms reflect.Value) (reflect.Value, error) { return ret, nil } -func concatSliceValue(val reflect.Value) (reflect.Value, error) { +func ConcatSliceValue(val reflect.Value) (reflect.Value, error) { elmType := val.Type().Elem() if val.Len() == 1 { diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 93dd817ca..2139201ec 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -17,9 +17,17 @@ package schema import ( + "context" + "fmt" + "reflect" + "strings" + "github.com/cloudwego/eino/schema/claude" "github.com/cloudwego/eino/schema/gemini" + + "github.com/cloudwego/eino/internal" "github.com/cloudwego/eino/schema/openai" + "github.com/eino-contrib/jsonschema" ) @@ -440,3 +448,1347 @@ func NewContentBlock(block any) *ContentBlock { return nil } } + +// AgenticMessagesTemplate is the interface for messages template. +// It's used to render a template to a list of agentic messages. +// e.g. +// +// chatTemplate := prompt.FromAgenticMessages( +// &schema.AgenticMessage{ +// Role: schema.AgenticRoleTypeSystem, +// ContentBlocks: []*schema.ContentBlock{ +// {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "you are an eino helper"}}, +// }, +// }, +// schema.AgenticMessagesPlaceholder("history", false), // <= this will use the value of "history" in params +// ) +// msgs, err := chatTemplate.Format(ctx, params) +type AgenticMessagesTemplate interface { + Format(ctx context.Context, vs map[string]any, formatType FormatType) ([]*AgenticMessage, error) +} + +var _ AgenticMessagesTemplate = &AgenticMessage{} +var _ AgenticMessagesTemplate = AgenticMessagesPlaceholder("", false) + +type agenticMessagesPlaceholder struct { + key string + optional bool +} + +// AgenticMessagesPlaceholder can render a placeholder to a list of agentic messages in params. +// e.g. +// +// placeholder := AgenticMessagesPlaceholder("history", false) +// params := map[string]any{ +// "history": []*schema.AgenticMessage{ +// &schema.AgenticMessage{ +// Role: schema.AgenticRoleTypeSystem, +// ContentBlocks: []*schema.ContentBlock{ +// {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "you are an eino helper"}}, +// }, +// }, +// }, +// } +// chatTemplate := chatTpl := prompt.FromMessages( +// schema.AgenticMessagesPlaceholder("history", false), // <= this will use the value of "history" in params +// ) +// msgs, err := chatTemplate.Format(ctx, params) +func AgenticMessagesPlaceholder(key string, optional bool) AgenticMessagesTemplate { + return &agenticMessagesPlaceholder{ + key: key, + optional: optional, + } +} + +func (p *agenticMessagesPlaceholder) Format(_ context.Context, vs map[string]any, _ FormatType) ([]*AgenticMessage, error) { + v, ok := vs[p.key] + if !ok { + if p.optional { + return []*AgenticMessage{}, nil + } + + return nil, fmt.Errorf("message placeholder format: %s not found", p.key) + } + + msgs, ok := v.([]*AgenticMessage) + if !ok { + return nil, fmt.Errorf("only agentic messages can be used to format message placeholder, key: %v, actual type: %v", p.key, reflect.TypeOf(v)) + } + + return msgs, nil +} + +// Format returns the agentic messages after rendering by the given formatType. +// It formats only the user input fields (UserInputText, UserInputImage, UserInputAudio, UserInputVideo, UserInputFile). +// e.g. +// +// msg := &schema.AgenticMessage{ +// Role: schema.AgenticRoleTypeUser, +// ContentBlocks: []*schema.ContentBlock{ +// {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "hello {name}"}}, +// }, +// } +// msgs, err := msg.Format(ctx, map[string]any{"name": "eino"}, schema.FString) +// // msgs[0].ContentBlocks[0].UserInputText.Text will be "hello eino" +func (m *AgenticMessage) Format(_ context.Context, vs map[string]any, formatType FormatType) ([]*AgenticMessage, error) { + copied := *m + + if len(m.ContentBlocks) > 0 { + copiedBlocks := make([]*ContentBlock, len(m.ContentBlocks)) + for i, block := range m.ContentBlocks { + if block == nil { + copiedBlocks[i] = nil + continue + } + + copiedBlock := *block + var err error + + switch block.Type { + case ContentBlockTypeUserInputText: + if block.UserInputText != nil { + copiedBlock.UserInputText, err = formatUserInputText(block.UserInputText, vs, formatType) + if err != nil { + return nil, err + } + } + case ContentBlockTypeUserInputImage: + if block.UserInputImage != nil { + copiedBlock.UserInputImage, err = formatUserInputImage(block.UserInputImage, vs, formatType) + if err != nil { + return nil, err + } + } + case ContentBlockTypeUserInputAudio: + if block.UserInputAudio != nil { + copiedBlock.UserInputAudio, err = formatUserInputAudio(block.UserInputAudio, vs, formatType) + if err != nil { + return nil, err + } + } + case ContentBlockTypeUserInputVideo: + if block.UserInputVideo != nil { + copiedBlock.UserInputVideo, err = formatUserInputVideo(block.UserInputVideo, vs, formatType) + if err != nil { + return nil, err + } + } + case ContentBlockTypeUserInputFile: + if block.UserInputFile != nil { + copiedBlock.UserInputFile, err = formatUserInputFile(block.UserInputFile, vs, formatType) + if err != nil { + return nil, err + } + } + } + + copiedBlocks[i] = &copiedBlock + } + copied.ContentBlocks = copiedBlocks + } + + return []*AgenticMessage{&copied}, nil +} + +func formatUserInputText(uit *UserInputText, vs map[string]any, formatType FormatType) (*UserInputText, error) { + text, err := formatContent(uit.Text, vs, formatType) + if err != nil { + return nil, err + } + copied := *uit + copied.Text = text + return &copied, nil +} + +func formatUserInputImage(uii *UserInputImage, vs map[string]any, formatType FormatType) (*UserInputImage, error) { + copied := *uii + if uii.URL != "" { + url, err := formatContent(uii.URL, vs, formatType) + if err != nil { + return nil, err + } + copied.URL = url + } + if uii.Base64Data != "" { + base64data, err := formatContent(uii.Base64Data, vs, formatType) + if err != nil { + return nil, err + } + copied.Base64Data = base64data + } + return &copied, nil +} + +func formatUserInputAudio(uia *UserInputAudio, vs map[string]any, formatType FormatType) (*UserInputAudio, error) { + copied := *uia + if uia.URL != "" { + url, err := formatContent(uia.URL, vs, formatType) + if err != nil { + return nil, err + } + copied.URL = url + } + if uia.Base64Data != "" { + base64data, err := formatContent(uia.Base64Data, vs, formatType) + if err != nil { + return nil, err + } + copied.Base64Data = base64data + } + return &copied, nil +} + +func formatUserInputVideo(uiv *UserInputVideo, vs map[string]any, formatType FormatType) (*UserInputVideo, error) { + copied := *uiv + if uiv.URL != "" { + url, err := formatContent(uiv.URL, vs, formatType) + if err != nil { + return nil, err + } + copied.URL = url + } + if uiv.Base64Data != "" { + base64data, err := formatContent(uiv.Base64Data, vs, formatType) + if err != nil { + return nil, err + } + copied.Base64Data = base64data + } + return &copied, nil +} + +func formatUserInputFile(uif *UserInputFile, vs map[string]any, formatType FormatType) (*UserInputFile, error) { + copied := *uif + if uif.URL != "" { + url, err := formatContent(uif.URL, vs, formatType) + if err != nil { + return nil, err + } + copied.URL = url + } + if uif.Name != "" { + name, err := formatContent(uif.Name, vs, formatType) + if err != nil { + return nil, err + } + copied.Name = name + } + if uif.Base64Data != "" { + base64data, err := formatContent(uif.Base64Data, vs, formatType) + if err != nil { + return nil, err + } + copied.Base64Data = base64data + } + return &copied, nil +} + +func ConcatAgenticMessagesArray(mas [][]*AgenticMessage) ([]*AgenticMessage, error) { + return buildConcatGenericArray[AgenticMessage](ConcatAgenticMessages)(mas) +} + +func ConcatAgenticMessages(msgs []*AgenticMessage) (*AgenticMessage, error) { + var ( + role AgenticRoleType + blocksList [][]*ContentBlock + blocks []*ContentBlock + metas []*AgenticResponseMeta + ) + + if len(msgs) == 1 { + return msgs[0], nil + } + + for idx, msg := range msgs { + if msg == nil { + return nil, fmt.Errorf("message at index %d is nil", idx) + } + + if msg.Role != "" { + if role == "" { + role = msg.Role + } else if role != msg.Role { + return nil, fmt.Errorf("cannot concat messages with different roles: got '%s' and '%s'", role, msg.Role) + } + } + + for _, block := range msg.ContentBlocks { + if block.StreamMeta == nil { + // Non-streaming block + if len(blocksList) > 0 { + // Cannot mix streaming and non-streaming blocks + return nil, fmt.Errorf("found non-streaming block after streaming blocks") + } + // Collect non-streaming block + blocks = append(blocks, block) + } else { + // Streaming block + if len(blocks) > 0 { + // Cannot mix non-streaming and streaming blocks + return nil, fmt.Errorf("found streaming block after non-streaming blocks") + } + // Collect streaming block by index + blocksList = expandSlice(int(block.StreamMeta.Index), blocksList) + blocksList[block.StreamMeta.Index] = append(blocksList[block.StreamMeta.Index], block) + } + } + + if msg.ResponseMeta != nil { + metas = append(metas, msg.ResponseMeta) + } + } + + meta, err := concatAgenticResponseMeta(metas) + if err != nil { + return nil, fmt.Errorf("failed to concat agentic response meta: %w", err) + } + + if len(blocksList) > 0 { + // All blocks are streaming, concat each group by index + blocks = make([]*ContentBlock, len(blocksList)) + for i, bs := range blocksList { + if len(bs) == 0 { + continue + } + b, err := concatAgenticContentBlocks(bs) + if err != nil { + return nil, fmt.Errorf("failed to concat content blocks at index %d: %w", i, err) + } + blocks[i] = b + } + } + + for i := 0; i < len(blocks); i++ { + if blocks[i] == nil { + blocks = append(blocks[:i], blocks[i+1:]...) + } + } + + return &AgenticMessage{ + ResponseMeta: meta, + Role: role, + ContentBlocks: blocks, + }, nil +} + +func concatAgenticResponseMeta(metas []*AgenticResponseMeta) (*AgenticResponseMeta, error) { + if len(metas) == 0 { + return nil, nil + } + ret := &AgenticResponseMeta{ + TokenUsage: &TokenUsage{}, + OpenAIExtension: nil, + ClaudeExtension: nil, + GeminiExtension: nil, + Extension: nil, + } + for _, meta := range metas { + ret.Extension = meta.Extension + ret.OpenAIExtension = meta.OpenAIExtension + ret.ClaudeExtension = meta.ClaudeExtension + ret.GeminiExtension = meta.GeminiExtension + if meta.TokenUsage != nil { + ret.TokenUsage.CompletionTokens += meta.TokenUsage.CompletionTokens + ret.TokenUsage.CompletionTokenDetails.ReasoningTokens += meta.TokenUsage.CompletionTokenDetails.ReasoningTokens + ret.TokenUsage.PromptTokens += meta.TokenUsage.PromptTokens + ret.TokenUsage.PromptTokenDetails.CachedTokens += meta.TokenUsage.PromptTokenDetails.CachedTokens + ret.TokenUsage.TotalTokens += meta.TokenUsage.TotalTokens + } + } + return ret, nil +} + +func concatAgenticContentBlocks(blocks []*ContentBlock) (*ContentBlock, error) { + if len(blocks) == 0 { + return nil, fmt.Errorf("no content blocks to concat") + } + blockType := blocks[0].Type + index := blocks[0].StreamMeta.Index + switch blockType { + case ContentBlockTypeReasoning: + return concatContentBlockHelper(blocks, blockType, "reasoning", + func(b *ContentBlock) *Reasoning { return b.Reasoning }, + concatReasoning, + func(r *Reasoning) *ContentBlock { + return &ContentBlock{Type: blockType, Reasoning: r, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeUserInputText: + return concatContentBlockHelper(blocks, blockType, "user input text", + func(b *ContentBlock) *UserInputText { return b.UserInputText }, + concatUserInputText, + func(t *UserInputText) *ContentBlock { + return &ContentBlock{Type: blockType, UserInputText: t, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeUserInputImage: + return concatContentBlockHelper(blocks, blockType, "user input image", + func(b *ContentBlock) *UserInputImage { return b.UserInputImage }, + concatUserInputImage, + func(i *UserInputImage) *ContentBlock { + return &ContentBlock{Type: blockType, UserInputImage: i, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeUserInputAudio: + return concatContentBlockHelper(blocks, blockType, "user input audio", + func(b *ContentBlock) *UserInputAudio { return b.UserInputAudio }, + concatUserInputAudio, + func(a *UserInputAudio) *ContentBlock { + return &ContentBlock{Type: blockType, UserInputAudio: a, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeUserInputVideo: + return concatContentBlockHelper(blocks, blockType, "user input video", + func(b *ContentBlock) *UserInputVideo { return b.UserInputVideo }, + concatUserInputVideo, + func(v *UserInputVideo) *ContentBlock { + return &ContentBlock{Type: blockType, UserInputVideo: v, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeUserInputFile: + return concatContentBlockHelper(blocks, blockType, "user input file", + func(b *ContentBlock) *UserInputFile { return b.UserInputFile }, + concatUserInputFile, + func(f *UserInputFile) *ContentBlock { + return &ContentBlock{Type: blockType, UserInputFile: f, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeAssistantGenText: + return concatContentBlockHelper(blocks, blockType, "assistant gen text", + func(b *ContentBlock) *AssistantGenText { return b.AssistantGenText }, + concatAssistantGenText, + func(t *AssistantGenText) *ContentBlock { + return &ContentBlock{Type: blockType, AssistantGenText: t, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeAssistantGenImage: + return concatContentBlockHelper(blocks, blockType, "assistant gen image", + func(b *ContentBlock) *AssistantGenImage { return b.AssistantGenImage }, + concatAssistantGenImage, + func(i *AssistantGenImage) *ContentBlock { + return &ContentBlock{Type: blockType, AssistantGenImage: i, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeAssistantGenAudio: + return concatContentBlockHelper(blocks, blockType, "assistant gen audio", + func(b *ContentBlock) *AssistantGenAudio { return b.AssistantGenAudio }, + concatAssistantGenAudio, + func(a *AssistantGenAudio) *ContentBlock { + return &ContentBlock{Type: blockType, AssistantGenAudio: a, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeAssistantGenVideo: + return concatContentBlockHelper(blocks, blockType, "assistant gen video", + func(b *ContentBlock) *AssistantGenVideo { return b.AssistantGenVideo }, + concatAssistantGenVideo, + func(v *AssistantGenVideo) *ContentBlock { + return &ContentBlock{Type: blockType, AssistantGenVideo: v, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeFunctionToolCall: + return concatContentBlockHelper(blocks, blockType, "function tool call", + func(b *ContentBlock) *FunctionToolCall { return b.FunctionToolCall }, + concatFunctionToolCall, + func(c *FunctionToolCall) *ContentBlock { + return &ContentBlock{Type: blockType, FunctionToolCall: c, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeFunctionToolResult: + return concatContentBlockHelper(blocks, blockType, "function tool result", + func(b *ContentBlock) *FunctionToolResult { return b.FunctionToolResult }, + concatFunctionToolResult, + func(r *FunctionToolResult) *ContentBlock { + return &ContentBlock{Type: blockType, FunctionToolResult: r, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeServerToolCall: + return concatContentBlockHelper(blocks, blockType, "server tool call", + func(b *ContentBlock) *ServerToolCall { return b.ServerToolCall }, + concatServerToolCall, + func(c *ServerToolCall) *ContentBlock { + return &ContentBlock{Type: blockType, ServerToolCall: c, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeServerToolResult: + return concatContentBlockHelper(blocks, blockType, "server tool result", + func(b *ContentBlock) *ServerToolResult { return b.ServerToolResult }, + concatServerToolResult, + func(r *ServerToolResult) *ContentBlock { + return &ContentBlock{Type: blockType, ServerToolResult: r, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeMCPToolCall: + return concatContentBlockHelper(blocks, blockType, "MCP tool call", + func(b *ContentBlock) *MCPToolCall { return b.MCPToolCall }, + concatMCPToolCall, + func(c *MCPToolCall) *ContentBlock { + return &ContentBlock{Type: blockType, MCPToolCall: c, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeMCPToolResult: + return concatContentBlockHelper(blocks, blockType, "MCP tool result", + func(b *ContentBlock) *MCPToolResult { return b.MCPToolResult }, + concatMCPToolResult, + func(r *MCPToolResult) *ContentBlock { + return &ContentBlock{Type: blockType, MCPToolResult: r, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeMCPListToolsResult: + return concatContentBlockHelper(blocks, blockType, "MCP list tools", + func(b *ContentBlock) *MCPListToolsResult { return b.MCPListToolsResult }, + concatMCPListToolsResult, + func(r *MCPListToolsResult) *ContentBlock { + return &ContentBlock{Type: blockType, MCPListToolsResult: r, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeMCPToolApprovalRequest: + return concatContentBlockHelper(blocks, blockType, "MCP tool approval request", + func(b *ContentBlock) *MCPToolApprovalRequest { return b.MCPToolApprovalRequest }, + concatMCPToolApprovalRequest, + func(r *MCPToolApprovalRequest) *ContentBlock { + return &ContentBlock{Type: blockType, MCPToolApprovalRequest: r, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeMCPToolApprovalResponse: + return concatContentBlockHelper(blocks, blockType, "MCP tool approval response", + func(b *ContentBlock) *MCPToolApprovalResponse { return b.MCPToolApprovalResponse }, + concatMCPToolApprovalResponse, + func(r *MCPToolApprovalResponse) *ContentBlock { + return &ContentBlock{Type: blockType, MCPToolApprovalResponse: r, StreamMeta: &StreamMeta{Index: index}} + }) + + default: + return nil, fmt.Errorf("unknown content block type: %s", blockType) + } +} + +// concatContentBlockHelper is a generic helper function that reduces code duplication +// for concatenating content blocks of a specific type. +func concatContentBlockHelper[T any]( + blocks []*ContentBlock, + expectedType ContentBlockType, + typeName string, + getter func(*ContentBlock) *T, + concatFunc func([]*T) (*T, error), + constructor func(*T) *ContentBlock, +) (*ContentBlock, error) { + items, err := genericGetTFromContentBlocks(blocks, func(block *ContentBlock) (*T, error) { + if block.Type != expectedType { + return nil, fmt.Errorf("expected %s block, got %s", typeName, block.Type) + } + item := getter(block) + if item == nil { + return nil, fmt.Errorf("%s content is nil", typeName) + } + return item, nil + }) + if err != nil { + return nil, err + } + + concatenated, err := concatFunc(items) + if err != nil { + return nil, err + } + + return constructor(concatenated), nil +} + +func genericGetTFromContentBlocks[T any](blocks []*ContentBlock, checkAndGetter func(block *ContentBlock) (T, error)) ([]T, error) { + ret := make([]T, 0, len(blocks)) + for _, block := range blocks { + t, err := checkAndGetter(block) + if err != nil { + return nil, err + } + ret = append(ret, t) + } + return ret, nil +} + +// Concatenation strategies for different content block types: +// +// String concatenation (incremental streaming): +// - Reasoning: Summary texts are concatenated, grouped by Index if present +// - UserInputText: Text fields are concatenated +// - AssistantGenText: Text fields are concatenated, annotations/citations are merged +// - FunctionToolCall: Arguments (JSON strings) are concatenated incrementally +// - FunctionToolResult: Result strings are concatenated +// - ServerToolCall: Arguments are merged (last non-nil value for any type) +// - ServerToolResult: Results are merged using internal.ConcatItems +// - MCPToolCall: Arguments (JSON strings) are concatenated incrementally +// - MCPToolResult: Result strings are concatenated +// - MCPListToolsResult: Tools arrays are merged +// - MCPToolApprovalRequest: Arguments are concatenated +// +// Take last block (non-streaming content): +// - UserInputImage, UserInputAudio, UserInputVideo, UserInputFile: Return last block +// - AssistantGenImage, AssistantGenAudio, AssistantGenVideo: Return last block +// - MCPToolApprovalResponse: Return last block +// + +func concatReasoning(reasons []*Reasoning) (*Reasoning, error) { + if len(reasons) == 0 { + return nil, fmt.Errorf("no reasoning found") + } + if len(reasons) == 1 { + return reasons[0], nil + } + + ret := &Reasoning{ + Summary: make([]*ReasoningSummary, 0), + EncryptedContent: "", + Extra: make(map[string]any), + } + + // Collect all summaries from all reasons + allSummaries := make([]*ReasoningSummary, 0) + for _, r := range reasons { + if r == nil { + continue + } + allSummaries = append(allSummaries, r.Summary...) + if r.EncryptedContent != "" { + ret.EncryptedContent += r.EncryptedContent + } + for k, v := range r.Extra { + ret.Extra[k] = v + } + } + + // Group by Index and concatenate Text for same Index + // Use dynamic array that expands as needed + var summaryArray []*ReasoningSummary + for _, s := range allSummaries { + idx := s.Index + // Expand array if needed + summaryArray = expandSlice(int(idx), summaryArray) + if summaryArray[idx] == nil { + // Create new entry with a copy of Index + summaryArray[idx] = &ReasoningSummary{ + Index: idx, + Text: s.Text, + } + } else { + // Concatenate text for same index + summaryArray[idx].Text += s.Text + } + } + + // Convert array to slice, filtering out nil entries + ret.Summary = make([]*ReasoningSummary, 0, len(summaryArray)) + for _, summary := range summaryArray { + if summary != nil { + ret.Summary = append(ret.Summary, summary) + } + } + + return ret, nil +} + +func concatUserInputText(texts []*UserInputText) (*UserInputText, error) { + if len(texts) == 0 { + return nil, fmt.Errorf("no user input text found") + } + if len(texts) == 1 { + return texts[0], nil + } + + ret := &UserInputText{ + Text: "", + Extra: make(map[string]any), + } + + for _, t := range texts { + if t == nil { + continue + } + ret.Text += t.Text + for k, v := range t.Extra { + ret.Extra[k] = v + } + } + + return ret, nil +} + +func concatUserInputImage(images []*UserInputImage) (*UserInputImage, error) { + if len(images) == 0 { + return nil, fmt.Errorf("no user input image found") + } + return images[len(images)-1], nil +} + +func concatUserInputAudio(audios []*UserInputAudio) (*UserInputAudio, error) { + if len(audios) == 0 { + return nil, fmt.Errorf("no user input audio found") + } + return audios[len(audios)-1], nil +} + +func concatUserInputVideo(videos []*UserInputVideo) (*UserInputVideo, error) { + if len(videos) == 0 { + return nil, fmt.Errorf("no user input video found") + } + return videos[len(videos)-1], nil +} + +func concatUserInputFile(files []*UserInputFile) (*UserInputFile, error) { + if len(files) == 0 { + return nil, fmt.Errorf("no user input file found") + } + return files[len(files)-1], nil +} + +func concatAssistantGenText(texts []*AssistantGenText) (*AssistantGenText, error) { + if len(texts) == 0 { + return nil, fmt.Errorf("no assistant gen text found") + } + if len(texts) == 1 { + return texts[0], nil + } + + ret := &AssistantGenText{ + Text: "", + OpenAIExtension: nil, + ClaudeExtension: nil, + Extra: make(map[string]any), + } + + for _, t := range texts { + if t == nil { + continue + } + ret.Text += t.Text + if t.OpenAIExtension != nil { + if ret.OpenAIExtension == nil { + ret.OpenAIExtension = &openai.AssistantGenTextExtension{} + } + ret.OpenAIExtension.Annotations = append(ret.OpenAIExtension.Annotations, t.OpenAIExtension.Annotations...) + } + if t.ClaudeExtension != nil { + if ret.ClaudeExtension == nil { + ret.ClaudeExtension = &claude.AssistantGenTextExtension{} + } + ret.ClaudeExtension.Citations = append(ret.ClaudeExtension.Citations, t.ClaudeExtension.Citations...) + } + for k, v := range t.Extra { + ret.Extra[k] = v + } + } + + return ret, nil +} + +func concatAssistantGenImage(images []*AssistantGenImage) (*AssistantGenImage, error) { + if len(images) == 0 { + return nil, fmt.Errorf("no assistant gen image found") + } + return images[len(images)-1], nil +} + +func concatAssistantGenAudio(audios []*AssistantGenAudio) (*AssistantGenAudio, error) { + if len(audios) == 0 { + return nil, fmt.Errorf("no assistant gen audio found") + } + return audios[len(audios)-1], nil +} + +func concatAssistantGenVideo(videos []*AssistantGenVideo) (*AssistantGenVideo, error) { + if len(videos) == 0 { + return nil, fmt.Errorf("no assistant gen video found") + } + return videos[len(videos)-1], nil +} + +func concatFunctionToolCall(calls []*FunctionToolCall) (*FunctionToolCall, error) { + if len(calls) == 0 { + return nil, fmt.Errorf("no function tool call found") + } + if len(calls) == 1 { + return calls[0], nil + } + + // For tool calls, arguments are typically built incrementally during streaming + ret := &FunctionToolCall{ + Extra: make(map[string]any), + } + + for _, c := range calls { + if c == nil { + continue + } + if ret.CallID == "" { + ret.CallID = c.CallID + } + if ret.Name == "" { + ret.Name = c.Name + } + ret.Arguments += c.Arguments + for k, v := range c.Extra { + ret.Extra[k] = v + } + } + + return ret, nil +} + +func concatFunctionToolResult(results []*FunctionToolResult) (*FunctionToolResult, error) { + if len(results) == 0 { + return nil, fmt.Errorf("no function tool result found") + } + if len(results) == 1 { + return results[0], nil + } + + ret := &FunctionToolResult{ + Extra: make(map[string]any), + } + + for _, r := range results { + if r == nil { + continue + } + if ret.CallID == "" { + ret.CallID = r.CallID + } + if ret.Name == "" { + ret.Name = r.Name + } + ret.Result += r.Result + for k, v := range r.Extra { + ret.Extra[k] = v + } + } + + return ret, nil +} + +func concatServerToolCall(calls []*ServerToolCall) (*ServerToolCall, error) { + if len(calls) == 0 { + return nil, fmt.Errorf("no server tool call found") + } + if len(calls) == 1 { + return calls[0], nil + } + + // ServerToolCall Arguments is of type any; merge strategy uses the last non-nil value + ret := &ServerToolCall{ + Extra: make(map[string]any), + } + + for _, c := range calls { + if c == nil { + continue + } + if ret.Name == "" { + ret.Name = c.Name + } + if ret.CallID == "" { + ret.CallID = c.CallID + } + if c.Arguments != nil { + ret.Arguments = c.Arguments + } + for k, v := range c.Extra { + ret.Extra[k] = v + } + } + + return ret, nil +} + +func concatServerToolResult(results []*ServerToolResult) (*ServerToolResult, error) { + if len(results) == 0 { + return nil, fmt.Errorf("no server tool result found") + } + if len(results) == 1 { + return results[0], nil + } + + // ServerToolResult Result is of type any; merge strategy uses the last non-nil value + ret := &ServerToolResult{ + Extra: make(map[string]any), + } + + tZeroResult := reflect.TypeOf(results[0].Result) + data := reflect.MakeSlice(reflect.SliceOf(tZeroResult), 0, 0) + for _, r := range results { + if r == nil { + continue + } + if ret.Name == "" { + ret.Name = r.Name + } + if ret.CallID == "" { + ret.CallID = r.CallID + } + if r.Result != nil { + vResult := reflect.ValueOf(r.Result) + if tZeroResult != vResult.Type() { + return nil, fmt.Errorf("tool result types are different: %v %v", tZeroResult, vResult.Type()) + } + data = reflect.Append(data, vResult) + } + for k, v := range r.Extra { + ret.Extra[k] = v + } + } + + d, err := internal.ConcatSliceValue(data) + if err != nil { + return nil, fmt.Errorf("failed to concat server tool result: %v", err) + } + ret.Result = d + + return ret, nil +} + +func concatMCPToolCall(calls []*MCPToolCall) (*MCPToolCall, error) { + if len(calls) == 0 { + return nil, fmt.Errorf("no mcp tool call found") + } + if len(calls) == 1 { + return calls[0], nil + } + + ret := &MCPToolCall{ + Extra: make(map[string]any), + } + + for _, c := range calls { + if c == nil { + continue + } + if ret.ServerLabel == "" { + ret.ServerLabel = c.ServerLabel + } + if ret.ApprovalRequestID == "" { + ret.ApprovalRequestID = c.ApprovalRequestID + } + if ret.CallID == "" { + ret.CallID = c.CallID + } + if ret.Name == "" { + ret.Name = c.Name + } + ret.Arguments += c.Arguments + for k, v := range c.Extra { + ret.Extra[k] = v + } + } + + return ret, nil +} + +func concatMCPToolResult(results []*MCPToolResult) (*MCPToolResult, error) { + if len(results) == 0 { + return nil, fmt.Errorf("no mcp tool result found") + } + if len(results) == 1 { + return results[0], nil + } + + ret := &MCPToolResult{ + Extra: make(map[string]any), + } + + for _, r := range results { + if r == nil { + continue + } + if ret.CallID == "" { + ret.CallID = r.CallID + } + if ret.Name == "" { + ret.Name = r.Name + } + ret.Result += r.Result + if r.Error != nil { + ret.Error = r.Error // Use the last error + } + for k, v := range r.Extra { + ret.Extra[k] = v + } + } + + return ret, nil +} + +func concatMCPListToolsResult(results []*MCPListToolsResult) (*MCPListToolsResult, error) { + if len(results) == 0 { + return nil, fmt.Errorf("no mcp list tools result found") + } + if len(results) == 1 { + return results[0], nil + } + + ret := &MCPListToolsResult{ + Tools: make([]*MCPListToolsItem, 0), + Extra: make(map[string]any), + } + + for _, r := range results { + if r == nil { + continue + } + if ret.ServerLabel == "" { + ret.ServerLabel = r.ServerLabel + } + ret.Tools = append(ret.Tools, r.Tools...) + if r.Error != "" { + ret.Error = r.Error // Use the last error + } + for k, v := range r.Extra { + ret.Extra[k] = v + } + } + + return ret, nil +} + +func concatMCPToolApprovalRequest(requests []*MCPToolApprovalRequest) (*MCPToolApprovalRequest, error) { + if len(requests) == 0 { + return nil, fmt.Errorf("no mcp tool approval request found") + } + if len(requests) == 1 { + return requests[0], nil + } + + ret := &MCPToolApprovalRequest{ + Extra: make(map[string]any), + } + + for _, r := range requests { + if r == nil { + continue + } + if ret.ID == "" { + ret.ID = r.ID + } + if ret.Name == "" { + ret.Name = r.Name + } + ret.Arguments += r.Arguments + if ret.ServerLabel == "" { + ret.ServerLabel = r.ServerLabel + } + for k, v := range r.Extra { + ret.Extra[k] = v + } + } + + return ret, nil +} + +func concatMCPToolApprovalResponse(responses []*MCPToolApprovalResponse) (*MCPToolApprovalResponse, error) { + if len(responses) == 0 { + return nil, fmt.Errorf("no mcp tool approval response found") + } + if len(responses) == 1 { + return responses[0], nil + } + + return responses[len(responses)-1], nil +} + +func expandSlice[T any](idx int, s []T) []T { + if len(s) > idx { + return s + } + return append(s, make([]T, idx-len(s)+1)...) +} + +func (m *AgenticMessage) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf("role: %s\n", m.Role)) + + if len(m.ContentBlocks) > 0 { + sb.WriteString("content_blocks:\n") + for i, block := range m.ContentBlocks { + if block == nil { + continue + } + sb.WriteString(fmt.Sprintf(" [%d] %s", i, block.String())) + } + } + + if m.ResponseMeta != nil { + sb.WriteString(m.ResponseMeta.String()) + } + + return sb.String() +} + +func (b *ContentBlock) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf("type: %s\n", b.Type)) + + switch b.Type { + case ContentBlockTypeReasoning: + if b.Reasoning != nil { + sb.WriteString(b.Reasoning.String()) + } + case ContentBlockTypeUserInputText: + if b.UserInputText != nil { + sb.WriteString(b.UserInputText.String()) + } + case ContentBlockTypeUserInputImage: + if b.UserInputImage != nil { + sb.WriteString(b.UserInputImage.String()) + } + case ContentBlockTypeUserInputAudio: + if b.UserInputAudio != nil { + sb.WriteString(b.UserInputAudio.String()) + } + case ContentBlockTypeUserInputVideo: + if b.UserInputVideo != nil { + sb.WriteString(b.UserInputVideo.String()) + } + case ContentBlockTypeUserInputFile: + if b.UserInputFile != nil { + sb.WriteString(b.UserInputFile.String()) + } + case ContentBlockTypeAssistantGenText: + if b.AssistantGenText != nil { + sb.WriteString(b.AssistantGenText.String()) + } + case ContentBlockTypeAssistantGenImage: + if b.AssistantGenImage != nil { + sb.WriteString(b.AssistantGenImage.String()) + } + case ContentBlockTypeAssistantGenAudio: + if b.AssistantGenAudio != nil { + sb.WriteString(b.AssistantGenAudio.String()) + } + case ContentBlockTypeAssistantGenVideo: + if b.AssistantGenVideo != nil { + sb.WriteString(b.AssistantGenVideo.String()) + } + case ContentBlockTypeFunctionToolCall: + if b.FunctionToolCall != nil { + sb.WriteString(b.FunctionToolCall.String()) + } + case ContentBlockTypeFunctionToolResult: + if b.FunctionToolResult != nil { + sb.WriteString(b.FunctionToolResult.String()) + } + case ContentBlockTypeServerToolCall: + if b.ServerToolCall != nil { + sb.WriteString(b.ServerToolCall.String()) + } + case ContentBlockTypeServerToolResult: + if b.ServerToolResult != nil { + sb.WriteString(b.ServerToolResult.String()) + } + case ContentBlockTypeMCPToolCall: + if b.MCPToolCall != nil { + sb.WriteString(b.MCPToolCall.String()) + } + case ContentBlockTypeMCPToolResult: + if b.MCPToolResult != nil { + sb.WriteString(b.MCPToolResult.String()) + } + case ContentBlockTypeMCPListToolsResult: + if b.MCPListToolsResult != nil { + sb.WriteString(b.MCPListToolsResult.String()) + } + case ContentBlockTypeMCPToolApprovalRequest: + if b.MCPToolApprovalRequest != nil { + sb.WriteString(b.MCPToolApprovalRequest.String()) + } + case ContentBlockTypeMCPToolApprovalResponse: + if b.MCPToolApprovalResponse != nil { + sb.WriteString(b.MCPToolApprovalResponse.String()) + } + } + + if b.StreamMeta != nil { + sb.WriteString(fmt.Sprintf(" stream_index: %d\n", b.StreamMeta.Index)) + } + + return sb.String() +} + +func (r *Reasoning) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" summary: %d items\n", len(r.Summary))) + for _, s := range r.Summary { + sb.WriteString(fmt.Sprintf(" [%d] %s\n", s.Index, s.Text)) + } + if r.EncryptedContent != "" { + sb.WriteString(fmt.Sprintf(" encrypted_content: %s\n", truncateString(r.EncryptedContent, 50))) + } + return sb.String() +} + +func (u *UserInputText) String() string { + return fmt.Sprintf(" text: %s\n", u.Text) +} + +func (u *UserInputImage) String() string { + return formatMediaString(u.URL, u.Base64Data, u.MIMEType, u.Detail) +} + +func (u *UserInputAudio) String() string { + return formatMediaString(u.URL, u.Base64Data, u.MIMEType, "") +} + +func (u *UserInputVideo) String() string { + return formatMediaString(u.URL, u.Base64Data, u.MIMEType, "") +} + +func (u *UserInputFile) String() string { + sb := &strings.Builder{} + if u.Name != "" { + sb.WriteString(fmt.Sprintf(" name: %s\n", u.Name)) + } + sb.WriteString(formatMediaString(u.URL, u.Base64Data, u.MIMEType, "")) + return sb.String() +} + +func (a *AssistantGenText) String() string { + return fmt.Sprintf(" text: %s\n", a.Text) +} + +func (a *AssistantGenImage) String() string { + return formatMediaString(a.URL, a.Base64Data, a.MIMEType, "") +} + +func (a *AssistantGenAudio) String() string { + return formatMediaString(a.URL, a.Base64Data, a.MIMEType, "") +} + +func (a *AssistantGenVideo) String() string { + return formatMediaString(a.URL, a.Base64Data, a.MIMEType, "") +} + +func (f *FunctionToolCall) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" call_id: %s\n", f.CallID)) + sb.WriteString(fmt.Sprintf(" name: %s\n", f.Name)) + sb.WriteString(fmt.Sprintf(" arguments: %s\n", f.Arguments)) + return sb.String() +} + +func (f *FunctionToolResult) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" call_id: %s\n", f.CallID)) + sb.WriteString(fmt.Sprintf(" name: %s\n", f.Name)) + sb.WriteString(fmt.Sprintf(" result: %s\n", f.Result)) + return sb.String() +} + +func (s *ServerToolCall) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" name: %s\n", s.Name)) + if s.CallID != "" { + sb.WriteString(fmt.Sprintf(" call_id: %s\n", s.CallID)) + } + sb.WriteString(fmt.Sprintf(" arguments: %v\n", s.Arguments)) + return sb.String() +} + +func (s *ServerToolResult) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" name: %s\n", s.Name)) + if s.CallID != "" { + sb.WriteString(fmt.Sprintf(" call_id: %s\n", s.CallID)) + } + sb.WriteString(fmt.Sprintf(" result: %v\n", s.Result)) + return sb.String() +} + +func (m *MCPToolCall) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" server_label: %s\n", m.ServerLabel)) + sb.WriteString(fmt.Sprintf(" call_id: %s\n", m.CallID)) + sb.WriteString(fmt.Sprintf(" name: %s\n", m.Name)) + sb.WriteString(fmt.Sprintf(" arguments: %s\n", m.Arguments)) + if m.ApprovalRequestID != "" { + sb.WriteString(fmt.Sprintf(" approval_request_id: %s\n", m.ApprovalRequestID)) + } + return sb.String() +} + +func (m *MCPToolResult) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" call_id: %s\n", m.CallID)) + sb.WriteString(fmt.Sprintf(" name: %s\n", m.Name)) + sb.WriteString(fmt.Sprintf(" result: %s\n", m.Result)) + if m.Error != nil { + sb.WriteString(fmt.Sprintf(" error: [%d] %s\n", *m.Error.Code, m.Error.Message)) + } + return sb.String() +} + +func (m *MCPListToolsResult) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" server_label: %s\n", m.ServerLabel)) + sb.WriteString(fmt.Sprintf(" tools: %d items\n", len(m.Tools))) + for _, tool := range m.Tools { + sb.WriteString(fmt.Sprintf(" - %s: %s\n", tool.Name, tool.Description)) + } + if m.Error != "" { + sb.WriteString(fmt.Sprintf(" error: %s\n", m.Error)) + } + return sb.String() +} + +func (m *MCPToolApprovalRequest) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" server_label: %s\n", m.ServerLabel)) + sb.WriteString(fmt.Sprintf(" id: %s\n", m.ID)) + sb.WriteString(fmt.Sprintf(" name: %s\n", m.Name)) + sb.WriteString(fmt.Sprintf(" arguments: %s\n", m.Arguments)) + return sb.String() +} + +func (m *MCPToolApprovalResponse) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" approval_request_id: %s\n", m.ApprovalRequestID)) + sb.WriteString(fmt.Sprintf(" approve: %v\n", m.Approve)) + if m.Reason != "" { + sb.WriteString(fmt.Sprintf(" reason: %s\n", m.Reason)) + } + return sb.String() +} + +func (a *AgenticResponseMeta) String() string { + sb := &strings.Builder{} + sb.WriteString("response_meta:\n") + if a.TokenUsage != nil { + sb.WriteString(fmt.Sprintf(" token_usage: prompt=%d, completion=%d, total=%d\n", + a.TokenUsage.PromptTokens, + a.TokenUsage.CompletionTokens, + a.TokenUsage.TotalTokens)) + } + return sb.String() +} + +// truncateString truncates a string to maxLen characters, adding "..." if truncated +func truncateString(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + return s[:maxLen] + "..." +} + +// formatMediaString formats URL, Base64Data, MIMEType and Detail for media content +func formatMediaString(url, base64Data string, mimeType string, detail any) string { + sb := &strings.Builder{} + if url != "" { + sb.WriteString(fmt.Sprintf(" url: %s\n", truncateString(url, 100))) + } + if base64Data != "" { + // Only show first few characters of base64 data + sb.WriteString(fmt.Sprintf(" base64_data: %s... (%d bytes)\n", truncateString(base64Data, 20), len(base64Data))) + } + if mimeType != "" { + sb.WriteString(fmt.Sprintf(" mime_type: %s\n", mimeType)) + } + if detail != nil && detail != "" { + sb.WriteString(fmt.Sprintf(" detail: %v\n", detail)) + } + return sb.String() +} diff --git a/schema/agentic_message_test.go b/schema/agentic_message_test.go new file mode 100644 index 000000000..0cafcd9ff --- /dev/null +++ b/schema/agentic_message_test.go @@ -0,0 +1,1381 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package schema + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConcatAgenticMessages(t *testing.T) { + t.Run("single message", func(t *testing.T) { + msg := &AgenticMessage{ + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Hello", + }, + }, + }, + } + + result, err := ConcatAgenticMessages([]*AgenticMessage{msg}) + assert.NoError(t, err) + assert.Equal(t, msg, result) + }) + + t.Run("nil message in stream", func(t *testing.T) { + msgs := []*AgenticMessage{ + {Role: AgenticRoleTypeAssistant}, + nil, + {Role: AgenticRoleTypeAssistant}, + } + + _, err := ConcatAgenticMessages(msgs) + assert.Error(t, err) + assert.ErrorContains(t, err, "message at index 1 is nil") + }) + + t.Run("different roles", func(t *testing.T) { + msgs := []*AgenticMessage{ + {Role: AgenticRoleTypeUser}, + {Role: AgenticRoleTypeAssistant}, + } + + _, err := ConcatAgenticMessages(msgs) + assert.Error(t, err) + assert.ErrorContains(t, err, "cannot concat messages with different roles") + }) + + t.Run("concat text blocks", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Hello ", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "World!", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Equal(t, AgenticRoleTypeAssistant, result.Role) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "Hello World!", result.ContentBlocks[0].AssistantGenText.Text) + }) + + t.Run("concat reasoning with nil index", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeReasoning, + Reasoning: &Reasoning{ + Summary: []*ReasoningSummary{ + {Index: 0, Text: "First "}, + }, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeReasoning, + Reasoning: &Reasoning{ + Summary: []*ReasoningSummary{ + {Index: 0, Text: "Second"}, + }, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Len(t, result.ContentBlocks[0].Reasoning.Summary, 1) + assert.Equal(t, "First Second", result.ContentBlocks[0].Reasoning.Summary[0].Text) + assert.Equal(t, int64(0), result.ContentBlocks[0].Reasoning.Summary[0].Index) + }) + + t.Run("concat reasoning with index", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeReasoning, + Reasoning: &Reasoning{ + Summary: []*ReasoningSummary{ + {Index: 0, Text: "Part1-"}, + {Index: 1, Text: "Part2-"}, + }, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeReasoning, + Reasoning: &Reasoning{ + Summary: []*ReasoningSummary{ + {Index: 0, Text: "Part3"}, + {Index: 1, Text: "Part4"}, + }, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Len(t, result.ContentBlocks[0].Reasoning.Summary, 2) + assert.Equal(t, "Part1-Part3", result.ContentBlocks[0].Reasoning.Summary[0].Text) + assert.Equal(t, "Part2-Part4", result.ContentBlocks[0].Reasoning.Summary[1].Text) + }) + + t.Run("concat user input text", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputText, + UserInputText: &UserInputText{ + Text: "Hello ", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputText, + UserInputText: &UserInputText{ + Text: "World!", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "Hello World!", result.ContentBlocks[0].UserInputText.Text) + }) + + t.Run("concat user input image", func(t *testing.T) { + url1 := "https://example.com/image1.jpg" + url2 := "https://example.com/image2.jpg" + + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputImage, + UserInputImage: &UserInputImage{ + URL: url1, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputImage, + UserInputImage: &UserInputImage{ + URL: url2, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + // Should take the last image + assert.Equal(t, url2, result.ContentBlocks[0].UserInputImage.URL) + }) + + t.Run("concat user input audio", func(t *testing.T) { + url1 := "https://example.com/audio1.mp3" + url2 := "https://example.com/audio2.mp3" + + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputAudio, + UserInputAudio: &UserInputAudio{ + URL: url1, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputAudio, + UserInputAudio: &UserInputAudio{ + URL: url2, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + // Should take the last audio + assert.Equal(t, url2, result.ContentBlocks[0].UserInputAudio.URL) + }) + + t.Run("concat user input video", func(t *testing.T) { + url1 := "https://example.com/video1.mp4" + url2 := "https://example.com/video2.mp4" + + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputVideo, + UserInputVideo: &UserInputVideo{ + URL: url1, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputVideo, + UserInputVideo: &UserInputVideo{ + URL: url2, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + // Should take the last video + assert.Equal(t, url2, result.ContentBlocks[0].UserInputVideo.URL) + }) + + t.Run("concat assistant gen text", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Generated ", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Text", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "Generated Text", result.ContentBlocks[0].AssistantGenText.Text) + }) + + t.Run("concat assistant gen image", func(t *testing.T) { + url1 := "https://example.com/gen_image1.jpg" + url2 := "https://example.com/gen_image2.jpg" + + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenImage, + AssistantGenImage: &AssistantGenImage{ + URL: url1, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenImage, + AssistantGenImage: &AssistantGenImage{ + URL: url2, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + // Should take the last image + assert.Equal(t, url2, result.ContentBlocks[0].AssistantGenImage.URL) + }) + + t.Run("concat assistant gen audio", func(t *testing.T) { + url1 := "https://example.com/gen_audio1.mp3" + url2 := "https://example.com/gen_audio2.mp3" + + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenAudio, + AssistantGenAudio: &AssistantGenAudio{ + URL: url1, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenAudio, + AssistantGenAudio: &AssistantGenAudio{ + URL: url2, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + // Should take the last audio + assert.Equal(t, url2, result.ContentBlocks[0].AssistantGenAudio.URL) + }) + + t.Run("concat assistant gen video", func(t *testing.T) { + url1 := "https://example.com/gen_video1.mp4" + url2 := "https://example.com/gen_video2.mp4" + + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenVideo, + AssistantGenVideo: &AssistantGenVideo{ + URL: url1, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenVideo, + AssistantGenVideo: &AssistantGenVideo{ + URL: url2, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + // Should take the last video + assert.Equal(t, url2, result.ContentBlocks[0].AssistantGenVideo.URL) + }) + + t.Run("concat function tool call", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeFunctionToolCall, + FunctionToolCall: &FunctionToolCall{ + CallID: "call_123", + Name: "get_weather", + Arguments: `{"location`, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeFunctionToolCall, + FunctionToolCall: &FunctionToolCall{ + Arguments: `":"NYC"}`, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "call_123", result.ContentBlocks[0].FunctionToolCall.CallID) + assert.Equal(t, "get_weather", result.ContentBlocks[0].FunctionToolCall.Name) + assert.Equal(t, `{"location":"NYC"}`, result.ContentBlocks[0].FunctionToolCall.Arguments) + }) + + t.Run("concat function tool result", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeFunctionToolResult, + FunctionToolResult: &FunctionToolResult{ + CallID: "call_123", + Name: "get_weather", + Result: `{"temp`, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeFunctionToolResult, + FunctionToolResult: &FunctionToolResult{ + Result: `":72}`, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "call_123", result.ContentBlocks[0].FunctionToolResult.CallID) + assert.Equal(t, "get_weather", result.ContentBlocks[0].FunctionToolResult.Name) + assert.Equal(t, `{"temp":72}`, result.ContentBlocks[0].FunctionToolResult.Result) + }) + + t.Run("concat server tool call", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeServerToolCall, + ServerToolCall: &ServerToolCall{ + CallID: "server_call_1", + Name: "server_func", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeServerToolCall, + ServerToolCall: &ServerToolCall{ + Arguments: map[string]any{"key": "value"}, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "server_call_1", result.ContentBlocks[0].ServerToolCall.CallID) + assert.Equal(t, "server_func", result.ContentBlocks[0].ServerToolCall.Name) + assert.NotNil(t, result.ContentBlocks[0].ServerToolCall.Arguments) + }) + + t.Run("concat server tool result", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeServerToolResult, + ServerToolResult: &ServerToolResult{ + CallID: "server_call_1", + Name: "server_func", + Result: "result1", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeServerToolResult, + ServerToolResult: &ServerToolResult{ + Result: "result2", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "server_call_1", result.ContentBlocks[0].ServerToolResult.CallID) + assert.Equal(t, "server_func", result.ContentBlocks[0].ServerToolResult.Name) + }) + + t.Run("concat mcp tool call", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolCall, + MCPToolCall: &MCPToolCall{ + ServerLabel: "mcp-server", + CallID: "mcp_call_1", + Name: "mcp_func", + Arguments: `{"arg`, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolCall, + MCPToolCall: &MCPToolCall{ + Arguments: `":123}`, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "mcp-server", result.ContentBlocks[0].MCPToolCall.ServerLabel) + assert.Equal(t, "mcp_call_1", result.ContentBlocks[0].MCPToolCall.CallID) + assert.Equal(t, "mcp_func", result.ContentBlocks[0].MCPToolCall.Name) + assert.Equal(t, `{"arg":123}`, result.ContentBlocks[0].MCPToolCall.Arguments) + }) + + t.Run("concat mcp tool result", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolResult, + MCPToolResult: &MCPToolResult{ + CallID: "mcp_call_1", + Name: "mcp_func", + Result: `{"res`, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolResult, + MCPToolResult: &MCPToolResult{ + Result: `ult":true}`, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "mcp_call_1", result.ContentBlocks[0].MCPToolResult.CallID) + assert.Equal(t, "mcp_func", result.ContentBlocks[0].MCPToolResult.Name) + assert.Equal(t, `{"result":true}`, result.ContentBlocks[0].MCPToolResult.Result) + }) + + t.Run("concat mcp list tools", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPListToolsResult, + MCPListToolsResult: &MCPListToolsResult{ + ServerLabel: "mcp-server", + Tools: []*MCPListToolsItem{ + {Name: "tool1"}, + }, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPListToolsResult, + MCPListToolsResult: &MCPListToolsResult{ + Tools: []*MCPListToolsItem{ + {Name: "tool2"}, + }, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "mcp-server", result.ContentBlocks[0].MCPListToolsResult.ServerLabel) + assert.Len(t, result.ContentBlocks[0].MCPListToolsResult.Tools, 2) + }) + + t.Run("concat mcp tool approval request", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolApprovalRequest, + MCPToolApprovalRequest: &MCPToolApprovalRequest{ + ID: "approval_1", + Name: "approval_func", + ServerLabel: "mcp-server", + Arguments: `{"request`, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolApprovalRequest, + MCPToolApprovalRequest: &MCPToolApprovalRequest{ + Arguments: `":1}`, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "approval_1", result.ContentBlocks[0].MCPToolApprovalRequest.ID) + assert.Equal(t, "approval_func", result.ContentBlocks[0].MCPToolApprovalRequest.Name) + assert.Equal(t, "mcp-server", result.ContentBlocks[0].MCPToolApprovalRequest.ServerLabel) + assert.Equal(t, `{"request":1}`, result.ContentBlocks[0].MCPToolApprovalRequest.Arguments) + }) + + t.Run("concat mcp tool approval response", func(t *testing.T) { + response1 := &MCPToolApprovalResponse{ + ApprovalRequestID: "approval_1", + Approve: false, + } + response2 := &MCPToolApprovalResponse{ + ApprovalRequestID: "approval_1", + Approve: true, + } + + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolApprovalResponse, + MCPToolApprovalResponse: response1, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolApprovalResponse, + MCPToolApprovalResponse: response2, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + // Should take the last response + assert.Equal(t, response2, result.ContentBlocks[0].MCPToolApprovalResponse) + }) + + t.Run("concat response meta", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ResponseMeta: &AgenticResponseMeta{ + TokenUsage: &TokenUsage{ + PromptTokens: 10, + CompletionTokens: 5, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ResponseMeta: &AgenticResponseMeta{ + TokenUsage: &TokenUsage{ + PromptTokens: 10, + CompletionTokens: 15, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.NotNil(t, result.ResponseMeta) + assert.Equal(t, 20, result.ResponseMeta.TokenUsage.CompletionTokens) + assert.Equal(t, 20, result.ResponseMeta.TokenUsage.PromptTokens) + }) + + t.Run("mixed streaming and non-streaming blocks error", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Hello", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "World", + }, + // No StreamMeta - non-streaming + }, + }, + }, + } + + _, err := ConcatAgenticMessages(msgs) + assert.Error(t, err) + assert.ErrorContains(t, err, "found non-streaming block after streaming blocks") + }) + + t.Run("concat MCP tool call", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolCall, + MCPToolCall: &MCPToolCall{ + ServerLabel: "mcp-server", + CallID: "call_456", + Name: "list_files", + Arguments: `{"path`, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolCall, + MCPToolCall: &MCPToolCall{ + Arguments: `":"/tmp"}`, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "mcp-server", result.ContentBlocks[0].MCPToolCall.ServerLabel) + assert.Equal(t, "call_456", result.ContentBlocks[0].MCPToolCall.CallID) + assert.Equal(t, `{"path":"/tmp"}`, result.ContentBlocks[0].MCPToolCall.Arguments) + }) + + t.Run("concat user input text", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputText, + UserInputText: &UserInputText{ + Text: "What is ", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputText, + UserInputText: &UserInputText{ + Text: "the weather?", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "What is the weather?", result.ContentBlocks[0].UserInputText.Text) + }) + + t.Run("multiple stream indexes - sparse indexes", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Index0-", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Index2-", + }, + StreamMeta: &StreamMeta{Index: 2}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Part2", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Part2", + }, + StreamMeta: &StreamMeta{Index: 2}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 2) + assert.Equal(t, "Index0-Part2", result.ContentBlocks[0].AssistantGenText.Text) + assert.Equal(t, "Index2-Part2", result.ContentBlocks[1].AssistantGenText.Text) + }) + + t.Run("multiple stream indexes - mixed content types", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Text ", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + { + Type: ContentBlockTypeFunctionToolCall, + FunctionToolCall: &FunctionToolCall{ + CallID: "call_1", + Name: "func1", + Arguments: `{"a`, + }, + StreamMeta: &StreamMeta{Index: 1}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Content", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + { + Type: ContentBlockTypeFunctionToolCall, + FunctionToolCall: &FunctionToolCall{ + Arguments: `":1}`, + }, + StreamMeta: &StreamMeta{Index: 1}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 2) + assert.Equal(t, "Text Content", result.ContentBlocks[0].AssistantGenText.Text) + assert.Equal(t, "call_1", result.ContentBlocks[1].FunctionToolCall.CallID) + assert.Equal(t, "func1", result.ContentBlocks[1].FunctionToolCall.Name) + assert.Equal(t, `{"a":1}`, result.ContentBlocks[1].FunctionToolCall.Arguments) + }) + + t.Run("multiple stream indexes - three indexes", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "A", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "B", + }, + StreamMeta: &StreamMeta{Index: 1}, + }, + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "C", + }, + StreamMeta: &StreamMeta{Index: 2}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "1", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "2", + }, + StreamMeta: &StreamMeta{Index: 1}, + }, + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "3", + }, + StreamMeta: &StreamMeta{Index: 2}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 3) + assert.Equal(t, "A1", result.ContentBlocks[0].AssistantGenText.Text) + assert.Equal(t, "B2", result.ContentBlocks[1].AssistantGenText.Text) + assert.Equal(t, "C3", result.ContentBlocks[2].AssistantGenText.Text) + }) +} + +func TestAgenticMessageFormat(t *testing.T) { + m := &AgenticMessage{ + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputText, + UserInputText: &UserInputText{Text: "{a}"}, + }, + { + Type: ContentBlockTypeUserInputImage, + UserInputImage: &UserInputImage{ + URL: "{b}", + Base64Data: "{c}", + }, + }, + { + Type: ContentBlockTypeUserInputAudio, + UserInputAudio: &UserInputAudio{ + URL: "{d}", + Base64Data: "{e}", + }, + }, + { + Type: ContentBlockTypeUserInputVideo, + UserInputVideo: &UserInputVideo{ + URL: "{f}", + Base64Data: "{g}", + }, + }, + { + Type: ContentBlockTypeUserInputFile, + UserInputFile: &UserInputFile{ + URL: "{h}", + Base64Data: "{i}", + }, + }, + }, + } + + result, err := m.Format(context.Background(), map[string]any{ + "a": "1", "b": "2", "c": "3", "d": "4", "e": "5", "f": "6", "g": "7", "h": "8", "i": "9", + }, FString) + assert.NoError(t, err) + assert.Equal(t, []*AgenticMessage{{ + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputText, + UserInputText: &UserInputText{Text: "1"}, + }, + { + Type: ContentBlockTypeUserInputImage, + UserInputImage: &UserInputImage{ + URL: "2", + Base64Data: "3", + }, + }, + { + Type: ContentBlockTypeUserInputAudio, + UserInputAudio: &UserInputAudio{ + URL: "4", + Base64Data: "5", + }, + }, + { + Type: ContentBlockTypeUserInputVideo, + UserInputVideo: &UserInputVideo{ + URL: "6", + Base64Data: "7", + }, + }, + { + Type: ContentBlockTypeUserInputFile, + UserInputFile: &UserInputFile{ + URL: "8", + Base64Data: "9", + }, + }, + }, + }}, result) +} + +func TestAgenticPlaceholderFormat(t *testing.T) { + ctx := context.Background() + ph := AgenticMessagesPlaceholder("a", false) + + result, err := ph.Format(ctx, map[string]any{ + "a": []*AgenticMessage{{Role: AgenticRoleTypeUser}, {Role: AgenticRoleTypeUser}}, + }, FString) + assert.NoError(t, err) + assert.Equal(t, 2, len(result)) + + ph = AgenticMessagesPlaceholder("a", true) + + result, err = ph.Format(ctx, map[string]any{}, FString) + assert.NoError(t, err) + assert.Equal(t, 0, len(result)) +} + +func ptrOf[T any](v T) *T { + return &v +} + +func TestAgenticMessageString(t *testing.T) { + longBase64 := "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" + + msg := &AgenticMessage{ + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputText, + UserInputText: &UserInputText{ + Text: "What's the weather like in New York City today?", + }, + }, + { + Type: ContentBlockTypeUserInputImage, + UserInputImage: &UserInputImage{ + URL: "https://example.com/weather-map.jpg", + Base64Data: longBase64, + MIMEType: "image/jpeg", + Detail: ImageURLDetailHigh, + }, + }, + { + Type: ContentBlockTypeReasoning, + Reasoning: &Reasoning{ + Summary: []*ReasoningSummary{ + {Index: 0, Text: "First, I need to identify the location (New York City) from the user's query."}, + {Index: 1, Text: "Then, I should call the weather API to get current conditions."}, + {Index: 2, Text: "Finally, I'll format the response in a user-friendly way with temperature and conditions."}, + }, + EncryptedContent: "encrypted_reasoning_content_that_is_very_long_and_will_be_truncated_for_display", + }, + }, + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "I'll check the current weather in New York City for you.", + }, + }, + { + Type: ContentBlockTypeFunctionToolCall, + FunctionToolCall: &FunctionToolCall{ + CallID: "call_weather_123", + Name: "get_current_weather", + Arguments: `{"location":"New York City","unit":"fahrenheit"}`, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + { + Type: ContentBlockTypeFunctionToolResult, + FunctionToolResult: &FunctionToolResult{ + CallID: "call_weather_123", + Name: "get_current_weather", + Result: `{"temperature":72,"condition":"sunny","humidity":45,"wind_speed":8}`, + }, + }, + { + Type: ContentBlockTypeMCPToolCall, + MCPToolCall: &MCPToolCall{ + ServerLabel: "weather-mcp-server", + CallID: "mcp_forecast_456", + Name: "get_7day_forecast", + Arguments: `{"city":"New York","days":7}`, + ApprovalRequestID: "approval_req_789", + }, + }, + { + Type: ContentBlockTypeMCPToolResult, + MCPToolResult: &MCPToolResult{ + CallID: "mcp_forecast_456", + Name: "get_7day_forecast", + Result: `{"status":"partial","days_available":3}`, + Error: &MCPToolCallError{ + Code: ptrOf[int64](503), + Message: "Service temporarily unavailable for full 7-day forecast", + }, + }, + }, + { + Type: ContentBlockTypeMCPListToolsResult, + MCPListToolsResult: &MCPListToolsResult{ + ServerLabel: "weather-mcp-server", + Tools: []*MCPListToolsItem{ + {Name: "get_current_weather", Description: "Get current weather conditions for a location"}, + {Name: "get_7day_forecast", Description: "Get 7-day weather forecast"}, + {Name: "get_weather_alerts", Description: "Get active weather alerts and warnings"}, + }, + }, + }, + }, + ResponseMeta: &AgenticResponseMeta{ + TokenUsage: &TokenUsage{ + PromptTokens: 250, + CompletionTokens: 180, + TotalTokens: 430, + }, + }, + } + + // Print the formatted output + output := msg.String() + + assert.Equal(t, `role: assistant +content_blocks: + [0] type: user_input_text + text: What's the weather like in New York City today? + [1] type: user_input_image + url: https://example.com/weather-map.jpg + base64_data: iVBORw0KGgoAAAANSUhE...... (96 bytes) + mime_type: image/jpeg + detail: high + [2] type: reasoning + summary: 3 items + [0] First, I need to identify the location (New York City) from the user's query. + [1] Then, I should call the weather API to get current conditions. + [2] Finally, I'll format the response in a user-friendly way with temperature and conditions. + encrypted_content: encrypted_reasoning_content_that_is_very_long_and_... + [3] type: assistant_gen_text + text: I'll check the current weather in New York City for you. + [4] type: function_tool_call + call_id: call_weather_123 + name: get_current_weather + arguments: {"location":"New York City","unit":"fahrenheit"} + stream_index: 0 + [5] type: function_tool_result + call_id: call_weather_123 + name: get_current_weather + result: {"temperature":72,"condition":"sunny","humidity":45,"wind_speed":8} + [6] type: mcp_tool_call + server_label: weather-mcp-server + call_id: mcp_forecast_456 + name: get_7day_forecast + arguments: {"city":"New York","days":7} + approval_request_id: approval_req_789 + [7] type: mcp_tool_result + call_id: mcp_forecast_456 + name: get_7day_forecast + result: {"status":"partial","days_available":3} + error: [503] Service temporarily unavailable for full 7-day forecast + [8] type: mcp_list_tools_result + server_label: weather-mcp-server + tools: 3 items + - get_current_weather: Get current weather conditions for a location + - get_7day_forecast: Get 7-day weather forecast + - get_weather_alerts: Get active weather alerts and warnings +response_meta: + token_usage: prompt=250, completion=180, total=430 +`, output) +} diff --git a/schema/message.go b/schema/message.go index fefb2079e..02cdefaad 100644 --- a/schema/message.go +++ b/schema/message.go @@ -40,47 +40,56 @@ func init() { internal.RegisterStreamChunkConcatFunc(ConcatMessages) internal.RegisterStreamChunkConcatFunc(ConcatMessageArray) + internal.RegisterStreamChunkConcatFunc(ConcatAgenticMessages) + internal.RegisterStreamChunkConcatFunc(ConcatAgenticMessagesArray) + internal.RegisterStreamChunkConcatFunc(ConcatToolResults) } -// ConcatMessageArray merges aligned slices of messages into a single slice, -// concatenating messages at the same index across the input arrays. -func ConcatMessageArray(mas [][]*Message) ([]*Message, error) { - arrayLen := len(mas[0]) +func buildConcatGenericArray[T any](f func([]*T) (*T, error)) func([][]*T) ([]*T, error) { + return func(mas [][]*T) ([]*T, error) { + arrayLen := len(mas[0]) - ret := make([]*Message, arrayLen) - slicesToConcat := make([][]*Message, arrayLen) + ret := make([]*T, arrayLen) + slicesToConcat := make([][]*T, arrayLen) - for _, ma := range mas { - if len(ma) != arrayLen { - return nil, fmt.Errorf("unexpected array length. "+ - "Got %d, expected %d", len(ma), arrayLen) - } + for _, ma := range mas { + if len(ma) != arrayLen { + return nil, fmt.Errorf("unexpected array length. "+ + "Got %d, expected %d", len(ma), arrayLen) + } - for i := 0; i < arrayLen; i++ { - m := ma[i] - if m != nil { - slicesToConcat[i] = append(slicesToConcat[i], m) + for i := 0; i < arrayLen; i++ { + m := ma[i] + if m != nil { + slicesToConcat[i] = append(slicesToConcat[i], m) + } } } - } - for i, slice := range slicesToConcat { - if len(slice) == 0 { - ret[i] = nil - } else if len(slice) == 1 { - ret[i] = slice[0] - } else { - cm, err := ConcatMessages(slice) - if err != nil { - return nil, err - } + for i, slice := range slicesToConcat { + if len(slice) == 0 { + ret[i] = nil + } else if len(slice) == 1 { + ret[i] = slice[0] + } else { + cm, err := f(slice) + if err != nil { + return nil, err + } - ret[i] = cm + ret[i] = cm + } } + + return ret, nil } +} - return ret, nil +// ConcatMessageArray merges aligned slices of messages into a single slice, +// concatenating messages at the same index across the input arrays. +func ConcatMessageArray(mas [][]*Message) ([]*Message, error) { + return buildConcatGenericArray[Message](ConcatMessages)(mas) } // FormatType used by MessageTemplate.Format @@ -721,7 +730,7 @@ var _ MessagesTemplate = MessagesPlaceholder("", false) // e.g. // // chatTemplate := prompt.FromMessages( -// schema.SystemMessage("you are eino helper"), +// schema.SystemMessage("you are an eino helper"), // schema.MessagesPlaceholder("history", false), // <= this will use the value of "history" in params // ) // msgs, err := chatTemplate.Format(ctx, params) @@ -739,7 +748,7 @@ type messagesPlaceholder struct { // // placeholder := MessagesPlaceholder("history", false) // params := map[string]any{ -// "history": []*schema.Message{{Role: "user", Content: "what is eino?"}, {Role: "assistant", Content: "eino is a great freamwork to build llm apps"}}, +// "history": []*schema.Message{{Role: "user", Content: "what is eino?"}, {Role: "assistant", Content: "eino is a great framework to build llm apps"}}, // "query": "how to use eino?", // } // chatTemplate := chatTpl := prompt.FromMessages( From 2eaf7afa7582b4b3e63a8cb29bb6a8566ba182ca Mon Sep 17 00:00:00 2001 From: mrh997 Date: Tue, 6 Jan 2026 16:48:56 +0800 Subject: [PATCH 16/59] fix: concat agentic messages (#604) --- components/agentic/callback_extra.go | 5 +- components/agentic/option.go | 7 +- components/agentic/option_test.go | 2 +- components/model/callback_extra.go | 6 +- components/prompt/callback_extra.go | 2 + components/types.go | 2 + compose/tools_node_agentic.go | 7 +- compose/tools_node_agentic_test.go | 37 +- schema/agentic_message.go | 1052 ++++++++++------- schema/agentic_message_test.go | 552 ++++++--- schema/claude/consts.go | 1 + .../claude/{content_block.go => extension.go} | 70 +- schema/claude/extension_test.go | 190 +++ schema/claude/response_meta.go | 22 - .../gemini/{response_meta.go => extension.go} | 43 +- schema/gemini/extension_test.go | 79 ++ schema/message.go | 4 +- schema/openai/consts.go | 69 ++ schema/openai/content_block.go | 75 -- schema/openai/extension.go | 206 ++++ schema/openai/extension_test.go | 193 +++ schema/openai/response_meta.go | 40 - schema/tool.go | 21 + 23 files changed, 1942 insertions(+), 743 deletions(-) rename schema/claude/{content_block.go => extension.go} (51%) create mode 100644 schema/claude/extension_test.go delete mode 100644 schema/claude/response_meta.go rename schema/gemini/{response_meta.go => extension.go} (76%) create mode 100644 schema/gemini/extension_test.go delete mode 100644 schema/openai/content_block.go create mode 100644 schema/openai/extension.go create mode 100644 schema/openai/extension_test.go delete mode 100644 schema/openai/response_meta.go diff --git a/components/agentic/callback_extra.go b/components/agentic/callback_extra.go index 389408d33..2c5a656fa 100644 --- a/components/agentic/callback_extra.go +++ b/components/agentic/callback_extra.go @@ -14,6 +14,7 @@ * limitations under the License. */ +// Package agentic defines callback payloads and configuration types for agentic models. package agentic import ( @@ -26,9 +27,9 @@ type Config struct { // Model is the model name. Model string // Temperature is the temperature, which controls the randomness of the model. - Temperature float32 + Temperature float64 // TopP is the top p, which controls the diversity of the model. - TopP float32 + TopP float64 } // CallbackInput is the input for the model callback. diff --git a/components/agentic/option.go b/components/agentic/option.go index ac117ddb4..d8873442a 100644 --- a/components/agentic/option.go +++ b/components/agentic/option.go @@ -30,8 +30,10 @@ type Options struct { TopP *float64 // Tools is a list of tools the model may call. Tools []*schema.ToolInfo - // ToolChoice controls which tool is called by the model. + // ToolChoice controls how the model call the tools. ToolChoice *schema.ToolChoice + // AllowedTools is a list of allowed tools the model may call. + AllowedTools []*schema.AllowedTool } // Option is the call option for ChatModel component. @@ -81,10 +83,11 @@ func WithTools(tools []*schema.ToolInfo) Option { } // WithToolChoice is the option to set tool choice for the model. -func WithToolChoice(toolChoice schema.ToolChoice) Option { +func WithToolChoice(toolChoice schema.ToolChoice, allowedTools ...*schema.AllowedTool) Option { return Option{ apply: func(opts *Options) { opts.ToolChoice = &toolChoice + opts.AllowedTools = allowedTools }, } } diff --git a/components/agentic/option_test.go b/components/agentic/option_test.go index d349f35ac..2c5bac652 100644 --- a/components/agentic/option_test.go +++ b/components/agentic/option_test.go @@ -29,7 +29,7 @@ func TestCommon(t *testing.T) { WithTools([]*schema.ToolInfo{{Name: "test"}}), WithModel("test"), WithTemperature(0.1), - WithToolChoice(schema.ToolChoiceAllowed), + WithToolChoice(schema.ToolChoiceAllowed, []*schema.AllowedTool{{FunctionToolName: "test"}}...), WithTopP(0.1), ) assert.Len(t, o.Tools, 1) diff --git a/components/model/callback_extra.go b/components/model/callback_extra.go index 8591c4373..2767e2e5e 100644 --- a/components/model/callback_extra.go +++ b/components/model/callback_extra.go @@ -29,17 +29,17 @@ type TokenUsage struct { PromptTokenDetails PromptTokenDetails // CompletionTokens is the number of completion tokens. CompletionTokens int + // CompletionTokensDetails is a breakdown of the completion tokens. + CompletionTokensDetails CompletionTokensDetails // TotalTokens is the total number of tokens. TotalTokens int - // CompletionTokensDetails is breakdown of completion tokens. - CompletionTokensDetails CompletionTokensDetails `json:"completion_token_details"` } type CompletionTokensDetails struct { // ReasoningTokens tokens generated by the model for reasoning. // This is currently supported by OpenAI, Gemini, ARK and Qwen chat models. // For other models, this field will be 0. - ReasoningTokens int `json:"reasoning_tokens,omitempty"` + ReasoningTokens int } // PromptTokenDetails provides a breakdown of prompt token usage. diff --git a/components/prompt/callback_extra.go b/components/prompt/callback_extra.go index ff5c3a8ff..3be780543 100644 --- a/components/prompt/callback_extra.go +++ b/components/prompt/callback_extra.go @@ -33,6 +33,7 @@ type AgenticCallbackOutput struct { Extra map[string]any } +// ConvAgenticCallbackInput converts the callback input to the agentic callback input. func ConvAgenticCallbackInput(src callbacks.CallbackInput) *AgenticCallbackInput { switch t := src.(type) { case *AgenticCallbackInput: @@ -46,6 +47,7 @@ func ConvAgenticCallbackInput(src callbacks.CallbackInput) *AgenticCallbackInput } } +// ConvAgenticCallbackOutput converts the callback output to the agentic callback output. func ConvAgenticCallbackOutput(src callbacks.CallbackOutput) *AgenticCallbackOutput { switch t := src.(type) { case *AgenticCallbackOutput: diff --git a/components/types.go b/components/types.go index 2ba088e93..2b0ad8f0e 100644 --- a/components/types.go +++ b/components/types.go @@ -66,9 +66,11 @@ type Component string const ( // ComponentOfPrompt identifies chat template components. ComponentOfPrompt Component = "ChatTemplate" + // ComponentOfAgenticPrompt identifies agentic template components. ComponentOfAgenticPrompt Component = "AgenticChatTemplate" // ComponentOfChatModel identifies chat model components. ComponentOfChatModel Component = "ChatModel" + // ComponentOfAgenticModel identifies agentic model components. ComponentOfAgenticModel Component = "AgenticModel" // ComponentOfEmbedding identifies embedding components. ComponentOfEmbedding Component = "Embedding" diff --git a/compose/tools_node_agentic.go b/compose/tools_node_agentic.go index 38c5c89de..96aef7b72 100644 --- a/compose/tools_node_agentic.go +++ b/compose/tools_node_agentic.go @@ -70,6 +70,7 @@ func agenticMessageToToolCallMessage(input *schema.AgenticMessage) *schema.Messa Name: block.FunctionToolCall.Name, Arguments: block.FunctionToolCall.Arguments, }, + Extra: block.Extra, }) } return &schema.Message{ @@ -87,8 +88,8 @@ func toolMessageToAgenticMessage(input []*schema.Message) []*schema.AgenticMessa CallID: m.ToolCallID, Name: m.ToolName, Result: m.Content, - Extra: m.Extra, }, + Extra: m.Extra, }) } return []*schema.AgenticMessage{{ @@ -110,9 +111,9 @@ func streamToolMessageToAgenticMessage(input *schema.StreamReader[[]*schema.Mess CallID: m.ToolCallID, Name: m.ToolName, Result: m.Content, - Extra: m.Extra, }, - StreamMeta: &schema.StreamMeta{Index: int64(i)}, + StreamingMeta: &schema.StreamingMeta{Index: i}, + Extra: m.Extra, }) } return []*schema.AgenticMessage{{ diff --git a/compose/tools_node_agentic_test.go b/compose/tools_node_agentic_test.go index dcd3177a9..4641dd8ae 100644 --- a/compose/tools_node_agentic_test.go +++ b/compose/tools_node_agentic_test.go @@ -20,6 +20,7 @@ import ( "io" "testing" + "github.com/bytedance/sonic" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/schema" @@ -155,13 +156,14 @@ func TestStreamToolMessageToAgenticMessage(t *testing.T) { nil, }, { + nil, { Role: schema.Tool, - Content: "content1-2", + Content: "content2-2", ToolName: "name2", ToolCallID: "2", }, - nil, nil, + nil, }, { nil, nil, @@ -172,16 +174,6 @@ func TestStreamToolMessageToAgenticMessage(t *testing.T) { ToolCallID: "3", }, }, - { - nil, - { - Role: schema.Tool, - Content: "content2-2", - ToolName: "name2", - ToolCallID: "2", - }, - nil, - }, { nil, nil, { @@ -204,7 +196,11 @@ func TestStreamToolMessageToAgenticMessage(t *testing.T) { } result, err := schema.ConcatAgenticMessagesArray(chunks) assert.NoError(t, err) - assert.Equal(t, []*schema.AgenticMessage{ + + actualStr, err := sonic.MarshalString(result) + assert.NoError(t, err) + + expected := []*schema.AgenticMessage{ { Role: schema.AgenticRoleTypeUser, ContentBlocks: []*schema.ContentBlock{ @@ -213,10 +209,8 @@ func TestStreamToolMessageToAgenticMessage(t *testing.T) { FunctionToolResult: &schema.FunctionToolResult{ CallID: "1", Name: "name1", - Result: "content1-1content1-2", - Extra: map[string]interface{}{}, + Result: "content1-1", }, - StreamMeta: &schema.StreamMeta{Index: 0}, }, { Type: schema.ContentBlockTypeFunctionToolResult, @@ -224,9 +218,7 @@ func TestStreamToolMessageToAgenticMessage(t *testing.T) { CallID: "2", Name: "name2", Result: "content2-1content2-2", - Extra: map[string]interface{}{}, }, - StreamMeta: &schema.StreamMeta{Index: 1}, }, { Type: schema.ContentBlockTypeFunctionToolResult, @@ -234,11 +226,14 @@ func TestStreamToolMessageToAgenticMessage(t *testing.T) { CallID: "3", Name: "name3", Result: "content3-1content3-2", - Extra: map[string]interface{}{}, }, - StreamMeta: &schema.StreamMeta{Index: 2}, }, }, }, - }, result) + } + + expectedStr, err := sonic.MarshalString(expected) + assert.NoError(t, err) + + assert.Equal(t, expectedStr, actualStr) } diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 2139201ec..b2225b2c7 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -20,15 +20,15 @@ import ( "context" "fmt" "reflect" + "sort" "strings" - "github.com/cloudwego/eino/schema/claude" - "github.com/cloudwego/eino/schema/gemini" + "github.com/eino-contrib/jsonschema" "github.com/cloudwego/eino/internal" + "github.com/cloudwego/eino/schema/claude" + "github.com/cloudwego/eino/schema/gemini" "github.com/cloudwego/eino/schema/openai" - - "github.com/eino-contrib/jsonschema" ) type ContentBlockType string @@ -82,9 +82,9 @@ type AgenticResponseMeta struct { Extension any } -type StreamMeta struct { +type StreamingMeta struct { // Index specifies the index position of this block in the final response. - Index int64 + Index int } type ContentBlock struct { @@ -123,14 +123,12 @@ type ContentBlock struct { // MCPToolApprovalResponse records the user's approval decision for an MCP tool call. MCPToolApprovalResponse *MCPToolApprovalResponse - StreamMeta *StreamMeta + StreamingMeta *StreamingMeta + Extra map[string]any } type UserInputText struct { Text string - - // Extra stores additional information. - Extra map[string]any } type UserInputImage struct { @@ -138,27 +136,18 @@ type UserInputImage struct { Base64Data string MIMEType string Detail ImageURLDetail - - // Extra stores additional information. - Extra map[string]any } type UserInputAudio struct { URL string Base64Data string MIMEType string - - // Extra stores additional information. - Extra map[string]any } type UserInputVideo struct { URL string Base64Data string MIMEType string - - // Extra stores additional information. - Extra map[string]any } type UserInputFile struct { @@ -166,9 +155,6 @@ type UserInputFile struct { Name string Base64Data string MIMEType string - - // Extra stores additional information. - Extra map[string]any } type AssistantGenText struct { @@ -177,51 +163,37 @@ type AssistantGenText struct { OpenAIExtension *openai.AssistantGenTextExtension ClaudeExtension *claude.AssistantGenTextExtension Extension any - - // Extra stores additional information. - Extra map[string]any } type AssistantGenImage struct { URL string Base64Data string MIMEType string - - // Extra stores additional information. - Extra map[string]any } type AssistantGenAudio struct { URL string Base64Data string MIMEType string - - // Extra stores additional information. - Extra map[string]any } type AssistantGenVideo struct { URL string Base64Data string MIMEType string - - // Extra stores additional information. - Extra map[string]any } type Reasoning struct { // Summary is the reasoning content summary. Summary []*ReasoningSummary + // EncryptedContent is the encrypted reasoning content. EncryptedContent string - - // Extra stores additional information. - Extra map[string]any } type ReasoningSummary struct { // Index specifies the index position of this summary in the final Reasoning. - Index int64 + Index int Text string } @@ -229,39 +201,37 @@ type ReasoningSummary struct { type FunctionToolCall struct { // CallID is the unique identifier for the tool call. CallID string + // Name specifies the function tool invoked. Name string + // Arguments is the JSON string arguments for the function tool call. Arguments string - - // Extra stores additional information - Extra map[string]any } type FunctionToolResult struct { // CallID is the unique identifier for the tool call. CallID string + // Name specifies the function tool invoked. Name string + // Result is the function tool result returned by the user Result string - - // Extra stores additional information. - Extra map[string]any } type ServerToolCall struct { // Name specifies the server-side tool invoked. // Supplied by the model server (e.g., `web_search` for OpenAI, `googleSearch` for Gemini). Name string + // CallID is the unique identifier for the tool call. // Empty if not provided by the model server. CallID string + // Arguments are the raw inputs to the server-side tool, // supplied by the component implementer. Arguments any - // Extra stores additional information. - Extra map[string]any } type ServerToolResult struct { @@ -276,41 +246,40 @@ type ServerToolResult struct { // Result refers to the raw output generated by the server-side tool, // supplied by the component implementer. Result any - - // Extra stores additional information. - Extra map[string]any } type MCPToolCall struct { // ServerLabel is the MCP server label used to identify it in tool calls ServerLabel string - // ApprovalRequestID is the unique ID of the approval request. + + // ApprovalRequestID is the approval request ID. ApprovalRequestID string + // CallID is the unique ID of the tool call. CallID string + // Name is the name of the tool to run. Name string + // Arguments is the JSON string arguments for the tool call. Arguments string - - // Extra stores additional information. - Extra map[string]any } type MCPToolResult struct { // ServerLabel is the MCP server label used to identify it in tool calls ServerLabel string + // CallID is the unique ID of the tool call. CallID string + // Name is the name of the tool to run. Name string + // Result is the JSON string with the tool result. Result string + // Error returned when the server fails to run the tool. Error *MCPToolCallError - - // Extra stores additional information. - Extra map[string]any } type MCPToolCallError struct { @@ -321,49 +290,49 @@ type MCPToolCallError struct { type MCPListToolsResult struct { // ServerLabel is the MCP server label used to identify it in tool calls. ServerLabel string + // Tools is the list of tools available on the server. Tools []*MCPListToolsItem + // Error returned when the server fails to list tools. Error string - - // Extra stores additional information. - Extra map[string]any } type MCPListToolsItem struct { // Name is the name of the tool. Name string + // Description is the description of the tool. Description string - // InputSchema is the JSON schema that describes the tool input. + + // InputSchema is the JSON schema that describes the tool input parameters. InputSchema *jsonschema.Schema } type MCPToolApprovalRequest struct { // ID is the approval request ID. ID string + // Name is the name of the tool to run. Name string + // Arguments is the JSON string arguments for the tool call. Arguments string + // ServerLabel is the MCP server label used to identify it in tool calls. ServerLabel string - - // Extra stores additional information. - Extra map[string]any } type MCPToolApprovalResponse struct { // ApprovalRequestID is the approval request ID being responded to. ApprovalRequestID string + // Approve indicates whether the request is approved. Approve bool + // Reason is the rationale for the decision. // Optional. Reason string - - // Extra stores additional information. - Extra map[string]any } // DeveloperAgenticMessage represents a message with AgenticRoleType "developer". @@ -404,8 +373,33 @@ func FunctionToolResultAgenticMessage(callID, name, result string) *AgenticMessa } } -func NewContentBlock(block any) *ContentBlock { - switch b := block.(type) { +type contentBlockVariant interface { + Reasoning | userInputVariant | assistantGenVariant | functionToolCallVariant | serverToolCallVariant | mcpToolCallVariant +} + +type userInputVariant interface { + UserInputText | UserInputImage | UserInputAudio | UserInputVideo | UserInputFile +} + +type assistantGenVariant interface { + AssistantGenText | AssistantGenImage | AssistantGenAudio | AssistantGenVideo +} + +type functionToolCallVariant interface { + FunctionToolCall | FunctionToolResult +} + +type serverToolCallVariant interface { + ServerToolCall | ServerToolResult +} + +type mcpToolCallVariant interface { + MCPToolCall | MCPToolResult | MCPListToolsResult | MCPToolApprovalRequest | MCPToolApprovalResponse +} + +// NewContentBlock creates a new ContentBlock with the given content. +func NewContentBlock[T contentBlockVariant](content *T) *ContentBlock { + switch b := any(content).(type) { case *Reasoning: return &ContentBlock{Type: ContentBlockTypeReasoning, Reasoning: b} case *UserInputText: @@ -449,6 +443,13 @@ func NewContentBlock(block any) *ContentBlock { } } +// NewContentBlockChunk creates a new ContentBlock with the given content and streaming metadata. +func NewContentBlockChunk[T contentBlockVariant](content *T, meta *StreamingMeta) *ContentBlock { + block := NewContentBlock(content) + block.StreamingMeta = meta + return block +} + // AgenticMessagesTemplate is the interface for messages template. // It's used to render a template to a list of agentic messages. // e.g. @@ -683,16 +684,19 @@ func formatUserInputFile(uif *UserInputFile, vs map[string]any, formatType Forma return &copied, nil } +// ConcatAgenticMessagesArray concatenates multiple streams of AgenticMessage into a single slice of AgenticMessage. func ConcatAgenticMessagesArray(mas [][]*AgenticMessage) ([]*AgenticMessage, error) { return buildConcatGenericArray[AgenticMessage](ConcatAgenticMessages)(mas) } +// ConcatAgenticMessages concatenates a list of AgenticMessage chunks into a single AgenticMessage. func ConcatAgenticMessages(msgs []*AgenticMessage) (*AgenticMessage, error) { var ( - role AgenticRoleType - blocksList [][]*ContentBlock - blocks []*ContentBlock - metas []*AgenticResponseMeta + role AgenticRoleType + blocks []*ContentBlock + metas []*AgenticResponseMeta + blockIndices []int + indexToBlocks = map[int][]*ContentBlock{} ) if len(msgs) == 1 { @@ -713,9 +717,12 @@ func ConcatAgenticMessages(msgs []*AgenticMessage) (*AgenticMessage, error) { } for _, block := range msg.ContentBlocks { - if block.StreamMeta == nil { + if block == nil { + continue + } + if block.StreamingMeta == nil { // Non-streaming block - if len(blocksList) > 0 { + if len(blockIndices) > 0 { // Cannot mix streaming and non-streaming blocks return nil, fmt.Errorf("found non-streaming block after streaming blocks") } @@ -728,8 +735,12 @@ func ConcatAgenticMessages(msgs []*AgenticMessage) (*AgenticMessage, error) { return nil, fmt.Errorf("found streaming block after non-streaming blocks") } // Collect streaming block by index - blocksList = expandSlice(int(block.StreamMeta.Index), blocksList) - blocksList[block.StreamMeta.Index] = append(blocksList[block.StreamMeta.Index], block) + if blocks_, ok := indexToBlocks[block.StreamingMeta.Index]; ok { + indexToBlocks[block.StreamingMeta.Index] = append(blocks_, block) + } else { + blockIndices = append(blockIndices, block.StreamingMeta.Index) + indexToBlocks[block.StreamingMeta.Index] = []*ContentBlock{block} + } } } @@ -743,219 +754,254 @@ func ConcatAgenticMessages(msgs []*AgenticMessage) (*AgenticMessage, error) { return nil, fmt.Errorf("failed to concat agentic response meta: %w", err) } - if len(blocksList) > 0 { + if len(blockIndices) > 0 { // All blocks are streaming, concat each group by index - blocks = make([]*ContentBlock, len(blocksList)) - for i, bs := range blocksList { - if len(bs) == 0 { - continue - } - b, err := concatAgenticContentBlocks(bs) + indexToBlock := map[int]*ContentBlock{} + for idx, bs := range indexToBlocks { + b, err := concatChunksOfSameContentBlock(bs) if err != nil { - return nil, fmt.Errorf("failed to concat content blocks at index %d: %w", i, err) + return nil, err } - blocks[i] = b + indexToBlock[idx] = b } - } - - for i := 0; i < len(blocks); i++ { - if blocks[i] == nil { - blocks = append(blocks[:i], blocks[i+1:]...) + blocks = make([]*ContentBlock, 0, len(blockIndices)) + sort.Slice(blockIndices, func(i, j int) bool { + return blockIndices[i] < blockIndices[j] + }) + for _, idx := range blockIndices { + blocks = append(blocks, indexToBlock[idx]) } } return &AgenticMessage{ - ResponseMeta: meta, Role: role, + ResponseMeta: meta, ContentBlocks: blocks, }, nil } -func concatAgenticResponseMeta(metas []*AgenticResponseMeta) (*AgenticResponseMeta, error) { +func concatAgenticResponseMeta(metas []*AgenticResponseMeta) (ret *AgenticResponseMeta, err error) { if len(metas) == 0 { return nil, nil } - ret := &AgenticResponseMeta{ - TokenUsage: &TokenUsage{}, - OpenAIExtension: nil, - ClaudeExtension: nil, - GeminiExtension: nil, - Extension: nil, - } + + openaiExtensions := make([]*openai.ResponseMetaExtension, 0, len(metas)) + claudeExtensions := make([]*claude.ResponseMetaExtension, 0, len(metas)) + geminiExtensions := make([]*gemini.ResponseMetaExtension, 0, len(metas)) + tokenUsages := make([]*TokenUsage, 0, len(metas)) + + var ( + extType reflect.Type + extensions reflect.Value + ) + for _, meta := range metas { - ret.Extension = meta.Extension - ret.OpenAIExtension = meta.OpenAIExtension - ret.ClaudeExtension = meta.ClaudeExtension - ret.GeminiExtension = meta.GeminiExtension if meta.TokenUsage != nil { - ret.TokenUsage.CompletionTokens += meta.TokenUsage.CompletionTokens - ret.TokenUsage.CompletionTokenDetails.ReasoningTokens += meta.TokenUsage.CompletionTokenDetails.ReasoningTokens - ret.TokenUsage.PromptTokens += meta.TokenUsage.PromptTokens - ret.TokenUsage.PromptTokenDetails.CachedTokens += meta.TokenUsage.PromptTokenDetails.CachedTokens - ret.TokenUsage.TotalTokens += meta.TokenUsage.TotalTokens + tokenUsages = append(tokenUsages, meta.TokenUsage) + } + + var isConsistent bool + + if meta.Extension != nil { + extType, isConsistent = validateExtensionType(extType, meta.Extension) + if !isConsistent { + return nil, fmt.Errorf("inconsistent extension types in response meta chunks: '%s' vs '%s'", + extType, reflect.TypeOf(meta.Extension)) + } + if !extensions.IsValid() { + extensions = reflect.MakeSlice(reflect.SliceOf(extType), 0, len(metas)) + } + extensions = reflect.Append(extensions, reflect.ValueOf(meta.Extension)) + } + + if meta.OpenAIExtension != nil { + extType, isConsistent = validateExtensionType(extType, meta.OpenAIExtension) + if !isConsistent { + return nil, fmt.Errorf("inconsistent extension types in response meta chunks: '%s' vs '%s'", + extType, reflect.TypeOf(meta.OpenAIExtension)) + } + openaiExtensions = append(openaiExtensions, meta.OpenAIExtension) + } + + if meta.ClaudeExtension != nil { + extType, isConsistent = validateExtensionType(extType, meta.ClaudeExtension) + if !isConsistent { + return nil, fmt.Errorf("inconsistent extension types in response meta chunks: '%s' vs '%s'", + extType, reflect.TypeOf(meta.ClaudeExtension)) + } + claudeExtensions = append(claudeExtensions, meta.ClaudeExtension) + } + + if meta.GeminiExtension != nil { + extType, isConsistent = validateExtensionType(extType, meta.GeminiExtension) + if !isConsistent { + return nil, fmt.Errorf("inconsistent extension types in response meta chunks: '%s' vs '%s'", + extType, reflect.TypeOf(meta.GeminiExtension)) + } + geminiExtensions = append(geminiExtensions, meta.GeminiExtension) + } + } + + ret = &AgenticResponseMeta{ + TokenUsage: concatTokenUsage(tokenUsages), + } + + if extensions.IsValid() && !extensions.IsZero() { + var extension reflect.Value + extension, err = internal.ConcatSliceValue(extensions) + if err != nil { + return nil, fmt.Errorf("failed to concat extensions: %w", err) + } + ret.Extension = extension.Interface() + } + + if len(openaiExtensions) > 0 { + ret.OpenAIExtension, err = openai.ConcatResponseMetaExtensions(openaiExtensions) + if err != nil { + return nil, fmt.Errorf("failed to concat openai extensions: %w", err) + } + } + + if len(claudeExtensions) > 0 { + ret.ClaudeExtension, err = claude.ConcatResponseMetaExtensions(claudeExtensions) + if err != nil { + return nil, fmt.Errorf("failed to concat claude extensions: %w", err) + } + } + + if len(geminiExtensions) > 0 { + ret.GeminiExtension, err = gemini.ConcatResponseMetaExtensions(geminiExtensions) + if err != nil { + return nil, fmt.Errorf("failed to concat gemini extensions: %w", err) } } + return ret, nil } -func concatAgenticContentBlocks(blocks []*ContentBlock) (*ContentBlock, error) { +func concatTokenUsage(usages []*TokenUsage) *TokenUsage { + if len(usages) == 0 { + return nil + } + + ret := &TokenUsage{} + + for _, usage := range usages { + if usage == nil { + continue + } + ret.CompletionTokens += usage.CompletionTokens + ret.CompletionTokensDetails.ReasoningTokens += usage.CompletionTokensDetails.ReasoningTokens + ret.PromptTokens += usage.PromptTokens + ret.PromptTokenDetails.CachedTokens += usage.PromptTokenDetails.CachedTokens + ret.TotalTokens += usage.TotalTokens + } + + return ret +} + +func concatChunksOfSameContentBlock(blocks []*ContentBlock) (*ContentBlock, error) { if len(blocks) == 0 { return nil, fmt.Errorf("no content blocks to concat") } + blockType := blocks[0].Type - index := blocks[0].StreamMeta.Index + switch blockType { case ContentBlockTypeReasoning: - return concatContentBlockHelper(blocks, blockType, "reasoning", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *Reasoning { return b.Reasoning }, - concatReasoning, - func(r *Reasoning) *ContentBlock { - return &ContentBlock{Type: blockType, Reasoning: r, StreamMeta: &StreamMeta{Index: index}} - }) + concatReasoning) case ContentBlockTypeUserInputText: - return concatContentBlockHelper(blocks, blockType, "user input text", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *UserInputText { return b.UserInputText }, - concatUserInputText, - func(t *UserInputText) *ContentBlock { - return &ContentBlock{Type: blockType, UserInputText: t, StreamMeta: &StreamMeta{Index: index}} - }) + concatUserInputTexts) case ContentBlockTypeUserInputImage: - return concatContentBlockHelper(blocks, blockType, "user input image", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *UserInputImage { return b.UserInputImage }, - concatUserInputImage, - func(i *UserInputImage) *ContentBlock { - return &ContentBlock{Type: blockType, UserInputImage: i, StreamMeta: &StreamMeta{Index: index}} - }) + concatUserInputImages) case ContentBlockTypeUserInputAudio: - return concatContentBlockHelper(blocks, blockType, "user input audio", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *UserInputAudio { return b.UserInputAudio }, - concatUserInputAudio, - func(a *UserInputAudio) *ContentBlock { - return &ContentBlock{Type: blockType, UserInputAudio: a, StreamMeta: &StreamMeta{Index: index}} - }) + concatUserInputAudios) case ContentBlockTypeUserInputVideo: - return concatContentBlockHelper(blocks, blockType, "user input video", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *UserInputVideo { return b.UserInputVideo }, - concatUserInputVideo, - func(v *UserInputVideo) *ContentBlock { - return &ContentBlock{Type: blockType, UserInputVideo: v, StreamMeta: &StreamMeta{Index: index}} - }) + concatUserInputVideos) case ContentBlockTypeUserInputFile: - return concatContentBlockHelper(blocks, blockType, "user input file", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *UserInputFile { return b.UserInputFile }, - concatUserInputFile, - func(f *UserInputFile) *ContentBlock { - return &ContentBlock{Type: blockType, UserInputFile: f, StreamMeta: &StreamMeta{Index: index}} - }) + concatUserInputFiles) case ContentBlockTypeAssistantGenText: - return concatContentBlockHelper(blocks, blockType, "assistant gen text", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *AssistantGenText { return b.AssistantGenText }, - concatAssistantGenText, - func(t *AssistantGenText) *ContentBlock { - return &ContentBlock{Type: blockType, AssistantGenText: t, StreamMeta: &StreamMeta{Index: index}} - }) + concatAssistantGenTexts) case ContentBlockTypeAssistantGenImage: - return concatContentBlockHelper(blocks, blockType, "assistant gen image", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *AssistantGenImage { return b.AssistantGenImage }, - concatAssistantGenImage, - func(i *AssistantGenImage) *ContentBlock { - return &ContentBlock{Type: blockType, AssistantGenImage: i, StreamMeta: &StreamMeta{Index: index}} - }) + concatAssistantGenImages) case ContentBlockTypeAssistantGenAudio: - return concatContentBlockHelper(blocks, blockType, "assistant gen audio", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *AssistantGenAudio { return b.AssistantGenAudio }, - concatAssistantGenAudio, - func(a *AssistantGenAudio) *ContentBlock { - return &ContentBlock{Type: blockType, AssistantGenAudio: a, StreamMeta: &StreamMeta{Index: index}} - }) + concatAssistantGenAudios) case ContentBlockTypeAssistantGenVideo: - return concatContentBlockHelper(blocks, blockType, "assistant gen video", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *AssistantGenVideo { return b.AssistantGenVideo }, - concatAssistantGenVideo, - func(v *AssistantGenVideo) *ContentBlock { - return &ContentBlock{Type: blockType, AssistantGenVideo: v, StreamMeta: &StreamMeta{Index: index}} - }) + concatAssistantGenVideos) case ContentBlockTypeFunctionToolCall: - return concatContentBlockHelper(blocks, blockType, "function tool call", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *FunctionToolCall { return b.FunctionToolCall }, - concatFunctionToolCall, - func(c *FunctionToolCall) *ContentBlock { - return &ContentBlock{Type: blockType, FunctionToolCall: c, StreamMeta: &StreamMeta{Index: index}} - }) + concatFunctionToolCalls) case ContentBlockTypeFunctionToolResult: - return concatContentBlockHelper(blocks, blockType, "function tool result", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *FunctionToolResult { return b.FunctionToolResult }, - concatFunctionToolResult, - func(r *FunctionToolResult) *ContentBlock { - return &ContentBlock{Type: blockType, FunctionToolResult: r, StreamMeta: &StreamMeta{Index: index}} - }) + concatFunctionToolResults) case ContentBlockTypeServerToolCall: - return concatContentBlockHelper(blocks, blockType, "server tool call", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *ServerToolCall { return b.ServerToolCall }, - concatServerToolCall, - func(c *ServerToolCall) *ContentBlock { - return &ContentBlock{Type: blockType, ServerToolCall: c, StreamMeta: &StreamMeta{Index: index}} - }) + concatServerToolCalls) case ContentBlockTypeServerToolResult: - return concatContentBlockHelper(blocks, blockType, "server tool result", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *ServerToolResult { return b.ServerToolResult }, - concatServerToolResult, - func(r *ServerToolResult) *ContentBlock { - return &ContentBlock{Type: blockType, ServerToolResult: r, StreamMeta: &StreamMeta{Index: index}} - }) + concatServerToolResults) case ContentBlockTypeMCPToolCall: - return concatContentBlockHelper(blocks, blockType, "MCP tool call", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *MCPToolCall { return b.MCPToolCall }, - concatMCPToolCall, - func(c *MCPToolCall) *ContentBlock { - return &ContentBlock{Type: blockType, MCPToolCall: c, StreamMeta: &StreamMeta{Index: index}} - }) + concatMCPToolCalls) case ContentBlockTypeMCPToolResult: - return concatContentBlockHelper(blocks, blockType, "MCP tool result", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *MCPToolResult { return b.MCPToolResult }, - concatMCPToolResult, - func(r *MCPToolResult) *ContentBlock { - return &ContentBlock{Type: blockType, MCPToolResult: r, StreamMeta: &StreamMeta{Index: index}} - }) + concatMCPToolResults) case ContentBlockTypeMCPListToolsResult: - return concatContentBlockHelper(blocks, blockType, "MCP list tools", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *MCPListToolsResult { return b.MCPListToolsResult }, - concatMCPListToolsResult, - func(r *MCPListToolsResult) *ContentBlock { - return &ContentBlock{Type: blockType, MCPListToolsResult: r, StreamMeta: &StreamMeta{Index: index}} - }) + concatMCPListToolsResults) case ContentBlockTypeMCPToolApprovalRequest: - return concatContentBlockHelper(blocks, blockType, "MCP tool approval request", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *MCPToolApprovalRequest { return b.MCPToolApprovalRequest }, - concatMCPToolApprovalRequest, - func(r *MCPToolApprovalRequest) *ContentBlock { - return &ContentBlock{Type: blockType, MCPToolApprovalRequest: r, StreamMeta: &StreamMeta{Index: index}} - }) + concatMCPToolApprovalRequests) case ContentBlockTypeMCPToolApprovalResponse: - return concatContentBlockHelper(blocks, blockType, "MCP tool approval response", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *MCPToolApprovalResponse { return b.MCPToolApprovalResponse }, - concatMCPToolApprovalResponse, - func(r *MCPToolApprovalResponse) *ContentBlock { - return &ContentBlock{Type: blockType, MCPToolApprovalResponse: r, StreamMeta: &StreamMeta{Index: index}} - }) + concatMCPToolApprovalResponses) default: return nil, fmt.Errorf("unknown content block type: %s", blockType) @@ -964,21 +1010,19 @@ func concatAgenticContentBlocks(blocks []*ContentBlock) (*ContentBlock, error) { // concatContentBlockHelper is a generic helper function that reduces code duplication // for concatenating content blocks of a specific type. -func concatContentBlockHelper[T any]( +func concatContentBlockHelper[T contentBlockVariant]( blocks []*ContentBlock, expectedType ContentBlockType, - typeName string, getter func(*ContentBlock) *T, concatFunc func([]*T) (*T, error), - constructor func(*T) *ContentBlock, ) (*ContentBlock, error) { items, err := genericGetTFromContentBlocks(blocks, func(block *ContentBlock) (*T, error) { if block.Type != expectedType { - return nil, fmt.Errorf("expected %s block, got %s", typeName, block.Type) + return nil, fmt.Errorf("content block type mismatch: expected '%s', but got '%s'", expectedType, block.Type) } item := getter(block) if item == nil { - return nil, fmt.Errorf("%s content is nil", typeName) + return nil, fmt.Errorf("'%s' content is nil", expectedType) } return item, nil }) @@ -988,10 +1032,28 @@ func concatContentBlockHelper[T any]( concatenated, err := concatFunc(items) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to concat '%s' content blocks: %w", expectedType, err) + } + + extras := make([]map[string]any, 0, len(blocks)) + for _, block := range blocks { + if len(block.Extra) > 0 { + extras = append(extras, block.Extra) + } + } + + var extra map[string]any + if len(extras) > 0 { + extra, err = internal.ConcatItems(extras) + if err != nil { + return nil, fmt.Errorf("failed to concat content block extras: %w", err) + } } - return constructor(concatenated), nil + block := NewContentBlock(concatenated) + block.Extra = extra + + return block, nil } func genericGetTFromContentBlocks[T any](blocks []*ContentBlock, checkAndGetter func(block *ContentBlock) (T, error)) ([]T, error) { @@ -1006,43 +1068,14 @@ func genericGetTFromContentBlocks[T any](blocks []*ContentBlock, checkAndGetter return ret, nil } -// Concatenation strategies for different content block types: -// -// String concatenation (incremental streaming): -// - Reasoning: Summary texts are concatenated, grouped by Index if present -// - UserInputText: Text fields are concatenated -// - AssistantGenText: Text fields are concatenated, annotations/citations are merged -// - FunctionToolCall: Arguments (JSON strings) are concatenated incrementally -// - FunctionToolResult: Result strings are concatenated -// - ServerToolCall: Arguments are merged (last non-nil value for any type) -// - ServerToolResult: Results are merged using internal.ConcatItems -// - MCPToolCall: Arguments (JSON strings) are concatenated incrementally -// - MCPToolResult: Result strings are concatenated -// - MCPListToolsResult: Tools arrays are merged -// - MCPToolApprovalRequest: Arguments are concatenated -// -// Take last block (non-streaming content): -// - UserInputImage, UserInputAudio, UserInputVideo, UserInputFile: Return last block -// - AssistantGenImage, AssistantGenAudio, AssistantGenVideo: Return last block -// - MCPToolApprovalResponse: Return last block -// - func concatReasoning(reasons []*Reasoning) (*Reasoning, error) { if len(reasons) == 0 { return nil, fmt.Errorf("no reasoning found") } - if len(reasons) == 1 { - return reasons[0], nil - } - ret := &Reasoning{ - Summary: make([]*ReasoningSummary, 0), - EncryptedContent: "", - Extra: make(map[string]any), - } + ret := &Reasoning{} - // Collect all summaries from all reasons - allSummaries := make([]*ReasoningSummary, 0) + var allSummaries []*ReasoningSummary for _, r := range reasons { if r == nil { continue @@ -1051,157 +1084,269 @@ func concatReasoning(reasons []*Reasoning) (*Reasoning, error) { if r.EncryptedContent != "" { ret.EncryptedContent += r.EncryptedContent } - for k, v := range r.Extra { - ret.Extra[k] = v - } } - // Group by Index and concatenate Text for same Index - // Use dynamic array that expands as needed - var summaryArray []*ReasoningSummary + var ( + indices []int + indexToSummary = map[int]*ReasoningSummary{} + ) + for _, s := range allSummaries { - idx := s.Index - // Expand array if needed - summaryArray = expandSlice(int(idx), summaryArray) - if summaryArray[idx] == nil { - // Create new entry with a copy of Index - summaryArray[idx] = &ReasoningSummary{ - Index: idx, - Text: s.Text, - } - } else { - // Concatenate text for same index - summaryArray[idx].Text += s.Text + if s == nil { + continue + } + if indexToSummary[s.Index] == nil { + indexToSummary[s.Index] = &ReasoningSummary{} + indices = append(indices, s.Index) } + indexToSummary[s.Index].Text += s.Text } - // Convert array to slice, filtering out nil entries - ret.Summary = make([]*ReasoningSummary, 0, len(summaryArray)) - for _, summary := range summaryArray { - if summary != nil { - ret.Summary = append(ret.Summary, summary) - } + sort.Slice(indices, func(i, j int) bool { + return indices[i] < indices[j] + }) + + ret.Summary = make([]*ReasoningSummary, 0, len(indices)) + for _, idx := range indices { + ret.Summary = append(ret.Summary, indexToSummary[idx]) } return ret, nil } -func concatUserInputText(texts []*UserInputText) (*UserInputText, error) { +func concatUserInputTexts(texts []*UserInputText) (*UserInputText, error) { if len(texts) == 0 { return nil, fmt.Errorf("no user input text found") } if len(texts) == 1 { return texts[0], nil } - - ret := &UserInputText{ - Text: "", - Extra: make(map[string]any), - } - - for _, t := range texts { - if t == nil { - continue - } - ret.Text += t.Text - for k, v := range t.Extra { - ret.Extra[k] = v - } - } - - return ret, nil + return nil, fmt.Errorf("cannot concat multiple user input texts") } -func concatUserInputImage(images []*UserInputImage) (*UserInputImage, error) { +func concatUserInputImages(images []*UserInputImage) (*UserInputImage, error) { if len(images) == 0 { return nil, fmt.Errorf("no user input image found") } - return images[len(images)-1], nil + if len(images) == 1 { + return images[0], nil + } + return nil, fmt.Errorf("cannot concat multiple user input images") } -func concatUserInputAudio(audios []*UserInputAudio) (*UserInputAudio, error) { +func concatUserInputAudios(audios []*UserInputAudio) (*UserInputAudio, error) { if len(audios) == 0 { return nil, fmt.Errorf("no user input audio found") } - return audios[len(audios)-1], nil + if len(audios) == 1 { + return audios[0], nil + } + return nil, fmt.Errorf("cannot concat multiple user input audios") } -func concatUserInputVideo(videos []*UserInputVideo) (*UserInputVideo, error) { +func concatUserInputVideos(videos []*UserInputVideo) (*UserInputVideo, error) { if len(videos) == 0 { return nil, fmt.Errorf("no user input video found") } - return videos[len(videos)-1], nil + if len(videos) == 1 { + return videos[0], nil + } + return nil, fmt.Errorf("cannot concat multiple user input videos") } -func concatUserInputFile(files []*UserInputFile) (*UserInputFile, error) { +func concatUserInputFiles(files []*UserInputFile) (*UserInputFile, error) { if len(files) == 0 { return nil, fmt.Errorf("no user input file found") } - return files[len(files)-1], nil + if len(files) == 1 { + return files[0], nil + } + return nil, fmt.Errorf("cannot concat multiple user input files") } -func concatAssistantGenText(texts []*AssistantGenText) (*AssistantGenText, error) { +func concatAssistantGenTexts(texts []*AssistantGenText) (ret *AssistantGenText, err error) { if len(texts) == 0 { - return nil, fmt.Errorf("no assistant gen text found") + return nil, fmt.Errorf("no assistant generated text found") } if len(texts) == 1 { return texts[0], nil } - ret := &AssistantGenText{ - Text: "", - OpenAIExtension: nil, - ClaudeExtension: nil, - Extra: make(map[string]any), - } + ret = &AssistantGenText{} + + openaiExtensions := make([]*openai.AssistantGenTextExtension, 0, len(texts)) + claudeExtensions := make([]*claude.AssistantGenTextExtension, 0, len(texts)) + + var ( + extType reflect.Type + extensions reflect.Value + ) for _, t := range texts { if t == nil { continue } + ret.Text += t.Text + + var isConsistent bool + + if t.Extension != nil { + extType, isConsistent = validateExtensionType(extType, t.Extension) + if !isConsistent { + return nil, fmt.Errorf("inconsistent extension types in assistant generated text chunks: '%s' vs '%s'", + extType, reflect.TypeOf(t.Extension)) + } + if !extensions.IsValid() { + extensions = reflect.MakeSlice(reflect.SliceOf(extType), 0, len(texts)) + } + extensions = reflect.Append(extensions, reflect.ValueOf(t.Extension)) + } + if t.OpenAIExtension != nil { - if ret.OpenAIExtension == nil { - ret.OpenAIExtension = &openai.AssistantGenTextExtension{} + extType, isConsistent = validateExtensionType(extType, t.OpenAIExtension) + if !isConsistent { + return nil, fmt.Errorf("inconsistent extension types in assistant generated text chunks: '%s' vs '%s'", + extType, reflect.TypeOf(t.OpenAIExtension)) } - ret.OpenAIExtension.Annotations = append(ret.OpenAIExtension.Annotations, t.OpenAIExtension.Annotations...) + openaiExtensions = append(openaiExtensions, t.OpenAIExtension) } + if t.ClaudeExtension != nil { - if ret.ClaudeExtension == nil { - ret.ClaudeExtension = &claude.AssistantGenTextExtension{} + extType, isConsistent = validateExtensionType(extType, t.ClaudeExtension) + if !isConsistent { + return nil, fmt.Errorf("inconsistent extension types in assistant generated text chunks: '%s' vs '%s'", + extType, reflect.TypeOf(t.ClaudeExtension)) } - ret.ClaudeExtension.Citations = append(ret.ClaudeExtension.Citations, t.ClaudeExtension.Citations...) + claudeExtensions = append(claudeExtensions, t.ClaudeExtension) } - for k, v := range t.Extra { - ret.Extra[k] = v + } + + if extensions.IsValid() && !extensions.IsZero() { + ret.Extension, err = internal.ConcatSliceValue(extensions) + if err != nil { + return nil, err + } + ret.Extension = extensions.Interface() + } + + if len(openaiExtensions) > 0 { + ret.OpenAIExtension, err = openai.ConcatAssistantGenTextExtensions(openaiExtensions) + if err != nil { + return nil, err + } + } + + if len(claudeExtensions) > 0 { + ret.ClaudeExtension, err = claude.ConcatAssistantGenTextExtensions(claudeExtensions) + if err != nil { + return nil, err } } return ret, nil } -func concatAssistantGenImage(images []*AssistantGenImage) (*AssistantGenImage, error) { +func concatAssistantGenImages(images []*AssistantGenImage) (*AssistantGenImage, error) { if len(images) == 0 { return nil, fmt.Errorf("no assistant gen image found") } - return images[len(images)-1], nil + if len(images) == 1 { + return images[0], nil + } + + ret := &AssistantGenImage{} + + for _, img := range images { + if img == nil { + continue + } + + ret.Base64Data += img.Base64Data + + if ret.URL == "" { + ret.URL = img.URL + } else if img.URL != "" && ret.URL != img.URL { + return nil, fmt.Errorf("inconsistent URLs in assistant generated image chunks: '%s' vs '%s'", ret.URL, img.URL) + } + + if ret.MIMEType == "" { + ret.MIMEType = img.MIMEType + } else if img.MIMEType != "" && ret.MIMEType != img.MIMEType { + return nil, fmt.Errorf("inconsistent MIME types in assistant generated image chunks: '%s' vs '%s'", ret.MIMEType, img.MIMEType) + } + } + + return ret, nil } -func concatAssistantGenAudio(audios []*AssistantGenAudio) (*AssistantGenAudio, error) { +func concatAssistantGenAudios(audios []*AssistantGenAudio) (*AssistantGenAudio, error) { if len(audios) == 0 { return nil, fmt.Errorf("no assistant gen audio found") } - return audios[len(audios)-1], nil + if len(audios) == 1 { + return audios[0], nil + } + + ret := &AssistantGenAudio{} + + for _, audio := range audios { + if audio == nil { + continue + } + + ret.Base64Data += audio.Base64Data + + if ret.URL == "" { + ret.URL = audio.URL + } else if audio.URL != "" && ret.URL != audio.URL { + return nil, fmt.Errorf("inconsistent URLs in assistant generated audio chunks: '%s' vs '%s'", ret.URL, audio.URL) + } + + if ret.MIMEType == "" { + ret.MIMEType = audio.MIMEType + } else if audio.MIMEType != "" && ret.MIMEType != audio.MIMEType { + return nil, fmt.Errorf("inconsistent MIME types in assistant generated audio chunks: '%s' vs '%s'", ret.MIMEType, audio.MIMEType) + } + } + + return ret, nil } -func concatAssistantGenVideo(videos []*AssistantGenVideo) (*AssistantGenVideo, error) { +func concatAssistantGenVideos(videos []*AssistantGenVideo) (*AssistantGenVideo, error) { if len(videos) == 0 { return nil, fmt.Errorf("no assistant gen video found") } - return videos[len(videos)-1], nil + if len(videos) == 1 { + return videos[0], nil + } + + ret := &AssistantGenVideo{} + + for _, video := range videos { + if video == nil { + continue + } + + ret.Base64Data += video.Base64Data + + if ret.URL == "" { + ret.URL = video.URL + } else if video.URL != "" && ret.URL != video.URL { + return nil, fmt.Errorf("inconsistent URLs in assistant generated video chunks: '%s' vs '%s'", ret.URL, video.URL) + } + + if ret.MIMEType == "" { + ret.MIMEType = video.MIMEType + } else if video.MIMEType != "" && ret.MIMEType != video.MIMEType { + return nil, fmt.Errorf("inconsistent MIME types in assistant generated video chunks: '%s' vs '%s'", ret.MIMEType, video.MIMEType) + } + } + + return ret, nil } -func concatFunctionToolCall(calls []*FunctionToolCall) (*FunctionToolCall, error) { +func concatFunctionToolCalls(calls []*FunctionToolCall) (*FunctionToolCall, error) { if len(calls) == 0 { return nil, fmt.Errorf("no function tool call found") } @@ -1209,31 +1354,32 @@ func concatFunctionToolCall(calls []*FunctionToolCall) (*FunctionToolCall, error return calls[0], nil } - // For tool calls, arguments are typically built incrementally during streaming - ret := &FunctionToolCall{ - Extra: make(map[string]any), - } + ret := &FunctionToolCall{} for _, c := range calls { if c == nil { continue } + if ret.CallID == "" { ret.CallID = c.CallID + } else if c.CallID != "" && c.CallID != ret.CallID { + return nil, fmt.Errorf("expected call ID '%s' for function tool call, but got '%s'", ret.CallID, c.CallID) } + if ret.Name == "" { ret.Name = c.Name + } else if c.Name != "" && c.Name != ret.Name { + return nil, fmt.Errorf("expected tool name '%s' for function tool call, but got '%s'", ret.Name, c.Name) } + ret.Arguments += c.Arguments - for k, v := range c.Extra { - ret.Extra[k] = v - } } return ret, nil } -func concatFunctionToolResult(results []*FunctionToolResult) (*FunctionToolResult, error) { +func concatFunctionToolResults(results []*FunctionToolResult) (*FunctionToolResult, error) { if len(results) == 0 { return nil, fmt.Errorf("no function tool result found") } @@ -1241,30 +1387,32 @@ func concatFunctionToolResult(results []*FunctionToolResult) (*FunctionToolResul return results[0], nil } - ret := &FunctionToolResult{ - Extra: make(map[string]any), - } + ret := &FunctionToolResult{} for _, r := range results { if r == nil { continue } + if ret.CallID == "" { ret.CallID = r.CallID + } else if r.CallID != "" && r.CallID != ret.CallID { + return nil, fmt.Errorf("expected call ID '%s' for function tool result, but got '%s'", ret.CallID, r.CallID) } + if ret.Name == "" { ret.Name = r.Name + } else if r.Name != "" && r.Name != ret.Name { + return nil, fmt.Errorf("expected tool name '%s' for function tool result, but got '%s'", ret.Name, r.Name) } + ret.Result += r.Result - for k, v := range r.Extra { - ret.Extra[k] = v - } } return ret, nil } -func concatServerToolCall(calls []*ServerToolCall) (*ServerToolCall, error) { +func concatServerToolCalls(calls []*ServerToolCall) (ret *ServerToolCall, err error) { if len(calls) == 0 { return nil, fmt.Errorf("no server tool call found") } @@ -1272,33 +1420,54 @@ func concatServerToolCall(calls []*ServerToolCall) (*ServerToolCall, error) { return calls[0], nil } - // ServerToolCall Arguments is of type any; merge strategy uses the last non-nil value - ret := &ServerToolCall{ - Extra: make(map[string]any), - } + ret = &ServerToolCall{} + + var ( + argsType reflect.Type + argsChunks reflect.Value + ) for _, c := range calls { if c == nil { continue } - if ret.Name == "" { - ret.Name = c.Name - } + if ret.CallID == "" { ret.CallID = c.CallID + } else if c.CallID != "" && c.CallID != ret.CallID { + return nil, fmt.Errorf("expected call ID '%s' for server tool call, but got '%s'", ret.CallID, c.CallID) + } + + if ret.Name == "" { + ret.Name = c.Name + } else if c.Name != "" && c.Name != ret.Name { + return nil, fmt.Errorf("expected tool name '%s' for server tool call, but got '%s'", ret.Name, c.Name) } + if c.Arguments != nil { - ret.Arguments = c.Arguments + argsType_ := reflect.TypeOf(c.Arguments) + if argsType == nil { + argsType = argsType_ + argsChunks = reflect.MakeSlice(reflect.SliceOf(argsType), 0, len(calls)) + } else if argsType != argsType_ { + return nil, fmt.Errorf("expected type '%s' for server tool call arguments, but got '%s'", argsType, argsType_) + } + argsChunks = reflect.Append(argsChunks, reflect.ValueOf(c.Arguments)) } - for k, v := range c.Extra { - ret.Extra[k] = v + } + + if argsChunks.IsValid() && !argsChunks.IsZero() { + arguments, err := internal.ConcatSliceValue(argsChunks) + if err != nil { + return nil, err } + ret.Arguments = arguments.Interface() } return ret, nil } -func concatServerToolResult(results []*ServerToolResult) (*ServerToolResult, error) { +func concatServerToolResults(results []*ServerToolResult) (ret *ServerToolResult, err error) { if len(results) == 0 { return nil, fmt.Errorf("no server tool result found") } @@ -1306,45 +1475,54 @@ func concatServerToolResult(results []*ServerToolResult) (*ServerToolResult, err return results[0], nil } - // ServerToolResult Result is of type any; merge strategy uses the last non-nil value - ret := &ServerToolResult{ - Extra: make(map[string]any), - } + ret = &ServerToolResult{} + + var ( + resType reflect.Type + resChunks reflect.Value + ) - tZeroResult := reflect.TypeOf(results[0].Result) - data := reflect.MakeSlice(reflect.SliceOf(tZeroResult), 0, 0) for _, r := range results { if r == nil { continue } - if ret.Name == "" { - ret.Name = r.Name - } + if ret.CallID == "" { ret.CallID = r.CallID + } else if r.CallID != "" && r.CallID != ret.CallID { + return nil, fmt.Errorf("expected call ID '%s' for server tool result, but got '%s'", ret.CallID, r.CallID) } + + if ret.Name == "" { + ret.Name = r.Name + } else if r.Name != "" && r.Name != ret.Name { + return nil, fmt.Errorf("expected tool name '%s' for server tool result, but got '%s'", ret.Name, r.Name) + } + if r.Result != nil { - vResult := reflect.ValueOf(r.Result) - if tZeroResult != vResult.Type() { - return nil, fmt.Errorf("tool result types are different: %v %v", tZeroResult, vResult.Type()) + resType_ := reflect.TypeOf(r.Result) + if resType == nil { + resType = resType_ + resChunks = reflect.MakeSlice(reflect.SliceOf(resType), 0, len(results)) + } else if resType != resType_ { + return nil, fmt.Errorf("expected type '%s' for server tool result, but got '%s'", resType, resType_) } - data = reflect.Append(data, vResult) - } - for k, v := range r.Extra { - ret.Extra[k] = v + resChunks = reflect.Append(resChunks, reflect.ValueOf(r.Result)) } } - d, err := internal.ConcatSliceValue(data) - if err != nil { - return nil, fmt.Errorf("failed to concat server tool result: %v", err) + if resChunks.IsValid() && !resChunks.IsZero() { + result, err := internal.ConcatSliceValue(resChunks) + if err != nil { + return nil, fmt.Errorf("failed to concat server tool result: %v", err) + } + ret.Result = result.Interface() } - ret.Result = d return ret, nil } -func concatMCPToolCall(calls []*MCPToolCall) (*MCPToolCall, error) { +func concatMCPToolCalls(calls []*MCPToolCall) (*MCPToolCall, error) { if len(calls) == 0 { return nil, fmt.Errorf("no mcp tool call found") } @@ -1352,36 +1530,38 @@ func concatMCPToolCall(calls []*MCPToolCall) (*MCPToolCall, error) { return calls[0], nil } - ret := &MCPToolCall{ - Extra: make(map[string]any), - } + ret := &MCPToolCall{} for _, c := range calls { if c == nil { continue } + + ret.Arguments += c.Arguments + if ret.ServerLabel == "" { ret.ServerLabel = c.ServerLabel + } else if c.ServerLabel != "" && c.ServerLabel != ret.ServerLabel { + return nil, fmt.Errorf("expected server label '%s' for mcp tool call, but got '%s'", ret.ServerLabel, c.ServerLabel) } - if ret.ApprovalRequestID == "" { - ret.ApprovalRequestID = c.ApprovalRequestID - } + if ret.CallID == "" { ret.CallID = c.CallID + } else if c.CallID != "" && c.CallID != ret.CallID { + return nil, fmt.Errorf("expected call ID '%s' for mcp tool call, but got '%s'", ret.CallID, c.CallID) } + if ret.Name == "" { ret.Name = c.Name - } - ret.Arguments += c.Arguments - for k, v := range c.Extra { - ret.Extra[k] = v + } else if c.Name != "" && c.Name != ret.Name { + return nil, fmt.Errorf("expected tool name '%s' for mcp tool call, but got '%s'", ret.Name, c.Name) } } return ret, nil } -func concatMCPToolResult(results []*MCPToolResult) (*MCPToolResult, error) { +func concatMCPToolResults(results []*MCPToolResult) (*MCPToolResult, error) { if len(results) == 0 { return nil, fmt.Errorf("no mcp tool result found") } @@ -1389,33 +1569,44 @@ func concatMCPToolResult(results []*MCPToolResult) (*MCPToolResult, error) { return results[0], nil } - ret := &MCPToolResult{ - Extra: make(map[string]any), - } + ret := &MCPToolResult{} for _, r := range results { if r == nil { continue } + + if r.Result != "" { + ret.Result = r.Result + } + + if ret.ServerLabel == "" { + ret.ServerLabel = r.ServerLabel + } else if r.ServerLabel != "" && r.ServerLabel != ret.ServerLabel { + return nil, fmt.Errorf("expected server label '%s' for mcp tool result, but got '%s'", ret.ServerLabel, r.ServerLabel) + } + if ret.CallID == "" { ret.CallID = r.CallID + } else if r.CallID != "" && r.CallID != ret.CallID { + return nil, fmt.Errorf("expected call ID '%s' for mcp tool result, but got '%s'", ret.CallID, r.CallID) } + if ret.Name == "" { ret.Name = r.Name + } else if r.Name != "" && r.Name != ret.Name { + return nil, fmt.Errorf("expected tool name '%s' for mcp tool result, but got '%s'", ret.Name, r.Name) } - ret.Result += r.Result + if r.Error != nil { - ret.Error = r.Error // Use the last error - } - for k, v := range r.Extra { - ret.Extra[k] = v + ret.Error = r.Error } } return ret, nil } -func concatMCPListToolsResult(results []*MCPListToolsResult) (*MCPListToolsResult, error) { +func concatMCPListToolsResults(results []*MCPListToolsResult) (*MCPListToolsResult, error) { if len(results) == 0 { return nil, fmt.Errorf("no mcp list tools result found") } @@ -1423,31 +1614,30 @@ func concatMCPListToolsResult(results []*MCPListToolsResult) (*MCPListToolsResul return results[0], nil } - ret := &MCPListToolsResult{ - Tools: make([]*MCPListToolsItem, 0), - Extra: make(map[string]any), - } + ret := &MCPListToolsResult{} for _, r := range results { if r == nil { continue } - if ret.ServerLabel == "" { - ret.ServerLabel = r.ServerLabel - } + ret.Tools = append(ret.Tools, r.Tools...) + if r.Error != "" { - ret.Error = r.Error // Use the last error + ret.Error = r.Error } - for k, v := range r.Extra { - ret.Extra[k] = v + + if ret.ServerLabel == "" { + ret.ServerLabel = r.ServerLabel + } else if r.ServerLabel != "" && r.ServerLabel != ret.ServerLabel { + return nil, fmt.Errorf("expected server label '%s' for mcp list tools result, but got '%s'", ret.ServerLabel, r.ServerLabel) } } return ret, nil } -func concatMCPToolApprovalRequest(requests []*MCPToolApprovalRequest) (*MCPToolApprovalRequest, error) { +func concatMCPToolApprovalRequests(requests []*MCPToolApprovalRequest) (*MCPToolApprovalRequest, error) { if len(requests) == 0 { return nil, fmt.Errorf("no mcp tool approval request found") } @@ -1455,50 +1645,48 @@ func concatMCPToolApprovalRequest(requests []*MCPToolApprovalRequest) (*MCPToolA return requests[0], nil } - ret := &MCPToolApprovalRequest{ - Extra: make(map[string]any), - } + ret := &MCPToolApprovalRequest{} for _, r := range requests { if r == nil { continue } + + ret.Arguments += r.Arguments + if ret.ID == "" { ret.ID = r.ID + } else if r.ID != "" && r.ID != ret.ID { + return nil, fmt.Errorf("expected request ID '%s' for mcp tool approval request, but got '%s'", ret.ID, r.ID) } + if ret.Name == "" { ret.Name = r.Name + } else if r.Name != "" && r.Name != ret.Name { + return nil, fmt.Errorf("expected tool name '%s' for mcp tool approval request, but got '%s'", ret.Name, r.Name) } - ret.Arguments += r.Arguments + if ret.ServerLabel == "" { ret.ServerLabel = r.ServerLabel - } - for k, v := range r.Extra { - ret.Extra[k] = v + } else if r.ServerLabel != "" && r.ServerLabel != ret.ServerLabel { + return nil, fmt.Errorf("expected server label '%s' for mcp tool approval request, but got '%s'", ret.ServerLabel, r.ServerLabel) } } return ret, nil } -func concatMCPToolApprovalResponse(responses []*MCPToolApprovalResponse) (*MCPToolApprovalResponse, error) { +func concatMCPToolApprovalResponses(responses []*MCPToolApprovalResponse) (*MCPToolApprovalResponse, error) { if len(responses) == 0 { return nil, fmt.Errorf("no mcp tool approval response found") } if len(responses) == 1 { return responses[0], nil } - - return responses[len(responses)-1], nil -} - -func expandSlice[T any](idx int, s []T) []T { - if len(s) > idx { - return s - } - return append(s, make([]T, idx-len(s)+1)...) + return nil, fmt.Errorf("cannot concat multiple mcp tool approval responses") } +// String returns the string representation of AgenticMessage. func (m *AgenticMessage) String() string { sb := &strings.Builder{} sb.WriteString(fmt.Sprintf("role: %s\n", m.Role)) @@ -1520,6 +1708,7 @@ func (m *AgenticMessage) String() string { return sb.String() } +// String returns the string representation of ContentBlock. func (b *ContentBlock) String() string { sb := &strings.Builder{} sb.WriteString(fmt.Sprintf("type: %s\n", b.Type)) @@ -1603,13 +1792,14 @@ func (b *ContentBlock) String() string { } } - if b.StreamMeta != nil { - sb.WriteString(fmt.Sprintf(" stream_index: %d\n", b.StreamMeta.Index)) + if b.StreamingMeta != nil { + sb.WriteString(fmt.Sprintf(" stream_index: %d\n", b.StreamingMeta.Index)) } return sb.String() } +// String returns the string representation of Reasoning. func (r *Reasoning) String() string { sb := &strings.Builder{} sb.WriteString(fmt.Sprintf(" summary: %d items\n", len(r.Summary))) @@ -1622,22 +1812,27 @@ func (r *Reasoning) String() string { return sb.String() } +// String returns the string representation of UserInputText. func (u *UserInputText) String() string { return fmt.Sprintf(" text: %s\n", u.Text) } +// String returns the string representation of UserInputImage. func (u *UserInputImage) String() string { return formatMediaString(u.URL, u.Base64Data, u.MIMEType, u.Detail) } +// String returns the string representation of UserInputAudio. func (u *UserInputAudio) String() string { return formatMediaString(u.URL, u.Base64Data, u.MIMEType, "") } +// String returns the string representation of UserInputVideo. func (u *UserInputVideo) String() string { return formatMediaString(u.URL, u.Base64Data, u.MIMEType, "") } +// String returns the string representation of UserInputFile. func (u *UserInputFile) String() string { sb := &strings.Builder{} if u.Name != "" { @@ -1647,22 +1842,27 @@ func (u *UserInputFile) String() string { return sb.String() } +// String returns the string representation of AssistantGenText. func (a *AssistantGenText) String() string { return fmt.Sprintf(" text: %s\n", a.Text) } +// String returns the string representation of AssistantGenImage. func (a *AssistantGenImage) String() string { return formatMediaString(a.URL, a.Base64Data, a.MIMEType, "") } +// String returns the string representation of AssistantGenAudio. func (a *AssistantGenAudio) String() string { return formatMediaString(a.URL, a.Base64Data, a.MIMEType, "") } +// String returns the string representation of AssistantGenVideo. func (a *AssistantGenVideo) String() string { return formatMediaString(a.URL, a.Base64Data, a.MIMEType, "") } +// String returns the string representation of FunctionToolCall. func (f *FunctionToolCall) String() string { sb := &strings.Builder{} sb.WriteString(fmt.Sprintf(" call_id: %s\n", f.CallID)) @@ -1671,6 +1871,7 @@ func (f *FunctionToolCall) String() string { return sb.String() } +// String returns the string representation of FunctionToolResult. func (f *FunctionToolResult) String() string { sb := &strings.Builder{} sb.WriteString(fmt.Sprintf(" call_id: %s\n", f.CallID)) @@ -1679,6 +1880,7 @@ func (f *FunctionToolResult) String() string { return sb.String() } +// String returns the string representation of ServerToolCall. func (s *ServerToolCall) String() string { sb := &strings.Builder{} sb.WriteString(fmt.Sprintf(" name: %s\n", s.Name)) @@ -1689,6 +1891,7 @@ func (s *ServerToolCall) String() string { return sb.String() } +// String returns the string representation of ServerToolResult. func (s *ServerToolResult) String() string { sb := &strings.Builder{} sb.WriteString(fmt.Sprintf(" name: %s\n", s.Name)) @@ -1699,18 +1902,17 @@ func (s *ServerToolResult) String() string { return sb.String() } +// String returns the string representation of MCPToolCall. func (m *MCPToolCall) String() string { sb := &strings.Builder{} sb.WriteString(fmt.Sprintf(" server_label: %s\n", m.ServerLabel)) sb.WriteString(fmt.Sprintf(" call_id: %s\n", m.CallID)) sb.WriteString(fmt.Sprintf(" name: %s\n", m.Name)) sb.WriteString(fmt.Sprintf(" arguments: %s\n", m.Arguments)) - if m.ApprovalRequestID != "" { - sb.WriteString(fmt.Sprintf(" approval_request_id: %s\n", m.ApprovalRequestID)) - } return sb.String() } +// String returns the string representation of MCPToolResult. func (m *MCPToolResult) String() string { sb := &strings.Builder{} sb.WriteString(fmt.Sprintf(" call_id: %s\n", m.CallID)) @@ -1722,6 +1924,7 @@ func (m *MCPToolResult) String() string { return sb.String() } +// String returns the string representation of MCPListToolsResult. func (m *MCPListToolsResult) String() string { sb := &strings.Builder{} sb.WriteString(fmt.Sprintf(" server_label: %s\n", m.ServerLabel)) @@ -1735,6 +1938,7 @@ func (m *MCPListToolsResult) String() string { return sb.String() } +// String returns the string representation of MCPToolApprovalRequest. func (m *MCPToolApprovalRequest) String() string { sb := &strings.Builder{} sb.WriteString(fmt.Sprintf(" server_label: %s\n", m.ServerLabel)) @@ -1744,6 +1948,7 @@ func (m *MCPToolApprovalRequest) String() string { return sb.String() } +// String returns the string representation of MCPToolApprovalResponse. func (m *MCPToolApprovalResponse) String() string { sb := &strings.Builder{} sb.WriteString(fmt.Sprintf(" approval_request_id: %s\n", m.ApprovalRequestID)) @@ -1754,6 +1959,7 @@ func (m *MCPToolApprovalResponse) String() string { return sb.String() } +// String returns the string representation of AgenticResponseMeta. func (a *AgenticResponseMeta) String() string { sb := &strings.Builder{} sb.WriteString("response_meta:\n") @@ -1792,3 +1998,17 @@ func formatMediaString(url, base64Data string, mimeType string, detail any) stri } return sb.String() } + +func validateExtensionType(expected reflect.Type, actual any) (reflect.Type, bool) { + if actual == nil { + return expected, true + } + actualType := reflect.TypeOf(actual) + if expected == nil { + return actualType, true + } + if expected != actualType { + return expected, false + } + return expected, true +} diff --git a/schema/agentic_message_test.go b/schema/agentic_message_test.go index 0cafcd9ff..016aa5c4e 100644 --- a/schema/agentic_message_test.go +++ b/schema/agentic_message_test.go @@ -18,6 +18,7 @@ package schema import ( "context" + "reflect" "testing" "github.com/stretchr/testify/assert" @@ -75,7 +76,7 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "Hello ", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -87,7 +88,7 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "World!", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -112,7 +113,7 @@ func TestConcatAgenticMessages(t *testing.T) { {Index: 0, Text: "First "}, }, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -126,7 +127,7 @@ func TestConcatAgenticMessages(t *testing.T) { {Index: 0, Text: "Second"}, }, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -137,7 +138,7 @@ func TestConcatAgenticMessages(t *testing.T) { assert.Len(t, result.ContentBlocks, 1) assert.Len(t, result.ContentBlocks[0].Reasoning.Summary, 1) assert.Equal(t, "First Second", result.ContentBlocks[0].Reasoning.Summary[0].Text) - assert.Equal(t, int64(0), result.ContentBlocks[0].Reasoning.Summary[0].Index) + assert.Equal(t, 0, result.ContentBlocks[0].Reasoning.Summary[0].Index) }) t.Run("concat reasoning with index", func(t *testing.T) { @@ -153,7 +154,7 @@ func TestConcatAgenticMessages(t *testing.T) { {Index: 1, Text: "Part2-"}, }, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -168,7 +169,7 @@ func TestConcatAgenticMessages(t *testing.T) { {Index: 1, Text: "Part4"}, }, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -185,26 +186,26 @@ func TestConcatAgenticMessages(t *testing.T) { t.Run("concat user input text", func(t *testing.T) { msgs := []*AgenticMessage{ { - Role: AgenticRoleTypeUser, + Role: AgenticRoleTypeAssistant, ContentBlocks: []*ContentBlock{ { - Type: ContentBlockTypeUserInputText, - UserInputText: &UserInputText{ + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ Text: "Hello ", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, { - Role: AgenticRoleTypeUser, + Role: AgenticRoleTypeAssistant, ContentBlocks: []*ContentBlock{ { - Type: ContentBlockTypeUserInputText, - UserInputText: &UserInputText{ + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ Text: "World!", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -213,35 +214,35 @@ func TestConcatAgenticMessages(t *testing.T) { result, err := ConcatAgenticMessages(msgs) assert.NoError(t, err) assert.Len(t, result.ContentBlocks, 1) - assert.Equal(t, "Hello World!", result.ContentBlocks[0].UserInputText.Text) + assert.Equal(t, "Hello World!", result.ContentBlocks[0].AssistantGenText.Text) }) - t.Run("concat user input image", func(t *testing.T) { - url1 := "https://example.com/image1.jpg" - url2 := "https://example.com/image2.jpg" + t.Run("concat assistant gen image", func(t *testing.T) { + base1 := "1" + base2 := "2" msgs := []*AgenticMessage{ { - Role: AgenticRoleTypeUser, + Role: AgenticRoleTypeAssistant, ContentBlocks: []*ContentBlock{ { - Type: ContentBlockTypeUserInputImage, - UserInputImage: &UserInputImage{ - URL: url1, + Type: ContentBlockTypeAssistantGenImage, + AssistantGenImage: &AssistantGenImage{ + Base64Data: base1, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, { - Role: AgenticRoleTypeUser, + Role: AgenticRoleTypeAssistant, ContentBlocks: []*ContentBlock{ { - Type: ContentBlockTypeUserInputImage, - UserInputImage: &UserInputImage{ - URL: url2, + Type: ContentBlockTypeAssistantGenImage, + AssistantGenImage: &AssistantGenImage{ + Base64Data: base2, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -250,11 +251,10 @@ func TestConcatAgenticMessages(t *testing.T) { result, err := ConcatAgenticMessages(msgs) assert.NoError(t, err) assert.Len(t, result.ContentBlocks, 1) - // Should take the last image - assert.Equal(t, url2, result.ContentBlocks[0].UserInputImage.URL) + assert.Equal(t, "12", result.ContentBlocks[0].AssistantGenImage.Base64Data) }) - t.Run("concat user input audio", func(t *testing.T) { + t.Run("concat user input audio - should error", func(t *testing.T) { url1 := "https://example.com/audio1.mp3" url2 := "https://example.com/audio2.mp3" @@ -267,7 +267,7 @@ func TestConcatAgenticMessages(t *testing.T) { UserInputAudio: &UserInputAudio{ URL: url1, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -279,20 +279,18 @@ func TestConcatAgenticMessages(t *testing.T) { UserInputAudio: &UserInputAudio{ URL: url2, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, } - result, err := ConcatAgenticMessages(msgs) - assert.NoError(t, err) - assert.Len(t, result.ContentBlocks, 1) - // Should take the last audio - assert.Equal(t, url2, result.ContentBlocks[0].UserInputAudio.URL) + _, err := ConcatAgenticMessages(msgs) + assert.Error(t, err) + assert.ErrorContains(t, err, "cannot concat multiple user input audios") }) - t.Run("concat user input video", func(t *testing.T) { + t.Run("concat user input video - should error", func(t *testing.T) { url1 := "https://example.com/video1.mp4" url2 := "https://example.com/video2.mp4" @@ -305,7 +303,7 @@ func TestConcatAgenticMessages(t *testing.T) { UserInputVideo: &UserInputVideo{ URL: url1, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -317,17 +315,15 @@ func TestConcatAgenticMessages(t *testing.T) { UserInputVideo: &UserInputVideo{ URL: url2, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, } - result, err := ConcatAgenticMessages(msgs) - assert.NoError(t, err) - assert.Len(t, result.ContentBlocks, 1) - // Should take the last video - assert.Equal(t, url2, result.ContentBlocks[0].UserInputVideo.URL) + _, err := ConcatAgenticMessages(msgs) + assert.Error(t, err) + assert.ErrorContains(t, err, "cannot concat multiple user input videos") }) t.Run("concat assistant gen text", func(t *testing.T) { @@ -340,7 +336,7 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "Generated ", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -352,7 +348,7 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "Text", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -365,9 +361,6 @@ func TestConcatAgenticMessages(t *testing.T) { }) t.Run("concat assistant gen image", func(t *testing.T) { - url1 := "https://example.com/gen_image1.jpg" - url2 := "https://example.com/gen_image2.jpg" - msgs := []*AgenticMessage{ { Role: AgenticRoleTypeAssistant, @@ -375,9 +368,9 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeAssistantGenImage, AssistantGenImage: &AssistantGenImage{ - URL: url1, + Base64Data: "part1", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -387,9 +380,9 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeAssistantGenImage, AssistantGenImage: &AssistantGenImage{ - URL: url2, + Base64Data: "part2", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -398,14 +391,10 @@ func TestConcatAgenticMessages(t *testing.T) { result, err := ConcatAgenticMessages(msgs) assert.NoError(t, err) assert.Len(t, result.ContentBlocks, 1) - // Should take the last image - assert.Equal(t, url2, result.ContentBlocks[0].AssistantGenImage.URL) + assert.Equal(t, "part1part2", result.ContentBlocks[0].AssistantGenImage.Base64Data) }) t.Run("concat assistant gen audio", func(t *testing.T) { - url1 := "https://example.com/gen_audio1.mp3" - url2 := "https://example.com/gen_audio2.mp3" - msgs := []*AgenticMessage{ { Role: AgenticRoleTypeAssistant, @@ -413,9 +402,9 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeAssistantGenAudio, AssistantGenAudio: &AssistantGenAudio{ - URL: url1, + Base64Data: "audio1", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -425,9 +414,9 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeAssistantGenAudio, AssistantGenAudio: &AssistantGenAudio{ - URL: url2, + Base64Data: "audio2", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -436,14 +425,10 @@ func TestConcatAgenticMessages(t *testing.T) { result, err := ConcatAgenticMessages(msgs) assert.NoError(t, err) assert.Len(t, result.ContentBlocks, 1) - // Should take the last audio - assert.Equal(t, url2, result.ContentBlocks[0].AssistantGenAudio.URL) + assert.Equal(t, "audio1audio2", result.ContentBlocks[0].AssistantGenAudio.Base64Data) }) t.Run("concat assistant gen video", func(t *testing.T) { - url1 := "https://example.com/gen_video1.mp4" - url2 := "https://example.com/gen_video2.mp4" - msgs := []*AgenticMessage{ { Role: AgenticRoleTypeAssistant, @@ -451,9 +436,9 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeAssistantGenVideo, AssistantGenVideo: &AssistantGenVideo{ - URL: url1, + Base64Data: "video1", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -463,9 +448,9 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeAssistantGenVideo, AssistantGenVideo: &AssistantGenVideo{ - URL: url2, + Base64Data: "video2", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -474,8 +459,7 @@ func TestConcatAgenticMessages(t *testing.T) { result, err := ConcatAgenticMessages(msgs) assert.NoError(t, err) assert.Len(t, result.ContentBlocks, 1) - // Should take the last video - assert.Equal(t, url2, result.ContentBlocks[0].AssistantGenVideo.URL) + assert.Equal(t, "video1video2", result.ContentBlocks[0].AssistantGenVideo.Base64Data) }) t.Run("concat function tool call", func(t *testing.T) { @@ -490,7 +474,7 @@ func TestConcatAgenticMessages(t *testing.T) { Name: "get_weather", Arguments: `{"location`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -502,7 +486,7 @@ func TestConcatAgenticMessages(t *testing.T) { FunctionToolCall: &FunctionToolCall{ Arguments: `":"NYC"}`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -528,7 +512,7 @@ func TestConcatAgenticMessages(t *testing.T) { Name: "get_weather", Result: `{"temp`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -540,7 +524,7 @@ func TestConcatAgenticMessages(t *testing.T) { FunctionToolResult: &FunctionToolResult{ Result: `":72}`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -565,7 +549,7 @@ func TestConcatAgenticMessages(t *testing.T) { CallID: "server_call_1", Name: "server_func", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -577,7 +561,7 @@ func TestConcatAgenticMessages(t *testing.T) { ServerToolCall: &ServerToolCall{ Arguments: map[string]any{"key": "value"}, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -603,7 +587,7 @@ func TestConcatAgenticMessages(t *testing.T) { Name: "server_func", Result: "result1", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -611,11 +595,9 @@ func TestConcatAgenticMessages(t *testing.T) { Role: AgenticRoleTypeAssistant, ContentBlocks: []*ContentBlock{ { - Type: ContentBlockTypeServerToolResult, - ServerToolResult: &ServerToolResult{ - Result: "result2", - }, - StreamMeta: &StreamMeta{Index: 0}, + Type: ContentBlockTypeServerToolResult, + ServerToolResult: &ServerToolResult{}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -626,6 +608,7 @@ func TestConcatAgenticMessages(t *testing.T) { assert.Len(t, result.ContentBlocks, 1) assert.Equal(t, "server_call_1", result.ContentBlocks[0].ServerToolResult.CallID) assert.Equal(t, "server_func", result.ContentBlocks[0].ServerToolResult.Name) + assert.Equal(t, "result1", result.ContentBlocks[0].ServerToolResult.Result) }) t.Run("concat mcp tool call", func(t *testing.T) { @@ -641,7 +624,7 @@ func TestConcatAgenticMessages(t *testing.T) { Name: "mcp_func", Arguments: `{"arg`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -653,7 +636,7 @@ func TestConcatAgenticMessages(t *testing.T) { MCPToolCall: &MCPToolCall{ Arguments: `":123}`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -676,11 +659,12 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeMCPToolResult, MCPToolResult: &MCPToolResult{ - CallID: "mcp_call_1", - Name: "mcp_func", - Result: `{"res`, + ServerLabel: "mcp-server", + CallID: "mcp_call_1", + Name: "mcp_func", + Result: `First`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -690,9 +674,9 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeMCPToolResult, MCPToolResult: &MCPToolResult{ - Result: `ult":true}`, + Result: `Second`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -701,9 +685,10 @@ func TestConcatAgenticMessages(t *testing.T) { result, err := ConcatAgenticMessages(msgs) assert.NoError(t, err) assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "mcp-server", result.ContentBlocks[0].MCPToolResult.ServerLabel) assert.Equal(t, "mcp_call_1", result.ContentBlocks[0].MCPToolResult.CallID) assert.Equal(t, "mcp_func", result.ContentBlocks[0].MCPToolResult.Name) - assert.Equal(t, `{"result":true}`, result.ContentBlocks[0].MCPToolResult.Result) + assert.Equal(t, `Second`, result.ContentBlocks[0].MCPToolResult.Result) }) t.Run("concat mcp list tools", func(t *testing.T) { @@ -719,7 +704,7 @@ func TestConcatAgenticMessages(t *testing.T) { {Name: "tool1"}, }, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -733,7 +718,7 @@ func TestConcatAgenticMessages(t *testing.T) { {Name: "tool2"}, }, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -759,7 +744,7 @@ func TestConcatAgenticMessages(t *testing.T) { ServerLabel: "mcp-server", Arguments: `{"request`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -771,7 +756,7 @@ func TestConcatAgenticMessages(t *testing.T) { MCPToolApprovalRequest: &MCPToolApprovalRequest{ Arguments: `":1}`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -786,7 +771,7 @@ func TestConcatAgenticMessages(t *testing.T) { assert.Equal(t, `{"request":1}`, result.ContentBlocks[0].MCPToolApprovalRequest.Arguments) }) - t.Run("concat mcp tool approval response", func(t *testing.T) { + t.Run("concat mcp tool approval response - should error", func(t *testing.T) { response1 := &MCPToolApprovalResponse{ ApprovalRequestID: "approval_1", Approve: false, @@ -803,7 +788,7 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeMCPToolApprovalResponse, MCPToolApprovalResponse: response1, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -813,17 +798,15 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeMCPToolApprovalResponse, MCPToolApprovalResponse: response2, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, } - result, err := ConcatAgenticMessages(msgs) - assert.NoError(t, err) - assert.Len(t, result.ContentBlocks, 1) - // Should take the last response - assert.Equal(t, response2, result.ContentBlocks[0].MCPToolApprovalResponse) + _, err := ConcatAgenticMessages(msgs) + assert.Error(t, err) + assert.ErrorContains(t, err, "cannot concat multiple mcp tool approval responses") }) t.Run("concat response meta", func(t *testing.T) { @@ -865,7 +848,7 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "Hello", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -877,7 +860,7 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "World", }, - // No StreamMeta - non-streaming + // No StreamingMeta - non-streaming }, }, }, @@ -901,7 +884,7 @@ func TestConcatAgenticMessages(t *testing.T) { Name: "list_files", Arguments: `{"path`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -913,7 +896,7 @@ func TestConcatAgenticMessages(t *testing.T) { MCPToolCall: &MCPToolCall{ Arguments: `":"/tmp"}`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -927,7 +910,7 @@ func TestConcatAgenticMessages(t *testing.T) { assert.Equal(t, `{"path":"/tmp"}`, result.ContentBlocks[0].MCPToolCall.Arguments) }) - t.Run("concat user input text", func(t *testing.T) { + t.Run("concat user input text - should error", func(t *testing.T) { msgs := []*AgenticMessage{ { Role: AgenticRoleTypeUser, @@ -937,7 +920,7 @@ func TestConcatAgenticMessages(t *testing.T) { UserInputText: &UserInputText{ Text: "What is ", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -949,16 +932,15 @@ func TestConcatAgenticMessages(t *testing.T) { UserInputText: &UserInputText{ Text: "the weather?", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, } - result, err := ConcatAgenticMessages(msgs) - assert.NoError(t, err) - assert.Len(t, result.ContentBlocks, 1) - assert.Equal(t, "What is the weather?", result.ContentBlocks[0].UserInputText.Text) + _, err := ConcatAgenticMessages(msgs) + assert.Error(t, err) + assert.ErrorContains(t, err, "cannot concat multiple user input texts") }) t.Run("multiple stream indexes - sparse indexes", func(t *testing.T) { @@ -971,14 +953,14 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "Index0-", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, { Type: ContentBlockTypeAssistantGenText, AssistantGenText: &AssistantGenText{ Text: "Index2-", }, - StreamMeta: &StreamMeta{Index: 2}, + StreamingMeta: &StreamingMeta{Index: 2}, }, }, }, @@ -990,14 +972,14 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "Part2", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, { Type: ContentBlockTypeAssistantGenText, AssistantGenText: &AssistantGenText{ Text: "Part2", }, - StreamMeta: &StreamMeta{Index: 2}, + StreamingMeta: &StreamingMeta{Index: 2}, }, }, }, @@ -1020,7 +1002,7 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "Text ", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, { Type: ContentBlockTypeFunctionToolCall, @@ -1029,7 +1011,7 @@ func TestConcatAgenticMessages(t *testing.T) { Name: "func1", Arguments: `{"a`, }, - StreamMeta: &StreamMeta{Index: 1}, + StreamingMeta: &StreamingMeta{Index: 1}, }, }, }, @@ -1041,14 +1023,14 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "Content", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, { Type: ContentBlockTypeFunctionToolCall, FunctionToolCall: &FunctionToolCall{ Arguments: `":1}`, }, - StreamMeta: &StreamMeta{Index: 1}, + StreamingMeta: &StreamingMeta{Index: 1}, }, }, }, @@ -1073,21 +1055,21 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "A", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, { Type: ContentBlockTypeAssistantGenText, AssistantGenText: &AssistantGenText{ Text: "B", }, - StreamMeta: &StreamMeta{Index: 1}, + StreamingMeta: &StreamingMeta{Index: 1}, }, { Type: ContentBlockTypeAssistantGenText, AssistantGenText: &AssistantGenText{ Text: "C", }, - StreamMeta: &StreamMeta{Index: 2}, + StreamingMeta: &StreamingMeta{Index: 2}, }, }, }, @@ -1099,21 +1081,21 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "1", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, { Type: ContentBlockTypeAssistantGenText, AssistantGenText: &AssistantGenText{ Text: "2", }, - StreamMeta: &StreamMeta{Index: 1}, + StreamingMeta: &StreamingMeta{Index: 1}, }, { Type: ContentBlockTypeAssistantGenText, AssistantGenText: &AssistantGenText{ Text: "3", }, - StreamMeta: &StreamMeta{Index: 2}, + StreamingMeta: &StreamingMeta{Index: 2}, }, }, }, @@ -1276,7 +1258,7 @@ func TestAgenticMessageString(t *testing.T) { Name: "get_current_weather", Arguments: `{"location":"New York City","unit":"fahrenheit"}`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, { Type: ContentBlockTypeFunctionToolResult, @@ -1289,11 +1271,10 @@ func TestAgenticMessageString(t *testing.T) { { Type: ContentBlockTypeMCPToolCall, MCPToolCall: &MCPToolCall{ - ServerLabel: "weather-mcp-server", - CallID: "mcp_forecast_456", - Name: "get_7day_forecast", - Arguments: `{"city":"New York","days":7}`, - ApprovalRequestID: "approval_req_789", + ServerLabel: "weather-mcp-server", + CallID: "mcp_forecast_456", + Name: "get_7day_forecast", + Arguments: `{"city":"New York","days":7}`, }, }, { @@ -1363,7 +1344,6 @@ content_blocks: call_id: mcp_forecast_456 name: get_7day_forecast arguments: {"city":"New York","days":7} - approval_request_id: approval_req_789 [7] type: mcp_tool_result call_id: mcp_forecast_456 name: get_7day_forecast @@ -1378,4 +1358,294 @@ content_blocks: response_meta: token_usage: prompt=250, completion=180, total=430 `, output) + + t.Run("full fields", func(t *testing.T) { + msg := &AgenticMessage{ + Role: AgenticRoleTypeSystem, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputAudio, + UserInputAudio: &UserInputAudio{ + URL: "http://audio.com", + Base64Data: "audio_data", + MIMEType: "audio/mp3", + }, + }, + { + Type: ContentBlockTypeUserInputVideo, + UserInputVideo: &UserInputVideo{ + URL: "http://video.com", + Base64Data: "video_data", + MIMEType: "video/mp4", + }, + }, + { + Type: ContentBlockTypeUserInputFile, + UserInputFile: &UserInputFile{ + URL: "http://file.com", + Name: "file.txt", + Base64Data: "file_data", + MIMEType: "text/plain", + }, + }, + { + Type: ContentBlockTypeAssistantGenImage, + AssistantGenImage: &AssistantGenImage{ + URL: "http://gen_image.com", + Base64Data: "gen_image_data", + MIMEType: "image/png", + }, + }, + { + Type: ContentBlockTypeAssistantGenAudio, + AssistantGenAudio: &AssistantGenAudio{ + URL: "http://gen_audio.com", + Base64Data: "gen_audio_data", + MIMEType: "audio/wav", + }, + }, + { + Type: ContentBlockTypeAssistantGenVideo, + AssistantGenVideo: &AssistantGenVideo{ + URL: "http://gen_video.com", + Base64Data: "gen_video_data", + MIMEType: "video/mp4", + }, + }, + { + Type: ContentBlockTypeServerToolCall, + ServerToolCall: &ServerToolCall{ + Name: "server_tool", + CallID: "call_1", + Arguments: map[string]any{"a": 1}, + }, + }, + { + Type: ContentBlockTypeServerToolResult, + ServerToolResult: &ServerToolResult{ + Name: "server_tool", + CallID: "call_1", + Result: map[string]any{"success": true}, + }, + }, + { + Type: ContentBlockTypeMCPToolApprovalRequest, + MCPToolApprovalRequest: &MCPToolApprovalRequest{ + ID: "req_1", + Name: "mcp_tool", + ServerLabel: "mcp_server", + Arguments: "{}", + }, + }, + { + Type: ContentBlockTypeMCPToolApprovalResponse, + MCPToolApprovalResponse: &MCPToolApprovalResponse{ + ApprovalRequestID: "req_1", + Approve: true, + Reason: "looks good", + }, + }, + }, + } + + s := msg.String() + assert.Contains(t, s, "role: system") + assert.Contains(t, s, "type: user_input_audio") + assert.Contains(t, s, "http://audio.com") + assert.Contains(t, s, "type: user_input_video") + assert.Contains(t, s, "http://video.com") + assert.Contains(t, s, "type: user_input_file") + assert.Contains(t, s, "file.txt") + assert.Contains(t, s, "type: assistant_gen_image") + assert.Contains(t, s, "http://gen_image.com") + assert.Contains(t, s, "type: assistant_gen_audio") + assert.Contains(t, s, "http://gen_audio.com") + assert.Contains(t, s, "type: assistant_gen_video") + assert.Contains(t, s, "http://gen_video.com") + assert.Contains(t, s, "type: server_tool_call") + assert.Contains(t, s, "server_tool") + assert.Contains(t, s, "map[a:1]") + assert.Contains(t, s, "type: server_tool_result") + assert.Contains(t, s, "map[success:true]") + assert.Contains(t, s, "type: mcp_tool_approval_request") + assert.Contains(t, s, "req_1") + assert.Contains(t, s, "type: mcp_tool_approval_response") + assert.Contains(t, s, "looks good") + }) + + t.Run("nil/empty fields", func(t *testing.T) { + msg := &AgenticMessage{ + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + {Type: ContentBlockTypeUserInputAudio, UserInputAudio: &UserInputAudio{}}, // empty + {Type: ContentBlockTypeUserInputVideo, UserInputVideo: &UserInputVideo{}}, + {Type: ContentBlockTypeUserInputFile, UserInputFile: &UserInputFile{}}, + {Type: ContentBlockTypeAssistantGenImage, AssistantGenImage: &AssistantGenImage{}}, + {Type: ContentBlockTypeAssistantGenAudio, AssistantGenAudio: &AssistantGenAudio{}}, + {Type: ContentBlockTypeAssistantGenVideo, AssistantGenVideo: &AssistantGenVideo{}}, + {Type: ContentBlockTypeServerToolCall, ServerToolCall: &ServerToolCall{Name: "t"}}, // No CallID + {Type: ContentBlockTypeServerToolResult, ServerToolResult: &ServerToolResult{Name: "t"}}, // No CallID + {Type: ContentBlockTypeMCPToolResult, MCPToolResult: &MCPToolResult{Name: "t"}}, // No Error + {Type: ContentBlockTypeMCPListToolsResult, MCPListToolsResult: &MCPListToolsResult{}}, // No Error + {Type: ContentBlockTypeMCPToolApprovalResponse, MCPToolApprovalResponse: &MCPToolApprovalResponse{Approve: false}}, // No Reason + nil, // Nil block in slice + }, + } + + s := msg.String() + assert.Contains(t, s, "type: user_input_audio") + assert.NotContains(t, s, "mime_type:") + assert.Contains(t, s, "type: server_tool_call") + }) + + t.Run("nil content struct in block", func(t *testing.T) { + // Test cases where the specific content struct is nil but type is set + // This shouldn't crash and should just print type + msg := &AgenticMessage{ + ContentBlocks: []*ContentBlock{ + {Type: ContentBlockTypeReasoning, Reasoning: nil}, + {Type: ContentBlockTypeUserInputText, UserInputText: nil}, + {Type: ContentBlockTypeUserInputImage, UserInputImage: nil}, + {Type: ContentBlockTypeUserInputAudio, UserInputAudio: nil}, + {Type: ContentBlockTypeUserInputVideo, UserInputVideo: nil}, + {Type: ContentBlockTypeUserInputFile, UserInputFile: nil}, + {Type: ContentBlockTypeAssistantGenText, AssistantGenText: nil}, + {Type: ContentBlockTypeAssistantGenImage, AssistantGenImage: nil}, + {Type: ContentBlockTypeAssistantGenAudio, AssistantGenAudio: nil}, + {Type: ContentBlockTypeAssistantGenVideo, AssistantGenVideo: nil}, + {Type: ContentBlockTypeFunctionToolCall, FunctionToolCall: nil}, + {Type: ContentBlockTypeFunctionToolResult, FunctionToolResult: nil}, + {Type: ContentBlockTypeServerToolCall, ServerToolCall: nil}, + {Type: ContentBlockTypeServerToolResult, ServerToolResult: nil}, + {Type: ContentBlockTypeMCPToolCall, MCPToolCall: nil}, + {Type: ContentBlockTypeMCPToolResult, MCPToolResult: nil}, + {Type: ContentBlockTypeMCPListToolsResult, MCPListToolsResult: nil}, + {Type: ContentBlockTypeMCPToolApprovalRequest, MCPToolApprovalRequest: nil}, + {Type: ContentBlockTypeMCPToolApprovalResponse, MCPToolApprovalResponse: nil}, + }, + } + s := msg.String() + assert.Contains(t, s, "type: reasoning") + // ensure no panic and basic output present + }) +} + +func TestDeveloperAgenticMessage(t *testing.T) { + t.Run("basic", func(t *testing.T) { + msg := DeveloperAgenticMessage("developer") + assert.Equal(t, AgenticRoleTypeDeveloper, msg.Role) + assert.Len(t, msg.ContentBlocks, 1) + assert.Equal(t, "developer", msg.ContentBlocks[0].UserInputText.Text) + }) +} + +func TestSystemAgenticMessage(t *testing.T) { + t.Run("basic", func(t *testing.T) { + msg := SystemAgenticMessage("system") + assert.Equal(t, AgenticRoleTypeSystem, msg.Role) + assert.Len(t, msg.ContentBlocks, 1) + assert.Equal(t, "system", msg.ContentBlocks[0].UserInputText.Text) + }) +} + +func TestUserAgenticMessage(t *testing.T) { + t.Run("basic", func(t *testing.T) { + msg := UserAgenticMessage("user") + assert.Equal(t, AgenticRoleTypeUser, msg.Role) + assert.Len(t, msg.ContentBlocks, 1) + assert.Equal(t, "user", msg.ContentBlocks[0].UserInputText.Text) + }) +} + +func TestFunctionToolResultAgenticMessage(t *testing.T) { + t.Run("basic", func(t *testing.T) { + msg := FunctionToolResultAgenticMessage("call_1", "tool_name", "result_str") + assert.Equal(t, AgenticRoleTypeUser, msg.Role) + assert.Len(t, msg.ContentBlocks, 1) + assert.Equal(t, ContentBlockTypeFunctionToolResult, msg.ContentBlocks[0].Type) + assert.Equal(t, "call_1", msg.ContentBlocks[0].FunctionToolResult.CallID) + assert.Equal(t, "tool_name", msg.ContentBlocks[0].FunctionToolResult.Name) + assert.Equal(t, "result_str", msg.ContentBlocks[0].FunctionToolResult.Result) + }) +} + +func TestNewContentBlock(t *testing.T) { + cbType := reflect.TypeOf(ContentBlock{}) + for i := 0; i < cbType.NumField(); i++ { + field := cbType.Field(i) + + // Skip non-content fields + if field.Name == "Type" || field.Name == "Extra" || field.Name == "StreamingMeta" { + continue + } + + t.Run(field.Name, func(t *testing.T) { + // Ensure field is a pointer + assert.Equal(t, reflect.Ptr, field.Type.Kind(), "Field %s should be a pointer", field.Name) + + // Create a new instance of the field's type + // field.Type is *T, so Elem() is T. reflect.New(T) returns *T. + elemType := field.Type.Elem() + inputVal := reflect.New(elemType) + input := inputVal.Interface() + + // Call NewContentBlock (generic) via type switch + var block *ContentBlock + switch v := input.(type) { + case *Reasoning: + block = NewContentBlock(v) + case *UserInputText: + block = NewContentBlock(v) + case *UserInputImage: + block = NewContentBlock(v) + case *UserInputAudio: + block = NewContentBlock(v) + case *UserInputVideo: + block = NewContentBlock(v) + case *UserInputFile: + block = NewContentBlock(v) + case *AssistantGenText: + block = NewContentBlock(v) + case *AssistantGenImage: + block = NewContentBlock(v) + case *AssistantGenAudio: + block = NewContentBlock(v) + case *AssistantGenVideo: + block = NewContentBlock(v) + case *FunctionToolCall: + block = NewContentBlock(v) + case *FunctionToolResult: + block = NewContentBlock(v) + case *ServerToolCall: + block = NewContentBlock(v) + case *ServerToolResult: + block = NewContentBlock(v) + case *MCPToolCall: + block = NewContentBlock(v) + case *MCPToolResult: + block = NewContentBlock(v) + case *MCPListToolsResult: + block = NewContentBlock(v) + case *MCPToolApprovalRequest: + block = NewContentBlock(v) + case *MCPToolApprovalResponse: + block = NewContentBlock(v) + default: + t.Fatalf("unsupported ContentBlock field type: %T", input) + } + + // Assertions + assert.NotNil(t, block, "NewContentBlock should return non-nil for type %T", input) + + // Check if the corresponding field in block is set equals to input + blockVal := reflect.ValueOf(block).Elem() + fieldVal := blockVal.FieldByName(field.Name) + assert.True(t, fieldVal.IsValid(), "Field %s not found in result", field.Name) + assert.Equal(t, input, fieldVal.Interface(), "Field %s should match input", field.Name) + + // Check Type is set + typeVal := blockVal.FieldByName("Type") + assert.NotEmpty(t, typeVal.String(), "Type should be set for %s", field.Name) + }) + } } diff --git a/schema/claude/consts.go b/schema/claude/consts.go index cbf8784f6..714b0362e 100644 --- a/schema/claude/consts.go +++ b/schema/claude/consts.go @@ -14,6 +14,7 @@ * limitations under the License. */ +// Package claude defines constants for claude. package claude type TextCitationType string diff --git a/schema/claude/content_block.go b/schema/claude/extension.go similarity index 51% rename from schema/claude/content_block.go rename to schema/claude/extension.go index 0c43d1045..5df8d8907 100644 --- a/schema/claude/content_block.go +++ b/schema/claude/extension.go @@ -16,6 +16,15 @@ package claude +import ( + "fmt" +) + +type ResponseMetaExtension struct { + ID string `json:"id,omitempty"` + StopReason string `json:"stop_reason,omitempty"` +} + type AssistantGenTextExtension struct { Citations []*TextCitation `json:"citations,omitempty"` } @@ -33,30 +42,30 @@ type CitationCharLocation struct { CitedText string `json:"cited_text,omitempty"` DocumentTitle string `json:"document_title,omitempty"` - DocumentIndex int64 `json:"document_index,omitempty"` + DocumentIndex int `json:"document_index,omitempty"` - StartCharIndex int64 `json:"start_char_index,omitempty"` - EndCharIndex int64 `json:"end_char_index,omitempty"` + StartCharIndex int `json:"start_char_index,omitempty"` + EndCharIndex int `json:"end_char_index,omitempty"` } type CitationPageLocation struct { CitedText string `json:"cited_text,omitempty"` DocumentTitle string `json:"document_title,omitempty"` - DocumentIndex int64 `json:"document_index,omitempty"` + DocumentIndex int `json:"document_index,omitempty"` - StartPageNumber int64 `json:"start_page_number,omitempty"` - EndPageNumber int64 `json:"end_page_number,omitempty"` + StartPageNumber int `json:"start_page_number,omitempty"` + EndPageNumber int `json:"end_page_number,omitempty"` } type CitationContentBlockLocation struct { CitedText string `json:"cited_text,omitempty"` DocumentTitle string `json:"document_title,omitempty"` - DocumentIndex int64 `json:"document_index,omitempty"` + DocumentIndex int `json:"document_index,omitempty"` - StartBlockIndex int64 `json:"start_block_index,omitempty"` - EndBlockIndex int64 `json:"end_block_index,omitempty"` + StartBlockIndex int `json:"start_block_index,omitempty"` + EndBlockIndex int `json:"end_block_index,omitempty"` } type CitationWebSearchResultLocation struct { @@ -67,3 +76,46 @@ type CitationWebSearchResultLocation struct { EncryptedIndex string `json:"encrypted_index,omitempty"` } + +// ConcatAssistantGenTextExtensions concatenates multiple AssistantGenTextExtension chunks into a single one. +func ConcatAssistantGenTextExtensions(chunks []*AssistantGenTextExtension) (*AssistantGenTextExtension, error) { + if len(chunks) == 0 { + return nil, fmt.Errorf("no assistant generated text extension found") + } + if len(chunks) == 1 { + return chunks[0], nil + } + + ret := &AssistantGenTextExtension{ + Citations: make([]*TextCitation, 0, len(chunks)), + } + + for _, ext := range chunks { + ret.Citations = append(ret.Citations, ext.Citations...) + } + + return ret, nil +} + +// ConcatResponseMetaExtensions concatenates multiple ResponseMetaExtension chunks into a single one. +func ConcatResponseMetaExtensions(chunks []*ResponseMetaExtension) (*ResponseMetaExtension, error) { + if len(chunks) == 0 { + return nil, fmt.Errorf("no response meta extension found") + } + if len(chunks) == 1 { + return chunks[0], nil + } + + ret := &ResponseMetaExtension{} + + for _, ext := range chunks { + if ext.ID != "" { + ret.ID = ext.ID + } + if ext.StopReason != "" { + ret.StopReason = ext.StopReason + } + } + + return ret, nil +} diff --git a/schema/claude/extension_test.go b/schema/claude/extension_test.go new file mode 100644 index 000000000..474fe740b --- /dev/null +++ b/schema/claude/extension_test.go @@ -0,0 +1,190 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package claude + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConcatAssistantGenTextExtensions(t *testing.T) { + t.Run("multiple extensions - concatenates all citations", func(t *testing.T) { + exts := []*AssistantGenTextExtension{ + { + Citations: []*TextCitation{ + { + Type: "char_location", + CharLocation: &CitationCharLocation{ + CitedText: "citation 1", + DocumentIndex: 0, + }, + }, + }, + }, + { + Citations: []*TextCitation{ + { + Type: "page_location", + PageLocation: &CitationPageLocation{ + CitedText: "citation 2", + StartPageNumber: 1, + EndPageNumber: 2, + }, + }, + { + Type: "web_search_result_location", + WebSearchResultLocation: &CitationWebSearchResultLocation{ + CitedText: "citation 3", + URL: "https://example.com", + }, + }, + }, + }, + { + Citations: []*TextCitation{ + { + Type: "content_block_location", + ContentBlockLocation: &CitationContentBlockLocation{ + CitedText: "citation 4", + StartBlockIndex: 0, + EndBlockIndex: 5, + }, + }, + }, + }, + } + + result, err := ConcatAssistantGenTextExtensions(exts) + assert.NoError(t, err) + assert.Len(t, result.Citations, 4) + assert.Equal(t, "citation 1", result.Citations[0].CharLocation.CitedText) + assert.Equal(t, "citation 2", result.Citations[1].PageLocation.CitedText) + assert.Equal(t, "citation 3", result.Citations[2].WebSearchResultLocation.CitedText) + assert.Equal(t, "citation 4", result.Citations[3].ContentBlockLocation.CitedText) + }) + + t.Run("mixed empty and non-empty citations", func(t *testing.T) { + exts := []*AssistantGenTextExtension{ + {Citations: nil}, + { + Citations: []*TextCitation{ + { + Type: "char_location", + CharLocation: &CitationCharLocation{ + CitedText: "text1", + }, + }, + }, + }, + {Citations: []*TextCitation{}}, + { + Citations: []*TextCitation{ + { + Type: "page_location", + PageLocation: &CitationPageLocation{ + CitedText: "text2", + }, + }, + }, + }, + } + + result, err := ConcatAssistantGenTextExtensions(exts) + assert.NoError(t, err) + assert.Len(t, result.Citations, 2) + assert.Equal(t, "text1", result.Citations[0].CharLocation.CitedText) + assert.Equal(t, "text2", result.Citations[1].PageLocation.CitedText) + }) + + t.Run("streaming scenario - citations arrive in chunks", func(t *testing.T) { + // Simulates streaming where citations arrive progressively + exts := []*AssistantGenTextExtension{ + { + Citations: []*TextCitation{ + {Type: "char_location", CharLocation: &CitationCharLocation{CitedText: "chunk1"}}, + }, + }, + { + Citations: []*TextCitation{ + {Type: "char_location", CharLocation: &CitationCharLocation{CitedText: "chunk2"}}, + }, + }, + { + Citations: []*TextCitation{ + {Type: "char_location", CharLocation: &CitationCharLocation{CitedText: "chunk3"}}, + }, + }, + } + + result, err := ConcatAssistantGenTextExtensions(exts) + assert.NoError(t, err) + assert.Len(t, result.Citations, 3) + assert.Equal(t, "chunk1", result.Citations[0].CharLocation.CitedText) + assert.Equal(t, "chunk2", result.Citations[1].CharLocation.CitedText) + assert.Equal(t, "chunk3", result.Citations[2].CharLocation.CitedText) + }) +} + +func TestConcatResponseMetaExtensions(t *testing.T) { + t.Run("multiple extensions - takes last non-empty values", func(t *testing.T) { + exts := []*ResponseMetaExtension{ + { + ID: "msg_1", + StopReason: "stop_1", + }, + { + ID: "msg_2", + StopReason: "", + }, + { + ID: "", + StopReason: "stop_3", + }, + } + + result, err := ConcatResponseMetaExtensions(exts) + assert.NoError(t, err) + assert.Equal(t, "msg_2", result.ID) // Last non-empty ID + assert.Equal(t, "stop_3", result.StopReason) // Last non-empty StopReason + }) + + t.Run("all empty fields", func(t *testing.T) { + exts := []*ResponseMetaExtension{ + {ID: "", StopReason: ""}, + {ID: "", StopReason: ""}, + } + + result, err := ConcatResponseMetaExtensions(exts) + assert.NoError(t, err) + assert.Equal(t, "", result.ID) + assert.Equal(t, "", result.StopReason) + }) + + t.Run("streaming scenario - ID in first chunk, StopReason in last", func(t *testing.T) { + exts := []*ResponseMetaExtension{ + {ID: "msg_stream_123", StopReason: ""}, + {ID: "", StopReason: ""}, + {ID: "", StopReason: "end_turn"}, + } + + result, err := ConcatResponseMetaExtensions(exts) + assert.NoError(t, err) + assert.Equal(t, "msg_stream_123", result.ID) + assert.Equal(t, "end_turn", result.StopReason) + }) +} diff --git a/schema/claude/response_meta.go b/schema/claude/response_meta.go deleted file mode 100644 index 9f60dd713..000000000 --- a/schema/claude/response_meta.go +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Copyright 2025 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package claude - -type ResponseMetaExtension struct { - ID string `json:"id,omitempty"` - StopReason string `json:"stop_reason,omitempty"` -} diff --git a/schema/gemini/response_meta.go b/schema/gemini/extension.go similarity index 76% rename from schema/gemini/response_meta.go rename to schema/gemini/extension.go index a5b3f626c..efbc4f4bd 100644 --- a/schema/gemini/response_meta.go +++ b/schema/gemini/extension.go @@ -14,8 +14,13 @@ * limitations under the License. */ +// Package gemini defines the extension for gemini. package gemini +import ( + "fmt" +) + type ResponseMetaExtension struct { ID string `json:"id,omitempty"` FinishReason string `json:"finish_reason,omitempty"` @@ -38,7 +43,7 @@ type GroundingChunk struct { Web *GroundingChunkWeb `json:"web,omitempty"` } -// Chunk from the web. +// GroundingChunkWeb is the chunk from the web. type GroundingChunkWeb struct { // Domain of the (original) URI. This field is not supported in Gemini API. Domain string `json:"domain,omitempty"` @@ -56,7 +61,7 @@ type GroundingSupport struct { // A list of indices (into 'grounding_chunk') specifying the citations associated with // the claim. For instance [1,3,4] means that grounding_chunk[1], grounding_chunk[3], // grounding_chunk[4] are the retrieved content attributed to the claim. - GroundingChunkIndices []int32 `json:"grounding_chunk_indices,omitempty"` + GroundingChunkIndices []int `json:"grounding_chunk_indices,omitempty"` // Segment of the content this support belongs to. Segment *Segment `json:"segment,omitempty"` } @@ -65,20 +70,46 @@ type GroundingSupport struct { type Segment struct { // Output only. End index in the given Part, measured in bytes. Offset from the start // of the Part, exclusive, starting at zero. - EndIndex int32 `json:"end_index,omitempty"` + EndIndex int `json:"end_index,omitempty"` // Output only. The index of a Part object within its parent Content object. - PartIndex int32 `json:"part_index,omitempty"` + PartIndex int `json:"part_index,omitempty"` // Output only. Start index in the given Part, measured in bytes. Offset from the start // of the Part, inclusive, starting at zero. - StartIndex int32 `json:"start_index,omitempty"` + StartIndex int `json:"start_index,omitempty"` // Output only. The text corresponding to the segment from the response. Text string `json:"text,omitempty"` } -// Google search entry point. +// SearchEntryPoint is the Google search entry point. type SearchEntryPoint struct { // Optional. Web content snippet that can be embedded in a web page or an app webview. RenderedContent string `json:"rendered_content,omitempty"` // Optional. Base64 encoded JSON representing array of tuple. SDKBlob []byte `json:"sdk_blob,omitempty"` } + +// ConcatResponseMetaExtensions concatenates multiple ResponseMetaExtension chunks into a single one. +func ConcatResponseMetaExtensions(chunks []*ResponseMetaExtension) (*ResponseMetaExtension, error) { + if len(chunks) == 0 { + return nil, fmt.Errorf("no response meta extension found") + } + if len(chunks) == 1 { + return chunks[0], nil + } + + ret := &ResponseMetaExtension{} + + for _, ext := range chunks { + if ext.ID != "" { + ret.ID = ext.ID + } + if ext.FinishReason != "" { + ret.FinishReason = ext.FinishReason + } + if ext.GroundingMeta != nil { + ret.GroundingMeta = ext.GroundingMeta + } + } + + return ret, nil +} diff --git a/schema/gemini/extension_test.go b/schema/gemini/extension_test.go new file mode 100644 index 000000000..56f390aa8 --- /dev/null +++ b/schema/gemini/extension_test.go @@ -0,0 +1,79 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package gemini + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConcatResponseMetaExtensions(t *testing.T) { + t.Run("multiple extensions - takes last non-empty values", func(t *testing.T) { + meta1 := &GroundingMetadata{WebSearchQueries: []string{"query1"}} + meta2 := &GroundingMetadata{WebSearchQueries: []string{"query2"}} + + exts := []*ResponseMetaExtension{ + { + ID: "resp_1", + FinishReason: "STOP", + GroundingMeta: meta1, + }, + { + ID: "resp_2", + FinishReason: "", + GroundingMeta: nil, + }, + { + ID: "", + FinishReason: "MAX_TOKENS", + GroundingMeta: meta2, + }, + } + + result, err := ConcatResponseMetaExtensions(exts) + assert.NoError(t, err) + assert.Equal(t, "resp_2", result.ID) + assert.Equal(t, "MAX_TOKENS", result.FinishReason) + assert.Equal(t, meta2, result.GroundingMeta) + }) + + t.Run("streaming scenario", func(t *testing.T) { + meta := &GroundingMetadata{ + GroundingChunks: []*GroundingChunk{ + { + Web: &GroundingChunkWeb{ + Title: "Example", + URI: "https://example.com", + }, + }, + }, + } + + exts := []*ResponseMetaExtension{ + {ID: "stream_123", FinishReason: "", GroundingMeta: nil}, + {ID: "", FinishReason: "", GroundingMeta: nil}, + {ID: "", FinishReason: "STOP", GroundingMeta: meta}, + } + + result, err := ConcatResponseMetaExtensions(exts) + assert.NoError(t, err) + assert.Equal(t, "stream_123", result.ID) + assert.Equal(t, "STOP", result.FinishReason) + assert.Equal(t, meta, result.GroundingMeta) + }) +} diff --git a/schema/message.go b/schema/message.go index 02cdefaad..611bcedca 100644 --- a/schema/message.go +++ b/schema/message.go @@ -703,10 +703,10 @@ type TokenUsage struct { PromptTokenDetails PromptTokenDetails `json:"prompt_token_details"` // CompletionTokens is the number of completion tokens. CompletionTokens int `json:"completion_tokens"` - // CompletionTokenDetails is a breakdown of the completion tokens. - CompletionTokenDetails CompletionTokensDetails `json:"completion_token_details"` // TotalTokens is the total number of tokens. TotalTokens int `json:"total_tokens"` + // CompletionTokensDetails is breakdown of completion tokens. + CompletionTokensDetails CompletionTokensDetails `json:"completion_token_details"` } type CompletionTokensDetails struct { diff --git a/schema/openai/consts.go b/schema/openai/consts.go index 321ee2a9e..5958cef40 100644 --- a/schema/openai/consts.go +++ b/schema/openai/consts.go @@ -14,6 +14,7 @@ * limitations under the License. */ +// Package openai defines constants for openai. package openai type TextAnnotationType string @@ -24,3 +25,71 @@ const ( TextAnnotationTypeContainerFileCitation TextAnnotationType = "container_file_citation" TextAnnotationTypeFilePath TextAnnotationType = "file_path" ) + +type ReasoningEffort string + +const ( + ReasoningEffortMinimal ReasoningEffort = "minimal" + ReasoningEffortLow ReasoningEffort = "low" + ReasoningEffortMedium ReasoningEffort = "medium" + ReasoningEffortHigh ReasoningEffort = "high" +) + +type ReasoningSummary string + +const ( + ReasoningSummaryAuto ReasoningSummary = "auto" + ReasoningSummaryConcise ReasoningSummary = "concise" + ReasoningSummaryDetailed ReasoningSummary = "detailed" +) + +type ServiceTier string + +const ( + ServiceTierAuto ServiceTier = "auto" + ServiceTierDefault ServiceTier = "default" + ServiceTierFlex ServiceTier = "flex" + ServiceTierScale ServiceTier = "scale" + ServiceTierPriority ServiceTier = "priority" +) + +type PromptCacheRetention string + +const ( + PromptCacheRetentionInMemory PromptCacheRetention = "in-memory" + PromptCacheRetention24h PromptCacheRetention = "24h" +) + +type ResponseStatus string + +const ( + ResponseStatusCompleted ResponseStatus = "completed" + ResponseStatusFailed ResponseStatus = "failed" + ResponseStatusInProgress ResponseStatus = "in_progress" + ResponseStatusCancelled ResponseStatus = "cancelled" + ResponseStatusQueued ResponseStatus = "queued" + ResponseStatusIncomplete ResponseStatus = "incomplete" +) + +type ResponseErrorCode string + +const ( + ResponseErrorCodeServerError ResponseErrorCode = "server_error" + ResponseErrorCodeRateLimitExceeded ResponseErrorCode = "rate_limit_exceeded" + ResponseErrorCodeInvalidPrompt ResponseErrorCode = "invalid_prompt" + ResponseErrorCodeVectorStoreTimeout ResponseErrorCode = "vector_store_timeout" + ResponseErrorCodeInvalidImage ResponseErrorCode = "invalid_image" + ResponseErrorCodeInvalidImageFormat ResponseErrorCode = "invalid_image_format" + ResponseErrorCodeInvalidBase64Image ResponseErrorCode = "invalid_base64_image" + ResponseErrorCodeInvalidImageURL ResponseErrorCode = "invalid_image_url" + ResponseErrorCodeImageTooLarge ResponseErrorCode = "image_too_large" + ResponseErrorCodeImageTooSmall ResponseErrorCode = "image_too_small" + ResponseErrorCodeImageParseError ResponseErrorCode = "image_parse_error" + ResponseErrorCodeImageContentPolicyViolation ResponseErrorCode = "image_content_policy_violation" + ResponseErrorCodeInvalidImageMode ResponseErrorCode = "invalid_image_mode" + ResponseErrorCodeImageFileTooLarge ResponseErrorCode = "image_file_too_large" + ResponseErrorCodeUnsupportedImageMediaType ResponseErrorCode = "unsupported_image_media_type" + ResponseErrorCodeEmptyImageFile ResponseErrorCode = "empty_image_file" + ResponseErrorCodeFailedToDownloadImage ResponseErrorCode = "failed_to_download_image" + ResponseErrorCodeImageFileNotFound ResponseErrorCode = "image_file_not_found" +) diff --git a/schema/openai/content_block.go b/schema/openai/content_block.go deleted file mode 100644 index 5d92be8f7..000000000 --- a/schema/openai/content_block.go +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Copyright 2025 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package openai - -type AssistantGenTextExtension struct { - Annotations []*TextAnnotation `json:"annotations,omitempty"` -} - -type TextAnnotation struct { - Type TextAnnotationType `json:"type,omitempty"` - - FileCitation *TextAnnotationFileCitation `json:"file_citation,omitempty"` - URLCitation *TextAnnotationURLCitation `json:"url_citation,omitempty"` - ContainerFileCitation *TextAnnotationContainerFileCitation `json:"container_file_citation,omitempty"` - FilePath *TextAnnotationFilePath `json:"file_path,omitempty"` -} - -type TextAnnotationFileCitation struct { - // The ID of the file. - FileID string `json:"file_id,omitempty"` - // The filename of the file cited. - Filename string `json:"filename,omitempty"` - - // The index of the file in the list of files. - Index int64 `json:"index,omitempty"` -} - -type TextAnnotationURLCitation struct { - // The title of the web resource. - Title string `json:"title,omitempty"` - // The URL of the web resource. - URL string `json:"url,omitempty"` - - // The index of the first character of the URL citation in the message. - StartIndex int64 `json:"start_index,omitempty"` - // The index of the last character of the URL citation in the message. - EndIndex int64 `json:"end_index,omitempty"` -} - -type TextAnnotationContainerFileCitation struct { - // The ID of the container file. - ContainerID string `json:"container_id,omitempty"` - - // The ID of the file. - FileID string `json:"file_id,omitempty"` - // The filename of the container file cited. - Filename string `json:"filename,omitempty"` - - // The index of the first character of the container file citation in the message. - StartIndex int64 `json:"start_index,omitempty"` - // The index of the last character of the container file citation in the message. - EndIndex int64 `json:"end_index,omitempty"` -} - -type TextAnnotationFilePath struct { - // The ID of the file. - FileID string `json:"file_id,omitempty"` - - // The index of the file in the list of files. - Index int64 `json:"index,omitempty"` -} diff --git a/schema/openai/extension.go b/schema/openai/extension.go new file mode 100644 index 000000000..c30d2d8ec --- /dev/null +++ b/schema/openai/extension.go @@ -0,0 +1,206 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package openai + +import ( + "fmt" + "sort" +) + +type ResponseMetaExtension struct { + ID string `json:"id,omitempty"` + Status ResponseStatus `json:"status,omitempty"` + Error *ResponseError `json:"error,omitempty"` + IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"` + PreviousResponseID string `json:"previous_response_id,omitempty"` + Reasoning *Reasoning `json:"reasoning,omitempty"` + ServiceTier ServiceTier `json:"service_tier,omitempty"` + CreatedAt int64 `json:"created_at,omitempty"` + PromptCacheRetention PromptCacheRetention `json:"prompt_cache_retention,omitempty"` +} + +type AssistantGenTextExtension struct { + Refusal *OutputRefusal `json:"refusal,omitempty"` + Annotations []*TextAnnotation `json:"annotations,omitempty"` +} + +type ResponseError struct { + Code ResponseErrorCode `json:"code,omitempty"` + Message string `json:"message,omitempty"` +} + +type IncompleteDetails struct { + Reason string `json:"reason,omitempty"` +} + +type Reasoning struct { + Effort ReasoningEffort `json:"effort,omitempty"` + Summary ReasoningSummary `json:"summary,omitempty"` +} + +type OutputRefusal struct { + Reason string `json:"reason,omitempty"` +} + +type TextAnnotation struct { + Index int `json:"index,omitempty"` + + Type TextAnnotationType `json:"type,omitempty"` + + FileCitation *TextAnnotationFileCitation `json:"file_citation,omitempty"` + URLCitation *TextAnnotationURLCitation `json:"url_citation,omitempty"` + ContainerFileCitation *TextAnnotationContainerFileCitation `json:"container_file_citation,omitempty"` + FilePath *TextAnnotationFilePath `json:"file_path,omitempty"` +} + +type TextAnnotationFileCitation struct { + // The ID of the file. + FileID string `json:"file_id,omitempty"` + // The filename of the file cited. + Filename string `json:"filename,omitempty"` + + // The index of the file in the list of files. + Index int `json:"index,omitempty"` +} + +type TextAnnotationURLCitation struct { + // The title of the web resource. + Title string `json:"title,omitempty"` + // The URL of the web resource. + URL string `json:"url,omitempty"` + + // The index of the first character of the URL citation in the message. + StartIndex int `json:"start_index,omitempty"` + // The index of the last character of the URL citation in the message. + EndIndex int `json:"end_index,omitempty"` +} + +type TextAnnotationContainerFileCitation struct { + // The ID of the container file. + ContainerID string `json:"container_id,omitempty"` + + // The ID of the file. + FileID string `json:"file_id,omitempty"` + // The filename of the container file cited. + Filename string `json:"filename,omitempty"` + + // The index of the first character of the container file citation in the message. + StartIndex int `json:"start_index,omitempty"` + // The index of the last character of the container file citation in the message. + EndIndex int `json:"end_index,omitempty"` +} + +type TextAnnotationFilePath struct { + // The ID of the file. + FileID string `json:"file_id,omitempty"` + + // The index of the file in the list of files. + Index int `json:"index,omitempty"` +} + +// ConcatAssistantGenTextExtensions concatenates multiple AssistantGenTextExtension chunks into a single one. +func ConcatAssistantGenTextExtensions(chunks []*AssistantGenTextExtension) (*AssistantGenTextExtension, error) { + if len(chunks) == 0 { + return nil, fmt.Errorf("no assistant generated text extension found") + } + + ret := &AssistantGenTextExtension{} + + var allAnnotations []*TextAnnotation + for _, ext := range chunks { + allAnnotations = append(allAnnotations, ext.Annotations...) + } + + var ( + indices []int + indexToAnnotation = map[int]*TextAnnotation{} + ) + + for _, an := range allAnnotations { + if an == nil { + continue + } + if indexToAnnotation[an.Index] == nil { + indexToAnnotation[an.Index] = an + indices = append(indices, an.Index) + } else { + return nil, fmt.Errorf("duplicate annotation index %d", an.Index) + } + } + + sort.Slice(indices, func(i, j int) bool { + return indices[i] < indices[j] + }) + + ret.Annotations = make([]*TextAnnotation, 0, len(indices)) + for _, idx := range indices { + an := *indexToAnnotation[idx] + an.Index = 0 // clear index + ret.Annotations = append(ret.Annotations, &an) + } + + for _, ext := range chunks { + if ext.Refusal == nil { + continue + } + if ret.Refusal == nil { + ret.Refusal = ext.Refusal + } else { + ret.Refusal.Reason += ext.Refusal.Reason + } + } + + return ret, nil +} + +// ConcatResponseMetaExtensions concatenates multiple ResponseMetaExtension chunks into a single one. +func ConcatResponseMetaExtensions(chunks []*ResponseMetaExtension) (*ResponseMetaExtension, error) { + if len(chunks) == 0 { + return nil, fmt.Errorf("no response meta extension found") + } + if len(chunks) == 1 { + return chunks[0], nil + } + + ret := &ResponseMetaExtension{} + + for _, ext := range chunks { + if ext.ID != "" { + ret.ID = ext.ID + } + if ext.Status != "" { + ret.Status = ext.Status + } + if ext.Error != nil { + ret.Error = ext.Error + } + if ext.IncompleteDetails != nil { + ret.IncompleteDetails = ext.IncompleteDetails + } + if ext.PreviousResponseID != "" { + ret.PreviousResponseID = ext.PreviousResponseID + } + if ext.Reasoning != nil { + ret.Reasoning = ext.Reasoning + } + if ext.ServiceTier != "" { + ret.ServiceTier = ext.ServiceTier + } + } + + return ret, nil +} diff --git a/schema/openai/extension_test.go b/schema/openai/extension_test.go new file mode 100644 index 000000000..640982fdf --- /dev/null +++ b/schema/openai/extension_test.go @@ -0,0 +1,193 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package openai + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConcatResponseMetaExtensions(t *testing.T) { + t.Run("multiple extensions - takes last non-empty values", func(t *testing.T) { + err1 := &ResponseError{Code: "err1", Message: "msg1"} + incomplete := &IncompleteDetails{Reason: "max_tokens"} + + exts := []*ResponseMetaExtension{ + { + ID: "id_1", + Status: "in_progress", + Error: err1, + IncompleteDetails: nil, + }, + { + ID: "id_2", + Status: "", + Error: nil, + IncompleteDetails: nil, + }, + { + ID: "", + Status: "completed", + Error: nil, + IncompleteDetails: incomplete, + }, + } + + result, err := ConcatResponseMetaExtensions(exts) + assert.NoError(t, err) + assert.Equal(t, "id_2", result.ID) + assert.Equal(t, ResponseStatus("completed"), result.Status) + assert.Equal(t, err1, result.Error) + assert.Equal(t, incomplete, result.IncompleteDetails) + }) + + t.Run("streaming scenario", func(t *testing.T) { + exts := []*ResponseMetaExtension{ + {ID: "chatcmpl_stream", Status: "", Error: nil, IncompleteDetails: nil}, + {ID: "", Status: ResponseStatus("in_progress"), Error: nil, IncompleteDetails: nil}, + {ID: "", Status: ResponseStatus("completed"), Error: nil, IncompleteDetails: nil}, + } + + result, err := ConcatResponseMetaExtensions(exts) + assert.NoError(t, err) + assert.Equal(t, "chatcmpl_stream", result.ID) + assert.Equal(t, ResponseStatus("completed"), result.Status) + }) +} + +func TestConcatAssistantGenTextExtensions(t *testing.T) { + t.Run("single extension with annotations", func(t *testing.T) { + ext := &AssistantGenTextExtension{ + Annotations: []*TextAnnotation{ + { + Index: 0, + Type: "file_citation", + FileCitation: &TextAnnotationFileCitation{ + FileID: "file_123", + Filename: "doc.pdf", + }, + }, + }, + } + + result, err := ConcatAssistantGenTextExtensions([]*AssistantGenTextExtension{ext}) + assert.NoError(t, err) + assert.Len(t, result.Annotations, 1) + assert.Equal(t, "file_123", result.Annotations[0].FileCitation.FileID) + }) + + t.Run("multiple extensions - merges annotations by index", func(t *testing.T) { + exts := []*AssistantGenTextExtension{ + { + Annotations: []*TextAnnotation{ + { + Index: 0, + Type: "file_citation", + FileCitation: &TextAnnotationFileCitation{ + FileID: "file_1", + }, + }, + }, + }, + { + Annotations: []*TextAnnotation{ + { + Index: 2, + Type: "url_citation", + URLCitation: &TextAnnotationURLCitation{ + URL: "https://example.com", + }, + }, + }, + }, + { + Annotations: []*TextAnnotation{ + { + Index: 1, + Type: "file_path", + FilePath: &TextAnnotationFilePath{ + FileID: "file_2", + }, + }, + }, + }, + } + + result, err := ConcatAssistantGenTextExtensions(exts) + assert.NoError(t, err) + assert.Len(t, result.Annotations, 3) + assert.Equal(t, "file_1", result.Annotations[0].FileCitation.FileID) + assert.Equal(t, "file_2", result.Annotations[1].FilePath.FileID) + assert.Equal(t, "https://example.com", result.Annotations[2].URLCitation.URL) + }) + + t.Run("streaming scenario - annotations arrive in chunks", func(t *testing.T) { + exts := []*AssistantGenTextExtension{ + { + Annotations: []*TextAnnotation{ + {Index: 0, Type: "file_citation", FileCitation: &TextAnnotationFileCitation{FileID: "f1"}}, + }, + }, + { + Annotations: []*TextAnnotation{ + {Index: 1, Type: "url_citation", URLCitation: &TextAnnotationURLCitation{URL: "url1"}}, + }, + }, + { + Annotations: []*TextAnnotation{ + {Index: 2, Type: "file_path", FilePath: &TextAnnotationFilePath{FileID: "f2"}}, + }, + }, + } + + result, err := ConcatAssistantGenTextExtensions(exts) + assert.NoError(t, err) + assert.Len(t, result.Annotations, 3) + assert.Equal(t, "f1", result.Annotations[0].FileCitation.FileID) + assert.Equal(t, "url1", result.Annotations[1].URLCitation.URL) + assert.Equal(t, "f2", result.Annotations[2].FilePath.FileID) + }) + + t.Run("multiple extensions - concatenates refusal reason", func(t *testing.T) { + ext1 := &AssistantGenTextExtension{Refusal: &OutputRefusal{Reason: "A"}} + ext2 := &AssistantGenTextExtension{Refusal: &OutputRefusal{Reason: "B"}} + + result, err := ConcatAssistantGenTextExtensions([]*AssistantGenTextExtension{ext1, ext2}) + assert.NoError(t, err) + assert.NotNil(t, result.Refusal) + assert.Equal(t, "AB", result.Refusal.Reason) + }) + + t.Run("duplicate index - error occurrence", func(t *testing.T) { + exts := []*AssistantGenTextExtension{ + { + Annotations: []*TextAnnotation{ + {Index: 0, Type: "file_citation", FileCitation: &TextAnnotationFileCitation{FileID: "first"}}, + }, + }, + { + Annotations: []*TextAnnotation{ + {Index: 0, Type: "url_citation", URLCitation: &TextAnnotationURLCitation{URL: "second"}}, + }, + }, + } + + _, err := ConcatAssistantGenTextExtensions(exts) + assert.Error(t, err) + }) +} diff --git a/schema/openai/response_meta.go b/schema/openai/response_meta.go deleted file mode 100644 index e1933065b..000000000 --- a/schema/openai/response_meta.go +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright 2025 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package openai - -type ResponseMetaExtension struct { - ID string `json:"id,omitempty"` - Status string `json:"status,omitempty"` - Error *ResponseError `json:"error,omitempty"` - StreamError *StreamResponseError `json:"stream_error,omitempty"` - IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"` -} - -type ResponseError struct { - Code string `json:"code,omitempty"` - Message string `json:"message,omitempty"` -} - -type StreamResponseError struct { - Code string - Message string - Param string -} - -type IncompleteDetails struct { - Reason string `json:"reason,omitempty"` -} diff --git a/schema/tool.go b/schema/tool.go index ccc93b6a3..2d6bf90db 100644 --- a/schema/tool.go +++ b/schema/tool.go @@ -59,6 +59,27 @@ const ( ToolChoiceForced ToolChoice = "forced" ) +type AllowedTool struct { + // FunctionToolName is the name of the function tool. + FunctionToolName string + + MCPTool *AllowedMCPTool + + ServerTool *AllowedServerTool +} +type AllowedMCPTool struct { + // ServerLabel is the label of the MCP server. + ServerLabel string + // The name of the MCP tool. + Name string +} + +type AllowedServerTool struct { + // The name of the server tool. + Name string +} + +// ToolInfo is the information of a tool. // ToolInfo describes a tool that can be passed to a ChatModel via // [ToolCallingChatModel.WithTools] or [ChatModel.BindTools]. // From 7a5087f9112361bfd326e227dacb8cb8a2e5c6cf Mon Sep 17 00:00:00 2001 From: mrh997 Date: Tue, 6 Jan 2026 16:48:56 +0800 Subject: [PATCH 17/59] fix: concat agentic messages (#604) --- components/model/callback_extra.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/components/model/callback_extra.go b/components/model/callback_extra.go index 2767e2e5e..afff3f0a7 100644 --- a/components/model/callback_extra.go +++ b/components/model/callback_extra.go @@ -29,10 +29,10 @@ type TokenUsage struct { PromptTokenDetails PromptTokenDetails // CompletionTokens is the number of completion tokens. CompletionTokens int - // CompletionTokensDetails is a breakdown of the completion tokens. - CompletionTokensDetails CompletionTokensDetails // TotalTokens is the total number of tokens. TotalTokens int + // CompletionTokensDetails is a breakdown of the completion tokens. + CompletionTokensDetails CompletionTokensDetails } type CompletionTokensDetails struct { From 19bd5a38e37abe6d678d3e66c68ee6076bc4b017 Mon Sep 17 00:00:00 2001 From: Megumin Date: Thu, 8 Jan 2026 15:34:53 +0800 Subject: [PATCH 18/59] fix(schema): agentic concat support extra (#670) --- schema/agentic_message.go | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/schema/agentic_message.go b/schema/agentic_message.go index b2225b2c7..2ba02b689 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -695,8 +695,10 @@ func ConcatAgenticMessages(msgs []*AgenticMessage) (*AgenticMessage, error) { role AgenticRoleType blocks []*ContentBlock metas []*AgenticResponseMeta + extra map[string]any blockIndices []int indexToBlocks = map[int][]*ContentBlock{} + extraList = make([]map[string]any, 0, len(msgs)) ) if len(msgs) == 1 { @@ -747,6 +749,10 @@ func ConcatAgenticMessages(msgs []*AgenticMessage) (*AgenticMessage, error) { if msg.ResponseMeta != nil { metas = append(metas, msg.ResponseMeta) } + + if msg.Extra != nil { + extraList = append(extraList, msg.Extra) + } } meta, err := concatAgenticResponseMeta(metas) @@ -758,7 +764,8 @@ func ConcatAgenticMessages(msgs []*AgenticMessage) (*AgenticMessage, error) { // All blocks are streaming, concat each group by index indexToBlock := map[int]*ContentBlock{} for idx, bs := range indexToBlocks { - b, err := concatChunksOfSameContentBlock(bs) + var b *ContentBlock + b, err = concatChunksOfSameContentBlock(bs) if err != nil { return nil, err } @@ -773,10 +780,18 @@ func ConcatAgenticMessages(msgs []*AgenticMessage) (*AgenticMessage, error) { } } + if len(extraList) > 0 { + extra, err = concatExtra(extraList) + if err != nil { + return nil, err + } + } + return &AgenticMessage{ Role: role, ResponseMeta: meta, ContentBlocks: blocks, + Extra: extra, }, nil } From 9b5035f89024da3a9895af6572d551f3dac4349e Mon Sep 17 00:00:00 2001 From: Megumin Date: Thu, 8 Jan 2026 19:36:17 +0800 Subject: [PATCH 19/59] feat(schema): optimize agent message format (#671) --- schema/agentic_message.go | 28 +++- schema/agentic_message_test.go | 270 +++++++++++++++++---------------- 2 files changed, 164 insertions(+), 134 deletions(-) diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 2ba02b689..5008f0c75 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -18,6 +18,7 @@ package schema import ( "context" + "encoding/json" "fmt" "reflect" "sort" @@ -1834,7 +1835,7 @@ func (u *UserInputText) String() string { // String returns the string representation of UserInputImage. func (u *UserInputImage) String() string { - return formatMediaString(u.URL, u.Base64Data, u.MIMEType, u.Detail) + return formatMediaString(u.URL, u.Base64Data, u.MIMEType, string(u.Detail)) } // String returns the string representation of UserInputAudio. @@ -1902,7 +1903,7 @@ func (s *ServerToolCall) String() string { if s.CallID != "" { sb.WriteString(fmt.Sprintf(" call_id: %s\n", s.CallID)) } - sb.WriteString(fmt.Sprintf(" arguments: %v\n", s.Arguments)) + sb.WriteString(fmt.Sprintf(" arguments: %s\n", printAny(s.Arguments))) return sb.String() } @@ -1913,7 +1914,7 @@ func (s *ServerToolResult) String() string { if s.CallID != "" { sb.WriteString(fmt.Sprintf(" call_id: %s\n", s.CallID)) } - sb.WriteString(fmt.Sprintf(" result: %v\n", s.Result)) + sb.WriteString(fmt.Sprintf(" result: %s\n", printAny(s.Result))) return sb.String() } @@ -1996,7 +1997,7 @@ func truncateString(s string, maxLen int) string { } // formatMediaString formats URL, Base64Data, MIMEType and Detail for media content -func formatMediaString(url, base64Data string, mimeType string, detail any) string { +func formatMediaString(url, base64Data string, mimeType string, detail string) string { sb := &strings.Builder{} if url != "" { sb.WriteString(fmt.Sprintf(" url: %s\n", truncateString(url, 100))) @@ -2008,8 +2009,8 @@ func formatMediaString(url, base64Data string, mimeType string, detail any) stri if mimeType != "" { sb.WriteString(fmt.Sprintf(" mime_type: %s\n", mimeType)) } - if detail != nil && detail != "" { - sb.WriteString(fmt.Sprintf(" detail: %v\n", detail)) + if detail != "" { + sb.WriteString(fmt.Sprintf(" detail: %s\n", detail)) } return sb.String() } @@ -2027,3 +2028,18 @@ func validateExtensionType(expected reflect.Type, actual any) (reflect.Type, boo } return expected, true } + +func printAny(a any) string { + switch v := a.(type) { + case string: + return v + case fmt.Stringer: + return v.String() + default: + b, err := json.MarshalIndent(a, "", " ") + if err != nil { + return fmt.Sprintf("%v", a) + } + return string(b) + } +} diff --git a/schema/agentic_message_test.go b/schema/agentic_message_test.go index 016aa5c4e..4beb74930 100644 --- a/schema/agentic_message_test.go +++ b/schema/agentic_message_test.go @@ -1234,6 +1234,61 @@ func TestAgenticMessageString(t *testing.T) { Detail: ImageURLDetailHigh, }, }, + { + Type: ContentBlockTypeUserInputAudio, + UserInputAudio: &UserInputAudio{ + URL: "http://audio.com", + Base64Data: "audio_data", + MIMEType: "audio/mp3", + }, + }, + { + Type: ContentBlockTypeUserInputVideo, + UserInputVideo: &UserInputVideo{ + URL: "http://video.com", + Base64Data: "video_data", + MIMEType: "video/mp4", + }, + }, + { + Type: ContentBlockTypeUserInputFile, + UserInputFile: &UserInputFile{ + URL: "http://file.com", + Name: "file.txt", + Base64Data: "file_data", + MIMEType: "text/plain", + }, + }, + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "I'll check the current weather in New York City for you.", + }, + }, + { + Type: ContentBlockTypeAssistantGenImage, + AssistantGenImage: &AssistantGenImage{ + URL: "http://gen_image.com", + Base64Data: "gen_image_data", + MIMEType: "image/png", + }, + }, + { + Type: ContentBlockTypeAssistantGenAudio, + AssistantGenAudio: &AssistantGenAudio{ + URL: "http://gen_audio.com", + Base64Data: "gen_audio_data", + MIMEType: "audio/wav", + }, + }, + { + Type: ContentBlockTypeAssistantGenVideo, + AssistantGenVideo: &AssistantGenVideo{ + URL: "http://gen_video.com", + Base64Data: "gen_video_data", + MIMEType: "video/mp4", + }, + }, { Type: ContentBlockTypeReasoning, Reasoning: &Reasoning{ @@ -1245,12 +1300,6 @@ func TestAgenticMessageString(t *testing.T) { EncryptedContent: "encrypted_reasoning_content_that_is_very_long_and_will_be_truncated_for_display", }, }, - { - Type: ContentBlockTypeAssistantGenText, - AssistantGenText: &AssistantGenText{ - Text: "I'll check the current weather in New York City for you.", - }, - }, { Type: ContentBlockTypeFunctionToolCall, FunctionToolCall: &FunctionToolCall{ @@ -1268,6 +1317,39 @@ func TestAgenticMessageString(t *testing.T) { Result: `{"temperature":72,"condition":"sunny","humidity":45,"wind_speed":8}`, }, }, + { + Type: ContentBlockTypeServerToolCall, + ServerToolCall: &ServerToolCall{ + Name: "server_tool", + CallID: "call_1", + Arguments: map[string]any{"a": 1}, + }, + }, + { + Type: ContentBlockTypeServerToolResult, + ServerToolResult: &ServerToolResult{ + Name: "server_tool", + CallID: "call_1", + Result: map[string]any{"success": true}, + }, + }, + { + Type: ContentBlockTypeMCPToolApprovalRequest, + MCPToolApprovalRequest: &MCPToolApprovalRequest{ + ID: "req_1", + Name: "mcp_tool", + ServerLabel: "mcp_server", + Arguments: "{}", + }, + }, + { + Type: ContentBlockTypeMCPToolApprovalResponse, + MCPToolApprovalResponse: &MCPToolApprovalResponse{ + ApprovalRequestID: "req_1", + Approve: true, + Reason: "looks good", + }, + }, { Type: ContentBlockTypeMCPToolCall, MCPToolCall: &MCPToolCall{ @@ -1322,34 +1404,80 @@ content_blocks: base64_data: iVBORw0KGgoAAAANSUhE...... (96 bytes) mime_type: image/jpeg detail: high - [2] type: reasoning + [2] type: user_input_audio + url: http://audio.com + base64_data: audio_data... (10 bytes) + mime_type: audio/mp3 + [3] type: user_input_video + url: http://video.com + base64_data: video_data... (10 bytes) + mime_type: video/mp4 + [4] type: user_input_file + name: file.txt + url: http://file.com + base64_data: file_data... (9 bytes) + mime_type: text/plain + [5] type: assistant_gen_text + text: I'll check the current weather in New York City for you. + [6] type: assistant_gen_image + url: http://gen_image.com + base64_data: gen_image_data... (14 bytes) + mime_type: image/png + [7] type: assistant_gen_audio + url: http://gen_audio.com + base64_data: gen_audio_data... (14 bytes) + mime_type: audio/wav + [8] type: assistant_gen_video + url: http://gen_video.com + base64_data: gen_video_data... (14 bytes) + mime_type: video/mp4 + [9] type: reasoning summary: 3 items [0] First, I need to identify the location (New York City) from the user's query. [1] Then, I should call the weather API to get current conditions. [2] Finally, I'll format the response in a user-friendly way with temperature and conditions. encrypted_content: encrypted_reasoning_content_that_is_very_long_and_... - [3] type: assistant_gen_text - text: I'll check the current weather in New York City for you. - [4] type: function_tool_call + [10] type: function_tool_call call_id: call_weather_123 name: get_current_weather arguments: {"location":"New York City","unit":"fahrenheit"} stream_index: 0 - [5] type: function_tool_result + [11] type: function_tool_result call_id: call_weather_123 name: get_current_weather result: {"temperature":72,"condition":"sunny","humidity":45,"wind_speed":8} - [6] type: mcp_tool_call + [12] type: server_tool_call + name: server_tool + call_id: call_1 + arguments: { + "a": 1 +} + [13] type: server_tool_result + name: server_tool + call_id: call_1 + result: { + "success": true +} + [14] type: mcp_tool_approval_request + server_label: mcp_server + id: req_1 + name: mcp_tool + arguments: {} + [15] type: mcp_tool_approval_response + approval_request_id: req_1 + approve: true + reason: looks good + [16] type: mcp_tool_call server_label: weather-mcp-server call_id: mcp_forecast_456 name: get_7day_forecast arguments: {"city":"New York","days":7} - [7] type: mcp_tool_result + [17] type: mcp_tool_result call_id: mcp_forecast_456 name: get_7day_forecast result: {"status":"partial","days_available":3} error: [503] Service temporarily unavailable for full 7-day forecast - [8] type: mcp_list_tools_result + [18] type: mcp_list_tools_result server_label: weather-mcp-server tools: 3 items - get_current_weather: Get current weather conditions for a location @@ -1359,120 +1487,6 @@ response_meta: token_usage: prompt=250, completion=180, total=430 `, output) - t.Run("full fields", func(t *testing.T) { - msg := &AgenticMessage{ - Role: AgenticRoleTypeSystem, - ContentBlocks: []*ContentBlock{ - { - Type: ContentBlockTypeUserInputAudio, - UserInputAudio: &UserInputAudio{ - URL: "http://audio.com", - Base64Data: "audio_data", - MIMEType: "audio/mp3", - }, - }, - { - Type: ContentBlockTypeUserInputVideo, - UserInputVideo: &UserInputVideo{ - URL: "http://video.com", - Base64Data: "video_data", - MIMEType: "video/mp4", - }, - }, - { - Type: ContentBlockTypeUserInputFile, - UserInputFile: &UserInputFile{ - URL: "http://file.com", - Name: "file.txt", - Base64Data: "file_data", - MIMEType: "text/plain", - }, - }, - { - Type: ContentBlockTypeAssistantGenImage, - AssistantGenImage: &AssistantGenImage{ - URL: "http://gen_image.com", - Base64Data: "gen_image_data", - MIMEType: "image/png", - }, - }, - { - Type: ContentBlockTypeAssistantGenAudio, - AssistantGenAudio: &AssistantGenAudio{ - URL: "http://gen_audio.com", - Base64Data: "gen_audio_data", - MIMEType: "audio/wav", - }, - }, - { - Type: ContentBlockTypeAssistantGenVideo, - AssistantGenVideo: &AssistantGenVideo{ - URL: "http://gen_video.com", - Base64Data: "gen_video_data", - MIMEType: "video/mp4", - }, - }, - { - Type: ContentBlockTypeServerToolCall, - ServerToolCall: &ServerToolCall{ - Name: "server_tool", - CallID: "call_1", - Arguments: map[string]any{"a": 1}, - }, - }, - { - Type: ContentBlockTypeServerToolResult, - ServerToolResult: &ServerToolResult{ - Name: "server_tool", - CallID: "call_1", - Result: map[string]any{"success": true}, - }, - }, - { - Type: ContentBlockTypeMCPToolApprovalRequest, - MCPToolApprovalRequest: &MCPToolApprovalRequest{ - ID: "req_1", - Name: "mcp_tool", - ServerLabel: "mcp_server", - Arguments: "{}", - }, - }, - { - Type: ContentBlockTypeMCPToolApprovalResponse, - MCPToolApprovalResponse: &MCPToolApprovalResponse{ - ApprovalRequestID: "req_1", - Approve: true, - Reason: "looks good", - }, - }, - }, - } - - s := msg.String() - assert.Contains(t, s, "role: system") - assert.Contains(t, s, "type: user_input_audio") - assert.Contains(t, s, "http://audio.com") - assert.Contains(t, s, "type: user_input_video") - assert.Contains(t, s, "http://video.com") - assert.Contains(t, s, "type: user_input_file") - assert.Contains(t, s, "file.txt") - assert.Contains(t, s, "type: assistant_gen_image") - assert.Contains(t, s, "http://gen_image.com") - assert.Contains(t, s, "type: assistant_gen_audio") - assert.Contains(t, s, "http://gen_audio.com") - assert.Contains(t, s, "type: assistant_gen_video") - assert.Contains(t, s, "http://gen_video.com") - assert.Contains(t, s, "type: server_tool_call") - assert.Contains(t, s, "server_tool") - assert.Contains(t, s, "map[a:1]") - assert.Contains(t, s, "type: server_tool_result") - assert.Contains(t, s, "map[success:true]") - assert.Contains(t, s, "type: mcp_tool_approval_request") - assert.Contains(t, s, "req_1") - assert.Contains(t, s, "type: mcp_tool_approval_response") - assert.Contains(t, s, "looks good") - }) - t.Run("nil/empty fields", func(t *testing.T) { msg := &AgenticMessage{ Role: AgenticRoleTypeUser, From 86d23bce167ce2bdcd69a4cd07669a54e416820e Mon Sep 17 00:00:00 2001 From: mrh997 Date: Mon, 12 Jan 2026 12:08:31 +0800 Subject: [PATCH 20/59] fix: openai ConcatResponseMetaExtensions (#678) --- schema/openai/extension.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/schema/openai/extension.go b/schema/openai/extension.go index c30d2d8ec..1e10c411e 100644 --- a/schema/openai/extension.go +++ b/schema/openai/extension.go @@ -200,6 +200,12 @@ func ConcatResponseMetaExtensions(chunks []*ResponseMetaExtension) (*ResponseMet if ext.ServiceTier != "" { ret.ServiceTier = ext.ServiceTier } + if ext.CreatedAt != 0 { + ret.CreatedAt = ext.CreatedAt + } + if ext.PromptCacheRetention != "" { + ret.PromptCacheRetention = ext.PromptCacheRetention + } } return ret, nil From 8ed1bfe471f3be0fd98c9b6b25fe178bca5d958b Mon Sep 17 00:00:00 2001 From: mrh997 Date: Mon, 12 Jan 2026 18:04:30 +0800 Subject: [PATCH 21/59] feat: improve comment (#679) --- components/agentic/interface.go | 3 + schema/agentic_message.go | 164 +++++++++++++++++++++++++------- schema/tool.go | 1 + 3 files changed, 131 insertions(+), 37 deletions(-) diff --git a/components/agentic/interface.go b/components/agentic/interface.go index e9960d332..e62a8eeab 100644 --- a/components/agentic/interface.go +++ b/components/agentic/interface.go @@ -25,5 +25,8 @@ import ( type Model interface { Generate(ctx context.Context, input []*schema.AgenticMessage, opts ...Option) (*schema.AgenticMessage, error) Stream(ctx context.Context, input []*schema.AgenticMessage, opts ...Option) (*schema.StreamReader[*schema.AgenticMessage], error) + + // WithTools returns a new Model instance with the specified tools bound. + // This method does not modify the current instance, making it safer for concurrent use. WithTools(tools []*schema.ToolInfo) (Model, error) } diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 5008f0c75..63c78c3eb 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -66,122 +66,208 @@ const ( ) type AgenticMessage struct { + // ResponseMeta is the response metadata. ResponseMeta *AgenticResponseMeta - Role AgenticRoleType + // Role is the message role. + Role AgenticRoleType + + // ContentBlocks is the list of content blocks. ContentBlocks []*ContentBlock + // Extra is the additional information. Extra map[string]any } type AgenticResponseMeta struct { + // TokenUsage is the token usage. TokenUsage *TokenUsage + // OpenAIExtension is the extension for OpenAI. OpenAIExtension *openai.ResponseMetaExtension + + // GeminiExtension is the extension for Gemini. GeminiExtension *gemini.ResponseMetaExtension + + // ClaudeExtension is the extension for Claude. ClaudeExtension *claude.ResponseMetaExtension - Extension any -} -type StreamingMeta struct { - // Index specifies the index position of this block in the final response. - Index int + // Extension is the extension for other models, supplied by the component implementer. + Extension any } type ContentBlock struct { Type ContentBlockType + // Reasoning contains the reasoning content generated by the model. Reasoning *Reasoning - UserInputText *UserInputText + // UserInputText contains the text content provided by the user. + UserInputText *UserInputText + + // UserInputImage contains the image content provided by the user. UserInputImage *UserInputImage + + // UserInputAudio contains the audio content provided by the user. UserInputAudio *UserInputAudio + + // UserInputVideo contains the video content provided by the user. UserInputVideo *UserInputVideo - UserInputFile *UserInputFile - AssistantGenText *AssistantGenText + // UserInputFile contains the file content provided by the user. + UserInputFile *UserInputFile + + // AssistantGenText contains the text content generated by the model. + AssistantGenText *AssistantGenText + + // AssistantGenImage contains the image content generated by the model. AssistantGenImage *AssistantGenImage + + // AssistantGenAudio contains the audio content generated by the model. AssistantGenAudio *AssistantGenAudio + + // AssistantGenVideo contains the video content generated by the model. AssistantGenVideo *AssistantGenVideo - // FunctionToolCall holds invocation details for a user-defined tool. + // FunctionToolCall contains the invocation details for a user-defined tool. FunctionToolCall *FunctionToolCall - // FunctionToolResult is the result from a user-defined tool call. + + // FunctionToolResult contains the result returned from a user-defined tool call. FunctionToolResult *FunctionToolResult - // ServerToolCall holds invocation details for a provider built-in tool run on the model server. + + // ServerToolCall contains the invocation details for a provider built-in tool executed on the model server. ServerToolCall *ServerToolCall - // ServerToolResult is the result from a provider built-in tool run on the model server. + + // ServerToolResult contains the result returned from a provider built-in tool executed on the model server. ServerToolResult *ServerToolResult - // MCPToolCall holds invocation details for an MCP tool managed by the model server. + // MCPToolCall contains the invocation details for an MCP tool managed by the model server. MCPToolCall *MCPToolCall - // MCPToolResult is the result from an MCP tool managed by the model server. + + // MCPToolResult contains the result returned from an MCP tool managed by the model server. MCPToolResult *MCPToolResult - // MCPListToolsResult lists available MCP tools reported by the model server. + + // MCPListToolsResult contains the list of available MCP tools reported by the model server. MCPListToolsResult *MCPListToolsResult - // MCPToolApprovalRequest requests user approval for an MCP tool call when required. + + // MCPToolApprovalRequest contains the user approval request for an MCP tool call when required. MCPToolApprovalRequest *MCPToolApprovalRequest - // MCPToolApprovalResponse records the user's approval decision for an MCP tool call. + + // MCPToolApprovalResponse contains the user's approval decision for an MCP tool call. MCPToolApprovalResponse *MCPToolApprovalResponse + // StreamingMeta contains metadata for streaming responses. StreamingMeta *StreamingMeta - Extra map[string]any + + // Extra contains additional information for the content block. + Extra map[string]any +} + +type StreamingMeta struct { + // Index specifies the index position of this block in the final response. + Index int } type UserInputText struct { + // Text is the text content. Text string } type UserInputImage struct { - URL string + // URL is the HTTP/HTTPS link. + URL string + + // Base64Data is the binary data in Base64 encoded string format. Base64Data string - MIMEType string - Detail ImageURLDetail + + // MIMEType is the mime type, e.g. "image/png". + MIMEType string + + // Detail is the quality of the image url. + Detail ImageURLDetail } type UserInputAudio struct { - URL string + // URL is the HTTP/HTTPS link. + URL string + + // Base64Data is the binary data in Base64 encoded string format. Base64Data string - MIMEType string + + // MIMEType is the mime type, e.g. "audio/wav". + MIMEType string } type UserInputVideo struct { - URL string + // URL is the HTTP/HTTPS link. + URL string + + // Base64Data is the binary data in Base64 encoded string format. Base64Data string - MIMEType string + + // MIMEType is the mime type, e.g. "video/mp4". + MIMEType string } type UserInputFile struct { - URL string - Name string + // URL is the HTTP/HTTPS link. + URL string + + // Name is the filename. + Name string + + // Base64Data is the binary data in Base64 encoded string format. Base64Data string - MIMEType string + + // MIMEType is the mime type, e.g. "application/pdf". + MIMEType string } type AssistantGenText struct { + // Text is the generated text. Text string + // OpenAIExtension is the extension for OpenAI. OpenAIExtension *openai.AssistantGenTextExtension + + // ClaudeExtension is the extension for Claude. ClaudeExtension *claude.AssistantGenTextExtension - Extension any + + // Extension is the extension for other models, supplied by the component implementer. + Extension any } type AssistantGenImage struct { - URL string + // URL is the HTTP/HTTPS link. + URL string + + // Base64Data is the binary data in Base64 encoded string format. Base64Data string - MIMEType string + + // MIMEType is the mime type, e.g. "image/png". + MIMEType string } type AssistantGenAudio struct { - URL string + // URL is the HTTP/HTTPS link. + URL string + + // Base64Data is the binary data in Base64 encoded string format. Base64Data string - MIMEType string + + // MIMEType is the mime type, e.g. "audio/wav". + MIMEType string } type AssistantGenVideo struct { - URL string + // URL is the HTTP/HTTPS link. + URL string + + // Base64Data is the binary data in Base64 encoded string format. Base64Data string - MIMEType string + + // MIMEType is the mime type, e.g. "video/mp4". + MIMEType string } type Reasoning struct { @@ -196,6 +282,7 @@ type ReasoningSummary struct { // Index specifies the index position of this summary in the final Reasoning. Index int + // Text is the reasoning content summary. Text string } @@ -284,7 +371,10 @@ type MCPToolResult struct { } type MCPToolCallError struct { - Code *int64 + // Code is the error code. + Code *int64 + + // Message is the error message. Message string } diff --git a/schema/tool.go b/schema/tool.go index 2d6bf90db..efed6d34b 100644 --- a/schema/tool.go +++ b/schema/tool.go @@ -67,6 +67,7 @@ type AllowedTool struct { ServerTool *AllowedServerTool } + type AllowedMCPTool struct { // ServerLabel is the label of the MCP server. ServerLabel string From 238a7a794cc1ee47467a34250572185a46122f6a Mon Sep 17 00:00:00 2001 From: mrh997 Date: Tue, 13 Jan 2026 21:41:07 +0800 Subject: [PATCH 22/59] feat: add agentic callbacks template (#681) --- components/agentic/callback_extra.go | 85 ----- components/agentic/callback_extra_test.go | 35 -- components/agentic/interface.go | 32 -- components/agentic/option.go | 142 -------- components/agentic/option_test.go | 79 ----- components/model/callback_extra.go | 18 +- components/model/interface.go | 12 + components/model/option.go | 31 +- components/model/option_test.go | 16 + ...te_agentic.go => agentic_chat_template.go} | 16 +- .../prompt/agentic_chat_template_test.go | 124 +++++++ components/prompt/callback_extra.go | 50 +-- components/prompt/callback_extra_test.go | 21 +- .../prompt/chat_template_agentic_test.go | 111 ------- components/prompt/interface.go | 1 + ..._node_agentic.go => agentic_tools_node.go} | 0 ...tic_test.go => agentic_tools_node_test.go} | 0 compose/chain.go | 3 +- compose/chain_branch.go | 3 +- compose/chain_parallel.go | 3 +- compose/component_to_graph_node.go | 3 +- compose/graph.go | 3 +- schema/agentic_message.go | 8 +- utils/callbacks/template.go | 176 +++++++++- utils/callbacks/template_test.go | 304 +++++++++++++++++- 25 files changed, 694 insertions(+), 582 deletions(-) delete mode 100644 components/agentic/callback_extra.go delete mode 100644 components/agentic/callback_extra_test.go delete mode 100644 components/agentic/interface.go delete mode 100644 components/agentic/option.go delete mode 100644 components/agentic/option_test.go rename components/prompt/{chat_template_agentic.go => agentic_chat_template.go} (87%) create mode 100644 components/prompt/agentic_chat_template_test.go delete mode 100644 components/prompt/chat_template_agentic_test.go rename compose/{tools_node_agentic.go => agentic_tools_node.go} (100%) rename compose/{tools_node_agentic_test.go => agentic_tools_node_test.go} (100%) diff --git a/components/agentic/callback_extra.go b/components/agentic/callback_extra.go deleted file mode 100644 index 2c5a656fa..000000000 --- a/components/agentic/callback_extra.go +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Copyright 2025 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// Package agentic defines callback payloads and configuration types for agentic models. -package agentic - -import ( - "github.com/cloudwego/eino/callbacks" - "github.com/cloudwego/eino/schema" -) - -// Config is the config for the model. -type Config struct { - // Model is the model name. - Model string - // Temperature is the temperature, which controls the randomness of the model. - Temperature float64 - // TopP is the top p, which controls the diversity of the model. - TopP float64 -} - -// CallbackInput is the input for the model callback. -type CallbackInput struct { - // Messages is the messages to be sent to the model. - Messages []*schema.AgenticMessage - // Tools is the tools to be used in the model. - Tools []*schema.ToolInfo - // ToolChoice controls which tool is called by the model. - ToolChoice *schema.ToolChoice - // Config is the config for the model. - Config *Config - // Extra is the extra information for the callback. - Extra map[string]any -} - -// CallbackOutput is the output for the model callback. -type CallbackOutput struct { - // Message is the message generated by the model. - Message *schema.AgenticMessage - // Config is the config for the model. - Config *Config - // Extra is the extra information for the callback. - Extra map[string]any -} - -// ConvCallbackInput converts the callback input to the model callback input. -func ConvCallbackInput(src callbacks.CallbackInput) *CallbackInput { - switch t := src.(type) { - case *CallbackInput: // when callback is triggered within component implementation, the input is usually already a typed *model.CallbackInput - return t - case []*schema.AgenticMessage: // when callback is injected by graph node, not the component implementation itself, the input is the input of Chat Model interface, which is []*schema.AgenticMessage - return &CallbackInput{ - Messages: t, - } - default: - return nil - } -} - -// ConvCallbackOutput converts the callback output to the model callback output. -func ConvCallbackOutput(src callbacks.CallbackOutput) *CallbackOutput { - switch t := src.(type) { - case *CallbackOutput: // when callback is triggered within component implementation, the output is usually already a typed *model.CallbackOutput - return t - case *schema.AgenticMessage: // when callback is injected by graph node, not the component implementation itself, the output is the output of Chat Model interface, which is *schema.AgenticMessage - return &CallbackOutput{ - Message: t, - } - default: - return nil - } -} diff --git a/components/agentic/callback_extra_test.go b/components/agentic/callback_extra_test.go deleted file mode 100644 index a77da6cd2..000000000 --- a/components/agentic/callback_extra_test.go +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Copyright 2025 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package agentic - -import ( - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/cloudwego/eino/schema" -) - -func TestConvModel(t *testing.T) { - assert.NotNil(t, ConvCallbackInput(&CallbackInput{})) - assert.NotNil(t, ConvCallbackInput([]*schema.AgenticMessage{})) - assert.Nil(t, ConvCallbackInput("asd")) - - assert.NotNil(t, ConvCallbackOutput(&CallbackOutput{})) - assert.NotNil(t, ConvCallbackOutput(&schema.AgenticMessage{})) - assert.Nil(t, ConvCallbackOutput("asd")) -} diff --git a/components/agentic/interface.go b/components/agentic/interface.go deleted file mode 100644 index e62a8eeab..000000000 --- a/components/agentic/interface.go +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright 2025 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package agentic - -import ( - "context" - - "github.com/cloudwego/eino/schema" -) - -type Model interface { - Generate(ctx context.Context, input []*schema.AgenticMessage, opts ...Option) (*schema.AgenticMessage, error) - Stream(ctx context.Context, input []*schema.AgenticMessage, opts ...Option) (*schema.StreamReader[*schema.AgenticMessage], error) - - // WithTools returns a new Model instance with the specified tools bound. - // This method does not modify the current instance, making it safer for concurrent use. - WithTools(tools []*schema.ToolInfo) (Model, error) -} diff --git a/components/agentic/option.go b/components/agentic/option.go deleted file mode 100644 index d8873442a..000000000 --- a/components/agentic/option.go +++ /dev/null @@ -1,142 +0,0 @@ -/* - * Copyright 2025 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package agentic - -import ( - "github.com/cloudwego/eino/schema" -) - -// Options is the common options for the model. -type Options struct { - // Temperature is the temperature for the model, which controls the randomness of the model. - Temperature *float64 - // Model is the model name. - Model *string - // TopP is the top p for the model, which controls the diversity of the model. - TopP *float64 - // Tools is a list of tools the model may call. - Tools []*schema.ToolInfo - // ToolChoice controls how the model call the tools. - ToolChoice *schema.ToolChoice - // AllowedTools is a list of allowed tools the model may call. - AllowedTools []*schema.AllowedTool -} - -// Option is the call option for ChatModel component. -type Option struct { - apply func(opts *Options) - - implSpecificOptFn any -} - -// WithTemperature is the option to set the temperature for the model. -func WithTemperature(temperature float64) Option { - return Option{ - apply: func(opts *Options) { - opts.Temperature = &temperature - }, - } -} - -// WithModel is the option to set the model name. -func WithModel(name string) Option { - return Option{ - apply: func(opts *Options) { - opts.Model = &name - }, - } -} - -// WithTopP is the option to set the top p for the model. -func WithTopP(topP float64) Option { - return Option{ - apply: func(opts *Options) { - opts.TopP = &topP - }, - } -} - -// WithTools is the option to set tools for the model. -func WithTools(tools []*schema.ToolInfo) Option { - if tools == nil { - tools = []*schema.ToolInfo{} - } - return Option{ - apply: func(opts *Options) { - opts.Tools = tools - }, - } -} - -// WithToolChoice is the option to set tool choice for the model. -func WithToolChoice(toolChoice schema.ToolChoice, allowedTools ...*schema.AllowedTool) Option { - return Option{ - apply: func(opts *Options) { - opts.ToolChoice = &toolChoice - opts.AllowedTools = allowedTools - }, - } -} - -// WrapImplSpecificOptFn is the option to wrap the implementation specific option function. -func WrapImplSpecificOptFn[T any](optFn func(*T)) Option { - return Option{ - implSpecificOptFn: optFn, - } -} - -// GetCommonOptions extract model Options from Option list, optionally providing a base Options with default values. -func GetCommonOptions(base *Options, opts ...Option) *Options { - if base == nil { - base = &Options{} - } - - for i := range opts { - opt := opts[i] - if opt.apply != nil { - opt.apply(base) - } - } - - return base -} - -// GetImplSpecificOptions extract the implementation specific options from Option list, optionally providing a base options with default values. -// e.g. -// -// myOption := &MyOption{ -// Field1: "default_value", -// } -// -// myOption := model.GetImplSpecificOptions(myOption, opts...) -func GetImplSpecificOptions[T any](base *T, opts ...Option) *T { - if base == nil { - base = new(T) - } - - for i := range opts { - opt := opts[i] - if opt.implSpecificOptFn != nil { - optFn, ok := opt.implSpecificOptFn.(func(*T)) - if ok { - optFn(base) - } - } - } - - return base -} diff --git a/components/agentic/option_test.go b/components/agentic/option_test.go deleted file mode 100644 index 2c5bac652..000000000 --- a/components/agentic/option_test.go +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Copyright 2025 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package agentic - -import ( - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/cloudwego/eino/schema" -) - -func TestCommon(t *testing.T) { - o := GetCommonOptions(nil, - WithTools([]*schema.ToolInfo{{Name: "test"}}), - WithModel("test"), - WithTemperature(0.1), - WithToolChoice(schema.ToolChoiceAllowed, []*schema.AllowedTool{{FunctionToolName: "test"}}...), - WithTopP(0.1), - ) - assert.Len(t, o.Tools, 1) - assert.Equal(t, "test", o.Tools[0].Name) - assert.Equal(t, "test", *o.Model) - assert.Equal(t, float64(0.1), *o.Temperature) - assert.Equal(t, schema.ToolChoiceAllowed, *o.ToolChoice) - assert.Equal(t, float64(0.1), *o.TopP) -} - -func TestImplSpecificOpts(t *testing.T) { - type implSpecificOptions struct { - conf string - index int - } - - withConf := func(conf string) func(o *implSpecificOptions) { - return func(o *implSpecificOptions) { - o.conf = conf - } - } - - withIndex := func(index int) func(o *implSpecificOptions) { - return func(o *implSpecificOptions) { - o.index = index - } - } - - documentOption1 := WrapImplSpecificOptFn(withConf("test_conf")) - documentOption2 := WrapImplSpecificOptFn(withIndex(1)) - - implSpecificOpts := GetImplSpecificOptions(&implSpecificOptions{}, documentOption1, documentOption2) - - assert.Equal(t, &implSpecificOptions{ - conf: "test_conf", - index: 1, - }, implSpecificOpts) - documentOption1 = WrapImplSpecificOptFn(withConf("test_conf")) - documentOption2 = WrapImplSpecificOptFn(withIndex(1)) - - implSpecificOpts = GetImplSpecificOptions(&implSpecificOptions{}, documentOption1, documentOption2) - - assert.Equal(t, &implSpecificOptions{ - conf: "test_conf", - index: 1, - }, implSpecificOpts) -} diff --git a/components/model/callback_extra.go b/components/model/callback_extra.go index afff3f0a7..ed9096d5c 100644 --- a/components/model/callback_extra.go +++ b/components/model/callback_extra.go @@ -31,15 +31,15 @@ type TokenUsage struct { CompletionTokens int // TotalTokens is the total number of tokens. TotalTokens int - // CompletionTokensDetails is a breakdown of the completion tokens. - CompletionTokensDetails CompletionTokensDetails + // CompletionTokensDetails is breakdown of completion tokens. + CompletionTokensDetails CompletionTokensDetails `json:"completion_token_details"` } type CompletionTokensDetails struct { // ReasoningTokens tokens generated by the model for reasoning. // This is currently supported by OpenAI, Gemini, ARK and Qwen chat models. // For other models, this field will be 0. - ReasoningTokens int + ReasoningTokens int `json:"reasoning_tokens,omitempty"` } // PromptTokenDetails provides a breakdown of prompt token usage. @@ -66,6 +66,8 @@ type Config struct { type CallbackInput struct { // Messages is the messages to be sent to the model. Messages []*schema.Message + // AgenticMessages is the agentic messages to be sent to the agentic model. + AgenticMessages []*schema.AgenticMessage // Tools is the tools to be used in the model. Tools []*schema.ToolInfo // ToolChoice is the tool choice, which controls the tool to be used in the model. @@ -80,6 +82,8 @@ type CallbackInput struct { type CallbackOutput struct { // Message is the message generated by the model. Message *schema.Message + // AgenticMessage is the agentic message generated by the agentic model. + AgenticMessage *schema.AgenticMessage // Config is the config for the model. Config *Config // TokenUsage is the token usage of this request. @@ -97,6 +101,10 @@ func ConvCallbackInput(src callbacks.CallbackInput) *CallbackInput { return &CallbackInput{ Messages: t, } + case []*schema.AgenticMessage: // when callback is injected by graph node, not the component implementation itself, the input is the input of Agentic Model interface, which is []*schema.AgenticMessage + return &CallbackInput{ + AgenticMessages: t, + } default: return nil } @@ -111,6 +119,10 @@ func ConvCallbackOutput(src callbacks.CallbackOutput) *CallbackOutput { return &CallbackOutput{ Message: t, } + case *schema.AgenticMessage: // when callback is injected by graph node, not the component implementation itself, the output is the output of Agentic Model interface, which is *schema.AgenticMessage + return &CallbackOutput{ + AgenticMessage: t, + } default: return nil } diff --git a/components/model/interface.go b/components/model/interface.go index deb7b56dd..cf79785bc 100644 --- a/components/model/interface.go +++ b/components/model/interface.go @@ -89,3 +89,15 @@ type ToolCallingChatModel interface { // This method does not modify the current instance, making it safer for concurrent use. WithTools(tools []*schema.ToolInfo) (ToolCallingChatModel, error) } + +// AgenticModel defines the interface for agentic models that support AgenticMessage. +// It provides methods for generating complete and streaming outputs, and supports +// tool calling via the WithTools method. +type AgenticModel interface { + Generate(ctx context.Context, input []*schema.AgenticMessage, opts ...Option) (*schema.AgenticMessage, error) + Stream(ctx context.Context, input []*schema.AgenticMessage, opts ...Option) (*schema.StreamReader[*schema.AgenticMessage], error) + + // WithTools returns a new Model instance with the specified tools bound. + // This method does not modify the current instance, making it safer for concurrent use. + WithTools(tools []*schema.ToolInfo) (AgenticModel, error) +} diff --git a/components/model/option.go b/components/model/option.go index 9fd96116c..0173d22aa 100644 --- a/components/model/option.go +++ b/components/model/option.go @@ -22,21 +22,29 @@ import "github.com/cloudwego/eino/schema" type Options struct { // Temperature is the temperature for the model, which controls the randomness of the model. Temperature *float32 - // MaxTokens is the max number of tokens, if reached the max tokens, the model will stop generating, and mostly return an finish reason of "length". - MaxTokens *int // Model is the model name. Model *string // TopP is the top p for the model, which controls the diversity of the model. TopP *float32 - // Stop is the stop words for the model, which controls the stopping condition of the model. - Stop []string // Tools is a list of tools the model may call. Tools []*schema.ToolInfo // ToolChoice controls which tool is called by the model. ToolChoice *schema.ToolChoice + + // Options only for chat model. + + // MaxTokens is the max number of tokens, if reached the max tokens, the model will stop generating, and mostly return an finish reason of "length". + MaxTokens *int // AllowedToolNames specifies a list of tool names that the model is allowed to call. // This allows for constraining the model to a specific subset of the available tools. AllowedToolNames []string + // Stop is the stop words for the model, which controls the stopping condition of the model. + Stop []string + + // Options only for agentic model. + + // AllowedTools is a list of allowed tools the model may call. + AllowedTools []*schema.AllowedTool } // Option is a call-time option for a ChatModel. Options are immutable and @@ -59,6 +67,7 @@ func WithTemperature(temperature float32) Option { } // WithMaxTokens is the option to set the max tokens for the model. +// Only available for ChatModel. func WithMaxTokens(maxTokens int) Option { return Option{ apply: func(opts *Options) { @@ -86,6 +95,7 @@ func WithTopP(topP float32) Option { } // WithStop is the option to set the stop words for the model. +// Only available for ChatModel. func WithStop(stop []string) Option { return Option{ apply: func(opts *Options) { @@ -108,6 +118,7 @@ func WithTools(tools []*schema.ToolInfo) Option { // WithToolChoice sets the tool choice for the model. It also allows for providing a list of // tool names to constrain the model to a specific subset of the available tools. +// Only available for ChatModel. func WithToolChoice(toolChoice schema.ToolChoice, allowedToolNames ...string) Option { return Option{ apply: func(opts *Options) { @@ -117,6 +128,18 @@ func WithToolChoice(toolChoice schema.ToolChoice, allowedToolNames ...string) Op } } +// WithAgenticToolChoice is the option to set tool choice for the agentic model. +// Only available for AgenticModel. +func WithAgenticToolChoice(toolChoice schema.ToolChoice, allowedTools ...*schema.AllowedTool) Option { + return Option{ + apply: func(opts *Options) { + opts.ToolChoice = &toolChoice + opts.AllowedTools = allowedTools + }, + } +} + +// WrapImplSpecificOptFn is the option to wrap the implementation specific option function. // WrapImplSpecificOptFn wraps an implementation-specific option function into // an [Option] so it can be passed alongside standard options. // diff --git a/components/model/option_test.go b/components/model/option_test.go index 36872c30e..bfacdd17c 100644 --- a/components/model/option_test.go +++ b/components/model/option_test.go @@ -82,6 +82,22 @@ func TestOptions(t *testing.T) { convey.So(opts.Tools, convey.ShouldNotBeNil) convey.So(len(opts.Tools), convey.ShouldEqual, 0) }) + + convey.Convey("test agentic tool choice option", t, func() { + var ( + toolChoice = schema.ToolChoiceForced + allowedTools = []*schema.AllowedTool{ + {FunctionToolName: "agentic_tool"}, + } + ) + opts := GetCommonOptions( + nil, + WithAgenticToolChoice(toolChoice, allowedTools...), + ) + + convey.So(opts.ToolChoice, convey.ShouldResemble, &toolChoice) + convey.So(opts.AllowedTools, convey.ShouldResemble, allowedTools) + }) } type implOption struct { diff --git a/components/prompt/chat_template_agentic.go b/components/prompt/agentic_chat_template.go similarity index 87% rename from components/prompt/chat_template_agentic.go rename to components/prompt/agentic_chat_template.go index 937d46f26..512a60ecd 100644 --- a/components/prompt/chat_template_agentic.go +++ b/components/prompt/agentic_chat_template.go @@ -1,5 +1,5 @@ /* - * Copyright 2025 CloudWeGo Authors + * Copyright 2026 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -45,9 +45,9 @@ type DefaultAgenticChatTemplate struct { func (t *DefaultAgenticChatTemplate) Format(ctx context.Context, vs map[string]any, opts ...Option) (result []*schema.AgenticMessage, err error) { ctx = callbacks.EnsureRunInfo(ctx, t.GetType(), components.ComponentOfAgenticPrompt) - ctx = callbacks.OnStart(ctx, &AgenticCallbackInput{ - Variables: vs, - Templates: t.templates, + ctx = callbacks.OnStart(ctx, &CallbackInput{ + Variables: vs, + AgenticTemplates: t.templates, }) defer func() { if err != nil { @@ -65,15 +65,15 @@ func (t *DefaultAgenticChatTemplate) Format(ctx context.Context, vs map[string]a result = append(result, msgs...) } - _ = callbacks.OnEnd(ctx, &AgenticCallbackOutput{ - Result: result, - Templates: t.templates, + _ = callbacks.OnEnd(ctx, &CallbackOutput{ + AgenticResult: result, + AgenticTemplates: t.templates, }) return result, nil } -// GetType returns the type of the chat template (Default). +// GetType returns the type of the agentic template (DefaultAgentic). func (t *DefaultAgenticChatTemplate) GetType() string { return "Default" } diff --git a/components/prompt/agentic_chat_template_test.go b/components/prompt/agentic_chat_template_test.go new file mode 100644 index 000000000..42d7a8630 --- /dev/null +++ b/components/prompt/agentic_chat_template_test.go @@ -0,0 +1,124 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package prompt + +import ( + "context" + "errors" + "testing" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/schema" + "github.com/stretchr/testify/assert" +) + +type mockAgenticTemplate struct { + err error +} + +func (m *mockAgenticTemplate) Format(ctx context.Context, vs map[string]any, formatType schema.FormatType) ([]*schema.AgenticMessage, error) { + if m.err != nil { + return nil, m.err + } + return []*schema.AgenticMessage{schema.UserAgenticMessage("mocked")}, nil +} + +func TestFromAgenticMessages(t *testing.T) { + t.Run("create template", func(t *testing.T) { + tpl := schema.UserAgenticMessage("hello") + ft := schema.FString + at := FromAgenticMessages(ft, tpl) + + assert.NotNil(t, at) + assert.Equal(t, ft, at.formatType) + assert.Len(t, at.templates, 1) + assert.Same(t, tpl, at.templates[0]) + }) +} + +func TestDefaultAgenticTemplate_GetType(t *testing.T) { + t.Run("get type", func(t *testing.T) { + at := &DefaultAgenticChatTemplate{} + assert.Equal(t, "Default", at.GetType()) + }) +} + +func TestDefaultAgenticTemplate_IsCallbacksEnabled(t *testing.T) { + t.Run("callbacks enabled", func(t *testing.T) { + at := &DefaultAgenticChatTemplate{} + assert.True(t, at.IsCallbacksEnabled()) + }) +} + +func TestDefaultAgenticTemplate_Format(t *testing.T) { + t.Run("success", func(t *testing.T) { + // Mock callback handler + cb := callbacks.NewHandlerBuilder(). + OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { + assert.Equal(t, "Default", info.Type) + return ctx + }). + OnEndFn(func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context { + assert.Equal(t, "Default", info.Type) + return ctx + }). + OnErrorFn(func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { + assert.Fail(t, "unexpected error callback") + return ctx + }). + Build() + + tpl := schema.UserAgenticMessage("hello {val}") + at := FromAgenticMessages(schema.FString, tpl) + + ctx := context.Background() + ctx = callbacks.InitCallbacks(ctx, &callbacks.RunInfo{ + Type: "Default", + Component: "agentic_prompt", + }, cb) + + res, err := at.Format(ctx, map[string]any{"val": "world"}) + assert.NoError(t, err) + assert.Len(t, res, 1) + assert.Equal(t, "hello world", res[0].ContentBlocks[0].UserInputText.Text) + }) + + t.Run("template format error", func(t *testing.T) { + mockErr := errors.New("mock error") + mockTpl := &mockAgenticTemplate{err: mockErr} + at := FromAgenticMessages(schema.FString, mockTpl) + + // Mock callback handler to verify OnError + cb := callbacks.NewHandlerBuilder(). + OnErrorFn(func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { + assert.Equal(t, mockErr, err) + return ctx + }). + Build() + + ctx := context.Background() + ctx = callbacks.InitCallbacks(ctx, &callbacks.RunInfo{ + Type: "Default", + Component: "agentic_prompt", + }, cb) + + res, err := at.Format(ctx, map[string]any{}) + assert.Error(t, err) + assert.Nil(t, res) + assert.Equal(t, mockErr, err) + }) +} diff --git a/components/prompt/callback_extra.go b/components/prompt/callback_extra.go index 3be780543..4c27f37c6 100644 --- a/components/prompt/callback_extra.go +++ b/components/prompt/callback_extra.go @@ -21,52 +21,14 @@ import ( "github.com/cloudwego/eino/schema" ) -type AgenticCallbackInput struct { - Variables map[string]any - Templates []schema.AgenticMessagesTemplate - Extra map[string]any -} - -type AgenticCallbackOutput struct { - Result []*schema.AgenticMessage - Templates []schema.AgenticMessagesTemplate - Extra map[string]any -} - -// ConvAgenticCallbackInput converts the callback input to the agentic callback input. -func ConvAgenticCallbackInput(src callbacks.CallbackInput) *AgenticCallbackInput { - switch t := src.(type) { - case *AgenticCallbackInput: - return t - case map[string]any: - return &AgenticCallbackInput{ - Variables: t, - } - default: - return nil - } -} - -// ConvAgenticCallbackOutput converts the callback output to the agentic callback output. -func ConvAgenticCallbackOutput(src callbacks.CallbackOutput) *AgenticCallbackOutput { - switch t := src.(type) { - case *AgenticCallbackOutput: - return t - case []*schema.AgenticMessage: - return &AgenticCallbackOutput{ - Result: t, - } - default: - return nil - } -} - // CallbackInput is the input for the callback. type CallbackInput struct { // Variables is the variables for the callback. Variables map[string]any // Templates is the templates for the callback. Templates []schema.MessagesTemplate + // AgenticTemplates is the agentic templates for the callback. + AgenticTemplates []schema.AgenticMessagesTemplate // Extra is the extra information for the callback. Extra map[string]any } @@ -75,8 +37,12 @@ type CallbackInput struct { type CallbackOutput struct { // Result is the result for the callback. Result []*schema.Message + // AgenticResult is the agentic result for the callback. + AgenticResult []*schema.AgenticMessage // Templates is the templates for the callback. Templates []schema.MessagesTemplate + // AgenticTemplates is the agentic templates for the callback. + AgenticTemplates []schema.AgenticMessagesTemplate // Extra is the extra information for the callback. Extra map[string]any } @@ -104,6 +70,10 @@ func ConvCallbackOutput(src callbacks.CallbackOutput) *CallbackOutput { return &CallbackOutput{ Result: t, } + case []*schema.AgenticMessage: + return &CallbackOutput{ + AgenticResult: t, + } default: return nil } diff --git a/components/prompt/callback_extra_test.go b/components/prompt/callback_extra_test.go index 456297e29..4b48ec114 100644 --- a/components/prompt/callback_extra_test.go +++ b/components/prompt/callback_extra_test.go @@ -25,11 +25,28 @@ import ( ) func TestConvPrompt(t *testing.T) { - assert.NotNil(t, ConvCallbackInput(&CallbackInput{})) + assert.NotNil(t, ConvCallbackInput(&CallbackInput{ + AgenticTemplates: []schema.AgenticMessagesTemplate{ + &schema.AgenticMessage{}, + }, + })) assert.NotNil(t, ConvCallbackInput(map[string]any{})) assert.Nil(t, ConvCallbackInput("asd")) - assert.NotNil(t, ConvCallbackOutput(&CallbackOutput{})) + assert.NotNil(t, ConvCallbackOutput(&CallbackOutput{ + AgenticResult: []*schema.AgenticMessage{ + {}, + }, + AgenticTemplates: []schema.AgenticMessagesTemplate{ + &schema.AgenticMessage{}, + }, + })) assert.NotNil(t, ConvCallbackOutput([]*schema.Message{})) + + agenticResult := []*schema.AgenticMessage{{}} + out := ConvCallbackOutput(agenticResult) + assert.NotNil(t, out) + assert.Equal(t, agenticResult, out.AgenticResult) + assert.Nil(t, ConvCallbackOutput("asd")) } diff --git a/components/prompt/chat_template_agentic_test.go b/components/prompt/chat_template_agentic_test.go deleted file mode 100644 index aaa7d6405..000000000 --- a/components/prompt/chat_template_agentic_test.go +++ /dev/null @@ -1,111 +0,0 @@ -/* - * Copyright 2025 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package prompt - -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/cloudwego/eino/schema" -) - -func TestAgenticFormat(t *testing.T) { - pyFmtTestTemplate := []schema.AgenticMessagesTemplate{ - &schema.AgenticMessage{ - Role: schema.AgenticRoleTypeUser, - ContentBlocks: []*schema.ContentBlock{ - {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "{context}"}}, - }, - }, - schema.AgenticMessagesPlaceholder("chat_history", true), - } - jinja2TestTemplate := []schema.AgenticMessagesTemplate{ - &schema.AgenticMessage{ - Role: schema.AgenticRoleTypeUser, - ContentBlocks: []*schema.ContentBlock{ - {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "{{context}}"}}, - }, - }, - schema.AgenticMessagesPlaceholder("chat_history", true), - } - goFmtTestTemplate := []schema.AgenticMessagesTemplate{ - &schema.AgenticMessage{ - Role: schema.AgenticRoleTypeUser, - ContentBlocks: []*schema.ContentBlock{ - {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "{{.context}}"}}, - }, - }, - schema.AgenticMessagesPlaceholder("chat_history", true), - } - testValues := map[string]any{ - "context": "it's beautiful day", - "chat_history": []*schema.AgenticMessage{ - { - Role: schema.AgenticRoleTypeUser, - ContentBlocks: []*schema.ContentBlock{ - {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "1"}}, - }, - }, - { - Role: schema.AgenticRoleTypeUser, - ContentBlocks: []*schema.ContentBlock{ - {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "2"}}, - }, - }, - }, - } - expected := []*schema.AgenticMessage{ - { - Role: schema.AgenticRoleTypeUser, - ContentBlocks: []*schema.ContentBlock{ - {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "it's beautiful day"}}, - }, - }, - { - Role: schema.AgenticRoleTypeUser, - ContentBlocks: []*schema.ContentBlock{ - {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "1"}}, - }, - }, - { - Role: schema.AgenticRoleTypeUser, - ContentBlocks: []*schema.ContentBlock{ - {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "2"}}, - }, - }, - } - - // FString - chatTemplate := FromAgenticMessages(schema.FString, pyFmtTestTemplate...) - msgs, err := chatTemplate.Format(context.Background(), testValues) - assert.Nil(t, err) - assert.Equal(t, expected, msgs) - - // Jinja2 - chatTemplate = FromAgenticMessages(schema.Jinja2, jinja2TestTemplate...) - msgs, err = chatTemplate.Format(context.Background(), testValues) - assert.Nil(t, err) - assert.Equal(t, expected, msgs) - - // GoTemplate - chatTemplate = FromAgenticMessages(schema.GoTemplate, goFmtTestTemplate...) - msgs, err = chatTemplate.Format(context.Background(), testValues) - assert.Nil(t, err) - assert.Equal(t, expected, msgs) -} diff --git a/components/prompt/interface.go b/components/prompt/interface.go index 7ffe7216a..2d5a2cbed 100644 --- a/components/prompt/interface.go +++ b/components/prompt/interface.go @@ -44,6 +44,7 @@ type ChatTemplate interface { Format(ctx context.Context, vs map[string]any, opts ...Option) ([]*schema.Message, error) } +// AgenticChatTemplate formats variables into a list of agentic messages according to a prompt schema. type AgenticChatTemplate interface { Format(ctx context.Context, vs map[string]any, opts ...Option) ([]*schema.AgenticMessage, error) } diff --git a/compose/tools_node_agentic.go b/compose/agentic_tools_node.go similarity index 100% rename from compose/tools_node_agentic.go rename to compose/agentic_tools_node.go diff --git a/compose/tools_node_agentic_test.go b/compose/agentic_tools_node_test.go similarity index 100% rename from compose/tools_node_agentic_test.go rename to compose/agentic_tools_node_test.go diff --git a/compose/chain.go b/compose/chain.go index 8484e8767..abfa6bf1d 100644 --- a/compose/chain.go +++ b/compose/chain.go @@ -22,7 +22,6 @@ import ( "fmt" "reflect" - "github.com/cloudwego/eino/components/agentic" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" @@ -181,7 +180,7 @@ func (c *Chain[I, O]) AppendChatModel(node model.BaseChatModel, opts ...GraphAdd // model, err := openai.NewAgenticModel(ctx, config) // if err != nil {...} // chain.AppendAgenticModel(model) -func (c *Chain[I, O]) AppendAgenticModel(node agentic.Model, opts ...GraphAddNodeOpt) *Chain[I, O] { +func (c *Chain[I, O]) AppendAgenticModel(node model.AgenticModel, opts ...GraphAddNodeOpt) *Chain[I, O] { gNode, options := toAgenticModelNode(node, opts...) c.addNode(gNode, options) return c diff --git a/compose/chain_branch.go b/compose/chain_branch.go index 004dbfac3..84fb11048 100644 --- a/compose/chain_branch.go +++ b/compose/chain_branch.go @@ -20,7 +20,6 @@ import ( "context" "fmt" - "github.com/cloudwego/eino/components/agentic" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" @@ -158,7 +157,7 @@ func (cb *ChainBranch) AddChatModel(key string, node model.BaseChatModel, opts . // }) // cb.AddAgenticModel("agentic_model_key_1", model1) // cb.AddAgenticModel("agentic_model_key_2", model2) -func (cb *ChainBranch) AddAgenticModel(key string, node agentic.Model, opts ...GraphAddNodeOpt) *ChainBranch { +func (cb *ChainBranch) AddAgenticModel(key string, node model.AgenticModel, opts ...GraphAddNodeOpt) *ChainBranch { gNode, options := toAgenticModelNode(node, opts...) return cb.addNode(key, gNode, options) } diff --git a/compose/chain_parallel.go b/compose/chain_parallel.go index 128ed4a26..463140be2 100644 --- a/compose/chain_parallel.go +++ b/compose/chain_parallel.go @@ -19,7 +19,6 @@ package compose import ( "fmt" - "github.com/cloudwego/eino/components/agentic" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" @@ -84,7 +83,7 @@ func (p *Parallel) AddChatModel(outputKey string, node model.BaseChatModel, opts // // p.AddAgenticModel("output_key1", model1) // p.AddAgenticModel("output_key2", model2) -func (p *Parallel) AddAgenticModel(outputKey string, node agentic.Model, opts ...GraphAddNodeOpt) *Parallel { +func (p *Parallel) AddAgenticModel(outputKey string, node model.AgenticModel, opts ...GraphAddNodeOpt) *Parallel { gNode, options := toAgenticModelNode(node, append(opts, WithOutputKey(outputKey))...) return p.addNode(outputKey, gNode, options) } diff --git a/compose/component_to_graph_node.go b/compose/component_to_graph_node.go index e64ce4f19..4bd27fe34 100644 --- a/compose/component_to_graph_node.go +++ b/compose/component_to_graph_node.go @@ -18,7 +18,6 @@ package compose import ( "github.com/cloudwego/eino/components" - "github.com/cloudwego/eino/components/agentic" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" @@ -102,7 +101,7 @@ func toChatModelNode(node model.BaseChatModel, opts ...GraphAddNodeOpt) (*graphN opts...) } -func toAgenticModelNode(node agentic.Model, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { +func toAgenticModelNode(node model.AgenticModel, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { return toComponentNode( node, components.ComponentOfAgenticModel, diff --git a/compose/graph.go b/compose/graph.go index 877b8fb42..bcf5ae423 100644 --- a/compose/graph.go +++ b/compose/graph.go @@ -23,7 +23,6 @@ import ( "reflect" "strings" - "github.com/cloudwego/eino/components/agentic" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" @@ -361,7 +360,7 @@ func (g *graph) AddChatModelNode(key string, node model.BaseChatModel, opts ...G // }) // // graph.AddAgenticModelNode("agentic_model_node_key", model) -func (g *graph) AddAgenticModelNode(key string, node agentic.Model, opts ...GraphAddNodeOpt) error { +func (g *graph) AddAgenticModelNode(key string, node model.AgenticModel, opts ...GraphAddNodeOpt) error { gNode, options := toAgenticModelNode(node, opts...) return g.addNode(key, gNode, options) } diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 63c78c3eb..a4554d38e 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -66,15 +66,15 @@ const ( ) type AgenticMessage struct { - // ResponseMeta is the response metadata. - ResponseMeta *AgenticResponseMeta - // Role is the message role. Role AgenticRoleType // ContentBlocks is the list of content blocks. ContentBlocks []*ContentBlock + // ResponseMeta is the response metadata. + ResponseMeta *AgenticResponseMeta + // Extra is the additional information. Extra map[string]any } @@ -541,7 +541,7 @@ func NewContentBlockChunk[T contentBlockVariant](content *T, meta *StreamingMeta return block } -// AgenticMessagesTemplate is the interface for messages template. +// AgenticMessagesTemplate is the interface for agentic messages template. // It's used to render a template to a list of agentic messages. // e.g. // diff --git a/utils/callbacks/template.go b/utils/callbacks/template.go index e04bddd63..4c73e6bbc 100644 --- a/utils/callbacks/template.go +++ b/utils/callbacks/template.go @@ -55,17 +55,20 @@ func NewHandlerHelper() *HandlerHelper { // // then use the handler with runnable.Invoke(ctx, input, compose.WithCallbacks(handler)) type HandlerHelper struct { - promptHandler *PromptCallbackHandler - chatModelHandler *ModelCallbackHandler - embeddingHandler *EmbeddingCallbackHandler - indexerHandler *IndexerCallbackHandler - retrieverHandler *RetrieverCallbackHandler - loaderHandler *LoaderCallbackHandler - transformerHandler *TransformerCallbackHandler - toolHandler *ToolCallbackHandler - toolsNodeHandler *ToolsNodeCallbackHandlers - agentHandler *AgentCallbackHandler - composeTemplates map[components.Component]callbacks.Handler + promptHandler *PromptCallbackHandler + chatModelHandler *ModelCallbackHandler + embeddingHandler *EmbeddingCallbackHandler + indexerHandler *IndexerCallbackHandler + retrieverHandler *RetrieverCallbackHandler + loaderHandler *LoaderCallbackHandler + transformerHandler *TransformerCallbackHandler + toolHandler *ToolCallbackHandler + toolsNodeHandler *ToolsNodeCallbackHandlers + agenticPromptHandler *AgenticPromptCallbackHandler + agenticModelHandler *AgenticModelCallbackHandler + agenticToolsNodeHandler *AgenticToolsNodeCallbackHandlers + agentHandler *AgentCallbackHandler + composeTemplates map[components.Component]callbacks.Handler } // Handler returns the callbacks.Handler created by HandlerHelper. @@ -127,6 +130,24 @@ func (c *HandlerHelper) ToolsNode(handler *ToolsNodeCallbackHandlers) *HandlerHe return c } +// AgenticPrompt sets the agentic prompt handler for the handler helper, which will be called when the agentic prompt component is executed. +func (c *HandlerHelper) AgenticPrompt(handler *AgenticPromptCallbackHandler) *HandlerHelper { + c.agenticPromptHandler = handler + return c +} + +// AgenticModel sets the agentic chat model handler for the handler helper, which will be called when the agentic chat model component is executed. +func (c *HandlerHelper) AgenticModel(handler *AgenticModelCallbackHandler) *HandlerHelper { + c.agenticModelHandler = handler + return c +} + +// AgenticToolsNode sets the agentic tools node handler for the handler helper, which will be called when the agentic tools node is executed. +func (c *HandlerHelper) AgenticToolsNode(handler *AgenticToolsNodeCallbackHandlers) *HandlerHelper { + c.agenticToolsNodeHandler = handler + return c +} + // Agent sets the agent handler for the handler helper, which will be called when the agent is executed. func (c *HandlerHelper) Agent(handler *AgentCallbackHandler) *HandlerHelper { c.agentHandler = handler @@ -161,8 +182,12 @@ func (c *handlerTemplate) OnStart(ctx context.Context, info *callbacks.RunInfo, switch info.Component { case components.ComponentOfPrompt: return c.promptHandler.OnStart(ctx, info, prompt.ConvCallbackInput(input)) + case components.ComponentOfAgenticPrompt: + return c.agenticPromptHandler.OnStart(ctx, info, prompt.ConvCallbackInput(input)) case components.ComponentOfChatModel: return c.chatModelHandler.OnStart(ctx, info, model.ConvCallbackInput(input)) + case components.ComponentOfAgenticModel: + return c.agenticModelHandler.OnStart(ctx, info, model.ConvCallbackInput(input)) case components.ComponentOfEmbedding: return c.embeddingHandler.OnStart(ctx, info, embedding.ConvCallbackInput(input)) case components.ComponentOfIndexer: @@ -177,6 +202,8 @@ func (c *handlerTemplate) OnStart(ctx context.Context, info *callbacks.RunInfo, return c.toolHandler.OnStart(ctx, info, tool.ConvCallbackInput(input)) case compose.ComponentOfToolsNode: return c.toolsNodeHandler.OnStart(ctx, info, convToolsNodeCallbackInput(input)) + case compose.ComponentOfAgenticToolsNode: + return c.agenticToolsNodeHandler.OnStart(ctx, info, convAgenticToolsNodeCallbackInput(input)) case adk.ComponentOfAgent: return c.agentHandler.OnStart(ctx, info, adk.ConvAgentCallbackInput(input)) case compose.ComponentOfGraph, @@ -194,8 +221,12 @@ func (c *handlerTemplate) OnEnd(ctx context.Context, info *callbacks.RunInfo, ou switch info.Component { case components.ComponentOfPrompt: return c.promptHandler.OnEnd(ctx, info, prompt.ConvCallbackOutput(output)) + case components.ComponentOfAgenticPrompt: + return c.agenticPromptHandler.OnEnd(ctx, info, prompt.ConvCallbackOutput(output)) case components.ComponentOfChatModel: return c.chatModelHandler.OnEnd(ctx, info, model.ConvCallbackOutput(output)) + case components.ComponentOfAgenticModel: + return c.agenticModelHandler.OnEnd(ctx, info, model.ConvCallbackOutput(output)) case components.ComponentOfEmbedding: return c.embeddingHandler.OnEnd(ctx, info, embedding.ConvCallbackOutput(output)) case components.ComponentOfIndexer: @@ -210,6 +241,8 @@ func (c *handlerTemplate) OnEnd(ctx context.Context, info *callbacks.RunInfo, ou return c.toolHandler.OnEnd(ctx, info, tool.ConvCallbackOutput(output)) case compose.ComponentOfToolsNode: return c.toolsNodeHandler.OnEnd(ctx, info, convToolsNodeCallbackOutput(output)) + case compose.ComponentOfAgenticToolsNode: + return c.agenticToolsNodeHandler.OnEnd(ctx, info, convAgenticToolsNodeCallbackOutput(output)) case adk.ComponentOfAgent: return c.agentHandler.OnEnd(ctx, info, adk.ConvAgentCallbackOutput(output)) case compose.ComponentOfGraph, @@ -227,8 +260,12 @@ func (c *handlerTemplate) OnError(ctx context.Context, info *callbacks.RunInfo, switch info.Component { case components.ComponentOfPrompt: return c.promptHandler.OnError(ctx, info, err) + case components.ComponentOfAgenticPrompt: + return c.agenticPromptHandler.OnError(ctx, info, err) case components.ComponentOfChatModel: return c.chatModelHandler.OnError(ctx, info, err) + case components.ComponentOfAgenticModel: + return c.agenticModelHandler.OnError(ctx, info, err) case components.ComponentOfEmbedding: return c.embeddingHandler.OnError(ctx, info, err) case components.ComponentOfIndexer: @@ -243,6 +280,8 @@ func (c *handlerTemplate) OnError(ctx context.Context, info *callbacks.RunInfo, return c.toolHandler.OnError(ctx, info, err) case compose.ComponentOfToolsNode: return c.toolsNodeHandler.OnError(ctx, info, err) + case compose.ComponentOfAgenticToolsNode: + return c.agenticToolsNodeHandler.OnError(ctx, info, err) case compose.ComponentOfGraph, compose.ComponentOfChain, compose.ComponentOfLambda: @@ -275,6 +314,11 @@ func (c *handlerTemplate) OnEndWithStreamOutput(ctx context.Context, info *callb schema.StreamReaderWithConvert(output, func(item callbacks.CallbackOutput) (*model.CallbackOutput, error) { return model.ConvCallbackOutput(item), nil })) + case components.ComponentOfAgenticModel: + return c.agenticModelHandler.OnEndWithStreamOutput(ctx, info, + schema.StreamReaderWithConvert(output, func(item callbacks.CallbackOutput) (*model.CallbackOutput, error) { + return model.ConvCallbackOutput(item), nil + })) case components.ComponentOfTool: return c.toolHandler.OnEndWithStreamOutput(ctx, info, schema.StreamReaderWithConvert(output, func(item callbacks.CallbackOutput) (*tool.CallbackOutput, error) { @@ -285,6 +329,11 @@ func (c *handlerTemplate) OnEndWithStreamOutput(ctx context.Context, info *callb schema.StreamReaderWithConvert(output, func(item callbacks.CallbackOutput) ([]*schema.Message, error) { return convToolsNodeCallbackOutput(item), nil })) + case compose.ComponentOfAgenticToolsNode: + return c.agenticToolsNodeHandler.OnEndWithStreamOutput(ctx, info, + schema.StreamReaderWithConvert(output, func(item callbacks.CallbackOutput) ([]*schema.AgenticMessage, error) { + return convAgenticToolsNodeCallbackOutput(item), nil + })) case compose.ComponentOfGraph, compose.ComponentOfChain, compose.ComponentOfLambda: @@ -295,6 +344,8 @@ func (c *handlerTemplate) OnEndWithStreamOutput(ctx context.Context, info *callb } // Needed checks if the callback handler is needed for the given timing. +// +//nolint:cyclop func (c *handlerTemplate) Needed(ctx context.Context, info *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { if info == nil { return false @@ -305,6 +356,10 @@ func (c *handlerTemplate) Needed(ctx context.Context, info *callbacks.RunInfo, t if c.chatModelHandler != nil && c.chatModelHandler.Needed(ctx, info, timing) { return true } + case components.ComponentOfAgenticModel: + if c.agenticModelHandler != nil && c.agenticModelHandler.Needed(ctx, info, timing) { + return true + } case components.ComponentOfEmbedding: if c.embeddingHandler != nil && c.embeddingHandler.Needed(ctx, info, timing) { return true @@ -321,6 +376,10 @@ func (c *handlerTemplate) Needed(ctx context.Context, info *callbacks.RunInfo, t if c.promptHandler != nil && c.promptHandler.Needed(ctx, info, timing) { return true } + case components.ComponentOfAgenticPrompt: + if c.agenticPromptHandler != nil && c.agenticPromptHandler.Needed(ctx, info, timing) { + return true + } case components.ComponentOfRetriever: if c.retrieverHandler != nil && c.retrieverHandler.Needed(ctx, info, timing) { return true @@ -337,6 +396,10 @@ func (c *handlerTemplate) Needed(ctx context.Context, info *callbacks.RunInfo, t if c.toolsNodeHandler != nil && c.toolsNodeHandler.Needed(ctx, info, timing) { return true } + case compose.ComponentOfAgenticToolsNode: + if c.agenticToolsNodeHandler != nil && c.agenticToolsNodeHandler.Needed(ctx, info, timing) { + return true + } case adk.ComponentOfAgent: if c.agentHandler != nil && c.agentHandler.Needed(ctx, info, timing) { return true @@ -596,3 +659,94 @@ func (ch *AgentCallbackHandler) Needed(ctx context.Context, info *callbacks.RunI return false } } + +// AgenticPromptCallbackHandler is the handler for the agentic prompt callback. +type AgenticPromptCallbackHandler struct { + // OnStart is the callback function for the start of the agentic prompt. + OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *prompt.CallbackInput) context.Context + // OnEnd is the callback function for the end of the agentic prompt. + OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *prompt.CallbackOutput) context.Context + // OnError is the callback function for the error of the agentic prompt. + OnError func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context +} + +// Needed checks if the callback handler is needed for the given timing. +func (ch *AgenticPromptCallbackHandler) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { + switch timing { + case callbacks.TimingOnStart: + return ch.OnStart != nil + case callbacks.TimingOnEnd: + return ch.OnEnd != nil + case callbacks.TimingOnError: + return ch.OnError != nil + default: + return false + } +} + +// AgenticModelCallbackHandler is the handler for the agentic chat model callback. +type AgenticModelCallbackHandler struct { + OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.CallbackInput) context.Context + OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *model.CallbackOutput) context.Context + OnEndWithStreamOutput func(ctx context.Context, runInfo *callbacks.RunInfo, output *schema.StreamReader[*model.CallbackOutput]) context.Context + OnError func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context +} + +// Needed checks if the callback handler is needed for the given timing. +func (ch *AgenticModelCallbackHandler) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { + switch timing { + case callbacks.TimingOnStart: + return ch.OnStart != nil + case callbacks.TimingOnEnd: + return ch.OnEnd != nil + case callbacks.TimingOnError: + return ch.OnError != nil + case callbacks.TimingOnEndWithStreamOutput: + return ch.OnEndWithStreamOutput != nil + default: + return false + } +} + +// AgenticToolsNodeCallbackHandlers defines optional callbacks for the Agentic Tools node +// lifecycle events. +type AgenticToolsNodeCallbackHandlers struct { + OnStart func(ctx context.Context, info *callbacks.RunInfo, input *schema.AgenticMessage) context.Context + OnEnd func(ctx context.Context, info *callbacks.RunInfo, input []*schema.AgenticMessage) context.Context + OnEndWithStreamOutput func(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[[]*schema.AgenticMessage]) context.Context + OnError func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context +} + +// Needed reports whether a handler is registered for the given timing. +func (ch *AgenticToolsNodeCallbackHandlers) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { + switch timing { + case callbacks.TimingOnStart: + return ch.OnStart != nil + case callbacks.TimingOnEnd: + return ch.OnEnd != nil + case callbacks.TimingOnEndWithStreamOutput: + return ch.OnEndWithStreamOutput != nil + case callbacks.TimingOnError: + return ch.OnError != nil + default: + return false + } +} + +func convAgenticToolsNodeCallbackInput(src callbacks.CallbackInput) *schema.AgenticMessage { + switch t := src.(type) { + case *schema.AgenticMessage: + return t + default: + return nil + } +} + +func convAgenticToolsNodeCallbackOutput(src callbacks.CallbackInput) []*schema.AgenticMessage { + switch t := src.(type) { + case []*schema.AgenticMessage: + return t + default: + return nil + } +} diff --git a/utils/callbacks/template_test.go b/utils/callbacks/template_test.go index 84ed6dfc6..f599e5300 100644 --- a/utils/callbacks/template_test.go +++ b/utils/callbacks/template_test.go @@ -142,6 +142,58 @@ func TestNewComponentTemplate(t *testing.T) { cnt++ return ctx }).Build()). + AgenticModel(&AgenticModelCallbackHandler{ + OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.CallbackInput) context.Context { + cnt++ + return ctx + }, + OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *model.CallbackOutput) context.Context { + cnt++ + return ctx + }, + OnEndWithStreamOutput: func(ctx context.Context, runInfo *callbacks.RunInfo, output *schema.StreamReader[*model.CallbackOutput]) context.Context { + output.Close() + cnt++ + return ctx + }, + OnError: func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context { + cnt++ + return ctx + }, + }). + AgenticPrompt(&AgenticPromptCallbackHandler{ + OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *prompt.CallbackInput) context.Context { + cnt++ + return ctx + }, + OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *prompt.CallbackOutput) context.Context { + cnt++ + return ctx + }, + OnError: func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context { + cnt++ + return ctx + }, + }). + AgenticToolsNode(&AgenticToolsNodeCallbackHandlers{ + OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *schema.AgenticMessage) context.Context { + cnt++ + return ctx + }, + OnEnd: func(ctx context.Context, info *callbacks.RunInfo, input []*schema.AgenticMessage) context.Context { + cnt++ + return ctx + }, + OnEndWithStreamOutput: func(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[[]*schema.AgenticMessage]) context.Context { + output.Close() + cnt++ + return ctx + }, + OnError: func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { + cnt++ + return ctx + }, + }). Handler() types := []components.Component{ @@ -151,6 +203,9 @@ func TestNewComponentTemplate(t *testing.T) { components.ComponentOfRetriever, components.ComponentOfTool, compose.ComponentOfLambda, + components.ComponentOfAgenticModel, + components.ComponentOfAgenticPrompt, + compose.ComponentOfAgenticToolsNode, } handler := tpl.Handler() @@ -169,28 +224,28 @@ func TestNewComponentTemplate(t *testing.T) { handler.OnEndWithStreamOutput(ctx, &callbacks.RunInfo{Component: typ}, sor) } - assert.Equal(t, 22, cnt) + assert.Equal(t, 33, cnt) ctx = context.Background() ctx = callbacks.InitCallbacks(ctx, &callbacks.RunInfo{Component: components.ComponentOfTransformer}, handler) callbacks.OnStart[any](ctx, nil) - assert.Equal(t, 22, cnt) + assert.Equal(t, 33, cnt) ctx = callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{Component: components.ComponentOfPrompt}) ctx = callbacks.OnStart[any](ctx, nil) - assert.Equal(t, 23, cnt) + assert.Equal(t, 34, cnt) ctx = callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{Component: components.ComponentOfIndexer}) callbacks.OnEnd[any](ctx, nil) - assert.Equal(t, 23, cnt) + assert.Equal(t, 34, cnt) ctx = callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{Component: components.ComponentOfEmbedding}) callbacks.OnError(ctx, nil) - assert.Equal(t, 24, cnt) + assert.Equal(t, 35, cnt) ctx = callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{Component: components.ComponentOfLoader}) callbacks.OnStart[any](ctx, nil) - assert.Equal(t, 24, cnt) + assert.Equal(t, 35, cnt) tpl.Transformer(&TransformerCallbackHandler{ OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *document.TransformerCallbackInput) context.Context { @@ -250,6 +305,37 @@ func TestNewComponentTemplate(t *testing.T) { } } }, + }).AgenticPrompt(&AgenticPromptCallbackHandler{ + OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *prompt.CallbackInput) context.Context { + cnt++ + return ctx + }, + OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *prompt.CallbackOutput) context.Context { + cnt++ + return ctx + }, + OnError: func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context { + cnt++ + return ctx + }, + }).AgenticToolsNode(&AgenticToolsNodeCallbackHandlers{ + OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *schema.AgenticMessage) context.Context { + cnt++ + return ctx + }, + OnEnd: func(ctx context.Context, info *callbacks.RunInfo, input []*schema.AgenticMessage) context.Context { + cnt++ + return ctx + }, + OnEndWithStreamOutput: func(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[[]*schema.AgenticMessage]) context.Context { + output.Close() + cnt++ + return ctx + }, + OnError: func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { + cnt++ + return ctx + }, }) handler = tpl.Handler() @@ -257,36 +343,222 @@ func TestNewComponentTemplate(t *testing.T) { ctx = callbacks.InitCallbacks(ctx, &callbacks.RunInfo{Component: components.ComponentOfTransformer}, handler) ctx = callbacks.OnStart[any](ctx, nil) - assert.Equal(t, 25, cnt) + assert.Equal(t, 36, cnt) callbacks.OnEnd[any](ctx, nil) - assert.Equal(t, 26, cnt) + assert.Equal(t, 37, cnt) ctx = callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{Component: components.ComponentOfLoader}) callbacks.OnEnd[any](ctx, nil) - assert.Equal(t, 27, cnt) + assert.Equal(t, 38, cnt) ctx = callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{Component: compose.ComponentOfToolsNode}) callbacks.OnStart[any](ctx, nil) - assert.Equal(t, 28, cnt) + assert.Equal(t, 39, cnt) sr, sw := schema.Pipe[any](0) sw.Close() callbacks.OnEndWithStreamOutput[any](ctx, sr) - assert.Equal(t, 29, cnt) + assert.Equal(t, 40, cnt) sr1, sw1 := schema.Pipe[[]*schema.Message](1) sw1.Send([]*schema.Message{{}}, nil) sw1.Close() callbacks.OnEndWithStreamOutput[[]*schema.Message](ctx, sr1) - assert.Equal(t, 30, cnt) - - callbacks.OnError(ctx, nil) - assert.Equal(t, 30, cnt) + // Check AgenticModel stream + sir2, siw2 := schema.Pipe[callbacks.CallbackOutput](1) + siw2.Close() + handler.OnEndWithStreamOutput(ctx, &callbacks.RunInfo{Component: components.ComponentOfAgenticModel}, sir2) + assert.Equal(t, 42, cnt) + + // Check AgenticToolsNode stream + sir3, siw3 := schema.Pipe[callbacks.CallbackOutput](1) + siw3.Close() + handler.OnEndWithStreamOutput(ctx, &callbacks.RunInfo{Component: compose.ComponentOfAgenticToolsNode}, sir3) + assert.Equal(t, 43, cnt) ctx = callbacks.ReuseHandlers(ctx, nil) callbacks.OnStart[any](ctx, nil) - assert.Equal(t, 30, cnt) + assert.Equal(t, 43, cnt) + }) + + t.Run("EdgeCases", func(t *testing.T) { + ctx := context.Background() + cnt := 0 + + // 1. Test Graph and Chain Setters and Execution + tpl := NewHandlerHelper(). + Graph(callbacks.NewHandlerBuilder(). + OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { + cnt++ + return ctx + }).Build()). + Chain(callbacks.NewHandlerBuilder(). + OnEndFn(func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context { + cnt++ + return ctx + }).Build()) + + h := tpl.Handler() + + // Trigger Graph OnStart + h.OnStart(ctx, &callbacks.RunInfo{Component: compose.ComponentOfGraph}, nil) + assert.Equal(t, 1, cnt) + + // Trigger Chain OnEnd + h.OnEnd(ctx, &callbacks.RunInfo{Component: compose.ComponentOfChain}, nil) + assert.Equal(t, 2, cnt) + + // 2. Test Needed logic for Graph/Chain when handler is present/absent + // Graph is present (OnStart) + needed := h.(callbacks.TimingChecker).Needed(ctx, &callbacks.RunInfo{Component: compose.ComponentOfGraph}, callbacks.TimingOnStart) + assert.True(t, needed) + + // Chain is present (OnEnd) - but we check OnStart which is not defined in the builder above? + // NewHandlerBuilder returns a handler that usually returns true for Needed if the specific func is not nil. + // Let's verify Chain OnStart is NOT needed because we only set OnEndFn. + needed = h.(callbacks.TimingChecker).Needed(ctx, &callbacks.RunInfo{Component: compose.ComponentOfChain}, callbacks.TimingOnStart) + assert.False(t, needed) // Should be false because OnStartFn wasn't set for Chain + + // Lambda is NOT present + needed = h.(callbacks.TimingChecker).Needed(ctx, &callbacks.RunInfo{Component: compose.ComponentOfLambda}, callbacks.TimingOnStart) + assert.False(t, needed) + + // 3. Test Conversion Fallbacks (Default cases) + // We need a handler with ToolsNode and AgenticToolsNode to test their conversion fallbacks + tpl2 := NewHandlerHelper(). + ToolsNode(&ToolsNodeCallbackHandlers{ + OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *schema.Message) context.Context { + if input == nil { + cnt++ + } + return ctx + }, + OnEnd: func(ctx context.Context, info *callbacks.RunInfo, input []*schema.Message) context.Context { + if input == nil { + cnt++ + } + return ctx + }, + }). + AgenticToolsNode(&AgenticToolsNodeCallbackHandlers{ + OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *schema.AgenticMessage) context.Context { + if input == nil { + cnt++ + } + return ctx + }, + OnEnd: func(ctx context.Context, info *callbacks.RunInfo, input []*schema.AgenticMessage) context.Context { + if input == nil { + cnt++ + } + return ctx + }, + }) + + h2 := tpl2.Handler() + + // Pass wrong type (string) to trigger default case in convToolsNodeCallbackInput -> returns nil + h2.OnStart(ctx, &callbacks.RunInfo{Component: compose.ComponentOfToolsNode}, "wrong-input-type") + assert.Equal(t, 3, cnt) // +1 + + // Pass wrong type to trigger default case in convToolsNodeCallbackOutput -> returns nil + h2.OnEnd(ctx, &callbacks.RunInfo{Component: compose.ComponentOfToolsNode}, "wrong-output-type") + assert.Equal(t, 4, cnt) // +1 + + // Pass wrong type to trigger default case in convAgenticToolsNodeCallbackInput -> returns nil + h2.OnStart(ctx, &callbacks.RunInfo{Component: compose.ComponentOfAgenticToolsNode}, "wrong-input-type") + assert.Equal(t, 5, cnt) // +1 + + // Pass wrong type to trigger default case in convAgenticToolsNodeCallbackOutput -> returns nil + h2.OnEnd(ctx, &callbacks.RunInfo{Component: compose.ComponentOfAgenticToolsNode}, "wrong-output-type") + assert.Equal(t, 6, cnt) // +1 + + // 4. Test Needed for Agentic components when handlers are Set vs Unset + // tpl2 has AgenticToolsNode set + needed = h2.(callbacks.TimingChecker).Needed(ctx, &callbacks.RunInfo{Component: compose.ComponentOfAgenticToolsNode}, callbacks.TimingOnStart) + assert.True(t, needed) + + // tpl2 does NOT have AgenticModel set + needed = h2.(callbacks.TimingChecker).Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfAgenticModel}, callbacks.TimingOnStart) + assert.False(t, needed) + + // Set it now + tpl2.AgenticModel(&AgenticModelCallbackHandler{ + OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.CallbackInput) context.Context { + return ctx + }, + }) + + needed = h2.(callbacks.TimingChecker).Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfAgenticModel}, callbacks.TimingOnStart) + assert.True(t, needed) + + // Check invalid component + needed = h2.(callbacks.TimingChecker).Needed(ctx, &callbacks.RunInfo{Component: "UnknownComponent"}, callbacks.TimingOnStart) + assert.False(t, needed) + + // Check RunInfo nil + needed = h2.(callbacks.TimingChecker).Needed(ctx, nil, callbacks.TimingOnStart) + assert.False(t, needed) + + // 5. Test Needed for Transformer, Loader, Indexer, etc to ensure switch coverage + tpl3 := NewHandlerHelper(). + Transformer(&TransformerCallbackHandler{OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *document.TransformerCallbackInput) context.Context { + return ctx + }}). + Loader(&LoaderCallbackHandler{OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *document.LoaderCallbackInput) context.Context { + return ctx + }}). + Indexer(&IndexerCallbackHandler{OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *indexer.CallbackInput) context.Context { + return ctx + }}). + Retriever(&RetrieverCallbackHandler{OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *retriever.CallbackInput) context.Context { + return ctx + }}). + Embedding(&EmbeddingCallbackHandler{OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *embedding.CallbackInput) context.Context { + return ctx + }}). + Tool(&ToolCallbackHandler{OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *tool.CallbackInput) context.Context { + return ctx + }}) + + h3 := tpl3.Handler() + checker := h3.(callbacks.TimingChecker) + + assert.True(t, checker.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfTransformer}, callbacks.TimingOnStart)) + assert.True(t, checker.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfLoader}, callbacks.TimingOnStart)) + assert.True(t, checker.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfIndexer}, callbacks.TimingOnStart)) + assert.True(t, checker.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfRetriever}, callbacks.TimingOnStart)) + assert.True(t, checker.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfEmbedding}, callbacks.TimingOnStart)) + assert.True(t, checker.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfTool}, callbacks.TimingOnStart)) + + // Verify False paths (by using a helper without them) + emptyH := NewHandlerHelper().Handler().(callbacks.TimingChecker) + assert.False(t, emptyH.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfTransformer}, callbacks.TimingOnStart)) + assert.False(t, emptyH.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfLoader}, callbacks.TimingOnStart)) + assert.False(t, emptyH.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfIndexer}, callbacks.TimingOnStart)) + assert.False(t, emptyH.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfRetriever}, callbacks.TimingOnStart)) + assert.False(t, emptyH.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfEmbedding}, callbacks.TimingOnStart)) + assert.False(t, emptyH.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfTool}, callbacks.TimingOnStart)) + + // 6. Test Needed for remaining components (ChatModel, Prompt, AgenticPrompt) + tpl4 := NewHandlerHelper(). + ChatModel(&ModelCallbackHandler{OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.CallbackInput) context.Context { + return ctx + }}). + Prompt(&PromptCallbackHandler{OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *prompt.CallbackInput) context.Context { + return ctx + }}). + AgenticPrompt(&AgenticPromptCallbackHandler{OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *prompt.CallbackInput) context.Context { + return ctx + }}) + + h4 := tpl4.Handler() + checker4 := h4.(callbacks.TimingChecker) + + assert.True(t, checker4.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfChatModel}, callbacks.TimingOnStart)) + assert.True(t, checker4.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfPrompt}, callbacks.TimingOnStart)) + assert.True(t, checker4.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfAgenticPrompt}, callbacks.TimingOnStart)) }) } From 29b518edfec58d83ea0c5a2dac8c97633a48799c Mon Sep 17 00:00:00 2001 From: mrh997 Date: Thu, 15 Jan 2026 16:19:47 +0800 Subject: [PATCH 23/59] feat: improve AgenticToolChoice (#684) --- components/model/option.go | 17 ++++++++--------- components/model/option_test.go | 13 ++++++++++--- compose/workflow.go | 18 ++++++++++++++++++ schema/tool.go | 31 +++++++++++++++++++++++++++++-- 4 files changed, 65 insertions(+), 14 deletions(-) diff --git a/components/model/option.go b/components/model/option.go index 0173d22aa..a337b7af2 100644 --- a/components/model/option.go +++ b/components/model/option.go @@ -28,11 +28,11 @@ type Options struct { TopP *float32 // Tools is a list of tools the model may call. Tools []*schema.ToolInfo - // ToolChoice controls which tool is called by the model. - ToolChoice *schema.ToolChoice - // Options only for chat model. + // Options only available for chat model. + // ToolChoice controls which tool is called by the model. + ToolChoice *schema.ToolChoice // MaxTokens is the max number of tokens, if reached the max tokens, the model will stop generating, and mostly return an finish reason of "length". MaxTokens *int // AllowedToolNames specifies a list of tool names that the model is allowed to call. @@ -41,10 +41,10 @@ type Options struct { // Stop is the stop words for the model, which controls the stopping condition of the model. Stop []string - // Options only for agentic model. + // Options only available for agentic model. - // AllowedTools is a list of allowed tools the model may call. - AllowedTools []*schema.AllowedTool + // AgenticToolChoice controls how the agentic model calls tools. + AgenticToolChoice *schema.AgenticToolChoice } // Option is a call-time option for a ChatModel. Options are immutable and @@ -130,11 +130,10 @@ func WithToolChoice(toolChoice schema.ToolChoice, allowedToolNames ...string) Op // WithAgenticToolChoice is the option to set tool choice for the agentic model. // Only available for AgenticModel. -func WithAgenticToolChoice(toolChoice schema.ToolChoice, allowedTools ...*schema.AllowedTool) Option { +func WithAgenticToolChoice(toolChoice *schema.AgenticToolChoice) Option { return Option{ apply: func(opts *Options) { - opts.ToolChoice = &toolChoice - opts.AllowedTools = allowedTools + opts.AgenticToolChoice = toolChoice }, } } diff --git a/components/model/option_test.go b/components/model/option_test.go index bfacdd17c..aa43e6e01 100644 --- a/components/model/option_test.go +++ b/components/model/option_test.go @@ -92,11 +92,18 @@ func TestOptions(t *testing.T) { ) opts := GetCommonOptions( nil, - WithAgenticToolChoice(toolChoice, allowedTools...), + WithAgenticToolChoice(&schema.AgenticToolChoice{ + Type: toolChoice, + Forced: &schema.AgenticForcedToolChoice{ + Tools: allowedTools, + }, + }), ) - convey.So(opts.ToolChoice, convey.ShouldResemble, &toolChoice) - convey.So(opts.AllowedTools, convey.ShouldResemble, allowedTools) + convey.So(opts.AgenticToolChoice, convey.ShouldNotBeNil) + convey.So(opts.AgenticToolChoice.Type, convey.ShouldEqual, toolChoice) + convey.So(opts.AgenticToolChoice.Forced, convey.ShouldNotBeNil) + convey.So(opts.AgenticToolChoice.Forced.Tools, convey.ShouldResemble, allowedTools) }) } diff --git a/compose/workflow.go b/compose/workflow.go index c3e4331a3..6b50962bb 100644 --- a/compose/workflow.go +++ b/compose/workflow.go @@ -89,18 +89,36 @@ func (wf *Workflow[I, O]) AddChatModelNode(key string, chatModel model.BaseChatM return wf.initNode(key) } +// AddAgenticModelNode adds an agentic model node and returns it. +func (wf *Workflow[I, O]) AddAgenticModelNode(key string, agenticModel model.AgenticModel, opts ...GraphAddNodeOpt) *WorkflowNode { + _ = wf.g.AddAgenticModelNode(key, agenticModel, opts...) + return wf.initNode(key) +} + // AddChatTemplateNode adds a chat template node and returns it. func (wf *Workflow[I, O]) AddChatTemplateNode(key string, chatTemplate prompt.ChatTemplate, opts ...GraphAddNodeOpt) *WorkflowNode { _ = wf.g.AddChatTemplateNode(key, chatTemplate, opts...) return wf.initNode(key) } +// AddAgenticChatTemplateNode adds an agentic chat template node and returns it. +func (wf *Workflow[I, O]) AddAgenticChatTemplateNode(key string, chatTemplate prompt.AgenticChatTemplate, opts ...GraphAddNodeOpt) *WorkflowNode { + _ = wf.g.AddAgenticChatTemplateNode(key, chatTemplate, opts...) + return wf.initNode(key) +} + // AddToolsNode adds a tools node and returns it. func (wf *Workflow[I, O]) AddToolsNode(key string, tools *ToolsNode, opts ...GraphAddNodeOpt) *WorkflowNode { _ = wf.g.AddToolsNode(key, tools, opts...) return wf.initNode(key) } +// AddAgenticToolsNode adds an agentic tools node and returns it. +func (wf *Workflow[I, O]) AddAgenticToolsNode(key string, tools *AgenticToolsNode, opts ...GraphAddNodeOpt) *WorkflowNode { + _ = wf.g.AddAgenticToolsNode(key, tools, opts...) + return wf.initNode(key) +} + // AddRetrieverNode adds a retriever node and returns it. func (wf *Workflow[I, O]) AddRetrieverNode(key string, retriever retriever.Retriever, opts ...GraphAddNodeOpt) *WorkflowNode { _ = wf.g.AddRetrieverNode(key, retriever, opts...) diff --git a/schema/tool.go b/schema/tool.go index efed6d34b..a067d87db 100644 --- a/schema/tool.go +++ b/schema/tool.go @@ -59,6 +59,31 @@ const ( ToolChoiceForced ToolChoice = "forced" ) +type AgenticToolChoice struct { + // Type is the tool choice mode. + Type ToolChoice + // Allowed optionally specifies the list of tools that the model is permitted to call. + Allowed *AgenticAllowedToolChoice + // Forced optionally specifies the list of tools that the model is required to call. + Forced *AgenticForcedToolChoice +} + +// AgenticAllowedToolChoice specifies a list of allowed tools for the model. +type AgenticAllowedToolChoice struct { + // Tools is the list of allowed tools for the model to call. + // Optional. + Tools []*AllowedTool +} + +// AgenticForcedToolChoice specifies a list of tools that the model must call. +type AgenticForcedToolChoice struct { + // Tools is the list of tools that the model must call. + // Optional. + Tools []*AllowedTool +} + +// AllowedTool represents a tool that the model is allowed or forced to call. +// Exactly one of FunctionToolName, MCPTool, or ServerTool must be specified. type AllowedTool struct { // FunctionToolName is the name of the function tool. FunctionToolName string @@ -68,15 +93,17 @@ type AllowedTool struct { ServerTool *AllowedServerTool } +// AllowedMCPTool contains the information for identifying an MCP tool. type AllowedMCPTool struct { // ServerLabel is the label of the MCP server. ServerLabel string - // The name of the MCP tool. + // Name is the name of the MCP tool. Name string } +// AllowedServerTool contains the information for identifying a server tool. type AllowedServerTool struct { - // The name of the server tool. + // Name is the name of the server tool. Name string } From 94799ab54f2477ab5ae8282a19ac1dad566fb855 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Thu, 15 Jan 2026 18:04:59 +0800 Subject: [PATCH 24/59] feat: define AgenticCallbackInput/Output (#689) --- components/model/agentic_callback_extra.go | 92 +++++++++++++++++++ .../model/agentic_callback_extra_test.go | 35 +++++++ components/model/callback_extra.go | 12 --- components/model/option_test.go | 2 +- components/prompt/agentic_callback_extra.go | 70 ++++++++++++++ .../prompt/agentic_callback_extra_test.go | 46 ++++++++++ components/prompt/agentic_chat_template.go | 4 +- components/prompt/callback_extra.go | 10 -- components/prompt/callback_extra_test.go | 17 +--- schema/agentic_message.go | 9 -- schema/agentic_message_test.go | 9 -- schema/tool.go | 8 +- 12 files changed, 256 insertions(+), 58 deletions(-) create mode 100644 components/model/agentic_callback_extra.go create mode 100644 components/model/agentic_callback_extra_test.go create mode 100644 components/prompt/agentic_callback_extra.go create mode 100644 components/prompt/agentic_callback_extra_test.go diff --git a/components/model/agentic_callback_extra.go b/components/model/agentic_callback_extra.go new file mode 100644 index 000000000..28dd366e6 --- /dev/null +++ b/components/model/agentic_callback_extra.go @@ -0,0 +1,92 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package model + +import ( + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/schema" +) + +// AgenticConfig is the config for the agentic model. +type AgenticConfig struct { + // Model is the model name. + Model string + // Temperature is the temperature, which controls the randomness of the agentic model. + Temperature float32 + // TopP is the top p, which controls the diversity of the agentic model. + TopP float32 +} + +// AgenticCallbackInput is the input for the agentic model callback. +type AgenticCallbackInput struct { + // AgenticMessages is the agentic messages to be sent to the agentic model. + AgenticMessages []*schema.AgenticMessage + // Tools is the tools to be used in the agentic model. + Tools []*schema.ToolInfo + // Config is the config for the agentic model. + Config *AgenticConfig + // Extra is the extra information for the callback. + Extra map[string]any +} + +// AgenticCallbackOutput is the output for the agentic model callback. +type AgenticCallbackOutput struct { + // AgenticMessage is the agentic message generated by the agentic model. + AgenticMessage *schema.AgenticMessage + // Config is the config for the agentic model. + Config *AgenticConfig + // TokenUsage is the token usage of this request. + TokenUsage *TokenUsage + // Extra is the extra information for the callback. + Extra map[string]any +} + +// ConvAgenticCallbackInput converts the callback input to the agentic model callback input. +func ConvAgenticCallbackInput(src callbacks.CallbackInput) *AgenticCallbackInput { + switch t := src.(type) { + case *AgenticCallbackInput: + // when callback is triggered within component implementation, + // the input is usually already a typed *model.AgenticCallbackInput + return t + case []*schema.AgenticMessage: + // when callback is injected by graph node, not the component implementation itself, + // the input is the input of Agentic Model interface, which is []*schema.AgenticMessage + return &AgenticCallbackInput{ + AgenticMessages: t, + } + default: + return nil + } +} + +// ConvAgenticCallbackOutput converts the callback output to the agentic model callback output. +func ConvAgenticCallbackOutput(src callbacks.CallbackOutput) *AgenticCallbackOutput { + switch t := src.(type) { + case *AgenticCallbackOutput: + // when callback is triggered within component implementation, + // the output is usually already a typed *model.AgenticCallbackOutput + return t + case *schema.AgenticMessage: + // when callback is injected by graph node, not the component implementation itself, + // the output is the output of Agentic Model interface, which is *schema.AgenticMessage + return &AgenticCallbackOutput{ + AgenticMessage: t, + } + default: + return nil + } +} diff --git a/components/model/agentic_callback_extra_test.go b/components/model/agentic_callback_extra_test.go new file mode 100644 index 000000000..937367477 --- /dev/null +++ b/components/model/agentic_callback_extra_test.go @@ -0,0 +1,35 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package model + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" +) + +func TestConvAgenticModel(t *testing.T) { + assert.NotNil(t, ConvAgenticCallbackInput(&AgenticCallbackInput{})) + assert.NotNil(t, ConvAgenticCallbackInput([]*schema.AgenticMessage{})) + assert.Nil(t, ConvAgenticCallbackInput("asd")) + + assert.NotNil(t, ConvAgenticCallbackOutput(&AgenticCallbackOutput{})) + assert.NotNil(t, ConvAgenticCallbackOutput(&schema.AgenticMessage{})) + assert.Nil(t, ConvAgenticCallbackOutput("asd")) +} diff --git a/components/model/callback_extra.go b/components/model/callback_extra.go index ed9096d5c..8591c4373 100644 --- a/components/model/callback_extra.go +++ b/components/model/callback_extra.go @@ -66,8 +66,6 @@ type Config struct { type CallbackInput struct { // Messages is the messages to be sent to the model. Messages []*schema.Message - // AgenticMessages is the agentic messages to be sent to the agentic model. - AgenticMessages []*schema.AgenticMessage // Tools is the tools to be used in the model. Tools []*schema.ToolInfo // ToolChoice is the tool choice, which controls the tool to be used in the model. @@ -82,8 +80,6 @@ type CallbackInput struct { type CallbackOutput struct { // Message is the message generated by the model. Message *schema.Message - // AgenticMessage is the agentic message generated by the agentic model. - AgenticMessage *schema.AgenticMessage // Config is the config for the model. Config *Config // TokenUsage is the token usage of this request. @@ -101,10 +97,6 @@ func ConvCallbackInput(src callbacks.CallbackInput) *CallbackInput { return &CallbackInput{ Messages: t, } - case []*schema.AgenticMessage: // when callback is injected by graph node, not the component implementation itself, the input is the input of Agentic Model interface, which is []*schema.AgenticMessage - return &CallbackInput{ - AgenticMessages: t, - } default: return nil } @@ -119,10 +111,6 @@ func ConvCallbackOutput(src callbacks.CallbackOutput) *CallbackOutput { return &CallbackOutput{ Message: t, } - case *schema.AgenticMessage: // when callback is injected by graph node, not the component implementation itself, the output is the output of Agentic Model interface, which is *schema.AgenticMessage - return &CallbackOutput{ - AgenticMessage: t, - } default: return nil } diff --git a/components/model/option_test.go b/components/model/option_test.go index aa43e6e01..c836933b7 100644 --- a/components/model/option_test.go +++ b/components/model/option_test.go @@ -87,7 +87,7 @@ func TestOptions(t *testing.T) { var ( toolChoice = schema.ToolChoiceForced allowedTools = []*schema.AllowedTool{ - {FunctionToolName: "agentic_tool"}, + {FunctionName: "agentic_tool"}, } ) opts := GetCommonOptions( diff --git a/components/prompt/agentic_callback_extra.go b/components/prompt/agentic_callback_extra.go new file mode 100644 index 000000000..1170854a1 --- /dev/null +++ b/components/prompt/agentic_callback_extra.go @@ -0,0 +1,70 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package prompt + +import ( + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/schema" +) + +// AgenticCallbackInput is the input for the callback. +type AgenticCallbackInput struct { + // Variables is the variables for the callback. + Variables map[string]any + // AgenticTemplates is the agentic templates for the callback. + AgenticTemplates []schema.AgenticMessagesTemplate + // Extra is the extra information for the callback. + Extra map[string]any +} + +// AgenticCallbackOutput is the output for the callback. +type AgenticCallbackOutput struct { + // AgenticResult is the agentic result for the callback. + AgenticResult []*schema.AgenticMessage + // AgenticTemplates is the agentic templates for the callback. + AgenticTemplates []schema.AgenticMessagesTemplate + // Extra is the extra information for the callback. + Extra map[string]any +} + +// ConvAgenticCallbackInput converts the callback input to the agentic prompt callback input. +func ConvAgenticCallbackInput(src callbacks.CallbackInput) *AgenticCallbackInput { + switch t := src.(type) { + case *AgenticCallbackInput: + return t + case map[string]any: + return &AgenticCallbackInput{ + Variables: t, + } + default: + return nil + } +} + +// ConvAgenticCallbackOutput converts the callback output to the agentic prompt callback output. +func ConvAgenticCallbackOutput(src callbacks.CallbackOutput) *AgenticCallbackOutput { + switch t := src.(type) { + case *AgenticCallbackOutput: + return t + case []*schema.AgenticMessage: + return &AgenticCallbackOutput{ + AgenticResult: t, + } + default: + return nil + } +} diff --git a/components/prompt/agentic_callback_extra_test.go b/components/prompt/agentic_callback_extra_test.go new file mode 100644 index 000000000..6dda1a349 --- /dev/null +++ b/components/prompt/agentic_callback_extra_test.go @@ -0,0 +1,46 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package prompt + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" +) + +func TestConvAgenticPrompt(t *testing.T) { + assert.NotNil(t, ConvAgenticCallbackInput(&AgenticCallbackInput{ + Variables: map[string]any{}, + AgenticTemplates: []schema.AgenticMessagesTemplate{ + &schema.AgenticMessage{}, + }, + })) + assert.NotNil(t, ConvAgenticCallbackInput(map[string]any{})) + assert.Nil(t, ConvAgenticCallbackInput("asd")) + + assert.NotNil(t, ConvAgenticCallbackOutput(&AgenticCallbackOutput{ + AgenticResult: []*schema.AgenticMessage{ + {}, + }, + AgenticTemplates: []schema.AgenticMessagesTemplate{ + &schema.AgenticMessage{}, + }, + })) + assert.NotNil(t, ConvAgenticCallbackOutput([]*schema.AgenticMessage{})) +} diff --git a/components/prompt/agentic_chat_template.go b/components/prompt/agentic_chat_template.go index 512a60ecd..c6c300d5a 100644 --- a/components/prompt/agentic_chat_template.go +++ b/components/prompt/agentic_chat_template.go @@ -45,7 +45,7 @@ type DefaultAgenticChatTemplate struct { func (t *DefaultAgenticChatTemplate) Format(ctx context.Context, vs map[string]any, opts ...Option) (result []*schema.AgenticMessage, err error) { ctx = callbacks.EnsureRunInfo(ctx, t.GetType(), components.ComponentOfAgenticPrompt) - ctx = callbacks.OnStart(ctx, &CallbackInput{ + ctx = callbacks.OnStart(ctx, &AgenticCallbackInput{ Variables: vs, AgenticTemplates: t.templates, }) @@ -65,7 +65,7 @@ func (t *DefaultAgenticChatTemplate) Format(ctx context.Context, vs map[string]a result = append(result, msgs...) } - _ = callbacks.OnEnd(ctx, &CallbackOutput{ + _ = callbacks.OnEnd(ctx, &AgenticCallbackOutput{ AgenticResult: result, AgenticTemplates: t.templates, }) diff --git a/components/prompt/callback_extra.go b/components/prompt/callback_extra.go index 4c27f37c6..324a418f3 100644 --- a/components/prompt/callback_extra.go +++ b/components/prompt/callback_extra.go @@ -27,8 +27,6 @@ type CallbackInput struct { Variables map[string]any // Templates is the templates for the callback. Templates []schema.MessagesTemplate - // AgenticTemplates is the agentic templates for the callback. - AgenticTemplates []schema.AgenticMessagesTemplate // Extra is the extra information for the callback. Extra map[string]any } @@ -37,12 +35,8 @@ type CallbackInput struct { type CallbackOutput struct { // Result is the result for the callback. Result []*schema.Message - // AgenticResult is the agentic result for the callback. - AgenticResult []*schema.AgenticMessage // Templates is the templates for the callback. Templates []schema.MessagesTemplate - // AgenticTemplates is the agentic templates for the callback. - AgenticTemplates []schema.AgenticMessagesTemplate // Extra is the extra information for the callback. Extra map[string]any } @@ -70,10 +64,6 @@ func ConvCallbackOutput(src callbacks.CallbackOutput) *CallbackOutput { return &CallbackOutput{ Result: t, } - case []*schema.AgenticMessage: - return &CallbackOutput{ - AgenticResult: t, - } default: return nil } diff --git a/components/prompt/callback_extra_test.go b/components/prompt/callback_extra_test.go index 4b48ec114..ad8a3c0c2 100644 --- a/components/prompt/callback_extra_test.go +++ b/components/prompt/callback_extra_test.go @@ -26,27 +26,20 @@ import ( func TestConvPrompt(t *testing.T) { assert.NotNil(t, ConvCallbackInput(&CallbackInput{ - AgenticTemplates: []schema.AgenticMessagesTemplate{ - &schema.AgenticMessage{}, + Templates: []schema.MessagesTemplate{ + &schema.Message{}, }, })) assert.NotNil(t, ConvCallbackInput(map[string]any{})) assert.Nil(t, ConvCallbackInput("asd")) assert.NotNil(t, ConvCallbackOutput(&CallbackOutput{ - AgenticResult: []*schema.AgenticMessage{ + Result: []*schema.Message{ {}, }, - AgenticTemplates: []schema.AgenticMessagesTemplate{ - &schema.AgenticMessage{}, + Templates: []schema.MessagesTemplate{ + &schema.Message{}, }, })) assert.NotNil(t, ConvCallbackOutput([]*schema.Message{})) - - agenticResult := []*schema.AgenticMessage{{}} - out := ConvCallbackOutput(agenticResult) - assert.NotNil(t, out) - assert.Equal(t, agenticResult, out.AgenticResult) - - assert.Nil(t, ConvCallbackOutput("asd")) } diff --git a/schema/agentic_message.go b/schema/agentic_message.go index a4554d38e..743f67855 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -59,7 +59,6 @@ const ( type AgenticRoleType string const ( - AgenticRoleTypeDeveloper AgenticRoleType = "developer" AgenticRoleTypeSystem AgenticRoleType = "system" AgenticRoleTypeUser AgenticRoleType = "user" AgenticRoleTypeAssistant AgenticRoleType = "assistant" @@ -426,14 +425,6 @@ type MCPToolApprovalResponse struct { Reason string } -// DeveloperAgenticMessage represents a message with AgenticRoleType "developer". -func DeveloperAgenticMessage(text string) *AgenticMessage { - return &AgenticMessage{ - Role: AgenticRoleTypeDeveloper, - ContentBlocks: []*ContentBlock{NewContentBlock(&UserInputText{Text: text})}, - } -} - // SystemAgenticMessage represents a message with AgenticRoleType "system". func SystemAgenticMessage(text string) *AgenticMessage { return &AgenticMessage{ diff --git a/schema/agentic_message_test.go b/schema/agentic_message_test.go index 4beb74930..144c0077e 100644 --- a/schema/agentic_message_test.go +++ b/schema/agentic_message_test.go @@ -1544,15 +1544,6 @@ response_meta: }) } -func TestDeveloperAgenticMessage(t *testing.T) { - t.Run("basic", func(t *testing.T) { - msg := DeveloperAgenticMessage("developer") - assert.Equal(t, AgenticRoleTypeDeveloper, msg.Role) - assert.Len(t, msg.ContentBlocks, 1) - assert.Equal(t, "developer", msg.ContentBlocks[0].UserInputText.Text) - }) -} - func TestSystemAgenticMessage(t *testing.T) { t.Run("basic", func(t *testing.T) { msg := SystemAgenticMessage("system") diff --git a/schema/tool.go b/schema/tool.go index a067d87db..c195d1f9e 100644 --- a/schema/tool.go +++ b/schema/tool.go @@ -83,13 +83,15 @@ type AgenticForcedToolChoice struct { } // AllowedTool represents a tool that the model is allowed or forced to call. -// Exactly one of FunctionToolName, MCPTool, or ServerTool must be specified. +// Exactly one of FunctionName, MCPTool, or ServerTool must be specified. type AllowedTool struct { - // FunctionToolName is the name of the function tool. - FunctionToolName string + // FunctionName specifies a function tool by name. + FunctionName string + // MCPTool specifies an MCP tool. MCPTool *AllowedMCPTool + // ServerTool specifies a server tool. ServerTool *AllowedServerTool } From 1bd56ce35f455b994f28082105d67569254a115e Mon Sep 17 00:00:00 2001 From: mrh997 Date: Thu, 15 Jan 2026 21:49:22 +0800 Subject: [PATCH 25/59] feat: improve callback definition (#692) --- components/model/agentic_callback_extra.go | 12 ++++++------ components/prompt/agentic_callback_extra.go | 14 +++++++------- components/prompt/agentic_callback_extra_test.go | 6 +++--- components/prompt/agentic_chat_template.go | 8 ++++---- 4 files changed, 20 insertions(+), 20 deletions(-) diff --git a/components/model/agentic_callback_extra.go b/components/model/agentic_callback_extra.go index 28dd366e6..54d49ff72 100644 --- a/components/model/agentic_callback_extra.go +++ b/components/model/agentic_callback_extra.go @@ -33,8 +33,8 @@ type AgenticConfig struct { // AgenticCallbackInput is the input for the agentic model callback. type AgenticCallbackInput struct { - // AgenticMessages is the agentic messages to be sent to the agentic model. - AgenticMessages []*schema.AgenticMessage + // Messages is the agentic messages to be sent to the agentic model. + Messages []*schema.AgenticMessage // Tools is the tools to be used in the agentic model. Tools []*schema.ToolInfo // Config is the config for the agentic model. @@ -45,8 +45,8 @@ type AgenticCallbackInput struct { // AgenticCallbackOutput is the output for the agentic model callback. type AgenticCallbackOutput struct { - // AgenticMessage is the agentic message generated by the agentic model. - AgenticMessage *schema.AgenticMessage + // Message is the agentic message generated by the agentic model. + Message *schema.AgenticMessage // Config is the config for the agentic model. Config *AgenticConfig // TokenUsage is the token usage of this request. @@ -66,7 +66,7 @@ func ConvAgenticCallbackInput(src callbacks.CallbackInput) *AgenticCallbackInput // when callback is injected by graph node, not the component implementation itself, // the input is the input of Agentic Model interface, which is []*schema.AgenticMessage return &AgenticCallbackInput{ - AgenticMessages: t, + Messages: t, } default: return nil @@ -84,7 +84,7 @@ func ConvAgenticCallbackOutput(src callbacks.CallbackOutput) *AgenticCallbackOut // when callback is injected by graph node, not the component implementation itself, // the output is the output of Agentic Model interface, which is *schema.AgenticMessage return &AgenticCallbackOutput{ - AgenticMessage: t, + Message: t, } default: return nil diff --git a/components/prompt/agentic_callback_extra.go b/components/prompt/agentic_callback_extra.go index 1170854a1..315d5a4da 100644 --- a/components/prompt/agentic_callback_extra.go +++ b/components/prompt/agentic_callback_extra.go @@ -25,18 +25,18 @@ import ( type AgenticCallbackInput struct { // Variables is the variables for the callback. Variables map[string]any - // AgenticTemplates is the agentic templates for the callback. - AgenticTemplates []schema.AgenticMessagesTemplate + // Templates is the agentic templates for the callback. + Templates []schema.AgenticMessagesTemplate // Extra is the extra information for the callback. Extra map[string]any } // AgenticCallbackOutput is the output for the callback. type AgenticCallbackOutput struct { - // AgenticResult is the agentic result for the callback. - AgenticResult []*schema.AgenticMessage - // AgenticTemplates is the agentic templates for the callback. - AgenticTemplates []schema.AgenticMessagesTemplate + // Result is the agentic result for the callback. + Result []*schema.AgenticMessage + // Templates is the agentic templates for the callback. + Templates []schema.AgenticMessagesTemplate // Extra is the extra information for the callback. Extra map[string]any } @@ -62,7 +62,7 @@ func ConvAgenticCallbackOutput(src callbacks.CallbackOutput) *AgenticCallbackOut return t case []*schema.AgenticMessage: return &AgenticCallbackOutput{ - AgenticResult: t, + Result: t, } default: return nil diff --git a/components/prompt/agentic_callback_extra_test.go b/components/prompt/agentic_callback_extra_test.go index 6dda1a349..67982be80 100644 --- a/components/prompt/agentic_callback_extra_test.go +++ b/components/prompt/agentic_callback_extra_test.go @@ -27,7 +27,7 @@ import ( func TestConvAgenticPrompt(t *testing.T) { assert.NotNil(t, ConvAgenticCallbackInput(&AgenticCallbackInput{ Variables: map[string]any{}, - AgenticTemplates: []schema.AgenticMessagesTemplate{ + Templates: []schema.AgenticMessagesTemplate{ &schema.AgenticMessage{}, }, })) @@ -35,10 +35,10 @@ func TestConvAgenticPrompt(t *testing.T) { assert.Nil(t, ConvAgenticCallbackInput("asd")) assert.NotNil(t, ConvAgenticCallbackOutput(&AgenticCallbackOutput{ - AgenticResult: []*schema.AgenticMessage{ + Result: []*schema.AgenticMessage{ {}, }, - AgenticTemplates: []schema.AgenticMessagesTemplate{ + Templates: []schema.AgenticMessagesTemplate{ &schema.AgenticMessage{}, }, })) diff --git a/components/prompt/agentic_chat_template.go b/components/prompt/agentic_chat_template.go index c6c300d5a..41d291065 100644 --- a/components/prompt/agentic_chat_template.go +++ b/components/prompt/agentic_chat_template.go @@ -46,8 +46,8 @@ type DefaultAgenticChatTemplate struct { func (t *DefaultAgenticChatTemplate) Format(ctx context.Context, vs map[string]any, opts ...Option) (result []*schema.AgenticMessage, err error) { ctx = callbacks.EnsureRunInfo(ctx, t.GetType(), components.ComponentOfAgenticPrompt) ctx = callbacks.OnStart(ctx, &AgenticCallbackInput{ - Variables: vs, - AgenticTemplates: t.templates, + Variables: vs, + Templates: t.templates, }) defer func() { if err != nil { @@ -66,8 +66,8 @@ func (t *DefaultAgenticChatTemplate) Format(ctx context.Context, vs map[string]a } _ = callbacks.OnEnd(ctx, &AgenticCallbackOutput{ - AgenticResult: result, - AgenticTemplates: t.templates, + Result: result, + Templates: t.templates, }) return result, nil From 896141b0853fe1241c07634d06fa8c5ab1bd66e5 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Mon, 19 Jan 2026 20:58:32 +0800 Subject: [PATCH 26/59] feat: improve callback definition (#702) --- schema/agentic_message.go | 61 +++++++------------------------- schema/agentic_message_test.go | 45 ++++++++--------------- schema/tool.go | 4 +++ utils/callbacks/template.go | 14 ++++---- utils/callbacks/template_test.go | 8 ++--- 5 files changed, 41 insertions(+), 91 deletions(-) diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 743f67855..ead2d866d 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -270,19 +270,12 @@ type AssistantGenVideo struct { } type Reasoning struct { - // Summary is the reasoning content summary. - Summary []*ReasoningSummary - - // EncryptedContent is the encrypted reasoning content. - EncryptedContent string -} - -type ReasoningSummary struct { - // Index specifies the index position of this summary in the final Reasoning. - Index int - - // Text is the reasoning content summary. + // Text is either the thought summary or the raw reasoning text itself. Text string + + // Signature contains encrypted reasoning tokens. + // Required by some models when passing reasoning text back. + Signature string } type FunctionToolCall struct { @@ -1172,42 +1165,15 @@ func concatReasoning(reasons []*Reasoning) (*Reasoning, error) { ret := &Reasoning{} - var allSummaries []*ReasoningSummary for _, r := range reasons { - if r == nil { - continue + if r.Text != "" { + ret.Text += r.Text } - allSummaries = append(allSummaries, r.Summary...) - if r.EncryptedContent != "" { - ret.EncryptedContent += r.EncryptedContent + if r.Signature != "" { + ret.Signature += r.Signature } } - var ( - indices []int - indexToSummary = map[int]*ReasoningSummary{} - ) - - for _, s := range allSummaries { - if s == nil { - continue - } - if indexToSummary[s.Index] == nil { - indexToSummary[s.Index] = &ReasoningSummary{} - indices = append(indices, s.Index) - } - indexToSummary[s.Index].Text += s.Text - } - - sort.Slice(indices, func(i, j int) bool { - return indices[i] < indices[j] - }) - - ret.Summary = make([]*ReasoningSummary, 0, len(indices)) - for _, idx := range indices { - ret.Summary = append(ret.Summary, indexToSummary[idx]) - } - return ret, nil } @@ -1899,12 +1865,9 @@ func (b *ContentBlock) String() string { // String returns the string representation of Reasoning. func (r *Reasoning) String() string { sb := &strings.Builder{} - sb.WriteString(fmt.Sprintf(" summary: %d items\n", len(r.Summary))) - for _, s := range r.Summary { - sb.WriteString(fmt.Sprintf(" [%d] %s\n", s.Index, s.Text)) - } - if r.EncryptedContent != "" { - sb.WriteString(fmt.Sprintf(" encrypted_content: %s\n", truncateString(r.EncryptedContent, 50))) + sb.WriteString(fmt.Sprintf(" text: %s\n", r.Text)) + if r.Signature != "" { + sb.WriteString(fmt.Sprintf(" signature: %s\n", truncateString(r.Signature, 50))) } return sb.String() } diff --git a/schema/agentic_message_test.go b/schema/agentic_message_test.go index 144c0077e..e8a1003f5 100644 --- a/schema/agentic_message_test.go +++ b/schema/agentic_message_test.go @@ -109,9 +109,7 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeReasoning, Reasoning: &Reasoning{ - Summary: []*ReasoningSummary{ - {Index: 0, Text: "First "}, - }, + Text: "First ", }, StreamingMeta: &StreamingMeta{Index: 0}, }, @@ -123,9 +121,7 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeReasoning, Reasoning: &Reasoning{ - Summary: []*ReasoningSummary{ - {Index: 0, Text: "Second"}, - }, + Text: "Second", }, StreamingMeta: &StreamingMeta{Index: 0}, }, @@ -136,9 +132,7 @@ func TestConcatAgenticMessages(t *testing.T) { result, err := ConcatAgenticMessages(msgs) assert.NoError(t, err) assert.Len(t, result.ContentBlocks, 1) - assert.Len(t, result.ContentBlocks[0].Reasoning.Summary, 1) - assert.Equal(t, "First Second", result.ContentBlocks[0].Reasoning.Summary[0].Text) - assert.Equal(t, 0, result.ContentBlocks[0].Reasoning.Summary[0].Index) + assert.Equal(t, "First Second", result.ContentBlocks[0].Reasoning.Text) }) t.Run("concat reasoning with index", func(t *testing.T) { @@ -149,10 +143,7 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeReasoning, Reasoning: &Reasoning{ - Summary: []*ReasoningSummary{ - {Index: 0, Text: "Part1-"}, - {Index: 1, Text: "Part2-"}, - }, + Text: "Part1-", }, StreamingMeta: &StreamingMeta{Index: 0}, }, @@ -164,10 +155,7 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeReasoning, Reasoning: &Reasoning{ - Summary: []*ReasoningSummary{ - {Index: 0, Text: "Part3"}, - {Index: 1, Text: "Part4"}, - }, + Text: "Part3", }, StreamingMeta: &StreamingMeta{Index: 0}, }, @@ -178,9 +166,7 @@ func TestConcatAgenticMessages(t *testing.T) { result, err := ConcatAgenticMessages(msgs) assert.NoError(t, err) assert.Len(t, result.ContentBlocks, 1) - assert.Len(t, result.ContentBlocks[0].Reasoning.Summary, 2) - assert.Equal(t, "Part1-Part3", result.ContentBlocks[0].Reasoning.Summary[0].Text) - assert.Equal(t, "Part2-Part4", result.ContentBlocks[0].Reasoning.Summary[1].Text) + assert.Equal(t, "Part1-Part3", result.ContentBlocks[0].Reasoning.Text) }) t.Run("concat user input text", func(t *testing.T) { @@ -1292,12 +1278,10 @@ func TestAgenticMessageString(t *testing.T) { { Type: ContentBlockTypeReasoning, Reasoning: &Reasoning{ - Summary: []*ReasoningSummary{ - {Index: 0, Text: "First, I need to identify the location (New York City) from the user's query."}, - {Index: 1, Text: "Then, I should call the weather API to get current conditions."}, - {Index: 2, Text: "Finally, I'll format the response in a user-friendly way with temperature and conditions."}, - }, - EncryptedContent: "encrypted_reasoning_content_that_is_very_long_and_will_be_truncated_for_display", + Text: "First, I need to identify the location (New York City) from the user's query.\n" + + "Then, I should call the weather API to get current conditions.\n" + + "Finally, I'll format the response in a user-friendly way with temperature and conditions.", + Signature: "encrypted_reasoning_content_that_is_very_long_and_will_be_truncated_for_display", }, }, { @@ -1432,11 +1416,10 @@ content_blocks: base64_data: gen_video_data... (14 bytes) mime_type: video/mp4 [9] type: reasoning - summary: 3 items - [0] First, I need to identify the location (New York City) from the user's query. - [1] Then, I should call the weather API to get current conditions. - [2] Finally, I'll format the response in a user-friendly way with temperature and conditions. - encrypted_content: encrypted_reasoning_content_that_is_very_long_and_... + text: First, I need to identify the location (New York City) from the user's query. +Then, I should call the weather API to get current conditions. +Finally, I'll format the response in a user-friendly way with temperature and conditions. + signature: encrypted_reasoning_content_that_is_very_long_and_... [10] type: function_tool_call call_id: call_weather_123 name: get_current_weather diff --git a/schema/tool.go b/schema/tool.go index c195d1f9e..a49306047 100644 --- a/schema/tool.go +++ b/schema/tool.go @@ -62,9 +62,13 @@ const ( type AgenticToolChoice struct { // Type is the tool choice mode. Type ToolChoice + // Allowed optionally specifies the list of tools that the model is permitted to call. + // Optional. Allowed *AgenticAllowedToolChoice + // Forced optionally specifies the list of tools that the model is required to call. + // Optional. Forced *AgenticForcedToolChoice } diff --git a/utils/callbacks/template.go b/utils/callbacks/template.go index 4c73e6bbc..4c2c709da 100644 --- a/utils/callbacks/template.go +++ b/utils/callbacks/template.go @@ -187,7 +187,7 @@ func (c *handlerTemplate) OnStart(ctx context.Context, info *callbacks.RunInfo, case components.ComponentOfChatModel: return c.chatModelHandler.OnStart(ctx, info, model.ConvCallbackInput(input)) case components.ComponentOfAgenticModel: - return c.agenticModelHandler.OnStart(ctx, info, model.ConvCallbackInput(input)) + return c.agenticModelHandler.OnStart(ctx, info, model.ConvAgenticCallbackInput(input)) case components.ComponentOfEmbedding: return c.embeddingHandler.OnStart(ctx, info, embedding.ConvCallbackInput(input)) case components.ComponentOfIndexer: @@ -226,7 +226,7 @@ func (c *handlerTemplate) OnEnd(ctx context.Context, info *callbacks.RunInfo, ou case components.ComponentOfChatModel: return c.chatModelHandler.OnEnd(ctx, info, model.ConvCallbackOutput(output)) case components.ComponentOfAgenticModel: - return c.agenticModelHandler.OnEnd(ctx, info, model.ConvCallbackOutput(output)) + return c.agenticModelHandler.OnEnd(ctx, info, model.ConvAgenticCallbackOutput(output)) case components.ComponentOfEmbedding: return c.embeddingHandler.OnEnd(ctx, info, embedding.ConvCallbackOutput(output)) case components.ComponentOfIndexer: @@ -316,8 +316,8 @@ func (c *handlerTemplate) OnEndWithStreamOutput(ctx context.Context, info *callb })) case components.ComponentOfAgenticModel: return c.agenticModelHandler.OnEndWithStreamOutput(ctx, info, - schema.StreamReaderWithConvert(output, func(item callbacks.CallbackOutput) (*model.CallbackOutput, error) { - return model.ConvCallbackOutput(item), nil + schema.StreamReaderWithConvert(output, func(item callbacks.CallbackOutput) (*model.AgenticCallbackOutput, error) { + return model.ConvAgenticCallbackOutput(item), nil })) case components.ComponentOfTool: return c.toolHandler.OnEndWithStreamOutput(ctx, info, @@ -686,9 +686,9 @@ func (ch *AgenticPromptCallbackHandler) Needed(ctx context.Context, runInfo *cal // AgenticModelCallbackHandler is the handler for the agentic chat model callback. type AgenticModelCallbackHandler struct { - OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.CallbackInput) context.Context - OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *model.CallbackOutput) context.Context - OnEndWithStreamOutput func(ctx context.Context, runInfo *callbacks.RunInfo, output *schema.StreamReader[*model.CallbackOutput]) context.Context + OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.AgenticCallbackInput) context.Context + OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *model.AgenticCallbackOutput) context.Context + OnEndWithStreamOutput func(ctx context.Context, runInfo *callbacks.RunInfo, output *schema.StreamReader[*model.AgenticCallbackOutput]) context.Context OnError func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context } diff --git a/utils/callbacks/template_test.go b/utils/callbacks/template_test.go index f599e5300..dcc0e5c7f 100644 --- a/utils/callbacks/template_test.go +++ b/utils/callbacks/template_test.go @@ -143,15 +143,15 @@ func TestNewComponentTemplate(t *testing.T) { return ctx }).Build()). AgenticModel(&AgenticModelCallbackHandler{ - OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.CallbackInput) context.Context { + OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.AgenticCallbackInput) context.Context { cnt++ return ctx }, - OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *model.CallbackOutput) context.Context { + OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *model.AgenticCallbackOutput) context.Context { cnt++ return ctx }, - OnEndWithStreamOutput: func(ctx context.Context, runInfo *callbacks.RunInfo, output *schema.StreamReader[*model.CallbackOutput]) context.Context { + OnEndWithStreamOutput: func(ctx context.Context, runInfo *callbacks.RunInfo, output *schema.StreamReader[*model.AgenticCallbackOutput]) context.Context { output.Close() cnt++ return ctx @@ -485,7 +485,7 @@ func TestNewComponentTemplate(t *testing.T) { // Set it now tpl2.AgenticModel(&AgenticModelCallbackHandler{ - OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.CallbackInput) context.Context { + OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.AgenticCallbackInput) context.Context { return ctx }, }) From f3b404576a673e21028d14353a8f895131415258 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Mon, 19 Jan 2026 22:12:10 +0800 Subject: [PATCH 27/59] feat: agentic model support MaxTokens (#703) --- components/model/agentic_callback_extra.go | 2 ++ components/model/option.go | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/components/model/agentic_callback_extra.go b/components/model/agentic_callback_extra.go index 54d49ff72..9a769cf7e 100644 --- a/components/model/agentic_callback_extra.go +++ b/components/model/agentic_callback_extra.go @@ -25,6 +25,8 @@ import ( type AgenticConfig struct { // Model is the model name. Model string + // MaxTokens is the max number of output tokens, if reached the max tokens, the model will stop generating. + MaxTokens int // Temperature is the temperature, which controls the randomness of the agentic model. Temperature float32 // TopP is the top p, which controls the diversity of the agentic model. diff --git a/components/model/option.go b/components/model/option.go index a337b7af2..a46b71b19 100644 --- a/components/model/option.go +++ b/components/model/option.go @@ -28,13 +28,13 @@ type Options struct { TopP *float32 // Tools is a list of tools the model may call. Tools []*schema.ToolInfo + // MaxTokens is the max number of tokens, if reached the max tokens, the model will stop generating, and mostly return a finish reason of "length". + MaxTokens *int // Options only available for chat model. // ToolChoice controls which tool is called by the model. ToolChoice *schema.ToolChoice - // MaxTokens is the max number of tokens, if reached the max tokens, the model will stop generating, and mostly return an finish reason of "length". - MaxTokens *int // AllowedToolNames specifies a list of tool names that the model is allowed to call. // This allows for constraining the model to a specific subset of the available tools. AllowedToolNames []string From 87b4a9914719af13fc8161764b42fc146fc04134 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Tue, 20 Jan 2026 13:21:40 +0800 Subject: [PATCH 28/59] feat: agentic model support stop option --- components/model/option.go | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/components/model/option.go b/components/model/option.go index a46b71b19..936b0fbda 100644 --- a/components/model/option.go +++ b/components/model/option.go @@ -30,6 +30,8 @@ type Options struct { Tools []*schema.ToolInfo // MaxTokens is the max number of tokens, if reached the max tokens, the model will stop generating, and mostly return a finish reason of "length". MaxTokens *int + // Stop is the stop words for the model, which controls the stopping condition of the model. + Stop []string // Options only available for chat model. @@ -38,8 +40,6 @@ type Options struct { // AllowedToolNames specifies a list of tool names that the model is allowed to call. // This allows for constraining the model to a specific subset of the available tools. AllowedToolNames []string - // Stop is the stop words for the model, which controls the stopping condition of the model. - Stop []string // Options only available for agentic model. @@ -67,7 +67,6 @@ func WithTemperature(temperature float32) Option { } // WithMaxTokens is the option to set the max tokens for the model. -// Only available for ChatModel. func WithMaxTokens(maxTokens int) Option { return Option{ apply: func(opts *Options) { @@ -95,7 +94,6 @@ func WithTopP(topP float32) Option { } // WithStop is the option to set the stop words for the model. -// Only available for ChatModel. func WithStop(stop []string) Option { return Option{ apply: func(opts *Options) { From 35333ffeae51e083f2959018b82f796ce7720d02 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Fri, 13 Mar 2026 10:51:24 +0800 Subject: [PATCH 29/59] feat: add json tag for agentic message (#880) --- schema/agentic_message.go | 198 ++++++++++++++++++------------------ utils/callbacks/template.go | 2 +- 2 files changed, 100 insertions(+), 100 deletions(-) diff --git a/schema/agentic_message.go b/schema/agentic_message.go index ead2d866d..95e14c0df 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -66,356 +66,356 @@ const ( type AgenticMessage struct { // Role is the message role. - Role AgenticRoleType + Role AgenticRoleType `json:"role"` // ContentBlocks is the list of content blocks. - ContentBlocks []*ContentBlock + ContentBlocks []*ContentBlock `json:"content_blocks,omitempty"` // ResponseMeta is the response metadata. - ResponseMeta *AgenticResponseMeta + ResponseMeta *AgenticResponseMeta `json:"response_meta,omitempty"` // Extra is the additional information. - Extra map[string]any + Extra map[string]any `json:"extra,omitempty"` } type AgenticResponseMeta struct { // TokenUsage is the token usage. - TokenUsage *TokenUsage + TokenUsage *TokenUsage `json:"token_usage,omitempty"` // OpenAIExtension is the extension for OpenAI. - OpenAIExtension *openai.ResponseMetaExtension + OpenAIExtension *openai.ResponseMetaExtension `json:"openai_extension,omitempty"` // GeminiExtension is the extension for Gemini. - GeminiExtension *gemini.ResponseMetaExtension + GeminiExtension *gemini.ResponseMetaExtension `json:"gemini_extension,omitempty"` // ClaudeExtension is the extension for Claude. - ClaudeExtension *claude.ResponseMetaExtension + ClaudeExtension *claude.ResponseMetaExtension `json:"claude_extension,omitempty"` // Extension is the extension for other models, supplied by the component implementer. - Extension any + Extension any `json:"extension,omitempty"` } type ContentBlock struct { - Type ContentBlockType + Type ContentBlockType `json:"type"` // Reasoning contains the reasoning content generated by the model. - Reasoning *Reasoning + Reasoning *Reasoning `json:"reasoning,omitempty"` // UserInputText contains the text content provided by the user. - UserInputText *UserInputText + UserInputText *UserInputText `json:"user_input_text,omitempty"` // UserInputImage contains the image content provided by the user. - UserInputImage *UserInputImage + UserInputImage *UserInputImage `json:"user_input_image,omitempty"` // UserInputAudio contains the audio content provided by the user. - UserInputAudio *UserInputAudio + UserInputAudio *UserInputAudio `json:"user_input_audio,omitempty"` // UserInputVideo contains the video content provided by the user. - UserInputVideo *UserInputVideo + UserInputVideo *UserInputVideo `json:"user_input_video,omitempty"` // UserInputFile contains the file content provided by the user. - UserInputFile *UserInputFile + UserInputFile *UserInputFile `json:"user_input_file,omitempty"` // AssistantGenText contains the text content generated by the model. - AssistantGenText *AssistantGenText + AssistantGenText *AssistantGenText `json:"assistant_gen_text,omitempty"` // AssistantGenImage contains the image content generated by the model. - AssistantGenImage *AssistantGenImage + AssistantGenImage *AssistantGenImage `json:"assistant_gen_image,omitempty"` // AssistantGenAudio contains the audio content generated by the model. - AssistantGenAudio *AssistantGenAudio + AssistantGenAudio *AssistantGenAudio `json:"assistant_gen_audio,omitempty"` // AssistantGenVideo contains the video content generated by the model. - AssistantGenVideo *AssistantGenVideo + AssistantGenVideo *AssistantGenVideo `json:"assistant_gen_video,omitempty"` // FunctionToolCall contains the invocation details for a user-defined tool. - FunctionToolCall *FunctionToolCall + FunctionToolCall *FunctionToolCall `json:"function_tool_call,omitempty"` // FunctionToolResult contains the result returned from a user-defined tool call. - FunctionToolResult *FunctionToolResult + FunctionToolResult *FunctionToolResult `json:"function_tool_result,omitempty"` // ServerToolCall contains the invocation details for a provider built-in tool executed on the model server. - ServerToolCall *ServerToolCall + ServerToolCall *ServerToolCall `json:"server_tool_call,omitempty"` // ServerToolResult contains the result returned from a provider built-in tool executed on the model server. - ServerToolResult *ServerToolResult + ServerToolResult *ServerToolResult `json:"server_tool_result,omitempty"` // MCPToolCall contains the invocation details for an MCP tool managed by the model server. - MCPToolCall *MCPToolCall + MCPToolCall *MCPToolCall `json:"mcp_tool_call,omitempty"` // MCPToolResult contains the result returned from an MCP tool managed by the model server. - MCPToolResult *MCPToolResult + MCPToolResult *MCPToolResult `json:"mcp_tool_result,omitempty"` // MCPListToolsResult contains the list of available MCP tools reported by the model server. - MCPListToolsResult *MCPListToolsResult + MCPListToolsResult *MCPListToolsResult `json:"mcp_list_tools_result,omitempty"` // MCPToolApprovalRequest contains the user approval request for an MCP tool call when required. - MCPToolApprovalRequest *MCPToolApprovalRequest + MCPToolApprovalRequest *MCPToolApprovalRequest `json:"mcp_tool_approval_request,omitempty"` // MCPToolApprovalResponse contains the user's approval decision for an MCP tool call. - MCPToolApprovalResponse *MCPToolApprovalResponse + MCPToolApprovalResponse *MCPToolApprovalResponse `json:"mcp_tool_approval_response,omitempty"` // StreamingMeta contains metadata for streaming responses. - StreamingMeta *StreamingMeta + StreamingMeta *StreamingMeta `json:"streaming_meta,omitempty"` // Extra contains additional information for the content block. - Extra map[string]any + Extra map[string]any `json:"extra,omitempty"` } type StreamingMeta struct { // Index specifies the index position of this block in the final response. - Index int + Index int `json:"index"` } type UserInputText struct { // Text is the text content. - Text string + Text string `json:"text,omitempty"` } type UserInputImage struct { // URL is the HTTP/HTTPS link. - URL string + URL string `json:"url,omitempty"` // Base64Data is the binary data in Base64 encoded string format. - Base64Data string + Base64Data string `json:"base64_data,omitempty"` // MIMEType is the mime type, e.g. "image/png". - MIMEType string + MIMEType string `json:"mime_type,omitempty"` // Detail is the quality of the image url. - Detail ImageURLDetail + Detail ImageURLDetail `json:"detail,omitempty"` } type UserInputAudio struct { // URL is the HTTP/HTTPS link. - URL string + URL string `json:"url,omitempty"` // Base64Data is the binary data in Base64 encoded string format. - Base64Data string + Base64Data string `json:"base64_data,omitempty"` // MIMEType is the mime type, e.g. "audio/wav". - MIMEType string + MIMEType string `json:"mime_type,omitempty"` } type UserInputVideo struct { // URL is the HTTP/HTTPS link. - URL string + URL string `json:"url,omitempty"` // Base64Data is the binary data in Base64 encoded string format. - Base64Data string + Base64Data string `json:"base64_data,omitempty"` // MIMEType is the mime type, e.g. "video/mp4". - MIMEType string + MIMEType string `json:"mime_type,omitempty"` } type UserInputFile struct { // URL is the HTTP/HTTPS link. - URL string + URL string `json:"url,omitempty"` // Name is the filename. - Name string + Name string `json:"name,omitempty"` // Base64Data is the binary data in Base64 encoded string format. - Base64Data string + Base64Data string `json:"base64_data,omitempty"` // MIMEType is the mime type, e.g. "application/pdf". - MIMEType string + MIMEType string `json:"mime_type,omitempty"` } type AssistantGenText struct { // Text is the generated text. - Text string + Text string `json:"text,omitempty"` // OpenAIExtension is the extension for OpenAI. - OpenAIExtension *openai.AssistantGenTextExtension + OpenAIExtension *openai.AssistantGenTextExtension `json:"openai_extension,omitempty"` // ClaudeExtension is the extension for Claude. - ClaudeExtension *claude.AssistantGenTextExtension + ClaudeExtension *claude.AssistantGenTextExtension `json:"claude_extension,omitempty"` // Extension is the extension for other models, supplied by the component implementer. - Extension any + Extension any `json:"extension,omitempty"` } type AssistantGenImage struct { // URL is the HTTP/HTTPS link. - URL string + URL string `json:"url,omitempty"` // Base64Data is the binary data in Base64 encoded string format. - Base64Data string + Base64Data string `json:"base64_data,omitempty"` // MIMEType is the mime type, e.g. "image/png". - MIMEType string + MIMEType string `json:"mime_type,omitempty"` } type AssistantGenAudio struct { // URL is the HTTP/HTTPS link. - URL string + URL string `json:"url,omitempty"` // Base64Data is the binary data in Base64 encoded string format. - Base64Data string + Base64Data string `json:"base64_data,omitempty"` // MIMEType is the mime type, e.g. "audio/wav". - MIMEType string + MIMEType string `json:"mime_type,omitempty"` } type AssistantGenVideo struct { // URL is the HTTP/HTTPS link. - URL string + URL string `json:"url,omitempty"` // Base64Data is the binary data in Base64 encoded string format. - Base64Data string + Base64Data string `json:"base64_data,omitempty"` // MIMEType is the mime type, e.g. "video/mp4". - MIMEType string + MIMEType string `json:"mime_type,omitempty"` } type Reasoning struct { // Text is either the thought summary or the raw reasoning text itself. - Text string + Text string `json:"text,omitempty"` // Signature contains encrypted reasoning tokens. // Required by some models when passing reasoning text back. - Signature string + Signature string `json:"signature,omitempty"` } type FunctionToolCall struct { // CallID is the unique identifier for the tool call. - CallID string + CallID string `json:"call_id,omitempty"` // Name specifies the function tool invoked. - Name string + Name string `json:"name"` // Arguments is the JSON string arguments for the function tool call. - Arguments string + Arguments string `json:"arguments,omitempty"` } type FunctionToolResult struct { // CallID is the unique identifier for the tool call. - CallID string + CallID string `json:"call_id,omitempty"` // Name specifies the function tool invoked. - Name string + Name string `json:"name"` // Result is the function tool result returned by the user - Result string + Result string `json:"result,omitempty"` } type ServerToolCall struct { // Name specifies the server-side tool invoked. // Supplied by the model server (e.g., `web_search` for OpenAI, `googleSearch` for Gemini). - Name string + Name string `json:"name"` // CallID is the unique identifier for the tool call. // Empty if not provided by the model server. - CallID string + CallID string `json:"call_id,omitempty"` // Arguments are the raw inputs to the server-side tool, // supplied by the component implementer. - Arguments any + Arguments any `json:"arguments,omitempty"` } type ServerToolResult struct { // Name specifies the server-side tool invoked. // Supplied by the model server (e.g., `web_search` for OpenAI, `googleSearch` for Gemini). - Name string + Name string `json:"name"` // CallID is the unique identifier for the tool call. // Empty if not provided by the model server. - CallID string + CallID string `json:"call_id,omitempty"` // Result refers to the raw output generated by the server-side tool, // supplied by the component implementer. - Result any + Result any `json:"result,omitempty"` } type MCPToolCall struct { // ServerLabel is the MCP server label used to identify it in tool calls - ServerLabel string + ServerLabel string `json:"server_label,omitempty"` // ApprovalRequestID is the approval request ID. - ApprovalRequestID string + ApprovalRequestID string `json:"approval_request_id,omitempty"` // CallID is the unique ID of the tool call. - CallID string + CallID string `json:"call_id,omitempty"` // Name is the name of the tool to run. - Name string + Name string `json:"name"` // Arguments is the JSON string arguments for the tool call. - Arguments string + Arguments string `json:"arguments,omitempty"` } type MCPToolResult struct { // ServerLabel is the MCP server label used to identify it in tool calls - ServerLabel string + ServerLabel string `json:"server_label,omitempty"` // CallID is the unique ID of the tool call. - CallID string + CallID string `json:"call_id,omitempty"` // Name is the name of the tool to run. - Name string + Name string `json:"name"` // Result is the JSON string with the tool result. - Result string + Result string `json:"result,omitempty"` // Error returned when the server fails to run the tool. - Error *MCPToolCallError + Error *MCPToolCallError `json:"error,omitempty"` } type MCPToolCallError struct { // Code is the error code. - Code *int64 + Code *int64 `json:"code,omitempty"` // Message is the error message. - Message string + Message string `json:"message,omitempty"` } type MCPListToolsResult struct { // ServerLabel is the MCP server label used to identify it in tool calls. - ServerLabel string + ServerLabel string `json:"server_label,omitempty"` // Tools is the list of tools available on the server. - Tools []*MCPListToolsItem + Tools []*MCPListToolsItem `json:"tools,omitempty"` // Error returned when the server fails to list tools. - Error string + Error string `json:"error,omitempty"` } type MCPListToolsItem struct { // Name is the name of the tool. - Name string + Name string `json:"name"` // Description is the description of the tool. - Description string + Description string `json:"description"` // InputSchema is the JSON schema that describes the tool input parameters. - InputSchema *jsonschema.Schema + InputSchema *jsonschema.Schema `json:"input_schema,omitempty"` } type MCPToolApprovalRequest struct { // ID is the approval request ID. - ID string + ID string `json:"id,omitempty"` // Name is the name of the tool to run. - Name string + Name string `json:"name"` // Arguments is the JSON string arguments for the tool call. - Arguments string + Arguments string `json:"arguments,omitempty"` // ServerLabel is the MCP server label used to identify it in tool calls. - ServerLabel string + ServerLabel string `json:"server_label,omitempty"` } type MCPToolApprovalResponse struct { // ApprovalRequestID is the approval request ID being responded to. - ApprovalRequestID string + ApprovalRequestID string `json:"approval_request_id,omitempty"` // Approve indicates whether the request is approved. - Approve bool + Approve bool `json:"approve"` // Reason is the rationale for the decision. // Optional. - Reason string + Reason string `json:"reason,omitempty"` } // SystemAgenticMessage represents a message with AgenticRoleType "system". diff --git a/utils/callbacks/template.go b/utils/callbacks/template.go index 4c2c709da..f01a849b6 100644 --- a/utils/callbacks/template.go +++ b/utils/callbacks/template.go @@ -64,10 +64,10 @@ type HandlerHelper struct { transformerHandler *TransformerCallbackHandler toolHandler *ToolCallbackHandler toolsNodeHandler *ToolsNodeCallbackHandlers + agentHandler *AgentCallbackHandler agenticPromptHandler *AgenticPromptCallbackHandler agenticModelHandler *AgenticModelCallbackHandler agenticToolsNodeHandler *AgenticToolsNodeCallbackHandlers - agentHandler *AgentCallbackHandler composeTemplates map[components.Component]callbacks.Handler } From 7189d7c56ee461fa8cec7645382b867f578b3e44 Mon Sep 17 00:00:00 2001 From: Ryo Date: Fri, 13 Mar 2026 14:09:00 +0800 Subject: [PATCH 30/59] =?UTF-8?q?feat(adk):=20add=20agentmd=20middleware?= =?UTF-8?q?=20for=20auto-injecting=20Agents.md=20into=20m=E2=80=A6=20(#882?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit feat(adk): add agentmd middleware for auto-injecting Agents.md into model input Change-Id: I34add4f925a23c6d6821925c482a21f6cddfddd4 --- adk/middlewares/agentsmd/agentsmd.go | 183 +++ adk/middlewares/agentsmd/agentsmd_test.go | 1420 +++++++++++++++++++++ adk/middlewares/agentsmd/loader.go | 299 +++++ 3 files changed, 1902 insertions(+) create mode 100644 adk/middlewares/agentsmd/agentsmd.go create mode 100644 adk/middlewares/agentsmd/agentsmd_test.go create mode 100644 adk/middlewares/agentsmd/loader.go diff --git a/adk/middlewares/agentsmd/agentsmd.go b/adk/middlewares/agentsmd/agentsmd.go new file mode 100644 index 000000000..7d29896a7 --- /dev/null +++ b/adk/middlewares/agentsmd/agentsmd.go @@ -0,0 +1,183 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package agentsmd provides a middleware that automatically injects Agents.md +// file contents into model input messages. The injection is transient — content +// is prepended at model call time and never persisted to conversation state, +// so it is naturally excluded from summarization / compression. +package agentsmd + +import ( + "context" + "fmt" + + "github.com/cloudwego/eino/adk" + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/schema" +) + +// Config defines the configuration for the agentsmd middleware. +type Config struct { + // Backend provides file access for loading Agents.md files. + // Implementations can use local filesystem, remote storage, or any other backend. + // Required. + Backend Backend + + // AgentsMDFiles specifies the ordered list of Agents.md file paths to load. + // Files are loaded and injected in the given order. + // Supports @import syntax inside files for recursive inclusion (max depth 5). + AgentsMDFiles []string + + // AllAgentsMDMaxBytes limits the total byte size of all loaded Agents.md content. + // Files are loaded in order; once the cumulative size exceeds this limit, + // remaining files are skipped. Each individual file is always loaded in full. + // 0 means no limit. + AllAgentsMDMaxBytes int + + // OnLoadWarning is an optional callback invoked when a non-fatal error occurs + // during Agents.md file loading (e.g. file not found, circular @import, depth + // exceeded). If nil, warnings are logged via log.Printf. + // + // Note: Backend.Read errors other than os.ErrNotExist (e.g. permission denied, + // I/O errors) are NOT treated as warnings and will abort the loading process. + OnLoadWarning func(filePath string, err error) +} + +// New creates an agentsmd middleware that injects Agents.md content into every +// model call. The content is loaded from the configured file paths via Backend +// on each model invocation. +// +// Recommended: place this middleware AFTER the summarization middleware, so that +// Agents.md content is excluded from summarization/compression. +func New(_ context.Context, cfg *Config) (adk.ChatModelAgentMiddleware, error) { + if err := cfg.validate(); err != nil { + return nil, err + } + + return &middleware{ + BaseChatModelAgentMiddleware: &adk.BaseChatModelAgentMiddleware{}, + loader: newLoaderConfig(cfg.Backend, cfg.AgentsMDFiles, cfg.AllAgentsMDMaxBytes, cfg.OnLoadWarning), + }, nil +} + +type middleware struct { + *adk.BaseChatModelAgentMiddleware + loader *loaderConfig +} + +// WrapModel returns a proxy model that prepends Agents.md content to the input +// messages on every Generate/Stream call. The injected message is never written +// back to ChatModelAgentState, so summarization and reduction middlewares are +// unaffected. +func (m *middleware) WrapModel(_ context.Context, cm model.BaseChatModel, _ *adk.ModelContext) (model.BaseChatModel, error) { + return &agentMDModel{ + inner: cm, + loader: m.loader, + }, nil +} + +// agentMDModel wraps a BaseChatModel to prepend Agents.md content to input. +type agentMDModel struct { + inner model.BaseChatModel + loader *loaderConfig +} + +func (m *agentMDModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + messages, err := m.prependAgentMD(ctx, input) + if err != nil { + return nil, err + } + return m.inner.Generate(ctx, messages, opts...) +} + +func (m *agentMDModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + messages, err := m.prependAgentMD(ctx, input) + if err != nil { + return nil, err + } + return m.inner.Stream(ctx, messages, opts...) +} + +const agentsMDCacheKey = "__agentsmd_content_cache__" + +// prependAgentMD loads the current Agents.md content and inserts it before the +// first User role message. If all configured agent files are empty (or skipped), +// the original input is returned unchanged. +// The loaded content is cached in RunLocalValue for the duration of the agent Run(). +func (m *agentMDModel) prependAgentMD(ctx context.Context, input []*schema.Message) ([]*schema.Message, error) { + var content string + + // Try to get cached content from RunLocalValue. + if cached, found, err := adk.GetRunLocalValue(ctx, agentsMDCacheKey); err == nil && found { + if s, ok := cached.(string); ok { + content = s + } + } + + if content == "" { + var err error + content, err = m.loader.load(ctx) + if err != nil { + return nil, fmt.Errorf("[agentsmd]: failed to load agent files: %w", err) + } + // Cache the loaded content for subsequent model calls in this Run(). + if content != "" { + _ = adk.SetRunLocalValue(ctx, agentsMDCacheKey, content) + } + } + if content == "" { + return input, nil + } + + agentMDMsg := &schema.Message{ + Role: schema.User, + Content: content, + } + + // Insert agentMDMsg before the first User role message. + messages := make([]*schema.Message, 0, len(input)+1) + inserted := false + for i, msg := range input { + if !inserted && msg.Role == schema.User { + messages = append(messages, agentMDMsg) + messages = append(messages, input[i:]...) + inserted = true + break + } + messages = append(messages, msg) + } + if !inserted { + // No User message found; append at the end as fallback. + messages = append(messages, agentMDMsg) + } + return messages, nil +} + +func (c *Config) validate() error { + if c == nil { + return fmt.Errorf("[agentsmd]: config is required") + } + if c.Backend == nil { + return fmt.Errorf("[agentsmd]: backend is required") + } + if len(c.AgentsMDFiles) == 0 { + return fmt.Errorf("[agentsmd]: at least one agent file path is required") + } + if c.AllAgentsMDMaxBytes < 0 { + return fmt.Errorf("[agentsmd]: AllAgentMDDocsMaxBytes must be non-negative") + } + return nil +} diff --git a/adk/middlewares/agentsmd/agentsmd_test.go b/adk/middlewares/agentsmd/agentsmd_test.go new file mode 100644 index 000000000..e3d7e00e9 --- /dev/null +++ b/adk/middlewares/agentsmd/agentsmd_test.go @@ -0,0 +1,1420 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package agentsmd + +import ( + "context" + "fmt" + "os" + "strings" + "testing" + + "github.com/cloudwego/eino/adk/filesystem" + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/schema" +) + +// --- test helpers --- + +type mockModel struct { + lastInput []*schema.Message +} + +func (m *mockModel) Generate(_ context.Context, input []*schema.Message, _ ...model.Option) (*schema.Message, error) { + m.lastInput = input + return &schema.Message{Role: schema.Assistant, Content: "ok"}, nil +} + +func (m *mockModel) Stream(_ context.Context, input []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + m.lastInput = input + return nil, nil +} + +type memBackend struct { + files map[string]string +} + +func newMemBackend() *memBackend { + return &memBackend{files: make(map[string]string)} +} + +func (b *memBackend) set(path string, content string) { + b.files[path] = content +} + +func (b *memBackend) Read(_ context.Context, req *ReadRequest) (*filesystem.FileContent, error) { + content, ok := b.files[req.FilePath] + if !ok { + return nil, fmt.Errorf("file not found: %s: %w", req.FilePath, os.ErrNotExist) + } + return &filesystem.FileContent{Content: content}, nil +} + +// errBackend always returns a non-ErrNotExist error on Read, simulating I/O failures. +type errBackend struct{} + +func (b *errBackend) Read(_ context.Context, req *ReadRequest) (*filesystem.FileContent, error) { + return nil, fmt.Errorf("permission denied: %s", req.FilePath) +} + +// partialErrBackend returns content for known files and I/O error for others. +type partialErrBackend struct { + files map[string]string +} + +func (b *partialErrBackend) Read(_ context.Context, req *ReadRequest) (*filesystem.FileContent, error) { + content, ok := b.files[req.FilePath] + if !ok { + return nil, fmt.Errorf("I/O error reading %s", req.FilePath) + } + return &filesystem.FileContent{Content: content}, nil +} + +// --- tests --- + +func TestNew_Validation(t *testing.T) { + ctx := context.Background() + b := newMemBackend() + + _, err := New(ctx, nil) + if err == nil { + t.Fatal("expected error for nil config") + } + + _, err = New(ctx, &Config{}) + if err == nil { + t.Fatal("expected error for empty config") + } + + _, err = New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/test.md"}, AllAgentsMDMaxBytes: -1}) + if err == nil { + t.Fatal("expected error for negative max bytes") + } + + _, err = New(ctx, &Config{AgentsMDFiles: []string{"/test.md"}}) + if err == nil { + t.Fatal("expected error for nil backend") + } +} + +func TestMiddleware_BasicInjection(t *testing.T) { + b := newMemBackend() + b.set("/agent.md", "You are a helpful assistant.") + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/agent.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + userMsg := &schema.Message{Role: schema.User, Content: "hello"} + if _, err = wrapped.Generate(ctx, []*schema.Message{userMsg}); err != nil { + t.Fatal(err) + } + + if len(mock.lastInput) != 2 { + t.Fatalf("expected 2 messages, got %d", len(mock.lastInput)) + } + if mock.lastInput[0].Role != schema.User { + t.Fatalf("expected first message role User, got %s", mock.lastInput[0].Role) + } + if !strings.Contains(mock.lastInput[0].Content, "You are a helpful assistant.") { + t.Fatalf("expected agent.md content in first message, got %q", mock.lastInput[0].Content) + } + if !strings.Contains(mock.lastInput[0].Content, "") { + t.Fatalf("expected system-reminder tag, got %q", mock.lastInput[0].Content) + } + if mock.lastInput[1].Content != "hello" { + t.Fatalf("expected original message preserved, got %q", mock.lastInput[1].Content) + } +} + +func TestMiddleware_MultipleFiles(t *testing.T) { + b := newMemBackend() + b.set("/a.md", "instruction A") + b.set("/b.md", "instruction B") + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/a.md", "/b.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + if _, err = wrapped.Generate(ctx, []*schema.Message{{Role: schema.User, Content: "hi"}}); err != nil { + t.Fatal(err) + } + + content := mock.lastInput[0].Content + idxA := strings.Index(content, "instruction A") + idxB := strings.Index(content, "instruction B") + if idxA < 0 || idxB < 0 { + t.Fatalf("both files should be included, content: %q", content) + } + if idxA >= idxB { + t.Fatal("file A should appear before file B") + } +} + +func TestMiddleware_ImportResolution(t *testing.T) { + b := newMemBackend() + b.set("/project/agent.md", "main instructions\n@sub/rules.md\nend") + b.set("/project/sub/rules.md", "imported rule") + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/project/agent.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + if _, err = wrapped.Generate(ctx, []*schema.Message{{Role: schema.User, Content: "hi"}}); err != nil { + t.Fatal(err) + } + + content := mock.lastInput[0].Content + // Original text should be preserved with @path intact. + if !strings.Contains(content, "main instructions") { + t.Fatalf("should contain original text, got %q", content) + } + if !strings.Contains(content, "@sub/rules.md") { + t.Fatalf("@import reference should be preserved in original text, got %q", content) + } + if !strings.Contains(content, "end") { + t.Fatalf("should contain original trailing text, got %q", content) + } + // Imported file should appear as a separate section. + if !strings.Contains(content, "Contents of /project/sub/rules.md") { + t.Fatalf("imported file should have its own section, got %q", content) + } + if !strings.Contains(content, "imported rule") { + t.Fatalf("imported file content should be present, got %q", content) + } +} + +func TestMiddleware_RecursiveImport(t *testing.T) { + b := newMemBackend() + b.set("/a.md", "top\n@/b.md") + b.set("/b.md", "middle\n@/c.md") + b.set("/c.md", "leaf content") + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/a.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + if _, err = wrapped.Generate(ctx, []*schema.Message{{Role: schema.User, Content: "hi"}}); err != nil { + t.Fatal(err) + } + + content := mock.lastInput[0].Content + // All three files should appear as separate sections. + for _, section := range []string{"Contents of /a.md", "Contents of /b.md", "Contents of /c.md"} { + if !strings.Contains(content, section) { + t.Fatalf("expected section %q in content, got %q", section, content) + } + } + for _, text := range []string{"top", "middle", "leaf content"} { + if !strings.Contains(content, text) { + t.Fatalf("expected %q in content, got %q", text, content) + } + } + // Sections should appear in order: a, b, c. + idxA := strings.Index(content, "Contents of /a.md") + idxB := strings.Index(content, "Contents of /b.md") + idxC := strings.Index(content, "Contents of /c.md") + if !(idxA < idxB && idxB < idxC) { + t.Fatalf("sections should appear in order a < b < c, got a=%d b=%d c=%d", idxA, idxB, idxC) + } +} + +func TestMiddleware_MaxImportDepth(t *testing.T) { + b := newMemBackend() + for i := 0; i < 7; i++ { + var content string + if i < 6 { + content = fmt.Sprintf("level %d\n@/level%d.md", i, i+1) + } else { + content = fmt.Sprintf("level %d", i) + } + b.set(fmt.Sprintf("/level%d.md", i), content) + } + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/level0.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + // Import failure at depth > 5 is logged, not returned as error. + _, err = wrapped.Generate(ctx, []*schema.Message{{Role: schema.User, Content: "hi"}}) + if err != nil { + t.Fatalf("expected no error (depth exceeded is logged), got %v", err) + } + // Levels 0-5 should be present as sections; level 6 fails silently. + content := mock.lastInput[0].Content + for i := 0; i <= 5; i++ { + want := fmt.Sprintf("Contents of /level%d.md", i) + if !strings.Contains(content, want) { + t.Fatalf("expected %q in content, got %q", want, content) + } + } + if strings.Contains(content, "Contents of /level6.md") { + t.Fatalf("level6 should not be present (depth exceeded), got %q", content) + } +} + +func TestMiddleware_CircularImport(t *testing.T) { + b := newMemBackend() + b.set("/a.md", "@/b.md") + b.set("/b.md", "@/a.md") + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/a.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + // Circular import failure is logged, not returned as error. + _, err = wrapped.Generate(ctx, []*schema.Message{{Role: schema.User, Content: "hi"}}) + if err != nil { + t.Fatalf("expected no error (circular import is logged), got %v", err) + } + // /a.md and /b.md should both be present; the circular ref from b->a is skipped. + content := mock.lastInput[0].Content + if !strings.Contains(content, "Contents of /a.md") { + t.Fatalf("expected /a.md section, got %q", content) + } + if !strings.Contains(content, "Contents of /b.md") { + t.Fatalf("expected /b.md section, got %q", content) + } +} + +func TestMiddleware_MaxBytesLimit(t *testing.T) { + b := newMemBackend() + b.set("/a.md", "AAAA") // 4 bytes + b.set("/b.md", "BBBB") // 4 bytes + + ctx := context.Background() + mw, err := New(ctx, &Config{ + Backend: b, + AgentsMDFiles: []string{"/a.md", "/b.md"}, + AllAgentsMDMaxBytes: 5, // file a (4) fits, file b (4) would exceed + }) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + if _, err = wrapped.Generate(ctx, []*schema.Message{{Role: schema.User, Content: "hi"}}); err != nil { + t.Fatal(err) + } + + content := mock.lastInput[0].Content + if !strings.Contains(content, "AAAA") { + t.Fatal("first file should be included") + } + if strings.Contains(content, "BBBB") { + t.Fatal("second file should be excluded due to max bytes") + } +} + +func TestMiddleware_NotPersistedInState(t *testing.T) { + b := newMemBackend() + b.set("/agent.md", "agent instructions") + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/agent.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + originalMsgs := []*schema.Message{{Role: schema.User, Content: "hello"}} + if _, err = wrapped.Generate(ctx, originalMsgs); err != nil { + t.Fatal(err) + } + + if len(originalMsgs) != 1 { + t.Fatalf("original messages should not be modified, got %d messages", len(originalMsgs)) + } + if originalMsgs[0].Content != "hello" { + t.Fatalf("original message should be unchanged, got %q", originalMsgs[0].Content) + } + if len(mock.lastInput) != 2 { + t.Fatalf("model should receive 2 messages, got %d", len(mock.lastInput)) + } +} + +func TestMiddleware_AbsoluteImportPath(t *testing.T) { + b := newMemBackend() + b.set("/project/main.md", "start\n@/shared/imported.md\nend") + b.set("/shared/imported.md", "absolute import content") + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/project/main.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + if _, err = wrapped.Generate(ctx, []*schema.Message{{Role: schema.User, Content: "hi"}}); err != nil { + t.Fatal(err) + } + + content := mock.lastInput[0].Content + // @path preserved in original text. + if !strings.Contains(content, "@/shared/imported.md") { + t.Fatalf("@import reference should be preserved, got %q", content) + } + // Imported content in separate section. + if !strings.Contains(content, "Contents of /shared/imported.md") { + t.Fatalf("expected separate section for imported file, got %q", content) + } + if !strings.Contains(content, "absolute import content") { + t.Fatalf("expected absolute import content, got %q", content) + } +} + +func TestMiddleware_ImportAsSeparateSection(t *testing.T) { + b := newMemBackend() + b.set("/project/agent.md", "Please read @sub/rules.md and also @sub/style.md for guidance.") + b.set("/project/sub/rules.md", "RULE_CONTENT") + b.set("/project/sub/style.md", "STYLE_CONTENT") + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/project/agent.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + if _, err = wrapped.Generate(ctx, []*schema.Message{{Role: schema.User, Content: "hi"}}); err != nil { + t.Fatal(err) + } + + content := mock.lastInput[0].Content + // Original text preserved with @paths intact. + if !strings.Contains(content, "Please read @sub/rules.md and also @sub/style.md for guidance.") { + t.Fatalf("original text with @paths should be preserved, got %q", content) + } + // Imported files appear as separate sections. + if !strings.Contains(content, "Contents of /project/sub/rules.md") { + t.Fatalf("expected rules.md section, got %q", content) + } + if !strings.Contains(content, "RULE_CONTENT") { + t.Fatalf("expected imported rule content, got %q", content) + } + if !strings.Contains(content, "Contents of /project/sub/style.md") { + t.Fatalf("expected style.md section, got %q", content) + } + if !strings.Contains(content, "STYLE_CONTENT") { + t.Fatalf("expected imported style content, got %q", content) + } + + // Sections should be ordered: agent.md, rules.md, style.md. + idxAgent := strings.Index(content, "Contents of /project/agent.md") + idxRules := strings.Index(content, "Contents of /project/sub/rules.md") + idxStyle := strings.Index(content, "Contents of /project/sub/style.md") + if !(idxAgent < idxRules && idxRules < idxStyle) { + t.Fatalf("sections should appear in order agent < rules < style, got agent=%d rules=%d style=%d", idxAgent, idxRules, idxStyle) + } +} + +// --- loader-specific tests --- + +func TestLoader_NoImportsPassthrough(t *testing.T) { + // Content without any @path should be returned as-is in its section. + b := newMemBackend() + b.set("/agent.md", "plain text without imports\nline two") + + l := newLoaderConfig(b, []string{"/agent.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(content, "plain text without imports") { + t.Fatalf("expected plain content, got %q", content) + } + if !strings.Contains(content, "line two") { + t.Fatalf("expected second line, got %q", content) + } +} + +func TestLoader_ImportAsSeparateSection(t *testing.T) { + // @path in the middle of a sentence should be preserved; imported file is a separate section. + b := newMemBackend() + b.set("/doc.md", "before @/snippet.md after") + b.set("/snippet.md", "INJECTED") + + l := newLoaderConfig(b, []string{"/doc.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + // Original text preserved. + if !strings.Contains(content, "before @/snippet.md after") { + t.Fatalf("original text should be preserved with @path, got %q", content) + } + // Imported file in separate section. + if !strings.Contains(content, "Contents of /snippet.md") { + t.Fatalf("expected separate section for snippet.md, got %q", content) + } + if !strings.Contains(content, "INJECTED") { + t.Fatalf("expected imported content, got %q", content) + } +} + +func TestLoader_MultipleImportsSameLine(t *testing.T) { + // Multiple @path on one line should each get a separate section. + b := newMemBackend() + b.set("/doc.md", "see @/a.txt and @/b.txt here") + b.set("/a.txt", "AAA") + b.set("/b.txt", "BBB") + + l := newLoaderConfig(b, []string{"/doc.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + // Original text preserved. + if !strings.Contains(content, "see @/a.txt and @/b.txt here") { + t.Fatalf("original text should be preserved, got %q", content) + } + // Each imported file has its own section. + if !strings.Contains(content, "Contents of /a.txt") { + t.Fatalf("expected section for a.txt, got %q", content) + } + if !strings.Contains(content, "AAA") { + t.Fatalf("expected a.txt content, got %q", content) + } + if !strings.Contains(content, "Contents of /b.txt") { + t.Fatalf("expected section for b.txt, got %q", content) + } + if !strings.Contains(content, "BBB") { + t.Fatalf("expected b.txt content, got %q", content) + } +} + +func TestLoader_SameFileTwiceOnSameLine(t *testing.T) { + // The same file referenced twice should appear only once as a section (deduped). + b := newMemBackend() + b.set("/doc.md", "@/shared.md and @/shared.md again") + b.set("/shared.md", "SHARED") + + l := newLoaderConfig(b, []string{"/doc.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + // Original text preserved. + if !strings.Contains(content, "@/shared.md and @/shared.md again") { + t.Fatalf("original text should be preserved, got %q", content) + } + // shared.md content should appear only once (deduped). + count := strings.Count(content, "Contents of /shared.md") + if count != 1 { + t.Fatalf("expected shared.md section to appear once (deduped), got %d in %q", count, content) + } +} + +func TestLoader_ImportFileNotFound(t *testing.T) { + b := newMemBackend() + b.set("/doc.md", "load @/missing.md please") + + l := newLoaderConfig(b, []string{"/doc.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatalf("expected no error (missing import is logged), got %v", err) + } + // Original text preserved; missing file simply has no section. + if !strings.Contains(content, "load @/missing.md please") { + t.Fatalf("expected original text preserved, got %q", content) + } + if strings.Contains(content, "Contents of /missing.md") { + t.Fatalf("missing file should not have a section, got %q", content) + } +} + +func TestLoader_RelativePathResolution(t *testing.T) { + // Relative path should resolve relative to the host file's directory. + b := newMemBackend() + b.set("/a/b/host.md", "ref @../c/target.md done") + b.set("/a/c/target.md", "TARGET") + + l := newLoaderConfig(b, []string{"/a/b/host.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + // Original text preserved. + if !strings.Contains(content, "ref @../c/target.md done") { + t.Fatalf("original text should be preserved, got %q", content) + } + // Imported file as separate section. + if !strings.Contains(content, "Contents of /a/c/target.md") { + t.Fatalf("expected section for target.md, got %q", content) + } + if !strings.Contains(content, "TARGET") { + t.Fatalf("expected imported content, got %q", content) + } +} + +func TestLoader_RelativeTopLevelPath(t *testing.T) { + // Top-level file uses relative path; imports with ./ resolve correctly. + b := newMemBackend() + b.set("sub/agents.md", "start @./other.md end") + b.set("sub/other.md", "OTHER CONTENT") + + l := newLoaderConfig(b, []string{"sub/agents.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(content, "start @./other.md end") { + t.Fatalf("expected original text preserved, got %q", content) + } + if !strings.Contains(content, "OTHER CONTENT") { + t.Fatalf("expected imported content, got %q", content) + } +} + +func TestLoader_RelativeTopLevelWithDotDotImport(t *testing.T) { + // Top-level file uses relative path; import with ../ resolves correctly. + b := newMemBackend() + b.set("sub/agents.md", "see @../shared/x.md here") + b.set("shared/x.md", "SHARED X") + + l := newLoaderConfig(b, []string{"sub/agents.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(content, "SHARED X") { + t.Fatalf("expected imported content, got %q", content) + } + // filepath.Clean should normalize "sub/../shared/x.md" to "shared/x.md" + if !strings.Contains(content, "Contents of shared/x.md") { + t.Fatalf("expected normalized path in section header, got %q", content) + } +} + +func TestLoader_RelativeTopLevelDedup(t *testing.T) { + // Two top-level relative paths that resolve to the same file via filepath.Clean + // should be deduped (loaded only once). + b := newMemBackend() + b.set("sub/a.md", "CONTENT A") + + l := newLoaderConfig(b, []string{"sub/a.md", "./sub/a.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + count := strings.Count(content, "CONTENT A") + if count != 1 { + t.Fatalf("expected file loaded once (deduped), got %d occurrences in %q", count, content) + } +} + +func TestLoader_AbsoluteTopLevelWithRelativeImport(t *testing.T) { + // Absolute top-level path with relative @import resolves correctly. + b := newMemBackend() + b.set("/project/agents.md", "ref @./lib/helper.md done") + b.set("/project/lib/helper.md", "HELPER") + + l := newLoaderConfig(b, []string{"/project/agents.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(content, "HELPER") { + t.Fatalf("expected imported content, got %q", content) + } + if !strings.Contains(content, "Contents of /project/lib/helper.md") { + t.Fatalf("expected section header, got %q", content) + } +} + +func TestLoader_AbsoluteTopLevelWithDotDotImport(t *testing.T) { + // Absolute top-level path; @import with ../ resolves and normalizes. + b := newMemBackend() + b.set("/project/sub/agents.md", "load @../shared/x.md here") + b.set("/project/shared/x.md", "SHARED") + + l := newLoaderConfig(b, []string{"/project/sub/agents.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(content, "SHARED") { + t.Fatalf("expected imported content, got %q", content) + } + // filepath.Clean normalizes "/project/sub/../shared/x.md" to "/project/shared/x.md" + if !strings.Contains(content, "Contents of /project/shared/x.md") { + t.Fatalf("expected normalized path in section header, got %q", content) + } +} + +func TestLoader_RelativeImportDedup(t *testing.T) { + // Two different relative @import paths that resolve to the same file + // should be deduped via filepath.Clean. + b := newMemBackend() + b.set("/a/main.md", "first @/a/b/shared.md second @../a/b/shared.md end") + b.set("/a/b/shared.md", "SHARED ONCE") + + l := newLoaderConfig(b, []string{"/a/main.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + count := strings.Count(content, "SHARED ONCE") + if count != 1 { + t.Fatalf("expected shared file loaded once (deduped), got %d in %q", count, content) + } +} + +func TestLoader_NestedRelativeImport(t *testing.T) { + // File A imports B via relative path, B imports C via relative path. + // All three should appear as separate sections. + b := newMemBackend() + b.set("/root/main.md", "start @sub/mid.md end") + b.set("/root/sub/mid.md", "mid @deep/leaf.md mid_end") + b.set("/root/sub/deep/leaf.md", "LEAF") + + l := newLoaderConfig(b, []string{"/root/main.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + for _, section := range []string{"Contents of /root/main.md", "Contents of /root/sub/mid.md", "Contents of /root/sub/deep/leaf.md"} { + if !strings.Contains(content, section) { + t.Fatalf("expected section %q, got %q", section, content) + } + } + if !strings.Contains(content, "LEAF") { + t.Fatalf("expected leaf content, got %q", content) + } +} + +func TestLoader_TransitiveImport(t *testing.T) { + // Imported file itself contains @imports; all should appear as separate sections. + b := newMemBackend() + b.set("/main.md", "header @/mid.md footer") + b.set("/mid.md", "mid-start @/leaf.md mid-end") + b.set("/leaf.md", "LEAF_VALUE") + + l := newLoaderConfig(b, []string{"/main.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + for _, section := range []string{"Contents of /main.md", "Contents of /mid.md", "Contents of /leaf.md"} { + if !strings.Contains(content, section) { + t.Fatalf("expected section %q, got %q", section, content) + } + } + if !strings.Contains(content, "LEAF_VALUE") { + t.Fatalf("expected leaf value, got %q", content) + } +} + +func TestLoader_EmptyFile(t *testing.T) { + b := newMemBackend() + b.set("/empty.md", "") + + l := newLoaderConfig(b, []string{"/empty.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + // Empty file is treated as non-existent, so output should be empty. + if content != "" { + t.Fatalf("expected empty output for empty file, got %q", content) + } +} + +func TestLoader_MaxBytesFirstFileFull(t *testing.T) { + // Even if the first file alone exceeds maxBytes, it should still be loaded in full. + b := newMemBackend() + b.set("/big.md", "ABCDEFGHIJ") // 10 bytes + + l := newLoaderConfig(b, []string{"/big.md"}, 3, nil) + content, err := l.load(context.Background()) // maxBytes=3, but first file always loads + if err != nil { + t.Fatal(err) + } + if !strings.Contains(content, "ABCDEFGHIJ") { + t.Fatalf("first file should always load in full, got %q", content) + } +} + +func TestLoader_CircularImportInline(t *testing.T) { + // Circular reference via @import should be detected, logged, and skipped. + b := newMemBackend() + b.set("/a.md", "text @/b.md more") + b.set("/b.md", "ref @/a.md back") + + l := newLoaderConfig(b, []string{"/a.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatalf("expected no error (circular import is logged), got %v", err) + } + // Both a and b should have sections; circular back-reference a from b is skipped. + if !strings.Contains(content, "Contents of /a.md") { + t.Fatalf("expected /a.md section, got %q", content) + } + if !strings.Contains(content, "Contents of /b.md") { + t.Fatalf("expected /b.md section, got %q", content) + } +} + +func TestLoader_MaxDepthInline(t *testing.T) { + // Deep chain via @import should be logged at depth > 5, not returned as error. + b := newMemBackend() + for i := 0; i < 7; i++ { + var content string + if i < 6 { + content = fmt.Sprintf("level%d @/level%d.md tail", i, i+1) + } else { + content = fmt.Sprintf("level%d", i) + } + b.set(fmt.Sprintf("/level%d.md", i), content) + } + + l := newLoaderConfig(b, []string{"/level0.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatalf("expected no error (depth exceeded is logged), got %v", err) + } + // Levels 0-5 should have sections. + for i := 0; i <= 5; i++ { + want := fmt.Sprintf("Contents of /level%d.md", i) + if !strings.Contains(content, want) { + t.Fatalf("expected %q in content, got %q", want, content) + } + } + // Level 6 should not be present. + if strings.Contains(content, "Contents of /level6.md") { + t.Fatalf("level6 should not be present (depth exceeded), got %q", content) + } +} + +func TestLoader_DiamondDependency(t *testing.T) { + // A imports B and D; B imports C; D also imports C. + // C should appear only once (deduped across the whole load). + b := newMemBackend() + b.set("/a.md", "start @/b.md middle @/d.md end") + b.set("/b.md", "B(@/c.md)") + b.set("/d.md", "D(@/c.md)") + b.set("/c.md", "SHARED") + + l := newLoaderConfig(b, []string{"/a.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatalf("diamond dependency should not be circular, got error: %v", err) + } + + // C should appear only once as a section (deduped). + count := strings.Count(content, "Contents of /c.md") + if count != 1 { + t.Fatalf("expected /c.md section once (deduped), got %d in %q", count, content) + } + // All files should have sections. + for _, section := range []string{"Contents of /a.md", "Contents of /b.md", "Contents of /c.md", "Contents of /d.md"} { + if !strings.Contains(content, section) { + t.Fatalf("expected section %q, got %q", section, content) + } + } +} + +func TestLoader_AtSignInNormalText(t *testing.T) { + // Bare @word without "/" or file extension should not trigger import. + // Email-like patterns (@example.com) with non-allowed extensions should also be ignored. + b := newMemBackend() + b.set("/agent.md", "contact me @ anytime or @ spaces and @someone mentioned and user@example.com and @company.org") + + l := newLoaderConfig(b, []string{"/agent.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(content, "contact me @ anytime") { + t.Fatalf("bare @ should not trigger import, got %q", content) + } + if !strings.Contains(content, "@someone mentioned") { + t.Fatalf("@someone without / or extension should not trigger import, got %q", content) + } + if !strings.Contains(content, "@example.com") { + t.Fatalf("email-like @example.com should not trigger import, got %q", content) + } + if !strings.Contains(content, "@company.org") { + t.Fatalf("email-like @company.org should not trigger import, got %q", content) + } +} + +func TestMiddleware_Stream(t *testing.T) { + b := newMemBackend() + b.set("/agent.md", "stream test") + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/agent.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + _, _ = wrapped.Stream(ctx, []*schema.Message{{Role: schema.User, Content: "hi"}}) + + if len(mock.lastInput) != 2 { + t.Fatalf("expected 2 messages for stream, got %d", len(mock.lastInput)) + } + if !strings.Contains(mock.lastInput[0].Content, "stream test") { + t.Fatalf("expected agent.md content in stream input, got %q", mock.lastInput[0].Content) + } +} + +func TestLoader_MaxBytesWithImports(t *testing.T) { + // Two top-level files that both import the same shared file. + // Budget should account for imported file bytes. + b := newMemBackend() + b.set("/a.md", "A(@/shared.md)") + b.set("/b.md", "B(@/shared.md)") + b.set("/shared.md", strings.Repeat("X", 100)) // 100 bytes + + l := newLoaderConfig(b, []string{"/a.md", "/b.md"}, 120, nil) + // /a.md = 14 bytes + /shared.md = 100 bytes => 114 total after /a.md. + // Budget = 120: /b.md (14 bytes) would push to 128, exceeding budget. + content, err := l.load(context.Background()) + if err != nil { + t.Fatalf("load failed: %v", err) + } + + // /a.md and its import should be included. + if !strings.Contains(content, strings.Repeat("X", 100)) { + t.Fatal("expected /a.md with shared content to be included") + } + + // /b.md should be excluded because totalBytes exceeded budget after loading /a.md. + if strings.Contains(content, "B(") { + t.Fatalf("expected /b.md to be excluded due to budget, got %q", content) + } +} + +func TestNew_Validation_EmptyAgentFiles(t *testing.T) { + ctx := context.Background() + b := newMemBackend() + + _, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{}}) + if err == nil { + t.Fatal("expected error for empty agent files") + } + if !strings.Contains(err.Error(), "at least one agent file path is required") { + t.Fatalf("unexpected error message: %v", err) + } +} + +func TestMiddleware_GenerateError(t *testing.T) { + // Non-ErrNotExist errors (e.g. permission denied) should propagate. + b := &errBackend{} + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/file.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + _, err = wrapped.Generate(ctx, []*schema.Message{{Role: schema.User, Content: "hi"}}) + if err == nil { + t.Fatal("expected error when backend read fails with non-ErrNotExist") + } + if !strings.Contains(err.Error(), "failed to load agent files") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestMiddleware_StreamError(t *testing.T) { + // Non-ErrNotExist errors (e.g. permission denied) should propagate. + b := &errBackend{} + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/file.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + _, err = wrapped.Stream(ctx, []*schema.Message{{Role: schema.User, Content: "hi"}}) + if err == nil { + t.Fatal("expected error when backend read fails with non-ErrNotExist for stream") + } + if !strings.Contains(err.Error(), "failed to load agent files") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestLoader_DuplicateTopLevelFiles(t *testing.T) { + // Same file listed twice in AgentFiles; second should be deduped via seen map. + b := newMemBackend() + b.set("/agent.md", "unique content") + + l := newLoaderConfig(b, []string{"/agent.md", "/agent.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + + count := strings.Count(content, "Contents of /agent.md") + if count != 1 { + t.Fatalf("expected /agent.md section once (deduped), got %d", count) + } +} + +func TestLoader_LoadFileError(t *testing.T) { + // Missing file (ErrNotExist) is silently skipped. + b := newMemBackend() + l := newLoaderConfig(b, []string{"/missing.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatalf("expected missing file to be skipped, got error: %v", err) + } + if content != "" { + t.Fatalf("expected empty output, got %q", content) + } +} + +func TestLoader_MaxBytesStopsImports(t *testing.T) { + // When budget is exhausted, further imports in collectImports should be skipped. + b := newMemBackend() + b.set("/main.md", "@/big.md @/small.md") + b.set("/big.md", strings.Repeat("B", 200)) + b.set("/small.md", "SMALL") + + l := newLoaderConfig(b, []string{"/main.md"}, 50, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + + // main.md itself is loaded (always), big.md pushes over budget, + // small.md should be skipped. + if !strings.Contains(content, "Contents of /main.md") { + t.Fatal("main.md should be present") + } + if strings.Contains(content, "SMALL") { + t.Fatal("small.md should be skipped after budget exhausted") + } +} + +func TestFormatContent_Empty(t *testing.T) { + // formatContent with nil/empty slice should return empty string. + if got := formatContent(nil); got != "" { + t.Fatalf("expected empty string for nil, got %q", got) + } + if got := formatContent([]loadedFile{}); got != "" { + t.Fatalf("expected empty string for empty slice, got %q", got) + } +} + +func TestMiddleware_AllFilesEmpty(t *testing.T) { + // When all agent files have empty content, loader returns "" and + // prependAgentMD returns the original input unchanged. + b := newMemBackend() + b.set("/agent.md", "") + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/agent.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + userMsg := []*schema.Message{{Role: schema.User, Content: "hello"}} + if _, err = wrapped.Generate(ctx, userMsg); err != nil { + t.Fatal(err) + } + // Empty file produces no agentmd content, so original messages pass through unchanged. + if len(mock.lastInput) != 1 { + t.Fatalf("expected 1 message (no agentmd prepended), got %d", len(mock.lastInput)) + } + if mock.lastInput[0].Content != "hello" { + t.Fatalf("expected original message unchanged, got %q", mock.lastInput[0].Content) + } +} + +func TestLoader_ExactOutput(t *testing.T) { + // Verify the exact output format matches the expected structure: + // each file (top-level and imported) gets its own "Contents of ..." section, + // @path references are preserved in the original text. + b := newMemBackend() + b.set("/project/CLAUDE.md", "this is project claude.md\n\n- git workflow @git/git-instructions.md") + b.set("/project/git/git-instructions.md", "this is git-instructions.md") + + l := newLoaderConfig(b, []string{"/project/CLAUDE.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + + expected := ` +As you answer the user's questions, you can use the following context: +Codebase and user instructions are shown below. Be sure to adhere to these instructions. IMPORTANT: These instructions OVERRIDE any default behavior and you MUST follow them exactly as written. + +Contents of /project/CLAUDE.md (instructions): + +this is project claude.md + +- git workflow @git/git-instructions.md + +Contents of /project/git/git-instructions.md (instructions): + +this is git-instructions.md +IMPORTANT: this context may or may not be relevant to your tasks. You should not respond to this context unless it is highly relevant to your task. +` + + if content != expected { + t.Fatalf("output mismatch.\n\ngot:\n%s\n\nexpected:\n%s", content, expected) + } +} + +func TestLoader_MissingFileSkipped(t *testing.T) { + b := newMemBackend() + b.set("/good.md", "GOOD CONTENT") + // /missing.md is not set + + l := newLoaderConfig(b, []string{"/missing.md", "/good.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatalf("expected no error for missing file, got %v", err) + } + if !strings.Contains(content, "GOOD CONTENT") { + t.Fatal("expected good.md content in output") + } +} + +func TestLoader_AllMissingFilesSkipped(t *testing.T) { + b := newMemBackend() + + l := newLoaderConfig(b, []string{"/missing1.md", "/missing2.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatalf("expected no error for missing files, got %v", err) + } + if content != "" { + t.Fatalf("expected empty output when all files missing, got %q", content) + } +} + +func TestLoader_CircularImportSkipped(t *testing.T) { + b := newMemBackend() + b.set("/a.md", "A content @/b.md") + b.set("/b.md", "B content @/a.md") + + // Circular import in collectImports is logged via onWarning and skipped. + l := newLoaderConfig(b, []string{"/a.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !strings.Contains(content, "A content") { + t.Fatal("expected a.md content") + } + if !strings.Contains(content, "B content") { + t.Fatal("expected b.md content") + } +} + +func TestLoader_DepthExceededSkipped(t *testing.T) { + b := newMemBackend() + // Create a chain that exceeds maxImportDepth (5) + b.set("/l0.md", "@/l1.md") + b.set("/l1.md", "@/l2.md") + b.set("/l2.md", "@/l3.md") + b.set("/l3.md", "@/l4.md") + b.set("/l4.md", "@/l5.md") + b.set("/l5.md", "@/l6.md") + b.set("/l6.md", "DEEP") + + l := newLoaderConfig(b, []string{"/l0.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatalf("expected no error for depth exceeded, got %v", err) + } + // Should have content up to the depth limit, deep file skipped. + if !strings.Contains(content, "/l0.md") { + t.Fatal("expected l0.md in output") + } +} + +func TestLoader_OnLoadWarningCallback(t *testing.T) { + b := newMemBackend() + b.set("/good.md", "GOOD CONTENT") + + var warnings []error + onWarning := func(filePath string, err error) { + warnings = append(warnings, fmt.Errorf("%s: %w", filePath, err)) + } + + l := newLoaderConfig(b, []string{"/missing.md", "/good.md"}, 0, onWarning) + content, err := l.load(context.Background()) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !strings.Contains(content, "GOOD CONTENT") { + t.Fatal("expected good.md content in output") + } + if len(warnings) == 0 { + t.Fatal("expected at least one warning for missing file") + } + if !strings.Contains(warnings[0].Error(), "file not found") { + t.Fatalf("expected file not found warning, got %v", warnings[0]) + } +} + +func TestMiddleware_MissingFile_Generate(t *testing.T) { + b := newMemBackend() + // /missing.md not set — will fail to read + + ctx := context.Background() + mw, err := New(ctx, &Config{ + Backend: b, + AgentsMDFiles: []string{"/missing.md"}, + }) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + userMsg := []*schema.Message{{Role: schema.User, Content: "hello"}} + _, err = wrapped.Generate(ctx, userMsg) + if err != nil { + t.Fatalf("expected no error for missing file, got %v", err) + } + // No agent.md content, so original messages should be passed through unchanged. + if len(mock.lastInput) != 1 { + t.Fatalf("expected 1 message (no agentmd prepended), got %d", len(mock.lastInput)) + } +} + +func TestMiddleware_MissingFile_Stream(t *testing.T) { + b := newMemBackend() + + ctx := context.Background() + mw, err := New(ctx, &Config{ + Backend: b, + AgentsMDFiles: []string{"/missing.md"}, + }) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + userMsg := []*schema.Message{{Role: schema.User, Content: "hello"}} + _, err = wrapped.Stream(ctx, userMsg) + if err != nil { + t.Fatalf("expected no error for missing file, got %v", err) + } + if len(mock.lastInput) != 1 { + t.Fatalf("expected 1 message (no agentmd prepended), got %d", len(mock.lastInput)) + } +} + +func TestMiddleware_InsertBeforeFirstUserMessage(t *testing.T) { + b := newMemBackend() + b.set("/agent.md", "agent instructions") + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/agent.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + // Input has a System message before the User message. + input := []*schema.Message{ + {Role: schema.System, Content: "system prompt"}, + {Role: schema.User, Content: "hello"}, + } + if _, err = wrapped.Generate(ctx, input); err != nil { + t.Fatal(err) + } + + if len(mock.lastInput) != 3 { + t.Fatalf("expected 3 messages, got %d", len(mock.lastInput)) + } + if mock.lastInput[0].Role != schema.System { + t.Fatalf("expected first message role System, got %s", mock.lastInput[0].Role) + } + if mock.lastInput[0].Content != "system prompt" { + t.Fatalf("expected system prompt preserved, got %q", mock.lastInput[0].Content) + } + if mock.lastInput[1].Role != schema.User || !strings.Contains(mock.lastInput[1].Content, "agent instructions") { + t.Fatalf("expected agentmd message before user message, got role=%s content=%q", mock.lastInput[1].Role, mock.lastInput[1].Content) + } + if mock.lastInput[2].Role != schema.User || mock.lastInput[2].Content != "hello" { + t.Fatalf("expected original user message at index 2, got role=%s content=%q", mock.lastInput[2].Role, mock.lastInput[2].Content) + } +} + +func TestMiddleware_InsertWithNoUserMessage(t *testing.T) { + b := newMemBackend() + b.set("/agent.md", "agent instructions") + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/agent.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + // Input has no User message at all. + input := []*schema.Message{ + {Role: schema.System, Content: "system prompt"}, + {Role: schema.Assistant, Content: "assistant reply"}, + } + if _, err = wrapped.Generate(ctx, input); err != nil { + t.Fatal(err) + } + + if len(mock.lastInput) != 3 { + t.Fatalf("expected 3 messages, got %d", len(mock.lastInput)) + } + if mock.lastInput[0].Role != schema.System { + t.Fatalf("expected System at index 0, got %s", mock.lastInput[0].Role) + } + if mock.lastInput[1].Role != schema.Assistant { + t.Fatalf("expected Assistant at index 1, got %s", mock.lastInput[1].Role) + } + if mock.lastInput[2].Role != schema.User || !strings.Contains(mock.lastInput[2].Content, "agent instructions") { + t.Fatalf("expected agentmd appended at end, got role=%s content=%q", mock.lastInput[2].Role, mock.lastInput[2].Content) + } +} + +func TestLoader_ImportIOError(t *testing.T) { + // When an imported file returns a non-ErrNotExist error (e.g. I/O error), + // the load should propagate the error (covers collectImports and loadFile error paths). + b := &partialErrBackend{ + files: map[string]string{ + "/main.md": "content @/broken.md", + }, + // /broken.md is NOT in the map, so Read returns I/O error (not ErrNotExist) + } + + l := newLoaderConfig(b, []string{"/main.md"}, 0, nil) + _, err := l.load(context.Background()) + if err == nil { + t.Fatal("expected error from I/O failure on imported file") + } + if !strings.Contains(err.Error(), "I/O error") { + t.Fatalf("expected I/O error, got: %v", err) + } +} diff --git a/adk/middlewares/agentsmd/loader.go b/adk/middlewares/agentsmd/loader.go new file mode 100644 index 000000000..db733383b --- /dev/null +++ b/adk/middlewares/agentsmd/loader.go @@ -0,0 +1,299 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package agentsmd + +import ( + "context" + "errors" + "fmt" + "log" + "os" + "path/filepath" + "regexp" + "strings" + + "github.com/cloudwego/eino/adk/filesystem" + "github.com/cloudwego/eino/adk/internal" +) + +// importRegex matches @path/to/file anywhere in text. +// The path must start with a letter, digit, dot, underscore, slash, or tilde, followed by +// path characters (letters, digits, dots, slashes, hyphens, underscores). +// A post-match filter further requires the path to contain "/" or end with +// an allowed extension (see allowedImportExts), so bare words like @someone +// and email-like patterns like @example.com are ignored. +var importRegex = regexp.MustCompile(`@([a-zA-Z0-9_.~/][a-zA-Z0-9_.~/\-]*)`) + +// allowedImportExts is the set of file extensions recognised as @import targets. +// Paths without "/" must end with one of these extensions to be treated as imports; +// this avoids false positives on email addresses (@example.com) and mentions (@foo.bar). +var allowedImportExts = map[string]bool{ + ".md": true, + ".txt": true, + ".mdx": true, + ".yaml": true, + ".yml": true, + ".json": true, + ".toml": true, +} + +const maxImportDepth = 5 + +// ReadRequest is an alias for filesystem.ReadRequest. +type ReadRequest = filesystem.ReadRequest +type FileContent = filesystem.FileContent + +// Backend defines the file access interface for loading Agents.md files. +// Implementations can use local filesystem, remote storage, or any other backend. +type Backend interface { + // Read reads the content of a file. + // If the file does not exist, implementations should return an error wrapping + // os.ErrNotExist (so that errors.Is(err, os.ErrNotExist) returns true). This allows the loader + // to silently skip missing files and notify via OnLoadWarning callback. + // Other errors (e.g. permission denied, I/O errors) will abort the loading process. + Read(ctx context.Context, req *ReadRequest) (*FileContent, error) +} + +// loaderConfig holds the immutable configuration for creating loaders. +// It is safe for concurrent use by multiple goroutines. +type loaderConfig struct { + backend Backend + files []string // ordered file paths from config + maxBytes int // cumulative read budget; 0 means unlimited + onWarning func(filePath string, err error) // callback for non-fatal loading warnings +} + +func newLoaderConfig(backend Backend, files []string, maxBytes int, onWarning func(filePath string, err error)) *loaderConfig { + if onWarning == nil { + onWarning = func(filePath string, err error) { + log.Printf("[agentsmd] warning: %s: %v", filePath, err) + } + } + return &loaderConfig{ + backend: backend, + files: files, + maxBytes: maxBytes, + onWarning: onWarning, + } +} + +// loader handles loading and @import resolution for agents.md files. +// A new loader is created for each load() call to avoid sharing mutable state +// (totalBytes) across concurrent invocations. +type loader struct { + *loaderConfig + totalBytes int // accumulated bytes during this load call +} + +func (cfg *loaderConfig) newLoader() *loader { + return &loader{loaderConfig: cfg} +} + +// load reads all agents.md files and returns the formatted content. +// Each top-level file and its @imported files appear as separate sections. +func (cfg *loaderConfig) load(ctx context.Context) (string, error) { + l := cfg.newLoader() + + var parts []loadedFile + seen := make(map[string]bool) // dedup across all files and imports + + for i, filePath := range l.files { + files, err := l.loadFile(ctx, filePath, 0, make(map[string]bool), seen) + if err != nil { + return "", fmt.Errorf("failed to load %q: %w", filePath, err) + } + + // If loading this file caused the budget to be exceeded, skip it + // (but always include the first file). + if i > 0 && l.maxBytes > 0 && l.totalBytes > l.maxBytes { + l.onWarning(filePath, fmt.Errorf("skipped: cumulative size %d exceeds max bytes %d", l.totalBytes, l.maxBytes)) + break + } + + parts = append(parts, files...) + } + + return formatContent(parts), nil +} + +// loadFile reads a file via Backend and collects @imported files as separate entries. +// Returns a slice where the first element is this file itself, followed by all +// transitively imported files (in encounter order, preserving @path in original text). +// visited tracks the current ancestor chain to detect circular imports. +// seen tracks globally loaded files to avoid duplicate reads and byte counting. +func (l *loader) loadFile(ctx context.Context, filePath string, depth int, visited map[string]bool, seen map[string]bool) ([]loadedFile, error) { + filePath = filepath.Clean(filePath) + + if depth > maxImportDepth { + l.onWarning(filePath, fmt.Errorf("@import depth exceeds maximum of %d", maxImportDepth)) + return nil, nil + } + + if visited[filePath] { + l.onWarning(filePath, fmt.Errorf("circular @import detected")) + return nil, nil + } + + if seen[filePath] { + return nil, nil + } + + visited[filePath] = true + defer delete(visited, filePath) + + fileContent, err := l.backend.Read(ctx, &ReadRequest{FilePath: filePath, Offset: 1}) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + l.onWarning(filePath, fmt.Errorf("file not found, skipping")) + return nil, nil + } + return nil, err + } + content := "" + if fileContent != nil { + content = fileContent.Content + } + + l.totalBytes += len(content) + seen[filePath] = true + + if content == "" { + return nil, nil + } + + // Collect imported files as separate sections (content stays untouched). + imports, err := l.collectImports(ctx, filePath, content, depth, visited, seen) + if err != nil { + return nil, err + } + + // This file first, then its imports. + result := make([]loadedFile, 0, 1+len(imports)) + result = append(result, loadedFile{path: filePath, content: content}) + result = append(result, imports...) + return result, nil +} + +// collectImports scans content for @path/to/file references and loads each +// imported file (plus its transitive imports). The original content is NOT modified. +// Returns the list of imported loadedFile entries in encounter order. +// seen is shared across the entire load call to avoid duplicate reads. +// Non-fatal errors (file not found, depth exceeded, circular import) are reported +// via onWarning and skipped. Fatal errors (e.g. I/O) are returned. +func (l *loader) collectImports(ctx context.Context, hostPath, content string, depth int, visited map[string]bool, seen map[string]bool) ([]loadedFile, error) { + dir := filepath.Dir(hostPath) + var imports []loadedFile + + matches := importRegex.FindAllStringSubmatch(content, -1) + for _, match := range matches { + rawPath := match[1] + + // Only treat as import if path contains "/" or ends with an allowed extension. + // This avoids false positives on email addresses and social mentions. + if !strings.Contains(rawPath, "/") && !allowedImportExts[filepath.Ext(rawPath)] { + continue + } + + // If budget is exhausted, skip further imports. + if l.maxBytes > 0 && l.totalBytes > l.maxBytes { + break + } + + importPath := rawPath + if !filepath.IsAbs(importPath) { + importPath = filepath.Join(dir, importPath) + } + + if seen[importPath] { + continue + } + + files, err := l.loadFile(ctx, importPath, depth+1, visited, seen) + if err != nil { + return nil, fmt.Errorf("failed to import %q from %q: %w", rawPath, hostPath, err) + } + + imports = append(imports, files...) + } + + return imports, nil +} + +type loadedFile struct { + path string + content string +} + +const formatHeaderEn = ` +As you answer the user's questions, you can use the following context: +Codebase and user instructions are shown below. Be sure to adhere to these instructions. IMPORTANT: These instructions OVERRIDE any default behavior and you MUST follow them exactly as written. +` + +const formatHeaderCn = ` +在回答用户问题时,你可以使用以下上下文: +代码库和用户指令如下。请务必遵守这些指令。重要提示:这些指令会覆盖任何默认行为,你必须严格按照要求执行。 +` + +const formatFileHeaderEn = "\nContents of " + +const formatFileHeaderCn = "\n文件内容:" + +const formatFileLabelEn = " (instructions):\n\n" + +const formatFileLabelCn = "(指令):\n\n" + +const formatFooterEn = `IMPORTANT: this context may or may not be relevant to your tasks. You should not respond to this context unless it is highly relevant to your task. +` + +const formatFooterCn = `重要提示:此上下文可能与你的任务相关,也可能不相关。除非此上下文与你的任务高度相关,否则不要响应此上下文。 +` + +func formatContent(files []loadedFile) string { + if len(files) == 0 { + return "" + } + + header := internal.SelectPrompt(internal.I18nPrompts{ + English: formatHeaderEn, + Chinese: formatHeaderCn, + }) + fileHeader := internal.SelectPrompt(internal.I18nPrompts{ + English: formatFileHeaderEn, + Chinese: formatFileHeaderCn, + }) + fileLabel := internal.SelectPrompt(internal.I18nPrompts{ + English: formatFileLabelEn, + Chinese: formatFileLabelCn, + }) + footer := internal.SelectPrompt(internal.I18nPrompts{ + English: formatFooterEn, + Chinese: formatFooterCn, + }) + + var sb strings.Builder + sb.WriteString(header) + + for _, f := range files { + sb.WriteString(fileHeader) + sb.WriteString(f.path) + sb.WriteString(fileLabel) + sb.WriteString(f.content) + sb.WriteString("\n") + } + sb.WriteString(footer) + return sb.String() +} From 355011bba8e6eb5ec9ec615dd19a2b9f3014d751 Mon Sep 17 00:00:00 2001 From: "luohuaqing.2018" Date: Thu, 12 Feb 2026 17:42:07 +0800 Subject: [PATCH 31/59] feat(adk): add TurnLoop and Cancellable interfaces Change-Id: Ifb3be9fadeabdd8f3b6985bb47d2fb38a7004beb --- adk/interface.go | 39 +++ adk/turn_loop.go | 251 +++++++++++++++++++ adk/turn_loop_test.go | 564 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 854 insertions(+) create mode 100644 adk/turn_loop.go create mode 100644 adk/turn_loop_test.go diff --git a/adk/interface.go b/adk/interface.go index 5c06843ae..fef2f98fe 100644 --- a/adk/interface.go +++ b/adk/interface.go @@ -20,6 +20,7 @@ import ( "bytes" "context" "encoding/gob" + "errors" "fmt" "io" @@ -269,3 +270,41 @@ type ResumableAgent interface { Resume(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] } + +// CancelMode specifies when an agent should be canceled. +// Modes can be combined with bitwise OR to cancel at multiple execution points. +// For example, CancelAfterChatModel | CancelAfterToolCall cancels the agent +// after whichever execution point is reached first. +type CancelMode int + +const ( + // CancelImmediate cancels the agent immediately without waiting + // for any execution point. + CancelImmediate CancelMode = 0 + // CancelAfterChatModel cancels the agent after a chat model call completes. + CancelAfterChatModel CancelMode = 1 << iota + // CancelAfterToolCall cancels the agent after a tool call completes. + CancelAfterToolCall +) + +// ErrAgentFinished is returned by Cancel when the agent has already finished execution. +var ErrAgentFinished = errors.New("agent has already finished execution") + +// CancelOption holds options for cancelling an agent. +type CancelOption struct { + Mode CancelMode +} + +// Cancellable is an optional interface that an Agent can implement to support +// cancellation during execution. +type Cancellable interface { + // Cancel signals the agent to stop, either immediately or after reaching the + // specified execution point(s) defined by opt.Mode. + // + // The opt parameter must be non-nil. Use opt.Mode to control when the + // cancellation takes effect (e.g., CancelImmediate, CancelAfterChatModel, + // CancelAfterToolCall, or a combination via bitwise OR). + // + // If the agent has already finished execution, Cancel returns ErrAgentFinished. + Cancel(ctx context.Context, opt *CancelOption) error +} diff --git a/adk/turn_loop.go b/adk/turn_loop.go new file mode 100644 index 000000000..d6ac24071 --- /dev/null +++ b/adk/turn_loop.go @@ -0,0 +1,251 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "fmt" + "time" +) + +// ConsumeMode specifies how a received message should be consumed +// relative to the currently running agent. +type ConsumeMode int + +const ( + // ConsumeNonPreemptive processes the message after the current agent + // finishes. This is the default queued behavior. + ConsumeNonPreemptive ConsumeMode = iota + // ConsumePreemptive cancels the currently running agent (if it + // implements Cancellable) and processes the message immediately. + // If the agent does not implement Cancellable, the message is + // buffered and processed after the agent finishes. + ConsumePreemptive +) + +// ConsumeOption describes how a received message should be consumed. +// It combines ConsumeMode (preemptive vs queued) with CancelMode +// (how to cancel the running agent when preempting). +type ConsumeOption struct { + // Mode specifies whether the message should preempt the current agent + // or be queued. Default zero value is ConsumeNonPreemptive. + Mode ConsumeMode + // CancelOption specifies when and how the running agent should be canceled. + // Only meaningful when Mode is ConsumePreemptive and the agent + // implements Cancellable. Default nil value means CancelImmediate. + CancelOption *CancelOption +} + +// NonPreemptiveConsumeOption is a convenience value for the common +// non-preemptive (queued) case. +var NonPreemptiveConsumeOption = ConsumeOption{Mode: ConsumeNonPreemptive} + +// MessageSource is an interface for pulling typed messages from an external source. +// Receive blocks until a message is available or an error occurs. +// The timeout parameter specifies the maximum duration to wait for a message. +// The returned ConsumeOption indicates whether the message should preempt the +// currently running agent (and how to cancel it) or be queued for processing +// after it finishes. +type MessageSource[T any] interface { + Receive(ctx context.Context, timeout time.Duration) (T, ConsumeOption, error) +} + +// TurnLoopConfig is the configuration for creating a TurnLoop. +type TurnLoopConfig[T any] struct { + // Source provides messages to drive the loop. Required. + Source MessageSource[T] + // GenInput converts a received message into AgentInput. Required. + GenInput func(ctx context.Context, item T) (*AgentInput, error) + // GetAgent returns the Agent to run for a given message. Required. + GetAgent func(ctx context.Context, item T) (Agent, error) + // OnAgentEvent is called for each event emitted by the agent. Optional. + OnAgentEvent func(ctx context.Context, event *AgentEvent) error + // RunOptions are passed to Agent.Run on each turn. Optional. + RunOptions []AgentRunOption + // ReceiveTimeout is the timeout passed to Source.Receive on each iteration. + // Zero means no timeout. Optional. + ReceiveTimeout time.Duration +} + +// TurnLoop is a loop that pulls messages from a source, runs an Agent for +// each message, and dispatches resulting events. It supports preemptive +// cancellation when the source returns ConsumePreemptive and the current +// agent implements Cancellable. +type TurnLoop[T any] struct { + source MessageSource[T] + genInput func(ctx context.Context, item T) (*AgentInput, error) + getAgent func(ctx context.Context, item T) (Agent, error) + onAgentEvent func(ctx context.Context, event *AgentEvent) error + runOptions []AgentRunOption + receiveTimeout time.Duration +} + +// NewTurnLoop creates a new TurnLoop from the given configuration. +// Source, GenInput, and GetAgent are required fields. +func NewTurnLoop[T any](config TurnLoopConfig[T]) (*TurnLoop[T], error) { + if config.Source == nil { + return nil, fmt.Errorf("TurnLoopConfig.Source is required") + } + if config.GenInput == nil { + return nil, fmt.Errorf("TurnLoopConfig.GenInput is required") + } + if config.GetAgent == nil { + return nil, fmt.Errorf("TurnLoopConfig.GetAgent is required") + } + + return &TurnLoop[T]{ + source: config.Source, + genInput: config.GenInput, + getAgent: config.GetAgent, + onAgentEvent: config.OnAgentEvent, + runOptions: config.RunOptions, + receiveTimeout: config.ReceiveTimeout, + }, nil +} + +// recvResult holds the result of a concurrent Receive call. +type recvResult[T any] struct { + item T + option ConsumeOption + err error +} + +// iterResult holds the result of a single AsyncIterator.Next call. +type iterResult struct { + event *AgentEvent + ok bool +} + +// Run starts the blocking loop that continuously receives messages, runs +// agents, and dispatches events. While an agent is running, the next message +// is received concurrently. If that message's ConsumeOption has ConsumePreemptive +// mode and the running agent implements Cancellable, the agent is canceled +// (using the CancelMode from the option) and the new message is processed +// immediately. +func (l *TurnLoop[T]) Run(ctx context.Context) error { + // done is closed when Run returns, signaling background goroutines to exit. + done := make(chan struct{}) + defer close(done) + + // Initial blocking receive — no agent running yet, mode is irrelevant. + item, _, err := l.source.Receive(ctx, l.receiveTimeout) + if err != nil { + return err + } + + for { + input, e := l.genInput(ctx, item) + if e != nil { + return fmt.Errorf("failed to generate agent input: %w", e) + } + + agent, e := l.getAgent(ctx, item) + if e != nil { + return fmt.Errorf("failed to get agent: %w", e) + } + + // Start receiving the next message concurrently with agent execution. + // The channel is buffered so the goroutine never blocks on send. + recvCh := make(chan recvResult[T], 1) + go func() { + i, opt, e_ := l.source.Receive(ctx, l.receiveTimeout) + recvCh <- recvResult[T]{i, opt, e_} + }() + + // Run the agent and forward events through a channel so we can + // select between agent events and incoming messages. + iter := agent.Run(ctx, input, l.runOptions...) + eventCh := make(chan iterResult, 1) + go func() { + for { + event, ok := iter.Next() + select { + case eventCh <- iterResult{event, ok}: + case <-done: + return + } + if !ok { + return + } + } + }() + + var pending *recvResult[T] + var turnErr error + + eventLoop: + for { + select { + case ev := <-eventCh: + if !ev.ok { + break eventLoop + } + if ev.event.Err != nil { + turnErr = fmt.Errorf("agent run failed: %w", ev.event.Err) + break eventLoop + } + if l.onAgentEvent != nil { + if e_ := l.onAgentEvent(ctx, ev.event); e_ != nil { + turnErr = fmt.Errorf("OnAgentEvent failed: %w", e_) + break eventLoop + } + } + + case recv := <-recvCh: + recvCh = nil // nil channel never matches in select + if recv.option.Mode == ConsumePreemptive && recv.err == nil { + if ca, ok := agent.(Cancellable); ok { + if e_ := ca.Cancel(ctx, recv.option.CancelOption); e_ != nil { + return fmt.Errorf("failed to cancel agent: %w", e_) + } + // Drain remaining events after cancellation. + for { + ev := <-eventCh + if !ev.ok { + break + } + } + pending = &recvResult[T]{item: recv.item} + break eventLoop + } + } + // Non-preemptive, preemptive but not Cancellable, or + // source error: buffer and let the eventLoop finish + // processing the current agent's events first. + pending = &recvResult[T]{item: recv.item, err: recv.err} + } + } + + if turnErr != nil { + return turnErr + } + + if pending != nil { + if pending.err != nil { + return pending.err + } + item = pending.item + } else { + // Agent finished before the next message arrived; wait for it. + recv := <-recvCh + if recv.err != nil { + return recv.err + } + item = recv.item + } + } +} diff --git a/adk/turn_loop_test.go b/adk/turn_loop_test.go new file mode 100644 index 000000000..9ef2d1a90 --- /dev/null +++ b/adk/turn_loop_test.go @@ -0,0 +1,564 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/cloudwego/eino/schema" +) + +// --------------------------------------------------------------------------- +// Test mocks +// --------------------------------------------------------------------------- + +// turnLoopMockSource returns items from a slice (all NonPreemptive), then an error. +type turnLoopMockSource struct { + items []string + idx int + err error +} + +func (s *turnLoopMockSource) Receive(_ context.Context, _ time.Duration) (string, ConsumeOption, error) { + if s.idx >= len(s.items) { + return "", NonPreemptiveConsumeOption, s.err + } + item := s.items[s.idx] + s.idx++ + return item, NonPreemptiveConsumeOption, nil +} + +// turnLoopFuncSource delegates Receive to a user-supplied function. +type turnLoopFuncSource[T any] struct { + fn func(ctx context.Context, timeout time.Duration) (T, ConsumeOption, error) +} + +func (s *turnLoopFuncSource[T]) Receive(ctx context.Context, timeout time.Duration) (T, ConsumeOption, error) { + return s.fn(ctx, timeout) +} + +// turnLoopMockAgent emits a fixed list of events per Run call. +type turnLoopMockAgent struct { + name string + events []*AgentEvent +} + +func (a *turnLoopMockAgent) Name(_ context.Context) string { return a.name } +func (a *turnLoopMockAgent) Description(_ context.Context) string { return "mock agent" } +func (a *turnLoopMockAgent) Run(_ context.Context, _ *AgentInput, _ ...AgentRunOption) *AsyncIterator[*AgentEvent] { + iter, gen := NewAsyncIteratorPair[*AgentEvent]() + go func() { + defer gen.Close() + for _, e := range a.events { + gen.Send(e) + } + }() + return iter +} + +// turnLoopCancellableAgent blocks until Cancel is called, then closes its iterator. +// It implements both Agent and Cancellable. +type turnLoopCancellableAgent struct { + name string + startedCh chan struct{} // closed when Run is entered + cancelCh chan struct{} // closed by Cancel + cancelled atomic.Bool + cancelledOpt *CancelOption // records the CancelOption passed to Cancel +} + +func (a *turnLoopCancellableAgent) Name(_ context.Context) string { return a.name } +func (a *turnLoopCancellableAgent) Description(_ context.Context) string { return "cancellable mock" } +func (a *turnLoopCancellableAgent) Run(_ context.Context, _ *AgentInput, _ ...AgentRunOption) *AsyncIterator[*AgentEvent] { + iter, gen := NewAsyncIteratorPair[*AgentEvent]() + close(a.startedCh) + go func() { + defer gen.Close() + <-a.cancelCh + }() + return iter +} + +func (a *turnLoopCancellableAgent) Cancel(_ context.Context, opt *CancelOption) error { + a.cancelled.Store(true) + a.cancelledOpt = opt + close(a.cancelCh) + return nil +} + +// turnLoopBlockingAgent blocks until blockCh is closed, then emits its events. +// It does NOT implement Cancellable. +type turnLoopBlockingAgent struct { + name string + startedCh chan struct{} + blockCh chan struct{} + events []*AgentEvent +} + +func (a *turnLoopBlockingAgent) Name(_ context.Context) string { return a.name } +func (a *turnLoopBlockingAgent) Description(_ context.Context) string { return "blocking mock" } +func (a *turnLoopBlockingAgent) Run(_ context.Context, _ *AgentInput, _ ...AgentRunOption) *AsyncIterator[*AgentEvent] { + iter, gen := NewAsyncIteratorPair[*AgentEvent]() + close(a.startedCh) + go func() { + defer gen.Close() + <-a.blockCh + for _, e := range a.events { + gen.Send(e) + } + }() + return iter +} + +// --------------------------------------------------------------------------- +// Tests — validation +// --------------------------------------------------------------------------- + +func TestNewTurnLoop_Validation(t *testing.T) { + t.Run("missing source", func(t *testing.T) { + _, err := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(_ context.Context, _ string) (*AgentInput, error) { return nil, nil }, + GetAgent: func(_ context.Context, _ string) (Agent, error) { return nil, nil }, + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "Source") + }) + + t.Run("missing GenInput", func(t *testing.T) { + _, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: &turnLoopMockSource{}, + GetAgent: func(_ context.Context, _ string) (Agent, error) { return nil, nil }, + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "GenInput") + }) + + t.Run("missing GetAgent", func(t *testing.T) { + _, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: &turnLoopMockSource{}, + GenInput: func(_ context.Context, _ string) (*AgentInput, error) { return nil, nil }, + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "GetAgent") + }) + + t.Run("valid config", func(t *testing.T) { + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: &turnLoopMockSource{}, + GenInput: func(_ context.Context, _ string) (*AgentInput, error) { return nil, nil }, + GetAgent: func(_ context.Context, _ string) (Agent, error) { return nil, nil }, + }) + require.NoError(t, err) + assert.NotNil(t, loop) + }) +} + +// --------------------------------------------------------------------------- +// Tests — non-preemptive (queued) behavior +// --------------------------------------------------------------------------- + +func TestTurnLoop_NormalLoop(t *testing.T) { + agent := &turnLoopMockAgent{ + name: "test-agent", + events: []*AgentEvent{ + {Output: &AgentOutput{MessageOutput: &MessageVariant{Message: schema.AssistantMessage("hello", nil)}}}, + }, + } + + var receivedEvents []*AgentEvent + var receivedItems []string + + source := &turnLoopMockSource{ + items: []string{"msg1", "msg2", "msg3"}, + err: context.DeadlineExceeded, + } + + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: source, + GenInput: func(_ context.Context, item string) (*AgentInput, error) { + receivedItems = append(receivedItems, item) + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil + }, + GetAgent: func(_ context.Context, _ string) (Agent, error) { + return agent, nil + }, + OnAgentEvent: func(_ context.Context, event *AgentEvent) error { + receivedEvents = append(receivedEvents, event) + return nil + }, + }) + require.NoError(t, err) + + err = loop.Run(context.Background()) + assert.ErrorIs(t, err, context.DeadlineExceeded) + assert.Equal(t, []string{"msg1", "msg2", "msg3"}, receivedItems) + assert.Len(t, receivedEvents, 3) +} + +func TestTurnLoop_SourceError(t *testing.T) { + sourceErr := errors.New("source failure") + source := &turnLoopMockSource{ + items: nil, + err: sourceErr, + } + + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: source, + GenInput: func(_ context.Context, _ string) (*AgentInput, error) { + return &AgentInput{}, nil + }, + GetAgent: func(_ context.Context, _ string) (Agent, error) { + return &turnLoopMockAgent{name: "a"}, nil + }, + }) + require.NoError(t, err) + + err = loop.Run(context.Background()) + assert.ErrorIs(t, err, sourceErr) +} + +func TestTurnLoop_GenInputError(t *testing.T) { + genErr := errors.New("gen input failure") + + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: &turnLoopMockSource{items: []string{"msg1"}, err: errors.New("should not reach")}, + GenInput: func(_ context.Context, _ string) (*AgentInput, error) { + return nil, genErr + }, + GetAgent: func(_ context.Context, _ string) (Agent, error) { + return &turnLoopMockAgent{name: "a"}, nil + }, + }) + require.NoError(t, err) + + err = loop.Run(context.Background()) + assert.ErrorIs(t, err, genErr) + assert.Contains(t, err.Error(), "failed to generate agent input") +} + +func TestTurnLoop_GetAgentError(t *testing.T) { + agentErr := errors.New("get agent failure") + + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: &turnLoopMockSource{items: []string{"msg1"}, err: errors.New("should not reach")}, + GenInput: func(_ context.Context, _ string) (*AgentInput, error) { + return &AgentInput{}, nil + }, + GetAgent: func(_ context.Context, _ string) (Agent, error) { + return nil, agentErr + }, + }) + require.NoError(t, err) + + err = loop.Run(context.Background()) + assert.ErrorIs(t, err, agentErr) + assert.Contains(t, err.Error(), "failed to get agent") +} + +func TestTurnLoop_OnAgentEventError(t *testing.T) { + eventErr := errors.New("event handler failure") + agent := &turnLoopMockAgent{ + name: "test-agent", + events: []*AgentEvent{{Output: &AgentOutput{}}}, + } + + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: &turnLoopMockSource{items: []string{"msg1"}, err: errors.New("should not reach")}, + GenInput: func(_ context.Context, _ string) (*AgentInput, error) { + return &AgentInput{}, nil + }, + GetAgent: func(_ context.Context, _ string) (Agent, error) { + return agent, nil + }, + OnAgentEvent: func(_ context.Context, _ *AgentEvent) error { + return eventErr + }, + }) + require.NoError(t, err) + + err = loop.Run(context.Background()) + assert.ErrorIs(t, err, eventErr) + assert.Contains(t, err.Error(), "OnAgentEvent failed") +} + +func TestTurnLoop_ContextCancellation(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + callCount := 0 + source := &turnLoopFuncSource[string]{fn: func(ctx context.Context, _ time.Duration) (string, ConsumeOption, error) { + callCount++ + if callCount > 1 { + cancel() + return "", NonPreemptiveConsumeOption, ctx.Err() + } + return "msg1", NonPreemptiveConsumeOption, nil + }} + + agent := &turnLoopMockAgent{ + name: "test-agent", + events: []*AgentEvent{{Output: &AgentOutput{}}}, + } + + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: source, + GenInput: func(_ context.Context, item string) (*AgentInput, error) { + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil + }, + GetAgent: func(_ context.Context, _ string) (Agent, error) { + return agent, nil + }, + }) + require.NoError(t, err) + + err = loop.Run(ctx) + assert.ErrorIs(t, err, context.Canceled) +} + +func TestTurnLoop_MultipleEventsPerTurn(t *testing.T) { + agent := &turnLoopMockAgent{ + name: "multi-event-agent", + events: []*AgentEvent{ + {Output: &AgentOutput{MessageOutput: &MessageVariant{Message: schema.AssistantMessage("event1", nil)}}}, + {Output: &AgentOutput{MessageOutput: &MessageVariant{Message: schema.AssistantMessage("event2", nil)}}}, + {Output: &AgentOutput{MessageOutput: &MessageVariant{Message: schema.AssistantMessage("event3", nil)}}}, + }, + } + + var eventCount int + + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: &turnLoopMockSource{items: []string{"msg1"}, err: context.DeadlineExceeded}, + GenInput: func(_ context.Context, item string) (*AgentInput, error) { + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil + }, + GetAgent: func(_ context.Context, _ string) (Agent, error) { + return agent, nil + }, + OnAgentEvent: func(_ context.Context, _ *AgentEvent) error { + eventCount++ + return nil + }, + }) + require.NoError(t, err) + + err = loop.Run(context.Background()) + assert.ErrorIs(t, err, context.DeadlineExceeded) + assert.Equal(t, 3, eventCount) +} + +func TestTurnLoop_NoOnAgentEvent(t *testing.T) { + agent := &turnLoopMockAgent{ + name: "test-agent", + events: []*AgentEvent{{Output: &AgentOutput{}}}, + } + + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: &turnLoopMockSource{items: []string{"msg1"}, err: context.DeadlineExceeded}, + GenInput: func(_ context.Context, item string) (*AgentInput, error) { + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil + }, + GetAgent: func(_ context.Context, _ string) (Agent, error) { + return agent, nil + }, + }) + require.NoError(t, err) + + err = loop.Run(context.Background()) + assert.ErrorIs(t, err, context.DeadlineExceeded) +} + +func TestTurnLoop_AgentErrorEvent(t *testing.T) { + agentErr := errors.New("agent internal error") + agent := &turnLoopMockAgent{ + name: "error-agent", + events: []*AgentEvent{{Err: agentErr}}, + } + + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: &turnLoopMockSource{items: []string{"msg1"}, err: context.DeadlineExceeded}, + GenInput: func(_ context.Context, item string) (*AgentInput, error) { + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil + }, + GetAgent: func(_ context.Context, _ string) (Agent, error) { + return agent, nil + }, + }) + require.NoError(t, err) + + err = loop.Run(context.Background()) + assert.ErrorIs(t, err, agentErr) + assert.Contains(t, err.Error(), "agent run failed") +} + +// --------------------------------------------------------------------------- +// Tests — preemptive behavior +// --------------------------------------------------------------------------- + +func TestTurnLoop_PreemptiveCancellation(t *testing.T) { + slowAgent := &turnLoopCancellableAgent{ + name: "slow-agent", + startedCh: make(chan struct{}), + cancelCh: make(chan struct{}), + } + + fastAgent := &turnLoopMockAgent{ + name: "fast-agent", + events: []*AgentEvent{{Output: &AgentOutput{}}}, + } + + var processedItems []string + callCount := 0 + source := &turnLoopFuncSource[string]{fn: func(_ context.Context, _ time.Duration) (string, ConsumeOption, error) { + callCount++ + switch callCount { + case 1: + return "slow-msg", NonPreemptiveConsumeOption, nil + case 2: + <-slowAgent.startedCh + return "preempt-msg", ConsumeOption{Mode: ConsumePreemptive, CancelOption: &CancelOption{Mode: CancelImmediate}}, nil + default: + return "", NonPreemptiveConsumeOption, context.DeadlineExceeded + } + }} + + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: source, + GenInput: func(_ context.Context, item string) (*AgentInput, error) { + processedItems = append(processedItems, item) + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil + }, + GetAgent: func(_ context.Context, item string) (Agent, error) { + if item == "slow-msg" { + return slowAgent, nil + } + return fastAgent, nil + }, + }) + require.NoError(t, err) + + err = loop.Run(context.Background()) + assert.ErrorIs(t, err, context.DeadlineExceeded) + assert.True(t, slowAgent.cancelled.Load(), "slow agent should have been cancelled") + assert.Equal(t, CancelImmediate, slowAgent.cancelledOpt.Mode) + assert.Equal(t, []string{"slow-msg", "preempt-msg"}, processedItems) +} + +func TestTurnLoop_PreemptiveWithCancelMode(t *testing.T) { + slowAgent := &turnLoopCancellableAgent{ + name: "slow-agent", + startedCh: make(chan struct{}), + cancelCh: make(chan struct{}), + } + + fastAgent := &turnLoopMockAgent{ + name: "fast-agent", + events: []*AgentEvent{{Output: &AgentOutput{}}}, + } + + callCount := 0 + source := &turnLoopFuncSource[string]{fn: func(_ context.Context, _ time.Duration) (string, ConsumeOption, error) { + callCount++ + switch callCount { + case 1: + return "slow-msg", NonPreemptiveConsumeOption, nil + case 2: + <-slowAgent.startedCh + return "preempt-msg", ConsumeOption{ + Mode: ConsumePreemptive, + CancelOption: &CancelOption{Mode: CancelAfterChatModel | CancelAfterToolCall}, + }, nil + default: + return "", NonPreemptiveConsumeOption, context.DeadlineExceeded + } + }} + + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: source, + GenInput: func(_ context.Context, item string) (*AgentInput, error) { + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil + }, + GetAgent: func(_ context.Context, item string) (Agent, error) { + if item == "slow-msg" { + return slowAgent, nil + } + return fastAgent, nil + }, + }) + require.NoError(t, err) + + err = loop.Run(context.Background()) + assert.ErrorIs(t, err, context.DeadlineExceeded) + assert.True(t, slowAgent.cancelled.Load()) + assert.Equal(t, CancelAfterChatModel|CancelAfterToolCall, slowAgent.cancelledOpt.Mode) +} + +func TestTurnLoop_PreemptiveNonCancellableAgent(t *testing.T) { + agentStarted := make(chan struct{}) + agentContinue := make(chan struct{}) + + blockingAgent := &turnLoopBlockingAgent{ + name: "blocking-agent", + startedCh: agentStarted, + blockCh: agentContinue, + events: []*AgentEvent{{Output: &AgentOutput{}}}, + } + fastAgent := &turnLoopMockAgent{ + name: "fast-agent", + events: []*AgentEvent{{Output: &AgentOutput{}}}, + } + + var processedItems []string + callCount := 0 + source := &turnLoopFuncSource[string]{fn: func(_ context.Context, _ time.Duration) (string, ConsumeOption, error) { + callCount++ + switch callCount { + case 1: + return "blocking-msg", NonPreemptiveConsumeOption, nil + case 2: + <-agentStarted + close(agentContinue) + return "preempt-msg", ConsumeOption{Mode: ConsumePreemptive}, nil + default: + return "", NonPreemptiveConsumeOption, context.DeadlineExceeded + } + }} + + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: source, + GenInput: func(_ context.Context, item string) (*AgentInput, error) { + processedItems = append(processedItems, item) + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil + }, + GetAgent: func(_ context.Context, item string) (Agent, error) { + if item == "blocking-msg" { + return blockingAgent, nil + } + return fastAgent, nil + }, + }) + require.NoError(t, err) + + err = loop.Run(context.Background()) + assert.ErrorIs(t, err, context.DeadlineExceeded) + assert.Equal(t, []string{"blocking-msg", "preempt-msg"}, processedItems) +} From b485962fd4a20981222cfd0b3f65f25183f25601 Mon Sep 17 00:00:00 2001 From: "luohuaqing.2018" Date: Thu, 12 Feb 2026 18:13:32 +0800 Subject: [PATCH 32/59] refactor(adk): improve TurnLoop API signatures - Add inputItem param to OnAgentEvent callback - Move RunOptions from static config field to GenInput return value - Change CancelOption from pointer to value type Change-Id: I5bff76223fedb2b48fc5671e90d635a9479b20f6 Co-Authored-By: Claude Opus 4.6 --- adk/interface.go | 8 ++--- adk/turn_loop.go | 26 ++++++++--------- adk/turn_loop_test.go | 68 +++++++++++++++++++++---------------------- 3 files changed, 50 insertions(+), 52 deletions(-) diff --git a/adk/interface.go b/adk/interface.go index fef2f98fe..28ba857cb 100644 --- a/adk/interface.go +++ b/adk/interface.go @@ -301,10 +301,10 @@ type Cancellable interface { // Cancel signals the agent to stop, either immediately or after reaching the // specified execution point(s) defined by opt.Mode. // - // The opt parameter must be non-nil. Use opt.Mode to control when the - // cancellation takes effect (e.g., CancelImmediate, CancelAfterChatModel, - // CancelAfterToolCall, or a combination via bitwise OR). + // Use opt.Mode to control when the cancellation takes effect + // (e.g., CancelImmediate, CancelAfterChatModel, CancelAfterToolCall, + // or a combination via bitwise OR). The zero value defaults to CancelImmediate. // // If the agent has already finished execution, Cancel returns ErrAgentFinished. - Cancel(ctx context.Context, opt *CancelOption) error + Cancel(ctx context.Context, opt CancelOption) error } diff --git a/adk/turn_loop.go b/adk/turn_loop.go index d6ac24071..a8f8d4c4f 100644 --- a/adk/turn_loop.go +++ b/adk/turn_loop.go @@ -46,8 +46,8 @@ type ConsumeOption struct { Mode ConsumeMode // CancelOption specifies when and how the running agent should be canceled. // Only meaningful when Mode is ConsumePreemptive and the agent - // implements Cancellable. Default nil value means CancelImmediate. - CancelOption *CancelOption + // implements Cancellable. Default zero value means CancelImmediate. + CancelOption CancelOption } // NonPreemptiveConsumeOption is a convenience value for the common @@ -68,14 +68,14 @@ type MessageSource[T any] interface { type TurnLoopConfig[T any] struct { // Source provides messages to drive the loop. Required. Source MessageSource[T] - // GenInput converts a received message into AgentInput. Required. - GenInput func(ctx context.Context, item T) (*AgentInput, error) + // GenInput converts a received message into AgentInput and optional + // RunOptions for the agent. Required. + GenInput func(ctx context.Context, item T) (*AgentInput, []AgentRunOption, error) // GetAgent returns the Agent to run for a given message. Required. GetAgent func(ctx context.Context, item T) (Agent, error) // OnAgentEvent is called for each event emitted by the agent. Optional. - OnAgentEvent func(ctx context.Context, event *AgentEvent) error - // RunOptions are passed to Agent.Run on each turn. Optional. - RunOptions []AgentRunOption + // The inputItem is the message that triggered the current agent turn. + OnAgentEvent func(ctx context.Context, inputItem T, event *AgentEvent) error // ReceiveTimeout is the timeout passed to Source.Receive on each iteration. // Zero means no timeout. Optional. ReceiveTimeout time.Duration @@ -87,10 +87,9 @@ type TurnLoopConfig[T any] struct { // agent implements Cancellable. type TurnLoop[T any] struct { source MessageSource[T] - genInput func(ctx context.Context, item T) (*AgentInput, error) + genInput func(ctx context.Context, item T) (*AgentInput, []AgentRunOption, error) getAgent func(ctx context.Context, item T) (Agent, error) - onAgentEvent func(ctx context.Context, event *AgentEvent) error - runOptions []AgentRunOption + onAgentEvent func(ctx context.Context, inputItem T, event *AgentEvent) error receiveTimeout time.Duration } @@ -112,7 +111,6 @@ func NewTurnLoop[T any](config TurnLoopConfig[T]) (*TurnLoop[T], error) { genInput: config.GenInput, getAgent: config.GetAgent, onAgentEvent: config.OnAgentEvent, - runOptions: config.RunOptions, receiveTimeout: config.ReceiveTimeout, }, nil } @@ -148,7 +146,7 @@ func (l *TurnLoop[T]) Run(ctx context.Context) error { } for { - input, e := l.genInput(ctx, item) + input, runOpts, e := l.genInput(ctx, item) if e != nil { return fmt.Errorf("failed to generate agent input: %w", e) } @@ -168,7 +166,7 @@ func (l *TurnLoop[T]) Run(ctx context.Context) error { // Run the agent and forward events through a channel so we can // select between agent events and incoming messages. - iter := agent.Run(ctx, input, l.runOptions...) + iter := agent.Run(ctx, input, runOpts...) eventCh := make(chan iterResult, 1) go func() { for { @@ -199,7 +197,7 @@ func (l *TurnLoop[T]) Run(ctx context.Context) error { break eventLoop } if l.onAgentEvent != nil { - if e_ := l.onAgentEvent(ctx, ev.event); e_ != nil { + if e_ := l.onAgentEvent(ctx, item, ev.event); e_ != nil { turnErr = fmt.Errorf("OnAgentEvent failed: %w", e_) break eventLoop } diff --git a/adk/turn_loop_test.go b/adk/turn_loop_test.go index 9ef2d1a90..d27ff4fb8 100644 --- a/adk/turn_loop_test.go +++ b/adk/turn_loop_test.go @@ -84,7 +84,7 @@ type turnLoopCancellableAgent struct { startedCh chan struct{} // closed when Run is entered cancelCh chan struct{} // closed by Cancel cancelled atomic.Bool - cancelledOpt *CancelOption // records the CancelOption passed to Cancel + cancelledOpt CancelOption // records the CancelOption passed to Cancel } func (a *turnLoopCancellableAgent) Name(_ context.Context) string { return a.name } @@ -99,7 +99,7 @@ func (a *turnLoopCancellableAgent) Run(_ context.Context, _ *AgentInput, _ ...Ag return iter } -func (a *turnLoopCancellableAgent) Cancel(_ context.Context, opt *CancelOption) error { +func (a *turnLoopCancellableAgent) Cancel(_ context.Context, opt CancelOption) error { a.cancelled.Store(true) a.cancelledOpt = opt close(a.cancelCh) @@ -137,7 +137,7 @@ func (a *turnLoopBlockingAgent) Run(_ context.Context, _ *AgentInput, _ ...Agent func TestNewTurnLoop_Validation(t *testing.T) { t.Run("missing source", func(t *testing.T) { _, err := NewTurnLoop(TurnLoopConfig[string]{ - GenInput: func(_ context.Context, _ string) (*AgentInput, error) { return nil, nil }, + GenInput: func(_ context.Context, _ string) (*AgentInput, []AgentRunOption, error) { return nil, nil, nil }, GetAgent: func(_ context.Context, _ string) (Agent, error) { return nil, nil }, }) require.Error(t, err) @@ -156,7 +156,7 @@ func TestNewTurnLoop_Validation(t *testing.T) { t.Run("missing GetAgent", func(t *testing.T) { _, err := NewTurnLoop(TurnLoopConfig[string]{ Source: &turnLoopMockSource{}, - GenInput: func(_ context.Context, _ string) (*AgentInput, error) { return nil, nil }, + GenInput: func(_ context.Context, _ string) (*AgentInput, []AgentRunOption, error) { return nil, nil, nil }, }) require.Error(t, err) assert.Contains(t, err.Error(), "GetAgent") @@ -165,7 +165,7 @@ func TestNewTurnLoop_Validation(t *testing.T) { t.Run("valid config", func(t *testing.T) { loop, err := NewTurnLoop(TurnLoopConfig[string]{ Source: &turnLoopMockSource{}, - GenInput: func(_ context.Context, _ string) (*AgentInput, error) { return nil, nil }, + GenInput: func(_ context.Context, _ string) (*AgentInput, []AgentRunOption, error) { return nil, nil, nil }, GetAgent: func(_ context.Context, _ string) (Agent, error) { return nil, nil }, }) require.NoError(t, err) @@ -195,14 +195,14 @@ func TestTurnLoop_NormalLoop(t *testing.T) { loop, err := NewTurnLoop(TurnLoopConfig[string]{ Source: source, - GenInput: func(_ context.Context, item string) (*AgentInput, error) { + GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { receivedItems = append(receivedItems, item) - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil }, GetAgent: func(_ context.Context, _ string) (Agent, error) { return agent, nil }, - OnAgentEvent: func(_ context.Context, event *AgentEvent) error { + OnAgentEvent: func(_ context.Context, _ string, event *AgentEvent) error { receivedEvents = append(receivedEvents, event) return nil }, @@ -224,8 +224,8 @@ func TestTurnLoop_SourceError(t *testing.T) { loop, err := NewTurnLoop(TurnLoopConfig[string]{ Source: source, - GenInput: func(_ context.Context, _ string) (*AgentInput, error) { - return &AgentInput{}, nil + GenInput: func(_ context.Context, _ string) (*AgentInput, []AgentRunOption, error) { + return &AgentInput{}, nil, nil }, GetAgent: func(_ context.Context, _ string) (Agent, error) { return &turnLoopMockAgent{name: "a"}, nil @@ -242,8 +242,8 @@ func TestTurnLoop_GenInputError(t *testing.T) { loop, err := NewTurnLoop(TurnLoopConfig[string]{ Source: &turnLoopMockSource{items: []string{"msg1"}, err: errors.New("should not reach")}, - GenInput: func(_ context.Context, _ string) (*AgentInput, error) { - return nil, genErr + GenInput: func(_ context.Context, _ string) (*AgentInput, []AgentRunOption, error) { + return nil, nil, genErr }, GetAgent: func(_ context.Context, _ string) (Agent, error) { return &turnLoopMockAgent{name: "a"}, nil @@ -261,8 +261,8 @@ func TestTurnLoop_GetAgentError(t *testing.T) { loop, err := NewTurnLoop(TurnLoopConfig[string]{ Source: &turnLoopMockSource{items: []string{"msg1"}, err: errors.New("should not reach")}, - GenInput: func(_ context.Context, _ string) (*AgentInput, error) { - return &AgentInput{}, nil + GenInput: func(_ context.Context, _ string) (*AgentInput, []AgentRunOption, error) { + return &AgentInput{}, nil, nil }, GetAgent: func(_ context.Context, _ string) (Agent, error) { return nil, agentErr @@ -284,13 +284,13 @@ func TestTurnLoop_OnAgentEventError(t *testing.T) { loop, err := NewTurnLoop(TurnLoopConfig[string]{ Source: &turnLoopMockSource{items: []string{"msg1"}, err: errors.New("should not reach")}, - GenInput: func(_ context.Context, _ string) (*AgentInput, error) { - return &AgentInput{}, nil + GenInput: func(_ context.Context, _ string) (*AgentInput, []AgentRunOption, error) { + return &AgentInput{}, nil, nil }, GetAgent: func(_ context.Context, _ string) (Agent, error) { return agent, nil }, - OnAgentEvent: func(_ context.Context, _ *AgentEvent) error { + OnAgentEvent: func(_ context.Context, _ string, _ *AgentEvent) error { return eventErr }, }) @@ -321,8 +321,8 @@ func TestTurnLoop_ContextCancellation(t *testing.T) { loop, err := NewTurnLoop(TurnLoopConfig[string]{ Source: source, - GenInput: func(_ context.Context, item string) (*AgentInput, error) { - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil + GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil }, GetAgent: func(_ context.Context, _ string) (Agent, error) { return agent, nil @@ -348,13 +348,13 @@ func TestTurnLoop_MultipleEventsPerTurn(t *testing.T) { loop, err := NewTurnLoop(TurnLoopConfig[string]{ Source: &turnLoopMockSource{items: []string{"msg1"}, err: context.DeadlineExceeded}, - GenInput: func(_ context.Context, item string) (*AgentInput, error) { - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil + GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil }, GetAgent: func(_ context.Context, _ string) (Agent, error) { return agent, nil }, - OnAgentEvent: func(_ context.Context, _ *AgentEvent) error { + OnAgentEvent: func(_ context.Context, _ string, _ *AgentEvent) error { eventCount++ return nil }, @@ -374,8 +374,8 @@ func TestTurnLoop_NoOnAgentEvent(t *testing.T) { loop, err := NewTurnLoop(TurnLoopConfig[string]{ Source: &turnLoopMockSource{items: []string{"msg1"}, err: context.DeadlineExceeded}, - GenInput: func(_ context.Context, item string) (*AgentInput, error) { - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil + GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil }, GetAgent: func(_ context.Context, _ string) (Agent, error) { return agent, nil @@ -396,8 +396,8 @@ func TestTurnLoop_AgentErrorEvent(t *testing.T) { loop, err := NewTurnLoop(TurnLoopConfig[string]{ Source: &turnLoopMockSource{items: []string{"msg1"}, err: context.DeadlineExceeded}, - GenInput: func(_ context.Context, item string) (*AgentInput, error) { - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil + GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil }, GetAgent: func(_ context.Context, _ string) (Agent, error) { return agent, nil @@ -435,7 +435,7 @@ func TestTurnLoop_PreemptiveCancellation(t *testing.T) { return "slow-msg", NonPreemptiveConsumeOption, nil case 2: <-slowAgent.startedCh - return "preempt-msg", ConsumeOption{Mode: ConsumePreemptive, CancelOption: &CancelOption{Mode: CancelImmediate}}, nil + return "preempt-msg", ConsumeOption{Mode: ConsumePreemptive, CancelOption: CancelOption{Mode: CancelImmediate}}, nil default: return "", NonPreemptiveConsumeOption, context.DeadlineExceeded } @@ -443,9 +443,9 @@ func TestTurnLoop_PreemptiveCancellation(t *testing.T) { loop, err := NewTurnLoop(TurnLoopConfig[string]{ Source: source, - GenInput: func(_ context.Context, item string) (*AgentInput, error) { + GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { processedItems = append(processedItems, item) - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil }, GetAgent: func(_ context.Context, item string) (Agent, error) { if item == "slow-msg" { @@ -485,7 +485,7 @@ func TestTurnLoop_PreemptiveWithCancelMode(t *testing.T) { <-slowAgent.startedCh return "preempt-msg", ConsumeOption{ Mode: ConsumePreemptive, - CancelOption: &CancelOption{Mode: CancelAfterChatModel | CancelAfterToolCall}, + CancelOption: CancelOption{Mode: CancelAfterChatModel | CancelAfterToolCall}, }, nil default: return "", NonPreemptiveConsumeOption, context.DeadlineExceeded @@ -494,8 +494,8 @@ func TestTurnLoop_PreemptiveWithCancelMode(t *testing.T) { loop, err := NewTurnLoop(TurnLoopConfig[string]{ Source: source, - GenInput: func(_ context.Context, item string) (*AgentInput, error) { - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil + GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil }, GetAgent: func(_ context.Context, item string) (Agent, error) { if item == "slow-msg" { @@ -545,9 +545,9 @@ func TestTurnLoop_PreemptiveNonCancellableAgent(t *testing.T) { loop, err := NewTurnLoop(TurnLoopConfig[string]{ Source: source, - GenInput: func(_ context.Context, item string) (*AgentInput, error) { + GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { processedItems = append(processedItems, item) - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil }, GetAgent: func(_ context.Context, item string) (Agent, error) { if item == "blocking-msg" { From 1f5017022d1f23163f48aeee83e31a62a579c12a Mon Sep 17 00:00:00 2001 From: "luohuaqing.2018" Date: Fri, 13 Feb 2026 12:42:59 +0800 Subject: [PATCH 33/59] fix(adk): TurnLoop.Run Change-Id: Iaa0da27dcab3fdd446828a592142a4504278daf7 --- adk/turn_loop.go | 178 +++++++++++++++++++----------------------- adk/turn_loop_test.go | 52 +++--------- 2 files changed, 95 insertions(+), 135 deletions(-) diff --git a/adk/turn_loop.go b/adk/turn_loop.go index a8f8d4c4f..46100e1ee 100644 --- a/adk/turn_loop.go +++ b/adk/turn_loop.go @@ -19,7 +19,10 @@ package adk import ( "context" "fmt" + `runtime/debug` "time" + + `github.com/cloudwego/eino/internal/safe` ) // ConsumeMode specifies how a received message should be consumed @@ -115,34 +118,20 @@ func NewTurnLoop[T any](config TurnLoopConfig[T]) (*TurnLoop[T], error) { }, nil } -// recvResult holds the result of a concurrent Receive call. -type recvResult[T any] struct { - item T - option ConsumeOption - err error -} - -// iterResult holds the result of a single AsyncIterator.Next call. -type iterResult struct { - event *AgentEvent - ok bool -} - -// Run starts the blocking loop that continuously receives messages, runs -// agents, and dispatches events. While an agent is running, the next message -// is received concurrently. If that message's ConsumeOption has ConsumePreemptive -// mode and the running agent implements Cancellable, the agent is canceled -// (using the CancelMode from the option) and the new message is processed -// immediately. +// Run starts the blocking loop that continuously receives messages from the +// source, runs the agent returned by GetAgent for each message, and dispatches +// resulting events to OnAgentEvent. It blocks until the source returns an error +// (including context cancellation) or a callback fails. +// +// If a received message has ConsumePreemptive mode and the current agent +// implements Cancellable, the agent is canceled and the new message is processed +// immediately. If the agent does not implement Cancellable, preemptive messages +// are queued and processed after the current agent finishes. func (l *TurnLoop[T]) Run(ctx context.Context) error { - // done is closed when Run returns, signaling background goroutines to exit. - done := make(chan struct{}) - defer close(done) - - // Initial blocking receive — no agent running yet, mode is irrelevant. - item, _, err := l.source.Receive(ctx, l.receiveTimeout) + // Initial blocking receive — no agent is running yet. + item, option, err := l.source.Receive(ctx, l.receiveTimeout) if err != nil { - return err + return fmt.Errorf("failed to receive message: %w", err) } for { @@ -156,94 +145,91 @@ func (l *TurnLoop[T]) Run(ctx context.Context) error { return fmt.Errorf("failed to get agent: %w", e) } - // Start receiving the next message concurrently with agent execution. - // The channel is buffered so the goroutine never blocks on send. - recvCh := make(chan recvResult[T], 1) - go func() { - i, opt, e_ := l.source.Receive(ctx, l.receiveTimeout) - recvCh <- recvResult[T]{i, opt, e_} - }() - - // Run the agent and forward events through a channel so we can - // select between agent events and incoming messages. + ca, isAgentCancellable := agent.(Cancellable) iter := agent.Run(ctx, input, runOpts...) - eventCh := make(chan iterResult, 1) - go func() { + + // handleEvents drains the agent iterator, forwarding each event to the + // OnAgentEvent callback. It is called directly in the non-cancellable + // path and from a goroutine in the cancellable path. + handleEvents := func() error { for { event, ok := iter.Next() - select { - case eventCh <- iterResult{event, ok}: - case <-done: - return - } if !ok { - return + break } - } - }() - - var pending *recvResult[T] - var turnErr error - eventLoop: - for { - select { - case ev := <-eventCh: - if !ev.ok { - break eventLoop - } - if ev.event.Err != nil { - turnErr = fmt.Errorf("agent run failed: %w", ev.event.Err) - break eventLoop + if event.Err != nil { + return fmt.Errorf("agent run failed: %w", event.Err) } + if l.onAgentEvent != nil { - if e_ := l.onAgentEvent(ctx, item, ev.event); e_ != nil { - turnErr = fmt.Errorf("OnAgentEvent failed: %w", e_) - break eventLoop + e = l.onAgentEvent(ctx, item, event) + if e != nil { + return fmt.Errorf("OnAgentEvent callback failed: %w", e) } } + } - case recv := <-recvCh: - recvCh = nil // nil channel never matches in select - if recv.option.Mode == ConsumePreemptive && recv.err == nil { - if ca, ok := agent.(Cancellable); ok { - if e_ := ca.Cancel(ctx, recv.option.CancelOption); e_ != nil { - return fmt.Errorf("failed to cancel agent: %w", e_) - } - // Drain remaining events after cancellation. - for { - ev := <-eventCh - if !ev.ok { - break - } - } - pending = &recvResult[T]{item: recv.item} - break eventLoop + return nil + } + + var handleEventErr error + if isAgentCancellable { + // Cancellable path: consume events in a goroutine so the main + // goroutine can block on Receive concurrently. + done := make(chan struct{}) + + go func() { + defer func() { + // Recover panics from the iterator or callback so they + // don't crash the process; surface them as errors instead. + panicErr := recover() + if panicErr != nil { + handleEventErr = safe.NewPanicErr(panicErr, debug.Stack()) } - } - // Non-preemptive, preemptive but not Cancellable, or - // source error: buffer and let the eventLoop finish - // processing the current agent's events first. - pending = &recvResult[T]{item: recv.item, err: recv.err} + + close(done) + }() + + handleEventErr = handleEvents() + }() + + // Block on the next message while events are being consumed above. + item, option, err = l.source.Receive(ctx, l.receiveTimeout) + if err != nil { + <-done // wait for the event goroutine before returning + return fmt.Errorf("failed to receive message: %w", err) } - } - if turnErr != nil { - return turnErr - } + // If the new message requests preemption, cancel the running agent. + // Cancel triggers the iterator to terminate, which unblocks the + // event goroutine above. + if option.Mode == ConsumePreemptive { + err = ca.Cancel(ctx, option.CancelOption) + if err != nil { + <-done // wait for the event goroutine before returning + return fmt.Errorf("failed to cancel agent: %w", err) + } + } - if pending != nil { - if pending.err != nil { - return pending.err + // Wait for event consumption to finish (normal completion or + // post-cancel drain) before starting the next turn. + <-done + if handleEventErr != nil { + return fmt.Errorf("failed to handle events: %w", handleEventErr) } - item = pending.item } else { - // Agent finished before the next message arrived; wait for it. - recv := <-recvCh - if recv.err != nil { - return recv.err + // Non-cancellable path: consume all events sequentially, then + // block on the next message. + if handleEventErr = handleEvents(); handleEventErr != nil { + return fmt.Errorf("failed to handle events: %w", handleEventErr) + } + + item, option, err = l.source.Receive(ctx, l.receiveTimeout) + if err != nil { + return fmt.Errorf("failed to receive message: %w", err) } - item = recv.item } } } + diff --git a/adk/turn_loop_test.go b/adk/turn_loop_test.go index d27ff4fb8..f66c1298b 100644 --- a/adk/turn_loop_test.go +++ b/adk/turn_loop_test.go @@ -106,30 +106,6 @@ func (a *turnLoopCancellableAgent) Cancel(_ context.Context, opt CancelOption) e return nil } -// turnLoopBlockingAgent blocks until blockCh is closed, then emits its events. -// It does NOT implement Cancellable. -type turnLoopBlockingAgent struct { - name string - startedCh chan struct{} - blockCh chan struct{} - events []*AgentEvent -} - -func (a *turnLoopBlockingAgent) Name(_ context.Context) string { return a.name } -func (a *turnLoopBlockingAgent) Description(_ context.Context) string { return "blocking mock" } -func (a *turnLoopBlockingAgent) Run(_ context.Context, _ *AgentInput, _ ...AgentRunOption) *AsyncIterator[*AgentEvent] { - iter, gen := NewAsyncIteratorPair[*AgentEvent]() - close(a.startedCh) - go func() { - defer gen.Close() - <-a.blockCh - for _, e := range a.events { - gen.Send(e) - } - }() - return iter -} - // --------------------------------------------------------------------------- // Tests — validation // --------------------------------------------------------------------------- @@ -298,7 +274,7 @@ func TestTurnLoop_OnAgentEventError(t *testing.T) { err = loop.Run(context.Background()) assert.ErrorIs(t, err, eventErr) - assert.Contains(t, err.Error(), "OnAgentEvent failed") + assert.Contains(t, err.Error(), "OnAgentEvent callback failed") } func TestTurnLoop_ContextCancellation(t *testing.T) { @@ -513,14 +489,12 @@ func TestTurnLoop_PreemptiveWithCancelMode(t *testing.T) { } func TestTurnLoop_PreemptiveNonCancellableAgent(t *testing.T) { - agentStarted := make(chan struct{}) - agentContinue := make(chan struct{}) - - blockingAgent := &turnLoopBlockingAgent{ - name: "blocking-agent", - startedCh: agentStarted, - blockCh: agentContinue, - events: []*AgentEvent{{Output: &AgentOutput{}}}, + // A non-cancellable agent cannot be preempted, so the new Run processes + // events sequentially before calling Receive. The preemptive message is + // effectively queued and processed in the next turn. + nonCancellableAgent := &turnLoopMockAgent{ + name: "non-cancellable-agent", + events: []*AgentEvent{{Output: &AgentOutput{}}}, } fastAgent := &turnLoopMockAgent{ name: "fast-agent", @@ -533,10 +507,10 @@ func TestTurnLoop_PreemptiveNonCancellableAgent(t *testing.T) { callCount++ switch callCount { case 1: - return "blocking-msg", NonPreemptiveConsumeOption, nil + return "non-cancel-msg", NonPreemptiveConsumeOption, nil case 2: - <-agentStarted - close(agentContinue) + // Even though Mode is ConsumePreemptive, the agent doesn't + // implement Cancellable, so it's treated as non-preemptive. return "preempt-msg", ConsumeOption{Mode: ConsumePreemptive}, nil default: return "", NonPreemptiveConsumeOption, context.DeadlineExceeded @@ -550,8 +524,8 @@ func TestTurnLoop_PreemptiveNonCancellableAgent(t *testing.T) { return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil }, GetAgent: func(_ context.Context, item string) (Agent, error) { - if item == "blocking-msg" { - return blockingAgent, nil + if item == "non-cancel-msg" { + return nonCancellableAgent, nil } return fastAgent, nil }, @@ -560,5 +534,5 @@ func TestTurnLoop_PreemptiveNonCancellableAgent(t *testing.T) { err = loop.Run(context.Background()) assert.ErrorIs(t, err, context.DeadlineExceeded) - assert.Equal(t, []string{"blocking-msg", "preempt-msg"}, processedItems) + assert.Equal(t, []string{"non-cancel-msg", "preempt-msg"}, processedItems) } From ca1c8b17f5883339e1333605f8ef28de667c0ca6 Mon Sep 17 00:00:00 2001 From: Megumin Date: Sat, 14 Feb 2026 11:19:24 +0800 Subject: [PATCH 34/59] chore --- adk/interface.go | 41 ++++++++++++------- adk/turn_loop.go | 104 ++++++++++++++++++++++++++++++++--------------- 2 files changed, 97 insertions(+), 48 deletions(-) diff --git a/adk/interface.go b/adk/interface.go index 28ba857cb..2e7e3d09c 100644 --- a/adk/interface.go +++ b/adk/interface.go @@ -23,6 +23,7 @@ import ( "errors" "fmt" "io" + "time" "github.com/cloudwego/eino/components" "github.com/cloudwego/eino/internal/core" @@ -290,21 +291,31 @@ const ( // ErrAgentFinished is returned by Cancel when the agent has already finished execution. var ErrAgentFinished = errors.New("agent has already finished execution") -// CancelOption holds options for cancelling an agent. -type CancelOption struct { - Mode CancelMode +type cancelConfig struct { + Mode CancelMode + Timeout *time.Duration } -// Cancellable is an optional interface that an Agent can implement to support -// cancellation during execution. -type Cancellable interface { - // Cancel signals the agent to stop, either immediately or after reaching the - // specified execution point(s) defined by opt.Mode. - // - // Use opt.Mode to control when the cancellation takes effect - // (e.g., CancelImmediate, CancelAfterChatModel, CancelAfterToolCall, - // or a combination via bitwise OR). The zero value defaults to CancelImmediate. - // - // If the agent has already finished execution, Cancel returns ErrAgentFinished. - Cancel(ctx context.Context, opt CancelOption) error +type CancelOption func(*cancelConfig) + +func WithCancelMode(mode CancelMode) CancelOption { + return func(config *cancelConfig) { + config.Mode = mode + } +} + +func WithCancelTimeout(timeout time.Duration) CancelOption { + return func(config *cancelConfig) { + config.Timeout = &timeout + } +} + +type CancelFunc func(context.Context, ...CancelOption) error + +type CancellableRun interface { + RunWithCancel(ctx context.Context, input *AgentInput, options ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) +} + +type CancellableResume interface { + ResumeWithCancel(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) } diff --git a/adk/turn_loop.go b/adk/turn_loop.go index 46100e1ee..164bb44ed 100644 --- a/adk/turn_loop.go +++ b/adk/turn_loop.go @@ -19,10 +19,10 @@ package adk import ( "context" "fmt" - `runtime/debug` + "runtime/debug" "time" - `github.com/cloudwego/eino/internal/safe` + "github.com/cloudwego/eino/internal/safe" ) // ConsumeMode specifies how a received message should be consumed @@ -38,33 +38,58 @@ const ( // If the agent does not implement Cancellable, the message is // buffered and processed after the agent finishes. ConsumePreemptive + + ConsumePreemptiveOnTimeout ) -// ConsumeOption describes how a received message should be consumed. -// It combines ConsumeMode (preemptive vs queued) with CancelMode -// (how to cancel the running agent when preempting). -type ConsumeOption struct { - // Mode specifies whether the message should preempt the current agent - // or be queued. Default zero value is ConsumeNonPreemptive. - Mode ConsumeMode - // CancelOption specifies when and how the running agent should be canceled. - // Only meaningful when Mode is ConsumePreemptive and the agent - // implements Cancellable. Default zero value means CancelImmediate. - CancelOption CancelOption +type consumeConfig struct { + Mode ConsumeMode + Timeout time.Duration + CancelOpts []CancelOption +} + +type ConsumeOption func(*consumeConfig) + +func WithPreemptive() ConsumeOption { + return func(config *consumeConfig) { + config.Mode = ConsumePreemptive + } +} + +func WithPreemptiveOnTimeout(timeout time.Duration) ConsumeOption { + return func(config *consumeConfig) { + config.Mode = ConsumePreemptive + config.Timeout = timeout + } +} + +func WithCancelOptions(opts ...CancelOption) ConsumeOption { + return func(config *consumeConfig) { + config.CancelOpts = append(config.CancelOpts, opts...) + } } -// NonPreemptiveConsumeOption is a convenience value for the common -// non-preemptive (queued) case. -var NonPreemptiveConsumeOption = ConsumeOption{Mode: ConsumeNonPreemptive} +type receiveConfig struct { + Timeout *time.Duration + NonBlocking bool +} + +type ReceiveOption func(*receiveConfig) + +func WithReceiveTimeout(timeout time.Duration) ReceiveOption { + return func(config *receiveConfig) { + config.Timeout = &timeout + } +} + +func WithReceiveNonBlocking() ReceiveOption { + return func(config *receiveConfig) { + config.NonBlocking = true + } +} -// MessageSource is an interface for pulling typed messages from an external source. -// Receive blocks until a message is available or an error occurs. -// The timeout parameter specifies the maximum duration to wait for a message. -// The returned ConsumeOption indicates whether the message should preempt the -// currently running agent (and how to cancel it) or be queued for processing -// after it finishes. type MessageSource[T any] interface { - Receive(ctx context.Context, timeout time.Duration) (T, ConsumeOption, error) + Receive(ctx context.Context, option ...ReceiveOption) (context.Context, T, []ConsumeOption, error) } // TurnLoopConfig is the configuration for creating a TurnLoop. @@ -129,24 +154,29 @@ func NewTurnLoop[T any](config TurnLoopConfig[T]) (*TurnLoop[T], error) { // are queued and processed after the current agent finishes. func (l *TurnLoop[T]) Run(ctx context.Context) error { // Initial blocking receive — no agent is running yet. - item, option, err := l.source.Receive(ctx, l.receiveTimeout) + nCtx, item, option, err := l.source.Receive(ctx, WithReceiveTimeout(l.receiveTimeout)) if err != nil { return fmt.Errorf("failed to receive message: %w", err) } for { - input, runOpts, e := l.genInput(ctx, item) + input, runOpts, e := l.genInput(nCtx, item) if e != nil { return fmt.Errorf("failed to generate agent input: %w", e) } - agent, e := l.getAgent(ctx, item) + agent, e := l.getAgent(nCtx, item) if e != nil { return fmt.Errorf("failed to get agent: %w", e) } - ca, isAgentCancellable := agent.(Cancellable) - iter := agent.Run(ctx, input, runOpts...) + var cancelFunc CancelFunc + var iter *AsyncIterator[*AgentEvent] + if ca, isAgentCancellable := agent.(CancellableRun); isAgentCancellable { + iter, cancelFunc = ca.RunWithCancel(nCtx, input, runOpts...) + } else { + iter = agent.Run(nCtx, input, runOpts...) + } // handleEvents drains the agent iterator, forwarding each event to the // OnAgentEvent callback. It is called directly in the non-cancellable @@ -174,7 +204,7 @@ func (l *TurnLoop[T]) Run(ctx context.Context) error { } var handleEventErr error - if isAgentCancellable { + if cancelFunc != nil { // Cancellable path: consume events in a goroutine so the main // goroutine can block on Receive concurrently. done := make(chan struct{}) @@ -195,7 +225,7 @@ func (l *TurnLoop[T]) Run(ctx context.Context) error { }() // Block on the next message while events are being consumed above. - item, option, err = l.source.Receive(ctx, l.receiveTimeout) + nCtx, item, option, err = l.source.Receive(ctx, WithReceiveTimeout(l.receiveTimeout)) if err != nil { <-done // wait for the event goroutine before returning return fmt.Errorf("failed to receive message: %w", err) @@ -204,8 +234,9 @@ func (l *TurnLoop[T]) Run(ctx context.Context) error { // If the new message requests preemption, cancel the running agent. // Cancel triggers the iterator to terminate, which unblocks the // event goroutine above. - if option.Mode == ConsumePreemptive { - err = ca.Cancel(ctx, option.CancelOption) + o := applyConsumeOptions(option) + if o.Mode == ConsumePreemptive { + err = cancelFunc(ctx, o.CancelOpts...) if err != nil { <-done // wait for the event goroutine before returning return fmt.Errorf("failed to cancel agent: %w", err) @@ -225,7 +256,7 @@ func (l *TurnLoop[T]) Run(ctx context.Context) error { return fmt.Errorf("failed to handle events: %w", handleEventErr) } - item, option, err = l.source.Receive(ctx, l.receiveTimeout) + nCtx, item, option, err = l.source.Receive(ctx, WithReceiveTimeout(l.receiveTimeout)) if err != nil { return fmt.Errorf("failed to receive message: %w", err) } @@ -233,3 +264,10 @@ func (l *TurnLoop[T]) Run(ctx context.Context) error { } } +func applyConsumeOptions(opts []ConsumeOption) *consumeConfig { + var config consumeConfig + for _, opt := range opts { + opt(&config) + } + return &config +} From 5f674e2a9314c69ed70309f5f1b054f97515dc45 Mon Sep 17 00:00:00 2001 From: Megumin Date: Sat, 14 Feb 2026 11:33:38 +0800 Subject: [PATCH 35/59] chore --- adk/turn_loop.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/adk/turn_loop.go b/adk/turn_loop.go index 164bb44ed..376cd667d 100644 --- a/adk/turn_loop.go +++ b/adk/turn_loop.go @@ -225,7 +225,7 @@ func (l *TurnLoop[T]) Run(ctx context.Context) error { }() // Block on the next message while events are being consumed above. - nCtx, item, option, err = l.source.Receive(ctx, WithReceiveTimeout(l.receiveTimeout)) + nCtx, item, option, err = l.source.Receive(nCtx, WithReceiveTimeout(l.receiveTimeout)) if err != nil { <-done // wait for the event goroutine before returning return fmt.Errorf("failed to receive message: %w", err) @@ -236,7 +236,7 @@ func (l *TurnLoop[T]) Run(ctx context.Context) error { // event goroutine above. o := applyConsumeOptions(option) if o.Mode == ConsumePreemptive { - err = cancelFunc(ctx, o.CancelOpts...) + err = cancelFunc(nCtx, o.CancelOpts...) if err != nil { <-done // wait for the event goroutine before returning return fmt.Errorf("failed to cancel agent: %w", err) @@ -256,7 +256,7 @@ func (l *TurnLoop[T]) Run(ctx context.Context) error { return fmt.Errorf("failed to handle events: %w", handleEventErr) } - nCtx, item, option, err = l.source.Receive(ctx, WithReceiveTimeout(l.receiveTimeout)) + nCtx, item, option, err = l.source.Receive(nCtx, WithReceiveTimeout(l.receiveTimeout)) if err != nil { return fmt.Errorf("failed to receive message: %w", err) } From b6d24c6c3120e5b7bf96cfeef9cd553d4fce63e7 Mon Sep 17 00:00:00 2001 From: Megumin Date: Sat, 14 Feb 2026 11:42:05 +0800 Subject: [PATCH 36/59] chore --- adk/turn_loop.go | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/adk/turn_loop.go b/adk/turn_loop.go index 376cd667d..673a8fcea 100644 --- a/adk/turn_loop.go +++ b/adk/turn_loop.go @@ -58,7 +58,7 @@ func WithPreemptive() ConsumeOption { func WithPreemptiveOnTimeout(timeout time.Duration) ConsumeOption { return func(config *consumeConfig) { - config.Mode = ConsumePreemptive + config.Mode = ConsumePreemptiveOnTimeout config.Timeout = timeout } } @@ -235,12 +235,23 @@ func (l *TurnLoop[T]) Run(ctx context.Context) error { // Cancel triggers the iterator to terminate, which unblocks the // event goroutine above. o := applyConsumeOptions(option) - if o.Mode == ConsumePreemptive { + switch o.Mode { + case ConsumePreemptive: err = cancelFunc(nCtx, o.CancelOpts...) if err != nil { <-done // wait for the event goroutine before returning return fmt.Errorf("failed to cancel agent: %w", err) } + case ConsumePreemptiveOnTimeout: + select { + case <-done: + case <-time.After(o.Timeout): + err = cancelFunc(nCtx, o.CancelOpts...) + if err != nil { + <-done // wait for the event goroutine before returning + return fmt.Errorf("failed to cancel agent: %w", err) + } + } } // Wait for event consumption to finish (normal completion or From 8282cd682b819686cdc57d3b497034b98e511e28 Mon Sep 17 00:00:00 2001 From: Megumin Date: Sat, 14 Feb 2026 16:13:04 +0800 Subject: [PATCH 37/59] chore --- adk/turn_loop.go | 33 ++++++++++++--------------------- 1 file changed, 12 insertions(+), 21 deletions(-) diff --git a/adk/turn_loop.go b/adk/turn_loop.go index 673a8fcea..ded163547 100644 --- a/adk/turn_loop.go +++ b/adk/turn_loop.go @@ -69,27 +69,12 @@ func WithCancelOptions(opts ...CancelOption) ConsumeOption { } } -type receiveConfig struct { - Timeout *time.Duration - NonBlocking bool -} - -type ReceiveOption func(*receiveConfig) - -func WithReceiveTimeout(timeout time.Duration) ReceiveOption { - return func(config *receiveConfig) { - config.Timeout = &timeout - } -} - -func WithReceiveNonBlocking() ReceiveOption { - return func(config *receiveConfig) { - config.NonBlocking = true - } +type ReceiveConfig struct { + Timeout time.Duration } type MessageSource[T any] interface { - Receive(ctx context.Context, option ...ReceiveOption) (context.Context, T, []ConsumeOption, error) + Receive(context.Context, ReceiveConfig) (context.Context, T, []ConsumeOption, error) } // TurnLoopConfig is the configuration for creating a TurnLoop. @@ -154,7 +139,9 @@ func NewTurnLoop[T any](config TurnLoopConfig[T]) (*TurnLoop[T], error) { // are queued and processed after the current agent finishes. func (l *TurnLoop[T]) Run(ctx context.Context) error { // Initial blocking receive — no agent is running yet. - nCtx, item, option, err := l.source.Receive(ctx, WithReceiveTimeout(l.receiveTimeout)) + nCtx, item, option, err := l.source.Receive(ctx, ReceiveConfig{ + Timeout: l.receiveTimeout, + }) if err != nil { return fmt.Errorf("failed to receive message: %w", err) } @@ -225,7 +212,9 @@ func (l *TurnLoop[T]) Run(ctx context.Context) error { }() // Block on the next message while events are being consumed above. - nCtx, item, option, err = l.source.Receive(nCtx, WithReceiveTimeout(l.receiveTimeout)) + nCtx, item, option, err = l.source.Receive(nCtx, ReceiveConfig{ + Timeout: l.receiveTimeout, + }) if err != nil { <-done // wait for the event goroutine before returning return fmt.Errorf("failed to receive message: %w", err) @@ -267,7 +256,9 @@ func (l *TurnLoop[T]) Run(ctx context.Context) error { return fmt.Errorf("failed to handle events: %w", handleEventErr) } - nCtx, item, option, err = l.source.Receive(nCtx, WithReceiveTimeout(l.receiveTimeout)) + nCtx, item, option, err = l.source.Receive(nCtx, ReceiveConfig{ + Timeout: l.receiveTimeout, + }) if err != nil { return fmt.Errorf("failed to receive message: %w", err) } From 1353c2695e39a1df65db4e52c2b226d97f57cbf0 Mon Sep 17 00:00:00 2001 From: Megumin Date: Thu, 19 Feb 2026 18:16:46 +0800 Subject: [PATCH 38/59] feat(adk): modify on agent events (#795) --- adk/turn_loop.go | 26 ++++++-------------------- 1 file changed, 6 insertions(+), 20 deletions(-) diff --git a/adk/turn_loop.go b/adk/turn_loop.go index ded163547..bc4d9b69c 100644 --- a/adk/turn_loop.go +++ b/adk/turn_loop.go @@ -88,7 +88,7 @@ type TurnLoopConfig[T any] struct { GetAgent func(ctx context.Context, item T) (Agent, error) // OnAgentEvent is called for each event emitted by the agent. Optional. // The inputItem is the message that triggered the current agent turn. - OnAgentEvent func(ctx context.Context, inputItem T, event *AgentEvent) error + OnAgentEvents func(ctx context.Context, inputItem T, event *AsyncIterator[*AgentEvent]) error // ReceiveTimeout is the timeout passed to Source.Receive on each iteration. // Zero means no timeout. Optional. ReceiveTimeout time.Duration @@ -102,7 +102,7 @@ type TurnLoop[T any] struct { source MessageSource[T] genInput func(ctx context.Context, item T) (*AgentInput, []AgentRunOption, error) getAgent func(ctx context.Context, item T) (Agent, error) - onAgentEvent func(ctx context.Context, inputItem T, event *AgentEvent) error + onAgentEvents func(ctx context.Context, inputItem T, event *AsyncIterator[*AgentEvent]) error receiveTimeout time.Duration } @@ -123,7 +123,7 @@ func NewTurnLoop[T any](config TurnLoopConfig[T]) (*TurnLoop[T], error) { source: config.Source, genInput: config.GenInput, getAgent: config.GetAgent, - onAgentEvent: config.OnAgentEvent, + onAgentEvents: config.OnAgentEvents, receiveTimeout: config.ReceiveTimeout, }, nil } @@ -169,24 +169,10 @@ func (l *TurnLoop[T]) Run(ctx context.Context) error { // OnAgentEvent callback. It is called directly in the non-cancellable // path and from a goroutine in the cancellable path. handleEvents := func() error { - for { - event, ok := iter.Next() - if !ok { - break - } - - if event.Err != nil { - return fmt.Errorf("agent run failed: %w", event.Err) - } - - if l.onAgentEvent != nil { - e = l.onAgentEvent(ctx, item, event) - if e != nil { - return fmt.Errorf("OnAgentEvent callback failed: %w", e) - } - } + oe := l.onAgentEvents(ctx, item, iter) + if oe != nil { + return oe } - return nil } From 8b5c0206c99579a8fe1f2aefd924508ebd3f1435 Mon Sep 17 00:00:00 2001 From: Megumin Date: Sat, 21 Feb 2026 11:45:32 +0800 Subject: [PATCH 39/59] feat(adk): turn loop support front and exit loop (#796) --- adk/turn_loop.go | 42 +++++++++++++++++++++++++----------------- 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/adk/turn_loop.go b/adk/turn_loop.go index bc4d9b69c..e5cf1c36f 100644 --- a/adk/turn_loop.go +++ b/adk/turn_loop.go @@ -18,6 +18,7 @@ package adk import ( "context" + "errors" "fmt" "runtime/debug" "time" @@ -75,6 +76,7 @@ type ReceiveConfig struct { type MessageSource[T any] interface { Receive(context.Context, ReceiveConfig) (context.Context, T, []ConsumeOption, error) + Front(context.Context, ReceiveConfig) (context.Context, T, []ConsumeOption, error) } // TurnLoopConfig is the configuration for creating a TurnLoop. @@ -128,6 +130,8 @@ func NewTurnLoop[T any](config TurnLoopConfig[T]) (*TurnLoop[T], error) { }, nil } +var ErrLoopExit = errors.New("loop exit") + // Run starts the blocking loop that continuously receives messages from the // source, runs the agent returned by GetAgent for each message, and dispatches // resulting events to OnAgentEvent. It blocks until the source returns an error @@ -138,15 +142,17 @@ func NewTurnLoop[T any](config TurnLoopConfig[T]) (*TurnLoop[T], error) { // immediately. If the agent does not implement Cancellable, preemptive messages // are queued and processed after the current agent finishes. func (l *TurnLoop[T]) Run(ctx context.Context) error { - // Initial blocking receive — no agent is running yet. - nCtx, item, option, err := l.source.Receive(ctx, ReceiveConfig{ - Timeout: l.receiveTimeout, - }) - if err != nil { - return fmt.Errorf("failed to receive message: %w", err) - } - for { + nCtx, item, option, err := l.source.Receive(ctx, ReceiveConfig{ + Timeout: l.receiveTimeout, + }) + if errors.Is(err, ErrLoopExit) { + return nil + } + if err != nil { + return fmt.Errorf("failed to receive message: %w", err) + } + input, runOpts, e := l.genInput(nCtx, item) if e != nil { return fmt.Errorf("failed to generate agent input: %w", e) @@ -198,12 +204,15 @@ func (l *TurnLoop[T]) Run(ctx context.Context) error { }() // Block on the next message while events are being consumed above. - nCtx, item, option, err = l.source.Receive(nCtx, ReceiveConfig{ + _, _, option, err = l.source.Front(nCtx, ReceiveConfig{ Timeout: l.receiveTimeout, }) if err != nil { <-done // wait for the event goroutine before returning - return fmt.Errorf("failed to receive message: %w", err) + if errors.Is(err, ErrLoopExit) { + return nil + } + return fmt.Errorf("failed to front message: %w", err) } // If the new message requests preemption, cancel the running agent. @@ -233,21 +242,20 @@ func (l *TurnLoop[T]) Run(ctx context.Context) error { // post-cancel drain) before starting the next turn. <-done if handleEventErr != nil { + if errors.Is(handleEventErr, ErrLoopExit) { + return nil + } return fmt.Errorf("failed to handle events: %w", handleEventErr) } } else { // Non-cancellable path: consume all events sequentially, then // block on the next message. if handleEventErr = handleEvents(); handleEventErr != nil { + if errors.Is(handleEventErr, ErrLoopExit) { + return nil + } return fmt.Errorf("failed to handle events: %w", handleEventErr) } - - nCtx, item, option, err = l.source.Receive(nCtx, ReceiveConfig{ - Timeout: l.receiveTimeout, - }) - if err != nil { - return fmt.Errorf("failed to receive message: %w", err) - } } } } From 64bc098a304f732b1b5236481f223fba31d2518d Mon Sep 17 00:00:00 2001 From: IPender Date: Tue, 24 Feb 2026 14:29:37 +0800 Subject: [PATCH 40/59] feat(adk): implement cancel mechanism for ChatModelAgent (#797) --- .gitignore | 1 + adk/cancel_test.go | 1023 +++++++++++++++++++++++++++++++++++++++++ adk/cancel_wrapper.go | 280 +++++++++++ adk/chatmodel.go | 149 ++++-- adk/flow.go | 96 ++-- adk/interface.go | 11 +- adk/react.go | 100 +++- adk/react_test.go | 12 +- adk/runner.go | 96 +++- adk/turn_loop.go | 428 ++++++++++++++--- adk/turn_loop_test.go | 811 +++++++++++++++++++++++++++----- 11 files changed, 2747 insertions(+), 260 deletions(-) create mode 100644 adk/cancel_test.go create mode 100644 adk/cancel_wrapper.go diff --git a/.gitignore b/.gitignore index 8ef36de95..8ac1d568d 100644 --- a/.gitignore +++ b/.gitignore @@ -47,6 +47,7 @@ output/* # Reports (generated analysis files) reports/ +/todos .DS_Store *.log diff --git a/adk/cancel_test.go b/adk/cancel_test.go new file mode 100644 index 000000000..f45e344ff --- /dev/null +++ b/adk/cancel_test.go @@ -0,0 +1,1023 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/components/tool" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" +) + +type cancelTestChatModel struct { + delay time.Duration + response *schema.Message + startedChan chan struct{} + doneChan chan struct{} +} + +func (m *cancelTestChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + select { + case m.startedChan <- struct{}{}: + default: + } + time.Sleep(m.delay) + select { + case m.doneChan <- struct{}{}: + default: + } + return m.response, nil +} + +func (m *cancelTestChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + m.startedChan <- struct{}{} + time.Sleep(m.delay) + m.doneChan <- struct{}{} + return schema.StreamReaderFromArray([]*schema.Message{m.response}), nil +} + +func (m *cancelTestChatModel) BindTools(tools []*schema.ToolInfo) error { + return nil +} + +type slowTool struct { + name string + delay time.Duration + result string + callCount int32 + startedChan chan struct{} +} + +func newSlowTool(name string, delay time.Duration, result string) *slowTool { + return &slowTool{ + name: name, + delay: delay, + result: result, + startedChan: make(chan struct{}, 10), + } +} + +func (t *slowTool) Info(ctx context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: t.name, + Desc: "A slow tool for testing", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "input": {Type: "string", Desc: "Input parameter"}, + }), + }, nil +} + +func (t *slowTool) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { + atomic.AddInt32(&t.callCount, 1) + select { + case t.startedChan <- struct{}{}: + default: + } + time.Sleep(t.delay) + return t.result, nil +} + +type cancelTestStore struct { + m map[string][]byte + mu sync.Mutex +} + +func newCancelTestStore() *cancelTestStore { + return &cancelTestStore{m: make(map[string][]byte)} +} + +func (s *cancelTestStore) Set(_ context.Context, key string, value []byte) error { + s.mu.Lock() + defer s.mu.Unlock() + s.m[key] = value + return nil +} + +func (s *cancelTestStore) Get(_ context.Context, key string) ([]byte, bool, error) { + s.mu.Lock() + defer s.mu.Unlock() + v, ok := s.m[key] + return v, ok, nil +} + +func TestCancelSig(t *testing.T) { + t.Run("BasicCancelSignal", func(t *testing.T) { + cs := newCancelSig() + + cfg := checkCancelSig(cs) + assert.Nil(t, cfg, "Should not be cancelled initially") + + cs.cancel(&cancelConfig{Mode: CancelImmediate}) + + cfg = checkCancelSig(cs) + assert.NotNil(t, cfg, "Should be cancelled after cancel()") + assert.Equal(t, CancelImmediate, cfg.Mode) + }) +} + +func TestRunWithCancel_WithTools(t *testing.T) { + ctx := context.Background() + + t.Run("CancelImmediate_DuringModelCall", func(t *testing.T) { + modelStarted := make(chan struct{}, 1) + st := newSlowTool("slow_tool", 100*time.Millisecond, "tool result") + + slowModel := &cancelTestChatModel{ + delay: 2 * time.Second, + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: schema.FunctionCall{ + Name: "slow_tool", + Arguments: `{"input": "test"}`, + }, + }, + }, + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent with tool", + Instruction: "You are a test assistant", + Model: slowModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + EnableStreaming: false, + }) + + iter, cancelFn := runner.RunWithCancel(ctx, []Message{schema.UserMessage("Use the tool")}) + assert.NotNil(t, iter) + assert.NotNil(t, cancelFn) + + eventsCh := make(chan []*AgentEvent, 1) + go func() { + var events []*AgentEvent + for { + event, ok := iter.Next() + if !ok { + break + } + events = append(events, event) + } + eventsCh <- events + }() + + select { + case <-modelStarted: + case <-time.After(5 * time.Second): + t.Fatal("Model did not start within 5 seconds") + } + + time.Sleep(100 * time.Millisecond) + + err = cancelFn(ctx) + assert.NoError(t, err) + + start := time.Now() + events := <-eventsCh + elapsed := time.Since(start) + + assert.True(t, elapsed < 1*time.Second, "Should return quickly after cancel, elapsed: %v", elapsed) + assert.True(t, len(events) > 0) + + hasInterrupted := false + for _, e := range events { + assert.Nil(t, e.Err, "Should not have error event after cancel") + if e.Action != nil && e.Action.Interrupted != nil { + hasInterrupted = true + } + } + assert.True(t, hasInterrupted, "Should have interrupted event after cancel") + }) + + t.Run("CancelAfterChatModel_DuringToolCall", func(t *testing.T) { + toolStarted := make(chan struct{}, 1) + st := &slowToolWithSignal{ + name: "slow_tool", + delay: 2 * time.Second, + result: "tool result", + startedChan: toolStarted, + } + + modelWithToolCall := &simpleChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: schema.FunctionCall{ + Name: "slow_tool", + Arguments: `{"input": "test"}`, + }, + }, + }, + }, + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent with tool", + Instruction: "You are a test assistant", + Model: modelWithToolCall, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + + iter, cancelFn := agent.RunWithCancel(ctx, &AgentInput{ + Messages: []Message{schema.UserMessage("Use the tool")}, + }) + assert.NotNil(t, iter) + assert.NotNil(t, cancelFn) + + <-toolStarted + + time.Sleep(100 * time.Millisecond) + + err = cancelFn(ctx, WithCancelMode(CancelAfterChatModel)) + assert.NoError(t, err) + + var events []*AgentEvent + for { + event, ok := iter.Next() + if !ok { + break + } + assert.Nil(t, event.Err, "Should not have error event after cancel") + events = append(events, event) + } + + assert.True(t, len(events) > 0) + assert.True(t, atomic.LoadInt32(&st.callCount) >= 1, "Tool should have been called") + }) + + t.Run("CancelAfterToolCall_CompletesToolExecution", func(t *testing.T) { + toolStarted := make(chan struct{}, 1) + st := &slowToolWithSignal{ + name: "slow_tool", + delay: 500 * time.Millisecond, + result: "tool result", + startedChan: toolStarted, + } + + modelWithToolCall := &simpleChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: schema.FunctionCall{ + Name: "slow_tool", + Arguments: `{"input": "test"}`, + }, + }, + }, + }, + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent with tool", + Instruction: "You are a test assistant", + Model: modelWithToolCall, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + + iter, cancelFn := agent.RunWithCancel(ctx, &AgentInput{ + Messages: []Message{schema.UserMessage("Use the tool")}, + }) + assert.NotNil(t, iter) + assert.NotNil(t, cancelFn) + + <-toolStarted + + time.Sleep(100 * time.Millisecond) + + err = cancelFn(ctx, WithCancelMode(CancelAfterToolCall)) + assert.NoError(t, err) + + var events []*AgentEvent + for { + event, ok := iter.Next() + if !ok { + break + } + assert.Nil(t, event.Err, "Should not have error event after cancel") + events = append(events, event) + } + + assert.True(t, len(events) > 0) + assert.True(t, atomic.LoadInt32(&st.callCount) >= 1, "Tool should have been called") + }) +} + +type slowToolWithSignal struct { + name string + delay time.Duration + result string + callCount int32 + startedChan chan struct{} +} + +func (t *slowToolWithSignal) Info(ctx context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: t.name, + Desc: "A slow tool for testing", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "input": {Type: "string", Desc: "Input parameter"}, + }), + }, nil +} + +func (t *slowToolWithSignal) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { + atomic.AddInt32(&t.callCount, 1) + t.startedChan <- struct{}{} + time.Sleep(t.delay) + return t.result, nil +} + +type simpleChatModel struct { + response *schema.Message +} + +func (m *simpleChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + return m.response, nil +} + +func (m *simpleChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return schema.StreamReaderFromArray([]*schema.Message{m.response}), nil +} + +func (m *simpleChatModel) BindTools(tools []*schema.ToolInfo) error { + return nil +} + +func TestRunWithCancel_WithCheckpoint(t *testing.T) { + ctx := context.Background() + + t.Run("CancelWithCheckpoint", func(t *testing.T) { + modelStarted := make(chan struct{}, 1) + st := newSlowTool("slow_tool", 100*time.Millisecond, "tool result") + + slowModel := &cancelTestChatModel{ + delay: 500 * time.Millisecond, + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: schema.FunctionCall{ + Name: "slow_tool", + Arguments: `{"input": "test"}`, + }, + }, + }, + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent with tool", + Instruction: "You are a test assistant", + Model: slowModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + + store := newCancelTestStore() + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + EnableStreaming: false, + CheckPointStore: store, + }) + + iter, cancelFn := runner.RunWithCancel(ctx, []Message{schema.UserMessage("Use the tool")}, WithCheckPointID("cancel-1")) + + <-modelStarted + + err = cancelFn(ctx) + assert.NoError(t, err) + + var events []*AgentEvent + for { + event, ok := iter.Next() + if !ok { + break + } + assert.Nil(t, event.Err, "Should not have error event after cancel") + events = append(events, event) + } + + assert.True(t, len(events) > 0) + }) +} + +func TestCancelFuncMultipleCalls(t *testing.T) { + ctx := context.Background() + + t.Run("SecondCancelReturnsErrAgentFinished", func(t *testing.T) { + modelStarted := make(chan struct{}, 1) + st := newSlowTool("slow_tool", 100*time.Millisecond, "tool result") + + slowModel := &cancelTestChatModel{ + delay: 1 * time.Second, + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: schema.FunctionCall{ + Name: "slow_tool", + Arguments: `{"input": "test"}`, + }, + }, + }, + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent with tool", + Instruction: "You are a test assistant", + Model: slowModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + EnableStreaming: false, + }) + + iter, cancelFn := runner.RunWithCancel(ctx, []Message{schema.UserMessage("Use the tool")}) + + <-modelStarted + + cancelErr := cancelFn(ctx) + assert.NoError(t, cancelErr) + + cancelErr = cancelFn(ctx) + assert.ErrorIs(t, cancelErr, ErrAgentFinished) + + for { + _, ok := iter.Next() + if !ok { + break + } + } + }) +} + +func TestAgentNotCancellable(t *testing.T) { + ctx := context.Background() + + nonCancellableAgent := &nonCancellableTestAgent{ + name: "NonCancellable", + } + + runner := NewRunner(ctx, RunnerConfig{ + Agent: nonCancellableAgent, + EnableStreaming: false, + }) + + t.Run("RunWithCancelReturnsNilCancelFunc", func(t *testing.T) { + iter, cancelFn := runner.RunWithCancel(ctx, []Message{schema.UserMessage("Hello")}) + assert.NotNil(t, iter) + assert.Nil(t, cancelFn) + + for { + _, ok := iter.Next() + if !ok { + break + } + } + }) +} + +type nonCancellableTestAgent struct { + name string +} + +func (a *nonCancellableTestAgent) Name(_ context.Context) string { + return a.name +} + +func (a *nonCancellableTestAgent) Description(_ context.Context) string { + return "A non-cancellable agent" +} + +func (a *nonCancellableTestAgent) Run(_ context.Context, input *AgentInput, _ ...AgentRunOption) *AsyncIterator[*AgentEvent] { + iter, gen := NewAsyncIteratorPair[*AgentEvent]() + gen.Send(&AgentEvent{ + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + Message: schema.AssistantMessage("Response", nil), + }, + }, + }) + gen.Close() + return iter +} + +func TestRunWithCancel_Streaming(t *testing.T) { + ctx := context.Background() + + t.Run("CancelImmediate_DuringModelStream", func(t *testing.T) { + modelStarted := make(chan struct{}, 1) + st := newSlowTool("slow_tool", 100*time.Millisecond, "tool result") + + slowModel := &cancelTestChatModel{ + delay: 2 * time.Second, + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: schema.FunctionCall{ + Name: "slow_tool", + Arguments: `{"input": "test"}`, + }, + }, + }, + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent with tool", + Instruction: "You are a test assistant", + Model: slowModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + EnableStreaming: true, + }) + + iter, cancelFn := runner.RunWithCancel(ctx, []Message{schema.UserMessage("Use the tool")}) + assert.NotNil(t, iter) + assert.NotNil(t, cancelFn) + + eventsCh := make(chan []*AgentEvent, 1) + go func() { + var events []*AgentEvent + for { + event, ok := iter.Next() + if !ok { + break + } + events = append(events, event) + } + eventsCh <- events + }() + + select { + case <-modelStarted: + case <-time.After(5 * time.Second): + t.Fatal("Model did not start within 5 seconds") + } + + time.Sleep(100 * time.Millisecond) + + cancelErr := cancelFn(ctx) + assert.NoError(t, cancelErr) + + start := time.Now() + events := <-eventsCh + elapsed := time.Since(start) + + assert.True(t, elapsed < 1*time.Second, "Should return quickly after cancel, elapsed: %v", elapsed) + assert.True(t, len(events) > 0) + + hasInterrupted := false + for _, e := range events { + assert.Nil(t, e.Err, "Should not have error event after cancel") + if e.Action != nil && e.Action.Interrupted != nil { + hasInterrupted = true + } + } + assert.True(t, hasInterrupted, "Should have interrupted event after cancel") + }) + + t.Run("CancelAfterToolCall_Streaming", func(t *testing.T) { + toolStarted := make(chan struct{}, 1) + st := &slowToolWithSignal{ + name: "slow_tool", + delay: 500 * time.Millisecond, + result: "tool result", + startedChan: toolStarted, + } + + modelWithToolCall := &simpleChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: schema.FunctionCall{ + Name: "slow_tool", + Arguments: `{"input": "test"}`, + }, + }, + }, + }, + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent with tool", + Instruction: "You are a test assistant", + Model: modelWithToolCall, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + EnableStreaming: true, + }) + + iter, cancelFn := runner.RunWithCancel(ctx, []Message{schema.UserMessage("Use the tool")}) + assert.NotNil(t, iter) + assert.NotNil(t, cancelFn) + + <-toolStarted + + time.Sleep(100 * time.Millisecond) + + cancelErr := cancelFn(ctx, WithCancelMode(CancelAfterToolCall)) + assert.NoError(t, cancelErr) + + var events []*AgentEvent + for { + event, ok := iter.Next() + if !ok { + break + } + assert.Nil(t, event.Err, "Should not have error event after cancel") + events = append(events, event) + } + + assert.True(t, len(events) > 0) + assert.True(t, atomic.LoadInt32(&st.callCount) >= 1, "Tool should have been called") + }) +} + +// TestResumeWithCancel tests the workflow of Cancel followed by Resume. +// +// IMPORTANT: When Cancel is triggered, the cancelableChatModel.Generate/Stream +// method returns immediately with an Interrupt error, but the inner model's +// Generate/Stream call continues running in a background goroutine until completion. +// This means the original model instance's fields (e.g., delay, response) may still +// be read by the background goroutine after Cancel returns. +// +// To avoid data races, we create new agent and runner instances for the Resume phase +// instead of reusing and modifying the original model instance. +func TestResumeWithCancel(t *testing.T) { + ctx := context.Background() + + t.Run("RunWithCancel_ThenResumeWithCancel", func(t *testing.T) { + modelStarted := make(chan struct{}, 1) + modelCallCount := int32(0) + st := newSlowTool("slow_tool", 100*time.Millisecond, "tool result") + + slowModel := &cancelTestChatModel{ + delay: 500 * time.Millisecond, + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: schema.FunctionCall{ + Name: "slow_tool", + Arguments: `{"input": "test"}`, + }, + }, + }, + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent with tool", + Instruction: "You are a test assistant", + Model: slowModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + + store := newCancelTestStore() + checkpointID := "resume-cancel-test-1" + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + EnableStreaming: false, + CheckPointStore: store, + }) + + iter, cancelFn := runner.RunWithCancel(ctx, []Message{schema.UserMessage("Use the tool")}, WithCheckPointID(checkpointID)) + + <-modelStarted + atomic.AddInt32(&modelCallCount, 1) + + cancelErr := cancelFn(ctx) + assert.NoError(t, cancelErr) + + var events []*AgentEvent + for { + event, ok := iter.Next() + if !ok { + break + } + assert.Nil(t, event.Err, "Should not have error event after cancel") + events = append(events, event) + } + assert.True(t, len(events) > 0) + + hasInterrupted := false + for _, e := range events { + if e.Action != nil && e.Action.Interrupted != nil { + hasInterrupted = true + break + } + } + assert.True(t, hasInterrupted, "First run should have interrupted event") + + newModelStarted := make(chan struct{}, 1) + slowModel2 := &cancelTestChatModel{ + delay: 100 * time.Millisecond, + response: &schema.Message{ + Role: schema.Assistant, + Content: "Final response after resume", + }, + startedChan: newModelStarted, + doneChan: make(chan struct{}, 1), + } + + agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent with tool", + Instruction: "You are a test assistant", + Model: slowModel2, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + + runner2 := NewRunner(ctx, RunnerConfig{ + Agent: agent2, + EnableStreaming: false, + CheckPointStore: store, + }) + + resumeIter, resumeCancelFn, err := runner2.ResumeWithCancel(ctx, checkpointID) + assert.NoError(t, err) + assert.NotNil(t, resumeIter) + assert.NotNil(t, resumeCancelFn) + + var resumeEvents []*AgentEvent + for { + event, ok := resumeIter.Next() + if !ok { + break + } + assert.Nil(t, event.Err, "Should not have error event during resume") + resumeEvents = append(resumeEvents, event) + } + + assert.True(t, len(resumeEvents) > 0, "Resume should produce events") + }) + + t.Run("ResumeWithCancel_ThenCancel", func(t *testing.T) { + firstModelStarted := make(chan struct{}, 1) + resumeModelStarted := make(chan struct{}, 1) + modelCallCount := int32(0) + st := newSlowTool("slow_tool", 100*time.Millisecond, "tool result") + + slowModel := &cancelTestChatModel{ + delay: 500 * time.Millisecond, + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: schema.FunctionCall{ + Name: "slow_tool", + Arguments: `{"input": "test"}`, + }, + }, + }, + }, + startedChan: firstModelStarted, + doneChan: make(chan struct{}, 1), + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent with tool", + Instruction: "You are a test assistant", + Model: slowModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + + store := newCancelTestStore() + checkpointID := "resume-then-cancel-test-1" + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + EnableStreaming: false, + CheckPointStore: store, + }) + + iter, cancelFn := runner.RunWithCancel(ctx, []Message{schema.UserMessage("Use the tool")}, WithCheckPointID(checkpointID)) + + <-firstModelStarted + atomic.AddInt32(&modelCallCount, 1) + + cancelErr := cancelFn(ctx) + assert.NoError(t, cancelErr) + + for { + _, ok := iter.Next() + if !ok { + break + } + } + + slowModel2 := &cancelTestChatModel{ + delay: 2 * time.Second, + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: schema.FunctionCall{ + Name: "slow_tool", + Arguments: `{"input": "test"}`, + }, + }, + }, + }, + startedChan: resumeModelStarted, + doneChan: make(chan struct{}, 1), + } + + agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent with tool", + Instruction: "You are a test assistant", + Model: slowModel2, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + + runner2 := NewRunner(ctx, RunnerConfig{ + Agent: agent2, + EnableStreaming: false, + CheckPointStore: store, + }) + + resumeIter, resumeCancelFn, err := runner2.ResumeWithCancel(ctx, checkpointID) + assert.NoError(t, err) + + resumeEventsCh := make(chan []*AgentEvent, 1) + go func() { + var events []*AgentEvent + for { + event, ok := resumeIter.Next() + if !ok { + break + } + events = append(events, event) + } + resumeEventsCh <- events + }() + + <-resumeModelStarted + atomic.AddInt32(&modelCallCount, 1) + + time.Sleep(100 * time.Millisecond) + + err = resumeCancelFn(ctx) + assert.NoError(t, err) + + start := time.Now() + resumeEvents := <-resumeEventsCh + elapsed := time.Since(start) + + assert.True(t, elapsed < 1*time.Second, "Resume should return quickly after cancel, elapsed: %v", elapsed) + assert.True(t, len(resumeEvents) > 0) + + hasInterrupted := false + for _, e := range resumeEvents { + assert.Nil(t, e.Err, "Should not have error event after resume cancel") + if e.Action != nil && e.Action.Interrupted != nil { + hasInterrupted = true + } + } + assert.True(t, hasInterrupted, "Resume should have interrupted event after cancel") + }) +} diff --git a/adk/cancel_wrapper.go b/adk/cancel_wrapper.go new file mode 100644 index 000000000..9c00df77f --- /dev/null +++ b/adk/cancel_wrapper.go @@ -0,0 +1,280 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "io" + "runtime/debug" + "time" + + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/internal/safe" + "github.com/cloudwego/eino/schema" +) + +type cancelWaitResult[T any] struct { + result T + err error + cancelled bool +} + +func waitWithCancel[T any](cs *cancelSig, resultCh <-chan cancelWaitResult[T]) cancelWaitResult[T] { + var timeCh <-chan time.Time + select { + case <-cs.done: + cfg := cs.config.Load().(*cancelConfig) + if cfg.Mode == CancelImmediate { + if cfg.Timeout == nil { + return cancelWaitResult[T]{cancelled: true} + } + timeCh = time.After(*cfg.Timeout) + } + case res := <-resultCh: + return res + } + select { + case <-timeCh: + return cancelWaitResult[T]{cancelled: true} + case res := <-resultCh: + return res + } +} + +type cancelableChatModel struct { + inner model.BaseChatModel + cs *cancelSig +} + +func wrapModelForCancelable(m model.BaseChatModel, cs *cancelSig) *cancelableChatModel { + return &cancelableChatModel{inner: m, cs: cs} +} + +func (c *cancelableChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + if cfg := checkCancelSig(c.cs); cfg != nil && cfg.Mode == CancelImmediate { + return nil, compose.Interrupt(ctx, "cancelled externally") + } + + resultCh := make(chan cancelWaitResult[*schema.Message], 1) + go func() { + defer func() { + if panicErr := recover(); panicErr != nil { + resultCh <- cancelWaitResult[*schema.Message]{err: safe.NewPanicErr(panicErr, debug.Stack())} + } + }() + res, err := c.inner.Generate(ctx, input, opts...) + resultCh <- cancelWaitResult[*schema.Message]{result: res, err: err} + }() + + res := waitWithCancel(c.cs, resultCh) + if res.cancelled { + return nil, compose.Interrupt(ctx, "cancelled externally") + } + return res.result, res.err +} + +func (c *cancelableChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + if cfg := checkCancelSig(c.cs); cfg != nil && cfg.Mode == CancelImmediate { + return nil, compose.Interrupt(ctx, "cancelled externally") + } + + resultCh := make(chan cancelWaitResult[*schema.StreamReader[*schema.Message]], 1) + go func() { + defer func() { + if panicErr := recover(); panicErr != nil { + resultCh <- cancelWaitResult[*schema.StreamReader[*schema.Message]]{err: safe.NewPanicErr(panicErr, debug.Stack())} + } + }() + + stream, err := c.inner.Stream(ctx, input, opts...) + if err != nil { + resultCh <- cancelWaitResult[*schema.StreamReader[*schema.Message]]{err: err} + return + } + copies := stream.Copy(2) + _ = consumeStreamForError(copies[0]) + resultCh <- cancelWaitResult[*schema.StreamReader[*schema.Message]]{result: copies[1]} + }() + + res := waitWithCancel(c.cs, resultCh) + if res.cancelled { + return nil, compose.Interrupt(ctx, "cancelled externally") + } + return res.result, res.err +} + +func cancelableToolInvokable(cs *cancelSig, endpoint compose.InvokableToolEndpoint) compose.InvokableToolEndpoint { + return func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) { + if cfg := checkCancelSig(cs); cfg != nil && cfg.Mode == CancelImmediate { + return nil, compose.Interrupt(ctx, "cancelled externally") + } + + resultCh := make(chan cancelWaitResult[*compose.ToolOutput], 1) + go func() { + defer func() { + if panicErr := recover(); panicErr != nil { + resultCh <- cancelWaitResult[*compose.ToolOutput]{err: safe.NewPanicErr(panicErr, debug.Stack())} + } + }() + output, err := endpoint(ctx, input) + resultCh <- cancelWaitResult[*compose.ToolOutput]{result: output, err: err} + }() + + res := waitWithCancel(cs, resultCh) + if res.cancelled { + return nil, compose.Interrupt(ctx, "cancelled externally") + } + return res.result, res.err + } +} + +func cancelableToolStreamable(cs *cancelSig, endpoint compose.StreamableToolEndpoint) compose.StreamableToolEndpoint { + return func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) { + if cfg := checkCancelSig(cs); cfg != nil && cfg.Mode == CancelImmediate { + return nil, compose.Interrupt(ctx, "cancelled externally") + } + + resultCh := make(chan cancelWaitResult[*schema.StreamReader[string]], 1) + go func() { + defer func() { + if panicErr := recover(); panicErr != nil { + resultCh <- cancelWaitResult[*schema.StreamReader[string]]{err: safe.NewPanicErr(panicErr, debug.Stack())} + } + }() + output, err := endpoint(ctx, input) + if err != nil { + resultCh <- cancelWaitResult[*schema.StreamReader[string]]{err: err} + return + } + copies := output.Result.Copy(2) + _ = consumeStreamForErrorString(copies[0]) + resultCh <- cancelWaitResult[*schema.StreamReader[string]]{result: copies[1]} + }() + + res := waitWithCancel(cs, resultCh) + if res.cancelled { + return nil, compose.Interrupt(ctx, "cancelled externally") + } + if res.err != nil { + return nil, res.err + } + return &compose.StreamToolOutput{Result: res.result}, nil + } +} + +func cancelableToolEnhancedInvokable(cs *cancelSig, endpoint compose.EnhancedInvokableToolEndpoint) compose.EnhancedInvokableToolEndpoint { + return func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedInvokableToolOutput, error) { + if cfg := checkCancelSig(cs); cfg != nil && cfg.Mode == CancelImmediate { + return nil, compose.Interrupt(ctx, "cancelled externally") + } + + resultCh := make(chan cancelWaitResult[*compose.EnhancedInvokableToolOutput], 1) + go func() { + defer func() { + if panicErr := recover(); panicErr != nil { + resultCh <- cancelWaitResult[*compose.EnhancedInvokableToolOutput]{err: safe.NewPanicErr(panicErr, debug.Stack())} + } + }() + output, err := endpoint(ctx, input) + resultCh <- cancelWaitResult[*compose.EnhancedInvokableToolOutput]{result: output, err: err} + }() + + res := waitWithCancel(cs, resultCh) + if res.cancelled { + return nil, compose.Interrupt(ctx, "cancelled externally") + } + return res.result, res.err + } +} + +func cancelableToolEnhancedStreamable(cs *cancelSig, endpoint compose.EnhancedStreamableToolEndpoint) compose.EnhancedStreamableToolEndpoint { + return func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedStreamableToolOutput, error) { + if cfg := checkCancelSig(cs); cfg != nil && cfg.Mode == CancelImmediate { + return nil, compose.Interrupt(ctx, "cancelled externally") + } + + resultCh := make(chan cancelWaitResult[*schema.StreamReader[*schema.ToolResult]], 1) + go func() { + defer func() { + if panicErr := recover(); panicErr != nil { + resultCh <- cancelWaitResult[*schema.StreamReader[*schema.ToolResult]]{err: safe.NewPanicErr(panicErr, debug.Stack())} + } + }() + output, err := endpoint(ctx, input) + if err != nil { + resultCh <- cancelWaitResult[*schema.StreamReader[*schema.ToolResult]]{err: err} + return + } + copies := output.Result.Copy(2) + _ = consumeStreamForErrorToolResult(copies[0]) + resultCh <- cancelWaitResult[*schema.StreamReader[*schema.ToolResult]]{result: copies[1]} + }() + + res := waitWithCancel(cs, resultCh) + if res.cancelled { + return nil, compose.Interrupt(ctx, "cancelled externally") + } + if res.err != nil { + return nil, res.err + } + return &compose.EnhancedStreamableToolOutput{Result: res.result}, nil + } +} + +func cancelableTool(cs *cancelSig) compose.ToolMiddleware { + return compose.ToolMiddleware{ + Invokable: func(endpoint compose.InvokableToolEndpoint) compose.InvokableToolEndpoint { + return cancelableToolInvokable(cs, endpoint) + }, + Streamable: func(endpoint compose.StreamableToolEndpoint) compose.StreamableToolEndpoint { + return cancelableToolStreamable(cs, endpoint) + }, + EnhancedInvokable: func(endpoint compose.EnhancedInvokableToolEndpoint) compose.EnhancedInvokableToolEndpoint { + return cancelableToolEnhancedInvokable(cs, endpoint) + }, + EnhancedStreamable: func(endpoint compose.EnhancedStreamableToolEndpoint) compose.EnhancedStreamableToolEndpoint { + return cancelableToolEnhancedStreamable(cs, endpoint) + }, + } +} + +func consumeStreamForErrorString(stream *schema.StreamReader[string]) error { + defer stream.Close() + for { + _, err := stream.Recv() + if err == io.EOF { + return nil + } + if err != nil { + return err + } + } +} + +func consumeStreamForErrorToolResult(stream *schema.StreamReader[*schema.ToolResult]) error { + defer stream.Close() + for { + _, err := stream.Recv() + if err == io.EOF { + return nil + } + if err != nil { + return err + } + } +} diff --git a/adk/chatmodel.go b/adk/chatmodel.go index 56993a7b2..effb270bd 100644 --- a/adk/chatmodel.go +++ b/adk/chatmodel.go @@ -38,6 +38,10 @@ import ( "github.com/cloudwego/eino/schema" ) +var _ ResumableAgent = &ChatModelAgent{} +var _ CancellableAgent = &ChatModelAgent{} +var _ CancellableResumableAgent = &ChatModelAgent{} + type chatModelAgentExecCtx struct { runtimeReturnDirectly map[string]bool generator *AsyncGenerator[*AgentEvent] @@ -341,7 +345,8 @@ type ChatModelAgent struct { exeCtx *execContext } -type runFunc func(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent], store *bridgeStore, instruction string, returnDirectly map[string]bool, opts ...compose.Option) +type runFunc func(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent], + store *bridgeStore, instruction string, returnDirectly map[string]bool, cs *cancelSig, opts ...compose.Option) // NewChatModelAgent constructs a chat model-backed agent with the provided config. func NewChatModelAgent(ctx context.Context, config *ChatModelAgentConfig) (*ChatModelAgent, error) { @@ -570,7 +575,7 @@ func setOutputToSession(ctx context.Context, msg Message, msgStream MessageStrea } func errFunc(err error) runFunc { - return func(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent], store *bridgeStore, _ string, _ map[string]bool, _ ...compose.Option) { + return func(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent], store *bridgeStore, _ string, _ map[string]bool, _ *cancelSig, _ ...compose.Option) { generator.Send(&AgentEvent{Err: err}) } } @@ -692,19 +697,23 @@ func (a *ChatModelAgent) prepareExecContext(ctx context.Context) (*execContext, } func (a *ChatModelAgent) buildNoToolsRunFunc(_ context.Context) runFunc { - wrappedModel := buildModelWrappers(a.model, &modelWrapperConfig{ - handlers: a.handlers, - middlewares: a.middlewares, - retryConfig: a.modelRetryConfig, - }) - type noToolsInput struct { input *AgentInput instruction string } return func(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent], - store *bridgeStore, instruction string, _ map[string]bool, opts ...compose.Option) { + store *bridgeStore, instruction string, _ map[string]bool, cs *cancelSig, opts ...compose.Option) { + + wrappedModel := buildModelWrappers(a.model, &modelWrapperConfig{ + handlers: a.handlers, + middlewares: a.middlewares, + retryConfig: a.modelRetryConfig, + }) + + if cs != nil { + wrappedModel = wrapModelForCancelable(wrappedModel, cs) + } chain := compose.NewChain[noToolsInput, Message]( compose.WithGenLocalState(func(ctx context.Context) (state *State) { @@ -750,13 +759,37 @@ func (a *ChatModelAgent) buildNoToolsRunFunc(_ context.Context) runFunc { } else if msgStream != nil { msgStream.Close() } - } else { + return + } + + info, ok := compose.ExtractInterruptInfo(err) + if !ok { generator.Send(&AgentEvent{Err: err}) + return + } + + data, existed, sErr := store.Get(ctx, bridgeCheckpointID) + if sErr != nil { + generator.Send(&AgentEvent{AgentName: a.name, Err: fmt.Errorf("failed to get interrupt info: %w", sErr)}) + return + } + if !existed { + generator.Send(&AgentEvent{AgentName: a.name, Err: fmt.Errorf("interrupt occurred but checkpoint data is missing")}) + return + } + + is := FromInterruptContexts(info.InterruptContexts) + event := CompositeInterrupt(ctx, info, data, is) + event.Action.Interrupted.Data = &ChatModelAgentInterruptInfo{ + Info: info, + Data: data, } + event.AgentName = a.name + generator.Send(event) } } -func (a *ChatModelAgent) buildReactRunFunc(ctx context.Context, bc *execContext) (runFunc, error) { +func (a *ChatModelAgent) buildReActRunFunc(_ context.Context, bc *execContext) (runFunc, error) { conf := &reactConfig{ model: a.model, toolsConfig: &bc.toolsNodeConf, @@ -777,8 +810,8 @@ func (a *ChatModelAgent) buildReactRunFunc(ctx context.Context, bc *execContext) } return func(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent], store *bridgeStore, - instruction string, returnDirectly map[string]bool, opts ...compose.Option) { - g, err := newReact(ctx, conf) + instruction string, returnDirectly map[string]bool, cs *cancelSig, opts ...compose.Option) { + g, err := newReact(ctx, conf, cs) if err != nil { generator.Send(&AgentEvent{Err: err}) return @@ -894,7 +927,7 @@ func (a *ChatModelAgent) buildRunFunc(ctx context.Context) runFunc { return } - run, err := a.buildReactRunFunc(ctx, ec) + run, err := a.buildReActRunFunc(ctx, ec) if err != nil { a.run = errFunc(err) return @@ -938,7 +971,7 @@ func (a *ChatModelAgent) getRunFunc(ctx context.Context) (context.Context, runFu if len(runtimeBC.toolsNodeConf.Tools) == 0 { tempRun = a.buildNoToolsRunFunc(ctx) } else { - tempRun, err = a.buildReactRunFunc(ctx, runtimeBC) + tempRun, err = a.buildReActRunFunc(ctx, runtimeBC) if err != nil { return ctx, nil, nil, err } @@ -948,6 +981,15 @@ func (a *ChatModelAgent) getRunFunc(ctx context.Context) (context.Context, runFu } func (a *ChatModelAgent) Run(ctx context.Context, input *AgentInput, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { + iter, _ := a.runInternal(ctx, input, false, opts...) + return iter +} + +func (a *ChatModelAgent) RunWithCancel(ctx context.Context, input *AgentInput, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) { + return a.runInternal(ctx, input, true, opts...) +} + +func (a *ChatModelAgent) runInternal(ctx context.Context, input *AgentInput, withCancel bool, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) { iterator, generator := NewAsyncIteratorPair[*AgentEvent]() ctx, run, bc, err := a.getRunFunc(ctx) @@ -956,7 +998,7 @@ func (a *ChatModelAgent) Run(ctx context.Context, input *AgentInput, opts ...Age generator.Send(&AgentEvent{Err: err}) generator.Close() }() - return iterator + return iterator, notCancellableFuncInternal } co := getComposeOptions(opts) @@ -969,6 +1011,13 @@ func (a *ChatModelAgent) Run(ctx context.Context, input *AgentInput, opts ...Age } } + var cs *cancelSig + var cancelFn CancelFunc = notCancellableFuncInternal + if withCancel { + cs = newCancelSig() + cancelFn = buildCancelFunc(cs) + } + go func() { defer func() { panicErr := recover() @@ -990,13 +1039,47 @@ func (a *ChatModelAgent) Run(ctx context.Context, input *AgentInput, opts ...Age returnDirectly = bc.returnDirectly } - run(ctx, input, generator, newBridgeStore(), instruction, returnDirectly, co...) + run(ctx, input, generator, newBridgeStore(), instruction, returnDirectly, cs, co...) }() - return iterator + return iterator, cancelFn +} + +func buildCancelFunc(cs *cancelSig) CancelFunc { + var once sync.Once + + return func(_ context.Context, opts ...CancelOption) error { + cfg := &cancelConfig{ + Mode: CancelImmediate, + } + for _, opt := range opts { + opt(cfg) + } + + cancelled := false + once.Do(func() { + cs.cancel(cfg) + cancelled = true + }) + + if !cancelled { + return ErrAgentFinished + } + + return nil + } } func (a *ChatModelAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { + iter, _ := a.resumeInternal(ctx, info, false, opts...) + return iter +} + +func (a *ChatModelAgent) ResumeWithCancel(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) { + return a.resumeInternal(ctx, info, true, opts...) +} + +func (a *ChatModelAgent) resumeInternal(ctx context.Context, info *ResumeInfo, withCancel bool, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) { iterator, generator := NewAsyncIteratorPair[*AgentEvent]() ctx, run, bc, err := a.getRunFunc(ctx) @@ -1005,7 +1088,7 @@ func (a *ChatModelAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...A generator.Send(&AgentEvent{Err: err}) generator.Close() }() - return iterator + return iterator, notCancellableFuncInternal } co := getComposeOptions(opts) @@ -1018,14 +1101,19 @@ func (a *ChatModelAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...A } } + methodName := "Resume" + if withCancel { + methodName = "ResumeWithCancel" + } + if info.InterruptState == nil { - panic(fmt.Sprintf("ChatModelAgent.Resume: agent '%s' was asked to resume but has no state", a.Name(ctx))) + panic(fmt.Sprintf("ChatModelAgent.%s: agent '%s' was asked to resume but has no state", methodName, a.Name(ctx))) } stateByte, ok := info.InterruptState.([]byte) if !ok { - panic(fmt.Sprintf("ChatModelAgent.Resume: agent '%s' was asked to resume but has invalid interrupt state type: %T", - a.Name(ctx), info.InterruptState)) + panic(fmt.Sprintf("ChatModelAgent.%s: agent '%s' was asked to resume but has invalid interrupt state type: %T", + methodName, a.Name(ctx), info.InterruptState)) } // Migrate legacy checkpoints before resume. @@ -1040,15 +1128,15 @@ func (a *ChatModelAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...A generator.Send(&AgentEvent{Err: err}) generator.Close() }() - return iterator + return iterator, nil } var historyModifier func(ctx context.Context, history []Message) []Message if info.ResumeData != nil { resumeData, ok := info.ResumeData.(*ChatModelAgentResumeData) if !ok { - panic(fmt.Sprintf("ChatModelAgent.Resume: agent '%s' was asked to resume but has invalid resume data type: %T", - a.Name(ctx), info.ResumeData)) + panic(fmt.Sprintf("ChatModelAgent.%s: agent '%s' was asked to resume but has invalid resume data type: %T", + methodName, a.Name(ctx), info.ResumeData)) } historyModifier = resumeData.HistoryModifier } @@ -1064,6 +1152,13 @@ func (a *ChatModelAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...A })) } + var cs *cancelSig + var cancelFn CancelFunc = notCancellableFuncInternal + if withCancel { + cs = newCancelSig() + cancelFn = buildCancelFunc(cs) + } + go func() { defer func() { panicErr := recover() @@ -1086,10 +1181,10 @@ func (a *ChatModelAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...A } run(ctx, &AgentInput{EnableStreaming: info.EnableStreaming}, generator, - newResumeBridgeStore(stateByte), instruction, returnDirectly, co...) + newResumeBridgeStore(stateByte), instruction, returnDirectly, cs, co...) }() - return iterator + return iterator, cancelFn } func getComposeOptions(opts []AgentRunOption) []compose.Option { diff --git a/adk/flow.go b/adk/flow.go index ee4dec96c..93143c641 100644 --- a/adk/flow.go +++ b/adk/flow.go @@ -333,6 +333,15 @@ func buildDefaultHistoryRewriter(agentName string) HistoryRewriter { } func (a *flowAgent) Run(ctx context.Context, input *AgentInput, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { + iter, _ := a.runInternal(ctx, input, false, opts...) + return iter +} + +func (a *flowAgent) RunWithCancel(ctx context.Context, input *AgentInput, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) { + return a.runInternal(ctx, input, true, opts...) +} + +func (a *flowAgent) runInternal(ctx context.Context, input *AgentInput, withCancel bool, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) { agentName := a.Name(ctx) var runCtx *runContext @@ -345,7 +354,7 @@ func (a *flowAgent) Run(ctx context.Context, input *AgentInput, opts ...AgentRun if err != nil { cbInput := &AgentCallbackInput{Input: input} ctx = callbacks.OnStart(ctx, cbInput) - return wrapIterWithOnEnd(ctx, genErrorIter(err)) + return wrapIterWithOnEnd(ctx, genErrorIter(err)), notCancellableFuncInternal } ctxForSubAgents := ctx @@ -358,19 +367,40 @@ func (a *flowAgent) Run(ctx context.Context, input *AgentInput, opts ...AgentRun input = processedInput if wf, ok := a.Agent.(*workflowAgent); ok { - return wrapIterWithOnEnd(ctx, wf.Run(ctx, input, filterCallbackHandlersForNestedAgents(agentName, opts)...)) + return wrapIterWithOnEnd(ctx, wf.Run(ctx, input, filterCallbackHandlersForNestedAgents(agentName, opts)...)), notCancellableFuncInternal } - aIter := a.Agent.Run(ctx, input, filterOptions(agentName, opts)...) + var aIter *AsyncIterator[*AgentEvent] + var cancelFn CancelFunc = notCancellableFuncInternal + + ca, supportCancel := a.Agent.(CancellableAgent) + if withCancel && supportCancel { + aIter, cancelFn = ca.RunWithCancel(ctx, input, filterOptions(agentName, opts)...) + } else { + aIter = a.Agent.Run(ctx, input, filterOptions(agentName, opts)...) + } iterator, generator := NewAsyncIteratorPair[*AgentEvent]() go a.run(ctx, ctxForSubAgents, runCtx, aIter, generator, opts...) - return iterator + return iterator, cancelFn +} + +func notCancellableFuncInternal(_ context.Context, _ ...CancelOption) error { + return ErrAgentNotCancellable } func (a *flowAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { + iter, _ := a.resumeInternal(ctx, info, false, opts...) + return iter +} + +func (a *flowAgent) ResumeWithCancel(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) { + return a.resumeInternal(ctx, info, true, opts...) +} + +func (a *flowAgent) resumeInternal(ctx context.Context, info *ResumeInfo, withCancel bool, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) { agentName := a.Name(ctx) ctx, info = buildResumeInfo(ctx, agentName, info) @@ -383,51 +413,59 @@ func (a *flowAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...AgentR ctx = callbacks.OnStart(ctx, cbInput) if info.WasInterrupted { - ra, ok := a.Agent.(ResumableAgent) - if !ok { + var aIter *AsyncIterator[*AgentEvent] + var cancelFn CancelFunc = notCancellableFuncInternal + + ca, supportCancel := a.Agent.(CancellableResumableAgent) + if withCancel && supportCancel { + aIter, cancelFn = ca.ResumeWithCancel(ctx, info, opts...) + } else if ra, ok := a.Agent.(ResumableAgent); ok { + if _, ok := ra.(*workflowAgent); ok { + filteredOpts := filterCallbackHandlersForNestedAgents(agentName, opts) + aIter := ra.Resume(ctx, info, filteredOpts...) + return wrapIterWithOnEnd(ctx, aIter), cancelFn + } + aIter = ra.Resume(ctx, info, opts...) + } else { return wrapIterWithOnEnd(ctx, genErrorIter(fmt.Errorf("failed to resume agent: agent '%s' is an interrupt point "+ - "but is not a ResumableAgent", agentName))) + "but is not a ResumableAgent", agentName))), notCancellableFuncInternal } - iterator, generator := NewAsyncIteratorPair[*AgentEvent]() - if _, ok := ra.(*workflowAgent); ok { - filteredOpts := filterCallbackHandlersForNestedAgents(agentName, opts) - aIter := ra.Resume(ctx, info, filteredOpts...) - return wrapIterWithOnEnd(ctx, aIter) - } - aIter := ra.Resume(ctx, info, opts...) + iterator, generator := NewAsyncIteratorPair[*AgentEvent]() go a.run(ctx, ctxForSubAgents, getRunCtx(ctxForSubAgents), aIter, generator, opts...) - return iterator + return iterator, cancelFn } nextAgentName, err := getNextResumeAgent(ctx, info) if err != nil { - return wrapIterWithOnEnd(ctx, genErrorIter(err)) + return wrapIterWithOnEnd(ctx, genErrorIter(err)), notCancellableFuncInternal } subAgent := a.getAgent(ctxForSubAgents, nextAgentName) if subAgent == nil { - // the inner agent wrapped by flowAgent may be ANY agent, including flowAgent, - // AgentWithDeterministicTransferTo, or any other custom agent user defined, - // or any combinations of the above in any order, - // that ultimately wraps the flowAgent with sub-agents - // We need to go through these wrappers to reach the flowAgent with sub-agents. if len(a.subAgents) == 0 { + ca, supportCancel := a.Agent.(CancellableResumableAgent) + if withCancel && supportCancel { + iter, cancelFn := ca.ResumeWithCancel(ctx, info, opts...) + return wrapIterWithOnEnd(ctx, iter), cancelFn + } if ra, ok := a.Agent.(ResumableAgent); ok { - // Use ctx (callback-enriched) instead of ctxForSubAgents here. - // This is the inner agent that flowAgent wraps (e.g., supervisorContainer), - // not a sub-agent. The callback context from OnStart should be propagated - // to ensure unified tracing for container patterns. - return wrapIterWithOnEnd(ctx, ra.Resume(ctx, info, opts...)) + return wrapIterWithOnEnd(ctx, ra.Resume(ctx, info, opts...)), notCancellableFuncInternal } return wrapIterWithOnEnd(ctx, genErrorIter(fmt.Errorf( "failed to resume agent: agent '%s' (type %T) has no sub-agents and does not implement ResumableAgent interface. "+ - "To support resume, your custom agent wrapper must implement the ResumableAgent interface", agentName, a.Agent))) + "To support resume, your custom agent wrapper must implement the ResumableAgent interface", agentName, a.Agent))), nil } - return wrapIterWithOnEnd(ctx, genErrorIter(fmt.Errorf("failed to resume agent: sub-agent '%s' not found in agent '%s'", nextAgentName, agentName))) + return wrapIterWithOnEnd(ctx, genErrorIter(fmt.Errorf("failed to resume agent: sub-agent '%s' not found in agent '%s'", nextAgentName, agentName))), notCancellableFuncInternal + } + + ca, supportCancel := ResumableAgent(subAgent).(CancellableResumableAgent) + if withCancel && supportCancel { + iter, cancelFn := ca.ResumeWithCancel(ctxForSubAgents, info, opts...) + return wrapIterWithOnEnd(ctx, iter), cancelFn } - return wrapIterWithOnEnd(ctx, subAgent.Resume(ctxForSubAgents, info, opts...)) + return wrapIterWithOnEnd(ctx, subAgent.Resume(ctxForSubAgents, info, opts...)), notCancellableFuncInternal } type DeterministicTransferConfig struct { diff --git a/adk/interface.go b/adk/interface.go index 2e7e3d09c..1d2f5f070 100644 --- a/adk/interface.go +++ b/adk/interface.go @@ -291,6 +291,9 @@ const ( // ErrAgentFinished is returned by Cancel when the agent has already finished execution. var ErrAgentFinished = errors.New("agent has already finished execution") +// ErrAgentNotCancellable is returned by Cancel when the agent does not support cancellation. +var ErrAgentNotCancellable = errors.New("agent does not implement CancellableAgent interface") + type cancelConfig struct { Mode CancelMode Timeout *time.Duration @@ -298,12 +301,14 @@ type cancelConfig struct { type CancelOption func(*cancelConfig) +// WithCancelMode sets the cancel mode for the cancel operation. func WithCancelMode(mode CancelMode) CancelOption { return func(config *cancelConfig) { config.Mode = mode } } +// WithCancelTimeout sets a timeout duration for CancelImmediate mode. func WithCancelTimeout(timeout time.Duration) CancelOption { return func(config *cancelConfig) { config.Timeout = &timeout @@ -312,10 +317,12 @@ func WithCancelTimeout(timeout time.Duration) CancelOption { type CancelFunc func(context.Context, ...CancelOption) error -type CancellableRun interface { +type CancellableAgent interface { + Agent RunWithCancel(ctx context.Context, input *AgentInput, options ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) } -type CancellableResume interface { +type CancellableResumableAgent interface { + ResumableAgent ResumeWithCancel(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) } diff --git a/adk/react.go b/adk/react.go index 2bf6dd462..0aec5d94c 100644 --- a/adk/react.go +++ b/adk/react.go @@ -22,6 +22,7 @@ import ( "encoding/gob" "errors" "io" + "sync/atomic" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/compose" @@ -299,13 +300,31 @@ func genReactState(config *reactConfig) func(ctx context.Context) *State { } } -func newReact(ctx context.Context, config *reactConfig) (reactGraph, error) { +func newReact(ctx context.Context, config *reactConfig, cs *cancelSig) (reactGraph, error) { const ( - initNode_ = "Init" - chatModel_ = "ChatModel" - toolNode_ = "ToolNode" + initNode_ = "Init" + chatModel_ = "ChatModel" + beforeToolNode_ = "BeforeToolNode" + toolNode_ = "ToolNode" + afterToolNode_ = "AfterToolNode" ) + checkCancel := cs != nil + + nodeNameAfterModel := func() string { + if checkCancel { + return beforeToolNode_ + } + return toolNode_ + } + + nodeNameAfterTool := func() string { + if checkCancel { + return afterToolNode_ + } + return chatModel_ + } + g := compose.NewGraph[*reactInput, Message](compose.WithGenLocalState(genReactState(config))) initLambda := func(ctx context.Context, input *reactInput) ([]Message, error) { @@ -318,7 +337,18 @@ func newReact(ctx context.Context, config *reactConfig) (reactGraph, error) { wrappedModel = buildModelWrappers(config.model, config.modelWrapperConf) } - toolsNode, err := compose.NewToolNode(ctx, config.toolsConfig) + toolsConfig := config.toolsConfig + if checkCancel { + wrappedModel = wrapModelForCancelable(wrappedModel, cs) + tcMWs := make([]compose.ToolMiddleware, 0, len(toolsConfig.ToolCallMiddlewares)+1) + tcMWs = append(tcMWs, cancelableTool(cs)) + tcMWs = append(tcMWs, toolsConfig.ToolCallMiddlewares...) + toolsConfigCopy := *toolsConfig + toolsConfigCopy.ToolCallMiddlewares = tcMWs + toolsConfig = &toolsConfigCopy + } + + toolsNode, err := compose.NewToolNode(ctx, toolsConfig) if err != nil { return nil, err } @@ -369,6 +399,28 @@ func newReact(ctx context.Context, config *reactConfig) (reactGraph, error) { _ = g.AddEdge(compose.START, initNode_) _ = g.AddEdge(initNode_, chatModel_) + if checkCancel { + beforeToolNode := func(ctx context.Context, input Message) (output Message, err error) { + if sig := checkCancelSig(cs); sig != nil && sig.Mode != CancelAfterToolCall { + return nil, compose.Interrupt(ctx, "cancelled externally") + } + + return input, nil + } + _ = g.AddLambdaNode(beforeToolNode_, compose.InvokableLambda(beforeToolNode), compose.WithNodeName(beforeToolNode_)) + g.AddEdge(beforeToolNode_, toolNode_) + + afterToolNode := func(ctx context.Context, input []Message) (output []Message, err error) { + if sig := checkCancelSig(cs); sig != nil && sig.Mode != CancelAfterChatModel { + return nil, compose.Interrupt(ctx, "cancelled externally") + } + + return input, nil + } + _ = g.AddLambdaNode(afterToolNode_, compose.InvokableLambda(afterToolNode), compose.WithNodeName(afterToolNode_)) + g.AddEdge(afterToolNode_, chatModel_) + } + toolCallCheck := func(ctx context.Context, sMsg MessageStream) (string, error) { defer sMsg.Close() for { @@ -382,11 +434,11 @@ func newReact(ctx context.Context, config *reactConfig) (reactGraph, error) { } if len(chunk.ToolCalls) > 0 { - return toolNode_, nil + return nodeNameAfterModel(), nil } } } - branch := compose.NewStreamGraphBranch(toolCallCheck, map[string]bool{compose.END: true, toolNode_: true}) + branch := compose.NewStreamGraphBranch(toolCallCheck, map[string]bool{compose.END: true, nodeNameAfterModel(): true}) _ = g.AddBranch(chatModel_, branch) if len(config.toolsReturnDirectly) > 0 { @@ -423,15 +475,43 @@ func newReact(ctx context.Context, config *reactConfig) (reactGraph, error) { return toolNodeToEndConverter, nil } - return chatModel_, nil + return nodeNameAfterTool(), nil } branch = compose.NewStreamGraphBranch(checkReturnDirect, - map[string]bool{toolNodeToEndConverter: true, chatModel_: true}) + map[string]bool{toolNodeToEndConverter: true, nodeNameAfterTool(): true}) _ = g.AddBranch(toolNode_, branch) } else { - _ = g.AddEdge(toolNode_, chatModel_) + _ = g.AddEdge(toolNode_, nodeNameAfterTool()) } return g, nil } + +type cancelSig struct { + done chan struct{} + config atomic.Value +} + +func newCancelSig() *cancelSig { + return &cancelSig{ + done: make(chan struct{}), + } +} + +func (cs *cancelSig) cancel(cfg *cancelConfig) { + cs.config.Store(cfg) + close(cs.done) +} + +func checkCancelSig(cs *cancelSig) *cancelConfig { + if cs == nil { + return nil + } + select { + case <-cs.done: + return cs.config.Load().(*cancelConfig) + default: + return nil + } +} diff --git a/adk/react_test.go b/adk/react_test.go index 5364f0912..969e73e35 100644 --- a/adk/react_test.go +++ b/adk/react_test.go @@ -144,7 +144,7 @@ func TestReact(t *testing.T) { toolsReturnDirectly: map[string]bool{}, } - graph, err := newReact(ctx, config) + graph, err := newReact(ctx, config, nil) assert.NoError(t, err) assert.NotNil(t, graph) @@ -211,7 +211,7 @@ func TestReact(t *testing.T) { toolsReturnDirectly: map[string]bool{info.Name: true}, } - graph, err := newReact(ctx, config) + graph, err := newReact(ctx, config, nil) assert.NoError(t, err) assert.NotNil(t, graph) @@ -303,7 +303,7 @@ func TestReact(t *testing.T) { toolsReturnDirectly: map[string]bool{}, } - graph, err := newReact(ctx, config) + graph, err := newReact(ctx, config, nil) assert.NoError(t, err) assert.NotNil(t, graph) @@ -413,7 +413,7 @@ func TestReact(t *testing.T) { toolsReturnDirectly: map[string]bool{streamInfo.Name: true}, } - graph, err := newReact(ctx, config) + graph, err := newReact(ctx, config, nil) assert.NoError(t, err) assert.NotNil(t, graph) @@ -502,7 +502,7 @@ func TestReact(t *testing.T) { maxIterations: 6, } - graph, err := newReact(ctx, config) + graph, err := newReact(ctx, config, nil) assert.NoError(t, err) assert.NotNil(t, graph) @@ -532,7 +532,7 @@ func TestReact(t *testing.T) { maxIterations: 5, } - graph, err = newReact(ctx, config) + graph, err = newReact(ctx, config, nil) assert.NoError(t, err) assert.NotNil(t, graph) diff --git a/adk/runner.go b/adk/runner.go index 07a931ac2..a9fd9b94d 100644 --- a/adk/runner.go +++ b/adk/runner.go @@ -74,6 +74,31 @@ func NewRunner(_ context.Context, conf RunnerConfig) *Runner { // upon interruption. func (r *Runner) Run(ctx context.Context, messages []Message, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { + iter, _, _ := r.runWithCancel(ctx, messages, false, opts...) + return iter +} + +// Query is a convenience method that starts a new execution with a single user query string. +func (r *Runner) Query(ctx context.Context, + query string, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { + + return r.Run(ctx, []Message{schema.UserMessage(query)}, opts...) +} + +// RunWithCancel starts a new execution of the agent and returns both an iterator and a cancel function. +// The cancel function can be used to interrupt the running agent at specific points based on the CancelMode. +// If the Runner was configured with a CheckPointStore and WithCheckPointID option, it will automatically +// save the agent's state upon cancellation for later resumption. +// +// If the agent does not implement CancellableAgent, the returned CancelFunc will be nil. +func (r *Runner) RunWithCancel(ctx context.Context, messages []Message, + opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) { + iter, cancelFn, _ := r.runWithCancel(ctx, messages, true, opts...) + return iter, cancelFn +} + +func (r *Runner) runWithCancel(ctx context.Context, messages []Message, withCancel bool, + opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc, error) { o := getCommonOptions(nil, opts...) fa := toFlowAgent(ctx, r.a) @@ -87,22 +112,46 @@ func (r *Runner) Run(ctx context.Context, messages []Message, AddSessionValues(ctx, o.sessionValues) - iter := fa.Run(ctx, input, opts...) + var iter *AsyncIterator[*AgentEvent] + var cancelFn CancelFunc + if withCancel { + if _, ok := r.a.(CancellableAgent); ok { + iter, cancelFn = fa.RunWithCancel(ctx, input, opts...) + } else { + iter = fa.Run(ctx, input, opts...) + } + } else { + iter = fa.Run(ctx, input, opts...) + } + if r.store == nil { - return iter + return iter, cancelFn, nil } niter, gen := NewAsyncIteratorPair[*AgentEvent]() go r.handleIter(ctx, iter, gen, o.checkPointID) - return niter + return niter, cancelFn, nil } -// Query is a convenience method that starts a new execution with a single user query string. -func (r *Runner) Query(ctx context.Context, - query string, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { +// ResumeWithCancel continues an interrupted execution from a checkpoint and returns both an iterator and a cancel function. +// This method uses the "Implicit Resume All" strategy where all previously interrupted points proceed without specific data. +// The cancel function can be used to interrupt the running agent again at specific points based on the CancelMode. +// +// If the agent does not implement CancellableResumableAgent, the returned CancelFunc will be nil. +func (r *Runner) ResumeWithCancel(ctx context.Context, checkPointID string, opts ...AgentRunOption) ( + *AsyncIterator[*AgentEvent], CancelFunc, error) { + return r.resumeWithCancel(ctx, checkPointID, nil, true, opts...) +} - return r.Run(ctx, []Message{schema.UserMessage(query)}, opts...) +// ResumeWithParamsAndCancel continues an interrupted execution from a checkpoint with specific parameters +// and returns both an iterator and a cancel function. +// The params.Targets map should contain the addresses of the components to be resumed as keys. +// +// If the agent does not implement CancellableResumableAgent, the returned CancelFunc will be nil. +func (r *Runner) ResumeWithParamsAndCancel(ctx context.Context, checkPointID string, params *ResumeParams, + opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc, error) { + return r.resumeWithCancel(ctx, checkPointID, params.Targets, true, opts...) } // Resume continues an interrupted execution from a checkpoint, using an "Implicit Resume All" strategy. @@ -114,7 +163,8 @@ func (r *Runner) Query(ctx context.Context, // pattern where an agent only needs to know `wasInterrupted` is true to continue. func (r *Runner) Resume(ctx context.Context, checkPointID string, opts ...AgentRunOption) ( *AsyncIterator[*AgentEvent], error) { - return r.resume(ctx, checkPointID, nil, opts...) + iter, _, err := r.resumeWithCancel(ctx, checkPointID, nil, false, opts...) + return iter, err } // ResumeWithParams continues an interrupted execution from a checkpoint with specific parameters. @@ -136,19 +186,19 @@ func (r *Runner) Resume(ctx context.Context, checkPointID string, opts ...AgentR // naturally re-interrupt if one of their interrupted children re-interrupts, as they receive the // new `CompositeInterrupt` signal from them. func (r *Runner) ResumeWithParams(ctx context.Context, checkPointID string, params *ResumeParams, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], error) { - return r.resume(ctx, checkPointID, params.Targets, opts...) + iter, _, err := r.resumeWithCancel(ctx, checkPointID, params.Targets, false, opts...) + return iter, err } -// resume is the internal implementation for both Resume and ResumeWithParams. -func (r *Runner) resume(ctx context.Context, checkPointID string, resumeData map[string]any, - opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], error) { +func (r *Runner) resumeWithCancel(ctx context.Context, checkPointID string, resumeData map[string]any, + withCancel bool, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc, error) { if r.store == nil { - return nil, fmt.Errorf("failed to resume: store is nil") + return nil, nil, fmt.Errorf("failed to resume: store is nil") } ctx, runCtx, resumeInfo, err := r.loadCheckPoint(ctx, checkPointID) if err != nil { - return nil, fmt.Errorf("failed to load from checkpoint: %w", err) + return nil, nil, fmt.Errorf("failed to load from checkpoint: %w", err) } o := getCommonOptions(nil, opts...) @@ -175,15 +225,27 @@ func (r *Runner) resume(ctx context.Context, checkPointID string, resumeData map } fa := toFlowAgent(ctx, r.a) - aIter := fa.Resume(ctx, resumeInfo, opts...) + + var aIter *AsyncIterator[*AgentEvent] + var cancelFn CancelFunc + if withCancel { + if _, ok := r.a.(CancellableResumableAgent); ok { + aIter, cancelFn = fa.ResumeWithCancel(ctx, resumeInfo, opts...) + } else { + aIter = fa.Resume(ctx, resumeInfo, opts...) + } + } else { + aIter = fa.Resume(ctx, resumeInfo, opts...) + } + if r.store == nil { - return aIter, nil + return aIter, cancelFn, nil } niter, gen := NewAsyncIteratorPair[*AgentEvent]() go r.handleIter(ctx, aIter, gen, &checkPointID) - return niter, nil + return niter, cancelFn, nil } func (r *Runner) handleIter(ctx context.Context, aIter *AsyncIterator[*AgentEvent], diff --git a/adk/turn_loop.go b/adk/turn_loop.go index e5cf1c36f..af69ae5d3 100644 --- a/adk/turn_loop.go +++ b/adk/turn_loop.go @@ -21,6 +21,8 @@ import ( "errors" "fmt" "runtime/debug" + "sync" + "sync/atomic" "time" "github.com/cloudwego/eino/internal/safe" @@ -44,19 +46,24 @@ const ( ) type consumeConfig struct { - Mode ConsumeMode - Timeout time.Duration - CancelOpts []CancelOption + Mode ConsumeMode + Timeout time.Duration + CancelOpts []CancelOption + CheckPointID string } type ConsumeOption func(*consumeConfig) +// WithPreemptive sets the consume mode to preemptive, which cancels the +// currently running agent immediately. func WithPreemptive() ConsumeOption { return func(config *consumeConfig) { config.Mode = ConsumePreemptive } } +// WithPreemptiveOnTimeout sets the consume mode to preemptive with a timeout. +// If the current agent does not complete within the timeout, it will be canceled. func WithPreemptiveOnTimeout(timeout time.Duration) ConsumeOption { return func(config *consumeConfig) { config.Mode = ConsumePreemptiveOnTimeout @@ -64,12 +71,21 @@ func WithPreemptiveOnTimeout(timeout time.Duration) ConsumeOption { } } +// WithCancelOptions appends cancel options to be used when canceling the agent. func WithCancelOptions(opts ...CancelOption) ConsumeOption { return func(config *consumeConfig) { config.CancelOpts = append(config.CancelOpts, opts...) } } +// WithConsumeCheckPointID sets the checkpoint ID for the consumed message. +// When set, the checkpoint will be saved with this ID if an interrupt occurs. +func WithConsumeCheckPointID(id string) ConsumeOption { + return func(config *consumeConfig) { + config.CheckPointID = id + } +} + type ReceiveConfig struct { Timeout time.Duration } @@ -79,6 +95,24 @@ type MessageSource[T any] interface { Front(context.Context, ReceiveConfig) (context.Context, T, []ConsumeOption, error) } +type turnLoopRunConfig[T any] struct { + checkPointID string + item T +} + +// TurnLoopRunOption is an option for TurnLoop.Run. +type TurnLoopRunOption[T any] func(*turnLoopRunConfig[T]) + +// WithTurnLoopResume configures the TurnLoop to resume from a previously saved checkpoint. +// The checkPointID identifies the checkpoint to resume from, and item is the original input +// that triggered the interrupted execution. +func WithTurnLoopResume[T any](checkPointID string, item T) TurnLoopRunOption[T] { + return func(c *turnLoopRunConfig[T]) { + c.checkPointID = checkPointID + c.item = item + } +} + // TurnLoopConfig is the configuration for creating a TurnLoop. type TurnLoopConfig[T any] struct { // Source provides messages to drive the loop. Required. @@ -88,12 +122,16 @@ type TurnLoopConfig[T any] struct { GenInput func(ctx context.Context, item T) (*AgentInput, []AgentRunOption, error) // GetAgent returns the Agent to run for a given message. Required. GetAgent func(ctx context.Context, item T) (Agent, error) - // OnAgentEvent is called for each event emitted by the agent. Optional. + // OnAgentEvents is called for each event emitted by the agent. Optional. // The inputItem is the message that triggered the current agent turn. + // If not provided, the default implementation will consume all events and + // return any error event encountered. OnAgentEvents func(ctx context.Context, inputItem T, event *AsyncIterator[*AgentEvent]) error // ReceiveTimeout is the timeout passed to Source.Receive on each iteration. // Zero means no timeout. Optional. ReceiveTimeout time.Duration + + Store CheckPointStore } // TurnLoop is a loop that pulls messages from a source, runs an Agent for @@ -106,8 +144,63 @@ type TurnLoop[T any] struct { getAgent func(ctx context.Context, item T) (Agent, error) onAgentEvents func(ctx context.Context, inputItem T, event *AsyncIterator[*AgentEvent]) error receiveTimeout time.Duration + store CheckPointStore +} + +type turnLoopCancelSig struct { + done chan struct{} + config atomic.Value } +func newTurnLoopCancelSig() *turnLoopCancelSig { + return &turnLoopCancelSig{ + done: make(chan struct{}), + } +} + +func (cs *turnLoopCancelSig) cancel(cfg *cancelConfig) { + cs.config.Store(cfg) + close(cs.done) +} + +func (cs *turnLoopCancelSig) isCancelled() bool { + select { + case <-cs.done: + return true + default: + return false + } +} + +func (cs *turnLoopCancelSig) getConfig() *cancelConfig { + if v := cs.config.Load(); v != nil { + return v.(*cancelConfig) + } + return nil +} + +type turnLoopCancelSigKey struct{} + +func withTurnLoopCancelSig(ctx context.Context, cs *turnLoopCancelSig) context.Context { + return context.WithValue(ctx, turnLoopCancelSigKey{}, cs) +} + +func getTurnLoopCancelSig(ctx context.Context) *turnLoopCancelSig { + if v, ok := ctx.Value(turnLoopCancelSigKey{}).(*turnLoopCancelSig); ok { + return v + } + return nil +} + +// TurnLoopCancelFunc is the cancel function returned by WithCancel. +// Unlike Agent's CancelFunc, it does not require a context parameter +// since the context is already bound when WithCancel is called. +type TurnLoopCancelFunc func(opts ...CancelOption) error + +// ErrAgentNotCancellableInTurnLoop is returned when WithCancel context is used +// but the Agent does not implement CancellableAgent. +var ErrAgentNotCancellableInTurnLoop = errors.New("agent does not support cancel but WithCancel context was provided") + // NewTurnLoop creates a new TurnLoop from the given configuration. // Source, GenInput, and GetAgent are required fields. func NewTurnLoop[T any](config TurnLoopConfig[T]) (*TurnLoop[T], error) { @@ -121,17 +214,78 @@ func NewTurnLoop[T any](config TurnLoopConfig[T]) (*TurnLoop[T], error) { return nil, fmt.Errorf("TurnLoopConfig.GetAgent is required") } + onAgentEvents := config.OnAgentEvents + if onAgentEvents == nil { + onAgentEvents = func(_ context.Context, _ T, iter *AsyncIterator[*AgentEvent]) error { + for { + event, ok := iter.Next() + if !ok { + break + } + if event.Err != nil { + return event.Err + } + } + return nil + } + } + return &TurnLoop[T]{ source: config.Source, genInput: config.GenInput, getAgent: config.GetAgent, - onAgentEvents: config.OnAgentEvents, + onAgentEvents: onAgentEvents, receiveTimeout: config.ReceiveTimeout, + store: config.Store, }, nil } var ErrLoopExit = errors.New("loop exit") +// WithCancel returns a new context and a cancel function that can be used to +// cancel the TurnLoop's Run method externally. Each call to WithCancel creates +// an independent cancel signal, allowing multiple concurrent Run calls with +// separate cancel controls. +// +// The returned TurnLoopCancelFunc does not require a context parameter since +// the context is already bound when WithCancel is called. +// +// Example: +// +// ctx, cancel := turnLoop.WithCancel(context.Background()) +// go func() { +// err := turnLoop.Run(ctx) +// }() +// // Later, to cancel: +// cancel(adk.WithCancelMode(adk.CancelAfterToolCall)) +func (l *TurnLoop[T]) WithCancel(ctx context.Context) (context.Context, TurnLoopCancelFunc) { + cs := newTurnLoopCancelSig() + ctx = withTurnLoopCancelSig(ctx, cs) + + var once sync.Once + cancelFn := func(opts ...CancelOption) error { + cfg := &cancelConfig{ + Mode: CancelImmediate, + } + for _, opt := range opts { + opt(cfg) + } + + cancelled := false + once.Do(func() { + cs.cancel(cfg) + cancelled = true + }) + + if !cancelled { + return ErrAgentFinished + } + return nil + } + + return ctx, cancelFn +} + // Run starts the blocking loop that continuously receives messages from the // source, runs the agent returned by GetAgent for each message, and dispatches // resulting events to OnAgentEvent. It blocks until the source returns an error @@ -141,16 +295,60 @@ var ErrLoopExit = errors.New("loop exit") // implements Cancellable, the agent is canceled and the new message is processed // immediately. If the agent does not implement Cancellable, preemptive messages // are queued and processed after the current agent finishes. -func (l *TurnLoop[T]) Run(ctx context.Context) error { +// +// To enable external cancellation, use WithCancel to create a cancellable context: +// +// ctx, cancel := turnLoop.WithCancel(context.Background()) +// go turnLoop.Run(ctx) +// // Later: cancel() +// +// To enable checkpoint-based resumption, use WithTurnLoopResume: +// +// err := turnLoop.Run(ctx, WithTurnLoopResume("session-123")) +// +//nolint:cyclop,funlen // This is a core method, splitting would make the logic harder to follow +func (l *TurnLoop[T]) Run(ctx context.Context, opts ...TurnLoopRunOption[T]) error { + var runCfg turnLoopRunConfig[T] + for _, opt := range opts { + opt(&runCfg) + } + + cs := getTurnLoopCancelSig(ctx) + toResumeFirst := false + if len(runCfg.checkPointID) > 0 { + toResumeFirst = true + } + for { - nCtx, item, option, err := l.source.Receive(ctx, ReceiveConfig{ - Timeout: l.receiveTimeout, - }) - if errors.Is(err, ErrLoopExit) { + if cs != nil && cs.isCancelled() { return nil } - if err != nil { - return fmt.Errorf("failed to receive message: %w", err) + + var nCtx context.Context + var item T + var checkPointID string + if !toResumeFirst { + var err error + var option []ConsumeOption + nCtx, item, option, err = l.source.Receive(ctx, ReceiveConfig{ + Timeout: l.receiveTimeout, + }) + if errors.Is(err, ErrLoopExit) { + return nil + } + if err != nil { + return fmt.Errorf("failed to receive message: %w", err) + } + o := applyConsumeOptions(option) + checkPointID = o.CheckPointID + } else { + nCtx = ctx + item = runCfg.item + checkPointID = runCfg.checkPointID + } + + if len(checkPointID) > 0 && l.store == nil { + return fmt.Errorf("CheckPointStore is required") } input, runOpts, e := l.genInput(nCtx, item) @@ -165,33 +363,56 @@ func (l *TurnLoop[T]) Run(ctx context.Context) error { var cancelFunc CancelFunc var iter *AsyncIterator[*AgentEvent] - if ca, isAgentCancellable := agent.(CancellableRun); isAgentCancellable { - iter, cancelFunc = ca.RunWithCancel(nCtx, input, runOpts...) + _, isAgentCancellable := agent.(CancellableAgent) + if cs != nil && !isAgentCancellable { + return fmt.Errorf("%w: agent %s", ErrAgentNotCancellableInTurnLoop, agent.Name(nCtx)) + } + + if toResumeFirst { + var err error + iter, cancelFunc, err = NewRunner(nCtx, RunnerConfig{ + EnableStreaming: input.EnableStreaming, + Agent: agent, + CheckPointStore: l.store, + }).ResumeWithCancel(nCtx, checkPointID, runOpts...) + if err != nil { + return fmt.Errorf("failed to resume agent: %w", err) + } + toResumeFirst = false + } else if isAgentCancellable { + var cps CheckPointStore + if len(checkPointID) > 0 { + cps = l.store + runOpts = append(runOpts, WithCheckPointID(checkPointID)) + } + iter, cancelFunc = NewRunner(nCtx, RunnerConfig{ + EnableStreaming: input.EnableStreaming, + Agent: agent, + CheckPointStore: cps, + }).RunWithCancel(nCtx, input.Messages, runOpts...) } else { - iter = agent.Run(nCtx, input, runOpts...) + var cps CheckPointStore + if len(checkPointID) > 0 { + cps = l.store + runOpts = append(runOpts, WithCheckPointID(checkPointID)) + } + iter = NewRunner(nCtx, RunnerConfig{ + EnableStreaming: input.EnableStreaming, + Agent: agent, + CheckPointStore: cps, + }).Run(nCtx, input.Messages, runOpts...) } - // handleEvents drains the agent iterator, forwarding each event to the - // OnAgentEvent callback. It is called directly in the non-cancellable - // path and from a goroutine in the cancellable path. handleEvents := func() error { - oe := l.onAgentEvents(ctx, item, iter) - if oe != nil { - return oe - } - return nil + return l.handleEvents(ctx, item, iter, checkPointID) } var handleEventErr error if cancelFunc != nil { - // Cancellable path: consume events in a goroutine so the main - // goroutine can block on Receive concurrently. done := make(chan struct{}) go func() { defer func() { - // Recover panics from the iterator or callback so they - // don't crash the process; surface them as errors instead. panicErr := recover() if panicErr != nil { handleEventErr = safe.NewPanicErr(panicErr, debug.Stack()) @@ -203,63 +424,150 @@ func (l *TurnLoop[T]) Run(ctx context.Context) error { handleEventErr = handleEvents() }() - // Block on the next message while events are being consumed above. - _, _, option, err = l.source.Front(nCtx, ReceiveConfig{ - Timeout: l.receiveTimeout, - }) - if err != nil { - <-done // wait for the event goroutine before returning - if errors.Is(err, ErrLoopExit) { + frontDone := make(chan struct{}) + var frontErr error + var option []ConsumeOption + go func() { + defer func() { + panicErr := recover() + if panicErr != nil { + frontErr = safe.NewPanicErr(panicErr, debug.Stack()) + } + + close(frontDone) + }() + _, _, option, frontErr = l.source.Front(nCtx, ReceiveConfig{ + Timeout: l.receiveTimeout, + }) + }() + + var externalCancelled bool + select { + case <-frontDone: + case <-done: + case <-func() <-chan struct{} { + if cs != nil { + return cs.done + } + return nil + }(): + externalCancelled = true + cfg := cs.getConfig() + err := cancelFunc(nCtx, cancelConfigToOpts(cfg)...) + if err != nil && !errors.Is(err, ErrAgentFinished) { + <-done + return fmt.Errorf("failed to cancel agent: %w", err) + } + } + + if externalCancelled { + <-done + return l.wrapHandleEventErr(handleEventErr) + } + + if frontErr != nil { + <-done + if errors.Is(frontErr, ErrLoopExit) { return nil } - return fmt.Errorf("failed to front message: %w", err) + return fmt.Errorf("failed to front message: %w", frontErr) } - // If the new message requests preemption, cancel the running agent. - // Cancel triggers the iterator to terminate, which unblocks the - // event goroutine above. o := applyConsumeOptions(option) switch o.Mode { case ConsumePreemptive: - err = cancelFunc(nCtx, o.CancelOpts...) + err := cancelFunc(nCtx, o.CancelOpts...) if err != nil { - <-done // wait for the event goroutine before returning + <-done return fmt.Errorf("failed to cancel agent: %w", err) } case ConsumePreemptiveOnTimeout: select { case <-done: case <-time.After(o.Timeout): - err = cancelFunc(nCtx, o.CancelOpts...) + err := cancelFunc(nCtx, o.CancelOpts...) if err != nil { - <-done // wait for the event goroutine before returning + <-done + return fmt.Errorf("failed to cancel agent: %w", err) + } + case <-func() <-chan struct{} { + if cs != nil { + return cs.done + } + return nil + }(): + cfg := cs.getConfig() + err := cancelFunc(nCtx, cancelConfigToOpts(cfg)...) + if err != nil && !errors.Is(err, ErrAgentFinished) { + <-done return fmt.Errorf("failed to cancel agent: %w", err) } + <-done + return l.wrapHandleEventErr(handleEventErr) } } - // Wait for event consumption to finish (normal completion or - // post-cancel drain) before starting the next turn. <-done - if handleEventErr != nil { - if errors.Is(handleEventErr, ErrLoopExit) { - return nil - } - return fmt.Errorf("failed to handle events: %w", handleEventErr) + if err := l.wrapHandleEventErr(handleEventErr); err != nil { + return err } } else { - // Non-cancellable path: consume all events sequentially, then - // block on the next message. if handleEventErr = handleEvents(); handleEventErr != nil { - if errors.Is(handleEventErr, ErrLoopExit) { - return nil + if err := l.wrapHandleEventErr(handleEventErr); err != nil { + return err } - return fmt.Errorf("failed to handle events: %w", handleEventErr) } } } } +func (l *TurnLoop[T]) wrapHandleEventErr(handleEventErr error) error { + if handleEventErr == nil { + return nil + } + if errors.Is(handleEventErr, ErrLoopExit) { + return nil + } + var interruptErr *TurnLoopInterruptError[T] + if errors.As(handleEventErr, &interruptErr) { + return interruptErr + } + return fmt.Errorf("failed to handle events: %w", handleEventErr) +} + +func (l *TurnLoop[T]) handleEvents(ctx context.Context, item T, iter *AsyncIterator[*AgentEvent], checkPointID string) error { + copies := copyEventIterator(iter, 2) + oe := l.onAgentEvents(ctx, item, copies[0]) + if oe != nil { + return oe + } + for { + e, ok := copies[1].Next() + if !ok { + break + } + if e.Action != nil && e.Action.Interrupted != nil { + return &TurnLoopInterruptError[T]{ + Item: item, + CheckpointID: checkPointID, + InterruptContexts: e.Action.Interrupted.InterruptContexts, + } + } + } + return nil +} + +func cancelConfigToOpts(cfg *cancelConfig) []CancelOption { + if cfg == nil { + return nil + } + opts := []CancelOption{WithCancelMode(cfg.Mode)} + if cfg.Timeout != nil { + opts = append(opts, WithCancelTimeout(*cfg.Timeout)) + } + return opts +} + func applyConsumeOptions(opts []ConsumeOption) *consumeConfig { var config consumeConfig for _, opt := range opts { @@ -267,3 +575,15 @@ func applyConsumeOptions(opts []ConsumeOption) *consumeConfig { } return &config } + +type TurnLoopInterruptError[T any] struct { + Item T + CheckpointID string + // InterruptContexts provides a structured, user-facing view of the interrupt chain. + // Each context represents a step in the agent hierarchy that was interrupted. + InterruptContexts []*InterruptCtx +} + +func (t *TurnLoopInterruptError[T]) Error() string { + return fmt.Sprintf("TurnLoopInterruptError[%s]: %v", t.CheckpointID, t.InterruptContexts) +} diff --git a/adk/turn_loop_test.go b/adk/turn_loop_test.go index f66c1298b..7c22f3699 100644 --- a/adk/turn_loop_test.go +++ b/adk/turn_loop_test.go @@ -19,53 +19,66 @@ package adk import ( "context" "errors" + "fmt" "sync/atomic" "testing" "time" + "unsafe" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/components/tool" + "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/schema" ) -// --------------------------------------------------------------------------- -// Test mocks -// --------------------------------------------------------------------------- - -// turnLoopMockSource returns items from a slice (all NonPreemptive), then an error. type turnLoopMockSource struct { items []string idx int err error } -func (s *turnLoopMockSource) Receive(_ context.Context, _ time.Duration) (string, ConsumeOption, error) { +func (s *turnLoopMockSource) Receive(ctx context.Context, cfg ReceiveConfig) (context.Context, string, []ConsumeOption, error) { if s.idx >= len(s.items) { - return "", NonPreemptiveConsumeOption, s.err + return ctx, "", nil, s.err } item := s.items[s.idx] s.idx++ - return item, NonPreemptiveConsumeOption, nil + return ctx, item, nil, nil +} + +func (s *turnLoopMockSource) Front(ctx context.Context, cfg ReceiveConfig) (context.Context, string, []ConsumeOption, error) { + if s.idx >= len(s.items) { + return ctx, "", nil, s.err + } + return ctx, s.items[s.idx], nil, nil } -// turnLoopFuncSource delegates Receive to a user-supplied function. type turnLoopFuncSource[T any] struct { - fn func(ctx context.Context, timeout time.Duration) (T, ConsumeOption, error) + receive func(ctx context.Context, cfg ReceiveConfig) (context.Context, T, []ConsumeOption, error) + front func(ctx context.Context, cfg ReceiveConfig) (context.Context, T, []ConsumeOption, error) } -func (s *turnLoopFuncSource[T]) Receive(ctx context.Context, timeout time.Duration) (T, ConsumeOption, error) { - return s.fn(ctx, timeout) +func (s *turnLoopFuncSource[T]) Receive(ctx context.Context, cfg ReceiveConfig) (context.Context, T, []ConsumeOption, error) { + return s.receive(ctx, cfg) +} + +func (s *turnLoopFuncSource[T]) Front(ctx context.Context, cfg ReceiveConfig) (context.Context, T, []ConsumeOption, error) { + if s.front != nil { + return s.front(ctx, cfg) + } + return s.receive(ctx, cfg) } -// turnLoopMockAgent emits a fixed list of events per Run call. type turnLoopMockAgent struct { name string events []*AgentEvent } func (a *turnLoopMockAgent) Name(_ context.Context) string { return a.name } -func (a *turnLoopMockAgent) Description(_ context.Context) string { return "mock agent" } +func (a *turnLoopMockAgent) Description(_ context.Context) string { return "mock agent" } func (a *turnLoopMockAgent) Run(_ context.Context, _ *AgentInput, _ ...AgentRunOption) *AsyncIterator[*AgentEvent] { iter, gen := NewAsyncIteratorPair[*AgentEvent]() go func() { @@ -77,39 +90,37 @@ func (a *turnLoopMockAgent) Run(_ context.Context, _ *AgentInput, _ ...AgentRunO return iter } -// turnLoopCancellableAgent blocks until Cancel is called, then closes its iterator. -// It implements both Agent and Cancellable. type turnLoopCancellableAgent struct { - name string - startedCh chan struct{} // closed when Run is entered - cancelCh chan struct{} // closed by Cancel - cancelled atomic.Bool - cancelledOpt CancelOption // records the CancelOption passed to Cancel + name string + startedCh chan struct{} + cancelCh chan struct{} + cancelled int32 + cancelledOpt []CancelOption } func (a *turnLoopCancellableAgent) Name(_ context.Context) string { return a.name } -func (a *turnLoopCancellableAgent) Description(_ context.Context) string { return "cancellable mock" } +func (a *turnLoopCancellableAgent) Description(_ context.Context) string { return "cancellable mock" } func (a *turnLoopCancellableAgent) Run(_ context.Context, _ *AgentInput, _ ...AgentRunOption) *AsyncIterator[*AgentEvent] { + iter, _ := a.RunWithCancel(context.Background(), nil) + return iter +} + +func (a *turnLoopCancellableAgent) RunWithCancel(_ context.Context, _ *AgentInput, _ ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) { iter, gen := NewAsyncIteratorPair[*AgentEvent]() close(a.startedCh) go func() { defer gen.Close() <-a.cancelCh }() - return iter -} - -func (a *turnLoopCancellableAgent) Cancel(_ context.Context, opt CancelOption) error { - a.cancelled.Store(true) - a.cancelledOpt = opt - close(a.cancelCh) - return nil + cancelFunc := func(_ context.Context, opts ...CancelOption) error { + atomic.StoreInt32(&a.cancelled, 1) + a.cancelledOpt = opts + close(a.cancelCh) + return nil + } + return iter, cancelFunc } -// --------------------------------------------------------------------------- -// Tests — validation -// --------------------------------------------------------------------------- - func TestNewTurnLoop_Validation(t *testing.T) { t.Run("missing source", func(t *testing.T) { _, err := NewTurnLoop(TurnLoopConfig[string]{ @@ -138,7 +149,7 @@ func TestNewTurnLoop_Validation(t *testing.T) { assert.Contains(t, err.Error(), "GetAgent") }) - t.Run("valid config", func(t *testing.T) { + t.Run("valid config without OnAgentEvents", func(t *testing.T) { loop, err := NewTurnLoop(TurnLoopConfig[string]{ Source: &turnLoopMockSource{}, GenInput: func(_ context.Context, _ string) (*AgentInput, []AgentRunOption, error) { return nil, nil, nil }, @@ -147,11 +158,18 @@ func TestNewTurnLoop_Validation(t *testing.T) { require.NoError(t, err) assert.NotNil(t, loop) }) -} -// --------------------------------------------------------------------------- -// Tests — non-preemptive (queued) behavior -// --------------------------------------------------------------------------- + t.Run("valid config with OnAgentEvents", func(t *testing.T) { + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: &turnLoopMockSource{}, + GenInput: func(_ context.Context, _ string) (*AgentInput, []AgentRunOption, error) { return nil, nil, nil }, + GetAgent: func(_ context.Context, _ string) (Agent, error) { return nil, nil }, + OnAgentEvents: func(_ context.Context, _ string, _ *AsyncIterator[*AgentEvent]) error { return nil }, + }) + require.NoError(t, err) + assert.NotNil(t, loop) + }) +} func TestTurnLoop_NormalLoop(t *testing.T) { agent := &turnLoopMockAgent{ @@ -161,8 +179,8 @@ func TestTurnLoop_NormalLoop(t *testing.T) { }, } - var receivedEvents []*AgentEvent var receivedItems []string + var eventCount int source := &turnLoopMockSource{ items: []string{"msg1", "msg2", "msg3"}, @@ -178,8 +196,14 @@ func TestTurnLoop_NormalLoop(t *testing.T) { GetAgent: func(_ context.Context, _ string) (Agent, error) { return agent, nil }, - OnAgentEvent: func(_ context.Context, _ string, event *AgentEvent) error { - receivedEvents = append(receivedEvents, event) + OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { + for { + _, ok := iter.Next() + if !ok { + break + } + eventCount++ + } return nil }, }) @@ -188,7 +212,7 @@ func TestTurnLoop_NormalLoop(t *testing.T) { err = loop.Run(context.Background()) assert.ErrorIs(t, err, context.DeadlineExceeded) assert.Equal(t, []string{"msg1", "msg2", "msg3"}, receivedItems) - assert.Len(t, receivedEvents, 3) + assert.Equal(t, 3, eventCount) } func TestTurnLoop_SourceError(t *testing.T) { @@ -206,6 +230,7 @@ func TestTurnLoop_SourceError(t *testing.T) { GetAgent: func(_ context.Context, _ string) (Agent, error) { return &turnLoopMockAgent{name: "a"}, nil }, + OnAgentEvents: func(_ context.Context, _ string, _ *AsyncIterator[*AgentEvent]) error { return nil }, }) require.NoError(t, err) @@ -224,6 +249,7 @@ func TestTurnLoop_GenInputError(t *testing.T) { GetAgent: func(_ context.Context, _ string) (Agent, error) { return &turnLoopMockAgent{name: "a"}, nil }, + OnAgentEvents: func(_ context.Context, _ string, _ *AsyncIterator[*AgentEvent]) error { return nil }, }) require.NoError(t, err) @@ -243,6 +269,7 @@ func TestTurnLoop_GetAgentError(t *testing.T) { GetAgent: func(_ context.Context, _ string) (Agent, error) { return nil, agentErr }, + OnAgentEvents: func(_ context.Context, _ string, _ *AsyncIterator[*AgentEvent]) error { return nil }, }) require.NoError(t, err) @@ -251,7 +278,7 @@ func TestTurnLoop_GetAgentError(t *testing.T) { assert.Contains(t, err.Error(), "failed to get agent") } -func TestTurnLoop_OnAgentEventError(t *testing.T) { +func TestTurnLoop_OnAgentEventsError(t *testing.T) { eventErr := errors.New("event handler failure") agent := &turnLoopMockAgent{ name: "test-agent", @@ -266,7 +293,7 @@ func TestTurnLoop_OnAgentEventError(t *testing.T) { GetAgent: func(_ context.Context, _ string) (Agent, error) { return agent, nil }, - OnAgentEvent: func(_ context.Context, _ string, _ *AgentEvent) error { + OnAgentEvents: func(_ context.Context, _ string, _ *AsyncIterator[*AgentEvent]) error { return eventErr }, }) @@ -274,20 +301,20 @@ func TestTurnLoop_OnAgentEventError(t *testing.T) { err = loop.Run(context.Background()) assert.ErrorIs(t, err, eventErr) - assert.Contains(t, err.Error(), "OnAgentEvent callback failed") + assert.Contains(t, err.Error(), "failed to handle events") } func TestTurnLoop_ContextCancellation(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) callCount := 0 - source := &turnLoopFuncSource[string]{fn: func(ctx context.Context, _ time.Duration) (string, ConsumeOption, error) { + source := &turnLoopFuncSource[string]{receive: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { callCount++ if callCount > 1 { cancel() - return "", NonPreemptiveConsumeOption, ctx.Err() + return ctx, "", nil, ctx.Err() } - return "msg1", NonPreemptiveConsumeOption, nil + return ctx, "msg1", nil, nil }} agent := &turnLoopMockAgent{ @@ -303,6 +330,14 @@ func TestTurnLoop_ContextCancellation(t *testing.T) { GetAgent: func(_ context.Context, _ string) (Agent, error) { return agent, nil }, + OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := iter.Next(); !ok { + break + } + } + return nil + }, }) require.NoError(t, err) @@ -330,8 +365,14 @@ func TestTurnLoop_MultipleEventsPerTurn(t *testing.T) { GetAgent: func(_ context.Context, _ string) (Agent, error) { return agent, nil }, - OnAgentEvent: func(_ context.Context, _ string, _ *AgentEvent) error { - eventCount++ + OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { + for { + _, ok := iter.Next() + if !ok { + break + } + eventCount++ + } return nil }, }) @@ -342,7 +383,7 @@ func TestTurnLoop_MultipleEventsPerTurn(t *testing.T) { assert.Equal(t, 3, eventCount) } -func TestTurnLoop_NoOnAgentEvent(t *testing.T) { +func TestTurnLoop_DefaultOnAgentEvents(t *testing.T) { agent := &turnLoopMockAgent{ name: "test-agent", events: []*AgentEvent{{Output: &AgentOutput{}}}, @@ -363,6 +404,28 @@ func TestTurnLoop_NoOnAgentEvent(t *testing.T) { assert.ErrorIs(t, err, context.DeadlineExceeded) } +func TestTurnLoop_DefaultOnAgentEventsWithError(t *testing.T) { + agentErr := errors.New("agent internal error") + agent := &turnLoopMockAgent{ + name: "error-agent", + events: []*AgentEvent{{Err: agentErr}}, + } + + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: &turnLoopMockSource{items: []string{"msg1"}, err: context.DeadlineExceeded}, + GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil + }, + GetAgent: func(_ context.Context, _ string) (Agent, error) { + return agent, nil + }, + }) + require.NoError(t, err) + + err = loop.Run(context.Background()) + assert.ErrorIs(t, err, agentErr) +} + func TestTurnLoop_AgentErrorEvent(t *testing.T) { agentErr := errors.New("agent internal error") agent := &turnLoopMockAgent{ @@ -378,6 +441,18 @@ func TestTurnLoop_AgentErrorEvent(t *testing.T) { GetAgent: func(_ context.Context, _ string) (Agent, error) { return agent, nil }, + OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { + for { + event, ok := iter.Next() + if !ok { + break + } + if event.Err != nil { + return fmt.Errorf("agent run failed: %w", event.Err) + } + } + return nil + }, }) require.NoError(t, err) @@ -386,10 +461,6 @@ func TestTurnLoop_AgentErrorEvent(t *testing.T) { assert.Contains(t, err.Error(), "agent run failed") } -// --------------------------------------------------------------------------- -// Tests — preemptive behavior -// --------------------------------------------------------------------------- - func TestTurnLoop_PreemptiveCancellation(t *testing.T) { slowAgent := &turnLoopCancellableAgent{ name: "slow-agent", @@ -403,19 +474,36 @@ func TestTurnLoop_PreemptiveCancellation(t *testing.T) { } var processedItems []string - callCount := 0 - source := &turnLoopFuncSource[string]{fn: func(_ context.Context, _ time.Duration) (string, ConsumeOption, error) { - callCount++ - switch callCount { - case 1: - return "slow-msg", NonPreemptiveConsumeOption, nil - case 2: + receiveCount := 0 + msgs := []struct { + item string + opts []ConsumeOption + err error + }{ + {"slow-msg", nil, nil}, + {"preempt-msg", []ConsumeOption{WithPreemptive(), WithCancelOptions(WithCancelMode(CancelImmediate))}, nil}, + {"", nil, context.DeadlineExceeded}, + } + frontIdx := 0 + source := &turnLoopFuncSource[string]{ + receive: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { + if receiveCount >= len(msgs) { + return ctx, "", nil, context.DeadlineExceeded + } + m := msgs[receiveCount] + receiveCount++ + frontIdx = receiveCount + return ctx, m.item, m.opts, m.err + }, + front: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { <-slowAgent.startedCh - return "preempt-msg", ConsumeOption{Mode: ConsumePreemptive, CancelOption: CancelOption{Mode: CancelImmediate}}, nil - default: - return "", NonPreemptiveConsumeOption, context.DeadlineExceeded - } - }} + if frontIdx >= len(msgs) { + return ctx, "", nil, context.DeadlineExceeded + } + m := msgs[frontIdx] + return ctx, m.item, m.opts, m.err + }, + } loop, err := NewTurnLoop(TurnLoopConfig[string]{ Source: source, @@ -429,110 +517,603 @@ func TestTurnLoop_PreemptiveCancellation(t *testing.T) { } return fastAgent, nil }, + OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := iter.Next(); !ok { + break + } + } + return nil + }, }) require.NoError(t, err) err = loop.Run(context.Background()) assert.ErrorIs(t, err, context.DeadlineExceeded) - assert.True(t, slowAgent.cancelled.Load(), "slow agent should have been cancelled") - assert.Equal(t, CancelImmediate, slowAgent.cancelledOpt.Mode) + assert.True(t, atomic.LoadInt32(&slowAgent.cancelled) == 1, "slow agent should have been cancelled") assert.Equal(t, []string{"slow-msg", "preempt-msg"}, processedItems) } -func TestTurnLoop_PreemptiveWithCancelMode(t *testing.T) { - slowAgent := &turnLoopCancellableAgent{ - name: "slow-agent", - startedCh: make(chan struct{}), - cancelCh: make(chan struct{}), +func TestTurnLoop_PreemptiveNonCancellableAgent(t *testing.T) { + nonCancellableAgent := &turnLoopMockAgent{ + name: "non-cancellable-agent", + events: []*AgentEvent{{Output: &AgentOutput{}}}, } - fastAgent := &turnLoopMockAgent{ name: "fast-agent", events: []*AgentEvent{{Output: &AgentOutput{}}}, } + var processedItems []string callCount := 0 - source := &turnLoopFuncSource[string]{fn: func(_ context.Context, _ time.Duration) (string, ConsumeOption, error) { + source := &turnLoopFuncSource[string]{receive: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { callCount++ switch callCount { case 1: - return "slow-msg", NonPreemptiveConsumeOption, nil + return ctx, "non-cancel-msg", nil, nil case 2: - <-slowAgent.startedCh - return "preempt-msg", ConsumeOption{ - Mode: ConsumePreemptive, - CancelOption: CancelOption{Mode: CancelAfterChatModel | CancelAfterToolCall}, - }, nil + return ctx, "preempt-msg", []ConsumeOption{WithPreemptive()}, nil default: - return "", NonPreemptiveConsumeOption, context.DeadlineExceeded + return ctx, "", nil, context.DeadlineExceeded } }} loop, err := NewTurnLoop(TurnLoopConfig[string]{ Source: source, GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { + processedItems = append(processedItems, item) return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil }, GetAgent: func(_ context.Context, item string) (Agent, error) { - if item == "slow-msg" { - return slowAgent, nil + if item == "non-cancel-msg" { + return nonCancellableAgent, nil } return fastAgent, nil }, + OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := iter.Next(); !ok { + break + } + } + return nil + }, }) require.NoError(t, err) err = loop.Run(context.Background()) assert.ErrorIs(t, err, context.DeadlineExceeded) - assert.True(t, slowAgent.cancelled.Load()) - assert.Equal(t, CancelAfterChatModel|CancelAfterToolCall, slowAgent.cancelledOpt.Mode) + assert.Equal(t, []string{"non-cancel-msg", "preempt-msg"}, processedItems) } -func TestTurnLoop_PreemptiveNonCancellableAgent(t *testing.T) { - // A non-cancellable agent cannot be preempted, so the new Run processes - // events sequentially before calling Receive. The preemptive message is - // effectively queued and processed in the next turn. - nonCancellableAgent := &turnLoopMockAgent{ +func TestTurnLoop_WithCancel_Basic(t *testing.T) { + agent := &turnLoopCancellableAgent{ + name: "test-agent", + startedCh: make(chan struct{}), + cancelCh: make(chan struct{}), + } + + receiveCount := int32(0) + frontBlocked := make(chan struct{}) + source := &turnLoopFuncSource[string]{ + receive: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { + cnt := atomic.AddInt32(&receiveCount, 1) + if cnt == 1 { + return ctx, "msg1", nil, nil + } + return ctx, "", nil, context.DeadlineExceeded + }, + front: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { + <-frontBlocked + return ctx, "", nil, context.DeadlineExceeded + }, + } + + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: source, + GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil + }, + GetAgent: func(_ context.Context, _ string) (Agent, error) { + return agent, nil + }, + OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := iter.Next(); !ok { + break + } + } + return nil + }, + }) + require.NoError(t, err) + + ctx, cancel := loop.WithCancel(context.Background()) + done := make(chan error) + go func() { + done <- loop.Run(ctx) + }() + + <-agent.startedCh + e := cancel() + assert.NoError(t, e) + + err = <-done + assert.NoError(t, err) +} + +func TestTurnLoop_WithCancel_DuringAgentRun(t *testing.T) { + slowAgent := &turnLoopCancellableAgent{ + name: "slow-agent", + startedCh: make(chan struct{}), + cancelCh: make(chan struct{}), + } + + receiveCount := int32(0) + frontBlocked := make(chan struct{}) + source := &turnLoopFuncSource[string]{ + receive: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { + cnt := atomic.AddInt32(&receiveCount, 1) + if cnt == 1 { + return ctx, "msg1", nil, nil + } + return ctx, "", nil, context.DeadlineExceeded + }, + front: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { + <-frontBlocked + return ctx, "", nil, context.DeadlineExceeded + }, + } + + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: source, + GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil + }, + GetAgent: func(_ context.Context, _ string) (Agent, error) { + return slowAgent, nil + }, + OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := iter.Next(); !ok { + break + } + } + return nil + }, + }) + require.NoError(t, err) + + ctx, cancel := loop.WithCancel(context.Background()) + done := make(chan error) + go func() { + done <- loop.Run(ctx) + }() + + <-slowAgent.startedCh + cancel(WithCancelMode(CancelImmediate)) + + err = <-done + assert.NoError(t, err) + assert.True(t, atomic.LoadInt32(&slowAgent.cancelled) == 1, "agent should have been cancelled") +} + +func TestTurnLoop_WithCancel_NonCancellableAgent_ReturnsError(t *testing.T) { + agent := &turnLoopMockAgent{ name: "non-cancellable-agent", events: []*AgentEvent{{Output: &AgentOutput{}}}, } - fastAgent := &turnLoopMockAgent{ - name: "fast-agent", + + source := &turnLoopFuncSource[string]{ + receive: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { + return ctx, "msg1", nil, nil + }, + } + + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: source, + GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil + }, + GetAgent: func(_ context.Context, _ string) (Agent, error) { + return agent, nil + }, + }) + require.NoError(t, err) + + ctx, _ := loop.WithCancel(context.Background()) + err = loop.Run(ctx) + + assert.ErrorIs(t, err, ErrAgentNotCancellableInTurnLoop) + assert.Contains(t, err.Error(), "non-cancellable-agent") +} + +func TestTurnLoop_WithCancel_MultipleCalls(t *testing.T) { + agent := &turnLoopMockAgent{ + name: "test-agent", events: []*AgentEvent{{Output: &AgentOutput{}}}, } - var processedItems []string - callCount := 0 - source := &turnLoopFuncSource[string]{fn: func(_ context.Context, _ time.Duration) (string, ConsumeOption, error) { - callCount++ - switch callCount { - case 1: - return "non-cancel-msg", NonPreemptiveConsumeOption, nil - case 2: - // Even though Mode is ConsumePreemptive, the agent doesn't - // implement Cancellable, so it's treated as non-preemptive. - return "preempt-msg", ConsumeOption{Mode: ConsumePreemptive}, nil + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: &turnLoopMockSource{items: []string{"msg1"}, err: context.DeadlineExceeded}, + GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil + }, + GetAgent: func(_ context.Context, _ string) (Agent, error) { + return agent, nil + }, + }) + require.NoError(t, err) + + _, cancel := loop.WithCancel(context.Background()) + + err1 := cancel() + err2 := cancel() + + assert.NoError(t, err1) + assert.ErrorIs(t, err2, ErrAgentFinished) +} + +func TestTurnLoop_WithCancel_IndependentCancels(t *testing.T) { + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: &turnLoopMockSource{items: []string{}, err: context.DeadlineExceeded}, + GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil + }, + GetAgent: func(_ context.Context, _ string) (Agent, error) { + return &turnLoopMockAgent{name: "a"}, nil + }, + }) + require.NoError(t, err) + + ctx1, cancel1 := loop.WithCancel(context.Background()) + ctx2, cancel2 := loop.WithCancel(context.Background()) + + cs1 := getTurnLoopCancelSig(ctx1) + cs2 := getTurnLoopCancelSig(ctx2) + + assert.NotNil(t, cs1) + assert.NotNil(t, cs2) + assert.NotEqual(t, cs1, cs2) + + cancel1() + assert.True(t, cs1.isCancelled()) + assert.False(t, cs2.isCancelled()) + + cancel2() + assert.True(t, cs2.isCancelled()) +} + +func TestTurnLoop_WithCancel_WithCancelOptions(t *testing.T) { + slowAgent := &turnLoopCancellableAgent{ + name: "slow-agent", + startedCh: make(chan struct{}), + cancelCh: make(chan struct{}), + } + + receiveCount := int32(0) + frontBlocked := make(chan struct{}) + source := &turnLoopFuncSource[string]{ + receive: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { + cnt := atomic.AddInt32(&receiveCount, 1) + if cnt == 1 { + return ctx, "msg1", nil, nil + } + return ctx, "", nil, context.DeadlineExceeded + }, + front: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { + <-frontBlocked + return ctx, "", nil, context.DeadlineExceeded + }, + } + + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: source, + GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil + }, + GetAgent: func(_ context.Context, _ string) (Agent, error) { + return slowAgent, nil + }, + OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := iter.Next(); !ok { + break + } + } + return nil + }, + }) + require.NoError(t, err) + + ctx, cancel := loop.WithCancel(context.Background()) + done := make(chan error) + go func() { + done <- loop.Run(ctx) + }() + + <-slowAgent.startedCh + cancel(WithCancelMode(CancelAfterToolCall)) + + <-done + assert.Len(t, slowAgent.cancelledOpt, 1) +} + +type turnLoopInMemoryStore struct { + data map[string][]byte +} + +func newTurnLoopInMemoryStore() *turnLoopInMemoryStore { + return &turnLoopInMemoryStore{data: make(map[string][]byte)} +} + +func (s *turnLoopInMemoryStore) Get(_ context.Context, key string) ([]byte, bool, error) { + v, ok := s.data[key] + return v, ok, nil +} + +func (s *turnLoopInMemoryStore) Set(_ context.Context, key string, value []byte) error { + s.data[key] = value + return nil +} + +type turnLoopTestModel struct { + messages []*schema.Message + idx int +} + +func (m *turnLoopTestModel) Generate(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + if m.idx >= len(m.messages) { + return nil, fmt.Errorf("no more messages") + } + msg := m.messages[m.idx] + m.idx++ + return msg, nil +} + +func (m *turnLoopTestModel) Stream(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + panic("not implemented") +} + +func (m *turnLoopTestModel) WithTools(_ []*schema.ToolInfo) (model.ToolCallingChatModel, error) { + return m, nil +} + +type turnLoopSlowModel struct { + delay int64 + startedCh unsafe.Pointer + doneCh chan struct{} + message *schema.Message +} + +func (m *turnLoopSlowModel) Generate(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + if ch := (*chan struct{})(atomic.LoadPointer(&m.startedCh)); ch != nil { + select { + case *ch <- struct{}{}: default: - return "", NonPreemptiveConsumeOption, context.DeadlineExceeded } - }} + } + if delay := atomic.LoadInt64(&m.delay); delay > 0 { + time.Sleep(time.Duration(delay)) + } + if m.doneCh != nil { + select { + case m.doneCh <- struct{}{}: + default: + } + } + return m.message, nil +} + +func (m *turnLoopSlowModel) Stream(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + panic("not implemented") +} + +type turnLoopInterruptTool struct { + name string +} + +func (t *turnLoopInterruptTool) Info(_ context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: t.name, + Desc: "A tool that interrupts", + }, nil +} + +func (t *turnLoopInterruptTool) InvokableRun(ctx context.Context, _ string, _ ...tool.Option) (string, error) { + wasInterrupted, _, _ := tool.GetInterruptState[any](ctx) + if !wasInterrupted { + return "", tool.Interrupt(ctx, "need approval") + } + isResumeTarget, hasData, data := tool.GetResumeContext[string](ctx) + if isResumeTarget && hasData { + return data, nil + } + return "approved", nil +} + +func TestTurnLoop_ExternalCancel_WithStore(t *testing.T) { + store := newTurnLoopInMemoryStore() + const checkPointID = "external-cancel-test" + + modelStarted := make(chan struct{}, 1) + testModel := &turnLoopSlowModel{ + doneCh: make(chan struct{}, 1), + message: schema.AssistantMessage("task completed", nil), + } + atomic.StoreInt64(&testModel.delay, int64(5*time.Second)) + atomic.StorePointer(&testModel.startedCh, unsafe.Pointer(&modelStarted)) + + agent, err := NewChatModelAgent(context.Background(), &ChatModelAgentConfig{ + Name: "test-agent", + Description: "test agent for external cancel", + Model: testModel, + }) + require.NoError(t, err) + + receiveCount := int32(0) + turnCount := int32(0) + frontBlocked := make(chan struct{}) + source := &turnLoopFuncSource[string]{ + receive: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { + cnt := atomic.AddInt32(&receiveCount, 1) + if cnt == 1 { + return ctx, "msg1", []ConsumeOption{WithConsumeCheckPointID(checkPointID)}, nil + } + return ctx, "", nil, ErrLoopExit + }, + front: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { + <-frontBlocked + return ctx, "", nil, context.DeadlineExceeded + }, + } loop, err := NewTurnLoop(TurnLoopConfig[string]{ Source: source, GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { - processedItems = append(processedItems, item) return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil }, - GetAgent: func(_ context.Context, item string) (Agent, error) { - if item == "non-cancel-msg" { - return nonCancellableAgent, nil + GetAgent: func(_ context.Context, _ string) (Agent, error) { + return agent, nil + }, + OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { + atomic.AddInt32(&turnCount, 1) + for { + if _, ok := iter.Next(); !ok { + break + } } - return fastAgent, nil + return nil }, + Store: store, + }) + require.NoError(t, err) + + ctx, cancel := loop.WithCancel(context.Background()) + done := make(chan error) + go func() { + done <- loop.Run(ctx) + }() + + select { + case <-modelStarted: + case err := <-done: + t.Fatalf("loop.Run returned early with error: %v", err) + case <-time.After(3 * time.Second): + t.Fatalf("timeout waiting for model to start") + } + + err = cancel(WithCancelMode(CancelImmediate)) + + var runErr error + select { + case runErr = <-done: + case <-time.After(3 * time.Second): + t.Fatalf("timeout waiting for loop.Run to return after cancel") + } + + var interruptErr *TurnLoopInterruptError[string] + require.True(t, errors.As(runErr, &interruptErr), "expected TurnLoopInterruptError, got: %v", runErr) + assert.Equal(t, checkPointID, interruptErr.CheckpointID) + assert.Equal(t, "msg1", interruptErr.Item) + assert.True(t, len(store.data) > 0, "checkpoint should be stored") + + atomic.StorePointer(&testModel.startedCh, nil) + atomic.StoreInt64(&testModel.delay, 0) + + done = make(chan error) + go func() { + done <- loop.Run(context.Background(), WithTurnLoopResume(checkPointID, "msg1")) + }() + + select { + case runErr = <-done: + case <-time.After(5 * time.Second): + t.Fatalf("timeout waiting for loop.Run (resume) to return") + } + assert.NoError(t, runErr) + assert.Equal(t, int32(2), atomic.LoadInt32(&turnCount), "should have 2 turns total (1 from first run + 1 from resume)") +} + +func TestTurnLoop_InternalToolInterrupt_WithCheckpoint_ThenResume(t *testing.T) { + store := newTurnLoopInMemoryStore() + const checkPointID = "tool-interrupt-test" + + interruptTool := &turnLoopInterruptTool{name: "approval_tool"} + + testModel := &turnLoopTestModel{ + messages: []*schema.Message{ + schema.AssistantMessage("", []schema.ToolCall{ + {ID: "call1", Function: schema.FunctionCall{Name: "approval_tool", Arguments: "{}"}}, + }), + schema.AssistantMessage("task completed", nil), + }, + } + + agent, err := NewChatModelAgent(context.Background(), &ChatModelAgentConfig{ + Name: "test-agent", + Description: "test agent", + Model: testModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{interruptTool}, + }, + }, + }) + require.NoError(t, err) + + receiveCount := int32(0) + frontBlocked := make(chan struct{}) + source := &turnLoopFuncSource[string]{ + receive: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { + cnt := atomic.AddInt32(&receiveCount, 1) + if cnt == 1 { + return ctx, "msg1", []ConsumeOption{WithConsumeCheckPointID(checkPointID)}, nil + } + return ctx, "", nil, ErrLoopExit + }, + front: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { + <-frontBlocked + return ctx, "", nil, context.DeadlineExceeded + }, + } + + var interruptErr *TurnLoopInterruptError[string] + loop, err := NewTurnLoop(TurnLoopConfig[string]{ + Source: source, + GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { + return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil + }, + GetAgent: func(_ context.Context, _ string) (Agent, error) { + return agent, nil + }, + OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := iter.Next(); !ok { + break + } + } + return nil + }, + Store: store, }) require.NoError(t, err) err = loop.Run(context.Background()) - assert.ErrorIs(t, err, context.DeadlineExceeded) - assert.Equal(t, []string{"non-cancel-msg", "preempt-msg"}, processedItems) + require.True(t, errors.As(err, &interruptErr), "expected TurnLoopInterruptError, got: %v", err) + assert.Equal(t, checkPointID, interruptErr.CheckpointID) + assert.Equal(t, "msg1", interruptErr.Item) + assert.Len(t, interruptErr.InterruptContexts, 1) + assert.Equal(t, "need approval", interruptErr.InterruptContexts[0].Info) + + interruptID := interruptErr.InterruptContexts[0].ID + assert.NotEmpty(t, interruptID) + assert.True(t, len(store.data) > 0, "checkpoint should be stored") + + testModel.idx = 1 + receiveCount = 0 + + resumeCtx := compose.ResumeWithData(context.Background(), interruptID, "user approved") + err = loop.Run(resumeCtx, WithTurnLoopResume(checkPointID, "msg1")) + assert.NoError(t, err) } From c2c571d94fe3c65c9c0198ff236dd598f7678e17 Mon Sep 17 00:00:00 2001 From: Megumin Date: Tue, 24 Feb 2026 16:35:57 +0800 Subject: [PATCH 41/59] feat(adk): modify cancel func (#800) --- adk/cancel_test.go | 23 ++++++++++------------- adk/chatmodel.go | 11 +---------- adk/flow.go | 2 +- adk/interface.go | 5 +---- adk/turn_loop.go | 26 +++++++------------------- adk/turn_loop_test.go | 9 +++------ 6 files changed, 23 insertions(+), 53 deletions(-) diff --git a/adk/cancel_test.go b/adk/cancel_test.go index f45e344ff..702efd3c3 100644 --- a/adk/cancel_test.go +++ b/adk/cancel_test.go @@ -207,7 +207,7 @@ func TestRunWithCancel_WithTools(t *testing.T) { time.Sleep(100 * time.Millisecond) - err = cancelFn(ctx) + err = cancelFn() assert.NoError(t, err) start := time.Now() @@ -276,7 +276,7 @@ func TestRunWithCancel_WithTools(t *testing.T) { time.Sleep(100 * time.Millisecond) - err = cancelFn(ctx, WithCancelMode(CancelAfterChatModel)) + err = cancelFn(WithCancelMode(CancelAfterChatModel)) assert.NoError(t, err) var events []*AgentEvent @@ -342,7 +342,7 @@ func TestRunWithCancel_WithTools(t *testing.T) { time.Sleep(100 * time.Millisecond) - err = cancelFn(ctx, WithCancelMode(CancelAfterToolCall)) + err = cancelFn(WithCancelMode(CancelAfterToolCall)) assert.NoError(t, err) var events []*AgentEvent @@ -452,7 +452,7 @@ func TestRunWithCancel_WithCheckpoint(t *testing.T) { <-modelStarted - err = cancelFn(ctx) + err = cancelFn() assert.NoError(t, err) var events []*AgentEvent @@ -518,12 +518,9 @@ func TestCancelFuncMultipleCalls(t *testing.T) { <-modelStarted - cancelErr := cancelFn(ctx) + cancelErr := cancelFn() assert.NoError(t, cancelErr) - cancelErr = cancelFn(ctx) - assert.ErrorIs(t, cancelErr, ErrAgentFinished) - for { _, ok := iter.Next() if !ok { @@ -654,7 +651,7 @@ func TestRunWithCancel_Streaming(t *testing.T) { time.Sleep(100 * time.Millisecond) - cancelErr := cancelFn(ctx) + cancelErr := cancelFn() assert.NoError(t, cancelErr) start := time.Now() @@ -726,7 +723,7 @@ func TestRunWithCancel_Streaming(t *testing.T) { time.Sleep(100 * time.Millisecond) - cancelErr := cancelFn(ctx, WithCancelMode(CancelAfterToolCall)) + cancelErr := cancelFn(WithCancelMode(CancelAfterToolCall)) assert.NoError(t, cancelErr) var events []*AgentEvent @@ -808,7 +805,7 @@ func TestResumeWithCancel(t *testing.T) { <-modelStarted atomic.AddInt32(&modelCallCount, 1) - cancelErr := cancelFn(ctx) + cancelErr := cancelFn() assert.NoError(t, cancelErr) var events []*AgentEvent @@ -931,7 +928,7 @@ func TestResumeWithCancel(t *testing.T) { <-firstModelStarted atomic.AddInt32(&modelCallCount, 1) - cancelErr := cancelFn(ctx) + cancelErr := cancelFn() assert.NoError(t, cancelErr) for { @@ -1001,7 +998,7 @@ func TestResumeWithCancel(t *testing.T) { time.Sleep(100 * time.Millisecond) - err = resumeCancelFn(ctx) + err = resumeCancelFn() assert.NoError(t, err) start := time.Now() diff --git a/adk/chatmodel.go b/adk/chatmodel.go index effb270bd..10faa7434 100644 --- a/adk/chatmodel.go +++ b/adk/chatmodel.go @@ -1047,25 +1047,16 @@ func (a *ChatModelAgent) runInternal(ctx context.Context, input *AgentInput, wit func buildCancelFunc(cs *cancelSig) CancelFunc { var once sync.Once - - return func(_ context.Context, opts ...CancelOption) error { + return func(opts ...CancelOption) error { cfg := &cancelConfig{ Mode: CancelImmediate, } for _, opt := range opts { opt(cfg) } - - cancelled := false once.Do(func() { cs.cancel(cfg) - cancelled = true }) - - if !cancelled { - return ErrAgentFinished - } - return nil } } diff --git a/adk/flow.go b/adk/flow.go index 93143c641..7acfd2137 100644 --- a/adk/flow.go +++ b/adk/flow.go @@ -387,7 +387,7 @@ func (a *flowAgent) runInternal(ctx context.Context, input *AgentInput, withCanc return iterator, cancelFn } -func notCancellableFuncInternal(_ context.Context, _ ...CancelOption) error { +func notCancellableFuncInternal(_ ...CancelOption) error { return ErrAgentNotCancellable } diff --git a/adk/interface.go b/adk/interface.go index 1d2f5f070..c73705ae3 100644 --- a/adk/interface.go +++ b/adk/interface.go @@ -288,9 +288,6 @@ const ( CancelAfterToolCall ) -// ErrAgentFinished is returned by Cancel when the agent has already finished execution. -var ErrAgentFinished = errors.New("agent has already finished execution") - // ErrAgentNotCancellable is returned by Cancel when the agent does not support cancellation. var ErrAgentNotCancellable = errors.New("agent does not implement CancellableAgent interface") @@ -315,7 +312,7 @@ func WithCancelTimeout(timeout time.Duration) CancelOption { } } -type CancelFunc func(context.Context, ...CancelOption) error +type CancelFunc func(...CancelOption) error type CancellableAgent interface { Agent diff --git a/adk/turn_loop.go b/adk/turn_loop.go index af69ae5d3..94b39fe6c 100644 --- a/adk/turn_loop.go +++ b/adk/turn_loop.go @@ -192,11 +192,6 @@ func getTurnLoopCancelSig(ctx context.Context) *turnLoopCancelSig { return nil } -// TurnLoopCancelFunc is the cancel function returned by WithCancel. -// Unlike Agent's CancelFunc, it does not require a context parameter -// since the context is already bound when WithCancel is called. -type TurnLoopCancelFunc func(opts ...CancelOption) error - // ErrAgentNotCancellableInTurnLoop is returned when WithCancel context is used // but the Agent does not implement CancellableAgent. var ErrAgentNotCancellableInTurnLoop = errors.New("agent does not support cancel but WithCancel context was provided") @@ -258,7 +253,7 @@ var ErrLoopExit = errors.New("loop exit") // }() // // Later, to cancel: // cancel(adk.WithCancelMode(adk.CancelAfterToolCall)) -func (l *TurnLoop[T]) WithCancel(ctx context.Context) (context.Context, TurnLoopCancelFunc) { +func (l *TurnLoop[T]) WithCancel(ctx context.Context) (context.Context, CancelFunc) { cs := newTurnLoopCancelSig() ctx = withTurnLoopCancelSig(ctx, cs) @@ -270,16 +265,9 @@ func (l *TurnLoop[T]) WithCancel(ctx context.Context) (context.Context, TurnLoop for _, opt := range opts { opt(cfg) } - - cancelled := false once.Do(func() { cs.cancel(cfg) - cancelled = true }) - - if !cancelled { - return ErrAgentFinished - } return nil } @@ -453,8 +441,8 @@ func (l *TurnLoop[T]) Run(ctx context.Context, opts ...TurnLoopRunOption[T]) err }(): externalCancelled = true cfg := cs.getConfig() - err := cancelFunc(nCtx, cancelConfigToOpts(cfg)...) - if err != nil && !errors.Is(err, ErrAgentFinished) { + err := cancelFunc(cancelConfigToOpts(cfg)...) + if err != nil { <-done return fmt.Errorf("failed to cancel agent: %w", err) } @@ -476,7 +464,7 @@ func (l *TurnLoop[T]) Run(ctx context.Context, opts ...TurnLoopRunOption[T]) err o := applyConsumeOptions(option) switch o.Mode { case ConsumePreemptive: - err := cancelFunc(nCtx, o.CancelOpts...) + err := cancelFunc(o.CancelOpts...) if err != nil { <-done return fmt.Errorf("failed to cancel agent: %w", err) @@ -485,7 +473,7 @@ func (l *TurnLoop[T]) Run(ctx context.Context, opts ...TurnLoopRunOption[T]) err select { case <-done: case <-time.After(o.Timeout): - err := cancelFunc(nCtx, o.CancelOpts...) + err := cancelFunc(o.CancelOpts...) if err != nil { <-done return fmt.Errorf("failed to cancel agent: %w", err) @@ -497,8 +485,8 @@ func (l *TurnLoop[T]) Run(ctx context.Context, opts ...TurnLoopRunOption[T]) err return nil }(): cfg := cs.getConfig() - err := cancelFunc(nCtx, cancelConfigToOpts(cfg)...) - if err != nil && !errors.Is(err, ErrAgentFinished) { + err := cancelFunc(cancelConfigToOpts(cfg)...) + if err != nil { <-done return fmt.Errorf("failed to cancel agent: %w", err) } diff --git a/adk/turn_loop_test.go b/adk/turn_loop_test.go index 7c22f3699..88e3679b3 100644 --- a/adk/turn_loop_test.go +++ b/adk/turn_loop_test.go @@ -112,7 +112,7 @@ func (a *turnLoopCancellableAgent) RunWithCancel(_ context.Context, _ *AgentInpu defer gen.Close() <-a.cancelCh }() - cancelFunc := func(_ context.Context, opts ...CancelOption) error { + cancelFunc := func(opts ...CancelOption) error { atomic.StoreInt32(&a.cancelled, 1) a.cancelledOpt = opts close(a.cancelCh) @@ -747,11 +747,8 @@ func TestTurnLoop_WithCancel_MultipleCalls(t *testing.T) { _, cancel := loop.WithCancel(context.Background()) - err1 := cancel() - err2 := cancel() - - assert.NoError(t, err1) - assert.ErrorIs(t, err2, ErrAgentFinished) + err = cancel() + assert.NoError(t, err) } func TestTurnLoop_WithCancel_IndependentCancels(t *testing.T) { From 4c55952901eaf49f1962a85b95f7da15eefc079f Mon Sep 17 00:00:00 2001 From: Megumin Date: Tue, 24 Feb 2026 21:07:31 +0800 Subject: [PATCH 42/59] fix(adk): select cancel after front without preemptive (#802) --- adk/turn_loop.go | 65 ++++++++++++++++++++++++++---------------------- 1 file changed, 35 insertions(+), 30 deletions(-) diff --git a/adk/turn_loop.go b/adk/turn_loop.go index 94b39fe6c..f7ec423c5 100644 --- a/adk/turn_loop.go +++ b/adk/turn_loop.go @@ -179,6 +179,13 @@ func (cs *turnLoopCancelSig) getConfig() *cancelConfig { return nil } +func (cs *turnLoopCancelSig) getDoneChan() <-chan struct{} { + if cs != nil { + return cs.done + } + return nil +} + type turnLoopCancelSigKey struct{} func withTurnLoopCancelSig(ctx context.Context, cs *turnLoopCancelSig) context.Context { @@ -395,8 +402,8 @@ func (l *TurnLoop[T]) Run(ctx context.Context, opts ...TurnLoopRunOption[T]) err return l.handleEvents(ctx, item, iter, checkPointID) } - var handleEventErr error if cancelFunc != nil { + var handleEventErr error done := make(chan struct{}) go func() { @@ -429,27 +436,14 @@ func (l *TurnLoop[T]) Run(ctx context.Context, opts ...TurnLoopRunOption[T]) err }) }() - var externalCancelled bool select { case <-frontDone: case <-done: - case <-func() <-chan struct{} { - if cs != nil { - return cs.done - } - return nil - }(): - externalCancelled = true - cfg := cs.getConfig() - err := cancelFunc(cancelConfigToOpts(cfg)...) + case <-cs.getDoneChan(): + err := cancelAndWait(cancelFunc, cs, done) if err != nil { - <-done - return fmt.Errorf("failed to cancel agent: %w", err) + return err } - } - - if externalCancelled { - <-done return l.wrapHandleEventErr(handleEventErr) } @@ -478,29 +472,29 @@ func (l *TurnLoop[T]) Run(ctx context.Context, opts ...TurnLoopRunOption[T]) err <-done return fmt.Errorf("failed to cancel agent: %w", err) } - case <-func() <-chan struct{} { - if cs != nil { - return cs.done - } - return nil - }(): - cfg := cs.getConfig() - err := cancelFunc(cancelConfigToOpts(cfg)...) + case <-cs.getDoneChan(): + err := cancelAndWait(cancelFunc, cs, done) if err != nil { - <-done - return fmt.Errorf("failed to cancel agent: %w", err) + return err } - <-done return l.wrapHandleEventErr(handleEventErr) } } - <-done + select { + case <-done: + case <-cs.getDoneChan(): + err := cancelAndWait(cancelFunc, cs, done) + if err != nil { + return err + } + return l.wrapHandleEventErr(handleEventErr) + } if err := l.wrapHandleEventErr(handleEventErr); err != nil { return err } } else { - if handleEventErr = handleEvents(); handleEventErr != nil { + if handleEventErr := handleEvents(); handleEventErr != nil { if err := l.wrapHandleEventErr(handleEventErr); err != nil { return err } @@ -545,6 +539,17 @@ func (l *TurnLoop[T]) handleEvents(ctx context.Context, item T, iter *AsyncItera return nil } +func cancelAndWait(cf CancelFunc, cs *turnLoopCancelSig, done chan struct{}) error { + cfg := cs.getConfig() + err := cf(cancelConfigToOpts(cfg)...) + if err != nil { + <-done + return fmt.Errorf("failed to cancel agent: %w", err) + } + <-done + return nil +} + func cancelConfigToOpts(cfg *cancelConfig) []CancelOption { if cfg == nil { return nil From 238722424391d70bb7337fd365516a30b4669df0 Mon Sep 17 00:00:00 2001 From: IPender Date: Mon, 2 Mar 2026 21:43:51 +0800 Subject: [PATCH 43/59] fix: implement IsCallbacksEnabled and GetType for cancelableChatModel (#826) --- adk/cancel_wrapper.go | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/adk/cancel_wrapper.go b/adk/cancel_wrapper.go index 9c00df77f..396d6a458 100644 --- a/adk/cancel_wrapper.go +++ b/adk/cancel_wrapper.go @@ -19,11 +19,14 @@ package adk import ( "context" "io" + "reflect" "runtime/debug" "time" + "github.com/cloudwego/eino/components" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/internal/generic" "github.com/cloudwego/eino/internal/safe" "github.com/cloudwego/eino/schema" ) @@ -118,6 +121,18 @@ func (c *cancelableChatModel) Stream(ctx context.Context, input []*schema.Messag return res.result, res.err } +func (c *cancelableChatModel) IsCallbacksEnabled() bool { + return components.IsCallbacksEnabled(c.inner) +} + +func (c *cancelableChatModel) GetType() string { + if name, ok := components.GetType(c.inner); ok { + return name + } + + return generic.ParseTypeName(reflect.ValueOf(c.inner)) +} + func cancelableToolInvokable(cs *cancelSig, endpoint compose.InvokableToolEndpoint) compose.InvokableToolEndpoint { return func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) { if cfg := checkCancelSig(cs); cfg != nil && cfg.Mode == CancelImmediate { From 87eb74fc5bedd4521a67cdb819f345ba411f93fe Mon Sep 17 00:00:00 2001 From: "shentong.martin" Date: Mon, 23 Mar 2026 10:39:45 +0800 Subject: [PATCH 44/59] fix: rebase error Change-Id: If20fa78dba82a1c177c8ec47090050ea8c1354ed --- adk/chatmodel.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/adk/chatmodel.go b/adk/chatmodel.go index 10faa7434..0848a859a 100644 --- a/adk/chatmodel.go +++ b/adk/chatmodel.go @@ -1119,7 +1119,7 @@ func (a *ChatModelAgent) resumeInternal(ctx context.Context, info *ResumeInfo, w generator.Send(&AgentEvent{Err: err}) generator.Close() }() - return iterator, nil + return iterator, notCancellableFuncInternal } var historyModifier func(ctx context.Context, history []Message) []Message From 989db7777a40db31839c326c8f06eb6f703c9527 Mon Sep 17 00:00:00 2001 From: shentongmartin Date: Thu, 26 Mar 2026 19:44:26 +0800 Subject: [PATCH 45/59] refactor(adk): replace TurnLoop with push-based API (#835) --- .gitignore | 5 +- adk/agent_tool.go | 21 +- adk/call_option.go | 28 + adk/cancel.go | 885 ++++ adk/cancel_edge_test.go | 1268 +++++ adk/cancel_multicall_test.go | 125 + adk/cancel_test.go | 2885 ++++++++++- adk/cancel_wrapper.go | 295 -- adk/chatmodel.go | 373 +- adk/chatmodel_retry_test.go | 145 + adk/flow.go | 114 +- adk/handler.go | 20 + adk/handler_test.go | 318 ++ adk/interface.go | 54 - adk/interrupt.go | 39 +- .../patchtoolcalls/patchtoolcalls.go | 2 +- adk/prebuilt/planexecute/plan_execute_test.go | 232 + adk/react.go | 226 +- adk/react_test.go | 37 +- adk/retry_chatmodel.go | 9 + adk/runctx.go | 12 +- adk/runctx_test.go | 209 + adk/runner.go | 127 +- adk/turn_loop.go | 1657 +++++-- adk/turn_loop_test.go | 4202 +++++++++++++---- adk/utils.go | 37 +- adk/workflow.go | 88 +- adk/wrappers.go | 38 +- compose/checkpoint_test.go | 28 + compose/graph_manager.go | 9 +- compose/graph_run.go | 82 +- compose/interrupt.go | 4 + examples | 1 + ext | 1 + internal/channel.go | 45 + internal/channel_test.go | 241 + internal/core/address.go | 25 +- internal/core/interrupt.go | 7 + schema/serialization.go | 2 +- 39 files changed, 11602 insertions(+), 2294 deletions(-) create mode 100644 adk/cancel.go create mode 100644 adk/cancel_edge_test.go create mode 100644 adk/cancel_multicall_test.go delete mode 100644 adk/cancel_wrapper.go create mode 160000 examples create mode 160000 ext diff --git a/.gitignore b/.gitignore index 8ac1d568d..04542d49a 100644 --- a/.gitignore +++ b/.gitignore @@ -50,8 +50,11 @@ reports/ /todos .DS_Store -*.log +*.log* +.claude CLAUDE.md +*.jsonl +*.txt # Specs directories */specs diff --git a/adk/agent_tool.go b/adk/agent_tool.go index 9472dab1f..fde319cb4 100644 --- a/adk/agent_tool.go +++ b/adk/agent_tool.go @@ -167,16 +167,19 @@ func (at *agentTool) InvokableRun(ctx context.Context, argumentsInJSON string, o } iter = newInvokableAgentToolRunner(at.agent, ms, enableStreaming).Run(ctx, input, - append(getOptionsByAgentName(at.agent.Name(ctx), opts), WithCheckPointID(bridgeCheckpointID), withSharedParentSession())...) + append(extractAndDeriveCancelCtx(ctx, at.agent.Name(ctx), opts), WithCheckPointID(bridgeCheckpointID), withSharedParentSession())...) } else { if !hasState { return "", fmt.Errorf("agent tool '%s' interrupt has happened, but cannot find interrupt state", at.agent.Name(ctx)) } - ms = newResumeBridgeStore(state) + ms = newResumeBridgeStore(bridgeCheckpointID, state) + + agentOpts := extractAndDeriveCancelCtx(ctx, at.agent.Name(ctx), opts) + agentOpts = append(agentOpts, withSharedParentSession()) iter, err = newInvokableAgentToolRunner(at.agent, ms, enableStreaming). - Resume(ctx, bridgeCheckpointID, append(getOptionsByAgentName(at.agent.Name(ctx), opts), withSharedParentSession())...) + Resume(ctx, bridgeCheckpointID, agentOpts...) if err != nil { return "", err } @@ -281,6 +284,18 @@ func getOptionsByAgentName(agentName string, opts []tool.Option) []AgentRunOptio return ret } +func extractAndDeriveCancelCtx(ctx context.Context, agentName string, opts []tool.Option) []AgentRunOption { + agentOpts := getOptionsByAgentName(agentName, opts) + baseOpts := getCommonOptions(nil, agentOpts...) + if baseOpts.cancelCtx != nil { + childCtx := baseOpts.cancelCtx.deriveChild(ctx) + agentOpts = append(agentOpts, WrapImplSpecificOptFn(func(o *options) { + o.cancelCtx = childCtx + })) + } + return agentOpts +} + func getEmitGeneratorAndEnableStreaming(opts []tool.Option) (*AsyncGenerator[*AgentEvent], bool) { o := tool.GetImplSpecificOptions[agentToolOptions](nil, opts...) if o == nil { diff --git a/adk/call_option.go b/adk/call_option.go index 55e57fd32..ead6ae636 100644 --- a/adk/call_option.go +++ b/adk/call_option.go @@ -24,6 +24,7 @@ type options struct { checkPointID *string skipTransferMessages bool handlers []callbacks.Handler + cancelCtx *cancelContext } // AgentRunOption is the call option for adk Agent. @@ -157,6 +158,33 @@ func filterCallbackHandlersForNestedAgents(currentAgentName string, opts []Agent return filteredOpts } +// filterCancelOption removes any AgentRunOption that sets a cancelCtx on *options. +// This prevents inner (nested) agents from receiving the cancel option when the +// outer flowAgent owns the cancel lifecycle. Inner agents access the cancelContext +// via the Go context (getCancelContext) instead. +func filterCancelOption(opts []AgentRunOption) []AgentRunOption { + if len(opts) == 0 { + return nil + } + var filteredOpts []AgentRunOption + for i := range opts { + opt := opts[i] + if opt.implSpecificOptFn == nil { + filteredOpts = append(filteredOpts, opt) + continue + } + if _, isCommonOpt := opt.implSpecificOptFn.(func(*options)); isCommonOpt { + testOpt := &options{} + opt.implSpecificOptFn.(func(*options))(testOpt) + if testOpt.cancelCtx != nil { + continue + } + } + filteredOpts = append(filteredOpts, opt) + } + return filteredOpts +} + func filterOptions(agentName string, opts []AgentRunOption) []AgentRunOption { if len(opts) == 0 { return nil diff --git a/adk/cancel.go b/adk/cancel.go new file mode 100644 index 000000000..20d72bb20 --- /dev/null +++ b/adk/cancel.go @@ -0,0 +1,885 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "errors" + "fmt" + "io" + "sync" + "sync/atomic" + "time" + + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" +) + +func init() { + schema.RegisterName[*CancelError]("_eino_adk_cancel_error") + schema.RegisterName[*AgentCancelInfo]("_eino_adk_agent_cancel_info") + schema.RegisterName[*StreamCanceledError]("_eino_adk_stream_cancelled_error") +} + +// CancelMode specifies when an agent should be canceled. +// Modes can be combined with bitwise OR to cancel at multiple safe-points. +// For example, CancelAfterChatModel | CancelAfterToolCalls cancels the agent +// after whichever safe-point is reached first. +type CancelMode int + +const ( + // CancelImmediate cancels the agent as soon as the signal is received, + // without waiting for a ChatModel or ToolCalls safe-point. Propagates + // to all descendant agents via the cancel context hierarchy, including + // agents nested inside AgentTools and workflow sub-agents. + CancelImmediate CancelMode = 0 + // CancelAfterChatModel cancels after the first chat model call that completes + // anywhere in the agent hierarchy, including nested sub-agents, agent tools, + // and workflow branches. The cancel mode propagates to all descendant agents; + // whichever ChatModel finishes first triggers the cancel. The interrupting + // agent emits an interrupt that bubbles up through the agent tree — parent + // agents do not need to reach their own ChatModel safe-point. + CancelAfterChatModel CancelMode = 1 << iota + // CancelAfterToolCalls cancels after the first set of concurrent tool calls + // that completes anywhere in the agent hierarchy. Like CancelAfterChatModel, + // this mode propagates to all descendants and fires at whichever level + // reaches the safe-point first. + CancelAfterToolCalls +) + +// CancelHandle represents a cancel operation that can be waited on. +type CancelHandle struct { + wait func() error +} + +func (h *CancelHandle) Wait() error { + return h.wait() +} + +// AgentCancelFunc is called to request cancellation of a running agent. +// It returns after the cancel request is committed; use the returned handle's +// Wait to block for completion and outcome. +// +// The returned bool reports whether this call contributed to the CancelError +// for the current execution. "Contributed" means this call's cancel options +// were included before cancellation was finalized. It is false when cancellation +// was already finalized (handled or execution completed). +type AgentCancelFunc func(...AgentCancelOption) (*CancelHandle, bool) + +type agentCancelConfig struct { + Mode CancelMode + Timeout *time.Duration +} + +// AgentCancelOption configures cancel behavior. +type AgentCancelOption func(*agentCancelConfig) + +// WithAgentCancelMode sets the cancel mode for the agent cancel operation. +func WithAgentCancelMode(mode CancelMode) AgentCancelOption { + return func(config *agentCancelConfig) { + config.Mode = mode + } +} + +// WithAgentCancelTimeout sets a timeout for the cancel operation. +// This only applies to safe-point modes (CancelAfterChatModel, CancelAfterToolCalls): +// if the safe-point hasn't fired within this duration, the cancel escalates to +// an immediate graph interrupt. +// For CancelImmediate this timeout is ignored — the graph interrupt fires +// immediately with timeout=0. +func WithAgentCancelTimeout(timeout time.Duration) AgentCancelOption { + return func(config *agentCancelConfig) { + config.Timeout = &timeout + } +} + +// AgentCancelInfo contains information about a cancel operation. +type AgentCancelInfo struct { + Mode CancelMode + Escalated bool + Timeout bool +} + +// CancelError is sent via AgentEvent.Err when an agent is canceled. +// Use errors.As to match and extract *CancelError from event errors. +// +// Interrupt absorption: when a cancel is active (shouldCancel() == true), ANY +// interrupt — whether from a cancel safe-point node or from business logic +// (e.g. compose.Interrupt in a tool) — is converted to a CancelError. The +// cancel "absorbs" the business interrupt. This is intentional: +// +// - In concurrent execution (parallel workflows, concurrent tool calls), +// cancel-induced and business interrupts can arrive as a single composite +// signal that cannot be split apart. +// - Even in sequential execution, treating business interrupts as CancelError +// during active cancel gives consistent semantics. +// - The business interrupt is NOT lost — the checkpoint preserves the full +// interrupt hierarchy. On resume (Runner.Resume), the agent re-executes +// the interrupting code path and the business interrupt re-fires naturally. +type CancelError struct { + Info *AgentCancelInfo + + // CheckPointID is the checkpoint ID associated with this cancel operation. + // When non-empty, the cancelled agent's state has been persisted under this ID + // and can be resumed via Runner.Resume or GenInputResult.ResumeFromCheckpointID. + CheckPointID string + + // InterruptContexts provides the interrupt contexts needed for targeted + // resumption via Runner.ResumeWithParams. Each context represents a step + // in the agent hierarchy that was interrupted. This is a slice because + // composite agents (e.g. parallel workflows) may interrupt at multiple + // points simultaneously, matching the shape of AgentAction.Interrupted.InterruptContexts. + // Use each InterruptCtx.ID as a key in ResumeParams.Targets. + InterruptContexts []*InterruptCtx + + interruptSignal *InterruptSignal // unexported — only Runner needs it for checkpoint +} + +func (e *CancelError) Error() string { + return fmt.Sprintf("agent canceled: mode=%v, escalated=%v", e.Info.Mode, e.Info.Escalated) +} + +// Sentinel errors for cancel outcomes. +var ( + // ErrCancelTimeout is returned by CancelHandle.Wait when the cancel operation timed out. + ErrCancelTimeout = errors.New("cancel timed out") + + // ErrExecutionCompleted is returned by CancelHandle.Wait when the agent finished + // before the cancel took effect. "Finished" means the event stream was fully + // drained without any interrupt — normal completion or a fatal error. + // + // Note: business interrupts that occur while cancel is active are absorbed + // into CancelError (see CancelError doc), so they result in nil (cancel + // succeeded), NOT ErrExecutionCompleted. Only execution that completes with + // no interrupt at all produces this error. + ErrExecutionCompleted = errors.New("execution already completed") + + // ErrStreamCanceled is the error sent through the stream when CancelImmediate aborts it. + // It is a *StreamCanceledError so it can be gob-serialized during checkpoint save + // (when stored as agentEventWrapper.StreamErr). + ErrStreamCanceled error = &StreamCanceledError{} +) + +// StreamCanceledError is the concrete error type for ErrStreamCanceled. +// It is exported so that gob can serialize it during checkpoint save when the error +// is stored in agentEventWrapper.StreamErr. +type StreamCanceledError struct{} + +func (e *StreamCanceledError) Error() string { + return "stream canceled" +} + +// WithCancel creates an AgentRunOption that enables cancellation for an agent run. +// It returns the option to pass to Run/Resume and a cancel function. +// Cancel options (mode, timeout) are passed to the returned AgentCancelFunc at call time. +func WithCancel() (AgentRunOption, AgentCancelFunc) { + cc := newCancelContext() + opt := WrapImplSpecificOptFn(func(o *options) { + o.cancelCtx = cc + }) + cancelFn := cc.buildCancelFunc() + return opt, cancelFn +} + +// cancelContext state constants (for int32 CAS). +// +// State transition rules: +// +// stateRunning -> stateCancelling (cancel requested by AgentCancelFunc) +// stateRunning -> stateDone (execution finished without interrupt) +// stateCancelling -> stateCancelHandled (ANY interrupt absorbed as CancelError) +// stateCancelling -> stateDone (execution finished without interrupt while cancel pending) +// +// Terminal states: stateDone, stateCancelHandled. +// +// Note: We intentionally do NOT distinguish between "completed" and "errored" +// terminal states. End-users get the actual outcome from AgentEvent. +// This simplification keeps the state machine minimal — only the cancel/non-cancel +// distinction matters for the AgentCancelFunc return value. +// +// Business interrupt handling: when cancel is active (stateCancelling) and any +// interrupt arrives — cancel-induced OR business — wrapIterWithCancelCtx absorbs +// it as a CancelError and transitions to stateCancelHandled. The business interrupt +// data is preserved in the checkpoint for re-emission on resume. +const ( + // stateRunning is the initial state: agent is executing normally. + stateRunning int32 = 0 + // stateCancelling means AgentCancelFunc has been called and cancelChan is + // closed, but the cancel has not yet been handled by the runFunc. + stateCancelling int32 = 1 + // stateDone means execution has finished through any non-cancel path: + // normal completion, business interrupt, or error. The specific outcome + // is conveyed through AgentEvent, not through the cancel state machine. + stateDone int32 = 2 + // stateCancelHandled means the cancel was processed by the runFunc and a + // CancelError was emitted through the event stream. This is the success + // terminal state for cancellation. + stateCancelHandled int32 = 5 +) + +// interruptSent constants (for int32 CAS). +// +// Transition rules: +// +// interruptNotSent -> interruptImmediate (CancelImmediate or escalation) +const ( + // interruptNotSent means no compose graph interrupt has been sent. + interruptNotSent int32 = 0 + // interruptImmediate means an immediate graph interrupt was sent with + // timeout=0, forcing the graph to stop as soon as possible. + interruptImmediate int32 = 1 +) + +// defaultCancelImmediateGracePeriod is the time a parent's graph interrupt +// waits when the cancelContext has active children (via deriveChild). This +// gives child agents time to propagate their interrupt signal back through +// the agentTool as a CompositeInterrupt. If this proves insufficient for +// deeply nested structures or too slow for latency-sensitive use cases, +// consider making it configurable via an AgentCancelOption. +const defaultCancelImmediateGracePeriod = 1 * time.Second + +type cancelContextKey struct{} + +// withCancelContext stores a cancelContext in the Go context. +func withCancelContext(ctx context.Context, cc *cancelContext) context.Context { + if cc == nil { + return ctx + } + return context.WithValue(ctx, cancelContextKey{}, cc) +} + +// getCancelContext retrieves the cancelContext from the Go context, or nil. +func getCancelContext(ctx context.Context) *cancelContext { + if v := ctx.Value(cancelContextKey{}); v != nil { + return v.(*cancelContext) + } + return nil +} + +type cancelContext struct { + mode int32 // atomic, CancelMode + + cancelChan chan struct{} // closed when cancel is requested (all modes, not just safe-point) + immediateChan chan struct{} // closed when an immediate graph interrupt fires + doneChan chan struct{} // closed when execution completes (by any mark* method) + doneOnce sync.Once // ensures doneChan is closed exactly once + + state int32 // stateRunning, stateCancelling, stateDone, stateCancelHandled + interruptSent int32 // interruptNotSent, interruptImmediate + escalated int32 // 1 if escalated from safe-point to immediate + timeoutEscalated int32 // 1 if escalation was triggered by timeout + startedMode int32 // atomic, mode when state transitioned to cancelling + deadlineUnixNano int64 // atomic, 0 means no deadline + + root bool // true for the original cancelContext created by WithCancel(); false for derived children + parent *cancelContext // non-nil for derived children; used to decrement parent's activeChildren on markDone + + activeChildren int32 // atomic; number of derived children that haven't called markDone() yet + decrementedParent int32 // atomic CAS guard; ensures parent.activeChildren is decremented at most once + + cancelMu sync.Mutex + timeoutOnce sync.Once + timeoutNotify chan struct{} + + mu sync.Mutex + graphInterruptFuncs []func(...compose.GraphInterruptOption) +} + +func newCancelContext() *cancelContext { + return &cancelContext{ + cancelChan: make(chan struct{}), + immediateChan: make(chan struct{}), + doneChan: make(chan struct{}), + timeoutNotify: make(chan struct{}, 1), + root: true, + } +} + +func (cc *cancelContext) isRoot() bool { + return cc != nil && cc.root +} + +// deriveChild creates a child cancelContext that receives cancel propagation +// from the parent. The caller MUST ensure the child's markDone() is eventually +// called (e.g., via wrapIterWithCancelCtx's defer) or that ctx is canceled; +// otherwise the two propagation goroutines will leak. +func (cc *cancelContext) deriveChild(ctx context.Context) *cancelContext { + if cc == nil { + return nil + } + child := newCancelContext() + child.root = false + child.parent = cc + atomic.AddInt32(&cc.activeChildren, 1) + + go func() { + select { + case <-cc.cancelChan: + child.triggerCancel(cc.getMode()) + case <-child.doneChan: + case <-ctx.Done(): + } + }() + + go func() { + select { + case <-cc.immediateChan: + child.triggerImmediateCancel() + case <-child.doneChan: + case <-ctx.Done(): + } + }() + + return child +} + +func (cc *cancelContext) triggerCancel(mode CancelMode) { + cc.setMode(mode) + if atomic.CompareAndSwapInt32(&cc.state, stateRunning, stateCancelling) { + close(cc.cancelChan) + } +} + +func (cc *cancelContext) triggerImmediateCancel() { + atomic.StoreInt32(&cc.escalated, 1) + cc.setMode(CancelImmediate) + if atomic.CompareAndSwapInt32(&cc.state, stateRunning, stateCancelling) { + close(cc.cancelChan) + } + cc.sendImmediateInterrupt() +} + +func (cc *cancelContext) getMode() CancelMode { + if cc == nil { + return CancelImmediate + } + return CancelMode(atomic.LoadInt32(&cc.mode)) +} + +func (cc *cancelContext) setMode(mode CancelMode) { + atomic.StoreInt32(&cc.mode, int32(mode)) +} + +func (cc *cancelContext) getDeadlineUnixNano() int64 { + return atomic.LoadInt64(&cc.deadlineUnixNano) +} + +func (cc *cancelContext) setDeadlineUnixNano(v int64) { + atomic.StoreInt64(&cc.deadlineUnixNano, v) +} + +func (cc *cancelContext) wakeTimeoutController() { + select { + case cc.timeoutNotify <- struct{}{}: + default: + } +} + +// shouldCancel returns true if a cancel has been requested (cancelChan is closed). +func (cc *cancelContext) shouldCancel() bool { + if cc == nil { + return false + } + select { + case <-cc.cancelChan: + return true + default: + return false + } +} + +// sendImmediateInterrupt sends the compose graph interrupt signal via graphInterruptFuncs. +// Also closes immediateChan (used by cancelMonitoredModel to abort an in-progress stream). +// Returns false if an interrupt was already sent or if no graphInterruptFuncs have been +// registered yet (the deferred fire in setGraphInterruptFunc will handle that case). +func (cc *cancelContext) sendImmediateInterrupt() bool { + cc.mu.Lock() + + if !atomic.CompareAndSwapInt32(&cc.interruptSent, interruptNotSent, interruptImmediate) { + cc.mu.Unlock() + return false + } + + close(cc.immediateChan) + + fns := make([]func(...compose.GraphInterruptOption), len(cc.graphInterruptFuncs)) + copy(fns, cc.graphInterruptFuncs) + + if len(fns) == 0 { + cc.mu.Unlock() + return false + } + + for _, fn := range fns { + fn(compose.WithGraphInterruptTimeout(0)) + } + cc.mu.Unlock() + return true +} + +// setGraphInterruptFunc appends a graph interrupt function to the list. +// If an immediate cancel was already requested, fires it retroactively. +// Multiple functions can be registered (e.g. one per parallel sub-agent). +// +// Both this method and sendImmediateInterrupt hold cc.mu across the entire +// check-and-fire sequence, ensuring each interrupt function is called exactly +// once (compose.WithGraphInterrupt returns a non-idempotent closure that panics +// on double-call). +func (cc *cancelContext) setGraphInterruptFunc(interrupt func(...compose.GraphInterruptOption)) { + cc.mu.Lock() + cc.graphInterruptFuncs = append(cc.graphInterruptFuncs, interrupt) + + shouldFire := atomic.LoadInt32(&cc.interruptSent) == interruptImmediate + if shouldFire { + interrupt(compose.WithGraphInterruptTimeout(0)) + } + cc.mu.Unlock() +} + +// markDone marks the execution as finished through any non-cancel path +// (normal completion, business interrupt, or error). +// This is safe to call even if a cancel is in progress — it allows the +// cancel func to detect that execution finished before cancel took effect. +func (cc *cancelContext) markDone() { + if cc == nil { + return + } + if atomic.CompareAndSwapInt32(&cc.state, stateRunning, stateDone) || + atomic.CompareAndSwapInt32(&cc.state, stateCancelling, stateDone) { + cc.doneOnce.Do(func() { close(cc.doneChan) }) + cc.detachFromParent() + } +} + +func (cc *cancelContext) detachFromParent() { + if cc.parent != nil && atomic.CompareAndSwapInt32(&cc.decrementedParent, 0, 1) { + atomic.AddInt32(&cc.parent.activeChildren, -1) + } +} + +func (cc *cancelContext) hasActiveChildren() bool { + return cc != nil && atomic.LoadInt32(&cc.activeChildren) > 0 +} + +func (cc *cancelContext) wrapGraphInterruptWithGracePeriod(interrupt func(...compose.GraphInterruptOption)) func(...compose.GraphInterruptOption) { + return func(opts ...compose.GraphInterruptOption) { + if cc.hasActiveChildren() { + newOpts := make([]compose.GraphInterruptOption, len(opts)+1) + copy(newOpts, opts) + newOpts[len(opts)] = compose.WithGraphInterruptTimeout(defaultCancelImmediateGracePeriod) + opts = newOpts + } + interrupt(opts...) + } +} + +// markCancelHandled signals that the cancel path in the runFunc has created +// and sent a CancelError. Transitions state to stateCancelHandled so that: +// 1. The deferred markDone() becomes a no-op (CAS from cancelling fails). +// 2. buildCancelFunc sees stateCancelHandled and returns nil (cancel succeeded). +// Returns true if the transition succeeded, false if cancel was already handled +// (e.g., by a sub-agent). This prevents duplicate CancelError emission. +func (cc *cancelContext) markCancelHandled() bool { + if cc == nil { + return false + } + if atomic.CompareAndSwapInt32(&cc.state, stateCancelling, stateCancelHandled) { + cc.doneOnce.Do(func() { close(cc.doneChan) }) + cc.detachFromParent() + return true + } + return false +} + +// createCancelError creates a CancelError based on the current cancel state. +func (cc *cancelContext) createCancelError() *CancelError { + info := &AgentCancelInfo{} + info.Mode = cc.getMode() + if atomic.LoadInt32(&cc.escalated) == 1 { + info.Escalated = true + info.Timeout = atomic.LoadInt32(&cc.timeoutEscalated) == 1 + } + return &CancelError{ + Info: info, + } +} + +func (cc *cancelContext) createAndMarkCancelHandled() (*CancelError, bool) { + cc.cancelMu.Lock() + defer cc.cancelMu.Unlock() + cancelErr := cc.createCancelError() + ok := cc.markCancelHandled() + return cancelErr, ok +} + +// buildCancelFunc builds the AgentCancelFunc for external use. +func (cc *cancelContext) buildCancelFunc() AgentCancelFunc { + joinMode := func(a, b CancelMode) CancelMode { + if a == CancelImmediate || b == CancelImmediate { + return CancelImmediate + } + return a | b + } + + parseReq := func(callOpts ...AgentCancelOption) *agentCancelConfig { + cfg := &agentCancelConfig{Mode: CancelImmediate} + for _, opt := range callOpts { + opt(cfg) + } + return cfg + } + + startTimeoutController := func() { + cc.timeoutOnce.Do(func() { + go func() { + for { + select { + case <-cc.doneChan: + return + default: + } + + mode := cc.getMode() + if mode == CancelImmediate { + return + } + + deadline := cc.getDeadlineUnixNano() + if deadline == 0 { + select { + case <-cc.timeoutNotify: + continue + case <-cc.doneChan: + return + } + } + + now := time.Now().UnixNano() + wait := time.Duration(deadline - now) + if wait <= 0 { + atomic.StoreInt32(&cc.escalated, 1) + atomic.StoreInt32(&cc.timeoutEscalated, 1) + cc.sendImmediateInterrupt() + return + } + + timer := time.NewTimer(wait) + select { + case <-timer.C: + timer.Stop() + atomic.StoreInt32(&cc.escalated, 1) + atomic.StoreInt32(&cc.timeoutEscalated, 1) + cc.sendImmediateInterrupt() + return + case <-cc.timeoutNotify: + timer.Stop() + continue + case <-cc.doneChan: + timer.Stop() + return + } + } + }() + }) + } + + newHandle := func(wait func() error) *CancelHandle { + return &CancelHandle{wait: wait} + } + + waitForCompletion := func() error { + <-cc.doneChan + + st := atomic.LoadInt32(&cc.state) + switch st { + case stateDone: + return ErrExecutionCompleted + default: + if atomic.LoadInt32(&cc.timeoutEscalated) == 1 { + return ErrCancelTimeout + } + return nil + } + } + + return func(callOpts ...AgentCancelOption) (*CancelHandle, bool) { + req := parseReq(callOpts...) + + st := atomic.LoadInt32(&cc.state) + switch st { + case stateCancelHandled: + return newHandle(func() error { return nil }), false + case stateDone: + return newHandle(func() error { return ErrExecutionCompleted }), false + } + + var needImmediate, needTimeoutCtl bool + + cc.cancelMu.Lock() + + st = atomic.LoadInt32(&cc.state) + switch st { + case stateCancelHandled: + cc.cancelMu.Unlock() + return newHandle(func() error { return nil }), false + case stateDone: + cc.cancelMu.Unlock() + return newHandle(func() error { return ErrExecutionCompleted }), false + } + + curMode := cc.getMode() + if st == stateRunning { + if !atomic.CompareAndSwapInt32(&cc.state, stateRunning, stateCancelling) { + st = atomic.LoadInt32(&cc.state) + cc.cancelMu.Unlock() + if st == stateDone { + return newHandle(func() error { return ErrExecutionCompleted }), false + } + return newHandle(waitForCompletion), true + } + + curMode = req.Mode + cc.setMode(curMode) + atomic.StoreInt32(&cc.startedMode, int32(curMode)) + close(cc.cancelChan) + } else { + curMode = joinMode(curMode, req.Mode) + cc.setMode(curMode) + } + + if curMode == CancelImmediate { + cc.setDeadlineUnixNano(0) + needImmediate = true + } else if req.Timeout != nil && *req.Timeout > 0 { + proposed := time.Now().Add(*req.Timeout).UnixNano() + existing := cc.getDeadlineUnixNano() + if existing == 0 || proposed < existing { + cc.setDeadlineUnixNano(proposed) + cc.wakeTimeoutController() + } + needTimeoutCtl = cc.getDeadlineUnixNano() != 0 + } + + cc.cancelMu.Unlock() + + if needImmediate { + if atomic.LoadInt32(&cc.startedMode) != int32(CancelImmediate) { + atomic.StoreInt32(&cc.escalated, 1) + } + cc.sendImmediateInterrupt() + } + if needTimeoutCtl { + startTimeoutController() + } + + return newHandle(waitForCompletion), true + } +} + +// wrapIterWithCancelCtx wraps an iterator with cancel lifecycle management. +// It calls markDone when the inner iterator is fully drained, ensuring the +// cancelContext's doneChan is closed and propagation goroutines can exit. +// +// For root cancelContexts (created by WithCancel, not deriveChild), it also +// converts interrupt ACTION events to CancelError when cancel is active. +// This is the single point of interrupt-to-CancelError conversion in the +// system — Runner.handleIter only enriches the resulting CancelError with +// checkpoint metadata. +// +// Interrupt absorption: ALL interrupts are converted when cancel is active, +// including business interrupts (compose.Interrupt from user code). Cancel and +// business interrupts cannot be reliably distinguished in concurrent execution +// (parallel workflows, concurrent tool calls) where they merge into a single +// composite signal. The business interrupt data is preserved in the checkpoint +// and re-fires naturally on resume. +// +// This conversion MUST happen in this wrapper (not deferred to Runner.handleIter) +// because markDone runs as a defer in this goroutine — if the interrupt event +// were passed through unconverted, markDone would transition stateCancelling→stateDone +// before the Runner goroutine could call createAndMarkCancelHandled, causing it +// to fail the CAS. +func wrapIterWithCancelCtx(iter *AsyncIterator[*AgentEvent], cancelCtx *cancelContext) *AsyncIterator[*AgentEvent] { + if cancelCtx == nil { + return iter + } + it, gen := NewAsyncIteratorPair[*AgentEvent]() + go func() { + defer cancelCtx.markDone() + defer gen.Close() + for { + event, ok := iter.Next() + if !ok { + break + } + + if cancelCtx.isRoot() && event.Action != nil && event.Action.internalInterrupted != nil { + if cancelCtx.shouldCancel() { + cancelErr, ok := cancelCtx.createAndMarkCancelHandled() + if ok { + cancelErr.interruptSignal = event.Action.internalInterrupted + gen.Send(&AgentEvent{Err: cancelErr}) + } + return + } + } + + gen.Send(event) + } + }() + return it +} + +// cancelMonitoredModel wraps a model with cancel monitoring. +// Generate: pure delegate to the inner model (CancelAfterChatModel is handled +// by a dedicated node after the ChatModel in the compose graph). +// Stream: pipes chunks through a goroutine that selects on immediateChan for +// CancelImmediate abort. +type cancelMonitoredModel struct { + inner model.BaseChatModel + cancelContext *cancelContext +} + +type recvResult[T any] struct { + data T + err error +} + +func (m *cancelMonitoredModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + return m.inner.Generate(ctx, input, opts...) +} + +func (m *cancelMonitoredModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + stream, err := m.inner.Stream(ctx, input, opts...) + if err != nil { + return nil, err + } + wrapped := wrapStreamWithCancelMonitoring(stream, m.cancelContext) + return wrapped, nil +} + +// wrapStreamWithCancelMonitoring wraps a stream with cancel monitoring. +// When immediateChan fires (CancelImmediate or timeout escalation), the output +// stream is terminated with ErrStreamCanceled. +func wrapStreamWithCancelMonitoring[T any](stream *schema.StreamReader[T], cc *cancelContext) *schema.StreamReader[T] { + if cc == nil { + return stream + } + + // Already canceled — terminate immediately + select { + case <-cc.immediateChan: + stream.Close() + r, w := schema.Pipe[T](1) + var zero T + w.Send(zero, ErrStreamCanceled) + w.Close() + return r + default: + } + + reader, writer := schema.Pipe[T](1) + + go func() { + done := make(chan struct{}) + defer close(done) + defer writer.Close() + defer stream.Close() + + ch := make(chan recvResult[T]) + go func() { + defer close(ch) + for { + chunk, recvErr := stream.Recv() + select { + case ch <- recvResult[T]{chunk, recvErr}: + case <-done: + return + } + if recvErr != nil { + return + } + } + }() + + for { + select { + case <-cc.immediateChan: + var zero T + writer.Send(zero, ErrStreamCanceled) + return + + case r, ok := <-ch: + if !ok { + return + } + if r.err != nil { + if r.err == io.EOF { + return + } + var zero T + writer.Send(zero, r.err) + return + } + if closed := writer.Send(r.data, nil); closed { + return + } + } + } + }() + + return reader +} + +// cancelMonitoredToolHandler wraps streamable tool calls with cancel monitoring. +// When CancelImmediate fires, the tool output stream is terminated with ErrStreamCanceled. +// This handler reads the cancelContext from the Go context via getCancelContext. +type cancelMonitoredToolHandler struct{} + +func (h *cancelMonitoredToolHandler) WrapStreamableToolCall(next compose.StreamableToolEndpoint) compose.StreamableToolEndpoint { + return func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) { + output, err := next(ctx, input) + if err != nil { + return nil, err + } + + cc := getCancelContext(ctx) + if cc == nil { + return output, nil + } + + wrapped := wrapStreamWithCancelMonitoring(output.Result, cc) + return &compose.StreamToolOutput{Result: wrapped}, nil + } +} + +func (h *cancelMonitoredToolHandler) WrapEnhancedStreamableToolCall(next compose.EnhancedStreamableToolEndpoint) compose.EnhancedStreamableToolEndpoint { + return func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedStreamableToolOutput, error) { + output, err := next(ctx, input) + if err != nil { + return nil, err + } + + cc := getCancelContext(ctx) + if cc == nil { + return output, nil + } + + wrapped := wrapStreamWithCancelMonitoring(output.Result, cc) + return &compose.EnhancedStreamableToolOutput{Result: wrapped}, nil + } +} diff --git a/adk/cancel_edge_test.go b/adk/cancel_edge_test.go new file mode 100644 index 000000000..d3fb02a1a --- /dev/null +++ b/adk/cancel_edge_test.go @@ -0,0 +1,1268 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/components/tool" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" +) + +// --- helpers shared across edge-case tests --- + +// blockingChatModel blocks until unblockCh is closed, then returns a fixed response. +type blockingChatModel struct { + unblockCh chan struct{} + response *schema.Message + started chan struct{} + callCount int32 +} + +func newBlockingChatModel(response *schema.Message) *blockingChatModel { + return &blockingChatModel{ + unblockCh: make(chan struct{}), + response: response, + started: make(chan struct{}, 1), + } +} + +func (m *blockingChatModel) Generate(ctx context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m.callCount, 1) + select { + case m.started <- struct{}{}: + default: + } + <-m.unblockCh + return m.response, nil +} + +func (m *blockingChatModel) Stream(ctx context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m.callCount, 1) + select { + case m.started <- struct{}{}: + default: + } + <-m.unblockCh + return schema.StreamReaderFromArray([]*schema.Message{m.response}), nil +} + +func (m *blockingChatModel) BindTools(_ []*schema.ToolInfo) error { return nil } + +// errorChatModel returns an error from Generate/Stream. +type errorChatModel struct { + err error + started chan struct{} +} + +func (m *errorChatModel) Generate(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + if m.started != nil { + select { + case m.started <- struct{}{}: + default: + } + } + return nil, m.err +} + +func (m *errorChatModel) Stream(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, m.err +} + +func (m *errorChatModel) BindTools(_ []*schema.ToolInfo) error { return nil } + +// plainResponseModel returns immediately with a fixed text response (no tool calls). +type plainResponseModel struct { + text string +} + +func (m *plainResponseModel) Generate(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return schema.AssistantMessage(m.text, nil), nil +} + +func (m *plainResponseModel) Stream(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage(m.text, nil)}), nil +} + +func (m *plainResponseModel) BindTools(_ []*schema.ToolInfo) error { return nil } + +// blockingTool blocks until unblockCh is closed. +type blockingTool struct { + name string + unblockCh chan struct{} + started chan struct{} + callCount int32 +} + +func newBlockingTool(name string) *blockingTool { + return &blockingTool{ + name: name, + unblockCh: make(chan struct{}), + started: make(chan struct{}, 4), + } +} + +func (t *blockingTool) Info(_ context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: t.name, + Desc: "blocking tool", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "input": {Type: "string"}, + }), + }, nil +} + +func (t *blockingTool) InvokableRun(_ context.Context, _ string, _ ...tool.Option) (string, error) { + atomic.AddInt32(&t.callCount, 1) + select { + case t.started <- struct{}{}: + default: + } + <-t.unblockCh + return "result", nil +} + +func toolCallMsg(calls ...schema.ToolCall) *schema.Message { + return &schema.Message{Role: schema.Assistant, ToolCalls: calls} +} + +func toolCall(id, name, args string) schema.ToolCall { + return schema.ToolCall{ID: id, Type: "function", Function: schema.FunctionCall{Name: name, Arguments: args}} +} + +func drainEvents(iter *AsyncIterator[*AgentEvent]) ([]*AgentEvent, bool) { + var events []*AgentEvent + hasCancelError := false + for { + e, ok := iter.Next() + if !ok { + break + } + events = append(events, e) + var ce *CancelError + if e.Err != nil && errors.As(e.Err, &ce) { + hasCancelError = true + } + } + return events, hasCancelError +} + +// --- tests --- + +// TestWithCancel_BeforeExecutionStarts verifies that a cancel issued before +// the graph begins executing still produces a CancelError without invoking +// the model or tools. +func TestWithCancel_BeforeExecutionStarts(t *testing.T) { + ctx := context.Background() + + blk := newBlockingChatModel(toolCallMsg(toolCall("c1", "bt", `{"input":"x"}`))) + bt := newBlockingTool("bt") + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: blk, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{bt}}, + }, + }) + assert.NoError(t, err) + + cancelOpt, cancelFn := WithCancel() + + // Extract the cancelContext so we can wait for cancelChan to close, + // ensuring the cancel is fully registered before Run starts. + cc := getCommonOptions(nil, cancelOpt).cancelCtx + + // Call cancel BEFORE calling agent.Run. + // The cancelFunc must succeed (not hang) even though execution hasn't started. + cancelDone := make(chan error, 1) + go func() { + handle, _ := cancelFn() + cancelDone <- handle.Wait() + }() + + // Wait for cancelChan to close so the pre-execution check in runFunc + // deterministically sees shouldCancel()=true (eliminates goroutine scheduling race). + <-cc.cancelChan + + // Now start the run — it should see shouldCancel()=true and emit CancelError immediately. + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("hi")}}, cancelOpt) + + _, hasCancelError := drainEvents(iter) + assert.True(t, hasCancelError, "expected CancelError when cancel precedes execution") + + // cancelFn must have already returned (or return quickly now that doneChan is closed). + select { + case cancelErr := <-cancelDone: + // Either nil (cancel handled) or ErrExecutionCompleted is acceptable + // depending on exact timing; what matters is it didn't hang. + _ = cancelErr + case <-time.After(3 * time.Second): + t.Fatal("cancelFn blocked indefinitely after pre-start cancel") + } + + // Model and tool must not have been invoked. + assert.Equal(t, int32(0), atomic.LoadInt32(&bt.callCount), "tool must not be called") +} + +// TestWithCancel_AfterCompletion verifies cancelFn returns ErrExecutionCompleted +// when called after a normal run finishes. +func TestWithCancel_AfterCompletion(t *testing.T) { + ctx := context.Background() + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: &plainResponseModel{text: "done"}, + }) + require.NoError(t, err) + + cancelOpt, cancelFn := WithCancel() + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("hi")}}, cancelOpt) + + // Drain all events so the run completes. + for { + _, ok := iter.Next() + if !ok { + break + } + } + + handle, _ := cancelFn() + cancelErr := handle.Wait() + assert.ErrorIs(t, cancelErr, ErrExecutionCompleted) +} + +// TestWithCancel_AfterBusinessInterrupt verifies cancelFn returns ErrExecutionCompleted +// when called after the agent has been interrupted by business logic. +func TestWithCancel_AfterBusinessInterrupt(t *testing.T) { + ctx := context.Background() + + // Use a model that triggers a compose.Interrupt so the agent stops with an interrupt. + interruptModel := &interruptingChatModel{} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: interruptModel, + }) + require.NoError(t, err) + + store := newCancelTestStore() + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + CheckPointStore: store, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("hi")}, cancelOpt, WithCheckPointID("biz-interrupt-1")) + + // Drain — expect an interrupt action event, no cancel error. + var gotInterrupt bool + for { + e, ok := iter.Next() + if !ok { + break + } + if e.Action != nil && e.Action.Interrupted != nil { + gotInterrupt = true + } + } + assert.True(t, gotInterrupt, "expected business interrupt event") + + handle, _ := cancelFn() + cancelErr := handle.Wait() + assert.ErrorIs(t, cancelErr, ErrExecutionCompleted) +} + +// TestWithCancel_AfterError verifies cancelFn returns ErrExecutionCompleted +// when called after the agent errors out. +func TestWithCancel_AfterError(t *testing.T) { + ctx := context.Background() + + modelErr := errors.New("model exploded") + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: &errorChatModel{err: modelErr}, + }) + require.NoError(t, err) + + cancelOpt, cancelFn := WithCancel() + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("hi")}}, cancelOpt) + + for { + _, ok := iter.Next() + if !ok { + break + } + } + + handle, _ := cancelFn() + cancelErr := handle.Wait() + assert.ErrorIs(t, cancelErr, ErrExecutionCompleted) +} + +// TestWithCancel_TimeoutEscalation tests that WithAgentCancelTimeout causes the +// cancel to escalate to immediate when the safe-point hasn't fired yet, and +// that the resulting CancelError has Escalated=true. +// +// Strategy: use CancelAfterChatModel mode. The model blocks (never completes), +// so the safe-point can't fire naturally. After the timeout, escalateToImmediate +// closes immediateChan which aborts the model stream via cancelMonitoredModel +// and causes a CancelError — no compose graph-interrupt races involved. +func TestWithCancel_TimeoutEscalation(t *testing.T) { + ctx := context.Background() + + blk := newBlockingChatModel(schema.AssistantMessage("hello", nil)) + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: blk, + }) + require.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + EnableStreaming: true, // use streaming so cancelMonitoredModel.Stream is exercised + }) + + timeout := 300 * time.Millisecond + // CancelAfterChatModel + timeout: safe-point can't fire (model never finishes), + // so after 300ms the timeout goroutine escalates to immediate. + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("go")}, cancelOpt) + + select { + case <-blk.started: + case <-time.After(5 * time.Second): + t.Fatal("model did not start") + } + + // Fire cancelFn; it will wait for escalation to complete. + start := time.Now() + handle, _ := cancelFn(WithAgentCancelMode(CancelAfterChatModel), WithAgentCancelTimeout(timeout)) + cancelErr := handle.Wait() + elapsed := time.Since(start) + + assert.ErrorIs(t, cancelErr, ErrCancelTimeout, "cancel should return ErrCancelTimeout after timeout escalation") + assert.True(t, elapsed >= timeout, "should wait at least the timeout duration, elapsed=%v", elapsed) + assert.True(t, elapsed < 3*time.Second, "should complete shortly after timeout, elapsed=%v", elapsed) + + var cancelError *CancelError + for { + e, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if e.Err != nil && errors.As(e.Err, &ce) { + cancelError = ce + } + } + if assert.NotNil(t, cancelError, "expected CancelError after timeout escalation") { + assert.True(t, cancelError.Info.Escalated, "CancelError should report Escalated=true") + assert.True(t, cancelError.Info.Timeout, "CancelError should report Timeout=true") + } +} + +// TestWithCancel_AfterChatModel_WithTools verifies CancelAfterChatModel fires +// when the model returns tool calls (the safe-point is on the tool-calls path). +func TestWithCancel_AfterChatModel_WithTools(t *testing.T) { + ctx := context.Background() + + blk := newBlockingChatModel(toolCallMsg(toolCall("c1", "bt", `{"input":"x"}`))) + bt := newBlockingTool("bt") + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: blk, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{bt}}, + }, + }) + require.NoError(t, err) + + cancelOpt, cancelFn := WithCancel() + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("hi")}}, cancelOpt) + + select { + case <-blk.started: + case <-time.After(5 * time.Second): + t.Fatal("model did not start") + } + + cancelDone := make(chan error, 1) + go func() { + handle, _ := cancelFn(WithAgentCancelMode(CancelAfterChatModel)) + cancelDone <- handle.Wait() + }() + + time.Sleep(20 * time.Millisecond) + + close(blk.unblockCh) + + cancelErr := <-cancelDone + assert.NoError(t, cancelErr) + + _, hasCancelError := drainEvents(iter) + assert.True(t, hasCancelError, "CancelError expected after model returns tool calls") +} + +// TestWithCancel_CancelImmediate_StreamAborted verifies that CancelImmediate +// during model streaming surfaces ErrStreamCanceled and completes quickly. +func TestWithCancel_CancelImmediate_StreamAborted(t *testing.T) { + ctx := context.Background() + + blk := newBlockingChatModel(schema.AssistantMessage("hello", nil)) + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: blk, + }) + require.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + EnableStreaming: true, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("hi")}, cancelOpt) + + select { + case <-blk.started: + case <-time.After(5 * time.Second): + t.Fatal("model did not start") + } + time.Sleep(50 * time.Millisecond) + + start := time.Now() + handle, _ := cancelFn() + cancelErr := handle.Wait() + assert.NoError(t, cancelErr) + elapsed := time.Since(start) + assert.True(t, elapsed < 2*time.Second, "cancel should complete quickly, elapsed=%v", elapsed) + + var foundStreamCanceled bool + for { + e, ok := iter.Next() + if !ok { + break + } + if e.Err != nil && errors.Is(e.Err, ErrStreamCanceled) { + foundStreamCanceled = true + } + var ce *CancelError + if e.Err != nil && errors.As(e.Err, &ce) { + foundStreamCanceled = true // CancelError wraps stream abort + } + } + assert.True(t, foundStreamCanceled, "expected stream-abort error during immediate cancel") +} + +// TestWithCancel_MultipleToolsConcurrent verifies that CancelAfterToolCalls +// waits for ALL concurrent tool calls to complete before cancelling. +func TestWithCancel_MultipleToolsConcurrent(t *testing.T) { + ctx := context.Background() + + bt1 := newBlockingTool("tool1") + bt2 := newBlockingTool("tool2") + + // Model calls both tools in one response. + modelResp := toolCallMsg( + toolCall("c1", "tool1", `{"input":"a"}`), + toolCall("c2", "tool2", `{"input":"b"}`), + ) + modelWithTools := &simpleChatModel{response: modelResp} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: modelWithTools, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{bt1, bt2}}, + }, + }) + assert.NoError(t, err) + + cancelOpt, cancelFn := WithCancel() + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("go")}}, cancelOpt) + + // Wait for both tools to start. + for i := 0; i < 2; i++ { + select { + case <-bt1.started: + case <-bt2.started: + case <-time.After(5 * time.Second): + t.Fatal("tools did not start") + } + } + + // Request cancel after tool calls while both are still blocking. + cancelDone := make(chan error, 1) + go func() { + handle, _ := cancelFn(WithAgentCancelMode(CancelAfterToolCalls)) + cancelDone <- handle.Wait() + }() + + // Unblock both tools — cancel should fire only after both complete. + time.Sleep(50 * time.Millisecond) + close(bt1.unblockCh) + time.Sleep(50 * time.Millisecond) + close(bt2.unblockCh) + + cancelErr := <-cancelDone + assert.NoError(t, cancelErr) + + assert.Equal(t, int32(1), atomic.LoadInt32(&bt1.callCount), "tool1 should complete") + assert.Equal(t, int32(1), atomic.LoadInt32(&bt2.callCount), "tool2 should complete") + + _, hasCancelError := drainEvents(iter) + assert.True(t, hasCancelError, "expected CancelError after concurrent tools completed") +} + +// TestWithCancel_GraphInterruptRaceBeforeSet verifies that a CancelImmediate +// issued before setGraphInterruptFunc is called still results in cancellation. +// This exercises the retroactive-fire path in setGraphInterruptFunc. +func TestWithCancel_GraphInterruptRaceBeforeSet(t *testing.T) { + ctx := context.Background() + + blk := newBlockingChatModel(schema.AssistantMessage("hi", nil)) + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: blk, + }) + require.NoError(t, err) + + cancelOpt, cancelFn := WithCancel() + + // Cancel immediately before run starts. + go func() { + handle, _ := cancelFn() + _ = handle.Wait() + }() + + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("hi")}}, cancelOpt) + + done := make(chan struct{}) + go func() { + defer close(done) + drainEvents(iter) + }() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("iteration did not complete after pre-start CancelImmediate") + } +} + +// TestWithCancel_NoCheckpointStore verifies cancel completes and does not panic +// when no checkpoint store is configured. +func TestWithCancel_NoCheckpointStore(t *testing.T) { + ctx := context.Background() + + blk := newBlockingChatModel(schema.AssistantMessage("hi", nil)) + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: blk, + }) + require.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + // No CheckPointStore set. + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("hi")}, cancelOpt) + + select { + case <-blk.started: + case <-time.After(5 * time.Second): + t.Fatal("model did not start") + } + time.Sleep(30 * time.Millisecond) + + handle, _ := cancelFn() + cancelErr := handle.Wait() + assert.NoError(t, cancelErr) + + var ce *CancelError + for { + e, ok := iter.Next() + if !ok { + break + } + if e.Err != nil && errors.As(e.Err, &ce) { + break + } + } + if assert.NotNil(t, ce, "expected CancelError even without checkpoint store") { + assert.Empty(t, ce.CheckPointID, "CheckPointID should be empty without checkpoint store") + } +} + +// TestWithCancel_ModelError verifies that a model error marks the cancelCtx as +// done so that a subsequent cancelFn call returns ErrExecutionCompleted. +func TestWithCancel_ModelError(t *testing.T) { + ctx := context.Background() + + modelErr := errors.New("model failed") + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: &errorChatModel{err: modelErr}, + }) + require.NoError(t, err) + + cancelOpt, cancelFn := WithCancel() + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("hi")}}, cancelOpt) + + var gotModelErr bool + for { + e, ok := iter.Next() + if !ok { + break + } + if e.Err != nil && !errors.As(e.Err, new(*CancelError)) { + gotModelErr = true + } + } + assert.True(t, gotModelErr, "expected non-cancel error event from model failure") + + handle, _ := cancelFn() + cancelErr := handle.Wait() + assert.ErrorIs(t, cancelErr, ErrExecutionCompleted, "cancelFn should return ErrExecutionCompleted after model error") +} + +// TestWithCancel_Resume_SafePoint covers CancelAfterChatModel and +// CancelAfterToolCalls on a Resume path. +func TestWithCancel_Resume_SafePoint(t *testing.T) { + ctx := context.Background() + + // --- phase 1: run to get a checkpoint via CancelImmediate --- + blk := newBlockingChatModel(toolCallMsg(toolCall("c1", "bt", `{"input":"x"}`))) + bt := newSlowTool("bt", 50*time.Millisecond, "result") + + agent1, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: blk, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{bt}}, + }, + }) + assert.NoError(t, err) + + store := newCancelTestStore() + runner1 := NewRunner(ctx, RunnerConfig{ + Agent: agent1, + CheckPointStore: store, + }) + + cancelOpt1, cancelFn1 := WithCancel() + iter1 := runner1.Run(ctx, []Message{schema.UserMessage("hi")}, cancelOpt1, WithCheckPointID("resume-sp-1")) + + select { + case <-blk.started: + case <-time.After(5 * time.Second): + t.Fatal("model did not start in phase 1") + } + _, _ = cancelFn1() + drainEvents(iter1) + + // --- phase 2: resume, cancel after chat model --- + resumeModel := newBlockingChatModel(toolCallMsg(toolCall("c1", "bt", `{"input":"x"}`))) + + bt2 := newSlowTool("bt", 50*time.Millisecond, "result") + agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: resumeModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{bt2}}, + }, + }) + assert.NoError(t, err) + + runner2 := NewRunner(ctx, RunnerConfig{ + Agent: agent2, + CheckPointStore: store, + }) + + cancelOpt2, cancelFn2 := WithCancel() + resumeIter, err := runner2.Resume(ctx, "resume-sp-1", cancelOpt2) + require.NoError(t, err) + + select { + case <-resumeModel.started: + case <-time.After(5 * time.Second): + t.Fatal("model did not start in phase 2") + } + + cancelDone := make(chan error, 1) + go func() { + handle, _ := cancelFn2(WithAgentCancelMode(CancelAfterChatModel)) + cancelDone <- handle.Wait() + }() + + time.Sleep(50 * time.Millisecond) + + close(resumeModel.unblockCh) + + cancelErr := <-cancelDone + assert.NoError(t, cancelErr) + + _, hasCancelError := drainEvents(resumeIter) + assert.True(t, hasCancelError, "CancelError expected after resumed model returns tool calls") +} + +// callbackTool is a tool that calls onCall when invoked. +type callbackTool struct { + name string + onCall func() +} + +func (t *callbackTool) Info(_ context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: t.name, + Desc: "callback tool", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "input": {Type: "string"}, + }), + }, nil +} + +func (t *callbackTool) InvokableRun(_ context.Context, _ string, _ ...tool.Option) (string, error) { + if t.onCall != nil { + t.onCall() + } + return "ok", nil +} + +// interruptingChatModel returns a compose.Interrupt error to simulate a +// business interrupt during execution. +type interruptingChatModel struct{} + +func (m *interruptingChatModel) Generate(ctx context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, compose.Interrupt(ctx, "test interrupt") +} + +func (m *interruptingChatModel) Stream(ctx context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, compose.Interrupt(ctx, "test interrupt") +} + +func (m *interruptingChatModel) BindTools(_ []*schema.ToolInfo) error { return nil } + +// TestWithCancel_TargetedResume_CancelImmediate cancels an agent via CancelImmediate, +// extracts InterruptContexts from the resulting CancelError, and uses them +// for targeted resumption via Runner.ResumeWithParams. +func TestWithCancel_TargetedResume_CancelImmediate(t *testing.T) { + ctx := context.Background() + + blk := newBlockingChatModel(toolCallMsg(toolCall("c1", "st", `{"input":"x"}`))) + st := newSlowTool("st", 50*time.Millisecond, "result") + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: blk, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{st}}, + }, + }) + require.NoError(t, err) + + store := newCancelTestStore() + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + CheckPointStore: store, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("go")}, cancelOpt, WithCheckPointID("targeted-imm-1")) + + select { + case <-blk.started: + case <-time.After(5 * time.Second): + t.Fatal("model did not start") + } + + handle, _ := cancelFn() // CancelImmediate (default) + cancelErr := handle.Wait() + assert.NoError(t, cancelErr) + + var cancelError *CancelError + for { + e, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if e.Err != nil && errors.As(e.Err, &ce) { + cancelError = ce + } + } + + require.NotNil(t, cancelError, "expected CancelError") + require.NotEmpty(t, cancelError.InterruptContexts, "CancelError should have InterruptContexts for targeted resume") + + // --- resume with targeted params --- + targets := make(map[string]any) + for _, ic := range cancelError.InterruptContexts { + targets[ic.ID] = nil + } + + resumeModel := &plainResponseModel{text: "resumed"} + agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: resumeModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{st}}, + }, + }) + require.NoError(t, err) + + runner2 := NewRunner(ctx, RunnerConfig{ + Agent: agent2, + CheckPointStore: store, + }) + + resumeIter, err := runner2.ResumeWithParams(ctx, "targeted-imm-1", &ResumeParams{Targets: targets}) + require.NoError(t, err) + + var gotOutput bool + for { + e, ok := resumeIter.Next() + if !ok { + break + } + if e.Err != nil { + t.Fatalf("unexpected error during targeted resume: %v", e.Err) + } + if e.Output != nil && e.Output.MessageOutput != nil { + gotOutput = true + } + } + assert.True(t, gotOutput, "targeted resume should produce output") +} + +// TestWithCancel_TargetedResume_SafePoint cancels an agent via CancelAfterChatModel +// (safe-point) and verifies that InterruptContexts are populated on the CancelError +// and that targeted resume via ResumeWithParams succeeds. +// Since safe-point cancels now use compose.Interrupt, compose saves checkpoint data, +// making the cancel fully resumable. +func TestWithCancel_TargetedResume_SafePoint(t *testing.T) { + ctx := context.Background() + + // The model returns a tool call so the react graph routes to toolPreHandle, + // which detects CancelAfterChatModel and fires compose.Interrupt. + blk := newBlockingChatModel(toolCallMsg(toolCall("c1", "st", `{"input":"x"}`))) + st := newSlowTool("st", 0, "result") + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: blk, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{st}}, + }, + }) + require.NoError(t, err) + + store := newCancelTestStore() + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + CheckPointStore: store, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("go")}, cancelOpt, WithCheckPointID("targeted-sp-1")) + + select { + case <-blk.started: + case <-time.After(5 * time.Second): + t.Fatal("model did not start") + } + + // Start cancelFn in background so the CAS happens before the model unblocks. + cancelDone := make(chan error, 1) + go func() { + handle, _ := cancelFn(WithAgentCancelMode(CancelAfterChatModel)) + cancelDone <- handle.Wait() + }() + time.Sleep(50 * time.Millisecond) + close(blk.unblockCh) + + cancelErr := <-cancelDone + assert.NoError(t, cancelErr) + + var cancelError *CancelError + for { + e, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if e.Err != nil && errors.As(e.Err, &ce) { + cancelError = ce + } + } + + require.NotNil(t, cancelError, "expected CancelError") + require.NotEmpty(t, cancelError.InterruptContexts, "CancelError should have InterruptContexts for targeted resume") + + // --- resume with targeted params --- + targets := make(map[string]any) + for _, ic := range cancelError.InterruptContexts { + targets[ic.ID] = nil + } + + resumeModel := &plainResponseModel{text: "resumed"} + agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: resumeModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{st}}, + }, + }) + require.NoError(t, err) + + runner2 := NewRunner(ctx, RunnerConfig{ + Agent: agent2, + CheckPointStore: store, + }) + + resumeIter, err := runner2.ResumeWithParams(ctx, "targeted-sp-1", &ResumeParams{Targets: targets}) + require.NoError(t, err) + + var gotOutput bool + for { + e, ok := resumeIter.Next() + if !ok { + break + } + if e.Err != nil { + t.Fatalf("unexpected error during targeted resume: %v", e.Err) + } + if e.Output != nil && e.Output.MessageOutput != nil { + gotOutput = true + } + } + assert.True(t, gotOutput, "targeted resume should produce output") +} + +// TestWithCancel_Resume_CancelAfterChatModel_MessagePreserved tests both the +// ReAct (with-tools) and noTools paths to ensure that when a +// CancelAfterChatModel safe-point fires and the run is later resumed, the +// original Message returned by the chat model is preserved through the +// StatefulInterrupt checkpoint. +// +// For the ReAct path: the model returns a tool-call message. On resume the +// cancelCheck node must return that same message so the branch routes to the +// ToolNode and the tool actually executes. +// +// For the noTools path: the model returns a plain text message. On resume the +// cancel-check lambda must return that same message as the chain output. +func TestWithCancel_Resume_CancelAfterChatModel_MessagePreserved(t *testing.T) { + t.Run("react_path_tool_call_preserved", func(t *testing.T) { + ctx := context.Background() + + // Phase-2 model returns no tool calls so the graph ends. + // We track whether the tool actually executes on resume. + toolExecuted := make(chan struct{}, 1) + st := &callbackTool{ + name: "my_tool", + onCall: func() { + select { + case toolExecuted <- struct{}{}: + default: + } + }, + } + + // Phase-1 model returns a tool call. + blk := newBlockingChatModel(toolCallMsg(toolCall("c1", "my_tool", `{"input":"x"}`))) + + agent1, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: blk, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{st}}, + }, + }) + require.NoError(t, err) + + store := newCancelTestStore() + runner1 := NewRunner(ctx, RunnerConfig{ + Agent: agent1, + CheckPointStore: store, + }) + + cancelOpt1, cancelFn1 := WithCancel() + iter1 := runner1.Run(ctx, []Message{schema.UserMessage("hi")}, + cancelOpt1, WithCheckPointID("react-msg-preserved-1")) + + select { + case <-blk.started: + case <-time.After(5 * time.Second): + t.Fatal("model did not start in phase 1") + } + + cancelDone := make(chan error, 1) + go func() { + handle, _ := cancelFn1(WithAgentCancelMode(CancelAfterChatModel)) + cancelDone <- handle.Wait() + }() + time.Sleep(50 * time.Millisecond) + close(blk.unblockCh) + + cancelErr := <-cancelDone + assert.NoError(t, cancelErr) + + _, hasCancelError := drainEvents(iter1) + assert.True(t, hasCancelError, "expected CancelError from phase 1") + + // Phase 2: resume. The model for phase-2 returns plain text (no tool + // calls) so the react graph ends after one iteration. But first the + // tool from the checkpoint must execute. + resumeModel := &plainResponseModel{text: "done"} + agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: resumeModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{st}}, + }, + }) + require.NoError(t, err) + + runner2 := NewRunner(ctx, RunnerConfig{ + Agent: agent2, + CheckPointStore: store, + }) + + resumeIter, err := runner2.Resume(ctx, "react-msg-preserved-1") + require.NoError(t, err) + + for { + e, ok := resumeIter.Next() + if !ok { + break + } + if e.Err != nil { + t.Fatalf("unexpected error during resume: %v", e.Err) + } + } + + // The key assertion: the tool must have been called during resume, + // which can only happen if the tool-call message was preserved. + select { + case <-toolExecuted: + // success + default: + t.Fatal("tool was not executed on resume — the tool-call message was lost") + } + }) + +} + +// TestHandleRunFuncError_AlreadyHandled_NoDuplicate verifies that when +// markCancelHandled() was already claimed by a sub-agent's handleRunFuncError, +// the sequential workflow's checkCancel does not emit a second CancelError. +// +// Setup: sequential[cma1, cma2] with CancelAfterToolCalls. cma1 has tools, +// cancel fires while tool is running. After tool completes, the safe-point +// fires in cma1's handleRunFuncError (claiming markCancelHandled). The +// sequential workflow's checkCancel at the transition point should find +// markCancelHandled returns false and skip — producing exactly 1 CancelError. +func TestHandleRunFuncError_AlreadyHandled_NoDuplicate(t *testing.T) { + ctx := context.Background() + + bt := newBlockingTool("bt") + + // cma1: model returns a tool call immediately, tool blocks until unblocked + cma1Model := newBlockingChatModel(toolCallMsg(toolCall("c1", "bt", `{"input":"x"}`))) + close(cma1Model.unblockCh) // model returns immediately + + agent1, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent1", Description: "first", Instruction: "test", + Model: cma1Model, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{bt}}, + }, + }) + require.NoError(t, err) + + agent2Model := &plainResponseModel{text: "agent2-response"} + agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent2", Description: "second", Instruction: "test", + Model: agent2Model, + }) + require.NoError(t, err) + + seqAgent, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "seq", Description: "sequential", SubAgents: []Agent{agent1, agent2}, + }) + require.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: seqAgent, EnableStreaming: false, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt) + + // Wait for tool to start + select { + case <-bt.started: + case <-time.After(5 * time.Second): + t.Fatal("Tool did not start") + } + + // Cancel while tool is still running (in goroutine because cancelFn blocks + // until execution finishes), then unblock tool so safe-point fires + go func() { + handle, _ := cancelFn(WithAgentCancelMode(CancelAfterToolCalls)) + _ = handle.Wait() + }() + + // Give cancel time to register, then unblock tool + time.Sleep(50 * time.Millisecond) + close(bt.unblockCh) + + cancelCount := 0 + for { + event, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + cancelCount++ + } + } + + assert.Equal(t, 1, cancelCount, "Should have exactly one CancelError, no duplicate from handleRunFuncError + checkCancel") +} + +func TestWithCancel_CancelAfterChatModel_NestedAgentTool(t *testing.T) { + ctx := context.Background() + + subAgentModel := newBlockingChatModel(toolCallMsg(toolCall("c1", "sub_tool", `{"input":"x"}`))) + subAgentModelStarted := subAgentModel.started + subTool := newBlockingTool("sub_tool") + + subAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "sub_agent", + Description: "test sub agent", + Instruction: "you are a sub agent", + Model: subAgentModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{subTool}}, + }, + }) + require.NoError(t, err) + + supervisorModel := &simpleChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + ToolCalls: []schema.ToolCall{{ + ID: "call_1", Type: "function", + Function: schema.FunctionCall{ + Name: TransferToAgentToolName, + Arguments: `{"agent_name": "sub_agent"}`, + }, + }}, + }, + } + + supervisorAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "supervisor", + Description: "supervisor agent (equivalent to DeepAgent)", + Instruction: "you are a supervisor", + Model: supervisorModel, + }) + require.NoError(t, err) + + agentWithSubAgents, err := SetSubAgents(ctx, supervisorAgent, []Agent{subAgent}) + require.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: agentWithSubAgents, + EnableStreaming: false, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt) + + select { + case <-subAgentModelStarted: + case <-time.After(10 * time.Second): + t.Fatal("Sub-agent model did not start") + } + + time.Sleep(50 * time.Millisecond) + + cancelDone := make(chan error, 1) + go func() { + handle, _ := cancelFn(WithAgentCancelMode(CancelAfterChatModel)) + cancelDone <- handle.Wait() + }() + + time.Sleep(100 * time.Millisecond) + close(subAgentModel.unblockCh) + + cancelErr := <-cancelDone + assert.NoError(t, cancelErr) + + hasCancelError := false + for { + event, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + hasCancelError = true + } + } + + assert.True(t, hasCancelError, "CancelError expected from nested agent tool with tools") +} diff --git a/adk/cancel_multicall_test.go b/adk/cancel_multicall_test.go new file mode 100644 index 000000000..790d14fb3 --- /dev/null +++ b/adk/cancel_multicall_test.go @@ -0,0 +1,125 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/compose" +) + +func TestAgentCancelFunc_MultiCall_EscalateToImmediate(t *testing.T) { + cc := newCancelContext() + var interruptCalls int32 + cc.setGraphInterruptFunc(func(opts ...compose.GraphInterruptOption) { + atomic.AddInt32(&interruptCalls, 1) + }) + cancelFn := cc.buildCancelFunc() + + handle1, _ := cancelFn(WithAgentCancelMode(CancelAfterChatModel)) + handle2, _ := cancelFn(WithAgentCancelMode(CancelImmediate)) + assert.Equal(t, int32(1), atomic.LoadInt32(&interruptCalls)) + + cancelErr := cc.createCancelError() + assert.Equal(t, CancelImmediate, cancelErr.Info.Mode) + assert.True(t, cancelErr.Info.Escalated) + assert.False(t, cancelErr.Info.Timeout) + + assert.True(t, cc.markCancelHandled()) + assert.NoError(t, handle1.Wait()) + assert.NoError(t, handle2.Wait()) +} + +func TestAgentCancelFunc_MultiCall_JoinSafePointModes(t *testing.T) { + cc := newCancelContext() + cancelFn := cc.buildCancelFunc() + + handle1, _ := cancelFn(WithAgentCancelMode(CancelAfterChatModel)) + handle2, _ := cancelFn(WithAgentCancelMode(CancelAfterToolCalls)) + + want := CancelAfterChatModel | CancelAfterToolCalls + assert.Equal(t, want, cc.getMode()) + + assert.True(t, cc.markCancelHandled()) + assert.NoError(t, handle1.Wait()) + assert.NoError(t, handle2.Wait()) +} + +func TestAgentCancelFunc_MultiCall_TimeoutDeadlineJoinUsesAbsoluteTime(t *testing.T) { + cc := newCancelContext() + cancelFn := cc.buildCancelFunc() + + handle1, _ := cancelFn( + WithAgentCancelMode(CancelAfterChatModel), + WithAgentCancelTimeout(200*time.Millisecond), + ) + + firstDeadline := cc.getDeadlineUnixNano() + assert.NotZero(t, firstDeadline) + + time.Sleep(50 * time.Millisecond) + + handle2, _ := cancelFn( + WithAgentCancelMode(CancelAfterToolCalls), + WithAgentCancelTimeout(60*time.Millisecond), + ) + + secondDeadline := cc.getDeadlineUnixNano() + assert.NotZero(t, secondDeadline) + assert.Less(t, secondDeadline, firstDeadline) + + assert.True(t, cc.markCancelHandled()) + assert.NoError(t, handle1.Wait()) + assert.NoError(t, handle2.Wait()) +} + +func TestAgentCancelFunc_MultiCall_TimeoutEscalationReturnsErrCancelTimeout(t *testing.T) { + cc := newCancelContext() + var interruptCalls int32 + interruptCh := make(chan struct{}, 1) + cc.setGraphInterruptFunc(func(opts ...compose.GraphInterruptOption) { + atomic.AddInt32(&interruptCalls, 1) + select { + case interruptCh <- struct{}{}: + default: + } + }) + cancelFn := cc.buildCancelFunc() + handle, _ := cancelFn( + WithAgentCancelMode(CancelAfterChatModel), + WithAgentCancelTimeout(30*time.Millisecond), + ) + + select { + case <-interruptCh: + case <-time.After(1 * time.Second): + t.Fatal("timeout escalation did not interrupt") + } + assert.Equal(t, int32(1), atomic.LoadInt32(&interruptCalls)) + + cancelErr := cc.createCancelError() + assert.Equal(t, CancelAfterChatModel, cancelErr.Info.Mode) + assert.True(t, cancelErr.Info.Escalated) + assert.True(t, cancelErr.Info.Timeout) + + assert.True(t, cc.markCancelHandled()) + assert.Equal(t, ErrCancelTimeout, handle.Wait()) +} diff --git a/adk/cancel_test.go b/adk/cancel_test.go index 702efd3c3..0d88db8cc 100644 --- a/adk/cancel_test.go +++ b/adk/cancel_test.go @@ -18,6 +18,10 @@ package adk import ( "context" + "errors" + "fmt" + "io" + "runtime" "sync" "sync/atomic" "testing" @@ -32,18 +36,30 @@ import ( ) type cancelTestChatModel struct { - delay time.Duration + delayNs int64 response *schema.Message startedChan chan struct{} doneChan chan struct{} } +func (m *cancelTestChatModel) getDelay() time.Duration { + return time.Duration(atomic.LoadInt64(&m.delayNs)) +} + +func (m *cancelTestChatModel) setDelay(d time.Duration) { + atomic.StoreInt64(&m.delayNs, int64(d)) +} + func (m *cancelTestChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { select { case m.startedChan <- struct{}{}: default: } - time.Sleep(m.delay) + select { + case <-time.After(m.getDelay()): + case <-ctx.Done(): + return nil, ctx.Err() + } select { case m.doneChan <- struct{}{}: default: @@ -53,7 +69,7 @@ func (m *cancelTestChatModel) Generate(ctx context.Context, input []*schema.Mess func (m *cancelTestChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { m.startedChan <- struct{}{} - time.Sleep(m.delay) + time.Sleep(m.getDelay()) m.doneChan <- struct{}{} return schema.StreamReaderFromArray([]*schema.Message{m.response}), nil } @@ -95,7 +111,11 @@ func (t *slowTool) InvokableRun(ctx context.Context, argumentsInJSON string, opt case t.startedChan <- struct{}{}: default: } - time.Sleep(t.delay) + select { + case <-time.After(t.delay): + case <-ctx.Done(): + return "", ctx.Err() + } return t.result, nil } @@ -122,22 +142,20 @@ func (s *cancelTestStore) Get(_ context.Context, key string) ([]byte, bool, erro return v, ok, nil } -func TestCancelSig(t *testing.T) { - t.Run("BasicCancelSignal", func(t *testing.T) { - cs := newCancelSig() +func TestCancelContext(t *testing.T) { + t.Run("BasicCancelContext", func(t *testing.T) { + cc := newCancelContext() + assert.False(t, cc.shouldCancel(), "Should not be cancelled initially") - cfg := checkCancelSig(cs) - assert.Nil(t, cfg, "Should not be cancelled initially") + cc.setMode(CancelImmediate) + close(cc.cancelChan) - cs.cancel(&cancelConfig{Mode: CancelImmediate}) - - cfg = checkCancelSig(cs) - assert.NotNil(t, cfg, "Should be cancelled after cancel()") - assert.Equal(t, CancelImmediate, cfg.Mode) + assert.True(t, cc.shouldCancel(), "Should be cancelled after close(cancelChan)") + assert.Equal(t, CancelImmediate, cc.getMode()) }) } -func TestRunWithCancel_WithTools(t *testing.T) { +func TestWithCancel_WithTools(t *testing.T) { ctx := context.Background() t.Run("CancelImmediate_DuringModelCall", func(t *testing.T) { @@ -145,7 +163,7 @@ func TestRunWithCancel_WithTools(t *testing.T) { st := newSlowTool("slow_tool", 100*time.Millisecond, "tool result") slowModel := &cancelTestChatModel{ - delay: 2 * time.Second, + delayNs: int64(2 * time.Second), response: &schema.Message{ Role: schema.Assistant, Content: "", @@ -182,7 +200,8 @@ func TestRunWithCancel_WithTools(t *testing.T) { EnableStreaming: false, }) - iter, cancelFn := runner.RunWithCancel(ctx, []Message{schema.UserMessage("Use the tool")}) + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("Use the tool")}, cancelOpt) assert.NotNil(t, iter) assert.NotNil(t, cancelFn) @@ -207,24 +226,27 @@ func TestRunWithCancel_WithTools(t *testing.T) { time.Sleep(100 * time.Millisecond) - err = cancelFn() + handle, _ := cancelFn() + err = handle.Wait() assert.NoError(t, err) - start := time.Now() - events := <-eventsCh - elapsed := time.Since(start) + var events []*AgentEvent + select { + case events = <-eventsCh: + case <-time.After(5 * time.Second): + t.Fatal("Timed out waiting for events") + } - assert.True(t, elapsed < 1*time.Second, "Should return quickly after cancel, elapsed: %v", elapsed) assert.True(t, len(events) > 0) - hasInterrupted := false + hasCancelError := false for _, e := range events { - assert.Nil(t, e.Err, "Should not have error event after cancel") - if e.Action != nil && e.Action.Interrupted != nil { - hasInterrupted = true + var cancelErr *CancelError + if e.Err != nil && errors.As(e.Err, &cancelErr) { + hasCancelError = true } } - assert.True(t, hasInterrupted, "Should have interrupted event after cancel") + assert.True(t, hasCancelError, "Should have CancelError event after cancel") }) t.Run("CancelAfterChatModel_DuringToolCall", func(t *testing.T) { @@ -237,6 +259,7 @@ func TestRunWithCancel_WithTools(t *testing.T) { } modelWithToolCall := &simpleChatModel{ + delay: 1 * time.Second, response: &schema.Message{ Role: schema.Assistant, Content: "", @@ -266,9 +289,10 @@ func TestRunWithCancel_WithTools(t *testing.T) { }) assert.NoError(t, err) - iter, cancelFn := agent.RunWithCancel(ctx, &AgentInput{ + cancelOpt, cancelFn := WithCancel() + iter := agent.Run(ctx, &AgentInput{ Messages: []Message{schema.UserMessage("Use the tool")}, - }) + }, cancelOpt) assert.NotNil(t, iter) assert.NotNil(t, cancelFn) @@ -276,7 +300,8 @@ func TestRunWithCancel_WithTools(t *testing.T) { time.Sleep(100 * time.Millisecond) - err = cancelFn(WithCancelMode(CancelAfterChatModel)) + handle, _ := cancelFn(WithAgentCancelMode(CancelAfterChatModel)) + err = handle.Wait() assert.NoError(t, err) var events []*AgentEvent @@ -285,7 +310,10 @@ func TestRunWithCancel_WithTools(t *testing.T) { if !ok { break } - assert.Nil(t, event.Err, "Should not have error event after cancel") + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + continue + } events = append(events, event) } @@ -293,7 +321,7 @@ func TestRunWithCancel_WithTools(t *testing.T) { assert.True(t, atomic.LoadInt32(&st.callCount) >= 1, "Tool should have been called") }) - t.Run("CancelAfterToolCall_CompletesToolExecution", func(t *testing.T) { + t.Run("CancelAfterToolCalls_CompletesToolExecution", func(t *testing.T) { toolStarted := make(chan struct{}, 1) st := &slowToolWithSignal{ name: "slow_tool", @@ -332,9 +360,10 @@ func TestRunWithCancel_WithTools(t *testing.T) { }) assert.NoError(t, err) - iter, cancelFn := agent.RunWithCancel(ctx, &AgentInput{ + cancelOpt, cancelFn := WithCancel() + iter := agent.Run(ctx, &AgentInput{ Messages: []Message{schema.UserMessage("Use the tool")}, - }) + }, cancelOpt) assert.NotNil(t, iter) assert.NotNil(t, cancelFn) @@ -342,7 +371,8 @@ func TestRunWithCancel_WithTools(t *testing.T) { time.Sleep(100 * time.Millisecond) - err = cancelFn(WithCancelMode(CancelAfterToolCall)) + handle, _ := cancelFn(WithAgentCancelMode(CancelAfterToolCalls)) + err = handle.Wait() assert.NoError(t, err) var events []*AgentEvent @@ -351,13 +381,120 @@ func TestRunWithCancel_WithTools(t *testing.T) { if !ok { break } - assert.Nil(t, event.Err, "Should not have error event after cancel") + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + continue + } events = append(events, event) } assert.True(t, len(events) > 0) assert.True(t, atomic.LoadInt32(&st.callCount) >= 1, "Tool should have been called") }) + + t.Run("NestedCancelPropagation", func(t *testing.T) { + cc := newCancelContext() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + child := cc.deriveChild(ctx) + assert.NotNil(t, child) + + cc.setMode(CancelImmediate) + + if atomic.CompareAndSwapInt32(&cc.state, stateRunning, stateCancelling) { + close(cc.cancelChan) + } + + select { + case <-child.cancelChan: + case <-time.After(1 * time.Second): + t.Fatal("Child did not receive cancel signal") + } + + assert.True(t, child.shouldCancel()) + assert.Equal(t, CancelImmediate, child.getMode()) + }) + + t.Run("DeepAgentIntegrationCancel", func(t *testing.T) { + ctx := context.Background() + modelStarted := make(chan struct{}, 1) + + leafModel := &cancelTestChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + Content: "Leaf result", + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + leafModel.setDelay(500 * time.Millisecond) + leafAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "LeafAgent", + Description: "desc", + Model: leafModel, + }) + assert.NoError(t, err) + + rootModel := &cancelTestChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: schema.FunctionCall{ + Name: "LeafAgent", + Arguments: `{}`, + }, + }, + }, + }, + startedChan: make(chan struct{}, 1), + doneChan: make(chan struct{}, 1), + } + rootAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "RootAgent", + Description: "desc", + Model: rootModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{NewAgentTool(ctx, leafAgent)}, + }, + }, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: rootAgent, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("Run leaf")}, cancelOpt) + + <-modelStarted + + handle, _ := cancelFn(WithAgentCancelMode(CancelAfterChatModel)) + err = handle.Wait() + assert.NoError(t, err) + + hasCancelError := false + for { + event, ok := iter.Next() + if !ok { + break + } + if event.Err != nil { + var ce *CancelError + if errors.As(event.Err, &ce) { + hasCancelError = true + assert.NotNil(t, ce.interruptSignal, "CancelError should carry interrupt signal") + } + } + } + assert.True(t, hasCancelError, "Should have received CancelError") + }) } type slowToolWithSignal struct { @@ -386,14 +523,29 @@ func (t *slowToolWithSignal) InvokableRun(ctx context.Context, argumentsInJSON s } type simpleChatModel struct { + delay time.Duration response *schema.Message } func (m *simpleChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + if m.delay > 0 { + select { + case <-time.After(m.delay): + case <-ctx.Done(): + return nil, ctx.Err() + } + } return m.response, nil } func (m *simpleChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + if m.delay > 0 { + select { + case <-time.After(m.delay): + case <-ctx.Done(): + return nil, ctx.Err() + } + } return schema.StreamReaderFromArray([]*schema.Message{m.response}), nil } @@ -401,7 +553,7 @@ func (m *simpleChatModel) BindTools(tools []*schema.ToolInfo) error { return nil } -func TestRunWithCancel_WithCheckpoint(t *testing.T) { +func TestWithCancel_WithCheckpoint(t *testing.T) { ctx := context.Background() t.Run("CancelWithCheckpoint", func(t *testing.T) { @@ -409,7 +561,7 @@ func TestRunWithCancel_WithCheckpoint(t *testing.T) { st := newSlowTool("slow_tool", 100*time.Millisecond, "tool result") slowModel := &cancelTestChatModel{ - delay: 500 * time.Millisecond, + delayNs: int64(1 * time.Second), response: &schema.Message{ Role: schema.Assistant, Content: "", @@ -448,28 +600,38 @@ func TestRunWithCancel_WithCheckpoint(t *testing.T) { CheckPointStore: store, }) - iter, cancelFn := runner.RunWithCancel(ctx, []Message{schema.UserMessage("Use the tool")}, WithCheckPointID("cancel-1")) + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("Use the tool")}, cancelOpt, WithCheckPointID("cancel-1")) <-modelStarted - err = cancelFn() + handle, _ := cancelFn() + err = handle.Wait() assert.NoError(t, err) var events []*AgentEvent + hasCancelError := false + var cancelErrorCheckPointID string for { event, ok := iter.Next() if !ok { break } - assert.Nil(t, event.Err, "Should not have error event after cancel") + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + hasCancelError = true + cancelErrorCheckPointID = ce.CheckPointID + continue + } events = append(events, event) } - assert.True(t, len(events) > 0) + assert.True(t, hasCancelError, "Should have CancelError event after cancel") + assert.Equal(t, "cancel-1", cancelErrorCheckPointID, "CancelError should contain the checkpoint ID") }) } -func TestCancelFuncMultipleCalls(t *testing.T) { +func TestAgentCancelFuncMultipleCalls(t *testing.T) { ctx := context.Background() t.Run("SecondCancelReturnsErrAgentFinished", func(t *testing.T) { @@ -477,7 +639,7 @@ func TestCancelFuncMultipleCalls(t *testing.T) { st := newSlowTool("slow_tool", 100*time.Millisecond, "tool result") slowModel := &cancelTestChatModel{ - delay: 1 * time.Second, + delayNs: int64(1 * time.Second), response: &schema.Message{ Role: schema.Assistant, Content: "", @@ -514,11 +676,13 @@ func TestCancelFuncMultipleCalls(t *testing.T) { EnableStreaming: false, }) - iter, cancelFn := runner.RunWithCancel(ctx, []Message{schema.UserMessage("Use the tool")}) + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("Use the tool")}, cancelOpt) <-modelStarted - cancelErr := cancelFn() + handle, _ := cancelFn() + cancelErr := handle.Wait() assert.NoError(t, cancelErr) for { @@ -530,58 +694,7 @@ func TestCancelFuncMultipleCalls(t *testing.T) { }) } -func TestAgentNotCancellable(t *testing.T) { - ctx := context.Background() - - nonCancellableAgent := &nonCancellableTestAgent{ - name: "NonCancellable", - } - - runner := NewRunner(ctx, RunnerConfig{ - Agent: nonCancellableAgent, - EnableStreaming: false, - }) - - t.Run("RunWithCancelReturnsNilCancelFunc", func(t *testing.T) { - iter, cancelFn := runner.RunWithCancel(ctx, []Message{schema.UserMessage("Hello")}) - assert.NotNil(t, iter) - assert.Nil(t, cancelFn) - - for { - _, ok := iter.Next() - if !ok { - break - } - } - }) -} - -type nonCancellableTestAgent struct { - name string -} - -func (a *nonCancellableTestAgent) Name(_ context.Context) string { - return a.name -} - -func (a *nonCancellableTestAgent) Description(_ context.Context) string { - return "A non-cancellable agent" -} - -func (a *nonCancellableTestAgent) Run(_ context.Context, input *AgentInput, _ ...AgentRunOption) *AsyncIterator[*AgentEvent] { - iter, gen := NewAsyncIteratorPair[*AgentEvent]() - gen.Send(&AgentEvent{ - Output: &AgentOutput{ - MessageOutput: &MessageVariant{ - Message: schema.AssistantMessage("Response", nil), - }, - }, - }) - gen.Close() - return iter -} - -func TestRunWithCancel_Streaming(t *testing.T) { +func TestWithCancel_Streaming(t *testing.T) { ctx := context.Background() t.Run("CancelImmediate_DuringModelStream", func(t *testing.T) { @@ -589,7 +702,7 @@ func TestRunWithCancel_Streaming(t *testing.T) { st := newSlowTool("slow_tool", 100*time.Millisecond, "tool result") slowModel := &cancelTestChatModel{ - delay: 2 * time.Second, + delayNs: int64(2 * time.Second), response: &schema.Message{ Role: schema.Assistant, Content: "", @@ -626,7 +739,8 @@ func TestRunWithCancel_Streaming(t *testing.T) { EnableStreaming: true, }) - iter, cancelFn := runner.RunWithCancel(ctx, []Message{schema.UserMessage("Use the tool")}) + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("Use the tool")}, cancelOpt) assert.NotNil(t, iter) assert.NotNil(t, cancelFn) @@ -651,27 +765,30 @@ func TestRunWithCancel_Streaming(t *testing.T) { time.Sleep(100 * time.Millisecond) - cancelErr := cancelFn() + handle, _ := cancelFn() + cancelErr := handle.Wait() assert.NoError(t, cancelErr) - start := time.Now() - events := <-eventsCh - elapsed := time.Since(start) + var events []*AgentEvent + select { + case events = <-eventsCh: + case <-time.After(5 * time.Second): + t.Fatal("Timed out waiting for events") + } - assert.True(t, elapsed < 1*time.Second, "Should return quickly after cancel, elapsed: %v", elapsed) assert.True(t, len(events) > 0) - hasInterrupted := false + hasCancelError := false for _, e := range events { - assert.Nil(t, e.Err, "Should not have error event after cancel") - if e.Action != nil && e.Action.Interrupted != nil { - hasInterrupted = true + var ce *CancelError + if e.Err != nil && errors.As(e.Err, &ce) { + hasCancelError = true } } - assert.True(t, hasInterrupted, "Should have interrupted event after cancel") + assert.True(t, hasCancelError, "Should have CancelError event after cancel") }) - t.Run("CancelAfterToolCall_Streaming", func(t *testing.T) { + t.Run("CancelAfterToolCalls_Streaming", func(t *testing.T) { toolStarted := make(chan struct{}, 1) st := &slowToolWithSignal{ name: "slow_tool", @@ -715,7 +832,8 @@ func TestRunWithCancel_Streaming(t *testing.T) { EnableStreaming: true, }) - iter, cancelFn := runner.RunWithCancel(ctx, []Message{schema.UserMessage("Use the tool")}) + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("Use the tool")}, cancelOpt) assert.NotNil(t, iter) assert.NotNil(t, cancelFn) @@ -723,7 +841,8 @@ func TestRunWithCancel_Streaming(t *testing.T) { time.Sleep(100 * time.Millisecond) - cancelErr := cancelFn(WithCancelMode(CancelAfterToolCall)) + handle, _ := cancelFn(WithAgentCancelMode(CancelAfterToolCalls)) + cancelErr := handle.Wait() assert.NoError(t, cancelErr) var events []*AgentEvent @@ -732,7 +851,10 @@ func TestRunWithCancel_Streaming(t *testing.T) { if !ok { break } - assert.Nil(t, event.Err, "Should not have error event after cancel") + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + continue + } events = append(events, event) } @@ -741,26 +863,20 @@ func TestRunWithCancel_Streaming(t *testing.T) { }) } -// TestResumeWithCancel tests the workflow of Cancel followed by Resume. -// -// IMPORTANT: When Cancel is triggered, the cancelableChatModel.Generate/Stream -// method returns immediately with an Interrupt error, but the inner model's -// Generate/Stream call continues running in a background goroutine until completion. -// This means the original model instance's fields (e.g., delay, response) may still -// be read by the background goroutine after Cancel returns. +// TestWithCancel_Resume tests the workflow of Cancel followed by Resume. // // To avoid data races, we create new agent and runner instances for the Resume phase // instead of reusing and modifying the original model instance. -func TestResumeWithCancel(t *testing.T) { +func TestWithCancel_Resume(t *testing.T) { ctx := context.Background() - t.Run("RunWithCancel_ThenResumeWithCancel", func(t *testing.T) { + t.Run("Cancel_ThenResume", func(t *testing.T) { modelStarted := make(chan struct{}, 1) modelCallCount := int32(0) st := newSlowTool("slow_tool", 100*time.Millisecond, "tool result") slowModel := &cancelTestChatModel{ - delay: 500 * time.Millisecond, + delayNs: int64(500 * time.Millisecond), response: &schema.Message{ Role: schema.Assistant, Content: "", @@ -800,37 +916,38 @@ func TestResumeWithCancel(t *testing.T) { CheckPointStore: store, }) - iter, cancelFn := runner.RunWithCancel(ctx, []Message{schema.UserMessage("Use the tool")}, WithCheckPointID(checkpointID)) + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("Use the tool")}, cancelOpt, WithCheckPointID(checkpointID)) <-modelStarted atomic.AddInt32(&modelCallCount, 1) - cancelErr := cancelFn() + handle, _ := cancelFn() + cancelErr := handle.Wait() assert.NoError(t, cancelErr) var events []*AgentEvent + hasCancelErr := false for { event, ok := iter.Next() if !ok { break } - assert.Nil(t, event.Err, "Should not have error event after cancel") - events = append(events, event) - } - assert.True(t, len(events) > 0) - - hasInterrupted := false - for _, e := range events { - if e.Action != nil && e.Action.Interrupted != nil { - hasInterrupted = true - break + if event.Err != nil { + var ce *CancelError + if errors.As(event.Err, &ce) { + hasCancelErr = true + continue + } + t.Fatalf("unexpected error: %v", event.Err) } + events = append(events, event) } - assert.True(t, hasInterrupted, "First run should have interrupted event") + assert.True(t, hasCancelErr, "Should have CancelError event after cancel") newModelStarted := make(chan struct{}, 1) slowModel2 := &cancelTestChatModel{ - delay: 100 * time.Millisecond, + delayNs: int64(100 * time.Millisecond), response: &schema.Message{ Role: schema.Assistant, Content: "Final response after resume", @@ -858,10 +975,10 @@ func TestResumeWithCancel(t *testing.T) { CheckPointStore: store, }) - resumeIter, resumeCancelFn, err := runner2.ResumeWithCancel(ctx, checkpointID) + resumeCancelOpt, _ := WithCancel() + resumeIter, err := runner2.Resume(ctx, checkpointID, resumeCancelOpt) assert.NoError(t, err) assert.NotNil(t, resumeIter) - assert.NotNil(t, resumeCancelFn) var resumeEvents []*AgentEvent for { @@ -876,14 +993,13 @@ func TestResumeWithCancel(t *testing.T) { assert.True(t, len(resumeEvents) > 0, "Resume should produce events") }) - t.Run("ResumeWithCancel_ThenCancel", func(t *testing.T) { + t.Run("Resume_ThenCancel", func(t *testing.T) { firstModelStarted := make(chan struct{}, 1) - resumeModelStarted := make(chan struct{}, 1) modelCallCount := int32(0) st := newSlowTool("slow_tool", 100*time.Millisecond, "tool result") slowModel := &cancelTestChatModel{ - delay: 500 * time.Millisecond, + delayNs: int64(500 * time.Millisecond), response: &schema.Message{ Role: schema.Assistant, Content: "", @@ -923,12 +1039,14 @@ func TestResumeWithCancel(t *testing.T) { CheckPointStore: store, }) - iter, cancelFn := runner.RunWithCancel(ctx, []Message{schema.UserMessage("Use the tool")}, WithCheckPointID(checkpointID)) + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("Use the tool")}, cancelOpt, WithCheckPointID(checkpointID)) <-firstModelStarted atomic.AddInt32(&modelCallCount, 1) - cancelErr := cancelFn() + handle, _ := cancelFn() + cancelErr := handle.Wait() assert.NoError(t, cancelErr) for { @@ -938,25 +1056,7 @@ func TestResumeWithCancel(t *testing.T) { } } - slowModel2 := &cancelTestChatModel{ - delay: 2 * time.Second, - response: &schema.Message{ - Role: schema.Assistant, - Content: "", - ToolCalls: []schema.ToolCall{ - { - ID: "call_1", - Type: "function", - Function: schema.FunctionCall{ - Name: "slow_tool", - Arguments: `{"input": "test"}`, - }, - }, - }, - }, - startedChan: resumeModelStarted, - doneChan: make(chan struct{}, 1), - } + slowModel2 := newBlockingChatModel(toolCallMsg(toolCall("call_1", "slow_tool", `{"input": "test"}`))) agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ Name: "TestAgent", @@ -977,7 +1077,8 @@ func TestResumeWithCancel(t *testing.T) { CheckPointStore: store, }) - resumeIter, resumeCancelFn, err := runner2.ResumeWithCancel(ctx, checkpointID) + resumeCancelOpt, resumeCancelFn := WithCancel() + resumeIter, err := runner2.Resume(ctx, checkpointID, resumeCancelOpt) assert.NoError(t, err) resumeEventsCh := make(chan []*AgentEvent, 1) @@ -993,13 +1094,13 @@ func TestResumeWithCancel(t *testing.T) { resumeEventsCh <- events }() - <-resumeModelStarted + <-slowModel2.started atomic.AddInt32(&modelCallCount, 1) - time.Sleep(100 * time.Millisecond) - - err = resumeCancelFn() - assert.NoError(t, err) + cancelHandle, _ := resumeCancelFn() + close(slowModel2.unblockCh) + err = cancelHandle.Wait() + assert.True(t, err == nil || errors.Is(err, ErrExecutionCompleted), "unexpected cancel wait error: %v", err) start := time.Now() resumeEvents := <-resumeEventsCh @@ -1008,13 +1109,2447 @@ func TestResumeWithCancel(t *testing.T) { assert.True(t, elapsed < 1*time.Second, "Resume should return quickly after cancel, elapsed: %v", elapsed) assert.True(t, len(resumeEvents) > 0) - hasInterrupted := false + hasCancelError := false for _, e := range resumeEvents { - assert.Nil(t, e.Err, "Should not have error event after resume cancel") - if e.Action != nil && e.Action.Interrupted != nil { - hasInterrupted = true + var ce *CancelError + if e.Err != nil && errors.As(e.Err, &ce) { + hasCancelError = true } } - assert.True(t, hasInterrupted, "Resume should have interrupted event after cancel") + executionCompletedBeforeCancel := errors.Is(err, ErrExecutionCompleted) + assert.True(t, hasCancelError || executionCompletedBeforeCancel, "Resume should have CancelError event after cancel, or execution completed before cancel") + }) +} + +func TestCancelMonitoredToolHandler_StreamableToolCall(t *testing.T) { + t.Run("NoCancelContext_PassesThrough", func(t *testing.T) { + handler := &cancelMonitoredToolHandler{} + + // Create a stream with some data + r, w := schema.Pipe[string](1) + go func() { + w.Send("chunk1", nil) + w.Send("chunk2", nil) + w.Close() + }() + + next := func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) { + return &compose.StreamToolOutput{Result: r}, nil + } + + wrapped := handler.WrapStreamableToolCall(next) + // No cancelContext in the Go context + output, err := wrapped(context.Background(), &compose.ToolInput{Name: "test"}) + assert.NoError(t, err) + + // Should get the original stream unchanged + chunk1, err := output.Result.Recv() + assert.NoError(t, err) + assert.Equal(t, "chunk1", chunk1) + + chunk2, err := output.Result.Recv() + assert.NoError(t, err) + assert.Equal(t, "chunk2", chunk2) + + _, err = output.Result.Recv() + assert.ErrorIs(t, err, io.EOF) + }) + + t.Run("WithCancelContext_NoCancel_StreamsNormally", func(t *testing.T) { + handler := &cancelMonitoredToolHandler{} + cc := newCancelContext() + + r, w := schema.Pipe[string](1) + go func() { + w.Send("data1", nil) + w.Send("data2", nil) + w.Close() + }() + + next := func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) { + return &compose.StreamToolOutput{Result: r}, nil + } + + wrapped := handler.WrapStreamableToolCall(next) + ctx := withCancelContext(context.Background(), cc) + output, err := wrapped(ctx, &compose.ToolInput{Name: "test"}) + assert.NoError(t, err) + + chunk1, err := output.Result.Recv() + assert.NoError(t, err) + assert.Equal(t, "data1", chunk1) + + chunk2, err := output.Result.Recv() + assert.NoError(t, err) + assert.Equal(t, "data2", chunk2) + + _, err = output.Result.Recv() + assert.ErrorIs(t, err, io.EOF) + }) + + t.Run("WithCancelContext_ImmediateCancel_TerminatesStream", func(t *testing.T) { + handler := &cancelMonitoredToolHandler{} + cc := newCancelContext() + + // Create a slow stream that we'll cancel mid-way + r, w := schema.Pipe[string](1) + go func() { + defer w.Close() + w.Send("chunk1", nil) + time.Sleep(200 * time.Millisecond) + w.Send("chunk2", nil) + }() + + next := func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) { + return &compose.StreamToolOutput{Result: r}, nil + } + + wrapped := handler.WrapStreamableToolCall(next) + ctx := withCancelContext(context.Background(), cc) + output, err := wrapped(ctx, &compose.ToolInput{Name: "test"}) + assert.NoError(t, err) + + // Read first chunk + chunk1, err := output.Result.Recv() + assert.NoError(t, err) + assert.Equal(t, "chunk1", chunk1) + + // Fire immediate cancel + close(cc.immediateChan) + + // Next recv should get ErrStreamCanceled + _, err = output.Result.Recv() + assert.ErrorIs(t, err, ErrStreamCanceled) + }) + + t.Run("WithCancelContext_AlreadyCancelled_TerminatesImmediately", func(t *testing.T) { + handler := &cancelMonitoredToolHandler{} + cc := newCancelContext() + close(cc.immediateChan) // Already canceled + + r, w := schema.Pipe[string](1) + go func() { + w.Send("should-not-see", nil) + w.Close() + }() + + next := func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) { + return &compose.StreamToolOutput{Result: r}, nil + } + + wrapped := handler.WrapStreamableToolCall(next) + ctx := withCancelContext(context.Background(), cc) + output, err := wrapped(ctx, &compose.ToolInput{Name: "test"}) + assert.NoError(t, err) + + _, err = output.Result.Recv() + assert.ErrorIs(t, err, ErrStreamCanceled) + }) + + t.Run("NextReturnsError_PropagatesError", func(t *testing.T) { + handler := &cancelMonitoredToolHandler{} + cc := newCancelContext() + + nextErr := errors.New("tool execution failed") + next := func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) { + return nil, nextErr + } + + wrapped := handler.WrapStreamableToolCall(next) + ctx := withCancelContext(context.Background(), cc) + _, err := wrapped(ctx, &compose.ToolInput{Name: "test"}) + assert.ErrorIs(t, err, nextErr) + }) +} + +func TestCancelMonitoredToolHandler_EnhancedStreamableToolCall(t *testing.T) { + t.Run("NoCancelContext_PassesThrough", func(t *testing.T) { + handler := &cancelMonitoredToolHandler{} + + tr1 := &schema.ToolResult{Parts: []schema.ToolOutputPart{{Type: schema.ToolPartTypeText, Text: "chunk1"}}} + r, w := schema.Pipe[*schema.ToolResult](1) + go func() { + w.Send(tr1, nil) + w.Close() + }() + + next := func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedStreamableToolOutput, error) { + return &compose.EnhancedStreamableToolOutput{Result: r}, nil + } + + wrapped := handler.WrapEnhancedStreamableToolCall(next) + output, err := wrapped(context.Background(), &compose.ToolInput{Name: "test"}) + assert.NoError(t, err) + + result, err := output.Result.Recv() + assert.NoError(t, err) + assert.Equal(t, tr1, result) + + _, err = output.Result.Recv() + assert.ErrorIs(t, err, io.EOF) + }) + + t.Run("WithCancelContext_ImmediateCancel_TerminatesStream", func(t *testing.T) { + handler := &cancelMonitoredToolHandler{} + cc := newCancelContext() + + tr1 := &schema.ToolResult{Parts: []schema.ToolOutputPart{{Type: schema.ToolPartTypeText, Text: "chunk1"}}} + tr2 := &schema.ToolResult{Parts: []schema.ToolOutputPart{{Type: schema.ToolPartTypeText, Text: "chunk2"}}} + r, w := schema.Pipe[*schema.ToolResult](1) + go func() { + defer w.Close() + w.Send(tr1, nil) + time.Sleep(200 * time.Millisecond) + w.Send(tr2, nil) + }() + + next := func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedStreamableToolOutput, error) { + return &compose.EnhancedStreamableToolOutput{Result: r}, nil + } + + wrapped := handler.WrapEnhancedStreamableToolCall(next) + ctx := withCancelContext(context.Background(), cc) + output, err := wrapped(ctx, &compose.ToolInput{Name: "test"}) + assert.NoError(t, err) + + result, err := output.Result.Recv() + assert.NoError(t, err) + assert.Equal(t, tr1, result) + + close(cc.immediateChan) + + _, err = output.Result.Recv() + assert.ErrorIs(t, err, ErrStreamCanceled) + }) + + t.Run("NextReturnsError_PropagatesError", func(t *testing.T) { + handler := &cancelMonitoredToolHandler{} + cc := newCancelContext() + + nextErr := errors.New("enhanced tool failed") + next := func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedStreamableToolOutput, error) { + return nil, nextErr + } + + wrapped := handler.WrapEnhancedStreamableToolCall(next) + ctx := withCancelContext(context.Background(), cc) + _, err := wrapped(ctx, &compose.ToolInput{Name: "test"}) + assert.ErrorIs(t, err, nextErr) + }) +} + +func TestCancelContextKey(t *testing.T) { + t.Run("WithAndGet_RoundTrips", func(t *testing.T) { + cc := newCancelContext() + ctx := withCancelContext(context.Background(), cc) + got := getCancelContext(ctx) + assert.Equal(t, cc, got) + }) + + t.Run("Get_NoValue_ReturnsNil", func(t *testing.T) { + got := getCancelContext(context.Background()) + assert.Nil(t, got) }) + + t.Run("With_NilCancelContext_ReturnsOriginalCtx", func(t *testing.T) { + ctx := context.Background() + result := withCancelContext(ctx, nil) + assert.Equal(t, ctx, result) + }) +} + +// -- Tests for cancel support across all agent types -- + +// cancelTestAgent is a ChatModelAgent-based agent where the model blocks until +// signalled, allowing tests to control exactly when to issue a cancel. +func newCancelTestAgent(t *testing.T, name string, modelDelay time.Duration, modelStarted chan struct{}) *ChatModelAgent { + t.Helper() + slowModel := &cancelTestChatModel{ + delayNs: int64(modelDelay), + response: &schema.Message{ + Role: schema.Assistant, + Content: "response from " + name, + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + + agent, err := NewChatModelAgent(context.Background(), &ChatModelAgentConfig{ + Name: name, + Description: "Test agent " + name, + Instruction: "You are a test assistant", + Model: slowModel, + }) + assert.NoError(t, err) + return agent +} + +func newCancelTestAgentWithTools(t *testing.T, name string, modelDelay time.Duration, modelStarted chan struct{}) *ChatModelAgent { + t.Helper() + toolName := name + "_tool" + slowModel := &cancelTestChatModel{ + delayNs: int64(modelDelay), + response: &schema.Message{ + Role: schema.Assistant, + ToolCalls: []schema.ToolCall{{ + ID: "call_1", Type: "function", + Function: schema.FunctionCall{ + Name: toolName, + Arguments: `{"input": "test"}`, + }, + }}, + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + + st := newSlowTool(toolName, 10*time.Millisecond, "tool result") + + agent, err := NewChatModelAgent(context.Background(), &ChatModelAgentConfig{ + Name: name, + Description: "Test agent " + name, + Instruction: "You are a test assistant", + Model: slowModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + return agent +} + +func newCancelTestAgentWithToolsFinalAnswer(t *testing.T, name string) *ChatModelAgent { + t.Helper() + toolName := name + "_tool" + finalModel := &cancelTestChatModel{ + delayNs: int64(10 * time.Millisecond), + response: &schema.Message{ + Role: schema.Assistant, + Content: "final response from " + name, + }, + startedChan: make(chan struct{}, 1), + doneChan: make(chan struct{}, 1), + } + + st := newSlowTool(toolName, 10*time.Millisecond, "tool result") + + agent, err := NewChatModelAgent(context.Background(), &ChatModelAgentConfig{ + Name: name, + Description: "Test agent " + name, + Instruction: "You are a test assistant", + Model: finalModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + return agent +} + +func TestWithCancel_SequentialAgent(t *testing.T) { + ctx := context.Background() + + t.Run("CancelImmediate_DuringSecondAgent", func(t *testing.T) { + // The first agent completes quickly. The second agent takes a long time. + // Cancel during the second agent's model call. + agent1Started := make(chan struct{}, 1) + agent2Started := make(chan struct{}, 1) + + agent1 := newCancelTestAgent(t, "fast_agent", 50*time.Millisecond, agent1Started) + agent2 := newCancelTestAgent(t, "slow_agent", 5*time.Second, agent2Started) + + seqAgent, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "seq_agent", + Description: "Sequential test", + SubAgents: []Agent{agent1, agent2}, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: seqAgent, + EnableStreaming: false, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt) + + // Wait for second agent to start + select { + case <-agent2Started: + case <-time.After(10 * time.Second): + t.Fatal("Second agent did not start") + } + + time.Sleep(50 * time.Millisecond) + + // Cancel should NOT return ErrExecutionCompleted (the bug before the fix) + handle, _ := cancelFn() + err = handle.Wait() + assert.NoError(t, err, "Cancel during second agent should succeed, not return ErrExecutionCompleted") + + var events []*AgentEvent + hasCancelError := false + for { + event, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + hasCancelError = true + } + events = append(events, event) + } + + assert.True(t, hasCancelError, "Should have CancelError event") + }) +} + +func TestWithCancel_LoopAgent(t *testing.T) { + ctx := context.Background() + + t.Run("CancelImmediate_DuringIteration", func(t *testing.T) { + // Agent in a loop. Cancel during second iteration's model call. + modelStarted := make(chan struct{}, 10) + + slowModel := &cancelTestChatModel{ + delayNs: int64(3 * time.Second), + response: &schema.Message{ + Role: schema.Assistant, + Content: "loop response", + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 10), + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "loop_inner", + Description: "Inner loop agent", + Instruction: "You are a test assistant", + Model: slowModel, + }) + assert.NoError(t, err) + + loopAgent, err := NewLoopAgent(ctx, &LoopAgentConfig{ + Name: "loop_agent", + Description: "Loop test", + SubAgents: []Agent{agent}, + MaxIterations: 10, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: loopAgent, + EnableStreaming: false, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt) + + // Wait for first iteration's model call to start + select { + case <-modelStarted: + case <-time.After(10 * time.Second): + t.Fatal("Model did not start") + } + + time.Sleep(50 * time.Millisecond) + + // Cancel should succeed + handle, _ := cancelFn() + err = handle.Wait() + assert.NoError(t, err, "Cancel during loop iteration should succeed") + + hasCancelError := false + for { + event, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + hasCancelError = true + } + } + + assert.True(t, hasCancelError, "Should have CancelError event") + }) +} + +func TestWithCancel_ParallelAgent(t *testing.T) { + ctx := context.Background() + + t.Run("CancelImmediate_InterruptsAllBranches", func(t *testing.T) { + agent1Started := make(chan struct{}, 1) + agent2Started := make(chan struct{}, 1) + + // Both agents have long delays, so cancel should interrupt both. + agent1 := newCancelTestAgent(t, "par_agent1", 5*time.Second, agent1Started) + agent2 := newCancelTestAgent(t, "par_agent2", 5*time.Second, agent2Started) + + parAgent, err := NewParallelAgent(ctx, &ParallelAgentConfig{ + Name: "par_agent", + Description: "Parallel test", + SubAgents: []Agent{agent1, agent2}, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: parAgent, + EnableStreaming: false, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt) + + // Wait for both agents to start + for i := 0; i < 2; i++ { + select { + case <-agent1Started: + case <-agent2Started: + case <-time.After(10 * time.Second): + t.Fatal("Parallel agents did not start") + } + } + + time.Sleep(50 * time.Millisecond) + + start := time.Now() + handle, _ := cancelFn() + err = handle.Wait() + assert.NoError(t, err, "Cancel during parallel agents should succeed") + + var events []*AgentEvent + hasCancelError := false + for { + event, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + hasCancelError = true + } + events = append(events, event) + } + elapsed := time.Since(start) + + assert.True(t, hasCancelError, "Should have CancelError event") + assert.True(t, elapsed < 3*time.Second, "Should complete quickly after cancel, elapsed: %v", elapsed) + }) +} + +func TestWithCancel_SupervisorAgent(t *testing.T) { + ctx := context.Background() + + t.Run("CancelImmediate_DuringSubAgent", func(t *testing.T) { + // Supervisor delegates to a slow sub-agent via transfer. + // Cancel during the sub-agent's model call. + supervisorModelStarted := make(chan struct{}, 1) + subAgentModelStarted := make(chan struct{}, 1) + + // The supervisor model returns a transfer_to_agent tool call + supervisorModel := &simpleChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: schema.FunctionCall{ + Name: TransferToAgentToolName, + Arguments: `{"agent_name": "slow_sub"}`, + }, + }, + }, + }, + } + + supervisorAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "supervisor", + Description: "Supervisor agent", + Instruction: "You are a supervisor", + Model: supervisorModel, + }) + assert.NoError(t, err) + + subAgent := newCancelTestAgent(t, "slow_sub", 5*time.Second, subAgentModelStarted) + + agentWithSubAgents, err := SetSubAgents(ctx, supervisorAgent, []Agent{subAgent}) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: agentWithSubAgents, + EnableStreaming: false, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt) + + // Ignore the supervisor model start, wait for the sub-agent model + // The supervisor model is fast (simpleChatModel), so it will start and finish quickly + _ = supervisorModelStarted + select { + case <-subAgentModelStarted: + case <-time.After(10 * time.Second): + t.Fatal("Sub-agent model did not start") + } + + time.Sleep(50 * time.Millisecond) + + start := time.Now() + handle, _ := cancelFn() + err = handle.Wait() + assert.NoError(t, err, "Cancel during sub-agent should succeed") + + hasCancelError := false + for { + event, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + hasCancelError = true + } + } + elapsed := time.Since(start) + + assert.True(t, hasCancelError, "Should have CancelError event") + assert.True(t, elapsed < 3*time.Second, "Should complete quickly after cancel, elapsed: %v", elapsed) + }) +} + +func TestFilterCancelOption(t *testing.T) { + t.Run("RemovesCancelOption", func(t *testing.T) { + cancelOpt, _ := WithCancel() + sessionOpt := WithSessionValues(map[string]any{"key": "value"}) + opts := []AgentRunOption{cancelOpt, sessionOpt} + + filtered := filterCancelOption(opts) + assert.Len(t, filtered, 1, "Should have removed the cancel option") + + // Verify the remaining option is the session option + testOpt := &options{} + filtered[0].implSpecificOptFn.(func(*options))(testOpt) + assert.NotNil(t, testOpt.sessionValues) + assert.Nil(t, testOpt.cancelCtx) + }) + + t.Run("KeepsNonCancelOptions", func(t *testing.T) { + sessionOpt := WithSessionValues(map[string]any{"key": "value"}) + callbackOpt := WithCallbacks() + opts := []AgentRunOption{sessionOpt, callbackOpt} + + filtered := filterCancelOption(opts) + assert.Len(t, filtered, 2, "Should keep all non-cancel options") + }) + + t.Run("EmptyInput", func(t *testing.T) { + filtered := filterCancelOption(nil) + assert.Nil(t, filtered) + }) +} + +func wrapIterWithMarkDone(iter *AsyncIterator[*AgentEvent], cc *cancelContext) *AsyncIterator[*AgentEvent] { + if cc == nil { + return iter + } + outIter, outGen := NewAsyncIteratorPair[*AgentEvent]() + go func() { + defer cc.markDone() + defer outGen.Close() + for { + event, ok := iter.Next() + if !ok { + return + } + outGen.Send(event) + } + }() + return outIter +} + +func TestWrapIterWithMarkDone(t *testing.T) { + t.Run("MarksDoneAfterDrain", func(t *testing.T) { + cc := newCancelContext() + iter, gen := NewAsyncIteratorPair[*AgentEvent]() + + go func() { + gen.Send(&AgentEvent{AgentName: "test"}) + gen.Close() + }() + + wrapped := wrapIterWithMarkDone(iter, cc) + + event, ok := wrapped.Next() + assert.True(t, ok) + assert.Equal(t, "test", event.AgentName) + + _, ok = wrapped.Next() + assert.False(t, ok) + + // markDone should have been called, so doneChan should be closed + select { + case <-cc.doneChan: + // good + case <-time.After(time.Second): + t.Fatal("doneChan was not closed after drain") + } + }) + + t.Run("NilCancelContext_PassesThrough", func(t *testing.T) { + iter, gen := NewAsyncIteratorPair[*AgentEvent]() + go func() { + gen.Send(&AgentEvent{AgentName: "test"}) + gen.Close() + }() + + wrapped := wrapIterWithMarkDone(iter, nil) + assert.Equal(t, iter, wrapped, "Should return same iter when cc is nil") + }) +} + +func TestGraphInterruptFuncs_Parallel(t *testing.T) { + t.Run("MultipleGraphInterruptFuncsAllCalled", func(t *testing.T) { + cc := newCancelContext() + + var called1, called2 int32 + cc.setGraphInterruptFunc(func(opts ...compose.GraphInterruptOption) { + atomic.AddInt32(&called1, 1) + }) + cc.setGraphInterruptFunc(func(opts ...compose.GraphInterruptOption) { + atomic.AddInt32(&called2, 1) + }) + + // Simulate immediate cancel + cc.setMode(CancelImmediate) + atomic.CompareAndSwapInt32(&cc.state, stateRunning, stateCancelling) + close(cc.cancelChan) + cc.sendImmediateInterrupt() + + assert.Equal(t, int32(1), atomic.LoadInt32(&called1), "First graph interrupt func should be called") + assert.Equal(t, int32(1), atomic.LoadInt32(&called2), "Second graph interrupt func should be called") + }) + + t.Run("RetroactiveFire_OnSetAfterCancel", func(t *testing.T) { + cc := newCancelContext() + + // First set up cancel state with immediate interrupt + cc.setMode(CancelImmediate) + atomic.CompareAndSwapInt32(&cc.state, stateRunning, stateCancelling) + close(cc.cancelChan) + close(cc.immediateChan) + atomic.StoreInt32(&cc.interruptSent, interruptImmediate) + + // Now register a new function - it should be retroactively fired + var called int32 + cc.setGraphInterruptFunc(func(opts ...compose.GraphInterruptOption) { + atomic.AddInt32(&called, 1) + }) + + assert.Equal(t, int32(1), atomic.LoadInt32(&called), "setGraphInterruptFunc should retroactively fire new func") + }) +} + +// -- Tests for transition-point cancel (cancel between sub-agents) -- + +// gatedChatModel is a model that: +// - Signals doneChan when Generate completes +// - Optionally blocks on gateChan before returning (nil gateChan = no blocking) +// - Tracks call count via callCount +type gatedChatModel struct { + response *schema.Message + gateChan chan struct{} // if non-nil, blocks until closed before returning + doneChan chan struct{} // signalled after Generate completes + callCount int32 +} + +func (m *gatedChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m.callCount, 1) + if m.gateChan != nil { + select { + case <-m.gateChan: + case <-ctx.Done(): + return nil, ctx.Err() + } + } + select { + case m.doneChan <- struct{}{}: + default: + } + return m.response, nil +} + +func (m *gatedChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + msg, err := m.Generate(ctx, input, opts...) + if err != nil { + return nil, err + } + return schema.StreamReaderFromArray([]*schema.Message{msg}), nil +} + +func (m *gatedChatModel) BindTools(tools []*schema.ToolInfo) error { + return nil +} + +func TestCheckCancel_Sequential_BetweenSubAgents(t *testing.T) { + ctx := context.Background() + + // CancelAfterToolCalls fires at transition boundaries between sub-agents. + // At a transition boundary, the completed sub-agent's entire execution + // (including any tool calls) is done, satisfying the CancelAfterToolCalls + // contract — even if this particular sub-agent had no tools. + model1 := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "agent1 done"}, + gateChan: make(chan struct{}), + doneChan: make(chan struct{}, 1), + } + model2 := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "agent2 done"}, + doneChan: make(chan struct{}, 1), + } + + agent1, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent1", Description: "first", Instruction: "test", Model: model1, + }) + assert.NoError(t, err) + + agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent2", Description: "second", Instruction: "test", Model: model2, + }) + assert.NoError(t, err) + + seqAgent, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "seq", Description: "sequential test", SubAgents: []Agent{agent1, agent2}, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: seqAgent, EnableStreaming: false, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("hello")}, cancelOpt) + + for atomic.LoadInt32(&model1.callCount) == 0 { + runtime.Gosched() + } + + cancelCalled, result := cancelAsync(cancelFn, WithAgentCancelMode(CancelAfterToolCalls)) + waitForChan(t, cancelCalled, "cancelFn was not called") + close(model1.gateChan) + + assert.NoError(t, result.waitDone(t), "CancelAfterToolCalls should succeed at transition boundary") + + for { + _, ok := iter.Next() + if !ok { + break + } + } + + assert.Equal(t, int32(1), atomic.LoadInt32(&model1.callCount), "Agent1 model should be invoked") + assert.Equal(t, int32(0), atomic.LoadInt32(&model2.callCount), + "Agent2 model should NOT be invoked (CancelAfterToolCalls caught at transition)") +} + +func TestCheckCancel_Loop_BetweenIterations(t *testing.T) { + ctx := context.Background() + + // CancelAfterToolCalls fires at loop iteration boundaries. + // After the first iteration completes, any tool calls it made are done, + // satisfying the CancelAfterToolCalls contract. + mdl := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "loop iter"}, + gateChan: make(chan struct{}), + doneChan: make(chan struct{}, 10), + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "loop_inner", Description: "inner", Instruction: "test", Model: mdl, + }) + assert.NoError(t, err) + + loopAgent, err := NewLoopAgent(ctx, &LoopAgentConfig{ + Name: "loop", Description: "loop test", SubAgents: []Agent{agent}, MaxIterations: 3, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: loopAgent, EnableStreaming: false, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("hello")}, cancelOpt) + + for atomic.LoadInt32(&mdl.callCount) == 0 { + runtime.Gosched() + } + + cancelCalled, result := cancelAsync(cancelFn, WithAgentCancelMode(CancelAfterToolCalls)) + waitForChan(t, cancelCalled, "cancelFn was not called") + close(mdl.gateChan) + + assert.NoError(t, result.waitDone(t), "CancelAfterToolCalls should succeed at loop transition boundary") + + for { + _, ok := iter.Next() + if !ok { + break + } + } + + assert.Equal(t, int32(1), atomic.LoadInt32(&mdl.callCount), + "Model should be called once; second iteration caught at transition") +} + +func TestCheckCancel_Parallel_PreSpawn(t *testing.T) { + ctx := context.Background() + + // Cancel fires before Run is called. Neither model should be invoked. + model1 := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "par1"}, + doneChan: make(chan struct{}, 1), + } + model2 := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "par2"}, + doneChan: make(chan struct{}, 1), + } + + agent1, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "par1", Description: "first", Instruction: "test", Model: model1, + }) + assert.NoError(t, err) + + agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "par2", Description: "second", Instruction: "test", Model: model2, + }) + assert.NoError(t, err) + + parAgent, err := NewParallelAgent(ctx, &ParallelAgentConfig{ + Name: "par", Description: "parallel test", SubAgents: []Agent{agent1, agent2}, + }) + assert.NoError(t, err) + + // Fire cancel in goroutine (cancelFn blocks until handled) + cancelOpt, cancelFn := WithCancel() + cancelDone := make(chan error, 1) + go func() { + handle, _ := cancelFn() + cancelDone <- handle.Wait() + }() + // Wait for cancelChan to be closed (happens synchronously before the blocking doneChan wait) + time.Sleep(20 * time.Millisecond) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: parAgent, EnableStreaming: false, + }) + + iter := runner.Run(ctx, []Message{schema.UserMessage("hello")}, cancelOpt) + + var cancelErr *CancelError + for { + event, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + cancelErr = ce + } + } + + // cancelFn should have completed + select { + case err = <-cancelDone: + assert.NoError(t, err) + case <-time.After(5 * time.Second): + t.Fatal("cancelFn did not return") + } + + assert.NotNil(t, cancelErr, "Should have CancelError") + assert.Equal(t, int32(0), atomic.LoadInt32(&model1.callCount), "First model should never be invoked") + assert.Equal(t, int32(0), atomic.LoadInt32(&model2.callCount), "Second model should never be invoked") +} + +func TestCheckCancel_Transfer_BeforeTarget(t *testing.T) { + ctx := context.Background() + + // Supervisor CMA returns a transfer action (instantly). + // Cancel fires after transfer action but before target runs. + // Target model should never be invoked. + supervisorModel := &simpleChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + ToolCalls: []schema.ToolCall{{ + ID: "call_1", Type: "function", + Function: schema.FunctionCall{ + Name: TransferToAgentToolName, + Arguments: `{"agent_name": "target"}`, + }, + }}, + }, + } + targetModel := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "target done"}, + doneChan: make(chan struct{}, 1), + } + + supervisorAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "supervisor", Description: "supervisor", Instruction: "test", Model: supervisorModel, + }) + assert.NoError(t, err) + + targetAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "target", Description: "target", Instruction: "test", Model: targetModel, + }) + assert.NoError(t, err) + + agentWithSub, err := SetSubAgents(ctx, supervisorAgent, []Agent{targetAgent}) + assert.NoError(t, err) + + // Fire cancel in goroutine (cancelFn blocks until handled) + cancelOpt, cancelFn := WithCancel() + cancelDone := make(chan error, 1) + go func() { + handle, _ := cancelFn() + cancelDone <- handle.Wait() + }() + time.Sleep(20 * time.Millisecond) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: agentWithSub, EnableStreaming: false, + }) + + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt) + + var cancelErr *CancelError + for { + event, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + cancelErr = ce + } + } + + select { + case err = <-cancelDone: + assert.NoError(t, err) + case <-time.After(5 * time.Second): + t.Fatal("cancelFn did not return") + } + + assert.NotNil(t, cancelErr, "Should have CancelError") + assert.Equal(t, int32(0), atomic.LoadInt32(&targetModel.callCount), "Target model should never be invoked") +} + +func TestCheckCancel_AlreadyHandled_NoDuplicate(t *testing.T) { + ctx := context.Background() + + // In a sequential agent, if the first CMA handles the cancel (graph interrupt), + // the workflow's transition check should NOT emit a duplicate CancelError. + // Use a slow model so cancel fires during its execution (handled by CMA). + modelStarted := make(chan struct{}, 1) + model1 := &cancelTestChatModel{ + delayNs: int64(2 * time.Second), + response: &schema.Message{Role: schema.Assistant, Content: "agent1"}, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + model2 := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "agent2"}, + doneChan: make(chan struct{}, 1), + } + + agent1, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent1", Description: "first", Instruction: "test", Model: model1, + }) + assert.NoError(t, err) + + agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent2", Description: "second", Instruction: "test", Model: model2, + }) + assert.NoError(t, err) + + seqAgent, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "seq", Description: "sequential", SubAgents: []Agent{agent1, agent2}, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: seqAgent, EnableStreaming: false, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt) + + // Wait for model to start, then cancel during model execution + select { + case <-modelStarted: + case <-time.After(5 * time.Second): + t.Fatal("Model did not start") + } + time.Sleep(50 * time.Millisecond) + handle, _ := cancelFn() + err = handle.Wait() + assert.NoError(t, err) + + cancelCount := 0 + for { + event, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + cancelCount++ + } + } + + assert.Equal(t, 1, cancelCount, "Should have exactly one CancelError, no duplicate from workflow transition") + assert.Equal(t, int32(0), atomic.LoadInt32(&model2.callCount), "Second agent should not run") +} + +// Tests for CancelAfterChatModel/CancelAfterToolCalls in nested workflow structures. +// These verify that safe-point cancel modes propagate through the entire agent hierarchy +// and fire at whichever nested level reaches the safe-point first. + +func TestCancel_SequentialWorkflow_CancelAfterChatModel(t *testing.T) { + ctx := context.Background() + agent1Started := make(chan struct{}, 1) + + agent1 := newCancelTestAgentWithTools(t, "seq_slow", 500*time.Millisecond, agent1Started) + agent2 := newCancelTestAgentWithTools(t, "seq_fast", 50*time.Millisecond, make(chan struct{}, 1)) + + seqAgent, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "seq_agent", + Description: "Sequential workflow", + SubAgents: []Agent{agent1, agent2}, + }) + assert.NoError(t, err) + + store := newCancelTestStore() + runner := NewRunner(ctx, RunnerConfig{ + Agent: seqAgent, + CheckPointStore: store, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt, WithCheckPointID("seq-cancel-1")) + + select { + case <-agent1Started: + case <-time.After(10 * time.Second): + t.Fatal("First agent did not start") + } + + handle, contributed := cancelFn(WithAgentCancelMode(CancelAfterChatModel)) + assert.True(t, contributed, "Cancel should contribute") + err = handle.Wait() + assert.NoError(t, err) + + hasCancelError := false + var cancelErr *CancelError + for { + event, ok := iter.Next() + if !ok { + break + } + if event.Err != nil && errors.As(event.Err, &cancelErr) { + hasCancelError = true + } + } + + assert.True(t, hasCancelError, "Should have CancelError") + assert.Equal(t, CancelAfterChatModel, cancelErr.Info.Mode) + assert.NotEmpty(t, cancelErr.CheckPointID, "CancelError should have checkpoint ID") + assert.NotNil(t, cancelErr.interruptSignal, "CancelError should have interrupt signal for checkpoint") + + resumeAgent1 := newCancelTestAgentWithToolsFinalAnswer(t, "seq_slow") + resumeAgent2 := newCancelTestAgentWithToolsFinalAnswer(t, "seq_fast") + + resumeSeq, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "seq_agent", + Description: "Sequential workflow", + SubAgents: []Agent{resumeAgent1, resumeAgent2}, + }) + assert.NoError(t, err) + + runner2 := NewRunner(ctx, RunnerConfig{ + Agent: resumeSeq, + CheckPointStore: store, + }) + + resumeIter, err := runner2.Resume(ctx, "seq-cancel-1") + assert.NoError(t, err) + assert.NotNil(t, resumeIter) + + var resumeEvents []*AgentEvent + for { + event, ok := resumeIter.Next() + if !ok { + break + } + assert.Nil(t, event.Err, "Should not have error during resume") + resumeEvents = append(resumeEvents, event) + } + assert.True(t, len(resumeEvents) > 0, "Resume should produce events") +} + +// -- Tests for CancelImmediate in nested agent structures -- + +func newTestChatModel(response *schema.Message, delay time.Duration) *cancelTestChatModel { + m := &cancelTestChatModel{ + response: response, + startedChan: make(chan struct{}, 1), + doneChan: make(chan struct{}, 1), + } + if delay > 0 { + m.setDelay(delay) + } + return m +} + +func newToolCallResponse(toolName string) *schema.Message { + return &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + {ID: "call_1", Type: "function", Function: schema.FunctionCall{Name: toolName, Arguments: `{}`}}, + }, + } +} + +func newAgentWithTool(t *testing.T, ctx context.Context, name string, mdl model.BaseChatModel, subAgent Agent) (Agent, error) { + t.Helper() + return NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: name, + Description: name, + Model: mdl, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{NewAgentTool(ctx, subAgent)}, + }, + }, + }) +} + +func waitForChan(t *testing.T, ch <-chan struct{}, msg string) { + t.Helper() + select { + case <-ch: + case <-time.After(10 * time.Second): + t.Fatal(msg) + } +} + +func drainCancelError(t *testing.T, iter *AsyncIterator[*AgentEvent]) *CancelError { + t.Helper() + var cancelErr *CancelError + for { + event, ok := iter.Next() + if !ok { + break + } + if event.Err != nil { + errors.As(event.Err, &cancelErr) + } + } + return cancelErr +} + +func drainResumeErrors(t *testing.T, iter *AsyncIterator[*AgentEvent]) []error { + t.Helper() + var errs []error + for { + event, ok := iter.Next() + if !ok { + break + } + if event.Err != nil { + errs = append(errs, event.Err) + } + } + return errs +} + +type cancelResult struct { + err error + contributed bool + done chan struct{} +} + +func cancelAsync(cancelFn AgentCancelFunc, opts ...AgentCancelOption) (cancelCalled chan struct{}, result *cancelResult) { + cancelCalled = make(chan struct{}) + result = &cancelResult{done: make(chan struct{})} + go func() { + handle, contributed := cancelFn(opts...) + result.contributed = contributed + close(cancelCalled) + result.err = handle.Wait() + close(result.done) + }() + return +} + +func (r *cancelResult) waitDone(t *testing.T) error { + t.Helper() + select { + case <-r.done: + return r.err + case <-time.After(10 * time.Second): + t.Fatal("cancel did not complete") + return nil + } +} + +func TestCancelImmediate_AgentTool_PreservesChildCheckpoint(t *testing.T) { + ctx := context.Background() + + leafModel := newTestChatModel( + &schema.Message{Role: schema.Assistant, Content: "leaf response"}, 2*time.Second) + leafAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "leaf_agent", Description: "Leaf agent in agentTool", Model: leafModel, + }) + assert.NoError(t, err) + + seqAgent, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "inner_seq", Description: "Inner sequential workflow", SubAgents: []Agent{leafAgent}, + }) + assert.NoError(t, err) + + rootModel := newTestChatModel(newToolCallResponse("inner_seq"), 0) + rootAgent, err := newAgentWithTool(t, ctx, "root_agent", rootModel, seqAgent) + assert.NoError(t, err) + + store := newCancelTestStore() + runner := NewRunner(ctx, RunnerConfig{Agent: rootAgent, CheckPointStore: store}) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt, WithCheckPointID("immediate-agent-tool-1")) + + waitForChan(t, leafModel.startedChan, "Leaf agent model did not start") + + handle, contributed := cancelFn() + assert.True(t, contributed) + assert.NoError(t, handle.Wait()) + + cancelErr := drainCancelError(t, iter) + assert.NotNil(t, cancelErr, "Should have CancelError from CancelImmediate through agentTool") + assert.NotEmpty(t, cancelErr.CheckPointID) + assert.NotNil(t, cancelErr.interruptSignal) + + resumeLeaf, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "leaf_agent", Description: "Leaf agent in agentTool", + Model: newTestChatModel(&schema.Message{Role: schema.Assistant, Content: "resumed leaf"}, 0), + }) + assert.NoError(t, err) + resumeSeq, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "inner_seq", Description: "Inner sequential workflow", SubAgents: []Agent{resumeLeaf}, + }) + assert.NoError(t, err) + resumeRoot, err := newAgentWithTool(t, ctx, "root_agent", + newTestChatModel(&schema.Message{Role: schema.Assistant, Content: "resumed root"}, 0), resumeSeq) + assert.NoError(t, err) + + runner2 := NewRunner(ctx, RunnerConfig{Agent: resumeRoot, CheckPointStore: store}) + resumeIter, err := runner2.Resume(ctx, "immediate-agent-tool-1") + assert.NoError(t, err) + assert.Empty(t, drainResumeErrors(t, resumeIter), "Resume should complete without errors") +} + +func TestCancelImmediate_ParallelWorkflow_WithAgentTool(t *testing.T) { + ctx := context.Background() + + leafModel := newTestChatModel( + &schema.Message{Role: schema.Assistant, Content: "leaf response"}, 2*time.Second) + leafAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "leaf_agent", Description: "Leaf agent in agentTool", Model: leafModel, + }) + assert.NoError(t, err) + + agentWithTool, err := newAgentWithTool(t, ctx, "agent_with_tool", + newTestChatModel(newToolCallResponse("leaf_agent"), 0), leafAgent) + assert.NoError(t, err) + + simpleStarted := make(chan struct{}, 1) + simpleAgent := newCancelTestAgent(t, "simple_agent", 2*time.Second, simpleStarted) + + parAgent, err := NewParallelAgent(ctx, &ParallelAgentConfig{ + Name: "par_agent", Description: "Parallel with agentTool and simple agent", + SubAgents: []Agent{agentWithTool, simpleAgent}, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{Agent: parAgent, EnableStreaming: false}) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt) + + waitForChan(t, leafModel.startedChan, "Leaf agent did not start") + waitForChan(t, simpleStarted, "Simple agent did not start") + + start := time.Now() + handle, _ := cancelFn() + assert.NoError(t, handle.Wait()) + + cancelErr := drainCancelError(t, iter) + elapsed := time.Since(start) + + assert.NotNil(t, cancelErr, "Should have CancelError from parallel with agentTool") + assert.True(t, elapsed < 5*time.Second, "Should complete quickly after cancel, elapsed: %v", elapsed) +} + +type cancelUnawareAgent struct { + name string + desc string + delay time.Duration + response string +} + +type multiResponseGatedModel struct { + responses []*schema.Message + gateChan chan struct{} + gateOnce bool + gated int32 + doneChan chan struct{} + callCount int32 +} + +func (m *multiResponseGatedModel) Generate(ctx context.Context, msgs []*schema.Message, opts ...model.Option) (*schema.Message, error) { + idx := atomic.AddInt32(&m.callCount, 1) + if m.gateChan != nil && (!m.gateOnce || atomic.CompareAndSwapInt32(&m.gated, 0, 1)) { + select { + case <-m.gateChan: + case <-ctx.Done(): + return nil, ctx.Err() + } + } + if len(m.responses) == 0 { + return nil, fmt.Errorf("multiResponseGatedModel: no responses configured") + } + resp := m.responses[(int(idx)-1)%len(m.responses)] + if m.doneChan != nil { + select { + case m.doneChan <- struct{}{}: + default: + } + } + return resp, nil +} + +func (m *multiResponseGatedModel) Stream(ctx context.Context, msgs []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + resp, err := m.Generate(ctx, msgs, opts...) + if err != nil { + return nil, err + } + return schema.StreamReaderFromArray([]*schema.Message{resp}), nil +} + +func (m *multiResponseGatedModel) BindTools(tools []*schema.ToolInfo) error { return nil } + +func (a *cancelUnawareAgent) Name(_ context.Context) string { return a.name } +func (a *cancelUnawareAgent) Description(_ context.Context) string { return a.desc } + +func (a *cancelUnawareAgent) Run(_ context.Context, _ *AgentInput, _ ...AgentRunOption) *AsyncIterator[*AgentEvent] { + iter, gen := NewAsyncIteratorPair[*AgentEvent]() + go func() { + defer gen.Close() + // Intentionally ignores ctx.Done() — simulates a custom agent that + // does not participate in the cancel protocol at all. + // Delay is kept short (relative to grace period) to avoid goroutine + // leak lasting long after the test completes. + time.Sleep(a.delay) + }() + return iter +} + +func TestCancelImmediate_CustomAgent_GracePeriodFallback(t *testing.T) { + ctx := context.Background() + + customAgent := &cancelUnawareAgent{ + name: "custom_slow", desc: "A custom agent that ignores cancel", + delay: 5 * time.Second, response: "custom response", + } + + rootModel := newTestChatModel(newToolCallResponse("custom_slow"), 0) + rootAgent, err := newAgentWithTool(t, ctx, "root_agent", rootModel, customAgent) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{Agent: rootAgent, EnableStreaming: false}) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt) + + waitForChan(t, rootModel.startedChan, "Root model did not start") + waitForChan(t, rootModel.doneChan, "Root model did not finish") + + start := time.Now() + handle, _ := cancelFn() + assert.NoError(t, handle.Wait()) + + cancelErr := drainCancelError(t, iter) + elapsed := time.Since(start) + + assert.NotNil(t, cancelErr, "Should have CancelError (from grace period fallback)") + assert.True(t, elapsed < 5*time.Second, + "Should complete within grace period + overhead, elapsed: %v", elapsed) +} + +func TestCancelImmediate_MultiLevelNesting(t *testing.T) { + ctx := context.Background() + + innerLeafModel := newTestChatModel( + &schema.Message{Role: schema.Assistant, Content: "inner leaf response"}, 2*time.Second) + innerLeafAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "inner_leaf", Description: "Innermost leaf agent", Model: innerLeafModel, + }) + assert.NoError(t, err) + + middleAgent, err := newAgentWithTool(t, ctx, "middle_agent", + newTestChatModel(newToolCallResponse("inner_leaf"), 0), innerLeafAgent) + assert.NoError(t, err) + + rootAgent, err := newAgentWithTool(t, ctx, "root_agent", + newTestChatModel(newToolCallResponse("middle_agent"), 0), middleAgent) + assert.NoError(t, err) + + store := newCancelTestStore() + runner := NewRunner(ctx, RunnerConfig{Agent: rootAgent, CheckPointStore: store}) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt, WithCheckPointID("multi-level-1")) + + waitForChan(t, innerLeafModel.startedChan, "Inner leaf model did not start") + + start := time.Now() + handle, contributed := cancelFn() + assert.True(t, contributed) + assert.NoError(t, handle.Wait()) + + cancelErr := drainCancelError(t, iter) + elapsed := time.Since(start) + + assert.NotNil(t, cancelErr, "Should have CancelError from multi-level nesting") + assert.NotEmpty(t, cancelErr.CheckPointID) + assert.NotNil(t, cancelErr.interruptSignal) + assert.True(t, elapsed < 5*time.Second, "Should complete quickly, elapsed: %v", elapsed) + + resumeInnerLeaf, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "inner_leaf", Description: "Innermost leaf agent", + Model: newTestChatModel(&schema.Message{Role: schema.Assistant, Content: "resumed inner leaf"}, 0), + }) + assert.NoError(t, err) + resumeMiddle, err := newAgentWithTool(t, ctx, "middle_agent", + newTestChatModel(&schema.Message{Role: schema.Assistant, Content: "resumed middle"}, 0), resumeInnerLeaf) + assert.NoError(t, err) + resumeRoot, err := newAgentWithTool(t, ctx, "root_agent", + newTestChatModel(&schema.Message{Role: schema.Assistant, Content: "resumed root"}, 0), resumeMiddle) + assert.NoError(t, err) + + runner2 := NewRunner(ctx, RunnerConfig{Agent: resumeRoot, CheckPointStore: store}) + resumeIter, err := runner2.Resume(ctx, "multi-level-1") + assert.NoError(t, err) + assert.Empty(t, drainResumeErrors(t, resumeIter), "Resume should complete without errors") +} + +func TestCancelImmediate_SequentialTransitionBoundary(t *testing.T) { + ctx := context.Background() + + model1 := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "agent1 done"}, + gateChan: make(chan struct{}), + doneChan: make(chan struct{}, 1), + } + model2 := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "agent2 done"}, + doneChan: make(chan struct{}, 1), + } + + agent1, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent1", Description: "first", Instruction: "test", Model: model1, + }) + assert.NoError(t, err) + + agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent2", Description: "second", Instruction: "test", Model: model2, + }) + assert.NoError(t, err) + + seqAgent, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "seq", Description: "sequential test", SubAgents: []Agent{agent1, agent2}, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: seqAgent, EnableStreaming: false, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("hello")}, cancelOpt) + + for atomic.LoadInt32(&model1.callCount) == 0 { + runtime.Gosched() + } + + cancelCalled, result := cancelAsync(cancelFn) + waitForChan(t, cancelCalled, "cancelFn was not called") + close(model1.gateChan) + + assert.NoError(t, result.waitDone(t), "CancelImmediate should succeed at transition") + + cancelErr := drainCancelError(t, iter) + + assert.NotNil(t, cancelErr, "Should have CancelError at transition boundary") + assert.Equal(t, int32(1), atomic.LoadInt32(&model1.callCount), "Agent1 model should be invoked") + assert.Equal(t, int32(0), atomic.LoadInt32(&model2.callCount), "Agent2 model should NOT be invoked (caught at transition)") +} + +func TestCancelImmediate_LoopTransitionBoundary(t *testing.T) { + ctx := context.Background() + + mdl := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "loop iter"}, + gateChan: make(chan struct{}), + doneChan: make(chan struct{}, 10), + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "loop_inner", Description: "inner", Instruction: "test", Model: mdl, + }) + assert.NoError(t, err) + + loopAgent, err := NewLoopAgent(ctx, &LoopAgentConfig{ + Name: "loop", Description: "loop test", SubAgents: []Agent{agent}, MaxIterations: 5, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: loopAgent, EnableStreaming: false, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("hello")}, cancelOpt) + + for atomic.LoadInt32(&mdl.callCount) == 0 { + runtime.Gosched() + } + + cancelCalled, result := cancelAsync(cancelFn) + waitForChan(t, cancelCalled, "cancelFn was not called") + close(mdl.gateChan) + + assert.NoError(t, result.waitDone(t), "CancelImmediate should succeed at loop transition") + + for { + _, ok := iter.Next() + if !ok { + break + } + } + + assert.Equal(t, int32(1), atomic.LoadInt32(&mdl.callCount), + "Model should be called once; second iteration caught at transition") +} + +func TestCancelAfterChatModel_SequentialTransitionBoundary(t *testing.T) { + ctx := context.Background() + + model1 := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "agent1 done"}, + gateChan: make(chan struct{}), + doneChan: make(chan struct{}, 1), + } + model2 := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "agent2 done"}, + doneChan: make(chan struct{}, 1), + } + + agent1, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent1", Description: "first", Instruction: "test", Model: model1, + }) + assert.NoError(t, err) + + agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent2", Description: "second", Instruction: "test", Model: model2, + }) + assert.NoError(t, err) + + seqAgent, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "seq", Description: "sequential test", SubAgents: []Agent{agent1, agent2}, + }) + assert.NoError(t, err) + + store := newCancelTestStore() + runner := NewRunner(ctx, RunnerConfig{ + Agent: seqAgent, + EnableStreaming: false, + CheckPointStore: store, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("hello")}, cancelOpt, WithCheckPointID("chatmodel-transition-1")) + + for atomic.LoadInt32(&model1.callCount) == 0 { + runtime.Gosched() + } + + cancelCalled, result := cancelAsync(cancelFn, WithAgentCancelMode(CancelAfterChatModel)) + waitForChan(t, cancelCalled, "cancelFn was not called") + close(model1.gateChan) + + assert.NoError(t, result.waitDone(t), "CancelAfterChatModel should succeed at transition boundary") + + var cancelErr *CancelError + for { + event, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + cancelErr = ce + } + } + + assert.NotNil(t, cancelErr, "Should have CancelError at transition boundary") + assert.Equal(t, CancelAfterChatModel, cancelErr.Info.Mode) + assert.Equal(t, int32(1), atomic.LoadInt32(&model1.callCount), "Agent1 model should be invoked") + assert.Equal(t, int32(0), atomic.LoadInt32(&model2.callCount), + "Agent2 model should NOT be invoked (CancelAfterChatModel caught at transition)") +} + +func TestCancelAfterChatModel_Sequential_Agent1CompletesCancelBeforeAgent2Resume(t *testing.T) { + ctx := context.Background() + + model1 := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "agent1 done"}, + gateChan: make(chan struct{}), + doneChan: make(chan struct{}, 1), + } + model2 := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "agent2 done"}, + doneChan: make(chan struct{}, 1), + } + model3 := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "agent3 done"}, + doneChan: make(chan struct{}, 1), + } + + agent1, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent1", Description: "first", Instruction: "test", Model: model1, + }) + assert.NoError(t, err) + agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent2", Description: "second", Instruction: "test", Model: model2, + }) + assert.NoError(t, err) + agent3, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent3", Description: "third", Instruction: "test", Model: model3, + }) + assert.NoError(t, err) + + seqAgent, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "seq", Description: "3-agent sequential", SubAgents: []Agent{agent1, agent2, agent3}, + }) + assert.NoError(t, err) + + store := newCancelTestStore() + runner := NewRunner(ctx, RunnerConfig{ + Agent: seqAgent, CheckPointStore: store, EnableStreaming: false, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("hello")}, cancelOpt, + WithCheckPointID("seq-transition-resume-1")) + + for atomic.LoadInt32(&model1.callCount) == 0 { + runtime.Gosched() + } + + cancelCalled, result := cancelAsync(cancelFn, WithAgentCancelMode(CancelAfterChatModel)) + waitForChan(t, cancelCalled, "cancelFn was not called") + close(model1.gateChan) + + assert.NoError(t, result.waitDone(t)) + + cancelErr := drainCancelError(t, iter) + assert.NotNil(t, cancelErr, "Should have CancelError") + assert.Equal(t, CancelAfterChatModel, cancelErr.Info.Mode) + assert.Equal(t, int32(1), atomic.LoadInt32(&model1.callCount)) + assert.Equal(t, int32(0), atomic.LoadInt32(&model2.callCount), + "Agent2 should NOT run (cancel caught at transition after agent1)") + assert.Equal(t, int32(0), atomic.LoadInt32(&model3.callCount)) + assert.NotEmpty(t, cancelErr.CheckPointID) + + resumeModel2 := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "resumed agent2"}, + doneChan: make(chan struct{}, 1), + } + resumeModel3 := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "resumed agent3"}, + doneChan: make(chan struct{}, 1), + } + + resumeAgent1, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent1", Description: "first", Instruction: "test", + Model: &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "should not run"}, + doneChan: make(chan struct{}, 1), + }, + }) + assert.NoError(t, err) + resumeAgent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent2", Description: "second", Instruction: "test", Model: resumeModel2, + }) + assert.NoError(t, err) + resumeAgent3, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent3", Description: "third", Instruction: "test", Model: resumeModel3, + }) + assert.NoError(t, err) + + resumeSeq, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "seq", Description: "3-agent sequential", + SubAgents: []Agent{resumeAgent1, resumeAgent2, resumeAgent3}, + }) + assert.NoError(t, err) + + runner2 := NewRunner(ctx, RunnerConfig{ + Agent: resumeSeq, CheckPointStore: store, EnableStreaming: false, + }) + resumeIter, err := runner2.Resume(ctx, "seq-transition-resume-1") + assert.NoError(t, err) + assert.Empty(t, drainResumeErrors(t, resumeIter), "Resume should complete without errors") + + assert.Equal(t, int32(1), atomic.LoadInt32(&resumeModel2.callCount), + "Agent2 should run on resume") + assert.Equal(t, int32(1), atomic.LoadInt32(&resumeModel3.callCount), + "Agent3 should run on resume") +} + +func TestCancelAfterToolCalls_LoopTransitionBoundary(t *testing.T) { + ctx := context.Background() + + // Model that returns tool calls on odd calls and no tools on even calls. + // This completes one ReAct cycle per pair of calls: + // call 1 (gated): returns tool call → tool runs → call 2: returns no tools → END + // The gate only blocks the very first call. After that, all calls proceed instantly. + mdl := &multiResponseGatedModel{ + responses: []*schema.Message{ + {Role: schema.Assistant, ToolCalls: []schema.ToolCall{{ + ID: "call_1", Type: "function", + Function: schema.FunctionCall{Name: "loop_tool", Arguments: `{"input": "test"}`}, + }}}, + {Role: schema.Assistant, Content: "iteration done"}, + }, + gateChan: make(chan struct{}), + gateOnce: true, + doneChan: make(chan struct{}, 10), + } + + st := &slowTool{ + name: "loop_tool", + delay: 10 * time.Millisecond, + result: "tool done", + startedChan: make(chan struct{}, 10), + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "loop_inner", Description: "inner", Instruction: "test", Model: mdl, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + + loopAgent, err := NewLoopAgent(ctx, &LoopAgentConfig{ + Name: "loop", Description: "loop test", SubAgents: []Agent{agent}, MaxIterations: 10, + }) + assert.NoError(t, err) + + store := newCancelTestStore() + runner := NewRunner(ctx, RunnerConfig{Agent: loopAgent, CheckPointStore: store}) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt, WithCheckPointID("toolcalls-loop-1")) + + // Wait for the model to be entered (blocked on gate) + for atomic.LoadInt32(&mdl.callCount) == 0 { + runtime.Gosched() + } + + // Fire cancel, wait for it to be registered, then release the gate + cancelCalled, result := cancelAsync(cancelFn, WithAgentCancelMode(CancelAfterToolCalls)) + waitForChan(t, cancelCalled, "cancelFn was not called") + close(mdl.gateChan) + + // Iteration 1 completes fully (model→tool→model-no-tools→END). + // The CancelAfterToolCalls safe-point inside ReAct fires after tool calls, + // OR the transition boundary catches it before iteration 2. + // Note: this test doesn't deterministically distinguish which path fires — + // both are semantically correct for CancelAfterToolCalls. The transition- + // boundary code path for CancelAfterToolCalls in loops is not definitively + // covered here because the ReAct safe-point may handle it first. + assert.NoError(t, result.waitDone(t)) + + cancelErr := drainCancelError(t, iter) + assert.NotNil(t, cancelErr, "Should have CancelError from CancelAfterToolCalls in loop") + assert.Equal(t, CancelAfterToolCalls, cancelErr.Info.Mode) + assert.NotEmpty(t, cancelErr.CheckPointID) +} + +func TestCancelContext_ActiveChildren_Tracking(t *testing.T) { + t.Run("DeriveChild_IncrementsActiveChildren", func(t *testing.T) { + parent := newCancelContext() + assert.False(t, parent.hasActiveChildren()) + + ctx := context.Background() + child := parent.deriveChild(ctx) + assert.True(t, parent.hasActiveChildren()) + assert.Equal(t, int32(1), atomic.LoadInt32(&parent.activeChildren)) + + child.markDone() + time.Sleep(10 * time.Millisecond) + assert.False(t, parent.hasActiveChildren()) + assert.Equal(t, int32(0), atomic.LoadInt32(&parent.activeChildren)) + }) + + t.Run("MultipleChildren_AllTracked", func(t *testing.T) { + parent := newCancelContext() + ctx := context.Background() + + child1 := parent.deriveChild(ctx) + child2 := parent.deriveChild(ctx) + assert.Equal(t, int32(2), atomic.LoadInt32(&parent.activeChildren)) + + child1.markDone() + time.Sleep(10 * time.Millisecond) + assert.Equal(t, int32(1), atomic.LoadInt32(&parent.activeChildren)) + assert.True(t, parent.hasActiveChildren()) + + child2.markDone() + time.Sleep(10 * time.Millisecond) + assert.False(t, parent.hasActiveChildren()) + }) + + t.Run("MarkCancelHandled_AlsoDecrementsParent", func(t *testing.T) { + parent := newCancelContext() + ctx := context.Background() + + child := parent.deriveChild(ctx) + assert.True(t, parent.hasActiveChildren()) + + child.triggerCancel(CancelImmediate) + child.markCancelHandled() + time.Sleep(10 * time.Millisecond) + assert.False(t, parent.hasActiveChildren()) + }) + + t.Run("GracePeriodWrapper_AppliesWhenChildrenActive", func(t *testing.T) { + parent := newCancelContext() + ctx := context.Background() + + var receivedOpts []compose.GraphInterruptOption + mockInterrupt := func(opts ...compose.GraphInterruptOption) { + receivedOpts = opts + } + + wrapped := parent.wrapGraphInterruptWithGracePeriod(mockInterrupt) + + // No children: no options appended + receivedOpts = nil + wrapped() + assert.Empty(t, receivedOpts, "Should pass no extra options when no children") + + // With active child: one timeout option appended + _ = parent.deriveChild(ctx) + receivedOpts = nil + wrapped() + assert.Len(t, receivedOpts, 1, "Should add exactly one timeout option when children are active") + + // Caller-provided options are preserved, grace period option appended after + receivedOpts = nil + callerOpt := compose.WithGraphInterruptTimeout(0) + wrapped(callerOpt) + assert.Len(t, receivedOpts, 2, + "Should append timeout option after caller-provided options when children are active") + // Note: verifying the exact timeout value (defaultCancelImmediateGracePeriod) + // requires access to unexported compose.graphInterruptOptions. The integration + // tests (TestCancelImmediate_AgentTool_PreservesChildCheckpoint) verify the + // actual behavioral effect — child interrupts propagate within the grace period. + }) +} + +func TestCancel_ParallelWorkflow_CancelAfterChatModel(t *testing.T) { + ctx := context.Background() + slowStarted := make(chan struct{}, 1) + + slowAgent := newCancelTestAgentWithTools(t, "par_slow", 1*time.Second, slowStarted) + fastAgent := newCancelTestAgentWithTools(t, "par_fast", 50*time.Millisecond, make(chan struct{}, 1)) + + parAgent, err := NewParallelAgent(ctx, &ParallelAgentConfig{ + Name: "par_agent", + Description: "Parallel workflow", + SubAgents: []Agent{slowAgent, fastAgent}, + }) + assert.NoError(t, err) + + store := newCancelTestStore() + runner := NewRunner(ctx, RunnerConfig{ + Agent: parAgent, + CheckPointStore: store, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt, WithCheckPointID("par-cancel-1")) + + select { + case <-slowStarted: + case <-time.After(10 * time.Second): + t.Fatal("Slow agent did not start") + } + + handle, contributed := cancelFn(WithAgentCancelMode(CancelAfterChatModel)) + assert.True(t, contributed, "Cancel should contribute") + err = handle.Wait() + assert.NoError(t, err) + + hasCancelError := false + var cancelErr *CancelError + for { + event, ok := iter.Next() + if !ok { + break + } + if event.Err != nil && errors.As(event.Err, &cancelErr) { + hasCancelError = true + } + } + + assert.True(t, hasCancelError, "Should have CancelError from parallel workflow") + assert.Equal(t, CancelAfterChatModel, cancelErr.Info.Mode) + assert.NotEmpty(t, cancelErr.CheckPointID, "CancelError should have checkpoint ID") + + resumeSlow := newCancelTestAgentWithToolsFinalAnswer(t, "par_slow") + resumeFast := newCancelTestAgentWithToolsFinalAnswer(t, "par_fast") + + resumePar, err := NewParallelAgent(ctx, &ParallelAgentConfig{ + Name: "par_agent", + Description: "Parallel workflow", + SubAgents: []Agent{resumeSlow, resumeFast}, + }) + assert.NoError(t, err) + + runner2 := NewRunner(ctx, RunnerConfig{ + Agent: resumePar, + CheckPointStore: store, + }) + + resumeIter, err := runner2.Resume(ctx, "par-cancel-1") + assert.NoError(t, err) + assert.NotNil(t, resumeIter) + + var resumeErrors []error + for { + event, ok := resumeIter.Next() + if !ok { + break + } + if event.Err != nil { + resumeErrors = append(resumeErrors, event.Err) + } + } + assert.Empty(t, resumeErrors, "Resume should complete without errors") +} + +func TestCancel_LoopWorkflow_CancelAfterChatModel(t *testing.T) { + ctx := context.Background() + modelStarted := make(chan struct{}, 10) + + agent := newCancelTestAgentWithTools(t, "loop_inner", 500*time.Millisecond, modelStarted) + + loopAgent, err := NewLoopAgent(ctx, &LoopAgentConfig{ + Name: "loop_agent", + Description: "Loop workflow", + SubAgents: []Agent{agent}, + MaxIterations: 10, + }) + assert.NoError(t, err) + + store := newCancelTestStore() + runner := NewRunner(ctx, RunnerConfig{ + Agent: loopAgent, + CheckPointStore: store, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt, WithCheckPointID("loop-cancel-1")) + + select { + case <-modelStarted: + case <-time.After(10 * time.Second): + t.Fatal("Model did not start") + } + + handle, contributed := cancelFn(WithAgentCancelMode(CancelAfterChatModel)) + assert.True(t, contributed, "Cancel should contribute") + err = handle.Wait() + assert.NoError(t, err) + + hasCancelError := false + var cancelErr *CancelError + for { + event, ok := iter.Next() + if !ok { + break + } + if event.Err != nil && errors.As(event.Err, &cancelErr) { + hasCancelError = true + } + } + + assert.True(t, hasCancelError, "Should have CancelError from loop workflow") + assert.Equal(t, CancelAfterChatModel, cancelErr.Info.Mode) + assert.NotEmpty(t, cancelErr.CheckPointID, "CancelError should have checkpoint ID") + + resumeAgent := newCancelTestAgentWithToolsFinalAnswer(t, "loop_inner") + + resumeLoop, err := NewLoopAgent(ctx, &LoopAgentConfig{ + Name: "loop_agent", + Description: "Loop workflow", + SubAgents: []Agent{resumeAgent}, + MaxIterations: 10, + }) + assert.NoError(t, err) + + runner2 := NewRunner(ctx, RunnerConfig{ + Agent: resumeLoop, + CheckPointStore: store, + }) + + resumeIter, err := runner2.Resume(ctx, "loop-cancel-1") + assert.NoError(t, err) + assert.NotNil(t, resumeIter) + + var resumeEvents []*AgentEvent + for { + event, ok := resumeIter.Next() + if !ok { + break + } + assert.Nil(t, event.Err, "Should not have error during resume") + resumeEvents = append(resumeEvents, event) + } + assert.True(t, len(resumeEvents) > 0, "Resume should produce events") +} + +func TestCancel_NestedWorkflow_AgentTool_CancelAfterChatModel(t *testing.T) { + // Structure: Runner -> RootCMA (with tools) -> agentTool -> flowAgent -> seqWorkflow -> LeafCMA + ctx := context.Background() + leafStarted := make(chan struct{}, 1) + + leafModel := &cancelTestChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + Content: "leaf response", + }, + startedChan: leafStarted, + doneChan: make(chan struct{}, 1), + } + leafModel.setDelay(500 * time.Millisecond) + + leafAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "leaf_agent", + Description: "Leaf agent in workflow", + Model: leafModel, + }) + assert.NoError(t, err) + + seqAgent, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "inner_seq", + Description: "Inner sequential workflow", + SubAgents: []Agent{leafAgent}, + }) + assert.NoError(t, err) + + rootModel := &cancelTestChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: schema.FunctionCall{ + Name: "inner_seq", + Arguments: `{}`, + }, + }, + }, + }, + startedChan: make(chan struct{}, 1), + doneChan: make(chan struct{}, 1), + } + rootAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "root_agent", + Description: "Root agent", + Model: rootModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{NewAgentTool(ctx, seqAgent)}, + }, + }, + }) + assert.NoError(t, err) + + store := newCancelTestStore() + runner := NewRunner(ctx, RunnerConfig{ + Agent: rootAgent, + CheckPointStore: store, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt, WithCheckPointID("nested-cancel-1")) + + select { + case <-leafStarted: + case <-time.After(10 * time.Second): + t.Fatal("Leaf agent model did not start") + } + + handle, contributed := cancelFn(WithAgentCancelMode(CancelAfterChatModel)) + assert.True(t, contributed, "Cancel should contribute") + err = handle.Wait() + assert.NoError(t, err) + + hasCancelError := false + var cancelErr *CancelError + for { + event, ok := iter.Next() + if !ok { + break + } + if event.Err != nil && errors.As(event.Err, &cancelErr) { + hasCancelError = true + } + } + + assert.True(t, hasCancelError, "Should have CancelError from deeply nested workflow") + assert.Equal(t, CancelAfterChatModel, cancelErr.Info.Mode) + assert.NotEmpty(t, cancelErr.CheckPointID, "CancelError should have checkpoint ID") + assert.NotNil(t, cancelErr.interruptSignal, "CancelError should carry interrupt signal through agent tree") + + // Phase 2: Resume from checkpoint — new instances to avoid data races + resumeLeafModel := &cancelTestChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + Content: "resumed leaf response", + }, + startedChan: make(chan struct{}, 1), + doneChan: make(chan struct{}, 1), + } + resumeLeaf, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "leaf_agent", + Description: "Leaf agent in workflow", + Model: resumeLeafModel, + }) + assert.NoError(t, err) + + resumeSeq, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "inner_seq", + Description: "Inner sequential workflow", + SubAgents: []Agent{resumeLeaf}, + }) + assert.NoError(t, err) + + resumeRootModel := &cancelTestChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + Content: "resumed root response", + }, + startedChan: make(chan struct{}, 1), + doneChan: make(chan struct{}, 1), + } + resumeRoot, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "root_agent", + Description: "Root agent", + Model: resumeRootModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{NewAgentTool(ctx, resumeSeq)}, + }, + }, + }) + assert.NoError(t, err) + + runner2 := NewRunner(ctx, RunnerConfig{ + Agent: resumeRoot, + CheckPointStore: store, + }) + + resumeIter, err := runner2.Resume(ctx, "nested-cancel-1") + assert.NoError(t, err) + assert.NotNil(t, resumeIter) + + var resumeErrors []error + for { + event, ok := resumeIter.Next() + if !ok { + break + } + if event.Err != nil { + resumeErrors = append(resumeErrors, event.Err) + } + } + assert.Empty(t, resumeErrors, "Resume should complete without errors") +} + +func TestCancel_CancelAfterToolCalls_InSequentialWorkflow(t *testing.T) { + ctx := context.Background() + toolStarted := make(chan struct{}, 1) + + st := &slowTool{ + name: "slow_tool", + delay: 200 * time.Millisecond, + result: "tool done", + startedChan: toolStarted, + } + + modelWithToolCall := &simpleChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: schema.FunctionCall{ + Name: "slow_tool", + Arguments: `{"input": "test"}`, + }, + }, + }, + }, + } + + agentWithTools, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent_with_tools", + Description: "Agent with slow tool", + Model: modelWithToolCall, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + + seqAgent, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "seq_agent", + Description: "Sequential workflow with tool agent", + SubAgents: []Agent{agentWithTools}, + }) + assert.NoError(t, err) + + store := newCancelTestStore() + runner := NewRunner(ctx, RunnerConfig{ + Agent: seqAgent, + CheckPointStore: store, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("Use the tool")}, cancelOpt, WithCheckPointID("tool-cancel-1")) + + select { + case <-toolStarted: + case <-time.After(10 * time.Second): + t.Fatal("Tool did not start") + } + + // Cancel after tool calls — should wait for the tool to finish, then cancel + handle, contributed := cancelFn(WithAgentCancelMode(CancelAfterToolCalls)) + assert.True(t, contributed, "Cancel should contribute") + err = handle.Wait() + assert.NoError(t, err) + + hasCancelError := false + var cancelErr *CancelError + for { + event, ok := iter.Next() + if !ok { + break + } + if event.Err != nil && errors.As(event.Err, &cancelErr) { + hasCancelError = true + } + } + + assert.True(t, hasCancelError, "Should have CancelError after tool calls complete") + assert.Equal(t, CancelAfterToolCalls, cancelErr.Info.Mode) + assert.NotEmpty(t, cancelErr.CheckPointID, "CancelError should have checkpoint ID") + + // Phase 2: Resume from checkpoint — new instances + resumeTool := &slowTool{ + name: "slow_tool", + delay: 50 * time.Millisecond, + result: "resumed tool done", + startedChan: make(chan struct{}, 1), + } + + resumeModel := &simpleChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + Content: "resumed response after tool", + }, + } + + resumeAgentWithTools, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent_with_tools", + Description: "Agent with slow tool", + Model: resumeModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{resumeTool}, + }, + }, + }) + assert.NoError(t, err) + + resumeSeq, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "seq_agent", + Description: "Sequential workflow with tool agent", + SubAgents: []Agent{resumeAgentWithTools}, + }) + assert.NoError(t, err) + + runner2 := NewRunner(ctx, RunnerConfig{ + Agent: resumeSeq, + CheckPointStore: store, + }) + + resumeIter, err := runner2.Resume(ctx, "tool-cancel-1") + assert.NoError(t, err) + assert.NotNil(t, resumeIter) + + var resumeEvents []*AgentEvent + for { + event, ok := resumeIter.Next() + if !ok { + break + } + assert.Nil(t, event.Err, "Should not have error during resume") + resumeEvents = append(resumeEvents, event) + } + assert.True(t, len(resumeEvents) > 0, "Resume should produce events") } diff --git a/adk/cancel_wrapper.go b/adk/cancel_wrapper.go deleted file mode 100644 index 396d6a458..000000000 --- a/adk/cancel_wrapper.go +++ /dev/null @@ -1,295 +0,0 @@ -/* - * Copyright 2026 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package adk - -import ( - "context" - "io" - "reflect" - "runtime/debug" - "time" - - "github.com/cloudwego/eino/components" - "github.com/cloudwego/eino/components/model" - "github.com/cloudwego/eino/compose" - "github.com/cloudwego/eino/internal/generic" - "github.com/cloudwego/eino/internal/safe" - "github.com/cloudwego/eino/schema" -) - -type cancelWaitResult[T any] struct { - result T - err error - cancelled bool -} - -func waitWithCancel[T any](cs *cancelSig, resultCh <-chan cancelWaitResult[T]) cancelWaitResult[T] { - var timeCh <-chan time.Time - select { - case <-cs.done: - cfg := cs.config.Load().(*cancelConfig) - if cfg.Mode == CancelImmediate { - if cfg.Timeout == nil { - return cancelWaitResult[T]{cancelled: true} - } - timeCh = time.After(*cfg.Timeout) - } - case res := <-resultCh: - return res - } - select { - case <-timeCh: - return cancelWaitResult[T]{cancelled: true} - case res := <-resultCh: - return res - } -} - -type cancelableChatModel struct { - inner model.BaseChatModel - cs *cancelSig -} - -func wrapModelForCancelable(m model.BaseChatModel, cs *cancelSig) *cancelableChatModel { - return &cancelableChatModel{inner: m, cs: cs} -} - -func (c *cancelableChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { - if cfg := checkCancelSig(c.cs); cfg != nil && cfg.Mode == CancelImmediate { - return nil, compose.Interrupt(ctx, "cancelled externally") - } - - resultCh := make(chan cancelWaitResult[*schema.Message], 1) - go func() { - defer func() { - if panicErr := recover(); panicErr != nil { - resultCh <- cancelWaitResult[*schema.Message]{err: safe.NewPanicErr(panicErr, debug.Stack())} - } - }() - res, err := c.inner.Generate(ctx, input, opts...) - resultCh <- cancelWaitResult[*schema.Message]{result: res, err: err} - }() - - res := waitWithCancel(c.cs, resultCh) - if res.cancelled { - return nil, compose.Interrupt(ctx, "cancelled externally") - } - return res.result, res.err -} - -func (c *cancelableChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { - if cfg := checkCancelSig(c.cs); cfg != nil && cfg.Mode == CancelImmediate { - return nil, compose.Interrupt(ctx, "cancelled externally") - } - - resultCh := make(chan cancelWaitResult[*schema.StreamReader[*schema.Message]], 1) - go func() { - defer func() { - if panicErr := recover(); panicErr != nil { - resultCh <- cancelWaitResult[*schema.StreamReader[*schema.Message]]{err: safe.NewPanicErr(panicErr, debug.Stack())} - } - }() - - stream, err := c.inner.Stream(ctx, input, opts...) - if err != nil { - resultCh <- cancelWaitResult[*schema.StreamReader[*schema.Message]]{err: err} - return - } - copies := stream.Copy(2) - _ = consumeStreamForError(copies[0]) - resultCh <- cancelWaitResult[*schema.StreamReader[*schema.Message]]{result: copies[1]} - }() - - res := waitWithCancel(c.cs, resultCh) - if res.cancelled { - return nil, compose.Interrupt(ctx, "cancelled externally") - } - return res.result, res.err -} - -func (c *cancelableChatModel) IsCallbacksEnabled() bool { - return components.IsCallbacksEnabled(c.inner) -} - -func (c *cancelableChatModel) GetType() string { - if name, ok := components.GetType(c.inner); ok { - return name - } - - return generic.ParseTypeName(reflect.ValueOf(c.inner)) -} - -func cancelableToolInvokable(cs *cancelSig, endpoint compose.InvokableToolEndpoint) compose.InvokableToolEndpoint { - return func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) { - if cfg := checkCancelSig(cs); cfg != nil && cfg.Mode == CancelImmediate { - return nil, compose.Interrupt(ctx, "cancelled externally") - } - - resultCh := make(chan cancelWaitResult[*compose.ToolOutput], 1) - go func() { - defer func() { - if panicErr := recover(); panicErr != nil { - resultCh <- cancelWaitResult[*compose.ToolOutput]{err: safe.NewPanicErr(panicErr, debug.Stack())} - } - }() - output, err := endpoint(ctx, input) - resultCh <- cancelWaitResult[*compose.ToolOutput]{result: output, err: err} - }() - - res := waitWithCancel(cs, resultCh) - if res.cancelled { - return nil, compose.Interrupt(ctx, "cancelled externally") - } - return res.result, res.err - } -} - -func cancelableToolStreamable(cs *cancelSig, endpoint compose.StreamableToolEndpoint) compose.StreamableToolEndpoint { - return func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) { - if cfg := checkCancelSig(cs); cfg != nil && cfg.Mode == CancelImmediate { - return nil, compose.Interrupt(ctx, "cancelled externally") - } - - resultCh := make(chan cancelWaitResult[*schema.StreamReader[string]], 1) - go func() { - defer func() { - if panicErr := recover(); panicErr != nil { - resultCh <- cancelWaitResult[*schema.StreamReader[string]]{err: safe.NewPanicErr(panicErr, debug.Stack())} - } - }() - output, err := endpoint(ctx, input) - if err != nil { - resultCh <- cancelWaitResult[*schema.StreamReader[string]]{err: err} - return - } - copies := output.Result.Copy(2) - _ = consumeStreamForErrorString(copies[0]) - resultCh <- cancelWaitResult[*schema.StreamReader[string]]{result: copies[1]} - }() - - res := waitWithCancel(cs, resultCh) - if res.cancelled { - return nil, compose.Interrupt(ctx, "cancelled externally") - } - if res.err != nil { - return nil, res.err - } - return &compose.StreamToolOutput{Result: res.result}, nil - } -} - -func cancelableToolEnhancedInvokable(cs *cancelSig, endpoint compose.EnhancedInvokableToolEndpoint) compose.EnhancedInvokableToolEndpoint { - return func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedInvokableToolOutput, error) { - if cfg := checkCancelSig(cs); cfg != nil && cfg.Mode == CancelImmediate { - return nil, compose.Interrupt(ctx, "cancelled externally") - } - - resultCh := make(chan cancelWaitResult[*compose.EnhancedInvokableToolOutput], 1) - go func() { - defer func() { - if panicErr := recover(); panicErr != nil { - resultCh <- cancelWaitResult[*compose.EnhancedInvokableToolOutput]{err: safe.NewPanicErr(panicErr, debug.Stack())} - } - }() - output, err := endpoint(ctx, input) - resultCh <- cancelWaitResult[*compose.EnhancedInvokableToolOutput]{result: output, err: err} - }() - - res := waitWithCancel(cs, resultCh) - if res.cancelled { - return nil, compose.Interrupt(ctx, "cancelled externally") - } - return res.result, res.err - } -} - -func cancelableToolEnhancedStreamable(cs *cancelSig, endpoint compose.EnhancedStreamableToolEndpoint) compose.EnhancedStreamableToolEndpoint { - return func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedStreamableToolOutput, error) { - if cfg := checkCancelSig(cs); cfg != nil && cfg.Mode == CancelImmediate { - return nil, compose.Interrupt(ctx, "cancelled externally") - } - - resultCh := make(chan cancelWaitResult[*schema.StreamReader[*schema.ToolResult]], 1) - go func() { - defer func() { - if panicErr := recover(); panicErr != nil { - resultCh <- cancelWaitResult[*schema.StreamReader[*schema.ToolResult]]{err: safe.NewPanicErr(panicErr, debug.Stack())} - } - }() - output, err := endpoint(ctx, input) - if err != nil { - resultCh <- cancelWaitResult[*schema.StreamReader[*schema.ToolResult]]{err: err} - return - } - copies := output.Result.Copy(2) - _ = consumeStreamForErrorToolResult(copies[0]) - resultCh <- cancelWaitResult[*schema.StreamReader[*schema.ToolResult]]{result: copies[1]} - }() - - res := waitWithCancel(cs, resultCh) - if res.cancelled { - return nil, compose.Interrupt(ctx, "cancelled externally") - } - if res.err != nil { - return nil, res.err - } - return &compose.EnhancedStreamableToolOutput{Result: res.result}, nil - } -} - -func cancelableTool(cs *cancelSig) compose.ToolMiddleware { - return compose.ToolMiddleware{ - Invokable: func(endpoint compose.InvokableToolEndpoint) compose.InvokableToolEndpoint { - return cancelableToolInvokable(cs, endpoint) - }, - Streamable: func(endpoint compose.StreamableToolEndpoint) compose.StreamableToolEndpoint { - return cancelableToolStreamable(cs, endpoint) - }, - EnhancedInvokable: func(endpoint compose.EnhancedInvokableToolEndpoint) compose.EnhancedInvokableToolEndpoint { - return cancelableToolEnhancedInvokable(cs, endpoint) - }, - EnhancedStreamable: func(endpoint compose.EnhancedStreamableToolEndpoint) compose.EnhancedStreamableToolEndpoint { - return cancelableToolEnhancedStreamable(cs, endpoint) - }, - } -} - -func consumeStreamForErrorString(stream *schema.StreamReader[string]) error { - defer stream.Close() - for { - _, err := stream.Recv() - if err == io.EOF { - return nil - } - if err != nil { - return err - } - } -} - -func consumeStreamForErrorToolResult(stream *schema.StreamReader[*schema.ToolResult]) error { - defer stream.Close() - for { - _, err := stream.Recv() - if err == io.EOF { - return nil - } - if err != nil { - return err - } - } -} diff --git a/adk/chatmodel.go b/adk/chatmodel.go index 0848a859a..a0126455a 100644 --- a/adk/chatmodel.go +++ b/adk/chatmodel.go @@ -39,8 +39,6 @@ import ( ) var _ ResumableAgent = &ChatModelAgent{} -var _ CancellableAgent = &ChatModelAgent{} -var _ CancellableResumableAgent = &ChatModelAgent{} type chatModelAgentExecCtx struct { runtimeReturnDirectly map[string]bool @@ -345,8 +343,19 @@ type ChatModelAgent struct { exeCtx *execContext } -type runFunc func(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent], - store *bridgeStore, instruction string, returnDirectly map[string]bool, cs *cancelSig, opts ...compose.Option) +// runParams holds the parameters for a runFunc invocation. +type runParams struct { + input *AgentInput + generator *AsyncGenerator[*AgentEvent] + store *bridgeStore + instruction string + returnDirectly map[string]bool + cancelCtx *cancelContext + cancelCtxOwned bool + composeOpts []compose.Option +} + +type runFunc func(ctx context.Context, p *runParams) // NewChatModelAgent constructs a chat model-backed agent with the provided config. func NewChatModelAgent(ctx context.Context, config *ChatModelAgentConfig) (*ChatModelAgent, error) { @@ -378,6 +387,16 @@ func NewChatModelAgent(ctx context.Context, config *ChatModelAgentConfig) (*Chat ) tc.ToolCallMiddlewares = append(tc.ToolCallMiddlewares, collectToolMiddlewaresFromMiddlewares(config.Middlewares)...) + // Cancel monitoring middleware (innermost — close to the tool endpoint). + // This allows early abort of the raw tool result stream when immediateChan fires + // (CancelImmediate or timeout escalation), while requiring outer wrappers to + // propagate stream errors such as ErrStreamCanceled without swallowing them. + cancelToolHandler := &cancelMonitoredToolHandler{} + tc.ToolCallMiddlewares = append(tc.ToolCallMiddlewares, compose.ToolMiddleware{ + Streamable: cancelToolHandler.WrapStreamableToolCall, + EnhancedStreamable: cancelToolHandler.WrapEnhancedStreamableToolCall, + }) + return &ChatModelAgent{ name: config.Name, description: config.Description, @@ -575,8 +594,8 @@ func setOutputToSession(ctx context.Context, msg Message, msgStream MessageStrea } func errFunc(err error) runFunc { - return func(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent], store *bridgeStore, _ string, _ map[string]bool, _ *cancelSig, _ ...compose.Option) { - generator.Send(&AgentEvent{Err: err}) + return func(ctx context.Context, p *runParams) { + p.generator.Send(&AgentEvent{Err: err}) } } @@ -696,25 +715,74 @@ func (a *ChatModelAgent) prepareExecContext(ctx context.Context) (*execContext, }, nil } +// handleRunFuncError is the common error handler for buildNoToolsRunFunc and buildReActRunFunc. +// It handles compose interrupts (both cancel-triggered and business) +// and generic errors, sending the appropriate event to the generator. +func (a *ChatModelAgent) handleRunFuncError( + ctx context.Context, + err error, + cancelCtx *cancelContext, + cancelCtxOwned bool, + store *bridgeStore, + generator *AsyncGenerator[*AgentEvent], +) { + info, ok := compose.ExtractInterruptInfo(err) + if ok { + if cancelCtx != nil { + // Note: there is a benign TOCTOU window here. Between shouldCancel() + // returning false and markDone() executing, a concurrent cancel could + // transition stateRunning→stateCancelling. markDone() then does + // stateCancelling→stateDone, and the cancel func receives + // ErrExecutionCompleted (execution finished before cancel took effect). + if !cancelCtx.shouldCancel() { + cancelCtx.markDone() + } + } + + data, existed, sErr := store.Get(ctx, bridgeCheckpointID) + if sErr != nil { + generator.Send(&AgentEvent{AgentName: a.name, Err: fmt.Errorf("failed to get interrupt info: %w", sErr)}) + return + } + if !existed { + generator.Send(&AgentEvent{AgentName: a.name, Err: fmt.Errorf("interrupt occurred but checkpoint data is missing")}) + return + } + + is := FromInterruptContexts(info.InterruptContexts) + event := CompositeInterrupt(ctx, info, data, is) + event.Action.Interrupted.Data = &ChatModelAgentInterruptInfo{ + Info: info, + Data: data, + } + event.AgentName = a.name + generator.Send(event) + return + } + + if cancelCtxOwned && cancelCtx != nil { + cancelCtx.markDone() + } + generator.Send(&AgentEvent{Err: err}) +} + func (a *ChatModelAgent) buildNoToolsRunFunc(_ context.Context) runFunc { type noToolsInput struct { input *AgentInput instruction string } - return func(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent], - store *bridgeStore, instruction string, _ map[string]bool, cs *cancelSig, opts ...compose.Option) { + return func(ctx context.Context, p *runParams) { + cancelCtx := p.cancelCtx + ctx = withCancelContext(ctx, cancelCtx) wrappedModel := buildModelWrappers(a.model, &modelWrapperConfig{ - handlers: a.handlers, - middlewares: a.middlewares, - retryConfig: a.modelRetryConfig, + handlers: a.handlers, + middlewares: a.middlewares, + retryConfig: a.modelRetryConfig, + cancelContext: cancelCtx, }) - if cs != nil { - wrappedModel = wrapModelForCancelable(wrappedModel, cs) - } - chain := compose.NewChain[noToolsInput, Message]( compose.WithGenLocalState(func(ctx context.Context) (state *State) { return &State{} @@ -724,37 +792,63 @@ func (a *ChatModelAgent) buildNoToolsRunFunc(_ context.Context) runFunc { if err != nil { return nil, err } + _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { + st.Messages = append(st.Messages, messages...) + return nil + }) return messages, nil })). AppendChatModel(wrappedModel) - r, err := chain.Compile(ctx, compose.WithGraphName(a.name), - compose.WithCheckPointStore(store), + var compileOptions []compose.GraphCompileOption + compileOptions = append(compileOptions, + compose.WithGraphName(a.name), + compose.WithCheckPointStore(p.store), compose.WithSerializer(&gobSerializer{})) + + if cancelCtx != nil { + var interrupt func(...compose.GraphInterruptOption) + ctx, interrupt = compose.WithGraphInterrupt(ctx) + cancelCtx.setGraphInterruptFunc(cancelCtx.wrapGraphInterruptWithGracePeriod(interrupt)) + } + + r, err := chain.Compile(ctx, compileOptions...) if err != nil { - generator.Send(&AgentEvent{Err: err}) + p.generator.Send(&AgentEvent{Err: err}) return } ctx = withChatModelAgentExecCtx(ctx, &chatModelAgentExecCtx{ - generator: generator, + generator: p.generator, }) - in := noToolsInput{input: input, instruction: instruction} + // Pre-execution cancel check + if cancelCtx != nil && cancelCtx.shouldCancel() { + if cancelCtx.getMode() == CancelImmediate || atomic.LoadInt32(&cancelCtx.escalated) == 1 { + cancelErr, ok := cancelCtx.createAndMarkCancelHandled() + if !ok { + return + } + p.generator.Send(&AgentEvent{Err: cancelErr}) + return + } + } + + in := noToolsInput{input: p.input, instruction: p.instruction} var msg Message var msgStream MessageStream - if input.EnableStreaming { - msgStream, err = r.Stream(ctx, in, opts...) + if p.input.EnableStreaming { + msgStream, err = r.Stream(ctx, in, p.composeOpts...) } else { - msg, err = r.Invoke(ctx, in, opts...) + msg, err = r.Invoke(ctx, in, p.composeOpts...) } if err == nil { if a.outputKey != "" { err = setOutputToSession(ctx, msg, msgStream, a.outputKey) if err != nil { - generator.Send(&AgentEvent{Err: err}) + p.generator.Send(&AgentEvent{Err: err}) } } else if msgStream != nil { msgStream.Close() @@ -762,30 +856,7 @@ func (a *ChatModelAgent) buildNoToolsRunFunc(_ context.Context) runFunc { return } - info, ok := compose.ExtractInterruptInfo(err) - if !ok { - generator.Send(&AgentEvent{Err: err}) - return - } - - data, existed, sErr := store.Get(ctx, bridgeCheckpointID) - if sErr != nil { - generator.Send(&AgentEvent{AgentName: a.name, Err: fmt.Errorf("failed to get interrupt info: %w", sErr)}) - return - } - if !existed { - generator.Send(&AgentEvent{AgentName: a.name, Err: fmt.Errorf("interrupt occurred but checkpoint data is missing")}) - return - } - - is := FromInterruptContexts(info.InterruptContexts) - event := CompositeInterrupt(ctx, info, data, is) - event.Action.Interrupted.Data = &ChatModelAgentInterruptInfo{ - Info: info, - Data: data, - } - event.AgentName = a.name - generator.Send(event) + a.handleRunFuncError(ctx, err, cancelCtx, p.cancelCtxOwned, p.store, p.generator) } } @@ -809,11 +880,17 @@ func (a *ChatModelAgent) buildReActRunFunc(_ context.Context, bc *execContext) ( instruction string } - return func(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent], store *bridgeStore, - instruction string, returnDirectly map[string]bool, cs *cancelSig, opts ...compose.Option) { - g, err := newReact(ctx, conf, cs) + return func(ctx context.Context, p *runParams) { + cancelCtx := p.cancelCtx + conf.cancelCtx = cancelCtx + if conf.modelWrapperConf != nil { + conf.modelWrapperConf.cancelContext = cancelCtx + } + ctx = withCancelContext(ctx, cancelCtx) + + g, err := newReact(ctx, conf) if err != nil { - generator.Send(&AgentEvent{Err: err}) + p.generator.Send(&AgentEvent{Err: err}) return } @@ -825,7 +902,7 @@ func (a *ChatModelAgent) buildReActRunFunc(_ context.Context, bc *execContext) ( return nil, genErr } return &reactInput{ - messages: messages, + Messages: messages, }, nil }), ). @@ -834,38 +911,56 @@ func (a *ChatModelAgent) buildReActRunFunc(_ context.Context, bc *execContext) ( var compileOptions []compose.GraphCompileOption compileOptions = append(compileOptions, compose.WithGraphName(a.name), - compose.WithCheckPointStore(store), + compose.WithCheckPointStore(p.store), compose.WithSerializer(&gobSerializer{}), compose.WithMaxRunSteps(math.MaxInt)) + if cancelCtx != nil { + var interrupt func(...compose.GraphInterruptOption) + ctx, interrupt = compose.WithGraphInterrupt(ctx) + cancelCtx.setGraphInterruptFunc(cancelCtx.wrapGraphInterruptWithGracePeriod(interrupt)) + } + runnable, err_ := chain.Compile(ctx, compileOptions...) if err_ != nil { - generator.Send(&AgentEvent{Err: err_}) + p.generator.Send(&AgentEvent{Err: err_}) return } ctx = withChatModelAgentExecCtx(ctx, &chatModelAgentExecCtx{ - runtimeReturnDirectly: returnDirectly, - generator: generator, + runtimeReturnDirectly: p.returnDirectly, + generator: p.generator, }) + // Pre-execution cancel check + if cancelCtx != nil && cancelCtx.shouldCancel() { + if cancelCtx.getMode() == CancelImmediate || atomic.LoadInt32(&cancelCtx.escalated) == 1 { + cancelErr, ok := cancelCtx.createAndMarkCancelHandled() + if !ok { + return + } + p.generator.Send(&AgentEvent{Err: cancelErr}) + return + } + } + in := reactRunInput{ - input: input, - instruction: instruction, + input: p.input, + instruction: p.instruction, } var runOpts []compose.Option - runOpts = append(runOpts, opts...) + runOpts = append(runOpts, p.composeOpts...) if a.toolsConfig.EmitInternalEvents { - runOpts = append(runOpts, compose.WithToolsNodeOption(compose.WithToolOption(withAgentToolEventGenerator(generator)))) + runOpts = append(runOpts, compose.WithToolsNodeOption(compose.WithToolOption(withAgentToolEventGenerator(p.generator)))) } - if input.EnableStreaming { + if p.input.EnableStreaming { runOpts = append(runOpts, compose.WithToolsNodeOption(compose.WithToolOption(withAgentToolEnableStreaming(true)))) } var msg Message var msgStream MessageStream - if input.EnableStreaming { + if p.input.EnableStreaming { msgStream, err_ = runnable.Stream(ctx, in, runOpts...) } else { msg, err_ = runnable.Invoke(ctx, in, runOpts...) @@ -875,7 +970,7 @@ func (a *ChatModelAgent) buildReActRunFunc(_ context.Context, bc *execContext) ( if a.outputKey != "" { err_ = setOutputToSession(ctx, msg, msgStream, a.outputKey) if err_ != nil { - generator.Send(&AgentEvent{Err: err_}) + p.generator.Send(&AgentEvent{Err: err_}) } } else if msgStream != nil { msgStream.Close() @@ -884,31 +979,7 @@ func (a *ChatModelAgent) buildReActRunFunc(_ context.Context, bc *execContext) ( return } - info, ok := compose.ExtractInterruptInfo(err_) - if !ok { - generator.Send(&AgentEvent{Err: err_}) - return - } - - data, existed, err := store.Get(ctx, bridgeCheckpointID) - if err != nil { - generator.Send(&AgentEvent{AgentName: a.name, Err: fmt.Errorf("failed to get interrupt info: %w", err)}) - return - } - if !existed { - generator.Send(&AgentEvent{AgentName: a.name, Err: fmt.Errorf("interrupt occurred but checkpoint data is missing")}) - return - } - - is := FromInterruptContexts(info.InterruptContexts) - - event := CompositeInterrupt(ctx, info, data, is) - event.Action.Interrupted.Data = &ChatModelAgentInterruptInfo{ - Info: info, - Data: data, - } - event.AgentName = a.name - generator.Send(event) + a.handleRunFuncError(ctx, err_, cancelCtx, p.cancelCtxOwned, p.store, p.generator) }, nil } @@ -981,24 +1052,25 @@ func (a *ChatModelAgent) getRunFunc(ctx context.Context) (context.Context, runFu } func (a *ChatModelAgent) Run(ctx context.Context, input *AgentInput, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { - iter, _ := a.runInternal(ctx, input, false, opts...) - return iter -} - -func (a *ChatModelAgent) RunWithCancel(ctx context.Context, input *AgentInput, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) { - return a.runInternal(ctx, input, true, opts...) -} - -func (a *ChatModelAgent) runInternal(ctx context.Context, input *AgentInput, withCancel bool, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) { iterator, generator := NewAsyncIteratorPair[*AgentEvent]() + o := getCommonOptions(nil, opts...) + cancelCtx := o.cancelCtx + cancelCtxOwned := cancelCtx != nil && getCancelContext(ctx) == nil + if cancelCtx == nil { + cancelCtx = getCancelContext(ctx) + } + ctx, run, bc, err := a.getRunFunc(ctx) if err != nil { go func() { - generator.Send(&AgentEvent{Err: err}) + if cancelCtxOwned && cancelCtx != nil { + defer cancelCtx.markDone() + } + generator.Send(&AgentEvent{Err: fmt.Errorf("ChatModelAgent getRunFunc error: %w", err)}) generator.Close() }() - return iterator, notCancellableFuncInternal + return iterator } co := getComposeOptions(opts) @@ -1011,13 +1083,6 @@ func (a *ChatModelAgent) runInternal(ctx context.Context, input *AgentInput, wit } } - var cs *cancelSig - var cancelFn CancelFunc = notCancellableFuncInternal - if withCancel { - cs = newCancelSig() - cancelFn = buildCancelFunc(cs) - } - go func() { defer func() { panicErr := recover() @@ -1039,47 +1104,44 @@ func (a *ChatModelAgent) runInternal(ctx context.Context, input *AgentInput, wit returnDirectly = bc.returnDirectly } - run(ctx, input, generator, newBridgeStore(), instruction, returnDirectly, cs, co...) + run(ctx, &runParams{ + input: input, + generator: generator, + store: newBridgeStore(), + instruction: instruction, + returnDirectly: returnDirectly, + cancelCtx: cancelCtx, + cancelCtxOwned: cancelCtxOwned, + composeOpts: co, + }) }() - return iterator, cancelFn -} - -func buildCancelFunc(cs *cancelSig) CancelFunc { - var once sync.Once - return func(opts ...CancelOption) error { - cfg := &cancelConfig{ - Mode: CancelImmediate, - } - for _, opt := range opts { - opt(cfg) - } - once.Do(func() { - cs.cancel(cfg) - }) - return nil + if cancelCtxOwned { + return wrapIterWithCancelCtx(iterator, cancelCtx) } + return iterator } func (a *ChatModelAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { - iter, _ := a.resumeInternal(ctx, info, false, opts...) - return iter -} - -func (a *ChatModelAgent) ResumeWithCancel(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) { - return a.resumeInternal(ctx, info, true, opts...) -} - -func (a *ChatModelAgent) resumeInternal(ctx context.Context, info *ResumeInfo, withCancel bool, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) { iterator, generator := NewAsyncIteratorPair[*AgentEvent]() + o := getCommonOptions(nil, opts...) + cancelCtx := o.cancelCtx + cancelCtxOwned := cancelCtx != nil && getCancelContext(ctx) == nil + if cancelCtx == nil { + cancelCtx = getCancelContext(ctx) + } + ctx, run, bc, err := a.getRunFunc(ctx) if err != nil { go func() { - generator.Send(&AgentEvent{Err: err}) + if cancelCtxOwned && cancelCtx != nil { + defer cancelCtx.markDone() + } + generator.Send(&AgentEvent{Err: fmt.Errorf("ChatModelAgent getRunFunc error: %w", err)}) generator.Close() }() - return iterator, notCancellableFuncInternal + return iterator } co := getComposeOptions(opts) @@ -1092,19 +1154,18 @@ func (a *ChatModelAgent) resumeInternal(ctx context.Context, info *ResumeInfo, w } } - methodName := "Resume" - if withCancel { - methodName = "ResumeWithCancel" + if info == nil { + panic(fmt.Sprintf("ChatModelAgent.Resume: agent '%s' was asked to resume but info is nil", a.Name(ctx))) } if info.InterruptState == nil { - panic(fmt.Sprintf("ChatModelAgent.%s: agent '%s' was asked to resume but has no state", methodName, a.Name(ctx))) + panic(fmt.Sprintf("ChatModelAgent.Resume: agent '%s' was asked to resume but has no state", a.Name(ctx))) } stateByte, ok := info.InterruptState.([]byte) if !ok { - panic(fmt.Sprintf("ChatModelAgent.%s: agent '%s' was asked to resume but has invalid interrupt state type: %T", - methodName, a.Name(ctx), info.InterruptState)) + panic(fmt.Sprintf("ChatModelAgent.Resume: agent '%s' was asked to resume but has invalid interrupt state type: %T", + a.Name(ctx), info.InterruptState)) } // Migrate legacy checkpoints before resume. @@ -1119,15 +1180,15 @@ func (a *ChatModelAgent) resumeInternal(ctx context.Context, info *ResumeInfo, w generator.Send(&AgentEvent{Err: err}) generator.Close() }() - return iterator, notCancellableFuncInternal + return iterator } var historyModifier func(ctx context.Context, history []Message) []Message if info.ResumeData != nil { resumeData, ok := info.ResumeData.(*ChatModelAgentResumeData) if !ok { - panic(fmt.Sprintf("ChatModelAgent.%s: agent '%s' was asked to resume but has invalid resume data type: %T", - methodName, a.Name(ctx), info.ResumeData)) + panic(fmt.Sprintf("ChatModelAgent.Resume: agent '%s' was asked to resume but has invalid resume data type: %T", + a.Name(ctx), info.ResumeData)) } historyModifier = resumeData.HistoryModifier } @@ -1143,13 +1204,6 @@ func (a *ChatModelAgent) resumeInternal(ctx context.Context, info *ResumeInfo, w })) } - var cs *cancelSig - var cancelFn CancelFunc = notCancellableFuncInternal - if withCancel { - cs = newCancelSig() - cancelFn = buildCancelFunc(cs) - } - go func() { defer func() { panicErr := recover() @@ -1171,11 +1225,22 @@ func (a *ChatModelAgent) resumeInternal(ctx context.Context, info *ResumeInfo, w returnDirectly = bc.returnDirectly } - run(ctx, &AgentInput{EnableStreaming: info.EnableStreaming}, generator, - newResumeBridgeStore(stateByte), instruction, returnDirectly, cs, co...) + run(ctx, &runParams{ + input: &AgentInput{EnableStreaming: info.EnableStreaming}, + generator: generator, + store: newResumeBridgeStore(bridgeCheckpointID, stateByte), + instruction: instruction, + returnDirectly: returnDirectly, + cancelCtx: cancelCtx, + cancelCtxOwned: cancelCtxOwned, + composeOpts: co, + }) }() - return iterator, cancelFn + if cancelCtxOwned { + return wrapIterWithCancelCtx(iterator, cancelCtx) + } + return iterator } func getComposeOptions(opts []AgentRunOption) []compose.Option { diff --git a/adk/chatmodel_retry_test.go b/adk/chatmodel_retry_test.go index 00c89b352..0cb2a87bd 100644 --- a/adk/chatmodel_retry_test.go +++ b/adk/chatmodel_retry_test.go @@ -1046,3 +1046,148 @@ func TestSequentialWorkflow_NoRetryConfig_StreamError_StopsFlow(t *testing.T) { assert.Equal(t, 0, len(capturingModel.capturedInputs), "Agent B should NOT be called due to error") assert.Equal(t, int32(1), atomic.LoadInt32(&noRetryModel.callCount), "Model should only be called once (no retry)") } + +// failThenToolCallStreamModel is a ChatModel that: +// - First Stream() call: yields a partial chunk then fails with a retryable error mid-stream. +// - Second Stream() call (retry): yields a tool-call message (success). +// - Third Generate() call (after tool result): yields a final assistant message. +// +// This exercises the path where the eventSenderModel copies the first stream, +// wraps its error as WillRetryError, and sends it as an event to the session. +// The retryModelWrapper then retries, gets a clean stream with a tool call, +// the tool interrupts, and checkpoint save needs to gob-encode the session +// (which still contains the unconsumed WillRetryError event stream). +type failThenToolCallStreamModel struct { + streamCallCount int32 + genCallCount int32 +} + +func (m *failThenToolCallStreamModel) Generate(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m.genCallCount, 1) + return schema.AssistantMessage("final answer", nil), nil +} + +func (m *failThenToolCallStreamModel) Stream(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + count := atomic.AddInt32(&m.streamCallCount, 1) + + sr, sw := schema.Pipe[*schema.Message](10) + go func() { + defer sw.Close() + if count == 1 { + // First call: yield a partial chunk then fail. + sw.Send(schema.AssistantMessage("partial", nil), nil) + sw.Send(nil, errRetryAble) + return + } + // Second call (retry): yield a tool-call message. + sw.Send(schema.AssistantMessage("", []schema.ToolCall{{ + ID: "call-1", + Function: schema.FunctionCall{ + Name: "interrupt_tool", + Arguments: `{}`, + }, + }}), nil) + }() + return sr, nil +} + +func (m *failThenToolCallStreamModel) WithTools(_ []*schema.ToolInfo) (model.ToolCallingChatModel, error) { + return m, nil +} + +// interruptToolForRetryTest is a tool that always interrupts. +type interruptToolForRetryTest struct{} + +func (t *interruptToolForRetryTest) Info(_ context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: "interrupt_tool", + Desc: "tool that interrupts", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "input": {Type: "string"}, + }), + }, nil +} + +func (t *interruptToolForRetryTest) InvokableRun(ctx context.Context, _ string, _ ...tool.Option) (string, error) { + return "", tool.Interrupt(ctx, "interrupted by tool") +} + +// TestCheckpointSave_WillRetryError_StreamNotConsumed verifies that checkpoint +// saving succeeds when the session contains an event with an unconsumed stream +// that ends with WillRetryError. +// +// Scenario: +// 1. ChatModelAgent with retry (MaxRetries=1) and a tool that always interrupts +// 2. Model.Stream() #1 yields "partial" then errRetryAble mid-stream +// → eventSenderModel copies the stream, wraps the error as WillRetryError, +// sends the event to the session (stream NOT consumed by anyone yet) +// → retryModelWrapper detects error on its copy, retries +// 3. Model.Stream() #2 succeeds with a tool-call message +// 4. Tool executes → interrupts +// 5. Runner.handleIter sees the interrupt → saveCheckPoint → gob encodes runSession +// 6. The session has the WillRetryError event with an unconsumed stream +// → agentEventWrapper.GobEncode proactively consumes the stream via +// getMessageFromWrappedEvent, so MessageVariant.GobEncode sees an error-free +// array and succeeds +func TestCheckpointSave_WillRetryError_StreamNotConsumed(t *testing.T) { + ctx := context.Background() + + mdl := &failThenToolCallStreamModel{} + itool := &interruptToolForRetryTest{} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Agent for checkpoint stream error test", + Instruction: "You are a test agent.", + Model: mdl, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{itool}, + }, + }, + ModelRetryConfig: &ModelRetryConfig{ + MaxRetries: 1, + IsRetryAble: func(_ context.Context, err error) bool { + return errors.Is(err, errRetryAble) + }, + BackoffFunc: func(_ context.Context, _ int) time.Duration { + return time.Millisecond // fast retry for test + }, + }, + }) + assert.NoError(t, err) + + store := newMyStore() + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + EnableStreaming: true, + CheckPointStore: store, + }) + + iter := runner.Run(ctx, + []Message{schema.UserMessage("hello")}, + WithCheckPointID("ckpt-1"), + ) + + var events []*AgentEvent + for { + event, ok := iter.Next() + if !ok { + break + } + events = append(events, event) + + if event.Err != nil { + t.Logf("event error: %v", event.Err) + } + } + + // Verify the checkpoint was saved successfully. + _, exists, _ := store.Get(ctx, "ckpt-1") + assert.True(t, exists, "checkpoint should be saved successfully; "+ + "if this fails, the WillRetryError stream in the session caused gob encoding to fail") + + // Sanity: the model should have been called twice for Stream (fail + retry). + assert.Equal(t, int32(2), atomic.LoadInt32(&mdl.streamCallCount), + "model should be called twice: first fail, then retry success") +} diff --git a/adk/flow.go b/adk/flow.go index 7acfd2137..52a346c74 100644 --- a/adk/flow.go +++ b/adk/flow.go @@ -333,15 +333,6 @@ func buildDefaultHistoryRewriter(agentName string) HistoryRewriter { } func (a *flowAgent) Run(ctx context.Context, input *AgentInput, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { - iter, _ := a.runInternal(ctx, input, false, opts...) - return iter -} - -func (a *flowAgent) RunWithCancel(ctx context.Context, input *AgentInput, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) { - return a.runInternal(ctx, input, true, opts...) -} - -func (a *flowAgent) runInternal(ctx context.Context, input *AgentInput, withCancel bool, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) { agentName := a.Name(ctx) var runCtx *runContext @@ -349,12 +340,16 @@ func (a *flowAgent) runInternal(ctx context.Context, input *AgentInput, withCanc ctx = AppendAddressSegment(ctx, AddressSegmentAgent, agentName) o := getCommonOptions(nil, opts...) + cancelCtx := o.cancelCtx processedInput, err := a.genAgentInput(ctx, runCtx, o.skipTransferMessages) if err != nil { + if cancelCtx != nil { + cancelCtx.markDone() + } cbInput := &AgentCallbackInput{Input: input} ctx = callbacks.OnStart(ctx, cbInput) - return wrapIterWithOnEnd(ctx, genErrorIter(err)), notCancellableFuncInternal + return wrapIterWithOnEnd(ctx, genErrorIter(err)) } ctxForSubAgents := ctx @@ -367,105 +362,90 @@ func (a *flowAgent) runInternal(ctx context.Context, input *AgentInput, withCanc input = processedInput if wf, ok := a.Agent.(*workflowAgent); ok { - return wrapIterWithOnEnd(ctx, wf.Run(ctx, input, filterCallbackHandlersForNestedAgents(agentName, opts)...)), notCancellableFuncInternal + ctx = withCancelContext(ctx, cancelCtx) + filteredOpts := filterCancelOption(filterCallbackHandlersForNestedAgents(agentName, opts)) + iter := wf.Run(ctx, input, filteredOpts...) + iter = wrapIterWithCancelCtx(iter, cancelCtx) + return wrapIterWithOnEnd(ctx, iter) } - var aIter *AsyncIterator[*AgentEvent] - var cancelFn CancelFunc = notCancellableFuncInternal - - ca, supportCancel := a.Agent.(CancellableAgent) - if withCancel && supportCancel { - aIter, cancelFn = ca.RunWithCancel(ctx, input, filterOptions(agentName, opts)...) - } else { - aIter = a.Agent.Run(ctx, input, filterOptions(agentName, opts)...) - } + aIter := a.Agent.Run(withCancelContext(ctx, cancelCtx), input, filterOptions(agentName, opts)...) iterator, generator := NewAsyncIteratorPair[*AgentEvent]() - go a.run(ctx, ctxForSubAgents, runCtx, aIter, generator, opts...) - - return iterator, cancelFn -} + go a.run(withCancelContext(ctx, cancelCtx), withCancelContext(ctxForSubAgents, cancelCtx), runCtx, aIter, generator, filterCancelOption(opts)...) -func notCancellableFuncInternal(_ ...CancelOption) error { - return ErrAgentNotCancellable + return wrapIterWithCancelCtx(iterator, cancelCtx) } func (a *flowAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { - iter, _ := a.resumeInternal(ctx, info, false, opts...) - return iter -} - -func (a *flowAgent) ResumeWithCancel(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) { - return a.resumeInternal(ctx, info, true, opts...) -} - -func (a *flowAgent) resumeInternal(ctx context.Context, info *ResumeInfo, withCancel bool, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) { agentName := a.Name(ctx) ctx, info = buildResumeInfo(ctx, agentName, info) ctxForSubAgents := ctx + o := getCommonOptions(nil, opts...) + cancelCtx := o.cancelCtx + agentType := getAgentType(a.Agent) ctx = initAgentCallbacks(ctx, agentName, agentType, filterOptions(agentName, opts)...) cbInput := &AgentCallbackInput{ResumeInfo: info} ctx = callbacks.OnStart(ctx, cbInput) if info.WasInterrupted { - var aIter *AsyncIterator[*AgentEvent] - var cancelFn CancelFunc = notCancellableFuncInternal - - ca, supportCancel := a.Agent.(CancellableResumableAgent) - if withCancel && supportCancel { - aIter, cancelFn = ca.ResumeWithCancel(ctx, info, opts...) - } else if ra, ok := a.Agent.(ResumableAgent); ok { + if ra, ok := a.Agent.(ResumableAgent); ok { if _, ok := ra.(*workflowAgent); ok { - filteredOpts := filterCallbackHandlersForNestedAgents(agentName, opts) + ctx = withCancelContext(ctx, cancelCtx) + filteredOpts := filterCancelOption(filterCallbackHandlersForNestedAgents(agentName, opts)) aIter := ra.Resume(ctx, info, filteredOpts...) - return wrapIterWithOnEnd(ctx, aIter), cancelFn + aIter = wrapIterWithCancelCtx(aIter, cancelCtx) + return wrapIterWithOnEnd(ctx, aIter) } - aIter = ra.Resume(ctx, info, opts...) - } else { - return wrapIterWithOnEnd(ctx, genErrorIter(fmt.Errorf("failed to resume agent: agent '%s' is an interrupt point "+ - "but is not a ResumableAgent", agentName))), notCancellableFuncInternal + + aIter := ra.Resume(withCancelContext(ctx, cancelCtx), info, opts...) + + iterator, generator := NewAsyncIteratorPair[*AgentEvent]() + go a.run(withCancelContext(ctx, cancelCtx), withCancelContext(ctxForSubAgents, cancelCtx), getRunCtx(ctxForSubAgents), aIter, generator, filterCancelOption(opts)...) + return wrapIterWithCancelCtx(iterator, cancelCtx) } - iterator, generator := NewAsyncIteratorPair[*AgentEvent]() - go a.run(ctx, ctxForSubAgents, getRunCtx(ctxForSubAgents), aIter, generator, opts...) - return iterator, cancelFn + if cancelCtx != nil { + cancelCtx.markDone() + } + return wrapIterWithOnEnd(ctx, genErrorIter(fmt.Errorf("failed to resume agent: agent '%s' is an interrupt point "+ + "but is not a ResumableAgent", agentName))) } nextAgentName, err := getNextResumeAgent(ctx, info) if err != nil { - return wrapIterWithOnEnd(ctx, genErrorIter(err)), notCancellableFuncInternal + if cancelCtx != nil { + cancelCtx.markDone() + } + return wrapIterWithOnEnd(ctx, genErrorIter(err)) } subAgent := a.getAgent(ctxForSubAgents, nextAgentName) if subAgent == nil { if len(a.subAgents) == 0 { - ca, supportCancel := a.Agent.(CancellableResumableAgent) - if withCancel && supportCancel { - iter, cancelFn := ca.ResumeWithCancel(ctx, info, opts...) - return wrapIterWithOnEnd(ctx, iter), cancelFn - } if ra, ok := a.Agent.(ResumableAgent); ok { - return wrapIterWithOnEnd(ctx, ra.Resume(ctx, info, opts...)), notCancellableFuncInternal + ctx = withCancelContext(ctx, cancelCtx) + innerIter := ra.Resume(ctx, info, filterCancelOption(opts)...) + return wrapIterWithCancelCtx(wrapIterWithOnEnd(ctx, innerIter), cancelCtx) } return wrapIterWithOnEnd(ctx, genErrorIter(fmt.Errorf( "failed to resume agent: agent '%s' (type %T) has no sub-agents and does not implement ResumableAgent interface. "+ - "To support resume, your custom agent wrapper must implement the ResumableAgent interface", agentName, a.Agent))), nil + "To support resume, your custom agent wrapper must implement the ResumableAgent interface", agentName, a.Agent))) } - return wrapIterWithOnEnd(ctx, genErrorIter(fmt.Errorf("failed to resume agent: sub-agent '%s' not found in agent '%s'", nextAgentName, agentName))), notCancellableFuncInternal - } - - ca, supportCancel := ResumableAgent(subAgent).(CancellableResumableAgent) - if withCancel && supportCancel { - iter, cancelFn := ca.ResumeWithCancel(ctxForSubAgents, info, opts...) - return wrapIterWithOnEnd(ctx, iter), cancelFn + if cancelCtx != nil { + cancelCtx.markDone() + } + return wrapIterWithOnEnd(ctx, genErrorIter(fmt.Errorf("failed to resume agent: sub-agent '%s' not found in agent '%s'", nextAgentName, agentName))) } - return wrapIterWithOnEnd(ctx, subAgent.Resume(ctxForSubAgents, info, opts...)), notCancellableFuncInternal + ctxForSubAgents = withCancelContext(ctxForSubAgents, cancelCtx) + innerIter := subAgent.Resume(ctxForSubAgents, info, filterCancelOption(opts)...) + return wrapIterWithCancelCtx(wrapIterWithOnEnd(ctx, innerIter), cancelCtx) } type DeterministicTransferConfig struct { diff --git a/adk/handler.go b/adk/handler.go index 7c7ebba71..423282a7a 100644 --- a/adk/handler.go +++ b/adk/handler.go @@ -47,6 +47,12 @@ type ToolContext struct { CallID string } +// ToolCallsContext contains metadata about the tool calls that just completed. +type ToolCallsContext struct { + // ToolCalls contains the tool call metadata from the model's response. + ToolCalls []ToolContext +} + // ModelContext contains context information passed to WrapModel. type ModelContext struct { // Tools contains the current tool list configured for the agent. @@ -57,6 +63,8 @@ type ModelContext struct { // This is populated at request time from the agent's ModelRetryConfig. // Used by EventSenderModelWrapper to wrap stream errors appropriately. ModelRetryConfig *ModelRetryConfig + + cancelContext *cancelContext } // ChatModelAgentContext contains runtime information passed to handlers before each ChatModelAgent run. @@ -138,6 +146,14 @@ type ChatModelAgentMiddleware interface { // - Tools: the current tool list that was sent to the model AfterModelRewriteState(ctx context.Context, state *ChatModelAgentState, mc *ModelContext) (context.Context, *ChatModelAgentState, error) + // AfterToolCallsRewriteState is called after all concurrent tool calls in an iteration complete. + // The input state includes all messages up to and including the tool call results. + // The returned state is persisted to the agent's internal state. + // + // The ToolCallsContext provides metadata about the tool calls that just completed, + // derived from the assistant message's ToolCalls field. + AfterToolCallsRewriteState(ctx context.Context, state *ChatModelAgentState, tc *ToolCallsContext) (context.Context, *ChatModelAgentState, error) + // WrapInvokableToolCall wraps a tool's synchronous execution with custom behavior. // Return the input endpoint unchanged and nil error if no wrapping is needed. // @@ -247,6 +263,10 @@ func (b *BaseChatModelAgentMiddleware) AfterModelRewriteState(ctx context.Contex return ctx, state, nil } +func (b *BaseChatModelAgentMiddleware) AfterToolCallsRewriteState(ctx context.Context, state *ChatModelAgentState, tc *ToolCallsContext) (context.Context, *ChatModelAgentState, error) { + return ctx, state, nil +} + // SetRunLocalValue sets a key-value pair that persists for the duration of the current agent Run() invocation. // The value is scoped to this specific execution and is not shared across different Run() calls or agent instances. // diff --git a/adk/handler_test.go b/adk/handler_test.go index e56da3842..abdb0ecab 100644 --- a/adk/handler_test.go +++ b/adk/handler_test.go @@ -111,6 +111,15 @@ func (h *testAfterModelRewriteStateHandler) AfterModelRewriteState(ctx context.C return h.fn(ctx, state, mc) } +type testAfterToolCallsHandler struct { + *BaseChatModelAgentMiddleware + fn func(ctx context.Context, state *ChatModelAgentState, tc *ToolCallsContext) (context.Context, *ChatModelAgentState, error) +} + +func (h *testAfterToolCallsHandler) AfterToolCallsRewriteState(ctx context.Context, state *ChatModelAgentState, tc *ToolCallsContext) (context.Context, *ChatModelAgentState, error) { + return h.fn(ctx, state, tc) +} + type testToolWrapperHandler struct { *BaseChatModelAgentMiddleware wrapInvokableFn func(context.Context, InvokableToolCallEndpoint, *ToolContext) InvokableToolCallEndpoint @@ -1820,3 +1829,312 @@ func TestToolContextInWrappers(t *testing.T) { assert.Equal(t, "test_call_id_123", capturedCallID, "ToolContext should have correct call ID") }) } + +func TestAfterToolCallsRewriteState(t *testing.T) { + t.Run("ReceivesCorrectToolCallsContext", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + tool1 := &namedTool{name: "tool_alpha"} + tool2 := &namedTool{name: "tool_beta"} + + cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() + + // First call: model returns two tool calls + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + Return(schema.AssistantMessage("calling tools", []schema.ToolCall{ + {ID: "call_1", Function: schema.FunctionCall{Name: "tool_alpha", Arguments: "{}"}}, + {ID: "call_2", Function: schema.FunctionCall{Name: "tool_beta", Arguments: "{}"}}, + }), nil).Times(1) + + // Second call: model returns final response + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + Return(schema.AssistantMessage("done", nil), nil).Times(1) + + var mu sync.Mutex + var capturedTC *ToolCallsContext + var capturedState *ChatModelAgentState + callCount := 0 + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: cm, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{tool1, tool2}, + }, + }, + Handlers: []ChatModelAgentMiddleware{ + &testAfterToolCallsHandler{fn: func(ctx context.Context, state *ChatModelAgentState, tc *ToolCallsContext) (context.Context, *ChatModelAgentState, error) { + mu.Lock() + callCount++ + capturedTC = tc + capturedState = &ChatModelAgentState{Messages: append([]Message{}, state.Messages...)} + mu.Unlock() + return ctx, state, nil + }}, + }, + }) + assert.NoError(t, err) + + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("test")}}) + for { + _, ok := iter.Next() + if !ok { + break + } + } + + mu.Lock() + defer mu.Unlock() + + // Should be called exactly once (one iteration with tool calls) + assert.Equal(t, 1, callCount) + + // ToolCallsContext should have the two tool calls + assert.NotNil(t, capturedTC) + assert.Len(t, capturedTC.ToolCalls, 2) + assert.Equal(t, "tool_alpha", capturedTC.ToolCalls[0].Name) + assert.Equal(t, "call_1", capturedTC.ToolCalls[0].CallID) + assert.Equal(t, "tool_beta", capturedTC.ToolCalls[1].Name) + assert.Equal(t, "call_2", capturedTC.ToolCalls[1].CallID) + + // State should contain: system msg + user msg + assistant msg + 2 tool results + assert.NotNil(t, capturedState) + assert.True(t, len(capturedState.Messages) >= 4, "expected at least 4 messages, got %d", len(capturedState.Messages)) + + // Check tool results are in state + toolResultCount := 0 + for _, msg := range capturedState.Messages { + if msg.Role == schema.Tool { + toolResultCount++ + } + } + assert.Equal(t, 2, toolResultCount) + }) + + t.Run("NotCalledWithoutToolCalls", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + // Model returns a direct response with no tool calls + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + Return(schema.AssistantMessage("direct response", nil), nil).Times(1) + + callCount := 0 + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: cm, + Handlers: []ChatModelAgentMiddleware{ + &testAfterToolCallsHandler{fn: func(ctx context.Context, state *ChatModelAgentState, tc *ToolCallsContext) (context.Context, *ChatModelAgentState, error) { + callCount++ + return ctx, state, nil + }}, + }, + }) + assert.NoError(t, err) + + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("test")}}) + for { + _, ok := iter.Next() + if !ok { + break + } + } + + assert.Equal(t, 0, callCount, "AfterToolCallsRewriteState should not be called when no tool calls happen") + }) + + t.Run("CanModifyStatePersistsToNextIteration", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + tool1 := &namedTool{name: "mytool"} + cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() + + // First call: model returns a tool call + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + Return(schema.AssistantMessage("calling", []schema.ToolCall{ + {ID: "c1", Function: schema.FunctionCall{Name: "mytool", Arguments: "{}"}}, + }), nil).Times(1) + + // Second call: capture messages to verify the injected message is present + var capturedMsgs []*schema.Message + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, msgs []*schema.Message, opts ...interface{}) (*schema.Message, error) { + capturedMsgs = append([]*schema.Message{}, msgs...) + return schema.AssistantMessage("final", nil), nil + }).Times(1) + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: cm, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{tool1}, + }, + }, + Handlers: []ChatModelAgentMiddleware{ + &testAfterToolCallsHandler{fn: func(ctx context.Context, state *ChatModelAgentState, tc *ToolCallsContext) (context.Context, *ChatModelAgentState, error) { + // Inject a user message into state + state.Messages = append(state.Messages, schema.UserMessage("injected_by_middleware")) + return ctx, state, nil + }}, + }, + }) + assert.NoError(t, err) + + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("original")}}) + for { + _, ok := iter.Next() + if !ok { + break + } + } + + // The injected message should be visible in the second model call + assert.NotNil(t, capturedMsgs) + found := false + for _, msg := range capturedMsgs { + if msg.Content == "injected_by_middleware" { + found = true + break + } + } + assert.True(t, found, "Injected message should persist to the next model call") + }) +} + +func TestToolResultNotDuplicated(t *testing.T) { + t.Run("SecondModelCallHasNoToolResultDuplication", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + tool1 := &namedTool{name: "mytool"} + cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() + + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + Return(schema.AssistantMessage("calling", []schema.ToolCall{ + {ID: "c1", Function: schema.FunctionCall{Name: "mytool", Arguments: "{}"}}, + }), nil).Times(1) + + var capturedMsgs []*schema.Message + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, msgs []*schema.Message, opts ...interface{}) (*schema.Message, error) { + capturedMsgs = append([]*schema.Message{}, msgs...) + return schema.AssistantMessage("final", nil), nil + }).Times(1) + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Instruction: "You are helpful.", + Model: cm, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{tool1}, + }, + }, + }) + assert.NoError(t, err) + + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("hello")}}) + for { + _, ok := iter.Next() + if !ok { + break + } + } + + assert.NotNil(t, capturedMsgs) + assert.Equal(t, 4, len(capturedMsgs), + "expected [system, user, assistant, tool_result], got %d messages", len(capturedMsgs)) + assert.Equal(t, schema.System, capturedMsgs[0].Role) + assert.Equal(t, schema.User, capturedMsgs[1].Role) + assert.Equal(t, schema.Assistant, capturedMsgs[2].Role) + assert.Equal(t, schema.Tool, capturedMsgs[3].Role) + + toolResultCount := 0 + for _, msg := range capturedMsgs { + if msg.Role == schema.Tool { + toolResultCount++ + } + } + assert.Equal(t, 1, toolResultCount, + "tool result should appear exactly once, got %d", toolResultCount) + }) + + t.Run("HandlerInjectedMessagePresentWithoutDuplication", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + tool1 := &namedTool{name: "mytool"} + cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() + + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + Return(schema.AssistantMessage("calling", []schema.ToolCall{ + {ID: "c1", Function: schema.FunctionCall{Name: "mytool", Arguments: "{}"}}, + }), nil).Times(1) + + var capturedMsgs []*schema.Message + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, msgs []*schema.Message, opts ...interface{}) (*schema.Message, error) { + capturedMsgs = append([]*schema.Message{}, msgs...) + return schema.AssistantMessage("final", nil), nil + }).Times(1) + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Instruction: "You are helpful.", + Model: cm, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{tool1}, + }, + }, + Handlers: []ChatModelAgentMiddleware{ + &testAfterToolCallsHandler{fn: func(ctx context.Context, state *ChatModelAgentState, tc *ToolCallsContext) (context.Context, *ChatModelAgentState, error) { + state.Messages = append(state.Messages, schema.UserMessage("injected")) + return ctx, state, nil + }}, + }, + }) + assert.NoError(t, err) + + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("hello")}}) + for { + _, ok := iter.Next() + if !ok { + break + } + } + + assert.NotNil(t, capturedMsgs) + assert.Equal(t, 5, len(capturedMsgs), + "expected [system, user, assistant, tool_result, injected], got %d messages", len(capturedMsgs)) + assert.Equal(t, schema.System, capturedMsgs[0].Role) + assert.Equal(t, schema.User, capturedMsgs[1].Role) + assert.Equal(t, schema.Assistant, capturedMsgs[2].Role) + assert.Equal(t, schema.Tool, capturedMsgs[3].Role) + assert.Equal(t, "injected", capturedMsgs[4].Content) + + toolResultCount := 0 + for _, msg := range capturedMsgs { + if msg.Role == schema.Tool { + toolResultCount++ + } + } + assert.Equal(t, 1, toolResultCount, + "tool result should appear exactly once, got %d", toolResultCount) + }) +} diff --git a/adk/interface.go b/adk/interface.go index c73705ae3..5c06843ae 100644 --- a/adk/interface.go +++ b/adk/interface.go @@ -20,10 +20,8 @@ import ( "bytes" "context" "encoding/gob" - "errors" "fmt" "io" - "time" "github.com/cloudwego/eino/components" "github.com/cloudwego/eino/internal/core" @@ -271,55 +269,3 @@ type ResumableAgent interface { Resume(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] } - -// CancelMode specifies when an agent should be canceled. -// Modes can be combined with bitwise OR to cancel at multiple execution points. -// For example, CancelAfterChatModel | CancelAfterToolCall cancels the agent -// after whichever execution point is reached first. -type CancelMode int - -const ( - // CancelImmediate cancels the agent immediately without waiting - // for any execution point. - CancelImmediate CancelMode = 0 - // CancelAfterChatModel cancels the agent after a chat model call completes. - CancelAfterChatModel CancelMode = 1 << iota - // CancelAfterToolCall cancels the agent after a tool call completes. - CancelAfterToolCall -) - -// ErrAgentNotCancellable is returned by Cancel when the agent does not support cancellation. -var ErrAgentNotCancellable = errors.New("agent does not implement CancellableAgent interface") - -type cancelConfig struct { - Mode CancelMode - Timeout *time.Duration -} - -type CancelOption func(*cancelConfig) - -// WithCancelMode sets the cancel mode for the cancel operation. -func WithCancelMode(mode CancelMode) CancelOption { - return func(config *cancelConfig) { - config.Mode = mode - } -} - -// WithCancelTimeout sets a timeout duration for CancelImmediate mode. -func WithCancelTimeout(timeout time.Duration) CancelOption { - return func(config *cancelConfig) { - config.Timeout = &timeout - } -} - -type CancelFunc func(...CancelOption) error - -type CancellableAgent interface { - Agent - RunWithCancel(ctx context.Context, input *AgentInput, options ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) -} - -type CancellableResumableAgent interface { - ResumableAgent - ResumeWithCancel(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) -} diff --git a/adk/interrupt.go b/adk/interrupt.go index 5941d0724..fce09d4cf 100644 --- a/adk/interrupt.go +++ b/adk/interrupt.go @@ -22,6 +22,7 @@ import ( "encoding/gob" "errors" "fmt" + "sync" "github.com/cloudwego/eino/internal/core" "github.com/cloudwego/eino/schema" @@ -183,6 +184,11 @@ func WithCheckPointID(id string) AgentRunOption { func init() { schema.RegisterName[*serialization]("_eino_adk_serialization") schema.RegisterName[*WorkflowInterruptInfo]("_eino_adk_workflow_interrupt_info") + // Register []byte for gob: the cancel refactor routes bridge store checkpoint + // bytes ([]byte) through InterruptState.State (type any) inside the outer + // serialization struct. Gob requires concrete types behind interface fields + // to be registered. + gob.Register([]byte{}) } // serialization CheckpointSchema: root checkpoint payload (gob). @@ -266,6 +272,10 @@ func (r *Runner) saveCheckPoint( info *InterruptInfo, is *core.InterruptSignal, ) error { + if r.store == nil { + return nil + } + runCtx := getRunCtx(ctx) id2Addr, id2State := core.SignalToPersistenceMaps(is) @@ -287,31 +297,36 @@ func (r *Runner) saveCheckPoint( const bridgeCheckpointID = "adk_react_mock_key" func newBridgeStore() *bridgeStore { - return &bridgeStore{} + return &bridgeStore{data: make(map[string][]byte)} } -func newResumeBridgeStore(data []byte) *bridgeStore { +func newResumeBridgeStore(checkPointID string, data []byte) *bridgeStore { return &bridgeStore{ - Data: data, - Valid: true, + data: map[string][]byte{checkPointID: data}, } } type bridgeStore struct { - Data []byte - Valid bool + mu sync.Mutex + data map[string][]byte } -func (m *bridgeStore) Get(_ context.Context, _ string) ([]byte, bool, error) { - if m.Valid { - return m.Data, true, nil +func (m *bridgeStore) Get(_ context.Context, key string) ([]byte, bool, error) { + m.mu.Lock() + defer m.mu.Unlock() + if v, ok := m.data[key]; ok { + return v, true, nil } return nil, false, nil } -func (m *bridgeStore) Set(_ context.Context, _ string, checkPoint []byte) error { - m.Data = checkPoint - m.Valid = true +func (m *bridgeStore) Set(_ context.Context, key string, checkPoint []byte) error { + m.mu.Lock() + defer m.mu.Unlock() + if m.data == nil { + m.data = make(map[string][]byte) + } + m.data[key] = checkPoint return nil } diff --git a/adk/middlewares/patchtoolcalls/patchtoolcalls.go b/adk/middlewares/patchtoolcalls/patchtoolcalls.go index 75fb5fcbf..833ca3794 100644 --- a/adk/middlewares/patchtoolcalls/patchtoolcalls.go +++ b/adk/middlewares/patchtoolcalls/patchtoolcalls.go @@ -121,6 +121,6 @@ func (m *middleware) createPatchedToolMessage(ctx context.Context, tc schema.Too } const ( - defaultPatchedToolMessageTemplate = "Tool call %s with id %s was cancelled - another message came in before it could be completed." + defaultPatchedToolMessageTemplate = "Tool call %s with id %s was canceled - another message came in before it could be completed." defaultPatchedToolMessageTemplateChinese = "工具调用 %s(ID 为 %s)已被取消——在其完成之前收到了另一条消息。" ) diff --git a/adk/prebuilt/planexecute/plan_execute_test.go b/adk/prebuilt/planexecute/plan_execute_test.go index fb7360357..6734a16b8 100644 --- a/adk/prebuilt/planexecute/plan_execute_test.go +++ b/adk/prebuilt/planexecute/plan_execute_test.go @@ -18,9 +18,12 @@ package planexecute import ( "context" + "errors" "fmt" "strings" + "sync" "testing" + "time" "github.com/bytedance/sonic" "github.com/stretchr/testify/assert" @@ -1002,3 +1005,232 @@ func TestPlanExecuteAgentInterruptResume(t *testing.T) { assert.True(t, hasAssistantCompletion, "Should have assistant completion message") assert.True(t, hasBreakLoop, "Should have break loop action indicating completion") } + +// slowChatModel is a ChatModel that blocks for a configurable duration. +type slowChatModel struct { + delay time.Duration + response *schema.Message + startedChan chan struct{} + startedOnce sync.Once +} + +func (m *slowChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + m.startedOnce.Do(func() { + close(m.startedChan) + }) + + select { + case <-time.After(m.delay): + return m.response, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (m *slowChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + msg, err := m.Generate(ctx, input, opts...) + if err != nil { + return nil, err + } + sr, sw := schema.Pipe[*schema.Message](1) + sw.Send(msg, nil) + sw.Close() + return sr, nil +} + +func (m *slowChatModel) BindTools(tools []*schema.ToolInfo) error { return nil } +func (m *slowChatModel) WithTools(tools []*schema.ToolInfo) (model.ToolCallingChatModel, error) { + return m, nil +} + +// TestWithCancel_PlanExecute_DuringExecution verifies that cancel works +// during the executor (ChatModelAgent) phase of the PlanExecute agent. +func TestWithCancel_PlanExecute_DuringExecution(t *testing.T) { + ctx := context.Background() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Planner: returns a plan quickly + mockPlanner := mockAdk.NewMockAgent(ctrl) + mockPlanner.EXPECT().Name(gomock.Any()).Return("planner").AnyTimes() + mockPlanner.EXPECT().Description(gomock.Any()).Return("a planner agent").AnyTimes() + + plan := &defaultPlan{Steps: []string{"Step 1", "Step 2"}} + userInput := []adk.Message{schema.UserMessage("test task")} + + mockPlanner.EXPECT().Run(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, input *adk.AgentInput, opts ...adk.AgentRunOption) *adk.AsyncIterator[*adk.AgentEvent] { + iterator, generator := adk.NewAsyncIteratorPair[*adk.AgentEvent]() + adk.AddSessionValue(ctx, PlanSessionKey, plan) + adk.AddSessionValue(ctx, UserInputSessionKey, userInput) + planJSON, _ := sonic.MarshalString(plan) + msg := schema.AssistantMessage(planJSON, nil) + generator.Send(adk.EventFromMessage(msg, nil, schema.Assistant, "")) + generator.Close() + return iterator + }, + ).Times(1) + + // Executor: uses a slow model that we can cancel + executorStarted := make(chan struct{}) + slowModel := &slowChatModel{ + delay: 5 * time.Second, + response: schema.AssistantMessage("step result", nil), + startedChan: executorStarted, + } + + executor, err := NewExecutor(ctx, &ExecutorConfig{ + Model: slowModel, + MaxIterations: 5, + }) + assert.NoError(t, err) + + // Replanner: should not be reached since we cancel during executor + mockReplanner := mockAdk.NewMockAgent(ctrl) + mockReplanner.EXPECT().Name(gomock.Any()).Return("replanner").AnyTimes() + mockReplanner.EXPECT().Description(gomock.Any()).Return("a replanner agent").AnyTimes() + + agent, err := New(ctx, &Config{ + Planner: mockPlanner, + Executor: executor, + Replanner: mockReplanner, + MaxIterations: 5, + }) + assert.NoError(t, err) + + runner := adk.NewRunner(ctx, adk.RunnerConfig{Agent: agent}) + + cancelOpt, cancelFn := adk.WithCancel() + iter := runner.Run(ctx, userInput, cancelOpt) + + // Wait for the executor's model to start + select { + case <-executorStarted: + case <-time.After(10 * time.Second): + t.Fatal("Executor model did not start") + } + + time.Sleep(50 * time.Millisecond) + + // Cancel should NOT return ErrExecutionCompleted + handle, _ := cancelFn() + err = handle.Wait() + assert.NoError(t, err, "Cancel during executor should succeed") + + hasCancelError := false + for { + event, ok := iter.Next() + if !ok { + break + } + var ce *adk.CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + hasCancelError = true + } + } + + assert.True(t, hasCancelError, "Should have CancelError event") +} + +// TestWithCancel_PlanExecute_BetweenTransitions verifies that cancel works +// when fired between agent transitions (e.g., after planner, before executor starts). +func TestWithCancel_PlanExecute_BetweenTransitions(t *testing.T) { + ctx := context.Background() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + plannerDone := make(chan struct{}) + + // Planner: signals when done + mockPlanner := mockAdk.NewMockAgent(ctrl) + mockPlanner.EXPECT().Name(gomock.Any()).Return("planner").AnyTimes() + mockPlanner.EXPECT().Description(gomock.Any()).Return("a planner agent").AnyTimes() + + plan := &defaultPlan{Steps: []string{"Step 1"}} + userInput := []adk.Message{schema.UserMessage("test task")} + + mockPlanner.EXPECT().Run(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, input *adk.AgentInput, opts ...adk.AgentRunOption) *adk.AsyncIterator[*adk.AgentEvent] { + iterator, generator := adk.NewAsyncIteratorPair[*adk.AgentEvent]() + go func() { + defer generator.Close() + adk.AddSessionValue(ctx, PlanSessionKey, plan) + adk.AddSessionValue(ctx, UserInputSessionKey, userInput) + planJSON, _ := sonic.MarshalString(plan) + msg := schema.AssistantMessage(planJSON, nil) + generator.Send(adk.EventFromMessage(msg, nil, schema.Assistant, "")) + close(plannerDone) + }() + return iterator + }, + ).Times(1) + + // Executor: slow model to give time to observe cancel + executorModelStarted := make(chan struct{}) + slowExecModel := &slowChatModel{ + delay: 5 * time.Second, + response: schema.AssistantMessage("step result", nil), + startedChan: executorModelStarted, + } + + executor, err := NewExecutor(ctx, &ExecutorConfig{ + Model: slowExecModel, + MaxIterations: 5, + }) + assert.NoError(t, err) + + mockReplanner := mockAdk.NewMockAgent(ctrl) + mockReplanner.EXPECT().Name(gomock.Any()).Return("replanner").AnyTimes() + mockReplanner.EXPECT().Description(gomock.Any()).Return("a replanner agent").AnyTimes() + + agent, err := New(ctx, &Config{ + Planner: mockPlanner, + Executor: executor, + Replanner: mockReplanner, + MaxIterations: 5, + }) + assert.NoError(t, err) + + runner := adk.NewRunner(ctx, adk.RunnerConfig{Agent: agent}) + + cancelOpt, cancelFn := adk.WithCancel() + iter := runner.Run(ctx, userInput, cancelOpt) + + // Wait for planner to finish, then cancel before executor has a chance to produce output + select { + case <-plannerDone: + case <-time.After(10 * time.Second): + t.Fatal("Planner did not finish") + } + + // Cancel after planner, during executor phase + // The executor is a ChatModelAgent which will handle the cancel + select { + case <-executorModelStarted: + case <-time.After(10 * time.Second): + t.Fatal("Executor model did not start") + } + + start := time.Now() + handle, _ := cancelFn() + err = handle.Wait() + assert.NoError(t, err, "Cancel between transitions should succeed") + + hasCancelError := false + for { + event, ok := iter.Next() + if !ok { + break + } + var ce *adk.CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + hasCancelError = true + } + } + elapsed := time.Since(start) + + assert.True(t, hasCancelError, "Should have CancelError event") + assert.True(t, elapsed < 3*time.Second, "Should complete quickly after cancel, elapsed: %v", elapsed) +} diff --git a/adk/react.go b/adk/react.go index 0aec5d94c..07fdbde9a 100644 --- a/adk/react.go +++ b/adk/react.go @@ -22,7 +22,6 @@ import ( "encoding/gob" "errors" "io" - "sync/atomic" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/compose" @@ -82,6 +81,7 @@ func init() { // when decoding checkpoints created by v0.8.0 - v0.8.3 gob.Register(&AgentEvent{}) gob.Register(int(0)) + schema.RegisterName[*reactInput]("_eino_adk_react_input") } func (s *State) getReturnDirectlyEvent() *AgentEvent { @@ -238,7 +238,7 @@ func SendToolGenAction(ctx context.Context, toolName string, action *AgentAction } type reactInput struct { - messages []Message + Messages []Message } type reactConfig struct { @@ -254,6 +254,8 @@ type reactConfig struct { agentName string maxIterations int + + cancelCtx *cancelContext } func genToolInfos(ctx context.Context, config *compose.ToolsNodeConfig) ([]*schema.ToolInfo, error) { @@ -271,8 +273,6 @@ func genToolInfos(ctx context.Context, config *compose.ToolsNodeConfig) ([]*sche } type reactGraph = *compose.Graph[*reactInput, Message] -type sToolNodeOutput = *schema.StreamReader[[]Message] -type sGraphOutput = MessageStream func getReturnDirectlyToolCallID(ctx context.Context) (string, bool) { var toolCallID string @@ -300,37 +300,25 @@ func genReactState(config *reactConfig) func(ctx context.Context) *State { } } -func newReact(ctx context.Context, config *reactConfig, cs *cancelSig) (reactGraph, error) { +func newReact(ctx context.Context, config *reactConfig) (reactGraph, error) { const ( - initNode_ = "Init" - chatModel_ = "ChatModel" - beforeToolNode_ = "BeforeToolNode" - toolNode_ = "ToolNode" - afterToolNode_ = "AfterToolNode" + initNode_ = "Init" + chatModel_ = "ChatModel" + cancelCheckNode_ = "CancelCheck" + toolNode_ = "ToolNode" + afterToolCallsNode_ = "AfterToolCalls" + afterToolCallsCancelCheckNode_ = "AfterToolCallsCancelCheck" ) - checkCancel := cs != nil - - nodeNameAfterModel := func() string { - if checkCancel { - return beforeToolNode_ - } - return toolNode_ - } - - nodeNameAfterTool := func() string { - if checkCancel { - return afterToolNode_ - } - return chatModel_ - } - + cancelCtx := config.cancelCtx g := compose.NewGraph[*reactInput, Message](compose.WithGenLocalState(genReactState(config))) - - initLambda := func(ctx context.Context, input *reactInput) ([]Message, error) { - return input.messages, nil - } - _ = g.AddLambdaNode(initNode_, compose.InvokableLambda(initLambda), compose.WithNodeName(initNode_)) + _ = g.AddLambdaNode(initNode_, compose.InvokableLambda(func(ctx context.Context, input *reactInput) ([]Message, error) { + _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { + st.Messages = append(st.Messages, input.Messages...) + return nil + }) + return input.Messages, nil + }), compose.WithNodeName(initNode_)) var wrappedModel model.BaseChatModel = config.model if config.modelWrapperConf != nil { @@ -338,39 +326,43 @@ func newReact(ctx context.Context, config *reactConfig, cs *cancelSig) (reactGra } toolsConfig := config.toolsConfig - if checkCancel { - wrappedModel = wrapModelForCancelable(wrappedModel, cs) - tcMWs := make([]compose.ToolMiddleware, 0, len(toolsConfig.ToolCallMiddlewares)+1) - tcMWs = append(tcMWs, cancelableTool(cs)) - tcMWs = append(tcMWs, toolsConfig.ToolCallMiddlewares...) - toolsConfigCopy := *toolsConfig - toolsConfigCopy.ToolCallMiddlewares = tcMWs - toolsConfig = &toolsConfigCopy - } toolsNode, err := compose.NewToolNode(ctx, toolsConfig) if err != nil { return nil, err } - modelPreHandle := func(ctx context.Context, input []Message, st *State) ([]Message, error) { - if st.getRemainingIterations() <= 0 { - return nil, ErrExceedMaxIterations + _ = g.AddChatModelNode(chatModel_, wrappedModel, compose.WithStatePreHandler( + func(ctx context.Context, input []Message, st *State) ([]Message, error) { + if st.getRemainingIterations() <= 0 { + return nil, ErrExceedMaxIterations + } + st.decrementRemainingIterations() + return input, nil + }), compose.WithNodeName(chatModel_)) + + // CancelAfterChatModel safe-point: on the tool-calls path, after the branch + // has confirmed that the model response contains tool calls (i.e. not a final + // answer). Skipped entirely when the model produces a final answer. + _ = g.AddLambdaNode(cancelCheckNode_, compose.InvokableLambda(func(ctx context.Context, msg Message) (Message, error) { + if cancelCtx != nil && cancelCtx.shouldCancel() { + if cancelCtx.getMode()&CancelAfterChatModel != 0 { + return nil, compose.StatefulInterrupt(ctx, "CancelAfterChatModel", msg) + } } - st.decrementRemainingIterations() - return input, nil - } - _ = g.AddChatModelNode(chatModel_, wrappedModel, - compose.WithStatePreHandler(modelPreHandle), compose.WithNodeName(chatModel_)) + wasInterrupted, hasState, state := compose.GetInterruptState[Message](ctx) + if wasInterrupted && hasState { + msg = state + } + return msg, nil + }), compose.WithNodeName(cancelCheckNode_)) toolPreHandle := func(ctx context.Context, _ Message, st *State) (Message, error) { input := st.Messages[len(st.Messages)-1] - returnDirectly := config.toolsReturnDirectly if execCtx := getChatModelAgentExecCtx(ctx); execCtx != nil && len(execCtx.runtimeReturnDirectly) > 0 { returnDirectly = execCtx.runtimeReturnDirectly } - if len(returnDirectly) > 0 { for i := range input.ToolCalls { toolName := input.ToolCalls[i].Function.Name @@ -379,10 +371,8 @@ func newReact(ctx context.Context, config *reactConfig, cs *cancelSig) (reactGra } } } - return input, nil } - toolPostHandle := func(ctx context.Context, out *schema.StreamReader[[]*schema.Message], st *State) (*schema.StreamReader[[]*schema.Message], error) { if event := st.getReturnDirectlyEvent(); event != nil { getChatModelAgentExecCtx(ctx).send(event) @@ -390,36 +380,62 @@ func newReact(ctx context.Context, config *reactConfig, cs *cancelSig) (reactGra } return out, nil } - _ = g.AddToolsNode(toolNode_, toolsNode, compose.WithStatePreHandler(toolPreHandle), compose.WithStreamStatePostHandler(toolPostHandle), compose.WithNodeName(toolNode_)) - _ = g.AddEdge(compose.START, initNode_) - _ = g.AddEdge(initNode_, chatModel_) - - if checkCancel { - beforeToolNode := func(ctx context.Context, input Message) (output Message, err error) { - if sig := checkCancelSig(cs); sig != nil && sig.Mode != CancelAfterToolCall { - return nil, compose.Interrupt(ctx, "cancelled externally") + // AfterToolCalls node: calls AfterToolCallsRewriteState handlers after all tool calls complete. + // The graph auto-materializes the ToolsNode stream into []Message before this node. + afterToolCalls := func(ctx context.Context, toolResults []Message) ([]Message, error) { + var stateMessages []Message + _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { + stateMessages = st.Messages + return nil + }) + + state := &ChatModelAgentState{Messages: append(stateMessages, toolResults...)} + + if config.modelWrapperConf != nil { + assistantMsg := stateMessages[len(stateMessages)-1] + tc := &ToolCallsContext{} + for _, toolCall := range assistantMsg.ToolCalls { + tc.ToolCalls = append(tc.ToolCalls, ToolContext{Name: toolCall.Function.Name, CallID: toolCall.ID}) } - return input, nil + for _, handler := range config.modelWrapperConf.handlers { + var err error + ctx, state, err = handler.AfterToolCallsRewriteState(ctx, state, tc) + if err != nil { + return nil, err + } + } } - _ = g.AddLambdaNode(beforeToolNode_, compose.InvokableLambda(beforeToolNode), compose.WithNodeName(beforeToolNode_)) - g.AddEdge(beforeToolNode_, toolNode_) - afterToolNode := func(ctx context.Context, input []Message) (output []Message, err error) { - if sig := checkCancelSig(cs); sig != nil && sig.Mode != CancelAfterChatModel { - return nil, compose.Interrupt(ctx, "cancelled externally") - } + _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { + st.Messages = state.Messages + return nil + }) - return input, nil + return toolResults, nil + } + _ = g.AddLambdaNode(afterToolCallsNode_, compose.InvokableLambda(afterToolCalls), + compose.WithNodeName(afterToolCallsNode_)) + + // AfterToolCallsCancelCheck: CancelAfterToolCalls safe-point, separated from toolPostHandle. + afterToolCallsCancelCheck := func(ctx context.Context, toolResults []Message) ([]Message, error) { + if cancelCtx != nil && cancelCtx.shouldCancel() { + if cancelCtx.getMode()&CancelAfterToolCalls != 0 { + return nil, compose.Interrupt(ctx, "CancelAfterToolCalls") + } } - _ = g.AddLambdaNode(afterToolNode_, compose.InvokableLambda(afterToolNode), compose.WithNodeName(afterToolNode_)) - g.AddEdge(afterToolNode_, chatModel_) + return toolResults, nil } + _ = g.AddLambdaNode(afterToolCallsCancelCheckNode_, compose.InvokableLambda(afterToolCallsCancelCheck), + compose.WithNodeName(afterToolCallsCancelCheckNode_)) + + _ = g.AddEdge(compose.START, initNode_) + _ = g.AddEdge(initNode_, chatModel_) toolCallCheck := func(ctx context.Context, sMsg MessageStream) (string, error) { defer sMsg.Close() @@ -434,84 +450,54 @@ func newReact(ctx context.Context, config *reactConfig, cs *cancelSig) (reactGra } if len(chunk.ToolCalls) > 0 { - return nodeNameAfterModel(), nil + return cancelCheckNode_, nil } } } - branch := compose.NewStreamGraphBranch(toolCallCheck, map[string]bool{compose.END: true, nodeNameAfterModel(): true}) + branch := compose.NewStreamGraphBranch(toolCallCheck, map[string]bool{compose.END: true, cancelCheckNode_: true}) _ = g.AddBranch(chatModel_, branch) + _ = g.AddEdge(cancelCheckNode_, toolNode_) + _ = g.AddEdge(toolNode_, afterToolCallsNode_) + _ = g.AddEdge(afterToolCallsNode_, afterToolCallsCancelCheckNode_) + if len(config.toolsReturnDirectly) > 0 { const ( toolNodeToEndConverter = "ToolNodeToEndConverter" ) - cvt := func(ctx context.Context, sToolCallMessages sToolNodeOutput) (sGraphOutput, error) { + cvt := func(ctx context.Context, toolResults []Message) (Message, error) { id, _ := getReturnDirectlyToolCallID(ctx) - return schema.StreamReaderWithConvert(sToolCallMessages, - func(in []Message) (Message, error) { - - for _, chunk := range in { - if chunk != nil && chunk.ToolCallID == id { - return chunk, nil - } - } + for _, msg := range toolResults { + if msg != nil && msg.ToolCallID == id { + return msg, nil + } + } - return nil, schema.ErrNoValue - }), nil + return nil, errors.New("return directly tool call result not found") } - _ = g.AddLambdaNode(toolNodeToEndConverter, compose.TransformableLambda(cvt), + _ = g.AddLambdaNode(toolNodeToEndConverter, compose.InvokableLambda(cvt), compose.WithNodeName(toolNodeToEndConverter)) _ = g.AddEdge(toolNodeToEndConverter, compose.END) - checkReturnDirect := func(ctx context.Context, - sToolCallMessages sToolNodeOutput) (string, error) { - + checkReturnDirect := func(ctx context.Context, toolResults []Message) (string, error) { _, ok := getReturnDirectlyToolCallID(ctx) if ok { return toolNodeToEndConverter, nil } - return nodeNameAfterTool(), nil + return chatModel_, nil } - branch = compose.NewStreamGraphBranch(checkReturnDirect, - map[string]bool{toolNodeToEndConverter: true, nodeNameAfterTool(): true}) - _ = g.AddBranch(toolNode_, branch) + returnDirectBranch := compose.NewGraphBranch(checkReturnDirect, + map[string]bool{toolNodeToEndConverter: true, chatModel_: true}) + _ = g.AddBranch(afterToolCallsCancelCheckNode_, returnDirectBranch) } else { - _ = g.AddEdge(toolNode_, nodeNameAfterTool()) + _ = g.AddEdge(afterToolCallsCancelCheckNode_, chatModel_) } return g, nil } - -type cancelSig struct { - done chan struct{} - config atomic.Value -} - -func newCancelSig() *cancelSig { - return &cancelSig{ - done: make(chan struct{}), - } -} - -func (cs *cancelSig) cancel(cfg *cancelConfig) { - cs.config.Store(cfg) - close(cs.done) -} - -func checkCancelSig(cs *cancelSig) *cancelConfig { - if cs == nil { - return nil - } - select { - case <-cs.done: - return cs.config.Load().(*cancelConfig) - default: - return nil - } -} diff --git a/adk/react_test.go b/adk/react_test.go index 969e73e35..b0a6c3985 100644 --- a/adk/react_test.go +++ b/adk/react_test.go @@ -23,6 +23,7 @@ import ( "errors" "fmt" "io" + "math" "math/rand" "testing" @@ -144,16 +145,16 @@ func TestReact(t *testing.T) { toolsReturnDirectly: map[string]bool{}, } - graph, err := newReact(ctx, config, nil) + graph, err := newReact(ctx, config) assert.NoError(t, err) assert.NotNil(t, graph) - compiled, err := graph.Compile(ctx) + compiled, err := graph.Compile(ctx, compose.WithMaxRunSteps(math.MaxInt)) assert.NoError(t, err) assert.NotNil(t, compiled) // Test with a user message - result, err := compiled.Invoke(ctx, &reactInput{messages: []Message{ + result, err := compiled.Invoke(ctx, &reactInput{Messages: []Message{ { Role: schema.User, Content: "Use the test tool to say hello", @@ -211,16 +212,16 @@ func TestReact(t *testing.T) { toolsReturnDirectly: map[string]bool{info.Name: true}, } - graph, err := newReact(ctx, config, nil) + graph, err := newReact(ctx, config) assert.NoError(t, err) assert.NotNil(t, graph) - compiled, err := graph.Compile(ctx) + compiled, err := graph.Compile(ctx, compose.WithMaxRunSteps(math.MaxInt)) assert.NoError(t, err) assert.NotNil(t, compiled) // Test with a user message when tool returns directly - result, err := compiled.Invoke(ctx, &reactInput{messages: []Message{ + result, err := compiled.Invoke(ctx, &reactInput{Messages: []Message{ { Role: schema.User, Content: "Use the test tool to say hello", @@ -303,16 +304,16 @@ func TestReact(t *testing.T) { toolsReturnDirectly: map[string]bool{}, } - graph, err := newReact(ctx, config, nil) + graph, err := newReact(ctx, config) assert.NoError(t, err) assert.NotNil(t, graph) - compiled, err := graph.Compile(ctx) + compiled, err := graph.Compile(ctx, compose.WithMaxRunSteps(math.MaxInt)) assert.NoError(t, err) assert.NotNil(t, compiled) // Test streaming with a user message - outStream, err := compiled.Stream(ctx, &reactInput{messages: []Message{ + outStream, err := compiled.Stream(ctx, &reactInput{Messages: []Message{ { Role: schema.User, Content: "Use the test tool to say hello", @@ -413,11 +414,11 @@ func TestReact(t *testing.T) { toolsReturnDirectly: map[string]bool{streamInfo.Name: true}, } - graph, err := newReact(ctx, config, nil) + graph, err := newReact(ctx, config) assert.NoError(t, err) assert.NotNil(t, graph) - compiled, err := graph.Compile(ctx) + compiled, err := graph.Compile(ctx, compose.WithMaxRunSteps(math.MaxInt)) assert.NoError(t, err) assert.NotNil(t, compiled) @@ -425,7 +426,7 @@ func TestReact(t *testing.T) { times = 0 // Test streaming with a user message when tool returns directly - outStream, err := compiled.Stream(ctx, &reactInput{messages: []Message{ + outStream, err := compiled.Stream(ctx, &reactInput{Messages: []Message{ { Role: schema.User, Content: "Use the test tool to say hello", @@ -502,16 +503,16 @@ func TestReact(t *testing.T) { maxIterations: 6, } - graph, err := newReact(ctx, config, nil) + graph, err := newReact(ctx, config) assert.NoError(t, err) assert.NotNil(t, graph) - compiled, err := graph.Compile(ctx) + compiled, err := graph.Compile(ctx, compose.WithMaxRunSteps(math.MaxInt)) assert.NoError(t, err) assert.NotNil(t, compiled) // Test with a user message - result, err := compiled.Invoke(ctx, &reactInput{messages: []Message{ + result, err := compiled.Invoke(ctx, &reactInput{Messages: []Message{ { Role: schema.User, Content: "Use the test tool to say hello", @@ -532,16 +533,16 @@ func TestReact(t *testing.T) { maxIterations: 5, } - graph, err = newReact(ctx, config, nil) + graph, err = newReact(ctx, config) assert.NoError(t, err) assert.NotNil(t, graph) - compiled, err = graph.Compile(ctx) + compiled, err = graph.Compile(ctx, compose.WithMaxRunSteps(math.MaxInt)) assert.NoError(t, err) assert.NotNil(t, compiled) // Test with a user message - result, err = compiled.Invoke(ctx, &reactInput{messages: []Message{ + result, err = compiled.Invoke(ctx, &reactInput{Messages: []Message{ { Role: schema.User, Content: "Use the test tool to say hello", diff --git a/adk/retry_chatmodel.go b/adk/retry_chatmodel.go index 8ae4e2aac..bac955033 100644 --- a/adk/retry_chatmodel.go +++ b/adk/retry_chatmodel.go @@ -196,6 +196,11 @@ func (r *retryModelWrapper) Generate(ctx context.Context, input []*schema.Messag return out, nil } + // Never retry interrupt errors (e.g. cancel safe-point interrupts). + if _, ok := compose.ExtractInterruptInfo(err); ok { + return nil, err + } + if !isRetryAble(ctx, err) { return nil, err } @@ -238,6 +243,10 @@ func (r *retryModelWrapper) Stream(ctx context.Context, input []*schema.Message, stream, err := r.inner.Stream(ctx, input, opts...) if err != nil { + // Never retry interrupt errors (e.g. cancel safe-point interrupts). + if _, ok := compose.ExtractInterruptInfo(err); ok { + return nil, err + } if !isRetryAble(ctx, err) { return nil, err } diff --git a/adk/runctx.go b/adk/runctx.go index 6e2a6cfbe..1a32f1760 100644 --- a/adk/runctx.go +++ b/adk/runctx.go @@ -24,8 +24,6 @@ import ( "sort" "sync" "time" - - "github.com/cloudwego/eino/schema" ) // runSession CheckpointSchema: persisted via serialization.RunCtx (gob). @@ -65,8 +63,14 @@ type agentEventWrapper struct { type otherAgentEventWrapperForEncode agentEventWrapper func (a *agentEventWrapper) GobEncode() ([]byte, error) { - if a.concatenatedMessage != nil && a.Output != nil && a.Output.MessageOutput != nil && a.Output.MessageOutput.IsStreaming { - a.Output.MessageOutput.MessageStream = schema.StreamReaderFromArray([]Message{a.concatenatedMessage}) + if a.Output != nil && a.Output.MessageOutput != nil && a.Output.MessageOutput.IsStreaming { + // Materialize the stream before encoding. An unconsumed stream that + // ends with a non-EOF error (WillRetryError, ErrStreamCanceled) would + // cause MessageVariant.GobEncode to fail. consumeStream replaces the + // stream with an error-free, materialized version. + if a.concatenatedMessage == nil && a.StreamErr == nil { + a.consumeStream() + } } buf := &bytes.Buffer{} diff --git a/adk/runctx_test.go b/adk/runctx_test.go index 7f164b3e2..bef1f44eb 100644 --- a/adk/runctx_test.go +++ b/adk/runctx_test.go @@ -17,7 +17,10 @@ package adk import ( + "bytes" "context" + "encoding/gob" + "errors" "testing" "time" @@ -423,3 +426,209 @@ func TestForkJoinRunCtx(t *testing.T) { mainRunCtx.Session.addEvent(eventF) assert.Equal(t, []string{"A", "B", "C1", "D", "E", "F"}, getEventNames(mainRunCtx.Session.getEvents()), "After F") } + +// makeStreamingEventWrapper creates an agentEventWrapper with a streaming MessageOutput +// whose stream yields the given message then terminates with streamErr (or io.EOF if nil). +func makeStreamingEventWrapper(msg Message, streamErr error) *agentEventWrapper { + r, w := schema.Pipe[Message](2) + w.Send(msg, nil) + if streamErr != nil { + w.Send(nil, streamErr) + } + w.Close() + + return &agentEventWrapper{ + AgentEvent: &AgentEvent{ + AgentName: "test-agent", + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + IsStreaming: true, + MessageStream: r, + Role: schema.Assistant, + }, + }, + }, + } +} + +func TestGobEncodeStreamErrors(t *testing.T) { + t.Run("WillRetryError_unconsumed_stream_fails_GobEncode", func(t *testing.T) { + // An agentEventWrapper whose stream yields a message then WillRetryError. + // Without pre-consuming (no getMessageFromWrappedEvent call), GobEncode + // reaches MessageVariant.GobEncode which treats non-EOF errors as fatal. + wrapper := makeStreamingEventWrapper( + schema.AssistantMessage("partial", nil), + &WillRetryError{ErrStr: "model error", RetryAttempt: 1}, + ) + + _, err := wrapper.GobEncode() + assert.NoError(t, err, "GobEncode should handle WillRetryError streams gracefully") + }) + + t.Run("ErrStreamCanceled_unconsumed_stream_fails_GobEncode", func(t *testing.T) { + // Same scenario but with ErrStreamCanceled (*errors.errorString). + wrapper := makeStreamingEventWrapper( + schema.AssistantMessage("partial", nil), + ErrStreamCanceled, + ) + + _, err := wrapper.GobEncode() + assert.NoError(t, err, "GobEncode should handle ErrStreamCanceled streams gracefully") + }) + + t.Run("successful_stream_GobEncode_succeeds", func(t *testing.T) { + // Control: a clean stream (no error) should encode fine. + wrapper := makeStreamingEventWrapper( + schema.AssistantMessage("hello", nil), + nil, // no stream error + ) + + data, err := wrapper.GobEncode() + assert.NoError(t, err) + assert.NotEmpty(t, data) + + // Verify round-trip decode works. + decoded := &agentEventWrapper{AgentEvent: &AgentEvent{}} + err = decoded.GobDecode(data) + assert.NoError(t, err) + assert.Equal(t, "test-agent", decoded.AgentName) + }) + + t.Run("preconsumed_WillRetryError_GobEncode_succeeds", func(t *testing.T) { + // When getMessageFromWrappedEvent is called first, WillRetryError is + // cached in StreamErr and the stream is replaced with an error-free array. + wrapper := makeStreamingEventWrapper( + schema.AssistantMessage("partial", nil), + &WillRetryError{ErrStr: "model error", RetryAttempt: 1}, + ) + + _, consumeErr := getMessageFromWrappedEvent(wrapper) + assert.Error(t, consumeErr) + + data, err := wrapper.GobEncode() + assert.NoError(t, err, "GobEncode should succeed after pre-consuming WillRetryError stream") + assert.NotEmpty(t, data) + }) + + t.Run("preconsumed_ErrStreamCanceled_GobEncode_succeeds", func(t *testing.T) { + // ErrStreamCanceled is a *StreamCanceledError which IS gob-registered. + // After getMessageFromWrappedEvent, StreamErr = ErrStreamCanceled. + // Since it's registered, gob encoding succeeds. + wrapper := makeStreamingEventWrapper( + schema.AssistantMessage("partial", nil), + ErrStreamCanceled, + ) + + _, consumeErr := getMessageFromWrappedEvent(wrapper) + assert.Error(t, consumeErr) + + data, err := wrapper.GobEncode() + assert.NoError(t, err, "GobEncode should succeed; ErrStreamCanceled is gob-registered") + assert.NotEmpty(t, data) + }) + + t.Run("GobEncode_roundtrip_preserves_content", func(t *testing.T) { + // Verify that after GobEncode with a WillRetryError stream, + // the decoded wrapper has the partial message content and StreamErr intact. + wrapper := makeStreamingEventWrapper( + schema.AssistantMessage("partial response", nil), + &WillRetryError{ErrStr: "err", RetryAttempt: 1}, + ) + + data, err := wrapper.GobEncode() + assert.NoError(t, err) + + decoded := &agentEventWrapper{AgentEvent: &AgentEvent{}} + err = decoded.GobDecode(data) + assert.NoError(t, err) + assert.Equal(t, "test-agent", decoded.AgentName) + assert.True(t, decoded.Output.MessageOutput.IsStreaming) + // The stream should be consumable and yield the partial message. + msg, recvErr := decoded.Output.MessageOutput.MessageStream.Recv() + assert.NoError(t, recvErr) + assert.Contains(t, msg.Content, "partial response") + // StreamErr should be preserved for end-user visibility. + var willRetryErr *WillRetryError + assert.True(t, errors.As(decoded.StreamErr, &willRetryErr)) + assert.Equal(t, "err", willRetryErr.ErrStr) + }) + + t.Run("GobEncode_roundtrip_preserves_ErrStreamCanceled", func(t *testing.T) { + // ErrStreamCanceled (*StreamCanceledError) is gob-registered, so + // StreamErr should survive encoding/decoding. + wrapper := makeStreamingEventWrapper( + schema.AssistantMessage("partial", nil), + ErrStreamCanceled, + ) + + data, err := wrapper.GobEncode() + assert.NoError(t, err) + + decoded := &agentEventWrapper{AgentEvent: &AgentEvent{}} + err = decoded.GobDecode(data) + assert.NoError(t, err) + var streamCanceledErr *StreamCanceledError + assert.ErrorAs(t, decoded.StreamErr, &streamCanceledErr) + }) + + t.Run("GobEncode_idempotent", func(t *testing.T) { + // Calling GobEncode twice should succeed both times (stream replaced on first call). + wrapper := makeStreamingEventWrapper( + schema.AssistantMessage("hello", nil), + &WillRetryError{ErrStr: "err", RetryAttempt: 1}, + ) + + data1, err := wrapper.GobEncode() + assert.NoError(t, err) + + data2, err := wrapper.GobEncode() + assert.NoError(t, err) + + // Both should decode to equivalent content. + d1, d2 := &agentEventWrapper{AgentEvent: &AgentEvent{}}, &agentEventWrapper{AgentEvent: &AgentEvent{}} + assert.NoError(t, d1.GobDecode(data1)) + assert.NoError(t, d2.GobDecode(data2)) + assert.Equal(t, d1.AgentName, d2.AgentName) + }) + + t.Run("GobEncode_non_streaming_unaffected", func(t *testing.T) { + // Non-streaming events should encode/decode as before. + wrapper := &agentEventWrapper{ + AgentEvent: &AgentEvent{ + AgentName: "non-stream-agent", + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + IsStreaming: false, + Message: schema.AssistantMessage("direct", nil), + Role: schema.Assistant, + }, + }, + }, + } + + data, err := wrapper.GobEncode() + assert.NoError(t, err) + + decoded := &agentEventWrapper{AgentEvent: &AgentEvent{}} + assert.NoError(t, decoded.GobDecode(data)) + assert.Equal(t, "non-stream-agent", decoded.AgentName) + assert.False(t, decoded.Output.MessageOutput.IsStreaming) + }) + + t.Run("GobEncode_within_runSession", func(t *testing.T) { + // Simulate the real scenario: a runSession with a streaming event containing + // WillRetryError is gob-encoded (as happens during checkpoint save). + wrapper := makeStreamingEventWrapper( + schema.AssistantMessage("checkpoint content", nil), + &WillRetryError{ErrStr: "retry", RetryAttempt: 1}, + ) + + session := newRunSession() + session.Events = []*agentEventWrapper{wrapper} + + // Encode the entire session (the checkpoint path). + var buf bytes.Buffer + err := gob.NewEncoder(&buf).Encode(session) + assert.NoError(t, err, "encoding runSession with WillRetryError stream should succeed") + }) +} diff --git a/adk/runner.go b/adk/runner.go index a9fd9b94d..4881122a6 100644 --- a/adk/runner.go +++ b/adk/runner.go @@ -18,6 +18,7 @@ package adk import ( "context" + "errors" "fmt" "runtime/debug" "sync" @@ -41,6 +42,8 @@ type Runner struct { type CheckPointStore = core.CheckPointStore +type CheckPointDeleter = core.CheckPointDeleter + type RunnerConfig struct { Agent Agent EnableStreaming bool @@ -74,31 +77,6 @@ func NewRunner(_ context.Context, conf RunnerConfig) *Runner { // upon interruption. func (r *Runner) Run(ctx context.Context, messages []Message, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { - iter, _, _ := r.runWithCancel(ctx, messages, false, opts...) - return iter -} - -// Query is a convenience method that starts a new execution with a single user query string. -func (r *Runner) Query(ctx context.Context, - query string, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { - - return r.Run(ctx, []Message{schema.UserMessage(query)}, opts...) -} - -// RunWithCancel starts a new execution of the agent and returns both an iterator and a cancel function. -// The cancel function can be used to interrupt the running agent at specific points based on the CancelMode. -// If the Runner was configured with a CheckPointStore and WithCheckPointID option, it will automatically -// save the agent's state upon cancellation for later resumption. -// -// If the agent does not implement CancellableAgent, the returned CancelFunc will be nil. -func (r *Runner) RunWithCancel(ctx context.Context, messages []Message, - opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) { - iter, cancelFn, _ := r.runWithCancel(ctx, messages, true, opts...) - return iter, cancelFn -} - -func (r *Runner) runWithCancel(ctx context.Context, messages []Message, withCancel bool, - opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc, error) { o := getCommonOptions(nil, opts...) fa := toFlowAgent(ctx, r.a) @@ -112,46 +90,23 @@ func (r *Runner) runWithCancel(ctx context.Context, messages []Message, withCanc AddSessionValues(ctx, o.sessionValues) - var iter *AsyncIterator[*AgentEvent] - var cancelFn CancelFunc - if withCancel { - if _, ok := r.a.(CancellableAgent); ok { - iter, cancelFn = fa.RunWithCancel(ctx, input, opts...) - } else { - iter = fa.Run(ctx, input, opts...) - } - } else { - iter = fa.Run(ctx, input, opts...) - } + iter := fa.Run(ctx, input, opts...) - if r.store == nil { - return iter, cancelFn, nil + if r.store == nil && o.cancelCtx == nil { + return iter } niter, gen := NewAsyncIteratorPair[*AgentEvent]() - go r.handleIter(ctx, iter, gen, o.checkPointID) - return niter, cancelFn, nil + go r.handleIter(ctx, iter, gen, o.checkPointID, o.cancelCtx) + return niter } -// ResumeWithCancel continues an interrupted execution from a checkpoint and returns both an iterator and a cancel function. -// This method uses the "Implicit Resume All" strategy where all previously interrupted points proceed without specific data. -// The cancel function can be used to interrupt the running agent again at specific points based on the CancelMode. -// -// If the agent does not implement CancellableResumableAgent, the returned CancelFunc will be nil. -func (r *Runner) ResumeWithCancel(ctx context.Context, checkPointID string, opts ...AgentRunOption) ( - *AsyncIterator[*AgentEvent], CancelFunc, error) { - return r.resumeWithCancel(ctx, checkPointID, nil, true, opts...) -} +// Query is a convenience method that starts a new execution with a single user query string. +func (r *Runner) Query(ctx context.Context, + query string, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { -// ResumeWithParamsAndCancel continues an interrupted execution from a checkpoint with specific parameters -// and returns both an iterator and a cancel function. -// The params.Targets map should contain the addresses of the components to be resumed as keys. -// -// If the agent does not implement CancellableResumableAgent, the returned CancelFunc will be nil. -func (r *Runner) ResumeWithParamsAndCancel(ctx context.Context, checkPointID string, params *ResumeParams, - opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc, error) { - return r.resumeWithCancel(ctx, checkPointID, params.Targets, true, opts...) + return r.Run(ctx, []Message{schema.UserMessage(query)}, opts...) } // Resume continues an interrupted execution from a checkpoint, using an "Implicit Resume All" strategy. @@ -163,8 +118,7 @@ func (r *Runner) ResumeWithParamsAndCancel(ctx context.Context, checkPointID str // pattern where an agent only needs to know `wasInterrupted` is true to continue. func (r *Runner) Resume(ctx context.Context, checkPointID string, opts ...AgentRunOption) ( *AsyncIterator[*AgentEvent], error) { - iter, _, err := r.resumeWithCancel(ctx, checkPointID, nil, false, opts...) - return iter, err + return r.resumeInternal(ctx, checkPointID, nil, opts...) } // ResumeWithParams continues an interrupted execution from a checkpoint with specific parameters. @@ -186,19 +140,18 @@ func (r *Runner) Resume(ctx context.Context, checkPointID string, opts ...AgentR // naturally re-interrupt if one of their interrupted children re-interrupts, as they receive the // new `CompositeInterrupt` signal from them. func (r *Runner) ResumeWithParams(ctx context.Context, checkPointID string, params *ResumeParams, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], error) { - iter, _, err := r.resumeWithCancel(ctx, checkPointID, params.Targets, false, opts...) - return iter, err + return r.resumeInternal(ctx, checkPointID, params.Targets, opts...) } -func (r *Runner) resumeWithCancel(ctx context.Context, checkPointID string, resumeData map[string]any, - withCancel bool, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc, error) { +func (r *Runner) resumeInternal(ctx context.Context, checkPointID string, resumeData map[string]any, + opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], error) { if r.store == nil { - return nil, nil, fmt.Errorf("failed to resume: store is nil") + return nil, fmt.Errorf("failed to resume: store is nil") } ctx, runCtx, resumeInfo, err := r.loadCheckPoint(ctx, checkPointID) if err != nil { - return nil, nil, fmt.Errorf("failed to load from checkpoint: %w", err) + return nil, fmt.Errorf("failed to load from checkpoint: %w", err) } o := getCommonOptions(nil, opts...) @@ -226,30 +179,20 @@ func (r *Runner) resumeWithCancel(ctx context.Context, checkPointID string, resu fa := toFlowAgent(ctx, r.a) - var aIter *AsyncIterator[*AgentEvent] - var cancelFn CancelFunc - if withCancel { - if _, ok := r.a.(CancellableResumableAgent); ok { - aIter, cancelFn = fa.ResumeWithCancel(ctx, resumeInfo, opts...) - } else { - aIter = fa.Resume(ctx, resumeInfo, opts...) - } - } else { - aIter = fa.Resume(ctx, resumeInfo, opts...) - } + aIter := fa.Resume(ctx, resumeInfo, opts...) - if r.store == nil { - return aIter, cancelFn, nil + if r.store == nil && o.cancelCtx == nil { + return aIter, nil } niter, gen := NewAsyncIteratorPair[*AgentEvent]() - go r.handleIter(ctx, aIter, gen, &checkPointID) - return niter, cancelFn, nil + go r.handleIter(ctx, aIter, gen, &checkPointID, o.cancelCtx) + return niter, nil } func (r *Runner) handleIter(ctx context.Context, aIter *AsyncIterator[*AgentEvent], - gen *AsyncGenerator[*AgentEvent], checkPointID *string) { + gen *AsyncGenerator[*AgentEvent], checkPointID *string, cancelCtx *cancelContext) { defer func() { panicErr := recover() if panicErr != nil { @@ -269,6 +212,25 @@ func (r *Runner) handleIter(ctx context.Context, aIter *AsyncIterator[*AgentEven break } + if event.Err != nil { + var cancelErr *CancelError + if errors.As(event.Err, &cancelErr) { + if cancelCtx != nil && cancelCtx.isRoot() && cancelCtx.shouldCancel() { + cancelCtx.markCancelHandled() + } + if cancelErr.interruptSignal != nil && checkPointID != nil { + cancelErr.CheckPointID = *checkPointID + cancelErr.InterruptContexts = core.ToInterruptContexts(cancelErr.interruptSignal, allowedAddressSegmentTypes) + err := r.saveCheckPoint(ctx, *checkPointID, &InterruptInfo{}, cancelErr.interruptSignal) + if err != nil { + gen.Send(&AgentEvent{Err: fmt.Errorf("failed to save checkpoint on cancel: %w", err)}) + } + } + gen.Send(event) + break + } + } + if event.Action != nil && event.Action.internalInterrupted != nil { if interruptSignal != nil { // even if multiple interrupt happens, they should be merged into one @@ -293,8 +255,7 @@ func (r *Runner) handleIter(ctx context.Context, aIter *AsyncIterator[*AgentEven legacyData = event.Action.Interrupted.Data if checkPointID != nil { - // save checkpoint first before sending interrupt event, - // so when end-user receives interrupt event, they can resume from this checkpoint + // save checkpoint first before sending interrupt event, so when end-user receives interrupt event, they can resume from this checkpoint err := r.saveCheckPoint(ctx, *checkPointID, &InterruptInfo{ Data: legacyData, }, interruptSignal) diff --git a/adk/turn_loop.go b/adk/turn_loop.go index f7ec423c5..88979876e 100644 --- a/adk/turn_loop.go +++ b/adk/turn_loop.go @@ -17,7 +17,9 @@ package adk import ( + "bytes" "context" + "encoding/gob" "errors" "fmt" "runtime/debug" @@ -25,558 +27,1387 @@ import ( "sync/atomic" "time" + "github.com/cloudwego/eino/internal" "github.com/cloudwego/eino/internal/safe" ) -// ConsumeMode specifies how a received message should be consumed -// relative to the currently running agent. -type ConsumeMode int - -const ( - // ConsumeNonPreemptive processes the message after the current agent - // finishes. This is the default queued behavior. - ConsumeNonPreemptive ConsumeMode = iota - // ConsumePreemptive cancels the currently running agent (if it - // implements Cancellable) and processes the message immediately. - // If the agent does not implement Cancellable, the message is - // buffered and processed after the agent finishes. - ConsumePreemptive - - ConsumePreemptiveOnTimeout -) - -type consumeConfig struct { - Mode ConsumeMode - Timeout time.Duration - CancelOpts []CancelOption - CheckPointID string +// stopSignal coordinates the Stop() call with per-turn watcher goroutines. +// +// Lifecycle overview: +// +// 1. SIGNAL — Stop() calls signal() which bumps the generation counter, +// stores the AgentCancelOptions, and deposits a one-shot notification +// in the buffered notify channel. +// +// 2. DONE — Stop() calls closeDone() which permanently closes the done +// channel. This acts as a durable "stopped" flag: any current or future +// select on done fires immediately, ensuring that every watcher — +// including watchers in turns that start after Stop() but before the +// run loop observes isStopped() — can reliably detect the stop. +// +// 3. RECEIVE — The per-turn watchStopSignal goroutine selects on the done +// channel (the durable flag) and the notify channel (to detect mode +// escalation from a second Stop call). On either signal, it calls +// agentCancelFunc to cancel the running agent. +// +// The generation counter (gen) de-duplicates wakes so that the watcher only +// acts when a new Stop() call has been made, supporting mode escalation +// (e.g. CancelAfterToolCalls followed by CancelImmediate). +type stopSignal struct { + // done is closed exactly once by closeDone(). A closed channel is + // readable forever, so it serves as a durable stop flag for all watchers. + done chan struct{} + + mu sync.Mutex + gen uint64 + agentCancelOpts []AgentCancelOption + // notify is a buffered(1) channel that wakes the current turn's watcher + // when Stop() is called. Unlike done, it supports repeated Stop() calls + // for cancel-mode escalation. + notify chan struct{} } -type ConsumeOption func(*consumeConfig) - -// WithPreemptive sets the consume mode to preemptive, which cancels the -// currently running agent immediately. -func WithPreemptive() ConsumeOption { - return func(config *consumeConfig) { - config.Mode = ConsumePreemptive +func newStopSignal() *stopSignal { + return &stopSignal{ + done: make(chan struct{}), + notify: make(chan struct{}, 1), } } -// WithPreemptiveOnTimeout sets the consume mode to preemptive with a timeout. -// If the current agent does not complete within the timeout, it will be canceled. -func WithPreemptiveOnTimeout(timeout time.Duration) ConsumeOption { - return func(config *consumeConfig) { - config.Mode = ConsumePreemptiveOnTimeout - config.Timeout = timeout +// signal records a stop request and wakes the current turn's watcher (if any). +// The non-blocking send means the notification is silently coalesced when the +// buffer is already full — this is safe because gen de-duplicates in the watcher. +func (s *stopSignal) signal(cfg *stopConfig) { + s.mu.Lock() + s.gen++ + s.agentCancelOpts = cfg.agentCancelOpts + s.mu.Unlock() + select { + case s.notify <- struct{}{}: + default: } } -// WithCancelOptions appends cancel options to be used when canceling the agent. -func WithCancelOptions(opts ...CancelOption) ConsumeOption { - return func(config *consumeConfig) { - config.CancelOpts = append(config.CancelOpts, opts...) +// isStopped returns true if closeDone() has been called. +func (s *stopSignal) isStopped() bool { + select { + case <-s.done: + return true + default: + return false } } -// WithConsumeCheckPointID sets the checkpoint ID for the consumed message. -// When set, the checkpoint will be saved with this ID if an interrupt occurs. -func WithConsumeCheckPointID(id string) ConsumeOption { - return func(config *consumeConfig) { - config.CheckPointID = id +// closeDone permanently marks the stop as committed. All current and future +// selects on s.done will fire immediately after this call. +func (s *stopSignal) closeDone() { + close(s.done) +} + +// check returns the current generation and a snapshot of the cancel options. +func (s *stopSignal) check() (uint64, []AgentCancelOption) { + s.mu.Lock() + defer s.mu.Unlock() + return s.gen, append([]AgentCancelOption{}, s.agentCancelOpts...) +} + +// preemptSignal coordinates preemption between Push callers and the run loop. +// +// Lifecycle overview: +// +// 1. HOLD — A Push caller (or the run loop itself) calls holdRunLoop() to +// increment holdCount. While holdCount > 0 the run loop blocks at +// waitForPreemptOrUnhold(), preventing it from starting a new turn. +// +// 2. REQUEST — The Push caller calls requestPreempt() which sets +// preemptRequested=true, bumps preemptGen, stores cancelOpts/acks, and +// wakes both the run-loop (via cond) and the in-turn watcher goroutine +// (via notify channel). +// +// 3. RECEIVE — The per-turn watchPreemptSignal goroutine calls +// receivePreempt(), obtains the cancel opts and ack channels, invokes +// agentCancelFunc to cancel the running agent, and closes the ack +// channels to notify Push callers. +// +// 4. UNHOLD — After the turn finishes (or if the Push caller decides not +// to preempt), unholdRunLoop() / endTurnAndUnhold() decrements +// holdCount. When holdCount reaches 0, all signal state is reset. +// +// The run loop brackets every turn with holdRunLoop() / endTurnAndUnhold() +// so that a concurrent Push caller's hold keeps holdCount > 0 even after +// the turn ends, preventing the loop from racing into a new turn before +// the Push caller's preempt request is delivered. +// +// Fields currentTC and currentRunCtx are stored here (rather than on +// TurnLoop) so that holdAndGetTurn() can atomically snapshot the turn +// state and increment holdCount under the same mu lock, eliminating the +// TOCTOU race between reading the turn and holding the loop. +type preemptSignal struct { + mu sync.Mutex + cond *sync.Cond + holdCount int + preemptRequested bool + preemptGen uint64 + agentCancelOpts []AgentCancelOption + pendingAckList []chan struct{} + notify chan struct{} + + currentTC any + currentRunCtx context.Context +} + +func newPreemptSignal() *preemptSignal { + s := &preemptSignal{notify: make(chan struct{}, 1)} + s.cond = sync.NewCond(&s.mu) + return s +} + +func (s *preemptSignal) holdRunLoop() { + s.mu.Lock() + s.holdCount++ + s.mu.Unlock() +} + +func (s *preemptSignal) setTurn(ctx context.Context, tc any) { + s.mu.Lock() + s.currentRunCtx = ctx + s.currentTC = tc + s.mu.Unlock() +} + +func (s *preemptSignal) holdAndGetTurn() (context.Context, any) { + s.mu.Lock() + defer s.mu.Unlock() + s.holdCount++ + return s.currentRunCtx, s.currentTC +} + +// requestPreempt records a preempt request and wakes both waiters. +// If holdCount is 0, no one is listening — close the ack immediately as a no-op. +func (s *preemptSignal) requestPreempt(ack chan struct{}, opts ...AgentCancelOption) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.holdCount <= 0 { + if ack != nil { + close(ack) + } + return + } + + s.preemptRequested = true + s.preemptGen++ + s.agentCancelOpts = opts + if ack != nil { + s.pendingAckList = append(s.pendingAckList, ack) } + select { + case s.notify <- struct{}{}: + default: + } + + s.cond.Broadcast() } -type ReceiveConfig struct { - Timeout time.Duration +// receivePreempt is called by the per-turn watcher goroutine to consume a +// pending preempt. It drains pendingAckList (so the watcher can close them +// after invoking agentCancelFunc) but intentionally preserves preemptRequested +// and preemptGen — these are needed by waitForPreemptOrUnhold on the run loop. +func (s *preemptSignal) receivePreempt() (bool, uint64, []AgentCancelOption, []chan struct{}) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.preemptRequested { + ackList := s.pendingAckList + s.pendingAckList = nil + return true, s.preemptGen, s.agentCancelOpts, ackList + } + return false, 0, nil, nil } -type MessageSource[T any] interface { - Receive(context.Context, ReceiveConfig) (context.Context, T, []ConsumeOption, error) - Front(context.Context, ReceiveConfig) (context.Context, T, []ConsumeOption, error) +// waitForPreemptOrUnhold blocks the run loop between turns. It returns early +// (preempted=false) when holdCount is 0 (no Push caller is holding). Otherwise +// it blocks until either a preempt is requested or all holders release. +func (s *preemptSignal) waitForPreemptOrUnhold() (preempted bool, opts []AgentCancelOption, ackList []chan struct{}) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.holdCount <= 0 { + return false, nil, nil + } + + for s.holdCount > 0 && !s.preemptRequested { + s.cond.Wait() + } + + if s.preemptRequested { + ackList = s.pendingAckList + s.pendingAckList = nil + return true, s.agentCancelOpts, ackList + } + return false, nil, nil } -type turnLoopRunConfig[T any] struct { - checkPointID string - item T +// resetLocked clears all signal state and closes pending ack channels so the +// next cycle starts clean and blocked Push callers are unblocked. Must be +// called with s.mu held. Does NOT touch holdCount, currentTC, or currentRunCtx +// — callers are responsible for those. +func (s *preemptSignal) resetLocked() { + s.preemptRequested = false + s.preemptGen = 0 + s.agentCancelOpts = nil + for _, ack := range s.pendingAckList { + close(ack) + } + s.pendingAckList = nil + select { + case <-s.notify: + default: + } } -// TurnLoopRunOption is an option for TurnLoop.Run. -type TurnLoopRunOption[T any] func(*turnLoopRunConfig[T]) +// unholdRunLoop drops one hold. When holdCount reaches 0, all signal state is +// reset so the next cycle starts clean. +func (s *preemptSignal) unholdRunLoop() { + s.mu.Lock() + defer s.mu.Unlock() + + s.holdCount-- + if s.holdCount < 0 { + s.holdCount = 0 + } + if s.holdCount == 0 { + s.resetLocked() + } + s.cond.Broadcast() +} -// WithTurnLoopResume configures the TurnLoop to resume from a previously saved checkpoint. -// The checkPointID identifies the checkpoint to resume from, and item is the original input -// that triggered the interrupted execution. -func WithTurnLoopResume[T any](checkPointID string, item T) TurnLoopRunOption[T] { - return func(c *turnLoopRunConfig[T]) { - c.checkPointID = checkPointID - c.item = item +// endTurnAndUnhold is called by the run loop after runAgentAndHandleEvents +// returns. It clears the current turn context and drops the run loop's hold. +func (s *preemptSignal) endTurnAndUnhold() { + s.mu.Lock() + defer s.mu.Unlock() + + s.currentTC = nil + s.currentRunCtx = nil + s.holdCount-- + if s.holdCount < 0 { + s.holdCount = 0 + } + if s.holdCount == 0 { + s.resetLocked() } + s.cond.Broadcast() +} + +// drainAll forcefully resets all preemptSignal state and closes any pending +// ack channels. Called during TurnLoop cleanup to prevent ack channels from +// leaking when the run loop exits (e.g. due to Stop) while a Push caller +// still holds a reference. +func (s *preemptSignal) drainAll() { + s.mu.Lock() + defer s.mu.Unlock() + + s.holdCount = 0 + s.currentTC = nil + s.currentRunCtx = nil + s.resetLocked() + s.cond.Broadcast() } // TurnLoopConfig is the configuration for creating a TurnLoop. type TurnLoopConfig[T any] struct { - // Source provides messages to drive the loop. Required. - Source MessageSource[T] - // GenInput converts a received message into AgentInput and optional - // RunOptions for the agent. Required. - GenInput func(ctx context.Context, item T) (*AgentInput, []AgentRunOption, error) - // GetAgent returns the Agent to run for a given message. Required. - GetAgent func(ctx context.Context, item T) (Agent, error) - // OnAgentEvents is called for each event emitted by the agent. Optional. - // The inputItem is the message that triggered the current agent turn. - // If not provided, the default implementation will consume all events and - // return any error event encountered. - OnAgentEvents func(ctx context.Context, inputItem T, event *AsyncIterator[*AgentEvent]) error - // ReceiveTimeout is the timeout passed to Source.Receive on each iteration. - // Zero means no timeout. Optional. - ReceiveTimeout time.Duration - + // GenInput receives the TurnLoop instance and all buffered items, and decides what to process. + // It returns which items to consume now vs keep for later turns. + // The loop parameter allows calling Push() or Stop() directly from within the callback. + // Required. + GenInput func(ctx context.Context, loop *TurnLoop[T], items []T) (*GenInputResult[T], error) + + // GenResume is called exactly once when the TurnLoop detects a mid-turn + // checkpoint on startup (i.e. CheckpointID is configured and the stored + // checkpoint has runner state from an interrupted agent execution). + // It receives: + // - canceledItems: the items being processed when the prior run was canceled + // - unhandledItems: items buffered but not processed when the prior run exited + // - newItems: items that were Push()-ed before Run() was called + // + // It returns a GenResumeResult describing how to resume the interrupted agent + // turn (optional ResumeParams) and how to manipulate the buffer + // (Consumed/Remaining) before continuing. + GenResume func(ctx context.Context, loop *TurnLoop[T], canceledItems, unhandledItems, newItems []T) (*GenResumeResult[T], error) + + // PrepareAgent returns an Agent configured to handle the consumed items. + // This callback should set up the agent with appropriate system prompt, + // tools, and middlewares based on what items are being processed. + // Called once per turn with the items that GenInput decided to consume. + // The loop parameter allows calling Push() or Stop() directly from within the callback. + // Required. + PrepareAgent func(ctx context.Context, loop *TurnLoop[T], consumed []T) (Agent, error) + + // OnAgentEvents is called to handle events emitted by the agent. + // The TurnContext provides per-turn info and control: + // - tc.Consumed: items that triggered this agent execution + // - tc.Loop: allows calling Push() or Stop() directly from within the callback + // - tc.Preempted / tc.Stopped: signals while processing events + // Optional. If not provided, events are drained and errors (except CancelError + // from Stop-triggered cancellation) are returned as ExitReason. + OnAgentEvents func(ctx context.Context, tc *TurnContext[T], events *AsyncIterator[*AgentEvent]) error + + // Store is the checkpoint store for persistence and resume. Optional. + // When set together with CheckpointID, enables automatic checkpoint-based resume. + // The TurnLoop always persists both runner checkpoint bytes and item bookkeeping + // (CanceledItems, UnhandledItems) via gob encoding, so T must be gob-encodable + // when Store is used. Store CheckPointStore + + // CheckpointID, when set together with Store, enables automatic + // checkpoint-based resume. On Run(), the TurnLoop queries Store for this ID: + // - If a checkpoint exists with runner state (mid-turn interrupt), + // GenResume is called to plan the resume turn. + // - If a checkpoint exists without runner state (between-turns), + // the stored unhandled items are buffered and the loop proceeds + // normally via GenInput. + // - If no checkpoint exists, the loop starts fresh. + // + // On exit, if the TurnLoop saved a new checkpoint, it is saved under this + // same CheckpointID. On clean exit (no checkpoint saved), the existing + // checkpoint under CheckpointID is deleted to prevent stale resumption. + CheckpointID string } -// TurnLoop is a loop that pulls messages from a source, runs an Agent for -// each message, and dispatches resulting events. It supports preemptive -// cancellation when the source returns ConsumePreemptive and the current -// agent implements Cancellable. -type TurnLoop[T any] struct { - source MessageSource[T] - genInput func(ctx context.Context, item T) (*AgentInput, []AgentRunOption, error) - getAgent func(ctx context.Context, item T) (Agent, error) - onAgentEvents func(ctx context.Context, inputItem T, event *AsyncIterator[*AgentEvent]) error - receiveTimeout time.Duration - store CheckPointStore +// GenInputResult contains the result of GenInput processing. +type GenInputResult[T any] struct { + // RunCtx, if non-nil, overrides the context for this turn's execution + // (PrepareAgent, agent run, OnAgentEvents). + // + // Must be derived from the ctx passed to GenInput to preserve the + // TurnLoop's cancellation semantics and inherited values. For example: + // + // runCtx := context.WithValue(ctx, traceKey{}, extractTraceID(items)) + // return &GenInputResult[T]{RunCtx: runCtx, ...}, nil + // + // If nil, the TurnLoop's context is used unchanged. + RunCtx context.Context + + // Input is the agent input to execute + Input *AgentInput + + // RunOpts are the options for this agent run + RunOpts []AgentRunOption + + // Consumed are the items selected for this turn. + // They are removed from the buffer and passed to PrepareAgent. + Consumed []T + + // Remaining are the items to keep in the buffer for a future turn. + // TurnLoop pushes Remaining back into the buffer before running the agent. + // + // Items from the GenInput input slice that are in neither Consumed nor Remaining + // are dropped by the loop. + Remaining []T } -type turnLoopCancelSig struct { - done chan struct{} - config atomic.Value +// GenResumeResult contains the result of GenResume processing. +type GenResumeResult[T any] struct { + // RunCtx, if non-nil, overrides the context for this resumed turn's execution + // (PrepareAgent, agent resume, OnAgentEvents). + RunCtx context.Context + + // RunOpts are the options for this agent resume run. + RunOpts []AgentRunOption + + // ResumeParams are optional parameters for resuming an interrupted agent. + ResumeParams *ResumeParams + + // Consumed are the items selected for this resumed turn. + // They are removed from the buffer and passed to PrepareAgent. + Consumed []T + + // Remaining are the items to keep in the buffer for a future turn. + // TurnLoop pushes Remaining back into the buffer before resuming the agent. + // + // Items from (canceledItems, unhandledItems, newItems) that are in neither Consumed + // nor Remaining are dropped by the loop. + Remaining []T } -func newTurnLoopCancelSig() *turnLoopCancelSig { - return &turnLoopCancelSig{ - done: make(chan struct{}), - } +type turnRunSpec[T any] struct { + runCtx context.Context + input *AgentInput + runOpts []AgentRunOption + resumeParams *ResumeParams + isResume bool + consumed []T + resumeBytes []byte } -func (cs *turnLoopCancelSig) cancel(cfg *cancelConfig) { - cs.config.Store(cfg) - close(cs.done) +type turnPlan[T any] struct { + turnCtx context.Context + remaining []T + spec *turnRunSpec[T] } -func (cs *turnLoopCancelSig) isCancelled() bool { - select { - case <-cs.done: - return true - default: - return false +func (l *TurnLoop[T]) planTurn( + ctx context.Context, + isResume bool, + items []T, + pr *turnLoopPendingResume[T], +) (*turnPlan[T], error) { + if !isResume { + result, err := l.config.GenInput(ctx, l, items) + if err != nil { + return nil, err + } + if result == nil { + return nil, errors.New("GenInputResult is nil") + } + if result.Input == nil { + return nil, errors.New("agent input is nil") + } + turnCtx := ctx + if result.RunCtx != nil { + turnCtx = result.RunCtx + } + return &turnPlan[T]{ + turnCtx: turnCtx, + remaining: result.Remaining, + spec: &turnRunSpec[T]{ + runCtx: result.RunCtx, + input: result.Input, + runOpts: result.RunOpts, + consumed: result.Consumed, + }, + }, nil + } + if pr == nil { + return nil, errors.New("resume payload is nil") + } + if l.config.GenResume == nil { + return nil, errors.New("GenResume is required for resume") + } + resumeResult, err := l.config.GenResume(ctx, l, pr.canceled, pr.unhandled, pr.newItems) + if err != nil { + return nil, err + } + if resumeResult == nil { + return nil, errors.New("GenResumeResult is nil") + } + turnCtx := ctx + if resumeResult.RunCtx != nil { + turnCtx = resumeResult.RunCtx } + return &turnPlan[T]{ + turnCtx: turnCtx, + remaining: resumeResult.Remaining, + spec: &turnRunSpec[T]{ + runCtx: resumeResult.RunCtx, + runOpts: resumeResult.RunOpts, + resumeParams: resumeResult.ResumeParams, + isResume: true, + consumed: resumeResult.Consumed, + resumeBytes: pr.resumeBytes, + }, + }, nil } -func (cs *turnLoopCancelSig) getConfig() *cancelConfig { - if v := cs.config.Load(); v != nil { - return v.(*cancelConfig) - } - return nil +// TurnLoopExitState is returned when TurnLoop exits, containing the exit reason +// and any items that were not processed. +type TurnLoopExitState[T any] struct { + // ExitReason indicates why the loop exited. + // nil means clean exit (Stop() was called and completed normally). + // Non-nil values include context errors, callback errors, *CancelError, etc. + // When Stop() cancels a running agent, ExitReason will be a *CancelError. + ExitReason error + + // UnhandledItems contains items that were buffered but not processed. + // This is always valid regardless of ExitReason. + UnhandledItems []T + + // CanceledItems contains the items whose turn was canceled by Stop(). + // This is set when Stop() is called during a running turn, even if it + // did not contribute to the final CancelError. + // It can be used to reconstruct GenInput/PrepareAgent inputs when resuming. + CanceledItems []T } -func (cs *turnLoopCancelSig) getDoneChan() <-chan struct{} { - if cs != nil { - return cs.done +// TurnContext provides per-turn context to the OnAgentEvents callback. +type TurnContext[T any] struct { + // Loop is the TurnLoop instance, allowing Push() or Stop() calls. + Loop *TurnLoop[T] + + // Consumed contains items that triggered this agent execution. + Consumed []T + + // Preempted is closed when a preempt signal fires for the current turn + // (via Push with WithPreempt) and at least one preemptive Push contributed + // to the CancelError for the current turn. + // "Contributed" means the preempt's cancel options were included in the + // CancelError before it was finalized. Remains open if no preempt contributed. + // Use in a select to detect preemption while processing events. + Preempted <-chan struct{} + + // Stopped is closed when a Stop() call contributed to the CancelError for the + // current turn. + // "Contributed" means Stop's cancel options were included in the CancelError + // before it was finalized. Remains open if Stop did not contribute. + // Use in a select to detect stop while processing events. + Stopped <-chan struct{} +} + +// TurnLoop is a push-based event loop for agent execution. +// Users push items via Push() and the loop processes them through the agent. +// +// Create with NewTurnLoop, then start with Run: +// +// loop := NewTurnLoop(cfg) +// // pass loop to other components, push initial items, etc. +// loop.Run(ctx) +// +// # Permissive API +// +// All methods are valid on a not-yet-running loop: +// - Push: items are buffered and will be processed once Run is called. +// - Stop: sets the stopped flag; a subsequent Run will exit immediately. +// - Wait: blocks until Run is called AND the loop exits. If Run is never +// called, Wait blocks forever (this is a programming error, analogous +// to reading from a channel that nobody writes to). +type TurnLoop[T any] struct { + config TurnLoopConfig[T] + + buffer *internal.UnboundedChan[T] + + stopped int32 + started int32 + + done chan struct{} + + result *TurnLoopExitState[T] + + stopOnce sync.Once + + runOnce sync.Once + + stopSig *stopSignal + + preemptSig *preemptSignal + + runErr error + + canceledItems []T + + checkPointRunnerBytes []byte + + pendingResume *turnLoopPendingResume[T] + + loadCheckpointID string + + onAgentEvents func(ctx context.Context, tc *TurnContext[T], events *AsyncIterator[*AgentEvent]) error +} + +type turnLoopCheckpoint[T any] struct { + RunnerCheckpoint []byte + // HasRunnerState reports whether RunnerCheckpoint contains resumable runner state. + // It is false for "between turns" checkpoints where no agent execution was + // interrupted (e.g. Stop() before the first turn or between turns). + HasRunnerState bool + UnhandledItems []T + CanceledItems []T +} + +// ErrCheckpointStoreNil is returned when a checkpoint operation requires a Store +// but none was configured in TurnLoopConfig. +var ErrCheckpointStoreNil = errors.New("checkpoint store is nil") + +func marshalTurnLoopCheckpoint[T any](c *turnLoopCheckpoint[T]) ([]byte, error) { + buf := new(bytes.Buffer) + if err := gob.NewEncoder(buf).Encode(c); err != nil { + return nil, err } - return nil + return buf.Bytes(), nil } -type turnLoopCancelSigKey struct{} +func unmarshalTurnLoopCheckpoint[T any](data []byte) (*turnLoopCheckpoint[T], error) { + var c turnLoopCheckpoint[T] + if err := gob.NewDecoder(bytes.NewReader(data)).Decode(&c); err != nil { + return nil, err + } + return &c, nil +} -func withTurnLoopCancelSig(ctx context.Context, cs *turnLoopCancelSig) context.Context { - return context.WithValue(ctx, turnLoopCancelSigKey{}, cs) +func (l *TurnLoop[T]) saveTurnLoopCheckpoint(ctx context.Context, checkPointID string, c *turnLoopCheckpoint[T]) error { + if l.config.Store == nil { + return ErrCheckpointStoreNil + } + data, err := marshalTurnLoopCheckpoint(c) + if err != nil { + return err + } + return l.config.Store.Set(ctx, checkPointID, data) } -func getTurnLoopCancelSig(ctx context.Context) *turnLoopCancelSig { - if v, ok := ctx.Value(turnLoopCancelSigKey{}).(*turnLoopCancelSig); ok { - return v +func (l *TurnLoop[T]) deleteTurnLoopCheckpoint(ctx context.Context, checkPointID string) error { + if l.config.Store == nil { + return nil } - return nil + if deleter, ok := l.config.Store.(CheckPointDeleter); ok { + return deleter.Delete(ctx, checkPointID) + } + return l.config.Store.Set(ctx, checkPointID, nil) } -// ErrAgentNotCancellableInTurnLoop is returned when WithCancel context is used -// but the Agent does not implement CancellableAgent. -var ErrAgentNotCancellableInTurnLoop = errors.New("agent does not support cancel but WithCancel context was provided") +func (l *TurnLoop[T]) tryLoadCheckpoint(ctx context.Context) error { + checkPointID := l.config.CheckpointID + if checkPointID == "" || l.config.Store == nil { + return nil + } + + l.loadCheckpointID = checkPointID -// NewTurnLoop creates a new TurnLoop from the given configuration. -// Source, GenInput, and GetAgent are required fields. -func NewTurnLoop[T any](config TurnLoopConfig[T]) (*TurnLoop[T], error) { - if config.Source == nil { - return nil, fmt.Errorf("TurnLoopConfig.Source is required") + data, existed, err := l.config.Store.Get(ctx, checkPointID) + if err != nil { + return fmt.Errorf("failed to load checkpoint[%s]: %w", checkPointID, err) + } + if !existed { + return nil } - if config.GenInput == nil { - return nil, fmt.Errorf("TurnLoopConfig.GenInput is required") + + var cp *turnLoopCheckpoint[T] + if len(data) == 0 { + return nil } - if config.GetAgent == nil { - return nil, fmt.Errorf("TurnLoopConfig.GetAgent is required") + cp, err = unmarshalTurnLoopCheckpoint[T](data) + if err != nil { + return fmt.Errorf("failed to unmarshal checkpoint[%s]: %w", checkPointID, err) } - onAgentEvents := config.OnAgentEvents - if onAgentEvents == nil { - onAgentEvents = func(_ context.Context, _ T, iter *AsyncIterator[*AgentEvent]) error { - for { - event, ok := iter.Next() - if !ok { - break - } - if event.Err != nil { - return event.Err - } - } - return nil + newItems := l.buffer.TakeAll() + + if cp.HasRunnerState { + if len(cp.RunnerCheckpoint) == 0 { + l.buffer.PushFront(newItems) + return fmt.Errorf("checkpoint[%s] has runner state but bytes are empty", checkPointID) + } + l.pendingResume = &turnLoopPendingResume[T]{ + canceled: append([]T{}, cp.CanceledItems...), + unhandled: append([]T{}, cp.UnhandledItems...), + newItems: append([]T{}, newItems...), + resumeBytes: append([]byte{}, cp.RunnerCheckpoint...), } + } else { + items := make([]T, 0, len(cp.UnhandledItems)+len(newItems)) + items = append(items, cp.UnhandledItems...) + items = append(items, newItems...) + l.buffer.PushFront(items) } - return &TurnLoop[T]{ - source: config.Source, - genInput: config.GenInput, - getAgent: config.GetAgent, - onAgentEvents: onAgentEvents, - receiveTimeout: config.ReceiveTimeout, - store: config.Store, - }, nil + return nil +} + +type turnLoopPendingResume[T any] struct { + canceled []T + unhandled []T + newItems []T + resumeBytes []byte +} + +type stopConfig struct { + agentCancelOpts []AgentCancelOption +} + +// StopOption is an option for Stop(). +type StopOption func(*stopConfig) + +// WithAgentCancel sets the agent cancel options to use when stopping the loop. +// These options control how the currently running agent is cancelled. +func WithAgentCancel(opts ...AgentCancelOption) StopOption { + return func(cfg *stopConfig) { + cfg.agentCancelOpts = opts + } +} + +type pushConfig[T any] struct { + preempt bool + preemptDelay time.Duration + agentCancelOpts []AgentCancelOption + pushStrategy func(context.Context, *TurnContext[T]) []PushOption[T] +} + +// PushOption is an option for Push(). +type PushOption[T any] func(*pushConfig[T]) + +// WithPreempt signals that the current agent should be canceled after pushing. +// This enables atomic "push + preempt" to avoid race conditions between +// pushing an urgent item and triggering preemption. +// The loop will cancel the current agent turn and continue with the next turn, +// where GenInput will see all buffered items including the newly pushed one. +func WithPreempt[T any](agentCancelOpts ...AgentCancelOption) PushOption[T] { + return func(cfg *pushConfig[T]) { + cfg.preempt = true + cfg.agentCancelOpts = agentCancelOpts + } } -var ErrLoopExit = errors.New("loop exit") +// WithPreemptDelay sets a delay duration before preemption takes effect. +// When used with WithPreempt, the push will succeed immediately, but the +// preemption signal will be delayed by the specified duration. +// This allows the current agent to continue processing for a grace period +// before being preempted. +func WithPreemptDelay[T any](delay time.Duration) PushOption[T] { + return func(cfg *pushConfig[T]) { + cfg.preemptDelay = delay + } +} -// WithCancel returns a new context and a cancel function that can be used to -// cancel the TurnLoop's Run method externally. Each call to WithCancel creates -// an independent cancel signal, allowing multiple concurrent Run calls with -// separate cancel controls. +// WithPushStrategy provides dynamic push option resolution based on the current turn state. +// The callback receives the current turn's context and TurnContext (nil if no turn is active) +// and returns the actual PushOptions to apply. When WithPushStrategy is used, all other +// PushOptions passed to the same Push call are ignored. // -// The returned TurnLoopCancelFunc does not require a context parameter since -// the context is already bound when WithCancel is called. +// The returned options must not contain another WithPushStrategy; any nested +// strategy is silently stripped. // -// Example: +// Example: preempt only if the current turn is processing low-priority items: // -// ctx, cancel := turnLoop.WithCancel(context.Background()) -// go func() { -// err := turnLoop.Run(ctx) -// }() -// // Later, to cancel: -// cancel(adk.WithCancelMode(adk.CancelAfterToolCall)) -func (l *TurnLoop[T]) WithCancel(ctx context.Context) (context.Context, CancelFunc) { - cs := newTurnLoopCancelSig() - ctx = withTurnLoopCancelSig(ctx, cs) +// loop.Push(urgentItem, WithPushStrategy(func(ctx context.Context, tc *TurnContext[MyItem]) []PushOption[MyItem] { +// if tc == nil { +// return nil // between turns, plain push +// } +// if isLowPriority(tc.Consumed) { +// return []PushOption[MyItem]{WithPreempt[MyItem]()} +// } +// return nil // don't preempt high-priority work +// })) +func WithPushStrategy[T any](fn func(ctx context.Context, tc *TurnContext[T]) []PushOption[T]) PushOption[T] { + return func(cfg *pushConfig[T]) { + cfg.pushStrategy = fn + } +} - var once sync.Once - cancelFn := func(opts ...CancelOption) error { - cfg := &cancelConfig{ - Mode: CancelImmediate, +func defaultTurnLoopOnAgentEvents[T any](_ context.Context, _ *TurnContext[T], events *AsyncIterator[*AgentEvent]) error { + for { + event, ok := events.Next() + if !ok { + break } - for _, opt := range opts { - opt(cfg) + if event.Err != nil { + return event.Err } - once.Do(func() { - cs.cancel(cfg) - }) - return nil } - - return ctx, cancelFn + return nil } -// Run starts the blocking loop that continuously receives messages from the -// source, runs the agent returned by GetAgent for each message, and dispatches -// resulting events to OnAgentEvent. It blocks until the source returns an error -// (including context cancellation) or a callback fails. +// NewTurnLoop creates a new TurnLoop without starting it. +// The returned loop accepts Push and Stop calls immediately; pushed items +// are buffered until Run is called. +// Call Run to start the processing goroutine. // -// If a received message has ConsumePreemptive mode and the current agent -// implements Cancellable, the agent is canceled and the new message is processed -// immediately. If the agent does not implement Cancellable, preemptive messages -// are queued and processed after the current agent finishes. -// -// To enable external cancellation, use WithCancel to create a cancellable context: +// NewTurnLoop panics if GenInput or PrepareAgent is nil. +func NewTurnLoop[T any](cfg TurnLoopConfig[T]) *TurnLoop[T] { + if cfg.GenInput == nil { + panic("adk: NewTurnLoop: GenInput is required") + } + if cfg.PrepareAgent == nil { + panic("adk: NewTurnLoop: PrepareAgent is required") + } + + l := &TurnLoop[T]{ + config: cfg, + buffer: internal.NewUnboundedChan[T](), + done: make(chan struct{}), + stopSig: newStopSignal(), + preemptSig: newPreemptSignal(), + } + if cfg.OnAgentEvents != nil { + l.onAgentEvents = cfg.OnAgentEvents + } else { + l.onAgentEvents = defaultTurnLoopOnAgentEvents[T] + } + return l +} + +func (l *TurnLoop[T]) start(ctx context.Context) { + l.runOnce.Do(func() { + atomic.StoreInt32(&l.started, 1) + go l.run(ctx) + }) +} + +// Run starts the loop's processing goroutine. It is non-blocking: the loop +// runs in the background and results are obtained via Wait. // -// ctx, cancel := turnLoop.WithCancel(context.Background()) -// go turnLoop.Run(ctx) -// // Later: cancel() +// If CheckpointID is configured in TurnLoopConfig and a matching checkpoint +// exists in Store, the loop automatically resumes from that checkpoint. +// Otherwise it starts fresh with whatever items were Push()-ed. // -// To enable checkpoint-based resumption, use WithTurnLoopResume: +// Calling Run more than once is a no-op: only the first call starts the loop. +func (l *TurnLoop[T]) Run(ctx context.Context) { + l.start(ctx) +} + +// Push adds an item to the loop's buffer for processing. +// This method is non-blocking and thread-safe. +// Returns false if the loop has stopped, true otherwise. If a preemptive push +// succeeds, the second return value is a channel that is closed when the loop +// has acknowledged the preempt signal (by either initiating cancellation of the +// current agent run or reaching a point where no cancellation is needed). +// If the loop has not been started yet (Run not called), items are buffered +// and will be processed once Run is called. // -// err := turnLoop.Run(ctx, WithTurnLoopResume("session-123")) +// Use WithPreempt() to atomically push an item and signal preemption of the current agent. +// This is useful for urgent items that should interrupt the current processing. +// The returned channel may be waited on if the caller needs to ensure the preempt +// signal has been observed. // -//nolint:cyclop,funlen // This is a core method, splitting would make the logic harder to follow -func (l *TurnLoop[T]) Run(ctx context.Context, opts ...TurnLoopRunOption[T]) error { - var runCfg turnLoopRunConfig[T] +// Use WithPreemptDelay() together with WithPreempt() to delay the preemption signal. +// Push returns immediately after the item is buffered, and a goroutine is spawned +// to signal preemption after the delay. +func (l *TurnLoop[T]) Push(item T, opts ...PushOption[T]) (bool, <-chan struct{}) { + cfg := &pushConfig[T]{} for _, opt := range opts { - opt(&runCfg) + opt(cfg) } - cs := getTurnLoopCancelSig(ctx) - toResumeFirst := false - if len(runCfg.checkPointID) > 0 { - toResumeFirst = true + if cfg.pushStrategy != nil { + return l.pushWithStrategy(item, cfg) } - for { - if cs != nil && cs.isCancelled() { - return nil - } + return l.pushWithConfig(item, cfg) +} - var nCtx context.Context - var item T - var checkPointID string - if !toResumeFirst { - var err error - var option []ConsumeOption - nCtx, item, option, err = l.source.Receive(ctx, ReceiveConfig{ - Timeout: l.receiveTimeout, - }) - if errors.Is(err, ErrLoopExit) { - return nil - } - if err != nil { - return fmt.Errorf("failed to receive message: %w", err) +// pushWithStrategy atomically holds the run loop and snapshots the current turn, +// then calls the strategy callback with a guaranteed-stable TurnContext. If the +// strategy returns preempt options, the hold is kept and a preempt is requested; +// otherwise the hold is released and the item is buffered as a plain push. +func (l *TurnLoop[T]) pushWithStrategy(item T, cfg *pushConfig[T]) (bool, <-chan struct{}) { + strategy := cfg.pushStrategy + + runCtx, tcAny := l.preemptSig.holdAndGetTurn() + if runCtx == nil { + runCtx = context.Background() + } + var tc *TurnContext[T] + if tcAny != nil { + tc = tcAny.(*TurnContext[T]) + } + realOpts := strategy(runCtx, tc) + cfg = &pushConfig[T]{} + for _, opt := range realOpts { + opt(cfg) + } + cfg.pushStrategy = nil + + if !cfg.preempt { + l.preemptSig.unholdRunLoop() + return l.buffer.TrySend(item), nil + } + + if atomic.LoadInt32(&l.stopped) != 0 { + l.preemptSig.unholdRunLoop() + return false, nil + } + + if !l.buffer.TrySend(item) { + l.preemptSig.unholdRunLoop() + return false, nil + } + + ack := make(chan struct{}) + if atomic.LoadInt32(&l.started) == 0 { + l.preemptSig.unholdRunLoop() + close(ack) + return true, ack + } + + if cfg.preemptDelay > 0 { + go func() { + select { + case <-time.After(cfg.preemptDelay): + l.preemptSig.requestPreempt(ack, cfg.agentCancelOpts...) + case <-l.done: + l.preemptSig.unholdRunLoop() + close(ack) } - o := applyConsumeOptions(option) - checkPointID = o.CheckPointID - } else { - nCtx = ctx - item = runCfg.item - checkPointID = runCfg.checkPointID + }() + } else { + l.preemptSig.requestPreempt(ack, cfg.agentCancelOpts...) + } + return true, ack +} + +func (l *TurnLoop[T]) pushWithConfig(item T, cfg *pushConfig[T]) (bool, <-chan struct{}) { + if atomic.LoadInt32(&l.stopped) != 0 { + return false, nil + } + + if cfg.preempt { + l.preemptSig.holdRunLoop() + + if !l.buffer.TrySend(item) { + l.preemptSig.unholdRunLoop() + return false, nil } - if len(checkPointID) > 0 && l.store == nil { - return fmt.Errorf("CheckPointStore is required") + ack := make(chan struct{}) + if atomic.LoadInt32(&l.started) == 0 { + l.preemptSig.unholdRunLoop() + close(ack) + return true, ack } - input, runOpts, e := l.genInput(nCtx, item) - if e != nil { - return fmt.Errorf("failed to generate agent input: %w", e) + if cfg.preemptDelay > 0 { + go func() { + select { + case <-time.After(cfg.preemptDelay): + l.preemptSig.requestPreempt(ack, cfg.agentCancelOpts...) + case <-l.done: + l.preemptSig.unholdRunLoop() + close(ack) + } + }() + } else { + l.preemptSig.requestPreempt(ack, cfg.agentCancelOpts...) } + return true, ack + } + + return l.buffer.TrySend(item), nil +} + +// Stop signals the loop to stop and returns immediately (non-blocking). +// The loop will finish the current turn (or cancel it via WithAgentCancel options), +// then exit without starting a new turn. +// Use WithAgentCancel to control how the currently running agent is cancelled. +// This method is idempotent - multiple calls update cancel options. +// Call Wait() to block until the loop has fully exited and get the result. +// +// Stop may be called before Run. In that case, the stopped flag is set and +// a subsequent Run will exit the loop immediately. +// +// If the running agent does not support the WithCancel AgentRunOption, +// Stop degrades to "exit the loop on entering the next iteration" — the +// current agent turn runs to completion before the loop exits. +func (l *TurnLoop[T]) Stop(opts ...StopOption) { + cfg := &stopConfig{} + for _, opt := range opts { + opt(cfg) + } + + l.stopSig.signal(cfg) + + l.stopOnce.Do(func() { + l.stopSig.closeDone() + atomic.StoreInt32(&l.stopped, 1) + l.buffer.Close() + }) +} + +// Wait blocks until the loop exits and returns the result. +// This method is safe to call from multiple goroutines. +// All callers will receive the same result. +// +// Wait blocks until Run is called AND the loop exits. If Run is +// ever called, Wait blocks forever. +func (l *TurnLoop[T]) Wait() *TurnLoopExitState[T] { + <-l.done + return l.result +} + +func (l *TurnLoop[T]) run(ctx context.Context) { + defer l.cleanup(ctx) + + if err := l.tryLoadCheckpoint(ctx); err != nil { + l.runErr = err + return + } - agent, e := l.getAgent(nCtx, item) - if e != nil { - return fmt.Errorf("failed to get agent: %w", e) + // Monitor context cancellation: close the buffer so that a blocking + // Receive() unblocks. The loop will then check ctx.Err() and exit. + go func() { + select { + case <-ctx.Done(): + l.buffer.Close() + case <-l.done: } + }() - var cancelFunc CancelFunc - var iter *AsyncIterator[*AgentEvent] - _, isAgentCancellable := agent.(CancellableAgent) - if cs != nil && !isAgentCancellable { - return fmt.Errorf("%w: agent %s", ErrAgentNotCancellableInTurnLoop, agent.Name(nCtx)) + for { + if l.stopSig.isStopped() { + return } - if toResumeFirst { - var err error - iter, cancelFunc, err = NewRunner(nCtx, RunnerConfig{ - EnableStreaming: input.EnableStreaming, - Agent: agent, - CheckPointStore: l.store, - }).ResumeWithCancel(nCtx, checkPointID, runOpts...) - if err != nil { - return fmt.Errorf("failed to resume agent: %w", err) + isResume := false + var pr *turnLoopPendingResume[T] + var items []T + var pushBack []T + + if l.pendingResume != nil { + isResume = true + pr = l.pendingResume + l.pendingResume = nil + + pushBack = make([]T, 0, len(pr.canceled)+len(pr.unhandled)+len(pr.newItems)) + pushBack = append(pushBack, pr.canceled...) + pushBack = append(pushBack, pr.unhandled...) + pushBack = append(pushBack, pr.newItems...) + } else { + first, ok := l.buffer.Receive() + if !ok { + if err := ctx.Err(); err != nil { + l.runErr = err + } + return } - toResumeFirst = false - } else if isAgentCancellable { - var cps CheckPointStore - if len(checkPointID) > 0 { - cps = l.store - runOpts = append(runOpts, WithCheckPointID(checkPointID)) + + if err := ctx.Err(); err != nil { + l.buffer.PushFront([]T{first}) + l.runErr = err + return } - iter, cancelFunc = NewRunner(nCtx, RunnerConfig{ - EnableStreaming: input.EnableStreaming, - Agent: agent, - CheckPointStore: cps, - }).RunWithCancel(nCtx, input.Messages, runOpts...) - } else { - var cps CheckPointStore - if len(checkPointID) > 0 { - cps = l.store - runOpts = append(runOpts, WithCheckPointID(checkPointID)) + + if l.stopSig.isStopped() { + l.buffer.PushFront([]T{first}) + return } - iter = NewRunner(nCtx, RunnerConfig{ - EnableStreaming: input.EnableStreaming, - Agent: agent, - CheckPointStore: cps, - }).Run(nCtx, input.Messages, runOpts...) + + rest := l.buffer.TakeAll() + items = append([]T{first}, rest...) + pushBack = items } - handleEvents := func() error { - return l.handleEvents(ctx, item, iter, checkPointID) + // Drain any pending preempt that arrived between turns. A Push caller + // may have called holdRunLoop + requestPreempt while the loop was + // between iterations; acknowledge and release before planning the + // next turn. Use drainAll to release all pusher holds at once — + // multiple concurrent Push(WithPreempt) callers each hold a ref. + if preempted, _, ackList := l.preemptSig.waitForPreemptOrUnhold(); preempted { + for _, ack := range ackList { + close(ack) + } + l.preemptSig.drainAll() } - if cancelFunc != nil { - var handleEventErr error - done := make(chan struct{}) + plan, err := l.planTurn(ctx, isResume, items, pr) + if err != nil { + if len(pushBack) > 0 { + l.buffer.PushFront(pushBack) + } + l.runErr = err + return + } - go func() { - defer func() { - panicErr := recover() - if panicErr != nil { - handleEventErr = safe.NewPanicErr(panicErr, debug.Stack()) - } + if l.stopSig.isStopped() { + if len(pushBack) > 0 { + l.buffer.PushFront(pushBack) + } + return + } - close(done) - }() + agent, err := l.config.PrepareAgent(plan.turnCtx, l, plan.spec.consumed) + if err != nil { + if len(pushBack) > 0 { + l.buffer.PushFront(pushBack) + } + l.runErr = err + return + } - handleEventErr = handleEvents() - }() + if l.stopSig.isStopped() { + if len(pushBack) > 0 { + l.buffer.PushFront(pushBack) + } + return + } - frontDone := make(chan struct{}) - var frontErr error - var option []ConsumeOption - go func() { - defer func() { - panicErr := recover() - if panicErr != nil { - frontErr = safe.NewPanicErr(panicErr, debug.Stack()) - } + l.buffer.PushFront(plan.remaining) - close(frontDone) - }() - _, _, option, frontErr = l.source.Front(nCtx, ReceiveConfig{ - Timeout: l.receiveTimeout, - }) - }() + // Bracket the turn with holdRunLoop / endTurnAndUnhold. The run loop's + // own hold ensures that if a Push caller also holds mid-turn, the total + // holdCount stays > 0 after endTurnAndUnhold, blocking the loop at + // waitForPreemptOrUnhold until the Push caller's preempt is resolved. + l.preemptSig.holdRunLoop() + runErr := l.runAgentAndHandleEvents(plan.turnCtx, agent, plan.spec) - select { - case <-frontDone: - case <-done: - case <-cs.getDoneChan(): - err := cancelAndWait(cancelFunc, cs, done) - if err != nil { - return err - } - return l.wrapHandleEventErr(handleEventErr) - } + l.preemptSig.endTurnAndUnhold() - if frontErr != nil { - <-done - if errors.Is(frontErr, ErrLoopExit) { - return nil - } - return fmt.Errorf("failed to front message: %w", frontErr) - } + if runErr != nil { + l.runErr = runErr + return + } + } +} - o := applyConsumeOptions(option) - switch o.Mode { - case ConsumePreemptive: - err := cancelFunc(o.CancelOpts...) - if err != nil { - <-done - return fmt.Errorf("failed to cancel agent: %w", err) - } - case ConsumePreemptiveOnTimeout: - select { - case <-done: - case <-time.After(o.Timeout): - err := cancelFunc(o.CancelOpts...) - if err != nil { - <-done - return fmt.Errorf("failed to cancel agent: %w", err) +func (l *TurnLoop[T]) setupBridgeStore(spec *turnRunSpec[T], runOpts []AgentRunOption) ([]AgentRunOption, *bridgeStore, error) { + store := l.config.Store + if store == nil && spec.isResume { + return nil, nil, fmt.Errorf("failed to resume agent: %w", ErrCheckpointStoreNil) + } + if store == nil { + return runOpts, nil, nil + } + runOpts = append(runOpts, WithCheckPointID(bridgeCheckpointID)) + if spec.isResume { + if len(spec.resumeBytes) == 0 { + return nil, nil, fmt.Errorf("resume checkpoint is empty") + } + return runOpts, newResumeBridgeStore(bridgeCheckpointID, spec.resumeBytes), nil + } + return runOpts, newBridgeStore(), nil +} + +// watchPreemptSignal runs for the lifetime of a single turn. It listens on the +// notify channel for preempt requests and relays them to agentCancelFunc. +// +// preemptGen de-duplicates notifications: multiple notify wakes can fire for the +// same logical preempt (e.g. cond.Broadcast + channel send), so the watcher +// only acts when the generation advances. +// +// On the first preempt whose cancel actually contributed (i.e. the cancel options +// were accepted before the CancelError was finalized), preemptDone is closed to +// wake runAgentAndHandleEvents's select. +func (l *TurnLoop[T]) watchPreemptSignal(done <-chan struct{}, agentCancelFunc AgentCancelFunc, preemptDone chan struct{}) { + var lastGen uint64 + for { + select { + case <-done: + return + case <-l.preemptSig.notify: + if preempted, gen, opts, ackList := l.preemptSig.receivePreempt(); preempted { + if gen != lastGen { + firstPreempt := lastGen == 0 + lastGen = gen + // CancelHandle is intentionally not awaited here: agentCancelFunc commits the cancel signal synchronously, + // while waiting would block until the turn finishes and can deadlock this watcher against the done signal. + _, contributed := agentCancelFunc(opts...) + if firstPreempt && contributed { + close(preemptDone) } - case <-cs.getDoneChan(): - err := cancelAndWait(cancelFunc, cs, done) - if err != nil { - return err + for _, ack := range ackList { + close(ack) } - return l.wrapHandleEventErr(handleEventErr) } } + } + } +} - select { - case <-done: - case <-cs.getDoneChan(): - err := cancelAndWait(cancelFunc, cs, done) - if err != nil { - return err +// watchStopSignal runs for the lifetime of a single turn. It selects on two +// channels from stopSignal: +// +// - done (permanently closed after Stop): the durable stop flag. Fires +// immediately for any watcher, even those in turns started after +// Stop() but before the run loop observed isStopped(). This eliminates +// the race where a previous turn's watcher consumed the one-shot notify, +// leaving the current turn unable to detect the stop. +// +// - notify (one-shot, buffered 1): fires when a new Stop() call is made, +// enabling cancel-mode escalation (e.g. CancelAfterToolCalls → CancelImmediate). +// The generation counter de-duplicates wakes, analogous to preemptGen in +// watchPreemptSignal. +// +// On the first cancel that actually contributed (i.e. the cancel was accepted +// before the CancelError was finalized), stoppedDone is closed to wake +// runAgentAndHandleEvents's select. +func (l *TurnLoop[T]) watchStopSignal(done <-chan struct{}, agentCancelFunc AgentCancelFunc, stoppedDone chan struct{}) { + var lastGen uint64 + stoppedClosed := false + for { + select { + case <-done: + return + case <-l.stopSig.notify: + gen, opts := l.stopSig.check() + if gen != lastGen { + lastGen = gen + // CancelHandle is intentionally not awaited here: agentCancelFunc + // commits the cancel signal synchronously, while waiting would block + // until the turn finishes and can deadlock this watcher against done. + _, contributed := agentCancelFunc(opts...) + if contributed && !stoppedClosed { + close(stoppedDone) + stoppedClosed = true } - return l.wrapHandleEventErr(handleEventErr) - } - if err := l.wrapHandleEventErr(handleEventErr); err != nil { - return err } - } else { - if handleEventErr := handleEvents(); handleEventErr != nil { - if err := l.wrapHandleEventErr(handleEventErr); err != nil { - return err - } + case <-l.stopSig.done: + _, opts := l.stopSig.check() + _, contributed := agentCancelFunc(opts...) + if contributed && !stoppedClosed { + close(stoppedDone) + stoppedClosed = true } + <-done + return } } } -func (l *TurnLoop[T]) wrapHandleEventErr(handleEventErr error) error { - if handleEventErr == nil { - return nil +func (l *TurnLoop[T]) runAgentAndHandleEvents( + ctx context.Context, + agent Agent, + spec *turnRunSpec[T], +) error { + var iter *AsyncIterator[*AgentEvent] + defer func() { + if l.stopSig.isStopped() && len(l.canceledItems) == 0 { + l.canceledItems = append([]T{}, spec.consumed...) + } + }() + + runOpts, ms, err := l.setupBridgeStore(spec, spec.runOpts) + if err != nil { + return err } - if errors.Is(handleEventErr, ErrLoopExit) { - return nil + store := l.config.Store + cancelOpt, agentCancelFunc := WithCancel() + runOpts = append(runOpts, cancelOpt) + + enableStreaming := false + if spec.input != nil { + enableStreaming = spec.input.EnableStreaming } - var interruptErr *TurnLoopInterruptError[T] - if errors.As(handleEventErr, &interruptErr) { - return interruptErr + runner := NewRunner(ctx, RunnerConfig{ + EnableStreaming: enableStreaming, + Agent: agent, + CheckPointStore: ms, + }) + + preemptDone := make(chan struct{}) + stoppedDone := make(chan struct{}) + + tc := &TurnContext[T]{ + Loop: l, + Consumed: spec.consumed, + Preempted: preemptDone, + Stopped: stoppedDone, } - return fmt.Errorf("failed to handle events: %w", handleEventErr) -} + l.preemptSig.setTurn(ctx, tc) -func (l *TurnLoop[T]) handleEvents(ctx context.Context, item T, iter *AsyncIterator[*AgentEvent], checkPointID string) error { - copies := copyEventIterator(iter, 2) - oe := l.onAgentEvents(ctx, item, copies[0]) - if oe != nil { - return oe - } - for { - e, ok := copies[1].Next() - if !ok { - break + if spec.isResume { + var err error + if spec.resumeParams != nil { + iter, err = runner.ResumeWithParams(ctx, bridgeCheckpointID, spec.resumeParams, runOpts...) + } else { + iter, err = runner.Resume(ctx, bridgeCheckpointID, runOpts...) } - if e.Action != nil && e.Action.Interrupted != nil { - return &TurnLoopInterruptError[T]{ - Item: item, - CheckpointID: checkPointID, - InterruptContexts: e.Action.Interrupted.InterruptContexts, - } + if err != nil { + return fmt.Errorf("failed to resume agent: %w", err) } + } else { + iter = runner.Run(ctx, spec.input.Messages, runOpts...) } - return nil -} -func cancelAndWait(cf CancelFunc, cs *turnLoopCancelSig, done chan struct{}) error { - cfg := cs.getConfig() - err := cf(cancelConfigToOpts(cfg)...) - if err != nil { - <-done - return fmt.Errorf("failed to cancel agent: %w", err) + handleEvents := func() error { + return l.onAgentEvents(ctx, tc, iter) } - <-done - return nil -} -func cancelConfigToOpts(cfg *cancelConfig) []CancelOption { - if cfg == nil { + done := make(chan struct{}) + var handleErr error + + go func() { + defer func() { + panicErr := recover() + if panicErr != nil { + handleErr = safe.NewPanicErr(panicErr, debug.Stack()) + } + close(done) + }() + handleErr = handleEvents() + }() + go l.watchPreemptSignal(done, agentCancelFunc, preemptDone) + go l.watchStopSignal(done, agentCancelFunc, stoppedDone) + + finalizeCheckpoint := func() error { + if store != nil && ms != nil { + data, ok, err := ms.Get(ctx, bridgeCheckpointID) + if err != nil { + return fmt.Errorf("failed to read runner checkpoint: %w", err) + } + if ok { + l.checkPointRunnerBytes = append([]byte{}, data...) + } + } return nil } - opts := []CancelOption{WithCancelMode(cfg.Mode)} - if cfg.Timeout != nil { - opts = append(opts, WithCancelTimeout(*cfg.Timeout)) + + // Wait for the turn to end. Three outcomes: + // + // done: Events fully handled (normal or error). If Stop() was + // called, save checkpoint so the caller can resume later. + // Also handle the select race: if preemptDone is closed + // too, treat as a preempt (return nil) instead of leaking + // the CancelError. + // + // preemptDone: A preemptive Push successfully cancelled the agent. + // Wait for the handleEvents goroutine to drain, then + // return nil — the run loop will start a new turn. + // + // stoppedDone: Stop() cancelled the agent. Save checkpoint so the + // caller can resume later. + select { + case <-done: + select { + case <-preemptDone: + return nil + default: + } + if l.stopSig.isStopped() { + if err := finalizeCheckpoint(); err != nil { + if handleErr != nil { + handleErr = fmt.Errorf("%w; checkpoint error: %v", handleErr, err) + } else { + handleErr = err + } + } + } + return handleErr + case <-preemptDone: + <-done + return nil + case <-stoppedDone: + <-done + if err := finalizeCheckpoint(); err != nil { + if handleErr != nil { + handleErr = fmt.Errorf("%w; checkpoint error: %v", handleErr, err) + } else { + handleErr = err + } + } + return handleErr } - return opts } -func applyConsumeOptions(opts []ConsumeOption) *consumeConfig { - var config consumeConfig - for _, opt := range opts { - opt(&config) +func (l *TurnLoop[T]) cleanup(ctx context.Context) { + atomic.StoreInt32(&l.stopped, 1) + + unhandled := l.buffer.TakeAll() + checkpointID := l.config.CheckpointID + shouldSaveCheckpoint := l.config.Store != nil && checkpointID != "" && l.stopSig.isStopped() + if shouldSaveCheckpoint { + cp := &turnLoopCheckpoint[T]{ + RunnerCheckpoint: l.checkPointRunnerBytes, + HasRunnerState: len(l.checkPointRunnerBytes) > 0, + UnhandledItems: unhandled, + CanceledItems: l.canceledItems, + } + err := l.saveTurnLoopCheckpoint(ctx, checkpointID, cp) + if err != nil { + saveErr := fmt.Errorf("failed to save turn loop checkpoint: %w", err) + if l.runErr != nil { + l.runErr = fmt.Errorf("%w; %v", l.runErr, saveErr) + } else { + l.runErr = saveErr + } + } + } else if l.loadCheckpointID != "" { + _ = l.deleteTurnLoopCheckpoint(ctx, l.loadCheckpointID) } - return &config -} -type TurnLoopInterruptError[T any] struct { - Item T - CheckpointID string - // InterruptContexts provides a structured, user-facing view of the interrupt chain. - // Each context represents a step in the agent hierarchy that was interrupted. - InterruptContexts []*InterruptCtx -} + l.result = &TurnLoopExitState[T]{ + ExitReason: l.runErr, + UnhandledItems: unhandled, + CanceledItems: l.canceledItems, + } -func (t *TurnLoopInterruptError[T]) Error() string { - return fmt.Sprintf("TurnLoopInterruptError[%s]: %v", t.CheckpointID, t.InterruptContexts) + l.preemptSig.drainAll() + l.buffer.Close() + close(l.done) } diff --git a/adk/turn_loop_test.go b/adk/turn_loop_test.go index 88e3679b3..6e3159cfc 100644 --- a/adk/turn_loop_test.go +++ b/adk/turn_loop_test.go @@ -20,1097 +20,3703 @@ import ( "context" "errors" "fmt" + "sync" "sync/atomic" "testing" "time" - "unsafe" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/cloudwego/eino/components/model" - "github.com/cloudwego/eino/components/tool" - "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/schema" ) -type turnLoopMockSource struct { - items []string - idx int - err error +type turnLoopMockAgent struct { + name string + events []*AgentEvent + runFunc func(ctx context.Context, input *AgentInput) (*AgentOutput, error) + cancelFunc func(opts ...AgentCancelOption) error } -func (s *turnLoopMockSource) Receive(ctx context.Context, cfg ReceiveConfig) (context.Context, string, []ConsumeOption, error) { - if s.idx >= len(s.items) { - return ctx, "", nil, s.err - } - item := s.items[s.idx] - s.idx++ - return ctx, item, nil, nil -} +func (a *turnLoopMockAgent) Name(_ context.Context) string { return a.name } +func (a *turnLoopMockAgent) Description(_ context.Context) string { return "mock agent" } +func (a *turnLoopMockAgent) Run(ctx context.Context, input *AgentInput, _ ...AgentRunOption) *AsyncIterator[*AgentEvent] { + iter, gen := NewAsyncIteratorPair[*AgentEvent]() -func (s *turnLoopMockSource) Front(ctx context.Context, cfg ReceiveConfig) (context.Context, string, []ConsumeOption, error) { - if s.idx >= len(s.items) { - return ctx, "", nil, s.err + if a.runFunc != nil { + go func() { + defer gen.Close() + output, err := a.runFunc(ctx, input) + if err != nil { + gen.Send(&AgentEvent{Err: err}) + return + } + gen.Send(&AgentEvent{Output: output}) + }() + return iter } - return ctx, s.items[s.idx], nil, nil + + go func() { + defer gen.Close() + for _, e := range a.events { + gen.Send(e) + } + }() + return iter } -type turnLoopFuncSource[T any] struct { - receive func(ctx context.Context, cfg ReceiveConfig) (context.Context, T, []ConsumeOption, error) - front func(ctx context.Context, cfg ReceiveConfig) (context.Context, T, []ConsumeOption, error) +type turnLoopCheckpointStore struct { + m map[string][]byte + mu sync.Mutex } -func (s *turnLoopFuncSource[T]) Receive(ctx context.Context, cfg ReceiveConfig) (context.Context, T, []ConsumeOption, error) { - return s.receive(ctx, cfg) +func (s *turnLoopCheckpointStore) Set(_ context.Context, key string, value []byte) error { + s.mu.Lock() + defer s.mu.Unlock() + s.m[key] = value + return nil } -func (s *turnLoopFuncSource[T]) Front(ctx context.Context, cfg ReceiveConfig) (context.Context, T, []ConsumeOption, error) { - if s.front != nil { - return s.front(ctx, cfg) - } - return s.receive(ctx, cfg) +func (s *turnLoopCheckpointStore) Get(_ context.Context, key string) ([]byte, bool, error) { + s.mu.Lock() + defer s.mu.Unlock() + v, ok := s.m[key] + return v, ok, nil } -type turnLoopMockAgent struct { - name string - events []*AgentEvent +type turnLoopCancellableMockAgent struct { + name string + runFunc func(ctx context.Context, input *AgentInput) (*AgentOutput, error) + onCancel func(cc *cancelContext) + cancel context.CancelFunc + mu sync.Mutex } -func (a *turnLoopMockAgent) Name(_ context.Context) string { return a.name } -func (a *turnLoopMockAgent) Description(_ context.Context) string { return "mock agent" } -func (a *turnLoopMockAgent) Run(_ context.Context, _ *AgentInput, _ ...AgentRunOption) *AsyncIterator[*AgentEvent] { +func (a *turnLoopCancellableMockAgent) Name(_ context.Context) string { return a.name } +func (a *turnLoopCancellableMockAgent) Description(_ context.Context) string { return "mock agent" } + +func (a *turnLoopCancellableMockAgent) Run(ctx context.Context, input *AgentInput, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { iter, gen := NewAsyncIteratorPair[*AgentEvent]() + + o := getCommonOptions(nil, opts...) + cc := o.cancelCtx + + a.mu.Lock() + var cancelCtx context.Context + cancelCtx, a.cancel = context.WithCancel(ctx) + a.mu.Unlock() + go func() { defer gen.Close() - for _, e := range a.events { - gen.Send(e) + if cc != nil { + go func() { + <-cc.cancelChan + // CRITICAL: call onCancel BEFORE cancel() to avoid race condition. + // If cancel() fires first, the runFunc returns immediately, + // flowAgent's defer calls markDone(), and doneChan closes + // before onCancel can read cc.config. + if a.onCancel != nil { + a.onCancel(cc) + } + a.mu.Lock() + if a.cancel != nil { + a.cancel() + } + a.mu.Unlock() + }() + } + + output, err := a.runFunc(cancelCtx, input) + if err != nil { + gen.Send(&AgentEvent{Err: err}) + return } + gen.Send(&AgentEvent{Output: output}) }() return iter } -type turnLoopCancellableAgent struct { - name string - startedCh chan struct{} - cancelCh chan struct{} - cancelled int32 - cancelledOpt []CancelOption -} - -func (a *turnLoopCancellableAgent) Name(_ context.Context) string { return a.name } -func (a *turnLoopCancellableAgent) Description(_ context.Context) string { return "cancellable mock" } -func (a *turnLoopCancellableAgent) Run(_ context.Context, _ *AgentInput, _ ...AgentRunOption) *AsyncIterator[*AgentEvent] { - iter, _ := a.RunWithCancel(context.Background(), nil) - return iter +type turnLoopStopModeProbeAgent struct { + ccCh chan *cancelContext } -func (a *turnLoopCancellableAgent) RunWithCancel(_ context.Context, _ *AgentInput, _ ...AgentRunOption) (*AsyncIterator[*AgentEvent], CancelFunc) { +func (a *turnLoopStopModeProbeAgent) Name(_ context.Context) string { return "probe" } +func (a *turnLoopStopModeProbeAgent) Description(_ context.Context) string { return "probe" } +func (a *turnLoopStopModeProbeAgent) Run(ctx context.Context, input *AgentInput, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { iter, gen := NewAsyncIteratorPair[*AgentEvent]() - close(a.startedCh) + o := getCommonOptions(nil, opts...) + cc := o.cancelCtx + a.ccCh <- cc go func() { defer gen.Close() - <-a.cancelCh + <-cc.cancelChan + for { + if cc.getMode() == CancelImmediate { + gen.Send(&AgentEvent{Err: cc.createCancelError()}) + return + } + time.Sleep(1 * time.Millisecond) + } }() - cancelFunc := func(opts ...CancelOption) error { - atomic.StoreInt32(&a.cancelled, 1) - a.cancelledOpt = opts - close(a.cancelCh) - return nil - } - return iter, cancelFunc + return iter } -func TestNewTurnLoop_Validation(t *testing.T) { - t.Run("missing source", func(t *testing.T) { - _, err := NewTurnLoop(TurnLoopConfig[string]{ - GenInput: func(_ context.Context, _ string) (*AgentInput, []AgentRunOption, error) { return nil, nil, nil }, - GetAgent: func(_ context.Context, _ string) (Agent, error) { return nil, nil }, - }) - require.Error(t, err) - assert.Contains(t, err.Error(), "Source") - }) +func newAndRunTurnLoop[T any](ctx context.Context, cfg TurnLoopConfig[T]) *TurnLoop[T] { + l := NewTurnLoop(cfg) + l.Run(ctx) + return l +} - t.Run("missing GenInput", func(t *testing.T) { - _, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: &turnLoopMockSource{}, - GetAgent: func(_ context.Context, _ string) (Agent, error) { return nil, nil }, - }) - require.Error(t, err) - assert.Contains(t, err.Error(), "GenInput") - }) +func TestTurnLoop_RunAndPush(t *testing.T) { + processedItems := make([]string, 0) + var mu sync.Mutex - t.Run("missing GetAgent", func(t *testing.T) { - _, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: &turnLoopMockSource{}, - GenInput: func(_ context.Context, _ string) (*AgentInput, []AgentRunOption, error) { return nil, nil, nil }, - }) - require.Error(t, err) - assert.Contains(t, err.Error(), "GetAgent") + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + mu.Lock() + processedItems = append(processedItems, items...) + mu.Unlock() + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, }) - t.Run("valid config without OnAgentEvents", func(t *testing.T) { - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: &turnLoopMockSource{}, - GenInput: func(_ context.Context, _ string) (*AgentInput, []AgentRunOption, error) { return nil, nil, nil }, - GetAgent: func(_ context.Context, _ string) (Agent, error) { return nil, nil }, - }) - require.NoError(t, err) - assert.NotNil(t, loop) - }) + loop.Push("msg1") + loop.Push("msg2") - t.Run("valid config with OnAgentEvents", func(t *testing.T) { - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: &turnLoopMockSource{}, - GenInput: func(_ context.Context, _ string) (*AgentInput, []AgentRunOption, error) { return nil, nil, nil }, - GetAgent: func(_ context.Context, _ string) (Agent, error) { return nil, nil }, - OnAgentEvents: func(_ context.Context, _ string, _ *AsyncIterator[*AgentEvent]) error { return nil }, - }) - require.NoError(t, err) - assert.NotNil(t, loop) - }) + time.Sleep(100 * time.Millisecond) + + loop.Stop() + result := loop.Wait() + + mu.Lock() + defer mu.Unlock() + + assert.NoError(t, result.ExitReason) + assert.True(t, len(processedItems) > 0, "should have processed at least one item") } -func TestTurnLoop_NormalLoop(t *testing.T) { - agent := &turnLoopMockAgent{ - name: "test-agent", - events: []*AgentEvent{ - {Output: &AgentOutput{MessageOutput: &MessageVariant{Message: schema.AssistantMessage("hello", nil)}}}, +func TestTurnLoop_PushReturnsErrorAfterStop(t *testing.T) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{}, + Consumed: items, + }, nil }, - } + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) - var receivedItems []string - var eventCount int + loop.Stop() - source := &turnLoopMockSource{ - items: []string{"msg1", "msg2", "msg3"}, - err: context.DeadlineExceeded, - } + ok, _ := loop.Push("msg1") + assert.False(t, ok) +} - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: source, - GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { - receivedItems = append(receivedItems, item) - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil +func TestTurnLoop_StopIsIdempotent(t *testing.T) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil }, - GetAgent: func(_ context.Context, _ string) (Agent, error) { - return agent, nil - }, - OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { - for { - _, ok := iter.Next() - if !ok { - break - } - eventCount++ - } - return nil + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil }, }) - require.NoError(t, err) - err = loop.Run(context.Background()) - assert.ErrorIs(t, err, context.DeadlineExceeded) - assert.Equal(t, []string{"msg1", "msg2", "msg3"}, receivedItems) - assert.Equal(t, 3, eventCount) + loop.Stop() + loop.Stop() + loop.Stop() + + result := loop.Wait() + assert.NoError(t, result.ExitReason) } -func TestTurnLoop_SourceError(t *testing.T) { - sourceErr := errors.New("source failure") - source := &turnLoopMockSource{ - items: nil, - err: sourceErr, +func TestTurnLoop_WaitMultipleGoroutines(t *testing.T) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Stop() + + var wg sync.WaitGroup + results := make([]*TurnLoopExitState[string], 3) + + for i := 0; i < 3; i++ { + i := i + wg.Add(1) + go func() { + defer wg.Done() + results[i] = loop.Wait() + }() } - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: source, - GenInput: func(_ context.Context, _ string) (*AgentInput, []AgentRunOption, error) { - return &AgentInput{}, nil, nil + wg.Wait() + + assert.Equal(t, results[0], results[1]) + assert.Equal(t, results[1], results[2]) +} + +func TestTurnLoop_UnhandledItemsOnStop(t *testing.T) { + started := make(chan struct{}) + blocked := make(chan struct{}) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + close(started) + <-blocked + return &GenInputResult[string]{ + Input: &AgentInput{}, + Consumed: items[:1], + Remaining: items[1:], + }, nil }, - GetAgent: func(_ context.Context, _ string) (Agent, error) { - return &turnLoopMockAgent{name: "a"}, nil + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil }, - OnAgentEvents: func(_ context.Context, _ string, _ *AsyncIterator[*AgentEvent]) error { return nil }, }) - require.NoError(t, err) - err = loop.Run(context.Background()) - assert.ErrorIs(t, err, sourceErr) + loop.Push("msg1") + loop.Push("msg2") + loop.Push("msg3") + + <-started + + loop.Stop() + close(blocked) + + result := loop.Wait() + assert.True(t, len(result.UnhandledItems) >= 0, "should return unhandled items") } func TestTurnLoop_GenInputError(t *testing.T) { - genErr := errors.New("gen input failure") + genErr := errors.New("gen input error") - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: &turnLoopMockSource{items: []string{"msg1"}, err: errors.New("should not reach")}, - GenInput: func(_ context.Context, _ string) (*AgentInput, []AgentRunOption, error) { - return nil, nil, genErr + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return nil, genErr }, - GetAgent: func(_ context.Context, _ string) (Agent, error) { - return &turnLoopMockAgent{name: "a"}, nil + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil }, - OnAgentEvents: func(_ context.Context, _ string, _ *AsyncIterator[*AgentEvent]) error { return nil }, }) - require.NoError(t, err) - err = loop.Run(context.Background()) - assert.ErrorIs(t, err, genErr) - assert.Contains(t, err.Error(), "failed to generate agent input") + loop.Push("msg1") + + result := loop.Wait() + assert.ErrorIs(t, result.ExitReason, genErr) } func TestTurnLoop_GetAgentError(t *testing.T) { - agentErr := errors.New("get agent failure") + agentErr := errors.New("get agent error") - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: &turnLoopMockSource{items: []string{"msg1"}, err: errors.New("should not reach")}, - GenInput: func(_ context.Context, _ string) (*AgentInput, []AgentRunOption, error) { - return &AgentInput{}, nil, nil + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil }, - GetAgent: func(_ context.Context, _ string) (Agent, error) { + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { return nil, agentErr }, - OnAgentEvents: func(_ context.Context, _ string, _ *AsyncIterator[*AgentEvent]) error { return nil }, }) - require.NoError(t, err) - err = loop.Run(context.Background()) - assert.ErrorIs(t, err, agentErr) - assert.Contains(t, err.Error(), "failed to get agent") + loop.Push("msg1") + + result := loop.Wait() + assert.ErrorIs(t, result.ExitReason, agentErr) } -func TestTurnLoop_OnAgentEventsError(t *testing.T) { - eventErr := errors.New("event handler failure") - agent := &turnLoopMockAgent{ - name: "test-agent", - events: []*AgentEvent{{Output: &AgentOutput{}}}, - } +func TestTurnLoop_BatchProcessing(t *testing.T) { + var batches [][]string + var mu sync.Mutex - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: &turnLoopMockSource{items: []string{"msg1"}, err: errors.New("should not reach")}, - GenInput: func(_ context.Context, _ string) (*AgentInput, []AgentRunOption, error) { - return &AgentInput{}, nil, nil - }, - GetAgent: func(_ context.Context, _ string) (Agent, error) { - return agent, nil + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + mu.Lock() + batches = append(batches, items) + mu.Unlock() + + return &GenInputResult[string]{ + Input: &AgentInput{}, + Consumed: items[:1], + Remaining: items[1:], + }, nil }, - OnAgentEvents: func(_ context.Context, _ string, _ *AsyncIterator[*AgentEvent]) error { - return eventErr + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil }, }) - require.NoError(t, err) - err = loop.Run(context.Background()) - assert.ErrorIs(t, err, eventErr) - assert.Contains(t, err.Error(), "failed to handle events") -} + loop.Push("msg1") + loop.Push("msg2") + loop.Push("msg3") -func TestTurnLoop_ContextCancellation(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + time.Sleep(200 * time.Millisecond) - callCount := 0 - source := &turnLoopFuncSource[string]{receive: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { - callCount++ - if callCount > 1 { - cancel() - return ctx, "", nil, ctx.Err() - } - return ctx, "msg1", nil, nil - }} + loop.Stop() + loop.Wait() - agent := &turnLoopMockAgent{ - name: "test-agent", - events: []*AgentEvent{{Output: &AgentOutput{}}}, - } + mu.Lock() + defer mu.Unlock() - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: source, - GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil - }, - GetAgent: func(_ context.Context, _ string) (Agent, error) { - return agent, nil + assert.True(t, len(batches) > 0, "should have processed at least one batch") +} + +func TestTurnLoop_StopWithMode(t *testing.T) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil }, - OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { - for { - if _, ok := iter.Next(); !ok { - break - } - } - return nil + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil }, }) - require.NoError(t, err) - err = loop.Run(ctx) - assert.ErrorIs(t, err, context.Canceled) + loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelAfterToolCalls))) + + result := loop.Wait() + assert.NoError(t, result.ExitReason) } -func TestTurnLoop_MultipleEventsPerTurn(t *testing.T) { - agent := &turnLoopMockAgent{ - name: "multi-event-agent", - events: []*AgentEvent{ - {Output: &AgentOutput{MessageOutput: &MessageVariant{Message: schema.AssistantMessage("event1", nil)}}}, - {Output: &AgentOutput{MessageOutput: &MessageVariant{Message: schema.AssistantMessage("event2", nil)}}}, - {Output: &AgentOutput{MessageOutput: &MessageVariant{Message: schema.AssistantMessage("event3", nil)}}}, +func TestTurnLoop_Preempt_CancelsCurrentAgent(t *testing.T) { + agentStarted := make(chan struct{}) + agentCancelled := make(chan struct{}) + agentStartedOnce := sync.Once{} + agentCancelledOnce := sync.Once{} + + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + agentStartedOnce.Do(func() { + close(agentStarted) + }) + <-ctx.Done() + agentCancelledOnce.Do(func() { + close(agentCancelled) + }) + return &AgentOutput{}, nil }, } - var eventCount int + genInputCalls := int32(0) + secondGenInputCalled := make(chan struct{}) + secondGenInputOnce := sync.Once{} - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: &turnLoopMockSource{items: []string{"msg1"}, err: context.DeadlineExceeded}, - GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil - }, - GetAgent: func(_ context.Context, _ string) (Agent, error) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { return agent, nil }, - OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { - for { - _, ok := iter.Next() - if !ok { - break - } - eventCount++ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + count := atomic.AddInt32(&genInputCalls, 1) + if count >= 2 { + secondGenInputOnce.Do(func() { + close(secondGenInputCalled) + }) } - return nil + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: []string{items[0]}, + Remaining: items[1:], + }, nil }, }) - require.NoError(t, err) - err = loop.Run(context.Background()) - assert.ErrorIs(t, err, context.DeadlineExceeded) - assert.Equal(t, 3, eventCount) -} + loop.Push("first") -func TestTurnLoop_DefaultOnAgentEvents(t *testing.T) { - agent := &turnLoopMockAgent{ - name: "test-agent", - events: []*AgentEvent{{Output: &AgentOutput{}}}, + select { + case <-agentStarted: + case <-time.After(1 * time.Second): + t.Fatal("agent did not start") } - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: &turnLoopMockSource{items: []string{"msg1"}, err: context.DeadlineExceeded}, - GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil - }, - GetAgent: func(_ context.Context, _ string) (Agent, error) { - return agent, nil - }, - }) - require.NoError(t, err) - - err = loop.Run(context.Background()) - assert.ErrorIs(t, err, context.DeadlineExceeded) -} + loop.Push("urgent", WithPreempt[string]()) -func TestTurnLoop_DefaultOnAgentEventsWithError(t *testing.T) { - agentErr := errors.New("agent internal error") - agent := &turnLoopMockAgent{ - name: "error-agent", - events: []*AgentEvent{{Err: agentErr}}, + select { + case <-agentCancelled: + case <-time.After(1 * time.Second): + t.Fatal("agent was not cancelled by preempt") } - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: &turnLoopMockSource{items: []string{"msg1"}, err: context.DeadlineExceeded}, - GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil - }, - GetAgent: func(_ context.Context, _ string) (Agent, error) { - return agent, nil - }, - }) - require.NoError(t, err) + select { + case <-secondGenInputCalled: + case <-time.After(1 * time.Second): + t.Fatal("second GenInput was not called after preempt") + } - err = loop.Run(context.Background()) - assert.ErrorIs(t, err, agentErr) + loop.Stop() + result := loop.Wait() + assert.NoError(t, result.ExitReason) + assert.GreaterOrEqual(t, atomic.LoadInt32(&genInputCalls), int32(2)) } -func TestTurnLoop_AgentErrorEvent(t *testing.T) { - agentErr := errors.New("agent internal error") - agent := &turnLoopMockAgent{ - name: "error-agent", - events: []*AgentEvent{{Err: agentErr}}, +func TestTurnLoop_Preempt_DiscardsConsumedItems(t *testing.T) { + agentStarted := make(chan struct{}) + agentDone := make(chan struct{}) + agentStartedOnce := sync.Once{} + agentDoneOnce := sync.Once{} + firstAgentRun := true + var firstRunMu sync.Mutex + + genInputResults := make([][]string, 0) + var mu sync.Mutex + + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + firstRunMu.Lock() + isFirst := firstAgentRun + firstAgentRun = false + firstRunMu.Unlock() + + if isFirst { + agentStartedOnce.Do(func() { + close(agentStarted) + }) + <-ctx.Done() + } else { + agentDoneOnce.Do(func() { + close(agentDone) + }) + } + return &AgentOutput{}, nil + }, } - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: &turnLoopMockSource{items: []string{"msg1"}, err: context.DeadlineExceeded}, - GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil - }, - GetAgent: func(_ context.Context, _ string) (Agent, error) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { return agent, nil }, - OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { - for { - event, ok := iter.Next() - if !ok { - break - } - if event.Err != nil { - return fmt.Errorf("agent run failed: %w", event.Err) - } - } - return nil + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + mu.Lock() + genInputResults = append(genInputResults, items) + mu.Unlock() + + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: []string{items[0]}, + Remaining: items[1:], + }, nil }, }) - require.NoError(t, err) - err = loop.Run(context.Background()) - assert.ErrorIs(t, err, agentErr) - assert.Contains(t, err.Error(), "agent run failed") -} + loop.Push("first") + + select { + case <-agentStarted: + case <-time.After(1 * time.Second): + t.Fatal("agent did not start") + } + + loop.Push("urgent", WithPreempt[string]()) -func TestTurnLoop_PreemptiveCancellation(t *testing.T) { - slowAgent := &turnLoopCancellableAgent{ - name: "slow-agent", - startedCh: make(chan struct{}), - cancelCh: make(chan struct{}), + select { + case <-agentDone: + case <-time.After(1 * time.Second): + t.Fatal("second agent run did not complete") } - fastAgent := &turnLoopMockAgent{ - name: "fast-agent", - events: []*AgentEvent{{Output: &AgentOutput{}}}, + loop.Stop() + result := loop.Wait() + assert.NoError(t, result.ExitReason) + + mu.Lock() + defer mu.Unlock() + assert.GreaterOrEqual(t, len(genInputResults), 2) + if len(genInputResults) >= 2 { + assert.NotContains(t, genInputResults[1], "first") + assert.Contains(t, genInputResults[1], "urgent") } +} - var processedItems []string - receiveCount := 0 - msgs := []struct { - item string - opts []ConsumeOption - err error - }{ - {"slow-msg", nil, nil}, - {"preempt-msg", []ConsumeOption{WithPreemptive(), WithCancelOptions(WithCancelMode(CancelImmediate))}, nil}, - {"", nil, context.DeadlineExceeded}, - } - frontIdx := 0 - source := &turnLoopFuncSource[string]{ - receive: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { - if receiveCount >= len(msgs) { - return ctx, "", nil, context.DeadlineExceeded - } - m := msgs[receiveCount] - receiveCount++ - frontIdx = receiveCount - return ctx, m.item, m.opts, m.err - }, - front: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { - <-slowAgent.startedCh - if frontIdx >= len(msgs) { - return ctx, "", nil, context.DeadlineExceeded - } - m := msgs[frontIdx] - return ctx, m.item, m.opts, m.err +func TestTurnLoop_Preempt_WithAgentCancelMode(t *testing.T) { + agentStarted := make(chan struct{}) + cancelFuncCalled := make(chan struct{}) + agentStartedOnce := sync.Once{} + cancelFuncCalledOnce := sync.Once{} + firstCancelModeUsed := CancelImmediate + var cancelModeMu sync.Mutex + + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + agentStartedOnce.Do(func() { + close(agentStarted) + }) + <-ctx.Done() + return &AgentOutput{}, nil + }, + onCancel: func(cc *cancelContext) { + cancelModeMu.Lock() + cancelFuncCalledOnce.Do(func() { + firstCancelModeUsed = cc.getMode() + close(cancelFuncCalled) + }) + cancelModeMu.Unlock() }, } - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: source, - GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { - processedItems = append(processedItems, item) - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil - }, - GetAgent: func(_ context.Context, item string) (Agent, error) { - if item == "slow-msg" { - return slowAgent, nil - } - return fastAgent, nil + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil }, - OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { - for { - if _, ok := iter.Next(); !ok { - break - } - } - return nil + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: []string{items[0]}, + Remaining: items[1:], + }, nil }, }) - require.NoError(t, err) - err = loop.Run(context.Background()) - assert.ErrorIs(t, err, context.DeadlineExceeded) - assert.True(t, atomic.LoadInt32(&slowAgent.cancelled) == 1, "slow agent should have been cancelled") - assert.Equal(t, []string{"slow-msg", "preempt-msg"}, processedItems) -} + loop.Push("first") -func TestTurnLoop_PreemptiveNonCancellableAgent(t *testing.T) { - nonCancellableAgent := &turnLoopMockAgent{ - name: "non-cancellable-agent", - events: []*AgentEvent{{Output: &AgentOutput{}}}, - } - fastAgent := &turnLoopMockAgent{ - name: "fast-agent", - events: []*AgentEvent{{Output: &AgentOutput{}}}, + select { + case <-agentStarted: + case <-time.After(1 * time.Second): + t.Fatal("agent did not start") } - var processedItems []string - callCount := 0 - source := &turnLoopFuncSource[string]{receive: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { - callCount++ - switch callCount { - case 1: - return ctx, "non-cancel-msg", nil, nil - case 2: - return ctx, "preempt-msg", []ConsumeOption{WithPreemptive()}, nil - default: - return ctx, "", nil, context.DeadlineExceeded - } - }} + loop.Push("urgent", WithPreempt[string](WithAgentCancelMode(CancelAfterToolCalls))) - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: source, - GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { - processedItems = append(processedItems, item) - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil - }, - GetAgent: func(_ context.Context, item string) (Agent, error) { - if item == "non-cancel-msg" { - return nonCancellableAgent, nil - } - return fastAgent, nil - }, - OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { - for { - if _, ok := iter.Next(); !ok { - break - } - } - return nil - }, - }) - require.NoError(t, err) + select { + case <-cancelFuncCalled: + case <-time.After(1 * time.Second): + t.Fatal("cancelFunc was not called by preempt") + } - err = loop.Run(context.Background()) - assert.ErrorIs(t, err, context.DeadlineExceeded) - assert.Equal(t, []string{"non-cancel-msg", "preempt-msg"}, processedItems) + loop.Stop() + result := loop.Wait() + assert.NoError(t, result.ExitReason) + cancelModeMu.Lock() + actualMode := firstCancelModeUsed + cancelModeMu.Unlock() + assert.Equal(t, CancelAfterToolCalls, actualMode) } -func TestTurnLoop_WithCancel_Basic(t *testing.T) { - agent := &turnLoopCancellableAgent{ - name: "test-agent", - startedCh: make(chan struct{}), - cancelCh: make(chan struct{}), - } +func TestTurnLoop_PreemptAck_ClosesAfterCancelIsInitiated(t *testing.T) { + agentStarted := make(chan struct{}) + cancelObserved := make(chan struct{}) + agentFinishGate := make(chan struct{}) + agentStartedOnce := sync.Once{} + cancelObservedOnce := sync.Once{} - receiveCount := int32(0) - frontBlocked := make(chan struct{}) - source := &turnLoopFuncSource[string]{ - receive: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { - cnt := atomic.AddInt32(&receiveCount, 1) - if cnt == 1 { - return ctx, "msg1", nil, nil - } - return ctx, "", nil, context.DeadlineExceeded + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + agentStartedOnce.Do(func() { close(agentStarted) }) + <-ctx.Done() + <-agentFinishGate + return &AgentOutput{}, nil }, - front: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { - <-frontBlocked - return ctx, "", nil, context.DeadlineExceeded + onCancel: func(cc *cancelContext) { + cancelObservedOnce.Do(func() { close(cancelObserved) }) }, } - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: source, - GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil - }, - GetAgent: func(_ context.Context, _ string) (Agent, error) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { return agent, nil }, - OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { - for { - if _, ok := iter.Next(); !ok { - break - } - } - return nil + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: []string{items[0]}, + Remaining: items[1:], + }, nil }, }) - require.NoError(t, err) - ctx, cancel := loop.WithCancel(context.Background()) - done := make(chan error) - go func() { - done <- loop.Run(ctx) - }() + _, _ = loop.Push("first") - <-agent.startedCh - e := cancel() - assert.NoError(t, e) + select { + case <-agentStarted: + case <-time.After(1 * time.Second): + t.Fatal("agent did not start") + } - err = <-done - assert.NoError(t, err) -} + ok, ack := loop.Push("urgent", WithPreempt[string](WithAgentCancelMode(CancelAfterToolCalls))) + assert.True(t, ok) + assert.NotNil(t, ack) -func TestTurnLoop_WithCancel_DuringAgentRun(t *testing.T) { - slowAgent := &turnLoopCancellableAgent{ - name: "slow-agent", - startedCh: make(chan struct{}), - cancelCh: make(chan struct{}), + select { + case <-ack: + case <-time.After(1 * time.Second): + t.Fatal("preempt ack was not closed") } - receiveCount := int32(0) - frontBlocked := make(chan struct{}) - source := &turnLoopFuncSource[string]{ - receive: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { - cnt := atomic.AddInt32(&receiveCount, 1) - if cnt == 1 { - return ctx, "msg1", nil, nil - } - return ctx, "", nil, context.DeadlineExceeded - }, - front: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { - <-frontBlocked - return ctx, "", nil, context.DeadlineExceeded - }, + select { + case <-cancelObserved: + case <-time.After(1 * time.Second): + t.Fatal("cancel was not initiated") } - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: source, - GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil - }, - GetAgent: func(_ context.Context, _ string) (Agent, error) { - return slowAgent, nil + close(agentFinishGate) + + loop.Stop() + result := loop.Wait() + assert.NoError(t, result.ExitReason) +} + +func TestTurnLoop_PreemptAck_ClosesImmediatelyIfLoopNotStarted(t *testing.T) { + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil }, - OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { - for { - if _, ok := iter.Next(); !ok { - break - } - } - return nil + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil }, }) - require.NoError(t, err) - - ctx, cancel := loop.WithCancel(context.Background()) - done := make(chan error) - go func() { - done <- loop.Run(ctx) - }() - <-slowAgent.startedCh - cancel(WithCancelMode(CancelImmediate)) + ok, ack := loop.Push("urgent", WithPreempt[string]()) + assert.True(t, ok) + assert.NotNil(t, ack) - err = <-done - assert.NoError(t, err) - assert.True(t, atomic.LoadInt32(&slowAgent.cancelled) == 1, "agent should have been cancelled") + select { + case <-ack: + case <-time.After(1 * time.Second): + t.Fatal("preempt ack was not closed") + } } -func TestTurnLoop_WithCancel_NonCancellableAgent_ReturnsError(t *testing.T) { - agent := &turnLoopMockAgent{ - name: "non-cancellable-agent", - events: []*AgentEvent{{Output: &AgentOutput{}}}, - } +func TestTurnLoop_Preempt_EscalatesOnSecondPreempt(t *testing.T) { + agentStarted := make(chan struct{}) + firstCancelSeen := make(chan struct{}) + agentFinishGate := make(chan struct{}) + agentStartedOnce := sync.Once{} + firstCancelOnce := sync.Once{} - source := &turnLoopFuncSource[string]{ - receive: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { - return ctx, "msg1", nil, nil + var ccPtr atomic.Value + + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + agentStartedOnce.Do(func() { close(agentStarted) }) + <-ctx.Done() + <-agentFinishGate + return &AgentOutput{}, nil + }, + onCancel: func(cc *cancelContext) { + ccPtr.Store(cc) + firstCancelOnce.Do(func() { close(firstCancelSeen) }) }, } - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: source, - GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil - }, - GetAgent: func(_ context.Context, _ string) (Agent, error) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { return agent, nil }, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: []string{items[0]}, + Remaining: items[1:], + }, nil + }, }) - require.NoError(t, err) - ctx, _ := loop.WithCancel(context.Background()) - err = loop.Run(ctx) + loop.Push("first") + select { + case <-agentStarted: + case <-time.After(1 * time.Second): + t.Fatal("agent did not start") + } - assert.ErrorIs(t, err, ErrAgentNotCancellableInTurnLoop) - assert.Contains(t, err.Error(), "non-cancellable-agent") -} + loop.Push("urgent1", WithPreempt[string](WithAgentCancelMode(CancelAfterChatModel))) + select { + case <-firstCancelSeen: + case <-time.After(1 * time.Second): + t.Fatal("first preempt did not trigger cancel") + } -func TestTurnLoop_WithCancel_MultipleCalls(t *testing.T) { - agent := &turnLoopMockAgent{ - name: "test-agent", - events: []*AgentEvent{{Output: &AgentOutput{}}}, + loop.Push("urgent2", WithPreempt[string](WithAgentCancelMode(CancelImmediate))) + + deadline := time.Now().Add(1 * time.Second) + for time.Now().Before(deadline) { + v := ccPtr.Load() + if v == nil { + time.Sleep(5 * time.Millisecond) + continue + } + cc := v.(*cancelContext) + if cc.getMode() == CancelImmediate && atomic.LoadInt32(&cc.escalated) == 1 { + break + } + time.Sleep(5 * time.Millisecond) } - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: &turnLoopMockSource{items: []string{"msg1"}, err: context.DeadlineExceeded}, - GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil - }, - GetAgent: func(_ context.Context, _ string) (Agent, error) { - return agent, nil - }, - }) - require.NoError(t, err) + v := ccPtr.Load() + if v == nil { + t.Fatal("cancel context was not captured") + } + cc := v.(*cancelContext) + assert.Equal(t, CancelImmediate, cc.getMode()) + assert.Equal(t, int32(1), atomic.LoadInt32(&cc.escalated)) - _, cancel := loop.WithCancel(context.Background()) + close(agentFinishGate) - err = cancel() - assert.NoError(t, err) + loop.Stop() + result := loop.Wait() + assert.NoError(t, result.ExitReason) } -func TestTurnLoop_WithCancel_IndependentCancels(t *testing.T) { - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: &turnLoopMockSource{items: []string{}, err: context.DeadlineExceeded}, - GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil +func TestTurnLoop_Preempt_JoinsSafePointModesOnSecondPreempt(t *testing.T) { + agentStarted := make(chan struct{}) + firstCancelSeen := make(chan struct{}) + agentFinishGate := make(chan struct{}) + agentStartedOnce := sync.Once{} + firstCancelOnce := sync.Once{} + + var ccPtr atomic.Value + + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + agentStartedOnce.Do(func() { close(agentStarted) }) + <-ctx.Done() + <-agentFinishGate + return &AgentOutput{}, nil + }, + onCancel: func(cc *cancelContext) { + ccPtr.Store(cc) + firstCancelOnce.Do(func() { close(firstCancelSeen) }) + }, + } + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil }, - GetAgent: func(_ context.Context, _ string) (Agent, error) { - return &turnLoopMockAgent{name: "a"}, nil + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: []string{items[0]}, + Remaining: items[1:], + }, nil }, }) - require.NoError(t, err) - ctx1, cancel1 := loop.WithCancel(context.Background()) - ctx2, cancel2 := loop.WithCancel(context.Background()) + loop.Push("first") + select { + case <-agentStarted: + case <-time.After(1 * time.Second): + t.Fatal("agent did not start") + } + + loop.Push("urgent1", WithPreempt[string](WithAgentCancelMode(CancelAfterChatModel))) + select { + case <-firstCancelSeen: + case <-time.After(1 * time.Second): + t.Fatal("first preempt did not trigger cancel") + } + + loop.Push("urgent2", WithPreempt[string](WithAgentCancelMode(CancelAfterToolCalls))) - cs1 := getTurnLoopCancelSig(ctx1) - cs2 := getTurnLoopCancelSig(ctx2) + want := CancelAfterChatModel | CancelAfterToolCalls + deadline := time.Now().Add(1 * time.Second) + for time.Now().Before(deadline) { + v := ccPtr.Load() + if v == nil { + time.Sleep(5 * time.Millisecond) + continue + } + cc := v.(*cancelContext) + if cc.getMode() == want { + break + } + time.Sleep(5 * time.Millisecond) + } - assert.NotNil(t, cs1) - assert.NotNil(t, cs2) - assert.NotEqual(t, cs1, cs2) + v := ccPtr.Load() + if v == nil { + t.Fatal("cancel context was not captured") + } + cc := v.(*cancelContext) + assert.Equal(t, want, cc.getMode()) - cancel1() - assert.True(t, cs1.isCancelled()) - assert.False(t, cs2.isCancelled()) + close(agentFinishGate) - cancel2() - assert.True(t, cs2.isCancelled()) + loop.Stop() + result := loop.Wait() + assert.NoError(t, result.ExitReason) } -func TestTurnLoop_WithCancel_WithCancelOptions(t *testing.T) { - slowAgent := &turnLoopCancellableAgent{ - name: "slow-agent", - startedCh: make(chan struct{}), - cancelCh: make(chan struct{}), - } +func TestTurnLoop_Push_WithoutPreempt_DoesNotCancel(t *testing.T) { + agentRunCount := 0 + agentDone := make(chan struct{}) - receiveCount := int32(0) - frontBlocked := make(chan struct{}) - source := &turnLoopFuncSource[string]{ - receive: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { - cnt := atomic.AddInt32(&receiveCount, 1) - if cnt == 1 { - return ctx, "msg1", nil, nil + agent := &turnLoopMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + agentRunCount++ + if agentRunCount == 1 { + time.Sleep(100 * time.Millisecond) } - return ctx, "", nil, context.DeadlineExceeded - }, - front: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { - <-frontBlocked - return ctx, "", nil, context.DeadlineExceeded + if agentRunCount == 2 { + close(agentDone) + } + return &AgentOutput{}, nil }, } - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: source, - GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil }, - GetAgent: func(_ context.Context, _ string) (Agent, error) { - return slowAgent, nil + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: []string{items[0]}, + Remaining: items[1:], + }, nil }, - OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { - for { - if _, ok := iter.Next(); !ok { + }) + + loop.Push("first") + time.Sleep(20 * time.Millisecond) + loop.Push("second") + + select { + case <-agentDone: + case <-time.After(1 * time.Second): + t.Fatal("second agent run did not complete") + } + + loop.Stop() + result := loop.Wait() + assert.NoError(t, result.ExitReason) + assert.Equal(t, 2, agentRunCount) +} + +func TestTurnLoop_PreemptDelay_NoMispreemptOnNaturalCompletion(t *testing.T) { + agent1Started := make(chan struct{}) + agent1Done := make(chan struct{}) + agent2Started := make(chan struct{}) + agent2Done := make(chan struct{}) + agent1StartedOnce := sync.Once{} + agent1DoneOnce := sync.Once{} + agent2StartedOnce := sync.Once{} + agent2DoneOnce := sync.Once{} + + var agentRunCount int32 + + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + count := atomic.AddInt32(&agentRunCount, 1) + if count == 1 { + agent1StartedOnce.Do(func() { close(agent1Started) }) + time.Sleep(50 * time.Millisecond) + agent1DoneOnce.Do(func() { close(agent1Done) }) + } else if count == 2 { + agent2StartedOnce.Do(func() { close(agent2Started) }) + time.Sleep(100 * time.Millisecond) + select { + case <-ctx.Done(): + t.Error("Agent2 was unexpectedly cancelled") + return nil, ctx.Err() + default: + } + agent2DoneOnce.Do(func() { close(agent2Done) }) + } + return &AgentOutput{}, nil + }, + } + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: []string{items[0]}, + Remaining: items[1:], + }, nil + }, + }) + + loop.Push("first") + + select { + case <-agent1Started: + case <-time.After(1 * time.Second): + t.Fatal("agent1 did not start") + } + + loop.Push("second", WithPreempt[string](), WithPreemptDelay[string](500*time.Millisecond)) + + select { + case <-agent1Done: + case <-time.After(1 * time.Second): + t.Fatal("agent1 did not complete naturally") + } + + select { + case <-agent2Started: + case <-time.After(1 * time.Second): + t.Fatal("agent2 did not start") + } + + select { + case <-agent2Done: + case <-time.After(1 * time.Second): + t.Fatal("agent2 did not complete - may have been incorrectly preempted") + } + + loop.Stop() + result := loop.Wait() + assert.NoError(t, result.ExitReason) + assert.Equal(t, int32(2), atomic.LoadInt32(&agentRunCount)) +} + +func TestTurnLoop_ConcurrentPush(t *testing.T) { + var count int32 + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + atomic.AddInt32(&count, int32(len(items))) + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + for j := 0; j < 10; j++ { + _, _ = loop.Push(fmt.Sprintf("msg-%d-%d", i, j)) + } + }(i) + } + + wg.Wait() + time.Sleep(200 * time.Millisecond) + + loop.Stop() + result := loop.Wait() + + processed := atomic.LoadInt32(&count) + unhandled := len(result.UnhandledItems) + + assert.True(t, processed > 0, "should have processed some items") + assert.True(t, int(processed)+unhandled <= 100, "total should not exceed pushed amount") +} + +func TestTurnLoop_StopAfterReceive_RecoverItem(t *testing.T) { + receiveStarted := make(chan struct{}) + cancelDone := make(chan struct{}) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + close(receiveStarted) + <-cancelDone + time.Sleep(50 * time.Millisecond) + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Push("msg1") + <-receiveStarted + + loop.Stop() + close(cancelDone) + + result := loop.Wait() + assert.NoError(t, result.ExitReason) +} + +func TestTurnLoop_StopAfterGenInput_RecoverConsumed(t *testing.T) { + genInputDone := make(chan struct{}) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + close(genInputDone) + time.Sleep(50 * time.Millisecond) + return &GenInputResult[string]{ + Input: &AgentInput{}, + Consumed: items[:1], + Remaining: items[1:], + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + time.Sleep(100 * time.Millisecond) + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Push("msg1") + loop.Push("msg2") + + <-genInputDone + + time.Sleep(60 * time.Millisecond) + loop.Stop() + + result := loop.Wait() + assert.NoError(t, result.ExitReason) +} + +func TestTurnLoop_GetAgentError_RecoverConsumed(t *testing.T) { + agentErr := errors.New("get agent error") + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{}, + Consumed: items[:1], + Remaining: items[1:], + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], c []string) (Agent, error) { + return nil, agentErr + }, + }) + + loop.Push("msg1") + loop.Push("msg2") + + result := loop.Wait() + assert.ErrorIs(t, result.ExitReason, agentErr) + assert.True(t, len(result.UnhandledItems) >= 1, "should recover at least the consumed item and remaining") +} + +func TestTurnLoop_GenInputError_RecoverItems(t *testing.T) { + genErr := errors.New("gen input error") + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return nil, genErr + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Push("msg1") + loop.Push("msg2") + + result := loop.Wait() + assert.ErrorIs(t, result.ExitReason, genErr) + assert.Len(t, result.UnhandledItems, 2, "should recover all items when GenInput fails") + assert.Contains(t, result.UnhandledItems, "msg1") + assert.Contains(t, result.UnhandledItems, "msg2") +} + +func TestTurnLoop_PrepareAgentError_RecoverItemsInOrder(t *testing.T) { + agentErr := errors.New("prepare agent error") + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + var urgent string + remaining := make([]string, 0, len(items)) + for _, item := range items { + if item == "urgent" { + urgent = item + } else { + remaining = append(remaining, item) + } + } + if urgent != "" { + return &GenInputResult[string]{ + Input: &AgentInput{}, + Consumed: []string{urgent}, + Remaining: remaining, + }, nil + } + return &GenInputResult[string]{ + Input: &AgentInput{}, + Consumed: items[:1], + Remaining: items[1:], + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return nil, agentErr + }, + }) + + loop.Push("msg1") + loop.Push("urgent") + loop.Push("msg2") + + result := loop.Wait() + assert.ErrorIs(t, result.ExitReason, agentErr) + assert.Len(t, result.UnhandledItems, 3, "should recover all items") + assert.Equal(t, []string{"msg1", "urgent", "msg2"}, result.UnhandledItems, + "should preserve original push order even when GenInput selects non-prefix items") +} + +// Context cancel tests: the TurnLoop monitors context cancellation by closing +// the internal buffer when ctx.Done() fires, which unblocks the blocking +// Receive() call. The loop then checks ctx.Err() and exits with the context error. + +func TestTurnLoop_ContextCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + genInputStarted := make(chan struct{}) + genInputDone := make(chan struct{}) + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + close(genInputStarted) + <-genInputDone + if err := ctx.Err(); err != nil { + return nil, err + } + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], c []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Push("msg1") + + <-genInputStarted + cancel() + close(genInputDone) + + result := loop.Wait() + assert.ErrorIs(t, result.ExitReason, context.Canceled) +} + +func TestTurnLoop_ContextDeadlineExceeded(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + select { + case <-time.After(100 * time.Millisecond): + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + case <-ctx.Done(): + return nil, ctx.Err() + } + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], c []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Push("msg1") + + result := loop.Wait() + assert.ErrorIs(t, result.ExitReason, context.DeadlineExceeded) +} + +func TestTurnLoop_ContextCancelBeforeReceive(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], c []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + // Push before Run to guarantee the item is buffered before the + // context-monitoring goroutine can close the buffer. + _, _ = loop.Push("msg1") + loop.Run(ctx) + + result := loop.Wait() + assert.ErrorIs(t, result.ExitReason, context.Canceled) + assert.Len(t, result.UnhandledItems, 1) +} + +func TestTurnLoop_ContextCancelDuringBlockingReceive(t *testing.T) { + // When context is cancelled while Receive() is blocking (no items in buffer), + // the context monitoring goroutine closes the buffer, which unblocks Receive(). + ctx, cancel := context.WithCancel(context.Background()) + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], c []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + // Don't push any items — let Receive() block + time.Sleep(50 * time.Millisecond) + cancel() + + result := loop.Wait() + assert.ErrorIs(t, result.ExitReason, context.Canceled) +} + +func TestTurnLoop_ContextCancelAfterGenInput_RecoverItems(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + genInputCount := 0 + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + genInputCount++ + if genInputCount == 1 { + cancel() + } + return &GenInputResult[string]{ + Input: &AgentInput{}, + Consumed: items[:1], + Remaining: items[1:], + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], c []string) (Agent, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Push("msg1") + loop.Push("msg2") + + result := loop.Wait() + assert.ErrorIs(t, result.ExitReason, context.Canceled) + assert.True(t, len(result.UnhandledItems) >= 1, "should recover consumed and remaining items") +} + +func TestTurnLoop_OnAgentEventsReceivesEvents(t *testing.T) { + var receivedEvents []*AgentEvent + var receivedConsumed []string + var mu sync.Mutex + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + mu.Lock() + receivedConsumed = append(receivedConsumed, tc.Consumed...) + mu.Unlock() + + for { + event, ok := events.Next() + if !ok { + break + } + mu.Lock() + receivedEvents = append(receivedEvents, event) + mu.Unlock() + } + return nil + }, + }) + + loop.Push("msg1") + + time.Sleep(100 * time.Millisecond) + + loop.Stop() + result := loop.Wait() + + assert.NoError(t, result.ExitReason) + + mu.Lock() + defer mu.Unlock() + assert.True(t, len(receivedConsumed) > 0, "should have received consumed items") +} + +func TestTurnLoop_StopDuringAgentExecution(t *testing.T) { + agentStarted := make(chan struct{}) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + close(agentStarted) + time.Sleep(200 * time.Millisecond) + for { + _, ok := events.Next() + if !ok { + break + } + } + return nil + }, + }) + + loop.Push("msg1") + + <-agentStarted + loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + + result := loop.Wait() + assert.NoError(t, result.ExitReason) + assert.Equal(t, []string{"msg1"}, result.CanceledItems) +} + +func TestTurnLoop_StopCheckPointIDInCancelError(t *testing.T) { + ctx := context.Background() + modelStarted := make(chan struct{}, 1) + checkpointID := "turn-loop-cancel-ckpt-1" + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + + slowModel := &cancelTestChatModel{ + delayNs: int64(500 * time.Millisecond), + response: &schema.Message{ + Role: schema.Assistant, + Content: "Hello", + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Instruction: "You are a test assistant", + Model: slowModel, + }) + assert.NoError(t, err) + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + Store: store, + CheckpointID: checkpointID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + }) + + loop.Push("msg1") + + <-modelStarted + loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + + result := loop.Wait() + + var cancelErr *CancelError + assert.True(t, errors.As(result.ExitReason, &cancelErr), "ExitReason should be a *CancelError") + + store.mu.Lock() + defer store.mu.Unlock() + _, ok := store.m[checkpointID] + assert.True(t, ok, "checkpoint should be saved under the configured CheckpointID") +} + +func TestTurnLoop_StopWithoutCheckpointIDDoesNotPersist(t *testing.T) { + ctx := context.Background() + modelStarted := make(chan struct{}, 1) + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + + slowModel := &cancelTestChatModel{ + delayNs: int64(500 * time.Millisecond), + response: &schema.Message{ + Role: schema.Assistant, + Content: "Hello", + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Instruction: "You are a test assistant", + Model: slowModel, + }) + assert.NoError(t, err) + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + Store: store, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + }) + + loop.Push("msg1") + + <-modelStarted + loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + + result := loop.Wait() + + var cancelErr *CancelError + assert.True(t, errors.As(result.ExitReason, &cancelErr), "ExitReason should be a *CancelError") + + store.mu.Lock() + defer store.mu.Unlock() + assert.Empty(t, store.m, "no checkpoint should be saved when CheckpointID is not configured") +} + +func TestTurnLoop_StopBetweenTurnsAndResume(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + cpID := "between-turns-session" + + loop := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Push("a") + loop.Push("b") + loop.Stop() + loop.Run(ctx) + + exit := loop.Wait() + assert.NoError(t, exit.ExitReason) + + var seen []string + var mu sync.Mutex + loop2 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + mu.Lock() + seen = append([]string{}, items...) + mu.Unlock() + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + _, ok := events.Next() + if !ok { + break + } + } + tc.Loop.Stop() + return nil + }, + }) + + loop2.Push("c") + loop2.Run(ctx) + exit2 := loop2.Wait() + assert.NoError(t, exit2.ExitReason) + + mu.Lock() + defer mu.Unlock() + assert.Equal(t, []string{"a", "b", "c"}, seen) +} + +func TestTurnLoop_StopDuringAgentExecution_PersistAndResume(t *testing.T) { + ctx := context.Background() + modelStarted := make(chan struct{}, 1) + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + cpID := "mid-turn-session" + + slowModel := &cancelTestChatModel{ + delayNs: int64(500 * time.Millisecond), + response: &schema.Message{ + Role: schema.Assistant, + Content: "Hello", + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Instruction: "You are a test assistant", + Model: slowModel, + }) + assert.NoError(t, err) + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + }) + + loop.Push("msg1") + <-modelStarted + loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + exit := loop.Wait() + + store.mu.Lock() + _, ok := store.m[cpID] + store.mu.Unlock() + assert.True(t, ok) + _ = exit + + slowModel.setDelay(10 * time.Millisecond) + + var consumed2 []string + var genResumeCalled bool + var genInputCalled bool + loop2 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenResume: func(ctx context.Context, _ *TurnLoop[string], canceledItems []string, unhandledItems []string, newItems []string) (*GenResumeResult[string], error) { + genResumeCalled = true + return &GenResumeResult[string]{ + Consumed: canceledItems, + Remaining: append(append([]string{}, unhandledItems...), newItems...), + }, nil + }, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + genInputCalled = true + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + consumed2 = append([]string{}, consumed...) + return agent, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + _, ok := events.Next() + if !ok { + break + } + } + tc.Loop.Stop() + return nil + }, + }) + + loop2.Run(ctx) + exit2 := loop2.Wait() + assert.NoError(t, exit2.ExitReason) + assert.Equal(t, []string{"msg1"}, consumed2) + assert.True(t, genResumeCalled) + assert.False(t, genInputCalled) +} + +func TestTurnLoop_CheckpointIDWithoutStore_FreshStart(t *testing.T) { + ctx := context.Background() + var genInputCalled bool + loop := NewTurnLoop(TurnLoopConfig[string]{ + CheckpointID: "some-id", + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + genInputCalled = true + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := events.Next(); !ok { + break + } + } + tc.Loop.Stop() + return nil + }, + }) + loop.Push("a") + loop.Run(ctx) + exit := loop.Wait() + assert.NoError(t, exit.ExitReason) + assert.True(t, genInputCalled) +} + +func TestTurnLoop_CheckpointNotFound_FreshStart(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + var genInputCalled bool + loop := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: "nonexistent-id", + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + genInputCalled = true + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := events.Next(); !ok { + break + } + } + tc.Loop.Stop() + return nil + }, + }) + loop.Push("a") + loop.Run(ctx) + exit := loop.Wait() + assert.NoError(t, exit.ExitReason) + assert.True(t, genInputCalled) +} + +func TestTurnLoop_CheckpointEmptyData_TreatedAsNoCheckpoint(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + store.m["cp-empty"] = nil + + var genInputCalled bool + loop := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: "cp-empty", + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + genInputCalled = true + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := events.Next(); !ok { + break + } + } + tc.Loop.Stop() + return nil + }, + }) + loop.Push("a") + loop.Run(ctx) + exit := loop.Wait() + assert.NoError(t, exit.ExitReason) + assert.True(t, genInputCalled) +} + +type errorCheckpointStore struct { + getErr error + setErr error +} + +func (s *errorCheckpointStore) Get(_ context.Context, _ string) ([]byte, bool, error) { + return nil, false, s.getErr +} + +func (s *errorCheckpointStore) Set(_ context.Context, _ string, _ []byte) error { + return s.setErr +} + +func TestTurnLoop_CheckpointLoadError_ReturnsError(t *testing.T) { + ctx := context.Background() + store := &errorCheckpointStore{getErr: fmt.Errorf("store unavailable")} + loop := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: "cp-1", + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop.Push("a") + loop.Run(ctx) + exit := loop.Wait() + assert.Error(t, exit.ExitReason) + assert.Contains(t, exit.ExitReason.Error(), "store unavailable") +} + +func TestTurnLoop_CheckpointCorruptData_ReturnsError(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + store.m["cp-corrupt"] = []byte("not-valid-gob-data") + loop := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: "cp-corrupt", + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop.Push("a") + loop.Run(ctx) + exit := loop.Wait() + assert.Error(t, exit.ExitReason) + assert.Contains(t, exit.ExitReason.Error(), "failed to unmarshal checkpoint") +} + +func TestTurnLoop_CheckpointSaveError_ReturnsError(t *testing.T) { + ctx := context.Background() + modelStarted := make(chan struct{}, 1) + saveStore := &errorCheckpointStore{setErr: fmt.Errorf("write failed")} + slowModel := &cancelTestChatModel{ + delayNs: int64(500 * time.Millisecond), + response: &schema.Message{ + Role: schema.Assistant, + Content: "Hello", + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Instruction: "You are a test assistant", + Model: slowModel, + }) + assert.NoError(t, err) + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + Store: saveStore, + CheckpointID: "cp-1", + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + }) + loop.Push("msg1") + <-modelStarted + loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + exit := loop.Wait() + assert.Error(t, exit.ExitReason) + assert.Contains(t, exit.ExitReason.Error(), "write failed") +} + +func TestTurnLoop_StaleCheckpointDeletion_OnCleanResume(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + cpID := "stale-session" + + loop1 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop1.Push("a") + loop1.Stop() + loop1.Run(ctx) + loop1.Wait() + + store.mu.Lock() + _, exists := store.m[cpID] + store.mu.Unlock() + assert.True(t, exists, "checkpoint should exist after first loop saves it") + + loop2 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := events.Next(); !ok { + break + } + } + tc.Loop.Stop() + return nil + }, + }) + loop2.Push("b") + loop2.Run(ctx) + exit2 := loop2.Wait() + assert.NoError(t, exit2.ExitReason) + + store.mu.Lock() + _, exists = store.m[cpID] + store.mu.Unlock() + assert.True(t, exists, "checkpoint should still exist because loop2 was stopped and saved a new one") +} + +func TestTurnLoop_StaleCheckpointDeletion_ContextCancel(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + cpID := "delete-on-cancel" + + loop1 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop1.Push("a") + loop1.Stop() + loop1.Run(ctx) + loop1.Wait() + + store.mu.Lock() + _, exists := store.m[cpID] + store.mu.Unlock() + assert.True(t, exists, "checkpoint saved after loop1") + + ctx2, cancel2 := context.WithCancel(ctx) + loop2 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := events.Next(); !ok { + break + } + } + cancel2() + return nil + }, + }) + loop2.Push("b") + loop2.Run(ctx2) + exit2 := loop2.Wait() + assert.ErrorIs(t, exit2.ExitReason, context.Canceled) + + store.mu.Lock() + v, exists := store.m[cpID] + store.mu.Unlock() + deletedViaNil := exists && v == nil + deletedViaAbsence := !exists + assert.True(t, deletedViaNil || deletedViaAbsence, "stale checkpoint should be deleted when loop exits via context cancellation") +} + +type deletableCheckpointStore struct { + turnLoopCheckpointStore + deleteCalled bool + deletedKey string +} + +func (s *deletableCheckpointStore) Delete(_ context.Context, key string) error { + s.mu.Lock() + defer s.mu.Unlock() + s.deleteCalled = true + s.deletedKey = key + delete(s.m, key) + return nil +} + +func TestTurnLoop_CheckpointDeleter_CalledOnContextCancel(t *testing.T) { + ctx := context.Background() + store := &deletableCheckpointStore{ + turnLoopCheckpointStore: turnLoopCheckpointStore{m: make(map[string][]byte)}, + } + cpID := "deleter-session" + + loop1 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop1.Push("a") + loop1.Stop() + loop1.Run(ctx) + loop1.Wait() + + store.mu.Lock() + _, exists := store.m[cpID] + store.mu.Unlock() + assert.True(t, exists, "checkpoint saved after loop1") + + ctx2, cancel2 := context.WithCancel(ctx) + loop2 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := events.Next(); !ok { + break + } + } + cancel2() + return nil + }, + }) + loop2.Push("b") + loop2.Run(ctx2) + exit2 := loop2.Wait() + assert.ErrorIs(t, exit2.ExitReason, context.Canceled) + + store.mu.Lock() + defer store.mu.Unlock() + assert.True(t, store.deleteCalled, "CheckPointDeleter.Delete should be called") + assert.Equal(t, cpID, store.deletedKey) + _, exists = store.m[cpID] + assert.False(t, exists, "checkpoint should be removed from store") +} + +func TestTurnLoop_GenResumeNil_Error(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + cpID := "resume-nil-session" + modelStarted := make(chan struct{}, 1) + + slowModel := &cancelTestChatModel{ + delayNs: int64(500 * time.Millisecond), + response: &schema.Message{ + Role: schema.Assistant, + Content: "Hello", + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Instruction: "You are a test assistant", + Model: slowModel, + }) + assert.NoError(t, err) + + loop1 := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + }) + loop1.Push("msg1") + <-modelStarted + loop1.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + loop1.Wait() + + loop2 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop2.Run(ctx) + exit2 := loop2.Wait() + assert.Error(t, exit2.ExitReason) + assert.Contains(t, exit2.ExitReason.Error(), "GenResume is required") +} + +func TestTurnLoop_SameCheckpointID_OverwritePattern(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + cpID := "overwrite-session" + + loop1 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop1.Push("a") + loop1.Push("b") + loop1.Stop() + loop1.Run(ctx) + loop1.Wait() + + store.mu.Lock() + data1 := append([]byte{}, store.m[cpID]...) + store.mu.Unlock() + assert.NotEmpty(t, data1) + + loop2 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop2.Push("c") + loop2.Stop() + loop2.Run(ctx) + loop2.Wait() + + store.mu.Lock() + data2 := append([]byte{}, store.m[cpID]...) + store.mu.Unlock() + assert.NotEmpty(t, data2) + assert.NotEqual(t, data1, data2, "checkpoint data should change because items are different") + + var seen []string + var mu sync.Mutex + loop3 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + mu.Lock() + seen = append([]string{}, items...) + mu.Unlock() + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := events.Next(); !ok { + break + } + } + tc.Loop.Stop() + return nil + }, + }) + loop3.Push("d") + loop3.Run(ctx) + exit3 := loop3.Wait() + assert.NoError(t, exit3.ExitReason) + + mu.Lock() + defer mu.Unlock() + assert.Equal(t, []string{"a", "b", "c", "d"}, seen, "should see loop2's unhandled items (a,b,c from loop2's checkpoint) plus new d") +} + +func TestTurnLoop_CheckpointHasRunnerStateButEmptyBytes(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + cpID := "empty-runner-bytes" + + cp := &turnLoopCheckpoint[string]{ + HasRunnerState: true, + RunnerCheckpoint: nil, + UnhandledItems: []string{"x"}, + } + data, err := marshalTurnLoopCheckpoint(cp) + assert.NoError(t, err) + store.m[cpID] = data + + loop := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop.Push("a") + loop.Run(ctx) + exit := loop.Wait() + assert.Error(t, exit.ExitReason) + assert.Contains(t, exit.ExitReason.Error(), "has runner state but bytes are empty") +} + +func TestTurnLoop_GenResumeReturnsError(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + cpID := "resume-err-session" + modelStarted := make(chan struct{}, 1) + + slowModel := &cancelTestChatModel{ + delayNs: int64(500 * time.Millisecond), + response: &schema.Message{ + Role: schema.Assistant, + Content: "Hello", + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Instruction: "You are a test assistant", + Model: slowModel, + }) + assert.NoError(t, err) + + loop1 := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + }) + loop1.Push("msg1") + <-modelStarted + loop1.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + loop1.Wait() + + genResumeErr := fmt.Errorf("resume callback failed") + loop2 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + GenResume: func(ctx context.Context, _ *TurnLoop[string], canceled, unhandled, newItems []string) (*GenResumeResult[string], error) { + return nil, genResumeErr + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop2.Run(ctx) + exit2 := loop2.Wait() + assert.Error(t, exit2.ExitReason) + assert.ErrorIs(t, exit2.ExitReason, genResumeErr) +} + +func TestTurnLoop_CheckpointSaveError_MergesWithExistingError(t *testing.T) { + ctx := context.Background() + modelStarted := make(chan struct{}, 1) + saveStore := &errorCheckpointStore{setErr: fmt.Errorf("disk full")} + slowModel := &cancelTestChatModel{ + delayNs: int64(500 * time.Millisecond), + response: &schema.Message{ + Role: schema.Assistant, + Content: "Hello", + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Instruction: "You are a test assistant", + Model: slowModel, + }) + assert.NoError(t, err) + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + Store: saveStore, + CheckpointID: "cp-merge-err", + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + }) + loop.Push("msg1") + <-modelStarted + loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + exit := loop.Wait() + assert.Error(t, exit.ExitReason) + errStr := exit.ExitReason.Error() + assert.Contains(t, errStr, "disk full") + var ce *CancelError + assert.True(t, errors.As(exit.ExitReason, &ce), "should wrap original CancelError") +} + +func TestTurnLoop_ResumeWithParams(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + cpID := "resume-params-session" + modelStarted := make(chan struct{}, 1) + + slowModel := &cancelTestChatModel{ + delayNs: int64(500 * time.Millisecond), + response: &schema.Message{ + Role: schema.Assistant, + Content: "Hello", + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Instruction: "You are a test assistant", + Model: slowModel, + }) + assert.NoError(t, err) + + loop1 := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + }) + loop1.Push("msg1") + <-modelStarted + loop1.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + exit1 := loop1.Wait() + var ce *CancelError + assert.True(t, errors.As(exit1.ExitReason, &ce)) + + var resumeParamsUsed *ResumeParams + loop2 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + GenResume: func(ctx context.Context, _ *TurnLoop[string], canceled, unhandled, newItems []string) (*GenResumeResult[string], error) { + params := &ResumeParams{ + Targets: map[string]any{"some-address": "user-data"}, + } + resumeParamsUsed = params + return &GenResumeResult[string]{ + ResumeParams: params, + Consumed: append(append(canceled, unhandled...), newItems...), + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := events.Next(); !ok { + break + } + } + tc.Loop.Stop() + return nil + }, + }) + loop2.Run(ctx) + exit2 := loop2.Wait() + assert.NotNil(t, resumeParamsUsed, "GenResume should have been called with ResumeParams") + assert.Contains(t, resumeParamsUsed.Targets, "some-address") + _ = exit2 +} + +func TestTurnLoop_StopOptionsArePassed(t *testing.T) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelAfterToolCalls))) + + result := loop.Wait() + assert.NoError(t, result.ExitReason) +} + +func TestTurnLoop_Stop_EscalatesCancelMode(t *testing.T) { + ctx := context.Background() + agentStarted := make(chan *cancelContext, 1) + probe := &turnLoopStopModeProbeAgent{ccCh: agentStarted} + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return probe, nil + }, + }) + + loop.Push("msg1") + cc := <-agentStarted + + loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelAfterToolCalls), WithAgentCancelTimeout(10*time.Second))) + loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + + deadline := time.After(1 * time.Second) + for { + if cc.getMode() == CancelImmediate { + break + } + select { + case <-deadline: + t.Fatal("cancel mode did not escalate to CancelImmediate") + default: + } + time.Sleep(1 * time.Millisecond) + } + + exit := loop.Wait() + var ce *CancelError + assert.True(t, errors.As(exit.ExitReason, &ce)) + assert.Equal(t, CancelImmediate, ce.Info.Mode) +} + +func TestTurnLoop_DefaultOnAgentEvents_ErrorPropagation(t *testing.T) { + agentErr := errors.New("agent execution error") + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + return nil, agentErr + }, + }, nil + }, + // No OnAgentEvents — use default handler + }) + + loop.Push("msg1") + + result := loop.Wait() + // The default handler should propagate the agent error as ExitReason + assert.Error(t, result.ExitReason) +} + +func TestTurnLoop_OnAgentEventsError(t *testing.T) { + handlerErr := errors.New("event handler error") + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + // Drain events then return error + for { + _, ok := events.Next() + if !ok { + break + } + } + return handlerErr + }, + }) + + loop.Push("msg1") + + result := loop.Wait() + assert.ErrorIs(t, result.ExitReason, handlerErr) +} + +func TestTurnLoop_StopCallFromGenInput(t *testing.T) { + // Test that calling Stop() from within GenInput works correctly + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, loop *TurnLoop[string], items []string) (*GenInputResult[string], error) { + loop.Stop() + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Push("msg1") + + result := loop.Wait() + assert.NoError(t, result.ExitReason) +} + +func TestTurnLoop_PushFromOnAgentEvents(t *testing.T) { + // Test that calling Push() from within OnAgentEvents works + pushCount := int32(0) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: []string{items[0]}, + Remaining: items[1:], + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + _, ok := events.Next() + if !ok { + break + } + } + count := atomic.AddInt32(&pushCount, 1) + if count == 1 { + // Push a follow-up item from the callback + _, _ = tc.Loop.Push("follow-up") + } else { + tc.Loop.Stop() + } + return nil + }, + }) + + loop.Push("initial") + + result := loop.Wait() + assert.NoError(t, result.ExitReason) + assert.GreaterOrEqual(t, atomic.LoadInt32(&pushCount), int32(2)) +} + +// Tests for NewTurnLoop: the permissive API where Push, Stop, and Wait are +// all valid on a not-yet-running loop. + +func TestNewTurnLoop_PushBeforeRun(t *testing.T) { + // Items pushed before Run are buffered and processed after Run starts. + var processedItems []string + var mu sync.Mutex + + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + mu.Lock() + processedItems = append(processedItems, items...) + mu.Unlock() + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + // Push before Run — items should be buffered. + ok, _ := loop.Push("msg1") + assert.True(t, ok) + ok, _ = loop.Push("msg2") + assert.True(t, ok) + + loop.Run(context.Background()) + + time.Sleep(100 * time.Millisecond) + + loop.Stop() + result := loop.Wait() + + mu.Lock() + defer mu.Unlock() + + assert.NoError(t, result.ExitReason) + assert.Contains(t, processedItems, "msg1") + assert.Contains(t, processedItems, "msg2") +} + +func TestNewTurnLoop_StopBeforeRun(t *testing.T) { + // Stop before Run sets the stopped flag. When Run is called, the loop + // exits immediately and buffered items appear as UnhandledItems. + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + t.Fatal("GenInput should not be called") + return nil, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + t.Fatal("PrepareAgent should not be called") + return nil, nil + }, + }) + + loop.Push("msg1") + loop.Push("msg2") + loop.Stop() + + // Push after Stop returns false. + ok, _ := loop.Push("msg3") + assert.False(t, ok) + + loop.Run(context.Background()) + result := loop.Wait() + + assert.NoError(t, result.ExitReason) + assert.Equal(t, []string{"msg1", "msg2"}, result.UnhandledItems) +} + +func TestNewTurnLoop_WaitBeforeRun(t *testing.T) { + // Wait blocks until Run is called AND the loop exits. + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + waitDone := make(chan *TurnLoopExitState[string], 1) + go func() { + waitDone <- loop.Wait() + }() + + // Wait should not return yet since Run hasn't been called. + select { + case <-waitDone: + t.Fatal("Wait returned before Run was called") + case <-time.After(50 * time.Millisecond): + // expected + } + + loop.Push("msg1") + loop.Stop() + loop.Run(context.Background()) + + select { + case result := <-waitDone: + assert.NoError(t, result.ExitReason) + assert.Equal(t, []string{"msg1"}, result.UnhandledItems) + case <-time.After(1 * time.Second): + t.Fatal("Wait did not return after Run + Stop") + } +} + +func TestNewTurnLoop_RunIsIdempotent(t *testing.T) { + var genInputCalls int32 + + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + atomic.AddInt32(&genInputCalls, 1) + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Push("msg1") + loop.Run(context.Background()) + loop.Run(context.Background()) + loop.Run(context.Background()) + + time.Sleep(100 * time.Millisecond) + + loop.Stop() + result := loop.Wait() + + assert.NoError(t, result.ExitReason) + assert.True(t, atomic.LoadInt32(&genInputCalls) >= 1) +} + +func TestNewTurnLoop_StopBeforeRun_ThenWait(t *testing.T) { + // Demonstrates the full sequence: create, push, stop, run, wait. + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + t.Fatal("GenInput should not be called after Stop") + return nil, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + t.Fatal("PrepareAgent should not be called after Stop") + return nil, nil + }, + }) + + loop.Push("a") + loop.Push("b") + loop.Push("c") + loop.Stop() + + // Run after Stop: the loop goroutine starts but exits immediately. + loop.Run(context.Background()) + + result := loop.Wait() + assert.NoError(t, result.ExitReason) + assert.Equal(t, []string{"a", "b", "c"}, result.UnhandledItems) +} + +func TestNewTurnLoop_ConcurrentPushAndRun(t *testing.T) { + // Concurrent Push and Run should not race. + for i := 0; i < 100; i++ { + var count int32 + + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + atomic.AddInt32(&count, int32(len(items))) + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + _, _ = loop.Push("item") + }() + + go func() { + defer wg.Done() + loop.Run(context.Background()) + }() + + wg.Wait() + + time.Sleep(50 * time.Millisecond) + + loop.Stop() + result := loop.Wait() + assert.NoError(t, result.ExitReason) + + processed := atomic.LoadInt32(&count) + unhandled := len(result.UnhandledItems) + assert.True(t, int(processed)+unhandled <= 1, + "total should not exceed pushed amount") + } +} + +type turnCtxKey struct{} + +func TestTurnLoop_RunCtx_Propagation(t *testing.T) { + // Verify that GenInputResult.RunCtx is propagated to PrepareAgent, + // the agent run, and OnAgentEvents. + + const traceVal = "trace-123" + var prepareCtxVal, agentCtxVal, eventsCtxVal string + + cfg := TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, loop *TurnLoop[string], items []string) (*GenInputResult[string], error) { + // Derive a new context with per-item trace data + runCtx := context.WithValue(ctx, turnCtxKey{}, traceVal) + return &GenInputResult[string]{ + RunCtx: runCtx, + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, loop *TurnLoop[string], consumed []string) (Agent, error) { + if v, ok := ctx.Value(turnCtxKey{}).(string); ok { + prepareCtxVal = v + } + return &turnLoopMockAgent{ + name: "trace-agent", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + if v, ok := ctx.Value(turnCtxKey{}).(string); ok { + agentCtxVal = v + } + return &AgentOutput{}, nil + }, + }, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + if v, ok := ctx.Value(turnCtxKey{}).(string); ok { + eventsCtxVal = v + } + for { + if _, ok := events.Next(); !ok { + break + } + } + tc.Loop.Stop() + return nil + }, + } + + loop := NewTurnLoop(cfg) + loop.Push("hello") + loop.Run(context.Background()) + result := loop.Wait() + + assert.Nil(t, result.ExitReason) + assert.Equal(t, traceVal, prepareCtxVal, "PrepareAgent should receive RunCtx") + assert.Equal(t, traceVal, agentCtxVal, "Agent run should receive RunCtx") + assert.Equal(t, traceVal, eventsCtxVal, "OnAgentEvents should receive RunCtx") +} + +func TestTurnLoop_TurnContext_PreemptedChannel(t *testing.T) { + preemptedSeen := make(chan struct{}) + agentStarted := make(chan struct{}) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopCancellableMockAgent{ + name: "slow", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + <-ctx.Done() + return nil, ctx.Err() + }, + }, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + close(agentStarted) + select { + case <-tc.Preempted: + close(preemptedSeen) + case <-time.After(5 * time.Second): + t.Error("timed out waiting for Preempted channel") + } + // Drain events + for { + if _, ok := events.Next(); !ok { break } } return nil }, }) - require.NoError(t, err) - ctx, cancel := loop.WithCancel(context.Background()) - done := make(chan error) + loop.Push("msg1") + <-agentStarted + loop.Push("msg2", WithPreempt[string](WithAgentCancelMode(CancelImmediate))) + + select { + case <-preemptedSeen: + // success + case <-time.After(5 * time.Second): + t.Fatal("preempted channel was never observed in OnAgentEvents") + } + + loop.Stop() + loop.Wait() +} + +// ============================================================================= +// preemptSignal unit tests (direct testing of the hold/preempt/unhold mechanism) +// ============================================================================= + +func TestPreemptSignal_HoldCountLifecycle(t *testing.T) { + s := newPreemptSignal() + + s.holdRunLoop() + s.holdRunLoop() + + done := make(chan bool) go func() { - done <- loop.Run(ctx) + preempted, _, _ := s.waitForPreemptOrUnhold() + done <- preempted }() - <-slowAgent.startedCh - cancel(WithCancelMode(CancelAfterToolCall)) + select { + case <-done: + t.Fatal("waitForPreemptOrUnhold should block while holdCount > 0") + case <-time.After(50 * time.Millisecond): + } - <-done - assert.Len(t, slowAgent.cancelledOpt, 1) -} + s.unholdRunLoop() -type turnLoopInMemoryStore struct { - data map[string][]byte -} + select { + case <-done: + t.Fatal("waitForPreemptOrUnhold should still block (holdCount=1)") + case <-time.After(50 * time.Millisecond): + } -func newTurnLoopInMemoryStore() *turnLoopInMemoryStore { - return &turnLoopInMemoryStore{data: make(map[string][]byte)} -} + s.unholdRunLoop() -func (s *turnLoopInMemoryStore) Get(_ context.Context, key string) ([]byte, bool, error) { - v, ok := s.data[key] - return v, ok, nil + select { + case preempted := <-done: + assert.False(t, preempted, "should return not-preempted when all holds released") + case <-time.After(1 * time.Second): + t.Fatal("waitForPreemptOrUnhold should unblock when holdCount reaches 0") + } } -func (s *turnLoopInMemoryStore) Set(_ context.Context, key string, value []byte) error { - s.data[key] = value - return nil -} +func TestPreemptSignal_RequestPreemptWithNoHold(t *testing.T) { + s := newPreemptSignal() -type turnLoopTestModel struct { - messages []*schema.Message - idx int -} + ack := make(chan struct{}) + s.requestPreempt(ack) -func (m *turnLoopTestModel) Generate(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { - if m.idx >= len(m.messages) { - return nil, fmt.Errorf("no more messages") + select { + case <-ack: + case <-time.After(100 * time.Millisecond): + t.Fatal("ack should be closed immediately when holdCount is 0") } - msg := m.messages[m.idx] - m.idx++ - return msg, nil } -func (m *turnLoopTestModel) Stream(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { - panic("not implemented") -} +func TestPreemptSignal_RequestPreemptWakesWaiter(t *testing.T) { + s := newPreemptSignal() + s.holdRunLoop() + + done := make(chan struct { + preempted bool + ackList []chan struct{} + }) + go func() { + preempted, _, ackList := s.waitForPreemptOrUnhold() + done <- struct { + preempted bool + ackList []chan struct{} + }{preempted, ackList} + }() -func (m *turnLoopTestModel) WithTools(_ []*schema.ToolInfo) (model.ToolCallingChatModel, error) { - return m, nil + ack := make(chan struct{}) + s.requestPreempt(ack) + + select { + case result := <-done: + assert.True(t, result.preempted) + assert.Len(t, result.ackList, 1) + close(result.ackList[0]) + case <-time.After(1 * time.Second): + t.Fatal("waitForPreemptOrUnhold should wake on requestPreempt") + } } -type turnLoopSlowModel struct { - delay int64 - startedCh unsafe.Pointer - doneCh chan struct{} - message *schema.Message +func TestPreemptSignal_HoldAndGetTurn(t *testing.T) { + s := newPreemptSignal() + s.setTurn(context.Background(), "turn-A") + + ctx, tc := s.holdAndGetTurn() + assert.NotNil(t, ctx) + assert.Equal(t, "turn-A", tc) + + s.endTurnAndUnhold() + + _, tc2 := s.holdAndGetTurn() + assert.Nil(t, tc2, "TC should be nil after endTurnAndUnhold") + s.unholdRunLoop() } -func (m *turnLoopSlowModel) Generate(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { - if ch := (*chan struct{})(atomic.LoadPointer(&m.startedCh)); ch != nil { - select { - case *ch <- struct{}{}: - default: +func TestPreemptSignal_EndTurnPreservesSignalWhenHoldRemains(t *testing.T) { + s := newPreemptSignal() + + s.holdRunLoop() + s.holdRunLoop() + + ack := make(chan struct{}) + s.requestPreempt(ack) + + s.endTurnAndUnhold() + + done := make(chan bool) + go func() { + preempted, _, ackList := s.waitForPreemptOrUnhold() + for _, a := range ackList { + close(a) } + done <- preempted + }() + + select { + case preempted := <-done: + assert.True(t, preempted, "signal state should be preserved when holdCount > 0 after endTurnAndUnhold") + case <-time.After(1 * time.Second): + t.Fatal("waiter should see the preserved preempt signal") } - if delay := atomic.LoadInt64(&m.delay); delay > 0 { - time.Sleep(time.Duration(delay)) - } - if m.doneCh != nil { - select { - case m.doneCh <- struct{}{}: - default: - } + + select { + case <-ack: + case <-time.After(100 * time.Millisecond): + t.Fatal("ack should have been closed") } - return m.message, nil } -func (m *turnLoopSlowModel) Stream(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { - panic("not implemented") +func TestPreemptSignal_ConcurrentHoldRequestUnhold(t *testing.T) { + s := newPreemptSignal() + + var wg sync.WaitGroup + for i := 0; i < 50; i++ { + wg.Add(1) + go func() { + defer wg.Done() + s.holdRunLoop() + ack := make(chan struct{}) + s.requestPreempt(ack) + s.unholdRunLoop() + <-ack + }() + } + wg.Wait() } -type turnLoopInterruptTool struct { - name string -} +// ============================================================================= +// Integration tests for race-prone preempt scenarios +// ============================================================================= -func (t *turnLoopInterruptTool) Info(_ context.Context) (*schema.ToolInfo, error) { - return &schema.ToolInfo{ - Name: t.name, - Desc: "A tool that interrupts", - }, nil -} +func TestTurnLoop_ConcurrentPreemptsDuringTurn(t *testing.T) { + agentStarted := make(chan struct{}) + agentStartedOnce := sync.Once{} -func (t *turnLoopInterruptTool) InvokableRun(ctx context.Context, _ string, _ ...tool.Option) (string, error) { - wasInterrupted, _, _ := tool.GetInterruptState[any](ctx) - if !wasInterrupted { - return "", tool.Interrupt(ctx, "need approval") - } - isResumeTarget, hasData, data := tool.GetResumeContext[string](ctx) - if isResumeTarget && hasData { - return data, nil + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + agentStartedOnce.Do(func() { + close(agentStarted) + }) + <-ctx.Done() + return &AgentOutput{}, nil + }, } - return "approved", nil -} -func TestTurnLoop_ExternalCancel_WithStore(t *testing.T) { - store := newTurnLoopInMemoryStore() - const checkPointID = "external-cancel-test" + var genInputCount int32 - modelStarted := make(chan struct{}, 1) - testModel := &turnLoopSlowModel{ - doneCh: make(chan struct{}, 1), - message: schema.AssistantMessage("task completed", nil), + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + atomic.AddInt32(&genInputCount, 1) + return &GenInputResult[string]{ + Input: &AgentInput{}, + Consumed: items, + }, nil + }, + }) + + loop.Push("seed") + + select { + case <-agentStarted: + case <-time.After(1 * time.Second): + t.Fatal("agent did not start") } - atomic.StoreInt64(&testModel.delay, int64(5*time.Second)) - atomic.StorePointer(&testModel.startedCh, unsafe.Pointer(&modelStarted)) - agent, err := NewChatModelAgent(context.Background(), &ChatModelAgentConfig{ - Name: "test-agent", - Description: "test agent for external cancel", - Model: testModel, - }) - require.NoError(t, err) + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + ok, ack := loop.Push(fmt.Sprintf("urgent-%d", i), WithPreempt[string]()) + if ok && ack != nil { + select { + case <-ack: + case <-time.After(5 * time.Second): + t.Error("ack channel not closed within timeout") + } + } + }(i) + } + + wg.Wait() + time.Sleep(200 * time.Millisecond) + + loop.Stop() + result := loop.Wait() + assert.NoError(t, result.ExitReason) + assert.True(t, atomic.LoadInt32(&genInputCount) >= 2, "should have had at least the initial turn + one preempted turn") +} - receiveCount := int32(0) +func TestTurnLoop_PreemptDuringTurnTransition(t *testing.T) { turnCount := int32(0) - frontBlocked := make(chan struct{}) - source := &turnLoopFuncSource[string]{ - receive: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { - cnt := atomic.AddInt32(&receiveCount, 1) - if cnt == 1 { - return ctx, "msg1", []ConsumeOption{WithConsumeCheckPointID(checkPointID)}, nil + firstTurnDone := make(chan struct{}) + firstTurnOnce := sync.Once{} + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "fast"}, nil + }, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + count := atomic.AddInt32(&turnCount, 1) + if count == 1 { + firstTurnOnce.Do(func() { + close(firstTurnDone) + }) } - return ctx, "", nil, ErrLoopExit - }, - front: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { - <-frontBlocked - return ctx, "", nil, context.DeadlineExceeded + return &GenInputResult[string]{ + Input: &AgentInput{}, + Consumed: items, + }, nil }, + }) + + loop.Push("first") + + select { + case <-firstTurnDone: + case <-time.After(1 * time.Second): + t.Fatal("first turn did not start") + } + + time.Sleep(50 * time.Millisecond) + + ok, ack := loop.Push("transitional", WithPreempt[string]()) + assert.True(t, ok, "push should succeed") + if ack != nil { + select { + case <-ack: + case <-time.After(2 * time.Second): + t.Fatal("ack should be closed even if preempt arrived during/after turn transition") + } } - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: source, - GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil + time.Sleep(100 * time.Millisecond) + + loop.Stop() + result := loop.Wait() + assert.NoError(t, result.ExitReason) + assert.True(t, atomic.LoadInt32(&turnCount) >= 2, "transitional item should have been processed") +} + +func TestTurnLoop_PushStrategy_DuringTurnTransition(t *testing.T) { + agentStarted := make(chan struct{}) + agentStartedOnce := sync.Once{} + allowFinish := make(chan struct{}) + + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + agentStartedOnce.Do(func() { + close(agentStarted) + }) + select { + case <-allowFinish: + return &AgentOutput{}, nil + case <-ctx.Done(): + return &AgentOutput{}, nil + } }, - GetAgent: func(_ context.Context, _ string) (Agent, error) { + } + + var genInputCount int32 + secondTurnDone := make(chan struct{}) + secondTurnOnce := sync.Once{} + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { return agent, nil }, - OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { - atomic.AddInt32(&turnCount, 1) - for { - if _, ok := iter.Next(); !ok { - break - } + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + count := atomic.AddInt32(&genInputCount, 1) + if count >= 2 { + secondTurnOnce.Do(func() { + close(secondTurnDone) + }) } - return nil + return &GenInputResult[string]{ + Input: &AgentInput{}, + Consumed: items, + }, nil }, - Store: store, }) - require.NoError(t, err) - ctx, cancel := loop.WithCancel(context.Background()) - done := make(chan error) + loop.Push("first") + + select { + case <-agentStarted: + case <-time.After(1 * time.Second): + t.Fatal("agent did not start") + } + + strategyBlocker := make(chan struct{}) + var strategyTCNotNil int32 + go func() { - done <- loop.Run(ctx) + loop.Push("strategic-item", WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { + if tc != nil { + atomic.StoreInt32(&strategyTCNotNil, 1) + } + <-strategyBlocker + return []PushOption[string]{WithPreempt[string]()} + })) }() + time.Sleep(50 * time.Millisecond) + close(allowFinish) + time.Sleep(50 * time.Millisecond) + close(strategyBlocker) + select { - case <-modelStarted: - case err := <-done: - t.Fatalf("loop.Run returned early with error: %v", err) + case <-secondTurnDone: case <-time.After(3 * time.Second): - t.Fatalf("timeout waiting for model to start") + t.Fatal("second turn should eventually run after strategy resolves") } - err = cancel(WithCancelMode(CancelImmediate)) + loop.Stop() + result := loop.Wait() + assert.NoError(t, result.ExitReason) + assert.True(t, atomic.LoadInt32(&genInputCount) >= 2) +} - var runErr error - select { - case runErr = <-done: - case <-time.After(3 * time.Second): - t.Fatalf("timeout waiting for loop.Run to return after cancel") +func TestTurnLoop_ConcurrentPreemptAndStop(t *testing.T) { + for iter := 0; iter < 20; iter++ { + t.Run(fmt.Sprintf("iter_%d", iter), func(t *testing.T) { + ctx := context.Background() + + agentStarted := make(chan struct{}) + agentStartedOnce := sync.Once{} + + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + agentStartedOnce.Do(func() { + close(agentStarted) + }) + <-ctx.Done() + return &AgentOutput{}, nil + }, + } + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{}, + Consumed: items, + }, nil + }, + }) + + loop.Push("seed") + + select { + case <-agentStarted: + case <-time.After(1 * time.Second): + t.Fatal("agent did not start") + } + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + _, ack := loop.Push("preempt-item", WithPreempt[string]()) + if ack != nil { + <-ack + } + }() + + go func() { + defer wg.Done() + loop.Stop() + }() + + wg.Wait() + loop.Wait() + }) } +} + +func TestTurnLoop_ConcurrentPushStrategyAndStop(t *testing.T) { + for iter := 0; iter < 20; iter++ { + t.Run(fmt.Sprintf("iter_%d", iter), func(t *testing.T) { + ctx := context.Background() + + agentStarted := make(chan struct{}) + agentStartedOnce := sync.Once{} + + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + agentStartedOnce.Do(func() { + close(agentStarted) + }) + <-ctx.Done() + return &AgentOutput{}, nil + }, + } - var interruptErr *TurnLoopInterruptError[string] - require.True(t, errors.As(runErr, &interruptErr), "expected TurnLoopInterruptError, got: %v", runErr) - assert.Equal(t, checkPointID, interruptErr.CheckpointID) - assert.Equal(t, "msg1", interruptErr.Item) - assert.True(t, len(store.data) > 0, "checkpoint should be stored") + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{}, + Consumed: items, + }, nil + }, + }) + + loop.Push("seed") + + select { + case <-agentStarted: + case <-time.After(1 * time.Second): + t.Fatal("agent did not start") + } - atomic.StorePointer(&testModel.startedCh, nil) - atomic.StoreInt64(&testModel.delay, 0) + var wg sync.WaitGroup + wg.Add(2) - done = make(chan error) - go func() { - done <- loop.Run(context.Background(), WithTurnLoopResume(checkPointID, "msg1")) - }() + go func() { + defer wg.Done() + _, ack := loop.Push("strategic-item", WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { + return []PushOption[string]{WithPreempt[string]()} + })) + if ack != nil { + <-ack + } + }() + + go func() { + defer wg.Done() + loop.Stop() + }() + + wg.Wait() + loop.Wait() + }) + } +} +func TestTurnLoop_TurnContext_StoppedChannel(t *testing.T) { + stoppedSeen := make(chan struct{}) + agentStarted := make(chan struct{}) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopCancellableMockAgent{ + name: "slow", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + <-ctx.Done() + return nil, ctx.Err() + }, + }, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + close(agentStarted) + select { + case <-tc.Stopped: + close(stoppedSeen) + case <-time.After(5 * time.Second): + t.Error("timed out waiting for Stopped channel") + } + // Drain events + for { + if _, ok := events.Next(); !ok { + break + } + } + return nil + }, + }) + + loop.Push("msg1") + <-agentStarted + loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) select { - case runErr = <-done: + case <-stoppedSeen: + // success case <-time.After(5 * time.Second): - t.Fatalf("timeout waiting for loop.Run (resume) to return") + t.Fatal("stopped channel was never observed in OnAgentEvents") } - assert.NoError(t, runErr) - assert.Equal(t, int32(2), atomic.LoadInt32(&turnCount), "should have 2 turns total (1 from first run + 1 from resume)") + + loop.Wait() } -func TestTurnLoop_InternalToolInterrupt_WithCheckpoint_ThenResume(t *testing.T) { - store := newTurnLoopInMemoryStore() - const checkPointID = "tool-interrupt-test" +func TestTurnLoop_PushStrategy_DuringTurn(t *testing.T) { + agentStarted := make(chan struct{}) + agentStartedOnce := sync.Once{} + agentCancelled := make(chan struct{}) + agentCancelledOnce := sync.Once{} + + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + agentStartedOnce.Do(func() { + close(agentStarted) + }) + <-ctx.Done() + agentCancelledOnce.Do(func() { + close(agentCancelled) + }) + return &AgentOutput{}, nil + }, + } - interruptTool := &turnLoopInterruptTool{name: "approval_tool"} + genInputCalls := int32(0) + secondGenInputCalled := make(chan struct{}) + secondGenInputOnce := sync.Once{} - testModel := &turnLoopTestModel{ - messages: []*schema.Message{ - schema.AssistantMessage("", []schema.ToolCall{ - {ID: "call1", Function: schema.FunctionCall{Name: "approval_tool", Arguments: "{}"}}, - }), - schema.AssistantMessage("task completed", nil), + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil }, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + count := atomic.AddInt32(&genInputCalls, 1) + if count >= 2 { + secondGenInputOnce.Do(func() { + close(secondGenInputCalled) + }) + } + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: []string{items[0]}, + Remaining: items[1:], + }, nil + }, + }) + + loop.Push("first") + + select { + case <-agentStarted: + case <-time.After(1 * time.Second): + t.Fatal("agent did not start") } - agent, err := NewChatModelAgent(context.Background(), &ChatModelAgentConfig{ - Name: "test-agent", - Description: "test agent", - Model: testModel, - ToolsConfig: ToolsConfig{ - ToolsNodeConfig: compose.ToolsNodeConfig{ - Tools: []tool.BaseTool{interruptTool}, - }, + // Strategy inspects TurnContext during a running turn and decides to preempt. + var strategyCalled int32 + var strategyTC *TurnContext[string] + loop.Push("urgent", WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { + atomic.AddInt32(&strategyCalled, 1) + strategyTC = tc + return []PushOption[string]{WithPreempt[string]()} + })) + + select { + case <-agentCancelled: + case <-time.After(1 * time.Second): + t.Fatal("agent was not cancelled by strategy-returned preempt") + } + + select { + case <-secondGenInputCalled: + case <-time.After(1 * time.Second): + t.Fatal("second GenInput was not called after preempt") + } + + loop.Stop() + loop.Wait() + + assert.Equal(t, int32(1), atomic.LoadInt32(&strategyCalled)) + assert.NotNil(t, strategyTC, "strategy should receive non-nil TurnContext during a turn") + assert.Equal(t, []string{"first"}, strategyTC.Consumed) +} + +func TestTurnLoop_PushStrategy_BetweenTurns(t *testing.T) { + // Push with strategy before Run() — TurnContext should be nil. + var strategyCalled int32 + var strategyTCWasNil bool + + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + return &AgentOutput{}, nil }, - }) - require.NoError(t, err) + } - receiveCount := int32(0) - frontBlocked := make(chan struct{}) - source := &turnLoopFuncSource[string]{ - receive: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { - cnt := atomic.AddInt32(&receiveCount, 1) - if cnt == 1 { - return ctx, "msg1", []ConsumeOption{WithConsumeCheckPointID(checkPointID)}, nil + agentDone := make(chan struct{}) + agentDoneOnce := sync.Once{} + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + Remaining: nil, + }, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + _, ok := events.Next() + if !ok { + break + } } - return ctx, "", nil, ErrLoopExit + agentDoneOnce.Do(func() { + close(agentDone) + }) + return nil + }, + }) + + // Push with strategy — no turn is active yet, so tc should be nil. + loop.Push("item", WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { + atomic.AddInt32(&strategyCalled, 1) + strategyTCWasNil = (tc == nil) + return nil // plain push, no preempt + })) + + select { + case <-agentDone: + case <-time.After(2 * time.Second): + t.Fatal("agent did not complete") + } + + loop.Stop() + loop.Wait() + + assert.Equal(t, int32(1), atomic.LoadInt32(&strategyCalled)) + assert.True(t, strategyTCWasNil, "strategy should receive nil TurnContext between turns") +} + +func TestTurnLoop_PushStrategy_OverridesOtherOptions(t *testing.T) { + // Push with both WithPreempt and WithPushStrategy — only strategy's result applies. + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + return &AgentOutput{}, nil + }, + } + + agentDone := make(chan struct{}) + agentDoneOnce := sync.Once{} + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil }, - front: func(ctx context.Context, _ ReceiveConfig) (context.Context, string, []ConsumeOption, error) { - <-frontBlocked - return ctx, "", nil, context.DeadlineExceeded + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + Remaining: nil, + }, nil }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + _, ok := events.Next() + if !ok { + break + } + } + agentDoneOnce.Do(func() { + close(agentDone) + }) + return nil + }, + }) + + // Strategy returns nil (no preempt), even though WithPreempt is also passed. + // The strategy should override — so the agent should NOT be preempted. + ok, ack := loop.Push("item", WithPreempt[string](), WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { + return nil // no preempt + })) + assert.True(t, ok) + assert.Nil(t, ack, "ack should be nil since strategy returned no preempt") + + select { + case <-agentDone: + case <-time.After(2 * time.Second): + t.Fatal("agent did not complete normally") } - var interruptErr *TurnLoopInterruptError[string] - loop, err := NewTurnLoop(TurnLoopConfig[string]{ - Source: source, - GenInput: func(_ context.Context, item string) (*AgentInput, []AgentRunOption, error) { - return &AgentInput{Messages: []Message{schema.UserMessage(item)}}, nil, nil + loop.Stop() + loop.Wait() +} + +func TestTurnLoop_PushStrategy_NestedStrategyStripped(t *testing.T) { + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + return &AgentOutput{}, nil }, - GetAgent: func(_ context.Context, _ string) (Agent, error) { + } + + agentDone := make(chan struct{}) + agentDoneOnce := sync.Once{} + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { return agent, nil }, - OnAgentEvents: func(_ context.Context, _ string, iter *AsyncIterator[*AgentEvent]) error { + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + Remaining: nil, + }, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { for { - if _, ok := iter.Next(); !ok { + _, ok := events.Next() + if !ok { break } } + agentDoneOnce.Do(func() { + close(agentDone) + }) return nil }, - Store: store, }) - require.NoError(t, err) - err = loop.Run(context.Background()) - require.True(t, errors.As(err, &interruptErr), "expected TurnLoopInterruptError, got: %v", err) - assert.Equal(t, checkPointID, interruptErr.CheckpointID) - assert.Equal(t, "msg1", interruptErr.Item) - assert.Len(t, interruptErr.InterruptContexts, 1) - assert.Equal(t, "need approval", interruptErr.InterruptContexts[0].Info) + // Strategy returns another WithPushStrategy — the nested one should be stripped. + innerCalled := int32(0) + ok, ack := loop.Push("item", WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { + return []PushOption[string]{ + WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { + atomic.AddInt32(&innerCalled, 1) + return []PushOption[string]{WithPreempt[string]()} + }), + } + })) + assert.True(t, ok) + assert.Nil(t, ack, "ack should be nil since nested strategy was stripped (no preempt)") - interruptID := interruptErr.InterruptContexts[0].ID - assert.NotEmpty(t, interruptID) - assert.True(t, len(store.data) > 0, "checkpoint should be stored") + select { + case <-agentDone: + case <-time.After(2 * time.Second): + t.Fatal("agent did not complete normally") + } - testModel.idx = 1 - receiveCount = 0 + loop.Stop() + loop.Wait() - resumeCtx := compose.ResumeWithData(context.Background(), interruptID, "user approved") - err = loop.Run(resumeCtx, WithTurnLoopResume(checkPointID, "msg1")) - assert.NoError(t, err) + assert.Equal(t, int32(0), atomic.LoadInt32(&innerCalled), "nested strategy should not be called") +} + +func TestTurnLoop_PushStrategy_ConsumedInspection(t *testing.T) { + // Strategy preempts only when current turn is processing "low-priority" items. + agentStarted := make(chan struct{}) + agentStartedOnce := sync.Once{} + + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + agentStartedOnce.Do(func() { + close(agentStarted) + }) + <-ctx.Done() + return &AgentOutput{}, nil + }, + } + + genInputCalls := int32(0) + secondGenInputItems := make(chan []string, 1) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + count := atomic.AddInt32(&genInputCalls, 1) + if count >= 2 { + select { + case secondGenInputItems <- append([]string{}, items...): + default: + } + } + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: []string{items[0]}, + Remaining: items[1:], + }, nil + }, + }) + + loop.Push("low-priority-task") + + select { + case <-agentStarted: + case <-time.After(1 * time.Second): + t.Fatal("agent did not start") + } + + // Strategy checks Consumed and preempts because current turn has "low-priority" items. + loop.Push("urgent-task", WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { + if tc != nil && len(tc.Consumed) > 0 && tc.Consumed[0] == "low-priority-task" { + return []PushOption[string]{WithPreempt[string]()} + } + return nil + })) + + select { + case items := <-secondGenInputItems: + assert.Contains(t, items, "urgent-task") + case <-time.After(2 * time.Second): + t.Fatal("second GenInput was not called after strategy-driven preempt") + } + + loop.Stop() + loop.Wait() } diff --git a/adk/utils.go b/adk/utils.go index ee061f71f..6dbdbc2b5 100644 --- a/adk/utils.go +++ b/adk/utils.go @@ -124,16 +124,28 @@ func getMessageFromWrappedEvent(e *agentEventWrapper) (Message, error) { return nil, e.StreamErr } + e.consumeStream() + + if e.StreamErr != nil { + return nil, e.StreamErr + } + return e.concatenatedMessage, nil +} + +// consumeStream drains the message stream, setting concatenatedMessage on +// success or StreamErr on failure. The stream is always replaced with an +// error-free, materialized version safe for gob encoding. +// Must be called at most once (guarded by callers checking concatenatedMessage/StreamErr). +func (e *agentEventWrapper) consumeStream() { e.mu.Lock() defer e.mu.Unlock() + if e.concatenatedMessage != nil { - return e.concatenatedMessage, nil + return } - var ( - msgs []Message - s = e.AgentEvent.Output.MessageOutput.MessageStream - ) + s := e.AgentEvent.Output.MessageOutput.MessageStream + var msgs []Message defer s.Close() for { @@ -143,19 +155,16 @@ func getMessageFromWrappedEvent(e *agentEventWrapper) (Message, error) { break } e.StreamErr = err - // Replace the stream with successfully received messages only (no error at the end). - // The error is preserved in StreamErr for users to check. - // We intentionally exclude the error from the new stream to ensure gob encoding - // compatibility, as the stream may be consumed during serialization. e.AgentEvent.Output.MessageOutput.MessageStream = schema.StreamReaderFromArray(msgs) - return nil, err + return } - msgs = append(msgs, msg) } if len(msgs) == 0 { - return nil, errors.New("no messages in MessageVariant.MessageStream") + e.StreamErr = errors.New("no messages in MessageVariant.MessageStream") + e.AgentEvent.Output.MessageOutput.MessageStream = schema.StreamReaderFromArray(msgs) + return } if len(msgs) == 1 { @@ -166,11 +175,11 @@ func getMessageFromWrappedEvent(e *agentEventWrapper) (Message, error) { if err != nil { e.StreamErr = err e.AgentEvent.Output.MessageOutput.MessageStream = schema.StreamReaderFromArray(msgs) - return nil, err + return } } - return e.concatenatedMessage, nil + e.AgentEvent.Output.MessageOutput.MessageStream = schema.StreamReaderFromArray([]Message{e.concatenatedMessage}) } // copyAgentEvent copies an AgentEvent. diff --git a/adk/workflow.go b/adk/workflow.go index 9d63d7347..00411e33b 100644 --- a/adk/workflow.go +++ b/adk/workflow.go @@ -175,7 +175,6 @@ func (a *workflowAgent) runSequential(ctx context.Context, startIdx := 0 - // seqCtx tracks the accumulated RunPath across the sequence. seqCtx := ctx // If we are resuming, find which sub-agent to start from and prepare its context. @@ -193,12 +192,28 @@ func (a *workflowAgent) runSequential(ctx context.Context, for i := startIdx; i < len(a.subAgents); i++ { subAgent := a.subAgents[i] + // Cancel check at transition boundary between sub-agents. + // Transition boundaries are always safe to cancel at — no sub-agent + // work is in progress, so any cancel mode is honoured. + if cancelCtx := getCancelContext(ctx); cancelCtx != nil && cancelCtx.shouldCancel() { + state := &sequentialWorkflowState{InterruptIndex: i} + event := cancelAtTransition(ctx, "Sequential workflow cancel at transition", state) + generator.Send(event) + return nil + } + var subIterator *AsyncIterator[*AgentEvent] if seqState != nil { - subIterator = subAgent.Resume(seqCtx, &ResumeInfo{ - EnableStreaming: info.EnableStreaming, - InterruptInfo: info.Data.(*WorkflowInterruptInfo).SequentialInterruptInfo, - }, opts...) + wfInfo, _ := info.Data.(*WorkflowInterruptInfo) + if wfInfo != nil && wfInfo.SequentialInterruptInfo != nil { + // Sub-agent was interrupted — resume it. + subIterator = subAgent.Resume(seqCtx, &ResumeInfo{ + EnableStreaming: info.EnableStreaming, + InterruptInfo: wfInfo.SequentialInterruptInfo, + }, opts...) + } else { + subIterator = subAgent.Run(seqCtx, nil, opts...) + } seqState = nil } else { subIterator = subAgent.Run(seqCtx, nil, opts...) @@ -304,7 +319,6 @@ func (a *workflowAgent) runLoop(ctx context.Context, generator *AsyncGenerator[* startIter := 0 startIdx := 0 - // loopCtx tracks the accumulated RunPath across the full sequence within a single iteration. loopCtx := ctx if loopState != nil { @@ -329,13 +343,25 @@ func (a *workflowAgent) runLoop(ctx context.Context, generator *AsyncGenerator[* for j := startIdx; j < len(a.subAgents); j++ { subAgent := a.subAgents[j] + if cancelCtx := getCancelContext(ctx); cancelCtx != nil && cancelCtx.shouldCancel() { + state := &loopWorkflowState{LoopIterations: i, SubAgentIndex: j} + event := cancelAtTransition(ctx, "Loop workflow cancel at transition", state) + generator.Send(event) + return nil + } + var subIterator *AsyncIterator[*AgentEvent] if loopState != nil { - // This is the agent we need to resume. - subIterator = subAgent.Resume(loopCtx, &ResumeInfo{ - EnableStreaming: resumeInfo.EnableStreaming, - InterruptInfo: resumeInfo.Data.(*WorkflowInterruptInfo).SequentialInterruptInfo, - }, opts...) + wfInfo, _ := resumeInfo.Data.(*WorkflowInterruptInfo) + if wfInfo != nil && wfInfo.SequentialInterruptInfo != nil { + // Sub-agent was interrupted — resume it. + subIterator = subAgent.Resume(loopCtx, &ResumeInfo{ + EnableStreaming: resumeInfo.EnableStreaming, + InterruptInfo: wfInfo.SequentialInterruptInfo, + }, opts...) + } else { + subIterator = subAgent.Run(loopCtx, nil, opts...) + } loopState = nil // Only resume the first time. } else { subIterator = subAgent.Run(loopCtx, nil, opts...) @@ -468,6 +494,15 @@ func (a *workflowAgent) runParallel(ctx context.Context, generator *AsyncGenerat } } + // Cancel check before spawning parallel goroutines. No sub-agent work + // is in progress, so any cancel mode is honoured at this boundary. + if cancelCtx := getCancelContext(ctx); cancelCtx != nil && cancelCtx.shouldCancel() { + state := ¶llelWorkflowState{} + event := cancelAtTransition(ctx, "Parallel workflow cancel before spawn", state) + generator.Send(event) + return nil + } + for i := range a.subAgents { wg.Add(1) go func(idx int, agent *flowAgent) { @@ -483,11 +518,13 @@ func (a *workflowAgent) runParallel(ctx context.Context, generator *AsyncGenerat var iterator *AsyncIterator[*AgentEvent] if _, ok := agentNames[agent.Name(ctx)]; ok { - // This branch was interrupted and needs to be resumed. - iterator = agent.Resume(childContexts[idx], &ResumeInfo{ + childResumeInfo := &ResumeInfo{ EnableStreaming: resumeInfo.EnableStreaming, - InterruptInfo: resumeInfo.Data.(*WorkflowInterruptInfo).ParallelInterruptInfo[idx], - }, opts...) + } + if wfInfo, ok := resumeInfo.Data.(*WorkflowInterruptInfo); ok && wfInfo != nil { + childResumeInfo.InterruptInfo = wfInfo.ParallelInterruptInfo[idx] + } + iterator = agent.Resume(childContexts[idx], childResumeInfo, opts...) } else if parState != nil { // We are resuming, but this child is not in the next points map. // This means it finished successfully, so we don't run it. @@ -550,6 +587,27 @@ func (a *workflowAgent) runParallel(ctx context.Context, generator *AsyncGenerat return nil } +func cancelAtTransition(ctx context.Context, info string, state any) *AgentEvent { + // state is the workflow checkpoint state (e.g. sequentialWorkflowState); + // nil for subContexts because this is a leaf interrupt with no child signals. + is, err := core.Interrupt(ctx, info, state, nil, + core.WithLayerPayload(getRunCtx(ctx).RunPath)) + if err != nil { + return &AgentEvent{Err: err} + } + + contexts := core.ToInterruptContexts(is, allowedAddressSegmentTypes) + + return &AgentEvent{ + Action: &AgentAction{ + Interrupted: &InterruptInfo{ + InterruptContexts: contexts, + }, + internalInterrupted: is, + }, + } +} + type SequentialAgentConfig struct { Name string Description string diff --git a/adk/wrappers.go b/adk/wrappers.go index 5061f5be8..eb23549fd 100644 --- a/adk/wrappers.go +++ b/adk/wrappers.go @@ -34,10 +34,11 @@ type generateEndpoint func(ctx context.Context, input []*schema.Message, opts .. type streamEndpoint func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) type modelWrapperConfig struct { - handlers []ChatModelAgentMiddleware - middlewares []AgentMiddleware - retryConfig *ModelRetryConfig - toolInfos []*schema.ToolInfo + handlers []ChatModelAgentMiddleware + middlewares []AgentMiddleware + retryConfig *ModelRetryConfig + toolInfos []*schema.ToolInfo + cancelContext *cancelContext } func buildModelWrappers(m model.BaseChatModel, config *modelWrapperConfig) model.BaseChatModel { @@ -54,6 +55,7 @@ func buildModelWrappers(m model.BaseChatModel, config *modelWrapperConfig) model middlewares: config.middlewares, toolInfos: config.toolInfos, modelRetryConfig: config.retryConfig, + cancelContext: config.cancelContext, } return wrapped @@ -252,11 +254,18 @@ func NewEventSenderModelWrapper() ChatModelAgentMiddleware { } func (w *eventSenderModelWrapper) WrapModel(_ context.Context, m model.BaseChatModel, mc *ModelContext) (model.BaseChatModel, error) { + inner := m + if mc != nil && mc.cancelContext != nil { + inner = &cancelMonitoredModel{ + inner: inner, + cancelContext: mc.cancelContext, + } + } var retryConfig *ModelRetryConfig if mc != nil { retryConfig = mc.ModelRetryConfig } - return &eventSenderModel{inner: m, modelRetryConfig: retryConfig}, nil + return &eventSenderModel{inner: inner, modelRetryConfig: retryConfig}, nil } type eventSenderModel struct { @@ -490,6 +499,7 @@ type stateModelWrapper struct { middlewares []AgentMiddleware toolInfos []*schema.ToolInfo modelRetryConfig *ModelRetryConfig + cancelContext *cancelContext } func (w *stateModelWrapper) IsCallbacksEnabled() bool { @@ -515,6 +525,7 @@ func (w *stateModelWrapper) hasUserEventSender() bool { func (w *stateModelWrapper) wrapGenerateEndpoint(endpoint generateEndpoint) generateEndpoint { hasUserEventSender := w.hasUserEventSender() retryConfig := w.modelRetryConfig + cc := w.cancelContext for i := len(w.handlers) - 1; i >= 0; i-- { handler := w.handlers[i] @@ -523,7 +534,7 @@ func (w *stateModelWrapper) wrapGenerateEndpoint(endpoint generateEndpoint) gene endpoint = func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { baseOpts := &model.Options{Tools: baseToolInfos} commonOpts := model.GetCommonOptions(baseOpts, opts...) - mc := &ModelContext{Tools: commonOpts.Tools, ModelRetryConfig: retryConfig} + mc := &ModelContext{Tools: commonOpts.Tools, ModelRetryConfig: retryConfig, cancelContext: cc} wrappedModel, err := handler.WrapModel(ctx, &endpointModel{generate: innerEndpoint}, mc) if err != nil { return nil, err @@ -540,7 +551,7 @@ func (w *stateModelWrapper) wrapGenerateEndpoint(endpoint generateEndpoint) gene if execCtx == nil || execCtx.generator == nil { return innerEndpoint(ctx, input, opts...) } - mc := &ModelContext{ModelRetryConfig: retryConfig} + mc := &ModelContext{ModelRetryConfig: retryConfig, cancelContext: cc} wrappedModel, err := eventSender.WrapModel(ctx, &endpointModel{generate: innerEndpoint}, mc) if err != nil { return nil, err @@ -563,6 +574,7 @@ func (w *stateModelWrapper) wrapGenerateEndpoint(endpoint generateEndpoint) gene func (w *stateModelWrapper) wrapStreamEndpoint(endpoint streamEndpoint) streamEndpoint { hasUserEventSender := w.hasUserEventSender() retryConfig := w.modelRetryConfig + cc := w.cancelContext for i := len(w.handlers) - 1; i >= 0; i-- { handler := w.handlers[i] @@ -571,7 +583,7 @@ func (w *stateModelWrapper) wrapStreamEndpoint(endpoint streamEndpoint) streamEn endpoint = func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { baseOpts := &model.Options{Tools: baseToolInfos} commonOpts := model.GetCommonOptions(baseOpts, opts...) - mc := &ModelContext{Tools: commonOpts.Tools, ModelRetryConfig: retryConfig} + mc := &ModelContext{Tools: commonOpts.Tools, ModelRetryConfig: retryConfig, cancelContext: cc} wrappedModel, err := handler.WrapModel(ctx, &endpointModel{stream: innerEndpoint}, mc) if err != nil { return nil, err @@ -588,7 +600,7 @@ func (w *stateModelWrapper) wrapStreamEndpoint(endpoint streamEndpoint) streamEn if execCtx == nil || execCtx.generator == nil { return innerEndpoint(ctx, input, opts...) } - mc := &ModelContext{ModelRetryConfig: retryConfig} + mc := &ModelContext{ModelRetryConfig: retryConfig, cancelContext: cc} wrappedModel, err := eventSender.WrapModel(ctx, &endpointModel{stream: innerEndpoint}, mc) if err != nil { return nil, err @@ -615,7 +627,7 @@ func (w *stateModelWrapper) Generate(ctx context.Context, input []*schema.Messag return nil }) - state := &ChatModelAgentState{Messages: append(stateMessages, input...)} + state := &ChatModelAgentState{Messages: stateMessages} for _, m := range w.middlewares { if m.BeforeChatModel != nil { @@ -627,7 +639,7 @@ func (w *stateModelWrapper) Generate(ctx context.Context, input []*schema.Messag baseOpts := &model.Options{Tools: w.toolInfos} commonOpts := model.GetCommonOptions(baseOpts, opts...) - mc := &ModelContext{Tools: commonOpts.Tools, ModelRetryConfig: w.modelRetryConfig} + mc := &ModelContext{Tools: commonOpts.Tools, ModelRetryConfig: w.modelRetryConfig, cancelContext: w.cancelContext} for _, handler := range w.handlers { var err error ctx, state, err = handler.BeforeModelRewriteState(ctx, state, mc) @@ -681,7 +693,7 @@ func (w *stateModelWrapper) Stream(ctx context.Context, input []*schema.Message, return nil }) - state := &ChatModelAgentState{Messages: append(stateMessages, input...)} + state := &ChatModelAgentState{Messages: stateMessages} for _, m := range w.middlewares { if m.BeforeChatModel != nil { @@ -693,7 +705,7 @@ func (w *stateModelWrapper) Stream(ctx context.Context, input []*schema.Message, baseOpts := &model.Options{Tools: w.toolInfos} commonOpts := model.GetCommonOptions(baseOpts, opts...) - mc := &ModelContext{Tools: commonOpts.Tools, ModelRetryConfig: w.modelRetryConfig} + mc := &ModelContext{Tools: commonOpts.Tools, ModelRetryConfig: w.modelRetryConfig, cancelContext: w.cancelContext} for _, handler := range w.handlers { var err error ctx, state, err = handler.BeforeModelRewriteState(ctx, state, mc) diff --git a/compose/checkpoint_test.go b/compose/checkpoint_test.go index c24b6ce6f..a86c02fb3 100644 --- a/compose/checkpoint_test.go +++ b/compose/checkpoint_test.go @@ -1383,6 +1383,7 @@ func TestCancelInterrupt(t *testing.T) { info, success := ExtractInterruptInfo(err) assert.True(t, success) assert.Equal(t, []string{"1"}, info.AfterNodes) + assert.True(t, info.FromGraphInterrupt) result, err := r.Invoke(ctx, "input", WithCheckPointID("1")) assert.NoError(t, err) assert.Equal(t, "input12", result) @@ -1397,6 +1398,7 @@ func TestCancelInterrupt(t *testing.T) { info, success = ExtractInterruptInfo(err) assert.True(t, success) assert.Equal(t, []string{"1"}, info.AfterNodes) + assert.True(t, info.FromGraphInterrupt) result, err = r.Invoke(ctx, "input", WithCheckPointID("2")) assert.NoError(t, err) assert.Equal(t, "input12", result) @@ -1412,6 +1414,7 @@ func TestCancelInterrupt(t *testing.T) { info, success = ExtractInterruptInfo(err) assert.True(t, success) assert.Equal(t, []string{"1"}, info.RerunNodes) + assert.True(t, info.FromGraphInterrupt) result, err = r.Invoke(ctx, "input", WithCheckPointID("3")) assert.NoError(t, err) assert.Equal(t, "input12", result) @@ -1441,6 +1444,7 @@ func TestCancelInterrupt(t *testing.T) { info, success = ExtractInterruptInfo(err) assert.True(t, success) assert.Equal(t, []string{"1"}, info.AfterNodes) + assert.True(t, info.FromGraphInterrupt) result, err = r.Invoke(ctx, "input", WithCheckPointID("1")) assert.NoError(t, err) assert.Equal(t, "input12", result) @@ -1455,6 +1459,7 @@ func TestCancelInterrupt(t *testing.T) { info, success = ExtractInterruptInfo(err) assert.True(t, success) assert.Equal(t, []string{"1"}, info.AfterNodes) + assert.True(t, info.FromGraphInterrupt) result, err = r.Invoke(ctx, "input", WithCheckPointID("2")) assert.NoError(t, err) assert.Equal(t, "input12", result) @@ -1470,6 +1475,7 @@ func TestCancelInterrupt(t *testing.T) { info, success = ExtractInterruptInfo(err) assert.True(t, success) assert.Equal(t, []string{"1"}, info.RerunNodes) + assert.True(t, info.FromGraphInterrupt) result, err = r.Invoke(ctx, "input", WithCheckPointID("3")) assert.NoError(t, err) assert.Equal(t, "input12", result) @@ -1510,6 +1516,7 @@ func TestCancelInterrupt(t *testing.T) { info, success = ExtractInterruptInfo(err) assert.True(t, success) assert.Equal(t, 2, len(info.AfterNodes)) + assert.True(t, info.FromGraphInterrupt) result2, err := rr.Invoke(ctx, "input", WithCheckPointID("1")) assert.NoError(t, err) assert.Equal(t, map[string]any{ @@ -1528,6 +1535,7 @@ func TestCancelInterrupt(t *testing.T) { info, success = ExtractInterruptInfo(err) assert.True(t, success) assert.Equal(t, 2, len(info.RerunNodes)) + assert.True(t, info.FromGraphInterrupt) result2, err = rr.Invoke(ctx, "input", WithCheckPointID("2")) assert.NoError(t, err) assert.Equal(t, map[string]any{ @@ -1536,6 +1544,26 @@ func TestCancelInterrupt(t *testing.T) { }, result2) } +func TestBusinessInterruptFromGraphInterruptFalse(t *testing.T) { + g := NewGraph[string, string]() + _ = g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { + return "", Interrupt(ctx, "biz") + })) + _ = g.AddEdge(START, "1") + _ = g.AddEdge("1", END) + + ctx := context.Background() + r, err := g.Compile(ctx, WithCheckPointStore(newInMemoryStore())) + assert.NoError(t, err) + + _, err = r.Invoke(ctx, "input", WithCheckPointID("biz")) + assert.Error(t, err) + info, existed := ExtractInterruptInfo(err) + assert.True(t, existed) + assert.False(t, info.FromGraphInterrupt) + assert.Equal(t, []string{"1"}, info.RerunNodes) +} + func TestPersistRerunInputNonStream(t *testing.T) { store := newInMemoryStore() diff --git a/compose/graph_manager.go b/compose/graph_manager.go index 944a0cf0a..46df3488e 100644 --- a/compose/graph_manager.go +++ b/compose/graph_manager.go @@ -496,12 +496,15 @@ func receiveWithListening(recv func() (*task, bool), cancel chan *time.Duration) return p.ta, p.closed, false, false, nil case timeout, ok := <-cancel: if !ok { - // unreachable - break + // The cancel channel has been closed — this means a previous call to + // receiveWithListening already consumed the cancel signal (task completed + // at the same time as cancel, and select picked the task result). Since + // cancel was already issued, treat this as an immediate cancel rather than + // blocking forever on resultCh. + return nil, false, true, true, nil } canceled = true if timeout == nil { - // canceled without timeout break } timeoutCh = time.After(*timeout) diff --git a/compose/graph_run.go b/compose/graph_run.go index a3e81ecf1..770cf16de 100644 --- a/compose/graph_run.go +++ b/compose/graph_run.go @@ -434,6 +434,7 @@ type interruptTempInfo struct { interruptBeforeNodes []string interruptAfterNodes []string interruptRerunExtra map[string]any + fromGraphInterrupt bool signals []*core.InterruptSignal } @@ -442,6 +443,7 @@ func (ti *interruptTempInfo) collectCanceledInfo(canceled bool, canceledTasks, c if !canceled { return } + ti.fromGraphInterrupt = true if len(canceledTasks) > 0 { for _, t := range canceledTasks { ti.interruptRerunNodes = append(ti.interruptRerunNodes, t.nodeKey) @@ -459,6 +461,13 @@ func (r *runner) resolveInterruptCompletedTasks(tempInfo *interruptTempInfo, com if info := isSubGraphInterrupt(completedTask.err); info != nil { tempInfo.subGraphInterrupts[completedTask.nodeKey] = info tempInfo.signals = append(tempInfo.signals, info.signal) + // Propagate FromGraphInterrupt from the sub-graph to the parent. + // The sub-graph's task manager may have consumed the cancel + // channel value before the parent's, so only the sub-graph + // knows the interrupt was triggered by a graph-level cancel. + if info.Info != nil && info.Info.FromGraphInterrupt { + tempInfo.fromGraphInterrupt = true + } continue } @@ -515,27 +524,27 @@ func (r *runner) handleInterrupt( if r.runCtx != nil { // current graph has enable state if state, ok := ctx.Value(stateKey{}).(*internalState); ok { - cp.State = state.state + state.mu.Lock() + copiedState, err := deepCopyState(state.state) + state.mu.Unlock() + if err != nil { + return fmt.Errorf("failed to copy state: %w", err) + } + cp.State = copiedState } } intInfo := &InterruptInfo{ - State: cp.State, - AfterNodes: tempInfo.interruptAfterNodes, - BeforeNodes: tempInfo.interruptBeforeNodes, - RerunNodes: tempInfo.interruptRerunNodes, - RerunNodesExtra: tempInfo.interruptRerunExtra, - SubGraphs: make(map[string]*InterruptInfo), + State: cp.State, + AfterNodes: tempInfo.interruptAfterNodes, + BeforeNodes: tempInfo.interruptBeforeNodes, + RerunNodes: tempInfo.interruptRerunNodes, + RerunNodesExtra: tempInfo.interruptRerunExtra, + SubGraphs: make(map[string]*InterruptInfo), + FromGraphInterrupt: tempInfo.fromGraphInterrupt, } - var info any - if cp.State != nil { - copiedState, err := deepCopyState(cp.State) - if err != nil { - return fmt.Errorf("failed to copy state: %w", err) - } - info = copiedState - } + info := cp.State is, err := core.Interrupt(ctx, info, nil, tempInfo.signals) if err != nil { @@ -581,15 +590,18 @@ func deepCopyState(state any) (any, error) { // Create new instance of the same type stateType := reflect.TypeOf(state) - if stateType.Kind() == reflect.Ptr { + isPtr := stateType.Kind() == reflect.Ptr + if isPtr { stateType = stateType.Elem() } - newState := reflect.New(stateType).Interface() - - if err := serializer.Unmarshal(data, newState); err != nil { + newStatePtr := reflect.New(stateType).Interface() + if err := serializer.Unmarshal(data, newStatePtr); err != nil { return nil, fmt.Errorf("failed to unmarshal state: %w", err) } - return newState, nil + if isPtr { + return newStatePtr, nil + } + return reflect.ValueOf(newStatePtr).Elem().Interface(), nil } func (r *runner) handleInterruptWithSubGraphAndRerunNodes( @@ -645,27 +657,27 @@ func (r *runner) handleInterruptWithSubGraphAndRerunNodes( if r.runCtx != nil { // current graph has enable state if state, ok := ctx.Value(stateKey{}).(*internalState); ok { - cp.State = state.state + state.mu.Lock() + copiedState, err_ := deepCopyState(state.state) + state.mu.Unlock() + if err_ != nil { + return fmt.Errorf("failed to copy state: %w", err_) + } + cp.State = copiedState } } intInfo := &InterruptInfo{ - State: cp.State, - BeforeNodes: tempInfo.interruptBeforeNodes, - AfterNodes: tempInfo.interruptAfterNodes, - RerunNodes: tempInfo.interruptRerunNodes, - RerunNodesExtra: tempInfo.interruptRerunExtra, - SubGraphs: make(map[string]*InterruptInfo), + State: cp.State, + BeforeNodes: tempInfo.interruptBeforeNodes, + AfterNodes: tempInfo.interruptAfterNodes, + RerunNodes: tempInfo.interruptRerunNodes, + RerunNodesExtra: tempInfo.interruptRerunExtra, + SubGraphs: make(map[string]*InterruptInfo), + FromGraphInterrupt: tempInfo.fromGraphInterrupt, } - var info any - if cp.State != nil { - copiedState, err_ := deepCopyState(cp.State) - if err_ != nil { - return fmt.Errorf("failed to copy state: %w", err_) - } - info = copiedState - } + info := cp.State is, err := core.Interrupt(ctx, info, nil, tempInfo.signals) if err != nil { diff --git a/compose/interrupt.go b/compose/interrupt.go index 98a5eeecc..cd423a1d6 100644 --- a/compose/interrupt.go +++ b/compose/interrupt.go @@ -263,6 +263,10 @@ type InterruptInfo struct { RerunNodesExtra map[string]any SubGraphs map[string]*InterruptInfo InterruptContexts []*InterruptCtx + // FromGraphInterrupt indicates whether the interrupt was triggered by a graph-level + // cancel operation (e.g., via WithGraphInterrupt) rather than business logic. + // When true, the interrupt originated from an external cancellation request. + FromGraphInterrupt bool } func init() { diff --git a/examples b/examples new file mode 160000 index 000000000..4afd5a3f2 --- /dev/null +++ b/examples @@ -0,0 +1 @@ +Subproject commit 4afd5a3f26a4db4833088505b9f7a0f631e9f231 diff --git a/ext b/ext new file mode 160000 index 000000000..f061db7e8 --- /dev/null +++ b/ext @@ -0,0 +1 @@ +Subproject commit f061db7e84191705db6c48f0085938de84f90742 diff --git a/internal/channel.go b/internal/channel.go index 2351c87e9..8e36d8939 100644 --- a/internal/channel.go +++ b/internal/channel.go @@ -46,6 +46,21 @@ func (ch *UnboundedChan[T]) Send(value T) { ch.notEmpty.Signal() // Wake up one goroutine waiting to receive } +// TrySend attempts to put an item into the channel. +// Returns false if the channel is closed, true otherwise. +func (ch *UnboundedChan[T]) TrySend(value T) bool { + ch.mutex.Lock() + defer ch.mutex.Unlock() + + if ch.closed { + return false + } + + ch.buffer = append(ch.buffer, value) + ch.notEmpty.Signal() + return true +} + // Receive gets an item from the channel (blocks if empty) func (ch *UnboundedChan[T]) Receive() (T, bool) { ch.mutex.Lock() @@ -76,3 +91,33 @@ func (ch *UnboundedChan[T]) Close() { ch.notEmpty.Broadcast() // Wake up all waiting goroutines } } + +// TakeAll removes and returns all values from the channel atomically. +// Returns nil if the channel is empty. +func (ch *UnboundedChan[T]) TakeAll() []T { + ch.mutex.Lock() + defer ch.mutex.Unlock() + + if len(ch.buffer) == 0 { + return nil + } + + values := ch.buffer + ch.buffer = nil + return values +} + +// PushFront adds values to the front of the channel. +// This is useful for recovering values that need to be reprocessed. +// Does nothing if values is empty. +func (ch *UnboundedChan[T]) PushFront(values []T) { + if len(values) == 0 { + return + } + + ch.mutex.Lock() + defer ch.mutex.Unlock() + + ch.buffer = append(append([]T{}, values...), ch.buffer...) + ch.notEmpty.Signal() +} diff --git a/internal/channel_test.go b/internal/channel_test.go index 736a27413..bed2383f1 100644 --- a/internal/channel_test.go +++ b/internal/channel_test.go @@ -219,3 +219,244 @@ func TestUnboundedChan_BlockingReceive(t *testing.T) { t.Error("Receive should have unblocked") } } + +func TestUnboundedChan_TakeAll(t *testing.T) { + ch := NewUnboundedChan[int]() + + // Test TakeAll on empty channel + items := ch.TakeAll() + if items != nil { + t.Errorf("TakeAll on empty channel should return nil, got %v", items) + } + + // Send some values + ch.Send(1) + ch.Send(2) + ch.Send(3) + + // Test TakeAll returns all values + items = ch.TakeAll() + if len(items) != 3 { + t.Errorf("expected 3 values, got %d", len(items)) + } + if items[0] != 1 || items[1] != 2 || items[2] != 3 { + t.Errorf("unexpected values: %v", items) + } + + // Channel should be empty now + if len(ch.buffer) != 0 { + t.Errorf("channel should be empty after TakeAll, got %d values", len(ch.buffer)) + } + + // TakeAll again should return nil + items = ch.TakeAll() + if items != nil { + t.Errorf("TakeAll on empty channel should return nil, got %v", items) + } +} + +func TestUnboundedChan_TakeAll_Partial(t *testing.T) { + ch := NewUnboundedChan[int]() + + // Send values + ch.Send(1) + ch.Send(2) + ch.Send(3) + + // Receive one + val, ok := ch.Receive() + if !ok || val != 1 { + t.Errorf("expected (1, true), got (%d, %v)", val, ok) + } + + // TakeAll should return remaining values + items := ch.TakeAll() + if len(items) != 2 { + t.Errorf("expected 2 values, got %d", len(items)) + } + if items[0] != 2 || items[1] != 3 { + t.Errorf("unexpected values: %v", items) + } +} + +func TestUnboundedChan_PushFront(t *testing.T) { + ch := NewUnboundedChan[int]() + + // Test PushFront with empty values (should do nothing) + ch.PushFront(nil) + ch.PushFront([]int{}) + if len(ch.buffer) != 0 { + t.Errorf("PushFront with empty values should not add anything, got %d values", len(ch.buffer)) + } + + // Send some values + ch.Send(3) + ch.Send(4) + + // PushFront should prepend values + ch.PushFront([]int{1, 2}) + + if len(ch.buffer) != 4 { + t.Errorf("expected 4 values, got %d", len(ch.buffer)) + } + if ch.buffer[0] != 1 || ch.buffer[1] != 2 || ch.buffer[2] != 3 || ch.buffer[3] != 4 { + t.Errorf("unexpected buffer: %v", ch.buffer) + } + + // Receive should return in correct order + val, _ := ch.Receive() + if val != 1 { + t.Errorf("expected 1, got %d", val) + } + val, _ = ch.Receive() + if val != 2 { + t.Errorf("expected 2, got %d", val) + } +} + +func TestUnboundedChan_PushFront_EmptyChannel(t *testing.T) { + ch := NewUnboundedChan[int]() + + // PushFront to empty channel + ch.PushFront([]int{1, 2, 3}) + + if len(ch.buffer) != 3 { + t.Errorf("expected 3 values, got %d", len(ch.buffer)) + } + + // Receive should work + val, ok := ch.Receive() + if !ok || val != 1 { + t.Errorf("expected (1, true), got (%d, %v)", val, ok) + } +} + +func TestUnboundedChan_PushFront_UnblocksReceive(t *testing.T) { + ch := NewUnboundedChan[int]() + + // Start a blocking receive + receiveDone := make(chan int) + go func() { + val, _ := ch.Receive() + receiveDone <- val + }() + + // Ensure receive is blocked + select { + case <-receiveDone: + t.Error("Receive should block on empty channel") + case <-time.After(50 * time.Millisecond): + // This is expected + } + + // PushFront should unblock the receive + ch.PushFront([]int{42}) + + select { + case val := <-receiveDone: + if val != 42 { + t.Errorf("expected 42, got %d", val) + } + case <-time.After(50 * time.Millisecond): + t.Error("Receive should have unblocked after PushFront") + } +} + +func TestUnboundedChan_PushFront_SpareCapacity(t *testing.T) { + ch := NewUnboundedChan[int]() + + // Pre-fill the channel so PushFront has something to append + ch.Send(10) + ch.Send(20) + + // Create a slice with spare capacity: len=2, cap=10. + // Elements beyond len (index 2-9) must not be corrupted by PushFront. + src := make([]int, 3, 10) + src[0] = 1 + src[1] = 2 + src[2] = 3 // sentinel — must survive PushFront(src[:2]) + + ch.PushFront(src[:2]) + + // Verify the sentinel was NOT overwritten by the channel's existing buffer + if src[2] != 3 { + t.Errorf("PushFront corrupted caller's backing array: src[2] = %d, want 3", src[2]) + } + + // Verify channel drains correctly: [1, 2, 10, 20] + expected := []int{1, 2, 10, 20} + for i, want := range expected { + got, ok := ch.Receive() + if !ok { + t.Fatalf("Receive returned ok=false at index %d", i) + } + if got != want { + t.Errorf("index %d: got %d, want %d", i, got, want) + } + } +} + +func TestUnboundedChan_TakeAll_PushFront_Concurrent(t *testing.T) { + ch := NewUnboundedChan[int]() + const numOps = 100 + + var wg sync.WaitGroup + wg.Add(3) + + // Goroutine 1: Send values + go func() { + defer wg.Done() + for i := 0; i < numOps; i++ { + ch.Send(i) + time.Sleep(time.Microsecond) + } + }() + + // Goroutine 2: TakeAll periodically + takeAllResults := make([][]int, 0) + var mu sync.Mutex + go func() { + defer wg.Done() + for i := 0; i < numOps/10; i++ { + items := ch.TakeAll() + if items != nil { + mu.Lock() + takeAllResults = append(takeAllResults, items) + mu.Unlock() + } + time.Sleep(10 * time.Microsecond) + } + }() + + // Goroutine 3: PushFront periodically + go func() { + defer wg.Done() + for i := 0; i < numOps/10; i++ { + ch.PushFront([]int{-i}) + time.Sleep(10 * time.Microsecond) + } + }() + + wg.Wait() + ch.Close() + + // Drain remaining values + remaining := ch.TakeAll() + if remaining != nil { + mu.Lock() + takeAllResults = append(takeAllResults, remaining) + mu.Unlock() + } + + // Count total values collected + total := 0 + for _, batch := range takeAllResults { + total += len(batch) + } + + // We should have exactly numOps (from Send) + numOps/10 (from PushFront) values + expected := numOps + numOps/10 + if total != expected { + t.Errorf("expected %d values, got %d", expected, total) + } +} diff --git a/internal/core/address.go b/internal/core/address.go index 8efabf943..c1ecb17fb 100644 --- a/internal/core/address.go +++ b/internal/core/address.go @@ -88,7 +88,7 @@ type addrCtx struct { type globalResumeInfoKey struct{} type globalResumeInfo struct { - mu sync.Mutex + mu sync.RWMutex id2ResumeData map[string]any id2ResumeDataUsed map[string]bool id2State map[string]InterruptState @@ -147,24 +147,21 @@ func AppendAddressSegment(ctx context.Context, segType AddressSegmentType, segID return context.WithValue(ctx, addrCtxKey{}, runCtx) } + rInfo.mu.Lock() + defer rInfo.mu.Unlock() + var id string for id_, addr := range rInfo.id2Addr { if addr.Equals(currentAddress) { - rInfo.mu.Lock() if used, ok := rInfo.id2StateUsed[id_]; !ok || !used { runCtx.interruptState = generic.PtrOf(rInfo.id2State[id_]) rInfo.id2StateUsed[id_] = true id = id_ - rInfo.mu.Unlock() break } - rInfo.mu.Unlock() } } - // take from globalResumeInfo the data for the new address if there is any - rInfo.mu.Lock() - defer rInfo.mu.Unlock() used := rInfo.id2ResumeDataUsed[id] if !used { rData, existed := rInfo.id2ResumeData[id] @@ -175,10 +172,6 @@ func AppendAddressSegment(ctx context.Context, segType AddressSegmentType, segID } } - // Also mark as resume target if any descendant address is a resume target. - // This allows composite components (e.g., a tool containing a nested graph) to know - // they should execute their children to reach the actual resume target. - // We only consider descendants whose resume data has not yet been consumed. if !runCtx.isResumeTarget { for id_, addr := range rInfo.id2Addr { if len(addr) > len(currentAddress) && addr[:len(currentAddress)].Equals(currentAddress) { @@ -202,6 +195,9 @@ func GetNextResumptionPoints(ctx context.Context) (map[string]bool, error) { return nil, fmt.Errorf("GetNextResumptionPoints: failed to get resume info from context") } + rInfo.mu.RLock() + defer rInfo.mu.RUnlock() + nextPoints := make(map[string]bool) parentAddrLen := len(parentAddr) @@ -276,6 +272,9 @@ func PopulateInterruptState(ctx context.Context, id2Addr map[string]Address, id2State map[string]InterruptState) context.Context { rInfo, ok := ctx.Value(globalResumeInfoKey{}).(*globalResumeInfo) if ok { + rInfo.mu.Lock() + defer rInfo.mu.Unlock() + if rInfo.id2Addr == nil { rInfo.id2Addr = make(map[string]Address) } @@ -299,17 +298,13 @@ func PopulateInterruptState(ctx context.Context, id2Addr map[string]Address, if addr.Equals(runCtx.addr) { if used, ok := rInfo.id2StateUsed[id_]; !ok || !used { runCtx.interruptState = generic.PtrOf(rInfo.id2State[id_]) - rInfo.mu.Lock() rInfo.id2StateUsed[id_] = true - rInfo.mu.Unlock() } if used, ok := rInfo.id2ResumeDataUsed[id_]; !ok || !used { runCtx.isResumeTarget = true runCtx.resumeData = rInfo.id2ResumeData[id_] - rInfo.mu.Lock() rInfo.id2ResumeDataUsed[id_] = true - rInfo.mu.Unlock() } break diff --git a/internal/core/interrupt.go b/internal/core/interrupt.go index d7a934a3d..174e0c47c 100644 --- a/internal/core/interrupt.go +++ b/internal/core/interrupt.go @@ -29,6 +29,13 @@ type CheckPointStore interface { Set(ctx context.Context, checkPointID string, checkPoint []byte) error } +// CheckPointDeleter is an optional interface that CheckPointStore implementations +// can implement to support explicit checkpoint deletion. If the Store does not +// implement this interface, deletion is performed by writing an empty value via Set. +type CheckPointDeleter interface { + Delete(ctx context.Context, checkPointID string) error +} + type InterruptSignal struct { ID string Address diff --git a/schema/serialization.go b/schema/serialization.go index 7a719b0a8..22fa16ade 100644 --- a/schema/serialization.go +++ b/schema/serialization.go @@ -25,7 +25,7 @@ import ( ) func init() { - RegisterName[Message]("_eino_message") + RegisterName[*Message]("_eino_message") RegisterName[[]*Message]("_eino_message_slice") RegisterName[Document]("_eino_document") RegisterName[RoleType]("_eino_role_type") From 66aedb3d667129e0c32d67e5966d8aa907840047 Mon Sep 17 00:00:00 2001 From: shentongmartin Date: Fri, 27 Mar 2026 11:46:36 +0800 Subject: [PATCH 46/59] fix(adk): skip saving checkpoint when TurnLoop is idle (#916) * fix(adk): skip saving checkpoint when TurnLoop is idle When Stop() is called on an idle TurnLoop (no active agent run, no unhandled items, no canceled items), the resulting checkpoint contains no meaningful state. Skip saving such checkpoints to avoid unnecessary store writes. - Add isIdle check in cleanup() before checkpoint save decision - Add TestTurnLoop_StopWhileIdle_SkipsCheckpoint test Change-Id: I6aeaff5ed5833a971cb95298193fdb96d904baf8 * fix(internal): merge id2State in PopulateInterruptState instead of replacing PopulateInterruptState merged id2Addr entries one by one but replaced id2State wholesale. In a parallel workflow resume, two goroutines share the same globalResumeInfo. If one goroutine's compose graph called PopulateInterruptState (replacing id2State with compose-only entries) before the other goroutine looked up its outer-level entry, the lookup returned a zero-value InterruptState with State=nil, triggering the 'has no state' panic in ChatModelAgent.Resume. Change id2State handling to merge entry by entry, consistent with id2Addr. Change-Id: Ia21f65289bff7beb2bc383fb033926ad9c92d7e7 * fix(adk): keep watching for cancel escalation after stopSig.done When watchStopSignal entered the stopSig.done branch, it processed the initial cancel and then blocked on <-done (turn completion), never looping back to check notify. This meant a subsequent Stop() call with a higher cancel mode (e.g. CancelImmediate) was never forwarded to the agent, causing TestTurnLoop_Stop_EscalatesCancelMode to time out. Replace the blocking <-done with an inner loop that selects on both done and notify, so escalation signals are always delivered. Also apply the generation-based dedup check consistent with the notify branch. Change-Id: Ia6a04d00a2b44625ffbcb625ff0e559c12ed145f --- adk/turn_loop.go | 34 ++++++++++++++++++++++++++-------- adk/turn_loop_test.go | 28 ++++++++++++++++++++++++++++ internal/core/address.go | 7 ++++++- 3 files changed, 60 insertions(+), 9 deletions(-) diff --git a/adk/turn_loop.go b/adk/turn_loop.go index 88979876e..000842589 100644 --- a/adk/turn_loop.go +++ b/adk/turn_loop.go @@ -1228,14 +1228,31 @@ func (l *TurnLoop[T]) watchStopSignal(done <-chan struct{}, agentCancelFunc Agen } } case <-l.stopSig.done: - _, opts := l.stopSig.check() - _, contributed := agentCancelFunc(opts...) - if contributed && !stoppedClosed { - close(stoppedDone) - stoppedClosed = true + gen, opts := l.stopSig.check() + if gen != lastGen { + lastGen = gen + _, contributed := agentCancelFunc(opts...) + if contributed && !stoppedClosed { + close(stoppedDone) + stoppedClosed = true + } + } + for { + select { + case <-done: + return + case <-l.stopSig.notify: + gen, opts := l.stopSig.check() + if gen != lastGen { + lastGen = gen + _, contributed := agentCancelFunc(opts...) + if contributed && !stoppedClosed { + close(stoppedDone) + stoppedClosed = true + } + } + } } - <-done - return } } } @@ -1380,7 +1397,8 @@ func (l *TurnLoop[T]) cleanup(ctx context.Context) { unhandled := l.buffer.TakeAll() checkpointID := l.config.CheckpointID - shouldSaveCheckpoint := l.config.Store != nil && checkpointID != "" && l.stopSig.isStopped() + isIdle := len(l.checkPointRunnerBytes) == 0 && len(unhandled) == 0 && len(l.canceledItems) == 0 + shouldSaveCheckpoint := l.config.Store != nil && checkpointID != "" && l.stopSig.isStopped() && !isIdle if shouldSaveCheckpoint { cp := &turnLoopCheckpoint[T]{ RunnerCheckpoint: l.checkPointRunnerBytes, diff --git a/adk/turn_loop_test.go b/adk/turn_loop_test.go index 6e3159cfc..72753fc5e 100644 --- a/adk/turn_loop_test.go +++ b/adk/turn_loop_test.go @@ -1484,6 +1484,34 @@ func TestTurnLoop_StopWithoutCheckpointIDDoesNotPersist(t *testing.T) { assert.Empty(t, store.m, "no checkpoint should be saved when CheckpointID is not configured") } +func TestTurnLoop_StopWhileIdle_SkipsCheckpoint(t *testing.T) { + ctx := context.Background() + store := &deletableCheckpointStore{ + turnLoopCheckpointStore: turnLoopCheckpointStore{m: make(map[string][]byte)}, + } + cpID := "idle-session" + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Stop() + exit := loop.Wait() + assert.NoError(t, exit.ExitReason) + + store.mu.Lock() + defer store.mu.Unlock() + _, exists := store.m[cpID] + assert.False(t, exists, "no checkpoint should be saved when TurnLoop is idle") +} + func TestTurnLoop_StopBetweenTurnsAndResume(t *testing.T) { ctx := context.Background() store := &turnLoopCheckpointStore{m: make(map[string][]byte)} diff --git a/internal/core/address.go b/internal/core/address.go index c1ecb17fb..bb2400a92 100644 --- a/internal/core/address.go +++ b/internal/core/address.go @@ -281,7 +281,12 @@ func PopulateInterruptState(ctx context.Context, id2Addr map[string]Address, for id, addr := range id2Addr { rInfo.id2Addr[id] = addr } - rInfo.id2State = id2State + if rInfo.id2State == nil { + rInfo.id2State = make(map[string]InterruptState) + } + for id, state := range id2State { + rInfo.id2State[id] = state + } } else { rInfo = &globalResumeInfo{ id2Addr: id2Addr, From 0842de7a61fb948f431f368a19f9f916d9dfe1db Mon Sep 17 00:00:00 2001 From: shentongmartin Date: Wed, 1 Apr 2026 11:41:15 +0800 Subject: [PATCH 47/59] feat(adk): export NewEventSenderToolWrapper for customizable tool event position (#926) --- adk/chatmodel.go | 39 ++- adk/wrappers.go | 88 ++++--- adk/wrappers_test.go | 597 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 680 insertions(+), 44 deletions(-) diff --git a/adk/chatmodel.go b/adk/chatmodel.go index a0126455a..d7382b309 100644 --- a/adk/chatmodel.go +++ b/adk/chatmodel.go @@ -283,13 +283,35 @@ type ChatModelAgentConfig struct { // the default event sender to avoid duplicate events. // // Tool call lifecycle (outermost to innermost): - // 1. eventSenderToolHandler (internal ToolMiddleware - sends tool result events after all processing) + // 1. eventSenderToolWrapper (internal ToolMiddleware - sends tool result events after all processing) // 2. ToolsConfig.ToolCallMiddlewares (ToolMiddleware) // 3. AgentMiddleware.WrapToolCall (ToolMiddleware) // 4. ChatModelAgentMiddleware.WrapToolCall (wrapper, first registered is outermost) // 5. callbackInjectedToolCall (internal - injects callbacks if tool doesn't handle them) // 6. Tool.InvokableRun/StreamableRun // + // Custom Tool Event Sender Position: + // By default, tool result events are emitted by an internal event sender placed before + // all user middlewares (outermost), so events reflect the fully processed tool output. + // To control exactly where in the handler chain tool events are emitted, pass + // NewEventSenderToolWrapper() as one of the Handlers. Its position determines which + // middlewares' effects are visible in the emitted event: + // + // agent, _ := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ + // Handlers: []adk.ChatModelAgentMiddleware{ + // loggingHandler, // Outermost: sees event-sender output + // adk.NewEventSenderToolWrapper(), // Events reflect output from handlers below + // sanitizationHandler, // Innermost: runs first, modifies tool output + // }, + // }) + // + // Handler order: first registered is outermost. So [A, B, C] wraps as A(B(C(tool))). + // The event sender captures tool output in post-processing, so its position controls + // which handlers' modifications are included in the emitted events. + // + // When NewEventSenderToolWrapper is detected in Handlers, the framework skips + // the default event sender to avoid duplicate events. + // // Tool List Modification: // // There are two ways to modify the tool list: @@ -371,20 +393,15 @@ func NewChatModelAgent(ctx context.Context, config *ChatModelAgentConfig) (*Chat tc := config.ToolsConfig // Tool call middleware execution order (outermost to innermost): - // 1. eventSenderToolHandler (internal - sends tool result events after all modifications) + // 1. eventSenderToolWrapper (internal - sends tool result events after all modifications) // 2. User-provided ToolsConfig.ToolCallMiddlewares (original order preserved) // 3. Middlewares' WrapToolCall (in registration order) // 4. ChatModelAgentMiddleware.WrapToolCall (in registration order) // 5. callbackInjectedToolCall (internal - injects callbacks if tool doesn't handle them) - eventSender := &eventSenderToolHandler{} - tc.ToolCallMiddlewares = append( - []compose.ToolMiddleware{{Invokable: eventSender.WrapInvokableToolCall, - Streamable: eventSender.WrapStreamableToolCall, - EnhancedInvokable: eventSender.WrapEnhancedInvokableToolCall, - EnhancedStreamable: eventSender.WrapEnhancedStreamableToolCall, - }}, - tc.ToolCallMiddlewares..., - ) + if !hasUserEventSenderToolWrapper(config.Handlers) { + defaultToolEventSender := handlersToToolMiddlewares([]ChatModelAgentMiddleware{NewEventSenderToolWrapper()}) + tc.ToolCallMiddlewares = append(defaultToolEventSender, tc.ToolCallMiddlewares...) + } tc.ToolCallMiddlewares = append(tc.ToolCallMiddlewares, collectToolMiddlewaresFromMiddlewares(config.Middlewares)...) // Cancel monitoring middleware (innermost — close to the tool endpoint). diff --git a/adk/wrappers.go b/adk/wrappers.go index eb23549fd..e6ac617ea 100644 --- a/adk/wrappers.go +++ b/adk/wrappers.go @@ -350,20 +350,33 @@ func popToolGenAction(ctx context.Context, toolName string) *AgentAction { return action } -type eventSenderToolHandler struct{} +type eventSenderToolWrapper struct { + *BaseChatModelAgentMiddleware +} + +// NewEventSenderToolWrapper returns a ChatModelAgentMiddleware that sends tool result events. +// By default, the framework places this before all user middlewares (outermost), so events +// reflect the fully processed tool output. To control exactly where events are emitted, +// include this in ChatModelAgentConfig.Handlers at the desired position. +// When detected in Handlers, the framework skips the default event sender to avoid duplicates. +func NewEventSenderToolWrapper() ChatModelAgentMiddleware { + return &eventSenderToolWrapper{ + BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, + } +} -func (h *eventSenderToolHandler) WrapInvokableToolCall(next compose.InvokableToolEndpoint) compose.InvokableToolEndpoint { - return func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) { - output, err := next(ctx, input) +func (w *eventSenderToolWrapper) WrapInvokableToolCall(_ context.Context, endpoint InvokableToolCallEndpoint, tCtx *ToolContext) (InvokableToolCallEndpoint, error) { + return func(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { + result, err := endpoint(ctx, argumentsInJSON, opts...) if err != nil { - return nil, err + return "", err } - toolName := input.Name - callID := input.CallID + toolName := tCtx.Name + callID := tCtx.CallID prePopAction := popToolGenAction(ctx, toolName) - msg := schema.ToolMessage(output.Result, callID, schema.WithToolName(toolName)) + msg := schema.ToolMessage(result, callID, schema.WithToolName(toolName)) event := EventFromMessage(msg, nil, schema.Tool, toolName) if prePopAction != nil { event.Action = prePopAction @@ -379,22 +392,22 @@ func (h *eventSenderToolHandler) WrapInvokableToolCall(next compose.InvokableToo return nil }) - return output, nil - } + return result, nil + }, nil } -func (h *eventSenderToolHandler) WrapStreamableToolCall(next compose.StreamableToolEndpoint) compose.StreamableToolEndpoint { - return func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) { - output, err := next(ctx, input) +func (w *eventSenderToolWrapper) WrapStreamableToolCall(_ context.Context, endpoint StreamableToolCallEndpoint, tCtx *ToolContext) (StreamableToolCallEndpoint, error) { + return func(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (*schema.StreamReader[string], error) { + result, err := endpoint(ctx, argumentsInJSON, opts...) if err != nil { return nil, err } - toolName := input.Name - callID := input.CallID + toolName := tCtx.Name + callID := tCtx.CallID prePopAction := popToolGenAction(ctx, toolName) - streams := output.Result.Copy(2) + streams := result.Copy(2) cvt := func(in string) (Message, error) { return schema.ToolMessage(in, callID, schema.WithToolName(toolName)), nil @@ -413,23 +426,23 @@ func (h *eventSenderToolHandler) WrapStreamableToolCall(next compose.StreamableT return nil }) - return &compose.StreamToolOutput{Result: streams[1]}, nil - } + return streams[1], nil + }, nil } -func (h *eventSenderToolHandler) WrapEnhancedInvokableToolCall(next compose.EnhancedInvokableToolEndpoint) compose.EnhancedInvokableToolEndpoint { - return func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedInvokableToolOutput, error) { - output, err := next(ctx, input) +func (w *eventSenderToolWrapper) WrapEnhancedInvokableToolCall(_ context.Context, endpoint EnhancedInvokableToolCallEndpoint, tCtx *ToolContext) (EnhancedInvokableToolCallEndpoint, error) { + return func(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (*schema.ToolResult, error) { + result, err := endpoint(ctx, toolArgument, opts...) if err != nil { return nil, err } - toolName := input.Name - callID := input.CallID + toolName := tCtx.Name + callID := tCtx.CallID prePopAction := popToolGenAction(ctx, toolName) msg := schema.ToolMessage("", callID, schema.WithToolName(toolName)) - msg.UserInputMultiContent, err = output.Result.ToMessageInputParts() + msg.UserInputMultiContent, err = result.ToMessageInputParts() if err != nil { return nil, err } @@ -448,22 +461,22 @@ func (h *eventSenderToolHandler) WrapEnhancedInvokableToolCall(next compose.Enha return nil }) - return output, nil - } + return result, nil + }, nil } -func (h *eventSenderToolHandler) WrapEnhancedStreamableToolCall(next compose.EnhancedStreamableToolEndpoint) compose.EnhancedStreamableToolEndpoint { - return func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedStreamableToolOutput, error) { - output, err := next(ctx, input) +func (w *eventSenderToolWrapper) WrapEnhancedStreamableToolCall(_ context.Context, endpoint EnhancedStreamableToolCallEndpoint, tCtx *ToolContext) (EnhancedStreamableToolCallEndpoint, error) { + return func(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (*schema.StreamReader[*schema.ToolResult], error) { + result, err := endpoint(ctx, toolArgument, opts...) if err != nil { return nil, err } - toolName := input.Name - callID := input.CallID + toolName := tCtx.Name + callID := tCtx.CallID prePopAction := popToolGenAction(ctx, toolName) - streams := output.Result.Copy(2) + streams := result.Copy(2) cvt := func(in *schema.ToolResult) (Message, error) { msg := schema.ToolMessage("", callID, schema.WithToolName(toolName)) @@ -488,8 +501,17 @@ func (h *eventSenderToolHandler) WrapEnhancedStreamableToolCall(next compose.Enh return nil }) - return &compose.EnhancedStreamableToolOutput{Result: streams[1]}, nil + return streams[1], nil + }, nil +} + +func hasUserEventSenderToolWrapper(handlers []ChatModelAgentMiddleware) bool { + for _, handler := range handlers { + if _, ok := handler.(*eventSenderToolWrapper); ok { + return true + } } + return false } type stateModelWrapper struct { diff --git a/adk/wrappers_test.go b/adk/wrappers_test.go index 5fd8acef5..acb6588be 100644 --- a/adk/wrappers_test.go +++ b/adk/wrappers_test.go @@ -1085,3 +1085,600 @@ func (m *contentModifyingModelWrapper) Stream(ctx context.Context, input []*sche result.Content = m.newContent return schema.StreamReaderFromArray([]*schema.Message{result}), nil } + +type mockToolCallingModel struct { + mu sync.Mutex + generateCalls int + toolCallName string +} + +func (m *mockToolCallingModel) Generate(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + m.mu.Lock() + m.generateCalls++ + calls := m.generateCalls + m.mu.Unlock() + if calls == 1 { + return schema.AssistantMessage("calling tool", []schema.ToolCall{ + {ID: "tc-1", Function: schema.FunctionCall{Name: m.toolCallName, Arguments: `{"input":"test"}`}}, + }), nil + } + return schema.AssistantMessage("done", nil), nil +} + +func (m *mockToolCallingModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + msg, err := m.Generate(ctx, input, opts...) + if err != nil { + return nil, err + } + return schema.StreamReaderFromArray([]*schema.Message{msg}), nil +} + +func (m *mockToolCallingModel) WithTools(_ []*schema.ToolInfo) (model.ToolCallingChatModel, error) { + return m, nil +} + +type invokableTestTool struct { + name string + result string +} + +func (t *invokableTestTool) Info(_ context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: t.name, + Desc: "test tool", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "input": {Desc: "input", Required: true, Type: schema.String}, + }), + }, nil +} + +func (t *invokableTestTool) InvokableRun(_ context.Context, _ string, _ ...tool.Option) (string, error) { + return t.result, nil +} + +type streamableTestTool struct { + name string + result string +} + +func (t *streamableTestTool) Info(_ context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: t.name, + Desc: "test tool", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "input": {Desc: "input", Required: true, Type: schema.String}, + }), + }, nil +} + +func (t *streamableTestTool) StreamableRun(_ context.Context, _ string, _ ...tool.Option) (*schema.StreamReader[string], error) { + return schema.StreamReaderFromArray([]string{t.result}), nil +} + +type enhancedInvokableTestTool struct { + name string + result string +} + +func (t *enhancedInvokableTestTool) Info(_ context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: t.name, + Desc: "test tool", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "input": {Desc: "input", Required: true, Type: schema.String}, + }), + }, nil +} + +func (t *enhancedInvokableTestTool) InvokableRun(_ context.Context, _ *schema.ToolArgument, _ ...tool.Option) (*schema.ToolResult, error) { + return &schema.ToolResult{ + Parts: []schema.ToolOutputPart{{Type: schema.ToolPartTypeText, Text: t.result}}, + }, nil +} + +type enhancedStreamableTestTool struct { + name string + result string +} + +func (t *enhancedStreamableTestTool) Info(_ context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: t.name, + Desc: "test tool", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "input": {Desc: "input", Required: true, Type: schema.String}, + }), + }, nil +} + +func (t *enhancedStreamableTestTool) StreamableRun(_ context.Context, _ *schema.ToolArgument, _ ...tool.Option) (*schema.StreamReader[*schema.ToolResult], error) { + return schema.StreamReaderFromArray([]*schema.ToolResult{ + {Parts: []schema.ToolOutputPart{{Type: schema.ToolPartTypeText, Text: t.result}}}, + }), nil +} + +type invokableResultModifier struct { + *BaseChatModelAgentMiddleware + modifiedResult string +} + +func (h *invokableResultModifier) WrapInvokableToolCall(_ context.Context, endpoint InvokableToolCallEndpoint, _ *ToolContext) (InvokableToolCallEndpoint, error) { + return func(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { + _, err := endpoint(ctx, argumentsInJSON, opts...) + if err != nil { + return "", err + } + return h.modifiedResult, nil + }, nil +} + +type streamableResultModifier struct { + *BaseChatModelAgentMiddleware + modifiedResult string +} + +func (h *streamableResultModifier) WrapStreamableToolCall(_ context.Context, endpoint StreamableToolCallEndpoint, _ *ToolContext) (StreamableToolCallEndpoint, error) { + return func(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (*schema.StreamReader[string], error) { + sr, err := endpoint(ctx, argumentsInJSON, opts...) + if err != nil { + return nil, err + } + sr.Close() + return schema.StreamReaderFromArray([]string{h.modifiedResult}), nil + }, nil +} + +type enhancedInvokableResultModifier struct { + *BaseChatModelAgentMiddleware + modifiedResult string +} + +func (h *enhancedInvokableResultModifier) WrapEnhancedInvokableToolCall(_ context.Context, endpoint EnhancedInvokableToolCallEndpoint, _ *ToolContext) (EnhancedInvokableToolCallEndpoint, error) { + return func(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (*schema.ToolResult, error) { + _, err := endpoint(ctx, toolArgument, opts...) + if err != nil { + return nil, err + } + return &schema.ToolResult{ + Parts: []schema.ToolOutputPart{{Type: schema.ToolPartTypeText, Text: h.modifiedResult}}, + }, nil + }, nil +} + +type enhancedStreamableResultModifier struct { + *BaseChatModelAgentMiddleware + modifiedResult string +} + +func (h *enhancedStreamableResultModifier) WrapEnhancedStreamableToolCall(_ context.Context, endpoint EnhancedStreamableToolCallEndpoint, _ *ToolContext) (EnhancedStreamableToolCallEndpoint, error) { + return func(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (*schema.StreamReader[*schema.ToolResult], error) { + sr, err := endpoint(ctx, toolArgument, opts...) + if err != nil { + return nil, err + } + sr.Close() + return schema.StreamReaderFromArray([]*schema.ToolResult{ + {Parts: []schema.ToolOutputPart{{Type: schema.ToolPartTypeText, Text: h.modifiedResult}}}, + }), nil + }, nil +} + +func collectToolEvents(it *AsyncIterator[*AgentEvent]) []*AgentEvent { + var toolEvents []*AgentEvent + for { + ev, ok := it.Next() + if !ok { + break + } + if ev.Output == nil || ev.Output.MessageOutput == nil { + continue + } + mo := ev.Output.MessageOutput + if mo.Message != nil && mo.Message.Role == schema.Tool { + toolEvents = append(toolEvents, ev) + continue + } + if mo.IsStreaming && mo.Role == schema.Tool && mo.MessageStream != nil { + toolEvents = append(toolEvents, ev) + } + } + return toolEvents +} + +func collectToolContent(events []*AgentEvent) []string { + var contents []string + for _, ev := range events { + mo := ev.Output.MessageOutput + if !mo.IsStreaming && mo.Message != nil { + if mo.Message.Content != "" { + contents = append(contents, mo.Message.Content) + } else if len(mo.Message.UserInputMultiContent) > 0 { + for _, part := range mo.Message.UserInputMultiContent { + if part.Text != "" { + contents = append(contents, part.Text) + } + } + } + continue + } + if mo.IsStreaming && mo.MessageStream != nil { + var msgs []*schema.Message + for { + msg, err := mo.MessageStream.Recv() + if err != nil { + break + } + msgs = append(msgs, msg) + } + if len(msgs) > 0 { + concated, err := schema.ConcatMessages(msgs) + if err == nil { + if concated.Content != "" { + contents = append(contents, concated.Content) + } else if len(concated.UserInputMultiContent) > 0 { + for _, part := range concated.UserInputMultiContent { + if part.Text != "" { + contents = append(contents, part.Text) + } + } + } + } + } + } + } + return contents +} + +func TestEventSenderToolHandler(t *testing.T) { + t.Run("Invokable", func(t *testing.T) { + t.Run("DefaultSendsEvent", func(t *testing.T) { + ctx := context.Background() + testTool := &invokableTestTool{name: "test_tool", result: "invokable_output"} + mockModel := &mockToolCallingModel{toolCallName: "test_tool"} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: mockModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{testTool}, + }, + }, + }) + assert.NoError(t, err) + + r := NewRunner(ctx, RunnerConfig{Agent: agent, EnableStreaming: false}) + it := r.Run(ctx, []Message{schema.UserMessage("test")}) + + toolEvents := collectToolEvents(it) + assert.Equal(t, 1, len(toolEvents)) + contents := collectToolContent(toolEvents) + assert.Contains(t, contents, "invokable_output") + }) + + t.Run("UserConfiguredSkipsDefault", func(t *testing.T) { + ctx := context.Background() + testTool := &invokableTestTool{name: "test_tool", result: "invokable_output"} + mockModel := &mockToolCallingModel{toolCallName: "test_tool"} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: mockModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{testTool}, + }, + }, + Handlers: []ChatModelAgentMiddleware{NewEventSenderToolWrapper()}, + }) + assert.NoError(t, err) + + r := NewRunner(ctx, RunnerConfig{Agent: agent, EnableStreaming: false}) + it := r.Run(ctx, []Message{schema.UserMessage("test")}) + + toolEvents := collectToolEvents(it) + assert.Equal(t, 1, len(toolEvents)) + }) + + t.Run("InnermostGetsOriginalOutput", func(t *testing.T) { + ctx := context.Background() + originalResult := "original_invokable_output" + modifiedResult := "modified_invokable_output" + testTool := &invokableTestTool{name: "test_tool", result: originalResult} + mockModel := &mockToolCallingModel{toolCallName: "test_tool"} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: mockModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{testTool}, + }, + }, + Handlers: []ChatModelAgentMiddleware{ + NewEventSenderToolWrapper(), + &invokableResultModifier{ + BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, + modifiedResult: modifiedResult, + }, + }, + }) + assert.NoError(t, err) + + r := NewRunner(ctx, RunnerConfig{Agent: agent, EnableStreaming: false}) + it := r.Run(ctx, []Message{schema.UserMessage("test")}) + + toolEvents := collectToolEvents(it) + assert.GreaterOrEqual(t, len(toolEvents), 1) + contents := collectToolContent(toolEvents) + assert.Contains(t, contents, originalResult) + }) + }) + + t.Run("Streamable", func(t *testing.T) { + t.Run("DefaultSendsEvent", func(t *testing.T) { + ctx := context.Background() + testTool := &streamableTestTool{name: "test_tool", result: "streamable_output"} + mockModel := &mockToolCallingModel{toolCallName: "test_tool"} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: mockModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{testTool}, + }, + }, + }) + assert.NoError(t, err) + + r := NewRunner(ctx, RunnerConfig{Agent: agent, EnableStreaming: true}) + it := r.Run(ctx, []Message{schema.UserMessage("test")}) + + toolEvents := collectToolEvents(it) + assert.Equal(t, 1, len(toolEvents)) + contents := collectToolContent(toolEvents) + assert.Contains(t, contents, "streamable_output") + }) + + t.Run("UserConfiguredSkipsDefault", func(t *testing.T) { + ctx := context.Background() + testTool := &streamableTestTool{name: "test_tool", result: "streamable_output"} + mockModel := &mockToolCallingModel{toolCallName: "test_tool"} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: mockModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{testTool}, + }, + }, + Handlers: []ChatModelAgentMiddleware{NewEventSenderToolWrapper()}, + }) + assert.NoError(t, err) + + r := NewRunner(ctx, RunnerConfig{Agent: agent, EnableStreaming: true}) + it := r.Run(ctx, []Message{schema.UserMessage("test")}) + + toolEvents := collectToolEvents(it) + assert.Equal(t, 1, len(toolEvents)) + }) + + t.Run("InnermostGetsOriginalOutput", func(t *testing.T) { + ctx := context.Background() + originalResult := "original_streamable_output" + modifiedResult := "modified_streamable_output" + testTool := &streamableTestTool{name: "test_tool", result: originalResult} + mockModel := &mockToolCallingModel{toolCallName: "test_tool"} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: mockModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{testTool}, + }, + }, + Handlers: []ChatModelAgentMiddleware{ + NewEventSenderToolWrapper(), + &streamableResultModifier{ + BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, + modifiedResult: modifiedResult, + }, + }, + }) + assert.NoError(t, err) + + r := NewRunner(ctx, RunnerConfig{Agent: agent, EnableStreaming: true}) + it := r.Run(ctx, []Message{schema.UserMessage("test")}) + + toolEvents := collectToolEvents(it) + assert.GreaterOrEqual(t, len(toolEvents), 1) + contents := collectToolContent(toolEvents) + assert.Contains(t, contents, originalResult) + }) + }) + + t.Run("EnhancedInvokable", func(t *testing.T) { + t.Run("DefaultSendsEvent", func(t *testing.T) { + ctx := context.Background() + testTool := &enhancedInvokableTestTool{name: "test_tool", result: "enhanced_invokable_output"} + mockModel := &mockToolCallingModel{toolCallName: "test_tool"} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: mockModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{testTool}, + }, + }, + }) + assert.NoError(t, err) + + r := NewRunner(ctx, RunnerConfig{Agent: agent, EnableStreaming: false}) + it := r.Run(ctx, []Message{schema.UserMessage("test")}) + + toolEvents := collectToolEvents(it) + assert.Equal(t, 1, len(toolEvents)) + contents := collectToolContent(toolEvents) + assert.Contains(t, contents, "enhanced_invokable_output") + }) + + t.Run("UserConfiguredSkipsDefault", func(t *testing.T) { + ctx := context.Background() + testTool := &enhancedInvokableTestTool{name: "test_tool", result: "enhanced_invokable_output"} + mockModel := &mockToolCallingModel{toolCallName: "test_tool"} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: mockModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{testTool}, + }, + }, + Handlers: []ChatModelAgentMiddleware{NewEventSenderToolWrapper()}, + }) + assert.NoError(t, err) + + r := NewRunner(ctx, RunnerConfig{Agent: agent, EnableStreaming: false}) + it := r.Run(ctx, []Message{schema.UserMessage("test")}) + + toolEvents := collectToolEvents(it) + assert.Equal(t, 1, len(toolEvents)) + }) + + t.Run("InnermostGetsOriginalOutput", func(t *testing.T) { + ctx := context.Background() + originalResult := "original_enhanced_invokable_output" + modifiedResult := "modified_enhanced_invokable_output" + testTool := &enhancedInvokableTestTool{name: "test_tool", result: originalResult} + mockModel := &mockToolCallingModel{toolCallName: "test_tool"} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: mockModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{testTool}, + }, + }, + Handlers: []ChatModelAgentMiddleware{ + NewEventSenderToolWrapper(), + &enhancedInvokableResultModifier{ + BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, + modifiedResult: modifiedResult, + }, + }, + }) + assert.NoError(t, err) + + r := NewRunner(ctx, RunnerConfig{Agent: agent, EnableStreaming: false}) + it := r.Run(ctx, []Message{schema.UserMessage("test")}) + + toolEvents := collectToolEvents(it) + assert.GreaterOrEqual(t, len(toolEvents), 1) + contents := collectToolContent(toolEvents) + assert.Contains(t, contents, originalResult) + }) + }) + + t.Run("EnhancedStreamable", func(t *testing.T) { + t.Run("DefaultSendsEvent", func(t *testing.T) { + ctx := context.Background() + testTool := &enhancedStreamableTestTool{name: "test_tool", result: "enhanced_streamable_output"} + mockModel := &mockToolCallingModel{toolCallName: "test_tool"} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: mockModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{testTool}, + }, + }, + }) + assert.NoError(t, err) + + r := NewRunner(ctx, RunnerConfig{Agent: agent, EnableStreaming: true}) + it := r.Run(ctx, []Message{schema.UserMessage("test")}) + + toolEvents := collectToolEvents(it) + assert.Equal(t, 1, len(toolEvents)) + contents := collectToolContent(toolEvents) + assert.Contains(t, contents, "enhanced_streamable_output") + }) + + t.Run("UserConfiguredSkipsDefault", func(t *testing.T) { + ctx := context.Background() + testTool := &enhancedStreamableTestTool{name: "test_tool", result: "enhanced_streamable_output"} + mockModel := &mockToolCallingModel{toolCallName: "test_tool"} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: mockModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{testTool}, + }, + }, + Handlers: []ChatModelAgentMiddleware{NewEventSenderToolWrapper()}, + }) + assert.NoError(t, err) + + r := NewRunner(ctx, RunnerConfig{Agent: agent, EnableStreaming: true}) + it := r.Run(ctx, []Message{schema.UserMessage("test")}) + + toolEvents := collectToolEvents(it) + assert.Equal(t, 1, len(toolEvents)) + }) + + t.Run("InnermostGetsOriginalOutput", func(t *testing.T) { + ctx := context.Background() + originalResult := "original_enhanced_streamable_output" + modifiedResult := "modified_enhanced_streamable_output" + testTool := &enhancedStreamableTestTool{name: "test_tool", result: originalResult} + mockModel := &mockToolCallingModel{toolCallName: "test_tool"} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: mockModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{testTool}, + }, + }, + Handlers: []ChatModelAgentMiddleware{ + NewEventSenderToolWrapper(), + &enhancedStreamableResultModifier{ + BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, + modifiedResult: modifiedResult, + }, + }, + }) + assert.NoError(t, err) + + r := NewRunner(ctx, RunnerConfig{Agent: agent, EnableStreaming: true}) + it := r.Run(ctx, []Message{schema.UserMessage("test")}) + + toolEvents := collectToolEvents(it) + assert.GreaterOrEqual(t, len(toolEvents), 1) + contents := collectToolContent(toolEvents) + assert.Contains(t, contents, originalResult) + }) + }) +} From d807af327b7e546dc5fbb608323504039d99f6a4 Mon Sep 17 00:00:00 2001 From: shentongmartin Date: Thu, 2 Apr 2026 20:17:29 +0800 Subject: [PATCH 48/59] fix(adk): prevent panic when orphaned tool goroutine sends event after agent cancellation (#929) * fix(adk): prevent panic when orphaned tool goroutine sends event after agent cancellation When CancelAfterChatModel times out and escalates to CancelImmediate, GraphInterrupt fires with timeout=0. The compose graph returns immediately, orphaning parallel tool goroutines. When an orphaned tool completes, eventSenderToolWrapper tries to send an event via the AsyncGenerator which is already closed, causing 'send on closed channel' panic. - Add isImmediateCancelled() to cancelContext for checking immediateChan - Make chatModelAgentExecCtx.send cancel-aware: skip send when immediate cancel is active - Use trySend as safety net for the TOCTOU race window - Route SendEvent() through execCtx.send() instead of direct generator.Send() Change-Id: Ic7e0194c860e2692a3cddc559911ab379024f650 * test(adk): add test for orphaned tool goroutine panic after CancelImmediate - unit_send_after_close: directly reproduces the panic by sending to a closed generator with isImmediateCancelled=true - unit_send_after_close_without_cancel_ctx: verifies trySend safety net prevents panic even without cancelCtx - integration_cancel_escalation_orphans_tool: end-to-end test with slow tool, CancelAfterChatModel timeout escalation, and orphaned goroutine Change-Id: Ia82fa957b102ccc2ac42094d18d4b15db2a1701c * test(adk): improve coverage for orphaned tool goroutine fix Add test cases for: - nil execCtx and nil generator defensive guards - nil cancelContext in isImmediateCancelled - TOCTOU race window (isImmediateCancelled=false but generator closed) - SendEvent public API with closed generator - SendEvent without exec context Change-Id: I197c36f34675f5376cbe5f830b15db6ca873cd1f --- adk/cancel.go | 16 ++++ adk/cancel_test.go | 186 +++++++++++++++++++++++++++++++++++++++++++++ adk/chatmodel.go | 11 ++- adk/handler.go | 2 +- adk/utils.go | 4 + 5 files changed, 216 insertions(+), 3 deletions(-) diff --git a/adk/cancel.go b/adk/cancel.go index 20d72bb20..f119d79f5 100644 --- a/adk/cancel.go +++ b/adk/cancel.go @@ -403,6 +403,22 @@ func (cc *cancelContext) shouldCancel() bool { } } +// isImmediateCancelled returns true if an immediate graph interrupt has been +// fired (CancelImmediate or timeout escalation). This is stronger than +// shouldCancel: it means the compose graph is being torn down right now and +// orphaned goroutines should not attempt to send events. +func (cc *cancelContext) isImmediateCancelled() bool { + if cc == nil { + return false + } + select { + case <-cc.immediateChan: + return true + default: + return false + } +} + // sendImmediateInterrupt sends the compose graph interrupt signal via graphInterruptFuncs. // Also closes immediateChan (used by cancelMonitoredModel to abort an in-progress stream). // Returns false if an interrupt was already sent or if no graphInterruptFuncs have been diff --git a/adk/cancel_test.go b/adk/cancel_test.go index 0d88db8cc..105c9ea13 100644 --- a/adk/cancel_test.go +++ b/adk/cancel_test.go @@ -2305,6 +2305,192 @@ func TestCancel_SequentialWorkflow_CancelAfterChatModel(t *testing.T) { assert.True(t, len(resumeEvents) > 0, "Resume should produce events") } +func TestCancelImmediate_OrphanedToolGoroutine_NoPanic(t *testing.T) { + t.Run("unit_send_after_close", func(t *testing.T) { + _, gen := NewAsyncIteratorPair[*AgentEvent]() + + cc := newCancelContext() + cc.setMode(CancelImmediate) + close(cc.cancelChan) + close(cc.immediateChan) + + gen.Close() + + execCtx := &chatModelAgentExecCtx{ + generator: gen, + cancelCtx: cc, + } + + assert.NotPanics(t, func() { + execCtx.send(&AgentEvent{AgentName: "test"}) + }, "send after generator.Close must not panic") + }) + + t.Run("unit_send_after_close_without_cancel_ctx", func(t *testing.T) { + _, gen := NewAsyncIteratorPair[*AgentEvent]() + gen.Close() + + execCtx := &chatModelAgentExecCtx{ + generator: gen, + } + + assert.NotPanics(t, func() { + execCtx.send(&AgentEvent{AgentName: "test"}) + }, "send after generator.Close must not panic even without cancelCtx (trySend safety net)") + }) + + t.Run("unit_send_nil_execCtx", func(t *testing.T) { + var execCtx *chatModelAgentExecCtx + assert.NotPanics(t, func() { + execCtx.send(&AgentEvent{AgentName: "test"}) + }, "send on nil execCtx must not panic") + }) + + t.Run("unit_send_nil_generator", func(t *testing.T) { + execCtx := &chatModelAgentExecCtx{} + assert.NotPanics(t, func() { + execCtx.send(&AgentEvent{AgentName: "test"}) + }, "send with nil generator must not panic") + }) + + t.Run("unit_isImmediateCancelled_nil_cancelContext", func(t *testing.T) { + var cc *cancelContext + assert.False(t, cc.isImmediateCancelled(), "nil cancelContext should return false") + }) + + t.Run("unit_trySend_race_window", func(t *testing.T) { + _, gen := NewAsyncIteratorPair[*AgentEvent]() + cc := newCancelContext() + + gen.Close() + + execCtx := &chatModelAgentExecCtx{ + generator: gen, + cancelCtx: cc, + } + + assert.NotPanics(t, func() { + execCtx.send(&AgentEvent{AgentName: "test"}) + }, "trySend must handle the case where isImmediateCancelled is false but generator is closed") + }) + + t.Run("unit_SendEvent_after_close", func(t *testing.T) { + _, gen := NewAsyncIteratorPair[*AgentEvent]() + + cc := newCancelContext() + cc.setMode(CancelImmediate) + close(cc.cancelChan) + close(cc.immediateChan) + + gen.Close() + + execCtx := &chatModelAgentExecCtx{ + generator: gen, + cancelCtx: cc, + } + + ctx := withChatModelAgentExecCtx(context.Background(), execCtx) + + assert.NotPanics(t, func() { + err := SendEvent(ctx, &AgentEvent{AgentName: "test"}) + assert.NoError(t, err) + }, "SendEvent after generator.Close must not panic") + }) + + t.Run("unit_SendEvent_no_execCtx", func(t *testing.T) { + err := SendEvent(context.Background(), &AgentEvent{AgentName: "test"}) + assert.Error(t, err, "SendEvent without execCtx should return error") + }) + + t.Run("integration_cancel_escalation_orphans_tool", func(t *testing.T) { + ctx := context.Background() + + toolStarted := make(chan struct{}, 1) + toolDone := make(chan struct{}, 1) + st := &slowToolWithSignal{ + name: "orphan_tool", + delay: 2 * time.Second, + result: "tool result", + startedChan: toolStarted, + } + + mdl := &simpleChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_orphan_1", + Type: "function", + Function: schema.FunctionCall{ + Name: "orphan_tool", + Arguments: `{"input": "test"}`, + }, + }, + }, + }, + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "OrphanTestAgent", + Description: "Test agent for orphaned tool goroutine panic", + Model: mdl, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + + cancelOpt, cancelFn := WithCancel() + iter := agent.Run(ctx, &AgentInput{ + Messages: []Message{schema.UserMessage("Use the tool")}, + }, cancelOpt) + assert.NotNil(t, iter) + + select { + case <-toolStarted: + case <-time.After(10 * time.Second): + t.Fatal("Tool did not start") + } + + timeout := 50 * time.Millisecond + handle, contributed := cancelFn( + WithAgentCancelMode(CancelAfterChatModel), + WithAgentCancelTimeout(timeout), + ) + assert.True(t, contributed, "Cancel should contribute") + + err = handle.Wait() + assert.True(t, err == nil || errors.Is(err, ErrCancelTimeout), + "handle.Wait should return nil or ErrCancelTimeout, got: %v", err) + + for { + _, ok := iter.Next() + if !ok { + break + } + } + + go func() { + time.Sleep(3 * time.Second) + select { + case toolDone <- struct{}{}: + default: + } + }() + + runtime.Gosched() + time.Sleep(3 * time.Second) + + select { + case <-toolDone: + default: + } + }) +} + // -- Tests for CancelImmediate in nested agent structures -- func newTestChatModel(response *schema.Message, delay time.Duration) *cancelTestChatModel { diff --git a/adk/chatmodel.go b/adk/chatmodel.go index d7382b309..4b736d51d 100644 --- a/adk/chatmodel.go +++ b/adk/chatmodel.go @@ -43,12 +43,17 @@ var _ ResumableAgent = &ChatModelAgent{} type chatModelAgentExecCtx struct { runtimeReturnDirectly map[string]bool generator *AsyncGenerator[*AgentEvent] + cancelCtx *cancelContext } func (e *chatModelAgentExecCtx) send(event *AgentEvent) { - if e != nil && e.generator != nil { - e.generator.Send(event) + if e == nil || e.generator == nil { + return + } + if e.cancelCtx != nil && e.cancelCtx.isImmediateCancelled() { + return } + e.generator.trySend(event) } type chatModelAgentExecCtxKey struct{} @@ -837,6 +842,7 @@ func (a *ChatModelAgent) buildNoToolsRunFunc(_ context.Context) runFunc { ctx = withChatModelAgentExecCtx(ctx, &chatModelAgentExecCtx{ generator: p.generator, + cancelCtx: cancelCtx, }) // Pre-execution cancel check @@ -947,6 +953,7 @@ func (a *ChatModelAgent) buildReActRunFunc(_ context.Context, bc *execContext) ( ctx = withChatModelAgentExecCtx(ctx, &chatModelAgentExecCtx{ runtimeReturnDirectly: p.returnDirectly, generator: p.generator, + cancelCtx: cancelCtx, }) // Pre-execution cancel check diff --git a/adk/handler.go b/adk/handler.go index 423282a7a..854063f16 100644 --- a/adk/handler.go +++ b/adk/handler.go @@ -347,7 +347,7 @@ func SendEvent(ctx context.Context, event *AgentEvent) error { if execCtx == nil || execCtx.generator == nil { return fmt.Errorf("SendEvent failed: must be called within a ChatModelAgent Run() or Resume() execution context") } - execCtx.generator.Send(event) + execCtx.send(event) return nil } diff --git a/adk/utils.go b/adk/utils.go index 6dbdbc2b5..24eba904b 100644 --- a/adk/utils.go +++ b/adk/utils.go @@ -44,6 +44,10 @@ func (ag *AsyncGenerator[T]) Send(v T) { ag.ch.Send(v) } +func (ag *AsyncGenerator[T]) trySend(v T) bool { + return ag.ch.TrySend(v) +} + func (ag *AsyncGenerator[T]) Close() { ag.ch.Close() } From 6255095597ee8d2735dbee7a0af1c8f755e997a7 Mon Sep 17 00:00:00 2001 From: shentongmartin Date: Wed, 8 Apr 2026 14:06:52 +0800 Subject: [PATCH 49/59] feat(adk): improve TurnLoop stop cleanup and add StopOption controls (#925) * fix(adk): keep late turn loop items Change-Id: Iabee0c25a83d5a25585d3592a41ca6a5fba35c2b * docs(adk): clarify cancel wait semantics Change-Id: Ia0a396b9cc2e43f15e85056d966f20b010dcd2b6 * feat(adk): add WithSkipCheckpoint and WithStopCause StopOptions Add two new StopOption variants for TurnLoop.Stop(): - WithSkipCheckpoint: prevents checkpoint persistence on stop, for cases where the caller does not intend to resume in the future. The flag is sticky across escalation calls. - WithStopCause: attaches a business-supplied reason string. Surfaced in TurnLoopExitState.StopCause and, after the Stopped channel closes, via TurnContext.StopCause(). Uses first-non-empty-wins semantics across multiple Stop() calls. Thread both fields through stopSignal with proper mutex protection. Update cleanup() to skip checkpoint save when skipCheckpoint is set. Change-Id: Ifeat-stop-options-skip-checkpoint-stop-cause --- adk/cancel.go | 10 + adk/turn_loop.go | 149 ++++++++- adk/turn_loop_test.go | 605 ++++++++++++++++++++++++++++++++++++- internal/core/interrupt.go | 8 +- 4 files changed, 748 insertions(+), 24 deletions(-) diff --git a/adk/cancel.go b/adk/cancel.go index f119d79f5..a15699f53 100644 --- a/adk/cancel.go +++ b/adk/cancel.go @@ -67,6 +67,16 @@ type CancelHandle struct { wait func() error } +// Wait blocks until the cancel request reaches a terminal outcome. +// +// It reports the result of the cancel operation itself, not the agent's final +// business error: +// - nil: cancellation succeeded, including the case where a business interrupt +// was absorbed into CancelError while cancellation was active +// - ErrCancelTimeout: the requested safe-point cancellation timed out and was +// escalated to immediate cancellation +// - ErrExecutionCompleted: the execution finished before cancellation took effect, +// meaning the stream drained to completion without any interrupt func (h *CancelHandle) Wait() error { return h.wait() } diff --git a/adk/turn_loop.go b/adk/turn_loop.go index 000842589..7d0f61a3b 100644 --- a/adk/turn_loop.go +++ b/adk/turn_loop.go @@ -61,6 +61,8 @@ type stopSignal struct { mu sync.Mutex gen uint64 agentCancelOpts []AgentCancelOption + skipCheckpoint bool + stopCause string // notify is a buffered(1) channel that wakes the current turn's watcher // when Stop() is called. Unlike done, it supports repeated Stop() calls // for cancel-mode escalation. @@ -81,6 +83,12 @@ func (s *stopSignal) signal(cfg *stopConfig) { s.mu.Lock() s.gen++ s.agentCancelOpts = cfg.agentCancelOpts + if cfg.skipCheckpoint { + s.skipCheckpoint = true + } + if cfg.stopCause != "" && s.stopCause == "" { + s.stopCause = cfg.stopCause + } s.mu.Unlock() select { case s.notify <- struct{}{}: @@ -111,6 +119,18 @@ func (s *stopSignal) check() (uint64, []AgentCancelOption) { return s.gen, append([]AgentCancelOption{}, s.agentCancelOpts...) } +func (s *stopSignal) isSkipCheckpoint() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.skipCheckpoint +} + +func (s *stopSignal) getStopCause() string { + s.mu.Lock() + defer s.mu.Unlock() + return s.stopCause +} + // preemptSignal coordinates preemption between Push callers and the run loop. // // Lifecycle overview: @@ -517,9 +537,11 @@ type TurnLoopExitState[T any] struct { // nil means clean exit (Stop() was called and completed normally). // Non-nil values include context errors, callback errors, *CancelError, etc. // When Stop() cancels a running agent, ExitReason will be a *CancelError. + // This never contains checkpoint errors — see CheckpointErr for those. ExitReason error // UnhandledItems contains items that were buffered but not processed. + // These are items for which Push returned true but were never consumed by a turn. // This is always valid regardless of ExitReason. UnhandledItems []T @@ -528,6 +550,33 @@ type TurnLoopExitState[T any] struct { // did not contribute to the final CancelError. // It can be used to reconstruct GenInput/PrepareAgent inputs when resuming. CanceledItems []T + + // StopCause is the business-supplied reason passed via WithStopCause. + // Empty if Stop was not called or no cause was provided. + StopCause string + + // Checkpointed indicates whether a checkpoint save was attempted during cleanup. + // True only when Store is configured, CheckpointID is set, Stop() was called, + // and the loop was not idle at exit time. + Checkpointed bool + + // CheckpointErr is the error from checkpoint save, if any. + // nil when Checkpointed is false (no attempt was made) or when the save succeeded. + CheckpointErr error + + // TakeLateItems returns items that were pushed after the loop stopped + // (i.e., Push returned false for these items). These items are NOT included + // in the checkpoint. + // + // This function is idempotent: the first call computes and caches the result; + // subsequent calls return the same slice. + // + // After TakeLateItems is called, any subsequent Push() will panic. This + // seals the late buffer and prevents items from being silently lost. + // + // It is safe to call TakeLateItems from any goroutine after Wait() returns. + // If TakeLateItems is never called, late items are simply garbage collected. + TakeLateItems func() []T } // TurnContext provides per-turn context to the OnAgentEvents callback. @@ -552,6 +601,11 @@ type TurnContext[T any] struct { // before it was finalized. Remains open if Stop did not contribute. // Use in a select to detect stop while processing events. Stopped <-chan struct{} + + // StopCause returns the business-supplied reason from WithStopCause. + // This value is only meaningful after the Stopped channel is closed. + // Before that, it returns an empty string. + StopCause func() string } // TurnLoop is a push-based event loop for agent execution. @@ -602,6 +656,19 @@ type TurnLoop[T any] struct { loadCheckpointID string onAgentEvents func(ctx context.Context, tc *TurnContext[T], events *AsyncIterator[*AgentEvent]) error + + lateMu sync.Mutex + lateItems []T + lateSealed bool +} + +func (l *TurnLoop[T]) appendLate(item T) { + l.lateMu.Lock() + defer l.lateMu.Unlock() + if l.lateSealed { + panic("TurnLoop: Push called after TakeLateItems") + } + l.lateItems = append(l.lateItems, item) } type turnLoopCheckpoint[T any] struct { @@ -652,7 +719,7 @@ func (l *TurnLoop[T]) deleteTurnLoopCheckpoint(ctx context.Context, checkPointID if deleter, ok := l.config.Store.(CheckPointDeleter); ok { return deleter.Delete(ctx, checkPointID) } - return l.config.Store.Set(ctx, checkPointID, nil) + return nil } func (l *TurnLoop[T]) tryLoadCheckpoint(ctx context.Context) error { @@ -712,6 +779,8 @@ type turnLoopPendingResume[T any] struct { type stopConfig struct { agentCancelOpts []AgentCancelOption + skipCheckpoint bool + stopCause string } // StopOption is an option for Stop(). @@ -725,6 +794,25 @@ func WithAgentCancel(opts ...AgentCancelOption) StopOption { } } +// WithSkipCheckpoint tells the TurnLoop not to persist a checkpoint for this +// Stop call. Use this when the caller does not intend to resume in the future. +// The flag is sticky: once any Stop() call sets it, subsequent calls cannot undo it. +func WithSkipCheckpoint() StopOption { + return func(cfg *stopConfig) { + cfg.skipCheckpoint = true + } +} + +// WithStopCause attaches a business-supplied reason string to this Stop call. +// The cause is surfaced in TurnLoopExitState.StopCause and, after the Stopped +// channel closes, via TurnContext.StopCause(). +// If multiple Stop() calls provide a cause, the first non-empty value wins. +func WithStopCause(cause string) StopOption { + return func(cfg *stopConfig) { + cfg.stopCause = cause + } +} + type pushConfig[T any] struct { preempt bool preemptDelay time.Duration @@ -852,6 +940,9 @@ func (l *TurnLoop[T]) Run(ctx context.Context) { // current agent run or reaching a point where no cancellation is needed). // If the loop has not been started yet (Run not called), items are buffered // and will be processed once Run is called. +// After Wait() returns, failed pushes can be recovered via TurnLoopExitState.TakeLateItems(). +// Once TakeLateItems() has been called, any subsequent push that would become a +// late item will panic instead of being silently dropped. // // Use WithPreempt() to atomically push an item and signal preemption of the current agent. // This is useful for urgent items that should interrupt the current processing. @@ -898,16 +989,22 @@ func (l *TurnLoop[T]) pushWithStrategy(item T, cfg *pushConfig[T]) (bool, <-chan if !cfg.preempt { l.preemptSig.unholdRunLoop() - return l.buffer.TrySend(item), nil + if !l.buffer.TrySend(item) { + l.appendLate(item) + return false, nil + } + return true, nil } if atomic.LoadInt32(&l.stopped) != 0 { l.preemptSig.unholdRunLoop() + l.appendLate(item) return false, nil } if !l.buffer.TrySend(item) { l.preemptSig.unholdRunLoop() + l.appendLate(item) return false, nil } @@ -936,6 +1033,7 @@ func (l *TurnLoop[T]) pushWithStrategy(item T, cfg *pushConfig[T]) (bool, <-chan func (l *TurnLoop[T]) pushWithConfig(item T, cfg *pushConfig[T]) (bool, <-chan struct{}) { if atomic.LoadInt32(&l.stopped) != 0 { + l.appendLate(item) return false, nil } @@ -944,6 +1042,7 @@ func (l *TurnLoop[T]) pushWithConfig(item T, cfg *pushConfig[T]) (bool, <-chan s if !l.buffer.TrySend(item) { l.preemptSig.unholdRunLoop() + l.appendLate(item) return false, nil } @@ -970,7 +1069,11 @@ func (l *TurnLoop[T]) pushWithConfig(item T, cfg *pushConfig[T]) (bool, <-chan s return true, ack } - return l.buffer.TrySend(item), nil + if !l.buffer.TrySend(item) { + l.appendLate(item) + return false, nil + } + return true, nil } // Stop signals the loop to stop and returns immediately (non-blocking). @@ -1295,6 +1398,7 @@ func (l *TurnLoop[T]) runAgentAndHandleEvents( Consumed: spec.consumed, Preempted: preemptDone, Stopped: stoppedDone, + StopCause: l.stopSig.getStopCause, } l.preemptSig.setTurn(ctx, tc) @@ -1398,7 +1502,18 @@ func (l *TurnLoop[T]) cleanup(ctx context.Context) { unhandled := l.buffer.TakeAll() checkpointID := l.config.CheckpointID isIdle := len(l.checkPointRunnerBytes) == 0 && len(unhandled) == 0 && len(l.canceledItems) == 0 - shouldSaveCheckpoint := l.config.Store != nil && checkpointID != "" && l.stopSig.isStopped() && !isIdle + + // Only save checkpoint when the loop exited due to an explicit Stop(). + // If Stop() was called but a callback error happened concurrently, + // the state may be inconsistent — don't checkpoint in that case. + // We consider the exit Stop-caused if runErr is nil (clean stop between + // turns) or a *CancelError (Stop canceled a running agent). + exitCausedByStop := l.runErr == nil || errors.As(l.runErr, new(*CancelError)) + shouldSaveCheckpoint := l.config.Store != nil && checkpointID != "" && l.stopSig.isStopped() && exitCausedByStop && !isIdle && !l.stopSig.isSkipCheckpoint() + + var checkpointed bool + var checkpointErr error + if shouldSaveCheckpoint { cp := &turnLoopCheckpoint[T]{ RunnerCheckpoint: l.checkPointRunnerBytes, @@ -1406,23 +1521,31 @@ func (l *TurnLoop[T]) cleanup(ctx context.Context) { UnhandledItems: unhandled, CanceledItems: l.canceledItems, } - err := l.saveTurnLoopCheckpoint(ctx, checkpointID, cp) - if err != nil { - saveErr := fmt.Errorf("failed to save turn loop checkpoint: %w", err) - if l.runErr != nil { - l.runErr = fmt.Errorf("%w; %v", l.runErr, saveErr) - } else { - l.runErr = saveErr - } - } + checkpointed = true + checkpointErr = l.saveTurnLoopCheckpoint(ctx, checkpointID, cp) } else if l.loadCheckpointID != "" { _ = l.deleteTurnLoopCheckpoint(ctx, l.loadCheckpointID) } + var takeLateOnce sync.Once + var takeLateResult []T + l.result = &TurnLoopExitState[T]{ ExitReason: l.runErr, UnhandledItems: unhandled, CanceledItems: l.canceledItems, + StopCause: l.stopSig.getStopCause(), + Checkpointed: checkpointed, + CheckpointErr: checkpointErr, + TakeLateItems: func() []T { + takeLateOnce.Do(func() { + l.lateMu.Lock() + takeLateResult = append([]T{}, l.lateItems...) + l.lateSealed = true + l.lateMu.Unlock() + }) + return takeLateResult + }, } l.preemptSig.drainAll() diff --git a/adk/turn_loop_test.go b/adk/turn_loop_test.go index 72753fc5e..1b8b2c86d 100644 --- a/adk/turn_loop_test.go +++ b/adk/turn_loop_test.go @@ -1853,7 +1853,9 @@ func TestTurnLoop_CheckpointSaveError_ReturnsError(t *testing.T) { loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) exit := loop.Wait() assert.Error(t, exit.ExitReason) - assert.Contains(t, exit.ExitReason.Error(), "write failed") + assert.True(t, exit.Checkpointed) + assert.Error(t, exit.CheckpointErr) + assert.Contains(t, exit.CheckpointErr.Error(), "write failed") } func TestTurnLoop_StaleCheckpointDeletion_OnCleanResume(t *testing.T) { @@ -1916,7 +1918,7 @@ func TestTurnLoop_StaleCheckpointDeletion_OnCleanResume(t *testing.T) { func TestTurnLoop_StaleCheckpointDeletion_ContextCancel(t *testing.T) { ctx := context.Background() - store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + store := &deletableCheckpointStore{turnLoopCheckpointStore: turnLoopCheckpointStore{m: make(map[string][]byte)}} cpID := "delete-on-cancel" loop1 := NewTurnLoop(TurnLoopConfig[string]{ @@ -1968,11 +1970,10 @@ func TestTurnLoop_StaleCheckpointDeletion_ContextCancel(t *testing.T) { assert.ErrorIs(t, exit2.ExitReason, context.Canceled) store.mu.Lock() - v, exists := store.m[cpID] + _, exists = store.m[cpID] + deleteCalled := store.deleteCalled store.mu.Unlock() - deletedViaNil := exists && v == nil - deletedViaAbsence := !exists - assert.True(t, deletedViaNil || deletedViaAbsence, "stale checkpoint should be deleted when loop exits via context cancellation") + assert.True(t, deleteCalled && !exists, "stale checkpoint should be deleted when loop exits via context cancellation") } type deletableCheckpointStore struct { @@ -2325,10 +2326,11 @@ func TestTurnLoop_CheckpointSaveError_MergesWithExistingError(t *testing.T) { loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) exit := loop.Wait() assert.Error(t, exit.ExitReason) - errStr := exit.ExitReason.Error() - assert.Contains(t, errStr, "disk full") var ce *CancelError - assert.True(t, errors.As(exit.ExitReason, &ce), "should wrap original CancelError") + assert.True(t, errors.As(exit.ExitReason, &ce), "ExitReason should be CancelError, not merged with checkpoint error") + assert.True(t, exit.Checkpointed) + assert.Error(t, exit.CheckpointErr) + assert.Contains(t, exit.CheckpointErr.Error(), "disk full") } func TestTurnLoop_ResumeWithParams(t *testing.T) { @@ -3748,3 +3750,588 @@ func TestTurnLoop_PushStrategy_ConsumedInspection(t *testing.T) { loop.Stop() loop.Wait() } + +func TestTurnLoop_PushAfterStop_BufferedAsLateItems(t *testing.T) { + ctx := context.Background() + processed := make(chan string, 10) + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := events.Next(); !ok { + break + } + } + processed <- tc.Consumed[0] + return nil + }, + }) + + loop.Push("msg1") + <-processed + loop.Stop() + result := loop.Wait() + + // Push after stop — should be buffered as late items + ok1, _ := loop.Push("late1") + ok2, _ := loop.Push("late2") + ok3, _ := loop.Push("late3") + assert.False(t, ok1) + assert.False(t, ok2) + assert.False(t, ok3) + + late := result.TakeLateItems() + assert.Equal(t, []string{"late1", "late2", "late3"}, late) +} + +func TestTurnLoop_TakeLateItems_Idempotent(t *testing.T) { + ctx := context.Background() + + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop.Push("a") + loop.Stop() + loop.Run(ctx) + result := loop.Wait() + + loop.Push("late1") + + first := result.TakeLateItems() + second := result.TakeLateItems() + third := result.TakeLateItems() + + assert.Equal(t, []string{"late1"}, first) + assert.Equal(t, first, second, "subsequent calls should return the same slice") + assert.Equal(t, first, third, "subsequent calls should return the same slice") +} + +func TestTurnLoop_PushAfterTakeLateItems_Panics(t *testing.T) { + ctx := context.Background() + + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop.Push("a") + loop.Stop() + loop.Run(ctx) + result := loop.Wait() + + result.TakeLateItems() + + assert.PanicsWithValue(t, "TurnLoop: Push called after TakeLateItems", func() { + loop.Push("too-late") + }) +} + +func TestTurnLoop_TakeLateItems_NeverCalled_NoImpact(t *testing.T) { + ctx := context.Background() + + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop.Push("a") + loop.Push("b") + loop.Stop() + loop.Run(ctx) + result := loop.Wait() + + // Don't call TakeLateItems — verify UnhandledItems works normally + assert.Contains(t, result.UnhandledItems, "b") + assert.Nil(t, result.ExitReason) +} + +func TestTurnLoop_CheckpointErr_SeparateFromExitReason(t *testing.T) { + ctx := context.Background() + saveStore := &errorCheckpointStore{setErr: fmt.Errorf("storage unavailable")} + + loop := NewTurnLoop(TurnLoopConfig[string]{ + Store: saveStore, + CheckpointID: "cp-separate-err", + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop.Push("a") + loop.Stop() + loop.Run(ctx) + result := loop.Wait() + + // ExitReason should be nil (clean stop), checkpoint error should be separate + assert.Nil(t, result.ExitReason) + assert.True(t, result.Checkpointed) + assert.Error(t, result.CheckpointErr) + assert.Contains(t, result.CheckpointErr.Error(), "storage unavailable") +} + +func TestTurnLoop_Checkpointed_FalseWhenNoStore(t *testing.T) { + ctx := context.Background() + + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop.Push("a") + loop.Stop() + loop.Run(ctx) + result := loop.Wait() + + assert.False(t, result.Checkpointed) + assert.Nil(t, result.CheckpointErr) +} + +func TestTurnLoop_Checkpointed_FalseOnErrorExit(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + genInputErr := errors.New("gen input failed") + + firstTurnDone := make(chan struct{}) + var callCount int32 + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + Store: store, + CheckpointID: "cp-err-exit", + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + n := atomic.AddInt32(&callCount, 1) + if n > 1 { + return nil, genInputErr + } + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := events.Next(); !ok { + break + } + } + close(firstTurnDone) + return nil + }, + }) + loop.Push("msg1") + <-firstTurnDone + loop.Push("msg2") + result := loop.Wait() + + // Loop exited from error, not Stop() — checkpoint should not be saved + assert.ErrorIs(t, result.ExitReason, genInputErr) + assert.False(t, result.Checkpointed) + assert.Nil(t, result.CheckpointErr) +} + +func TestTurnLoop_StopConcurrentWithCallbackError_NoCheckpoint(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + cpID := "stop-concurrent-err" + + prepareErr := errors.New("prepare agent failed") + firstTurnDone := make(chan struct{}) + stopCalled := make(chan struct{}) + var prepareCount int32 + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + n := atomic.AddInt32(&prepareCount, 1) + if n > 1 { + // Wait until Stop() has been called so stopSig.isStopped() is true + <-stopCalled + return nil, prepareErr + } + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := events.Next(); !ok { + break + } + } + close(firstTurnDone) + return nil + }, + }) + + loop.Push("msg1") + <-firstTurnDone + loop.Push("msg2") + + // Call Stop() and signal PrepareAgent to proceed with error + go func() { + loop.Stop() + close(stopCalled) + }() + + result := loop.Wait() + + // The loop may exit via Stop (clean) or via PrepareAgent error. + // If it exited via PrepareAgent error with Stop also called: + // checkpoint should NOT be saved. + if result.ExitReason != nil && !errors.As(result.ExitReason, new(*CancelError)) { + assert.ErrorIs(t, result.ExitReason, prepareErr) + assert.False(t, result.Checkpointed, "should not checkpoint when exit is caused by callback error") + } + // If Stop won the race, that's fine — checkpoint may or may not be saved + // depending on idle state. The test is about the error path. +} + +func TestTurnLoop_DeleteWithoutCheckPointDeleter_NoOp(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + cpID := "no-deleter" + + // First loop: save a checkpoint + loop1 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop1.Push("a") + loop1.Stop() + loop1.Run(ctx) + loop1.Wait() + + store.mu.Lock() + _, exists := store.m[cpID] + store.mu.Unlock() + assert.True(t, exists, "checkpoint should be saved") + + // Second loop: exit via context cancel — should try to delete but store + // doesn't implement CheckPointDeleter, so checkpoint persists (no-op) + ctx2, cancel2 := context.WithCancel(ctx) + loop2 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := events.Next(); !ok { + break + } + } + cancel2() + return nil + }, + }) + loop2.Push("b") + loop2.Run(ctx2) + loop2.Wait() + + // Without CheckPointDeleter, the stale checkpoint should NOT be deleted + store.mu.Lock() + v, exists := store.m[cpID] + store.mu.Unlock() + assert.True(t, exists, "checkpoint should still exist without CheckPointDeleter") + assert.NotNil(t, v, "checkpoint should not be set to nil") +} + +func TestTurnLoop_StopWithSkipCheckpoint(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + cpID := "skip-cp-session" + + loop := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Push("a") + loop.Push("b") + loop.Stop(WithSkipCheckpoint()) + loop.Run(ctx) + + exit := loop.Wait() + assert.NoError(t, exit.ExitReason) + assert.False(t, exit.Checkpointed, "checkpoint should be skipped when WithSkipCheckpoint is used") + + store.mu.Lock() + _, exists := store.m[cpID] + store.mu.Unlock() + assert.False(t, exists, "no checkpoint should be saved when WithSkipCheckpoint is used") +} + +func TestTurnLoop_StopWithSkipCheckpoint_DeletesStaleCheckpoint(t *testing.T) { + ctx := context.Background() + store := &deletableCheckpointStore{ + turnLoopCheckpointStore: turnLoopCheckpointStore{m: make(map[string][]byte)}, + } + cpID := "skip-stale-session" + + loop1 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop1.Push("a") + loop1.Stop() + loop1.Run(ctx) + exit1 := loop1.Wait() + assert.True(t, exit1.Checkpointed) + + store.mu.Lock() + _, exists := store.m[cpID] + store.mu.Unlock() + assert.True(t, exists, "first loop should save checkpoint") + + loop2 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop2.Push("b") + loop2.Stop(WithSkipCheckpoint()) + loop2.Run(ctx) + exit2 := loop2.Wait() + assert.False(t, exit2.Checkpointed, "second loop should skip checkpoint") + + store.mu.Lock() + deleteCalled := store.deleteCalled + store.mu.Unlock() + assert.True(t, deleteCalled, "stale checkpoint should be deleted when SkipCheckpoint is used") +} + +func TestTurnLoop_StopWithStopCause(t *testing.T) { + ctx := context.Background() + cause := "user session timeout" + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Push("a") + loop.Stop(WithStopCause(cause)) + + exit := loop.Wait() + assert.Equal(t, cause, exit.StopCause) +} + +func TestTurnLoop_StopCause_EmptyWhenNoStop(t *testing.T) { + ctx := context.Background() + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Stop() + exit := loop.Wait() + assert.Empty(t, exit.StopCause) +} + +func TestTurnLoop_StopCause_InTurnContext(t *testing.T) { + cause := "business shutdown" + gotCause := make(chan string, 1) + agentStarted := make(chan struct{}) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopCancellableMockAgent{ + name: "slow", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + <-ctx.Done() + return nil, ctx.Err() + }, + }, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + close(agentStarted) + select { + case <-tc.Stopped: + gotCause <- tc.StopCause() + case <-time.After(5 * time.Second): + t.Error("timed out waiting for Stopped channel") + } + for { + if _, ok := events.Next(); !ok { + break + } + } + return nil + }, + }) + + loop.Push("msg1") + <-agentStarted + loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate)), WithStopCause(cause)) + + select { + case c := <-gotCause: + assert.Equal(t, cause, c) + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for StopCause in TurnContext") + } + + exit := loop.Wait() + assert.Equal(t, cause, exit.StopCause) +} + +func TestTurnLoop_StopCause_FirstNonEmptyWins(t *testing.T) { + agentStarted := make(chan struct{}) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopCancellableMockAgent{ + name: "slow", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + <-ctx.Done() + return nil, ctx.Err() + }, + }, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + close(agentStarted) + for { + if _, ok := events.Next(); !ok { + break + } + } + return nil + }, + }) + + loop.Push("msg1") + <-agentStarted + loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelAfterToolCalls)), WithStopCause("first cause")) + loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate)), WithStopCause("second cause")) + + exit := loop.Wait() + assert.Equal(t, "first cause", exit.StopCause, "first non-empty StopCause should win") +} + +func TestTurnLoop_SkipCheckpoint_Sticky(t *testing.T) { + agentStarted := make(chan struct{}) + + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + cpID := "sticky-skip-session" + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopCancellableMockAgent{ + name: "slow", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + <-ctx.Done() + return nil, ctx.Err() + }, + }, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + close(agentStarted) + for { + if _, ok := events.Next(); !ok { + break + } + } + return nil + }, + }) + + loop.Push("msg1") + <-agentStarted + loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelAfterToolCalls)), WithSkipCheckpoint()) + loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + + exit := loop.Wait() + assert.False(t, exit.Checkpointed, "SkipCheckpoint should be sticky across multiple Stop calls") + + store.mu.Lock() + _, exists := store.m[cpID] + store.mu.Unlock() + assert.False(t, exists, "no checkpoint should be saved when SkipCheckpoint was set in any Stop call") +} diff --git a/internal/core/interrupt.go b/internal/core/interrupt.go index 174e0c47c..38ddbdae0 100644 --- a/internal/core/interrupt.go +++ b/internal/core/interrupt.go @@ -30,8 +30,12 @@ type CheckPointStore interface { } // CheckPointDeleter is an optional interface that CheckPointStore implementations -// can implement to support explicit checkpoint deletion. If the Store does not -// implement this interface, deletion is performed by writing an empty value via Set. +// can implement to support explicit checkpoint deletion. +// +// If the Store does not implement this interface, stale checkpoints will NOT be +// automatically cleaned up. The store owner is responsible for managing checkpoint +// lifecycle in that case (e.g., via TTL, external cleanup, or implementing this +// interface). type CheckPointDeleter interface { Delete(ctx context.Context, checkPointID string) error } From d0e13180e3b8ef8ca4bc5dec2a3beb1d4e60de4e Mon Sep 17 00:00:00 2001 From: Born Date: Wed, 8 Apr 2026 14:13:14 +0800 Subject: [PATCH 50/59] feat(compose): support tool name and argument aliases in ToolsNode (#931) --- compose/tool_alias_test.go | 1178 ++++++++++++++++++++++++++++++++++++ compose/tool_node.go | 303 +++++++++- 2 files changed, 1459 insertions(+), 22 deletions(-) create mode 100644 compose/tool_alias_test.go diff --git a/compose/tool_alias_test.go b/compose/tool_alias_test.go new file mode 100644 index 000000000..487132cbe --- /dev/null +++ b/compose/tool_alias_test.go @@ -0,0 +1,1178 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "context" + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/cloudwego/eino/components/tool" + "github.com/cloudwego/eino/schema" +) + +type searchArgs struct { + Query string `json:"query"` +} + +func TestToolNameAliases(t *testing.T) { + ctx := context.Background() + + // Create test tool + searchTool := newTool(&schema.ToolInfo{ + Name: "search", + Desc: "Search for information", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "query": {Type: "string", Desc: "Search query"}, + }), + }, func(ctx context.Context, args *searchArgs) (string, error) { + return "search result", nil + }) + + // Configure aliases + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"search_v1", "query", "find"}, + }, + }, + } + + node, err := NewToolNode(ctx, config) + require.NoError(t, err) + + // Test calling tool with alias + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "search_v1", // Using alias + Arguments: `{"query": "test"}`, + }, + }, + }) + + output, err := node.Invoke(ctx, input) + require.NoError(t, err) + require.Len(t, output, 1) + assert.Equal(t, "call_1", output[0].ToolCallID) + assert.Contains(t, output[0].Content, "search result") +} + +type searchArgsWithLimit struct { + Query string `json:"query"` + Limit int `json:"limit"` +} + +func TestArgumentsAliases(t *testing.T) { + ctx := context.Background() + + receivedArgs := "" + searchTool := newTool(&schema.ToolInfo{ + Name: "search", + Desc: "Search for information", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "query": {Type: "string"}, + "limit": {Type: "integer"}, + }), + }, func(ctx context.Context, args *searchArgsWithLimit) (string, error) { + b, _ := json.Marshal(args) + receivedArgs = string(b) + return "result", nil + }) + + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + ArgumentsAliases: map[string][]string{ + "query": {"q", "search_term"}, + "limit": {"max_results", "count"}, + }, + }, + }, + } + + node, err := NewToolNode(ctx, config) + require.NoError(t, err) + + // Use alias parameters + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "search", + Arguments: `{"q": "test", "max_results": 10}`, // Using aliases + }, + }, + }) + + _, err = node.Invoke(ctx, input) + require.NoError(t, err) + + // Verify tool received canonical parameter names + var args map[string]any + err = json.Unmarshal([]byte(receivedArgs), &args) + require.NoError(t, err) + assert.Equal(t, "test", args["query"]) + assert.Equal(t, float64(10), args["limit"]) + assert.NotContains(t, args, "q") + assert.NotContains(t, args, "max_results") +} + +type emptyArgs struct{} + +func TestAliasConflict(t *testing.T) { + ctx := context.Background() + + tool1 := newTool(&schema.ToolInfo{Name: "search", Desc: "Search"}, func(ctx context.Context, args *emptyArgs) (string, error) { + return "result", nil + }) + tool2 := newTool(&schema.ToolInfo{Name: "query", Desc: "Query"}, func(ctx context.Context, args *emptyArgs) (string, error) { + return "result", nil + }) + + t.Run("tool name alias conflict", func(t *testing.T) { + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{tool1, tool2}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"find"}, + }, + "query": { + NameAliases: []string{"find"}, // Conflict: find already used by search + }, + }, + } + + _, err := NewToolNode(ctx, config) + require.Error(t, err) + assert.Contains(t, err.Error(), "conflicts with an alias already registered for") + }) + + t.Run("tool name alias conflicts with canonical name", func(t *testing.T) { + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{tool1, tool2}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"query"}, // Conflict: "query" is tool2's canonical name + }, + }, + } + + _, err := NewToolNode(ctx, config) + require.Error(t, err) + assert.Contains(t, err.Error(), "conflicts with existing tool's canonical name") + }) + + t.Run("argument alias conflict", func(t *testing.T) { + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{tool1}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + ArgumentsAliases: map[string][]string{ + "query": {"q"}, + "limit": {"q"}, // Conflict: q maps to multiple parameters + }, + }, + }, + } + + _, err := NewToolNode(ctx, config) + require.Error(t, err) + assert.Contains(t, err.Error(), "conflicting arg alias") + }) + + t.Run("arg alias conflicts with existing schema property", func(t *testing.T) { + searchWithParams := newTool(&schema.ToolInfo{ + Name: "search", + Desc: "Search", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "query": {Type: "string"}, + "limit": {Type: "integer"}, + }), + }, func(ctx context.Context, args *emptyArgs) (string, error) { + return "result", nil + }) + + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchWithParams}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + ArgumentsAliases: map[string][]string{ + "limit": {"query"}, // "query" is already a schema property + }, + }, + }, + } + + _, err := NewToolNode(ctx, config) + require.Error(t, err) + assert.Contains(t, err.Error(), "conflicts with existing schema property") + }) +} + +func TestArgumentsAliasesWithHandler(t *testing.T) { + ctx := context.Background() + + executionOrder := []string{} + + searchTool := newTool(&schema.ToolInfo{ + Name: "search", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "query": {Type: "string"}, + }), + }, func(ctx context.Context, args *searchArgs) (string, error) { + executionOrder = append(executionOrder, "tool_invoke") + return "result", nil + }) + + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"find"}, + ArgumentsAliases: map[string][]string{ + "query": {"q"}, + }, + }, + }, + ToolArgumentsHandler: func(ctx context.Context, name, args string) (string, error) { + executionOrder = append(executionOrder, "args_handler") + // Handler receives the original model-returned name (alias) + assert.Equal(t, "search", name) + // Verify alias remapping has already been done + var m map[string]any + err := json.Unmarshal([]byte(args), &m) + require.NoError(t, err) + assert.Contains(t, m, "query") + assert.NotContains(t, m, "q") + return args, nil + }, + } + + node, err := NewToolNode(ctx, config) + require.NoError(t, err) + + // Call with alias name "find" and alias arg "q" + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "find", + Arguments: `{"q": "test"}`, + }, + }, + }) + + _, err = node.Invoke(ctx, input) + require.NoError(t, err) + + // Verify execution order: alias remapping → ToolArgumentsHandler → tool execution + assert.Equal(t, []string{"args_handler", "tool_invoke"}, executionOrder) +} + +func TestNonExistentToolInAliasConfig(t *testing.T) { + ctx := context.Background() + + tool1 := newTool(&schema.ToolInfo{Name: "search", Desc: "Search"}, func(ctx context.Context, args *emptyArgs) (string, error) { + return "result", nil + }) + + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{tool1}, + ToolAliases: map[string]ToolAliasConfig{ + "non_existent_tool": { // Non-existent tool + NameAliases: []string{"alias1"}, + }, + }, + } + + // Should not error — non-existent tool alias configs are silently skipped + node, err := NewToolNode(ctx, config) + require.NoError(t, err) + + // The existing tool should still work normally + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "search", + Arguments: `{}`, + }, + }, + }) + output, err := node.Invoke(ctx, input) + require.NoError(t, err) + require.Len(t, output, 1) + assert.Contains(t, output[0].Content, "result") +} + +type weatherArgs struct { + Location string `json:"location"` +} + +func TestToolAliasesE2E(t *testing.T) { + ctx := context.Background() + + // Create multiple tools + searchTool := newTool(&schema.ToolInfo{ + Name: "search", + Desc: "Search for information", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "query": {Type: "string"}, + "limit": {Type: "integer"}, + }), + }, func(ctx context.Context, args *searchArgsWithLimit) (string, error) { + return "search result", nil + }) + + weatherTool := newTool(&schema.ToolInfo{ + Name: "weather", + Desc: "Get weather information", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "location": {Type: "string"}, + }), + }, func(ctx context.Context, args *weatherArgs) (string, error) { + return "weather result", nil + }) + + // Configure aliases for multiple tools + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool, weatherTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"search_v1", "query"}, + ArgumentsAliases: map[string][]string{ + "query": {"q", "search_term"}, + "limit": {"max_results"}, + }, + }, + "weather": { + NameAliases: []string{"get_weather"}, + ArgumentsAliases: map[string][]string{ + "location": {"loc", "city"}, + }, + }, + }, + } + + node, err := NewToolNode(ctx, config) + require.NoError(t, err) + + // Construct message with multiple tool calls using different aliases + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "search_v1", // Tool name alias + Arguments: `{"q": "test", "max_results": 5}`, // Parameter aliases + }, + }, + { + ID: "call_2", + Function: schema.FunctionCall{ + Name: "get_weather", // Tool name alias + Arguments: `{"city": "Beijing"}`, // Parameter alias + }, + }, + }) + + output, err := node.Invoke(ctx, input) + require.NoError(t, err) + require.Len(t, output, 2) + + // Verify both tools executed successfully + assert.Equal(t, "call_1", output[0].ToolCallID) + assert.Equal(t, "call_2", output[1].ToolCallID) + assert.Contains(t, output[0].Content, "search result") + assert.Contains(t, output[1].Content, "weather result") +} + +func TestRemapArgsEdgeCases(t *testing.T) { + aliasMap := map[string]string{"q": "query"} + + t.Run("empty string", func(t *testing.T) { + result, err := remapArgs("", aliasMap) + assert.NoError(t, err) + assert.Equal(t, "", result) + }) + + t.Run("whitespace only", func(t *testing.T) { + result, err := remapArgs(" ", aliasMap) + assert.NoError(t, err) + assert.Equal(t, " ", result) + }) + + t.Run("non-object JSON", func(t *testing.T) { + result, err := remapArgs(`"hello"`, aliasMap) + assert.NoError(t, err) + assert.Equal(t, `"hello"`, result) + }) + + t.Run("JSON array", func(t *testing.T) { + result, err := remapArgs(`[1,2,3]`, aliasMap) + assert.NoError(t, err) + assert.Equal(t, `[1,2,3]`, result) + }) + + t.Run("invalid JSON", func(t *testing.T) { + result, err := remapArgs(`{invalid`, aliasMap) + assert.NoError(t, err) + assert.Equal(t, `{invalid`, result) + }) + + t.Run("alias and canonical both present", func(t *testing.T) { + // When both alias "q" and canonical "query" exist, alias is kept as-is (not deleted, not overwritten) + result, err := remapArgs(`{"q": "alias_val", "query": "canonical_val"}`, aliasMap) + assert.NoError(t, err) + var m map[string]any + require.NoError(t, json.Unmarshal([]byte(result), &m)) + assert.Equal(t, "canonical_val", m["query"]) + assert.Equal(t, "alias_val", m["q"]) + }) + + t.Run("unknown fields preserved", func(t *testing.T) { + result, err := remapArgs(`{"q": "test", "unknown_field": 42}`, aliasMap) + assert.NoError(t, err) + var m map[string]any + require.NoError(t, json.Unmarshal([]byte(result), &m)) + assert.Equal(t, "test", m["query"]) + assert.NotContains(t, m, "q") + assert.Equal(t, float64(42), m["unknown_field"]) + }) +} + +func TestCanonicalNameCallWithAliasConfigured(t *testing.T) { + ctx := context.Background() + + searchTool := newTool(&schema.ToolInfo{ + Name: "search", + Desc: "Search", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "query": {Type: "string"}, + }), + }, func(ctx context.Context, args *searchArgs) (string, error) { + return "result: " + args.Query, nil + }) + + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"find"}, + ArgumentsAliases: map[string][]string{ + "query": {"q"}, + }, + }, + }, + } + + node, err := NewToolNode(ctx, config) + require.NoError(t, err) + + // Call with canonical name and canonical arg — should work normally + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "search", + Arguments: `{"query": "hello"}`, + }, + }, + }) + + output, err := node.Invoke(ctx, input) + require.NoError(t, err) + require.Len(t, output, 1) + assert.Contains(t, output[0].Content, "result: hello") +} + +func TestEmptyAliasValidation(t *testing.T) { + ctx := context.Background() + + searchTool := newTool(&schema.ToolInfo{Name: "search", Desc: "Search"}, func(ctx context.Context, args *emptyArgs) (string, error) { + return "result", nil + }) + + t.Run("empty name alias", func(t *testing.T) { + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{""}, + }, + }, + } + _, err := NewToolNode(ctx, config) + require.Error(t, err) + assert.Contains(t, err.Error(), "empty name alias") + }) + + t.Run("empty arg alias", func(t *testing.T) { + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + ArgumentsAliases: map[string][]string{ + "query": {""}, + }, + }, + }, + } + _, err := NewToolNode(ctx, config) + require.Error(t, err) + assert.Contains(t, err.Error(), "empty argument alias") + }) + + t.Run("empty canonical arg key", func(t *testing.T) { + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + ArgumentsAliases: map[string][]string{ + "": {"q"}, + }, + }, + }, + } + _, err := NewToolNode(ctx, config) + require.Error(t, err) + assert.Contains(t, err.Error(), "empty canonical argument key") + }) +} + +func TestNameAliasSameAsCanonical(t *testing.T) { + ctx := context.Background() + + searchTool := newTool(&schema.ToolInfo{Name: "search", Desc: "Search"}, func(ctx context.Context, args *emptyArgs) (string, error) { + return "result", nil + }) + + // Alias same as canonical name — should be tolerated (skip, no error) + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"search", "find"}, + }, + }, + } + + node, err := NewToolNode(ctx, config) + require.NoError(t, err) + + // Both canonical and alias should work + for _, name := range []string{"search", "find"} { + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: name, + Arguments: `{}`, + }, + }, + }) + output, err := node.Invoke(ctx, input) + require.NoError(t, err) + require.Len(t, output, 1) + assert.Contains(t, output[0].Content, "result") + } +} + +func TestToolAliasesWithDynamicToolList(t *testing.T) { + ctx := context.Background() + + searchTool := newTool(&schema.ToolInfo{ + Name: "search", + Desc: "Search", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "query": {Type: "string"}, + }), + }, func(ctx context.Context, args *searchArgs) (string, error) { + return "search result: " + args.Query, nil + }) + + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"find"}, + ArgumentsAliases: map[string][]string{ + "query": {"q"}, + }, + }, + }, + } + + node, err := NewToolNode(ctx, config) + require.NoError(t, err) + + // Use dynamic ToolList via option — alias should still work + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "find", + Arguments: `{"q": "dynamic"}`, + }, + }, + }) + + output, err := node.Invoke(ctx, input, WithToolList(searchTool)) + require.NoError(t, err) + require.Len(t, output, 1) + assert.Contains(t, output[0].Content, "search result: dynamic") +} + +func TestToolNameAliasesStream(t *testing.T) { + ctx := context.Background() + + searchTool := newTool(&schema.ToolInfo{ + Name: "search", + Desc: "Search for information", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "query": {Type: "string"}, + }), + }, func(ctx context.Context, args *searchArgs) (string, error) { + return "stream result: " + args.Query, nil + }) + + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"find"}, + ArgumentsAliases: map[string][]string{ + "query": {"q"}, + }, + }, + }, + } + + node, err := NewToolNode(ctx, config) + require.NoError(t, err) + + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "find", + Arguments: `{"q": "hello"}`, + }, + }, + }) + + reader, err := node.Stream(ctx, input) + require.NoError(t, err) + + var chunks [][]*schema.Message + for { + chunk, err := reader.Recv() + if err != nil { + break + } + chunks = append(chunks, chunk) + } + + msgs, err := schema.ConcatMessageArray(chunks) + require.NoError(t, err) + require.Len(t, msgs, 1) + assert.Equal(t, "call_1", msgs[0].ToolCallID) + assert.Contains(t, msgs[0].Content, "stream result: hello") +} + +func TestEnhancedToolWithAliases(t *testing.T) { + ctx := context.Background() + + enhancedTool := &enhancedInvokableTool{ + info: &schema.ToolInfo{ + Name: "search", + Desc: "Enhanced search", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "query": {Type: "string"}, + }), + }, + fn: func(ctx context.Context, input *schema.ToolArgument) (*schema.ToolResult, error) { + return &schema.ToolResult{ + Parts: []schema.ToolOutputPart{ + {Type: schema.ToolPartTypeText, Text: "enhanced: " + input.Text}, + }, + }, nil + }, + } + + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{enhancedTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"find"}, + ArgumentsAliases: map[string][]string{ + "query": {"q"}, + }, + }, + }, + } + + node, err := NewToolNode(ctx, config) + require.NoError(t, err) + + // Call with alias name and alias arg + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "find", + Arguments: `{"q": "test"}`, + }, + }, + }) + + output, err := node.Invoke(ctx, input) + require.NoError(t, err) + require.Len(t, output, 1) + assert.Equal(t, "call_1", output[0].ToolCallID) + // Verify arg alias was remapped: "q" → "query" in the JSON passed to enhanced tool + assert.Contains(t, output[0].UserInputMultiContent[0].Text, "enhanced:") +} + +func TestDynamicToolListAliasRemoved(t *testing.T) { + ctx := context.Background() + + searchTool := newTool(&schema.ToolInfo{ + Name: "search", + Desc: "Search", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "query": {Type: "string"}, + }), + }, func(ctx context.Context, args *searchArgs) (string, error) { + return "search result", nil + }) + + weatherTool := newTool(&schema.ToolInfo{ + Name: "weather", + Desc: "Weather", + }, func(ctx context.Context, args *emptyArgs) (string, error) { + return "weather result", nil + }) + + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool, weatherTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"find"}, + }, + }, + } + + node, err := NewToolNode(ctx, config) + require.NoError(t, err) + + // Dynamic tool list only contains weatherTool — "search" and its alias "find" should not be available + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "find", + Arguments: `{}`, + }, + }, + }) + + _, err = node.Invoke(ctx, input, WithToolList(weatherTool)) + require.Error(t, err) + assert.Contains(t, err.Error(), "not found") +} + +func TestToolAliasesOptionOverridesGlobal(t *testing.T) { + ctx := context.Background() + + searchTool := newTool(&schema.ToolInfo{ + Name: "search", + Desc: "Search", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "query": {Type: "string"}, + }), + }, func(ctx context.Context, args *searchArgs) (string, error) { + return "search result: " + args.Query, nil + }) + + weatherTool := newTool(&schema.ToolInfo{ + Name: "weather", + Desc: "Weather", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "location": {Type: "string"}, + }), + }, func(ctx context.Context, args *weatherArgs) (string, error) { + return "weather result: " + args.Location, nil + }) + + // Global aliases: search has alias "find" + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool, weatherTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"find"}, + ArgumentsAliases: map[string][]string{ + "query": {"q"}, + }, + }, + }, + } + + node, err := NewToolNode(ctx, config) + require.NoError(t, err) + + t.Run("opt ToolAliases overrides global in Invoke", func(t *testing.T) { + // opt.ToolAliases defines "lookup" as alias for search (not "find") + optAliases := map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"lookup"}, + ArgumentsAliases: map[string][]string{ + "query": {"keyword"}, + }, + }, + } + + // "lookup" should work with opt aliases + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "lookup", + Arguments: `{"keyword": "test"}`, + }, + }, + }) + + output, err := node.Invoke(ctx, input, WithToolList(searchTool), WithToolAliases(optAliases)) + require.NoError(t, err) + require.Len(t, output, 1) + assert.Contains(t, output[0].Content, "search result: test") + + // "find" (global alias) should NOT work when opt.ToolAliases is set + input2 := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_2", + Function: schema.FunctionCall{ + Name: "find", + Arguments: `{"q": "test"}`, + }, + }, + }) + + _, err = node.Invoke(ctx, input2, WithToolList(searchTool), WithToolAliases(optAliases)) + require.Error(t, err) + assert.Contains(t, err.Error(), "not found") + }) + + t.Run("opt ToolAliases overrides global in Stream", func(t *testing.T) { + optAliases := map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"lookup"}, + ArgumentsAliases: map[string][]string{ + "query": {"keyword"}, + }, + }, + } + + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "lookup", + Arguments: `{"keyword": "stream_test"}`, + }, + }, + }) + + reader, err := node.Stream(ctx, input, WithToolList(searchTool), WithToolAliases(optAliases)) + require.NoError(t, err) + + var chunks [][]*schema.Message + for { + chunk, err := reader.Recv() + if err != nil { + break + } + chunks = append(chunks, chunk) + } + + msgs, err := schema.ConcatMessageArray(chunks) + require.NoError(t, err) + require.Len(t, msgs, 1) + assert.Contains(t, msgs[0].Content, "search result: stream_test") + }) + + t.Run("nil opt ToolAliases falls back to global filtered", func(t *testing.T) { + // No WithToolAliases — should use global "find" alias, filtered by ToolList + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "find", + Arguments: `{"q": "fallback"}`, + }, + }, + }) + + output, err := node.Invoke(ctx, input, WithToolList(searchTool)) + require.NoError(t, err) + require.Len(t, output, 1) + assert.Contains(t, output[0].Content, "search result: fallback") + }) + + t.Run("opt ToolAliases only without ToolList replaces global", func(t *testing.T) { + // Only WithToolAliases, no WithToolList — should use global tools with opt aliases + optAliases := map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"lookup"}, + ArgumentsAliases: map[string][]string{ + "query": {"keyword"}, + }, + }, + } + + // "lookup" (opt alias) should work + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "lookup", + Arguments: `{"keyword": "only_alias"}`, + }, + }, + }) + + output, err := node.Invoke(ctx, input, WithToolAliases(optAliases)) + require.NoError(t, err) + require.Len(t, output, 1) + assert.Contains(t, output[0].Content, "search result: only_alias") + + // "find" (global alias) should NOT work when opt.ToolAliases replaces global + input2 := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_2", + Function: schema.FunctionCall{ + Name: "find", + Arguments: `{"q": "test"}`, + }, + }, + }) + + _, err = node.Invoke(ctx, input2, WithToolAliases(optAliases)) + require.Error(t, err) + assert.Contains(t, err.Error(), "not found") + }) + + t.Run("opt ToolAliases only without ToolList in Stream", func(t *testing.T) { + optAliases := map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"lookup"}, + }, + } + + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "lookup", + Arguments: `{"query": "stream_only_alias"}`, + }, + }, + }) + + reader, err := node.Stream(ctx, input, WithToolAliases(optAliases)) + require.NoError(t, err) + + var chunks [][]*schema.Message + for { + chunk, err := reader.Recv() + if err != nil { + break + } + chunks = append(chunks, chunk) + } + + msgs, err := schema.ConcatMessageArray(chunks) + require.NoError(t, err) + require.Len(t, msgs, 1) + assert.Contains(t, msgs[0].Content, "search result: stream_only_alias") + }) +} + +func TestAliasConfigForToolAddedViaOption(t *testing.T) { + ctx := context.Background() + + searchTool := newTool(&schema.ToolInfo{ + Name: "search", + Desc: "Search", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "query": {Type: "string"}, + }), + }, func(ctx context.Context, args *searchArgs) (string, error) { + return "search result: " + args.Query, nil + }) + + weatherTool := newTool(&schema.ToolInfo{ + Name: "weather", + Desc: "Weather", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "location": {Type: "string"}, + }), + }, func(ctx context.Context, args *weatherArgs) (string, error) { + return "weather result: " + args.Location, nil + }) + + // New with only searchTool, but alias config includes weather tool + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"find"}, + ArgumentsAliases: map[string][]string{ + "query": {"q"}, + }, + }, + "weather": { + NameAliases: []string{"forecast"}, + ArgumentsAliases: map[string][]string{ + "location": {"loc"}, + }, + }, + }, + } + + node, err := NewToolNode(ctx, config) + require.NoError(t, err) + + t.Run("weather alias works when tool passed via option", func(t *testing.T) { + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "forecast", + Arguments: `{"loc": "Beijing"}`, + }, + }, + }) + + output, err := node.Invoke(ctx, input, WithToolList(searchTool, weatherTool)) + require.NoError(t, err) + require.Len(t, output, 1) + assert.Contains(t, output[0].Content, "weather result: Beijing") + }) + + t.Run("search alias still works with option tool list", func(t *testing.T) { + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "find", + Arguments: `{"q": "test"}`, + }, + }, + }) + + output, err := node.Invoke(ctx, input, WithToolList(searchTool, weatherTool)) + require.NoError(t, err) + require.Len(t, output, 1) + assert.Contains(t, output[0].Content, "search result: test") + }) +} + +func TestOptionWithToolListAndToolAliases(t *testing.T) { + ctx := context.Background() + + searchTool := newTool(&schema.ToolInfo{ + Name: "search", + Desc: "Search", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "query": {Type: "string"}, + }), + }, func(ctx context.Context, args *searchArgs) (string, error) { + return "search result: " + args.Query, nil + }) + + weatherTool := newTool(&schema.ToolInfo{ + Name: "weather", + Desc: "Weather", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "location": {Type: "string"}, + }), + }, func(ctx context.Context, args *weatherArgs) (string, error) { + return "weather result: " + args.Location, nil + }) + + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"find"}, + }, + }, + } + + node, err := NewToolNode(ctx, config) + require.NoError(t, err) + + t.Run("opt aliases override global when both tool list and aliases provided", func(t *testing.T) { + optAliases := map[string]ToolAliasConfig{ + "weather": { + NameAliases: []string{"forecast"}, + ArgumentsAliases: map[string][]string{ + "location": {"loc"}, + }, + }, + } + + // "forecast" should work via opt aliases + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "forecast", + Arguments: `{"loc": "Shanghai"}`, + }, + }, + }) + + output, err := node.Invoke(ctx, input, WithToolList(searchTool, weatherTool), WithToolAliases(optAliases)) + require.NoError(t, err) + require.Len(t, output, 1) + assert.Contains(t, output[0].Content, "weather result: Shanghai") + + // "find" (global alias) should NOT work when opt aliases override + input2 := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_2", + Function: schema.FunctionCall{ + Name: "find", + Arguments: `{"query": "test"}`, + }, + }, + }) + + _, err = node.Invoke(ctx, input2, WithToolList(searchTool, weatherTool), WithToolAliases(optAliases)) + require.Error(t, err) + assert.Contains(t, err.Error(), "not found") + }) +} diff --git a/compose/tool_node.go b/compose/tool_node.go index a8f98a866..f65037e90 100644 --- a/compose/tool_node.go +++ b/compose/tool_node.go @@ -18,11 +18,16 @@ package compose import ( "context" + "encoding/json" "errors" "fmt" "runtime/debug" + "sort" + "strings" "sync" + "github.com/bytedance/sonic" + "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/components" "github.com/cloudwego/eino/components/tool" @@ -33,6 +38,8 @@ import ( type toolsNodeOptions struct { ToolOptions []tool.Option ToolList []tool.BaseTool + + ToolAliases map[string]ToolAliasConfig } // ToolsNodeOption is the option func type for ToolsNode. @@ -52,6 +59,15 @@ func WithToolList(tool ...tool.BaseTool) ToolsNodeOption { } } +// WithToolAliases sets the tool aliases for the ToolsNode call option. +// When used with WithToolList, it overrides the global alias configuration for the dynamic tool list. +// When used alone (without WithToolList), it replaces the global alias configuration while keeping the original tool list. +func WithToolAliases(toolAliases map[string]ToolAliasConfig) ToolsNodeOption { + return func(o *toolsNodeOptions) { + o.ToolAliases = toolAliases + } +} + // ToolsNode represents a node capable of executing tools within a graph. // The Graph Node interface is defined as follows: // @@ -62,6 +78,7 @@ func WithToolList(tool ...tool.BaseTool) ToolsNodeOption { // Output: An array of ToolMessage where the order of elements corresponds to the order of ToolCalls in the input type ToolsNode struct { tuple *toolsTuple + tools []tool.BaseTool unknownToolHandler func(ctx context.Context, name, input string) (string, error) executeSequentially bool toolArgumentsHandler func(ctx context.Context, name, input string) (string, error) @@ -69,6 +86,7 @@ type ToolsNode struct { streamToolCallMiddlewares []StreamableToolMiddleware enhancedToolCallMiddlewares []EnhancedInvokableToolMiddleware enhancedStreamToolCallMiddlewares []EnhancedStreamableToolMiddleware + toolAliasConfigs map[string]ToolAliasConfig } // ToolInput represents the input parameters for a tool call execution. @@ -150,11 +168,30 @@ type ToolMiddleware struct { EnhancedStreamable EnhancedStreamableToolMiddleware } +// ToolAliasConfig configures name and argument aliases for a single tool. +type ToolAliasConfig struct { + // NameAliases are alternative names for this tool. + // If the model returns any of these names, it will be resolved to the canonical tool name. + NameAliases []string + + // ArgumentsAliases maps canonical argument keys to their alias lists. + // key=canonical, value=[]alias. Applied to top-level JSON keys before tool execution. + // Example: {"query": ["q", "search_term"], "limit": ["max_results", "count"]} + ArgumentsAliases map[string][]string +} + // ToolsNodeConfig is the config for ToolsNode. type ToolsNodeConfig struct { // Tools specify the list of tools can be called which are BaseTool but must implement InvokableTool or StreamableTool. Tools []tool.BaseTool + // ToolAliases configures name and argument aliases for tools. + // Key is the canonical tool name, value defines its aliases. + // This field is optional. When provided, tool name aliases will be resolved during tool dispatch, + // and argument aliases will be remapped before ToolArgumentsHandler (if configured) and tool execution. + // Execution order: ArgumentsAliases remapping → ToolArgumentsHandler → tool execution + ToolAliases map[string]ToolAliasConfig + // UnknownToolsHandler handles tool calls for non-existent tools when LLM hallucinates. // This field is optional. When not set, calling a non-existent tool will result in an error. // When provided, if the LLM attempts to call a tool that doesn't exist in the Tools list, @@ -219,13 +256,22 @@ func NewToolNode(ctx context.Context, conf *ToolsNodeConfig) (*ToolsNode, error) } } - tuple, err := convTools(ctx, conf.Tools, middlewares, streamMiddlewares, enhancedInvokableMiddlewares, enhancedStreamableMiddlewares) + params := convToolsParams{ + tools: conf.Tools, + aliasConfigs: conf.ToolAliases, + } + params.middlewares.invokable = middlewares + params.middlewares.streamable = streamMiddlewares + params.middlewares.enhancedInvokable = enhancedInvokableMiddlewares + params.middlewares.enhancedStreamable = enhancedStreamableMiddlewares + tuple, err := convTools(ctx, params) if err != nil { return nil, err } return &ToolsNode{ tuple: tuple, + tools: conf.Tools, unknownToolHandler: conf.UnknownToolsHandler, executeSequentially: conf.ExecuteSequentially, toolArgumentsHandler: conf.ToolArgumentsHandler, @@ -233,6 +279,7 @@ func NewToolNode(ctx context.Context, conf *ToolsNodeConfig) (*ToolsNode, error) streamToolCallMiddlewares: streamMiddlewares, enhancedToolCallMiddlewares: enhancedInvokableMiddlewares, enhancedStreamToolCallMiddlewares: enhancedStreamableMiddlewares, + toolAliasConfigs: conf.ToolAliases, }, nil } @@ -273,19 +320,184 @@ type toolsTuple struct { streamEndpoints []StreamableToolEndpoint enhancedInvokableEndpoints []EnhancedInvokableToolEndpoint enhancedStreamableEndpoints []EnhancedStreamableToolEndpoint + // argsAliasMap stores reverse argument alias mappings for each tool. + // key: canonical tool name, value: map[aliasKey]canonicalKey (alias → canonical direction) + argsAliasMap map[string]map[string]string + // canonicalNames stores the canonical name for each tool index + canonicalNames []string + // toolInfos stores the ToolInfo for each tool index, used for alias validation + toolInfos []*schema.ToolInfo +} + +// remapArgs replaces alias keys in the JSON arguments string with canonical keys. +// aliasMap: alias → canonical mapping +func remapArgs(args string, aliasMap map[string]string) (string, error) { + if len(aliasMap) == 0 { + return args, nil + } + + trimmed := strings.TrimSpace(args) + if trimmed == "" || trimmed[0] != '{' { + return args, nil + } + + var m map[string]json.RawMessage + if err := sonic.Unmarshal([]byte(args), &m); err != nil { + return args, nil + } + + changed := false + for alias, canonical := range aliasMap { + if v, ok := m[alias]; ok { + // Only replace if canonical key doesn't exist. + // If both alias and canonical are present (e.g. {"q":"a","query":"b"}), + // the alias key is kept as-is and passed through as an unknown field. + if _, exists := m[canonical]; !exists { + m[canonical] = v + delete(m, alias) + changed = true + } + } + } + + if !changed { + return args, nil + } + + b, err := sonic.Marshal(m) + return string(b), err +} + +type convToolsParams struct { + tools []tool.BaseTool + middlewares struct { + invokable []InvokableToolMiddleware + streamable []StreamableToolMiddleware + enhancedInvokable []EnhancedInvokableToolMiddleware + enhancedStreamable []EnhancedStreamableToolMiddleware + } + aliasConfigs map[string]ToolAliasConfig +} + +func (t *toolsTuple) applyAliasConfigs(aliasConfigs map[string]ToolAliasConfig) error { + t.argsAliasMap = make(map[string]map[string]string) + + sortedToolNames := make([]string, 0, len(aliasConfigs)) + for toolName := range aliasConfigs { + sortedToolNames = append(sortedToolNames, toolName) + } + sort.Strings(sortedToolNames) + + for _, toolName := range sortedToolNames { + aliasConfig := aliasConfigs[toolName] + var ( + toolIdx int + exists bool + ) + if toolIdx, exists = t.indexes[toolName]; !exists { + continue + } + + if err := t.applyNameAliases(toolName, toolIdx, aliasConfig.NameAliases); err != nil { + return err + } + + if err := t.applyArgsAliases(toolName, toolIdx, aliasConfig.ArgumentsAliases); err != nil { + return err + } + } + + return nil +} + +// applyNameAliases validates and registers name aliases for a single tool into the indexes map. +func (t *toolsTuple) applyNameAliases(toolName string, toolIdx int, nameAliases []string) error { + for _, alias := range nameAliases { + if strings.TrimSpace(alias) == "" { + return fmt.Errorf("tool '%s' has empty name alias", toolName) + } + if existingIdx, conflict := t.indexes[alias]; conflict { + if existingIdx != toolIdx { + conflictToolName := t.canonicalNames[existingIdx] + if alias == conflictToolName { + return fmt.Errorf("tool '%s': name alias '%s' conflicts with existing tool's canonical name", toolName, alias) + } + return fmt.Errorf("tool '%s': name alias '%s' conflicts with an alias already registered for tool '%s'", toolName, alias, conflictToolName) + } + continue + } + t.indexes[alias] = toolIdx + } + return nil } -func convTools(ctx context.Context, tools []tool.BaseTool, ms []InvokableToolMiddleware, sms []StreamableToolMiddleware, - ems []EnhancedInvokableToolMiddleware, esms []EnhancedStreamableToolMiddleware) (*toolsTuple, error) { +// applyArgsAliases validates argument aliases against the tool schema and builds a reverse alias map for a single tool. +func (t *toolsTuple) applyArgsAliases(toolName string, toolIdx int, argumentsAliases map[string][]string) error { + if len(argumentsAliases) == 0 { + return nil + } + + schemaKeys := make(map[string]bool) + if info := t.toolInfos[toolIdx]; info != nil && info.ParamsOneOf != nil { + js, err := info.ParamsOneOf.ToJSONSchema() + if err != nil { + return fmt.Errorf("tool '%s': failed to parse JSON schema for alias validation: %w", toolName, err) + } + if js != nil && js.Properties != nil { + for pair := js.Properties.Oldest(); pair != nil; pair = pair.Next() { + schemaKeys[pair.Key] = true + } + } + } + + reverseMap := make(map[string]string) + sortedCanonicals := make([]string, 0, len(argumentsAliases)) + for canonical := range argumentsAliases { + sortedCanonicals = append(sortedCanonicals, canonical) + } + sort.Strings(sortedCanonicals) + + for _, canonical := range sortedCanonicals { + aliases := argumentsAliases[canonical] + if strings.TrimSpace(canonical) == "" { + return fmt.Errorf("tool '%s' has empty canonical argument key", toolName) + } + if strings.Contains(canonical, ".") { + return fmt.Errorf("tool '%s' has unsupported '.' in canonical argument key '%s': nested field matching is not yet supported", + toolName, canonical) + } + for _, alias := range aliases { + if strings.TrimSpace(alias) == "" { + return fmt.Errorf("tool '%s' has empty argument alias for canonical key '%s'", toolName, canonical) + } + if schemaKeys[alias] { + return fmt.Errorf("tool '%s' has arg alias '%s' that conflicts with existing schema property '%s'", + toolName, alias, alias) + } + if existingCanonical, conflict := reverseMap[alias]; conflict { + return fmt.Errorf("tool '%s' has conflicting arg alias '%s' mapped to both '%s' and '%s'", + toolName, alias, existingCanonical, canonical) + } + reverseMap[alias] = canonical + } + } + t.argsAliasMap[toolName] = reverseMap + + return nil +} + +func convTools(ctx context.Context, params convToolsParams) (*toolsTuple, error) { ret := &toolsTuple{ indexes: make(map[string]int), - meta: make([]*executorMeta, len(tools)), - endpoints: make([]InvokableToolEndpoint, len(tools)), - streamEndpoints: make([]StreamableToolEndpoint, len(tools)), - enhancedInvokableEndpoints: make([]EnhancedInvokableToolEndpoint, len(tools)), - enhancedStreamableEndpoints: make([]EnhancedStreamableToolEndpoint, len(tools)), + meta: make([]*executorMeta, len(params.tools)), + endpoints: make([]InvokableToolEndpoint, len(params.tools)), + streamEndpoints: make([]StreamableToolEndpoint, len(params.tools)), + enhancedInvokableEndpoints: make([]EnhancedInvokableToolEndpoint, len(params.tools)), + enhancedStreamableEndpoints: make([]EnhancedStreamableToolEndpoint, len(params.tools)), + canonicalNames: make([]string, len(params.tools)), + toolInfos: make([]*schema.ToolInfo, len(params.tools)), } - for idx, bt := range tools { + for idx, bt := range params.tools { tl, err := bt.Info(ctx) if err != nil { return nil, fmt.Errorf("(NewToolNode) failed to get tool info at idx= %d: %w", idx, err) @@ -310,19 +522,19 @@ func convTools(ctx context.Context, tools []tool.BaseTool, ms []InvokableToolMid meta = parseExecutorInfoFromComponent(components.ComponentOfTool, bt) if st, ok = bt.(tool.StreamableTool); ok { - streamable = wrapStreamToolCall(st, sms, !meta.isComponentCallbackEnabled) + streamable = wrapStreamToolCall(st, params.middlewares.streamable, !meta.isComponentCallbackEnabled) } if it, ok = bt.(tool.InvokableTool); ok { - invokable = wrapToolCall(it, ms, !meta.isComponentCallbackEnabled) + invokable = wrapToolCall(it, params.middlewares.invokable, !meta.isComponentCallbackEnabled) } if eiTool, ok = bt.(tool.EnhancedInvokableTool); ok { - enhancedInvokable = wrapEnhancedInvokableToolCall(eiTool, ems, !meta.isComponentCallbackEnabled) + enhancedInvokable = wrapEnhancedInvokableToolCall(eiTool, params.middlewares.enhancedInvokable, !meta.isComponentCallbackEnabled) } if esTool, ok = bt.(tool.EnhancedStreamableTool); ok { - enhancedStreamable = wrapEnhancedStreamableToolCall(esTool, esms, !meta.isComponentCallbackEnabled) + enhancedStreamable = wrapEnhancedStreamableToolCall(esTool, params.middlewares.enhancedStreamable, !meta.isComponentCallbackEnabled) } if st == nil && it == nil && eiTool == nil && esTool == nil { @@ -348,7 +560,16 @@ func convTools(ctx context.Context, tools []tool.BaseTool, ms []InvokableToolMid ret.streamEndpoints[idx] = streamable ret.enhancedInvokableEndpoints[idx] = enhancedInvokable ret.enhancedStreamableEndpoints[idx] = enhancedStreamable + ret.canonicalNames[idx] = toolName + ret.toolInfos[idx] = tl } + + if len(params.aliasConfigs) > 0 { + if err := ret.applyAliasConfigs(params.aliasConfigs); err != nil { + return nil, err + } + } + return ret, nil } @@ -616,14 +837,27 @@ func (tn *ToolsNode) genToolCallTasks(ctx context.Context, tuple *toolsTuple, toolCallTasks[i].useEnhanced = false } + // Get canonical tool name for looking up argument aliases + canonicalToolName := tuple.canonicalNames[index] + + // Process argument aliases remapping + args := toolCall.Function.Arguments + if aliasMap, hasAliases := tuple.argsAliasMap[canonicalToolName]; hasAliases { + remappedArgs, err := remapArgs(args, aliasMap) + if err != nil { + return nil, fmt.Errorf("failed to remap args for tool[name:%s]: %w", canonicalToolName, err) + } + args = remappedArgs + } + if tn.toolArgumentsHandler != nil { - arg, err := tn.toolArgumentsHandler(ctx, toolCall.Function.Name, toolCall.Function.Arguments) + arg, err := tn.toolArgumentsHandler(ctx, canonicalToolName, args) if err != nil { - return nil, fmt.Errorf("failed to executed tool[name:%s arguments:%s] arguments handler: %w", toolCall.Function.Name, toolCall.Function.Arguments, err) + return nil, fmt.Errorf("failed to executed tool[name:%s arguments:%s] arguments handler: %w", toolCall.Function.Name, args, err) } toolCallTasks[i].arg = arg } else { - toolCallTasks[i].arg = toolCall.Function.Arguments + toolCallTasks[i].arg = args } } } @@ -782,6 +1016,31 @@ func parallelRunToolCall(ctx context.Context, wg.Wait() } +// buildTupleFromOpts rebuilds a toolsTuple when call options override tools or aliases. +func (tn *ToolsNode) buildTupleFromOpts(ctx context.Context, opt *toolsNodeOptions) (*toolsTuple, error) { + tools := opt.ToolList + if tools == nil { + tools = tn.tools + } + aliasConfigs := opt.ToolAliases + if aliasConfigs == nil { + aliasConfigs = tn.toolAliasConfigs + } + p := convToolsParams{ + tools: tools, + aliasConfigs: aliasConfigs, + } + p.middlewares.invokable = tn.toolCallMiddlewares + p.middlewares.streamable = tn.streamToolCallMiddlewares + p.middlewares.enhancedInvokable = tn.enhancedToolCallMiddlewares + p.middlewares.enhancedStreamable = tn.enhancedStreamToolCallMiddlewares + tuple, err := convTools(ctx, p) + if err != nil { + return nil, fmt.Errorf("failed to convert tool list from call option: %w", err) + } + return tuple, nil +} + // Invoke calls the tools and collects the results of invokable tools. // it's parallel if there are multiple tool calls in the input message. func (tn *ToolsNode) Invoke(ctx context.Context, input *schema.Message, @@ -789,11 +1048,11 @@ func (tn *ToolsNode) Invoke(ctx context.Context, input *schema.Message, opt := getToolsNodeOptions(opts...) tuple := tn.tuple - if opt.ToolList != nil { + if opt.ToolList != nil || opt.ToolAliases != nil { var err error - tuple, err = convTools(ctx, opt.ToolList, tn.toolCallMiddlewares, tn.streamToolCallMiddlewares, tn.enhancedToolCallMiddlewares, tn.enhancedStreamToolCallMiddlewares) + tuple, err = tn.buildTupleFromOpts(ctx, opt) if err != nil { - return nil, fmt.Errorf("failed to convert tool list from call option: %w", err) + return nil, err } } @@ -891,11 +1150,11 @@ func (tn *ToolsNode) Stream(ctx context.Context, input *schema.Message, opt := getToolsNodeOptions(opts...) tuple := tn.tuple - if opt.ToolList != nil { + if opt.ToolList != nil || opt.ToolAliases != nil { var err error - tuple, err = convTools(ctx, opt.ToolList, tn.toolCallMiddlewares, tn.streamToolCallMiddlewares, tn.enhancedToolCallMiddlewares, tn.enhancedStreamToolCallMiddlewares) + tuple, err = tn.buildTupleFromOpts(ctx, opt) if err != nil { - return nil, fmt.Errorf("failed to convert tool list from call option: %w", err) + return nil, err } } From 4a45c5bf54d5b006a9097985af3f903d0d3725e7 Mon Sep 17 00:00:00 2001 From: Ryo Date: Thu, 9 Apr 2026 11:41:14 +0800 Subject: [PATCH 51/59] feat(adk): add failover support for ChatModel (#885) * fix: rebase error Change-Id: If20fa78dba82a1c177c8ec47090050ea8c1354ed * feat(adk): add failover support for ChatModel Change-Id: Ice1b513b4b509e7b540316da9119ff3d529c9bae * feat(adk): add failover support for ChatModel Change-Id: Ice1b513b4b509e7b540316da9119ff3d529c9bae * feat(adk): add failover support for ChatModel Change-Id: Id5483447b74322f6dd495bdd3b994c001094569d * feat(adk): make Name and Description optional in ChatModelAgentConfig * feat(adk): add callback lifecycle management to failoverProxyModel - Extract prepareCallbacks method to reuse callback setup logic between Generate and Stream methods - Add callbacks.ReuseHandlers with proper RunInfo (model type + component) before each failover model invocation so handlers receive correct identity - Add explicit OnStart/OnEnd/OnError callback invocations in Generate and Stream since failoverProxyModel declares IsCallbacksEnabled() = true and the outer layer skips automatic callback injection Change-Id: I0150529024125251828cf6f77c8247aa464b1f84 * fix(adk): preserve partial result in failoverProxyModel.Generate on error Return result instead of nil when target.Generate fails, so that the outer failoverModelWrapper can pass the partial output message to ShouldFailover for inspection. Change-Id: I32d86151a6e133f1a58d5e988bccf42d831a646c * refactor(adk): use EnsureRunInfo in failoverProxyModel and separate ctx for callbacks - Replace manual RunInfo construction + ReuseHandlers with callbacks.EnsureRunInfo for cleaner RunInfo setup - Use nCtx (from EnsureRunInfo) for target model invocation and original ctx for OnStart/OnEnd/OnError callback lifecycle Change-Id: I1d5982d0e1ceeaf8f6648b9c40c229b6a2b07ab8 --------- Co-authored-by: shentong.martin --- adk/chatmodel.go | 94 ++-- adk/chatmodel_test.go | 41 ++ adk/failover_chatmodel.go | 466 +++++++++++++++++++ adk/failover_chatmodel_test.go | 697 ++++++++++++++++++++++++++++ adk/handler.go | 6 + adk/prebuilt/deep/deep.go | 12 +- adk/prebuilt/deep/task_tool.go | 23 +- adk/prebuilt/deep/task_tool_test.go | 1 + adk/wrappers.go | 141 ++++-- adk/wrappers_failover_test.go | 181 ++++++++ adk/wrappers_retry_failover_test.go | 411 ++++++++++++++++ 11 files changed, 1993 insertions(+), 80 deletions(-) create mode 100644 adk/failover_chatmodel.go create mode 100644 adk/failover_chatmodel_test.go create mode 100644 adk/wrappers_failover_test.go create mode 100644 adk/wrappers_retry_failover_test.go diff --git a/adk/chatmodel.go b/adk/chatmodel.go index 4b736d51d..83d8ffd4a 100644 --- a/adk/chatmodel.go +++ b/adk/chatmodel.go @@ -44,6 +44,9 @@ type chatModelAgentExecCtx struct { runtimeReturnDirectly map[string]bool generator *AsyncGenerator[*AgentEvent] cancelCtx *cancelContext + + // failoverLastSuccessModel is the last success model only used in failover middleware. + failoverLastSuccessModel model.BaseChatModel } func (e *chatModelAgentExecCtx) send(event *AgentEvent) { @@ -260,13 +263,14 @@ type ChatModelAgentConfig struct { // Model call lifecycle (outermost to innermost wrapper chain): // 1. AgentMiddleware.BeforeChatModel (hook, runs before model call) // 2. ChatModelAgentMiddleware.BeforeModelRewriteState (hook, can modify state before model call) - // 3. retryModelWrapper (internal - retries on failure, if configured) - // 4. eventSenderModelWrapper (internal - sends model response events) - // 5. ChatModelAgentMiddleware.WrapModel (wrapper, first registered is outermost) - // 6. callbackInjectionModelWrapper (internal - injects callbacks if not enabled) - // 7. Model.Generate/Stream - // 8. ChatModelAgentMiddleware.AfterModelRewriteState (hook, can modify state after model call) - // 9. AgentMiddleware.AfterChatModel (hook, runs after model call) + // 3. failoverModelWrapper (internal - failover between models, if configured) + // 4. retryModelWrapper (internal - retries on failure, if configured) + // 5. eventSenderModelWrapper (internal - sends model response events) + // 6. ChatModelAgentMiddleware.WrapModel (wrapper, first registered is outermost) + // 7. callbackInjectionModelWrapper (internal - injects callbacks if not enabled; when failover is enabled, this is handled per-model inside failoverProxyModel instead) + // 8. failoverProxyModel (internal - dispatches to selected failover model, if configured) / Model.Generate/Stream + // 9. ChatModelAgentMiddleware.AfterModelRewriteState (hook, can modify state after model call) + // 10. AgentMiddleware.AfterChatModel (hook, runs after model call) // // Custom Event Sender Position: // By default, events are sent after all user middlewares (WrapModel) have processed the output, @@ -337,6 +341,13 @@ type ChatModelAgentConfig struct { // based on the configured policy. // Optional. If nil, no retry will be performed. ModelRetryConfig *ModelRetryConfig + + // ModelFailoverConfig configures failover behavior for the ChatModel. + // When set, the agent will first try the last successful model (initially the configured Model), + // and on failure, call GetFailoverModel to select alternate models. + // Model field is still required as it serves as the initial model. + // Optional. If nil, no failover will be performed. + ModelFailoverConfig *ModelFailoverConfig } type ChatModelAgent struct { @@ -362,7 +373,8 @@ type ChatModelAgent struct { handlers []ChatModelAgentMiddleware middlewares []AgentMiddleware - modelRetryConfig *ModelRetryConfig + modelRetryConfig *ModelRetryConfig + modelFailoverConfig *ModelFailoverConfig once sync.Once run runFunc @@ -386,6 +398,17 @@ type runFunc func(ctx context.Context, p *runParams) // NewChatModelAgent constructs a chat model-backed agent with the provided config. func NewChatModelAgent(ctx context.Context, config *ChatModelAgentConfig) (*ChatModelAgent, error) { + if config.ModelFailoverConfig != nil { + if config.ModelFailoverConfig.GetFailoverModel == nil { + return nil, errors.New("ModelFailoverConfig.GetFailoverModel is required when ModelFailoverConfig is set") + } + + // ShouldFailover is required when ModelFailoverConfig is set + if config.ModelFailoverConfig.ShouldFailover == nil { + return nil, errors.New("ModelFailoverConfig.ShouldFailover is required when ModelFailoverConfig is set") + } + } + if config.Model == nil { return nil, errors.New("agent 'Model' is required") } @@ -420,18 +443,19 @@ func NewChatModelAgent(ctx context.Context, config *ChatModelAgentConfig) (*Chat }) return &ChatModelAgent{ - name: config.Name, - description: config.Description, - instruction: config.Instruction, - model: config.Model, - toolsConfig: tc, - genModelInput: genInput, - exit: config.Exit, - outputKey: config.OutputKey, - maxIterations: config.MaxIterations, - handlers: config.Handlers, - middlewares: config.Middlewares, - modelRetryConfig: config.ModelRetryConfig, + name: config.Name, + description: config.Description, + instruction: config.Instruction, + model: config.Model, + toolsConfig: tc, + genModelInput: genInput, + exit: config.Exit, + outputKey: config.OutputKey, + maxIterations: config.MaxIterations, + handlers: config.Handlers, + middlewares: config.Middlewares, + modelRetryConfig: config.ModelRetryConfig, + modelFailoverConfig: config.ModelFailoverConfig, }, nil } @@ -799,10 +823,11 @@ func (a *ChatModelAgent) buildNoToolsRunFunc(_ context.Context) runFunc { ctx = withCancelContext(ctx, cancelCtx) wrappedModel := buildModelWrappers(a.model, &modelWrapperConfig{ - handlers: a.handlers, - middlewares: a.middlewares, - retryConfig: a.modelRetryConfig, - cancelContext: cancelCtx, + handlers: a.handlers, + middlewares: a.middlewares, + retryConfig: a.modelRetryConfig, + failoverConfig: a.modelFailoverConfig, + cancelContext: cancelCtx, }) chain := compose.NewChain[noToolsInput, Message]( @@ -841,8 +866,9 @@ func (a *ChatModelAgent) buildNoToolsRunFunc(_ context.Context) runFunc { } ctx = withChatModelAgentExecCtx(ctx, &chatModelAgentExecCtx{ - generator: p.generator, - cancelCtx: cancelCtx, + generator: p.generator, + cancelCtx: cancelCtx, + failoverLastSuccessModel: a.model, }) // Pre-execution cancel check @@ -888,10 +914,11 @@ func (a *ChatModelAgent) buildReActRunFunc(_ context.Context, bc *execContext) ( model: a.model, toolsConfig: &bc.toolsNodeConf, modelWrapperConf: &modelWrapperConfig{ - handlers: a.handlers, - middlewares: a.middlewares, - retryConfig: a.modelRetryConfig, - toolInfos: bc.toolInfos, + handlers: a.handlers, + middlewares: a.middlewares, + retryConfig: a.modelRetryConfig, + failoverConfig: a.modelFailoverConfig, + toolInfos: bc.toolInfos, }, toolsReturnDirectly: bc.returnDirectly, agentName: a.name, @@ -951,9 +978,10 @@ func (a *ChatModelAgent) buildReActRunFunc(_ context.Context, bc *execContext) ( } ctx = withChatModelAgentExecCtx(ctx, &chatModelAgentExecCtx{ - runtimeReturnDirectly: p.returnDirectly, - generator: p.generator, - cancelCtx: cancelCtx, + runtimeReturnDirectly: p.returnDirectly, + generator: p.generator, + cancelCtx: cancelCtx, + failoverLastSuccessModel: a.model, }) // Pre-execution cancel check diff --git a/adk/chatmodel_test.go b/adk/chatmodel_test.go index 3a2f920dd..f3ff6ea05 100644 --- a/adk/chatmodel_test.go +++ b/adk/chatmodel_test.go @@ -2057,3 +2057,44 @@ func TestPreprocessComposeCheckpoint_MigrateErrorIsReturned(t *testing.T) { _, err := preprocessComposeCheckpoint(in) assert.Error(t, err) } + +func TestNewChatModelAgent_FailoverConfigValidation(t *testing.T) { + ctx := context.Background() + cm := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return schema.AssistantMessage("ok", nil), nil + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("ok", nil)}), nil + }, + } + + t.Run("missing GetFailoverModel", func(t *testing.T) { + _, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: cm, + ModelFailoverConfig: &ModelFailoverConfig{ + ShouldFailover: func(context.Context, *schema.Message, error) bool { return true }, + }, + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "ModelFailoverConfig.GetFailoverModel") + }) + + t.Run("missing ShouldFailover", func(t *testing.T) { + _, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: cm, + ModelFailoverConfig: &ModelFailoverConfig{ + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return cm, nil, nil + }, + }, + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "ModelFailoverConfig.ShouldFailover") + }) +} diff --git a/adk/failover_chatmodel.go b/adk/failover_chatmodel.go new file mode 100644 index 000000000..2a467ed76 --- /dev/null +++ b/adk/failover_chatmodel.go @@ -0,0 +1,466 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "errors" + "fmt" + "io" + "log" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/components" + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/schema" +) + +type failoverCurrentModelKey struct{} + +type failoverCurrentModel struct { + model model.BaseChatModel +} + +func setFailoverCurrentModel(ctx context.Context, currentModel model.BaseChatModel) context.Context { + return context.WithValue(ctx, failoverCurrentModelKey{}, &failoverCurrentModel{ + model: currentModel, + }) +} + +func getFailoverCurrentModel(ctx context.Context) *failoverCurrentModel { + if fm, ok := ctx.Value(failoverCurrentModelKey{}).(*failoverCurrentModel); ok { + return fm + } + return nil +} + +type failoverHasMoreAttemptsKey struct{} + +// withFailoverHasMoreAttempts sets a flag in context indicating whether additional failover +// attempts remain after the current one. This is read by buildErrWrapper to decide whether +// stream errors should be wrapped as WillRetryError. +func withFailoverHasMoreAttempts(ctx context.Context, hasMore bool) context.Context { + return context.WithValue(ctx, failoverHasMoreAttemptsKey{}, hasMore) +} + +// getFailoverHasMoreAttempts returns true if the current failover attempt has more attempts +// after it, false otherwise (including when no failover context is present). +func getFailoverHasMoreAttempts(ctx context.Context) bool { + v, _ := ctx.Value(failoverHasMoreAttemptsKey{}).(bool) + return v +} + +type failoverProxyModel struct { +} + +func (m *failoverProxyModel) prepareCallbacks(ctx context.Context) (context.Context, model.BaseChatModel, error) { + current := getFailoverCurrentModel(ctx) + if current == nil || current.model == nil { + return nil, nil, errors.New("failover current model not found in context") + } + + typ, _ := components.GetType(current.model) + ctx = callbacks.EnsureRunInfo(ctx, typ, components.ComponentOfChatModel) + + target := current.model + if !components.IsCallbacksEnabled(target) { + target = (&callbackInjectionModelWrapper{}).WrapModel(target) + } + + return ctx, target, nil +} + +func (m *failoverProxyModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + nCtx, target, err := m.prepareCallbacks(ctx) + if err != nil { + return nil, err + } + + ctx = callbacks.OnStart(ctx, input) + + result, err := target.Generate(nCtx, input, opts...) + if err != nil { + callbacks.OnError(ctx, err) + return result, err + } + + callbacks.OnEnd(ctx, result) + + return result, nil +} + +func (m *failoverProxyModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + nCtx, target, err := m.prepareCallbacks(ctx) + if err != nil { + return nil, err + } + + ctx = callbacks.OnStart(ctx, input) + + result, err := target.Stream(nCtx, input, opts...) + if err != nil { + callbacks.OnError(ctx, err) + return nil, err + } + + _, wrappedStream := callbacks.OnEndWithStreamOutput(ctx, result) + return wrappedStream, nil +} + +func (m *failoverProxyModel) IsCallbacksEnabled() bool { + return true +} + +func (m *failoverProxyModel) GetType() string { + return "FailoverProxyModel" +} + +// FailoverContext contains context information during failover process. +type FailoverContext struct { + // FailoverAttempt is the current failover attempt number, starting from 1. + FailoverAttempt uint + + // InputMessages is the original input messages before any transformation. + InputMessages []*schema.Message + + // LastOutputMessage is the output message from the last failed attempt. + // May be nil if no output was produced. For streaming, this may be a partial message + // already received before the stream error. + LastOutputMessage *schema.Message + + // LastErr is the error from the last failed attempt that triggered this failover. + // + // Note: When ModelRetryConfig is also configured, LastErr will be a *RetryExhaustedError + // (if retries were exhausted) rather than the original model error. The original error + // can be retrieved via RetryExhaustedError.LastErr. + LastErr error +} + +// ModelFailoverConfig configures failover behavior for ChatModel. +// When configured, each ChatModel call first tries the last successful model (initially the configured Model), +// and if that fails, calls GetFailoverModel to select alternate models. +type ModelFailoverConfig struct { + // MaxRetries specifies the maximum number of failover attempts. + // + // When failover is triggered, GetFailoverModel will be called up to MaxRetries times + // (FailoverAttempt starts from 1). If GetFailoverModel returns an error, failover + // stops immediately and that error is returned. + // + // A value of 0 means no failover (GetFailoverModel will not be called). + // A value of 1 means GetFailoverModel may be called once. + // + // Note: if lastSuccessModel is set (from a previous successful call), it will be tried + // first before calling GetFailoverModel. + MaxRetries uint + + // ShouldFailover determines whether to fail over to the next model when an error occurs. + // It receives the output message (may be nil if no output is available) and the error (non-nil on failure). + // For streaming errors, outputMessage can carry a partial message accumulated before the error. + // + // Note: When ModelRetryConfig is also configured, outputErr will be a *RetryExhaustedError + // (if retries were exhausted) rather than the original model error. Use errors.As to extract + // the RetryExhaustedError and access RetryExhaustedError.LastErr for the original error: + // + // var retryErr *adk.RetryExhaustedError + // if errors.As(outputErr, &retryErr) { + // // retryErr.LastErr contains the original model error + // } + // + // Note: When the context itself is cancelled (ctx.Err() != nil), failover will stop immediately + // regardless of this function. However, if the model returns context.Canceled or context.DeadlineExceeded + // as an error while the context is still active, this function will still be called. + // Should not be nil when ModelFailoverConfig is set. + // Return true to fail over to the next model, false to stop and return the current result/error. + ShouldFailover func(ctx context.Context, outputMessage *schema.Message, outputErr error) bool + + // GetFailoverModel is called when a model call fails and ShouldFailover returns true. + // It selects the next model to use for the failover attempt and optionally transforms input messages. + // It receives the failover context containing attempt number (starting from 1), original input, and last result. + // Return values: + // - failoverModel: The model to use for this failover attempt. + // - failoverModelInputMessages: The transformed input messages for the failover model. If nil, will use original input. + // - failoverErr: If non-nil, failover stops and this error is returned. + // Should not be nil when ModelFailoverConfig is set via ChatModelAgentConfig. + GetFailoverModel func(ctx context.Context, failoverCtx *FailoverContext) ( + failoverModel model.BaseChatModel, failoverModelInputMessages []*schema.Message, failoverErr error) +} + +func getLastSuccessModel(ctx context.Context) model.BaseChatModel { + if execCtx := getChatModelAgentExecCtx(ctx); execCtx != nil { + return execCtx.failoverLastSuccessModel + } + return nil +} + +func setLastSuccessModel(ctx context.Context, m model.BaseChatModel) { + if execCtx := getChatModelAgentExecCtx(ctx); execCtx != nil { + execCtx.failoverLastSuccessModel = m + } +} + +type failoverModelWrapper struct { + config *ModelFailoverConfig + inner model.BaseChatModel +} + +func newFailoverModelWrapper(inner model.BaseChatModel, config *ModelFailoverConfig) *failoverModelWrapper { + return &failoverModelWrapper{ + config: config, + inner: inner, + } +} + +func (f *failoverModelWrapper) needFailover(ctx context.Context, outputMessage *schema.Message, outputErr error) bool { + if ctx.Err() != nil { + return false + } + + // ShouldFailover is validated at agent construction; nil here indicates a programmer error. + return f.config.ShouldFailover(ctx, outputMessage, outputErr) +} + +func (f *failoverModelWrapper) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + // Defensive: GetFailoverModel is validated non-nil at agent construction. + if f.config.GetFailoverModel == nil { + return f.inner.Generate(ctx, input, opts...) + } + + var lastOutputMessage *schema.Message + var lastErr error + + // Try lastSuccessModel first if available. + if lastSuccess := getLastSuccessModel(ctx); lastSuccess != nil { + if err := ctx.Err(); err != nil { + return nil, err + } + + modelCtx := setFailoverCurrentModel(ctx, lastSuccess) + modelCtx = withFailoverHasMoreAttempts(modelCtx, f.config.MaxRetries > 0) + result, err := f.inner.Generate(modelCtx, input, opts...) + if err == nil { + return result, nil + } + + lastOutputMessage = result + lastErr = err + + if !f.needFailover(ctx, result, err) { + return result, err + } + + log.Printf("failover ChatModel.Generate lastSuccessModel failed: %v", err) + } + + for attempt := uint(1); attempt <= f.config.MaxRetries; attempt++ { + if err := ctx.Err(); err != nil { + return nil, err + } + + failoverCtx := &FailoverContext{ + FailoverAttempt: attempt, + InputMessages: input, + LastOutputMessage: lastOutputMessage, + LastErr: lastErr, + } + + currentModel, currentInput, err := f.config.GetFailoverModel(ctx, failoverCtx) + if err != nil { + return nil, err + } + if currentModel == nil { + return nil, fmt.Errorf("failover GetFailoverModel returned nil model at attempt %d", attempt) + } + + if currentInput == nil { + currentInput = input + } + + modelCtx := setFailoverCurrentModel(ctx, currentModel) + modelCtx = withFailoverHasMoreAttempts(modelCtx, attempt < f.config.MaxRetries) + result, err := f.inner.Generate(modelCtx, currentInput, opts...) + lastOutputMessage = result + lastErr = err + + if err == nil { + setLastSuccessModel(ctx, currentModel) + return result, nil + } + + if !f.needFailover(ctx, result, err) { + return result, err + } + + if attempt < f.config.MaxRetries { + log.Printf("failover ChatModel.Generate attempt %d failed: %v", attempt, err) + } + } + + return lastOutputMessage, lastErr +} + +func (f *failoverModelWrapper) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) ( + *schema.StreamReader[*schema.Message], error) { + // Defensive: GetFailoverModel is validated non-nil at agent construction. + if f.config.GetFailoverModel == nil { + return f.inner.Stream(ctx, input, opts...) + } + + var lastOutputMessage *schema.Message + var lastErr error + + // Try lastSuccessModel first if available. + if lastSuccess := getLastSuccessModel(ctx); lastSuccess != nil { + if err := ctx.Err(); err != nil { + return nil, err + } + + modelCtx := setFailoverCurrentModel(ctx, lastSuccess) + modelCtx = withFailoverHasMoreAttempts(modelCtx, f.config.MaxRetries > 0) + stream, err := f.inner.Stream(modelCtx, input, opts...) + if err != nil { + lastErr = err + if !f.needFailover(ctx, nil, err) { + return nil, err + } + log.Printf("failover ChatModel.Stream lastSuccessModel failed: %v", err) + } else { + copies := stream.Copy(2) + checkCopy := copies[0] + returnCopy := copies[1] + + outMsg, streamErr := consumeStream(checkCopy) + if streamErr != nil { + lastOutputMessage = outMsg + lastErr = streamErr + returnCopy.Close() + + if !f.needFailover(ctx, outMsg, streamErr) { + return nil, streamErr + } + log.Printf("failover ChatModel.Stream lastSuccessModel failed: %v", streamErr) + } else { + return returnCopy, nil + } + } + } + + for attempt := uint(1); attempt <= f.config.MaxRetries; attempt++ { + if err := ctx.Err(); err != nil { + return nil, err + } + + failoverCtx := &FailoverContext{ + FailoverAttempt: attempt, + InputMessages: input, + LastOutputMessage: lastOutputMessage, + LastErr: lastErr, + } + + currentModel, currentInput, err := f.config.GetFailoverModel(ctx, failoverCtx) + if err != nil { + return nil, err + } + if currentModel == nil { + return nil, fmt.Errorf("failover GetFailoverModel returned nil model at attempt %d", attempt) + } + + if currentInput == nil { + currentInput = input + } + + modelCtx := setFailoverCurrentModel(ctx, currentModel) + modelCtx = withFailoverHasMoreAttempts(modelCtx, attempt < f.config.MaxRetries) + stream, err := f.inner.Stream(modelCtx, currentInput, opts...) + if err != nil { + lastErr = err + lastOutputMessage = nil + + if !f.needFailover(ctx, nil, err) { + return nil, err + } + + if attempt < f.config.MaxRetries { + log.Printf("failover ChatModel.Stream attempt %d failed: %v", attempt, err) + } + continue + } + + // The stream returned by f.inner.Stream is already Copy'd by the inner eventSender layer: one + // copy is forwarded to the client in real time via events. Therefore consuming a copy here does + // NOT block client-side streaming. + // + // We Copy the stream into two readers: + // - checkCopy: consumed synchronously to surface mid-stream errors and decide whether to fail over. + // - returnCopy: returned to the caller (stateModelWrapper), which also consumes synchronously to + // build state (AfterModelRewriteState), so waiting here adds no extra latency. + // + // If checkCopy errors and failover is allowed, we close returnCopy and retry with the next model. + // Otherwise we return returnCopy. + // + // NOTE on duplicate events during failover: when a retry happens, events from the failed attempt + // may already have been emitted to the client, and the retry will emit a new stream. Client-side + // handlers are expected to handle multiple rounds (e.g., reset on retry or deduplicate by attempt + // metadata). + copies := stream.Copy(2) + checkCopy := copies[0] + returnCopy := copies[1] + + outMsg, streamErr := consumeStream(checkCopy) + if streamErr != nil { + lastOutputMessage = outMsg + lastErr = streamErr + returnCopy.Close() + + if !f.needFailover(ctx, outMsg, streamErr) { + return nil, streamErr + } + + if attempt < f.config.MaxRetries { + log.Printf("failover ChatModel.Stream attempt %d failed: %v", attempt, streamErr) + } + continue + } + + setLastSuccessModel(ctx, currentModel) + return returnCopy, nil + } + + return nil, lastErr +} + +func consumeStream(stream *schema.StreamReader[*schema.Message]) (*schema.Message, error) { + defer stream.Close() + chunks := make([]*schema.Message, 0) + for { + chunk, err := stream.Recv() + if err == io.EOF { + break + } + if err != nil { + // ignore concat error + msg, _ := schema.ConcatMessages(chunks) + return msg, err + } + + chunks = append(chunks, chunk) + } + + // Stream completed successfully (EOF). ConcatMessages error is not a stream error, + // so ignore it to avoid incorrectly triggering failover. + msg, _ := schema.ConcatMessages(chunks) + return msg, nil +} diff --git a/adk/failover_chatmodel_test.go b/adk/failover_chatmodel_test.go new file mode 100644 index 000000000..82866e994 --- /dev/null +++ b/adk/failover_chatmodel_test.go @@ -0,0 +1,697 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "errors" + "io" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/schema" +) + +type fakeChatModel struct { + generate func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) + stream func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) + callbacksEnabled bool +} + +func (m *fakeChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + return m.generate(ctx, input, opts...) +} + +func (m *fakeChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return m.stream(ctx, input, opts...) +} + +func (m *fakeChatModel) IsCallbacksEnabled() bool { + return m.callbacksEnabled +} + +func drainMessageStream(sr *schema.StreamReader[*schema.Message]) ([]*schema.Message, error) { + defer sr.Close() + var out []*schema.Message + for { + chunk, err := sr.Recv() + if err == io.EOF { + return out, nil + } + if err != nil { + return out, err + } + out = append(out, chunk) + } +} + +func streamWithMidError(chunks []*schema.Message, err error) *schema.StreamReader[*schema.Message] { + sr, sw := schema.Pipe[*schema.Message](2) + go func() { + defer sw.Close() + for _, c := range chunks { + sw.Send(c, nil) + } + sw.Send(nil, err) + }() + return sr +} + +func streamWithMidErrorControlled(chunks []*schema.Message, err error, firstSent chan struct{}, release chan struct{}) *schema.StreamReader[*schema.Message] { + sr, sw := schema.Pipe[*schema.Message](2) + go func() { + defer sw.Close() + for i, c := range chunks { + sw.Send(c, nil) + if i == 0 && firstSent != nil { + close(firstSent) + if release != nil { + <-release + } + } + } + sw.Send(nil, err) + }() + return sr +} + +func TestFailoverCurrentModelContext(t *testing.T) { + t.Run("set and get", func(t *testing.T) { + ctx := context.Background() + m := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return schema.AssistantMessage("ok", nil), nil + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("ok", nil)}), nil + }, + } + ctx = setFailoverCurrentModel(ctx, m) + got := getFailoverCurrentModel(ctx) + require.NotNil(t, got) + require.Same(t, m, got.model) + }) + + t.Run("wrong type", func(t *testing.T) { + ctx := context.WithValue(context.Background(), failoverCurrentModelKey{}, "bad") + require.Nil(t, getFailoverCurrentModel(ctx)) + }) + + t.Run("missing", func(t *testing.T) { + require.Nil(t, getFailoverCurrentModel(context.Background())) + }) +} + +func TestFailoverProxyModel(t *testing.T) { + t.Run("generate missing context", func(t *testing.T) { + p := &failoverProxyModel{} + _, err := p.Generate(context.Background(), []*schema.Message{schema.UserMessage("hi")}) + require.Error(t, err) + }) + + t.Run("stream missing context", func(t *testing.T) { + p := &failoverProxyModel{} + _, err := p.Stream(context.Background(), []*schema.Message{schema.UserMessage("hi")}) + require.Error(t, err) + }) + + t.Run("generate routes to current model", func(t *testing.T) { + var called int32 + target := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&called, 1) + return schema.AssistantMessage("routed", nil), nil + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("routed", nil)}), nil + }, + } + ctx := setFailoverCurrentModel(context.Background(), target) + p := &failoverProxyModel{} + msg, err := p.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.NoError(t, err) + require.Equal(t, "routed", msg.Content) + require.Equal(t, int32(1), atomic.LoadInt32(&called)) + }) +} + +func TestFailoverModelWrapper_Generate(t *testing.T) { + t.Run("delegates when GetFailoverModel nil", func(t *testing.T) { + var called int32 + inner := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&called, 1) + return schema.AssistantMessage("inner", nil), nil + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("inner", nil)}), nil + }, + } + w := newFailoverModelWrapper(inner, &ModelFailoverConfig{ + MaxRetries: 2, + ShouldFailover: func(context.Context, *schema.Message, error) bool { return true }, + GetFailoverModel: nil, + }) + msg, err := w.Generate(context.Background(), []*schema.Message{schema.UserMessage("hi")}) + require.NoError(t, err) + require.Equal(t, "inner", msg.Content) + require.Equal(t, int32(1), atomic.LoadInt32(&called)) + }) + + t.Run("failover to second model", func(t *testing.T) { + wantErr := errors.New("first failed") + var shouldCalls int32 + var m1Calls int32 + var m2Calls int32 + + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m1Calls, 1) + return nil, wantErr + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, errors.New("unused") + }, + } + m2 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m2Calls, 1) + return schema.AssistantMessage("ok", nil), nil + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, errors.New("unused") + }, + } + + cfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + atomic.AddInt32(&shouldCalls, 1) + return errors.Is(err, wantErr) + }, + GetFailoverModel: func(_ context.Context, failoverCtx *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + require.Equal(t, uint(1), failoverCtx.FailoverAttempt) + return m2, nil, nil + }, + } + + w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + msg, err := w.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.NoError(t, err) + require.Equal(t, "ok", msg.Content) + require.Equal(t, int32(1), atomic.LoadInt32(&m1Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&m2Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&shouldCalls)) + }) + + t.Run("canceled error delegates to ShouldFailover", func(t *testing.T) { + var shouldCalls int32 + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, context.Canceled + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, errors.New("unused") + }, + } + + cfg := &ModelFailoverConfig{ + MaxRetries: 5, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + atomic.AddInt32(&shouldCalls, 1) + // User decides to stop on canceled error + return !errors.Is(err, context.Canceled) + }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return m1, nil, nil + }, + } + + w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + _, err := w.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.ErrorIs(t, err, context.Canceled) + // ShouldFailover is called once and returns false, stopping failover + require.Equal(t, int32(1), atomic.LoadInt32(&shouldCalls)) + }) + + t.Run("stops when GetFailoverModel returns error", func(t *testing.T) { + wantErr := errors.New("get model failed") + var called int32 + inner := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&called, 1) + return schema.AssistantMessage("unused", nil), nil + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, errors.New("unused") + }, + } + + cfg := &ModelFailoverConfig{ + MaxRetries: 3, + ShouldFailover: func(context.Context, *schema.Message, error) bool { return true }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return nil, nil, wantErr + }, + } + + w := newFailoverModelWrapper(inner, cfg) + _, err := w.Generate(context.Background(), []*schema.Message{schema.UserMessage("hi")}) + require.ErrorIs(t, err, wantErr) + require.Equal(t, int32(0), atomic.LoadInt32(&called)) + }) + + t.Run("stops when GetFailoverModel returns nil model", func(t *testing.T) { + cfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(context.Context, *schema.Message, error) bool { return true }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return nil, nil, nil + }, + } + + w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) + msg, err := w.Generate(context.Background(), []*schema.Message{schema.UserMessage("hi")}) + require.Nil(t, msg) + require.Error(t, err) + require.ErrorContains(t, err, "GetFailoverModel returned nil model") + }) +} + +func TestFailoverModelWrapper_Stream(t *testing.T) { + t.Run("returns stream when first attempt succeeds", func(t *testing.T) { + var shouldCalls int32 + in := schema.UserMessage("hi") + + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, input []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + require.Len(t, input, 1) + require.Same(t, in, input[0]) + return schema.StreamReaderFromArray([]*schema.Message{ + schema.AssistantMessage("a", nil), + schema.AssistantMessage("b", nil), + }), nil + }, + } + + cfg := &ModelFailoverConfig{ + MaxRetries: 0, + ShouldFailover: func(context.Context, *schema.Message, error) bool { + atomic.AddInt32(&shouldCalls, 1) + return false + }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return m1, nil, nil + }, + } + + w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + sr, err := w.Stream(ctx, []*schema.Message{in}) + require.NoError(t, err) + msgs, err := drainMessageStream(sr) + require.NoError(t, err) + require.Len(t, msgs, 2) + require.Equal(t, "a", msgs[0].Content) + require.Equal(t, "b", msgs[1].Content) + require.Equal(t, int32(0), atomic.LoadInt32(&shouldCalls)) + }) + + t.Run("failover when Stream returns error immediately", func(t *testing.T) { + wantErr := errors.New("stream init failed") + var shouldCalls int32 + var m1Calls int32 + var m2Calls int32 + + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m1Calls, 1) + return nil, wantErr + }, + } + m2 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m2Calls, 1) + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("ok", nil)}), nil + }, + } + + cfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + atomic.AddInt32(&shouldCalls, 1) + return errors.Is(err, wantErr) + }, + GetFailoverModel: func(_ context.Context, failoverCtx *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + require.Equal(t, uint(1), failoverCtx.FailoverAttempt) + return m2, nil, nil + }, + } + + w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + sr, err := w.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.NoError(t, err) + msgs, err := drainMessageStream(sr) + require.NoError(t, err) + require.Len(t, msgs, 1) + require.Equal(t, "ok", msgs[0].Content) + require.Equal(t, int32(1), atomic.LoadInt32(&m1Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&m2Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&shouldCalls)) + }) + + t.Run("failover when stream errors mid-way", func(t *testing.T) { + streamErr := errors.New("mid error") + var shouldCalls int32 + var seenOutput atomic.Value + var m1Calls int32 + var m2Calls int32 + + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m1Calls, 1) + return streamWithMidError([]*schema.Message{ + schema.AssistantMessage("p1", nil), + schema.AssistantMessage("p2", nil), + }, streamErr), nil + }, + } + m2 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m2Calls, 1) + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("final", nil)}), nil + }, + } + + cfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(_ context.Context, out *schema.Message, err error) bool { + atomic.AddInt32(&shouldCalls, 1) + if errors.Is(err, streamErr) && out != nil { + seenOutput.Store(out.Content) + } + return errors.Is(err, streamErr) + }, + GetFailoverModel: func(_ context.Context, failoverCtx *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + require.Equal(t, uint(1), failoverCtx.FailoverAttempt) + return m2, nil, nil + }, + } + + w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + sr, err := w.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.NoError(t, err) + msgs, err := drainMessageStream(sr) + require.NoError(t, err) + require.Len(t, msgs, 1) + require.Equal(t, "final", msgs[0].Content) + require.Equal(t, "p1p2", seenOutput.Load()) + require.Equal(t, int32(1), atomic.LoadInt32(&m1Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&m2Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&shouldCalls)) + }) + + t.Run("stop when ShouldFailover returns false for mid-way error", func(t *testing.T) { + streamErr := errors.New("mid error") + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return streamWithMidError([]*schema.Message{schema.AssistantMessage("p", nil)}, streamErr), nil + }, + } + + cfg := &ModelFailoverConfig{ + MaxRetries: 3, + ShouldFailover: func(context.Context, *schema.Message, error) bool { + return false + }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return m1, nil, nil + }, + } + + w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + sr, err := w.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.Nil(t, sr) + require.ErrorIs(t, err, streamErr) + }) + + t.Run("canceled mid-way error delegates to ShouldFailover", func(t *testing.T) { + var shouldCalls int32 + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return streamWithMidError([]*schema.Message{schema.AssistantMessage("p", nil)}, context.Canceled), nil + }, + } + + cfg := &ModelFailoverConfig{ + MaxRetries: 3, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + atomic.AddInt32(&shouldCalls, 1) + // User decides to stop on canceled error + return !errors.Is(err, context.Canceled) + }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return m1, nil, nil + }, + } + + w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + sr, err := w.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.Nil(t, sr) + require.ErrorIs(t, err, context.Canceled) + // ShouldFailover is called once and returns false, stopping failover + require.Equal(t, int32(1), atomic.LoadInt32(&shouldCalls)) + }) + + t.Run("stop when Stream returns error immediately and ShouldFailover returns false", func(t *testing.T) { + wantErr := errors.New("stream init failed") + var shouldCalls int32 + var m1Calls int32 + + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m1Calls, 1) + return nil, wantErr + }, + } + + cfg := &ModelFailoverConfig{ + MaxRetries: 3, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + atomic.AddInt32(&shouldCalls, 1) + require.ErrorIs(t, err, wantErr) + return false + }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return m1, nil, nil + }, + } + + w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + sr, err := w.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.Nil(t, sr) + require.ErrorIs(t, err, wantErr) + require.Equal(t, int32(1), atomic.LoadInt32(&m1Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&shouldCalls)) + }) + + t.Run("stops when GetFailoverModel returns nil model", func(t *testing.T) { + cfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(context.Context, *schema.Message, error) bool { return true }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return nil, nil, nil + }, + } + + w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) + sr, err := w.Stream(context.Background(), []*schema.Message{schema.UserMessage("hi")}) + require.Nil(t, sr) + require.Error(t, err) + require.ErrorContains(t, err, "GetFailoverModel returned nil model") + }) + + t.Run("stops when GetFailoverModel returns error", func(t *testing.T) { + wantErr := errors.New("get model failed") + var called int32 + inner := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&called, 1) + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("unused", nil)}), nil + }, + } + + cfg := &ModelFailoverConfig{ + MaxRetries: 3, + ShouldFailover: func(context.Context, *schema.Message, error) bool { return true }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return nil, nil, wantErr + }, + } + + w := newFailoverModelWrapper(inner, cfg) + sr, err := w.Stream(context.Background(), []*schema.Message{schema.UserMessage("hi")}) + require.Nil(t, sr) + require.ErrorIs(t, err, wantErr) + require.Equal(t, int32(0), atomic.LoadInt32(&called)) + }) + + t.Run("stops when ctx canceled during mid-way error handling", func(t *testing.T) { + midErr := errors.New("mid error") + var shouldCalls int32 + var m1Calls int32 + var m2Calls int32 + firstSent := make(chan struct{}) + release := make(chan struct{}) + + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m1Calls, 1) + return streamWithMidErrorControlled( + []*schema.Message{schema.AssistantMessage("p", nil)}, + midErr, + firstSent, + release, + ), nil + }, + } + m2 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m2Calls, 1) + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("unused", nil)}), nil + }, + } + + cfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(context.Context, *schema.Message, error) bool { + atomic.AddInt32(&shouldCalls, 1) + return true + }, + GetFailoverModel: func(_ context.Context, failoverCtx *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + require.Equal(t, uint(1), failoverCtx.FailoverAttempt) + return m2, nil, nil + }, + } + + w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) + baseCtx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + ctx, cancel := context.WithCancel(baseCtx) + type result struct { + sr *schema.StreamReader[*schema.Message] + err error + } + ch := make(chan result, 1) + go func() { + sr, err := w.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) + ch <- result{sr: sr, err: err} + }() + + <-firstSent + cancel() + close(release) + + res := <-ch + if res.sr != nil { + res.sr.Close() + } + require.Nil(t, res.sr) + require.ErrorIs(t, res.err, midErr) + require.Equal(t, int32(1), atomic.LoadInt32(&m1Calls)) + require.Equal(t, int32(0), atomic.LoadInt32(&m2Calls)) + require.Equal(t, int32(0), atomic.LoadInt32(&shouldCalls)) + }) +} diff --git a/adk/handler.go b/adk/handler.go index 854063f16..d18abc965 100644 --- a/adk/handler.go +++ b/adk/handler.go @@ -64,6 +64,12 @@ type ModelContext struct { // Used by EventSenderModelWrapper to wrap stream errors appropriately. ModelRetryConfig *ModelRetryConfig + // ModelFailoverConfig contains the failover configuration for the model. + // This is populated at request time from the agent's ModelFailoverConfig. + // Used by EventSenderModelWrapper to wrap stream errors so that failed failover + // attempts are skipped (not treated as fatal) by the flow event processor. + ModelFailoverConfig *ModelFailoverConfig + cancelContext *cancelContext } diff --git a/adk/prebuilt/deep/deep.go b/adk/prebuilt/deep/deep.go index 48b5349a6..3918d47e4 100644 --- a/adk/prebuilt/deep/deep.go +++ b/adk/prebuilt/deep/deep.go @@ -93,6 +93,10 @@ type Config struct { Handlers []adk.ChatModelAgentMiddleware ModelRetryConfig *adk.ModelRetryConfig + // ModelFailoverConfig configures failover behavior for the ChatModel. + // When set, the agent will automatically fail over to alternative models on errors. + // This config is also propagated to the general sub-agent. + ModelFailoverConfig *adk.ModelFailoverConfig // OutputKey stores the agent's response in the session. // Optional. When set, stores output via AddSessionValue(ctx, outputKey, msg.Content). @@ -129,6 +133,7 @@ func New(ctx context.Context, cfg *Config) (adk.ResumableAgent, error) { cfg.MaxIteration, cfg.Middlewares, append(handlers, cfg.Handlers...), + cfg.ModelFailoverConfig, ) if err != nil { return nil, fmt.Errorf("failed to new task tool: %w", err) @@ -146,9 +151,10 @@ func New(ctx context.Context, cfg *Config) (adk.ResumableAgent, error) { Middlewares: cfg.Middlewares, Handlers: append(handlers, cfg.Handlers...), - GenModelInput: genModelInput, - ModelRetryConfig: cfg.ModelRetryConfig, - OutputKey: cfg.OutputKey, + GenModelInput: genModelInput, + ModelRetryConfig: cfg.ModelRetryConfig, + ModelFailoverConfig: cfg.ModelFailoverConfig, + OutputKey: cfg.OutputKey, }) } diff --git a/adk/prebuilt/deep/task_tool.go b/adk/prebuilt/deep/task_tool.go index 6235021bd..e6fcedeb3 100644 --- a/adk/prebuilt/deep/task_tool.go +++ b/adk/prebuilt/deep/task_tool.go @@ -45,8 +45,9 @@ func newTaskToolMiddleware( maxIteration int, middlewares []adk.AgentMiddleware, handlers []adk.ChatModelAgentMiddleware, + modelFailoverConfig *adk.ModelFailoverConfig, ) (adk.ChatModelAgentMiddleware, error) { - t, err := newTaskTool(ctx, taskToolDescriptionGenerator, subAgents, withoutGeneralSubAgent, cm, instruction, toolsConfig, maxIteration, middlewares, handlers) + t, err := newTaskTool(ctx, taskToolDescriptionGenerator, subAgents, withoutGeneralSubAgent, cm, instruction, toolsConfig, maxIteration, middlewares, handlers, modelFailoverConfig) if err != nil { return nil, err } @@ -71,6 +72,7 @@ func newTaskTool( MaxIteration int, middlewares []adk.AgentMiddleware, handlers []adk.ChatModelAgentMiddleware, + modelFailoverConfig *adk.ModelFailoverConfig, ) (tool.InvokableTool, error) { t := &taskTool{ subAgents: map[string]tool.InvokableTool{}, @@ -88,15 +90,16 @@ func newTaskTool( Chinese: generalAgentDescriptionChinese, }) generalAgent, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ - Name: generalAgentName, - Description: agentDesc, - Instruction: Instruction, - Model: Model, - ToolsConfig: ToolsConfig, - MaxIterations: MaxIteration, - Middlewares: middlewares, - Handlers: handlers, - GenModelInput: genModelInput, + Name: generalAgentName, + Description: agentDesc, + Instruction: Instruction, + Model: Model, + ToolsConfig: ToolsConfig, + MaxIterations: MaxIteration, + Middlewares: middlewares, + Handlers: handlers, + GenModelInput: genModelInput, + ModelFailoverConfig: modelFailoverConfig, }) if err != nil { return nil, err diff --git a/adk/prebuilt/deep/task_tool_test.go b/adk/prebuilt/deep/task_tool_test.go index 91c3a7784..8d60eb452 100644 --- a/adk/prebuilt/deep/task_tool_test.go +++ b/adk/prebuilt/deep/task_tool_test.go @@ -41,6 +41,7 @@ func TestTaskTool(t *testing.T) { 10, nil, nil, + nil, ) assert.NoError(t, err) diff --git a/adk/wrappers.go b/adk/wrappers.go index e6ac617ea..b4e16d298 100644 --- a/adk/wrappers.go +++ b/adk/wrappers.go @@ -34,28 +34,35 @@ type generateEndpoint func(ctx context.Context, input []*schema.Message, opts .. type streamEndpoint func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) type modelWrapperConfig struct { - handlers []ChatModelAgentMiddleware - middlewares []AgentMiddleware - retryConfig *ModelRetryConfig - toolInfos []*schema.ToolInfo - cancelContext *cancelContext + handlers []ChatModelAgentMiddleware + middlewares []AgentMiddleware + retryConfig *ModelRetryConfig + failoverConfig *ModelFailoverConfig + toolInfos []*schema.ToolInfo + cancelContext *cancelContext } func buildModelWrappers(m model.BaseChatModel, config *modelWrapperConfig) model.BaseChatModel { var wrapped model.BaseChatModel = m - if !components.IsCallbacksEnabled(m) { + // failoverProxyModel must be the innermost wrapper to read the selected failover model from context. + if config.failoverConfig != nil { + wrapped = &failoverProxyModel{} + } + + if !components.IsCallbacksEnabled(wrapped) { wrapped = (&callbackInjectionModelWrapper{}).WrapModel(wrapped) } wrapped = &stateModelWrapper{ - inner: wrapped, - original: m, - handlers: config.handlers, - middlewares: config.middlewares, - toolInfos: config.toolInfos, - modelRetryConfig: config.retryConfig, - cancelContext: config.cancelContext, + inner: wrapped, + original: m, + handlers: config.handlers, + middlewares: config.middlewares, + toolInfos: config.toolInfos, + modelRetryConfig: config.retryConfig, + modelFailoverConfig: config.failoverConfig, + cancelContext: config.cancelContext, } return wrapped @@ -265,12 +272,17 @@ func (w *eventSenderModelWrapper) WrapModel(_ context.Context, m model.BaseChatM if mc != nil { retryConfig = mc.ModelRetryConfig } - return &eventSenderModel{inner: inner, modelRetryConfig: retryConfig}, nil + var failoverConfig *ModelFailoverConfig + if mc != nil { + failoverConfig = mc.ModelFailoverConfig + } + return &eventSenderModel{inner: inner, modelRetryConfig: retryConfig, modelFailoverConfig: failoverConfig}, nil } type eventSenderModel struct { - inner model.BaseChatModel - modelRetryConfig *ModelRetryConfig + inner model.BaseChatModel + modelRetryConfig *ModelRetryConfig + modelFailoverConfig *ModelFailoverConfig } func (m *eventSenderModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { @@ -303,19 +315,12 @@ func (m *eventSenderModel) Stream(ctx context.Context, input []*schema.Message, return nil, errors.New("generator is nil when sending event in Stream: ensure agent state is properly initialized") } - var retryAttempt int - _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { - retryAttempt = st.getRetryAttempt() - return nil - }) - streams := result.Copy(2) eventStream := streams[0] - if m.modelRetryConfig != nil { + if errWrapper := m.buildErrWrapper(ctx); errWrapper != nil { convertOpts := []schema.ConvertOption{ - schema.WithErrWrapper(genErrWrapper(ctx, m.modelRetryConfig.MaxRetries, - retryAttempt, m.modelRetryConfig.IsRetryAble)), + schema.WithErrWrapper(errWrapper), } eventStream = schema.StreamReaderWithConvert(streams[0], func(msg *schema.Message) (*schema.Message, error) { return msg, nil }, @@ -328,6 +333,51 @@ func (m *eventSenderModel) Stream(ctx context.Context, input []*schema.Message, return streams[1], nil } +// buildErrWrapper constructs an error wrapper function for event streams. +// It wraps stream errors as WillRetryError when retry or failover is configured, +// so that flow.go:genAgentInput() can skip events from failed attempts instead of +// treating them as fatal errors. +func (m *eventSenderModel) buildErrWrapper(ctx context.Context) func(error) error { + var retryAttempt int + _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { + retryAttempt = st.getRetryAttempt() + return nil + }) + + var retryWrapper func(error) error + if m.modelRetryConfig != nil { + retryWrapper = genErrWrapper(ctx, m.modelRetryConfig.MaxRetries, retryAttempt, m.modelRetryConfig.IsRetryAble) + } + + hasFailover := m.modelFailoverConfig != nil + // failoverHasMoreAttempts is set by failoverModelWrapper before each inner call. + // It is true when additional failover attempts remain after the current one, + // meaning stream errors should be wrapped as WillRetryError so the flow layer + // skips them. On the final attempt it is false, so the error propagates normally. + failoverHasMore := getFailoverHasMoreAttempts(ctx) + + if retryWrapper == nil && !(hasFailover && failoverHasMore) { + return nil + } + + return func(err error) error { + // If retry is configured and will retry this error, use the retry wrapper's WillRetryError. + if retryWrapper != nil { + wrapped := retryWrapper(err) + if _, ok := wrapped.(*WillRetryError); ok { + return wrapped + } + } + // Retry won't handle this error (either exhausted or not configured), but + // failover still has more attempts remaining. Wrap it as WillRetryError so + // the flow layer skips this event from the failed attempt. + if hasFailover && failoverHasMore { + return &WillRetryError{ErrStr: err.Error(), err: err} + } + return err + } +} + func popToolGenAction(ctx context.Context, toolName string) *AgentAction { toolCallID := compose.GetToolCallID(ctx) @@ -515,13 +565,14 @@ func hasUserEventSenderToolWrapper(handlers []ChatModelAgentMiddleware) bool { } type stateModelWrapper struct { - inner model.BaseChatModel - original model.BaseChatModel - handlers []ChatModelAgentMiddleware - middlewares []AgentMiddleware - toolInfos []*schema.ToolInfo - modelRetryConfig *ModelRetryConfig - cancelContext *cancelContext + inner model.BaseChatModel + original model.BaseChatModel + handlers []ChatModelAgentMiddleware + middlewares []AgentMiddleware + toolInfos []*schema.ToolInfo + modelRetryConfig *ModelRetryConfig + modelFailoverConfig *ModelFailoverConfig + cancelContext *cancelContext } func (w *stateModelWrapper) IsCallbacksEnabled() bool { @@ -547,6 +598,7 @@ func (w *stateModelWrapper) hasUserEventSender() bool { func (w *stateModelWrapper) wrapGenerateEndpoint(endpoint generateEndpoint) generateEndpoint { hasUserEventSender := w.hasUserEventSender() retryConfig := w.modelRetryConfig + failoverConfig := w.modelFailoverConfig cc := w.cancelContext for i := len(w.handlers) - 1; i >= 0; i-- { @@ -573,7 +625,7 @@ func (w *stateModelWrapper) wrapGenerateEndpoint(endpoint generateEndpoint) gene if execCtx == nil || execCtx.generator == nil { return innerEndpoint(ctx, input, opts...) } - mc := &ModelContext{ModelRetryConfig: retryConfig, cancelContext: cc} + mc := &ModelContext{ModelRetryConfig: retryConfig, ModelFailoverConfig: failoverConfig, cancelContext: cc} wrappedModel, err := eventSender.WrapModel(ctx, &endpointModel{generate: innerEndpoint}, mc) if err != nil { return nil, err @@ -590,12 +642,23 @@ func (w *stateModelWrapper) wrapGenerateEndpoint(endpoint generateEndpoint) gene } } + // Needs to handle failoverWrapper after retryWrapper + if w.modelFailoverConfig != nil { + config := w.modelFailoverConfig + innerEndpoint := endpoint + endpoint = func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + failoverWrapper := newFailoverModelWrapper(&endpointModel{generate: innerEndpoint}, config) + return failoverWrapper.Generate(ctx, input, opts...) + } + } + return endpoint } func (w *stateModelWrapper) wrapStreamEndpoint(endpoint streamEndpoint) streamEndpoint { hasUserEventSender := w.hasUserEventSender() retryConfig := w.modelRetryConfig + failoverConfig := w.modelFailoverConfig cc := w.cancelContext for i := len(w.handlers) - 1; i >= 0; i-- { @@ -622,7 +685,7 @@ func (w *stateModelWrapper) wrapStreamEndpoint(endpoint streamEndpoint) streamEn if execCtx == nil || execCtx.generator == nil { return innerEndpoint(ctx, input, opts...) } - mc := &ModelContext{ModelRetryConfig: retryConfig, cancelContext: cc} + mc := &ModelContext{ModelRetryConfig: retryConfig, ModelFailoverConfig: failoverConfig, cancelContext: cc} wrappedModel, err := eventSender.WrapModel(ctx, &endpointModel{stream: innerEndpoint}, mc) if err != nil { return nil, err @@ -639,6 +702,16 @@ func (w *stateModelWrapper) wrapStreamEndpoint(endpoint streamEndpoint) streamEn } } + // Needs to handle failoverWrapper after retryWrapper + if w.modelFailoverConfig != nil { + config := w.modelFailoverConfig + innerEndpoint := endpoint + endpoint = func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + failoverWrapper := newFailoverModelWrapper(&endpointModel{stream: innerEndpoint}, config) + return failoverWrapper.Stream(ctx, input, opts...) + } + } + return endpoint } diff --git a/adk/wrappers_failover_test.go b/adk/wrappers_failover_test.go new file mode 100644 index 000000000..8b14463e1 --- /dev/null +++ b/adk/wrappers_failover_test.go @@ -0,0 +1,181 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "errors" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/schema" +) + +func TestBuildModelWrappers_FailoverProxyInner(t *testing.T) { + base := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return schema.AssistantMessage("ok", nil), nil + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("ok", nil)}), nil + }, + } + + failoverCfg := &ModelFailoverConfig{ + MaxRetries: 0, + ShouldFailover: func(context.Context, *schema.Message, error) bool { return false }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return base, nil, nil + }, + } + + wrapped := buildModelWrappers(base, &modelWrapperConfig{ + failoverConfig: failoverCfg, + }) + + smw, ok := wrapped.(*stateModelWrapper) + require.True(t, ok) + _, ok = smw.inner.(*failoverProxyModel) + require.True(t, ok) + require.Same(t, base, smw.original) + require.Same(t, failoverCfg, smw.modelFailoverConfig) +} + +func TestStateModelWrapper_Generate_WithFailover(t *testing.T) { + wantErr := errors.New("first failed") + var shouldCalls int32 + var m1Calls int32 + var m2Calls int32 + + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m1Calls, 1) + return schema.AssistantMessage("partial", nil), wantErr + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, errors.New("unused") + }, + } + m2 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m2Calls, 1) + return schema.AssistantMessage("ok", nil), nil + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, errors.New("unused") + }, + } + + failoverCfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(_ context.Context, out *schema.Message, err error) bool { + atomic.AddInt32(&shouldCalls, 1) + require.ErrorIs(t, err, wantErr) + require.NotNil(t, out) + require.Equal(t, "partial", out.Content) + return true + }, + GetFailoverModel: func(_ context.Context, failoverCtx *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + require.Equal(t, uint(1), failoverCtx.FailoverAttempt) + return m2, nil, nil + }, + } + + wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + failoverConfig: failoverCfg, + }) + + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + got, err := wrapped.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.NoError(t, err) + require.NotNil(t, got) + require.Equal(t, "ok", got.Content) + require.Equal(t, int32(1), atomic.LoadInt32(&m1Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&m2Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&shouldCalls)) +} + +func TestStateModelWrapper_Stream_WithFailover(t *testing.T) { + streamErr := errors.New("mid error") + var shouldCalls int32 + var m1Calls int32 + var m2Calls int32 + + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m1Calls, 1) + return streamWithMidError([]*schema.Message{ + schema.AssistantMessage("p1", nil), + schema.AssistantMessage("p2", nil), + }, streamErr), nil + }, + } + m2 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m2Calls, 1) + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("final", nil)}), nil + }, + } + + failoverCfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(_ context.Context, out *schema.Message, err error) bool { + atomic.AddInt32(&shouldCalls, 1) + require.ErrorIs(t, err, streamErr) + require.NotNil(t, out) + require.Equal(t, "p1p2", out.Content) + return true + }, + GetFailoverModel: func(_ context.Context, failoverCtx *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + require.Equal(t, uint(1), failoverCtx.FailoverAttempt) + return m2, nil, nil + }, + } + + wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + failoverConfig: failoverCfg, + }) + + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + sr, err := wrapped.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.NoError(t, err) + msgs, err := drainMessageStream(sr) + require.NoError(t, err) + require.Len(t, msgs, 1) + require.Equal(t, "final", msgs[0].Content) + require.Equal(t, int32(1), atomic.LoadInt32(&m1Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&m2Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&shouldCalls)) +} diff --git a/adk/wrappers_retry_failover_test.go b/adk/wrappers_retry_failover_test.go new file mode 100644 index 000000000..98db172e9 --- /dev/null +++ b/adk/wrappers_retry_failover_test.go @@ -0,0 +1,411 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/schema" +) + +// TestRetryThenFailover_Generate_RetryExhaustedTriggersFailover tests the combined +// retry + failover path for Generate: m1 always fails, retry exhausted, failover to m2 which succeeds. +func TestRetryThenFailover_Generate_RetryExhaustedTriggersFailover(t *testing.T) { + modelErr := errors.New("model error") + var m1Calls int32 + var m2Calls int32 + + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m1Calls, 1) + return nil, modelErr + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, errors.New("unused") + }, + } + m2 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m2Calls, 1) + return schema.AssistantMessage("ok from m2", nil), nil + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, errors.New("unused") + }, + } + + retryCfg := &ModelRetryConfig{ + MaxRetries: 2, + IsRetryAble: func(_ context.Context, err error) bool { return true }, + BackoffFunc: func(_ context.Context, _ int) time.Duration { return 0 }, + } + + failoverCfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + return err != nil + }, + GetFailoverModel: func(_ context.Context, fc *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + require.NotNil(t, fc.LastErr) + return m2, nil, nil + }, + } + + wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + retryConfig: retryCfg, + failoverConfig: failoverCfg, + }) + + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + msg, err := wrapped.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.NoError(t, err) + require.Equal(t, "ok from m2", msg.Content) + + // m1: 1 (lastSuccess) + 2 retries = 3 calls on lastSuccess attempt, + // then failover to m2 which also goes through retry wrapper: 1 call succeeds. + require.Equal(t, int32(3), atomic.LoadInt32(&m1Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&m2Calls)) +} + +// TestRetryThenFailover_Generate_AllExhausted tests: m1 retry exhausted → failover to m2 → m2 retry exhausted → final error. +func TestRetryThenFailover_Generate_AllExhausted(t *testing.T) { + modelErr := errors.New("always fails") + var m1Calls int32 + var m2Calls int32 + + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m1Calls, 1) + return nil, modelErr + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, errors.New("unused") + }, + } + m2 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m2Calls, 1) + return nil, modelErr + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, errors.New("unused") + }, + } + + retryCfg := &ModelRetryConfig{ + MaxRetries: 1, + IsRetryAble: func(_ context.Context, err error) bool { return true }, + BackoffFunc: func(_ context.Context, _ int) time.Duration { return 0 }, + } + + failoverCfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + return err != nil + }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return m2, nil, nil + }, + } + + wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + retryConfig: retryCfg, + failoverConfig: failoverCfg, + }) + + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + _, err := wrapped.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.Error(t, err) + + // Should be RetryExhaustedError from m2's retry wrapper + var retryErr *RetryExhaustedError + require.True(t, errors.As(err, &retryErr)) + + // m1: 1 initial + 1 retry = 2 calls + require.Equal(t, int32(2), atomic.LoadInt32(&m1Calls)) + // m2: 1 initial + 1 retry = 2 calls + require.Equal(t, int32(2), atomic.LoadInt32(&m2Calls)) +} + +// TestRetryThenFailover_Stream_RetryExhaustedTriggersFailover tests stream path: +// m1 stream always errors mid-way, retry exhausted, failover to m2 which succeeds. +func TestRetryThenFailover_Stream_RetryExhaustedTriggersFailover(t *testing.T) { + streamErr := errors.New("stream mid error") + var m1Calls int32 + var m2Calls int32 + + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m1Calls, 1) + return streamWithMidError([]*schema.Message{ + schema.AssistantMessage("partial", nil), + }, streamErr), nil + }, + } + m2 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m2Calls, 1) + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("ok from m2", nil)}), nil + }, + } + + retryCfg := &ModelRetryConfig{ + MaxRetries: 1, + IsRetryAble: func(_ context.Context, err error) bool { return true }, + BackoffFunc: func(_ context.Context, _ int) time.Duration { return 0 }, + } + + failoverCfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + return err != nil + }, + GetFailoverModel: func(_ context.Context, fc *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + require.NotNil(t, fc.LastErr) + return m2, nil, nil + }, + } + + wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + retryConfig: retryCfg, + failoverConfig: failoverCfg, + }) + + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + sr, err := wrapped.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.NoError(t, err) + msgs, err := drainMessageStream(sr) + require.NoError(t, err) + require.Len(t, msgs, 1) + require.Equal(t, "ok from m2", msgs[0].Content) + + // m1: 1 initial + 1 retry = 2 calls on lastSuccess attempt + require.Equal(t, int32(2), atomic.LoadInt32(&m1Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&m2Calls)) +} + +// TestRetryThenFailover_Generate_RetrySucceedsNoFailover tests that when retry +// succeeds on the first model, failover is never triggered. +func TestRetryThenFailover_Generate_RetrySucceedsNoFailover(t *testing.T) { + var m1Calls int32 + var failoverCalled int32 + + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + n := atomic.AddInt32(&m1Calls, 1) + if n == 1 { + return nil, errors.New("transient error") + } + return schema.AssistantMessage("ok on retry", nil), nil + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, errors.New("unused") + }, + } + + retryCfg := &ModelRetryConfig{ + MaxRetries: 2, + IsRetryAble: func(_ context.Context, err error) bool { return true }, + BackoffFunc: func(_ context.Context, _ int) time.Duration { return 0 }, + } + + failoverCfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + atomic.AddInt32(&failoverCalled, 1) + return true + }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + t.Fatal("GetFailoverModel should not be called when retry succeeds") + return nil, nil, nil + }, + } + + wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + retryConfig: retryCfg, + failoverConfig: failoverCfg, + }) + + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + msg, err := wrapped.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.NoError(t, err) + require.Equal(t, "ok on retry", msg.Content) + + // 2 calls: first fails, second succeeds via retry + require.Equal(t, int32(2), atomic.LoadInt32(&m1Calls)) + // ShouldFailover should never be called + require.Equal(t, int32(0), atomic.LoadInt32(&failoverCalled)) +} + +// TestRetryThenFailover_Generate_NonRetryableErrorTriggersFailover tests that a non-retryable +// error skips retry and directly triggers failover. +func TestRetryThenFailover_Generate_NonRetryableErrorTriggersFailover(t *testing.T) { + nonRetryableErr := errors.New("non-retryable") + var m1Calls int32 + var m2Calls int32 + + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m1Calls, 1) + return nil, nonRetryableErr + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, errors.New("unused") + }, + } + m2 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m2Calls, 1) + return schema.AssistantMessage("ok from m2", nil), nil + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, errors.New("unused") + }, + } + + retryCfg := &ModelRetryConfig{ + MaxRetries: 3, + IsRetryAble: func(_ context.Context, err error) bool { + // Only non-retryable errors + return !errors.Is(err, nonRetryableErr) + }, + BackoffFunc: func(_ context.Context, _ int) time.Duration { return 0 }, + } + + failoverCfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + return err != nil + }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return m2, nil, nil + }, + } + + wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + retryConfig: retryCfg, + failoverConfig: failoverCfg, + }) + + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + msg, err := wrapped.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.NoError(t, err) + require.Equal(t, "ok from m2", msg.Content) + + // m1 called only once — non-retryable error skips retry + require.Equal(t, int32(1), atomic.LoadInt32(&m1Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&m2Calls)) +} + +// TestRetryThenFailover_Stream_AllExhausted tests stream path when both retry and failover are exhausted. +func TestRetryThenFailover_Stream_AllExhausted(t *testing.T) { + streamErr := errors.New("always fails mid-stream") + var m1Calls int32 + var m2Calls int32 + + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m1Calls, 1) + return streamWithMidError([]*schema.Message{ + schema.AssistantMessage("p", nil), + }, streamErr), nil + }, + } + m2 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m2Calls, 1) + return streamWithMidError([]*schema.Message{ + schema.AssistantMessage("p", nil), + }, streamErr), nil + }, + } + + retryCfg := &ModelRetryConfig{ + MaxRetries: 1, + IsRetryAble: func(_ context.Context, err error) bool { return true }, + BackoffFunc: func(_ context.Context, _ int) time.Duration { return 0 }, + } + + failoverCfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + return err != nil + }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return m2, nil, nil + }, + } + + wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + retryConfig: retryCfg, + failoverConfig: failoverCfg, + }) + + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + _, err := wrapped.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.Error(t, err) + + var retryErr *RetryExhaustedError + require.True(t, errors.As(err, &retryErr)) + + // m1: 1 initial + 1 retry = 2 calls + require.Equal(t, int32(2), atomic.LoadInt32(&m1Calls)) + // m2: 1 initial + 1 retry = 2 calls + require.Equal(t, int32(2), atomic.LoadInt32(&m2Calls)) +} From 42cecff5d862f580341f171296a50f6c2da4d2df Mon Sep 17 00:00:00 2001 From: Megumin Date: Thu, 9 Apr 2026 14:57:26 +0800 Subject: [PATCH 52/59] feat: tool search (#884) feat: tool search definition --- .../dynamictool/toolsearch/prompt.go | 162 ++++ .../dynamictool/toolsearch/toolsearch.go | 457 ++++++++-- .../dynamictool/toolsearch/toolsearch_test.go | 849 ++++++++++-------- components/model/option.go | 37 + schema/agentic_message.go | 52 +- schema/agentic_message_test.go | 2 + schema/message.go | 42 +- schema/tool.go | 101 +++ schema/tool_test.go | 48 + 9 files changed, 1312 insertions(+), 438 deletions(-) create mode 100644 adk/middlewares/dynamictool/toolsearch/prompt.go diff --git a/adk/middlewares/dynamictool/toolsearch/prompt.go b/adk/middlewares/dynamictool/toolsearch/prompt.go new file mode 100644 index 000000000..5aaa56ad1 --- /dev/null +++ b/adk/middlewares/dynamictool/toolsearch/prompt.go @@ -0,0 +1,162 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package toolsearch + +const ( + toolDescription = `Search for or select deferred tools to make them available for use. + +MANDATORY PREREQUISITE - THIS IS A HARD REQUIREMENT + +You MUST use this tool to load deferred tools BEFORE calling them directly. + +This is a BLOCKING REQUIREMENT - deferred tools are NOT available until you load them using this tool. Look for messages in the conversation for the list of tools you can discover. Both query modes (keyword search and direct selection) load the returned tools — once a tool appears in the results, it is immediately available to call. + +Why this is non-negotiable: +- Deferred tools are not loaded until discovered via this tool +- Calling a deferred tool without first loading it will fail +Query modes: + +1. Keyword search - Use keywords when you're unsure which tool to use or need to discover multiple tools at once: + - "list directory" - find tools for listing directories + - "notebook jupyter" - find notebook editing tools + - "slack message" - find slack messaging tools + - Returns up to 5 matching tools ranked by relevance + - All returned tools are immediately available to call — no further selection step needed +2. Direct selection - Use select: when you know the exact tool name: + - "select:mcp__slack__read_channel" + - "select:NotebookEdit" + - "select:Read,Edit,Grep" - load multiple tools at once with comma separation + - Returns the named tool(s) if they exist +IMPORTANT: Both modes load tools equally. Do NOT follow up a keyword search with select: calls for tools already returned — they are already loaded. + +3. Required keyword - Prefix with + to require a match: + - "+linear create issue" - only tools from "linear", ranked by "create"/"issue" + - "+slack send" - only "slack" tools, ranked by "send" + - Useful when you know the service name but not the exact tool +CORRECT Usage Patterns: + + +User: I need to work with slack somehow +Assistant: Let me search for slack tools. +[Calls tool_search with query: "slack"] +Assistant: Found several options including mcp__slack__read_channel. +[Calls mcp__slack__read_channel directly — it was loaded by the keyword search] + + + +User: Edit the Jupyter notebook +Assistant: Let me load the notebook editing tool. +[Calls tool_search with query: "select:NotebookEdit"] +[Calls NotebookEdit] + + + +User: List files in the src directory +Assistant: I can see mcp__filesystem__list_directory in the available tools. Let me select it. +[Calls tool_search with query: "select:mcp__filesystem__list_directory"] +[Calls the tool] + + +INCORRECT Usage Patterns - NEVER DO THESE: + + +User: Read my slack messages +Assistant: [Directly calls mcp__slack__read_channel without loading it first] +WRONG - You must load the tool FIRST using this tool + + + +Assistant: [Calls tool_search with query: "slack", gets back mcp__slack__read_channel] +Assistant: [Calls tool_search with query: "select:mcp__slack__read_channel"] +WRONG - The keyword search already loaded the tool. The select call is redundant. +` + + toolDescriptionChinese = `搜索或选择延迟加载(deferred)的工具,使其可供调用。 + +强制前提条件(MANDATORY PREREQUISITE)— 硬性要求 + +在直接调用任何 延迟加载工具(deferred tools) 之前,你 必须先使用此工具将其加载。 + +这是一个 阻塞性要求(BLOCKING REQUIREMENT) — 延迟加载工具在被加载之前是 不可用的。你需要在对话中查找 消息,以获取可以发现的工具列表。无论使用哪种查询方式(关键字搜索 或 直接选择),只要工具出现在返回结果中,它们就会自动被加载并立即可调用。 + +为什么这是不可协商的规则: +- 延迟加载工具在被发现之前不会被加载 +- 如果你在加载之前直接调用延迟工具,调用将会失败 +查询模式: + +1. 关键字搜索(Keyword search)- 当你不确定具体需要哪个工具,或希望一次发现多个工具时使用关键字搜索: +- "list directory" — 查找用于列出目录的工具 +- "notebook jupyter" — 查找 Jupyter Notebook 编辑工具 +- "slack message" — 查找 Slack 消息相关工具 +- 返回最多 5 个最相关的工具 +- 所有返回的工具都会立即加载并可直接调用 — 不需要额外执行 select 步骤 + +2. 直接选择(Direct selection)— 当你已经知道工具的确切名称时使用 select:: +- "select:mcp__slack__read_channel" +- "select:NotebookEdit" +- "select:Read,Edit,Grep" — 一次加载多个工具 +- 如果工具存在,将被加载并返回 +重要说明:两种模式的加载效果完全相同。不要在关键词搜索之后,对返回的工具再次进行 select: 选择 — 它们已经加载好了。 + +3. 必须匹配关键字(Required keyword)— 在关键字前添加 + 可以 强制匹配特定服务或来源。 +- "+linear create issue" — 仅返回名字中包含 "linear" 的工具,按 "create" / "issue" 排序 +- "+slack send" — 仅返回名字中包含 "slack" 的工具,按 "send" 排序 +- 适用于你知道服务名称但不知道具体工具名称 + +正确使用示例: + + +User: 我需要处理 Slack 相关的事情 +Assistant: 让我搜索 Slack 工具。 +[调用 tool_search,query: "slack"] +Assistant: 找到多个选项,包括 mcp__slack__read_channel。 +[直接调用 mcp__slack__read_channel — 关键字搜索已经加载了该工具] + + + +User: 编辑这个 Jupyter Notebook +Assistant: 让我加载 Notebook 编辑工具。 +[调用 tool_search,query: "select:NotebookEdit"] +[调用 NotebookEdit] + + + +User: 列出 src 目录中的文件 +Assistant: 我看到可用工具中有 mcp__filesystem__list_directory,让我加载它。 +[调用 tool_search,query: "select:mcp__filesystem__list_directory"] +[调用该工具] + + +错误用法(严禁) + + +User: 读取我的 Slack 消息 +Assistant: [不调用 tool_search 工具加载,直接调用 mcp__slack__read_channel] +错误 — 在调用工具之前没有先使用 tool_search 加载该工具。 + + + +Assistant:[调用 tool_search,query: "slack",返回 mcp__slack__read_channel] +Assistant:[再次调用 tool_search,query: "select:mcp__slack__read_channel"] +错误 — 关键字搜索 已经加载了该工具,再次 select 是冗余操作。` + + systemReminderTpl = ` +{{- range .Tools }} +{{ . }} +{{- end }} +` +) diff --git a/adk/middlewares/dynamictool/toolsearch/toolsearch.go b/adk/middlewares/dynamictool/toolsearch/toolsearch.go index 4ee4c216b..55883e914 100644 --- a/adk/middlewares/dynamictool/toolsearch/toolsearch.go +++ b/adk/middlewares/dynamictool/toolsearch/toolsearch.go @@ -18,12 +18,17 @@ package toolsearch import ( + "bytes" "context" "encoding/json" "fmt" - "regexp" + "sort" + "strings" + "text/template" + "unicode" "github.com/cloudwego/eino/adk" + "github.com/cloudwego/eino/adk/internal" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/schema" @@ -33,6 +38,16 @@ import ( type Config struct { // DynamicTools is a list of tools that can be dynamically searched and loaded by the agent. DynamicTools []tool.BaseTool + + // UseModelToolSearch indicates whether the ChatModel natively supports tool search. + // + // When true, the middleware delegates tool search to the model's native capability. + // + // When false (default), the middleware manages tool visibility by filtering the tool list + // based on tool_search results before each model call. Note that this approach may + // invalidate the model's KV-cache (as the tool list changes between calls), and effectiveness + // depends on the model's ability to work with a dynamically changing tool set. + UseModelToolSearch bool } // New constructs and returns the tool search middleware. @@ -41,7 +56,7 @@ type Config struct { // Instead of passing all tools to the model at once (which can overwhelm context limits), // this middleware: // -// 1. Adds a "tool_search" meta-tool that accepts a regex pattern to search tool names +// 1. Adds a "tool_search" meta-tool that accepts keyword queries to search tools // 2. Initially hides all dynamic tools from the model's tool list // 3. When the model calls tool_search, matching tools become available for subsequent calls // @@ -62,14 +77,55 @@ func New(ctx context.Context, config *Config) (adk.ChatModelAgentMiddleware, err return nil, fmt.Errorf("tools is required") } + tpl, err := template.New("").Parse(systemReminderTpl) + if err != nil { + return nil, err + } + + dynamicToolInfos := make([]*schema.ToolInfo, 0, len(config.DynamicTools)) + mapOfDynamicTools := make(map[string]*schema.ToolInfo, len(config.DynamicTools)) + toolNames := make([]string, 0, len(config.DynamicTools)) + for _, t := range config.DynamicTools { + info, infoErr := t.Info(ctx) + if infoErr != nil { + return nil, fmt.Errorf("failed to get dynamic tool info: %w", infoErr) + } + + if _, ok := mapOfDynamicTools[info.Name]; ok { + return nil, fmt.Errorf("duplicate dynamic tool name: %s", info.Name) + } + + toolNames = append(toolNames, info.Name) + mapOfDynamicTools[info.Name] = info + dynamicToolInfos = append(dynamicToolInfos, info) + } + + buf := &bytes.Buffer{} + err = tpl.Execute(buf, systemReminder{Tools: toolNames}) + if err != nil { + return nil, fmt.Errorf("failed to format system reminder template: %w", err) + } + return &middleware{ - dynamicTools: config.DynamicTools, + dynamicTools: config.DynamicTools, + mapOfDynamicTools: mapOfDynamicTools, + dynamicToolInfos: dynamicToolInfos, + useModelToolSearch: config.UseModelToolSearch, + sr: buf.String(), }, nil } +type systemReminder struct { + Tools []string +} + type middleware struct { adk.BaseChatModelAgentMiddleware - dynamicTools []tool.BaseTool + dynamicTools []tool.BaseTool + mapOfDynamicTools map[string]*schema.ToolInfo + dynamicToolInfos []*schema.ToolInfo + useModelToolSearch bool + sr string } func (m *middleware) BeforeAgent(ctx context.Context, runCtx *adk.ChatModelAgentContext) (context.Context, *adk.ChatModelAgentContext, error) { @@ -78,123 +134,384 @@ func (m *middleware) BeforeAgent(ctx context.Context, runCtx *adk.ChatModelAgent } nRunCtx := *runCtx - toolNames, err := getToolNames(ctx, m.dynamicTools) - if err != nil { - return ctx, nil, fmt.Errorf("failed to get tool names: %w", err) - } - nRunCtx.Tools = append(nRunCtx.Tools, newToolSearchTool(toolNames)) + nRunCtx.Tools = make([]tool.BaseTool, len(runCtx.Tools), len(runCtx.Tools)+1+len(m.dynamicTools)) + copy(nRunCtx.Tools, runCtx.Tools) + nRunCtx.Tools = append(nRunCtx.Tools, newToolSearchTool(m.mapOfDynamicTools, m.useModelToolSearch)) nRunCtx.Tools = append(nRunCtx.Tools, m.dynamicTools...) return ctx, &nRunCtx, nil } func (m *middleware) WrapModel(_ context.Context, cm model.BaseChatModel, mc *adk.ModelContext) (model.BaseChatModel, error) { - return &wrapper{allTools: mc.Tools, cm: cm, dynamicTools: m.dynamicTools}, nil + return &wrapper{ + allTools: mc.Tools, + cm: cm, + dynamicToolInfos: m.dynamicToolInfos, + reminder: m.sr, + useModelToolSearch: m.useModelToolSearch, + }, nil } type wrapper struct { - allTools []*schema.ToolInfo - dynamicTools []tool.BaseTool + allTools []*schema.ToolInfo + dynamicToolInfos []*schema.ToolInfo + reminder string + useModelToolSearch bool cm model.BaseChatModel } func (w *wrapper) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { - tools, err := removeTools(ctx, w.allTools, w.dynamicTools, input) + toolsOpts, err := w.resolveTools(ctx, input) if err != nil { return nil, fmt.Errorf("failed to load dynamic tools: %w", err) } - return w.cm.Generate(ctx, input, append(opts, model.WithTools(tools))...) + return w.cm.Generate(ctx, w.insertReminder(input), append(opts, toolsOpts...)...) } func (w *wrapper) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { - tools, err := removeTools(ctx, w.allTools, w.dynamicTools, input) + toolsOpts, err := w.resolveTools(ctx, input) if err != nil { return nil, fmt.Errorf("failed to load dynamic tools: %w", err) } - return w.cm.Stream(ctx, input, append(opts, model.WithTools(tools))...) + return w.cm.Stream(ctx, w.insertReminder(input), append(opts, toolsOpts...)...) +} + +func (w *wrapper) resolveTools(ctx context.Context, input []*schema.Message) ([]model.Option, error) { + if w.useModelToolSearch { + // Model handles tool search natively; remove all dynamic tools from the list. + return calculateTools(ctx, w.allTools, w.dynamicToolInfos, nil, w.useModelToolSearch) + } + return calculateTools(ctx, w.allTools, w.dynamicToolInfos, input, w.useModelToolSearch) +} + +func (w *wrapper) insertReminder(input []*schema.Message) []*schema.Message { + inserted := false + ret := make([]*schema.Message, 0, len(input)+1) + for _, m := range input { + if m.Role != schema.System && !inserted { + inserted = true + ret = append(ret, schema.UserMessage(w.reminder)) + } + ret = append(ret, m) + } + if !inserted { + ret = append(ret, schema.UserMessage(w.reminder)) + } + return ret +} + +func newToolSearchTool(tools map[string]*schema.ToolInfo, useModelToolSearch bool) tool.BaseTool { + if useModelToolSearch { + return &modelToolSearchTool{tools: tools} + } + return &toolSearchTool{tools: tools} } -func newToolSearchTool(toolNames []string) *toolSearchTool { - return &toolSearchTool{toolNames: toolNames} +type toolSearchArgs struct { + Query string `json:"query"` + MaxResults *int `json:"max_results,omitempty"` +} + +type toolSearchResult struct { + Matches []string `json:"matches"` } type toolSearchTool struct { - toolNames []string + tools map[string]*schema.ToolInfo +} + +func (t *toolSearchTool) Info(ctx context.Context) (*schema.ToolInfo, error) { + return getToolSearchToolInfo(), nil +} + +func (t *toolSearchTool) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { + matches, err := search(argumentsInJSON, t.tools) + if err != nil { + return "", err + } + result := &toolSearchResult{} + for _, m := range matches { + result.Matches = append(result.Matches, m.Name) + } + b, err := json.Marshal(result) + if err != nil { + return "", fmt.Errorf("failed to marshal tool search result: %w", err) + } + return string(b), nil +} + +type modelToolSearchTool struct { + tools map[string]*schema.ToolInfo +} + +func (t *modelToolSearchTool) Info(_ context.Context) (*schema.ToolInfo, error) { + return getToolSearchToolInfo(), nil +} + +func (t *modelToolSearchTool) InvokableRun(_ context.Context, argumentsInJSON *schema.ToolArgument, _ ...tool.Option) (*schema.ToolResult, error) { + ret, err := search(argumentsInJSON.Text, t.tools) + if err != nil { + return nil, err + } + + return &schema.ToolResult{Parts: []schema.ToolOutputPart{ + { + Type: schema.ToolPartTypeToolSearchResult, + ToolSearchResult: &schema.ToolSearchResult{ + Tools: ret, + }, + }, + }}, nil } const ( toolSearchToolName = "tool_search" + defaultMaxResults = 5 ) -func (t *toolSearchTool) Info(ctx context.Context) (*schema.ToolInfo, error) { +func getToolSearchToolInfo() *schema.ToolInfo { return &schema.ToolInfo{ - Name: "tool_search", - Desc: "Search for tools using a regex pattern that matches tool names. Returns a list of matching tool names. Use this when you need a tool but don't have it available yet.", + Name: toolSearchToolName, + Desc: internal.SelectPrompt(internal.I18nPrompts{ + English: toolDescription, + Chinese: toolDescriptionChinese, + }), ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ - "regex_pattern": { + "query": { Type: schema.String, - Desc: "A regex pattern to match tool names against.", + Desc: "Query to find deferred tools. Use \"select:\" for direct selection, or keywords to search.", Required: true, }, + "max_results": { + Type: schema.Integer, + Desc: "Maximum number of results to return (default: 5)", + Required: false, + }, }), - }, nil + } } -type toolSearchArgs struct { - RegexPattern string `json:"regex_pattern"` +func search(argumentsInJSON string, tools map[string]*schema.ToolInfo) ([]*schema.ToolInfo, error) { + var args toolSearchArgs + if err := json.Unmarshal([]byte(argumentsInJSON), &args); err != nil { + return nil, fmt.Errorf("failed to unmarshal tool search arguments: %w", err) + } + + query := strings.TrimSpace(args.Query) + if query == "" { + return nil, fmt.Errorf("query is required") + } + + maxResults := defaultMaxResults + if args.MaxResults != nil && *args.MaxResults > 0 { + maxResults = *args.MaxResults + } + + var matches []string + + // Direct selection mode: select:tool1,tool2 + // max_results is intentionally not applied here because the model has + // already specified the exact tools it wants by name. + if strings.HasPrefix(query, "select:") { + names := strings.Split(strings.TrimPrefix(query, "select:"), ",") + toolSet := make(map[string]bool, len(tools)) + for name := range tools { + toolSet[name] = true + } + for _, name := range names { + name = strings.TrimSpace(name) + if name != "" && toolSet[name] { + matches = append(matches, name) + } + } + } else { + matches = keywordSearch(query, maxResults, tools) + } + + ret := make([]*schema.ToolInfo, 0, len(matches)) + for _, name := range matches { + ti, ok := tools[name] + if !ok { + continue + } + ret = append(ret, ti) + } + return ret, nil } -type toolSearchResult struct { - SelectedTools []string `json:"selectedTools"` +func intMax(a, b int) int { + if a > b { + return a + } + return b } -func (t *toolSearchTool) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { - var args toolSearchArgs - if err := json.Unmarshal([]byte(argumentsInJSON), &args); err != nil { - return "", fmt.Errorf("failed to unmarshal tool search arguments: %w", err) +func intMin(a, b int) int { + if a < b { + return a } + return b +} + +// scoredTool pairs a tool name with its search score. +type scoredTool struct { + name string + score int +} - if args.RegexPattern == "" { - return "", fmt.Errorf("regex_pattern is required") +// keywordSearch scores all tools against the query keywords and returns the top N. +func keywordSearch(query string, maxResults int, tools map[string]*schema.ToolInfo) []string { + keywords := parseKeywords(query) + if len(keywords) == 0 { + return nil } - re, err := regexp.Compile(args.RegexPattern) - if err != nil { - return "", fmt.Errorf("invalid regex pattern: %w", err) + var scored []scoredTool + + for name, tm := range tools { + nameParts := splitToolName(name) + nameLower := strings.ToLower(name) + descLower := strings.ToLower(tm.Desc) + + totalScore := 0 + allRequiredFound := true + + for _, kw := range keywords { + kwLower := strings.ToLower(kw.word) + kwScore := 0 + + // Score against name parts + for _, part := range nameParts { + partLower := strings.ToLower(part) + if partLower == kwLower { + kwScore = intMax(kwScore, 10) + } else if strings.Contains(partLower, kwLower) { + kwScore = intMax(kwScore, 5) + } + } + + // Score against full name + if strings.Contains(nameLower, kwLower) { + kwScore = intMax(kwScore, 3) + } + + // Score against description (substring match) + if descLower != "" && strings.Contains(descLower, kwLower) { + kwScore = intMax(kwScore, 2) + } + + if kw.required && kwScore == 0 { + allRequiredFound = false + break + } + + totalScore += kwScore + } + + if !allRequiredFound { + continue + } + + if totalScore > 0 { + scored = append(scored, scoredTool{name: name, score: totalScore}) + } } - var matchedTools []string - for _, name := range t.toolNames { - if re.MatchString(name) { - matchedTools = append(matchedTools, name) + // Sort by score descending, then by name for stability + sort.Slice(scored, func(i, j int) bool { + if scored[i].score != scored[j].score { + return scored[i].score > scored[j].score } + return scored[i].name < scored[j].name + }) + + results := make([]string, 0, intMin(maxResults, len(scored))) + for i := 0; i < len(scored) && i < maxResults; i++ { + results = append(results, scored[i].name) } + return results +} + +// keyword represents a parsed search keyword. +type keyword struct { + word string + required bool +} - result := toolSearchResult{ - SelectedTools: matchedTools, +// parseKeywords splits a query string into keywords, handling the '+' required prefix. +func parseKeywords(query string) (keywords []keyword) { + parts := strings.Fields(query) + for _, p := range parts { + if strings.HasPrefix(p, "+") { + word := strings.TrimPrefix(p, "+") + if word != "" { + keywords = append(keywords, keyword{word: word, required: true}) + } + } else if p != "" { + keywords = append(keywords, keyword{word: p, required: false}) + } } + return +} - output, err := json.Marshal(result) - if err != nil { - return "", fmt.Errorf("failed to marshal result: %w", err) +// splitToolName splits a tool name into parts by underscores, double underscores (MCP separator), +// and camelCase boundaries. +func splitToolName(name string) []string { + // First split by double underscore (MCP server__tool separator) + segments := strings.Split(name, "__") + + var parts []string + for _, seg := range segments { + // Split each segment by single underscore + underscoreParts := strings.Split(seg, "_") + for _, up := range underscoreParts { + if up == "" { + continue + } + // Further split by camelCase + camelParts := splitCamelCase(up) + parts = append(parts, camelParts...) + } + } + return parts +} + +// splitCamelCase splits a camelCase or PascalCase string into its constituent words. +func splitCamelCase(s string) []string { + if s == "" { + return nil } - return string(output), nil + var parts []string + runes := []rune(s) + start := 0 + + for i := 1; i < len(runes); i++ { + if unicode.IsUpper(runes[i]) { + if unicode.IsLower(runes[i-1]) { + parts = append(parts, string(runes[start:i])) + start = i + } else if i+1 < len(runes) && unicode.IsLower(runes[i+1]) { + parts = append(parts, string(runes[start:i])) + start = i + } + } + } + parts = append(parts, string(runes[start:])) + + return parts } -func getToolNames(ctx context.Context, tools []tool.BaseTool) ([]string, error) { +// getToolNames extracts just tool names from a slice of BaseTools (used by calculateTools). +func getToolNames(tools []*schema.ToolInfo) []string { ret := make([]string, 0, len(tools)) for _, t := range tools { - info, err := t.Info(ctx) - if err != nil { - return nil, err - } - ret = append(ret, info.Name) + ret = append(ret, t.Name) } - return ret, nil + return ret } -func extractSelectedTools(ctx context.Context, messages []*schema.Message) ([]string, error) { +func extractSelectedTools(_ context.Context, messages []*schema.Message) ([]string, error) { var selectedTools []string for _, message := range messages { if message.Role != schema.Tool || message.ToolName != toolSearchToolName { @@ -206,7 +523,7 @@ func extractSelectedTools(ctx context.Context, messages []*schema.Message) ([]st if err != nil { return nil, fmt.Errorf("failed to unmarshal tool search tool result: %w", err) } - selectedTools = append(selectedTools, result.SelectedTools...) + selectedTools = append(selectedTools, result.Matches...) } return selectedTools, nil } @@ -226,22 +543,30 @@ func invertSelect[T comparable](all []T, selected []T) map[T]struct{} { return result } -func removeTools(ctx context.Context, all []*schema.ToolInfo, dynamicTools []tool.BaseTool, messages []*schema.Message) ([]*schema.ToolInfo, error) { - selectedToolNames, err := extractSelectedTools(ctx, messages) - if err != nil { - return nil, err +func calculateTools(ctx context.Context, all []*schema.ToolInfo, dynamicTools []*schema.ToolInfo, messages []*schema.Message, useModelToolSearch bool) ([]model.Option, error) { + var err error + var ret []model.Option + var selectedToolNames []string + if !useModelToolSearch { + selectedToolNames, err = extractSelectedTools(ctx, messages) + if err != nil { + return nil, err + } } - dynamicToolNames, err := getToolNames(ctx, dynamicTools) - if err != nil { - return nil, err + dynamicToolNames := getToolNames(dynamicTools) + if useModelToolSearch { + // if useModelToolSearch, register tool search tool by WithToolSearchTool + dynamicToolNames = append(dynamicToolNames, toolSearchToolName) + ret = append(ret, model.WithToolSearchTool(getToolSearchToolInfo())) } removeMap := invertSelect(dynamicToolNames, selectedToolNames) - ret := make([]*schema.ToolInfo, 0, len(all)-len(dynamicTools)) + tools := make([]*schema.ToolInfo, 0, len(all)-len(dynamicTools)) for _, info := range all { if _, ok := removeMap[info.Name]; ok { continue } - ret = append(ret, info) + tools = append(tools, info) } + ret = append(ret, model.WithTools(tools)) return ret, nil } diff --git a/adk/middlewares/dynamictool/toolsearch/toolsearch_test.go b/adk/middlewares/dynamictool/toolsearch/toolsearch_test.go index 4b249b9be..20cee35da 100644 --- a/adk/middlewares/dynamictool/toolsearch/toolsearch_test.go +++ b/adk/middlewares/dynamictool/toolsearch/toolsearch_test.go @@ -19,6 +19,10 @@ package toolsearch import ( "context" "encoding/json" + "fmt" + "sort" + "strings" + "sync" "testing" "github.com/stretchr/testify/assert" @@ -27,464 +31,569 @@ import ( "github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/tool" + "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/schema" ) -type mockTool struct { - name string - desc string -} +// --------------------------------------------------------------------------- +// helpers +// --------------------------------------------------------------------------- -func (m *mockTool) Info(ctx context.Context) (*schema.ToolInfo, error) { - return &schema.ToolInfo{ - Name: m.name, - Desc: m.desc, - }, nil +func makeToolMap(tools ...*schema.ToolInfo) map[string]*schema.ToolInfo { + m := make(map[string]*schema.ToolInfo, len(tools)) + for _, t := range tools { + m[t.Name] = t + } + return m } -func newMockTool(name, desc string) *mockTool { - return &mockTool{name: name, desc: desc} +func ti(name, desc string) *schema.ToolInfo { + return &schema.ToolInfo{Name: name, Desc: desc} } -func TestNew(t *testing.T) { - ctx := context.Background() +func toolNames(infos []*schema.ToolInfo) []string { + names := make([]string, len(infos)) + for i, info := range infos { + names[i] = info.Name + } + sort.Strings(names) + return names +} - t.Run("nil config returns error", func(t *testing.T) { - m, err := New(ctx, nil) - assert.Nil(t, m) - assert.Error(t, err) - assert.Contains(t, err.Error(), "config is required") - }) +func searchJSON(query string, maxResults *int) string { + args := toolSearchArgs{Query: query, MaxResults: maxResults} + b, _ := json.Marshal(args) + return string(b) +} - t.Run("empty tools returns error", func(t *testing.T) { - m, err := New(ctx, &Config{DynamicTools: []tool.BaseTool{}}) - assert.Nil(t, m) - assert.Error(t, err) - assert.Contains(t, err.Error(), "tools is required") - }) +func intPtr(v int) *int { return &v } + +// --------------------------------------------------------------------------- +// TestSearch — unit tests for the search() function +// --------------------------------------------------------------------------- + +func TestSearch(t *testing.T) { + tools := makeToolMap( + ti("get_weather", "Get current weather for a city"), + ti("search_flights", "Search available flights"), + ti("mcp__slack__send_message", "Send a message to Slack channel"), + ti("mcp__slack__read_channel", "Read messages from Slack channel"), + ti("create_calendar_event", "Create a new calendar event"), + ti("NotebookEdit", "Edit Jupyter notebook cells"), + ) + + tests := []struct { + name string + json string + wantNames []string // sorted; nil means expect empty + wantErr bool + }{ + { + name: "keyword exact name part match", + json: searchJSON("weather", nil), + wantNames: []string{"get_weather"}, + }, + { + name: "keyword matches multiple tools", + json: searchJSON("slack", nil), + wantNames: []string{"mcp__slack__read_channel", "mcp__slack__send_message"}, + }, + { + name: "multi-word ranking - send_message ranked first", + json: searchJSON("send message", nil), + wantNames: []string{"mcp__slack__send_message"}, // check first element only + }, + { + name: "required keyword filters to slack only", + json: searchJSON("+slack send", nil), + wantNames: []string{"mcp__slack__read_channel", "mcp__slack__send_message"}, + }, + { + name: "required keyword no match", + json: searchJSON("+github send", nil), + wantNames: nil, + }, + { + name: "direct select single", + json: searchJSON("select:get_weather", nil), + wantNames: []string{"get_weather"}, + }, + { + name: "direct select multiple", + json: searchJSON("select:get_weather,NotebookEdit", nil), + wantNames: []string{"NotebookEdit", "get_weather"}, + }, + { + name: "direct select nonexistent", + json: searchJSON("select:nonexistent", nil), + wantNames: nil, + }, + { + name: "max_results limits output", + json: searchJSON("slack", intPtr(1)), + wantNames: []string{"mcp__slack__read_channel"}, // just check length below + }, + { + name: "camelCase split matches notebook", + json: searchJSON("notebook", nil), + wantNames: []string{"NotebookEdit"}, + }, + { + name: "empty query returns error", + json: searchJSON("", nil), + wantErr: true, + }, + { + name: "description match - jupyter", + json: searchJSON("jupyter", nil), + wantNames: []string{"NotebookEdit"}, + }, + } - t.Run("valid config returns middleware", func(t *testing.T) { - tools := []tool.BaseTool{ - newMockTool("tool1", "desc1"), - newMockTool("tool2", "desc2"), - } - m, err := New(ctx, &Config{DynamicTools: tools}) - assert.NoError(t, err) - assert.NotNil(t, m) - }) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := search(tt.json, tools) + if tt.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + + // special case: max_results limit + if tt.name == "max_results limits output" { + assert.Len(t, got, 1) + return + } + + // special case: ranking — just check first element + if tt.name == "multi-word ranking - send_message ranked first" { + require.NotEmpty(t, got) + assert.Equal(t, "mcp__slack__send_message", got[0].Name) + return + } + + gotNames := toolNames(got) + if tt.wantNames == nil { + assert.Empty(t, gotNames) + } else { + assert.Equal(t, tt.wantNames, gotNames) + } + }) + } } -func TestMiddleware_BeforeAgent(t *testing.T) { - ctx := context.Background() +// --------------------------------------------------------------------------- +// TestMiddlewareFlow — integration test for UseModelToolSearch=false +// --------------------------------------------------------------------------- - t.Run("nil runCtx returns nil", func(t *testing.T) { - tools := []tool.BaseTool{newMockTool("tool1", "desc1")} - m, err := New(ctx, &Config{DynamicTools: tools}) - require.NoError(t, err) +// simpleTool is a minimal InvokableTool for testing. +type simpleTool struct { + name string + desc string + called bool + mu sync.Mutex +} - newCtx, newRunCtx, err := m.BeforeAgent(ctx, nil) - assert.NoError(t, err) - assert.Equal(t, ctx, newCtx) - assert.Nil(t, newRunCtx) - }) +func (s *simpleTool) Info(_ context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: s.name, + Desc: s.desc, + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "input": {Type: schema.String, Desc: "input", Required: true}, + }), + }, nil +} - t.Run("adds tool_search and dynamic tools", func(t *testing.T) { - tools := []tool.BaseTool{ - newMockTool("tool1", "desc1"), - newMockTool("tool2", "desc2"), - } - m, err := New(ctx, &Config{DynamicTools: tools}) - require.NoError(t, err) +func (s *simpleTool) InvokableRun(_ context.Context, _ string, _ ...tool.Option) (string, error) { + s.mu.Lock() + s.called = true + s.mu.Unlock() + return `{"result":"ok"}`, nil +} - middleware := m.(*middleware) - runCtx := &adk.ChatModelAgentContext{ - Tools: []tool.BaseTool{}, - } +func (s *simpleTool) wasCalled() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.called +} - _, newRunCtx, err := middleware.BeforeAgent(ctx, runCtx) - assert.NoError(t, err) - assert.NotNil(t, newRunCtx) - assert.Len(t, newRunCtx.Tools, 3) - }) +// mockChatModel implements model.ToolCallingChatModel. +// It drives a 3-turn conversation: +// +// Turn 1: call tool_search with select:dynamic_tool_a +// Turn 2: call dynamic_tool_a +// Turn 3: return final text +type mockChatModel struct { + mu sync.Mutex + generateCall int + // toolsPerCall records the tool names passed via model.WithTools for each Generate call. + toolsPerCall [][]string } -func TestToolSearchTool_Info(t *testing.T) { - ctx := context.Background() - toolNames := []string{"tool1", "tool2", "tool3"} - tst := newToolSearchTool(toolNames) - - info, err := tst.Info(ctx) - assert.NoError(t, err) - assert.Equal(t, "tool_search", info.Name) - assert.Contains(t, info.Desc, "regex pattern") - assert.NotNil(t, info.ParamsOneOf) +func (m *mockChatModel) Generate(_ context.Context, _ []*schema.Message, opts ...model.Option) (*schema.Message, error) { + options := model.GetCommonOptions(nil, opts...) + var names []string + for _, t := range options.Tools { + names = append(names, t.Name) + } + sort.Strings(names) + + m.mu.Lock() + m.generateCall++ + call := m.generateCall + m.toolsPerCall = append(m.toolsPerCall, names) + m.mu.Unlock() + + switch call { + case 1: + // Ask tool_search to select dynamic_tool_a + return schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "tc1", + Function: schema.FunctionCall{ + Name: toolSearchToolName, + Arguments: `{"query":"select:dynamic_tool_a","max_results":5}`, + }, + }, + }), nil + case 2: + // Call dynamic_tool_a + return schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "tc2", + Function: schema.FunctionCall{ + Name: "dynamic_tool_a", + Arguments: `{"input":"hello"}`, + }, + }, + }), nil + default: + // Final response + return schema.AssistantMessage("done", nil), nil + } } -func TestToolSearchTool_InvokableRun(t *testing.T) { - ctx := context.Background() - toolNames := []string{"get_weather", "get_time", "search_web", "calculate_sum"} - tst := newToolSearchTool(toolNames) +func (m *mockChatModel) Stream(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, fmt.Errorf("not implemented") +} - t.Run("empty regex pattern returns error", func(t *testing.T) { - args := `{"regex_pattern": ""}` - result, err := tst.InvokableRun(ctx, args) - assert.Error(t, err) - assert.Contains(t, err.Error(), "regex_pattern is required") - assert.Empty(t, result) - }) +func (m *mockChatModel) WithTools(_ []*schema.ToolInfo) (model.ToolCallingChatModel, error) { + return m, nil +} - t.Run("invalid json returns error", func(t *testing.T) { - args := `{invalid json}` - result, err := tst.InvokableRun(ctx, args) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to unmarshal") - assert.Empty(t, result) - }) +func (m *mockChatModel) getToolsPerCall() [][]string { + m.mu.Lock() + defer m.mu.Unlock() + ret := make([][]string, len(m.toolsPerCall)) + copy(ret, m.toolsPerCall) + return ret +} - t.Run("invalid regex returns error", func(t *testing.T) { - args := `{"regex_pattern": "[invalid"}` - result, err := tst.InvokableRun(ctx, args) - assert.Error(t, err) - assert.Contains(t, err.Error(), "invalid regex pattern") - assert.Empty(t, result) - }) +func TestMiddlewareFlow(t *testing.T) { + ctx := context.Background() - t.Run("matches tools with prefix pattern", func(t *testing.T) { - args := `{"regex_pattern": "^get_"}` - result, err := tst.InvokableRun(ctx, args) - assert.NoError(t, err) + dynamicA := &simpleTool{name: "dynamic_tool_a", desc: "Dynamic tool A"} + dynamicB := &simpleTool{name: "dynamic_tool_b", desc: "Dynamic tool B"} + staticTool := &simpleTool{name: "static_tool", desc: "Static tool"} - var res toolSearchResult - err = json.Unmarshal([]byte(result), &res) - assert.NoError(t, err) - assert.ElementsMatch(t, []string{"get_weather", "get_time"}, res.SelectedTools) + mw, err := New(ctx, &Config{ + DynamicTools: []tool.BaseTool{dynamicA, dynamicB}, + UseModelToolSearch: false, }) - - t.Run("matches tools with suffix pattern", func(t *testing.T) { - args := `{"regex_pattern": "_sum$"}` - result, err := tst.InvokableRun(ctx, args) - assert.NoError(t, err) - - var res toolSearchResult - err = json.Unmarshal([]byte(result), &res) - assert.NoError(t, err) - assert.ElementsMatch(t, []string{"calculate_sum"}, res.SelectedTools) + require.NoError(t, err) + + cm := &mockChatModel{} + + agent, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ + Name: "test_agent", + Description: "test", + Instruction: "you are a test agent", + Model: cm, + ToolsConfig: adk.ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{staticTool}, + }, + }, + Handlers: []adk.ChatModelAgentMiddleware{mw}, }) + require.NoError(t, err) - t.Run("matches all tools with wildcard", func(t *testing.T) { - args := `{"regex_pattern": ".*"}` - result, err := tst.InvokableRun(ctx, args) - assert.NoError(t, err) + input := &adk.AgentInput{ + Messages: []adk.Message{schema.UserMessage("test")}, + } + iter := agent.Run(ctx, input) - var res toolSearchResult - err = json.Unmarshal([]byte(result), &res) - assert.NoError(t, err) - assert.ElementsMatch(t, toolNames, res.SelectedTools) - }) + var events []*adk.AgentEvent + for { + ev, ok := iter.Next() + if !ok { + break + } + events = append(events, ev) + } - t.Run("no matches returns empty list", func(t *testing.T) { - args := `{"regex_pattern": "^nonexistent_"}` - result, err := tst.InvokableRun(ctx, args) - assert.NoError(t, err) + // Verify no error event. + for _, ev := range events { + if ev.Err != nil { + t.Fatalf("unexpected error event: %v", ev.Err) + } + } - var res toolSearchResult - err = json.Unmarshal([]byte(result), &res) - assert.NoError(t, err) - assert.Empty(t, res.SelectedTools) - }) + // Verify final output is "done". + lastEvent := events[len(events)-1] + require.NotNil(t, lastEvent.Output) + require.NotNil(t, lastEvent.Output.MessageOutput) + assert.Equal(t, "done", lastEvent.Output.MessageOutput.Message.Content) + + // Verify dynamic_tool_a was actually called. + assert.True(t, dynamicA.wasCalled(), "dynamic_tool_a should have been called") + assert.False(t, dynamicB.wasCalled(), "dynamic_tool_b should not have been called") + + // Verify tool lists per Generate call. + toolsPerCall := cm.getToolsPerCall() + require.Len(t, toolsPerCall, 3, "expected 3 Generate calls") + + // Call 1: tool_search + static_tool; dynamic tools are hidden. + assert.Contains(t, toolsPerCall[0], "tool_search") + assert.Contains(t, toolsPerCall[0], "static_tool") + assert.NotContains(t, toolsPerCall[0], "dynamic_tool_a") + assert.NotContains(t, toolsPerCall[0], "dynamic_tool_b") + + // Call 2: after selecting dynamic_tool_a, it becomes visible. + assert.Contains(t, toolsPerCall[1], "tool_search") + assert.Contains(t, toolsPerCall[1], "static_tool") + assert.Contains(t, toolsPerCall[1], "dynamic_tool_a") + assert.NotContains(t, toolsPerCall[1], "dynamic_tool_b") + + // Call 3: same as call 2. + assert.Contains(t, toolsPerCall[2], "tool_search") + assert.Contains(t, toolsPerCall[2], "static_tool") + assert.Contains(t, toolsPerCall[2], "dynamic_tool_a") + assert.NotContains(t, toolsPerCall[2], "dynamic_tool_b") + + // Verify reminder is present in messages (checked via tool list — the wrapper inserts it). + // The model received messages, and the reminder contains "". + // We indirectly verify this by checking that the middleware ran without error and the + // 3-turn flow completed successfully, which requires the tool_search tool to work. + + // Additional: verify that the reminder contains the dynamic tool names. + mwImpl := mw.(*middleware) + assert.True(t, strings.Contains(mwImpl.sr, "dynamic_tool_a")) + assert.True(t, strings.Contains(mwImpl.sr, "dynamic_tool_b")) + assert.True(t, strings.Contains(mwImpl.sr, "")) } -func TestGetToolNames(t *testing.T) { +// --------------------------------------------------------------------------- +// TestNew — error paths for New() +// --------------------------------------------------------------------------- + +func TestNew(t *testing.T) { ctx := context.Background() - t.Run("returns tool names", func(t *testing.T) { - tools := []tool.BaseTool{ - newMockTool("tool1", "desc1"), - newMockTool("tool2", "desc2"), - newMockTool("tool3", "desc3"), - } - names, err := getToolNames(ctx, tools) - assert.NoError(t, err) - assert.Equal(t, []string{"tool1", "tool2", "tool3"}, names) + t.Run("nil config", func(t *testing.T) { + _, err := New(ctx, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "config is required") }) - t.Run("empty tools returns empty slice", func(t *testing.T) { - names, err := getToolNames(ctx, []tool.BaseTool{}) - assert.NoError(t, err) - assert.Empty(t, names) + t.Run("empty DynamicTools", func(t *testing.T) { + _, err := New(ctx, &Config{}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "tools is required") }) -} - -func TestExtractSelectedTools(t *testing.T) { - ctx := context.Background() - - t.Run("extracts selected tools from messages", func(t *testing.T) { - result := toolSearchResult{SelectedTools: []string{"tool1", "tool2"}} - resultJSON, _ := json.Marshal(result) - messages := []*schema.Message{ - schema.UserMessage("hello"), - {Role: schema.Tool, ToolName: toolSearchToolName, Content: string(resultJSON)}, - } - - selected, err := extractSelectedTools(ctx, messages) - assert.NoError(t, err) - assert.ElementsMatch(t, []string{"tool1", "tool2"}, selected) + t.Run("success", func(t *testing.T) { + st := &simpleTool{name: "t1", desc: "tool 1"} + mw, err := New(ctx, &Config{DynamicTools: []tool.BaseTool{st}}) + require.NoError(t, err) + assert.NotNil(t, mw) }) +} - t.Run("handles multiple tool_search results", func(t *testing.T) { - result1 := toolSearchResult{SelectedTools: []string{"tool1"}} - result1JSON, _ := json.Marshal(result1) - result2 := toolSearchResult{SelectedTools: []string{"tool2", "tool3"}} - result2JSON, _ := json.Marshal(result2) +// --------------------------------------------------------------------------- +// TestSplitCamelCase +// --------------------------------------------------------------------------- + +func TestSplitCamelCase(t *testing.T) { + tests := []struct { + input string + want []string + }{ + {"", nil}, + {"hello", []string{"hello"}}, + {"NotebookEdit", []string{"Notebook", "Edit"}}, + {"camelCase", []string{"camel", "Case"}}, + {"HTMLParser", []string{"HTML", "Parser"}}, + {"getURL", []string{"get", "URL"}}, + {"A", []string{"A"}}, + {"AB", []string{"AB"}}, + {"HTTP", []string{"HTTP"}}, + } + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := splitCamelCase(tt.input) + assert.Equal(t, tt.want, got) + }) + } +} - messages := []*schema.Message{ - {Role: schema.Tool, ToolName: toolSearchToolName, Content: string(result1JSON)}, - schema.UserMessage("continue"), - {Role: schema.Tool, ToolName: toolSearchToolName, Content: string(result2JSON)}, - } +// --------------------------------------------------------------------------- +// TestInsertReminder +// --------------------------------------------------------------------------- - selected, err := extractSelectedTools(ctx, messages) - assert.NoError(t, err) - assert.ElementsMatch(t, []string{"tool1", "tool2", "tool3"}, selected) - }) +func TestInsertReminder(t *testing.T) { + w := &wrapper{reminder: ""} - t.Run("ignores non-tool_search messages", func(t *testing.T) { - messages := []*schema.Message{ - schema.UserMessage("hello"), - {Role: schema.Tool, ToolName: "other_tool", Content: "some content"}, - {Role: schema.Assistant, Content: "response"}, + t.Run("normal: system then user", func(t *testing.T) { + input := []*schema.Message{ + {Role: schema.System, Content: "sys"}, + {Role: schema.User, Content: "hi"}, } - - selected, err := extractSelectedTools(ctx, messages) - assert.NoError(t, err) - assert.Empty(t, selected) + got := w.insertReminder(input) + require.Len(t, got, 3) + assert.Equal(t, schema.System, got[0].Role) + assert.Equal(t, schema.User, got[1].Role) + assert.Equal(t, "", got[1].Content) + assert.Equal(t, schema.User, got[2].Role) + assert.Equal(t, "hi", got[2].Content) }) - t.Run("returns error for invalid json", func(t *testing.T) { - messages := []*schema.Message{ - {Role: schema.Tool, ToolName: toolSearchToolName, Content: "invalid json"}, + t.Run("all system messages", func(t *testing.T) { + input := []*schema.Message{ + {Role: schema.System, Content: "sys1"}, + {Role: schema.System, Content: "sys2"}, } - - selected, err := extractSelectedTools(ctx, messages) - assert.Error(t, err) - assert.Nil(t, selected) + got := w.insertReminder(input) + require.Len(t, got, 3) + // Reminder appended at the end since no non-system message found during iteration. + assert.Equal(t, schema.System, got[0].Role) + assert.Equal(t, schema.System, got[1].Role) + assert.Equal(t, "", got[2].Content) }) -} -func TestInvertSelect(t *testing.T) { - t.Run("returns items not in selected", func(t *testing.T) { - all := []string{"a", "b", "c", "d"} - selected := []string{"b", "d"} - - result := invertSelect(all, selected) - assert.Len(t, result, 2) - _, hasA := result["a"] - _, hasC := result["c"] - assert.True(t, hasA) - assert.True(t, hasC) + t.Run("empty input", func(t *testing.T) { + got := w.insertReminder(nil) + require.Len(t, got, 1) + assert.Equal(t, "", got[0].Content) }) - t.Run("empty selected returns all", func(t *testing.T) { - all := []string{"a", "b", "c"} - selected := []string{} - - result := invertSelect(all, selected) - assert.Len(t, result, 3) - }) - - t.Run("all selected returns empty", func(t *testing.T) { - all := []string{"a", "b"} - selected := []string{"a", "b"} - - result := invertSelect(all, selected) - assert.Empty(t, result) - }) - - t.Run("works with integers", func(t *testing.T) { - all := []int{1, 2, 3, 4, 5} - selected := []int{2, 4} - - result := invertSelect(all, selected) - assert.Len(t, result, 3) - _, has1 := result[1] - _, has3 := result[3] - _, has5 := result[5] - assert.True(t, has1) - assert.True(t, has3) - assert.True(t, has5) + t.Run("no system messages", func(t *testing.T) { + input := []*schema.Message{ + {Role: schema.User, Content: "hi"}, + {Role: schema.Assistant, Content: "hello"}, + } + got := w.insertReminder(input) + require.Len(t, got, 3) + // Reminder inserted before the first non-system message. + assert.Equal(t, "", got[0].Content) + assert.Equal(t, "hi", got[1].Content) + assert.Equal(t, "hello", got[2].Content) }) } -func TestRemoveTools(t *testing.T) { - ctx := context.Background() - - t.Run("removes unselected dynamic tools", func(t *testing.T) { - allTools := []*schema.ToolInfo{ - {Name: "static_tool"}, - {Name: "dynamic_tool1"}, - {Name: "dynamic_tool2"}, - {Name: "dynamic_tool3"}, - } +// --------------------------------------------------------------------------- +// TestExtractSelectedTools +// --------------------------------------------------------------------------- - dynamicTools := []tool.BaseTool{ - newMockTool("dynamic_tool1", ""), - newMockTool("dynamic_tool2", ""), - newMockTool("dynamic_tool3", ""), - } +func TestExtractSelectedTools(t *testing.T) { + ctx := context.Background() - result := toolSearchResult{SelectedTools: []string{"dynamic_tool1"}} - resultJSON, _ := json.Marshal(result) + t.Run("accumulates from multiple tool_search results", func(t *testing.T) { messages := []*schema.Message{ - {Role: schema.Tool, ToolName: toolSearchToolName, Content: string(resultJSON)}, - } - - tools, err := removeTools(ctx, allTools, dynamicTools, messages) - assert.NoError(t, err) - assert.Len(t, tools, 2) - - toolNames := make([]string, len(tools)) - for i, t := range tools { - toolNames[i] = t.Name + {Role: schema.Tool, ToolName: toolSearchToolName, Content: `{"matches":["tool_a"]}`}, + {Role: schema.Tool, ToolName: toolSearchToolName, Content: `{"matches":["tool_b","tool_c"]}`}, } - assert.ElementsMatch(t, []string{"static_tool", "dynamic_tool1"}, toolNames) + got, err := extractSelectedTools(ctx, messages) + require.NoError(t, err) + assert.Equal(t, []string{"tool_a", "tool_b", "tool_c"}, got) }) - t.Run("remove all dynamic tools when no tool_search result", func(t *testing.T) { - allTools := []*schema.ToolInfo{ - {Name: "static_tool"}, - {Name: "dynamic_tool1"}, - } - - dynamicTools := []tool.BaseTool{ - newMockTool("dynamic_tool1", ""), - } - + t.Run("ignores non tool_search messages", func(t *testing.T) { messages := []*schema.Message{ - schema.UserMessage("hello"), + {Role: schema.User, Content: "hello"}, + {Role: schema.Tool, ToolName: "other_tool", Content: `{"matches":["should_ignore"]}`}, + {Role: schema.Assistant, Content: "world"}, + {Role: schema.Tool, ToolName: toolSearchToolName, Content: `{"matches":["tool_a"]}`}, } - - tools, err := removeTools(ctx, allTools, dynamicTools, messages) - assert.NoError(t, err) - assert.Len(t, tools, 1) - assert.Equal(t, "static_tool", tools[0].Name) + got, err := extractSelectedTools(ctx, messages) + require.NoError(t, err) + assert.Equal(t, []string{"tool_a"}, got) }) - t.Run("handles empty dynamic tools", func(t *testing.T) { - allTools := []*schema.ToolInfo{ - {Name: "static_tool1"}, - {Name: "static_tool2"}, + t.Run("malformed JSON returns error", func(t *testing.T) { + messages := []*schema.Message{ + {Role: schema.Tool, ToolName: toolSearchToolName, Content: `not json`}, } - - dynamicTools := []tool.BaseTool{} - messages := []*schema.Message{} - - tools, err := removeTools(ctx, allTools, dynamicTools, messages) - assert.NoError(t, err) - assert.Len(t, tools, 2) + _, err := extractSelectedTools(ctx, messages) + assert.Error(t, err) }) -} - -type mockChatModel struct { - generateFunc func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) - streamFunc func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) -} -func (m *mockChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { - if m.generateFunc != nil { - return m.generateFunc(ctx, input, opts...) - } - return &schema.Message{Role: schema.Assistant, Content: "response"}, nil + t.Run("nil messages returns nil", func(t *testing.T) { + got, err := extractSelectedTools(ctx, nil) + require.NoError(t, err) + assert.Nil(t, got) + }) } -func (m *mockChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { - if m.streamFunc != nil { - return m.streamFunc(ctx, input, opts...) - } - return nil, nil -} +// --------------------------------------------------------------------------- +// TestCalculateTools +// --------------------------------------------------------------------------- -func TestWrapper_Generate(t *testing.T) { +func TestCalculateTools(t *testing.T) { ctx := context.Background() - t.Run("filters tools based on tool_search result", func(t *testing.T) { - allTools := []*schema.ToolInfo{ - {Name: "static_tool"}, - {Name: "dynamic_tool1"}, - {Name: "dynamic_tool2"}, - } - - dynamicTools := []tool.BaseTool{ - newMockTool("dynamic_tool1", ""), - newMockTool("dynamic_tool2", ""), - } + staticTool := ti("static_tool", "static") + toolSearchInfo := getToolSearchToolInfo() + dynA := ti("dynamic_a", "A") + dynB := ti("dynamic_b", "B") - result := toolSearchResult{SelectedTools: []string{"dynamic_tool1"}} - resultJSON, _ := json.Marshal(result) + allTools := []*schema.ToolInfo{staticTool, toolSearchInfo, dynA, dynB} + dynamicTools := []*schema.ToolInfo{dynA, dynB} + t.Run("no selection: dynamic tools hidden", func(t *testing.T) { messages := []*schema.Message{ - schema.UserMessage("hello"), - {Role: schema.Tool, ToolName: toolSearchToolName, Content: string(resultJSON)}, + {Role: schema.User, Content: "hello"}, } + opts, err := calculateTools(ctx, allTools, dynamicTools, messages, false) + require.NoError(t, err) - w := &wrapper{ - allTools: allTools, - dynamicTools: dynamicTools, - cm: &mockChatModel{ - generateFunc: func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { - options := model.GetCommonOptions(nil, opts...) - assert.Len(t, options.Tools, 2) - assert.Equal(t, "static_tool", options.Tools[0].Name) - assert.Equal(t, "dynamic_tool1", options.Tools[1].Name) - return nil, nil - }, - }, + options := model.GetCommonOptions(nil, opts...) + var names []string + for _, t := range options.Tools { + names = append(names, t.Name) } - - _, err := w.Generate(ctx, messages) - assert.NoError(t, err) + sort.Strings(names) + assert.Equal(t, []string{"static_tool", "tool_search"}, names) }) -} - -func TestWrapper_Stream(t *testing.T) { - ctx := context.Background() - t.Run("filters tools based on tool_search result", func(t *testing.T) { - allTools := []*schema.ToolInfo{ - {Name: "static_tool"}, - {Name: "dynamic_tool1"}, - {Name: "dynamic_tool2"}, + t.Run("partial selection: selected tool visible", func(t *testing.T) { + messages := []*schema.Message{ + {Role: schema.Tool, ToolName: toolSearchToolName, Content: `{"matches":["dynamic_a"]}`}, } + opts, err := calculateTools(ctx, allTools, dynamicTools, messages, false) + require.NoError(t, err) - dynamicTools := []tool.BaseTool{ - newMockTool("dynamic_tool1", ""), - newMockTool("dynamic_tool2", ""), + options := model.GetCommonOptions(nil, opts...) + var names []string + for _, t := range options.Tools { + names = append(names, t.Name) } + sort.Strings(names) + assert.Equal(t, []string{"dynamic_a", "static_tool", "tool_search"}, names) + }) - result := toolSearchResult{SelectedTools: []string{"dynamic_tool1"}} - resultJSON, _ := json.Marshal(result) - - messages := []*schema.Message{ - schema.UserMessage("hello"), - {Role: schema.Tool, ToolName: toolSearchToolName, Content: string(resultJSON)}, - } + t.Run("useModelToolSearch: dynamic tools and tool_search removed from WithTools", func(t *testing.T) { + opts, err := calculateTools(ctx, allTools, dynamicTools, nil, true) + require.NoError(t, err) - w := &wrapper{ - allTools: allTools, - dynamicTools: dynamicTools, - cm: &mockChatModel{ - streamFunc: func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { - options := model.GetCommonOptions(nil, opts...) - assert.Len(t, options.Tools, 2) - assert.Equal(t, "static_tool", options.Tools[0].Name) - assert.Equal(t, "dynamic_tool1", options.Tools[1].Name) - return nil, nil - }, - }, + options := model.GetCommonOptions(nil, opts...) + var names []string + for _, t := range options.Tools { + names = append(names, t.Name) } - - stream, err := w.Stream(ctx, messages) - assert.NoError(t, err) - assert.Nil(t, stream) + assert.Equal(t, []string{"static_tool"}, names) + // ToolSearchTool should be set. + assert.NotNil(t, options.ToolSearchTool) + assert.Equal(t, toolSearchToolName, options.ToolSearchTool.Name) }) } diff --git a/components/model/option.go b/components/model/option.go index 936b0fbda..2222e14a1 100644 --- a/components/model/option.go +++ b/components/model/option.go @@ -28,6 +28,16 @@ type Options struct { TopP *float32 // Tools is a list of tools the model may call. Tools []*schema.ToolInfo + // DeferredTools is a list of tools to be registered with defer_loading=true + // for the model's built-in (server-side) tool search capability. + // These tools are sent to the model API but not loaded into context upfront — + // only their names and descriptions are visible to the model. The model's + // built-in tool search tool searches through them and loads matching ones + // on demand. + DeferredTools []*schema.ToolInfo + + ToolSearchTool *schema.ToolInfo + // MaxTokens is the max number of tokens, if reached the max tokens, the model will stop generating, and mostly return a finish reason of "length". MaxTokens *int // Stop is the stop words for the model, which controls the stopping condition of the model. @@ -114,6 +124,33 @@ func WithTools(tools []*schema.ToolInfo) Option { } } +// WithToolSearchTool is the option to register a tool search tool with the model. +// When set, the model uses this tool to discover and load deferred tools on demand. +// Note: The tool search tool should NOT be included in WithTools. +func WithToolSearchTool(tool *schema.ToolInfo) Option { + return Option{ + apply: func(opts *Options) { + opts.ToolSearchTool = tool + }, + } +} + +// WithDeferredTools is the option to set deferred tools for the model's +// built-in (server-side) tool search. These tools are registered with +// defer_loading=true so the model can discover and load them on demand +// via its native tool search capability. +// Note: Deferred tools should NOT be included in WithTools. +func WithDeferredTools(tools []*schema.ToolInfo) Option { + if tools == nil { + tools = []*schema.ToolInfo{} + } + return Option{ + apply: func(opts *Options) { + opts.DeferredTools = tools + }, + } +} + // WithToolChoice sets the tool choice for the model. It also allows for providing a list of // tool names to constrain the model to a specific subset of the available tools. // Only available for ChatModel. diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 95e14c0df..43376c146 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -41,6 +41,7 @@ const ( ContentBlockTypeUserInputAudio ContentBlockType = "user_input_audio" ContentBlockTypeUserInputVideo ContentBlockType = "user_input_video" ContentBlockTypeUserInputFile ContentBlockType = "user_input_file" + ContentBlockTypeToolSearchResult ContentBlockType = "tool_search_result" ContentBlockTypeAssistantGenText ContentBlockType = "assistant_gen_text" ContentBlockTypeAssistantGenImage ContentBlockType = "assistant_gen_image" ContentBlockTypeAssistantGenAudio ContentBlockType = "assistant_gen_audio" @@ -134,6 +135,11 @@ type ContentBlock struct { // FunctionToolResult contains the result returned from a user-defined tool call. FunctionToolResult *FunctionToolResult `json:"function_tool_result,omitempty"` + // ToolSearchFunctionToolResult contains the result of a client-side custom tool search tool call. + // It carries the full definitions of newly discovered tools so that the model can + // recognize which tools have been added and are now available for invocation. + ToolSearchFunctionToolResult *ToolSearchFunctionToolResult `json:"tool_search_function_tool_result,omitempty"` + // ServerToolCall contains the invocation details for a provider built-in tool executed on the model server. ServerToolCall *ServerToolCall `json:"server_tool_call,omitempty"` @@ -300,6 +306,28 @@ type FunctionToolResult struct { Result string `json:"result,omitempty"` } +// ToolSearchFunctionToolResult represents the result of a client-side custom tool search +// function tool call. Unlike a regular FunctionToolResult, this carries a ToolSearchResult +// containing the full definitions of newly discovered tools, so the model can recognize +// which tools have been added and are now available for invocation. +type ToolSearchFunctionToolResult struct { + // CallID is the unique identifier for the tool call. + CallID string `json:"call_id,omitempty"` + + // Name specifies the function tool invoked. + Name string `json:"name"` + + // Result is the function tool result returned by the user + Result *ToolSearchResult `json:"result,omitempty"` +} + +func (t *ToolSearchFunctionToolResult) String() string { + if t.Result != nil { + return t.Result.String() + } + return "" +} + type ServerToolCall struct { // Name specifies the server-side tool invoked. // Supplied by the model server (e.g., `web_search` for OpenAI, `googleSearch` for Gemini). @@ -461,7 +489,7 @@ type assistantGenVariant interface { } type functionToolCallVariant interface { - FunctionToolCall | FunctionToolResult + FunctionToolCall | FunctionToolResult | ToolSearchFunctionToolResult } type serverToolCallVariant interface { @@ -487,6 +515,8 @@ func NewContentBlock[T contentBlockVariant](content *T) *ContentBlock { return &ContentBlock{Type: ContentBlockTypeUserInputVideo, UserInputVideo: b} case *UserInputFile: return &ContentBlock{Type: ContentBlockTypeUserInputFile, UserInputFile: b} + case *ToolSearchFunctionToolResult: + return &ContentBlock{Type: ContentBlockTypeToolSearchResult, ToolSearchFunctionToolResult: b} case *AssistantGenText: return &ContentBlock{Type: ContentBlockTypeAssistantGenText, AssistantGenText: b} case *AssistantGenImage: @@ -1028,6 +1058,11 @@ func concatChunksOfSameContentBlock(blocks []*ContentBlock) (*ContentBlock, erro func(b *ContentBlock) *UserInputFile { return b.UserInputFile }, concatUserInputFiles) + case ContentBlockTypeToolSearchResult: + return concatContentBlockHelper(blocks, blockType, + func(b *ContentBlock) *ToolSearchFunctionToolResult { return b.ToolSearchFunctionToolResult }, + concatToolSearchFunctionToolResult) + case ContentBlockTypeAssistantGenText: return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *AssistantGenText { return b.AssistantGenText }, @@ -1227,6 +1262,16 @@ func concatUserInputFiles(files []*UserInputFile) (*UserInputFile, error) { return nil, fmt.Errorf("cannot concat multiple user input files") } +func concatToolSearchFunctionToolResult(results []*ToolSearchFunctionToolResult) (*ToolSearchFunctionToolResult, error) { + if len(results) == 0 { + return nil, fmt.Errorf("no tool search results found") + } + if len(results) == 1 { + return results[0], nil + } + return nil, fmt.Errorf("cannot concat multiple tool search results") +} + func concatAssistantGenTexts(texts []*AssistantGenText) (ret *AssistantGenText, err error) { if len(texts) == 0 { return nil, fmt.Errorf("no assistant generated text found") @@ -1772,6 +1817,7 @@ func (m *AgenticMessage) String() string { } // String returns the string representation of ContentBlock. +// nolint func (b *ContentBlock) String() string { sb := &strings.Builder{} sb.WriteString(fmt.Sprintf("type: %s\n", b.Type)) @@ -1801,6 +1847,10 @@ func (b *ContentBlock) String() string { if b.UserInputFile != nil { sb.WriteString(b.UserInputFile.String()) } + case ContentBlockTypeToolSearchResult: + if b.ToolSearchFunctionToolResult != nil { + sb.WriteString(b.ToolSearchFunctionToolResult.String()) + } case ContentBlockTypeAssistantGenText: if b.AssistantGenText != nil { sb.WriteString(b.AssistantGenText.String()) diff --git a/schema/agentic_message_test.go b/schema/agentic_message_test.go index e8a1003f5..10639f738 100644 --- a/schema/agentic_message_test.go +++ b/schema/agentic_message_test.go @@ -1592,6 +1592,8 @@ func TestNewContentBlock(t *testing.T) { block = NewContentBlock(v) case *UserInputFile: block = NewContentBlock(v) + case *ToolSearchFunctionToolResult: + block = NewContentBlock(v) case *AssistantGenText: block = NewContentBlock(v) case *AssistantGenImage: diff --git a/schema/message.go b/schema/message.go index 611bcedca..890af48ab 100644 --- a/schema/message.go +++ b/schema/message.go @@ -139,7 +139,6 @@ type ToolCall struct { Type string `json:"type"` // Function is the function call to be made. Function FunctionCall `json:"function"` - // Extra is used to store extra information for the tool call. Extra map[string]any `json:"extra,omitempty"` } @@ -222,6 +221,9 @@ type MessageInputPart struct { // File is the file input of the part, it's used when Type is "file_url". File *MessageInputFile `json:"file,omitempty"` + // ToolSearchResult holds the result of a tool search request, containing the matched tool names and their definitions. + ToolSearchResult *ToolSearchResult `json:"tool_search_result,omitempty"` + // Extra is used to store extra information. Extra map[string]any `json:"extra,omitempty"` } @@ -310,6 +312,9 @@ const ( // ToolPartTypeFile means the part is a file url. ToolPartTypeFile ToolPartType = "file" + + // ToolPartTypeToolSearchResult means the part contains tool search results. + ToolPartTypeToolSearchResult ToolPartType = "tool_search_result" ) // ToolOutputImage represents an image in tool output. @@ -336,6 +341,27 @@ type ToolOutputFile struct { MessagePartCommon } +// ToolSearchResult represents the result of a tool search operation. +// When a model issues a tool search call, the framework searches for matching tools +// and returns the results via this struct. +type ToolSearchResult struct { + // Tools contains the full definitions of matched tools that were not previously + // registered. Their complete definitions are required so that the model can + // understand their parameters and usage. + Tools []*ToolInfo +} + +func (t *ToolSearchResult) String() string { + sb := new(strings.Builder) + sb.WriteString("ToolSearchResult[") + for _, tool := range t.Tools { + sb.WriteString(tool.Name) + sb.WriteString(",") + } + sb.WriteString("]") + return sb.String() +} + // ToolOutputPart represents a part of tool execution output. // It supports streaming scenarios through the Index field for chunk merging. type ToolOutputPart struct { @@ -358,6 +384,9 @@ type ToolOutputPart struct { // File is the file content, used when Type is ToolPartTypeFile. File *ToolOutputFile `json:"file,omitempty"` + // ToolSearchResult holds the tool search results, used when Type is ToolPartTypeToolSearchResult. + ToolSearchResult *ToolSearchResult `json:"tool_search_result,omitempty"` + // Extra is used to store extra information. Extra map[string]any `json:"extra,omitempty"` } @@ -422,6 +451,14 @@ func convToolOutputPartToMessageInputPart(toolPart ToolOutputPart) (MessageInput File: &MessageInputFile{MessagePartCommon: toolPart.File.MessagePartCommon}, Extra: toolPart.Extra, }, nil + case ToolPartTypeToolSearchResult: + if toolPart.ToolSearchResult == nil { + return MessageInputPart{}, fmt.Errorf("tool search result is nil for tool part type %v", toolPart.Type) + } + return MessageInputPart{ + Type: ChatMessagePartTypeToolSearchResult, + ToolSearchResult: toolPart.ToolSearchResult, + }, nil default: return MessageInputPart{}, fmt.Errorf("unknown tool part type: %v", toolPart.Type) } @@ -498,6 +535,9 @@ const ( ChatMessagePartTypeFileURL ChatMessagePartType = "file_url" // ChatMessagePartTypeReasoning means the part is a reasoning block. ChatMessagePartTypeReasoning ChatMessagePartType = "reasoning" + + // ChatMessagePartTypeToolSearchResult means the part contains tool search results. + ChatMessagePartTypeToolSearchResult ChatMessagePartType = "tool_search_result" ) // Deprecated: This struct is deprecated as the MultiContent field is deprecated. diff --git a/schema/tool.go b/schema/tool.go index a49306047..f8a0a743e 100644 --- a/schema/tool.go +++ b/schema/tool.go @@ -17,6 +17,9 @@ package schema import ( + "bytes" + "encoding/gob" + "encoding/json" "sort" "github.com/eino-contrib/jsonschema" @@ -137,6 +140,104 @@ type ToolInfo struct { *ParamsOneOf } +type toolInfoForJSON struct { + Name string `json:"name,omitempty"` + Desc string `json:"desc,omitempty"` + Extra map[string]any `json:"extra,omitempty"` + HasParamsOneOf bool `json:"has_params_one_of,omitempty"` + Params map[string]*ParameterInfo `json:"params,omitempty"` + JSONSchema *jsonschema.Schema `json:"json_schema,omitempty"` +} + +type toolInfoForGob struct { + Name string + Desc string + Extra map[string]any + HasParamsOneOf bool + Params map[string]*ParameterInfo + JSONSchema *string +} + +func (t *ToolInfo) MarshalJSON() ([]byte, error) { + tmp := &toolInfoForJSON{ + Name: t.Name, + Desc: t.Desc, + Extra: t.Extra, + } + if t.ParamsOneOf != nil { + tmp.HasParamsOneOf = true + tmp.Params = t.ParamsOneOf.params + tmp.JSONSchema = t.ParamsOneOf.jsonschema + } + return json.Marshal(tmp) +} + +func (t *ToolInfo) UnmarshalJSON(data []byte) error { + tmp := &toolInfoForJSON{} + if err := json.Unmarshal(data, tmp); err != nil { + return err + } + t.Name = tmp.Name + t.Desc = tmp.Desc + t.Extra = tmp.Extra + if tmp.HasParamsOneOf { + t.ParamsOneOf = &ParamsOneOf{ + params: tmp.Params, + jsonschema: tmp.JSONSchema, + } + } + return nil +} + +func (t *ToolInfo) GobEncode() ([]byte, error) { + tmp := &toolInfoForGob{ + Name: t.Name, + Desc: t.Desc, + Extra: t.Extra, + } + if t.ParamsOneOf != nil { + tmp.HasParamsOneOf = true + tmp.Params = t.ParamsOneOf.params + if t.ParamsOneOf.jsonschema != nil { + b, err := json.Marshal(t.ParamsOneOf.jsonschema) + if err != nil { + return nil, err + } + str := string(b) + tmp.JSONSchema = &str + } + } + buf := new(bytes.Buffer) + if err := gob.NewEncoder(buf).Encode(tmp); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +func (t *ToolInfo) GobDecode(b []byte) error { + tmp := &toolInfoForGob{} + if err := gob.NewDecoder(bytes.NewBuffer(b)).Decode(tmp); err != nil { + return err + } + t.Name = tmp.Name + t.Desc = tmp.Desc + t.Extra = tmp.Extra + if !tmp.HasParamsOneOf { + return nil + } + t.ParamsOneOf = &ParamsOneOf{ + params: tmp.Params, + } + if tmp.JSONSchema != nil { + s := &jsonschema.Schema{} + if err := json.Unmarshal([]byte(*tmp.JSONSchema), s); err != nil { + return err + } + t.ParamsOneOf.jsonschema = s + } + return nil +} + // ParameterInfo is the information of a parameter. // It is used to describe the parameters of a tool. type ParameterInfo struct { diff --git a/schema/tool_test.go b/schema/tool_test.go index 97af29be2..e8f95c364 100644 --- a/schema/tool_test.go +++ b/schema/tool_test.go @@ -17,6 +17,8 @@ package schema import ( + "bytes" + "encoding/gob" "encoding/json" "testing" @@ -133,3 +135,49 @@ func TestParamsOneOfToJSONSchema(t *testing.T) { }) } + +func TestToolInfoSerialization(t *testing.T) { + ti1 := &ToolInfo{ + ParamsOneOf: NewParamsOneOfByParams(map[string]*ParameterInfo{ + "a": { + Type: String, + Desc: "desc", + }, + }), + } + ti2 := &ToolInfo{ + ParamsOneOf: NewParamsOneOfByJSONSchema(&jsonschema.Schema{ + Type: "string", + }), + } + + // json + b, err := json.Marshal(ti1) + assert.NoError(t, err) + result := &ToolInfo{} + err = json.Unmarshal(b, result) + assert.NoError(t, err) + assert.Equal(t, ti1, result) + b, err = json.Marshal(ti2) + assert.NoError(t, err) + result = &ToolInfo{} + err = json.Unmarshal(b, result) + assert.NoError(t, err) + assert.Equal(t, ti2, result) + + // gob + buf := new(bytes.Buffer) + err = gob.NewEncoder(buf).Encode(ti1) + assert.NoError(t, err) + result = &ToolInfo{} + err = gob.NewDecoder(buf).Decode(result) + assert.NoError(t, err) + assert.Equal(t, ti1, result) + buf = new(bytes.Buffer) + err = gob.NewEncoder(buf).Encode(ti2) + assert.NoError(t, err) + result = &ToolInfo{} + err = gob.NewDecoder(buf).Decode(result) + assert.NoError(t, err) + assert.Equal(t, ti2, result) +} From 303e22ea943fa6b172b98ee0612d732c6dd23e79 Mon Sep 17 00:00:00 2001 From: Born Date: Fri, 10 Apr 2026 16:23:12 +0800 Subject: [PATCH 53/59] fix(adk): propagate missing ToolsNodeConfig fields in ChatModelAgent (#945) - Add ToolAliases to prepareExecContext when building ToolsNodeConfig - Add UnknownToolsHandler, ExecuteSequentially, ToolArgumentsHandler, and ToolAliases to applyBeforeAgent when rebuilding after BeforeAgent handlers modify tools - Add tests covering argument alias remapping, name alias dispatch, alias preservation after handler rebuild, and handler-only tool registration with pre-configured aliases --- adk/chatmodel.go | 9 +- adk/chatmodel_test.go | 317 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 324 insertions(+), 2 deletions(-) diff --git a/adk/chatmodel.go b/adk/chatmodel.go index 83d8ffd4a..abfc55fa0 100644 --- a/adk/chatmodel.go +++ b/adk/chatmodel.go @@ -683,8 +683,12 @@ func (a *ChatModelAgent) applyBeforeAgent(ctx context.Context, ec *execContext) runtimeEC := &execContext{ instruction: runCtx.Instruction, toolsNodeConf: compose.ToolsNodeConfig{ - Tools: runCtx.Tools, - ToolCallMiddlewares: cloneSlice(ec.toolsNodeConf.ToolCallMiddlewares), + Tools: runCtx.Tools, + ToolCallMiddlewares: cloneSlice(ec.toolsNodeConf.ToolCallMiddlewares), + UnknownToolsHandler: ec.toolsNodeConf.UnknownToolsHandler, + ExecuteSequentially: ec.toolsNodeConf.ExecuteSequentially, + ToolArgumentsHandler: ec.toolsNodeConf.ToolArgumentsHandler, + ToolAliases: ec.toolsNodeConf.ToolAliases, }, returnDirectly: runCtx.ReturnDirectly, toolUpdated: true, @@ -710,6 +714,7 @@ func (a *ChatModelAgent) prepareExecContext(ctx context.Context) (*execContext, UnknownToolsHandler: a.toolsConfig.UnknownToolsHandler, ExecuteSequentially: a.toolsConfig.ExecuteSequentially, ToolArgumentsHandler: a.toolsConfig.ToolArgumentsHandler, + ToolAliases: a.toolsConfig.ToolAliases, } returnDirectly := copyMap(a.toolsConfig.ReturnDirectly) diff --git a/adk/chatmodel_test.go b/adk/chatmodel_test.go index f3ff6ea05..0edab5a2d 100644 --- a/adk/chatmodel_test.go +++ b/adk/chatmodel_test.go @@ -18,11 +18,13 @@ package adk import ( "context" + "encoding/json" "errors" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" "github.com/cloudwego/eino/components/model" @@ -2098,3 +2100,318 @@ func TestNewChatModelAgent_FailoverConfigValidation(t *testing.T) { assert.Contains(t, err.Error(), "ModelFailoverConfig.ShouldFailover") }) } + +// aliasCaptureTool captures the raw arguments JSON received by the tool. +type aliasCaptureTool struct { + name string + params map[string]*schema.ParameterInfo + receivedArgs string +} + +func (t *aliasCaptureTool) Info(_ context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: t.name, + Desc: t.name + " tool", + ParamsOneOf: schema.NewParamsOneOfByParams(t.params), + }, nil +} + +func (t *aliasCaptureTool) InvokableRun(_ context.Context, argumentsInJSON string, _ ...tool.Option) (string, error) { + t.receivedArgs = argumentsInJSON + return "ok", nil +} + +func TestToolAliasesPropagation(t *testing.T) { + t.Run("prepareExecContext_propagates_ToolAliases", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + + captureTool := &aliasCaptureTool{ + name: "grep", + params: map[string]*schema.ParameterInfo{ + "pattern": {Type: schema.String, Desc: "regex pattern"}, + "path": {Type: schema.String, Desc: "search path"}, + }, + } + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + generateCount := 0 + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, msgs []*schema.Message, opts ...model.Option) (*schema.Message, error) { + generateCount++ + if generateCount == 1 { + return &schema.Message{ + Role: schema.Assistant, + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "grep", + Arguments: `{"grep_content": "TODO", "path": "/src"}`, + }, + }, + }, + }, nil + } + return schema.AssistantMessage("done", nil), nil + }).AnyTimes() + cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "test", + Instruction: "test", + Model: cm, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{captureTool}, + ToolAliases: map[string]compose.ToolAliasConfig{ + "grep": { + ArgumentsAliases: map[string][]string{ + "pattern": {"grep_content"}, + }, + }, + }, + }, + }, + }) + require.NoError(t, err) + + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("search for TODOs")}}) + for { + _, ok := iter.Next() + if !ok { + break + } + } + + require.NotEmpty(t, captureTool.receivedArgs, "tool should have been called") + var args map[string]any + err = json.Unmarshal([]byte(captureTool.receivedArgs), &args) + require.NoError(t, err) + assert.Equal(t, "TODO", args["pattern"], "alias 'grep_content' should be remapped to 'pattern'") + assert.NotContains(t, args, "grep_content", "alias key should not be present after remapping") + assert.Equal(t, "/src", args["path"]) + }) + + t.Run("applyBeforeAgent_preserves_ToolAliases_when_handler_modifies_tools", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + + captureTool := &aliasCaptureTool{ + name: "grep", + params: map[string]*schema.ParameterInfo{ + "pattern": {Type: schema.String, Desc: "regex pattern"}, + }, + } + + extraTool := &aliasCaptureTool{ + name: "extra_tool", + params: map[string]*schema.ParameterInfo{ + "input": {Type: schema.String}, + }, + } + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + generateCount := 0 + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, msgs []*schema.Message, opts ...model.Option) (*schema.Message, error) { + generateCount++ + if generateCount == 1 { + return &schema.Message{ + Role: schema.Assistant, + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "grep", + Arguments: `{"grep_content": "FIXME"}`, + }, + }, + }, + }, nil + } + return schema.AssistantMessage("done", nil), nil + }).AnyTimes() + cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() + + handler := &testToolsHandler{ + BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, + tools: []tool.BaseTool{extraTool}, + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "test", + Instruction: "test", + Model: cm, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{captureTool}, + ToolAliases: map[string]compose.ToolAliasConfig{ + "grep": { + ArgumentsAliases: map[string][]string{ + "pattern": {"grep_content"}, + }, + }, + }, + }, + }, + Handlers: []ChatModelAgentMiddleware{handler}, + }) + require.NoError(t, err) + + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("search for FIXMEs")}}) + for { + _, ok := iter.Next() + if !ok { + break + } + } + + require.NotEmpty(t, captureTool.receivedArgs, "tool should have been called") + var args map[string]any + err = json.Unmarshal([]byte(captureTool.receivedArgs), &args) + require.NoError(t, err) + assert.Equal(t, "FIXME", args["pattern"], "alias 'grep_content' should be remapped to 'pattern' even after handler rebuild") + assert.NotContains(t, args, "grep_content", "alias key should not be present after remapping") + }) + + t.Run("name_alias_propagated_through_prepareExecContext", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + + captureTool := &aliasCaptureTool{ + name: "grep", + params: map[string]*schema.ParameterInfo{ + "pattern": {Type: schema.String}, + }, + } + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + generateCount := 0 + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, msgs []*schema.Message, opts ...model.Option) (*schema.Message, error) { + generateCount++ + if generateCount == 1 { + return &schema.Message{ + Role: schema.Assistant, + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "search_content", + Arguments: `{"pattern": "TODO"}`, + }, + }, + }, + }, nil + } + return schema.AssistantMessage("done", nil), nil + }).AnyTimes() + cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "test", + Instruction: "test", + Model: cm, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{captureTool}, + ToolAliases: map[string]compose.ToolAliasConfig{ + "grep": { + NameAliases: []string{"search_content"}, + }, + }, + }, + }, + }) + require.NoError(t, err) + + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("search")}}) + for { + _, ok := iter.Next() + if !ok { + break + } + } + + require.NotEmpty(t, captureTool.receivedArgs, "tool should have been called via name alias 'search_content'") + var args map[string]any + err = json.Unmarshal([]byte(captureTool.receivedArgs), &args) + require.NoError(t, err) + assert.Equal(t, "TODO", args["pattern"]) + }) + + t.Run("handler_adds_tool_matching_preexisting_ToolAliases_with_no_initial_tools", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + + captureTool := &aliasCaptureTool{ + name: "grep", + params: map[string]*schema.ParameterInfo{ + "pattern": {Type: schema.String, Desc: "regex pattern"}, + }, + } + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + generateCount := 0 + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, msgs []*schema.Message, opts ...model.Option) (*schema.Message, error) { + generateCount++ + if generateCount == 1 { + return &schema.Message{ + Role: schema.Assistant, + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "grep", + Arguments: `{"grep_content": "BUG"}`, + }, + }, + }, + }, nil + } + return schema.AssistantMessage("done", nil), nil + }).AnyTimes() + cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() + + handler := &testToolsHandler{ + BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, + tools: []tool.BaseTool{captureTool}, + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "test", + Instruction: "test", + Model: cm, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + ToolAliases: map[string]compose.ToolAliasConfig{ + "grep": { + ArgumentsAliases: map[string][]string{ + "pattern": {"grep_content"}, + }, + }, + }, + }, + }, + Handlers: []ChatModelAgentMiddleware{handler}, + }) + require.NoError(t, err) + + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("find bugs")}}) + for { + _, ok := iter.Next() + if !ok { + break + } + } + + require.NotEmpty(t, captureTool.receivedArgs, "tool added by handler should have been called") + var args map[string]any + err = json.Unmarshal([]byte(captureTool.receivedArgs), &args) + require.NoError(t, err) + assert.Equal(t, "BUG", args["pattern"], "alias 'grep_content' should be remapped to 'pattern' for handler-added tool") + assert.NotContains(t, args, "grep_content") + }) +} From d7161f78e54092216139176d32a9efc6bd56f2ef Mon Sep 17 00:00:00 2001 From: shentongmartin Date: Tue, 14 Apr 2026 14:14:41 +0800 Subject: [PATCH 54/59] refactor(adk): improve cancel propagation, encapsulate TurnLoop stop options, add UntilIdleFor (#942) --- adk/attack_test.go | 449 +++++++++++++ adk/cancel.go | 108 ++- adk/cancel_edge_test.go | 2 +- adk/cancel_recursive_test.go | 409 ++++++++++++ adk/cancel_test.go | 163 +++-- adk/turn_buffer.go | 128 ++++ adk/turn_loop.go | 372 ++++++++--- adk/turn_loop_test.go | 632 +++++++++++++----- adk/wrappers_test.go | 8 +- .../prompt/agentic_chat_template_test.go | 3 +- internal/channel.go | 39 +- internal/channel_test.go | 241 ------- 12 files changed, 1921 insertions(+), 633 deletions(-) create mode 100644 adk/attack_test.go create mode 100644 adk/cancel_recursive_test.go create mode 100644 adk/turn_buffer.go diff --git a/adk/attack_test.go b/adk/attack_test.go new file mode 100644 index 000000000..bfb4462ef --- /dev/null +++ b/adk/attack_test.go @@ -0,0 +1,449 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/cloudwego/eino/schema" +) + +func TestAttack_UntilIdleFor_ConcurrentPushDuringIdleTimer(t *testing.T) { + turnCount := int32(0) + turnDone := make(chan struct{}, 10) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + atomic.AddInt32(&turnCount, 1) + turnDone <- struct{}{} + return &AgentOutput{}, nil + }, + }, nil + }, + }) + + loop.Push("msg1") + <-turnDone + + loop.Stop(UntilIdleFor(200 * time.Millisecond)) + + for i := 0; i < 5; i++ { + time.Sleep(50 * time.Millisecond) + loop.Push("concurrent-" + string(rune('a'+i))) + <-turnDone + } + + done := make(chan struct{}) + go func() { + loop.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(3 * time.Second): + t.Fatal("loop did not exit after idle timeout — Push did not reset timer correctly") + } + + finalCount := atomic.LoadInt32(&turnCount) + assert.Equal(t, int32(6), finalCount, "all 6 pushes should have been processed") +} + +func TestAttack_UntilIdleFor_MultipleStopCallsFirstWins(t *testing.T) { + turnDone := make(chan struct{}) + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + close(turnDone) + return &AgentOutput{}, nil + }, + }, nil + }, + }) + + loop.Push("msg1") + <-turnDone + + loop.Stop(UntilIdleFor(100 * time.Millisecond)) + loop.Stop(UntilIdleFor(10 * time.Minute)) + + done := make(chan struct{}) + go func() { + loop.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("second UntilIdleFor should have been ignored; loop should have exited with 100ms timer") + } +} + +func TestAttack_BareStopOverridesUntilIdleFor(t *testing.T) { + agentStarted := make(chan struct{}) + agentDone := make(chan struct{}) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + close(agentStarted) + <-agentDone + return &AgentOutput{}, nil + }, + }, nil + }, + }) + + loop.Push("msg1") + <-agentStarted + + loop.Stop(UntilIdleFor(10 * time.Minute)) + + loop.Stop() + close(agentDone) + + done := make(chan struct{}) + go func() { + loop.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("bare Stop should override UntilIdleFor and cause immediate shutdown") + } + + exit := loop.Wait() + assert.NoError(t, exit.ExitReason, "bare Stop should exit cleanly") +} + +func TestAttack_StopSignal_NilCancelOptsDoNotDeescalate(t *testing.T) { + agentStarted := make(chan *cancelContext, 1) + probe := &turnLoopStopModeProbeAgent{ccCh: agentStarted} + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return probe, nil + }, + }) + + loop.Push("msg1") + cc := <-agentStarted + + loop.Stop(WithImmediate()) + + time.Sleep(20 * time.Millisecond) + + loop.Stop() + + time.Sleep(20 * time.Millisecond) + mode := cc.getMode() + assert.Equal(t, CancelImmediate, mode, "bare Stop after WithImmediate must not de-escalate cancel mode") + + exit := loop.Wait() + var ce *CancelError + require.True(t, errors.As(exit.ExitReason, &ce)) + assert.Equal(t, CancelImmediate, ce.Info.Mode) +} + +func TestAttack_CanceledItems_EmptyWhenAgentFinishesNormally(t *testing.T) { + agentStarted := make(chan struct{}) + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + close(agentStarted) + return &AgentOutput{}, nil + }, + }, nil + }, + }) + + loop.Push("msg1") + <-agentStarted + time.Sleep(50 * time.Millisecond) + loop.Stop() + + exit := loop.Wait() + assert.NoError(t, exit.ExitReason) + assert.Empty(t, exit.CanceledItems, "CanceledItems must be empty when agent finished normally") +} + +func TestAttack_TurnBuffer_WakeupDoesNotLoseItems(t *testing.T) { + tb := newTurnBuffer[string]() + + tb.Send("a") + tb.Send("b") + tb.Wakeup() + tb.Send("c") + + var got []string + for i := 0; i < 3; i++ { + val, ok := tb.Receive() + require.True(t, ok) + got = append(got, val) + } + + assert.Equal(t, []string{"a", "b", "c"}, got, "Wakeup must not cause items to be lost") +} + +func TestAttack_TurnBuffer_ClearWakeupPreventsSpuriousReturn(t *testing.T) { + tb := newTurnBuffer[string]() + + tb.Wakeup() + tb.ClearWakeup() + + received := make(chan string, 1) + go func() { + val, ok := tb.Receive() + if ok { + received <- val + } + }() + + time.Sleep(50 * time.Millisecond) + tb.Send("real") + + select { + case val := <-received: + assert.Equal(t, "real", val, "ClearWakeup should prevent spurious empty return") + case <-time.After(2 * time.Second): + t.Fatal("Receive blocked forever despite Send") + } +} + +func TestAttack_StopBeforeRun_UntilIdleFor_ExitsImmediately(t *testing.T) { + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Stop(UntilIdleFor(10 * time.Minute)) + loop.Stop() + + loop.Run(context.Background()) + + done := make(chan struct{}) + go func() { + loop.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("loop should exit immediately when Stop() called before Run()") + } +} + +func TestAttack_PushAfterStop_UntilIdleFor_RoutedToLateItems(t *testing.T) { + turnDone := make(chan struct{}) + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + close(turnDone) + return &AgentOutput{}, nil + }, + }, nil + }, + }) + + loop.Push("msg1") + <-turnDone + + loop.Stop(UntilIdleFor(50 * time.Millisecond)) + exit := loop.Wait() + assert.NoError(t, exit.ExitReason) + + ok, _ := loop.Push("after-stop") + assert.False(t, ok, "Push after loop exited should return false") + + late := exit.TakeLateItems() + assert.Equal(t, []string{"after-stop"}, late) +} + +func TestAttack_ConcurrentStopEscalation_RaceDetector(t *testing.T) { + agentStarted := make(chan struct{}) + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + close(agentStarted) + <-ctx.Done() + return nil, ctx.Err() + }, + }, nil + }, + }) + + loop.Push("msg1") + <-agentStarted + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + switch i % 4 { + case 0: + loop.Stop() + case 1: + loop.Stop(WithImmediate()) + case 2: + loop.Stop(WithGracefulTimeout(100 * time.Millisecond)) + case 3: + loop.Stop(UntilIdleFor(50 * time.Millisecond)) + } + }(i) + } + + wg.Wait() + exit := loop.Wait() + t.Log("ExitReason:", exit.ExitReason) +} + +func TestAttack_StopCause_FirstNonEmptyWins_ConcurrentCallers(t *testing.T) { + turnDone := make(chan struct{}) + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + close(turnDone) + return &AgentOutput{}, nil + }, + }, nil + }, + }) + + loop.Push("msg1") + <-turnDone + + loop.Stop(WithStopCause("first-cause")) + loop.Stop(WithStopCause("second-cause")) + + exit := loop.Wait() + assert.Equal(t, "first-cause", exit.StopCause, "first non-empty StopCause should win") +} + +func TestAttack_SkipCheckpoint_Sticky(t *testing.T) { + agentStarted := make(chan struct{}) + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + close(agentStarted) + <-ctx.Done() + return nil, ctx.Err() + }, + }, nil + }, + Store: &turnLoopCheckpointStore{m: make(map[string][]byte)}, + CheckpointID: "test-sticky", + }) + + loop.Push("msg1") + <-agentStarted + + loop.Stop(WithSkipCheckpoint()) + loop.Stop(WithImmediate()) + + exit := loop.Wait() + assert.False(t, exit.Checkpointed, "SkipCheckpoint is sticky; checkpoint should be skipped") +} diff --git a/adk/cancel.go b/adk/cancel.go index a15699f53..6d4aa9ad9 100644 --- a/adk/cancel.go +++ b/adk/cancel.go @@ -44,21 +44,21 @@ type CancelMode int const ( // CancelImmediate cancels the agent as soon as the signal is received, - // without waiting for a ChatModel or ToolCalls safe-point. Propagates - // to all descendant agents via the cancel context hierarchy, including - // agents nested inside AgentTools and workflow sub-agents. + // without waiting for a ChatModel or ToolCalls safe-point. + // By default, only the root agent is interrupted; descendant agents inside + // AgentTools are torn down via context cancellation as a side effect. + // Use WithRecursive to propagate explicit immediate-cancel signals to + // descendants for clean teardown with grace period. CancelImmediate CancelMode = 0 - // CancelAfterChatModel cancels after the first chat model call that completes - // anywhere in the agent hierarchy, including nested sub-agents, agent tools, - // and workflow branches. The cancel mode propagates to all descendant agents; - // whichever ChatModel finishes first triggers the cancel. The interrupting - // agent emits an interrupt that bubbles up through the agent tree — parent - // agents do not need to reach their own ChatModel safe-point. + // CancelAfterChatModel cancels after the root agent's next chat model call + // completes. By default, only the root agent checks this safe-point; + // nested sub-agents inside AgentTools are unaware of the cancel. + // Use WithRecursive to propagate the cancel to all descendants — whichever + // ChatModel finishes first triggers the cancel. CancelAfterChatModel CancelMode = 1 << iota - // CancelAfterToolCalls cancels after the first set of concurrent tool calls - // that completes anywhere in the agent hierarchy. Like CancelAfterChatModel, - // this mode propagates to all descendants and fires at whichever level - // reaches the safe-point first. + // CancelAfterToolCalls cancels after the root agent's next set of concurrent + // tool calls completes. By default, only the root agent checks this safe-point. + // Use WithRecursive to propagate to all descendants. CancelAfterToolCalls ) @@ -92,8 +92,9 @@ func (h *CancelHandle) Wait() error { type AgentCancelFunc func(...AgentCancelOption) (*CancelHandle, bool) type agentCancelConfig struct { - Mode CancelMode - Timeout *time.Duration + Mode CancelMode + Recursive bool + Timeout *time.Duration } // AgentCancelOption configures cancel behavior. @@ -118,6 +119,21 @@ func WithAgentCancelTimeout(timeout time.Duration) AgentCancelOption { } } +// WithRecursive opts into recursive cancel propagation. By default, cancel +// modes only affect the root agent; descendant agents inside AgentTools are +// not notified. WithRecursive makes the cancel propagate to all descendants: +// - CancelAfterChatModel / CancelAfterToolCalls: descendants check their own safe-points. +// - CancelImmediate: descendants receive explicit immediate-cancel signals for +// clean teardown; the root uses a grace period to collect child interrupts. +// +// Once any cancel call includes WithRecursive, the flag stays set for the +// entire cancel lifecycle (monotonic escalation). +func WithRecursive() AgentCancelOption { + return func(config *agentCancelConfig) { + config.Recursive = true + } +} + // AgentCancelInfo contains information about a cancel operation. type AgentCancelInfo struct { Mode CancelMode @@ -296,6 +312,9 @@ type cancelContext struct { startedMode int32 // atomic, mode when state transitioned to cancelling deadlineUnixNano int64 // atomic, 0 means no deadline + recursive int32 // atomic; 1 if cancel should propagate to descendant agents via deriveChild + recursiveChan chan struct{} // closed when recursive transitions from 0 to 1 + root bool // true for the original cancelContext created by WithCancel(); false for derived children parent *cancelContext // non-nil for derived children; used to decrement parent's activeChildren on markDone @@ -316,6 +335,7 @@ func newCancelContext() *cancelContext { immediateChan: make(chan struct{}), doneChan: make(chan struct{}), timeoutNotify: make(chan struct{}, 1), + recursiveChan: make(chan struct{}), root: true, } } @@ -324,6 +344,18 @@ func (cc *cancelContext) isRoot() bool { return cc != nil && cc.root } +func (cc *cancelContext) isRecursive() bool { + return cc != nil && atomic.LoadInt32(&cc.recursive) == 1 +} + +// setRecursive(false) is a no-op; recursive is monotonically escalating: +// once set to true, it cannot be reverted. +func (cc *cancelContext) setRecursive(v bool) { + if v && atomic.CompareAndSwapInt32(&cc.recursive, 0, 1) { + close(cc.recursiveChan) + } +} + // deriveChild creates a child cancelContext that receives cancel propagation // from the parent. The caller MUST ensure the child's markDone() is eventually // called (e.g., via wrapIterWithCancelCtx's defer) or that ctx is canceled; @@ -337,10 +369,29 @@ func (cc *cancelContext) deriveChild(ctx context.Context) *cancelContext { child.parent = cc atomic.AddInt32(&cc.activeChildren, 1) + // Each goroutine below propagates one signal class (cancel / immediate) to + // the child. The pattern is a two-phase select: + // Phase 1: wait for the parent signal (or child/ctx completion). + // Phase 2: if the signal fired but recursive mode is not active yet, + // enter a second select waiting for either recursive escalation + // (recursiveChan) or child/ctx completion. This ensures + // non-recursive cancels leave children unaware, while a late + // escalation to recursive still propagates. go func() { select { case <-cc.cancelChan: - child.triggerCancel(cc.getMode()) + if cc.isRecursive() { + child.setRecursive(true) + child.triggerCancel(cc.getMode()) + return + } + select { + case <-cc.recursiveChan: + child.setRecursive(true) + child.triggerCancel(cc.getMode()) + case <-child.doneChan: + case <-ctx.Done(): + } case <-child.doneChan: case <-ctx.Done(): } @@ -349,7 +400,18 @@ func (cc *cancelContext) deriveChild(ctx context.Context) *cancelContext { go func() { select { case <-cc.immediateChan: - child.triggerImmediateCancel() + if cc.isRecursive() { + child.setRecursive(true) + child.triggerImmediateCancel() + return + } + select { + case <-cc.recursiveChan: + child.setRecursive(true) + child.triggerImmediateCancel() + case <-child.doneChan: + case <-ctx.Done(): + } case <-child.doneChan: case <-ctx.Done(): } @@ -504,7 +566,10 @@ func (cc *cancelContext) hasActiveChildren() bool { func (cc *cancelContext) wrapGraphInterruptWithGracePeriod(interrupt func(...compose.GraphInterruptOption)) func(...compose.GraphInterruptOption) { return func(opts ...compose.GraphInterruptOption) { - if cc.hasActiveChildren() { + // Grace period only applies in recursive mode: in shallow mode, + // children are unaware of the cancel and don't need time to propagate + // their interrupt signals back. + if cc.isRecursive() && cc.hasActiveChildren() { newOpts := make([]compose.GraphInterruptOption, len(opts)+1) copy(newOpts, opts) newOpts[len(opts)] = compose.WithGraphInterruptTimeout(defaultCancelImmediateGracePeriod) @@ -682,10 +747,17 @@ func (cc *cancelContext) buildCancelFunc() AgentCancelFunc { curMode = req.Mode cc.setMode(curMode) atomic.StoreInt32(&cc.startedMode, int32(curMode)) + cc.setRecursive(req.Recursive) close(cc.cancelChan) } else { + // Recursive is monotonic: once set, cannot be unset. The first + // cancel call uses the bool directly; subsequent calls only + // escalate (false → true) — setRecursive(false) is a no-op. curMode = joinMode(curMode, req.Mode) cc.setMode(curMode) + if req.Recursive { + cc.setRecursive(true) + } } if curMode == CancelImmediate { diff --git a/adk/cancel_edge_test.go b/adk/cancel_edge_test.go index d3fb02a1a..b0afbe674 100644 --- a/adk/cancel_edge_test.go +++ b/adk/cancel_edge_test.go @@ -1242,7 +1242,7 @@ func TestWithCancel_CancelAfterChatModel_NestedAgentTool(t *testing.T) { cancelDone := make(chan error, 1) go func() { - handle, _ := cancelFn(WithAgentCancelMode(CancelAfterChatModel)) + handle, _ := cancelFn(WithAgentCancelMode(CancelAfterChatModel), WithRecursive()) cancelDone <- handle.Wait() }() diff --git a/adk/cancel_recursive_test.go b/adk/cancel_recursive_test.go new file mode 100644 index 000000000..9f13f55d2 --- /dev/null +++ b/adk/cancel_recursive_test.go @@ -0,0 +1,409 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "runtime" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/compose" +) + +func assertNotClosedWithin(t *testing.T, ch <-chan struct{}, d time.Duration) { + t.Helper() + select { + case <-ch: + t.Fatal("channel was closed but should not have been") + case <-time.After(d): + } +} + +func setupParentChild(t *testing.T) (parent, child *cancelContext, cleanup func()) { + parent = newCancelContext() + ctx, cancel := context.WithCancel(context.Background()) + child = parent.deriveChild(ctx) + cleanup = func() { + child.markDone() + cancel() + } + t.Cleanup(cleanup) + return parent, child, cleanup +} + +func TestDeriveChild(t *testing.T) { + t.Run("Shallow", func(t *testing.T) { + t.Run("DoesNotPropagateSafePoint", func(t *testing.T) { + parent, child, _ := setupParentChild(t) + + parent.triggerCancel(CancelAfterChatModel) + + assertNotClosedWithin(t, child.cancelChan, 50*time.Millisecond) + }) + + t.Run("ImmediateDoesNotPropagate", func(t *testing.T) { + parent, child, _ := setupParentChild(t) + + parent.triggerImmediateCancel() + + assertNotClosedWithin(t, child.immediateChan, 50*time.Millisecond) + }) + + t.Run("GrandchildNoPropagation", func(t *testing.T) { + a := newCancelContext() + ctx, cancel := context.WithCancel(context.Background()) + + b := a.deriveChild(ctx) + c := b.deriveChild(ctx) + t.Cleanup(func() { + c.markDone() + b.markDone() + cancel() + }) + + a.triggerCancel(CancelAfterChatModel) + + assertNotClosedWithin(t, b.cancelChan, 50*time.Millisecond) + assertNotClosedWithin(t, c.cancelChan, 50*time.Millisecond) + }) + + t.Run("NeverRecursive_GoroutineCleanup", func(t *testing.T) { + runtime.GC() + time.Sleep(50 * time.Millisecond) + before := runtime.NumGoroutine() + + parent := newCancelContext() + ctx, cancel := context.WithCancel(context.Background()) + + child := parent.deriveChild(ctx) + + parent.triggerCancel(CancelAfterChatModel) + time.Sleep(100 * time.Millisecond) + + child.markDone() + cancel() + + time.Sleep(200 * time.Millisecond) + runtime.GC() + time.Sleep(50 * time.Millisecond) + after := runtime.NumGoroutine() + + assert.InDelta(t, before, after, 5, "goroutine leak detected: before=%d after=%d", before, after) + }) + }) + + t.Run("Recursive", func(t *testing.T) { + t.Run("PropagatesSafePoint", func(t *testing.T) { + parent, child, _ := setupParentChild(t) + + parent.setRecursive(true) + parent.triggerCancel(CancelAfterChatModel) + + select { + case <-child.cancelChan: + case <-time.After(1 * time.Second): + t.Fatal("child did not receive cancel within 1s") + } + assert.True(t, child.shouldCancel()) + }) + + t.Run("ImmediatePropagates", func(t *testing.T) { + parent, child, _ := setupParentChild(t) + + parent.setRecursive(true) + parent.triggerImmediateCancel() + + select { + case <-child.immediateChan: + case <-time.After(1 * time.Second): + t.Fatal("child did not receive immediate cancel within 1s") + } + assert.True(t, child.isImmediateCancelled()) + }) + + t.Run("GrandchildPropagation", func(t *testing.T) { + a := newCancelContext() + ctx, cancel := context.WithCancel(context.Background()) + + b := a.deriveChild(ctx) + c := b.deriveChild(ctx) + t.Cleanup(func() { + c.markDone() + b.markDone() + cancel() + }) + + a.setRecursive(true) + a.triggerCancel(CancelAfterChatModel) + + select { + case <-b.cancelChan: + case <-time.After(1 * time.Second): + t.Fatal("B did not receive cancel within 1s") + } + + select { + case <-c.cancelChan: + case <-time.After(1 * time.Second): + t.Fatal("C did not receive cancel within 1s") + } + + assert.True(t, b.shouldCancel()) + assert.True(t, c.shouldCancel()) + }) + + t.Run("SetBeforeCancel", func(t *testing.T) { + parent, child, _ := setupParentChild(t) + + parent.setRecursive(true) + + parent.triggerCancel(CancelAfterChatModel) + + select { + case <-child.cancelChan: + case <-time.After(1 * time.Second): + t.Fatal("child did not receive cancel within 1s") + } + assert.True(t, child.shouldCancel()) + }) + + t.Run("AfterRecursiveAndCancelAlreadySet", func(t *testing.T) { + parent := newCancelContext() + ctx, cancel := context.WithCancel(context.Background()) + + parent.setRecursive(true) + parent.triggerCancel(CancelAfterChatModel) + + child := parent.deriveChild(ctx) + t.Cleanup(func() { + child.markDone() + cancel() + }) + + select { + case <-child.cancelChan: + case <-time.After(1 * time.Second): + t.Fatal("child did not immediately receive cancel") + } + assert.True(t, child.shouldCancel()) + }) + }) + + t.Run("Escalation", func(t *testing.T) { + t.Run("EscalateFromNonRecursive", func(t *testing.T) { + parent, child, _ := setupParentChild(t) + + parent.triggerCancel(CancelAfterChatModel) + + assertNotClosedWithin(t, child.cancelChan, 50*time.Millisecond) + + parent.setRecursive(true) + + select { + case <-child.cancelChan: + case <-time.After(1 * time.Second): + t.Fatal("child did not receive cancel after escalation within 1s") + } + assert.True(t, child.shouldCancel()) + }) + + t.Run("EscalateImmediate", func(t *testing.T) { + parent, child, _ := setupParentChild(t) + + parent.triggerImmediateCancel() + + assertNotClosedWithin(t, child.immediateChan, 50*time.Millisecond) + + parent.setRecursive(true) + + select { + case <-child.immediateChan: + case <-time.After(1 * time.Second): + t.Fatal("child did not receive immediate cancel after escalation within 1s") + } + assert.True(t, child.isImmediateCancelled()) + }) + }) +} + +func TestDeriveChild_Race(t *testing.T) { + t.Run("SetRecursiveConcurrentWithCancelChan", func(t *testing.T) { + for i := 0; i < 100; i++ { + parent := newCancelContext() + ctx, cancel := context.WithCancel(context.Background()) + + child := parent.deriveChild(ctx) + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + parent.setRecursive(true) + }() + + go func() { + defer wg.Done() + parent.triggerCancel(CancelAfterChatModel) + }() + + wg.Wait() + + select { + case <-child.cancelChan: + case <-time.After(1 * time.Second): + t.Fatalf("iteration %d: child did not receive cancel within 1s", i) + } + + assert.True(t, child.shouldCancel()) + child.markDone() + cancel() + } + }) + + t.Run("ChildCompletesBeforeEscalation", func(t *testing.T) { + parent := newCancelContext() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + child := parent.deriveChild(ctx) + + parent.triggerCancel(CancelAfterChatModel) + time.Sleep(50 * time.Millisecond) + + child.markDone() + time.Sleep(50 * time.Millisecond) + + parent.setRecursive(true) + + assertNotClosedWithin(t, child.cancelChan, 50*time.Millisecond) + }) + + t.Run("MultipleChildren_PartialCompletion", func(t *testing.T) { + parent := newCancelContext() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + child1 := parent.deriveChild(ctx) + child2 := parent.deriveChild(ctx) + + parent.triggerCancel(CancelAfterChatModel) + time.Sleep(50 * time.Millisecond) + + child1.markDone() + time.Sleep(50 * time.Millisecond) + + parent.setRecursive(true) + + select { + case <-child2.cancelChan: + case <-time.After(1 * time.Second): + t.Fatal("running child did not receive cancel within 1s") + } + + assert.True(t, child2.shouldCancel()) + assert.False(t, child1.shouldCancel()) + child2.markDone() + }) + + t.Run("ContextCancelConcurrentWithRecursive", func(t *testing.T) { + done := make(chan struct{}) + go func() { + defer close(done) + + parent := newCancelContext() + ctx, cancel := context.WithCancel(context.Background()) + + child := parent.deriveChild(ctx) + + parent.triggerCancel(CancelAfterChatModel) + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + cancel() + }() + + go func() { + defer wg.Done() + parent.setRecursive(true) + }() + + wg.Wait() + child.markDone() + }() + + select { + case <-done: + case <-time.After(1 * time.Second): + t.Fatal("deadlock detected") + } + }) + + t.Run("ConcurrentSetRecursive", func(t *testing.T) { + parent := newCancelContext() + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + parent.setRecursive(true) + }() + } + + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(1 * time.Second): + t.Fatal("deadlock or panic in concurrent setRecursive") + } + + assert.True(t, parent.isRecursive()) + }) +} + +func TestGracePeriod_OnlyWhenRecursive(t *testing.T) { + parent, _, _ := setupParentChild(t) + + var nonRecursiveOptCount int + wrappedNonRecursive := parent.wrapGraphInterruptWithGracePeriod(func(opts ...compose.GraphInterruptOption) { + nonRecursiveOptCount = len(opts) + }) + wrappedNonRecursive() + assert.Equal(t, 0, nonRecursiveOptCount) + + parent.setRecursive(true) + + var recursiveOptCount int + wrappedRecursive := parent.wrapGraphInterruptWithGracePeriod(func(opts ...compose.GraphInterruptOption) { + recursiveOptCount = len(opts) + }) + wrappedRecursive() + assert.Equal(t, 1, recursiveOptCount) +} diff --git a/adk/cancel_test.go b/adk/cancel_test.go index 105c9ea13..2096a9ac3 100644 --- a/adk/cancel_test.go +++ b/adk/cancel_test.go @@ -142,6 +142,51 @@ func (s *cancelTestStore) Get(_ context.Context, key string) ([]byte, bool, erro return v, ok, nil } +func assertHasCancelError(t *testing.T, events []*AgentEvent) { + t.Helper() + for _, e := range events { + var ce *CancelError + if e.Err != nil && errors.As(e.Err, &ce) { + return + } + } + t.Fatal("expected CancelError in events") +} + +func drainAndAssertCancelError(t *testing.T, iter *AsyncIterator[*AgentEvent]) { + t.Helper() + for { + event, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + return + } + } + t.Fatal("expected CancelError in event stream") +} + +func drainEventsAndAssertCancelError(t *testing.T, iter *AsyncIterator[*AgentEvent]) []*AgentEvent { + t.Helper() + var events []*AgentEvent + hasCancelError := false + for { + event, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + hasCancelError = true + } + events = append(events, event) + } + assert.True(t, hasCancelError, "expected CancelError in event stream") + return events +} + func TestCancelContext(t *testing.T) { t.Run("BasicCancelContext", func(t *testing.T) { cc := newCancelContext() @@ -237,16 +282,9 @@ func TestWithCancel_WithTools(t *testing.T) { t.Fatal("Timed out waiting for events") } - assert.True(t, len(events) > 0) + assert.NotEmpty(t, events) - hasCancelError := false - for _, e := range events { - var cancelErr *CancelError - if e.Err != nil && errors.As(e.Err, &cancelErr) { - hasCancelError = true - } - } - assert.True(t, hasCancelError, "Should have CancelError event after cancel") + assertHasCancelError(t, events) }) t.Run("CancelAfterChatModel_DuringToolCall", func(t *testing.T) { @@ -317,7 +355,7 @@ func TestWithCancel_WithTools(t *testing.T) { events = append(events, event) } - assert.True(t, len(events) > 0) + assert.NotEmpty(t, events) assert.True(t, atomic.LoadInt32(&st.callCount) >= 1, "Tool should have been called") }) @@ -388,7 +426,7 @@ func TestWithCancel_WithTools(t *testing.T) { events = append(events, event) } - assert.True(t, len(events) > 0) + assert.NotEmpty(t, events) assert.True(t, atomic.LoadInt32(&st.callCount) >= 1, "Tool should have been called") }) @@ -400,6 +438,7 @@ func TestWithCancel_WithTools(t *testing.T) { child := cc.deriveChild(ctx) assert.NotNil(t, child) + cc.setRecursive(true) cc.setMode(CancelImmediate) if atomic.CompareAndSwapInt32(&cc.state, stateRunning, stateCancelling) { @@ -475,7 +514,7 @@ func TestWithCancel_WithTools(t *testing.T) { <-modelStarted - handle, _ := cancelFn(WithAgentCancelMode(CancelAfterChatModel)) + handle, _ := cancelFn(WithAgentCancelMode(CancelAfterChatModel), WithRecursive()) err = handle.Wait() assert.NoError(t, err) @@ -776,16 +815,9 @@ func TestWithCancel_Streaming(t *testing.T) { t.Fatal("Timed out waiting for events") } - assert.True(t, len(events) > 0) + assert.NotEmpty(t, events) - hasCancelError := false - for _, e := range events { - var ce *CancelError - if e.Err != nil && errors.As(e.Err, &ce) { - hasCancelError = true - } - } - assert.True(t, hasCancelError, "Should have CancelError event after cancel") + assertHasCancelError(t, events) }) t.Run("CancelAfterToolCalls_Streaming", func(t *testing.T) { @@ -858,7 +890,7 @@ func TestWithCancel_Streaming(t *testing.T) { events = append(events, event) } - assert.True(t, len(events) > 0) + assert.NotEmpty(t, events) assert.True(t, atomic.LoadInt32(&st.callCount) >= 1, "Tool should have been called") }) } @@ -990,7 +1022,7 @@ func TestWithCancel_Resume(t *testing.T) { resumeEvents = append(resumeEvents, event) } - assert.True(t, len(resumeEvents) > 0, "Resume should produce events") + assert.NotEmpty(t, resumeEvents, "Resume should produce events") }) t.Run("Resume_ThenCancel", func(t *testing.T) { @@ -1107,7 +1139,7 @@ func TestWithCancel_Resume(t *testing.T) { elapsed := time.Since(start) assert.True(t, elapsed < 1*time.Second, "Resume should return quickly after cancel, elapsed: %v", elapsed) - assert.True(t, len(resumeEvents) > 0) + assert.NotEmpty(t, resumeEvents) hasCancelError := false for _, e := range resumeEvents { @@ -1491,21 +1523,7 @@ func TestWithCancel_SequentialAgent(t *testing.T) { err = handle.Wait() assert.NoError(t, err, "Cancel during second agent should succeed, not return ErrExecutionCompleted") - var events []*AgentEvent - hasCancelError := false - for { - event, ok := iter.Next() - if !ok { - break - } - var ce *CancelError - if event.Err != nil && errors.As(event.Err, &ce) { - hasCancelError = true - } - events = append(events, event) - } - - assert.True(t, hasCancelError, "Should have CancelError event") + drainEventsAndAssertCancelError(t, iter) }) } @@ -1564,19 +1582,7 @@ func TestWithCancel_LoopAgent(t *testing.T) { err = handle.Wait() assert.NoError(t, err, "Cancel during loop iteration should succeed") - hasCancelError := false - for { - event, ok := iter.Next() - if !ok { - break - } - var ce *CancelError - if event.Err != nil && errors.As(event.Err, &ce) { - hasCancelError = true - } - } - - assert.True(t, hasCancelError, "Should have CancelError event") + drainAndAssertCancelError(t, iter) }) } @@ -1623,22 +1629,10 @@ func TestWithCancel_ParallelAgent(t *testing.T) { err = handle.Wait() assert.NoError(t, err, "Cancel during parallel agents should succeed") - var events []*AgentEvent - hasCancelError := false - for { - event, ok := iter.Next() - if !ok { - break - } - var ce *CancelError - if event.Err != nil && errors.As(event.Err, &ce) { - hasCancelError = true - } - events = append(events, event) - } + events := drainEventsAndAssertCancelError(t, iter) elapsed := time.Since(start) - assert.True(t, hasCancelError, "Should have CancelError event") + _ = events assert.True(t, elapsed < 3*time.Second, "Should complete quickly after cancel, elapsed: %v", elapsed) }) } @@ -1707,20 +1701,9 @@ func TestWithCancel_SupervisorAgent(t *testing.T) { err = handle.Wait() assert.NoError(t, err, "Cancel during sub-agent should succeed") - hasCancelError := false - for { - event, ok := iter.Next() - if !ok { - break - } - var ce *CancelError - if event.Err != nil && errors.As(event.Err, &ce) { - hasCancelError = true - } - } + drainAndAssertCancelError(t, iter) elapsed := time.Since(start) - assert.True(t, hasCancelError, "Should have CancelError event") assert.True(t, elapsed < 3*time.Second, "Should complete quickly after cancel, elapsed: %v", elapsed) }) } @@ -2302,7 +2285,7 @@ func TestCancel_SequentialWorkflow_CancelAfterChatModel(t *testing.T) { assert.Nil(t, event.Err, "Should not have error during resume") resumeEvents = append(resumeEvents, event) } - assert.True(t, len(resumeEvents) > 0, "Resume should produce events") + assert.NotEmpty(t, resumeEvents, "Resume should produce events") } func TestCancelImmediate_OrphanedToolGoroutine_NoPanic(t *testing.T) { @@ -2625,7 +2608,7 @@ func TestCancelImmediate_AgentTool_PreservesChildCheckpoint(t *testing.T) { waitForChan(t, leafModel.startedChan, "Leaf agent model did not start") - handle, contributed := cancelFn() + handle, contributed := cancelFn(WithRecursive()) assert.True(t, contributed) assert.NoError(t, handle.Wait()) @@ -3253,23 +3236,27 @@ func TestCancelContext_ActiveChildren_Tracking(t *testing.T) { wrapped := parent.wrapGraphInterruptWithGracePeriod(mockInterrupt) - // No children: no options appended receivedOpts = nil wrapped() assert.Empty(t, receivedOpts, "Should pass no extra options when no children") - // With active child: one timeout option appended _ = parent.deriveChild(ctx) + + receivedOpts = nil + wrapped() + assert.Empty(t, receivedOpts, "Should pass no extra options when children are active but not recursive") + + parent.setRecursive(true) + receivedOpts = nil wrapped() - assert.Len(t, receivedOpts, 1, "Should add exactly one timeout option when children are active") + assert.Len(t, receivedOpts, 1, "Should add exactly one timeout option when children are active and recursive") - // Caller-provided options are preserved, grace period option appended after receivedOpts = nil callerOpt := compose.WithGraphInterruptTimeout(0) wrapped(callerOpt) assert.Len(t, receivedOpts, 2, - "Should append timeout option after caller-provided options when children are active") + "Should append timeout option after caller-provided options when children are active and recursive") // Note: verifying the exact timeout value (defaultCancelImmediateGracePeriod) // requires access to unexported compose.graphInterruptOptions. The integration // tests (TestCancelImmediate_AgentTool_PreservesChildCheckpoint) verify the @@ -3437,7 +3424,7 @@ func TestCancel_LoopWorkflow_CancelAfterChatModel(t *testing.T) { assert.Nil(t, event.Err, "Should not have error during resume") resumeEvents = append(resumeEvents, event) } - assert.True(t, len(resumeEvents) > 0, "Resume should produce events") + assert.NotEmpty(t, resumeEvents, "Resume should produce events") } func TestCancel_NestedWorkflow_AgentTool_CancelAfterChatModel(t *testing.T) { @@ -3514,7 +3501,7 @@ func TestCancel_NestedWorkflow_AgentTool_CancelAfterChatModel(t *testing.T) { t.Fatal("Leaf agent model did not start") } - handle, contributed := cancelFn(WithAgentCancelMode(CancelAfterChatModel)) + handle, contributed := cancelFn(WithAgentCancelMode(CancelAfterChatModel), WithRecursive()) assert.True(t, contributed, "Cancel should contribute") err = handle.Wait() assert.NoError(t, err) @@ -3737,5 +3724,5 @@ func TestCancel_CancelAfterToolCalls_InSequentialWorkflow(t *testing.T) { assert.Nil(t, event.Err, "Should not have error during resume") resumeEvents = append(resumeEvents, event) } - assert.True(t, len(resumeEvents) > 0, "Resume should produce events") + assert.NotEmpty(t, resumeEvents, "Resume should produce events") } diff --git a/adk/turn_buffer.go b/adk/turn_buffer.go new file mode 100644 index 000000000..b154587c9 --- /dev/null +++ b/adk/turn_buffer.go @@ -0,0 +1,128 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import "sync" + +type turnBuffer[T any] struct { + buf []T + mu sync.Mutex + notEmpty *sync.Cond + closed bool + woken bool +} + +func newTurnBuffer[T any]() *turnBuffer[T] { + tb := &turnBuffer[T]{} + tb.notEmpty = sync.NewCond(&tb.mu) + return tb +} + +func (tb *turnBuffer[T]) Send(value T) { + tb.mu.Lock() + defer tb.mu.Unlock() + + if tb.closed { + panic("turnBuffer: send on closed buffer") + } + + tb.buf = append(tb.buf, value) + tb.notEmpty.Signal() +} + +func (tb *turnBuffer[T]) TrySend(value T) bool { + tb.mu.Lock() + defer tb.mu.Unlock() + + if tb.closed { + return false + } + + tb.buf = append(tb.buf, value) + tb.notEmpty.Signal() + return true +} + +func (tb *turnBuffer[T]) Receive() (T, bool) { + tb.mu.Lock() + defer tb.mu.Unlock() + + for len(tb.buf) == 0 && !tb.closed && !tb.woken { + tb.notEmpty.Wait() + } + + tb.woken = false + + if len(tb.buf) == 0 { + var zero T + return zero, false + } + + val := tb.buf[0] + tb.buf = tb.buf[1:] + return val, true +} + +func (tb *turnBuffer[T]) Close() { + tb.mu.Lock() + defer tb.mu.Unlock() + + if !tb.closed { + tb.closed = true + tb.notEmpty.Broadcast() + } +} + +func (tb *turnBuffer[T]) TakeAll() []T { + tb.mu.Lock() + defer tb.mu.Unlock() + + if len(tb.buf) == 0 { + return nil + } + + values := tb.buf + tb.buf = nil + return values +} + +func (tb *turnBuffer[T]) PushFront(values []T) { + if len(values) == 0 { + return + } + + tb.mu.Lock() + defer tb.mu.Unlock() + + tb.buf = append(append([]T{}, values...), tb.buf...) + tb.notEmpty.Signal() +} + +func (tb *turnBuffer[T]) Wakeup() { + tb.mu.Lock() + defer tb.mu.Unlock() + + tb.woken = true + tb.notEmpty.Broadcast() +} + +func (tb *turnBuffer[T]) ClearWakeup() { + tb.mu.Lock() + defer tb.mu.Unlock() + + tb.woken = false +} diff --git a/adk/turn_loop.go b/adk/turn_loop.go index 7d0f61a3b..124f65459 100644 --- a/adk/turn_loop.go +++ b/adk/turn_loop.go @@ -27,7 +27,6 @@ import ( "sync/atomic" "time" - "github.com/cloudwego/eino/internal" "github.com/cloudwego/eino/internal/safe" ) @@ -54,19 +53,19 @@ import ( // acts when a new Stop() call has been made, supporting mode escalation // (e.g. CancelAfterToolCalls followed by CancelImmediate). type stopSignal struct { - // done is closed exactly once by closeDone(). A closed channel is - // readable forever, so it serves as a durable stop flag for all watchers. done chan struct{} - mu sync.Mutex - gen uint64 + mu sync.Mutex + gen uint64 + // agentCancelOpts controls how the stop interacts with the running agent: + // nil → no cancel; the turn runs to completion (bare Stop) + // empty → CancelImmediate (WithImmediate) + // non-empty → cancel with specific modes (WithGraceful, WithGracefulTimeout) agentCancelOpts []AgentCancelOption skipCheckpoint bool stopCause string - // notify is a buffered(1) channel that wakes the current turn's watcher - // when Stop() is called. Unlike done, it supports repeated Stop() calls - // for cancel-mode escalation. - notify chan struct{} + idleFor time.Duration + notify chan struct{} } func newStopSignal() *stopSignal { @@ -82,13 +81,21 @@ func newStopSignal() *stopSignal { func (s *stopSignal) signal(cfg *stopConfig) { s.mu.Lock() s.gen++ - s.agentCancelOpts = cfg.agentCancelOpts + // Only overwrite when the caller explicitly provides cancel options. + // A bare Stop() leaves cfg.agentCancelOpts nil (no cancel intent), which + // must not de-escalate a previously set cancel policy. + if cfg.agentCancelOpts != nil { + s.agentCancelOpts = cfg.agentCancelOpts + } if cfg.skipCheckpoint { s.skipCheckpoint = true } if cfg.stopCause != "" && s.stopCause == "" { s.stopCause = cfg.stopCause } + if cfg.idleFor > 0 && s.idleFor == 0 { + s.idleFor = cfg.idleFor + } s.mu.Unlock() select { case s.notify <- struct{}{}: @@ -131,6 +138,12 @@ func (s *stopSignal) getStopCause() string { return s.stopCause } +func (s *stopSignal) getIdleFor() time.Duration { + s.mu.Lock() + defer s.mu.Unlock() + return s.idleFor +} + // preemptSignal coordinates preemption between Push callers and the run loop. // // Lifecycle overview: @@ -369,6 +382,16 @@ type TurnLoopConfig[T any] struct { // - tc.Consumed: items that triggered this agent execution // - tc.Loop: allows calling Push() or Stop() directly from within the callback // - tc.Preempted / tc.Stopped: signals while processing events + // + // Error handling: the returned error is only used when the callback itself + // wants to abort the TurnLoop. The TurnLoop already captures CancelError + // from the event stream when the turn is stopped or preempted, so the + // callback should NOT propagate CancelError. In practice, return a non-nil + // error only for callback-internal failures that should terminate the loop; + // return nil when the current agent is canceled by an external Stop or + // Preempt (Preempt cancels the current agent but the loop continues with + // the next turn). + // // Optional. If not provided, events are drained and errors (except CancelError // from Stop-triggered cancellation) are returned as ExitReason. OnAgentEvents func(ctx context.Context, tc *TurnContext[T], events *AsyncIterator[*AgentEvent]) error @@ -534,9 +557,11 @@ func (l *TurnLoop[T]) planTurn( // and any items that were not processed. type TurnLoopExitState[T any] struct { // ExitReason indicates why the loop exited. - // nil means clean exit (Stop() was called and completed normally). + // nil means clean exit (Stop() was called without cancel options, or the + // agent completed normally before Stop took effect). // Non-nil values include context errors, callback errors, *CancelError, etc. - // When Stop() cancels a running agent, ExitReason will be a *CancelError. + // When Stop(WithImmediate()) or Stop(WithGraceful()) cancels a running + // agent, ExitReason will be a *CancelError. // This never contains checkpoint errors — see CheckpointErr for those. ExitReason error @@ -545,9 +570,10 @@ type TurnLoopExitState[T any] struct { // This is always valid regardless of ExitReason. UnhandledItems []T - // CanceledItems contains the items whose turn was canceled by Stop(). - // This is set when Stop() is called during a running turn, even if it - // did not contribute to the final CancelError. + // CanceledItems contains the items whose turn was actually interrupted + // by a cancel (Stop with WithImmediate, WithGraceful, or WithGracefulTimeout). + // Only populated when ExitReason is a *CancelError — if the agent finishes + // normally before the cancel takes effect, CanceledItems is empty. // It can be used to reconstruct GenInput/PrepareAgent inputs when resuming. CanceledItems []T @@ -588,11 +614,15 @@ type TurnContext[T any] struct { Consumed []T // Preempted is closed when a preempt signal fires for the current turn - // (via Push with WithPreempt) and at least one preemptive Push contributed - // to the CancelError for the current turn. + // (via Push with WithPreempt/WithPreemptTimeout) and at least one + // preemptive Push contributed to the CancelError for the current turn. // "Contributed" means the preempt's cancel options were included in the // CancelError before it was finalized. Remains open if no preempt contributed. // Use in a select to detect preemption while processing events. + // + // Both Preempted and Stopped may be closed within the same turn if both + // signals arrive while the agent is still being cancelled. Whichever + // arrives after the cancel is fully handled will not contribute. Preempted <-chan struct{} // Stopped is closed when a Stop() call contributed to the CancelError for the @@ -600,6 +630,8 @@ type TurnContext[T any] struct { // "Contributed" means Stop's cancel options were included in the CancelError // before it was finalized. Remains open if Stop did not contribute. // Use in a select to detect stop while processing events. + // + // See Preempted for the relationship between the two channels. Stopped <-chan struct{} // StopCause returns the business-supplied reason from WithStopCause. @@ -628,7 +660,7 @@ type TurnContext[T any] struct { type TurnLoop[T any] struct { config TurnLoopConfig[T] - buffer *internal.UnboundedChan[T] + buffer *turnBuffer[T] stopped int32 started int32 @@ -777,20 +809,98 @@ type turnLoopPendingResume[T any] struct { resumeBytes []byte } +// SafePoint describes at which boundary the agent may be cancelled. +// It is a bitmask: values can be combined with bitwise OR to accept multiple +// safe points (e.g. AfterToolCalls | AfterChatModel). Internally, SafePoint +// is translated to CancelMode via toCancelMode(). +// +// SafePoint is used only in the preemption API (WithPreempt/WithPreemptTimeout). +// A key design constraint: preemption always targets a safe point — the user's +// intent is to cancel at a well-defined boundary, never to abort immediately. +// Immediate cancellation is only reachable as an automatic timeout escalation +// (via WithPreemptTimeout), not as a direct user choice. This is why SafePoint +// has no "immediate" value and why WithPreempt requires a non-zero SafePoint +// (panics otherwise). +type SafePoint int + +const ( + // AfterToolCalls allows the agent to finish the current tool-call round + // before being cancelled. + AfterToolCalls SafePoint = 1 << iota + // AfterChatModel allows the agent to finish the current chat-model + // call before being cancelled. + AfterChatModel + // AnySafePoint is shorthand for AfterToolCalls | AfterChatModel. + AnySafePoint = AfterToolCalls | AfterChatModel +) + +func (sp SafePoint) toCancelMode() CancelMode { + var mode CancelMode + if sp&AfterToolCalls != 0 { + mode |= CancelAfterToolCalls + } + if sp&AfterChatModel != 0 { + mode |= CancelAfterChatModel + } + return mode +} + type stopConfig struct { agentCancelOpts []AgentCancelOption skipCheckpoint bool stopCause string + idleFor time.Duration } // StopOption is an option for Stop(). type StopOption func(*stopConfig) -// WithAgentCancel sets the agent cancel options to use when stopping the loop. -// These options control how the currently running agent is cancelled. -func WithAgentCancel(opts ...AgentCancelOption) StopOption { +// WithGraceful requests a graceful stop that waits at the nearest safe point +// (after tool calls or after a chat-model call) and propagates recursively to +// nested agents. It does not impose a time limit; use WithGracefulTimeout to +// add a grace period after which the stop escalates to immediate cancellation. +// +// WithGraceful and WithGracefulTimeout are mutually exclusive; if both are +// passed to the same Stop call, the last one wins. +func WithGraceful() StopOption { + return func(cfg *stopConfig) { + cfg.agentCancelOpts = []AgentCancelOption{ + WithAgentCancelMode(CancelAfterChatModel | CancelAfterToolCalls), + WithRecursive(), + } + } +} + +// WithImmediate aborts the running agent turn as soon as possible. +// The agent's context is cancelled immediately without waiting for any +// safe point. Nested agents inside AgentTools are torn down as a side effect. +// +// This is the most aggressive stop mode — typically used when the caller +// wants to shut down the TurnLoop with no intention of resuming. +func WithImmediate() StopOption { return func(cfg *stopConfig) { - cfg.agentCancelOpts = opts + cfg.agentCancelOpts = []AgentCancelOption{} + } +} + +// WithGracefulTimeout is like WithGraceful but adds a grace period. +// If the agent has not reached a safe point within gracePeriod, the stop +// escalates to immediate cancellation. +// +// gracePeriod must be positive; passing a zero or negative duration panics. +// +// WithGraceful and WithGracefulTimeout are mutually exclusive; if both are +// passed to the same Stop call, the last one wins. +func WithGracefulTimeout(gracePeriod time.Duration) StopOption { + if gracePeriod <= 0 { + panic("adk: WithGracefulTimeout: gracePeriod must be positive") + } + return func(cfg *stopConfig) { + cfg.agentCancelOpts = []AgentCancelOption{ + WithAgentCancelMode(CancelAfterChatModel | CancelAfterToolCalls), + WithRecursive(), + WithAgentCancelTimeout(gracePeriod), + } } } @@ -813,6 +923,34 @@ func WithStopCause(cause string) StopOption { } } +// UntilIdleFor defers the stop until the TurnLoop has been continuously idle +// (blocked between turns with no pending items) for at least the given +// duration. Each time a new item arrives the timer resets from zero. +// +// This is useful when business code monitors agent activity externally and +// wants to shut down the loop once there has been no work for a while, without +// racing with concurrent Push calls. +// +// UntilIdleFor is combinable with other StopOptions in the same call. +// For example, Stop(UntilIdleFor(30*time.Second), WithGraceful()) means +// "after 30 s of idle, stop gracefully". If another Stop call is made +// without UntilIdleFor (e.g. Stop(WithImmediate())), the loop shuts down +// immediately, bypassing the idle wait. +// +// Only the first UntilIdleFor duration takes effect; subsequent calls with +// a different duration are ignored. A Stop() call without UntilIdleFor always +// shuts down the loop immediately regardless of any pending idle timer. +// +// duration must be positive; passing a zero or negative value panics. +func UntilIdleFor(duration time.Duration) StopOption { + if duration <= 0 { + panic("adk: UntilIdleFor: duration must be positive") + } + return func(cfg *stopConfig) { + cfg.idleFor = duration + } +} + type pushConfig[T any] struct { preempt bool preemptDelay time.Duration @@ -823,23 +961,58 @@ type pushConfig[T any] struct { // PushOption is an option for Push(). type PushOption[T any] func(*pushConfig[T]) -// WithPreempt signals that the current agent should be canceled after pushing. -// This enables atomic "push + preempt" to avoid race conditions between -// pushing an urgent item and triggering preemption. -// The loop will cancel the current agent turn and continue with the next turn, -// where GenInput will see all buffered items including the newly pushed one. -func WithPreempt[T any](agentCancelOpts ...AgentCancelOption) PushOption[T] { +// WithPreempt signals that the current agent turn should be cancelled at the +// specified safePoint after pushing the new item. The loop cancels the current +// turn and starts a new one, where GenInput will see all buffered items +// including the newly pushed one. +// Use WithPreemptTimeout to add a timeout that escalates to immediate abort. +// +// Because safe points fire at turn-level boundaries (after the chat model +// returns or after all tool calls complete), no nested agent is running at +// the moment of cancellation — nested agents within AgentTools have either +// not started yet (AfterChatModel) or already finished (AfterToolCalls). +// If the preemption escalates to immediate via WithPreemptTimeout, any +// in-flight nested agent is torn down through Go context cancellation. +// +// WithPreempt and WithPreemptTimeout are mutually exclusive; if both are +// passed to the same Push call, the last one wins. +// +// safePoint must not be zero; passing SafePoint(0) panics. +func WithPreempt[T any](safePoint SafePoint) PushOption[T] { + if safePoint == 0 { + panic("adk: SafePoint must not be zero; use AfterToolCalls, AfterChatModel, or AnySafePoint") + } return func(cfg *pushConfig[T]) { cfg.preempt = true - cfg.agentCancelOpts = agentCancelOpts + cfg.agentCancelOpts = []AgentCancelOption{ + WithAgentCancelMode(safePoint.toCancelMode()), + } + } +} + +// WithPreemptTimeout is like WithPreempt but adds a timeout. If the agent has +// not reached the safe point within timeout, the preemption escalates to +// immediate cancellation. +// +// safePoint must not be zero; passing SafePoint(0) panics. +func WithPreemptTimeout[T any](safePoint SafePoint, timeout time.Duration) PushOption[T] { + if safePoint == 0 { + panic("adk: SafePoint must not be zero; use AfterToolCalls, AfterChatModel, or AnySafePoint") + } + return func(cfg *pushConfig[T]) { + cfg.preempt = true + cfg.agentCancelOpts = []AgentCancelOption{ + WithAgentCancelMode(safePoint.toCancelMode()), + WithAgentCancelTimeout(timeout), + } } } // WithPreemptDelay sets a delay duration before preemption takes effect. -// When used with WithPreempt, the push will succeed immediately, but the -// preemption signal will be delayed by the specified duration. -// This allows the current agent to continue processing for a grace period -// before being preempted. +// When used with WithPreempt or WithPreemptTimeout, the push will succeed +// immediately, but the preemption signal will be delayed by the specified +// duration. This allows the current agent to continue processing for a grace +// period before being preempted. func WithPreemptDelay[T any](delay time.Duration) PushOption[T] { return func(cfg *pushConfig[T]) { cfg.preemptDelay = delay @@ -861,7 +1034,7 @@ func WithPreemptDelay[T any](delay time.Duration) PushOption[T] { // return nil // between turns, plain push // } // if isLowPriority(tc.Consumed) { -// return []PushOption[MyItem]{WithPreempt[MyItem]()} +// return []PushOption[MyItem]{WithPreempt[MyItem](AnySafePoint)} // } // return nil // don't preempt high-priority work // })) @@ -900,7 +1073,7 @@ func NewTurnLoop[T any](cfg TurnLoopConfig[T]) *TurnLoop[T] { l := &TurnLoop[T]{ config: cfg, - buffer: internal.NewUnboundedChan[T](), + buffer: newTurnBuffer[T](), done: make(chan struct{}), stopSig: newStopSignal(), preemptSig: newPreemptSignal(), @@ -944,12 +1117,14 @@ func (l *TurnLoop[T]) Run(ctx context.Context) { // Once TakeLateItems() has been called, any subsequent push that would become a // late item will panic instead of being silently dropped. // -// Use WithPreempt() to atomically push an item and signal preemption of the current agent. -// This is useful for urgent items that should interrupt the current processing. +// Use WithPreempt() or WithPreemptTimeout() to atomically push an item and signal +// preemption of the current agent. This is useful for urgent items that should +// interrupt the current processing. // The returned channel may be waited on if the caller needs to ensure the preempt // signal has been observed. // -// Use WithPreemptDelay() together with WithPreempt() to delay the preemption signal. +// Use WithPreemptDelay() together with WithPreempt()/WithPreemptTimeout() to delay +// the preemption signal. // Push returns immediately after the item is buffered, and a goroutine is spawned // to signal preemption after the delay. func (l *TurnLoop[T]) Push(item T, opts ...PushOption[T]) (bool, <-chan struct{}) { @@ -1077,18 +1252,29 @@ func (l *TurnLoop[T]) pushWithConfig(item T, cfg *pushConfig[T]) (bool, <-chan s } // Stop signals the loop to stop and returns immediately (non-blocking). -// The loop will finish the current turn (or cancel it via WithAgentCancel options), -// then exit without starting a new turn. -// Use WithAgentCancel to control how the currently running agent is cancelled. -// This method is idempotent - multiple calls update cancel options. +// Without options, the current agent turn runs to completion and the loop +// exits at the turn boundary without starting a new turn. ExitReason is nil. +// +// Use WithImmediate() to abort the running agent turn immediately. +// Use WithGraceful() to cancel at the nearest safe point with recursive +// propagation to nested agents. +// Use WithGracefulTimeout() for safe-point cancel with an escalation deadline. +// Use UntilIdleFor() to defer the stop until the loop has been continuously +// idle for a given duration; the loop shuts down automatically once the idle +// timer fires. +// +// This method may be called multiple times; subsequent calls update cancel options. +// A Stop() call without UntilIdleFor shuts down the loop immediately, even if +// a prior UntilIdleFor is still waiting. // Call Wait() to block until the loop has fully exited and get the result. // // Stop may be called before Run. In that case, the stopped flag is set and // a subsequent Run will exit the loop immediately. // // If the running agent does not support the WithCancel AgentRunOption, -// Stop degrades to "exit the loop on entering the next iteration" — the -// current agent turn runs to completion before the loop exits. +// all cancel-related options (WithImmediate, WithGraceful, WithGracefulTimeout) +// degrade to "exit the loop on entering the next iteration" — the current +// agent turn runs to completion before the loop exits. func (l *TurnLoop[T]) Stop(opts ...StopOption) { cfg := &stopConfig{} for _, opt := range opts { @@ -1097,6 +1283,14 @@ func (l *TurnLoop[T]) Stop(opts ...StopOption) { l.stopSig.signal(cfg) + if cfg.idleFor > 0 { + l.buffer.Wakeup() + return + } + l.commitStop() +} + +func (l *TurnLoop[T]) commitStop() { l.stopOnce.Do(func() { l.stopSig.closeDone() atomic.StoreInt32(&l.stopped, 1) @@ -1109,7 +1303,7 @@ func (l *TurnLoop[T]) Stop(opts ...StopOption) { // All callers will receive the same result. // // Wait blocks until Run is called AND the loop exits. If Run is -// ever called, Wait blocks forever. +// never called, Wait blocks forever. func (l *TurnLoop[T]) Wait() *TurnLoopExitState[T] { <-l.done return l.result @@ -1153,7 +1347,36 @@ func (l *TurnLoop[T]) run(ctx context.Context) { pushBack = append(pushBack, pr.unhandled...) pushBack = append(pushBack, pr.newItems...) } else { - first, ok := l.buffer.Receive() + var first T + var ok bool + + if idleFor := l.stopSig.getIdleFor(); idleFor > 0 { + l.buffer.ClearWakeup() + idleTimer := time.NewTimer(idleFor) + cancelIdle := make(chan struct{}) + // When the idle timer fires, commitStop closes the buffer via + // buffer.Close(), which broadcasts to unblock the pending + // Receive() call below. + go func() { + select { + case <-idleTimer.C: + l.commitStop() + case <-cancelIdle: + } + }() + + first, ok = l.buffer.Receive() + + idleTimer.Stop() + close(cancelIdle) + } else { + first, ok = l.buffer.Receive() + // Woken up by Stop(UntilIdleFor); re-enter loop to start the idle timer. + if !ok && l.stopSig.getIdleFor() > 0 { + continue + } + } + if !ok { if err := ctx.Err(); err != nil { l.runErr = err @@ -1233,6 +1456,9 @@ func (l *TurnLoop[T]) run(ctx context.Context) { l.preemptSig.endTurnAndUnhold() if runErr != nil { + if errors.As(runErr, new(*CancelError)) && len(l.canceledItems) == 0 { + l.canceledItems = append([]T{}, plan.spec.consumed...) + } l.runErr = runErr return } @@ -1313,47 +1539,36 @@ func (l *TurnLoop[T]) watchPreemptSignal(done <-chan struct{}, agentCancelFunc A func (l *TurnLoop[T]) watchStopSignal(done <-chan struct{}, agentCancelFunc AgentCancelFunc, stoppedDone chan struct{}) { var lastGen uint64 stoppedClosed := false + + tryCancel := func(gen uint64, opts []AgentCancelOption) { + if gen == lastGen { + return + } + lastGen = gen + if opts == nil { + return + } + _, contributed := agentCancelFunc(opts...) + if contributed && !stoppedClosed { + close(stoppedDone) + stoppedClosed = true + } + } + for { select { case <-done: return case <-l.stopSig.notify: - gen, opts := l.stopSig.check() - if gen != lastGen { - lastGen = gen - // CancelHandle is intentionally not awaited here: agentCancelFunc - // commits the cancel signal synchronously, while waiting would block - // until the turn finishes and can deadlock this watcher against done. - _, contributed := agentCancelFunc(opts...) - if contributed && !stoppedClosed { - close(stoppedDone) - stoppedClosed = true - } - } + tryCancel(l.stopSig.check()) case <-l.stopSig.done: - gen, opts := l.stopSig.check() - if gen != lastGen { - lastGen = gen - _, contributed := agentCancelFunc(opts...) - if contributed && !stoppedClosed { - close(stoppedDone) - stoppedClosed = true - } - } + tryCancel(l.stopSig.check()) for { select { case <-done: return case <-l.stopSig.notify: - gen, opts := l.stopSig.check() - if gen != lastGen { - lastGen = gen - _, contributed := agentCancelFunc(opts...) - if contributed && !stoppedClosed { - close(stoppedDone) - stoppedClosed = true - } - } + tryCancel(l.stopSig.check()) } } } @@ -1366,11 +1581,6 @@ func (l *TurnLoop[T]) runAgentAndHandleEvents( spec *turnRunSpec[T], ) error { var iter *AsyncIterator[*AgentEvent] - defer func() { - if l.stopSig.isStopped() && len(l.canceledItems) == 0 { - l.canceledItems = append([]T{}, spec.consumed...) - } - }() runOpts, ms, err := l.setupBridgeStore(spec, spec.runOpts) if err != nil { diff --git a/adk/turn_loop_test.go b/adk/turn_loop_test.go index 1b8b2c86d..4f22ca1a7 100644 --- a/adk/turn_loop_test.go +++ b/adk/turn_loop_test.go @@ -26,6 +26,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/cloudwego/eino/schema" ) @@ -166,6 +167,42 @@ func newAndRunTurnLoop[T any](ctx context.Context, cfg TurnLoopConfig[T]) *TurnL return l } +func newPreemptTestLoop(t *testing.T, agent *turnLoopCancellableMockAgent) *TurnLoop[string] { + t.Helper() + + agentStarted := make(chan struct{}) + agentStartedOnce := sync.Once{} + + originalRunFunc := agent.runFunc + agent.runFunc = func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + agentStartedOnce.Do(func() { close(agentStarted) }) + return originalRunFunc(ctx, input) + } + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: []string{items[0]}, + Remaining: items[1:], + }, nil + }, + }) + + loop.Push("first") + + select { + case <-agentStarted: + case <-time.After(1 * time.Second): + t.Fatal("agent did not start") + } + + return loop +} + func TestTurnLoop_RunAndPush(t *testing.T) { processedItems := make([]string, 0) var mu sync.Mutex @@ -197,7 +234,7 @@ func TestTurnLoop_RunAndPush(t *testing.T) { defer mu.Unlock() assert.NoError(t, result.ExitReason) - assert.True(t, len(processedItems) > 0, "should have processed at least one item") + assert.NotEmpty(t, processedItems, "should have processed at least one item") } func TestTurnLoop_PushReturnsErrorAfterStop(t *testing.T) { @@ -296,7 +333,7 @@ func TestTurnLoop_UnhandledItemsOnStop(t *testing.T) { close(blocked) result := loop.Wait() - assert.True(t, len(result.UnhandledItems) >= 0, "should return unhandled items") + assert.NotEmpty(t, result.UnhandledItems, "should return unhandled items") } func TestTurnLoop_GenInputError(t *testing.T) { @@ -368,7 +405,7 @@ func TestTurnLoop_BatchProcessing(t *testing.T) { mu.Lock() defer mu.Unlock() - assert.True(t, len(batches) > 0, "should have processed at least one batch") + assert.NotEmpty(t, batches, "should have processed at least one batch") } func TestTurnLoop_StopWithMode(t *testing.T) { @@ -381,7 +418,7 @@ func TestTurnLoop_StopWithMode(t *testing.T) { }, }) - loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelAfterToolCalls))) + loop.Stop(WithGraceful()) result := loop.Wait() assert.NoError(t, result.ExitReason) @@ -438,7 +475,7 @@ func TestTurnLoop_Preempt_CancelsCurrentAgent(t *testing.T) { t.Fatal("agent did not start") } - loop.Push("urgent", WithPreempt[string]()) + loop.Push("urgent", WithPreempt[string](AnySafePoint)) select { case <-agentCancelled: @@ -516,7 +553,7 @@ func TestTurnLoop_Preempt_DiscardsConsumedItems(t *testing.T) { t.Fatal("agent did not start") } - loop.Push("urgent", WithPreempt[string]()) + loop.Push("urgent", WithPreempt[string](AnySafePoint)) select { case <-agentDone: @@ -530,17 +567,13 @@ func TestTurnLoop_Preempt_DiscardsConsumedItems(t *testing.T) { mu.Lock() defer mu.Unlock() - assert.GreaterOrEqual(t, len(genInputResults), 2) - if len(genInputResults) >= 2 { - assert.NotContains(t, genInputResults[1], "first") - assert.Contains(t, genInputResults[1], "urgent") - } + require.GreaterOrEqual(t, len(genInputResults), 2) + assert.NotContains(t, genInputResults[1], "first") + assert.Contains(t, genInputResults[1], "urgent") } func TestTurnLoop_Preempt_WithAgentCancelMode(t *testing.T) { - agentStarted := make(chan struct{}) cancelFuncCalled := make(chan struct{}) - agentStartedOnce := sync.Once{} cancelFuncCalledOnce := sync.Once{} firstCancelModeUsed := CancelImmediate var cancelModeMu sync.Mutex @@ -548,9 +581,6 @@ func TestTurnLoop_Preempt_WithAgentCancelMode(t *testing.T) { agent := &turnLoopCancellableMockAgent{ name: "test", runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { - agentStartedOnce.Do(func() { - close(agentStarted) - }) <-ctx.Done() return &AgentOutput{}, nil }, @@ -564,28 +594,9 @@ func TestTurnLoop_Preempt_WithAgentCancelMode(t *testing.T) { }, } - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { - return agent, nil - }, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ - Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, - Consumed: []string{items[0]}, - Remaining: items[1:], - }, nil - }, - }) - - loop.Push("first") + loop := newPreemptTestLoop(t, agent) - select { - case <-agentStarted: - case <-time.After(1 * time.Second): - t.Fatal("agent did not start") - } - - loop.Push("urgent", WithPreempt[string](WithAgentCancelMode(CancelAfterToolCalls))) + loop.Push("urgent", WithPreempt[string](AfterToolCalls)) select { case <-cancelFuncCalled: @@ -603,16 +614,13 @@ func TestTurnLoop_Preempt_WithAgentCancelMode(t *testing.T) { } func TestTurnLoop_PreemptAck_ClosesAfterCancelIsInitiated(t *testing.T) { - agentStarted := make(chan struct{}) cancelObserved := make(chan struct{}) agentFinishGate := make(chan struct{}) - agentStartedOnce := sync.Once{} cancelObservedOnce := sync.Once{} agent := &turnLoopCancellableMockAgent{ name: "test", runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { - agentStartedOnce.Do(func() { close(agentStarted) }) <-ctx.Done() <-agentFinishGate return &AgentOutput{}, nil @@ -622,28 +630,9 @@ func TestTurnLoop_PreemptAck_ClosesAfterCancelIsInitiated(t *testing.T) { }, } - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { - return agent, nil - }, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ - Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, - Consumed: []string{items[0]}, - Remaining: items[1:], - }, nil - }, - }) - - _, _ = loop.Push("first") + loop := newPreemptTestLoop(t, agent) - select { - case <-agentStarted: - case <-time.After(1 * time.Second): - t.Fatal("agent did not start") - } - - ok, ack := loop.Push("urgent", WithPreempt[string](WithAgentCancelMode(CancelAfterToolCalls))) + ok, ack := loop.Push("urgent", WithPreempt[string](AfterToolCalls)) assert.True(t, ok) assert.NotNil(t, ack) @@ -676,7 +665,7 @@ func TestTurnLoop_PreemptAck_ClosesImmediatelyIfLoopNotStarted(t *testing.T) { }, }) - ok, ack := loop.Push("urgent", WithPreempt[string]()) + ok, ack := loop.Push("urgent", WithPreempt[string](AnySafePoint)) assert.True(t, ok) assert.NotNil(t, ack) @@ -688,10 +677,8 @@ func TestTurnLoop_PreemptAck_ClosesImmediatelyIfLoopNotStarted(t *testing.T) { } func TestTurnLoop_Preempt_EscalatesOnSecondPreempt(t *testing.T) { - agentStarted := make(chan struct{}) firstCancelSeen := make(chan struct{}) agentFinishGate := make(chan struct{}) - agentStartedOnce := sync.Once{} firstCancelOnce := sync.Once{} var ccPtr atomic.Value @@ -699,7 +686,6 @@ func TestTurnLoop_Preempt_EscalatesOnSecondPreempt(t *testing.T) { agent := &turnLoopCancellableMockAgent{ name: "test", runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { - agentStartedOnce.Do(func() { close(agentStarted) }) <-ctx.Done() <-agentFinishGate return &AgentOutput{}, nil @@ -710,35 +696,18 @@ func TestTurnLoop_Preempt_EscalatesOnSecondPreempt(t *testing.T) { }, } - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { - return agent, nil - }, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ - Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, - Consumed: []string{items[0]}, - Remaining: items[1:], - }, nil - }, - }) + loop := newPreemptTestLoop(t, agent) - loop.Push("first") - select { - case <-agentStarted: - case <-time.After(1 * time.Second): - t.Fatal("agent did not start") - } - - loop.Push("urgent1", WithPreempt[string](WithAgentCancelMode(CancelAfterChatModel))) + loop.Push("urgent1", WithPreempt[string](AfterChatModel)) select { case <-firstCancelSeen: case <-time.After(1 * time.Second): t.Fatal("first preempt did not trigger cancel") } - loop.Push("urgent2", WithPreempt[string](WithAgentCancelMode(CancelImmediate))) + loop.Push("urgent2", WithPreemptTimeout[string](AnySafePoint, time.Millisecond)) + wantMode := CancelAfterChatModel | CancelAfterToolCalls deadline := time.Now().Add(1 * time.Second) for time.Now().Before(deadline) { v := ccPtr.Load() @@ -747,7 +716,7 @@ func TestTurnLoop_Preempt_EscalatesOnSecondPreempt(t *testing.T) { continue } cc := v.(*cancelContext) - if cc.getMode() == CancelImmediate && atomic.LoadInt32(&cc.escalated) == 1 { + if cc.getMode() == wantMode && atomic.LoadInt32(&cc.escalated) == 1 { break } time.Sleep(5 * time.Millisecond) @@ -758,7 +727,7 @@ func TestTurnLoop_Preempt_EscalatesOnSecondPreempt(t *testing.T) { t.Fatal("cancel context was not captured") } cc := v.(*cancelContext) - assert.Equal(t, CancelImmediate, cc.getMode()) + assert.Equal(t, wantMode, cc.getMode()) assert.Equal(t, int32(1), atomic.LoadInt32(&cc.escalated)) close(agentFinishGate) @@ -769,10 +738,8 @@ func TestTurnLoop_Preempt_EscalatesOnSecondPreempt(t *testing.T) { } func TestTurnLoop_Preempt_JoinsSafePointModesOnSecondPreempt(t *testing.T) { - agentStarted := make(chan struct{}) firstCancelSeen := make(chan struct{}) agentFinishGate := make(chan struct{}) - agentStartedOnce := sync.Once{} firstCancelOnce := sync.Once{} var ccPtr atomic.Value @@ -780,7 +747,6 @@ func TestTurnLoop_Preempt_JoinsSafePointModesOnSecondPreempt(t *testing.T) { agent := &turnLoopCancellableMockAgent{ name: "test", runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { - agentStartedOnce.Do(func() { close(agentStarted) }) <-ctx.Done() <-agentFinishGate return &AgentOutput{}, nil @@ -791,34 +757,16 @@ func TestTurnLoop_Preempt_JoinsSafePointModesOnSecondPreempt(t *testing.T) { }, } - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { - return agent, nil - }, - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ - Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, - Consumed: []string{items[0]}, - Remaining: items[1:], - }, nil - }, - }) - - loop.Push("first") - select { - case <-agentStarted: - case <-time.After(1 * time.Second): - t.Fatal("agent did not start") - } + loop := newPreemptTestLoop(t, agent) - loop.Push("urgent1", WithPreempt[string](WithAgentCancelMode(CancelAfterChatModel))) + loop.Push("urgent1", WithPreempt[string](AfterChatModel)) select { case <-firstCancelSeen: case <-time.After(1 * time.Second): t.Fatal("first preempt did not trigger cancel") } - loop.Push("urgent2", WithPreempt[string](WithAgentCancelMode(CancelAfterToolCalls))) + loop.Push("urgent2", WithPreempt[string](AfterToolCalls)) want := CancelAfterChatModel | CancelAfterToolCalls deadline := time.Now().Add(1 * time.Second) @@ -952,7 +900,7 @@ func TestTurnLoop_PreemptDelay_NoMispreemptOnNaturalCompletion(t *testing.T) { t.Fatal("agent1 did not start") } - loop.Push("second", WithPreempt[string](), WithPreemptDelay[string](500*time.Millisecond)) + loop.Push("second", WithPreempt[string](AnySafePoint), WithPreemptDelay[string](500*time.Millisecond)) select { case <-agent1Done: @@ -1093,7 +1041,7 @@ func TestTurnLoop_GetAgentError_RecoverConsumed(t *testing.T) { result := loop.Wait() assert.ErrorIs(t, result.ExitReason, agentErr) - assert.True(t, len(result.UnhandledItems) >= 1, "should recover at least the consumed item and remaining") + assert.NotEmpty(t, result.UnhandledItems, "should recover at least the consumed item and remaining") } func TestTurnLoop_GenInputError_RecoverItems(t *testing.T) { @@ -1293,7 +1241,7 @@ func TestTurnLoop_ContextCancelAfterGenInput_RecoverItems(t *testing.T) { result := loop.Wait() assert.ErrorIs(t, result.ExitReason, context.Canceled) - assert.True(t, len(result.UnhandledItems) >= 1, "should recover consumed and remaining items") + assert.NotEmpty(t, result.UnhandledItems, "should recover consumed and remaining items") } func TestTurnLoop_OnAgentEventsReceivesEvents(t *testing.T) { @@ -1340,7 +1288,7 @@ func TestTurnLoop_OnAgentEventsReceivesEvents(t *testing.T) { mu.Lock() defer mu.Unlock() - assert.True(t, len(receivedConsumed) > 0, "should have received consumed items") + assert.NotEmpty(t, receivedConsumed, "should have received consumed items") } func TestTurnLoop_StopDuringAgentExecution(t *testing.T) { @@ -1372,11 +1320,11 @@ func TestTurnLoop_StopDuringAgentExecution(t *testing.T) { loop.Push("msg1") <-agentStarted - loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + loop.Stop() result := loop.Wait() assert.NoError(t, result.ExitReason) - assert.Equal(t, []string{"msg1"}, result.CanceledItems) + assert.Empty(t, result.CanceledItems) } func TestTurnLoop_StopCheckPointIDInCancelError(t *testing.T) { @@ -1420,7 +1368,7 @@ func TestTurnLoop_StopCheckPointIDInCancelError(t *testing.T) { loop.Push("msg1") <-modelStarted - loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + loop.Stop() result := loop.Wait() @@ -1472,7 +1420,7 @@ func TestTurnLoop_StopWithoutCheckpointIDDoesNotPersist(t *testing.T) { loop.Push("msg1") <-modelStarted - loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + loop.Stop() result := loop.Wait() @@ -1615,7 +1563,7 @@ func TestTurnLoop_StopDuringAgentExecution_PersistAndResume(t *testing.T) { loop.Push("msg1") <-modelStarted - loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + loop.Stop() exit := loop.Wait() store.mu.Lock() @@ -1850,7 +1798,7 @@ func TestTurnLoop_CheckpointSaveError_ReturnsError(t *testing.T) { }) loop.Push("msg1") <-modelStarted - loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + loop.Stop() exit := loop.Wait() assert.Error(t, exit.ExitReason) assert.True(t, exit.Checkpointed) @@ -2092,7 +2040,7 @@ func TestTurnLoop_GenResumeNil_Error(t *testing.T) { }) loop1.Push("msg1") <-modelStarted - loop1.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + loop1.Stop() loop1.Wait() loop2 := NewTurnLoop(TurnLoopConfig[string]{ @@ -2264,7 +2212,7 @@ func TestTurnLoop_GenResumeReturnsError(t *testing.T) { }) loop1.Push("msg1") <-modelStarted - loop1.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + loop1.Stop() loop1.Wait() genResumeErr := fmt.Errorf("resume callback failed") @@ -2323,7 +2271,7 @@ func TestTurnLoop_CheckpointSaveError_MergesWithExistingError(t *testing.T) { }) loop.Push("msg1") <-modelStarted - loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + loop.Stop() exit := loop.Wait() assert.Error(t, exit.ExitReason) var ce *CancelError @@ -2371,7 +2319,7 @@ func TestTurnLoop_ResumeWithParams(t *testing.T) { }) loop1.Push("msg1") <-modelStarted - loop1.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + loop1.Stop() exit1 := loop1.Wait() var ce *CancelError assert.True(t, errors.As(exit1.ExitReason, &ce)) @@ -2413,25 +2361,6 @@ func TestTurnLoop_ResumeWithParams(t *testing.T) { _ = exit2 } -func TestTurnLoop_StopOptionsArePassed(t *testing.T) { - loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ - GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { - return &GenInputResult[string]{ - Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, - Consumed: items, - }, nil - }, - PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { - return &turnLoopMockAgent{name: "test"}, nil - }, - }) - - loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelAfterToolCalls))) - - result := loop.Wait() - assert.NoError(t, result.ExitReason) -} - func TestTurnLoop_Stop_EscalatesCancelMode(t *testing.T) { ctx := context.Background() agentStarted := make(chan *cancelContext, 1) @@ -2451,8 +2380,8 @@ func TestTurnLoop_Stop_EscalatesCancelMode(t *testing.T) { loop.Push("msg1") cc := <-agentStarted - loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelAfterToolCalls), WithAgentCancelTimeout(10*time.Second))) - loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + loop.Stop(WithGracefulTimeout(10 * time.Second)) + loop.Stop(WithImmediate()) deadline := time.After(1 * time.Second) for { @@ -2469,7 +2398,7 @@ func TestTurnLoop_Stop_EscalatesCancelMode(t *testing.T) { exit := loop.Wait() var ce *CancelError - assert.True(t, errors.As(exit.ExitReason, &ce)) + require.True(t, errors.As(exit.ExitReason, &ce)) assert.Equal(t, CancelImmediate, ce.Info.Mode) } @@ -2587,7 +2516,7 @@ func TestTurnLoop_PushFromOnAgentEvents(t *testing.T) { result := loop.Wait() assert.NoError(t, result.ExitReason) - assert.GreaterOrEqual(t, atomic.LoadInt32(&pushCount), int32(2)) + assert.Equal(t, int32(2), atomic.LoadInt32(&pushCount)) } // Tests for NewTurnLoop: the permissive API where Push, Stop, and Wait are @@ -2894,7 +2823,7 @@ func TestTurnLoop_TurnContext_PreemptedChannel(t *testing.T) { loop.Push("msg1") <-agentStarted - loop.Push("msg2", WithPreempt[string](WithAgentCancelMode(CancelImmediate))) + loop.Push("msg2", WithPreemptTimeout[string](AnySafePoint, time.Millisecond)) select { case <-preemptedSeen: @@ -3103,7 +3032,7 @@ func TestTurnLoop_ConcurrentPreemptsDuringTurn(t *testing.T) { wg.Add(1) go func(i int) { defer wg.Done() - ok, ack := loop.Push(fmt.Sprintf("urgent-%d", i), WithPreempt[string]()) + ok, ack := loop.Push(fmt.Sprintf("urgent-%d", i), WithPreemptTimeout[string](AnySafePoint, time.Millisecond)) if ok && ack != nil { select { case <-ack: @@ -3156,7 +3085,7 @@ func TestTurnLoop_PreemptDuringTurnTransition(t *testing.T) { time.Sleep(50 * time.Millisecond) - ok, ack := loop.Push("transitional", WithPreempt[string]()) + ok, ack := loop.Push("transitional", WithPreempt[string](AnySafePoint)) assert.True(t, ok, "push should succeed") if ack != nil { select { @@ -3233,7 +3162,7 @@ func TestTurnLoop_PushStrategy_DuringTurnTransition(t *testing.T) { atomic.StoreInt32(&strategyTCNotNil, 1) } <-strategyBlocker - return []PushOption[string]{WithPreempt[string]()} + return []PushOption[string]{WithPreempt[string](AnySafePoint)} })) }() @@ -3298,7 +3227,7 @@ func TestTurnLoop_ConcurrentPreemptAndStop(t *testing.T) { go func() { defer wg.Done() - _, ack := loop.Push("preempt-item", WithPreempt[string]()) + _, ack := loop.Push("preempt-item", WithPreempt[string](AnySafePoint)) if ack != nil { <-ack } @@ -3360,7 +3289,7 @@ func TestTurnLoop_ConcurrentPushStrategyAndStop(t *testing.T) { go func() { defer wg.Done() _, ack := loop.Push("strategic-item", WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { - return []PushOption[string]{WithPreempt[string]()} + return []PushOption[string]{WithPreempt[string](AnySafePoint)} })) if ack != nil { <-ack @@ -3417,7 +3346,7 @@ func TestTurnLoop_TurnContext_StoppedChannel(t *testing.T) { loop.Push("msg1") <-agentStarted - loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + loop.Stop() select { case <-stoppedSeen: @@ -3429,6 +3358,110 @@ func TestTurnLoop_TurnContext_StoppedChannel(t *testing.T) { loop.Wait() } +func TestTurnLoop_TurnContext_BothPreemptedAndStopped(t *testing.T) { + t.Run("PreemptThenStop_OnlyPreemptContributes", func(t *testing.T) { + preemptedSeen := make(chan struct{}) + agentStarted := make(chan struct{}) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopCancellableMockAgent{ + name: "slow", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + <-ctx.Done() + return nil, ctx.Err() + }, + }, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + close(agentStarted) + select { + case <-tc.Preempted: + close(preemptedSeen) + case <-time.After(5 * time.Second): + t.Error("timed out waiting for Preempted") + } + for { + if _, ok := events.Next(); !ok { + break + } + } + return nil + }, + }) + + loop.Push("msg1") + <-agentStarted + loop.Push("msg2", WithPreemptTimeout[string](AnySafePoint, time.Millisecond)) + + select { + case <-preemptedSeen: + case <-time.After(5 * time.Second): + t.Fatal("Preempted channel was never closed") + } + + loop.Stop(WithImmediate()) + loop.Wait() + }) + + t.Run("StopThenPreempt_OnlyStopContributes", func(t *testing.T) { + stoppedSeen := make(chan struct{}) + agentStarted := make(chan struct{}) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopCancellableMockAgent{ + name: "slow", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + <-ctx.Done() + return nil, ctx.Err() + }, + }, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + close(agentStarted) + select { + case <-tc.Stopped: + close(stoppedSeen) + case <-time.After(5 * time.Second): + t.Error("timed out waiting for Stopped") + } + for { + if _, ok := events.Next(); !ok { + break + } + } + return nil + }, + }) + + loop.Push("msg1") + <-agentStarted + loop.Stop(WithImmediate()) + + select { + case <-stoppedSeen: + case <-time.After(5 * time.Second): + t.Fatal("Stopped channel was never closed") + } + + loop.Push("msg2", WithPreemptTimeout[string](AnySafePoint, time.Millisecond)) + loop.Wait() + }) +} + func TestTurnLoop_PushStrategy_DuringTurn(t *testing.T) { agentStarted := make(chan struct{}) agentStartedOnce := sync.Once{} @@ -3486,7 +3519,7 @@ func TestTurnLoop_PushStrategy_DuringTurn(t *testing.T) { loop.Push("urgent", WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { atomic.AddInt32(&strategyCalled, 1) strategyTC = tc - return []PushOption[string]{WithPreempt[string]()} + return []PushOption[string]{WithPreempt[string](AnySafePoint)} })) select { @@ -3608,7 +3641,7 @@ func TestTurnLoop_PushStrategy_OverridesOtherOptions(t *testing.T) { // Strategy returns nil (no preempt), even though WithPreempt is also passed. // The strategy should override — so the agent should NOT be preempted. - ok, ack := loop.Push("item", WithPreempt[string](), WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { + ok, ack := loop.Push("item", WithPreempt[string](AnySafePoint), WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { return nil // no preempt })) assert.True(t, ok) @@ -3666,7 +3699,7 @@ func TestTurnLoop_PushStrategy_NestedStrategyStripped(t *testing.T) { return []PushOption[string]{ WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { atomic.AddInt32(&innerCalled, 1) - return []PushOption[string]{WithPreempt[string]()} + return []PushOption[string]{WithPreempt[string](AnySafePoint)} }), } })) @@ -3735,7 +3768,7 @@ func TestTurnLoop_PushStrategy_ConsumedInspection(t *testing.T) { // Strategy checks Consumed and preempts because current turn has "low-priority" items. loop.Push("urgent-task", WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { if tc != nil && len(tc.Consumed) > 0 && tc.Consumed[0] == "low-priority-task" { - return []PushOption[string]{WithPreempt[string]()} + return []PushOption[string]{WithPreempt[string](AnySafePoint)} } return nil })) @@ -4235,7 +4268,7 @@ func TestTurnLoop_StopCause_InTurnContext(t *testing.T) { loop.Push("msg1") <-agentStarted - loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate)), WithStopCause(cause)) + loop.Stop(WithStopCause(cause)) select { case c := <-gotCause: @@ -4280,13 +4313,68 @@ func TestTurnLoop_StopCause_FirstNonEmptyWins(t *testing.T) { loop.Push("msg1") <-agentStarted - loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelAfterToolCalls)), WithStopCause("first cause")) - loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate)), WithStopCause("second cause")) + loop.Stop(WithGraceful(), WithStopCause("first cause")) + loop.Stop(WithStopCause("second cause")) exit := loop.Wait() assert.Equal(t, "first cause", exit.StopCause, "first non-empty StopCause should win") } +func TestTurnLoop_StopBeforeRun_PushThenStop(t *testing.T) { + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + t.Fatal("GenInput should not be called when Stop is called before Run") + return nil, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + t.Fatal("PrepareAgent should not be called when Stop is called before Run") + return nil, nil + }, + }) + + ok, _ := loop.Push("item1") + assert.True(t, ok) + ok, _ = loop.Push("item2") + assert.True(t, ok) + + loop.Stop() + loop.Run(context.Background()) + result := loop.Wait() + + assert.NoError(t, result.ExitReason) + assert.Equal(t, []string{"item1", "item2"}, result.UnhandledItems) + assert.Empty(t, result.CanceledItems) + assert.Empty(t, result.TakeLateItems()) +} + +func TestTurnLoop_StopBeforeRun_StopThenPush(t *testing.T) { + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + t.Fatal("GenInput should not be called when Stop is called before Run") + return nil, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + t.Fatal("PrepareAgent should not be called when Stop is called before Run") + return nil, nil + }, + }) + + loop.Stop() + + ok, _ := loop.Push("item1") + assert.False(t, ok) + ok, _ = loop.Push("item2") + assert.False(t, ok) + + loop.Run(context.Background()) + result := loop.Wait() + + assert.NoError(t, result.ExitReason) + assert.Empty(t, result.UnhandledItems) + assert.Empty(t, result.CanceledItems) + assert.Equal(t, []string{"item1", "item2"}, result.TakeLateItems()) +} + func TestTurnLoop_SkipCheckpoint_Sticky(t *testing.T) { agentStarted := make(chan struct{}) @@ -4324,8 +4412,8 @@ func TestTurnLoop_SkipCheckpoint_Sticky(t *testing.T) { loop.Push("msg1") <-agentStarted - loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelAfterToolCalls)), WithSkipCheckpoint()) - loop.Stop(WithAgentCancel(WithAgentCancelMode(CancelImmediate))) + loop.Stop(WithGraceful(), WithSkipCheckpoint()) + loop.Stop() exit := loop.Wait() assert.False(t, exit.Checkpointed, "SkipCheckpoint should be sticky across multiple Stop calls") @@ -4335,3 +4423,217 @@ func TestTurnLoop_SkipCheckpoint_Sticky(t *testing.T) { store.mu.Unlock() assert.False(t, exists, "no checkpoint should be saved when SkipCheckpoint was set in any Stop call") } + +func TestWithGracefulTimeout_NonPositive_Panics(t *testing.T) { + assert.PanicsWithValue(t, "adk: WithGracefulTimeout: gracePeriod must be positive", + func() { WithGracefulTimeout(0) }) + assert.PanicsWithValue(t, "adk: WithGracefulTimeout: gracePeriod must be positive", + func() { WithGracefulTimeout(-1 * time.Second) }) +} + +func TestWithPreempt_ZeroSafePoint_Panics(t *testing.T) { + assert.PanicsWithValue(t, "adk: SafePoint must not be zero; use AfterToolCalls, AfterChatModel, or AnySafePoint", + func() { WithPreempt[string](SafePoint(0)) }) +} + +func TestWithPreemptTimeout_ZeroSafePoint_Panics(t *testing.T) { + assert.PanicsWithValue(t, "adk: SafePoint must not be zero; use AfterToolCalls, AfterChatModel, or AnySafePoint", + func() { WithPreemptTimeout[string](SafePoint(0), time.Second) }) +} + +func TestSafePoint_ToCancelMode(t *testing.T) { + assert.Equal(t, CancelAfterToolCalls, AfterToolCalls.toCancelMode()) + assert.Equal(t, CancelAfterChatModel, AfterChatModel.toCancelMode()) + assert.Equal(t, CancelAfterToolCalls|CancelAfterChatModel, AnySafePoint.toCancelMode()) +} + +func TestNewTurnLoop_NilGenInput_Panics(t *testing.T) { + assert.PanicsWithValue(t, "adk: NewTurnLoop: GenInput is required", func() { + NewTurnLoop(TurnLoopConfig[string]{PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { return nil, nil }}) + }) +} + +func TestNewTurnLoop_NilPrepareAgent_Panics(t *testing.T) { + assert.PanicsWithValue(t, "adk: NewTurnLoop: PrepareAgent is required", func() { + NewTurnLoop(TurnLoopConfig[string]{GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return nil, nil + }}) + }) +} + +func TestDeriveChild_NilParent_ReturnsNil(t *testing.T) { + var cc *cancelContext + assert.Nil(t, cc.deriveChild(context.Background())) +} + +func TestUntilIdleFor(t *testing.T) { + t.Run("FiresAfterIdleDuration", func(t *testing.T) { + turnDone := make(chan struct{}) + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + close(turnDone) + return &AgentOutput{}, nil + }, + }, nil + }, + }) + + loop.Push("msg1") + <-turnDone + + loop.Stop(UntilIdleFor(50 * time.Millisecond)) + + done := make(chan struct{}) + go func() { + loop.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("loop did not exit after idle timeout") + } + }) + + t.Run("ResetsOnPush", func(t *testing.T) { + turnCount := int32(0) + turnDone := make(chan struct{}, 10) + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + atomic.AddInt32(&turnCount, 1) + turnDone <- struct{}{} + return &AgentOutput{}, nil + }, + }, nil + }, + }) + + loop.Push("msg1") + <-turnDone + + loop.Stop(UntilIdleFor(200 * time.Millisecond)) + + time.Sleep(100 * time.Millisecond) + loop.Push("msg2") + <-turnDone + + done := make(chan struct{}) + go func() { + loop.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("loop did not exit after idle timeout") + } + + assert.Equal(t, int32(2), atomic.LoadInt32(&turnCount)) + }) + + t.Run("EscalatedByStopWithImmediate", func(t *testing.T) { + agentStarted := make(chan *cancelContext, 1) + probe := &turnLoopStopModeProbeAgent{ccCh: agentStarted} + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return probe, nil + }, + }) + + loop.Push("msg1") + cc := <-agentStarted + + loop.Stop(UntilIdleFor(10 * time.Minute)) + loop.Stop(WithImmediate()) + + deadline := time.After(2 * time.Second) + for { + if cc.getMode() == CancelImmediate { + break + } + select { + case <-deadline: + t.Fatal("cancel mode did not escalate to CancelImmediate") + default: + } + time.Sleep(1 * time.Millisecond) + } + + exit := loop.Wait() + var ce *CancelError + require.True(t, errors.As(exit.ExitReason, &ce)) + assert.Equal(t, CancelImmediate, ce.Info.Mode) + }) + + t.Run("EscalatedByStopWithGraceful", func(t *testing.T) { + agentStarted := make(chan struct{}) + agentDone := make(chan struct{}) + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + close(agentStarted) + <-ctx.Done() + close(agentDone) + return nil, ctx.Err() + }, + }, nil + }, + }) + + loop.Push("msg1") + <-agentStarted + + loop.Stop(UntilIdleFor(10 * time.Minute)) + loop.Stop(WithGracefulTimeout(50 * time.Millisecond)) + + select { + case <-agentDone: + case <-time.After(2 * time.Second): + t.Fatal("agent was not cancelled") + } + + exit := loop.Wait() + assert.Error(t, exit.ExitReason) + }) +} + +func TestUntilIdleFor_NonPositive_Panics(t *testing.T) { + assert.PanicsWithValue(t, "adk: UntilIdleFor: duration must be positive", + func() { UntilIdleFor(0) }) + assert.PanicsWithValue(t, "adk: UntilIdleFor: duration must be positive", + func() { UntilIdleFor(-1 * time.Second) }) +} diff --git a/adk/wrappers_test.go b/adk/wrappers_test.go index acb6588be..f231e3c07 100644 --- a/adk/wrappers_test.go +++ b/adk/wrappers_test.go @@ -1402,7 +1402,7 @@ func TestEventSenderToolHandler(t *testing.T) { NewEventSenderToolWrapper(), &invokableResultModifier{ BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, - modifiedResult: modifiedResult, + modifiedResult: modifiedResult, }, }, }) @@ -1490,7 +1490,7 @@ func TestEventSenderToolHandler(t *testing.T) { NewEventSenderToolWrapper(), &streamableResultModifier{ BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, - modifiedResult: modifiedResult, + modifiedResult: modifiedResult, }, }, }) @@ -1578,7 +1578,7 @@ func TestEventSenderToolHandler(t *testing.T) { NewEventSenderToolWrapper(), &enhancedInvokableResultModifier{ BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, - modifiedResult: modifiedResult, + modifiedResult: modifiedResult, }, }, }) @@ -1666,7 +1666,7 @@ func TestEventSenderToolHandler(t *testing.T) { NewEventSenderToolWrapper(), &enhancedStreamableResultModifier{ BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, - modifiedResult: modifiedResult, + modifiedResult: modifiedResult, }, }, }) diff --git a/components/prompt/agentic_chat_template_test.go b/components/prompt/agentic_chat_template_test.go index 42d7a8630..f47020a2c 100644 --- a/components/prompt/agentic_chat_template_test.go +++ b/components/prompt/agentic_chat_template_test.go @@ -21,9 +21,10 @@ import ( "errors" "testing" + "github.com/stretchr/testify/assert" + "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/schema" - "github.com/stretchr/testify/assert" ) type mockAgenticTemplate struct { diff --git a/internal/channel.go b/internal/channel.go index 8e36d8939..fa4215359 100644 --- a/internal/channel.go +++ b/internal/channel.go @@ -61,17 +61,18 @@ func (ch *UnboundedChan[T]) TrySend(value T) bool { return true } -// Receive gets an item from the channel (blocks if empty) +// Receive gets an item from the channel (blocks if empty). +// Returns (value, true) if an item was received. +// Returns (zero, false) if the channel was closed with no data remaining. func (ch *UnboundedChan[T]) Receive() (T, bool) { ch.mutex.Lock() defer ch.mutex.Unlock() for len(ch.buffer) == 0 && !ch.closed { - ch.notEmpty.Wait() // Wait until data is available + ch.notEmpty.Wait() } if len(ch.buffer) == 0 { - // Channel is closed and empty var zero T return zero, false } @@ -88,36 +89,6 @@ func (ch *UnboundedChan[T]) Close() { if !ch.closed { ch.closed = true - ch.notEmpty.Broadcast() // Wake up all waiting goroutines + ch.notEmpty.Broadcast() } } - -// TakeAll removes and returns all values from the channel atomically. -// Returns nil if the channel is empty. -func (ch *UnboundedChan[T]) TakeAll() []T { - ch.mutex.Lock() - defer ch.mutex.Unlock() - - if len(ch.buffer) == 0 { - return nil - } - - values := ch.buffer - ch.buffer = nil - return values -} - -// PushFront adds values to the front of the channel. -// This is useful for recovering values that need to be reprocessed. -// Does nothing if values is empty. -func (ch *UnboundedChan[T]) PushFront(values []T) { - if len(values) == 0 { - return - } - - ch.mutex.Lock() - defer ch.mutex.Unlock() - - ch.buffer = append(append([]T{}, values...), ch.buffer...) - ch.notEmpty.Signal() -} diff --git a/internal/channel_test.go b/internal/channel_test.go index bed2383f1..736a27413 100644 --- a/internal/channel_test.go +++ b/internal/channel_test.go @@ -219,244 +219,3 @@ func TestUnboundedChan_BlockingReceive(t *testing.T) { t.Error("Receive should have unblocked") } } - -func TestUnboundedChan_TakeAll(t *testing.T) { - ch := NewUnboundedChan[int]() - - // Test TakeAll on empty channel - items := ch.TakeAll() - if items != nil { - t.Errorf("TakeAll on empty channel should return nil, got %v", items) - } - - // Send some values - ch.Send(1) - ch.Send(2) - ch.Send(3) - - // Test TakeAll returns all values - items = ch.TakeAll() - if len(items) != 3 { - t.Errorf("expected 3 values, got %d", len(items)) - } - if items[0] != 1 || items[1] != 2 || items[2] != 3 { - t.Errorf("unexpected values: %v", items) - } - - // Channel should be empty now - if len(ch.buffer) != 0 { - t.Errorf("channel should be empty after TakeAll, got %d values", len(ch.buffer)) - } - - // TakeAll again should return nil - items = ch.TakeAll() - if items != nil { - t.Errorf("TakeAll on empty channel should return nil, got %v", items) - } -} - -func TestUnboundedChan_TakeAll_Partial(t *testing.T) { - ch := NewUnboundedChan[int]() - - // Send values - ch.Send(1) - ch.Send(2) - ch.Send(3) - - // Receive one - val, ok := ch.Receive() - if !ok || val != 1 { - t.Errorf("expected (1, true), got (%d, %v)", val, ok) - } - - // TakeAll should return remaining values - items := ch.TakeAll() - if len(items) != 2 { - t.Errorf("expected 2 values, got %d", len(items)) - } - if items[0] != 2 || items[1] != 3 { - t.Errorf("unexpected values: %v", items) - } -} - -func TestUnboundedChan_PushFront(t *testing.T) { - ch := NewUnboundedChan[int]() - - // Test PushFront with empty values (should do nothing) - ch.PushFront(nil) - ch.PushFront([]int{}) - if len(ch.buffer) != 0 { - t.Errorf("PushFront with empty values should not add anything, got %d values", len(ch.buffer)) - } - - // Send some values - ch.Send(3) - ch.Send(4) - - // PushFront should prepend values - ch.PushFront([]int{1, 2}) - - if len(ch.buffer) != 4 { - t.Errorf("expected 4 values, got %d", len(ch.buffer)) - } - if ch.buffer[0] != 1 || ch.buffer[1] != 2 || ch.buffer[2] != 3 || ch.buffer[3] != 4 { - t.Errorf("unexpected buffer: %v", ch.buffer) - } - - // Receive should return in correct order - val, _ := ch.Receive() - if val != 1 { - t.Errorf("expected 1, got %d", val) - } - val, _ = ch.Receive() - if val != 2 { - t.Errorf("expected 2, got %d", val) - } -} - -func TestUnboundedChan_PushFront_EmptyChannel(t *testing.T) { - ch := NewUnboundedChan[int]() - - // PushFront to empty channel - ch.PushFront([]int{1, 2, 3}) - - if len(ch.buffer) != 3 { - t.Errorf("expected 3 values, got %d", len(ch.buffer)) - } - - // Receive should work - val, ok := ch.Receive() - if !ok || val != 1 { - t.Errorf("expected (1, true), got (%d, %v)", val, ok) - } -} - -func TestUnboundedChan_PushFront_UnblocksReceive(t *testing.T) { - ch := NewUnboundedChan[int]() - - // Start a blocking receive - receiveDone := make(chan int) - go func() { - val, _ := ch.Receive() - receiveDone <- val - }() - - // Ensure receive is blocked - select { - case <-receiveDone: - t.Error("Receive should block on empty channel") - case <-time.After(50 * time.Millisecond): - // This is expected - } - - // PushFront should unblock the receive - ch.PushFront([]int{42}) - - select { - case val := <-receiveDone: - if val != 42 { - t.Errorf("expected 42, got %d", val) - } - case <-time.After(50 * time.Millisecond): - t.Error("Receive should have unblocked after PushFront") - } -} - -func TestUnboundedChan_PushFront_SpareCapacity(t *testing.T) { - ch := NewUnboundedChan[int]() - - // Pre-fill the channel so PushFront has something to append - ch.Send(10) - ch.Send(20) - - // Create a slice with spare capacity: len=2, cap=10. - // Elements beyond len (index 2-9) must not be corrupted by PushFront. - src := make([]int, 3, 10) - src[0] = 1 - src[1] = 2 - src[2] = 3 // sentinel — must survive PushFront(src[:2]) - - ch.PushFront(src[:2]) - - // Verify the sentinel was NOT overwritten by the channel's existing buffer - if src[2] != 3 { - t.Errorf("PushFront corrupted caller's backing array: src[2] = %d, want 3", src[2]) - } - - // Verify channel drains correctly: [1, 2, 10, 20] - expected := []int{1, 2, 10, 20} - for i, want := range expected { - got, ok := ch.Receive() - if !ok { - t.Fatalf("Receive returned ok=false at index %d", i) - } - if got != want { - t.Errorf("index %d: got %d, want %d", i, got, want) - } - } -} - -func TestUnboundedChan_TakeAll_PushFront_Concurrent(t *testing.T) { - ch := NewUnboundedChan[int]() - const numOps = 100 - - var wg sync.WaitGroup - wg.Add(3) - - // Goroutine 1: Send values - go func() { - defer wg.Done() - for i := 0; i < numOps; i++ { - ch.Send(i) - time.Sleep(time.Microsecond) - } - }() - - // Goroutine 2: TakeAll periodically - takeAllResults := make([][]int, 0) - var mu sync.Mutex - go func() { - defer wg.Done() - for i := 0; i < numOps/10; i++ { - items := ch.TakeAll() - if items != nil { - mu.Lock() - takeAllResults = append(takeAllResults, items) - mu.Unlock() - } - time.Sleep(10 * time.Microsecond) - } - }() - - // Goroutine 3: PushFront periodically - go func() { - defer wg.Done() - for i := 0; i < numOps/10; i++ { - ch.PushFront([]int{-i}) - time.Sleep(10 * time.Microsecond) - } - }() - - wg.Wait() - ch.Close() - - // Drain remaining values - remaining := ch.TakeAll() - if remaining != nil { - mu.Lock() - takeAllResults = append(takeAllResults, remaining) - mu.Unlock() - } - - // Count total values collected - total := 0 - for _, batch := range takeAllResults { - total += len(batch) - } - - // We should have exactly numOps (from Send) + numOps/10 (from PushFront) values - expected := numOps + numOps/10 - if total != expected { - t.Errorf("expected %d values, got %d", expected, total) - } -} From f3adfba5b6d3918a9c0eef549dfbbf163a106cce Mon Sep 17 00:00:00 2001 From: "shentong.martin" Date: Tue, 14 Apr 2026 21:30:59 +0800 Subject: [PATCH 55/59] feat(middlewares): add permission middleware for tool call gating Change-Id: I04e50d2736406cd2bed030715ec02959cb33dc68 --- adk/middlewares/permission/permission.go | 208 ++++++++++ adk/middlewares/permission/permission_test.go | 382 ++++++++++++++++++ 2 files changed, 590 insertions(+) create mode 100644 adk/middlewares/permission/permission.go create mode 100644 adk/middlewares/permission/permission_test.go diff --git a/adk/middlewares/permission/permission.go b/adk/middlewares/permission/permission.go new file mode 100644 index 000000000..05766926f --- /dev/null +++ b/adk/middlewares/permission/permission.go @@ -0,0 +1,208 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package permission provides a ChatModelAgentMiddleware that gates tool execution +// behind a user-defined permission check (BeforeToolCall). It supports three decisions: +// Allow (execute the tool), Deny (return a deny message as tool result), and Ask +// (interrupt the agent loop via StatefulInterrupt for external approval). +package permission + +import ( + "context" + "fmt" + + "github.com/cloudwego/eino/adk" + "github.com/cloudwego/eino/components/tool" + "github.com/cloudwego/eino/schema" +) + +func init() { + schema.RegisterName[*AskInfo]("_eino_adk_permission_ask_info") + schema.RegisterName[*AskState]("_eino_adk_permission_ask_state") +} + +// Decision represents the outcome of a permission check. +type Decision string + +const ( + // Allow permits the tool to execute. + Allow Decision = "allow" + // Deny blocks the tool and returns a deny message as the tool result. + Deny Decision = "deny" + // Ask interrupts the agent loop to await external approval. + Ask Decision = "ask" +) + +// ToolCallDecision is the result of a BeforeToolCall evaluation. +type ToolCallDecision struct { + Decision Decision + Message string + UpdatedInput string + Reason string +} + +// BeforeToolCall is the user-provided evaluation function invoked before each tool call. +// It returns a ToolCallDecision that determines whether the call is allowed, denied, or +// requires interactive approval. Returning an error signals an infrastructure failure +// and aborts the agent loop; permission denials should use Decision: Deny instead. +type BeforeToolCall func(ctx context.Context, toolName string, argumentsInJSON string) (*ToolCallDecision, error) + +// AskInfo is the interrupt info exposed to external consumers (UI / OnAgentEvents). +// All fields are basic types to satisfy gob serialization requirements. +type AskInfo struct { + ToolName string + CallID string + Arguments string + Message string +} + +// AskState is the interrupt state persisted via CheckPointStore (gob serialization). +type AskState struct { + Info *AskInfo +} + +// ResumeResponse is the decision injected externally via ResumeWithParams. +type ResumeResponse struct { + Approved bool + UpdatedInput string + DenyMessage string +} + +// Middleware is a ChatModelAgentMiddleware that gates tool execution behind +// a user-defined BeforeToolCall permission check. +type Middleware struct { + *adk.BaseChatModelAgentMiddleware + beforeToolCall BeforeToolCall +} + +// NewMiddleware creates a new permission Middleware with the given BeforeToolCall evaluator. +func NewMiddleware(beforeToolCall BeforeToolCall) *Middleware { + return &Middleware{ + BaseChatModelAgentMiddleware: &adk.BaseChatModelAgentMiddleware{}, + beforeToolCall: beforeToolCall, + } +} + +type gateResult struct { + allowed bool + denyResult string + updatedInput string +} + +func (m *Middleware) permissionGate( + ctx context.Context, + tCtx *adk.ToolContext, + argumentsInJSON string, +) (*gateResult, error) { + wasInterrupted, _, savedState := tool.GetInterruptState[*AskState](ctx) + isTarget, hasData, response := tool.GetResumeContext[*ResumeResponse](ctx) + + if wasInterrupted && !isTarget { + return nil, tool.StatefulInterrupt(ctx, savedState.Info, savedState) + } + + if isTarget && hasData { + if !response.Approved { + return &gateResult{denyResult: formatDenyResult(tCtx.Name, response.DenyMessage)}, nil + } + input := argumentsInJSON + if response.UpdatedInput != "" { + input = response.UpdatedInput + } + return &gateResult{allowed: true, updatedInput: input}, nil + } + + if isTarget && !hasData { + return nil, fmt.Errorf("permission: tool %s was targeted for resume but received nil or type-mismatched ResumeResponse", tCtx.Name) + } + + decision, err := m.beforeToolCall(ctx, tCtx.Name, argumentsInJSON) + if err != nil { + return nil, fmt.Errorf("permission check failed: %w", err) + } + if decision == nil { + return nil, fmt.Errorf("permission: BeforeToolCall returned nil decision for tool %s", tCtx.Name) + } + + switch decision.Decision { + case Allow: + input := argumentsInJSON + if decision.UpdatedInput != "" { + input = decision.UpdatedInput + } + return &gateResult{allowed: true, updatedInput: input}, nil + + case Deny: + return &gateResult{denyResult: formatDenyResult(tCtx.Name, decision.Message)}, nil + + case Ask: + info := &AskInfo{ + ToolName: tCtx.Name, + CallID: tCtx.CallID, + Arguments: argumentsInJSON, + Message: decision.Message, + } + state := &AskState{Info: info} + return nil, tool.StatefulInterrupt(ctx, info, state) + + default: + return &gateResult{denyResult: formatDenyResult(tCtx.Name, + fmt.Sprintf("unknown permission decision %q", decision.Decision))}, nil + } +} + +// WrapInvokableToolCall intercepts synchronous tool calls with a permission gate. +func (m *Middleware) WrapInvokableToolCall( + ctx context.Context, + endpoint adk.InvokableToolCallEndpoint, + tCtx *adk.ToolContext, +) (adk.InvokableToolCallEndpoint, error) { + return func(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { + result, err := m.permissionGate(ctx, tCtx, argumentsInJSON) + if err != nil { + return "", err + } + if !result.allowed { + return result.denyResult, nil + } + return endpoint(ctx, result.updatedInput, opts...) + }, nil +} + +// WrapStreamableToolCall intercepts streaming tool calls with a permission gate. +func (m *Middleware) WrapStreamableToolCall( + ctx context.Context, + endpoint adk.StreamableToolCallEndpoint, + tCtx *adk.ToolContext, +) (adk.StreamableToolCallEndpoint, error) { + return func(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (*schema.StreamReader[string], error) { + result, err := m.permissionGate(ctx, tCtx, argumentsInJSON) + if err != nil { + return nil, err + } + if !result.allowed { + sr, sw := schema.Pipe[string](1) + sw.Send(result.denyResult, nil) + sw.Close() + return sr, nil + } + return endpoint(ctx, result.updatedInput, opts...) + }, nil +} + +func formatDenyResult(toolName, message string) string { + return fmt.Sprintf("Permission denied for tool %s: %s", toolName, message) +} diff --git a/adk/middlewares/permission/permission_test.go b/adk/middlewares/permission/permission_test.go new file mode 100644 index 000000000..4a95a304a --- /dev/null +++ b/adk/middlewares/permission/permission_test.go @@ -0,0 +1,382 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package permission + +import ( + "context" + "errors" + "fmt" + "io" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/cloudwego/eino/adk" + "github.com/cloudwego/eino/components/tool" + "github.com/cloudwego/eino/internal/core" + "github.com/cloudwego/eino/schema" +) + +func TestNewMiddleware(t *testing.T) { + called := false + m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { + called = true + return &ToolCallDecision{Decision: Allow}, nil + }) + assert.NotNil(t, m) + assert.NotNil(t, m.beforeToolCall) + assert.NotNil(t, m.BaseChatModelAgentMiddleware) + assert.False(t, called) +} + +func TestFormatDenyResult(t *testing.T) { + result := formatDenyResult("WriteFile", "destructive operation blocked") + assert.Equal(t, "Permission denied for tool WriteFile: destructive operation blocked", result) +} + +func makeCtxWithAddr() context.Context { + ctx := context.Background() + return core.AppendAddressSegment(ctx, "agent", "test-agent", "") +} + +func TestPermissionGate_Allow(t *testing.T) { + m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { + assert.Equal(t, "ReadFile", toolName) + assert.Equal(t, `{"path":"/tmp/x"}`, args) + return &ToolCallDecision{Decision: Allow, Reason: "read-only"}, nil + }) + + tCtx := &adk.ToolContext{Name: "ReadFile", CallID: "call_1"} + result, err := m.permissionGate(context.Background(), tCtx, `{"path":"/tmp/x"}`) + require.NoError(t, err) + assert.True(t, result.allowed) + assert.Equal(t, `{"path":"/tmp/x"}`, result.updatedInput) +} + +func TestPermissionGate_AllowWithUpdatedInput(t *testing.T) { + m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { + return &ToolCallDecision{ + Decision: Allow, + UpdatedInput: `{"path":"/tmp/safe"}`, + }, nil + }) + + tCtx := &adk.ToolContext{Name: "ReadFile", CallID: "call_1"} + result, err := m.permissionGate(context.Background(), tCtx, `{"path":"/tmp/danger"}`) + require.NoError(t, err) + assert.True(t, result.allowed) + assert.Equal(t, `{"path":"/tmp/safe"}`, result.updatedInput) +} + +func TestPermissionGate_Deny(t *testing.T) { + m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { + return &ToolCallDecision{ + Decision: Deny, + Message: "operation not allowed", + Reason: "policy", + }, nil + }) + + tCtx := &adk.ToolContext{Name: "DeleteFile", CallID: "call_2"} + result, err := m.permissionGate(context.Background(), tCtx, `{}`) + require.NoError(t, err) + assert.False(t, result.allowed) + assert.Equal(t, "Permission denied for tool DeleteFile: operation not allowed", result.denyResult) +} + +func TestPermissionGate_Ask(t *testing.T) { + m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { + return &ToolCallDecision{ + Decision: Ask, + Message: "requires approval", + }, nil + }) + + tCtx := &adk.ToolContext{Name: "Execute", CallID: "call_3"} + ctx := makeCtxWithAddr() + result, err := m.permissionGate(ctx, tCtx, `{"cmd":"rm -rf /"}`) + assert.Nil(t, result) + require.Error(t, err) + + var is *core.InterruptSignal + assert.True(t, errors.As(err, &is)) +} + +func TestPermissionGate_UnknownDecision(t *testing.T) { + m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: "maybe"}, nil + }) + + tCtx := &adk.ToolContext{Name: "SomeTool", CallID: "call_4"} + result, err := m.permissionGate(context.Background(), tCtx, `{}`) + require.NoError(t, err) + assert.False(t, result.allowed) + assert.Contains(t, result.denyResult, "unknown permission decision") + assert.Contains(t, result.denyResult, "maybe") +} + +func TestPermissionGate_NilDecision(t *testing.T) { + m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { + return nil, nil + }) + + tCtx := &adk.ToolContext{Name: "SomeTool", CallID: "call_5"} + result, err := m.permissionGate(context.Background(), tCtx, `{}`) + assert.Nil(t, result) + require.Error(t, err) + assert.Contains(t, err.Error(), "nil decision") +} + +func TestPermissionGate_BeforeToolCallError(t *testing.T) { + m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { + return nil, fmt.Errorf("rule store unreachable") + }) + + tCtx := &adk.ToolContext{Name: "SomeTool", CallID: "call_6"} + result, err := m.permissionGate(context.Background(), tCtx, `{}`) + assert.Nil(t, result) + require.Error(t, err) + assert.Contains(t, err.Error(), "permission check failed") + assert.Contains(t, err.Error(), "rule store unreachable") +} + +func TestWrapInvokableToolCall_Allow(t *testing.T) { + endpointCalled := false + m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Allow}, nil + }) + + originalEndpoint := adk.InvokableToolCallEndpoint(func(ctx context.Context, args string, opts ...tool.Option) (string, error) { + endpointCalled = true + return "tool result: " + args, nil + }) + + tCtx := &adk.ToolContext{Name: "MyTool", CallID: "call_1"} + wrapped, err := m.WrapInvokableToolCall(context.Background(), originalEndpoint, tCtx) + require.NoError(t, err) + + result, err := wrapped(context.Background(), `{"key":"value"}`) + require.NoError(t, err) + assert.True(t, endpointCalled) + assert.Equal(t, `tool result: {"key":"value"}`, result) +} + +func TestWrapInvokableToolCall_AllowWithUpdatedInput(t *testing.T) { + m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Allow, UpdatedInput: `{"sanitized":true}`}, nil + }) + + var receivedArgs string + originalEndpoint := adk.InvokableToolCallEndpoint(func(ctx context.Context, args string, opts ...tool.Option) (string, error) { + receivedArgs = args + return "ok", nil + }) + + tCtx := &adk.ToolContext{Name: "MyTool", CallID: "call_1"} + wrapped, err := m.WrapInvokableToolCall(context.Background(), originalEndpoint, tCtx) + require.NoError(t, err) + + result, err := wrapped(context.Background(), `{"original":true}`) + require.NoError(t, err) + assert.Equal(t, `{"sanitized":true}`, receivedArgs) + assert.Equal(t, "ok", result) +} + +func TestWrapInvokableToolCall_Deny(t *testing.T) { + endpointCalled := false + m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Deny, Message: "blocked"}, nil + }) + + originalEndpoint := adk.InvokableToolCallEndpoint(func(ctx context.Context, args string, opts ...tool.Option) (string, error) { + endpointCalled = true + return "should not reach", nil + }) + + tCtx := &adk.ToolContext{Name: "MyTool", CallID: "call_1"} + wrapped, err := m.WrapInvokableToolCall(context.Background(), originalEndpoint, tCtx) + require.NoError(t, err) + + result, err := wrapped(context.Background(), `{}`) + require.NoError(t, err) + assert.False(t, endpointCalled) + assert.Equal(t, "Permission denied for tool MyTool: blocked", result) +} + +func TestWrapInvokableToolCall_Ask(t *testing.T) { + m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Ask, Message: "need approval"}, nil + }) + + originalEndpoint := adk.InvokableToolCallEndpoint(func(ctx context.Context, args string, opts ...tool.Option) (string, error) { + return "should not reach", nil + }) + + tCtx := &adk.ToolContext{Name: "DangerTool", CallID: "call_1"} + wrapped, err := m.WrapInvokableToolCall(context.Background(), originalEndpoint, tCtx) + require.NoError(t, err) + + ctx := makeCtxWithAddr() + result, err := wrapped(ctx, `{"danger":true}`) + assert.Equal(t, "", result) + require.Error(t, err) + + var is *core.InterruptSignal + assert.True(t, errors.As(err, &is)) +} + +func TestWrapStreamableToolCall_Allow(t *testing.T) { + m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Allow}, nil + }) + + originalEndpoint := adk.StreamableToolCallEndpoint(func(ctx context.Context, args string, opts ...tool.Option) (*schema.StreamReader[string], error) { + sr, sw := schema.Pipe[string](1) + sw.Send("stream chunk: "+args, nil) + sw.Close() + return sr, nil + }) + + tCtx := &adk.ToolContext{Name: "StreamTool", CallID: "call_1"} + wrapped, err := m.WrapStreamableToolCall(context.Background(), originalEndpoint, tCtx) + require.NoError(t, err) + + sr, err := wrapped(context.Background(), `{"key":"val"}`) + require.NoError(t, err) + require.NotNil(t, sr) + + chunk, err := sr.Recv() + require.NoError(t, err) + assert.Equal(t, `stream chunk: {"key":"val"}`, chunk) + + _, err = sr.Recv() + assert.Equal(t, io.EOF, err) +} + +func TestWrapStreamableToolCall_Deny(t *testing.T) { + endpointCalled := false + m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Deny, Message: "stream blocked"}, nil + }) + + originalEndpoint := adk.StreamableToolCallEndpoint(func(ctx context.Context, args string, opts ...tool.Option) (*schema.StreamReader[string], error) { + endpointCalled = true + return nil, nil + }) + + tCtx := &adk.ToolContext{Name: "StreamTool", CallID: "call_1"} + wrapped, err := m.WrapStreamableToolCall(context.Background(), originalEndpoint, tCtx) + require.NoError(t, err) + + sr, err := wrapped(context.Background(), `{}`) + require.NoError(t, err) + assert.False(t, endpointCalled) + require.NotNil(t, sr) + + chunk, err := sr.Recv() + require.NoError(t, err) + assert.Equal(t, "Permission denied for tool StreamTool: stream blocked", chunk) + + _, err = sr.Recv() + assert.Equal(t, io.EOF, err) +} + +func TestWrapStreamableToolCall_Ask(t *testing.T) { + m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Ask, Message: "stream needs approval"}, nil + }) + + originalEndpoint := adk.StreamableToolCallEndpoint(func(ctx context.Context, args string, opts ...tool.Option) (*schema.StreamReader[string], error) { + return nil, nil + }) + + tCtx := &adk.ToolContext{Name: "StreamTool", CallID: "call_1"} + wrapped, err := m.WrapStreamableToolCall(context.Background(), originalEndpoint, tCtx) + require.NoError(t, err) + + ctx := makeCtxWithAddr() + sr, err := wrapped(ctx, `{}`) + assert.Nil(t, sr) + require.Error(t, err) + + var is *core.InterruptSignal + assert.True(t, errors.As(err, &is)) +} + +func TestWrapInvokableToolCall_BeforeToolCallError(t *testing.T) { + m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { + return nil, fmt.Errorf("infra failure") + }) + + originalEndpoint := adk.InvokableToolCallEndpoint(func(ctx context.Context, args string, opts ...tool.Option) (string, error) { + return "should not reach", nil + }) + + tCtx := &adk.ToolContext{Name: "MyTool", CallID: "call_1"} + wrapped, err := m.WrapInvokableToolCall(context.Background(), originalEndpoint, tCtx) + require.NoError(t, err) + + result, err := wrapped(context.Background(), `{}`) + assert.Equal(t, "", result) + require.Error(t, err) + assert.Contains(t, err.Error(), "permission check failed") +} + +func TestMiddleware_ImplementsInterface(t *testing.T) { + m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Allow}, nil + }) + var _ adk.ChatModelAgentMiddleware = m +} + +func TestDecisionConstants(t *testing.T) { + assert.Equal(t, Decision("allow"), Allow) + assert.Equal(t, Decision("deny"), Deny) + assert.Equal(t, Decision("ask"), Ask) +} + +func TestAskInfoFields(t *testing.T) { + info := &AskInfo{ + ToolName: "MyTool", + CallID: "call_1", + Arguments: `{"key":"value"}`, + Message: "requires approval", + } + assert.Equal(t, "MyTool", info.ToolName) + assert.Equal(t, "call_1", info.CallID) + assert.Equal(t, `{"key":"value"}`, info.Arguments) + assert.Equal(t, "requires approval", info.Message) +} + +func TestResumeResponse_Approved(t *testing.T) { + resp := &ResumeResponse{ + Approved: true, + UpdatedInput: `{"modified":true}`, + } + assert.True(t, resp.Approved) + assert.Equal(t, `{"modified":true}`, resp.UpdatedInput) +} + +func TestResumeResponse_Denied(t *testing.T) { + resp := &ResumeResponse{ + Approved: false, + DenyMessage: "user rejected", + } + assert.False(t, resp.Approved) + assert.Equal(t, "user rejected", resp.DenyMessage) +} From 1f7a29ecc406f251793c8a17e2eca05dcf8c8e7b Mon Sep 17 00:00:00 2001 From: "shentong.martin" Date: Tue, 14 Apr 2026 21:45:20 +0800 Subject: [PATCH 56/59] test(middlewares): add attack tests for permission middleware Change-Id: I8b733829302674ff05b1d775f264a5f6f0f9657a --- adk/middlewares/permission/permission_test.go | 164 ++++++++++++++++++ 1 file changed, 164 insertions(+) diff --git a/adk/middlewares/permission/permission_test.go b/adk/middlewares/permission/permission_test.go index 4a95a304a..784e4af9a 100644 --- a/adk/middlewares/permission/permission_test.go +++ b/adk/middlewares/permission/permission_test.go @@ -17,10 +17,13 @@ package permission import ( + "bytes" "context" + "encoding/gob" "errors" "fmt" "io" + "sync" "testing" "github.com/stretchr/testify/assert" @@ -380,3 +383,164 @@ func TestResumeResponse_Denied(t *testing.T) { assert.False(t, resp.Approved) assert.Equal(t, "user rejected", resp.DenyMessage) } + +func TestAttack_NilBeforeToolCall(t *testing.T) { + m := NewMiddleware(nil) + require.NotNil(t, m) + + tCtx := &adk.ToolContext{Name: "SomeTool", CallID: "call_nil"} + assert.Panics(t, func() { + _, _ = m.permissionGate(context.Background(), tCtx, `{}`) + }) +} + +func TestAttack_EmptyDenyMessage(t *testing.T) { + result := formatDenyResult("WriteTool", "") + assert.Equal(t, "Permission denied for tool WriteTool: ", result) + + m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Deny, Message: ""}, nil + }) + + tCtx := &adk.ToolContext{Name: "WriteTool", CallID: "call_empty_deny"} + gr, err := m.permissionGate(context.Background(), tCtx, `{}`) + require.NoError(t, err) + assert.False(t, gr.allowed) + assert.Equal(t, "Permission denied for tool WriteTool: ", gr.denyResult) +} + +func TestAttack_DenyWithEmptyToolName(t *testing.T) { + m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { + assert.Equal(t, "", toolName) + return &ToolCallDecision{Decision: Deny, Message: "no name"}, nil + }) + + tCtx := &adk.ToolContext{Name: "", CallID: "call_empty_name"} + gr, err := m.permissionGate(context.Background(), tCtx, `{"x":1}`) + require.NoError(t, err) + assert.False(t, gr.allowed) + assert.Equal(t, "Permission denied for tool : no name", gr.denyResult) +} + +func TestAttack_AllowUpdatedInputEmpty(t *testing.T) { + m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Allow, UpdatedInput: ""}, nil + }) + + originalArgs := `{"important":"data"}` + tCtx := &adk.ToolContext{Name: "MyTool", CallID: "call_empty_update"} + gr, err := m.permissionGate(context.Background(), tCtx, originalArgs) + require.NoError(t, err) + assert.True(t, gr.allowed) + assert.Equal(t, originalArgs, gr.updatedInput) + + var receivedArgs string + endpoint := adk.InvokableToolCallEndpoint(func(ctx context.Context, args string, opts ...tool.Option) (string, error) { + receivedArgs = args + return "ok", nil + }) + + wrapped, err := m.WrapInvokableToolCall(context.Background(), endpoint, tCtx) + require.NoError(t, err) + + result, err := wrapped(context.Background(), originalArgs) + require.NoError(t, err) + assert.Equal(t, "ok", result) + assert.Equal(t, originalArgs, receivedArgs) +} + +func TestAttack_AskInfoGobSerializable(t *testing.T) { + info := &AskInfo{ + ToolName: "DangerTool", + CallID: "call_gob", + Arguments: `{"rm":"-rf /"}`, + Message: "are you sure?", + } + state := &AskState{Info: info} + + var buf bytes.Buffer + require.NoError(t, gob.NewEncoder(&buf).Encode(info)) + + var decodedInfo AskInfo + require.NoError(t, gob.NewDecoder(&buf).Decode(&decodedInfo)) + assert.Equal(t, info.ToolName, decodedInfo.ToolName) + assert.Equal(t, info.CallID, decodedInfo.CallID) + assert.Equal(t, info.Arguments, decodedInfo.Arguments) + assert.Equal(t, info.Message, decodedInfo.Message) + + buf.Reset() + require.NoError(t, gob.NewEncoder(&buf).Encode(state)) + + var decodedState AskState + require.NoError(t, gob.NewDecoder(&buf).Decode(&decodedState)) + require.NotNil(t, decodedState.Info) + assert.Equal(t, info.ToolName, decodedState.Info.ToolName) + assert.Equal(t, info.CallID, decodedState.Info.CallID) + assert.Equal(t, info.Arguments, decodedState.Info.Arguments) + assert.Equal(t, info.Message, decodedState.Info.Message) +} + +func TestAttack_ResumeResponseEmptyUpdatedInput(t *testing.T) { + m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Ask, Message: "confirm?"}, nil + }) + + originalArgs := `{"critical":"payload"}` + tCtx := &adk.ToolContext{Name: "CriticalTool", CallID: "call_resume_empty"} + + ctx := makeCtxWithAddr() + gr, err := m.permissionGate(ctx, tCtx, originalArgs) + assert.Nil(t, gr) + require.Error(t, err) + var is *core.InterruptSignal + require.True(t, errors.As(err, &is)) + + resp := &ResumeResponse{Approved: true, UpdatedInput: ""} + assert.True(t, resp.Approved) + assert.Equal(t, "", resp.UpdatedInput) +} + +func TestAttack_ConcurrentBeforeToolCall(t *testing.T) { + m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Allow}, nil + }) + + var receivedMu sync.Mutex + received := make(map[string]string) + + endpoint := adk.InvokableToolCallEndpoint(func(ctx context.Context, args string, opts ...tool.Option) (string, error) { + receivedMu.Lock() + received[args] = "done" + receivedMu.Unlock() + return "result:" + args, nil + }) + + tCtx := &adk.ToolContext{Name: "ConcurrentTool", CallID: "call_concurrent"} + wrapped, err := m.WrapInvokableToolCall(context.Background(), endpoint, tCtx) + require.NoError(t, err) + + const goroutines = 50 + var wg sync.WaitGroup + wg.Add(goroutines) + errs := make([]error, goroutines) + results := make([]string, goroutines) + + for i := 0; i < goroutines; i++ { + go func(idx int) { + defer wg.Done() + args := fmt.Sprintf(`{"id":%d}`, idx) + results[idx], errs[idx] = wrapped(context.Background(), args) + }(i) + } + wg.Wait() + + for i := 0; i < goroutines; i++ { + assert.NoError(t, errs[i], "goroutine %d returned error", i) + expected := fmt.Sprintf(`result:{"id":%d}`, i) + assert.Equal(t, expected, results[i], "goroutine %d result mismatch", i) + } + + receivedMu.Lock() + assert.Len(t, received, goroutines) + receivedMu.Unlock() +} From 0054beb66e37a4bec835c5ccdeb022dadf256e1b Mon Sep 17 00:00:00 2001 From: "shentong.martin" Date: Wed, 15 Apr 2026 10:32:21 +0800 Subject: [PATCH 57/59] refactor(middlewares): address PR review comments for permission middleware MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Rename BeforeToolCall type to Checker; pass *adk.ToolContext instead of bare toolName (Thread #1) - Rename NewMiddleware to New to match codebase convention (Thread #4) - Improve error messages with tool name, call ID, and args for LLM consumption (Thread #5) - Replace schema.Pipe with StreamReaderFromArray for stream deny path (Thread #6) - Add 4 E2E ask→resume tests: approved, denied, re-interrupt non-target, resume with updated input (Thread #7) Change-Id: If43ac495fa3a5fc1d71a27db9614cb996d4547b5 --- adk/middlewares/permission/permission.go | 62 +++--- adk/middlewares/permission/permission_test.go | 184 +++++++++++++++--- 2 files changed, 181 insertions(+), 65 deletions(-) diff --git a/adk/middlewares/permission/permission.go b/adk/middlewares/permission/permission.go index 05766926f..6777bf893 100644 --- a/adk/middlewares/permission/permission.go +++ b/adk/middlewares/permission/permission.go @@ -15,7 +15,7 @@ */ // Package permission provides a ChatModelAgentMiddleware that gates tool execution -// behind a user-defined permission check (BeforeToolCall). It supports three decisions: +// behind a user-defined permission check (Checker). It supports three decisions: // Allow (execute the tool), Deny (return a deny message as tool result), and Ask // (interrupt the agent loop via StatefulInterrupt for external approval). package permission @@ -34,19 +34,14 @@ func init() { schema.RegisterName[*AskState]("_eino_adk_permission_ask_state") } -// Decision represents the outcome of a permission check. type Decision string const ( - // Allow permits the tool to execute. Allow Decision = "allow" - // Deny blocks the tool and returns a deny message as the tool result. - Deny Decision = "deny" - // Ask interrupts the agent loop to await external approval. - Ask Decision = "ask" + Deny Decision = "deny" + Ask Decision = "ask" ) -// ToolCallDecision is the result of a BeforeToolCall evaluation. type ToolCallDecision struct { Decision Decision Message string @@ -54,14 +49,14 @@ type ToolCallDecision struct { Reason string } -// BeforeToolCall is the user-provided evaluation function invoked before each tool call. -// It returns a ToolCallDecision that determines whether the call is allowed, denied, or -// requires interactive approval. Returning an error signals an infrastructure failure -// and aborts the agent loop; permission denials should use Decision: Deny instead. -type BeforeToolCall func(ctx context.Context, toolName string, argumentsInJSON string) (*ToolCallDecision, error) +// Checker is the user-provided evaluation function invoked before each tool call. +// It receives the full ToolContext (including tool name and call ID) along with +// the raw JSON arguments, and returns a ToolCallDecision that determines whether +// the call is allowed, denied, or requires interactive approval. +// Returning an error signals an infrastructure failure and aborts the agent loop; +// permission denials should use Decision: Deny instead. +type Checker func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) -// AskInfo is the interrupt info exposed to external consumers (UI / OnAgentEvents). -// All fields are basic types to satisfy gob serialization requirements. type AskInfo struct { ToolName string CallID string @@ -69,30 +64,26 @@ type AskInfo struct { Message string } -// AskState is the interrupt state persisted via CheckPointStore (gob serialization). type AskState struct { Info *AskInfo } -// ResumeResponse is the decision injected externally via ResumeWithParams. type ResumeResponse struct { Approved bool UpdatedInput string DenyMessage string } -// Middleware is a ChatModelAgentMiddleware that gates tool execution behind -// a user-defined BeforeToolCall permission check. type Middleware struct { *adk.BaseChatModelAgentMiddleware - beforeToolCall BeforeToolCall + checker Checker } -// NewMiddleware creates a new permission Middleware with the given BeforeToolCall evaluator. -func NewMiddleware(beforeToolCall BeforeToolCall) *Middleware { +// New creates a permission Middleware with the given Checker evaluator. +func New(checker Checker) *Middleware { return &Middleware{ BaseChatModelAgentMiddleware: &adk.BaseChatModelAgentMiddleware{}, - beforeToolCall: beforeToolCall, + checker: checker, } } @@ -126,15 +117,23 @@ func (m *Middleware) permissionGate( } if isTarget && !hasData { - return nil, fmt.Errorf("permission: tool %s was targeted for resume but received nil or type-mismatched ResumeResponse", tCtx.Name) + return nil, fmt.Errorf( + "permission: tool %q (call_id=%s) was targeted for resume but received nil "+ + "or type-mismatched ResumeResponse; the caller must supply a *permission.ResumeResponse "+ + "via ResumeWithParams", tCtx.Name, tCtx.CallID) } - decision, err := m.beforeToolCall(ctx, tCtx.Name, argumentsInJSON) + decision, err := m.checker(ctx, tCtx, argumentsInJSON) if err != nil { - return nil, fmt.Errorf("permission check failed: %w", err) + return nil, fmt.Errorf( + "permission: checker error for tool %q (call_id=%s, args=%s): %w", + tCtx.Name, tCtx.CallID, argumentsInJSON, err) } if decision == nil { - return nil, fmt.Errorf("permission: BeforeToolCall returned nil decision for tool %s", tCtx.Name) + return nil, fmt.Errorf( + "permission: checker returned nil ToolCallDecision for tool %q (call_id=%s); "+ + "return a valid *ToolCallDecision with Decision set to Allow, Deny, or Ask", + tCtx.Name, tCtx.CallID) } switch decision.Decision { @@ -160,11 +159,10 @@ func (m *Middleware) permissionGate( default: return &gateResult{denyResult: formatDenyResult(tCtx.Name, - fmt.Sprintf("unknown permission decision %q", decision.Decision))}, nil + fmt.Sprintf("unknown permission decision %q; expected allow, deny, or ask", decision.Decision))}, nil } } -// WrapInvokableToolCall intercepts synchronous tool calls with a permission gate. func (m *Middleware) WrapInvokableToolCall( ctx context.Context, endpoint adk.InvokableToolCallEndpoint, @@ -182,7 +180,6 @@ func (m *Middleware) WrapInvokableToolCall( }, nil } -// WrapStreamableToolCall intercepts streaming tool calls with a permission gate. func (m *Middleware) WrapStreamableToolCall( ctx context.Context, endpoint adk.StreamableToolCallEndpoint, @@ -194,10 +191,7 @@ func (m *Middleware) WrapStreamableToolCall( return nil, err } if !result.allowed { - sr, sw := schema.Pipe[string](1) - sw.Send(result.denyResult, nil) - sw.Close() - return sr, nil + return schema.StreamReaderFromArray([]string{result.denyResult}), nil } return endpoint(ctx, result.updatedInput, opts...) }, nil diff --git a/adk/middlewares/permission/permission_test.go b/adk/middlewares/permission/permission_test.go index 784e4af9a..b713cae1b 100644 --- a/adk/middlewares/permission/permission_test.go +++ b/adk/middlewares/permission/permission_test.go @@ -35,14 +35,14 @@ import ( "github.com/cloudwego/eino/schema" ) -func TestNewMiddleware(t *testing.T) { +func TestNew(t *testing.T) { called := false - m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { called = true return &ToolCallDecision{Decision: Allow}, nil }) assert.NotNil(t, m) - assert.NotNil(t, m.beforeToolCall) + assert.NotNil(t, m.checker) assert.NotNil(t, m.BaseChatModelAgentMiddleware) assert.False(t, called) } @@ -58,9 +58,9 @@ func makeCtxWithAddr() context.Context { } func TestPermissionGate_Allow(t *testing.T) { - m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { - assert.Equal(t, "ReadFile", toolName) - assert.Equal(t, `{"path":"/tmp/x"}`, args) + m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { + assert.Equal(t, "ReadFile", tCtx.Name) + assert.Equal(t, `{"path":"/tmp/x"}`, argumentsInJSON) return &ToolCallDecision{Decision: Allow, Reason: "read-only"}, nil }) @@ -72,7 +72,7 @@ func TestPermissionGate_Allow(t *testing.T) { } func TestPermissionGate_AllowWithUpdatedInput(t *testing.T) { - m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { return &ToolCallDecision{ Decision: Allow, UpdatedInput: `{"path":"/tmp/safe"}`, @@ -87,7 +87,7 @@ func TestPermissionGate_AllowWithUpdatedInput(t *testing.T) { } func TestPermissionGate_Deny(t *testing.T) { - m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { return &ToolCallDecision{ Decision: Deny, Message: "operation not allowed", @@ -103,7 +103,7 @@ func TestPermissionGate_Deny(t *testing.T) { } func TestPermissionGate_Ask(t *testing.T) { - m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { return &ToolCallDecision{ Decision: Ask, Message: "requires approval", @@ -121,7 +121,7 @@ func TestPermissionGate_Ask(t *testing.T) { } func TestPermissionGate_UnknownDecision(t *testing.T) { - m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { return &ToolCallDecision{Decision: "maybe"}, nil }) @@ -134,7 +134,7 @@ func TestPermissionGate_UnknownDecision(t *testing.T) { } func TestPermissionGate_NilDecision(t *testing.T) { - m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { return nil, nil }) @@ -142,11 +142,11 @@ func TestPermissionGate_NilDecision(t *testing.T) { result, err := m.permissionGate(context.Background(), tCtx, `{}`) assert.Nil(t, result) require.Error(t, err) - assert.Contains(t, err.Error(), "nil decision") + assert.Contains(t, err.Error(), "nil ToolCallDecision") } func TestPermissionGate_BeforeToolCallError(t *testing.T) { - m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { return nil, fmt.Errorf("rule store unreachable") }) @@ -154,13 +154,13 @@ func TestPermissionGate_BeforeToolCallError(t *testing.T) { result, err := m.permissionGate(context.Background(), tCtx, `{}`) assert.Nil(t, result) require.Error(t, err) - assert.Contains(t, err.Error(), "permission check failed") + assert.Contains(t, err.Error(), "permission: checker error") assert.Contains(t, err.Error(), "rule store unreachable") } func TestWrapInvokableToolCall_Allow(t *testing.T) { endpointCalled := false - m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { return &ToolCallDecision{Decision: Allow}, nil }) @@ -180,7 +180,7 @@ func TestWrapInvokableToolCall_Allow(t *testing.T) { } func TestWrapInvokableToolCall_AllowWithUpdatedInput(t *testing.T) { - m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { return &ToolCallDecision{Decision: Allow, UpdatedInput: `{"sanitized":true}`}, nil }) @@ -202,7 +202,7 @@ func TestWrapInvokableToolCall_AllowWithUpdatedInput(t *testing.T) { func TestWrapInvokableToolCall_Deny(t *testing.T) { endpointCalled := false - m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { return &ToolCallDecision{Decision: Deny, Message: "blocked"}, nil }) @@ -222,7 +222,7 @@ func TestWrapInvokableToolCall_Deny(t *testing.T) { } func TestWrapInvokableToolCall_Ask(t *testing.T) { - m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { return &ToolCallDecision{Decision: Ask, Message: "need approval"}, nil }) @@ -244,7 +244,7 @@ func TestWrapInvokableToolCall_Ask(t *testing.T) { } func TestWrapStreamableToolCall_Allow(t *testing.T) { - m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { return &ToolCallDecision{Decision: Allow}, nil }) @@ -273,7 +273,7 @@ func TestWrapStreamableToolCall_Allow(t *testing.T) { func TestWrapStreamableToolCall_Deny(t *testing.T) { endpointCalled := false - m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { return &ToolCallDecision{Decision: Deny, Message: "stream blocked"}, nil }) @@ -300,7 +300,7 @@ func TestWrapStreamableToolCall_Deny(t *testing.T) { } func TestWrapStreamableToolCall_Ask(t *testing.T) { - m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { return &ToolCallDecision{Decision: Ask, Message: "stream needs approval"}, nil }) @@ -322,7 +322,7 @@ func TestWrapStreamableToolCall_Ask(t *testing.T) { } func TestWrapInvokableToolCall_BeforeToolCallError(t *testing.T) { - m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { return nil, fmt.Errorf("infra failure") }) @@ -337,11 +337,11 @@ func TestWrapInvokableToolCall_BeforeToolCallError(t *testing.T) { result, err := wrapped(context.Background(), `{}`) assert.Equal(t, "", result) require.Error(t, err) - assert.Contains(t, err.Error(), "permission check failed") + assert.Contains(t, err.Error(), "permission: checker error") } func TestMiddleware_ImplementsInterface(t *testing.T) { - m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { return &ToolCallDecision{Decision: Allow}, nil }) var _ adk.ChatModelAgentMiddleware = m @@ -385,7 +385,7 @@ func TestResumeResponse_Denied(t *testing.T) { } func TestAttack_NilBeforeToolCall(t *testing.T) { - m := NewMiddleware(nil) + m := New(nil) require.NotNil(t, m) tCtx := &adk.ToolContext{Name: "SomeTool", CallID: "call_nil"} @@ -398,7 +398,7 @@ func TestAttack_EmptyDenyMessage(t *testing.T) { result := formatDenyResult("WriteTool", "") assert.Equal(t, "Permission denied for tool WriteTool: ", result) - m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { return &ToolCallDecision{Decision: Deny, Message: ""}, nil }) @@ -410,8 +410,8 @@ func TestAttack_EmptyDenyMessage(t *testing.T) { } func TestAttack_DenyWithEmptyToolName(t *testing.T) { - m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { - assert.Equal(t, "", toolName) + m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { + assert.Equal(t, "", tCtx.Name) return &ToolCallDecision{Decision: Deny, Message: "no name"}, nil }) @@ -423,7 +423,7 @@ func TestAttack_DenyWithEmptyToolName(t *testing.T) { } func TestAttack_AllowUpdatedInputEmpty(t *testing.T) { - m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { return &ToolCallDecision{Decision: Allow, UpdatedInput: ""}, nil }) @@ -481,7 +481,7 @@ func TestAttack_AskInfoGobSerializable(t *testing.T) { } func TestAttack_ResumeResponseEmptyUpdatedInput(t *testing.T) { - m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { return &ToolCallDecision{Decision: Ask, Message: "confirm?"}, nil }) @@ -501,7 +501,7 @@ func TestAttack_ResumeResponseEmptyUpdatedInput(t *testing.T) { } func TestAttack_ConcurrentBeforeToolCall(t *testing.T) { - m := NewMiddleware(func(ctx context.Context, toolName, args string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { return &ToolCallDecision{Decision: Allow}, nil }) @@ -544,3 +544,125 @@ func TestAttack_ConcurrentBeforeToolCall(t *testing.T) { assert.Len(t, received, goroutines) receivedMu.Unlock() } + +func buildResumeCtx( + t *testing.T, + signal *core.InterruptSignal, + resumeData map[string]any, +) context.Context { + t.Helper() + id2Addr, id2State := core.SignalToPersistenceMaps(signal) + ctx := context.Background() + ctx = core.PopulateInterruptState(ctx, id2Addr, id2State) + ctx = core.BatchResumeWithData(ctx, resumeData) + ctx = core.AppendAddressSegment(ctx, "agent", "test-agent", "") + return ctx +} + +func TestE2E_AskThenResumeApproved(t *testing.T) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Ask, Message: "approve rm?"}, nil + }) + + tCtx := &adk.ToolContext{Name: "ShellExec", CallID: "call_e2e_1"} + originalArgs := `{"cmd":"rm -rf /"}` + + ctx := makeCtxWithAddr() + result, err := m.permissionGate(ctx, tCtx, originalArgs) + assert.Nil(t, result) + require.Error(t, err) + + var signal *core.InterruptSignal + require.True(t, errors.As(err, &signal)) + assert.Equal(t, originalArgs, signal.InterruptState.State.(*AskState).Info.Arguments) + + resumeCtx := buildResumeCtx(t, signal, map[string]any{ + signal.ID: &ResumeResponse{Approved: true}, + }) + + result, err = m.permissionGate(resumeCtx, tCtx, originalArgs) + require.NoError(t, err) + assert.True(t, result.allowed) + assert.Equal(t, originalArgs, result.updatedInput) +} + +func TestE2E_AskThenResumeDenied(t *testing.T) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Ask, Message: "dangerous"}, nil + }) + + tCtx := &adk.ToolContext{Name: "DeleteDB", CallID: "call_e2e_deny"} + ctx := makeCtxWithAddr() + + _, err := m.permissionGate(ctx, tCtx, `{}`) + require.Error(t, err) + + var signal *core.InterruptSignal + require.True(t, errors.As(err, &signal)) + + resumeCtx := buildResumeCtx(t, signal, map[string]any{ + signal.ID: &ResumeResponse{Approved: false, DenyMessage: "user said no"}, + }) + + result, err := m.permissionGate(resumeCtx, tCtx, `{}`) + require.NoError(t, err) + assert.False(t, result.allowed) + assert.Contains(t, result.denyResult, "user said no") +} + +func TestE2E_ReInterruptNonTargetTool(t *testing.T) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Ask, Message: "confirm"}, nil + }) + + tCtx := &adk.ToolContext{Name: "ToolA", CallID: "call_a"} + ctx := makeCtxWithAddr() + + _, err := m.permissionGate(ctx, tCtx, `{}`) + require.Error(t, err) + + var signal *core.InterruptSignal + require.True(t, errors.As(err, &signal)) + + id2Addr, id2State := core.SignalToPersistenceMaps(signal) + resumeCtx := context.Background() + resumeCtx = core.PopulateInterruptState(resumeCtx, id2Addr, id2State) + resumeCtx = core.BatchResumeWithData(resumeCtx, map[string]any{ + "some_other_id": &ResumeResponse{Approved: true}, + }) + resumeCtx = core.AppendAddressSegment(resumeCtx, "agent", "test-agent", "") + + result, err := m.permissionGate(resumeCtx, tCtx, `{}`) + assert.Nil(t, result) + require.Error(t, err) + + var reSignal *core.InterruptSignal + require.True(t, errors.As(err, &reSignal)) + assert.Equal(t, "ToolA", reSignal.InterruptState.State.(*AskState).Info.ToolName) +} + +func TestE2E_ResumeWithUpdatedInput(t *testing.T) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Ask, Message: "sanitize?"}, nil + }) + + tCtx := &adk.ToolContext{Name: "WriteFile", CallID: "call_e2e_update"} + originalArgs := `{"path":"/etc/passwd","content":"hacked"}` + + ctx := makeCtxWithAddr() + _, err := m.permissionGate(ctx, tCtx, originalArgs) + require.Error(t, err) + + var signal *core.InterruptSignal + require.True(t, errors.As(err, &signal)) + + sanitizedArgs := `{"path":"/tmp/safe.txt","content":"ok"}` + resumeCtx := buildResumeCtx(t, signal, map[string]any{ + signal.ID: &ResumeResponse{Approved: true, UpdatedInput: sanitizedArgs}, + }) + + result, err := m.permissionGate(resumeCtx, tCtx, originalArgs) + require.NoError(t, err) + assert.True(t, result.allowed) + assert.Equal(t, sanitizedArgs, result.updatedInput) +} From 8b1082fec71db45dfcc364a55f5b20682a2357be Mon Sep 17 00:00:00 2001 From: "shentong.martin" Date: Wed, 15 Apr 2026 10:43:12 +0800 Subject: [PATCH 58/59] refactor(middlewares): change Checker param from string to *schema.ToolArgument Forward-compatibility: when ToolArgument gains additional fields (e.g. multimodal content), existing Checker implementations will receive them without signature changes. The middleware constructs the ToolArgument from the raw argumentsInJSON before invoking the checker, and extracts .Text back for the endpoint. UpdatedInput stays string since the endpoint signature remains string-based. Change-Id: I19b8c808415e01969330466c034becbd74e25994 --- adk/middlewares/permission/permission.go | 14 ++--- adk/middlewares/permission/permission_test.go | 54 +++++++++---------- 2 files changed, 35 insertions(+), 33 deletions(-) diff --git a/adk/middlewares/permission/permission.go b/adk/middlewares/permission/permission.go index 6777bf893..b416d16c6 100644 --- a/adk/middlewares/permission/permission.go +++ b/adk/middlewares/permission/permission.go @@ -51,11 +51,13 @@ type ToolCallDecision struct { // Checker is the user-provided evaluation function invoked before each tool call. // It receives the full ToolContext (including tool name and call ID) along with -// the raw JSON arguments, and returns a ToolCallDecision that determines whether -// the call is allowed, denied, or requires interactive approval. -// Returning an error signals an infrastructure failure and aborts the agent loop; -// permission denials should use Decision: Deny instead. -type Checker func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) +// the tool arguments as a *schema.ToolArgument, and returns a ToolCallDecision +// that determines whether the call is allowed, denied, or requires interactive +// approval. Using *schema.ToolArgument instead of a raw string ensures +// forward-compatibility when the struct gains additional fields (e.g. multimodal +// content). Returning an error signals an infrastructure failure and aborts the +// agent loop; permission denials should use Decision: Deny instead. +type Checker func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) type AskInfo struct { ToolName string @@ -123,7 +125,7 @@ func (m *Middleware) permissionGate( "via ResumeWithParams", tCtx.Name, tCtx.CallID) } - decision, err := m.checker(ctx, tCtx, argumentsInJSON) + decision, err := m.checker(ctx, tCtx, &schema.ToolArgument{Text: argumentsInJSON}) if err != nil { return nil, fmt.Errorf( "permission: checker error for tool %q (call_id=%s, args=%s): %w", diff --git a/adk/middlewares/permission/permission_test.go b/adk/middlewares/permission/permission_test.go index b713cae1b..0b18ce18e 100644 --- a/adk/middlewares/permission/permission_test.go +++ b/adk/middlewares/permission/permission_test.go @@ -37,7 +37,7 @@ import ( func TestNew(t *testing.T) { called := false - m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { called = true return &ToolCallDecision{Decision: Allow}, nil }) @@ -58,9 +58,9 @@ func makeCtxWithAddr() context.Context { } func TestPermissionGate_Allow(t *testing.T) { - m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { assert.Equal(t, "ReadFile", tCtx.Name) - assert.Equal(t, `{"path":"/tmp/x"}`, argumentsInJSON) + assert.Equal(t, `{"path":"/tmp/x"}`, args.Text) return &ToolCallDecision{Decision: Allow, Reason: "read-only"}, nil }) @@ -72,7 +72,7 @@ func TestPermissionGate_Allow(t *testing.T) { } func TestPermissionGate_AllowWithUpdatedInput(t *testing.T) { - m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { return &ToolCallDecision{ Decision: Allow, UpdatedInput: `{"path":"/tmp/safe"}`, @@ -87,7 +87,7 @@ func TestPermissionGate_AllowWithUpdatedInput(t *testing.T) { } func TestPermissionGate_Deny(t *testing.T) { - m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { return &ToolCallDecision{ Decision: Deny, Message: "operation not allowed", @@ -103,7 +103,7 @@ func TestPermissionGate_Deny(t *testing.T) { } func TestPermissionGate_Ask(t *testing.T) { - m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { return &ToolCallDecision{ Decision: Ask, Message: "requires approval", @@ -121,7 +121,7 @@ func TestPermissionGate_Ask(t *testing.T) { } func TestPermissionGate_UnknownDecision(t *testing.T) { - m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { return &ToolCallDecision{Decision: "maybe"}, nil }) @@ -134,7 +134,7 @@ func TestPermissionGate_UnknownDecision(t *testing.T) { } func TestPermissionGate_NilDecision(t *testing.T) { - m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { return nil, nil }) @@ -146,7 +146,7 @@ func TestPermissionGate_NilDecision(t *testing.T) { } func TestPermissionGate_BeforeToolCallError(t *testing.T) { - m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { return nil, fmt.Errorf("rule store unreachable") }) @@ -160,7 +160,7 @@ func TestPermissionGate_BeforeToolCallError(t *testing.T) { func TestWrapInvokableToolCall_Allow(t *testing.T) { endpointCalled := false - m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { return &ToolCallDecision{Decision: Allow}, nil }) @@ -180,7 +180,7 @@ func TestWrapInvokableToolCall_Allow(t *testing.T) { } func TestWrapInvokableToolCall_AllowWithUpdatedInput(t *testing.T) { - m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { return &ToolCallDecision{Decision: Allow, UpdatedInput: `{"sanitized":true}`}, nil }) @@ -202,7 +202,7 @@ func TestWrapInvokableToolCall_AllowWithUpdatedInput(t *testing.T) { func TestWrapInvokableToolCall_Deny(t *testing.T) { endpointCalled := false - m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { return &ToolCallDecision{Decision: Deny, Message: "blocked"}, nil }) @@ -222,7 +222,7 @@ func TestWrapInvokableToolCall_Deny(t *testing.T) { } func TestWrapInvokableToolCall_Ask(t *testing.T) { - m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { return &ToolCallDecision{Decision: Ask, Message: "need approval"}, nil }) @@ -244,7 +244,7 @@ func TestWrapInvokableToolCall_Ask(t *testing.T) { } func TestWrapStreamableToolCall_Allow(t *testing.T) { - m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { return &ToolCallDecision{Decision: Allow}, nil }) @@ -273,7 +273,7 @@ func TestWrapStreamableToolCall_Allow(t *testing.T) { func TestWrapStreamableToolCall_Deny(t *testing.T) { endpointCalled := false - m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { return &ToolCallDecision{Decision: Deny, Message: "stream blocked"}, nil }) @@ -300,7 +300,7 @@ func TestWrapStreamableToolCall_Deny(t *testing.T) { } func TestWrapStreamableToolCall_Ask(t *testing.T) { - m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { return &ToolCallDecision{Decision: Ask, Message: "stream needs approval"}, nil }) @@ -322,7 +322,7 @@ func TestWrapStreamableToolCall_Ask(t *testing.T) { } func TestWrapInvokableToolCall_BeforeToolCallError(t *testing.T) { - m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { return nil, fmt.Errorf("infra failure") }) @@ -341,7 +341,7 @@ func TestWrapInvokableToolCall_BeforeToolCallError(t *testing.T) { } func TestMiddleware_ImplementsInterface(t *testing.T) { - m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { return &ToolCallDecision{Decision: Allow}, nil }) var _ adk.ChatModelAgentMiddleware = m @@ -398,7 +398,7 @@ func TestAttack_EmptyDenyMessage(t *testing.T) { result := formatDenyResult("WriteTool", "") assert.Equal(t, "Permission denied for tool WriteTool: ", result) - m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { return &ToolCallDecision{Decision: Deny, Message: ""}, nil }) @@ -410,7 +410,7 @@ func TestAttack_EmptyDenyMessage(t *testing.T) { } func TestAttack_DenyWithEmptyToolName(t *testing.T) { - m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { assert.Equal(t, "", tCtx.Name) return &ToolCallDecision{Decision: Deny, Message: "no name"}, nil }) @@ -423,7 +423,7 @@ func TestAttack_DenyWithEmptyToolName(t *testing.T) { } func TestAttack_AllowUpdatedInputEmpty(t *testing.T) { - m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { return &ToolCallDecision{Decision: Allow, UpdatedInput: ""}, nil }) @@ -481,7 +481,7 @@ func TestAttack_AskInfoGobSerializable(t *testing.T) { } func TestAttack_ResumeResponseEmptyUpdatedInput(t *testing.T) { - m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { return &ToolCallDecision{Decision: Ask, Message: "confirm?"}, nil }) @@ -501,7 +501,7 @@ func TestAttack_ResumeResponseEmptyUpdatedInput(t *testing.T) { } func TestAttack_ConcurrentBeforeToolCall(t *testing.T) { - m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { return &ToolCallDecision{Decision: Allow}, nil }) @@ -560,7 +560,7 @@ func buildResumeCtx( } func TestE2E_AskThenResumeApproved(t *testing.T) { - m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { return &ToolCallDecision{Decision: Ask, Message: "approve rm?"}, nil }) @@ -587,7 +587,7 @@ func TestE2E_AskThenResumeApproved(t *testing.T) { } func TestE2E_AskThenResumeDenied(t *testing.T) { - m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { return &ToolCallDecision{Decision: Ask, Message: "dangerous"}, nil }) @@ -611,7 +611,7 @@ func TestE2E_AskThenResumeDenied(t *testing.T) { } func TestE2E_ReInterruptNonTargetTool(t *testing.T) { - m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { return &ToolCallDecision{Decision: Ask, Message: "confirm"}, nil }) @@ -642,7 +642,7 @@ func TestE2E_ReInterruptNonTargetTool(t *testing.T) { } func TestE2E_ResumeWithUpdatedInput(t *testing.T) { - m := New(func(ctx context.Context, tCtx *adk.ToolContext, argumentsInJSON string) (*ToolCallDecision, error) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { return &ToolCallDecision{Decision: Ask, Message: "sanitize?"}, nil }) From 1d491542ba0f26684a1a58c1822d8367a0610a0f Mon Sep 17 00:00:00 2001 From: "shentong.martin" Date: Wed, 15 Apr 2026 11:28:34 +0800 Subject: [PATCH 59/59] feat(middlewares): wrap enhanced tool endpoints and add i18n deny messages Cover EnhancedInvokableToolCall and EnhancedStreamableToolCall endpoints in the permission gate to prevent bypass. Add denyToolResult helper for structured ToolResult deny responses. Use internal.SelectPrompt for English/Chinese deny message templates. Change-Id: I7c95ee55526bb2109566e5cd6a60cb5242baaf8b --- adk/middlewares/permission/permission.go | 55 ++++- adk/middlewares/permission/permission_test.go | 208 ++++++++++++++++++ 2 files changed, 262 insertions(+), 1 deletion(-) diff --git a/adk/middlewares/permission/permission.go b/adk/middlewares/permission/permission.go index b416d16c6..728606a55 100644 --- a/adk/middlewares/permission/permission.go +++ b/adk/middlewares/permission/permission.go @@ -25,6 +25,7 @@ import ( "fmt" "github.com/cloudwego/eino/adk" + "github.com/cloudwego/eino/adk/internal" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/schema" ) @@ -199,6 +200,58 @@ func (m *Middleware) WrapStreamableToolCall( }, nil } +func (m *Middleware) WrapEnhancedInvokableToolCall( + ctx context.Context, + endpoint adk.EnhancedInvokableToolCallEndpoint, + tCtx *adk.ToolContext, +) (adk.EnhancedInvokableToolCallEndpoint, error) { + return func(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (*schema.ToolResult, error) { + result, err := m.permissionGate(ctx, tCtx, toolArgument.Text) + if err != nil { + return nil, err + } + if !result.allowed { + return denyToolResult(result.denyResult), nil + } + if result.updatedInput != toolArgument.Text { + toolArgument = &schema.ToolArgument{Text: result.updatedInput} + } + return endpoint(ctx, toolArgument, opts...) + }, nil +} + +func (m *Middleware) WrapEnhancedStreamableToolCall( + ctx context.Context, + endpoint adk.EnhancedStreamableToolCallEndpoint, + tCtx *adk.ToolContext, +) (adk.EnhancedStreamableToolCallEndpoint, error) { + return func(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (*schema.StreamReader[*schema.ToolResult], error) { + result, err := m.permissionGate(ctx, tCtx, toolArgument.Text) + if err != nil { + return nil, err + } + if !result.allowed { + return schema.StreamReaderFromArray([]*schema.ToolResult{denyToolResult(result.denyResult)}), nil + } + if result.updatedInput != toolArgument.Text { + toolArgument = &schema.ToolArgument{Text: result.updatedInput} + } + return endpoint(ctx, toolArgument, opts...) + }, nil +} + +func denyToolResult(denyMsg string) *schema.ToolResult { + return &schema.ToolResult{ + Parts: []schema.ToolOutputPart{ + {Type: schema.ToolPartTypeText, Text: denyMsg}, + }, + } +} + func formatDenyResult(toolName, message string) string { - return fmt.Sprintf("Permission denied for tool %s: %s", toolName, message) + tpl := internal.SelectPrompt(internal.I18nPrompts{ + English: "Permission denied for tool %s: %s", + Chinese: "工具 %s 权限被拒绝: %s", + }) + return fmt.Sprintf(tpl, toolName, message) } diff --git a/adk/middlewares/permission/permission_test.go b/adk/middlewares/permission/permission_test.go index 0b18ce18e..f2d97c925 100644 --- a/adk/middlewares/permission/permission_test.go +++ b/adk/middlewares/permission/permission_test.go @@ -666,3 +666,211 @@ func TestE2E_ResumeWithUpdatedInput(t *testing.T) { assert.True(t, result.allowed) assert.Equal(t, sanitizedArgs, result.updatedInput) } + +// --- Enhanced Tool Call Endpoint Tests --- + +func TestWrapEnhancedInvokableToolCall_Allow(t *testing.T) { + endpointCalled := false + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { + assert.Equal(t, "EnhancedTool", tCtx.Name) + assert.Equal(t, `{"key":"val"}`, args.Text) + return &ToolCallDecision{Decision: Allow}, nil + }) + + originalEndpoint := adk.EnhancedInvokableToolCallEndpoint( + func(ctx context.Context, arg *schema.ToolArgument, opts ...tool.Option) (*schema.ToolResult, error) { + endpointCalled = true + return &schema.ToolResult{Parts: []schema.ToolOutputPart{ + {Type: schema.ToolPartTypeText, Text: "enhanced:" + arg.Text}, + }}, nil + }) + + tCtx := &adk.ToolContext{Name: "EnhancedTool", CallID: "call_ei_1"} + wrapped, err := m.WrapEnhancedInvokableToolCall(context.Background(), originalEndpoint, tCtx) + require.NoError(t, err) + + result, err := wrapped(context.Background(), &schema.ToolArgument{Text: `{"key":"val"}`}) + require.NoError(t, err) + assert.True(t, endpointCalled) + require.Len(t, result.Parts, 1) + assert.Equal(t, `enhanced:{"key":"val"}`, result.Parts[0].Text) +} + +func TestWrapEnhancedInvokableToolCall_AllowWithUpdatedInput(t *testing.T) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Allow, UpdatedInput: `{"sanitized":true}`}, nil + }) + + var receivedText string + originalEndpoint := adk.EnhancedInvokableToolCallEndpoint( + func(ctx context.Context, arg *schema.ToolArgument, opts ...tool.Option) (*schema.ToolResult, error) { + receivedText = arg.Text + return &schema.ToolResult{Parts: []schema.ToolOutputPart{ + {Type: schema.ToolPartTypeText, Text: "ok"}, + }}, nil + }) + + tCtx := &adk.ToolContext{Name: "EnhancedTool", CallID: "call_ei_2"} + wrapped, err := m.WrapEnhancedInvokableToolCall(context.Background(), originalEndpoint, tCtx) + require.NoError(t, err) + + _, err = wrapped(context.Background(), &schema.ToolArgument{Text: `{"original":true}`}) + require.NoError(t, err) + assert.Equal(t, `{"sanitized":true}`, receivedText) +} + +func TestWrapEnhancedInvokableToolCall_Deny(t *testing.T) { + endpointCalled := false + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Deny, Message: "enhanced blocked"}, nil + }) + + originalEndpoint := adk.EnhancedInvokableToolCallEndpoint( + func(ctx context.Context, arg *schema.ToolArgument, opts ...tool.Option) (*schema.ToolResult, error) { + endpointCalled = true + return nil, nil + }) + + tCtx := &adk.ToolContext{Name: "EnhancedTool", CallID: "call_ei_3"} + wrapped, err := m.WrapEnhancedInvokableToolCall(context.Background(), originalEndpoint, tCtx) + require.NoError(t, err) + + result, err := wrapped(context.Background(), &schema.ToolArgument{Text: `{}`}) + require.NoError(t, err) + assert.False(t, endpointCalled) + require.Len(t, result.Parts, 1) + assert.Equal(t, schema.ToolPartTypeText, result.Parts[0].Type) + assert.Equal(t, "Permission denied for tool EnhancedTool: enhanced blocked", result.Parts[0].Text) +} + +func TestWrapEnhancedInvokableToolCall_Ask(t *testing.T) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Ask, Message: "enhanced needs approval"}, nil + }) + + originalEndpoint := adk.EnhancedInvokableToolCallEndpoint( + func(ctx context.Context, arg *schema.ToolArgument, opts ...tool.Option) (*schema.ToolResult, error) { + return nil, nil + }) + + tCtx := &adk.ToolContext{Name: "EnhancedTool", CallID: "call_ei_4"} + wrapped, err := m.WrapEnhancedInvokableToolCall(context.Background(), originalEndpoint, tCtx) + require.NoError(t, err) + + ctx := makeCtxWithAddr() + result, err := wrapped(ctx, &schema.ToolArgument{Text: `{}`}) + assert.Nil(t, result) + require.Error(t, err) + + var is *core.InterruptSignal + assert.True(t, errors.As(err, &is)) +} + +func TestWrapEnhancedStreamableToolCall_Allow(t *testing.T) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Allow}, nil + }) + + originalEndpoint := adk.EnhancedStreamableToolCallEndpoint( + func(ctx context.Context, arg *schema.ToolArgument, opts ...tool.Option) (*schema.StreamReader[*schema.ToolResult], error) { + tr := &schema.ToolResult{Parts: []schema.ToolOutputPart{ + {Type: schema.ToolPartTypeText, Text: "streamed:" + arg.Text}, + }} + return schema.StreamReaderFromArray([]*schema.ToolResult{tr}), nil + }) + + tCtx := &adk.ToolContext{Name: "EnhancedStreamTool", CallID: "call_es_1"} + wrapped, err := m.WrapEnhancedStreamableToolCall(context.Background(), originalEndpoint, tCtx) + require.NoError(t, err) + + sr, err := wrapped(context.Background(), &schema.ToolArgument{Text: `{"k":"v"}`}) + require.NoError(t, err) + require.NotNil(t, sr) + + chunk, err := sr.Recv() + require.NoError(t, err) + require.Len(t, chunk.Parts, 1) + assert.Equal(t, `streamed:{"k":"v"}`, chunk.Parts[0].Text) + + _, err = sr.Recv() + assert.Equal(t, io.EOF, err) +} + +func TestWrapEnhancedStreamableToolCall_Deny(t *testing.T) { + endpointCalled := false + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Deny, Message: "stream enhanced blocked"}, nil + }) + + originalEndpoint := adk.EnhancedStreamableToolCallEndpoint( + func(ctx context.Context, arg *schema.ToolArgument, opts ...tool.Option) (*schema.StreamReader[*schema.ToolResult], error) { + endpointCalled = true + return nil, nil + }) + + tCtx := &adk.ToolContext{Name: "EnhancedStreamTool", CallID: "call_es_2"} + wrapped, err := m.WrapEnhancedStreamableToolCall(context.Background(), originalEndpoint, tCtx) + require.NoError(t, err) + + sr, err := wrapped(context.Background(), &schema.ToolArgument{Text: `{}`}) + require.NoError(t, err) + assert.False(t, endpointCalled) + require.NotNil(t, sr) + + chunk, err := sr.Recv() + require.NoError(t, err) + require.Len(t, chunk.Parts, 1) + assert.Equal(t, schema.ToolPartTypeText, chunk.Parts[0].Type) + assert.Equal(t, "Permission denied for tool EnhancedStreamTool: stream enhanced blocked", chunk.Parts[0].Text) + + _, err = sr.Recv() + assert.Equal(t, io.EOF, err) +} + +func TestWrapEnhancedStreamableToolCall_Ask(t *testing.T) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Ask, Message: "enhanced stream needs approval"}, nil + }) + + originalEndpoint := adk.EnhancedStreamableToolCallEndpoint( + func(ctx context.Context, arg *schema.ToolArgument, opts ...tool.Option) (*schema.StreamReader[*schema.ToolResult], error) { + return nil, nil + }) + + tCtx := &adk.ToolContext{Name: "EnhancedStreamTool", CallID: "call_es_3"} + wrapped, err := m.WrapEnhancedStreamableToolCall(context.Background(), originalEndpoint, tCtx) + require.NoError(t, err) + + ctx := makeCtxWithAddr() + sr, err := wrapped(ctx, &schema.ToolArgument{Text: `{}`}) + assert.Nil(t, sr) + require.Error(t, err) + + var is *core.InterruptSignal + assert.True(t, errors.As(err, &is)) +} + +func TestWrapEnhancedStreamableToolCall_AllowWithUpdatedInput(t *testing.T) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Allow, UpdatedInput: `{"safe":true}`}, nil + }) + + var receivedText string + originalEndpoint := adk.EnhancedStreamableToolCallEndpoint( + func(ctx context.Context, arg *schema.ToolArgument, opts ...tool.Option) (*schema.StreamReader[*schema.ToolResult], error) { + receivedText = arg.Text + tr := &schema.ToolResult{Parts: []schema.ToolOutputPart{ + {Type: schema.ToolPartTypeText, Text: "ok"}, + }} + return schema.StreamReaderFromArray([]*schema.ToolResult{tr}), nil + }) + + tCtx := &adk.ToolContext{Name: "EnhancedStreamTool", CallID: "call_es_4"} + wrapped, err := m.WrapEnhancedStreamableToolCall(context.Background(), originalEndpoint, tCtx) + require.NoError(t, err) + + sr, err := wrapped(context.Background(), &schema.ToolArgument{Text: `{"dangerous":true}`}) + require.NoError(t, err) + require.NotNil(t, sr) + assert.Equal(t, `{"safe":true}`, receivedText) +}