Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 39 additions & 3 deletions docs/en/get_started/customization.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Below is a summary of all available customization interfaces and their purposes.
| [`--rollout-data-postprocess-path`](#8-rollout-data-postprocess---rollout-data-postprocess-path) | Post-process rollout data after log probs are computed. |
| [`--custom-loss-function-path`](#9-custom-loss-function---custom-loss-function-path) | Implement custom training loss computation. |
| [`--custom-tis-function-path`](#10-custom-tisrs-function---custom-tis-function-path) | Implement custom importance sampling for off-policy correction. |
| [`--custom-pg-loss-reducer-function-path`](#11-custom-pg-loss-reducer---custom-pg-loss-reducer-function-path) | Customize pg_loss reduction (e.g., for Dr.GRPO). |
| [`--custom-pg-loss-reducer-function-path`](#11-custom-pg-loss-reducer---custom-pg-loss-reducer-function-path) | Customize pg_loss reduction not covered by `--loss-aggregation`. |
| [`--custom-reward-post-process-path`](#12-reward-post-processing---custom-reward-post-process-path) | Custom post-processing of rewards before advantage computation. |
| [`--custom-convert-samples-to-train-data-path`](#13-samples-to-train-data-conversion---custom-convert-samples-to-train-data-path) | Override the conversion of samples to training data format. |
| [`--custom-rollout-log-function-path`](#14-logging-functions) | Custom logging for training rollouts. |
Expand Down Expand Up @@ -295,8 +295,44 @@ def get_pg_loss_reducer(
```

**Use Cases**:
- Dr.GRPO: Divide by a constant instead of effective token count
- Custom loss normalization strategies
- Custom loss normalization strategies not covered by `--loss-aggregation`

> Use the custom reducer hook only for normalizations that the built-in
> `--loss-aggregation` modes do not express. When set, the custom reducer takes
> precedence over `--loss-aggregation`.

**Built-in modes — `--loss-aggregation`**

`--loss-aggregation {sample_mean,prompt_mean,token_mean,constant}` selects how
pg_loss is aggregated across a training step (pg_loss only; every other metric
keeps the default sample-mean reducer — same scope as the custom hook above).
Modes follow the ScaleRL taxonomy ([arXiv:2510.13786](https://arxiv.org/abs/2510.13786) §3.2):

| Mode | Paper | pg_loss denominator |
| :--- | :--- | :--- |
| `sample_mean` (default) | GRPO sample average | Per-rollout token-weighted mean; each rollout contributes equally regardless of fan-out. |
| `prompt_mean` | DAPO prompt average | Per-prompt-group token-weighted mean (all rollouts sharing a prompt share one denominator). |
| `token_mean` | token average | Global per-token mean. `--calculate-per-token-loss` is accepted as an alias for this mode when no conflicting `--loss-aggregation` is set. |
| `constant` | Dr.GRPO ([arXiv:2503.20783](https://arxiv.org/abs/2503.20783)) | `sum(token_loss * loss_mask) / L`, where `L = --loss-aggregation-divisor` (e.g. the max context length). |

`--loss-aggregation-divisor L` is required (validated `> 0` at startup) only for
`constant`; it is ignored for the other modes (passing it with any other mode
fails at startup). If `--loss-aggregation` is not set, slime uses `sample_mean`.

The `constant` denominator `L` is *per-token*, not per-step: each mode's reducer
returns a step sum that is then averaged by the usual `/ step_global_batch_size`
(identical structure across all four modes). So the effective per-step
normalization for `constant` is `/ (L * step_global_batch_size)` — `L` only sets
the data-independent per-token scale; the `/ step_global_batch_size` step average
is applied on top exactly as for the other modes.

`prompt_mean` weights every prompt group equally: each group's token-weighted
mean contributes once, and the final scalar is the mean over prompt groups.

**`--calculate-per-token-loss`.** This flag is an alias for
`--loss-aggregation=token_mean`. It can be used by itself, or with
`--loss-aggregation=token_mean`. Combining it with `prompt_mean` or `constant`
fails at startup because those modes use different denominators.

---

Expand Down
40 changes: 37 additions & 3 deletions docs/zh/get_started/customization.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ slime 通过函数路径参数提供了广泛的自定义能力。这些参数
| [`--rollout-data-postprocess-path`](#8-rollout-数据后处理---rollout-data-postprocess-path) | 在计算 log probabilities 后对 rollout 数据进行后处理。 |
| [`--custom-loss-function-path`](#9-自定义损失函数---custom-loss-function-path) | 实现自定义训练损失计算。 |
| [`--custom-tis-function-path`](#10-自定义-tisrs-函数---custom-tis-function-path) | 实现用于离策略(off-policy)校正的自定义重要性采样。 |
| [`--custom-pg-loss-reducer-function-path`](#11-自定义-pg-loss-reducer---custom-pg-loss-reducer-function-path) | 自定义 pg_loss 的归约方式(如 Dr.GRPO)。 |
| [`--custom-pg-loss-reducer-function-path`](#11-自定义-pg-loss-reducer---custom-pg-loss-reducer-function-path) | 自定义 `--loss-aggregation` 未覆盖的 pg_loss 归约方式。 |
| [`--custom-reward-post-process-path`](#12-奖励后处理---custom-reward-post-process-path) | 在优势计算前对奖励进行自定义后处理。 |
| [`--custom-convert-samples-to-train-data-path`](#13-样本转训练数据---custom-convert-samples-to-train-data-path) | 覆盖样本到训练数据格式的转换逻辑。 |
| [`--custom-rollout-log-function-path`](#14-日志函数) | 训练 rollout 的自定义日志记录。 |
Expand Down Expand Up @@ -295,8 +295,42 @@ def get_pg_loss_reducer(
```

**使用场景**:
- Dr.GRPO:除以常数而非有效 token 数
- 自定义损失归一化策略
- `--loss-aggregation` 未覆盖的自定义损失归一化策略

> 只有内置 `--loss-aggregation` 模式无法表达的归一化方式才需要使用自定义
> reducer hook。设置后,自定义 reducer 会优先于 `--loss-aggregation` 生效。

**内置模式 — `--loss-aggregation`**

`--loss-aggregation {sample_mean,prompt_mean,token_mean,constant}` 选择一个
training step 内 pg_loss 的聚合方式(仅作用于 pg_loss;其他指标仍使用默认的
sample-mean reducer,与上面的 custom hook 作用范围一致)。这些模式遵循
ScaleRL taxonomy([arXiv:2510.13786](https://arxiv.org/abs/2510.13786) §3.2):

| 模式 | 论文语义 | pg_loss 分母 |
| :--- | :--- | :--- |
| `sample_mean`(默认) | GRPO sample average | 每个 rollout 先做 token 加权均值,再对 rollout 求平均;每个 rollout 权重相同,与 fan-out 无关。 |
| `prompt_mean` | DAPO prompt average | 每个 prompt group 先做 token 加权均值;共享同一 prompt 的所有 rollout 使用同一个分母。 |
| `token_mean` | token average | 全局 per-token mean。当没有设置冲突的 `--loss-aggregation` 时,`--calculate-per-token-loss` 可作为该模式的 alias。 |
| `constant` | Dr.GRPO ([arXiv:2503.20783](https://arxiv.org/abs/2503.20783)) | `sum(token_loss * loss_mask) / L`,其中 `L = --loss-aggregation-divisor`(例如最大上下文长度)。 |

`--loss-aggregation-divisor L` 只在 `constant` 模式下需要,并在启动时校验
`> 0`;其他模式不使用它,若同时传入会在启动时报错。默认值
为 `sample_mean`。

`constant` 的分母 `L` 是 per-token 的,不是 per-step 的:每种模式的 reducer
都会返回 step sum,然后再经过原有的 `/ step_global_batch_size` step average
(四种模式结构一致)。因此 `constant` 的实际 per-step 归一化是
`/ (L * step_global_batch_size)`:`L` 只控制数据无关的 per-token scale,
`/ step_global_batch_size` 会像其他模式一样继续叠加。

`prompt_mean` 会让每个 prompt group 权重相同:每个 group 的 token 加权均值
贡献一次,最终 scalar 是 prompt group 的均值。

**`--calculate-per-token-loss`。** 这个 flag 是
`--loss-aggregation=token_mean` 的 alias。它可以单独使用,也可以与
`--loss-aggregation=token_mean` 一起使用。将它与 `prompt_mean` 或
`constant` 组合会在启动时报错,因为这些模式使用不同的分母。

---

Expand Down
11 changes: 5 additions & 6 deletions slime/backends/megatron_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,12 +235,11 @@ def _get_rollout_data(self, rollout_data_ref: Box) -> RolloutBatch:
rollout_data["loss_masks"] = [
t.to(device=device, dtype=torch.int, non_blocking=True) for t in rollout_data["loss_masks"]
]
if "rollout_mask_sums" in rollout_data:
# Promote precomputed per-rollout mask totals to GPU tensors here
# (matching loss_masks) so the loss reducer can just divide.
rollout_data["rollout_mask_sums"] = rollout_data["rollout_mask_sums"].to(
device=device, dtype=torch.float32, non_blocking=True
)
for mask_sums_key in ("rollout_mask_sums", "prompt_mask_sums"):
if mask_sums_key in rollout_data:
rollout_data[mask_sums_key] = rollout_data[mask_sums_key].to(
device=device, dtype=torch.float32, non_blocking=True
)
if "multimodal_train_inputs" in rollout_data:
# Move multimodal training tensors to GPU in advance
rollout_data["multimodal_train_inputs"] = [
Expand Down
22 changes: 20 additions & 2 deletions slime/backends/megatron_utils/cp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def get_sum_of_sample_mean(
loss_masks: list[torch.Tensor],
sample_denoms: list[torch.Tensor] | torch.Tensor | None = None,
calculate_per_token_loss: bool = False,
constant_divisor: float | None = None,
) -> Callable[[torch.Tensor], torch.Tensor]:
"""
Calculate correct sample mean for CP.
Expand All @@ -63,7 +64,17 @@ def get_sum_of_sample_mean(
step level rather than per-mb is required — otherwise a rollout whose
samples land in different micro-batches would get a partial denominator
on each side.

``constant_divisor`` (Dr.GRPO, arXiv:2503.20783) divides the masked
token-sum by a fixed ``L``; being identical on every CP rank, Megatron's
gradient sum-allreduce already yields the full-batch value (no extra
all-reduce). Mutually exclusive with ``calculate_per_token_loss``.
"""
if constant_divisor is not None and calculate_per_token_loss:
raise ValueError(
"constant_divisor (loss-aggregation=constant) and calculate_per_token_loss "
"(loss-aggregation=token_mean) are mutually exclusive aggregation modes."
)
if sample_denoms is None:
sample_denoms = [m.sum() for m in loss_masks]

Expand All @@ -75,7 +86,7 @@ def sum_of_sample_mean(x: torch.Tensor) -> torch.Tensor:
[
(x_i * loss_mask_i).sum() / torch.clamp_min(denom, 1)
for x_i, loss_mask_i, denom in zip(
x.split(response_lengths, dim=0), loss_masks, sample_denoms, strict=False
x.split(response_lengths, dim=0), loss_masks, sample_denoms, strict=True
)
]
)
Expand Down Expand Up @@ -106,7 +117,7 @@ def sum_of_sample_mean(x: torch.Tensor) -> torch.Tensor:
[
(x_i * chunked_loss_mask).sum() / torch.clamp_min(denom, 1)
for x_i, chunked_loss_mask, denom in zip(
x.split(cp_chunk_lengths, dim=0), chunked_loss_masks, sample_denoms, strict=False
x.split(cp_chunk_lengths, dim=0), chunked_loss_masks, sample_denoms, strict=True
)
]
)
Expand All @@ -121,6 +132,13 @@ def sum_of_token(x: torch.Tensor) -> torch.Tensor:
]
)

if constant_divisor is not None:

def sum_of_constant(x: torch.Tensor) -> torch.Tensor:
return sum_of_token(x) / constant_divisor

return sum_of_constant

return sum_of_sample_mean if not calculate_per_token_loss else sum_of_token


Expand Down
1 change: 1 addition & 0 deletions slime/backends/megatron_utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ def log_rollout_data(
"rollout_mask_sums",
"rollout_top_p_token_ids",
"rollout_top_p_token_offsets",
"prompt_mask_sums",
"rollout_routed_experts",
"global_batch_sizes",
"num_microbatches",
Expand Down
61 changes: 56 additions & 5 deletions slime/backends/megatron_utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,51 @@ def icepop_function(
return pg_loss, loss_masks, metrics


def get_pg_loss_reducer(
args: Namespace,
batch: RolloutBatch,
*,
total_lengths: list[int],
response_lengths: list[int],
pg_loss_masks: list[torch.Tensor],
default_reducer: Callable[[torch.Tensor], torch.Tensor],
) -> Callable[[torch.Tensor], torch.Tensor]:
"""Return the pg_loss reducer selected by ``--loss-aggregation``."""
mode = args.loss_aggregation
if mode in ("sample_mean", "token_mean"):
return default_reducer
if mode == "prompt_mean":
prompt_mask_sums = batch.get("prompt_mask_sums")
if prompt_mask_sums is None:
raise ValueError(
"--loss-aggregation=prompt_mean requires per-prompt-group mask sums "
"(batch['prompt_mask_sums']), but they are missing. A custom "
"--custom-convert-samples-to-train-data-path must populate "
"'prompt_mask_sums' (grouped by Sample.group_index)."
)
reducer = get_sum_of_sample_mean(
total_lengths,
response_lengths,
pg_loss_masks,
prompt_mask_sums,
)

# The train step divides non-token losses by response count; scale the
# prompt-group sum so the final scalar is the exact mean over prompts.
def prompt_mean_reducer(x: torch.Tensor) -> torch.Tensor:
return reducer(x) * args.n_samples_per_prompt

return prompt_mean_reducer
if mode == "constant":
return get_sum_of_sample_mean(
total_lengths,
response_lengths,
pg_loss_masks,
constant_divisor=args.loss_aggregation_divisor,
)
raise ValueError(f"Unknown --loss-aggregation mode: {mode!r}")


def policy_loss_function(
args: Namespace,
batch: RolloutBatch,
Expand Down Expand Up @@ -1024,16 +1069,22 @@ def policy_loss_function(
args.calculate_per_token_loss,
)

# Determine pg_loss reducer: use custom if specified, otherwise default
if getattr(args, "custom_pg_loss_reducer_function_path", None) is not None:
pg_loss_masks = modified_response_masks if (args.get_mismatch_metrics or args.use_tis) else batch["loss_masks"]

if args.custom_pg_loss_reducer_function_path is not None:
custom_pg_loss_reducer_func = load_function(args.custom_pg_loss_reducer_function_path)
# Determine which loss_masks to use for pg_loss reducer
pg_loss_masks = modified_response_masks if (args.get_mismatch_metrics or args.use_tis) else batch["loss_masks"]
pg_loss_reducer = custom_pg_loss_reducer_func(
total_lengths, response_lengths, pg_loss_masks, args.calculate_per_token_loss
)
else:
pg_loss_reducer = sum_of_sample_mean
pg_loss_reducer = get_pg_loss_reducer(
args,
batch,
total_lengths=total_lengths,
response_lengths=response_lengths,
pg_loss_masks=pg_loss_masks,
default_reducer=sum_of_sample_mean,
)

pg_loss = pg_loss_reducer(pg_loss)
pg_clipfrac = sum_of_sample_mean(pg_clipfrac)
Expand Down
1 change: 1 addition & 0 deletions slime/backends/megatron_utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,7 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p
"rollout_log_probs",
"teacher_log_probs",
"rollout_mask_sums",
"prompt_mask_sums",
],
),
args.data_pad_size_multiplier,
Expand Down
25 changes: 20 additions & 5 deletions slime/ray/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,12 @@ def _tensorize_rollout_data_for_training(rollout_data: dict[str, Any]) -> None:
for mm_dict in rollout_data["multimodal_train_inputs"]
]

if "rollout_mask_sums" in rollout_data:
rollout_data["rollout_mask_sums"] = _cpu_tensor(
rollout_data["rollout_mask_sums"],
dtype=torch.float32,
)
for mask_sums_key in ("rollout_mask_sums", "prompt_mask_sums"):
if mask_sums_key in rollout_data:
rollout_data[mask_sums_key] = _cpu_tensor(
rollout_data[mask_sums_key],
dtype=torch.float32,
)


@dataclasses.dataclass
Expand Down Expand Up @@ -777,6 +778,19 @@ def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sampl
rollout_total_mask[rid] = rollout_total_mask.get(rid, 0) + ms
train_data["rollout_mask_sums"] = [rollout_total_mask[rid] for rid in rollout_id_list]

if self.args.loss_aggregation == "prompt_mean":
group_total_mask: dict[int, int] = {}
for sample, ms in zip(samples, mask_sums_per_sample, strict=True):
if sample.group_index is None:
raise ValueError(
"--loss-aggregation prompt_mean requires every Sample.group_index to be set, "
"but a sample has group_index=None. prompt_mean divides each sample by its "
"prompt group's total mask; a None group_index means the sample belongs to no "
"prompt group, so its denominator is undefined."
)
group_total_mask[sample.group_index] = group_total_mask.get(sample.group_index, 0) + ms
train_data["prompt_mask_sums"] = [group_total_mask[sample.group_index] for sample in samples]

# Overwrite raw_reward when available. Mixed-source batches may only
# populate this field for a subset of samples (e.g. SWE but not code).
if any(sample.metadata and "raw_reward" in sample.metadata for sample in samples):
Expand Down Expand Up @@ -866,6 +880,7 @@ def _split_train_data_by_dp(self, data):
"sample_indices",
"rollout_ids",
"rollout_mask_sums",
"prompt_mask_sums",
"rollout_log_probs",
"rollout_top_p_token_ids",
"rollout_top_p_token_offsets",
Expand Down
Loading
Loading