Skip to content

feat(loss): add pg_loss aggregation modes#2090

Open
EazyReal wants to merge 2 commits into
THUDM:mainfrom
EazyReal:upstream-pr/loss-aggregation-modes
Open

feat(loss): add pg_loss aggregation modes#2090
EazyReal wants to merge 2 commits into
THUDM:mainfrom
EazyReal:upstream-pr/loss-aggregation-modes

Conversation

@EazyReal

@EazyReal EazyReal commented Jun 16, 2026

Copy link
Copy Markdown
Contributor

On main, slime has two built-in pg_loss normalizations: the default sample/rollout mean and the legacy --calculate-per-token-loss path. Prompt-group normalization and fixed-divisor normalization both require a custom pg_loss reducer today. That is fragile because the reducer has to compose correctly with CP slicing, micro-batch packing, and Megatron's final train-step divisor; a wrong constant factor silently changes the effective learning rate.

This PR moves the common pg_loss aggregation choices behind --loss-aggregation. For one train step, let N be the number of rollouts, G = n_samples_per_prompt, P = N / G, M_i be the valid-token count for rollout i, and L = --loss-aggregation-divisor.

mode step scalar behavior
sample_mean (1 / N) * sum_i token_mean(i) Keeps the current default behavior.
token_mean sum_i,t loss_it * mask_it / sum_i M_i Uses the existing --calculate-per-token-loss path under the new unified knob.
prompt_mean (1 / P) * sum_g token_mean(prompt_group_g) Groups rollouts by Sample.group_index and gives each prompt group one unit of weight.
constant sum_i,t loss_it * mask_it / (L * N) Uses a fixed token-scale divisor before the usual step average.

The implementation reuses slime's existing reducer path instead of adding a parallel loss stack. Rollout builds prompt_mask_sums only for prompt_mean; cp_utils.get_sum_of_sample_mean still owns CP-aware summation; and the train step keeps its existing final scaling. For prompt_mean, the reducer scales the prompt-group sum by n_samples_per_prompt, so users get the requested prompt mean directly instead of a constant-offset objective that would need learning-rate compensation.

Startup validation rejects configurations that would silently use the wrong denominator, such as --loss-aggregation-divisor outside constant or --calculate-per-token-loss together with prompt_mean/constant. The custom pg_loss reducer hook remains available and keeps precedence for non-standard objectives.

The focused tests cover the final prompt_mean scalar, CP/sample-denominator invariance, argument validation, and rollout metadata construction.

Tested with:

  • uv run --no-project --with pytest --with torch --with numpy --with httpx --with pyyaml python -m pytest -q tests/test_cp_utils.py tests/test_megatron_argument_validation.py tests/test_rollout_validation.py
  • uv run --no-project --with ruff ruff check --config pyproject.toml docs/en/get_started/customization.md slime/backends/megatron_utils/loss.py slime/ray/rollout.py slime/utils/arguments.py tests/test_cp_utils.py tests/test_megatron_argument_validation.py tests/test_rollout_validation.py
  • uv run --no-project --with black black --check --line-length 119 slime/backends/megatron_utils/loss.py slime/utils/arguments.py

@EazyReal EazyReal force-pushed the upstream-pr/loss-aggregation-modes branch 3 times, most recently from 40d955f to 3774a73 Compare June 17, 2026 08:36
@EazyReal EazyReal marked this pull request as draft June 17, 2026 08:36
@EazyReal EazyReal marked this pull request as ready for review June 17, 2026 08:54
@EazyReal EazyReal changed the title Add --loss-aggregation for the four ScaleRL pg_loss aggregation modes feat(loss): add --loss-aggregation for the four ScaleRL pg_loss modes Jun 24, 2026
@EazyReal EazyReal force-pushed the upstream-pr/loss-aggregation-modes branch from 3774a73 to f23073b Compare June 24, 2026 03:18
@EazyReal EazyReal force-pushed the upstream-pr/loss-aggregation-modes branch from f23073b to 0fe5e98 Compare June 24, 2026 04:24
@EazyReal EazyReal changed the title feat(loss): add --loss-aggregation for the four ScaleRL pg_loss modes feat(loss): add pg_loss aggregation modes Jun 26, 2026
@EazyReal EazyReal force-pushed the upstream-pr/loss-aggregation-modes branch from c3da906 to fa69ef7 Compare June 26, 2026 00:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant