Skip to content

[GRPO] fix TRL-liger missing #1202

Open
kashif wants to merge 4 commits intolinkedin:mainfrom
kashif:fix_grpo_dapo_normalizer
Open

[GRPO] fix TRL-liger missing #1202
kashif wants to merge 4 commits intolinkedin:mainfrom
kashif:fix_grpo_dapo_normalizer

Conversation

@kashif
Copy link
Copy Markdown
Contributor

@kashif kashif commented Apr 26, 2026

Summary

Fixes missing features as flagged in #1082

Testing Done

  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

kashif added 4 commits April 26, 2026 08:20
The dapo/cispo/vespo branch normalizes by per-process active tokens
computed from the current micro-batch's attention_mask. TRL's compute_loss
uses ``num_items_in_batch / num_processes`` instead — the total active
tokens across the entire generation batch (all gradient-accumulation
micro-batches × all processes). Equivalent only when each micro-batch has
the same token count; with variable-length completions, Liger over-weighted
tokens in shorter micro-batches.

Plumb an optional ``num_items_in_batch`` through ``LigerFusedLinearGRPOLoss``
(forward), ``LigerFusedLinearGRPOFunction``, and ``LigerFusedLinearPPOBase``
into ``_compute_dapo_normalizer``. When provided, normalize by
``num_items_in_batch / world_size``. When None, keep the existing
attention-mask + all_reduce fallback so callers that don't have the total
still get a sensible normalizer.

Tests:
- test_num_items_in_batch_normalizer: confirms passing ``mask.sum()``
  matches the default and that doubling the value halves loss/grad.
- test_num_items_in_batch_matches_trl_formula: cross-checks against TRL's
  ``(per_token_loss * mask).sum() / num_items_in_batch`` formula directly.
When ``importance_sampling_level == "sequence"`` and ``use_bias_correction_kl``
is enabled, TRL multiplies ``per_token_kl`` by the sequence-level ``coef_1``
(shape ``(B, 1)``) so every token in a sequence is scaled by the same ratio.
Liger always recomputed a token-level ratio (``exp(per_token_logps - old)``,
shape ``(B, T)``), so the bias-corrected KL was wrong for sequence-level IS
even though it matched TRL for token-level.

Use ``exp(log_importance_weights)`` directly — that variable already encodes
the importance-sampling level (token: ``(B, T)``; sequence: ``(B, 1)``) and
is unchanged by the loss-type branches and the optional ``delta`` clamp, so
this also lines up with TRL's order of operations (bias correction before
delta clamp).

Apply the same fix to the torch reference ``TorchLMHeadGRPO`` in the test
file — it had the identical bug, which is why ``test_correctness_with_bias_correction_kl``
silently passed for both ISL values.

Extend that test to parametrize over ``importance_sampling_level`` so the
sequence path is now covered.
The triton GRPO loss had the same two TRL-compat issues as the chunked path
before its fix:

1. ``dapo`` / ``cispo`` reduction normalized by per-process active tokens
   from the current micro-batch's mask. Add ``num_items_in_batch`` to
   ``triton_grpo_loss``, ``GrpoLossFunction``, ``_reduce_loss``,
   ``_reduce_grpo_loss``, and the ``_compute_dapo_normalizer`` helper in
   ``ops/grpo_loss.py``. When provided, normalize by
   ``num_items_in_batch / world_size`` to match TRL's ``compute_loss``;
   otherwise fall back to the existing all-reduced micro-batch sum.

2. Bias-corrected KL with ``importance_sampling_level == "sequence"`` used
   a freshly-recomputed token-level ratio (``exp(logp - old_logp)``) for
   ``per_token_kl *= coef_1``. TRL multiplies by the importance-sampling-
   level-aware ``coef_1`` (sequence-level: shape ``(B, 1)``, broadcasting
   the same scale to every token in a sequence).
   - Forward kernel ``_grpo_loss_fwd_kernel_seq``: take a new
     ``COEF_1_RAW`` argument carrying the pre delta-clamp sequence-level
     ``coef_1`` and use it for the bias-correction multiply (TRL applies
     bias correction before the delta clamp).
   - Backward kernel ``_grpo_loss_bwd_kernel_seq``: the kernel already
     loaded the pre-clamp sequence-level ``coef_1``; just use it instead
     of recomputing token-level inside the bias-correction branch.
   - The token-level forward kernel was already correct (its ``coef_1``
     is token-level by construction).

Tests:
- Fix the bug also encoded in the test reference ``trl_reference_grpo_loss``
  in ``test/transformers/test_grpo_loss.py`` so it now matches TRL exactly
  for both ISL values.
- Extend ``test_grpo_loss_vs_trl`` to parametrize
  ``use_bias_correction_kl`` so the sequence-level + bias-correction path
  is now covered against TRL.
- Add ``test_triton_num_items_in_batch_normalizer``: passing
  ``num_items_in_batch=mask.sum()`` matches the default; doubling halves
  loss and gradients (mirrors the chunked-loss test).
VESPO (TRL grpo_trainer ``loss_type='vespo'``) was missing from the triton
path — calling ``triton_grpo_loss(loss_type='vespo')`` would fail at the
``_str_to_loss_type`` lookup. Add it to both the forward and backward
token-level kernels. Sequence-level IS is rejected upstream for VESPO (TRL
also does this), so the seq kernel doesn't need a branch.

Kernel changes (``ops/grpo_loss.py``):
- New ``_LOSS_TYPE_VESPO = 3`` constant; map ``"vespo"`` to it.
- ``_grpo_loss_fwd_kernel`` and ``_grpo_loss_bwd_kernel`` take a new
  ``PHI_SEQ`` (B,) tensor pointer (or None).
- New ``LOSS_TYPE == 3`` branch:
  - Forward: ``per_token_loss = -phi_seq * advantage * logp``;
    ``is_clipped = 0`` (VESPO has no clipping).
  - Backward: ``dlogp = -phi_seq * advantage`` (phi_seq is detached).
- The vLLM IS ratio multiplication is skipped for VESPO in both kernels —
  it's already folded into ``phi_seq`` via ``log_is_ratio`` upstream.

Wrapper changes (``transformers/grpo_loss.py``):
- ``triton_grpo_loss`` accepts the four VESPO hyperparameters
  (``vespo_k_pos / vespo_lambda_pos / vespo_k_neg / vespo_lambda_neg``).
- For ``loss_type='vespo'``: pre-compute per-token logp via
  ``fused_selective_log_softmax`` (no_grad — phi_seq is detached anyway),
  call the existing ``get_gamma_weights`` from the chunked path to get
  phi_seq, then pass it to ``GrpoLossFunction.apply``. We also force
  ``vllm_is_ratio=None`` for the kernel call since it's folded in.
- ``_reduce_grpo_loss``: route ``vespo`` through the dapo normalizer (with
  ``num_items_in_batch`` support).

``GrpoLossFunction`` changes:
- New ``phi_seq`` argument; saved via ctx for backward; passed to both
  fwd and bwd kernels; +1 ``None`` in the gradient return.
- ``_reduce_loss`` and the dapo-normalizer dloss-scaling branch in
  ``backward`` route ``vespo`` to the dapo normalizer.
- ``delta`` and ``importance_sampling_level == 'sequence'`` validations
  extended to reject ``vespo`` (matching TRL).

Tests:
- Add a ``vespo`` branch to ``trl_reference_grpo_loss`` so the reference
  computes ``-phi_seq * advantage * logp`` and uses the dapo normalizer.
  The branch uses the same ``get_gamma_weights`` as the chunked path, so
  the reference is single-source-of-truth across chunked and triton.
- ``test_grpo_loss_vs_trl`` parametrized over ``loss_type='vespo'`` (skips
  the rejected ``vespo + sequence`` and ``vespo + delta`` combos).
- Verified end-to-end: triton VESPO matches a hand-rolled pure-pytorch
  reference within 2e-6 (loss + grads, fp32).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant