Skip to content

Draft: Add green-context split-kernel MegaMoE features#357

Open
RayWang96 wants to merge 1 commit into
deepseek-ai:nv_devfrom
RayWang96:split_mega_moe
Open

Draft: Add green-context split-kernel MegaMoE features#357
RayWang96 wants to merge 1 commit into
deepseek-ai:nv_devfrom
RayWang96:split_mega_moe

Conversation

@RayWang96

@RayWang96 RayWang96 commented Jun 12, 2026

Copy link
Copy Markdown
Collaborator

Add green-context split-kernel MegaMoE (dispatch_l1_swiglu / l2_combine / combine_reduce)

Summary

This PR adds a split-kernel implementation of the FP8/FP4 MegaMoE alongside the existing fused
megakernel. Instead of one monolithic kernel that does everything, the MoE forward is decomposed
into three focused kernels that are wired into a single native CUDA graph:

Kernel SMs (default) Responsibility
K1 dispatch_l1_swiglu 96 NVLink dispatch (gather routed tokens) → Linear1 GEMM → SwiGLU + per-token FP8 quant → write activated tokens to a shared HBM pool, stamping a per-block arrival mask.
K2 l2_combine 52 Consume pool blocks as K1 produces them (spin-wait on the HBM arrival mask) → Linear2 GEMM → NVLink combine-scatter of the partials.
K3 combine_reduce 148 After K1+K2 finish, reduce the per-expert combine partials with the top-k weights into the final output.

K1 and K2 run concurrently on disjoint SM partitions carved with CUDA green contexts
(cudaGreenCtxCreate / cudaDevSmResourceSplit), coupled only through the HBM pool + arrival
masks. K3 runs on all SMs after the barrier. Everything replays as one CUDA-graph launch, so the
external API matches the fused path.

The fused fp8_fp4_mega_moe kernel is left completely untouched — this is a second,
parallel implementation. It is bitwise-identical to the fused output and up to ~11% faster
at large batch.

Why split it?

1. Maintainability

The fused megakernel interleaves dispatch, both GEMMs, activation, combine, and reduce inside one
persistent-CTA kernel that shares a single register / shared-memory / TMEM budget and a single
launch. That is fast but monolithic: tuning one stage (a GEMM tile, an occupancy target, the SM
count for L2) is entangled with every other stage.

2. Performance

Running K1 ‖ K2 concurrently overlaps the dispatch + Linear1 compute with Linear2 + combine,
which the monolithic kernel serializes inside its internal pipeline. At large batch this overlap
plus the per-kernel occupancy headroom yields a net speedup while remaining numerically identical.

3. As a possible reference

As a possible reference implementation for architectures such as Hopper

Correctness

tests/test_mega_moe_split.py builds the same inputs for both paths, runs the fused megakernel as
the reference, replays the split graph, and asserts torch.equal(y_split, y_fused).
All 32 swept configurations are bitwise-identical (max_abs = 0).

Benchmark results on B200

Sweep over token count × intermediate size × ranks. Fixed: hidden=7168, 6/384 experts,
SM split K1=96 / K2=52 / K3=148, fast_math=1, clamp=10. Timing: wall-clock best-of-20 (3 warmup),
2 GiB L2 flush + buffer reset each iter; fused & split measured by the same harness. Metrics
use the existing test_mega_moe.py formulas (TFLOPS = 2·recv·H·I·3; HBM = FP4 weights + acts + out;
NVLink = recv·H·3). ratio = split_us / fused_us (< 1 = split faster).

procs=8, intermediate=3072 (48 experts/rank)

tok recv Fused µs Split µs ratio F/S TFLOPS F/S HBM GB/s F/S NVL GB/s
16 102 304.3 305.2 1.003 44 / 44 4785 / 4772 7 / 7
32 184 320.9 325.8 1.015 76 / 75 4854 / 4781 12 / 12
64 375 326.4 334.1 1.024 152 / 148 4889 / 4776 25 / 24
128 764 341.0 357.7 1.049 296 / 282 4711 / 4491 48 / 46
512 3043 392.7 408.6 1.041 1024 / 984 4252 / 4086 167 / 160
1024 6022 492.8 508.7 1.032 1615 / 1564 3555 / 3444 263 / 255
4096 24818 1476.9 1454.5 0.985 2220 / 2254 1538 / 1562 361 / 367
8192 48980 2820.4 2722.0 0.965 2294 / 2377 1042 / 1080 373 / 387

procs=8, intermediate=2048 (48 experts/rank)

tok recv Fused µs Split µs ratio F/S TFLOPS F/S HBM GB/s F/S NVL GB/s
16 94 229.4 226.8 0.989 36 / 37 3849 / 3893 9 / 9
32 182 245.7 241.7 0.984 65 / 66 4321 / 4392 16 / 16
64 364 246.8 245.4 0.994 130 / 131 4320 / 4344 32 / 32
128 769 269.5 278.5 1.034 251 / 243 3995 / 3866 61 / 59
512 3103 298.4 304.4 1.020 916 / 898 3808 / 3733 224 / 219
1024 6137 403.0 416.4 1.033 1341 / 1298 3013 / 2916 328 / 317
4096 24630 1199.4 1101.6 0.919 1809 / 1969 1407 / 1532 442 / 481
8192 48559 2307.8 2046.7 0.887 1853 / 2090 997 / 1124 453 / 510

procs=4, intermediate=3072 (96 experts/rank)

tok recv Fused µs Split µs ratio F/S TFLOPS F/S HBM GB/s F/S NVL GB/s
16 101 405.8 403.0 0.993 33 / 33 5054 / 5088 5 / 5
32 192 519.8 522.2 1.005 49 / 49 5411 / 5386 8 / 8
64 393 565.1 576.1 1.019 92 / 90 5572 / 5466 15 / 15
128 765 573.4 587.4 1.025 176 / 172 5567 / 5434 29 / 28
512 3078 633.4 673.3 1.063 642 / 604 5141 / 4836 105 / 98
1024 6015 726.0 744.8 1.026 1095 / 1067 4597 / 4481 178 / 174
4096 24442 1726.4 1724.4 0.999 1871 / 1873 2228 / 2231 304 / 305
8192 49059 2827.8 2813.9 0.995 2292 / 2304 1601 / 1609 373 / 375

procs=4, intermediate=2048 (96 experts/rank)

tok recv Fused µs Split µs ratio F/S TFLOPS F/S HBM GB/s F/S NVL GB/s
16 103 301.7 304.6 1.010 30 / 30 4680 / 4635 7 / 7
32 196 369.7 372.8 1.009 47 / 46 5136 / 5093 11 / 11
64 388 409.5 408.9 0.998 83 / 84 5186 / 5195 20 / 20
128 788 414.2 421.2 1.017 168 / 165 5152 / 5067 41 / 40
512 3126 463.6 484.8 1.046 594 / 568 4732 / 4526 145 / 139
1024 6143 526.1 534.2 1.015 1028 / 1013 4317 / 4252 251 / 247
4096 24647 1281.6 1261.9 0.985 1694 / 1720 2142 / 2175 414 / 420
8192 49388 2194.1 2133.1 0.972 1983 / 2039 1540 / 1584 484 / 498

Note: Without CUDA_HOME pointing at a 13.1+ toolkit, the split graph constructor throws
Failed to load CUDA runtime API.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant