-
Notifications
You must be signed in to change notification settings - Fork 16
Use in-memory event history for condenser replay #252
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
45281ae
1c1b7d1
1b3d43f
b483744
ea2cbfb
473a1cb
df710f9
2ec045a
d0f45a4
e1e1ebf
2cbd497
825e914
a3f9251
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,11 +1,12 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import argparse | ||
| import asyncio | ||
| import json | ||
| import os | ||
| import sys | ||
| import tempfile | ||
| from collections.abc import Sequence | ||
| from collections.abc import Iterator, Sequence | ||
| from typing import Any | ||
|
|
||
| os.environ.setdefault("OPENHANDS_SUPPRESS_BANNER", "1") | ||
|
|
@@ -15,6 +16,8 @@ | |
| from openhands.sdk.context.condenser import LLMSummarizingCondenser | ||
| from openhands.sdk.context.condenser.utils import get_total_token_count | ||
| from openhands.sdk.context.view import View | ||
| from openhands.sdk.event import LLMConvertibleEvent as SDKEvent | ||
| from openhands.sdk.event import MessageEvent, SystemPromptEvent | ||
| from openhands.sdk.event.condenser import Condensation | ||
| from openhands.sdk.llm.llm_response import LLMResponse | ||
| from openhands.sdk.tool import ToolDefinition | ||
|
|
@@ -98,10 +101,31 @@ def format_messages(llm: LLM, messages: list[Message]) -> list[dict[str, Any]]: | |
| return normalize_message_content(llm.format_messages_for_llm(messages)) | ||
|
|
||
|
|
||
| class TrackingSDKEventBuilder(SDKEventBuilder): | ||
| def __init__( | ||
| self, | ||
| conversation: Conversation, | ||
| metadata: Any, | ||
| event_history: list[SDKEvent], | ||
| ) -> None: | ||
| super().__init__(conversation, metadata) | ||
| self.event_history = event_history | ||
|
|
||
| def append(self, event: SDKEvent) -> None: | ||
| self.event_history.append(event) | ||
| super().append(event) | ||
|
|
||
|
|
||
| def token_count(view: View, llm: LLM) -> int: | ||
| return get_total_token_count(view.events, llm) | ||
|
|
||
|
|
||
| def formatted_token_count(events: Sequence[SDKEvent], llm: LLM) -> int: | ||
| view = View.from_events(events) | ||
| messages = LLMConvertibleEvent.events_to_messages(view.events) | ||
| return llm.get_token_count(messages) | ||
|
|
||
|
|
||
| def make_condensation_prompt_record( | ||
| *, | ||
| trajectory_id: str, | ||
|
|
@@ -168,7 +192,7 @@ def make_trajectory_record_from_conversation( | |
|
|
||
| def condensation_prompt_record_if_needed( | ||
| *, | ||
| conversation: Conversation, | ||
| events: list[SDKEvent], | ||
| condenser: LLMSummarizingCondenser, | ||
| agent_llm: LLM, | ||
| condenser_llm: PromptCapturingLLM, | ||
|
|
@@ -177,7 +201,7 @@ def condensation_prompt_record_if_needed( | |
| max_tokens: int, | ||
| condensation_index: int, | ||
| ) -> tuple[Condensation, dict[str, Any]] | None: | ||
| view = View.from_events(conversation.state.events) | ||
| view = View.from_events(events) | ||
| prompt_token_count = token_count(view, condenser.llm) | ||
| before_prompt_count = len(condenser_llm.captured_messages) | ||
| condensation_result = condenser.condense(view, agent_llm=agent_llm) | ||
|
|
@@ -213,7 +237,28 @@ def append_standardized_events_with_condensation( | |
| include_trajectories: bool, | ||
| ) -> list[dict[str, Any]]: | ||
| metadata = load_dataset_metadata(dataset_name, required=True) | ||
| builder = SDKEventBuilder(conversation, metadata) | ||
| event_history: list[SDKEvent] = [ | ||
| SystemPromptEvent( | ||
| system_prompt=TextContent(text=conversation.agent.static_system_message), | ||
| tools=list(conversation.agent.tools_map.values()), | ||
| ) | ||
| ] | ||
| builder = TrackingSDKEventBuilder(conversation, metadata, event_history) | ||
| first_event = trajectory.content[0] | ||
| if not isinstance(first_event, TextObservation) or first_event.source != "user": | ||
| raise ValueError( | ||
| "OpenHands SDK condensation conversion expects the first event to be a " | ||
| "user TextObservation" | ||
| ) | ||
| builder.append( | ||
| MessageEvent( | ||
| source="user", | ||
| llm_message=Message( | ||
| role="user", | ||
| content=[TextContent(text=first_event.content)], | ||
| ), | ||
| ) | ||
| ) | ||
| condenser_llm = PromptCapturingLLM( | ||
| usage_id="openhands-sdk-condensation-sft-condenser", | ||
| model=model, | ||
|
|
@@ -231,18 +276,17 @@ def append_standardized_events_with_condensation( | |
| condensation_index = 1 | ||
| index = start_index | ||
| batch_number = 0 | ||
| last_safe_events = list(conversation.state.events) | ||
| last_safe_events = list(event_history) | ||
|
|
||
| def update_last_safe_events() -> None: | ||
| nonlocal last_safe_events | ||
| view = View.from_events(conversation.state.events) | ||
| if token_count(view, conversation.agent.llm) <= max_tokens: | ||
| last_safe_events = list(conversation.state.events) | ||
| if formatted_token_count(event_history, conversation.agent.llm) <= max_tokens: | ||
| last_safe_events = list(event_history) | ||
|
|
||
| def emit_condensation_boundary_if_needed() -> None: | ||
| nonlocal segment_index, condensation_index, last_safe_events | ||
| result = condensation_prompt_record_if_needed( | ||
| conversation=conversation, | ||
| events=event_history, | ||
| condenser=condenser, | ||
| agent_llm=conversation.agent.llm, | ||
| condenser_llm=condenser_llm, | ||
|
|
@@ -266,8 +310,9 @@ def emit_condensation_boundary_if_needed() -> None: | |
| ) | ||
| segment_index += 1 | ||
| records.append(prompt_record) | ||
| event_history.append(condensation) | ||
| conversation.state.events.append(condensation) | ||
| last_safe_events = list(conversation.state.events) | ||
| last_safe_events = list(event_history) | ||
| condensation_index += 1 | ||
|
|
||
| while index < len(trajectory.content): | ||
|
|
@@ -342,7 +387,7 @@ def process_row( | |
| with tempfile.TemporaryDirectory(prefix="openhands-sdk-condensation-sft-") as tmpdir: | ||
| conversation = Conversation(agent=agent, workspace=tmpdir, visualizer=None) | ||
| try: | ||
| conversation.send_message(first_event.content) | ||
| conversation._ensure_agent_ready() | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟠 Important: |
||
| return append_standardized_events_with_condensation( | ||
| conversation=conversation, | ||
| trajectory=trajectory, | ||
|
|
@@ -358,6 +403,85 @@ def process_row( | |
| conversation.close() | ||
|
|
||
|
|
||
| def iter_input_chunks(chunk_size: int) -> Iterator[list[str]]: | ||
| chunk: list[str] = [] | ||
| for line in sys.stdin: | ||
| line = line.strip() | ||
| if not line: | ||
| continue | ||
| chunk.append(line) | ||
| if len(chunk) >= chunk_size: | ||
| yield chunk | ||
| chunk = [] | ||
| if chunk: | ||
| yield chunk | ||
|
|
||
|
|
||
| async def process_line( | ||
| line: str, | ||
| *, | ||
| args: argparse.Namespace, | ||
| semaphore: asyncio.Semaphore, | ||
| ) -> list[dict[str, Any]]: | ||
| try: | ||
| async with semaphore: | ||
| return await asyncio.to_thread( | ||
| process_row, | ||
| line, | ||
| max_tokens=args.max_tokens, | ||
| model=args.model, | ||
| include_trajectories=args.include_trajectories == "yes", | ||
| max_size=args.max_size, | ||
| keep_first=args.keep_first, | ||
| ) | ||
| except Exception as exc: | ||
| if not args.continue_on_error: | ||
| raise | ||
| row_id = None | ||
| try: | ||
| row_id = json.loads(line).get("id") | ||
| except Exception: | ||
| pass | ||
| print( | ||
| json.dumps( | ||
| { | ||
| "id": row_id, | ||
| "error_type": type(exc).__name__, | ||
| "error": str(exc), | ||
| }, | ||
| ensure_ascii=False, | ||
| ), | ||
| file=sys.stderr, | ||
| flush=True, | ||
| ) | ||
| return [] | ||
|
|
||
|
|
||
| async def process_stream(args: argparse.Namespace) -> None: | ||
| from tqdm import tqdm | ||
|
|
||
| semaphore = asyncio.Semaphore(args.concurrency) | ||
| progress = tqdm( | ||
| desc="condensation_sft", | ||
| unit="row", | ||
| dynamic_ncols=True, | ||
| disable=args.no_progress, | ||
| ) | ||
| try: | ||
| for chunk in iter_input_chunks(args.chunk_size): | ||
| tasks = [ | ||
| asyncio.create_task(process_line(line, args=args, semaphore=semaphore)) | ||
| for line in chunk | ||
| ] | ||
| for task in asyncio.as_completed(tasks): | ||
| records = await task | ||
| for record in records: | ||
| print(json.dumps(record, ensure_ascii=False), flush=True) | ||
| progress.update(1) | ||
| finally: | ||
| progress.close() | ||
|
|
||
|
|
||
| def main() -> None: | ||
| parser = argparse.ArgumentParser( | ||
| description=( | ||
|
|
@@ -375,21 +499,34 @@ def main() -> None: | |
| default="yes", | ||
| help="Whether to emit the original OpenHands SDK trajectory record before summaries.", | ||
| ) | ||
| parser.add_argument( | ||
| "--concurrency", | ||
| type=int, | ||
| default=1, | ||
| help="Number of input trajectories to process concurrently.", | ||
| ) | ||
| parser.add_argument( | ||
| "--chunk-size", | ||
| type=int, | ||
| default=100, | ||
| help="Number of input rows to schedule per async batch.", | ||
| ) | ||
| parser.add_argument( | ||
| "--no-progress", | ||
| action="store_true", | ||
| help="Disable tqdm progress output on stderr.", | ||
| ) | ||
| parser.add_argument( | ||
| "--continue-on-error", | ||
| action="store_true", | ||
| help="Log per-row conversion errors to stderr and continue processing remaining rows.", | ||
| ) | ||
| args = parser.parse_args() | ||
| for line in sys.stdin: | ||
| line = line.strip() | ||
| if not line: | ||
| continue | ||
| records = process_row( | ||
| line, | ||
| max_tokens=args.max_tokens, | ||
| model=args.model, | ||
| include_trajectories=args.include_trajectories == "yes", | ||
| max_size=args.max_size, | ||
| keep_first=args.keep_first, | ||
| ) | ||
| for record in records: | ||
| print(json.dumps(record, ensure_ascii=False)) | ||
| if args.concurrency < 1: | ||
| raise ValueError("--concurrency must be at least 1") | ||
| if args.chunk_size < 1: | ||
| raise ValueError("--chunk-size must be at least 1") | ||
| asyncio.run(process_stream(args)) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟡 Suggestion:
conversation.state.eventsis no longer read anywhere in this function after the refactor —event_historyis the canonical source for all condensation logic. YetTrackingSDKEventBuilder.appendstill callssuper().append()(line 115), which writes every event toconversation.state.events; and here the condensation is also written there explicitly to keep the two lists in sync.This dual-write pattern obscures what the true source of truth is. If
conversation.state.eventsis being kept in sync intentionally (e.g. as a guard against unknown SDK side effects that read it internally), add a brief comment saying so. If it is not needed, remove thesuper().append()call inTrackingSDKEventBuilderand this explicit append, which would make the migration complete and the code self-documenting.