Skip to content

Commit 756421c

Browse files
authored
fix(mcp-tool): using the async invocation for MCP tools (#840)
1 parent ee02b9f commit 756421c

2 files changed

Lines changed: 18 additions & 17 deletions

File tree

src/graph/nodes.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,10 +1125,11 @@ async def _execute_agent_step(
11251125
)
11261126

11271127
try:
1128-
# Use stream from the start to capture messages in real-time
1128+
# Use astream (async) from the start to capture messages in real-time
11291129
# This allows us to retrieve accumulated messages even if recursion limit is hit
1130+
# NOTE: astream is required for MCP tools which only support async invocation
11301131
accumulated_messages = []
1131-
for chunk in agent.stream(
1132+
async for chunk in agent.astream(
11321133
input=agent_input,
11331134
config={"recursion_limit": recursion_limit},
11341135
stream_mode="values",

tests/integration/test_nodes.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,12 +1107,12 @@ async def ainvoke(input, config):
11071107
# Simulate agent returning a message list
11081108
return {"messages": [MagicMock(content="result content")]}
11091109

1110-
def stream(input, config, stream_mode):
1111-
# Simulate agent.stream() yielding messages
1110+
async def astream(input, config, stream_mode):
1111+
# Simulate agent.astream() yielding messages (async generator)
11121112
yield {"messages": [MagicMock(content="result content")]}
11131113

11141114
agent.ainvoke = ainvoke
1115-
agent.stream = stream
1115+
agent.astream = astream
11161116
return agent
11171117

11181118

@@ -1177,12 +1177,12 @@ async def ainvoke(input, config):
11771177
assert any("DO NOT include inline citations" in m.content for m in messages)
11781178
return {"messages": [MagicMock(content="resource result")]}
11791179

1180-
def stream(input, config, stream_mode):
1181-
# Simulate agent.stream() yielding messages
1180+
async def astream(input, config, stream_mode):
1181+
# Simulate agent.astream() yielding messages (async generator)
11821182
yield {"messages": [MagicMock(content="resource result")]}
11831183

11841184
agent.ainvoke = ainvoke
1185-
agent.stream = stream
1185+
agent.astream = astream
11861186
with patch(
11871187
"src.graph.nodes.HumanMessage",
11881188
side_effect=lambda content, name=None: MagicMock(content=content, name=name),
@@ -2424,8 +2424,8 @@ async def mock_ainvoke(input, config):
24242424
]
24252425
return {"messages": messages}
24262426

2427-
def stream(input, config, stream_mode):
2428-
# Simulate agent.stream() yielding the final messages
2427+
async def astream(input, config, stream_mode):
2428+
# Simulate agent.astream() yielding the final messages (async generator)
24292429
messages = [
24302430
AIMessage(
24312431
content="I'll search for information about this topic.",
@@ -2460,7 +2460,7 @@ def stream(input, config, stream_mode):
24602460
yield {"messages": messages}
24612461

24622462
agent.ainvoke = mock_ainvoke
2463-
agent.stream = stream
2463+
agent.astream = astream
24642464

24652465
# Execute the agent step
24662466
with patch(
@@ -2556,8 +2556,8 @@ async def mock_ainvoke(input, config):
25562556
]
25572557
return {"messages": messages}
25582558

2559-
def stream(input, config, stream_mode):
2560-
# Simulate agent.stream() yielding the messages
2559+
async def astream(input, config, stream_mode):
2560+
# Simulate agent.astream() yielding the messages (async generator)
25612561
messages = [
25622562
AIMessage(
25632563
content="I'll search for information.",
@@ -2579,7 +2579,7 @@ def stream(input, config, stream_mode):
25792579
yield {"messages": messages}
25802580

25812581
agent.ainvoke = mock_ainvoke
2582-
agent.stream = stream
2582+
agent.astream = astream
25832583

25842584
with patch(
25852585
"src.graph.nodes.HumanMessage",
@@ -2639,8 +2639,8 @@ async def mock_ainvoke(input, config):
26392639
]
26402640
return {"messages": messages}
26412641

2642-
def stream(input, config, stream_mode):
2643-
# Simulate agent.stream() yielding messages without tool calls
2642+
async def astream(input, config, stream_mode):
2643+
# Simulate agent.astream() yielding messages without tool calls (async generator)
26442644
messages = [
26452645
AIMessage(
26462646
content="Based on my knowledge, here is the answer without needing to search."
@@ -2649,7 +2649,7 @@ def stream(input, config, stream_mode):
26492649
yield {"messages": messages}
26502650

26512651
agent.ainvoke = mock_ainvoke
2652-
agent.stream = stream
2652+
agent.astream = astream
26532653

26542654
with patch(
26552655
"src.graph.nodes.HumanMessage",

0 commit comments

Comments
 (0)