diff --git a/src/liger_kernel/chunked_loss/fused_linear_ppo.py b/src/liger_kernel/chunked_loss/fused_linear_ppo.py index a1851cfa4..28bd5f11c 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_ppo.py +++ b/src/liger_kernel/chunked_loss/fused_linear_ppo.py @@ -200,6 +200,7 @@ def forward( vespo_lambda_pos=3.0, vespo_k_neg=3.0, vespo_lambda_neg=2.0, + num_items_in_batch=None, ): """Chunked forward pass for PPO loss computation. @@ -254,6 +255,7 @@ def forward( vespo_lambda_pos=vespo_lambda_pos, vespo_k_neg=vespo_k_neg, vespo_lambda_neg=vespo_lambda_neg, + num_items_in_batch=num_items_in_batch, ) compiled_compute_loss = torch.compile(compute_loss) if compiled else compute_loss @@ -433,16 +435,38 @@ def accumulate_chunk( return loss_acc, tuple(final_metrics) @staticmethod - def _compute_dapo_normalizer(attention_mask): - """Global active tokens averaged per process.""" - normalizer = attention_mask.to(torch.float32).sum() + def _compute_dapo_normalizer(attention_mask, num_items_in_batch=None): + """Per-process normalizer for DAPO/CISPO/VESPO. + + When ``num_items_in_batch`` is provided it is used directly, matching + TRL's ``num_items_in_batch / num_processes`` — the total active tokens + across the entire generation batch (all grad-accum micro-batches × all + processes). Falling back to the current micro-batch's mask biases the + per-token weight by micro-batch size when grad-accum steps have + unequal completion lengths. + """ world_size = 1 if torch.distributed.is_available() and torch.distributed.is_initialized(): import torch.distributed as dist + world_size = dist.get_world_size() + + if num_items_in_batch is not None: + if isinstance(num_items_in_batch, torch.Tensor): + normalizer = num_items_in_batch.to(device=attention_mask.device, dtype=torch.float32) + else: + normalizer = torch.as_tensor( + float(num_items_in_batch), device=attention_mask.device, dtype=torch.float32 + ) + normalizer = normalizer / world_size + return torch.clamp(normalizer, min=1.0) + + normalizer = attention_mask.to(torch.float32).sum() + if torch.distributed.is_available() and torch.distributed.is_initialized(): + import torch.distributed as dist + normalizer = normalizer.clone() dist.all_reduce(normalizer, op=dist.ReduceOp.SUM) - world_size = dist.get_world_size() normalizer = normalizer / world_size return torch.clamp(normalizer, min=1.0) @@ -471,6 +495,7 @@ def _compute_loss_from_logps( vespo_lambda_pos=3.0, vespo_k_neg=3.0, vespo_lambda_neg=2.0, + num_items_in_batch=None, ): """Compute loss from pre-computed logprobs. This is the torch.compile-friendly part.""" chunk_loss, chunk_metrics = ppo_loss_fn( @@ -495,6 +520,7 @@ def _compute_loss_from_logps( vespo_lambda_pos=vespo_lambda_pos, vespo_k_neg=vespo_k_neg, vespo_lambda_neg=vespo_lambda_neg, + num_items_in_batch=num_items_in_batch, ) return chunk_loss, chunk_metrics @@ -556,4 +582,5 @@ def backward(ctx, grad_output, *grad_metrics): None, # grad_vllm_is_ratio None, # grad_delta None, # grad_use_bias_correction_kl + None, # grad_num_items_in_batch ) diff --git a/src/liger_kernel/chunked_loss/grpo_loss.py b/src/liger_kernel/chunked_loss/grpo_loss.py index 400808536..5e84d65f1 100644 --- a/src/liger_kernel/chunked_loss/grpo_loss.py +++ b/src/liger_kernel/chunked_loss/grpo_loss.py @@ -120,6 +120,7 @@ def ppo_loss_fn( vespo_lambda_pos=3.0, # VESPO gamma rate lambda for non-negative advantages vespo_k_neg=3.0, # VESPO gamma shape k for negative advantages vespo_lambda_neg=2.0, # VESPO gamma rate lambda for negative advantages + num_items_in_batch=None, # Total active tokens across the entire generation batch (TRL-compat) **kwargs, ): """GRPO Loss Function matching GRPOTrainer implementation.""" @@ -209,9 +210,11 @@ def ppo_loss_fn( # Compute KL penalty (approximates KL[per_token_logps, ref_per_token_logps]) kl_div = k3_loss_fn(ref_per_token_logps, per_token_logps) if use_bias_correction_kl: - # Importance-sampling-corrected KL (DeepSeek-V3.2): kl *= token-level coef_1 - token_coef_1 = torch.exp(per_token_logps - old_per_token_logps) - kl_div = kl_div * token_coef_1 + # Importance-sampling-corrected KL (DeepSeek-V3.2): kl *= coef_1. + # Use exp(log_importance_weights) so the ratio's shape matches + # importance_sampling_level (token: (B, T); sequence: (B, 1)), + # mirroring TRL's ``per_token_kl * coef_1`` (un-clamped, before delta). + kl_div = kl_div * torch.exp(log_importance_weights) # Combine losses per_token_loss = per_token_loss + beta * kl_div @@ -233,7 +236,9 @@ def ppo_loss_fn( raise ValueError("max_completion_length must be provided for loss_type 'dr_grpo'") loss = (per_token_loss * attention_mask).sum() / (full_attention_mask.shape[0] * max_completion_length) elif loss_type in ("dapo", "cispo", "vespo"): - loss_normalizer = LigerFusedLinearPPOBase._compute_dapo_normalizer(full_attention_mask) + loss_normalizer = LigerFusedLinearPPOBase._compute_dapo_normalizer( + full_attention_mask, num_items_in_batch=num_items_in_batch + ) loss = (per_token_loss * attention_mask).sum() / loss_normalizer elif loss_type == "luspo": # Match TRL exactly: loss = (per_token_loss * mask.sum(1, keepdim=True)).mean() @@ -296,6 +301,7 @@ def forward( vespo_lambda_pos=3.0, vespo_k_neg=3.0, vespo_lambda_neg=2.0, + num_items_in_batch=None, ): """ Fused linear layer with GRPO loss. @@ -366,6 +372,7 @@ def forward( vespo_lambda_pos=vespo_lambda_pos, vespo_k_neg=vespo_k_neg, vespo_lambda_neg=vespo_lambda_neg, + num_items_in_batch=num_items_in_batch, ) @staticmethod @@ -405,6 +412,7 @@ def backward(ctx, grad_output, *grad_metrics): None, # grad_vespo_lambda_pos None, # grad_vespo_k_neg None, # grad_vespo_lambda_neg + None, # grad_num_items_in_batch ) @@ -492,6 +500,7 @@ def forward( ref_weight=None, ref_bias=None, vllm_is_ratio=None, + num_items_in_batch=None, ): return LigerFusedLinearGRPOFunction.apply( _input, @@ -524,4 +533,5 @@ def forward( self.vespo_lambda_pos, self.vespo_k_neg, self.vespo_lambda_neg, + num_items_in_batch, ) diff --git a/src/liger_kernel/ops/grpo_loss.py b/src/liger_kernel/ops/grpo_loss.py index bcd1262bc..034193bb7 100644 --- a/src/liger_kernel/ops/grpo_loss.py +++ b/src/liger_kernel/ops/grpo_loss.py @@ -7,6 +7,7 @@ _LOSS_TYPE_GRPO: tl.constexpr = tl.constexpr(0) _LOSS_TYPE_CISPO: tl.constexpr = tl.constexpr(1) _LOSS_TYPE_SAPO: tl.constexpr = tl.constexpr(2) +_LOSS_TYPE_VESPO: tl.constexpr = tl.constexpr(3) _str_to_loss_type = { "grpo": _LOSS_TYPE_GRPO.value, @@ -16,6 +17,7 @@ "luspo": _LOSS_TYPE_GRPO.value, "cispo": _LOSS_TYPE_CISPO.value, "sapo": _LOSS_TYPE_SAPO.value, + "vespo": _LOSS_TYPE_VESPO.value, } @@ -93,6 +95,7 @@ def _grpo_loss_fwd_kernel( ADVANTAGES, VLLM_IS_RATIO, VLLM_IS_RATIO_STRIDE, + PHI_SEQ, # VESPO gamma weight per sequence (B,) or None LOSS, LSE, KL, @@ -177,8 +180,16 @@ def _grpo_loss_fwd_kernel( per_token_loss = -sapo_coef * advantage is_clipped = 0.0 # SAPO has no clipping concept + elif LOSS_TYPE == 3: # VESPO: detached gamma weighting phi(w) on per-token logp + # Reference: TRL grpo_trainer.get_gamma_weights / chunked_loss/grpo_loss.py + # phi_seq is precomputed per-sequence, vllm correction is folded into phi_seq. + phi_seq = tl.load(PHI_SEQ + off_b).to(tl.float32) + per_token_loss = -phi_seq * advantage * logp + is_clipped = 0.0 # VESPO has no clipping concept + # Apply vLLM importance sampling correction BEFORE adding KL penalty - if VLLM_IS_RATIO is not None: + # VESPO folds this into phi_seq (in log space), so we skip it here. + if VLLM_IS_RATIO is not None and LOSS_TYPE != 3: # Use modulo to support both (B, L) per-token and (B, 1) per-sequence shapes vllm_is_ratio = tl.load(VLLM_IS_RATIO + off_b * VLLM_IS_RATIO_STRIDE + off_l % VLLM_IS_RATIO_STRIDE).to( tl.float32 @@ -210,7 +221,8 @@ def _grpo_loss_fwd_kernel_seq( INPUT_IDS, COMPLETION_MASK, ADVANTAGES, - COEF_1, # Pre-computed sequence-level importance weight (B,) + COEF_1, # Pre-computed sequence-level importance weight, post delta-clamp (B,) + COEF_1_RAW, # Pre-computed sequence-level importance weight, pre delta-clamp (B,) COEF_2, # Pre-computed clipped coef (B,) IS_CLIPPED_SEQ, # Pre-computed clipping indicator (B,) VLLM_IS_RATIO, # vLLM importance sampling ratio (B, L) or (B, 1) or None @@ -239,6 +251,7 @@ def _grpo_loss_fwd_kernel_seq( INPUT_IDS += off_b * L + off_l ADVANTAGES += off_b COEF_1 += off_b + COEF_1_RAW += off_b COEF_2 += off_b IS_CLIPPED_SEQ += off_b LOSS += off_b * L + off_l @@ -284,12 +297,10 @@ def _grpo_loss_fwd_kernel_seq( ref_logp = tl.load(REF_LOGP).to(tl.float32) kl = tl.exp(ref_logp - logp) - (ref_logp - logp) - 1 if USE_BIAS_CORRECTION_KL: - # Importance-sampling-corrected KL (DeepSeek-V3.2): kl *= token-level coef_1 - if OLD_LOGP is None: - old_logp = logp - else: - old_logp = tl.load(OLD_LOGP + off_b * L + off_l).to(tl.float32) - kl = kl * tl.exp(logp - old_logp) + # Importance-sampling-corrected KL (DeepSeek-V3.2): kl *= coef_1. + # Use the pre-clamp sequence-level coef_1 to match TRL — the same coef_1 + # that was passed to ``per_token_kl * coef_1`` upstream of the delta clamp. + kl = kl * tl.load(COEF_1_RAW).to(tl.float32) per_token_loss += BETA * kl tl.store(KL, kl) @@ -377,13 +388,10 @@ def _grpo_loss_bwd_kernel_seq( REF_LOGP += off_b * L + off_l ref_logp = tl.load(REF_LOGP).to(tl.float32) if USE_BIAS_CORRECTION_KL: - # d(kl * coef_1)/d(logp) = coef_1 * (logp - ref_logp), where coef_1 = exp(logp - old_logp) - if OLD_LOGP is None: - old_logp = logp - else: - old_logp = tl.load(OLD_LOGP + off_b * L + off_l).to(tl.float32) - token_coef_1 = tl.exp(logp - old_logp) - dlogp += BETA * token_coef_1 * (logp - ref_logp) * dloss + # d(kl * coef_1)/d(logp) ≈ coef_1 * (logp - ref_logp), with coef_1 detached. + # Use the loaded sequence-level coef_1 (pre delta-clamp) to match TRL's + # ``per_token_kl * coef_1`` for importance_sampling_level == "sequence". + dlogp += BETA * coef_1 * (logp - ref_logp) * dloss else: dlogp += BETA * (1 - tl.exp(ref_logp - logp)) * dloss @@ -410,6 +418,7 @@ def _grpo_loss_bwd_kernel( LSE, VLLM_IS_RATIO, VLLM_IS_RATIO_STRIDE, + PHI_SEQ, # VESPO gamma weight per sequence (B,) or None TEMPERATURE, BETA: tl.constexpr, EPS_LOW, @@ -492,8 +501,14 @@ def _grpo_loss_bwd_kernel( d_sapo_d_coef1 = 4.0 * sigmoid_val * (1.0 - sigmoid_val) dlogp = -advantage * d_sapo_d_coef1 * coef_1 + elif LOSS_TYPE == 3: # VESPO: detached gamma weighting on per-token logp + # loss = -phi_seq * advantage * logp; phi_seq is detached → ∂loss/∂logp = -phi_seq * advantage + phi_seq = tl.load(PHI_SEQ + off_b).to(tl.float32) + dlogp = -phi_seq * advantage + # Apply vLLM IS ratio to PPO gradient (before KL gradient) - if VLLM_IS_RATIO is not None: + # VESPO folds vllm correction into phi_seq, so we skip it here. + if VLLM_IS_RATIO is not None and LOSS_TYPE != 3: # Use modulo to support both (B, L) per-token and (B, 1) per-sequence shapes vllm_is_ratio = tl.load(VLLM_IS_RATIO + off_b * VLLM_IS_RATIO_STRIDE + off_l % VLLM_IS_RATIO_STRIDE).to( tl.float32 @@ -519,19 +534,37 @@ def _grpo_loss_bwd_kernel( tl.store(DLOGITS + cols, dlogits, mask=cols < N) -def _compute_dapo_normalizer(completion_mask): - """Global active tokens averaged per process (for distributed DAPO loss).""" - normalizer = completion_mask.to(torch.float32).sum() +def _compute_dapo_normalizer(completion_mask, num_items_in_batch=None): + """Per-process normalizer for DAPO/CISPO. + + When ``num_items_in_batch`` is provided it is used directly, matching TRL's + ``num_items_in_batch / num_processes`` (total active tokens across the entire + generation batch including all gradient-accumulation micro-batches × all + processes). Falling back to the current micro-batch's mask biases per-token + weights by micro-batch size when grad-accum micro-batches have unequal + completion lengths. + """ world_size = 1 + if torch.distributed.is_available() and torch.distributed.is_initialized(): + world_size = torch.distributed.get_world_size() + + if num_items_in_batch is not None: + if isinstance(num_items_in_batch, torch.Tensor): + normalizer = num_items_in_batch.to(device=completion_mask.device, dtype=torch.float32) + else: + normalizer = torch.as_tensor(float(num_items_in_batch), device=completion_mask.device, dtype=torch.float32) + normalizer = normalizer / world_size + return torch.clamp(normalizer, min=1.0) + + normalizer = completion_mask.to(torch.float32).sum() if torch.distributed.is_available() and torch.distributed.is_initialized(): normalizer = normalizer.clone() torch.distributed.all_reduce(normalizer, op=torch.distributed.ReduceOp.SUM) - world_size = torch.distributed.get_world_size() normalizer = normalizer / world_size return torch.clamp(normalizer, min=1.0) -def _reduce_loss(per_token_loss, mask, loss_type, max_completion_length, B, L): +def _reduce_loss(per_token_loss, mask, loss_type, max_completion_length, B, L, num_items_in_batch=None): """Apply loss reduction based on loss_type.""" if loss_type == "grpo" or loss_type == "sapo": return ((per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)).mean() @@ -540,11 +573,13 @@ def _reduce_loss(per_token_loss, mask, loss_type, max_completion_length, B, L): elif loss_type == "dr_grpo": max_len = max_completion_length if max_completion_length is not None else L return (per_token_loss * mask).sum() / (B * max_len) - elif loss_type == "dapo" or loss_type == "cispo": - return (per_token_loss * mask).sum() / _compute_dapo_normalizer(mask) + elif loss_type == "dapo" or loss_type == "cispo" or loss_type == "vespo": + return (per_token_loss * mask).sum() / _compute_dapo_normalizer(mask, num_items_in_batch=num_items_in_batch) elif loss_type == "luspo": return (per_token_loss * mask.sum(-1, keepdim=True)).mean() - raise ValueError(f"Unknown loss_type: {loss_type}. Expected one of: grpo, bnpo, dr_grpo, dapo, cispo, sapo, luspo") + raise ValueError( + f"Unknown loss_type: {loss_type}. Expected one of: grpo, bnpo, dr_grpo, dapo, cispo, sapo, luspo, vespo" + ) class GrpoLossFunction(torch.autograd.Function): @@ -571,6 +606,8 @@ def forward( vllm_is_ratio=None, delta=None, use_bias_correction_kl=False, + num_items_in_batch=None, + phi_seq=None, ): assert logits.is_contiguous() and completion_ids.is_contiguous() assert old_logp is None or old_logp.is_contiguous() @@ -584,14 +621,14 @@ def forward( raise ValueError(f"Unknown loss_type '{loss_type}'. Supported types: {list(_str_to_loss_type.keys())}") # Validate delta + loss_type combinations - if delta is not None and loss_type in ("cispo", "sapo"): + if delta is not None and loss_type in ("cispo", "sapo", "vespo"): raise ValueError(f"delta (two-sided clipping) is not supported for loss_type='{loss_type}'.") # Map delta to float for Triton (Triton can't handle None) delta_val = 0.0 if delta is None else float(delta) # Validate sequence-level + loss_type combinations - if importance_sampling_level == "sequence" and loss_type in ("cispo", "sapo"): + if importance_sampling_level == "sequence" and loss_type in ("cispo", "sapo", "vespo"): raise ValueError( f"Sequence-level importance sampling is not supported for loss_type='{loss_type}'. " f"Use importance_sampling_level='token' instead." @@ -610,6 +647,15 @@ def forward( B, L_ADD_1, N = logits.shape L = L_ADD_1 - 1 + # VESPO requires phi_seq pre-computed by the caller (uses get_gamma_weights). + if loss_type == "vespo": + if phi_seq is None: + raise ValueError("loss_type='vespo' requires phi_seq pre-computed (use the triton_grpo_loss wrapper).") + assert phi_seq.shape in ((B,), (B, 1)), f"phi_seq must be (B,) or (B, 1), got {tuple(phi_seq.shape)}" + phi_seq = phi_seq.reshape(-1).contiguous() + else: + phi_seq = None + if completion_mask is not None: assert completion_mask.is_contiguous() @@ -676,6 +722,7 @@ def forward( completion_mask, advantages, coef_1_for_loss.contiguous(), + coef_1.contiguous(), # COEF_1_RAW: pre delta-clamp, for bias-corrected KL coef_2.contiguous(), is_clipped_seq.contiguous(), vllm_is_ratio_ptr, @@ -718,6 +765,7 @@ def forward( advantages, vllm_is_ratio_ptr, vllm_is_ratio_stride, + phi_seq, # PHI_SEQ (B,) for VESPO, None otherwise loss, lse, kl, @@ -736,7 +784,16 @@ def forward( **kwargs, ) ctx.save_for_backward( - logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse, mask, vllm_is_ratio_ptr + logits, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + lse, + mask, + vllm_is_ratio_ptr, + phi_seq, ) ctx.infos = ( @@ -757,6 +814,7 @@ def forward( reduce, delta_val, use_bias_correction_kl, + num_items_in_batch, ) # Compute metrics before reduction @@ -770,7 +828,9 @@ def forward( is_clipped_out = is_clipped * mask return loss_out, kl_out, is_clipped_out - reduced_loss = _reduce_loss(loss, mask, loss_type, max_completion_length, B, L) + reduced_loss = _reduce_loss( + loss, mask, loss_type, max_completion_length, B, L, num_items_in_batch=num_items_in_batch + ) return reduced_loss, kl_mean, clip_ratio @staticmethod @@ -795,6 +855,7 @@ def backward(ctx, *args): reduce, delta_val, use_bias_correction_kl, + num_items_in_batch, ) = ctx.infos if importance_sampling_level == "sequence": @@ -811,10 +872,20 @@ def backward(ctx, *args): seq_lens, vllm_is_ratio, ) = saved_tensors + phi_seq = None else: - (logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse, mask, vllm_is_ratio) = ( - saved_tensors - ) + ( + logits, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + lse, + mask, + vllm_is_ratio, + phi_seq, + ) = saved_tensors _, L_ADD_1, N = logits.shape @@ -829,8 +900,8 @@ def backward(ctx, *args): elif loss_type == "dr_grpo": max_len = max_completion_length if max_completion_length is not None else L dloss = dloss_input * mask / (B * max_len) - elif loss_type == "dapo" or loss_type == "cispo": - dloss = dloss_input * mask / _compute_dapo_normalizer(mask) + elif loss_type == "dapo" or loss_type == "cispo" or loss_type == "vespo": + dloss = dloss_input * mask / _compute_dapo_normalizer(mask, num_items_in_batch=num_items_in_batch) elif loss_type == "luspo": # loss = mean(per_token_loss * seq_lens), mean divides by B*L seq_lens_bwd = mask.sum(-1, keepdim=True).clamp(min=1.0) @@ -889,6 +960,7 @@ def backward(ctx, *args): lse, vllm_is_ratio, vllm_is_ratio_stride, + phi_seq, temperature, beta, eps_low, @@ -905,7 +977,7 @@ def backward(ctx, *args): ) dlogits[:, -1, :] = 0 - # Return gradients for all forward inputs: dlogits + 19 None for non-differentiable params + # Return gradients for all forward inputs: dlogits + 21 None for non-differentiable params return ( dlogits, None, @@ -927,4 +999,6 @@ def backward(ctx, *args): None, None, None, + None, # num_items_in_batch + None, # phi_seq ) diff --git a/src/liger_kernel/transformers/grpo_loss.py b/src/liger_kernel/transformers/grpo_loss.py index caa053bd6..2af36158a 100644 --- a/src/liger_kernel/transformers/grpo_loss.py +++ b/src/liger_kernel/transformers/grpo_loss.py @@ -1,7 +1,9 @@ import torch from liger_kernel.chunked_loss.fused_linear_ppo import LigerFusedLinearPPOBase +from liger_kernel.chunked_loss.grpo_loss import get_gamma_weights from liger_kernel.ops import GrpoLossFunction +from liger_kernel.ops.grpo_loss import fused_selective_log_softmax def triton_grpo_loss( @@ -25,6 +27,11 @@ def triton_grpo_loss( vllm_is_ratio=None, delta=None, use_bias_correction_kl=False, + num_items_in_batch=None, + vespo_k_pos=2.0, + vespo_lambda_pos=3.0, + vespo_k_neg=3.0, + vespo_lambda_neg=2.0, ): """ Triton-optimized GRPO loss function. @@ -53,6 +60,14 @@ def triton_grpo_loss( types (grpo, bnpo, dr_grpo, dapo, luspo). None means disabled. use_bias_correction_kl: If True, multiply KL divergence by coef_1 (importance sampling ratio) for bias-corrected KL estimation (DeepSeek-V3.2). Default False. + num_items_in_batch: Optional total active tokens across the entire generation batch + (all gradient-accumulation micro-batches × all processes). When provided, dapo / + cispo / vespo normalization uses ``num_items_in_batch / num_processes`` to match + TRL's ``compute_loss``. When None, falls back to the current micro-batch's mask + sum. + vespo_k_pos, vespo_lambda_pos, vespo_k_neg, vespo_lambda_neg: VESPO gamma weighting + hyperparameters (k for shape, lambda for rate; ``_pos`` for non-negative + advantages, ``_neg`` for negative). Only used when ``loss_type='vespo'``. Returns: If reduce=True: (loss, metrics) where metrics = [kl_mean, clip_ratio] or [clip_ratio] @@ -65,6 +80,45 @@ def triton_grpo_loss( f"importance_sampling_level must be 'token' or 'sequence', got {importance_sampling_level}" ) + # VESPO: pre-compute phi_seq (detached, sequence-level gamma weighting). The vllm + # importance-sampling correction is folded into phi_seq via log_is_ratio rather than + # multiplied onto per_token_loss, so we drop vllm_is_ratio for the kernel call. + phi_seq = None + if loss_type == "vespo": + if importance_sampling_level == "sequence": + raise ValueError("loss_type='vespo' requires importance_sampling_level='token'.") + # Need per-token logp for log_ratio. fused_selective_log_softmax is no-grad — + # phi_seq is detached anyway, so this is fine. + per_token_logps = fused_selective_log_softmax(logits, completion_ids, temperature, completion_mask) + if old_logp is None: + log_ratio = torch.zeros_like(per_token_logps) + else: + log_ratio = per_token_logps - old_logp + mask = ( + completion_mask + if completion_mask is not None + else torch.ones_like(per_token_logps, dtype=per_token_logps.dtype) + ) + # Normalize vllm_is_ratio shape to (B, T) for get_gamma_weights' sum-over-time. + vllm_for_phi = vllm_is_ratio + if vllm_for_phi is not None: + if vllm_for_phi.dim() == 1: + vllm_for_phi = vllm_for_phi.unsqueeze(-1).expand_as(per_token_logps) + elif vllm_for_phi.dim() == 2 and vllm_for_phi.shape[1] == 1: + vllm_for_phi = vllm_for_phi.expand_as(per_token_logps) + phi_seq = get_gamma_weights( + advantages=advantages, + log_ratio_per_token=log_ratio, + mask=mask, + importance_sampling_ratio=vllm_for_phi, + k_pos=vespo_k_pos, + lambda_pos=vespo_lambda_pos, + k_neg=vespo_k_neg, + lambda_neg=vespo_lambda_neg, + ) # (B, 1) + # vllm correction is folded into phi_seq; do not pass it to the kernel separately. + vllm_is_ratio = None + result = GrpoLossFunction.apply( logits, old_logp, @@ -86,6 +140,8 @@ def triton_grpo_loss( vllm_is_ratio, delta, use_bias_correction_kl, + num_items_in_batch, + phi_seq, ) if not reduce: @@ -101,7 +157,7 @@ def triton_grpo_loss( return reduced_loss, metrics -def _reduce_grpo_loss(per_token_loss, completion_mask, loss_type, max_completion_length): +def _reduce_grpo_loss(per_token_loss, completion_mask, loss_type, max_completion_length, num_items_in_batch=None): mask = completion_mask if mask is None: mask = torch.ones_like(per_token_loss, dtype=per_token_loss.dtype, device=per_token_loss.device) @@ -117,9 +173,9 @@ def _reduce_grpo_loss(per_token_loss, completion_mask, loss_type, max_completion batch = per_token_loss.shape[0] max_len = max_completion_length if max_completion_length is not None else per_token_loss.shape[1] return (per_token_loss * mask).sum() / (batch * max_len) - if loss_type == "dapo" or loss_type == "cispo": - # CISPO uses the same normalization as DAPO - normalizer = LigerFusedLinearPPOBase._compute_dapo_normalizer(mask) + if loss_type == "dapo" or loss_type == "cispo" or loss_type == "vespo": + # CISPO and VESPO use the same normalization as DAPO + normalizer = LigerFusedLinearPPOBase._compute_dapo_normalizer(mask, num_items_in_batch=num_items_in_batch) return (per_token_loss * mask).sum() / normalizer if loss_type == "luspo": # LUSPO: scale each sequence's loss by its valid token count, then average across sequences diff --git a/test/chunked_loss/test_grpo_loss.py b/test/chunked_loss/test_grpo_loss.py index e447b85ad..9da0aedf8 100644 --- a/test/chunked_loss/test_grpo_loss.py +++ b/test/chunked_loss/test_grpo_loss.py @@ -187,8 +187,8 @@ def compute_per_token_components( ref_per_token_logps = ref_per_token_logps.float() kl_div = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1.0 if use_bias_correction_kl: - token_coef_1 = torch.exp(per_token_logps - old_per_token_logps) - kl_div = kl_div * token_coef_1 + # TRL: per_token_kl *= coef_1 with coef_1 reflecting importance_sampling_level. + kl_div = kl_div * torch.exp(log_importance_weights) per_token_loss = per_token_loss + beta * kl_div # Adjust clipping metric calculation based on importance sampling level @@ -347,6 +347,7 @@ def forward( old_per_token_logps=None, ref_input=None, vllm_is_ratio=None, + num_items_in_batch=None, ): return self.grpo_loss( x, # _input @@ -361,6 +362,7 @@ def forward( self.ref_lin.weight, # ref_weight self.ref_lin.bias, # ref_bias vllm_is_ratio=vllm_is_ratio, + num_items_in_batch=num_items_in_batch, ) @@ -646,14 +648,23 @@ def test_correctness( @pytest.mark.parametrize("loss_type", ["grpo", "dapo"]) +@pytest.mark.parametrize("importance_sampling_level", ["token", "sequence"]) @pytest.mark.parametrize( "dtype, atol, rtol", [ (torch.float32, 1e-5, 5e-4), ], ) -def test_correctness_with_bias_correction_kl(loss_type, dtype, atol, rtol): - """Test use_bias_correction_kl (importance-sampling-corrected KL from DeepSeek-V3.2).""" +def test_correctness_with_bias_correction_kl(loss_type, importance_sampling_level, dtype, atol, rtol): + """Test use_bias_correction_kl (importance-sampling-corrected KL from DeepSeek-V3.2). + + Covers both ``importance_sampling_level`` values: TRL multiplies ``per_token_kl`` + by ``coef_1`` whose shape mirrors the importance-sampling level (token: (B, T); + sequence: (B, 1)). Liger must do the same — historically it always recomputed a + token-level ratio, which silently miscomputed the bias-corrected KL when + sequence-level importance sampling was selected. + """ + set_seed() B, T, H, V = 3, 47, 31, 123 beta = 0.1 # Must be non-zero for KL to matter torch.compiler.reset() @@ -664,6 +675,7 @@ def test_correctness_with_bias_correction_kl(loss_type, dtype, atol, rtol): dtype=dtype, beta=beta, loss_type=loss_type, + importance_sampling_level=importance_sampling_level, use_bias_correction_kl=True, ) liger_lm_head_grpo = LigerLMHeadGRPO( @@ -672,6 +684,7 @@ def test_correctness_with_bias_correction_kl(loss_type, dtype, atol, rtol): dtype=dtype, beta=beta, loss_type=loss_type, + importance_sampling_level=importance_sampling_level, use_bias_correction_kl=True, ) @@ -791,6 +804,128 @@ def test_correctness_with_vllm_is_ratio(loss_type, beta): assert_verbose_allclose(input3.grad, input4.grad, atol=1e-5, rtol=1e-5) +@pytest.mark.parametrize("loss_type", ["dapo", "cispo", "vespo"]) +def test_num_items_in_batch_normalizer(loss_type): + """``num_items_in_batch`` overrides the dapo/cispo/vespo normalizer. + + TRL's ``compute_loss`` for these loss types divides by ``num_items_in_batch / + num_processes`` — the total active tokens across the entire generation batch + (all gradient-accumulation micro-batches × all processes). The Liger default + falls back to the current micro-batch's mask, which biases per-token weights + by micro-batch size when grad-accum micro-batches have unequal lengths. + + This test verifies, in single-process world: + 1. Passing ``num_items_in_batch=mask.sum()`` matches the default normalizer. + 2. Doubling ``num_items_in_batch`` halves both loss and input gradients + (linear in the normalizer, no other dependence). + """ + set_seed() + torch.compiler.reset() + B, T, H, V = 3, 47, 31, 123 + dtype = torch.float32 + + _weight = torch.randn(V, H, device=device, dtype=dtype) + _input = torch.randn(B, T, H, device=device, dtype=dtype) + selected_token_ids = torch.randint(0, V, (B, T), device=device) + attention_mask = torch.ones(B, T, device=device) + attention_mask[:, -5:] = 0 + advantages = torch.randn(B, device=device, dtype=dtype) + advantages[0] = -advantages[0].abs() + advantages[1] = advantages[1].abs() + + mask_sum = attention_mask.sum().item() + + def _run(num_items_in_batch): + liger = LigerLMHeadGRPO(H=H, V=V, dtype=dtype, beta=0.0, loss_type=loss_type, use_ref_model=False) + liger.lin.weight.data = _weight.clone() + inp = _input.detach().clone().requires_grad_(True) + loss, _ = liger( + inp, + selected_token_ids, + attention_mask, + advantages, + num_items_in_batch=num_items_in_batch, + ) + loss.backward() + return loss.detach(), inp.grad.detach().clone() + + loss_default, grad_default = _run(num_items_in_batch=None) + loss_match, grad_match = _run(num_items_in_batch=mask_sum) + loss_double, grad_double = _run(num_items_in_batch=mask_sum * 2) + + assert_verbose_allclose(loss_default, loss_match, atol=1e-5, rtol=1e-5) + assert_verbose_allclose(grad_default, grad_match, atol=1e-5, rtol=1e-5) + + assert_verbose_allclose(loss_double * 2, loss_default, atol=1e-5, rtol=1e-5) + assert_verbose_allclose(grad_double * 2, grad_default, atol=1e-5, rtol=1e-5) + + +@pytest.mark.parametrize("loss_type", ["dapo", "cispo", "vespo"]) +def test_num_items_in_batch_matches_trl_formula(loss_type): + """Liger with ``num_items_in_batch=N`` matches TRL's ``sum / (N / num_processes)``. + + Reproduces TRL's exact formula in single-process world (num_processes=1): + ``loss = (per_token_loss * mask).sum() / num_items_in_batch``. + """ + set_seed() + torch.compiler.reset() + B, T, H, V = 3, 47, 31, 123 + dtype = torch.float32 + + _weight = torch.randn(V, H, device=device, dtype=dtype) + _input = torch.randn(B, T, H, device=device, dtype=dtype) + selected_token_ids = torch.randint(0, V, (B, T), device=device) + attention_mask = torch.ones(B, T, device=device) + attention_mask[:, -5:] = 0 + advantages = torch.randn(B, device=device, dtype=dtype) + advantages[0] = -advantages[0].abs() + advantages[1] = advantages[1].abs() + + # Pick num_items_in_batch != mask.sum() to exercise the new path. + num_items_in_batch = float(attention_mask.sum().item()) * 1.7 + + # Liger: pass num_items_in_batch through the new param. + liger = LigerLMHeadGRPO(H=H, V=V, dtype=dtype, beta=0.0, loss_type=loss_type, use_ref_model=False) + liger.lin.weight.data = _weight.clone() + input1 = _input.detach().clone().requires_grad_(True) + loss_liger, _ = liger( + input1, + selected_token_ids, + attention_mask, + advantages, + num_items_in_batch=num_items_in_batch, + ) + + # Torch reference: use num_items_in_batch directly as the normalizer. + # We monkey-patch TorchLMHeadGRPO's branch by overriding the loss with TRL's exact formula. + torch_lm = TorchLMHeadGRPO(H=H, V=V, dtype=dtype, beta=0.0, loss_type=loss_type, use_ref_model=False) + torch_lm.lin.weight.data = _weight.clone() + input2 = _input.detach().clone().requires_grad_(True) + + logits = input2 @ torch_lm.lin.weight.t() + log_probs = F.log_softmax(logits.float(), dim=-1) + per_token_logps = log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(-1) + per_token_loss, _, _ = TorchLMHeadGRPO.compute_per_token_components( + per_token_logps, + attention_mask, + advantages, + old_per_token_logps=None, + ref_per_token_logps=None, + epsilon_low=0.2, + epsilon_high=0.2, + beta=0.0, + importance_sampling_level="token", + loss_type=loss_type, + ) + loss_ref = (per_token_loss * attention_mask).sum() / num_items_in_batch + + assert_verbose_allclose(loss_liger, loss_ref, atol=1e-5, rtol=1e-4) + + loss_liger.backward() + loss_ref.backward() + assert_verbose_allclose(input1.grad, input2.grad, atol=1e-5, rtol=1e-4) + + @pytest.mark.parametrize( "B, T, H, V", [ diff --git a/test/transformers/test_grpo_loss.py b/test/transformers/test_grpo_loss.py index d139e393d..18c72ded8 100644 --- a/test/transformers/test_grpo_loss.py +++ b/test/transformers/test_grpo_loss.py @@ -496,8 +496,14 @@ def trl_reference_grpo_loss( importance_sampling_level, delta=None, use_bias_correction_kl=False, + vespo_k_pos=2.0, + vespo_lambda_pos=3.0, + vespo_k_neg=3.0, + vespo_lambda_neg=2.0, ): """TRL reference implementation from grpo_trainer.py""" + from liger_kernel.chunked_loss.grpo_loss import get_gamma_weights + B, L_ADD_1, V = logits.shape L = L_ADD_1 - 1 @@ -517,21 +523,38 @@ def trl_reference_grpo_loss( log_importance_weights = log_importance_weights.unsqueeze(-1) coef_1 = torch.exp(log_importance_weights) - coef_2 = torch.clamp(coef_1, 1 - eps_low, 1 + eps_high) - if delta is not None: - coef_1 = torch.clamp(coef_1, max=delta) - per_token_loss1 = coef_1 * advantages.unsqueeze(-1) - per_token_loss2 = coef_2 * advantages.unsqueeze(-1) - per_token_loss = -torch.min(per_token_loss1, per_token_loss2) + if loss_type == "vespo": + # VESPO: detached gamma weighting on per-token logp, no clipping. + # phi_seq replaces the (coef_1, coef_2) clipping pair. + phi_seq = get_gamma_weights( + advantages=advantages, + log_ratio_per_token=log_ratio, + mask=completion_mask, + k_pos=vespo_k_pos, + lambda_pos=vespo_lambda_pos, + k_neg=vespo_k_neg, + lambda_neg=vespo_lambda_neg, + ) # (B, 1) + per_token_loss = -phi_seq * advantages.unsqueeze(-1) * per_token_logps + else: + coef_2 = torch.clamp(coef_1, 1 - eps_low, 1 + eps_high) + if delta is not None: + coef_1 = torch.clamp(coef_1, max=delta) - if importance_sampling_level == "sequence": - per_token_loss = per_token_loss.expand(B, L) + per_token_loss1 = coef_1 * advantages.unsqueeze(-1) + per_token_loss2 = coef_2 * advantages.unsqueeze(-1) + per_token_loss = -torch.min(per_token_loss1, per_token_loss2) + + if importance_sampling_level == "sequence": + per_token_loss = per_token_loss.expand(B, L) if beta != 0.0: kl = torch.exp(ref_logp - per_token_logps) - (ref_logp - per_token_logps) - 1.0 if use_bias_correction_kl: - kl = kl * torch.exp(per_token_logps - old_logp) + # TRL: kl *= coef_1 with shape matching importance_sampling_level + # (token: (B, T); sequence: (B, 1)). + kl = kl * torch.exp(log_importance_weights) per_token_loss = per_token_loss + beta * kl # Loss reduction @@ -541,7 +564,7 @@ def trl_reference_grpo_loss( loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) elif loss_type == "dr_grpo": loss = (per_token_loss * completion_mask).sum() / (B * L) - elif loss_type == "dapo": + elif loss_type == "dapo" or loss_type == "vespo": loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) elif loss_type == "luspo": loss = (per_token_loss * completion_mask.sum(-1, keepdim=True)).mean() @@ -551,18 +574,22 @@ def trl_reference_grpo_loss( @pytest.mark.parametrize("delta", [None, 1.5]) @pytest.mark.parametrize("importance_sampling_level", ["token", "sequence"]) -@pytest.mark.parametrize("loss_type", ["grpo", "bnpo", "dr_grpo", "dapo", "luspo"]) -@pytest.mark.parametrize("beta", [0.0, 0.04]) +@pytest.mark.parametrize("loss_type", ["grpo", "bnpo", "dr_grpo", "dapo", "luspo", "vespo"]) +@pytest.mark.parametrize("beta,use_bias_correction_kl", [(0.0, False), (0.04, False), (0.04, True)]) @pytest.mark.parametrize( "B, T, V", [ (2, 128, 1000), ], ) -def test_grpo_loss_vs_trl(B, T, V, beta, loss_type, importance_sampling_level, delta): +def test_grpo_loss_vs_trl(B, T, V, beta, use_bias_correction_kl, loss_type, importance_sampling_level, delta): """Test that triton_grpo_loss matches TRL's exact implementation.""" if importance_sampling_level == "token" and loss_type == "luspo": pytest.skip("Token-level importance sampling is not supported for loss_type='luspo'") + if importance_sampling_level == "sequence" and loss_type == "vespo": + pytest.skip("Sequence-level importance sampling is not supported for loss_type='vespo'") + if delta is not None and loss_type == "vespo": + pytest.skip("delta (two-sided clipping) is not supported for loss_type='vespo'") torch.manual_seed(42) logits = torch.randn(B, T + 1, V, device=device, dtype=torch.float32) @@ -595,6 +622,7 @@ def test_grpo_loss_vs_trl(B, T, V, beta, loss_type, importance_sampling_level, d loss_type, importance_sampling_level, delta=delta, + use_bias_correction_kl=use_bias_correction_kl, ) # Triton implementation @@ -615,6 +643,7 @@ def test_grpo_loss_vs_trl(B, T, V, beta, loss_type, importance_sampling_level, d max_completion_length=T, reduce=True, delta=delta, + use_bias_correction_kl=use_bias_correction_kl, ) # Verify forward match @@ -626,6 +655,61 @@ def test_grpo_loss_vs_trl(B, T, V, beta, loss_type, importance_sampling_level, d assert not torch.isnan(logits_triton.grad).any() +@pytest.mark.parametrize("loss_type", ["dapo", "cispo"]) +def test_triton_num_items_in_batch_normalizer(loss_type): + """``num_items_in_batch`` overrides the dapo/cispo normalizer in the triton path. + + Mirrors the chunked-loss test: in single-process world, passing + ``num_items_in_batch=mask.sum()`` matches the default normalizer; doubling + the value halves both the loss and the input gradient. + """ + torch.manual_seed(0) + B, T, V = 2, 64, 256 + + completion_ids = torch.randint(0, V, (B, T), device=device) + completion_mask = torch.ones(B, T, device=device, dtype=torch.float32) + completion_mask[:, -8:] = 0 + advantages = torch.randn(B, device=device, dtype=torch.float32) + ref_logp = torch.randn(B, T, device=device, dtype=torch.float32) + old_logp = torch.randn(B, T, device=device, dtype=torch.float32) + + eps_low, eps_high = (0.2, 0.4) if loss_type == "dapo" else (0.0, 5.0) + mask_sum = completion_mask.sum().item() + base_logits = torch.randn(B, T + 1, V, device=device, dtype=torch.float32) + + def _run(num_items_in_batch): + logits = base_logits.clone().requires_grad_(True) + loss, _ = triton_grpo_loss( + logits, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + temperature=0.9, + beta=0.04, + eps_low=eps_low, + eps_high=eps_high, + inplace=False, + loss_type=loss_type, + max_completion_length=T, + reduce=True, + num_items_in_batch=num_items_in_batch, + ) + loss.backward() + return loss.detach(), logits.grad.detach().clone() + + loss_default, grad_default = _run(num_items_in_batch=None) + loss_match, grad_match = _run(num_items_in_batch=mask_sum) + loss_double, grad_double = _run(num_items_in_batch=mask_sum * 2) + + assert_verbose_allclose(loss_default, loss_match, atol=1e-5, rtol=1e-5) + assert_verbose_allclose(grad_default, grad_match, atol=1e-5, rtol=1e-5) + + assert_verbose_allclose(loss_double * 2, loss_default, atol=1e-5, rtol=1e-5) + assert_verbose_allclose(grad_double * 2, grad_default, atol=1e-5, rtol=1e-5) + + def trl_reference_grpo_loss_with_vllm_is( logits, old_logp, @@ -681,7 +765,9 @@ def trl_reference_grpo_loss_with_vllm_is( if beta != 0.0: kl = torch.exp(ref_logp - per_token_logps) - (ref_logp - per_token_logps) - 1.0 if use_bias_correction_kl: - kl = kl * torch.exp(per_token_logps - old_logp) + # TRL: kl *= coef_1 with shape matching importance_sampling_level + # (token: (B, T); sequence: (B, 1)). + kl = kl * torch.exp(log_importance_weights) per_token_loss = per_token_loss + beta * kl # Loss reduction