Open
Conversation
The dapo/cispo/vespo branch normalizes by per-process active tokens computed from the current micro-batch's attention_mask. TRL's compute_loss uses ``num_items_in_batch / num_processes`` instead — the total active tokens across the entire generation batch (all gradient-accumulation micro-batches × all processes). Equivalent only when each micro-batch has the same token count; with variable-length completions, Liger over-weighted tokens in shorter micro-batches. Plumb an optional ``num_items_in_batch`` through ``LigerFusedLinearGRPOLoss`` (forward), ``LigerFusedLinearGRPOFunction``, and ``LigerFusedLinearPPOBase`` into ``_compute_dapo_normalizer``. When provided, normalize by ``num_items_in_batch / world_size``. When None, keep the existing attention-mask + all_reduce fallback so callers that don't have the total still get a sensible normalizer. Tests: - test_num_items_in_batch_normalizer: confirms passing ``mask.sum()`` matches the default and that doubling the value halves loss/grad. - test_num_items_in_batch_matches_trl_formula: cross-checks against TRL's ``(per_token_loss * mask).sum() / num_items_in_batch`` formula directly.
When ``importance_sampling_level == "sequence"`` and ``use_bias_correction_kl`` is enabled, TRL multiplies ``per_token_kl`` by the sequence-level ``coef_1`` (shape ``(B, 1)``) so every token in a sequence is scaled by the same ratio. Liger always recomputed a token-level ratio (``exp(per_token_logps - old)``, shape ``(B, T)``), so the bias-corrected KL was wrong for sequence-level IS even though it matched TRL for token-level. Use ``exp(log_importance_weights)`` directly — that variable already encodes the importance-sampling level (token: ``(B, T)``; sequence: ``(B, 1)``) and is unchanged by the loss-type branches and the optional ``delta`` clamp, so this also lines up with TRL's order of operations (bias correction before delta clamp). Apply the same fix to the torch reference ``TorchLMHeadGRPO`` in the test file — it had the identical bug, which is why ``test_correctness_with_bias_correction_kl`` silently passed for both ISL values. Extend that test to parametrize over ``importance_sampling_level`` so the sequence path is now covered.
The triton GRPO loss had the same two TRL-compat issues as the chunked path
before its fix:
1. ``dapo`` / ``cispo`` reduction normalized by per-process active tokens
from the current micro-batch's mask. Add ``num_items_in_batch`` to
``triton_grpo_loss``, ``GrpoLossFunction``, ``_reduce_loss``,
``_reduce_grpo_loss``, and the ``_compute_dapo_normalizer`` helper in
``ops/grpo_loss.py``. When provided, normalize by
``num_items_in_batch / world_size`` to match TRL's ``compute_loss``;
otherwise fall back to the existing all-reduced micro-batch sum.
2. Bias-corrected KL with ``importance_sampling_level == "sequence"`` used
a freshly-recomputed token-level ratio (``exp(logp - old_logp)``) for
``per_token_kl *= coef_1``. TRL multiplies by the importance-sampling-
level-aware ``coef_1`` (sequence-level: shape ``(B, 1)``, broadcasting
the same scale to every token in a sequence).
- Forward kernel ``_grpo_loss_fwd_kernel_seq``: take a new
``COEF_1_RAW`` argument carrying the pre delta-clamp sequence-level
``coef_1`` and use it for the bias-correction multiply (TRL applies
bias correction before the delta clamp).
- Backward kernel ``_grpo_loss_bwd_kernel_seq``: the kernel already
loaded the pre-clamp sequence-level ``coef_1``; just use it instead
of recomputing token-level inside the bias-correction branch.
- The token-level forward kernel was already correct (its ``coef_1``
is token-level by construction).
Tests:
- Fix the bug also encoded in the test reference ``trl_reference_grpo_loss``
in ``test/transformers/test_grpo_loss.py`` so it now matches TRL exactly
for both ISL values.
- Extend ``test_grpo_loss_vs_trl`` to parametrize
``use_bias_correction_kl`` so the sequence-level + bias-correction path
is now covered against TRL.
- Add ``test_triton_num_items_in_batch_normalizer``: passing
``num_items_in_batch=mask.sum()`` matches the default; doubling halves
loss and gradients (mirrors the chunked-loss test).
VESPO (TRL grpo_trainer ``loss_type='vespo'``) was missing from the triton
path — calling ``triton_grpo_loss(loss_type='vespo')`` would fail at the
``_str_to_loss_type`` lookup. Add it to both the forward and backward
token-level kernels. Sequence-level IS is rejected upstream for VESPO (TRL
also does this), so the seq kernel doesn't need a branch.
Kernel changes (``ops/grpo_loss.py``):
- New ``_LOSS_TYPE_VESPO = 3`` constant; map ``"vespo"`` to it.
- ``_grpo_loss_fwd_kernel`` and ``_grpo_loss_bwd_kernel`` take a new
``PHI_SEQ`` (B,) tensor pointer (or None).
- New ``LOSS_TYPE == 3`` branch:
- Forward: ``per_token_loss = -phi_seq * advantage * logp``;
``is_clipped = 0`` (VESPO has no clipping).
- Backward: ``dlogp = -phi_seq * advantage`` (phi_seq is detached).
- The vLLM IS ratio multiplication is skipped for VESPO in both kernels —
it's already folded into ``phi_seq`` via ``log_is_ratio`` upstream.
Wrapper changes (``transformers/grpo_loss.py``):
- ``triton_grpo_loss`` accepts the four VESPO hyperparameters
(``vespo_k_pos / vespo_lambda_pos / vespo_k_neg / vespo_lambda_neg``).
- For ``loss_type='vespo'``: pre-compute per-token logp via
``fused_selective_log_softmax`` (no_grad — phi_seq is detached anyway),
call the existing ``get_gamma_weights`` from the chunked path to get
phi_seq, then pass it to ``GrpoLossFunction.apply``. We also force
``vllm_is_ratio=None`` for the kernel call since it's folded in.
- ``_reduce_grpo_loss``: route ``vespo`` through the dapo normalizer (with
``num_items_in_batch`` support).
``GrpoLossFunction`` changes:
- New ``phi_seq`` argument; saved via ctx for backward; passed to both
fwd and bwd kernels; +1 ``None`` in the gradient return.
- ``_reduce_loss`` and the dapo-normalizer dloss-scaling branch in
``backward`` route ``vespo`` to the dapo normalizer.
- ``delta`` and ``importance_sampling_level == 'sequence'`` validations
extended to reject ``vespo`` (matching TRL).
Tests:
- Add a ``vespo`` branch to ``trl_reference_grpo_loss`` so the reference
computes ``-phi_seq * advantage * logp`` and uses the dapo normalizer.
The branch uses the same ``get_gamma_weights`` as the chunked path, so
the reference is single-source-of-truth across chunked and triton.
- ``test_grpo_loss_vs_trl`` parametrized over ``loss_type='vespo'`` (skips
the rejected ``vespo + sequence`` and ``vespo + delta`` combos).
- Verified end-to-end: triton VESPO matches a hand-rolled pure-pytorch
reference within 2e-6 (loss + grads, fp32).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fixes missing features as flagged in #1082
Testing Done
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence