Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions .github/workflows/check_api_docstrings.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Check Docstrings
name: Check Dataset Metadata

on:
push:
Expand All @@ -9,7 +9,7 @@ on:
- main

jobs:
check_docstrings:
check_dataset_metadata:
runs-on: ubuntu-latest

steps:
Expand All @@ -21,10 +21,12 @@ jobs:
with:
python-version: '3.12'

- name: Install ruff
- name: Install dependencies
run: |
python -m pip install ruff
python -m pip install --upgrade pip
pip install pytest
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi

- name: Check for docstrings
- name: Check dataset metadata
run: |
ruff check datasets/*/api.py --select D --ignore D100,D203,D213
pytest tests/test_dataset_structure.py
187 changes: 162 additions & 25 deletions agents/openhands_sdk/condensation_sft.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

import argparse
import asyncio
import json
import os
import sys
import tempfile
from collections.abc import Sequence
from collections.abc import Iterator, Sequence
from typing import Any

os.environ.setdefault("OPENHANDS_SUPPRESS_BANNER", "1")
Expand All @@ -15,6 +16,8 @@
from openhands.sdk.context.condenser import LLMSummarizingCondenser
from openhands.sdk.context.condenser.utils import get_total_token_count
from openhands.sdk.context.view import View
from openhands.sdk.event import LLMConvertibleEvent as SDKEvent
from openhands.sdk.event import MessageEvent, SystemPromptEvent
from openhands.sdk.event.condenser import Condensation
from openhands.sdk.llm.llm_response import LLMResponse
from openhands.sdk.tool import ToolDefinition
Expand Down Expand Up @@ -98,10 +101,31 @@ def format_messages(llm: LLM, messages: list[Message]) -> list[dict[str, Any]]:
return normalize_message_content(llm.format_messages_for_llm(messages))


class TrackingSDKEventBuilder(SDKEventBuilder):
def __init__(
self,
conversation: Conversation,
metadata: Any,
event_history: list[SDKEvent],
) -> None:
super().__init__(conversation, metadata)
self.event_history = event_history

def append(self, event: SDKEvent) -> None:
self.event_history.append(event)
super().append(event)


def token_count(view: View, llm: LLM) -> int:
return get_total_token_count(view.events, llm)


def formatted_token_count(events: Sequence[SDKEvent], llm: LLM) -> int:
view = View.from_events(events)
messages = LLMConvertibleEvent.events_to_messages(view.events)
return llm.get_token_count(messages)


def make_condensation_prompt_record(
*,
trajectory_id: str,
Expand Down Expand Up @@ -168,7 +192,7 @@ def make_trajectory_record_from_conversation(

def condensation_prompt_record_if_needed(
*,
conversation: Conversation,
events: list[SDKEvent],
condenser: LLMSummarizingCondenser,
agent_llm: LLM,
condenser_llm: PromptCapturingLLM,
Expand All @@ -177,7 +201,7 @@ def condensation_prompt_record_if_needed(
max_tokens: int,
condensation_index: int,
) -> tuple[Condensation, dict[str, Any]] | None:
view = View.from_events(conversation.state.events)
view = View.from_events(events)
prompt_token_count = token_count(view, condenser.llm)
before_prompt_count = len(condenser_llm.captured_messages)
condensation_result = condenser.condense(view, agent_llm=agent_llm)
Expand Down Expand Up @@ -213,7 +237,28 @@ def append_standardized_events_with_condensation(
include_trajectories: bool,
) -> list[dict[str, Any]]:
metadata = load_dataset_metadata(dataset_name, required=True)
builder = SDKEventBuilder(conversation, metadata)
event_history: list[SDKEvent] = [
SystemPromptEvent(
system_prompt=TextContent(text=conversation.agent.static_system_message),
tools=list(conversation.agent.tools_map.values()),
)
]
builder = TrackingSDKEventBuilder(conversation, metadata, event_history)
first_event = trajectory.content[0]
if not isinstance(first_event, TextObservation) or first_event.source != "user":
raise ValueError(
"OpenHands SDK condensation conversion expects the first event to be a "
"user TextObservation"
)
builder.append(
MessageEvent(
source="user",
llm_message=Message(
role="user",
content=[TextContent(text=first_event.content)],
),
)
)
condenser_llm = PromptCapturingLLM(
usage_id="openhands-sdk-condensation-sft-condenser",
model=model,
Expand All @@ -231,18 +276,17 @@ def append_standardized_events_with_condensation(
condensation_index = 1
index = start_index
batch_number = 0
last_safe_events = list(conversation.state.events)
last_safe_events = list(event_history)

def update_last_safe_events() -> None:
nonlocal last_safe_events
view = View.from_events(conversation.state.events)
if token_count(view, conversation.agent.llm) <= max_tokens:
last_safe_events = list(conversation.state.events)
if formatted_token_count(event_history, conversation.agent.llm) <= max_tokens:
last_safe_events = list(event_history)

def emit_condensation_boundary_if_needed() -> None:
nonlocal segment_index, condensation_index, last_safe_events
result = condensation_prompt_record_if_needed(
conversation=conversation,
events=event_history,
condenser=condenser,
agent_llm=conversation.agent.llm,
condenser_llm=condenser_llm,
Expand All @@ -266,8 +310,9 @@ def emit_condensation_boundary_if_needed() -> None:
)
segment_index += 1
records.append(prompt_record)
event_history.append(condensation)
conversation.state.events.append(condensation)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Suggestion: conversation.state.events is no longer read anywhere in this function after the refactor — event_history is the canonical source for all condensation logic. Yet TrackingSDKEventBuilder.append still calls super().append() (line 115), which writes every event to conversation.state.events; and here the condensation is also written there explicitly to keep the two lists in sync.

This dual-write pattern obscures what the true source of truth is. If conversation.state.events is being kept in sync intentionally (e.g. as a guard against unknown SDK side effects that read it internally), add a brief comment saying so. If it is not needed, remove the super().append() call in TrackingSDKEventBuilder and this explicit append, which would make the migration complete and the code self-documenting.

last_safe_events = list(conversation.state.events)
last_safe_events = list(event_history)
condensation_index += 1

while index < len(trajectory.content):
Expand Down Expand Up @@ -342,7 +387,7 @@ def process_row(
with tempfile.TemporaryDirectory(prefix="openhands-sdk-condensation-sft-") as tmpdir:
conversation = Conversation(agent=agent, workspace=tmpdir, visualizer=None)
try:
conversation.send_message(first_event.content)
conversation._ensure_agent_ready()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 Important: _ensure_agent_ready() is a private (underscore-prefixed) SDK method. Coupling to internal implementation details is fragile — the SDK can rename or remove it without a major version bump and this will break silently. If there is no public API for "initialize without sending a message", that gap should be raised with the SDK team. At minimum, add a comment explaining why this private method is called here and what it guards against.

return append_standardized_events_with_condensation(
conversation=conversation,
trajectory=trajectory,
Expand All @@ -358,6 +403,85 @@ def process_row(
conversation.close()


def iter_input_chunks(chunk_size: int) -> Iterator[list[str]]:
chunk: list[str] = []
for line in sys.stdin:
line = line.strip()
if not line:
continue
chunk.append(line)
if len(chunk) >= chunk_size:
yield chunk
chunk = []
if chunk:
yield chunk


async def process_line(
line: str,
*,
args: argparse.Namespace,
semaphore: asyncio.Semaphore,
) -> list[dict[str, Any]]:
try:
async with semaphore:
return await asyncio.to_thread(
process_row,
line,
max_tokens=args.max_tokens,
model=args.model,
include_trajectories=args.include_trajectories == "yes",
max_size=args.max_size,
keep_first=args.keep_first,
)
except Exception as exc:
if not args.continue_on_error:
raise
row_id = None
try:
row_id = json.loads(line).get("id")
except Exception:
pass
print(
json.dumps(
{
"id": row_id,
"error_type": type(exc).__name__,
"error": str(exc),
},
ensure_ascii=False,
),
file=sys.stderr,
flush=True,
)
return []


async def process_stream(args: argparse.Namespace) -> None:
from tqdm import tqdm

semaphore = asyncio.Semaphore(args.concurrency)
progress = tqdm(
desc="condensation_sft",
unit="row",
dynamic_ncols=True,
disable=args.no_progress,
)
try:
for chunk in iter_input_chunks(args.chunk_size):
tasks = [
asyncio.create_task(process_line(line, args=args, semaphore=semaphore))
for line in chunk
]
for task in asyncio.as_completed(tasks):
records = await task
for record in records:
print(json.dumps(record, ensure_ascii=False), flush=True)
progress.update(1)
finally:
progress.close()


def main() -> None:
parser = argparse.ArgumentParser(
description=(
Expand All @@ -375,21 +499,34 @@ def main() -> None:
default="yes",
help="Whether to emit the original OpenHands SDK trajectory record before summaries.",
)
parser.add_argument(
"--concurrency",
type=int,
default=1,
help="Number of input trajectories to process concurrently.",
)
parser.add_argument(
"--chunk-size",
type=int,
default=100,
help="Number of input rows to schedule per async batch.",
)
parser.add_argument(
"--no-progress",
action="store_true",
help="Disable tqdm progress output on stderr.",
)
parser.add_argument(
"--continue-on-error",
action="store_true",
help="Log per-row conversion errors to stderr and continue processing remaining rows.",
)
args = parser.parse_args()
for line in sys.stdin:
line = line.strip()
if not line:
continue
records = process_row(
line,
max_tokens=args.max_tokens,
model=args.model,
include_trajectories=args.include_trajectories == "yes",
max_size=args.max_size,
keep_first=args.keep_first,
)
for record in records:
print(json.dumps(record, ensure_ascii=False))
if args.concurrency < 1:
raise ValueError("--concurrency must be at least 1")
if args.chunk_size < 1:
raise ValueError("--chunk-size must be at least 1")
asyncio.run(process_stream(args))


if __name__ == "__main__":
Expand Down
Loading
Loading