Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 134 additions & 1 deletion tests/realtime/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest
from pydantic import BaseModel, ConfigDict

from agents.exceptions import UserError
from agents.exceptions import ToolTimeoutError, UserError
from agents.guardrail import GuardrailFunctionOutput, OutputGuardrail
from agents.handoffs import Handoff
from agents.realtime.agent import RealtimeAgent
Expand Down Expand Up @@ -60,6 +60,7 @@
RealtimeModelSendUserInput,
)
from agents.realtime.session import REJECTION_MESSAGE, RealtimeSession, _serialize_tool_output
from agents.run_context import RunContextWrapper
from agents.tool import FunctionTool
from agents.tool_context import ToolContext

Expand Down Expand Up @@ -1058,6 +1059,138 @@ async def invoke_slow_tool(_ctx: ToolContext[Any], _arguments: str) -> str:
assert start_response is True
assert "timed out" in sent_output.lower()

@pytest.mark.asyncio
async def test_function_tool_timeout_raise_exception_propagates(self, mock_model, mock_agent):
async def invoke_slow_tool(_ctx: ToolContext[Any], _arguments: str) -> str:
await asyncio.sleep(0.2)
return "done"

timeout_tool = FunctionTool(
name="slow_tool",
description="slow",
params_json_schema={"type": "object", "properties": {}},
on_invoke_tool=invoke_slow_tool,
timeout_seconds=0.01,
timeout_behavior="raise_exception",
)
mock_agent.get_all_tools.return_value = [timeout_tool]

session = RealtimeSession(mock_model, mock_agent, None)
tool_call_event = RealtimeModelToolCallEvent(
name="slow_tool",
call_id="call_timeout_raise",
arguments="{}",
)

with pytest.raises(ToolTimeoutError, match="timed out"):
await session._handle_tool_call(tool_call_event)

assert len(mock_model.sent_tool_outputs) == 0
assert session._event_queue.qsize() == 1

tool_start_event = await session._event_queue.get()
assert isinstance(tool_start_event, RealtimeToolStart)
assert tool_start_event.tool == timeout_tool
assert tool_start_event.arguments == "{}"

@pytest.mark.asyncio
async def test_function_tool_timeout_uses_async_error_function_result(
self, mock_model, mock_agent
):
async def invoke_slow_tool(_ctx: ToolContext[Any], _arguments: str) -> str:
await asyncio.sleep(0.2)
return "done"

async def format_timeout_error(ctx: RunContextWrapper[Any], error: Exception) -> str:
assert isinstance(error, ToolTimeoutError)
assert isinstance(ctx, ToolContext)
assert ctx.tool_name == "slow_tool"
assert ctx.tool_call_id == "call_timeout_custom"
return f"async-timeout:{error.tool_name}:{error.timeout_seconds:g}"

timeout_tool = FunctionTool(
name="slow_tool",
description="slow",
params_json_schema={"type": "object", "properties": {}},
on_invoke_tool=invoke_slow_tool,
timeout_seconds=0.01,
timeout_error_function=format_timeout_error,
)
mock_agent.get_all_tools.return_value = [timeout_tool]

session = RealtimeSession(mock_model, mock_agent, None)
tool_call_event = RealtimeModelToolCallEvent(
name="slow_tool",
call_id="call_timeout_custom",
arguments="{}",
)

await session._handle_tool_call(tool_call_event)

assert len(mock_model.sent_tool_outputs) == 1
sent_call, sent_output, start_response = mock_model.sent_tool_outputs[0]
assert sent_call == tool_call_event
assert sent_output == "async-timeout:slow_tool:0.01"
assert start_response is True

assert session._event_queue.qsize() == 2
await session._event_queue.get()
tool_end_event = await session._event_queue.get()
assert isinstance(tool_end_event, RealtimeToolEnd)
assert tool_end_event.output == "async-timeout:slow_tool:0.01"

@pytest.mark.asyncio
async def test_function_call_event_timeout_raise_exception_enqueues_error(
self, mock_model, mock_agent
):
async def invoke_slow_tool(_ctx: ToolContext[Any], _arguments: str) -> str:
await asyncio.sleep(0.2)
return "done"

timeout_tool = FunctionTool(
name="slow_tool",
description="slow",
params_json_schema={"type": "object", "properties": {}},
on_invoke_tool=invoke_slow_tool,
timeout_seconds=0.01,
timeout_behavior="raise_exception",
)
mock_agent.get_all_tools.return_value = [timeout_tool]

session = RealtimeSession(mock_model, mock_agent, None)
tool_call_event = RealtimeModelToolCallEvent(
name="slow_tool",
call_id="call_timeout_async",
arguments="{}",
)

await session.on_event(tool_call_event)

tool_call_tasks = list(session._tool_call_tasks)
assert len(tool_call_tasks) == 1
await asyncio.gather(*tool_call_tasks, return_exceptions=True)

assert isinstance(session._stored_exception, ToolTimeoutError)
assert session._stored_exception.tool_name == "slow_tool"
assert len(mock_model.sent_tool_outputs) == 0

events = []
while True:
event = await asyncio.wait_for(session._event_queue.get(), timeout=1)
events.append(event)
if isinstance(event, RealtimeError):
break

assert any(
isinstance(event, RealtimeRawModelEvent) and event.data == tool_call_event
for event in events
)
assert any(isinstance(event, RealtimeToolStart) for event in events)

error_event = next(event for event in events if isinstance(event, RealtimeError))
assert "Tool call task failed" in error_event.error["message"]
assert "timed out" in error_event.error["message"]

@pytest.mark.asyncio
async def test_function_tool_with_multiple_tools_available(self, mock_model, mock_agent):
"""Test function tool execution when multiple tools are available"""
Expand Down