feat(rl): add off-policy IS correction hook (current policy vs rollout)#2084
Open
EazyReal wants to merge 1 commit into
Open
feat(rl): add off-policy IS correction hook (current policy vs rollout)#2084EazyReal wants to merge 1 commit into
EazyReal wants to merge 1 commit into
Conversation
Contributor
Author
|
@zhuzilin could you review this one? It adds the off-policy IS hook at the policy-loss boundary using current-policy logprobs against rollout logprobs, preserving gradient flow through pi_theta while keeping the correction weight detached. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What changed
policy_loss_function(slime/backends/megatron_utils/loss.py) now passes the current grad-carrying log-probs ascur_log_probsinto the TIS-correction kwargs, alongside the existing frozentrain_log_probs(pi_theta_old) androllout_log_probs(pi_rollout). The existing corrections (vanilla_tis_function,icepop_function) ignore the new kwarg via**kwargs.off_policy_is_functioninslime/utils/ppo_utils.py: a truncated-IS correction whose detached weight isclip(pi_theta / pi_rollout)against the actual rollout logprob, so one weight corrects both the train/inference mismatch and async (multi-version) staleness. The existing TIS only hadpi_theta_old / pi_rollout, which equals this only in the single-update-per-rollout limit.--use-tis --custom-tis-function-path slime.utils.ppo_utils.off_policy_is_function.--eps-clip/--eps-clip-high(the CISPO/PPO clip convention);--eps-clip 1.0gives canonical single-sided clipping. Note this reuses the policy-loss clip range rather than the--tis-clip*range used by the other TIS hooks.Why
On a REINFORCE base (
--advantage-estimator reinforce) this reproduces the CISPO surrogate (MiniMax-M1, arxiv 2506.13585) as a composable correction:L = -clip(pi_theta/pi_rollout).detach() * A * log pi, with gradient only throughlog pi. It generalizes the train/inference-mismatch correction to also absorb off-policy staleness without adding a dedicated estimator.Validation
CPU unit test
tests/test_off_policy_is.py(registered in the cpu-unittest matrix,NUM_GPUS = 0), run viapytest tests/test_off_policy_is.py:log pi);--eps-clip 1.0disables the lower bound (single-sided).The
cur_log_probswiring inloss.pyimports megatron and is exercised by the GPU CI suites.