Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions src/liger_kernel/chunked_loss/fused_linear_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def forward(
vespo_lambda_pos=3.0,
vespo_k_neg=3.0,
vespo_lambda_neg=2.0,
num_items_in_batch=None,
):
"""Chunked forward pass for PPO loss computation.

Expand Down Expand Up @@ -254,6 +255,7 @@ def forward(
vespo_lambda_pos=vespo_lambda_pos,
vespo_k_neg=vespo_k_neg,
vespo_lambda_neg=vespo_lambda_neg,
num_items_in_batch=num_items_in_batch,
)
compiled_compute_loss = torch.compile(compute_loss) if compiled else compute_loss

Expand Down Expand Up @@ -433,16 +435,38 @@ def accumulate_chunk(
return loss_acc, tuple(final_metrics)

@staticmethod
def _compute_dapo_normalizer(attention_mask):
"""Global active tokens averaged per process."""
normalizer = attention_mask.to(torch.float32).sum()
def _compute_dapo_normalizer(attention_mask, num_items_in_batch=None):
"""Per-process normalizer for DAPO/CISPO/VESPO.

When ``num_items_in_batch`` is provided it is used directly, matching
TRL's ``num_items_in_batch / num_processes`` — the total active tokens
across the entire generation batch (all grad-accum micro-batches × all
processes). Falling back to the current micro-batch's mask biases the
per-token weight by micro-batch size when grad-accum steps have
unequal completion lengths.
"""
world_size = 1
if torch.distributed.is_available() and torch.distributed.is_initialized():
import torch.distributed as dist

world_size = dist.get_world_size()

if num_items_in_batch is not None:
if isinstance(num_items_in_batch, torch.Tensor):
normalizer = num_items_in_batch.to(device=attention_mask.device, dtype=torch.float32)
else:
normalizer = torch.as_tensor(
float(num_items_in_batch), device=attention_mask.device, dtype=torch.float32
)
normalizer = normalizer / world_size
return torch.clamp(normalizer, min=1.0)

normalizer = attention_mask.to(torch.float32).sum()
if torch.distributed.is_available() and torch.distributed.is_initialized():
import torch.distributed as dist

normalizer = normalizer.clone()
dist.all_reduce(normalizer, op=dist.ReduceOp.SUM)
world_size = dist.get_world_size()

normalizer = normalizer / world_size
return torch.clamp(normalizer, min=1.0)
Expand Down Expand Up @@ -471,6 +495,7 @@ def _compute_loss_from_logps(
vespo_lambda_pos=3.0,
vespo_k_neg=3.0,
vespo_lambda_neg=2.0,
num_items_in_batch=None,
):
"""Compute loss from pre-computed logprobs. This is the torch.compile-friendly part."""
chunk_loss, chunk_metrics = ppo_loss_fn(
Expand All @@ -495,6 +520,7 @@ def _compute_loss_from_logps(
vespo_lambda_pos=vespo_lambda_pos,
vespo_k_neg=vespo_k_neg,
vespo_lambda_neg=vespo_lambda_neg,
num_items_in_batch=num_items_in_batch,
)
return chunk_loss, chunk_metrics

Expand Down Expand Up @@ -556,4 +582,5 @@ def backward(ctx, grad_output, *grad_metrics):
None, # grad_vllm_is_ratio
None, # grad_delta
None, # grad_use_bias_correction_kl
None, # grad_num_items_in_batch
)
18 changes: 14 additions & 4 deletions src/liger_kernel/chunked_loss/grpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def ppo_loss_fn(
vespo_lambda_pos=3.0, # VESPO gamma rate lambda for non-negative advantages
vespo_k_neg=3.0, # VESPO gamma shape k for negative advantages
vespo_lambda_neg=2.0, # VESPO gamma rate lambda for negative advantages
num_items_in_batch=None, # Total active tokens across the entire generation batch (TRL-compat)
**kwargs,
):
"""GRPO Loss Function matching GRPOTrainer implementation."""
Expand Down Expand Up @@ -209,9 +210,11 @@ def ppo_loss_fn(
# Compute KL penalty (approximates KL[per_token_logps, ref_per_token_logps])
kl_div = k3_loss_fn(ref_per_token_logps, per_token_logps)
if use_bias_correction_kl:
# Importance-sampling-corrected KL (DeepSeek-V3.2): kl *= token-level coef_1
token_coef_1 = torch.exp(per_token_logps - old_per_token_logps)
kl_div = kl_div * token_coef_1
# Importance-sampling-corrected KL (DeepSeek-V3.2): kl *= coef_1.
# Use exp(log_importance_weights) so the ratio's shape matches
# importance_sampling_level (token: (B, T); sequence: (B, 1)),
# mirroring TRL's ``per_token_kl * coef_1`` (un-clamped, before delta).
kl_div = kl_div * torch.exp(log_importance_weights)
# Combine losses
per_token_loss = per_token_loss + beta * kl_div

Expand All @@ -233,7 +236,9 @@ def ppo_loss_fn(
raise ValueError("max_completion_length must be provided for loss_type 'dr_grpo'")
loss = (per_token_loss * attention_mask).sum() / (full_attention_mask.shape[0] * max_completion_length)
elif loss_type in ("dapo", "cispo", "vespo"):
loss_normalizer = LigerFusedLinearPPOBase._compute_dapo_normalizer(full_attention_mask)
loss_normalizer = LigerFusedLinearPPOBase._compute_dapo_normalizer(
full_attention_mask, num_items_in_batch=num_items_in_batch
)
loss = (per_token_loss * attention_mask).sum() / loss_normalizer
elif loss_type == "luspo":
# Match TRL exactly: loss = (per_token_loss * mask.sum(1, keepdim=True)).mean()
Expand Down Expand Up @@ -296,6 +301,7 @@ def forward(
vespo_lambda_pos=3.0,
vespo_k_neg=3.0,
vespo_lambda_neg=2.0,
num_items_in_batch=None,
):
"""
Fused linear layer with GRPO loss.
Expand Down Expand Up @@ -366,6 +372,7 @@ def forward(
vespo_lambda_pos=vespo_lambda_pos,
vespo_k_neg=vespo_k_neg,
vespo_lambda_neg=vespo_lambda_neg,
num_items_in_batch=num_items_in_batch,
)

@staticmethod
Expand Down Expand Up @@ -405,6 +412,7 @@ def backward(ctx, grad_output, *grad_metrics):
None, # grad_vespo_lambda_pos
None, # grad_vespo_k_neg
None, # grad_vespo_lambda_neg
None, # grad_num_items_in_batch
)


Expand Down Expand Up @@ -492,6 +500,7 @@ def forward(
ref_weight=None,
ref_bias=None,
vllm_is_ratio=None,
num_items_in_batch=None,
):
return LigerFusedLinearGRPOFunction.apply(
_input,
Expand Down Expand Up @@ -524,4 +533,5 @@ def forward(
self.vespo_lambda_pos,
self.vespo_k_neg,
self.vespo_lambda_neg,
num_items_in_batch,
)
Loading
Loading