From f73fc1f112641fc7b2167bc4b28c61d5eaf4c7a1 Mon Sep 17 00:00:00 2001 From: EazyReal Date: Tue, 16 Jun 2026 07:52:36 +0000 Subject: [PATCH 1/2] feat(loss): support pg_loss aggregation modes --- docs/user-guide/cli-reference.md | 4 +- docs/user-guide/customization.md | 57 +- .../backends/experimental/fsdp_utils/actor.py | 2 + miles/backends/megatron_utils/model.py | 2 + miles/backends/training_utils/cp_utils.py | 33 +- miles/backends/training_utils/data.py | 4 + miles/backends/training_utils/log_utils.py | 21 +- miles/backends/training_utils/loss.py | 45 +- .../training_utils/loss_hub/losses.py | 128 +++- miles/ray/rollout/train_data_conversion.py | 13 + miles/utils/arguments.py | 72 +++ .../training_utils/loss/test_loss_snapshot.py | 6 +- .../training_utils/test_loss_aggregation.py | 592 ++++++++++++++++++ 13 files changed, 943 insertions(+), 36 deletions(-) create mode 100644 tests/fast/backends/training_utils/test_loss_aggregation.py diff --git a/docs/user-guide/cli-reference.md b/docs/user-guide/cli-reference.md index 30487b5a3d..b9e915b5e8 100644 --- a/docs/user-guide/cli-reference.md +++ b/docs/user-guide/cli-reference.md @@ -234,7 +234,9 @@ Sections mirror the launch-script argument groups. | `--use-tis` | flag | off | Truncated Importance Sampling. | | `--use-routing-replay` | flag | off | Forward/backward routing consistency. | | `--use-rollout-routing-replay` | flag | off | R3 — capture inference-side expert routing and replay it during training. | -| `--calculate-per-token-loss` | flag | off | Per-token loss reduction. | +| `--calculate-per-token-loss` | flag | off | Legacy alias for `--loss-aggregation token_mean` (global per-token mean); kept for backward compatibility. Prefer `--loss-aggregation token_mean`. | +| `--loss-aggregation` | enum | `sample_mean` | How pg_loss is aggregated (pg_loss only): `sample_mean` (GRPO), `prompt_mean` (DAPO), `token_mean` (= legacy `--calculate-per-token-loss`), `constant` (Dr.GRPO). See [customization](/user-guide/customization). | +| `--loss-aggregation-divisor` | float | unset | Constant divisor `L` for `--loss-aggregation constant`; required and validated `> 0`. Combines with the standard `/ global_batch_size` step average for an effective `/(L * global_batch_size)`; incompatible with `--calculate-per-token-loss`. | | `--no-check-for-nan-in-loss-and-grad` | flag | off | Skip NaN/Inf guard (Megatron flag, debug only). | | `--true-on-policy-mode` | flag | off | Strict on-policy: reject samples from a prior policy. | diff --git a/docs/user-guide/customization.md b/docs/user-guide/customization.md index 9968712fc1..cf5a985047 100644 --- a/docs/user-guide/customization.md +++ b/docs/user-guide/customization.md @@ -189,9 +189,61 @@ def get_pg_loss_reducer( ... ``` -Use case: Dr.GRPO divides by a constant instead of effective token count. +Use case: a normalization not covered by `--loss-aggregation` below. The four +standard modes are available first class; reach for this hook only for a custom +reducer. When set, it takes precedence over `--loss-aggregation`. **Reference:** [`examples/DrGRPO/custom_reducer.py`](https://github.com/radixark/miles/blob/main/examples/DrGRPO/custom_reducer.py). +### `--loss-aggregation` + +`--loss-aggregation {sample_mean,prompt_mean,token_mean,constant}` selects how +pg_loss is aggregated across a training step (pg_loss only; every other metric - +`pg_clipfrac`, `ppo_kl`, `entropy_loss`, `kl_loss` - keeps the default sample-mean +reducer, the same scope as the custom hook above). Modes follow the ScaleRL +taxonomy ([arXiv:2510.13786](https://arxiv.org/abs/2510.13786) section 3.2): + +| Mode | Paper | pg_loss denominator | +| :--- | :--- | :--- | +| `sample_mean` (default) | GRPO sample average | Per-rollout token-weighted mean; each rollout contributes equally regardless of fan-out. Same behavior as the prior default. | +| `prompt_mean` | DAPO prompt average | Per-prompt-group token-weighted mean (all rollouts sharing a `Sample.group_index` share one denominator). ScaleRL's recommended default for new recipes. | +| `token_mean` | token average | Global per-token mean. This is the same objective as the legacy `--calculate-per-token-loss` flag (see below); prefer `--loss-aggregation token_mean`. | +| `constant` | Dr.GRPO ([arXiv:2503.20783](https://arxiv.org/abs/2503.20783)) | `sum(token_loss * loss_mask) / L`, where `L = --loss-aggregation-divisor` (e.g. the max context length). | + +`--calculate-per-token-loss` is the legacy spelling of `--loss-aggregation +token_mean`: both select the global per-token mean. It is kept for backward +compatibility (existing Megatron-style recipes), but new recipes should prefer +`--loss-aggregation token_mean`. The two spellings are reconciled onto one axis at +startup: `--calculate-per-token-loss` alone reports `loss_aggregation=token_mean`, +and either spelling alone is accepted. They +may not be combined with a *different* mode: `prompt_mean` or `constant` together +with `--calculate-per-token-loss` is rejected at startup (per-token loss would +renormalize by the token count and undo that mode's denominator). + +`--loss-aggregation-divisor L` is required (validated `> 0` at startup) only for +`constant`; it is ignored for the other modes. + +Every mode shares the same outer structure: each step's pg_loss sum is divided by +`global_batch_size` (the standard step average). The per-mode denominator above is +the *inner* per-sample scale. For `constant`, the effective denominator is therefore +`L * global_batch_size`: `L` sets the data-independent per-token scale (so loss is +length-unbiased, Dr.GRPO's point) and the `/ global_batch_size` step average is +identical to every other mode. Pick `L` on the order of the max response length to +keep the loss magnitude comparable to the data-dependent modes. + +`prompt_mean` weights every prompt group equally: each group's token-weighted +mean contributes once, and the final scalar is the mean over prompt groups. It +requires `global_batch_size` to be a multiple of `n_samples_per_prompt`, so a +contiguous train step keeps each prompt group whole instead of normalizing +against a partially-present group total. + +For custom train-data converters, `prompt_mean` should mirror the built-in +converter and emit both `prompt_mask_sums` and `prompt_group_indices`. +`prompt_mask_sums` is required for the standard full-step denominator, and +`prompt_group_indices` lets the reducer rebuild denominators from the current +pg_loss masks when TIS/RS changes the active mask. A custom +`--custom-convert-samples-to-train-data-path` that omits the required +`prompt_mean` fields will fail before pg_loss reduction. + ### `--custom-convert-samples-to-train-data-path` ```python @@ -211,6 +263,9 @@ def convert_samples_to_train_data(args, samples) -> dict: "metadata": [...], "multimodal_train_inputs": [...], "teacher_log_probs": [...], + # required when args.loss_aggregation == "prompt_mean" + "prompt_group_indices": [...], + "prompt_mask_sums": [...], } ``` diff --git a/miles/backends/experimental/fsdp_utils/actor.py b/miles/backends/experimental/fsdp_utils/actor.py index 98b8b37f0e..064b579d92 100644 --- a/miles/backends/experimental/fsdp_utils/actor.py +++ b/miles/backends/experimental/fsdp_utils/actor.py @@ -465,6 +465,8 @@ def _train_core(self, rollout_id: int, rollout_data) -> None: "returns", "ref_log_probs", "rollout_log_probs", + "prompt_group_indices", + "prompt_mask_sums", ], self.args.data_pad_size_multiplier, self.args.qkv_format, diff --git a/miles/backends/megatron_utils/model.py b/miles/backends/megatron_utils/model.py index 44fdaeedcd..e224cbabfa 100644 --- a/miles/backends/megatron_utils/model.py +++ b/miles/backends/megatron_utils/model.py @@ -417,6 +417,8 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p "rollout_log_probs", "max_seq_lens", "opd_reverse_kl", + "prompt_group_indices", + "prompt_mask_sums", ], args.data_pad_size_multiplier, args.qkv_format, diff --git a/miles/backends/training_utils/cp_utils.py b/miles/backends/training_utils/cp_utils.py index 28b8052dc9..9fd257b88f 100644 --- a/miles/backends/training_utils/cp_utils.py +++ b/miles/backends/training_utils/cp_utils.py @@ -96,10 +96,18 @@ def get_sum_of_sample_mean( calculate_per_token_loss: bool = False, qkv_format: str = "thd", max_seq_lens: list[int] | None = None, + sample_denoms: list[torch.Tensor] | None = None, + constant_divisor: float | None = None, ) -> Callable[[torch.Tensor], torch.Tensor]: - """ - Calculate correct sample mean for CP - """ + """Build a CP-aware reducer for masked token losses.""" + if constant_divisor is not None and calculate_per_token_loss: + raise ValueError( + "constant_divisor (loss-aggregation=constant) and calculate_per_token_loss " + "(loss-aggregation=token_mean) are mutually exclusive aggregation modes." + ) + if sample_denoms is None and constant_divisor is None and not calculate_per_token_loss: + sample_denoms = [loss_mask.sum() for loss_mask in loss_masks] + parallel_state = get_parallel_state() cp_size = parallel_state.cp.size if cp_size == 1: @@ -107,8 +115,10 @@ def get_sum_of_sample_mean( def sum_of_sample_mean(x: torch.Tensor) -> torch.Tensor: return sum( [ - (x_i * loss_mask_i).sum() / torch.clamp_min(loss_mask_i.sum(), 1) - for x_i, loss_mask_i in zip(x.split(response_lengths, dim=0), loss_masks, strict=True) + (x_i * loss_mask_i).sum() / torch.clamp_min(denom, 1) + for x_i, loss_mask_i, denom in zip( + x.split(response_lengths, dim=0), loss_masks, sample_denoms, strict=True + ) ] ) @@ -135,9 +145,9 @@ def sum_of_token(x: torch.Tensor) -> torch.Tensor: def sum_of_sample_mean(x: torch.Tensor) -> torch.Tensor: return sum( [ - (x_i * chunked_loss_mask).sum() / torch.clamp_min(loss_mask.sum(), 1) - for x_i, chunked_loss_mask, loss_mask in zip( - x.split(cp_chunk_lengths, dim=0), chunked_loss_masks, loss_masks, strict=True + (x_i * chunked_loss_mask).sum() / torch.clamp_min(denom, 1) + for x_i, chunked_loss_mask, denom in zip( + x.split(cp_chunk_lengths, dim=0), chunked_loss_masks, sample_denoms, strict=True ) ] ) @@ -152,6 +162,13 @@ def sum_of_token(x: torch.Tensor) -> torch.Tensor: ] ) + if constant_divisor is not None: + + def sum_of_constant(x: torch.Tensor) -> torch.Tensor: + return sum_of_token(x) / constant_divisor + + return sum_of_constant + return sum_of_sample_mean if not calculate_per_token_loss else sum_of_token diff --git a/miles/backends/training_utils/data.py b/miles/backends/training_utils/data.py index a43e0aecfb..d5c2513850 100644 --- a/miles/backends/training_utils/data.py +++ b/miles/backends/training_utils/data.py @@ -45,6 +45,10 @@ def get_rollout_data(args: Namespace, rollout_data_ref: Box) -> RolloutBatch: rollout_data["loss_masks"] = [ torch.tensor(t, dtype=torch.int, device=torch.cuda.current_device()) for t in rollout_data["loss_masks"] ] + if "prompt_mask_sums" in rollout_data: + rollout_data["prompt_mask_sums"] = list( + torch.tensor(rollout_data["prompt_mask_sums"], dtype=torch.float32, device=torch.cuda.current_device()) + ) if "multimodal_train_inputs" in rollout_data: # Move multimodal training tensors to GPU in advance rollout_data["multimodal_train_inputs"] = [ diff --git a/miles/backends/training_utils/log_utils.py b/miles/backends/training_utils/log_utils.py index f660f9026c..b6214b70cc 100644 --- a/miles/backends/training_utils/log_utils.py +++ b/miles/backends/training_utils/log_utils.py @@ -390,20 +390,37 @@ def aggregate_train_losses( keys = losses_reduced[0]["keys"] values = None + normalizers = None for log_dict in losses_reduced: if values is None: values = log_dict["values"].clone() else: values += log_dict["values"] + if "normalizers" in log_dict: + if normalizers is None: + normalizers = log_dict["normalizers"].clone() + else: + normalizers += log_dict["normalizers"] - assert len(keys) + 1 == values.numel(), f"Expected {len(keys) + 1} values, got {values.numel()}" + if normalizers is None: + assert len(keys) + 1 == values.numel(), f"Expected {len(keys) + 1} values, got {values.numel()}" + else: + assert len(keys) == values.numel(), f"Expected {len(keys)} values, got {values.numel()}" + assert len(keys) == normalizers.numel(), f"Expected {len(keys)} normalizers, got {normalizers.numel()}" dist.all_reduce(values, op=dist.ReduceOp.SUM, group=parallel_state.intra_dp_cp.group) + if normalizers is not None: + dist.all_reduce(normalizers, op=dist.ReduceOp.SUM, group=parallel_state.intra_dp_cp.group) loss_reduced = {} values = values.tolist() - num_samples_or_tokens = values[0] + if normalizers is not None: + normalizers = normalizers.tolist() + for key, value, normalizer in zip(keys, values, normalizers, strict=False): + loss_reduced[key] = value * parallel_state.cp.size / normalizer + return loss_reduced + num_samples_or_tokens = values[0] for key, value in zip(keys, values[1:], strict=False): loss_reduced[key] = value * parallel_state.cp.size / num_samples_or_tokens diff --git a/miles/backends/training_utils/loss.py b/miles/backends/training_utils/loss.py index c693febeba..7da16e0160 100644 --- a/miles/backends/training_utils/loss.py +++ b/miles/backends/training_utils/loss.py @@ -123,13 +123,14 @@ def loss_function( num_tokens = sum([torch.clamp_min(loss_mask.sum(), 1) for loss_mask in batch["loss_masks"]]) num_samples = len(batch["response_lengths"]) + # --loss-aggregation applies to pg_loss only; metrics keep this reducer. sum_of_sample_mean = get_sum_of_sample_mean( batch["total_lengths"], batch["response_lengths"], batch["loss_masks"], - args.calculate_per_token_loss, - args.qkv_format, - batch.get("max_seq_lens", None), + calculate_per_token_loss=False, + qkv_format=args.qkv_format, + max_seq_lens=batch.get("max_seq_lens", None), ) func = get_loss_function(args) @@ -166,17 +167,33 @@ def loss_function( if apply_megatron_loss_scaling: loss = loss * parallel_state.cp.size - return ( - loss, - torch.tensor(num_tokens if args.calculate_per_token_loss else 1, device=logits.device), - { - "keys": list(log.keys()), - "values": torch.tensor( - [ - num_samples if not args.calculate_per_token_loss else num_tokens, - ] - + list(log.values()), + log_keys = list(log.keys()) + log_dict = { + "keys": log_keys, + "values": torch.tensor( + [ + num_samples, + ] + + list(log.values()), + device=logits.device, + ), + } + if args.calculate_per_token_loss: + log_dict = { + "keys": log_keys, + "values": torch.tensor(list(log.values()), device=logits.device), + "normalizers": torch.tensor( + [num_tokens if key in {"loss", "pg_loss", "ess_ratio"} else num_samples for key in log_keys], device=logits.device, ), - }, + } + + return ( + loss, + ( + num_tokens.to(device=logits.device) + if args.calculate_per_token_loss + else torch.tensor(1, device=logits.device) + ), + log_dict, ) diff --git a/miles/backends/training_utils/loss_hub/losses.py b/miles/backends/training_utils/loss_hub/losses.py index cc070e5dcd..66d80edbfe 100644 --- a/miles/backends/training_utils/loss_hub/losses.py +++ b/miles/backends/training_utils/loss_hub/losses.py @@ -59,6 +59,113 @@ def __call__( ... +def _prompt_group_mask_sums( + prompt_group_indices: list[int], + pg_loss_masks: list[torch.Tensor], + *, + expected_group_size: int | None = None, +) -> list[torch.Tensor]: + if len(prompt_group_indices) != len(pg_loss_masks): + raise ValueError( + "--loss-aggregation prompt_mean requires one prompt_group_indices entry " + f"per sample; got {len(prompt_group_indices)} group ids and {len(pg_loss_masks)} masks." + ) + + group_denoms: dict[int, torch.Tensor] = {} + group_counts: dict[int, int] = {} + group_order: list[int] = [] + for group_index, loss_mask in zip(prompt_group_indices, pg_loss_masks, strict=True): + group_key = int(group_index.item()) if isinstance(group_index, torch.Tensor) else int(group_index) + group_order.append(group_key) + group_counts[group_key] = group_counts.get(group_key, 0) + 1 + group_denoms[group_key] = group_denoms.get( + group_key, loss_mask.new_zeros((), dtype=torch.float32) + ) + loss_mask.sum().to(dtype=torch.float32) + + if expected_group_size is not None: + partial_groups = {key: count for key, count in group_counts.items() if count != expected_group_size} + if partial_groups: + raise ValueError( + "--loss-aggregation prompt_mean with modified pg_loss masks requires complete prompt groups " + f"in each local batch; expected {expected_group_size} samples per group, got {partial_groups}." + ) + + return [group_denoms[group_key] for group_key in group_order] + + +def get_pg_loss_reducer( + args: Namespace, + batch: RolloutBatch, + *, + total_lengths: list[int], + response_lengths: list[int], + pg_loss_masks: list[torch.Tensor], + max_seq_lens: list[int] | None, + default_reducer: Callable[[torch.Tensor], torch.Tensor], +) -> Callable[[torch.Tensor], torch.Tensor]: + """Select the pg_loss reducer for ``--loss-aggregation``.""" + mode = getattr(args, "loss_aggregation", None) + if mode is None: + mode = "token_mean" if getattr(args, "calculate_per_token_loss", False) else "sample_mean" + if mode == "sample_mean": + return default_reducer + if mode == "token_mean": + return get_sum_of_sample_mean( + total_lengths, + response_lengths, + pg_loss_masks, + calculate_per_token_loss=True, + qkv_format=args.qkv_format, + max_seq_lens=max_seq_lens, + ) + if mode == "prompt_mean": + prompt_mask_sums = batch.get("prompt_mask_sums") + if prompt_mask_sums is None: + raise ValueError( + "--loss-aggregation prompt_mean requires per-prompt-group mask sums " + "(batch['prompt_mask_sums']), but they are missing. A custom " + "--custom-convert-samples-to-train-data-path must populate " + "'prompt_mask_sums' (grouped by Sample.group_index)." + ) + if pg_loss_masks is not batch.get("loss_masks"): + prompt_group_indices = batch.get("prompt_group_indices") + if prompt_group_indices is None: + raise ValueError( + "--loss-aggregation prompt_mean with modified pg_loss masks requires per-sample prompt " + "group ids (batch['prompt_group_indices']), but they are missing. A custom " + "--custom-convert-samples-to-train-data-path must populate 'prompt_group_indices' " + "(from Sample.group_index)." + ) + prompt_mask_sums = _prompt_group_mask_sums( + prompt_group_indices, + pg_loss_masks, + expected_group_size=args.n_samples_per_prompt, + ) + reducer = get_sum_of_sample_mean( + total_lengths, + response_lengths, + pg_loss_masks, + qkv_format=args.qkv_format, + max_seq_lens=max_seq_lens, + sample_denoms=prompt_mask_sums, + ) + + def prompt_mean_reducer(x: torch.Tensor) -> torch.Tensor: + return reducer(x) * args.n_samples_per_prompt + + return prompt_mean_reducer + if mode == "constant": + return get_sum_of_sample_mean( + total_lengths, + response_lengths, + pg_loss_masks, + qkv_format=args.qkv_format, + max_seq_lens=max_seq_lens, + constant_divisor=args.loss_aggregation_divisor, + ) + raise ValueError(f"Unknown --loss-aggregation mode: {mode!r}") + + def policy_loss_function( args: Namespace, batch: RolloutBatch, @@ -234,21 +341,28 @@ def policy_loss_function( total_lengths, response_lengths, modified_response_masks, - args.calculate_per_token_loss, - args.qkv_format, - max_seq_lens, + calculate_per_token_loss=args.calculate_per_token_loss, + qkv_format=args.qkv_format, + max_seq_lens=max_seq_lens, ) - # Determine pg_loss reducer: use custom if specified, otherwise default + pg_loss_masks = modified_response_masks if (args.get_mismatch_metrics or args.use_tis) else batch["loss_masks"] + if args.custom_pg_loss_reducer_function_path is not None: custom_pg_loss_reducer_func = load_function(args.custom_pg_loss_reducer_function_path) - # Determine which loss_masks to use for pg_loss reducer - pg_loss_masks = modified_response_masks if (args.get_mismatch_metrics or args.use_tis) else batch["loss_masks"] pg_loss_reducer = custom_pg_loss_reducer_func( total_lengths, response_lengths, pg_loss_masks, args.calculate_per_token_loss ) else: - pg_loss_reducer = sum_of_sample_mean + pg_loss_reducer = get_pg_loss_reducer( + args, + batch, + total_lengths=total_lengths, + response_lengths=response_lengths, + pg_loss_masks=pg_loss_masks, + max_seq_lens=max_seq_lens, + default_reducer=sum_of_sample_mean, + ) # ESS (Effective Sample Size) ratio from per-token IS weights # w = π_new/π_old = exp(-ppo_kl). A value of 1.0 is on-policy; near 0 diff --git a/miles/ray/rollout/train_data_conversion.py b/miles/ray/rollout/train_data_conversion.py index 65bc8d4b6d..cffdf07e32 100644 --- a/miles/ray/rollout/train_data_conversion.py +++ b/miles/ray/rollout/train_data_conversion.py @@ -55,6 +55,17 @@ def convert_samples_to_train_data( loss_masks.append(sample.loss_mask) train_data["loss_masks"] = loss_masks + if getattr(args, "loss_aggregation", "sample_mean") == "prompt_mean": + group_mask_totals: dict[int, int] = {} + prompt_group_indices: list[int] = [] + for sample, loss_mask in zip(samples, loss_masks, strict=True): + if sample.group_index is None: + raise ValueError("--loss-aggregation prompt_mean requires every Sample.group_index to be set.") + prompt_group_indices.append(sample.group_index) + group_mask_totals[sample.group_index] = group_mask_totals.get(sample.group_index, 0) + sum(loss_mask) + train_data["prompt_group_indices"] = prompt_group_indices + train_data["prompt_mask_sums"] = [group_mask_totals[group_index] for group_index in prompt_group_indices] + # overwriting the raw reward if samples[0].metadata and "raw_reward" in samples[0].metadata: train_data["raw_reward"] = [sample.metadata["raw_reward"] for sample in samples] @@ -158,6 +169,8 @@ def split_train_data_by_dp(args, data, dp_size): "teacher_log_probs", "opd_reverse_kl", "weight_versions", + "prompt_group_indices", + "prompt_mask_sums", ]: if key not in data: continue diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 344a59141d..35b858e697 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -1065,6 +1065,39 @@ def add_algo_arguments(parser): default=None, help="Path to a custom reducer function for pg_loss only. When set, pg_loss will use this custom reducer while other metrics (pg_clipfrac, ppo_kl, entropy_loss, etc.) still use the default sum_of_sample_mean. (e.g., examples/Dr.GRPO/custom_reducer.py:get_pg_loss_reducer).", ) + parser.add_argument( + "--loss-aggregation", + type=str, + default="sample_mean", + choices=["sample_mean", "prompt_mean", "token_mean", "constant"], + help=( + "How pg_loss is aggregated across the step (applies to pg_loss only; " + "pg_clipfrac, ppo_kl, entropy_loss, kl_loss keep the default sample-mean reducer " + "-- same scope as --custom-pg-loss-reducer-function-path, which still takes " + "precedence when set). Modes follow the ScaleRL taxonomy (arXiv:2510.13786 section " + "3.2): " + "'sample_mean' (default, GRPO): each sequence is averaged by its own active-token " + "count, then averaged across sequences (same behavior as the prior default). " + "'prompt_mean' (DAPO, ScaleRL's recommended default for new recipes): every token in " + "a prompt group weighs equally -- normalize by the group's total active tokens across " + "its completions, then average over prompt groups. " + "'token_mean': global per-token mean; the same objective as the legacy " + "--calculate-per-token-loss flag (prefer --loss-aggregation token_mean). " + "'constant' (Dr.GRPO, arXiv:2503.20783 Eq.2): divide the summed token loss by the " + "fixed --loss-aggregation-divisor (no data-dependent denominator)." + ), + ) + parser.add_argument( + "--loss-aggregation-divisor", + type=float, + default=None, + help=( + "Constant divisor L for --loss-aggregation constant (Dr.GRPO): the model's max " + "generation/context length (e.g. 40960). Required when --loss-aggregation=constant; " + "must be a positive number. No default is provided, so a missing value raises " + "instead of training with an arbitrary denominator." + ), + ) parser.add_argument( "--use-routing-replay", @@ -2008,6 +2041,43 @@ def _resolve_eval_datasets(args) -> list[EvalDatasetConfig]: return eval_datasets +def _validate_loss_aggregation_args(args): + """Reconcile --loss-aggregation with its legacy alias --calculate-per-token-loss.""" + if args.loss_aggregation == "constant": + if args.loss_aggregation_divisor is None or not args.loss_aggregation_divisor > 0: + raise ValueError( + "--loss-aggregation constant requires --loss-aggregation-divisor " + "(the model's max generation/context length, e.g. 40960), got " + f"{args.loss_aggregation_divisor!r}." + ) + elif args.loss_aggregation_divisor is not None: + raise ValueError( + "--loss-aggregation-divisor is only used with --loss-aggregation constant, " + f"but --loss-aggregation={args.loss_aggregation!r}." + ) + + if args.loss_aggregation == "token_mean": + args.calculate_per_token_loss = True + elif args.calculate_per_token_loss: + if args.loss_aggregation == "sample_mean": + args.loss_aggregation = "token_mean" + else: + raise ValueError( + f"--loss-aggregation {args.loss_aggregation} is incompatible with --calculate-per-token-loss " + "(use --loss-aggregation token_mean for the per-token global mean)." + ) + if ( + args.loss_aggregation == "prompt_mean" + and args.global_batch_size is not None + and args.global_batch_size % args.n_samples_per_prompt != 0 + ): + raise ValueError( + "--loss-aggregation prompt_mean requires global_batch_size to be a multiple of " + "n_samples_per_prompt so each prompt group stays within one training step " + f"(got global_batch_size={args.global_batch_size}, n_samples_per_prompt={args.n_samples_per_prompt})." + ) + + def miles_validate_args(args): args.eval_datasets = _resolve_eval_datasets(args) @@ -2333,6 +2403,8 @@ def miles_validate_args(args): ) args.global_batch_size = global_batch_size + _validate_loss_aggregation_args(args) + if args.n_samples_per_prompt == 1: args.grpo_std_normalization = False logger.info("n_samples_per_prompt is set to 1, grpo_std_normalization will be set to False.") diff --git a/tests/fast/backends/training_utils/loss/test_loss_snapshot.py b/tests/fast/backends/training_utils/loss/test_loss_snapshot.py index c4a7559c76..28399f751e 100644 --- a/tests/fast/backends/training_utils/loss/test_loss_snapshot.py +++ b/tests/fast/backends/training_utils/loss/test_loss_snapshot.py @@ -117,9 +117,9 @@ def _get_sum_of_sample_mean(batch, args, parallel_state): batch["total_lengths"], batch["response_lengths"], batch["loss_masks"], - args.calculate_per_token_loss, - args.qkv_format, - batch.get("max_seq_lens", None), + calculate_per_token_loss=args.calculate_per_token_loss, + qkv_format=args.qkv_format, + max_seq_lens=batch.get("max_seq_lens", None), ) diff --git a/tests/fast/backends/training_utils/test_loss_aggregation.py b/tests/fast/backends/training_utils/test_loss_aggregation.py new file mode 100644 index 0000000000..48f163c1d9 --- /dev/null +++ b/tests/fast/backends/training_utils/test_loss_aggregation.py @@ -0,0 +1,592 @@ +"""Tests for the ``--loss-aggregation`` pg_loss modes.""" + +import argparse +import importlib +import sys +import types +from enum import Enum +from types import SimpleNamespace + +import pytest +import torch + +_MISSING = object() +_MILES_MODULES = [ + "miles.backends.sglang_utils.arguments", + "miles.backends.training_utils.cp_utils", + "miles.backends.training_utils.log_utils", + "miles.backends.training_utils.loss", + "miles.backends.training_utils.loss_hub.losses", + "miles.backends.training_utils.parallel", + "miles.ray.rollout.train_data_conversion", + "miles.utils.arguments", + "miles.utils.types", +] + + +def _install_miles_import_stubs(monkeypatch): + for name in [ + "sglang", + "sglang.srt", + "sglang.srt.entrypoints", + "sglang.srt.entrypoints.openai", + ]: + module = types.ModuleType(name) + module.__path__ = [] + monkeypatch.setitem(sys.modules, name, module) + + protocol = types.ModuleType("sglang.srt.entrypoints.openai.protocol") + protocol.Tool = object + monkeypatch.setitem(sys.modules, "sglang.srt.entrypoints.openai.protocol", protocol) + + server_args = types.ModuleType("sglang.srt.server_args") + + class _ServerArgs: + @staticmethod + def add_cli_args(parser): + return None + + server_args.ServerArgs = _ServerArgs + monkeypatch.setitem(sys.modules, "sglang.srt.server_args", server_args) + + chat_template_utils = types.ModuleType("miles.utils.chat_template_utils") + chat_template_utils.__path__ = [] + chat_template_utils.resolve_fixed_chat_template = lambda *args, **kwargs: (None, {}) + monkeypatch.setitem(sys.modules, "miles.utils.chat_template_utils", chat_template_utils) + + tito_tokenizer = types.ModuleType("miles.utils.chat_template_utils.tito_tokenizer") + + class _TITOTokenizerType(Enum): + DEFAULT = "default" + + tito_tokenizer.TITOTokenizerType = _TITOTokenizerType + monkeypatch.setitem(sys.modules, "miles.utils.chat_template_utils.tito_tokenizer", tito_tokenizer) + + sglang_router = types.ModuleType("sglang_router") + sglang_router.__path__ = [] + monkeypatch.setitem(sys.modules, "sglang_router", sglang_router) + + launch_router = types.ModuleType("sglang_router.launch_router") + + class _RouterArgs: + @staticmethod + def add_cli_args(parser, *args, **kwargs): + return None + + launch_router.RouterArgs = _RouterArgs + monkeypatch.setitem(sys.modules, "sglang_router.launch_router", launch_router) + + +@pytest.fixture +def miles(monkeypatch): + previous_modules = {name: sys.modules.get(name, _MISSING) for name in _MILES_MODULES} + for name in _MILES_MODULES: + sys.modules.pop(name, None) + + _install_miles_import_stubs(monkeypatch) + + cp_utils = importlib.import_module("miles.backends.training_utils.cp_utils") + log_utils = importlib.import_module("miles.backends.training_utils.log_utils") + loss = importlib.import_module("miles.backends.training_utils.loss") + losses = importlib.import_module("miles.backends.training_utils.loss_hub.losses") + parallel = importlib.import_module("miles.backends.training_utils.parallel") + train_data_conversion = importlib.import_module("miles.ray.rollout.train_data_conversion") + arguments = importlib.import_module("miles.utils.arguments") + types_module = importlib.import_module("miles.utils.types") + + try: + yield SimpleNamespace( + arguments=arguments, + convert_samples_to_train_data=train_data_conversion.convert_samples_to_train_data, + cp_utils=cp_utils, + log_utils=log_utils, + loss=loss, + get_pg_loss_reducer=losses.get_pg_loss_reducer, + get_sum_of_sample_mean=cp_utils.get_sum_of_sample_mean, + GroupInfo=parallel.GroupInfo, + ParallelState=parallel.ParallelState, + Sample=types_module.Sample, + ) + finally: + for name in reversed(_MILES_MODULES): + previous = previous_modules[name] + if previous is _MISSING: + sys.modules.pop(name, None) + else: + sys.modules[name] = previous + + +def _parallel_state(miles, *, cp_size: int, cp_rank: int = 0): + return miles.ParallelState( + intra_dp=miles.GroupInfo(rank=0, size=1, group=None), + intra_dp_cp=miles.GroupInfo(rank=0, size=cp_size, group=None), + cp=miles.GroupInfo(rank=cp_rank, size=cp_size, group=None), + tp=miles.GroupInfo(rank=0, size=1, group=None), + pp=miles.GroupInfo(rank=0, size=1, group=None), + ep=miles.GroupInfo(rank=0, size=1, group=None), + etp=miles.GroupInfo(rank=0, size=1, group=None), + cp_comm_type=None, + ) + + +def _legacy_sum_of_sample_mean(response_lengths, loss_masks): + def reducer(x: torch.Tensor) -> torch.Tensor: + return sum( + (x_i * m_i).sum() / torch.clamp_min(m_i.sum(), 1) + for x_i, m_i in zip(x.split(response_lengths, dim=0), loss_masks, strict=True) + ) + + return reducer + + +RESPONSE_LENGTHS = [3, 3, 4, 4] +TOTAL_LENGTHS = [3, 3, 4, 4] +LOSS_MASKS = [ + torch.tensor([1.0, 1.0, 0.0]), + torch.tensor([1.0, 1.0, 1.0]), + torch.tensor([1.0, 0.0, 0.0, 0.0]), + torch.tensor([1.0, 1.0, 1.0, 1.0]), +] +PROMPT_GROUP_INDICES = [0, 0, 1, 1] +X = torch.arange(1.0, 15.0) + + +def test_default_is_byte_identical_to_legacy_reducer(miles, monkeypatch): + monkeypatch.setattr(miles.cp_utils, "get_parallel_state", lambda: _parallel_state(miles, cp_size=1)) + + new = miles.get_sum_of_sample_mean(TOTAL_LENGTHS, RESPONSE_LENGTHS, LOSS_MASKS) + legacy = _legacy_sum_of_sample_mean(RESPONSE_LENGTHS, LOSS_MASKS) + + torch.testing.assert_close(new(X), legacy(X), rtol=0, atol=0) + + +def test_prompt_mean_denominator_is_group_token_total(miles, monkeypatch): + monkeypatch.setattr(miles.cp_utils, "get_parallel_state", lambda: _parallel_state(miles, cp_size=1)) + + sample_denoms = [torch.tensor(5.0)] * 4 + reducer = miles.get_sum_of_sample_mean(TOTAL_LENGTHS, RESPONSE_LENGTHS, LOSS_MASKS, sample_denoms=sample_denoms) + expected = (3 + 15) / 5 + (7 + 50) / 5 + torch.testing.assert_close(reducer(X), torch.tensor(expected)) + + +def test_constant_divides_summed_token_loss_by_L(miles, monkeypatch): + monkeypatch.setattr(miles.cp_utils, "get_parallel_state", lambda: _parallel_state(miles, cp_size=1)) + + reducer = miles.get_sum_of_sample_mean(TOTAL_LENGTHS, RESPONSE_LENGTHS, LOSS_MASKS, constant_divisor=10.0) + torch.testing.assert_close(reducer(X), torch.tensor(7.5)) + + +def test_constant_does_not_compute_sample_denoms(miles, monkeypatch): + monkeypatch.setattr(miles.cp_utils, "get_parallel_state", lambda: _parallel_state(miles, cp_size=1)) + + original_sum = torch.Tensor.sum + + def fail_sum(self, *args, **kwargs): + raise AssertionError("constant aggregation should not compute sample denominators") + + monkeypatch.setattr(torch.Tensor, "sum", fail_sum) + reducer = miles.get_sum_of_sample_mean([3], [3], [torch.ones(3)], constant_divisor=10.0) + monkeypatch.setattr(torch.Tensor, "sum", original_sum) + + torch.testing.assert_close(reducer(torch.arange(1.0, 4.0)), torch.tensor(0.6)) + + +def test_constant_and_per_token_loss_are_mutually_exclusive(miles, monkeypatch): + monkeypatch.setattr(miles.cp_utils, "get_parallel_state", lambda: _parallel_state(miles, cp_size=1)) + + with pytest.raises(ValueError, match="mutually exclusive"): + miles.get_sum_of_sample_mean( + TOTAL_LENGTHS, RESPONSE_LENGTHS, LOSS_MASKS, calculate_per_token_loss=True, constant_divisor=10.0 + ) + + +def test_token_mean_is_unnormalized_global_token_sum(miles, monkeypatch): + monkeypatch.setattr(miles.cp_utils, "get_parallel_state", lambda: _parallel_state(miles, cp_size=1)) + + reducer = miles.get_sum_of_sample_mean(TOTAL_LENGTHS, RESPONSE_LENGTHS, LOSS_MASKS, calculate_per_token_loss=True) + torch.testing.assert_close(reducer(X), torch.tensor(75.0)) + + +def test_legacy_positional_per_token_loss_call_still_works(miles, monkeypatch): + monkeypatch.setattr(miles.cp_utils, "get_parallel_state", lambda: _parallel_state(miles, cp_size=1)) + + reducer = miles.get_sum_of_sample_mean(TOTAL_LENGTHS, RESPONSE_LENGTHS, LOSS_MASKS, True) + torch.testing.assert_close(reducer(X), torch.tensor(75.0)) + + +@pytest.mark.parametrize("constant_divisor", [None, 20.0]) +def test_cp_zigzag_rank_sum_matches_single_rank(miles, monkeypatch, constant_divisor): + total_length, response_length = 10, 8 + loss_mask = torch.tensor([1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0]) + x_full = torch.arange(1.0, 9.0) + sample_denoms = None if constant_divisor is not None else [torch.tensor(20.0)] + + def build(): + return miles.get_sum_of_sample_mean( + [total_length], + [response_length], + [loss_mask], + constant_divisor=constant_divisor, + sample_denoms=sample_denoms, + ) + + monkeypatch.setattr(miles.cp_utils, "get_parallel_state", lambda: _parallel_state(miles, cp_size=1)) + ref = build()(x_full) + + total = torch.zeros(()) + for rank in range(2): + monkeypatch.setattr( + miles.cp_utils, + "get_parallel_state", + lambda r=rank: _parallel_state(miles, cp_size=2, cp_rank=r), + ) + x_local = miles.cp_utils._slice_loss_mask_for_local_cp(total_length, response_length, x_full, "thd", None) + total = total + build()(x_local) + + torch.testing.assert_close(total, ref) + + +def _args(**overrides): + base = dict( + loss_aggregation="sample_mean", + loss_aggregation_divisor=None, + calculate_per_token_loss=False, + n_samples_per_prompt=2, + qkv_format="thd", + ) + base.update(overrides) + return SimpleNamespace(**base) + + +def _default_reducer(miles, *, calculate_per_token_loss=False): + return miles.get_sum_of_sample_mean( + TOTAL_LENGTHS, + RESPONSE_LENGTHS, + LOSS_MASKS, + calculate_per_token_loss=calculate_per_token_loss, + ) + + +def _select(miles, args, batch, *, default_reducer=None, pg_loss_masks=LOSS_MASKS): + batch = dict(batch) + batch.setdefault("loss_masks", LOSS_MASKS) + if default_reducer is None: + default_reducer = _default_reducer(miles) + return miles.get_pg_loss_reducer( + args, + batch, + total_lengths=TOTAL_LENGTHS, + response_lengths=RESPONSE_LENGTHS, + pg_loss_masks=pg_loss_masks, + max_seq_lens=None, + default_reducer=default_reducer, + ) + + +@pytest.mark.parametrize( + ("args", "batch", "expected"), + [ + (_args(loss_aggregation="sample_mean"), {}, 26.0), + (_args(loss_aggregation="token_mean", calculate_per_token_loss=True), {}, 75.0), + (_args(loss_aggregation="constant", loss_aggregation_divisor=10.0), {}, 7.5), + ( + _args(loss_aggregation="prompt_mean"), + {"prompt_group_indices": PROMPT_GROUP_INDICES, "prompt_mask_sums": [torch.tensor(5.0)] * 4}, + 30.0, + ), + ], +) +def test_pg_loss_reducer_modes_compute_expected_loss(miles, monkeypatch, args, batch, expected): + monkeypatch.setattr(miles.cp_utils, "get_parallel_state", lambda: _parallel_state(miles, cp_size=1)) + + reducer = _select(miles, args, batch) + torch.testing.assert_close(reducer(X), torch.tensor(expected)) + + +def test_missing_loss_aggregation_reuses_legacy_sample_mean(miles, monkeypatch): + monkeypatch.setattr(miles.cp_utils, "get_parallel_state", lambda: _parallel_state(miles, cp_size=1)) + + args = SimpleNamespace(calculate_per_token_loss=False, qkv_format="thd", n_samples_per_prompt=2) + reducer = _select(miles, args, {}) + torch.testing.assert_close(reducer(X), torch.tensor(26.0)) + + +def test_missing_loss_aggregation_reuses_legacy_token_mean(miles, monkeypatch): + monkeypatch.setattr(miles.cp_utils, "get_parallel_state", lambda: _parallel_state(miles, cp_size=1)) + + args = SimpleNamespace(calculate_per_token_loss=True, qkv_format="thd", n_samples_per_prompt=2) + reducer = _select(miles, args, {}) + torch.testing.assert_close(reducer(X), torch.tensor(75.0)) + + +def test_prompt_mean_recomputes_denoms_from_current_pg_loss_masks(miles, monkeypatch): + monkeypatch.setattr(miles.cp_utils, "get_parallel_state", lambda: _parallel_state(miles, cp_size=1)) + + pg_loss_masks = [ + torch.tensor([1.0, 0.0, 0.0]), + torch.tensor([0.0, 0.0, 0.0]), + torch.tensor([1.0, 0.0, 0.0, 0.0]), + torch.tensor([1.0, 1.0, 0.0, 0.0]), + ] + + reducer = _select( + miles, + _args(loss_aggregation="prompt_mean"), + {"prompt_group_indices": PROMPT_GROUP_INDICES, "prompt_mask_sums": [torch.tensor(99.0)] * 4}, + pg_loss_masks=pg_loss_masks, + ) + + torch.testing.assert_close(reducer(X), torch.tensor(22.0)) + + +def test_prompt_mean_handles_zero_mask_completion_in_group(miles, monkeypatch): + monkeypatch.setattr(miles.cp_utils, "get_parallel_state", lambda: _parallel_state(miles, cp_size=1)) + + args = _args(loss_aggregation="prompt_mean") + pg_loss_masks = [torch.tensor([0.0, 0.0]), torch.tensor([1.0, 1.0])] + pg_loss = torch.tensor([100.0, 200.0, 3.0, 7.0]) + + reducer = miles.get_pg_loss_reducer( + args, + { + "loss_masks": [torch.ones(2), torch.ones(2)], + "prompt_group_indices": [0, 0], + "prompt_mask_sums": [torch.tensor(99.0), torch.tensor(99.0)], + }, + total_lengths=[2, 2], + response_lengths=[2, 2], + pg_loss_masks=pg_loss_masks, + max_seq_lens=None, + default_reducer=miles.get_sum_of_sample_mean([2, 2], [2, 2], pg_loss_masks), + ) + + torch.testing.assert_close(reducer(pg_loss), torch.tensor(10.0)) + + +def test_prompt_mean_without_prompt_mask_sums_fails(miles): + with pytest.raises(ValueError, match="prompt_mask_sums"): + _select(miles, _args(loss_aggregation="prompt_mean"), {}, default_reducer=lambda x: x) + + +def test_prompt_mean_modified_masks_without_prompt_group_indices_fails(miles): + pg_loss_masks = [loss_mask.clone() for loss_mask in LOSS_MASKS] + with pytest.raises(ValueError, match="prompt_group_indices"): + _select( + miles, + _args(loss_aggregation="prompt_mean"), + {"prompt_mask_sums": [torch.tensor(5.0)] * 4}, + default_reducer=lambda x: x, + pg_loss_masks=pg_loss_masks, + ) + + +def test_prompt_mean_modified_masks_rejects_partial_prompt_groups(miles, monkeypatch): + monkeypatch.setattr(miles.cp_utils, "get_parallel_state", lambda: _parallel_state(miles, cp_size=1)) + + pg_loss_masks = [torch.tensor([1.0, 0.0]), torch.tensor([1.0, 1.0])] + with pytest.raises(ValueError, match="complete prompt groups"): + miles.get_pg_loss_reducer( + _args(loss_aggregation="prompt_mean"), + { + "loss_masks": [torch.ones(2), torch.ones(2)], + "prompt_group_indices": [0, 1], + "prompt_mask_sums": [torch.tensor(99.0), torch.tensor(99.0)], + }, + total_lengths=[2, 2], + response_lengths=[2, 2], + pg_loss_masks=pg_loss_masks, + max_seq_lens=None, + default_reducer=miles.get_sum_of_sample_mean([2, 2], [2, 2], pg_loss_masks), + ) + + +def test_token_mean_log_aggregation_keeps_metric_sample_mean(miles, monkeypatch): + state = _parallel_state(miles, cp_size=1) + monkeypatch.setattr(miles.cp_utils, "get_parallel_state", lambda: state) + monkeypatch.setattr(miles.loss, "get_parallel_state", lambda: state) + monkeypatch.setattr(miles.log_utils, "get_parallel_state", lambda: state) + monkeypatch.setattr(miles.log_utils.dist, "all_reduce", lambda *args, **kwargs: None) + + args = _args(loss_aggregation="token_mean", calculate_per_token_loss=True) + args.global_batch_size = 2 + args.use_dynamic_global_batch_size = False + args.recompute_loss_function = False + args.true_on_policy_mode = False + + batch = { + "loss_masks": [torch.ones(2), torch.ones(3)], + "total_lengths": [2, 3], + "response_lengths": [2, 3], + } + + def fake_loss_function(args, batch, logits, sum_of_sample_mean): + per_token_metric = torch.arange(1.0, 6.0, device=logits.device) + return logits.sum() * 0, { + "pg_loss": per_token_metric.sum(), + "ppo_kl": sum_of_sample_mean(per_token_metric), + } + + monkeypatch.setattr(miles.loss, "get_loss_function", lambda args: fake_loss_function) + + _, _, log_dict = miles.loss.loss_function(args, batch, 1, torch.ones((), requires_grad=True)) + loss_dict = miles.log_utils.aggregate_train_losses([log_dict]) + + torch.testing.assert_close(torch.tensor(loss_dict["pg_loss"]), torch.tensor(3.0)) + torch.testing.assert_close(torch.tensor(loss_dict["ppo_kl"]), torch.tensor(2.75)) + + +def _validate_args(**overrides): + base = dict( + loss_aggregation="sample_mean", + loss_aggregation_divisor=None, + calculate_per_token_loss=False, + global_batch_size=4, + n_samples_per_prompt=2, + ) + base.update(overrides) + return SimpleNamespace(**base) + + +@pytest.mark.parametrize("divisor", [None, 0.0, -1.0, float("nan")]) +def test_validate_constant_rejects_nonpositive_divisor(miles, divisor): + args = _validate_args(loss_aggregation="constant", loss_aggregation_divisor=divisor) + with pytest.raises(ValueError, match="loss-aggregation-divisor"): + miles.arguments._validate_loss_aggregation_args(args) + + +def test_validate_token_mean_aliases_calculate_per_token_loss(miles): + args = _validate_args(loss_aggregation="token_mean") + miles.arguments._validate_loss_aggregation_args(args) + assert args.calculate_per_token_loss is True + + +def test_validate_calculate_per_token_loss_alone_reconciles_to_token_mean(miles): + args = _validate_args(calculate_per_token_loss=True) + miles.arguments._validate_loss_aggregation_args(args) + assert args.loss_aggregation == "token_mean" + assert args.calculate_per_token_loss is True + + +def test_validate_default_leaves_per_token_loss_off(miles): + args = _validate_args() + miles.arguments._validate_loss_aggregation_args(args) + assert args.loss_aggregation == "sample_mean" + assert args.calculate_per_token_loss is False + + +def test_validate_constant_rejects_calculate_per_token_loss(miles): + args = _validate_args(loss_aggregation="constant", loss_aggregation_divisor=10.0, calculate_per_token_loss=True) + with pytest.raises(ValueError, match="incompatible with --calculate-per-token-loss"): + miles.arguments._validate_loss_aggregation_args(args) + + +def test_validate_constant_with_per_token_loss_off_passes(miles): + args = _validate_args(loss_aggregation="constant", loss_aggregation_divisor=10.0, calculate_per_token_loss=False) + miles.arguments._validate_loss_aggregation_args(args) + assert args.calculate_per_token_loss is False + + +def test_validate_prompt_mean_rejects_calculate_per_token_loss(miles): + args = _validate_args(loss_aggregation="prompt_mean", calculate_per_token_loss=True) + with pytest.raises(ValueError, match="incompatible with --calculate-per-token-loss"): + miles.arguments._validate_loss_aggregation_args(args) + + +def test_validate_prompt_mean_rejects_non_multiple_global_batch_size(miles): + args = _validate_args(loss_aggregation="prompt_mean", global_batch_size=3, n_samples_per_prompt=2) + with pytest.raises(ValueError, match="multiple of n_samples_per_prompt"): + miles.arguments._validate_loss_aggregation_args(args) + + +def test_miles_validate_args_checks_prompt_mean_after_deriving_global_batch_size(miles): + parser = argparse.ArgumentParser() + miles.arguments.get_miles_extra_args_provider()(parser) + args = parser.parse_args( + [ + "--rollout-batch-size", + "3", + "--n-samples-per-prompt", + "2", + "--num-steps-per-rollout", + "2", + "--loss-aggregation", + "prompt_mean", + ] + ) + + with pytest.raises(ValueError, match="multiple of n_samples_per_prompt"): + miles.arguments.miles_validate_args(args) + + +def test_validate_prompt_mean_accepts_multiple_global_batch_size(miles): + args = _validate_args(loss_aggregation="prompt_mean", global_batch_size=6, n_samples_per_prompt=2) + miles.arguments._validate_loss_aggregation_args(args) + assert args.loss_aggregation == "prompt_mean" + + +def test_validate_non_prompt_mean_allows_non_multiple_global_batch_size(miles): + args = _validate_args(loss_aggregation="sample_mean", global_batch_size=3, n_samples_per_prompt=2) + miles.arguments._validate_loss_aggregation_args(args) + assert args.loss_aggregation == "sample_mean" + + +def _make_sample(miles, group_index, index, response_length, loss_mask): + sample = miles.Sample(group_index=group_index, index=index, response_length=response_length, reward=1.0) + sample.tokens = [0] * response_length + sample.loss_mask = loss_mask + sample.status = miles.Sample.Status.COMPLETED + return sample + + +def _convert_args(**overrides): + base = dict( + advantage_estimator="grpo", + rewards_normalization=False, + reward_key=None, + use_dynamic_global_batch_size=False, + loss_aggregation="prompt_mean", + n_samples_per_prompt=2, + rollout_batch_size=2, + grpo_std_normalization=False, + ) + base.update(overrides) + return SimpleNamespace(**base) + + +def _convert(miles, samples, args): + return miles.convert_samples_to_train_data( + args, + samples, + metadata={}, + custom_convert_samples_to_train_data_func=None, + custom_reward_post_process_func=None, + ) + + +def test_convert_samples_computes_step_level_prompt_group_denoms(miles): + samples = [ + _make_sample(miles, 0, 0, 3, [1, 1, 0]), + _make_sample(miles, 0, 1, 3, [1, 1, 1]), + _make_sample(miles, 1, 2, 4, [1, 0, 0, 0]), + _make_sample(miles, 1, 3, 4, [1, 1, 1, 1]), + ] + + train_data = _convert(miles, samples, _convert_args()) + + assert train_data["prompt_group_indices"] == [0, 0, 1, 1] + assert train_data["prompt_mask_sums"] == [5, 5, 5, 5] + + +def test_convert_samples_prompt_mean_rejects_none_group_index(miles): + samples = [_make_sample(miles, None, 0, 2, [1, 1]), _make_sample(miles, None, 1, 2, [1, 0])] + with pytest.raises(ValueError, match="group_index"): + _convert(miles, samples, _convert_args(rollout_batch_size=1)) + + +def test_convert_samples_omits_prompt_group_fields_for_default_mode(miles): + samples = [_make_sample(miles, 0, 0, 2, [1, 1]), _make_sample(miles, 0, 1, 2, [1, 0])] + + train_data = _convert(miles, samples, _convert_args(loss_aggregation="sample_mean", rollout_batch_size=1)) + + assert "prompt_group_indices" not in train_data + assert "prompt_mask_sums" not in train_data From fe7d2bfa9524787cb6e5b1e14a6c8dad58b2d228 Mon Sep 17 00:00:00 2001 From: EazyReal Date: Sat, 27 Jun 2026 22:14:22 +0000 Subject: [PATCH 2/2] fix(loss): guard loss aggregation normalizers --- docs/user-guide/customization.md | 6 +- miles/backends/training_utils/cp_utils.py | 2 + miles/backends/training_utils/log_utils.py | 23 +- miles/backends/training_utils/loss.py | 100 ++++-- .../training_utils/loss_hub/losses.py | 13 +- miles/ray/rollout/train_data_conversion.py | 51 ++- miles/utils/arguments.py | 26 +- .../training_utils/test_loss_aggregation.py | 300 ++++++++++++++++++ 8 files changed, 472 insertions(+), 49 deletions(-) diff --git a/docs/user-guide/customization.md b/docs/user-guide/customization.md index cf5a985047..158045cf5e 100644 --- a/docs/user-guide/customization.md +++ b/docs/user-guide/customization.md @@ -220,7 +220,7 @@ with `--calculate-per-token-loss` is rejected at startup (per-token loss would renormalize by the token count and undo that mode's denominator). `--loss-aggregation-divisor L` is required (validated `> 0` at startup) only for -`constant`; it is ignored for the other modes. +`constant`; every other mode rejects a configured divisor at startup. Every mode shares the same outer structure: each step's pg_loss sum is divided by `global_batch_size` (the standard step average). The per-mode denominator above is @@ -235,6 +235,10 @@ mean contributes once, and the final scalar is the mean over prompt groups. It requires `global_batch_size` to be a multiple of `n_samples_per_prompt`, so a contiguous train step keeps each prompt group whole instead of normalizing against a partially-present group total. +Prompt groups also stay whole when splitting train data across data-parallel +ranks; if the number of prompt groups in a train step is not divisible by the +data-parallel size, partitioning fails before training instead of creating a +partial prompt group on a rank. For custom train-data converters, `prompt_mean` should mirror the built-in converter and emit both `prompt_mask_sums` and `prompt_group_indices`. diff --git a/miles/backends/training_utils/cp_utils.py b/miles/backends/training_utils/cp_utils.py index 9fd257b88f..ac630cbc81 100644 --- a/miles/backends/training_utils/cp_utils.py +++ b/miles/backends/training_utils/cp_utils.py @@ -105,6 +105,8 @@ def get_sum_of_sample_mean( "constant_divisor (loss-aggregation=constant) and calculate_per_token_loss " "(loss-aggregation=token_mean) are mutually exclusive aggregation modes." ) + if constant_divisor is not None and constant_divisor <= 0: + raise ValueError("constant_divisor (loss-aggregation=constant) must be positive.") if sample_denoms is None and constant_divisor is None and not calculate_per_token_loss: sample_denoms = [loss_mask.sum() for loss_mask in loss_masks] diff --git a/miles/backends/training_utils/log_utils.py b/miles/backends/training_utils/log_utils.py index b6214b70cc..a5a6726ad6 100644 --- a/miles/backends/training_utils/log_utils.py +++ b/miles/backends/training_utils/log_utils.py @@ -388,10 +388,20 @@ def aggregate_train_losses( return {} keys = losses_reduced[0]["keys"] + has_normalizers = "normalizers" in losses_reduced[0] values = None normalizers = None for log_dict in losses_reduced: + if log_dict["keys"] != keys: + raise ValueError( + "All train loss log_dict entries must have the same keys in the same order; " + f"got {log_dict['keys']} after {keys}." + ) + if ("normalizers" in log_dict) != has_normalizers: + raise ValueError( + "Cannot aggregate a mix of train loss log_dict entries with and without explicit normalizers." + ) if values is None: values = log_dict["values"].clone() else: @@ -403,10 +413,13 @@ def aggregate_train_losses( normalizers += log_dict["normalizers"] if normalizers is None: - assert len(keys) + 1 == values.numel(), f"Expected {len(keys) + 1} values, got {values.numel()}" + if len(keys) + 1 != values.numel(): + raise ValueError(f"Expected {len(keys) + 1} values, got {values.numel()}") else: - assert len(keys) == values.numel(), f"Expected {len(keys)} values, got {values.numel()}" - assert len(keys) == normalizers.numel(), f"Expected {len(keys)} normalizers, got {normalizers.numel()}" + if len(keys) != values.numel(): + raise ValueError(f"Expected {len(keys)} values, got {values.numel()}") + if len(keys) != normalizers.numel(): + raise ValueError(f"Expected {len(keys)} normalizers, got {normalizers.numel()}") dist.all_reduce(values, op=dist.ReduceOp.SUM, group=parallel_state.intra_dp_cp.group) if normalizers is not None: @@ -416,12 +429,12 @@ def aggregate_train_losses( values = values.tolist() if normalizers is not None: normalizers = normalizers.tolist() - for key, value, normalizer in zip(keys, values, normalizers, strict=False): + for key, value, normalizer in zip(keys, values, normalizers, strict=True): loss_reduced[key] = value * parallel_state.cp.size / normalizer return loss_reduced num_samples_or_tokens = values[0] - for key, value in zip(keys, values[1:], strict=False): + for key, value in zip(keys, values[1:], strict=True): loss_reduced[key] = value * parallel_state.cp.size / num_samples_or_tokens return loss_reduced diff --git a/miles/backends/training_utils/loss.py b/miles/backends/training_utils/loss.py index 7da16e0160..874aa07d33 100644 --- a/miles/backends/training_utils/loss.py +++ b/miles/backends/training_utils/loss.py @@ -12,6 +12,69 @@ from miles.backends.training_utils.parallel import get_parallel_state from miles.utils.types import RolloutBatch +TOKEN_NORMALIZED_TRAIN_KEYS = frozenset({"loss", "pg_loss", "ess_ratio", "value_loss"}) + + +def _as_scalar_tensor(value, *, device: torch.device) -> torch.Tensor: + if isinstance(value, torch.Tensor): + return value.to(device=device).reshape(()) + return torch.tensor(value, device=device) + + +def _build_train_log_dict( + log: dict[str, torch.Tensor], + *, + num_samples: int, + num_tokens: torch.Tensor, + device: torch.device, + calculate_per_token_loss: bool, +) -> dict[str, list[str] | torch.Tensor]: + keys = list(log.keys()) + values = torch.stack([_as_scalar_tensor(value, device=device) for value in log.values()]) + if not calculate_per_token_loss: + return { + "keys": keys, + "values": torch.cat([torch.tensor([num_samples], device=device), values]), + } + + normalizers = torch.stack( + [ + ( + num_tokens.to(device=device).reshape(()) + if key in TOKEN_NORMALIZED_TRAIN_KEYS + else torch.tensor(num_samples, device=device) + ) + for key in keys + ] + ) + return { + "keys": keys, + "values": values, + "normalizers": normalizers, + } + + +def _validate_loss_aggregation_contract(args: Namespace) -> None: + if getattr(args, "loss_type", None) != "policy_loss": + return + mode = getattr(args, "loss_aggregation", None) + token_mean = mode == "token_mean" or (mode is None and getattr(args, "calculate_per_token_loss", False)) + if not token_mean: + return + + mixed_terms = [] + if getattr(args, "entropy_coef", 0) != 0: + mixed_terms.append("--entropy-coef") + if getattr(args, "use_kl_loss", False) and getattr(args, "kl_loss_coef", 0) != 0: + mixed_terms.append("--kl-loss-coef") + if mixed_terms: + raise ValueError( + "--loss-aggregation token_mean cannot be combined with " + f"{', '.join(mixed_terms)} because the policy-gradient term is token-normalized " + "while auxiliary policy-loss terms are sample-normalized. Use sample_mean, " + "set the auxiliary coefficient to 0, or add explicit per-term loss normalizers." + ) + def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch) -> None: """Compute advantages and returns in-place based on `args.advantage_estimator`. @@ -122,13 +185,19 @@ def loss_function( parallel_state = get_parallel_state() num_tokens = sum([torch.clamp_min(loss_mask.sum(), 1) for loss_mask in batch["loss_masks"]]) num_samples = len(batch["response_lengths"]) + _validate_loss_aggregation_contract(args) - # --loss-aggregation applies to pg_loss only; metrics keep this reducer. + # Policy loss selects pg_loss aggregation separately; this shared reducer is + # for metrics and auxiliary terms. Non-policy losses use the legacy reducer + # axis directly. + reducer_per_token_loss = ( + args.calculate_per_token_loss and getattr(args, "loss_type", "policy_loss") != "policy_loss" + ) sum_of_sample_mean = get_sum_of_sample_mean( batch["total_lengths"], batch["response_lengths"], batch["loss_masks"], - calculate_per_token_loss=False, + calculate_per_token_loss=reducer_per_token_loss, qkv_format=args.qkv_format, max_seq_lens=batch.get("max_seq_lens", None), ) @@ -167,26 +236,13 @@ def loss_function( if apply_megatron_loss_scaling: loss = loss * parallel_state.cp.size - log_keys = list(log.keys()) - log_dict = { - "keys": log_keys, - "values": torch.tensor( - [ - num_samples, - ] - + list(log.values()), - device=logits.device, - ), - } - if args.calculate_per_token_loss: - log_dict = { - "keys": log_keys, - "values": torch.tensor(list(log.values()), device=logits.device), - "normalizers": torch.tensor( - [num_tokens if key in {"loss", "pg_loss", "ess_ratio"} else num_samples for key in log_keys], - device=logits.device, - ), - } + log_dict = _build_train_log_dict( + log, + num_samples=num_samples, + num_tokens=num_tokens, + device=logits.device, + calculate_per_token_loss=args.calculate_per_token_loss, + ) return ( loss, diff --git a/miles/backends/training_utils/loss_hub/losses.py b/miles/backends/training_utils/loss_hub/losses.py index 66d80edbfe..22c21ab579 100644 --- a/miles/backends/training_utils/loss_hub/losses.py +++ b/miles/backends/training_utils/loss_hub/losses.py @@ -304,14 +304,7 @@ def policy_loss_function( # Apply off-policy correction using importance sampling if enabled if args.get_mismatch_metrics or args.use_tis: - # NOTE: - # `tis_func` may apply rejection-sampling style masking (RS) and return `modified_response_masks`. - # We rebuild `sum_of_sample_mean` with those masks to correct denominators for loss/backprop. - # - # However, mismatch/TIS/RS metrics (e.g., "truncate_fraction") are often defined over the - # *pre-RS* valid tokens. If we aggregate metrics with `modified_response_masks`, the rejected - # tokens are excluded from the denominator and the metric can be artificially driven to 0. - # Keep a copy of the original reducer (based on `batch["loss_masks"]`) for metric aggregation. + # TIS/RS can shrink the pg_loss mask; mismatch metrics stay on the original mask. sum_of_sample_mean_for_mismatch_metrics = sum_of_sample_mean assert "rollout_log_probs" in batch, "rollout_log_probs must be provided for TIS" @@ -335,13 +328,11 @@ def policy_loss_function( tis_func = vanilla_tis_function pg_loss, modified_response_masks, tis_metrics = tis_func(**tis_kwargs) - # [decouple IS and rejection] Rebuild sum_of_sample_mean with modified_response_masks for denominator correction - # modified_response_masks will be sliced with cp in get_sum_of_sample_mean sum_of_sample_mean = get_sum_of_sample_mean( total_lengths, response_lengths, modified_response_masks, - calculate_per_token_loss=args.calculate_per_token_loss, + calculate_per_token_loss=False, qkv_format=args.qkv_format, max_seq_lens=max_seq_lens, ) diff --git a/miles/ray/rollout/train_data_conversion.py b/miles/ray/rollout/train_data_conversion.py index cffdf07e32..dfc4886a65 100644 --- a/miles/ray/rollout/train_data_conversion.py +++ b/miles/ray/rollout/train_data_conversion.py @@ -132,6 +132,48 @@ def _post_process_rewards(args, samples: list[Sample] | list[list[Sample]], cust return raw_rewards, raw_rewards +def _prompt_group_partitions( + prompt_group_indices: list[int], + total_lengths: list[int], + dp_size: int, + *, + balance_data: bool, +) -> list[list[int]]: + if len(prompt_group_indices) != len(total_lengths): + raise ValueError( + "--loss-aggregation prompt_mean requires one prompt_group_indices entry per sample " + f"(got {len(prompt_group_indices)} for {len(total_lengths)} samples)." + ) + + group_to_indices: dict[int, list[int]] = {} + group_order: list[int] = [] + for sample_index, group_index in enumerate(prompt_group_indices): + group_key = int(group_index.item()) if isinstance(group_index, torch.Tensor) else int(group_index) + if group_key not in group_to_indices: + group_to_indices[group_key] = [] + group_order.append(group_key) + group_to_indices[group_key].append(sample_index) + + if len(group_order) % dp_size != 0: + raise ValueError( + "--loss-aggregation prompt_mean requires the number of prompt groups in a train step " + f"to be divisible by dp_size (got {len(group_order)} prompt groups, dp_size={dp_size})." + ) + + if balance_data: + group_lengths = [ + sum(total_lengths[index] for index in group_to_indices[group_key]) for group_key in group_order + ] + group_partitions = get_seqlen_balanced_partitions(group_lengths, dp_size, equal_size=True) + else: + group_partitions = [range(i, len(group_order), dp_size) for i in range(dp_size)] + + return [ + [sample_index for group_index in partition for sample_index in group_to_indices[group_order[group_index]]] + for partition in group_partitions + ] + + def split_train_data_by_dp(args, data, dp_size): """Split the train data by data parallel size.""" rollout_data = {} @@ -142,7 +184,14 @@ def split_train_data_by_dp(args, data, dp_size): total_lengths = [len(t) for t in data["tokens"]] data["total_lengths"] = total_lengths - if args.balance_data: + if getattr(args, "loss_aggregation", "sample_mean") == "prompt_mean" and "prompt_group_indices" in data: + partitions = _prompt_group_partitions( + data["prompt_group_indices"], + total_lengths, + dp_size, + balance_data=args.balance_data, + ) + elif args.balance_data: partitions = get_seqlen_balanced_partitions(total_lengths, dp_size, equal_size=True) else: partitions = [range(i, len(total_lengths), dp_size) for i in range(dp_size)] diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 35b858e697..f79a506f2b 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -2078,6 +2078,20 @@ def _validate_loss_aggregation_args(args): ) +def _derive_global_batch_size_from_rollout(args, *, require_existing_match: bool = True) -> None: + if args.num_steps_per_rollout is None: + return + global_batch_size = args.rollout_batch_size * args.n_samples_per_prompt // args.num_steps_per_rollout + if require_existing_match and args.global_batch_size is not None: + if args.global_batch_size != global_batch_size: + raise ValueError( + f"global_batch_size {args.global_batch_size} is not equal to " + f"rollout_batch_size {args.rollout_batch_size} * n_samples_per_prompt {args.n_samples_per_prompt} " + f"// num_steps_per_rollout {args.num_steps_per_rollout}" + ) + args.global_batch_size = global_batch_size + + def miles_validate_args(args): args.eval_datasets = _resolve_eval_datasets(args) @@ -2393,15 +2407,7 @@ def miles_validate_args(args): if args.eval_function_path is None: args.eval_function_path = args.rollout_function_path - if args.num_steps_per_rollout is not None: - global_batch_size = args.rollout_batch_size * args.n_samples_per_prompt // args.num_steps_per_rollout - if args.global_batch_size is not None: - assert args.global_batch_size == global_batch_size, ( - f"global_batch_size {args.global_batch_size} is not equal to " - f"rollout_batch_size {args.rollout_batch_size} * n_samples_per_prompt {args.n_samples_per_prompt} " - f"// num_steps_per_rollout {args.num_steps_per_rollout}" - ) - args.global_batch_size = global_batch_size + _derive_global_batch_size_from_rollout(args) _validate_loss_aggregation_args(args) @@ -2444,6 +2450,8 @@ def miles_validate_args(args): if hasattr(args, k): logger.info(f"Warning: Argument {k} is already set to {getattr(args, k)}, will override with {v}.") setattr(args, k, v) + _derive_global_batch_size_from_rollout(args, require_existing_match="global_batch_size" in data) + _validate_loss_aggregation_args(args) if args.use_rollout_indexer_replay: args.use_indexer_replay = True diff --git a/tests/fast/backends/training_utils/test_loss_aggregation.py b/tests/fast/backends/training_utils/test_loss_aggregation.py index 48f163c1d9..23b4d6e44d 100644 --- a/tests/fast/backends/training_utils/test_loss_aggregation.py +++ b/tests/fast/backends/training_utils/test_loss_aggregation.py @@ -98,6 +98,8 @@ def miles(monkeypatch): yield SimpleNamespace( arguments=arguments, convert_samples_to_train_data=train_data_conversion.convert_samples_to_train_data, + split_train_data_by_dp=train_data_conversion.split_train_data_by_dp, + train_data_conversion=train_data_conversion, cp_utils=cp_utils, log_utils=log_utils, loss=loss, @@ -246,6 +248,19 @@ def build(): torch.testing.assert_close(total, ref) +@pytest.mark.parametrize("constant_divisor", [0.0, -1.0]) +def test_constant_reducer_rejects_nonpositive_divisor(miles, monkeypatch, constant_divisor): + monkeypatch.setattr(miles.cp_utils, "get_parallel_state", lambda: _parallel_state(miles, cp_size=1)) + + with pytest.raises(ValueError, match="constant_divisor"): + miles.get_sum_of_sample_mean( + TOTAL_LENGTHS, + RESPONSE_LENGTHS, + LOSS_MASKS, + constant_divisor=constant_divisor, + ) + + def _args(**overrides): base = dict( loss_aggregation="sample_mean", @@ -435,6 +450,159 @@ def fake_loss_function(args, batch, logits, sum_of_sample_mean): torch.testing.assert_close(torch.tensor(loss_dict["ppo_kl"]), torch.tensor(2.75)) +@pytest.mark.parametrize( + "overrides", + [ + {"entropy_coef": 0.01, "use_kl_loss": False, "kl_loss_coef": 0.0}, + {"entropy_coef": 0.0, "use_kl_loss": True, "kl_loss_coef": 0.1}, + ], +) +def test_token_mean_rejects_mixed_policy_loss_normalizers(miles, overrides): + args = _args(loss_aggregation="token_mean", calculate_per_token_loss=True) + args.loss_type = "policy_loss" + args.entropy_coef = overrides["entropy_coef"] + args.use_kl_loss = overrides["use_kl_loss"] + args.kl_loss_coef = overrides["kl_loss_coef"] + + with pytest.raises(ValueError, match="auxiliary policy-loss terms are sample-normalized"): + miles.loss._validate_loss_aggregation_contract(args) + + +def test_token_mean_allows_auxiliary_metrics_without_loss_contribution(miles): + args = _args(loss_aggregation="token_mean", calculate_per_token_loss=True) + args.loss_type = "policy_loss" + args.entropy_coef = 0.0 + args.use_kl_loss = True + args.kl_loss_coef = 0.0 + + miles.loss._validate_loss_aggregation_contract(args) + + +def test_build_train_log_dict_carries_per_metric_normalizers(miles): + log_dict = miles.loss._build_train_log_dict( + { + "loss": torch.tensor(12.0), + "pg_loss": torch.tensor(15.0), + "ppo_kl": torch.tensor(5.5), + }, + num_samples=2, + num_tokens=torch.tensor(5.0), + device=torch.device("cpu"), + calculate_per_token_loss=True, + ) + + assert log_dict["keys"] == ["loss", "pg_loss", "ppo_kl"] + torch.testing.assert_close(log_dict["values"], torch.tensor([12.0, 15.0, 5.5])) + torch.testing.assert_close(log_dict["normalizers"], torch.tensor([5.0, 5.0, 2.0])) + + +def test_non_policy_per_token_loss_uses_token_reducer(miles, monkeypatch): + state = _parallel_state(miles, cp_size=1) + monkeypatch.setattr(miles.cp_utils, "get_parallel_state", lambda: state) + monkeypatch.setattr(miles.loss, "get_parallel_state", lambda: state) + + args = _args(loss_aggregation="token_mean", calculate_per_token_loss=True) + args.loss_type = "sft_loss" + args.global_batch_size = 2 + args.use_dynamic_global_batch_size = False + args.recompute_loss_function = False + args.true_on_policy_mode = False + + batch = { + "loss_masks": [torch.ones(2), torch.ones(3)], + "total_lengths": [2, 3], + "response_lengths": [2, 3], + } + + def fake_loss_function(args, batch, logits, sum_of_sample_mean): + token_values = torch.arange(1.0, 6.0, device=logits.device) + loss = sum_of_sample_mean(token_values) + return loss, {"loss": loss.detach()} + + monkeypatch.setattr(miles.loss, "get_loss_function", lambda args: fake_loss_function) + + loss, normalizer, log_dict = miles.loss.loss_function(args, batch, 1, torch.ones((), requires_grad=True)) + + torch.testing.assert_close(loss, torch.tensor(15.0)) + torch.testing.assert_close(normalizer, torch.tensor(5.0)) + torch.testing.assert_close(log_dict["values"], torch.tensor([15.0])) + torch.testing.assert_close(log_dict["normalizers"], torch.tensor([5.0])) + + +def test_aggregate_train_losses_rejects_mixed_normalizer_contracts(miles, monkeypatch): + state = _parallel_state(miles, cp_size=1) + monkeypatch.setattr(miles.log_utils, "get_parallel_state", lambda: state) + monkeypatch.setattr(miles.log_utils.dist, "all_reduce", lambda *args, **kwargs: None) + + with pytest.raises(ValueError, match="mix"): + miles.log_utils.aggregate_train_losses( + [ + { + "keys": ["pg_loss"], + "values": torch.tensor([10.0]), + "normalizers": torch.tensor([5.0]), + }, + { + "keys": ["pg_loss"], + "values": torch.tensor([2.0, 6.0]), + }, + ] + ) + + +def test_aggregate_train_losses_rejects_key_order_mismatch(miles, monkeypatch): + state = _parallel_state(miles, cp_size=1) + monkeypatch.setattr(miles.log_utils, "get_parallel_state", lambda: state) + monkeypatch.setattr(miles.log_utils.dist, "all_reduce", lambda *args, **kwargs: None) + + with pytest.raises(ValueError, match="same keys"): + miles.log_utils.aggregate_train_losses( + [ + { + "keys": ["pg_loss", "ppo_kl"], + "values": torch.tensor([10.0, 5.0]), + "normalizers": torch.tensor([5.0, 2.0]), + }, + { + "keys": ["ppo_kl", "pg_loss"], + "values": torch.tensor([5.0, 10.0]), + "normalizers": torch.tensor([2.0, 5.0]), + }, + ] + ) + + +def test_aggregate_train_losses_rejects_bad_legacy_value_count(miles, monkeypatch): + state = _parallel_state(miles, cp_size=1) + monkeypatch.setattr(miles.log_utils, "get_parallel_state", lambda: state) + + with pytest.raises(ValueError, match="Expected 2 values"): + miles.log_utils.aggregate_train_losses( + [ + { + "keys": ["pg_loss"], + "values": torch.tensor([10.0]), + }, + ] + ) + + +def test_aggregate_train_losses_rejects_bad_normalizer_count(miles, monkeypatch): + state = _parallel_state(miles, cp_size=1) + monkeypatch.setattr(miles.log_utils, "get_parallel_state", lambda: state) + + with pytest.raises(ValueError, match="Expected 2 normalizers"): + miles.log_utils.aggregate_train_losses( + [ + { + "keys": ["pg_loss", "ppo_kl"], + "values": torch.tensor([10.0, 5.0]), + "normalizers": torch.tensor([5.0]), + }, + ] + ) + + def _validate_args(**overrides): base = dict( loss_aggregation="sample_mean", @@ -518,6 +686,98 @@ def test_miles_validate_args_checks_prompt_mean_after_deriving_global_batch_size miles.arguments.miles_validate_args(args) +def test_miles_validate_args_rechecks_loss_aggregation_after_custom_config(miles, tmp_path): + config_path = tmp_path / "config.yaml" + config_path.write_text( + "\n".join( + [ + "loss_aggregation: prompt_mean", + "rollout_batch_size: 3", + "n_samples_per_prompt: 2", + "num_steps_per_rollout: 2", + ] + ) + ) + parser = argparse.ArgumentParser() + miles.arguments.get_miles_extra_args_provider()(parser) + args = parser.parse_args( + [ + "--rollout-batch-size", + "4", + "--n-samples-per-prompt", + "2", + "--num-steps-per-rollout", + "2", + "--num-rollout", + "1", + "--custom-config-path", + str(config_path), + ] + ) + + with pytest.raises(ValueError, match="multiple of n_samples_per_prompt"): + miles.arguments.miles_validate_args(args) + + +def test_miles_validate_args_rejects_conflicting_custom_config_global_batch_size(miles, tmp_path): + config_path = tmp_path / "config.yaml" + config_path.write_text( + "\n".join( + [ + "rollout_batch_size: 3", + "n_samples_per_prompt: 2", + "num_steps_per_rollout: 2", + "global_batch_size: 4", + ] + ) + ) + parser = argparse.ArgumentParser() + miles.arguments.get_miles_extra_args_provider()(parser) + args = parser.parse_args( + [ + "--rollout-batch-size", + "4", + "--n-samples-per-prompt", + "2", + "--num-steps-per-rollout", + "2", + "--num-rollout", + "1", + "--custom-config-path", + str(config_path), + ] + ) + + with pytest.raises(ValueError, match="global_batch_size 4 is not equal"): + miles.arguments.miles_validate_args(args) + + +def test_miles_validate_args_reconciles_token_mean_from_custom_config(miles, tmp_path): + config_path = tmp_path / "config.yaml" + config_path.write_text("loss_aggregation: token_mean\n") + parser = argparse.ArgumentParser() + miles.arguments.get_miles_extra_args_provider()(parser) + args = parser.parse_args( + [ + "--rollout-batch-size", + "4", + "--n-samples-per-prompt", + "2", + "--num-steps-per-rollout", + "2", + "--num-rollout", + "1", + "--custom-config-path", + str(config_path), + ] + ) + + miles.arguments.miles_validate_args(args) + + assert args.loss_aggregation == "token_mean" + assert args.calculate_per_token_loss is True + + def test_validate_prompt_mean_accepts_multiple_global_batch_size(miles): args = _validate_args(loss_aggregation="prompt_mean", global_batch_size=6, n_samples_per_prompt=2) miles.arguments._validate_loss_aggregation_args(args) @@ -541,6 +801,7 @@ def _make_sample(miles, group_index, index, response_length, loss_mask): def _convert_args(**overrides): base = dict( advantage_estimator="grpo", + balance_data=False, rewards_normalization=False, reward_key=None, use_dynamic_global_batch_size=False, @@ -577,6 +838,45 @@ def test_convert_samples_computes_step_level_prompt_group_denoms(miles): assert train_data["prompt_mask_sums"] == [5, 5, 5, 5] +def test_split_train_data_by_dp_keeps_prompt_groups_whole(miles, monkeypatch): + monkeypatch.setattr(miles.train_data_conversion.ray, "put", lambda value: value) + samples = [ + _make_sample(miles, 0, 0, 3, [1, 1, 0]), + _make_sample(miles, 0, 1, 3, [1, 1, 1]), + _make_sample(miles, 1, 2, 4, [1, 0, 0, 0]), + _make_sample(miles, 1, 3, 4, [1, 1, 1, 1]), + ] + args = _convert_args() + train_data = _convert(miles, samples, args) + + refs = miles.split_train_data_by_dp(args, train_data, dp_size=2) + parts = [ref.inner for ref in refs] + + assert parts[0]["partition"] == [0, 1] + assert parts[1]["partition"] == [2, 3] + assert parts[0]["prompt_group_indices"] == [0, 0] + assert parts[1]["prompt_group_indices"] == [1, 1] + assert parts[0]["prompt_mask_sums"] == [5, 5] + assert parts[1]["prompt_mask_sums"] == [5, 5] + + +def test_split_train_data_by_dp_rejects_undistributable_prompt_groups(miles, monkeypatch): + monkeypatch.setattr(miles.train_data_conversion.ray, "put", lambda value: value) + samples = [ + _make_sample(miles, 0, 0, 2, [1, 1]), + _make_sample(miles, 0, 1, 2, [1, 1]), + _make_sample(miles, 1, 2, 2, [1, 1]), + _make_sample(miles, 1, 3, 2, [1, 1]), + _make_sample(miles, 2, 4, 2, [1, 1]), + _make_sample(miles, 2, 5, 2, [1, 1]), + ] + args = _convert_args() + train_data = _convert(miles, samples, args) + + with pytest.raises(ValueError, match="divisible by dp_size"): + miles.split_train_data_by_dp(args, train_data, dp_size=2) + + def test_convert_samples_prompt_mean_rejects_none_group_index(miles): samples = [_make_sample(miles, None, 0, 2, [1, 1]), _make_sample(miles, None, 1, 2, [1, 0])] with pytest.raises(ValueError, match="group_index"):