diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 2d6e8ce5b2..6a7c4268b5 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -372,7 +372,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 0, "test_file": "test_megatron_argument_validation.py"}, {"num_gpus": 0, "test_file": "utils/test_megatron_server_arguments.py"}, {"num_gpus": 0, "test_file": "test_dp_schedule.py"}, {"num_gpus": 0, "test_file": "test_cp_utils.py"}, {"num_gpus": 0, "test_file": "test_metric_report.py"}, {"num_gpus": 0, "test_file": "test_metric_report_dist.py"}, {"num_gpus": 0, "test_file": "test_loss_cp_invariance.py"}, {"num_gpus": 0, "test_file": "test_logprob_response_spans.py"}, {"num_gpus": 0, "test_file": "test_value_temperature.py"}, {"num_gpus": 0, "test_file": "test_cispo_loss.py"}, {"num_gpus": 0, "test_file": "test_rm_f1.py"}, {"num_gpus": 0, "test_file": "test_rm_gpqa.py"}, {"num_gpus": 0, "test_file": "test_rm_math.py"}, {"num_gpus": 0, "test_file": "test_rm_math_dapo.py"}, {"num_gpus": 0, "test_file": "test_rm_deepscaler.py"}, {"num_gpus": 0, "test_file": "test_sample.py"}, {"num_gpus": 0, "test_file": "test_rollout_validation.py"}, {"num_gpus": 0, "test_file": "test_placement_group.py"}, {"num_gpus": 0, "test_file": "test_external_sglang_engines.py"}, {"num_gpus": 0, "test_file": "utils/test_hf_checkpoint_saver.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_rollout_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_runtime_hook_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_path_loading_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_generate_contracts.py"}] + info: [{"num_gpus": 0, "test_file": "test_megatron_argument_validation.py"}, {"num_gpus": 0, "test_file": "utils/test_megatron_server_arguments.py"}, {"num_gpus": 0, "test_file": "test_dp_schedule.py"}, {"num_gpus": 0, "test_file": "test_cp_utils.py"}, {"num_gpus": 0, "test_file": "test_metric_report.py"}, {"num_gpus": 0, "test_file": "test_metric_report_dist.py"}, {"num_gpus": 0, "test_file": "test_loss_cp_invariance.py"}, {"num_gpus": 0, "test_file": "test_logprob_response_spans.py"}, {"num_gpus": 0, "test_file": "test_value_temperature.py"}, {"num_gpus": 0, "test_file": "test_cispo_loss.py"}, {"num_gpus": 0, "test_file": "test_reinforce.py"}, {"num_gpus": 0, "test_file": "test_rm_f1.py"}, {"num_gpus": 0, "test_file": "test_rm_gpqa.py"}, {"num_gpus": 0, "test_file": "test_rm_math.py"}, {"num_gpus": 0, "test_file": "test_rm_math_dapo.py"}, {"num_gpus": 0, "test_file": "test_rm_deepscaler.py"}, {"num_gpus": 0, "test_file": "test_sample.py"}, {"num_gpus": 0, "test_file": "test_rollout_validation.py"}, {"num_gpus": 0, "test_file": "test_placement_group.py"}, {"num_gpus": 0, "test_file": "test_external_sglang_engines.py"}, {"num_gpus": 0, "test_file": "utils/test_hf_checkpoint_saver.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_rollout_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_runtime_hook_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_path_loading_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_generate_contracts.py"}] defaults: run: working-directory: ${{ github.workspace }} diff --git a/.github/workflows/pr-test.yml.j2 b/.github/workflows/pr-test.yml.j2 index 30cf386421..0b6b5f5c60 100644 --- a/.github/workflows/pr-test.yml.j2 +++ b/.github/workflows/pr-test.yml.j2 @@ -72,6 +72,7 @@ {'test_file': 'test_logprob_response_spans.py', 'num_gpus': 0}, {'test_file': 'test_value_temperature.py', 'num_gpus': 0}, {'test_file': 'test_cispo_loss.py', 'num_gpus': 0}, + {'test_file': 'test_reinforce.py', 'num_gpus': 0}, {'test_file': 'test_rm_f1.py', 'num_gpus': 0}, {'test_file': 'test_rm_gpqa.py', 'num_gpus': 0}, {'test_file': 'test_rm_math.py', 'num_gpus': 0}, diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index 72afdfa66c..7f72be18f5 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -17,6 +17,7 @@ compute_gspo_kl, compute_opsm_mask, compute_policy_loss, + compute_reinforce_loss, get_advantages_and_returns_batch, get_grpo_returns, get_reinforce_plus_plus_baseline_advantages, @@ -713,7 +714,7 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch) custom_adv_fn(args, rollout_data) advantages, returns = rollout_data["advantages"], rollout_data["returns"] - elif args.advantage_estimator in ["grpo", "gspo", "cispo"]: + elif args.advantage_estimator in ["grpo", "gspo", "cispo", "reinforce"]: rewards = torch.tensor(rewards, dtype=torch.float32, device=kl[0].device) returns = get_grpo_returns(rewards, kl) # TODO: is the copy necessary? @@ -973,6 +974,8 @@ def policy_loss_function( if args.advantage_estimator == "cispo": pg_loss, pg_clipfrac = compute_cispo_loss(ppo_kl, log_probs, advantages, args.eps_clip, args.eps_clip_high) + elif args.advantage_estimator == "reinforce": + pg_loss, pg_clipfrac = compute_reinforce_loss(advantages, log_probs) else: pg_loss, pg_clipfrac = compute_policy_loss(ppo_kl, advantages, args.eps_clip, args.eps_clip_high) diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index ff571101af..76cbcb6e33 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -689,7 +689,7 @@ def _post_process_rewards(self, samples: list[Sample] | list[list[Sample]]): raw_rewards = [sample.get_reward_value(self.args) for sample in samples] if ( - self.args.advantage_estimator in ["grpo", "gspo", "cispo", "reinforce_plus_plus_baseline"] + self.args.advantage_estimator in ["grpo", "gspo", "cispo", "reinforce", "reinforce_plus_plus_baseline"] and self.args.rewards_normalization ): # group norm @@ -702,7 +702,10 @@ def _post_process_rewards(self, samples: list[Sample] | list[list[Sample]]): mean = rewards.mean(dim=-1, keepdim=True) rewards = rewards - mean - if self.args.advantage_estimator in ["grpo", "gspo", "cispo"] and self.args.grpo_std_normalization: + if ( + self.args.advantage_estimator in ["grpo", "gspo", "cispo", "reinforce"] + and self.args.grpo_std_normalization + ): std = rewards.std(dim=-1, keepdim=True) rewards = rewards / (std + 1e-6) diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index d5cac9d44b..59dc3afb64 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -911,14 +911,17 @@ def add_algo_arguments(parser): "grpo", "gspo", "cispo", + "reinforce", "reinforce_plus_plus", "reinforce_plus_plus_baseline", "ppo", ], default="grpo", help=( - "Advantage estimator to use. Note: on-policy distillation (OPD) is now orthogonal " - "to the advantage estimator. Use --opd-kl-coef > 0 to enable OPD on top of any estimator." + "Advantage estimator to use. 'reinforce' uses GRPO-style group-normalized " + "advantages with the plain additive surrogate (no PPO/IS ratio, no clipping). " + "Note: on-policy distillation (OPD) is now orthogonal to the advantage estimator. " + "Use --opd-kl-coef > 0 to enable OPD on top of any estimator." ), ) parser.add_argument( diff --git a/slime/utils/ppo_utils.py b/slime/utils/ppo_utils.py index a4dd0c0181..a40fa27210 100644 --- a/slime/utils/ppo_utils.py +++ b/slime/utils/ppo_utils.py @@ -171,6 +171,20 @@ def compute_cispo_loss( return pg_losses, clipfrac +@torch.compile(dynamic=True) +def compute_reinforce_loss( + advantages: torch.Tensor, + log_probs: torch.Tensor, +): + """REINFORCE surrogate ``-A * log pi_theta`` (no IS ratio, no clipping); gradient + flows only through ``log_probs``. Same ``(per_token_loss, clipfrac)`` contract as + :func:`compute_policy_loss`, with ``clipfrac`` identically zero (nothing is clipped). + """ + pg_losses = -advantages * log_probs + clipfrac = torch.zeros_like(pg_losses) + return pg_losses, clipfrac + + def compute_log_probs( logits: torch.Tensor, tokens: torch.Tensor, diff --git a/tests/test_reinforce.py b/tests/test_reinforce.py new file mode 100644 index 0000000000..28526bab22 --- /dev/null +++ b/tests/test_reinforce.py @@ -0,0 +1,35 @@ +"""CPU tests for compute_reinforce_loss (plain ``-A * log pi_theta`` surrogate).""" + +import pytest +import torch + +from slime.utils.ppo_utils import compute_reinforce_loss + +NUM_GPUS = 0 + + +@pytest.mark.unit +def test_reinforce_loss_matches_closed_form(): + advantages = torch.tensor([2.0, -1.0, 0.5]) + log_probs = torch.tensor([-0.1, -0.2, -0.3]) + + pg_loss, clipfrac = compute_reinforce_loss(advantages, log_probs) + + assert torch.allclose(pg_loss, -advantages * log_probs) + assert torch.allclose(clipfrac, torch.zeros(3)) + + +@pytest.mark.unit +def test_reinforce_gradient_flows_only_through_log_probs(): + advantages = torch.tensor([2.0, -1.0, 0.5]) + log_probs = torch.tensor([-0.1, -0.2, -0.3], requires_grad=True) + + pg_loss, _ = compute_reinforce_loss(advantages, log_probs) + pg_loss.sum().backward() + + # d/d log_probs [ -A * log_probs ] = -A + assert torch.allclose(log_probs.grad, -advantages) + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__]))