Skip to content

feat(loss): support pg_loss aggregation modes#1498

Open
EazyReal wants to merge 2 commits into
radixark:mainfrom
EazyReal:upstream-pr/loss-aggregation-modes-v2
Open

feat(loss): support pg_loss aggregation modes#1498
EazyReal wants to merge 2 commits into
radixark:mainfrom
EazyReal:upstream-pr/loss-aggregation-modes-v2

Conversation

@EazyReal

@EazyReal EazyReal commented Jun 27, 2026

Copy link
Copy Markdown

Description

Adds built-in pg_loss aggregation modes to Miles. Related implementations for comparison:

This adds --loss-aggregation {sample_mean,prompt_mean,token_mean,constant} while keeping sample_mean as the default. --calculate-per-token-loss remains the legacy spelling for token_mean, and custom pg-loss reducers still take precedence.

Mode Denominator
sample_mean per-sample active-token count
prompt_mean per-prompt-group active-token count, then mean over prompt groups
token_mean global active-token count
constant fixed --loss-aggregation-divisor

The implementation keeps denominator ownership explicit:

  • token_mean policy loss rejects nonzero entropy or KL-loss coefficients that would mix token-normalized pg_loss with sample-normalized auxiliary loss terms.
  • Train logging carries per-key normalizers when needed, and aggregation rejects mixed legacy/normalized log dictionaries, key-order mismatches, and malformed value/normalizer lengths.
  • SFT/value losses use the token reducer when the legacy per-token path is active.
  • TIS/RS pg_loss uses the modified mask for loss reduction while mismatch metrics stay on the original mask.
  • prompt_mean DP splitting keeps prompt groups whole on each DP shard and rejects train steps whose prompt-group count cannot be distributed evenly.
  • Custom-config overrides rederive and revalidate global_batch_size before checking aggregation constraints.

Validation

  • uv run --with pytest --with torch --with numpy --with httpx --with pyyaml --with ray --with huggingface_hub --with transformers --with pydantic --with psutil pytest --confcutdir=tests/fast/backends/training_utils tests/fast/backends/training_utils/test_loss_aggregation.py tests/fast/backends/training_utils/loss/test_loss_snapshot.py -q
  • uv run --with ruff ruff check miles/backends/training_utils/cp_utils.py miles/backends/training_utils/data.py miles/backends/training_utils/log_utils.py miles/backends/training_utils/loss.py miles/backends/training_utils/loss_hub/losses.py miles/ray/rollout/train_data_conversion.py miles/utils/arguments.py miles/backends/megatron_utils/model.py miles/backends/experimental/fsdp_utils/actor.py tests/fast/backends/training_utils/test_loss_aggregation.py tests/fast/backends/training_utils/loss/test_loss_snapshot.py
  • uv run --with black black --check miles/backends/training_utils/cp_utils.py miles/backends/training_utils/data.py miles/backends/training_utils/log_utils.py miles/backends/training_utils/loss.py miles/backends/training_utils/loss_hub/losses.py miles/ray/rollout/train_data_conversion.py miles/utils/arguments.py miles/backends/megatron_utils/model.py miles/backends/experimental/fsdp_utils/actor.py tests/fast/backends/training_utils/test_loss_aggregation.py tests/fast/backends/training_utils/loss/test_loss_snapshot.py
  • git diff --check upstream/main

python3 train.py --help is not runnable in my local environment because sglang is not installed; the parser path for the new flags is covered by the focused tests.

@EazyReal

EazyReal commented Jun 27, 2026

Copy link
Copy Markdown
Author

Hi @yueming-yuan, reopening this fresh PR for visibility after syncing the loss aggregation behavior with THUDM/slime#2090.

This adds pg_loss aggregation modes in Miles and keeps prompt_mean aligned with slime's current implementation: prompt_mask_sums, n_samples_per_prompt reducer scaling, and global_batch_size % n_samples_per_prompt == 0 validation.

Please review when you have a chance.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the --loss-aggregation command-line option to support multiple aggregation modes for pg_loss (sample_mean, prompt_mean, token_mean, and constant), along with validation logic, documentation, and a comprehensive test suite. The review feedback suggests appending new parameters to the end of get_sum_of_sample_mean to preserve backward compatibility for positional arguments, avoiding redundant computation of sample_denoms when constant_divisor is active, and optimizing the GPU tensor conversion of prompt_mask_sums by avoiding a loop over individual scalar tensors.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +96 to +100
sample_denoms: list[torch.Tensor] | None = None,
calculate_per_token_loss: bool = False,
qkv_format: str = "thd",
max_seq_lens: list[int] | None = None,
constant_divisor: float | None = None,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Inserting the new parameter sample_denoms in the middle of the existing parameter list breaks backward compatibility for any callers (including potential custom user scripts or external plugins) that invoke get_sum_of_sample_mean using positional arguments.

To maintain backward compatibility and avoid breaking existing positional calls, new parameters with default values should be appended to the end of the parameter list.

Suggested change
sample_denoms: list[torch.Tensor] | None = None,
calculate_per_token_loss: bool = False,
qkv_format: str = "thd",
max_seq_lens: list[int] | None = None,
constant_divisor: float | None = None,
calculate_per_token_loss: bool = False,
qkv_format: str = "thd",
max_seq_lens: list[int] | None = None,
sample_denoms: list[torch.Tensor] | None = None,
constant_divisor: float | None = None,

Comment on lines +108 to +109
if sample_denoms is None:
sample_denoms = [loss_mask.sum() for loss_mask in loss_masks]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When constant_divisor is provided (i.e., in constant mode), sample_denoms is completely unused because sum_of_constant uses sum_of_token which does not reference sample_denoms. Computing sample_denoms via list comprehension over GPU tensors is redundant and can be avoided.

We should only compute sample_denoms if it is None and constant_divisor is also None.

Suggested change
if sample_denoms is None:
sample_denoms = [loss_mask.sum() for loss_mask in loss_masks]
if sample_denoms is None and constant_divisor is None:
sample_denoms = [loss_mask.sum() for loss_mask in loss_masks]

Comment thread miles/backends/training_utils/data.py Outdated
Comment on lines +48 to +52
if "prompt_mask_sums" in rollout_data:
rollout_data["prompt_mask_sums"] = [
torch.tensor(d, dtype=torch.float32, device=torch.cuda.current_device())
for d in rollout_data["prompt_mask_sums"]
]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a list comprehension to create individual 0-D tensors on GPU in a loop (torch.tensor(d, ...)) introduces significant Python loop and CUDA launch overhead, especially for larger batch sizes.

Since prompt_mask_sums is a list of scalars, we can convert the entire list to a single 1-D GPU tensor in a single operation, and then convert it back to a list of 0-D views using list(). This is much more efficient.

    if "prompt_mask_sums" in rollout_data:
        rollout_data["prompt_mask_sums"] = list(
            torch.tensor(rollout_data["prompt_mask_sums"], dtype=torch.float32, device=torch.cuda.current_device())
        )

@EazyReal EazyReal force-pushed the upstream-pr/loss-aggregation-modes-v2 branch from d0be4cc to 5bba0f9 Compare June 27, 2026 17:17
@EazyReal EazyReal changed the title feat(loss): add --loss-aggregation pg_loss modes feat(loss): make pg_loss aggregation explicit Jun 27, 2026
@EazyReal EazyReal force-pushed the upstream-pr/loss-aggregation-modes-v2 branch 2 times, most recently from fd42cd9 to c64cbaa Compare June 27, 2026 17:29
@EazyReal EazyReal changed the title feat(loss): make pg_loss aggregation explicit feat(loss): support pg_loss aggregation modes Jun 27, 2026
@EazyReal EazyReal force-pushed the upstream-pr/loss-aggregation-modes-v2 branch 2 times, most recently from 98588bd to 4f1a9af Compare June 27, 2026 20:12
@EazyReal EazyReal force-pushed the upstream-pr/loss-aggregation-modes-v2 branch from 67e0503 to fe7d2bf Compare June 27, 2026 22:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant