diff --git a/.gitignore b/.gitignore index 8ef36de95..04542d49a 100644 --- a/.gitignore +++ b/.gitignore @@ -47,10 +47,14 @@ output/* # Reports (generated analysis files) reports/ +/todos .DS_Store -*.log +*.log* +.claude CLAUDE.md +*.jsonl +*.txt # Specs directories */specs diff --git a/adk/agent_tool.go b/adk/agent_tool.go index 9472dab1f..fde319cb4 100644 --- a/adk/agent_tool.go +++ b/adk/agent_tool.go @@ -167,16 +167,19 @@ func (at *agentTool) InvokableRun(ctx context.Context, argumentsInJSON string, o } iter = newInvokableAgentToolRunner(at.agent, ms, enableStreaming).Run(ctx, input, - append(getOptionsByAgentName(at.agent.Name(ctx), opts), WithCheckPointID(bridgeCheckpointID), withSharedParentSession())...) + append(extractAndDeriveCancelCtx(ctx, at.agent.Name(ctx), opts), WithCheckPointID(bridgeCheckpointID), withSharedParentSession())...) } else { if !hasState { return "", fmt.Errorf("agent tool '%s' interrupt has happened, but cannot find interrupt state", at.agent.Name(ctx)) } - ms = newResumeBridgeStore(state) + ms = newResumeBridgeStore(bridgeCheckpointID, state) + + agentOpts := extractAndDeriveCancelCtx(ctx, at.agent.Name(ctx), opts) + agentOpts = append(agentOpts, withSharedParentSession()) iter, err = newInvokableAgentToolRunner(at.agent, ms, enableStreaming). - Resume(ctx, bridgeCheckpointID, append(getOptionsByAgentName(at.agent.Name(ctx), opts), withSharedParentSession())...) + Resume(ctx, bridgeCheckpointID, agentOpts...) if err != nil { return "", err } @@ -281,6 +284,18 @@ func getOptionsByAgentName(agentName string, opts []tool.Option) []AgentRunOptio return ret } +func extractAndDeriveCancelCtx(ctx context.Context, agentName string, opts []tool.Option) []AgentRunOption { + agentOpts := getOptionsByAgentName(agentName, opts) + baseOpts := getCommonOptions(nil, agentOpts...) + if baseOpts.cancelCtx != nil { + childCtx := baseOpts.cancelCtx.deriveChild(ctx) + agentOpts = append(agentOpts, WrapImplSpecificOptFn(func(o *options) { + o.cancelCtx = childCtx + })) + } + return agentOpts +} + func getEmitGeneratorAndEnableStreaming(opts []tool.Option) (*AsyncGenerator[*AgentEvent], bool) { o := tool.GetImplSpecificOptions[agentToolOptions](nil, opts...) if o == nil { diff --git a/adk/attack_test.go b/adk/attack_test.go new file mode 100644 index 000000000..bfb4462ef --- /dev/null +++ b/adk/attack_test.go @@ -0,0 +1,449 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/cloudwego/eino/schema" +) + +func TestAttack_UntilIdleFor_ConcurrentPushDuringIdleTimer(t *testing.T) { + turnCount := int32(0) + turnDone := make(chan struct{}, 10) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + atomic.AddInt32(&turnCount, 1) + turnDone <- struct{}{} + return &AgentOutput{}, nil + }, + }, nil + }, + }) + + loop.Push("msg1") + <-turnDone + + loop.Stop(UntilIdleFor(200 * time.Millisecond)) + + for i := 0; i < 5; i++ { + time.Sleep(50 * time.Millisecond) + loop.Push("concurrent-" + string(rune('a'+i))) + <-turnDone + } + + done := make(chan struct{}) + go func() { + loop.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(3 * time.Second): + t.Fatal("loop did not exit after idle timeout — Push did not reset timer correctly") + } + + finalCount := atomic.LoadInt32(&turnCount) + assert.Equal(t, int32(6), finalCount, "all 6 pushes should have been processed") +} + +func TestAttack_UntilIdleFor_MultipleStopCallsFirstWins(t *testing.T) { + turnDone := make(chan struct{}) + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + close(turnDone) + return &AgentOutput{}, nil + }, + }, nil + }, + }) + + loop.Push("msg1") + <-turnDone + + loop.Stop(UntilIdleFor(100 * time.Millisecond)) + loop.Stop(UntilIdleFor(10 * time.Minute)) + + done := make(chan struct{}) + go func() { + loop.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("second UntilIdleFor should have been ignored; loop should have exited with 100ms timer") + } +} + +func TestAttack_BareStopOverridesUntilIdleFor(t *testing.T) { + agentStarted := make(chan struct{}) + agentDone := make(chan struct{}) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + close(agentStarted) + <-agentDone + return &AgentOutput{}, nil + }, + }, nil + }, + }) + + loop.Push("msg1") + <-agentStarted + + loop.Stop(UntilIdleFor(10 * time.Minute)) + + loop.Stop() + close(agentDone) + + done := make(chan struct{}) + go func() { + loop.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("bare Stop should override UntilIdleFor and cause immediate shutdown") + } + + exit := loop.Wait() + assert.NoError(t, exit.ExitReason, "bare Stop should exit cleanly") +} + +func TestAttack_StopSignal_NilCancelOptsDoNotDeescalate(t *testing.T) { + agentStarted := make(chan *cancelContext, 1) + probe := &turnLoopStopModeProbeAgent{ccCh: agentStarted} + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return probe, nil + }, + }) + + loop.Push("msg1") + cc := <-agentStarted + + loop.Stop(WithImmediate()) + + time.Sleep(20 * time.Millisecond) + + loop.Stop() + + time.Sleep(20 * time.Millisecond) + mode := cc.getMode() + assert.Equal(t, CancelImmediate, mode, "bare Stop after WithImmediate must not de-escalate cancel mode") + + exit := loop.Wait() + var ce *CancelError + require.True(t, errors.As(exit.ExitReason, &ce)) + assert.Equal(t, CancelImmediate, ce.Info.Mode) +} + +func TestAttack_CanceledItems_EmptyWhenAgentFinishesNormally(t *testing.T) { + agentStarted := make(chan struct{}) + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + close(agentStarted) + return &AgentOutput{}, nil + }, + }, nil + }, + }) + + loop.Push("msg1") + <-agentStarted + time.Sleep(50 * time.Millisecond) + loop.Stop() + + exit := loop.Wait() + assert.NoError(t, exit.ExitReason) + assert.Empty(t, exit.CanceledItems, "CanceledItems must be empty when agent finished normally") +} + +func TestAttack_TurnBuffer_WakeupDoesNotLoseItems(t *testing.T) { + tb := newTurnBuffer[string]() + + tb.Send("a") + tb.Send("b") + tb.Wakeup() + tb.Send("c") + + var got []string + for i := 0; i < 3; i++ { + val, ok := tb.Receive() + require.True(t, ok) + got = append(got, val) + } + + assert.Equal(t, []string{"a", "b", "c"}, got, "Wakeup must not cause items to be lost") +} + +func TestAttack_TurnBuffer_ClearWakeupPreventsSpuriousReturn(t *testing.T) { + tb := newTurnBuffer[string]() + + tb.Wakeup() + tb.ClearWakeup() + + received := make(chan string, 1) + go func() { + val, ok := tb.Receive() + if ok { + received <- val + } + }() + + time.Sleep(50 * time.Millisecond) + tb.Send("real") + + select { + case val := <-received: + assert.Equal(t, "real", val, "ClearWakeup should prevent spurious empty return") + case <-time.After(2 * time.Second): + t.Fatal("Receive blocked forever despite Send") + } +} + +func TestAttack_StopBeforeRun_UntilIdleFor_ExitsImmediately(t *testing.T) { + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Stop(UntilIdleFor(10 * time.Minute)) + loop.Stop() + + loop.Run(context.Background()) + + done := make(chan struct{}) + go func() { + loop.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("loop should exit immediately when Stop() called before Run()") + } +} + +func TestAttack_PushAfterStop_UntilIdleFor_RoutedToLateItems(t *testing.T) { + turnDone := make(chan struct{}) + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + close(turnDone) + return &AgentOutput{}, nil + }, + }, nil + }, + }) + + loop.Push("msg1") + <-turnDone + + loop.Stop(UntilIdleFor(50 * time.Millisecond)) + exit := loop.Wait() + assert.NoError(t, exit.ExitReason) + + ok, _ := loop.Push("after-stop") + assert.False(t, ok, "Push after loop exited should return false") + + late := exit.TakeLateItems() + assert.Equal(t, []string{"after-stop"}, late) +} + +func TestAttack_ConcurrentStopEscalation_RaceDetector(t *testing.T) { + agentStarted := make(chan struct{}) + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + close(agentStarted) + <-ctx.Done() + return nil, ctx.Err() + }, + }, nil + }, + }) + + loop.Push("msg1") + <-agentStarted + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + switch i % 4 { + case 0: + loop.Stop() + case 1: + loop.Stop(WithImmediate()) + case 2: + loop.Stop(WithGracefulTimeout(100 * time.Millisecond)) + case 3: + loop.Stop(UntilIdleFor(50 * time.Millisecond)) + } + }(i) + } + + wg.Wait() + exit := loop.Wait() + t.Log("ExitReason:", exit.ExitReason) +} + +func TestAttack_StopCause_FirstNonEmptyWins_ConcurrentCallers(t *testing.T) { + turnDone := make(chan struct{}) + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + close(turnDone) + return &AgentOutput{}, nil + }, + }, nil + }, + }) + + loop.Push("msg1") + <-turnDone + + loop.Stop(WithStopCause("first-cause")) + loop.Stop(WithStopCause("second-cause")) + + exit := loop.Wait() + assert.Equal(t, "first-cause", exit.StopCause, "first non-empty StopCause should win") +} + +func TestAttack_SkipCheckpoint_Sticky(t *testing.T) { + agentStarted := make(chan struct{}) + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + close(agentStarted) + <-ctx.Done() + return nil, ctx.Err() + }, + }, nil + }, + Store: &turnLoopCheckpointStore{m: make(map[string][]byte)}, + CheckpointID: "test-sticky", + }) + + loop.Push("msg1") + <-agentStarted + + loop.Stop(WithSkipCheckpoint()) + loop.Stop(WithImmediate()) + + exit := loop.Wait() + assert.False(t, exit.Checkpointed, "SkipCheckpoint is sticky; checkpoint should be skipped") +} diff --git a/adk/call_option.go b/adk/call_option.go index 55e57fd32..ead6ae636 100644 --- a/adk/call_option.go +++ b/adk/call_option.go @@ -24,6 +24,7 @@ type options struct { checkPointID *string skipTransferMessages bool handlers []callbacks.Handler + cancelCtx *cancelContext } // AgentRunOption is the call option for adk Agent. @@ -157,6 +158,33 @@ func filterCallbackHandlersForNestedAgents(currentAgentName string, opts []Agent return filteredOpts } +// filterCancelOption removes any AgentRunOption that sets a cancelCtx on *options. +// This prevents inner (nested) agents from receiving the cancel option when the +// outer flowAgent owns the cancel lifecycle. Inner agents access the cancelContext +// via the Go context (getCancelContext) instead. +func filterCancelOption(opts []AgentRunOption) []AgentRunOption { + if len(opts) == 0 { + return nil + } + var filteredOpts []AgentRunOption + for i := range opts { + opt := opts[i] + if opt.implSpecificOptFn == nil { + filteredOpts = append(filteredOpts, opt) + continue + } + if _, isCommonOpt := opt.implSpecificOptFn.(func(*options)); isCommonOpt { + testOpt := &options{} + opt.implSpecificOptFn.(func(*options))(testOpt) + if testOpt.cancelCtx != nil { + continue + } + } + filteredOpts = append(filteredOpts, opt) + } + return filteredOpts +} + func filterOptions(agentName string, opts []AgentRunOption) []AgentRunOption { if len(opts) == 0 { return nil diff --git a/adk/cancel.go b/adk/cancel.go new file mode 100644 index 000000000..6d4aa9ad9 --- /dev/null +++ b/adk/cancel.go @@ -0,0 +1,983 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "errors" + "fmt" + "io" + "sync" + "sync/atomic" + "time" + + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" +) + +func init() { + schema.RegisterName[*CancelError]("_eino_adk_cancel_error") + schema.RegisterName[*AgentCancelInfo]("_eino_adk_agent_cancel_info") + schema.RegisterName[*StreamCanceledError]("_eino_adk_stream_cancelled_error") +} + +// CancelMode specifies when an agent should be canceled. +// Modes can be combined with bitwise OR to cancel at multiple safe-points. +// For example, CancelAfterChatModel | CancelAfterToolCalls cancels the agent +// after whichever safe-point is reached first. +type CancelMode int + +const ( + // CancelImmediate cancels the agent as soon as the signal is received, + // without waiting for a ChatModel or ToolCalls safe-point. + // By default, only the root agent is interrupted; descendant agents inside + // AgentTools are torn down via context cancellation as a side effect. + // Use WithRecursive to propagate explicit immediate-cancel signals to + // descendants for clean teardown with grace period. + CancelImmediate CancelMode = 0 + // CancelAfterChatModel cancels after the root agent's next chat model call + // completes. By default, only the root agent checks this safe-point; + // nested sub-agents inside AgentTools are unaware of the cancel. + // Use WithRecursive to propagate the cancel to all descendants — whichever + // ChatModel finishes first triggers the cancel. + CancelAfterChatModel CancelMode = 1 << iota + // CancelAfterToolCalls cancels after the root agent's next set of concurrent + // tool calls completes. By default, only the root agent checks this safe-point. + // Use WithRecursive to propagate to all descendants. + CancelAfterToolCalls +) + +// CancelHandle represents a cancel operation that can be waited on. +type CancelHandle struct { + wait func() error +} + +// Wait blocks until the cancel request reaches a terminal outcome. +// +// It reports the result of the cancel operation itself, not the agent's final +// business error: +// - nil: cancellation succeeded, including the case where a business interrupt +// was absorbed into CancelError while cancellation was active +// - ErrCancelTimeout: the requested safe-point cancellation timed out and was +// escalated to immediate cancellation +// - ErrExecutionCompleted: the execution finished before cancellation took effect, +// meaning the stream drained to completion without any interrupt +func (h *CancelHandle) Wait() error { + return h.wait() +} + +// AgentCancelFunc is called to request cancellation of a running agent. +// It returns after the cancel request is committed; use the returned handle's +// Wait to block for completion and outcome. +// +// The returned bool reports whether this call contributed to the CancelError +// for the current execution. "Contributed" means this call's cancel options +// were included before cancellation was finalized. It is false when cancellation +// was already finalized (handled or execution completed). +type AgentCancelFunc func(...AgentCancelOption) (*CancelHandle, bool) + +type agentCancelConfig struct { + Mode CancelMode + Recursive bool + Timeout *time.Duration +} + +// AgentCancelOption configures cancel behavior. +type AgentCancelOption func(*agentCancelConfig) + +// WithAgentCancelMode sets the cancel mode for the agent cancel operation. +func WithAgentCancelMode(mode CancelMode) AgentCancelOption { + return func(config *agentCancelConfig) { + config.Mode = mode + } +} + +// WithAgentCancelTimeout sets a timeout for the cancel operation. +// This only applies to safe-point modes (CancelAfterChatModel, CancelAfterToolCalls): +// if the safe-point hasn't fired within this duration, the cancel escalates to +// an immediate graph interrupt. +// For CancelImmediate this timeout is ignored — the graph interrupt fires +// immediately with timeout=0. +func WithAgentCancelTimeout(timeout time.Duration) AgentCancelOption { + return func(config *agentCancelConfig) { + config.Timeout = &timeout + } +} + +// WithRecursive opts into recursive cancel propagation. By default, cancel +// modes only affect the root agent; descendant agents inside AgentTools are +// not notified. WithRecursive makes the cancel propagate to all descendants: +// - CancelAfterChatModel / CancelAfterToolCalls: descendants check their own safe-points. +// - CancelImmediate: descendants receive explicit immediate-cancel signals for +// clean teardown; the root uses a grace period to collect child interrupts. +// +// Once any cancel call includes WithRecursive, the flag stays set for the +// entire cancel lifecycle (monotonic escalation). +func WithRecursive() AgentCancelOption { + return func(config *agentCancelConfig) { + config.Recursive = true + } +} + +// AgentCancelInfo contains information about a cancel operation. +type AgentCancelInfo struct { + Mode CancelMode + Escalated bool + Timeout bool +} + +// CancelError is sent via AgentEvent.Err when an agent is canceled. +// Use errors.As to match and extract *CancelError from event errors. +// +// Interrupt absorption: when a cancel is active (shouldCancel() == true), ANY +// interrupt — whether from a cancel safe-point node or from business logic +// (e.g. compose.Interrupt in a tool) — is converted to a CancelError. The +// cancel "absorbs" the business interrupt. This is intentional: +// +// - In concurrent execution (parallel workflows, concurrent tool calls), +// cancel-induced and business interrupts can arrive as a single composite +// signal that cannot be split apart. +// - Even in sequential execution, treating business interrupts as CancelError +// during active cancel gives consistent semantics. +// - The business interrupt is NOT lost — the checkpoint preserves the full +// interrupt hierarchy. On resume (Runner.Resume), the agent re-executes +// the interrupting code path and the business interrupt re-fires naturally. +type CancelError struct { + Info *AgentCancelInfo + + // CheckPointID is the checkpoint ID associated with this cancel operation. + // When non-empty, the cancelled agent's state has been persisted under this ID + // and can be resumed via Runner.Resume or GenInputResult.ResumeFromCheckpointID. + CheckPointID string + + // InterruptContexts provides the interrupt contexts needed for targeted + // resumption via Runner.ResumeWithParams. Each context represents a step + // in the agent hierarchy that was interrupted. This is a slice because + // composite agents (e.g. parallel workflows) may interrupt at multiple + // points simultaneously, matching the shape of AgentAction.Interrupted.InterruptContexts. + // Use each InterruptCtx.ID as a key in ResumeParams.Targets. + InterruptContexts []*InterruptCtx + + interruptSignal *InterruptSignal // unexported — only Runner needs it for checkpoint +} + +func (e *CancelError) Error() string { + return fmt.Sprintf("agent canceled: mode=%v, escalated=%v", e.Info.Mode, e.Info.Escalated) +} + +// Sentinel errors for cancel outcomes. +var ( + // ErrCancelTimeout is returned by CancelHandle.Wait when the cancel operation timed out. + ErrCancelTimeout = errors.New("cancel timed out") + + // ErrExecutionCompleted is returned by CancelHandle.Wait when the agent finished + // before the cancel took effect. "Finished" means the event stream was fully + // drained without any interrupt — normal completion or a fatal error. + // + // Note: business interrupts that occur while cancel is active are absorbed + // into CancelError (see CancelError doc), so they result in nil (cancel + // succeeded), NOT ErrExecutionCompleted. Only execution that completes with + // no interrupt at all produces this error. + ErrExecutionCompleted = errors.New("execution already completed") + + // ErrStreamCanceled is the error sent through the stream when CancelImmediate aborts it. + // It is a *StreamCanceledError so it can be gob-serialized during checkpoint save + // (when stored as agentEventWrapper.StreamErr). + ErrStreamCanceled error = &StreamCanceledError{} +) + +// StreamCanceledError is the concrete error type for ErrStreamCanceled. +// It is exported so that gob can serialize it during checkpoint save when the error +// is stored in agentEventWrapper.StreamErr. +type StreamCanceledError struct{} + +func (e *StreamCanceledError) Error() string { + return "stream canceled" +} + +// WithCancel creates an AgentRunOption that enables cancellation for an agent run. +// It returns the option to pass to Run/Resume and a cancel function. +// Cancel options (mode, timeout) are passed to the returned AgentCancelFunc at call time. +func WithCancel() (AgentRunOption, AgentCancelFunc) { + cc := newCancelContext() + opt := WrapImplSpecificOptFn(func(o *options) { + o.cancelCtx = cc + }) + cancelFn := cc.buildCancelFunc() + return opt, cancelFn +} + +// cancelContext state constants (for int32 CAS). +// +// State transition rules: +// +// stateRunning -> stateCancelling (cancel requested by AgentCancelFunc) +// stateRunning -> stateDone (execution finished without interrupt) +// stateCancelling -> stateCancelHandled (ANY interrupt absorbed as CancelError) +// stateCancelling -> stateDone (execution finished without interrupt while cancel pending) +// +// Terminal states: stateDone, stateCancelHandled. +// +// Note: We intentionally do NOT distinguish between "completed" and "errored" +// terminal states. End-users get the actual outcome from AgentEvent. +// This simplification keeps the state machine minimal — only the cancel/non-cancel +// distinction matters for the AgentCancelFunc return value. +// +// Business interrupt handling: when cancel is active (stateCancelling) and any +// interrupt arrives — cancel-induced OR business — wrapIterWithCancelCtx absorbs +// it as a CancelError and transitions to stateCancelHandled. The business interrupt +// data is preserved in the checkpoint for re-emission on resume. +const ( + // stateRunning is the initial state: agent is executing normally. + stateRunning int32 = 0 + // stateCancelling means AgentCancelFunc has been called and cancelChan is + // closed, but the cancel has not yet been handled by the runFunc. + stateCancelling int32 = 1 + // stateDone means execution has finished through any non-cancel path: + // normal completion, business interrupt, or error. The specific outcome + // is conveyed through AgentEvent, not through the cancel state machine. + stateDone int32 = 2 + // stateCancelHandled means the cancel was processed by the runFunc and a + // CancelError was emitted through the event stream. This is the success + // terminal state for cancellation. + stateCancelHandled int32 = 5 +) + +// interruptSent constants (for int32 CAS). +// +// Transition rules: +// +// interruptNotSent -> interruptImmediate (CancelImmediate or escalation) +const ( + // interruptNotSent means no compose graph interrupt has been sent. + interruptNotSent int32 = 0 + // interruptImmediate means an immediate graph interrupt was sent with + // timeout=0, forcing the graph to stop as soon as possible. + interruptImmediate int32 = 1 +) + +// defaultCancelImmediateGracePeriod is the time a parent's graph interrupt +// waits when the cancelContext has active children (via deriveChild). This +// gives child agents time to propagate their interrupt signal back through +// the agentTool as a CompositeInterrupt. If this proves insufficient for +// deeply nested structures or too slow for latency-sensitive use cases, +// consider making it configurable via an AgentCancelOption. +const defaultCancelImmediateGracePeriod = 1 * time.Second + +type cancelContextKey struct{} + +// withCancelContext stores a cancelContext in the Go context. +func withCancelContext(ctx context.Context, cc *cancelContext) context.Context { + if cc == nil { + return ctx + } + return context.WithValue(ctx, cancelContextKey{}, cc) +} + +// getCancelContext retrieves the cancelContext from the Go context, or nil. +func getCancelContext(ctx context.Context) *cancelContext { + if v := ctx.Value(cancelContextKey{}); v != nil { + return v.(*cancelContext) + } + return nil +} + +type cancelContext struct { + mode int32 // atomic, CancelMode + + cancelChan chan struct{} // closed when cancel is requested (all modes, not just safe-point) + immediateChan chan struct{} // closed when an immediate graph interrupt fires + doneChan chan struct{} // closed when execution completes (by any mark* method) + doneOnce sync.Once // ensures doneChan is closed exactly once + + state int32 // stateRunning, stateCancelling, stateDone, stateCancelHandled + interruptSent int32 // interruptNotSent, interruptImmediate + escalated int32 // 1 if escalated from safe-point to immediate + timeoutEscalated int32 // 1 if escalation was triggered by timeout + startedMode int32 // atomic, mode when state transitioned to cancelling + deadlineUnixNano int64 // atomic, 0 means no deadline + + recursive int32 // atomic; 1 if cancel should propagate to descendant agents via deriveChild + recursiveChan chan struct{} // closed when recursive transitions from 0 to 1 + + root bool // true for the original cancelContext created by WithCancel(); false for derived children + parent *cancelContext // non-nil for derived children; used to decrement parent's activeChildren on markDone + + activeChildren int32 // atomic; number of derived children that haven't called markDone() yet + decrementedParent int32 // atomic CAS guard; ensures parent.activeChildren is decremented at most once + + cancelMu sync.Mutex + timeoutOnce sync.Once + timeoutNotify chan struct{} + + mu sync.Mutex + graphInterruptFuncs []func(...compose.GraphInterruptOption) +} + +func newCancelContext() *cancelContext { + return &cancelContext{ + cancelChan: make(chan struct{}), + immediateChan: make(chan struct{}), + doneChan: make(chan struct{}), + timeoutNotify: make(chan struct{}, 1), + recursiveChan: make(chan struct{}), + root: true, + } +} + +func (cc *cancelContext) isRoot() bool { + return cc != nil && cc.root +} + +func (cc *cancelContext) isRecursive() bool { + return cc != nil && atomic.LoadInt32(&cc.recursive) == 1 +} + +// setRecursive(false) is a no-op; recursive is monotonically escalating: +// once set to true, it cannot be reverted. +func (cc *cancelContext) setRecursive(v bool) { + if v && atomic.CompareAndSwapInt32(&cc.recursive, 0, 1) { + close(cc.recursiveChan) + } +} + +// deriveChild creates a child cancelContext that receives cancel propagation +// from the parent. The caller MUST ensure the child's markDone() is eventually +// called (e.g., via wrapIterWithCancelCtx's defer) or that ctx is canceled; +// otherwise the two propagation goroutines will leak. +func (cc *cancelContext) deriveChild(ctx context.Context) *cancelContext { + if cc == nil { + return nil + } + child := newCancelContext() + child.root = false + child.parent = cc + atomic.AddInt32(&cc.activeChildren, 1) + + // Each goroutine below propagates one signal class (cancel / immediate) to + // the child. The pattern is a two-phase select: + // Phase 1: wait for the parent signal (or child/ctx completion). + // Phase 2: if the signal fired but recursive mode is not active yet, + // enter a second select waiting for either recursive escalation + // (recursiveChan) or child/ctx completion. This ensures + // non-recursive cancels leave children unaware, while a late + // escalation to recursive still propagates. + go func() { + select { + case <-cc.cancelChan: + if cc.isRecursive() { + child.setRecursive(true) + child.triggerCancel(cc.getMode()) + return + } + select { + case <-cc.recursiveChan: + child.setRecursive(true) + child.triggerCancel(cc.getMode()) + case <-child.doneChan: + case <-ctx.Done(): + } + case <-child.doneChan: + case <-ctx.Done(): + } + }() + + go func() { + select { + case <-cc.immediateChan: + if cc.isRecursive() { + child.setRecursive(true) + child.triggerImmediateCancel() + return + } + select { + case <-cc.recursiveChan: + child.setRecursive(true) + child.triggerImmediateCancel() + case <-child.doneChan: + case <-ctx.Done(): + } + case <-child.doneChan: + case <-ctx.Done(): + } + }() + + return child +} + +func (cc *cancelContext) triggerCancel(mode CancelMode) { + cc.setMode(mode) + if atomic.CompareAndSwapInt32(&cc.state, stateRunning, stateCancelling) { + close(cc.cancelChan) + } +} + +func (cc *cancelContext) triggerImmediateCancel() { + atomic.StoreInt32(&cc.escalated, 1) + cc.setMode(CancelImmediate) + if atomic.CompareAndSwapInt32(&cc.state, stateRunning, stateCancelling) { + close(cc.cancelChan) + } + cc.sendImmediateInterrupt() +} + +func (cc *cancelContext) getMode() CancelMode { + if cc == nil { + return CancelImmediate + } + return CancelMode(atomic.LoadInt32(&cc.mode)) +} + +func (cc *cancelContext) setMode(mode CancelMode) { + atomic.StoreInt32(&cc.mode, int32(mode)) +} + +func (cc *cancelContext) getDeadlineUnixNano() int64 { + return atomic.LoadInt64(&cc.deadlineUnixNano) +} + +func (cc *cancelContext) setDeadlineUnixNano(v int64) { + atomic.StoreInt64(&cc.deadlineUnixNano, v) +} + +func (cc *cancelContext) wakeTimeoutController() { + select { + case cc.timeoutNotify <- struct{}{}: + default: + } +} + +// shouldCancel returns true if a cancel has been requested (cancelChan is closed). +func (cc *cancelContext) shouldCancel() bool { + if cc == nil { + return false + } + select { + case <-cc.cancelChan: + return true + default: + return false + } +} + +// isImmediateCancelled returns true if an immediate graph interrupt has been +// fired (CancelImmediate or timeout escalation). This is stronger than +// shouldCancel: it means the compose graph is being torn down right now and +// orphaned goroutines should not attempt to send events. +func (cc *cancelContext) isImmediateCancelled() bool { + if cc == nil { + return false + } + select { + case <-cc.immediateChan: + return true + default: + return false + } +} + +// sendImmediateInterrupt sends the compose graph interrupt signal via graphInterruptFuncs. +// Also closes immediateChan (used by cancelMonitoredModel to abort an in-progress stream). +// Returns false if an interrupt was already sent or if no graphInterruptFuncs have been +// registered yet (the deferred fire in setGraphInterruptFunc will handle that case). +func (cc *cancelContext) sendImmediateInterrupt() bool { + cc.mu.Lock() + + if !atomic.CompareAndSwapInt32(&cc.interruptSent, interruptNotSent, interruptImmediate) { + cc.mu.Unlock() + return false + } + + close(cc.immediateChan) + + fns := make([]func(...compose.GraphInterruptOption), len(cc.graphInterruptFuncs)) + copy(fns, cc.graphInterruptFuncs) + + if len(fns) == 0 { + cc.mu.Unlock() + return false + } + + for _, fn := range fns { + fn(compose.WithGraphInterruptTimeout(0)) + } + cc.mu.Unlock() + return true +} + +// setGraphInterruptFunc appends a graph interrupt function to the list. +// If an immediate cancel was already requested, fires it retroactively. +// Multiple functions can be registered (e.g. one per parallel sub-agent). +// +// Both this method and sendImmediateInterrupt hold cc.mu across the entire +// check-and-fire sequence, ensuring each interrupt function is called exactly +// once (compose.WithGraphInterrupt returns a non-idempotent closure that panics +// on double-call). +func (cc *cancelContext) setGraphInterruptFunc(interrupt func(...compose.GraphInterruptOption)) { + cc.mu.Lock() + cc.graphInterruptFuncs = append(cc.graphInterruptFuncs, interrupt) + + shouldFire := atomic.LoadInt32(&cc.interruptSent) == interruptImmediate + if shouldFire { + interrupt(compose.WithGraphInterruptTimeout(0)) + } + cc.mu.Unlock() +} + +// markDone marks the execution as finished through any non-cancel path +// (normal completion, business interrupt, or error). +// This is safe to call even if a cancel is in progress — it allows the +// cancel func to detect that execution finished before cancel took effect. +func (cc *cancelContext) markDone() { + if cc == nil { + return + } + if atomic.CompareAndSwapInt32(&cc.state, stateRunning, stateDone) || + atomic.CompareAndSwapInt32(&cc.state, stateCancelling, stateDone) { + cc.doneOnce.Do(func() { close(cc.doneChan) }) + cc.detachFromParent() + } +} + +func (cc *cancelContext) detachFromParent() { + if cc.parent != nil && atomic.CompareAndSwapInt32(&cc.decrementedParent, 0, 1) { + atomic.AddInt32(&cc.parent.activeChildren, -1) + } +} + +func (cc *cancelContext) hasActiveChildren() bool { + return cc != nil && atomic.LoadInt32(&cc.activeChildren) > 0 +} + +func (cc *cancelContext) wrapGraphInterruptWithGracePeriod(interrupt func(...compose.GraphInterruptOption)) func(...compose.GraphInterruptOption) { + return func(opts ...compose.GraphInterruptOption) { + // Grace period only applies in recursive mode: in shallow mode, + // children are unaware of the cancel and don't need time to propagate + // their interrupt signals back. + if cc.isRecursive() && cc.hasActiveChildren() { + newOpts := make([]compose.GraphInterruptOption, len(opts)+1) + copy(newOpts, opts) + newOpts[len(opts)] = compose.WithGraphInterruptTimeout(defaultCancelImmediateGracePeriod) + opts = newOpts + } + interrupt(opts...) + } +} + +// markCancelHandled signals that the cancel path in the runFunc has created +// and sent a CancelError. Transitions state to stateCancelHandled so that: +// 1. The deferred markDone() becomes a no-op (CAS from cancelling fails). +// 2. buildCancelFunc sees stateCancelHandled and returns nil (cancel succeeded). +// Returns true if the transition succeeded, false if cancel was already handled +// (e.g., by a sub-agent). This prevents duplicate CancelError emission. +func (cc *cancelContext) markCancelHandled() bool { + if cc == nil { + return false + } + if atomic.CompareAndSwapInt32(&cc.state, stateCancelling, stateCancelHandled) { + cc.doneOnce.Do(func() { close(cc.doneChan) }) + cc.detachFromParent() + return true + } + return false +} + +// createCancelError creates a CancelError based on the current cancel state. +func (cc *cancelContext) createCancelError() *CancelError { + info := &AgentCancelInfo{} + info.Mode = cc.getMode() + if atomic.LoadInt32(&cc.escalated) == 1 { + info.Escalated = true + info.Timeout = atomic.LoadInt32(&cc.timeoutEscalated) == 1 + } + return &CancelError{ + Info: info, + } +} + +func (cc *cancelContext) createAndMarkCancelHandled() (*CancelError, bool) { + cc.cancelMu.Lock() + defer cc.cancelMu.Unlock() + cancelErr := cc.createCancelError() + ok := cc.markCancelHandled() + return cancelErr, ok +} + +// buildCancelFunc builds the AgentCancelFunc for external use. +func (cc *cancelContext) buildCancelFunc() AgentCancelFunc { + joinMode := func(a, b CancelMode) CancelMode { + if a == CancelImmediate || b == CancelImmediate { + return CancelImmediate + } + return a | b + } + + parseReq := func(callOpts ...AgentCancelOption) *agentCancelConfig { + cfg := &agentCancelConfig{Mode: CancelImmediate} + for _, opt := range callOpts { + opt(cfg) + } + return cfg + } + + startTimeoutController := func() { + cc.timeoutOnce.Do(func() { + go func() { + for { + select { + case <-cc.doneChan: + return + default: + } + + mode := cc.getMode() + if mode == CancelImmediate { + return + } + + deadline := cc.getDeadlineUnixNano() + if deadline == 0 { + select { + case <-cc.timeoutNotify: + continue + case <-cc.doneChan: + return + } + } + + now := time.Now().UnixNano() + wait := time.Duration(deadline - now) + if wait <= 0 { + atomic.StoreInt32(&cc.escalated, 1) + atomic.StoreInt32(&cc.timeoutEscalated, 1) + cc.sendImmediateInterrupt() + return + } + + timer := time.NewTimer(wait) + select { + case <-timer.C: + timer.Stop() + atomic.StoreInt32(&cc.escalated, 1) + atomic.StoreInt32(&cc.timeoutEscalated, 1) + cc.sendImmediateInterrupt() + return + case <-cc.timeoutNotify: + timer.Stop() + continue + case <-cc.doneChan: + timer.Stop() + return + } + } + }() + }) + } + + newHandle := func(wait func() error) *CancelHandle { + return &CancelHandle{wait: wait} + } + + waitForCompletion := func() error { + <-cc.doneChan + + st := atomic.LoadInt32(&cc.state) + switch st { + case stateDone: + return ErrExecutionCompleted + default: + if atomic.LoadInt32(&cc.timeoutEscalated) == 1 { + return ErrCancelTimeout + } + return nil + } + } + + return func(callOpts ...AgentCancelOption) (*CancelHandle, bool) { + req := parseReq(callOpts...) + + st := atomic.LoadInt32(&cc.state) + switch st { + case stateCancelHandled: + return newHandle(func() error { return nil }), false + case stateDone: + return newHandle(func() error { return ErrExecutionCompleted }), false + } + + var needImmediate, needTimeoutCtl bool + + cc.cancelMu.Lock() + + st = atomic.LoadInt32(&cc.state) + switch st { + case stateCancelHandled: + cc.cancelMu.Unlock() + return newHandle(func() error { return nil }), false + case stateDone: + cc.cancelMu.Unlock() + return newHandle(func() error { return ErrExecutionCompleted }), false + } + + curMode := cc.getMode() + if st == stateRunning { + if !atomic.CompareAndSwapInt32(&cc.state, stateRunning, stateCancelling) { + st = atomic.LoadInt32(&cc.state) + cc.cancelMu.Unlock() + if st == stateDone { + return newHandle(func() error { return ErrExecutionCompleted }), false + } + return newHandle(waitForCompletion), true + } + + curMode = req.Mode + cc.setMode(curMode) + atomic.StoreInt32(&cc.startedMode, int32(curMode)) + cc.setRecursive(req.Recursive) + close(cc.cancelChan) + } else { + // Recursive is monotonic: once set, cannot be unset. The first + // cancel call uses the bool directly; subsequent calls only + // escalate (false → true) — setRecursive(false) is a no-op. + curMode = joinMode(curMode, req.Mode) + cc.setMode(curMode) + if req.Recursive { + cc.setRecursive(true) + } + } + + if curMode == CancelImmediate { + cc.setDeadlineUnixNano(0) + needImmediate = true + } else if req.Timeout != nil && *req.Timeout > 0 { + proposed := time.Now().Add(*req.Timeout).UnixNano() + existing := cc.getDeadlineUnixNano() + if existing == 0 || proposed < existing { + cc.setDeadlineUnixNano(proposed) + cc.wakeTimeoutController() + } + needTimeoutCtl = cc.getDeadlineUnixNano() != 0 + } + + cc.cancelMu.Unlock() + + if needImmediate { + if atomic.LoadInt32(&cc.startedMode) != int32(CancelImmediate) { + atomic.StoreInt32(&cc.escalated, 1) + } + cc.sendImmediateInterrupt() + } + if needTimeoutCtl { + startTimeoutController() + } + + return newHandle(waitForCompletion), true + } +} + +// wrapIterWithCancelCtx wraps an iterator with cancel lifecycle management. +// It calls markDone when the inner iterator is fully drained, ensuring the +// cancelContext's doneChan is closed and propagation goroutines can exit. +// +// For root cancelContexts (created by WithCancel, not deriveChild), it also +// converts interrupt ACTION events to CancelError when cancel is active. +// This is the single point of interrupt-to-CancelError conversion in the +// system — Runner.handleIter only enriches the resulting CancelError with +// checkpoint metadata. +// +// Interrupt absorption: ALL interrupts are converted when cancel is active, +// including business interrupts (compose.Interrupt from user code). Cancel and +// business interrupts cannot be reliably distinguished in concurrent execution +// (parallel workflows, concurrent tool calls) where they merge into a single +// composite signal. The business interrupt data is preserved in the checkpoint +// and re-fires naturally on resume. +// +// This conversion MUST happen in this wrapper (not deferred to Runner.handleIter) +// because markDone runs as a defer in this goroutine — if the interrupt event +// were passed through unconverted, markDone would transition stateCancelling→stateDone +// before the Runner goroutine could call createAndMarkCancelHandled, causing it +// to fail the CAS. +func wrapIterWithCancelCtx(iter *AsyncIterator[*AgentEvent], cancelCtx *cancelContext) *AsyncIterator[*AgentEvent] { + if cancelCtx == nil { + return iter + } + it, gen := NewAsyncIteratorPair[*AgentEvent]() + go func() { + defer cancelCtx.markDone() + defer gen.Close() + for { + event, ok := iter.Next() + if !ok { + break + } + + if cancelCtx.isRoot() && event.Action != nil && event.Action.internalInterrupted != nil { + if cancelCtx.shouldCancel() { + cancelErr, ok := cancelCtx.createAndMarkCancelHandled() + if ok { + cancelErr.interruptSignal = event.Action.internalInterrupted + gen.Send(&AgentEvent{Err: cancelErr}) + } + return + } + } + + gen.Send(event) + } + }() + return it +} + +// cancelMonitoredModel wraps a model with cancel monitoring. +// Generate: pure delegate to the inner model (CancelAfterChatModel is handled +// by a dedicated node after the ChatModel in the compose graph). +// Stream: pipes chunks through a goroutine that selects on immediateChan for +// CancelImmediate abort. +type cancelMonitoredModel struct { + inner model.BaseChatModel + cancelContext *cancelContext +} + +type recvResult[T any] struct { + data T + err error +} + +func (m *cancelMonitoredModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + return m.inner.Generate(ctx, input, opts...) +} + +func (m *cancelMonitoredModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + stream, err := m.inner.Stream(ctx, input, opts...) + if err != nil { + return nil, err + } + wrapped := wrapStreamWithCancelMonitoring(stream, m.cancelContext) + return wrapped, nil +} + +// wrapStreamWithCancelMonitoring wraps a stream with cancel monitoring. +// When immediateChan fires (CancelImmediate or timeout escalation), the output +// stream is terminated with ErrStreamCanceled. +func wrapStreamWithCancelMonitoring[T any](stream *schema.StreamReader[T], cc *cancelContext) *schema.StreamReader[T] { + if cc == nil { + return stream + } + + // Already canceled — terminate immediately + select { + case <-cc.immediateChan: + stream.Close() + r, w := schema.Pipe[T](1) + var zero T + w.Send(zero, ErrStreamCanceled) + w.Close() + return r + default: + } + + reader, writer := schema.Pipe[T](1) + + go func() { + done := make(chan struct{}) + defer close(done) + defer writer.Close() + defer stream.Close() + + ch := make(chan recvResult[T]) + go func() { + defer close(ch) + for { + chunk, recvErr := stream.Recv() + select { + case ch <- recvResult[T]{chunk, recvErr}: + case <-done: + return + } + if recvErr != nil { + return + } + } + }() + + for { + select { + case <-cc.immediateChan: + var zero T + writer.Send(zero, ErrStreamCanceled) + return + + case r, ok := <-ch: + if !ok { + return + } + if r.err != nil { + if r.err == io.EOF { + return + } + var zero T + writer.Send(zero, r.err) + return + } + if closed := writer.Send(r.data, nil); closed { + return + } + } + } + }() + + return reader +} + +// cancelMonitoredToolHandler wraps streamable tool calls with cancel monitoring. +// When CancelImmediate fires, the tool output stream is terminated with ErrStreamCanceled. +// This handler reads the cancelContext from the Go context via getCancelContext. +type cancelMonitoredToolHandler struct{} + +func (h *cancelMonitoredToolHandler) WrapStreamableToolCall(next compose.StreamableToolEndpoint) compose.StreamableToolEndpoint { + return func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) { + output, err := next(ctx, input) + if err != nil { + return nil, err + } + + cc := getCancelContext(ctx) + if cc == nil { + return output, nil + } + + wrapped := wrapStreamWithCancelMonitoring(output.Result, cc) + return &compose.StreamToolOutput{Result: wrapped}, nil + } +} + +func (h *cancelMonitoredToolHandler) WrapEnhancedStreamableToolCall(next compose.EnhancedStreamableToolEndpoint) compose.EnhancedStreamableToolEndpoint { + return func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedStreamableToolOutput, error) { + output, err := next(ctx, input) + if err != nil { + return nil, err + } + + cc := getCancelContext(ctx) + if cc == nil { + return output, nil + } + + wrapped := wrapStreamWithCancelMonitoring(output.Result, cc) + return &compose.EnhancedStreamableToolOutput{Result: wrapped}, nil + } +} diff --git a/adk/cancel_edge_test.go b/adk/cancel_edge_test.go new file mode 100644 index 000000000..b0afbe674 --- /dev/null +++ b/adk/cancel_edge_test.go @@ -0,0 +1,1268 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/components/tool" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" +) + +// --- helpers shared across edge-case tests --- + +// blockingChatModel blocks until unblockCh is closed, then returns a fixed response. +type blockingChatModel struct { + unblockCh chan struct{} + response *schema.Message + started chan struct{} + callCount int32 +} + +func newBlockingChatModel(response *schema.Message) *blockingChatModel { + return &blockingChatModel{ + unblockCh: make(chan struct{}), + response: response, + started: make(chan struct{}, 1), + } +} + +func (m *blockingChatModel) Generate(ctx context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m.callCount, 1) + select { + case m.started <- struct{}{}: + default: + } + <-m.unblockCh + return m.response, nil +} + +func (m *blockingChatModel) Stream(ctx context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m.callCount, 1) + select { + case m.started <- struct{}{}: + default: + } + <-m.unblockCh + return schema.StreamReaderFromArray([]*schema.Message{m.response}), nil +} + +func (m *blockingChatModel) BindTools(_ []*schema.ToolInfo) error { return nil } + +// errorChatModel returns an error from Generate/Stream. +type errorChatModel struct { + err error + started chan struct{} +} + +func (m *errorChatModel) Generate(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + if m.started != nil { + select { + case m.started <- struct{}{}: + default: + } + } + return nil, m.err +} + +func (m *errorChatModel) Stream(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, m.err +} + +func (m *errorChatModel) BindTools(_ []*schema.ToolInfo) error { return nil } + +// plainResponseModel returns immediately with a fixed text response (no tool calls). +type plainResponseModel struct { + text string +} + +func (m *plainResponseModel) Generate(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return schema.AssistantMessage(m.text, nil), nil +} + +func (m *plainResponseModel) Stream(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage(m.text, nil)}), nil +} + +func (m *plainResponseModel) BindTools(_ []*schema.ToolInfo) error { return nil } + +// blockingTool blocks until unblockCh is closed. +type blockingTool struct { + name string + unblockCh chan struct{} + started chan struct{} + callCount int32 +} + +func newBlockingTool(name string) *blockingTool { + return &blockingTool{ + name: name, + unblockCh: make(chan struct{}), + started: make(chan struct{}, 4), + } +} + +func (t *blockingTool) Info(_ context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: t.name, + Desc: "blocking tool", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "input": {Type: "string"}, + }), + }, nil +} + +func (t *blockingTool) InvokableRun(_ context.Context, _ string, _ ...tool.Option) (string, error) { + atomic.AddInt32(&t.callCount, 1) + select { + case t.started <- struct{}{}: + default: + } + <-t.unblockCh + return "result", nil +} + +func toolCallMsg(calls ...schema.ToolCall) *schema.Message { + return &schema.Message{Role: schema.Assistant, ToolCalls: calls} +} + +func toolCall(id, name, args string) schema.ToolCall { + return schema.ToolCall{ID: id, Type: "function", Function: schema.FunctionCall{Name: name, Arguments: args}} +} + +func drainEvents(iter *AsyncIterator[*AgentEvent]) ([]*AgentEvent, bool) { + var events []*AgentEvent + hasCancelError := false + for { + e, ok := iter.Next() + if !ok { + break + } + events = append(events, e) + var ce *CancelError + if e.Err != nil && errors.As(e.Err, &ce) { + hasCancelError = true + } + } + return events, hasCancelError +} + +// --- tests --- + +// TestWithCancel_BeforeExecutionStarts verifies that a cancel issued before +// the graph begins executing still produces a CancelError without invoking +// the model or tools. +func TestWithCancel_BeforeExecutionStarts(t *testing.T) { + ctx := context.Background() + + blk := newBlockingChatModel(toolCallMsg(toolCall("c1", "bt", `{"input":"x"}`))) + bt := newBlockingTool("bt") + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: blk, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{bt}}, + }, + }) + assert.NoError(t, err) + + cancelOpt, cancelFn := WithCancel() + + // Extract the cancelContext so we can wait for cancelChan to close, + // ensuring the cancel is fully registered before Run starts. + cc := getCommonOptions(nil, cancelOpt).cancelCtx + + // Call cancel BEFORE calling agent.Run. + // The cancelFunc must succeed (not hang) even though execution hasn't started. + cancelDone := make(chan error, 1) + go func() { + handle, _ := cancelFn() + cancelDone <- handle.Wait() + }() + + // Wait for cancelChan to close so the pre-execution check in runFunc + // deterministically sees shouldCancel()=true (eliminates goroutine scheduling race). + <-cc.cancelChan + + // Now start the run — it should see shouldCancel()=true and emit CancelError immediately. + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("hi")}}, cancelOpt) + + _, hasCancelError := drainEvents(iter) + assert.True(t, hasCancelError, "expected CancelError when cancel precedes execution") + + // cancelFn must have already returned (or return quickly now that doneChan is closed). + select { + case cancelErr := <-cancelDone: + // Either nil (cancel handled) or ErrExecutionCompleted is acceptable + // depending on exact timing; what matters is it didn't hang. + _ = cancelErr + case <-time.After(3 * time.Second): + t.Fatal("cancelFn blocked indefinitely after pre-start cancel") + } + + // Model and tool must not have been invoked. + assert.Equal(t, int32(0), atomic.LoadInt32(&bt.callCount), "tool must not be called") +} + +// TestWithCancel_AfterCompletion verifies cancelFn returns ErrExecutionCompleted +// when called after a normal run finishes. +func TestWithCancel_AfterCompletion(t *testing.T) { + ctx := context.Background() + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: &plainResponseModel{text: "done"}, + }) + require.NoError(t, err) + + cancelOpt, cancelFn := WithCancel() + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("hi")}}, cancelOpt) + + // Drain all events so the run completes. + for { + _, ok := iter.Next() + if !ok { + break + } + } + + handle, _ := cancelFn() + cancelErr := handle.Wait() + assert.ErrorIs(t, cancelErr, ErrExecutionCompleted) +} + +// TestWithCancel_AfterBusinessInterrupt verifies cancelFn returns ErrExecutionCompleted +// when called after the agent has been interrupted by business logic. +func TestWithCancel_AfterBusinessInterrupt(t *testing.T) { + ctx := context.Background() + + // Use a model that triggers a compose.Interrupt so the agent stops with an interrupt. + interruptModel := &interruptingChatModel{} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: interruptModel, + }) + require.NoError(t, err) + + store := newCancelTestStore() + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + CheckPointStore: store, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("hi")}, cancelOpt, WithCheckPointID("biz-interrupt-1")) + + // Drain — expect an interrupt action event, no cancel error. + var gotInterrupt bool + for { + e, ok := iter.Next() + if !ok { + break + } + if e.Action != nil && e.Action.Interrupted != nil { + gotInterrupt = true + } + } + assert.True(t, gotInterrupt, "expected business interrupt event") + + handle, _ := cancelFn() + cancelErr := handle.Wait() + assert.ErrorIs(t, cancelErr, ErrExecutionCompleted) +} + +// TestWithCancel_AfterError verifies cancelFn returns ErrExecutionCompleted +// when called after the agent errors out. +func TestWithCancel_AfterError(t *testing.T) { + ctx := context.Background() + + modelErr := errors.New("model exploded") + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: &errorChatModel{err: modelErr}, + }) + require.NoError(t, err) + + cancelOpt, cancelFn := WithCancel() + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("hi")}}, cancelOpt) + + for { + _, ok := iter.Next() + if !ok { + break + } + } + + handle, _ := cancelFn() + cancelErr := handle.Wait() + assert.ErrorIs(t, cancelErr, ErrExecutionCompleted) +} + +// TestWithCancel_TimeoutEscalation tests that WithAgentCancelTimeout causes the +// cancel to escalate to immediate when the safe-point hasn't fired yet, and +// that the resulting CancelError has Escalated=true. +// +// Strategy: use CancelAfterChatModel mode. The model blocks (never completes), +// so the safe-point can't fire naturally. After the timeout, escalateToImmediate +// closes immediateChan which aborts the model stream via cancelMonitoredModel +// and causes a CancelError — no compose graph-interrupt races involved. +func TestWithCancel_TimeoutEscalation(t *testing.T) { + ctx := context.Background() + + blk := newBlockingChatModel(schema.AssistantMessage("hello", nil)) + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: blk, + }) + require.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + EnableStreaming: true, // use streaming so cancelMonitoredModel.Stream is exercised + }) + + timeout := 300 * time.Millisecond + // CancelAfterChatModel + timeout: safe-point can't fire (model never finishes), + // so after 300ms the timeout goroutine escalates to immediate. + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("go")}, cancelOpt) + + select { + case <-blk.started: + case <-time.After(5 * time.Second): + t.Fatal("model did not start") + } + + // Fire cancelFn; it will wait for escalation to complete. + start := time.Now() + handle, _ := cancelFn(WithAgentCancelMode(CancelAfterChatModel), WithAgentCancelTimeout(timeout)) + cancelErr := handle.Wait() + elapsed := time.Since(start) + + assert.ErrorIs(t, cancelErr, ErrCancelTimeout, "cancel should return ErrCancelTimeout after timeout escalation") + assert.True(t, elapsed >= timeout, "should wait at least the timeout duration, elapsed=%v", elapsed) + assert.True(t, elapsed < 3*time.Second, "should complete shortly after timeout, elapsed=%v", elapsed) + + var cancelError *CancelError + for { + e, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if e.Err != nil && errors.As(e.Err, &ce) { + cancelError = ce + } + } + if assert.NotNil(t, cancelError, "expected CancelError after timeout escalation") { + assert.True(t, cancelError.Info.Escalated, "CancelError should report Escalated=true") + assert.True(t, cancelError.Info.Timeout, "CancelError should report Timeout=true") + } +} + +// TestWithCancel_AfterChatModel_WithTools verifies CancelAfterChatModel fires +// when the model returns tool calls (the safe-point is on the tool-calls path). +func TestWithCancel_AfterChatModel_WithTools(t *testing.T) { + ctx := context.Background() + + blk := newBlockingChatModel(toolCallMsg(toolCall("c1", "bt", `{"input":"x"}`))) + bt := newBlockingTool("bt") + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: blk, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{bt}}, + }, + }) + require.NoError(t, err) + + cancelOpt, cancelFn := WithCancel() + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("hi")}}, cancelOpt) + + select { + case <-blk.started: + case <-time.After(5 * time.Second): + t.Fatal("model did not start") + } + + cancelDone := make(chan error, 1) + go func() { + handle, _ := cancelFn(WithAgentCancelMode(CancelAfterChatModel)) + cancelDone <- handle.Wait() + }() + + time.Sleep(20 * time.Millisecond) + + close(blk.unblockCh) + + cancelErr := <-cancelDone + assert.NoError(t, cancelErr) + + _, hasCancelError := drainEvents(iter) + assert.True(t, hasCancelError, "CancelError expected after model returns tool calls") +} + +// TestWithCancel_CancelImmediate_StreamAborted verifies that CancelImmediate +// during model streaming surfaces ErrStreamCanceled and completes quickly. +func TestWithCancel_CancelImmediate_StreamAborted(t *testing.T) { + ctx := context.Background() + + blk := newBlockingChatModel(schema.AssistantMessage("hello", nil)) + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: blk, + }) + require.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + EnableStreaming: true, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("hi")}, cancelOpt) + + select { + case <-blk.started: + case <-time.After(5 * time.Second): + t.Fatal("model did not start") + } + time.Sleep(50 * time.Millisecond) + + start := time.Now() + handle, _ := cancelFn() + cancelErr := handle.Wait() + assert.NoError(t, cancelErr) + elapsed := time.Since(start) + assert.True(t, elapsed < 2*time.Second, "cancel should complete quickly, elapsed=%v", elapsed) + + var foundStreamCanceled bool + for { + e, ok := iter.Next() + if !ok { + break + } + if e.Err != nil && errors.Is(e.Err, ErrStreamCanceled) { + foundStreamCanceled = true + } + var ce *CancelError + if e.Err != nil && errors.As(e.Err, &ce) { + foundStreamCanceled = true // CancelError wraps stream abort + } + } + assert.True(t, foundStreamCanceled, "expected stream-abort error during immediate cancel") +} + +// TestWithCancel_MultipleToolsConcurrent verifies that CancelAfterToolCalls +// waits for ALL concurrent tool calls to complete before cancelling. +func TestWithCancel_MultipleToolsConcurrent(t *testing.T) { + ctx := context.Background() + + bt1 := newBlockingTool("tool1") + bt2 := newBlockingTool("tool2") + + // Model calls both tools in one response. + modelResp := toolCallMsg( + toolCall("c1", "tool1", `{"input":"a"}`), + toolCall("c2", "tool2", `{"input":"b"}`), + ) + modelWithTools := &simpleChatModel{response: modelResp} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: modelWithTools, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{bt1, bt2}}, + }, + }) + assert.NoError(t, err) + + cancelOpt, cancelFn := WithCancel() + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("go")}}, cancelOpt) + + // Wait for both tools to start. + for i := 0; i < 2; i++ { + select { + case <-bt1.started: + case <-bt2.started: + case <-time.After(5 * time.Second): + t.Fatal("tools did not start") + } + } + + // Request cancel after tool calls while both are still blocking. + cancelDone := make(chan error, 1) + go func() { + handle, _ := cancelFn(WithAgentCancelMode(CancelAfterToolCalls)) + cancelDone <- handle.Wait() + }() + + // Unblock both tools — cancel should fire only after both complete. + time.Sleep(50 * time.Millisecond) + close(bt1.unblockCh) + time.Sleep(50 * time.Millisecond) + close(bt2.unblockCh) + + cancelErr := <-cancelDone + assert.NoError(t, cancelErr) + + assert.Equal(t, int32(1), atomic.LoadInt32(&bt1.callCount), "tool1 should complete") + assert.Equal(t, int32(1), atomic.LoadInt32(&bt2.callCount), "tool2 should complete") + + _, hasCancelError := drainEvents(iter) + assert.True(t, hasCancelError, "expected CancelError after concurrent tools completed") +} + +// TestWithCancel_GraphInterruptRaceBeforeSet verifies that a CancelImmediate +// issued before setGraphInterruptFunc is called still results in cancellation. +// This exercises the retroactive-fire path in setGraphInterruptFunc. +func TestWithCancel_GraphInterruptRaceBeforeSet(t *testing.T) { + ctx := context.Background() + + blk := newBlockingChatModel(schema.AssistantMessage("hi", nil)) + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: blk, + }) + require.NoError(t, err) + + cancelOpt, cancelFn := WithCancel() + + // Cancel immediately before run starts. + go func() { + handle, _ := cancelFn() + _ = handle.Wait() + }() + + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("hi")}}, cancelOpt) + + done := make(chan struct{}) + go func() { + defer close(done) + drainEvents(iter) + }() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("iteration did not complete after pre-start CancelImmediate") + } +} + +// TestWithCancel_NoCheckpointStore verifies cancel completes and does not panic +// when no checkpoint store is configured. +func TestWithCancel_NoCheckpointStore(t *testing.T) { + ctx := context.Background() + + blk := newBlockingChatModel(schema.AssistantMessage("hi", nil)) + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: blk, + }) + require.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + // No CheckPointStore set. + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("hi")}, cancelOpt) + + select { + case <-blk.started: + case <-time.After(5 * time.Second): + t.Fatal("model did not start") + } + time.Sleep(30 * time.Millisecond) + + handle, _ := cancelFn() + cancelErr := handle.Wait() + assert.NoError(t, cancelErr) + + var ce *CancelError + for { + e, ok := iter.Next() + if !ok { + break + } + if e.Err != nil && errors.As(e.Err, &ce) { + break + } + } + if assert.NotNil(t, ce, "expected CancelError even without checkpoint store") { + assert.Empty(t, ce.CheckPointID, "CheckPointID should be empty without checkpoint store") + } +} + +// TestWithCancel_ModelError verifies that a model error marks the cancelCtx as +// done so that a subsequent cancelFn call returns ErrExecutionCompleted. +func TestWithCancel_ModelError(t *testing.T) { + ctx := context.Background() + + modelErr := errors.New("model failed") + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: &errorChatModel{err: modelErr}, + }) + require.NoError(t, err) + + cancelOpt, cancelFn := WithCancel() + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("hi")}}, cancelOpt) + + var gotModelErr bool + for { + e, ok := iter.Next() + if !ok { + break + } + if e.Err != nil && !errors.As(e.Err, new(*CancelError)) { + gotModelErr = true + } + } + assert.True(t, gotModelErr, "expected non-cancel error event from model failure") + + handle, _ := cancelFn() + cancelErr := handle.Wait() + assert.ErrorIs(t, cancelErr, ErrExecutionCompleted, "cancelFn should return ErrExecutionCompleted after model error") +} + +// TestWithCancel_Resume_SafePoint covers CancelAfterChatModel and +// CancelAfterToolCalls on a Resume path. +func TestWithCancel_Resume_SafePoint(t *testing.T) { + ctx := context.Background() + + // --- phase 1: run to get a checkpoint via CancelImmediate --- + blk := newBlockingChatModel(toolCallMsg(toolCall("c1", "bt", `{"input":"x"}`))) + bt := newSlowTool("bt", 50*time.Millisecond, "result") + + agent1, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: blk, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{bt}}, + }, + }) + assert.NoError(t, err) + + store := newCancelTestStore() + runner1 := NewRunner(ctx, RunnerConfig{ + Agent: agent1, + CheckPointStore: store, + }) + + cancelOpt1, cancelFn1 := WithCancel() + iter1 := runner1.Run(ctx, []Message{schema.UserMessage("hi")}, cancelOpt1, WithCheckPointID("resume-sp-1")) + + select { + case <-blk.started: + case <-time.After(5 * time.Second): + t.Fatal("model did not start in phase 1") + } + _, _ = cancelFn1() + drainEvents(iter1) + + // --- phase 2: resume, cancel after chat model --- + resumeModel := newBlockingChatModel(toolCallMsg(toolCall("c1", "bt", `{"input":"x"}`))) + + bt2 := newSlowTool("bt", 50*time.Millisecond, "result") + agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: resumeModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{bt2}}, + }, + }) + assert.NoError(t, err) + + runner2 := NewRunner(ctx, RunnerConfig{ + Agent: agent2, + CheckPointStore: store, + }) + + cancelOpt2, cancelFn2 := WithCancel() + resumeIter, err := runner2.Resume(ctx, "resume-sp-1", cancelOpt2) + require.NoError(t, err) + + select { + case <-resumeModel.started: + case <-time.After(5 * time.Second): + t.Fatal("model did not start in phase 2") + } + + cancelDone := make(chan error, 1) + go func() { + handle, _ := cancelFn2(WithAgentCancelMode(CancelAfterChatModel)) + cancelDone <- handle.Wait() + }() + + time.Sleep(50 * time.Millisecond) + + close(resumeModel.unblockCh) + + cancelErr := <-cancelDone + assert.NoError(t, cancelErr) + + _, hasCancelError := drainEvents(resumeIter) + assert.True(t, hasCancelError, "CancelError expected after resumed model returns tool calls") +} + +// callbackTool is a tool that calls onCall when invoked. +type callbackTool struct { + name string + onCall func() +} + +func (t *callbackTool) Info(_ context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: t.name, + Desc: "callback tool", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "input": {Type: "string"}, + }), + }, nil +} + +func (t *callbackTool) InvokableRun(_ context.Context, _ string, _ ...tool.Option) (string, error) { + if t.onCall != nil { + t.onCall() + } + return "ok", nil +} + +// interruptingChatModel returns a compose.Interrupt error to simulate a +// business interrupt during execution. +type interruptingChatModel struct{} + +func (m *interruptingChatModel) Generate(ctx context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, compose.Interrupt(ctx, "test interrupt") +} + +func (m *interruptingChatModel) Stream(ctx context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, compose.Interrupt(ctx, "test interrupt") +} + +func (m *interruptingChatModel) BindTools(_ []*schema.ToolInfo) error { return nil } + +// TestWithCancel_TargetedResume_CancelImmediate cancels an agent via CancelImmediate, +// extracts InterruptContexts from the resulting CancelError, and uses them +// for targeted resumption via Runner.ResumeWithParams. +func TestWithCancel_TargetedResume_CancelImmediate(t *testing.T) { + ctx := context.Background() + + blk := newBlockingChatModel(toolCallMsg(toolCall("c1", "st", `{"input":"x"}`))) + st := newSlowTool("st", 50*time.Millisecond, "result") + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: blk, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{st}}, + }, + }) + require.NoError(t, err) + + store := newCancelTestStore() + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + CheckPointStore: store, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("go")}, cancelOpt, WithCheckPointID("targeted-imm-1")) + + select { + case <-blk.started: + case <-time.After(5 * time.Second): + t.Fatal("model did not start") + } + + handle, _ := cancelFn() // CancelImmediate (default) + cancelErr := handle.Wait() + assert.NoError(t, cancelErr) + + var cancelError *CancelError + for { + e, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if e.Err != nil && errors.As(e.Err, &ce) { + cancelError = ce + } + } + + require.NotNil(t, cancelError, "expected CancelError") + require.NotEmpty(t, cancelError.InterruptContexts, "CancelError should have InterruptContexts for targeted resume") + + // --- resume with targeted params --- + targets := make(map[string]any) + for _, ic := range cancelError.InterruptContexts { + targets[ic.ID] = nil + } + + resumeModel := &plainResponseModel{text: "resumed"} + agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: resumeModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{st}}, + }, + }) + require.NoError(t, err) + + runner2 := NewRunner(ctx, RunnerConfig{ + Agent: agent2, + CheckPointStore: store, + }) + + resumeIter, err := runner2.ResumeWithParams(ctx, "targeted-imm-1", &ResumeParams{Targets: targets}) + require.NoError(t, err) + + var gotOutput bool + for { + e, ok := resumeIter.Next() + if !ok { + break + } + if e.Err != nil { + t.Fatalf("unexpected error during targeted resume: %v", e.Err) + } + if e.Output != nil && e.Output.MessageOutput != nil { + gotOutput = true + } + } + assert.True(t, gotOutput, "targeted resume should produce output") +} + +// TestWithCancel_TargetedResume_SafePoint cancels an agent via CancelAfterChatModel +// (safe-point) and verifies that InterruptContexts are populated on the CancelError +// and that targeted resume via ResumeWithParams succeeds. +// Since safe-point cancels now use compose.Interrupt, compose saves checkpoint data, +// making the cancel fully resumable. +func TestWithCancel_TargetedResume_SafePoint(t *testing.T) { + ctx := context.Background() + + // The model returns a tool call so the react graph routes to toolPreHandle, + // which detects CancelAfterChatModel and fires compose.Interrupt. + blk := newBlockingChatModel(toolCallMsg(toolCall("c1", "st", `{"input":"x"}`))) + st := newSlowTool("st", 0, "result") + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: blk, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{st}}, + }, + }) + require.NoError(t, err) + + store := newCancelTestStore() + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + CheckPointStore: store, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("go")}, cancelOpt, WithCheckPointID("targeted-sp-1")) + + select { + case <-blk.started: + case <-time.After(5 * time.Second): + t.Fatal("model did not start") + } + + // Start cancelFn in background so the CAS happens before the model unblocks. + cancelDone := make(chan error, 1) + go func() { + handle, _ := cancelFn(WithAgentCancelMode(CancelAfterChatModel)) + cancelDone <- handle.Wait() + }() + time.Sleep(50 * time.Millisecond) + close(blk.unblockCh) + + cancelErr := <-cancelDone + assert.NoError(t, cancelErr) + + var cancelError *CancelError + for { + e, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if e.Err != nil && errors.As(e.Err, &ce) { + cancelError = ce + } + } + + require.NotNil(t, cancelError, "expected CancelError") + require.NotEmpty(t, cancelError.InterruptContexts, "CancelError should have InterruptContexts for targeted resume") + + // --- resume with targeted params --- + targets := make(map[string]any) + for _, ic := range cancelError.InterruptContexts { + targets[ic.ID] = nil + } + + resumeModel := &plainResponseModel{text: "resumed"} + agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: resumeModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{st}}, + }, + }) + require.NoError(t, err) + + runner2 := NewRunner(ctx, RunnerConfig{ + Agent: agent2, + CheckPointStore: store, + }) + + resumeIter, err := runner2.ResumeWithParams(ctx, "targeted-sp-1", &ResumeParams{Targets: targets}) + require.NoError(t, err) + + var gotOutput bool + for { + e, ok := resumeIter.Next() + if !ok { + break + } + if e.Err != nil { + t.Fatalf("unexpected error during targeted resume: %v", e.Err) + } + if e.Output != nil && e.Output.MessageOutput != nil { + gotOutput = true + } + } + assert.True(t, gotOutput, "targeted resume should produce output") +} + +// TestWithCancel_Resume_CancelAfterChatModel_MessagePreserved tests both the +// ReAct (with-tools) and noTools paths to ensure that when a +// CancelAfterChatModel safe-point fires and the run is later resumed, the +// original Message returned by the chat model is preserved through the +// StatefulInterrupt checkpoint. +// +// For the ReAct path: the model returns a tool-call message. On resume the +// cancelCheck node must return that same message so the branch routes to the +// ToolNode and the tool actually executes. +// +// For the noTools path: the model returns a plain text message. On resume the +// cancel-check lambda must return that same message as the chain output. +func TestWithCancel_Resume_CancelAfterChatModel_MessagePreserved(t *testing.T) { + t.Run("react_path_tool_call_preserved", func(t *testing.T) { + ctx := context.Background() + + // Phase-2 model returns no tool calls so the graph ends. + // We track whether the tool actually executes on resume. + toolExecuted := make(chan struct{}, 1) + st := &callbackTool{ + name: "my_tool", + onCall: func() { + select { + case toolExecuted <- struct{}{}: + default: + } + }, + } + + // Phase-1 model returns a tool call. + blk := newBlockingChatModel(toolCallMsg(toolCall("c1", "my_tool", `{"input":"x"}`))) + + agent1, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: blk, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{st}}, + }, + }) + require.NoError(t, err) + + store := newCancelTestStore() + runner1 := NewRunner(ctx, RunnerConfig{ + Agent: agent1, + CheckPointStore: store, + }) + + cancelOpt1, cancelFn1 := WithCancel() + iter1 := runner1.Run(ctx, []Message{schema.UserMessage("hi")}, + cancelOpt1, WithCheckPointID("react-msg-preserved-1")) + + select { + case <-blk.started: + case <-time.After(5 * time.Second): + t.Fatal("model did not start in phase 1") + } + + cancelDone := make(chan error, 1) + go func() { + handle, _ := cancelFn1(WithAgentCancelMode(CancelAfterChatModel)) + cancelDone <- handle.Wait() + }() + time.Sleep(50 * time.Millisecond) + close(blk.unblockCh) + + cancelErr := <-cancelDone + assert.NoError(t, cancelErr) + + _, hasCancelError := drainEvents(iter1) + assert.True(t, hasCancelError, "expected CancelError from phase 1") + + // Phase 2: resume. The model for phase-2 returns plain text (no tool + // calls) so the react graph ends after one iteration. But first the + // tool from the checkpoint must execute. + resumeModel := &plainResponseModel{text: "done"} + agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: resumeModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{st}}, + }, + }) + require.NoError(t, err) + + runner2 := NewRunner(ctx, RunnerConfig{ + Agent: agent2, + CheckPointStore: store, + }) + + resumeIter, err := runner2.Resume(ctx, "react-msg-preserved-1") + require.NoError(t, err) + + for { + e, ok := resumeIter.Next() + if !ok { + break + } + if e.Err != nil { + t.Fatalf("unexpected error during resume: %v", e.Err) + } + } + + // The key assertion: the tool must have been called during resume, + // which can only happen if the tool-call message was preserved. + select { + case <-toolExecuted: + // success + default: + t.Fatal("tool was not executed on resume — the tool-call message was lost") + } + }) + +} + +// TestHandleRunFuncError_AlreadyHandled_NoDuplicate verifies that when +// markCancelHandled() was already claimed by a sub-agent's handleRunFuncError, +// the sequential workflow's checkCancel does not emit a second CancelError. +// +// Setup: sequential[cma1, cma2] with CancelAfterToolCalls. cma1 has tools, +// cancel fires while tool is running. After tool completes, the safe-point +// fires in cma1's handleRunFuncError (claiming markCancelHandled). The +// sequential workflow's checkCancel at the transition point should find +// markCancelHandled returns false and skip — producing exactly 1 CancelError. +func TestHandleRunFuncError_AlreadyHandled_NoDuplicate(t *testing.T) { + ctx := context.Background() + + bt := newBlockingTool("bt") + + // cma1: model returns a tool call immediately, tool blocks until unblocked + cma1Model := newBlockingChatModel(toolCallMsg(toolCall("c1", "bt", `{"input":"x"}`))) + close(cma1Model.unblockCh) // model returns immediately + + agent1, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent1", Description: "first", Instruction: "test", + Model: cma1Model, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{bt}}, + }, + }) + require.NoError(t, err) + + agent2Model := &plainResponseModel{text: "agent2-response"} + agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent2", Description: "second", Instruction: "test", + Model: agent2Model, + }) + require.NoError(t, err) + + seqAgent, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "seq", Description: "sequential", SubAgents: []Agent{agent1, agent2}, + }) + require.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: seqAgent, EnableStreaming: false, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt) + + // Wait for tool to start + select { + case <-bt.started: + case <-time.After(5 * time.Second): + t.Fatal("Tool did not start") + } + + // Cancel while tool is still running (in goroutine because cancelFn blocks + // until execution finishes), then unblock tool so safe-point fires + go func() { + handle, _ := cancelFn(WithAgentCancelMode(CancelAfterToolCalls)) + _ = handle.Wait() + }() + + // Give cancel time to register, then unblock tool + time.Sleep(50 * time.Millisecond) + close(bt.unblockCh) + + cancelCount := 0 + for { + event, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + cancelCount++ + } + } + + assert.Equal(t, 1, cancelCount, "Should have exactly one CancelError, no duplicate from handleRunFuncError + checkCancel") +} + +func TestWithCancel_CancelAfterChatModel_NestedAgentTool(t *testing.T) { + ctx := context.Background() + + subAgentModel := newBlockingChatModel(toolCallMsg(toolCall("c1", "sub_tool", `{"input":"x"}`))) + subAgentModelStarted := subAgentModel.started + subTool := newBlockingTool("sub_tool") + + subAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "sub_agent", + Description: "test sub agent", + Instruction: "you are a sub agent", + Model: subAgentModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{Tools: []tool.BaseTool{subTool}}, + }, + }) + require.NoError(t, err) + + supervisorModel := &simpleChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + ToolCalls: []schema.ToolCall{{ + ID: "call_1", Type: "function", + Function: schema.FunctionCall{ + Name: TransferToAgentToolName, + Arguments: `{"agent_name": "sub_agent"}`, + }, + }}, + }, + } + + supervisorAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "supervisor", + Description: "supervisor agent (equivalent to DeepAgent)", + Instruction: "you are a supervisor", + Model: supervisorModel, + }) + require.NoError(t, err) + + agentWithSubAgents, err := SetSubAgents(ctx, supervisorAgent, []Agent{subAgent}) + require.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: agentWithSubAgents, + EnableStreaming: false, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt) + + select { + case <-subAgentModelStarted: + case <-time.After(10 * time.Second): + t.Fatal("Sub-agent model did not start") + } + + time.Sleep(50 * time.Millisecond) + + cancelDone := make(chan error, 1) + go func() { + handle, _ := cancelFn(WithAgentCancelMode(CancelAfterChatModel), WithRecursive()) + cancelDone <- handle.Wait() + }() + + time.Sleep(100 * time.Millisecond) + close(subAgentModel.unblockCh) + + cancelErr := <-cancelDone + assert.NoError(t, cancelErr) + + hasCancelError := false + for { + event, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + hasCancelError = true + } + } + + assert.True(t, hasCancelError, "CancelError expected from nested agent tool with tools") +} diff --git a/adk/cancel_multicall_test.go b/adk/cancel_multicall_test.go new file mode 100644 index 000000000..790d14fb3 --- /dev/null +++ b/adk/cancel_multicall_test.go @@ -0,0 +1,125 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/compose" +) + +func TestAgentCancelFunc_MultiCall_EscalateToImmediate(t *testing.T) { + cc := newCancelContext() + var interruptCalls int32 + cc.setGraphInterruptFunc(func(opts ...compose.GraphInterruptOption) { + atomic.AddInt32(&interruptCalls, 1) + }) + cancelFn := cc.buildCancelFunc() + + handle1, _ := cancelFn(WithAgentCancelMode(CancelAfterChatModel)) + handle2, _ := cancelFn(WithAgentCancelMode(CancelImmediate)) + assert.Equal(t, int32(1), atomic.LoadInt32(&interruptCalls)) + + cancelErr := cc.createCancelError() + assert.Equal(t, CancelImmediate, cancelErr.Info.Mode) + assert.True(t, cancelErr.Info.Escalated) + assert.False(t, cancelErr.Info.Timeout) + + assert.True(t, cc.markCancelHandled()) + assert.NoError(t, handle1.Wait()) + assert.NoError(t, handle2.Wait()) +} + +func TestAgentCancelFunc_MultiCall_JoinSafePointModes(t *testing.T) { + cc := newCancelContext() + cancelFn := cc.buildCancelFunc() + + handle1, _ := cancelFn(WithAgentCancelMode(CancelAfterChatModel)) + handle2, _ := cancelFn(WithAgentCancelMode(CancelAfterToolCalls)) + + want := CancelAfterChatModel | CancelAfterToolCalls + assert.Equal(t, want, cc.getMode()) + + assert.True(t, cc.markCancelHandled()) + assert.NoError(t, handle1.Wait()) + assert.NoError(t, handle2.Wait()) +} + +func TestAgentCancelFunc_MultiCall_TimeoutDeadlineJoinUsesAbsoluteTime(t *testing.T) { + cc := newCancelContext() + cancelFn := cc.buildCancelFunc() + + handle1, _ := cancelFn( + WithAgentCancelMode(CancelAfterChatModel), + WithAgentCancelTimeout(200*time.Millisecond), + ) + + firstDeadline := cc.getDeadlineUnixNano() + assert.NotZero(t, firstDeadline) + + time.Sleep(50 * time.Millisecond) + + handle2, _ := cancelFn( + WithAgentCancelMode(CancelAfterToolCalls), + WithAgentCancelTimeout(60*time.Millisecond), + ) + + secondDeadline := cc.getDeadlineUnixNano() + assert.NotZero(t, secondDeadline) + assert.Less(t, secondDeadline, firstDeadline) + + assert.True(t, cc.markCancelHandled()) + assert.NoError(t, handle1.Wait()) + assert.NoError(t, handle2.Wait()) +} + +func TestAgentCancelFunc_MultiCall_TimeoutEscalationReturnsErrCancelTimeout(t *testing.T) { + cc := newCancelContext() + var interruptCalls int32 + interruptCh := make(chan struct{}, 1) + cc.setGraphInterruptFunc(func(opts ...compose.GraphInterruptOption) { + atomic.AddInt32(&interruptCalls, 1) + select { + case interruptCh <- struct{}{}: + default: + } + }) + cancelFn := cc.buildCancelFunc() + handle, _ := cancelFn( + WithAgentCancelMode(CancelAfterChatModel), + WithAgentCancelTimeout(30*time.Millisecond), + ) + + select { + case <-interruptCh: + case <-time.After(1 * time.Second): + t.Fatal("timeout escalation did not interrupt") + } + assert.Equal(t, int32(1), atomic.LoadInt32(&interruptCalls)) + + cancelErr := cc.createCancelError() + assert.Equal(t, CancelAfterChatModel, cancelErr.Info.Mode) + assert.True(t, cancelErr.Info.Escalated) + assert.True(t, cancelErr.Info.Timeout) + + assert.True(t, cc.markCancelHandled()) + assert.Equal(t, ErrCancelTimeout, handle.Wait()) +} diff --git a/adk/cancel_recursive_test.go b/adk/cancel_recursive_test.go new file mode 100644 index 000000000..9f13f55d2 --- /dev/null +++ b/adk/cancel_recursive_test.go @@ -0,0 +1,409 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "runtime" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/compose" +) + +func assertNotClosedWithin(t *testing.T, ch <-chan struct{}, d time.Duration) { + t.Helper() + select { + case <-ch: + t.Fatal("channel was closed but should not have been") + case <-time.After(d): + } +} + +func setupParentChild(t *testing.T) (parent, child *cancelContext, cleanup func()) { + parent = newCancelContext() + ctx, cancel := context.WithCancel(context.Background()) + child = parent.deriveChild(ctx) + cleanup = func() { + child.markDone() + cancel() + } + t.Cleanup(cleanup) + return parent, child, cleanup +} + +func TestDeriveChild(t *testing.T) { + t.Run("Shallow", func(t *testing.T) { + t.Run("DoesNotPropagateSafePoint", func(t *testing.T) { + parent, child, _ := setupParentChild(t) + + parent.triggerCancel(CancelAfterChatModel) + + assertNotClosedWithin(t, child.cancelChan, 50*time.Millisecond) + }) + + t.Run("ImmediateDoesNotPropagate", func(t *testing.T) { + parent, child, _ := setupParentChild(t) + + parent.triggerImmediateCancel() + + assertNotClosedWithin(t, child.immediateChan, 50*time.Millisecond) + }) + + t.Run("GrandchildNoPropagation", func(t *testing.T) { + a := newCancelContext() + ctx, cancel := context.WithCancel(context.Background()) + + b := a.deriveChild(ctx) + c := b.deriveChild(ctx) + t.Cleanup(func() { + c.markDone() + b.markDone() + cancel() + }) + + a.triggerCancel(CancelAfterChatModel) + + assertNotClosedWithin(t, b.cancelChan, 50*time.Millisecond) + assertNotClosedWithin(t, c.cancelChan, 50*time.Millisecond) + }) + + t.Run("NeverRecursive_GoroutineCleanup", func(t *testing.T) { + runtime.GC() + time.Sleep(50 * time.Millisecond) + before := runtime.NumGoroutine() + + parent := newCancelContext() + ctx, cancel := context.WithCancel(context.Background()) + + child := parent.deriveChild(ctx) + + parent.triggerCancel(CancelAfterChatModel) + time.Sleep(100 * time.Millisecond) + + child.markDone() + cancel() + + time.Sleep(200 * time.Millisecond) + runtime.GC() + time.Sleep(50 * time.Millisecond) + after := runtime.NumGoroutine() + + assert.InDelta(t, before, after, 5, "goroutine leak detected: before=%d after=%d", before, after) + }) + }) + + t.Run("Recursive", func(t *testing.T) { + t.Run("PropagatesSafePoint", func(t *testing.T) { + parent, child, _ := setupParentChild(t) + + parent.setRecursive(true) + parent.triggerCancel(CancelAfterChatModel) + + select { + case <-child.cancelChan: + case <-time.After(1 * time.Second): + t.Fatal("child did not receive cancel within 1s") + } + assert.True(t, child.shouldCancel()) + }) + + t.Run("ImmediatePropagates", func(t *testing.T) { + parent, child, _ := setupParentChild(t) + + parent.setRecursive(true) + parent.triggerImmediateCancel() + + select { + case <-child.immediateChan: + case <-time.After(1 * time.Second): + t.Fatal("child did not receive immediate cancel within 1s") + } + assert.True(t, child.isImmediateCancelled()) + }) + + t.Run("GrandchildPropagation", func(t *testing.T) { + a := newCancelContext() + ctx, cancel := context.WithCancel(context.Background()) + + b := a.deriveChild(ctx) + c := b.deriveChild(ctx) + t.Cleanup(func() { + c.markDone() + b.markDone() + cancel() + }) + + a.setRecursive(true) + a.triggerCancel(CancelAfterChatModel) + + select { + case <-b.cancelChan: + case <-time.After(1 * time.Second): + t.Fatal("B did not receive cancel within 1s") + } + + select { + case <-c.cancelChan: + case <-time.After(1 * time.Second): + t.Fatal("C did not receive cancel within 1s") + } + + assert.True(t, b.shouldCancel()) + assert.True(t, c.shouldCancel()) + }) + + t.Run("SetBeforeCancel", func(t *testing.T) { + parent, child, _ := setupParentChild(t) + + parent.setRecursive(true) + + parent.triggerCancel(CancelAfterChatModel) + + select { + case <-child.cancelChan: + case <-time.After(1 * time.Second): + t.Fatal("child did not receive cancel within 1s") + } + assert.True(t, child.shouldCancel()) + }) + + t.Run("AfterRecursiveAndCancelAlreadySet", func(t *testing.T) { + parent := newCancelContext() + ctx, cancel := context.WithCancel(context.Background()) + + parent.setRecursive(true) + parent.triggerCancel(CancelAfterChatModel) + + child := parent.deriveChild(ctx) + t.Cleanup(func() { + child.markDone() + cancel() + }) + + select { + case <-child.cancelChan: + case <-time.After(1 * time.Second): + t.Fatal("child did not immediately receive cancel") + } + assert.True(t, child.shouldCancel()) + }) + }) + + t.Run("Escalation", func(t *testing.T) { + t.Run("EscalateFromNonRecursive", func(t *testing.T) { + parent, child, _ := setupParentChild(t) + + parent.triggerCancel(CancelAfterChatModel) + + assertNotClosedWithin(t, child.cancelChan, 50*time.Millisecond) + + parent.setRecursive(true) + + select { + case <-child.cancelChan: + case <-time.After(1 * time.Second): + t.Fatal("child did not receive cancel after escalation within 1s") + } + assert.True(t, child.shouldCancel()) + }) + + t.Run("EscalateImmediate", func(t *testing.T) { + parent, child, _ := setupParentChild(t) + + parent.triggerImmediateCancel() + + assertNotClosedWithin(t, child.immediateChan, 50*time.Millisecond) + + parent.setRecursive(true) + + select { + case <-child.immediateChan: + case <-time.After(1 * time.Second): + t.Fatal("child did not receive immediate cancel after escalation within 1s") + } + assert.True(t, child.isImmediateCancelled()) + }) + }) +} + +func TestDeriveChild_Race(t *testing.T) { + t.Run("SetRecursiveConcurrentWithCancelChan", func(t *testing.T) { + for i := 0; i < 100; i++ { + parent := newCancelContext() + ctx, cancel := context.WithCancel(context.Background()) + + child := parent.deriveChild(ctx) + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + parent.setRecursive(true) + }() + + go func() { + defer wg.Done() + parent.triggerCancel(CancelAfterChatModel) + }() + + wg.Wait() + + select { + case <-child.cancelChan: + case <-time.After(1 * time.Second): + t.Fatalf("iteration %d: child did not receive cancel within 1s", i) + } + + assert.True(t, child.shouldCancel()) + child.markDone() + cancel() + } + }) + + t.Run("ChildCompletesBeforeEscalation", func(t *testing.T) { + parent := newCancelContext() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + child := parent.deriveChild(ctx) + + parent.triggerCancel(CancelAfterChatModel) + time.Sleep(50 * time.Millisecond) + + child.markDone() + time.Sleep(50 * time.Millisecond) + + parent.setRecursive(true) + + assertNotClosedWithin(t, child.cancelChan, 50*time.Millisecond) + }) + + t.Run("MultipleChildren_PartialCompletion", func(t *testing.T) { + parent := newCancelContext() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + child1 := parent.deriveChild(ctx) + child2 := parent.deriveChild(ctx) + + parent.triggerCancel(CancelAfterChatModel) + time.Sleep(50 * time.Millisecond) + + child1.markDone() + time.Sleep(50 * time.Millisecond) + + parent.setRecursive(true) + + select { + case <-child2.cancelChan: + case <-time.After(1 * time.Second): + t.Fatal("running child did not receive cancel within 1s") + } + + assert.True(t, child2.shouldCancel()) + assert.False(t, child1.shouldCancel()) + child2.markDone() + }) + + t.Run("ContextCancelConcurrentWithRecursive", func(t *testing.T) { + done := make(chan struct{}) + go func() { + defer close(done) + + parent := newCancelContext() + ctx, cancel := context.WithCancel(context.Background()) + + child := parent.deriveChild(ctx) + + parent.triggerCancel(CancelAfterChatModel) + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + cancel() + }() + + go func() { + defer wg.Done() + parent.setRecursive(true) + }() + + wg.Wait() + child.markDone() + }() + + select { + case <-done: + case <-time.After(1 * time.Second): + t.Fatal("deadlock detected") + } + }) + + t.Run("ConcurrentSetRecursive", func(t *testing.T) { + parent := newCancelContext() + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + parent.setRecursive(true) + }() + } + + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(1 * time.Second): + t.Fatal("deadlock or panic in concurrent setRecursive") + } + + assert.True(t, parent.isRecursive()) + }) +} + +func TestGracePeriod_OnlyWhenRecursive(t *testing.T) { + parent, _, _ := setupParentChild(t) + + var nonRecursiveOptCount int + wrappedNonRecursive := parent.wrapGraphInterruptWithGracePeriod(func(opts ...compose.GraphInterruptOption) { + nonRecursiveOptCount = len(opts) + }) + wrappedNonRecursive() + assert.Equal(t, 0, nonRecursiveOptCount) + + parent.setRecursive(true) + + var recursiveOptCount int + wrappedRecursive := parent.wrapGraphInterruptWithGracePeriod(func(opts ...compose.GraphInterruptOption) { + recursiveOptCount = len(opts) + }) + wrappedRecursive() + assert.Equal(t, 1, recursiveOptCount) +} diff --git a/adk/cancel_test.go b/adk/cancel_test.go new file mode 100644 index 000000000..2096a9ac3 --- /dev/null +++ b/adk/cancel_test.go @@ -0,0 +1,3728 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "errors" + "fmt" + "io" + "runtime" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/components/tool" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" +) + +type cancelTestChatModel struct { + delayNs int64 + response *schema.Message + startedChan chan struct{} + doneChan chan struct{} +} + +func (m *cancelTestChatModel) getDelay() time.Duration { + return time.Duration(atomic.LoadInt64(&m.delayNs)) +} + +func (m *cancelTestChatModel) setDelay(d time.Duration) { + atomic.StoreInt64(&m.delayNs, int64(d)) +} + +func (m *cancelTestChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + select { + case m.startedChan <- struct{}{}: + default: + } + select { + case <-time.After(m.getDelay()): + case <-ctx.Done(): + return nil, ctx.Err() + } + select { + case m.doneChan <- struct{}{}: + default: + } + return m.response, nil +} + +func (m *cancelTestChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + m.startedChan <- struct{}{} + time.Sleep(m.getDelay()) + m.doneChan <- struct{}{} + return schema.StreamReaderFromArray([]*schema.Message{m.response}), nil +} + +func (m *cancelTestChatModel) BindTools(tools []*schema.ToolInfo) error { + return nil +} + +type slowTool struct { + name string + delay time.Duration + result string + callCount int32 + startedChan chan struct{} +} + +func newSlowTool(name string, delay time.Duration, result string) *slowTool { + return &slowTool{ + name: name, + delay: delay, + result: result, + startedChan: make(chan struct{}, 10), + } +} + +func (t *slowTool) Info(ctx context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: t.name, + Desc: "A slow tool for testing", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "input": {Type: "string", Desc: "Input parameter"}, + }), + }, nil +} + +func (t *slowTool) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { + atomic.AddInt32(&t.callCount, 1) + select { + case t.startedChan <- struct{}{}: + default: + } + select { + case <-time.After(t.delay): + case <-ctx.Done(): + return "", ctx.Err() + } + return t.result, nil +} + +type cancelTestStore struct { + m map[string][]byte + mu sync.Mutex +} + +func newCancelTestStore() *cancelTestStore { + return &cancelTestStore{m: make(map[string][]byte)} +} + +func (s *cancelTestStore) Set(_ context.Context, key string, value []byte) error { + s.mu.Lock() + defer s.mu.Unlock() + s.m[key] = value + return nil +} + +func (s *cancelTestStore) Get(_ context.Context, key string) ([]byte, bool, error) { + s.mu.Lock() + defer s.mu.Unlock() + v, ok := s.m[key] + return v, ok, nil +} + +func assertHasCancelError(t *testing.T, events []*AgentEvent) { + t.Helper() + for _, e := range events { + var ce *CancelError + if e.Err != nil && errors.As(e.Err, &ce) { + return + } + } + t.Fatal("expected CancelError in events") +} + +func drainAndAssertCancelError(t *testing.T, iter *AsyncIterator[*AgentEvent]) { + t.Helper() + for { + event, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + return + } + } + t.Fatal("expected CancelError in event stream") +} + +func drainEventsAndAssertCancelError(t *testing.T, iter *AsyncIterator[*AgentEvent]) []*AgentEvent { + t.Helper() + var events []*AgentEvent + hasCancelError := false + for { + event, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + hasCancelError = true + } + events = append(events, event) + } + assert.True(t, hasCancelError, "expected CancelError in event stream") + return events +} + +func TestCancelContext(t *testing.T) { + t.Run("BasicCancelContext", func(t *testing.T) { + cc := newCancelContext() + assert.False(t, cc.shouldCancel(), "Should not be cancelled initially") + + cc.setMode(CancelImmediate) + close(cc.cancelChan) + + assert.True(t, cc.shouldCancel(), "Should be cancelled after close(cancelChan)") + assert.Equal(t, CancelImmediate, cc.getMode()) + }) +} + +func TestWithCancel_WithTools(t *testing.T) { + ctx := context.Background() + + t.Run("CancelImmediate_DuringModelCall", func(t *testing.T) { + modelStarted := make(chan struct{}, 1) + st := newSlowTool("slow_tool", 100*time.Millisecond, "tool result") + + slowModel := &cancelTestChatModel{ + delayNs: int64(2 * time.Second), + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: schema.FunctionCall{ + Name: "slow_tool", + Arguments: `{"input": "test"}`, + }, + }, + }, + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent with tool", + Instruction: "You are a test assistant", + Model: slowModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + EnableStreaming: false, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("Use the tool")}, cancelOpt) + assert.NotNil(t, iter) + assert.NotNil(t, cancelFn) + + eventsCh := make(chan []*AgentEvent, 1) + go func() { + var events []*AgentEvent + for { + event, ok := iter.Next() + if !ok { + break + } + events = append(events, event) + } + eventsCh <- events + }() + + select { + case <-modelStarted: + case <-time.After(5 * time.Second): + t.Fatal("Model did not start within 5 seconds") + } + + time.Sleep(100 * time.Millisecond) + + handle, _ := cancelFn() + err = handle.Wait() + assert.NoError(t, err) + + var events []*AgentEvent + select { + case events = <-eventsCh: + case <-time.After(5 * time.Second): + t.Fatal("Timed out waiting for events") + } + + assert.NotEmpty(t, events) + + assertHasCancelError(t, events) + }) + + t.Run("CancelAfterChatModel_DuringToolCall", func(t *testing.T) { + toolStarted := make(chan struct{}, 1) + st := &slowToolWithSignal{ + name: "slow_tool", + delay: 2 * time.Second, + result: "tool result", + startedChan: toolStarted, + } + + modelWithToolCall := &simpleChatModel{ + delay: 1 * time.Second, + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: schema.FunctionCall{ + Name: "slow_tool", + Arguments: `{"input": "test"}`, + }, + }, + }, + }, + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent with tool", + Instruction: "You are a test assistant", + Model: modelWithToolCall, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + + cancelOpt, cancelFn := WithCancel() + iter := agent.Run(ctx, &AgentInput{ + Messages: []Message{schema.UserMessage("Use the tool")}, + }, cancelOpt) + assert.NotNil(t, iter) + assert.NotNil(t, cancelFn) + + <-toolStarted + + time.Sleep(100 * time.Millisecond) + + handle, _ := cancelFn(WithAgentCancelMode(CancelAfterChatModel)) + err = handle.Wait() + assert.NoError(t, err) + + var events []*AgentEvent + for { + event, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + continue + } + events = append(events, event) + } + + assert.NotEmpty(t, events) + assert.True(t, atomic.LoadInt32(&st.callCount) >= 1, "Tool should have been called") + }) + + t.Run("CancelAfterToolCalls_CompletesToolExecution", func(t *testing.T) { + toolStarted := make(chan struct{}, 1) + st := &slowToolWithSignal{ + name: "slow_tool", + delay: 500 * time.Millisecond, + result: "tool result", + startedChan: toolStarted, + } + + modelWithToolCall := &simpleChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: schema.FunctionCall{ + Name: "slow_tool", + Arguments: `{"input": "test"}`, + }, + }, + }, + }, + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent with tool", + Instruction: "You are a test assistant", + Model: modelWithToolCall, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + + cancelOpt, cancelFn := WithCancel() + iter := agent.Run(ctx, &AgentInput{ + Messages: []Message{schema.UserMessage("Use the tool")}, + }, cancelOpt) + assert.NotNil(t, iter) + assert.NotNil(t, cancelFn) + + <-toolStarted + + time.Sleep(100 * time.Millisecond) + + handle, _ := cancelFn(WithAgentCancelMode(CancelAfterToolCalls)) + err = handle.Wait() + assert.NoError(t, err) + + var events []*AgentEvent + for { + event, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + continue + } + events = append(events, event) + } + + assert.NotEmpty(t, events) + assert.True(t, atomic.LoadInt32(&st.callCount) >= 1, "Tool should have been called") + }) + + t.Run("NestedCancelPropagation", func(t *testing.T) { + cc := newCancelContext() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + child := cc.deriveChild(ctx) + assert.NotNil(t, child) + + cc.setRecursive(true) + cc.setMode(CancelImmediate) + + if atomic.CompareAndSwapInt32(&cc.state, stateRunning, stateCancelling) { + close(cc.cancelChan) + } + + select { + case <-child.cancelChan: + case <-time.After(1 * time.Second): + t.Fatal("Child did not receive cancel signal") + } + + assert.True(t, child.shouldCancel()) + assert.Equal(t, CancelImmediate, child.getMode()) + }) + + t.Run("DeepAgentIntegrationCancel", func(t *testing.T) { + ctx := context.Background() + modelStarted := make(chan struct{}, 1) + + leafModel := &cancelTestChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + Content: "Leaf result", + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + leafModel.setDelay(500 * time.Millisecond) + leafAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "LeafAgent", + Description: "desc", + Model: leafModel, + }) + assert.NoError(t, err) + + rootModel := &cancelTestChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: schema.FunctionCall{ + Name: "LeafAgent", + Arguments: `{}`, + }, + }, + }, + }, + startedChan: make(chan struct{}, 1), + doneChan: make(chan struct{}, 1), + } + rootAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "RootAgent", + Description: "desc", + Model: rootModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{NewAgentTool(ctx, leafAgent)}, + }, + }, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: rootAgent, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("Run leaf")}, cancelOpt) + + <-modelStarted + + handle, _ := cancelFn(WithAgentCancelMode(CancelAfterChatModel), WithRecursive()) + err = handle.Wait() + assert.NoError(t, err) + + hasCancelError := false + for { + event, ok := iter.Next() + if !ok { + break + } + if event.Err != nil { + var ce *CancelError + if errors.As(event.Err, &ce) { + hasCancelError = true + assert.NotNil(t, ce.interruptSignal, "CancelError should carry interrupt signal") + } + } + } + assert.True(t, hasCancelError, "Should have received CancelError") + }) +} + +type slowToolWithSignal struct { + name string + delay time.Duration + result string + callCount int32 + startedChan chan struct{} +} + +func (t *slowToolWithSignal) Info(ctx context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: t.name, + Desc: "A slow tool for testing", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "input": {Type: "string", Desc: "Input parameter"}, + }), + }, nil +} + +func (t *slowToolWithSignal) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { + atomic.AddInt32(&t.callCount, 1) + t.startedChan <- struct{}{} + time.Sleep(t.delay) + return t.result, nil +} + +type simpleChatModel struct { + delay time.Duration + response *schema.Message +} + +func (m *simpleChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + if m.delay > 0 { + select { + case <-time.After(m.delay): + case <-ctx.Done(): + return nil, ctx.Err() + } + } + return m.response, nil +} + +func (m *simpleChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + if m.delay > 0 { + select { + case <-time.After(m.delay): + case <-ctx.Done(): + return nil, ctx.Err() + } + } + return schema.StreamReaderFromArray([]*schema.Message{m.response}), nil +} + +func (m *simpleChatModel) BindTools(tools []*schema.ToolInfo) error { + return nil +} + +func TestWithCancel_WithCheckpoint(t *testing.T) { + ctx := context.Background() + + t.Run("CancelWithCheckpoint", func(t *testing.T) { + modelStarted := make(chan struct{}, 1) + st := newSlowTool("slow_tool", 100*time.Millisecond, "tool result") + + slowModel := &cancelTestChatModel{ + delayNs: int64(1 * time.Second), + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: schema.FunctionCall{ + Name: "slow_tool", + Arguments: `{"input": "test"}`, + }, + }, + }, + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent with tool", + Instruction: "You are a test assistant", + Model: slowModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + + store := newCancelTestStore() + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + EnableStreaming: false, + CheckPointStore: store, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("Use the tool")}, cancelOpt, WithCheckPointID("cancel-1")) + + <-modelStarted + + handle, _ := cancelFn() + err = handle.Wait() + assert.NoError(t, err) + + var events []*AgentEvent + hasCancelError := false + var cancelErrorCheckPointID string + for { + event, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + hasCancelError = true + cancelErrorCheckPointID = ce.CheckPointID + continue + } + events = append(events, event) + } + + assert.True(t, hasCancelError, "Should have CancelError event after cancel") + assert.Equal(t, "cancel-1", cancelErrorCheckPointID, "CancelError should contain the checkpoint ID") + }) +} + +func TestAgentCancelFuncMultipleCalls(t *testing.T) { + ctx := context.Background() + + t.Run("SecondCancelReturnsErrAgentFinished", func(t *testing.T) { + modelStarted := make(chan struct{}, 1) + st := newSlowTool("slow_tool", 100*time.Millisecond, "tool result") + + slowModel := &cancelTestChatModel{ + delayNs: int64(1 * time.Second), + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: schema.FunctionCall{ + Name: "slow_tool", + Arguments: `{"input": "test"}`, + }, + }, + }, + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent with tool", + Instruction: "You are a test assistant", + Model: slowModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + EnableStreaming: false, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("Use the tool")}, cancelOpt) + + <-modelStarted + + handle, _ := cancelFn() + cancelErr := handle.Wait() + assert.NoError(t, cancelErr) + + for { + _, ok := iter.Next() + if !ok { + break + } + } + }) +} + +func TestWithCancel_Streaming(t *testing.T) { + ctx := context.Background() + + t.Run("CancelImmediate_DuringModelStream", func(t *testing.T) { + modelStarted := make(chan struct{}, 1) + st := newSlowTool("slow_tool", 100*time.Millisecond, "tool result") + + slowModel := &cancelTestChatModel{ + delayNs: int64(2 * time.Second), + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: schema.FunctionCall{ + Name: "slow_tool", + Arguments: `{"input": "test"}`, + }, + }, + }, + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent with tool", + Instruction: "You are a test assistant", + Model: slowModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + EnableStreaming: true, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("Use the tool")}, cancelOpt) + assert.NotNil(t, iter) + assert.NotNil(t, cancelFn) + + eventsCh := make(chan []*AgentEvent, 1) + go func() { + var events []*AgentEvent + for { + event, ok := iter.Next() + if !ok { + break + } + events = append(events, event) + } + eventsCh <- events + }() + + select { + case <-modelStarted: + case <-time.After(5 * time.Second): + t.Fatal("Model did not start within 5 seconds") + } + + time.Sleep(100 * time.Millisecond) + + handle, _ := cancelFn() + cancelErr := handle.Wait() + assert.NoError(t, cancelErr) + + var events []*AgentEvent + select { + case events = <-eventsCh: + case <-time.After(5 * time.Second): + t.Fatal("Timed out waiting for events") + } + + assert.NotEmpty(t, events) + + assertHasCancelError(t, events) + }) + + t.Run("CancelAfterToolCalls_Streaming", func(t *testing.T) { + toolStarted := make(chan struct{}, 1) + st := &slowToolWithSignal{ + name: "slow_tool", + delay: 500 * time.Millisecond, + result: "tool result", + startedChan: toolStarted, + } + + modelWithToolCall := &simpleChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: schema.FunctionCall{ + Name: "slow_tool", + Arguments: `{"input": "test"}`, + }, + }, + }, + }, + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent with tool", + Instruction: "You are a test assistant", + Model: modelWithToolCall, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + EnableStreaming: true, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("Use the tool")}, cancelOpt) + assert.NotNil(t, iter) + assert.NotNil(t, cancelFn) + + <-toolStarted + + time.Sleep(100 * time.Millisecond) + + handle, _ := cancelFn(WithAgentCancelMode(CancelAfterToolCalls)) + cancelErr := handle.Wait() + assert.NoError(t, cancelErr) + + var events []*AgentEvent + for { + event, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + continue + } + events = append(events, event) + } + + assert.NotEmpty(t, events) + assert.True(t, atomic.LoadInt32(&st.callCount) >= 1, "Tool should have been called") + }) +} + +// TestWithCancel_Resume tests the workflow of Cancel followed by Resume. +// +// To avoid data races, we create new agent and runner instances for the Resume phase +// instead of reusing and modifying the original model instance. +func TestWithCancel_Resume(t *testing.T) { + ctx := context.Background() + + t.Run("Cancel_ThenResume", func(t *testing.T) { + modelStarted := make(chan struct{}, 1) + modelCallCount := int32(0) + st := newSlowTool("slow_tool", 100*time.Millisecond, "tool result") + + slowModel := &cancelTestChatModel{ + delayNs: int64(500 * time.Millisecond), + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: schema.FunctionCall{ + Name: "slow_tool", + Arguments: `{"input": "test"}`, + }, + }, + }, + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent with tool", + Instruction: "You are a test assistant", + Model: slowModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + + store := newCancelTestStore() + checkpointID := "resume-cancel-test-1" + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + EnableStreaming: false, + CheckPointStore: store, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("Use the tool")}, cancelOpt, WithCheckPointID(checkpointID)) + + <-modelStarted + atomic.AddInt32(&modelCallCount, 1) + + handle, _ := cancelFn() + cancelErr := handle.Wait() + assert.NoError(t, cancelErr) + + var events []*AgentEvent + hasCancelErr := false + for { + event, ok := iter.Next() + if !ok { + break + } + if event.Err != nil { + var ce *CancelError + if errors.As(event.Err, &ce) { + hasCancelErr = true + continue + } + t.Fatalf("unexpected error: %v", event.Err) + } + events = append(events, event) + } + assert.True(t, hasCancelErr, "Should have CancelError event after cancel") + + newModelStarted := make(chan struct{}, 1) + slowModel2 := &cancelTestChatModel{ + delayNs: int64(100 * time.Millisecond), + response: &schema.Message{ + Role: schema.Assistant, + Content: "Final response after resume", + }, + startedChan: newModelStarted, + doneChan: make(chan struct{}, 1), + } + + agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent with tool", + Instruction: "You are a test assistant", + Model: slowModel2, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + + runner2 := NewRunner(ctx, RunnerConfig{ + Agent: agent2, + EnableStreaming: false, + CheckPointStore: store, + }) + + resumeCancelOpt, _ := WithCancel() + resumeIter, err := runner2.Resume(ctx, checkpointID, resumeCancelOpt) + assert.NoError(t, err) + assert.NotNil(t, resumeIter) + + var resumeEvents []*AgentEvent + for { + event, ok := resumeIter.Next() + if !ok { + break + } + assert.Nil(t, event.Err, "Should not have error event during resume") + resumeEvents = append(resumeEvents, event) + } + + assert.NotEmpty(t, resumeEvents, "Resume should produce events") + }) + + t.Run("Resume_ThenCancel", func(t *testing.T) { + firstModelStarted := make(chan struct{}, 1) + modelCallCount := int32(0) + st := newSlowTool("slow_tool", 100*time.Millisecond, "tool result") + + slowModel := &cancelTestChatModel{ + delayNs: int64(500 * time.Millisecond), + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: schema.FunctionCall{ + Name: "slow_tool", + Arguments: `{"input": "test"}`, + }, + }, + }, + }, + startedChan: firstModelStarted, + doneChan: make(chan struct{}, 1), + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent with tool", + Instruction: "You are a test assistant", + Model: slowModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + + store := newCancelTestStore() + checkpointID := "resume-then-cancel-test-1" + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + EnableStreaming: false, + CheckPointStore: store, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("Use the tool")}, cancelOpt, WithCheckPointID(checkpointID)) + + <-firstModelStarted + atomic.AddInt32(&modelCallCount, 1) + + handle, _ := cancelFn() + cancelErr := handle.Wait() + assert.NoError(t, cancelErr) + + for { + _, ok := iter.Next() + if !ok { + break + } + } + + slowModel2 := newBlockingChatModel(toolCallMsg(toolCall("call_1", "slow_tool", `{"input": "test"}`))) + + agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent with tool", + Instruction: "You are a test assistant", + Model: slowModel2, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + + runner2 := NewRunner(ctx, RunnerConfig{ + Agent: agent2, + EnableStreaming: false, + CheckPointStore: store, + }) + + resumeCancelOpt, resumeCancelFn := WithCancel() + resumeIter, err := runner2.Resume(ctx, checkpointID, resumeCancelOpt) + assert.NoError(t, err) + + resumeEventsCh := make(chan []*AgentEvent, 1) + go func() { + var events []*AgentEvent + for { + event, ok := resumeIter.Next() + if !ok { + break + } + events = append(events, event) + } + resumeEventsCh <- events + }() + + <-slowModel2.started + atomic.AddInt32(&modelCallCount, 1) + + cancelHandle, _ := resumeCancelFn() + close(slowModel2.unblockCh) + err = cancelHandle.Wait() + assert.True(t, err == nil || errors.Is(err, ErrExecutionCompleted), "unexpected cancel wait error: %v", err) + + start := time.Now() + resumeEvents := <-resumeEventsCh + elapsed := time.Since(start) + + assert.True(t, elapsed < 1*time.Second, "Resume should return quickly after cancel, elapsed: %v", elapsed) + assert.NotEmpty(t, resumeEvents) + + hasCancelError := false + for _, e := range resumeEvents { + var ce *CancelError + if e.Err != nil && errors.As(e.Err, &ce) { + hasCancelError = true + } + } + executionCompletedBeforeCancel := errors.Is(err, ErrExecutionCompleted) + assert.True(t, hasCancelError || executionCompletedBeforeCancel, "Resume should have CancelError event after cancel, or execution completed before cancel") + }) +} + +func TestCancelMonitoredToolHandler_StreamableToolCall(t *testing.T) { + t.Run("NoCancelContext_PassesThrough", func(t *testing.T) { + handler := &cancelMonitoredToolHandler{} + + // Create a stream with some data + r, w := schema.Pipe[string](1) + go func() { + w.Send("chunk1", nil) + w.Send("chunk2", nil) + w.Close() + }() + + next := func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) { + return &compose.StreamToolOutput{Result: r}, nil + } + + wrapped := handler.WrapStreamableToolCall(next) + // No cancelContext in the Go context + output, err := wrapped(context.Background(), &compose.ToolInput{Name: "test"}) + assert.NoError(t, err) + + // Should get the original stream unchanged + chunk1, err := output.Result.Recv() + assert.NoError(t, err) + assert.Equal(t, "chunk1", chunk1) + + chunk2, err := output.Result.Recv() + assert.NoError(t, err) + assert.Equal(t, "chunk2", chunk2) + + _, err = output.Result.Recv() + assert.ErrorIs(t, err, io.EOF) + }) + + t.Run("WithCancelContext_NoCancel_StreamsNormally", func(t *testing.T) { + handler := &cancelMonitoredToolHandler{} + cc := newCancelContext() + + r, w := schema.Pipe[string](1) + go func() { + w.Send("data1", nil) + w.Send("data2", nil) + w.Close() + }() + + next := func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) { + return &compose.StreamToolOutput{Result: r}, nil + } + + wrapped := handler.WrapStreamableToolCall(next) + ctx := withCancelContext(context.Background(), cc) + output, err := wrapped(ctx, &compose.ToolInput{Name: "test"}) + assert.NoError(t, err) + + chunk1, err := output.Result.Recv() + assert.NoError(t, err) + assert.Equal(t, "data1", chunk1) + + chunk2, err := output.Result.Recv() + assert.NoError(t, err) + assert.Equal(t, "data2", chunk2) + + _, err = output.Result.Recv() + assert.ErrorIs(t, err, io.EOF) + }) + + t.Run("WithCancelContext_ImmediateCancel_TerminatesStream", func(t *testing.T) { + handler := &cancelMonitoredToolHandler{} + cc := newCancelContext() + + // Create a slow stream that we'll cancel mid-way + r, w := schema.Pipe[string](1) + go func() { + defer w.Close() + w.Send("chunk1", nil) + time.Sleep(200 * time.Millisecond) + w.Send("chunk2", nil) + }() + + next := func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) { + return &compose.StreamToolOutput{Result: r}, nil + } + + wrapped := handler.WrapStreamableToolCall(next) + ctx := withCancelContext(context.Background(), cc) + output, err := wrapped(ctx, &compose.ToolInput{Name: "test"}) + assert.NoError(t, err) + + // Read first chunk + chunk1, err := output.Result.Recv() + assert.NoError(t, err) + assert.Equal(t, "chunk1", chunk1) + + // Fire immediate cancel + close(cc.immediateChan) + + // Next recv should get ErrStreamCanceled + _, err = output.Result.Recv() + assert.ErrorIs(t, err, ErrStreamCanceled) + }) + + t.Run("WithCancelContext_AlreadyCancelled_TerminatesImmediately", func(t *testing.T) { + handler := &cancelMonitoredToolHandler{} + cc := newCancelContext() + close(cc.immediateChan) // Already canceled + + r, w := schema.Pipe[string](1) + go func() { + w.Send("should-not-see", nil) + w.Close() + }() + + next := func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) { + return &compose.StreamToolOutput{Result: r}, nil + } + + wrapped := handler.WrapStreamableToolCall(next) + ctx := withCancelContext(context.Background(), cc) + output, err := wrapped(ctx, &compose.ToolInput{Name: "test"}) + assert.NoError(t, err) + + _, err = output.Result.Recv() + assert.ErrorIs(t, err, ErrStreamCanceled) + }) + + t.Run("NextReturnsError_PropagatesError", func(t *testing.T) { + handler := &cancelMonitoredToolHandler{} + cc := newCancelContext() + + nextErr := errors.New("tool execution failed") + next := func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) { + return nil, nextErr + } + + wrapped := handler.WrapStreamableToolCall(next) + ctx := withCancelContext(context.Background(), cc) + _, err := wrapped(ctx, &compose.ToolInput{Name: "test"}) + assert.ErrorIs(t, err, nextErr) + }) +} + +func TestCancelMonitoredToolHandler_EnhancedStreamableToolCall(t *testing.T) { + t.Run("NoCancelContext_PassesThrough", func(t *testing.T) { + handler := &cancelMonitoredToolHandler{} + + tr1 := &schema.ToolResult{Parts: []schema.ToolOutputPart{{Type: schema.ToolPartTypeText, Text: "chunk1"}}} + r, w := schema.Pipe[*schema.ToolResult](1) + go func() { + w.Send(tr1, nil) + w.Close() + }() + + next := func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedStreamableToolOutput, error) { + return &compose.EnhancedStreamableToolOutput{Result: r}, nil + } + + wrapped := handler.WrapEnhancedStreamableToolCall(next) + output, err := wrapped(context.Background(), &compose.ToolInput{Name: "test"}) + assert.NoError(t, err) + + result, err := output.Result.Recv() + assert.NoError(t, err) + assert.Equal(t, tr1, result) + + _, err = output.Result.Recv() + assert.ErrorIs(t, err, io.EOF) + }) + + t.Run("WithCancelContext_ImmediateCancel_TerminatesStream", func(t *testing.T) { + handler := &cancelMonitoredToolHandler{} + cc := newCancelContext() + + tr1 := &schema.ToolResult{Parts: []schema.ToolOutputPart{{Type: schema.ToolPartTypeText, Text: "chunk1"}}} + tr2 := &schema.ToolResult{Parts: []schema.ToolOutputPart{{Type: schema.ToolPartTypeText, Text: "chunk2"}}} + r, w := schema.Pipe[*schema.ToolResult](1) + go func() { + defer w.Close() + w.Send(tr1, nil) + time.Sleep(200 * time.Millisecond) + w.Send(tr2, nil) + }() + + next := func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedStreamableToolOutput, error) { + return &compose.EnhancedStreamableToolOutput{Result: r}, nil + } + + wrapped := handler.WrapEnhancedStreamableToolCall(next) + ctx := withCancelContext(context.Background(), cc) + output, err := wrapped(ctx, &compose.ToolInput{Name: "test"}) + assert.NoError(t, err) + + result, err := output.Result.Recv() + assert.NoError(t, err) + assert.Equal(t, tr1, result) + + close(cc.immediateChan) + + _, err = output.Result.Recv() + assert.ErrorIs(t, err, ErrStreamCanceled) + }) + + t.Run("NextReturnsError_PropagatesError", func(t *testing.T) { + handler := &cancelMonitoredToolHandler{} + cc := newCancelContext() + + nextErr := errors.New("enhanced tool failed") + next := func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedStreamableToolOutput, error) { + return nil, nextErr + } + + wrapped := handler.WrapEnhancedStreamableToolCall(next) + ctx := withCancelContext(context.Background(), cc) + _, err := wrapped(ctx, &compose.ToolInput{Name: "test"}) + assert.ErrorIs(t, err, nextErr) + }) +} + +func TestCancelContextKey(t *testing.T) { + t.Run("WithAndGet_RoundTrips", func(t *testing.T) { + cc := newCancelContext() + ctx := withCancelContext(context.Background(), cc) + got := getCancelContext(ctx) + assert.Equal(t, cc, got) + }) + + t.Run("Get_NoValue_ReturnsNil", func(t *testing.T) { + got := getCancelContext(context.Background()) + assert.Nil(t, got) + }) + + t.Run("With_NilCancelContext_ReturnsOriginalCtx", func(t *testing.T) { + ctx := context.Background() + result := withCancelContext(ctx, nil) + assert.Equal(t, ctx, result) + }) +} + +// -- Tests for cancel support across all agent types -- + +// cancelTestAgent is a ChatModelAgent-based agent where the model blocks until +// signalled, allowing tests to control exactly when to issue a cancel. +func newCancelTestAgent(t *testing.T, name string, modelDelay time.Duration, modelStarted chan struct{}) *ChatModelAgent { + t.Helper() + slowModel := &cancelTestChatModel{ + delayNs: int64(modelDelay), + response: &schema.Message{ + Role: schema.Assistant, + Content: "response from " + name, + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + + agent, err := NewChatModelAgent(context.Background(), &ChatModelAgentConfig{ + Name: name, + Description: "Test agent " + name, + Instruction: "You are a test assistant", + Model: slowModel, + }) + assert.NoError(t, err) + return agent +} + +func newCancelTestAgentWithTools(t *testing.T, name string, modelDelay time.Duration, modelStarted chan struct{}) *ChatModelAgent { + t.Helper() + toolName := name + "_tool" + slowModel := &cancelTestChatModel{ + delayNs: int64(modelDelay), + response: &schema.Message{ + Role: schema.Assistant, + ToolCalls: []schema.ToolCall{{ + ID: "call_1", Type: "function", + Function: schema.FunctionCall{ + Name: toolName, + Arguments: `{"input": "test"}`, + }, + }}, + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + + st := newSlowTool(toolName, 10*time.Millisecond, "tool result") + + agent, err := NewChatModelAgent(context.Background(), &ChatModelAgentConfig{ + Name: name, + Description: "Test agent " + name, + Instruction: "You are a test assistant", + Model: slowModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + return agent +} + +func newCancelTestAgentWithToolsFinalAnswer(t *testing.T, name string) *ChatModelAgent { + t.Helper() + toolName := name + "_tool" + finalModel := &cancelTestChatModel{ + delayNs: int64(10 * time.Millisecond), + response: &schema.Message{ + Role: schema.Assistant, + Content: "final response from " + name, + }, + startedChan: make(chan struct{}, 1), + doneChan: make(chan struct{}, 1), + } + + st := newSlowTool(toolName, 10*time.Millisecond, "tool result") + + agent, err := NewChatModelAgent(context.Background(), &ChatModelAgentConfig{ + Name: name, + Description: "Test agent " + name, + Instruction: "You are a test assistant", + Model: finalModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + return agent +} + +func TestWithCancel_SequentialAgent(t *testing.T) { + ctx := context.Background() + + t.Run("CancelImmediate_DuringSecondAgent", func(t *testing.T) { + // The first agent completes quickly. The second agent takes a long time. + // Cancel during the second agent's model call. + agent1Started := make(chan struct{}, 1) + agent2Started := make(chan struct{}, 1) + + agent1 := newCancelTestAgent(t, "fast_agent", 50*time.Millisecond, agent1Started) + agent2 := newCancelTestAgent(t, "slow_agent", 5*time.Second, agent2Started) + + seqAgent, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "seq_agent", + Description: "Sequential test", + SubAgents: []Agent{agent1, agent2}, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: seqAgent, + EnableStreaming: false, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt) + + // Wait for second agent to start + select { + case <-agent2Started: + case <-time.After(10 * time.Second): + t.Fatal("Second agent did not start") + } + + time.Sleep(50 * time.Millisecond) + + // Cancel should NOT return ErrExecutionCompleted (the bug before the fix) + handle, _ := cancelFn() + err = handle.Wait() + assert.NoError(t, err, "Cancel during second agent should succeed, not return ErrExecutionCompleted") + + drainEventsAndAssertCancelError(t, iter) + }) +} + +func TestWithCancel_LoopAgent(t *testing.T) { + ctx := context.Background() + + t.Run("CancelImmediate_DuringIteration", func(t *testing.T) { + // Agent in a loop. Cancel during second iteration's model call. + modelStarted := make(chan struct{}, 10) + + slowModel := &cancelTestChatModel{ + delayNs: int64(3 * time.Second), + response: &schema.Message{ + Role: schema.Assistant, + Content: "loop response", + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 10), + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "loop_inner", + Description: "Inner loop agent", + Instruction: "You are a test assistant", + Model: slowModel, + }) + assert.NoError(t, err) + + loopAgent, err := NewLoopAgent(ctx, &LoopAgentConfig{ + Name: "loop_agent", + Description: "Loop test", + SubAgents: []Agent{agent}, + MaxIterations: 10, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: loopAgent, + EnableStreaming: false, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt) + + // Wait for first iteration's model call to start + select { + case <-modelStarted: + case <-time.After(10 * time.Second): + t.Fatal("Model did not start") + } + + time.Sleep(50 * time.Millisecond) + + // Cancel should succeed + handle, _ := cancelFn() + err = handle.Wait() + assert.NoError(t, err, "Cancel during loop iteration should succeed") + + drainAndAssertCancelError(t, iter) + }) +} + +func TestWithCancel_ParallelAgent(t *testing.T) { + ctx := context.Background() + + t.Run("CancelImmediate_InterruptsAllBranches", func(t *testing.T) { + agent1Started := make(chan struct{}, 1) + agent2Started := make(chan struct{}, 1) + + // Both agents have long delays, so cancel should interrupt both. + agent1 := newCancelTestAgent(t, "par_agent1", 5*time.Second, agent1Started) + agent2 := newCancelTestAgent(t, "par_agent2", 5*time.Second, agent2Started) + + parAgent, err := NewParallelAgent(ctx, &ParallelAgentConfig{ + Name: "par_agent", + Description: "Parallel test", + SubAgents: []Agent{agent1, agent2}, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: parAgent, + EnableStreaming: false, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt) + + // Wait for both agents to start + for i := 0; i < 2; i++ { + select { + case <-agent1Started: + case <-agent2Started: + case <-time.After(10 * time.Second): + t.Fatal("Parallel agents did not start") + } + } + + time.Sleep(50 * time.Millisecond) + + start := time.Now() + handle, _ := cancelFn() + err = handle.Wait() + assert.NoError(t, err, "Cancel during parallel agents should succeed") + + events := drainEventsAndAssertCancelError(t, iter) + elapsed := time.Since(start) + + _ = events + assert.True(t, elapsed < 3*time.Second, "Should complete quickly after cancel, elapsed: %v", elapsed) + }) +} + +func TestWithCancel_SupervisorAgent(t *testing.T) { + ctx := context.Background() + + t.Run("CancelImmediate_DuringSubAgent", func(t *testing.T) { + // Supervisor delegates to a slow sub-agent via transfer. + // Cancel during the sub-agent's model call. + supervisorModelStarted := make(chan struct{}, 1) + subAgentModelStarted := make(chan struct{}, 1) + + // The supervisor model returns a transfer_to_agent tool call + supervisorModel := &simpleChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: schema.FunctionCall{ + Name: TransferToAgentToolName, + Arguments: `{"agent_name": "slow_sub"}`, + }, + }, + }, + }, + } + + supervisorAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "supervisor", + Description: "Supervisor agent", + Instruction: "You are a supervisor", + Model: supervisorModel, + }) + assert.NoError(t, err) + + subAgent := newCancelTestAgent(t, "slow_sub", 5*time.Second, subAgentModelStarted) + + agentWithSubAgents, err := SetSubAgents(ctx, supervisorAgent, []Agent{subAgent}) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: agentWithSubAgents, + EnableStreaming: false, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt) + + // Ignore the supervisor model start, wait for the sub-agent model + // The supervisor model is fast (simpleChatModel), so it will start and finish quickly + _ = supervisorModelStarted + select { + case <-subAgentModelStarted: + case <-time.After(10 * time.Second): + t.Fatal("Sub-agent model did not start") + } + + time.Sleep(50 * time.Millisecond) + + start := time.Now() + handle, _ := cancelFn() + err = handle.Wait() + assert.NoError(t, err, "Cancel during sub-agent should succeed") + + drainAndAssertCancelError(t, iter) + elapsed := time.Since(start) + + assert.True(t, elapsed < 3*time.Second, "Should complete quickly after cancel, elapsed: %v", elapsed) + }) +} + +func TestFilterCancelOption(t *testing.T) { + t.Run("RemovesCancelOption", func(t *testing.T) { + cancelOpt, _ := WithCancel() + sessionOpt := WithSessionValues(map[string]any{"key": "value"}) + opts := []AgentRunOption{cancelOpt, sessionOpt} + + filtered := filterCancelOption(opts) + assert.Len(t, filtered, 1, "Should have removed the cancel option") + + // Verify the remaining option is the session option + testOpt := &options{} + filtered[0].implSpecificOptFn.(func(*options))(testOpt) + assert.NotNil(t, testOpt.sessionValues) + assert.Nil(t, testOpt.cancelCtx) + }) + + t.Run("KeepsNonCancelOptions", func(t *testing.T) { + sessionOpt := WithSessionValues(map[string]any{"key": "value"}) + callbackOpt := WithCallbacks() + opts := []AgentRunOption{sessionOpt, callbackOpt} + + filtered := filterCancelOption(opts) + assert.Len(t, filtered, 2, "Should keep all non-cancel options") + }) + + t.Run("EmptyInput", func(t *testing.T) { + filtered := filterCancelOption(nil) + assert.Nil(t, filtered) + }) +} + +func wrapIterWithMarkDone(iter *AsyncIterator[*AgentEvent], cc *cancelContext) *AsyncIterator[*AgentEvent] { + if cc == nil { + return iter + } + outIter, outGen := NewAsyncIteratorPair[*AgentEvent]() + go func() { + defer cc.markDone() + defer outGen.Close() + for { + event, ok := iter.Next() + if !ok { + return + } + outGen.Send(event) + } + }() + return outIter +} + +func TestWrapIterWithMarkDone(t *testing.T) { + t.Run("MarksDoneAfterDrain", func(t *testing.T) { + cc := newCancelContext() + iter, gen := NewAsyncIteratorPair[*AgentEvent]() + + go func() { + gen.Send(&AgentEvent{AgentName: "test"}) + gen.Close() + }() + + wrapped := wrapIterWithMarkDone(iter, cc) + + event, ok := wrapped.Next() + assert.True(t, ok) + assert.Equal(t, "test", event.AgentName) + + _, ok = wrapped.Next() + assert.False(t, ok) + + // markDone should have been called, so doneChan should be closed + select { + case <-cc.doneChan: + // good + case <-time.After(time.Second): + t.Fatal("doneChan was not closed after drain") + } + }) + + t.Run("NilCancelContext_PassesThrough", func(t *testing.T) { + iter, gen := NewAsyncIteratorPair[*AgentEvent]() + go func() { + gen.Send(&AgentEvent{AgentName: "test"}) + gen.Close() + }() + + wrapped := wrapIterWithMarkDone(iter, nil) + assert.Equal(t, iter, wrapped, "Should return same iter when cc is nil") + }) +} + +func TestGraphInterruptFuncs_Parallel(t *testing.T) { + t.Run("MultipleGraphInterruptFuncsAllCalled", func(t *testing.T) { + cc := newCancelContext() + + var called1, called2 int32 + cc.setGraphInterruptFunc(func(opts ...compose.GraphInterruptOption) { + atomic.AddInt32(&called1, 1) + }) + cc.setGraphInterruptFunc(func(opts ...compose.GraphInterruptOption) { + atomic.AddInt32(&called2, 1) + }) + + // Simulate immediate cancel + cc.setMode(CancelImmediate) + atomic.CompareAndSwapInt32(&cc.state, stateRunning, stateCancelling) + close(cc.cancelChan) + cc.sendImmediateInterrupt() + + assert.Equal(t, int32(1), atomic.LoadInt32(&called1), "First graph interrupt func should be called") + assert.Equal(t, int32(1), atomic.LoadInt32(&called2), "Second graph interrupt func should be called") + }) + + t.Run("RetroactiveFire_OnSetAfterCancel", func(t *testing.T) { + cc := newCancelContext() + + // First set up cancel state with immediate interrupt + cc.setMode(CancelImmediate) + atomic.CompareAndSwapInt32(&cc.state, stateRunning, stateCancelling) + close(cc.cancelChan) + close(cc.immediateChan) + atomic.StoreInt32(&cc.interruptSent, interruptImmediate) + + // Now register a new function - it should be retroactively fired + var called int32 + cc.setGraphInterruptFunc(func(opts ...compose.GraphInterruptOption) { + atomic.AddInt32(&called, 1) + }) + + assert.Equal(t, int32(1), atomic.LoadInt32(&called), "setGraphInterruptFunc should retroactively fire new func") + }) +} + +// -- Tests for transition-point cancel (cancel between sub-agents) -- + +// gatedChatModel is a model that: +// - Signals doneChan when Generate completes +// - Optionally blocks on gateChan before returning (nil gateChan = no blocking) +// - Tracks call count via callCount +type gatedChatModel struct { + response *schema.Message + gateChan chan struct{} // if non-nil, blocks until closed before returning + doneChan chan struct{} // signalled after Generate completes + callCount int32 +} + +func (m *gatedChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m.callCount, 1) + if m.gateChan != nil { + select { + case <-m.gateChan: + case <-ctx.Done(): + return nil, ctx.Err() + } + } + select { + case m.doneChan <- struct{}{}: + default: + } + return m.response, nil +} + +func (m *gatedChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + msg, err := m.Generate(ctx, input, opts...) + if err != nil { + return nil, err + } + return schema.StreamReaderFromArray([]*schema.Message{msg}), nil +} + +func (m *gatedChatModel) BindTools(tools []*schema.ToolInfo) error { + return nil +} + +func TestCheckCancel_Sequential_BetweenSubAgents(t *testing.T) { + ctx := context.Background() + + // CancelAfterToolCalls fires at transition boundaries between sub-agents. + // At a transition boundary, the completed sub-agent's entire execution + // (including any tool calls) is done, satisfying the CancelAfterToolCalls + // contract — even if this particular sub-agent had no tools. + model1 := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "agent1 done"}, + gateChan: make(chan struct{}), + doneChan: make(chan struct{}, 1), + } + model2 := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "agent2 done"}, + doneChan: make(chan struct{}, 1), + } + + agent1, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent1", Description: "first", Instruction: "test", Model: model1, + }) + assert.NoError(t, err) + + agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent2", Description: "second", Instruction: "test", Model: model2, + }) + assert.NoError(t, err) + + seqAgent, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "seq", Description: "sequential test", SubAgents: []Agent{agent1, agent2}, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: seqAgent, EnableStreaming: false, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("hello")}, cancelOpt) + + for atomic.LoadInt32(&model1.callCount) == 0 { + runtime.Gosched() + } + + cancelCalled, result := cancelAsync(cancelFn, WithAgentCancelMode(CancelAfterToolCalls)) + waitForChan(t, cancelCalled, "cancelFn was not called") + close(model1.gateChan) + + assert.NoError(t, result.waitDone(t), "CancelAfterToolCalls should succeed at transition boundary") + + for { + _, ok := iter.Next() + if !ok { + break + } + } + + assert.Equal(t, int32(1), atomic.LoadInt32(&model1.callCount), "Agent1 model should be invoked") + assert.Equal(t, int32(0), atomic.LoadInt32(&model2.callCount), + "Agent2 model should NOT be invoked (CancelAfterToolCalls caught at transition)") +} + +func TestCheckCancel_Loop_BetweenIterations(t *testing.T) { + ctx := context.Background() + + // CancelAfterToolCalls fires at loop iteration boundaries. + // After the first iteration completes, any tool calls it made are done, + // satisfying the CancelAfterToolCalls contract. + mdl := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "loop iter"}, + gateChan: make(chan struct{}), + doneChan: make(chan struct{}, 10), + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "loop_inner", Description: "inner", Instruction: "test", Model: mdl, + }) + assert.NoError(t, err) + + loopAgent, err := NewLoopAgent(ctx, &LoopAgentConfig{ + Name: "loop", Description: "loop test", SubAgents: []Agent{agent}, MaxIterations: 3, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: loopAgent, EnableStreaming: false, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("hello")}, cancelOpt) + + for atomic.LoadInt32(&mdl.callCount) == 0 { + runtime.Gosched() + } + + cancelCalled, result := cancelAsync(cancelFn, WithAgentCancelMode(CancelAfterToolCalls)) + waitForChan(t, cancelCalled, "cancelFn was not called") + close(mdl.gateChan) + + assert.NoError(t, result.waitDone(t), "CancelAfterToolCalls should succeed at loop transition boundary") + + for { + _, ok := iter.Next() + if !ok { + break + } + } + + assert.Equal(t, int32(1), atomic.LoadInt32(&mdl.callCount), + "Model should be called once; second iteration caught at transition") +} + +func TestCheckCancel_Parallel_PreSpawn(t *testing.T) { + ctx := context.Background() + + // Cancel fires before Run is called. Neither model should be invoked. + model1 := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "par1"}, + doneChan: make(chan struct{}, 1), + } + model2 := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "par2"}, + doneChan: make(chan struct{}, 1), + } + + agent1, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "par1", Description: "first", Instruction: "test", Model: model1, + }) + assert.NoError(t, err) + + agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "par2", Description: "second", Instruction: "test", Model: model2, + }) + assert.NoError(t, err) + + parAgent, err := NewParallelAgent(ctx, &ParallelAgentConfig{ + Name: "par", Description: "parallel test", SubAgents: []Agent{agent1, agent2}, + }) + assert.NoError(t, err) + + // Fire cancel in goroutine (cancelFn blocks until handled) + cancelOpt, cancelFn := WithCancel() + cancelDone := make(chan error, 1) + go func() { + handle, _ := cancelFn() + cancelDone <- handle.Wait() + }() + // Wait for cancelChan to be closed (happens synchronously before the blocking doneChan wait) + time.Sleep(20 * time.Millisecond) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: parAgent, EnableStreaming: false, + }) + + iter := runner.Run(ctx, []Message{schema.UserMessage("hello")}, cancelOpt) + + var cancelErr *CancelError + for { + event, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + cancelErr = ce + } + } + + // cancelFn should have completed + select { + case err = <-cancelDone: + assert.NoError(t, err) + case <-time.After(5 * time.Second): + t.Fatal("cancelFn did not return") + } + + assert.NotNil(t, cancelErr, "Should have CancelError") + assert.Equal(t, int32(0), atomic.LoadInt32(&model1.callCount), "First model should never be invoked") + assert.Equal(t, int32(0), atomic.LoadInt32(&model2.callCount), "Second model should never be invoked") +} + +func TestCheckCancel_Transfer_BeforeTarget(t *testing.T) { + ctx := context.Background() + + // Supervisor CMA returns a transfer action (instantly). + // Cancel fires after transfer action but before target runs. + // Target model should never be invoked. + supervisorModel := &simpleChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + ToolCalls: []schema.ToolCall{{ + ID: "call_1", Type: "function", + Function: schema.FunctionCall{ + Name: TransferToAgentToolName, + Arguments: `{"agent_name": "target"}`, + }, + }}, + }, + } + targetModel := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "target done"}, + doneChan: make(chan struct{}, 1), + } + + supervisorAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "supervisor", Description: "supervisor", Instruction: "test", Model: supervisorModel, + }) + assert.NoError(t, err) + + targetAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "target", Description: "target", Instruction: "test", Model: targetModel, + }) + assert.NoError(t, err) + + agentWithSub, err := SetSubAgents(ctx, supervisorAgent, []Agent{targetAgent}) + assert.NoError(t, err) + + // Fire cancel in goroutine (cancelFn blocks until handled) + cancelOpt, cancelFn := WithCancel() + cancelDone := make(chan error, 1) + go func() { + handle, _ := cancelFn() + cancelDone <- handle.Wait() + }() + time.Sleep(20 * time.Millisecond) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: agentWithSub, EnableStreaming: false, + }) + + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt) + + var cancelErr *CancelError + for { + event, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + cancelErr = ce + } + } + + select { + case err = <-cancelDone: + assert.NoError(t, err) + case <-time.After(5 * time.Second): + t.Fatal("cancelFn did not return") + } + + assert.NotNil(t, cancelErr, "Should have CancelError") + assert.Equal(t, int32(0), atomic.LoadInt32(&targetModel.callCount), "Target model should never be invoked") +} + +func TestCheckCancel_AlreadyHandled_NoDuplicate(t *testing.T) { + ctx := context.Background() + + // In a sequential agent, if the first CMA handles the cancel (graph interrupt), + // the workflow's transition check should NOT emit a duplicate CancelError. + // Use a slow model so cancel fires during its execution (handled by CMA). + modelStarted := make(chan struct{}, 1) + model1 := &cancelTestChatModel{ + delayNs: int64(2 * time.Second), + response: &schema.Message{Role: schema.Assistant, Content: "agent1"}, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + model2 := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "agent2"}, + doneChan: make(chan struct{}, 1), + } + + agent1, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent1", Description: "first", Instruction: "test", Model: model1, + }) + assert.NoError(t, err) + + agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent2", Description: "second", Instruction: "test", Model: model2, + }) + assert.NoError(t, err) + + seqAgent, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "seq", Description: "sequential", SubAgents: []Agent{agent1, agent2}, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: seqAgent, EnableStreaming: false, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt) + + // Wait for model to start, then cancel during model execution + select { + case <-modelStarted: + case <-time.After(5 * time.Second): + t.Fatal("Model did not start") + } + time.Sleep(50 * time.Millisecond) + handle, _ := cancelFn() + err = handle.Wait() + assert.NoError(t, err) + + cancelCount := 0 + for { + event, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + cancelCount++ + } + } + + assert.Equal(t, 1, cancelCount, "Should have exactly one CancelError, no duplicate from workflow transition") + assert.Equal(t, int32(0), atomic.LoadInt32(&model2.callCount), "Second agent should not run") +} + +// Tests for CancelAfterChatModel/CancelAfterToolCalls in nested workflow structures. +// These verify that safe-point cancel modes propagate through the entire agent hierarchy +// and fire at whichever nested level reaches the safe-point first. + +func TestCancel_SequentialWorkflow_CancelAfterChatModel(t *testing.T) { + ctx := context.Background() + agent1Started := make(chan struct{}, 1) + + agent1 := newCancelTestAgentWithTools(t, "seq_slow", 500*time.Millisecond, agent1Started) + agent2 := newCancelTestAgentWithTools(t, "seq_fast", 50*time.Millisecond, make(chan struct{}, 1)) + + seqAgent, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "seq_agent", + Description: "Sequential workflow", + SubAgents: []Agent{agent1, agent2}, + }) + assert.NoError(t, err) + + store := newCancelTestStore() + runner := NewRunner(ctx, RunnerConfig{ + Agent: seqAgent, + CheckPointStore: store, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt, WithCheckPointID("seq-cancel-1")) + + select { + case <-agent1Started: + case <-time.After(10 * time.Second): + t.Fatal("First agent did not start") + } + + handle, contributed := cancelFn(WithAgentCancelMode(CancelAfterChatModel)) + assert.True(t, contributed, "Cancel should contribute") + err = handle.Wait() + assert.NoError(t, err) + + hasCancelError := false + var cancelErr *CancelError + for { + event, ok := iter.Next() + if !ok { + break + } + if event.Err != nil && errors.As(event.Err, &cancelErr) { + hasCancelError = true + } + } + + assert.True(t, hasCancelError, "Should have CancelError") + assert.Equal(t, CancelAfterChatModel, cancelErr.Info.Mode) + assert.NotEmpty(t, cancelErr.CheckPointID, "CancelError should have checkpoint ID") + assert.NotNil(t, cancelErr.interruptSignal, "CancelError should have interrupt signal for checkpoint") + + resumeAgent1 := newCancelTestAgentWithToolsFinalAnswer(t, "seq_slow") + resumeAgent2 := newCancelTestAgentWithToolsFinalAnswer(t, "seq_fast") + + resumeSeq, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "seq_agent", + Description: "Sequential workflow", + SubAgents: []Agent{resumeAgent1, resumeAgent2}, + }) + assert.NoError(t, err) + + runner2 := NewRunner(ctx, RunnerConfig{ + Agent: resumeSeq, + CheckPointStore: store, + }) + + resumeIter, err := runner2.Resume(ctx, "seq-cancel-1") + assert.NoError(t, err) + assert.NotNil(t, resumeIter) + + var resumeEvents []*AgentEvent + for { + event, ok := resumeIter.Next() + if !ok { + break + } + assert.Nil(t, event.Err, "Should not have error during resume") + resumeEvents = append(resumeEvents, event) + } + assert.NotEmpty(t, resumeEvents, "Resume should produce events") +} + +func TestCancelImmediate_OrphanedToolGoroutine_NoPanic(t *testing.T) { + t.Run("unit_send_after_close", func(t *testing.T) { + _, gen := NewAsyncIteratorPair[*AgentEvent]() + + cc := newCancelContext() + cc.setMode(CancelImmediate) + close(cc.cancelChan) + close(cc.immediateChan) + + gen.Close() + + execCtx := &chatModelAgentExecCtx{ + generator: gen, + cancelCtx: cc, + } + + assert.NotPanics(t, func() { + execCtx.send(&AgentEvent{AgentName: "test"}) + }, "send after generator.Close must not panic") + }) + + t.Run("unit_send_after_close_without_cancel_ctx", func(t *testing.T) { + _, gen := NewAsyncIteratorPair[*AgentEvent]() + gen.Close() + + execCtx := &chatModelAgentExecCtx{ + generator: gen, + } + + assert.NotPanics(t, func() { + execCtx.send(&AgentEvent{AgentName: "test"}) + }, "send after generator.Close must not panic even without cancelCtx (trySend safety net)") + }) + + t.Run("unit_send_nil_execCtx", func(t *testing.T) { + var execCtx *chatModelAgentExecCtx + assert.NotPanics(t, func() { + execCtx.send(&AgentEvent{AgentName: "test"}) + }, "send on nil execCtx must not panic") + }) + + t.Run("unit_send_nil_generator", func(t *testing.T) { + execCtx := &chatModelAgentExecCtx{} + assert.NotPanics(t, func() { + execCtx.send(&AgentEvent{AgentName: "test"}) + }, "send with nil generator must not panic") + }) + + t.Run("unit_isImmediateCancelled_nil_cancelContext", func(t *testing.T) { + var cc *cancelContext + assert.False(t, cc.isImmediateCancelled(), "nil cancelContext should return false") + }) + + t.Run("unit_trySend_race_window", func(t *testing.T) { + _, gen := NewAsyncIteratorPair[*AgentEvent]() + cc := newCancelContext() + + gen.Close() + + execCtx := &chatModelAgentExecCtx{ + generator: gen, + cancelCtx: cc, + } + + assert.NotPanics(t, func() { + execCtx.send(&AgentEvent{AgentName: "test"}) + }, "trySend must handle the case where isImmediateCancelled is false but generator is closed") + }) + + t.Run("unit_SendEvent_after_close", func(t *testing.T) { + _, gen := NewAsyncIteratorPair[*AgentEvent]() + + cc := newCancelContext() + cc.setMode(CancelImmediate) + close(cc.cancelChan) + close(cc.immediateChan) + + gen.Close() + + execCtx := &chatModelAgentExecCtx{ + generator: gen, + cancelCtx: cc, + } + + ctx := withChatModelAgentExecCtx(context.Background(), execCtx) + + assert.NotPanics(t, func() { + err := SendEvent(ctx, &AgentEvent{AgentName: "test"}) + assert.NoError(t, err) + }, "SendEvent after generator.Close must not panic") + }) + + t.Run("unit_SendEvent_no_execCtx", func(t *testing.T) { + err := SendEvent(context.Background(), &AgentEvent{AgentName: "test"}) + assert.Error(t, err, "SendEvent without execCtx should return error") + }) + + t.Run("integration_cancel_escalation_orphans_tool", func(t *testing.T) { + ctx := context.Background() + + toolStarted := make(chan struct{}, 1) + toolDone := make(chan struct{}, 1) + st := &slowToolWithSignal{ + name: "orphan_tool", + delay: 2 * time.Second, + result: "tool result", + startedChan: toolStarted, + } + + mdl := &simpleChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_orphan_1", + Type: "function", + Function: schema.FunctionCall{ + Name: "orphan_tool", + Arguments: `{"input": "test"}`, + }, + }, + }, + }, + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "OrphanTestAgent", + Description: "Test agent for orphaned tool goroutine panic", + Model: mdl, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + + cancelOpt, cancelFn := WithCancel() + iter := agent.Run(ctx, &AgentInput{ + Messages: []Message{schema.UserMessage("Use the tool")}, + }, cancelOpt) + assert.NotNil(t, iter) + + select { + case <-toolStarted: + case <-time.After(10 * time.Second): + t.Fatal("Tool did not start") + } + + timeout := 50 * time.Millisecond + handle, contributed := cancelFn( + WithAgentCancelMode(CancelAfterChatModel), + WithAgentCancelTimeout(timeout), + ) + assert.True(t, contributed, "Cancel should contribute") + + err = handle.Wait() + assert.True(t, err == nil || errors.Is(err, ErrCancelTimeout), + "handle.Wait should return nil or ErrCancelTimeout, got: %v", err) + + for { + _, ok := iter.Next() + if !ok { + break + } + } + + go func() { + time.Sleep(3 * time.Second) + select { + case toolDone <- struct{}{}: + default: + } + }() + + runtime.Gosched() + time.Sleep(3 * time.Second) + + select { + case <-toolDone: + default: + } + }) +} + +// -- Tests for CancelImmediate in nested agent structures -- + +func newTestChatModel(response *schema.Message, delay time.Duration) *cancelTestChatModel { + m := &cancelTestChatModel{ + response: response, + startedChan: make(chan struct{}, 1), + doneChan: make(chan struct{}, 1), + } + if delay > 0 { + m.setDelay(delay) + } + return m +} + +func newToolCallResponse(toolName string) *schema.Message { + return &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + {ID: "call_1", Type: "function", Function: schema.FunctionCall{Name: toolName, Arguments: `{}`}}, + }, + } +} + +func newAgentWithTool(t *testing.T, ctx context.Context, name string, mdl model.BaseChatModel, subAgent Agent) (Agent, error) { + t.Helper() + return NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: name, + Description: name, + Model: mdl, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{NewAgentTool(ctx, subAgent)}, + }, + }, + }) +} + +func waitForChan(t *testing.T, ch <-chan struct{}, msg string) { + t.Helper() + select { + case <-ch: + case <-time.After(10 * time.Second): + t.Fatal(msg) + } +} + +func drainCancelError(t *testing.T, iter *AsyncIterator[*AgentEvent]) *CancelError { + t.Helper() + var cancelErr *CancelError + for { + event, ok := iter.Next() + if !ok { + break + } + if event.Err != nil { + errors.As(event.Err, &cancelErr) + } + } + return cancelErr +} + +func drainResumeErrors(t *testing.T, iter *AsyncIterator[*AgentEvent]) []error { + t.Helper() + var errs []error + for { + event, ok := iter.Next() + if !ok { + break + } + if event.Err != nil { + errs = append(errs, event.Err) + } + } + return errs +} + +type cancelResult struct { + err error + contributed bool + done chan struct{} +} + +func cancelAsync(cancelFn AgentCancelFunc, opts ...AgentCancelOption) (cancelCalled chan struct{}, result *cancelResult) { + cancelCalled = make(chan struct{}) + result = &cancelResult{done: make(chan struct{})} + go func() { + handle, contributed := cancelFn(opts...) + result.contributed = contributed + close(cancelCalled) + result.err = handle.Wait() + close(result.done) + }() + return +} + +func (r *cancelResult) waitDone(t *testing.T) error { + t.Helper() + select { + case <-r.done: + return r.err + case <-time.After(10 * time.Second): + t.Fatal("cancel did not complete") + return nil + } +} + +func TestCancelImmediate_AgentTool_PreservesChildCheckpoint(t *testing.T) { + ctx := context.Background() + + leafModel := newTestChatModel( + &schema.Message{Role: schema.Assistant, Content: "leaf response"}, 2*time.Second) + leafAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "leaf_agent", Description: "Leaf agent in agentTool", Model: leafModel, + }) + assert.NoError(t, err) + + seqAgent, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "inner_seq", Description: "Inner sequential workflow", SubAgents: []Agent{leafAgent}, + }) + assert.NoError(t, err) + + rootModel := newTestChatModel(newToolCallResponse("inner_seq"), 0) + rootAgent, err := newAgentWithTool(t, ctx, "root_agent", rootModel, seqAgent) + assert.NoError(t, err) + + store := newCancelTestStore() + runner := NewRunner(ctx, RunnerConfig{Agent: rootAgent, CheckPointStore: store}) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt, WithCheckPointID("immediate-agent-tool-1")) + + waitForChan(t, leafModel.startedChan, "Leaf agent model did not start") + + handle, contributed := cancelFn(WithRecursive()) + assert.True(t, contributed) + assert.NoError(t, handle.Wait()) + + cancelErr := drainCancelError(t, iter) + assert.NotNil(t, cancelErr, "Should have CancelError from CancelImmediate through agentTool") + assert.NotEmpty(t, cancelErr.CheckPointID) + assert.NotNil(t, cancelErr.interruptSignal) + + resumeLeaf, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "leaf_agent", Description: "Leaf agent in agentTool", + Model: newTestChatModel(&schema.Message{Role: schema.Assistant, Content: "resumed leaf"}, 0), + }) + assert.NoError(t, err) + resumeSeq, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "inner_seq", Description: "Inner sequential workflow", SubAgents: []Agent{resumeLeaf}, + }) + assert.NoError(t, err) + resumeRoot, err := newAgentWithTool(t, ctx, "root_agent", + newTestChatModel(&schema.Message{Role: schema.Assistant, Content: "resumed root"}, 0), resumeSeq) + assert.NoError(t, err) + + runner2 := NewRunner(ctx, RunnerConfig{Agent: resumeRoot, CheckPointStore: store}) + resumeIter, err := runner2.Resume(ctx, "immediate-agent-tool-1") + assert.NoError(t, err) + assert.Empty(t, drainResumeErrors(t, resumeIter), "Resume should complete without errors") +} + +func TestCancelImmediate_ParallelWorkflow_WithAgentTool(t *testing.T) { + ctx := context.Background() + + leafModel := newTestChatModel( + &schema.Message{Role: schema.Assistant, Content: "leaf response"}, 2*time.Second) + leafAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "leaf_agent", Description: "Leaf agent in agentTool", Model: leafModel, + }) + assert.NoError(t, err) + + agentWithTool, err := newAgentWithTool(t, ctx, "agent_with_tool", + newTestChatModel(newToolCallResponse("leaf_agent"), 0), leafAgent) + assert.NoError(t, err) + + simpleStarted := make(chan struct{}, 1) + simpleAgent := newCancelTestAgent(t, "simple_agent", 2*time.Second, simpleStarted) + + parAgent, err := NewParallelAgent(ctx, &ParallelAgentConfig{ + Name: "par_agent", Description: "Parallel with agentTool and simple agent", + SubAgents: []Agent{agentWithTool, simpleAgent}, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{Agent: parAgent, EnableStreaming: false}) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt) + + waitForChan(t, leafModel.startedChan, "Leaf agent did not start") + waitForChan(t, simpleStarted, "Simple agent did not start") + + start := time.Now() + handle, _ := cancelFn() + assert.NoError(t, handle.Wait()) + + cancelErr := drainCancelError(t, iter) + elapsed := time.Since(start) + + assert.NotNil(t, cancelErr, "Should have CancelError from parallel with agentTool") + assert.True(t, elapsed < 5*time.Second, "Should complete quickly after cancel, elapsed: %v", elapsed) +} + +type cancelUnawareAgent struct { + name string + desc string + delay time.Duration + response string +} + +type multiResponseGatedModel struct { + responses []*schema.Message + gateChan chan struct{} + gateOnce bool + gated int32 + doneChan chan struct{} + callCount int32 +} + +func (m *multiResponseGatedModel) Generate(ctx context.Context, msgs []*schema.Message, opts ...model.Option) (*schema.Message, error) { + idx := atomic.AddInt32(&m.callCount, 1) + if m.gateChan != nil && (!m.gateOnce || atomic.CompareAndSwapInt32(&m.gated, 0, 1)) { + select { + case <-m.gateChan: + case <-ctx.Done(): + return nil, ctx.Err() + } + } + if len(m.responses) == 0 { + return nil, fmt.Errorf("multiResponseGatedModel: no responses configured") + } + resp := m.responses[(int(idx)-1)%len(m.responses)] + if m.doneChan != nil { + select { + case m.doneChan <- struct{}{}: + default: + } + } + return resp, nil +} + +func (m *multiResponseGatedModel) Stream(ctx context.Context, msgs []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + resp, err := m.Generate(ctx, msgs, opts...) + if err != nil { + return nil, err + } + return schema.StreamReaderFromArray([]*schema.Message{resp}), nil +} + +func (m *multiResponseGatedModel) BindTools(tools []*schema.ToolInfo) error { return nil } + +func (a *cancelUnawareAgent) Name(_ context.Context) string { return a.name } +func (a *cancelUnawareAgent) Description(_ context.Context) string { return a.desc } + +func (a *cancelUnawareAgent) Run(_ context.Context, _ *AgentInput, _ ...AgentRunOption) *AsyncIterator[*AgentEvent] { + iter, gen := NewAsyncIteratorPair[*AgentEvent]() + go func() { + defer gen.Close() + // Intentionally ignores ctx.Done() — simulates a custom agent that + // does not participate in the cancel protocol at all. + // Delay is kept short (relative to grace period) to avoid goroutine + // leak lasting long after the test completes. + time.Sleep(a.delay) + }() + return iter +} + +func TestCancelImmediate_CustomAgent_GracePeriodFallback(t *testing.T) { + ctx := context.Background() + + customAgent := &cancelUnawareAgent{ + name: "custom_slow", desc: "A custom agent that ignores cancel", + delay: 5 * time.Second, response: "custom response", + } + + rootModel := newTestChatModel(newToolCallResponse("custom_slow"), 0) + rootAgent, err := newAgentWithTool(t, ctx, "root_agent", rootModel, customAgent) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{Agent: rootAgent, EnableStreaming: false}) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt) + + waitForChan(t, rootModel.startedChan, "Root model did not start") + waitForChan(t, rootModel.doneChan, "Root model did not finish") + + start := time.Now() + handle, _ := cancelFn() + assert.NoError(t, handle.Wait()) + + cancelErr := drainCancelError(t, iter) + elapsed := time.Since(start) + + assert.NotNil(t, cancelErr, "Should have CancelError (from grace period fallback)") + assert.True(t, elapsed < 5*time.Second, + "Should complete within grace period + overhead, elapsed: %v", elapsed) +} + +func TestCancelImmediate_MultiLevelNesting(t *testing.T) { + ctx := context.Background() + + innerLeafModel := newTestChatModel( + &schema.Message{Role: schema.Assistant, Content: "inner leaf response"}, 2*time.Second) + innerLeafAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "inner_leaf", Description: "Innermost leaf agent", Model: innerLeafModel, + }) + assert.NoError(t, err) + + middleAgent, err := newAgentWithTool(t, ctx, "middle_agent", + newTestChatModel(newToolCallResponse("inner_leaf"), 0), innerLeafAgent) + assert.NoError(t, err) + + rootAgent, err := newAgentWithTool(t, ctx, "root_agent", + newTestChatModel(newToolCallResponse("middle_agent"), 0), middleAgent) + assert.NoError(t, err) + + store := newCancelTestStore() + runner := NewRunner(ctx, RunnerConfig{Agent: rootAgent, CheckPointStore: store}) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt, WithCheckPointID("multi-level-1")) + + waitForChan(t, innerLeafModel.startedChan, "Inner leaf model did not start") + + start := time.Now() + handle, contributed := cancelFn() + assert.True(t, contributed) + assert.NoError(t, handle.Wait()) + + cancelErr := drainCancelError(t, iter) + elapsed := time.Since(start) + + assert.NotNil(t, cancelErr, "Should have CancelError from multi-level nesting") + assert.NotEmpty(t, cancelErr.CheckPointID) + assert.NotNil(t, cancelErr.interruptSignal) + assert.True(t, elapsed < 5*time.Second, "Should complete quickly, elapsed: %v", elapsed) + + resumeInnerLeaf, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "inner_leaf", Description: "Innermost leaf agent", + Model: newTestChatModel(&schema.Message{Role: schema.Assistant, Content: "resumed inner leaf"}, 0), + }) + assert.NoError(t, err) + resumeMiddle, err := newAgentWithTool(t, ctx, "middle_agent", + newTestChatModel(&schema.Message{Role: schema.Assistant, Content: "resumed middle"}, 0), resumeInnerLeaf) + assert.NoError(t, err) + resumeRoot, err := newAgentWithTool(t, ctx, "root_agent", + newTestChatModel(&schema.Message{Role: schema.Assistant, Content: "resumed root"}, 0), resumeMiddle) + assert.NoError(t, err) + + runner2 := NewRunner(ctx, RunnerConfig{Agent: resumeRoot, CheckPointStore: store}) + resumeIter, err := runner2.Resume(ctx, "multi-level-1") + assert.NoError(t, err) + assert.Empty(t, drainResumeErrors(t, resumeIter), "Resume should complete without errors") +} + +func TestCancelImmediate_SequentialTransitionBoundary(t *testing.T) { + ctx := context.Background() + + model1 := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "agent1 done"}, + gateChan: make(chan struct{}), + doneChan: make(chan struct{}, 1), + } + model2 := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "agent2 done"}, + doneChan: make(chan struct{}, 1), + } + + agent1, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent1", Description: "first", Instruction: "test", Model: model1, + }) + assert.NoError(t, err) + + agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent2", Description: "second", Instruction: "test", Model: model2, + }) + assert.NoError(t, err) + + seqAgent, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "seq", Description: "sequential test", SubAgents: []Agent{agent1, agent2}, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: seqAgent, EnableStreaming: false, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("hello")}, cancelOpt) + + for atomic.LoadInt32(&model1.callCount) == 0 { + runtime.Gosched() + } + + cancelCalled, result := cancelAsync(cancelFn) + waitForChan(t, cancelCalled, "cancelFn was not called") + close(model1.gateChan) + + assert.NoError(t, result.waitDone(t), "CancelImmediate should succeed at transition") + + cancelErr := drainCancelError(t, iter) + + assert.NotNil(t, cancelErr, "Should have CancelError at transition boundary") + assert.Equal(t, int32(1), atomic.LoadInt32(&model1.callCount), "Agent1 model should be invoked") + assert.Equal(t, int32(0), atomic.LoadInt32(&model2.callCount), "Agent2 model should NOT be invoked (caught at transition)") +} + +func TestCancelImmediate_LoopTransitionBoundary(t *testing.T) { + ctx := context.Background() + + mdl := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "loop iter"}, + gateChan: make(chan struct{}), + doneChan: make(chan struct{}, 10), + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "loop_inner", Description: "inner", Instruction: "test", Model: mdl, + }) + assert.NoError(t, err) + + loopAgent, err := NewLoopAgent(ctx, &LoopAgentConfig{ + Name: "loop", Description: "loop test", SubAgents: []Agent{agent}, MaxIterations: 5, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: loopAgent, EnableStreaming: false, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("hello")}, cancelOpt) + + for atomic.LoadInt32(&mdl.callCount) == 0 { + runtime.Gosched() + } + + cancelCalled, result := cancelAsync(cancelFn) + waitForChan(t, cancelCalled, "cancelFn was not called") + close(mdl.gateChan) + + assert.NoError(t, result.waitDone(t), "CancelImmediate should succeed at loop transition") + + for { + _, ok := iter.Next() + if !ok { + break + } + } + + assert.Equal(t, int32(1), atomic.LoadInt32(&mdl.callCount), + "Model should be called once; second iteration caught at transition") +} + +func TestCancelAfterChatModel_SequentialTransitionBoundary(t *testing.T) { + ctx := context.Background() + + model1 := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "agent1 done"}, + gateChan: make(chan struct{}), + doneChan: make(chan struct{}, 1), + } + model2 := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "agent2 done"}, + doneChan: make(chan struct{}, 1), + } + + agent1, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent1", Description: "first", Instruction: "test", Model: model1, + }) + assert.NoError(t, err) + + agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent2", Description: "second", Instruction: "test", Model: model2, + }) + assert.NoError(t, err) + + seqAgent, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "seq", Description: "sequential test", SubAgents: []Agent{agent1, agent2}, + }) + assert.NoError(t, err) + + store := newCancelTestStore() + runner := NewRunner(ctx, RunnerConfig{ + Agent: seqAgent, + EnableStreaming: false, + CheckPointStore: store, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("hello")}, cancelOpt, WithCheckPointID("chatmodel-transition-1")) + + for atomic.LoadInt32(&model1.callCount) == 0 { + runtime.Gosched() + } + + cancelCalled, result := cancelAsync(cancelFn, WithAgentCancelMode(CancelAfterChatModel)) + waitForChan(t, cancelCalled, "cancelFn was not called") + close(model1.gateChan) + + assert.NoError(t, result.waitDone(t), "CancelAfterChatModel should succeed at transition boundary") + + var cancelErr *CancelError + for { + event, ok := iter.Next() + if !ok { + break + } + var ce *CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + cancelErr = ce + } + } + + assert.NotNil(t, cancelErr, "Should have CancelError at transition boundary") + assert.Equal(t, CancelAfterChatModel, cancelErr.Info.Mode) + assert.Equal(t, int32(1), atomic.LoadInt32(&model1.callCount), "Agent1 model should be invoked") + assert.Equal(t, int32(0), atomic.LoadInt32(&model2.callCount), + "Agent2 model should NOT be invoked (CancelAfterChatModel caught at transition)") +} + +func TestCancelAfterChatModel_Sequential_Agent1CompletesCancelBeforeAgent2Resume(t *testing.T) { + ctx := context.Background() + + model1 := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "agent1 done"}, + gateChan: make(chan struct{}), + doneChan: make(chan struct{}, 1), + } + model2 := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "agent2 done"}, + doneChan: make(chan struct{}, 1), + } + model3 := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "agent3 done"}, + doneChan: make(chan struct{}, 1), + } + + agent1, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent1", Description: "first", Instruction: "test", Model: model1, + }) + assert.NoError(t, err) + agent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent2", Description: "second", Instruction: "test", Model: model2, + }) + assert.NoError(t, err) + agent3, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent3", Description: "third", Instruction: "test", Model: model3, + }) + assert.NoError(t, err) + + seqAgent, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "seq", Description: "3-agent sequential", SubAgents: []Agent{agent1, agent2, agent3}, + }) + assert.NoError(t, err) + + store := newCancelTestStore() + runner := NewRunner(ctx, RunnerConfig{ + Agent: seqAgent, CheckPointStore: store, EnableStreaming: false, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("hello")}, cancelOpt, + WithCheckPointID("seq-transition-resume-1")) + + for atomic.LoadInt32(&model1.callCount) == 0 { + runtime.Gosched() + } + + cancelCalled, result := cancelAsync(cancelFn, WithAgentCancelMode(CancelAfterChatModel)) + waitForChan(t, cancelCalled, "cancelFn was not called") + close(model1.gateChan) + + assert.NoError(t, result.waitDone(t)) + + cancelErr := drainCancelError(t, iter) + assert.NotNil(t, cancelErr, "Should have CancelError") + assert.Equal(t, CancelAfterChatModel, cancelErr.Info.Mode) + assert.Equal(t, int32(1), atomic.LoadInt32(&model1.callCount)) + assert.Equal(t, int32(0), atomic.LoadInt32(&model2.callCount), + "Agent2 should NOT run (cancel caught at transition after agent1)") + assert.Equal(t, int32(0), atomic.LoadInt32(&model3.callCount)) + assert.NotEmpty(t, cancelErr.CheckPointID) + + resumeModel2 := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "resumed agent2"}, + doneChan: make(chan struct{}, 1), + } + resumeModel3 := &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "resumed agent3"}, + doneChan: make(chan struct{}, 1), + } + + resumeAgent1, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent1", Description: "first", Instruction: "test", + Model: &gatedChatModel{ + response: &schema.Message{Role: schema.Assistant, Content: "should not run"}, + doneChan: make(chan struct{}, 1), + }, + }) + assert.NoError(t, err) + resumeAgent2, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent2", Description: "second", Instruction: "test", Model: resumeModel2, + }) + assert.NoError(t, err) + resumeAgent3, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent3", Description: "third", Instruction: "test", Model: resumeModel3, + }) + assert.NoError(t, err) + + resumeSeq, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "seq", Description: "3-agent sequential", + SubAgents: []Agent{resumeAgent1, resumeAgent2, resumeAgent3}, + }) + assert.NoError(t, err) + + runner2 := NewRunner(ctx, RunnerConfig{ + Agent: resumeSeq, CheckPointStore: store, EnableStreaming: false, + }) + resumeIter, err := runner2.Resume(ctx, "seq-transition-resume-1") + assert.NoError(t, err) + assert.Empty(t, drainResumeErrors(t, resumeIter), "Resume should complete without errors") + + assert.Equal(t, int32(1), atomic.LoadInt32(&resumeModel2.callCount), + "Agent2 should run on resume") + assert.Equal(t, int32(1), atomic.LoadInt32(&resumeModel3.callCount), + "Agent3 should run on resume") +} + +func TestCancelAfterToolCalls_LoopTransitionBoundary(t *testing.T) { + ctx := context.Background() + + // Model that returns tool calls on odd calls and no tools on even calls. + // This completes one ReAct cycle per pair of calls: + // call 1 (gated): returns tool call → tool runs → call 2: returns no tools → END + // The gate only blocks the very first call. After that, all calls proceed instantly. + mdl := &multiResponseGatedModel{ + responses: []*schema.Message{ + {Role: schema.Assistant, ToolCalls: []schema.ToolCall{{ + ID: "call_1", Type: "function", + Function: schema.FunctionCall{Name: "loop_tool", Arguments: `{"input": "test"}`}, + }}}, + {Role: schema.Assistant, Content: "iteration done"}, + }, + gateChan: make(chan struct{}), + gateOnce: true, + doneChan: make(chan struct{}, 10), + } + + st := &slowTool{ + name: "loop_tool", + delay: 10 * time.Millisecond, + result: "tool done", + startedChan: make(chan struct{}, 10), + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "loop_inner", Description: "inner", Instruction: "test", Model: mdl, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + + loopAgent, err := NewLoopAgent(ctx, &LoopAgentConfig{ + Name: "loop", Description: "loop test", SubAgents: []Agent{agent}, MaxIterations: 10, + }) + assert.NoError(t, err) + + store := newCancelTestStore() + runner := NewRunner(ctx, RunnerConfig{Agent: loopAgent, CheckPointStore: store}) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt, WithCheckPointID("toolcalls-loop-1")) + + // Wait for the model to be entered (blocked on gate) + for atomic.LoadInt32(&mdl.callCount) == 0 { + runtime.Gosched() + } + + // Fire cancel, wait for it to be registered, then release the gate + cancelCalled, result := cancelAsync(cancelFn, WithAgentCancelMode(CancelAfterToolCalls)) + waitForChan(t, cancelCalled, "cancelFn was not called") + close(mdl.gateChan) + + // Iteration 1 completes fully (model→tool→model-no-tools→END). + // The CancelAfterToolCalls safe-point inside ReAct fires after tool calls, + // OR the transition boundary catches it before iteration 2. + // Note: this test doesn't deterministically distinguish which path fires — + // both are semantically correct for CancelAfterToolCalls. The transition- + // boundary code path for CancelAfterToolCalls in loops is not definitively + // covered here because the ReAct safe-point may handle it first. + assert.NoError(t, result.waitDone(t)) + + cancelErr := drainCancelError(t, iter) + assert.NotNil(t, cancelErr, "Should have CancelError from CancelAfterToolCalls in loop") + assert.Equal(t, CancelAfterToolCalls, cancelErr.Info.Mode) + assert.NotEmpty(t, cancelErr.CheckPointID) +} + +func TestCancelContext_ActiveChildren_Tracking(t *testing.T) { + t.Run("DeriveChild_IncrementsActiveChildren", func(t *testing.T) { + parent := newCancelContext() + assert.False(t, parent.hasActiveChildren()) + + ctx := context.Background() + child := parent.deriveChild(ctx) + assert.True(t, parent.hasActiveChildren()) + assert.Equal(t, int32(1), atomic.LoadInt32(&parent.activeChildren)) + + child.markDone() + time.Sleep(10 * time.Millisecond) + assert.False(t, parent.hasActiveChildren()) + assert.Equal(t, int32(0), atomic.LoadInt32(&parent.activeChildren)) + }) + + t.Run("MultipleChildren_AllTracked", func(t *testing.T) { + parent := newCancelContext() + ctx := context.Background() + + child1 := parent.deriveChild(ctx) + child2 := parent.deriveChild(ctx) + assert.Equal(t, int32(2), atomic.LoadInt32(&parent.activeChildren)) + + child1.markDone() + time.Sleep(10 * time.Millisecond) + assert.Equal(t, int32(1), atomic.LoadInt32(&parent.activeChildren)) + assert.True(t, parent.hasActiveChildren()) + + child2.markDone() + time.Sleep(10 * time.Millisecond) + assert.False(t, parent.hasActiveChildren()) + }) + + t.Run("MarkCancelHandled_AlsoDecrementsParent", func(t *testing.T) { + parent := newCancelContext() + ctx := context.Background() + + child := parent.deriveChild(ctx) + assert.True(t, parent.hasActiveChildren()) + + child.triggerCancel(CancelImmediate) + child.markCancelHandled() + time.Sleep(10 * time.Millisecond) + assert.False(t, parent.hasActiveChildren()) + }) + + t.Run("GracePeriodWrapper_AppliesWhenChildrenActive", func(t *testing.T) { + parent := newCancelContext() + ctx := context.Background() + + var receivedOpts []compose.GraphInterruptOption + mockInterrupt := func(opts ...compose.GraphInterruptOption) { + receivedOpts = opts + } + + wrapped := parent.wrapGraphInterruptWithGracePeriod(mockInterrupt) + + receivedOpts = nil + wrapped() + assert.Empty(t, receivedOpts, "Should pass no extra options when no children") + + _ = parent.deriveChild(ctx) + + receivedOpts = nil + wrapped() + assert.Empty(t, receivedOpts, "Should pass no extra options when children are active but not recursive") + + parent.setRecursive(true) + + receivedOpts = nil + wrapped() + assert.Len(t, receivedOpts, 1, "Should add exactly one timeout option when children are active and recursive") + + receivedOpts = nil + callerOpt := compose.WithGraphInterruptTimeout(0) + wrapped(callerOpt) + assert.Len(t, receivedOpts, 2, + "Should append timeout option after caller-provided options when children are active and recursive") + // Note: verifying the exact timeout value (defaultCancelImmediateGracePeriod) + // requires access to unexported compose.graphInterruptOptions. The integration + // tests (TestCancelImmediate_AgentTool_PreservesChildCheckpoint) verify the + // actual behavioral effect — child interrupts propagate within the grace period. + }) +} + +func TestCancel_ParallelWorkflow_CancelAfterChatModel(t *testing.T) { + ctx := context.Background() + slowStarted := make(chan struct{}, 1) + + slowAgent := newCancelTestAgentWithTools(t, "par_slow", 1*time.Second, slowStarted) + fastAgent := newCancelTestAgentWithTools(t, "par_fast", 50*time.Millisecond, make(chan struct{}, 1)) + + parAgent, err := NewParallelAgent(ctx, &ParallelAgentConfig{ + Name: "par_agent", + Description: "Parallel workflow", + SubAgents: []Agent{slowAgent, fastAgent}, + }) + assert.NoError(t, err) + + store := newCancelTestStore() + runner := NewRunner(ctx, RunnerConfig{ + Agent: parAgent, + CheckPointStore: store, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt, WithCheckPointID("par-cancel-1")) + + select { + case <-slowStarted: + case <-time.After(10 * time.Second): + t.Fatal("Slow agent did not start") + } + + handle, contributed := cancelFn(WithAgentCancelMode(CancelAfterChatModel)) + assert.True(t, contributed, "Cancel should contribute") + err = handle.Wait() + assert.NoError(t, err) + + hasCancelError := false + var cancelErr *CancelError + for { + event, ok := iter.Next() + if !ok { + break + } + if event.Err != nil && errors.As(event.Err, &cancelErr) { + hasCancelError = true + } + } + + assert.True(t, hasCancelError, "Should have CancelError from parallel workflow") + assert.Equal(t, CancelAfterChatModel, cancelErr.Info.Mode) + assert.NotEmpty(t, cancelErr.CheckPointID, "CancelError should have checkpoint ID") + + resumeSlow := newCancelTestAgentWithToolsFinalAnswer(t, "par_slow") + resumeFast := newCancelTestAgentWithToolsFinalAnswer(t, "par_fast") + + resumePar, err := NewParallelAgent(ctx, &ParallelAgentConfig{ + Name: "par_agent", + Description: "Parallel workflow", + SubAgents: []Agent{resumeSlow, resumeFast}, + }) + assert.NoError(t, err) + + runner2 := NewRunner(ctx, RunnerConfig{ + Agent: resumePar, + CheckPointStore: store, + }) + + resumeIter, err := runner2.Resume(ctx, "par-cancel-1") + assert.NoError(t, err) + assert.NotNil(t, resumeIter) + + var resumeErrors []error + for { + event, ok := resumeIter.Next() + if !ok { + break + } + if event.Err != nil { + resumeErrors = append(resumeErrors, event.Err) + } + } + assert.Empty(t, resumeErrors, "Resume should complete without errors") +} + +func TestCancel_LoopWorkflow_CancelAfterChatModel(t *testing.T) { + ctx := context.Background() + modelStarted := make(chan struct{}, 10) + + agent := newCancelTestAgentWithTools(t, "loop_inner", 500*time.Millisecond, modelStarted) + + loopAgent, err := NewLoopAgent(ctx, &LoopAgentConfig{ + Name: "loop_agent", + Description: "Loop workflow", + SubAgents: []Agent{agent}, + MaxIterations: 10, + }) + assert.NoError(t, err) + + store := newCancelTestStore() + runner := NewRunner(ctx, RunnerConfig{ + Agent: loopAgent, + CheckPointStore: store, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt, WithCheckPointID("loop-cancel-1")) + + select { + case <-modelStarted: + case <-time.After(10 * time.Second): + t.Fatal("Model did not start") + } + + handle, contributed := cancelFn(WithAgentCancelMode(CancelAfterChatModel)) + assert.True(t, contributed, "Cancel should contribute") + err = handle.Wait() + assert.NoError(t, err) + + hasCancelError := false + var cancelErr *CancelError + for { + event, ok := iter.Next() + if !ok { + break + } + if event.Err != nil && errors.As(event.Err, &cancelErr) { + hasCancelError = true + } + } + + assert.True(t, hasCancelError, "Should have CancelError from loop workflow") + assert.Equal(t, CancelAfterChatModel, cancelErr.Info.Mode) + assert.NotEmpty(t, cancelErr.CheckPointID, "CancelError should have checkpoint ID") + + resumeAgent := newCancelTestAgentWithToolsFinalAnswer(t, "loop_inner") + + resumeLoop, err := NewLoopAgent(ctx, &LoopAgentConfig{ + Name: "loop_agent", + Description: "Loop workflow", + SubAgents: []Agent{resumeAgent}, + MaxIterations: 10, + }) + assert.NoError(t, err) + + runner2 := NewRunner(ctx, RunnerConfig{ + Agent: resumeLoop, + CheckPointStore: store, + }) + + resumeIter, err := runner2.Resume(ctx, "loop-cancel-1") + assert.NoError(t, err) + assert.NotNil(t, resumeIter) + + var resumeEvents []*AgentEvent + for { + event, ok := resumeIter.Next() + if !ok { + break + } + assert.Nil(t, event.Err, "Should not have error during resume") + resumeEvents = append(resumeEvents, event) + } + assert.NotEmpty(t, resumeEvents, "Resume should produce events") +} + +func TestCancel_NestedWorkflow_AgentTool_CancelAfterChatModel(t *testing.T) { + // Structure: Runner -> RootCMA (with tools) -> agentTool -> flowAgent -> seqWorkflow -> LeafCMA + ctx := context.Background() + leafStarted := make(chan struct{}, 1) + + leafModel := &cancelTestChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + Content: "leaf response", + }, + startedChan: leafStarted, + doneChan: make(chan struct{}, 1), + } + leafModel.setDelay(500 * time.Millisecond) + + leafAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "leaf_agent", + Description: "Leaf agent in workflow", + Model: leafModel, + }) + assert.NoError(t, err) + + seqAgent, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "inner_seq", + Description: "Inner sequential workflow", + SubAgents: []Agent{leafAgent}, + }) + assert.NoError(t, err) + + rootModel := &cancelTestChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: schema.FunctionCall{ + Name: "inner_seq", + Arguments: `{}`, + }, + }, + }, + }, + startedChan: make(chan struct{}, 1), + doneChan: make(chan struct{}, 1), + } + rootAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "root_agent", + Description: "Root agent", + Model: rootModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{NewAgentTool(ctx, seqAgent)}, + }, + }, + }) + assert.NoError(t, err) + + store := newCancelTestStore() + runner := NewRunner(ctx, RunnerConfig{ + Agent: rootAgent, + CheckPointStore: store, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("test")}, cancelOpt, WithCheckPointID("nested-cancel-1")) + + select { + case <-leafStarted: + case <-time.After(10 * time.Second): + t.Fatal("Leaf agent model did not start") + } + + handle, contributed := cancelFn(WithAgentCancelMode(CancelAfterChatModel), WithRecursive()) + assert.True(t, contributed, "Cancel should contribute") + err = handle.Wait() + assert.NoError(t, err) + + hasCancelError := false + var cancelErr *CancelError + for { + event, ok := iter.Next() + if !ok { + break + } + if event.Err != nil && errors.As(event.Err, &cancelErr) { + hasCancelError = true + } + } + + assert.True(t, hasCancelError, "Should have CancelError from deeply nested workflow") + assert.Equal(t, CancelAfterChatModel, cancelErr.Info.Mode) + assert.NotEmpty(t, cancelErr.CheckPointID, "CancelError should have checkpoint ID") + assert.NotNil(t, cancelErr.interruptSignal, "CancelError should carry interrupt signal through agent tree") + + // Phase 2: Resume from checkpoint — new instances to avoid data races + resumeLeafModel := &cancelTestChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + Content: "resumed leaf response", + }, + startedChan: make(chan struct{}, 1), + doneChan: make(chan struct{}, 1), + } + resumeLeaf, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "leaf_agent", + Description: "Leaf agent in workflow", + Model: resumeLeafModel, + }) + assert.NoError(t, err) + + resumeSeq, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "inner_seq", + Description: "Inner sequential workflow", + SubAgents: []Agent{resumeLeaf}, + }) + assert.NoError(t, err) + + resumeRootModel := &cancelTestChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + Content: "resumed root response", + }, + startedChan: make(chan struct{}, 1), + doneChan: make(chan struct{}, 1), + } + resumeRoot, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "root_agent", + Description: "Root agent", + Model: resumeRootModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{NewAgentTool(ctx, resumeSeq)}, + }, + }, + }) + assert.NoError(t, err) + + runner2 := NewRunner(ctx, RunnerConfig{ + Agent: resumeRoot, + CheckPointStore: store, + }) + + resumeIter, err := runner2.Resume(ctx, "nested-cancel-1") + assert.NoError(t, err) + assert.NotNil(t, resumeIter) + + var resumeErrors []error + for { + event, ok := resumeIter.Next() + if !ok { + break + } + if event.Err != nil { + resumeErrors = append(resumeErrors, event.Err) + } + } + assert.Empty(t, resumeErrors, "Resume should complete without errors") +} + +func TestCancel_CancelAfterToolCalls_InSequentialWorkflow(t *testing.T) { + ctx := context.Background() + toolStarted := make(chan struct{}, 1) + + st := &slowTool{ + name: "slow_tool", + delay: 200 * time.Millisecond, + result: "tool done", + startedChan: toolStarted, + } + + modelWithToolCall := &simpleChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: schema.FunctionCall{ + Name: "slow_tool", + Arguments: `{"input": "test"}`, + }, + }, + }, + }, + } + + agentWithTools, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent_with_tools", + Description: "Agent with slow tool", + Model: modelWithToolCall, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{st}, + }, + }, + }) + assert.NoError(t, err) + + seqAgent, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "seq_agent", + Description: "Sequential workflow with tool agent", + SubAgents: []Agent{agentWithTools}, + }) + assert.NoError(t, err) + + store := newCancelTestStore() + runner := NewRunner(ctx, RunnerConfig{ + Agent: seqAgent, + CheckPointStore: store, + }) + + cancelOpt, cancelFn := WithCancel() + iter := runner.Run(ctx, []Message{schema.UserMessage("Use the tool")}, cancelOpt, WithCheckPointID("tool-cancel-1")) + + select { + case <-toolStarted: + case <-time.After(10 * time.Second): + t.Fatal("Tool did not start") + } + + // Cancel after tool calls — should wait for the tool to finish, then cancel + handle, contributed := cancelFn(WithAgentCancelMode(CancelAfterToolCalls)) + assert.True(t, contributed, "Cancel should contribute") + err = handle.Wait() + assert.NoError(t, err) + + hasCancelError := false + var cancelErr *CancelError + for { + event, ok := iter.Next() + if !ok { + break + } + if event.Err != nil && errors.As(event.Err, &cancelErr) { + hasCancelError = true + } + } + + assert.True(t, hasCancelError, "Should have CancelError after tool calls complete") + assert.Equal(t, CancelAfterToolCalls, cancelErr.Info.Mode) + assert.NotEmpty(t, cancelErr.CheckPointID, "CancelError should have checkpoint ID") + + // Phase 2: Resume from checkpoint — new instances + resumeTool := &slowTool{ + name: "slow_tool", + delay: 50 * time.Millisecond, + result: "resumed tool done", + startedChan: make(chan struct{}, 1), + } + + resumeModel := &simpleChatModel{ + response: &schema.Message{ + Role: schema.Assistant, + Content: "resumed response after tool", + }, + } + + resumeAgentWithTools, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "agent_with_tools", + Description: "Agent with slow tool", + Model: resumeModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{resumeTool}, + }, + }, + }) + assert.NoError(t, err) + + resumeSeq, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "seq_agent", + Description: "Sequential workflow with tool agent", + SubAgents: []Agent{resumeAgentWithTools}, + }) + assert.NoError(t, err) + + runner2 := NewRunner(ctx, RunnerConfig{ + Agent: resumeSeq, + CheckPointStore: store, + }) + + resumeIter, err := runner2.Resume(ctx, "tool-cancel-1") + assert.NoError(t, err) + assert.NotNil(t, resumeIter) + + var resumeEvents []*AgentEvent + for { + event, ok := resumeIter.Next() + if !ok { + break + } + assert.Nil(t, event.Err, "Should not have error during resume") + resumeEvents = append(resumeEvents, event) + } + assert.NotEmpty(t, resumeEvents, "Resume should produce events") +} diff --git a/adk/chatmodel.go b/adk/chatmodel.go index 56993a7b2..abfc55fa0 100644 --- a/adk/chatmodel.go +++ b/adk/chatmodel.go @@ -38,15 +38,25 @@ import ( "github.com/cloudwego/eino/schema" ) +var _ ResumableAgent = &ChatModelAgent{} + type chatModelAgentExecCtx struct { runtimeReturnDirectly map[string]bool generator *AsyncGenerator[*AgentEvent] + cancelCtx *cancelContext + + // failoverLastSuccessModel is the last success model only used in failover middleware. + failoverLastSuccessModel model.BaseChatModel } func (e *chatModelAgentExecCtx) send(event *AgentEvent) { - if e != nil && e.generator != nil { - e.generator.Send(event) + if e == nil || e.generator == nil { + return } + if e.cancelCtx != nil && e.cancelCtx.isImmediateCancelled() { + return + } + e.generator.trySend(event) } type chatModelAgentExecCtxKey struct{} @@ -253,13 +263,14 @@ type ChatModelAgentConfig struct { // Model call lifecycle (outermost to innermost wrapper chain): // 1. AgentMiddleware.BeforeChatModel (hook, runs before model call) // 2. ChatModelAgentMiddleware.BeforeModelRewriteState (hook, can modify state before model call) - // 3. retryModelWrapper (internal - retries on failure, if configured) - // 4. eventSenderModelWrapper (internal - sends model response events) - // 5. ChatModelAgentMiddleware.WrapModel (wrapper, first registered is outermost) - // 6. callbackInjectionModelWrapper (internal - injects callbacks if not enabled) - // 7. Model.Generate/Stream - // 8. ChatModelAgentMiddleware.AfterModelRewriteState (hook, can modify state after model call) - // 9. AgentMiddleware.AfterChatModel (hook, runs after model call) + // 3. failoverModelWrapper (internal - failover between models, if configured) + // 4. retryModelWrapper (internal - retries on failure, if configured) + // 5. eventSenderModelWrapper (internal - sends model response events) + // 6. ChatModelAgentMiddleware.WrapModel (wrapper, first registered is outermost) + // 7. callbackInjectionModelWrapper (internal - injects callbacks if not enabled; when failover is enabled, this is handled per-model inside failoverProxyModel instead) + // 8. failoverProxyModel (internal - dispatches to selected failover model, if configured) / Model.Generate/Stream + // 9. ChatModelAgentMiddleware.AfterModelRewriteState (hook, can modify state after model call) + // 10. AgentMiddleware.AfterChatModel (hook, runs after model call) // // Custom Event Sender Position: // By default, events are sent after all user middlewares (WrapModel) have processed the output, @@ -281,13 +292,35 @@ type ChatModelAgentConfig struct { // the default event sender to avoid duplicate events. // // Tool call lifecycle (outermost to innermost): - // 1. eventSenderToolHandler (internal ToolMiddleware - sends tool result events after all processing) + // 1. eventSenderToolWrapper (internal ToolMiddleware - sends tool result events after all processing) // 2. ToolsConfig.ToolCallMiddlewares (ToolMiddleware) // 3. AgentMiddleware.WrapToolCall (ToolMiddleware) // 4. ChatModelAgentMiddleware.WrapToolCall (wrapper, first registered is outermost) // 5. callbackInjectedToolCall (internal - injects callbacks if tool doesn't handle them) // 6. Tool.InvokableRun/StreamableRun // + // Custom Tool Event Sender Position: + // By default, tool result events are emitted by an internal event sender placed before + // all user middlewares (outermost), so events reflect the fully processed tool output. + // To control exactly where in the handler chain tool events are emitted, pass + // NewEventSenderToolWrapper() as one of the Handlers. Its position determines which + // middlewares' effects are visible in the emitted event: + // + // agent, _ := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ + // Handlers: []adk.ChatModelAgentMiddleware{ + // loggingHandler, // Outermost: sees event-sender output + // adk.NewEventSenderToolWrapper(), // Events reflect output from handlers below + // sanitizationHandler, // Innermost: runs first, modifies tool output + // }, + // }) + // + // Handler order: first registered is outermost. So [A, B, C] wraps as A(B(C(tool))). + // The event sender captures tool output in post-processing, so its position controls + // which handlers' modifications are included in the emitted events. + // + // When NewEventSenderToolWrapper is detected in Handlers, the framework skips + // the default event sender to avoid duplicate events. + // // Tool List Modification: // // There are two ways to modify the tool list: @@ -308,6 +341,13 @@ type ChatModelAgentConfig struct { // based on the configured policy. // Optional. If nil, no retry will be performed. ModelRetryConfig *ModelRetryConfig + + // ModelFailoverConfig configures failover behavior for the ChatModel. + // When set, the agent will first try the last successful model (initially the configured Model), + // and on failure, call GetFailoverModel to select alternate models. + // Model field is still required as it serves as the initial model. + // Optional. If nil, no failover will be performed. + ModelFailoverConfig *ModelFailoverConfig } type ChatModelAgent struct { @@ -333,7 +373,8 @@ type ChatModelAgent struct { handlers []ChatModelAgentMiddleware middlewares []AgentMiddleware - modelRetryConfig *ModelRetryConfig + modelRetryConfig *ModelRetryConfig + modelFailoverConfig *ModelFailoverConfig once sync.Once run runFunc @@ -341,10 +382,33 @@ type ChatModelAgent struct { exeCtx *execContext } -type runFunc func(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent], store *bridgeStore, instruction string, returnDirectly map[string]bool, opts ...compose.Option) +// runParams holds the parameters for a runFunc invocation. +type runParams struct { + input *AgentInput + generator *AsyncGenerator[*AgentEvent] + store *bridgeStore + instruction string + returnDirectly map[string]bool + cancelCtx *cancelContext + cancelCtxOwned bool + composeOpts []compose.Option +} + +type runFunc func(ctx context.Context, p *runParams) // NewChatModelAgent constructs a chat model-backed agent with the provided config. func NewChatModelAgent(ctx context.Context, config *ChatModelAgentConfig) (*ChatModelAgent, error) { + if config.ModelFailoverConfig != nil { + if config.ModelFailoverConfig.GetFailoverModel == nil { + return nil, errors.New("ModelFailoverConfig.GetFailoverModel is required when ModelFailoverConfig is set") + } + + // ShouldFailover is required when ModelFailoverConfig is set + if config.ModelFailoverConfig.ShouldFailover == nil { + return nil, errors.New("ModelFailoverConfig.ShouldFailover is required when ModelFailoverConfig is set") + } + } + if config.Model == nil { return nil, errors.New("agent 'Model' is required") } @@ -357,35 +421,41 @@ func NewChatModelAgent(ctx context.Context, config *ChatModelAgentConfig) (*Chat tc := config.ToolsConfig // Tool call middleware execution order (outermost to innermost): - // 1. eventSenderToolHandler (internal - sends tool result events after all modifications) + // 1. eventSenderToolWrapper (internal - sends tool result events after all modifications) // 2. User-provided ToolsConfig.ToolCallMiddlewares (original order preserved) // 3. Middlewares' WrapToolCall (in registration order) // 4. ChatModelAgentMiddleware.WrapToolCall (in registration order) // 5. callbackInjectedToolCall (internal - injects callbacks if tool doesn't handle them) - eventSender := &eventSenderToolHandler{} - tc.ToolCallMiddlewares = append( - []compose.ToolMiddleware{{Invokable: eventSender.WrapInvokableToolCall, - Streamable: eventSender.WrapStreamableToolCall, - EnhancedInvokable: eventSender.WrapEnhancedInvokableToolCall, - EnhancedStreamable: eventSender.WrapEnhancedStreamableToolCall, - }}, - tc.ToolCallMiddlewares..., - ) + if !hasUserEventSenderToolWrapper(config.Handlers) { + defaultToolEventSender := handlersToToolMiddlewares([]ChatModelAgentMiddleware{NewEventSenderToolWrapper()}) + tc.ToolCallMiddlewares = append(defaultToolEventSender, tc.ToolCallMiddlewares...) + } tc.ToolCallMiddlewares = append(tc.ToolCallMiddlewares, collectToolMiddlewaresFromMiddlewares(config.Middlewares)...) + // Cancel monitoring middleware (innermost — close to the tool endpoint). + // This allows early abort of the raw tool result stream when immediateChan fires + // (CancelImmediate or timeout escalation), while requiring outer wrappers to + // propagate stream errors such as ErrStreamCanceled without swallowing them. + cancelToolHandler := &cancelMonitoredToolHandler{} + tc.ToolCallMiddlewares = append(tc.ToolCallMiddlewares, compose.ToolMiddleware{ + Streamable: cancelToolHandler.WrapStreamableToolCall, + EnhancedStreamable: cancelToolHandler.WrapEnhancedStreamableToolCall, + }) + return &ChatModelAgent{ - name: config.Name, - description: config.Description, - instruction: config.Instruction, - model: config.Model, - toolsConfig: tc, - genModelInput: genInput, - exit: config.Exit, - outputKey: config.OutputKey, - maxIterations: config.MaxIterations, - handlers: config.Handlers, - middlewares: config.Middlewares, - modelRetryConfig: config.ModelRetryConfig, + name: config.Name, + description: config.Description, + instruction: config.Instruction, + model: config.Model, + toolsConfig: tc, + genModelInput: genInput, + exit: config.Exit, + outputKey: config.OutputKey, + maxIterations: config.MaxIterations, + handlers: config.Handlers, + middlewares: config.Middlewares, + modelRetryConfig: config.ModelRetryConfig, + modelFailoverConfig: config.ModelFailoverConfig, }, nil } @@ -570,8 +640,8 @@ func setOutputToSession(ctx context.Context, msg Message, msgStream MessageStrea } func errFunc(err error) runFunc { - return func(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent], store *bridgeStore, _ string, _ map[string]bool, _ ...compose.Option) { - generator.Send(&AgentEvent{Err: err}) + return func(ctx context.Context, p *runParams) { + p.generator.Send(&AgentEvent{Err: err}) } } @@ -613,8 +683,12 @@ func (a *ChatModelAgent) applyBeforeAgent(ctx context.Context, ec *execContext) runtimeEC := &execContext{ instruction: runCtx.Instruction, toolsNodeConf: compose.ToolsNodeConfig{ - Tools: runCtx.Tools, - ToolCallMiddlewares: cloneSlice(ec.toolsNodeConf.ToolCallMiddlewares), + Tools: runCtx.Tools, + ToolCallMiddlewares: cloneSlice(ec.toolsNodeConf.ToolCallMiddlewares), + UnknownToolsHandler: ec.toolsNodeConf.UnknownToolsHandler, + ExecuteSequentially: ec.toolsNodeConf.ExecuteSequentially, + ToolArgumentsHandler: ec.toolsNodeConf.ToolArgumentsHandler, + ToolAliases: ec.toolsNodeConf.ToolAliases, }, returnDirectly: runCtx.ReturnDirectly, toolUpdated: true, @@ -640,6 +714,7 @@ func (a *ChatModelAgent) prepareExecContext(ctx context.Context) (*execContext, UnknownToolsHandler: a.toolsConfig.UnknownToolsHandler, ExecuteSequentially: a.toolsConfig.ExecuteSequentially, ToolArgumentsHandler: a.toolsConfig.ToolArgumentsHandler, + ToolAliases: a.toolsConfig.ToolAliases, } returnDirectly := copyMap(a.toolsConfig.ReturnDirectly) @@ -691,20 +766,74 @@ func (a *ChatModelAgent) prepareExecContext(ctx context.Context) (*execContext, }, nil } -func (a *ChatModelAgent) buildNoToolsRunFunc(_ context.Context) runFunc { - wrappedModel := buildModelWrappers(a.model, &modelWrapperConfig{ - handlers: a.handlers, - middlewares: a.middlewares, - retryConfig: a.modelRetryConfig, - }) +// handleRunFuncError is the common error handler for buildNoToolsRunFunc and buildReActRunFunc. +// It handles compose interrupts (both cancel-triggered and business) +// and generic errors, sending the appropriate event to the generator. +func (a *ChatModelAgent) handleRunFuncError( + ctx context.Context, + err error, + cancelCtx *cancelContext, + cancelCtxOwned bool, + store *bridgeStore, + generator *AsyncGenerator[*AgentEvent], +) { + info, ok := compose.ExtractInterruptInfo(err) + if ok { + if cancelCtx != nil { + // Note: there is a benign TOCTOU window here. Between shouldCancel() + // returning false and markDone() executing, a concurrent cancel could + // transition stateRunning→stateCancelling. markDone() then does + // stateCancelling→stateDone, and the cancel func receives + // ErrExecutionCompleted (execution finished before cancel took effect). + if !cancelCtx.shouldCancel() { + cancelCtx.markDone() + } + } + + data, existed, sErr := store.Get(ctx, bridgeCheckpointID) + if sErr != nil { + generator.Send(&AgentEvent{AgentName: a.name, Err: fmt.Errorf("failed to get interrupt info: %w", sErr)}) + return + } + if !existed { + generator.Send(&AgentEvent{AgentName: a.name, Err: fmt.Errorf("interrupt occurred but checkpoint data is missing")}) + return + } + + is := FromInterruptContexts(info.InterruptContexts) + event := CompositeInterrupt(ctx, info, data, is) + event.Action.Interrupted.Data = &ChatModelAgentInterruptInfo{ + Info: info, + Data: data, + } + event.AgentName = a.name + generator.Send(event) + return + } + + if cancelCtxOwned && cancelCtx != nil { + cancelCtx.markDone() + } + generator.Send(&AgentEvent{Err: err}) +} +func (a *ChatModelAgent) buildNoToolsRunFunc(_ context.Context) runFunc { type noToolsInput struct { input *AgentInput instruction string } - return func(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent], - store *bridgeStore, instruction string, _ map[string]bool, opts ...compose.Option) { + return func(ctx context.Context, p *runParams) { + cancelCtx := p.cancelCtx + ctx = withCancelContext(ctx, cancelCtx) + + wrappedModel := buildModelWrappers(a.model, &modelWrapperConfig{ + handlers: a.handlers, + middlewares: a.middlewares, + retryConfig: a.modelRetryConfig, + failoverConfig: a.modelFailoverConfig, + cancelContext: cancelCtx, + }) chain := compose.NewChain[noToolsInput, Message]( compose.WithGenLocalState(func(ctx context.Context) (state *State) { @@ -715,56 +844,86 @@ func (a *ChatModelAgent) buildNoToolsRunFunc(_ context.Context) runFunc { if err != nil { return nil, err } + _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { + st.Messages = append(st.Messages, messages...) + return nil + }) return messages, nil })). AppendChatModel(wrappedModel) - r, err := chain.Compile(ctx, compose.WithGraphName(a.name), - compose.WithCheckPointStore(store), + var compileOptions []compose.GraphCompileOption + compileOptions = append(compileOptions, + compose.WithGraphName(a.name), + compose.WithCheckPointStore(p.store), compose.WithSerializer(&gobSerializer{})) + + if cancelCtx != nil { + var interrupt func(...compose.GraphInterruptOption) + ctx, interrupt = compose.WithGraphInterrupt(ctx) + cancelCtx.setGraphInterruptFunc(cancelCtx.wrapGraphInterruptWithGracePeriod(interrupt)) + } + + r, err := chain.Compile(ctx, compileOptions...) if err != nil { - generator.Send(&AgentEvent{Err: err}) + p.generator.Send(&AgentEvent{Err: err}) return } ctx = withChatModelAgentExecCtx(ctx, &chatModelAgentExecCtx{ - generator: generator, + generator: p.generator, + cancelCtx: cancelCtx, + failoverLastSuccessModel: a.model, }) - in := noToolsInput{input: input, instruction: instruction} + // Pre-execution cancel check + if cancelCtx != nil && cancelCtx.shouldCancel() { + if cancelCtx.getMode() == CancelImmediate || atomic.LoadInt32(&cancelCtx.escalated) == 1 { + cancelErr, ok := cancelCtx.createAndMarkCancelHandled() + if !ok { + return + } + p.generator.Send(&AgentEvent{Err: cancelErr}) + return + } + } + + in := noToolsInput{input: p.input, instruction: p.instruction} var msg Message var msgStream MessageStream - if input.EnableStreaming { - msgStream, err = r.Stream(ctx, in, opts...) + if p.input.EnableStreaming { + msgStream, err = r.Stream(ctx, in, p.composeOpts...) } else { - msg, err = r.Invoke(ctx, in, opts...) + msg, err = r.Invoke(ctx, in, p.composeOpts...) } if err == nil { if a.outputKey != "" { err = setOutputToSession(ctx, msg, msgStream, a.outputKey) if err != nil { - generator.Send(&AgentEvent{Err: err}) + p.generator.Send(&AgentEvent{Err: err}) } } else if msgStream != nil { msgStream.Close() } - } else { - generator.Send(&AgentEvent{Err: err}) + return } + + a.handleRunFuncError(ctx, err, cancelCtx, p.cancelCtxOwned, p.store, p.generator) } } -func (a *ChatModelAgent) buildReactRunFunc(ctx context.Context, bc *execContext) (runFunc, error) { +func (a *ChatModelAgent) buildReActRunFunc(_ context.Context, bc *execContext) (runFunc, error) { conf := &reactConfig{ model: a.model, toolsConfig: &bc.toolsNodeConf, modelWrapperConf: &modelWrapperConfig{ - handlers: a.handlers, - middlewares: a.middlewares, - retryConfig: a.modelRetryConfig, - toolInfos: bc.toolInfos, + handlers: a.handlers, + middlewares: a.middlewares, + retryConfig: a.modelRetryConfig, + failoverConfig: a.modelFailoverConfig, + toolInfos: bc.toolInfos, }, toolsReturnDirectly: bc.returnDirectly, agentName: a.name, @@ -776,11 +935,17 @@ func (a *ChatModelAgent) buildReactRunFunc(ctx context.Context, bc *execContext) instruction string } - return func(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent], store *bridgeStore, - instruction string, returnDirectly map[string]bool, opts ...compose.Option) { + return func(ctx context.Context, p *runParams) { + cancelCtx := p.cancelCtx + conf.cancelCtx = cancelCtx + if conf.modelWrapperConf != nil { + conf.modelWrapperConf.cancelContext = cancelCtx + } + ctx = withCancelContext(ctx, cancelCtx) + g, err := newReact(ctx, conf) if err != nil { - generator.Send(&AgentEvent{Err: err}) + p.generator.Send(&AgentEvent{Err: err}) return } @@ -792,7 +957,7 @@ func (a *ChatModelAgent) buildReactRunFunc(ctx context.Context, bc *execContext) return nil, genErr } return &reactInput{ - messages: messages, + Messages: messages, }, nil }), ). @@ -801,38 +966,58 @@ func (a *ChatModelAgent) buildReactRunFunc(ctx context.Context, bc *execContext) var compileOptions []compose.GraphCompileOption compileOptions = append(compileOptions, compose.WithGraphName(a.name), - compose.WithCheckPointStore(store), + compose.WithCheckPointStore(p.store), compose.WithSerializer(&gobSerializer{}), compose.WithMaxRunSteps(math.MaxInt)) + if cancelCtx != nil { + var interrupt func(...compose.GraphInterruptOption) + ctx, interrupt = compose.WithGraphInterrupt(ctx) + cancelCtx.setGraphInterruptFunc(cancelCtx.wrapGraphInterruptWithGracePeriod(interrupt)) + } + runnable, err_ := chain.Compile(ctx, compileOptions...) if err_ != nil { - generator.Send(&AgentEvent{Err: err_}) + p.generator.Send(&AgentEvent{Err: err_}) return } ctx = withChatModelAgentExecCtx(ctx, &chatModelAgentExecCtx{ - runtimeReturnDirectly: returnDirectly, - generator: generator, + runtimeReturnDirectly: p.returnDirectly, + generator: p.generator, + cancelCtx: cancelCtx, + failoverLastSuccessModel: a.model, }) + // Pre-execution cancel check + if cancelCtx != nil && cancelCtx.shouldCancel() { + if cancelCtx.getMode() == CancelImmediate || atomic.LoadInt32(&cancelCtx.escalated) == 1 { + cancelErr, ok := cancelCtx.createAndMarkCancelHandled() + if !ok { + return + } + p.generator.Send(&AgentEvent{Err: cancelErr}) + return + } + } + in := reactRunInput{ - input: input, - instruction: instruction, + input: p.input, + instruction: p.instruction, } var runOpts []compose.Option - runOpts = append(runOpts, opts...) + runOpts = append(runOpts, p.composeOpts...) if a.toolsConfig.EmitInternalEvents { - runOpts = append(runOpts, compose.WithToolsNodeOption(compose.WithToolOption(withAgentToolEventGenerator(generator)))) + runOpts = append(runOpts, compose.WithToolsNodeOption(compose.WithToolOption(withAgentToolEventGenerator(p.generator)))) } - if input.EnableStreaming { + if p.input.EnableStreaming { runOpts = append(runOpts, compose.WithToolsNodeOption(compose.WithToolOption(withAgentToolEnableStreaming(true)))) } var msg Message var msgStream MessageStream - if input.EnableStreaming { + if p.input.EnableStreaming { msgStream, err_ = runnable.Stream(ctx, in, runOpts...) } else { msg, err_ = runnable.Invoke(ctx, in, runOpts...) @@ -842,7 +1027,7 @@ func (a *ChatModelAgent) buildReactRunFunc(ctx context.Context, bc *execContext) if a.outputKey != "" { err_ = setOutputToSession(ctx, msg, msgStream, a.outputKey) if err_ != nil { - generator.Send(&AgentEvent{Err: err_}) + p.generator.Send(&AgentEvent{Err: err_}) } } else if msgStream != nil { msgStream.Close() @@ -851,31 +1036,7 @@ func (a *ChatModelAgent) buildReactRunFunc(ctx context.Context, bc *execContext) return } - info, ok := compose.ExtractInterruptInfo(err_) - if !ok { - generator.Send(&AgentEvent{Err: err_}) - return - } - - data, existed, err := store.Get(ctx, bridgeCheckpointID) - if err != nil { - generator.Send(&AgentEvent{AgentName: a.name, Err: fmt.Errorf("failed to get interrupt info: %w", err)}) - return - } - if !existed { - generator.Send(&AgentEvent{AgentName: a.name, Err: fmt.Errorf("interrupt occurred but checkpoint data is missing")}) - return - } - - is := FromInterruptContexts(info.InterruptContexts) - - event := CompositeInterrupt(ctx, info, data, is) - event.Action.Interrupted.Data = &ChatModelAgentInterruptInfo{ - Info: info, - Data: data, - } - event.AgentName = a.name - generator.Send(event) + a.handleRunFuncError(ctx, err_, cancelCtx, p.cancelCtxOwned, p.store, p.generator) }, nil } @@ -894,7 +1055,7 @@ func (a *ChatModelAgent) buildRunFunc(ctx context.Context) runFunc { return } - run, err := a.buildReactRunFunc(ctx, ec) + run, err := a.buildReActRunFunc(ctx, ec) if err != nil { a.run = errFunc(err) return @@ -938,7 +1099,7 @@ func (a *ChatModelAgent) getRunFunc(ctx context.Context) (context.Context, runFu if len(runtimeBC.toolsNodeConf.Tools) == 0 { tempRun = a.buildNoToolsRunFunc(ctx) } else { - tempRun, err = a.buildReactRunFunc(ctx, runtimeBC) + tempRun, err = a.buildReActRunFunc(ctx, runtimeBC) if err != nil { return ctx, nil, nil, err } @@ -950,10 +1111,20 @@ func (a *ChatModelAgent) getRunFunc(ctx context.Context) (context.Context, runFu func (a *ChatModelAgent) Run(ctx context.Context, input *AgentInput, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { iterator, generator := NewAsyncIteratorPair[*AgentEvent]() + o := getCommonOptions(nil, opts...) + cancelCtx := o.cancelCtx + cancelCtxOwned := cancelCtx != nil && getCancelContext(ctx) == nil + if cancelCtx == nil { + cancelCtx = getCancelContext(ctx) + } + ctx, run, bc, err := a.getRunFunc(ctx) if err != nil { go func() { - generator.Send(&AgentEvent{Err: err}) + if cancelCtxOwned && cancelCtx != nil { + defer cancelCtx.markDone() + } + generator.Send(&AgentEvent{Err: fmt.Errorf("ChatModelAgent getRunFunc error: %w", err)}) generator.Close() }() return iterator @@ -990,19 +1161,41 @@ func (a *ChatModelAgent) Run(ctx context.Context, input *AgentInput, opts ...Age returnDirectly = bc.returnDirectly } - run(ctx, input, generator, newBridgeStore(), instruction, returnDirectly, co...) + run(ctx, &runParams{ + input: input, + generator: generator, + store: newBridgeStore(), + instruction: instruction, + returnDirectly: returnDirectly, + cancelCtx: cancelCtx, + cancelCtxOwned: cancelCtxOwned, + composeOpts: co, + }) }() + if cancelCtxOwned { + return wrapIterWithCancelCtx(iterator, cancelCtx) + } return iterator } func (a *ChatModelAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { iterator, generator := NewAsyncIteratorPair[*AgentEvent]() + o := getCommonOptions(nil, opts...) + cancelCtx := o.cancelCtx + cancelCtxOwned := cancelCtx != nil && getCancelContext(ctx) == nil + if cancelCtx == nil { + cancelCtx = getCancelContext(ctx) + } + ctx, run, bc, err := a.getRunFunc(ctx) if err != nil { go func() { - generator.Send(&AgentEvent{Err: err}) + if cancelCtxOwned && cancelCtx != nil { + defer cancelCtx.markDone() + } + generator.Send(&AgentEvent{Err: fmt.Errorf("ChatModelAgent getRunFunc error: %w", err)}) generator.Close() }() return iterator @@ -1018,6 +1211,10 @@ func (a *ChatModelAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...A } } + if info == nil { + panic(fmt.Sprintf("ChatModelAgent.Resume: agent '%s' was asked to resume but info is nil", a.Name(ctx))) + } + if info.InterruptState == nil { panic(fmt.Sprintf("ChatModelAgent.Resume: agent '%s' was asked to resume but has no state", a.Name(ctx))) } @@ -1085,10 +1282,21 @@ func (a *ChatModelAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...A returnDirectly = bc.returnDirectly } - run(ctx, &AgentInput{EnableStreaming: info.EnableStreaming}, generator, - newResumeBridgeStore(stateByte), instruction, returnDirectly, co...) + run(ctx, &runParams{ + input: &AgentInput{EnableStreaming: info.EnableStreaming}, + generator: generator, + store: newResumeBridgeStore(bridgeCheckpointID, stateByte), + instruction: instruction, + returnDirectly: returnDirectly, + cancelCtx: cancelCtx, + cancelCtxOwned: cancelCtxOwned, + composeOpts: co, + }) }() + if cancelCtxOwned { + return wrapIterWithCancelCtx(iterator, cancelCtx) + } return iterator } diff --git a/adk/chatmodel_retry_test.go b/adk/chatmodel_retry_test.go index 00c89b352..0cb2a87bd 100644 --- a/adk/chatmodel_retry_test.go +++ b/adk/chatmodel_retry_test.go @@ -1046,3 +1046,148 @@ func TestSequentialWorkflow_NoRetryConfig_StreamError_StopsFlow(t *testing.T) { assert.Equal(t, 0, len(capturingModel.capturedInputs), "Agent B should NOT be called due to error") assert.Equal(t, int32(1), atomic.LoadInt32(&noRetryModel.callCount), "Model should only be called once (no retry)") } + +// failThenToolCallStreamModel is a ChatModel that: +// - First Stream() call: yields a partial chunk then fails with a retryable error mid-stream. +// - Second Stream() call (retry): yields a tool-call message (success). +// - Third Generate() call (after tool result): yields a final assistant message. +// +// This exercises the path where the eventSenderModel copies the first stream, +// wraps its error as WillRetryError, and sends it as an event to the session. +// The retryModelWrapper then retries, gets a clean stream with a tool call, +// the tool interrupts, and checkpoint save needs to gob-encode the session +// (which still contains the unconsumed WillRetryError event stream). +type failThenToolCallStreamModel struct { + streamCallCount int32 + genCallCount int32 +} + +func (m *failThenToolCallStreamModel) Generate(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m.genCallCount, 1) + return schema.AssistantMessage("final answer", nil), nil +} + +func (m *failThenToolCallStreamModel) Stream(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + count := atomic.AddInt32(&m.streamCallCount, 1) + + sr, sw := schema.Pipe[*schema.Message](10) + go func() { + defer sw.Close() + if count == 1 { + // First call: yield a partial chunk then fail. + sw.Send(schema.AssistantMessage("partial", nil), nil) + sw.Send(nil, errRetryAble) + return + } + // Second call (retry): yield a tool-call message. + sw.Send(schema.AssistantMessage("", []schema.ToolCall{{ + ID: "call-1", + Function: schema.FunctionCall{ + Name: "interrupt_tool", + Arguments: `{}`, + }, + }}), nil) + }() + return sr, nil +} + +func (m *failThenToolCallStreamModel) WithTools(_ []*schema.ToolInfo) (model.ToolCallingChatModel, error) { + return m, nil +} + +// interruptToolForRetryTest is a tool that always interrupts. +type interruptToolForRetryTest struct{} + +func (t *interruptToolForRetryTest) Info(_ context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: "interrupt_tool", + Desc: "tool that interrupts", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "input": {Type: "string"}, + }), + }, nil +} + +func (t *interruptToolForRetryTest) InvokableRun(ctx context.Context, _ string, _ ...tool.Option) (string, error) { + return "", tool.Interrupt(ctx, "interrupted by tool") +} + +// TestCheckpointSave_WillRetryError_StreamNotConsumed verifies that checkpoint +// saving succeeds when the session contains an event with an unconsumed stream +// that ends with WillRetryError. +// +// Scenario: +// 1. ChatModelAgent with retry (MaxRetries=1) and a tool that always interrupts +// 2. Model.Stream() #1 yields "partial" then errRetryAble mid-stream +// → eventSenderModel copies the stream, wraps the error as WillRetryError, +// sends the event to the session (stream NOT consumed by anyone yet) +// → retryModelWrapper detects error on its copy, retries +// 3. Model.Stream() #2 succeeds with a tool-call message +// 4. Tool executes → interrupts +// 5. Runner.handleIter sees the interrupt → saveCheckPoint → gob encodes runSession +// 6. The session has the WillRetryError event with an unconsumed stream +// → agentEventWrapper.GobEncode proactively consumes the stream via +// getMessageFromWrappedEvent, so MessageVariant.GobEncode sees an error-free +// array and succeeds +func TestCheckpointSave_WillRetryError_StreamNotConsumed(t *testing.T) { + ctx := context.Background() + + mdl := &failThenToolCallStreamModel{} + itool := &interruptToolForRetryTest{} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Agent for checkpoint stream error test", + Instruction: "You are a test agent.", + Model: mdl, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{itool}, + }, + }, + ModelRetryConfig: &ModelRetryConfig{ + MaxRetries: 1, + IsRetryAble: func(_ context.Context, err error) bool { + return errors.Is(err, errRetryAble) + }, + BackoffFunc: func(_ context.Context, _ int) time.Duration { + return time.Millisecond // fast retry for test + }, + }, + }) + assert.NoError(t, err) + + store := newMyStore() + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + EnableStreaming: true, + CheckPointStore: store, + }) + + iter := runner.Run(ctx, + []Message{schema.UserMessage("hello")}, + WithCheckPointID("ckpt-1"), + ) + + var events []*AgentEvent + for { + event, ok := iter.Next() + if !ok { + break + } + events = append(events, event) + + if event.Err != nil { + t.Logf("event error: %v", event.Err) + } + } + + // Verify the checkpoint was saved successfully. + _, exists, _ := store.Get(ctx, "ckpt-1") + assert.True(t, exists, "checkpoint should be saved successfully; "+ + "if this fails, the WillRetryError stream in the session caused gob encoding to fail") + + // Sanity: the model should have been called twice for Stream (fail + retry). + assert.Equal(t, int32(2), atomic.LoadInt32(&mdl.streamCallCount), + "model should be called twice: first fail, then retry success") +} diff --git a/adk/chatmodel_test.go b/adk/chatmodel_test.go index 3a2f920dd..0edab5a2d 100644 --- a/adk/chatmodel_test.go +++ b/adk/chatmodel_test.go @@ -18,11 +18,13 @@ package adk import ( "context" + "encoding/json" "errors" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" "github.com/cloudwego/eino/components/model" @@ -2057,3 +2059,359 @@ func TestPreprocessComposeCheckpoint_MigrateErrorIsReturned(t *testing.T) { _, err := preprocessComposeCheckpoint(in) assert.Error(t, err) } + +func TestNewChatModelAgent_FailoverConfigValidation(t *testing.T) { + ctx := context.Background() + cm := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return schema.AssistantMessage("ok", nil), nil + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("ok", nil)}), nil + }, + } + + t.Run("missing GetFailoverModel", func(t *testing.T) { + _, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: cm, + ModelFailoverConfig: &ModelFailoverConfig{ + ShouldFailover: func(context.Context, *schema.Message, error) bool { return true }, + }, + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "ModelFailoverConfig.GetFailoverModel") + }) + + t.Run("missing ShouldFailover", func(t *testing.T) { + _, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "test", + Model: cm, + ModelFailoverConfig: &ModelFailoverConfig{ + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return cm, nil, nil + }, + }, + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "ModelFailoverConfig.ShouldFailover") + }) +} + +// aliasCaptureTool captures the raw arguments JSON received by the tool. +type aliasCaptureTool struct { + name string + params map[string]*schema.ParameterInfo + receivedArgs string +} + +func (t *aliasCaptureTool) Info(_ context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: t.name, + Desc: t.name + " tool", + ParamsOneOf: schema.NewParamsOneOfByParams(t.params), + }, nil +} + +func (t *aliasCaptureTool) InvokableRun(_ context.Context, argumentsInJSON string, _ ...tool.Option) (string, error) { + t.receivedArgs = argumentsInJSON + return "ok", nil +} + +func TestToolAliasesPropagation(t *testing.T) { + t.Run("prepareExecContext_propagates_ToolAliases", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + + captureTool := &aliasCaptureTool{ + name: "grep", + params: map[string]*schema.ParameterInfo{ + "pattern": {Type: schema.String, Desc: "regex pattern"}, + "path": {Type: schema.String, Desc: "search path"}, + }, + } + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + generateCount := 0 + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, msgs []*schema.Message, opts ...model.Option) (*schema.Message, error) { + generateCount++ + if generateCount == 1 { + return &schema.Message{ + Role: schema.Assistant, + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "grep", + Arguments: `{"grep_content": "TODO", "path": "/src"}`, + }, + }, + }, + }, nil + } + return schema.AssistantMessage("done", nil), nil + }).AnyTimes() + cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "test", + Instruction: "test", + Model: cm, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{captureTool}, + ToolAliases: map[string]compose.ToolAliasConfig{ + "grep": { + ArgumentsAliases: map[string][]string{ + "pattern": {"grep_content"}, + }, + }, + }, + }, + }, + }) + require.NoError(t, err) + + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("search for TODOs")}}) + for { + _, ok := iter.Next() + if !ok { + break + } + } + + require.NotEmpty(t, captureTool.receivedArgs, "tool should have been called") + var args map[string]any + err = json.Unmarshal([]byte(captureTool.receivedArgs), &args) + require.NoError(t, err) + assert.Equal(t, "TODO", args["pattern"], "alias 'grep_content' should be remapped to 'pattern'") + assert.NotContains(t, args, "grep_content", "alias key should not be present after remapping") + assert.Equal(t, "/src", args["path"]) + }) + + t.Run("applyBeforeAgent_preserves_ToolAliases_when_handler_modifies_tools", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + + captureTool := &aliasCaptureTool{ + name: "grep", + params: map[string]*schema.ParameterInfo{ + "pattern": {Type: schema.String, Desc: "regex pattern"}, + }, + } + + extraTool := &aliasCaptureTool{ + name: "extra_tool", + params: map[string]*schema.ParameterInfo{ + "input": {Type: schema.String}, + }, + } + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + generateCount := 0 + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, msgs []*schema.Message, opts ...model.Option) (*schema.Message, error) { + generateCount++ + if generateCount == 1 { + return &schema.Message{ + Role: schema.Assistant, + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "grep", + Arguments: `{"grep_content": "FIXME"}`, + }, + }, + }, + }, nil + } + return schema.AssistantMessage("done", nil), nil + }).AnyTimes() + cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() + + handler := &testToolsHandler{ + BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, + tools: []tool.BaseTool{extraTool}, + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "test", + Instruction: "test", + Model: cm, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{captureTool}, + ToolAliases: map[string]compose.ToolAliasConfig{ + "grep": { + ArgumentsAliases: map[string][]string{ + "pattern": {"grep_content"}, + }, + }, + }, + }, + }, + Handlers: []ChatModelAgentMiddleware{handler}, + }) + require.NoError(t, err) + + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("search for FIXMEs")}}) + for { + _, ok := iter.Next() + if !ok { + break + } + } + + require.NotEmpty(t, captureTool.receivedArgs, "tool should have been called") + var args map[string]any + err = json.Unmarshal([]byte(captureTool.receivedArgs), &args) + require.NoError(t, err) + assert.Equal(t, "FIXME", args["pattern"], "alias 'grep_content' should be remapped to 'pattern' even after handler rebuild") + assert.NotContains(t, args, "grep_content", "alias key should not be present after remapping") + }) + + t.Run("name_alias_propagated_through_prepareExecContext", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + + captureTool := &aliasCaptureTool{ + name: "grep", + params: map[string]*schema.ParameterInfo{ + "pattern": {Type: schema.String}, + }, + } + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + generateCount := 0 + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, msgs []*schema.Message, opts ...model.Option) (*schema.Message, error) { + generateCount++ + if generateCount == 1 { + return &schema.Message{ + Role: schema.Assistant, + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "search_content", + Arguments: `{"pattern": "TODO"}`, + }, + }, + }, + }, nil + } + return schema.AssistantMessage("done", nil), nil + }).AnyTimes() + cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "test", + Instruction: "test", + Model: cm, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{captureTool}, + ToolAliases: map[string]compose.ToolAliasConfig{ + "grep": { + NameAliases: []string{"search_content"}, + }, + }, + }, + }, + }) + require.NoError(t, err) + + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("search")}}) + for { + _, ok := iter.Next() + if !ok { + break + } + } + + require.NotEmpty(t, captureTool.receivedArgs, "tool should have been called via name alias 'search_content'") + var args map[string]any + err = json.Unmarshal([]byte(captureTool.receivedArgs), &args) + require.NoError(t, err) + assert.Equal(t, "TODO", args["pattern"]) + }) + + t.Run("handler_adds_tool_matching_preexisting_ToolAliases_with_no_initial_tools", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + + captureTool := &aliasCaptureTool{ + name: "grep", + params: map[string]*schema.ParameterInfo{ + "pattern": {Type: schema.String, Desc: "regex pattern"}, + }, + } + + cm := mockModel.NewMockToolCallingChatModel(ctrl) + generateCount := 0 + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, msgs []*schema.Message, opts ...model.Option) (*schema.Message, error) { + generateCount++ + if generateCount == 1 { + return &schema.Message{ + Role: schema.Assistant, + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "grep", + Arguments: `{"grep_content": "BUG"}`, + }, + }, + }, + }, nil + } + return schema.AssistantMessage("done", nil), nil + }).AnyTimes() + cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() + + handler := &testToolsHandler{ + BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, + tools: []tool.BaseTool{captureTool}, + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "test", + Instruction: "test", + Model: cm, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + ToolAliases: map[string]compose.ToolAliasConfig{ + "grep": { + ArgumentsAliases: map[string][]string{ + "pattern": {"grep_content"}, + }, + }, + }, + }, + }, + Handlers: []ChatModelAgentMiddleware{handler}, + }) + require.NoError(t, err) + + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("find bugs")}}) + for { + _, ok := iter.Next() + if !ok { + break + } + } + + require.NotEmpty(t, captureTool.receivedArgs, "tool added by handler should have been called") + var args map[string]any + err = json.Unmarshal([]byte(captureTool.receivedArgs), &args) + require.NoError(t, err) + assert.Equal(t, "BUG", args["pattern"], "alias 'grep_content' should be remapped to 'pattern' for handler-added tool") + assert.NotContains(t, args, "grep_content") + }) +} diff --git a/adk/failover_chatmodel.go b/adk/failover_chatmodel.go new file mode 100644 index 000000000..2a467ed76 --- /dev/null +++ b/adk/failover_chatmodel.go @@ -0,0 +1,466 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "errors" + "fmt" + "io" + "log" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/components" + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/schema" +) + +type failoverCurrentModelKey struct{} + +type failoverCurrentModel struct { + model model.BaseChatModel +} + +func setFailoverCurrentModel(ctx context.Context, currentModel model.BaseChatModel) context.Context { + return context.WithValue(ctx, failoverCurrentModelKey{}, &failoverCurrentModel{ + model: currentModel, + }) +} + +func getFailoverCurrentModel(ctx context.Context) *failoverCurrentModel { + if fm, ok := ctx.Value(failoverCurrentModelKey{}).(*failoverCurrentModel); ok { + return fm + } + return nil +} + +type failoverHasMoreAttemptsKey struct{} + +// withFailoverHasMoreAttempts sets a flag in context indicating whether additional failover +// attempts remain after the current one. This is read by buildErrWrapper to decide whether +// stream errors should be wrapped as WillRetryError. +func withFailoverHasMoreAttempts(ctx context.Context, hasMore bool) context.Context { + return context.WithValue(ctx, failoverHasMoreAttemptsKey{}, hasMore) +} + +// getFailoverHasMoreAttempts returns true if the current failover attempt has more attempts +// after it, false otherwise (including when no failover context is present). +func getFailoverHasMoreAttempts(ctx context.Context) bool { + v, _ := ctx.Value(failoverHasMoreAttemptsKey{}).(bool) + return v +} + +type failoverProxyModel struct { +} + +func (m *failoverProxyModel) prepareCallbacks(ctx context.Context) (context.Context, model.BaseChatModel, error) { + current := getFailoverCurrentModel(ctx) + if current == nil || current.model == nil { + return nil, nil, errors.New("failover current model not found in context") + } + + typ, _ := components.GetType(current.model) + ctx = callbacks.EnsureRunInfo(ctx, typ, components.ComponentOfChatModel) + + target := current.model + if !components.IsCallbacksEnabled(target) { + target = (&callbackInjectionModelWrapper{}).WrapModel(target) + } + + return ctx, target, nil +} + +func (m *failoverProxyModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + nCtx, target, err := m.prepareCallbacks(ctx) + if err != nil { + return nil, err + } + + ctx = callbacks.OnStart(ctx, input) + + result, err := target.Generate(nCtx, input, opts...) + if err != nil { + callbacks.OnError(ctx, err) + return result, err + } + + callbacks.OnEnd(ctx, result) + + return result, nil +} + +func (m *failoverProxyModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + nCtx, target, err := m.prepareCallbacks(ctx) + if err != nil { + return nil, err + } + + ctx = callbacks.OnStart(ctx, input) + + result, err := target.Stream(nCtx, input, opts...) + if err != nil { + callbacks.OnError(ctx, err) + return nil, err + } + + _, wrappedStream := callbacks.OnEndWithStreamOutput(ctx, result) + return wrappedStream, nil +} + +func (m *failoverProxyModel) IsCallbacksEnabled() bool { + return true +} + +func (m *failoverProxyModel) GetType() string { + return "FailoverProxyModel" +} + +// FailoverContext contains context information during failover process. +type FailoverContext struct { + // FailoverAttempt is the current failover attempt number, starting from 1. + FailoverAttempt uint + + // InputMessages is the original input messages before any transformation. + InputMessages []*schema.Message + + // LastOutputMessage is the output message from the last failed attempt. + // May be nil if no output was produced. For streaming, this may be a partial message + // already received before the stream error. + LastOutputMessage *schema.Message + + // LastErr is the error from the last failed attempt that triggered this failover. + // + // Note: When ModelRetryConfig is also configured, LastErr will be a *RetryExhaustedError + // (if retries were exhausted) rather than the original model error. The original error + // can be retrieved via RetryExhaustedError.LastErr. + LastErr error +} + +// ModelFailoverConfig configures failover behavior for ChatModel. +// When configured, each ChatModel call first tries the last successful model (initially the configured Model), +// and if that fails, calls GetFailoverModel to select alternate models. +type ModelFailoverConfig struct { + // MaxRetries specifies the maximum number of failover attempts. + // + // When failover is triggered, GetFailoverModel will be called up to MaxRetries times + // (FailoverAttempt starts from 1). If GetFailoverModel returns an error, failover + // stops immediately and that error is returned. + // + // A value of 0 means no failover (GetFailoverModel will not be called). + // A value of 1 means GetFailoverModel may be called once. + // + // Note: if lastSuccessModel is set (from a previous successful call), it will be tried + // first before calling GetFailoverModel. + MaxRetries uint + + // ShouldFailover determines whether to fail over to the next model when an error occurs. + // It receives the output message (may be nil if no output is available) and the error (non-nil on failure). + // For streaming errors, outputMessage can carry a partial message accumulated before the error. + // + // Note: When ModelRetryConfig is also configured, outputErr will be a *RetryExhaustedError + // (if retries were exhausted) rather than the original model error. Use errors.As to extract + // the RetryExhaustedError and access RetryExhaustedError.LastErr for the original error: + // + // var retryErr *adk.RetryExhaustedError + // if errors.As(outputErr, &retryErr) { + // // retryErr.LastErr contains the original model error + // } + // + // Note: When the context itself is cancelled (ctx.Err() != nil), failover will stop immediately + // regardless of this function. However, if the model returns context.Canceled or context.DeadlineExceeded + // as an error while the context is still active, this function will still be called. + // Should not be nil when ModelFailoverConfig is set. + // Return true to fail over to the next model, false to stop and return the current result/error. + ShouldFailover func(ctx context.Context, outputMessage *schema.Message, outputErr error) bool + + // GetFailoverModel is called when a model call fails and ShouldFailover returns true. + // It selects the next model to use for the failover attempt and optionally transforms input messages. + // It receives the failover context containing attempt number (starting from 1), original input, and last result. + // Return values: + // - failoverModel: The model to use for this failover attempt. + // - failoverModelInputMessages: The transformed input messages for the failover model. If nil, will use original input. + // - failoverErr: If non-nil, failover stops and this error is returned. + // Should not be nil when ModelFailoverConfig is set via ChatModelAgentConfig. + GetFailoverModel func(ctx context.Context, failoverCtx *FailoverContext) ( + failoverModel model.BaseChatModel, failoverModelInputMessages []*schema.Message, failoverErr error) +} + +func getLastSuccessModel(ctx context.Context) model.BaseChatModel { + if execCtx := getChatModelAgentExecCtx(ctx); execCtx != nil { + return execCtx.failoverLastSuccessModel + } + return nil +} + +func setLastSuccessModel(ctx context.Context, m model.BaseChatModel) { + if execCtx := getChatModelAgentExecCtx(ctx); execCtx != nil { + execCtx.failoverLastSuccessModel = m + } +} + +type failoverModelWrapper struct { + config *ModelFailoverConfig + inner model.BaseChatModel +} + +func newFailoverModelWrapper(inner model.BaseChatModel, config *ModelFailoverConfig) *failoverModelWrapper { + return &failoverModelWrapper{ + config: config, + inner: inner, + } +} + +func (f *failoverModelWrapper) needFailover(ctx context.Context, outputMessage *schema.Message, outputErr error) bool { + if ctx.Err() != nil { + return false + } + + // ShouldFailover is validated at agent construction; nil here indicates a programmer error. + return f.config.ShouldFailover(ctx, outputMessage, outputErr) +} + +func (f *failoverModelWrapper) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + // Defensive: GetFailoverModel is validated non-nil at agent construction. + if f.config.GetFailoverModel == nil { + return f.inner.Generate(ctx, input, opts...) + } + + var lastOutputMessage *schema.Message + var lastErr error + + // Try lastSuccessModel first if available. + if lastSuccess := getLastSuccessModel(ctx); lastSuccess != nil { + if err := ctx.Err(); err != nil { + return nil, err + } + + modelCtx := setFailoverCurrentModel(ctx, lastSuccess) + modelCtx = withFailoverHasMoreAttempts(modelCtx, f.config.MaxRetries > 0) + result, err := f.inner.Generate(modelCtx, input, opts...) + if err == nil { + return result, nil + } + + lastOutputMessage = result + lastErr = err + + if !f.needFailover(ctx, result, err) { + return result, err + } + + log.Printf("failover ChatModel.Generate lastSuccessModel failed: %v", err) + } + + for attempt := uint(1); attempt <= f.config.MaxRetries; attempt++ { + if err := ctx.Err(); err != nil { + return nil, err + } + + failoverCtx := &FailoverContext{ + FailoverAttempt: attempt, + InputMessages: input, + LastOutputMessage: lastOutputMessage, + LastErr: lastErr, + } + + currentModel, currentInput, err := f.config.GetFailoverModel(ctx, failoverCtx) + if err != nil { + return nil, err + } + if currentModel == nil { + return nil, fmt.Errorf("failover GetFailoverModel returned nil model at attempt %d", attempt) + } + + if currentInput == nil { + currentInput = input + } + + modelCtx := setFailoverCurrentModel(ctx, currentModel) + modelCtx = withFailoverHasMoreAttempts(modelCtx, attempt < f.config.MaxRetries) + result, err := f.inner.Generate(modelCtx, currentInput, opts...) + lastOutputMessage = result + lastErr = err + + if err == nil { + setLastSuccessModel(ctx, currentModel) + return result, nil + } + + if !f.needFailover(ctx, result, err) { + return result, err + } + + if attempt < f.config.MaxRetries { + log.Printf("failover ChatModel.Generate attempt %d failed: %v", attempt, err) + } + } + + return lastOutputMessage, lastErr +} + +func (f *failoverModelWrapper) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) ( + *schema.StreamReader[*schema.Message], error) { + // Defensive: GetFailoverModel is validated non-nil at agent construction. + if f.config.GetFailoverModel == nil { + return f.inner.Stream(ctx, input, opts...) + } + + var lastOutputMessage *schema.Message + var lastErr error + + // Try lastSuccessModel first if available. + if lastSuccess := getLastSuccessModel(ctx); lastSuccess != nil { + if err := ctx.Err(); err != nil { + return nil, err + } + + modelCtx := setFailoverCurrentModel(ctx, lastSuccess) + modelCtx = withFailoverHasMoreAttempts(modelCtx, f.config.MaxRetries > 0) + stream, err := f.inner.Stream(modelCtx, input, opts...) + if err != nil { + lastErr = err + if !f.needFailover(ctx, nil, err) { + return nil, err + } + log.Printf("failover ChatModel.Stream lastSuccessModel failed: %v", err) + } else { + copies := stream.Copy(2) + checkCopy := copies[0] + returnCopy := copies[1] + + outMsg, streamErr := consumeStream(checkCopy) + if streamErr != nil { + lastOutputMessage = outMsg + lastErr = streamErr + returnCopy.Close() + + if !f.needFailover(ctx, outMsg, streamErr) { + return nil, streamErr + } + log.Printf("failover ChatModel.Stream lastSuccessModel failed: %v", streamErr) + } else { + return returnCopy, nil + } + } + } + + for attempt := uint(1); attempt <= f.config.MaxRetries; attempt++ { + if err := ctx.Err(); err != nil { + return nil, err + } + + failoverCtx := &FailoverContext{ + FailoverAttempt: attempt, + InputMessages: input, + LastOutputMessage: lastOutputMessage, + LastErr: lastErr, + } + + currentModel, currentInput, err := f.config.GetFailoverModel(ctx, failoverCtx) + if err != nil { + return nil, err + } + if currentModel == nil { + return nil, fmt.Errorf("failover GetFailoverModel returned nil model at attempt %d", attempt) + } + + if currentInput == nil { + currentInput = input + } + + modelCtx := setFailoverCurrentModel(ctx, currentModel) + modelCtx = withFailoverHasMoreAttempts(modelCtx, attempt < f.config.MaxRetries) + stream, err := f.inner.Stream(modelCtx, currentInput, opts...) + if err != nil { + lastErr = err + lastOutputMessage = nil + + if !f.needFailover(ctx, nil, err) { + return nil, err + } + + if attempt < f.config.MaxRetries { + log.Printf("failover ChatModel.Stream attempt %d failed: %v", attempt, err) + } + continue + } + + // The stream returned by f.inner.Stream is already Copy'd by the inner eventSender layer: one + // copy is forwarded to the client in real time via events. Therefore consuming a copy here does + // NOT block client-side streaming. + // + // We Copy the stream into two readers: + // - checkCopy: consumed synchronously to surface mid-stream errors and decide whether to fail over. + // - returnCopy: returned to the caller (stateModelWrapper), which also consumes synchronously to + // build state (AfterModelRewriteState), so waiting here adds no extra latency. + // + // If checkCopy errors and failover is allowed, we close returnCopy and retry with the next model. + // Otherwise we return returnCopy. + // + // NOTE on duplicate events during failover: when a retry happens, events from the failed attempt + // may already have been emitted to the client, and the retry will emit a new stream. Client-side + // handlers are expected to handle multiple rounds (e.g., reset on retry or deduplicate by attempt + // metadata). + copies := stream.Copy(2) + checkCopy := copies[0] + returnCopy := copies[1] + + outMsg, streamErr := consumeStream(checkCopy) + if streamErr != nil { + lastOutputMessage = outMsg + lastErr = streamErr + returnCopy.Close() + + if !f.needFailover(ctx, outMsg, streamErr) { + return nil, streamErr + } + + if attempt < f.config.MaxRetries { + log.Printf("failover ChatModel.Stream attempt %d failed: %v", attempt, streamErr) + } + continue + } + + setLastSuccessModel(ctx, currentModel) + return returnCopy, nil + } + + return nil, lastErr +} + +func consumeStream(stream *schema.StreamReader[*schema.Message]) (*schema.Message, error) { + defer stream.Close() + chunks := make([]*schema.Message, 0) + for { + chunk, err := stream.Recv() + if err == io.EOF { + break + } + if err != nil { + // ignore concat error + msg, _ := schema.ConcatMessages(chunks) + return msg, err + } + + chunks = append(chunks, chunk) + } + + // Stream completed successfully (EOF). ConcatMessages error is not a stream error, + // so ignore it to avoid incorrectly triggering failover. + msg, _ := schema.ConcatMessages(chunks) + return msg, nil +} diff --git a/adk/failover_chatmodel_test.go b/adk/failover_chatmodel_test.go new file mode 100644 index 000000000..82866e994 --- /dev/null +++ b/adk/failover_chatmodel_test.go @@ -0,0 +1,697 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "errors" + "io" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/schema" +) + +type fakeChatModel struct { + generate func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) + stream func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) + callbacksEnabled bool +} + +func (m *fakeChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + return m.generate(ctx, input, opts...) +} + +func (m *fakeChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return m.stream(ctx, input, opts...) +} + +func (m *fakeChatModel) IsCallbacksEnabled() bool { + return m.callbacksEnabled +} + +func drainMessageStream(sr *schema.StreamReader[*schema.Message]) ([]*schema.Message, error) { + defer sr.Close() + var out []*schema.Message + for { + chunk, err := sr.Recv() + if err == io.EOF { + return out, nil + } + if err != nil { + return out, err + } + out = append(out, chunk) + } +} + +func streamWithMidError(chunks []*schema.Message, err error) *schema.StreamReader[*schema.Message] { + sr, sw := schema.Pipe[*schema.Message](2) + go func() { + defer sw.Close() + for _, c := range chunks { + sw.Send(c, nil) + } + sw.Send(nil, err) + }() + return sr +} + +func streamWithMidErrorControlled(chunks []*schema.Message, err error, firstSent chan struct{}, release chan struct{}) *schema.StreamReader[*schema.Message] { + sr, sw := schema.Pipe[*schema.Message](2) + go func() { + defer sw.Close() + for i, c := range chunks { + sw.Send(c, nil) + if i == 0 && firstSent != nil { + close(firstSent) + if release != nil { + <-release + } + } + } + sw.Send(nil, err) + }() + return sr +} + +func TestFailoverCurrentModelContext(t *testing.T) { + t.Run("set and get", func(t *testing.T) { + ctx := context.Background() + m := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return schema.AssistantMessage("ok", nil), nil + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("ok", nil)}), nil + }, + } + ctx = setFailoverCurrentModel(ctx, m) + got := getFailoverCurrentModel(ctx) + require.NotNil(t, got) + require.Same(t, m, got.model) + }) + + t.Run("wrong type", func(t *testing.T) { + ctx := context.WithValue(context.Background(), failoverCurrentModelKey{}, "bad") + require.Nil(t, getFailoverCurrentModel(ctx)) + }) + + t.Run("missing", func(t *testing.T) { + require.Nil(t, getFailoverCurrentModel(context.Background())) + }) +} + +func TestFailoverProxyModel(t *testing.T) { + t.Run("generate missing context", func(t *testing.T) { + p := &failoverProxyModel{} + _, err := p.Generate(context.Background(), []*schema.Message{schema.UserMessage("hi")}) + require.Error(t, err) + }) + + t.Run("stream missing context", func(t *testing.T) { + p := &failoverProxyModel{} + _, err := p.Stream(context.Background(), []*schema.Message{schema.UserMessage("hi")}) + require.Error(t, err) + }) + + t.Run("generate routes to current model", func(t *testing.T) { + var called int32 + target := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&called, 1) + return schema.AssistantMessage("routed", nil), nil + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("routed", nil)}), nil + }, + } + ctx := setFailoverCurrentModel(context.Background(), target) + p := &failoverProxyModel{} + msg, err := p.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.NoError(t, err) + require.Equal(t, "routed", msg.Content) + require.Equal(t, int32(1), atomic.LoadInt32(&called)) + }) +} + +func TestFailoverModelWrapper_Generate(t *testing.T) { + t.Run("delegates when GetFailoverModel nil", func(t *testing.T) { + var called int32 + inner := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&called, 1) + return schema.AssistantMessage("inner", nil), nil + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("inner", nil)}), nil + }, + } + w := newFailoverModelWrapper(inner, &ModelFailoverConfig{ + MaxRetries: 2, + ShouldFailover: func(context.Context, *schema.Message, error) bool { return true }, + GetFailoverModel: nil, + }) + msg, err := w.Generate(context.Background(), []*schema.Message{schema.UserMessage("hi")}) + require.NoError(t, err) + require.Equal(t, "inner", msg.Content) + require.Equal(t, int32(1), atomic.LoadInt32(&called)) + }) + + t.Run("failover to second model", func(t *testing.T) { + wantErr := errors.New("first failed") + var shouldCalls int32 + var m1Calls int32 + var m2Calls int32 + + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m1Calls, 1) + return nil, wantErr + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, errors.New("unused") + }, + } + m2 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m2Calls, 1) + return schema.AssistantMessage("ok", nil), nil + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, errors.New("unused") + }, + } + + cfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + atomic.AddInt32(&shouldCalls, 1) + return errors.Is(err, wantErr) + }, + GetFailoverModel: func(_ context.Context, failoverCtx *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + require.Equal(t, uint(1), failoverCtx.FailoverAttempt) + return m2, nil, nil + }, + } + + w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + msg, err := w.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.NoError(t, err) + require.Equal(t, "ok", msg.Content) + require.Equal(t, int32(1), atomic.LoadInt32(&m1Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&m2Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&shouldCalls)) + }) + + t.Run("canceled error delegates to ShouldFailover", func(t *testing.T) { + var shouldCalls int32 + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, context.Canceled + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, errors.New("unused") + }, + } + + cfg := &ModelFailoverConfig{ + MaxRetries: 5, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + atomic.AddInt32(&shouldCalls, 1) + // User decides to stop on canceled error + return !errors.Is(err, context.Canceled) + }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return m1, nil, nil + }, + } + + w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + _, err := w.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.ErrorIs(t, err, context.Canceled) + // ShouldFailover is called once and returns false, stopping failover + require.Equal(t, int32(1), atomic.LoadInt32(&shouldCalls)) + }) + + t.Run("stops when GetFailoverModel returns error", func(t *testing.T) { + wantErr := errors.New("get model failed") + var called int32 + inner := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&called, 1) + return schema.AssistantMessage("unused", nil), nil + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, errors.New("unused") + }, + } + + cfg := &ModelFailoverConfig{ + MaxRetries: 3, + ShouldFailover: func(context.Context, *schema.Message, error) bool { return true }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return nil, nil, wantErr + }, + } + + w := newFailoverModelWrapper(inner, cfg) + _, err := w.Generate(context.Background(), []*schema.Message{schema.UserMessage("hi")}) + require.ErrorIs(t, err, wantErr) + require.Equal(t, int32(0), atomic.LoadInt32(&called)) + }) + + t.Run("stops when GetFailoverModel returns nil model", func(t *testing.T) { + cfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(context.Context, *schema.Message, error) bool { return true }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return nil, nil, nil + }, + } + + w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) + msg, err := w.Generate(context.Background(), []*schema.Message{schema.UserMessage("hi")}) + require.Nil(t, msg) + require.Error(t, err) + require.ErrorContains(t, err, "GetFailoverModel returned nil model") + }) +} + +func TestFailoverModelWrapper_Stream(t *testing.T) { + t.Run("returns stream when first attempt succeeds", func(t *testing.T) { + var shouldCalls int32 + in := schema.UserMessage("hi") + + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, input []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + require.Len(t, input, 1) + require.Same(t, in, input[0]) + return schema.StreamReaderFromArray([]*schema.Message{ + schema.AssistantMessage("a", nil), + schema.AssistantMessage("b", nil), + }), nil + }, + } + + cfg := &ModelFailoverConfig{ + MaxRetries: 0, + ShouldFailover: func(context.Context, *schema.Message, error) bool { + atomic.AddInt32(&shouldCalls, 1) + return false + }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return m1, nil, nil + }, + } + + w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + sr, err := w.Stream(ctx, []*schema.Message{in}) + require.NoError(t, err) + msgs, err := drainMessageStream(sr) + require.NoError(t, err) + require.Len(t, msgs, 2) + require.Equal(t, "a", msgs[0].Content) + require.Equal(t, "b", msgs[1].Content) + require.Equal(t, int32(0), atomic.LoadInt32(&shouldCalls)) + }) + + t.Run("failover when Stream returns error immediately", func(t *testing.T) { + wantErr := errors.New("stream init failed") + var shouldCalls int32 + var m1Calls int32 + var m2Calls int32 + + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m1Calls, 1) + return nil, wantErr + }, + } + m2 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m2Calls, 1) + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("ok", nil)}), nil + }, + } + + cfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + atomic.AddInt32(&shouldCalls, 1) + return errors.Is(err, wantErr) + }, + GetFailoverModel: func(_ context.Context, failoverCtx *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + require.Equal(t, uint(1), failoverCtx.FailoverAttempt) + return m2, nil, nil + }, + } + + w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + sr, err := w.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.NoError(t, err) + msgs, err := drainMessageStream(sr) + require.NoError(t, err) + require.Len(t, msgs, 1) + require.Equal(t, "ok", msgs[0].Content) + require.Equal(t, int32(1), atomic.LoadInt32(&m1Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&m2Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&shouldCalls)) + }) + + t.Run("failover when stream errors mid-way", func(t *testing.T) { + streamErr := errors.New("mid error") + var shouldCalls int32 + var seenOutput atomic.Value + var m1Calls int32 + var m2Calls int32 + + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m1Calls, 1) + return streamWithMidError([]*schema.Message{ + schema.AssistantMessage("p1", nil), + schema.AssistantMessage("p2", nil), + }, streamErr), nil + }, + } + m2 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m2Calls, 1) + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("final", nil)}), nil + }, + } + + cfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(_ context.Context, out *schema.Message, err error) bool { + atomic.AddInt32(&shouldCalls, 1) + if errors.Is(err, streamErr) && out != nil { + seenOutput.Store(out.Content) + } + return errors.Is(err, streamErr) + }, + GetFailoverModel: func(_ context.Context, failoverCtx *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + require.Equal(t, uint(1), failoverCtx.FailoverAttempt) + return m2, nil, nil + }, + } + + w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + sr, err := w.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.NoError(t, err) + msgs, err := drainMessageStream(sr) + require.NoError(t, err) + require.Len(t, msgs, 1) + require.Equal(t, "final", msgs[0].Content) + require.Equal(t, "p1p2", seenOutput.Load()) + require.Equal(t, int32(1), atomic.LoadInt32(&m1Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&m2Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&shouldCalls)) + }) + + t.Run("stop when ShouldFailover returns false for mid-way error", func(t *testing.T) { + streamErr := errors.New("mid error") + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return streamWithMidError([]*schema.Message{schema.AssistantMessage("p", nil)}, streamErr), nil + }, + } + + cfg := &ModelFailoverConfig{ + MaxRetries: 3, + ShouldFailover: func(context.Context, *schema.Message, error) bool { + return false + }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return m1, nil, nil + }, + } + + w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + sr, err := w.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.Nil(t, sr) + require.ErrorIs(t, err, streamErr) + }) + + t.Run("canceled mid-way error delegates to ShouldFailover", func(t *testing.T) { + var shouldCalls int32 + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return streamWithMidError([]*schema.Message{schema.AssistantMessage("p", nil)}, context.Canceled), nil + }, + } + + cfg := &ModelFailoverConfig{ + MaxRetries: 3, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + atomic.AddInt32(&shouldCalls, 1) + // User decides to stop on canceled error + return !errors.Is(err, context.Canceled) + }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return m1, nil, nil + }, + } + + w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + sr, err := w.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.Nil(t, sr) + require.ErrorIs(t, err, context.Canceled) + // ShouldFailover is called once and returns false, stopping failover + require.Equal(t, int32(1), atomic.LoadInt32(&shouldCalls)) + }) + + t.Run("stop when Stream returns error immediately and ShouldFailover returns false", func(t *testing.T) { + wantErr := errors.New("stream init failed") + var shouldCalls int32 + var m1Calls int32 + + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m1Calls, 1) + return nil, wantErr + }, + } + + cfg := &ModelFailoverConfig{ + MaxRetries: 3, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + atomic.AddInt32(&shouldCalls, 1) + require.ErrorIs(t, err, wantErr) + return false + }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return m1, nil, nil + }, + } + + w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + sr, err := w.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.Nil(t, sr) + require.ErrorIs(t, err, wantErr) + require.Equal(t, int32(1), atomic.LoadInt32(&m1Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&shouldCalls)) + }) + + t.Run("stops when GetFailoverModel returns nil model", func(t *testing.T) { + cfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(context.Context, *schema.Message, error) bool { return true }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return nil, nil, nil + }, + } + + w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) + sr, err := w.Stream(context.Background(), []*schema.Message{schema.UserMessage("hi")}) + require.Nil(t, sr) + require.Error(t, err) + require.ErrorContains(t, err, "GetFailoverModel returned nil model") + }) + + t.Run("stops when GetFailoverModel returns error", func(t *testing.T) { + wantErr := errors.New("get model failed") + var called int32 + inner := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&called, 1) + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("unused", nil)}), nil + }, + } + + cfg := &ModelFailoverConfig{ + MaxRetries: 3, + ShouldFailover: func(context.Context, *schema.Message, error) bool { return true }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return nil, nil, wantErr + }, + } + + w := newFailoverModelWrapper(inner, cfg) + sr, err := w.Stream(context.Background(), []*schema.Message{schema.UserMessage("hi")}) + require.Nil(t, sr) + require.ErrorIs(t, err, wantErr) + require.Equal(t, int32(0), atomic.LoadInt32(&called)) + }) + + t.Run("stops when ctx canceled during mid-way error handling", func(t *testing.T) { + midErr := errors.New("mid error") + var shouldCalls int32 + var m1Calls int32 + var m2Calls int32 + firstSent := make(chan struct{}) + release := make(chan struct{}) + + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m1Calls, 1) + return streamWithMidErrorControlled( + []*schema.Message{schema.AssistantMessage("p", nil)}, + midErr, + firstSent, + release, + ), nil + }, + } + m2 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m2Calls, 1) + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("unused", nil)}), nil + }, + } + + cfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(context.Context, *schema.Message, error) bool { + atomic.AddInt32(&shouldCalls, 1) + return true + }, + GetFailoverModel: func(_ context.Context, failoverCtx *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + require.Equal(t, uint(1), failoverCtx.FailoverAttempt) + return m2, nil, nil + }, + } + + w := newFailoverModelWrapper(&failoverProxyModel{}, cfg) + baseCtx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + ctx, cancel := context.WithCancel(baseCtx) + type result struct { + sr *schema.StreamReader[*schema.Message] + err error + } + ch := make(chan result, 1) + go func() { + sr, err := w.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) + ch <- result{sr: sr, err: err} + }() + + <-firstSent + cancel() + close(release) + + res := <-ch + if res.sr != nil { + res.sr.Close() + } + require.Nil(t, res.sr) + require.ErrorIs(t, res.err, midErr) + require.Equal(t, int32(1), atomic.LoadInt32(&m1Calls)) + require.Equal(t, int32(0), atomic.LoadInt32(&m2Calls)) + require.Equal(t, int32(0), atomic.LoadInt32(&shouldCalls)) + }) +} diff --git a/adk/flow.go b/adk/flow.go index ee4dec96c..52a346c74 100644 --- a/adk/flow.go +++ b/adk/flow.go @@ -340,9 +340,13 @@ func (a *flowAgent) Run(ctx context.Context, input *AgentInput, opts ...AgentRun ctx = AppendAddressSegment(ctx, AddressSegmentAgent, agentName) o := getCommonOptions(nil, opts...) + cancelCtx := o.cancelCtx processedInput, err := a.genAgentInput(ctx, runCtx, o.skipTransferMessages) if err != nil { + if cancelCtx != nil { + cancelCtx.markDone() + } cbInput := &AgentCallbackInput{Input: input} ctx = callbacks.OnStart(ctx, cbInput) return wrapIterWithOnEnd(ctx, genErrorIter(err)) @@ -358,16 +362,20 @@ func (a *flowAgent) Run(ctx context.Context, input *AgentInput, opts ...AgentRun input = processedInput if wf, ok := a.Agent.(*workflowAgent); ok { - return wrapIterWithOnEnd(ctx, wf.Run(ctx, input, filterCallbackHandlersForNestedAgents(agentName, opts)...)) + ctx = withCancelContext(ctx, cancelCtx) + filteredOpts := filterCancelOption(filterCallbackHandlersForNestedAgents(agentName, opts)) + iter := wf.Run(ctx, input, filteredOpts...) + iter = wrapIterWithCancelCtx(iter, cancelCtx) + return wrapIterWithOnEnd(ctx, iter) } - aIter := a.Agent.Run(ctx, input, filterOptions(agentName, opts)...) + aIter := a.Agent.Run(withCancelContext(ctx, cancelCtx), input, filterOptions(agentName, opts)...) iterator, generator := NewAsyncIteratorPair[*AgentEvent]() - go a.run(ctx, ctxForSubAgents, runCtx, aIter, generator, opts...) + go a.run(withCancelContext(ctx, cancelCtx), withCancelContext(ctxForSubAgents, cancelCtx), runCtx, aIter, generator, filterCancelOption(opts)...) - return iterator + return wrapIterWithCancelCtx(iterator, cancelCtx) } func (a *flowAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { @@ -377,57 +385,67 @@ func (a *flowAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...AgentR ctxForSubAgents := ctx + o := getCommonOptions(nil, opts...) + cancelCtx := o.cancelCtx + agentType := getAgentType(a.Agent) ctx = initAgentCallbacks(ctx, agentName, agentType, filterOptions(agentName, opts)...) cbInput := &AgentCallbackInput{ResumeInfo: info} ctx = callbacks.OnStart(ctx, cbInput) if info.WasInterrupted { - ra, ok := a.Agent.(ResumableAgent) - if !ok { - return wrapIterWithOnEnd(ctx, genErrorIter(fmt.Errorf("failed to resume agent: agent '%s' is an interrupt point "+ - "but is not a ResumableAgent", agentName))) + if ra, ok := a.Agent.(ResumableAgent); ok { + if _, ok := ra.(*workflowAgent); ok { + ctx = withCancelContext(ctx, cancelCtx) + filteredOpts := filterCancelOption(filterCallbackHandlersForNestedAgents(agentName, opts)) + aIter := ra.Resume(ctx, info, filteredOpts...) + aIter = wrapIterWithCancelCtx(aIter, cancelCtx) + return wrapIterWithOnEnd(ctx, aIter) + } + + aIter := ra.Resume(withCancelContext(ctx, cancelCtx), info, opts...) + + iterator, generator := NewAsyncIteratorPair[*AgentEvent]() + go a.run(withCancelContext(ctx, cancelCtx), withCancelContext(ctxForSubAgents, cancelCtx), getRunCtx(ctxForSubAgents), aIter, generator, filterCancelOption(opts)...) + return wrapIterWithCancelCtx(iterator, cancelCtx) } - iterator, generator := NewAsyncIteratorPair[*AgentEvent]() - if _, ok := ra.(*workflowAgent); ok { - filteredOpts := filterCallbackHandlersForNestedAgents(agentName, opts) - aIter := ra.Resume(ctx, info, filteredOpts...) - return wrapIterWithOnEnd(ctx, aIter) + if cancelCtx != nil { + cancelCtx.markDone() } - aIter := ra.Resume(ctx, info, opts...) - go a.run(ctx, ctxForSubAgents, getRunCtx(ctxForSubAgents), aIter, generator, opts...) - return iterator + return wrapIterWithOnEnd(ctx, genErrorIter(fmt.Errorf("failed to resume agent: agent '%s' is an interrupt point "+ + "but is not a ResumableAgent", agentName))) } nextAgentName, err := getNextResumeAgent(ctx, info) if err != nil { + if cancelCtx != nil { + cancelCtx.markDone() + } return wrapIterWithOnEnd(ctx, genErrorIter(err)) } subAgent := a.getAgent(ctxForSubAgents, nextAgentName) if subAgent == nil { - // the inner agent wrapped by flowAgent may be ANY agent, including flowAgent, - // AgentWithDeterministicTransferTo, or any other custom agent user defined, - // or any combinations of the above in any order, - // that ultimately wraps the flowAgent with sub-agents - // We need to go through these wrappers to reach the flowAgent with sub-agents. if len(a.subAgents) == 0 { if ra, ok := a.Agent.(ResumableAgent); ok { - // Use ctx (callback-enriched) instead of ctxForSubAgents here. - // This is the inner agent that flowAgent wraps (e.g., supervisorContainer), - // not a sub-agent. The callback context from OnStart should be propagated - // to ensure unified tracing for container patterns. - return wrapIterWithOnEnd(ctx, ra.Resume(ctx, info, opts...)) + ctx = withCancelContext(ctx, cancelCtx) + innerIter := ra.Resume(ctx, info, filterCancelOption(opts)...) + return wrapIterWithCancelCtx(wrapIterWithOnEnd(ctx, innerIter), cancelCtx) } return wrapIterWithOnEnd(ctx, genErrorIter(fmt.Errorf( "failed to resume agent: agent '%s' (type %T) has no sub-agents and does not implement ResumableAgent interface. "+ "To support resume, your custom agent wrapper must implement the ResumableAgent interface", agentName, a.Agent))) } + if cancelCtx != nil { + cancelCtx.markDone() + } return wrapIterWithOnEnd(ctx, genErrorIter(fmt.Errorf("failed to resume agent: sub-agent '%s' not found in agent '%s'", nextAgentName, agentName))) } - return wrapIterWithOnEnd(ctx, subAgent.Resume(ctxForSubAgents, info, opts...)) + ctxForSubAgents = withCancelContext(ctxForSubAgents, cancelCtx) + innerIter := subAgent.Resume(ctxForSubAgents, info, filterCancelOption(opts)...) + return wrapIterWithCancelCtx(wrapIterWithOnEnd(ctx, innerIter), cancelCtx) } type DeterministicTransferConfig struct { diff --git a/adk/handler.go b/adk/handler.go index 7c7ebba71..d18abc965 100644 --- a/adk/handler.go +++ b/adk/handler.go @@ -47,6 +47,12 @@ type ToolContext struct { CallID string } +// ToolCallsContext contains metadata about the tool calls that just completed. +type ToolCallsContext struct { + // ToolCalls contains the tool call metadata from the model's response. + ToolCalls []ToolContext +} + // ModelContext contains context information passed to WrapModel. type ModelContext struct { // Tools contains the current tool list configured for the agent. @@ -57,6 +63,14 @@ type ModelContext struct { // This is populated at request time from the agent's ModelRetryConfig. // Used by EventSenderModelWrapper to wrap stream errors appropriately. ModelRetryConfig *ModelRetryConfig + + // ModelFailoverConfig contains the failover configuration for the model. + // This is populated at request time from the agent's ModelFailoverConfig. + // Used by EventSenderModelWrapper to wrap stream errors so that failed failover + // attempts are skipped (not treated as fatal) by the flow event processor. + ModelFailoverConfig *ModelFailoverConfig + + cancelContext *cancelContext } // ChatModelAgentContext contains runtime information passed to handlers before each ChatModelAgent run. @@ -138,6 +152,14 @@ type ChatModelAgentMiddleware interface { // - Tools: the current tool list that was sent to the model AfterModelRewriteState(ctx context.Context, state *ChatModelAgentState, mc *ModelContext) (context.Context, *ChatModelAgentState, error) + // AfterToolCallsRewriteState is called after all concurrent tool calls in an iteration complete. + // The input state includes all messages up to and including the tool call results. + // The returned state is persisted to the agent's internal state. + // + // The ToolCallsContext provides metadata about the tool calls that just completed, + // derived from the assistant message's ToolCalls field. + AfterToolCallsRewriteState(ctx context.Context, state *ChatModelAgentState, tc *ToolCallsContext) (context.Context, *ChatModelAgentState, error) + // WrapInvokableToolCall wraps a tool's synchronous execution with custom behavior. // Return the input endpoint unchanged and nil error if no wrapping is needed. // @@ -247,6 +269,10 @@ func (b *BaseChatModelAgentMiddleware) AfterModelRewriteState(ctx context.Contex return ctx, state, nil } +func (b *BaseChatModelAgentMiddleware) AfterToolCallsRewriteState(ctx context.Context, state *ChatModelAgentState, tc *ToolCallsContext) (context.Context, *ChatModelAgentState, error) { + return ctx, state, nil +} + // SetRunLocalValue sets a key-value pair that persists for the duration of the current agent Run() invocation. // The value is scoped to this specific execution and is not shared across different Run() calls or agent instances. // @@ -327,7 +353,7 @@ func SendEvent(ctx context.Context, event *AgentEvent) error { if execCtx == nil || execCtx.generator == nil { return fmt.Errorf("SendEvent failed: must be called within a ChatModelAgent Run() or Resume() execution context") } - execCtx.generator.Send(event) + execCtx.send(event) return nil } diff --git a/adk/handler_test.go b/adk/handler_test.go index e56da3842..abdb0ecab 100644 --- a/adk/handler_test.go +++ b/adk/handler_test.go @@ -111,6 +111,15 @@ func (h *testAfterModelRewriteStateHandler) AfterModelRewriteState(ctx context.C return h.fn(ctx, state, mc) } +type testAfterToolCallsHandler struct { + *BaseChatModelAgentMiddleware + fn func(ctx context.Context, state *ChatModelAgentState, tc *ToolCallsContext) (context.Context, *ChatModelAgentState, error) +} + +func (h *testAfterToolCallsHandler) AfterToolCallsRewriteState(ctx context.Context, state *ChatModelAgentState, tc *ToolCallsContext) (context.Context, *ChatModelAgentState, error) { + return h.fn(ctx, state, tc) +} + type testToolWrapperHandler struct { *BaseChatModelAgentMiddleware wrapInvokableFn func(context.Context, InvokableToolCallEndpoint, *ToolContext) InvokableToolCallEndpoint @@ -1820,3 +1829,312 @@ func TestToolContextInWrappers(t *testing.T) { assert.Equal(t, "test_call_id_123", capturedCallID, "ToolContext should have correct call ID") }) } + +func TestAfterToolCallsRewriteState(t *testing.T) { + t.Run("ReceivesCorrectToolCallsContext", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + tool1 := &namedTool{name: "tool_alpha"} + tool2 := &namedTool{name: "tool_beta"} + + cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() + + // First call: model returns two tool calls + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + Return(schema.AssistantMessage("calling tools", []schema.ToolCall{ + {ID: "call_1", Function: schema.FunctionCall{Name: "tool_alpha", Arguments: "{}"}}, + {ID: "call_2", Function: schema.FunctionCall{Name: "tool_beta", Arguments: "{}"}}, + }), nil).Times(1) + + // Second call: model returns final response + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + Return(schema.AssistantMessage("done", nil), nil).Times(1) + + var mu sync.Mutex + var capturedTC *ToolCallsContext + var capturedState *ChatModelAgentState + callCount := 0 + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: cm, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{tool1, tool2}, + }, + }, + Handlers: []ChatModelAgentMiddleware{ + &testAfterToolCallsHandler{fn: func(ctx context.Context, state *ChatModelAgentState, tc *ToolCallsContext) (context.Context, *ChatModelAgentState, error) { + mu.Lock() + callCount++ + capturedTC = tc + capturedState = &ChatModelAgentState{Messages: append([]Message{}, state.Messages...)} + mu.Unlock() + return ctx, state, nil + }}, + }, + }) + assert.NoError(t, err) + + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("test")}}) + for { + _, ok := iter.Next() + if !ok { + break + } + } + + mu.Lock() + defer mu.Unlock() + + // Should be called exactly once (one iteration with tool calls) + assert.Equal(t, 1, callCount) + + // ToolCallsContext should have the two tool calls + assert.NotNil(t, capturedTC) + assert.Len(t, capturedTC.ToolCalls, 2) + assert.Equal(t, "tool_alpha", capturedTC.ToolCalls[0].Name) + assert.Equal(t, "call_1", capturedTC.ToolCalls[0].CallID) + assert.Equal(t, "tool_beta", capturedTC.ToolCalls[1].Name) + assert.Equal(t, "call_2", capturedTC.ToolCalls[1].CallID) + + // State should contain: system msg + user msg + assistant msg + 2 tool results + assert.NotNil(t, capturedState) + assert.True(t, len(capturedState.Messages) >= 4, "expected at least 4 messages, got %d", len(capturedState.Messages)) + + // Check tool results are in state + toolResultCount := 0 + for _, msg := range capturedState.Messages { + if msg.Role == schema.Tool { + toolResultCount++ + } + } + assert.Equal(t, 2, toolResultCount) + }) + + t.Run("NotCalledWithoutToolCalls", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + // Model returns a direct response with no tool calls + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + Return(schema.AssistantMessage("direct response", nil), nil).Times(1) + + callCount := 0 + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: cm, + Handlers: []ChatModelAgentMiddleware{ + &testAfterToolCallsHandler{fn: func(ctx context.Context, state *ChatModelAgentState, tc *ToolCallsContext) (context.Context, *ChatModelAgentState, error) { + callCount++ + return ctx, state, nil + }}, + }, + }) + assert.NoError(t, err) + + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("test")}}) + for { + _, ok := iter.Next() + if !ok { + break + } + } + + assert.Equal(t, 0, callCount, "AfterToolCallsRewriteState should not be called when no tool calls happen") + }) + + t.Run("CanModifyStatePersistsToNextIteration", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + tool1 := &namedTool{name: "mytool"} + cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() + + // First call: model returns a tool call + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + Return(schema.AssistantMessage("calling", []schema.ToolCall{ + {ID: "c1", Function: schema.FunctionCall{Name: "mytool", Arguments: "{}"}}, + }), nil).Times(1) + + // Second call: capture messages to verify the injected message is present + var capturedMsgs []*schema.Message + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, msgs []*schema.Message, opts ...interface{}) (*schema.Message, error) { + capturedMsgs = append([]*schema.Message{}, msgs...) + return schema.AssistantMessage("final", nil), nil + }).Times(1) + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: cm, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{tool1}, + }, + }, + Handlers: []ChatModelAgentMiddleware{ + &testAfterToolCallsHandler{fn: func(ctx context.Context, state *ChatModelAgentState, tc *ToolCallsContext) (context.Context, *ChatModelAgentState, error) { + // Inject a user message into state + state.Messages = append(state.Messages, schema.UserMessage("injected_by_middleware")) + return ctx, state, nil + }}, + }, + }) + assert.NoError(t, err) + + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("original")}}) + for { + _, ok := iter.Next() + if !ok { + break + } + } + + // The injected message should be visible in the second model call + assert.NotNil(t, capturedMsgs) + found := false + for _, msg := range capturedMsgs { + if msg.Content == "injected_by_middleware" { + found = true + break + } + } + assert.True(t, found, "Injected message should persist to the next model call") + }) +} + +func TestToolResultNotDuplicated(t *testing.T) { + t.Run("SecondModelCallHasNoToolResultDuplication", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + tool1 := &namedTool{name: "mytool"} + cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() + + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + Return(schema.AssistantMessage("calling", []schema.ToolCall{ + {ID: "c1", Function: schema.FunctionCall{Name: "mytool", Arguments: "{}"}}, + }), nil).Times(1) + + var capturedMsgs []*schema.Message + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, msgs []*schema.Message, opts ...interface{}) (*schema.Message, error) { + capturedMsgs = append([]*schema.Message{}, msgs...) + return schema.AssistantMessage("final", nil), nil + }).Times(1) + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Instruction: "You are helpful.", + Model: cm, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{tool1}, + }, + }, + }) + assert.NoError(t, err) + + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("hello")}}) + for { + _, ok := iter.Next() + if !ok { + break + } + } + + assert.NotNil(t, capturedMsgs) + assert.Equal(t, 4, len(capturedMsgs), + "expected [system, user, assistant, tool_result], got %d messages", len(capturedMsgs)) + assert.Equal(t, schema.System, capturedMsgs[0].Role) + assert.Equal(t, schema.User, capturedMsgs[1].Role) + assert.Equal(t, schema.Assistant, capturedMsgs[2].Role) + assert.Equal(t, schema.Tool, capturedMsgs[3].Role) + + toolResultCount := 0 + for _, msg := range capturedMsgs { + if msg.Role == schema.Tool { + toolResultCount++ + } + } + assert.Equal(t, 1, toolResultCount, + "tool result should appear exactly once, got %d", toolResultCount) + }) + + t.Run("HandlerInjectedMessagePresentWithoutDuplication", func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + tool1 := &namedTool{name: "mytool"} + cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() + + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + Return(schema.AssistantMessage("calling", []schema.ToolCall{ + {ID: "c1", Function: schema.FunctionCall{Name: "mytool", Arguments: "{}"}}, + }), nil).Times(1) + + var capturedMsgs []*schema.Message + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, msgs []*schema.Message, opts ...interface{}) (*schema.Message, error) { + capturedMsgs = append([]*schema.Message{}, msgs...) + return schema.AssistantMessage("final", nil), nil + }).Times(1) + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Instruction: "You are helpful.", + Model: cm, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{tool1}, + }, + }, + Handlers: []ChatModelAgentMiddleware{ + &testAfterToolCallsHandler{fn: func(ctx context.Context, state *ChatModelAgentState, tc *ToolCallsContext) (context.Context, *ChatModelAgentState, error) { + state.Messages = append(state.Messages, schema.UserMessage("injected")) + return ctx, state, nil + }}, + }, + }) + assert.NoError(t, err) + + iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("hello")}}) + for { + _, ok := iter.Next() + if !ok { + break + } + } + + assert.NotNil(t, capturedMsgs) + assert.Equal(t, 5, len(capturedMsgs), + "expected [system, user, assistant, tool_result, injected], got %d messages", len(capturedMsgs)) + assert.Equal(t, schema.System, capturedMsgs[0].Role) + assert.Equal(t, schema.User, capturedMsgs[1].Role) + assert.Equal(t, schema.Assistant, capturedMsgs[2].Role) + assert.Equal(t, schema.Tool, capturedMsgs[3].Role) + assert.Equal(t, "injected", capturedMsgs[4].Content) + + toolResultCount := 0 + for _, msg := range capturedMsgs { + if msg.Role == schema.Tool { + toolResultCount++ + } + } + assert.Equal(t, 1, toolResultCount, + "tool result should appear exactly once, got %d", toolResultCount) + }) +} diff --git a/adk/interrupt.go b/adk/interrupt.go index 5941d0724..fce09d4cf 100644 --- a/adk/interrupt.go +++ b/adk/interrupt.go @@ -22,6 +22,7 @@ import ( "encoding/gob" "errors" "fmt" + "sync" "github.com/cloudwego/eino/internal/core" "github.com/cloudwego/eino/schema" @@ -183,6 +184,11 @@ func WithCheckPointID(id string) AgentRunOption { func init() { schema.RegisterName[*serialization]("_eino_adk_serialization") schema.RegisterName[*WorkflowInterruptInfo]("_eino_adk_workflow_interrupt_info") + // Register []byte for gob: the cancel refactor routes bridge store checkpoint + // bytes ([]byte) through InterruptState.State (type any) inside the outer + // serialization struct. Gob requires concrete types behind interface fields + // to be registered. + gob.Register([]byte{}) } // serialization CheckpointSchema: root checkpoint payload (gob). @@ -266,6 +272,10 @@ func (r *Runner) saveCheckPoint( info *InterruptInfo, is *core.InterruptSignal, ) error { + if r.store == nil { + return nil + } + runCtx := getRunCtx(ctx) id2Addr, id2State := core.SignalToPersistenceMaps(is) @@ -287,31 +297,36 @@ func (r *Runner) saveCheckPoint( const bridgeCheckpointID = "adk_react_mock_key" func newBridgeStore() *bridgeStore { - return &bridgeStore{} + return &bridgeStore{data: make(map[string][]byte)} } -func newResumeBridgeStore(data []byte) *bridgeStore { +func newResumeBridgeStore(checkPointID string, data []byte) *bridgeStore { return &bridgeStore{ - Data: data, - Valid: true, + data: map[string][]byte{checkPointID: data}, } } type bridgeStore struct { - Data []byte - Valid bool + mu sync.Mutex + data map[string][]byte } -func (m *bridgeStore) Get(_ context.Context, _ string) ([]byte, bool, error) { - if m.Valid { - return m.Data, true, nil +func (m *bridgeStore) Get(_ context.Context, key string) ([]byte, bool, error) { + m.mu.Lock() + defer m.mu.Unlock() + if v, ok := m.data[key]; ok { + return v, true, nil } return nil, false, nil } -func (m *bridgeStore) Set(_ context.Context, _ string, checkPoint []byte) error { - m.Data = checkPoint - m.Valid = true +func (m *bridgeStore) Set(_ context.Context, key string, checkPoint []byte) error { + m.mu.Lock() + defer m.mu.Unlock() + if m.data == nil { + m.data = make(map[string][]byte) + } + m.data[key] = checkPoint return nil } diff --git a/adk/middlewares/agentsmd/agentsmd.go b/adk/middlewares/agentsmd/agentsmd.go new file mode 100644 index 000000000..7d29896a7 --- /dev/null +++ b/adk/middlewares/agentsmd/agentsmd.go @@ -0,0 +1,183 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package agentsmd provides a middleware that automatically injects Agents.md +// file contents into model input messages. The injection is transient — content +// is prepended at model call time and never persisted to conversation state, +// so it is naturally excluded from summarization / compression. +package agentsmd + +import ( + "context" + "fmt" + + "github.com/cloudwego/eino/adk" + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/schema" +) + +// Config defines the configuration for the agentsmd middleware. +type Config struct { + // Backend provides file access for loading Agents.md files. + // Implementations can use local filesystem, remote storage, or any other backend. + // Required. + Backend Backend + + // AgentsMDFiles specifies the ordered list of Agents.md file paths to load. + // Files are loaded and injected in the given order. + // Supports @import syntax inside files for recursive inclusion (max depth 5). + AgentsMDFiles []string + + // AllAgentsMDMaxBytes limits the total byte size of all loaded Agents.md content. + // Files are loaded in order; once the cumulative size exceeds this limit, + // remaining files are skipped. Each individual file is always loaded in full. + // 0 means no limit. + AllAgentsMDMaxBytes int + + // OnLoadWarning is an optional callback invoked when a non-fatal error occurs + // during Agents.md file loading (e.g. file not found, circular @import, depth + // exceeded). If nil, warnings are logged via log.Printf. + // + // Note: Backend.Read errors other than os.ErrNotExist (e.g. permission denied, + // I/O errors) are NOT treated as warnings and will abort the loading process. + OnLoadWarning func(filePath string, err error) +} + +// New creates an agentsmd middleware that injects Agents.md content into every +// model call. The content is loaded from the configured file paths via Backend +// on each model invocation. +// +// Recommended: place this middleware AFTER the summarization middleware, so that +// Agents.md content is excluded from summarization/compression. +func New(_ context.Context, cfg *Config) (adk.ChatModelAgentMiddleware, error) { + if err := cfg.validate(); err != nil { + return nil, err + } + + return &middleware{ + BaseChatModelAgentMiddleware: &adk.BaseChatModelAgentMiddleware{}, + loader: newLoaderConfig(cfg.Backend, cfg.AgentsMDFiles, cfg.AllAgentsMDMaxBytes, cfg.OnLoadWarning), + }, nil +} + +type middleware struct { + *adk.BaseChatModelAgentMiddleware + loader *loaderConfig +} + +// WrapModel returns a proxy model that prepends Agents.md content to the input +// messages on every Generate/Stream call. The injected message is never written +// back to ChatModelAgentState, so summarization and reduction middlewares are +// unaffected. +func (m *middleware) WrapModel(_ context.Context, cm model.BaseChatModel, _ *adk.ModelContext) (model.BaseChatModel, error) { + return &agentMDModel{ + inner: cm, + loader: m.loader, + }, nil +} + +// agentMDModel wraps a BaseChatModel to prepend Agents.md content to input. +type agentMDModel struct { + inner model.BaseChatModel + loader *loaderConfig +} + +func (m *agentMDModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + messages, err := m.prependAgentMD(ctx, input) + if err != nil { + return nil, err + } + return m.inner.Generate(ctx, messages, opts...) +} + +func (m *agentMDModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + messages, err := m.prependAgentMD(ctx, input) + if err != nil { + return nil, err + } + return m.inner.Stream(ctx, messages, opts...) +} + +const agentsMDCacheKey = "__agentsmd_content_cache__" + +// prependAgentMD loads the current Agents.md content and inserts it before the +// first User role message. If all configured agent files are empty (or skipped), +// the original input is returned unchanged. +// The loaded content is cached in RunLocalValue for the duration of the agent Run(). +func (m *agentMDModel) prependAgentMD(ctx context.Context, input []*schema.Message) ([]*schema.Message, error) { + var content string + + // Try to get cached content from RunLocalValue. + if cached, found, err := adk.GetRunLocalValue(ctx, agentsMDCacheKey); err == nil && found { + if s, ok := cached.(string); ok { + content = s + } + } + + if content == "" { + var err error + content, err = m.loader.load(ctx) + if err != nil { + return nil, fmt.Errorf("[agentsmd]: failed to load agent files: %w", err) + } + // Cache the loaded content for subsequent model calls in this Run(). + if content != "" { + _ = adk.SetRunLocalValue(ctx, agentsMDCacheKey, content) + } + } + if content == "" { + return input, nil + } + + agentMDMsg := &schema.Message{ + Role: schema.User, + Content: content, + } + + // Insert agentMDMsg before the first User role message. + messages := make([]*schema.Message, 0, len(input)+1) + inserted := false + for i, msg := range input { + if !inserted && msg.Role == schema.User { + messages = append(messages, agentMDMsg) + messages = append(messages, input[i:]...) + inserted = true + break + } + messages = append(messages, msg) + } + if !inserted { + // No User message found; append at the end as fallback. + messages = append(messages, agentMDMsg) + } + return messages, nil +} + +func (c *Config) validate() error { + if c == nil { + return fmt.Errorf("[agentsmd]: config is required") + } + if c.Backend == nil { + return fmt.Errorf("[agentsmd]: backend is required") + } + if len(c.AgentsMDFiles) == 0 { + return fmt.Errorf("[agentsmd]: at least one agent file path is required") + } + if c.AllAgentsMDMaxBytes < 0 { + return fmt.Errorf("[agentsmd]: AllAgentMDDocsMaxBytes must be non-negative") + } + return nil +} diff --git a/adk/middlewares/agentsmd/agentsmd_test.go b/adk/middlewares/agentsmd/agentsmd_test.go new file mode 100644 index 000000000..e3d7e00e9 --- /dev/null +++ b/adk/middlewares/agentsmd/agentsmd_test.go @@ -0,0 +1,1420 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package agentsmd + +import ( + "context" + "fmt" + "os" + "strings" + "testing" + + "github.com/cloudwego/eino/adk/filesystem" + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/schema" +) + +// --- test helpers --- + +type mockModel struct { + lastInput []*schema.Message +} + +func (m *mockModel) Generate(_ context.Context, input []*schema.Message, _ ...model.Option) (*schema.Message, error) { + m.lastInput = input + return &schema.Message{Role: schema.Assistant, Content: "ok"}, nil +} + +func (m *mockModel) Stream(_ context.Context, input []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + m.lastInput = input + return nil, nil +} + +type memBackend struct { + files map[string]string +} + +func newMemBackend() *memBackend { + return &memBackend{files: make(map[string]string)} +} + +func (b *memBackend) set(path string, content string) { + b.files[path] = content +} + +func (b *memBackend) Read(_ context.Context, req *ReadRequest) (*filesystem.FileContent, error) { + content, ok := b.files[req.FilePath] + if !ok { + return nil, fmt.Errorf("file not found: %s: %w", req.FilePath, os.ErrNotExist) + } + return &filesystem.FileContent{Content: content}, nil +} + +// errBackend always returns a non-ErrNotExist error on Read, simulating I/O failures. +type errBackend struct{} + +func (b *errBackend) Read(_ context.Context, req *ReadRequest) (*filesystem.FileContent, error) { + return nil, fmt.Errorf("permission denied: %s", req.FilePath) +} + +// partialErrBackend returns content for known files and I/O error for others. +type partialErrBackend struct { + files map[string]string +} + +func (b *partialErrBackend) Read(_ context.Context, req *ReadRequest) (*filesystem.FileContent, error) { + content, ok := b.files[req.FilePath] + if !ok { + return nil, fmt.Errorf("I/O error reading %s", req.FilePath) + } + return &filesystem.FileContent{Content: content}, nil +} + +// --- tests --- + +func TestNew_Validation(t *testing.T) { + ctx := context.Background() + b := newMemBackend() + + _, err := New(ctx, nil) + if err == nil { + t.Fatal("expected error for nil config") + } + + _, err = New(ctx, &Config{}) + if err == nil { + t.Fatal("expected error for empty config") + } + + _, err = New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/test.md"}, AllAgentsMDMaxBytes: -1}) + if err == nil { + t.Fatal("expected error for negative max bytes") + } + + _, err = New(ctx, &Config{AgentsMDFiles: []string{"/test.md"}}) + if err == nil { + t.Fatal("expected error for nil backend") + } +} + +func TestMiddleware_BasicInjection(t *testing.T) { + b := newMemBackend() + b.set("/agent.md", "You are a helpful assistant.") + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/agent.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + userMsg := &schema.Message{Role: schema.User, Content: "hello"} + if _, err = wrapped.Generate(ctx, []*schema.Message{userMsg}); err != nil { + t.Fatal(err) + } + + if len(mock.lastInput) != 2 { + t.Fatalf("expected 2 messages, got %d", len(mock.lastInput)) + } + if mock.lastInput[0].Role != schema.User { + t.Fatalf("expected first message role User, got %s", mock.lastInput[0].Role) + } + if !strings.Contains(mock.lastInput[0].Content, "You are a helpful assistant.") { + t.Fatalf("expected agent.md content in first message, got %q", mock.lastInput[0].Content) + } + if !strings.Contains(mock.lastInput[0].Content, "") { + t.Fatalf("expected system-reminder tag, got %q", mock.lastInput[0].Content) + } + if mock.lastInput[1].Content != "hello" { + t.Fatalf("expected original message preserved, got %q", mock.lastInput[1].Content) + } +} + +func TestMiddleware_MultipleFiles(t *testing.T) { + b := newMemBackend() + b.set("/a.md", "instruction A") + b.set("/b.md", "instruction B") + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/a.md", "/b.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + if _, err = wrapped.Generate(ctx, []*schema.Message{{Role: schema.User, Content: "hi"}}); err != nil { + t.Fatal(err) + } + + content := mock.lastInput[0].Content + idxA := strings.Index(content, "instruction A") + idxB := strings.Index(content, "instruction B") + if idxA < 0 || idxB < 0 { + t.Fatalf("both files should be included, content: %q", content) + } + if idxA >= idxB { + t.Fatal("file A should appear before file B") + } +} + +func TestMiddleware_ImportResolution(t *testing.T) { + b := newMemBackend() + b.set("/project/agent.md", "main instructions\n@sub/rules.md\nend") + b.set("/project/sub/rules.md", "imported rule") + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/project/agent.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + if _, err = wrapped.Generate(ctx, []*schema.Message{{Role: schema.User, Content: "hi"}}); err != nil { + t.Fatal(err) + } + + content := mock.lastInput[0].Content + // Original text should be preserved with @path intact. + if !strings.Contains(content, "main instructions") { + t.Fatalf("should contain original text, got %q", content) + } + if !strings.Contains(content, "@sub/rules.md") { + t.Fatalf("@import reference should be preserved in original text, got %q", content) + } + if !strings.Contains(content, "end") { + t.Fatalf("should contain original trailing text, got %q", content) + } + // Imported file should appear as a separate section. + if !strings.Contains(content, "Contents of /project/sub/rules.md") { + t.Fatalf("imported file should have its own section, got %q", content) + } + if !strings.Contains(content, "imported rule") { + t.Fatalf("imported file content should be present, got %q", content) + } +} + +func TestMiddleware_RecursiveImport(t *testing.T) { + b := newMemBackend() + b.set("/a.md", "top\n@/b.md") + b.set("/b.md", "middle\n@/c.md") + b.set("/c.md", "leaf content") + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/a.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + if _, err = wrapped.Generate(ctx, []*schema.Message{{Role: schema.User, Content: "hi"}}); err != nil { + t.Fatal(err) + } + + content := mock.lastInput[0].Content + // All three files should appear as separate sections. + for _, section := range []string{"Contents of /a.md", "Contents of /b.md", "Contents of /c.md"} { + if !strings.Contains(content, section) { + t.Fatalf("expected section %q in content, got %q", section, content) + } + } + for _, text := range []string{"top", "middle", "leaf content"} { + if !strings.Contains(content, text) { + t.Fatalf("expected %q in content, got %q", text, content) + } + } + // Sections should appear in order: a, b, c. + idxA := strings.Index(content, "Contents of /a.md") + idxB := strings.Index(content, "Contents of /b.md") + idxC := strings.Index(content, "Contents of /c.md") + if !(idxA < idxB && idxB < idxC) { + t.Fatalf("sections should appear in order a < b < c, got a=%d b=%d c=%d", idxA, idxB, idxC) + } +} + +func TestMiddleware_MaxImportDepth(t *testing.T) { + b := newMemBackend() + for i := 0; i < 7; i++ { + var content string + if i < 6 { + content = fmt.Sprintf("level %d\n@/level%d.md", i, i+1) + } else { + content = fmt.Sprintf("level %d", i) + } + b.set(fmt.Sprintf("/level%d.md", i), content) + } + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/level0.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + // Import failure at depth > 5 is logged, not returned as error. + _, err = wrapped.Generate(ctx, []*schema.Message{{Role: schema.User, Content: "hi"}}) + if err != nil { + t.Fatalf("expected no error (depth exceeded is logged), got %v", err) + } + // Levels 0-5 should be present as sections; level 6 fails silently. + content := mock.lastInput[0].Content + for i := 0; i <= 5; i++ { + want := fmt.Sprintf("Contents of /level%d.md", i) + if !strings.Contains(content, want) { + t.Fatalf("expected %q in content, got %q", want, content) + } + } + if strings.Contains(content, "Contents of /level6.md") { + t.Fatalf("level6 should not be present (depth exceeded), got %q", content) + } +} + +func TestMiddleware_CircularImport(t *testing.T) { + b := newMemBackend() + b.set("/a.md", "@/b.md") + b.set("/b.md", "@/a.md") + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/a.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + // Circular import failure is logged, not returned as error. + _, err = wrapped.Generate(ctx, []*schema.Message{{Role: schema.User, Content: "hi"}}) + if err != nil { + t.Fatalf("expected no error (circular import is logged), got %v", err) + } + // /a.md and /b.md should both be present; the circular ref from b->a is skipped. + content := mock.lastInput[0].Content + if !strings.Contains(content, "Contents of /a.md") { + t.Fatalf("expected /a.md section, got %q", content) + } + if !strings.Contains(content, "Contents of /b.md") { + t.Fatalf("expected /b.md section, got %q", content) + } +} + +func TestMiddleware_MaxBytesLimit(t *testing.T) { + b := newMemBackend() + b.set("/a.md", "AAAA") // 4 bytes + b.set("/b.md", "BBBB") // 4 bytes + + ctx := context.Background() + mw, err := New(ctx, &Config{ + Backend: b, + AgentsMDFiles: []string{"/a.md", "/b.md"}, + AllAgentsMDMaxBytes: 5, // file a (4) fits, file b (4) would exceed + }) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + if _, err = wrapped.Generate(ctx, []*schema.Message{{Role: schema.User, Content: "hi"}}); err != nil { + t.Fatal(err) + } + + content := mock.lastInput[0].Content + if !strings.Contains(content, "AAAA") { + t.Fatal("first file should be included") + } + if strings.Contains(content, "BBBB") { + t.Fatal("second file should be excluded due to max bytes") + } +} + +func TestMiddleware_NotPersistedInState(t *testing.T) { + b := newMemBackend() + b.set("/agent.md", "agent instructions") + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/agent.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + originalMsgs := []*schema.Message{{Role: schema.User, Content: "hello"}} + if _, err = wrapped.Generate(ctx, originalMsgs); err != nil { + t.Fatal(err) + } + + if len(originalMsgs) != 1 { + t.Fatalf("original messages should not be modified, got %d messages", len(originalMsgs)) + } + if originalMsgs[0].Content != "hello" { + t.Fatalf("original message should be unchanged, got %q", originalMsgs[0].Content) + } + if len(mock.lastInput) != 2 { + t.Fatalf("model should receive 2 messages, got %d", len(mock.lastInput)) + } +} + +func TestMiddleware_AbsoluteImportPath(t *testing.T) { + b := newMemBackend() + b.set("/project/main.md", "start\n@/shared/imported.md\nend") + b.set("/shared/imported.md", "absolute import content") + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/project/main.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + if _, err = wrapped.Generate(ctx, []*schema.Message{{Role: schema.User, Content: "hi"}}); err != nil { + t.Fatal(err) + } + + content := mock.lastInput[0].Content + // @path preserved in original text. + if !strings.Contains(content, "@/shared/imported.md") { + t.Fatalf("@import reference should be preserved, got %q", content) + } + // Imported content in separate section. + if !strings.Contains(content, "Contents of /shared/imported.md") { + t.Fatalf("expected separate section for imported file, got %q", content) + } + if !strings.Contains(content, "absolute import content") { + t.Fatalf("expected absolute import content, got %q", content) + } +} + +func TestMiddleware_ImportAsSeparateSection(t *testing.T) { + b := newMemBackend() + b.set("/project/agent.md", "Please read @sub/rules.md and also @sub/style.md for guidance.") + b.set("/project/sub/rules.md", "RULE_CONTENT") + b.set("/project/sub/style.md", "STYLE_CONTENT") + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/project/agent.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + if _, err = wrapped.Generate(ctx, []*schema.Message{{Role: schema.User, Content: "hi"}}); err != nil { + t.Fatal(err) + } + + content := mock.lastInput[0].Content + // Original text preserved with @paths intact. + if !strings.Contains(content, "Please read @sub/rules.md and also @sub/style.md for guidance.") { + t.Fatalf("original text with @paths should be preserved, got %q", content) + } + // Imported files appear as separate sections. + if !strings.Contains(content, "Contents of /project/sub/rules.md") { + t.Fatalf("expected rules.md section, got %q", content) + } + if !strings.Contains(content, "RULE_CONTENT") { + t.Fatalf("expected imported rule content, got %q", content) + } + if !strings.Contains(content, "Contents of /project/sub/style.md") { + t.Fatalf("expected style.md section, got %q", content) + } + if !strings.Contains(content, "STYLE_CONTENT") { + t.Fatalf("expected imported style content, got %q", content) + } + + // Sections should be ordered: agent.md, rules.md, style.md. + idxAgent := strings.Index(content, "Contents of /project/agent.md") + idxRules := strings.Index(content, "Contents of /project/sub/rules.md") + idxStyle := strings.Index(content, "Contents of /project/sub/style.md") + if !(idxAgent < idxRules && idxRules < idxStyle) { + t.Fatalf("sections should appear in order agent < rules < style, got agent=%d rules=%d style=%d", idxAgent, idxRules, idxStyle) + } +} + +// --- loader-specific tests --- + +func TestLoader_NoImportsPassthrough(t *testing.T) { + // Content without any @path should be returned as-is in its section. + b := newMemBackend() + b.set("/agent.md", "plain text without imports\nline two") + + l := newLoaderConfig(b, []string{"/agent.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(content, "plain text without imports") { + t.Fatalf("expected plain content, got %q", content) + } + if !strings.Contains(content, "line two") { + t.Fatalf("expected second line, got %q", content) + } +} + +func TestLoader_ImportAsSeparateSection(t *testing.T) { + // @path in the middle of a sentence should be preserved; imported file is a separate section. + b := newMemBackend() + b.set("/doc.md", "before @/snippet.md after") + b.set("/snippet.md", "INJECTED") + + l := newLoaderConfig(b, []string{"/doc.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + // Original text preserved. + if !strings.Contains(content, "before @/snippet.md after") { + t.Fatalf("original text should be preserved with @path, got %q", content) + } + // Imported file in separate section. + if !strings.Contains(content, "Contents of /snippet.md") { + t.Fatalf("expected separate section for snippet.md, got %q", content) + } + if !strings.Contains(content, "INJECTED") { + t.Fatalf("expected imported content, got %q", content) + } +} + +func TestLoader_MultipleImportsSameLine(t *testing.T) { + // Multiple @path on one line should each get a separate section. + b := newMemBackend() + b.set("/doc.md", "see @/a.txt and @/b.txt here") + b.set("/a.txt", "AAA") + b.set("/b.txt", "BBB") + + l := newLoaderConfig(b, []string{"/doc.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + // Original text preserved. + if !strings.Contains(content, "see @/a.txt and @/b.txt here") { + t.Fatalf("original text should be preserved, got %q", content) + } + // Each imported file has its own section. + if !strings.Contains(content, "Contents of /a.txt") { + t.Fatalf("expected section for a.txt, got %q", content) + } + if !strings.Contains(content, "AAA") { + t.Fatalf("expected a.txt content, got %q", content) + } + if !strings.Contains(content, "Contents of /b.txt") { + t.Fatalf("expected section for b.txt, got %q", content) + } + if !strings.Contains(content, "BBB") { + t.Fatalf("expected b.txt content, got %q", content) + } +} + +func TestLoader_SameFileTwiceOnSameLine(t *testing.T) { + // The same file referenced twice should appear only once as a section (deduped). + b := newMemBackend() + b.set("/doc.md", "@/shared.md and @/shared.md again") + b.set("/shared.md", "SHARED") + + l := newLoaderConfig(b, []string{"/doc.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + // Original text preserved. + if !strings.Contains(content, "@/shared.md and @/shared.md again") { + t.Fatalf("original text should be preserved, got %q", content) + } + // shared.md content should appear only once (deduped). + count := strings.Count(content, "Contents of /shared.md") + if count != 1 { + t.Fatalf("expected shared.md section to appear once (deduped), got %d in %q", count, content) + } +} + +func TestLoader_ImportFileNotFound(t *testing.T) { + b := newMemBackend() + b.set("/doc.md", "load @/missing.md please") + + l := newLoaderConfig(b, []string{"/doc.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatalf("expected no error (missing import is logged), got %v", err) + } + // Original text preserved; missing file simply has no section. + if !strings.Contains(content, "load @/missing.md please") { + t.Fatalf("expected original text preserved, got %q", content) + } + if strings.Contains(content, "Contents of /missing.md") { + t.Fatalf("missing file should not have a section, got %q", content) + } +} + +func TestLoader_RelativePathResolution(t *testing.T) { + // Relative path should resolve relative to the host file's directory. + b := newMemBackend() + b.set("/a/b/host.md", "ref @../c/target.md done") + b.set("/a/c/target.md", "TARGET") + + l := newLoaderConfig(b, []string{"/a/b/host.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + // Original text preserved. + if !strings.Contains(content, "ref @../c/target.md done") { + t.Fatalf("original text should be preserved, got %q", content) + } + // Imported file as separate section. + if !strings.Contains(content, "Contents of /a/c/target.md") { + t.Fatalf("expected section for target.md, got %q", content) + } + if !strings.Contains(content, "TARGET") { + t.Fatalf("expected imported content, got %q", content) + } +} + +func TestLoader_RelativeTopLevelPath(t *testing.T) { + // Top-level file uses relative path; imports with ./ resolve correctly. + b := newMemBackend() + b.set("sub/agents.md", "start @./other.md end") + b.set("sub/other.md", "OTHER CONTENT") + + l := newLoaderConfig(b, []string{"sub/agents.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(content, "start @./other.md end") { + t.Fatalf("expected original text preserved, got %q", content) + } + if !strings.Contains(content, "OTHER CONTENT") { + t.Fatalf("expected imported content, got %q", content) + } +} + +func TestLoader_RelativeTopLevelWithDotDotImport(t *testing.T) { + // Top-level file uses relative path; import with ../ resolves correctly. + b := newMemBackend() + b.set("sub/agents.md", "see @../shared/x.md here") + b.set("shared/x.md", "SHARED X") + + l := newLoaderConfig(b, []string{"sub/agents.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(content, "SHARED X") { + t.Fatalf("expected imported content, got %q", content) + } + // filepath.Clean should normalize "sub/../shared/x.md" to "shared/x.md" + if !strings.Contains(content, "Contents of shared/x.md") { + t.Fatalf("expected normalized path in section header, got %q", content) + } +} + +func TestLoader_RelativeTopLevelDedup(t *testing.T) { + // Two top-level relative paths that resolve to the same file via filepath.Clean + // should be deduped (loaded only once). + b := newMemBackend() + b.set("sub/a.md", "CONTENT A") + + l := newLoaderConfig(b, []string{"sub/a.md", "./sub/a.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + count := strings.Count(content, "CONTENT A") + if count != 1 { + t.Fatalf("expected file loaded once (deduped), got %d occurrences in %q", count, content) + } +} + +func TestLoader_AbsoluteTopLevelWithRelativeImport(t *testing.T) { + // Absolute top-level path with relative @import resolves correctly. + b := newMemBackend() + b.set("/project/agents.md", "ref @./lib/helper.md done") + b.set("/project/lib/helper.md", "HELPER") + + l := newLoaderConfig(b, []string{"/project/agents.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(content, "HELPER") { + t.Fatalf("expected imported content, got %q", content) + } + if !strings.Contains(content, "Contents of /project/lib/helper.md") { + t.Fatalf("expected section header, got %q", content) + } +} + +func TestLoader_AbsoluteTopLevelWithDotDotImport(t *testing.T) { + // Absolute top-level path; @import with ../ resolves and normalizes. + b := newMemBackend() + b.set("/project/sub/agents.md", "load @../shared/x.md here") + b.set("/project/shared/x.md", "SHARED") + + l := newLoaderConfig(b, []string{"/project/sub/agents.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(content, "SHARED") { + t.Fatalf("expected imported content, got %q", content) + } + // filepath.Clean normalizes "/project/sub/../shared/x.md" to "/project/shared/x.md" + if !strings.Contains(content, "Contents of /project/shared/x.md") { + t.Fatalf("expected normalized path in section header, got %q", content) + } +} + +func TestLoader_RelativeImportDedup(t *testing.T) { + // Two different relative @import paths that resolve to the same file + // should be deduped via filepath.Clean. + b := newMemBackend() + b.set("/a/main.md", "first @/a/b/shared.md second @../a/b/shared.md end") + b.set("/a/b/shared.md", "SHARED ONCE") + + l := newLoaderConfig(b, []string{"/a/main.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + count := strings.Count(content, "SHARED ONCE") + if count != 1 { + t.Fatalf("expected shared file loaded once (deduped), got %d in %q", count, content) + } +} + +func TestLoader_NestedRelativeImport(t *testing.T) { + // File A imports B via relative path, B imports C via relative path. + // All three should appear as separate sections. + b := newMemBackend() + b.set("/root/main.md", "start @sub/mid.md end") + b.set("/root/sub/mid.md", "mid @deep/leaf.md mid_end") + b.set("/root/sub/deep/leaf.md", "LEAF") + + l := newLoaderConfig(b, []string{"/root/main.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + for _, section := range []string{"Contents of /root/main.md", "Contents of /root/sub/mid.md", "Contents of /root/sub/deep/leaf.md"} { + if !strings.Contains(content, section) { + t.Fatalf("expected section %q, got %q", section, content) + } + } + if !strings.Contains(content, "LEAF") { + t.Fatalf("expected leaf content, got %q", content) + } +} + +func TestLoader_TransitiveImport(t *testing.T) { + // Imported file itself contains @imports; all should appear as separate sections. + b := newMemBackend() + b.set("/main.md", "header @/mid.md footer") + b.set("/mid.md", "mid-start @/leaf.md mid-end") + b.set("/leaf.md", "LEAF_VALUE") + + l := newLoaderConfig(b, []string{"/main.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + for _, section := range []string{"Contents of /main.md", "Contents of /mid.md", "Contents of /leaf.md"} { + if !strings.Contains(content, section) { + t.Fatalf("expected section %q, got %q", section, content) + } + } + if !strings.Contains(content, "LEAF_VALUE") { + t.Fatalf("expected leaf value, got %q", content) + } +} + +func TestLoader_EmptyFile(t *testing.T) { + b := newMemBackend() + b.set("/empty.md", "") + + l := newLoaderConfig(b, []string{"/empty.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + // Empty file is treated as non-existent, so output should be empty. + if content != "" { + t.Fatalf("expected empty output for empty file, got %q", content) + } +} + +func TestLoader_MaxBytesFirstFileFull(t *testing.T) { + // Even if the first file alone exceeds maxBytes, it should still be loaded in full. + b := newMemBackend() + b.set("/big.md", "ABCDEFGHIJ") // 10 bytes + + l := newLoaderConfig(b, []string{"/big.md"}, 3, nil) + content, err := l.load(context.Background()) // maxBytes=3, but first file always loads + if err != nil { + t.Fatal(err) + } + if !strings.Contains(content, "ABCDEFGHIJ") { + t.Fatalf("first file should always load in full, got %q", content) + } +} + +func TestLoader_CircularImportInline(t *testing.T) { + // Circular reference via @import should be detected, logged, and skipped. + b := newMemBackend() + b.set("/a.md", "text @/b.md more") + b.set("/b.md", "ref @/a.md back") + + l := newLoaderConfig(b, []string{"/a.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatalf("expected no error (circular import is logged), got %v", err) + } + // Both a and b should have sections; circular back-reference a from b is skipped. + if !strings.Contains(content, "Contents of /a.md") { + t.Fatalf("expected /a.md section, got %q", content) + } + if !strings.Contains(content, "Contents of /b.md") { + t.Fatalf("expected /b.md section, got %q", content) + } +} + +func TestLoader_MaxDepthInline(t *testing.T) { + // Deep chain via @import should be logged at depth > 5, not returned as error. + b := newMemBackend() + for i := 0; i < 7; i++ { + var content string + if i < 6 { + content = fmt.Sprintf("level%d @/level%d.md tail", i, i+1) + } else { + content = fmt.Sprintf("level%d", i) + } + b.set(fmt.Sprintf("/level%d.md", i), content) + } + + l := newLoaderConfig(b, []string{"/level0.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatalf("expected no error (depth exceeded is logged), got %v", err) + } + // Levels 0-5 should have sections. + for i := 0; i <= 5; i++ { + want := fmt.Sprintf("Contents of /level%d.md", i) + if !strings.Contains(content, want) { + t.Fatalf("expected %q in content, got %q", want, content) + } + } + // Level 6 should not be present. + if strings.Contains(content, "Contents of /level6.md") { + t.Fatalf("level6 should not be present (depth exceeded), got %q", content) + } +} + +func TestLoader_DiamondDependency(t *testing.T) { + // A imports B and D; B imports C; D also imports C. + // C should appear only once (deduped across the whole load). + b := newMemBackend() + b.set("/a.md", "start @/b.md middle @/d.md end") + b.set("/b.md", "B(@/c.md)") + b.set("/d.md", "D(@/c.md)") + b.set("/c.md", "SHARED") + + l := newLoaderConfig(b, []string{"/a.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatalf("diamond dependency should not be circular, got error: %v", err) + } + + // C should appear only once as a section (deduped). + count := strings.Count(content, "Contents of /c.md") + if count != 1 { + t.Fatalf("expected /c.md section once (deduped), got %d in %q", count, content) + } + // All files should have sections. + for _, section := range []string{"Contents of /a.md", "Contents of /b.md", "Contents of /c.md", "Contents of /d.md"} { + if !strings.Contains(content, section) { + t.Fatalf("expected section %q, got %q", section, content) + } + } +} + +func TestLoader_AtSignInNormalText(t *testing.T) { + // Bare @word without "/" or file extension should not trigger import. + // Email-like patterns (@example.com) with non-allowed extensions should also be ignored. + b := newMemBackend() + b.set("/agent.md", "contact me @ anytime or @ spaces and @someone mentioned and user@example.com and @company.org") + + l := newLoaderConfig(b, []string{"/agent.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(content, "contact me @ anytime") { + t.Fatalf("bare @ should not trigger import, got %q", content) + } + if !strings.Contains(content, "@someone mentioned") { + t.Fatalf("@someone without / or extension should not trigger import, got %q", content) + } + if !strings.Contains(content, "@example.com") { + t.Fatalf("email-like @example.com should not trigger import, got %q", content) + } + if !strings.Contains(content, "@company.org") { + t.Fatalf("email-like @company.org should not trigger import, got %q", content) + } +} + +func TestMiddleware_Stream(t *testing.T) { + b := newMemBackend() + b.set("/agent.md", "stream test") + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/agent.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + _, _ = wrapped.Stream(ctx, []*schema.Message{{Role: schema.User, Content: "hi"}}) + + if len(mock.lastInput) != 2 { + t.Fatalf("expected 2 messages for stream, got %d", len(mock.lastInput)) + } + if !strings.Contains(mock.lastInput[0].Content, "stream test") { + t.Fatalf("expected agent.md content in stream input, got %q", mock.lastInput[0].Content) + } +} + +func TestLoader_MaxBytesWithImports(t *testing.T) { + // Two top-level files that both import the same shared file. + // Budget should account for imported file bytes. + b := newMemBackend() + b.set("/a.md", "A(@/shared.md)") + b.set("/b.md", "B(@/shared.md)") + b.set("/shared.md", strings.Repeat("X", 100)) // 100 bytes + + l := newLoaderConfig(b, []string{"/a.md", "/b.md"}, 120, nil) + // /a.md = 14 bytes + /shared.md = 100 bytes => 114 total after /a.md. + // Budget = 120: /b.md (14 bytes) would push to 128, exceeding budget. + content, err := l.load(context.Background()) + if err != nil { + t.Fatalf("load failed: %v", err) + } + + // /a.md and its import should be included. + if !strings.Contains(content, strings.Repeat("X", 100)) { + t.Fatal("expected /a.md with shared content to be included") + } + + // /b.md should be excluded because totalBytes exceeded budget after loading /a.md. + if strings.Contains(content, "B(") { + t.Fatalf("expected /b.md to be excluded due to budget, got %q", content) + } +} + +func TestNew_Validation_EmptyAgentFiles(t *testing.T) { + ctx := context.Background() + b := newMemBackend() + + _, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{}}) + if err == nil { + t.Fatal("expected error for empty agent files") + } + if !strings.Contains(err.Error(), "at least one agent file path is required") { + t.Fatalf("unexpected error message: %v", err) + } +} + +func TestMiddleware_GenerateError(t *testing.T) { + // Non-ErrNotExist errors (e.g. permission denied) should propagate. + b := &errBackend{} + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/file.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + _, err = wrapped.Generate(ctx, []*schema.Message{{Role: schema.User, Content: "hi"}}) + if err == nil { + t.Fatal("expected error when backend read fails with non-ErrNotExist") + } + if !strings.Contains(err.Error(), "failed to load agent files") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestMiddleware_StreamError(t *testing.T) { + // Non-ErrNotExist errors (e.g. permission denied) should propagate. + b := &errBackend{} + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/file.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + _, err = wrapped.Stream(ctx, []*schema.Message{{Role: schema.User, Content: "hi"}}) + if err == nil { + t.Fatal("expected error when backend read fails with non-ErrNotExist for stream") + } + if !strings.Contains(err.Error(), "failed to load agent files") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestLoader_DuplicateTopLevelFiles(t *testing.T) { + // Same file listed twice in AgentFiles; second should be deduped via seen map. + b := newMemBackend() + b.set("/agent.md", "unique content") + + l := newLoaderConfig(b, []string{"/agent.md", "/agent.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + + count := strings.Count(content, "Contents of /agent.md") + if count != 1 { + t.Fatalf("expected /agent.md section once (deduped), got %d", count) + } +} + +func TestLoader_LoadFileError(t *testing.T) { + // Missing file (ErrNotExist) is silently skipped. + b := newMemBackend() + l := newLoaderConfig(b, []string{"/missing.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatalf("expected missing file to be skipped, got error: %v", err) + } + if content != "" { + t.Fatalf("expected empty output, got %q", content) + } +} + +func TestLoader_MaxBytesStopsImports(t *testing.T) { + // When budget is exhausted, further imports in collectImports should be skipped. + b := newMemBackend() + b.set("/main.md", "@/big.md @/small.md") + b.set("/big.md", strings.Repeat("B", 200)) + b.set("/small.md", "SMALL") + + l := newLoaderConfig(b, []string{"/main.md"}, 50, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + + // main.md itself is loaded (always), big.md pushes over budget, + // small.md should be skipped. + if !strings.Contains(content, "Contents of /main.md") { + t.Fatal("main.md should be present") + } + if strings.Contains(content, "SMALL") { + t.Fatal("small.md should be skipped after budget exhausted") + } +} + +func TestFormatContent_Empty(t *testing.T) { + // formatContent with nil/empty slice should return empty string. + if got := formatContent(nil); got != "" { + t.Fatalf("expected empty string for nil, got %q", got) + } + if got := formatContent([]loadedFile{}); got != "" { + t.Fatalf("expected empty string for empty slice, got %q", got) + } +} + +func TestMiddleware_AllFilesEmpty(t *testing.T) { + // When all agent files have empty content, loader returns "" and + // prependAgentMD returns the original input unchanged. + b := newMemBackend() + b.set("/agent.md", "") + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/agent.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + userMsg := []*schema.Message{{Role: schema.User, Content: "hello"}} + if _, err = wrapped.Generate(ctx, userMsg); err != nil { + t.Fatal(err) + } + // Empty file produces no agentmd content, so original messages pass through unchanged. + if len(mock.lastInput) != 1 { + t.Fatalf("expected 1 message (no agentmd prepended), got %d", len(mock.lastInput)) + } + if mock.lastInput[0].Content != "hello" { + t.Fatalf("expected original message unchanged, got %q", mock.lastInput[0].Content) + } +} + +func TestLoader_ExactOutput(t *testing.T) { + // Verify the exact output format matches the expected structure: + // each file (top-level and imported) gets its own "Contents of ..." section, + // @path references are preserved in the original text. + b := newMemBackend() + b.set("/project/CLAUDE.md", "this is project claude.md\n\n- git workflow @git/git-instructions.md") + b.set("/project/git/git-instructions.md", "this is git-instructions.md") + + l := newLoaderConfig(b, []string{"/project/CLAUDE.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatal(err) + } + + expected := ` +As you answer the user's questions, you can use the following context: +Codebase and user instructions are shown below. Be sure to adhere to these instructions. IMPORTANT: These instructions OVERRIDE any default behavior and you MUST follow them exactly as written. + +Contents of /project/CLAUDE.md (instructions): + +this is project claude.md + +- git workflow @git/git-instructions.md + +Contents of /project/git/git-instructions.md (instructions): + +this is git-instructions.md +IMPORTANT: this context may or may not be relevant to your tasks. You should not respond to this context unless it is highly relevant to your task. +` + + if content != expected { + t.Fatalf("output mismatch.\n\ngot:\n%s\n\nexpected:\n%s", content, expected) + } +} + +func TestLoader_MissingFileSkipped(t *testing.T) { + b := newMemBackend() + b.set("/good.md", "GOOD CONTENT") + // /missing.md is not set + + l := newLoaderConfig(b, []string{"/missing.md", "/good.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatalf("expected no error for missing file, got %v", err) + } + if !strings.Contains(content, "GOOD CONTENT") { + t.Fatal("expected good.md content in output") + } +} + +func TestLoader_AllMissingFilesSkipped(t *testing.T) { + b := newMemBackend() + + l := newLoaderConfig(b, []string{"/missing1.md", "/missing2.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatalf("expected no error for missing files, got %v", err) + } + if content != "" { + t.Fatalf("expected empty output when all files missing, got %q", content) + } +} + +func TestLoader_CircularImportSkipped(t *testing.T) { + b := newMemBackend() + b.set("/a.md", "A content @/b.md") + b.set("/b.md", "B content @/a.md") + + // Circular import in collectImports is logged via onWarning and skipped. + l := newLoaderConfig(b, []string{"/a.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !strings.Contains(content, "A content") { + t.Fatal("expected a.md content") + } + if !strings.Contains(content, "B content") { + t.Fatal("expected b.md content") + } +} + +func TestLoader_DepthExceededSkipped(t *testing.T) { + b := newMemBackend() + // Create a chain that exceeds maxImportDepth (5) + b.set("/l0.md", "@/l1.md") + b.set("/l1.md", "@/l2.md") + b.set("/l2.md", "@/l3.md") + b.set("/l3.md", "@/l4.md") + b.set("/l4.md", "@/l5.md") + b.set("/l5.md", "@/l6.md") + b.set("/l6.md", "DEEP") + + l := newLoaderConfig(b, []string{"/l0.md"}, 0, nil) + content, err := l.load(context.Background()) + if err != nil { + t.Fatalf("expected no error for depth exceeded, got %v", err) + } + // Should have content up to the depth limit, deep file skipped. + if !strings.Contains(content, "/l0.md") { + t.Fatal("expected l0.md in output") + } +} + +func TestLoader_OnLoadWarningCallback(t *testing.T) { + b := newMemBackend() + b.set("/good.md", "GOOD CONTENT") + + var warnings []error + onWarning := func(filePath string, err error) { + warnings = append(warnings, fmt.Errorf("%s: %w", filePath, err)) + } + + l := newLoaderConfig(b, []string{"/missing.md", "/good.md"}, 0, onWarning) + content, err := l.load(context.Background()) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !strings.Contains(content, "GOOD CONTENT") { + t.Fatal("expected good.md content in output") + } + if len(warnings) == 0 { + t.Fatal("expected at least one warning for missing file") + } + if !strings.Contains(warnings[0].Error(), "file not found") { + t.Fatalf("expected file not found warning, got %v", warnings[0]) + } +} + +func TestMiddleware_MissingFile_Generate(t *testing.T) { + b := newMemBackend() + // /missing.md not set — will fail to read + + ctx := context.Background() + mw, err := New(ctx, &Config{ + Backend: b, + AgentsMDFiles: []string{"/missing.md"}, + }) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + userMsg := []*schema.Message{{Role: schema.User, Content: "hello"}} + _, err = wrapped.Generate(ctx, userMsg) + if err != nil { + t.Fatalf("expected no error for missing file, got %v", err) + } + // No agent.md content, so original messages should be passed through unchanged. + if len(mock.lastInput) != 1 { + t.Fatalf("expected 1 message (no agentmd prepended), got %d", len(mock.lastInput)) + } +} + +func TestMiddleware_MissingFile_Stream(t *testing.T) { + b := newMemBackend() + + ctx := context.Background() + mw, err := New(ctx, &Config{ + Backend: b, + AgentsMDFiles: []string{"/missing.md"}, + }) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + userMsg := []*schema.Message{{Role: schema.User, Content: "hello"}} + _, err = wrapped.Stream(ctx, userMsg) + if err != nil { + t.Fatalf("expected no error for missing file, got %v", err) + } + if len(mock.lastInput) != 1 { + t.Fatalf("expected 1 message (no agentmd prepended), got %d", len(mock.lastInput)) + } +} + +func TestMiddleware_InsertBeforeFirstUserMessage(t *testing.T) { + b := newMemBackend() + b.set("/agent.md", "agent instructions") + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/agent.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + // Input has a System message before the User message. + input := []*schema.Message{ + {Role: schema.System, Content: "system prompt"}, + {Role: schema.User, Content: "hello"}, + } + if _, err = wrapped.Generate(ctx, input); err != nil { + t.Fatal(err) + } + + if len(mock.lastInput) != 3 { + t.Fatalf("expected 3 messages, got %d", len(mock.lastInput)) + } + if mock.lastInput[0].Role != schema.System { + t.Fatalf("expected first message role System, got %s", mock.lastInput[0].Role) + } + if mock.lastInput[0].Content != "system prompt" { + t.Fatalf("expected system prompt preserved, got %q", mock.lastInput[0].Content) + } + if mock.lastInput[1].Role != schema.User || !strings.Contains(mock.lastInput[1].Content, "agent instructions") { + t.Fatalf("expected agentmd message before user message, got role=%s content=%q", mock.lastInput[1].Role, mock.lastInput[1].Content) + } + if mock.lastInput[2].Role != schema.User || mock.lastInput[2].Content != "hello" { + t.Fatalf("expected original user message at index 2, got role=%s content=%q", mock.lastInput[2].Role, mock.lastInput[2].Content) + } +} + +func TestMiddleware_InsertWithNoUserMessage(t *testing.T) { + b := newMemBackend() + b.set("/agent.md", "agent instructions") + + ctx := context.Background() + mw, err := New(ctx, &Config{Backend: b, AgentsMDFiles: []string{"/agent.md"}}) + if err != nil { + t.Fatal(err) + } + + mock := &mockModel{} + wrapped, err := mw.WrapModel(ctx, mock, nil) + if err != nil { + t.Fatal(err) + } + + // Input has no User message at all. + input := []*schema.Message{ + {Role: schema.System, Content: "system prompt"}, + {Role: schema.Assistant, Content: "assistant reply"}, + } + if _, err = wrapped.Generate(ctx, input); err != nil { + t.Fatal(err) + } + + if len(mock.lastInput) != 3 { + t.Fatalf("expected 3 messages, got %d", len(mock.lastInput)) + } + if mock.lastInput[0].Role != schema.System { + t.Fatalf("expected System at index 0, got %s", mock.lastInput[0].Role) + } + if mock.lastInput[1].Role != schema.Assistant { + t.Fatalf("expected Assistant at index 1, got %s", mock.lastInput[1].Role) + } + if mock.lastInput[2].Role != schema.User || !strings.Contains(mock.lastInput[2].Content, "agent instructions") { + t.Fatalf("expected agentmd appended at end, got role=%s content=%q", mock.lastInput[2].Role, mock.lastInput[2].Content) + } +} + +func TestLoader_ImportIOError(t *testing.T) { + // When an imported file returns a non-ErrNotExist error (e.g. I/O error), + // the load should propagate the error (covers collectImports and loadFile error paths). + b := &partialErrBackend{ + files: map[string]string{ + "/main.md": "content @/broken.md", + }, + // /broken.md is NOT in the map, so Read returns I/O error (not ErrNotExist) + } + + l := newLoaderConfig(b, []string{"/main.md"}, 0, nil) + _, err := l.load(context.Background()) + if err == nil { + t.Fatal("expected error from I/O failure on imported file") + } + if !strings.Contains(err.Error(), "I/O error") { + t.Fatalf("expected I/O error, got: %v", err) + } +} diff --git a/adk/middlewares/agentsmd/loader.go b/adk/middlewares/agentsmd/loader.go new file mode 100644 index 000000000..db733383b --- /dev/null +++ b/adk/middlewares/agentsmd/loader.go @@ -0,0 +1,299 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package agentsmd + +import ( + "context" + "errors" + "fmt" + "log" + "os" + "path/filepath" + "regexp" + "strings" + + "github.com/cloudwego/eino/adk/filesystem" + "github.com/cloudwego/eino/adk/internal" +) + +// importRegex matches @path/to/file anywhere in text. +// The path must start with a letter, digit, dot, underscore, slash, or tilde, followed by +// path characters (letters, digits, dots, slashes, hyphens, underscores). +// A post-match filter further requires the path to contain "/" or end with +// an allowed extension (see allowedImportExts), so bare words like @someone +// and email-like patterns like @example.com are ignored. +var importRegex = regexp.MustCompile(`@([a-zA-Z0-9_.~/][a-zA-Z0-9_.~/\-]*)`) + +// allowedImportExts is the set of file extensions recognised as @import targets. +// Paths without "/" must end with one of these extensions to be treated as imports; +// this avoids false positives on email addresses (@example.com) and mentions (@foo.bar). +var allowedImportExts = map[string]bool{ + ".md": true, + ".txt": true, + ".mdx": true, + ".yaml": true, + ".yml": true, + ".json": true, + ".toml": true, +} + +const maxImportDepth = 5 + +// ReadRequest is an alias for filesystem.ReadRequest. +type ReadRequest = filesystem.ReadRequest +type FileContent = filesystem.FileContent + +// Backend defines the file access interface for loading Agents.md files. +// Implementations can use local filesystem, remote storage, or any other backend. +type Backend interface { + // Read reads the content of a file. + // If the file does not exist, implementations should return an error wrapping + // os.ErrNotExist (so that errors.Is(err, os.ErrNotExist) returns true). This allows the loader + // to silently skip missing files and notify via OnLoadWarning callback. + // Other errors (e.g. permission denied, I/O errors) will abort the loading process. + Read(ctx context.Context, req *ReadRequest) (*FileContent, error) +} + +// loaderConfig holds the immutable configuration for creating loaders. +// It is safe for concurrent use by multiple goroutines. +type loaderConfig struct { + backend Backend + files []string // ordered file paths from config + maxBytes int // cumulative read budget; 0 means unlimited + onWarning func(filePath string, err error) // callback for non-fatal loading warnings +} + +func newLoaderConfig(backend Backend, files []string, maxBytes int, onWarning func(filePath string, err error)) *loaderConfig { + if onWarning == nil { + onWarning = func(filePath string, err error) { + log.Printf("[agentsmd] warning: %s: %v", filePath, err) + } + } + return &loaderConfig{ + backend: backend, + files: files, + maxBytes: maxBytes, + onWarning: onWarning, + } +} + +// loader handles loading and @import resolution for agents.md files. +// A new loader is created for each load() call to avoid sharing mutable state +// (totalBytes) across concurrent invocations. +type loader struct { + *loaderConfig + totalBytes int // accumulated bytes during this load call +} + +func (cfg *loaderConfig) newLoader() *loader { + return &loader{loaderConfig: cfg} +} + +// load reads all agents.md files and returns the formatted content. +// Each top-level file and its @imported files appear as separate sections. +func (cfg *loaderConfig) load(ctx context.Context) (string, error) { + l := cfg.newLoader() + + var parts []loadedFile + seen := make(map[string]bool) // dedup across all files and imports + + for i, filePath := range l.files { + files, err := l.loadFile(ctx, filePath, 0, make(map[string]bool), seen) + if err != nil { + return "", fmt.Errorf("failed to load %q: %w", filePath, err) + } + + // If loading this file caused the budget to be exceeded, skip it + // (but always include the first file). + if i > 0 && l.maxBytes > 0 && l.totalBytes > l.maxBytes { + l.onWarning(filePath, fmt.Errorf("skipped: cumulative size %d exceeds max bytes %d", l.totalBytes, l.maxBytes)) + break + } + + parts = append(parts, files...) + } + + return formatContent(parts), nil +} + +// loadFile reads a file via Backend and collects @imported files as separate entries. +// Returns a slice where the first element is this file itself, followed by all +// transitively imported files (in encounter order, preserving @path in original text). +// visited tracks the current ancestor chain to detect circular imports. +// seen tracks globally loaded files to avoid duplicate reads and byte counting. +func (l *loader) loadFile(ctx context.Context, filePath string, depth int, visited map[string]bool, seen map[string]bool) ([]loadedFile, error) { + filePath = filepath.Clean(filePath) + + if depth > maxImportDepth { + l.onWarning(filePath, fmt.Errorf("@import depth exceeds maximum of %d", maxImportDepth)) + return nil, nil + } + + if visited[filePath] { + l.onWarning(filePath, fmt.Errorf("circular @import detected")) + return nil, nil + } + + if seen[filePath] { + return nil, nil + } + + visited[filePath] = true + defer delete(visited, filePath) + + fileContent, err := l.backend.Read(ctx, &ReadRequest{FilePath: filePath, Offset: 1}) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + l.onWarning(filePath, fmt.Errorf("file not found, skipping")) + return nil, nil + } + return nil, err + } + content := "" + if fileContent != nil { + content = fileContent.Content + } + + l.totalBytes += len(content) + seen[filePath] = true + + if content == "" { + return nil, nil + } + + // Collect imported files as separate sections (content stays untouched). + imports, err := l.collectImports(ctx, filePath, content, depth, visited, seen) + if err != nil { + return nil, err + } + + // This file first, then its imports. + result := make([]loadedFile, 0, 1+len(imports)) + result = append(result, loadedFile{path: filePath, content: content}) + result = append(result, imports...) + return result, nil +} + +// collectImports scans content for @path/to/file references and loads each +// imported file (plus its transitive imports). The original content is NOT modified. +// Returns the list of imported loadedFile entries in encounter order. +// seen is shared across the entire load call to avoid duplicate reads. +// Non-fatal errors (file not found, depth exceeded, circular import) are reported +// via onWarning and skipped. Fatal errors (e.g. I/O) are returned. +func (l *loader) collectImports(ctx context.Context, hostPath, content string, depth int, visited map[string]bool, seen map[string]bool) ([]loadedFile, error) { + dir := filepath.Dir(hostPath) + var imports []loadedFile + + matches := importRegex.FindAllStringSubmatch(content, -1) + for _, match := range matches { + rawPath := match[1] + + // Only treat as import if path contains "/" or ends with an allowed extension. + // This avoids false positives on email addresses and social mentions. + if !strings.Contains(rawPath, "/") && !allowedImportExts[filepath.Ext(rawPath)] { + continue + } + + // If budget is exhausted, skip further imports. + if l.maxBytes > 0 && l.totalBytes > l.maxBytes { + break + } + + importPath := rawPath + if !filepath.IsAbs(importPath) { + importPath = filepath.Join(dir, importPath) + } + + if seen[importPath] { + continue + } + + files, err := l.loadFile(ctx, importPath, depth+1, visited, seen) + if err != nil { + return nil, fmt.Errorf("failed to import %q from %q: %w", rawPath, hostPath, err) + } + + imports = append(imports, files...) + } + + return imports, nil +} + +type loadedFile struct { + path string + content string +} + +const formatHeaderEn = ` +As you answer the user's questions, you can use the following context: +Codebase and user instructions are shown below. Be sure to adhere to these instructions. IMPORTANT: These instructions OVERRIDE any default behavior and you MUST follow them exactly as written. +` + +const formatHeaderCn = ` +在回答用户问题时,你可以使用以下上下文: +代码库和用户指令如下。请务必遵守这些指令。重要提示:这些指令会覆盖任何默认行为,你必须严格按照要求执行。 +` + +const formatFileHeaderEn = "\nContents of " + +const formatFileHeaderCn = "\n文件内容:" + +const formatFileLabelEn = " (instructions):\n\n" + +const formatFileLabelCn = "(指令):\n\n" + +const formatFooterEn = `IMPORTANT: this context may or may not be relevant to your tasks. You should not respond to this context unless it is highly relevant to your task. +` + +const formatFooterCn = `重要提示:此上下文可能与你的任务相关,也可能不相关。除非此上下文与你的任务高度相关,否则不要响应此上下文。 +` + +func formatContent(files []loadedFile) string { + if len(files) == 0 { + return "" + } + + header := internal.SelectPrompt(internal.I18nPrompts{ + English: formatHeaderEn, + Chinese: formatHeaderCn, + }) + fileHeader := internal.SelectPrompt(internal.I18nPrompts{ + English: formatFileHeaderEn, + Chinese: formatFileHeaderCn, + }) + fileLabel := internal.SelectPrompt(internal.I18nPrompts{ + English: formatFileLabelEn, + Chinese: formatFileLabelCn, + }) + footer := internal.SelectPrompt(internal.I18nPrompts{ + English: formatFooterEn, + Chinese: formatFooterCn, + }) + + var sb strings.Builder + sb.WriteString(header) + + for _, f := range files { + sb.WriteString(fileHeader) + sb.WriteString(f.path) + sb.WriteString(fileLabel) + sb.WriteString(f.content) + sb.WriteString("\n") + } + sb.WriteString(footer) + return sb.String() +} diff --git a/adk/middlewares/dynamictool/toolsearch/prompt.go b/adk/middlewares/dynamictool/toolsearch/prompt.go new file mode 100644 index 000000000..5aaa56ad1 --- /dev/null +++ b/adk/middlewares/dynamictool/toolsearch/prompt.go @@ -0,0 +1,162 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package toolsearch + +const ( + toolDescription = `Search for or select deferred tools to make them available for use. + +MANDATORY PREREQUISITE - THIS IS A HARD REQUIREMENT + +You MUST use this tool to load deferred tools BEFORE calling them directly. + +This is a BLOCKING REQUIREMENT - deferred tools are NOT available until you load them using this tool. Look for messages in the conversation for the list of tools you can discover. Both query modes (keyword search and direct selection) load the returned tools — once a tool appears in the results, it is immediately available to call. + +Why this is non-negotiable: +- Deferred tools are not loaded until discovered via this tool +- Calling a deferred tool without first loading it will fail +Query modes: + +1. Keyword search - Use keywords when you're unsure which tool to use or need to discover multiple tools at once: + - "list directory" - find tools for listing directories + - "notebook jupyter" - find notebook editing tools + - "slack message" - find slack messaging tools + - Returns up to 5 matching tools ranked by relevance + - All returned tools are immediately available to call — no further selection step needed +2. Direct selection - Use select: when you know the exact tool name: + - "select:mcp__slack__read_channel" + - "select:NotebookEdit" + - "select:Read,Edit,Grep" - load multiple tools at once with comma separation + - Returns the named tool(s) if they exist +IMPORTANT: Both modes load tools equally. Do NOT follow up a keyword search with select: calls for tools already returned — they are already loaded. + +3. Required keyword - Prefix with + to require a match: + - "+linear create issue" - only tools from "linear", ranked by "create"/"issue" + - "+slack send" - only "slack" tools, ranked by "send" + - Useful when you know the service name but not the exact tool +CORRECT Usage Patterns: + + +User: I need to work with slack somehow +Assistant: Let me search for slack tools. +[Calls tool_search with query: "slack"] +Assistant: Found several options including mcp__slack__read_channel. +[Calls mcp__slack__read_channel directly — it was loaded by the keyword search] + + + +User: Edit the Jupyter notebook +Assistant: Let me load the notebook editing tool. +[Calls tool_search with query: "select:NotebookEdit"] +[Calls NotebookEdit] + + + +User: List files in the src directory +Assistant: I can see mcp__filesystem__list_directory in the available tools. Let me select it. +[Calls tool_search with query: "select:mcp__filesystem__list_directory"] +[Calls the tool] + + +INCORRECT Usage Patterns - NEVER DO THESE: + + +User: Read my slack messages +Assistant: [Directly calls mcp__slack__read_channel without loading it first] +WRONG - You must load the tool FIRST using this tool + + + +Assistant: [Calls tool_search with query: "slack", gets back mcp__slack__read_channel] +Assistant: [Calls tool_search with query: "select:mcp__slack__read_channel"] +WRONG - The keyword search already loaded the tool. The select call is redundant. +` + + toolDescriptionChinese = `搜索或选择延迟加载(deferred)的工具,使其可供调用。 + +强制前提条件(MANDATORY PREREQUISITE)— 硬性要求 + +在直接调用任何 延迟加载工具(deferred tools) 之前,你 必须先使用此工具将其加载。 + +这是一个 阻塞性要求(BLOCKING REQUIREMENT) — 延迟加载工具在被加载之前是 不可用的。你需要在对话中查找 消息,以获取可以发现的工具列表。无论使用哪种查询方式(关键字搜索 或 直接选择),只要工具出现在返回结果中,它们就会自动被加载并立即可调用。 + +为什么这是不可协商的规则: +- 延迟加载工具在被发现之前不会被加载 +- 如果你在加载之前直接调用延迟工具,调用将会失败 +查询模式: + +1. 关键字搜索(Keyword search)- 当你不确定具体需要哪个工具,或希望一次发现多个工具时使用关键字搜索: +- "list directory" — 查找用于列出目录的工具 +- "notebook jupyter" — 查找 Jupyter Notebook 编辑工具 +- "slack message" — 查找 Slack 消息相关工具 +- 返回最多 5 个最相关的工具 +- 所有返回的工具都会立即加载并可直接调用 — 不需要额外执行 select 步骤 + +2. 直接选择(Direct selection)— 当你已经知道工具的确切名称时使用 select:: +- "select:mcp__slack__read_channel" +- "select:NotebookEdit" +- "select:Read,Edit,Grep" — 一次加载多个工具 +- 如果工具存在,将被加载并返回 +重要说明:两种模式的加载效果完全相同。不要在关键词搜索之后,对返回的工具再次进行 select: 选择 — 它们已经加载好了。 + +3. 必须匹配关键字(Required keyword)— 在关键字前添加 + 可以 强制匹配特定服务或来源。 +- "+linear create issue" — 仅返回名字中包含 "linear" 的工具,按 "create" / "issue" 排序 +- "+slack send" — 仅返回名字中包含 "slack" 的工具,按 "send" 排序 +- 适用于你知道服务名称但不知道具体工具名称 + +正确使用示例: + + +User: 我需要处理 Slack 相关的事情 +Assistant: 让我搜索 Slack 工具。 +[调用 tool_search,query: "slack"] +Assistant: 找到多个选项,包括 mcp__slack__read_channel。 +[直接调用 mcp__slack__read_channel — 关键字搜索已经加载了该工具] + + + +User: 编辑这个 Jupyter Notebook +Assistant: 让我加载 Notebook 编辑工具。 +[调用 tool_search,query: "select:NotebookEdit"] +[调用 NotebookEdit] + + + +User: 列出 src 目录中的文件 +Assistant: 我看到可用工具中有 mcp__filesystem__list_directory,让我加载它。 +[调用 tool_search,query: "select:mcp__filesystem__list_directory"] +[调用该工具] + + +错误用法(严禁) + + +User: 读取我的 Slack 消息 +Assistant: [不调用 tool_search 工具加载,直接调用 mcp__slack__read_channel] +错误 — 在调用工具之前没有先使用 tool_search 加载该工具。 + + + +Assistant:[调用 tool_search,query: "slack",返回 mcp__slack__read_channel] +Assistant:[再次调用 tool_search,query: "select:mcp__slack__read_channel"] +错误 — 关键字搜索 已经加载了该工具,再次 select 是冗余操作。` + + systemReminderTpl = ` +{{- range .Tools }} +{{ . }} +{{- end }} +` +) diff --git a/adk/middlewares/dynamictool/toolsearch/toolsearch.go b/adk/middlewares/dynamictool/toolsearch/toolsearch.go index 4ee4c216b..55883e914 100644 --- a/adk/middlewares/dynamictool/toolsearch/toolsearch.go +++ b/adk/middlewares/dynamictool/toolsearch/toolsearch.go @@ -18,12 +18,17 @@ package toolsearch import ( + "bytes" "context" "encoding/json" "fmt" - "regexp" + "sort" + "strings" + "text/template" + "unicode" "github.com/cloudwego/eino/adk" + "github.com/cloudwego/eino/adk/internal" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/schema" @@ -33,6 +38,16 @@ import ( type Config struct { // DynamicTools is a list of tools that can be dynamically searched and loaded by the agent. DynamicTools []tool.BaseTool + + // UseModelToolSearch indicates whether the ChatModel natively supports tool search. + // + // When true, the middleware delegates tool search to the model's native capability. + // + // When false (default), the middleware manages tool visibility by filtering the tool list + // based on tool_search results before each model call. Note that this approach may + // invalidate the model's KV-cache (as the tool list changes between calls), and effectiveness + // depends on the model's ability to work with a dynamically changing tool set. + UseModelToolSearch bool } // New constructs and returns the tool search middleware. @@ -41,7 +56,7 @@ type Config struct { // Instead of passing all tools to the model at once (which can overwhelm context limits), // this middleware: // -// 1. Adds a "tool_search" meta-tool that accepts a regex pattern to search tool names +// 1. Adds a "tool_search" meta-tool that accepts keyword queries to search tools // 2. Initially hides all dynamic tools from the model's tool list // 3. When the model calls tool_search, matching tools become available for subsequent calls // @@ -62,14 +77,55 @@ func New(ctx context.Context, config *Config) (adk.ChatModelAgentMiddleware, err return nil, fmt.Errorf("tools is required") } + tpl, err := template.New("").Parse(systemReminderTpl) + if err != nil { + return nil, err + } + + dynamicToolInfos := make([]*schema.ToolInfo, 0, len(config.DynamicTools)) + mapOfDynamicTools := make(map[string]*schema.ToolInfo, len(config.DynamicTools)) + toolNames := make([]string, 0, len(config.DynamicTools)) + for _, t := range config.DynamicTools { + info, infoErr := t.Info(ctx) + if infoErr != nil { + return nil, fmt.Errorf("failed to get dynamic tool info: %w", infoErr) + } + + if _, ok := mapOfDynamicTools[info.Name]; ok { + return nil, fmt.Errorf("duplicate dynamic tool name: %s", info.Name) + } + + toolNames = append(toolNames, info.Name) + mapOfDynamicTools[info.Name] = info + dynamicToolInfos = append(dynamicToolInfos, info) + } + + buf := &bytes.Buffer{} + err = tpl.Execute(buf, systemReminder{Tools: toolNames}) + if err != nil { + return nil, fmt.Errorf("failed to format system reminder template: %w", err) + } + return &middleware{ - dynamicTools: config.DynamicTools, + dynamicTools: config.DynamicTools, + mapOfDynamicTools: mapOfDynamicTools, + dynamicToolInfos: dynamicToolInfos, + useModelToolSearch: config.UseModelToolSearch, + sr: buf.String(), }, nil } +type systemReminder struct { + Tools []string +} + type middleware struct { adk.BaseChatModelAgentMiddleware - dynamicTools []tool.BaseTool + dynamicTools []tool.BaseTool + mapOfDynamicTools map[string]*schema.ToolInfo + dynamicToolInfos []*schema.ToolInfo + useModelToolSearch bool + sr string } func (m *middleware) BeforeAgent(ctx context.Context, runCtx *adk.ChatModelAgentContext) (context.Context, *adk.ChatModelAgentContext, error) { @@ -78,123 +134,384 @@ func (m *middleware) BeforeAgent(ctx context.Context, runCtx *adk.ChatModelAgent } nRunCtx := *runCtx - toolNames, err := getToolNames(ctx, m.dynamicTools) - if err != nil { - return ctx, nil, fmt.Errorf("failed to get tool names: %w", err) - } - nRunCtx.Tools = append(nRunCtx.Tools, newToolSearchTool(toolNames)) + nRunCtx.Tools = make([]tool.BaseTool, len(runCtx.Tools), len(runCtx.Tools)+1+len(m.dynamicTools)) + copy(nRunCtx.Tools, runCtx.Tools) + nRunCtx.Tools = append(nRunCtx.Tools, newToolSearchTool(m.mapOfDynamicTools, m.useModelToolSearch)) nRunCtx.Tools = append(nRunCtx.Tools, m.dynamicTools...) return ctx, &nRunCtx, nil } func (m *middleware) WrapModel(_ context.Context, cm model.BaseChatModel, mc *adk.ModelContext) (model.BaseChatModel, error) { - return &wrapper{allTools: mc.Tools, cm: cm, dynamicTools: m.dynamicTools}, nil + return &wrapper{ + allTools: mc.Tools, + cm: cm, + dynamicToolInfos: m.dynamicToolInfos, + reminder: m.sr, + useModelToolSearch: m.useModelToolSearch, + }, nil } type wrapper struct { - allTools []*schema.ToolInfo - dynamicTools []tool.BaseTool + allTools []*schema.ToolInfo + dynamicToolInfos []*schema.ToolInfo + reminder string + useModelToolSearch bool cm model.BaseChatModel } func (w *wrapper) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { - tools, err := removeTools(ctx, w.allTools, w.dynamicTools, input) + toolsOpts, err := w.resolveTools(ctx, input) if err != nil { return nil, fmt.Errorf("failed to load dynamic tools: %w", err) } - return w.cm.Generate(ctx, input, append(opts, model.WithTools(tools))...) + return w.cm.Generate(ctx, w.insertReminder(input), append(opts, toolsOpts...)...) } func (w *wrapper) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { - tools, err := removeTools(ctx, w.allTools, w.dynamicTools, input) + toolsOpts, err := w.resolveTools(ctx, input) if err != nil { return nil, fmt.Errorf("failed to load dynamic tools: %w", err) } - return w.cm.Stream(ctx, input, append(opts, model.WithTools(tools))...) + return w.cm.Stream(ctx, w.insertReminder(input), append(opts, toolsOpts...)...) +} + +func (w *wrapper) resolveTools(ctx context.Context, input []*schema.Message) ([]model.Option, error) { + if w.useModelToolSearch { + // Model handles tool search natively; remove all dynamic tools from the list. + return calculateTools(ctx, w.allTools, w.dynamicToolInfos, nil, w.useModelToolSearch) + } + return calculateTools(ctx, w.allTools, w.dynamicToolInfos, input, w.useModelToolSearch) +} + +func (w *wrapper) insertReminder(input []*schema.Message) []*schema.Message { + inserted := false + ret := make([]*schema.Message, 0, len(input)+1) + for _, m := range input { + if m.Role != schema.System && !inserted { + inserted = true + ret = append(ret, schema.UserMessage(w.reminder)) + } + ret = append(ret, m) + } + if !inserted { + ret = append(ret, schema.UserMessage(w.reminder)) + } + return ret +} + +func newToolSearchTool(tools map[string]*schema.ToolInfo, useModelToolSearch bool) tool.BaseTool { + if useModelToolSearch { + return &modelToolSearchTool{tools: tools} + } + return &toolSearchTool{tools: tools} } -func newToolSearchTool(toolNames []string) *toolSearchTool { - return &toolSearchTool{toolNames: toolNames} +type toolSearchArgs struct { + Query string `json:"query"` + MaxResults *int `json:"max_results,omitempty"` +} + +type toolSearchResult struct { + Matches []string `json:"matches"` } type toolSearchTool struct { - toolNames []string + tools map[string]*schema.ToolInfo +} + +func (t *toolSearchTool) Info(ctx context.Context) (*schema.ToolInfo, error) { + return getToolSearchToolInfo(), nil +} + +func (t *toolSearchTool) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { + matches, err := search(argumentsInJSON, t.tools) + if err != nil { + return "", err + } + result := &toolSearchResult{} + for _, m := range matches { + result.Matches = append(result.Matches, m.Name) + } + b, err := json.Marshal(result) + if err != nil { + return "", fmt.Errorf("failed to marshal tool search result: %w", err) + } + return string(b), nil +} + +type modelToolSearchTool struct { + tools map[string]*schema.ToolInfo +} + +func (t *modelToolSearchTool) Info(_ context.Context) (*schema.ToolInfo, error) { + return getToolSearchToolInfo(), nil +} + +func (t *modelToolSearchTool) InvokableRun(_ context.Context, argumentsInJSON *schema.ToolArgument, _ ...tool.Option) (*schema.ToolResult, error) { + ret, err := search(argumentsInJSON.Text, t.tools) + if err != nil { + return nil, err + } + + return &schema.ToolResult{Parts: []schema.ToolOutputPart{ + { + Type: schema.ToolPartTypeToolSearchResult, + ToolSearchResult: &schema.ToolSearchResult{ + Tools: ret, + }, + }, + }}, nil } const ( toolSearchToolName = "tool_search" + defaultMaxResults = 5 ) -func (t *toolSearchTool) Info(ctx context.Context) (*schema.ToolInfo, error) { +func getToolSearchToolInfo() *schema.ToolInfo { return &schema.ToolInfo{ - Name: "tool_search", - Desc: "Search for tools using a regex pattern that matches tool names. Returns a list of matching tool names. Use this when you need a tool but don't have it available yet.", + Name: toolSearchToolName, + Desc: internal.SelectPrompt(internal.I18nPrompts{ + English: toolDescription, + Chinese: toolDescriptionChinese, + }), ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ - "regex_pattern": { + "query": { Type: schema.String, - Desc: "A regex pattern to match tool names against.", + Desc: "Query to find deferred tools. Use \"select:\" for direct selection, or keywords to search.", Required: true, }, + "max_results": { + Type: schema.Integer, + Desc: "Maximum number of results to return (default: 5)", + Required: false, + }, }), - }, nil + } } -type toolSearchArgs struct { - RegexPattern string `json:"regex_pattern"` +func search(argumentsInJSON string, tools map[string]*schema.ToolInfo) ([]*schema.ToolInfo, error) { + var args toolSearchArgs + if err := json.Unmarshal([]byte(argumentsInJSON), &args); err != nil { + return nil, fmt.Errorf("failed to unmarshal tool search arguments: %w", err) + } + + query := strings.TrimSpace(args.Query) + if query == "" { + return nil, fmt.Errorf("query is required") + } + + maxResults := defaultMaxResults + if args.MaxResults != nil && *args.MaxResults > 0 { + maxResults = *args.MaxResults + } + + var matches []string + + // Direct selection mode: select:tool1,tool2 + // max_results is intentionally not applied here because the model has + // already specified the exact tools it wants by name. + if strings.HasPrefix(query, "select:") { + names := strings.Split(strings.TrimPrefix(query, "select:"), ",") + toolSet := make(map[string]bool, len(tools)) + for name := range tools { + toolSet[name] = true + } + for _, name := range names { + name = strings.TrimSpace(name) + if name != "" && toolSet[name] { + matches = append(matches, name) + } + } + } else { + matches = keywordSearch(query, maxResults, tools) + } + + ret := make([]*schema.ToolInfo, 0, len(matches)) + for _, name := range matches { + ti, ok := tools[name] + if !ok { + continue + } + ret = append(ret, ti) + } + return ret, nil } -type toolSearchResult struct { - SelectedTools []string `json:"selectedTools"` +func intMax(a, b int) int { + if a > b { + return a + } + return b } -func (t *toolSearchTool) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { - var args toolSearchArgs - if err := json.Unmarshal([]byte(argumentsInJSON), &args); err != nil { - return "", fmt.Errorf("failed to unmarshal tool search arguments: %w", err) +func intMin(a, b int) int { + if a < b { + return a } + return b +} + +// scoredTool pairs a tool name with its search score. +type scoredTool struct { + name string + score int +} - if args.RegexPattern == "" { - return "", fmt.Errorf("regex_pattern is required") +// keywordSearch scores all tools against the query keywords and returns the top N. +func keywordSearch(query string, maxResults int, tools map[string]*schema.ToolInfo) []string { + keywords := parseKeywords(query) + if len(keywords) == 0 { + return nil } - re, err := regexp.Compile(args.RegexPattern) - if err != nil { - return "", fmt.Errorf("invalid regex pattern: %w", err) + var scored []scoredTool + + for name, tm := range tools { + nameParts := splitToolName(name) + nameLower := strings.ToLower(name) + descLower := strings.ToLower(tm.Desc) + + totalScore := 0 + allRequiredFound := true + + for _, kw := range keywords { + kwLower := strings.ToLower(kw.word) + kwScore := 0 + + // Score against name parts + for _, part := range nameParts { + partLower := strings.ToLower(part) + if partLower == kwLower { + kwScore = intMax(kwScore, 10) + } else if strings.Contains(partLower, kwLower) { + kwScore = intMax(kwScore, 5) + } + } + + // Score against full name + if strings.Contains(nameLower, kwLower) { + kwScore = intMax(kwScore, 3) + } + + // Score against description (substring match) + if descLower != "" && strings.Contains(descLower, kwLower) { + kwScore = intMax(kwScore, 2) + } + + if kw.required && kwScore == 0 { + allRequiredFound = false + break + } + + totalScore += kwScore + } + + if !allRequiredFound { + continue + } + + if totalScore > 0 { + scored = append(scored, scoredTool{name: name, score: totalScore}) + } } - var matchedTools []string - for _, name := range t.toolNames { - if re.MatchString(name) { - matchedTools = append(matchedTools, name) + // Sort by score descending, then by name for stability + sort.Slice(scored, func(i, j int) bool { + if scored[i].score != scored[j].score { + return scored[i].score > scored[j].score } + return scored[i].name < scored[j].name + }) + + results := make([]string, 0, intMin(maxResults, len(scored))) + for i := 0; i < len(scored) && i < maxResults; i++ { + results = append(results, scored[i].name) } + return results +} + +// keyword represents a parsed search keyword. +type keyword struct { + word string + required bool +} - result := toolSearchResult{ - SelectedTools: matchedTools, +// parseKeywords splits a query string into keywords, handling the '+' required prefix. +func parseKeywords(query string) (keywords []keyword) { + parts := strings.Fields(query) + for _, p := range parts { + if strings.HasPrefix(p, "+") { + word := strings.TrimPrefix(p, "+") + if word != "" { + keywords = append(keywords, keyword{word: word, required: true}) + } + } else if p != "" { + keywords = append(keywords, keyword{word: p, required: false}) + } } + return +} - output, err := json.Marshal(result) - if err != nil { - return "", fmt.Errorf("failed to marshal result: %w", err) +// splitToolName splits a tool name into parts by underscores, double underscores (MCP separator), +// and camelCase boundaries. +func splitToolName(name string) []string { + // First split by double underscore (MCP server__tool separator) + segments := strings.Split(name, "__") + + var parts []string + for _, seg := range segments { + // Split each segment by single underscore + underscoreParts := strings.Split(seg, "_") + for _, up := range underscoreParts { + if up == "" { + continue + } + // Further split by camelCase + camelParts := splitCamelCase(up) + parts = append(parts, camelParts...) + } + } + return parts +} + +// splitCamelCase splits a camelCase or PascalCase string into its constituent words. +func splitCamelCase(s string) []string { + if s == "" { + return nil } - return string(output), nil + var parts []string + runes := []rune(s) + start := 0 + + for i := 1; i < len(runes); i++ { + if unicode.IsUpper(runes[i]) { + if unicode.IsLower(runes[i-1]) { + parts = append(parts, string(runes[start:i])) + start = i + } else if i+1 < len(runes) && unicode.IsLower(runes[i+1]) { + parts = append(parts, string(runes[start:i])) + start = i + } + } + } + parts = append(parts, string(runes[start:])) + + return parts } -func getToolNames(ctx context.Context, tools []tool.BaseTool) ([]string, error) { +// getToolNames extracts just tool names from a slice of BaseTools (used by calculateTools). +func getToolNames(tools []*schema.ToolInfo) []string { ret := make([]string, 0, len(tools)) for _, t := range tools { - info, err := t.Info(ctx) - if err != nil { - return nil, err - } - ret = append(ret, info.Name) + ret = append(ret, t.Name) } - return ret, nil + return ret } -func extractSelectedTools(ctx context.Context, messages []*schema.Message) ([]string, error) { +func extractSelectedTools(_ context.Context, messages []*schema.Message) ([]string, error) { var selectedTools []string for _, message := range messages { if message.Role != schema.Tool || message.ToolName != toolSearchToolName { @@ -206,7 +523,7 @@ func extractSelectedTools(ctx context.Context, messages []*schema.Message) ([]st if err != nil { return nil, fmt.Errorf("failed to unmarshal tool search tool result: %w", err) } - selectedTools = append(selectedTools, result.SelectedTools...) + selectedTools = append(selectedTools, result.Matches...) } return selectedTools, nil } @@ -226,22 +543,30 @@ func invertSelect[T comparable](all []T, selected []T) map[T]struct{} { return result } -func removeTools(ctx context.Context, all []*schema.ToolInfo, dynamicTools []tool.BaseTool, messages []*schema.Message) ([]*schema.ToolInfo, error) { - selectedToolNames, err := extractSelectedTools(ctx, messages) - if err != nil { - return nil, err +func calculateTools(ctx context.Context, all []*schema.ToolInfo, dynamicTools []*schema.ToolInfo, messages []*schema.Message, useModelToolSearch bool) ([]model.Option, error) { + var err error + var ret []model.Option + var selectedToolNames []string + if !useModelToolSearch { + selectedToolNames, err = extractSelectedTools(ctx, messages) + if err != nil { + return nil, err + } } - dynamicToolNames, err := getToolNames(ctx, dynamicTools) - if err != nil { - return nil, err + dynamicToolNames := getToolNames(dynamicTools) + if useModelToolSearch { + // if useModelToolSearch, register tool search tool by WithToolSearchTool + dynamicToolNames = append(dynamicToolNames, toolSearchToolName) + ret = append(ret, model.WithToolSearchTool(getToolSearchToolInfo())) } removeMap := invertSelect(dynamicToolNames, selectedToolNames) - ret := make([]*schema.ToolInfo, 0, len(all)-len(dynamicTools)) + tools := make([]*schema.ToolInfo, 0, len(all)-len(dynamicTools)) for _, info := range all { if _, ok := removeMap[info.Name]; ok { continue } - ret = append(ret, info) + tools = append(tools, info) } + ret = append(ret, model.WithTools(tools)) return ret, nil } diff --git a/adk/middlewares/dynamictool/toolsearch/toolsearch_test.go b/adk/middlewares/dynamictool/toolsearch/toolsearch_test.go index 4b249b9be..20cee35da 100644 --- a/adk/middlewares/dynamictool/toolsearch/toolsearch_test.go +++ b/adk/middlewares/dynamictool/toolsearch/toolsearch_test.go @@ -19,6 +19,10 @@ package toolsearch import ( "context" "encoding/json" + "fmt" + "sort" + "strings" + "sync" "testing" "github.com/stretchr/testify/assert" @@ -27,464 +31,569 @@ import ( "github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/tool" + "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/schema" ) -type mockTool struct { - name string - desc string -} +// --------------------------------------------------------------------------- +// helpers +// --------------------------------------------------------------------------- -func (m *mockTool) Info(ctx context.Context) (*schema.ToolInfo, error) { - return &schema.ToolInfo{ - Name: m.name, - Desc: m.desc, - }, nil +func makeToolMap(tools ...*schema.ToolInfo) map[string]*schema.ToolInfo { + m := make(map[string]*schema.ToolInfo, len(tools)) + for _, t := range tools { + m[t.Name] = t + } + return m } -func newMockTool(name, desc string) *mockTool { - return &mockTool{name: name, desc: desc} +func ti(name, desc string) *schema.ToolInfo { + return &schema.ToolInfo{Name: name, Desc: desc} } -func TestNew(t *testing.T) { - ctx := context.Background() +func toolNames(infos []*schema.ToolInfo) []string { + names := make([]string, len(infos)) + for i, info := range infos { + names[i] = info.Name + } + sort.Strings(names) + return names +} - t.Run("nil config returns error", func(t *testing.T) { - m, err := New(ctx, nil) - assert.Nil(t, m) - assert.Error(t, err) - assert.Contains(t, err.Error(), "config is required") - }) +func searchJSON(query string, maxResults *int) string { + args := toolSearchArgs{Query: query, MaxResults: maxResults} + b, _ := json.Marshal(args) + return string(b) +} - t.Run("empty tools returns error", func(t *testing.T) { - m, err := New(ctx, &Config{DynamicTools: []tool.BaseTool{}}) - assert.Nil(t, m) - assert.Error(t, err) - assert.Contains(t, err.Error(), "tools is required") - }) +func intPtr(v int) *int { return &v } + +// --------------------------------------------------------------------------- +// TestSearch — unit tests for the search() function +// --------------------------------------------------------------------------- + +func TestSearch(t *testing.T) { + tools := makeToolMap( + ti("get_weather", "Get current weather for a city"), + ti("search_flights", "Search available flights"), + ti("mcp__slack__send_message", "Send a message to Slack channel"), + ti("mcp__slack__read_channel", "Read messages from Slack channel"), + ti("create_calendar_event", "Create a new calendar event"), + ti("NotebookEdit", "Edit Jupyter notebook cells"), + ) + + tests := []struct { + name string + json string + wantNames []string // sorted; nil means expect empty + wantErr bool + }{ + { + name: "keyword exact name part match", + json: searchJSON("weather", nil), + wantNames: []string{"get_weather"}, + }, + { + name: "keyword matches multiple tools", + json: searchJSON("slack", nil), + wantNames: []string{"mcp__slack__read_channel", "mcp__slack__send_message"}, + }, + { + name: "multi-word ranking - send_message ranked first", + json: searchJSON("send message", nil), + wantNames: []string{"mcp__slack__send_message"}, // check first element only + }, + { + name: "required keyword filters to slack only", + json: searchJSON("+slack send", nil), + wantNames: []string{"mcp__slack__read_channel", "mcp__slack__send_message"}, + }, + { + name: "required keyword no match", + json: searchJSON("+github send", nil), + wantNames: nil, + }, + { + name: "direct select single", + json: searchJSON("select:get_weather", nil), + wantNames: []string{"get_weather"}, + }, + { + name: "direct select multiple", + json: searchJSON("select:get_weather,NotebookEdit", nil), + wantNames: []string{"NotebookEdit", "get_weather"}, + }, + { + name: "direct select nonexistent", + json: searchJSON("select:nonexistent", nil), + wantNames: nil, + }, + { + name: "max_results limits output", + json: searchJSON("slack", intPtr(1)), + wantNames: []string{"mcp__slack__read_channel"}, // just check length below + }, + { + name: "camelCase split matches notebook", + json: searchJSON("notebook", nil), + wantNames: []string{"NotebookEdit"}, + }, + { + name: "empty query returns error", + json: searchJSON("", nil), + wantErr: true, + }, + { + name: "description match - jupyter", + json: searchJSON("jupyter", nil), + wantNames: []string{"NotebookEdit"}, + }, + } - t.Run("valid config returns middleware", func(t *testing.T) { - tools := []tool.BaseTool{ - newMockTool("tool1", "desc1"), - newMockTool("tool2", "desc2"), - } - m, err := New(ctx, &Config{DynamicTools: tools}) - assert.NoError(t, err) - assert.NotNil(t, m) - }) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := search(tt.json, tools) + if tt.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + + // special case: max_results limit + if tt.name == "max_results limits output" { + assert.Len(t, got, 1) + return + } + + // special case: ranking — just check first element + if tt.name == "multi-word ranking - send_message ranked first" { + require.NotEmpty(t, got) + assert.Equal(t, "mcp__slack__send_message", got[0].Name) + return + } + + gotNames := toolNames(got) + if tt.wantNames == nil { + assert.Empty(t, gotNames) + } else { + assert.Equal(t, tt.wantNames, gotNames) + } + }) + } } -func TestMiddleware_BeforeAgent(t *testing.T) { - ctx := context.Background() +// --------------------------------------------------------------------------- +// TestMiddlewareFlow — integration test for UseModelToolSearch=false +// --------------------------------------------------------------------------- - t.Run("nil runCtx returns nil", func(t *testing.T) { - tools := []tool.BaseTool{newMockTool("tool1", "desc1")} - m, err := New(ctx, &Config{DynamicTools: tools}) - require.NoError(t, err) +// simpleTool is a minimal InvokableTool for testing. +type simpleTool struct { + name string + desc string + called bool + mu sync.Mutex +} - newCtx, newRunCtx, err := m.BeforeAgent(ctx, nil) - assert.NoError(t, err) - assert.Equal(t, ctx, newCtx) - assert.Nil(t, newRunCtx) - }) +func (s *simpleTool) Info(_ context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: s.name, + Desc: s.desc, + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "input": {Type: schema.String, Desc: "input", Required: true}, + }), + }, nil +} - t.Run("adds tool_search and dynamic tools", func(t *testing.T) { - tools := []tool.BaseTool{ - newMockTool("tool1", "desc1"), - newMockTool("tool2", "desc2"), - } - m, err := New(ctx, &Config{DynamicTools: tools}) - require.NoError(t, err) +func (s *simpleTool) InvokableRun(_ context.Context, _ string, _ ...tool.Option) (string, error) { + s.mu.Lock() + s.called = true + s.mu.Unlock() + return `{"result":"ok"}`, nil +} - middleware := m.(*middleware) - runCtx := &adk.ChatModelAgentContext{ - Tools: []tool.BaseTool{}, - } +func (s *simpleTool) wasCalled() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.called +} - _, newRunCtx, err := middleware.BeforeAgent(ctx, runCtx) - assert.NoError(t, err) - assert.NotNil(t, newRunCtx) - assert.Len(t, newRunCtx.Tools, 3) - }) +// mockChatModel implements model.ToolCallingChatModel. +// It drives a 3-turn conversation: +// +// Turn 1: call tool_search with select:dynamic_tool_a +// Turn 2: call dynamic_tool_a +// Turn 3: return final text +type mockChatModel struct { + mu sync.Mutex + generateCall int + // toolsPerCall records the tool names passed via model.WithTools for each Generate call. + toolsPerCall [][]string } -func TestToolSearchTool_Info(t *testing.T) { - ctx := context.Background() - toolNames := []string{"tool1", "tool2", "tool3"} - tst := newToolSearchTool(toolNames) - - info, err := tst.Info(ctx) - assert.NoError(t, err) - assert.Equal(t, "tool_search", info.Name) - assert.Contains(t, info.Desc, "regex pattern") - assert.NotNil(t, info.ParamsOneOf) +func (m *mockChatModel) Generate(_ context.Context, _ []*schema.Message, opts ...model.Option) (*schema.Message, error) { + options := model.GetCommonOptions(nil, opts...) + var names []string + for _, t := range options.Tools { + names = append(names, t.Name) + } + sort.Strings(names) + + m.mu.Lock() + m.generateCall++ + call := m.generateCall + m.toolsPerCall = append(m.toolsPerCall, names) + m.mu.Unlock() + + switch call { + case 1: + // Ask tool_search to select dynamic_tool_a + return schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "tc1", + Function: schema.FunctionCall{ + Name: toolSearchToolName, + Arguments: `{"query":"select:dynamic_tool_a","max_results":5}`, + }, + }, + }), nil + case 2: + // Call dynamic_tool_a + return schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "tc2", + Function: schema.FunctionCall{ + Name: "dynamic_tool_a", + Arguments: `{"input":"hello"}`, + }, + }, + }), nil + default: + // Final response + return schema.AssistantMessage("done", nil), nil + } } -func TestToolSearchTool_InvokableRun(t *testing.T) { - ctx := context.Background() - toolNames := []string{"get_weather", "get_time", "search_web", "calculate_sum"} - tst := newToolSearchTool(toolNames) +func (m *mockChatModel) Stream(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, fmt.Errorf("not implemented") +} - t.Run("empty regex pattern returns error", func(t *testing.T) { - args := `{"regex_pattern": ""}` - result, err := tst.InvokableRun(ctx, args) - assert.Error(t, err) - assert.Contains(t, err.Error(), "regex_pattern is required") - assert.Empty(t, result) - }) +func (m *mockChatModel) WithTools(_ []*schema.ToolInfo) (model.ToolCallingChatModel, error) { + return m, nil +} - t.Run("invalid json returns error", func(t *testing.T) { - args := `{invalid json}` - result, err := tst.InvokableRun(ctx, args) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to unmarshal") - assert.Empty(t, result) - }) +func (m *mockChatModel) getToolsPerCall() [][]string { + m.mu.Lock() + defer m.mu.Unlock() + ret := make([][]string, len(m.toolsPerCall)) + copy(ret, m.toolsPerCall) + return ret +} - t.Run("invalid regex returns error", func(t *testing.T) { - args := `{"regex_pattern": "[invalid"}` - result, err := tst.InvokableRun(ctx, args) - assert.Error(t, err) - assert.Contains(t, err.Error(), "invalid regex pattern") - assert.Empty(t, result) - }) +func TestMiddlewareFlow(t *testing.T) { + ctx := context.Background() - t.Run("matches tools with prefix pattern", func(t *testing.T) { - args := `{"regex_pattern": "^get_"}` - result, err := tst.InvokableRun(ctx, args) - assert.NoError(t, err) + dynamicA := &simpleTool{name: "dynamic_tool_a", desc: "Dynamic tool A"} + dynamicB := &simpleTool{name: "dynamic_tool_b", desc: "Dynamic tool B"} + staticTool := &simpleTool{name: "static_tool", desc: "Static tool"} - var res toolSearchResult - err = json.Unmarshal([]byte(result), &res) - assert.NoError(t, err) - assert.ElementsMatch(t, []string{"get_weather", "get_time"}, res.SelectedTools) + mw, err := New(ctx, &Config{ + DynamicTools: []tool.BaseTool{dynamicA, dynamicB}, + UseModelToolSearch: false, }) - - t.Run("matches tools with suffix pattern", func(t *testing.T) { - args := `{"regex_pattern": "_sum$"}` - result, err := tst.InvokableRun(ctx, args) - assert.NoError(t, err) - - var res toolSearchResult - err = json.Unmarshal([]byte(result), &res) - assert.NoError(t, err) - assert.ElementsMatch(t, []string{"calculate_sum"}, res.SelectedTools) + require.NoError(t, err) + + cm := &mockChatModel{} + + agent, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ + Name: "test_agent", + Description: "test", + Instruction: "you are a test agent", + Model: cm, + ToolsConfig: adk.ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{staticTool}, + }, + }, + Handlers: []adk.ChatModelAgentMiddleware{mw}, }) + require.NoError(t, err) - t.Run("matches all tools with wildcard", func(t *testing.T) { - args := `{"regex_pattern": ".*"}` - result, err := tst.InvokableRun(ctx, args) - assert.NoError(t, err) + input := &adk.AgentInput{ + Messages: []adk.Message{schema.UserMessage("test")}, + } + iter := agent.Run(ctx, input) - var res toolSearchResult - err = json.Unmarshal([]byte(result), &res) - assert.NoError(t, err) - assert.ElementsMatch(t, toolNames, res.SelectedTools) - }) + var events []*adk.AgentEvent + for { + ev, ok := iter.Next() + if !ok { + break + } + events = append(events, ev) + } - t.Run("no matches returns empty list", func(t *testing.T) { - args := `{"regex_pattern": "^nonexistent_"}` - result, err := tst.InvokableRun(ctx, args) - assert.NoError(t, err) + // Verify no error event. + for _, ev := range events { + if ev.Err != nil { + t.Fatalf("unexpected error event: %v", ev.Err) + } + } - var res toolSearchResult - err = json.Unmarshal([]byte(result), &res) - assert.NoError(t, err) - assert.Empty(t, res.SelectedTools) - }) + // Verify final output is "done". + lastEvent := events[len(events)-1] + require.NotNil(t, lastEvent.Output) + require.NotNil(t, lastEvent.Output.MessageOutput) + assert.Equal(t, "done", lastEvent.Output.MessageOutput.Message.Content) + + // Verify dynamic_tool_a was actually called. + assert.True(t, dynamicA.wasCalled(), "dynamic_tool_a should have been called") + assert.False(t, dynamicB.wasCalled(), "dynamic_tool_b should not have been called") + + // Verify tool lists per Generate call. + toolsPerCall := cm.getToolsPerCall() + require.Len(t, toolsPerCall, 3, "expected 3 Generate calls") + + // Call 1: tool_search + static_tool; dynamic tools are hidden. + assert.Contains(t, toolsPerCall[0], "tool_search") + assert.Contains(t, toolsPerCall[0], "static_tool") + assert.NotContains(t, toolsPerCall[0], "dynamic_tool_a") + assert.NotContains(t, toolsPerCall[0], "dynamic_tool_b") + + // Call 2: after selecting dynamic_tool_a, it becomes visible. + assert.Contains(t, toolsPerCall[1], "tool_search") + assert.Contains(t, toolsPerCall[1], "static_tool") + assert.Contains(t, toolsPerCall[1], "dynamic_tool_a") + assert.NotContains(t, toolsPerCall[1], "dynamic_tool_b") + + // Call 3: same as call 2. + assert.Contains(t, toolsPerCall[2], "tool_search") + assert.Contains(t, toolsPerCall[2], "static_tool") + assert.Contains(t, toolsPerCall[2], "dynamic_tool_a") + assert.NotContains(t, toolsPerCall[2], "dynamic_tool_b") + + // Verify reminder is present in messages (checked via tool list — the wrapper inserts it). + // The model received messages, and the reminder contains "". + // We indirectly verify this by checking that the middleware ran without error and the + // 3-turn flow completed successfully, which requires the tool_search tool to work. + + // Additional: verify that the reminder contains the dynamic tool names. + mwImpl := mw.(*middleware) + assert.True(t, strings.Contains(mwImpl.sr, "dynamic_tool_a")) + assert.True(t, strings.Contains(mwImpl.sr, "dynamic_tool_b")) + assert.True(t, strings.Contains(mwImpl.sr, "")) } -func TestGetToolNames(t *testing.T) { +// --------------------------------------------------------------------------- +// TestNew — error paths for New() +// --------------------------------------------------------------------------- + +func TestNew(t *testing.T) { ctx := context.Background() - t.Run("returns tool names", func(t *testing.T) { - tools := []tool.BaseTool{ - newMockTool("tool1", "desc1"), - newMockTool("tool2", "desc2"), - newMockTool("tool3", "desc3"), - } - names, err := getToolNames(ctx, tools) - assert.NoError(t, err) - assert.Equal(t, []string{"tool1", "tool2", "tool3"}, names) + t.Run("nil config", func(t *testing.T) { + _, err := New(ctx, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "config is required") }) - t.Run("empty tools returns empty slice", func(t *testing.T) { - names, err := getToolNames(ctx, []tool.BaseTool{}) - assert.NoError(t, err) - assert.Empty(t, names) + t.Run("empty DynamicTools", func(t *testing.T) { + _, err := New(ctx, &Config{}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "tools is required") }) -} - -func TestExtractSelectedTools(t *testing.T) { - ctx := context.Background() - - t.Run("extracts selected tools from messages", func(t *testing.T) { - result := toolSearchResult{SelectedTools: []string{"tool1", "tool2"}} - resultJSON, _ := json.Marshal(result) - messages := []*schema.Message{ - schema.UserMessage("hello"), - {Role: schema.Tool, ToolName: toolSearchToolName, Content: string(resultJSON)}, - } - - selected, err := extractSelectedTools(ctx, messages) - assert.NoError(t, err) - assert.ElementsMatch(t, []string{"tool1", "tool2"}, selected) + t.Run("success", func(t *testing.T) { + st := &simpleTool{name: "t1", desc: "tool 1"} + mw, err := New(ctx, &Config{DynamicTools: []tool.BaseTool{st}}) + require.NoError(t, err) + assert.NotNil(t, mw) }) +} - t.Run("handles multiple tool_search results", func(t *testing.T) { - result1 := toolSearchResult{SelectedTools: []string{"tool1"}} - result1JSON, _ := json.Marshal(result1) - result2 := toolSearchResult{SelectedTools: []string{"tool2", "tool3"}} - result2JSON, _ := json.Marshal(result2) +// --------------------------------------------------------------------------- +// TestSplitCamelCase +// --------------------------------------------------------------------------- + +func TestSplitCamelCase(t *testing.T) { + tests := []struct { + input string + want []string + }{ + {"", nil}, + {"hello", []string{"hello"}}, + {"NotebookEdit", []string{"Notebook", "Edit"}}, + {"camelCase", []string{"camel", "Case"}}, + {"HTMLParser", []string{"HTML", "Parser"}}, + {"getURL", []string{"get", "URL"}}, + {"A", []string{"A"}}, + {"AB", []string{"AB"}}, + {"HTTP", []string{"HTTP"}}, + } + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := splitCamelCase(tt.input) + assert.Equal(t, tt.want, got) + }) + } +} - messages := []*schema.Message{ - {Role: schema.Tool, ToolName: toolSearchToolName, Content: string(result1JSON)}, - schema.UserMessage("continue"), - {Role: schema.Tool, ToolName: toolSearchToolName, Content: string(result2JSON)}, - } +// --------------------------------------------------------------------------- +// TestInsertReminder +// --------------------------------------------------------------------------- - selected, err := extractSelectedTools(ctx, messages) - assert.NoError(t, err) - assert.ElementsMatch(t, []string{"tool1", "tool2", "tool3"}, selected) - }) +func TestInsertReminder(t *testing.T) { + w := &wrapper{reminder: ""} - t.Run("ignores non-tool_search messages", func(t *testing.T) { - messages := []*schema.Message{ - schema.UserMessage("hello"), - {Role: schema.Tool, ToolName: "other_tool", Content: "some content"}, - {Role: schema.Assistant, Content: "response"}, + t.Run("normal: system then user", func(t *testing.T) { + input := []*schema.Message{ + {Role: schema.System, Content: "sys"}, + {Role: schema.User, Content: "hi"}, } - - selected, err := extractSelectedTools(ctx, messages) - assert.NoError(t, err) - assert.Empty(t, selected) + got := w.insertReminder(input) + require.Len(t, got, 3) + assert.Equal(t, schema.System, got[0].Role) + assert.Equal(t, schema.User, got[1].Role) + assert.Equal(t, "", got[1].Content) + assert.Equal(t, schema.User, got[2].Role) + assert.Equal(t, "hi", got[2].Content) }) - t.Run("returns error for invalid json", func(t *testing.T) { - messages := []*schema.Message{ - {Role: schema.Tool, ToolName: toolSearchToolName, Content: "invalid json"}, + t.Run("all system messages", func(t *testing.T) { + input := []*schema.Message{ + {Role: schema.System, Content: "sys1"}, + {Role: schema.System, Content: "sys2"}, } - - selected, err := extractSelectedTools(ctx, messages) - assert.Error(t, err) - assert.Nil(t, selected) + got := w.insertReminder(input) + require.Len(t, got, 3) + // Reminder appended at the end since no non-system message found during iteration. + assert.Equal(t, schema.System, got[0].Role) + assert.Equal(t, schema.System, got[1].Role) + assert.Equal(t, "", got[2].Content) }) -} -func TestInvertSelect(t *testing.T) { - t.Run("returns items not in selected", func(t *testing.T) { - all := []string{"a", "b", "c", "d"} - selected := []string{"b", "d"} - - result := invertSelect(all, selected) - assert.Len(t, result, 2) - _, hasA := result["a"] - _, hasC := result["c"] - assert.True(t, hasA) - assert.True(t, hasC) + t.Run("empty input", func(t *testing.T) { + got := w.insertReminder(nil) + require.Len(t, got, 1) + assert.Equal(t, "", got[0].Content) }) - t.Run("empty selected returns all", func(t *testing.T) { - all := []string{"a", "b", "c"} - selected := []string{} - - result := invertSelect(all, selected) - assert.Len(t, result, 3) - }) - - t.Run("all selected returns empty", func(t *testing.T) { - all := []string{"a", "b"} - selected := []string{"a", "b"} - - result := invertSelect(all, selected) - assert.Empty(t, result) - }) - - t.Run("works with integers", func(t *testing.T) { - all := []int{1, 2, 3, 4, 5} - selected := []int{2, 4} - - result := invertSelect(all, selected) - assert.Len(t, result, 3) - _, has1 := result[1] - _, has3 := result[3] - _, has5 := result[5] - assert.True(t, has1) - assert.True(t, has3) - assert.True(t, has5) + t.Run("no system messages", func(t *testing.T) { + input := []*schema.Message{ + {Role: schema.User, Content: "hi"}, + {Role: schema.Assistant, Content: "hello"}, + } + got := w.insertReminder(input) + require.Len(t, got, 3) + // Reminder inserted before the first non-system message. + assert.Equal(t, "", got[0].Content) + assert.Equal(t, "hi", got[1].Content) + assert.Equal(t, "hello", got[2].Content) }) } -func TestRemoveTools(t *testing.T) { - ctx := context.Background() - - t.Run("removes unselected dynamic tools", func(t *testing.T) { - allTools := []*schema.ToolInfo{ - {Name: "static_tool"}, - {Name: "dynamic_tool1"}, - {Name: "dynamic_tool2"}, - {Name: "dynamic_tool3"}, - } +// --------------------------------------------------------------------------- +// TestExtractSelectedTools +// --------------------------------------------------------------------------- - dynamicTools := []tool.BaseTool{ - newMockTool("dynamic_tool1", ""), - newMockTool("dynamic_tool2", ""), - newMockTool("dynamic_tool3", ""), - } +func TestExtractSelectedTools(t *testing.T) { + ctx := context.Background() - result := toolSearchResult{SelectedTools: []string{"dynamic_tool1"}} - resultJSON, _ := json.Marshal(result) + t.Run("accumulates from multiple tool_search results", func(t *testing.T) { messages := []*schema.Message{ - {Role: schema.Tool, ToolName: toolSearchToolName, Content: string(resultJSON)}, - } - - tools, err := removeTools(ctx, allTools, dynamicTools, messages) - assert.NoError(t, err) - assert.Len(t, tools, 2) - - toolNames := make([]string, len(tools)) - for i, t := range tools { - toolNames[i] = t.Name + {Role: schema.Tool, ToolName: toolSearchToolName, Content: `{"matches":["tool_a"]}`}, + {Role: schema.Tool, ToolName: toolSearchToolName, Content: `{"matches":["tool_b","tool_c"]}`}, } - assert.ElementsMatch(t, []string{"static_tool", "dynamic_tool1"}, toolNames) + got, err := extractSelectedTools(ctx, messages) + require.NoError(t, err) + assert.Equal(t, []string{"tool_a", "tool_b", "tool_c"}, got) }) - t.Run("remove all dynamic tools when no tool_search result", func(t *testing.T) { - allTools := []*schema.ToolInfo{ - {Name: "static_tool"}, - {Name: "dynamic_tool1"}, - } - - dynamicTools := []tool.BaseTool{ - newMockTool("dynamic_tool1", ""), - } - + t.Run("ignores non tool_search messages", func(t *testing.T) { messages := []*schema.Message{ - schema.UserMessage("hello"), + {Role: schema.User, Content: "hello"}, + {Role: schema.Tool, ToolName: "other_tool", Content: `{"matches":["should_ignore"]}`}, + {Role: schema.Assistant, Content: "world"}, + {Role: schema.Tool, ToolName: toolSearchToolName, Content: `{"matches":["tool_a"]}`}, } - - tools, err := removeTools(ctx, allTools, dynamicTools, messages) - assert.NoError(t, err) - assert.Len(t, tools, 1) - assert.Equal(t, "static_tool", tools[0].Name) + got, err := extractSelectedTools(ctx, messages) + require.NoError(t, err) + assert.Equal(t, []string{"tool_a"}, got) }) - t.Run("handles empty dynamic tools", func(t *testing.T) { - allTools := []*schema.ToolInfo{ - {Name: "static_tool1"}, - {Name: "static_tool2"}, + t.Run("malformed JSON returns error", func(t *testing.T) { + messages := []*schema.Message{ + {Role: schema.Tool, ToolName: toolSearchToolName, Content: `not json`}, } - - dynamicTools := []tool.BaseTool{} - messages := []*schema.Message{} - - tools, err := removeTools(ctx, allTools, dynamicTools, messages) - assert.NoError(t, err) - assert.Len(t, tools, 2) + _, err := extractSelectedTools(ctx, messages) + assert.Error(t, err) }) -} - -type mockChatModel struct { - generateFunc func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) - streamFunc func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) -} -func (m *mockChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { - if m.generateFunc != nil { - return m.generateFunc(ctx, input, opts...) - } - return &schema.Message{Role: schema.Assistant, Content: "response"}, nil + t.Run("nil messages returns nil", func(t *testing.T) { + got, err := extractSelectedTools(ctx, nil) + require.NoError(t, err) + assert.Nil(t, got) + }) } -func (m *mockChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { - if m.streamFunc != nil { - return m.streamFunc(ctx, input, opts...) - } - return nil, nil -} +// --------------------------------------------------------------------------- +// TestCalculateTools +// --------------------------------------------------------------------------- -func TestWrapper_Generate(t *testing.T) { +func TestCalculateTools(t *testing.T) { ctx := context.Background() - t.Run("filters tools based on tool_search result", func(t *testing.T) { - allTools := []*schema.ToolInfo{ - {Name: "static_tool"}, - {Name: "dynamic_tool1"}, - {Name: "dynamic_tool2"}, - } - - dynamicTools := []tool.BaseTool{ - newMockTool("dynamic_tool1", ""), - newMockTool("dynamic_tool2", ""), - } + staticTool := ti("static_tool", "static") + toolSearchInfo := getToolSearchToolInfo() + dynA := ti("dynamic_a", "A") + dynB := ti("dynamic_b", "B") - result := toolSearchResult{SelectedTools: []string{"dynamic_tool1"}} - resultJSON, _ := json.Marshal(result) + allTools := []*schema.ToolInfo{staticTool, toolSearchInfo, dynA, dynB} + dynamicTools := []*schema.ToolInfo{dynA, dynB} + t.Run("no selection: dynamic tools hidden", func(t *testing.T) { messages := []*schema.Message{ - schema.UserMessage("hello"), - {Role: schema.Tool, ToolName: toolSearchToolName, Content: string(resultJSON)}, + {Role: schema.User, Content: "hello"}, } + opts, err := calculateTools(ctx, allTools, dynamicTools, messages, false) + require.NoError(t, err) - w := &wrapper{ - allTools: allTools, - dynamicTools: dynamicTools, - cm: &mockChatModel{ - generateFunc: func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { - options := model.GetCommonOptions(nil, opts...) - assert.Len(t, options.Tools, 2) - assert.Equal(t, "static_tool", options.Tools[0].Name) - assert.Equal(t, "dynamic_tool1", options.Tools[1].Name) - return nil, nil - }, - }, + options := model.GetCommonOptions(nil, opts...) + var names []string + for _, t := range options.Tools { + names = append(names, t.Name) } - - _, err := w.Generate(ctx, messages) - assert.NoError(t, err) + sort.Strings(names) + assert.Equal(t, []string{"static_tool", "tool_search"}, names) }) -} - -func TestWrapper_Stream(t *testing.T) { - ctx := context.Background() - t.Run("filters tools based on tool_search result", func(t *testing.T) { - allTools := []*schema.ToolInfo{ - {Name: "static_tool"}, - {Name: "dynamic_tool1"}, - {Name: "dynamic_tool2"}, + t.Run("partial selection: selected tool visible", func(t *testing.T) { + messages := []*schema.Message{ + {Role: schema.Tool, ToolName: toolSearchToolName, Content: `{"matches":["dynamic_a"]}`}, } + opts, err := calculateTools(ctx, allTools, dynamicTools, messages, false) + require.NoError(t, err) - dynamicTools := []tool.BaseTool{ - newMockTool("dynamic_tool1", ""), - newMockTool("dynamic_tool2", ""), + options := model.GetCommonOptions(nil, opts...) + var names []string + for _, t := range options.Tools { + names = append(names, t.Name) } + sort.Strings(names) + assert.Equal(t, []string{"dynamic_a", "static_tool", "tool_search"}, names) + }) - result := toolSearchResult{SelectedTools: []string{"dynamic_tool1"}} - resultJSON, _ := json.Marshal(result) - - messages := []*schema.Message{ - schema.UserMessage("hello"), - {Role: schema.Tool, ToolName: toolSearchToolName, Content: string(resultJSON)}, - } + t.Run("useModelToolSearch: dynamic tools and tool_search removed from WithTools", func(t *testing.T) { + opts, err := calculateTools(ctx, allTools, dynamicTools, nil, true) + require.NoError(t, err) - w := &wrapper{ - allTools: allTools, - dynamicTools: dynamicTools, - cm: &mockChatModel{ - streamFunc: func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { - options := model.GetCommonOptions(nil, opts...) - assert.Len(t, options.Tools, 2) - assert.Equal(t, "static_tool", options.Tools[0].Name) - assert.Equal(t, "dynamic_tool1", options.Tools[1].Name) - return nil, nil - }, - }, + options := model.GetCommonOptions(nil, opts...) + var names []string + for _, t := range options.Tools { + names = append(names, t.Name) } - - stream, err := w.Stream(ctx, messages) - assert.NoError(t, err) - assert.Nil(t, stream) + assert.Equal(t, []string{"static_tool"}, names) + // ToolSearchTool should be set. + assert.NotNil(t, options.ToolSearchTool) + assert.Equal(t, toolSearchToolName, options.ToolSearchTool.Name) }) } diff --git a/adk/middlewares/patchtoolcalls/patchtoolcalls.go b/adk/middlewares/patchtoolcalls/patchtoolcalls.go index 75fb5fcbf..833ca3794 100644 --- a/adk/middlewares/patchtoolcalls/patchtoolcalls.go +++ b/adk/middlewares/patchtoolcalls/patchtoolcalls.go @@ -121,6 +121,6 @@ func (m *middleware) createPatchedToolMessage(ctx context.Context, tc schema.Too } const ( - defaultPatchedToolMessageTemplate = "Tool call %s with id %s was cancelled - another message came in before it could be completed." + defaultPatchedToolMessageTemplate = "Tool call %s with id %s was canceled - another message came in before it could be completed." defaultPatchedToolMessageTemplateChinese = "工具调用 %s(ID 为 %s)已被取消——在其完成之前收到了另一条消息。" ) diff --git a/adk/middlewares/permission/permission.go b/adk/middlewares/permission/permission.go new file mode 100644 index 000000000..728606a55 --- /dev/null +++ b/adk/middlewares/permission/permission.go @@ -0,0 +1,257 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package permission provides a ChatModelAgentMiddleware that gates tool execution +// behind a user-defined permission check (Checker). It supports three decisions: +// Allow (execute the tool), Deny (return a deny message as tool result), and Ask +// (interrupt the agent loop via StatefulInterrupt for external approval). +package permission + +import ( + "context" + "fmt" + + "github.com/cloudwego/eino/adk" + "github.com/cloudwego/eino/adk/internal" + "github.com/cloudwego/eino/components/tool" + "github.com/cloudwego/eino/schema" +) + +func init() { + schema.RegisterName[*AskInfo]("_eino_adk_permission_ask_info") + schema.RegisterName[*AskState]("_eino_adk_permission_ask_state") +} + +type Decision string + +const ( + Allow Decision = "allow" + Deny Decision = "deny" + Ask Decision = "ask" +) + +type ToolCallDecision struct { + Decision Decision + Message string + UpdatedInput string + Reason string +} + +// Checker is the user-provided evaluation function invoked before each tool call. +// It receives the full ToolContext (including tool name and call ID) along with +// the tool arguments as a *schema.ToolArgument, and returns a ToolCallDecision +// that determines whether the call is allowed, denied, or requires interactive +// approval. Using *schema.ToolArgument instead of a raw string ensures +// forward-compatibility when the struct gains additional fields (e.g. multimodal +// content). Returning an error signals an infrastructure failure and aborts the +// agent loop; permission denials should use Decision: Deny instead. +type Checker func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) + +type AskInfo struct { + ToolName string + CallID string + Arguments string + Message string +} + +type AskState struct { + Info *AskInfo +} + +type ResumeResponse struct { + Approved bool + UpdatedInput string + DenyMessage string +} + +type Middleware struct { + *adk.BaseChatModelAgentMiddleware + checker Checker +} + +// New creates a permission Middleware with the given Checker evaluator. +func New(checker Checker) *Middleware { + return &Middleware{ + BaseChatModelAgentMiddleware: &adk.BaseChatModelAgentMiddleware{}, + checker: checker, + } +} + +type gateResult struct { + allowed bool + denyResult string + updatedInput string +} + +func (m *Middleware) permissionGate( + ctx context.Context, + tCtx *adk.ToolContext, + argumentsInJSON string, +) (*gateResult, error) { + wasInterrupted, _, savedState := tool.GetInterruptState[*AskState](ctx) + isTarget, hasData, response := tool.GetResumeContext[*ResumeResponse](ctx) + + if wasInterrupted && !isTarget { + return nil, tool.StatefulInterrupt(ctx, savedState.Info, savedState) + } + + if isTarget && hasData { + if !response.Approved { + return &gateResult{denyResult: formatDenyResult(tCtx.Name, response.DenyMessage)}, nil + } + input := argumentsInJSON + if response.UpdatedInput != "" { + input = response.UpdatedInput + } + return &gateResult{allowed: true, updatedInput: input}, nil + } + + if isTarget && !hasData { + return nil, fmt.Errorf( + "permission: tool %q (call_id=%s) was targeted for resume but received nil "+ + "or type-mismatched ResumeResponse; the caller must supply a *permission.ResumeResponse "+ + "via ResumeWithParams", tCtx.Name, tCtx.CallID) + } + + decision, err := m.checker(ctx, tCtx, &schema.ToolArgument{Text: argumentsInJSON}) + if err != nil { + return nil, fmt.Errorf( + "permission: checker error for tool %q (call_id=%s, args=%s): %w", + tCtx.Name, tCtx.CallID, argumentsInJSON, err) + } + if decision == nil { + return nil, fmt.Errorf( + "permission: checker returned nil ToolCallDecision for tool %q (call_id=%s); "+ + "return a valid *ToolCallDecision with Decision set to Allow, Deny, or Ask", + tCtx.Name, tCtx.CallID) + } + + switch decision.Decision { + case Allow: + input := argumentsInJSON + if decision.UpdatedInput != "" { + input = decision.UpdatedInput + } + return &gateResult{allowed: true, updatedInput: input}, nil + + case Deny: + return &gateResult{denyResult: formatDenyResult(tCtx.Name, decision.Message)}, nil + + case Ask: + info := &AskInfo{ + ToolName: tCtx.Name, + CallID: tCtx.CallID, + Arguments: argumentsInJSON, + Message: decision.Message, + } + state := &AskState{Info: info} + return nil, tool.StatefulInterrupt(ctx, info, state) + + default: + return &gateResult{denyResult: formatDenyResult(tCtx.Name, + fmt.Sprintf("unknown permission decision %q; expected allow, deny, or ask", decision.Decision))}, nil + } +} + +func (m *Middleware) WrapInvokableToolCall( + ctx context.Context, + endpoint adk.InvokableToolCallEndpoint, + tCtx *adk.ToolContext, +) (adk.InvokableToolCallEndpoint, error) { + return func(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { + result, err := m.permissionGate(ctx, tCtx, argumentsInJSON) + if err != nil { + return "", err + } + if !result.allowed { + return result.denyResult, nil + } + return endpoint(ctx, result.updatedInput, opts...) + }, nil +} + +func (m *Middleware) WrapStreamableToolCall( + ctx context.Context, + endpoint adk.StreamableToolCallEndpoint, + tCtx *adk.ToolContext, +) (adk.StreamableToolCallEndpoint, error) { + return func(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (*schema.StreamReader[string], error) { + result, err := m.permissionGate(ctx, tCtx, argumentsInJSON) + if err != nil { + return nil, err + } + if !result.allowed { + return schema.StreamReaderFromArray([]string{result.denyResult}), nil + } + return endpoint(ctx, result.updatedInput, opts...) + }, nil +} + +func (m *Middleware) WrapEnhancedInvokableToolCall( + ctx context.Context, + endpoint adk.EnhancedInvokableToolCallEndpoint, + tCtx *adk.ToolContext, +) (adk.EnhancedInvokableToolCallEndpoint, error) { + return func(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (*schema.ToolResult, error) { + result, err := m.permissionGate(ctx, tCtx, toolArgument.Text) + if err != nil { + return nil, err + } + if !result.allowed { + return denyToolResult(result.denyResult), nil + } + if result.updatedInput != toolArgument.Text { + toolArgument = &schema.ToolArgument{Text: result.updatedInput} + } + return endpoint(ctx, toolArgument, opts...) + }, nil +} + +func (m *Middleware) WrapEnhancedStreamableToolCall( + ctx context.Context, + endpoint adk.EnhancedStreamableToolCallEndpoint, + tCtx *adk.ToolContext, +) (adk.EnhancedStreamableToolCallEndpoint, error) { + return func(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (*schema.StreamReader[*schema.ToolResult], error) { + result, err := m.permissionGate(ctx, tCtx, toolArgument.Text) + if err != nil { + return nil, err + } + if !result.allowed { + return schema.StreamReaderFromArray([]*schema.ToolResult{denyToolResult(result.denyResult)}), nil + } + if result.updatedInput != toolArgument.Text { + toolArgument = &schema.ToolArgument{Text: result.updatedInput} + } + return endpoint(ctx, toolArgument, opts...) + }, nil +} + +func denyToolResult(denyMsg string) *schema.ToolResult { + return &schema.ToolResult{ + Parts: []schema.ToolOutputPart{ + {Type: schema.ToolPartTypeText, Text: denyMsg}, + }, + } +} + +func formatDenyResult(toolName, message string) string { + tpl := internal.SelectPrompt(internal.I18nPrompts{ + English: "Permission denied for tool %s: %s", + Chinese: "工具 %s 权限被拒绝: %s", + }) + return fmt.Sprintf(tpl, toolName, message) +} diff --git a/adk/middlewares/permission/permission_test.go b/adk/middlewares/permission/permission_test.go new file mode 100644 index 000000000..f2d97c925 --- /dev/null +++ b/adk/middlewares/permission/permission_test.go @@ -0,0 +1,876 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package permission + +import ( + "bytes" + "context" + "encoding/gob" + "errors" + "fmt" + "io" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/cloudwego/eino/adk" + "github.com/cloudwego/eino/components/tool" + "github.com/cloudwego/eino/internal/core" + "github.com/cloudwego/eino/schema" +) + +func TestNew(t *testing.T) { + called := false + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { + called = true + return &ToolCallDecision{Decision: Allow}, nil + }) + assert.NotNil(t, m) + assert.NotNil(t, m.checker) + assert.NotNil(t, m.BaseChatModelAgentMiddleware) + assert.False(t, called) +} + +func TestFormatDenyResult(t *testing.T) { + result := formatDenyResult("WriteFile", "destructive operation blocked") + assert.Equal(t, "Permission denied for tool WriteFile: destructive operation blocked", result) +} + +func makeCtxWithAddr() context.Context { + ctx := context.Background() + return core.AppendAddressSegment(ctx, "agent", "test-agent", "") +} + +func TestPermissionGate_Allow(t *testing.T) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { + assert.Equal(t, "ReadFile", tCtx.Name) + assert.Equal(t, `{"path":"/tmp/x"}`, args.Text) + return &ToolCallDecision{Decision: Allow, Reason: "read-only"}, nil + }) + + tCtx := &adk.ToolContext{Name: "ReadFile", CallID: "call_1"} + result, err := m.permissionGate(context.Background(), tCtx, `{"path":"/tmp/x"}`) + require.NoError(t, err) + assert.True(t, result.allowed) + assert.Equal(t, `{"path":"/tmp/x"}`, result.updatedInput) +} + +func TestPermissionGate_AllowWithUpdatedInput(t *testing.T) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { + return &ToolCallDecision{ + Decision: Allow, + UpdatedInput: `{"path":"/tmp/safe"}`, + }, nil + }) + + tCtx := &adk.ToolContext{Name: "ReadFile", CallID: "call_1"} + result, err := m.permissionGate(context.Background(), tCtx, `{"path":"/tmp/danger"}`) + require.NoError(t, err) + assert.True(t, result.allowed) + assert.Equal(t, `{"path":"/tmp/safe"}`, result.updatedInput) +} + +func TestPermissionGate_Deny(t *testing.T) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { + return &ToolCallDecision{ + Decision: Deny, + Message: "operation not allowed", + Reason: "policy", + }, nil + }) + + tCtx := &adk.ToolContext{Name: "DeleteFile", CallID: "call_2"} + result, err := m.permissionGate(context.Background(), tCtx, `{}`) + require.NoError(t, err) + assert.False(t, result.allowed) + assert.Equal(t, "Permission denied for tool DeleteFile: operation not allowed", result.denyResult) +} + +func TestPermissionGate_Ask(t *testing.T) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { + return &ToolCallDecision{ + Decision: Ask, + Message: "requires approval", + }, nil + }) + + tCtx := &adk.ToolContext{Name: "Execute", CallID: "call_3"} + ctx := makeCtxWithAddr() + result, err := m.permissionGate(ctx, tCtx, `{"cmd":"rm -rf /"}`) + assert.Nil(t, result) + require.Error(t, err) + + var is *core.InterruptSignal + assert.True(t, errors.As(err, &is)) +} + +func TestPermissionGate_UnknownDecision(t *testing.T) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: "maybe"}, nil + }) + + tCtx := &adk.ToolContext{Name: "SomeTool", CallID: "call_4"} + result, err := m.permissionGate(context.Background(), tCtx, `{}`) + require.NoError(t, err) + assert.False(t, result.allowed) + assert.Contains(t, result.denyResult, "unknown permission decision") + assert.Contains(t, result.denyResult, "maybe") +} + +func TestPermissionGate_NilDecision(t *testing.T) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { + return nil, nil + }) + + tCtx := &adk.ToolContext{Name: "SomeTool", CallID: "call_5"} + result, err := m.permissionGate(context.Background(), tCtx, `{}`) + assert.Nil(t, result) + require.Error(t, err) + assert.Contains(t, err.Error(), "nil ToolCallDecision") +} + +func TestPermissionGate_BeforeToolCallError(t *testing.T) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { + return nil, fmt.Errorf("rule store unreachable") + }) + + tCtx := &adk.ToolContext{Name: "SomeTool", CallID: "call_6"} + result, err := m.permissionGate(context.Background(), tCtx, `{}`) + assert.Nil(t, result) + require.Error(t, err) + assert.Contains(t, err.Error(), "permission: checker error") + assert.Contains(t, err.Error(), "rule store unreachable") +} + +func TestWrapInvokableToolCall_Allow(t *testing.T) { + endpointCalled := false + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Allow}, nil + }) + + originalEndpoint := adk.InvokableToolCallEndpoint(func(ctx context.Context, args string, opts ...tool.Option) (string, error) { + endpointCalled = true + return "tool result: " + args, nil + }) + + tCtx := &adk.ToolContext{Name: "MyTool", CallID: "call_1"} + wrapped, err := m.WrapInvokableToolCall(context.Background(), originalEndpoint, tCtx) + require.NoError(t, err) + + result, err := wrapped(context.Background(), `{"key":"value"}`) + require.NoError(t, err) + assert.True(t, endpointCalled) + assert.Equal(t, `tool result: {"key":"value"}`, result) +} + +func TestWrapInvokableToolCall_AllowWithUpdatedInput(t *testing.T) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Allow, UpdatedInput: `{"sanitized":true}`}, nil + }) + + var receivedArgs string + originalEndpoint := adk.InvokableToolCallEndpoint(func(ctx context.Context, args string, opts ...tool.Option) (string, error) { + receivedArgs = args + return "ok", nil + }) + + tCtx := &adk.ToolContext{Name: "MyTool", CallID: "call_1"} + wrapped, err := m.WrapInvokableToolCall(context.Background(), originalEndpoint, tCtx) + require.NoError(t, err) + + result, err := wrapped(context.Background(), `{"original":true}`) + require.NoError(t, err) + assert.Equal(t, `{"sanitized":true}`, receivedArgs) + assert.Equal(t, "ok", result) +} + +func TestWrapInvokableToolCall_Deny(t *testing.T) { + endpointCalled := false + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Deny, Message: "blocked"}, nil + }) + + originalEndpoint := adk.InvokableToolCallEndpoint(func(ctx context.Context, args string, opts ...tool.Option) (string, error) { + endpointCalled = true + return "should not reach", nil + }) + + tCtx := &adk.ToolContext{Name: "MyTool", CallID: "call_1"} + wrapped, err := m.WrapInvokableToolCall(context.Background(), originalEndpoint, tCtx) + require.NoError(t, err) + + result, err := wrapped(context.Background(), `{}`) + require.NoError(t, err) + assert.False(t, endpointCalled) + assert.Equal(t, "Permission denied for tool MyTool: blocked", result) +} + +func TestWrapInvokableToolCall_Ask(t *testing.T) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Ask, Message: "need approval"}, nil + }) + + originalEndpoint := adk.InvokableToolCallEndpoint(func(ctx context.Context, args string, opts ...tool.Option) (string, error) { + return "should not reach", nil + }) + + tCtx := &adk.ToolContext{Name: "DangerTool", CallID: "call_1"} + wrapped, err := m.WrapInvokableToolCall(context.Background(), originalEndpoint, tCtx) + require.NoError(t, err) + + ctx := makeCtxWithAddr() + result, err := wrapped(ctx, `{"danger":true}`) + assert.Equal(t, "", result) + require.Error(t, err) + + var is *core.InterruptSignal + assert.True(t, errors.As(err, &is)) +} + +func TestWrapStreamableToolCall_Allow(t *testing.T) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Allow}, nil + }) + + originalEndpoint := adk.StreamableToolCallEndpoint(func(ctx context.Context, args string, opts ...tool.Option) (*schema.StreamReader[string], error) { + sr, sw := schema.Pipe[string](1) + sw.Send("stream chunk: "+args, nil) + sw.Close() + return sr, nil + }) + + tCtx := &adk.ToolContext{Name: "StreamTool", CallID: "call_1"} + wrapped, err := m.WrapStreamableToolCall(context.Background(), originalEndpoint, tCtx) + require.NoError(t, err) + + sr, err := wrapped(context.Background(), `{"key":"val"}`) + require.NoError(t, err) + require.NotNil(t, sr) + + chunk, err := sr.Recv() + require.NoError(t, err) + assert.Equal(t, `stream chunk: {"key":"val"}`, chunk) + + _, err = sr.Recv() + assert.Equal(t, io.EOF, err) +} + +func TestWrapStreamableToolCall_Deny(t *testing.T) { + endpointCalled := false + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Deny, Message: "stream blocked"}, nil + }) + + originalEndpoint := adk.StreamableToolCallEndpoint(func(ctx context.Context, args string, opts ...tool.Option) (*schema.StreamReader[string], error) { + endpointCalled = true + return nil, nil + }) + + tCtx := &adk.ToolContext{Name: "StreamTool", CallID: "call_1"} + wrapped, err := m.WrapStreamableToolCall(context.Background(), originalEndpoint, tCtx) + require.NoError(t, err) + + sr, err := wrapped(context.Background(), `{}`) + require.NoError(t, err) + assert.False(t, endpointCalled) + require.NotNil(t, sr) + + chunk, err := sr.Recv() + require.NoError(t, err) + assert.Equal(t, "Permission denied for tool StreamTool: stream blocked", chunk) + + _, err = sr.Recv() + assert.Equal(t, io.EOF, err) +} + +func TestWrapStreamableToolCall_Ask(t *testing.T) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Ask, Message: "stream needs approval"}, nil + }) + + originalEndpoint := adk.StreamableToolCallEndpoint(func(ctx context.Context, args string, opts ...tool.Option) (*schema.StreamReader[string], error) { + return nil, nil + }) + + tCtx := &adk.ToolContext{Name: "StreamTool", CallID: "call_1"} + wrapped, err := m.WrapStreamableToolCall(context.Background(), originalEndpoint, tCtx) + require.NoError(t, err) + + ctx := makeCtxWithAddr() + sr, err := wrapped(ctx, `{}`) + assert.Nil(t, sr) + require.Error(t, err) + + var is *core.InterruptSignal + assert.True(t, errors.As(err, &is)) +} + +func TestWrapInvokableToolCall_BeforeToolCallError(t *testing.T) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { + return nil, fmt.Errorf("infra failure") + }) + + originalEndpoint := adk.InvokableToolCallEndpoint(func(ctx context.Context, args string, opts ...tool.Option) (string, error) { + return "should not reach", nil + }) + + tCtx := &adk.ToolContext{Name: "MyTool", CallID: "call_1"} + wrapped, err := m.WrapInvokableToolCall(context.Background(), originalEndpoint, tCtx) + require.NoError(t, err) + + result, err := wrapped(context.Background(), `{}`) + assert.Equal(t, "", result) + require.Error(t, err) + assert.Contains(t, err.Error(), "permission: checker error") +} + +func TestMiddleware_ImplementsInterface(t *testing.T) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Allow}, nil + }) + var _ adk.ChatModelAgentMiddleware = m +} + +func TestDecisionConstants(t *testing.T) { + assert.Equal(t, Decision("allow"), Allow) + assert.Equal(t, Decision("deny"), Deny) + assert.Equal(t, Decision("ask"), Ask) +} + +func TestAskInfoFields(t *testing.T) { + info := &AskInfo{ + ToolName: "MyTool", + CallID: "call_1", + Arguments: `{"key":"value"}`, + Message: "requires approval", + } + assert.Equal(t, "MyTool", info.ToolName) + assert.Equal(t, "call_1", info.CallID) + assert.Equal(t, `{"key":"value"}`, info.Arguments) + assert.Equal(t, "requires approval", info.Message) +} + +func TestResumeResponse_Approved(t *testing.T) { + resp := &ResumeResponse{ + Approved: true, + UpdatedInput: `{"modified":true}`, + } + assert.True(t, resp.Approved) + assert.Equal(t, `{"modified":true}`, resp.UpdatedInput) +} + +func TestResumeResponse_Denied(t *testing.T) { + resp := &ResumeResponse{ + Approved: false, + DenyMessage: "user rejected", + } + assert.False(t, resp.Approved) + assert.Equal(t, "user rejected", resp.DenyMessage) +} + +func TestAttack_NilBeforeToolCall(t *testing.T) { + m := New(nil) + require.NotNil(t, m) + + tCtx := &adk.ToolContext{Name: "SomeTool", CallID: "call_nil"} + assert.Panics(t, func() { + _, _ = m.permissionGate(context.Background(), tCtx, `{}`) + }) +} + +func TestAttack_EmptyDenyMessage(t *testing.T) { + result := formatDenyResult("WriteTool", "") + assert.Equal(t, "Permission denied for tool WriteTool: ", result) + + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Deny, Message: ""}, nil + }) + + tCtx := &adk.ToolContext{Name: "WriteTool", CallID: "call_empty_deny"} + gr, err := m.permissionGate(context.Background(), tCtx, `{}`) + require.NoError(t, err) + assert.False(t, gr.allowed) + assert.Equal(t, "Permission denied for tool WriteTool: ", gr.denyResult) +} + +func TestAttack_DenyWithEmptyToolName(t *testing.T) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { + assert.Equal(t, "", tCtx.Name) + return &ToolCallDecision{Decision: Deny, Message: "no name"}, nil + }) + + tCtx := &adk.ToolContext{Name: "", CallID: "call_empty_name"} + gr, err := m.permissionGate(context.Background(), tCtx, `{"x":1}`) + require.NoError(t, err) + assert.False(t, gr.allowed) + assert.Equal(t, "Permission denied for tool : no name", gr.denyResult) +} + +func TestAttack_AllowUpdatedInputEmpty(t *testing.T) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Allow, UpdatedInput: ""}, nil + }) + + originalArgs := `{"important":"data"}` + tCtx := &adk.ToolContext{Name: "MyTool", CallID: "call_empty_update"} + gr, err := m.permissionGate(context.Background(), tCtx, originalArgs) + require.NoError(t, err) + assert.True(t, gr.allowed) + assert.Equal(t, originalArgs, gr.updatedInput) + + var receivedArgs string + endpoint := adk.InvokableToolCallEndpoint(func(ctx context.Context, args string, opts ...tool.Option) (string, error) { + receivedArgs = args + return "ok", nil + }) + + wrapped, err := m.WrapInvokableToolCall(context.Background(), endpoint, tCtx) + require.NoError(t, err) + + result, err := wrapped(context.Background(), originalArgs) + require.NoError(t, err) + assert.Equal(t, "ok", result) + assert.Equal(t, originalArgs, receivedArgs) +} + +func TestAttack_AskInfoGobSerializable(t *testing.T) { + info := &AskInfo{ + ToolName: "DangerTool", + CallID: "call_gob", + Arguments: `{"rm":"-rf /"}`, + Message: "are you sure?", + } + state := &AskState{Info: info} + + var buf bytes.Buffer + require.NoError(t, gob.NewEncoder(&buf).Encode(info)) + + var decodedInfo AskInfo + require.NoError(t, gob.NewDecoder(&buf).Decode(&decodedInfo)) + assert.Equal(t, info.ToolName, decodedInfo.ToolName) + assert.Equal(t, info.CallID, decodedInfo.CallID) + assert.Equal(t, info.Arguments, decodedInfo.Arguments) + assert.Equal(t, info.Message, decodedInfo.Message) + + buf.Reset() + require.NoError(t, gob.NewEncoder(&buf).Encode(state)) + + var decodedState AskState + require.NoError(t, gob.NewDecoder(&buf).Decode(&decodedState)) + require.NotNil(t, decodedState.Info) + assert.Equal(t, info.ToolName, decodedState.Info.ToolName) + assert.Equal(t, info.CallID, decodedState.Info.CallID) + assert.Equal(t, info.Arguments, decodedState.Info.Arguments) + assert.Equal(t, info.Message, decodedState.Info.Message) +} + +func TestAttack_ResumeResponseEmptyUpdatedInput(t *testing.T) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Ask, Message: "confirm?"}, nil + }) + + originalArgs := `{"critical":"payload"}` + tCtx := &adk.ToolContext{Name: "CriticalTool", CallID: "call_resume_empty"} + + ctx := makeCtxWithAddr() + gr, err := m.permissionGate(ctx, tCtx, originalArgs) + assert.Nil(t, gr) + require.Error(t, err) + var is *core.InterruptSignal + require.True(t, errors.As(err, &is)) + + resp := &ResumeResponse{Approved: true, UpdatedInput: ""} + assert.True(t, resp.Approved) + assert.Equal(t, "", resp.UpdatedInput) +} + +func TestAttack_ConcurrentBeforeToolCall(t *testing.T) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Allow}, nil + }) + + var receivedMu sync.Mutex + received := make(map[string]string) + + endpoint := adk.InvokableToolCallEndpoint(func(ctx context.Context, args string, opts ...tool.Option) (string, error) { + receivedMu.Lock() + received[args] = "done" + receivedMu.Unlock() + return "result:" + args, nil + }) + + tCtx := &adk.ToolContext{Name: "ConcurrentTool", CallID: "call_concurrent"} + wrapped, err := m.WrapInvokableToolCall(context.Background(), endpoint, tCtx) + require.NoError(t, err) + + const goroutines = 50 + var wg sync.WaitGroup + wg.Add(goroutines) + errs := make([]error, goroutines) + results := make([]string, goroutines) + + for i := 0; i < goroutines; i++ { + go func(idx int) { + defer wg.Done() + args := fmt.Sprintf(`{"id":%d}`, idx) + results[idx], errs[idx] = wrapped(context.Background(), args) + }(i) + } + wg.Wait() + + for i := 0; i < goroutines; i++ { + assert.NoError(t, errs[i], "goroutine %d returned error", i) + expected := fmt.Sprintf(`result:{"id":%d}`, i) + assert.Equal(t, expected, results[i], "goroutine %d result mismatch", i) + } + + receivedMu.Lock() + assert.Len(t, received, goroutines) + receivedMu.Unlock() +} + +func buildResumeCtx( + t *testing.T, + signal *core.InterruptSignal, + resumeData map[string]any, +) context.Context { + t.Helper() + id2Addr, id2State := core.SignalToPersistenceMaps(signal) + ctx := context.Background() + ctx = core.PopulateInterruptState(ctx, id2Addr, id2State) + ctx = core.BatchResumeWithData(ctx, resumeData) + ctx = core.AppendAddressSegment(ctx, "agent", "test-agent", "") + return ctx +} + +func TestE2E_AskThenResumeApproved(t *testing.T) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Ask, Message: "approve rm?"}, nil + }) + + tCtx := &adk.ToolContext{Name: "ShellExec", CallID: "call_e2e_1"} + originalArgs := `{"cmd":"rm -rf /"}` + + ctx := makeCtxWithAddr() + result, err := m.permissionGate(ctx, tCtx, originalArgs) + assert.Nil(t, result) + require.Error(t, err) + + var signal *core.InterruptSignal + require.True(t, errors.As(err, &signal)) + assert.Equal(t, originalArgs, signal.InterruptState.State.(*AskState).Info.Arguments) + + resumeCtx := buildResumeCtx(t, signal, map[string]any{ + signal.ID: &ResumeResponse{Approved: true}, + }) + + result, err = m.permissionGate(resumeCtx, tCtx, originalArgs) + require.NoError(t, err) + assert.True(t, result.allowed) + assert.Equal(t, originalArgs, result.updatedInput) +} + +func TestE2E_AskThenResumeDenied(t *testing.T) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Ask, Message: "dangerous"}, nil + }) + + tCtx := &adk.ToolContext{Name: "DeleteDB", CallID: "call_e2e_deny"} + ctx := makeCtxWithAddr() + + _, err := m.permissionGate(ctx, tCtx, `{}`) + require.Error(t, err) + + var signal *core.InterruptSignal + require.True(t, errors.As(err, &signal)) + + resumeCtx := buildResumeCtx(t, signal, map[string]any{ + signal.ID: &ResumeResponse{Approved: false, DenyMessage: "user said no"}, + }) + + result, err := m.permissionGate(resumeCtx, tCtx, `{}`) + require.NoError(t, err) + assert.False(t, result.allowed) + assert.Contains(t, result.denyResult, "user said no") +} + +func TestE2E_ReInterruptNonTargetTool(t *testing.T) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Ask, Message: "confirm"}, nil + }) + + tCtx := &adk.ToolContext{Name: "ToolA", CallID: "call_a"} + ctx := makeCtxWithAddr() + + _, err := m.permissionGate(ctx, tCtx, `{}`) + require.Error(t, err) + + var signal *core.InterruptSignal + require.True(t, errors.As(err, &signal)) + + id2Addr, id2State := core.SignalToPersistenceMaps(signal) + resumeCtx := context.Background() + resumeCtx = core.PopulateInterruptState(resumeCtx, id2Addr, id2State) + resumeCtx = core.BatchResumeWithData(resumeCtx, map[string]any{ + "some_other_id": &ResumeResponse{Approved: true}, + }) + resumeCtx = core.AppendAddressSegment(resumeCtx, "agent", "test-agent", "") + + result, err := m.permissionGate(resumeCtx, tCtx, `{}`) + assert.Nil(t, result) + require.Error(t, err) + + var reSignal *core.InterruptSignal + require.True(t, errors.As(err, &reSignal)) + assert.Equal(t, "ToolA", reSignal.InterruptState.State.(*AskState).Info.ToolName) +} + +func TestE2E_ResumeWithUpdatedInput(t *testing.T) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Ask, Message: "sanitize?"}, nil + }) + + tCtx := &adk.ToolContext{Name: "WriteFile", CallID: "call_e2e_update"} + originalArgs := `{"path":"/etc/passwd","content":"hacked"}` + + ctx := makeCtxWithAddr() + _, err := m.permissionGate(ctx, tCtx, originalArgs) + require.Error(t, err) + + var signal *core.InterruptSignal + require.True(t, errors.As(err, &signal)) + + sanitizedArgs := `{"path":"/tmp/safe.txt","content":"ok"}` + resumeCtx := buildResumeCtx(t, signal, map[string]any{ + signal.ID: &ResumeResponse{Approved: true, UpdatedInput: sanitizedArgs}, + }) + + result, err := m.permissionGate(resumeCtx, tCtx, originalArgs) + require.NoError(t, err) + assert.True(t, result.allowed) + assert.Equal(t, sanitizedArgs, result.updatedInput) +} + +// --- Enhanced Tool Call Endpoint Tests --- + +func TestWrapEnhancedInvokableToolCall_Allow(t *testing.T) { + endpointCalled := false + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { + assert.Equal(t, "EnhancedTool", tCtx.Name) + assert.Equal(t, `{"key":"val"}`, args.Text) + return &ToolCallDecision{Decision: Allow}, nil + }) + + originalEndpoint := adk.EnhancedInvokableToolCallEndpoint( + func(ctx context.Context, arg *schema.ToolArgument, opts ...tool.Option) (*schema.ToolResult, error) { + endpointCalled = true + return &schema.ToolResult{Parts: []schema.ToolOutputPart{ + {Type: schema.ToolPartTypeText, Text: "enhanced:" + arg.Text}, + }}, nil + }) + + tCtx := &adk.ToolContext{Name: "EnhancedTool", CallID: "call_ei_1"} + wrapped, err := m.WrapEnhancedInvokableToolCall(context.Background(), originalEndpoint, tCtx) + require.NoError(t, err) + + result, err := wrapped(context.Background(), &schema.ToolArgument{Text: `{"key":"val"}`}) + require.NoError(t, err) + assert.True(t, endpointCalled) + require.Len(t, result.Parts, 1) + assert.Equal(t, `enhanced:{"key":"val"}`, result.Parts[0].Text) +} + +func TestWrapEnhancedInvokableToolCall_AllowWithUpdatedInput(t *testing.T) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Allow, UpdatedInput: `{"sanitized":true}`}, nil + }) + + var receivedText string + originalEndpoint := adk.EnhancedInvokableToolCallEndpoint( + func(ctx context.Context, arg *schema.ToolArgument, opts ...tool.Option) (*schema.ToolResult, error) { + receivedText = arg.Text + return &schema.ToolResult{Parts: []schema.ToolOutputPart{ + {Type: schema.ToolPartTypeText, Text: "ok"}, + }}, nil + }) + + tCtx := &adk.ToolContext{Name: "EnhancedTool", CallID: "call_ei_2"} + wrapped, err := m.WrapEnhancedInvokableToolCall(context.Background(), originalEndpoint, tCtx) + require.NoError(t, err) + + _, err = wrapped(context.Background(), &schema.ToolArgument{Text: `{"original":true}`}) + require.NoError(t, err) + assert.Equal(t, `{"sanitized":true}`, receivedText) +} + +func TestWrapEnhancedInvokableToolCall_Deny(t *testing.T) { + endpointCalled := false + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Deny, Message: "enhanced blocked"}, nil + }) + + originalEndpoint := adk.EnhancedInvokableToolCallEndpoint( + func(ctx context.Context, arg *schema.ToolArgument, opts ...tool.Option) (*schema.ToolResult, error) { + endpointCalled = true + return nil, nil + }) + + tCtx := &adk.ToolContext{Name: "EnhancedTool", CallID: "call_ei_3"} + wrapped, err := m.WrapEnhancedInvokableToolCall(context.Background(), originalEndpoint, tCtx) + require.NoError(t, err) + + result, err := wrapped(context.Background(), &schema.ToolArgument{Text: `{}`}) + require.NoError(t, err) + assert.False(t, endpointCalled) + require.Len(t, result.Parts, 1) + assert.Equal(t, schema.ToolPartTypeText, result.Parts[0].Type) + assert.Equal(t, "Permission denied for tool EnhancedTool: enhanced blocked", result.Parts[0].Text) +} + +func TestWrapEnhancedInvokableToolCall_Ask(t *testing.T) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Ask, Message: "enhanced needs approval"}, nil + }) + + originalEndpoint := adk.EnhancedInvokableToolCallEndpoint( + func(ctx context.Context, arg *schema.ToolArgument, opts ...tool.Option) (*schema.ToolResult, error) { + return nil, nil + }) + + tCtx := &adk.ToolContext{Name: "EnhancedTool", CallID: "call_ei_4"} + wrapped, err := m.WrapEnhancedInvokableToolCall(context.Background(), originalEndpoint, tCtx) + require.NoError(t, err) + + ctx := makeCtxWithAddr() + result, err := wrapped(ctx, &schema.ToolArgument{Text: `{}`}) + assert.Nil(t, result) + require.Error(t, err) + + var is *core.InterruptSignal + assert.True(t, errors.As(err, &is)) +} + +func TestWrapEnhancedStreamableToolCall_Allow(t *testing.T) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Allow}, nil + }) + + originalEndpoint := adk.EnhancedStreamableToolCallEndpoint( + func(ctx context.Context, arg *schema.ToolArgument, opts ...tool.Option) (*schema.StreamReader[*schema.ToolResult], error) { + tr := &schema.ToolResult{Parts: []schema.ToolOutputPart{ + {Type: schema.ToolPartTypeText, Text: "streamed:" + arg.Text}, + }} + return schema.StreamReaderFromArray([]*schema.ToolResult{tr}), nil + }) + + tCtx := &adk.ToolContext{Name: "EnhancedStreamTool", CallID: "call_es_1"} + wrapped, err := m.WrapEnhancedStreamableToolCall(context.Background(), originalEndpoint, tCtx) + require.NoError(t, err) + + sr, err := wrapped(context.Background(), &schema.ToolArgument{Text: `{"k":"v"}`}) + require.NoError(t, err) + require.NotNil(t, sr) + + chunk, err := sr.Recv() + require.NoError(t, err) + require.Len(t, chunk.Parts, 1) + assert.Equal(t, `streamed:{"k":"v"}`, chunk.Parts[0].Text) + + _, err = sr.Recv() + assert.Equal(t, io.EOF, err) +} + +func TestWrapEnhancedStreamableToolCall_Deny(t *testing.T) { + endpointCalled := false + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Deny, Message: "stream enhanced blocked"}, nil + }) + + originalEndpoint := adk.EnhancedStreamableToolCallEndpoint( + func(ctx context.Context, arg *schema.ToolArgument, opts ...tool.Option) (*schema.StreamReader[*schema.ToolResult], error) { + endpointCalled = true + return nil, nil + }) + + tCtx := &adk.ToolContext{Name: "EnhancedStreamTool", CallID: "call_es_2"} + wrapped, err := m.WrapEnhancedStreamableToolCall(context.Background(), originalEndpoint, tCtx) + require.NoError(t, err) + + sr, err := wrapped(context.Background(), &schema.ToolArgument{Text: `{}`}) + require.NoError(t, err) + assert.False(t, endpointCalled) + require.NotNil(t, sr) + + chunk, err := sr.Recv() + require.NoError(t, err) + require.Len(t, chunk.Parts, 1) + assert.Equal(t, schema.ToolPartTypeText, chunk.Parts[0].Type) + assert.Equal(t, "Permission denied for tool EnhancedStreamTool: stream enhanced blocked", chunk.Parts[0].Text) + + _, err = sr.Recv() + assert.Equal(t, io.EOF, err) +} + +func TestWrapEnhancedStreamableToolCall_Ask(t *testing.T) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Ask, Message: "enhanced stream needs approval"}, nil + }) + + originalEndpoint := adk.EnhancedStreamableToolCallEndpoint( + func(ctx context.Context, arg *schema.ToolArgument, opts ...tool.Option) (*schema.StreamReader[*schema.ToolResult], error) { + return nil, nil + }) + + tCtx := &adk.ToolContext{Name: "EnhancedStreamTool", CallID: "call_es_3"} + wrapped, err := m.WrapEnhancedStreamableToolCall(context.Background(), originalEndpoint, tCtx) + require.NoError(t, err) + + ctx := makeCtxWithAddr() + sr, err := wrapped(ctx, &schema.ToolArgument{Text: `{}`}) + assert.Nil(t, sr) + require.Error(t, err) + + var is *core.InterruptSignal + assert.True(t, errors.As(err, &is)) +} + +func TestWrapEnhancedStreamableToolCall_AllowWithUpdatedInput(t *testing.T) { + m := New(func(ctx context.Context, tCtx *adk.ToolContext, args *schema.ToolArgument) (*ToolCallDecision, error) { + return &ToolCallDecision{Decision: Allow, UpdatedInput: `{"safe":true}`}, nil + }) + + var receivedText string + originalEndpoint := adk.EnhancedStreamableToolCallEndpoint( + func(ctx context.Context, arg *schema.ToolArgument, opts ...tool.Option) (*schema.StreamReader[*schema.ToolResult], error) { + receivedText = arg.Text + tr := &schema.ToolResult{Parts: []schema.ToolOutputPart{ + {Type: schema.ToolPartTypeText, Text: "ok"}, + }} + return schema.StreamReaderFromArray([]*schema.ToolResult{tr}), nil + }) + + tCtx := &adk.ToolContext{Name: "EnhancedStreamTool", CallID: "call_es_4"} + wrapped, err := m.WrapEnhancedStreamableToolCall(context.Background(), originalEndpoint, tCtx) + require.NoError(t, err) + + sr, err := wrapped(context.Background(), &schema.ToolArgument{Text: `{"dangerous":true}`}) + require.NoError(t, err) + require.NotNil(t, sr) + assert.Equal(t, `{"safe":true}`, receivedText) +} diff --git a/adk/prebuilt/deep/deep.go b/adk/prebuilt/deep/deep.go index 48b5349a6..3918d47e4 100644 --- a/adk/prebuilt/deep/deep.go +++ b/adk/prebuilt/deep/deep.go @@ -93,6 +93,10 @@ type Config struct { Handlers []adk.ChatModelAgentMiddleware ModelRetryConfig *adk.ModelRetryConfig + // ModelFailoverConfig configures failover behavior for the ChatModel. + // When set, the agent will automatically fail over to alternative models on errors. + // This config is also propagated to the general sub-agent. + ModelFailoverConfig *adk.ModelFailoverConfig // OutputKey stores the agent's response in the session. // Optional. When set, stores output via AddSessionValue(ctx, outputKey, msg.Content). @@ -129,6 +133,7 @@ func New(ctx context.Context, cfg *Config) (adk.ResumableAgent, error) { cfg.MaxIteration, cfg.Middlewares, append(handlers, cfg.Handlers...), + cfg.ModelFailoverConfig, ) if err != nil { return nil, fmt.Errorf("failed to new task tool: %w", err) @@ -146,9 +151,10 @@ func New(ctx context.Context, cfg *Config) (adk.ResumableAgent, error) { Middlewares: cfg.Middlewares, Handlers: append(handlers, cfg.Handlers...), - GenModelInput: genModelInput, - ModelRetryConfig: cfg.ModelRetryConfig, - OutputKey: cfg.OutputKey, + GenModelInput: genModelInput, + ModelRetryConfig: cfg.ModelRetryConfig, + ModelFailoverConfig: cfg.ModelFailoverConfig, + OutputKey: cfg.OutputKey, }) } diff --git a/adk/prebuilt/deep/task_tool.go b/adk/prebuilt/deep/task_tool.go index 6235021bd..e6fcedeb3 100644 --- a/adk/prebuilt/deep/task_tool.go +++ b/adk/prebuilt/deep/task_tool.go @@ -45,8 +45,9 @@ func newTaskToolMiddleware( maxIteration int, middlewares []adk.AgentMiddleware, handlers []adk.ChatModelAgentMiddleware, + modelFailoverConfig *adk.ModelFailoverConfig, ) (adk.ChatModelAgentMiddleware, error) { - t, err := newTaskTool(ctx, taskToolDescriptionGenerator, subAgents, withoutGeneralSubAgent, cm, instruction, toolsConfig, maxIteration, middlewares, handlers) + t, err := newTaskTool(ctx, taskToolDescriptionGenerator, subAgents, withoutGeneralSubAgent, cm, instruction, toolsConfig, maxIteration, middlewares, handlers, modelFailoverConfig) if err != nil { return nil, err } @@ -71,6 +72,7 @@ func newTaskTool( MaxIteration int, middlewares []adk.AgentMiddleware, handlers []adk.ChatModelAgentMiddleware, + modelFailoverConfig *adk.ModelFailoverConfig, ) (tool.InvokableTool, error) { t := &taskTool{ subAgents: map[string]tool.InvokableTool{}, @@ -88,15 +90,16 @@ func newTaskTool( Chinese: generalAgentDescriptionChinese, }) generalAgent, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ - Name: generalAgentName, - Description: agentDesc, - Instruction: Instruction, - Model: Model, - ToolsConfig: ToolsConfig, - MaxIterations: MaxIteration, - Middlewares: middlewares, - Handlers: handlers, - GenModelInput: genModelInput, + Name: generalAgentName, + Description: agentDesc, + Instruction: Instruction, + Model: Model, + ToolsConfig: ToolsConfig, + MaxIterations: MaxIteration, + Middlewares: middlewares, + Handlers: handlers, + GenModelInput: genModelInput, + ModelFailoverConfig: modelFailoverConfig, }) if err != nil { return nil, err diff --git a/adk/prebuilt/deep/task_tool_test.go b/adk/prebuilt/deep/task_tool_test.go index 91c3a7784..8d60eb452 100644 --- a/adk/prebuilt/deep/task_tool_test.go +++ b/adk/prebuilt/deep/task_tool_test.go @@ -41,6 +41,7 @@ func TestTaskTool(t *testing.T) { 10, nil, nil, + nil, ) assert.NoError(t, err) diff --git a/adk/prebuilt/planexecute/plan_execute_test.go b/adk/prebuilt/planexecute/plan_execute_test.go index fb7360357..6734a16b8 100644 --- a/adk/prebuilt/planexecute/plan_execute_test.go +++ b/adk/prebuilt/planexecute/plan_execute_test.go @@ -18,9 +18,12 @@ package planexecute import ( "context" + "errors" "fmt" "strings" + "sync" "testing" + "time" "github.com/bytedance/sonic" "github.com/stretchr/testify/assert" @@ -1002,3 +1005,232 @@ func TestPlanExecuteAgentInterruptResume(t *testing.T) { assert.True(t, hasAssistantCompletion, "Should have assistant completion message") assert.True(t, hasBreakLoop, "Should have break loop action indicating completion") } + +// slowChatModel is a ChatModel that blocks for a configurable duration. +type slowChatModel struct { + delay time.Duration + response *schema.Message + startedChan chan struct{} + startedOnce sync.Once +} + +func (m *slowChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + m.startedOnce.Do(func() { + close(m.startedChan) + }) + + select { + case <-time.After(m.delay): + return m.response, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (m *slowChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + msg, err := m.Generate(ctx, input, opts...) + if err != nil { + return nil, err + } + sr, sw := schema.Pipe[*schema.Message](1) + sw.Send(msg, nil) + sw.Close() + return sr, nil +} + +func (m *slowChatModel) BindTools(tools []*schema.ToolInfo) error { return nil } +func (m *slowChatModel) WithTools(tools []*schema.ToolInfo) (model.ToolCallingChatModel, error) { + return m, nil +} + +// TestWithCancel_PlanExecute_DuringExecution verifies that cancel works +// during the executor (ChatModelAgent) phase of the PlanExecute agent. +func TestWithCancel_PlanExecute_DuringExecution(t *testing.T) { + ctx := context.Background() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Planner: returns a plan quickly + mockPlanner := mockAdk.NewMockAgent(ctrl) + mockPlanner.EXPECT().Name(gomock.Any()).Return("planner").AnyTimes() + mockPlanner.EXPECT().Description(gomock.Any()).Return("a planner agent").AnyTimes() + + plan := &defaultPlan{Steps: []string{"Step 1", "Step 2"}} + userInput := []adk.Message{schema.UserMessage("test task")} + + mockPlanner.EXPECT().Run(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, input *adk.AgentInput, opts ...adk.AgentRunOption) *adk.AsyncIterator[*adk.AgentEvent] { + iterator, generator := adk.NewAsyncIteratorPair[*adk.AgentEvent]() + adk.AddSessionValue(ctx, PlanSessionKey, plan) + adk.AddSessionValue(ctx, UserInputSessionKey, userInput) + planJSON, _ := sonic.MarshalString(plan) + msg := schema.AssistantMessage(planJSON, nil) + generator.Send(adk.EventFromMessage(msg, nil, schema.Assistant, "")) + generator.Close() + return iterator + }, + ).Times(1) + + // Executor: uses a slow model that we can cancel + executorStarted := make(chan struct{}) + slowModel := &slowChatModel{ + delay: 5 * time.Second, + response: schema.AssistantMessage("step result", nil), + startedChan: executorStarted, + } + + executor, err := NewExecutor(ctx, &ExecutorConfig{ + Model: slowModel, + MaxIterations: 5, + }) + assert.NoError(t, err) + + // Replanner: should not be reached since we cancel during executor + mockReplanner := mockAdk.NewMockAgent(ctrl) + mockReplanner.EXPECT().Name(gomock.Any()).Return("replanner").AnyTimes() + mockReplanner.EXPECT().Description(gomock.Any()).Return("a replanner agent").AnyTimes() + + agent, err := New(ctx, &Config{ + Planner: mockPlanner, + Executor: executor, + Replanner: mockReplanner, + MaxIterations: 5, + }) + assert.NoError(t, err) + + runner := adk.NewRunner(ctx, adk.RunnerConfig{Agent: agent}) + + cancelOpt, cancelFn := adk.WithCancel() + iter := runner.Run(ctx, userInput, cancelOpt) + + // Wait for the executor's model to start + select { + case <-executorStarted: + case <-time.After(10 * time.Second): + t.Fatal("Executor model did not start") + } + + time.Sleep(50 * time.Millisecond) + + // Cancel should NOT return ErrExecutionCompleted + handle, _ := cancelFn() + err = handle.Wait() + assert.NoError(t, err, "Cancel during executor should succeed") + + hasCancelError := false + for { + event, ok := iter.Next() + if !ok { + break + } + var ce *adk.CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + hasCancelError = true + } + } + + assert.True(t, hasCancelError, "Should have CancelError event") +} + +// TestWithCancel_PlanExecute_BetweenTransitions verifies that cancel works +// when fired between agent transitions (e.g., after planner, before executor starts). +func TestWithCancel_PlanExecute_BetweenTransitions(t *testing.T) { + ctx := context.Background() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + plannerDone := make(chan struct{}) + + // Planner: signals when done + mockPlanner := mockAdk.NewMockAgent(ctrl) + mockPlanner.EXPECT().Name(gomock.Any()).Return("planner").AnyTimes() + mockPlanner.EXPECT().Description(gomock.Any()).Return("a planner agent").AnyTimes() + + plan := &defaultPlan{Steps: []string{"Step 1"}} + userInput := []adk.Message{schema.UserMessage("test task")} + + mockPlanner.EXPECT().Run(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, input *adk.AgentInput, opts ...adk.AgentRunOption) *adk.AsyncIterator[*adk.AgentEvent] { + iterator, generator := adk.NewAsyncIteratorPair[*adk.AgentEvent]() + go func() { + defer generator.Close() + adk.AddSessionValue(ctx, PlanSessionKey, plan) + adk.AddSessionValue(ctx, UserInputSessionKey, userInput) + planJSON, _ := sonic.MarshalString(plan) + msg := schema.AssistantMessage(planJSON, nil) + generator.Send(adk.EventFromMessage(msg, nil, schema.Assistant, "")) + close(plannerDone) + }() + return iterator + }, + ).Times(1) + + // Executor: slow model to give time to observe cancel + executorModelStarted := make(chan struct{}) + slowExecModel := &slowChatModel{ + delay: 5 * time.Second, + response: schema.AssistantMessage("step result", nil), + startedChan: executorModelStarted, + } + + executor, err := NewExecutor(ctx, &ExecutorConfig{ + Model: slowExecModel, + MaxIterations: 5, + }) + assert.NoError(t, err) + + mockReplanner := mockAdk.NewMockAgent(ctrl) + mockReplanner.EXPECT().Name(gomock.Any()).Return("replanner").AnyTimes() + mockReplanner.EXPECT().Description(gomock.Any()).Return("a replanner agent").AnyTimes() + + agent, err := New(ctx, &Config{ + Planner: mockPlanner, + Executor: executor, + Replanner: mockReplanner, + MaxIterations: 5, + }) + assert.NoError(t, err) + + runner := adk.NewRunner(ctx, adk.RunnerConfig{Agent: agent}) + + cancelOpt, cancelFn := adk.WithCancel() + iter := runner.Run(ctx, userInput, cancelOpt) + + // Wait for planner to finish, then cancel before executor has a chance to produce output + select { + case <-plannerDone: + case <-time.After(10 * time.Second): + t.Fatal("Planner did not finish") + } + + // Cancel after planner, during executor phase + // The executor is a ChatModelAgent which will handle the cancel + select { + case <-executorModelStarted: + case <-time.After(10 * time.Second): + t.Fatal("Executor model did not start") + } + + start := time.Now() + handle, _ := cancelFn() + err = handle.Wait() + assert.NoError(t, err, "Cancel between transitions should succeed") + + hasCancelError := false + for { + event, ok := iter.Next() + if !ok { + break + } + var ce *adk.CancelError + if event.Err != nil && errors.As(event.Err, &ce) { + hasCancelError = true + } + } + elapsed := time.Since(start) + + assert.True(t, hasCancelError, "Should have CancelError event") + assert.True(t, elapsed < 3*time.Second, "Should complete quickly after cancel, elapsed: %v", elapsed) +} diff --git a/adk/react.go b/adk/react.go index 2bf6dd462..07fdbde9a 100644 --- a/adk/react.go +++ b/adk/react.go @@ -81,6 +81,7 @@ func init() { // when decoding checkpoints created by v0.8.0 - v0.8.3 gob.Register(&AgentEvent{}) gob.Register(int(0)) + schema.RegisterName[*reactInput]("_eino_adk_react_input") } func (s *State) getReturnDirectlyEvent() *AgentEvent { @@ -237,7 +238,7 @@ func SendToolGenAction(ctx context.Context, toolName string, action *AgentAction } type reactInput struct { - messages []Message + Messages []Message } type reactConfig struct { @@ -253,6 +254,8 @@ type reactConfig struct { agentName string maxIterations int + + cancelCtx *cancelContext } func genToolInfos(ctx context.Context, config *compose.ToolsNodeConfig) ([]*schema.ToolInfo, error) { @@ -270,8 +273,6 @@ func genToolInfos(ctx context.Context, config *compose.ToolsNodeConfig) ([]*sche } type reactGraph = *compose.Graph[*reactInput, Message] -type sToolNodeOutput = *schema.StreamReader[[]Message] -type sGraphOutput = MessageStream func getReturnDirectlyToolCallID(ctx context.Context) (string, bool) { var toolCallID string @@ -301,46 +302,67 @@ func genReactState(config *reactConfig) func(ctx context.Context) *State { func newReact(ctx context.Context, config *reactConfig) (reactGraph, error) { const ( - initNode_ = "Init" - chatModel_ = "ChatModel" - toolNode_ = "ToolNode" + initNode_ = "Init" + chatModel_ = "ChatModel" + cancelCheckNode_ = "CancelCheck" + toolNode_ = "ToolNode" + afterToolCallsNode_ = "AfterToolCalls" + afterToolCallsCancelCheckNode_ = "AfterToolCallsCancelCheck" ) + cancelCtx := config.cancelCtx g := compose.NewGraph[*reactInput, Message](compose.WithGenLocalState(genReactState(config))) - - initLambda := func(ctx context.Context, input *reactInput) ([]Message, error) { - return input.messages, nil - } - _ = g.AddLambdaNode(initNode_, compose.InvokableLambda(initLambda), compose.WithNodeName(initNode_)) + _ = g.AddLambdaNode(initNode_, compose.InvokableLambda(func(ctx context.Context, input *reactInput) ([]Message, error) { + _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { + st.Messages = append(st.Messages, input.Messages...) + return nil + }) + return input.Messages, nil + }), compose.WithNodeName(initNode_)) var wrappedModel model.BaseChatModel = config.model if config.modelWrapperConf != nil { wrappedModel = buildModelWrappers(config.model, config.modelWrapperConf) } - toolsNode, err := compose.NewToolNode(ctx, config.toolsConfig) + toolsConfig := config.toolsConfig + + toolsNode, err := compose.NewToolNode(ctx, toolsConfig) if err != nil { return nil, err } - modelPreHandle := func(ctx context.Context, input []Message, st *State) ([]Message, error) { - if st.getRemainingIterations() <= 0 { - return nil, ErrExceedMaxIterations + _ = g.AddChatModelNode(chatModel_, wrappedModel, compose.WithStatePreHandler( + func(ctx context.Context, input []Message, st *State) ([]Message, error) { + if st.getRemainingIterations() <= 0 { + return nil, ErrExceedMaxIterations + } + st.decrementRemainingIterations() + return input, nil + }), compose.WithNodeName(chatModel_)) + + // CancelAfterChatModel safe-point: on the tool-calls path, after the branch + // has confirmed that the model response contains tool calls (i.e. not a final + // answer). Skipped entirely when the model produces a final answer. + _ = g.AddLambdaNode(cancelCheckNode_, compose.InvokableLambda(func(ctx context.Context, msg Message) (Message, error) { + if cancelCtx != nil && cancelCtx.shouldCancel() { + if cancelCtx.getMode()&CancelAfterChatModel != 0 { + return nil, compose.StatefulInterrupt(ctx, "CancelAfterChatModel", msg) + } } - st.decrementRemainingIterations() - return input, nil - } - _ = g.AddChatModelNode(chatModel_, wrappedModel, - compose.WithStatePreHandler(modelPreHandle), compose.WithNodeName(chatModel_)) + wasInterrupted, hasState, state := compose.GetInterruptState[Message](ctx) + if wasInterrupted && hasState { + msg = state + } + return msg, nil + }), compose.WithNodeName(cancelCheckNode_)) toolPreHandle := func(ctx context.Context, _ Message, st *State) (Message, error) { input := st.Messages[len(st.Messages)-1] - returnDirectly := config.toolsReturnDirectly if execCtx := getChatModelAgentExecCtx(ctx); execCtx != nil && len(execCtx.runtimeReturnDirectly) > 0 { returnDirectly = execCtx.runtimeReturnDirectly } - if len(returnDirectly) > 0 { for i := range input.ToolCalls { toolName := input.ToolCalls[i].Function.Name @@ -349,10 +371,8 @@ func newReact(ctx context.Context, config *reactConfig) (reactGraph, error) { } } } - return input, nil } - toolPostHandle := func(ctx context.Context, out *schema.StreamReader[[]*schema.Message], st *State) (*schema.StreamReader[[]*schema.Message], error) { if event := st.getReturnDirectlyEvent(); event != nil { getChatModelAgentExecCtx(ctx).send(event) @@ -360,12 +380,60 @@ func newReact(ctx context.Context, config *reactConfig) (reactGraph, error) { } return out, nil } - _ = g.AddToolsNode(toolNode_, toolsNode, compose.WithStatePreHandler(toolPreHandle), compose.WithStreamStatePostHandler(toolPostHandle), compose.WithNodeName(toolNode_)) + // AfterToolCalls node: calls AfterToolCallsRewriteState handlers after all tool calls complete. + // The graph auto-materializes the ToolsNode stream into []Message before this node. + afterToolCalls := func(ctx context.Context, toolResults []Message) ([]Message, error) { + var stateMessages []Message + _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { + stateMessages = st.Messages + return nil + }) + + state := &ChatModelAgentState{Messages: append(stateMessages, toolResults...)} + + if config.modelWrapperConf != nil { + assistantMsg := stateMessages[len(stateMessages)-1] + tc := &ToolCallsContext{} + for _, toolCall := range assistantMsg.ToolCalls { + tc.ToolCalls = append(tc.ToolCalls, ToolContext{Name: toolCall.Function.Name, CallID: toolCall.ID}) + } + + for _, handler := range config.modelWrapperConf.handlers { + var err error + ctx, state, err = handler.AfterToolCallsRewriteState(ctx, state, tc) + if err != nil { + return nil, err + } + } + } + + _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { + st.Messages = state.Messages + return nil + }) + + return toolResults, nil + } + _ = g.AddLambdaNode(afterToolCallsNode_, compose.InvokableLambda(afterToolCalls), + compose.WithNodeName(afterToolCallsNode_)) + + // AfterToolCallsCancelCheck: CancelAfterToolCalls safe-point, separated from toolPostHandle. + afterToolCallsCancelCheck := func(ctx context.Context, toolResults []Message) ([]Message, error) { + if cancelCtx != nil && cancelCtx.shouldCancel() { + if cancelCtx.getMode()&CancelAfterToolCalls != 0 { + return nil, compose.Interrupt(ctx, "CancelAfterToolCalls") + } + } + return toolResults, nil + } + _ = g.AddLambdaNode(afterToolCallsCancelCheckNode_, compose.InvokableLambda(afterToolCallsCancelCheck), + compose.WithNodeName(afterToolCallsCancelCheckNode_)) + _ = g.AddEdge(compose.START, initNode_) _ = g.AddEdge(initNode_, chatModel_) @@ -382,41 +450,39 @@ func newReact(ctx context.Context, config *reactConfig) (reactGraph, error) { } if len(chunk.ToolCalls) > 0 { - return toolNode_, nil + return cancelCheckNode_, nil } } } - branch := compose.NewStreamGraphBranch(toolCallCheck, map[string]bool{compose.END: true, toolNode_: true}) + branch := compose.NewStreamGraphBranch(toolCallCheck, map[string]bool{compose.END: true, cancelCheckNode_: true}) _ = g.AddBranch(chatModel_, branch) + _ = g.AddEdge(cancelCheckNode_, toolNode_) + _ = g.AddEdge(toolNode_, afterToolCallsNode_) + _ = g.AddEdge(afterToolCallsNode_, afterToolCallsCancelCheckNode_) + if len(config.toolsReturnDirectly) > 0 { const ( toolNodeToEndConverter = "ToolNodeToEndConverter" ) - cvt := func(ctx context.Context, sToolCallMessages sToolNodeOutput) (sGraphOutput, error) { + cvt := func(ctx context.Context, toolResults []Message) (Message, error) { id, _ := getReturnDirectlyToolCallID(ctx) - return schema.StreamReaderWithConvert(sToolCallMessages, - func(in []Message) (Message, error) { - - for _, chunk := range in { - if chunk != nil && chunk.ToolCallID == id { - return chunk, nil - } - } + for _, msg := range toolResults { + if msg != nil && msg.ToolCallID == id { + return msg, nil + } + } - return nil, schema.ErrNoValue - }), nil + return nil, errors.New("return directly tool call result not found") } - _ = g.AddLambdaNode(toolNodeToEndConverter, compose.TransformableLambda(cvt), + _ = g.AddLambdaNode(toolNodeToEndConverter, compose.InvokableLambda(cvt), compose.WithNodeName(toolNodeToEndConverter)) _ = g.AddEdge(toolNodeToEndConverter, compose.END) - checkReturnDirect := func(ctx context.Context, - sToolCallMessages sToolNodeOutput) (string, error) { - + checkReturnDirect := func(ctx context.Context, toolResults []Message) (string, error) { _, ok := getReturnDirectlyToolCallID(ctx) if ok { @@ -426,11 +492,11 @@ func newReact(ctx context.Context, config *reactConfig) (reactGraph, error) { return chatModel_, nil } - branch = compose.NewStreamGraphBranch(checkReturnDirect, + returnDirectBranch := compose.NewGraphBranch(checkReturnDirect, map[string]bool{toolNodeToEndConverter: true, chatModel_: true}) - _ = g.AddBranch(toolNode_, branch) + _ = g.AddBranch(afterToolCallsCancelCheckNode_, returnDirectBranch) } else { - _ = g.AddEdge(toolNode_, chatModel_) + _ = g.AddEdge(afterToolCallsCancelCheckNode_, chatModel_) } return g, nil diff --git a/adk/react_test.go b/adk/react_test.go index 5364f0912..b0a6c3985 100644 --- a/adk/react_test.go +++ b/adk/react_test.go @@ -23,6 +23,7 @@ import ( "errors" "fmt" "io" + "math" "math/rand" "testing" @@ -148,12 +149,12 @@ func TestReact(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, graph) - compiled, err := graph.Compile(ctx) + compiled, err := graph.Compile(ctx, compose.WithMaxRunSteps(math.MaxInt)) assert.NoError(t, err) assert.NotNil(t, compiled) // Test with a user message - result, err := compiled.Invoke(ctx, &reactInput{messages: []Message{ + result, err := compiled.Invoke(ctx, &reactInput{Messages: []Message{ { Role: schema.User, Content: "Use the test tool to say hello", @@ -215,12 +216,12 @@ func TestReact(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, graph) - compiled, err := graph.Compile(ctx) + compiled, err := graph.Compile(ctx, compose.WithMaxRunSteps(math.MaxInt)) assert.NoError(t, err) assert.NotNil(t, compiled) // Test with a user message when tool returns directly - result, err := compiled.Invoke(ctx, &reactInput{messages: []Message{ + result, err := compiled.Invoke(ctx, &reactInput{Messages: []Message{ { Role: schema.User, Content: "Use the test tool to say hello", @@ -307,12 +308,12 @@ func TestReact(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, graph) - compiled, err := graph.Compile(ctx) + compiled, err := graph.Compile(ctx, compose.WithMaxRunSteps(math.MaxInt)) assert.NoError(t, err) assert.NotNil(t, compiled) // Test streaming with a user message - outStream, err := compiled.Stream(ctx, &reactInput{messages: []Message{ + outStream, err := compiled.Stream(ctx, &reactInput{Messages: []Message{ { Role: schema.User, Content: "Use the test tool to say hello", @@ -417,7 +418,7 @@ func TestReact(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, graph) - compiled, err := graph.Compile(ctx) + compiled, err := graph.Compile(ctx, compose.WithMaxRunSteps(math.MaxInt)) assert.NoError(t, err) assert.NotNil(t, compiled) @@ -425,7 +426,7 @@ func TestReact(t *testing.T) { times = 0 // Test streaming with a user message when tool returns directly - outStream, err := compiled.Stream(ctx, &reactInput{messages: []Message{ + outStream, err := compiled.Stream(ctx, &reactInput{Messages: []Message{ { Role: schema.User, Content: "Use the test tool to say hello", @@ -506,12 +507,12 @@ func TestReact(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, graph) - compiled, err := graph.Compile(ctx) + compiled, err := graph.Compile(ctx, compose.WithMaxRunSteps(math.MaxInt)) assert.NoError(t, err) assert.NotNil(t, compiled) // Test with a user message - result, err := compiled.Invoke(ctx, &reactInput{messages: []Message{ + result, err := compiled.Invoke(ctx, &reactInput{Messages: []Message{ { Role: schema.User, Content: "Use the test tool to say hello", @@ -536,12 +537,12 @@ func TestReact(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, graph) - compiled, err = graph.Compile(ctx) + compiled, err = graph.Compile(ctx, compose.WithMaxRunSteps(math.MaxInt)) assert.NoError(t, err) assert.NotNil(t, compiled) // Test with a user message - result, err = compiled.Invoke(ctx, &reactInput{messages: []Message{ + result, err = compiled.Invoke(ctx, &reactInput{Messages: []Message{ { Role: schema.User, Content: "Use the test tool to say hello", diff --git a/adk/retry_chatmodel.go b/adk/retry_chatmodel.go index 8ae4e2aac..bac955033 100644 --- a/adk/retry_chatmodel.go +++ b/adk/retry_chatmodel.go @@ -196,6 +196,11 @@ func (r *retryModelWrapper) Generate(ctx context.Context, input []*schema.Messag return out, nil } + // Never retry interrupt errors (e.g. cancel safe-point interrupts). + if _, ok := compose.ExtractInterruptInfo(err); ok { + return nil, err + } + if !isRetryAble(ctx, err) { return nil, err } @@ -238,6 +243,10 @@ func (r *retryModelWrapper) Stream(ctx context.Context, input []*schema.Message, stream, err := r.inner.Stream(ctx, input, opts...) if err != nil { + // Never retry interrupt errors (e.g. cancel safe-point interrupts). + if _, ok := compose.ExtractInterruptInfo(err); ok { + return nil, err + } if !isRetryAble(ctx, err) { return nil, err } diff --git a/adk/runctx.go b/adk/runctx.go index 6e2a6cfbe..1a32f1760 100644 --- a/adk/runctx.go +++ b/adk/runctx.go @@ -24,8 +24,6 @@ import ( "sort" "sync" "time" - - "github.com/cloudwego/eino/schema" ) // runSession CheckpointSchema: persisted via serialization.RunCtx (gob). @@ -65,8 +63,14 @@ type agentEventWrapper struct { type otherAgentEventWrapperForEncode agentEventWrapper func (a *agentEventWrapper) GobEncode() ([]byte, error) { - if a.concatenatedMessage != nil && a.Output != nil && a.Output.MessageOutput != nil && a.Output.MessageOutput.IsStreaming { - a.Output.MessageOutput.MessageStream = schema.StreamReaderFromArray([]Message{a.concatenatedMessage}) + if a.Output != nil && a.Output.MessageOutput != nil && a.Output.MessageOutput.IsStreaming { + // Materialize the stream before encoding. An unconsumed stream that + // ends with a non-EOF error (WillRetryError, ErrStreamCanceled) would + // cause MessageVariant.GobEncode to fail. consumeStream replaces the + // stream with an error-free, materialized version. + if a.concatenatedMessage == nil && a.StreamErr == nil { + a.consumeStream() + } } buf := &bytes.Buffer{} diff --git a/adk/runctx_test.go b/adk/runctx_test.go index 7f164b3e2..bef1f44eb 100644 --- a/adk/runctx_test.go +++ b/adk/runctx_test.go @@ -17,7 +17,10 @@ package adk import ( + "bytes" "context" + "encoding/gob" + "errors" "testing" "time" @@ -423,3 +426,209 @@ func TestForkJoinRunCtx(t *testing.T) { mainRunCtx.Session.addEvent(eventF) assert.Equal(t, []string{"A", "B", "C1", "D", "E", "F"}, getEventNames(mainRunCtx.Session.getEvents()), "After F") } + +// makeStreamingEventWrapper creates an agentEventWrapper with a streaming MessageOutput +// whose stream yields the given message then terminates with streamErr (or io.EOF if nil). +func makeStreamingEventWrapper(msg Message, streamErr error) *agentEventWrapper { + r, w := schema.Pipe[Message](2) + w.Send(msg, nil) + if streamErr != nil { + w.Send(nil, streamErr) + } + w.Close() + + return &agentEventWrapper{ + AgentEvent: &AgentEvent{ + AgentName: "test-agent", + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + IsStreaming: true, + MessageStream: r, + Role: schema.Assistant, + }, + }, + }, + } +} + +func TestGobEncodeStreamErrors(t *testing.T) { + t.Run("WillRetryError_unconsumed_stream_fails_GobEncode", func(t *testing.T) { + // An agentEventWrapper whose stream yields a message then WillRetryError. + // Without pre-consuming (no getMessageFromWrappedEvent call), GobEncode + // reaches MessageVariant.GobEncode which treats non-EOF errors as fatal. + wrapper := makeStreamingEventWrapper( + schema.AssistantMessage("partial", nil), + &WillRetryError{ErrStr: "model error", RetryAttempt: 1}, + ) + + _, err := wrapper.GobEncode() + assert.NoError(t, err, "GobEncode should handle WillRetryError streams gracefully") + }) + + t.Run("ErrStreamCanceled_unconsumed_stream_fails_GobEncode", func(t *testing.T) { + // Same scenario but with ErrStreamCanceled (*errors.errorString). + wrapper := makeStreamingEventWrapper( + schema.AssistantMessage("partial", nil), + ErrStreamCanceled, + ) + + _, err := wrapper.GobEncode() + assert.NoError(t, err, "GobEncode should handle ErrStreamCanceled streams gracefully") + }) + + t.Run("successful_stream_GobEncode_succeeds", func(t *testing.T) { + // Control: a clean stream (no error) should encode fine. + wrapper := makeStreamingEventWrapper( + schema.AssistantMessage("hello", nil), + nil, // no stream error + ) + + data, err := wrapper.GobEncode() + assert.NoError(t, err) + assert.NotEmpty(t, data) + + // Verify round-trip decode works. + decoded := &agentEventWrapper{AgentEvent: &AgentEvent{}} + err = decoded.GobDecode(data) + assert.NoError(t, err) + assert.Equal(t, "test-agent", decoded.AgentName) + }) + + t.Run("preconsumed_WillRetryError_GobEncode_succeeds", func(t *testing.T) { + // When getMessageFromWrappedEvent is called first, WillRetryError is + // cached in StreamErr and the stream is replaced with an error-free array. + wrapper := makeStreamingEventWrapper( + schema.AssistantMessage("partial", nil), + &WillRetryError{ErrStr: "model error", RetryAttempt: 1}, + ) + + _, consumeErr := getMessageFromWrappedEvent(wrapper) + assert.Error(t, consumeErr) + + data, err := wrapper.GobEncode() + assert.NoError(t, err, "GobEncode should succeed after pre-consuming WillRetryError stream") + assert.NotEmpty(t, data) + }) + + t.Run("preconsumed_ErrStreamCanceled_GobEncode_succeeds", func(t *testing.T) { + // ErrStreamCanceled is a *StreamCanceledError which IS gob-registered. + // After getMessageFromWrappedEvent, StreamErr = ErrStreamCanceled. + // Since it's registered, gob encoding succeeds. + wrapper := makeStreamingEventWrapper( + schema.AssistantMessage("partial", nil), + ErrStreamCanceled, + ) + + _, consumeErr := getMessageFromWrappedEvent(wrapper) + assert.Error(t, consumeErr) + + data, err := wrapper.GobEncode() + assert.NoError(t, err, "GobEncode should succeed; ErrStreamCanceled is gob-registered") + assert.NotEmpty(t, data) + }) + + t.Run("GobEncode_roundtrip_preserves_content", func(t *testing.T) { + // Verify that after GobEncode with a WillRetryError stream, + // the decoded wrapper has the partial message content and StreamErr intact. + wrapper := makeStreamingEventWrapper( + schema.AssistantMessage("partial response", nil), + &WillRetryError{ErrStr: "err", RetryAttempt: 1}, + ) + + data, err := wrapper.GobEncode() + assert.NoError(t, err) + + decoded := &agentEventWrapper{AgentEvent: &AgentEvent{}} + err = decoded.GobDecode(data) + assert.NoError(t, err) + assert.Equal(t, "test-agent", decoded.AgentName) + assert.True(t, decoded.Output.MessageOutput.IsStreaming) + // The stream should be consumable and yield the partial message. + msg, recvErr := decoded.Output.MessageOutput.MessageStream.Recv() + assert.NoError(t, recvErr) + assert.Contains(t, msg.Content, "partial response") + // StreamErr should be preserved for end-user visibility. + var willRetryErr *WillRetryError + assert.True(t, errors.As(decoded.StreamErr, &willRetryErr)) + assert.Equal(t, "err", willRetryErr.ErrStr) + }) + + t.Run("GobEncode_roundtrip_preserves_ErrStreamCanceled", func(t *testing.T) { + // ErrStreamCanceled (*StreamCanceledError) is gob-registered, so + // StreamErr should survive encoding/decoding. + wrapper := makeStreamingEventWrapper( + schema.AssistantMessage("partial", nil), + ErrStreamCanceled, + ) + + data, err := wrapper.GobEncode() + assert.NoError(t, err) + + decoded := &agentEventWrapper{AgentEvent: &AgentEvent{}} + err = decoded.GobDecode(data) + assert.NoError(t, err) + var streamCanceledErr *StreamCanceledError + assert.ErrorAs(t, decoded.StreamErr, &streamCanceledErr) + }) + + t.Run("GobEncode_idempotent", func(t *testing.T) { + // Calling GobEncode twice should succeed both times (stream replaced on first call). + wrapper := makeStreamingEventWrapper( + schema.AssistantMessage("hello", nil), + &WillRetryError{ErrStr: "err", RetryAttempt: 1}, + ) + + data1, err := wrapper.GobEncode() + assert.NoError(t, err) + + data2, err := wrapper.GobEncode() + assert.NoError(t, err) + + // Both should decode to equivalent content. + d1, d2 := &agentEventWrapper{AgentEvent: &AgentEvent{}}, &agentEventWrapper{AgentEvent: &AgentEvent{}} + assert.NoError(t, d1.GobDecode(data1)) + assert.NoError(t, d2.GobDecode(data2)) + assert.Equal(t, d1.AgentName, d2.AgentName) + }) + + t.Run("GobEncode_non_streaming_unaffected", func(t *testing.T) { + // Non-streaming events should encode/decode as before. + wrapper := &agentEventWrapper{ + AgentEvent: &AgentEvent{ + AgentName: "non-stream-agent", + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + IsStreaming: false, + Message: schema.AssistantMessage("direct", nil), + Role: schema.Assistant, + }, + }, + }, + } + + data, err := wrapper.GobEncode() + assert.NoError(t, err) + + decoded := &agentEventWrapper{AgentEvent: &AgentEvent{}} + assert.NoError(t, decoded.GobDecode(data)) + assert.Equal(t, "non-stream-agent", decoded.AgentName) + assert.False(t, decoded.Output.MessageOutput.IsStreaming) + }) + + t.Run("GobEncode_within_runSession", func(t *testing.T) { + // Simulate the real scenario: a runSession with a streaming event containing + // WillRetryError is gob-encoded (as happens during checkpoint save). + wrapper := makeStreamingEventWrapper( + schema.AssistantMessage("checkpoint content", nil), + &WillRetryError{ErrStr: "retry", RetryAttempt: 1}, + ) + + session := newRunSession() + session.Events = []*agentEventWrapper{wrapper} + + // Encode the entire session (the checkpoint path). + var buf bytes.Buffer + err := gob.NewEncoder(&buf).Encode(session) + assert.NoError(t, err, "encoding runSession with WillRetryError stream should succeed") + }) +} diff --git a/adk/runner.go b/adk/runner.go index 07a931ac2..4881122a6 100644 --- a/adk/runner.go +++ b/adk/runner.go @@ -18,6 +18,7 @@ package adk import ( "context" + "errors" "fmt" "runtime/debug" "sync" @@ -41,6 +42,8 @@ type Runner struct { type CheckPointStore = core.CheckPointStore +type CheckPointDeleter = core.CheckPointDeleter + type RunnerConfig struct { Agent Agent EnableStreaming bool @@ -88,13 +91,14 @@ func (r *Runner) Run(ctx context.Context, messages []Message, AddSessionValues(ctx, o.sessionValues) iter := fa.Run(ctx, input, opts...) - if r.store == nil { + + if r.store == nil && o.cancelCtx == nil { return iter } niter, gen := NewAsyncIteratorPair[*AgentEvent]() - go r.handleIter(ctx, iter, gen, o.checkPointID) + go r.handleIter(ctx, iter, gen, o.checkPointID, o.cancelCtx) return niter } @@ -114,7 +118,7 @@ func (r *Runner) Query(ctx context.Context, // pattern where an agent only needs to know `wasInterrupted` is true to continue. func (r *Runner) Resume(ctx context.Context, checkPointID string, opts ...AgentRunOption) ( *AsyncIterator[*AgentEvent], error) { - return r.resume(ctx, checkPointID, nil, opts...) + return r.resumeInternal(ctx, checkPointID, nil, opts...) } // ResumeWithParams continues an interrupted execution from a checkpoint with specific parameters. @@ -136,11 +140,10 @@ func (r *Runner) Resume(ctx context.Context, checkPointID string, opts ...AgentR // naturally re-interrupt if one of their interrupted children re-interrupts, as they receive the // new `CompositeInterrupt` signal from them. func (r *Runner) ResumeWithParams(ctx context.Context, checkPointID string, params *ResumeParams, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], error) { - return r.resume(ctx, checkPointID, params.Targets, opts...) + return r.resumeInternal(ctx, checkPointID, params.Targets, opts...) } -// resume is the internal implementation for both Resume and ResumeWithParams. -func (r *Runner) resume(ctx context.Context, checkPointID string, resumeData map[string]any, +func (r *Runner) resumeInternal(ctx context.Context, checkPointID string, resumeData map[string]any, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], error) { if r.store == nil { return nil, fmt.Errorf("failed to resume: store is nil") @@ -175,19 +178,21 @@ func (r *Runner) resume(ctx context.Context, checkPointID string, resumeData map } fa := toFlowAgent(ctx, r.a) + aIter := fa.Resume(ctx, resumeInfo, opts...) - if r.store == nil { + + if r.store == nil && o.cancelCtx == nil { return aIter, nil } niter, gen := NewAsyncIteratorPair[*AgentEvent]() - go r.handleIter(ctx, aIter, gen, &checkPointID) + go r.handleIter(ctx, aIter, gen, &checkPointID, o.cancelCtx) return niter, nil } func (r *Runner) handleIter(ctx context.Context, aIter *AsyncIterator[*AgentEvent], - gen *AsyncGenerator[*AgentEvent], checkPointID *string) { + gen *AsyncGenerator[*AgentEvent], checkPointID *string, cancelCtx *cancelContext) { defer func() { panicErr := recover() if panicErr != nil { @@ -207,6 +212,25 @@ func (r *Runner) handleIter(ctx context.Context, aIter *AsyncIterator[*AgentEven break } + if event.Err != nil { + var cancelErr *CancelError + if errors.As(event.Err, &cancelErr) { + if cancelCtx != nil && cancelCtx.isRoot() && cancelCtx.shouldCancel() { + cancelCtx.markCancelHandled() + } + if cancelErr.interruptSignal != nil && checkPointID != nil { + cancelErr.CheckPointID = *checkPointID + cancelErr.InterruptContexts = core.ToInterruptContexts(cancelErr.interruptSignal, allowedAddressSegmentTypes) + err := r.saveCheckPoint(ctx, *checkPointID, &InterruptInfo{}, cancelErr.interruptSignal) + if err != nil { + gen.Send(&AgentEvent{Err: fmt.Errorf("failed to save checkpoint on cancel: %w", err)}) + } + } + gen.Send(event) + break + } + } + if event.Action != nil && event.Action.internalInterrupted != nil { if interruptSignal != nil { // even if multiple interrupt happens, they should be merged into one @@ -231,8 +255,7 @@ func (r *Runner) handleIter(ctx context.Context, aIter *AsyncIterator[*AgentEven legacyData = event.Action.Interrupted.Data if checkPointID != nil { - // save checkpoint first before sending interrupt event, - // so when end-user receives interrupt event, they can resume from this checkpoint + // save checkpoint first before sending interrupt event, so when end-user receives interrupt event, they can resume from this checkpoint err := r.saveCheckPoint(ctx, *checkPointID, &InterruptInfo{ Data: legacyData, }, interruptSignal) diff --git a/adk/turn_buffer.go b/adk/turn_buffer.go new file mode 100644 index 000000000..b154587c9 --- /dev/null +++ b/adk/turn_buffer.go @@ -0,0 +1,128 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import "sync" + +type turnBuffer[T any] struct { + buf []T + mu sync.Mutex + notEmpty *sync.Cond + closed bool + woken bool +} + +func newTurnBuffer[T any]() *turnBuffer[T] { + tb := &turnBuffer[T]{} + tb.notEmpty = sync.NewCond(&tb.mu) + return tb +} + +func (tb *turnBuffer[T]) Send(value T) { + tb.mu.Lock() + defer tb.mu.Unlock() + + if tb.closed { + panic("turnBuffer: send on closed buffer") + } + + tb.buf = append(tb.buf, value) + tb.notEmpty.Signal() +} + +func (tb *turnBuffer[T]) TrySend(value T) bool { + tb.mu.Lock() + defer tb.mu.Unlock() + + if tb.closed { + return false + } + + tb.buf = append(tb.buf, value) + tb.notEmpty.Signal() + return true +} + +func (tb *turnBuffer[T]) Receive() (T, bool) { + tb.mu.Lock() + defer tb.mu.Unlock() + + for len(tb.buf) == 0 && !tb.closed && !tb.woken { + tb.notEmpty.Wait() + } + + tb.woken = false + + if len(tb.buf) == 0 { + var zero T + return zero, false + } + + val := tb.buf[0] + tb.buf = tb.buf[1:] + return val, true +} + +func (tb *turnBuffer[T]) Close() { + tb.mu.Lock() + defer tb.mu.Unlock() + + if !tb.closed { + tb.closed = true + tb.notEmpty.Broadcast() + } +} + +func (tb *turnBuffer[T]) TakeAll() []T { + tb.mu.Lock() + defer tb.mu.Unlock() + + if len(tb.buf) == 0 { + return nil + } + + values := tb.buf + tb.buf = nil + return values +} + +func (tb *turnBuffer[T]) PushFront(values []T) { + if len(values) == 0 { + return + } + + tb.mu.Lock() + defer tb.mu.Unlock() + + tb.buf = append(append([]T{}, values...), tb.buf...) + tb.notEmpty.Signal() +} + +func (tb *turnBuffer[T]) Wakeup() { + tb.mu.Lock() + defer tb.mu.Unlock() + + tb.woken = true + tb.notEmpty.Broadcast() +} + +func (tb *turnBuffer[T]) ClearWakeup() { + tb.mu.Lock() + defer tb.mu.Unlock() + + tb.woken = false +} diff --git a/adk/turn_loop.go b/adk/turn_loop.go new file mode 100644 index 000000000..124f65459 --- /dev/null +++ b/adk/turn_loop.go @@ -0,0 +1,1764 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "bytes" + "context" + "encoding/gob" + "errors" + "fmt" + "runtime/debug" + "sync" + "sync/atomic" + "time" + + "github.com/cloudwego/eino/internal/safe" +) + +// stopSignal coordinates the Stop() call with per-turn watcher goroutines. +// +// Lifecycle overview: +// +// 1. SIGNAL — Stop() calls signal() which bumps the generation counter, +// stores the AgentCancelOptions, and deposits a one-shot notification +// in the buffered notify channel. +// +// 2. DONE — Stop() calls closeDone() which permanently closes the done +// channel. This acts as a durable "stopped" flag: any current or future +// select on done fires immediately, ensuring that every watcher — +// including watchers in turns that start after Stop() but before the +// run loop observes isStopped() — can reliably detect the stop. +// +// 3. RECEIVE — The per-turn watchStopSignal goroutine selects on the done +// channel (the durable flag) and the notify channel (to detect mode +// escalation from a second Stop call). On either signal, it calls +// agentCancelFunc to cancel the running agent. +// +// The generation counter (gen) de-duplicates wakes so that the watcher only +// acts when a new Stop() call has been made, supporting mode escalation +// (e.g. CancelAfterToolCalls followed by CancelImmediate). +type stopSignal struct { + done chan struct{} + + mu sync.Mutex + gen uint64 + // agentCancelOpts controls how the stop interacts with the running agent: + // nil → no cancel; the turn runs to completion (bare Stop) + // empty → CancelImmediate (WithImmediate) + // non-empty → cancel with specific modes (WithGraceful, WithGracefulTimeout) + agentCancelOpts []AgentCancelOption + skipCheckpoint bool + stopCause string + idleFor time.Duration + notify chan struct{} +} + +func newStopSignal() *stopSignal { + return &stopSignal{ + done: make(chan struct{}), + notify: make(chan struct{}, 1), + } +} + +// signal records a stop request and wakes the current turn's watcher (if any). +// The non-blocking send means the notification is silently coalesced when the +// buffer is already full — this is safe because gen de-duplicates in the watcher. +func (s *stopSignal) signal(cfg *stopConfig) { + s.mu.Lock() + s.gen++ + // Only overwrite when the caller explicitly provides cancel options. + // A bare Stop() leaves cfg.agentCancelOpts nil (no cancel intent), which + // must not de-escalate a previously set cancel policy. + if cfg.agentCancelOpts != nil { + s.agentCancelOpts = cfg.agentCancelOpts + } + if cfg.skipCheckpoint { + s.skipCheckpoint = true + } + if cfg.stopCause != "" && s.stopCause == "" { + s.stopCause = cfg.stopCause + } + if cfg.idleFor > 0 && s.idleFor == 0 { + s.idleFor = cfg.idleFor + } + s.mu.Unlock() + select { + case s.notify <- struct{}{}: + default: + } +} + +// isStopped returns true if closeDone() has been called. +func (s *stopSignal) isStopped() bool { + select { + case <-s.done: + return true + default: + return false + } +} + +// closeDone permanently marks the stop as committed. All current and future +// selects on s.done will fire immediately after this call. +func (s *stopSignal) closeDone() { + close(s.done) +} + +// check returns the current generation and a snapshot of the cancel options. +func (s *stopSignal) check() (uint64, []AgentCancelOption) { + s.mu.Lock() + defer s.mu.Unlock() + return s.gen, append([]AgentCancelOption{}, s.agentCancelOpts...) +} + +func (s *stopSignal) isSkipCheckpoint() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.skipCheckpoint +} + +func (s *stopSignal) getStopCause() string { + s.mu.Lock() + defer s.mu.Unlock() + return s.stopCause +} + +func (s *stopSignal) getIdleFor() time.Duration { + s.mu.Lock() + defer s.mu.Unlock() + return s.idleFor +} + +// preemptSignal coordinates preemption between Push callers and the run loop. +// +// Lifecycle overview: +// +// 1. HOLD — A Push caller (or the run loop itself) calls holdRunLoop() to +// increment holdCount. While holdCount > 0 the run loop blocks at +// waitForPreemptOrUnhold(), preventing it from starting a new turn. +// +// 2. REQUEST — The Push caller calls requestPreempt() which sets +// preemptRequested=true, bumps preemptGen, stores cancelOpts/acks, and +// wakes both the run-loop (via cond) and the in-turn watcher goroutine +// (via notify channel). +// +// 3. RECEIVE — The per-turn watchPreemptSignal goroutine calls +// receivePreempt(), obtains the cancel opts and ack channels, invokes +// agentCancelFunc to cancel the running agent, and closes the ack +// channels to notify Push callers. +// +// 4. UNHOLD — After the turn finishes (or if the Push caller decides not +// to preempt), unholdRunLoop() / endTurnAndUnhold() decrements +// holdCount. When holdCount reaches 0, all signal state is reset. +// +// The run loop brackets every turn with holdRunLoop() / endTurnAndUnhold() +// so that a concurrent Push caller's hold keeps holdCount > 0 even after +// the turn ends, preventing the loop from racing into a new turn before +// the Push caller's preempt request is delivered. +// +// Fields currentTC and currentRunCtx are stored here (rather than on +// TurnLoop) so that holdAndGetTurn() can atomically snapshot the turn +// state and increment holdCount under the same mu lock, eliminating the +// TOCTOU race between reading the turn and holding the loop. +type preemptSignal struct { + mu sync.Mutex + cond *sync.Cond + holdCount int + preemptRequested bool + preemptGen uint64 + agentCancelOpts []AgentCancelOption + pendingAckList []chan struct{} + notify chan struct{} + + currentTC any + currentRunCtx context.Context +} + +func newPreemptSignal() *preemptSignal { + s := &preemptSignal{notify: make(chan struct{}, 1)} + s.cond = sync.NewCond(&s.mu) + return s +} + +func (s *preemptSignal) holdRunLoop() { + s.mu.Lock() + s.holdCount++ + s.mu.Unlock() +} + +func (s *preemptSignal) setTurn(ctx context.Context, tc any) { + s.mu.Lock() + s.currentRunCtx = ctx + s.currentTC = tc + s.mu.Unlock() +} + +func (s *preemptSignal) holdAndGetTurn() (context.Context, any) { + s.mu.Lock() + defer s.mu.Unlock() + s.holdCount++ + return s.currentRunCtx, s.currentTC +} + +// requestPreempt records a preempt request and wakes both waiters. +// If holdCount is 0, no one is listening — close the ack immediately as a no-op. +func (s *preemptSignal) requestPreempt(ack chan struct{}, opts ...AgentCancelOption) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.holdCount <= 0 { + if ack != nil { + close(ack) + } + return + } + + s.preemptRequested = true + s.preemptGen++ + s.agentCancelOpts = opts + if ack != nil { + s.pendingAckList = append(s.pendingAckList, ack) + } + select { + case s.notify <- struct{}{}: + default: + } + + s.cond.Broadcast() +} + +// receivePreempt is called by the per-turn watcher goroutine to consume a +// pending preempt. It drains pendingAckList (so the watcher can close them +// after invoking agentCancelFunc) but intentionally preserves preemptRequested +// and preemptGen — these are needed by waitForPreemptOrUnhold on the run loop. +func (s *preemptSignal) receivePreempt() (bool, uint64, []AgentCancelOption, []chan struct{}) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.preemptRequested { + ackList := s.pendingAckList + s.pendingAckList = nil + return true, s.preemptGen, s.agentCancelOpts, ackList + } + return false, 0, nil, nil +} + +// waitForPreemptOrUnhold blocks the run loop between turns. It returns early +// (preempted=false) when holdCount is 0 (no Push caller is holding). Otherwise +// it blocks until either a preempt is requested or all holders release. +func (s *preemptSignal) waitForPreemptOrUnhold() (preempted bool, opts []AgentCancelOption, ackList []chan struct{}) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.holdCount <= 0 { + return false, nil, nil + } + + for s.holdCount > 0 && !s.preemptRequested { + s.cond.Wait() + } + + if s.preemptRequested { + ackList = s.pendingAckList + s.pendingAckList = nil + return true, s.agentCancelOpts, ackList + } + return false, nil, nil +} + +// resetLocked clears all signal state and closes pending ack channels so the +// next cycle starts clean and blocked Push callers are unblocked. Must be +// called with s.mu held. Does NOT touch holdCount, currentTC, or currentRunCtx +// — callers are responsible for those. +func (s *preemptSignal) resetLocked() { + s.preemptRequested = false + s.preemptGen = 0 + s.agentCancelOpts = nil + for _, ack := range s.pendingAckList { + close(ack) + } + s.pendingAckList = nil + select { + case <-s.notify: + default: + } +} + +// unholdRunLoop drops one hold. When holdCount reaches 0, all signal state is +// reset so the next cycle starts clean. +func (s *preemptSignal) unholdRunLoop() { + s.mu.Lock() + defer s.mu.Unlock() + + s.holdCount-- + if s.holdCount < 0 { + s.holdCount = 0 + } + if s.holdCount == 0 { + s.resetLocked() + } + s.cond.Broadcast() +} + +// endTurnAndUnhold is called by the run loop after runAgentAndHandleEvents +// returns. It clears the current turn context and drops the run loop's hold. +func (s *preemptSignal) endTurnAndUnhold() { + s.mu.Lock() + defer s.mu.Unlock() + + s.currentTC = nil + s.currentRunCtx = nil + s.holdCount-- + if s.holdCount < 0 { + s.holdCount = 0 + } + if s.holdCount == 0 { + s.resetLocked() + } + s.cond.Broadcast() +} + +// drainAll forcefully resets all preemptSignal state and closes any pending +// ack channels. Called during TurnLoop cleanup to prevent ack channels from +// leaking when the run loop exits (e.g. due to Stop) while a Push caller +// still holds a reference. +func (s *preemptSignal) drainAll() { + s.mu.Lock() + defer s.mu.Unlock() + + s.holdCount = 0 + s.currentTC = nil + s.currentRunCtx = nil + s.resetLocked() + s.cond.Broadcast() +} + +// TurnLoopConfig is the configuration for creating a TurnLoop. +type TurnLoopConfig[T any] struct { + // GenInput receives the TurnLoop instance and all buffered items, and decides what to process. + // It returns which items to consume now vs keep for later turns. + // The loop parameter allows calling Push() or Stop() directly from within the callback. + // Required. + GenInput func(ctx context.Context, loop *TurnLoop[T], items []T) (*GenInputResult[T], error) + + // GenResume is called exactly once when the TurnLoop detects a mid-turn + // checkpoint on startup (i.e. CheckpointID is configured and the stored + // checkpoint has runner state from an interrupted agent execution). + // It receives: + // - canceledItems: the items being processed when the prior run was canceled + // - unhandledItems: items buffered but not processed when the prior run exited + // - newItems: items that were Push()-ed before Run() was called + // + // It returns a GenResumeResult describing how to resume the interrupted agent + // turn (optional ResumeParams) and how to manipulate the buffer + // (Consumed/Remaining) before continuing. + GenResume func(ctx context.Context, loop *TurnLoop[T], canceledItems, unhandledItems, newItems []T) (*GenResumeResult[T], error) + + // PrepareAgent returns an Agent configured to handle the consumed items. + // This callback should set up the agent with appropriate system prompt, + // tools, and middlewares based on what items are being processed. + // Called once per turn with the items that GenInput decided to consume. + // The loop parameter allows calling Push() or Stop() directly from within the callback. + // Required. + PrepareAgent func(ctx context.Context, loop *TurnLoop[T], consumed []T) (Agent, error) + + // OnAgentEvents is called to handle events emitted by the agent. + // The TurnContext provides per-turn info and control: + // - tc.Consumed: items that triggered this agent execution + // - tc.Loop: allows calling Push() or Stop() directly from within the callback + // - tc.Preempted / tc.Stopped: signals while processing events + // + // Error handling: the returned error is only used when the callback itself + // wants to abort the TurnLoop. The TurnLoop already captures CancelError + // from the event stream when the turn is stopped or preempted, so the + // callback should NOT propagate CancelError. In practice, return a non-nil + // error only for callback-internal failures that should terminate the loop; + // return nil when the current agent is canceled by an external Stop or + // Preempt (Preempt cancels the current agent but the loop continues with + // the next turn). + // + // Optional. If not provided, events are drained and errors (except CancelError + // from Stop-triggered cancellation) are returned as ExitReason. + OnAgentEvents func(ctx context.Context, tc *TurnContext[T], events *AsyncIterator[*AgentEvent]) error + + // Store is the checkpoint store for persistence and resume. Optional. + // When set together with CheckpointID, enables automatic checkpoint-based resume. + // The TurnLoop always persists both runner checkpoint bytes and item bookkeeping + // (CanceledItems, UnhandledItems) via gob encoding, so T must be gob-encodable + // when Store is used. + Store CheckPointStore + + // CheckpointID, when set together with Store, enables automatic + // checkpoint-based resume. On Run(), the TurnLoop queries Store for this ID: + // - If a checkpoint exists with runner state (mid-turn interrupt), + // GenResume is called to plan the resume turn. + // - If a checkpoint exists without runner state (between-turns), + // the stored unhandled items are buffered and the loop proceeds + // normally via GenInput. + // - If no checkpoint exists, the loop starts fresh. + // + // On exit, if the TurnLoop saved a new checkpoint, it is saved under this + // same CheckpointID. On clean exit (no checkpoint saved), the existing + // checkpoint under CheckpointID is deleted to prevent stale resumption. + CheckpointID string +} + +// GenInputResult contains the result of GenInput processing. +type GenInputResult[T any] struct { + // RunCtx, if non-nil, overrides the context for this turn's execution + // (PrepareAgent, agent run, OnAgentEvents). + // + // Must be derived from the ctx passed to GenInput to preserve the + // TurnLoop's cancellation semantics and inherited values. For example: + // + // runCtx := context.WithValue(ctx, traceKey{}, extractTraceID(items)) + // return &GenInputResult[T]{RunCtx: runCtx, ...}, nil + // + // If nil, the TurnLoop's context is used unchanged. + RunCtx context.Context + + // Input is the agent input to execute + Input *AgentInput + + // RunOpts are the options for this agent run + RunOpts []AgentRunOption + + // Consumed are the items selected for this turn. + // They are removed from the buffer and passed to PrepareAgent. + Consumed []T + + // Remaining are the items to keep in the buffer for a future turn. + // TurnLoop pushes Remaining back into the buffer before running the agent. + // + // Items from the GenInput input slice that are in neither Consumed nor Remaining + // are dropped by the loop. + Remaining []T +} + +// GenResumeResult contains the result of GenResume processing. +type GenResumeResult[T any] struct { + // RunCtx, if non-nil, overrides the context for this resumed turn's execution + // (PrepareAgent, agent resume, OnAgentEvents). + RunCtx context.Context + + // RunOpts are the options for this agent resume run. + RunOpts []AgentRunOption + + // ResumeParams are optional parameters for resuming an interrupted agent. + ResumeParams *ResumeParams + + // Consumed are the items selected for this resumed turn. + // They are removed from the buffer and passed to PrepareAgent. + Consumed []T + + // Remaining are the items to keep in the buffer for a future turn. + // TurnLoop pushes Remaining back into the buffer before resuming the agent. + // + // Items from (canceledItems, unhandledItems, newItems) that are in neither Consumed + // nor Remaining are dropped by the loop. + Remaining []T +} + +type turnRunSpec[T any] struct { + runCtx context.Context + input *AgentInput + runOpts []AgentRunOption + resumeParams *ResumeParams + isResume bool + consumed []T + resumeBytes []byte +} + +type turnPlan[T any] struct { + turnCtx context.Context + remaining []T + spec *turnRunSpec[T] +} + +func (l *TurnLoop[T]) planTurn( + ctx context.Context, + isResume bool, + items []T, + pr *turnLoopPendingResume[T], +) (*turnPlan[T], error) { + if !isResume { + result, err := l.config.GenInput(ctx, l, items) + if err != nil { + return nil, err + } + if result == nil { + return nil, errors.New("GenInputResult is nil") + } + if result.Input == nil { + return nil, errors.New("agent input is nil") + } + turnCtx := ctx + if result.RunCtx != nil { + turnCtx = result.RunCtx + } + return &turnPlan[T]{ + turnCtx: turnCtx, + remaining: result.Remaining, + spec: &turnRunSpec[T]{ + runCtx: result.RunCtx, + input: result.Input, + runOpts: result.RunOpts, + consumed: result.Consumed, + }, + }, nil + } + if pr == nil { + return nil, errors.New("resume payload is nil") + } + if l.config.GenResume == nil { + return nil, errors.New("GenResume is required for resume") + } + resumeResult, err := l.config.GenResume(ctx, l, pr.canceled, pr.unhandled, pr.newItems) + if err != nil { + return nil, err + } + if resumeResult == nil { + return nil, errors.New("GenResumeResult is nil") + } + turnCtx := ctx + if resumeResult.RunCtx != nil { + turnCtx = resumeResult.RunCtx + } + return &turnPlan[T]{ + turnCtx: turnCtx, + remaining: resumeResult.Remaining, + spec: &turnRunSpec[T]{ + runCtx: resumeResult.RunCtx, + runOpts: resumeResult.RunOpts, + resumeParams: resumeResult.ResumeParams, + isResume: true, + consumed: resumeResult.Consumed, + resumeBytes: pr.resumeBytes, + }, + }, nil +} + +// TurnLoopExitState is returned when TurnLoop exits, containing the exit reason +// and any items that were not processed. +type TurnLoopExitState[T any] struct { + // ExitReason indicates why the loop exited. + // nil means clean exit (Stop() was called without cancel options, or the + // agent completed normally before Stop took effect). + // Non-nil values include context errors, callback errors, *CancelError, etc. + // When Stop(WithImmediate()) or Stop(WithGraceful()) cancels a running + // agent, ExitReason will be a *CancelError. + // This never contains checkpoint errors — see CheckpointErr for those. + ExitReason error + + // UnhandledItems contains items that were buffered but not processed. + // These are items for which Push returned true but were never consumed by a turn. + // This is always valid regardless of ExitReason. + UnhandledItems []T + + // CanceledItems contains the items whose turn was actually interrupted + // by a cancel (Stop with WithImmediate, WithGraceful, or WithGracefulTimeout). + // Only populated when ExitReason is a *CancelError — if the agent finishes + // normally before the cancel takes effect, CanceledItems is empty. + // It can be used to reconstruct GenInput/PrepareAgent inputs when resuming. + CanceledItems []T + + // StopCause is the business-supplied reason passed via WithStopCause. + // Empty if Stop was not called or no cause was provided. + StopCause string + + // Checkpointed indicates whether a checkpoint save was attempted during cleanup. + // True only when Store is configured, CheckpointID is set, Stop() was called, + // and the loop was not idle at exit time. + Checkpointed bool + + // CheckpointErr is the error from checkpoint save, if any. + // nil when Checkpointed is false (no attempt was made) or when the save succeeded. + CheckpointErr error + + // TakeLateItems returns items that were pushed after the loop stopped + // (i.e., Push returned false for these items). These items are NOT included + // in the checkpoint. + // + // This function is idempotent: the first call computes and caches the result; + // subsequent calls return the same slice. + // + // After TakeLateItems is called, any subsequent Push() will panic. This + // seals the late buffer and prevents items from being silently lost. + // + // It is safe to call TakeLateItems from any goroutine after Wait() returns. + // If TakeLateItems is never called, late items are simply garbage collected. + TakeLateItems func() []T +} + +// TurnContext provides per-turn context to the OnAgentEvents callback. +type TurnContext[T any] struct { + // Loop is the TurnLoop instance, allowing Push() or Stop() calls. + Loop *TurnLoop[T] + + // Consumed contains items that triggered this agent execution. + Consumed []T + + // Preempted is closed when a preempt signal fires for the current turn + // (via Push with WithPreempt/WithPreemptTimeout) and at least one + // preemptive Push contributed to the CancelError for the current turn. + // "Contributed" means the preempt's cancel options were included in the + // CancelError before it was finalized. Remains open if no preempt contributed. + // Use in a select to detect preemption while processing events. + // + // Both Preempted and Stopped may be closed within the same turn if both + // signals arrive while the agent is still being cancelled. Whichever + // arrives after the cancel is fully handled will not contribute. + Preempted <-chan struct{} + + // Stopped is closed when a Stop() call contributed to the CancelError for the + // current turn. + // "Contributed" means Stop's cancel options were included in the CancelError + // before it was finalized. Remains open if Stop did not contribute. + // Use in a select to detect stop while processing events. + // + // See Preempted for the relationship between the two channels. + Stopped <-chan struct{} + + // StopCause returns the business-supplied reason from WithStopCause. + // This value is only meaningful after the Stopped channel is closed. + // Before that, it returns an empty string. + StopCause func() string +} + +// TurnLoop is a push-based event loop for agent execution. +// Users push items via Push() and the loop processes them through the agent. +// +// Create with NewTurnLoop, then start with Run: +// +// loop := NewTurnLoop(cfg) +// // pass loop to other components, push initial items, etc. +// loop.Run(ctx) +// +// # Permissive API +// +// All methods are valid on a not-yet-running loop: +// - Push: items are buffered and will be processed once Run is called. +// - Stop: sets the stopped flag; a subsequent Run will exit immediately. +// - Wait: blocks until Run is called AND the loop exits. If Run is never +// called, Wait blocks forever (this is a programming error, analogous +// to reading from a channel that nobody writes to). +type TurnLoop[T any] struct { + config TurnLoopConfig[T] + + buffer *turnBuffer[T] + + stopped int32 + started int32 + + done chan struct{} + + result *TurnLoopExitState[T] + + stopOnce sync.Once + + runOnce sync.Once + + stopSig *stopSignal + + preemptSig *preemptSignal + + runErr error + + canceledItems []T + + checkPointRunnerBytes []byte + + pendingResume *turnLoopPendingResume[T] + + loadCheckpointID string + + onAgentEvents func(ctx context.Context, tc *TurnContext[T], events *AsyncIterator[*AgentEvent]) error + + lateMu sync.Mutex + lateItems []T + lateSealed bool +} + +func (l *TurnLoop[T]) appendLate(item T) { + l.lateMu.Lock() + defer l.lateMu.Unlock() + if l.lateSealed { + panic("TurnLoop: Push called after TakeLateItems") + } + l.lateItems = append(l.lateItems, item) +} + +type turnLoopCheckpoint[T any] struct { + RunnerCheckpoint []byte + // HasRunnerState reports whether RunnerCheckpoint contains resumable runner state. + // It is false for "between turns" checkpoints where no agent execution was + // interrupted (e.g. Stop() before the first turn or between turns). + HasRunnerState bool + UnhandledItems []T + CanceledItems []T +} + +// ErrCheckpointStoreNil is returned when a checkpoint operation requires a Store +// but none was configured in TurnLoopConfig. +var ErrCheckpointStoreNil = errors.New("checkpoint store is nil") + +func marshalTurnLoopCheckpoint[T any](c *turnLoopCheckpoint[T]) ([]byte, error) { + buf := new(bytes.Buffer) + if err := gob.NewEncoder(buf).Encode(c); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +func unmarshalTurnLoopCheckpoint[T any](data []byte) (*turnLoopCheckpoint[T], error) { + var c turnLoopCheckpoint[T] + if err := gob.NewDecoder(bytes.NewReader(data)).Decode(&c); err != nil { + return nil, err + } + return &c, nil +} + +func (l *TurnLoop[T]) saveTurnLoopCheckpoint(ctx context.Context, checkPointID string, c *turnLoopCheckpoint[T]) error { + if l.config.Store == nil { + return ErrCheckpointStoreNil + } + data, err := marshalTurnLoopCheckpoint(c) + if err != nil { + return err + } + return l.config.Store.Set(ctx, checkPointID, data) +} + +func (l *TurnLoop[T]) deleteTurnLoopCheckpoint(ctx context.Context, checkPointID string) error { + if l.config.Store == nil { + return nil + } + if deleter, ok := l.config.Store.(CheckPointDeleter); ok { + return deleter.Delete(ctx, checkPointID) + } + return nil +} + +func (l *TurnLoop[T]) tryLoadCheckpoint(ctx context.Context) error { + checkPointID := l.config.CheckpointID + if checkPointID == "" || l.config.Store == nil { + return nil + } + + l.loadCheckpointID = checkPointID + + data, existed, err := l.config.Store.Get(ctx, checkPointID) + if err != nil { + return fmt.Errorf("failed to load checkpoint[%s]: %w", checkPointID, err) + } + if !existed { + return nil + } + + var cp *turnLoopCheckpoint[T] + if len(data) == 0 { + return nil + } + cp, err = unmarshalTurnLoopCheckpoint[T](data) + if err != nil { + return fmt.Errorf("failed to unmarshal checkpoint[%s]: %w", checkPointID, err) + } + + newItems := l.buffer.TakeAll() + + if cp.HasRunnerState { + if len(cp.RunnerCheckpoint) == 0 { + l.buffer.PushFront(newItems) + return fmt.Errorf("checkpoint[%s] has runner state but bytes are empty", checkPointID) + } + l.pendingResume = &turnLoopPendingResume[T]{ + canceled: append([]T{}, cp.CanceledItems...), + unhandled: append([]T{}, cp.UnhandledItems...), + newItems: append([]T{}, newItems...), + resumeBytes: append([]byte{}, cp.RunnerCheckpoint...), + } + } else { + items := make([]T, 0, len(cp.UnhandledItems)+len(newItems)) + items = append(items, cp.UnhandledItems...) + items = append(items, newItems...) + l.buffer.PushFront(items) + } + + return nil +} + +type turnLoopPendingResume[T any] struct { + canceled []T + unhandled []T + newItems []T + resumeBytes []byte +} + +// SafePoint describes at which boundary the agent may be cancelled. +// It is a bitmask: values can be combined with bitwise OR to accept multiple +// safe points (e.g. AfterToolCalls | AfterChatModel). Internally, SafePoint +// is translated to CancelMode via toCancelMode(). +// +// SafePoint is used only in the preemption API (WithPreempt/WithPreemptTimeout). +// A key design constraint: preemption always targets a safe point — the user's +// intent is to cancel at a well-defined boundary, never to abort immediately. +// Immediate cancellation is only reachable as an automatic timeout escalation +// (via WithPreemptTimeout), not as a direct user choice. This is why SafePoint +// has no "immediate" value and why WithPreempt requires a non-zero SafePoint +// (panics otherwise). +type SafePoint int + +const ( + // AfterToolCalls allows the agent to finish the current tool-call round + // before being cancelled. + AfterToolCalls SafePoint = 1 << iota + // AfterChatModel allows the agent to finish the current chat-model + // call before being cancelled. + AfterChatModel + // AnySafePoint is shorthand for AfterToolCalls | AfterChatModel. + AnySafePoint = AfterToolCalls | AfterChatModel +) + +func (sp SafePoint) toCancelMode() CancelMode { + var mode CancelMode + if sp&AfterToolCalls != 0 { + mode |= CancelAfterToolCalls + } + if sp&AfterChatModel != 0 { + mode |= CancelAfterChatModel + } + return mode +} + +type stopConfig struct { + agentCancelOpts []AgentCancelOption + skipCheckpoint bool + stopCause string + idleFor time.Duration +} + +// StopOption is an option for Stop(). +type StopOption func(*stopConfig) + +// WithGraceful requests a graceful stop that waits at the nearest safe point +// (after tool calls or after a chat-model call) and propagates recursively to +// nested agents. It does not impose a time limit; use WithGracefulTimeout to +// add a grace period after which the stop escalates to immediate cancellation. +// +// WithGraceful and WithGracefulTimeout are mutually exclusive; if both are +// passed to the same Stop call, the last one wins. +func WithGraceful() StopOption { + return func(cfg *stopConfig) { + cfg.agentCancelOpts = []AgentCancelOption{ + WithAgentCancelMode(CancelAfterChatModel | CancelAfterToolCalls), + WithRecursive(), + } + } +} + +// WithImmediate aborts the running agent turn as soon as possible. +// The agent's context is cancelled immediately without waiting for any +// safe point. Nested agents inside AgentTools are torn down as a side effect. +// +// This is the most aggressive stop mode — typically used when the caller +// wants to shut down the TurnLoop with no intention of resuming. +func WithImmediate() StopOption { + return func(cfg *stopConfig) { + cfg.agentCancelOpts = []AgentCancelOption{} + } +} + +// WithGracefulTimeout is like WithGraceful but adds a grace period. +// If the agent has not reached a safe point within gracePeriod, the stop +// escalates to immediate cancellation. +// +// gracePeriod must be positive; passing a zero or negative duration panics. +// +// WithGraceful and WithGracefulTimeout are mutually exclusive; if both are +// passed to the same Stop call, the last one wins. +func WithGracefulTimeout(gracePeriod time.Duration) StopOption { + if gracePeriod <= 0 { + panic("adk: WithGracefulTimeout: gracePeriod must be positive") + } + return func(cfg *stopConfig) { + cfg.agentCancelOpts = []AgentCancelOption{ + WithAgentCancelMode(CancelAfterChatModel | CancelAfterToolCalls), + WithRecursive(), + WithAgentCancelTimeout(gracePeriod), + } + } +} + +// WithSkipCheckpoint tells the TurnLoop not to persist a checkpoint for this +// Stop call. Use this when the caller does not intend to resume in the future. +// The flag is sticky: once any Stop() call sets it, subsequent calls cannot undo it. +func WithSkipCheckpoint() StopOption { + return func(cfg *stopConfig) { + cfg.skipCheckpoint = true + } +} + +// WithStopCause attaches a business-supplied reason string to this Stop call. +// The cause is surfaced in TurnLoopExitState.StopCause and, after the Stopped +// channel closes, via TurnContext.StopCause(). +// If multiple Stop() calls provide a cause, the first non-empty value wins. +func WithStopCause(cause string) StopOption { + return func(cfg *stopConfig) { + cfg.stopCause = cause + } +} + +// UntilIdleFor defers the stop until the TurnLoop has been continuously idle +// (blocked between turns with no pending items) for at least the given +// duration. Each time a new item arrives the timer resets from zero. +// +// This is useful when business code monitors agent activity externally and +// wants to shut down the loop once there has been no work for a while, without +// racing with concurrent Push calls. +// +// UntilIdleFor is combinable with other StopOptions in the same call. +// For example, Stop(UntilIdleFor(30*time.Second), WithGraceful()) means +// "after 30 s of idle, stop gracefully". If another Stop call is made +// without UntilIdleFor (e.g. Stop(WithImmediate())), the loop shuts down +// immediately, bypassing the idle wait. +// +// Only the first UntilIdleFor duration takes effect; subsequent calls with +// a different duration are ignored. A Stop() call without UntilIdleFor always +// shuts down the loop immediately regardless of any pending idle timer. +// +// duration must be positive; passing a zero or negative value panics. +func UntilIdleFor(duration time.Duration) StopOption { + if duration <= 0 { + panic("adk: UntilIdleFor: duration must be positive") + } + return func(cfg *stopConfig) { + cfg.idleFor = duration + } +} + +type pushConfig[T any] struct { + preempt bool + preemptDelay time.Duration + agentCancelOpts []AgentCancelOption + pushStrategy func(context.Context, *TurnContext[T]) []PushOption[T] +} + +// PushOption is an option for Push(). +type PushOption[T any] func(*pushConfig[T]) + +// WithPreempt signals that the current agent turn should be cancelled at the +// specified safePoint after pushing the new item. The loop cancels the current +// turn and starts a new one, where GenInput will see all buffered items +// including the newly pushed one. +// Use WithPreemptTimeout to add a timeout that escalates to immediate abort. +// +// Because safe points fire at turn-level boundaries (after the chat model +// returns or after all tool calls complete), no nested agent is running at +// the moment of cancellation — nested agents within AgentTools have either +// not started yet (AfterChatModel) or already finished (AfterToolCalls). +// If the preemption escalates to immediate via WithPreemptTimeout, any +// in-flight nested agent is torn down through Go context cancellation. +// +// WithPreempt and WithPreemptTimeout are mutually exclusive; if both are +// passed to the same Push call, the last one wins. +// +// safePoint must not be zero; passing SafePoint(0) panics. +func WithPreempt[T any](safePoint SafePoint) PushOption[T] { + if safePoint == 0 { + panic("adk: SafePoint must not be zero; use AfterToolCalls, AfterChatModel, or AnySafePoint") + } + return func(cfg *pushConfig[T]) { + cfg.preempt = true + cfg.agentCancelOpts = []AgentCancelOption{ + WithAgentCancelMode(safePoint.toCancelMode()), + } + } +} + +// WithPreemptTimeout is like WithPreempt but adds a timeout. If the agent has +// not reached the safe point within timeout, the preemption escalates to +// immediate cancellation. +// +// safePoint must not be zero; passing SafePoint(0) panics. +func WithPreemptTimeout[T any](safePoint SafePoint, timeout time.Duration) PushOption[T] { + if safePoint == 0 { + panic("adk: SafePoint must not be zero; use AfterToolCalls, AfterChatModel, or AnySafePoint") + } + return func(cfg *pushConfig[T]) { + cfg.preempt = true + cfg.agentCancelOpts = []AgentCancelOption{ + WithAgentCancelMode(safePoint.toCancelMode()), + WithAgentCancelTimeout(timeout), + } + } +} + +// WithPreemptDelay sets a delay duration before preemption takes effect. +// When used with WithPreempt or WithPreemptTimeout, the push will succeed +// immediately, but the preemption signal will be delayed by the specified +// duration. This allows the current agent to continue processing for a grace +// period before being preempted. +func WithPreemptDelay[T any](delay time.Duration) PushOption[T] { + return func(cfg *pushConfig[T]) { + cfg.preemptDelay = delay + } +} + +// WithPushStrategy provides dynamic push option resolution based on the current turn state. +// The callback receives the current turn's context and TurnContext (nil if no turn is active) +// and returns the actual PushOptions to apply. When WithPushStrategy is used, all other +// PushOptions passed to the same Push call are ignored. +// +// The returned options must not contain another WithPushStrategy; any nested +// strategy is silently stripped. +// +// Example: preempt only if the current turn is processing low-priority items: +// +// loop.Push(urgentItem, WithPushStrategy(func(ctx context.Context, tc *TurnContext[MyItem]) []PushOption[MyItem] { +// if tc == nil { +// return nil // between turns, plain push +// } +// if isLowPriority(tc.Consumed) { +// return []PushOption[MyItem]{WithPreempt[MyItem](AnySafePoint)} +// } +// return nil // don't preempt high-priority work +// })) +func WithPushStrategy[T any](fn func(ctx context.Context, tc *TurnContext[T]) []PushOption[T]) PushOption[T] { + return func(cfg *pushConfig[T]) { + cfg.pushStrategy = fn + } +} + +func defaultTurnLoopOnAgentEvents[T any](_ context.Context, _ *TurnContext[T], events *AsyncIterator[*AgentEvent]) error { + for { + event, ok := events.Next() + if !ok { + break + } + if event.Err != nil { + return event.Err + } + } + return nil +} + +// NewTurnLoop creates a new TurnLoop without starting it. +// The returned loop accepts Push and Stop calls immediately; pushed items +// are buffered until Run is called. +// Call Run to start the processing goroutine. +// +// NewTurnLoop panics if GenInput or PrepareAgent is nil. +func NewTurnLoop[T any](cfg TurnLoopConfig[T]) *TurnLoop[T] { + if cfg.GenInput == nil { + panic("adk: NewTurnLoop: GenInput is required") + } + if cfg.PrepareAgent == nil { + panic("adk: NewTurnLoop: PrepareAgent is required") + } + + l := &TurnLoop[T]{ + config: cfg, + buffer: newTurnBuffer[T](), + done: make(chan struct{}), + stopSig: newStopSignal(), + preemptSig: newPreemptSignal(), + } + if cfg.OnAgentEvents != nil { + l.onAgentEvents = cfg.OnAgentEvents + } else { + l.onAgentEvents = defaultTurnLoopOnAgentEvents[T] + } + return l +} + +func (l *TurnLoop[T]) start(ctx context.Context) { + l.runOnce.Do(func() { + atomic.StoreInt32(&l.started, 1) + go l.run(ctx) + }) +} + +// Run starts the loop's processing goroutine. It is non-blocking: the loop +// runs in the background and results are obtained via Wait. +// +// If CheckpointID is configured in TurnLoopConfig and a matching checkpoint +// exists in Store, the loop automatically resumes from that checkpoint. +// Otherwise it starts fresh with whatever items were Push()-ed. +// +// Calling Run more than once is a no-op: only the first call starts the loop. +func (l *TurnLoop[T]) Run(ctx context.Context) { + l.start(ctx) +} + +// Push adds an item to the loop's buffer for processing. +// This method is non-blocking and thread-safe. +// Returns false if the loop has stopped, true otherwise. If a preemptive push +// succeeds, the second return value is a channel that is closed when the loop +// has acknowledged the preempt signal (by either initiating cancellation of the +// current agent run or reaching a point where no cancellation is needed). +// If the loop has not been started yet (Run not called), items are buffered +// and will be processed once Run is called. +// After Wait() returns, failed pushes can be recovered via TurnLoopExitState.TakeLateItems(). +// Once TakeLateItems() has been called, any subsequent push that would become a +// late item will panic instead of being silently dropped. +// +// Use WithPreempt() or WithPreemptTimeout() to atomically push an item and signal +// preemption of the current agent. This is useful for urgent items that should +// interrupt the current processing. +// The returned channel may be waited on if the caller needs to ensure the preempt +// signal has been observed. +// +// Use WithPreemptDelay() together with WithPreempt()/WithPreemptTimeout() to delay +// the preemption signal. +// Push returns immediately after the item is buffered, and a goroutine is spawned +// to signal preemption after the delay. +func (l *TurnLoop[T]) Push(item T, opts ...PushOption[T]) (bool, <-chan struct{}) { + cfg := &pushConfig[T]{} + for _, opt := range opts { + opt(cfg) + } + + if cfg.pushStrategy != nil { + return l.pushWithStrategy(item, cfg) + } + + return l.pushWithConfig(item, cfg) +} + +// pushWithStrategy atomically holds the run loop and snapshots the current turn, +// then calls the strategy callback with a guaranteed-stable TurnContext. If the +// strategy returns preempt options, the hold is kept and a preempt is requested; +// otherwise the hold is released and the item is buffered as a plain push. +func (l *TurnLoop[T]) pushWithStrategy(item T, cfg *pushConfig[T]) (bool, <-chan struct{}) { + strategy := cfg.pushStrategy + + runCtx, tcAny := l.preemptSig.holdAndGetTurn() + if runCtx == nil { + runCtx = context.Background() + } + var tc *TurnContext[T] + if tcAny != nil { + tc = tcAny.(*TurnContext[T]) + } + realOpts := strategy(runCtx, tc) + cfg = &pushConfig[T]{} + for _, opt := range realOpts { + opt(cfg) + } + cfg.pushStrategy = nil + + if !cfg.preempt { + l.preemptSig.unholdRunLoop() + if !l.buffer.TrySend(item) { + l.appendLate(item) + return false, nil + } + return true, nil + } + + if atomic.LoadInt32(&l.stopped) != 0 { + l.preemptSig.unholdRunLoop() + l.appendLate(item) + return false, nil + } + + if !l.buffer.TrySend(item) { + l.preemptSig.unholdRunLoop() + l.appendLate(item) + return false, nil + } + + ack := make(chan struct{}) + if atomic.LoadInt32(&l.started) == 0 { + l.preemptSig.unholdRunLoop() + close(ack) + return true, ack + } + + if cfg.preemptDelay > 0 { + go func() { + select { + case <-time.After(cfg.preemptDelay): + l.preemptSig.requestPreempt(ack, cfg.agentCancelOpts...) + case <-l.done: + l.preemptSig.unholdRunLoop() + close(ack) + } + }() + } else { + l.preemptSig.requestPreempt(ack, cfg.agentCancelOpts...) + } + return true, ack +} + +func (l *TurnLoop[T]) pushWithConfig(item T, cfg *pushConfig[T]) (bool, <-chan struct{}) { + if atomic.LoadInt32(&l.stopped) != 0 { + l.appendLate(item) + return false, nil + } + + if cfg.preempt { + l.preemptSig.holdRunLoop() + + if !l.buffer.TrySend(item) { + l.preemptSig.unholdRunLoop() + l.appendLate(item) + return false, nil + } + + ack := make(chan struct{}) + if atomic.LoadInt32(&l.started) == 0 { + l.preemptSig.unholdRunLoop() + close(ack) + return true, ack + } + + if cfg.preemptDelay > 0 { + go func() { + select { + case <-time.After(cfg.preemptDelay): + l.preemptSig.requestPreempt(ack, cfg.agentCancelOpts...) + case <-l.done: + l.preemptSig.unholdRunLoop() + close(ack) + } + }() + } else { + l.preemptSig.requestPreempt(ack, cfg.agentCancelOpts...) + } + return true, ack + } + + if !l.buffer.TrySend(item) { + l.appendLate(item) + return false, nil + } + return true, nil +} + +// Stop signals the loop to stop and returns immediately (non-blocking). +// Without options, the current agent turn runs to completion and the loop +// exits at the turn boundary without starting a new turn. ExitReason is nil. +// +// Use WithImmediate() to abort the running agent turn immediately. +// Use WithGraceful() to cancel at the nearest safe point with recursive +// propagation to nested agents. +// Use WithGracefulTimeout() for safe-point cancel with an escalation deadline. +// Use UntilIdleFor() to defer the stop until the loop has been continuously +// idle for a given duration; the loop shuts down automatically once the idle +// timer fires. +// +// This method may be called multiple times; subsequent calls update cancel options. +// A Stop() call without UntilIdleFor shuts down the loop immediately, even if +// a prior UntilIdleFor is still waiting. +// Call Wait() to block until the loop has fully exited and get the result. +// +// Stop may be called before Run. In that case, the stopped flag is set and +// a subsequent Run will exit the loop immediately. +// +// If the running agent does not support the WithCancel AgentRunOption, +// all cancel-related options (WithImmediate, WithGraceful, WithGracefulTimeout) +// degrade to "exit the loop on entering the next iteration" — the current +// agent turn runs to completion before the loop exits. +func (l *TurnLoop[T]) Stop(opts ...StopOption) { + cfg := &stopConfig{} + for _, opt := range opts { + opt(cfg) + } + + l.stopSig.signal(cfg) + + if cfg.idleFor > 0 { + l.buffer.Wakeup() + return + } + l.commitStop() +} + +func (l *TurnLoop[T]) commitStop() { + l.stopOnce.Do(func() { + l.stopSig.closeDone() + atomic.StoreInt32(&l.stopped, 1) + l.buffer.Close() + }) +} + +// Wait blocks until the loop exits and returns the result. +// This method is safe to call from multiple goroutines. +// All callers will receive the same result. +// +// Wait blocks until Run is called AND the loop exits. If Run is +// never called, Wait blocks forever. +func (l *TurnLoop[T]) Wait() *TurnLoopExitState[T] { + <-l.done + return l.result +} + +func (l *TurnLoop[T]) run(ctx context.Context) { + defer l.cleanup(ctx) + + if err := l.tryLoadCheckpoint(ctx); err != nil { + l.runErr = err + return + } + + // Monitor context cancellation: close the buffer so that a blocking + // Receive() unblocks. The loop will then check ctx.Err() and exit. + go func() { + select { + case <-ctx.Done(): + l.buffer.Close() + case <-l.done: + } + }() + + for { + if l.stopSig.isStopped() { + return + } + + isResume := false + var pr *turnLoopPendingResume[T] + var items []T + var pushBack []T + + if l.pendingResume != nil { + isResume = true + pr = l.pendingResume + l.pendingResume = nil + + pushBack = make([]T, 0, len(pr.canceled)+len(pr.unhandled)+len(pr.newItems)) + pushBack = append(pushBack, pr.canceled...) + pushBack = append(pushBack, pr.unhandled...) + pushBack = append(pushBack, pr.newItems...) + } else { + var first T + var ok bool + + if idleFor := l.stopSig.getIdleFor(); idleFor > 0 { + l.buffer.ClearWakeup() + idleTimer := time.NewTimer(idleFor) + cancelIdle := make(chan struct{}) + // When the idle timer fires, commitStop closes the buffer via + // buffer.Close(), which broadcasts to unblock the pending + // Receive() call below. + go func() { + select { + case <-idleTimer.C: + l.commitStop() + case <-cancelIdle: + } + }() + + first, ok = l.buffer.Receive() + + idleTimer.Stop() + close(cancelIdle) + } else { + first, ok = l.buffer.Receive() + // Woken up by Stop(UntilIdleFor); re-enter loop to start the idle timer. + if !ok && l.stopSig.getIdleFor() > 0 { + continue + } + } + + if !ok { + if err := ctx.Err(); err != nil { + l.runErr = err + } + return + } + + if err := ctx.Err(); err != nil { + l.buffer.PushFront([]T{first}) + l.runErr = err + return + } + + if l.stopSig.isStopped() { + l.buffer.PushFront([]T{first}) + return + } + + rest := l.buffer.TakeAll() + items = append([]T{first}, rest...) + pushBack = items + } + + // Drain any pending preempt that arrived between turns. A Push caller + // may have called holdRunLoop + requestPreempt while the loop was + // between iterations; acknowledge and release before planning the + // next turn. Use drainAll to release all pusher holds at once — + // multiple concurrent Push(WithPreempt) callers each hold a ref. + if preempted, _, ackList := l.preemptSig.waitForPreemptOrUnhold(); preempted { + for _, ack := range ackList { + close(ack) + } + l.preemptSig.drainAll() + } + + plan, err := l.planTurn(ctx, isResume, items, pr) + if err != nil { + if len(pushBack) > 0 { + l.buffer.PushFront(pushBack) + } + l.runErr = err + return + } + + if l.stopSig.isStopped() { + if len(pushBack) > 0 { + l.buffer.PushFront(pushBack) + } + return + } + + agent, err := l.config.PrepareAgent(plan.turnCtx, l, plan.spec.consumed) + if err != nil { + if len(pushBack) > 0 { + l.buffer.PushFront(pushBack) + } + l.runErr = err + return + } + + if l.stopSig.isStopped() { + if len(pushBack) > 0 { + l.buffer.PushFront(pushBack) + } + return + } + + l.buffer.PushFront(plan.remaining) + + // Bracket the turn with holdRunLoop / endTurnAndUnhold. The run loop's + // own hold ensures that if a Push caller also holds mid-turn, the total + // holdCount stays > 0 after endTurnAndUnhold, blocking the loop at + // waitForPreemptOrUnhold until the Push caller's preempt is resolved. + l.preemptSig.holdRunLoop() + runErr := l.runAgentAndHandleEvents(plan.turnCtx, agent, plan.spec) + + l.preemptSig.endTurnAndUnhold() + + if runErr != nil { + if errors.As(runErr, new(*CancelError)) && len(l.canceledItems) == 0 { + l.canceledItems = append([]T{}, plan.spec.consumed...) + } + l.runErr = runErr + return + } + } +} + +func (l *TurnLoop[T]) setupBridgeStore(spec *turnRunSpec[T], runOpts []AgentRunOption) ([]AgentRunOption, *bridgeStore, error) { + store := l.config.Store + if store == nil && spec.isResume { + return nil, nil, fmt.Errorf("failed to resume agent: %w", ErrCheckpointStoreNil) + } + if store == nil { + return runOpts, nil, nil + } + runOpts = append(runOpts, WithCheckPointID(bridgeCheckpointID)) + if spec.isResume { + if len(spec.resumeBytes) == 0 { + return nil, nil, fmt.Errorf("resume checkpoint is empty") + } + return runOpts, newResumeBridgeStore(bridgeCheckpointID, spec.resumeBytes), nil + } + return runOpts, newBridgeStore(), nil +} + +// watchPreemptSignal runs for the lifetime of a single turn. It listens on the +// notify channel for preempt requests and relays them to agentCancelFunc. +// +// preemptGen de-duplicates notifications: multiple notify wakes can fire for the +// same logical preempt (e.g. cond.Broadcast + channel send), so the watcher +// only acts when the generation advances. +// +// On the first preempt whose cancel actually contributed (i.e. the cancel options +// were accepted before the CancelError was finalized), preemptDone is closed to +// wake runAgentAndHandleEvents's select. +func (l *TurnLoop[T]) watchPreemptSignal(done <-chan struct{}, agentCancelFunc AgentCancelFunc, preemptDone chan struct{}) { + var lastGen uint64 + for { + select { + case <-done: + return + case <-l.preemptSig.notify: + if preempted, gen, opts, ackList := l.preemptSig.receivePreempt(); preempted { + if gen != lastGen { + firstPreempt := lastGen == 0 + lastGen = gen + // CancelHandle is intentionally not awaited here: agentCancelFunc commits the cancel signal synchronously, + // while waiting would block until the turn finishes and can deadlock this watcher against the done signal. + _, contributed := agentCancelFunc(opts...) + if firstPreempt && contributed { + close(preemptDone) + } + for _, ack := range ackList { + close(ack) + } + } + } + } + } +} + +// watchStopSignal runs for the lifetime of a single turn. It selects on two +// channels from stopSignal: +// +// - done (permanently closed after Stop): the durable stop flag. Fires +// immediately for any watcher, even those in turns started after +// Stop() but before the run loop observed isStopped(). This eliminates +// the race where a previous turn's watcher consumed the one-shot notify, +// leaving the current turn unable to detect the stop. +// +// - notify (one-shot, buffered 1): fires when a new Stop() call is made, +// enabling cancel-mode escalation (e.g. CancelAfterToolCalls → CancelImmediate). +// The generation counter de-duplicates wakes, analogous to preemptGen in +// watchPreemptSignal. +// +// On the first cancel that actually contributed (i.e. the cancel was accepted +// before the CancelError was finalized), stoppedDone is closed to wake +// runAgentAndHandleEvents's select. +func (l *TurnLoop[T]) watchStopSignal(done <-chan struct{}, agentCancelFunc AgentCancelFunc, stoppedDone chan struct{}) { + var lastGen uint64 + stoppedClosed := false + + tryCancel := func(gen uint64, opts []AgentCancelOption) { + if gen == lastGen { + return + } + lastGen = gen + if opts == nil { + return + } + _, contributed := agentCancelFunc(opts...) + if contributed && !stoppedClosed { + close(stoppedDone) + stoppedClosed = true + } + } + + for { + select { + case <-done: + return + case <-l.stopSig.notify: + tryCancel(l.stopSig.check()) + case <-l.stopSig.done: + tryCancel(l.stopSig.check()) + for { + select { + case <-done: + return + case <-l.stopSig.notify: + tryCancel(l.stopSig.check()) + } + } + } + } +} + +func (l *TurnLoop[T]) runAgentAndHandleEvents( + ctx context.Context, + agent Agent, + spec *turnRunSpec[T], +) error { + var iter *AsyncIterator[*AgentEvent] + + runOpts, ms, err := l.setupBridgeStore(spec, spec.runOpts) + if err != nil { + return err + } + store := l.config.Store + cancelOpt, agentCancelFunc := WithCancel() + runOpts = append(runOpts, cancelOpt) + + enableStreaming := false + if spec.input != nil { + enableStreaming = spec.input.EnableStreaming + } + runner := NewRunner(ctx, RunnerConfig{ + EnableStreaming: enableStreaming, + Agent: agent, + CheckPointStore: ms, + }) + + preemptDone := make(chan struct{}) + stoppedDone := make(chan struct{}) + + tc := &TurnContext[T]{ + Loop: l, + Consumed: spec.consumed, + Preempted: preemptDone, + Stopped: stoppedDone, + StopCause: l.stopSig.getStopCause, + } + l.preemptSig.setTurn(ctx, tc) + + if spec.isResume { + var err error + if spec.resumeParams != nil { + iter, err = runner.ResumeWithParams(ctx, bridgeCheckpointID, spec.resumeParams, runOpts...) + } else { + iter, err = runner.Resume(ctx, bridgeCheckpointID, runOpts...) + } + if err != nil { + return fmt.Errorf("failed to resume agent: %w", err) + } + } else { + iter = runner.Run(ctx, spec.input.Messages, runOpts...) + } + + handleEvents := func() error { + return l.onAgentEvents(ctx, tc, iter) + } + + done := make(chan struct{}) + var handleErr error + + go func() { + defer func() { + panicErr := recover() + if panicErr != nil { + handleErr = safe.NewPanicErr(panicErr, debug.Stack()) + } + close(done) + }() + handleErr = handleEvents() + }() + go l.watchPreemptSignal(done, agentCancelFunc, preemptDone) + go l.watchStopSignal(done, agentCancelFunc, stoppedDone) + + finalizeCheckpoint := func() error { + if store != nil && ms != nil { + data, ok, err := ms.Get(ctx, bridgeCheckpointID) + if err != nil { + return fmt.Errorf("failed to read runner checkpoint: %w", err) + } + if ok { + l.checkPointRunnerBytes = append([]byte{}, data...) + } + } + return nil + } + + // Wait for the turn to end. Three outcomes: + // + // done: Events fully handled (normal or error). If Stop() was + // called, save checkpoint so the caller can resume later. + // Also handle the select race: if preemptDone is closed + // too, treat as a preempt (return nil) instead of leaking + // the CancelError. + // + // preemptDone: A preemptive Push successfully cancelled the agent. + // Wait for the handleEvents goroutine to drain, then + // return nil — the run loop will start a new turn. + // + // stoppedDone: Stop() cancelled the agent. Save checkpoint so the + // caller can resume later. + select { + case <-done: + select { + case <-preemptDone: + return nil + default: + } + if l.stopSig.isStopped() { + if err := finalizeCheckpoint(); err != nil { + if handleErr != nil { + handleErr = fmt.Errorf("%w; checkpoint error: %v", handleErr, err) + } else { + handleErr = err + } + } + } + return handleErr + case <-preemptDone: + <-done + return nil + case <-stoppedDone: + <-done + if err := finalizeCheckpoint(); err != nil { + if handleErr != nil { + handleErr = fmt.Errorf("%w; checkpoint error: %v", handleErr, err) + } else { + handleErr = err + } + } + return handleErr + } +} + +func (l *TurnLoop[T]) cleanup(ctx context.Context) { + atomic.StoreInt32(&l.stopped, 1) + + unhandled := l.buffer.TakeAll() + checkpointID := l.config.CheckpointID + isIdle := len(l.checkPointRunnerBytes) == 0 && len(unhandled) == 0 && len(l.canceledItems) == 0 + + // Only save checkpoint when the loop exited due to an explicit Stop(). + // If Stop() was called but a callback error happened concurrently, + // the state may be inconsistent — don't checkpoint in that case. + // We consider the exit Stop-caused if runErr is nil (clean stop between + // turns) or a *CancelError (Stop canceled a running agent). + exitCausedByStop := l.runErr == nil || errors.As(l.runErr, new(*CancelError)) + shouldSaveCheckpoint := l.config.Store != nil && checkpointID != "" && l.stopSig.isStopped() && exitCausedByStop && !isIdle && !l.stopSig.isSkipCheckpoint() + + var checkpointed bool + var checkpointErr error + + if shouldSaveCheckpoint { + cp := &turnLoopCheckpoint[T]{ + RunnerCheckpoint: l.checkPointRunnerBytes, + HasRunnerState: len(l.checkPointRunnerBytes) > 0, + UnhandledItems: unhandled, + CanceledItems: l.canceledItems, + } + checkpointed = true + checkpointErr = l.saveTurnLoopCheckpoint(ctx, checkpointID, cp) + } else if l.loadCheckpointID != "" { + _ = l.deleteTurnLoopCheckpoint(ctx, l.loadCheckpointID) + } + + var takeLateOnce sync.Once + var takeLateResult []T + + l.result = &TurnLoopExitState[T]{ + ExitReason: l.runErr, + UnhandledItems: unhandled, + CanceledItems: l.canceledItems, + StopCause: l.stopSig.getStopCause(), + Checkpointed: checkpointed, + CheckpointErr: checkpointErr, + TakeLateItems: func() []T { + takeLateOnce.Do(func() { + l.lateMu.Lock() + takeLateResult = append([]T{}, l.lateItems...) + l.lateSealed = true + l.lateMu.Unlock() + }) + return takeLateResult + }, + } + + l.preemptSig.drainAll() + l.buffer.Close() + close(l.done) +} diff --git a/adk/turn_loop_test.go b/adk/turn_loop_test.go new file mode 100644 index 000000000..4f22ca1a7 --- /dev/null +++ b/adk/turn_loop_test.go @@ -0,0 +1,4639 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/cloudwego/eino/schema" +) + +type turnLoopMockAgent struct { + name string + events []*AgentEvent + runFunc func(ctx context.Context, input *AgentInput) (*AgentOutput, error) + cancelFunc func(opts ...AgentCancelOption) error +} + +func (a *turnLoopMockAgent) Name(_ context.Context) string { return a.name } +func (a *turnLoopMockAgent) Description(_ context.Context) string { return "mock agent" } +func (a *turnLoopMockAgent) Run(ctx context.Context, input *AgentInput, _ ...AgentRunOption) *AsyncIterator[*AgentEvent] { + iter, gen := NewAsyncIteratorPair[*AgentEvent]() + + if a.runFunc != nil { + go func() { + defer gen.Close() + output, err := a.runFunc(ctx, input) + if err != nil { + gen.Send(&AgentEvent{Err: err}) + return + } + gen.Send(&AgentEvent{Output: output}) + }() + return iter + } + + go func() { + defer gen.Close() + for _, e := range a.events { + gen.Send(e) + } + }() + return iter +} + +type turnLoopCheckpointStore struct { + m map[string][]byte + mu sync.Mutex +} + +func (s *turnLoopCheckpointStore) Set(_ context.Context, key string, value []byte) error { + s.mu.Lock() + defer s.mu.Unlock() + s.m[key] = value + return nil +} + +func (s *turnLoopCheckpointStore) Get(_ context.Context, key string) ([]byte, bool, error) { + s.mu.Lock() + defer s.mu.Unlock() + v, ok := s.m[key] + return v, ok, nil +} + +type turnLoopCancellableMockAgent struct { + name string + runFunc func(ctx context.Context, input *AgentInput) (*AgentOutput, error) + onCancel func(cc *cancelContext) + cancel context.CancelFunc + mu sync.Mutex +} + +func (a *turnLoopCancellableMockAgent) Name(_ context.Context) string { return a.name } +func (a *turnLoopCancellableMockAgent) Description(_ context.Context) string { return "mock agent" } + +func (a *turnLoopCancellableMockAgent) Run(ctx context.Context, input *AgentInput, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { + iter, gen := NewAsyncIteratorPair[*AgentEvent]() + + o := getCommonOptions(nil, opts...) + cc := o.cancelCtx + + a.mu.Lock() + var cancelCtx context.Context + cancelCtx, a.cancel = context.WithCancel(ctx) + a.mu.Unlock() + + go func() { + defer gen.Close() + if cc != nil { + go func() { + <-cc.cancelChan + // CRITICAL: call onCancel BEFORE cancel() to avoid race condition. + // If cancel() fires first, the runFunc returns immediately, + // flowAgent's defer calls markDone(), and doneChan closes + // before onCancel can read cc.config. + if a.onCancel != nil { + a.onCancel(cc) + } + a.mu.Lock() + if a.cancel != nil { + a.cancel() + } + a.mu.Unlock() + }() + } + + output, err := a.runFunc(cancelCtx, input) + if err != nil { + gen.Send(&AgentEvent{Err: err}) + return + } + gen.Send(&AgentEvent{Output: output}) + }() + return iter +} + +type turnLoopStopModeProbeAgent struct { + ccCh chan *cancelContext +} + +func (a *turnLoopStopModeProbeAgent) Name(_ context.Context) string { return "probe" } +func (a *turnLoopStopModeProbeAgent) Description(_ context.Context) string { return "probe" } +func (a *turnLoopStopModeProbeAgent) Run(ctx context.Context, input *AgentInput, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { + iter, gen := NewAsyncIteratorPair[*AgentEvent]() + o := getCommonOptions(nil, opts...) + cc := o.cancelCtx + a.ccCh <- cc + go func() { + defer gen.Close() + <-cc.cancelChan + for { + if cc.getMode() == CancelImmediate { + gen.Send(&AgentEvent{Err: cc.createCancelError()}) + return + } + time.Sleep(1 * time.Millisecond) + } + }() + return iter +} + +func newAndRunTurnLoop[T any](ctx context.Context, cfg TurnLoopConfig[T]) *TurnLoop[T] { + l := NewTurnLoop(cfg) + l.Run(ctx) + return l +} + +func newPreemptTestLoop(t *testing.T, agent *turnLoopCancellableMockAgent) *TurnLoop[string] { + t.Helper() + + agentStarted := make(chan struct{}) + agentStartedOnce := sync.Once{} + + originalRunFunc := agent.runFunc + agent.runFunc = func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + agentStartedOnce.Do(func() { close(agentStarted) }) + return originalRunFunc(ctx, input) + } + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: []string{items[0]}, + Remaining: items[1:], + }, nil + }, + }) + + loop.Push("first") + + select { + case <-agentStarted: + case <-time.After(1 * time.Second): + t.Fatal("agent did not start") + } + + return loop +} + +func TestTurnLoop_RunAndPush(t *testing.T) { + processedItems := make([]string, 0) + var mu sync.Mutex + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + mu.Lock() + processedItems = append(processedItems, items...) + mu.Unlock() + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Push("msg1") + loop.Push("msg2") + + time.Sleep(100 * time.Millisecond) + + loop.Stop() + result := loop.Wait() + + mu.Lock() + defer mu.Unlock() + + assert.NoError(t, result.ExitReason) + assert.NotEmpty(t, processedItems, "should have processed at least one item") +} + +func TestTurnLoop_PushReturnsErrorAfterStop(t *testing.T) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Stop() + + ok, _ := loop.Push("msg1") + assert.False(t, ok) +} + +func TestTurnLoop_StopIsIdempotent(t *testing.T) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Stop() + loop.Stop() + loop.Stop() + + result := loop.Wait() + assert.NoError(t, result.ExitReason) +} + +func TestTurnLoop_WaitMultipleGoroutines(t *testing.T) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Stop() + + var wg sync.WaitGroup + results := make([]*TurnLoopExitState[string], 3) + + for i := 0; i < 3; i++ { + i := i + wg.Add(1) + go func() { + defer wg.Done() + results[i] = loop.Wait() + }() + } + + wg.Wait() + + assert.Equal(t, results[0], results[1]) + assert.Equal(t, results[1], results[2]) +} + +func TestTurnLoop_UnhandledItemsOnStop(t *testing.T) { + started := make(chan struct{}) + blocked := make(chan struct{}) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + close(started) + <-blocked + return &GenInputResult[string]{ + Input: &AgentInput{}, + Consumed: items[:1], + Remaining: items[1:], + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Push("msg1") + loop.Push("msg2") + loop.Push("msg3") + + <-started + + loop.Stop() + close(blocked) + + result := loop.Wait() + assert.NotEmpty(t, result.UnhandledItems, "should return unhandled items") +} + +func TestTurnLoop_GenInputError(t *testing.T) { + genErr := errors.New("gen input error") + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return nil, genErr + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Push("msg1") + + result := loop.Wait() + assert.ErrorIs(t, result.ExitReason, genErr) +} + +func TestTurnLoop_GetAgentError(t *testing.T) { + agentErr := errors.New("get agent error") + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return nil, agentErr + }, + }) + + loop.Push("msg1") + + result := loop.Wait() + assert.ErrorIs(t, result.ExitReason, agentErr) +} + +func TestTurnLoop_BatchProcessing(t *testing.T) { + var batches [][]string + var mu sync.Mutex + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + mu.Lock() + batches = append(batches, items) + mu.Unlock() + + return &GenInputResult[string]{ + Input: &AgentInput{}, + Consumed: items[:1], + Remaining: items[1:], + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Push("msg1") + loop.Push("msg2") + loop.Push("msg3") + + time.Sleep(200 * time.Millisecond) + + loop.Stop() + loop.Wait() + + mu.Lock() + defer mu.Unlock() + + assert.NotEmpty(t, batches, "should have processed at least one batch") +} + +func TestTurnLoop_StopWithMode(t *testing.T) { + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Stop(WithGraceful()) + + result := loop.Wait() + assert.NoError(t, result.ExitReason) +} + +func TestTurnLoop_Preempt_CancelsCurrentAgent(t *testing.T) { + agentStarted := make(chan struct{}) + agentCancelled := make(chan struct{}) + agentStartedOnce := sync.Once{} + agentCancelledOnce := sync.Once{} + + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + agentStartedOnce.Do(func() { + close(agentStarted) + }) + <-ctx.Done() + agentCancelledOnce.Do(func() { + close(agentCancelled) + }) + return &AgentOutput{}, nil + }, + } + + genInputCalls := int32(0) + secondGenInputCalled := make(chan struct{}) + secondGenInputOnce := sync.Once{} + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + count := atomic.AddInt32(&genInputCalls, 1) + if count >= 2 { + secondGenInputOnce.Do(func() { + close(secondGenInputCalled) + }) + } + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: []string{items[0]}, + Remaining: items[1:], + }, nil + }, + }) + + loop.Push("first") + + select { + case <-agentStarted: + case <-time.After(1 * time.Second): + t.Fatal("agent did not start") + } + + loop.Push("urgent", WithPreempt[string](AnySafePoint)) + + select { + case <-agentCancelled: + case <-time.After(1 * time.Second): + t.Fatal("agent was not cancelled by preempt") + } + + select { + case <-secondGenInputCalled: + case <-time.After(1 * time.Second): + t.Fatal("second GenInput was not called after preempt") + } + + loop.Stop() + result := loop.Wait() + assert.NoError(t, result.ExitReason) + assert.GreaterOrEqual(t, atomic.LoadInt32(&genInputCalls), int32(2)) +} + +func TestTurnLoop_Preempt_DiscardsConsumedItems(t *testing.T) { + agentStarted := make(chan struct{}) + agentDone := make(chan struct{}) + agentStartedOnce := sync.Once{} + agentDoneOnce := sync.Once{} + firstAgentRun := true + var firstRunMu sync.Mutex + + genInputResults := make([][]string, 0) + var mu sync.Mutex + + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + firstRunMu.Lock() + isFirst := firstAgentRun + firstAgentRun = false + firstRunMu.Unlock() + + if isFirst { + agentStartedOnce.Do(func() { + close(agentStarted) + }) + <-ctx.Done() + } else { + agentDoneOnce.Do(func() { + close(agentDone) + }) + } + return &AgentOutput{}, nil + }, + } + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + mu.Lock() + genInputResults = append(genInputResults, items) + mu.Unlock() + + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: []string{items[0]}, + Remaining: items[1:], + }, nil + }, + }) + + loop.Push("first") + + select { + case <-agentStarted: + case <-time.After(1 * time.Second): + t.Fatal("agent did not start") + } + + loop.Push("urgent", WithPreempt[string](AnySafePoint)) + + select { + case <-agentDone: + case <-time.After(1 * time.Second): + t.Fatal("second agent run did not complete") + } + + loop.Stop() + result := loop.Wait() + assert.NoError(t, result.ExitReason) + + mu.Lock() + defer mu.Unlock() + require.GreaterOrEqual(t, len(genInputResults), 2) + assert.NotContains(t, genInputResults[1], "first") + assert.Contains(t, genInputResults[1], "urgent") +} + +func TestTurnLoop_Preempt_WithAgentCancelMode(t *testing.T) { + cancelFuncCalled := make(chan struct{}) + cancelFuncCalledOnce := sync.Once{} + firstCancelModeUsed := CancelImmediate + var cancelModeMu sync.Mutex + + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + <-ctx.Done() + return &AgentOutput{}, nil + }, + onCancel: func(cc *cancelContext) { + cancelModeMu.Lock() + cancelFuncCalledOnce.Do(func() { + firstCancelModeUsed = cc.getMode() + close(cancelFuncCalled) + }) + cancelModeMu.Unlock() + }, + } + + loop := newPreemptTestLoop(t, agent) + + loop.Push("urgent", WithPreempt[string](AfterToolCalls)) + + select { + case <-cancelFuncCalled: + case <-time.After(1 * time.Second): + t.Fatal("cancelFunc was not called by preempt") + } + + loop.Stop() + result := loop.Wait() + assert.NoError(t, result.ExitReason) + cancelModeMu.Lock() + actualMode := firstCancelModeUsed + cancelModeMu.Unlock() + assert.Equal(t, CancelAfterToolCalls, actualMode) +} + +func TestTurnLoop_PreemptAck_ClosesAfterCancelIsInitiated(t *testing.T) { + cancelObserved := make(chan struct{}) + agentFinishGate := make(chan struct{}) + cancelObservedOnce := sync.Once{} + + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + <-ctx.Done() + <-agentFinishGate + return &AgentOutput{}, nil + }, + onCancel: func(cc *cancelContext) { + cancelObservedOnce.Do(func() { close(cancelObserved) }) + }, + } + + loop := newPreemptTestLoop(t, agent) + + ok, ack := loop.Push("urgent", WithPreempt[string](AfterToolCalls)) + assert.True(t, ok) + assert.NotNil(t, ack) + + select { + case <-ack: + case <-time.After(1 * time.Second): + t.Fatal("preempt ack was not closed") + } + + select { + case <-cancelObserved: + case <-time.After(1 * time.Second): + t.Fatal("cancel was not initiated") + } + + close(agentFinishGate) + + loop.Stop() + result := loop.Wait() + assert.NoError(t, result.ExitReason) +} + +func TestTurnLoop_PreemptAck_ClosesImmediatelyIfLoopNotStarted(t *testing.T) { + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + ok, ack := loop.Push("urgent", WithPreempt[string](AnySafePoint)) + assert.True(t, ok) + assert.NotNil(t, ack) + + select { + case <-ack: + case <-time.After(1 * time.Second): + t.Fatal("preempt ack was not closed") + } +} + +func TestTurnLoop_Preempt_EscalatesOnSecondPreempt(t *testing.T) { + firstCancelSeen := make(chan struct{}) + agentFinishGate := make(chan struct{}) + firstCancelOnce := sync.Once{} + + var ccPtr atomic.Value + + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + <-ctx.Done() + <-agentFinishGate + return &AgentOutput{}, nil + }, + onCancel: func(cc *cancelContext) { + ccPtr.Store(cc) + firstCancelOnce.Do(func() { close(firstCancelSeen) }) + }, + } + + loop := newPreemptTestLoop(t, agent) + + loop.Push("urgent1", WithPreempt[string](AfterChatModel)) + select { + case <-firstCancelSeen: + case <-time.After(1 * time.Second): + t.Fatal("first preempt did not trigger cancel") + } + + loop.Push("urgent2", WithPreemptTimeout[string](AnySafePoint, time.Millisecond)) + + wantMode := CancelAfterChatModel | CancelAfterToolCalls + deadline := time.Now().Add(1 * time.Second) + for time.Now().Before(deadline) { + v := ccPtr.Load() + if v == nil { + time.Sleep(5 * time.Millisecond) + continue + } + cc := v.(*cancelContext) + if cc.getMode() == wantMode && atomic.LoadInt32(&cc.escalated) == 1 { + break + } + time.Sleep(5 * time.Millisecond) + } + + v := ccPtr.Load() + if v == nil { + t.Fatal("cancel context was not captured") + } + cc := v.(*cancelContext) + assert.Equal(t, wantMode, cc.getMode()) + assert.Equal(t, int32(1), atomic.LoadInt32(&cc.escalated)) + + close(agentFinishGate) + + loop.Stop() + result := loop.Wait() + assert.NoError(t, result.ExitReason) +} + +func TestTurnLoop_Preempt_JoinsSafePointModesOnSecondPreempt(t *testing.T) { + firstCancelSeen := make(chan struct{}) + agentFinishGate := make(chan struct{}) + firstCancelOnce := sync.Once{} + + var ccPtr atomic.Value + + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + <-ctx.Done() + <-agentFinishGate + return &AgentOutput{}, nil + }, + onCancel: func(cc *cancelContext) { + ccPtr.Store(cc) + firstCancelOnce.Do(func() { close(firstCancelSeen) }) + }, + } + + loop := newPreemptTestLoop(t, agent) + + loop.Push("urgent1", WithPreempt[string](AfterChatModel)) + select { + case <-firstCancelSeen: + case <-time.After(1 * time.Second): + t.Fatal("first preempt did not trigger cancel") + } + + loop.Push("urgent2", WithPreempt[string](AfterToolCalls)) + + want := CancelAfterChatModel | CancelAfterToolCalls + deadline := time.Now().Add(1 * time.Second) + for time.Now().Before(deadline) { + v := ccPtr.Load() + if v == nil { + time.Sleep(5 * time.Millisecond) + continue + } + cc := v.(*cancelContext) + if cc.getMode() == want { + break + } + time.Sleep(5 * time.Millisecond) + } + + v := ccPtr.Load() + if v == nil { + t.Fatal("cancel context was not captured") + } + cc := v.(*cancelContext) + assert.Equal(t, want, cc.getMode()) + + close(agentFinishGate) + + loop.Stop() + result := loop.Wait() + assert.NoError(t, result.ExitReason) +} + +func TestTurnLoop_Push_WithoutPreempt_DoesNotCancel(t *testing.T) { + agentRunCount := 0 + agentDone := make(chan struct{}) + + agent := &turnLoopMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + agentRunCount++ + if agentRunCount == 1 { + time.Sleep(100 * time.Millisecond) + } + if agentRunCount == 2 { + close(agentDone) + } + return &AgentOutput{}, nil + }, + } + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: []string{items[0]}, + Remaining: items[1:], + }, nil + }, + }) + + loop.Push("first") + time.Sleep(20 * time.Millisecond) + loop.Push("second") + + select { + case <-agentDone: + case <-time.After(1 * time.Second): + t.Fatal("second agent run did not complete") + } + + loop.Stop() + result := loop.Wait() + assert.NoError(t, result.ExitReason) + assert.Equal(t, 2, agentRunCount) +} + +func TestTurnLoop_PreemptDelay_NoMispreemptOnNaturalCompletion(t *testing.T) { + agent1Started := make(chan struct{}) + agent1Done := make(chan struct{}) + agent2Started := make(chan struct{}) + agent2Done := make(chan struct{}) + agent1StartedOnce := sync.Once{} + agent1DoneOnce := sync.Once{} + agent2StartedOnce := sync.Once{} + agent2DoneOnce := sync.Once{} + + var agentRunCount int32 + + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + count := atomic.AddInt32(&agentRunCount, 1) + if count == 1 { + agent1StartedOnce.Do(func() { close(agent1Started) }) + time.Sleep(50 * time.Millisecond) + agent1DoneOnce.Do(func() { close(agent1Done) }) + } else if count == 2 { + agent2StartedOnce.Do(func() { close(agent2Started) }) + time.Sleep(100 * time.Millisecond) + select { + case <-ctx.Done(): + t.Error("Agent2 was unexpectedly cancelled") + return nil, ctx.Err() + default: + } + agent2DoneOnce.Do(func() { close(agent2Done) }) + } + return &AgentOutput{}, nil + }, + } + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: []string{items[0]}, + Remaining: items[1:], + }, nil + }, + }) + + loop.Push("first") + + select { + case <-agent1Started: + case <-time.After(1 * time.Second): + t.Fatal("agent1 did not start") + } + + loop.Push("second", WithPreempt[string](AnySafePoint), WithPreemptDelay[string](500*time.Millisecond)) + + select { + case <-agent1Done: + case <-time.After(1 * time.Second): + t.Fatal("agent1 did not complete naturally") + } + + select { + case <-agent2Started: + case <-time.After(1 * time.Second): + t.Fatal("agent2 did not start") + } + + select { + case <-agent2Done: + case <-time.After(1 * time.Second): + t.Fatal("agent2 did not complete - may have been incorrectly preempted") + } + + loop.Stop() + result := loop.Wait() + assert.NoError(t, result.ExitReason) + assert.Equal(t, int32(2), atomic.LoadInt32(&agentRunCount)) +} + +func TestTurnLoop_ConcurrentPush(t *testing.T) { + var count int32 + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + atomic.AddInt32(&count, int32(len(items))) + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + for j := 0; j < 10; j++ { + _, _ = loop.Push(fmt.Sprintf("msg-%d-%d", i, j)) + } + }(i) + } + + wg.Wait() + time.Sleep(200 * time.Millisecond) + + loop.Stop() + result := loop.Wait() + + processed := atomic.LoadInt32(&count) + unhandled := len(result.UnhandledItems) + + assert.True(t, processed > 0, "should have processed some items") + assert.True(t, int(processed)+unhandled <= 100, "total should not exceed pushed amount") +} + +func TestTurnLoop_StopAfterReceive_RecoverItem(t *testing.T) { + receiveStarted := make(chan struct{}) + cancelDone := make(chan struct{}) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + close(receiveStarted) + <-cancelDone + time.Sleep(50 * time.Millisecond) + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Push("msg1") + <-receiveStarted + + loop.Stop() + close(cancelDone) + + result := loop.Wait() + assert.NoError(t, result.ExitReason) +} + +func TestTurnLoop_StopAfterGenInput_RecoverConsumed(t *testing.T) { + genInputDone := make(chan struct{}) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + close(genInputDone) + time.Sleep(50 * time.Millisecond) + return &GenInputResult[string]{ + Input: &AgentInput{}, + Consumed: items[:1], + Remaining: items[1:], + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + time.Sleep(100 * time.Millisecond) + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Push("msg1") + loop.Push("msg2") + + <-genInputDone + + time.Sleep(60 * time.Millisecond) + loop.Stop() + + result := loop.Wait() + assert.NoError(t, result.ExitReason) +} + +func TestTurnLoop_GetAgentError_RecoverConsumed(t *testing.T) { + agentErr := errors.New("get agent error") + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{}, + Consumed: items[:1], + Remaining: items[1:], + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], c []string) (Agent, error) { + return nil, agentErr + }, + }) + + loop.Push("msg1") + loop.Push("msg2") + + result := loop.Wait() + assert.ErrorIs(t, result.ExitReason, agentErr) + assert.NotEmpty(t, result.UnhandledItems, "should recover at least the consumed item and remaining") +} + +func TestTurnLoop_GenInputError_RecoverItems(t *testing.T) { + genErr := errors.New("gen input error") + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return nil, genErr + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Push("msg1") + loop.Push("msg2") + + result := loop.Wait() + assert.ErrorIs(t, result.ExitReason, genErr) + assert.Len(t, result.UnhandledItems, 2, "should recover all items when GenInput fails") + assert.Contains(t, result.UnhandledItems, "msg1") + assert.Contains(t, result.UnhandledItems, "msg2") +} + +func TestTurnLoop_PrepareAgentError_RecoverItemsInOrder(t *testing.T) { + agentErr := errors.New("prepare agent error") + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + var urgent string + remaining := make([]string, 0, len(items)) + for _, item := range items { + if item == "urgent" { + urgent = item + } else { + remaining = append(remaining, item) + } + } + if urgent != "" { + return &GenInputResult[string]{ + Input: &AgentInput{}, + Consumed: []string{urgent}, + Remaining: remaining, + }, nil + } + return &GenInputResult[string]{ + Input: &AgentInput{}, + Consumed: items[:1], + Remaining: items[1:], + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return nil, agentErr + }, + }) + + loop.Push("msg1") + loop.Push("urgent") + loop.Push("msg2") + + result := loop.Wait() + assert.ErrorIs(t, result.ExitReason, agentErr) + assert.Len(t, result.UnhandledItems, 3, "should recover all items") + assert.Equal(t, []string{"msg1", "urgent", "msg2"}, result.UnhandledItems, + "should preserve original push order even when GenInput selects non-prefix items") +} + +// Context cancel tests: the TurnLoop monitors context cancellation by closing +// the internal buffer when ctx.Done() fires, which unblocks the blocking +// Receive() call. The loop then checks ctx.Err() and exits with the context error. + +func TestTurnLoop_ContextCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + genInputStarted := make(chan struct{}) + genInputDone := make(chan struct{}) + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + close(genInputStarted) + <-genInputDone + if err := ctx.Err(); err != nil { + return nil, err + } + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], c []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Push("msg1") + + <-genInputStarted + cancel() + close(genInputDone) + + result := loop.Wait() + assert.ErrorIs(t, result.ExitReason, context.Canceled) +} + +func TestTurnLoop_ContextDeadlineExceeded(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + select { + case <-time.After(100 * time.Millisecond): + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + case <-ctx.Done(): + return nil, ctx.Err() + } + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], c []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Push("msg1") + + result := loop.Wait() + assert.ErrorIs(t, result.ExitReason, context.DeadlineExceeded) +} + +func TestTurnLoop_ContextCancelBeforeReceive(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], c []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + // Push before Run to guarantee the item is buffered before the + // context-monitoring goroutine can close the buffer. + _, _ = loop.Push("msg1") + loop.Run(ctx) + + result := loop.Wait() + assert.ErrorIs(t, result.ExitReason, context.Canceled) + assert.Len(t, result.UnhandledItems, 1) +} + +func TestTurnLoop_ContextCancelDuringBlockingReceive(t *testing.T) { + // When context is cancelled while Receive() is blocking (no items in buffer), + // the context monitoring goroutine closes the buffer, which unblocks Receive(). + ctx, cancel := context.WithCancel(context.Background()) + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], c []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + // Don't push any items — let Receive() block + time.Sleep(50 * time.Millisecond) + cancel() + + result := loop.Wait() + assert.ErrorIs(t, result.ExitReason, context.Canceled) +} + +func TestTurnLoop_ContextCancelAfterGenInput_RecoverItems(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + genInputCount := 0 + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + genInputCount++ + if genInputCount == 1 { + cancel() + } + return &GenInputResult[string]{ + Input: &AgentInput{}, + Consumed: items[:1], + Remaining: items[1:], + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], c []string) (Agent, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Push("msg1") + loop.Push("msg2") + + result := loop.Wait() + assert.ErrorIs(t, result.ExitReason, context.Canceled) + assert.NotEmpty(t, result.UnhandledItems, "should recover consumed and remaining items") +} + +func TestTurnLoop_OnAgentEventsReceivesEvents(t *testing.T) { + var receivedEvents []*AgentEvent + var receivedConsumed []string + var mu sync.Mutex + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + mu.Lock() + receivedConsumed = append(receivedConsumed, tc.Consumed...) + mu.Unlock() + + for { + event, ok := events.Next() + if !ok { + break + } + mu.Lock() + receivedEvents = append(receivedEvents, event) + mu.Unlock() + } + return nil + }, + }) + + loop.Push("msg1") + + time.Sleep(100 * time.Millisecond) + + loop.Stop() + result := loop.Wait() + + assert.NoError(t, result.ExitReason) + + mu.Lock() + defer mu.Unlock() + assert.NotEmpty(t, receivedConsumed, "should have received consumed items") +} + +func TestTurnLoop_StopDuringAgentExecution(t *testing.T) { + agentStarted := make(chan struct{}) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + close(agentStarted) + time.Sleep(200 * time.Millisecond) + for { + _, ok := events.Next() + if !ok { + break + } + } + return nil + }, + }) + + loop.Push("msg1") + + <-agentStarted + loop.Stop() + + result := loop.Wait() + assert.NoError(t, result.ExitReason) + assert.Empty(t, result.CanceledItems) +} + +func TestTurnLoop_StopCheckPointIDInCancelError(t *testing.T) { + ctx := context.Background() + modelStarted := make(chan struct{}, 1) + checkpointID := "turn-loop-cancel-ckpt-1" + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + + slowModel := &cancelTestChatModel{ + delayNs: int64(500 * time.Millisecond), + response: &schema.Message{ + Role: schema.Assistant, + Content: "Hello", + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Instruction: "You are a test assistant", + Model: slowModel, + }) + assert.NoError(t, err) + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + Store: store, + CheckpointID: checkpointID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + }) + + loop.Push("msg1") + + <-modelStarted + loop.Stop() + + result := loop.Wait() + + var cancelErr *CancelError + assert.True(t, errors.As(result.ExitReason, &cancelErr), "ExitReason should be a *CancelError") + + store.mu.Lock() + defer store.mu.Unlock() + _, ok := store.m[checkpointID] + assert.True(t, ok, "checkpoint should be saved under the configured CheckpointID") +} + +func TestTurnLoop_StopWithoutCheckpointIDDoesNotPersist(t *testing.T) { + ctx := context.Background() + modelStarted := make(chan struct{}, 1) + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + + slowModel := &cancelTestChatModel{ + delayNs: int64(500 * time.Millisecond), + response: &schema.Message{ + Role: schema.Assistant, + Content: "Hello", + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Instruction: "You are a test assistant", + Model: slowModel, + }) + assert.NoError(t, err) + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + Store: store, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + }) + + loop.Push("msg1") + + <-modelStarted + loop.Stop() + + result := loop.Wait() + + var cancelErr *CancelError + assert.True(t, errors.As(result.ExitReason, &cancelErr), "ExitReason should be a *CancelError") + + store.mu.Lock() + defer store.mu.Unlock() + assert.Empty(t, store.m, "no checkpoint should be saved when CheckpointID is not configured") +} + +func TestTurnLoop_StopWhileIdle_SkipsCheckpoint(t *testing.T) { + ctx := context.Background() + store := &deletableCheckpointStore{ + turnLoopCheckpointStore: turnLoopCheckpointStore{m: make(map[string][]byte)}, + } + cpID := "idle-session" + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Stop() + exit := loop.Wait() + assert.NoError(t, exit.ExitReason) + + store.mu.Lock() + defer store.mu.Unlock() + _, exists := store.m[cpID] + assert.False(t, exists, "no checkpoint should be saved when TurnLoop is idle") +} + +func TestTurnLoop_StopBetweenTurnsAndResume(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + cpID := "between-turns-session" + + loop := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Push("a") + loop.Push("b") + loop.Stop() + loop.Run(ctx) + + exit := loop.Wait() + assert.NoError(t, exit.ExitReason) + + var seen []string + var mu sync.Mutex + loop2 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + mu.Lock() + seen = append([]string{}, items...) + mu.Unlock() + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + _, ok := events.Next() + if !ok { + break + } + } + tc.Loop.Stop() + return nil + }, + }) + + loop2.Push("c") + loop2.Run(ctx) + exit2 := loop2.Wait() + assert.NoError(t, exit2.ExitReason) + + mu.Lock() + defer mu.Unlock() + assert.Equal(t, []string{"a", "b", "c"}, seen) +} + +func TestTurnLoop_StopDuringAgentExecution_PersistAndResume(t *testing.T) { + ctx := context.Background() + modelStarted := make(chan struct{}, 1) + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + cpID := "mid-turn-session" + + slowModel := &cancelTestChatModel{ + delayNs: int64(500 * time.Millisecond), + response: &schema.Message{ + Role: schema.Assistant, + Content: "Hello", + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Instruction: "You are a test assistant", + Model: slowModel, + }) + assert.NoError(t, err) + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + }) + + loop.Push("msg1") + <-modelStarted + loop.Stop() + exit := loop.Wait() + + store.mu.Lock() + _, ok := store.m[cpID] + store.mu.Unlock() + assert.True(t, ok) + _ = exit + + slowModel.setDelay(10 * time.Millisecond) + + var consumed2 []string + var genResumeCalled bool + var genInputCalled bool + loop2 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenResume: func(ctx context.Context, _ *TurnLoop[string], canceledItems []string, unhandledItems []string, newItems []string) (*GenResumeResult[string], error) { + genResumeCalled = true + return &GenResumeResult[string]{ + Consumed: canceledItems, + Remaining: append(append([]string{}, unhandledItems...), newItems...), + }, nil + }, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + genInputCalled = true + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + consumed2 = append([]string{}, consumed...) + return agent, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + _, ok := events.Next() + if !ok { + break + } + } + tc.Loop.Stop() + return nil + }, + }) + + loop2.Run(ctx) + exit2 := loop2.Wait() + assert.NoError(t, exit2.ExitReason) + assert.Equal(t, []string{"msg1"}, consumed2) + assert.True(t, genResumeCalled) + assert.False(t, genInputCalled) +} + +func TestTurnLoop_CheckpointIDWithoutStore_FreshStart(t *testing.T) { + ctx := context.Background() + var genInputCalled bool + loop := NewTurnLoop(TurnLoopConfig[string]{ + CheckpointID: "some-id", + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + genInputCalled = true + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := events.Next(); !ok { + break + } + } + tc.Loop.Stop() + return nil + }, + }) + loop.Push("a") + loop.Run(ctx) + exit := loop.Wait() + assert.NoError(t, exit.ExitReason) + assert.True(t, genInputCalled) +} + +func TestTurnLoop_CheckpointNotFound_FreshStart(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + var genInputCalled bool + loop := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: "nonexistent-id", + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + genInputCalled = true + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := events.Next(); !ok { + break + } + } + tc.Loop.Stop() + return nil + }, + }) + loop.Push("a") + loop.Run(ctx) + exit := loop.Wait() + assert.NoError(t, exit.ExitReason) + assert.True(t, genInputCalled) +} + +func TestTurnLoop_CheckpointEmptyData_TreatedAsNoCheckpoint(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + store.m["cp-empty"] = nil + + var genInputCalled bool + loop := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: "cp-empty", + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + genInputCalled = true + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := events.Next(); !ok { + break + } + } + tc.Loop.Stop() + return nil + }, + }) + loop.Push("a") + loop.Run(ctx) + exit := loop.Wait() + assert.NoError(t, exit.ExitReason) + assert.True(t, genInputCalled) +} + +type errorCheckpointStore struct { + getErr error + setErr error +} + +func (s *errorCheckpointStore) Get(_ context.Context, _ string) ([]byte, bool, error) { + return nil, false, s.getErr +} + +func (s *errorCheckpointStore) Set(_ context.Context, _ string, _ []byte) error { + return s.setErr +} + +func TestTurnLoop_CheckpointLoadError_ReturnsError(t *testing.T) { + ctx := context.Background() + store := &errorCheckpointStore{getErr: fmt.Errorf("store unavailable")} + loop := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: "cp-1", + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop.Push("a") + loop.Run(ctx) + exit := loop.Wait() + assert.Error(t, exit.ExitReason) + assert.Contains(t, exit.ExitReason.Error(), "store unavailable") +} + +func TestTurnLoop_CheckpointCorruptData_ReturnsError(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + store.m["cp-corrupt"] = []byte("not-valid-gob-data") + loop := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: "cp-corrupt", + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop.Push("a") + loop.Run(ctx) + exit := loop.Wait() + assert.Error(t, exit.ExitReason) + assert.Contains(t, exit.ExitReason.Error(), "failed to unmarshal checkpoint") +} + +func TestTurnLoop_CheckpointSaveError_ReturnsError(t *testing.T) { + ctx := context.Background() + modelStarted := make(chan struct{}, 1) + saveStore := &errorCheckpointStore{setErr: fmt.Errorf("write failed")} + slowModel := &cancelTestChatModel{ + delayNs: int64(500 * time.Millisecond), + response: &schema.Message{ + Role: schema.Assistant, + Content: "Hello", + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Instruction: "You are a test assistant", + Model: slowModel, + }) + assert.NoError(t, err) + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + Store: saveStore, + CheckpointID: "cp-1", + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + }) + loop.Push("msg1") + <-modelStarted + loop.Stop() + exit := loop.Wait() + assert.Error(t, exit.ExitReason) + assert.True(t, exit.Checkpointed) + assert.Error(t, exit.CheckpointErr) + assert.Contains(t, exit.CheckpointErr.Error(), "write failed") +} + +func TestTurnLoop_StaleCheckpointDeletion_OnCleanResume(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + cpID := "stale-session" + + loop1 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop1.Push("a") + loop1.Stop() + loop1.Run(ctx) + loop1.Wait() + + store.mu.Lock() + _, exists := store.m[cpID] + store.mu.Unlock() + assert.True(t, exists, "checkpoint should exist after first loop saves it") + + loop2 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := events.Next(); !ok { + break + } + } + tc.Loop.Stop() + return nil + }, + }) + loop2.Push("b") + loop2.Run(ctx) + exit2 := loop2.Wait() + assert.NoError(t, exit2.ExitReason) + + store.mu.Lock() + _, exists = store.m[cpID] + store.mu.Unlock() + assert.True(t, exists, "checkpoint should still exist because loop2 was stopped and saved a new one") +} + +func TestTurnLoop_StaleCheckpointDeletion_ContextCancel(t *testing.T) { + ctx := context.Background() + store := &deletableCheckpointStore{turnLoopCheckpointStore: turnLoopCheckpointStore{m: make(map[string][]byte)}} + cpID := "delete-on-cancel" + + loop1 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop1.Push("a") + loop1.Stop() + loop1.Run(ctx) + loop1.Wait() + + store.mu.Lock() + _, exists := store.m[cpID] + store.mu.Unlock() + assert.True(t, exists, "checkpoint saved after loop1") + + ctx2, cancel2 := context.WithCancel(ctx) + loop2 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := events.Next(); !ok { + break + } + } + cancel2() + return nil + }, + }) + loop2.Push("b") + loop2.Run(ctx2) + exit2 := loop2.Wait() + assert.ErrorIs(t, exit2.ExitReason, context.Canceled) + + store.mu.Lock() + _, exists = store.m[cpID] + deleteCalled := store.deleteCalled + store.mu.Unlock() + assert.True(t, deleteCalled && !exists, "stale checkpoint should be deleted when loop exits via context cancellation") +} + +type deletableCheckpointStore struct { + turnLoopCheckpointStore + deleteCalled bool + deletedKey string +} + +func (s *deletableCheckpointStore) Delete(_ context.Context, key string) error { + s.mu.Lock() + defer s.mu.Unlock() + s.deleteCalled = true + s.deletedKey = key + delete(s.m, key) + return nil +} + +func TestTurnLoop_CheckpointDeleter_CalledOnContextCancel(t *testing.T) { + ctx := context.Background() + store := &deletableCheckpointStore{ + turnLoopCheckpointStore: turnLoopCheckpointStore{m: make(map[string][]byte)}, + } + cpID := "deleter-session" + + loop1 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop1.Push("a") + loop1.Stop() + loop1.Run(ctx) + loop1.Wait() + + store.mu.Lock() + _, exists := store.m[cpID] + store.mu.Unlock() + assert.True(t, exists, "checkpoint saved after loop1") + + ctx2, cancel2 := context.WithCancel(ctx) + loop2 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := events.Next(); !ok { + break + } + } + cancel2() + return nil + }, + }) + loop2.Push("b") + loop2.Run(ctx2) + exit2 := loop2.Wait() + assert.ErrorIs(t, exit2.ExitReason, context.Canceled) + + store.mu.Lock() + defer store.mu.Unlock() + assert.True(t, store.deleteCalled, "CheckPointDeleter.Delete should be called") + assert.Equal(t, cpID, store.deletedKey) + _, exists = store.m[cpID] + assert.False(t, exists, "checkpoint should be removed from store") +} + +func TestTurnLoop_GenResumeNil_Error(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + cpID := "resume-nil-session" + modelStarted := make(chan struct{}, 1) + + slowModel := &cancelTestChatModel{ + delayNs: int64(500 * time.Millisecond), + response: &schema.Message{ + Role: schema.Assistant, + Content: "Hello", + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Instruction: "You are a test assistant", + Model: slowModel, + }) + assert.NoError(t, err) + + loop1 := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + }) + loop1.Push("msg1") + <-modelStarted + loop1.Stop() + loop1.Wait() + + loop2 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop2.Run(ctx) + exit2 := loop2.Wait() + assert.Error(t, exit2.ExitReason) + assert.Contains(t, exit2.ExitReason.Error(), "GenResume is required") +} + +func TestTurnLoop_SameCheckpointID_OverwritePattern(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + cpID := "overwrite-session" + + loop1 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop1.Push("a") + loop1.Push("b") + loop1.Stop() + loop1.Run(ctx) + loop1.Wait() + + store.mu.Lock() + data1 := append([]byte{}, store.m[cpID]...) + store.mu.Unlock() + assert.NotEmpty(t, data1) + + loop2 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop2.Push("c") + loop2.Stop() + loop2.Run(ctx) + loop2.Wait() + + store.mu.Lock() + data2 := append([]byte{}, store.m[cpID]...) + store.mu.Unlock() + assert.NotEmpty(t, data2) + assert.NotEqual(t, data1, data2, "checkpoint data should change because items are different") + + var seen []string + var mu sync.Mutex + loop3 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + mu.Lock() + seen = append([]string{}, items...) + mu.Unlock() + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := events.Next(); !ok { + break + } + } + tc.Loop.Stop() + return nil + }, + }) + loop3.Push("d") + loop3.Run(ctx) + exit3 := loop3.Wait() + assert.NoError(t, exit3.ExitReason) + + mu.Lock() + defer mu.Unlock() + assert.Equal(t, []string{"a", "b", "c", "d"}, seen, "should see loop2's unhandled items (a,b,c from loop2's checkpoint) plus new d") +} + +func TestTurnLoop_CheckpointHasRunnerStateButEmptyBytes(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + cpID := "empty-runner-bytes" + + cp := &turnLoopCheckpoint[string]{ + HasRunnerState: true, + RunnerCheckpoint: nil, + UnhandledItems: []string{"x"}, + } + data, err := marshalTurnLoopCheckpoint(cp) + assert.NoError(t, err) + store.m[cpID] = data + + loop := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop.Push("a") + loop.Run(ctx) + exit := loop.Wait() + assert.Error(t, exit.ExitReason) + assert.Contains(t, exit.ExitReason.Error(), "has runner state but bytes are empty") +} + +func TestTurnLoop_GenResumeReturnsError(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + cpID := "resume-err-session" + modelStarted := make(chan struct{}, 1) + + slowModel := &cancelTestChatModel{ + delayNs: int64(500 * time.Millisecond), + response: &schema.Message{ + Role: schema.Assistant, + Content: "Hello", + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Instruction: "You are a test assistant", + Model: slowModel, + }) + assert.NoError(t, err) + + loop1 := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + }) + loop1.Push("msg1") + <-modelStarted + loop1.Stop() + loop1.Wait() + + genResumeErr := fmt.Errorf("resume callback failed") + loop2 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + GenResume: func(ctx context.Context, _ *TurnLoop[string], canceled, unhandled, newItems []string) (*GenResumeResult[string], error) { + return nil, genResumeErr + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop2.Run(ctx) + exit2 := loop2.Wait() + assert.Error(t, exit2.ExitReason) + assert.ErrorIs(t, exit2.ExitReason, genResumeErr) +} + +func TestTurnLoop_CheckpointSaveError_MergesWithExistingError(t *testing.T) { + ctx := context.Background() + modelStarted := make(chan struct{}, 1) + saveStore := &errorCheckpointStore{setErr: fmt.Errorf("disk full")} + slowModel := &cancelTestChatModel{ + delayNs: int64(500 * time.Millisecond), + response: &schema.Message{ + Role: schema.Assistant, + Content: "Hello", + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Instruction: "You are a test assistant", + Model: slowModel, + }) + assert.NoError(t, err) + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + Store: saveStore, + CheckpointID: "cp-merge-err", + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + }) + loop.Push("msg1") + <-modelStarted + loop.Stop() + exit := loop.Wait() + assert.Error(t, exit.ExitReason) + var ce *CancelError + assert.True(t, errors.As(exit.ExitReason, &ce), "ExitReason should be CancelError, not merged with checkpoint error") + assert.True(t, exit.Checkpointed) + assert.Error(t, exit.CheckpointErr) + assert.Contains(t, exit.CheckpointErr.Error(), "disk full") +} + +func TestTurnLoop_ResumeWithParams(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + cpID := "resume-params-session" + modelStarted := make(chan struct{}, 1) + + slowModel := &cancelTestChatModel{ + delayNs: int64(500 * time.Millisecond), + response: &schema.Message{ + Role: schema.Assistant, + Content: "Hello", + }, + startedChan: modelStarted, + doneChan: make(chan struct{}, 1), + } + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Instruction: "You are a test assistant", + Model: slowModel, + }) + assert.NoError(t, err) + + loop1 := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + }) + loop1.Push("msg1") + <-modelStarted + loop1.Stop() + exit1 := loop1.Wait() + var ce *CancelError + assert.True(t, errors.As(exit1.ExitReason, &ce)) + + var resumeParamsUsed *ResumeParams + loop2 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + GenResume: func(ctx context.Context, _ *TurnLoop[string], canceled, unhandled, newItems []string) (*GenResumeResult[string], error) { + params := &ResumeParams{ + Targets: map[string]any{"some-address": "user-data"}, + } + resumeParamsUsed = params + return &GenResumeResult[string]{ + ResumeParams: params, + Consumed: append(append(canceled, unhandled...), newItems...), + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := events.Next(); !ok { + break + } + } + tc.Loop.Stop() + return nil + }, + }) + loop2.Run(ctx) + exit2 := loop2.Wait() + assert.NotNil(t, resumeParamsUsed, "GenResume should have been called with ResumeParams") + assert.Contains(t, resumeParamsUsed.Targets, "some-address") + _ = exit2 +} + +func TestTurnLoop_Stop_EscalatesCancelMode(t *testing.T) { + ctx := context.Background() + agentStarted := make(chan *cancelContext, 1) + probe := &turnLoopStopModeProbeAgent{ccCh: agentStarted} + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return probe, nil + }, + }) + + loop.Push("msg1") + cc := <-agentStarted + + loop.Stop(WithGracefulTimeout(10 * time.Second)) + loop.Stop(WithImmediate()) + + deadline := time.After(1 * time.Second) + for { + if cc.getMode() == CancelImmediate { + break + } + select { + case <-deadline: + t.Fatal("cancel mode did not escalate to CancelImmediate") + default: + } + time.Sleep(1 * time.Millisecond) + } + + exit := loop.Wait() + var ce *CancelError + require.True(t, errors.As(exit.ExitReason, &ce)) + assert.Equal(t, CancelImmediate, ce.Info.Mode) +} + +func TestTurnLoop_DefaultOnAgentEvents_ErrorPropagation(t *testing.T) { + agentErr := errors.New("agent execution error") + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + return nil, agentErr + }, + }, nil + }, + // No OnAgentEvents — use default handler + }) + + loop.Push("msg1") + + result := loop.Wait() + // The default handler should propagate the agent error as ExitReason + assert.Error(t, result.ExitReason) +} + +func TestTurnLoop_OnAgentEventsError(t *testing.T) { + handlerErr := errors.New("event handler error") + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + // Drain events then return error + for { + _, ok := events.Next() + if !ok { + break + } + } + return handlerErr + }, + }) + + loop.Push("msg1") + + result := loop.Wait() + assert.ErrorIs(t, result.ExitReason, handlerErr) +} + +func TestTurnLoop_StopCallFromGenInput(t *testing.T) { + // Test that calling Stop() from within GenInput works correctly + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, loop *TurnLoop[string], items []string) (*GenInputResult[string], error) { + loop.Stop() + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Push("msg1") + + result := loop.Wait() + assert.NoError(t, result.ExitReason) +} + +func TestTurnLoop_PushFromOnAgentEvents(t *testing.T) { + // Test that calling Push() from within OnAgentEvents works + pushCount := int32(0) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: []string{items[0]}, + Remaining: items[1:], + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + _, ok := events.Next() + if !ok { + break + } + } + count := atomic.AddInt32(&pushCount, 1) + if count == 1 { + // Push a follow-up item from the callback + _, _ = tc.Loop.Push("follow-up") + } else { + tc.Loop.Stop() + } + return nil + }, + }) + + loop.Push("initial") + + result := loop.Wait() + assert.NoError(t, result.ExitReason) + assert.Equal(t, int32(2), atomic.LoadInt32(&pushCount)) +} + +// Tests for NewTurnLoop: the permissive API where Push, Stop, and Wait are +// all valid on a not-yet-running loop. + +func TestNewTurnLoop_PushBeforeRun(t *testing.T) { + // Items pushed before Run are buffered and processed after Run starts. + var processedItems []string + var mu sync.Mutex + + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + mu.Lock() + processedItems = append(processedItems, items...) + mu.Unlock() + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + // Push before Run — items should be buffered. + ok, _ := loop.Push("msg1") + assert.True(t, ok) + ok, _ = loop.Push("msg2") + assert.True(t, ok) + + loop.Run(context.Background()) + + time.Sleep(100 * time.Millisecond) + + loop.Stop() + result := loop.Wait() + + mu.Lock() + defer mu.Unlock() + + assert.NoError(t, result.ExitReason) + assert.Contains(t, processedItems, "msg1") + assert.Contains(t, processedItems, "msg2") +} + +func TestNewTurnLoop_StopBeforeRun(t *testing.T) { + // Stop before Run sets the stopped flag. When Run is called, the loop + // exits immediately and buffered items appear as UnhandledItems. + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + t.Fatal("GenInput should not be called") + return nil, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + t.Fatal("PrepareAgent should not be called") + return nil, nil + }, + }) + + loop.Push("msg1") + loop.Push("msg2") + loop.Stop() + + // Push after Stop returns false. + ok, _ := loop.Push("msg3") + assert.False(t, ok) + + loop.Run(context.Background()) + result := loop.Wait() + + assert.NoError(t, result.ExitReason) + assert.Equal(t, []string{"msg1", "msg2"}, result.UnhandledItems) +} + +func TestNewTurnLoop_WaitBeforeRun(t *testing.T) { + // Wait blocks until Run is called AND the loop exits. + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + waitDone := make(chan *TurnLoopExitState[string], 1) + go func() { + waitDone <- loop.Wait() + }() + + // Wait should not return yet since Run hasn't been called. + select { + case <-waitDone: + t.Fatal("Wait returned before Run was called") + case <-time.After(50 * time.Millisecond): + // expected + } + + loop.Push("msg1") + loop.Stop() + loop.Run(context.Background()) + + select { + case result := <-waitDone: + assert.NoError(t, result.ExitReason) + assert.Equal(t, []string{"msg1"}, result.UnhandledItems) + case <-time.After(1 * time.Second): + t.Fatal("Wait did not return after Run + Stop") + } +} + +func TestNewTurnLoop_RunIsIdempotent(t *testing.T) { + var genInputCalls int32 + + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + atomic.AddInt32(&genInputCalls, 1) + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Push("msg1") + loop.Run(context.Background()) + loop.Run(context.Background()) + loop.Run(context.Background()) + + time.Sleep(100 * time.Millisecond) + + loop.Stop() + result := loop.Wait() + + assert.NoError(t, result.ExitReason) + assert.True(t, atomic.LoadInt32(&genInputCalls) >= 1) +} + +func TestNewTurnLoop_StopBeforeRun_ThenWait(t *testing.T) { + // Demonstrates the full sequence: create, push, stop, run, wait. + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + t.Fatal("GenInput should not be called after Stop") + return nil, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + t.Fatal("PrepareAgent should not be called after Stop") + return nil, nil + }, + }) + + loop.Push("a") + loop.Push("b") + loop.Push("c") + loop.Stop() + + // Run after Stop: the loop goroutine starts but exits immediately. + loop.Run(context.Background()) + + result := loop.Wait() + assert.NoError(t, result.ExitReason) + assert.Equal(t, []string{"a", "b", "c"}, result.UnhandledItems) +} + +func TestNewTurnLoop_ConcurrentPushAndRun(t *testing.T) { + // Concurrent Push and Run should not race. + for i := 0; i < 100; i++ { + var count int32 + + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + atomic.AddInt32(&count, int32(len(items))) + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + _, _ = loop.Push("item") + }() + + go func() { + defer wg.Done() + loop.Run(context.Background()) + }() + + wg.Wait() + + time.Sleep(50 * time.Millisecond) + + loop.Stop() + result := loop.Wait() + assert.NoError(t, result.ExitReason) + + processed := atomic.LoadInt32(&count) + unhandled := len(result.UnhandledItems) + assert.True(t, int(processed)+unhandled <= 1, + "total should not exceed pushed amount") + } +} + +type turnCtxKey struct{} + +func TestTurnLoop_RunCtx_Propagation(t *testing.T) { + // Verify that GenInputResult.RunCtx is propagated to PrepareAgent, + // the agent run, and OnAgentEvents. + + const traceVal = "trace-123" + var prepareCtxVal, agentCtxVal, eventsCtxVal string + + cfg := TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, loop *TurnLoop[string], items []string) (*GenInputResult[string], error) { + // Derive a new context with per-item trace data + runCtx := context.WithValue(ctx, turnCtxKey{}, traceVal) + return &GenInputResult[string]{ + RunCtx: runCtx, + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, loop *TurnLoop[string], consumed []string) (Agent, error) { + if v, ok := ctx.Value(turnCtxKey{}).(string); ok { + prepareCtxVal = v + } + return &turnLoopMockAgent{ + name: "trace-agent", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + if v, ok := ctx.Value(turnCtxKey{}).(string); ok { + agentCtxVal = v + } + return &AgentOutput{}, nil + }, + }, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + if v, ok := ctx.Value(turnCtxKey{}).(string); ok { + eventsCtxVal = v + } + for { + if _, ok := events.Next(); !ok { + break + } + } + tc.Loop.Stop() + return nil + }, + } + + loop := NewTurnLoop(cfg) + loop.Push("hello") + loop.Run(context.Background()) + result := loop.Wait() + + assert.Nil(t, result.ExitReason) + assert.Equal(t, traceVal, prepareCtxVal, "PrepareAgent should receive RunCtx") + assert.Equal(t, traceVal, agentCtxVal, "Agent run should receive RunCtx") + assert.Equal(t, traceVal, eventsCtxVal, "OnAgentEvents should receive RunCtx") +} + +func TestTurnLoop_TurnContext_PreemptedChannel(t *testing.T) { + preemptedSeen := make(chan struct{}) + agentStarted := make(chan struct{}) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopCancellableMockAgent{ + name: "slow", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + <-ctx.Done() + return nil, ctx.Err() + }, + }, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + close(agentStarted) + select { + case <-tc.Preempted: + close(preemptedSeen) + case <-time.After(5 * time.Second): + t.Error("timed out waiting for Preempted channel") + } + // Drain events + for { + if _, ok := events.Next(); !ok { + break + } + } + return nil + }, + }) + + loop.Push("msg1") + <-agentStarted + loop.Push("msg2", WithPreemptTimeout[string](AnySafePoint, time.Millisecond)) + + select { + case <-preemptedSeen: + // success + case <-time.After(5 * time.Second): + t.Fatal("preempted channel was never observed in OnAgentEvents") + } + + loop.Stop() + loop.Wait() +} + +// ============================================================================= +// preemptSignal unit tests (direct testing of the hold/preempt/unhold mechanism) +// ============================================================================= + +func TestPreemptSignal_HoldCountLifecycle(t *testing.T) { + s := newPreemptSignal() + + s.holdRunLoop() + s.holdRunLoop() + + done := make(chan bool) + go func() { + preempted, _, _ := s.waitForPreemptOrUnhold() + done <- preempted + }() + + select { + case <-done: + t.Fatal("waitForPreemptOrUnhold should block while holdCount > 0") + case <-time.After(50 * time.Millisecond): + } + + s.unholdRunLoop() + + select { + case <-done: + t.Fatal("waitForPreemptOrUnhold should still block (holdCount=1)") + case <-time.After(50 * time.Millisecond): + } + + s.unholdRunLoop() + + select { + case preempted := <-done: + assert.False(t, preempted, "should return not-preempted when all holds released") + case <-time.After(1 * time.Second): + t.Fatal("waitForPreemptOrUnhold should unblock when holdCount reaches 0") + } +} + +func TestPreemptSignal_RequestPreemptWithNoHold(t *testing.T) { + s := newPreemptSignal() + + ack := make(chan struct{}) + s.requestPreempt(ack) + + select { + case <-ack: + case <-time.After(100 * time.Millisecond): + t.Fatal("ack should be closed immediately when holdCount is 0") + } +} + +func TestPreemptSignal_RequestPreemptWakesWaiter(t *testing.T) { + s := newPreemptSignal() + s.holdRunLoop() + + done := make(chan struct { + preempted bool + ackList []chan struct{} + }) + go func() { + preempted, _, ackList := s.waitForPreemptOrUnhold() + done <- struct { + preempted bool + ackList []chan struct{} + }{preempted, ackList} + }() + + ack := make(chan struct{}) + s.requestPreempt(ack) + + select { + case result := <-done: + assert.True(t, result.preempted) + assert.Len(t, result.ackList, 1) + close(result.ackList[0]) + case <-time.After(1 * time.Second): + t.Fatal("waitForPreemptOrUnhold should wake on requestPreempt") + } +} + +func TestPreemptSignal_HoldAndGetTurn(t *testing.T) { + s := newPreemptSignal() + s.setTurn(context.Background(), "turn-A") + + ctx, tc := s.holdAndGetTurn() + assert.NotNil(t, ctx) + assert.Equal(t, "turn-A", tc) + + s.endTurnAndUnhold() + + _, tc2 := s.holdAndGetTurn() + assert.Nil(t, tc2, "TC should be nil after endTurnAndUnhold") + s.unholdRunLoop() +} + +func TestPreemptSignal_EndTurnPreservesSignalWhenHoldRemains(t *testing.T) { + s := newPreemptSignal() + + s.holdRunLoop() + s.holdRunLoop() + + ack := make(chan struct{}) + s.requestPreempt(ack) + + s.endTurnAndUnhold() + + done := make(chan bool) + go func() { + preempted, _, ackList := s.waitForPreemptOrUnhold() + for _, a := range ackList { + close(a) + } + done <- preempted + }() + + select { + case preempted := <-done: + assert.True(t, preempted, "signal state should be preserved when holdCount > 0 after endTurnAndUnhold") + case <-time.After(1 * time.Second): + t.Fatal("waiter should see the preserved preempt signal") + } + + select { + case <-ack: + case <-time.After(100 * time.Millisecond): + t.Fatal("ack should have been closed") + } +} + +func TestPreemptSignal_ConcurrentHoldRequestUnhold(t *testing.T) { + s := newPreemptSignal() + + var wg sync.WaitGroup + for i := 0; i < 50; i++ { + wg.Add(1) + go func() { + defer wg.Done() + s.holdRunLoop() + ack := make(chan struct{}) + s.requestPreempt(ack) + s.unholdRunLoop() + <-ack + }() + } + wg.Wait() +} + +// ============================================================================= +// Integration tests for race-prone preempt scenarios +// ============================================================================= + +func TestTurnLoop_ConcurrentPreemptsDuringTurn(t *testing.T) { + agentStarted := make(chan struct{}) + agentStartedOnce := sync.Once{} + + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + agentStartedOnce.Do(func() { + close(agentStarted) + }) + <-ctx.Done() + return &AgentOutput{}, nil + }, + } + + var genInputCount int32 + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + atomic.AddInt32(&genInputCount, 1) + return &GenInputResult[string]{ + Input: &AgentInput{}, + Consumed: items, + }, nil + }, + }) + + loop.Push("seed") + + select { + case <-agentStarted: + case <-time.After(1 * time.Second): + t.Fatal("agent did not start") + } + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + ok, ack := loop.Push(fmt.Sprintf("urgent-%d", i), WithPreemptTimeout[string](AnySafePoint, time.Millisecond)) + if ok && ack != nil { + select { + case <-ack: + case <-time.After(5 * time.Second): + t.Error("ack channel not closed within timeout") + } + } + }(i) + } + + wg.Wait() + time.Sleep(200 * time.Millisecond) + + loop.Stop() + result := loop.Wait() + assert.NoError(t, result.ExitReason) + assert.True(t, atomic.LoadInt32(&genInputCount) >= 2, "should have had at least the initial turn + one preempted turn") +} + +func TestTurnLoop_PreemptDuringTurnTransition(t *testing.T) { + turnCount := int32(0) + firstTurnDone := make(chan struct{}) + firstTurnOnce := sync.Once{} + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "fast"}, nil + }, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + count := atomic.AddInt32(&turnCount, 1) + if count == 1 { + firstTurnOnce.Do(func() { + close(firstTurnDone) + }) + } + return &GenInputResult[string]{ + Input: &AgentInput{}, + Consumed: items, + }, nil + }, + }) + + loop.Push("first") + + select { + case <-firstTurnDone: + case <-time.After(1 * time.Second): + t.Fatal("first turn did not start") + } + + time.Sleep(50 * time.Millisecond) + + ok, ack := loop.Push("transitional", WithPreempt[string](AnySafePoint)) + assert.True(t, ok, "push should succeed") + if ack != nil { + select { + case <-ack: + case <-time.After(2 * time.Second): + t.Fatal("ack should be closed even if preempt arrived during/after turn transition") + } + } + + time.Sleep(100 * time.Millisecond) + + loop.Stop() + result := loop.Wait() + assert.NoError(t, result.ExitReason) + assert.True(t, atomic.LoadInt32(&turnCount) >= 2, "transitional item should have been processed") +} + +func TestTurnLoop_PushStrategy_DuringTurnTransition(t *testing.T) { + agentStarted := make(chan struct{}) + agentStartedOnce := sync.Once{} + allowFinish := make(chan struct{}) + + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + agentStartedOnce.Do(func() { + close(agentStarted) + }) + select { + case <-allowFinish: + return &AgentOutput{}, nil + case <-ctx.Done(): + return &AgentOutput{}, nil + } + }, + } + + var genInputCount int32 + secondTurnDone := make(chan struct{}) + secondTurnOnce := sync.Once{} + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + count := atomic.AddInt32(&genInputCount, 1) + if count >= 2 { + secondTurnOnce.Do(func() { + close(secondTurnDone) + }) + } + return &GenInputResult[string]{ + Input: &AgentInput{}, + Consumed: items, + }, nil + }, + }) + + loop.Push("first") + + select { + case <-agentStarted: + case <-time.After(1 * time.Second): + t.Fatal("agent did not start") + } + + strategyBlocker := make(chan struct{}) + var strategyTCNotNil int32 + + go func() { + loop.Push("strategic-item", WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { + if tc != nil { + atomic.StoreInt32(&strategyTCNotNil, 1) + } + <-strategyBlocker + return []PushOption[string]{WithPreempt[string](AnySafePoint)} + })) + }() + + time.Sleep(50 * time.Millisecond) + close(allowFinish) + time.Sleep(50 * time.Millisecond) + close(strategyBlocker) + + select { + case <-secondTurnDone: + case <-time.After(3 * time.Second): + t.Fatal("second turn should eventually run after strategy resolves") + } + + loop.Stop() + result := loop.Wait() + assert.NoError(t, result.ExitReason) + assert.True(t, atomic.LoadInt32(&genInputCount) >= 2) +} + +func TestTurnLoop_ConcurrentPreemptAndStop(t *testing.T) { + for iter := 0; iter < 20; iter++ { + t.Run(fmt.Sprintf("iter_%d", iter), func(t *testing.T) { + ctx := context.Background() + + agentStarted := make(chan struct{}) + agentStartedOnce := sync.Once{} + + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + agentStartedOnce.Do(func() { + close(agentStarted) + }) + <-ctx.Done() + return &AgentOutput{}, nil + }, + } + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{}, + Consumed: items, + }, nil + }, + }) + + loop.Push("seed") + + select { + case <-agentStarted: + case <-time.After(1 * time.Second): + t.Fatal("agent did not start") + } + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + _, ack := loop.Push("preempt-item", WithPreempt[string](AnySafePoint)) + if ack != nil { + <-ack + } + }() + + go func() { + defer wg.Done() + loop.Stop() + }() + + wg.Wait() + loop.Wait() + }) + } +} + +func TestTurnLoop_ConcurrentPushStrategyAndStop(t *testing.T) { + for iter := 0; iter < 20; iter++ { + t.Run(fmt.Sprintf("iter_%d", iter), func(t *testing.T) { + ctx := context.Background() + + agentStarted := make(chan struct{}) + agentStartedOnce := sync.Once{} + + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + agentStartedOnce.Do(func() { + close(agentStarted) + }) + <-ctx.Done() + return &AgentOutput{}, nil + }, + } + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{}, + Consumed: items, + }, nil + }, + }) + + loop.Push("seed") + + select { + case <-agentStarted: + case <-time.After(1 * time.Second): + t.Fatal("agent did not start") + } + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + _, ack := loop.Push("strategic-item", WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { + return []PushOption[string]{WithPreempt[string](AnySafePoint)} + })) + if ack != nil { + <-ack + } + }() + + go func() { + defer wg.Done() + loop.Stop() + }() + + wg.Wait() + loop.Wait() + }) + } +} +func TestTurnLoop_TurnContext_StoppedChannel(t *testing.T) { + stoppedSeen := make(chan struct{}) + agentStarted := make(chan struct{}) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopCancellableMockAgent{ + name: "slow", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + <-ctx.Done() + return nil, ctx.Err() + }, + }, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + close(agentStarted) + select { + case <-tc.Stopped: + close(stoppedSeen) + case <-time.After(5 * time.Second): + t.Error("timed out waiting for Stopped channel") + } + // Drain events + for { + if _, ok := events.Next(); !ok { + break + } + } + return nil + }, + }) + + loop.Push("msg1") + <-agentStarted + loop.Stop() + + select { + case <-stoppedSeen: + // success + case <-time.After(5 * time.Second): + t.Fatal("stopped channel was never observed in OnAgentEvents") + } + + loop.Wait() +} + +func TestTurnLoop_TurnContext_BothPreemptedAndStopped(t *testing.T) { + t.Run("PreemptThenStop_OnlyPreemptContributes", func(t *testing.T) { + preemptedSeen := make(chan struct{}) + agentStarted := make(chan struct{}) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopCancellableMockAgent{ + name: "slow", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + <-ctx.Done() + return nil, ctx.Err() + }, + }, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + close(agentStarted) + select { + case <-tc.Preempted: + close(preemptedSeen) + case <-time.After(5 * time.Second): + t.Error("timed out waiting for Preempted") + } + for { + if _, ok := events.Next(); !ok { + break + } + } + return nil + }, + }) + + loop.Push("msg1") + <-agentStarted + loop.Push("msg2", WithPreemptTimeout[string](AnySafePoint, time.Millisecond)) + + select { + case <-preemptedSeen: + case <-time.After(5 * time.Second): + t.Fatal("Preempted channel was never closed") + } + + loop.Stop(WithImmediate()) + loop.Wait() + }) + + t.Run("StopThenPreempt_OnlyStopContributes", func(t *testing.T) { + stoppedSeen := make(chan struct{}) + agentStarted := make(chan struct{}) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopCancellableMockAgent{ + name: "slow", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + <-ctx.Done() + return nil, ctx.Err() + }, + }, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + close(agentStarted) + select { + case <-tc.Stopped: + close(stoppedSeen) + case <-time.After(5 * time.Second): + t.Error("timed out waiting for Stopped") + } + for { + if _, ok := events.Next(); !ok { + break + } + } + return nil + }, + }) + + loop.Push("msg1") + <-agentStarted + loop.Stop(WithImmediate()) + + select { + case <-stoppedSeen: + case <-time.After(5 * time.Second): + t.Fatal("Stopped channel was never closed") + } + + loop.Push("msg2", WithPreemptTimeout[string](AnySafePoint, time.Millisecond)) + loop.Wait() + }) +} + +func TestTurnLoop_PushStrategy_DuringTurn(t *testing.T) { + agentStarted := make(chan struct{}) + agentStartedOnce := sync.Once{} + agentCancelled := make(chan struct{}) + agentCancelledOnce := sync.Once{} + + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + agentStartedOnce.Do(func() { + close(agentStarted) + }) + <-ctx.Done() + agentCancelledOnce.Do(func() { + close(agentCancelled) + }) + return &AgentOutput{}, nil + }, + } + + genInputCalls := int32(0) + secondGenInputCalled := make(chan struct{}) + secondGenInputOnce := sync.Once{} + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + count := atomic.AddInt32(&genInputCalls, 1) + if count >= 2 { + secondGenInputOnce.Do(func() { + close(secondGenInputCalled) + }) + } + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: []string{items[0]}, + Remaining: items[1:], + }, nil + }, + }) + + loop.Push("first") + + select { + case <-agentStarted: + case <-time.After(1 * time.Second): + t.Fatal("agent did not start") + } + + // Strategy inspects TurnContext during a running turn and decides to preempt. + var strategyCalled int32 + var strategyTC *TurnContext[string] + loop.Push("urgent", WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { + atomic.AddInt32(&strategyCalled, 1) + strategyTC = tc + return []PushOption[string]{WithPreempt[string](AnySafePoint)} + })) + + select { + case <-agentCancelled: + case <-time.After(1 * time.Second): + t.Fatal("agent was not cancelled by strategy-returned preempt") + } + + select { + case <-secondGenInputCalled: + case <-time.After(1 * time.Second): + t.Fatal("second GenInput was not called after preempt") + } + + loop.Stop() + loop.Wait() + + assert.Equal(t, int32(1), atomic.LoadInt32(&strategyCalled)) + assert.NotNil(t, strategyTC, "strategy should receive non-nil TurnContext during a turn") + assert.Equal(t, []string{"first"}, strategyTC.Consumed) +} + +func TestTurnLoop_PushStrategy_BetweenTurns(t *testing.T) { + // Push with strategy before Run() — TurnContext should be nil. + var strategyCalled int32 + var strategyTCWasNil bool + + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + return &AgentOutput{}, nil + }, + } + + agentDone := make(chan struct{}) + agentDoneOnce := sync.Once{} + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + Remaining: nil, + }, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + _, ok := events.Next() + if !ok { + break + } + } + agentDoneOnce.Do(func() { + close(agentDone) + }) + return nil + }, + }) + + // Push with strategy — no turn is active yet, so tc should be nil. + loop.Push("item", WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { + atomic.AddInt32(&strategyCalled, 1) + strategyTCWasNil = (tc == nil) + return nil // plain push, no preempt + })) + + select { + case <-agentDone: + case <-time.After(2 * time.Second): + t.Fatal("agent did not complete") + } + + loop.Stop() + loop.Wait() + + assert.Equal(t, int32(1), atomic.LoadInt32(&strategyCalled)) + assert.True(t, strategyTCWasNil, "strategy should receive nil TurnContext between turns") +} + +func TestTurnLoop_PushStrategy_OverridesOtherOptions(t *testing.T) { + // Push with both WithPreempt and WithPushStrategy — only strategy's result applies. + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + return &AgentOutput{}, nil + }, + } + + agentDone := make(chan struct{}) + agentDoneOnce := sync.Once{} + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + Remaining: nil, + }, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + _, ok := events.Next() + if !ok { + break + } + } + agentDoneOnce.Do(func() { + close(agentDone) + }) + return nil + }, + }) + + // Strategy returns nil (no preempt), even though WithPreempt is also passed. + // The strategy should override — so the agent should NOT be preempted. + ok, ack := loop.Push("item", WithPreempt[string](AnySafePoint), WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { + return nil // no preempt + })) + assert.True(t, ok) + assert.Nil(t, ack, "ack should be nil since strategy returned no preempt") + + select { + case <-agentDone: + case <-time.After(2 * time.Second): + t.Fatal("agent did not complete normally") + } + + loop.Stop() + loop.Wait() +} + +func TestTurnLoop_PushStrategy_NestedStrategyStripped(t *testing.T) { + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + return &AgentOutput{}, nil + }, + } + + agentDone := make(chan struct{}) + agentDoneOnce := sync.Once{} + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + Remaining: nil, + }, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + _, ok := events.Next() + if !ok { + break + } + } + agentDoneOnce.Do(func() { + close(agentDone) + }) + return nil + }, + }) + + // Strategy returns another WithPushStrategy — the nested one should be stripped. + innerCalled := int32(0) + ok, ack := loop.Push("item", WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { + return []PushOption[string]{ + WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { + atomic.AddInt32(&innerCalled, 1) + return []PushOption[string]{WithPreempt[string](AnySafePoint)} + }), + } + })) + assert.True(t, ok) + assert.Nil(t, ack, "ack should be nil since nested strategy was stripped (no preempt)") + + select { + case <-agentDone: + case <-time.After(2 * time.Second): + t.Fatal("agent did not complete normally") + } + + loop.Stop() + loop.Wait() + + assert.Equal(t, int32(0), atomic.LoadInt32(&innerCalled), "nested strategy should not be called") +} + +func TestTurnLoop_PushStrategy_ConsumedInspection(t *testing.T) { + // Strategy preempts only when current turn is processing "low-priority" items. + agentStarted := make(chan struct{}) + agentStartedOnce := sync.Once{} + + agent := &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + agentStartedOnce.Do(func() { + close(agentStarted) + }) + <-ctx.Done() + return &AgentOutput{}, nil + }, + } + + genInputCalls := int32(0) + secondGenInputItems := make(chan []string, 1) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return agent, nil + }, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + count := atomic.AddInt32(&genInputCalls, 1) + if count >= 2 { + select { + case secondGenInputItems <- append([]string{}, items...): + default: + } + } + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: []string{items[0]}, + Remaining: items[1:], + }, nil + }, + }) + + loop.Push("low-priority-task") + + select { + case <-agentStarted: + case <-time.After(1 * time.Second): + t.Fatal("agent did not start") + } + + // Strategy checks Consumed and preempts because current turn has "low-priority" items. + loop.Push("urgent-task", WithPushStrategy(func(ctx context.Context, tc *TurnContext[string]) []PushOption[string] { + if tc != nil && len(tc.Consumed) > 0 && tc.Consumed[0] == "low-priority-task" { + return []PushOption[string]{WithPreempt[string](AnySafePoint)} + } + return nil + })) + + select { + case items := <-secondGenInputItems: + assert.Contains(t, items, "urgent-task") + case <-time.After(2 * time.Second): + t.Fatal("second GenInput was not called after strategy-driven preempt") + } + + loop.Stop() + loop.Wait() +} + +func TestTurnLoop_PushAfterStop_BufferedAsLateItems(t *testing.T) { + ctx := context.Background() + processed := make(chan string, 10) + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := events.Next(); !ok { + break + } + } + processed <- tc.Consumed[0] + return nil + }, + }) + + loop.Push("msg1") + <-processed + loop.Stop() + result := loop.Wait() + + // Push after stop — should be buffered as late items + ok1, _ := loop.Push("late1") + ok2, _ := loop.Push("late2") + ok3, _ := loop.Push("late3") + assert.False(t, ok1) + assert.False(t, ok2) + assert.False(t, ok3) + + late := result.TakeLateItems() + assert.Equal(t, []string{"late1", "late2", "late3"}, late) +} + +func TestTurnLoop_TakeLateItems_Idempotent(t *testing.T) { + ctx := context.Background() + + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop.Push("a") + loop.Stop() + loop.Run(ctx) + result := loop.Wait() + + loop.Push("late1") + + first := result.TakeLateItems() + second := result.TakeLateItems() + third := result.TakeLateItems() + + assert.Equal(t, []string{"late1"}, first) + assert.Equal(t, first, second, "subsequent calls should return the same slice") + assert.Equal(t, first, third, "subsequent calls should return the same slice") +} + +func TestTurnLoop_PushAfterTakeLateItems_Panics(t *testing.T) { + ctx := context.Background() + + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop.Push("a") + loop.Stop() + loop.Run(ctx) + result := loop.Wait() + + result.TakeLateItems() + + assert.PanicsWithValue(t, "TurnLoop: Push called after TakeLateItems", func() { + loop.Push("too-late") + }) +} + +func TestTurnLoop_TakeLateItems_NeverCalled_NoImpact(t *testing.T) { + ctx := context.Background() + + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop.Push("a") + loop.Push("b") + loop.Stop() + loop.Run(ctx) + result := loop.Wait() + + // Don't call TakeLateItems — verify UnhandledItems works normally + assert.Contains(t, result.UnhandledItems, "b") + assert.Nil(t, result.ExitReason) +} + +func TestTurnLoop_CheckpointErr_SeparateFromExitReason(t *testing.T) { + ctx := context.Background() + saveStore := &errorCheckpointStore{setErr: fmt.Errorf("storage unavailable")} + + loop := NewTurnLoop(TurnLoopConfig[string]{ + Store: saveStore, + CheckpointID: "cp-separate-err", + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop.Push("a") + loop.Stop() + loop.Run(ctx) + result := loop.Wait() + + // ExitReason should be nil (clean stop), checkpoint error should be separate + assert.Nil(t, result.ExitReason) + assert.True(t, result.Checkpointed) + assert.Error(t, result.CheckpointErr) + assert.Contains(t, result.CheckpointErr.Error(), "storage unavailable") +} + +func TestTurnLoop_Checkpointed_FalseWhenNoStore(t *testing.T) { + ctx := context.Background() + + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop.Push("a") + loop.Stop() + loop.Run(ctx) + result := loop.Wait() + + assert.False(t, result.Checkpointed) + assert.Nil(t, result.CheckpointErr) +} + +func TestTurnLoop_Checkpointed_FalseOnErrorExit(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + genInputErr := errors.New("gen input failed") + + firstTurnDone := make(chan struct{}) + var callCount int32 + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + Store: store, + CheckpointID: "cp-err-exit", + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + n := atomic.AddInt32(&callCount, 1) + if n > 1 { + return nil, genInputErr + } + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := events.Next(); !ok { + break + } + } + close(firstTurnDone) + return nil + }, + }) + loop.Push("msg1") + <-firstTurnDone + loop.Push("msg2") + result := loop.Wait() + + // Loop exited from error, not Stop() — checkpoint should not be saved + assert.ErrorIs(t, result.ExitReason, genInputErr) + assert.False(t, result.Checkpointed) + assert.Nil(t, result.CheckpointErr) +} + +func TestTurnLoop_StopConcurrentWithCallbackError_NoCheckpoint(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + cpID := "stop-concurrent-err" + + prepareErr := errors.New("prepare agent failed") + firstTurnDone := make(chan struct{}) + stopCalled := make(chan struct{}) + var prepareCount int32 + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + n := atomic.AddInt32(&prepareCount, 1) + if n > 1 { + // Wait until Stop() has been called so stopSig.isStopped() is true + <-stopCalled + return nil, prepareErr + } + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := events.Next(); !ok { + break + } + } + close(firstTurnDone) + return nil + }, + }) + + loop.Push("msg1") + <-firstTurnDone + loop.Push("msg2") + + // Call Stop() and signal PrepareAgent to proceed with error + go func() { + loop.Stop() + close(stopCalled) + }() + + result := loop.Wait() + + // The loop may exit via Stop (clean) or via PrepareAgent error. + // If it exited via PrepareAgent error with Stop also called: + // checkpoint should NOT be saved. + if result.ExitReason != nil && !errors.As(result.ExitReason, new(*CancelError)) { + assert.ErrorIs(t, result.ExitReason, prepareErr) + assert.False(t, result.Checkpointed, "should not checkpoint when exit is caused by callback error") + } + // If Stop won the race, that's fine — checkpoint may or may not be saved + // depending on idle state. The test is about the error path. +} + +func TestTurnLoop_DeleteWithoutCheckPointDeleter_NoOp(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + cpID := "no-deleter" + + // First loop: save a checkpoint + loop1 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop1.Push("a") + loop1.Stop() + loop1.Run(ctx) + loop1.Wait() + + store.mu.Lock() + _, exists := store.m[cpID] + store.mu.Unlock() + assert.True(t, exists, "checkpoint should be saved") + + // Second loop: exit via context cancel — should try to delete but store + // doesn't implement CheckPointDeleter, so checkpoint persists (no-op) + ctx2, cancel2 := context.WithCancel(ctx) + loop2 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + for { + if _, ok := events.Next(); !ok { + break + } + } + cancel2() + return nil + }, + }) + loop2.Push("b") + loop2.Run(ctx2) + loop2.Wait() + + // Without CheckPointDeleter, the stale checkpoint should NOT be deleted + store.mu.Lock() + v, exists := store.m[cpID] + store.mu.Unlock() + assert.True(t, exists, "checkpoint should still exist without CheckPointDeleter") + assert.NotNil(t, v, "checkpoint should not be set to nil") +} + +func TestTurnLoop_StopWithSkipCheckpoint(t *testing.T) { + ctx := context.Background() + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + cpID := "skip-cp-session" + + loop := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Push("a") + loop.Push("b") + loop.Stop(WithSkipCheckpoint()) + loop.Run(ctx) + + exit := loop.Wait() + assert.NoError(t, exit.ExitReason) + assert.False(t, exit.Checkpointed, "checkpoint should be skipped when WithSkipCheckpoint is used") + + store.mu.Lock() + _, exists := store.m[cpID] + store.mu.Unlock() + assert.False(t, exists, "no checkpoint should be saved when WithSkipCheckpoint is used") +} + +func TestTurnLoop_StopWithSkipCheckpoint_DeletesStaleCheckpoint(t *testing.T) { + ctx := context.Background() + store := &deletableCheckpointStore{ + turnLoopCheckpointStore: turnLoopCheckpointStore{m: make(map[string][]byte)}, + } + cpID := "skip-stale-session" + + loop1 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop1.Push("a") + loop1.Stop() + loop1.Run(ctx) + exit1 := loop1.Wait() + assert.True(t, exit1.Checkpointed) + + store.mu.Lock() + _, exists := store.m[cpID] + store.mu.Unlock() + assert.True(t, exists, "first loop should save checkpoint") + + loop2 := NewTurnLoop(TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + loop2.Push("b") + loop2.Stop(WithSkipCheckpoint()) + loop2.Run(ctx) + exit2 := loop2.Wait() + assert.False(t, exit2.Checkpointed, "second loop should skip checkpoint") + + store.mu.Lock() + deleteCalled := store.deleteCalled + store.mu.Unlock() + assert.True(t, deleteCalled, "stale checkpoint should be deleted when SkipCheckpoint is used") +} + +func TestTurnLoop_StopWithStopCause(t *testing.T) { + ctx := context.Background() + cause := "user session timeout" + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Push("a") + loop.Stop(WithStopCause(cause)) + + exit := loop.Wait() + assert.Equal(t, cause, exit.StopCause) +} + +func TestTurnLoop_StopCause_EmptyWhenNoStop(t *testing.T) { + ctx := context.Background() + + loop := newAndRunTurnLoop(ctx, TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{name: "test"}, nil + }, + }) + + loop.Stop() + exit := loop.Wait() + assert.Empty(t, exit.StopCause) +} + +func TestTurnLoop_StopCause_InTurnContext(t *testing.T) { + cause := "business shutdown" + gotCause := make(chan string, 1) + agentStarted := make(chan struct{}) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopCancellableMockAgent{ + name: "slow", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + <-ctx.Done() + return nil, ctx.Err() + }, + }, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + close(agentStarted) + select { + case <-tc.Stopped: + gotCause <- tc.StopCause() + case <-time.After(5 * time.Second): + t.Error("timed out waiting for Stopped channel") + } + for { + if _, ok := events.Next(); !ok { + break + } + } + return nil + }, + }) + + loop.Push("msg1") + <-agentStarted + loop.Stop(WithStopCause(cause)) + + select { + case c := <-gotCause: + assert.Equal(t, cause, c) + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for StopCause in TurnContext") + } + + exit := loop.Wait() + assert.Equal(t, cause, exit.StopCause) +} + +func TestTurnLoop_StopCause_FirstNonEmptyWins(t *testing.T) { + agentStarted := make(chan struct{}) + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopCancellableMockAgent{ + name: "slow", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + <-ctx.Done() + return nil, ctx.Err() + }, + }, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + close(agentStarted) + for { + if _, ok := events.Next(); !ok { + break + } + } + return nil + }, + }) + + loop.Push("msg1") + <-agentStarted + loop.Stop(WithGraceful(), WithStopCause("first cause")) + loop.Stop(WithStopCause("second cause")) + + exit := loop.Wait() + assert.Equal(t, "first cause", exit.StopCause, "first non-empty StopCause should win") +} + +func TestTurnLoop_StopBeforeRun_PushThenStop(t *testing.T) { + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + t.Fatal("GenInput should not be called when Stop is called before Run") + return nil, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + t.Fatal("PrepareAgent should not be called when Stop is called before Run") + return nil, nil + }, + }) + + ok, _ := loop.Push("item1") + assert.True(t, ok) + ok, _ = loop.Push("item2") + assert.True(t, ok) + + loop.Stop() + loop.Run(context.Background()) + result := loop.Wait() + + assert.NoError(t, result.ExitReason) + assert.Equal(t, []string{"item1", "item2"}, result.UnhandledItems) + assert.Empty(t, result.CanceledItems) + assert.Empty(t, result.TakeLateItems()) +} + +func TestTurnLoop_StopBeforeRun_StopThenPush(t *testing.T) { + loop := NewTurnLoop(TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + t.Fatal("GenInput should not be called when Stop is called before Run") + return nil, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + t.Fatal("PrepareAgent should not be called when Stop is called before Run") + return nil, nil + }, + }) + + loop.Stop() + + ok, _ := loop.Push("item1") + assert.False(t, ok) + ok, _ = loop.Push("item2") + assert.False(t, ok) + + loop.Run(context.Background()) + result := loop.Wait() + + assert.NoError(t, result.ExitReason) + assert.Empty(t, result.UnhandledItems) + assert.Empty(t, result.CanceledItems) + assert.Equal(t, []string{"item1", "item2"}, result.TakeLateItems()) +} + +func TestTurnLoop_SkipCheckpoint_Sticky(t *testing.T) { + agentStarted := make(chan struct{}) + + store := &turnLoopCheckpointStore{m: make(map[string][]byte)} + cpID := "sticky-skip-session" + + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + Store: store, + CheckpointID: cpID, + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopCancellableMockAgent{ + name: "slow", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + <-ctx.Done() + return nil, ctx.Err() + }, + }, nil + }, + OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error { + close(agentStarted) + for { + if _, ok := events.Next(); !ok { + break + } + } + return nil + }, + }) + + loop.Push("msg1") + <-agentStarted + loop.Stop(WithGraceful(), WithSkipCheckpoint()) + loop.Stop() + + exit := loop.Wait() + assert.False(t, exit.Checkpointed, "SkipCheckpoint should be sticky across multiple Stop calls") + + store.mu.Lock() + _, exists := store.m[cpID] + store.mu.Unlock() + assert.False(t, exists, "no checkpoint should be saved when SkipCheckpoint was set in any Stop call") +} + +func TestWithGracefulTimeout_NonPositive_Panics(t *testing.T) { + assert.PanicsWithValue(t, "adk: WithGracefulTimeout: gracePeriod must be positive", + func() { WithGracefulTimeout(0) }) + assert.PanicsWithValue(t, "adk: WithGracefulTimeout: gracePeriod must be positive", + func() { WithGracefulTimeout(-1 * time.Second) }) +} + +func TestWithPreempt_ZeroSafePoint_Panics(t *testing.T) { + assert.PanicsWithValue(t, "adk: SafePoint must not be zero; use AfterToolCalls, AfterChatModel, or AnySafePoint", + func() { WithPreempt[string](SafePoint(0)) }) +} + +func TestWithPreemptTimeout_ZeroSafePoint_Panics(t *testing.T) { + assert.PanicsWithValue(t, "adk: SafePoint must not be zero; use AfterToolCalls, AfterChatModel, or AnySafePoint", + func() { WithPreemptTimeout[string](SafePoint(0), time.Second) }) +} + +func TestSafePoint_ToCancelMode(t *testing.T) { + assert.Equal(t, CancelAfterToolCalls, AfterToolCalls.toCancelMode()) + assert.Equal(t, CancelAfterChatModel, AfterChatModel.toCancelMode()) + assert.Equal(t, CancelAfterToolCalls|CancelAfterChatModel, AnySafePoint.toCancelMode()) +} + +func TestNewTurnLoop_NilGenInput_Panics(t *testing.T) { + assert.PanicsWithValue(t, "adk: NewTurnLoop: GenInput is required", func() { + NewTurnLoop(TurnLoopConfig[string]{PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { return nil, nil }}) + }) +} + +func TestNewTurnLoop_NilPrepareAgent_Panics(t *testing.T) { + assert.PanicsWithValue(t, "adk: NewTurnLoop: PrepareAgent is required", func() { + NewTurnLoop(TurnLoopConfig[string]{GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return nil, nil + }}) + }) +} + +func TestDeriveChild_NilParent_ReturnsNil(t *testing.T) { + var cc *cancelContext + assert.Nil(t, cc.deriveChild(context.Background())) +} + +func TestUntilIdleFor(t *testing.T) { + t.Run("FiresAfterIdleDuration", func(t *testing.T) { + turnDone := make(chan struct{}) + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + close(turnDone) + return &AgentOutput{}, nil + }, + }, nil + }, + }) + + loop.Push("msg1") + <-turnDone + + loop.Stop(UntilIdleFor(50 * time.Millisecond)) + + done := make(chan struct{}) + go func() { + loop.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("loop did not exit after idle timeout") + } + }) + + t.Run("ResetsOnPush", func(t *testing.T) { + turnCount := int32(0) + turnDone := make(chan struct{}, 10) + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + atomic.AddInt32(&turnCount, 1) + turnDone <- struct{}{} + return &AgentOutput{}, nil + }, + }, nil + }, + }) + + loop.Push("msg1") + <-turnDone + + loop.Stop(UntilIdleFor(200 * time.Millisecond)) + + time.Sleep(100 * time.Millisecond) + loop.Push("msg2") + <-turnDone + + done := make(chan struct{}) + go func() { + loop.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("loop did not exit after idle timeout") + } + + assert.Equal(t, int32(2), atomic.LoadInt32(&turnCount)) + }) + + t.Run("EscalatedByStopWithImmediate", func(t *testing.T) { + agentStarted := make(chan *cancelContext, 1) + probe := &turnLoopStopModeProbeAgent{ccCh: agentStarted} + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return probe, nil + }, + }) + + loop.Push("msg1") + cc := <-agentStarted + + loop.Stop(UntilIdleFor(10 * time.Minute)) + loop.Stop(WithImmediate()) + + deadline := time.After(2 * time.Second) + for { + if cc.getMode() == CancelImmediate { + break + } + select { + case <-deadline: + t.Fatal("cancel mode did not escalate to CancelImmediate") + default: + } + time.Sleep(1 * time.Millisecond) + } + + exit := loop.Wait() + var ce *CancelError + require.True(t, errors.As(exit.ExitReason, &ce)) + assert.Equal(t, CancelImmediate, ce.Info.Mode) + }) + + t.Run("EscalatedByStopWithGraceful", func(t *testing.T) { + agentStarted := make(chan struct{}) + agentDone := make(chan struct{}) + loop := newAndRunTurnLoop(context.Background(), TurnLoopConfig[string]{ + GenInput: func(ctx context.Context, _ *TurnLoop[string], items []string) (*GenInputResult[string], error) { + return &GenInputResult[string]{ + Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, + Consumed: items, + }, nil + }, + PrepareAgent: func(ctx context.Context, _ *TurnLoop[string], consumed []string) (Agent, error) { + return &turnLoopCancellableMockAgent{ + name: "test", + runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) { + close(agentStarted) + <-ctx.Done() + close(agentDone) + return nil, ctx.Err() + }, + }, nil + }, + }) + + loop.Push("msg1") + <-agentStarted + + loop.Stop(UntilIdleFor(10 * time.Minute)) + loop.Stop(WithGracefulTimeout(50 * time.Millisecond)) + + select { + case <-agentDone: + case <-time.After(2 * time.Second): + t.Fatal("agent was not cancelled") + } + + exit := loop.Wait() + assert.Error(t, exit.ExitReason) + }) +} + +func TestUntilIdleFor_NonPositive_Panics(t *testing.T) { + assert.PanicsWithValue(t, "adk: UntilIdleFor: duration must be positive", + func() { UntilIdleFor(0) }) + assert.PanicsWithValue(t, "adk: UntilIdleFor: duration must be positive", + func() { UntilIdleFor(-1 * time.Second) }) +} diff --git a/adk/utils.go b/adk/utils.go index ee061f71f..24eba904b 100644 --- a/adk/utils.go +++ b/adk/utils.go @@ -44,6 +44,10 @@ func (ag *AsyncGenerator[T]) Send(v T) { ag.ch.Send(v) } +func (ag *AsyncGenerator[T]) trySend(v T) bool { + return ag.ch.TrySend(v) +} + func (ag *AsyncGenerator[T]) Close() { ag.ch.Close() } @@ -124,16 +128,28 @@ func getMessageFromWrappedEvent(e *agentEventWrapper) (Message, error) { return nil, e.StreamErr } + e.consumeStream() + + if e.StreamErr != nil { + return nil, e.StreamErr + } + return e.concatenatedMessage, nil +} + +// consumeStream drains the message stream, setting concatenatedMessage on +// success or StreamErr on failure. The stream is always replaced with an +// error-free, materialized version safe for gob encoding. +// Must be called at most once (guarded by callers checking concatenatedMessage/StreamErr). +func (e *agentEventWrapper) consumeStream() { e.mu.Lock() defer e.mu.Unlock() + if e.concatenatedMessage != nil { - return e.concatenatedMessage, nil + return } - var ( - msgs []Message - s = e.AgentEvent.Output.MessageOutput.MessageStream - ) + s := e.AgentEvent.Output.MessageOutput.MessageStream + var msgs []Message defer s.Close() for { @@ -143,19 +159,16 @@ func getMessageFromWrappedEvent(e *agentEventWrapper) (Message, error) { break } e.StreamErr = err - // Replace the stream with successfully received messages only (no error at the end). - // The error is preserved in StreamErr for users to check. - // We intentionally exclude the error from the new stream to ensure gob encoding - // compatibility, as the stream may be consumed during serialization. e.AgentEvent.Output.MessageOutput.MessageStream = schema.StreamReaderFromArray(msgs) - return nil, err + return } - msgs = append(msgs, msg) } if len(msgs) == 0 { - return nil, errors.New("no messages in MessageVariant.MessageStream") + e.StreamErr = errors.New("no messages in MessageVariant.MessageStream") + e.AgentEvent.Output.MessageOutput.MessageStream = schema.StreamReaderFromArray(msgs) + return } if len(msgs) == 1 { @@ -166,11 +179,11 @@ func getMessageFromWrappedEvent(e *agentEventWrapper) (Message, error) { if err != nil { e.StreamErr = err e.AgentEvent.Output.MessageOutput.MessageStream = schema.StreamReaderFromArray(msgs) - return nil, err + return } } - return e.concatenatedMessage, nil + e.AgentEvent.Output.MessageOutput.MessageStream = schema.StreamReaderFromArray([]Message{e.concatenatedMessage}) } // copyAgentEvent copies an AgentEvent. diff --git a/adk/workflow.go b/adk/workflow.go index 9d63d7347..00411e33b 100644 --- a/adk/workflow.go +++ b/adk/workflow.go @@ -175,7 +175,6 @@ func (a *workflowAgent) runSequential(ctx context.Context, startIdx := 0 - // seqCtx tracks the accumulated RunPath across the sequence. seqCtx := ctx // If we are resuming, find which sub-agent to start from and prepare its context. @@ -193,12 +192,28 @@ func (a *workflowAgent) runSequential(ctx context.Context, for i := startIdx; i < len(a.subAgents); i++ { subAgent := a.subAgents[i] + // Cancel check at transition boundary between sub-agents. + // Transition boundaries are always safe to cancel at — no sub-agent + // work is in progress, so any cancel mode is honoured. + if cancelCtx := getCancelContext(ctx); cancelCtx != nil && cancelCtx.shouldCancel() { + state := &sequentialWorkflowState{InterruptIndex: i} + event := cancelAtTransition(ctx, "Sequential workflow cancel at transition", state) + generator.Send(event) + return nil + } + var subIterator *AsyncIterator[*AgentEvent] if seqState != nil { - subIterator = subAgent.Resume(seqCtx, &ResumeInfo{ - EnableStreaming: info.EnableStreaming, - InterruptInfo: info.Data.(*WorkflowInterruptInfo).SequentialInterruptInfo, - }, opts...) + wfInfo, _ := info.Data.(*WorkflowInterruptInfo) + if wfInfo != nil && wfInfo.SequentialInterruptInfo != nil { + // Sub-agent was interrupted — resume it. + subIterator = subAgent.Resume(seqCtx, &ResumeInfo{ + EnableStreaming: info.EnableStreaming, + InterruptInfo: wfInfo.SequentialInterruptInfo, + }, opts...) + } else { + subIterator = subAgent.Run(seqCtx, nil, opts...) + } seqState = nil } else { subIterator = subAgent.Run(seqCtx, nil, opts...) @@ -304,7 +319,6 @@ func (a *workflowAgent) runLoop(ctx context.Context, generator *AsyncGenerator[* startIter := 0 startIdx := 0 - // loopCtx tracks the accumulated RunPath across the full sequence within a single iteration. loopCtx := ctx if loopState != nil { @@ -329,13 +343,25 @@ func (a *workflowAgent) runLoop(ctx context.Context, generator *AsyncGenerator[* for j := startIdx; j < len(a.subAgents); j++ { subAgent := a.subAgents[j] + if cancelCtx := getCancelContext(ctx); cancelCtx != nil && cancelCtx.shouldCancel() { + state := &loopWorkflowState{LoopIterations: i, SubAgentIndex: j} + event := cancelAtTransition(ctx, "Loop workflow cancel at transition", state) + generator.Send(event) + return nil + } + var subIterator *AsyncIterator[*AgentEvent] if loopState != nil { - // This is the agent we need to resume. - subIterator = subAgent.Resume(loopCtx, &ResumeInfo{ - EnableStreaming: resumeInfo.EnableStreaming, - InterruptInfo: resumeInfo.Data.(*WorkflowInterruptInfo).SequentialInterruptInfo, - }, opts...) + wfInfo, _ := resumeInfo.Data.(*WorkflowInterruptInfo) + if wfInfo != nil && wfInfo.SequentialInterruptInfo != nil { + // Sub-agent was interrupted — resume it. + subIterator = subAgent.Resume(loopCtx, &ResumeInfo{ + EnableStreaming: resumeInfo.EnableStreaming, + InterruptInfo: wfInfo.SequentialInterruptInfo, + }, opts...) + } else { + subIterator = subAgent.Run(loopCtx, nil, opts...) + } loopState = nil // Only resume the first time. } else { subIterator = subAgent.Run(loopCtx, nil, opts...) @@ -468,6 +494,15 @@ func (a *workflowAgent) runParallel(ctx context.Context, generator *AsyncGenerat } } + // Cancel check before spawning parallel goroutines. No sub-agent work + // is in progress, so any cancel mode is honoured at this boundary. + if cancelCtx := getCancelContext(ctx); cancelCtx != nil && cancelCtx.shouldCancel() { + state := ¶llelWorkflowState{} + event := cancelAtTransition(ctx, "Parallel workflow cancel before spawn", state) + generator.Send(event) + return nil + } + for i := range a.subAgents { wg.Add(1) go func(idx int, agent *flowAgent) { @@ -483,11 +518,13 @@ func (a *workflowAgent) runParallel(ctx context.Context, generator *AsyncGenerat var iterator *AsyncIterator[*AgentEvent] if _, ok := agentNames[agent.Name(ctx)]; ok { - // This branch was interrupted and needs to be resumed. - iterator = agent.Resume(childContexts[idx], &ResumeInfo{ + childResumeInfo := &ResumeInfo{ EnableStreaming: resumeInfo.EnableStreaming, - InterruptInfo: resumeInfo.Data.(*WorkflowInterruptInfo).ParallelInterruptInfo[idx], - }, opts...) + } + if wfInfo, ok := resumeInfo.Data.(*WorkflowInterruptInfo); ok && wfInfo != nil { + childResumeInfo.InterruptInfo = wfInfo.ParallelInterruptInfo[idx] + } + iterator = agent.Resume(childContexts[idx], childResumeInfo, opts...) } else if parState != nil { // We are resuming, but this child is not in the next points map. // This means it finished successfully, so we don't run it. @@ -550,6 +587,27 @@ func (a *workflowAgent) runParallel(ctx context.Context, generator *AsyncGenerat return nil } +func cancelAtTransition(ctx context.Context, info string, state any) *AgentEvent { + // state is the workflow checkpoint state (e.g. sequentialWorkflowState); + // nil for subContexts because this is a leaf interrupt with no child signals. + is, err := core.Interrupt(ctx, info, state, nil, + core.WithLayerPayload(getRunCtx(ctx).RunPath)) + if err != nil { + return &AgentEvent{Err: err} + } + + contexts := core.ToInterruptContexts(is, allowedAddressSegmentTypes) + + return &AgentEvent{ + Action: &AgentAction{ + Interrupted: &InterruptInfo{ + InterruptContexts: contexts, + }, + internalInterrupted: is, + }, + } +} + type SequentialAgentConfig struct { Name string Description string diff --git a/adk/wrappers.go b/adk/wrappers.go index 5061f5be8..b4e16d298 100644 --- a/adk/wrappers.go +++ b/adk/wrappers.go @@ -34,26 +34,35 @@ type generateEndpoint func(ctx context.Context, input []*schema.Message, opts .. type streamEndpoint func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) type modelWrapperConfig struct { - handlers []ChatModelAgentMiddleware - middlewares []AgentMiddleware - retryConfig *ModelRetryConfig - toolInfos []*schema.ToolInfo + handlers []ChatModelAgentMiddleware + middlewares []AgentMiddleware + retryConfig *ModelRetryConfig + failoverConfig *ModelFailoverConfig + toolInfos []*schema.ToolInfo + cancelContext *cancelContext } func buildModelWrappers(m model.BaseChatModel, config *modelWrapperConfig) model.BaseChatModel { var wrapped model.BaseChatModel = m - if !components.IsCallbacksEnabled(m) { + // failoverProxyModel must be the innermost wrapper to read the selected failover model from context. + if config.failoverConfig != nil { + wrapped = &failoverProxyModel{} + } + + if !components.IsCallbacksEnabled(wrapped) { wrapped = (&callbackInjectionModelWrapper{}).WrapModel(wrapped) } wrapped = &stateModelWrapper{ - inner: wrapped, - original: m, - handlers: config.handlers, - middlewares: config.middlewares, - toolInfos: config.toolInfos, - modelRetryConfig: config.retryConfig, + inner: wrapped, + original: m, + handlers: config.handlers, + middlewares: config.middlewares, + toolInfos: config.toolInfos, + modelRetryConfig: config.retryConfig, + modelFailoverConfig: config.failoverConfig, + cancelContext: config.cancelContext, } return wrapped @@ -252,16 +261,28 @@ func NewEventSenderModelWrapper() ChatModelAgentMiddleware { } func (w *eventSenderModelWrapper) WrapModel(_ context.Context, m model.BaseChatModel, mc *ModelContext) (model.BaseChatModel, error) { + inner := m + if mc != nil && mc.cancelContext != nil { + inner = &cancelMonitoredModel{ + inner: inner, + cancelContext: mc.cancelContext, + } + } var retryConfig *ModelRetryConfig if mc != nil { retryConfig = mc.ModelRetryConfig } - return &eventSenderModel{inner: m, modelRetryConfig: retryConfig}, nil + var failoverConfig *ModelFailoverConfig + if mc != nil { + failoverConfig = mc.ModelFailoverConfig + } + return &eventSenderModel{inner: inner, modelRetryConfig: retryConfig, modelFailoverConfig: failoverConfig}, nil } type eventSenderModel struct { - inner model.BaseChatModel - modelRetryConfig *ModelRetryConfig + inner model.BaseChatModel + modelRetryConfig *ModelRetryConfig + modelFailoverConfig *ModelFailoverConfig } func (m *eventSenderModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { @@ -294,19 +315,12 @@ func (m *eventSenderModel) Stream(ctx context.Context, input []*schema.Message, return nil, errors.New("generator is nil when sending event in Stream: ensure agent state is properly initialized") } - var retryAttempt int - _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { - retryAttempt = st.getRetryAttempt() - return nil - }) - streams := result.Copy(2) eventStream := streams[0] - if m.modelRetryConfig != nil { + if errWrapper := m.buildErrWrapper(ctx); errWrapper != nil { convertOpts := []schema.ConvertOption{ - schema.WithErrWrapper(genErrWrapper(ctx, m.modelRetryConfig.MaxRetries, - retryAttempt, m.modelRetryConfig.IsRetryAble)), + schema.WithErrWrapper(errWrapper), } eventStream = schema.StreamReaderWithConvert(streams[0], func(msg *schema.Message) (*schema.Message, error) { return msg, nil }, @@ -319,6 +333,51 @@ func (m *eventSenderModel) Stream(ctx context.Context, input []*schema.Message, return streams[1], nil } +// buildErrWrapper constructs an error wrapper function for event streams. +// It wraps stream errors as WillRetryError when retry or failover is configured, +// so that flow.go:genAgentInput() can skip events from failed attempts instead of +// treating them as fatal errors. +func (m *eventSenderModel) buildErrWrapper(ctx context.Context) func(error) error { + var retryAttempt int + _ = compose.ProcessState(ctx, func(_ context.Context, st *State) error { + retryAttempt = st.getRetryAttempt() + return nil + }) + + var retryWrapper func(error) error + if m.modelRetryConfig != nil { + retryWrapper = genErrWrapper(ctx, m.modelRetryConfig.MaxRetries, retryAttempt, m.modelRetryConfig.IsRetryAble) + } + + hasFailover := m.modelFailoverConfig != nil + // failoverHasMoreAttempts is set by failoverModelWrapper before each inner call. + // It is true when additional failover attempts remain after the current one, + // meaning stream errors should be wrapped as WillRetryError so the flow layer + // skips them. On the final attempt it is false, so the error propagates normally. + failoverHasMore := getFailoverHasMoreAttempts(ctx) + + if retryWrapper == nil && !(hasFailover && failoverHasMore) { + return nil + } + + return func(err error) error { + // If retry is configured and will retry this error, use the retry wrapper's WillRetryError. + if retryWrapper != nil { + wrapped := retryWrapper(err) + if _, ok := wrapped.(*WillRetryError); ok { + return wrapped + } + } + // Retry won't handle this error (either exhausted or not configured), but + // failover still has more attempts remaining. Wrap it as WillRetryError so + // the flow layer skips this event from the failed attempt. + if hasFailover && failoverHasMore { + return &WillRetryError{ErrStr: err.Error(), err: err} + } + return err + } +} + func popToolGenAction(ctx context.Context, toolName string) *AgentAction { toolCallID := compose.GetToolCallID(ctx) @@ -341,20 +400,33 @@ func popToolGenAction(ctx context.Context, toolName string) *AgentAction { return action } -type eventSenderToolHandler struct{} +type eventSenderToolWrapper struct { + *BaseChatModelAgentMiddleware +} -func (h *eventSenderToolHandler) WrapInvokableToolCall(next compose.InvokableToolEndpoint) compose.InvokableToolEndpoint { - return func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) { - output, err := next(ctx, input) +// NewEventSenderToolWrapper returns a ChatModelAgentMiddleware that sends tool result events. +// By default, the framework places this before all user middlewares (outermost), so events +// reflect the fully processed tool output. To control exactly where events are emitted, +// include this in ChatModelAgentConfig.Handlers at the desired position. +// When detected in Handlers, the framework skips the default event sender to avoid duplicates. +func NewEventSenderToolWrapper() ChatModelAgentMiddleware { + return &eventSenderToolWrapper{ + BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, + } +} + +func (w *eventSenderToolWrapper) WrapInvokableToolCall(_ context.Context, endpoint InvokableToolCallEndpoint, tCtx *ToolContext) (InvokableToolCallEndpoint, error) { + return func(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { + result, err := endpoint(ctx, argumentsInJSON, opts...) if err != nil { - return nil, err + return "", err } - toolName := input.Name - callID := input.CallID + toolName := tCtx.Name + callID := tCtx.CallID prePopAction := popToolGenAction(ctx, toolName) - msg := schema.ToolMessage(output.Result, callID, schema.WithToolName(toolName)) + msg := schema.ToolMessage(result, callID, schema.WithToolName(toolName)) event := EventFromMessage(msg, nil, schema.Tool, toolName) if prePopAction != nil { event.Action = prePopAction @@ -370,22 +442,22 @@ func (h *eventSenderToolHandler) WrapInvokableToolCall(next compose.InvokableToo return nil }) - return output, nil - } + return result, nil + }, nil } -func (h *eventSenderToolHandler) WrapStreamableToolCall(next compose.StreamableToolEndpoint) compose.StreamableToolEndpoint { - return func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) { - output, err := next(ctx, input) +func (w *eventSenderToolWrapper) WrapStreamableToolCall(_ context.Context, endpoint StreamableToolCallEndpoint, tCtx *ToolContext) (StreamableToolCallEndpoint, error) { + return func(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (*schema.StreamReader[string], error) { + result, err := endpoint(ctx, argumentsInJSON, opts...) if err != nil { return nil, err } - toolName := input.Name - callID := input.CallID + toolName := tCtx.Name + callID := tCtx.CallID prePopAction := popToolGenAction(ctx, toolName) - streams := output.Result.Copy(2) + streams := result.Copy(2) cvt := func(in string) (Message, error) { return schema.ToolMessage(in, callID, schema.WithToolName(toolName)), nil @@ -404,23 +476,23 @@ func (h *eventSenderToolHandler) WrapStreamableToolCall(next compose.StreamableT return nil }) - return &compose.StreamToolOutput{Result: streams[1]}, nil - } + return streams[1], nil + }, nil } -func (h *eventSenderToolHandler) WrapEnhancedInvokableToolCall(next compose.EnhancedInvokableToolEndpoint) compose.EnhancedInvokableToolEndpoint { - return func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedInvokableToolOutput, error) { - output, err := next(ctx, input) +func (w *eventSenderToolWrapper) WrapEnhancedInvokableToolCall(_ context.Context, endpoint EnhancedInvokableToolCallEndpoint, tCtx *ToolContext) (EnhancedInvokableToolCallEndpoint, error) { + return func(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (*schema.ToolResult, error) { + result, err := endpoint(ctx, toolArgument, opts...) if err != nil { return nil, err } - toolName := input.Name - callID := input.CallID + toolName := tCtx.Name + callID := tCtx.CallID prePopAction := popToolGenAction(ctx, toolName) msg := schema.ToolMessage("", callID, schema.WithToolName(toolName)) - msg.UserInputMultiContent, err = output.Result.ToMessageInputParts() + msg.UserInputMultiContent, err = result.ToMessageInputParts() if err != nil { return nil, err } @@ -439,22 +511,22 @@ func (h *eventSenderToolHandler) WrapEnhancedInvokableToolCall(next compose.Enha return nil }) - return output, nil - } + return result, nil + }, nil } -func (h *eventSenderToolHandler) WrapEnhancedStreamableToolCall(next compose.EnhancedStreamableToolEndpoint) compose.EnhancedStreamableToolEndpoint { - return func(ctx context.Context, input *compose.ToolInput) (*compose.EnhancedStreamableToolOutput, error) { - output, err := next(ctx, input) +func (w *eventSenderToolWrapper) WrapEnhancedStreamableToolCall(_ context.Context, endpoint EnhancedStreamableToolCallEndpoint, tCtx *ToolContext) (EnhancedStreamableToolCallEndpoint, error) { + return func(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (*schema.StreamReader[*schema.ToolResult], error) { + result, err := endpoint(ctx, toolArgument, opts...) if err != nil { return nil, err } - toolName := input.Name - callID := input.CallID + toolName := tCtx.Name + callID := tCtx.CallID prePopAction := popToolGenAction(ctx, toolName) - streams := output.Result.Copy(2) + streams := result.Copy(2) cvt := func(in *schema.ToolResult) (Message, error) { msg := schema.ToolMessage("", callID, schema.WithToolName(toolName)) @@ -479,17 +551,28 @@ func (h *eventSenderToolHandler) WrapEnhancedStreamableToolCall(next compose.Enh return nil }) - return &compose.EnhancedStreamableToolOutput{Result: streams[1]}, nil + return streams[1], nil + }, nil +} + +func hasUserEventSenderToolWrapper(handlers []ChatModelAgentMiddleware) bool { + for _, handler := range handlers { + if _, ok := handler.(*eventSenderToolWrapper); ok { + return true + } } + return false } type stateModelWrapper struct { - inner model.BaseChatModel - original model.BaseChatModel - handlers []ChatModelAgentMiddleware - middlewares []AgentMiddleware - toolInfos []*schema.ToolInfo - modelRetryConfig *ModelRetryConfig + inner model.BaseChatModel + original model.BaseChatModel + handlers []ChatModelAgentMiddleware + middlewares []AgentMiddleware + toolInfos []*schema.ToolInfo + modelRetryConfig *ModelRetryConfig + modelFailoverConfig *ModelFailoverConfig + cancelContext *cancelContext } func (w *stateModelWrapper) IsCallbacksEnabled() bool { @@ -515,6 +598,8 @@ func (w *stateModelWrapper) hasUserEventSender() bool { func (w *stateModelWrapper) wrapGenerateEndpoint(endpoint generateEndpoint) generateEndpoint { hasUserEventSender := w.hasUserEventSender() retryConfig := w.modelRetryConfig + failoverConfig := w.modelFailoverConfig + cc := w.cancelContext for i := len(w.handlers) - 1; i >= 0; i-- { handler := w.handlers[i] @@ -523,7 +608,7 @@ func (w *stateModelWrapper) wrapGenerateEndpoint(endpoint generateEndpoint) gene endpoint = func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { baseOpts := &model.Options{Tools: baseToolInfos} commonOpts := model.GetCommonOptions(baseOpts, opts...) - mc := &ModelContext{Tools: commonOpts.Tools, ModelRetryConfig: retryConfig} + mc := &ModelContext{Tools: commonOpts.Tools, ModelRetryConfig: retryConfig, cancelContext: cc} wrappedModel, err := handler.WrapModel(ctx, &endpointModel{generate: innerEndpoint}, mc) if err != nil { return nil, err @@ -540,7 +625,7 @@ func (w *stateModelWrapper) wrapGenerateEndpoint(endpoint generateEndpoint) gene if execCtx == nil || execCtx.generator == nil { return innerEndpoint(ctx, input, opts...) } - mc := &ModelContext{ModelRetryConfig: retryConfig} + mc := &ModelContext{ModelRetryConfig: retryConfig, ModelFailoverConfig: failoverConfig, cancelContext: cc} wrappedModel, err := eventSender.WrapModel(ctx, &endpointModel{generate: innerEndpoint}, mc) if err != nil { return nil, err @@ -557,12 +642,24 @@ func (w *stateModelWrapper) wrapGenerateEndpoint(endpoint generateEndpoint) gene } } + // Needs to handle failoverWrapper after retryWrapper + if w.modelFailoverConfig != nil { + config := w.modelFailoverConfig + innerEndpoint := endpoint + endpoint = func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + failoverWrapper := newFailoverModelWrapper(&endpointModel{generate: innerEndpoint}, config) + return failoverWrapper.Generate(ctx, input, opts...) + } + } + return endpoint } func (w *stateModelWrapper) wrapStreamEndpoint(endpoint streamEndpoint) streamEndpoint { hasUserEventSender := w.hasUserEventSender() retryConfig := w.modelRetryConfig + failoverConfig := w.modelFailoverConfig + cc := w.cancelContext for i := len(w.handlers) - 1; i >= 0; i-- { handler := w.handlers[i] @@ -571,7 +668,7 @@ func (w *stateModelWrapper) wrapStreamEndpoint(endpoint streamEndpoint) streamEn endpoint = func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { baseOpts := &model.Options{Tools: baseToolInfos} commonOpts := model.GetCommonOptions(baseOpts, opts...) - mc := &ModelContext{Tools: commonOpts.Tools, ModelRetryConfig: retryConfig} + mc := &ModelContext{Tools: commonOpts.Tools, ModelRetryConfig: retryConfig, cancelContext: cc} wrappedModel, err := handler.WrapModel(ctx, &endpointModel{stream: innerEndpoint}, mc) if err != nil { return nil, err @@ -588,7 +685,7 @@ func (w *stateModelWrapper) wrapStreamEndpoint(endpoint streamEndpoint) streamEn if execCtx == nil || execCtx.generator == nil { return innerEndpoint(ctx, input, opts...) } - mc := &ModelContext{ModelRetryConfig: retryConfig} + mc := &ModelContext{ModelRetryConfig: retryConfig, ModelFailoverConfig: failoverConfig, cancelContext: cc} wrappedModel, err := eventSender.WrapModel(ctx, &endpointModel{stream: innerEndpoint}, mc) if err != nil { return nil, err @@ -605,6 +702,16 @@ func (w *stateModelWrapper) wrapStreamEndpoint(endpoint streamEndpoint) streamEn } } + // Needs to handle failoverWrapper after retryWrapper + if w.modelFailoverConfig != nil { + config := w.modelFailoverConfig + innerEndpoint := endpoint + endpoint = func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + failoverWrapper := newFailoverModelWrapper(&endpointModel{stream: innerEndpoint}, config) + return failoverWrapper.Stream(ctx, input, opts...) + } + } + return endpoint } @@ -615,7 +722,7 @@ func (w *stateModelWrapper) Generate(ctx context.Context, input []*schema.Messag return nil }) - state := &ChatModelAgentState{Messages: append(stateMessages, input...)} + state := &ChatModelAgentState{Messages: stateMessages} for _, m := range w.middlewares { if m.BeforeChatModel != nil { @@ -627,7 +734,7 @@ func (w *stateModelWrapper) Generate(ctx context.Context, input []*schema.Messag baseOpts := &model.Options{Tools: w.toolInfos} commonOpts := model.GetCommonOptions(baseOpts, opts...) - mc := &ModelContext{Tools: commonOpts.Tools, ModelRetryConfig: w.modelRetryConfig} + mc := &ModelContext{Tools: commonOpts.Tools, ModelRetryConfig: w.modelRetryConfig, cancelContext: w.cancelContext} for _, handler := range w.handlers { var err error ctx, state, err = handler.BeforeModelRewriteState(ctx, state, mc) @@ -681,7 +788,7 @@ func (w *stateModelWrapper) Stream(ctx context.Context, input []*schema.Message, return nil }) - state := &ChatModelAgentState{Messages: append(stateMessages, input...)} + state := &ChatModelAgentState{Messages: stateMessages} for _, m := range w.middlewares { if m.BeforeChatModel != nil { @@ -693,7 +800,7 @@ func (w *stateModelWrapper) Stream(ctx context.Context, input []*schema.Message, baseOpts := &model.Options{Tools: w.toolInfos} commonOpts := model.GetCommonOptions(baseOpts, opts...) - mc := &ModelContext{Tools: commonOpts.Tools, ModelRetryConfig: w.modelRetryConfig} + mc := &ModelContext{Tools: commonOpts.Tools, ModelRetryConfig: w.modelRetryConfig, cancelContext: w.cancelContext} for _, handler := range w.handlers { var err error ctx, state, err = handler.BeforeModelRewriteState(ctx, state, mc) diff --git a/adk/wrappers_failover_test.go b/adk/wrappers_failover_test.go new file mode 100644 index 000000000..8b14463e1 --- /dev/null +++ b/adk/wrappers_failover_test.go @@ -0,0 +1,181 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "errors" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/schema" +) + +func TestBuildModelWrappers_FailoverProxyInner(t *testing.T) { + base := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return schema.AssistantMessage("ok", nil), nil + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("ok", nil)}), nil + }, + } + + failoverCfg := &ModelFailoverConfig{ + MaxRetries: 0, + ShouldFailover: func(context.Context, *schema.Message, error) bool { return false }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return base, nil, nil + }, + } + + wrapped := buildModelWrappers(base, &modelWrapperConfig{ + failoverConfig: failoverCfg, + }) + + smw, ok := wrapped.(*stateModelWrapper) + require.True(t, ok) + _, ok = smw.inner.(*failoverProxyModel) + require.True(t, ok) + require.Same(t, base, smw.original) + require.Same(t, failoverCfg, smw.modelFailoverConfig) +} + +func TestStateModelWrapper_Generate_WithFailover(t *testing.T) { + wantErr := errors.New("first failed") + var shouldCalls int32 + var m1Calls int32 + var m2Calls int32 + + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m1Calls, 1) + return schema.AssistantMessage("partial", nil), wantErr + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, errors.New("unused") + }, + } + m2 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m2Calls, 1) + return schema.AssistantMessage("ok", nil), nil + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, errors.New("unused") + }, + } + + failoverCfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(_ context.Context, out *schema.Message, err error) bool { + atomic.AddInt32(&shouldCalls, 1) + require.ErrorIs(t, err, wantErr) + require.NotNil(t, out) + require.Equal(t, "partial", out.Content) + return true + }, + GetFailoverModel: func(_ context.Context, failoverCtx *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + require.Equal(t, uint(1), failoverCtx.FailoverAttempt) + return m2, nil, nil + }, + } + + wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + failoverConfig: failoverCfg, + }) + + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + got, err := wrapped.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.NoError(t, err) + require.NotNil(t, got) + require.Equal(t, "ok", got.Content) + require.Equal(t, int32(1), atomic.LoadInt32(&m1Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&m2Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&shouldCalls)) +} + +func TestStateModelWrapper_Stream_WithFailover(t *testing.T) { + streamErr := errors.New("mid error") + var shouldCalls int32 + var m1Calls int32 + var m2Calls int32 + + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m1Calls, 1) + return streamWithMidError([]*schema.Message{ + schema.AssistantMessage("p1", nil), + schema.AssistantMessage("p2", nil), + }, streamErr), nil + }, + } + m2 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m2Calls, 1) + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("final", nil)}), nil + }, + } + + failoverCfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(_ context.Context, out *schema.Message, err error) bool { + atomic.AddInt32(&shouldCalls, 1) + require.ErrorIs(t, err, streamErr) + require.NotNil(t, out) + require.Equal(t, "p1p2", out.Content) + return true + }, + GetFailoverModel: func(_ context.Context, failoverCtx *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + require.Equal(t, uint(1), failoverCtx.FailoverAttempt) + return m2, nil, nil + }, + } + + wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + failoverConfig: failoverCfg, + }) + + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + sr, err := wrapped.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.NoError(t, err) + msgs, err := drainMessageStream(sr) + require.NoError(t, err) + require.Len(t, msgs, 1) + require.Equal(t, "final", msgs[0].Content) + require.Equal(t, int32(1), atomic.LoadInt32(&m1Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&m2Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&shouldCalls)) +} diff --git a/adk/wrappers_retry_failover_test.go b/adk/wrappers_retry_failover_test.go new file mode 100644 index 000000000..98db172e9 --- /dev/null +++ b/adk/wrappers_retry_failover_test.go @@ -0,0 +1,411 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/schema" +) + +// TestRetryThenFailover_Generate_RetryExhaustedTriggersFailover tests the combined +// retry + failover path for Generate: m1 always fails, retry exhausted, failover to m2 which succeeds. +func TestRetryThenFailover_Generate_RetryExhaustedTriggersFailover(t *testing.T) { + modelErr := errors.New("model error") + var m1Calls int32 + var m2Calls int32 + + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m1Calls, 1) + return nil, modelErr + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, errors.New("unused") + }, + } + m2 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m2Calls, 1) + return schema.AssistantMessage("ok from m2", nil), nil + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, errors.New("unused") + }, + } + + retryCfg := &ModelRetryConfig{ + MaxRetries: 2, + IsRetryAble: func(_ context.Context, err error) bool { return true }, + BackoffFunc: func(_ context.Context, _ int) time.Duration { return 0 }, + } + + failoverCfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + return err != nil + }, + GetFailoverModel: func(_ context.Context, fc *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + require.NotNil(t, fc.LastErr) + return m2, nil, nil + }, + } + + wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + retryConfig: retryCfg, + failoverConfig: failoverCfg, + }) + + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + msg, err := wrapped.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.NoError(t, err) + require.Equal(t, "ok from m2", msg.Content) + + // m1: 1 (lastSuccess) + 2 retries = 3 calls on lastSuccess attempt, + // then failover to m2 which also goes through retry wrapper: 1 call succeeds. + require.Equal(t, int32(3), atomic.LoadInt32(&m1Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&m2Calls)) +} + +// TestRetryThenFailover_Generate_AllExhausted tests: m1 retry exhausted → failover to m2 → m2 retry exhausted → final error. +func TestRetryThenFailover_Generate_AllExhausted(t *testing.T) { + modelErr := errors.New("always fails") + var m1Calls int32 + var m2Calls int32 + + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m1Calls, 1) + return nil, modelErr + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, errors.New("unused") + }, + } + m2 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m2Calls, 1) + return nil, modelErr + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, errors.New("unused") + }, + } + + retryCfg := &ModelRetryConfig{ + MaxRetries: 1, + IsRetryAble: func(_ context.Context, err error) bool { return true }, + BackoffFunc: func(_ context.Context, _ int) time.Duration { return 0 }, + } + + failoverCfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + return err != nil + }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return m2, nil, nil + }, + } + + wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + retryConfig: retryCfg, + failoverConfig: failoverCfg, + }) + + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + _, err := wrapped.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.Error(t, err) + + // Should be RetryExhaustedError from m2's retry wrapper + var retryErr *RetryExhaustedError + require.True(t, errors.As(err, &retryErr)) + + // m1: 1 initial + 1 retry = 2 calls + require.Equal(t, int32(2), atomic.LoadInt32(&m1Calls)) + // m2: 1 initial + 1 retry = 2 calls + require.Equal(t, int32(2), atomic.LoadInt32(&m2Calls)) +} + +// TestRetryThenFailover_Stream_RetryExhaustedTriggersFailover tests stream path: +// m1 stream always errors mid-way, retry exhausted, failover to m2 which succeeds. +func TestRetryThenFailover_Stream_RetryExhaustedTriggersFailover(t *testing.T) { + streamErr := errors.New("stream mid error") + var m1Calls int32 + var m2Calls int32 + + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m1Calls, 1) + return streamWithMidError([]*schema.Message{ + schema.AssistantMessage("partial", nil), + }, streamErr), nil + }, + } + m2 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m2Calls, 1) + return schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("ok from m2", nil)}), nil + }, + } + + retryCfg := &ModelRetryConfig{ + MaxRetries: 1, + IsRetryAble: func(_ context.Context, err error) bool { return true }, + BackoffFunc: func(_ context.Context, _ int) time.Duration { return 0 }, + } + + failoverCfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + return err != nil + }, + GetFailoverModel: func(_ context.Context, fc *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + require.NotNil(t, fc.LastErr) + return m2, nil, nil + }, + } + + wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + retryConfig: retryCfg, + failoverConfig: failoverCfg, + }) + + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + sr, err := wrapped.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.NoError(t, err) + msgs, err := drainMessageStream(sr) + require.NoError(t, err) + require.Len(t, msgs, 1) + require.Equal(t, "ok from m2", msgs[0].Content) + + // m1: 1 initial + 1 retry = 2 calls on lastSuccess attempt + require.Equal(t, int32(2), atomic.LoadInt32(&m1Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&m2Calls)) +} + +// TestRetryThenFailover_Generate_RetrySucceedsNoFailover tests that when retry +// succeeds on the first model, failover is never triggered. +func TestRetryThenFailover_Generate_RetrySucceedsNoFailover(t *testing.T) { + var m1Calls int32 + var failoverCalled int32 + + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + n := atomic.AddInt32(&m1Calls, 1) + if n == 1 { + return nil, errors.New("transient error") + } + return schema.AssistantMessage("ok on retry", nil), nil + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, errors.New("unused") + }, + } + + retryCfg := &ModelRetryConfig{ + MaxRetries: 2, + IsRetryAble: func(_ context.Context, err error) bool { return true }, + BackoffFunc: func(_ context.Context, _ int) time.Duration { return 0 }, + } + + failoverCfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + atomic.AddInt32(&failoverCalled, 1) + return true + }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + t.Fatal("GetFailoverModel should not be called when retry succeeds") + return nil, nil, nil + }, + } + + wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + retryConfig: retryCfg, + failoverConfig: failoverCfg, + }) + + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + msg, err := wrapped.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.NoError(t, err) + require.Equal(t, "ok on retry", msg.Content) + + // 2 calls: first fails, second succeeds via retry + require.Equal(t, int32(2), atomic.LoadInt32(&m1Calls)) + // ShouldFailover should never be called + require.Equal(t, int32(0), atomic.LoadInt32(&failoverCalled)) +} + +// TestRetryThenFailover_Generate_NonRetryableErrorTriggersFailover tests that a non-retryable +// error skips retry and directly triggers failover. +func TestRetryThenFailover_Generate_NonRetryableErrorTriggersFailover(t *testing.T) { + nonRetryableErr := errors.New("non-retryable") + var m1Calls int32 + var m2Calls int32 + + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m1Calls, 1) + return nil, nonRetryableErr + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, errors.New("unused") + }, + } + m2 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m2Calls, 1) + return schema.AssistantMessage("ok from m2", nil), nil + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + return nil, errors.New("unused") + }, + } + + retryCfg := &ModelRetryConfig{ + MaxRetries: 3, + IsRetryAble: func(_ context.Context, err error) bool { + // Only non-retryable errors + return !errors.Is(err, nonRetryableErr) + }, + BackoffFunc: func(_ context.Context, _ int) time.Duration { return 0 }, + } + + failoverCfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + return err != nil + }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return m2, nil, nil + }, + } + + wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + retryConfig: retryCfg, + failoverConfig: failoverCfg, + }) + + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + msg, err := wrapped.Generate(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.NoError(t, err) + require.Equal(t, "ok from m2", msg.Content) + + // m1 called only once — non-retryable error skips retry + require.Equal(t, int32(1), atomic.LoadInt32(&m1Calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&m2Calls)) +} + +// TestRetryThenFailover_Stream_AllExhausted tests stream path when both retry and failover are exhausted. +func TestRetryThenFailover_Stream_AllExhausted(t *testing.T) { + streamErr := errors.New("always fails mid-stream") + var m1Calls int32 + var m2Calls int32 + + m1 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m1Calls, 1) + return streamWithMidError([]*schema.Message{ + schema.AssistantMessage("p", nil), + }, streamErr), nil + }, + } + m2 := &fakeChatModel{ + callbacksEnabled: true, + generate: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return nil, errors.New("unused") + }, + stream: func(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + atomic.AddInt32(&m2Calls, 1) + return streamWithMidError([]*schema.Message{ + schema.AssistantMessage("p", nil), + }, streamErr), nil + }, + } + + retryCfg := &ModelRetryConfig{ + MaxRetries: 1, + IsRetryAble: func(_ context.Context, err error) bool { return true }, + BackoffFunc: func(_ context.Context, _ int) time.Duration { return 0 }, + } + + failoverCfg := &ModelFailoverConfig{ + MaxRetries: 1, + ShouldFailover: func(_ context.Context, _ *schema.Message, err error) bool { + return err != nil + }, + GetFailoverModel: func(_ context.Context, _ *FailoverContext) (model.BaseChatModel, []*schema.Message, error) { + return m2, nil, nil + }, + } + + wrapped := buildModelWrappers(m1, &modelWrapperConfig{ + retryConfig: retryCfg, + failoverConfig: failoverCfg, + }) + + ctx := withChatModelAgentExecCtx(context.Background(), &chatModelAgentExecCtx{ + failoverLastSuccessModel: m1, + }) + _, err := wrapped.Stream(ctx, []*schema.Message{schema.UserMessage("hi")}) + require.Error(t, err) + + var retryErr *RetryExhaustedError + require.True(t, errors.As(err, &retryErr)) + + // m1: 1 initial + 1 retry = 2 calls + require.Equal(t, int32(2), atomic.LoadInt32(&m1Calls)) + // m2: 1 initial + 1 retry = 2 calls + require.Equal(t, int32(2), atomic.LoadInt32(&m2Calls)) +} diff --git a/adk/wrappers_test.go b/adk/wrappers_test.go index 5fd8acef5..f231e3c07 100644 --- a/adk/wrappers_test.go +++ b/adk/wrappers_test.go @@ -1085,3 +1085,600 @@ func (m *contentModifyingModelWrapper) Stream(ctx context.Context, input []*sche result.Content = m.newContent return schema.StreamReaderFromArray([]*schema.Message{result}), nil } + +type mockToolCallingModel struct { + mu sync.Mutex + generateCalls int + toolCallName string +} + +func (m *mockToolCallingModel) Generate(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + m.mu.Lock() + m.generateCalls++ + calls := m.generateCalls + m.mu.Unlock() + if calls == 1 { + return schema.AssistantMessage("calling tool", []schema.ToolCall{ + {ID: "tc-1", Function: schema.FunctionCall{Name: m.toolCallName, Arguments: `{"input":"test"}`}}, + }), nil + } + return schema.AssistantMessage("done", nil), nil +} + +func (m *mockToolCallingModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + msg, err := m.Generate(ctx, input, opts...) + if err != nil { + return nil, err + } + return schema.StreamReaderFromArray([]*schema.Message{msg}), nil +} + +func (m *mockToolCallingModel) WithTools(_ []*schema.ToolInfo) (model.ToolCallingChatModel, error) { + return m, nil +} + +type invokableTestTool struct { + name string + result string +} + +func (t *invokableTestTool) Info(_ context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: t.name, + Desc: "test tool", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "input": {Desc: "input", Required: true, Type: schema.String}, + }), + }, nil +} + +func (t *invokableTestTool) InvokableRun(_ context.Context, _ string, _ ...tool.Option) (string, error) { + return t.result, nil +} + +type streamableTestTool struct { + name string + result string +} + +func (t *streamableTestTool) Info(_ context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: t.name, + Desc: "test tool", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "input": {Desc: "input", Required: true, Type: schema.String}, + }), + }, nil +} + +func (t *streamableTestTool) StreamableRun(_ context.Context, _ string, _ ...tool.Option) (*schema.StreamReader[string], error) { + return schema.StreamReaderFromArray([]string{t.result}), nil +} + +type enhancedInvokableTestTool struct { + name string + result string +} + +func (t *enhancedInvokableTestTool) Info(_ context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: t.name, + Desc: "test tool", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "input": {Desc: "input", Required: true, Type: schema.String}, + }), + }, nil +} + +func (t *enhancedInvokableTestTool) InvokableRun(_ context.Context, _ *schema.ToolArgument, _ ...tool.Option) (*schema.ToolResult, error) { + return &schema.ToolResult{ + Parts: []schema.ToolOutputPart{{Type: schema.ToolPartTypeText, Text: t.result}}, + }, nil +} + +type enhancedStreamableTestTool struct { + name string + result string +} + +func (t *enhancedStreamableTestTool) Info(_ context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: t.name, + Desc: "test tool", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "input": {Desc: "input", Required: true, Type: schema.String}, + }), + }, nil +} + +func (t *enhancedStreamableTestTool) StreamableRun(_ context.Context, _ *schema.ToolArgument, _ ...tool.Option) (*schema.StreamReader[*schema.ToolResult], error) { + return schema.StreamReaderFromArray([]*schema.ToolResult{ + {Parts: []schema.ToolOutputPart{{Type: schema.ToolPartTypeText, Text: t.result}}}, + }), nil +} + +type invokableResultModifier struct { + *BaseChatModelAgentMiddleware + modifiedResult string +} + +func (h *invokableResultModifier) WrapInvokableToolCall(_ context.Context, endpoint InvokableToolCallEndpoint, _ *ToolContext) (InvokableToolCallEndpoint, error) { + return func(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { + _, err := endpoint(ctx, argumentsInJSON, opts...) + if err != nil { + return "", err + } + return h.modifiedResult, nil + }, nil +} + +type streamableResultModifier struct { + *BaseChatModelAgentMiddleware + modifiedResult string +} + +func (h *streamableResultModifier) WrapStreamableToolCall(_ context.Context, endpoint StreamableToolCallEndpoint, _ *ToolContext) (StreamableToolCallEndpoint, error) { + return func(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (*schema.StreamReader[string], error) { + sr, err := endpoint(ctx, argumentsInJSON, opts...) + if err != nil { + return nil, err + } + sr.Close() + return schema.StreamReaderFromArray([]string{h.modifiedResult}), nil + }, nil +} + +type enhancedInvokableResultModifier struct { + *BaseChatModelAgentMiddleware + modifiedResult string +} + +func (h *enhancedInvokableResultModifier) WrapEnhancedInvokableToolCall(_ context.Context, endpoint EnhancedInvokableToolCallEndpoint, _ *ToolContext) (EnhancedInvokableToolCallEndpoint, error) { + return func(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (*schema.ToolResult, error) { + _, err := endpoint(ctx, toolArgument, opts...) + if err != nil { + return nil, err + } + return &schema.ToolResult{ + Parts: []schema.ToolOutputPart{{Type: schema.ToolPartTypeText, Text: h.modifiedResult}}, + }, nil + }, nil +} + +type enhancedStreamableResultModifier struct { + *BaseChatModelAgentMiddleware + modifiedResult string +} + +func (h *enhancedStreamableResultModifier) WrapEnhancedStreamableToolCall(_ context.Context, endpoint EnhancedStreamableToolCallEndpoint, _ *ToolContext) (EnhancedStreamableToolCallEndpoint, error) { + return func(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (*schema.StreamReader[*schema.ToolResult], error) { + sr, err := endpoint(ctx, toolArgument, opts...) + if err != nil { + return nil, err + } + sr.Close() + return schema.StreamReaderFromArray([]*schema.ToolResult{ + {Parts: []schema.ToolOutputPart{{Type: schema.ToolPartTypeText, Text: h.modifiedResult}}}, + }), nil + }, nil +} + +func collectToolEvents(it *AsyncIterator[*AgentEvent]) []*AgentEvent { + var toolEvents []*AgentEvent + for { + ev, ok := it.Next() + if !ok { + break + } + if ev.Output == nil || ev.Output.MessageOutput == nil { + continue + } + mo := ev.Output.MessageOutput + if mo.Message != nil && mo.Message.Role == schema.Tool { + toolEvents = append(toolEvents, ev) + continue + } + if mo.IsStreaming && mo.Role == schema.Tool && mo.MessageStream != nil { + toolEvents = append(toolEvents, ev) + } + } + return toolEvents +} + +func collectToolContent(events []*AgentEvent) []string { + var contents []string + for _, ev := range events { + mo := ev.Output.MessageOutput + if !mo.IsStreaming && mo.Message != nil { + if mo.Message.Content != "" { + contents = append(contents, mo.Message.Content) + } else if len(mo.Message.UserInputMultiContent) > 0 { + for _, part := range mo.Message.UserInputMultiContent { + if part.Text != "" { + contents = append(contents, part.Text) + } + } + } + continue + } + if mo.IsStreaming && mo.MessageStream != nil { + var msgs []*schema.Message + for { + msg, err := mo.MessageStream.Recv() + if err != nil { + break + } + msgs = append(msgs, msg) + } + if len(msgs) > 0 { + concated, err := schema.ConcatMessages(msgs) + if err == nil { + if concated.Content != "" { + contents = append(contents, concated.Content) + } else if len(concated.UserInputMultiContent) > 0 { + for _, part := range concated.UserInputMultiContent { + if part.Text != "" { + contents = append(contents, part.Text) + } + } + } + } + } + } + } + return contents +} + +func TestEventSenderToolHandler(t *testing.T) { + t.Run("Invokable", func(t *testing.T) { + t.Run("DefaultSendsEvent", func(t *testing.T) { + ctx := context.Background() + testTool := &invokableTestTool{name: "test_tool", result: "invokable_output"} + mockModel := &mockToolCallingModel{toolCallName: "test_tool"} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: mockModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{testTool}, + }, + }, + }) + assert.NoError(t, err) + + r := NewRunner(ctx, RunnerConfig{Agent: agent, EnableStreaming: false}) + it := r.Run(ctx, []Message{schema.UserMessage("test")}) + + toolEvents := collectToolEvents(it) + assert.Equal(t, 1, len(toolEvents)) + contents := collectToolContent(toolEvents) + assert.Contains(t, contents, "invokable_output") + }) + + t.Run("UserConfiguredSkipsDefault", func(t *testing.T) { + ctx := context.Background() + testTool := &invokableTestTool{name: "test_tool", result: "invokable_output"} + mockModel := &mockToolCallingModel{toolCallName: "test_tool"} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: mockModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{testTool}, + }, + }, + Handlers: []ChatModelAgentMiddleware{NewEventSenderToolWrapper()}, + }) + assert.NoError(t, err) + + r := NewRunner(ctx, RunnerConfig{Agent: agent, EnableStreaming: false}) + it := r.Run(ctx, []Message{schema.UserMessage("test")}) + + toolEvents := collectToolEvents(it) + assert.Equal(t, 1, len(toolEvents)) + }) + + t.Run("InnermostGetsOriginalOutput", func(t *testing.T) { + ctx := context.Background() + originalResult := "original_invokable_output" + modifiedResult := "modified_invokable_output" + testTool := &invokableTestTool{name: "test_tool", result: originalResult} + mockModel := &mockToolCallingModel{toolCallName: "test_tool"} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: mockModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{testTool}, + }, + }, + Handlers: []ChatModelAgentMiddleware{ + NewEventSenderToolWrapper(), + &invokableResultModifier{ + BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, + modifiedResult: modifiedResult, + }, + }, + }) + assert.NoError(t, err) + + r := NewRunner(ctx, RunnerConfig{Agent: agent, EnableStreaming: false}) + it := r.Run(ctx, []Message{schema.UserMessage("test")}) + + toolEvents := collectToolEvents(it) + assert.GreaterOrEqual(t, len(toolEvents), 1) + contents := collectToolContent(toolEvents) + assert.Contains(t, contents, originalResult) + }) + }) + + t.Run("Streamable", func(t *testing.T) { + t.Run("DefaultSendsEvent", func(t *testing.T) { + ctx := context.Background() + testTool := &streamableTestTool{name: "test_tool", result: "streamable_output"} + mockModel := &mockToolCallingModel{toolCallName: "test_tool"} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: mockModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{testTool}, + }, + }, + }) + assert.NoError(t, err) + + r := NewRunner(ctx, RunnerConfig{Agent: agent, EnableStreaming: true}) + it := r.Run(ctx, []Message{schema.UserMessage("test")}) + + toolEvents := collectToolEvents(it) + assert.Equal(t, 1, len(toolEvents)) + contents := collectToolContent(toolEvents) + assert.Contains(t, contents, "streamable_output") + }) + + t.Run("UserConfiguredSkipsDefault", func(t *testing.T) { + ctx := context.Background() + testTool := &streamableTestTool{name: "test_tool", result: "streamable_output"} + mockModel := &mockToolCallingModel{toolCallName: "test_tool"} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: mockModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{testTool}, + }, + }, + Handlers: []ChatModelAgentMiddleware{NewEventSenderToolWrapper()}, + }) + assert.NoError(t, err) + + r := NewRunner(ctx, RunnerConfig{Agent: agent, EnableStreaming: true}) + it := r.Run(ctx, []Message{schema.UserMessage("test")}) + + toolEvents := collectToolEvents(it) + assert.Equal(t, 1, len(toolEvents)) + }) + + t.Run("InnermostGetsOriginalOutput", func(t *testing.T) { + ctx := context.Background() + originalResult := "original_streamable_output" + modifiedResult := "modified_streamable_output" + testTool := &streamableTestTool{name: "test_tool", result: originalResult} + mockModel := &mockToolCallingModel{toolCallName: "test_tool"} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: mockModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{testTool}, + }, + }, + Handlers: []ChatModelAgentMiddleware{ + NewEventSenderToolWrapper(), + &streamableResultModifier{ + BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, + modifiedResult: modifiedResult, + }, + }, + }) + assert.NoError(t, err) + + r := NewRunner(ctx, RunnerConfig{Agent: agent, EnableStreaming: true}) + it := r.Run(ctx, []Message{schema.UserMessage("test")}) + + toolEvents := collectToolEvents(it) + assert.GreaterOrEqual(t, len(toolEvents), 1) + contents := collectToolContent(toolEvents) + assert.Contains(t, contents, originalResult) + }) + }) + + t.Run("EnhancedInvokable", func(t *testing.T) { + t.Run("DefaultSendsEvent", func(t *testing.T) { + ctx := context.Background() + testTool := &enhancedInvokableTestTool{name: "test_tool", result: "enhanced_invokable_output"} + mockModel := &mockToolCallingModel{toolCallName: "test_tool"} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: mockModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{testTool}, + }, + }, + }) + assert.NoError(t, err) + + r := NewRunner(ctx, RunnerConfig{Agent: agent, EnableStreaming: false}) + it := r.Run(ctx, []Message{schema.UserMessage("test")}) + + toolEvents := collectToolEvents(it) + assert.Equal(t, 1, len(toolEvents)) + contents := collectToolContent(toolEvents) + assert.Contains(t, contents, "enhanced_invokable_output") + }) + + t.Run("UserConfiguredSkipsDefault", func(t *testing.T) { + ctx := context.Background() + testTool := &enhancedInvokableTestTool{name: "test_tool", result: "enhanced_invokable_output"} + mockModel := &mockToolCallingModel{toolCallName: "test_tool"} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: mockModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{testTool}, + }, + }, + Handlers: []ChatModelAgentMiddleware{NewEventSenderToolWrapper()}, + }) + assert.NoError(t, err) + + r := NewRunner(ctx, RunnerConfig{Agent: agent, EnableStreaming: false}) + it := r.Run(ctx, []Message{schema.UserMessage("test")}) + + toolEvents := collectToolEvents(it) + assert.Equal(t, 1, len(toolEvents)) + }) + + t.Run("InnermostGetsOriginalOutput", func(t *testing.T) { + ctx := context.Background() + originalResult := "original_enhanced_invokable_output" + modifiedResult := "modified_enhanced_invokable_output" + testTool := &enhancedInvokableTestTool{name: "test_tool", result: originalResult} + mockModel := &mockToolCallingModel{toolCallName: "test_tool"} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: mockModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{testTool}, + }, + }, + Handlers: []ChatModelAgentMiddleware{ + NewEventSenderToolWrapper(), + &enhancedInvokableResultModifier{ + BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, + modifiedResult: modifiedResult, + }, + }, + }) + assert.NoError(t, err) + + r := NewRunner(ctx, RunnerConfig{Agent: agent, EnableStreaming: false}) + it := r.Run(ctx, []Message{schema.UserMessage("test")}) + + toolEvents := collectToolEvents(it) + assert.GreaterOrEqual(t, len(toolEvents), 1) + contents := collectToolContent(toolEvents) + assert.Contains(t, contents, originalResult) + }) + }) + + t.Run("EnhancedStreamable", func(t *testing.T) { + t.Run("DefaultSendsEvent", func(t *testing.T) { + ctx := context.Background() + testTool := &enhancedStreamableTestTool{name: "test_tool", result: "enhanced_streamable_output"} + mockModel := &mockToolCallingModel{toolCallName: "test_tool"} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: mockModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{testTool}, + }, + }, + }) + assert.NoError(t, err) + + r := NewRunner(ctx, RunnerConfig{Agent: agent, EnableStreaming: true}) + it := r.Run(ctx, []Message{schema.UserMessage("test")}) + + toolEvents := collectToolEvents(it) + assert.Equal(t, 1, len(toolEvents)) + contents := collectToolContent(toolEvents) + assert.Contains(t, contents, "enhanced_streamable_output") + }) + + t.Run("UserConfiguredSkipsDefault", func(t *testing.T) { + ctx := context.Background() + testTool := &enhancedStreamableTestTool{name: "test_tool", result: "enhanced_streamable_output"} + mockModel := &mockToolCallingModel{toolCallName: "test_tool"} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: mockModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{testTool}, + }, + }, + Handlers: []ChatModelAgentMiddleware{NewEventSenderToolWrapper()}, + }) + assert.NoError(t, err) + + r := NewRunner(ctx, RunnerConfig{Agent: agent, EnableStreaming: true}) + it := r.Run(ctx, []Message{schema.UserMessage("test")}) + + toolEvents := collectToolEvents(it) + assert.Equal(t, 1, len(toolEvents)) + }) + + t.Run("InnermostGetsOriginalOutput", func(t *testing.T) { + ctx := context.Background() + originalResult := "original_enhanced_streamable_output" + modifiedResult := "modified_enhanced_streamable_output" + testTool := &enhancedStreamableTestTool{name: "test_tool", result: originalResult} + mockModel := &mockToolCallingModel{toolCallName: "test_tool"} + + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent", + Model: mockModel, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{testTool}, + }, + }, + Handlers: []ChatModelAgentMiddleware{ + NewEventSenderToolWrapper(), + &enhancedStreamableResultModifier{ + BaseChatModelAgentMiddleware: &BaseChatModelAgentMiddleware{}, + modifiedResult: modifiedResult, + }, + }, + }) + assert.NoError(t, err) + + r := NewRunner(ctx, RunnerConfig{Agent: agent, EnableStreaming: true}) + it := r.Run(ctx, []Message{schema.UserMessage("test")}) + + toolEvents := collectToolEvents(it) + assert.GreaterOrEqual(t, len(toolEvents), 1) + contents := collectToolContent(toolEvents) + assert.Contains(t, contents, originalResult) + }) + }) +} diff --git a/components/model/agentic_callback_extra.go b/components/model/agentic_callback_extra.go new file mode 100644 index 000000000..9a769cf7e --- /dev/null +++ b/components/model/agentic_callback_extra.go @@ -0,0 +1,94 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package model + +import ( + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/schema" +) + +// AgenticConfig is the config for the agentic model. +type AgenticConfig struct { + // Model is the model name. + Model string + // MaxTokens is the max number of output tokens, if reached the max tokens, the model will stop generating. + MaxTokens int + // Temperature is the temperature, which controls the randomness of the agentic model. + Temperature float32 + // TopP is the top p, which controls the diversity of the agentic model. + TopP float32 +} + +// AgenticCallbackInput is the input for the agentic model callback. +type AgenticCallbackInput struct { + // Messages is the agentic messages to be sent to the agentic model. + Messages []*schema.AgenticMessage + // Tools is the tools to be used in the agentic model. + Tools []*schema.ToolInfo + // Config is the config for the agentic model. + Config *AgenticConfig + // Extra is the extra information for the callback. + Extra map[string]any +} + +// AgenticCallbackOutput is the output for the agentic model callback. +type AgenticCallbackOutput struct { + // Message is the agentic message generated by the agentic model. + Message *schema.AgenticMessage + // Config is the config for the agentic model. + Config *AgenticConfig + // TokenUsage is the token usage of this request. + TokenUsage *TokenUsage + // Extra is the extra information for the callback. + Extra map[string]any +} + +// ConvAgenticCallbackInput converts the callback input to the agentic model callback input. +func ConvAgenticCallbackInput(src callbacks.CallbackInput) *AgenticCallbackInput { + switch t := src.(type) { + case *AgenticCallbackInput: + // when callback is triggered within component implementation, + // the input is usually already a typed *model.AgenticCallbackInput + return t + case []*schema.AgenticMessage: + // when callback is injected by graph node, not the component implementation itself, + // the input is the input of Agentic Model interface, which is []*schema.AgenticMessage + return &AgenticCallbackInput{ + Messages: t, + } + default: + return nil + } +} + +// ConvAgenticCallbackOutput converts the callback output to the agentic model callback output. +func ConvAgenticCallbackOutput(src callbacks.CallbackOutput) *AgenticCallbackOutput { + switch t := src.(type) { + case *AgenticCallbackOutput: + // when callback is triggered within component implementation, + // the output is usually already a typed *model.AgenticCallbackOutput + return t + case *schema.AgenticMessage: + // when callback is injected by graph node, not the component implementation itself, + // the output is the output of Agentic Model interface, which is *schema.AgenticMessage + return &AgenticCallbackOutput{ + Message: t, + } + default: + return nil + } +} diff --git a/components/model/agentic_callback_extra_test.go b/components/model/agentic_callback_extra_test.go new file mode 100644 index 000000000..937367477 --- /dev/null +++ b/components/model/agentic_callback_extra_test.go @@ -0,0 +1,35 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package model + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" +) + +func TestConvAgenticModel(t *testing.T) { + assert.NotNil(t, ConvAgenticCallbackInput(&AgenticCallbackInput{})) + assert.NotNil(t, ConvAgenticCallbackInput([]*schema.AgenticMessage{})) + assert.Nil(t, ConvAgenticCallbackInput("asd")) + + assert.NotNil(t, ConvAgenticCallbackOutput(&AgenticCallbackOutput{})) + assert.NotNil(t, ConvAgenticCallbackOutput(&schema.AgenticMessage{})) + assert.Nil(t, ConvAgenticCallbackOutput("asd")) +} diff --git a/components/model/interface.go b/components/model/interface.go index deb7b56dd..cf79785bc 100644 --- a/components/model/interface.go +++ b/components/model/interface.go @@ -89,3 +89,15 @@ type ToolCallingChatModel interface { // This method does not modify the current instance, making it safer for concurrent use. WithTools(tools []*schema.ToolInfo) (ToolCallingChatModel, error) } + +// AgenticModel defines the interface for agentic models that support AgenticMessage. +// It provides methods for generating complete and streaming outputs, and supports +// tool calling via the WithTools method. +type AgenticModel interface { + Generate(ctx context.Context, input []*schema.AgenticMessage, opts ...Option) (*schema.AgenticMessage, error) + Stream(ctx context.Context, input []*schema.AgenticMessage, opts ...Option) (*schema.StreamReader[*schema.AgenticMessage], error) + + // WithTools returns a new Model instance with the specified tools bound. + // This method does not modify the current instance, making it safer for concurrent use. + WithTools(tools []*schema.ToolInfo) (AgenticModel, error) +} diff --git a/components/model/option.go b/components/model/option.go index 9fd96116c..2222e14a1 100644 --- a/components/model/option.go +++ b/components/model/option.go @@ -22,21 +22,39 @@ import "github.com/cloudwego/eino/schema" type Options struct { // Temperature is the temperature for the model, which controls the randomness of the model. Temperature *float32 - // MaxTokens is the max number of tokens, if reached the max tokens, the model will stop generating, and mostly return an finish reason of "length". - MaxTokens *int // Model is the model name. Model *string // TopP is the top p for the model, which controls the diversity of the model. TopP *float32 - // Stop is the stop words for the model, which controls the stopping condition of the model. - Stop []string // Tools is a list of tools the model may call. Tools []*schema.ToolInfo + // DeferredTools is a list of tools to be registered with defer_loading=true + // for the model's built-in (server-side) tool search capability. + // These tools are sent to the model API but not loaded into context upfront — + // only their names and descriptions are visible to the model. The model's + // built-in tool search tool searches through them and loads matching ones + // on demand. + DeferredTools []*schema.ToolInfo + + ToolSearchTool *schema.ToolInfo + + // MaxTokens is the max number of tokens, if reached the max tokens, the model will stop generating, and mostly return a finish reason of "length". + MaxTokens *int + // Stop is the stop words for the model, which controls the stopping condition of the model. + Stop []string + + // Options only available for chat model. + // ToolChoice controls which tool is called by the model. ToolChoice *schema.ToolChoice // AllowedToolNames specifies a list of tool names that the model is allowed to call. // This allows for constraining the model to a specific subset of the available tools. AllowedToolNames []string + + // Options only available for agentic model. + + // AgenticToolChoice controls how the agentic model calls tools. + AgenticToolChoice *schema.AgenticToolChoice } // Option is a call-time option for a ChatModel. Options are immutable and @@ -106,8 +124,36 @@ func WithTools(tools []*schema.ToolInfo) Option { } } +// WithToolSearchTool is the option to register a tool search tool with the model. +// When set, the model uses this tool to discover and load deferred tools on demand. +// Note: The tool search tool should NOT be included in WithTools. +func WithToolSearchTool(tool *schema.ToolInfo) Option { + return Option{ + apply: func(opts *Options) { + opts.ToolSearchTool = tool + }, + } +} + +// WithDeferredTools is the option to set deferred tools for the model's +// built-in (server-side) tool search. These tools are registered with +// defer_loading=true so the model can discover and load them on demand +// via its native tool search capability. +// Note: Deferred tools should NOT be included in WithTools. +func WithDeferredTools(tools []*schema.ToolInfo) Option { + if tools == nil { + tools = []*schema.ToolInfo{} + } + return Option{ + apply: func(opts *Options) { + opts.DeferredTools = tools + }, + } +} + // WithToolChoice sets the tool choice for the model. It also allows for providing a list of // tool names to constrain the model to a specific subset of the available tools. +// Only available for ChatModel. func WithToolChoice(toolChoice schema.ToolChoice, allowedToolNames ...string) Option { return Option{ apply: func(opts *Options) { @@ -117,6 +163,17 @@ func WithToolChoice(toolChoice schema.ToolChoice, allowedToolNames ...string) Op } } +// WithAgenticToolChoice is the option to set tool choice for the agentic model. +// Only available for AgenticModel. +func WithAgenticToolChoice(toolChoice *schema.AgenticToolChoice) Option { + return Option{ + apply: func(opts *Options) { + opts.AgenticToolChoice = toolChoice + }, + } +} + +// WrapImplSpecificOptFn is the option to wrap the implementation specific option function. // WrapImplSpecificOptFn wraps an implementation-specific option function into // an [Option] so it can be passed alongside standard options. // diff --git a/components/model/option_test.go b/components/model/option_test.go index 36872c30e..c836933b7 100644 --- a/components/model/option_test.go +++ b/components/model/option_test.go @@ -82,6 +82,29 @@ func TestOptions(t *testing.T) { convey.So(opts.Tools, convey.ShouldNotBeNil) convey.So(len(opts.Tools), convey.ShouldEqual, 0) }) + + convey.Convey("test agentic tool choice option", t, func() { + var ( + toolChoice = schema.ToolChoiceForced + allowedTools = []*schema.AllowedTool{ + {FunctionName: "agentic_tool"}, + } + ) + opts := GetCommonOptions( + nil, + WithAgenticToolChoice(&schema.AgenticToolChoice{ + Type: toolChoice, + Forced: &schema.AgenticForcedToolChoice{ + Tools: allowedTools, + }, + }), + ) + + convey.So(opts.AgenticToolChoice, convey.ShouldNotBeNil) + convey.So(opts.AgenticToolChoice.Type, convey.ShouldEqual, toolChoice) + convey.So(opts.AgenticToolChoice.Forced, convey.ShouldNotBeNil) + convey.So(opts.AgenticToolChoice.Forced.Tools, convey.ShouldResemble, allowedTools) + }) } type implOption struct { diff --git a/components/prompt/agentic_callback_extra.go b/components/prompt/agentic_callback_extra.go new file mode 100644 index 000000000..315d5a4da --- /dev/null +++ b/components/prompt/agentic_callback_extra.go @@ -0,0 +1,70 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package prompt + +import ( + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/schema" +) + +// AgenticCallbackInput is the input for the callback. +type AgenticCallbackInput struct { + // Variables is the variables for the callback. + Variables map[string]any + // Templates is the agentic templates for the callback. + Templates []schema.AgenticMessagesTemplate + // Extra is the extra information for the callback. + Extra map[string]any +} + +// AgenticCallbackOutput is the output for the callback. +type AgenticCallbackOutput struct { + // Result is the agentic result for the callback. + Result []*schema.AgenticMessage + // Templates is the agentic templates for the callback. + Templates []schema.AgenticMessagesTemplate + // Extra is the extra information for the callback. + Extra map[string]any +} + +// ConvAgenticCallbackInput converts the callback input to the agentic prompt callback input. +func ConvAgenticCallbackInput(src callbacks.CallbackInput) *AgenticCallbackInput { + switch t := src.(type) { + case *AgenticCallbackInput: + return t + case map[string]any: + return &AgenticCallbackInput{ + Variables: t, + } + default: + return nil + } +} + +// ConvAgenticCallbackOutput converts the callback output to the agentic prompt callback output. +func ConvAgenticCallbackOutput(src callbacks.CallbackOutput) *AgenticCallbackOutput { + switch t := src.(type) { + case *AgenticCallbackOutput: + return t + case []*schema.AgenticMessage: + return &AgenticCallbackOutput{ + Result: t, + } + default: + return nil + } +} diff --git a/components/prompt/agentic_callback_extra_test.go b/components/prompt/agentic_callback_extra_test.go new file mode 100644 index 000000000..67982be80 --- /dev/null +++ b/components/prompt/agentic_callback_extra_test.go @@ -0,0 +1,46 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package prompt + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" +) + +func TestConvAgenticPrompt(t *testing.T) { + assert.NotNil(t, ConvAgenticCallbackInput(&AgenticCallbackInput{ + Variables: map[string]any{}, + Templates: []schema.AgenticMessagesTemplate{ + &schema.AgenticMessage{}, + }, + })) + assert.NotNil(t, ConvAgenticCallbackInput(map[string]any{})) + assert.Nil(t, ConvAgenticCallbackInput("asd")) + + assert.NotNil(t, ConvAgenticCallbackOutput(&AgenticCallbackOutput{ + Result: []*schema.AgenticMessage{ + {}, + }, + Templates: []schema.AgenticMessagesTemplate{ + &schema.AgenticMessage{}, + }, + })) + assert.NotNil(t, ConvAgenticCallbackOutput([]*schema.AgenticMessage{})) +} diff --git a/components/prompt/agentic_chat_template.go b/components/prompt/agentic_chat_template.go new file mode 100644 index 000000000..41d291065 --- /dev/null +++ b/components/prompt/agentic_chat_template.go @@ -0,0 +1,84 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package prompt + +import ( + "context" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/components" + "github.com/cloudwego/eino/schema" +) + +// FromAgenticMessages creates a new DefaultAgenticChatTemplate from the given templates and format type. +// eg. +// +// template := prompt.FromAgenticMessages(schema.FString, &schema.AgenticMessage{}) +// // in chain, or graph +// chain := compose.NewChain[map[string]any, []*schema.AgenticMessage]() +// chain.AppendAgenticChatTemplate(template) +func FromAgenticMessages(formatType schema.FormatType, templates ...schema.AgenticMessagesTemplate) *DefaultAgenticChatTemplate { + return &DefaultAgenticChatTemplate{ + templates: templates, + formatType: formatType, + } +} + +type DefaultAgenticChatTemplate struct { + templates []schema.AgenticMessagesTemplate + formatType schema.FormatType +} + +func (t *DefaultAgenticChatTemplate) Format(ctx context.Context, vs map[string]any, opts ...Option) (result []*schema.AgenticMessage, err error) { + ctx = callbacks.EnsureRunInfo(ctx, t.GetType(), components.ComponentOfAgenticPrompt) + ctx = callbacks.OnStart(ctx, &AgenticCallbackInput{ + Variables: vs, + Templates: t.templates, + }) + defer func() { + if err != nil { + _ = callbacks.OnError(ctx, err) + } + }() + + result = make([]*schema.AgenticMessage, 0, len(t.templates)) + for _, template := range t.templates { + msgs, err := template.Format(ctx, vs, t.formatType) + if err != nil { + return nil, err + } + + result = append(result, msgs...) + } + + _ = callbacks.OnEnd(ctx, &AgenticCallbackOutput{ + Result: result, + Templates: t.templates, + }) + + return result, nil +} + +// GetType returns the type of the agentic template (DefaultAgentic). +func (t *DefaultAgenticChatTemplate) GetType() string { + return "Default" +} + +// IsCallbacksEnabled checks if the callbacks are enabled for the chat template. +func (t *DefaultAgenticChatTemplate) IsCallbacksEnabled() bool { + return true +} diff --git a/components/prompt/agentic_chat_template_test.go b/components/prompt/agentic_chat_template_test.go new file mode 100644 index 000000000..f47020a2c --- /dev/null +++ b/components/prompt/agentic_chat_template_test.go @@ -0,0 +1,125 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package prompt + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/schema" +) + +type mockAgenticTemplate struct { + err error +} + +func (m *mockAgenticTemplate) Format(ctx context.Context, vs map[string]any, formatType schema.FormatType) ([]*schema.AgenticMessage, error) { + if m.err != nil { + return nil, m.err + } + return []*schema.AgenticMessage{schema.UserAgenticMessage("mocked")}, nil +} + +func TestFromAgenticMessages(t *testing.T) { + t.Run("create template", func(t *testing.T) { + tpl := schema.UserAgenticMessage("hello") + ft := schema.FString + at := FromAgenticMessages(ft, tpl) + + assert.NotNil(t, at) + assert.Equal(t, ft, at.formatType) + assert.Len(t, at.templates, 1) + assert.Same(t, tpl, at.templates[0]) + }) +} + +func TestDefaultAgenticTemplate_GetType(t *testing.T) { + t.Run("get type", func(t *testing.T) { + at := &DefaultAgenticChatTemplate{} + assert.Equal(t, "Default", at.GetType()) + }) +} + +func TestDefaultAgenticTemplate_IsCallbacksEnabled(t *testing.T) { + t.Run("callbacks enabled", func(t *testing.T) { + at := &DefaultAgenticChatTemplate{} + assert.True(t, at.IsCallbacksEnabled()) + }) +} + +func TestDefaultAgenticTemplate_Format(t *testing.T) { + t.Run("success", func(t *testing.T) { + // Mock callback handler + cb := callbacks.NewHandlerBuilder(). + OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { + assert.Equal(t, "Default", info.Type) + return ctx + }). + OnEndFn(func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context { + assert.Equal(t, "Default", info.Type) + return ctx + }). + OnErrorFn(func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { + assert.Fail(t, "unexpected error callback") + return ctx + }). + Build() + + tpl := schema.UserAgenticMessage("hello {val}") + at := FromAgenticMessages(schema.FString, tpl) + + ctx := context.Background() + ctx = callbacks.InitCallbacks(ctx, &callbacks.RunInfo{ + Type: "Default", + Component: "agentic_prompt", + }, cb) + + res, err := at.Format(ctx, map[string]any{"val": "world"}) + assert.NoError(t, err) + assert.Len(t, res, 1) + assert.Equal(t, "hello world", res[0].ContentBlocks[0].UserInputText.Text) + }) + + t.Run("template format error", func(t *testing.T) { + mockErr := errors.New("mock error") + mockTpl := &mockAgenticTemplate{err: mockErr} + at := FromAgenticMessages(schema.FString, mockTpl) + + // Mock callback handler to verify OnError + cb := callbacks.NewHandlerBuilder(). + OnErrorFn(func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { + assert.Equal(t, mockErr, err) + return ctx + }). + Build() + + ctx := context.Background() + ctx = callbacks.InitCallbacks(ctx, &callbacks.RunInfo{ + Type: "Default", + Component: "agentic_prompt", + }, cb) + + res, err := at.Format(ctx, map[string]any{}) + assert.Error(t, err) + assert.Nil(t, res) + assert.Equal(t, mockErr, err) + }) +} diff --git a/components/prompt/callback_extra_test.go b/components/prompt/callback_extra_test.go index 456297e29..ad8a3c0c2 100644 --- a/components/prompt/callback_extra_test.go +++ b/components/prompt/callback_extra_test.go @@ -25,11 +25,21 @@ import ( ) func TestConvPrompt(t *testing.T) { - assert.NotNil(t, ConvCallbackInput(&CallbackInput{})) + assert.NotNil(t, ConvCallbackInput(&CallbackInput{ + Templates: []schema.MessagesTemplate{ + &schema.Message{}, + }, + })) assert.NotNil(t, ConvCallbackInput(map[string]any{})) assert.Nil(t, ConvCallbackInput("asd")) - assert.NotNil(t, ConvCallbackOutput(&CallbackOutput{})) + assert.NotNil(t, ConvCallbackOutput(&CallbackOutput{ + Result: []*schema.Message{ + {}, + }, + Templates: []schema.MessagesTemplate{ + &schema.Message{}, + }, + })) assert.NotNil(t, ConvCallbackOutput([]*schema.Message{})) - assert.Nil(t, ConvCallbackOutput("asd")) } diff --git a/components/prompt/interface.go b/components/prompt/interface.go index eac695eda..2d5a2cbed 100644 --- a/components/prompt/interface.go +++ b/components/prompt/interface.go @@ -23,6 +23,7 @@ import ( ) var _ ChatTemplate = &DefaultChatTemplate{} +var _ AgenticChatTemplate = &DefaultAgenticChatTemplate{} // ChatTemplate formats a variables map into a list of messages for a ChatModel. // @@ -42,3 +43,8 @@ var _ ChatTemplate = &DefaultChatTemplate{} type ChatTemplate interface { Format(ctx context.Context, vs map[string]any, opts ...Option) ([]*schema.Message, error) } + +// AgenticChatTemplate formats variables into a list of agentic messages according to a prompt schema. +type AgenticChatTemplate interface { + Format(ctx context.Context, vs map[string]any, opts ...Option) ([]*schema.AgenticMessage, error) +} diff --git a/components/types.go b/components/types.go index a546ae59f..2b0ad8f0e 100644 --- a/components/types.go +++ b/components/types.go @@ -66,8 +66,12 @@ type Component string const ( // ComponentOfPrompt identifies chat template components. ComponentOfPrompt Component = "ChatTemplate" + // ComponentOfAgenticPrompt identifies agentic template components. + ComponentOfAgenticPrompt Component = "AgenticChatTemplate" // ComponentOfChatModel identifies chat model components. ComponentOfChatModel Component = "ChatModel" + // ComponentOfAgenticModel identifies agentic model components. + ComponentOfAgenticModel Component = "AgenticModel" // ComponentOfEmbedding identifies embedding components. ComponentOfEmbedding Component = "Embedding" // ComponentOfIndexer identifies indexer components. diff --git a/compose/agentic_tools_node.go b/compose/agentic_tools_node.go new file mode 100644 index 000000000..96aef7b72 --- /dev/null +++ b/compose/agentic_tools_node.go @@ -0,0 +1,126 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "context" + + "github.com/cloudwego/eino/schema" +) + +// NewAgenticToolsNode creates a new AgenticToolsNode. +// e.g. +// +// conf := &ToolsNodeConfig{ +// Tools: []tool.BaseTool{invokableTool1, streamableTool2}, +// } +// toolsNode, err := NewAgenticToolsNode(ctx, conf) +func NewAgenticToolsNode(ctx context.Context, conf *ToolsNodeConfig) (*AgenticToolsNode, error) { + tn, err := NewToolNode(ctx, conf) + if err != nil { + return nil, err + } + return &AgenticToolsNode{inner: tn}, nil +} + +type AgenticToolsNode struct { + inner *ToolsNode +} + +func (a *AgenticToolsNode) Invoke(ctx context.Context, input *schema.AgenticMessage, opts ...ToolsNodeOption) ([]*schema.AgenticMessage, error) { + result, err := a.inner.Invoke(ctx, agenticMessageToToolCallMessage(input), opts...) + if err != nil { + return nil, err + } + return toolMessageToAgenticMessage(result), nil +} + +func (a *AgenticToolsNode) Stream(ctx context.Context, input *schema.AgenticMessage, + opts ...ToolsNodeOption) (*schema.StreamReader[[]*schema.AgenticMessage], error) { + result, err := a.inner.Stream(ctx, agenticMessageToToolCallMessage(input), opts...) + if err != nil { + return nil, err + } + return streamToolMessageToAgenticMessage(result), nil +} + +func agenticMessageToToolCallMessage(input *schema.AgenticMessage) *schema.Message { + var tc []schema.ToolCall + for _, block := range input.ContentBlocks { + if block.Type != schema.ContentBlockTypeFunctionToolCall || block.FunctionToolCall == nil { + continue + } + tc = append(tc, schema.ToolCall{ + ID: block.FunctionToolCall.CallID, + Function: schema.FunctionCall{ + Name: block.FunctionToolCall.Name, + Arguments: block.FunctionToolCall.Arguments, + }, + Extra: block.Extra, + }) + } + return &schema.Message{ + Role: schema.Assistant, + ToolCalls: tc, + } +} + +func toolMessageToAgenticMessage(input []*schema.Message) []*schema.AgenticMessage { + var results []*schema.ContentBlock + for _, m := range input { + results = append(results, &schema.ContentBlock{ + Type: schema.ContentBlockTypeFunctionToolResult, + FunctionToolResult: &schema.FunctionToolResult{ + CallID: m.ToolCallID, + Name: m.ToolName, + Result: m.Content, + }, + Extra: m.Extra, + }) + } + return []*schema.AgenticMessage{{ + Role: schema.AgenticRoleTypeUser, + ContentBlocks: results, + }} +} + +func streamToolMessageToAgenticMessage(input *schema.StreamReader[[]*schema.Message]) *schema.StreamReader[[]*schema.AgenticMessage] { + return schema.StreamReaderWithConvert(input, func(t []*schema.Message) ([]*schema.AgenticMessage, error) { + var results []*schema.ContentBlock + for i, m := range t { + if m == nil { + continue + } + results = append(results, &schema.ContentBlock{ + Type: schema.ContentBlockTypeFunctionToolResult, + FunctionToolResult: &schema.FunctionToolResult{ + CallID: m.ToolCallID, + Name: m.ToolName, + Result: m.Content, + }, + StreamingMeta: &schema.StreamingMeta{Index: i}, + Extra: m.Extra, + }) + } + return []*schema.AgenticMessage{{ + Role: schema.AgenticRoleTypeUser, + ContentBlocks: results, + }}, nil + }) +} + +func (a *AgenticToolsNode) GetType() string { return "" } diff --git a/compose/agentic_tools_node_test.go b/compose/agentic_tools_node_test.go new file mode 100644 index 000000000..4641dd8ae --- /dev/null +++ b/compose/agentic_tools_node_test.go @@ -0,0 +1,239 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "io" + "testing" + + "github.com/bytedance/sonic" + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" +) + +func TestAgenticMessageToToolCallMessage(t *testing.T) { + input := &schema.AgenticMessage{ + ContentBlocks: []*schema.ContentBlock{ + { + Type: schema.ContentBlockTypeFunctionToolCall, + FunctionToolCall: &schema.FunctionToolCall{ + CallID: "1", + Name: "name1", + Arguments: "arg1", + }, + }, + { + Type: schema.ContentBlockTypeFunctionToolCall, + FunctionToolCall: &schema.FunctionToolCall{ + CallID: "2", + Name: "name2", + Arguments: "arg2", + }, + }, + { + Type: schema.ContentBlockTypeFunctionToolCall, + FunctionToolCall: &schema.FunctionToolCall{ + CallID: "3", + Name: "name3", + Arguments: "arg3", + }, + }, + }, + } + ret := agenticMessageToToolCallMessage(input) + assert.Equal(t, schema.Assistant, ret.Role) + assert.Equal(t, []schema.ToolCall{ + { + ID: "1", + Function: schema.FunctionCall{ + Name: "name1", + Arguments: "arg1", + }, + }, + { + ID: "2", + Function: schema.FunctionCall{ + Name: "name2", + Arguments: "arg2", + }, + }, + { + ID: "3", + Function: schema.FunctionCall{ + Name: "name3", + Arguments: "arg3", + }, + }, + }, ret.ToolCalls) +} + +func TestToolMessageToAgenticMessage(t *testing.T) { + input := []*schema.Message{ + { + Role: schema.Tool, + Content: "content1", + ToolCallID: "1", + ToolName: "name1", + }, + { + Role: schema.Tool, + Content: "content2", + ToolCallID: "2", + ToolName: "name2", + }, + { + Role: schema.Tool, + Content: "content3", + ToolCallID: "3", + ToolName: "name3", + }, + } + ret := toolMessageToAgenticMessage(input) + assert.Equal(t, 1, len(ret)) + assert.Equal(t, schema.AgenticRoleTypeUser, ret[0].Role) + assert.Equal(t, []*schema.ContentBlock{ + { + Type: schema.ContentBlockTypeFunctionToolResult, + FunctionToolResult: &schema.FunctionToolResult{ + CallID: "1", + Name: "name1", + Result: "content1", + }, + }, + { + Type: schema.ContentBlockTypeFunctionToolResult, + FunctionToolResult: &schema.FunctionToolResult{ + CallID: "2", + Name: "name2", + Result: "content2", + }, + }, + { + Type: schema.ContentBlockTypeFunctionToolResult, + FunctionToolResult: &schema.FunctionToolResult{ + CallID: "3", + Name: "name3", + Result: "content3", + }, + }, + }, ret[0].ContentBlocks) +} + +func TestStreamToolMessageToAgenticMessage(t *testing.T) { + input := schema.StreamReaderFromArray([][]*schema.Message{ + { + { + Role: schema.Tool, + Content: "content1-1", + ToolName: "name1", + ToolCallID: "1", + }, + nil, nil, + }, + { + nil, + { + Role: schema.Tool, + Content: "content2-1", + ToolName: "name2", + ToolCallID: "2", + }, + nil, + }, + { + nil, + { + Role: schema.Tool, + Content: "content2-2", + ToolName: "name2", + ToolCallID: "2", + }, + nil, + }, + { + nil, nil, + { + Role: schema.Tool, + Content: "content3-1", + ToolName: "name3", + ToolCallID: "3", + }, + }, + { + nil, nil, + { + Role: schema.Tool, + Content: "content3-2", + ToolName: "name3", + ToolCallID: "3", + }, + }, + }) + ret := streamToolMessageToAgenticMessage(input) + var chunks [][]*schema.AgenticMessage + for { + chunk, err := ret.Recv() + if err == io.EOF { + break + } + assert.NoError(t, err) + chunks = append(chunks, chunk) + } + result, err := schema.ConcatAgenticMessagesArray(chunks) + assert.NoError(t, err) + + actualStr, err := sonic.MarshalString(result) + assert.NoError(t, err) + + expected := []*schema.AgenticMessage{ + { + Role: schema.AgenticRoleTypeUser, + ContentBlocks: []*schema.ContentBlock{ + { + Type: schema.ContentBlockTypeFunctionToolResult, + FunctionToolResult: &schema.FunctionToolResult{ + CallID: "1", + Name: "name1", + Result: "content1-1", + }, + }, + { + Type: schema.ContentBlockTypeFunctionToolResult, + FunctionToolResult: &schema.FunctionToolResult{ + CallID: "2", + Name: "name2", + Result: "content2-1content2-2", + }, + }, + { + Type: schema.ContentBlockTypeFunctionToolResult, + FunctionToolResult: &schema.FunctionToolResult{ + CallID: "3", + Name: "name3", + Result: "content3-1content3-2", + }, + }, + }, + }, + } + + expectedStr, err := sonic.MarshalString(expected) + assert.NoError(t, err) + + assert.Equal(t, expectedStr, actualStr) +} diff --git a/compose/chain.go b/compose/chain.go index 5e4a8e1c0..abfa6bf1d 100644 --- a/compose/chain.go +++ b/compose/chain.go @@ -174,6 +174,18 @@ func (c *Chain[I, O]) AppendChatModel(node model.BaseChatModel, opts ...GraphAdd return c } +// AppendAgenticModel add a agentic.Model node to the chain. +// e.g. +// +// model, err := openai.NewAgenticModel(ctx, config) +// if err != nil {...} +// chain.AppendAgenticModel(model) +func (c *Chain[I, O]) AppendAgenticModel(node model.AgenticModel, opts ...GraphAddNodeOpt) *Chain[I, O] { + gNode, options := toAgenticModelNode(node, opts...) + c.addNode(gNode, options) + return c +} + // AppendChatTemplate add a ChatTemplate node to the chain. // eg. // @@ -189,11 +201,23 @@ func (c *Chain[I, O]) AppendChatTemplate(node prompt.ChatTemplate, opts ...Graph return c } +// AppendAgenticChatTemplate add a prompt.AgenticChatTemplate node to the chain. +// eg. +// +// chatTemplate, err := prompt.FromAgenticMessages(schema.FString, &schema.AgenticMessage{}) +// +// chain.AppendAgenticChatTemplate(chatTemplate) +func (c *Chain[I, O]) AppendAgenticChatTemplate(node prompt.AgenticChatTemplate, opts ...GraphAddNodeOpt) *Chain[I, O] { + gNode, options := toAgenticChatTemplateNode(node, opts...) + c.addNode(gNode, options) + return c +} + // AppendToolsNode add a ToolsNode node to the chain. // e.g. // -// toolsNode, err := tools.NewToolNode(ctx, &tools.ToolsNodeConfig{ -// Tools: []tools.Tool{...}, +// toolsNode, err := compose.NewToolNode(ctx, &compose.ToolsNodeConfig{ +// Tools: []tools.BaseTool{...}, // }) // // chain.AppendToolsNode(toolsNode) @@ -203,6 +227,20 @@ func (c *Chain[I, O]) AppendToolsNode(node *ToolsNode, opts ...GraphAddNodeOpt) return c } +// AppendAgenticToolsNode add a AgenticToolsNode node to the chain. +// e.g. +// +// toolsNode, err := compose.NewAgenticToolsNode(ctx, &compose.ToolsNodeConfig{ +// Tools: []tools.BaseTool{...}, +// }) +// +// chain.AppendAgenticToolsNode(toolsNode) +func (c *Chain[I, O]) AppendAgenticToolsNode(node *AgenticToolsNode, opts ...GraphAddNodeOpt) *Chain[I, O] { + gNode, options := toAgenticToolsNode(node, opts...) + c.addNode(gNode, options) + return c +} + // AppendDocumentTransformer add a DocumentTransformer node to the chain. // e.g. // diff --git a/compose/chain_branch.go b/compose/chain_branch.go index ec3a433af..84fb11048 100644 --- a/compose/chain_branch.go +++ b/compose/chain_branch.go @@ -146,6 +146,22 @@ func (cb *ChainBranch) AddChatModel(key string, node model.BaseChatModel, opts . return cb.addNode(key, gNode, options) } +// AddAgenticModel adds a agentic.Model node to the branch. +// eg. +// +// model1, err := openai.NewAgenticModel(ctx, &openai.AgenticModelConfig{ +// Model: "gpt-4o", +// }) +// model2, err := openai.NewAgenticModel(ctx, &openai.AgenticModelConfig{ +// Model: "gpt-4o-mini", +// }) +// cb.AddAgenticModel("agentic_model_key_1", model1) +// cb.AddAgenticModel("agentic_model_key_2", model2) +func (cb *ChainBranch) AddAgenticModel(key string, node model.AgenticModel, opts ...GraphAddNodeOpt) *ChainBranch { + gNode, options := toAgenticModelNode(node, opts...) + return cb.addNode(key, gNode, options) +} + // AddChatTemplate adds a ChatTemplate node to the branch. // eg. // @@ -167,11 +183,26 @@ func (cb *ChainBranch) AddChatTemplate(key string, node prompt.ChatTemplate, opt return cb.addNode(key, gNode, options) } +// AddAgenticChatTemplate adds a prompt.AgenticChatTemplate node to the branch. +// eg. +// +// chatTemplate, err := prompt.FromAgenticMessages(schema.FString, &schema.AgenticMessage{}) +// +// cb.AddAgenticChatTemplate("chat_template_key_01", chatTemplate) +// +// chatTemplate2, err := prompt.FromAgenticMessages(schema.FString, &schema.AgenticMessage{}) +// +// cb.AddAgenticChatTemplate("chat_template_key_02", chatTemplate2) +func (cb *ChainBranch) AddAgenticChatTemplate(key string, node prompt.AgenticChatTemplate, opts ...GraphAddNodeOpt) *ChainBranch { + gNode, options := toAgenticChatTemplateNode(node, opts...) + return cb.addNode(key, gNode, options) +} + // AddToolsNode adds a ToolsNode to the branch. // eg. // -// toolsNode, err := tools.NewToolNode(ctx, &tools.ToolsNodeConfig{ -// Tools: []tools.Tool{...}, +// toolsNode, err := compose.NewToolNode(ctx, &compose.ToolsNodeConfig{ +// Tools: []tools.BaseTool{...}, // }) // // cb.AddToolsNode("tools_node_key", toolsNode) @@ -180,6 +211,19 @@ func (cb *ChainBranch) AddToolsNode(key string, node *ToolsNode, opts ...GraphAd return cb.addNode(key, gNode, options) } +// AddAgenticToolsNode adds a AgenticToolsNode to the branch. +// eg. +// +// toolsNode, err := compose.NewAgenticToolsNode(ctx, &compose.ToolsNodeConfig{ +// Tools: []tools.BaseTool{...}, +// }) +// +// cb.AddAgenticToolsNode("tools_node_key", toolsNode) +func (cb *ChainBranch) AddAgenticToolsNode(key string, node *AgenticToolsNode, opts ...GraphAddNodeOpt) *ChainBranch { + gNode, options := toAgenticToolsNode(node, opts...) + return cb.addNode(key, gNode, options) +} + // AddLambda adds a Lambda node to the branch. // eg. // diff --git a/compose/chain_parallel.go b/compose/chain_parallel.go index 64cdf2db1..463140be2 100644 --- a/compose/chain_parallel.go +++ b/compose/chain_parallel.go @@ -70,6 +70,24 @@ func (p *Parallel) AddChatModel(outputKey string, node model.BaseChatModel, opts return p.addNode(outputKey, gNode, options) } +// AddAgenticModel adds a agentic.Model to the parallel. +// eg. +// +// model1, err := openai.NewAgenticModel(ctx, &openai.AgenticModelConfig{ +// Model: "gpt-4o", +// }) +// +// model2, err := openai.NewAgenticModel(ctx, &openai.AgenticModelConfig{ +// Model: "gpt-4o", +// }) +// +// p.AddAgenticModel("output_key1", model1) +// p.AddAgenticModel("output_key2", model2) +func (p *Parallel) AddAgenticModel(outputKey string, node model.AgenticModel, opts ...GraphAddNodeOpt) *Parallel { + gNode, options := toAgenticModelNode(node, append(opts, WithOutputKey(outputKey))...) + return p.addNode(outputKey, gNode, options) +} + // AddChatTemplate adds a chat template to the parallel. // eg. // @@ -84,6 +102,17 @@ func (p *Parallel) AddChatTemplate(outputKey string, node prompt.ChatTemplate, o return p.addNode(outputKey, gNode, options) } +// AddAgenticChatTemplate adds a prompt.AgenticChatTemplate to the parallel. +// eg. +// +// chatTemplate01, err := prompt.FromAgenticMessages(schema.FString, &schema.AgenticMessage{}) +// +// p.AddAgenticChatTemplate("output_key01", chatTemplate01) +func (p *Parallel) AddAgenticChatTemplate(outputKey string, node prompt.AgenticChatTemplate, opts ...GraphAddNodeOpt) *Parallel { + gNode, options := toAgenticChatTemplateNode(node, append(opts, WithOutputKey(outputKey))...) + return p.addNode(outputKey, gNode, options) +} + // AddToolsNode adds a tools node to the parallel. // eg. // @@ -97,6 +126,19 @@ func (p *Parallel) AddToolsNode(outputKey string, node *ToolsNode, opts ...Graph return p.addNode(outputKey, gNode, options) } +// AddAgenticToolsNode adds a tools node to the parallel. +// eg. +// +// toolsNode, err := compose.NewAgenticToolsNode(ctx, &compose.ToolsNodeConfig{ +// Tools: []tool.BaseTool{...}, +// }) +// +// p.AddAgenticToolsNode("output_key01", toolsNode) +func (p *Parallel) AddAgenticToolsNode(outputKey string, node *AgenticToolsNode, opts ...GraphAddNodeOpt) *Parallel { + gNode, options := toAgenticToolsNode(node, append(opts, WithOutputKey(outputKey))...) + return p.addNode(outputKey, gNode, options) +} + // AddLambda adds a lambda node to the parallel. // eg. // diff --git a/compose/checkpoint_test.go b/compose/checkpoint_test.go index c24b6ce6f..a86c02fb3 100644 --- a/compose/checkpoint_test.go +++ b/compose/checkpoint_test.go @@ -1383,6 +1383,7 @@ func TestCancelInterrupt(t *testing.T) { info, success := ExtractInterruptInfo(err) assert.True(t, success) assert.Equal(t, []string{"1"}, info.AfterNodes) + assert.True(t, info.FromGraphInterrupt) result, err := r.Invoke(ctx, "input", WithCheckPointID("1")) assert.NoError(t, err) assert.Equal(t, "input12", result) @@ -1397,6 +1398,7 @@ func TestCancelInterrupt(t *testing.T) { info, success = ExtractInterruptInfo(err) assert.True(t, success) assert.Equal(t, []string{"1"}, info.AfterNodes) + assert.True(t, info.FromGraphInterrupt) result, err = r.Invoke(ctx, "input", WithCheckPointID("2")) assert.NoError(t, err) assert.Equal(t, "input12", result) @@ -1412,6 +1414,7 @@ func TestCancelInterrupt(t *testing.T) { info, success = ExtractInterruptInfo(err) assert.True(t, success) assert.Equal(t, []string{"1"}, info.RerunNodes) + assert.True(t, info.FromGraphInterrupt) result, err = r.Invoke(ctx, "input", WithCheckPointID("3")) assert.NoError(t, err) assert.Equal(t, "input12", result) @@ -1441,6 +1444,7 @@ func TestCancelInterrupt(t *testing.T) { info, success = ExtractInterruptInfo(err) assert.True(t, success) assert.Equal(t, []string{"1"}, info.AfterNodes) + assert.True(t, info.FromGraphInterrupt) result, err = r.Invoke(ctx, "input", WithCheckPointID("1")) assert.NoError(t, err) assert.Equal(t, "input12", result) @@ -1455,6 +1459,7 @@ func TestCancelInterrupt(t *testing.T) { info, success = ExtractInterruptInfo(err) assert.True(t, success) assert.Equal(t, []string{"1"}, info.AfterNodes) + assert.True(t, info.FromGraphInterrupt) result, err = r.Invoke(ctx, "input", WithCheckPointID("2")) assert.NoError(t, err) assert.Equal(t, "input12", result) @@ -1470,6 +1475,7 @@ func TestCancelInterrupt(t *testing.T) { info, success = ExtractInterruptInfo(err) assert.True(t, success) assert.Equal(t, []string{"1"}, info.RerunNodes) + assert.True(t, info.FromGraphInterrupt) result, err = r.Invoke(ctx, "input", WithCheckPointID("3")) assert.NoError(t, err) assert.Equal(t, "input12", result) @@ -1510,6 +1516,7 @@ func TestCancelInterrupt(t *testing.T) { info, success = ExtractInterruptInfo(err) assert.True(t, success) assert.Equal(t, 2, len(info.AfterNodes)) + assert.True(t, info.FromGraphInterrupt) result2, err := rr.Invoke(ctx, "input", WithCheckPointID("1")) assert.NoError(t, err) assert.Equal(t, map[string]any{ @@ -1528,6 +1535,7 @@ func TestCancelInterrupt(t *testing.T) { info, success = ExtractInterruptInfo(err) assert.True(t, success) assert.Equal(t, 2, len(info.RerunNodes)) + assert.True(t, info.FromGraphInterrupt) result2, err = rr.Invoke(ctx, "input", WithCheckPointID("2")) assert.NoError(t, err) assert.Equal(t, map[string]any{ @@ -1536,6 +1544,26 @@ func TestCancelInterrupt(t *testing.T) { }, result2) } +func TestBusinessInterruptFromGraphInterruptFalse(t *testing.T) { + g := NewGraph[string, string]() + _ = g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { + return "", Interrupt(ctx, "biz") + })) + _ = g.AddEdge(START, "1") + _ = g.AddEdge("1", END) + + ctx := context.Background() + r, err := g.Compile(ctx, WithCheckPointStore(newInMemoryStore())) + assert.NoError(t, err) + + _, err = r.Invoke(ctx, "input", WithCheckPointID("biz")) + assert.Error(t, err) + info, existed := ExtractInterruptInfo(err) + assert.True(t, existed) + assert.False(t, info.FromGraphInterrupt) + assert.Equal(t, []string{"1"}, info.RerunNodes) +} + func TestPersistRerunInputNonStream(t *testing.T) { store := newInMemoryStore() diff --git a/compose/component_to_graph_node.go b/compose/component_to_graph_node.go index ab4694f1a..4bd27fe34 100644 --- a/compose/component_to_graph_node.go +++ b/compose/component_to_graph_node.go @@ -101,6 +101,17 @@ func toChatModelNode(node model.BaseChatModel, opts ...GraphAddNodeOpt) (*graphN opts...) } +func toAgenticModelNode(node model.AgenticModel, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { + return toComponentNode( + node, + components.ComponentOfAgenticModel, + node.Generate, + node.Stream, + nil, nil, + opts..., + ) +} + func toChatTemplateNode(node prompt.ChatTemplate, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { return toComponentNode( node, @@ -112,6 +123,16 @@ func toChatTemplateNode(node prompt.ChatTemplate, opts ...GraphAddNodeOpt) (*gra opts...) } +func toAgenticChatTemplateNode(node prompt.AgenticChatTemplate, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { + return toComponentNode( + node, + components.ComponentOfAgenticPrompt, + node.Format, + nil, nil, nil, + opts..., + ) +} + func toDocumentTransformerNode(node document.Transformer, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { return toComponentNode( node, @@ -134,6 +155,17 @@ func toToolsNode(node *ToolsNode, opts ...GraphAddNodeOpt) (*graphNode, *graphAd opts...) } +func toAgenticToolsNode(node *AgenticToolsNode, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { + return toComponentNode( + node, + ComponentOfAgenticToolsNode, + node.Invoke, + node.Stream, + nil, nil, + opts..., + ) +} + func toLambdaNode(node *Lambda, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { info, options := getNodeInfo(opts...) diff --git a/compose/graph.go b/compose/graph.go index 9370665f0..bcf5ae423 100644 --- a/compose/graph.go +++ b/compose/graph.go @@ -352,6 +352,19 @@ func (g *graph) AddChatModelNode(key string, node model.BaseChatModel, opts ...G return g.addNode(key, gNode, options) } +// AddAgenticModelNode add node that implements agentic.Model. +// e.g. +// +// model, err := openai.NewAgenticModel(ctx, &openai.AgenticModelConfig{ +// Model: "gpt-4o", +// }) +// +// graph.AddAgenticModelNode("agentic_model_node_key", model) +func (g *graph) AddAgenticModelNode(key string, node model.AgenticModel, opts ...GraphAddNodeOpt) error { + gNode, options := toAgenticModelNode(node, opts...) + return g.addNode(key, gNode, options) +} + // AddChatTemplateNode add node that implements prompt.ChatTemplate. // e.g. // @@ -366,10 +379,21 @@ func (g *graph) AddChatTemplateNode(key string, node prompt.ChatTemplate, opts . return g.addNode(key, gNode, options) } -// AddToolsNode adds a node that implements tools.ToolsNode. +// AddAgenticChatTemplateNode add node that implements prompt.AgenticChatTemplate. +// e.g. +// +// chatTemplate, err := prompt.FromAgenticMessages(schema.FString, &schema.AgenticMessage{}) +// +// graph.AddAgenticChatTemplateNode("chat_template_node_key", chatTemplate) +func (g *graph) AddAgenticChatTemplateNode(key string, node prompt.AgenticChatTemplate, opts ...GraphAddNodeOpt) error { + gNode, options := toAgenticChatTemplateNode(node, opts...) + return g.addNode(key, gNode, options) +} + +// AddToolsNode adds a node that implements ToolsNode. // e.g. // -// toolsNode, err := tools.NewToolNode(ctx, &tools.ToolsNodeConfig{}) +// toolsNode, err := compose.NewToolNode(ctx, &compose.ToolsNodeConfig{}) // // graph.AddToolsNode("tools_node_key", toolsNode) func (g *graph) AddToolsNode(key string, node *ToolsNode, opts ...GraphAddNodeOpt) error { @@ -377,6 +401,17 @@ func (g *graph) AddToolsNode(key string, node *ToolsNode, opts ...GraphAddNodeOp return g.addNode(key, gNode, options) } +// AddAgenticToolsNode adds a node that implements AgenticToolsNode. +// e.g. +// +// toolsNode, err := compose.NewAgenticToolsNode(ctx, &compose.ToolsNodeConfig{}) +// +// graph.AddAgenticToolsNode("tools_node_key", toolsNode) +func (g *graph) AddAgenticToolsNode(key string, node *AgenticToolsNode, opts ...GraphAddNodeOpt) error { + gNode, options := toAgenticToolsNode(node, opts...) + return g.addNode(key, gNode, options) +} + // AddDocumentTransformerNode adds a node that implements document.Transformer. // e.g. // diff --git a/compose/graph_manager.go b/compose/graph_manager.go index 944a0cf0a..46df3488e 100644 --- a/compose/graph_manager.go +++ b/compose/graph_manager.go @@ -496,12 +496,15 @@ func receiveWithListening(recv func() (*task, bool), cancel chan *time.Duration) return p.ta, p.closed, false, false, nil case timeout, ok := <-cancel: if !ok { - // unreachable - break + // The cancel channel has been closed — this means a previous call to + // receiveWithListening already consumed the cancel signal (task completed + // at the same time as cancel, and select picked the task result). Since + // cancel was already issued, treat this as an immediate cancel rather than + // blocking forever on resultCh. + return nil, false, true, true, nil } canceled = true if timeout == nil { - // canceled without timeout break } timeoutCh = time.After(*timeout) diff --git a/compose/graph_run.go b/compose/graph_run.go index a3e81ecf1..770cf16de 100644 --- a/compose/graph_run.go +++ b/compose/graph_run.go @@ -434,6 +434,7 @@ type interruptTempInfo struct { interruptBeforeNodes []string interruptAfterNodes []string interruptRerunExtra map[string]any + fromGraphInterrupt bool signals []*core.InterruptSignal } @@ -442,6 +443,7 @@ func (ti *interruptTempInfo) collectCanceledInfo(canceled bool, canceledTasks, c if !canceled { return } + ti.fromGraphInterrupt = true if len(canceledTasks) > 0 { for _, t := range canceledTasks { ti.interruptRerunNodes = append(ti.interruptRerunNodes, t.nodeKey) @@ -459,6 +461,13 @@ func (r *runner) resolveInterruptCompletedTasks(tempInfo *interruptTempInfo, com if info := isSubGraphInterrupt(completedTask.err); info != nil { tempInfo.subGraphInterrupts[completedTask.nodeKey] = info tempInfo.signals = append(tempInfo.signals, info.signal) + // Propagate FromGraphInterrupt from the sub-graph to the parent. + // The sub-graph's task manager may have consumed the cancel + // channel value before the parent's, so only the sub-graph + // knows the interrupt was triggered by a graph-level cancel. + if info.Info != nil && info.Info.FromGraphInterrupt { + tempInfo.fromGraphInterrupt = true + } continue } @@ -515,27 +524,27 @@ func (r *runner) handleInterrupt( if r.runCtx != nil { // current graph has enable state if state, ok := ctx.Value(stateKey{}).(*internalState); ok { - cp.State = state.state + state.mu.Lock() + copiedState, err := deepCopyState(state.state) + state.mu.Unlock() + if err != nil { + return fmt.Errorf("failed to copy state: %w", err) + } + cp.State = copiedState } } intInfo := &InterruptInfo{ - State: cp.State, - AfterNodes: tempInfo.interruptAfterNodes, - BeforeNodes: tempInfo.interruptBeforeNodes, - RerunNodes: tempInfo.interruptRerunNodes, - RerunNodesExtra: tempInfo.interruptRerunExtra, - SubGraphs: make(map[string]*InterruptInfo), + State: cp.State, + AfterNodes: tempInfo.interruptAfterNodes, + BeforeNodes: tempInfo.interruptBeforeNodes, + RerunNodes: tempInfo.interruptRerunNodes, + RerunNodesExtra: tempInfo.interruptRerunExtra, + SubGraphs: make(map[string]*InterruptInfo), + FromGraphInterrupt: tempInfo.fromGraphInterrupt, } - var info any - if cp.State != nil { - copiedState, err := deepCopyState(cp.State) - if err != nil { - return fmt.Errorf("failed to copy state: %w", err) - } - info = copiedState - } + info := cp.State is, err := core.Interrupt(ctx, info, nil, tempInfo.signals) if err != nil { @@ -581,15 +590,18 @@ func deepCopyState(state any) (any, error) { // Create new instance of the same type stateType := reflect.TypeOf(state) - if stateType.Kind() == reflect.Ptr { + isPtr := stateType.Kind() == reflect.Ptr + if isPtr { stateType = stateType.Elem() } - newState := reflect.New(stateType).Interface() - - if err := serializer.Unmarshal(data, newState); err != nil { + newStatePtr := reflect.New(stateType).Interface() + if err := serializer.Unmarshal(data, newStatePtr); err != nil { return nil, fmt.Errorf("failed to unmarshal state: %w", err) } - return newState, nil + if isPtr { + return newStatePtr, nil + } + return reflect.ValueOf(newStatePtr).Elem().Interface(), nil } func (r *runner) handleInterruptWithSubGraphAndRerunNodes( @@ -645,27 +657,27 @@ func (r *runner) handleInterruptWithSubGraphAndRerunNodes( if r.runCtx != nil { // current graph has enable state if state, ok := ctx.Value(stateKey{}).(*internalState); ok { - cp.State = state.state + state.mu.Lock() + copiedState, err_ := deepCopyState(state.state) + state.mu.Unlock() + if err_ != nil { + return fmt.Errorf("failed to copy state: %w", err_) + } + cp.State = copiedState } } intInfo := &InterruptInfo{ - State: cp.State, - BeforeNodes: tempInfo.interruptBeforeNodes, - AfterNodes: tempInfo.interruptAfterNodes, - RerunNodes: tempInfo.interruptRerunNodes, - RerunNodesExtra: tempInfo.interruptRerunExtra, - SubGraphs: make(map[string]*InterruptInfo), + State: cp.State, + BeforeNodes: tempInfo.interruptBeforeNodes, + AfterNodes: tempInfo.interruptAfterNodes, + RerunNodes: tempInfo.interruptRerunNodes, + RerunNodesExtra: tempInfo.interruptRerunExtra, + SubGraphs: make(map[string]*InterruptInfo), + FromGraphInterrupt: tempInfo.fromGraphInterrupt, } - var info any - if cp.State != nil { - copiedState, err_ := deepCopyState(cp.State) - if err_ != nil { - return fmt.Errorf("failed to copy state: %w", err_) - } - info = copiedState - } + info := cp.State is, err := core.Interrupt(ctx, info, nil, tempInfo.signals) if err != nil { diff --git a/compose/interrupt.go b/compose/interrupt.go index 98a5eeecc..cd423a1d6 100644 --- a/compose/interrupt.go +++ b/compose/interrupt.go @@ -263,6 +263,10 @@ type InterruptInfo struct { RerunNodesExtra map[string]any SubGraphs map[string]*InterruptInfo InterruptContexts []*InterruptCtx + // FromGraphInterrupt indicates whether the interrupt was triggered by a graph-level + // cancel operation (e.g., via WithGraphInterrupt) rather than business logic. + // When true, the interrupt originated from an external cancellation request. + FromGraphInterrupt bool } func init() { diff --git a/compose/tool_alias_test.go b/compose/tool_alias_test.go new file mode 100644 index 000000000..487132cbe --- /dev/null +++ b/compose/tool_alias_test.go @@ -0,0 +1,1178 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "context" + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/cloudwego/eino/components/tool" + "github.com/cloudwego/eino/schema" +) + +type searchArgs struct { + Query string `json:"query"` +} + +func TestToolNameAliases(t *testing.T) { + ctx := context.Background() + + // Create test tool + searchTool := newTool(&schema.ToolInfo{ + Name: "search", + Desc: "Search for information", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "query": {Type: "string", Desc: "Search query"}, + }), + }, func(ctx context.Context, args *searchArgs) (string, error) { + return "search result", nil + }) + + // Configure aliases + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"search_v1", "query", "find"}, + }, + }, + } + + node, err := NewToolNode(ctx, config) + require.NoError(t, err) + + // Test calling tool with alias + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "search_v1", // Using alias + Arguments: `{"query": "test"}`, + }, + }, + }) + + output, err := node.Invoke(ctx, input) + require.NoError(t, err) + require.Len(t, output, 1) + assert.Equal(t, "call_1", output[0].ToolCallID) + assert.Contains(t, output[0].Content, "search result") +} + +type searchArgsWithLimit struct { + Query string `json:"query"` + Limit int `json:"limit"` +} + +func TestArgumentsAliases(t *testing.T) { + ctx := context.Background() + + receivedArgs := "" + searchTool := newTool(&schema.ToolInfo{ + Name: "search", + Desc: "Search for information", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "query": {Type: "string"}, + "limit": {Type: "integer"}, + }), + }, func(ctx context.Context, args *searchArgsWithLimit) (string, error) { + b, _ := json.Marshal(args) + receivedArgs = string(b) + return "result", nil + }) + + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + ArgumentsAliases: map[string][]string{ + "query": {"q", "search_term"}, + "limit": {"max_results", "count"}, + }, + }, + }, + } + + node, err := NewToolNode(ctx, config) + require.NoError(t, err) + + // Use alias parameters + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "search", + Arguments: `{"q": "test", "max_results": 10}`, // Using aliases + }, + }, + }) + + _, err = node.Invoke(ctx, input) + require.NoError(t, err) + + // Verify tool received canonical parameter names + var args map[string]any + err = json.Unmarshal([]byte(receivedArgs), &args) + require.NoError(t, err) + assert.Equal(t, "test", args["query"]) + assert.Equal(t, float64(10), args["limit"]) + assert.NotContains(t, args, "q") + assert.NotContains(t, args, "max_results") +} + +type emptyArgs struct{} + +func TestAliasConflict(t *testing.T) { + ctx := context.Background() + + tool1 := newTool(&schema.ToolInfo{Name: "search", Desc: "Search"}, func(ctx context.Context, args *emptyArgs) (string, error) { + return "result", nil + }) + tool2 := newTool(&schema.ToolInfo{Name: "query", Desc: "Query"}, func(ctx context.Context, args *emptyArgs) (string, error) { + return "result", nil + }) + + t.Run("tool name alias conflict", func(t *testing.T) { + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{tool1, tool2}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"find"}, + }, + "query": { + NameAliases: []string{"find"}, // Conflict: find already used by search + }, + }, + } + + _, err := NewToolNode(ctx, config) + require.Error(t, err) + assert.Contains(t, err.Error(), "conflicts with an alias already registered for") + }) + + t.Run("tool name alias conflicts with canonical name", func(t *testing.T) { + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{tool1, tool2}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"query"}, // Conflict: "query" is tool2's canonical name + }, + }, + } + + _, err := NewToolNode(ctx, config) + require.Error(t, err) + assert.Contains(t, err.Error(), "conflicts with existing tool's canonical name") + }) + + t.Run("argument alias conflict", func(t *testing.T) { + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{tool1}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + ArgumentsAliases: map[string][]string{ + "query": {"q"}, + "limit": {"q"}, // Conflict: q maps to multiple parameters + }, + }, + }, + } + + _, err := NewToolNode(ctx, config) + require.Error(t, err) + assert.Contains(t, err.Error(), "conflicting arg alias") + }) + + t.Run("arg alias conflicts with existing schema property", func(t *testing.T) { + searchWithParams := newTool(&schema.ToolInfo{ + Name: "search", + Desc: "Search", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "query": {Type: "string"}, + "limit": {Type: "integer"}, + }), + }, func(ctx context.Context, args *emptyArgs) (string, error) { + return "result", nil + }) + + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchWithParams}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + ArgumentsAliases: map[string][]string{ + "limit": {"query"}, // "query" is already a schema property + }, + }, + }, + } + + _, err := NewToolNode(ctx, config) + require.Error(t, err) + assert.Contains(t, err.Error(), "conflicts with existing schema property") + }) +} + +func TestArgumentsAliasesWithHandler(t *testing.T) { + ctx := context.Background() + + executionOrder := []string{} + + searchTool := newTool(&schema.ToolInfo{ + Name: "search", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "query": {Type: "string"}, + }), + }, func(ctx context.Context, args *searchArgs) (string, error) { + executionOrder = append(executionOrder, "tool_invoke") + return "result", nil + }) + + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"find"}, + ArgumentsAliases: map[string][]string{ + "query": {"q"}, + }, + }, + }, + ToolArgumentsHandler: func(ctx context.Context, name, args string) (string, error) { + executionOrder = append(executionOrder, "args_handler") + // Handler receives the original model-returned name (alias) + assert.Equal(t, "search", name) + // Verify alias remapping has already been done + var m map[string]any + err := json.Unmarshal([]byte(args), &m) + require.NoError(t, err) + assert.Contains(t, m, "query") + assert.NotContains(t, m, "q") + return args, nil + }, + } + + node, err := NewToolNode(ctx, config) + require.NoError(t, err) + + // Call with alias name "find" and alias arg "q" + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "find", + Arguments: `{"q": "test"}`, + }, + }, + }) + + _, err = node.Invoke(ctx, input) + require.NoError(t, err) + + // Verify execution order: alias remapping → ToolArgumentsHandler → tool execution + assert.Equal(t, []string{"args_handler", "tool_invoke"}, executionOrder) +} + +func TestNonExistentToolInAliasConfig(t *testing.T) { + ctx := context.Background() + + tool1 := newTool(&schema.ToolInfo{Name: "search", Desc: "Search"}, func(ctx context.Context, args *emptyArgs) (string, error) { + return "result", nil + }) + + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{tool1}, + ToolAliases: map[string]ToolAliasConfig{ + "non_existent_tool": { // Non-existent tool + NameAliases: []string{"alias1"}, + }, + }, + } + + // Should not error — non-existent tool alias configs are silently skipped + node, err := NewToolNode(ctx, config) + require.NoError(t, err) + + // The existing tool should still work normally + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "search", + Arguments: `{}`, + }, + }, + }) + output, err := node.Invoke(ctx, input) + require.NoError(t, err) + require.Len(t, output, 1) + assert.Contains(t, output[0].Content, "result") +} + +type weatherArgs struct { + Location string `json:"location"` +} + +func TestToolAliasesE2E(t *testing.T) { + ctx := context.Background() + + // Create multiple tools + searchTool := newTool(&schema.ToolInfo{ + Name: "search", + Desc: "Search for information", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "query": {Type: "string"}, + "limit": {Type: "integer"}, + }), + }, func(ctx context.Context, args *searchArgsWithLimit) (string, error) { + return "search result", nil + }) + + weatherTool := newTool(&schema.ToolInfo{ + Name: "weather", + Desc: "Get weather information", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "location": {Type: "string"}, + }), + }, func(ctx context.Context, args *weatherArgs) (string, error) { + return "weather result", nil + }) + + // Configure aliases for multiple tools + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool, weatherTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"search_v1", "query"}, + ArgumentsAliases: map[string][]string{ + "query": {"q", "search_term"}, + "limit": {"max_results"}, + }, + }, + "weather": { + NameAliases: []string{"get_weather"}, + ArgumentsAliases: map[string][]string{ + "location": {"loc", "city"}, + }, + }, + }, + } + + node, err := NewToolNode(ctx, config) + require.NoError(t, err) + + // Construct message with multiple tool calls using different aliases + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "search_v1", // Tool name alias + Arguments: `{"q": "test", "max_results": 5}`, // Parameter aliases + }, + }, + { + ID: "call_2", + Function: schema.FunctionCall{ + Name: "get_weather", // Tool name alias + Arguments: `{"city": "Beijing"}`, // Parameter alias + }, + }, + }) + + output, err := node.Invoke(ctx, input) + require.NoError(t, err) + require.Len(t, output, 2) + + // Verify both tools executed successfully + assert.Equal(t, "call_1", output[0].ToolCallID) + assert.Equal(t, "call_2", output[1].ToolCallID) + assert.Contains(t, output[0].Content, "search result") + assert.Contains(t, output[1].Content, "weather result") +} + +func TestRemapArgsEdgeCases(t *testing.T) { + aliasMap := map[string]string{"q": "query"} + + t.Run("empty string", func(t *testing.T) { + result, err := remapArgs("", aliasMap) + assert.NoError(t, err) + assert.Equal(t, "", result) + }) + + t.Run("whitespace only", func(t *testing.T) { + result, err := remapArgs(" ", aliasMap) + assert.NoError(t, err) + assert.Equal(t, " ", result) + }) + + t.Run("non-object JSON", func(t *testing.T) { + result, err := remapArgs(`"hello"`, aliasMap) + assert.NoError(t, err) + assert.Equal(t, `"hello"`, result) + }) + + t.Run("JSON array", func(t *testing.T) { + result, err := remapArgs(`[1,2,3]`, aliasMap) + assert.NoError(t, err) + assert.Equal(t, `[1,2,3]`, result) + }) + + t.Run("invalid JSON", func(t *testing.T) { + result, err := remapArgs(`{invalid`, aliasMap) + assert.NoError(t, err) + assert.Equal(t, `{invalid`, result) + }) + + t.Run("alias and canonical both present", func(t *testing.T) { + // When both alias "q" and canonical "query" exist, alias is kept as-is (not deleted, not overwritten) + result, err := remapArgs(`{"q": "alias_val", "query": "canonical_val"}`, aliasMap) + assert.NoError(t, err) + var m map[string]any + require.NoError(t, json.Unmarshal([]byte(result), &m)) + assert.Equal(t, "canonical_val", m["query"]) + assert.Equal(t, "alias_val", m["q"]) + }) + + t.Run("unknown fields preserved", func(t *testing.T) { + result, err := remapArgs(`{"q": "test", "unknown_field": 42}`, aliasMap) + assert.NoError(t, err) + var m map[string]any + require.NoError(t, json.Unmarshal([]byte(result), &m)) + assert.Equal(t, "test", m["query"]) + assert.NotContains(t, m, "q") + assert.Equal(t, float64(42), m["unknown_field"]) + }) +} + +func TestCanonicalNameCallWithAliasConfigured(t *testing.T) { + ctx := context.Background() + + searchTool := newTool(&schema.ToolInfo{ + Name: "search", + Desc: "Search", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "query": {Type: "string"}, + }), + }, func(ctx context.Context, args *searchArgs) (string, error) { + return "result: " + args.Query, nil + }) + + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"find"}, + ArgumentsAliases: map[string][]string{ + "query": {"q"}, + }, + }, + }, + } + + node, err := NewToolNode(ctx, config) + require.NoError(t, err) + + // Call with canonical name and canonical arg — should work normally + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "search", + Arguments: `{"query": "hello"}`, + }, + }, + }) + + output, err := node.Invoke(ctx, input) + require.NoError(t, err) + require.Len(t, output, 1) + assert.Contains(t, output[0].Content, "result: hello") +} + +func TestEmptyAliasValidation(t *testing.T) { + ctx := context.Background() + + searchTool := newTool(&schema.ToolInfo{Name: "search", Desc: "Search"}, func(ctx context.Context, args *emptyArgs) (string, error) { + return "result", nil + }) + + t.Run("empty name alias", func(t *testing.T) { + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{""}, + }, + }, + } + _, err := NewToolNode(ctx, config) + require.Error(t, err) + assert.Contains(t, err.Error(), "empty name alias") + }) + + t.Run("empty arg alias", func(t *testing.T) { + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + ArgumentsAliases: map[string][]string{ + "query": {""}, + }, + }, + }, + } + _, err := NewToolNode(ctx, config) + require.Error(t, err) + assert.Contains(t, err.Error(), "empty argument alias") + }) + + t.Run("empty canonical arg key", func(t *testing.T) { + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + ArgumentsAliases: map[string][]string{ + "": {"q"}, + }, + }, + }, + } + _, err := NewToolNode(ctx, config) + require.Error(t, err) + assert.Contains(t, err.Error(), "empty canonical argument key") + }) +} + +func TestNameAliasSameAsCanonical(t *testing.T) { + ctx := context.Background() + + searchTool := newTool(&schema.ToolInfo{Name: "search", Desc: "Search"}, func(ctx context.Context, args *emptyArgs) (string, error) { + return "result", nil + }) + + // Alias same as canonical name — should be tolerated (skip, no error) + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"search", "find"}, + }, + }, + } + + node, err := NewToolNode(ctx, config) + require.NoError(t, err) + + // Both canonical and alias should work + for _, name := range []string{"search", "find"} { + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: name, + Arguments: `{}`, + }, + }, + }) + output, err := node.Invoke(ctx, input) + require.NoError(t, err) + require.Len(t, output, 1) + assert.Contains(t, output[0].Content, "result") + } +} + +func TestToolAliasesWithDynamicToolList(t *testing.T) { + ctx := context.Background() + + searchTool := newTool(&schema.ToolInfo{ + Name: "search", + Desc: "Search", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "query": {Type: "string"}, + }), + }, func(ctx context.Context, args *searchArgs) (string, error) { + return "search result: " + args.Query, nil + }) + + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"find"}, + ArgumentsAliases: map[string][]string{ + "query": {"q"}, + }, + }, + }, + } + + node, err := NewToolNode(ctx, config) + require.NoError(t, err) + + // Use dynamic ToolList via option — alias should still work + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "find", + Arguments: `{"q": "dynamic"}`, + }, + }, + }) + + output, err := node.Invoke(ctx, input, WithToolList(searchTool)) + require.NoError(t, err) + require.Len(t, output, 1) + assert.Contains(t, output[0].Content, "search result: dynamic") +} + +func TestToolNameAliasesStream(t *testing.T) { + ctx := context.Background() + + searchTool := newTool(&schema.ToolInfo{ + Name: "search", + Desc: "Search for information", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "query": {Type: "string"}, + }), + }, func(ctx context.Context, args *searchArgs) (string, error) { + return "stream result: " + args.Query, nil + }) + + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"find"}, + ArgumentsAliases: map[string][]string{ + "query": {"q"}, + }, + }, + }, + } + + node, err := NewToolNode(ctx, config) + require.NoError(t, err) + + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "find", + Arguments: `{"q": "hello"}`, + }, + }, + }) + + reader, err := node.Stream(ctx, input) + require.NoError(t, err) + + var chunks [][]*schema.Message + for { + chunk, err := reader.Recv() + if err != nil { + break + } + chunks = append(chunks, chunk) + } + + msgs, err := schema.ConcatMessageArray(chunks) + require.NoError(t, err) + require.Len(t, msgs, 1) + assert.Equal(t, "call_1", msgs[0].ToolCallID) + assert.Contains(t, msgs[0].Content, "stream result: hello") +} + +func TestEnhancedToolWithAliases(t *testing.T) { + ctx := context.Background() + + enhancedTool := &enhancedInvokableTool{ + info: &schema.ToolInfo{ + Name: "search", + Desc: "Enhanced search", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "query": {Type: "string"}, + }), + }, + fn: func(ctx context.Context, input *schema.ToolArgument) (*schema.ToolResult, error) { + return &schema.ToolResult{ + Parts: []schema.ToolOutputPart{ + {Type: schema.ToolPartTypeText, Text: "enhanced: " + input.Text}, + }, + }, nil + }, + } + + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{enhancedTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"find"}, + ArgumentsAliases: map[string][]string{ + "query": {"q"}, + }, + }, + }, + } + + node, err := NewToolNode(ctx, config) + require.NoError(t, err) + + // Call with alias name and alias arg + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "find", + Arguments: `{"q": "test"}`, + }, + }, + }) + + output, err := node.Invoke(ctx, input) + require.NoError(t, err) + require.Len(t, output, 1) + assert.Equal(t, "call_1", output[0].ToolCallID) + // Verify arg alias was remapped: "q" → "query" in the JSON passed to enhanced tool + assert.Contains(t, output[0].UserInputMultiContent[0].Text, "enhanced:") +} + +func TestDynamicToolListAliasRemoved(t *testing.T) { + ctx := context.Background() + + searchTool := newTool(&schema.ToolInfo{ + Name: "search", + Desc: "Search", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "query": {Type: "string"}, + }), + }, func(ctx context.Context, args *searchArgs) (string, error) { + return "search result", nil + }) + + weatherTool := newTool(&schema.ToolInfo{ + Name: "weather", + Desc: "Weather", + }, func(ctx context.Context, args *emptyArgs) (string, error) { + return "weather result", nil + }) + + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool, weatherTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"find"}, + }, + }, + } + + node, err := NewToolNode(ctx, config) + require.NoError(t, err) + + // Dynamic tool list only contains weatherTool — "search" and its alias "find" should not be available + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "find", + Arguments: `{}`, + }, + }, + }) + + _, err = node.Invoke(ctx, input, WithToolList(weatherTool)) + require.Error(t, err) + assert.Contains(t, err.Error(), "not found") +} + +func TestToolAliasesOptionOverridesGlobal(t *testing.T) { + ctx := context.Background() + + searchTool := newTool(&schema.ToolInfo{ + Name: "search", + Desc: "Search", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "query": {Type: "string"}, + }), + }, func(ctx context.Context, args *searchArgs) (string, error) { + return "search result: " + args.Query, nil + }) + + weatherTool := newTool(&schema.ToolInfo{ + Name: "weather", + Desc: "Weather", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "location": {Type: "string"}, + }), + }, func(ctx context.Context, args *weatherArgs) (string, error) { + return "weather result: " + args.Location, nil + }) + + // Global aliases: search has alias "find" + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool, weatherTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"find"}, + ArgumentsAliases: map[string][]string{ + "query": {"q"}, + }, + }, + }, + } + + node, err := NewToolNode(ctx, config) + require.NoError(t, err) + + t.Run("opt ToolAliases overrides global in Invoke", func(t *testing.T) { + // opt.ToolAliases defines "lookup" as alias for search (not "find") + optAliases := map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"lookup"}, + ArgumentsAliases: map[string][]string{ + "query": {"keyword"}, + }, + }, + } + + // "lookup" should work with opt aliases + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "lookup", + Arguments: `{"keyword": "test"}`, + }, + }, + }) + + output, err := node.Invoke(ctx, input, WithToolList(searchTool), WithToolAliases(optAliases)) + require.NoError(t, err) + require.Len(t, output, 1) + assert.Contains(t, output[0].Content, "search result: test") + + // "find" (global alias) should NOT work when opt.ToolAliases is set + input2 := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_2", + Function: schema.FunctionCall{ + Name: "find", + Arguments: `{"q": "test"}`, + }, + }, + }) + + _, err = node.Invoke(ctx, input2, WithToolList(searchTool), WithToolAliases(optAliases)) + require.Error(t, err) + assert.Contains(t, err.Error(), "not found") + }) + + t.Run("opt ToolAliases overrides global in Stream", func(t *testing.T) { + optAliases := map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"lookup"}, + ArgumentsAliases: map[string][]string{ + "query": {"keyword"}, + }, + }, + } + + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "lookup", + Arguments: `{"keyword": "stream_test"}`, + }, + }, + }) + + reader, err := node.Stream(ctx, input, WithToolList(searchTool), WithToolAliases(optAliases)) + require.NoError(t, err) + + var chunks [][]*schema.Message + for { + chunk, err := reader.Recv() + if err != nil { + break + } + chunks = append(chunks, chunk) + } + + msgs, err := schema.ConcatMessageArray(chunks) + require.NoError(t, err) + require.Len(t, msgs, 1) + assert.Contains(t, msgs[0].Content, "search result: stream_test") + }) + + t.Run("nil opt ToolAliases falls back to global filtered", func(t *testing.T) { + // No WithToolAliases — should use global "find" alias, filtered by ToolList + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "find", + Arguments: `{"q": "fallback"}`, + }, + }, + }) + + output, err := node.Invoke(ctx, input, WithToolList(searchTool)) + require.NoError(t, err) + require.Len(t, output, 1) + assert.Contains(t, output[0].Content, "search result: fallback") + }) + + t.Run("opt ToolAliases only without ToolList replaces global", func(t *testing.T) { + // Only WithToolAliases, no WithToolList — should use global tools with opt aliases + optAliases := map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"lookup"}, + ArgumentsAliases: map[string][]string{ + "query": {"keyword"}, + }, + }, + } + + // "lookup" (opt alias) should work + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "lookup", + Arguments: `{"keyword": "only_alias"}`, + }, + }, + }) + + output, err := node.Invoke(ctx, input, WithToolAliases(optAliases)) + require.NoError(t, err) + require.Len(t, output, 1) + assert.Contains(t, output[0].Content, "search result: only_alias") + + // "find" (global alias) should NOT work when opt.ToolAliases replaces global + input2 := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_2", + Function: schema.FunctionCall{ + Name: "find", + Arguments: `{"q": "test"}`, + }, + }, + }) + + _, err = node.Invoke(ctx, input2, WithToolAliases(optAliases)) + require.Error(t, err) + assert.Contains(t, err.Error(), "not found") + }) + + t.Run("opt ToolAliases only without ToolList in Stream", func(t *testing.T) { + optAliases := map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"lookup"}, + }, + } + + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "lookup", + Arguments: `{"query": "stream_only_alias"}`, + }, + }, + }) + + reader, err := node.Stream(ctx, input, WithToolAliases(optAliases)) + require.NoError(t, err) + + var chunks [][]*schema.Message + for { + chunk, err := reader.Recv() + if err != nil { + break + } + chunks = append(chunks, chunk) + } + + msgs, err := schema.ConcatMessageArray(chunks) + require.NoError(t, err) + require.Len(t, msgs, 1) + assert.Contains(t, msgs[0].Content, "search result: stream_only_alias") + }) +} + +func TestAliasConfigForToolAddedViaOption(t *testing.T) { + ctx := context.Background() + + searchTool := newTool(&schema.ToolInfo{ + Name: "search", + Desc: "Search", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "query": {Type: "string"}, + }), + }, func(ctx context.Context, args *searchArgs) (string, error) { + return "search result: " + args.Query, nil + }) + + weatherTool := newTool(&schema.ToolInfo{ + Name: "weather", + Desc: "Weather", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "location": {Type: "string"}, + }), + }, func(ctx context.Context, args *weatherArgs) (string, error) { + return "weather result: " + args.Location, nil + }) + + // New with only searchTool, but alias config includes weather tool + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"find"}, + ArgumentsAliases: map[string][]string{ + "query": {"q"}, + }, + }, + "weather": { + NameAliases: []string{"forecast"}, + ArgumentsAliases: map[string][]string{ + "location": {"loc"}, + }, + }, + }, + } + + node, err := NewToolNode(ctx, config) + require.NoError(t, err) + + t.Run("weather alias works when tool passed via option", func(t *testing.T) { + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "forecast", + Arguments: `{"loc": "Beijing"}`, + }, + }, + }) + + output, err := node.Invoke(ctx, input, WithToolList(searchTool, weatherTool)) + require.NoError(t, err) + require.Len(t, output, 1) + assert.Contains(t, output[0].Content, "weather result: Beijing") + }) + + t.Run("search alias still works with option tool list", func(t *testing.T) { + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "find", + Arguments: `{"q": "test"}`, + }, + }, + }) + + output, err := node.Invoke(ctx, input, WithToolList(searchTool, weatherTool)) + require.NoError(t, err) + require.Len(t, output, 1) + assert.Contains(t, output[0].Content, "search result: test") + }) +} + +func TestOptionWithToolListAndToolAliases(t *testing.T) { + ctx := context.Background() + + searchTool := newTool(&schema.ToolInfo{ + Name: "search", + Desc: "Search", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "query": {Type: "string"}, + }), + }, func(ctx context.Context, args *searchArgs) (string, error) { + return "search result: " + args.Query, nil + }) + + weatherTool := newTool(&schema.ToolInfo{ + Name: "weather", + Desc: "Weather", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "location": {Type: "string"}, + }), + }, func(ctx context.Context, args *weatherArgs) (string, error) { + return "weather result: " + args.Location, nil + }) + + config := &ToolsNodeConfig{ + Tools: []tool.BaseTool{searchTool}, + ToolAliases: map[string]ToolAliasConfig{ + "search": { + NameAliases: []string{"find"}, + }, + }, + } + + node, err := NewToolNode(ctx, config) + require.NoError(t, err) + + t.Run("opt aliases override global when both tool list and aliases provided", func(t *testing.T) { + optAliases := map[string]ToolAliasConfig{ + "weather": { + NameAliases: []string{"forecast"}, + ArgumentsAliases: map[string][]string{ + "location": {"loc"}, + }, + }, + } + + // "forecast" should work via opt aliases + input := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "forecast", + Arguments: `{"loc": "Shanghai"}`, + }, + }, + }) + + output, err := node.Invoke(ctx, input, WithToolList(searchTool, weatherTool), WithToolAliases(optAliases)) + require.NoError(t, err) + require.Len(t, output, 1) + assert.Contains(t, output[0].Content, "weather result: Shanghai") + + // "find" (global alias) should NOT work when opt aliases override + input2 := schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call_2", + Function: schema.FunctionCall{ + Name: "find", + Arguments: `{"query": "test"}`, + }, + }, + }) + + _, err = node.Invoke(ctx, input2, WithToolList(searchTool, weatherTool), WithToolAliases(optAliases)) + require.Error(t, err) + assert.Contains(t, err.Error(), "not found") + }) +} diff --git a/compose/tool_node.go b/compose/tool_node.go index a8f98a866..f65037e90 100644 --- a/compose/tool_node.go +++ b/compose/tool_node.go @@ -18,11 +18,16 @@ package compose import ( "context" + "encoding/json" "errors" "fmt" "runtime/debug" + "sort" + "strings" "sync" + "github.com/bytedance/sonic" + "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/components" "github.com/cloudwego/eino/components/tool" @@ -33,6 +38,8 @@ import ( type toolsNodeOptions struct { ToolOptions []tool.Option ToolList []tool.BaseTool + + ToolAliases map[string]ToolAliasConfig } // ToolsNodeOption is the option func type for ToolsNode. @@ -52,6 +59,15 @@ func WithToolList(tool ...tool.BaseTool) ToolsNodeOption { } } +// WithToolAliases sets the tool aliases for the ToolsNode call option. +// When used with WithToolList, it overrides the global alias configuration for the dynamic tool list. +// When used alone (without WithToolList), it replaces the global alias configuration while keeping the original tool list. +func WithToolAliases(toolAliases map[string]ToolAliasConfig) ToolsNodeOption { + return func(o *toolsNodeOptions) { + o.ToolAliases = toolAliases + } +} + // ToolsNode represents a node capable of executing tools within a graph. // The Graph Node interface is defined as follows: // @@ -62,6 +78,7 @@ func WithToolList(tool ...tool.BaseTool) ToolsNodeOption { // Output: An array of ToolMessage where the order of elements corresponds to the order of ToolCalls in the input type ToolsNode struct { tuple *toolsTuple + tools []tool.BaseTool unknownToolHandler func(ctx context.Context, name, input string) (string, error) executeSequentially bool toolArgumentsHandler func(ctx context.Context, name, input string) (string, error) @@ -69,6 +86,7 @@ type ToolsNode struct { streamToolCallMiddlewares []StreamableToolMiddleware enhancedToolCallMiddlewares []EnhancedInvokableToolMiddleware enhancedStreamToolCallMiddlewares []EnhancedStreamableToolMiddleware + toolAliasConfigs map[string]ToolAliasConfig } // ToolInput represents the input parameters for a tool call execution. @@ -150,11 +168,30 @@ type ToolMiddleware struct { EnhancedStreamable EnhancedStreamableToolMiddleware } +// ToolAliasConfig configures name and argument aliases for a single tool. +type ToolAliasConfig struct { + // NameAliases are alternative names for this tool. + // If the model returns any of these names, it will be resolved to the canonical tool name. + NameAliases []string + + // ArgumentsAliases maps canonical argument keys to their alias lists. + // key=canonical, value=[]alias. Applied to top-level JSON keys before tool execution. + // Example: {"query": ["q", "search_term"], "limit": ["max_results", "count"]} + ArgumentsAliases map[string][]string +} + // ToolsNodeConfig is the config for ToolsNode. type ToolsNodeConfig struct { // Tools specify the list of tools can be called which are BaseTool but must implement InvokableTool or StreamableTool. Tools []tool.BaseTool + // ToolAliases configures name and argument aliases for tools. + // Key is the canonical tool name, value defines its aliases. + // This field is optional. When provided, tool name aliases will be resolved during tool dispatch, + // and argument aliases will be remapped before ToolArgumentsHandler (if configured) and tool execution. + // Execution order: ArgumentsAliases remapping → ToolArgumentsHandler → tool execution + ToolAliases map[string]ToolAliasConfig + // UnknownToolsHandler handles tool calls for non-existent tools when LLM hallucinates. // This field is optional. When not set, calling a non-existent tool will result in an error. // When provided, if the LLM attempts to call a tool that doesn't exist in the Tools list, @@ -219,13 +256,22 @@ func NewToolNode(ctx context.Context, conf *ToolsNodeConfig) (*ToolsNode, error) } } - tuple, err := convTools(ctx, conf.Tools, middlewares, streamMiddlewares, enhancedInvokableMiddlewares, enhancedStreamableMiddlewares) + params := convToolsParams{ + tools: conf.Tools, + aliasConfigs: conf.ToolAliases, + } + params.middlewares.invokable = middlewares + params.middlewares.streamable = streamMiddlewares + params.middlewares.enhancedInvokable = enhancedInvokableMiddlewares + params.middlewares.enhancedStreamable = enhancedStreamableMiddlewares + tuple, err := convTools(ctx, params) if err != nil { return nil, err } return &ToolsNode{ tuple: tuple, + tools: conf.Tools, unknownToolHandler: conf.UnknownToolsHandler, executeSequentially: conf.ExecuteSequentially, toolArgumentsHandler: conf.ToolArgumentsHandler, @@ -233,6 +279,7 @@ func NewToolNode(ctx context.Context, conf *ToolsNodeConfig) (*ToolsNode, error) streamToolCallMiddlewares: streamMiddlewares, enhancedToolCallMiddlewares: enhancedInvokableMiddlewares, enhancedStreamToolCallMiddlewares: enhancedStreamableMiddlewares, + toolAliasConfigs: conf.ToolAliases, }, nil } @@ -273,19 +320,184 @@ type toolsTuple struct { streamEndpoints []StreamableToolEndpoint enhancedInvokableEndpoints []EnhancedInvokableToolEndpoint enhancedStreamableEndpoints []EnhancedStreamableToolEndpoint + // argsAliasMap stores reverse argument alias mappings for each tool. + // key: canonical tool name, value: map[aliasKey]canonicalKey (alias → canonical direction) + argsAliasMap map[string]map[string]string + // canonicalNames stores the canonical name for each tool index + canonicalNames []string + // toolInfos stores the ToolInfo for each tool index, used for alias validation + toolInfos []*schema.ToolInfo +} + +// remapArgs replaces alias keys in the JSON arguments string with canonical keys. +// aliasMap: alias → canonical mapping +func remapArgs(args string, aliasMap map[string]string) (string, error) { + if len(aliasMap) == 0 { + return args, nil + } + + trimmed := strings.TrimSpace(args) + if trimmed == "" || trimmed[0] != '{' { + return args, nil + } + + var m map[string]json.RawMessage + if err := sonic.Unmarshal([]byte(args), &m); err != nil { + return args, nil + } + + changed := false + for alias, canonical := range aliasMap { + if v, ok := m[alias]; ok { + // Only replace if canonical key doesn't exist. + // If both alias and canonical are present (e.g. {"q":"a","query":"b"}), + // the alias key is kept as-is and passed through as an unknown field. + if _, exists := m[canonical]; !exists { + m[canonical] = v + delete(m, alias) + changed = true + } + } + } + + if !changed { + return args, nil + } + + b, err := sonic.Marshal(m) + return string(b), err +} + +type convToolsParams struct { + tools []tool.BaseTool + middlewares struct { + invokable []InvokableToolMiddleware + streamable []StreamableToolMiddleware + enhancedInvokable []EnhancedInvokableToolMiddleware + enhancedStreamable []EnhancedStreamableToolMiddleware + } + aliasConfigs map[string]ToolAliasConfig +} + +func (t *toolsTuple) applyAliasConfigs(aliasConfigs map[string]ToolAliasConfig) error { + t.argsAliasMap = make(map[string]map[string]string) + + sortedToolNames := make([]string, 0, len(aliasConfigs)) + for toolName := range aliasConfigs { + sortedToolNames = append(sortedToolNames, toolName) + } + sort.Strings(sortedToolNames) + + for _, toolName := range sortedToolNames { + aliasConfig := aliasConfigs[toolName] + var ( + toolIdx int + exists bool + ) + if toolIdx, exists = t.indexes[toolName]; !exists { + continue + } + + if err := t.applyNameAliases(toolName, toolIdx, aliasConfig.NameAliases); err != nil { + return err + } + + if err := t.applyArgsAliases(toolName, toolIdx, aliasConfig.ArgumentsAliases); err != nil { + return err + } + } + + return nil +} + +// applyNameAliases validates and registers name aliases for a single tool into the indexes map. +func (t *toolsTuple) applyNameAliases(toolName string, toolIdx int, nameAliases []string) error { + for _, alias := range nameAliases { + if strings.TrimSpace(alias) == "" { + return fmt.Errorf("tool '%s' has empty name alias", toolName) + } + if existingIdx, conflict := t.indexes[alias]; conflict { + if existingIdx != toolIdx { + conflictToolName := t.canonicalNames[existingIdx] + if alias == conflictToolName { + return fmt.Errorf("tool '%s': name alias '%s' conflicts with existing tool's canonical name", toolName, alias) + } + return fmt.Errorf("tool '%s': name alias '%s' conflicts with an alias already registered for tool '%s'", toolName, alias, conflictToolName) + } + continue + } + t.indexes[alias] = toolIdx + } + return nil } -func convTools(ctx context.Context, tools []tool.BaseTool, ms []InvokableToolMiddleware, sms []StreamableToolMiddleware, - ems []EnhancedInvokableToolMiddleware, esms []EnhancedStreamableToolMiddleware) (*toolsTuple, error) { +// applyArgsAliases validates argument aliases against the tool schema and builds a reverse alias map for a single tool. +func (t *toolsTuple) applyArgsAliases(toolName string, toolIdx int, argumentsAliases map[string][]string) error { + if len(argumentsAliases) == 0 { + return nil + } + + schemaKeys := make(map[string]bool) + if info := t.toolInfos[toolIdx]; info != nil && info.ParamsOneOf != nil { + js, err := info.ParamsOneOf.ToJSONSchema() + if err != nil { + return fmt.Errorf("tool '%s': failed to parse JSON schema for alias validation: %w", toolName, err) + } + if js != nil && js.Properties != nil { + for pair := js.Properties.Oldest(); pair != nil; pair = pair.Next() { + schemaKeys[pair.Key] = true + } + } + } + + reverseMap := make(map[string]string) + sortedCanonicals := make([]string, 0, len(argumentsAliases)) + for canonical := range argumentsAliases { + sortedCanonicals = append(sortedCanonicals, canonical) + } + sort.Strings(sortedCanonicals) + + for _, canonical := range sortedCanonicals { + aliases := argumentsAliases[canonical] + if strings.TrimSpace(canonical) == "" { + return fmt.Errorf("tool '%s' has empty canonical argument key", toolName) + } + if strings.Contains(canonical, ".") { + return fmt.Errorf("tool '%s' has unsupported '.' in canonical argument key '%s': nested field matching is not yet supported", + toolName, canonical) + } + for _, alias := range aliases { + if strings.TrimSpace(alias) == "" { + return fmt.Errorf("tool '%s' has empty argument alias for canonical key '%s'", toolName, canonical) + } + if schemaKeys[alias] { + return fmt.Errorf("tool '%s' has arg alias '%s' that conflicts with existing schema property '%s'", + toolName, alias, alias) + } + if existingCanonical, conflict := reverseMap[alias]; conflict { + return fmt.Errorf("tool '%s' has conflicting arg alias '%s' mapped to both '%s' and '%s'", + toolName, alias, existingCanonical, canonical) + } + reverseMap[alias] = canonical + } + } + t.argsAliasMap[toolName] = reverseMap + + return nil +} + +func convTools(ctx context.Context, params convToolsParams) (*toolsTuple, error) { ret := &toolsTuple{ indexes: make(map[string]int), - meta: make([]*executorMeta, len(tools)), - endpoints: make([]InvokableToolEndpoint, len(tools)), - streamEndpoints: make([]StreamableToolEndpoint, len(tools)), - enhancedInvokableEndpoints: make([]EnhancedInvokableToolEndpoint, len(tools)), - enhancedStreamableEndpoints: make([]EnhancedStreamableToolEndpoint, len(tools)), + meta: make([]*executorMeta, len(params.tools)), + endpoints: make([]InvokableToolEndpoint, len(params.tools)), + streamEndpoints: make([]StreamableToolEndpoint, len(params.tools)), + enhancedInvokableEndpoints: make([]EnhancedInvokableToolEndpoint, len(params.tools)), + enhancedStreamableEndpoints: make([]EnhancedStreamableToolEndpoint, len(params.tools)), + canonicalNames: make([]string, len(params.tools)), + toolInfos: make([]*schema.ToolInfo, len(params.tools)), } - for idx, bt := range tools { + for idx, bt := range params.tools { tl, err := bt.Info(ctx) if err != nil { return nil, fmt.Errorf("(NewToolNode) failed to get tool info at idx= %d: %w", idx, err) @@ -310,19 +522,19 @@ func convTools(ctx context.Context, tools []tool.BaseTool, ms []InvokableToolMid meta = parseExecutorInfoFromComponent(components.ComponentOfTool, bt) if st, ok = bt.(tool.StreamableTool); ok { - streamable = wrapStreamToolCall(st, sms, !meta.isComponentCallbackEnabled) + streamable = wrapStreamToolCall(st, params.middlewares.streamable, !meta.isComponentCallbackEnabled) } if it, ok = bt.(tool.InvokableTool); ok { - invokable = wrapToolCall(it, ms, !meta.isComponentCallbackEnabled) + invokable = wrapToolCall(it, params.middlewares.invokable, !meta.isComponentCallbackEnabled) } if eiTool, ok = bt.(tool.EnhancedInvokableTool); ok { - enhancedInvokable = wrapEnhancedInvokableToolCall(eiTool, ems, !meta.isComponentCallbackEnabled) + enhancedInvokable = wrapEnhancedInvokableToolCall(eiTool, params.middlewares.enhancedInvokable, !meta.isComponentCallbackEnabled) } if esTool, ok = bt.(tool.EnhancedStreamableTool); ok { - enhancedStreamable = wrapEnhancedStreamableToolCall(esTool, esms, !meta.isComponentCallbackEnabled) + enhancedStreamable = wrapEnhancedStreamableToolCall(esTool, params.middlewares.enhancedStreamable, !meta.isComponentCallbackEnabled) } if st == nil && it == nil && eiTool == nil && esTool == nil { @@ -348,7 +560,16 @@ func convTools(ctx context.Context, tools []tool.BaseTool, ms []InvokableToolMid ret.streamEndpoints[idx] = streamable ret.enhancedInvokableEndpoints[idx] = enhancedInvokable ret.enhancedStreamableEndpoints[idx] = enhancedStreamable + ret.canonicalNames[idx] = toolName + ret.toolInfos[idx] = tl } + + if len(params.aliasConfigs) > 0 { + if err := ret.applyAliasConfigs(params.aliasConfigs); err != nil { + return nil, err + } + } + return ret, nil } @@ -616,14 +837,27 @@ func (tn *ToolsNode) genToolCallTasks(ctx context.Context, tuple *toolsTuple, toolCallTasks[i].useEnhanced = false } + // Get canonical tool name for looking up argument aliases + canonicalToolName := tuple.canonicalNames[index] + + // Process argument aliases remapping + args := toolCall.Function.Arguments + if aliasMap, hasAliases := tuple.argsAliasMap[canonicalToolName]; hasAliases { + remappedArgs, err := remapArgs(args, aliasMap) + if err != nil { + return nil, fmt.Errorf("failed to remap args for tool[name:%s]: %w", canonicalToolName, err) + } + args = remappedArgs + } + if tn.toolArgumentsHandler != nil { - arg, err := tn.toolArgumentsHandler(ctx, toolCall.Function.Name, toolCall.Function.Arguments) + arg, err := tn.toolArgumentsHandler(ctx, canonicalToolName, args) if err != nil { - return nil, fmt.Errorf("failed to executed tool[name:%s arguments:%s] arguments handler: %w", toolCall.Function.Name, toolCall.Function.Arguments, err) + return nil, fmt.Errorf("failed to executed tool[name:%s arguments:%s] arguments handler: %w", toolCall.Function.Name, args, err) } toolCallTasks[i].arg = arg } else { - toolCallTasks[i].arg = toolCall.Function.Arguments + toolCallTasks[i].arg = args } } } @@ -782,6 +1016,31 @@ func parallelRunToolCall(ctx context.Context, wg.Wait() } +// buildTupleFromOpts rebuilds a toolsTuple when call options override tools or aliases. +func (tn *ToolsNode) buildTupleFromOpts(ctx context.Context, opt *toolsNodeOptions) (*toolsTuple, error) { + tools := opt.ToolList + if tools == nil { + tools = tn.tools + } + aliasConfigs := opt.ToolAliases + if aliasConfigs == nil { + aliasConfigs = tn.toolAliasConfigs + } + p := convToolsParams{ + tools: tools, + aliasConfigs: aliasConfigs, + } + p.middlewares.invokable = tn.toolCallMiddlewares + p.middlewares.streamable = tn.streamToolCallMiddlewares + p.middlewares.enhancedInvokable = tn.enhancedToolCallMiddlewares + p.middlewares.enhancedStreamable = tn.enhancedStreamToolCallMiddlewares + tuple, err := convTools(ctx, p) + if err != nil { + return nil, fmt.Errorf("failed to convert tool list from call option: %w", err) + } + return tuple, nil +} + // Invoke calls the tools and collects the results of invokable tools. // it's parallel if there are multiple tool calls in the input message. func (tn *ToolsNode) Invoke(ctx context.Context, input *schema.Message, @@ -789,11 +1048,11 @@ func (tn *ToolsNode) Invoke(ctx context.Context, input *schema.Message, opt := getToolsNodeOptions(opts...) tuple := tn.tuple - if opt.ToolList != nil { + if opt.ToolList != nil || opt.ToolAliases != nil { var err error - tuple, err = convTools(ctx, opt.ToolList, tn.toolCallMiddlewares, tn.streamToolCallMiddlewares, tn.enhancedToolCallMiddlewares, tn.enhancedStreamToolCallMiddlewares) + tuple, err = tn.buildTupleFromOpts(ctx, opt) if err != nil { - return nil, fmt.Errorf("failed to convert tool list from call option: %w", err) + return nil, err } } @@ -891,11 +1150,11 @@ func (tn *ToolsNode) Stream(ctx context.Context, input *schema.Message, opt := getToolsNodeOptions(opts...) tuple := tn.tuple - if opt.ToolList != nil { + if opt.ToolList != nil || opt.ToolAliases != nil { var err error - tuple, err = convTools(ctx, opt.ToolList, tn.toolCallMiddlewares, tn.streamToolCallMiddlewares, tn.enhancedToolCallMiddlewares, tn.enhancedStreamToolCallMiddlewares) + tuple, err = tn.buildTupleFromOpts(ctx, opt) if err != nil { - return nil, fmt.Errorf("failed to convert tool list from call option: %w", err) + return nil, err } } diff --git a/compose/types.go b/compose/types.go index 13d925df2..54f8e2be3 100644 --- a/compose/types.go +++ b/compose/types.go @@ -25,13 +25,14 @@ type component = components.Component // built-in component types in graph node. // it represents the type of the most primitive executable object provided by the user. const ( - ComponentOfUnknown component = "Unknown" - ComponentOfGraph component = "Graph" - ComponentOfWorkflow component = "Workflow" - ComponentOfChain component = "Chain" - ComponentOfPassthrough component = "Passthrough" - ComponentOfToolsNode component = "ToolsNode" - ComponentOfLambda component = "Lambda" + ComponentOfUnknown component = "Unknown" + ComponentOfGraph component = "Graph" + ComponentOfWorkflow component = "Workflow" + ComponentOfChain component = "Chain" + ComponentOfPassthrough component = "Passthrough" + ComponentOfToolsNode component = "ToolsNode" + ComponentOfAgenticToolsNode component = "AgenticToolsNode" + ComponentOfLambda component = "Lambda" ) // NodeTriggerMode controls the triggering mode of graph nodes. diff --git a/compose/workflow.go b/compose/workflow.go index c3e4331a3..6b50962bb 100644 --- a/compose/workflow.go +++ b/compose/workflow.go @@ -89,18 +89,36 @@ func (wf *Workflow[I, O]) AddChatModelNode(key string, chatModel model.BaseChatM return wf.initNode(key) } +// AddAgenticModelNode adds an agentic model node and returns it. +func (wf *Workflow[I, O]) AddAgenticModelNode(key string, agenticModel model.AgenticModel, opts ...GraphAddNodeOpt) *WorkflowNode { + _ = wf.g.AddAgenticModelNode(key, agenticModel, opts...) + return wf.initNode(key) +} + // AddChatTemplateNode adds a chat template node and returns it. func (wf *Workflow[I, O]) AddChatTemplateNode(key string, chatTemplate prompt.ChatTemplate, opts ...GraphAddNodeOpt) *WorkflowNode { _ = wf.g.AddChatTemplateNode(key, chatTemplate, opts...) return wf.initNode(key) } +// AddAgenticChatTemplateNode adds an agentic chat template node and returns it. +func (wf *Workflow[I, O]) AddAgenticChatTemplateNode(key string, chatTemplate prompt.AgenticChatTemplate, opts ...GraphAddNodeOpt) *WorkflowNode { + _ = wf.g.AddAgenticChatTemplateNode(key, chatTemplate, opts...) + return wf.initNode(key) +} + // AddToolsNode adds a tools node and returns it. func (wf *Workflow[I, O]) AddToolsNode(key string, tools *ToolsNode, opts ...GraphAddNodeOpt) *WorkflowNode { _ = wf.g.AddToolsNode(key, tools, opts...) return wf.initNode(key) } +// AddAgenticToolsNode adds an agentic tools node and returns it. +func (wf *Workflow[I, O]) AddAgenticToolsNode(key string, tools *AgenticToolsNode, opts ...GraphAddNodeOpt) *WorkflowNode { + _ = wf.g.AddAgenticToolsNode(key, tools, opts...) + return wf.initNode(key) +} + // AddRetrieverNode adds a retriever node and returns it. func (wf *Workflow[I, O]) AddRetrieverNode(key string, retriever retriever.Retriever, opts ...GraphAddNodeOpt) *WorkflowNode { _ = wf.g.AddRetrieverNode(key, retriever, opts...) diff --git a/examples b/examples new file mode 160000 index 000000000..4afd5a3f2 --- /dev/null +++ b/examples @@ -0,0 +1 @@ +Subproject commit 4afd5a3f26a4db4833088505b9f7a0f631e9f231 diff --git a/ext b/ext new file mode 160000 index 000000000..f061db7e8 --- /dev/null +++ b/ext @@ -0,0 +1 @@ +Subproject commit f061db7e84191705db6c48f0085938de84f90742 diff --git a/go.mod b/go.mod index cfa6957cc..0b87a6cab 100644 --- a/go.mod +++ b/go.mod @@ -41,6 +41,7 @@ require ( github.com/yargevad/filepathx v1.0.0 // indirect golang.org/x/arch v0.11.0 // indirect golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 // indirect - golang.org/x/sys v0.26.0 // indirect + golang.org/x/sys v0.29.0 // indirect + golang.org/x/term v0.28.0 // indirect gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect ) diff --git a/go.sum b/go.sum index a80d6399b..5813766b2 100644 --- a/go.sum +++ b/go.sum @@ -117,9 +117,10 @@ golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= -golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.10.0 h1:3R7pNqamzBraeqj/Tj8qt1aQ2HpmlC+Cx/qL/7hn4/c= +golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= +golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg= +golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= diff --git a/internal/channel.go b/internal/channel.go index 2351c87e9..fa4215359 100644 --- a/internal/channel.go +++ b/internal/channel.go @@ -46,17 +46,33 @@ func (ch *UnboundedChan[T]) Send(value T) { ch.notEmpty.Signal() // Wake up one goroutine waiting to receive } -// Receive gets an item from the channel (blocks if empty) +// TrySend attempts to put an item into the channel. +// Returns false if the channel is closed, true otherwise. +func (ch *UnboundedChan[T]) TrySend(value T) bool { + ch.mutex.Lock() + defer ch.mutex.Unlock() + + if ch.closed { + return false + } + + ch.buffer = append(ch.buffer, value) + ch.notEmpty.Signal() + return true +} + +// Receive gets an item from the channel (blocks if empty). +// Returns (value, true) if an item was received. +// Returns (zero, false) if the channel was closed with no data remaining. func (ch *UnboundedChan[T]) Receive() (T, bool) { ch.mutex.Lock() defer ch.mutex.Unlock() for len(ch.buffer) == 0 && !ch.closed { - ch.notEmpty.Wait() // Wait until data is available + ch.notEmpty.Wait() } if len(ch.buffer) == 0 { - // Channel is closed and empty var zero T return zero, false } @@ -73,6 +89,6 @@ func (ch *UnboundedChan[T]) Close() { if !ch.closed { ch.closed = true - ch.notEmpty.Broadcast() // Wake up all waiting goroutines + ch.notEmpty.Broadcast() } } diff --git a/internal/concat.go b/internal/concat.go index 2681322ab..fd9b8abc5 100644 --- a/internal/concat.go +++ b/internal/concat.go @@ -99,7 +99,7 @@ func ConcatItems[T any](items []T) (T, error) { if typ.Kind() == reflect.Map { cv, err = concatMaps(v) } else { - cv, err = concatSliceValue(v) + cv, err = ConcatSliceValue(v) } if err != nil { @@ -158,7 +158,7 @@ func concatMaps(ms reflect.Value) (reflect.Value, error) { if v.Type().Elem().Kind() == reflect.Map { cv, err = concatMaps(v) } else { - cv, err = concatSliceValue(v) + cv, err = ConcatSliceValue(v) } if err != nil { @@ -171,7 +171,7 @@ func concatMaps(ms reflect.Value) (reflect.Value, error) { return ret, nil } -func concatSliceValue(val reflect.Value) (reflect.Value, error) { +func ConcatSliceValue(val reflect.Value) (reflect.Value, error) { elmType := val.Type().Elem() if val.Len() == 1 { diff --git a/internal/core/address.go b/internal/core/address.go index 8efabf943..bb2400a92 100644 --- a/internal/core/address.go +++ b/internal/core/address.go @@ -88,7 +88,7 @@ type addrCtx struct { type globalResumeInfoKey struct{} type globalResumeInfo struct { - mu sync.Mutex + mu sync.RWMutex id2ResumeData map[string]any id2ResumeDataUsed map[string]bool id2State map[string]InterruptState @@ -147,24 +147,21 @@ func AppendAddressSegment(ctx context.Context, segType AddressSegmentType, segID return context.WithValue(ctx, addrCtxKey{}, runCtx) } + rInfo.mu.Lock() + defer rInfo.mu.Unlock() + var id string for id_, addr := range rInfo.id2Addr { if addr.Equals(currentAddress) { - rInfo.mu.Lock() if used, ok := rInfo.id2StateUsed[id_]; !ok || !used { runCtx.interruptState = generic.PtrOf(rInfo.id2State[id_]) rInfo.id2StateUsed[id_] = true id = id_ - rInfo.mu.Unlock() break } - rInfo.mu.Unlock() } } - // take from globalResumeInfo the data for the new address if there is any - rInfo.mu.Lock() - defer rInfo.mu.Unlock() used := rInfo.id2ResumeDataUsed[id] if !used { rData, existed := rInfo.id2ResumeData[id] @@ -175,10 +172,6 @@ func AppendAddressSegment(ctx context.Context, segType AddressSegmentType, segID } } - // Also mark as resume target if any descendant address is a resume target. - // This allows composite components (e.g., a tool containing a nested graph) to know - // they should execute their children to reach the actual resume target. - // We only consider descendants whose resume data has not yet been consumed. if !runCtx.isResumeTarget { for id_, addr := range rInfo.id2Addr { if len(addr) > len(currentAddress) && addr[:len(currentAddress)].Equals(currentAddress) { @@ -202,6 +195,9 @@ func GetNextResumptionPoints(ctx context.Context) (map[string]bool, error) { return nil, fmt.Errorf("GetNextResumptionPoints: failed to get resume info from context") } + rInfo.mu.RLock() + defer rInfo.mu.RUnlock() + nextPoints := make(map[string]bool) parentAddrLen := len(parentAddr) @@ -276,13 +272,21 @@ func PopulateInterruptState(ctx context.Context, id2Addr map[string]Address, id2State map[string]InterruptState) context.Context { rInfo, ok := ctx.Value(globalResumeInfoKey{}).(*globalResumeInfo) if ok { + rInfo.mu.Lock() + defer rInfo.mu.Unlock() + if rInfo.id2Addr == nil { rInfo.id2Addr = make(map[string]Address) } for id, addr := range id2Addr { rInfo.id2Addr[id] = addr } - rInfo.id2State = id2State + if rInfo.id2State == nil { + rInfo.id2State = make(map[string]InterruptState) + } + for id, state := range id2State { + rInfo.id2State[id] = state + } } else { rInfo = &globalResumeInfo{ id2Addr: id2Addr, @@ -299,17 +303,13 @@ func PopulateInterruptState(ctx context.Context, id2Addr map[string]Address, if addr.Equals(runCtx.addr) { if used, ok := rInfo.id2StateUsed[id_]; !ok || !used { runCtx.interruptState = generic.PtrOf(rInfo.id2State[id_]) - rInfo.mu.Lock() rInfo.id2StateUsed[id_] = true - rInfo.mu.Unlock() } if used, ok := rInfo.id2ResumeDataUsed[id_]; !ok || !used { runCtx.isResumeTarget = true runCtx.resumeData = rInfo.id2ResumeData[id_] - rInfo.mu.Lock() rInfo.id2ResumeDataUsed[id_] = true - rInfo.mu.Unlock() } break diff --git a/internal/core/interrupt.go b/internal/core/interrupt.go index d7a934a3d..38ddbdae0 100644 --- a/internal/core/interrupt.go +++ b/internal/core/interrupt.go @@ -29,6 +29,17 @@ type CheckPointStore interface { Set(ctx context.Context, checkPointID string, checkPoint []byte) error } +// CheckPointDeleter is an optional interface that CheckPointStore implementations +// can implement to support explicit checkpoint deletion. +// +// If the Store does not implement this interface, stale checkpoints will NOT be +// automatically cleaned up. The store owner is responsible for managing checkpoint +// lifecycle in that case (e.g., via TTL, external cleanup, or implementing this +// interface). +type CheckPointDeleter interface { + Delete(ctx context.Context, checkPointID string) error +} + type InterruptSignal struct { ID string Address diff --git a/schema/agentic_message.go b/schema/agentic_message.go new file mode 100644 index 000000000..43376c146 --- /dev/null +++ b/schema/agentic_message.go @@ -0,0 +1,2139 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package schema + +import ( + "context" + "encoding/json" + "fmt" + "reflect" + "sort" + "strings" + + "github.com/eino-contrib/jsonschema" + + "github.com/cloudwego/eino/internal" + "github.com/cloudwego/eino/schema/claude" + "github.com/cloudwego/eino/schema/gemini" + "github.com/cloudwego/eino/schema/openai" +) + +type ContentBlockType string + +const ( + ContentBlockTypeReasoning ContentBlockType = "reasoning" + ContentBlockTypeUserInputText ContentBlockType = "user_input_text" + ContentBlockTypeUserInputImage ContentBlockType = "user_input_image" + ContentBlockTypeUserInputAudio ContentBlockType = "user_input_audio" + ContentBlockTypeUserInputVideo ContentBlockType = "user_input_video" + ContentBlockTypeUserInputFile ContentBlockType = "user_input_file" + ContentBlockTypeToolSearchResult ContentBlockType = "tool_search_result" + ContentBlockTypeAssistantGenText ContentBlockType = "assistant_gen_text" + ContentBlockTypeAssistantGenImage ContentBlockType = "assistant_gen_image" + ContentBlockTypeAssistantGenAudio ContentBlockType = "assistant_gen_audio" + ContentBlockTypeAssistantGenVideo ContentBlockType = "assistant_gen_video" + ContentBlockTypeFunctionToolCall ContentBlockType = "function_tool_call" + ContentBlockTypeFunctionToolResult ContentBlockType = "function_tool_result" + ContentBlockTypeServerToolCall ContentBlockType = "server_tool_call" + ContentBlockTypeServerToolResult ContentBlockType = "server_tool_result" + ContentBlockTypeMCPToolCall ContentBlockType = "mcp_tool_call" + ContentBlockTypeMCPToolResult ContentBlockType = "mcp_tool_result" + ContentBlockTypeMCPListToolsResult ContentBlockType = "mcp_list_tools_result" + ContentBlockTypeMCPToolApprovalRequest ContentBlockType = "mcp_tool_approval_request" + ContentBlockTypeMCPToolApprovalResponse ContentBlockType = "mcp_tool_approval_response" +) + +type AgenticRoleType string + +const ( + AgenticRoleTypeSystem AgenticRoleType = "system" + AgenticRoleTypeUser AgenticRoleType = "user" + AgenticRoleTypeAssistant AgenticRoleType = "assistant" +) + +type AgenticMessage struct { + // Role is the message role. + Role AgenticRoleType `json:"role"` + + // ContentBlocks is the list of content blocks. + ContentBlocks []*ContentBlock `json:"content_blocks,omitempty"` + + // ResponseMeta is the response metadata. + ResponseMeta *AgenticResponseMeta `json:"response_meta,omitempty"` + + // Extra is the additional information. + Extra map[string]any `json:"extra,omitempty"` +} + +type AgenticResponseMeta struct { + // TokenUsage is the token usage. + TokenUsage *TokenUsage `json:"token_usage,omitempty"` + + // OpenAIExtension is the extension for OpenAI. + OpenAIExtension *openai.ResponseMetaExtension `json:"openai_extension,omitempty"` + + // GeminiExtension is the extension for Gemini. + GeminiExtension *gemini.ResponseMetaExtension `json:"gemini_extension,omitempty"` + + // ClaudeExtension is the extension for Claude. + ClaudeExtension *claude.ResponseMetaExtension `json:"claude_extension,omitempty"` + + // Extension is the extension for other models, supplied by the component implementer. + Extension any `json:"extension,omitempty"` +} + +type ContentBlock struct { + Type ContentBlockType `json:"type"` + + // Reasoning contains the reasoning content generated by the model. + Reasoning *Reasoning `json:"reasoning,omitempty"` + + // UserInputText contains the text content provided by the user. + UserInputText *UserInputText `json:"user_input_text,omitempty"` + + // UserInputImage contains the image content provided by the user. + UserInputImage *UserInputImage `json:"user_input_image,omitempty"` + + // UserInputAudio contains the audio content provided by the user. + UserInputAudio *UserInputAudio `json:"user_input_audio,omitempty"` + + // UserInputVideo contains the video content provided by the user. + UserInputVideo *UserInputVideo `json:"user_input_video,omitempty"` + + // UserInputFile contains the file content provided by the user. + UserInputFile *UserInputFile `json:"user_input_file,omitempty"` + + // AssistantGenText contains the text content generated by the model. + AssistantGenText *AssistantGenText `json:"assistant_gen_text,omitempty"` + + // AssistantGenImage contains the image content generated by the model. + AssistantGenImage *AssistantGenImage `json:"assistant_gen_image,omitempty"` + + // AssistantGenAudio contains the audio content generated by the model. + AssistantGenAudio *AssistantGenAudio `json:"assistant_gen_audio,omitempty"` + + // AssistantGenVideo contains the video content generated by the model. + AssistantGenVideo *AssistantGenVideo `json:"assistant_gen_video,omitempty"` + + // FunctionToolCall contains the invocation details for a user-defined tool. + FunctionToolCall *FunctionToolCall `json:"function_tool_call,omitempty"` + + // FunctionToolResult contains the result returned from a user-defined tool call. + FunctionToolResult *FunctionToolResult `json:"function_tool_result,omitempty"` + + // ToolSearchFunctionToolResult contains the result of a client-side custom tool search tool call. + // It carries the full definitions of newly discovered tools so that the model can + // recognize which tools have been added and are now available for invocation. + ToolSearchFunctionToolResult *ToolSearchFunctionToolResult `json:"tool_search_function_tool_result,omitempty"` + + // ServerToolCall contains the invocation details for a provider built-in tool executed on the model server. + ServerToolCall *ServerToolCall `json:"server_tool_call,omitempty"` + + // ServerToolResult contains the result returned from a provider built-in tool executed on the model server. + ServerToolResult *ServerToolResult `json:"server_tool_result,omitempty"` + + // MCPToolCall contains the invocation details for an MCP tool managed by the model server. + MCPToolCall *MCPToolCall `json:"mcp_tool_call,omitempty"` + + // MCPToolResult contains the result returned from an MCP tool managed by the model server. + MCPToolResult *MCPToolResult `json:"mcp_tool_result,omitempty"` + + // MCPListToolsResult contains the list of available MCP tools reported by the model server. + MCPListToolsResult *MCPListToolsResult `json:"mcp_list_tools_result,omitempty"` + + // MCPToolApprovalRequest contains the user approval request for an MCP tool call when required. + MCPToolApprovalRequest *MCPToolApprovalRequest `json:"mcp_tool_approval_request,omitempty"` + + // MCPToolApprovalResponse contains the user's approval decision for an MCP tool call. + MCPToolApprovalResponse *MCPToolApprovalResponse `json:"mcp_tool_approval_response,omitempty"` + + // StreamingMeta contains metadata for streaming responses. + StreamingMeta *StreamingMeta `json:"streaming_meta,omitempty"` + + // Extra contains additional information for the content block. + Extra map[string]any `json:"extra,omitempty"` +} + +type StreamingMeta struct { + // Index specifies the index position of this block in the final response. + Index int `json:"index"` +} + +type UserInputText struct { + // Text is the text content. + Text string `json:"text,omitempty"` +} + +type UserInputImage struct { + // URL is the HTTP/HTTPS link. + URL string `json:"url,omitempty"` + + // Base64Data is the binary data in Base64 encoded string format. + Base64Data string `json:"base64_data,omitempty"` + + // MIMEType is the mime type, e.g. "image/png". + MIMEType string `json:"mime_type,omitempty"` + + // Detail is the quality of the image url. + Detail ImageURLDetail `json:"detail,omitempty"` +} + +type UserInputAudio struct { + // URL is the HTTP/HTTPS link. + URL string `json:"url,omitempty"` + + // Base64Data is the binary data in Base64 encoded string format. + Base64Data string `json:"base64_data,omitempty"` + + // MIMEType is the mime type, e.g. "audio/wav". + MIMEType string `json:"mime_type,omitempty"` +} + +type UserInputVideo struct { + // URL is the HTTP/HTTPS link. + URL string `json:"url,omitempty"` + + // Base64Data is the binary data in Base64 encoded string format. + Base64Data string `json:"base64_data,omitempty"` + + // MIMEType is the mime type, e.g. "video/mp4". + MIMEType string `json:"mime_type,omitempty"` +} + +type UserInputFile struct { + // URL is the HTTP/HTTPS link. + URL string `json:"url,omitempty"` + + // Name is the filename. + Name string `json:"name,omitempty"` + + // Base64Data is the binary data in Base64 encoded string format. + Base64Data string `json:"base64_data,omitempty"` + + // MIMEType is the mime type, e.g. "application/pdf". + MIMEType string `json:"mime_type,omitempty"` +} + +type AssistantGenText struct { + // Text is the generated text. + Text string `json:"text,omitempty"` + + // OpenAIExtension is the extension for OpenAI. + OpenAIExtension *openai.AssistantGenTextExtension `json:"openai_extension,omitempty"` + + // ClaudeExtension is the extension for Claude. + ClaudeExtension *claude.AssistantGenTextExtension `json:"claude_extension,omitempty"` + + // Extension is the extension for other models, supplied by the component implementer. + Extension any `json:"extension,omitempty"` +} + +type AssistantGenImage struct { + // URL is the HTTP/HTTPS link. + URL string `json:"url,omitempty"` + + // Base64Data is the binary data in Base64 encoded string format. + Base64Data string `json:"base64_data,omitempty"` + + // MIMEType is the mime type, e.g. "image/png". + MIMEType string `json:"mime_type,omitempty"` +} + +type AssistantGenAudio struct { + // URL is the HTTP/HTTPS link. + URL string `json:"url,omitempty"` + + // Base64Data is the binary data in Base64 encoded string format. + Base64Data string `json:"base64_data,omitempty"` + + // MIMEType is the mime type, e.g. "audio/wav". + MIMEType string `json:"mime_type,omitempty"` +} + +type AssistantGenVideo struct { + // URL is the HTTP/HTTPS link. + URL string `json:"url,omitempty"` + + // Base64Data is the binary data in Base64 encoded string format. + Base64Data string `json:"base64_data,omitempty"` + + // MIMEType is the mime type, e.g. "video/mp4". + MIMEType string `json:"mime_type,omitempty"` +} + +type Reasoning struct { + // Text is either the thought summary or the raw reasoning text itself. + Text string `json:"text,omitempty"` + + // Signature contains encrypted reasoning tokens. + // Required by some models when passing reasoning text back. + Signature string `json:"signature,omitempty"` +} + +type FunctionToolCall struct { + // CallID is the unique identifier for the tool call. + CallID string `json:"call_id,omitempty"` + + // Name specifies the function tool invoked. + Name string `json:"name"` + + // Arguments is the JSON string arguments for the function tool call. + Arguments string `json:"arguments,omitempty"` +} + +type FunctionToolResult struct { + // CallID is the unique identifier for the tool call. + CallID string `json:"call_id,omitempty"` + + // Name specifies the function tool invoked. + Name string `json:"name"` + + // Result is the function tool result returned by the user + Result string `json:"result,omitempty"` +} + +// ToolSearchFunctionToolResult represents the result of a client-side custom tool search +// function tool call. Unlike a regular FunctionToolResult, this carries a ToolSearchResult +// containing the full definitions of newly discovered tools, so the model can recognize +// which tools have been added and are now available for invocation. +type ToolSearchFunctionToolResult struct { + // CallID is the unique identifier for the tool call. + CallID string `json:"call_id,omitempty"` + + // Name specifies the function tool invoked. + Name string `json:"name"` + + // Result is the function tool result returned by the user + Result *ToolSearchResult `json:"result,omitempty"` +} + +func (t *ToolSearchFunctionToolResult) String() string { + if t.Result != nil { + return t.Result.String() + } + return "" +} + +type ServerToolCall struct { + // Name specifies the server-side tool invoked. + // Supplied by the model server (e.g., `web_search` for OpenAI, `googleSearch` for Gemini). + Name string `json:"name"` + + // CallID is the unique identifier for the tool call. + // Empty if not provided by the model server. + CallID string `json:"call_id,omitempty"` + + // Arguments are the raw inputs to the server-side tool, + // supplied by the component implementer. + Arguments any `json:"arguments,omitempty"` +} + +type ServerToolResult struct { + // Name specifies the server-side tool invoked. + // Supplied by the model server (e.g., `web_search` for OpenAI, `googleSearch` for Gemini). + Name string `json:"name"` + + // CallID is the unique identifier for the tool call. + // Empty if not provided by the model server. + CallID string `json:"call_id,omitempty"` + + // Result refers to the raw output generated by the server-side tool, + // supplied by the component implementer. + Result any `json:"result,omitempty"` +} + +type MCPToolCall struct { + // ServerLabel is the MCP server label used to identify it in tool calls + ServerLabel string `json:"server_label,omitempty"` + + // ApprovalRequestID is the approval request ID. + ApprovalRequestID string `json:"approval_request_id,omitempty"` + + // CallID is the unique ID of the tool call. + CallID string `json:"call_id,omitempty"` + + // Name is the name of the tool to run. + Name string `json:"name"` + + // Arguments is the JSON string arguments for the tool call. + Arguments string `json:"arguments,omitempty"` +} + +type MCPToolResult struct { + // ServerLabel is the MCP server label used to identify it in tool calls + ServerLabel string `json:"server_label,omitempty"` + + // CallID is the unique ID of the tool call. + CallID string `json:"call_id,omitempty"` + + // Name is the name of the tool to run. + Name string `json:"name"` + + // Result is the JSON string with the tool result. + Result string `json:"result,omitempty"` + + // Error returned when the server fails to run the tool. + Error *MCPToolCallError `json:"error,omitempty"` +} + +type MCPToolCallError struct { + // Code is the error code. + Code *int64 `json:"code,omitempty"` + + // Message is the error message. + Message string `json:"message,omitempty"` +} + +type MCPListToolsResult struct { + // ServerLabel is the MCP server label used to identify it in tool calls. + ServerLabel string `json:"server_label,omitempty"` + + // Tools is the list of tools available on the server. + Tools []*MCPListToolsItem `json:"tools,omitempty"` + + // Error returned when the server fails to list tools. + Error string `json:"error,omitempty"` +} + +type MCPListToolsItem struct { + // Name is the name of the tool. + Name string `json:"name"` + + // Description is the description of the tool. + Description string `json:"description"` + + // InputSchema is the JSON schema that describes the tool input parameters. + InputSchema *jsonschema.Schema `json:"input_schema,omitempty"` +} + +type MCPToolApprovalRequest struct { + // ID is the approval request ID. + ID string `json:"id,omitempty"` + + // Name is the name of the tool to run. + Name string `json:"name"` + + // Arguments is the JSON string arguments for the tool call. + Arguments string `json:"arguments,omitempty"` + + // ServerLabel is the MCP server label used to identify it in tool calls. + ServerLabel string `json:"server_label,omitempty"` +} + +type MCPToolApprovalResponse struct { + // ApprovalRequestID is the approval request ID being responded to. + ApprovalRequestID string `json:"approval_request_id,omitempty"` + + // Approve indicates whether the request is approved. + Approve bool `json:"approve"` + + // Reason is the rationale for the decision. + // Optional. + Reason string `json:"reason,omitempty"` +} + +// SystemAgenticMessage represents a message with AgenticRoleType "system". +func SystemAgenticMessage(text string) *AgenticMessage { + return &AgenticMessage{ + Role: AgenticRoleTypeSystem, + ContentBlocks: []*ContentBlock{NewContentBlock(&UserInputText{Text: text})}, + } +} + +// UserAgenticMessage represents a message with AgenticRoleType "user". +func UserAgenticMessage(text string) *AgenticMessage { + return &AgenticMessage{ + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{NewContentBlock(&UserInputText{Text: text})}, + } +} + +// FunctionToolResultAgenticMessage represents a function tool result message with AgenticRoleType "user". +func FunctionToolResultAgenticMessage(callID, name, result string) *AgenticMessage { + return &AgenticMessage{ + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + NewContentBlock(&FunctionToolResult{ + CallID: callID, + Name: name, + Result: result, + }), + }, + } +} + +type contentBlockVariant interface { + Reasoning | userInputVariant | assistantGenVariant | functionToolCallVariant | serverToolCallVariant | mcpToolCallVariant +} + +type userInputVariant interface { + UserInputText | UserInputImage | UserInputAudio | UserInputVideo | UserInputFile +} + +type assistantGenVariant interface { + AssistantGenText | AssistantGenImage | AssistantGenAudio | AssistantGenVideo +} + +type functionToolCallVariant interface { + FunctionToolCall | FunctionToolResult | ToolSearchFunctionToolResult +} + +type serverToolCallVariant interface { + ServerToolCall | ServerToolResult +} + +type mcpToolCallVariant interface { + MCPToolCall | MCPToolResult | MCPListToolsResult | MCPToolApprovalRequest | MCPToolApprovalResponse +} + +// NewContentBlock creates a new ContentBlock with the given content. +func NewContentBlock[T contentBlockVariant](content *T) *ContentBlock { + switch b := any(content).(type) { + case *Reasoning: + return &ContentBlock{Type: ContentBlockTypeReasoning, Reasoning: b} + case *UserInputText: + return &ContentBlock{Type: ContentBlockTypeUserInputText, UserInputText: b} + case *UserInputImage: + return &ContentBlock{Type: ContentBlockTypeUserInputImage, UserInputImage: b} + case *UserInputAudio: + return &ContentBlock{Type: ContentBlockTypeUserInputAudio, UserInputAudio: b} + case *UserInputVideo: + return &ContentBlock{Type: ContentBlockTypeUserInputVideo, UserInputVideo: b} + case *UserInputFile: + return &ContentBlock{Type: ContentBlockTypeUserInputFile, UserInputFile: b} + case *ToolSearchFunctionToolResult: + return &ContentBlock{Type: ContentBlockTypeToolSearchResult, ToolSearchFunctionToolResult: b} + case *AssistantGenText: + return &ContentBlock{Type: ContentBlockTypeAssistantGenText, AssistantGenText: b} + case *AssistantGenImage: + return &ContentBlock{Type: ContentBlockTypeAssistantGenImage, AssistantGenImage: b} + case *AssistantGenAudio: + return &ContentBlock{Type: ContentBlockTypeAssistantGenAudio, AssistantGenAudio: b} + case *AssistantGenVideo: + return &ContentBlock{Type: ContentBlockTypeAssistantGenVideo, AssistantGenVideo: b} + case *FunctionToolCall: + return &ContentBlock{Type: ContentBlockTypeFunctionToolCall, FunctionToolCall: b} + case *FunctionToolResult: + return &ContentBlock{Type: ContentBlockTypeFunctionToolResult, FunctionToolResult: b} + case *ServerToolCall: + return &ContentBlock{Type: ContentBlockTypeServerToolCall, ServerToolCall: b} + case *ServerToolResult: + return &ContentBlock{Type: ContentBlockTypeServerToolResult, ServerToolResult: b} + case *MCPToolCall: + return &ContentBlock{Type: ContentBlockTypeMCPToolCall, MCPToolCall: b} + case *MCPToolResult: + return &ContentBlock{Type: ContentBlockTypeMCPToolResult, MCPToolResult: b} + case *MCPListToolsResult: + return &ContentBlock{Type: ContentBlockTypeMCPListToolsResult, MCPListToolsResult: b} + case *MCPToolApprovalRequest: + return &ContentBlock{Type: ContentBlockTypeMCPToolApprovalRequest, MCPToolApprovalRequest: b} + case *MCPToolApprovalResponse: + return &ContentBlock{Type: ContentBlockTypeMCPToolApprovalResponse, MCPToolApprovalResponse: b} + default: + return nil + } +} + +// NewContentBlockChunk creates a new ContentBlock with the given content and streaming metadata. +func NewContentBlockChunk[T contentBlockVariant](content *T, meta *StreamingMeta) *ContentBlock { + block := NewContentBlock(content) + block.StreamingMeta = meta + return block +} + +// AgenticMessagesTemplate is the interface for agentic messages template. +// It's used to render a template to a list of agentic messages. +// e.g. +// +// chatTemplate := prompt.FromAgenticMessages( +// &schema.AgenticMessage{ +// Role: schema.AgenticRoleTypeSystem, +// ContentBlocks: []*schema.ContentBlock{ +// {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "you are an eino helper"}}, +// }, +// }, +// schema.AgenticMessagesPlaceholder("history", false), // <= this will use the value of "history" in params +// ) +// msgs, err := chatTemplate.Format(ctx, params) +type AgenticMessagesTemplate interface { + Format(ctx context.Context, vs map[string]any, formatType FormatType) ([]*AgenticMessage, error) +} + +var _ AgenticMessagesTemplate = &AgenticMessage{} +var _ AgenticMessagesTemplate = AgenticMessagesPlaceholder("", false) + +type agenticMessagesPlaceholder struct { + key string + optional bool +} + +// AgenticMessagesPlaceholder can render a placeholder to a list of agentic messages in params. +// e.g. +// +// placeholder := AgenticMessagesPlaceholder("history", false) +// params := map[string]any{ +// "history": []*schema.AgenticMessage{ +// &schema.AgenticMessage{ +// Role: schema.AgenticRoleTypeSystem, +// ContentBlocks: []*schema.ContentBlock{ +// {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "you are an eino helper"}}, +// }, +// }, +// }, +// } +// chatTemplate := chatTpl := prompt.FromMessages( +// schema.AgenticMessagesPlaceholder("history", false), // <= this will use the value of "history" in params +// ) +// msgs, err := chatTemplate.Format(ctx, params) +func AgenticMessagesPlaceholder(key string, optional bool) AgenticMessagesTemplate { + return &agenticMessagesPlaceholder{ + key: key, + optional: optional, + } +} + +func (p *agenticMessagesPlaceholder) Format(_ context.Context, vs map[string]any, _ FormatType) ([]*AgenticMessage, error) { + v, ok := vs[p.key] + if !ok { + if p.optional { + return []*AgenticMessage{}, nil + } + + return nil, fmt.Errorf("message placeholder format: %s not found", p.key) + } + + msgs, ok := v.([]*AgenticMessage) + if !ok { + return nil, fmt.Errorf("only agentic messages can be used to format message placeholder, key: %v, actual type: %v", p.key, reflect.TypeOf(v)) + } + + return msgs, nil +} + +// Format returns the agentic messages after rendering by the given formatType. +// It formats only the user input fields (UserInputText, UserInputImage, UserInputAudio, UserInputVideo, UserInputFile). +// e.g. +// +// msg := &schema.AgenticMessage{ +// Role: schema.AgenticRoleTypeUser, +// ContentBlocks: []*schema.ContentBlock{ +// {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "hello {name}"}}, +// }, +// } +// msgs, err := msg.Format(ctx, map[string]any{"name": "eino"}, schema.FString) +// // msgs[0].ContentBlocks[0].UserInputText.Text will be "hello eino" +func (m *AgenticMessage) Format(_ context.Context, vs map[string]any, formatType FormatType) ([]*AgenticMessage, error) { + copied := *m + + if len(m.ContentBlocks) > 0 { + copiedBlocks := make([]*ContentBlock, len(m.ContentBlocks)) + for i, block := range m.ContentBlocks { + if block == nil { + copiedBlocks[i] = nil + continue + } + + copiedBlock := *block + var err error + + switch block.Type { + case ContentBlockTypeUserInputText: + if block.UserInputText != nil { + copiedBlock.UserInputText, err = formatUserInputText(block.UserInputText, vs, formatType) + if err != nil { + return nil, err + } + } + case ContentBlockTypeUserInputImage: + if block.UserInputImage != nil { + copiedBlock.UserInputImage, err = formatUserInputImage(block.UserInputImage, vs, formatType) + if err != nil { + return nil, err + } + } + case ContentBlockTypeUserInputAudio: + if block.UserInputAudio != nil { + copiedBlock.UserInputAudio, err = formatUserInputAudio(block.UserInputAudio, vs, formatType) + if err != nil { + return nil, err + } + } + case ContentBlockTypeUserInputVideo: + if block.UserInputVideo != nil { + copiedBlock.UserInputVideo, err = formatUserInputVideo(block.UserInputVideo, vs, formatType) + if err != nil { + return nil, err + } + } + case ContentBlockTypeUserInputFile: + if block.UserInputFile != nil { + copiedBlock.UserInputFile, err = formatUserInputFile(block.UserInputFile, vs, formatType) + if err != nil { + return nil, err + } + } + } + + copiedBlocks[i] = &copiedBlock + } + copied.ContentBlocks = copiedBlocks + } + + return []*AgenticMessage{&copied}, nil +} + +func formatUserInputText(uit *UserInputText, vs map[string]any, formatType FormatType) (*UserInputText, error) { + text, err := formatContent(uit.Text, vs, formatType) + if err != nil { + return nil, err + } + copied := *uit + copied.Text = text + return &copied, nil +} + +func formatUserInputImage(uii *UserInputImage, vs map[string]any, formatType FormatType) (*UserInputImage, error) { + copied := *uii + if uii.URL != "" { + url, err := formatContent(uii.URL, vs, formatType) + if err != nil { + return nil, err + } + copied.URL = url + } + if uii.Base64Data != "" { + base64data, err := formatContent(uii.Base64Data, vs, formatType) + if err != nil { + return nil, err + } + copied.Base64Data = base64data + } + return &copied, nil +} + +func formatUserInputAudio(uia *UserInputAudio, vs map[string]any, formatType FormatType) (*UserInputAudio, error) { + copied := *uia + if uia.URL != "" { + url, err := formatContent(uia.URL, vs, formatType) + if err != nil { + return nil, err + } + copied.URL = url + } + if uia.Base64Data != "" { + base64data, err := formatContent(uia.Base64Data, vs, formatType) + if err != nil { + return nil, err + } + copied.Base64Data = base64data + } + return &copied, nil +} + +func formatUserInputVideo(uiv *UserInputVideo, vs map[string]any, formatType FormatType) (*UserInputVideo, error) { + copied := *uiv + if uiv.URL != "" { + url, err := formatContent(uiv.URL, vs, formatType) + if err != nil { + return nil, err + } + copied.URL = url + } + if uiv.Base64Data != "" { + base64data, err := formatContent(uiv.Base64Data, vs, formatType) + if err != nil { + return nil, err + } + copied.Base64Data = base64data + } + return &copied, nil +} + +func formatUserInputFile(uif *UserInputFile, vs map[string]any, formatType FormatType) (*UserInputFile, error) { + copied := *uif + if uif.URL != "" { + url, err := formatContent(uif.URL, vs, formatType) + if err != nil { + return nil, err + } + copied.URL = url + } + if uif.Name != "" { + name, err := formatContent(uif.Name, vs, formatType) + if err != nil { + return nil, err + } + copied.Name = name + } + if uif.Base64Data != "" { + base64data, err := formatContent(uif.Base64Data, vs, formatType) + if err != nil { + return nil, err + } + copied.Base64Data = base64data + } + return &copied, nil +} + +// ConcatAgenticMessagesArray concatenates multiple streams of AgenticMessage into a single slice of AgenticMessage. +func ConcatAgenticMessagesArray(mas [][]*AgenticMessage) ([]*AgenticMessage, error) { + return buildConcatGenericArray[AgenticMessage](ConcatAgenticMessages)(mas) +} + +// ConcatAgenticMessages concatenates a list of AgenticMessage chunks into a single AgenticMessage. +func ConcatAgenticMessages(msgs []*AgenticMessage) (*AgenticMessage, error) { + var ( + role AgenticRoleType + blocks []*ContentBlock + metas []*AgenticResponseMeta + extra map[string]any + blockIndices []int + indexToBlocks = map[int][]*ContentBlock{} + extraList = make([]map[string]any, 0, len(msgs)) + ) + + if len(msgs) == 1 { + return msgs[0], nil + } + + for idx, msg := range msgs { + if msg == nil { + return nil, fmt.Errorf("message at index %d is nil", idx) + } + + if msg.Role != "" { + if role == "" { + role = msg.Role + } else if role != msg.Role { + return nil, fmt.Errorf("cannot concat messages with different roles: got '%s' and '%s'", role, msg.Role) + } + } + + for _, block := range msg.ContentBlocks { + if block == nil { + continue + } + if block.StreamingMeta == nil { + // Non-streaming block + if len(blockIndices) > 0 { + // Cannot mix streaming and non-streaming blocks + return nil, fmt.Errorf("found non-streaming block after streaming blocks") + } + // Collect non-streaming block + blocks = append(blocks, block) + } else { + // Streaming block + if len(blocks) > 0 { + // Cannot mix non-streaming and streaming blocks + return nil, fmt.Errorf("found streaming block after non-streaming blocks") + } + // Collect streaming block by index + if blocks_, ok := indexToBlocks[block.StreamingMeta.Index]; ok { + indexToBlocks[block.StreamingMeta.Index] = append(blocks_, block) + } else { + blockIndices = append(blockIndices, block.StreamingMeta.Index) + indexToBlocks[block.StreamingMeta.Index] = []*ContentBlock{block} + } + } + } + + if msg.ResponseMeta != nil { + metas = append(metas, msg.ResponseMeta) + } + + if msg.Extra != nil { + extraList = append(extraList, msg.Extra) + } + } + + meta, err := concatAgenticResponseMeta(metas) + if err != nil { + return nil, fmt.Errorf("failed to concat agentic response meta: %w", err) + } + + if len(blockIndices) > 0 { + // All blocks are streaming, concat each group by index + indexToBlock := map[int]*ContentBlock{} + for idx, bs := range indexToBlocks { + var b *ContentBlock + b, err = concatChunksOfSameContentBlock(bs) + if err != nil { + return nil, err + } + indexToBlock[idx] = b + } + blocks = make([]*ContentBlock, 0, len(blockIndices)) + sort.Slice(blockIndices, func(i, j int) bool { + return blockIndices[i] < blockIndices[j] + }) + for _, idx := range blockIndices { + blocks = append(blocks, indexToBlock[idx]) + } + } + + if len(extraList) > 0 { + extra, err = concatExtra(extraList) + if err != nil { + return nil, err + } + } + + return &AgenticMessage{ + Role: role, + ResponseMeta: meta, + ContentBlocks: blocks, + Extra: extra, + }, nil +} + +func concatAgenticResponseMeta(metas []*AgenticResponseMeta) (ret *AgenticResponseMeta, err error) { + if len(metas) == 0 { + return nil, nil + } + + openaiExtensions := make([]*openai.ResponseMetaExtension, 0, len(metas)) + claudeExtensions := make([]*claude.ResponseMetaExtension, 0, len(metas)) + geminiExtensions := make([]*gemini.ResponseMetaExtension, 0, len(metas)) + tokenUsages := make([]*TokenUsage, 0, len(metas)) + + var ( + extType reflect.Type + extensions reflect.Value + ) + + for _, meta := range metas { + if meta.TokenUsage != nil { + tokenUsages = append(tokenUsages, meta.TokenUsage) + } + + var isConsistent bool + + if meta.Extension != nil { + extType, isConsistent = validateExtensionType(extType, meta.Extension) + if !isConsistent { + return nil, fmt.Errorf("inconsistent extension types in response meta chunks: '%s' vs '%s'", + extType, reflect.TypeOf(meta.Extension)) + } + if !extensions.IsValid() { + extensions = reflect.MakeSlice(reflect.SliceOf(extType), 0, len(metas)) + } + extensions = reflect.Append(extensions, reflect.ValueOf(meta.Extension)) + } + + if meta.OpenAIExtension != nil { + extType, isConsistent = validateExtensionType(extType, meta.OpenAIExtension) + if !isConsistent { + return nil, fmt.Errorf("inconsistent extension types in response meta chunks: '%s' vs '%s'", + extType, reflect.TypeOf(meta.OpenAIExtension)) + } + openaiExtensions = append(openaiExtensions, meta.OpenAIExtension) + } + + if meta.ClaudeExtension != nil { + extType, isConsistent = validateExtensionType(extType, meta.ClaudeExtension) + if !isConsistent { + return nil, fmt.Errorf("inconsistent extension types in response meta chunks: '%s' vs '%s'", + extType, reflect.TypeOf(meta.ClaudeExtension)) + } + claudeExtensions = append(claudeExtensions, meta.ClaudeExtension) + } + + if meta.GeminiExtension != nil { + extType, isConsistent = validateExtensionType(extType, meta.GeminiExtension) + if !isConsistent { + return nil, fmt.Errorf("inconsistent extension types in response meta chunks: '%s' vs '%s'", + extType, reflect.TypeOf(meta.GeminiExtension)) + } + geminiExtensions = append(geminiExtensions, meta.GeminiExtension) + } + } + + ret = &AgenticResponseMeta{ + TokenUsage: concatTokenUsage(tokenUsages), + } + + if extensions.IsValid() && !extensions.IsZero() { + var extension reflect.Value + extension, err = internal.ConcatSliceValue(extensions) + if err != nil { + return nil, fmt.Errorf("failed to concat extensions: %w", err) + } + ret.Extension = extension.Interface() + } + + if len(openaiExtensions) > 0 { + ret.OpenAIExtension, err = openai.ConcatResponseMetaExtensions(openaiExtensions) + if err != nil { + return nil, fmt.Errorf("failed to concat openai extensions: %w", err) + } + } + + if len(claudeExtensions) > 0 { + ret.ClaudeExtension, err = claude.ConcatResponseMetaExtensions(claudeExtensions) + if err != nil { + return nil, fmt.Errorf("failed to concat claude extensions: %w", err) + } + } + + if len(geminiExtensions) > 0 { + ret.GeminiExtension, err = gemini.ConcatResponseMetaExtensions(geminiExtensions) + if err != nil { + return nil, fmt.Errorf("failed to concat gemini extensions: %w", err) + } + } + + return ret, nil +} + +func concatTokenUsage(usages []*TokenUsage) *TokenUsage { + if len(usages) == 0 { + return nil + } + + ret := &TokenUsage{} + + for _, usage := range usages { + if usage == nil { + continue + } + ret.CompletionTokens += usage.CompletionTokens + ret.CompletionTokensDetails.ReasoningTokens += usage.CompletionTokensDetails.ReasoningTokens + ret.PromptTokens += usage.PromptTokens + ret.PromptTokenDetails.CachedTokens += usage.PromptTokenDetails.CachedTokens + ret.TotalTokens += usage.TotalTokens + } + + return ret +} + +func concatChunksOfSameContentBlock(blocks []*ContentBlock) (*ContentBlock, error) { + if len(blocks) == 0 { + return nil, fmt.Errorf("no content blocks to concat") + } + + blockType := blocks[0].Type + + switch blockType { + case ContentBlockTypeReasoning: + return concatContentBlockHelper(blocks, blockType, + func(b *ContentBlock) *Reasoning { return b.Reasoning }, + concatReasoning) + + case ContentBlockTypeUserInputText: + return concatContentBlockHelper(blocks, blockType, + func(b *ContentBlock) *UserInputText { return b.UserInputText }, + concatUserInputTexts) + + case ContentBlockTypeUserInputImage: + return concatContentBlockHelper(blocks, blockType, + func(b *ContentBlock) *UserInputImage { return b.UserInputImage }, + concatUserInputImages) + + case ContentBlockTypeUserInputAudio: + return concatContentBlockHelper(blocks, blockType, + func(b *ContentBlock) *UserInputAudio { return b.UserInputAudio }, + concatUserInputAudios) + + case ContentBlockTypeUserInputVideo: + return concatContentBlockHelper(blocks, blockType, + func(b *ContentBlock) *UserInputVideo { return b.UserInputVideo }, + concatUserInputVideos) + + case ContentBlockTypeUserInputFile: + return concatContentBlockHelper(blocks, blockType, + func(b *ContentBlock) *UserInputFile { return b.UserInputFile }, + concatUserInputFiles) + + case ContentBlockTypeToolSearchResult: + return concatContentBlockHelper(blocks, blockType, + func(b *ContentBlock) *ToolSearchFunctionToolResult { return b.ToolSearchFunctionToolResult }, + concatToolSearchFunctionToolResult) + + case ContentBlockTypeAssistantGenText: + return concatContentBlockHelper(blocks, blockType, + func(b *ContentBlock) *AssistantGenText { return b.AssistantGenText }, + concatAssistantGenTexts) + + case ContentBlockTypeAssistantGenImage: + return concatContentBlockHelper(blocks, blockType, + func(b *ContentBlock) *AssistantGenImage { return b.AssistantGenImage }, + concatAssistantGenImages) + + case ContentBlockTypeAssistantGenAudio: + return concatContentBlockHelper(blocks, blockType, + func(b *ContentBlock) *AssistantGenAudio { return b.AssistantGenAudio }, + concatAssistantGenAudios) + + case ContentBlockTypeAssistantGenVideo: + return concatContentBlockHelper(blocks, blockType, + func(b *ContentBlock) *AssistantGenVideo { return b.AssistantGenVideo }, + concatAssistantGenVideos) + + case ContentBlockTypeFunctionToolCall: + return concatContentBlockHelper(blocks, blockType, + func(b *ContentBlock) *FunctionToolCall { return b.FunctionToolCall }, + concatFunctionToolCalls) + + case ContentBlockTypeFunctionToolResult: + return concatContentBlockHelper(blocks, blockType, + func(b *ContentBlock) *FunctionToolResult { return b.FunctionToolResult }, + concatFunctionToolResults) + + case ContentBlockTypeServerToolCall: + return concatContentBlockHelper(blocks, blockType, + func(b *ContentBlock) *ServerToolCall { return b.ServerToolCall }, + concatServerToolCalls) + + case ContentBlockTypeServerToolResult: + return concatContentBlockHelper(blocks, blockType, + func(b *ContentBlock) *ServerToolResult { return b.ServerToolResult }, + concatServerToolResults) + + case ContentBlockTypeMCPToolCall: + return concatContentBlockHelper(blocks, blockType, + func(b *ContentBlock) *MCPToolCall { return b.MCPToolCall }, + concatMCPToolCalls) + + case ContentBlockTypeMCPToolResult: + return concatContentBlockHelper(blocks, blockType, + func(b *ContentBlock) *MCPToolResult { return b.MCPToolResult }, + concatMCPToolResults) + + case ContentBlockTypeMCPListToolsResult: + return concatContentBlockHelper(blocks, blockType, + func(b *ContentBlock) *MCPListToolsResult { return b.MCPListToolsResult }, + concatMCPListToolsResults) + + case ContentBlockTypeMCPToolApprovalRequest: + return concatContentBlockHelper(blocks, blockType, + func(b *ContentBlock) *MCPToolApprovalRequest { return b.MCPToolApprovalRequest }, + concatMCPToolApprovalRequests) + + case ContentBlockTypeMCPToolApprovalResponse: + return concatContentBlockHelper(blocks, blockType, + func(b *ContentBlock) *MCPToolApprovalResponse { return b.MCPToolApprovalResponse }, + concatMCPToolApprovalResponses) + + default: + return nil, fmt.Errorf("unknown content block type: %s", blockType) + } +} + +// concatContentBlockHelper is a generic helper function that reduces code duplication +// for concatenating content blocks of a specific type. +func concatContentBlockHelper[T contentBlockVariant]( + blocks []*ContentBlock, + expectedType ContentBlockType, + getter func(*ContentBlock) *T, + concatFunc func([]*T) (*T, error), +) (*ContentBlock, error) { + items, err := genericGetTFromContentBlocks(blocks, func(block *ContentBlock) (*T, error) { + if block.Type != expectedType { + return nil, fmt.Errorf("content block type mismatch: expected '%s', but got '%s'", expectedType, block.Type) + } + item := getter(block) + if item == nil { + return nil, fmt.Errorf("'%s' content is nil", expectedType) + } + return item, nil + }) + if err != nil { + return nil, err + } + + concatenated, err := concatFunc(items) + if err != nil { + return nil, fmt.Errorf("failed to concat '%s' content blocks: %w", expectedType, err) + } + + extras := make([]map[string]any, 0, len(blocks)) + for _, block := range blocks { + if len(block.Extra) > 0 { + extras = append(extras, block.Extra) + } + } + + var extra map[string]any + if len(extras) > 0 { + extra, err = internal.ConcatItems(extras) + if err != nil { + return nil, fmt.Errorf("failed to concat content block extras: %w", err) + } + } + + block := NewContentBlock(concatenated) + block.Extra = extra + + return block, nil +} + +func genericGetTFromContentBlocks[T any](blocks []*ContentBlock, checkAndGetter func(block *ContentBlock) (T, error)) ([]T, error) { + ret := make([]T, 0, len(blocks)) + for _, block := range blocks { + t, err := checkAndGetter(block) + if err != nil { + return nil, err + } + ret = append(ret, t) + } + return ret, nil +} + +func concatReasoning(reasons []*Reasoning) (*Reasoning, error) { + if len(reasons) == 0 { + return nil, fmt.Errorf("no reasoning found") + } + + ret := &Reasoning{} + + for _, r := range reasons { + if r.Text != "" { + ret.Text += r.Text + } + if r.Signature != "" { + ret.Signature += r.Signature + } + } + + return ret, nil +} + +func concatUserInputTexts(texts []*UserInputText) (*UserInputText, error) { + if len(texts) == 0 { + return nil, fmt.Errorf("no user input text found") + } + if len(texts) == 1 { + return texts[0], nil + } + return nil, fmt.Errorf("cannot concat multiple user input texts") +} + +func concatUserInputImages(images []*UserInputImage) (*UserInputImage, error) { + if len(images) == 0 { + return nil, fmt.Errorf("no user input image found") + } + if len(images) == 1 { + return images[0], nil + } + return nil, fmt.Errorf("cannot concat multiple user input images") +} + +func concatUserInputAudios(audios []*UserInputAudio) (*UserInputAudio, error) { + if len(audios) == 0 { + return nil, fmt.Errorf("no user input audio found") + } + if len(audios) == 1 { + return audios[0], nil + } + return nil, fmt.Errorf("cannot concat multiple user input audios") +} + +func concatUserInputVideos(videos []*UserInputVideo) (*UserInputVideo, error) { + if len(videos) == 0 { + return nil, fmt.Errorf("no user input video found") + } + if len(videos) == 1 { + return videos[0], nil + } + return nil, fmt.Errorf("cannot concat multiple user input videos") +} + +func concatUserInputFiles(files []*UserInputFile) (*UserInputFile, error) { + if len(files) == 0 { + return nil, fmt.Errorf("no user input file found") + } + if len(files) == 1 { + return files[0], nil + } + return nil, fmt.Errorf("cannot concat multiple user input files") +} + +func concatToolSearchFunctionToolResult(results []*ToolSearchFunctionToolResult) (*ToolSearchFunctionToolResult, error) { + if len(results) == 0 { + return nil, fmt.Errorf("no tool search results found") + } + if len(results) == 1 { + return results[0], nil + } + return nil, fmt.Errorf("cannot concat multiple tool search results") +} + +func concatAssistantGenTexts(texts []*AssistantGenText) (ret *AssistantGenText, err error) { + if len(texts) == 0 { + return nil, fmt.Errorf("no assistant generated text found") + } + if len(texts) == 1 { + return texts[0], nil + } + + ret = &AssistantGenText{} + + openaiExtensions := make([]*openai.AssistantGenTextExtension, 0, len(texts)) + claudeExtensions := make([]*claude.AssistantGenTextExtension, 0, len(texts)) + + var ( + extType reflect.Type + extensions reflect.Value + ) + + for _, t := range texts { + if t == nil { + continue + } + + ret.Text += t.Text + + var isConsistent bool + + if t.Extension != nil { + extType, isConsistent = validateExtensionType(extType, t.Extension) + if !isConsistent { + return nil, fmt.Errorf("inconsistent extension types in assistant generated text chunks: '%s' vs '%s'", + extType, reflect.TypeOf(t.Extension)) + } + if !extensions.IsValid() { + extensions = reflect.MakeSlice(reflect.SliceOf(extType), 0, len(texts)) + } + extensions = reflect.Append(extensions, reflect.ValueOf(t.Extension)) + } + + if t.OpenAIExtension != nil { + extType, isConsistent = validateExtensionType(extType, t.OpenAIExtension) + if !isConsistent { + return nil, fmt.Errorf("inconsistent extension types in assistant generated text chunks: '%s' vs '%s'", + extType, reflect.TypeOf(t.OpenAIExtension)) + } + openaiExtensions = append(openaiExtensions, t.OpenAIExtension) + } + + if t.ClaudeExtension != nil { + extType, isConsistent = validateExtensionType(extType, t.ClaudeExtension) + if !isConsistent { + return nil, fmt.Errorf("inconsistent extension types in assistant generated text chunks: '%s' vs '%s'", + extType, reflect.TypeOf(t.ClaudeExtension)) + } + claudeExtensions = append(claudeExtensions, t.ClaudeExtension) + } + } + + if extensions.IsValid() && !extensions.IsZero() { + ret.Extension, err = internal.ConcatSliceValue(extensions) + if err != nil { + return nil, err + } + ret.Extension = extensions.Interface() + } + + if len(openaiExtensions) > 0 { + ret.OpenAIExtension, err = openai.ConcatAssistantGenTextExtensions(openaiExtensions) + if err != nil { + return nil, err + } + } + + if len(claudeExtensions) > 0 { + ret.ClaudeExtension, err = claude.ConcatAssistantGenTextExtensions(claudeExtensions) + if err != nil { + return nil, err + } + } + + return ret, nil +} + +func concatAssistantGenImages(images []*AssistantGenImage) (*AssistantGenImage, error) { + if len(images) == 0 { + return nil, fmt.Errorf("no assistant gen image found") + } + if len(images) == 1 { + return images[0], nil + } + + ret := &AssistantGenImage{} + + for _, img := range images { + if img == nil { + continue + } + + ret.Base64Data += img.Base64Data + + if ret.URL == "" { + ret.URL = img.URL + } else if img.URL != "" && ret.URL != img.URL { + return nil, fmt.Errorf("inconsistent URLs in assistant generated image chunks: '%s' vs '%s'", ret.URL, img.URL) + } + + if ret.MIMEType == "" { + ret.MIMEType = img.MIMEType + } else if img.MIMEType != "" && ret.MIMEType != img.MIMEType { + return nil, fmt.Errorf("inconsistent MIME types in assistant generated image chunks: '%s' vs '%s'", ret.MIMEType, img.MIMEType) + } + } + + return ret, nil +} + +func concatAssistantGenAudios(audios []*AssistantGenAudio) (*AssistantGenAudio, error) { + if len(audios) == 0 { + return nil, fmt.Errorf("no assistant gen audio found") + } + if len(audios) == 1 { + return audios[0], nil + } + + ret := &AssistantGenAudio{} + + for _, audio := range audios { + if audio == nil { + continue + } + + ret.Base64Data += audio.Base64Data + + if ret.URL == "" { + ret.URL = audio.URL + } else if audio.URL != "" && ret.URL != audio.URL { + return nil, fmt.Errorf("inconsistent URLs in assistant generated audio chunks: '%s' vs '%s'", ret.URL, audio.URL) + } + + if ret.MIMEType == "" { + ret.MIMEType = audio.MIMEType + } else if audio.MIMEType != "" && ret.MIMEType != audio.MIMEType { + return nil, fmt.Errorf("inconsistent MIME types in assistant generated audio chunks: '%s' vs '%s'", ret.MIMEType, audio.MIMEType) + } + } + + return ret, nil +} + +func concatAssistantGenVideos(videos []*AssistantGenVideo) (*AssistantGenVideo, error) { + if len(videos) == 0 { + return nil, fmt.Errorf("no assistant gen video found") + } + if len(videos) == 1 { + return videos[0], nil + } + + ret := &AssistantGenVideo{} + + for _, video := range videos { + if video == nil { + continue + } + + ret.Base64Data += video.Base64Data + + if ret.URL == "" { + ret.URL = video.URL + } else if video.URL != "" && ret.URL != video.URL { + return nil, fmt.Errorf("inconsistent URLs in assistant generated video chunks: '%s' vs '%s'", ret.URL, video.URL) + } + + if ret.MIMEType == "" { + ret.MIMEType = video.MIMEType + } else if video.MIMEType != "" && ret.MIMEType != video.MIMEType { + return nil, fmt.Errorf("inconsistent MIME types in assistant generated video chunks: '%s' vs '%s'", ret.MIMEType, video.MIMEType) + } + } + + return ret, nil +} + +func concatFunctionToolCalls(calls []*FunctionToolCall) (*FunctionToolCall, error) { + if len(calls) == 0 { + return nil, fmt.Errorf("no function tool call found") + } + if len(calls) == 1 { + return calls[0], nil + } + + ret := &FunctionToolCall{} + + for _, c := range calls { + if c == nil { + continue + } + + if ret.CallID == "" { + ret.CallID = c.CallID + } else if c.CallID != "" && c.CallID != ret.CallID { + return nil, fmt.Errorf("expected call ID '%s' for function tool call, but got '%s'", ret.CallID, c.CallID) + } + + if ret.Name == "" { + ret.Name = c.Name + } else if c.Name != "" && c.Name != ret.Name { + return nil, fmt.Errorf("expected tool name '%s' for function tool call, but got '%s'", ret.Name, c.Name) + } + + ret.Arguments += c.Arguments + } + + return ret, nil +} + +func concatFunctionToolResults(results []*FunctionToolResult) (*FunctionToolResult, error) { + if len(results) == 0 { + return nil, fmt.Errorf("no function tool result found") + } + if len(results) == 1 { + return results[0], nil + } + + ret := &FunctionToolResult{} + + for _, r := range results { + if r == nil { + continue + } + + if ret.CallID == "" { + ret.CallID = r.CallID + } else if r.CallID != "" && r.CallID != ret.CallID { + return nil, fmt.Errorf("expected call ID '%s' for function tool result, but got '%s'", ret.CallID, r.CallID) + } + + if ret.Name == "" { + ret.Name = r.Name + } else if r.Name != "" && r.Name != ret.Name { + return nil, fmt.Errorf("expected tool name '%s' for function tool result, but got '%s'", ret.Name, r.Name) + } + + ret.Result += r.Result + } + + return ret, nil +} + +func concatServerToolCalls(calls []*ServerToolCall) (ret *ServerToolCall, err error) { + if len(calls) == 0 { + return nil, fmt.Errorf("no server tool call found") + } + if len(calls) == 1 { + return calls[0], nil + } + + ret = &ServerToolCall{} + + var ( + argsType reflect.Type + argsChunks reflect.Value + ) + + for _, c := range calls { + if c == nil { + continue + } + + if ret.CallID == "" { + ret.CallID = c.CallID + } else if c.CallID != "" && c.CallID != ret.CallID { + return nil, fmt.Errorf("expected call ID '%s' for server tool call, but got '%s'", ret.CallID, c.CallID) + } + + if ret.Name == "" { + ret.Name = c.Name + } else if c.Name != "" && c.Name != ret.Name { + return nil, fmt.Errorf("expected tool name '%s' for server tool call, but got '%s'", ret.Name, c.Name) + } + + if c.Arguments != nil { + argsType_ := reflect.TypeOf(c.Arguments) + if argsType == nil { + argsType = argsType_ + argsChunks = reflect.MakeSlice(reflect.SliceOf(argsType), 0, len(calls)) + } else if argsType != argsType_ { + return nil, fmt.Errorf("expected type '%s' for server tool call arguments, but got '%s'", argsType, argsType_) + } + argsChunks = reflect.Append(argsChunks, reflect.ValueOf(c.Arguments)) + } + } + + if argsChunks.IsValid() && !argsChunks.IsZero() { + arguments, err := internal.ConcatSliceValue(argsChunks) + if err != nil { + return nil, err + } + ret.Arguments = arguments.Interface() + } + + return ret, nil +} + +func concatServerToolResults(results []*ServerToolResult) (ret *ServerToolResult, err error) { + if len(results) == 0 { + return nil, fmt.Errorf("no server tool result found") + } + if len(results) == 1 { + return results[0], nil + } + + ret = &ServerToolResult{} + + var ( + resType reflect.Type + resChunks reflect.Value + ) + + for _, r := range results { + if r == nil { + continue + } + + if ret.CallID == "" { + ret.CallID = r.CallID + } else if r.CallID != "" && r.CallID != ret.CallID { + return nil, fmt.Errorf("expected call ID '%s' for server tool result, but got '%s'", ret.CallID, r.CallID) + } + + if ret.Name == "" { + ret.Name = r.Name + } else if r.Name != "" && r.Name != ret.Name { + return nil, fmt.Errorf("expected tool name '%s' for server tool result, but got '%s'", ret.Name, r.Name) + } + + if r.Result != nil { + resType_ := reflect.TypeOf(r.Result) + if resType == nil { + resType = resType_ + resChunks = reflect.MakeSlice(reflect.SliceOf(resType), 0, len(results)) + } else if resType != resType_ { + return nil, fmt.Errorf("expected type '%s' for server tool result, but got '%s'", resType, resType_) + } + resChunks = reflect.Append(resChunks, reflect.ValueOf(r.Result)) + } + } + + if resChunks.IsValid() && !resChunks.IsZero() { + result, err := internal.ConcatSliceValue(resChunks) + if err != nil { + return nil, fmt.Errorf("failed to concat server tool result: %v", err) + } + ret.Result = result.Interface() + } + + return ret, nil +} + +func concatMCPToolCalls(calls []*MCPToolCall) (*MCPToolCall, error) { + if len(calls) == 0 { + return nil, fmt.Errorf("no mcp tool call found") + } + if len(calls) == 1 { + return calls[0], nil + } + + ret := &MCPToolCall{} + + for _, c := range calls { + if c == nil { + continue + } + + ret.Arguments += c.Arguments + + if ret.ServerLabel == "" { + ret.ServerLabel = c.ServerLabel + } else if c.ServerLabel != "" && c.ServerLabel != ret.ServerLabel { + return nil, fmt.Errorf("expected server label '%s' for mcp tool call, but got '%s'", ret.ServerLabel, c.ServerLabel) + } + + if ret.CallID == "" { + ret.CallID = c.CallID + } else if c.CallID != "" && c.CallID != ret.CallID { + return nil, fmt.Errorf("expected call ID '%s' for mcp tool call, but got '%s'", ret.CallID, c.CallID) + } + + if ret.Name == "" { + ret.Name = c.Name + } else if c.Name != "" && c.Name != ret.Name { + return nil, fmt.Errorf("expected tool name '%s' for mcp tool call, but got '%s'", ret.Name, c.Name) + } + } + + return ret, nil +} + +func concatMCPToolResults(results []*MCPToolResult) (*MCPToolResult, error) { + if len(results) == 0 { + return nil, fmt.Errorf("no mcp tool result found") + } + if len(results) == 1 { + return results[0], nil + } + + ret := &MCPToolResult{} + + for _, r := range results { + if r == nil { + continue + } + + if r.Result != "" { + ret.Result = r.Result + } + + if ret.ServerLabel == "" { + ret.ServerLabel = r.ServerLabel + } else if r.ServerLabel != "" && r.ServerLabel != ret.ServerLabel { + return nil, fmt.Errorf("expected server label '%s' for mcp tool result, but got '%s'", ret.ServerLabel, r.ServerLabel) + } + + if ret.CallID == "" { + ret.CallID = r.CallID + } else if r.CallID != "" && r.CallID != ret.CallID { + return nil, fmt.Errorf("expected call ID '%s' for mcp tool result, but got '%s'", ret.CallID, r.CallID) + } + + if ret.Name == "" { + ret.Name = r.Name + } else if r.Name != "" && r.Name != ret.Name { + return nil, fmt.Errorf("expected tool name '%s' for mcp tool result, but got '%s'", ret.Name, r.Name) + } + + if r.Error != nil { + ret.Error = r.Error + } + } + + return ret, nil +} + +func concatMCPListToolsResults(results []*MCPListToolsResult) (*MCPListToolsResult, error) { + if len(results) == 0 { + return nil, fmt.Errorf("no mcp list tools result found") + } + if len(results) == 1 { + return results[0], nil + } + + ret := &MCPListToolsResult{} + + for _, r := range results { + if r == nil { + continue + } + + ret.Tools = append(ret.Tools, r.Tools...) + + if r.Error != "" { + ret.Error = r.Error + } + + if ret.ServerLabel == "" { + ret.ServerLabel = r.ServerLabel + } else if r.ServerLabel != "" && r.ServerLabel != ret.ServerLabel { + return nil, fmt.Errorf("expected server label '%s' for mcp list tools result, but got '%s'", ret.ServerLabel, r.ServerLabel) + } + } + + return ret, nil +} + +func concatMCPToolApprovalRequests(requests []*MCPToolApprovalRequest) (*MCPToolApprovalRequest, error) { + if len(requests) == 0 { + return nil, fmt.Errorf("no mcp tool approval request found") + } + if len(requests) == 1 { + return requests[0], nil + } + + ret := &MCPToolApprovalRequest{} + + for _, r := range requests { + if r == nil { + continue + } + + ret.Arguments += r.Arguments + + if ret.ID == "" { + ret.ID = r.ID + } else if r.ID != "" && r.ID != ret.ID { + return nil, fmt.Errorf("expected request ID '%s' for mcp tool approval request, but got '%s'", ret.ID, r.ID) + } + + if ret.Name == "" { + ret.Name = r.Name + } else if r.Name != "" && r.Name != ret.Name { + return nil, fmt.Errorf("expected tool name '%s' for mcp tool approval request, but got '%s'", ret.Name, r.Name) + } + + if ret.ServerLabel == "" { + ret.ServerLabel = r.ServerLabel + } else if r.ServerLabel != "" && r.ServerLabel != ret.ServerLabel { + return nil, fmt.Errorf("expected server label '%s' for mcp tool approval request, but got '%s'", ret.ServerLabel, r.ServerLabel) + } + } + + return ret, nil +} + +func concatMCPToolApprovalResponses(responses []*MCPToolApprovalResponse) (*MCPToolApprovalResponse, error) { + if len(responses) == 0 { + return nil, fmt.Errorf("no mcp tool approval response found") + } + if len(responses) == 1 { + return responses[0], nil + } + return nil, fmt.Errorf("cannot concat multiple mcp tool approval responses") +} + +// String returns the string representation of AgenticMessage. +func (m *AgenticMessage) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf("role: %s\n", m.Role)) + + if len(m.ContentBlocks) > 0 { + sb.WriteString("content_blocks:\n") + for i, block := range m.ContentBlocks { + if block == nil { + continue + } + sb.WriteString(fmt.Sprintf(" [%d] %s", i, block.String())) + } + } + + if m.ResponseMeta != nil { + sb.WriteString(m.ResponseMeta.String()) + } + + return sb.String() +} + +// String returns the string representation of ContentBlock. +// nolint +func (b *ContentBlock) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf("type: %s\n", b.Type)) + + switch b.Type { + case ContentBlockTypeReasoning: + if b.Reasoning != nil { + sb.WriteString(b.Reasoning.String()) + } + case ContentBlockTypeUserInputText: + if b.UserInputText != nil { + sb.WriteString(b.UserInputText.String()) + } + case ContentBlockTypeUserInputImage: + if b.UserInputImage != nil { + sb.WriteString(b.UserInputImage.String()) + } + case ContentBlockTypeUserInputAudio: + if b.UserInputAudio != nil { + sb.WriteString(b.UserInputAudio.String()) + } + case ContentBlockTypeUserInputVideo: + if b.UserInputVideo != nil { + sb.WriteString(b.UserInputVideo.String()) + } + case ContentBlockTypeUserInputFile: + if b.UserInputFile != nil { + sb.WriteString(b.UserInputFile.String()) + } + case ContentBlockTypeToolSearchResult: + if b.ToolSearchFunctionToolResult != nil { + sb.WriteString(b.ToolSearchFunctionToolResult.String()) + } + case ContentBlockTypeAssistantGenText: + if b.AssistantGenText != nil { + sb.WriteString(b.AssistantGenText.String()) + } + case ContentBlockTypeAssistantGenImage: + if b.AssistantGenImage != nil { + sb.WriteString(b.AssistantGenImage.String()) + } + case ContentBlockTypeAssistantGenAudio: + if b.AssistantGenAudio != nil { + sb.WriteString(b.AssistantGenAudio.String()) + } + case ContentBlockTypeAssistantGenVideo: + if b.AssistantGenVideo != nil { + sb.WriteString(b.AssistantGenVideo.String()) + } + case ContentBlockTypeFunctionToolCall: + if b.FunctionToolCall != nil { + sb.WriteString(b.FunctionToolCall.String()) + } + case ContentBlockTypeFunctionToolResult: + if b.FunctionToolResult != nil { + sb.WriteString(b.FunctionToolResult.String()) + } + case ContentBlockTypeServerToolCall: + if b.ServerToolCall != nil { + sb.WriteString(b.ServerToolCall.String()) + } + case ContentBlockTypeServerToolResult: + if b.ServerToolResult != nil { + sb.WriteString(b.ServerToolResult.String()) + } + case ContentBlockTypeMCPToolCall: + if b.MCPToolCall != nil { + sb.WriteString(b.MCPToolCall.String()) + } + case ContentBlockTypeMCPToolResult: + if b.MCPToolResult != nil { + sb.WriteString(b.MCPToolResult.String()) + } + case ContentBlockTypeMCPListToolsResult: + if b.MCPListToolsResult != nil { + sb.WriteString(b.MCPListToolsResult.String()) + } + case ContentBlockTypeMCPToolApprovalRequest: + if b.MCPToolApprovalRequest != nil { + sb.WriteString(b.MCPToolApprovalRequest.String()) + } + case ContentBlockTypeMCPToolApprovalResponse: + if b.MCPToolApprovalResponse != nil { + sb.WriteString(b.MCPToolApprovalResponse.String()) + } + } + + if b.StreamingMeta != nil { + sb.WriteString(fmt.Sprintf(" stream_index: %d\n", b.StreamingMeta.Index)) + } + + return sb.String() +} + +// String returns the string representation of Reasoning. +func (r *Reasoning) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" text: %s\n", r.Text)) + if r.Signature != "" { + sb.WriteString(fmt.Sprintf(" signature: %s\n", truncateString(r.Signature, 50))) + } + return sb.String() +} + +// String returns the string representation of UserInputText. +func (u *UserInputText) String() string { + return fmt.Sprintf(" text: %s\n", u.Text) +} + +// String returns the string representation of UserInputImage. +func (u *UserInputImage) String() string { + return formatMediaString(u.URL, u.Base64Data, u.MIMEType, string(u.Detail)) +} + +// String returns the string representation of UserInputAudio. +func (u *UserInputAudio) String() string { + return formatMediaString(u.URL, u.Base64Data, u.MIMEType, "") +} + +// String returns the string representation of UserInputVideo. +func (u *UserInputVideo) String() string { + return formatMediaString(u.URL, u.Base64Data, u.MIMEType, "") +} + +// String returns the string representation of UserInputFile. +func (u *UserInputFile) String() string { + sb := &strings.Builder{} + if u.Name != "" { + sb.WriteString(fmt.Sprintf(" name: %s\n", u.Name)) + } + sb.WriteString(formatMediaString(u.URL, u.Base64Data, u.MIMEType, "")) + return sb.String() +} + +// String returns the string representation of AssistantGenText. +func (a *AssistantGenText) String() string { + return fmt.Sprintf(" text: %s\n", a.Text) +} + +// String returns the string representation of AssistantGenImage. +func (a *AssistantGenImage) String() string { + return formatMediaString(a.URL, a.Base64Data, a.MIMEType, "") +} + +// String returns the string representation of AssistantGenAudio. +func (a *AssistantGenAudio) String() string { + return formatMediaString(a.URL, a.Base64Data, a.MIMEType, "") +} + +// String returns the string representation of AssistantGenVideo. +func (a *AssistantGenVideo) String() string { + return formatMediaString(a.URL, a.Base64Data, a.MIMEType, "") +} + +// String returns the string representation of FunctionToolCall. +func (f *FunctionToolCall) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" call_id: %s\n", f.CallID)) + sb.WriteString(fmt.Sprintf(" name: %s\n", f.Name)) + sb.WriteString(fmt.Sprintf(" arguments: %s\n", f.Arguments)) + return sb.String() +} + +// String returns the string representation of FunctionToolResult. +func (f *FunctionToolResult) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" call_id: %s\n", f.CallID)) + sb.WriteString(fmt.Sprintf(" name: %s\n", f.Name)) + sb.WriteString(fmt.Sprintf(" result: %s\n", f.Result)) + return sb.String() +} + +// String returns the string representation of ServerToolCall. +func (s *ServerToolCall) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" name: %s\n", s.Name)) + if s.CallID != "" { + sb.WriteString(fmt.Sprintf(" call_id: %s\n", s.CallID)) + } + sb.WriteString(fmt.Sprintf(" arguments: %s\n", printAny(s.Arguments))) + return sb.String() +} + +// String returns the string representation of ServerToolResult. +func (s *ServerToolResult) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" name: %s\n", s.Name)) + if s.CallID != "" { + sb.WriteString(fmt.Sprintf(" call_id: %s\n", s.CallID)) + } + sb.WriteString(fmt.Sprintf(" result: %s\n", printAny(s.Result))) + return sb.String() +} + +// String returns the string representation of MCPToolCall. +func (m *MCPToolCall) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" server_label: %s\n", m.ServerLabel)) + sb.WriteString(fmt.Sprintf(" call_id: %s\n", m.CallID)) + sb.WriteString(fmt.Sprintf(" name: %s\n", m.Name)) + sb.WriteString(fmt.Sprintf(" arguments: %s\n", m.Arguments)) + return sb.String() +} + +// String returns the string representation of MCPToolResult. +func (m *MCPToolResult) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" call_id: %s\n", m.CallID)) + sb.WriteString(fmt.Sprintf(" name: %s\n", m.Name)) + sb.WriteString(fmt.Sprintf(" result: %s\n", m.Result)) + if m.Error != nil { + sb.WriteString(fmt.Sprintf(" error: [%d] %s\n", *m.Error.Code, m.Error.Message)) + } + return sb.String() +} + +// String returns the string representation of MCPListToolsResult. +func (m *MCPListToolsResult) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" server_label: %s\n", m.ServerLabel)) + sb.WriteString(fmt.Sprintf(" tools: %d items\n", len(m.Tools))) + for _, tool := range m.Tools { + sb.WriteString(fmt.Sprintf(" - %s: %s\n", tool.Name, tool.Description)) + } + if m.Error != "" { + sb.WriteString(fmt.Sprintf(" error: %s\n", m.Error)) + } + return sb.String() +} + +// String returns the string representation of MCPToolApprovalRequest. +func (m *MCPToolApprovalRequest) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" server_label: %s\n", m.ServerLabel)) + sb.WriteString(fmt.Sprintf(" id: %s\n", m.ID)) + sb.WriteString(fmt.Sprintf(" name: %s\n", m.Name)) + sb.WriteString(fmt.Sprintf(" arguments: %s\n", m.Arguments)) + return sb.String() +} + +// String returns the string representation of MCPToolApprovalResponse. +func (m *MCPToolApprovalResponse) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" approval_request_id: %s\n", m.ApprovalRequestID)) + sb.WriteString(fmt.Sprintf(" approve: %v\n", m.Approve)) + if m.Reason != "" { + sb.WriteString(fmt.Sprintf(" reason: %s\n", m.Reason)) + } + return sb.String() +} + +// String returns the string representation of AgenticResponseMeta. +func (a *AgenticResponseMeta) String() string { + sb := &strings.Builder{} + sb.WriteString("response_meta:\n") + if a.TokenUsage != nil { + sb.WriteString(fmt.Sprintf(" token_usage: prompt=%d, completion=%d, total=%d\n", + a.TokenUsage.PromptTokens, + a.TokenUsage.CompletionTokens, + a.TokenUsage.TotalTokens)) + } + return sb.String() +} + +// truncateString truncates a string to maxLen characters, adding "..." if truncated +func truncateString(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + return s[:maxLen] + "..." +} + +// formatMediaString formats URL, Base64Data, MIMEType and Detail for media content +func formatMediaString(url, base64Data string, mimeType string, detail string) string { + sb := &strings.Builder{} + if url != "" { + sb.WriteString(fmt.Sprintf(" url: %s\n", truncateString(url, 100))) + } + if base64Data != "" { + // Only show first few characters of base64 data + sb.WriteString(fmt.Sprintf(" base64_data: %s... (%d bytes)\n", truncateString(base64Data, 20), len(base64Data))) + } + if mimeType != "" { + sb.WriteString(fmt.Sprintf(" mime_type: %s\n", mimeType)) + } + if detail != "" { + sb.WriteString(fmt.Sprintf(" detail: %s\n", detail)) + } + return sb.String() +} + +func validateExtensionType(expected reflect.Type, actual any) (reflect.Type, bool) { + if actual == nil { + return expected, true + } + actualType := reflect.TypeOf(actual) + if expected == nil { + return actualType, true + } + if expected != actualType { + return expected, false + } + return expected, true +} + +func printAny(a any) string { + switch v := a.(type) { + case string: + return v + case fmt.Stringer: + return v.String() + default: + b, err := json.MarshalIndent(a, "", " ") + if err != nil { + return fmt.Sprintf("%v", a) + } + return string(b) + } +} diff --git a/schema/agentic_message_test.go b/schema/agentic_message_test.go new file mode 100644 index 000000000..10639f738 --- /dev/null +++ b/schema/agentic_message_test.go @@ -0,0 +1,1641 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package schema + +import ( + "context" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConcatAgenticMessages(t *testing.T) { + t.Run("single message", func(t *testing.T) { + msg := &AgenticMessage{ + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Hello", + }, + }, + }, + } + + result, err := ConcatAgenticMessages([]*AgenticMessage{msg}) + assert.NoError(t, err) + assert.Equal(t, msg, result) + }) + + t.Run("nil message in stream", func(t *testing.T) { + msgs := []*AgenticMessage{ + {Role: AgenticRoleTypeAssistant}, + nil, + {Role: AgenticRoleTypeAssistant}, + } + + _, err := ConcatAgenticMessages(msgs) + assert.Error(t, err) + assert.ErrorContains(t, err, "message at index 1 is nil") + }) + + t.Run("different roles", func(t *testing.T) { + msgs := []*AgenticMessage{ + {Role: AgenticRoleTypeUser}, + {Role: AgenticRoleTypeAssistant}, + } + + _, err := ConcatAgenticMessages(msgs) + assert.Error(t, err) + assert.ErrorContains(t, err, "cannot concat messages with different roles") + }) + + t.Run("concat text blocks", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Hello ", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "World!", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Equal(t, AgenticRoleTypeAssistant, result.Role) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "Hello World!", result.ContentBlocks[0].AssistantGenText.Text) + }) + + t.Run("concat reasoning with nil index", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeReasoning, + Reasoning: &Reasoning{ + Text: "First ", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeReasoning, + Reasoning: &Reasoning{ + Text: "Second", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "First Second", result.ContentBlocks[0].Reasoning.Text) + }) + + t.Run("concat reasoning with index", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeReasoning, + Reasoning: &Reasoning{ + Text: "Part1-", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeReasoning, + Reasoning: &Reasoning{ + Text: "Part3", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "Part1-Part3", result.ContentBlocks[0].Reasoning.Text) + }) + + t.Run("concat user input text", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Hello ", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "World!", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "Hello World!", result.ContentBlocks[0].AssistantGenText.Text) + }) + + t.Run("concat assistant gen image", func(t *testing.T) { + base1 := "1" + base2 := "2" + + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenImage, + AssistantGenImage: &AssistantGenImage{ + Base64Data: base1, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenImage, + AssistantGenImage: &AssistantGenImage{ + Base64Data: base2, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "12", result.ContentBlocks[0].AssistantGenImage.Base64Data) + }) + + t.Run("concat user input audio - should error", func(t *testing.T) { + url1 := "https://example.com/audio1.mp3" + url2 := "https://example.com/audio2.mp3" + + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputAudio, + UserInputAudio: &UserInputAudio{ + URL: url1, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputAudio, + UserInputAudio: &UserInputAudio{ + URL: url2, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + _, err := ConcatAgenticMessages(msgs) + assert.Error(t, err) + assert.ErrorContains(t, err, "cannot concat multiple user input audios") + }) + + t.Run("concat user input video - should error", func(t *testing.T) { + url1 := "https://example.com/video1.mp4" + url2 := "https://example.com/video2.mp4" + + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputVideo, + UserInputVideo: &UserInputVideo{ + URL: url1, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputVideo, + UserInputVideo: &UserInputVideo{ + URL: url2, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + _, err := ConcatAgenticMessages(msgs) + assert.Error(t, err) + assert.ErrorContains(t, err, "cannot concat multiple user input videos") + }) + + t.Run("concat assistant gen text", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Generated ", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Text", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "Generated Text", result.ContentBlocks[0].AssistantGenText.Text) + }) + + t.Run("concat assistant gen image", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenImage, + AssistantGenImage: &AssistantGenImage{ + Base64Data: "part1", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenImage, + AssistantGenImage: &AssistantGenImage{ + Base64Data: "part2", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "part1part2", result.ContentBlocks[0].AssistantGenImage.Base64Data) + }) + + t.Run("concat assistant gen audio", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenAudio, + AssistantGenAudio: &AssistantGenAudio{ + Base64Data: "audio1", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenAudio, + AssistantGenAudio: &AssistantGenAudio{ + Base64Data: "audio2", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "audio1audio2", result.ContentBlocks[0].AssistantGenAudio.Base64Data) + }) + + t.Run("concat assistant gen video", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenVideo, + AssistantGenVideo: &AssistantGenVideo{ + Base64Data: "video1", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenVideo, + AssistantGenVideo: &AssistantGenVideo{ + Base64Data: "video2", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "video1video2", result.ContentBlocks[0].AssistantGenVideo.Base64Data) + }) + + t.Run("concat function tool call", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeFunctionToolCall, + FunctionToolCall: &FunctionToolCall{ + CallID: "call_123", + Name: "get_weather", + Arguments: `{"location`, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeFunctionToolCall, + FunctionToolCall: &FunctionToolCall{ + Arguments: `":"NYC"}`, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "call_123", result.ContentBlocks[0].FunctionToolCall.CallID) + assert.Equal(t, "get_weather", result.ContentBlocks[0].FunctionToolCall.Name) + assert.Equal(t, `{"location":"NYC"}`, result.ContentBlocks[0].FunctionToolCall.Arguments) + }) + + t.Run("concat function tool result", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeFunctionToolResult, + FunctionToolResult: &FunctionToolResult{ + CallID: "call_123", + Name: "get_weather", + Result: `{"temp`, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeFunctionToolResult, + FunctionToolResult: &FunctionToolResult{ + Result: `":72}`, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "call_123", result.ContentBlocks[0].FunctionToolResult.CallID) + assert.Equal(t, "get_weather", result.ContentBlocks[0].FunctionToolResult.Name) + assert.Equal(t, `{"temp":72}`, result.ContentBlocks[0].FunctionToolResult.Result) + }) + + t.Run("concat server tool call", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeServerToolCall, + ServerToolCall: &ServerToolCall{ + CallID: "server_call_1", + Name: "server_func", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeServerToolCall, + ServerToolCall: &ServerToolCall{ + Arguments: map[string]any{"key": "value"}, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "server_call_1", result.ContentBlocks[0].ServerToolCall.CallID) + assert.Equal(t, "server_func", result.ContentBlocks[0].ServerToolCall.Name) + assert.NotNil(t, result.ContentBlocks[0].ServerToolCall.Arguments) + }) + + t.Run("concat server tool result", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeServerToolResult, + ServerToolResult: &ServerToolResult{ + CallID: "server_call_1", + Name: "server_func", + Result: "result1", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeServerToolResult, + ServerToolResult: &ServerToolResult{}, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "server_call_1", result.ContentBlocks[0].ServerToolResult.CallID) + assert.Equal(t, "server_func", result.ContentBlocks[0].ServerToolResult.Name) + assert.Equal(t, "result1", result.ContentBlocks[0].ServerToolResult.Result) + }) + + t.Run("concat mcp tool call", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolCall, + MCPToolCall: &MCPToolCall{ + ServerLabel: "mcp-server", + CallID: "mcp_call_1", + Name: "mcp_func", + Arguments: `{"arg`, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolCall, + MCPToolCall: &MCPToolCall{ + Arguments: `":123}`, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "mcp-server", result.ContentBlocks[0].MCPToolCall.ServerLabel) + assert.Equal(t, "mcp_call_1", result.ContentBlocks[0].MCPToolCall.CallID) + assert.Equal(t, "mcp_func", result.ContentBlocks[0].MCPToolCall.Name) + assert.Equal(t, `{"arg":123}`, result.ContentBlocks[0].MCPToolCall.Arguments) + }) + + t.Run("concat mcp tool result", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolResult, + MCPToolResult: &MCPToolResult{ + ServerLabel: "mcp-server", + CallID: "mcp_call_1", + Name: "mcp_func", + Result: `First`, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolResult, + MCPToolResult: &MCPToolResult{ + Result: `Second`, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "mcp-server", result.ContentBlocks[0].MCPToolResult.ServerLabel) + assert.Equal(t, "mcp_call_1", result.ContentBlocks[0].MCPToolResult.CallID) + assert.Equal(t, "mcp_func", result.ContentBlocks[0].MCPToolResult.Name) + assert.Equal(t, `Second`, result.ContentBlocks[0].MCPToolResult.Result) + }) + + t.Run("concat mcp list tools", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPListToolsResult, + MCPListToolsResult: &MCPListToolsResult{ + ServerLabel: "mcp-server", + Tools: []*MCPListToolsItem{ + {Name: "tool1"}, + }, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPListToolsResult, + MCPListToolsResult: &MCPListToolsResult{ + Tools: []*MCPListToolsItem{ + {Name: "tool2"}, + }, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "mcp-server", result.ContentBlocks[0].MCPListToolsResult.ServerLabel) + assert.Len(t, result.ContentBlocks[0].MCPListToolsResult.Tools, 2) + }) + + t.Run("concat mcp tool approval request", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolApprovalRequest, + MCPToolApprovalRequest: &MCPToolApprovalRequest{ + ID: "approval_1", + Name: "approval_func", + ServerLabel: "mcp-server", + Arguments: `{"request`, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolApprovalRequest, + MCPToolApprovalRequest: &MCPToolApprovalRequest{ + Arguments: `":1}`, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "approval_1", result.ContentBlocks[0].MCPToolApprovalRequest.ID) + assert.Equal(t, "approval_func", result.ContentBlocks[0].MCPToolApprovalRequest.Name) + assert.Equal(t, "mcp-server", result.ContentBlocks[0].MCPToolApprovalRequest.ServerLabel) + assert.Equal(t, `{"request":1}`, result.ContentBlocks[0].MCPToolApprovalRequest.Arguments) + }) + + t.Run("concat mcp tool approval response - should error", func(t *testing.T) { + response1 := &MCPToolApprovalResponse{ + ApprovalRequestID: "approval_1", + Approve: false, + } + response2 := &MCPToolApprovalResponse{ + ApprovalRequestID: "approval_1", + Approve: true, + } + + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolApprovalResponse, + MCPToolApprovalResponse: response1, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolApprovalResponse, + MCPToolApprovalResponse: response2, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + _, err := ConcatAgenticMessages(msgs) + assert.Error(t, err) + assert.ErrorContains(t, err, "cannot concat multiple mcp tool approval responses") + }) + + t.Run("concat response meta", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ResponseMeta: &AgenticResponseMeta{ + TokenUsage: &TokenUsage{ + PromptTokens: 10, + CompletionTokens: 5, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ResponseMeta: &AgenticResponseMeta{ + TokenUsage: &TokenUsage{ + PromptTokens: 10, + CompletionTokens: 15, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.NotNil(t, result.ResponseMeta) + assert.Equal(t, 20, result.ResponseMeta.TokenUsage.CompletionTokens) + assert.Equal(t, 20, result.ResponseMeta.TokenUsage.PromptTokens) + }) + + t.Run("mixed streaming and non-streaming blocks error", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Hello", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "World", + }, + // No StreamingMeta - non-streaming + }, + }, + }, + } + + _, err := ConcatAgenticMessages(msgs) + assert.Error(t, err) + assert.ErrorContains(t, err, "found non-streaming block after streaming blocks") + }) + + t.Run("concat MCP tool call", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolCall, + MCPToolCall: &MCPToolCall{ + ServerLabel: "mcp-server", + CallID: "call_456", + Name: "list_files", + Arguments: `{"path`, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolCall, + MCPToolCall: &MCPToolCall{ + Arguments: `":"/tmp"}`, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "mcp-server", result.ContentBlocks[0].MCPToolCall.ServerLabel) + assert.Equal(t, "call_456", result.ContentBlocks[0].MCPToolCall.CallID) + assert.Equal(t, `{"path":"/tmp"}`, result.ContentBlocks[0].MCPToolCall.Arguments) + }) + + t.Run("concat user input text - should error", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputText, + UserInputText: &UserInputText{ + Text: "What is ", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputText, + UserInputText: &UserInputText{ + Text: "the weather?", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + _, err := ConcatAgenticMessages(msgs) + assert.Error(t, err) + assert.ErrorContains(t, err, "cannot concat multiple user input texts") + }) + + t.Run("multiple stream indexes - sparse indexes", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Index0-", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Index2-", + }, + StreamingMeta: &StreamingMeta{Index: 2}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Part2", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Part2", + }, + StreamingMeta: &StreamingMeta{Index: 2}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 2) + assert.Equal(t, "Index0-Part2", result.ContentBlocks[0].AssistantGenText.Text) + assert.Equal(t, "Index2-Part2", result.ContentBlocks[1].AssistantGenText.Text) + }) + + t.Run("multiple stream indexes - mixed content types", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Text ", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + { + Type: ContentBlockTypeFunctionToolCall, + FunctionToolCall: &FunctionToolCall{ + CallID: "call_1", + Name: "func1", + Arguments: `{"a`, + }, + StreamingMeta: &StreamingMeta{Index: 1}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Content", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + { + Type: ContentBlockTypeFunctionToolCall, + FunctionToolCall: &FunctionToolCall{ + Arguments: `":1}`, + }, + StreamingMeta: &StreamingMeta{Index: 1}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 2) + assert.Equal(t, "Text Content", result.ContentBlocks[0].AssistantGenText.Text) + assert.Equal(t, "call_1", result.ContentBlocks[1].FunctionToolCall.CallID) + assert.Equal(t, "func1", result.ContentBlocks[1].FunctionToolCall.Name) + assert.Equal(t, `{"a":1}`, result.ContentBlocks[1].FunctionToolCall.Arguments) + }) + + t.Run("multiple stream indexes - three indexes", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "A", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "B", + }, + StreamingMeta: &StreamingMeta{Index: 1}, + }, + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "C", + }, + StreamingMeta: &StreamingMeta{Index: 2}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "1", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "2", + }, + StreamingMeta: &StreamingMeta{Index: 1}, + }, + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "3", + }, + StreamingMeta: &StreamingMeta{Index: 2}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 3) + assert.Equal(t, "A1", result.ContentBlocks[0].AssistantGenText.Text) + assert.Equal(t, "B2", result.ContentBlocks[1].AssistantGenText.Text) + assert.Equal(t, "C3", result.ContentBlocks[2].AssistantGenText.Text) + }) +} + +func TestAgenticMessageFormat(t *testing.T) { + m := &AgenticMessage{ + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputText, + UserInputText: &UserInputText{Text: "{a}"}, + }, + { + Type: ContentBlockTypeUserInputImage, + UserInputImage: &UserInputImage{ + URL: "{b}", + Base64Data: "{c}", + }, + }, + { + Type: ContentBlockTypeUserInputAudio, + UserInputAudio: &UserInputAudio{ + URL: "{d}", + Base64Data: "{e}", + }, + }, + { + Type: ContentBlockTypeUserInputVideo, + UserInputVideo: &UserInputVideo{ + URL: "{f}", + Base64Data: "{g}", + }, + }, + { + Type: ContentBlockTypeUserInputFile, + UserInputFile: &UserInputFile{ + URL: "{h}", + Base64Data: "{i}", + }, + }, + }, + } + + result, err := m.Format(context.Background(), map[string]any{ + "a": "1", "b": "2", "c": "3", "d": "4", "e": "5", "f": "6", "g": "7", "h": "8", "i": "9", + }, FString) + assert.NoError(t, err) + assert.Equal(t, []*AgenticMessage{{ + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputText, + UserInputText: &UserInputText{Text: "1"}, + }, + { + Type: ContentBlockTypeUserInputImage, + UserInputImage: &UserInputImage{ + URL: "2", + Base64Data: "3", + }, + }, + { + Type: ContentBlockTypeUserInputAudio, + UserInputAudio: &UserInputAudio{ + URL: "4", + Base64Data: "5", + }, + }, + { + Type: ContentBlockTypeUserInputVideo, + UserInputVideo: &UserInputVideo{ + URL: "6", + Base64Data: "7", + }, + }, + { + Type: ContentBlockTypeUserInputFile, + UserInputFile: &UserInputFile{ + URL: "8", + Base64Data: "9", + }, + }, + }, + }}, result) +} + +func TestAgenticPlaceholderFormat(t *testing.T) { + ctx := context.Background() + ph := AgenticMessagesPlaceholder("a", false) + + result, err := ph.Format(ctx, map[string]any{ + "a": []*AgenticMessage{{Role: AgenticRoleTypeUser}, {Role: AgenticRoleTypeUser}}, + }, FString) + assert.NoError(t, err) + assert.Equal(t, 2, len(result)) + + ph = AgenticMessagesPlaceholder("a", true) + + result, err = ph.Format(ctx, map[string]any{}, FString) + assert.NoError(t, err) + assert.Equal(t, 0, len(result)) +} + +func ptrOf[T any](v T) *T { + return &v +} + +func TestAgenticMessageString(t *testing.T) { + longBase64 := "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" + + msg := &AgenticMessage{ + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputText, + UserInputText: &UserInputText{ + Text: "What's the weather like in New York City today?", + }, + }, + { + Type: ContentBlockTypeUserInputImage, + UserInputImage: &UserInputImage{ + URL: "https://example.com/weather-map.jpg", + Base64Data: longBase64, + MIMEType: "image/jpeg", + Detail: ImageURLDetailHigh, + }, + }, + { + Type: ContentBlockTypeUserInputAudio, + UserInputAudio: &UserInputAudio{ + URL: "http://audio.com", + Base64Data: "audio_data", + MIMEType: "audio/mp3", + }, + }, + { + Type: ContentBlockTypeUserInputVideo, + UserInputVideo: &UserInputVideo{ + URL: "http://video.com", + Base64Data: "video_data", + MIMEType: "video/mp4", + }, + }, + { + Type: ContentBlockTypeUserInputFile, + UserInputFile: &UserInputFile{ + URL: "http://file.com", + Name: "file.txt", + Base64Data: "file_data", + MIMEType: "text/plain", + }, + }, + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "I'll check the current weather in New York City for you.", + }, + }, + { + Type: ContentBlockTypeAssistantGenImage, + AssistantGenImage: &AssistantGenImage{ + URL: "http://gen_image.com", + Base64Data: "gen_image_data", + MIMEType: "image/png", + }, + }, + { + Type: ContentBlockTypeAssistantGenAudio, + AssistantGenAudio: &AssistantGenAudio{ + URL: "http://gen_audio.com", + Base64Data: "gen_audio_data", + MIMEType: "audio/wav", + }, + }, + { + Type: ContentBlockTypeAssistantGenVideo, + AssistantGenVideo: &AssistantGenVideo{ + URL: "http://gen_video.com", + Base64Data: "gen_video_data", + MIMEType: "video/mp4", + }, + }, + { + Type: ContentBlockTypeReasoning, + Reasoning: &Reasoning{ + Text: "First, I need to identify the location (New York City) from the user's query.\n" + + "Then, I should call the weather API to get current conditions.\n" + + "Finally, I'll format the response in a user-friendly way with temperature and conditions.", + Signature: "encrypted_reasoning_content_that_is_very_long_and_will_be_truncated_for_display", + }, + }, + { + Type: ContentBlockTypeFunctionToolCall, + FunctionToolCall: &FunctionToolCall{ + CallID: "call_weather_123", + Name: "get_current_weather", + Arguments: `{"location":"New York City","unit":"fahrenheit"}`, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + { + Type: ContentBlockTypeFunctionToolResult, + FunctionToolResult: &FunctionToolResult{ + CallID: "call_weather_123", + Name: "get_current_weather", + Result: `{"temperature":72,"condition":"sunny","humidity":45,"wind_speed":8}`, + }, + }, + { + Type: ContentBlockTypeServerToolCall, + ServerToolCall: &ServerToolCall{ + Name: "server_tool", + CallID: "call_1", + Arguments: map[string]any{"a": 1}, + }, + }, + { + Type: ContentBlockTypeServerToolResult, + ServerToolResult: &ServerToolResult{ + Name: "server_tool", + CallID: "call_1", + Result: map[string]any{"success": true}, + }, + }, + { + Type: ContentBlockTypeMCPToolApprovalRequest, + MCPToolApprovalRequest: &MCPToolApprovalRequest{ + ID: "req_1", + Name: "mcp_tool", + ServerLabel: "mcp_server", + Arguments: "{}", + }, + }, + { + Type: ContentBlockTypeMCPToolApprovalResponse, + MCPToolApprovalResponse: &MCPToolApprovalResponse{ + ApprovalRequestID: "req_1", + Approve: true, + Reason: "looks good", + }, + }, + { + Type: ContentBlockTypeMCPToolCall, + MCPToolCall: &MCPToolCall{ + ServerLabel: "weather-mcp-server", + CallID: "mcp_forecast_456", + Name: "get_7day_forecast", + Arguments: `{"city":"New York","days":7}`, + }, + }, + { + Type: ContentBlockTypeMCPToolResult, + MCPToolResult: &MCPToolResult{ + CallID: "mcp_forecast_456", + Name: "get_7day_forecast", + Result: `{"status":"partial","days_available":3}`, + Error: &MCPToolCallError{ + Code: ptrOf[int64](503), + Message: "Service temporarily unavailable for full 7-day forecast", + }, + }, + }, + { + Type: ContentBlockTypeMCPListToolsResult, + MCPListToolsResult: &MCPListToolsResult{ + ServerLabel: "weather-mcp-server", + Tools: []*MCPListToolsItem{ + {Name: "get_current_weather", Description: "Get current weather conditions for a location"}, + {Name: "get_7day_forecast", Description: "Get 7-day weather forecast"}, + {Name: "get_weather_alerts", Description: "Get active weather alerts and warnings"}, + }, + }, + }, + }, + ResponseMeta: &AgenticResponseMeta{ + TokenUsage: &TokenUsage{ + PromptTokens: 250, + CompletionTokens: 180, + TotalTokens: 430, + }, + }, + } + + // Print the formatted output + output := msg.String() + + assert.Equal(t, `role: assistant +content_blocks: + [0] type: user_input_text + text: What's the weather like in New York City today? + [1] type: user_input_image + url: https://example.com/weather-map.jpg + base64_data: iVBORw0KGgoAAAANSUhE...... (96 bytes) + mime_type: image/jpeg + detail: high + [2] type: user_input_audio + url: http://audio.com + base64_data: audio_data... (10 bytes) + mime_type: audio/mp3 + [3] type: user_input_video + url: http://video.com + base64_data: video_data... (10 bytes) + mime_type: video/mp4 + [4] type: user_input_file + name: file.txt + url: http://file.com + base64_data: file_data... (9 bytes) + mime_type: text/plain + [5] type: assistant_gen_text + text: I'll check the current weather in New York City for you. + [6] type: assistant_gen_image + url: http://gen_image.com + base64_data: gen_image_data... (14 bytes) + mime_type: image/png + [7] type: assistant_gen_audio + url: http://gen_audio.com + base64_data: gen_audio_data... (14 bytes) + mime_type: audio/wav + [8] type: assistant_gen_video + url: http://gen_video.com + base64_data: gen_video_data... (14 bytes) + mime_type: video/mp4 + [9] type: reasoning + text: First, I need to identify the location (New York City) from the user's query. +Then, I should call the weather API to get current conditions. +Finally, I'll format the response in a user-friendly way with temperature and conditions. + signature: encrypted_reasoning_content_that_is_very_long_and_... + [10] type: function_tool_call + call_id: call_weather_123 + name: get_current_weather + arguments: {"location":"New York City","unit":"fahrenheit"} + stream_index: 0 + [11] type: function_tool_result + call_id: call_weather_123 + name: get_current_weather + result: {"temperature":72,"condition":"sunny","humidity":45,"wind_speed":8} + [12] type: server_tool_call + name: server_tool + call_id: call_1 + arguments: { + "a": 1 +} + [13] type: server_tool_result + name: server_tool + call_id: call_1 + result: { + "success": true +} + [14] type: mcp_tool_approval_request + server_label: mcp_server + id: req_1 + name: mcp_tool + arguments: {} + [15] type: mcp_tool_approval_response + approval_request_id: req_1 + approve: true + reason: looks good + [16] type: mcp_tool_call + server_label: weather-mcp-server + call_id: mcp_forecast_456 + name: get_7day_forecast + arguments: {"city":"New York","days":7} + [17] type: mcp_tool_result + call_id: mcp_forecast_456 + name: get_7day_forecast + result: {"status":"partial","days_available":3} + error: [503] Service temporarily unavailable for full 7-day forecast + [18] type: mcp_list_tools_result + server_label: weather-mcp-server + tools: 3 items + - get_current_weather: Get current weather conditions for a location + - get_7day_forecast: Get 7-day weather forecast + - get_weather_alerts: Get active weather alerts and warnings +response_meta: + token_usage: prompt=250, completion=180, total=430 +`, output) + + t.Run("nil/empty fields", func(t *testing.T) { + msg := &AgenticMessage{ + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + {Type: ContentBlockTypeUserInputAudio, UserInputAudio: &UserInputAudio{}}, // empty + {Type: ContentBlockTypeUserInputVideo, UserInputVideo: &UserInputVideo{}}, + {Type: ContentBlockTypeUserInputFile, UserInputFile: &UserInputFile{}}, + {Type: ContentBlockTypeAssistantGenImage, AssistantGenImage: &AssistantGenImage{}}, + {Type: ContentBlockTypeAssistantGenAudio, AssistantGenAudio: &AssistantGenAudio{}}, + {Type: ContentBlockTypeAssistantGenVideo, AssistantGenVideo: &AssistantGenVideo{}}, + {Type: ContentBlockTypeServerToolCall, ServerToolCall: &ServerToolCall{Name: "t"}}, // No CallID + {Type: ContentBlockTypeServerToolResult, ServerToolResult: &ServerToolResult{Name: "t"}}, // No CallID + {Type: ContentBlockTypeMCPToolResult, MCPToolResult: &MCPToolResult{Name: "t"}}, // No Error + {Type: ContentBlockTypeMCPListToolsResult, MCPListToolsResult: &MCPListToolsResult{}}, // No Error + {Type: ContentBlockTypeMCPToolApprovalResponse, MCPToolApprovalResponse: &MCPToolApprovalResponse{Approve: false}}, // No Reason + nil, // Nil block in slice + }, + } + + s := msg.String() + assert.Contains(t, s, "type: user_input_audio") + assert.NotContains(t, s, "mime_type:") + assert.Contains(t, s, "type: server_tool_call") + }) + + t.Run("nil content struct in block", func(t *testing.T) { + // Test cases where the specific content struct is nil but type is set + // This shouldn't crash and should just print type + msg := &AgenticMessage{ + ContentBlocks: []*ContentBlock{ + {Type: ContentBlockTypeReasoning, Reasoning: nil}, + {Type: ContentBlockTypeUserInputText, UserInputText: nil}, + {Type: ContentBlockTypeUserInputImage, UserInputImage: nil}, + {Type: ContentBlockTypeUserInputAudio, UserInputAudio: nil}, + {Type: ContentBlockTypeUserInputVideo, UserInputVideo: nil}, + {Type: ContentBlockTypeUserInputFile, UserInputFile: nil}, + {Type: ContentBlockTypeAssistantGenText, AssistantGenText: nil}, + {Type: ContentBlockTypeAssistantGenImage, AssistantGenImage: nil}, + {Type: ContentBlockTypeAssistantGenAudio, AssistantGenAudio: nil}, + {Type: ContentBlockTypeAssistantGenVideo, AssistantGenVideo: nil}, + {Type: ContentBlockTypeFunctionToolCall, FunctionToolCall: nil}, + {Type: ContentBlockTypeFunctionToolResult, FunctionToolResult: nil}, + {Type: ContentBlockTypeServerToolCall, ServerToolCall: nil}, + {Type: ContentBlockTypeServerToolResult, ServerToolResult: nil}, + {Type: ContentBlockTypeMCPToolCall, MCPToolCall: nil}, + {Type: ContentBlockTypeMCPToolResult, MCPToolResult: nil}, + {Type: ContentBlockTypeMCPListToolsResult, MCPListToolsResult: nil}, + {Type: ContentBlockTypeMCPToolApprovalRequest, MCPToolApprovalRequest: nil}, + {Type: ContentBlockTypeMCPToolApprovalResponse, MCPToolApprovalResponse: nil}, + }, + } + s := msg.String() + assert.Contains(t, s, "type: reasoning") + // ensure no panic and basic output present + }) +} + +func TestSystemAgenticMessage(t *testing.T) { + t.Run("basic", func(t *testing.T) { + msg := SystemAgenticMessage("system") + assert.Equal(t, AgenticRoleTypeSystem, msg.Role) + assert.Len(t, msg.ContentBlocks, 1) + assert.Equal(t, "system", msg.ContentBlocks[0].UserInputText.Text) + }) +} + +func TestUserAgenticMessage(t *testing.T) { + t.Run("basic", func(t *testing.T) { + msg := UserAgenticMessage("user") + assert.Equal(t, AgenticRoleTypeUser, msg.Role) + assert.Len(t, msg.ContentBlocks, 1) + assert.Equal(t, "user", msg.ContentBlocks[0].UserInputText.Text) + }) +} + +func TestFunctionToolResultAgenticMessage(t *testing.T) { + t.Run("basic", func(t *testing.T) { + msg := FunctionToolResultAgenticMessage("call_1", "tool_name", "result_str") + assert.Equal(t, AgenticRoleTypeUser, msg.Role) + assert.Len(t, msg.ContentBlocks, 1) + assert.Equal(t, ContentBlockTypeFunctionToolResult, msg.ContentBlocks[0].Type) + assert.Equal(t, "call_1", msg.ContentBlocks[0].FunctionToolResult.CallID) + assert.Equal(t, "tool_name", msg.ContentBlocks[0].FunctionToolResult.Name) + assert.Equal(t, "result_str", msg.ContentBlocks[0].FunctionToolResult.Result) + }) +} + +func TestNewContentBlock(t *testing.T) { + cbType := reflect.TypeOf(ContentBlock{}) + for i := 0; i < cbType.NumField(); i++ { + field := cbType.Field(i) + + // Skip non-content fields + if field.Name == "Type" || field.Name == "Extra" || field.Name == "StreamingMeta" { + continue + } + + t.Run(field.Name, func(t *testing.T) { + // Ensure field is a pointer + assert.Equal(t, reflect.Ptr, field.Type.Kind(), "Field %s should be a pointer", field.Name) + + // Create a new instance of the field's type + // field.Type is *T, so Elem() is T. reflect.New(T) returns *T. + elemType := field.Type.Elem() + inputVal := reflect.New(elemType) + input := inputVal.Interface() + + // Call NewContentBlock (generic) via type switch + var block *ContentBlock + switch v := input.(type) { + case *Reasoning: + block = NewContentBlock(v) + case *UserInputText: + block = NewContentBlock(v) + case *UserInputImage: + block = NewContentBlock(v) + case *UserInputAudio: + block = NewContentBlock(v) + case *UserInputVideo: + block = NewContentBlock(v) + case *UserInputFile: + block = NewContentBlock(v) + case *ToolSearchFunctionToolResult: + block = NewContentBlock(v) + case *AssistantGenText: + block = NewContentBlock(v) + case *AssistantGenImage: + block = NewContentBlock(v) + case *AssistantGenAudio: + block = NewContentBlock(v) + case *AssistantGenVideo: + block = NewContentBlock(v) + case *FunctionToolCall: + block = NewContentBlock(v) + case *FunctionToolResult: + block = NewContentBlock(v) + case *ServerToolCall: + block = NewContentBlock(v) + case *ServerToolResult: + block = NewContentBlock(v) + case *MCPToolCall: + block = NewContentBlock(v) + case *MCPToolResult: + block = NewContentBlock(v) + case *MCPListToolsResult: + block = NewContentBlock(v) + case *MCPToolApprovalRequest: + block = NewContentBlock(v) + case *MCPToolApprovalResponse: + block = NewContentBlock(v) + default: + t.Fatalf("unsupported ContentBlock field type: %T", input) + } + + // Assertions + assert.NotNil(t, block, "NewContentBlock should return non-nil for type %T", input) + + // Check if the corresponding field in block is set equals to input + blockVal := reflect.ValueOf(block).Elem() + fieldVal := blockVal.FieldByName(field.Name) + assert.True(t, fieldVal.IsValid(), "Field %s not found in result", field.Name) + assert.Equal(t, input, fieldVal.Interface(), "Field %s should match input", field.Name) + + // Check Type is set + typeVal := blockVal.FieldByName("Type") + assert.NotEmpty(t, typeVal.String(), "Type should be set for %s", field.Name) + }) + } +} diff --git a/schema/claude/consts.go b/schema/claude/consts.go new file mode 100644 index 000000000..714b0362e --- /dev/null +++ b/schema/claude/consts.go @@ -0,0 +1,27 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package claude defines constants for claude. +package claude + +type TextCitationType string + +const ( + TextCitationTypeCharLocation TextCitationType = "char_location" + TextCitationTypePageLocation TextCitationType = "page_location" + TextCitationTypeContentBlockLocation TextCitationType = "content_block_location" + TextCitationTypeWebSearchResultLocation TextCitationType = "web_search_result_location" +) diff --git a/schema/claude/extension.go b/schema/claude/extension.go new file mode 100644 index 000000000..5df8d8907 --- /dev/null +++ b/schema/claude/extension.go @@ -0,0 +1,121 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package claude + +import ( + "fmt" +) + +type ResponseMetaExtension struct { + ID string `json:"id,omitempty"` + StopReason string `json:"stop_reason,omitempty"` +} + +type AssistantGenTextExtension struct { + Citations []*TextCitation `json:"citations,omitempty"` +} + +type TextCitation struct { + Type TextCitationType `json:"type,omitempty"` + + CharLocation *CitationCharLocation `json:"char_location,omitempty"` + PageLocation *CitationPageLocation `json:"page_location,omitempty"` + ContentBlockLocation *CitationContentBlockLocation `json:"content_block_location,omitempty"` + WebSearchResultLocation *CitationWebSearchResultLocation `json:"web_search_result_location,omitempty"` +} + +type CitationCharLocation struct { + CitedText string `json:"cited_text,omitempty"` + + DocumentTitle string `json:"document_title,omitempty"` + DocumentIndex int `json:"document_index,omitempty"` + + StartCharIndex int `json:"start_char_index,omitempty"` + EndCharIndex int `json:"end_char_index,omitempty"` +} + +type CitationPageLocation struct { + CitedText string `json:"cited_text,omitempty"` + + DocumentTitle string `json:"document_title,omitempty"` + DocumentIndex int `json:"document_index,omitempty"` + + StartPageNumber int `json:"start_page_number,omitempty"` + EndPageNumber int `json:"end_page_number,omitempty"` +} + +type CitationContentBlockLocation struct { + CitedText string `json:"cited_text,omitempty"` + + DocumentTitle string `json:"document_title,omitempty"` + DocumentIndex int `json:"document_index,omitempty"` + + StartBlockIndex int `json:"start_block_index,omitempty"` + EndBlockIndex int `json:"end_block_index,omitempty"` +} + +type CitationWebSearchResultLocation struct { + CitedText string `json:"cited_text,omitempty"` + + Title string `json:"title,omitempty"` + URL string `json:"url,omitempty"` + + EncryptedIndex string `json:"encrypted_index,omitempty"` +} + +// ConcatAssistantGenTextExtensions concatenates multiple AssistantGenTextExtension chunks into a single one. +func ConcatAssistantGenTextExtensions(chunks []*AssistantGenTextExtension) (*AssistantGenTextExtension, error) { + if len(chunks) == 0 { + return nil, fmt.Errorf("no assistant generated text extension found") + } + if len(chunks) == 1 { + return chunks[0], nil + } + + ret := &AssistantGenTextExtension{ + Citations: make([]*TextCitation, 0, len(chunks)), + } + + for _, ext := range chunks { + ret.Citations = append(ret.Citations, ext.Citations...) + } + + return ret, nil +} + +// ConcatResponseMetaExtensions concatenates multiple ResponseMetaExtension chunks into a single one. +func ConcatResponseMetaExtensions(chunks []*ResponseMetaExtension) (*ResponseMetaExtension, error) { + if len(chunks) == 0 { + return nil, fmt.Errorf("no response meta extension found") + } + if len(chunks) == 1 { + return chunks[0], nil + } + + ret := &ResponseMetaExtension{} + + for _, ext := range chunks { + if ext.ID != "" { + ret.ID = ext.ID + } + if ext.StopReason != "" { + ret.StopReason = ext.StopReason + } + } + + return ret, nil +} diff --git a/schema/claude/extension_test.go b/schema/claude/extension_test.go new file mode 100644 index 000000000..474fe740b --- /dev/null +++ b/schema/claude/extension_test.go @@ -0,0 +1,190 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package claude + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConcatAssistantGenTextExtensions(t *testing.T) { + t.Run("multiple extensions - concatenates all citations", func(t *testing.T) { + exts := []*AssistantGenTextExtension{ + { + Citations: []*TextCitation{ + { + Type: "char_location", + CharLocation: &CitationCharLocation{ + CitedText: "citation 1", + DocumentIndex: 0, + }, + }, + }, + }, + { + Citations: []*TextCitation{ + { + Type: "page_location", + PageLocation: &CitationPageLocation{ + CitedText: "citation 2", + StartPageNumber: 1, + EndPageNumber: 2, + }, + }, + { + Type: "web_search_result_location", + WebSearchResultLocation: &CitationWebSearchResultLocation{ + CitedText: "citation 3", + URL: "https://example.com", + }, + }, + }, + }, + { + Citations: []*TextCitation{ + { + Type: "content_block_location", + ContentBlockLocation: &CitationContentBlockLocation{ + CitedText: "citation 4", + StartBlockIndex: 0, + EndBlockIndex: 5, + }, + }, + }, + }, + } + + result, err := ConcatAssistantGenTextExtensions(exts) + assert.NoError(t, err) + assert.Len(t, result.Citations, 4) + assert.Equal(t, "citation 1", result.Citations[0].CharLocation.CitedText) + assert.Equal(t, "citation 2", result.Citations[1].PageLocation.CitedText) + assert.Equal(t, "citation 3", result.Citations[2].WebSearchResultLocation.CitedText) + assert.Equal(t, "citation 4", result.Citations[3].ContentBlockLocation.CitedText) + }) + + t.Run("mixed empty and non-empty citations", func(t *testing.T) { + exts := []*AssistantGenTextExtension{ + {Citations: nil}, + { + Citations: []*TextCitation{ + { + Type: "char_location", + CharLocation: &CitationCharLocation{ + CitedText: "text1", + }, + }, + }, + }, + {Citations: []*TextCitation{}}, + { + Citations: []*TextCitation{ + { + Type: "page_location", + PageLocation: &CitationPageLocation{ + CitedText: "text2", + }, + }, + }, + }, + } + + result, err := ConcatAssistantGenTextExtensions(exts) + assert.NoError(t, err) + assert.Len(t, result.Citations, 2) + assert.Equal(t, "text1", result.Citations[0].CharLocation.CitedText) + assert.Equal(t, "text2", result.Citations[1].PageLocation.CitedText) + }) + + t.Run("streaming scenario - citations arrive in chunks", func(t *testing.T) { + // Simulates streaming where citations arrive progressively + exts := []*AssistantGenTextExtension{ + { + Citations: []*TextCitation{ + {Type: "char_location", CharLocation: &CitationCharLocation{CitedText: "chunk1"}}, + }, + }, + { + Citations: []*TextCitation{ + {Type: "char_location", CharLocation: &CitationCharLocation{CitedText: "chunk2"}}, + }, + }, + { + Citations: []*TextCitation{ + {Type: "char_location", CharLocation: &CitationCharLocation{CitedText: "chunk3"}}, + }, + }, + } + + result, err := ConcatAssistantGenTextExtensions(exts) + assert.NoError(t, err) + assert.Len(t, result.Citations, 3) + assert.Equal(t, "chunk1", result.Citations[0].CharLocation.CitedText) + assert.Equal(t, "chunk2", result.Citations[1].CharLocation.CitedText) + assert.Equal(t, "chunk3", result.Citations[2].CharLocation.CitedText) + }) +} + +func TestConcatResponseMetaExtensions(t *testing.T) { + t.Run("multiple extensions - takes last non-empty values", func(t *testing.T) { + exts := []*ResponseMetaExtension{ + { + ID: "msg_1", + StopReason: "stop_1", + }, + { + ID: "msg_2", + StopReason: "", + }, + { + ID: "", + StopReason: "stop_3", + }, + } + + result, err := ConcatResponseMetaExtensions(exts) + assert.NoError(t, err) + assert.Equal(t, "msg_2", result.ID) // Last non-empty ID + assert.Equal(t, "stop_3", result.StopReason) // Last non-empty StopReason + }) + + t.Run("all empty fields", func(t *testing.T) { + exts := []*ResponseMetaExtension{ + {ID: "", StopReason: ""}, + {ID: "", StopReason: ""}, + } + + result, err := ConcatResponseMetaExtensions(exts) + assert.NoError(t, err) + assert.Equal(t, "", result.ID) + assert.Equal(t, "", result.StopReason) + }) + + t.Run("streaming scenario - ID in first chunk, StopReason in last", func(t *testing.T) { + exts := []*ResponseMetaExtension{ + {ID: "msg_stream_123", StopReason: ""}, + {ID: "", StopReason: ""}, + {ID: "", StopReason: "end_turn"}, + } + + result, err := ConcatResponseMetaExtensions(exts) + assert.NoError(t, err) + assert.Equal(t, "msg_stream_123", result.ID) + assert.Equal(t, "end_turn", result.StopReason) + }) +} diff --git a/schema/gemini/extension.go b/schema/gemini/extension.go new file mode 100644 index 000000000..efbc4f4bd --- /dev/null +++ b/schema/gemini/extension.go @@ -0,0 +1,115 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package gemini defines the extension for gemini. +package gemini + +import ( + "fmt" +) + +type ResponseMetaExtension struct { + ID string `json:"id,omitempty"` + FinishReason string `json:"finish_reason,omitempty"` + GroundingMeta *GroundingMetadata `json:"grounding_meta,omitempty"` +} + +type GroundingMetadata struct { + // List of supporting references retrieved from specified grounding source. + GroundingChunks []*GroundingChunk `json:"grounding_chunks,omitempty"` + // Optional. List of grounding support. + GroundingSupports []*GroundingSupport `json:"grounding_supports,omitempty"` + // Optional. Google search entry for the following-up web searches. + SearchEntryPoint *SearchEntryPoint `json:"search_entry_point,omitempty"` + // Optional. Web search queries for the following-up web search. + WebSearchQueries []string `json:"web_search_queries,omitempty"` +} + +type GroundingChunk struct { + // Grounding chunk from the web. + Web *GroundingChunkWeb `json:"web,omitempty"` +} + +// GroundingChunkWeb is the chunk from the web. +type GroundingChunkWeb struct { + // Domain of the (original) URI. This field is not supported in Gemini API. + Domain string `json:"domain,omitempty"` + // Title of the chunk. + Title string `json:"title,omitempty"` + // URI reference of the chunk. + URI string `json:"uri,omitempty"` +} + +type GroundingSupport struct { + // Confidence score of the support references. Ranges from 0 to 1. 1 is the most confident. + // For Gemini 2.0 and before, this list must have the same size as the grounding_chunk_indices. + // For Gemini 2.5 and after, this list will be empty and should be ignored. + ConfidenceScores []float32 `json:"confidence_scores,omitempty"` + // A list of indices (into 'grounding_chunk') specifying the citations associated with + // the claim. For instance [1,3,4] means that grounding_chunk[1], grounding_chunk[3], + // grounding_chunk[4] are the retrieved content attributed to the claim. + GroundingChunkIndices []int `json:"grounding_chunk_indices,omitempty"` + // Segment of the content this support belongs to. + Segment *Segment `json:"segment,omitempty"` +} + +// Segment of the content. +type Segment struct { + // Output only. End index in the given Part, measured in bytes. Offset from the start + // of the Part, exclusive, starting at zero. + EndIndex int `json:"end_index,omitempty"` + // Output only. The index of a Part object within its parent Content object. + PartIndex int `json:"part_index,omitempty"` + // Output only. Start index in the given Part, measured in bytes. Offset from the start + // of the Part, inclusive, starting at zero. + StartIndex int `json:"start_index,omitempty"` + // Output only. The text corresponding to the segment from the response. + Text string `json:"text,omitempty"` +} + +// SearchEntryPoint is the Google search entry point. +type SearchEntryPoint struct { + // Optional. Web content snippet that can be embedded in a web page or an app webview. + RenderedContent string `json:"rendered_content,omitempty"` + // Optional. Base64 encoded JSON representing array of tuple. + SDKBlob []byte `json:"sdk_blob,omitempty"` +} + +// ConcatResponseMetaExtensions concatenates multiple ResponseMetaExtension chunks into a single one. +func ConcatResponseMetaExtensions(chunks []*ResponseMetaExtension) (*ResponseMetaExtension, error) { + if len(chunks) == 0 { + return nil, fmt.Errorf("no response meta extension found") + } + if len(chunks) == 1 { + return chunks[0], nil + } + + ret := &ResponseMetaExtension{} + + for _, ext := range chunks { + if ext.ID != "" { + ret.ID = ext.ID + } + if ext.FinishReason != "" { + ret.FinishReason = ext.FinishReason + } + if ext.GroundingMeta != nil { + ret.GroundingMeta = ext.GroundingMeta + } + } + + return ret, nil +} diff --git a/schema/gemini/extension_test.go b/schema/gemini/extension_test.go new file mode 100644 index 000000000..56f390aa8 --- /dev/null +++ b/schema/gemini/extension_test.go @@ -0,0 +1,79 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package gemini + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConcatResponseMetaExtensions(t *testing.T) { + t.Run("multiple extensions - takes last non-empty values", func(t *testing.T) { + meta1 := &GroundingMetadata{WebSearchQueries: []string{"query1"}} + meta2 := &GroundingMetadata{WebSearchQueries: []string{"query2"}} + + exts := []*ResponseMetaExtension{ + { + ID: "resp_1", + FinishReason: "STOP", + GroundingMeta: meta1, + }, + { + ID: "resp_2", + FinishReason: "", + GroundingMeta: nil, + }, + { + ID: "", + FinishReason: "MAX_TOKENS", + GroundingMeta: meta2, + }, + } + + result, err := ConcatResponseMetaExtensions(exts) + assert.NoError(t, err) + assert.Equal(t, "resp_2", result.ID) + assert.Equal(t, "MAX_TOKENS", result.FinishReason) + assert.Equal(t, meta2, result.GroundingMeta) + }) + + t.Run("streaming scenario", func(t *testing.T) { + meta := &GroundingMetadata{ + GroundingChunks: []*GroundingChunk{ + { + Web: &GroundingChunkWeb{ + Title: "Example", + URI: "https://example.com", + }, + }, + }, + } + + exts := []*ResponseMetaExtension{ + {ID: "stream_123", FinishReason: "", GroundingMeta: nil}, + {ID: "", FinishReason: "", GroundingMeta: nil}, + {ID: "", FinishReason: "STOP", GroundingMeta: meta}, + } + + result, err := ConcatResponseMetaExtensions(exts) + assert.NoError(t, err) + assert.Equal(t, "stream_123", result.ID) + assert.Equal(t, "STOP", result.FinishReason) + assert.Equal(t, meta, result.GroundingMeta) + }) +} diff --git a/schema/message.go b/schema/message.go index 3746244bb..890af48ab 100644 --- a/schema/message.go +++ b/schema/message.go @@ -40,47 +40,56 @@ func init() { internal.RegisterStreamChunkConcatFunc(ConcatMessages) internal.RegisterStreamChunkConcatFunc(ConcatMessageArray) + internal.RegisterStreamChunkConcatFunc(ConcatAgenticMessages) + internal.RegisterStreamChunkConcatFunc(ConcatAgenticMessagesArray) + internal.RegisterStreamChunkConcatFunc(ConcatToolResults) } -// ConcatMessageArray merges aligned slices of messages into a single slice, -// concatenating messages at the same index across the input arrays. -func ConcatMessageArray(mas [][]*Message) ([]*Message, error) { - arrayLen := len(mas[0]) +func buildConcatGenericArray[T any](f func([]*T) (*T, error)) func([][]*T) ([]*T, error) { + return func(mas [][]*T) ([]*T, error) { + arrayLen := len(mas[0]) - ret := make([]*Message, arrayLen) - slicesToConcat := make([][]*Message, arrayLen) + ret := make([]*T, arrayLen) + slicesToConcat := make([][]*T, arrayLen) - for _, ma := range mas { - if len(ma) != arrayLen { - return nil, fmt.Errorf("unexpected array length. "+ - "Got %d, expected %d", len(ma), arrayLen) - } + for _, ma := range mas { + if len(ma) != arrayLen { + return nil, fmt.Errorf("unexpected array length. "+ + "Got %d, expected %d", len(ma), arrayLen) + } - for i := 0; i < arrayLen; i++ { - m := ma[i] - if m != nil { - slicesToConcat[i] = append(slicesToConcat[i], m) + for i := 0; i < arrayLen; i++ { + m := ma[i] + if m != nil { + slicesToConcat[i] = append(slicesToConcat[i], m) + } } } - } - for i, slice := range slicesToConcat { - if len(slice) == 0 { - ret[i] = nil - } else if len(slice) == 1 { - ret[i] = slice[0] - } else { - cm, err := ConcatMessages(slice) - if err != nil { - return nil, err - } + for i, slice := range slicesToConcat { + if len(slice) == 0 { + ret[i] = nil + } else if len(slice) == 1 { + ret[i] = slice[0] + } else { + cm, err := f(slice) + if err != nil { + return nil, err + } - ret[i] = cm + ret[i] = cm + } } + + return ret, nil } +} - return ret, nil +// ConcatMessageArray merges aligned slices of messages into a single slice, +// concatenating messages at the same index across the input arrays. +func ConcatMessageArray(mas [][]*Message) ([]*Message, error) { + return buildConcatGenericArray[Message](ConcatMessages)(mas) } // FormatType used by MessageTemplate.Format @@ -130,7 +139,6 @@ type ToolCall struct { Type string `json:"type"` // Function is the function call to be made. Function FunctionCall `json:"function"` - // Extra is used to store extra information for the tool call. Extra map[string]any `json:"extra,omitempty"` } @@ -213,6 +221,9 @@ type MessageInputPart struct { // File is the file input of the part, it's used when Type is "file_url". File *MessageInputFile `json:"file,omitempty"` + // ToolSearchResult holds the result of a tool search request, containing the matched tool names and their definitions. + ToolSearchResult *ToolSearchResult `json:"tool_search_result,omitempty"` + // Extra is used to store extra information. Extra map[string]any `json:"extra,omitempty"` } @@ -301,6 +312,9 @@ const ( // ToolPartTypeFile means the part is a file url. ToolPartTypeFile ToolPartType = "file" + + // ToolPartTypeToolSearchResult means the part contains tool search results. + ToolPartTypeToolSearchResult ToolPartType = "tool_search_result" ) // ToolOutputImage represents an image in tool output. @@ -327,6 +341,27 @@ type ToolOutputFile struct { MessagePartCommon } +// ToolSearchResult represents the result of a tool search operation. +// When a model issues a tool search call, the framework searches for matching tools +// and returns the results via this struct. +type ToolSearchResult struct { + // Tools contains the full definitions of matched tools that were not previously + // registered. Their complete definitions are required so that the model can + // understand their parameters and usage. + Tools []*ToolInfo +} + +func (t *ToolSearchResult) String() string { + sb := new(strings.Builder) + sb.WriteString("ToolSearchResult[") + for _, tool := range t.Tools { + sb.WriteString(tool.Name) + sb.WriteString(",") + } + sb.WriteString("]") + return sb.String() +} + // ToolOutputPart represents a part of tool execution output. // It supports streaming scenarios through the Index field for chunk merging. type ToolOutputPart struct { @@ -349,6 +384,9 @@ type ToolOutputPart struct { // File is the file content, used when Type is ToolPartTypeFile. File *ToolOutputFile `json:"file,omitempty"` + // ToolSearchResult holds the tool search results, used when Type is ToolPartTypeToolSearchResult. + ToolSearchResult *ToolSearchResult `json:"tool_search_result,omitempty"` + // Extra is used to store extra information. Extra map[string]any `json:"extra,omitempty"` } @@ -413,6 +451,14 @@ func convToolOutputPartToMessageInputPart(toolPart ToolOutputPart) (MessageInput File: &MessageInputFile{MessagePartCommon: toolPart.File.MessagePartCommon}, Extra: toolPart.Extra, }, nil + case ToolPartTypeToolSearchResult: + if toolPart.ToolSearchResult == nil { + return MessageInputPart{}, fmt.Errorf("tool search result is nil for tool part type %v", toolPart.Type) + } + return MessageInputPart{ + Type: ChatMessagePartTypeToolSearchResult, + ToolSearchResult: toolPart.ToolSearchResult, + }, nil default: return MessageInputPart{}, fmt.Errorf("unknown tool part type: %v", toolPart.Type) } @@ -489,6 +535,9 @@ const ( ChatMessagePartTypeFileURL ChatMessagePartType = "file_url" // ChatMessagePartTypeReasoning means the part is a reasoning block. ChatMessagePartTypeReasoning ChatMessagePartType = "reasoning" + + // ChatMessagePartTypeToolSearchResult means the part contains tool search results. + ChatMessagePartTypeToolSearchResult ChatMessagePartType = "tool_search_result" ) // Deprecated: This struct is deprecated as the MultiContent field is deprecated. @@ -721,7 +770,7 @@ var _ MessagesTemplate = MessagesPlaceholder("", false) // e.g. // // chatTemplate := prompt.FromMessages( -// schema.SystemMessage("you are eino helper"), +// schema.SystemMessage("you are an eino helper"), // schema.MessagesPlaceholder("history", false), // <= this will use the value of "history" in params // ) // msgs, err := chatTemplate.Format(ctx, params) @@ -739,7 +788,7 @@ type messagesPlaceholder struct { // // placeholder := MessagesPlaceholder("history", false) // params := map[string]any{ -// "history": []*schema.Message{{Role: "user", Content: "what is eino?"}, {Role: "assistant", Content: "eino is a great freamwork to build llm apps"}}, +// "history": []*schema.Message{{Role: "user", Content: "what is eino?"}, {Role: "assistant", Content: "eino is a great framework to build llm apps"}}, // "query": "how to use eino?", // } // chatTemplate := chatTpl := prompt.FromMessages( diff --git a/schema/openai/consts.go b/schema/openai/consts.go new file mode 100644 index 000000000..5958cef40 --- /dev/null +++ b/schema/openai/consts.go @@ -0,0 +1,95 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package openai defines constants for openai. +package openai + +type TextAnnotationType string + +const ( + TextAnnotationTypeFileCitation TextAnnotationType = "file_citation" + TextAnnotationTypeURLCitation TextAnnotationType = "url_citation" + TextAnnotationTypeContainerFileCitation TextAnnotationType = "container_file_citation" + TextAnnotationTypeFilePath TextAnnotationType = "file_path" +) + +type ReasoningEffort string + +const ( + ReasoningEffortMinimal ReasoningEffort = "minimal" + ReasoningEffortLow ReasoningEffort = "low" + ReasoningEffortMedium ReasoningEffort = "medium" + ReasoningEffortHigh ReasoningEffort = "high" +) + +type ReasoningSummary string + +const ( + ReasoningSummaryAuto ReasoningSummary = "auto" + ReasoningSummaryConcise ReasoningSummary = "concise" + ReasoningSummaryDetailed ReasoningSummary = "detailed" +) + +type ServiceTier string + +const ( + ServiceTierAuto ServiceTier = "auto" + ServiceTierDefault ServiceTier = "default" + ServiceTierFlex ServiceTier = "flex" + ServiceTierScale ServiceTier = "scale" + ServiceTierPriority ServiceTier = "priority" +) + +type PromptCacheRetention string + +const ( + PromptCacheRetentionInMemory PromptCacheRetention = "in-memory" + PromptCacheRetention24h PromptCacheRetention = "24h" +) + +type ResponseStatus string + +const ( + ResponseStatusCompleted ResponseStatus = "completed" + ResponseStatusFailed ResponseStatus = "failed" + ResponseStatusInProgress ResponseStatus = "in_progress" + ResponseStatusCancelled ResponseStatus = "cancelled" + ResponseStatusQueued ResponseStatus = "queued" + ResponseStatusIncomplete ResponseStatus = "incomplete" +) + +type ResponseErrorCode string + +const ( + ResponseErrorCodeServerError ResponseErrorCode = "server_error" + ResponseErrorCodeRateLimitExceeded ResponseErrorCode = "rate_limit_exceeded" + ResponseErrorCodeInvalidPrompt ResponseErrorCode = "invalid_prompt" + ResponseErrorCodeVectorStoreTimeout ResponseErrorCode = "vector_store_timeout" + ResponseErrorCodeInvalidImage ResponseErrorCode = "invalid_image" + ResponseErrorCodeInvalidImageFormat ResponseErrorCode = "invalid_image_format" + ResponseErrorCodeInvalidBase64Image ResponseErrorCode = "invalid_base64_image" + ResponseErrorCodeInvalidImageURL ResponseErrorCode = "invalid_image_url" + ResponseErrorCodeImageTooLarge ResponseErrorCode = "image_too_large" + ResponseErrorCodeImageTooSmall ResponseErrorCode = "image_too_small" + ResponseErrorCodeImageParseError ResponseErrorCode = "image_parse_error" + ResponseErrorCodeImageContentPolicyViolation ResponseErrorCode = "image_content_policy_violation" + ResponseErrorCodeInvalidImageMode ResponseErrorCode = "invalid_image_mode" + ResponseErrorCodeImageFileTooLarge ResponseErrorCode = "image_file_too_large" + ResponseErrorCodeUnsupportedImageMediaType ResponseErrorCode = "unsupported_image_media_type" + ResponseErrorCodeEmptyImageFile ResponseErrorCode = "empty_image_file" + ResponseErrorCodeFailedToDownloadImage ResponseErrorCode = "failed_to_download_image" + ResponseErrorCodeImageFileNotFound ResponseErrorCode = "image_file_not_found" +) diff --git a/schema/openai/extension.go b/schema/openai/extension.go new file mode 100644 index 000000000..1e10c411e --- /dev/null +++ b/schema/openai/extension.go @@ -0,0 +1,212 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package openai + +import ( + "fmt" + "sort" +) + +type ResponseMetaExtension struct { + ID string `json:"id,omitempty"` + Status ResponseStatus `json:"status,omitempty"` + Error *ResponseError `json:"error,omitempty"` + IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"` + PreviousResponseID string `json:"previous_response_id,omitempty"` + Reasoning *Reasoning `json:"reasoning,omitempty"` + ServiceTier ServiceTier `json:"service_tier,omitempty"` + CreatedAt int64 `json:"created_at,omitempty"` + PromptCacheRetention PromptCacheRetention `json:"prompt_cache_retention,omitempty"` +} + +type AssistantGenTextExtension struct { + Refusal *OutputRefusal `json:"refusal,omitempty"` + Annotations []*TextAnnotation `json:"annotations,omitempty"` +} + +type ResponseError struct { + Code ResponseErrorCode `json:"code,omitempty"` + Message string `json:"message,omitempty"` +} + +type IncompleteDetails struct { + Reason string `json:"reason,omitempty"` +} + +type Reasoning struct { + Effort ReasoningEffort `json:"effort,omitempty"` + Summary ReasoningSummary `json:"summary,omitempty"` +} + +type OutputRefusal struct { + Reason string `json:"reason,omitempty"` +} + +type TextAnnotation struct { + Index int `json:"index,omitempty"` + + Type TextAnnotationType `json:"type,omitempty"` + + FileCitation *TextAnnotationFileCitation `json:"file_citation,omitempty"` + URLCitation *TextAnnotationURLCitation `json:"url_citation,omitempty"` + ContainerFileCitation *TextAnnotationContainerFileCitation `json:"container_file_citation,omitempty"` + FilePath *TextAnnotationFilePath `json:"file_path,omitempty"` +} + +type TextAnnotationFileCitation struct { + // The ID of the file. + FileID string `json:"file_id,omitempty"` + // The filename of the file cited. + Filename string `json:"filename,omitempty"` + + // The index of the file in the list of files. + Index int `json:"index,omitempty"` +} + +type TextAnnotationURLCitation struct { + // The title of the web resource. + Title string `json:"title,omitempty"` + // The URL of the web resource. + URL string `json:"url,omitempty"` + + // The index of the first character of the URL citation in the message. + StartIndex int `json:"start_index,omitempty"` + // The index of the last character of the URL citation in the message. + EndIndex int `json:"end_index,omitempty"` +} + +type TextAnnotationContainerFileCitation struct { + // The ID of the container file. + ContainerID string `json:"container_id,omitempty"` + + // The ID of the file. + FileID string `json:"file_id,omitempty"` + // The filename of the container file cited. + Filename string `json:"filename,omitempty"` + + // The index of the first character of the container file citation in the message. + StartIndex int `json:"start_index,omitempty"` + // The index of the last character of the container file citation in the message. + EndIndex int `json:"end_index,omitempty"` +} + +type TextAnnotationFilePath struct { + // The ID of the file. + FileID string `json:"file_id,omitempty"` + + // The index of the file in the list of files. + Index int `json:"index,omitempty"` +} + +// ConcatAssistantGenTextExtensions concatenates multiple AssistantGenTextExtension chunks into a single one. +func ConcatAssistantGenTextExtensions(chunks []*AssistantGenTextExtension) (*AssistantGenTextExtension, error) { + if len(chunks) == 0 { + return nil, fmt.Errorf("no assistant generated text extension found") + } + + ret := &AssistantGenTextExtension{} + + var allAnnotations []*TextAnnotation + for _, ext := range chunks { + allAnnotations = append(allAnnotations, ext.Annotations...) + } + + var ( + indices []int + indexToAnnotation = map[int]*TextAnnotation{} + ) + + for _, an := range allAnnotations { + if an == nil { + continue + } + if indexToAnnotation[an.Index] == nil { + indexToAnnotation[an.Index] = an + indices = append(indices, an.Index) + } else { + return nil, fmt.Errorf("duplicate annotation index %d", an.Index) + } + } + + sort.Slice(indices, func(i, j int) bool { + return indices[i] < indices[j] + }) + + ret.Annotations = make([]*TextAnnotation, 0, len(indices)) + for _, idx := range indices { + an := *indexToAnnotation[idx] + an.Index = 0 // clear index + ret.Annotations = append(ret.Annotations, &an) + } + + for _, ext := range chunks { + if ext.Refusal == nil { + continue + } + if ret.Refusal == nil { + ret.Refusal = ext.Refusal + } else { + ret.Refusal.Reason += ext.Refusal.Reason + } + } + + return ret, nil +} + +// ConcatResponseMetaExtensions concatenates multiple ResponseMetaExtension chunks into a single one. +func ConcatResponseMetaExtensions(chunks []*ResponseMetaExtension) (*ResponseMetaExtension, error) { + if len(chunks) == 0 { + return nil, fmt.Errorf("no response meta extension found") + } + if len(chunks) == 1 { + return chunks[0], nil + } + + ret := &ResponseMetaExtension{} + + for _, ext := range chunks { + if ext.ID != "" { + ret.ID = ext.ID + } + if ext.Status != "" { + ret.Status = ext.Status + } + if ext.Error != nil { + ret.Error = ext.Error + } + if ext.IncompleteDetails != nil { + ret.IncompleteDetails = ext.IncompleteDetails + } + if ext.PreviousResponseID != "" { + ret.PreviousResponseID = ext.PreviousResponseID + } + if ext.Reasoning != nil { + ret.Reasoning = ext.Reasoning + } + if ext.ServiceTier != "" { + ret.ServiceTier = ext.ServiceTier + } + if ext.CreatedAt != 0 { + ret.CreatedAt = ext.CreatedAt + } + if ext.PromptCacheRetention != "" { + ret.PromptCacheRetention = ext.PromptCacheRetention + } + } + + return ret, nil +} diff --git a/schema/openai/extension_test.go b/schema/openai/extension_test.go new file mode 100644 index 000000000..640982fdf --- /dev/null +++ b/schema/openai/extension_test.go @@ -0,0 +1,193 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package openai + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConcatResponseMetaExtensions(t *testing.T) { + t.Run("multiple extensions - takes last non-empty values", func(t *testing.T) { + err1 := &ResponseError{Code: "err1", Message: "msg1"} + incomplete := &IncompleteDetails{Reason: "max_tokens"} + + exts := []*ResponseMetaExtension{ + { + ID: "id_1", + Status: "in_progress", + Error: err1, + IncompleteDetails: nil, + }, + { + ID: "id_2", + Status: "", + Error: nil, + IncompleteDetails: nil, + }, + { + ID: "", + Status: "completed", + Error: nil, + IncompleteDetails: incomplete, + }, + } + + result, err := ConcatResponseMetaExtensions(exts) + assert.NoError(t, err) + assert.Equal(t, "id_2", result.ID) + assert.Equal(t, ResponseStatus("completed"), result.Status) + assert.Equal(t, err1, result.Error) + assert.Equal(t, incomplete, result.IncompleteDetails) + }) + + t.Run("streaming scenario", func(t *testing.T) { + exts := []*ResponseMetaExtension{ + {ID: "chatcmpl_stream", Status: "", Error: nil, IncompleteDetails: nil}, + {ID: "", Status: ResponseStatus("in_progress"), Error: nil, IncompleteDetails: nil}, + {ID: "", Status: ResponseStatus("completed"), Error: nil, IncompleteDetails: nil}, + } + + result, err := ConcatResponseMetaExtensions(exts) + assert.NoError(t, err) + assert.Equal(t, "chatcmpl_stream", result.ID) + assert.Equal(t, ResponseStatus("completed"), result.Status) + }) +} + +func TestConcatAssistantGenTextExtensions(t *testing.T) { + t.Run("single extension with annotations", func(t *testing.T) { + ext := &AssistantGenTextExtension{ + Annotations: []*TextAnnotation{ + { + Index: 0, + Type: "file_citation", + FileCitation: &TextAnnotationFileCitation{ + FileID: "file_123", + Filename: "doc.pdf", + }, + }, + }, + } + + result, err := ConcatAssistantGenTextExtensions([]*AssistantGenTextExtension{ext}) + assert.NoError(t, err) + assert.Len(t, result.Annotations, 1) + assert.Equal(t, "file_123", result.Annotations[0].FileCitation.FileID) + }) + + t.Run("multiple extensions - merges annotations by index", func(t *testing.T) { + exts := []*AssistantGenTextExtension{ + { + Annotations: []*TextAnnotation{ + { + Index: 0, + Type: "file_citation", + FileCitation: &TextAnnotationFileCitation{ + FileID: "file_1", + }, + }, + }, + }, + { + Annotations: []*TextAnnotation{ + { + Index: 2, + Type: "url_citation", + URLCitation: &TextAnnotationURLCitation{ + URL: "https://example.com", + }, + }, + }, + }, + { + Annotations: []*TextAnnotation{ + { + Index: 1, + Type: "file_path", + FilePath: &TextAnnotationFilePath{ + FileID: "file_2", + }, + }, + }, + }, + } + + result, err := ConcatAssistantGenTextExtensions(exts) + assert.NoError(t, err) + assert.Len(t, result.Annotations, 3) + assert.Equal(t, "file_1", result.Annotations[0].FileCitation.FileID) + assert.Equal(t, "file_2", result.Annotations[1].FilePath.FileID) + assert.Equal(t, "https://example.com", result.Annotations[2].URLCitation.URL) + }) + + t.Run("streaming scenario - annotations arrive in chunks", func(t *testing.T) { + exts := []*AssistantGenTextExtension{ + { + Annotations: []*TextAnnotation{ + {Index: 0, Type: "file_citation", FileCitation: &TextAnnotationFileCitation{FileID: "f1"}}, + }, + }, + { + Annotations: []*TextAnnotation{ + {Index: 1, Type: "url_citation", URLCitation: &TextAnnotationURLCitation{URL: "url1"}}, + }, + }, + { + Annotations: []*TextAnnotation{ + {Index: 2, Type: "file_path", FilePath: &TextAnnotationFilePath{FileID: "f2"}}, + }, + }, + } + + result, err := ConcatAssistantGenTextExtensions(exts) + assert.NoError(t, err) + assert.Len(t, result.Annotations, 3) + assert.Equal(t, "f1", result.Annotations[0].FileCitation.FileID) + assert.Equal(t, "url1", result.Annotations[1].URLCitation.URL) + assert.Equal(t, "f2", result.Annotations[2].FilePath.FileID) + }) + + t.Run("multiple extensions - concatenates refusal reason", func(t *testing.T) { + ext1 := &AssistantGenTextExtension{Refusal: &OutputRefusal{Reason: "A"}} + ext2 := &AssistantGenTextExtension{Refusal: &OutputRefusal{Reason: "B"}} + + result, err := ConcatAssistantGenTextExtensions([]*AssistantGenTextExtension{ext1, ext2}) + assert.NoError(t, err) + assert.NotNil(t, result.Refusal) + assert.Equal(t, "AB", result.Refusal.Reason) + }) + + t.Run("duplicate index - error occurrence", func(t *testing.T) { + exts := []*AssistantGenTextExtension{ + { + Annotations: []*TextAnnotation{ + {Index: 0, Type: "file_citation", FileCitation: &TextAnnotationFileCitation{FileID: "first"}}, + }, + }, + { + Annotations: []*TextAnnotation{ + {Index: 0, Type: "url_citation", URLCitation: &TextAnnotationURLCitation{URL: "second"}}, + }, + }, + } + + _, err := ConcatAssistantGenTextExtensions(exts) + assert.Error(t, err) + }) +} diff --git a/schema/serialization.go b/schema/serialization.go index 7a719b0a8..22fa16ade 100644 --- a/schema/serialization.go +++ b/schema/serialization.go @@ -25,7 +25,7 @@ import ( ) func init() { - RegisterName[Message]("_eino_message") + RegisterName[*Message]("_eino_message") RegisterName[[]*Message]("_eino_message_slice") RegisterName[Document]("_eino_document") RegisterName[RoleType]("_eino_role_type") diff --git a/schema/tool.go b/schema/tool.go index ccc93b6a3..f8a0a743e 100644 --- a/schema/tool.go +++ b/schema/tool.go @@ -17,6 +17,9 @@ package schema import ( + "bytes" + "encoding/gob" + "encoding/json" "sort" "github.com/eino-contrib/jsonschema" @@ -59,6 +62,61 @@ const ( ToolChoiceForced ToolChoice = "forced" ) +type AgenticToolChoice struct { + // Type is the tool choice mode. + Type ToolChoice + + // Allowed optionally specifies the list of tools that the model is permitted to call. + // Optional. + Allowed *AgenticAllowedToolChoice + + // Forced optionally specifies the list of tools that the model is required to call. + // Optional. + Forced *AgenticForcedToolChoice +} + +// AgenticAllowedToolChoice specifies a list of allowed tools for the model. +type AgenticAllowedToolChoice struct { + // Tools is the list of allowed tools for the model to call. + // Optional. + Tools []*AllowedTool +} + +// AgenticForcedToolChoice specifies a list of tools that the model must call. +type AgenticForcedToolChoice struct { + // Tools is the list of tools that the model must call. + // Optional. + Tools []*AllowedTool +} + +// AllowedTool represents a tool that the model is allowed or forced to call. +// Exactly one of FunctionName, MCPTool, or ServerTool must be specified. +type AllowedTool struct { + // FunctionName specifies a function tool by name. + FunctionName string + + // MCPTool specifies an MCP tool. + MCPTool *AllowedMCPTool + + // ServerTool specifies a server tool. + ServerTool *AllowedServerTool +} + +// AllowedMCPTool contains the information for identifying an MCP tool. +type AllowedMCPTool struct { + // ServerLabel is the label of the MCP server. + ServerLabel string + // Name is the name of the MCP tool. + Name string +} + +// AllowedServerTool contains the information for identifying a server tool. +type AllowedServerTool struct { + // Name is the name of the server tool. + Name string +} + +// ToolInfo is the information of a tool. // ToolInfo describes a tool that can be passed to a ChatModel via // [ToolCallingChatModel.WithTools] or [ChatModel.BindTools]. // @@ -82,6 +140,104 @@ type ToolInfo struct { *ParamsOneOf } +type toolInfoForJSON struct { + Name string `json:"name,omitempty"` + Desc string `json:"desc,omitempty"` + Extra map[string]any `json:"extra,omitempty"` + HasParamsOneOf bool `json:"has_params_one_of,omitempty"` + Params map[string]*ParameterInfo `json:"params,omitempty"` + JSONSchema *jsonschema.Schema `json:"json_schema,omitempty"` +} + +type toolInfoForGob struct { + Name string + Desc string + Extra map[string]any + HasParamsOneOf bool + Params map[string]*ParameterInfo + JSONSchema *string +} + +func (t *ToolInfo) MarshalJSON() ([]byte, error) { + tmp := &toolInfoForJSON{ + Name: t.Name, + Desc: t.Desc, + Extra: t.Extra, + } + if t.ParamsOneOf != nil { + tmp.HasParamsOneOf = true + tmp.Params = t.ParamsOneOf.params + tmp.JSONSchema = t.ParamsOneOf.jsonschema + } + return json.Marshal(tmp) +} + +func (t *ToolInfo) UnmarshalJSON(data []byte) error { + tmp := &toolInfoForJSON{} + if err := json.Unmarshal(data, tmp); err != nil { + return err + } + t.Name = tmp.Name + t.Desc = tmp.Desc + t.Extra = tmp.Extra + if tmp.HasParamsOneOf { + t.ParamsOneOf = &ParamsOneOf{ + params: tmp.Params, + jsonschema: tmp.JSONSchema, + } + } + return nil +} + +func (t *ToolInfo) GobEncode() ([]byte, error) { + tmp := &toolInfoForGob{ + Name: t.Name, + Desc: t.Desc, + Extra: t.Extra, + } + if t.ParamsOneOf != nil { + tmp.HasParamsOneOf = true + tmp.Params = t.ParamsOneOf.params + if t.ParamsOneOf.jsonschema != nil { + b, err := json.Marshal(t.ParamsOneOf.jsonschema) + if err != nil { + return nil, err + } + str := string(b) + tmp.JSONSchema = &str + } + } + buf := new(bytes.Buffer) + if err := gob.NewEncoder(buf).Encode(tmp); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +func (t *ToolInfo) GobDecode(b []byte) error { + tmp := &toolInfoForGob{} + if err := gob.NewDecoder(bytes.NewBuffer(b)).Decode(tmp); err != nil { + return err + } + t.Name = tmp.Name + t.Desc = tmp.Desc + t.Extra = tmp.Extra + if !tmp.HasParamsOneOf { + return nil + } + t.ParamsOneOf = &ParamsOneOf{ + params: tmp.Params, + } + if tmp.JSONSchema != nil { + s := &jsonschema.Schema{} + if err := json.Unmarshal([]byte(*tmp.JSONSchema), s); err != nil { + return err + } + t.ParamsOneOf.jsonschema = s + } + return nil +} + // ParameterInfo is the information of a parameter. // It is used to describe the parameters of a tool. type ParameterInfo struct { diff --git a/schema/tool_test.go b/schema/tool_test.go index 97af29be2..e8f95c364 100644 --- a/schema/tool_test.go +++ b/schema/tool_test.go @@ -17,6 +17,8 @@ package schema import ( + "bytes" + "encoding/gob" "encoding/json" "testing" @@ -133,3 +135,49 @@ func TestParamsOneOfToJSONSchema(t *testing.T) { }) } + +func TestToolInfoSerialization(t *testing.T) { + ti1 := &ToolInfo{ + ParamsOneOf: NewParamsOneOfByParams(map[string]*ParameterInfo{ + "a": { + Type: String, + Desc: "desc", + }, + }), + } + ti2 := &ToolInfo{ + ParamsOneOf: NewParamsOneOfByJSONSchema(&jsonschema.Schema{ + Type: "string", + }), + } + + // json + b, err := json.Marshal(ti1) + assert.NoError(t, err) + result := &ToolInfo{} + err = json.Unmarshal(b, result) + assert.NoError(t, err) + assert.Equal(t, ti1, result) + b, err = json.Marshal(ti2) + assert.NoError(t, err) + result = &ToolInfo{} + err = json.Unmarshal(b, result) + assert.NoError(t, err) + assert.Equal(t, ti2, result) + + // gob + buf := new(bytes.Buffer) + err = gob.NewEncoder(buf).Encode(ti1) + assert.NoError(t, err) + result = &ToolInfo{} + err = gob.NewDecoder(buf).Decode(result) + assert.NoError(t, err) + assert.Equal(t, ti1, result) + buf = new(bytes.Buffer) + err = gob.NewEncoder(buf).Encode(ti2) + assert.NoError(t, err) + result = &ToolInfo{} + err = gob.NewDecoder(buf).Decode(result) + assert.NoError(t, err) + assert.Equal(t, ti2, result) +} diff --git a/utils/callbacks/template.go b/utils/callbacks/template.go index e04bddd63..f01a849b6 100644 --- a/utils/callbacks/template.go +++ b/utils/callbacks/template.go @@ -55,17 +55,20 @@ func NewHandlerHelper() *HandlerHelper { // // then use the handler with runnable.Invoke(ctx, input, compose.WithCallbacks(handler)) type HandlerHelper struct { - promptHandler *PromptCallbackHandler - chatModelHandler *ModelCallbackHandler - embeddingHandler *EmbeddingCallbackHandler - indexerHandler *IndexerCallbackHandler - retrieverHandler *RetrieverCallbackHandler - loaderHandler *LoaderCallbackHandler - transformerHandler *TransformerCallbackHandler - toolHandler *ToolCallbackHandler - toolsNodeHandler *ToolsNodeCallbackHandlers - agentHandler *AgentCallbackHandler - composeTemplates map[components.Component]callbacks.Handler + promptHandler *PromptCallbackHandler + chatModelHandler *ModelCallbackHandler + embeddingHandler *EmbeddingCallbackHandler + indexerHandler *IndexerCallbackHandler + retrieverHandler *RetrieverCallbackHandler + loaderHandler *LoaderCallbackHandler + transformerHandler *TransformerCallbackHandler + toolHandler *ToolCallbackHandler + toolsNodeHandler *ToolsNodeCallbackHandlers + agentHandler *AgentCallbackHandler + agenticPromptHandler *AgenticPromptCallbackHandler + agenticModelHandler *AgenticModelCallbackHandler + agenticToolsNodeHandler *AgenticToolsNodeCallbackHandlers + composeTemplates map[components.Component]callbacks.Handler } // Handler returns the callbacks.Handler created by HandlerHelper. @@ -127,6 +130,24 @@ func (c *HandlerHelper) ToolsNode(handler *ToolsNodeCallbackHandlers) *HandlerHe return c } +// AgenticPrompt sets the agentic prompt handler for the handler helper, which will be called when the agentic prompt component is executed. +func (c *HandlerHelper) AgenticPrompt(handler *AgenticPromptCallbackHandler) *HandlerHelper { + c.agenticPromptHandler = handler + return c +} + +// AgenticModel sets the agentic chat model handler for the handler helper, which will be called when the agentic chat model component is executed. +func (c *HandlerHelper) AgenticModel(handler *AgenticModelCallbackHandler) *HandlerHelper { + c.agenticModelHandler = handler + return c +} + +// AgenticToolsNode sets the agentic tools node handler for the handler helper, which will be called when the agentic tools node is executed. +func (c *HandlerHelper) AgenticToolsNode(handler *AgenticToolsNodeCallbackHandlers) *HandlerHelper { + c.agenticToolsNodeHandler = handler + return c +} + // Agent sets the agent handler for the handler helper, which will be called when the agent is executed. func (c *HandlerHelper) Agent(handler *AgentCallbackHandler) *HandlerHelper { c.agentHandler = handler @@ -161,8 +182,12 @@ func (c *handlerTemplate) OnStart(ctx context.Context, info *callbacks.RunInfo, switch info.Component { case components.ComponentOfPrompt: return c.promptHandler.OnStart(ctx, info, prompt.ConvCallbackInput(input)) + case components.ComponentOfAgenticPrompt: + return c.agenticPromptHandler.OnStart(ctx, info, prompt.ConvCallbackInput(input)) case components.ComponentOfChatModel: return c.chatModelHandler.OnStart(ctx, info, model.ConvCallbackInput(input)) + case components.ComponentOfAgenticModel: + return c.agenticModelHandler.OnStart(ctx, info, model.ConvAgenticCallbackInput(input)) case components.ComponentOfEmbedding: return c.embeddingHandler.OnStart(ctx, info, embedding.ConvCallbackInput(input)) case components.ComponentOfIndexer: @@ -177,6 +202,8 @@ func (c *handlerTemplate) OnStart(ctx context.Context, info *callbacks.RunInfo, return c.toolHandler.OnStart(ctx, info, tool.ConvCallbackInput(input)) case compose.ComponentOfToolsNode: return c.toolsNodeHandler.OnStart(ctx, info, convToolsNodeCallbackInput(input)) + case compose.ComponentOfAgenticToolsNode: + return c.agenticToolsNodeHandler.OnStart(ctx, info, convAgenticToolsNodeCallbackInput(input)) case adk.ComponentOfAgent: return c.agentHandler.OnStart(ctx, info, adk.ConvAgentCallbackInput(input)) case compose.ComponentOfGraph, @@ -194,8 +221,12 @@ func (c *handlerTemplate) OnEnd(ctx context.Context, info *callbacks.RunInfo, ou switch info.Component { case components.ComponentOfPrompt: return c.promptHandler.OnEnd(ctx, info, prompt.ConvCallbackOutput(output)) + case components.ComponentOfAgenticPrompt: + return c.agenticPromptHandler.OnEnd(ctx, info, prompt.ConvCallbackOutput(output)) case components.ComponentOfChatModel: return c.chatModelHandler.OnEnd(ctx, info, model.ConvCallbackOutput(output)) + case components.ComponentOfAgenticModel: + return c.agenticModelHandler.OnEnd(ctx, info, model.ConvAgenticCallbackOutput(output)) case components.ComponentOfEmbedding: return c.embeddingHandler.OnEnd(ctx, info, embedding.ConvCallbackOutput(output)) case components.ComponentOfIndexer: @@ -210,6 +241,8 @@ func (c *handlerTemplate) OnEnd(ctx context.Context, info *callbacks.RunInfo, ou return c.toolHandler.OnEnd(ctx, info, tool.ConvCallbackOutput(output)) case compose.ComponentOfToolsNode: return c.toolsNodeHandler.OnEnd(ctx, info, convToolsNodeCallbackOutput(output)) + case compose.ComponentOfAgenticToolsNode: + return c.agenticToolsNodeHandler.OnEnd(ctx, info, convAgenticToolsNodeCallbackOutput(output)) case adk.ComponentOfAgent: return c.agentHandler.OnEnd(ctx, info, adk.ConvAgentCallbackOutput(output)) case compose.ComponentOfGraph, @@ -227,8 +260,12 @@ func (c *handlerTemplate) OnError(ctx context.Context, info *callbacks.RunInfo, switch info.Component { case components.ComponentOfPrompt: return c.promptHandler.OnError(ctx, info, err) + case components.ComponentOfAgenticPrompt: + return c.agenticPromptHandler.OnError(ctx, info, err) case components.ComponentOfChatModel: return c.chatModelHandler.OnError(ctx, info, err) + case components.ComponentOfAgenticModel: + return c.agenticModelHandler.OnError(ctx, info, err) case components.ComponentOfEmbedding: return c.embeddingHandler.OnError(ctx, info, err) case components.ComponentOfIndexer: @@ -243,6 +280,8 @@ func (c *handlerTemplate) OnError(ctx context.Context, info *callbacks.RunInfo, return c.toolHandler.OnError(ctx, info, err) case compose.ComponentOfToolsNode: return c.toolsNodeHandler.OnError(ctx, info, err) + case compose.ComponentOfAgenticToolsNode: + return c.agenticToolsNodeHandler.OnError(ctx, info, err) case compose.ComponentOfGraph, compose.ComponentOfChain, compose.ComponentOfLambda: @@ -275,6 +314,11 @@ func (c *handlerTemplate) OnEndWithStreamOutput(ctx context.Context, info *callb schema.StreamReaderWithConvert(output, func(item callbacks.CallbackOutput) (*model.CallbackOutput, error) { return model.ConvCallbackOutput(item), nil })) + case components.ComponentOfAgenticModel: + return c.agenticModelHandler.OnEndWithStreamOutput(ctx, info, + schema.StreamReaderWithConvert(output, func(item callbacks.CallbackOutput) (*model.AgenticCallbackOutput, error) { + return model.ConvAgenticCallbackOutput(item), nil + })) case components.ComponentOfTool: return c.toolHandler.OnEndWithStreamOutput(ctx, info, schema.StreamReaderWithConvert(output, func(item callbacks.CallbackOutput) (*tool.CallbackOutput, error) { @@ -285,6 +329,11 @@ func (c *handlerTemplate) OnEndWithStreamOutput(ctx context.Context, info *callb schema.StreamReaderWithConvert(output, func(item callbacks.CallbackOutput) ([]*schema.Message, error) { return convToolsNodeCallbackOutput(item), nil })) + case compose.ComponentOfAgenticToolsNode: + return c.agenticToolsNodeHandler.OnEndWithStreamOutput(ctx, info, + schema.StreamReaderWithConvert(output, func(item callbacks.CallbackOutput) ([]*schema.AgenticMessage, error) { + return convAgenticToolsNodeCallbackOutput(item), nil + })) case compose.ComponentOfGraph, compose.ComponentOfChain, compose.ComponentOfLambda: @@ -295,6 +344,8 @@ func (c *handlerTemplate) OnEndWithStreamOutput(ctx context.Context, info *callb } // Needed checks if the callback handler is needed for the given timing. +// +//nolint:cyclop func (c *handlerTemplate) Needed(ctx context.Context, info *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { if info == nil { return false @@ -305,6 +356,10 @@ func (c *handlerTemplate) Needed(ctx context.Context, info *callbacks.RunInfo, t if c.chatModelHandler != nil && c.chatModelHandler.Needed(ctx, info, timing) { return true } + case components.ComponentOfAgenticModel: + if c.agenticModelHandler != nil && c.agenticModelHandler.Needed(ctx, info, timing) { + return true + } case components.ComponentOfEmbedding: if c.embeddingHandler != nil && c.embeddingHandler.Needed(ctx, info, timing) { return true @@ -321,6 +376,10 @@ func (c *handlerTemplate) Needed(ctx context.Context, info *callbacks.RunInfo, t if c.promptHandler != nil && c.promptHandler.Needed(ctx, info, timing) { return true } + case components.ComponentOfAgenticPrompt: + if c.agenticPromptHandler != nil && c.agenticPromptHandler.Needed(ctx, info, timing) { + return true + } case components.ComponentOfRetriever: if c.retrieverHandler != nil && c.retrieverHandler.Needed(ctx, info, timing) { return true @@ -337,6 +396,10 @@ func (c *handlerTemplate) Needed(ctx context.Context, info *callbacks.RunInfo, t if c.toolsNodeHandler != nil && c.toolsNodeHandler.Needed(ctx, info, timing) { return true } + case compose.ComponentOfAgenticToolsNode: + if c.agenticToolsNodeHandler != nil && c.agenticToolsNodeHandler.Needed(ctx, info, timing) { + return true + } case adk.ComponentOfAgent: if c.agentHandler != nil && c.agentHandler.Needed(ctx, info, timing) { return true @@ -596,3 +659,94 @@ func (ch *AgentCallbackHandler) Needed(ctx context.Context, info *callbacks.RunI return false } } + +// AgenticPromptCallbackHandler is the handler for the agentic prompt callback. +type AgenticPromptCallbackHandler struct { + // OnStart is the callback function for the start of the agentic prompt. + OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *prompt.CallbackInput) context.Context + // OnEnd is the callback function for the end of the agentic prompt. + OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *prompt.CallbackOutput) context.Context + // OnError is the callback function for the error of the agentic prompt. + OnError func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context +} + +// Needed checks if the callback handler is needed for the given timing. +func (ch *AgenticPromptCallbackHandler) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { + switch timing { + case callbacks.TimingOnStart: + return ch.OnStart != nil + case callbacks.TimingOnEnd: + return ch.OnEnd != nil + case callbacks.TimingOnError: + return ch.OnError != nil + default: + return false + } +} + +// AgenticModelCallbackHandler is the handler for the agentic chat model callback. +type AgenticModelCallbackHandler struct { + OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.AgenticCallbackInput) context.Context + OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *model.AgenticCallbackOutput) context.Context + OnEndWithStreamOutput func(ctx context.Context, runInfo *callbacks.RunInfo, output *schema.StreamReader[*model.AgenticCallbackOutput]) context.Context + OnError func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context +} + +// Needed checks if the callback handler is needed for the given timing. +func (ch *AgenticModelCallbackHandler) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { + switch timing { + case callbacks.TimingOnStart: + return ch.OnStart != nil + case callbacks.TimingOnEnd: + return ch.OnEnd != nil + case callbacks.TimingOnError: + return ch.OnError != nil + case callbacks.TimingOnEndWithStreamOutput: + return ch.OnEndWithStreamOutput != nil + default: + return false + } +} + +// AgenticToolsNodeCallbackHandlers defines optional callbacks for the Agentic Tools node +// lifecycle events. +type AgenticToolsNodeCallbackHandlers struct { + OnStart func(ctx context.Context, info *callbacks.RunInfo, input *schema.AgenticMessage) context.Context + OnEnd func(ctx context.Context, info *callbacks.RunInfo, input []*schema.AgenticMessage) context.Context + OnEndWithStreamOutput func(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[[]*schema.AgenticMessage]) context.Context + OnError func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context +} + +// Needed reports whether a handler is registered for the given timing. +func (ch *AgenticToolsNodeCallbackHandlers) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { + switch timing { + case callbacks.TimingOnStart: + return ch.OnStart != nil + case callbacks.TimingOnEnd: + return ch.OnEnd != nil + case callbacks.TimingOnEndWithStreamOutput: + return ch.OnEndWithStreamOutput != nil + case callbacks.TimingOnError: + return ch.OnError != nil + default: + return false + } +} + +func convAgenticToolsNodeCallbackInput(src callbacks.CallbackInput) *schema.AgenticMessage { + switch t := src.(type) { + case *schema.AgenticMessage: + return t + default: + return nil + } +} + +func convAgenticToolsNodeCallbackOutput(src callbacks.CallbackInput) []*schema.AgenticMessage { + switch t := src.(type) { + case []*schema.AgenticMessage: + return t + default: + return nil + } +} diff --git a/utils/callbacks/template_test.go b/utils/callbacks/template_test.go index 84ed6dfc6..dcc0e5c7f 100644 --- a/utils/callbacks/template_test.go +++ b/utils/callbacks/template_test.go @@ -142,6 +142,58 @@ func TestNewComponentTemplate(t *testing.T) { cnt++ return ctx }).Build()). + AgenticModel(&AgenticModelCallbackHandler{ + OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.AgenticCallbackInput) context.Context { + cnt++ + return ctx + }, + OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *model.AgenticCallbackOutput) context.Context { + cnt++ + return ctx + }, + OnEndWithStreamOutput: func(ctx context.Context, runInfo *callbacks.RunInfo, output *schema.StreamReader[*model.AgenticCallbackOutput]) context.Context { + output.Close() + cnt++ + return ctx + }, + OnError: func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context { + cnt++ + return ctx + }, + }). + AgenticPrompt(&AgenticPromptCallbackHandler{ + OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *prompt.CallbackInput) context.Context { + cnt++ + return ctx + }, + OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *prompt.CallbackOutput) context.Context { + cnt++ + return ctx + }, + OnError: func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context { + cnt++ + return ctx + }, + }). + AgenticToolsNode(&AgenticToolsNodeCallbackHandlers{ + OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *schema.AgenticMessage) context.Context { + cnt++ + return ctx + }, + OnEnd: func(ctx context.Context, info *callbacks.RunInfo, input []*schema.AgenticMessage) context.Context { + cnt++ + return ctx + }, + OnEndWithStreamOutput: func(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[[]*schema.AgenticMessage]) context.Context { + output.Close() + cnt++ + return ctx + }, + OnError: func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { + cnt++ + return ctx + }, + }). Handler() types := []components.Component{ @@ -151,6 +203,9 @@ func TestNewComponentTemplate(t *testing.T) { components.ComponentOfRetriever, components.ComponentOfTool, compose.ComponentOfLambda, + components.ComponentOfAgenticModel, + components.ComponentOfAgenticPrompt, + compose.ComponentOfAgenticToolsNode, } handler := tpl.Handler() @@ -169,28 +224,28 @@ func TestNewComponentTemplate(t *testing.T) { handler.OnEndWithStreamOutput(ctx, &callbacks.RunInfo{Component: typ}, sor) } - assert.Equal(t, 22, cnt) + assert.Equal(t, 33, cnt) ctx = context.Background() ctx = callbacks.InitCallbacks(ctx, &callbacks.RunInfo{Component: components.ComponentOfTransformer}, handler) callbacks.OnStart[any](ctx, nil) - assert.Equal(t, 22, cnt) + assert.Equal(t, 33, cnt) ctx = callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{Component: components.ComponentOfPrompt}) ctx = callbacks.OnStart[any](ctx, nil) - assert.Equal(t, 23, cnt) + assert.Equal(t, 34, cnt) ctx = callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{Component: components.ComponentOfIndexer}) callbacks.OnEnd[any](ctx, nil) - assert.Equal(t, 23, cnt) + assert.Equal(t, 34, cnt) ctx = callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{Component: components.ComponentOfEmbedding}) callbacks.OnError(ctx, nil) - assert.Equal(t, 24, cnt) + assert.Equal(t, 35, cnt) ctx = callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{Component: components.ComponentOfLoader}) callbacks.OnStart[any](ctx, nil) - assert.Equal(t, 24, cnt) + assert.Equal(t, 35, cnt) tpl.Transformer(&TransformerCallbackHandler{ OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *document.TransformerCallbackInput) context.Context { @@ -250,6 +305,37 @@ func TestNewComponentTemplate(t *testing.T) { } } }, + }).AgenticPrompt(&AgenticPromptCallbackHandler{ + OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *prompt.CallbackInput) context.Context { + cnt++ + return ctx + }, + OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *prompt.CallbackOutput) context.Context { + cnt++ + return ctx + }, + OnError: func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context { + cnt++ + return ctx + }, + }).AgenticToolsNode(&AgenticToolsNodeCallbackHandlers{ + OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *schema.AgenticMessage) context.Context { + cnt++ + return ctx + }, + OnEnd: func(ctx context.Context, info *callbacks.RunInfo, input []*schema.AgenticMessage) context.Context { + cnt++ + return ctx + }, + OnEndWithStreamOutput: func(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[[]*schema.AgenticMessage]) context.Context { + output.Close() + cnt++ + return ctx + }, + OnError: func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { + cnt++ + return ctx + }, }) handler = tpl.Handler() @@ -257,36 +343,222 @@ func TestNewComponentTemplate(t *testing.T) { ctx = callbacks.InitCallbacks(ctx, &callbacks.RunInfo{Component: components.ComponentOfTransformer}, handler) ctx = callbacks.OnStart[any](ctx, nil) - assert.Equal(t, 25, cnt) + assert.Equal(t, 36, cnt) callbacks.OnEnd[any](ctx, nil) - assert.Equal(t, 26, cnt) + assert.Equal(t, 37, cnt) ctx = callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{Component: components.ComponentOfLoader}) callbacks.OnEnd[any](ctx, nil) - assert.Equal(t, 27, cnt) + assert.Equal(t, 38, cnt) ctx = callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{Component: compose.ComponentOfToolsNode}) callbacks.OnStart[any](ctx, nil) - assert.Equal(t, 28, cnt) + assert.Equal(t, 39, cnt) sr, sw := schema.Pipe[any](0) sw.Close() callbacks.OnEndWithStreamOutput[any](ctx, sr) - assert.Equal(t, 29, cnt) + assert.Equal(t, 40, cnt) sr1, sw1 := schema.Pipe[[]*schema.Message](1) sw1.Send([]*schema.Message{{}}, nil) sw1.Close() callbacks.OnEndWithStreamOutput[[]*schema.Message](ctx, sr1) - assert.Equal(t, 30, cnt) - - callbacks.OnError(ctx, nil) - assert.Equal(t, 30, cnt) + // Check AgenticModel stream + sir2, siw2 := schema.Pipe[callbacks.CallbackOutput](1) + siw2.Close() + handler.OnEndWithStreamOutput(ctx, &callbacks.RunInfo{Component: components.ComponentOfAgenticModel}, sir2) + assert.Equal(t, 42, cnt) + + // Check AgenticToolsNode stream + sir3, siw3 := schema.Pipe[callbacks.CallbackOutput](1) + siw3.Close() + handler.OnEndWithStreamOutput(ctx, &callbacks.RunInfo{Component: compose.ComponentOfAgenticToolsNode}, sir3) + assert.Equal(t, 43, cnt) ctx = callbacks.ReuseHandlers(ctx, nil) callbacks.OnStart[any](ctx, nil) - assert.Equal(t, 30, cnt) + assert.Equal(t, 43, cnt) + }) + + t.Run("EdgeCases", func(t *testing.T) { + ctx := context.Background() + cnt := 0 + + // 1. Test Graph and Chain Setters and Execution + tpl := NewHandlerHelper(). + Graph(callbacks.NewHandlerBuilder(). + OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { + cnt++ + return ctx + }).Build()). + Chain(callbacks.NewHandlerBuilder(). + OnEndFn(func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context { + cnt++ + return ctx + }).Build()) + + h := tpl.Handler() + + // Trigger Graph OnStart + h.OnStart(ctx, &callbacks.RunInfo{Component: compose.ComponentOfGraph}, nil) + assert.Equal(t, 1, cnt) + + // Trigger Chain OnEnd + h.OnEnd(ctx, &callbacks.RunInfo{Component: compose.ComponentOfChain}, nil) + assert.Equal(t, 2, cnt) + + // 2. Test Needed logic for Graph/Chain when handler is present/absent + // Graph is present (OnStart) + needed := h.(callbacks.TimingChecker).Needed(ctx, &callbacks.RunInfo{Component: compose.ComponentOfGraph}, callbacks.TimingOnStart) + assert.True(t, needed) + + // Chain is present (OnEnd) - but we check OnStart which is not defined in the builder above? + // NewHandlerBuilder returns a handler that usually returns true for Needed if the specific func is not nil. + // Let's verify Chain OnStart is NOT needed because we only set OnEndFn. + needed = h.(callbacks.TimingChecker).Needed(ctx, &callbacks.RunInfo{Component: compose.ComponentOfChain}, callbacks.TimingOnStart) + assert.False(t, needed) // Should be false because OnStartFn wasn't set for Chain + + // Lambda is NOT present + needed = h.(callbacks.TimingChecker).Needed(ctx, &callbacks.RunInfo{Component: compose.ComponentOfLambda}, callbacks.TimingOnStart) + assert.False(t, needed) + + // 3. Test Conversion Fallbacks (Default cases) + // We need a handler with ToolsNode and AgenticToolsNode to test their conversion fallbacks + tpl2 := NewHandlerHelper(). + ToolsNode(&ToolsNodeCallbackHandlers{ + OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *schema.Message) context.Context { + if input == nil { + cnt++ + } + return ctx + }, + OnEnd: func(ctx context.Context, info *callbacks.RunInfo, input []*schema.Message) context.Context { + if input == nil { + cnt++ + } + return ctx + }, + }). + AgenticToolsNode(&AgenticToolsNodeCallbackHandlers{ + OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *schema.AgenticMessage) context.Context { + if input == nil { + cnt++ + } + return ctx + }, + OnEnd: func(ctx context.Context, info *callbacks.RunInfo, input []*schema.AgenticMessage) context.Context { + if input == nil { + cnt++ + } + return ctx + }, + }) + + h2 := tpl2.Handler() + + // Pass wrong type (string) to trigger default case in convToolsNodeCallbackInput -> returns nil + h2.OnStart(ctx, &callbacks.RunInfo{Component: compose.ComponentOfToolsNode}, "wrong-input-type") + assert.Equal(t, 3, cnt) // +1 + + // Pass wrong type to trigger default case in convToolsNodeCallbackOutput -> returns nil + h2.OnEnd(ctx, &callbacks.RunInfo{Component: compose.ComponentOfToolsNode}, "wrong-output-type") + assert.Equal(t, 4, cnt) // +1 + + // Pass wrong type to trigger default case in convAgenticToolsNodeCallbackInput -> returns nil + h2.OnStart(ctx, &callbacks.RunInfo{Component: compose.ComponentOfAgenticToolsNode}, "wrong-input-type") + assert.Equal(t, 5, cnt) // +1 + + // Pass wrong type to trigger default case in convAgenticToolsNodeCallbackOutput -> returns nil + h2.OnEnd(ctx, &callbacks.RunInfo{Component: compose.ComponentOfAgenticToolsNode}, "wrong-output-type") + assert.Equal(t, 6, cnt) // +1 + + // 4. Test Needed for Agentic components when handlers are Set vs Unset + // tpl2 has AgenticToolsNode set + needed = h2.(callbacks.TimingChecker).Needed(ctx, &callbacks.RunInfo{Component: compose.ComponentOfAgenticToolsNode}, callbacks.TimingOnStart) + assert.True(t, needed) + + // tpl2 does NOT have AgenticModel set + needed = h2.(callbacks.TimingChecker).Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfAgenticModel}, callbacks.TimingOnStart) + assert.False(t, needed) + + // Set it now + tpl2.AgenticModel(&AgenticModelCallbackHandler{ + OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.AgenticCallbackInput) context.Context { + return ctx + }, + }) + + needed = h2.(callbacks.TimingChecker).Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfAgenticModel}, callbacks.TimingOnStart) + assert.True(t, needed) + + // Check invalid component + needed = h2.(callbacks.TimingChecker).Needed(ctx, &callbacks.RunInfo{Component: "UnknownComponent"}, callbacks.TimingOnStart) + assert.False(t, needed) + + // Check RunInfo nil + needed = h2.(callbacks.TimingChecker).Needed(ctx, nil, callbacks.TimingOnStart) + assert.False(t, needed) + + // 5. Test Needed for Transformer, Loader, Indexer, etc to ensure switch coverage + tpl3 := NewHandlerHelper(). + Transformer(&TransformerCallbackHandler{OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *document.TransformerCallbackInput) context.Context { + return ctx + }}). + Loader(&LoaderCallbackHandler{OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *document.LoaderCallbackInput) context.Context { + return ctx + }}). + Indexer(&IndexerCallbackHandler{OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *indexer.CallbackInput) context.Context { + return ctx + }}). + Retriever(&RetrieverCallbackHandler{OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *retriever.CallbackInput) context.Context { + return ctx + }}). + Embedding(&EmbeddingCallbackHandler{OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *embedding.CallbackInput) context.Context { + return ctx + }}). + Tool(&ToolCallbackHandler{OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *tool.CallbackInput) context.Context { + return ctx + }}) + + h3 := tpl3.Handler() + checker := h3.(callbacks.TimingChecker) + + assert.True(t, checker.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfTransformer}, callbacks.TimingOnStart)) + assert.True(t, checker.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfLoader}, callbacks.TimingOnStart)) + assert.True(t, checker.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfIndexer}, callbacks.TimingOnStart)) + assert.True(t, checker.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfRetriever}, callbacks.TimingOnStart)) + assert.True(t, checker.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfEmbedding}, callbacks.TimingOnStart)) + assert.True(t, checker.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfTool}, callbacks.TimingOnStart)) + + // Verify False paths (by using a helper without them) + emptyH := NewHandlerHelper().Handler().(callbacks.TimingChecker) + assert.False(t, emptyH.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfTransformer}, callbacks.TimingOnStart)) + assert.False(t, emptyH.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfLoader}, callbacks.TimingOnStart)) + assert.False(t, emptyH.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfIndexer}, callbacks.TimingOnStart)) + assert.False(t, emptyH.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfRetriever}, callbacks.TimingOnStart)) + assert.False(t, emptyH.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfEmbedding}, callbacks.TimingOnStart)) + assert.False(t, emptyH.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfTool}, callbacks.TimingOnStart)) + + // 6. Test Needed for remaining components (ChatModel, Prompt, AgenticPrompt) + tpl4 := NewHandlerHelper(). + ChatModel(&ModelCallbackHandler{OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.CallbackInput) context.Context { + return ctx + }}). + Prompt(&PromptCallbackHandler{OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *prompt.CallbackInput) context.Context { + return ctx + }}). + AgenticPrompt(&AgenticPromptCallbackHandler{OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *prompt.CallbackInput) context.Context { + return ctx + }}) + + h4 := tpl4.Handler() + checker4 := h4.(callbacks.TimingChecker) + + assert.True(t, checker4.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfChatModel}, callbacks.TimingOnStart)) + assert.True(t, checker4.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfPrompt}, callbacks.TimingOnStart)) + assert.True(t, checker4.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfAgenticPrompt}, callbacks.TimingOnStart)) }) }