diff --git a/docs/en/get_started/customization.md b/docs/en/get_started/customization.md index 3f31ee493d..bed3b9e352 100644 --- a/docs/en/get_started/customization.md +++ b/docs/en/get_started/customization.md @@ -18,7 +18,7 @@ Below is a summary of all available customization interfaces and their purposes. | [`--rollout-data-postprocess-path`](#8-rollout-data-postprocess---rollout-data-postprocess-path) | Post-process rollout data after log probs are computed. | | [`--custom-loss-function-path`](#9-custom-loss-function---custom-loss-function-path) | Implement custom training loss computation. | | [`--custom-tis-function-path`](#10-custom-tisrs-function---custom-tis-function-path) | Implement custom importance sampling for off-policy correction. | -| [`--custom-pg-loss-reducer-function-path`](#11-custom-pg-loss-reducer---custom-pg-loss-reducer-function-path) | Customize pg_loss reduction (e.g., for Dr.GRPO). | +| [`--custom-pg-loss-reducer-function-path`](#11-custom-pg-loss-reducer---custom-pg-loss-reducer-function-path) | Customize pg_loss reduction not covered by `--loss-aggregation`. | | [`--custom-reward-post-process-path`](#12-reward-post-processing---custom-reward-post-process-path) | Custom post-processing of rewards before advantage computation. | | [`--custom-convert-samples-to-train-data-path`](#13-samples-to-train-data-conversion---custom-convert-samples-to-train-data-path) | Override the conversion of samples to training data format. | | [`--custom-rollout-log-function-path`](#14-logging-functions) | Custom logging for training rollouts. | @@ -295,8 +295,44 @@ def get_pg_loss_reducer( ``` **Use Cases**: -- Dr.GRPO: Divide by a constant instead of effective token count -- Custom loss normalization strategies +- Custom loss normalization strategies not covered by `--loss-aggregation` + +> Use the custom reducer hook only for normalizations that the built-in +> `--loss-aggregation` modes do not express. When set, the custom reducer takes +> precedence over `--loss-aggregation`. + +**Built-in modes — `--loss-aggregation`** + +`--loss-aggregation {sample_mean,prompt_mean,token_mean,constant}` selects how +pg_loss is aggregated across a training step (pg_loss only; every other metric +keeps the default sample-mean reducer — same scope as the custom hook above). +Modes follow the ScaleRL taxonomy ([arXiv:2510.13786](https://arxiv.org/abs/2510.13786) §3.2): + +| Mode | Paper | pg_loss denominator | +| :--- | :--- | :--- | +| `sample_mean` (default) | GRPO sample average | Per-rollout token-weighted mean; each rollout contributes equally regardless of fan-out. | +| `prompt_mean` | DAPO prompt average | Per-prompt-group token-weighted mean (all rollouts sharing a prompt share one denominator). | +| `token_mean` | token average | Global per-token mean. `--calculate-per-token-loss` is accepted as an alias for this mode when no conflicting `--loss-aggregation` is set. | +| `constant` | Dr.GRPO ([arXiv:2503.20783](https://arxiv.org/abs/2503.20783)) | `sum(token_loss * loss_mask) / L`, where `L = --loss-aggregation-divisor` (e.g. the max context length). | + +`--loss-aggregation-divisor L` is required (validated `> 0` at startup) only for +`constant`; it is ignored for the other modes (passing it with any other mode +fails at startup). If `--loss-aggregation` is not set, slime uses `sample_mean`. + +The `constant` denominator `L` is *per-token*, not per-step: each mode's reducer +returns a step sum that is then averaged by the usual `/ step_global_batch_size` +(identical structure across all four modes). So the effective per-step +normalization for `constant` is `/ (L * step_global_batch_size)` — `L` only sets +the data-independent per-token scale; the `/ step_global_batch_size` step average +is applied on top exactly as for the other modes. + +`prompt_mean` weights every prompt group equally: each group's token-weighted +mean contributes once, and the final scalar is the mean over prompt groups. + +**`--calculate-per-token-loss`.** This flag is an alias for +`--loss-aggregation=token_mean`. It can be used by itself, or with +`--loss-aggregation=token_mean`. Combining it with `prompt_mean` or `constant` +fails at startup because those modes use different denominators. --- diff --git a/docs/zh/get_started/customization.md b/docs/zh/get_started/customization.md index fd067c04c9..4465f5c990 100644 --- a/docs/zh/get_started/customization.md +++ b/docs/zh/get_started/customization.md @@ -18,7 +18,7 @@ slime 通过函数路径参数提供了广泛的自定义能力。这些参数 | [`--rollout-data-postprocess-path`](#8-rollout-数据后处理---rollout-data-postprocess-path) | 在计算 log probabilities 后对 rollout 数据进行后处理。 | | [`--custom-loss-function-path`](#9-自定义损失函数---custom-loss-function-path) | 实现自定义训练损失计算。 | | [`--custom-tis-function-path`](#10-自定义-tisrs-函数---custom-tis-function-path) | 实现用于离策略(off-policy)校正的自定义重要性采样。 | -| [`--custom-pg-loss-reducer-function-path`](#11-自定义-pg-loss-reducer---custom-pg-loss-reducer-function-path) | 自定义 pg_loss 的归约方式(如 Dr.GRPO)。 | +| [`--custom-pg-loss-reducer-function-path`](#11-自定义-pg-loss-reducer---custom-pg-loss-reducer-function-path) | 自定义 `--loss-aggregation` 未覆盖的 pg_loss 归约方式。 | | [`--custom-reward-post-process-path`](#12-奖励后处理---custom-reward-post-process-path) | 在优势计算前对奖励进行自定义后处理。 | | [`--custom-convert-samples-to-train-data-path`](#13-样本转训练数据---custom-convert-samples-to-train-data-path) | 覆盖样本到训练数据格式的转换逻辑。 | | [`--custom-rollout-log-function-path`](#14-日志函数) | 训练 rollout 的自定义日志记录。 | @@ -295,8 +295,42 @@ def get_pg_loss_reducer( ``` **使用场景**: -- Dr.GRPO:除以常数而非有效 token 数 -- 自定义损失归一化策略 +- `--loss-aggregation` 未覆盖的自定义损失归一化策略 + +> 只有内置 `--loss-aggregation` 模式无法表达的归一化方式才需要使用自定义 +> reducer hook。设置后,自定义 reducer 会优先于 `--loss-aggregation` 生效。 + +**内置模式 — `--loss-aggregation`** + +`--loss-aggregation {sample_mean,prompt_mean,token_mean,constant}` 选择一个 +training step 内 pg_loss 的聚合方式(仅作用于 pg_loss;其他指标仍使用默认的 +sample-mean reducer,与上面的 custom hook 作用范围一致)。这些模式遵循 +ScaleRL taxonomy([arXiv:2510.13786](https://arxiv.org/abs/2510.13786) §3.2): + +| 模式 | 论文语义 | pg_loss 分母 | +| :--- | :--- | :--- | +| `sample_mean`(默认) | GRPO sample average | 每个 rollout 先做 token 加权均值,再对 rollout 求平均;每个 rollout 权重相同,与 fan-out 无关。 | +| `prompt_mean` | DAPO prompt average | 每个 prompt group 先做 token 加权均值;共享同一 prompt 的所有 rollout 使用同一个分母。 | +| `token_mean` | token average | 全局 per-token mean。当没有设置冲突的 `--loss-aggregation` 时,`--calculate-per-token-loss` 可作为该模式的 alias。 | +| `constant` | Dr.GRPO ([arXiv:2503.20783](https://arxiv.org/abs/2503.20783)) | `sum(token_loss * loss_mask) / L`,其中 `L = --loss-aggregation-divisor`(例如最大上下文长度)。 | + +`--loss-aggregation-divisor L` 只在 `constant` 模式下需要,并在启动时校验 +`> 0`;其他模式不使用它,若同时传入会在启动时报错。默认值 +为 `sample_mean`。 + +`constant` 的分母 `L` 是 per-token 的,不是 per-step 的:每种模式的 reducer +都会返回 step sum,然后再经过原有的 `/ step_global_batch_size` step average +(四种模式结构一致)。因此 `constant` 的实际 per-step 归一化是 +`/ (L * step_global_batch_size)`:`L` 只控制数据无关的 per-token scale, +`/ step_global_batch_size` 会像其他模式一样继续叠加。 + +`prompt_mean` 会让每个 prompt group 权重相同:每个 group 的 token 加权均值 +贡献一次,最终 scalar 是 prompt group 的均值。 + +**`--calculate-per-token-loss`。** 这个 flag 是 +`--loss-aggregation=token_mean` 的 alias。它可以单独使用,也可以与 +`--loss-aggregation=token_mean` 一起使用。将它与 `prompt_mean` 或 +`constant` 组合会在启动时报错,因为这些模式使用不同的分母。 --- diff --git a/slime/backends/megatron_utils/actor.py b/slime/backends/megatron_utils/actor.py index 6bdc2dd3fe..6615dead88 100644 --- a/slime/backends/megatron_utils/actor.py +++ b/slime/backends/megatron_utils/actor.py @@ -235,12 +235,11 @@ def _get_rollout_data(self, rollout_data_ref: Box) -> RolloutBatch: rollout_data["loss_masks"] = [ t.to(device=device, dtype=torch.int, non_blocking=True) for t in rollout_data["loss_masks"] ] - if "rollout_mask_sums" in rollout_data: - # Promote precomputed per-rollout mask totals to GPU tensors here - # (matching loss_masks) so the loss reducer can just divide. - rollout_data["rollout_mask_sums"] = rollout_data["rollout_mask_sums"].to( - device=device, dtype=torch.float32, non_blocking=True - ) + for mask_sums_key in ("rollout_mask_sums", "prompt_mask_sums"): + if mask_sums_key in rollout_data: + rollout_data[mask_sums_key] = rollout_data[mask_sums_key].to( + device=device, dtype=torch.float32, non_blocking=True + ) if "multimodal_train_inputs" in rollout_data: # Move multimodal training tensors to GPU in advance rollout_data["multimodal_train_inputs"] = [ diff --git a/slime/backends/megatron_utils/cp_utils.py b/slime/backends/megatron_utils/cp_utils.py index a97c45cc42..aa58a1a756 100644 --- a/slime/backends/megatron_utils/cp_utils.py +++ b/slime/backends/megatron_utils/cp_utils.py @@ -50,6 +50,7 @@ def get_sum_of_sample_mean( loss_masks: list[torch.Tensor], sample_denoms: list[torch.Tensor] | torch.Tensor | None = None, calculate_per_token_loss: bool = False, + constant_divisor: float | None = None, ) -> Callable[[torch.Tensor], torch.Tensor]: """ Calculate correct sample mean for CP. @@ -63,7 +64,17 @@ def get_sum_of_sample_mean( step level rather than per-mb is required — otherwise a rollout whose samples land in different micro-batches would get a partial denominator on each side. + + ``constant_divisor`` (Dr.GRPO, arXiv:2503.20783) divides the masked + token-sum by a fixed ``L``; being identical on every CP rank, Megatron's + gradient sum-allreduce already yields the full-batch value (no extra + all-reduce). Mutually exclusive with ``calculate_per_token_loss``. """ + if constant_divisor is not None and calculate_per_token_loss: + raise ValueError( + "constant_divisor (loss-aggregation=constant) and calculate_per_token_loss " + "(loss-aggregation=token_mean) are mutually exclusive aggregation modes." + ) if sample_denoms is None: sample_denoms = [m.sum() for m in loss_masks] @@ -75,7 +86,7 @@ def sum_of_sample_mean(x: torch.Tensor) -> torch.Tensor: [ (x_i * loss_mask_i).sum() / torch.clamp_min(denom, 1) for x_i, loss_mask_i, denom in zip( - x.split(response_lengths, dim=0), loss_masks, sample_denoms, strict=False + x.split(response_lengths, dim=0), loss_masks, sample_denoms, strict=True ) ] ) @@ -106,7 +117,7 @@ def sum_of_sample_mean(x: torch.Tensor) -> torch.Tensor: [ (x_i * chunked_loss_mask).sum() / torch.clamp_min(denom, 1) for x_i, chunked_loss_mask, denom in zip( - x.split(cp_chunk_lengths, dim=0), chunked_loss_masks, sample_denoms, strict=False + x.split(cp_chunk_lengths, dim=0), chunked_loss_masks, sample_denoms, strict=True ) ] ) @@ -121,6 +132,13 @@ def sum_of_token(x: torch.Tensor) -> torch.Tensor: ] ) + if constant_divisor is not None: + + def sum_of_constant(x: torch.Tensor) -> torch.Tensor: + return sum_of_token(x) / constant_divisor + + return sum_of_constant + return sum_of_sample_mean if not calculate_per_token_loss else sum_of_token diff --git a/slime/backends/megatron_utils/data.py b/slime/backends/megatron_utils/data.py index 00f319928f..a17c4dd20f 100644 --- a/slime/backends/megatron_utils/data.py +++ b/slime/backends/megatron_utils/data.py @@ -286,6 +286,7 @@ def log_rollout_data( "rollout_mask_sums", "rollout_top_p_token_ids", "rollout_top_p_token_offsets", + "prompt_mask_sums", "rollout_routed_experts", "global_batch_sizes", "num_microbatches", diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index 72afdfa66c..88d760b7a7 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -874,6 +874,51 @@ def icepop_function( return pg_loss, loss_masks, metrics +def get_pg_loss_reducer( + args: Namespace, + batch: RolloutBatch, + *, + total_lengths: list[int], + response_lengths: list[int], + pg_loss_masks: list[torch.Tensor], + default_reducer: Callable[[torch.Tensor], torch.Tensor], +) -> Callable[[torch.Tensor], torch.Tensor]: + """Return the pg_loss reducer selected by ``--loss-aggregation``.""" + mode = args.loss_aggregation + if mode in ("sample_mean", "token_mean"): + return default_reducer + if mode == "prompt_mean": + prompt_mask_sums = batch.get("prompt_mask_sums") + if prompt_mask_sums is None: + raise ValueError( + "--loss-aggregation=prompt_mean requires per-prompt-group mask sums " + "(batch['prompt_mask_sums']), but they are missing. A custom " + "--custom-convert-samples-to-train-data-path must populate " + "'prompt_mask_sums' (grouped by Sample.group_index)." + ) + reducer = get_sum_of_sample_mean( + total_lengths, + response_lengths, + pg_loss_masks, + prompt_mask_sums, + ) + + # The train step divides non-token losses by response count; scale the + # prompt-group sum so the final scalar is the exact mean over prompts. + def prompt_mean_reducer(x: torch.Tensor) -> torch.Tensor: + return reducer(x) * args.n_samples_per_prompt + + return prompt_mean_reducer + if mode == "constant": + return get_sum_of_sample_mean( + total_lengths, + response_lengths, + pg_loss_masks, + constant_divisor=args.loss_aggregation_divisor, + ) + raise ValueError(f"Unknown --loss-aggregation mode: {mode!r}") + + def policy_loss_function( args: Namespace, batch: RolloutBatch, @@ -1024,16 +1069,22 @@ def policy_loss_function( args.calculate_per_token_loss, ) - # Determine pg_loss reducer: use custom if specified, otherwise default - if getattr(args, "custom_pg_loss_reducer_function_path", None) is not None: + pg_loss_masks = modified_response_masks if (args.get_mismatch_metrics or args.use_tis) else batch["loss_masks"] + + if args.custom_pg_loss_reducer_function_path is not None: custom_pg_loss_reducer_func = load_function(args.custom_pg_loss_reducer_function_path) - # Determine which loss_masks to use for pg_loss reducer - pg_loss_masks = modified_response_masks if (args.get_mismatch_metrics or args.use_tis) else batch["loss_masks"] pg_loss_reducer = custom_pg_loss_reducer_func( total_lengths, response_lengths, pg_loss_masks, args.calculate_per_token_loss ) else: - pg_loss_reducer = sum_of_sample_mean + pg_loss_reducer = get_pg_loss_reducer( + args, + batch, + total_lengths=total_lengths, + response_lengths=response_lengths, + pg_loss_masks=pg_loss_masks, + default_reducer=sum_of_sample_mean, + ) pg_loss = pg_loss_reducer(pg_loss) pg_clipfrac = sum_of_sample_mean(pg_clipfrac) diff --git a/slime/backends/megatron_utils/model.py b/slime/backends/megatron_utils/model.py index 1ad6cd7957..5c972aaf53 100644 --- a/slime/backends/megatron_utils/model.py +++ b/slime/backends/megatron_utils/model.py @@ -593,6 +593,7 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p "rollout_log_probs", "teacher_log_probs", "rollout_mask_sums", + "prompt_mask_sums", ], ), args.data_pad_size_multiplier, diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index ff571101af..9de8cade61 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -95,11 +95,12 @@ def _tensorize_rollout_data_for_training(rollout_data: dict[str, Any]) -> None: for mm_dict in rollout_data["multimodal_train_inputs"] ] - if "rollout_mask_sums" in rollout_data: - rollout_data["rollout_mask_sums"] = _cpu_tensor( - rollout_data["rollout_mask_sums"], - dtype=torch.float32, - ) + for mask_sums_key in ("rollout_mask_sums", "prompt_mask_sums"): + if mask_sums_key in rollout_data: + rollout_data[mask_sums_key] = _cpu_tensor( + rollout_data[mask_sums_key], + dtype=torch.float32, + ) @dataclasses.dataclass @@ -777,6 +778,19 @@ def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sampl rollout_total_mask[rid] = rollout_total_mask.get(rid, 0) + ms train_data["rollout_mask_sums"] = [rollout_total_mask[rid] for rid in rollout_id_list] + if self.args.loss_aggregation == "prompt_mean": + group_total_mask: dict[int, int] = {} + for sample, ms in zip(samples, mask_sums_per_sample, strict=True): + if sample.group_index is None: + raise ValueError( + "--loss-aggregation prompt_mean requires every Sample.group_index to be set, " + "but a sample has group_index=None. prompt_mean divides each sample by its " + "prompt group's total mask; a None group_index means the sample belongs to no " + "prompt group, so its denominator is undefined." + ) + group_total_mask[sample.group_index] = group_total_mask.get(sample.group_index, 0) + ms + train_data["prompt_mask_sums"] = [group_total_mask[sample.group_index] for sample in samples] + # Overwrite raw_reward when available. Mixed-source batches may only # populate this field for a subset of samples (e.g. SWE but not code). if any(sample.metadata and "raw_reward" in sample.metadata for sample in samples): @@ -866,6 +880,7 @@ def _split_train_data_by_dp(self, data): "sample_indices", "rollout_ids", "rollout_mask_sums", + "prompt_mask_sums", "rollout_log_probs", "rollout_top_p_token_ids", "rollout_top_p_token_offsets", diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index d5cac9d44b..626f1910a2 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -28,6 +28,8 @@ def reset_arg(parser, name, **kwargs): if name in action.option_strings: if "default" in kwargs: action.default = kwargs["default"] + if "help" in kwargs: + action.help = kwargs["help"] break else: parser.add_argument(name, **kwargs) @@ -843,7 +845,16 @@ def add_algo_arguments(parser): ) reset_arg(parser, "--seed", type=int, default=1234) reset_arg(parser, "--clip-grad", type=float, default=1.0) - reset_arg(parser, "--calculate-per-token-loss", action="store_true") + reset_arg( + parser, + "--calculate-per-token-loss", + action="store_true", + help=( + "Legacy alias for --loss-aggregation=token_mean (the global per-token mean); " + "reconciled onto that mode at startup. Prefer --loss-aggregation=token_mean. " + "Incompatible with --loss-aggregation prompt_mean/constant." + ), + ) reset_arg(parser, "--lr", type=float, default=1e-6) parser.add_argument( @@ -1054,6 +1065,44 @@ def add_algo_arguments(parser): default=None, help="Path to a custom reducer function for pg_loss only. When set, pg_loss will use this custom reducer while other metrics (pg_clipfrac, ppo_kl, entropy_loss, etc.) still use the default sum_of_sample_mean.", ) + parser.add_argument( + "--loss-aggregation", + type=str, + default="sample_mean", + choices=["sample_mean", "prompt_mean", "token_mean", "constant"], + help=( + "How pg_loss is aggregated across the step (applies to pg_loss only; " + "pg_clipfrac, ppo_kl, entropy_loss, kl_loss keep the default sample-mean " + "reducer — same scope as --custom-pg-loss-reducer-function-path, which still " + "takes precedence when set). Modes follow the ScaleRL taxonomy " + "(arXiv:2510.13786 §3.2): " + "'sample_mean' (default; GRPO sample average) — each rollout's tokens are " + "averaged with the per-rollout token-weighted denominator, so every rollout " + "contributes equally regardless of fan-out (byte-identical to slime's prior " + "default); " + "'prompt_mean' (DAPO prompt average) — tokens are averaged over each prompt " + "group (all rollouts sharing a Sample.group_index share one denominator); " + "'token_mean' (token average) — global per-token mean; the legacy " + "--calculate-per-token-loss flag is exactly this mode (and is reconciled onto " + "it at startup); " + "'constant' (Dr.GRPO, arXiv:2503.20783) — masked token sum divided by a fixed " + "--loss-aggregation-divisor (e.g. the max context length)." + ), + ) + parser.add_argument( + "--loss-aggregation-divisor", + type=float, + default=None, + help=( + "Constant divisor L for --loss-aggregation=constant (Dr.GRPO). pg_loss is " + "aggregated as sum(token_loss * loss_mask) / L instead of any data-dependent " + "denominator. L is per-token: the usual / step_global_batch_size step average " + "is then applied on top (same structure as every mode), so the effective " + "per-step normalization is / (L * step_global_batch_size); L only sets the " + "data-independent per-token scale. Required and validated > 0 at startup only " + "when --loss-aggregation=constant; setting it with any other mode fails loud." + ), + ) parser.add_argument( "--use-routing-replay", @@ -1825,6 +1874,30 @@ def slime_validate_args(args): assert not (args.kl_coef != 0 and args.kl_loss_coef != 0), "Only one of kl_coef and kl_loss_coef can be set" + loss_aggregation = args.loss_aggregation + divisor = args.loss_aggregation_divisor + if loss_aggregation == "constant": + if divisor is None or not (divisor > 0): + raise ValueError( + "--loss-aggregation-divisor must be set to a positive value when " + f"--loss-aggregation=constant (got {divisor!r})." + ) + elif divisor is not None: + raise ValueError( + "--loss-aggregation-divisor is only used with --loss-aggregation=constant " + f"(got --loss-aggregation={loss_aggregation})." + ) + if loss_aggregation == "token_mean": + args.calculate_per_token_loss = True + elif args.calculate_per_token_loss: + if loss_aggregation == "sample_mean": + loss_aggregation = args.loss_aggregation = "token_mean" + else: + raise ValueError( + f"--loss-aggregation={loss_aggregation} is incompatible with --calculate-per-token-loss " + "(use --loss-aggregation=token_mean for the per-token mean)." + ) + if args.advantage_estimator in ["reinforce_plus_plus", "reinforce_plus_plus_baseline"]: assert args.normalize_advantages, ( "The 'reinforce_plus_plus' and 'reinforce_plus_plus_baseline' advantage estimators " @@ -1948,6 +2021,17 @@ def slime_validate_args(args): ) args.global_batch_size = global_batch_size + if ( + args.loss_aggregation == "prompt_mean" + and args.global_batch_size is not None + and args.global_batch_size % args.n_samples_per_prompt != 0 + ): + raise ValueError( + "--loss-aggregation prompt_mean requires global_batch_size to be a multiple of " + "n_samples_per_prompt so each prompt group stays within one training step " + f"(got global_batch_size={args.global_batch_size}, n_samples_per_prompt={args.n_samples_per_prompt})." + ) + if args.n_samples_per_prompt == 1: args.grpo_std_normalization = False logger.info("n_samples_per_prompt is set to 1, grpo_std_normalization will be set to False.") diff --git a/tests/test_cp_utils.py b/tests/test_cp_utils.py index c7d3abe9a2..68a0df8a1c 100644 --- a/tests/test_cp_utils.py +++ b/tests/test_cp_utils.py @@ -15,6 +15,8 @@ from __future__ import annotations +from argparse import Namespace + # Import the helpers BEFORE the slime imports so the megatron stub lands # in sys.modules first. pytest's prepend importmode puts this file's # directory (``tests/``) on sys.path, which is what makes the bare-name @@ -27,6 +29,7 @@ get_logits_and_tokens_offset_with_cp, get_sum_of_sample_mean, ) +from slime.backends.megatron_utils.loss import get_pg_loss_reducer # noqa: E402 NUM_GPUS = 0 @@ -176,5 +179,119 @@ def test_cp_chunking_preserves_per_rollout_mean_report(monkeypatch): assert cp_total == pytest.approx(baseline) +@pytest.mark.unit +def test_constant_divisor_divides_masked_token_sum_by_L(): + total_lengths, response_lengths, loss_masks = _make_inputs([3, 3]) + L = 40.0 + reducer = get_sum_of_sample_mean(total_lengths, response_lengths, loss_masks, constant_divisor=L) + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + + assert reducer(x).item() == pytest.approx(21.0 / L) + + +@pytest.mark.unit +def test_prompt_mean_denom_is_per_group_token_sum(): + total_lengths, response_lengths, loss_masks = _make_inputs([2, 2, 4, 4]) + prompt_denoms = _denoms(4, 4, 8, 8) + reducer = get_sum_of_sample_mean(total_lengths, response_lengths, loss_masks, prompt_denoms) + + x = torch.tensor([1.0] * 2 + [1.0] * 2 + [1.0] * 4 + [1.0] * 4) + assert reducer(x).item() == pytest.approx(2.0) + + x2 = torch.tensor([2.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]) + assert reducer(x2).item() == pytest.approx(1.5) + + sample_mean = get_sum_of_sample_mean(total_lengths, response_lengths, loss_masks) + assert sample_mean(x2).item() == pytest.approx(3.0) + assert sample_mean(x2).item() != pytest.approx(reducer(x2).item()) + + token_sum = get_sum_of_sample_mean(total_lengths, response_lengths, loss_masks, calculate_per_token_loss=True) + assert token_sum(x2).item() == pytest.approx(10.0) + assert token_sum(x2).item() != pytest.approx(reducer(x2).item()) + + +@pytest.mark.unit +def test_prompt_mean_reducer_matches_final_prompt_average(): + total_lengths, response_lengths, loss_masks = _make_inputs([2, 2, 4, 4]) + batch = {"prompt_mask_sums": _denoms(4, 4, 8, 8)} + default_reducer = get_sum_of_sample_mean(total_lengths, response_lengths, loss_masks) + args = Namespace( + loss_aggregation="prompt_mean", + n_samples_per_prompt=2, + ) + reducer = get_pg_loss_reducer( + args, + batch, + total_lengths=total_lengths, + response_lengths=response_lengths, + pg_loss_masks=loss_masks, + default_reducer=default_reducer, + ) + + x = torch.tensor([2.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]) + step_global_batch_size = 4 + + assert (reducer(x) / step_global_batch_size).item() == pytest.approx(0.75) + + +@pytest.mark.unit +def test_constant_and_per_token_loss_are_mutually_exclusive(): + total_lengths, response_lengths, loss_masks = _make_inputs([3]) + with pytest.raises(ValueError, match="mutually exclusive"): + get_sum_of_sample_mean( + total_lengths, + response_lengths, + loss_masks, + calculate_per_token_loss=True, + constant_divisor=40.0, + ) + + +@pytest.mark.unit +def test_cp_chunking_preserves_constant_divisor(monkeypatch): + from megatron.core import mpu as _mpu + + total_lengths = [12, 12] + response_lengths = [8, 8] + loss_masks = [torch.ones(r, dtype=torch.float32) for r in response_lengths] + L = 40.0 + x_full = [ + torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]), + torch.tensor([10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0]), + ] + x_concat = torch.cat(x_full) + + monkeypatch.setattr(_mpu, "get_context_parallel_world_size", lambda: 1) + monkeypatch.setattr(_mpu, "get_context_parallel_rank", lambda: 0) + reducer_cp1 = get_sum_of_sample_mean(total_lengths, response_lengths, loss_masks, constant_divisor=L) + baseline = reducer_cp1(x_concat).item() + + monkeypatch.setattr(_mpu, "get_context_parallel_world_size", lambda: 2) + cp_total = 0.0 + for cp_rank in range(2): + monkeypatch.setattr(_mpu, "get_context_parallel_rank", lambda r=cp_rank: r) + x_chunks_per_sample = [] + for tl, rl, x in zip(total_lengths, response_lengths, x_full, strict=True): + prompt_length = tl - rl + _, _, _, tokens_offset = get_logits_and_tokens_offset_with_cp(tl, rl) + chunk_0 = x[tokens_offset[0][0] - prompt_length : tokens_offset[0][1] - prompt_length] + chunk_1 = x[tokens_offset[1][0] - prompt_length : tokens_offset[1][1] - prompt_length] + x_chunks_per_sample.append(torch.cat([chunk_0, chunk_1])) + x_for_rank = torch.cat(x_chunks_per_sample) + reducer_cp2 = get_sum_of_sample_mean(total_lengths, response_lengths, loss_masks, constant_divisor=L) + cp_total += reducer_cp2(x_for_rank).item() + + assert cp_total == pytest.approx(baseline) + + +@pytest.mark.unit +def test_sample_denoms_length_mismatch_fails_loud(): + total_lengths, response_lengths, loss_masks = _make_inputs([3, 3, 3]) + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]) + reducer = get_sum_of_sample_mean(total_lengths, response_lengths, loss_masks, _denoms(9, 9)) + with pytest.raises(ValueError, match="zip"): + reducer(x) + + if __name__ == "__main__": raise SystemExit(pytest.main([__file__])) diff --git a/tests/test_megatron_argument_validation.py b/tests/test_megatron_argument_validation.py index db5c0bfd81..b25b23c27a 100644 --- a/tests/test_megatron_argument_validation.py +++ b/tests/test_megatron_argument_validation.py @@ -303,6 +303,9 @@ def make_slime_validate_args(**overrides): update_weight_disk_dir=None, update_weight_delta_dir=None, update_weight_mode="full", + loss_aggregation="sample_mean", + loss_aggregation_divisor=None, + calculate_per_token_loss=False, ) values.update(overrides) return types.SimpleNamespace(**values) @@ -351,6 +354,136 @@ def test_slime_validate_args_preserves_zero_rollout_gpus_without_colocate(monkey assert args.offload_rollout is False +@pytest.mark.unit +@pytest.mark.parametrize("divisor", [None, 0.0, -1.0, float("nan")]) +def test_loss_aggregation_constant_rejects_nonpositive_divisor(monkeypatch, divisor): + module = load_slime_arguments_module(monkeypatch) + args = make_slime_validate_args(loss_aggregation="constant", loss_aggregation_divisor=divisor) + + with pytest.raises(ValueError, match="loss-aggregation-divisor"): + module.slime_validate_args(args) + + +@pytest.mark.unit +def test_loss_aggregation_constant_accepts_positive_divisor(monkeypatch): + module = load_slime_arguments_module(monkeypatch) + args = make_slime_validate_args(loss_aggregation="constant", loss_aggregation_divisor=40960.0) + + module.slime_validate_args(args) + + assert args.loss_aggregation_divisor == 40960.0 + + +@pytest.mark.unit +def test_loss_aggregation_token_mean_aliases_calculate_per_token_loss(monkeypatch): + module = load_slime_arguments_module(monkeypatch) + args = make_slime_validate_args(loss_aggregation="token_mean", calculate_per_token_loss=False) + + module.slime_validate_args(args) + + assert args.calculate_per_token_loss is True + + +@pytest.mark.unit +def test_calculate_per_token_loss_alone_reconciles_to_token_mean(monkeypatch): + module = load_slime_arguments_module(monkeypatch) + args = make_slime_validate_args(calculate_per_token_loss=True) + + module.slime_validate_args(args) + + assert args.loss_aggregation == "token_mean" + assert args.calculate_per_token_loss is True + + +@pytest.mark.unit +def test_loss_aggregation_default_leaves_per_token_loss_off(monkeypatch): + module = load_slime_arguments_module(monkeypatch) + args = make_slime_validate_args() + + module.slime_validate_args(args) + + assert args.calculate_per_token_loss is False + assert args.loss_aggregation == "sample_mean" + + +@pytest.mark.unit +def test_loss_aggregation_prompt_mean_rejects_calculate_per_token_loss(monkeypatch): + module = load_slime_arguments_module(monkeypatch) + args = make_slime_validate_args(loss_aggregation="prompt_mean", calculate_per_token_loss=True) + + with pytest.raises(ValueError, match="prompt_mean is incompatible with --calculate-per-token-loss"): + module.slime_validate_args(args) + + +@pytest.mark.unit +def test_loss_aggregation_prompt_mean_without_per_token_loss_passes(monkeypatch): + module = load_slime_arguments_module(monkeypatch) + args = make_slime_validate_args(loss_aggregation="prompt_mean", calculate_per_token_loss=False) + + module.slime_validate_args(args) + + assert args.loss_aggregation == "prompt_mean" + assert args.calculate_per_token_loss is False + + +@pytest.mark.unit +def test_loss_aggregation_constant_rejects_calculate_per_token_loss(monkeypatch): + module = load_slime_arguments_module(monkeypatch) + args = make_slime_validate_args( + loss_aggregation="constant", loss_aggregation_divisor=40960.0, calculate_per_token_loss=True + ) + + with pytest.raises(ValueError, match="constant is incompatible with --calculate-per-token-loss"): + module.slime_validate_args(args) + + +@pytest.mark.unit +@pytest.mark.parametrize("mode", ["sample_mean", "token_mean", "prompt_mean"]) +def test_loss_aggregation_divisor_rejected_on_non_constant_modes(monkeypatch, mode): + module = load_slime_arguments_module(monkeypatch) + args = make_slime_validate_args(loss_aggregation=mode, loss_aggregation_divisor=40960.0) + + with pytest.raises(ValueError, match="loss-aggregation-divisor is only used with --loss-aggregation=constant"): + module.slime_validate_args(args) + + +@pytest.mark.unit +def test_loss_aggregation_prompt_mean_rejects_non_multiple_global_batch_size(monkeypatch): + module = load_slime_arguments_module(monkeypatch) + args = make_slime_validate_args(loss_aggregation="prompt_mean", n_samples_per_prompt=4, global_batch_size=6) + + with pytest.raises(ValueError, match="prompt_mean requires global_batch_size to be a multiple"): + module.slime_validate_args(args) + + +@pytest.mark.unit +def test_loss_aggregation_prompt_mean_accepts_multiple_global_batch_size(monkeypatch): + module = load_slime_arguments_module(monkeypatch) + args = make_slime_validate_args(loss_aggregation="prompt_mean", n_samples_per_prompt=4, global_batch_size=8) + + module.slime_validate_args(args) + + assert args.loss_aggregation == "prompt_mean" + assert args.global_batch_size == 8 + + +@pytest.mark.unit +@pytest.mark.parametrize("mode", ["sample_mean", "token_mean", "constant"]) +def test_loss_aggregation_non_prompt_mean_allows_non_multiple_global_batch_size(monkeypatch, mode): + module = load_slime_arguments_module(monkeypatch) + divisor = 40960.0 if mode == "constant" else None + args = make_slime_validate_args( + loss_aggregation=mode, + loss_aggregation_divisor=divisor, + n_samples_per_prompt=4, + global_batch_size=6, + ) + + module.slime_validate_args(args) + + assert args.global_batch_size == 6 + + @pytest.mark.unit def test_update_weight_delta_rejects_colocate(monkeypatch): module = load_slime_arguments_module(monkeypatch) diff --git a/tests/test_rollout_validation.py b/tests/test_rollout_validation.py index fa2ff34335..b345d9f15d 100644 --- a/tests/test_rollout_validation.py +++ b/tests/test_rollout_validation.py @@ -1,3 +1,5 @@ +import types + import pytest from slime.ray.rollout_validation import validate_server_group_gpu_indices @@ -59,5 +61,75 @@ def test_validate_server_group_gpu_indices_reports_config_context(): assert "rollout_num_gpus_per_engine=2" in message +def _make_convert_manager(loss_aggregation): + from slime.ray.rollout import RolloutManager + + meta = getattr(RolloutManager, "__ray_metadata__", None) + cls = meta.modified_class if meta is not None else RolloutManager + manager = cls.__new__(cls) + manager.custom_convert_samples_to_train_data_func = None + manager.custom_reward_post_process_func = None + manager.args = types.SimpleNamespace( + loss_aggregation=loss_aggregation, + reward_key=None, + advantage_estimator="reinforce", + rewards_normalization=False, + grpo_std_normalization=False, + ) + return manager + + +def _make_grouped_samples(group_indices): + from slime.utils.types import Sample + + samples = [] + for i, gid in enumerate(group_indices): + samples.append( + Sample( + index=i, + group_index=gid, + rollout_id=i, + tokens=[0, 1, 2, 3], + response_length=2, + reward=0.0, + loss_mask=[1, 1], + ) + ) + return samples + + +@pytest.mark.unit +def test_prompt_mean_fails_loud_on_none_group_index(): + pytest.importorskip("sglang") # RolloutManager import pulls sglang + manager = _make_convert_manager("prompt_mean") + samples = _make_grouped_samples([0, 0, None, 1]) + + with pytest.raises(ValueError, match="group_index"): + manager._convert_samples_to_train_data(samples) + + +@pytest.mark.unit +def test_prompt_mean_builds_per_group_mask_sums(): + pytest.importorskip("sglang") + manager = _make_convert_manager("prompt_mean") + samples = _make_grouped_samples([0, 0, 1]) + + train_data = manager._convert_samples_to_train_data(samples) + + assert train_data["prompt_mask_sums"] == [4, 4, 2] + + +@pytest.mark.unit +@pytest.mark.parametrize("mode", ["sample_mean", "constant", "token_mean"]) +def test_non_prompt_mean_modes_ignore_none_group_index(mode): + pytest.importorskip("sglang") + manager = _make_convert_manager(mode) + samples = _make_grouped_samples([0, None, 1]) + + train_data = manager._convert_samples_to_train_data(samples) + + assert "prompt_mask_sums" not in train_data + + if __name__ == "__main__": raise SystemExit(pytest.main([__file__]))