Skip to content

feat(loss): support pg_loss aggregation modes#2

Open
EazyReal wants to merge 2 commits into
mainfrom
codex/loss-aggregation-reduction-contract
Open

feat(loss): support pg_loss aggregation modes#2
EazyReal wants to merge 2 commits into
mainfrom
codex/loss-aggregation-reduction-contract

Conversation

@EazyReal

@EazyReal EazyReal commented Jun 27, 2026

Copy link
Copy Markdown
Owner

Description

Fork review branch for the Miles pg_loss aggregation work before updating radixark#1498. Related implementations for comparison:

This adds --loss-aggregation {sample_mean,prompt_mean,token_mean,constant} while keeping sample_mean as the default. --calculate-per-token-loss remains the legacy spelling for token_mean, and custom pg-loss reducers still take precedence.

Mode Denominator
sample_mean per-sample active-token count
prompt_mean per-prompt-group active-token count, then mean over prompt groups
token_mean global active-token count
constant fixed --loss-aggregation-divisor

The correctness contract is explicit in code:

  • token_mean policy loss rejects nonzero entropy or KL-loss coefficients that would mix token-normalized pg_loss with sample-normalized auxiliary loss terms.
  • Train logging carries per-key normalizers when needed, and aggregation rejects mixed legacy/normalized log dictionaries, key-order mismatches, and malformed value/normalizer lengths.
  • SFT/value losses use the token reducer when the legacy per-token path is active.
  • TIS/RS pg_loss uses the modified mask for loss reduction while mismatch metrics stay on the original mask.
  • prompt_mean DP splitting keeps prompt groups whole on each DP shard and rejects train steps whose prompt-group count cannot be distributed evenly.
  • Custom-config overrides rederive and revalidate global_batch_size before checking aggregation constraints.

Validation

  • uv run --with pytest --with torch --with numpy --with httpx --with pyyaml --with ray --with huggingface_hub --with transformers --with pydantic --with psutil pytest --confcutdir=tests/fast/backends/training_utils tests/fast/backends/training_utils/test_loss_aggregation.py tests/fast/backends/training_utils/loss/test_loss_snapshot.py -q
  • uv run --with ruff ruff check miles/backends/training_utils/cp_utils.py miles/backends/training_utils/data.py miles/backends/training_utils/log_utils.py miles/backends/training_utils/loss.py miles/backends/training_utils/loss_hub/losses.py miles/ray/rollout/train_data_conversion.py miles/utils/arguments.py miles/backends/megatron_utils/model.py miles/backends/experimental/fsdp_utils/actor.py tests/fast/backends/training_utils/test_loss_aggregation.py tests/fast/backends/training_utils/loss/test_loss_snapshot.py
  • uv run --with black black --check miles/backends/training_utils/cp_utils.py miles/backends/training_utils/data.py miles/backends/training_utils/log_utils.py miles/backends/training_utils/loss.py miles/backends/training_utils/loss_hub/losses.py miles/ray/rollout/train_data_conversion.py miles/utils/arguments.py miles/backends/megatron_utils/model.py miles/backends/experimental/fsdp_utils/actor.py tests/fast/backends/training_utils/test_loss_aggregation.py tests/fast/backends/training_utils/loss/test_loss_snapshot.py
  • git diff --check

python3 train.py --help is not runnable in this local environment because sglang is not installed; the parser path for the new flags is covered by the focused tests.

@coderabbitai

coderabbitai Bot commented Jun 27, 2026

Copy link
Copy Markdown

Review Change Stack

Warning

Review limit reached

@EazyReal, we couldn't start this review because you've reached your PR review rate limit.

More reviews will be available in 10 minutes and 44 seconds. Learn how PR review limits work.

Your organization has used up its prepaid credits, and credit purchases are no longer available. Enable the review add-on in the billing tab to keep reviews running — you're only billed for reviews past your plan's rate limits ($0.25/file).

⌛ How to resolve this issue?

After more reviews become available, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

To avoid repeated limits, reduce automatic review volume by pausing incremental auto-reviews earlier, using label-based review opt-in, excluding WIP or generated PR titles, or requesting reviews manually when the PR is ready. If your team needs uninterrupted high-volume reviews, an organization admin can enable usage-based credits.

🚦 How do rate limits work?

CodeRabbit enforces per-developer PR review limits for each organization. Most developers receive the normal plan review availability.

For paid Pro and Pro+ PR reviews, CodeRabbit uses adaptive limits for sustained high-volume activity. When a developer's recent PR review activity reaches the 95th percentile or higher among CodeRabbit users, additional reviews become available more gradually as earlier reviews age out of the rolling window.

Please see our Fair Usage Limits Policy for further information.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro Plus

Run ID: 79129ced-47f4-4004-9c5b-abdfe2ce0c39

📥 Commits

Reviewing files that changed from the base of the PR and between 0bce1a8 and fe7d2bf.

📒 Files selected for processing (8)
  • docs/user-guide/customization.md
  • miles/backends/training_utils/cp_utils.py
  • miles/backends/training_utils/log_utils.py
  • miles/backends/training_utils/loss.py
  • miles/backends/training_utils/loss_hub/losses.py
  • miles/ray/rollout/train_data_conversion.py
  • miles/utils/arguments.py
  • tests/fast/backends/training_utils/test_loss_aggregation.py
📝 Walkthrough

Walkthrough

Adds a --loss-aggregation CLI flag with modes sample_mean, prompt_mean, token_mean, and constant, superseding the legacy --calculate-per-token-loss flag. New prompt_group_indices and prompt_mask_sums fields are plumbed through the data pipeline. Loss reducers, log aggregation, and argument validation are updated accordingly, with full documentation and a new test module.

Changes

Loss Aggregation Modes

Layer / File(s) Summary
CLI arguments and validation
miles/utils/arguments.py
Adds --loss-aggregation (modes: sample_mean, prompt_mean, token_mean, constant) and --loss-aggregation-divisor; _validate_loss_aggregation_args reconciles them with legacy --calculate-per-token-loss and wires into miles_validate_args.
Data pipeline for prompt_group_indices and prompt_mask_sums
miles/ray/rollout/train_data_conversion.py, miles/backends/training_utils/data.py, miles/backends/experimental/fsdp_utils/actor.py, miles/backends/megatron_utils/model.py
convert_samples_to_train_data computes prompt-group fields for prompt_mean; split_train_data_by_dp slices them; get_rollout_data converts prompt_mask_sums to CUDA float32; FSDP and Megatron actors fetch both fields from get_batch.
Loss reducers: cp_utils and get_pg_loss_reducer
miles/backends/training_utils/cp_utils.py, miles/backends/training_utils/loss_hub/losses.py
get_sum_of_sample_mean gains sample_denoms and constant_divisor parameters; _prompt_group_mask_sums and get_pg_loss_reducer dispatch pg-loss aggregation by mode; policy_loss_function uses get_pg_loss_reducer replacing the direct sum_of_sample_mean default.
loss.py helpers and log_utils normalizer aggregation
miles/backends/training_utils/loss.py, miles/backends/training_utils/log_utils.py
Adds _build_train_log_dict and _validate_loss_aggregation_contract in loss.py; loss_function uses both. aggregate_train_losses accumulates and all-reduces optional per-metric normalizers for the new token-mean path.
CLI reference and customization docs
docs/user-guide/cli-reference.md, docs/user-guide/customization.md
Updates CLI reference table and adds comprehensive --loss-aggregation section in customization.md covering mode semantics, legacy flag reconciliation, and prompt_mean output schema requirements.
Tests
tests/fast/backends/training_utils/loss/test_loss_snapshot.py, tests/fast/backends/training_utils/test_loss_aggregation.py
New test_loss_aggregation.py covers get_sum_of_sample_mean modes, get_pg_loss_reducer dispatch and prompt-mean edge cases, log aggregation normalizer contracts, CLI validation, and convert_samples_to_train_data prompt-mean output; existing snapshot test updated to keyword args.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

🐇 Hopping through tokens, one by one,
Now prompt_mean groups them, all under the sun.
constant_divisor set, the math is clear,
Legacy flags reconciled — no more fear!
The rabbit aggregates loss with delight~ 🎉

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 15.85% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and accurately summarizes the main change: adding support for pg_loss aggregation modes.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch codex/loss-aggregation-reduction-contract

Comment @coderabbitai help to get the list of available commands.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
miles/ray/rollout/train_data_conversion.py (1)

152-178: 🎯 Functional Correctness | 🟠 Major | 🏗️ Heavy lift

Keep complete prompt groups on the same DP shard.

Copying prompt_group_indices and prompt_mask_sums is not enough here because the partitioning still stripes samples by index. With n_samples_per_prompt=2, a batch ordered like [0,0,1,1] becomes [0,1] on each shard for dp_size=2, and get_pg_loss_reducer(..., prompt_mean) will then raise for any modified pg_loss_masks because _prompt_group_mask_sums(..., expected_group_size=2) only sees partial groups locally. This needs group-aware partitioning (or an explicit invariant check) before slicing these fields.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@miles/ray/rollout/train_data_conversion.py` around lines 152 - 178, The
current rollout slicing in train_data_conversion.py still partitions by sample
index, so prompt groups can be split across DP shards even though
prompt_group_indices and prompt_mask_sums are carried through. Update the
partitioning logic in the conversion path that builds rollout_data so complete
prompt groups stay on the same shard, or add an explicit invariant check before
slicing these fields. Make sure the fix covers the loop over dp_size and the
handling of prompt_group_indices, prompt_mask_sums, and related prompt data used
by get_pg_loss_reducer / prompt_mean.
🧹 Nitpick comments (1)
tests/fast/backends/training_utils/test_loss_aggregation.py (1)

655-681: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick win

Add a DP-splitting regression for prompt_mean.

The new suite checks conversion and reducer validation separately, but it never exercises split_train_data_by_dp with grouped samples. A small test around [0,0,1,1] + dp_size=2 would catch the prompt-group sharding break before it reaches training.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/fast/backends/training_utils/test_loss_aggregation.py` around lines 655
- 681, Add a regression test for prompt_mean DP sharding in the training utils
test suite. The current tests cover _convert and validation, but not
split_train_data_by_dp with grouped samples; add a case using
prompt_group_indices like [0, 0, 1, 1] and dp_size=2 to verify grouped samples
are split correctly. Place the new test near the existing prompt_mean checks in
test_loss_aggregation.py and use the existing helpers such as _make_sample,
_convert, and split_train_data_by_dp.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@docs/user-guide/customization.md`:
- Around line 222-223: Update the customization docs to say that
--loss-aggregation-divisor is only accepted with constant and is rejected by
_validate_loss_aggregation_args() for every other loss mode. Use the existing
loss aggregation option names and the _validate_loss_aggregation_args helper to
make the behavior explicit, so users are not told the divisor is ignored when it
actually fails startup validation.

In `@miles/backends/training_utils/cp_utils.py`:
- Around line 103-108: The reducer setup currently accepts constant_divisor
values that are zero or negative, which later leads to invalid division or
flipped loss scaling. Add an explicit validation in the helper that handles
constant_divisor and calculate_per_token_loss to reject any non-positive
constant_divisor before building the closure, and mirror the same guard in the
other reducer path referenced by the review.

In `@miles/backends/training_utils/log_utils.py`:
- Around line 415-430: The metric-shape validation in the loss reduction path is
only enforced with asserts, so malformed inputs can slip through when assertions
are disabled. In the log_utils reduction logic around the values/normalizers
handling, replace the asserts with explicit ValueError checks and make the zip
over keys, values, and normalizers use strict validation so mismatched lengths
cannot silently truncate metrics. Keep the fix localized to the reduction block
that builds loss_reduced.

In `@miles/backends/training_utils/loss.py`:
- Around line 188-197: The non-policy loss paths still use sample-mean reduction
because get_sum_of_sample_mean is hardcoded with calculate_per_token_loss
disabled. Update the reducer wiring in loss.py so value_loss_function and
sft_loss_function receive the calculate_per_token_loss setting from args, using
the same flag when calling get_sum_of_sample_mean and any downstream reducer
selection tied to _validate_loss_aggregation_contract.

In `@miles/utils/arguments.py`:
- Around line 2406-2407: The loss aggregation validation in the arguments flow
is running before `--custom-config-path` overrides are applied, so values
derived from batch-size fields can become stale. Update the arguments processing
around `_validate_loss_aggregation_args(args)` and the custom config overlay
logic so the YAML overrides are applied first, then rerun the
`num_steps_per_rollout` to `global_batch_size` derivation and
`_validate_loss_aggregation_args(args)` afterward. Use the existing arguments
handling block in `arguments.py` to ensure `loss_aggregation`,
`calculate_per_token_loss`, and `global_batch_size` are validated against the
final merged config.

---

Outside diff comments:
In `@miles/ray/rollout/train_data_conversion.py`:
- Around line 152-178: The current rollout slicing in train_data_conversion.py
still partitions by sample index, so prompt groups can be split across DP shards
even though prompt_group_indices and prompt_mask_sums are carried through.
Update the partitioning logic in the conversion path that builds rollout_data so
complete prompt groups stay on the same shard, or add an explicit invariant
check before slicing these fields. Make sure the fix covers the loop over
dp_size and the handling of prompt_group_indices, prompt_mask_sums, and related
prompt data used by get_pg_loss_reducer / prompt_mean.

---

Nitpick comments:
In `@tests/fast/backends/training_utils/test_loss_aggregation.py`:
- Around line 655-681: Add a regression test for prompt_mean DP sharding in the
training utils test suite. The current tests cover _convert and validation, but
not split_train_data_by_dp with grouped samples; add a case using
prompt_group_indices like [0, 0, 1, 1] and dp_size=2 to verify grouped samples
are split correctly. Place the new test near the existing prompt_mean checks in
test_loss_aggregation.py and use the existing helpers such as _make_sample,
_convert, and split_train_data_by_dp.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro Plus

Run ID: 5c784fb7-21c6-41da-a323-6abab962eacc

📥 Commits

Reviewing files that changed from the base of the PR and between 39a1580 and 0bce1a8.

📒 Files selected for processing (13)
  • docs/user-guide/cli-reference.md
  • docs/user-guide/customization.md
  • miles/backends/experimental/fsdp_utils/actor.py
  • miles/backends/megatron_utils/model.py
  • miles/backends/training_utils/cp_utils.py
  • miles/backends/training_utils/data.py
  • miles/backends/training_utils/log_utils.py
  • miles/backends/training_utils/loss.py
  • miles/backends/training_utils/loss_hub/losses.py
  • miles/ray/rollout/train_data_conversion.py
  • miles/utils/arguments.py
  • tests/fast/backends/training_utils/loss/test_loss_snapshot.py
  • tests/fast/backends/training_utils/test_loss_aggregation.py

Comment thread docs/user-guide/customization.md Outdated
Comment thread miles/backends/training_utils/cp_utils.py
Comment thread miles/backends/training_utils/log_utils.py
Comment thread miles/backends/training_utils/loss.py
Comment thread miles/utils/arguments.py
@EazyReal EazyReal force-pushed the codex/loss-aggregation-reduction-contract branch 2 times, most recently from c288377 to 67e0503 Compare June 27, 2026 22:41
@EazyReal EazyReal force-pushed the codex/loss-aggregation-reduction-contract branch from 67e0503 to fe7d2bf Compare June 27, 2026 22:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant