|
7 | 7 | from functools import partial |
8 | 8 | from typing import Annotated, Any, Literal |
9 | 9 |
|
10 | | -from langchain_core.messages import AIMessage, HumanMessage, ToolMessage |
| 10 | +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage |
11 | 11 | from langchain_core.runnables import RunnableConfig |
12 | 12 | from langchain_core.tools import tool |
13 | 13 | from langchain_mcp_adapters.client import MultiServerMCPClient |
| 14 | +from langgraph.errors import GraphRecursionError |
14 | 15 | from langgraph.types import Command, interrupt |
15 | 16 |
|
16 | 17 | from src.agents import create_agent |
|
19 | 20 | from src.config.configuration import Configuration |
20 | 21 | from src.llms.llm import get_llm_by_type, get_llm_token_limit_by_type |
21 | 22 | from src.prompts.planner_model import Plan |
22 | | -from src.prompts.template import apply_prompt_template |
| 23 | +from src.prompts.template import apply_prompt_template, get_system_prompt_template |
23 | 24 | from src.tools import ( |
24 | 25 | crawl_tool, |
25 | 26 | get_retriever_tool, |
@@ -929,6 +930,79 @@ def validate_web_search_usage(messages: list, agent_name: str = "agent") -> bool |
929 | 930 | return web_search_used |
930 | 931 |
|
931 | 932 |
|
| 933 | +async def _handle_recursion_limit_fallback( |
| 934 | + messages: list, |
| 935 | + agent_name: str, |
| 936 | + current_step, |
| 937 | + state: State, |
| 938 | +) -> list: |
| 939 | + """Handle GraphRecursionError with graceful fallback using LLM summary. |
| 940 | +
|
| 941 | + When the agent hits the recursion limit, this function generates a final output |
| 942 | + using only the observations already gathered, without calling any tools. |
| 943 | +
|
| 944 | + Args: |
| 945 | + messages: Messages accumulated during agent execution before hitting limit |
| 946 | + agent_name: Name of the agent that hit the limit |
| 947 | + current_step: The current step being executed |
| 948 | + state: Current workflow state |
| 949 | +
|
| 950 | + Returns: |
| 951 | + list: Messages including the accumulated messages plus the fallback summary |
| 952 | +
|
| 953 | + Raises: |
| 954 | + Exception: If the fallback LLM call fails |
| 955 | + """ |
| 956 | + logger.warning( |
| 957 | + f"Recursion limit reached for {agent_name} agent. " |
| 958 | + f"Attempting graceful fallback with {len(messages)} accumulated messages." |
| 959 | + ) |
| 960 | + |
| 961 | + if len(messages) == 0: |
| 962 | + return messages |
| 963 | + |
| 964 | + cleared_messages = messages.copy() |
| 965 | + while len(cleared_messages) > 0 and cleared_messages[-1].type == "system": |
| 966 | + cleared_messages = cleared_messages[:-1] |
| 967 | + |
| 968 | + # Prepare state for prompt template |
| 969 | + fallback_state = { |
| 970 | + "locale": state.get("locale", "en-US"), |
| 971 | + } |
| 972 | + |
| 973 | + # Apply the recursion_fallback prompt template |
| 974 | + system_prompt = get_system_prompt_template(agent_name, fallback_state, None, fallback_state.get("locale", "en-US")) |
| 975 | + limit_prompt = get_system_prompt_template("recursion_fallback", fallback_state, None, fallback_state.get("locale", "en-US")) |
| 976 | + fallback_messages = cleared_messages + [ |
| 977 | + SystemMessage(content=system_prompt), |
| 978 | + SystemMessage(content=limit_prompt) |
| 979 | + ] |
| 980 | + |
| 981 | + # Get the LLM without tools (strip all tools from binding) |
| 982 | + fallback_llm = get_llm_by_type(AGENT_LLM_MAP[agent_name]) |
| 983 | + |
| 984 | + # Call the LLM with the updated messages |
| 985 | + fallback_response = fallback_llm.invoke(fallback_messages) |
| 986 | + fallback_content = fallback_response.content |
| 987 | + |
| 988 | + logger.info( |
| 989 | + f"Graceful fallback succeeded for {agent_name} agent. " |
| 990 | + f"Generated summary of {len(fallback_content)} characters." |
| 991 | + ) |
| 992 | + |
| 993 | + # Sanitize response |
| 994 | + fallback_content = sanitize_tool_response(str(fallback_content)) |
| 995 | + |
| 996 | + # Update the step with the fallback result |
| 997 | + current_step.execution_res = fallback_content |
| 998 | + |
| 999 | + # Return the accumulated messages plus the fallback response |
| 1000 | + result_messages = list(cleared_messages) |
| 1001 | + result_messages.append(AIMessage(content=fallback_content, name=agent_name)) |
| 1002 | + |
| 1003 | + return result_messages |
| 1004 | + |
| 1005 | + |
932 | 1006 | async def _execute_agent_step( |
933 | 1007 | state: State, agent, agent_name: str, config: RunnableConfig = None |
934 | 1008 | ) -> Command[Literal["research_team"]]: |
@@ -1049,11 +1123,51 @@ async def _execute_agent_step( |
1049 | 1123 | f"Context compression for {agent_name}: {len(compressed_state.get('messages', []))} messages, " |
1050 | 1124 | f"estimated tokens before: ~{token_count_before}, after: ~{token_count_after}" |
1051 | 1125 | ) |
1052 | | - |
| 1126 | + |
1053 | 1127 | try: |
1054 | | - result = await agent.ainvoke( |
1055 | | - input=agent_input, config={"recursion_limit": recursion_limit} |
1056 | | - ) |
| 1128 | + # Use stream from the start to capture messages in real-time |
| 1129 | + # This allows us to retrieve accumulated messages even if recursion limit is hit |
| 1130 | + accumulated_messages = [] |
| 1131 | + for chunk in agent.stream( |
| 1132 | + input=agent_input, |
| 1133 | + config={"recursion_limit": recursion_limit}, |
| 1134 | + stream_mode="values", |
| 1135 | + ): |
| 1136 | + if isinstance(chunk, dict) and "messages" in chunk: |
| 1137 | + accumulated_messages = chunk["messages"] |
| 1138 | + |
| 1139 | + # If we get here, execution completed successfully |
| 1140 | + result = {"messages": accumulated_messages} |
| 1141 | + except GraphRecursionError: |
| 1142 | + # Check if recursion fallback is enabled |
| 1143 | + configurable = Configuration.from_runnable_config(config) if config else Configuration() |
| 1144 | + |
| 1145 | + if configurable.enable_recursion_fallback: |
| 1146 | + try: |
| 1147 | + # Call fallback with accumulated messages (function returns list of messages) |
| 1148 | + response_messages = await _handle_recursion_limit_fallback( |
| 1149 | + messages=accumulated_messages, |
| 1150 | + agent_name=agent_name, |
| 1151 | + current_step=current_step, |
| 1152 | + state=state, |
| 1153 | + ) |
| 1154 | + |
| 1155 | + # Create result dict so the code can continue normally from line 1178 |
| 1156 | + result = {"messages": response_messages} |
| 1157 | + except Exception as fallback_error: |
| 1158 | + # If fallback fails, log and fall through to standard error handling |
| 1159 | + logger.error( |
| 1160 | + f"Recursion fallback failed for {agent_name} agent: {fallback_error}. " |
| 1161 | + "Falling back to standard error handling." |
| 1162 | + ) |
| 1163 | + raise |
| 1164 | + else: |
| 1165 | + # Fallback disabled, let error propagate to standard handler |
| 1166 | + logger.info( |
| 1167 | + f"Recursion limit reached but graceful fallback is disabled. " |
| 1168 | + "Using standard error handling." |
| 1169 | + ) |
| 1170 | + raise |
1057 | 1171 | except Exception as e: |
1058 | 1172 | import traceback |
1059 | 1173 |
|
@@ -1088,8 +1202,10 @@ async def _execute_agent_step( |
1088 | 1202 | goto="research_team", |
1089 | 1203 | ) |
1090 | 1204 |
|
| 1205 | + response_messages = result["messages"] |
| 1206 | + |
1091 | 1207 | # Process the result |
1092 | | - response_content = result["messages"][-1].content |
| 1208 | + response_content = response_messages[-1].content |
1093 | 1209 |
|
1094 | 1210 | # Sanitize response to remove extra tokens and truncate if needed |
1095 | 1211 | response_content = sanitize_tool_response(str(response_content)) |
|
0 commit comments