Fix a race condition in contiguous k-grouped GEMM where in-flight tensormaps are updated in-place#343
Fix a race condition in contiguous k-grouped GEMM where in-flight tensormaps are updated in-place#343dfyz wants to merge 1 commit into
Conversation
|
I believe that, at the memory-model level, @xay5421 can help review and correct me if I missed something. |
I'm totally fine with replacing
The guide also says that the behavior is "invalid or undefined" only if "A non-exited thread specified in the I think that my usage is conceptually this: Which seems equivalent to the example from the docs. |
|
Thanks for pointing to the CUDA programming guide. I agree my previous wording was too strong: this is not unconditionally undefined behavior. The precise point is: this pattern is only valid if all non-elected lanes have truly exited before the elected lane reaches the full-mask if (... && cute::elect_one_sync()) {
__syncwarp(0xffffffff);
}
So this would be valid if the final control flow is really equivalent to: if (threadIdx.x != ELECTED_THREAD_ID)
return;
__syncwarp(0xffffffff);But the source pattern above does not guarantee that. For example, even if there is no explicit code after the if, a C++ destructor / cleanup for an object declared before the if can still be emitted at scope exit. More generally, the compiler could lower the skipped path to a common exit block: ELECT P0
@!P0 BRA common_exit
WARPSYNC 0xffffffff // elected lane only
common_exit:
// cleanup / epilogue / setmaxnreg-related code
EXITIn that case, the non-elected lanes are still non-exited when the elected lane executes |
|
Just to confirm, the next required action here is for Ivan @dfyz to update the MR based on the suggestion above? |
Yeah, I think so. I'm still not 100% sure I understand the logic about |
|
@RayWang96 I've tried using In general, I'm not sure if any form of What is the right tool for the job? I believe that, conceptually, we need to order the generic proxy write for the next group ( Of course, |
Have you checked the SASS code already? I believe all the
|
Right, and then we can wait for the DS team to review and merge it. If you want to get it merged quickly, you can merge it into the nv_dev branch first. |
After What we need is to prevent STG from being reordered before |
Yes I did (after seeing the test fail). To clarify, I'm using the following To reproduce the desired behavior, you can run the Here, the If you do the same for commit Again, you can inspect the control-flow graph to be sure.
Fair enough, but if I take commit
I think there's no rush! The training codebase where this issue was initially found has a workaround applied internally, and it would be good to fix it and review it in a proper way upstream. |
Indeed, my earlier approach was incorrect; the issue here cannot be resolved with a memory barrier. I believe warp sync is the correct fix, but I would recommend using __syncwarp(1 << lane_idx); so that it acts purely as a instruction scheduling barrier and avoids the risk of causing other side effects. |
Thanks, this sounds like a good idea. I rebased my branch on top of |
Do we maybe need to tag someone from the DS team to look at this PR? |
The weight gradient kernel for SM90 modifies the A/B tensormaps stored in GMEM when the scheduler decides to switch groups, but it doesn't make sure that previous TMA loads have already finished reading the tensormaps at that point. A TMA load reading from a tensormap being modified concurrently corrupts the output for obvious reasons.
A reliable way to reproduce this corruption (at least on an H200) is to run the
test_k_grouped_gemm_contiguoustest a couple of times with the shape added in this PR. Eventually the corruption will be large enough to make the test fail:Even when the test doesn't fail, the corruption still happens. You can see that by printing the
torch.norm()of the result here and seeing that it differs between the test runs:The fix for this is conceptually straightforward: section 9.7.9.26.5.2 from the PTX docs says that bulk async-group based completion mechanism can be used "for the completion of reading of the tensormap object" in (otherwise mbarrier-based) TMA loads. So running
cp.async.bulk.commit_group+cp.async.bulk.wait_group.read 0before modifying GMEM ensures all in-flight TMA loads have finished reading their tensormaps, which fixes the race condition.The only catch is that writing to a TMA descriptor from a single thread generates multiple
STG.E.128instructions, andptxasdecides to move some of them before theDEPBAR.LE SB0, 0x0generated bycp.async.bulk.wait_group.read 0, which re-introduces the race condition. CUTLASS does a__syncwarp()before updating tensormaps in GMEM, and inserting a warp sync indeed preventsptxasfrom reordering instructions, but this seems more like a coincidence, since we only have one thread from the warp anyway and I don't see a good way to explain this in term of the PTX memory model. In other words, this is an ugly hack, so if you have better solutions, I will gladly implement them. :)