From 0a034ba6f990cf4dcacc58c95d000bb2dc7566ff Mon Sep 17 00:00:00 2001 From: EazyReal Date: Sun, 21 Jun 2026 18:42:22 +0000 Subject: [PATCH] fix(ppo): stop corrupting the logged rollout/kl metric compute_advantages_and_returns stores the approximate KL in rollout_data["kl"], which log_rollout_data reduces and logs as rollout/kl. The ppo estimator builds its reward signal as -kl_coef * kl (plus the scalar reward at the last token) in place (`k *= kl_coef; k[-1] += reward`), and `k` aliases the tensors in rollout_data["kl"]. So after the ppo branch the logged KL is overwritten with the reward. Every other estimator (grpo/gspo/cispo/reinforce++) treats kl as read-only. Build the reward out-of-place so the stored KL stays intact. Co-Authored-By: Claude Opus 4.8 --- .github/workflows/pr-test.yml | 2 +- .github/workflows/pr-test.yml.j2 | 1 + slime/backends/megatron_utils/loss.py | 4 +- tests/test_ppo_kl_metric.py | 76 +++++++++++++++++++++++++++ 4 files changed, 81 insertions(+), 2 deletions(-) create mode 100644 tests/test_ppo_kl_metric.py diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 2d6e8ce5b2..1f86639f4b 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -372,7 +372,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 0, "test_file": "test_megatron_argument_validation.py"}, {"num_gpus": 0, "test_file": "utils/test_megatron_server_arguments.py"}, {"num_gpus": 0, "test_file": "test_dp_schedule.py"}, {"num_gpus": 0, "test_file": "test_cp_utils.py"}, {"num_gpus": 0, "test_file": "test_metric_report.py"}, {"num_gpus": 0, "test_file": "test_metric_report_dist.py"}, {"num_gpus": 0, "test_file": "test_loss_cp_invariance.py"}, {"num_gpus": 0, "test_file": "test_logprob_response_spans.py"}, {"num_gpus": 0, "test_file": "test_value_temperature.py"}, {"num_gpus": 0, "test_file": "test_cispo_loss.py"}, {"num_gpus": 0, "test_file": "test_rm_f1.py"}, {"num_gpus": 0, "test_file": "test_rm_gpqa.py"}, {"num_gpus": 0, "test_file": "test_rm_math.py"}, {"num_gpus": 0, "test_file": "test_rm_math_dapo.py"}, {"num_gpus": 0, "test_file": "test_rm_deepscaler.py"}, {"num_gpus": 0, "test_file": "test_sample.py"}, {"num_gpus": 0, "test_file": "test_rollout_validation.py"}, {"num_gpus": 0, "test_file": "test_placement_group.py"}, {"num_gpus": 0, "test_file": "test_external_sglang_engines.py"}, {"num_gpus": 0, "test_file": "utils/test_hf_checkpoint_saver.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_rollout_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_runtime_hook_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_path_loading_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_generate_contracts.py"}] + info: [{"num_gpus": 0, "test_file": "test_megatron_argument_validation.py"}, {"num_gpus": 0, "test_file": "utils/test_megatron_server_arguments.py"}, {"num_gpus": 0, "test_file": "test_dp_schedule.py"}, {"num_gpus": 0, "test_file": "test_cp_utils.py"}, {"num_gpus": 0, "test_file": "test_metric_report.py"}, {"num_gpus": 0, "test_file": "test_metric_report_dist.py"}, {"num_gpus": 0, "test_file": "test_loss_cp_invariance.py"}, {"num_gpus": 0, "test_file": "test_logprob_response_spans.py"}, {"num_gpus": 0, "test_file": "test_value_temperature.py"}, {"num_gpus": 0, "test_file": "test_ppo_kl_metric.py"}, {"num_gpus": 0, "test_file": "test_cispo_loss.py"}, {"num_gpus": 0, "test_file": "test_rm_f1.py"}, {"num_gpus": 0, "test_file": "test_rm_gpqa.py"}, {"num_gpus": 0, "test_file": "test_rm_math.py"}, {"num_gpus": 0, "test_file": "test_rm_math_dapo.py"}, {"num_gpus": 0, "test_file": "test_rm_deepscaler.py"}, {"num_gpus": 0, "test_file": "test_sample.py"}, {"num_gpus": 0, "test_file": "test_rollout_validation.py"}, {"num_gpus": 0, "test_file": "test_placement_group.py"}, {"num_gpus": 0, "test_file": "test_external_sglang_engines.py"}, {"num_gpus": 0, "test_file": "utils/test_hf_checkpoint_saver.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_rollout_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_runtime_hook_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_path_loading_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_generate_contracts.py"}] defaults: run: working-directory: ${{ github.workspace }} diff --git a/.github/workflows/pr-test.yml.j2 b/.github/workflows/pr-test.yml.j2 index 30cf386421..e798e92cde 100644 --- a/.github/workflows/pr-test.yml.j2 +++ b/.github/workflows/pr-test.yml.j2 @@ -71,6 +71,7 @@ {'test_file': 'test_loss_cp_invariance.py', 'num_gpus': 0}, {'test_file': 'test_logprob_response_spans.py', 'num_gpus': 0}, {'test_file': 'test_value_temperature.py', 'num_gpus': 0}, + {'test_file': 'test_ppo_kl_metric.py', 'num_gpus': 0}, {'test_file': 'test_cispo_loss.py', 'num_gpus': 0}, {'test_file': 'test_rm_f1.py', 'num_gpus': 0}, {'test_file': 'test_rm_gpqa.py', 'num_gpus': 0}, diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index 72afdfa66c..af1ec16d04 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -725,7 +725,9 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch) kl_coef = -args.kl_coef cp_rank = mpu.get_context_parallel_rank() for reward, k in zip(old_rewards, kl, strict=False): - k *= kl_coef + # Build rewards out-of-place: kl aliases rollout_data["kl"], which is + # the source for the logged rollout/kl metric. + k = k * kl_coef if cp_rank == 0: k[-1] += reward rewards.append(k) diff --git a/tests/test_ppo_kl_metric.py b/tests/test_ppo_kl_metric.py new file mode 100644 index 0000000000..c084afee9d --- /dev/null +++ b/tests/test_ppo_kl_metric.py @@ -0,0 +1,76 @@ +import sys +import types +from argparse import Namespace + +import pytest +import torch + +NUM_GPUS = 0 + + +@pytest.mark.unit +def test_ppo_estimator_does_not_corrupt_logged_kl(monkeypatch): + """The ppo branch turns kl into a reward (-kl_coef * kl, plus the scalar + reward at the last token). It must do so out-of-place: rollout_data["kl"] is + the source for the logged rollout/kl metric, and every other estimator treats + kl as read-only. An in-place update overwrites the logged KL with the reward. + """ + previous_loss = sys.modules.pop("slime.backends.megatron_utils.loss", None) + previous_cp_utils = sys.modules.pop("slime.backends.megatron_utils.cp_utils", None) + + mpu_stub = types.SimpleNamespace( + get_context_parallel_world_size=lambda: 1, + get_context_parallel_rank=lambda: 0, + is_pipeline_last_stage=lambda: True, + ) + megatron_mod = types.ModuleType("megatron") + core_mod = types.ModuleType("megatron.core") + core_mod.mpu = mpu_stub + monkeypatch.setitem(sys.modules, "megatron", megatron_mod) + monkeypatch.setitem(sys.modules, "megatron.core", core_mod) + + try: + from slime.backends.megatron_utils.loss import compute_advantages_and_returns + from slime.utils.ppo_utils import compute_approx_kl + + log_probs = [torch.tensor([0.5, 0.7, 0.9])] + ref_log_probs = [torch.tensor([0.4, 0.5, 0.6])] + expected_kl = compute_approx_kl(log_probs[0], ref_log_probs[0], kl_loss_type="k1") + + rollout_data = { + "log_probs": log_probs, + "ref_log_probs": ref_log_probs, + "rewards": [1.0], + "values": [torch.zeros(3)], + "response_lengths": [3], + "total_lengths": [5], + "loss_masks": [torch.ones(3)], + } + args = Namespace( + advantage_estimator="ppo", + kl_coef=0.05, + kl_loss_type="k1", + use_rollout_logprobs=False, + custom_advantage_function_path=None, + normalize_advantages=False, + use_opd=False, + gamma=1.0, + lambd=1.0, + ) + + compute_advantages_and_returns(args, rollout_data) + + torch.testing.assert_close(rollout_data["kl"][0], expected_kl) + finally: + if previous_loss is None: + sys.modules.pop("slime.backends.megatron_utils.loss", None) + else: + sys.modules["slime.backends.megatron_utils.loss"] = previous_loss + if previous_cp_utils is None: + sys.modules.pop("slime.backends.megatron_utils.cp_utils", None) + else: + sys.modules["slime.backends.megatron_utils.cp_utils"] = previous_cp_utils + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__]))