Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/user-guide/cli-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,9 @@ Sections mirror the launch-script argument groups.
| `--use-tis` | flag | off | Truncated Importance Sampling. |
| `--use-routing-replay` | flag | off | Forward/backward routing consistency. |
| `--use-rollout-routing-replay` | flag | off | R3 — capture inference-side expert routing and replay it during training. |
| `--calculate-per-token-loss` | flag | off | Per-token loss reduction. |
| `--calculate-per-token-loss` | flag | off | Legacy alias for `--loss-aggregation token_mean` (global per-token mean); kept for backward compatibility. Prefer `--loss-aggregation token_mean`. |
| `--loss-aggregation` | enum | `sample_mean` | How pg_loss is aggregated (pg_loss only): `sample_mean` (GRPO), `prompt_mean` (DAPO), `token_mean` (= legacy `--calculate-per-token-loss`), `constant` (Dr.GRPO). See [customization](/user-guide/customization). |
| `--loss-aggregation-divisor` | float | unset | Constant divisor `L` for `--loss-aggregation constant`; required and validated `> 0`. Combines with the standard `/ global_batch_size` step average for an effective `/(L * global_batch_size)`; incompatible with `--calculate-per-token-loss`. |
| `--no-check-for-nan-in-loss-and-grad` | flag | off | Skip NaN/Inf guard (Megatron flag, debug only). |
| `--true-on-policy-mode` | flag | off | Strict on-policy: reject samples from a prior policy. |

Expand Down
61 changes: 60 additions & 1 deletion docs/user-guide/customization.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,65 @@ def get_pg_loss_reducer(
...
```

Use case: Dr.GRPO divides by a constant instead of effective token count.
Use case: a normalization not covered by `--loss-aggregation` below. The four
standard modes are available first class; reach for this hook only for a custom
reducer. When set, it takes precedence over `--loss-aggregation`.
**Reference:** [`examples/DrGRPO/custom_reducer.py`](https://github.com/radixark/miles/blob/main/examples/DrGRPO/custom_reducer.py).

### `--loss-aggregation`

`--loss-aggregation {sample_mean,prompt_mean,token_mean,constant}` selects how
pg_loss is aggregated across a training step (pg_loss only; every other metric -
`pg_clipfrac`, `ppo_kl`, `entropy_loss`, `kl_loss` - keeps the default sample-mean
reducer, the same scope as the custom hook above). Modes follow the ScaleRL
taxonomy ([arXiv:2510.13786](https://arxiv.org/abs/2510.13786) section 3.2):

| Mode | Paper | pg_loss denominator |
| :--- | :--- | :--- |
| `sample_mean` (default) | GRPO sample average | Per-rollout token-weighted mean; each rollout contributes equally regardless of fan-out. Same behavior as the prior default. |
| `prompt_mean` | DAPO prompt average | Per-prompt-group token-weighted mean (all rollouts sharing a `Sample.group_index` share one denominator). ScaleRL's recommended default for new recipes. |
| `token_mean` | token average | Global per-token mean. This is the same objective as the legacy `--calculate-per-token-loss` flag (see below); prefer `--loss-aggregation token_mean`. |
| `constant` | Dr.GRPO ([arXiv:2503.20783](https://arxiv.org/abs/2503.20783)) | `sum(token_loss * loss_mask) / L`, where `L = --loss-aggregation-divisor` (e.g. the max context length). |

`--calculate-per-token-loss` is the legacy spelling of `--loss-aggregation
token_mean`: both select the global per-token mean. It is kept for backward
compatibility (existing Megatron-style recipes), but new recipes should prefer
`--loss-aggregation token_mean`. The two spellings are reconciled onto one axis at
startup: `--calculate-per-token-loss` alone reports `loss_aggregation=token_mean`,
and either spelling alone is accepted. They
may not be combined with a *different* mode: `prompt_mean` or `constant` together
with `--calculate-per-token-loss` is rejected at startup (per-token loss would
renormalize by the token count and undo that mode's denominator).

`--loss-aggregation-divisor L` is required (validated `> 0` at startup) only for
`constant`; every other mode rejects a configured divisor at startup.

Every mode shares the same outer structure: each step's pg_loss sum is divided by
`global_batch_size` (the standard step average). The per-mode denominator above is
the *inner* per-sample scale. For `constant`, the effective denominator is therefore
`L * global_batch_size`: `L` sets the data-independent per-token scale (so loss is
length-unbiased, Dr.GRPO's point) and the `/ global_batch_size` step average is
identical to every other mode. Pick `L` on the order of the max response length to
keep the loss magnitude comparable to the data-dependent modes.

`prompt_mean` weights every prompt group equally: each group's token-weighted
mean contributes once, and the final scalar is the mean over prompt groups. It
requires `global_batch_size` to be a multiple of `n_samples_per_prompt`, so a
contiguous train step keeps each prompt group whole instead of normalizing
against a partially-present group total.
Prompt groups also stay whole when splitting train data across data-parallel
ranks; if the number of prompt groups in a train step is not divisible by the
data-parallel size, partitioning fails before training instead of creating a
partial prompt group on a rank.

For custom train-data converters, `prompt_mean` should mirror the built-in
converter and emit both `prompt_mask_sums` and `prompt_group_indices`.
`prompt_mask_sums` is required for the standard full-step denominator, and
`prompt_group_indices` lets the reducer rebuild denominators from the current
pg_loss masks when TIS/RS changes the active mask. A custom
`--custom-convert-samples-to-train-data-path` that omits the required
`prompt_mean` fields will fail before pg_loss reduction.

### `--custom-convert-samples-to-train-data-path`

```python
Expand All @@ -211,6 +267,9 @@ def convert_samples_to_train_data(args, samples) -> dict:
"metadata": [...],
"multimodal_train_inputs": [...],
"teacher_log_probs": [...],
# required when args.loss_aggregation == "prompt_mean"
"prompt_group_indices": [...],
"prompt_mask_sums": [...],
}
```

Expand Down
2 changes: 2 additions & 0 deletions miles/backends/experimental/fsdp_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,8 @@ def _train_core(self, rollout_id: int, rollout_data) -> None:
"returns",
"ref_log_probs",
"rollout_log_probs",
"prompt_group_indices",
"prompt_mask_sums",
],
self.args.data_pad_size_multiplier,
self.args.qkv_format,
Expand Down
2 changes: 2 additions & 0 deletions miles/backends/megatron_utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,8 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p
"rollout_log_probs",
"max_seq_lens",
"opd_reverse_kl",
"prompt_group_indices",
"prompt_mask_sums",
],
args.data_pad_size_multiplier,
args.qkv_format,
Expand Down
35 changes: 27 additions & 8 deletions miles/backends/training_utils/cp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,19 +96,31 @@ def get_sum_of_sample_mean(
calculate_per_token_loss: bool = False,
qkv_format: str = "thd",
max_seq_lens: list[int] | None = None,
sample_denoms: list[torch.Tensor] | None = None,
constant_divisor: float | None = None,
) -> Callable[[torch.Tensor], torch.Tensor]:
"""
Calculate correct sample mean for CP
"""
"""Build a CP-aware reducer for masked token losses."""
if constant_divisor is not None and calculate_per_token_loss:
raise ValueError(
"constant_divisor (loss-aggregation=constant) and calculate_per_token_loss "
"(loss-aggregation=token_mean) are mutually exclusive aggregation modes."
)
if constant_divisor is not None and constant_divisor <= 0:
raise ValueError("constant_divisor (loss-aggregation=constant) must be positive.")
if sample_denoms is None and constant_divisor is None and not calculate_per_token_loss:
Comment thread
coderabbitai[bot] marked this conversation as resolved.
sample_denoms = [loss_mask.sum() for loss_mask in loss_masks]

parallel_state = get_parallel_state()
cp_size = parallel_state.cp.size
if cp_size == 1:

def sum_of_sample_mean(x: torch.Tensor) -> torch.Tensor:
return sum(
[
(x_i * loss_mask_i).sum() / torch.clamp_min(loss_mask_i.sum(), 1)
for x_i, loss_mask_i in zip(x.split(response_lengths, dim=0), loss_masks, strict=True)
(x_i * loss_mask_i).sum() / torch.clamp_min(denom, 1)
for x_i, loss_mask_i, denom in zip(
x.split(response_lengths, dim=0), loss_masks, sample_denoms, strict=True
)
]
)

Expand All @@ -135,9 +147,9 @@ def sum_of_token(x: torch.Tensor) -> torch.Tensor:
def sum_of_sample_mean(x: torch.Tensor) -> torch.Tensor:
return sum(
[
(x_i * chunked_loss_mask).sum() / torch.clamp_min(loss_mask.sum(), 1)
for x_i, chunked_loss_mask, loss_mask in zip(
x.split(cp_chunk_lengths, dim=0), chunked_loss_masks, loss_masks, strict=True
(x_i * chunked_loss_mask).sum() / torch.clamp_min(denom, 1)
for x_i, chunked_loss_mask, denom in zip(
x.split(cp_chunk_lengths, dim=0), chunked_loss_masks, sample_denoms, strict=True
)
]
)
Expand All @@ -152,6 +164,13 @@ def sum_of_token(x: torch.Tensor) -> torch.Tensor:
]
)

if constant_divisor is not None:

def sum_of_constant(x: torch.Tensor) -> torch.Tensor:
return sum_of_token(x) / constant_divisor

return sum_of_constant

return sum_of_sample_mean if not calculate_per_token_loss else sum_of_token


Expand Down
4 changes: 4 additions & 0 deletions miles/backends/training_utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ def get_rollout_data(args: Namespace, rollout_data_ref: Box) -> RolloutBatch:
rollout_data["loss_masks"] = [
torch.tensor(t, dtype=torch.int, device=torch.cuda.current_device()) for t in rollout_data["loss_masks"]
]
if "prompt_mask_sums" in rollout_data:
rollout_data["prompt_mask_sums"] = list(
torch.tensor(rollout_data["prompt_mask_sums"], dtype=torch.float32, device=torch.cuda.current_device())
)
if "multimodal_train_inputs" in rollout_data:
# Move multimodal training tensors to GPU in advance
rollout_data["multimodal_train_inputs"] = [
Expand Down
36 changes: 33 additions & 3 deletions miles/backends/training_utils/log_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,23 +388,53 @@ def aggregate_train_losses(
return {}

keys = losses_reduced[0]["keys"]
has_normalizers = "normalizers" in losses_reduced[0]

values = None
normalizers = None
for log_dict in losses_reduced:
if log_dict["keys"] != keys:
raise ValueError(
"All train loss log_dict entries must have the same keys in the same order; "
f"got {log_dict['keys']} after {keys}."
)
if ("normalizers" in log_dict) != has_normalizers:
raise ValueError(
"Cannot aggregate a mix of train loss log_dict entries with and without explicit normalizers."
)
if values is None:
values = log_dict["values"].clone()
else:
values += log_dict["values"]
if "normalizers" in log_dict:
if normalizers is None:
normalizers = log_dict["normalizers"].clone()
else:
normalizers += log_dict["normalizers"]

assert len(keys) + 1 == values.numel(), f"Expected {len(keys) + 1} values, got {values.numel()}"
if normalizers is None:
if len(keys) + 1 != values.numel():
raise ValueError(f"Expected {len(keys) + 1} values, got {values.numel()}")
else:
if len(keys) != values.numel():
raise ValueError(f"Expected {len(keys)} values, got {values.numel()}")
if len(keys) != normalizers.numel():
raise ValueError(f"Expected {len(keys)} normalizers, got {normalizers.numel()}")

dist.all_reduce(values, op=dist.ReduceOp.SUM, group=parallel_state.intra_dp_cp.group)
if normalizers is not None:
dist.all_reduce(normalizers, op=dist.ReduceOp.SUM, group=parallel_state.intra_dp_cp.group)

loss_reduced = {}
values = values.tolist()
num_samples_or_tokens = values[0]
if normalizers is not None:
normalizers = normalizers.tolist()
for key, value, normalizer in zip(keys, values, normalizers, strict=True):
loss_reduced[key] = value * parallel_state.cp.size / normalizer
Comment thread
coderabbitai[bot] marked this conversation as resolved.
return loss_reduced

for key, value in zip(keys, values[1:], strict=False):
num_samples_or_tokens = values[0]
for key, value in zip(keys, values[1:], strict=True):
loss_reduced[key] = value * parallel_state.cp.size / num_samples_or_tokens

return loss_reduced
Expand Down
101 changes: 87 additions & 14 deletions miles/backends/training_utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,69 @@
from miles.backends.training_utils.parallel import get_parallel_state
from miles.utils.types import RolloutBatch

TOKEN_NORMALIZED_TRAIN_KEYS = frozenset({"loss", "pg_loss", "ess_ratio", "value_loss"})


def _as_scalar_tensor(value, *, device: torch.device) -> torch.Tensor:
if isinstance(value, torch.Tensor):
return value.to(device=device).reshape(())
return torch.tensor(value, device=device)


def _build_train_log_dict(
log: dict[str, torch.Tensor],
*,
num_samples: int,
num_tokens: torch.Tensor,
device: torch.device,
calculate_per_token_loss: bool,
) -> dict[str, list[str] | torch.Tensor]:
keys = list(log.keys())
values = torch.stack([_as_scalar_tensor(value, device=device) for value in log.values()])
if not calculate_per_token_loss:
return {
"keys": keys,
"values": torch.cat([torch.tensor([num_samples], device=device), values]),
}

normalizers = torch.stack(
[
(
num_tokens.to(device=device).reshape(())
if key in TOKEN_NORMALIZED_TRAIN_KEYS
else torch.tensor(num_samples, device=device)
)
for key in keys
]
)
return {
"keys": keys,
"values": values,
"normalizers": normalizers,
}


def _validate_loss_aggregation_contract(args: Namespace) -> None:
if getattr(args, "loss_type", None) != "policy_loss":
return
mode = getattr(args, "loss_aggregation", None)
token_mean = mode == "token_mean" or (mode is None and getattr(args, "calculate_per_token_loss", False))
if not token_mean:
return

mixed_terms = []
if getattr(args, "entropy_coef", 0) != 0:
mixed_terms.append("--entropy-coef")
if getattr(args, "use_kl_loss", False) and getattr(args, "kl_loss_coef", 0) != 0:
mixed_terms.append("--kl-loss-coef")
if mixed_terms:
raise ValueError(
"--loss-aggregation token_mean cannot be combined with "
f"{', '.join(mixed_terms)} because the policy-gradient term is token-normalized "
"while auxiliary policy-loss terms are sample-normalized. Use sample_mean, "
"set the auxiliary coefficient to 0, or add explicit per-term loss normalizers."
)


def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch) -> None:
"""Compute advantages and returns in-place based on `args.advantage_estimator`.
Expand Down Expand Up @@ -122,14 +185,21 @@ def loss_function(
parallel_state = get_parallel_state()
num_tokens = sum([torch.clamp_min(loss_mask.sum(), 1) for loss_mask in batch["loss_masks"]])
num_samples = len(batch["response_lengths"])
_validate_loss_aggregation_contract(args)

# Policy loss selects pg_loss aggregation separately; this shared reducer is
# for metrics and auxiliary terms. Non-policy losses use the legacy reducer
# axis directly.
reducer_per_token_loss = (
args.calculate_per_token_loss and getattr(args, "loss_type", "policy_loss") != "policy_loss"
)
sum_of_sample_mean = get_sum_of_sample_mean(
batch["total_lengths"],
batch["response_lengths"],
batch["loss_masks"],
args.calculate_per_token_loss,
args.qkv_format,
batch.get("max_seq_lens", None),
calculate_per_token_loss=reducer_per_token_loss,
qkv_format=args.qkv_format,
max_seq_lens=batch.get("max_seq_lens", None),
Comment thread
coderabbitai[bot] marked this conversation as resolved.
)

func = get_loss_function(args)
Expand Down Expand Up @@ -166,17 +236,20 @@ def loss_function(
if apply_megatron_loss_scaling:
loss = loss * parallel_state.cp.size

log_dict = _build_train_log_dict(
log,
num_samples=num_samples,
num_tokens=num_tokens,
device=logits.device,
calculate_per_token_loss=args.calculate_per_token_loss,
)

return (
loss,
torch.tensor(num_tokens if args.calculate_per_token_loss else 1, device=logits.device),
{
"keys": list(log.keys()),
"values": torch.tensor(
[
num_samples if not args.calculate_per_token_loss else num_tokens,
]
+ list(log.values()),
device=logits.device,
),
},
(
num_tokens.to(device=logits.device)
if args.calculate_per_token_loss
else torch.tensor(1, device=logits.device)
),
log_dict,
)
Loading
Loading