-
Notifications
You must be signed in to change notification settings - Fork 735
[Pyt][Common] Enabling/Guarding sm120 support (non - attention) #2833
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
a0de09a
1bc3a7b
dd39661
784fc76
4ca422a
791d279
c42b2e4
a4eccee
fc1d54c
fe238a2
b4bfaa3
2acdfb6
e4c7784
d48fbe1
7923dc5
85126dc
0d2baf1
0ec85ac
9e422bc
21be3b1
4ea2377
62fb50a
d58fa66
d960e24
508cc81
8623c81
b9e9093
fa2a283
616a537
bdeeae8
42d66ee
72f1c31
b5ddaaa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,6 +30,11 @@ | |
| warnings.filterwarnings("ignore", category=FutureWarning) | ||
| warnings.filterwarnings("ignore", category=UserWarning) | ||
|
|
||
| FP8_DEFAULT_RTOL_ATOL = (0.125, 0.0625) | ||
| FP8_CS_SM120_DETERMINISTIC_RTOL_ATOL = (0.4, 0.25) | ||
| BF16_DEFAULT_RTOL_ATOL = (0.025, 0.00125) | ||
| BF16_SM120_DETERMINISTIC_OVERLAP_RTOL_ATOL = (0.05, 0.01) | ||
|
|
||
|
|
||
| class multi_module_model(torch.nn.Module): | ||
| def __init__(self, module, num_layers, *args, **kwargs): | ||
|
|
@@ -551,9 +556,33 @@ def run_fwd_bwd(model, x): | |
|
|
||
| # Now validate accuracy | ||
| if not bool(numerics_failed.item()): | ||
| is_sm120 = torch.cuda.get_device_capability() == (12, 0) | ||
| is_deterministic_mode = os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1") == "0" | ||
| for i, (test_g, ref_g) in enumerate(zip(test_grads, ref_grads)): | ||
| rtol = 0.125 if opts.fp8 else 0.025 | ||
| atol = 0.0625 if opts.fp8 else 0.00125 | ||
| if opts.fp8: | ||
| if ( | ||
| opts.quantization == "fp8_current_scaling" | ||
| and is_sm120 | ||
| and is_deterministic_mode | ||
| ): | ||
| # SM120 deterministic mode disables fused attn, so rt uses alternate attn backends. | ||
| # Combined with FP8 CS, this path needs the looser distributed fp8_cs tolerance policy. | ||
|
Comment on lines
+563
to
+569
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the discrepancy is due to changes in the attention backend, we should only relax the tols with |
||
| rtol, atol = FP8_CS_SM120_DETERMINISTIC_RTOL_ATOL | ||
| else: | ||
| rtol, atol = FP8_DEFAULT_RTOL_ATOL | ||
| else: | ||
| rtol, atol = BF16_DEFAULT_RTOL_ATOL | ||
| if ( | ||
| is_sm120 | ||
| and is_deterministic_mode | ||
| and opts.layer_type == te.TransformerLayer | ||
| and opts.num_layers > 1 | ||
| and opts.overlap_rs_dgrad | ||
| ): | ||
| # SM120 + deterministic training disables fused attn . | ||
| # Rt then selects an alternate attn backend, and | ||
| # the overlap path can show tiny BF16 accumulation-order drift vs reference. | ||
| rtol, atol = BF16_SM120_DETERMINISTIC_OVERLAP_RTOL_ATOL | ||
| grad_failed, grad_info = _compare_tensors(names[i], test_g, ref_g, rtol, atol) | ||
| dist_print(grad_info, src=WORLD_RANK, error=grad_failed) | ||
| numerics_failed[0] = int(grad_failed) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,11 +14,31 @@ | |
|
|
||
| recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) | ||
|
|
||
| SM120_SR_EQUIVALENCE_ATOL = 2e-7 | ||
|
|
||
| seed = 12345 | ||
| torch.manual_seed(seed) | ||
| torch.cuda.manual_seed(seed) | ||
|
|
||
|
|
||
| def _assert_sr_vs_rn_behavior( | ||
| me_sr: torch.Tensor, | ||
| me_rn: torch.Tensor, | ||
| me_t_sr: torch.Tensor, | ||
| me_t_rn: torch.Tensor, | ||
| ) -> None: | ||
| if torch.cuda.get_device_capability() == (12, 0): | ||
| # SM120 currently disables NVFP4 stochastic rounding in backend paths, | ||
| # so SR and RN should be numerically equivalent. | ||
|
Comment on lines
+31
to
+32
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: I'd expect a function called |
||
| torch.testing.assert_close(me_sr, me_rn, atol=SM120_SR_EQUIVALENCE_ATOL, rtol=0.0) | ||
| torch.testing.assert_close(me_t_sr, me_t_rn, atol=SM120_SR_EQUIVALENCE_ATOL, rtol=0.0) | ||
| else: | ||
| assert me_sr < me_rn, "Stochastic rounding failed - error larger than the round to nearest." | ||
| assert ( | ||
| me_t_sr < me_t_rn | ||
| ), "Stochastic rounding failed - error larger than the round to nearest." | ||
|
|
||
|
|
||
| def unpack_fp4(x: torch.Tensor) -> torch.Tensor: | ||
| repeated = x.repeat_interleave(2, dim=1) | ||
| repeated[:, 0::2] &= 0x0F | ||
|
|
@@ -247,7 +267,7 @@ def check_quantization_nvfp4_versus_reference( | |
| me_t_rn = torch.sqrt((error_t_rn * error_t_rn).mean()) | ||
| sr_result = torch.zeros_like(x).float() | ||
| sr_t_result = torch.zeros_like(x).float().t().contiguous() | ||
| for i in range(n_iters): | ||
| for _ in range(n_iters): | ||
| q_sr, s_sr, q_t_sr, s_t_sr = quantize_fp4( | ||
| x, use_stochastic_rounding=True, use_2D=use_2D, use_RHT=use_RHT | ||
| ) | ||
|
|
@@ -278,8 +298,7 @@ def check_quantization_nvfp4_versus_reference( | |
|
|
||
| print(f"RMSE SR: {me_sr:.3e} | RMSE RN: {me_rn:.3e}") | ||
| print(f"RMSE SR_t: {me_t_sr:.3e} | RMSE RN_t: {me_t_rn:.3e}") | ||
| assert me_sr < me_rn, "Stochastic rounding failed - error larger than the round to nearest." | ||
| assert me_t_sr < me_t_rn, "Stochastic rounding failed - error larger than the round to nearest." | ||
| _assert_sr_vs_rn_behavior(me_sr, me_rn, me_t_sr, me_t_rn) | ||
|
|
||
|
|
||
| def check_group_quantization_nvfp4_versus_reference( | ||
|
|
@@ -362,10 +381,7 @@ def check_group_quantization_nvfp4_versus_reference( | |
|
|
||
| print(f"RMSE SR: {me_sr:.3e} | RMSE RN: {me_rn:.3e}") | ||
| print(f"RMSE SR_t: {me_t_sr:.3e} | RMSE RN_t: {me_t_rn:.3e}") | ||
| assert me_sr < me_rn, "Stochastic rounding failed - error larger than the round to nearest." | ||
| assert ( | ||
| me_t_sr < me_t_rn | ||
| ), "Stochastic rounding failed - error larger than the round to nearest." | ||
| _assert_sr_vs_rn_behavior(me_sr, me_rn, me_t_sr, me_t_rn) | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like a proper bug. If we run on SM 12.0, we want the test to fail rather than giving us a false pass.