Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pr-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ jobs:
strategy:
fail-fast: false
matrix:
info: [{"num_gpus": 0, "test_file": "test_megatron_argument_validation.py"}, {"num_gpus": 0, "test_file": "utils/test_megatron_server_arguments.py"}, {"num_gpus": 0, "test_file": "test_dp_schedule.py"}, {"num_gpus": 0, "test_file": "test_cp_utils.py"}, {"num_gpus": 0, "test_file": "test_metric_report.py"}, {"num_gpus": 0, "test_file": "test_metric_report_dist.py"}, {"num_gpus": 0, "test_file": "test_loss_cp_invariance.py"}, {"num_gpus": 0, "test_file": "test_logprob_response_spans.py"}, {"num_gpus": 0, "test_file": "test_value_temperature.py"}, {"num_gpus": 0, "test_file": "test_cispo_loss.py"}, {"num_gpus": 0, "test_file": "test_rm_f1.py"}, {"num_gpus": 0, "test_file": "test_rm_gpqa.py"}, {"num_gpus": 0, "test_file": "test_rm_math.py"}, {"num_gpus": 0, "test_file": "test_rm_math_dapo.py"}, {"num_gpus": 0, "test_file": "test_rm_deepscaler.py"}, {"num_gpus": 0, "test_file": "test_sample.py"}, {"num_gpus": 0, "test_file": "test_rollout_validation.py"}, {"num_gpus": 0, "test_file": "test_placement_group.py"}, {"num_gpus": 0, "test_file": "test_external_sglang_engines.py"}, {"num_gpus": 0, "test_file": "utils/test_hf_checkpoint_saver.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_rollout_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_runtime_hook_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_path_loading_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_generate_contracts.py"}]
info: [{"num_gpus": 0, "test_file": "test_megatron_argument_validation.py"}, {"num_gpus": 0, "test_file": "utils/test_megatron_server_arguments.py"}, {"num_gpus": 0, "test_file": "test_dp_schedule.py"}, {"num_gpus": 0, "test_file": "test_cp_utils.py"}, {"num_gpus": 0, "test_file": "test_metric_report.py"}, {"num_gpus": 0, "test_file": "test_metric_report_dist.py"}, {"num_gpus": 0, "test_file": "test_loss_cp_invariance.py"}, {"num_gpus": 0, "test_file": "test_logprob_response_spans.py"}, {"num_gpus": 0, "test_file": "test_value_temperature.py"}, {"num_gpus": 0, "test_file": "test_ppo_kl_metric.py"}, {"num_gpus": 0, "test_file": "test_cispo_loss.py"}, {"num_gpus": 0, "test_file": "test_rm_f1.py"}, {"num_gpus": 0, "test_file": "test_rm_gpqa.py"}, {"num_gpus": 0, "test_file": "test_rm_math.py"}, {"num_gpus": 0, "test_file": "test_rm_math_dapo.py"}, {"num_gpus": 0, "test_file": "test_rm_deepscaler.py"}, {"num_gpus": 0, "test_file": "test_sample.py"}, {"num_gpus": 0, "test_file": "test_rollout_validation.py"}, {"num_gpus": 0, "test_file": "test_placement_group.py"}, {"num_gpus": 0, "test_file": "test_external_sglang_engines.py"}, {"num_gpus": 0, "test_file": "utils/test_hf_checkpoint_saver.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_rollout_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_runtime_hook_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_path_loading_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_generate_contracts.py"}]
defaults:
run:
working-directory: ${{ github.workspace }}
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/pr-test.yml.j2
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
{'test_file': 'test_loss_cp_invariance.py', 'num_gpus': 0},
{'test_file': 'test_logprob_response_spans.py', 'num_gpus': 0},
{'test_file': 'test_value_temperature.py', 'num_gpus': 0},
{'test_file': 'test_ppo_kl_metric.py', 'num_gpus': 0},
{'test_file': 'test_cispo_loss.py', 'num_gpus': 0},
{'test_file': 'test_rm_f1.py', 'num_gpus': 0},
{'test_file': 'test_rm_gpqa.py', 'num_gpus': 0},
Expand Down
4 changes: 3 additions & 1 deletion slime/backends/megatron_utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,9 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch)
kl_coef = -args.kl_coef
cp_rank = mpu.get_context_parallel_rank()
for reward, k in zip(old_rewards, kl, strict=False):
k *= kl_coef
# Build rewards out-of-place: kl aliases rollout_data["kl"], which is
# the source for the logged rollout/kl metric.
k = k * kl_coef
if cp_rank == 0:
k[-1] += reward
rewards.append(k)
Expand Down
76 changes: 76 additions & 0 deletions tests/test_ppo_kl_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import sys
import types
from argparse import Namespace

import pytest
import torch

NUM_GPUS = 0


@pytest.mark.unit
def test_ppo_estimator_does_not_corrupt_logged_kl(monkeypatch):
"""The ppo branch turns kl into a reward (-kl_coef * kl, plus the scalar
reward at the last token). It must do so out-of-place: rollout_data["kl"] is
the source for the logged rollout/kl metric, and every other estimator treats
kl as read-only. An in-place update overwrites the logged KL with the reward.
"""
previous_loss = sys.modules.pop("slime.backends.megatron_utils.loss", None)
previous_cp_utils = sys.modules.pop("slime.backends.megatron_utils.cp_utils", None)

mpu_stub = types.SimpleNamespace(
get_context_parallel_world_size=lambda: 1,
get_context_parallel_rank=lambda: 0,
is_pipeline_last_stage=lambda: True,
)
megatron_mod = types.ModuleType("megatron")
core_mod = types.ModuleType("megatron.core")
core_mod.mpu = mpu_stub
monkeypatch.setitem(sys.modules, "megatron", megatron_mod)
monkeypatch.setitem(sys.modules, "megatron.core", core_mod)

try:
from slime.backends.megatron_utils.loss import compute_advantages_and_returns
from slime.utils.ppo_utils import compute_approx_kl

log_probs = [torch.tensor([0.5, 0.7, 0.9])]
ref_log_probs = [torch.tensor([0.4, 0.5, 0.6])]
expected_kl = compute_approx_kl(log_probs[0], ref_log_probs[0], kl_loss_type="k1")

rollout_data = {
"log_probs": log_probs,
"ref_log_probs": ref_log_probs,
"rewards": [1.0],
"values": [torch.zeros(3)],
"response_lengths": [3],
"total_lengths": [5],
"loss_masks": [torch.ones(3)],
}
args = Namespace(
advantage_estimator="ppo",
kl_coef=0.05,
kl_loss_type="k1",
use_rollout_logprobs=False,
custom_advantage_function_path=None,
normalize_advantages=False,
use_opd=False,
gamma=1.0,
lambd=1.0,
)

compute_advantages_and_returns(args, rollout_data)

torch.testing.assert_close(rollout_data["kl"][0], expected_kl)
finally:
if previous_loss is None:
sys.modules.pop("slime.backends.megatron_utils.loss", None)
else:
sys.modules["slime.backends.megatron_utils.loss"] = previous_loss
if previous_cp_utils is None:
sys.modules.pop("slime.backends.megatron_utils.cp_utils", None)
else:
sys.modules["slime.backends.megatron_utils.cp_utils"] = previous_cp_utils


if __name__ == "__main__":
raise SystemExit(pytest.main([__file__]))
Loading