Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pr-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ jobs:
strategy:
fail-fast: false
matrix:
info: [{"num_gpus": 0, "test_file": "test_megatron_argument_validation.py"}, {"num_gpus": 0, "test_file": "utils/test_megatron_server_arguments.py"}, {"num_gpus": 0, "test_file": "test_dp_schedule.py"}, {"num_gpus": 0, "test_file": "test_cp_utils.py"}, {"num_gpus": 0, "test_file": "test_metric_report.py"}, {"num_gpus": 0, "test_file": "test_metric_report_dist.py"}, {"num_gpus": 0, "test_file": "test_loss_cp_invariance.py"}, {"num_gpus": 0, "test_file": "test_logprob_response_spans.py"}, {"num_gpus": 0, "test_file": "test_value_temperature.py"}, {"num_gpus": 0, "test_file": "test_cispo_loss.py"}, {"num_gpus": 0, "test_file": "test_rm_f1.py"}, {"num_gpus": 0, "test_file": "test_rm_gpqa.py"}, {"num_gpus": 0, "test_file": "test_rm_math.py"}, {"num_gpus": 0, "test_file": "test_rm_math_dapo.py"}, {"num_gpus": 0, "test_file": "test_rm_deepscaler.py"}, {"num_gpus": 0, "test_file": "test_sample.py"}, {"num_gpus": 0, "test_file": "test_rollout_validation.py"}, {"num_gpus": 0, "test_file": "test_placement_group.py"}, {"num_gpus": 0, "test_file": "test_external_sglang_engines.py"}, {"num_gpus": 0, "test_file": "utils/test_hf_checkpoint_saver.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_rollout_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_runtime_hook_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_path_loading_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_generate_contracts.py"}]
info: [{"num_gpus": 0, "test_file": "test_megatron_argument_validation.py"}, {"num_gpus": 0, "test_file": "utils/test_megatron_server_arguments.py"}, {"num_gpus": 0, "test_file": "test_dp_schedule.py"}, {"num_gpus": 0, "test_file": "test_cp_utils.py"}, {"num_gpus": 0, "test_file": "test_metric_report.py"}, {"num_gpus": 0, "test_file": "test_metric_report_dist.py"}, {"num_gpus": 0, "test_file": "test_loss_cp_invariance.py"}, {"num_gpus": 0, "test_file": "test_logprob_response_spans.py"}, {"num_gpus": 0, "test_file": "test_value_temperature.py"}, {"num_gpus": 0, "test_file": "test_cispo_loss.py"}, {"num_gpus": 0, "test_file": "test_reinforce.py"}, {"num_gpus": 0, "test_file": "test_rm_f1.py"}, {"num_gpus": 0, "test_file": "test_rm_gpqa.py"}, {"num_gpus": 0, "test_file": "test_rm_math.py"}, {"num_gpus": 0, "test_file": "test_rm_math_dapo.py"}, {"num_gpus": 0, "test_file": "test_rm_deepscaler.py"}, {"num_gpus": 0, "test_file": "test_sample.py"}, {"num_gpus": 0, "test_file": "test_rollout_validation.py"}, {"num_gpus": 0, "test_file": "test_placement_group.py"}, {"num_gpus": 0, "test_file": "test_external_sglang_engines.py"}, {"num_gpus": 0, "test_file": "utils/test_hf_checkpoint_saver.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_rollout_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_runtime_hook_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_path_loading_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_generate_contracts.py"}]
defaults:
run:
working-directory: ${{ github.workspace }}
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/pr-test.yml.j2
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
{'test_file': 'test_logprob_response_spans.py', 'num_gpus': 0},
{'test_file': 'test_value_temperature.py', 'num_gpus': 0},
{'test_file': 'test_cispo_loss.py', 'num_gpus': 0},
{'test_file': 'test_reinforce.py', 'num_gpus': 0},
{'test_file': 'test_rm_f1.py', 'num_gpus': 0},
{'test_file': 'test_rm_gpqa.py', 'num_gpus': 0},
{'test_file': 'test_rm_math.py', 'num_gpus': 0},
Expand Down
5 changes: 4 additions & 1 deletion slime/backends/megatron_utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
compute_gspo_kl,
compute_opsm_mask,
compute_policy_loss,
compute_reinforce_loss,
get_advantages_and_returns_batch,
get_grpo_returns,
get_reinforce_plus_plus_baseline_advantages,
Expand Down Expand Up @@ -713,7 +714,7 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch)
custom_adv_fn(args, rollout_data)
advantages, returns = rollout_data["advantages"], rollout_data["returns"]

elif args.advantage_estimator in ["grpo", "gspo", "cispo"]:
elif args.advantage_estimator in ["grpo", "gspo", "cispo", "reinforce"]:
rewards = torch.tensor(rewards, dtype=torch.float32, device=kl[0].device)
returns = get_grpo_returns(rewards, kl)
# TODO: is the copy necessary?
Expand Down Expand Up @@ -973,6 +974,8 @@ def policy_loss_function(

if args.advantage_estimator == "cispo":
pg_loss, pg_clipfrac = compute_cispo_loss(ppo_kl, log_probs, advantages, args.eps_clip, args.eps_clip_high)
elif args.advantage_estimator == "reinforce":
pg_loss, pg_clipfrac = compute_reinforce_loss(advantages, log_probs)
else:
pg_loss, pg_clipfrac = compute_policy_loss(ppo_kl, advantages, args.eps_clip, args.eps_clip_high)

Expand Down
7 changes: 5 additions & 2 deletions slime/ray/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,7 @@ def _post_process_rewards(self, samples: list[Sample] | list[list[Sample]]):

raw_rewards = [sample.get_reward_value(self.args) for sample in samples]
if (
self.args.advantage_estimator in ["grpo", "gspo", "cispo", "reinforce_plus_plus_baseline"]
self.args.advantage_estimator in ["grpo", "gspo", "cispo", "reinforce", "reinforce_plus_plus_baseline"]
and self.args.rewards_normalization
):
# group norm
Expand All @@ -702,7 +702,10 @@ def _post_process_rewards(self, samples: list[Sample] | list[list[Sample]]):
mean = rewards.mean(dim=-1, keepdim=True)
rewards = rewards - mean

if self.args.advantage_estimator in ["grpo", "gspo", "cispo"] and self.args.grpo_std_normalization:
if (
self.args.advantage_estimator in ["grpo", "gspo", "cispo", "reinforce"]
and self.args.grpo_std_normalization
):
std = rewards.std(dim=-1, keepdim=True)
rewards = rewards / (std + 1e-6)

Expand Down
7 changes: 5 additions & 2 deletions slime/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,14 +911,17 @@ def add_algo_arguments(parser):
"grpo",
"gspo",
"cispo",
"reinforce",
"reinforce_plus_plus",
"reinforce_plus_plus_baseline",
"ppo",
],
default="grpo",
help=(
"Advantage estimator to use. Note: on-policy distillation (OPD) is now orthogonal "
"to the advantage estimator. Use --opd-kl-coef > 0 to enable OPD on top of any estimator."
"Advantage estimator to use. 'reinforce' uses GRPO-style group-normalized "
"advantages with the plain additive surrogate (no PPO/IS ratio, no clipping). "
"Note: on-policy distillation (OPD) is now orthogonal to the advantage estimator. "
"Use --opd-kl-coef > 0 to enable OPD on top of any estimator."
),
)
parser.add_argument(
Expand Down
14 changes: 14 additions & 0 deletions slime/utils/ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,20 @@ def compute_cispo_loss(
return pg_losses, clipfrac


@torch.compile(dynamic=True)
def compute_reinforce_loss(
advantages: torch.Tensor,
log_probs: torch.Tensor,
):
"""REINFORCE surrogate ``-A * log pi_theta`` (no IS ratio, no clipping); gradient
flows only through ``log_probs``. Same ``(per_token_loss, clipfrac)`` contract as
:func:`compute_policy_loss`, with ``clipfrac`` identically zero (nothing is clipped).
"""
pg_losses = -advantages * log_probs
clipfrac = torch.zeros_like(pg_losses)
return pg_losses, clipfrac


def compute_log_probs(
logits: torch.Tensor,
tokens: torch.Tensor,
Expand Down
35 changes: 35 additions & 0 deletions tests/test_reinforce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""CPU tests for compute_reinforce_loss (plain ``-A * log pi_theta`` surrogate)."""

import pytest
import torch

from slime.utils.ppo_utils import compute_reinforce_loss

NUM_GPUS = 0


@pytest.mark.unit
def test_reinforce_loss_matches_closed_form():
advantages = torch.tensor([2.0, -1.0, 0.5])
log_probs = torch.tensor([-0.1, -0.2, -0.3])

pg_loss, clipfrac = compute_reinforce_loss(advantages, log_probs)

assert torch.allclose(pg_loss, -advantages * log_probs)
assert torch.allclose(clipfrac, torch.zeros(3))


@pytest.mark.unit
def test_reinforce_gradient_flows_only_through_log_probs():
advantages = torch.tensor([2.0, -1.0, 0.5])
log_probs = torch.tensor([-0.1, -0.2, -0.3], requires_grad=True)

pg_loss, _ = compute_reinforce_loss(advantages, log_probs)
pg_loss.sum().backward()

# d/d log_probs [ -A * log_probs ] = -A
assert torch.allclose(log_probs.grad, -advantages)


if __name__ == "__main__":
raise SystemExit(pytest.main([__file__]))
Loading