feat(loss): support pg_loss aggregation modes#1498
Conversation
|
Hi @yueming-yuan, reopening this fresh PR for visibility after syncing the loss aggregation behavior with THUDM/slime#2090. This adds Please review when you have a chance. |
There was a problem hiding this comment.
Code Review
This pull request introduces the --loss-aggregation command-line option to support multiple aggregation modes for pg_loss (sample_mean, prompt_mean, token_mean, and constant), along with validation logic, documentation, and a comprehensive test suite. The review feedback suggests appending new parameters to the end of get_sum_of_sample_mean to preserve backward compatibility for positional arguments, avoiding redundant computation of sample_denoms when constant_divisor is active, and optimizing the GPU tensor conversion of prompt_mask_sums by avoiding a loop over individual scalar tensors.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| sample_denoms: list[torch.Tensor] | None = None, | ||
| calculate_per_token_loss: bool = False, | ||
| qkv_format: str = "thd", | ||
| max_seq_lens: list[int] | None = None, | ||
| constant_divisor: float | None = None, |
There was a problem hiding this comment.
Inserting the new parameter sample_denoms in the middle of the existing parameter list breaks backward compatibility for any callers (including potential custom user scripts or external plugins) that invoke get_sum_of_sample_mean using positional arguments.
To maintain backward compatibility and avoid breaking existing positional calls, new parameters with default values should be appended to the end of the parameter list.
| sample_denoms: list[torch.Tensor] | None = None, | |
| calculate_per_token_loss: bool = False, | |
| qkv_format: str = "thd", | |
| max_seq_lens: list[int] | None = None, | |
| constant_divisor: float | None = None, | |
| calculate_per_token_loss: bool = False, | |
| qkv_format: str = "thd", | |
| max_seq_lens: list[int] | None = None, | |
| sample_denoms: list[torch.Tensor] | None = None, | |
| constant_divisor: float | None = None, |
| if sample_denoms is None: | ||
| sample_denoms = [loss_mask.sum() for loss_mask in loss_masks] |
There was a problem hiding this comment.
When constant_divisor is provided (i.e., in constant mode), sample_denoms is completely unused because sum_of_constant uses sum_of_token which does not reference sample_denoms. Computing sample_denoms via list comprehension over GPU tensors is redundant and can be avoided.
We should only compute sample_denoms if it is None and constant_divisor is also None.
| if sample_denoms is None: | |
| sample_denoms = [loss_mask.sum() for loss_mask in loss_masks] | |
| if sample_denoms is None and constant_divisor is None: | |
| sample_denoms = [loss_mask.sum() for loss_mask in loss_masks] |
| if "prompt_mask_sums" in rollout_data: | ||
| rollout_data["prompt_mask_sums"] = [ | ||
| torch.tensor(d, dtype=torch.float32, device=torch.cuda.current_device()) | ||
| for d in rollout_data["prompt_mask_sums"] | ||
| ] |
There was a problem hiding this comment.
Using a list comprehension to create individual 0-D tensors on GPU in a loop (torch.tensor(d, ...)) introduces significant Python loop and CUDA launch overhead, especially for larger batch sizes.
Since prompt_mask_sums is a list of scalars, we can convert the entire list to a single 1-D GPU tensor in a single operation, and then convert it back to a list of 0-D views using list(). This is much more efficient.
if "prompt_mask_sums" in rollout_data:
rollout_data["prompt_mask_sums"] = list(
torch.tensor(rollout_data["prompt_mask_sums"], dtype=torch.float32, device=torch.cuda.current_device())
)d0be4cc to
5bba0f9
Compare
fd42cd9 to
c64cbaa
Compare
98588bd to
4f1a9af
Compare
4f1a9af to
f73fc1f
Compare
67e0503 to
fe7d2bf
Compare
Description
Adds built-in
pg_lossaggregation modes to Miles. Related implementations for comparison:This adds
--loss-aggregation {sample_mean,prompt_mean,token_mean,constant}while keepingsample_meanas the default.--calculate-per-token-lossremains the legacy spelling fortoken_mean, and custom pg-loss reducers still take precedence.sample_meanprompt_meantoken_meanconstant--loss-aggregation-divisorThe implementation keeps denominator ownership explicit:
token_meanpolicy loss rejects nonzero entropy or KL-loss coefficients that would mix token-normalizedpg_losswith sample-normalized auxiliary loss terms.pg_lossuses the modified mask for loss reduction while mismatch metrics stay on the original mask.prompt_meanDP splitting keeps prompt groups whole on each DP shard and rejects train steps whose prompt-group count cannot be distributed evenly.global_batch_sizebefore checking aggregation constraints.Validation
uv run --with pytest --with torch --with numpy --with httpx --with pyyaml --with ray --with huggingface_hub --with transformers --with pydantic --with psutil pytest --confcutdir=tests/fast/backends/training_utils tests/fast/backends/training_utils/test_loss_aggregation.py tests/fast/backends/training_utils/loss/test_loss_snapshot.py -quv run --with ruff ruff check miles/backends/training_utils/cp_utils.py miles/backends/training_utils/data.py miles/backends/training_utils/log_utils.py miles/backends/training_utils/loss.py miles/backends/training_utils/loss_hub/losses.py miles/ray/rollout/train_data_conversion.py miles/utils/arguments.py miles/backends/megatron_utils/model.py miles/backends/experimental/fsdp_utils/actor.py tests/fast/backends/training_utils/test_loss_aggregation.py tests/fast/backends/training_utils/loss/test_loss_snapshot.pyuv run --with black black --check miles/backends/training_utils/cp_utils.py miles/backends/training_utils/data.py miles/backends/training_utils/log_utils.py miles/backends/training_utils/loss.py miles/backends/training_utils/loss_hub/losses.py miles/ray/rollout/train_data_conversion.py miles/utils/arguments.py miles/backends/megatron_utils/model.py miles/backends/experimental/fsdp_utils/actor.py tests/fast/backends/training_utils/test_loss_aggregation.py tests/fast/backends/training_utils/loss/test_loss_snapshot.pygit diff --check upstream/mainpython3 train.py --helpis not runnable in my local environment becausesglangis not installed; the parser path for the new flags is covered by the focused tests.