-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathrollout_collector.py
More file actions
189 lines (155 loc) · 6.31 KB
/
rollout_collector.py
File metadata and controls
189 lines (155 loc) · 6.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
"""Rollout collector for GRPO training.
Collects groups of N rollouts using the openadapt-evals RLEnvironment,
which wraps a live WAA server. Each rollout produces a trajectory of
(observation, action, reward) tuples and a terminal reward from the
WAA evaluator.
Currently sequential (single VM). Parallel VM support via
openadapt-evals PoolManager is future work.
"""
from __future__ import annotations
import logging
import random
from dataclasses import dataclass, field
from typing import Any, Callable
from openadapt_ml.training.grpo.config import GRPOConfig
from openadapt_ml.training.grpo.reward import binary_task_success
logger = logging.getLogger(__name__)
# Deferred imports for openadapt-evals dependencies (optional at install time)
try:
from openadapt_evals.adapters import (
RLEnvironment,
RolloutStep,
WAALiveAdapter,
WAALiveConfig,
)
from openadapt_evals.adapters.rl_env import ResetConfig
except ImportError:
RLEnvironment = None # type: ignore[assignment, misc]
RolloutStep = None # type: ignore[assignment, misc]
WAALiveAdapter = None # type: ignore[assignment, misc]
WAALiveConfig = None # type: ignore[assignment, misc]
ResetConfig = None # type: ignore[assignment, misc]
@dataclass
class Rollout:
"""Complete episode rollout with reward.
Attributes:
task_id: The WAA task that was executed.
steps: List of RolloutStep objects from the RLEnvironment.
reward: Binary reward (0.0 or 1.0) from the evaluator.
num_steps: Number of steps taken in the episode.
instruction: Task instruction text for prompt reconstruction
during loss computation. Populated from the environment's
current task after rollout collection.
"""
task_id: str
steps: list[Any] = field(default_factory=list) # list[RolloutStep]
reward: float = 0.0
num_steps: int = 0
instruction: str = ""
class GRPORolloutCollector:
"""Collects groups of rollouts using openadapt-evals RLEnvironment.
Creates a WAALiveAdapter and RLEnvironment from the config, then
provides methods to collect groups of N rollouts for GRPO training.
Currently sequential (single VM); parallel VM support is future work.
Args:
config: GRPO training configuration.
task_configs: Optional dict mapping task_id -> TaskConfig. When
provided, task configs are loaded into the RLEnvironment for
milestone-based dense reward evaluation.
Raises:
ImportError: If openadapt-evals is not installed.
"""
def __init__(
self,
config: GRPOConfig,
task_configs: dict[str, Any] | None = None,
) -> None:
if RLEnvironment is None:
raise ImportError(
"openadapt-evals is required for rollout collection. "
"Install it with: uv add openadapt-evals"
)
self._config = config
self._task_configs = task_configs or {}
self._adapter = WAALiveAdapter(
WAALiveConfig(
server_url=config.server_url,
evaluate_url=config.evaluate_url,
)
)
self._env = RLEnvironment(self._adapter)
@property
def env(self) -> Any:
"""The underlying RLEnvironment instance."""
return self._env
def collect_group(
self,
agent_fn: Callable,
task_id: str | None = None,
) -> list[Rollout]:
"""Collect N rollouts for one GRPO gradient step.
Runs the agent N times on the same task (or a random task from
config.task_ids if task_id is not specified). Each rollout resets
the environment, runs the agent, and evaluates the result.
Currently sequential (single VM). Parallel VM support via
openadapt-evals PoolManager is future work.
Args:
agent_fn: Callable that takes a BenchmarkObservation and returns
a BenchmarkAction. This is the model's predict function.
task_id: Specific task ID, or None to pick from config.task_ids.
Returns:
List of N Rollout objects with binary rewards.
"""
if task_id is None:
if not self._config.task_ids:
raise ValueError("No task_id provided and config.task_ids is empty.")
task_id = random.choice(self._config.task_ids)
rollouts: list[Rollout] = []
# Load task config into the environment for dense milestone rewards
if task_id in self._task_configs:
tc = self._task_configs[task_id]
self._env.load_task_config(tc)
for i in range(self._config.num_rollouts_per_step):
logger.info(
"Collecting rollout %d/%d for task %s",
i + 1,
self._config.num_rollouts_per_step,
task_id,
)
# collect_rollout resets the environment internally with the
# given task_id before running the agent
steps = self._env.collect_rollout(
agent_fn=agent_fn,
max_steps=self._config.max_steps_per_episode,
stuck_window=self._config.stuck_window,
task_id=task_id,
)
# Extract terminal score from the last step's reward
raw_score = steps[-1].reward if steps else 0.0
reward = binary_task_success(raw_score)
# CR-01: Extract task instruction from the environment's
# current task (set during reset inside collect_rollout).
instruction = ""
task = getattr(self._env, "_current_task", None)
if task is not None:
instruction = getattr(task, "instruction", "") or ""
rollout = Rollout(
task_id=task_id,
steps=steps,
reward=reward,
num_steps=len(steps),
instruction=instruction,
)
rollouts.append(rollout)
logger.info(
"Rollout %d: %d steps, raw_score=%.2f, reward=%.1f",
i + 1,
len(steps),
raw_score,
reward,
)
return rollouts
def close(self) -> None:
"""Clean up adapter resources."""
if hasattr(self._adapter, "close"):
self._adapter.close()