diff --git a/.github/workflows/pr-test.yml.j2 b/.github/workflows/pr-test.yml.j2 index 5621374ad3..c8cb81202f 100644 --- a/.github/workflows/pr-test.yml.j2 +++ b/.github/workflows/pr-test.yml.j2 @@ -68,6 +68,7 @@ {'test_file': 'test_metric_report.py', 'num_gpus': 0}, {'test_file': 'test_metric_report_dist.py', 'num_gpus': 0}, {'test_file': 'test_loss_cp_invariance.py', 'num_gpus': 0}, + {'test_file': 'test_logprob_entropy_fused.py', 'num_gpus': 0}, {'test_file': 'test_value_temperature.py', 'num_gpus': 0}, {'test_file': 'test_cispo_loss.py', 'num_gpus': 0}, {'test_file': 'test_rm_f1.py', 'num_gpus': 0}, diff --git a/examples/retool/retool_qwen3_4b_rl.sh b/examples/retool/retool_qwen3_4b_rl.sh index 32a837f394..e2f229d45d 100644 --- a/examples/retool/retool_qwen3_4b_rl.sh +++ b/examples/retool/retool_qwen3_4b_rl.sh @@ -75,6 +75,11 @@ PERF_ARGS=( # --micro-batch-size 1 --use-dynamic-batch-size --max-tokens-per-gpu 9216 + # Bound the fused cross-entropy [tokens, vocab] transient that can OOM on long retool traces. + --log-probs-chunk-size 1024 + # Gather only response tokens before the cross-entropy: retool traces have long prompts, so this + # shrinks the [T, vocab] tensor to [T_response, vocab] and stacks with the chunking above. + --log-probs-response-only ) GRPO_ARGS=( @@ -153,4 +158,4 @@ ray job submit --address="http://127.0.0.1:8265" \ ${EVAL_ARGS[@]} \ ${SGLANG_ARGS[@]} \ ${MISC_ARGS[@]} \ - ${CUSTOM_ARGS[@]} \ No newline at end of file + ${CUSTOM_ARGS[@]} diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index a456939e74..ac9d899607 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -292,6 +292,81 @@ def _build_shifted_tokens( return full_tokens +def _response_keep_index( + total_lengths: list[int], + response_lengths: list[int], + qkv_format: str, + max_seq_lens: list[int] | None, + allgather_cp: bool, + device: torch.device, + T: int, +) -> torch.Tensor: + """Positions that ``_extract_per_sample`` reads, as a flat 1-D LongTensor. + + The cross-entropy in ``get_log_probs_and_entropy`` is only consumed on these + response-window positions; everything else in ``[T, V]`` is computed and then + discarded. Gathering ``logits`` down to these rows shrinks the dominant tensor + before CE; scattering the results back to full ``T`` leaves + ``_extract_per_sample`` untouched. + + The ranges below mirror ``_extract_per_sample`` branch-for-branch and in the + same order, so the two stay in lock-step (single source of truth for which + positions survive). + """ + cp_size = mpu.get_context_parallel_world_size() + ranges: list[tuple[int, int]] = [] + + if cp_size > 1 and not allgather_cp: + # zigzag CP: two windows per sample + pos = 0 + for i, (total_length, response_length) in enumerate(zip(total_lengths, response_lengths, strict=False)): + max_seq_len = max_seq_lens[i] if max_seq_lens is not None else None + chunk_size_cp, chunks_offset, logits_offset, _tokens_offset = get_logits_and_tokens_offset_with_cp( + total_length, response_length, qkv_format, max_seq_len + ) + lo0 = logits_offset[0][0] - chunks_offset[0][0] + hi0 = logits_offset[0][1] - chunks_offset[0][0] + lo1 = logits_offset[1][0] - chunks_offset[1][0] + hi1 = logits_offset[1][1] - chunks_offset[1][0] + ranges.append((pos + lo0, pos + hi0)) + ranges.append((pos + chunk_size_cp + lo1, pos + chunk_size_cp + hi1)) + pos += 2 * chunk_size_cp + + elif allgather_cp: + cp_rank = mpu.get_context_parallel_rank() + chunk_start = cp_rank * T + chunk_end = chunk_start + T + seq_start = 0 + for total_length, response_length in zip(total_lengths, response_lengths, strict=False): + prompt_length = total_length - response_length + logit_global_start = seq_start + prompt_length - 1 + logit_global_end = seq_start + total_length - 1 + s = max(logit_global_start, chunk_start) + e = min(logit_global_end, chunk_end) + if e > s: + ranges.append((s - chunk_start, e - chunk_start)) + seq_start += total_length + + else: + # cp1 + if qkv_format == "thd": + offset = 0 + for total_length, response_length in zip(total_lengths, response_lengths, strict=False): + end = offset + total_length + start = end - response_length + ranges.append((start - 1, end - 1)) + offset += total_length + else: # bshd + for i, (total_length, response_length) in enumerate(zip(total_lengths, response_lengths, strict=False)): + end = max_seq_lens[i] * i + total_length + start = end - response_length + ranges.append((start - 1, end - 1)) + + if not ranges: + return torch.zeros((0,), dtype=torch.long, device=device) + return torch.cat([torch.arange(s, e, device=device, dtype=torch.long) for s, e in ranges]) + + def _extract_per_sample( log_prob_full: torch.Tensor, entropy_full: torch.Tensor | None, @@ -394,6 +469,7 @@ def get_log_probs_and_entropy( with_entropy: bool = False, non_loss_data: bool = True, max_seq_lens: list[int] | None = None, + full_loss_mask: torch.Tensor | None = None, ) -> dict[str, list[torch.Tensor]]: """Compute per-token log-probabilities (and optionally entropy) on responses. @@ -401,6 +477,14 @@ def get_log_probs_and_entropy( per-sample slicing) so backward traverses ``[T, V]`` only once, then extracts per-sample response portions. + With ``--log-probs-response-only`` the CE runs only on the response-window + rows (gathered out of ``[T, V]`` before CE and scattered back after), so the + dominant tensor shrinks from ``T`` to the number of response tokens ``T'``. + With ``--log-probs-loss-mask-only`` (and a ``full_loss_mask`` aligned to the + logits layout) it shrinks further to the ``loss_mask == 1`` rows; positions + dropped this way return a log-prob/entropy of 0 and so are only valid where + the downstream loss masks them out (policy-loss path). + When ``entropy_coef == 0``, entropy is computed under ``torch.no_grad()`` to avoid retaining the computation graph and to skip cloning. """ @@ -432,15 +516,44 @@ def get_log_probs_and_entropy( T, device, unconcat_tokens, total_lengths, response_lengths, qkv_format, max_seq_lens, args.allgather_cp ) - # --- compute on full [T,V] logits at once via calculate_log_probs_and_entropy --- - log_prob_full, entropy_full = calculate_log_probs_and_entropy( - logits, - full_tokens, - tp_group, - with_entropy=with_entropy, - chunk_size=chunk_size, - ) - log_prob_full = log_prob_full.squeeze(-1) # [T, 1] -> [T] + # --- compute CE, optionally on a gathered subset of rows --- + if getattr(args, "log_probs_response_only", False): + # Only the response windows survive _extract_per_sample; gather them so CE + # runs on [T', V] instead of [T, V] (autograd's index_select backward + # scatters grads back to the dropped rows as zeros, which is exactly right). + keep_index = _response_keep_index( + total_lengths, response_lengths, qkv_format, max_seq_lens, args.allgather_cp, device, T + ) + if getattr(args, "log_probs_loss_mask_only", False) and full_loss_mask is not None: + mask_kept = full_loss_mask.reshape(-1).index_select(0, keep_index).to(torch.bool) + keep_index = keep_index[mask_kept] + + logits_kept = logits.index_select(0, keep_index) + tokens_kept = full_tokens.index_select(0, keep_index) + lp_kept, ent_kept = calculate_log_probs_and_entropy( + logits_kept, + tokens_kept, + tp_group, + with_entropy=with_entropy, + chunk_size=chunk_size, + ) + lp_kept = lp_kept.squeeze(-1) # [T', 1] -> [T'] + + # scatter back to full length so _extract_per_sample is unchanged + log_prob_full = lp_kept.new_zeros(T).index_copy(0, keep_index, lp_kept) + entropy_full = None + if with_entropy: + entropy_full = ent_kept.new_zeros(T).index_copy(0, keep_index, ent_kept) + else: + # --- compute on full [T,V] logits at once via calculate_log_probs_and_entropy --- + log_prob_full, entropy_full = calculate_log_probs_and_entropy( + logits, + full_tokens, + tp_group, + with_entropy=with_entropy, + chunk_size=chunk_size, + ) + log_prob_full = log_prob_full.squeeze(-1) # [T, 1] -> [T] # --- extract per-sample response portions --- log_probs_list, entropy_list = _extract_per_sample( @@ -481,6 +594,7 @@ def get_values( with_entropy: bool = False, non_loss_data: bool = True, max_seq_lens: list[int] | None = None, + full_loss_mask: torch.Tensor | None = None, # unused; accepted so the shared forward_only partial fits ) -> dict[str, list[torch.Tensor]]: """Extract per-token value predictions over response tokens. diff --git a/slime/backends/megatron_utils/model.py b/slime/backends/megatron_utils/model.py index db6020a94d..37938ab37a 100644 --- a/slime/backends/megatron_utils/model.py +++ b/slime/backends/megatron_utils/model.py @@ -39,6 +39,23 @@ logger = logging.getLogger(__name__) +def _mem_probe_enabled() -> bool: + return os.environ.get("SLIME_MEM_PROBE", "0") == "1" and torch.cuda.is_available() + + +def _log_train_step_mem_probe(rollout_id: int, step_id: int, start_allocated: int) -> None: + max_allocated = torch.cuda.max_memory_allocated() + logger.info( + "SLIME_MEM_PROBE train_one_step rollout_id=%s step_id=%s " + "allocated_start=%s allocated_peak=%s allocated_peak_delta=%s", + rollout_id, + step_id, + start_allocated, + max_allocated, + max_allocated - start_allocated, + ) + + def _disable_tqdm_for_non_main_rank() -> bool: return not ( mpu.get_data_parallel_rank(with_context_parallel=True) == 0 @@ -352,6 +369,7 @@ def forward_step( response_lengths=response_lengths, with_entropy=args.use_rollout_entropy, max_seq_lens=batch.get("max_seq_lens", None), + full_loss_mask=batch["full_loss_masks"], ) # Turn on evaluation mode which disables dropout. @@ -455,6 +473,11 @@ def train_one_step( and gradient norm for logging. """ args = get_args() + mem_probe = _mem_probe_enabled() + mem_probe_start_allocated = 0 + if mem_probe: + torch.cuda.reset_peak_memory_stats() + mem_probe_start_allocated = torch.cuda.memory_allocated() # Set grad to zero. for model_chunk in model: @@ -601,7 +624,11 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p cp_size=mpu.get_context_parallel_world_size(), dp_with_cp_group=mpu.get_data_parallel_group(with_context_parallel=True), ) + if mem_probe: + _log_train_step_mem_probe(rollout_id, step_id, mem_probe_start_allocated) return loss_reduced, grad_norm + if mem_probe: + _log_train_step_mem_probe(rollout_id, step_id, mem_probe_start_allocated) return {}, grad_norm diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index 6efe85eae7..0175efab7b 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -239,6 +239,18 @@ def add_train_arguments(parser): parser.add_argument( "--log-probs-chunk-size", type=int, default=-1, help="Chunk size to compute log probs to save memory" ) + parser.add_argument( + "--log-probs-response-only", + action="store_true", + help="Gather only the response-window rows before the log-prob/entropy cross-entropy, " + "shrinking the [T, V] logits tensor to [T', V] (T' = response tokens). Results are identical.", + ) + parser.add_argument( + "--log-probs-loss-mask-only", + action="store_true", + help="Further restrict the log-prob/entropy cross-entropy to loss_mask==1 rows. Requires " + "--log-probs-response-only; only valid on the policy-loss path (masked positions return 0).", + ) parser.add_argument( "--only-train-params-name-list", type=str, @@ -1851,6 +1863,9 @@ def slime_validate_args(args): assert args.use_dynamic_batch_size, "--balance-by-flops requires --use-dynamic-batch-size" args.balance_data = True + if getattr(args, "log_probs_loss_mask_only", False): + assert args.log_probs_response_only, "--log-probs-loss-mask-only requires --log-probs-response-only" + if args.eps_clip_high is None: args.eps_clip_high = args.eps_clip diff --git a/slime/utils/ppo_utils.py b/slime/utils/ppo_utils.py index 327dec2de6..f1fce5d73a 100644 --- a/slime/utils/ppo_utils.py +++ b/slime/utils/ppo_utils.py @@ -181,44 +181,95 @@ def compute_log_probs(logits: torch.Tensor, tokens: torch.Tensor, process_group: return -fused_vocab_parallel_cross_entropy(logits, tokens, process_group) -# from https://github.com/volcengine/verl/blob/0bdf7f469854815177e73dcfe9e420836c952e6e/verl/utils/megatron/tensor_parallel.py#L99 -class _VocabParallelEntropy(torch.autograd.Function): +class _VocabParallelLogProbsAndEntropy(torch.autograd.Function): @staticmethod - def forward(ctx, vocab_parallel_logits: torch.Tensor, process_group: dist.ProcessGroup) -> torch.Tensor: + def forward(ctx, vocab_parallel_logits: torch.Tensor, target: torch.Tensor, process_group): + from megatron.core.tensor_parallel.utils import VocabUtility - @torch.compile(dynamic=True) - def mul_reduce(a, b): - return (a * b).sum(dim=-1, keepdim=True) + # Pass None (not a zero-filled tensor) for an output whose grad does not flow, + # so the single-output backward paths skip a wasted full-vocab allocation. + ctx.set_materialize_grads(False) - logits_max = vocab_parallel_logits.max(dim=-1, keepdim=True).values + logits_max = vocab_parallel_logits.max(dim=-1).values dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group) - normalized_vocab_parallel_logits = vocab_parallel_logits - logits_max - normalized_exp_logits = normalized_vocab_parallel_logits.exp_() - normalized_sum_exp_logits = normalized_exp_logits.sum(dim=-1, keepdim=True) - dist.all_reduce(normalized_sum_exp_logits, group=process_group) - softmax_logits = normalized_exp_logits.div_(normalized_sum_exp_logits) - sum_softmax_times_logits = mul_reduce(softmax_logits, vocab_parallel_logits) - dist.all_reduce(sum_softmax_times_logits, group=process_group) - entropy = logits_max + normalized_sum_exp_logits.log() - sum_softmax_times_logits - ctx.save_for_backward(vocab_parallel_logits, softmax_logits, sum_softmax_times_logits) - return entropy.squeeze(dim=-1) + + get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size + partition_vocab_size = vocab_parallel_logits.size(-1) + vocab_start_index, vocab_end_index = get_vocab_range( + partition_vocab_size, process_group.rank(), process_group.size() + ) + + vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) + target_mask = (target < vocab_start_index) | (target >= vocab_end_index) + masked_target = target.clone() - vocab_start_index + masked_target[target_mask] = 0 + + logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) + masked_target_1d = masked_target.view(-1) + arange_1d = torch.arange(start=0, end=logits_2d.size(0), device=logits_2d.device) + predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] + predicted_logits_1d = predicted_logits_1d.clone().contiguous() + predicted_logits = predicted_logits_1d.view_as(target) + predicted_logits[target_mask] = 0.0 + + torch.exp(vocab_parallel_logits, out=vocab_parallel_logits) + sum_exp_logits = vocab_parallel_logits.sum(dim=-1) + dist.all_reduce(predicted_logits, op=dist.ReduceOp.SUM, group=process_group) + dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=process_group) + + log_sum_exp = sum_exp_logits.log() + log_prob = predicted_logits - log_sum_exp + softmax = vocab_parallel_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + + local_entropy = torch.zeros_like(log_sum_exp) + for softmax_chunk in softmax.chunk(max(1, softmax.size(-1) // 8192), dim=-1): + local_entropy -= torch.xlogy(softmax_chunk, softmax_chunk).sum(dim=-1) + dist.all_reduce(local_entropy, op=dist.ReduceOp.SUM, group=process_group) + + # `softmax` is reused in place as the gradient buffer in backward (no double-backward). + ctx.save_for_backward(softmax, target_mask, masked_target_1d, local_entropy) + return log_prob, local_entropy.squeeze(dim=-1) @staticmethod - def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: - vocab_parallel_logits, softmax_logits, sum_softmax_times_logits = ctx.saved_tensors - # reuse softmax_logits as grad - vocab_parallel_logits.sub_(sum_softmax_times_logits) - softmax_logits.mul_(vocab_parallel_logits) - softmax_logits.mul_(grad_output.unsqueeze(dim=-1)) - # recover vocab_parallel_logits - vocab_parallel_logits.add_(sum_softmax_times_logits) - softmax_logits.mul_(-1) - return softmax_logits, None + def backward(ctx, grad_log_prob: torch.Tensor, grad_entropy: torch.Tensor): + # Local grad wrt logit z_j (no cross-rank reduce needed): + # g_j = softmax_j * [ -grad_log_prob - grad_entropy * (entropy + log_softmax_j) ] + # + grad_log_prob * 1{j == target} + # The saved softmax buffer is reused in place as the gradient, so the only extra + # full-vocab allocation is `log_softmax`, and only when entropy gradient flows. + softmax, target_mask, masked_target_1d, entropy = ctx.saved_tensors + partition_vocab_size = softmax.size(-1) + + if grad_entropy is not None: + # log_softmax = log(softmax); softmax underflow (==0) -> log(0)=-inf, mapped to 0 + # so the softmax==0 positions contribute nothing (0 * finite). + log_softmax = softmax.log() + log_softmax.nan_to_num_(neginf=0.0) + log_softmax.add_(entropy.unsqueeze(dim=-1)) # entropy + log_softmax + log_softmax.mul_(grad_entropy.unsqueeze(dim=-1).unsqueeze(dim=-1)) + if grad_log_prob is not None: + log_softmax.add_(grad_log_prob.unsqueeze(dim=-1)) + softmax.mul_(log_softmax).neg_() # softmax * [-grad_log_prob - grad_entropy*(H+log p)] + del log_softmax + elif grad_log_prob is not None: + softmax.mul_(grad_log_prob.unsqueeze(dim=-1).neg()) # -softmax * grad_log_prob + else: + return None, None, None + + if grad_log_prob is not None: + grad_2d = softmax.view(-1, partition_vocab_size) + arange_1d = torch.arange(start=0, end=grad_2d.size(0), device=grad_2d.device) + softmax_update = 1.0 - target_mask.view(-1).to(softmax.dtype) + grad_2d[arange_1d, masked_target_1d] += grad_log_prob.reshape(-1) * softmax_update + return softmax.to(torch.bfloat16), None, None -def compute_entropy_from_logits(logits: torch.Tensor, process_group) -> torch.Tensor: - return _VocabParallelEntropy.apply(logits, process_group) + +def compute_log_probs_and_entropy(logits: torch.Tensor, tokens: torch.Tensor, process_group): + logits = logits.unsqueeze(1) + tokens = tokens.unsqueeze(1) + return _VocabParallelLogProbsAndEntropy.apply(logits, tokens, process_group) def get_grpo_returns( @@ -669,6 +720,18 @@ def chunked_gae( return advantages, returns +def _clone_if_grad_tracked(logits: torch.Tensor) -> torch.Tensor: + # Megatron-LM's fused CE mutates its float32 input in place (subtract-max, + # exp(out=...), div_). That is safe to hand over directly + # only when autograd will not observe the mutation. When grad is tracked, an + # in-place write on a (view of a) grad-requiring tensor corrupts the graph and + # raises (issue #1951). Clone exactly when grad is tracked; otherwise pass the + # tensor through so the no-grad ref/old-logprob path keeps its peak-memory win. + if torch.is_grad_enabled() and logits.requires_grad: + return logits.clone() + return logits + + def calculate_log_probs_and_entropy(logits, tokens, tp_group, with_entropy: bool = False, chunk_size: int = -1): logits = logits.contiguous() entropy = None @@ -680,22 +743,29 @@ def calculate_log_probs_and_entropy(logits, tokens, tp_group, with_entropy: bool if with_entropy: entropys = [] - for logits_chunk in logits_chunks: - entropy_input = logits_chunk.clone() - entropys.append(compute_entropy_from_logits(entropy_input, tp_group)) + log_probs = [] + for tokens_chunk, logits_chunk in zip(tokens_chunks, logits_chunks, strict=True): + # The fused helper computes log-probs and entropy from one tensor, + # replacing the two separate destructive passes (and two clones). + log_prob, entropy_chunk = compute_log_probs_and_entropy( + _clone_if_grad_tracked(logits_chunk), tokens_chunk, tp_group + ) + log_probs.append(log_prob) + entropys.append(entropy_chunk) entropy = torch.cat(entropys, dim=0) - - log_probs = [] - for tokens_chunk, logits_chunk in zip(tokens_chunks, logits_chunks, strict=True): - log_prob = compute_log_probs(logits_chunk.clone(), tokens_chunk, tp_group) - log_probs.append(log_prob) + else: + log_probs = [] + for tokens_chunk, logits_chunk in zip(tokens_chunks, logits_chunks, strict=True): + log_prob = compute_log_probs(_clone_if_grad_tracked(logits_chunk), tokens_chunk, tp_group) + log_probs.append(log_prob) log_prob = torch.cat(log_probs, dim=0) else: if with_entropy: - entropy_input = logits.clone() - entropy = compute_entropy_from_logits(entropy_input, tp_group) - - log_prob = compute_log_probs(logits.clone(), tokens, tp_group) + # The fused helper computes log-probs and entropy from one tensor, + # replacing the two separate destructive passes (and two clones). + log_prob, entropy = compute_log_probs_and_entropy(_clone_if_grad_tracked(logits), tokens, tp_group) + else: + log_prob = compute_log_probs(_clone_if_grad_tracked(logits), tokens, tp_group) else: log_prob = logits.new_zeros((0,)) if with_entropy: diff --git a/tests/test_logprob_entropy_fused.py b/tests/test_logprob_entropy_fused.py new file mode 100644 index 0000000000..97728c4d99 --- /dev/null +++ b/tests/test_logprob_entropy_fused.py @@ -0,0 +1,384 @@ +from __future__ import annotations + +import sys +import types +from contextlib import contextmanager +from pathlib import Path + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from _cp_dist_helpers import free_port + + +ROOT = Path(__file__).resolve().parents[1] +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + +NUM_GPUS = 0 + + +class _FakeSingleRankGroup: + def rank(self) -> int: + return 0 + + def size(self) -> int: + return 1 + + +def _install_megatron_stubs() -> None: + megatron = sys.modules.setdefault("megatron", types.ModuleType("megatron")) + core = sys.modules.setdefault("megatron.core", types.ModuleType("megatron.core")) + fusions = sys.modules.setdefault("megatron.core.fusions", types.ModuleType("megatron.core.fusions")) + fused = types.ModuleType("megatron.core.fusions.fused_cross_entropy") + tensor_parallel = sys.modules.setdefault( + "megatron.core.tensor_parallel", types.ModuleType("megatron.core.tensor_parallel") + ) + utils = types.ModuleType("megatron.core.tensor_parallel.utils") + + class VocabUtility: + @staticmethod + def vocab_range_from_per_partition_vocab_size(partition_vocab_size: int, rank: int, world_size: int): + assert world_size > 0 + assert 0 <= rank < world_size + start = rank * partition_vocab_size + return start, start + partition_vocab_size + + class _MockVocabParallelCrossEntropy(torch.autograd.Function): + @staticmethod + def forward(ctx, logits: torch.Tensor, target: torch.Tensor, process_group): + del process_group + logits_max = logits.max(dim=-1, keepdim=True).values + logits.sub_(logits_max) + predicted_logits = logits.gather(dim=-1, index=target.unsqueeze(-1)).squeeze(-1) + torch.exp(logits, out=logits) + sum_exp_logits = logits.sum(dim=-1) + logits.div_(sum_exp_logits.unsqueeze(-1)) + ctx.save_for_backward(logits, target) + return sum_exp_logits.log() - predicted_logits + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + softmax, target = ctx.saved_tensors + grad_input = softmax.clone() + grad_input.scatter_add_( + dim=-1, + index=target.unsqueeze(-1), + src=-torch.ones_like(target, dtype=grad_input.dtype).unsqueeze(-1), + ) + grad_input.mul_(grad_output.unsqueeze(-1)) + return grad_input.to(torch.bfloat16), None, None + + def fused_vocab_parallel_cross_entropy(logits: torch.Tensor, target: torch.Tensor, process_group): + return _MockVocabParallelCrossEntropy.apply(logits, target, process_group) + + # mpu stub for the cp1 (no context-parallel) path used by loss.py helpers. + mpu = types.ModuleType("megatron.core.mpu") + mpu.get_context_parallel_world_size = lambda: 1 + mpu.get_context_parallel_rank = lambda: 0 + mpu.get_tensor_model_parallel_group = lambda: _FakeSingleRankGroup() + + fused.fused_vocab_parallel_cross_entropy = fused_vocab_parallel_cross_entropy + utils.VocabUtility = VocabUtility + fusions.fused_cross_entropy = fused + tensor_parallel.utils = utils + core.fusions = fusions + core.tensor_parallel = tensor_parallel + core.mpu = mpu + megatron.core = core + sys.modules["megatron.core.fusions.fused_cross_entropy"] = fused + sys.modules["megatron.core.tensor_parallel.utils"] = utils + sys.modules["megatron.core.mpu"] = mpu + + +@contextmanager +def _single_rank_all_reduce(): + original_all_reduce = dist.all_reduce + + def all_reduce(tensor, op=None, group=None, async_op=False): + del tensor, op, group + if async_op: + raise NotImplementedError("async all_reduce is not needed by this test") + return None + + dist.all_reduce = all_reduce + try: + yield + finally: + dist.all_reduce = original_all_reduce + + +def _naive_log_probs_and_entropy(logits: torch.Tensor, tokens: torch.Tensor): + log_softmax = torch.log_softmax(logits.float(), dim=-1) + log_probs = log_softmax.gather(dim=-1, index=tokens.unsqueeze(-1)) + entropy = -(log_softmax.exp() * log_softmax).sum(dim=-1) + return log_probs, entropy + + +def _make_inputs(requires_grad: bool = False): + torch.manual_seed(1234) + logits = torch.randn(9, 17, dtype=torch.float32) + logits.requires_grad_(requires_grad) + tokens = torch.randint(0, logits.size(-1), (logits.size(0),), dtype=torch.long) + return logits, tokens + + +def test_fused_forward_matches_naive_reference(): + _install_megatron_stubs() + from slime.utils.ppo_utils import calculate_log_probs_and_entropy + + logits, tokens = _make_inputs() + with _single_rank_all_reduce(): + log_probs, entropy = calculate_log_probs_and_entropy( + logits.clone(), tokens, _FakeSingleRankGroup(), with_entropy=True + ) + + ref_log_probs, ref_entropy = _naive_log_probs_and_entropy(logits, tokens) + torch.testing.assert_close(log_probs, ref_log_probs, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(entropy, ref_entropy, atol=1e-5, rtol=1e-5) + + +@pytest.mark.parametrize("with_entropy", [False, True]) +def test_chunked_matches_unchunked(with_entropy: bool): + _install_megatron_stubs() + from slime.utils.ppo_utils import calculate_log_probs_and_entropy + + logits, tokens = _make_inputs() + with _single_rank_all_reduce(): + full_log_probs, full_entropy = calculate_log_probs_and_entropy( + logits.clone(), tokens, _FakeSingleRankGroup(), with_entropy=with_entropy, chunk_size=-1 + ) + chunk_log_probs, chunk_entropy = calculate_log_probs_and_entropy( + logits.clone(), tokens, _FakeSingleRankGroup(), with_entropy=with_entropy, chunk_size=4 + ) + + torch.testing.assert_close(chunk_log_probs, full_log_probs, atol=1e-5, rtol=1e-5) + if with_entropy: + torch.testing.assert_close(chunk_entropy, full_entropy, atol=1e-5, rtol=1e-5) + else: + assert chunk_entropy is None + + +def test_no_entropy_chunked_backward_preserves_input_grad(): + _install_megatron_stubs() + from slime.utils.ppo_utils import calculate_log_probs_and_entropy + + logits, tokens = _make_inputs(requires_grad=True) + with _single_rank_all_reduce(): + log_probs, entropy = calculate_log_probs_and_entropy( + logits, tokens, _FakeSingleRankGroup(), with_entropy=False, chunk_size=4 + ) + assert entropy is None + log_probs.float().sum().backward() + + assert logits.grad is not None + assert torch.isfinite(logits.grad).all() + assert logits.grad.abs().sum() > 0 + + +def test_fused_backward_matches_naive_reference_with_bf16_tolerance(): + _install_megatron_stubs() + from slime.utils.ppo_utils import calculate_log_probs_and_entropy + + logits, tokens = _make_inputs(requires_grad=True) + with _single_rank_all_reduce(): + log_probs, entropy = calculate_log_probs_and_entropy(logits, tokens, _FakeSingleRankGroup(), with_entropy=True) + (log_probs.float().sum() + 0.13 * entropy.float().sum()).backward() + + ref_logits = logits.detach().clone().requires_grad_(True) + ref_log_probs, ref_entropy = _naive_log_probs_and_entropy(ref_logits, tokens) + (ref_log_probs.float().sum() + 0.13 * ref_entropy.float().sum()).backward() + + torch.testing.assert_close(logits.grad, ref_logits.grad, atol=4e-3, rtol=4e-3) + + +def test_fused_entropy_only_backward_matches_naive_reference(): + _install_megatron_stubs() + from slime.utils.ppo_utils import calculate_log_probs_and_entropy + + logits, tokens = _make_inputs(requires_grad=True) + with _single_rank_all_reduce(): + log_probs, entropy = calculate_log_probs_and_entropy(logits, tokens, _FakeSingleRankGroup(), with_entropy=True) + # Only entropy contributes to the loss, so autograd passes + # grad_log_prob=None into the fused backward, exercising the + # entropy-only branch. + entropy.float().sum().backward() + + ref_logits = logits.detach().clone().requires_grad_(True) + _, ref_entropy = _naive_log_probs_and_entropy(ref_logits, tokens) + ref_entropy.float().sum().backward() + + assert logits.grad is not None + assert torch.isfinite(logits.grad).all() + torch.testing.assert_close(logits.grad, ref_logits.grad, atol=4e-3, rtol=4e-3) + + +def _tp2_worker(rank: int, master_port: int, result_path: str) -> None: + import os + + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(master_port) + dist.init_process_group(backend="gloo", rank=rank, world_size=2) + try: + _install_megatron_stubs() + from slime.utils.ppo_utils import compute_log_probs_and_entropy + + torch.manual_seed(2024) + batch_size = 7 + partition_vocab_size = 5 + full_vocab_size = partition_vocab_size * 2 + full_logits = torch.randn(batch_size, full_vocab_size, dtype=torch.float32) + tokens = torch.randint(0, full_vocab_size, (batch_size,), dtype=torch.long) + start = rank * partition_vocab_size + local_logits = full_logits[:, start : start + partition_vocab_size].contiguous().requires_grad_(True) + + log_probs, entropy = compute_log_probs_and_entropy(local_logits, tokens, dist.group.WORLD) + (log_probs.float().sum() + 0.13 * entropy.float().sum()).backward() + + gathered_grads = [torch.empty_like(local_logits.grad) for _ in range(2)] + dist.all_gather(gathered_grads, local_logits.grad) + + if rank == 0: + ref_logits = full_logits.detach().clone().requires_grad_(True) + ref_log_probs, ref_entropy = _naive_log_probs_and_entropy(ref_logits, tokens) + (ref_log_probs.float().sum() + 0.13 * ref_entropy.float().sum()).backward() + + torch.testing.assert_close(log_probs, ref_log_probs, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(entropy, ref_entropy, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(torch.cat(gathered_grads, dim=-1), ref_logits.grad, atol=4e-3, rtol=4e-3) + with open(result_path, "w") as f: + f.write("ok") + finally: + dist.destroy_process_group() + + +def test_tp2_fused_backward_matches_full_vocab_reference(tmp_path): + result_path = str(tmp_path / "tp2_result.txt") + mp.spawn(_tp2_worker, args=(free_port(), result_path), nprocs=2, join=True) + assert (tmp_path / "tp2_result.txt").read_text() == "ok" + + +# --------------------------------------------------------------------------- +# Response-only / loss-mask gather (shrinks the [T, V] logits tensor before CE) +# --------------------------------------------------------------------------- + + +def _packed_thd_inputs(requires_grad: bool = False): + """A small cp1/thd packed batch: logits [1, T, V] + per-sample token/length lists.""" + torch.manual_seed(7) + total_lengths = [6, 5, 7] + response_lengths = [3, 2, 4] + vocab = 17 + T = sum(total_lengths) + logits = torch.randn(1, T, vocab, dtype=torch.float32, requires_grad=requires_grad) + unconcat_tokens = [torch.randint(0, vocab, (tl,), dtype=torch.long) for tl in total_lengths] + return logits, unconcat_tokens, total_lengths, response_lengths + + +def _thd_args(response_only: bool, chunk_size: int = -1, loss_mask_only: bool = False): + import types as _t + + return _t.SimpleNamespace( + qkv_format="thd", + rollout_temperature=1.0, + log_probs_chunk_size=chunk_size, + log_probs_response_only=response_only, + log_probs_loss_mask_only=loss_mask_only, + allgather_cp=False, + ) + + +def _run_get_log_probs( + logits, unconcat_tokens, total_lengths, response_lengths, args, with_entropy, full_loss_mask=None +): + from slime.backends.megatron_utils.loss import get_log_probs_and_entropy + + _, res = get_log_probs_and_entropy( + logits, + args=args, + unconcat_tokens=unconcat_tokens, + total_lengths=total_lengths, + response_lengths=response_lengths, + with_entropy=with_entropy, + max_seq_lens=None, + full_loss_mask=full_loss_mask, + ) + return res + + +@pytest.mark.parametrize("with_entropy", [False, True]) +@pytest.mark.parametrize("chunk_size", [-1, 4]) +def test_response_only_matches_full_path(with_entropy: bool, chunk_size: int): + _install_megatron_stubs() + + logits, toks, tl, rl = _packed_thd_inputs() + with _single_rank_all_reduce(): + full = _run_get_log_probs(logits.clone(), toks, tl, rl, _thd_args(False, chunk_size), with_entropy) + gathered = _run_get_log_probs(logits.clone(), toks, tl, rl, _thd_args(True, chunk_size), with_entropy) + + assert len(full["log_probs"]) == len(gathered["log_probs"]) == len(tl) + for a, b in zip(full["log_probs"], gathered["log_probs"], strict=True): + torch.testing.assert_close(a, b, atol=1e-6, rtol=1e-6) + if with_entropy: + for a, b in zip(full["entropy"], gathered["entropy"], strict=True): + torch.testing.assert_close(a, b, atol=1e-6, rtol=1e-6) + + +def test_response_only_backward_matches_full_path(): + _install_megatron_stubs() + + def _loss_and_grad(response_only: bool): + logits, toks, tl, rl = _packed_thd_inputs(requires_grad=True) + with _single_rank_all_reduce(): + res = _run_get_log_probs(logits, toks, tl, rl, _thd_args(response_only), with_entropy=True) + loss = sum(lp.float().sum() for lp in res["log_probs"]) + loss = loss + 0.13 * sum(e.float().sum() for e in res["entropy"]) + loss.backward() + return logits.grad + + full_grad = _loss_and_grad(False) + gathered_grad = _loss_and_grad(True) + # Non-response rows have no path to the loss in either case -> zero grad in both. + torch.testing.assert_close(gathered_grad, full_grad, atol=4e-3, rtol=4e-3) + + +def test_loss_mask_only_zeros_masked_positions(): + _install_megatron_stubs() + + logits, toks, tl, rl = _packed_thd_inputs() + T = logits.size(1) + # mask aligned to the logits layout; drop half the response positions + torch.manual_seed(0) + full_loss_mask = (torch.rand(T) > 0.5).to(torch.float32) + + with _single_rank_all_reduce(): + full = _run_get_log_probs(logits.clone(), toks, tl, rl, _thd_args(False), with_entropy=False) + masked = _run_get_log_probs( + logits.clone(), + toks, + tl, + rl, + _thd_args(True, loss_mask_only=True), + with_entropy=False, + full_loss_mask=full_loss_mask, + ) + + from slime.backends.megatron_utils.loss import _response_keep_index + + keep = _response_keep_index(tl, rl, "thd", None, False, logits.device, T) + # walk per-sample windows; kept(mask==1) positions equal full path, dropped positions are 0 + pos = 0 + for s_full, s_masked, length in zip(full["log_probs"], masked["log_probs"], rl, strict=True): + for j in range(length): + keep_pos = int(keep[pos + j]) + if full_loss_mask[keep_pos] > 0: + torch.testing.assert_close(s_masked[j], s_full[j], atol=1e-6, rtol=1e-6) + else: + assert s_masked[j].item() == 0.0 + pos += length + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__])) diff --git a/tools/repro_1951.py b/tools/repro_1951.py new file mode 100644 index 0000000000..64c5d17bcf --- /dev/null +++ b/tools/repro_1951.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python +"""Synthetic reproducer for issue #1951 log-prob/entropy memory peaks.""" + +from __future__ import annotations + +import argparse +import sys +import types +from contextlib import contextmanager + +import torch +import torch.distributed as dist + + +class _FakeSingleRankGroup: + def rank(self) -> int: + return 0 + + def size(self) -> int: + return 1 + + +def _install_mock_fused_cross_entropy() -> None: + """Install a single-rank Megatron fused CE stand-in for local repro runs.""" + megatron = sys.modules.setdefault("megatron", types.ModuleType("megatron")) + core = sys.modules.setdefault("megatron.core", types.ModuleType("megatron.core")) + fusions = sys.modules.setdefault("megatron.core.fusions", types.ModuleType("megatron.core.fusions")) + fused = types.ModuleType("megatron.core.fusions.fused_cross_entropy") + tensor_parallel = sys.modules.setdefault( + "megatron.core.tensor_parallel", types.ModuleType("megatron.core.tensor_parallel") + ) + utils = types.ModuleType("megatron.core.tensor_parallel.utils") + + class VocabUtility: + @staticmethod + def vocab_range_from_per_partition_vocab_size(partition_vocab_size: int, rank: int, world_size: int): + assert world_size == 1 + assert rank == 0 + return 0, partition_vocab_size + + class _MockVocabParallelCrossEntropy(torch.autograd.Function): + @staticmethod + def forward(ctx, logits: torch.Tensor, target: torch.Tensor, process_group): + del process_group + logits = logits.float() + logits_max = logits.max(dim=-1, keepdim=True).values + logits.sub_(logits_max) + predicted_logits = logits.gather(dim=-1, index=target.unsqueeze(-1)).squeeze(-1) + torch.exp(logits, out=logits) + sum_exp_logits = logits.sum(dim=-1) + logits.div_(sum_exp_logits.unsqueeze(-1)) + ctx.save_for_backward(logits, target) + return sum_exp_logits.log() - predicted_logits + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + softmax, target = ctx.saved_tensors + grad_input = softmax.clone() + grad_input.scatter_add_( + dim=-1, + index=target.unsqueeze(-1), + src=-torch.ones_like(target, dtype=grad_input.dtype).unsqueeze(-1), + ) + grad_input.mul_(grad_output.unsqueeze(-1)) + return grad_input.to(torch.bfloat16), None, None + + def fused_vocab_parallel_cross_entropy(logits: torch.Tensor, target: torch.Tensor, process_group): + return _MockVocabParallelCrossEntropy.apply(logits, target, process_group) + + fused.fused_vocab_parallel_cross_entropy = fused_vocab_parallel_cross_entropy + utils.VocabUtility = VocabUtility + fusions.fused_cross_entropy = fused + tensor_parallel.utils = utils + core.fusions = fusions + core.tensor_parallel = tensor_parallel + megatron.core = core + sys.modules["megatron.core.fusions.fused_cross_entropy"] = fused + sys.modules["megatron.core.tensor_parallel.utils"] = utils + + +@contextmanager +def _single_rank_all_reduce(): + original_all_reduce = dist.all_reduce + + def all_reduce(tensor, op=None, group=None, async_op=False): + del tensor, op, group + if async_op: + raise NotImplementedError("async all_reduce is not needed by this repro") + return None + + dist.all_reduce = all_reduce + try: + yield + finally: + dist.all_reduce = original_all_reduce + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("--batch", "-B", type=int, default=4096, help="Number of token positions.") + parser.add_argument("--vocab", "-V", type=int, default=151936, help="Vocabulary dimension.") + parser.add_argument("--chunk-size", type=int, default=-1, help="Forwarded to calculate_log_probs_and_entropy.") + parser.add_argument("--dtype", choices=["float32", "bfloat16", "float16"], default="float32") + parser.add_argument("--with-entropy", action="store_true", help="Also compute entropy.") + parser.add_argument("--backward", action="store_true", help="Run backward through the returned tensors.") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--use-real-megatron", action="store_true", help="Do not install the local fused CE mock.") + parser.add_argument( + "--response-frac", + type=float, + default=1.0, + help="Fraction of the B positions that are response tokens. Emulates --log-probs-response-only by " + "running CE on B*frac rows instead of B (prompt rows are gathered out before CE).", + ) + parser.add_argument( + "--sweep", + action="store_true", + help="Sweep response-frac over {1.0, 0.5, 0.25, 0.0625} at fixed total B and print a peak-memory table.", + ) + return parser.parse_args() + + +def _fmt_bytes(value: int) -> str: + return f"{value / 1024**3:.3f} GiB ({value} bytes)" + + +def _measure(B: int, args, dtype, device) -> tuple[int, int]: + """Run CE on a [B, V] tensor and return (allocated_after_logits, peak).""" + from slime.utils.ppo_utils import calculate_log_probs_and_entropy + + torch.manual_seed(args.seed) + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + logits = torch.randn(B, args.vocab, device=device, dtype=dtype) + if args.backward: + logits.requires_grad_(True) + tokens = torch.randint(args.vocab, (B,), device=device) + torch.cuda.synchronize() + + allocated_after_logits = torch.cuda.memory_allocated() + torch.cuda.reset_peak_memory_stats() + + with _single_rank_all_reduce(): + log_probs, entropy = calculate_log_probs_and_entropy( + logits, + tokens, + _FakeSingleRankGroup(), + with_entropy=args.with_entropy, + chunk_size=args.chunk_size, + ) + if args.backward: + loss = log_probs.float().sum() + if entropy is not None: + loss = loss + entropy.float().sum() + loss.backward() + + torch.cuda.synchronize() + peak = torch.cuda.max_memory_allocated() + del logits, tokens, log_probs, entropy + return allocated_after_logits, peak + + +def main() -> None: + args = _parse_args() + if not torch.cuda.is_available(): + raise SystemExit("CUDA is required for peak-memory measurement") + + if not args.use_real_megatron: + _install_mock_fused_cross_entropy() + + dtype = getattr(torch, args.dtype) + device = torch.device("cuda") + + print( + f"total_B={args.batch} vocab={args.vocab} dtype={args.dtype} with_entropy={args.with_entropy} " + f"chunk_size={args.chunk_size} backward={args.backward} mock_megatron={not args.use_real_megatron}" + ) + + if args.sweep: + # Emulates --log-probs-response-only: CE peak should track T' = frac * T, not T. + # frac=1.0 is the "before" (full T); smaller fracs are the gathered T'. + print(f"{'response_frac':>14} {'T_prime':>10} {'peak':>16}") + for frac in (1.0, 0.5, 0.25, 0.0625): + B_eff = max(1, int(args.batch * frac)) + try: + _, peak = _measure(B_eff, args, dtype, device) + print(f"{frac:>14} {B_eff:>10} {_fmt_bytes(peak):>16}") + except torch.cuda.OutOfMemoryError: + print(f"{frac:>14} {B_eff:>10} {'OOM':>16}") + torch.cuda.empty_cache() + return + + B_eff = max(1, int(args.batch * args.response_frac)) + allocated_after_logits, peak = _measure(B_eff, args, dtype, device) + print(f"response_frac={args.response_frac} -> T_prime={B_eff} (of T={args.batch})") + print(f"allocated_after_logits={_fmt_bytes(allocated_after_logits)}") + print(f"peak_during_call={_fmt_bytes(peak)}") + print(f"peak_delta_after_logits={_fmt_bytes(peak - allocated_after_logits)}") + + +if __name__ == "__main__": + main()