Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 160 additions & 65 deletions src/agents/run_internal/session_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,58 +424,24 @@ async def rewind_session_items(
logger.debug("Rewind target %d (first 300 chars): %s", i, target[:300])

snapshot_serializations = target_serializations.copy()
rewound = await _rewind_session_tail_suffix(
session=session,
pop_item=pop_item,
expected_serializations=target_serializations,
ignore_ids_for_matching=ignore_ids_for_matching,
mismatch_warning=(
"Skipping session rewind because the current tail does not match the retry-owned suffix"
),
pop_failure_warning="Failed to rewind session item: %s",
)
if not rewound:
return

remaining = target_serializations.copy()

while remaining:
try:
result = pop_item()
if inspect.isawaitable(result):
result = await result
except Exception as exc:
logger.warning("Failed to rewind session item: %s", exc)
break
else:
if result is None:
break

popped_serialized = fingerprint_input_item(
result, ignore_ids_for_matching=ignore_ids_for_matching
)

logger.debug("Popped item type during rewind: %s", type(result).__name__)
if popped_serialized:
logger.debug("Popped serialized (first 300 chars): %s", popped_serialized[:300])
else:
logger.debug("Popped serialized: None")

logger.debug("Number of remaining targets: %d", len(remaining))
if remaining and popped_serialized:
logger.debug("First target (first 300 chars): %s", remaining[0][:300])
logger.debug("Match found: %s", popped_serialized in remaining)
if len(remaining) > 0:
first_target = remaining[0]
if abs(len(first_target) - len(popped_serialized)) < 50:
logger.debug(
"Length comparison - popped: %d, target: %d",
len(popped_serialized),
len(first_target),
)

if popped_serialized and popped_serialized in remaining:
remaining.remove(popped_serialized)

if remaining:
logger.warning(
"Unable to fully rewind session; %d items still unmatched after retry",
len(remaining),
)
else:
await wait_for_session_cleanup(
session,
snapshot_serializations,
ignore_ids_for_matching=ignore_ids_for_matching,
)
await wait_for_session_cleanup(
session,
snapshot_serializations,
ignore_ids_for_matching=ignore_ids_for_matching,
)

if session is None or server_tracker is None:
return
Expand All @@ -493,22 +459,36 @@ async def rewind_session_items(
if isinstance(latest_id, str) and latest_id in server_tracker.server_item_ids:
return

logger.debug("Stripping stray conversation items until we reach a known server item")
while True:
try:
result = pop_item()
if inspect.isawaitable(result):
result = await result
except Exception as exc:
logger.warning("Failed to strip stray session item: %s", exc)
break
try:
session_items = await session.get_items()
except Exception as exc:
logger.debug("Failed to inspect session tail while stripping stray items: %s", exc)
return

if result is None:
break
stray_serializations = _collect_retry_owned_tail_serializations(
session_items,
server_tracker=server_tracker,
ignore_ids_for_matching=ignore_ids_for_matching,
)
if not stray_serializations:
return

stripped_id = result.get("id") if isinstance(result, dict) else getattr(result, "id", None)
if isinstance(stripped_id, str) and stripped_id in server_tracker.server_item_ids:
break
logger.debug(
"Stripping %d retry-owned conversation items until the session tail reaches "
"a known server item",
len(stray_serializations),
)
await _rewind_session_tail_suffix(
session=session,
pop_item=pop_item,
expected_serializations=stray_serializations,
ignore_ids_for_matching=ignore_ids_for_matching,
mismatch_warning=(
"Skipping stray session cleanup because the current tail no longer matches "
"retry-owned conversation items"
),
pop_failure_warning="Failed to strip stray session item: %s",
)


async def wait_for_session_cleanup(
Expand Down Expand Up @@ -582,6 +562,121 @@ def _fingerprint_or_repr(item: TResponseInputItem, *, ignore_ids_for_matching: b
)


async def _rewind_session_tail_suffix(
*,
session: Session,
pop_item: Any,
expected_serializations: Sequence[str],
ignore_ids_for_matching: bool,
mismatch_warning: str,
pop_failure_warning: str,
) -> bool:
"""Remove an exact serialized suffix from the session tail, aborting when the tail diverges."""
if not expected_serializations:
return True

try:
tail_items = await session.get_items(limit=len(expected_serializations))
except Exception as exc:
logger.warning(pop_failure_warning, exc)
return False

if len(tail_items) != len(expected_serializations):
logger.warning(mismatch_warning)
return False

tail_serializations: list[str] = []
for item in tail_items:
serialized = fingerprint_input_item(item, ignore_ids_for_matching=ignore_ids_for_matching)
if not serialized:
logger.warning(mismatch_warning)
return False
tail_serializations.append(serialized)

if tail_serializations != list(expected_serializations):
logger.warning(mismatch_warning)
return False

popped_items: list[TResponseInputItem] = []
for expected in reversed(expected_serializations):
try:
result = pop_item()
if inspect.isawaitable(result):
result = await result
except Exception as exc:
await _restore_popped_session_items(session, popped_items)
logger.warning(pop_failure_warning, exc)
return False

if result is None:
await _restore_popped_session_items(session, popped_items)
logger.warning(mismatch_warning)
return False

popped_items.append(result)
popped_serialized = fingerprint_input_item(
result, ignore_ids_for_matching=ignore_ids_for_matching
)
if popped_serialized != expected:
await _restore_popped_session_items(session, popped_items)
logger.warning(mismatch_warning)
return False

return True


async def _restore_popped_session_items(
session: Session, popped_items: Sequence[TResponseInputItem]
) -> None:
"""Best-effort restoration for items popped during a failed rewind attempt."""
if not popped_items:
return

add_items = getattr(session, "add_items", None)
if not callable(add_items):
return

try:
result = add_items(list(reversed(popped_items)))
if inspect.isawaitable(result):
await result
except Exception as exc:
logger.warning("Failed to restore session items after a rewind mismatch: %s", exc)


def _collect_retry_owned_tail_serializations(
session_items: Sequence[TResponseInputItem],
*,
server_tracker: OpenAIServerConversationTracker,
ignore_ids_for_matching: bool,
) -> list[str]:
"""Return the contiguous retry-owned tail suffix that can be safely stripped."""
stray_tail: list[str] = []

for item in reversed(session_items):
item_id = item.get("id") if isinstance(item, dict) else getattr(item, "id", None)
if isinstance(item_id, str) and item_id in server_tracker.server_item_ids:
return list(reversed(stray_tail))

serialized = fingerprint_input_item(item, ignore_ids_for_matching=ignore_ids_for_matching)
if serialized and serialized in server_tracker.sent_item_fingerprints:
stray_tail.append(serialized)
continue

logger.warning(
"Skipping stray session cleanup because the current tail contains items unrelated "
"to this retry"
)
return []

if stray_tail:
logger.warning(
"Skipping stray session cleanup because no known server item was found before the "
"session boundary"
)
return []


def _session_item_key(item: Any) -> str:
"""Return a stable representation of a session item for comparison."""
try:
Expand Down
74 changes: 74 additions & 0 deletions tests/test_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
from agents.run_internal.run_loop import get_new_response
from agents.run_internal.run_steps import NextStepFinalOutput, SingleStepResult
from agents.run_internal.session_persistence import (
_collect_retry_owned_tail_serializations,
persist_session_items_for_guardrail_trip,
prepare_input_with_session,
rewind_session_items,
Expand Down Expand Up @@ -2364,6 +2365,79 @@ async def test_rewind_handles_id_stripped_sessions() -> None:
assert session.saved_items == []


@pytest.mark.asyncio
async def test_rewind_skips_mismatched_tail_suffix() -> None:
target = cast(TResponseInputItem, {"type": "message", "role": "user", "content": "target"})
unrelated = cast(
TResponseInputItem,
{"type": "message", "role": "user", "content": "unrelated tail item"},
)
session = CountingSession(history=[target, unrelated])

await rewind_session_items(session, [target])

assert session.pop_calls == 0
assert session.saved_items == [target, unrelated]


@pytest.mark.asyncio
async def test_rewind_preserves_unrelated_tail_items_when_server_tracker_cleanup_runs() -> None:
known_server_item = cast(
TResponseInputItem,
{"id": "msg_server_1", "type": "message", "role": "assistant", "content": "server item"},
)
unrelated = cast(
TResponseInputItem,
{"type": "message", "role": "user", "content": "unrelated tail item"},
)
target = cast(TResponseInputItem, {"type": "message", "role": "user", "content": "target"})
session = CountingSession(history=[known_server_item, unrelated, target])
tracker = OpenAIServerConversationTracker()
tracker.server_item_ids.add("msg_server_1")

await rewind_session_items(session, [target], tracker)

assert session.pop_calls == 1
assert session.saved_items == [known_server_item, unrelated]


@pytest.mark.asyncio
async def test_rewind_strips_only_retry_owned_tail_items_before_known_server_item() -> None:
known_server_item = cast(
TResponseInputItem,
{"id": "msg_server_1", "type": "message", "role": "assistant", "content": "server item"},
)
retry_owned_tail = cast(
TResponseInputItem,
{"type": "message", "role": "user", "content": "retry-owned local item"},
)
target = cast(TResponseInputItem, {"type": "message", "role": "user", "content": "target"})
session = CountingSession(history=[known_server_item, retry_owned_tail, target])
tracker = OpenAIServerConversationTracker()
tracker.server_item_ids.add("msg_server_1")
retry_owned_fingerprint = fingerprint_input_item(retry_owned_tail)
assert retry_owned_fingerprint is not None
tracker.sent_item_fingerprints.add(retry_owned_fingerprint)

await rewind_session_items(session, [target], tracker)

assert session.pop_calls == 2
assert session.saved_items == [known_server_item]


def test_collect_retry_owned_tail_serializations_returns_empty_for_empty_session() -> None:
tracker = OpenAIServerConversationTracker()

assert (
_collect_retry_owned_tail_serializations(
[],
server_tracker=tracker,
ignore_ids_for_matching=False,
)
== []
)


@pytest.mark.asyncio
async def test_save_result_to_session_does_not_increment_counter_when_nothing_saved() -> None:
session = SimpleListSession()
Expand Down