feat(loss): support pg_loss aggregation modes#2
Conversation
|
Warning Review limit reached
More reviews will be available in 10 minutes and 44 seconds. Learn how PR review limits work. Your organization has used up its prepaid credits, and credit purchases are no longer available. Enable the review add-on in the billing tab to keep reviews running — you're only billed for reviews past your plan's rate limits ($0.25/file). ⌛ How to resolve this issue?After more reviews become available, a review can be triggered using the To avoid repeated limits, reduce automatic review volume by pausing incremental auto-reviews earlier, using label-based review opt-in, excluding WIP or generated PR titles, or requesting reviews manually when the PR is ready. If your team needs uninterrupted high-volume reviews, an organization admin can enable usage-based credits. 🚦 How do rate limits work?CodeRabbit enforces per-developer PR review limits for each organization. Most developers receive the normal plan review availability. For paid Pro and Pro+ PR reviews, CodeRabbit uses adaptive limits for sustained high-volume activity. When a developer's recent PR review activity reaches the 95th percentile or higher among CodeRabbit users, additional reviews become available more gradually as earlier reviews age out of the rolling window. Please see our Fair Usage Limits Policy for further information. ℹ️ Review info⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Plus Run ID: 📒 Files selected for processing (8)
📝 WalkthroughWalkthroughAdds a ChangesLoss Aggregation Modes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
miles/ray/rollout/train_data_conversion.py (1)
152-178: 🎯 Functional Correctness | 🟠 Major | 🏗️ Heavy liftKeep complete prompt groups on the same DP shard.
Copying
prompt_group_indicesandprompt_mask_sumsis not enough here because the partitioning still stripes samples by index. Withn_samples_per_prompt=2, a batch ordered like[0,0,1,1]becomes[0,1]on each shard fordp_size=2, andget_pg_loss_reducer(..., prompt_mean)will then raise for any modifiedpg_loss_masksbecause_prompt_group_mask_sums(..., expected_group_size=2)only sees partial groups locally. This needs group-aware partitioning (or an explicit invariant check) before slicing these fields.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@miles/ray/rollout/train_data_conversion.py` around lines 152 - 178, The current rollout slicing in train_data_conversion.py still partitions by sample index, so prompt groups can be split across DP shards even though prompt_group_indices and prompt_mask_sums are carried through. Update the partitioning logic in the conversion path that builds rollout_data so complete prompt groups stay on the same shard, or add an explicit invariant check before slicing these fields. Make sure the fix covers the loop over dp_size and the handling of prompt_group_indices, prompt_mask_sums, and related prompt data used by get_pg_loss_reducer / prompt_mean.
🧹 Nitpick comments (1)
tests/fast/backends/training_utils/test_loss_aggregation.py (1)
655-681: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winAdd a DP-splitting regression for
prompt_mean.The new suite checks conversion and reducer validation separately, but it never exercises
split_train_data_by_dpwith grouped samples. A small test around[0,0,1,1]+dp_size=2would catch the prompt-group sharding break before it reaches training.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/fast/backends/training_utils/test_loss_aggregation.py` around lines 655 - 681, Add a regression test for prompt_mean DP sharding in the training utils test suite. The current tests cover _convert and validation, but not split_train_data_by_dp with grouped samples; add a case using prompt_group_indices like [0, 0, 1, 1] and dp_size=2 to verify grouped samples are split correctly. Place the new test near the existing prompt_mean checks in test_loss_aggregation.py and use the existing helpers such as _make_sample, _convert, and split_train_data_by_dp.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@docs/user-guide/customization.md`:
- Around line 222-223: Update the customization docs to say that
--loss-aggregation-divisor is only accepted with constant and is rejected by
_validate_loss_aggregation_args() for every other loss mode. Use the existing
loss aggregation option names and the _validate_loss_aggregation_args helper to
make the behavior explicit, so users are not told the divisor is ignored when it
actually fails startup validation.
In `@miles/backends/training_utils/cp_utils.py`:
- Around line 103-108: The reducer setup currently accepts constant_divisor
values that are zero or negative, which later leads to invalid division or
flipped loss scaling. Add an explicit validation in the helper that handles
constant_divisor and calculate_per_token_loss to reject any non-positive
constant_divisor before building the closure, and mirror the same guard in the
other reducer path referenced by the review.
In `@miles/backends/training_utils/log_utils.py`:
- Around line 415-430: The metric-shape validation in the loss reduction path is
only enforced with asserts, so malformed inputs can slip through when assertions
are disabled. In the log_utils reduction logic around the values/normalizers
handling, replace the asserts with explicit ValueError checks and make the zip
over keys, values, and normalizers use strict validation so mismatched lengths
cannot silently truncate metrics. Keep the fix localized to the reduction block
that builds loss_reduced.
In `@miles/backends/training_utils/loss.py`:
- Around line 188-197: The non-policy loss paths still use sample-mean reduction
because get_sum_of_sample_mean is hardcoded with calculate_per_token_loss
disabled. Update the reducer wiring in loss.py so value_loss_function and
sft_loss_function receive the calculate_per_token_loss setting from args, using
the same flag when calling get_sum_of_sample_mean and any downstream reducer
selection tied to _validate_loss_aggregation_contract.
In `@miles/utils/arguments.py`:
- Around line 2406-2407: The loss aggregation validation in the arguments flow
is running before `--custom-config-path` overrides are applied, so values
derived from batch-size fields can become stale. Update the arguments processing
around `_validate_loss_aggregation_args(args)` and the custom config overlay
logic so the YAML overrides are applied first, then rerun the
`num_steps_per_rollout` to `global_batch_size` derivation and
`_validate_loss_aggregation_args(args)` afterward. Use the existing arguments
handling block in `arguments.py` to ensure `loss_aggregation`,
`calculate_per_token_loss`, and `global_batch_size` are validated against the
final merged config.
---
Outside diff comments:
In `@miles/ray/rollout/train_data_conversion.py`:
- Around line 152-178: The current rollout slicing in train_data_conversion.py
still partitions by sample index, so prompt groups can be split across DP shards
even though prompt_group_indices and prompt_mask_sums are carried through.
Update the partitioning logic in the conversion path that builds rollout_data so
complete prompt groups stay on the same shard, or add an explicit invariant
check before slicing these fields. Make sure the fix covers the loop over
dp_size and the handling of prompt_group_indices, prompt_mask_sums, and related
prompt data used by get_pg_loss_reducer / prompt_mean.
---
Nitpick comments:
In `@tests/fast/backends/training_utils/test_loss_aggregation.py`:
- Around line 655-681: Add a regression test for prompt_mean DP sharding in the
training utils test suite. The current tests cover _convert and validation, but
not split_train_data_by_dp with grouped samples; add a case using
prompt_group_indices like [0, 0, 1, 1] and dp_size=2 to verify grouped samples
are split correctly. Place the new test near the existing prompt_mean checks in
test_loss_aggregation.py and use the existing helpers such as _make_sample,
_convert, and split_train_data_by_dp.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro Plus
Run ID: 5c784fb7-21c6-41da-a323-6abab962eacc
📒 Files selected for processing (13)
docs/user-guide/cli-reference.mddocs/user-guide/customization.mdmiles/backends/experimental/fsdp_utils/actor.pymiles/backends/megatron_utils/model.pymiles/backends/training_utils/cp_utils.pymiles/backends/training_utils/data.pymiles/backends/training_utils/log_utils.pymiles/backends/training_utils/loss.pymiles/backends/training_utils/loss_hub/losses.pymiles/ray/rollout/train_data_conversion.pymiles/utils/arguments.pytests/fast/backends/training_utils/loss/test_loss_snapshot.pytests/fast/backends/training_utils/test_loss_aggregation.py
c288377 to
67e0503
Compare
67e0503 to
fe7d2bf
Compare
Description
Fork review branch for the Miles
pg_lossaggregation work before updating radixark#1498. Related implementations for comparison:This adds
--loss-aggregation {sample_mean,prompt_mean,token_mean,constant}while keepingsample_meanas the default.--calculate-per-token-lossremains the legacy spelling fortoken_mean, and custom pg-loss reducers still take precedence.sample_meanprompt_meantoken_meanconstant--loss-aggregation-divisorThe correctness contract is explicit in code:
token_meanpolicy loss rejects nonzero entropy or KL-loss coefficients that would mix token-normalizedpg_losswith sample-normalized auxiliary loss terms.pg_lossuses the modified mask for loss reduction while mismatch metrics stay on the original mask.prompt_meanDP splitting keeps prompt groups whole on each DP shard and rejects train steps whose prompt-group count cannot be distributed evenly.global_batch_sizebefore checking aggregation constraints.Validation
uv run --with pytest --with torch --with numpy --with httpx --with pyyaml --with ray --with huggingface_hub --with transformers --with pydantic --with psutil pytest --confcutdir=tests/fast/backends/training_utils tests/fast/backends/training_utils/test_loss_aggregation.py tests/fast/backends/training_utils/loss/test_loss_snapshot.py -quv run --with ruff ruff check miles/backends/training_utils/cp_utils.py miles/backends/training_utils/data.py miles/backends/training_utils/log_utils.py miles/backends/training_utils/loss.py miles/backends/training_utils/loss_hub/losses.py miles/ray/rollout/train_data_conversion.py miles/utils/arguments.py miles/backends/megatron_utils/model.py miles/backends/experimental/fsdp_utils/actor.py tests/fast/backends/training_utils/test_loss_aggregation.py tests/fast/backends/training_utils/loss/test_loss_snapshot.pyuv run --with black black --check miles/backends/training_utils/cp_utils.py miles/backends/training_utils/data.py miles/backends/training_utils/log_utils.py miles/backends/training_utils/loss.py miles/backends/training_utils/loss_hub/losses.py miles/ray/rollout/train_data_conversion.py miles/utils/arguments.py miles/backends/megatron_utils/model.py miles/backends/experimental/fsdp_utils/actor.py tests/fast/backends/training_utils/test_loss_aggregation.py tests/fast/backends/training_utils/loss/test_loss_snapshot.pygit diff --checkpython3 train.py --helpis not runnable in this local environment becausesglangis not installed; the parser path for the new flags is covered by the focused tests.