Draft: Add green-context split-kernel MegaMoE features#357
Open
RayWang96 wants to merge 1 commit into
Open
Conversation
…ne / combine_reduce)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Add green-context split-kernel MegaMoE (dispatch_l1_swiglu / l2_combine / combine_reduce)
Summary
This PR adds a split-kernel implementation of the FP8/FP4 MegaMoE alongside the existing fused
megakernel. Instead of one monolithic kernel that does everything, the MoE forward is decomposed
into three focused kernels that are wired into a single native CUDA graph:
dispatch_l1_swiglul2_combinecombine_reduceK1 and K2 run concurrently on disjoint SM partitions carved with CUDA green contexts
(
cudaGreenCtxCreate/cudaDevSmResourceSplit), coupled only through the HBM pool + arrivalmasks. K3 runs on all SMs after the barrier. Everything replays as one CUDA-graph launch, so the
external API matches the fused path.
The fused
fp8_fp4_mega_moekernel is left completely untouched — this is a second,parallel implementation. It is bitwise-identical to the fused output and up to ~11% faster
at large batch.
Why split it?
1. Maintainability
The fused megakernel interleaves dispatch, both GEMMs, activation, combine, and reduce inside one
persistent-CTA kernel that shares a single register / shared-memory / TMEM budget and a single
launch. That is fast but monolithic: tuning one stage (a GEMM tile, an occupancy target, the SM
count for L2) is entangled with every other stage.
2. Performance
Running K1 ‖ K2 concurrently overlaps the dispatch + Linear1 compute with Linear2 + combine,
which the monolithic kernel serializes inside its internal pipeline. At large batch this overlap
plus the per-kernel occupancy headroom yields a net speedup while remaining numerically identical.
3. As a possible reference
As a possible reference implementation for architectures such as Hopper
Correctness
tests/test_mega_moe_split.pybuilds the same inputs for both paths, runs the fused megakernel asthe reference, replays the split graph, and asserts
torch.equal(y_split, y_fused).All 32 swept configurations are bitwise-identical (
max_abs = 0).Benchmark results on B200
Sweep over token count × intermediate size × ranks. Fixed: hidden=7168, 6/384 experts,
SM split K1=96 / K2=52 / K3=148, fast_math=1, clamp=10. Timing: wall-clock best-of-20 (3 warmup),
2 GiB L2 flush + buffer reset each iter; fused & split measured by the same harness. Metrics
use the existing
test_mega_moe.pyformulas (TFLOPS = 2·recv·H·I·3; HBM = FP4 weights + acts + out;NVLink = recv·H·3).
ratio = split_us / fused_us(< 1 = split faster).procs=8, intermediate=3072 (48 experts/rank)
procs=8, intermediate=2048 (48 experts/rank)
procs=4, intermediate=3072 (96 experts/rank)
procs=4, intermediate=2048 (96 experts/rank)
Note: Without
CUDA_HOMEpointing at a 13.1+ toolkit, the split graph constructor throwsFailed to load CUDA runtime API.