From 0fe5e981c0cfef3fb1e4c765ed85ff20a416822b Mon Sep 17 00:00:00 2001 From: EazyReal <8047065+EazyReal@users.noreply.github.com> Date: Wed, 17 Jun 2026 08:26:38 +0000 Subject: [PATCH 1/3] feat(loss): add --loss-aggregation for the four ScaleRL pg_loss modes --- docs/en/get_started/customization.md | 50 ++++++- slime/backends/megatron_utils/actor.py | 11 +- slime/backends/megatron_utils/cp_utils.py | 22 ++- slime/backends/megatron_utils/data.py | 1 + slime/backends/megatron_utils/loss.py | 64 ++++++++- slime/backends/megatron_utils/model.py | 1 + slime/ray/rollout.py | 36 ++++- slime/utils/arguments.py | 97 +++++++++++++- tests/test_cp_utils.py | 115 ++++++++++++++++ tests/test_megatron_argument_validation.py | 148 +++++++++++++++++++++ tests/test_rollout_validation.py | 102 ++++++++++++++ 11 files changed, 627 insertions(+), 20 deletions(-) diff --git a/docs/en/get_started/customization.md b/docs/en/get_started/customization.md index 3f31ee493d..e6d665fa63 100644 --- a/docs/en/get_started/customization.md +++ b/docs/en/get_started/customization.md @@ -295,8 +295,54 @@ def get_pg_loss_reducer( ``` **Use Cases**: -- Dr.GRPO: Divide by a constant instead of effective token count -- Custom loss normalization strategies +- Custom loss normalization strategies not covered by `--loss-aggregation` + +> The four standard loss-aggregation modes (GRPO sample average, DAPO prompt +> average, token average, Dr.GRPO constant divisor) are available first-class via +> `--loss-aggregation` (see below) — no custom reducer needed. Reach for this +> hook only for a normalization those modes do not express. When set, it takes +> precedence over `--loss-aggregation`. + +**Built-in modes — `--loss-aggregation`** + +`--loss-aggregation {sample_mean,prompt_mean,token_mean,constant}` selects how +pg_loss is aggregated across a training step (pg_loss only; every other metric +keeps the default sample-mean reducer — same scope as the custom hook above). +Modes follow the ScaleRL taxonomy ([arXiv:2510.13786](https://arxiv.org/abs/2510.13786) §3.2): + +| Mode | Paper | pg_loss denominator | +| :--- | :--- | :--- | +| `sample_mean` (default) | GRPO sample average | Per-rollout token-weighted mean — each rollout contributes equally regardless of fan-out. Byte-identical to slime's prior default. | +| `prompt_mean` | DAPO prompt average | Per-prompt-group token-weighted mean (all rollouts sharing a prompt share one denominator). | +| `token_mean` | token average | Global per-token mean. The legacy `--calculate-per-token-loss` flag is exactly this mode (see below); prefer `--loss-aggregation=token_mean`. | +| `constant` | Dr.GRPO ([arXiv:2503.20783](https://arxiv.org/abs/2503.20783)) | `sum(token_loss * loss_mask) / L`, where `L = --loss-aggregation-divisor` (e.g. the max context length). | + +`--loss-aggregation-divisor L` is required (validated `> 0` at startup) only for +`constant`; it is ignored for the other modes (passing it with any other mode +fails at startup). The default (`sample_mean`) leaves behavior unchanged. + +The `constant` denominator `L` is *per-token*, not per-step: each mode's reducer +returns a step sum that is then averaged by the usual `/ step_global_batch_size` +(identical structure across all four modes). So the effective per-step +normalization for `constant` is `/ (L * step_global_batch_size)` — `L` only sets +the data-independent per-token scale; the `/ step_global_batch_size` step average +is applied on top exactly as for the other modes. + +`prompt_mean` weights every prompt group equally (each group's token-weighted +mean enters the step sum once, all under the same `/ step_global_batch_size` +divisor). Its absolute scale differs from a strict `1/P` DAPO average by a +constant factor (`P / N`, prompts over rollouts), which the learning rate +absorbs; the relative per-prompt weighting is uniform. + +**Legacy `--calculate-per-token-loss`.** `--calculate-per-token-loss` is not a +separate axis: it *is* the `token_mean` point on `--loss-aggregation`. The two +spellings are reconciled at startup so the reported objective is honest and there +is one knob: `--loss-aggregation=token_mean` sets `--calculate-per-token-loss`, +and `--calculate-per-token-loss` (on the default `sample_mean`) is relabeled to +`token_mean` with no behavior change. Existing `--calculate-per-token-loss` +recipes keep working unchanged; prefer `--loss-aggregation=token_mean` in new +recipes. Combining `--calculate-per-token-loss` with `prompt_mean` or `constant` +(distinct objectives) fails loud at startup. --- diff --git a/slime/backends/megatron_utils/actor.py b/slime/backends/megatron_utils/actor.py index 6bdc2dd3fe..6615dead88 100644 --- a/slime/backends/megatron_utils/actor.py +++ b/slime/backends/megatron_utils/actor.py @@ -235,12 +235,11 @@ def _get_rollout_data(self, rollout_data_ref: Box) -> RolloutBatch: rollout_data["loss_masks"] = [ t.to(device=device, dtype=torch.int, non_blocking=True) for t in rollout_data["loss_masks"] ] - if "rollout_mask_sums" in rollout_data: - # Promote precomputed per-rollout mask totals to GPU tensors here - # (matching loss_masks) so the loss reducer can just divide. - rollout_data["rollout_mask_sums"] = rollout_data["rollout_mask_sums"].to( - device=device, dtype=torch.float32, non_blocking=True - ) + for mask_sums_key in ("rollout_mask_sums", "prompt_mask_sums"): + if mask_sums_key in rollout_data: + rollout_data[mask_sums_key] = rollout_data[mask_sums_key].to( + device=device, dtype=torch.float32, non_blocking=True + ) if "multimodal_train_inputs" in rollout_data: # Move multimodal training tensors to GPU in advance rollout_data["multimodal_train_inputs"] = [ diff --git a/slime/backends/megatron_utils/cp_utils.py b/slime/backends/megatron_utils/cp_utils.py index a97c45cc42..aa58a1a756 100644 --- a/slime/backends/megatron_utils/cp_utils.py +++ b/slime/backends/megatron_utils/cp_utils.py @@ -50,6 +50,7 @@ def get_sum_of_sample_mean( loss_masks: list[torch.Tensor], sample_denoms: list[torch.Tensor] | torch.Tensor | None = None, calculate_per_token_loss: bool = False, + constant_divisor: float | None = None, ) -> Callable[[torch.Tensor], torch.Tensor]: """ Calculate correct sample mean for CP. @@ -63,7 +64,17 @@ def get_sum_of_sample_mean( step level rather than per-mb is required — otherwise a rollout whose samples land in different micro-batches would get a partial denominator on each side. + + ``constant_divisor`` (Dr.GRPO, arXiv:2503.20783) divides the masked + token-sum by a fixed ``L``; being identical on every CP rank, Megatron's + gradient sum-allreduce already yields the full-batch value (no extra + all-reduce). Mutually exclusive with ``calculate_per_token_loss``. """ + if constant_divisor is not None and calculate_per_token_loss: + raise ValueError( + "constant_divisor (loss-aggregation=constant) and calculate_per_token_loss " + "(loss-aggregation=token_mean) are mutually exclusive aggregation modes." + ) if sample_denoms is None: sample_denoms = [m.sum() for m in loss_masks] @@ -75,7 +86,7 @@ def sum_of_sample_mean(x: torch.Tensor) -> torch.Tensor: [ (x_i * loss_mask_i).sum() / torch.clamp_min(denom, 1) for x_i, loss_mask_i, denom in zip( - x.split(response_lengths, dim=0), loss_masks, sample_denoms, strict=False + x.split(response_lengths, dim=0), loss_masks, sample_denoms, strict=True ) ] ) @@ -106,7 +117,7 @@ def sum_of_sample_mean(x: torch.Tensor) -> torch.Tensor: [ (x_i * chunked_loss_mask).sum() / torch.clamp_min(denom, 1) for x_i, chunked_loss_mask, denom in zip( - x.split(cp_chunk_lengths, dim=0), chunked_loss_masks, sample_denoms, strict=False + x.split(cp_chunk_lengths, dim=0), chunked_loss_masks, sample_denoms, strict=True ) ] ) @@ -121,6 +132,13 @@ def sum_of_token(x: torch.Tensor) -> torch.Tensor: ] ) + if constant_divisor is not None: + + def sum_of_constant(x: torch.Tensor) -> torch.Tensor: + return sum_of_token(x) / constant_divisor + + return sum_of_constant + return sum_of_sample_mean if not calculate_per_token_loss else sum_of_token diff --git a/slime/backends/megatron_utils/data.py b/slime/backends/megatron_utils/data.py index 00f319928f..a17c4dd20f 100644 --- a/slime/backends/megatron_utils/data.py +++ b/slime/backends/megatron_utils/data.py @@ -286,6 +286,7 @@ def log_rollout_data( "rollout_mask_sums", "rollout_top_p_token_ids", "rollout_top_p_token_offsets", + "prompt_mask_sums", "rollout_routed_experts", "global_batch_sizes", "num_microbatches", diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index 72afdfa66c..50a504a50a 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -874,6 +874,54 @@ def icepop_function( return pg_loss, loss_masks, metrics +def get_pg_loss_reducer( + args: Namespace, + batch: RolloutBatch, + *, + total_lengths: list[int], + response_lengths: list[int], + pg_loss_masks: list[torch.Tensor], + default_reducer: Callable[[torch.Tensor], torch.Tensor], +) -> Callable[[torch.Tensor], torch.Tensor]: + """The ``--loss-aggregation`` reducer for pg_loss. ``sample_mean`` returns + ``default_reducer`` unchanged, keeping the prior default byte-identical. + """ + mode = getattr(args, "loss_aggregation", "sample_mean") + if mode in ("sample_mean", "token_mean"): + # token_mean is aliased onto --calculate-per-token-loss, so + # default_reducer is already the per-token path. + return default_reducer + if mode == "prompt_mean": + prompt_mask_sums = batch.get("prompt_mask_sums") + if prompt_mask_sums is None: + # None would silently fall back to the per-sample mean; a custom + # convert path that drops prompt_mask_sums must fail, not degrade. + raise ValueError( + "--loss-aggregation=prompt_mean requires per-prompt-group mask sums " + "(batch['prompt_mask_sums']), but they are missing. A custom " + "--custom-convert-samples-to-train-data-path must populate " + "'prompt_mask_sums' (grouped by Sample.group_index)." + ) + # Never pass calculate_per_token_loss here: prompt_mean is a per-prompt-group + # token-weighted mean (sample_denoms=prompt_mask_sums); the per-token path would + # discard those denominators. Validation already rejects the combo — this keeps the + # reducer structurally unable to return the global per-token mean even if bypassed. + return get_sum_of_sample_mean( + total_lengths, + response_lengths, + pg_loss_masks, + prompt_mask_sums, + ) + if mode == "constant": + return get_sum_of_sample_mean( + total_lengths, + response_lengths, + pg_loss_masks, + constant_divisor=args.loss_aggregation_divisor, + ) + raise ValueError(f"Unknown --loss-aggregation mode: {mode!r}") + + def policy_loss_function( args: Namespace, batch: RolloutBatch, @@ -1024,16 +1072,24 @@ def policy_loss_function( args.calculate_per_token_loss, ) - # Determine pg_loss reducer: use custom if specified, otherwise default + # Under TIS/RS rejected tokens are zeroed in modified_response_masks. + pg_loss_masks = modified_response_masks if (args.get_mismatch_metrics or args.use_tis) else batch["loss_masks"] + + # Custom reducer path takes precedence over --loss-aggregation. if getattr(args, "custom_pg_loss_reducer_function_path", None) is not None: custom_pg_loss_reducer_func = load_function(args.custom_pg_loss_reducer_function_path) - # Determine which loss_masks to use for pg_loss reducer - pg_loss_masks = modified_response_masks if (args.get_mismatch_metrics or args.use_tis) else batch["loss_masks"] pg_loss_reducer = custom_pg_loss_reducer_func( total_lengths, response_lengths, pg_loss_masks, args.calculate_per_token_loss ) else: - pg_loss_reducer = sum_of_sample_mean + pg_loss_reducer = get_pg_loss_reducer( + args, + batch, + total_lengths=total_lengths, + response_lengths=response_lengths, + pg_loss_masks=pg_loss_masks, + default_reducer=sum_of_sample_mean, + ) pg_loss = pg_loss_reducer(pg_loss) pg_clipfrac = sum_of_sample_mean(pg_clipfrac) diff --git a/slime/backends/megatron_utils/model.py b/slime/backends/megatron_utils/model.py index 1ad6cd7957..5c972aaf53 100644 --- a/slime/backends/megatron_utils/model.py +++ b/slime/backends/megatron_utils/model.py @@ -593,6 +593,7 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p "rollout_log_probs", "teacher_log_probs", "rollout_mask_sums", + "prompt_mask_sums", ], ), args.data_pad_size_multiplier, diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index ff571101af..fed0537545 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -95,11 +95,12 @@ def _tensorize_rollout_data_for_training(rollout_data: dict[str, Any]) -> None: for mm_dict in rollout_data["multimodal_train_inputs"] ] - if "rollout_mask_sums" in rollout_data: - rollout_data["rollout_mask_sums"] = _cpu_tensor( - rollout_data["rollout_mask_sums"], - dtype=torch.float32, - ) + for mask_sums_key in ("rollout_mask_sums", "prompt_mask_sums"): + if mask_sums_key in rollout_data: + rollout_data[mask_sums_key] = _cpu_tensor( + rollout_data[mask_sums_key], + dtype=torch.float32, + ) @dataclasses.dataclass @@ -777,6 +778,30 @@ def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sampl rollout_total_mask[rid] = rollout_total_mask.get(rid, 0) + ms train_data["rollout_mask_sums"] = [rollout_total_mask[rid] for rid in rollout_id_list] + # prompt_mask_sums: per-prompt-group mask total, summed here at the step + # level (every sibling is visible) and broadcast per-sample so a group + # split across micro-batches by packing still divides by its whole total. + # Built only under prompt_mean — the other modes never read it, so the + # default (sample_mean) batch stays byte-identical with no extra key. + # The broadcast per-sample full-group denom is step/mb-independent, but + # prompt_mean still requires the whole group within one step (enforced + # in slime_validate_args via global_batch_size % n_samples_per_prompt). + if getattr(self.args, "loss_aggregation", "sample_mean") == "prompt_mean": + group_total_mask: dict[int, int] = {} + for sample, ms in zip(samples, mask_sums_per_sample, strict=True): + # A None group_index would collapse unrelated prompts into one + # denominator, silently degrading prompt_mean -> sample_mean for + # that sample. The prompt-grouping invariant is violated, so fail. + if sample.group_index is None: + raise ValueError( + "--loss-aggregation prompt_mean requires every Sample.group_index to be set, " + "but a sample has group_index=None. prompt_mean divides each sample by its " + "prompt group's total mask; a None group_index means the sample belongs to no " + "prompt group, so its denominator is undefined." + ) + group_total_mask[sample.group_index] = group_total_mask.get(sample.group_index, 0) + ms + train_data["prompt_mask_sums"] = [group_total_mask[sample.group_index] for sample in samples] + # Overwrite raw_reward when available. Mixed-source batches may only # populate this field for a subset of samples (e.g. SWE but not code). if any(sample.metadata and "raw_reward" in sample.metadata for sample in samples): @@ -866,6 +891,7 @@ def _split_train_data_by_dp(self, data): "sample_indices", "rollout_ids", "rollout_mask_sums", + "prompt_mask_sums", "rollout_log_probs", "rollout_top_p_token_ids", "rollout_top_p_token_offsets", diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index d5cac9d44b..66c8d6bbad 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -28,6 +28,8 @@ def reset_arg(parser, name, **kwargs): if name in action.option_strings: if "default" in kwargs: action.default = kwargs["default"] + if "help" in kwargs: + action.help = kwargs["help"] break else: parser.add_argument(name, **kwargs) @@ -843,7 +845,16 @@ def add_algo_arguments(parser): ) reset_arg(parser, "--seed", type=int, default=1234) reset_arg(parser, "--clip-grad", type=float, default=1.0) - reset_arg(parser, "--calculate-per-token-loss", action="store_true") + reset_arg( + parser, + "--calculate-per-token-loss", + action="store_true", + help=( + "Legacy alias for --loss-aggregation=token_mean (the global per-token mean); " + "reconciled onto that mode at startup. Prefer --loss-aggregation=token_mean. " + "Incompatible with --loss-aggregation prompt_mean/constant." + ), + ) reset_arg(parser, "--lr", type=float, default=1e-6) parser.add_argument( @@ -1054,6 +1065,44 @@ def add_algo_arguments(parser): default=None, help="Path to a custom reducer function for pg_loss only. When set, pg_loss will use this custom reducer while other metrics (pg_clipfrac, ppo_kl, entropy_loss, etc.) still use the default sum_of_sample_mean.", ) + parser.add_argument( + "--loss-aggregation", + type=str, + default="sample_mean", + choices=["sample_mean", "prompt_mean", "token_mean", "constant"], + help=( + "How pg_loss is aggregated across the step (applies to pg_loss only; " + "pg_clipfrac, ppo_kl, entropy_loss, kl_loss keep the default sample-mean " + "reducer — same scope as --custom-pg-loss-reducer-function-path, which still " + "takes precedence when set). Modes follow the ScaleRL taxonomy " + "(arXiv:2510.13786 §3.2): " + "'sample_mean' (default; GRPO sample average) — each rollout's tokens are " + "averaged with the per-rollout token-weighted denominator, so every rollout " + "contributes equally regardless of fan-out (byte-identical to slime's prior " + "default); " + "'prompt_mean' (DAPO prompt average) — tokens are averaged over each prompt " + "group (all rollouts sharing a Sample.group_index share one denominator); " + "'token_mean' (token average) — global per-token mean; the legacy " + "--calculate-per-token-loss flag is exactly this mode (and is reconciled onto " + "it at startup); " + "'constant' (Dr.GRPO, arXiv:2503.20783) — masked token sum divided by a fixed " + "--loss-aggregation-divisor (e.g. the max context length)." + ), + ) + parser.add_argument( + "--loss-aggregation-divisor", + type=float, + default=None, + help=( + "Constant divisor L for --loss-aggregation=constant (Dr.GRPO). pg_loss is " + "aggregated as sum(token_loss * loss_mask) / L instead of any data-dependent " + "denominator. L is per-token: the usual / step_global_batch_size step average " + "is then applied on top (same structure as every mode), so the effective " + "per-step normalization is / (L * step_global_batch_size); L only sets the " + "data-independent per-token scale. Required and validated > 0 at startup only " + "when --loss-aggregation=constant; setting it with any other mode fails loud." + ), + ) parser.add_argument( "--use-routing-replay", @@ -1825,6 +1874,41 @@ def slime_validate_args(args): assert not (args.kl_coef != 0 and args.kl_loss_coef != 0), "Only one of kl_coef and kl_loss_coef can be set" + loss_aggregation = getattr(args, "loss_aggregation", "sample_mean") + # 1. Divisor only feeds the `constant` reducer: it is required+positive there and a + # silent no-op everywhere else, so fail loud rather than mislead about normalization. + divisor = getattr(args, "loss_aggregation_divisor", None) + if loss_aggregation == "constant": + # Dr.GRPO needs a fixed, positive divisor; fail at startup, not mid-train. + if divisor is None or not (divisor > 0): + raise ValueError( + "--loss-aggregation-divisor must be set to a positive value when " + f"--loss-aggregation=constant (got {divisor!r})." + ) + elif divisor is not None: + raise ValueError( + "--loss-aggregation-divisor is only used with --loss-aggregation=constant " + f"(got --loss-aggregation={loss_aggregation})." + ) + # 2. --calculate-per-token-loss is the legacy spelling of --loss-aggregation=token_mean; + # reconcile the two so there is one coherent axis (no silent label override). + if loss_aggregation == "token_mean": + # Forward: token_mean drives the per-token reducer/normalizer path. + args.calculate_per_token_loss = True + elif getattr(args, "calculate_per_token_loss", False): + if loss_aggregation == "sample_mean": + # Backward: the legacy flag *is* token_mean (sample_mean already returned the + # per-token default reducer here); relabel so the reported objective is honest. + loss_aggregation = args.loss_aggregation = "token_mean" + else: + # prompt_mean/constant are distinct objectives from the global per-token mean; + # --calculate-per-token-loss would silently override the reducer's denominator + # (prompt_mean's per-group total / constant's L). Fail loud instead. + raise ValueError( + f"--loss-aggregation={loss_aggregation} is incompatible with --calculate-per-token-loss " + "(use --loss-aggregation=token_mean for the per-token mean)." + ) + if args.advantage_estimator in ["reinforce_plus_plus", "reinforce_plus_plus_baseline"]: assert args.normalize_advantages, ( "The 'reinforce_plus_plus' and 'reinforce_plus_plus_baseline' advantage estimators " @@ -1948,6 +2032,17 @@ def slime_validate_args(args): ) args.global_batch_size = global_batch_size + if ( + getattr(args, "loss_aggregation", "sample_mean") == "prompt_mean" + and args.global_batch_size is not None + and args.global_batch_size % args.n_samples_per_prompt != 0 + ): + raise ValueError( + "--loss-aggregation prompt_mean requires global_batch_size to be a multiple of " + "n_samples_per_prompt so each prompt group stays within one training step " + f"(got global_batch_size={args.global_batch_size}, n_samples_per_prompt={args.n_samples_per_prompt})." + ) + if args.n_samples_per_prompt == 1: args.grpo_std_normalization = False logger.info("n_samples_per_prompt is set to 1, grpo_std_normalization will be set to False.") diff --git a/tests/test_cp_utils.py b/tests/test_cp_utils.py index c7d3abe9a2..8f900ccd4f 100644 --- a/tests/test_cp_utils.py +++ b/tests/test_cp_utils.py @@ -176,5 +176,120 @@ def test_cp_chunking_preserves_per_rollout_mean_report(monkeypatch): assert cp_total == pytest.approx(baseline) +@pytest.mark.unit +def test_constant_divisor_divides_masked_token_sum_by_L(): + """``constant`` (Dr.GRPO) aggregation: masked token sum / L, NOT any + data-dependent denominator.""" + total_lengths, response_lengths, loss_masks = _make_inputs([3, 3]) + L = 40.0 + reducer = get_sum_of_sample_mean(total_lengths, response_lengths, loss_masks, constant_divisor=L) + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + # sum of all masked tokens = 21; / 40 = 0.525. + assert reducer(x).item() == pytest.approx(21.0 / L) + + +@pytest.mark.unit +def test_prompt_mean_denom_is_per_group_token_sum(): + """``prompt_mean`` (DAPO): two prompts × G rollouts. Each sample's + denominator is its WHOLE prompt-group's mask total (all rollouts of that + prompt), distinct from per-rollout (sample_mean) and per-token (token_mean). + + Fixture: prompt P0 has 2 rollouts of length 2 (group mask sum = 4); + prompt P1 has 2 rollouts of length 4 (group mask sum = 8). + """ + # 4 samples: [P0r0=2, P0r1=2, P1r0=4, P1r1=4]. + total_lengths, response_lengths, loss_masks = _make_inputs([2, 2, 4, 4]) + # prompt_mask_sums: P0 group total = 2+2 = 4 (both P0 samples); P1 = 4+4 = 8. + prompt_denoms = _denoms(4, 4, 8, 8) + reducer = get_sum_of_sample_mean(total_lengths, response_lengths, loss_masks, prompt_denoms) + + # x laid out per sample: P0r0=[1,1], P0r1=[1,1], P1r0=[1,1,1,1], P1r1=[1,1,1,1] + x = torch.tensor([1.0] * 2 + [1.0] * 2 + [1.0] * 4 + [1.0] * 4) + # P0 group mean: (sum of P0 tokens)/4 = 4/4 = 1. P1 group mean: 8/8 = 1. Sum = 2. + assert reducer(x).item() == pytest.approx(2.0) + + # Now make the per-prompt content uneven so prompt_mean is numerically + # distinct from both sample_mean and token_mean. + x2 = torch.tensor([2.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]) + # prompt_mean: P0 sum = 2, /4 = 0.5; P1 sum = 8, /8 = 1.0 → 1.5. + assert reducer(x2).item() == pytest.approx(1.5) + # sample_mean (per-rollout denoms = own length): P0r0 2/2=1, P0r1 0/2=0, + # P1r0 4/4=1, P1r1 4/4=1 → 3.0. Distinct from prompt_mean. + sample_mean = get_sum_of_sample_mean(total_lengths, response_lengths, loss_masks) + assert sample_mean(x2).item() == pytest.approx(3.0) + assert sample_mean(x2).item() != pytest.approx(reducer(x2).item()) + # token_mean numerator (raw masked sum) = 10; distinct again. + token_sum = get_sum_of_sample_mean(total_lengths, response_lengths, loss_masks, calculate_per_token_loss=True) + assert token_sum(x2).item() == pytest.approx(10.0) + assert token_sum(x2).item() != pytest.approx(reducer(x2).item()) + + +@pytest.mark.unit +def test_constant_and_per_token_loss_are_mutually_exclusive(): + """The constant divisor and per-token-loss are distinct aggregation modes; + asking for both is a configuration error, rejected eagerly.""" + total_lengths, response_lengths, loss_masks = _make_inputs([3]) + with pytest.raises(ValueError, match="mutually exclusive"): + get_sum_of_sample_mean( + total_lengths, + response_lengths, + loss_masks, + calculate_per_token_loss=True, + constant_divisor=40.0, + ) + + +@pytest.mark.unit +def test_cp_chunking_preserves_constant_divisor(monkeypatch): + """CP rank-sum invariance for the constant divisor: the divisor is identical + on every CP rank, so summing per-rank reducer outputs reproduces cp=1.""" + from megatron.core import mpu as _mpu + + total_lengths = [12, 12] + response_lengths = [8, 8] + loss_masks = [torch.ones(r, dtype=torch.float32) for r in response_lengths] + L = 40.0 + x_full = [ + torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]), + torch.tensor([10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0]), + ] + x_concat = torch.cat(x_full) + + monkeypatch.setattr(_mpu, "get_context_parallel_world_size", lambda: 1) + monkeypatch.setattr(_mpu, "get_context_parallel_rank", lambda: 0) + reducer_cp1 = get_sum_of_sample_mean(total_lengths, response_lengths, loss_masks, constant_divisor=L) + baseline = reducer_cp1(x_concat).item() + + monkeypatch.setattr(_mpu, "get_context_parallel_world_size", lambda: 2) + cp_total = 0.0 + for cp_rank in range(2): + monkeypatch.setattr(_mpu, "get_context_parallel_rank", lambda r=cp_rank: r) + x_chunks_per_sample = [] + for tl, rl, x in zip(total_lengths, response_lengths, x_full, strict=True): + prompt_length = tl - rl + _, _, _, tokens_offset = get_logits_and_tokens_offset_with_cp(tl, rl) + chunk_0 = x[tokens_offset[0][0] - prompt_length : tokens_offset[0][1] - prompt_length] + chunk_1 = x[tokens_offset[1][0] - prompt_length : tokens_offset[1][1] - prompt_length] + x_chunks_per_sample.append(torch.cat([chunk_0, chunk_1])) + x_for_rank = torch.cat(x_chunks_per_sample) + reducer_cp2 = get_sum_of_sample_mean(total_lengths, response_lengths, loss_masks, constant_divisor=L) + cp_total += reducer_cp2(x_for_rank).item() + + assert cp_total == pytest.approx(baseline) + + +@pytest.mark.unit +def test_sample_denoms_length_mismatch_fails_loud(): + """``sample_denoms`` is one denominator per sample by construction (built parallel + to ``loss_masks``). The reducer zips them ``strict=True``, so a caller that supplies + a mismatched-length ``sample_denoms`` fails loud instead of silently dropping samples.""" + total_lengths, response_lengths, loss_masks = _make_inputs([3, 3, 3]) + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]) + # 3 samples but only 2 denoms — a construction bug the strict zip must surface. + reducer = get_sum_of_sample_mean(total_lengths, response_lengths, loss_masks, _denoms(9, 9)) + with pytest.raises(ValueError, match="zip"): + reducer(x) + + if __name__ == "__main__": raise SystemExit(pytest.main([__file__])) diff --git a/tests/test_megatron_argument_validation.py b/tests/test_megatron_argument_validation.py index db5c0bfd81..c684ae05db 100644 --- a/tests/test_megatron_argument_validation.py +++ b/tests/test_megatron_argument_validation.py @@ -303,6 +303,9 @@ def make_slime_validate_args(**overrides): update_weight_disk_dir=None, update_weight_delta_dir=None, update_weight_mode="full", + loss_aggregation="sample_mean", + loss_aggregation_divisor=None, + calculate_per_token_loss=False, ) values.update(overrides) return types.SimpleNamespace(**values) @@ -351,6 +354,151 @@ def test_slime_validate_args_preserves_zero_rollout_gpus_without_colocate(monkey assert args.offload_rollout is False +@pytest.mark.unit +@pytest.mark.parametrize("divisor", [None, 0.0, -1.0, float("nan")]) +def test_loss_aggregation_constant_rejects_nonpositive_divisor(monkeypatch, divisor): + module = load_slime_arguments_module(monkeypatch) + args = make_slime_validate_args(loss_aggregation="constant", loss_aggregation_divisor=divisor) + + with pytest.raises(ValueError, match="loss-aggregation-divisor"): + module.slime_validate_args(args) + + +@pytest.mark.unit +def test_loss_aggregation_constant_accepts_positive_divisor(monkeypatch): + module = load_slime_arguments_module(monkeypatch) + args = make_slime_validate_args(loss_aggregation="constant", loss_aggregation_divisor=40960.0) + + module.slime_validate_args(args) + + assert args.loss_aggregation_divisor == 40960.0 + + +@pytest.mark.unit +def test_loss_aggregation_token_mean_aliases_calculate_per_token_loss(monkeypatch): + module = load_slime_arguments_module(monkeypatch) + args = make_slime_validate_args(loss_aggregation="token_mean", calculate_per_token_loss=False) + + module.slime_validate_args(args) + + assert args.calculate_per_token_loss is True + + +@pytest.mark.unit +def test_calculate_per_token_loss_alone_reconciles_to_token_mean(monkeypatch): + # Legacy spelling: --calculate-per-token-loss on the default sample_mean IS token_mean. + # Reconcile the label (no behavior change) so there is one honest axis. + module = load_slime_arguments_module(monkeypatch) + args = make_slime_validate_args(calculate_per_token_loss=True) # loss_aggregation default sample_mean + + module.slime_validate_args(args) + + assert args.loss_aggregation == "token_mean" + assert args.calculate_per_token_loss is True + + +@pytest.mark.unit +def test_loss_aggregation_default_leaves_per_token_loss_off(monkeypatch): + module = load_slime_arguments_module(monkeypatch) + args = make_slime_validate_args() # default sample_mean + + module.slime_validate_args(args) + + assert args.calculate_per_token_loss is False + # No divisor required for non-constant modes. + assert args.loss_aggregation == "sample_mean" + + +@pytest.mark.unit +def test_loss_aggregation_prompt_mean_rejects_calculate_per_token_loss(monkeypatch): + # prompt_mean + --calculate-per-token-loss would silently degrade to the global + # per-token mean in the reducer, so validation must fail loud. + module = load_slime_arguments_module(monkeypatch) + args = make_slime_validate_args(loss_aggregation="prompt_mean", calculate_per_token_loss=True) + + with pytest.raises(ValueError, match="prompt_mean is incompatible with --calculate-per-token-loss"): + module.slime_validate_args(args) + + +@pytest.mark.unit +def test_loss_aggregation_prompt_mean_without_per_token_loss_passes(monkeypatch): + module = load_slime_arguments_module(monkeypatch) + args = make_slime_validate_args(loss_aggregation="prompt_mean", calculate_per_token_loss=False) + + module.slime_validate_args(args) + + assert args.loss_aggregation == "prompt_mean" + assert args.calculate_per_token_loss is False + + +@pytest.mark.unit +def test_loss_aggregation_constant_rejects_calculate_per_token_loss(monkeypatch): + # constant + --calculate-per-token-loss is silently inconsistent (the reducer makes them + # mutually exclusive); fail loud at startup instead. + module = load_slime_arguments_module(monkeypatch) + args = make_slime_validate_args( + loss_aggregation="constant", loss_aggregation_divisor=40960.0, calculate_per_token_loss=True + ) + + with pytest.raises(ValueError, match="constant is incompatible with --calculate-per-token-loss"): + module.slime_validate_args(args) + + +@pytest.mark.unit +@pytest.mark.parametrize("mode", ["sample_mean", "token_mean", "prompt_mean"]) +def test_loss_aggregation_divisor_rejected_on_non_constant_modes(monkeypatch, mode): + # A stray --loss-aggregation-divisor on a non-constant mode is silently ignored by the + # reducer; fail loud so the user is not misled about the normalization. + module = load_slime_arguments_module(monkeypatch) + args = make_slime_validate_args(loss_aggregation=mode, loss_aggregation_divisor=40960.0) + + with pytest.raises(ValueError, match="loss-aggregation-divisor is only used with --loss-aggregation=constant"): + module.slime_validate_args(args) + + +@pytest.mark.unit +def test_loss_aggregation_prompt_mean_rejects_non_multiple_global_batch_size(monkeypatch): + # prompt_mean normalizes per prompt group, but dp_schedule slices each step as a contiguous + # run of global_batch_size rollouts. When global_batch_size is not a multiple of + # n_samples_per_prompt a prompt group straddles a step boundary and its per-group + # normalization fragments across optimizer updates, so validation must fail loud. + module = load_slime_arguments_module(monkeypatch) + args = make_slime_validate_args(loss_aggregation="prompt_mean", n_samples_per_prompt=4, global_batch_size=6) + + with pytest.raises(ValueError, match="prompt_mean requires global_batch_size to be a multiple"): + module.slime_validate_args(args) + + +@pytest.mark.unit +def test_loss_aggregation_prompt_mean_accepts_multiple_global_batch_size(monkeypatch): + module = load_slime_arguments_module(monkeypatch) + args = make_slime_validate_args(loss_aggregation="prompt_mean", n_samples_per_prompt=4, global_batch_size=8) + + module.slime_validate_args(args) + + assert args.loss_aggregation == "prompt_mean" + assert args.global_batch_size == 8 + + +@pytest.mark.unit +@pytest.mark.parametrize("mode", ["sample_mean", "token_mean", "constant"]) +def test_loss_aggregation_non_prompt_mean_allows_non_multiple_global_batch_size(monkeypatch, mode): + # The step-straddle guard is specific to prompt_mean's per-group denominator; the other + # modes do not normalize per prompt group, so the same non-multiple gbs must NOT trip it. + module = load_slime_arguments_module(monkeypatch) + divisor = 40960.0 if mode == "constant" else None + args = make_slime_validate_args( + loss_aggregation=mode, + loss_aggregation_divisor=divisor, + n_samples_per_prompt=4, + global_batch_size=6, + ) + + module.slime_validate_args(args) + + assert args.global_batch_size == 6 + + @pytest.mark.unit def test_update_weight_delta_rejects_colocate(monkeypatch): module = load_slime_arguments_module(monkeypatch) diff --git a/tests/test_rollout_validation.py b/tests/test_rollout_validation.py index fa2ff34335..dc7626e9a9 100644 --- a/tests/test_rollout_validation.py +++ b/tests/test_rollout_validation.py @@ -1,3 +1,5 @@ +import types + import pytest from slime.ray.rollout_validation import validate_server_group_gpu_indices @@ -59,5 +61,105 @@ def test_validate_server_group_gpu_indices_reports_config_context(): assert "rollout_num_gpus_per_engine=2" in message +# --------------------------------------------------------------------------- +# _convert_samples_to_train_data: prompt_mask_sums build (--loss-aggregation) +# --------------------------------------------------------------------------- +# +# prompt_mask_sums (the per-prompt-group denominator for prompt_mean) is built +# only under --loss-aggregation=prompt_mean, and that build fails loud if any +# sample is missing its prompt group (group_index is None) — a None would +# silently collapse unrelated prompts into one denominator, degrading +# prompt_mean to sample_mean for that sample. The other three modes never read +# group_index, so they must neither build the key nor consult group_index. + + +def _make_convert_manager(loss_aggregation): + """A bare RolloutManager (no Ray/sglang init) wired just enough to call + ``_convert_samples_to_train_data``: no custom hooks, and reward + post-processing reduced to identity (advantage_estimator outside the + group-norm set), so the only behavior under test is the prompt_mask_sums + build + its group_index guard.""" + from slime.ray.rollout import RolloutManager + + # RolloutManager is @ray.remote-decorated (an ActorClass); unwrap to the plain + # class so a bare instance can be built to exercise the method directly. + meta = getattr(RolloutManager, "__ray_metadata__", None) + cls = meta.modified_class if meta is not None else RolloutManager + manager = cls.__new__(cls) + manager.custom_convert_samples_to_train_data_func = None + manager.custom_reward_post_process_func = None + manager.args = types.SimpleNamespace( + loss_aggregation=loss_aggregation, + reward_key=None, + advantage_estimator="reinforce", # outside the group-norm reshape path + rewards_normalization=False, + grpo_std_normalization=False, + ) + return manager + + +def _make_grouped_samples(group_indices): + """One Sample per entry; each carries a length-2 loss_mask so the + per-group mask totals are non-trivial.""" + from slime.utils.types import Sample + + samples = [] + for i, gid in enumerate(group_indices): + samples.append( + Sample( + index=i, + group_index=gid, + rollout_id=i, + tokens=[0, 1, 2, 3], + response_length=2, + reward=0.0, + loss_mask=[1, 1], + ) + ) + return samples + + +@pytest.mark.unit +def test_prompt_mean_fails_loud_on_none_group_index(): + """prompt_mean with a None group_index is a real break (the prompt-grouping + invariant is violated), so the convert step must raise — not silently + renumber the sample into its own singleton group.""" + pytest.importorskip("sglang") # RolloutManager import pulls sglang + manager = _make_convert_manager("prompt_mean") + samples = _make_grouped_samples([0, 0, None, 1]) + + with pytest.raises(ValueError, match="group_index"): + manager._convert_samples_to_train_data(samples) + + +@pytest.mark.unit +def test_prompt_mean_builds_per_group_mask_sums(): + """Sanity: with every group_index set, prompt_mask_sums is the per-group + mask total broadcast per sample (group 0 has two length-2 samples → 4).""" + pytest.importorskip("sglang") + manager = _make_convert_manager("prompt_mean") + samples = _make_grouped_samples([0, 0, 1]) # group 0: 2 samples, group 1: 1 + + train_data = manager._convert_samples_to_train_data(samples) + + # group 0 = 2+2 = 4 (both samples), group 1 = 2. + assert train_data["prompt_mask_sums"] == [4, 4, 2] + + +@pytest.mark.unit +@pytest.mark.parametrize("mode", ["sample_mean", "constant", "token_mean"]) +def test_non_prompt_mean_modes_ignore_none_group_index(mode): + """The other three modes never read group_index and never build + prompt_mask_sums, so a None group_index must NOT raise and the key must be + absent (keeping the default batch byte-identical — no extra key).""" + pytest.importorskip("sglang") + manager = _make_convert_manager(mode) + samples = _make_grouped_samples([0, None, 1]) + + train_data = manager._convert_samples_to_train_data(samples) + + assert "prompt_mask_sums" not in train_data + + if __name__ == "__main__": raise SystemExit(pytest.main([__file__])) From fa69ef7a88fb20aab7434e9e35eb7552121cae0a Mon Sep 17 00:00:00 2001 From: EazyReal <8047065+EazyReal@users.noreply.github.com> Date: Thu, 25 Jun 2026 17:35:00 -0700 Subject: [PATCH 2/3] fix(loss): normalize prompt_mean by prompt count --- docs/en/get_started/customization.md | 7 +-- slime/backends/megatron_utils/loss.py | 27 ++++------ slime/ray/rollout.py | 13 +---- slime/utils/arguments.py | 19 ++----- tests/test_cp_utils.py | 58 +++++++++++----------- tests/test_megatron_argument_validation.py | 19 +------ tests/test_rollout_validation.py | 34 +------------ 7 files changed, 52 insertions(+), 125 deletions(-) diff --git a/docs/en/get_started/customization.md b/docs/en/get_started/customization.md index e6d665fa63..4717a66595 100644 --- a/docs/en/get_started/customization.md +++ b/docs/en/get_started/customization.md @@ -328,11 +328,8 @@ normalization for `constant` is `/ (L * step_global_batch_size)` — `L` only se the data-independent per-token scale; the `/ step_global_batch_size` step average is applied on top exactly as for the other modes. -`prompt_mean` weights every prompt group equally (each group's token-weighted -mean enters the step sum once, all under the same `/ step_global_batch_size` -divisor). Its absolute scale differs from a strict `1/P` DAPO average by a -constant factor (`P / N`, prompts over rollouts), which the learning rate -absorbs; the relative per-prompt weighting is uniform. +`prompt_mean` weights every prompt group equally: each group's token-weighted +mean contributes once, and the final scalar is the mean over prompt groups. **Legacy `--calculate-per-token-loss`.** `--calculate-per-token-loss` is not a separate axis: it *is* the `token_mean` point on `--loss-aggregation`. The two diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index 50a504a50a..88d760b7a7 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -883,35 +883,32 @@ def get_pg_loss_reducer( pg_loss_masks: list[torch.Tensor], default_reducer: Callable[[torch.Tensor], torch.Tensor], ) -> Callable[[torch.Tensor], torch.Tensor]: - """The ``--loss-aggregation`` reducer for pg_loss. ``sample_mean`` returns - ``default_reducer`` unchanged, keeping the prior default byte-identical. - """ - mode = getattr(args, "loss_aggregation", "sample_mean") + """Return the pg_loss reducer selected by ``--loss-aggregation``.""" + mode = args.loss_aggregation if mode in ("sample_mean", "token_mean"): - # token_mean is aliased onto --calculate-per-token-loss, so - # default_reducer is already the per-token path. return default_reducer if mode == "prompt_mean": prompt_mask_sums = batch.get("prompt_mask_sums") if prompt_mask_sums is None: - # None would silently fall back to the per-sample mean; a custom - # convert path that drops prompt_mask_sums must fail, not degrade. raise ValueError( "--loss-aggregation=prompt_mean requires per-prompt-group mask sums " "(batch['prompt_mask_sums']), but they are missing. A custom " "--custom-convert-samples-to-train-data-path must populate " "'prompt_mask_sums' (grouped by Sample.group_index)." ) - # Never pass calculate_per_token_loss here: prompt_mean is a per-prompt-group - # token-weighted mean (sample_denoms=prompt_mask_sums); the per-token path would - # discard those denominators. Validation already rejects the combo — this keeps the - # reducer structurally unable to return the global per-token mean even if bypassed. - return get_sum_of_sample_mean( + reducer = get_sum_of_sample_mean( total_lengths, response_lengths, pg_loss_masks, prompt_mask_sums, ) + + # The train step divides non-token losses by response count; scale the + # prompt-group sum so the final scalar is the exact mean over prompts. + def prompt_mean_reducer(x: torch.Tensor) -> torch.Tensor: + return reducer(x) * args.n_samples_per_prompt + + return prompt_mean_reducer if mode == "constant": return get_sum_of_sample_mean( total_lengths, @@ -1072,11 +1069,9 @@ def policy_loss_function( args.calculate_per_token_loss, ) - # Under TIS/RS rejected tokens are zeroed in modified_response_masks. pg_loss_masks = modified_response_masks if (args.get_mismatch_metrics or args.use_tis) else batch["loss_masks"] - # Custom reducer path takes precedence over --loss-aggregation. - if getattr(args, "custom_pg_loss_reducer_function_path", None) is not None: + if args.custom_pg_loss_reducer_function_path is not None: custom_pg_loss_reducer_func = load_function(args.custom_pg_loss_reducer_function_path) pg_loss_reducer = custom_pg_loss_reducer_func( total_lengths, response_lengths, pg_loss_masks, args.calculate_per_token_loss diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index fed0537545..9de8cade61 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -778,20 +778,9 @@ def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sampl rollout_total_mask[rid] = rollout_total_mask.get(rid, 0) + ms train_data["rollout_mask_sums"] = [rollout_total_mask[rid] for rid in rollout_id_list] - # prompt_mask_sums: per-prompt-group mask total, summed here at the step - # level (every sibling is visible) and broadcast per-sample so a group - # split across micro-batches by packing still divides by its whole total. - # Built only under prompt_mean — the other modes never read it, so the - # default (sample_mean) batch stays byte-identical with no extra key. - # The broadcast per-sample full-group denom is step/mb-independent, but - # prompt_mean still requires the whole group within one step (enforced - # in slime_validate_args via global_batch_size % n_samples_per_prompt). - if getattr(self.args, "loss_aggregation", "sample_mean") == "prompt_mean": + if self.args.loss_aggregation == "prompt_mean": group_total_mask: dict[int, int] = {} for sample, ms in zip(samples, mask_sums_per_sample, strict=True): - # A None group_index would collapse unrelated prompts into one - # denominator, silently degrading prompt_mean -> sample_mean for - # that sample. The prompt-grouping invariant is violated, so fail. if sample.group_index is None: raise ValueError( "--loss-aggregation prompt_mean requires every Sample.group_index to be set, " diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index 66c8d6bbad..626f1910a2 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -1874,12 +1874,9 @@ def slime_validate_args(args): assert not (args.kl_coef != 0 and args.kl_loss_coef != 0), "Only one of kl_coef and kl_loss_coef can be set" - loss_aggregation = getattr(args, "loss_aggregation", "sample_mean") - # 1. Divisor only feeds the `constant` reducer: it is required+positive there and a - # silent no-op everywhere else, so fail loud rather than mislead about normalization. - divisor = getattr(args, "loss_aggregation_divisor", None) + loss_aggregation = args.loss_aggregation + divisor = args.loss_aggregation_divisor if loss_aggregation == "constant": - # Dr.GRPO needs a fixed, positive divisor; fail at startup, not mid-train. if divisor is None or not (divisor > 0): raise ValueError( "--loss-aggregation-divisor must be set to a positive value when " @@ -1890,20 +1887,12 @@ def slime_validate_args(args): "--loss-aggregation-divisor is only used with --loss-aggregation=constant " f"(got --loss-aggregation={loss_aggregation})." ) - # 2. --calculate-per-token-loss is the legacy spelling of --loss-aggregation=token_mean; - # reconcile the two so there is one coherent axis (no silent label override). if loss_aggregation == "token_mean": - # Forward: token_mean drives the per-token reducer/normalizer path. args.calculate_per_token_loss = True - elif getattr(args, "calculate_per_token_loss", False): + elif args.calculate_per_token_loss: if loss_aggregation == "sample_mean": - # Backward: the legacy flag *is* token_mean (sample_mean already returned the - # per-token default reducer here); relabel so the reported objective is honest. loss_aggregation = args.loss_aggregation = "token_mean" else: - # prompt_mean/constant are distinct objectives from the global per-token mean; - # --calculate-per-token-loss would silently override the reducer's denominator - # (prompt_mean's per-group total / constant's L). Fail loud instead. raise ValueError( f"--loss-aggregation={loss_aggregation} is incompatible with --calculate-per-token-loss " "(use --loss-aggregation=token_mean for the per-token mean)." @@ -2033,7 +2022,7 @@ def slime_validate_args(args): args.global_batch_size = global_batch_size if ( - getattr(args, "loss_aggregation", "sample_mean") == "prompt_mean" + args.loss_aggregation == "prompt_mean" and args.global_batch_size is not None and args.global_batch_size % args.n_samples_per_prompt != 0 ): diff --git a/tests/test_cp_utils.py b/tests/test_cp_utils.py index 8f900ccd4f..68a0df8a1c 100644 --- a/tests/test_cp_utils.py +++ b/tests/test_cp_utils.py @@ -15,6 +15,8 @@ from __future__ import annotations +from argparse import Namespace + # Import the helpers BEFORE the slime imports so the megatron stub lands # in sys.modules first. pytest's prepend importmode puts this file's # directory (``tests/``) on sys.path, which is what makes the bare-name @@ -27,6 +29,7 @@ get_logits_and_tokens_offset_with_cp, get_sum_of_sample_mean, ) +from slime.backends.megatron_utils.loss import get_pg_loss_reducer # noqa: E402 NUM_GPUS = 0 @@ -178,56 +181,61 @@ def test_cp_chunking_preserves_per_rollout_mean_report(monkeypatch): @pytest.mark.unit def test_constant_divisor_divides_masked_token_sum_by_L(): - """``constant`` (Dr.GRPO) aggregation: masked token sum / L, NOT any - data-dependent denominator.""" total_lengths, response_lengths, loss_masks = _make_inputs([3, 3]) L = 40.0 reducer = get_sum_of_sample_mean(total_lengths, response_lengths, loss_masks, constant_divisor=L) x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) - # sum of all masked tokens = 21; / 40 = 0.525. + assert reducer(x).item() == pytest.approx(21.0 / L) @pytest.mark.unit def test_prompt_mean_denom_is_per_group_token_sum(): - """``prompt_mean`` (DAPO): two prompts × G rollouts. Each sample's - denominator is its WHOLE prompt-group's mask total (all rollouts of that - prompt), distinct from per-rollout (sample_mean) and per-token (token_mean). - - Fixture: prompt P0 has 2 rollouts of length 2 (group mask sum = 4); - prompt P1 has 2 rollouts of length 4 (group mask sum = 8). - """ - # 4 samples: [P0r0=2, P0r1=2, P1r0=4, P1r1=4]. total_lengths, response_lengths, loss_masks = _make_inputs([2, 2, 4, 4]) - # prompt_mask_sums: P0 group total = 2+2 = 4 (both P0 samples); P1 = 4+4 = 8. prompt_denoms = _denoms(4, 4, 8, 8) reducer = get_sum_of_sample_mean(total_lengths, response_lengths, loss_masks, prompt_denoms) - # x laid out per sample: P0r0=[1,1], P0r1=[1,1], P1r0=[1,1,1,1], P1r1=[1,1,1,1] x = torch.tensor([1.0] * 2 + [1.0] * 2 + [1.0] * 4 + [1.0] * 4) - # P0 group mean: (sum of P0 tokens)/4 = 4/4 = 1. P1 group mean: 8/8 = 1. Sum = 2. assert reducer(x).item() == pytest.approx(2.0) - # Now make the per-prompt content uneven so prompt_mean is numerically - # distinct from both sample_mean and token_mean. x2 = torch.tensor([2.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]) - # prompt_mean: P0 sum = 2, /4 = 0.5; P1 sum = 8, /8 = 1.0 → 1.5. assert reducer(x2).item() == pytest.approx(1.5) - # sample_mean (per-rollout denoms = own length): P0r0 2/2=1, P0r1 0/2=0, - # P1r0 4/4=1, P1r1 4/4=1 → 3.0. Distinct from prompt_mean. + sample_mean = get_sum_of_sample_mean(total_lengths, response_lengths, loss_masks) assert sample_mean(x2).item() == pytest.approx(3.0) assert sample_mean(x2).item() != pytest.approx(reducer(x2).item()) - # token_mean numerator (raw masked sum) = 10; distinct again. + token_sum = get_sum_of_sample_mean(total_lengths, response_lengths, loss_masks, calculate_per_token_loss=True) assert token_sum(x2).item() == pytest.approx(10.0) assert token_sum(x2).item() != pytest.approx(reducer(x2).item()) +@pytest.mark.unit +def test_prompt_mean_reducer_matches_final_prompt_average(): + total_lengths, response_lengths, loss_masks = _make_inputs([2, 2, 4, 4]) + batch = {"prompt_mask_sums": _denoms(4, 4, 8, 8)} + default_reducer = get_sum_of_sample_mean(total_lengths, response_lengths, loss_masks) + args = Namespace( + loss_aggregation="prompt_mean", + n_samples_per_prompt=2, + ) + reducer = get_pg_loss_reducer( + args, + batch, + total_lengths=total_lengths, + response_lengths=response_lengths, + pg_loss_masks=loss_masks, + default_reducer=default_reducer, + ) + + x = torch.tensor([2.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]) + step_global_batch_size = 4 + + assert (reducer(x) / step_global_batch_size).item() == pytest.approx(0.75) + + @pytest.mark.unit def test_constant_and_per_token_loss_are_mutually_exclusive(): - """The constant divisor and per-token-loss are distinct aggregation modes; - asking for both is a configuration error, rejected eagerly.""" total_lengths, response_lengths, loss_masks = _make_inputs([3]) with pytest.raises(ValueError, match="mutually exclusive"): get_sum_of_sample_mean( @@ -241,8 +249,6 @@ def test_constant_and_per_token_loss_are_mutually_exclusive(): @pytest.mark.unit def test_cp_chunking_preserves_constant_divisor(monkeypatch): - """CP rank-sum invariance for the constant divisor: the divisor is identical - on every CP rank, so summing per-rank reducer outputs reproduces cp=1.""" from megatron.core import mpu as _mpu total_lengths = [12, 12] @@ -280,12 +286,8 @@ def test_cp_chunking_preserves_constant_divisor(monkeypatch): @pytest.mark.unit def test_sample_denoms_length_mismatch_fails_loud(): - """``sample_denoms`` is one denominator per sample by construction (built parallel - to ``loss_masks``). The reducer zips them ``strict=True``, so a caller that supplies - a mismatched-length ``sample_denoms`` fails loud instead of silently dropping samples.""" total_lengths, response_lengths, loss_masks = _make_inputs([3, 3, 3]) x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]) - # 3 samples but only 2 denoms — a construction bug the strict zip must surface. reducer = get_sum_of_sample_mean(total_lengths, response_lengths, loss_masks, _denoms(9, 9)) with pytest.raises(ValueError, match="zip"): reducer(x) diff --git a/tests/test_megatron_argument_validation.py b/tests/test_megatron_argument_validation.py index c684ae05db..b25b23c27a 100644 --- a/tests/test_megatron_argument_validation.py +++ b/tests/test_megatron_argument_validation.py @@ -386,10 +386,8 @@ def test_loss_aggregation_token_mean_aliases_calculate_per_token_loss(monkeypatc @pytest.mark.unit def test_calculate_per_token_loss_alone_reconciles_to_token_mean(monkeypatch): - # Legacy spelling: --calculate-per-token-loss on the default sample_mean IS token_mean. - # Reconcile the label (no behavior change) so there is one honest axis. module = load_slime_arguments_module(monkeypatch) - args = make_slime_validate_args(calculate_per_token_loss=True) # loss_aggregation default sample_mean + args = make_slime_validate_args(calculate_per_token_loss=True) module.slime_validate_args(args) @@ -400,19 +398,16 @@ def test_calculate_per_token_loss_alone_reconciles_to_token_mean(monkeypatch): @pytest.mark.unit def test_loss_aggregation_default_leaves_per_token_loss_off(monkeypatch): module = load_slime_arguments_module(monkeypatch) - args = make_slime_validate_args() # default sample_mean + args = make_slime_validate_args() module.slime_validate_args(args) assert args.calculate_per_token_loss is False - # No divisor required for non-constant modes. assert args.loss_aggregation == "sample_mean" @pytest.mark.unit def test_loss_aggregation_prompt_mean_rejects_calculate_per_token_loss(monkeypatch): - # prompt_mean + --calculate-per-token-loss would silently degrade to the global - # per-token mean in the reducer, so validation must fail loud. module = load_slime_arguments_module(monkeypatch) args = make_slime_validate_args(loss_aggregation="prompt_mean", calculate_per_token_loss=True) @@ -433,8 +428,6 @@ def test_loss_aggregation_prompt_mean_without_per_token_loss_passes(monkeypatch) @pytest.mark.unit def test_loss_aggregation_constant_rejects_calculate_per_token_loss(monkeypatch): - # constant + --calculate-per-token-loss is silently inconsistent (the reducer makes them - # mutually exclusive); fail loud at startup instead. module = load_slime_arguments_module(monkeypatch) args = make_slime_validate_args( loss_aggregation="constant", loss_aggregation_divisor=40960.0, calculate_per_token_loss=True @@ -447,8 +440,6 @@ def test_loss_aggregation_constant_rejects_calculate_per_token_loss(monkeypatch) @pytest.mark.unit @pytest.mark.parametrize("mode", ["sample_mean", "token_mean", "prompt_mean"]) def test_loss_aggregation_divisor_rejected_on_non_constant_modes(monkeypatch, mode): - # A stray --loss-aggregation-divisor on a non-constant mode is silently ignored by the - # reducer; fail loud so the user is not misled about the normalization. module = load_slime_arguments_module(monkeypatch) args = make_slime_validate_args(loss_aggregation=mode, loss_aggregation_divisor=40960.0) @@ -458,10 +449,6 @@ def test_loss_aggregation_divisor_rejected_on_non_constant_modes(monkeypatch, mo @pytest.mark.unit def test_loss_aggregation_prompt_mean_rejects_non_multiple_global_batch_size(monkeypatch): - # prompt_mean normalizes per prompt group, but dp_schedule slices each step as a contiguous - # run of global_batch_size rollouts. When global_batch_size is not a multiple of - # n_samples_per_prompt a prompt group straddles a step boundary and its per-group - # normalization fragments across optimizer updates, so validation must fail loud. module = load_slime_arguments_module(monkeypatch) args = make_slime_validate_args(loss_aggregation="prompt_mean", n_samples_per_prompt=4, global_batch_size=6) @@ -483,8 +470,6 @@ def test_loss_aggregation_prompt_mean_accepts_multiple_global_batch_size(monkeyp @pytest.mark.unit @pytest.mark.parametrize("mode", ["sample_mean", "token_mean", "constant"]) def test_loss_aggregation_non_prompt_mean_allows_non_multiple_global_batch_size(monkeypatch, mode): - # The step-straddle guard is specific to prompt_mean's per-group denominator; the other - # modes do not normalize per prompt group, so the same non-multiple gbs must NOT trip it. module = load_slime_arguments_module(monkeypatch) divisor = 40960.0 if mode == "constant" else None args = make_slime_validate_args( diff --git a/tests/test_rollout_validation.py b/tests/test_rollout_validation.py index dc7626e9a9..b345d9f15d 100644 --- a/tests/test_rollout_validation.py +++ b/tests/test_rollout_validation.py @@ -61,28 +61,9 @@ def test_validate_server_group_gpu_indices_reports_config_context(): assert "rollout_num_gpus_per_engine=2" in message -# --------------------------------------------------------------------------- -# _convert_samples_to_train_data: prompt_mask_sums build (--loss-aggregation) -# --------------------------------------------------------------------------- -# -# prompt_mask_sums (the per-prompt-group denominator for prompt_mean) is built -# only under --loss-aggregation=prompt_mean, and that build fails loud if any -# sample is missing its prompt group (group_index is None) — a None would -# silently collapse unrelated prompts into one denominator, degrading -# prompt_mean to sample_mean for that sample. The other three modes never read -# group_index, so they must neither build the key nor consult group_index. - - def _make_convert_manager(loss_aggregation): - """A bare RolloutManager (no Ray/sglang init) wired just enough to call - ``_convert_samples_to_train_data``: no custom hooks, and reward - post-processing reduced to identity (advantage_estimator outside the - group-norm set), so the only behavior under test is the prompt_mask_sums - build + its group_index guard.""" from slime.ray.rollout import RolloutManager - # RolloutManager is @ray.remote-decorated (an ActorClass); unwrap to the plain - # class so a bare instance can be built to exercise the method directly. meta = getattr(RolloutManager, "__ray_metadata__", None) cls = meta.modified_class if meta is not None else RolloutManager manager = cls.__new__(cls) @@ -91,7 +72,7 @@ def _make_convert_manager(loss_aggregation): manager.args = types.SimpleNamespace( loss_aggregation=loss_aggregation, reward_key=None, - advantage_estimator="reinforce", # outside the group-norm reshape path + advantage_estimator="reinforce", rewards_normalization=False, grpo_std_normalization=False, ) @@ -99,8 +80,6 @@ def _make_convert_manager(loss_aggregation): def _make_grouped_samples(group_indices): - """One Sample per entry; each carries a length-2 loss_mask so the - per-group mask totals are non-trivial.""" from slime.utils.types import Sample samples = [] @@ -121,9 +100,6 @@ def _make_grouped_samples(group_indices): @pytest.mark.unit def test_prompt_mean_fails_loud_on_none_group_index(): - """prompt_mean with a None group_index is a real break (the prompt-grouping - invariant is violated), so the convert step must raise — not silently - renumber the sample into its own singleton group.""" pytest.importorskip("sglang") # RolloutManager import pulls sglang manager = _make_convert_manager("prompt_mean") samples = _make_grouped_samples([0, 0, None, 1]) @@ -134,24 +110,18 @@ def test_prompt_mean_fails_loud_on_none_group_index(): @pytest.mark.unit def test_prompt_mean_builds_per_group_mask_sums(): - """Sanity: with every group_index set, prompt_mask_sums is the per-group - mask total broadcast per sample (group 0 has two length-2 samples → 4).""" pytest.importorskip("sglang") manager = _make_convert_manager("prompt_mean") - samples = _make_grouped_samples([0, 0, 1]) # group 0: 2 samples, group 1: 1 + samples = _make_grouped_samples([0, 0, 1]) train_data = manager._convert_samples_to_train_data(samples) - # group 0 = 2+2 = 4 (both samples), group 1 = 2. assert train_data["prompt_mask_sums"] == [4, 4, 2] @pytest.mark.unit @pytest.mark.parametrize("mode", ["sample_mean", "constant", "token_mean"]) def test_non_prompt_mean_modes_ignore_none_group_index(mode): - """The other three modes never read group_index and never build - prompt_mask_sums, so a None group_index must NOT raise and the key must be - absent (keeping the default batch byte-identical — no extra key).""" pytest.importorskip("sglang") manager = _make_convert_manager(mode) samples = _make_grouped_samples([0, None, 1]) From 9ca9ceb8f8d591ab62a280648f2113ca2173cc4f Mon Sep 17 00:00:00 2001 From: EazyReal <8047065+EazyReal@users.noreply.github.com> Date: Thu, 25 Jun 2026 18:29:01 -0700 Subject: [PATCH 3/3] docs: add zh loss aggregation guide --- docs/en/get_started/customization.md | 27 +++++++------------ docs/zh/get_started/customization.md | 40 +++++++++++++++++++++++++--- 2 files changed, 47 insertions(+), 20 deletions(-) diff --git a/docs/en/get_started/customization.md b/docs/en/get_started/customization.md index 4717a66595..bed3b9e352 100644 --- a/docs/en/get_started/customization.md +++ b/docs/en/get_started/customization.md @@ -18,7 +18,7 @@ Below is a summary of all available customization interfaces and their purposes. | [`--rollout-data-postprocess-path`](#8-rollout-data-postprocess---rollout-data-postprocess-path) | Post-process rollout data after log probs are computed. | | [`--custom-loss-function-path`](#9-custom-loss-function---custom-loss-function-path) | Implement custom training loss computation. | | [`--custom-tis-function-path`](#10-custom-tisrs-function---custom-tis-function-path) | Implement custom importance sampling for off-policy correction. | -| [`--custom-pg-loss-reducer-function-path`](#11-custom-pg-loss-reducer---custom-pg-loss-reducer-function-path) | Customize pg_loss reduction (e.g., for Dr.GRPO). | +| [`--custom-pg-loss-reducer-function-path`](#11-custom-pg-loss-reducer---custom-pg-loss-reducer-function-path) | Customize pg_loss reduction not covered by `--loss-aggregation`. | | [`--custom-reward-post-process-path`](#12-reward-post-processing---custom-reward-post-process-path) | Custom post-processing of rewards before advantage computation. | | [`--custom-convert-samples-to-train-data-path`](#13-samples-to-train-data-conversion---custom-convert-samples-to-train-data-path) | Override the conversion of samples to training data format. | | [`--custom-rollout-log-function-path`](#14-logging-functions) | Custom logging for training rollouts. | @@ -297,10 +297,8 @@ def get_pg_loss_reducer( **Use Cases**: - Custom loss normalization strategies not covered by `--loss-aggregation` -> The four standard loss-aggregation modes (GRPO sample average, DAPO prompt -> average, token average, Dr.GRPO constant divisor) are available first-class via -> `--loss-aggregation` (see below) — no custom reducer needed. Reach for this -> hook only for a normalization those modes do not express. When set, it takes +> Use the custom reducer hook only for normalizations that the built-in +> `--loss-aggregation` modes do not express. When set, the custom reducer takes > precedence over `--loss-aggregation`. **Built-in modes — `--loss-aggregation`** @@ -312,14 +310,14 @@ Modes follow the ScaleRL taxonomy ([arXiv:2510.13786](https://arxiv.org/abs/2510 | Mode | Paper | pg_loss denominator | | :--- | :--- | :--- | -| `sample_mean` (default) | GRPO sample average | Per-rollout token-weighted mean — each rollout contributes equally regardless of fan-out. Byte-identical to slime's prior default. | +| `sample_mean` (default) | GRPO sample average | Per-rollout token-weighted mean; each rollout contributes equally regardless of fan-out. | | `prompt_mean` | DAPO prompt average | Per-prompt-group token-weighted mean (all rollouts sharing a prompt share one denominator). | -| `token_mean` | token average | Global per-token mean. The legacy `--calculate-per-token-loss` flag is exactly this mode (see below); prefer `--loss-aggregation=token_mean`. | +| `token_mean` | token average | Global per-token mean. `--calculate-per-token-loss` is accepted as an alias for this mode when no conflicting `--loss-aggregation` is set. | | `constant` | Dr.GRPO ([arXiv:2503.20783](https://arxiv.org/abs/2503.20783)) | `sum(token_loss * loss_mask) / L`, where `L = --loss-aggregation-divisor` (e.g. the max context length). | `--loss-aggregation-divisor L` is required (validated `> 0` at startup) only for `constant`; it is ignored for the other modes (passing it with any other mode -fails at startup). The default (`sample_mean`) leaves behavior unchanged. +fails at startup). If `--loss-aggregation` is not set, slime uses `sample_mean`. The `constant` denominator `L` is *per-token*, not per-step: each mode's reducer returns a step sum that is then averaged by the usual `/ step_global_batch_size` @@ -331,15 +329,10 @@ is applied on top exactly as for the other modes. `prompt_mean` weights every prompt group equally: each group's token-weighted mean contributes once, and the final scalar is the mean over prompt groups. -**Legacy `--calculate-per-token-loss`.** `--calculate-per-token-loss` is not a -separate axis: it *is* the `token_mean` point on `--loss-aggregation`. The two -spellings are reconciled at startup so the reported objective is honest and there -is one knob: `--loss-aggregation=token_mean` sets `--calculate-per-token-loss`, -and `--calculate-per-token-loss` (on the default `sample_mean`) is relabeled to -`token_mean` with no behavior change. Existing `--calculate-per-token-loss` -recipes keep working unchanged; prefer `--loss-aggregation=token_mean` in new -recipes. Combining `--calculate-per-token-loss` with `prompt_mean` or `constant` -(distinct objectives) fails loud at startup. +**`--calculate-per-token-loss`.** This flag is an alias for +`--loss-aggregation=token_mean`. It can be used by itself, or with +`--loss-aggregation=token_mean`. Combining it with `prompt_mean` or `constant` +fails at startup because those modes use different denominators. --- diff --git a/docs/zh/get_started/customization.md b/docs/zh/get_started/customization.md index fd067c04c9..4465f5c990 100644 --- a/docs/zh/get_started/customization.md +++ b/docs/zh/get_started/customization.md @@ -18,7 +18,7 @@ slime 通过函数路径参数提供了广泛的自定义能力。这些参数 | [`--rollout-data-postprocess-path`](#8-rollout-数据后处理---rollout-data-postprocess-path) | 在计算 log probabilities 后对 rollout 数据进行后处理。 | | [`--custom-loss-function-path`](#9-自定义损失函数---custom-loss-function-path) | 实现自定义训练损失计算。 | | [`--custom-tis-function-path`](#10-自定义-tisrs-函数---custom-tis-function-path) | 实现用于离策略(off-policy)校正的自定义重要性采样。 | -| [`--custom-pg-loss-reducer-function-path`](#11-自定义-pg-loss-reducer---custom-pg-loss-reducer-function-path) | 自定义 pg_loss 的归约方式(如 Dr.GRPO)。 | +| [`--custom-pg-loss-reducer-function-path`](#11-自定义-pg-loss-reducer---custom-pg-loss-reducer-function-path) | 自定义 `--loss-aggregation` 未覆盖的 pg_loss 归约方式。 | | [`--custom-reward-post-process-path`](#12-奖励后处理---custom-reward-post-process-path) | 在优势计算前对奖励进行自定义后处理。 | | [`--custom-convert-samples-to-train-data-path`](#13-样本转训练数据---custom-convert-samples-to-train-data-path) | 覆盖样本到训练数据格式的转换逻辑。 | | [`--custom-rollout-log-function-path`](#14-日志函数) | 训练 rollout 的自定义日志记录。 | @@ -295,8 +295,42 @@ def get_pg_loss_reducer( ``` **使用场景**: -- Dr.GRPO:除以常数而非有效 token 数 -- 自定义损失归一化策略 +- `--loss-aggregation` 未覆盖的自定义损失归一化策略 + +> 只有内置 `--loss-aggregation` 模式无法表达的归一化方式才需要使用自定义 +> reducer hook。设置后,自定义 reducer 会优先于 `--loss-aggregation` 生效。 + +**内置模式 — `--loss-aggregation`** + +`--loss-aggregation {sample_mean,prompt_mean,token_mean,constant}` 选择一个 +training step 内 pg_loss 的聚合方式(仅作用于 pg_loss;其他指标仍使用默认的 +sample-mean reducer,与上面的 custom hook 作用范围一致)。这些模式遵循 +ScaleRL taxonomy([arXiv:2510.13786](https://arxiv.org/abs/2510.13786) §3.2): + +| 模式 | 论文语义 | pg_loss 分母 | +| :--- | :--- | :--- | +| `sample_mean`(默认) | GRPO sample average | 每个 rollout 先做 token 加权均值,再对 rollout 求平均;每个 rollout 权重相同,与 fan-out 无关。 | +| `prompt_mean` | DAPO prompt average | 每个 prompt group 先做 token 加权均值;共享同一 prompt 的所有 rollout 使用同一个分母。 | +| `token_mean` | token average | 全局 per-token mean。当没有设置冲突的 `--loss-aggregation` 时,`--calculate-per-token-loss` 可作为该模式的 alias。 | +| `constant` | Dr.GRPO ([arXiv:2503.20783](https://arxiv.org/abs/2503.20783)) | `sum(token_loss * loss_mask) / L`,其中 `L = --loss-aggregation-divisor`(例如最大上下文长度)。 | + +`--loss-aggregation-divisor L` 只在 `constant` 模式下需要,并在启动时校验 +`> 0`;其他模式不使用它,若同时传入会在启动时报错。默认值 +为 `sample_mean`。 + +`constant` 的分母 `L` 是 per-token 的,不是 per-step 的:每种模式的 reducer +都会返回 step sum,然后再经过原有的 `/ step_global_batch_size` step average +(四种模式结构一致)。因此 `constant` 的实际 per-step 归一化是 +`/ (L * step_global_batch_size)`:`L` 只控制数据无关的 per-token scale, +`/ step_global_batch_size` 会像其他模式一样继续叠加。 + +`prompt_mean` 会让每个 prompt group 权重相同:每个 group 的 token 加权均值 +贡献一次,最终 scalar 是 prompt group 的均值。 + +**`--calculate-per-token-loss`。** 这个 flag 是 +`--loss-aggregation=token_mean` 的 alias。它可以单独使用,也可以与 +`--loss-aggregation=token_mean` 一起使用。将它与 `prompt_mean` 或 +`constant` 组合会在启动时报错,因为这些模式使用不同的分母。 ---