diff --git a/csrc/apis/mega.hpp b/csrc/apis/mega.hpp index aa47e40bf8..d55b479d45 100644 --- a/csrc/apis/mega.hpp +++ b/csrc/apis/mega.hpp @@ -3,12 +3,14 @@ #include #include #include +#include #if DG_TENSORMAP_COMPATIBLE #include "../jit/compiler.hpp" #endif #include "../jit/device_runtime.hpp" #include "../jit_kernels/impls/sm100_fp8_fp4_mega_moe.hpp" +#include "../jit_kernels/impls/sm100_fp8_fp4_mega_moe_split.hpp" namespace deep_gemm::mega { @@ -132,6 +134,123 @@ get_symm_buffer_size_for_mega_moe( return {reinterpret_cast(combine_token_buffer.get_end_ptr()), slice_input_buffers}; } +// Symmetric-buffer sizing for the SPLIT-kernel pipeline. Identical to the fused layout above +// except it reserves the route-based `SplitWorkspace` bookkeeping region (the split K1 dispatch +// pull metadata). Kept separate so the fused MegaMoE buffer layout is untouched. +static std::tuple(const torch::Tensor&)>> +get_symm_buffer_size_for_mega_moe_split( + const int& num_ranks, const int& num_experts, + const int& num_max_tokens_per_rank, const int& num_topk, + const int& hidden, const int& intermediate_hidden, + const bool& use_fp8_dispatch, const std::string& activation) { + DG_HOST_ASSERT(num_experts % num_ranks == 0); + DG_HOST_ASSERT(use_fp8_dispatch); + + // Workspace bytes (route-based split layout) + const auto workspace = layout::SplitWorkspace(nullptr, num_ranks, num_experts, num_max_tokens_per_rank, num_topk); + + // Layouts + const auto fp8_token_layout = layout::Data(hidden); + const auto bf16_token_layout = layout::Data(hidden * 2); + const auto fp8_intermediate_token_layout = layout::Data(intermediate_hidden); + const auto fp8_sf_layout = layout::Data(hidden / 32); + const auto fp8_intermediate_sf_layout = layout::Data(intermediate_hidden / 32); + const auto input_topk_idx_layout = layout::Data(num_topk * sizeof(int64_t), false); + const auto input_topk_weights_layout = layout::Data(num_topk * sizeof(float), false); + const auto l1_topk_weights_layout = layout::Data(sizeof(float), false); + + // Input buffers + const auto input_token_buffer = layout::Buffer( + fp8_token_layout, 1, num_max_tokens_per_rank, + workspace.get_end_ptr()); + const auto input_sf_buffer = layout::Buffer( + fp8_sf_layout, 1, num_max_tokens_per_rank, + input_token_buffer.get_end_ptr()); + const auto input_topk_idx_buffer = layout::Buffer( + input_topk_idx_layout, 1, num_max_tokens_per_rank, + input_sf_buffer.get_end_ptr()); + const auto input_topk_weights_buffer = layout::Buffer( + input_topk_weights_layout, 1, num_max_tokens_per_rank, + input_topk_idx_buffer.get_end_ptr()); + + // Buffer configs + const auto num_max_pool_tokens = static_cast(workspace.num_max_pool_tokens); + int num_max_padded_sf_pool_tokens = 0; + for (int block_m: layout::kCandidateBlockM) { + num_max_padded_sf_pool_tokens = std::max( + num_max_padded_sf_pool_tokens, + layout::get_num_padded_sf_pool_tokens(num_max_pool_tokens, block_m) + ); + } + + // L1 input buffer + const auto l1_token_buffer = layout::Buffer( + fp8_token_layout, 1, num_max_pool_tokens, + input_topk_weights_buffer.get_end_ptr()); + const auto l1_sf_buffer = layout::Buffer( + fp8_sf_layout, 1, num_max_padded_sf_pool_tokens, + l1_token_buffer.get_end_ptr()); + const auto l1_topk_weights_buffer = layout::Buffer( + l1_topk_weights_layout, 1, num_max_pool_tokens, + l1_sf_buffer.get_end_ptr()); + + // L2 input buffer + const auto l2_token_buffer = layout::Buffer( + fp8_intermediate_token_layout, 1, num_max_pool_tokens, + l1_topk_weights_buffer.get_end_ptr()); + const auto l2_sf_buffer = layout::Buffer( + fp8_intermediate_sf_layout, 1, num_max_padded_sf_pool_tokens, + l2_token_buffer.get_end_ptr()); + + // Combine input buffer: BF16 tokens for cross-rank combine + const auto combine_token_buffer = layout::Buffer( + bf16_token_layout, num_topk, num_max_tokens_per_rank, + l2_sf_buffer.get_end_ptr()); + + // Check SF buffer requirements + DG_HOST_ASSERT(hidden % 128 == 0 and intermediate_hidden % 128 == 0); + DG_HOST_ASSERT(num_max_padded_sf_pool_tokens % 4 == 0); + + auto slice_input_buffers = [=](const torch::Tensor& buffer) { + auto x = torch::from_blob( + math::advance_ptr(buffer.data_ptr(), reinterpret_cast(input_token_buffer.base)), + {num_max_tokens_per_rank, hidden}, + torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device())); + auto x_sf = torch::from_blob( + math::advance_ptr(buffer.data_ptr(), reinterpret_cast(input_sf_buffer.base)), + {num_max_tokens_per_rank, hidden / 128}, + torch::TensorOptions().dtype(torch::kInt).device(buffer.device())); + auto topk_idx = torch::from_blob( + math::advance_ptr(buffer.data_ptr(), reinterpret_cast(input_topk_idx_buffer.base)), + {num_max_tokens_per_rank, num_topk}, + torch::TensorOptions().dtype(torch::kInt64).device(buffer.device())); + auto topk_weights = torch::from_blob( + math::advance_ptr(buffer.data_ptr(), reinterpret_cast(input_topk_weights_buffer.base)), + {num_max_tokens_per_rank, num_topk}, + torch::TensorOptions().dtype(torch::kFloat32).device(buffer.device())); + auto l1_acts = torch::from_blob( + math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l1_token_buffer.base)), + {num_max_pool_tokens, hidden}, + torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device())); + auto l1_acts_sf = torch::from_blob( + math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l1_sf_buffer.base)), + {num_max_padded_sf_pool_tokens, hidden / 128}, + {1, num_max_padded_sf_pool_tokens}, + torch::TensorOptions().dtype(torch::kInt).device(buffer.device())); + auto l2_acts = torch::from_blob( + math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l2_token_buffer.base)), + {num_max_pool_tokens, intermediate_hidden}, + torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device())); + auto l2_acts_sf = torch::from_blob( + math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l2_sf_buffer.base)), + {num_max_padded_sf_pool_tokens, intermediate_hidden / 128}, + {1, num_max_padded_sf_pool_tokens}, + torch::TensorOptions().dtype(torch::kInt).device(buffer.device())); + return std::make_tuple(x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf); + }; + return {reinterpret_cast(combine_token_buffer.get_end_ptr()), slice_input_buffers}; +} + static void fp8_fp4_mega_moe( const torch::Tensor& y, const std::tuple& l1_weights_tuple, @@ -230,7 +349,41 @@ static void register_apis(pybind11::module_& m) { #if DG_TENSORMAP_COMPATIBLE m.def("get_token_alignment_for_mega_moe", &get_token_alignment_for_mega_moe); m.def("get_symm_buffer_size_for_mega_moe", &get_symm_buffer_size_for_mega_moe); + m.def("get_symm_buffer_size_for_mega_moe_split", &get_symm_buffer_size_for_mega_moe_split); m.def("fp8_fp4_mega_moe", &fp8_fp4_mega_moe); + // Split-kernel MegaMoE: dispatch_l1_swiglu (K1) + l2_combine (K2) run concurrently on + // green-context SM partitions, combine_reduce (K3) reduces after both. The graph class + // exists on any CUDA (a stub throws below 13.1); see sm100_fp8_fp4_mega_moe_split.hpp. + pybind11::class_( + m, "SM100FP8FP4MegaMoESplitGraph") + .def(pybind11::init< + std::vector, + std::vector, + std::vector, + std::vector>, + std::vector, + std::vector, + std::vector, + std::vector, + std::vector, + const int&, + const int&, + const int&, + const int&, + const int&, + const int&, + const int&, + const float&, + const bool&, + const int&, + const int&, + const int&, + const int&, + const int&>()) + .def("replay", &SM100FP8FP4MegaMoESplitGraph::replay, + pybind11::call_guard()) + .def("get_green_context_ids", + &SM100FP8FP4MegaMoESplitGraph::get_green_context_ids); #endif } diff --git a/csrc/jit/handle.hpp b/csrc/jit/handle.hpp index be3bc31c07..c9ad787b45 100644 --- a/csrc/jit/handle.hpp +++ b/csrc/jit/handle.hpp @@ -3,7 +3,9 @@ #include #include #include +#include #include +#include #include "../utils/exception.hpp" #include "../utils/compatibility.hpp" @@ -20,6 +22,22 @@ static void* get_driver_handle() { return handle; } +static void* get_runtime_handle() { + static void* handle = nullptr; + if (handle == nullptr) { + if (const auto cuda_home = std::getenv("CUDA_HOME"); cuda_home != nullptr and cuda_home[0] != '\0') { + const auto runtime_path = std::string(cuda_home) + "/lib64/libcudart.so.13"; + handle = dlopen(runtime_path.c_str(), RTLD_LAZY | RTLD_LOCAL); + } + if (handle == nullptr) + handle = dlopen("libcudart.so.13", RTLD_LAZY | RTLD_LOCAL); + if (handle == nullptr) + handle = dlopen("libcudart.so", RTLD_LAZY | RTLD_LOCAL); + DG_HOST_ASSERT(handle != nullptr and "Failed to load CUDA runtime `libcudart.so`"); + } + return handle; +} + // Macro to define wrapper functions named `lazy_cu{API name}` #define DECL_LAZY_CUDA_DRIVER_FUNCTION(name) \ template \ @@ -33,6 +51,18 @@ static auto lazy_##name(Args&&... args) -> decltype(name(args...)) { \ return func(std::forward(args)...); \ } +#define DECL_LAZY_CUDA_RUNTIME_FUNCTION(name) \ +template \ +static auto lazy_##name(Args&&... args) -> decltype(name(args...)) { \ + using FuncType = decltype(&(name)); \ + static FuncType func = nullptr; \ + if (func == nullptr) { \ + func = reinterpret_cast(dlsym(get_runtime_handle(), #name)); \ + DG_HOST_ASSERT(func != nullptr and "Failed to load CUDA runtime API"); \ + } \ + return func(std::forward(args)...); \ +} + DECL_LAZY_CUDA_DRIVER_FUNCTION(cuGetErrorName); DECL_LAZY_CUDA_DRIVER_FUNCTION(cuGetErrorString); DECL_LAZY_CUDA_DRIVER_FUNCTION(cuFuncSetAttribute); @@ -45,6 +75,33 @@ DECL_LAZY_CUDA_DRIVER_FUNCTION(cuKernelGetFunction); DECL_LAZY_CUDA_DRIVER_FUNCTION(cuLaunchKernelEx); DECL_LAZY_CUDA_DRIVER_FUNCTION(cuTensorMapEncodeTiled); +#if CUDART_VERSION >= 13010 +DECL_LAZY_CUDA_RUNTIME_FUNCTION(cudaDeviceGetExecutionCtx); +DECL_LAZY_CUDA_RUNTIME_FUNCTION(cudaDeviceGetDevResource); +DECL_LAZY_CUDA_RUNTIME_FUNCTION(cudaDevSmResourceSplit); +DECL_LAZY_CUDA_RUNTIME_FUNCTION(cudaDevResourceGenerateDesc); +DECL_LAZY_CUDA_RUNTIME_FUNCTION(cudaGreenCtxCreate); +DECL_LAZY_CUDA_RUNTIME_FUNCTION(cudaExecutionCtxGetId); +DECL_LAZY_CUDA_RUNTIME_FUNCTION(cudaExecutionCtxDestroy); +DECL_LAZY_CUDA_RUNTIME_FUNCTION(cudaGraphAddNode); +DECL_LAZY_CUDA_RUNTIME_FUNCTION(cudaGraphKernelNodeSetAttribute); +DECL_LAZY_CUDA_RUNTIME_FUNCTION(cudaGraphLaunch); +DECL_LAZY_CUDA_RUNTIME_FUNCTION(cudaGraphExecDestroy); +DECL_LAZY_CUDA_RUNTIME_FUNCTION(cudaGraphDestroy); + +static cudaError_t lazy_cudaGraphInstantiate(cudaGraphExec_t *phGraphExec, + cudaGraph_t graph, + unsigned long long flags) { + using FuncType = cudaError_t (CUDARTAPI *)(cudaGraphExec_t*, cudaGraph_t, unsigned long long); + static FuncType func = nullptr; + if (func == nullptr) { + func = reinterpret_cast(dlsym(get_runtime_handle(), "cudaGraphInstantiate")); + DG_HOST_ASSERT(func != nullptr and "Failed to load CUDA runtime API"); + } + return func(phGraphExec, graph, flags); +} +#endif + #if CUDART_VERSION >= 12080 and defined(DG_JIT_USE_RUNTIME_API) // Use CUDA runtime API diff --git a/csrc/jit_kernels/impls/sm100_fp8_fp4_mega_moe_split.hpp b/csrc/jit_kernels/impls/sm100_fp8_fp4_mega_moe_split.hpp new file mode 100644 index 0000000000..fc53dc0fce --- /dev/null +++ b/csrc/jit_kernels/impls/sm100_fp8_fp4_mega_moe_split.hpp @@ -0,0 +1,1184 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../../jit/compiler.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "runtime_utils.hpp" + +#include +#include +#include + +#include "../heuristics/mega_moe.hpp" + +namespace deep_gemm { + +class SM100FP8FP4MegaMoESplitDispatchL1SwigluRuntime final : + public LaunchRuntime { +public: + struct Args { + int num_max_tokens_per_rank; + int hidden, intermediate_hidden; + int num_experts, num_topk; + int num_ranks; + int num_sms; + float activation_clamp; + bool fast_math; + bool local_only; + MegaMoEConfig config; + + int* cumulative_local_expert_recv_stats; + int num_tokens; + layout::SymBuffer<> sym_buffer_ptrs; + + CUtensorMap tensor_map_l1_acts; + CUtensorMap tensor_map_l1_acts_sf; + CUtensorMap tensor_map_l1_weights; + CUtensorMap tensor_map_l1_weights_sf; + CUtensorMap tensor_map_l1_output; + + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast( + &mega_moe_split::sm100_fp8_fp4_mega_moe_split_dispatch_l1_swiglu_impl< + {}, + {}, {}, + {}, {}, + {}, + {}, {}, {}, + {}, + {}, {}, + {}, + {}, + {}, + {}, {}, {}, + {}, {}, + {}, + {}, + {} + >); +}}; +)", args.num_max_tokens_per_rank, + args.hidden, args.intermediate_hidden, + args.num_experts, args.num_topk, + args.config.num_experts_per_wave, + args.config.block_m, args.config.block_n, args.config.block_k, + args.config.store_block_m, + args.config.sf_block_m, args.config.sf_block_n, + args.config.num_max_pool_tokens, + args.config.num_padded_sf_pool_tokens, + args.config.num_stages, + args.config.num_dispatch_threads, args.config.num_non_epilogue_threads, args.config.num_epilogue_threads, + args.num_sms, args.num_ranks, + to_string(args.activation_clamp), + args.fast_math ? "true" : "false", + args.local_only ? "true" : "false"); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.cumulative_local_expert_recv_stats, + args.num_tokens, + args.sym_buffer_ptrs, + args.tensor_map_l1_acts, + args.tensor_map_l1_acts_sf, + args.tensor_map_l1_weights, + args.tensor_map_l1_weights_sf, + args.tensor_map_l1_output + )); + } +}; + +class SM100FP8FP4MegaMoESplitL2CombineRuntime final : + public LaunchRuntime { +public: + struct Args { + int num_max_tokens_per_rank; + int hidden, intermediate_hidden; + int num_experts, num_topk; + int num_ranks; + int kernel1_sms, kernel2_sms; + MegaMoEConfig config; + + void* state; + uint32_t num_work_iters; + layout::SymBuffer<> sym_buffer_ptrs; + CUtensorMap tensor_map_l2_acts; + CUtensorMap tensor_map_l2_acts_sf; + CUtensorMap tensor_map_l2_weights; + CUtensorMap tensor_map_l2_weights_sf; + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast( + &mega_moe_split::sm100_fp8_fp4_mega_moe_split_l2_combine_impl< + {}, + {}, {}, + {}, {}, + {}, {}, {}, + {}, + {}, {}, + {}, + {}, + {}, + {}, + {}, + {}, {}, {} + >); +}}; +)", args.num_max_tokens_per_rank, + args.hidden, args.intermediate_hidden, + args.num_experts, args.num_topk, + args.config.block_m, args.config.block_n, args.config.block_k, + args.config.store_block_m, + args.config.sf_block_m, args.config.sf_block_n, + args.config.num_max_pool_tokens, + args.config.num_padded_sf_pool_tokens, + args.config.num_stages, + args.config.num_non_epilogue_threads, + args.config.num_epilogue_threads, + args.kernel1_sms, args.kernel2_sms, args.num_ranks); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.state, + args.num_work_iters, + args.sym_buffer_ptrs, + args.tensor_map_l2_acts, + args.tensor_map_l2_acts_sf, + args.tensor_map_l2_weights, + args.tensor_map_l2_weights_sf + )); + } +}; + +class SM100FP8FP4MegaMoESplitCombineReduceRuntime final : + public LaunchRuntime { +public: + struct Args { + int num_max_tokens_per_rank; + int hidden, intermediate_hidden; + int num_experts, num_topk; + int num_ranks; + MegaMoEConfig config; + + void* y; + void* state; + uint32_t num_tokens; + layout::SymBuffer<> sym_buffer_ptrs; + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast( + &mega_moe_split::sm100_fp8_fp4_mega_moe_split_combine_reduce_impl< + {}, + {}, {}, + {}, {}, + {}, + {}, + 512 + >); +}}; +)", args.num_max_tokens_per_rank, + args.hidden, args.intermediate_hidden, + args.num_experts, args.num_topk, + args.config.num_padded_sf_pool_tokens, + args.num_ranks); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.y, + args.state, + args.num_tokens, + args.sym_buffer_ptrs + )); + } +}; + +static void check_split_skeleton_tensor(const torch::Tensor& tensor, const at::ScalarType& dtype) { + DG_HOST_ASSERT(tensor.is_cuda()); + DG_HOST_ASSERT(tensor.is_contiguous()); + DG_HOST_ASSERT(tensor.scalar_type() == dtype); +} + +// SMEM layout / pipeline depth for the split dispatch_l1_swiglu (K1) kernel. K1 uses a +// route-based dispatch with a per-dispatch-warp token send/pull buffer in SMEM, so its SMEM +// formula differs from the fused megamoe heuristic (which sizes a pull buffer instead). This +// mirrors the K1 kernel's own SMEM partitioning exactly. +static std::pair get_mega_moe_split_kernel1_pipeline( + const int& smem_capacity, + const int& num_experts, const int& hidden, + const int& block_m, const int& block_n, const int& block_k, const int& store_block_m, + const int& sf_block_m, const int& sf_block_n, const int& gran_k, + const int& num_dispatch_warps, const int& num_epilogue_warps) { + constexpr int kSmemAlignment = 1024; + constexpr int kNumEpilogueStages = 2; + constexpr int kNumTMAStoreStages = 2; + + const int load_block_m = block_m / 2; + + // Dispatch region: expert counts + per-dispatch-warp token send/pull buffers + const int smem_expert_count_size = align( + num_experts * static_cast(sizeof(uint32_t)), kSmemAlignment); + const int smem_send_buffers_size = align( + static_cast(layout::Buffer(layout::Data(hidden), num_dispatch_warps, 1).get_num_bytes()), + kSmemAlignment); + const int smem_dispatch_size = smem_expert_count_size + smem_send_buffers_size; + + // C/D output region: max of L1 FP8 (2 TMA stages, BLOCK_N/2 post-SwiGLU) and L2 BF16 (1 stage) + const auto num_epilogue_warpgroups = num_epilogue_warps / 4; + const int smem_cd_l1 = num_epilogue_warpgroups * store_block_m * (block_n / 2) * kNumTMAStoreStages; + const int smem_cd_l2 = num_epilogue_warpgroups * store_block_m * block_n * static_cast(sizeof(nv_bfloat16)); + const int smem_cd = std::max(smem_cd_l1, smem_cd_l2); + + // Barriers: dispatch + tensor-memory full/empty + combine (2 per epilogue warp) + const int smem_barriers = (num_dispatch_warps + kNumEpilogueStages * 2 + num_epilogue_warps * 2) * 8; + + // Amax reduction + tensor-memory pointer + const int smem_amax_reduction = store_block_m * num_epilogue_warps * static_cast(sizeof(float)); + const int smem_tmem_ptr = 4; + + const int smem_sfa_per_stage = sf_block_m * (block_k / gran_k); + const int smem_sfb_per_stage = sf_block_n * (block_k / gran_k); + const int smem_per_stage = load_block_m * block_k + block_n * block_k + smem_sfa_per_stage + smem_sfb_per_stage + 2 * 8; + const int smem_fixed = smem_dispatch_size + smem_cd + smem_amax_reduction + smem_barriers + smem_tmem_ptr; + + const int num_stages = (smem_capacity - smem_fixed) / smem_per_stage; + DG_HOST_ASSERT(num_stages >= 2); + return {num_stages, smem_fixed + num_stages * smem_per_stage}; +} + +static MegaMoEConfig get_mega_moe_split_kernel1_config( + const int& num_ranks, + const int& num_experts, + const int& num_experts_per_rank, + const int& num_max_tokens_per_rank, + const int& num_tokens, + const int& num_topk, + const int& hidden, + const int& intermediate_hidden, + const int& num_padded_sf_pool_tokens, + const int& num_sms +) { + auto config = get_mega_moe_config( + num_ranks, num_experts, num_experts_per_rank, + num_max_tokens_per_rank, num_tokens, num_topk, + hidden, intermediate_hidden, num_padded_sf_pool_tokens); + config.num_experts_per_wave = get_num_experts_per_wave_for_mega_moe( + num_experts_per_rank, num_tokens, num_topk, + intermediate_hidden, config.block_m, config.block_n, num_sms); + // Override the pipeline depth / SMEM size with the split-K1 layout (the fused heuristic + // sizes a different dispatch region, which under-allocates K1's send buffers). + constexpr int kGranK = 32; + const auto [num_stages, smem_size] = get_mega_moe_split_kernel1_pipeline( + SM100ArchSpec::smem_capacity, + num_experts, hidden, + config.block_m, config.block_n, config.block_k, config.store_block_m, + config.sf_block_m, config.sf_block_n, kGranK, + config.num_dispatch_threads / 32, config.num_epilogue_threads / 32); + config.num_stages = num_stages; + config.smem_size = smem_size; + return config; +} + +static void print_mega_moe_split_kernel1_config(const SM100FP8FP4MegaMoESplitDispatchL1SwigluRuntime::Args& args) { + if (not (get_env("DG_PRINT_SPLIT_K1_CONFIG") or + get_env("DG_PRINT_CONFIGS") or + get_env("DG_JIT_DEBUG"))) { + return; + } + + const auto& config = args.config; + const auto key = fmt::format( + "split_k1:num_max_tokens_per_rank={},num_tokens={},hidden={},intermediate_hidden={}," + "num_experts={},num_topk={},num_ranks={},num_sms={},activation_clamp={},fast_math={}," + "local_only={},block_m={},block_n={},block_k={},store_block_m={},num_stages={}", + args.num_max_tokens_per_rank, args.num_tokens, args.hidden, args.intermediate_hidden, + args.num_experts, args.num_topk, args.num_ranks, args.num_sms, + to_string(args.activation_clamp), args.fast_math, args.local_only, + config.block_m, config.block_n, config.block_k, config.store_block_m, config.num_stages); + + static std::unordered_set printed; + if (printed.count(key) > 0) + return; + printed.insert(key); + + const auto num_threads = config.num_dispatch_threads + + config.num_non_epilogue_threads + config.num_epilogue_threads; + std::cout + << "\n" + << "SM100 FP8/FP4 MegaMoE Split K1 config\n" + << " > kernel: sm100_fp8_fp4_mega_moe_split_dispatch_l1_swiglu_impl\n" + << " > problem: num_tokens=" << args.num_tokens + << ", num_max_tokens_per_rank=" << args.num_max_tokens_per_rank + << ", hidden=" << args.hidden + << ", intermediate_hidden=" << args.intermediate_hidden + << ", num_experts=" << args.num_experts + << ", num_topk=" << args.num_topk + << ", num_ranks=" << args.num_ranks << "\n" + << " > runtime: num_sms=" << args.num_sms + << ", activation_clamp=" << args.activation_clamp + << ", fast_math=" << args.fast_math + << ", local_only=" << args.local_only << "\n" + << " > launch: grid=(" << args.launch_args.grid_dim.first + << ", " << args.launch_args.grid_dim.second + << "), threads=" << args.launch_args.num_threads + << ", smem=" << args.launch_args.smem_size + << " bytes, cluster_dim=" << args.launch_args.cluster_dim + << ", enable_pdl=" << args.launch_args.enable_pdl << "\n" + << " > thread layout: dispatch=" << config.num_dispatch_threads + << ", non_epilogue=" << config.num_non_epilogue_threads + << ", epilogue=" << config.num_epilogue_threads + << ", total=" << num_threads << "\n" + << " > tiles: block_m=" << config.block_m + << ", block_n=" << config.block_n + << ", block_k=" << config.block_k + << ", load_block_m=" << config.load_block_m + << ", load_block_n=" << config.load_block_n + << ", store_block_m=" << config.store_block_m << "\n" + << " > scale-factor tiles: sf_block_m=" << config.sf_block_m + << ", sf_block_n=" << config.sf_block_n + << ", num_padded_sf_pool_tokens=" << config.num_padded_sf_pool_tokens << "\n" + << " > pool/waves: num_max_pool_tokens=" << config.num_max_pool_tokens + << ", num_experts_per_wave=" << config.num_experts_per_wave << "\n" + << " > swizzle: acts=" << config.swizzle_acts_mode + << ", weights=" << config.swizzle_weights_mode << "\n" + << " > pipeline: num_stages=" << config.num_stages + << ", smem_size=" << config.smem_size << " bytes\n" + << " > template compact: <" + << args.num_max_tokens_per_rank << ", " + << args.hidden << ", " << args.intermediate_hidden << ", " + << args.num_experts << ", " << args.num_topk << ", " + << config.num_experts_per_wave << ", " + << config.block_m << ", " << config.block_n << ", " << config.block_k << ", " + << config.store_block_m << ", " + << config.sf_block_m << ", " << config.sf_block_n << ", " + << config.num_max_pool_tokens << ", " + << config.num_padded_sf_pool_tokens << ", " + << config.num_stages << ", " + << config.num_dispatch_threads << ", " << config.num_non_epilogue_threads << ", " + << config.num_epilogue_threads << ", " + << args.num_sms << ", " << args.num_ranks << ", " + << to_string(args.activation_clamp) << ", " + << (args.fast_math ? "true" : "false") << ", " + << (args.local_only ? "true" : "false") << ">\n" + << " > template parameters:\n" + << " [00] kNumMaxTokensPerRank = " << args.num_max_tokens_per_rank << "\n" + << " [01] kHidden = " << args.hidden << "\n" + << " [02] kIntermediateHidden = " << args.intermediate_hidden << "\n" + << " [03] kNumExperts = " << args.num_experts << "\n" + << " [04] kNumTopk = " << args.num_topk << "\n" + << " [05] kNumExpertsPerWave = " << config.num_experts_per_wave << "\n" + << " [06] BLOCK_M = " << config.block_m << "\n" + << " [07] BLOCK_N = " << config.block_n << "\n" + << " [08] BLOCK_K = " << config.block_k << "\n" + << " [09] STORE_BLOCK_M = " << config.store_block_m << "\n" + << " [10] SF_BLOCK_M = " << config.sf_block_m << "\n" + << " [11] SF_BLOCK_N = " << config.sf_block_n << "\n" + << " [12] kNumMaxPoolTokens = " << config.num_max_pool_tokens << "\n" + << " [13] kNumPaddedSFPoolTokens = " << config.num_padded_sf_pool_tokens << "\n" + << " [14] kNumStages = " << config.num_stages << "\n" + << " [15] kNumDispatchThreads = " << config.num_dispatch_threads << "\n" + << " [16] kNumNonEpilogueThreads = " << config.num_non_epilogue_threads << "\n" + << " [17] kNumEpilogueThreads = " << config.num_epilogue_threads << "\n" + << " [18] kNumSMs = " << args.num_sms << "\n" + << " [19] kNumRanks = " << args.num_ranks << "\n" + << " [20] kActivationClamp = " << to_string(args.activation_clamp) + << " (" << args.activation_clamp << ")\n" + << " [21] kFastMath = " << (args.fast_math ? "true" : "false") << "\n" + << " [22] kLocalOnly = " << (args.local_only ? "true" : "false") << "\n" + << " > derived template defaults:\n" + << " L1_SHAPE_N = " << args.intermediate_hidden * 2 + << " (kIntermediateHidden * 2)\n" + << " L1_SHAPE_K = " << args.hidden + << " (kHidden)\n" + << " L2_SHAPE_N = " << args.hidden + << " (kHidden)\n" + << " L2_SHAPE_K = " << args.intermediate_hidden + << " (kIntermediateHidden)\n" + << " kNumDispatchWarps = " << config.num_dispatch_threads / 32 << "\n" + << " kNumMMANonEpilogueWarps = " << config.num_non_epilogue_threads / 32 << "\n" + << " kNumEpilogueWarps = " << config.num_epilogue_threads / 32 << "\n" + << " kNumEpilogueWarpgroups = " << config.num_epilogue_threads / 128 << "\n" + << " kNumThreads = " << num_threads << "\n" + << " kNumTokensPerWarp = " << 32 / args.num_topk << "\n" + << " kNumExpertsPerRank = " << args.num_experts / args.num_ranks << "\n" + << std::endl; +} + +static MegaMoEConfig get_mega_moe_split_kernel2_config( + MegaMoEConfig config +) { + constexpr int kNumEpilogueStages = 2; + constexpr int kNumMaxStages = 32; + constexpr int kGranK = 32; + + const int load_block_m = config.block_m / 2; + const int num_epilogue_warpgroups = config.num_epilogue_threads / 128; + const int smem_cd_l2 = num_epilogue_warpgroups * config.store_block_m * config.block_n * + static_cast(sizeof(nv_bfloat16)); + const int smem_sfa_per_stage = config.sf_block_m * (config.block_k / kGranK); + const int smem_sfb_per_stage = config.sf_block_n * (config.block_k / kGranK); + const int smem_per_stage = + load_block_m * config.block_k + + config.block_n * config.block_k + + smem_sfa_per_stage + smem_sfb_per_stage + + 2 * static_cast(sizeof(uint64_t)); + const int smem_fixed = + smem_cd_l2 + + kNumEpilogueStages * 2 * static_cast(sizeof(uint64_t)) + + static_cast(sizeof(uint32_t)); + const int num_stages = std::min( + (SM100ArchSpec::smem_capacity - smem_fixed) / smem_per_stage, + kNumMaxStages); + DG_HOST_ASSERT(num_stages >= 2); + + config.num_stages = num_stages; + config.smem_size = smem_fixed + num_stages * smem_per_stage; + return config; +} + +#if CUDART_VERSION >= 13010 + +class SM100FP8FP4MegaMoESplitGraph final { +private: + struct BufferViews { + torch::Tensor l1_acts; + torch::Tensor l1_acts_sf; + torch::Tensor l2_acts; + torch::Tensor l2_acts_sf; + }; + + struct Kernel1NodeArgs { + int* cumulative_local_expert_recv_stats; + uint32_t num_tokens; + layout::SymBuffer<> sym_buffer_ptrs; + CUtensorMap tensor_map_l1_acts; + CUtensorMap tensor_map_l1_acts_sf; + CUtensorMap tensor_map_l1_weights; + CUtensorMap tensor_map_l1_weights_sf; + CUtensorMap tensor_map_l1_output; + + Kernel1NodeArgs() = default; + }; + + struct Kernel2NodeArgs { + void* state; + uint32_t num_work_iters; + layout::SymBuffer<> sym_buffer_ptrs; + CUtensorMap tensor_map_l2_acts; + CUtensorMap tensor_map_l2_acts_sf; + CUtensorMap tensor_map_l2_weights; + CUtensorMap tensor_map_l2_weights_sf; + + Kernel2NodeArgs() = default; + }; + + struct Kernel3NodeArgs { + void* y; + void* state; + uint32_t num_tokens; + layout::SymBuffer<> sym_buffer_ptrs; + + Kernel3NodeArgs() = default; + + Kernel3NodeArgs( + void* y, + void* state, + const uint32_t& num_tokens, + const layout::SymBuffer<>& sym_buffer_ptrs + ) : y(y), + state(state), + num_tokens(num_tokens), + sym_buffer_ptrs(sym_buffer_ptrs) {} + }; + + std::vector states_; + std::vector ys_; + std::vector sym_buffers_; + std::vector> sym_buffer_ptrs_; + std::vector l1_weights_; + std::vector l1_weights_sf_; + std::vector l2_weights_; + std::vector l2_weights_sf_; + std::vector stats_; + + uint32_t rank_idx_ = 0; + uint32_t num_max_tokens_per_rank_ = 0; + uint32_t num_experts_ = 0; + uint32_t num_experts_per_rank_ = 0; + uint32_t num_topk_ = 0; + uint32_t num_tokens_ = 0; + uint32_t hidden_ = 0; + uint32_t intermediate_hidden_ = 0; + uint32_t kernel1_sms_ = 0; + uint32_t kernel2_sms_ = 0; + uint32_t reduce_sms_ = 0; + uint32_t kernel2_work_iters_ = 0; + uint32_t reduce_work_iters_ = 0; + float activation_clamp_ = 0.0f; + bool fast_math_ = true; + MegaMoEConfig config_; + MegaMoEConfig kernel2_config_; + + std::shared_ptr kernel1_runtime_; + std::shared_ptr kernel2_runtime_; + std::shared_ptr kernel3_runtime_; + KernelHandle kernel1_graph_kernel_ = nullptr; + KernelHandle kernel2_graph_kernel_ = nullptr; + KernelHandle kernel3_graph_kernel_ = nullptr; + + int device_idx_ = 0; + cudaExecutionContext_t primary_context_ = nullptr; + std::array green_contexts_ = {nullptr, nullptr}; + std::array green_context_ids_ = {0, 0}; + + cudaGraph_t graph_ = nullptr; + cudaGraphExec_t graph_exec_ = nullptr; + std::vector kernel1_nodes_; + std::vector kernel2_nodes_; + std::vector reduce_nodes_; + + std::vector kernel1_args_; + std::vector kernel2_args_; + std::vector kernel3_args_; + std::vector> kernel1_params_; + std::vector> kernel2_params_; + std::vector> kernel3_params_; + + static uint32_t checked_nonnegative_u32(const int& value) { + DG_HOST_ASSERT(value >= 0); + return static_cast(value); + } + + static uint32_t checked_positive_u32(const int& value) { + DG_HOST_ASSERT(value > 0); + return static_cast(value); + } + + static int get_num_padded_sf_pool_tokens( + const int& num_max_pool_tokens + ) { + int num_padded_sf_pool_tokens = 0; + for (int block_m: layout::kCandidateBlockM) { + num_padded_sf_pool_tokens = std::max( + num_padded_sf_pool_tokens, + layout::get_num_padded_sf_pool_tokens(num_max_pool_tokens, block_m)); + } + return num_padded_sf_pool_tokens; + } + + BufferViews slice_buffer(const torch::Tensor& buffer) const { + const auto workspace = layout::SplitWorkspace( + nullptr, num_ranks(), num_experts_, num_max_tokens_per_rank_, num_topk_); + const auto fp8_token_layout = layout::Data(hidden_); + const auto fp8_intermediate_token_layout = layout::Data(intermediate_hidden_); + const auto fp8_sf_layout = layout::Data(hidden_ / 32); + const auto fp8_intermediate_sf_layout = layout::Data(intermediate_hidden_ / 32); + const auto input_topk_idx_layout = layout::Data(num_topk_ * sizeof(int64_t), false); + const auto input_topk_weights_layout = layout::Data(num_topk_ * sizeof(float), false); + const auto l1_topk_weights_layout = layout::Data(sizeof(float), false); + + const auto input_token_buffer = layout::Buffer( + fp8_token_layout, 1, num_max_tokens_per_rank_, + workspace.get_end_ptr()); + const auto input_sf_buffer = layout::Buffer( + fp8_sf_layout, 1, num_max_tokens_per_rank_, + input_token_buffer.get_end_ptr()); + const auto input_topk_idx_buffer = layout::Buffer( + input_topk_idx_layout, 1, num_max_tokens_per_rank_, + input_sf_buffer.get_end_ptr()); + const auto input_topk_weights_buffer = layout::Buffer( + input_topk_weights_layout, 1, num_max_tokens_per_rank_, + input_topk_idx_buffer.get_end_ptr()); + const auto l1_token_buffer = layout::Buffer( + fp8_token_layout, 1, config_.num_max_pool_tokens, + input_topk_weights_buffer.get_end_ptr()); + const auto l1_sf_buffer = layout::Buffer( + fp8_sf_layout, 1, config_.num_padded_sf_pool_tokens, + l1_token_buffer.get_end_ptr()); + const auto l1_topk_weights_buffer = layout::Buffer( + l1_topk_weights_layout, 1, config_.num_max_pool_tokens, + l1_sf_buffer.get_end_ptr()); + const auto l2_token_buffer = layout::Buffer( + fp8_intermediate_token_layout, 1, config_.num_max_pool_tokens, + l1_topk_weights_buffer.get_end_ptr()); + const auto l2_sf_buffer = layout::Buffer( + fp8_intermediate_sf_layout, 1, config_.num_padded_sf_pool_tokens, + l2_token_buffer.get_end_ptr()); + + return { + torch::from_blob( + math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l1_token_buffer.base)), + {config_.num_max_pool_tokens, hidden_}, + torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device())), + torch::from_blob( + math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l1_sf_buffer.base)), + {config_.num_padded_sf_pool_tokens, hidden_ / 128}, + {1, config_.num_padded_sf_pool_tokens}, + torch::TensorOptions().dtype(torch::kInt).device(buffer.device())), + torch::from_blob( + math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l2_token_buffer.base)), + {config_.num_max_pool_tokens, intermediate_hidden_}, + torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device())), + torch::from_blob( + math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l2_sf_buffer.base)), + {config_.num_padded_sf_pool_tokens, intermediate_hidden_ / 128}, + {1, config_.num_padded_sf_pool_tokens}, + torch::TensorOptions().dtype(torch::kInt).device(buffer.device())) + }; + } + + uint32_t num_ranks() const { + DG_HOST_ASSERT(not sym_buffer_ptrs_.empty()); + return static_cast(sym_buffer_ptrs_[0].size()); + } + + void check_inputs() { + DG_HOST_ASSERT(not states_.empty()); + DG_HOST_ASSERT(states_.size() == ys_.size()); + DG_HOST_ASSERT(states_.size() == sym_buffers_.size()); + DG_HOST_ASSERT(states_.size() == sym_buffer_ptrs_.size()); + DG_HOST_ASSERT(states_.size() == l1_weights_.size()); + DG_HOST_ASSERT(states_.size() == l1_weights_sf_.size()); + DG_HOST_ASSERT(states_.size() == l2_weights_.size()); + DG_HOST_ASSERT(states_.size() == l2_weights_sf_.size()); + DG_HOST_ASSERT(states_.size() == stats_.size()); + DG_HOST_ASSERT(num_experts_ % num_ranks() == 0); + DG_HOST_ASSERT(num_experts_per_rank_ == num_experts_ / num_ranks()); + DG_HOST_ASSERT(num_tokens_ <= num_max_tokens_per_rank_); + DG_HOST_ASSERT(hidden_ % 128 == 0 and intermediate_hidden_ % 128 == 0); + DG_HOST_ASSERT(kernel1_sms_ % 2 == 0 and kernel2_sms_ % 2 == 0); + + for (uint32_t buffer_idx = 0; buffer_idx < states_.size(); ++buffer_idx) { + check_split_skeleton_tensor(states_[buffer_idx], torch::kInt); + DG_HOST_ASSERT(states_[buffer_idx].numel() >= 7); + DG_HOST_ASSERT(ys_[buffer_idx].is_cuda()); + DG_HOST_ASSERT(ys_[buffer_idx].is_contiguous()); + DG_HOST_ASSERT(ys_[buffer_idx].nbytes() % 4 == 0); + DG_HOST_ASSERT(sym_buffers_[buffer_idx].is_cuda()); + DG_HOST_ASSERT(sym_buffers_[buffer_idx].is_contiguous()); + DG_HOST_ASSERT(l1_weights_[buffer_idx].is_cuda()); + DG_HOST_ASSERT(l1_weights_[buffer_idx].is_contiguous()); + DG_HOST_ASSERT(l1_weights_sf_[buffer_idx].is_cuda()); + DG_HOST_ASSERT(l2_weights_[buffer_idx].is_cuda()); + DG_HOST_ASSERT(l2_weights_[buffer_idx].is_contiguous()); + DG_HOST_ASSERT(l2_weights_sf_[buffer_idx].is_cuda()); + DG_HOST_ASSERT(stats_[buffer_idx].is_cuda()); + DG_HOST_ASSERT(stats_[buffer_idx].is_contiguous()); + DG_HOST_ASSERT(stats_[buffer_idx].scalar_type() == torch::kInt); + DG_HOST_ASSERT(stats_[buffer_idx].numel() == num_experts_per_rank_); + } + } + + void build_kernel_runtimes() { + const SM100FP8FP4MegaMoESplitDispatchL1SwigluRuntime::Args kernel1_args = { + .num_max_tokens_per_rank = static_cast(num_max_tokens_per_rank_), + .hidden = static_cast(hidden_), + .intermediate_hidden = static_cast(intermediate_hidden_), + .num_experts = static_cast(num_experts_), + .num_topk = static_cast(num_topk_), + .num_ranks = static_cast(num_ranks()), + .num_sms = static_cast(kernel1_sms_), + .activation_clamp = activation_clamp_, + .fast_math = fast_math_, + .local_only = false, + .config = config_, + .cumulative_local_expert_recv_stats = nullptr, + .num_tokens = static_cast(num_tokens_), + .sym_buffer_ptrs = layout::SymBuffer<>(), + .tensor_map_l1_acts = {}, + .tensor_map_l1_acts_sf = {}, + .tensor_map_l1_weights = {}, + .tensor_map_l1_weights_sf = {}, + .tensor_map_l1_output = {}, + .launch_args = LaunchArgs(static_cast(kernel1_sms_), + config_.num_dispatch_threads + config_.num_non_epilogue_threads + config_.num_epilogue_threads, + config_.smem_size, 2) + }; + const SM100FP8FP4MegaMoESplitL2CombineRuntime::Args real_kernel2_args = { + .num_max_tokens_per_rank = static_cast(num_max_tokens_per_rank_), + .hidden = static_cast(hidden_), + .intermediate_hidden = static_cast(intermediate_hidden_), + .num_experts = static_cast(num_experts_), + .num_topk = static_cast(num_topk_), + .num_ranks = static_cast(num_ranks()), + .kernel1_sms = static_cast(kernel1_sms_), + .kernel2_sms = static_cast(kernel2_sms_), + .config = kernel2_config_, + .state = nullptr, + .num_work_iters = kernel2_work_iters_, + .sym_buffer_ptrs = layout::SymBuffer<>(), + .tensor_map_l2_acts = {}, + .tensor_map_l2_acts_sf = {}, + .tensor_map_l2_weights = {}, + .tensor_map_l2_weights_sf = {}, + .launch_args = LaunchArgs( + static_cast(kernel2_sms_), + kernel2_config_.num_non_epilogue_threads + kernel2_config_.num_epilogue_threads, + kernel2_config_.smem_size, 2) + }; + const SM100FP8FP4MegaMoESplitCombineReduceRuntime::Args kernel3_args = { + .num_max_tokens_per_rank = static_cast(num_max_tokens_per_rank_), + .hidden = static_cast(hidden_), + .intermediate_hidden = static_cast(intermediate_hidden_), + .num_experts = static_cast(num_experts_), + .num_topk = static_cast(num_topk_), + .num_ranks = static_cast(num_ranks()), + .config = config_, + .y = nullptr, + .state = nullptr, + .num_tokens = static_cast(num_tokens_), + .sym_buffer_ptrs = layout::SymBuffer<>(), + .launch_args = LaunchArgs(static_cast(num_tokens_), 512, 0, 1, false) + }; + + print_mega_moe_split_kernel1_config(kernel1_args); + const auto kernel1_code = SM100FP8FP4MegaMoESplitDispatchL1SwigluRuntime::generate(kernel1_args); + const auto kernel2_code = SM100FP8FP4MegaMoESplitL2CombineRuntime::generate(real_kernel2_args); + const auto kernel3_code = SM100FP8FP4MegaMoESplitCombineReduceRuntime::generate(kernel3_args); + kernel1_runtime_ = compiler->build("sm100_fp8_fp4_mega_moe_split_dispatch_l1_swiglu", kernel1_code); + kernel2_runtime_ = compiler->build("sm100_fp8_fp4_mega_moe_split_l2_combine", kernel2_code); + kernel3_runtime_ = compiler->build("sm100_fp8_fp4_mega_moe_split_combine_reduce", kernel3_code); + + kernel1_graph_kernel_ = kernel1_runtime_->kernel; + kernel2_graph_kernel_ = kernel2_runtime_->kernel; + kernel3_graph_kernel_ = kernel3_runtime_->kernel; + } + + void create_green_contexts() { + DG_CUDA_RUNTIME_CHECK(cudaGetDevice(&device_idx_)); + DG_CUDA_RUNTIME_CHECK(cudaSetDevice(device_idx_)); + DG_CUDA_RUNTIME_CHECK(lazy_cudaDeviceGetExecutionCtx(&primary_context_, device_idx_)); + DG_HOST_ASSERT(primary_context_ != nullptr); + + cudaDevResource sm_resource = {}; + DG_CUDA_RUNTIME_CHECK(lazy_cudaDeviceGetDevResource(device_idx_, &sm_resource, cudaDevResourceTypeSm)); + DG_HOST_ASSERT(sm_resource.type == cudaDevResourceTypeSm); + DG_HOST_ASSERT(kernel1_sms_ + kernel2_sms_ <= sm_resource.sm.smCount); + + cudaDevResource workqueue_resource = {}; + DG_CUDA_RUNTIME_CHECK(lazy_cudaDeviceGetDevResource( + device_idx_, &workqueue_resource, cudaDevResourceTypeWorkqueueConfig)); + DG_HOST_ASSERT(workqueue_resource.type == cudaDevResourceTypeWorkqueueConfig); + workqueue_resource.wqConfig.sharingScope = cudaDevWorkqueueConfigScopeGreenCtxBalanced; + + std::array split_resources = {}; + cudaDevResource remainder = {}; + std::array group_params = {}; + group_params[0].smCount = kernel1_sms_; + group_params[0].coscheduledSmCount = 2; + group_params[1].smCount = kernel2_sms_; + group_params[1].coscheduledSmCount = 2; + DG_CUDA_RUNTIME_CHECK(lazy_cudaDevSmResourceSplit( + split_resources.data(), static_cast(split_resources.size()), + &sm_resource, &remainder, 0, group_params.data())); + + for (uint32_t context_idx = 0; context_idx < split_resources.size(); ++context_idx) { + DG_HOST_ASSERT(split_resources[context_idx].type == cudaDevResourceTypeSm); + std::array context_resources = { + split_resources[context_idx], + workqueue_resource + }; + cudaDevResourceDesc_t resource_desc = nullptr; + DG_CUDA_RUNTIME_CHECK(lazy_cudaDevResourceGenerateDesc( + &resource_desc, context_resources.data(), context_resources.size())); + DG_CUDA_RUNTIME_CHECK(lazy_cudaGreenCtxCreate( + &green_contexts_[context_idx], resource_desc, device_idx_, 0)); + DG_CUDA_RUNTIME_CHECK(lazy_cudaExecutionCtxGetId( + green_contexts_[context_idx], &green_context_ids_[context_idx])); + } + } + + void prepare_kernel_params() { + const auto num_buffers = states_.size(); + kernel1_args_.resize(num_buffers); + kernel2_args_.resize(num_buffers); + kernel3_args_.resize(num_buffers); + kernel1_params_.resize(num_buffers); + kernel2_params_.resize(num_buffers); + kernel3_params_.resize(num_buffers); + + constexpr int kGranK = 32; + const int sf_smem_outer_dim = config_.block_k / (kGranK * 4); + for (uint32_t buffer_idx = 0; buffer_idx < num_buffers; ++buffer_idx) { + const auto views = slice_buffer(sym_buffers_[buffer_idx]); + kernel1_args_[buffer_idx].cumulative_local_expert_recv_stats = stats_[buffer_idx].data_ptr(); + kernel1_args_[buffer_idx].num_tokens = num_tokens_; + kernel1_args_[buffer_idx].sym_buffer_ptrs = layout::SymBuffer<>(sym_buffer_ptrs_[buffer_idx], rank_idx_); + kernel1_args_[buffer_idx].tensor_map_l1_acts = make_tma_2d_desc( + views.l1_acts, + hidden_, config_.num_max_pool_tokens, + config_.block_k, config_.load_block_m, + static_cast(views.l1_acts.stride(-2)), + config_.swizzle_acts_mode); + kernel1_args_[buffer_idx].tensor_map_l1_acts_sf = make_tma_sf_desc( + cute::UMMA::Major::MN, views.l1_acts_sf, + config_.num_padded_sf_pool_tokens, hidden_, + config_.sf_block_m, kGranK, + 1, 0, 0, false, + sf_smem_outer_dim); + kernel1_args_[buffer_idx].tensor_map_l1_weights = make_tma_2d_desc( + l1_weights_[buffer_idx], + hidden_, num_experts_per_rank_ * intermediate_hidden_ * 2, + config_.block_k, config_.load_block_n, + static_cast(l1_weights_[buffer_idx].stride(-2)), + config_.swizzle_weights_mode); + kernel1_args_[buffer_idx].tensor_map_l1_weights_sf = make_tma_sf_desc( + cute::UMMA::Major::MN, l1_weights_sf_[buffer_idx], + intermediate_hidden_ * 2, hidden_, + config_.block_n, kGranK, + num_experts_per_rank_, 0, 0, false, + sf_smem_outer_dim); + kernel1_args_[buffer_idx].tensor_map_l1_output = make_tma_2d_desc( + views.l2_acts, + intermediate_hidden_, config_.num_max_pool_tokens, + config_.block_n / 2, config_.store_block_m, + static_cast(views.l2_acts.stride(-2)), + config_.swizzle_acts_mode / 2); + + kernel2_args_[buffer_idx].state = states_[buffer_idx].data_ptr(); + kernel2_args_[buffer_idx].num_work_iters = kernel2_work_iters_; + kernel2_args_[buffer_idx].sym_buffer_ptrs = layout::SymBuffer<>(sym_buffer_ptrs_[buffer_idx], rank_idx_); + kernel2_args_[buffer_idx].tensor_map_l2_acts = make_tma_2d_desc( + views.l2_acts, + intermediate_hidden_, config_.num_max_pool_tokens, + config_.block_k, config_.load_block_m, + static_cast(views.l2_acts.stride(-2)), + config_.swizzle_acts_mode); + kernel2_args_[buffer_idx].tensor_map_l2_acts_sf = make_tma_sf_desc( + cute::UMMA::Major::MN, views.l2_acts_sf, + config_.num_padded_sf_pool_tokens, intermediate_hidden_, + config_.sf_block_m, kGranK, + 1, 0, 0, false, + sf_smem_outer_dim); + kernel2_args_[buffer_idx].tensor_map_l2_weights = make_tma_2d_desc( + l2_weights_[buffer_idx], + intermediate_hidden_, num_experts_per_rank_ * hidden_, + config_.block_k, config_.load_block_n, + static_cast(l2_weights_[buffer_idx].stride(-2)), + config_.swizzle_weights_mode); + kernel2_args_[buffer_idx].tensor_map_l2_weights_sf = make_tma_sf_desc( + cute::UMMA::Major::MN, l2_weights_sf_[buffer_idx], + hidden_, intermediate_hidden_, + config_.block_n, kGranK, + num_experts_per_rank_, 0, 0, false, + sf_smem_outer_dim); + kernel3_args_[buffer_idx] = Kernel3NodeArgs( + ys_[buffer_idx].data_ptr(), + states_[buffer_idx].data_ptr(), + num_tokens_, + layout::SymBuffer<>(sym_buffer_ptrs_[buffer_idx], rank_idx_)); + + kernel1_params_[buffer_idx] = { + &kernel1_args_[buffer_idx].cumulative_local_expert_recv_stats, + &kernel1_args_[buffer_idx].num_tokens, + &kernel1_args_[buffer_idx].sym_buffer_ptrs, + &kernel1_args_[buffer_idx].tensor_map_l1_acts, + &kernel1_args_[buffer_idx].tensor_map_l1_acts_sf, + &kernel1_args_[buffer_idx].tensor_map_l1_weights, + &kernel1_args_[buffer_idx].tensor_map_l1_weights_sf, + &kernel1_args_[buffer_idx].tensor_map_l1_output + }; + kernel2_params_[buffer_idx] = { + &kernel2_args_[buffer_idx].state, + &kernel2_args_[buffer_idx].num_work_iters, + &kernel2_args_[buffer_idx].sym_buffer_ptrs, + &kernel2_args_[buffer_idx].tensor_map_l2_acts, + &kernel2_args_[buffer_idx].tensor_map_l2_acts_sf, + &kernel2_args_[buffer_idx].tensor_map_l2_weights, + &kernel2_args_[buffer_idx].tensor_map_l2_weights_sf + }; + kernel3_params_[buffer_idx] = { + &kernel3_args_[buffer_idx].y, + &kernel3_args_[buffer_idx].state, + &kernel3_args_[buffer_idx].num_tokens, + &kernel3_args_[buffer_idx].sym_buffer_ptrs + }; + } + } + + cudaGraphNode_t add_kernel_node( + const cudaExecutionContext_t& context, + const KernelHandle& kernel, + const uint32_t& num_blocks, + const uint32_t& block_dim, + const uint32_t& shared_mem_bytes, + const uint32_t& cluster_dim, + void** kernel_params, + const std::vector& dependencies + ) { + if (shared_mem_bytes > 0) { + #if defined(DG_JIT_USE_RUNTIME_API) + DG_CUDA_RUNTIME_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_bytes)); + #else + DG_CUDA_DRIVER_CHECK(lazy_cuFuncSetAttribute( + kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, static_cast(shared_mem_bytes))); + #endif + } + + cudaGraphNode_t node = nullptr; + cudaGraphNodeParams node_params = {}; + node_params.type = cudaGraphNodeTypeKernel; + #if defined(DG_JIT_USE_RUNTIME_API) + node_params.kernel.kern = kernel; + node_params.kernel.functionType = cudaKernelFunctionTypeKernel; + #else + node_params.kernel.cuFunc = kernel; + node_params.kernel.functionType = cudaKernelFunctionTypeFunction; + #endif + node_params.kernel.gridDim = dim3(num_blocks, 1, 1); + node_params.kernel.blockDim = dim3(block_dim, 1, 1); + node_params.kernel.sharedMemBytes = shared_mem_bytes; + node_params.kernel.kernelParams = kernel_params; + node_params.kernel.extra = nullptr; + node_params.kernel.ctx = context; + + const auto dependency_ptr = dependencies.empty() ? nullptr : dependencies.data(); + DG_CUDA_RUNTIME_CHECK(lazy_cudaGraphAddNode( + &node, graph_, dependency_ptr, nullptr, dependencies.size(), &node_params)); + + if (cluster_dim > 1) { + cudaKernelNodeAttrValue attr = {}; + attr.clusterDim = {cluster_dim, 1, 1}; + DG_CUDA_RUNTIME_CHECK(lazy_cudaGraphKernelNodeSetAttribute( + node, cudaKernelNodeAttributeClusterDimension, &attr)); + } + return node; + } + + void build_graph() { + prepare_kernel_params(); + DG_CUDA_RUNTIME_CHECK(cudaGraphCreate(&graph_, 0)); + + const auto num_buffers = states_.size(); + kernel1_nodes_.reserve(num_buffers); + kernel2_nodes_.reserve(num_buffers); + reduce_nodes_.reserve(num_buffers); + + const uint32_t kernel1_block_dim = static_cast( + config_.num_dispatch_threads + config_.num_non_epilogue_threads + config_.num_epilogue_threads); + for (uint32_t buffer_idx = 0; buffer_idx < num_buffers; ++buffer_idx) { + std::vector dependencies; + if (not kernel1_nodes_.empty()) + dependencies.push_back(kernel1_nodes_.back()); + kernel1_nodes_.push_back(add_kernel_node( + green_contexts_[0], kernel1_graph_kernel_, kernel1_sms_, + kernel1_block_dim, static_cast(config_.smem_size), 2, + kernel1_params_[buffer_idx].data(), dependencies)); + } + + for (uint32_t buffer_idx = 0; buffer_idx < num_buffers; ++buffer_idx) { + std::vector dependencies; + if (not kernel2_nodes_.empty()) + dependencies.push_back(kernel2_nodes_.back()); + const uint32_t kernel2_block_dim = static_cast(kernel2_config_.num_non_epilogue_threads + kernel2_config_.num_epilogue_threads); + const uint32_t kernel2_shared_mem_bytes = static_cast(kernel2_config_.smem_size); + const uint32_t kernel2_cluster_dim = 2u; + kernel2_nodes_.push_back(add_kernel_node( + green_contexts_[1], kernel2_graph_kernel_, kernel2_sms_, + kernel2_block_dim, kernel2_shared_mem_bytes, kernel2_cluster_dim, + kernel2_params_[buffer_idx].data(), dependencies)); + } + + for (uint32_t buffer_idx = 0; buffer_idx < num_buffers; ++buffer_idx) { + std::vector dependencies; + if (reduce_nodes_.empty()) { + dependencies.push_back(kernel1_nodes_.back()); + dependencies.push_back(kernel2_nodes_.back()); + } else { + dependencies.push_back(reduce_nodes_.back()); + } + reduce_nodes_.push_back(add_kernel_node( + primary_context_, kernel3_graph_kernel_, num_tokens_, + 512, 0, 1, + kernel3_params_[buffer_idx].data(), dependencies)); + } + + DG_CUDA_RUNTIME_CHECK(lazy_cudaGraphInstantiate(&graph_exec_, graph_, 0)); + } + + void destroy_noexcept() noexcept { + if (graph_exec_ != nullptr) { + (void) lazy_cudaGraphExecDestroy(graph_exec_); + graph_exec_ = nullptr; + } + if (graph_ != nullptr) { + (void) lazy_cudaGraphDestroy(graph_); + graph_ = nullptr; + } + for (uint32_t context_idx = 0; context_idx < green_contexts_.size(); ++context_idx) { + if (green_contexts_[context_idx] != nullptr) { + (void) lazy_cudaExecutionCtxDestroy(green_contexts_[context_idx]); + green_contexts_[context_idx] = nullptr; + } + } + } + +public: + SM100FP8FP4MegaMoESplitGraph( + std::vector states, + std::vector ys, + std::vector sym_buffers, + std::vector> sym_buffer_ptrs, + std::vector l1_weights, + std::vector l1_weights_sf, + std::vector l2_weights, + std::vector l2_weights_sf, + std::vector stats, + const int& rank_idx, + const int& num_max_tokens_per_rank, + const int& num_experts, + const int& num_topk, + const int& num_tokens, + const int& hidden, + const int& intermediate_hidden, + const float& activation_clamp, + const bool& fast_math, + const int& kernel1_sms, + const int& kernel2_sms, + const int& reduce_sms, + const int& kernel2_work_iters, + const int& reduce_work_iters + ) : states_(std::move(states)), + ys_(std::move(ys)), + sym_buffers_(std::move(sym_buffers)), + sym_buffer_ptrs_(std::move(sym_buffer_ptrs)), + l1_weights_(std::move(l1_weights)), + l1_weights_sf_(std::move(l1_weights_sf)), + l2_weights_(std::move(l2_weights)), + l2_weights_sf_(std::move(l2_weights_sf)), + stats_(std::move(stats)), + rank_idx_(checked_nonnegative_u32(rank_idx)), + num_max_tokens_per_rank_(checked_positive_u32(num_max_tokens_per_rank)), + num_experts_(checked_positive_u32(num_experts)), + num_experts_per_rank_(0), + num_topk_(checked_positive_u32(num_topk)), + num_tokens_(checked_nonnegative_u32(num_tokens)), + hidden_(checked_positive_u32(hidden)), + intermediate_hidden_(checked_positive_u32(intermediate_hidden)), + kernel1_sms_(checked_positive_u32(kernel1_sms)), + kernel2_sms_(checked_positive_u32(kernel2_sms)), + reduce_sms_(checked_positive_u32(reduce_sms)), + kernel2_work_iters_(checked_nonnegative_u32(kernel2_work_iters)), + reduce_work_iters_(checked_nonnegative_u32(reduce_work_iters)), + activation_clamp_(activation_clamp), + fast_math_(fast_math) { + DG_HOST_ASSERT(not sym_buffer_ptrs_.empty()); + num_experts_per_rank_ = num_experts_ / num_ranks(); + const int num_max_pool_tokens = layout::get_num_max_pool_tokens( + static_cast(num_ranks()), static_cast(num_max_tokens_per_rank_), + static_cast(num_topk_), static_cast(num_experts_per_rank_)); + const int num_padded_sf_pool_tokens = get_num_padded_sf_pool_tokens(num_max_pool_tokens); + config_ = get_mega_moe_split_kernel1_config( + static_cast(num_ranks()), static_cast(num_experts_), static_cast(num_experts_per_rank_), + static_cast(num_max_tokens_per_rank_), static_cast(num_tokens_), static_cast(num_topk_), + static_cast(hidden_), static_cast(intermediate_hidden_), + num_padded_sf_pool_tokens, static_cast(kernel1_sms_)); + kernel2_config_ = get_mega_moe_split_kernel2_config(config_); + check_inputs(); + build_kernel_runtimes(); + create_green_contexts(); + build_graph(); + } + + SM100FP8FP4MegaMoESplitGraph(const SM100FP8FP4MegaMoESplitGraph&) = delete; + SM100FP8FP4MegaMoESplitGraph& operator=(const SM100FP8FP4MegaMoESplitGraph&) = delete; + + ~SM100FP8FP4MegaMoESplitGraph() { + destroy_noexcept(); + } + + void replay() { + DG_HOST_ASSERT(graph_exec_ != nullptr); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + DG_CUDA_RUNTIME_CHECK(lazy_cudaGraphLaunch(graph_exec_, stream)); + } + + std::tuple get_green_context_ids() const { + return {green_context_ids_[0], green_context_ids_[1]}; + } +}; + +#else + +class SM100FP8FP4MegaMoESplitGraph final { +public: + SM100FP8FP4MegaMoESplitGraph( + std::vector, + std::vector, + std::vector, + std::vector>, + std::vector, + std::vector, + std::vector, + std::vector, + std::vector, + const int&, + const int&, + const int&, + const int&, + const int&, + const int&, + const int&, + const float&, + const bool&, + const int&, + const int&, + const int&, + const int&, + const int& + ) { + DG_HOST_UNREACHABLE( + "Native real K1 fake K2 graph requires CUDA Runtime 13.1+"); + } + + void replay() { + DG_HOST_UNREACHABLE( + "Native real K1 fake K2 graph requires CUDA Runtime 13.1+"); + } + + std::tuple get_green_context_ids() const { + DG_HOST_UNREACHABLE( + "Native real K1 fake K2 graph requires CUDA Runtime 13.1+"); + } +}; + +#endif + +} // namespace deep_gemm diff --git a/deep_gemm/__init__.py b/deep_gemm/__init__.py index a9542e2f44..122ee577d2 100644 --- a/deep_gemm/__init__.py +++ b/deep_gemm/__init__.py @@ -84,8 +84,10 @@ from .mega import ( SymmBuffer, get_symm_buffer_for_mega_moe, + get_symm_buffer_for_mega_moe_split, transform_weights_for_mega_moe, fp8_fp4_mega_moe, + SM100FP8FP4MegaMoESplitGraph, ) # Some utils diff --git a/deep_gemm/include/deep_gemm/common/utils.cuh b/deep_gemm/include/deep_gemm/common/utils.cuh index 42c2ca7232..8352fa2cf0 100644 --- a/deep_gemm/include/deep_gemm/common/utils.cuh +++ b/deep_gemm/include/deep_gemm/common/utils.cuh @@ -65,4 +65,108 @@ __device__ __forceinline__ T shfl_sync(unsigned mask, T var, int srcLane, int wi return result; } + +// --- Round-robin "peel" indexing over per-rank token counts (shared by the split-kernel +// dispatch pull). Maps a flat pool slot <-> (rank, token-in-rank) for the contiguous +// expert token pool. --- +template +struct PeelIter { + static constexpr uint32_t kNumRanks = kNumRanks_; + static constexpr uint32_t kNumRanksPerLane = math::constexpr_ceil_div(kNumRanks, 32u); + + uint32_t remaining[kNumRanksPerLane]; + uint32_t slot_base; + uint32_t row_base; + uint32_t num_active_ranks; + uint32_t round_depth; + + CUTLASS_DEVICE explicit PeelIter(const uint32_t (&counts)[kNumRanksPerLane]) + : slot_base(0), row_base(0) + { + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) + remaining[i] = counts[i]; + compute_round(); + } + + CUTLASS_DEVICE void compute_round() { + uint32_t in_lane_actives = 0; + uint32_t in_lane_min = 0xffffffffu; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) { + in_lane_actives += remaining[i] > 0; + if (remaining[i] > 0) + in_lane_min = remaining[i] < in_lane_min ? remaining[i] : in_lane_min; + } + num_active_ranks = __reduce_add_sync(0xffffffffu, in_lane_actives); + round_depth = __reduce_min_sync(0xffffffffu, in_lane_min); + } + + CUTLASS_DEVICE void advance() { + slot_base += round_depth * num_active_ranks; + row_base += round_depth; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) + remaining[i] -= remaining[i] < round_depth ? remaining[i] : round_depth; + compute_round(); + } + + CUTLASS_DEVICE uint32_t select_active_rank(const uint32_t& active_rank_idx) const { + uint32_t rank = 0; + uint32_t seen = 0; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) { + const uint32_t mask = __ballot_sync(0xffffffffu, remaining[i] > 0); + const uint32_t active_lanes = __popc(mask); + if (active_rank_idx >= seen and active_rank_idx < seen + active_lanes) + rank = i * 32 + __fns(mask, 0, active_rank_idx - seen + 1); + seen += active_lanes; + } + return rank; + } + + CUTLASS_DEVICE uint32_t rank_order_of(const uint32_t& target_rank) const { + uint32_t order = 0; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) { + const uint32_t mask = __ballot_sync(0xffffffffu, remaining[i] > 0); + if (target_rank >= i * 32 + 32) { + order += __popc(mask); + } else if (target_rank >= i * 32) { + order += __popc(mask & ((1u << (target_rank - i * 32)) - 1u)); + } + } + return order; + } +}; + +template +CUTLASS_DEVICE void peel_forward( + const uint32_t (&counts)[math::constexpr_ceil_div(kNumRanks, 32u)], + const uint32_t& slot_idx, + uint32_t& out_rank, + uint32_t& out_token_idx_in_rank) +{ + PeelIter it(counts); + while (slot_idx >= it.slot_base + it.round_depth * it.num_active_ranks) + it.advance(); + const uint32_t local_slot_idx = slot_idx - it.slot_base; + out_rank = it.select_active_rank(local_slot_idx % it.num_active_ranks); + out_token_idx_in_rank = it.row_base + local_slot_idx / it.num_active_ranks; +} + +template +CUTLASS_DEVICE uint32_t peel_inverse( + const uint32_t (&counts)[math::constexpr_ceil_div(kNumRanks, 32u)], + const uint32_t& target_rank, + const uint32_t& target_token_idx_in_rank) +{ + PeelIter it(counts); + while (target_token_idx_in_rank >= it.row_base + it.round_depth) + it.advance(); + return it.slot_base + + (target_token_idx_in_rank - it.row_base) * it.num_active_ranks + + it.rank_order_of(target_rank); +} + } // namespace deep_gemm::utils diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe_split/combine_reduce.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe_split/combine_reduce.cuh new file mode 100644 index 0000000000..32f54ce461 --- /dev/null +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe_split/combine_reduce.cuh @@ -0,0 +1,142 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace deep_gemm::mega_moe_split { + +template < + uint32_t kNumMaxTokensPerRank, + uint32_t kHidden, + uint32_t kIntermediateHidden, + uint32_t kNumExperts, + uint32_t kNumTopk, + uint32_t kNumPaddedSFPoolTokens, + uint32_t kNumRanks, + uint32_t kNumThreads = 256 +> +CUTLASS_GLOBAL __launch_bounds__(kNumThreads, 1) void +sm100_fp8_fp4_mega_moe_split_combine_reduce_impl( + void* y, + uint32_t* state, + const uint32_t num_tokens, + const __grid_constant__ layout::SymBuffer sym_buffer +) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + DG_STATIC_ASSERT(kNumThreads % 32 == 0, "K3 thread count must be warp-aligned"); + DG_STATIC_ASSERT(kHidden % 8 == 0, "Hidden must be divisible by one uint4 of BF16"); + DG_STATIC_ASSERT(kNumTopk <= 32, "Top-k must fit in one warp"); + DG_STATIC_ASSERT(kNumExperts % kNumRanks == 0, "Invalid expert/rank shape"); + + constexpr uint32_t kNumHiddenBytes = kHidden * sizeof(nv_bfloat16); + constexpr uint32_t kNumHiddenVec = kNumHiddenBytes / sizeof(ptx::longlong4_t); + constexpr uint32_t kNumElemsPerVec = sizeof(ptx::longlong4_t) / sizeof(nv_bfloat162); + DG_STATIC_ASSERT(kNumHiddenBytes % sizeof(ptx::longlong4_t) == 0, + "Hidden must be divisible by one 256-bit vector of BF16"); + + const uint32_t thread_idx = threadIdx.x; + const uint32_t token_idx = blockIdx.x; + if (token_idx >= num_tokens) + return; + + const auto workspace = layout::SplitWorkspace( + sym_buffer.get_base_ptr(), kNumRanks, kNumExperts, kNumMaxTokensPerRank, kNumTopk); + + const auto fp8_token_layout = layout::Data(kHidden); + const auto fp8_sf_layout = layout::Data(kHidden / 32); + const auto input_topk_idx_layout = layout::Data(kNumTopk * sizeof(int64_t), false); + const auto input_topk_weights_layout = layout::Data(kNumTopk * sizeof(float), false); + const auto l1_topk_weights_layout = layout::Data(sizeof(float), false); + constexpr uint32_t kNumExpertsPerRank = kNumExperts / kNumRanks; + constexpr uint32_t kNumMaxPoolTokens = layout::get_num_max_pool_tokens( + kNumRanks, kNumMaxTokensPerRank, kNumTopk, kNumExpertsPerRank); + const auto fp8_intermediate_token_layout = layout::Data(kIntermediateHidden); + const auto fp8_intermediate_sf_layout = layout::Data(kIntermediateHidden / 32); + const auto bf16_token_layout = layout::Data(kNumHiddenBytes); + + const auto input_token_buffer = layout::Buffer( + fp8_token_layout, 1, kNumMaxTokensPerRank, + workspace.get_end_ptr()); + const auto input_sf_buffer = layout::Buffer( + fp8_sf_layout, 1, kNumMaxTokensPerRank, + input_token_buffer.get_end_ptr()); + const auto input_topk_idx_buffer = layout::Buffer( + input_topk_idx_layout, 1, kNumMaxTokensPerRank, + input_sf_buffer.get_end_ptr()); + const auto input_topk_weights_buffer = layout::Buffer( + input_topk_weights_layout, 1, kNumMaxTokensPerRank, + input_topk_idx_buffer.get_end_ptr()); + const auto l1_token_buffer = layout::Buffer( + fp8_token_layout, 1, kNumMaxPoolTokens, + input_topk_weights_buffer.get_end_ptr()); + const auto l1_sf_buffer = layout::Buffer( + fp8_sf_layout, 1, kNumPaddedSFPoolTokens, + l1_token_buffer.get_end_ptr()); + const auto l1_topk_weights_buffer = layout::Buffer( + l1_topk_weights_layout, 1, kNumMaxPoolTokens, + l1_sf_buffer.get_end_ptr()); + const auto l2_token_buffer = layout::Buffer( + fp8_intermediate_token_layout, 1, kNumMaxPoolTokens, + l1_topk_weights_buffer.get_end_ptr()); + const auto l2_sf_buffer = layout::Buffer( + fp8_intermediate_sf_layout, 1, kNumPaddedSFPoolTokens, + l2_token_buffer.get_end_ptr()); + const auto combine_token_buffer = layout::Buffer( + bf16_token_layout, kNumTopk, kNumMaxTokensPerRank, + l2_sf_buffer.get_end_ptr()); + + __shared__ int32_t topk_idx_smem[kNumTopk]; + const auto topk_idx_ptr = input_topk_idx_buffer.get_data_buffer(token_idx) + .get_base_ptr(); + if (thread_idx < kNumTopk) + topk_idx_smem[thread_idx] = static_cast(topk_idx_ptr[thread_idx]); + __syncthreads(); + + for (uint32_t vec_idx = thread_idx; vec_idx < kNumHiddenVec; vec_idx += kNumThreads) { + float2 reduced[kNumElemsPerVec]; + #pragma unroll + for (uint32_t elem_idx = 0; elem_idx < kNumElemsPerVec; ++elem_idx) + reduced[elem_idx] = make_float2(0.0f, 0.0f); + + #pragma unroll + for (uint32_t topk_slot_idx = 0; topk_slot_idx < kNumTopk; ++topk_slot_idx) { + const auto valid_topk = static_cast(topk_idx_smem[topk_slot_idx]); + const auto src_ptr = combine_token_buffer.get_rank_buffer(topk_slot_idx) + .get_data_buffer(token_idx) + .get_base_ptr(); + const auto vec_values = ptx::ld_gez_pred(src_ptr + vec_idx, valid_topk); + const auto bf16_values = reinterpret_cast(&vec_values); + #pragma unroll + for (uint32_t elem_idx = 0; elem_idx < kNumElemsPerVec; ++elem_idx) + ptx::accumulate(reduced[elem_idx], bf16_values[elem_idx]); + } + + ptx::longlong4_t casted; + const auto casted_bf16 = reinterpret_cast(&casted); + #pragma unroll + for (uint32_t elem_idx = 0; elem_idx < kNumElemsPerVec; ++elem_idx) + casted_bf16[elem_idx] = __float22bfloat162_rn(reduced[elem_idx]); + + const auto dst_ptr = math::advance_ptr( + y, static_cast(token_idx) * kNumHiddenBytes); + const auto casted_uint4 = reinterpret_cast(&casted); + ptx::st_global_v4_u32(dst_ptr + vec_idx * 2u, casted_uint4[0]); + ptx::st_global_v4_u32(dst_ptr + vec_idx * 2u + 1u, casted_uint4[1]); + } + + if (thread_idx == 0) + atomicAdd(state + get_state_offset(SplitStateOffset::K3DoneElements), 1u); +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only supports sm_100f"); +#endif +} + +} // namespace deep_gemm::mega_moe_split diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe_split/common.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe_split/common.cuh new file mode 100644 index 0000000000..7b6a9a6cfc --- /dev/null +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe_split/common.cuh @@ -0,0 +1,27 @@ +#pragma once + +#include + +#include + +namespace deep_gemm::mega_moe_split { + +// Shared `state` tensor layout for the split-kernel pipeline: small device counters used to +// couple the kernels through the CUDA-graph dependency edges and to report progress to the +// host (consumed L2 blocks, launched CTAs, reduced tokens). +enum class SplitStateOffset : uint32_t { + K1ReadyTasks = 0, + K1DoneBlocks = 1, + K2ClaimCounter = 2, + K2DoneTasks = 3, + K2DoneBlocks = 4, + K3DoneElements = 5, + K2Checksum = 6, + NumOffsets = 7, +}; + +constexpr CUTLASS_HOST_DEVICE uint32_t get_state_offset(const SplitStateOffset offset) { + return static_cast(offset); +} + +} // namespace deep_gemm::mega_moe_split diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe_split/dispatch_l1_swiglu.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe_split/dispatch_l1_swiglu.cuh new file mode 100644 index 0000000000..81cb518f90 --- /dev/null +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe_split/dispatch_l1_swiglu.cuh @@ -0,0 +1,1013 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace deep_gemm::mega_moe_split { + +template < + uint32_t kNumMaxTokensPerRank, + uint32_t kHidden, uint32_t kIntermediateHidden, + uint32_t kNumExperts, uint32_t kNumTopk, + uint32_t kNumExpertsPerWave, + uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K, + uint32_t STORE_BLOCK_M, + uint32_t SF_BLOCK_M, uint32_t SF_BLOCK_N, + uint32_t kNumMaxPoolTokens, + uint32_t kNumPaddedSFPoolTokens, + uint32_t kNumStages, + uint32_t kNumDispatchThreads, uint32_t kNumNonEpilogueThreads, + uint32_t kNumEpilogueThreads, + uint32_t kNumSMs, uint32_t kNumRanks, + float kActivationClamp, + bool kFastMath, + bool kLocalOnly, + uint32_t L1_SHAPE_N = kIntermediateHidden * 2, + uint32_t L1_SHAPE_K = kHidden, + uint32_t L2_SHAPE_N = kHidden, + uint32_t L2_SHAPE_K = kIntermediateHidden, + uint32_t kNumDispatchWarps = kNumDispatchThreads / 32, + uint32_t kNumMMANonEpilogueWarps = kNumNonEpilogueThreads / 32, + uint32_t kNumEpilogueWarps = kNumEpilogueThreads / 32, + uint32_t kNumEpilogueWarpgroups = kNumEpilogueWarps / 4, + uint32_t kNumThreads = kNumDispatchThreads + kNumNonEpilogueThreads + kNumEpilogueThreads, + uint32_t kNumTokensPerWarp = 32 / kNumTopk, + uint32_t kNumExpertsPerRank = kNumExperts / kNumRanks +> +CUTLASS_GLOBAL __launch_bounds__(kNumThreads, 1) void +sm100_fp8_fp4_mega_moe_split_dispatch_l1_swiglu_impl( + int* cumulative_local_expert_recv_stats, + const uint32_t num_tokens, + const __grid_constant__ layout::SymBuffer sym_buffer, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_acts, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_acts_sf, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_weights, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_weights_sf, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_output +) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + using Allocator = cute::TMEM::Allocator2Sm; + + DG_STATIC_ASSERT(kNumDispatchThreads % 128 == 0, "Invalid number of dispatch threads"); + DG_STATIC_ASSERT(kNumNonEpilogueThreads == 128, "Invalid number of MMA non-epilogue threads"); + DG_STATIC_ASSERT(kNumEpilogueThreads % 128 == 0, "Invalid number of MMA epilogue threads"); + DG_STATIC_ASSERT(kNumExperts % kNumRanks == 0, "Invalid number of experts or ranks"); + + const bool is_leader_cta = cute::block_rank_in_cluster() == 0; + const uint32_t sm_idx = blockIdx.x; + const uint32_t thread_idx = threadIdx.x; + const uint32_t warp_idx = cutlass::canonical_warp_idx_sync(); + const uint32_t lane_idx = ptx::get_lane_idx(); + + if (warp_idx == 0) { + cute::prefetch_tma_descriptor(&tensor_map_l1_acts); + cute::prefetch_tma_descriptor(&tensor_map_l1_acts_sf); + cute::prefetch_tma_descriptor(&tensor_map_l1_weights); + cute::prefetch_tma_descriptor(&tensor_map_l1_weights_sf); + cute::prefetch_tma_descriptor(&tensor_map_l1_output); + } + + const auto workspace = layout::SplitWorkspace( + sym_buffer.get_base_ptr(), kNumRanks, kNumExperts, kNumMaxTokensPerRank, kNumTopk); + + constexpr auto fp8_token_layout = layout::Data(kHidden); + constexpr auto bf16_token_layout = layout::Data(kHidden * sizeof(nv_bfloat16)); + constexpr auto fp8_intermediate_token_layout = layout::Data(kIntermediateHidden); + constexpr auto fp8_sf_layout = layout::Data(kHidden / 32); + constexpr auto fp8_intermediate_sf_layout = layout::Data(kIntermediateHidden / 32); + constexpr auto input_topk_idx_layout = layout::Data(kNumTopk * sizeof(int64_t), false); + constexpr auto input_topk_weights_layout = layout::Data(kNumTopk * sizeof(float), false); + constexpr auto l1_topk_weights_layout = layout::Data(sizeof(float), false); + + const auto input_token_buffer = layout::Buffer( + fp8_token_layout, 1, kNumMaxTokensPerRank, + workspace.get_end_ptr()); + const auto input_sf_buffer = layout::Buffer( + fp8_sf_layout, 1, kNumMaxTokensPerRank, + input_token_buffer.get_end_ptr()); + const auto input_topk_idx_buffer = layout::Buffer( + input_topk_idx_layout, 1, kNumMaxTokensPerRank, + input_sf_buffer.get_end_ptr()); + const auto input_topk_weights_buffer = layout::Buffer( + input_topk_weights_layout, 1, kNumMaxTokensPerRank, + input_topk_idx_buffer.get_end_ptr()); + + constexpr uint32_t kGranK = 32; + constexpr uint32_t kNumUTCCPAlignedElems = 128; + DG_STATIC_ASSERT(SF_BLOCK_M == math::constexpr_align(BLOCK_M, kNumUTCCPAlignedElems), "Invalid SF_BLOCK_M"); + DG_STATIC_ASSERT(SF_BLOCK_N == BLOCK_N, "Invalid SF_BLOCK_N"); + + const auto transform_sf_token_idx = [](const uint32_t& token_idx_in_expert) { + const uint32_t idx = token_idx_in_expert % BLOCK_M; + return token_idx_in_expert / BLOCK_M * SF_BLOCK_M + + (idx & ~127u) + (idx & 31u) * 4 + ((idx >> 5) & 3u); + }; + + const auto l1_token_buffer = layout::Buffer( + fp8_token_layout, 1, kNumMaxPoolTokens, + input_topk_weights_buffer.get_end_ptr()); + const auto l1_sf_buffer = layout::Buffer( + fp8_sf_layout, 1, kNumPaddedSFPoolTokens, + l1_token_buffer.get_end_ptr()); + const auto l1_topk_weights_buffer = layout::Buffer( + l1_topk_weights_layout, 1, kNumMaxPoolTokens, + l1_sf_buffer.get_end_ptr()); + const auto l2_token_buffer = layout::Buffer( + fp8_intermediate_token_layout, 1, kNumMaxPoolTokens, + l1_topk_weights_buffer.get_end_ptr()); + const auto l2_sf_buffer = layout::Buffer( + fp8_intermediate_sf_layout, 1, kNumPaddedSFPoolTokens, + l2_token_buffer.get_end_ptr()); + const auto combine_token_buffer = layout::Buffer( + bf16_token_layout, kNumTopk, kNumMaxTokensPerRank, + l2_sf_buffer.get_end_ptr()); + (void) combine_token_buffer; + + using a_dtype_t = cutlass::float_e4m3_t; + using b_dtype_t = cutlass::detail::float_e2m1_unpacksmem_t; + + constexpr uint32_t LAYOUT_AD_M = 128; + constexpr uint32_t UMMA_M = LAYOUT_AD_M * 2; + constexpr uint32_t UMMA_N = BLOCK_M; + constexpr uint32_t UMMA_BLOCK_K = 128; + constexpr uint32_t UMMA_K = 32; + constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / 2; + constexpr uint32_t LOAD_BLOCK_N = BLOCK_N; + DG_STATIC_ASSERT(BLOCK_M % 16 == 0, "Invalid block M"); + DG_STATIC_ASSERT(BLOCK_N == LAYOUT_AD_M, "Invalid block N"); + + constexpr uint32_t kSwizzleAMode = 128; + constexpr uint32_t kSwizzleBMode = 128; + constexpr uint32_t kSwizzleCDMode = 128; + DG_STATIC_ASSERT(BLOCK_N % kSwizzleCDMode == 0, "Invalid block N"); + + constexpr uint32_t kNumEpilogueStages = 2; + constexpr uint32_t kNumTMAStoreStages = 2; + + constexpr uint32_t kSharedMemoryAlignment = 1024; + extern __shared__ __align__(kSharedMemoryAlignment) uint8_t smem_buffer[]; + + constexpr uint32_t L1_OUT_BLOCK_N = BLOCK_N / 2; + constexpr uint32_t SMEM_EXPERT_COUNT_SIZE = + math::constexpr_align(kNumExperts * sizeof(uint32_t), kSharedMemoryAlignment); + constexpr uint32_t SMEM_SEND_BUFFER_SIZE = + math::constexpr_align(fp8_token_layout.get_num_bytes() * kNumDispatchWarps, kSharedMemoryAlignment); + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t); + constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t) * (BLOCK_K / 128); + constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t) * (BLOCK_K / 128); + constexpr uint32_t SMEM_CD_L1_SIZE = + kNumEpilogueWarpgroups * STORE_BLOCK_M * L1_OUT_BLOCK_N * sizeof(cutlass::float_e4m3_t) * kNumTMAStoreStages; + constexpr uint32_t SMEM_CD_L2_SIZE = + kNumEpilogueWarpgroups * STORE_BLOCK_M * BLOCK_N * sizeof(nv_bfloat16); + constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_L1_SIZE > SMEM_CD_L2_SIZE ? SMEM_CD_L1_SIZE : SMEM_CD_L2_SIZE; + constexpr uint32_t SMEM_CD_L1_SIZE_PER_STAGE = SMEM_CD_L1_SIZE / kNumTMAStoreStages; + DG_STATIC_ASSERT(SMEM_CD_SIZE % kSharedMemoryAlignment == 0 and + SMEM_A_SIZE_PER_STAGE % kSharedMemoryAlignment == 0 and + SMEM_B_SIZE_PER_STAGE % kSharedMemoryAlignment == 0, + "Shared memory of CD/A/B must be aligned to 1024 bytes"); + + constexpr uint32_t kNumAccumTmemCols = UMMA_N * kNumEpilogueStages; + constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32; + constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32; + constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols(); + constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols; + constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols; + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + const auto smem_expert_count = reinterpret_cast(smem_buffer); + const auto smem_send_buffers = layout::Buffer( + fp8_token_layout, kNumDispatchWarps, 1, + math::advance_ptr(smem_buffer, SMEM_EXPERT_COUNT_SIZE)); + + auto smem_gemm_base = math::advance_ptr( + smem_buffer, SMEM_EXPERT_COUNT_SIZE + SMEM_SEND_BUFFER_SIZE); + auto smem_cd = utils::PatternVisitor([=](const uint32_t& i) { + return math::advance_ptr(smem_gemm_base, i * SMEM_CD_L1_SIZE_PER_STAGE); + }); + auto smem_a = utils::PatternVisitor([=](const uint32_t& i) { + return math::advance_ptr(smem_gemm_base, SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); + }); + auto smem_b = utils::PatternVisitor([=](const uint32_t& i) { + return math::advance_ptr( + smem_gemm_base, SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + }); + auto sf_start_ptr = math::advance_ptr( + smem_gemm_base, SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto smem_sfa = utils::PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE); + }); + auto smem_sfb = utils::PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE); + }); + auto smem_amax_reduction = reinterpret_cast(smem_sfb[kNumStages]); + auto barrier_start_ptr = reinterpret_cast(smem_amax_reduction + STORE_BLOCK_M * kNumEpilogueWarps / 2); + auto dispatch_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + i; }); + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + kNumDispatchWarps + i; }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + kNumDispatchWarps + kNumStages + i; }); + auto tmem_full_barriers = utils::PatternVisitor([=](const uint32_t& i) { + return barrier_start_ptr + kNumDispatchWarps + kNumStages * 2 + i; + }); + auto tmem_empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { + return barrier_start_ptr + kNumDispatchWarps + kNumStages * 2 + kNumEpilogueStages + i; + }); + auto tmem_ptr_in_smem = reinterpret_cast( + barrier_start_ptr + kNumDispatchWarps + kNumStages * 2 + kNumEpilogueStages * 2 + kNumEpilogueWarps * 2); + + comm::cluster_sync_with_relaxed_arrive(); + + if (warp_idx == 0) { + if (cute::elect_one_sync()) + ptx::st_shared_bulk(smem_expert_count, kNumExperts * sizeof(uint32_t)); + } else if (warp_idx == 1) { + #pragma unroll + for (uint32_t i = lane_idx; i < kNumDispatchWarps; i += 32) + dispatch_barriers[i]->init(1); + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 2) { + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++i) { + full_barriers[i]->init(2 * 2); + empty_barriers[i]->init(1); + } + #pragma unroll + for (uint32_t i = 0; i < kNumEpilogueStages; ++i) { + tmem_full_barriers[i]->init(1); + tmem_empty_barriers[i]->init(2 * kNumEpilogueThreads); + } + } + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 3) { + Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + comm::cluster_sync_with_relaxed_arrive(); + + auto scheduler = sched::MegaMoEScheduler< + BLOCK_M, BLOCK_N, BLOCK_K, + L1_SHAPE_N, L1_SHAPE_K, + L2_SHAPE_N, L2_SHAPE_K, + kNumExpertsPerRank, + kNumExpertsPerWave, + kNumSMs, kNumRanks>(workspace); + + uint32_t stage_idx = 0, phase = 0; + auto advance_pipeline = [&](uint32_t& k_block_idx) { + ++k_block_idx; + stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1; + phase ^= stage_idx == 0; + }; + + constexpr uint32_t kDispatchBarrierIdx = 0; + constexpr uint32_t kDispatchWithEpilogueBarrierIdx = 1; + constexpr uint32_t kEpilogueFullBarrierIdx = 2; + constexpr uint32_t kEpilogueWGBarrierStartIdx = 3; + + constexpr uint32_t kBeforeDispatchPullBarrierTag = 1; + constexpr uint32_t kDispatchGridSyncIndex = 0; + + constexpr uint32_t kNumDispatchRegisters = 48; + constexpr uint32_t kNumNonEpilogueRegisters = 40; + constexpr uint32_t kNumEpilogueRegisters = 208; + DG_STATIC_ASSERT(kNumDispatchRegisters * kNumDispatchThreads + + kNumNonEpilogueRegisters * kNumNonEpilogueThreads + + kNumEpilogueRegisters * kNumEpilogueThreads <= 64512, + "Too many registers"); + + DG_STATIC_ASSERT(kNumTopk <= 64 and kNumExpertsPerRank <= 1024 and kNumMaxTokensPerRank <= 65536, + "Route entry bit-packing constraints"); + const auto pull_meta_base = workspace.get_src_token_topk_idx_ptr(0, 0, 0); + constexpr uint32_t kRouteCountOffset = kNumExpertsPerRank * kNumRanks * kNumMaxTokensPerRank; + constexpr uint32_t kRouteEntriesOffset = kRouteCountOffset + kNumRanks * kNumMaxTokensPerRank; + + if (warp_idx < kNumDispatchWarps) { + cutlass::arch::warpgroup_reg_dealloc(); + + DG_STATIC_ASSERT(kNumTopk <= 32, "Invalid number of topk"); + constexpr uint32_t kNumActivateLanes = kNumTokensPerWarp * kNumTopk; + const auto read_topk_idx = [&](const auto& process) { + #pragma unroll + for (uint32_t i = (sm_idx * kNumDispatchWarps + warp_idx) * kNumTokensPerWarp; + i < num_tokens; + i += kNumSMs * kNumDispatchWarps * kNumTokensPerWarp) { + int32_t expert_idx = -1; + if (i + lane_idx / kNumTopk < num_tokens and lane_idx < kNumActivateLanes) { + expert_idx = static_cast( + __ldg(input_topk_idx_buffer.get_base_ptr() + i * kNumTopk + lane_idx)); + if (expert_idx >= 0) + process(i * kNumTopk + lane_idx, expert_idx); + } + __syncwarp(); + } + }; + + read_topk_idx([&](const uint32_t& token_topk_idx, const int32_t& expert_idx) { + (void) token_topk_idx; + atomicAdd_block(smem_expert_count + static_cast(expert_idx), 1); + }); + ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); + + #pragma unroll + for (uint32_t i = thread_idx; i < kNumExperts; i += kNumDispatchThreads) { + const uint64_t send_value = (1ull << 32) | static_cast(smem_expert_count[i]); + smem_expert_count[i] = static_cast( + ptx::atomic_add(workspace.get_expert_send_count_ptr(i), send_value)); + } + ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); + + #pragma unroll + for (uint32_t i = (sm_idx * kNumDispatchWarps + warp_idx) * kNumTokensPerWarp; + i < num_tokens; + i += kNumSMs * kNumDispatchWarps * kNumTokensPerWarp) { + const uint32_t src_token_idx = i + lane_idx / kNumTopk; + const bool in_range = src_token_idx < num_tokens and lane_idx < kNumActivateLanes; + int32_t expert_idx = -1; + if (in_range) + expert_idx = static_cast( + __ldg(input_topk_idx_buffer.get_base_ptr() + i * kNumTopk + lane_idx)); + const bool routed = in_range and expert_idx >= 0; + const int32_t dst_rank_idx = routed ? expert_idx / static_cast(kNumExpertsPerRank) : 0; + const uint32_t local_expert_idx = routed ? static_cast(expert_idx) % kNumExpertsPerRank : 0u; + + uint32_t dst_slot_idx = 0; + if (routed) { + dst_slot_idx = atomicAdd_block(smem_expert_count + static_cast(expert_idx), 1); + *sym_buffer.map(pull_meta_base + (local_expert_idx * kNumRanks + sym_buffer.rank_idx) + * kNumMaxTokensPerRank + dst_slot_idx, static_cast(dst_rank_idx)) = + i * kNumTopk + lane_idx; + } + + #pragma unroll + for (uint32_t dst_idx = 0; dst_idx < kNumRanks; ++dst_idx) { + const uint32_t targets_dst = __ballot_sync(0xffffffffu, routed and dst_rank_idx == int32_t(dst_idx)); + const uint32_t group_base = (lane_idx / kNumTopk) * kNumTopk; + const uint32_t our_routes = targets_dst & (((1u << kNumTopk) - 1u) << group_base); + const uint32_t route_count = __popc(our_routes); + if (in_range and lane_idx == group_base) + *sym_buffer.map(pull_meta_base + kRouteCountOffset + + sym_buffer.rank_idx * kNumMaxTokensPerRank + src_token_idx, dst_idx) = + route_count; + if (routed and dst_rank_idx == int32_t(dst_idx) and route_count > 1) { + const uint32_t packed = (dst_slot_idx << 16) | (local_expert_idx << 6) | (lane_idx - group_base); + const uint32_t slot_idx = __popc(our_routes & ((1u << lane_idx) - 1u)); + *sym_buffer.map(pull_meta_base + kRouteEntriesOffset + + (sym_buffer.rank_idx * kNumMaxTokensPerRank + src_token_idx) * kNumTopk + slot_idx, + dst_idx) = packed; + } + } + __syncwarp(); + } + + comm::grid_sync( + workspace, sm_idx, thread_idx, + [=]() { ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); }); + + if (sm_idx == 0) { + if constexpr (kLocalOnly) { + #pragma unroll + for (uint32_t i = thread_idx; i < kNumExpertsPerRank; i += kNumDispatchThreads) { + const auto expert_status = *workspace.get_expert_send_count_ptr(i); + const auto num_recv_tokens = static_cast(expert_status); + #pragma unroll + for (uint32_t rank_idx = 0; rank_idx < kNumRanks; ++rank_idx) { + *workspace.get_expert_recv_count_ptr(rank_idx, i) = + rank_idx == 0 ? num_recv_tokens : 0u; + } + *workspace.get_expert_recv_count_sum_ptr(i) = + (static_cast(kNumSMs) * kNumRanks << 32) | num_recv_tokens; + } + } else { + #pragma unroll + for (uint32_t i = thread_idx; i < kNumExperts; i += kNumDispatchThreads) { + const auto dst_rank_idx = i / kNumExpertsPerRank; + const auto dst_local_expert_idx = i % kNumExpertsPerRank; + const auto expert_status = *workspace.get_expert_send_count_ptr(i); + if constexpr (kNumRanks == 1) { + *workspace.get_expert_recv_count_ptr(sym_buffer.rank_idx, dst_local_expert_idx) = + expert_status & 0xffffffff; + ptx::atomic_add( + workspace.get_expert_recv_count_sum_ptr(dst_local_expert_idx), + expert_status); + } else { + *sym_buffer.map( + workspace.get_expert_recv_count_ptr(sym_buffer.rank_idx, dst_local_expert_idx), + dst_rank_idx) = expert_status & 0xffffffff; + ptx::atomic_add_sys( + sym_buffer.map(workspace.get_expert_recv_count_sum_ptr(dst_local_expert_idx), dst_rank_idx), + expert_status); + } + } + } + } + ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); + + if constexpr (kLocalOnly or kNumRanks == 1) { + comm::grid_sync( + workspace, sm_idx, thread_idx, + [=]() { ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); }); + } else { + comm::nvlink_barrier( + workspace, sym_buffer, sm_idx, thread_idx, + [=]() { ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); }, + false, + true); + } + + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + + uint32_t pull_mbarrier_phase = 0; + const auto pull_buffer = smem_send_buffers.get_rank_buffer(warp_idx).get_data_buffer(0); + const auto pull_mbarrier = dispatch_barriers[warp_idx]; + + scheduler.fetch_expert_recv_count(); + + const auto smem_per_rank_count = smem_expert_count; + #pragma unroll + for (uint32_t i = thread_idx; i < kNumExpertsPerRank * kNumRanks; i += kNumDispatchThreads) + smem_per_rank_count[i] = static_cast( + *workspace.get_expert_recv_count_ptr(i % kNumRanks, i / kNumRanks)); + ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); + + constexpr uint32_t kNumRanksPerLane = math::constexpr_ceil_div(kNumRanks, 32u); + int32_t current_expert_idx = -1; + uint32_t stored_rank_count[kNumRanksPerLane] = {}; + uint32_t expert_start_idx = 0, expert_end_idx = 0; + uint32_t expert_pool_block_offset = 0; + + constexpr uint32_t kNumGlobalWarps = kNumSMs * kNumDispatchWarps; + for (uint32_t token_idx = sm_idx * kNumDispatchWarps + warp_idx; ; token_idx += kNumGlobalWarps) { + int32_t old_expert_idx = current_expert_idx; + while (token_idx >= expert_end_idx) { + if (++current_expert_idx >= int32_t(kNumExpertsPerRank)) + break; + + expert_pool_block_offset += math::ceil_div(expert_end_idx - expert_start_idx, BLOCK_M); + expert_start_idx = expert_end_idx; + expert_end_idx += scheduler.get_num_tokens(static_cast(current_expert_idx)); + } + + if (current_expert_idx >= int32_t(kNumExpertsPerRank)) + break; + + if (old_expert_idx != current_expert_idx) { + old_expert_idx = current_expert_idx; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++i) { + const uint32_t j = i * 32 + lane_idx; + stored_rank_count[i] = j < kNumRanks + ? smem_per_rank_count[static_cast(current_expert_idx) * kNumRanks + j] : 0u; + } + } + + uint32_t current_rank_in_expert_idx, token_idx_in_rank; + utils::peel_forward(stored_rank_count, + token_idx - expert_start_idx, + current_rank_in_expert_idx, + token_idx_in_rank); + + const uint32_t src_token_topk_idx = *(pull_meta_base + + (static_cast(current_expert_idx) * kNumRanks + current_rank_in_expert_idx) + * kNumMaxTokensPerRank + + token_idx_in_rank); + const uint32_t src_token_idx = src_token_topk_idx / kNumTopk; + + uint32_t route_count = 0; + if (lane_idx == 0) { + const auto route_count_ptr = pull_meta_base + kRouteCountOffset + + current_rank_in_expert_idx * kNumMaxTokensPerRank + src_token_idx; + route_count = *route_count_ptr; + if (route_count > 1) + route_count = atomicExch(route_count_ptr, 0u); + } + route_count = __shfl_sync(0xffffffffu, route_count, 0); + if (route_count == 0) + continue; + + if (cute::elect_one_sync()) { + ptx::tma_load_1d( + pull_buffer.get_base_ptr(), + sym_buffer.map(input_token_buffer.get_data_buffer(src_token_idx).get_base_ptr(), + current_rank_in_expert_idx), + pull_mbarrier, kHidden); + } + __syncwarp(); + + constexpr uint32_t kNumSFUint32 = kHidden / 128; + constexpr uint32_t kNumSFUint32PerLane = math::constexpr_ceil_div(kNumSFUint32, 32u); + DG_STATIC_ASSERT(kNumSFUint32 > 0 and kHidden % 128 == 0, "Invalid SF"); + const auto remote_sf_ptr = sym_buffer.map( + input_sf_buffer.get_data_buffer(src_token_idx).get_base_ptr(), + current_rank_in_expert_idx); + uint32_t stored_sf[kNumSFUint32PerLane] = {}; + #pragma unroll + for (uint32_t i = 0; i < kNumSFUint32PerLane; ++i) { + const uint32_t j = i * 32 + lane_idx; + if (j < kNumSFUint32) + stored_sf[i] = remote_sf_ptr[j]; + } + float stored_weight = 0.0f; + if (lane_idx < kNumTopk) + stored_weight = *sym_buffer.map( + input_topk_weights_buffer.get_base_ptr() + src_token_idx * kNumTopk + lane_idx, + current_rank_in_expert_idx); + + if (cute::elect_one_sync()) { + ptx::mbarrier_arrive_and_set_tx(pull_mbarrier, kHidden); + ptx::mbarrier_wait_and_flip_phase(pull_mbarrier, pull_mbarrier_phase); + } + __syncwarp(); + + const auto local_sf_ptr = l1_sf_buffer.get_base_ptr(); + if (route_count == 1) { + const uint32_t route_topk_slot_idx = src_token_topk_idx - src_token_idx * kNumTopk; + const uint32_t route_token_idx_in_expert = token_idx - expert_start_idx; + const uint32_t pool_token_idx = expert_pool_block_offset * BLOCK_M + route_token_idx_in_expert; + const uint32_t sf_pool_token_idx = expert_pool_block_offset * SF_BLOCK_M + + transform_sf_token_idx(route_token_idx_in_expert); + const uint32_t stored_pool_block_idx = expert_pool_block_offset + route_token_idx_in_expert / BLOCK_M; + const float weight = __shfl_sync(0xffffffffu, stored_weight, route_topk_slot_idx); + __syncwarp(); + if (cute::elect_one_sync()) { + ptx::tma_store_1d( + l1_token_buffer.get_data_buffer(pool_token_idx).get_base_ptr(), + pull_buffer.get_base_ptr(), pull_buffer.get_num_bytes()); + *l1_topk_weights_buffer.get_data_buffer(pool_token_idx).get_base_ptr() = weight; + *workspace.get_token_src_metadata_ptr(pool_token_idx) = + {current_rank_in_expert_idx, src_token_idx, route_topk_slot_idx}; + cute::tma_store_arrive(); + } + + #pragma unroll + for (uint32_t i = 0; i < kNumSFUint32PerLane; ++i) { + const uint32_t j = i * 32 + lane_idx; + if (j < kNumSFUint32) + local_sf_ptr[j * kNumPaddedSFPoolTokens + sf_pool_token_idx] = stored_sf[i]; + } + __syncwarp(); + + if (cute::elect_one_sync()) { + ptx::tma_store_wait<0>(); + ptx::red_add_rel(workspace.get_l1_arrival_count_ptr(stored_pool_block_idx), 1); + } + __syncwarp(); + continue; + } + + uint32_t stored_pool_blocks[kNumTopk] = {}; + #pragma unroll + for (uint32_t route_idx = 0; route_idx < route_count; ++route_idx) { + const uint32_t packed = *(pull_meta_base + kRouteEntriesOffset + + (current_rank_in_expert_idx * kNumMaxTokensPerRank + src_token_idx) * kNumTopk + route_idx); + const uint32_t route_topk_slot_idx = packed & 0x3Fu; + const uint32_t route_local_expert_idx = (packed >> 6) & 0x3FFu; + const uint32_t route_token_idx_in_rank = packed >> 16; + + uint32_t route_token_idx_in_expert = token_idx - expert_start_idx; + if (route_local_expert_idx != static_cast(current_expert_idx) + or route_token_idx_in_rank != token_idx_in_rank) { + uint32_t route_rank_counts[kNumRanksPerLane]; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++i) { + const uint32_t j = i * 32 + lane_idx; + route_rank_counts[i] = j < kNumRanks + ? smem_per_rank_count[route_local_expert_idx * kNumRanks + j] : 0u; + } + route_token_idx_in_expert = utils::peel_inverse( + route_rank_counts, current_rank_in_expert_idx, route_token_idx_in_rank); + } + + const uint32_t route_pool_block_offset = scheduler.get_pool_block_offset(route_local_expert_idx); + const uint32_t pool_token_idx = route_pool_block_offset * BLOCK_M + route_token_idx_in_expert; + const uint32_t sf_pool_token_idx = route_pool_block_offset * SF_BLOCK_M + + transform_sf_token_idx(route_token_idx_in_expert); + stored_pool_blocks[route_idx] = route_pool_block_offset + route_token_idx_in_expert / BLOCK_M; + const float weight = __shfl_sync(0xffffffffu, stored_weight, route_topk_slot_idx); + __syncwarp(); + if (cute::elect_one_sync()) { + ptx::tma_store_1d( + l1_token_buffer.get_data_buffer(pool_token_idx).get_base_ptr(), + pull_buffer.get_base_ptr(), pull_buffer.get_num_bytes()); + *l1_topk_weights_buffer.get_data_buffer(pool_token_idx).get_base_ptr() = weight; + *workspace.get_token_src_metadata_ptr(pool_token_idx) = + {current_rank_in_expert_idx, src_token_idx, route_topk_slot_idx}; + cute::tma_store_arrive(); + } + + #pragma unroll + for (uint32_t i = 0; i < kNumSFUint32PerLane; ++i) { + const uint32_t j = i * 32 + lane_idx; + if (j < kNumSFUint32) + local_sf_ptr[j * kNumPaddedSFPoolTokens + sf_pool_token_idx] = stored_sf[i]; + } + __syncwarp(); + + if (cute::elect_one_sync()) { + ptx::tma_store_wait<0>(); + ptx::red_add_rel(workspace.get_l1_arrival_count_ptr(stored_pool_blocks[route_idx]), 1); + } + } + + __syncwarp(); + } + + if (sm_idx == 0 and cumulative_local_expert_recv_stats != nullptr) { + #pragma unroll + for (uint32_t i = thread_idx; i < kNumExpertsPerRank; i += kNumDispatchThreads) { + const auto num_recv_tokens = static_cast( + *workspace.get_expert_recv_count_sum_ptr(i)); + ptx::red_add(cumulative_local_expert_recv_stats + i, static_cast(num_recv_tokens)); + } + } + } else if (warp_idx == kNumDispatchWarps) { + cutlass::arch::warpgroup_reg_dealloc(); + + scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + (void) local_expert_idx; + (void) n_block_idx; + if (block_phase != sched::BlockPhase::Linear1) + return; + + const uint32_t pool_block_idx = scheduler.get_current_pool_block_offset() + m_block_idx; + const auto ptr = workspace.get_l1_arrival_count_ptr(pool_block_idx); + const auto expected = scheduler.template get_valid_m(); + while (ptx::ld_acq(ptr) != expected); + + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { + empty_barriers[stage_idx]->wait(phase ^ 1); + + uint32_t m_idx = pool_block_idx * BLOCK_M; + uint32_t k_idx = k_block_idx * BLOCK_K; + uint32_t sfa_m_idx = pool_block_idx * SF_BLOCK_M; + uint32_t sfa_k_idx = k_block_idx * (BLOCK_K / 128); + if (not is_leader_cta) + m_idx += scheduler.template get_valid_m() / 2; + + if (cute::elect_one_sync()) { + tma::copy( + &tensor_map_l1_acts, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx, 2); + tma::copy( + &tensor_map_l1_acts_sf, full_barriers[stage_idx], smem_sfa[stage_idx], sfa_m_idx, sfa_k_idx, 2); + if (is_leader_cta) { + full_barriers[stage_idx]->arrive_and_expect_tx( + SMEM_A_SIZE_PER_STAGE * 2 + SMEM_SFA_SIZE_PER_STAGE * 2); + } else { + full_barriers[stage_idx]->arrive(0u); + } + } + __syncwarp(); + } + }); + } else if (warp_idx == kNumDispatchWarps + 1) { + cutlass::arch::warpgroup_reg_dealloc(); + + scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + (void) m_block_idx; + if (block_phase != sched::BlockPhase::Linear1) + return; + + const auto shape_k = L1_SHAPE_K; + const auto shape_n = L1_SHAPE_N; + const auto shape_sfb_k = math::ceil_div(shape_k, kGranK * 4u); + + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { + empty_barriers[stage_idx]->wait(phase ^ 1); + + uint32_t n_idx = local_expert_idx * shape_n + n_block_idx * BLOCK_N; + uint32_t k_idx = k_block_idx * BLOCK_K; + uint32_t sfb_n_idx = n_block_idx * BLOCK_N; + uint32_t sfb_k_idx = local_expert_idx * shape_sfb_k + k_block_idx * (BLOCK_K / 128); + + if (cute::elect_one_sync()) { + tma::copy( + &tensor_map_l1_weights, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx, 2); + tma::copy( + &tensor_map_l1_weights_sf, full_barriers[stage_idx], smem_sfb[stage_idx], sfb_n_idx, sfb_k_idx, 2); + if (is_leader_cta) { + full_barriers[stage_idx]->arrive_and_expect_tx(SMEM_B_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE * 2); + } else { + full_barriers[stage_idx]->arrive(0u); + } + } + __syncwarp(); + } + }); + } else if (warp_idx == kNumDispatchWarps + 2) { + cutlass::arch::warpgroup_reg_dealloc(); + + if (is_leader_cta) { + auto instr_desc = cute::UMMA::make_instr_desc_block_scaled< + b_dtype_t, a_dtype_t, float, cutlass::float_ue8m0_t, + UMMA_M, UMMA_N, + cute::UMMA::Major::K, cute::UMMA::Major::K>(); + auto sf_desc = mma::sm100::make_sf_desc(nullptr); + + DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); + auto a_desc = mma::sm100::make_umma_desc(smem_a[0], 0, 0); + auto b_desc = mma::sm100::make_umma_desc(smem_b[0], 0, 0); + uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; + uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; + + DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), + "Invalid MMA instruction shape"); + + uint32_t current_iter_idx = 0; + scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + (void) local_expert_idx; + (void) m_block_idx; + (void) n_block_idx; + if (block_phase != sched::BlockPhase::Linear1) + return; + + mma::sm100::update_instr_desc_with_umma_n(instr_desc, scheduler.template get_valid_m()); + + const auto accum_stage_idx = current_iter_idx % kNumEpilogueStages; + const auto accum_phase = (current_iter_idx++ / kNumEpilogueStages) & 1; + tmem_empty_barriers[accum_stage_idx]->wait(accum_phase ^ 1); + ptx::tcgen05_after_thread_sync(); + + auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) { + auto umma_arrive = [](const uint64_t* barrier) { + constexpr uint16_t kCTAMask = (1 << 2) - 1; + cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask); + }; + umma_arrive(reinterpret_cast(empty_barriers[stage_idx])); + if (do_tmem_full_arrive) + umma_arrive(reinterpret_cast(tmem_full_barriers[accum_stage_idx])); + __syncwarp(); + }; + + #pragma unroll 2 + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { + full_barriers[stage_idx]->wait(phase); + ptx::tcgen05_after_thread_sync(); + + const auto a_desc_base_lo = ptx::exchange(a_desc_lo, stage_idx); + const auto b_desc_base_lo = ptx::exchange(b_desc_lo, stage_idx); + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t umma_k_block_idx = 0; umma_k_block_idx < BLOCK_K / UMMA_BLOCK_K; ++umma_k_block_idx) { + using cute_utccp_t = cute::SM100_UTCCP_4x32dp128bit_2cta; + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++i) { + auto smem_ptr = smem_sfa[stage_idx] + umma_k_block_idx * SF_BLOCK_M + i * kNumUTCCPAlignedElems; + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4); + } + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++i) { + auto smem_ptr = smem_sfb[stage_idx] + umma_k_block_idx * SF_BLOCK_N + i * kNumUTCCPAlignedElems; + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4); + } + + #pragma unroll + for (uint32_t k = 0; k < UMMA_BLOCK_K / UMMA_K; ++k) { + const auto runtime_instr_desc = + mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, k, k); + a_desc.lo = mma::sm100::advance_umma_desc_lo< + cute::UMMA::Major::K, LOAD_BLOCK_M, kSwizzleAMode, a_dtype_t>( + a_desc_base_lo, umma_k_block_idx * UMMA_BLOCK_K * LOAD_BLOCK_M * sizeof(a_dtype_t), k * UMMA_K); + b_desc.lo = mma::sm100::advance_umma_desc_lo< + cute::UMMA::Major::K, LOAD_BLOCK_N, kSwizzleBMode, b_dtype_t>( + b_desc_base_lo, umma_k_block_idx * UMMA_BLOCK_K * LOAD_BLOCK_N * sizeof(b_dtype_t), k * UMMA_K); + ptx::SM100_MMA_MXF8F6F4_2x1SM_SS::fma( + b_desc, a_desc, accum_stage_idx * UMMA_N, + k_block_idx > 0 or umma_k_block_idx > 0 or k > 0, runtime_instr_desc, + kTmemStartColOfSFB, kTmemStartColOfSFA); + } + } + } + __syncwarp(); + + empty_barrier_arrive(k_block_idx == num_k_blocks - 1); + } + }); + + if (current_iter_idx > 0) { + const auto accum_phase_idx = ((current_iter_idx - 1) / kNumEpilogueStages) & 1; + tmem_empty_barriers[(current_iter_idx - 1) % kNumEpilogueStages]->wait(accum_phase_idx); + } + } + } else if (warp_idx == kNumDispatchWarps + 3) { + cutlass::arch::warpgroup_reg_dealloc(); + } else if (warp_idx >= kNumDispatchWarps + kNumMMANonEpilogueWarps) { + cutlass::arch::warpgroup_reg_alloc(); + + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); + + const auto epilogue_warp_idx = warp_idx - (kNumDispatchWarps + kNumMMANonEpilogueWarps); + const auto epilogue_wg_idx = epilogue_warp_idx / 4; + const auto warp_idx_in_wg = epilogue_warp_idx % 4; + DG_STATIC_ASSERT((kNumDispatchWarps + kNumMMANonEpilogueWarps) % 4 == 0 and + kNumEpilogueWarps % 4 == 0, "Invalid epilogue warps"); + + constexpr uint32_t WG_BLOCK_M = BLOCK_M / kNumEpilogueWarpgroups; + constexpr uint32_t ATOM_M = 8; + constexpr uint32_t kNumBankGroupBytes = 16u; + constexpr uint32_t kNumAtomsPerStore = STORE_BLOCK_M / ATOM_M; + DG_STATIC_ASSERT(BLOCK_M % kNumEpilogueWarpgroups == 0, "Invalid block M"); + DG_STATIC_ASSERT(WG_BLOCK_M % STORE_BLOCK_M == 0, "Invalid warpgroup block M"); + DG_STATIC_ASSERT(STORE_BLOCK_M % ATOM_M == 0, "Invalid store block M"); + DG_STATIC_ASSERT(BLOCK_N == 128, "Invalid block N"); + + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + + uint32_t current_iter_idx = 0; + scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + (void) local_expert_idx; + (void) num_k_blocks; + if (block_phase != sched::BlockPhase::Linear1) + return; + + const auto accum_stage_idx = current_iter_idx % kNumEpilogueStages; + const auto accum_phase = (current_iter_idx++ / kNumEpilogueStages) & 1; + tmem_full_barriers[accum_stage_idx]->wait(accum_phase); + ptx::tcgen05_after_thread_sync(); + + const uint32_t valid_m = ptx::exchange(scheduler.template get_valid_m(), 0); + const uint32_t pool_block_idx = scheduler.get_current_pool_block_offset() + m_block_idx; + uint32_t m_idx = pool_block_idx * BLOCK_M; + + float stored_cached_weight = 0; + #pragma unroll + for (uint32_t s = 0; s < WG_BLOCK_M / STORE_BLOCK_M; ++s) { + if (epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M >= valid_m) { + ptx::tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + break; + } + + float2 swiglu_values[kNumAtomsPerStore * 2]; + float2 amax_values[kNumAtomsPerStore]; + #pragma unroll + for (uint32_t i = 0; i < kNumAtomsPerStore; ++i) { + const uint32_t j = s * kNumAtomsPerStore + i; + + DG_STATIC_ASSERT(32 % ATOM_M == 0, "Invalid block size"); + if ((j * ATOM_M) % 32 == 0 and (WG_BLOCK_M % 32 == 0 or j * ATOM_M + lane_idx < WG_BLOCK_M)) { + stored_cached_weight = *l1_topk_weights_buffer + .get_data_buffer(m_idx + epilogue_wg_idx * WG_BLOCK_M + j * ATOM_M + lane_idx) + .get_base_ptr(); + } + + const float2 weights = { + ptx::exchange(stored_cached_weight, (j * ATOM_M) % 32 + (lane_idx % 4) * 2 + 0), + ptx::exchange(stored_cached_weight, (j * ATOM_M) % 32 + (lane_idx % 4) * 2 + 1) + }; + + uint32_t tmem_addr = accum_stage_idx * UMMA_N + epilogue_wg_idx * WG_BLOCK_M + j * ATOM_M; + uint32_t values[ATOM_M]; + cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr | 0x00100000, + values[4], values[5], values[6], values[7]); + cutlass::arch::fence_view_async_tmem_load(); + + if (j == WG_BLOCK_M / ATOM_M - 1) { + ptx::tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + } + + auto fp32_values = reinterpret_cast(values); + #pragma unroll + for (uint32_t k = 0; k < 2; ++k) { + auto bf16_gate = __float22bfloat162_rn(make_float2(fp32_values[k * 4], fp32_values[k * 4 + 1])); + auto bf16_up = __float22bfloat162_rn(make_float2(fp32_values[k * 4 + 2], fp32_values[k * 4 + 3])); + + if constexpr (kActivationClamp != cute::numeric_limits::infinity()) { + bf16_gate = __hmin2(bf16_gate, {kActivationClamp, kActivationClamp}); + bf16_up = __hmax2(bf16_up, {-kActivationClamp, -kActivationClamp}); + bf16_up = __hmin2(bf16_up, {kActivationClamp, kActivationClamp}); + } + + auto gate = __bfloat1622float2(bf16_gate); + auto neg_gate_exp = make_float2( + kFastMath ? __expf(-gate.x) : expf(-gate.x), + kFastMath ? __expf(-gate.y) : expf(-gate.y)); + const auto denom = __fadd2_rn({1.0f, 1.0f}, neg_gate_exp); + if constexpr (kFastMath) { + gate = __fmul2_rn(gate, {math::fast_rcp(denom.x), math::fast_rcp(denom.y)}); + } else { + gate = {gate.x / denom.x, gate.y / denom.y}; + } + const auto up = __bfloat1622float2(bf16_up); + swiglu_values[i * 2 + k] = __fmul2_rn(__fmul2_rn(gate, up), weights); + } + + amax_values[i].x = math::warp_reduce<4, true>( + cute::max(cute::abs(swiglu_values[i * 2 + 0].x), cute::abs(swiglu_values[i * 2 + 1].x)), + math::ReduceMax()); + amax_values[i].y = math::warp_reduce<4, true>( + cute::max(cute::abs(swiglu_values[i * 2 + 0].y), cute::abs(swiglu_values[i * 2 + 1].y)), + math::ReduceMax()); + if (lane_idx < 4) + smem_amax_reduction[epilogue_warp_idx * (STORE_BLOCK_M / 2) + i * (ATOM_M / 2) + lane_idx] = amax_values[i]; + __syncwarp(); + } + + const uint32_t tma_stage_idx = s % kNumTMAStoreStages; + ptx::tma_store_wait(); + ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); + + #pragma unroll + for (uint32_t i = 0; i < kNumAtomsPerStore; ++i) { + const float2 wp_amax = + smem_amax_reduction[(epilogue_warp_idx ^ 1) * (STORE_BLOCK_M / 2) + i * (ATOM_M / 2) + lane_idx % 4]; + amax_values[i].x = cute::max(amax_values[i].x, wp_amax.x); + amax_values[i].y = cute::max(amax_values[i].y, wp_amax.y); + + float2 sf, sf_inv; + math::get_e4m3_sf_and_sf_inv(amax_values[i], sf, sf_inv); + + const float2 upper = __fmul2_rn(swiglu_values[i * 2 + 0], sf_inv); + const float2 lower = __fmul2_rn(swiglu_values[i * 2 + 1], sf_inv); + const auto fp8x4_values = __nv_fp8x4_e4m3(make_float4(upper.x, upper.y, lower.x, lower.y)); + + uint32_t row = lane_idx; + uint32_t col = warp_idx_in_wg; + const auto smem_ptr = smem_cd[tma_stage_idx] + epilogue_wg_idx * STORE_BLOCK_M * L1_OUT_BLOCK_N + + i * ATOM_M * L1_OUT_BLOCK_N + + row * L1_OUT_BLOCK_N + + (col ^ (row / 2)) * kNumBankGroupBytes; + ptx::SM100_U8x4_STSM_T<__nv_fp8x4_e4m3>::copy(fp8x4_values, smem_ptr); + + if (warp_idx_in_wg % 2 == 0 and lane_idx < 4) { + const uint32_t k_idx = n_block_idx * 2 + warp_idx_in_wg / 2; + const uint32_t k_uint_idx = k_idx / 4, byte_idx = k_idx % 4; + const uint32_t mn_stride = kNumPaddedSFPoolTokens * sizeof(uint32_t); + const auto sf_base_ptr = l2_sf_buffer.get_base_ptr(); + const uint32_t token_base_idx = epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M; + __builtin_assume(token_base_idx < BLOCK_M); + const auto sf_pool_token_idx = scheduler.get_current_pool_block_offset() * SF_BLOCK_M + + m_block_idx * SF_BLOCK_M + transform_sf_token_idx(token_base_idx) + (lane_idx * 2) * 4; + const auto sf_addr = k_uint_idx * mn_stride + sf_pool_token_idx * static_cast(sizeof(uint32_t)) + byte_idx; + sf_base_ptr[sf_addr] = + (*reinterpret_cast(&sf.x) >> 23); + sf_base_ptr[sf_addr + 4 * static_cast(sizeof(uint32_t))] = + (*reinterpret_cast(&sf.y) >> 23); + } + __syncwarp(); + } + ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); + + if (warp_idx_in_wg == 0 and cute::elect_one_sync()) { + uint32_t out_n_idx = n_block_idx * L1_OUT_BLOCK_N; + cute::tma_store_fence(); + cute::SM90_TMA_STORE_2D::copy( + &tensor_map_l1_output, + smem_cd[tma_stage_idx] + epilogue_wg_idx * STORE_BLOCK_M * L1_OUT_BLOCK_N, + out_n_idx, + m_idx + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M); + cute::tma_store_arrive(); + } + __syncwarp(); + } + + ptx::tma_store_wait<0>(); + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { + DG_STATIC_ASSERT(L2_SHAPE_K <= 64 * L1_OUT_BLOCK_N, "L2 shape K is too large"); + ptx::red_or_rel_gpu( + workspace.get_l2_arrival_mask_ptr(pool_block_idx), + 1ull << n_block_idx); + } + __syncwarp(); + }); + + if (epilogue_warp_idx == 0) + Allocator().free(0, kNumTmemCols); + } +#endif +} + +} // namespace deep_gemm::mega_moe_split diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe_split/l2_combine.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe_split/l2_combine.cuh new file mode 100644 index 0000000000..27193b6305 --- /dev/null +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe_split/l2_combine.cuh @@ -0,0 +1,660 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace deep_gemm::mega_moe_split { + +template < + uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K, + uint32_t L2_SHAPE_N, uint32_t L2_SHAPE_K, + uint32_t kNumExpertsPerRank, + uint32_t kKernel1SMs, uint32_t kKernel2SMs, uint32_t kNumRanks, + uint32_t kNumExpertsPerLane = math::constexpr_ceil_div(kNumExpertsPerRank, 32u), + uint32_t kNumL2BlockNs = L2_SHAPE_N / BLOCK_N, + uint32_t kNumL2BlockKs = L2_SHAPE_K / BLOCK_K +> +struct Kernel2L2Scheduler { + DG_STATIC_ASSERT(L2_SHAPE_N % BLOCK_N == 0, "Invalid L2 N shape"); + DG_STATIC_ASSERT(L2_SHAPE_K % BLOCK_K == 0, "Invalid L2 K shape"); + DG_STATIC_ASSERT(kKernel1SMs % 2 == 0 and kKernel2SMs % 2 == 0, "Invalid SM split"); + DG_STATIC_ASSERT(kNumL2BlockNs % 2 == 0, "L2 N block count must be even for 2-CTA cluster"); + + const layout::SplitWorkspace& workspace; + uint32_t block_idx = 0; + uint32_t current_local_expert_idx = 0; + uint32_t current_num_tokens = 0; + uint32_t current_pool_block_offset = 0; + uint32_t m_block_idx = 0; + uint32_t n_block_idx = 0; + uint32_t stored_num_tokens_per_expert[kNumExpertsPerLane] = {}; + + CUTLASS_DEVICE explicit Kernel2L2Scheduler(const layout::SplitWorkspace& workspace): workspace(workspace) { + block_idx = blockIdx.x; + } + + CUTLASS_DEVICE uint32_t get_num_tokens(const uint32_t& expert_idx) const { + uint32_t valid_value = 0; + #pragma unroll + for (uint32_t i = 0; i < kNumExpertsPerLane; ++i) { + valid_value = (expert_idx == i * 32 + ptx::get_lane_idx()) ? + stored_num_tokens_per_expert[i] : valid_value; + } + return ptx::exchange(valid_value, expert_idx % 32); + } + + CUTLASS_DEVICE uint32_t get_pool_block_offset(const uint32_t& expert_idx) { + uint32_t num_blocks = 0; + #pragma unroll + for (uint32_t i = 0; i < kNumExpertsPerLane; ++i) { + if (i * 32 + ptx::get_lane_idx() < expert_idx) + num_blocks += math::ceil_div(stored_num_tokens_per_expert[i], BLOCK_M); + } + return __reduce_add_sync(0xffffffffu, num_blocks); + } + + CUTLASS_DEVICE uint32_t get_current_pool_block_offset() const { + return current_pool_block_offset; + } + + CUTLASS_DEVICE uint32_t get_current_num_m_blocks() const { + return math::ceil_div(current_num_tokens, BLOCK_M); + } + + template + CUTLASS_DEVICE uint32_t get_valid_m() const { + const auto m = cute::min(current_num_tokens - m_block_idx * BLOCK_M, BLOCK_M); + return kDoUMMAAligned ? math::align(m, 16u) : m; + } + + CUTLASS_DEVICE void fetch_expert_recv_count() { + #pragma unroll + for (uint32_t i = 0; i < kNumExpertsPerLane; ++i) { + const auto expert_idx = i * 32 + ptx::get_lane_idx(); + uint64_t value = 0; + if (expert_idx < kNumExpertsPerRank) { + do { + value = ptx::ld_volatile(workspace.get_expert_recv_count_sum_ptr(expert_idx)); + } while (static_cast(value >> 32) != kKernel1SMs * kNumRanks); + } + stored_num_tokens_per_expert[i] = static_cast(value); + } + __syncwarp(); + } + + CUTLASS_DEVICE void set_expert_idx(const uint32_t& expert_idx) { + current_local_expert_idx = expert_idx; + current_num_tokens = get_num_tokens(expert_idx); + current_pool_block_offset = get_pool_block_offset(expert_idx); + } + + CUTLASS_DEVICE void advance_expert_idx() { + current_pool_block_offset += get_current_num_m_blocks(); + current_local_expert_idx += 1; + current_num_tokens = get_num_tokens(current_local_expert_idx); + } + + CUTLASS_DEVICE bool fetch_next_block() { + while (current_local_expert_idx < kNumExpertsPerRank) { + const auto num_m_blocks = get_current_num_m_blocks(); + const auto num_expert_blocks = num_m_blocks * kNumL2BlockNs; + if (block_idx < num_expert_blocks) { + m_block_idx = block_idx / kNumL2BlockNs; + n_block_idx = block_idx - m_block_idx * kNumL2BlockNs; + block_idx += kKernel2SMs; + return true; + } + block_idx -= num_expert_blocks; + advance_expert_idx(); + } + return false; + } + + template + CUTLASS_DEVICE void for_each_block(Func&& func) { + fetch_expert_recv_count(); + set_expert_idx(0); + while (fetch_next_block()) { + func(current_local_expert_idx, kNumL2BlockKs, m_block_idx, n_block_idx); + } + } +}; + +template < + uint32_t kNumMaxTokensPerRank, + uint32_t kHidden, uint32_t kIntermediateHidden, + uint32_t kNumExperts, uint32_t kNumTopk, + uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K, + uint32_t STORE_BLOCK_M, + uint32_t SF_BLOCK_M, uint32_t SF_BLOCK_N, + uint32_t kNumMaxPoolTokens, + uint32_t kNumPaddedSFPoolTokens, + uint32_t kNumStages, + uint32_t kNumNonEpilogueThreads, + uint32_t kNumEpilogueThreads, + uint32_t kKernel1SMs, uint32_t kKernel2SMs, uint32_t kNumRanks, + uint32_t L2_SHAPE_N = kHidden, + uint32_t L2_SHAPE_K = kIntermediateHidden, + uint32_t kNumMMANonEpilogueWarps = kNumNonEpilogueThreads / 32, + uint32_t kNumEpilogueWarps = kNumEpilogueThreads / 32, + uint32_t kNumEpilogueWarpgroups = kNumEpilogueWarps / 4, + uint32_t kNumThreads = kNumNonEpilogueThreads + kNumEpilogueThreads, + uint32_t kNumExpertsPerRank = kNumExperts / kNumRanks +> +CUTLASS_GLOBAL __launch_bounds__(kNumThreads, 1) void +sm100_fp8_fp4_mega_moe_split_l2_combine_impl( + uint32_t* state, + const uint32_t num_work_iters, + const __grid_constant__ layout::SymBuffer sym_buffer, + const __grid_constant__ cute::TmaDescriptor tensor_map_l2_acts, + const __grid_constant__ cute::TmaDescriptor tensor_map_l2_acts_sf, + const __grid_constant__ cute::TmaDescriptor tensor_map_l2_weights, + const __grid_constant__ cute::TmaDescriptor tensor_map_l2_weights_sf +) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + using Allocator = cute::TMEM::Allocator2Sm; + + DG_STATIC_ASSERT(kNumNonEpilogueThreads == 128, "Invalid number of MMA non-epilogue threads"); + DG_STATIC_ASSERT(kNumEpilogueThreads % 128 == 0, "Invalid number of epilogue threads"); + DG_STATIC_ASSERT(kNumExperts % kNumRanks == 0, "Invalid number of experts or ranks"); + DG_STATIC_ASSERT(kNumMMANonEpilogueWarps == 4, "K2 expects four non-epilogue MMA warps"); + + const bool is_leader_cta = cute::block_rank_in_cluster() == 0; + const uint32_t sm_idx = blockIdx.x; + const uint32_t thread_idx = threadIdx.x; + const uint32_t warp_idx = cutlass::canonical_warp_idx_sync(); + const uint32_t lane_idx = ptx::get_lane_idx(); + (void) num_work_iters; + + if (warp_idx == 0) { + cute::prefetch_tma_descriptor(&tensor_map_l2_acts); + cute::prefetch_tma_descriptor(&tensor_map_l2_acts_sf); + cute::prefetch_tma_descriptor(&tensor_map_l2_weights); + cute::prefetch_tma_descriptor(&tensor_map_l2_weights_sf); + } + + const auto workspace = layout::SplitWorkspace( + sym_buffer.get_base_ptr(), kNumRanks, kNumExperts, kNumMaxTokensPerRank, kNumTopk); + + const auto bf16_token_layout = layout::Data(kHidden * sizeof(nv_bfloat16)); + const auto fp8_intermediate_token_layout = layout::Data(kIntermediateHidden); + const auto fp8_intermediate_sf_layout = layout::Data(kIntermediateHidden / 32); + const auto fp8_token_layout = layout::Data(kHidden); + const auto fp8_sf_layout = layout::Data(kHidden / 32); + const auto input_topk_idx_layout = layout::Data(kNumTopk * sizeof(int64_t), false); + const auto input_topk_weights_layout = layout::Data(kNumTopk * sizeof(float), false); + const auto l1_topk_weights_layout = layout::Data(sizeof(float), false); + + const auto input_token_buffer = layout::Buffer( + fp8_token_layout, 1, kNumMaxTokensPerRank, + workspace.get_end_ptr()); + const auto input_sf_buffer = layout::Buffer( + fp8_sf_layout, 1, kNumMaxTokensPerRank, + input_token_buffer.get_end_ptr()); + const auto input_topk_idx_buffer = layout::Buffer( + input_topk_idx_layout, 1, kNumMaxTokensPerRank, + input_sf_buffer.get_end_ptr()); + const auto input_topk_weights_buffer = layout::Buffer( + input_topk_weights_layout, 1, kNumMaxTokensPerRank, + input_topk_idx_buffer.get_end_ptr()); + const auto l1_token_buffer = layout::Buffer( + fp8_token_layout, 1, kNumMaxPoolTokens, + input_topk_weights_buffer.get_end_ptr()); + const auto l1_sf_buffer = layout::Buffer( + fp8_sf_layout, 1, kNumPaddedSFPoolTokens, + l1_token_buffer.get_end_ptr()); + const auto l1_topk_weights_buffer = layout::Buffer( + l1_topk_weights_layout, 1, kNumMaxPoolTokens, + l1_sf_buffer.get_end_ptr()); + const auto l2_token_buffer = layout::Buffer( + fp8_intermediate_token_layout, 1, kNumMaxPoolTokens, + l1_topk_weights_buffer.get_end_ptr()); + const auto l2_sf_buffer = layout::Buffer( + fp8_intermediate_sf_layout, 1, kNumPaddedSFPoolTokens, + l2_token_buffer.get_end_ptr()); + const auto combine_token_buffer = layout::Buffer( + bf16_token_layout, kNumTopk, kNumMaxTokensPerRank, + l2_sf_buffer.get_end_ptr()); + + using a_dtype_t = cutlass::float_e4m3_t; + using b_dtype_t = cutlass::detail::float_e2m1_unpacksmem_t; + + constexpr uint32_t kGranK = 32; + constexpr uint32_t kNumUTCCPAlignedElems = 128; + constexpr uint32_t LAYOUT_AD_M = 128; + constexpr uint32_t UMMA_M = LAYOUT_AD_M * 2; + constexpr uint32_t UMMA_N = BLOCK_M; + constexpr uint32_t UMMA_BLOCK_K = 128; + constexpr uint32_t UMMA_K = 32; + constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / 2; + constexpr uint32_t LOAD_BLOCK_N = BLOCK_N; + constexpr uint32_t kSwizzleAMode = 128; + constexpr uint32_t kSwizzleBMode = 128; + constexpr uint32_t kSwizzleCDMode = 128; + constexpr uint32_t kNumEpilogueStages = 2; + constexpr uint32_t kSharedMemoryAlignment = 1024; + DG_STATIC_ASSERT(BLOCK_M % 16 == 0, "Invalid block M"); + DG_STATIC_ASSERT(BLOCK_N == LAYOUT_AD_M, "Invalid block N"); + DG_STATIC_ASSERT(BLOCK_N % kSwizzleCDMode == 0, "Invalid block N"); + DG_STATIC_ASSERT(SF_BLOCK_M == math::constexpr_align(BLOCK_M, kNumUTCCPAlignedElems), "Invalid SF_BLOCK_M"); + DG_STATIC_ASSERT(SF_BLOCK_N == BLOCK_N, "Invalid SF_BLOCK_N"); + + extern __shared__ __align__(kSharedMemoryAlignment) uint8_t smem_buffer[]; + constexpr uint32_t SMEM_CD_L2_SIZE = + kNumEpilogueWarpgroups * STORE_BLOCK_M * BLOCK_N * sizeof(nv_bfloat16); + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t); + constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t) * (BLOCK_K / 128); + constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t) * (BLOCK_K / 128); + DG_STATIC_ASSERT(SMEM_CD_L2_SIZE % kSharedMemoryAlignment == 0 and + SMEM_A_SIZE_PER_STAGE % kSharedMemoryAlignment == 0 and + SMEM_B_SIZE_PER_STAGE % kSharedMemoryAlignment == 0, + "Shared memory of CD/A/B must be aligned to 1024 bytes"); + + constexpr uint32_t kNumAccumTmemCols = UMMA_N * kNumEpilogueStages; + constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32; + constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32; + constexpr uint32_t kNumTmemCols = + utils::get_num_aligned_tmem_cols(); + constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols; + constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols; + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + auto smem_gemm_base = smem_buffer; + auto smem_cd_l2 = smem_gemm_base; + auto smem_a = utils::PatternVisitor([=](const uint32_t& i) { + return math::advance_ptr(smem_gemm_base, SMEM_CD_L2_SIZE + i * SMEM_A_SIZE_PER_STAGE); + }); + auto smem_b = utils::PatternVisitor([=](const uint32_t& i) { + return math::advance_ptr( + smem_gemm_base, SMEM_CD_L2_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + }); + auto sf_start_ptr = math::advance_ptr( + smem_gemm_base, SMEM_CD_L2_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto smem_sfa = utils::PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE); + }); + auto smem_sfb = utils::PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE); + }); + + auto barrier_start_ptr = reinterpret_cast(smem_sfb[kNumStages]); + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + i; }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + kNumStages + i; }); + auto tmem_full_barriers = utils::PatternVisitor([=](const uint32_t& i) { + return barrier_start_ptr + kNumStages * 2 + i; + }); + auto tmem_empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { + return barrier_start_ptr + kNumStages * 2 + kNumEpilogueStages + i; + }); + auto tmem_ptr_in_smem = reinterpret_cast( + barrier_start_ptr + kNumStages * 2 + kNumEpilogueStages * 2); + + comm::cluster_sync_with_relaxed_arrive(); + if (warp_idx == 0) { + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++i) { + full_barriers[i]->init(2 * 2); + empty_barriers[i]->init(1); + } + #pragma unroll + for (uint32_t i = 0; i < kNumEpilogueStages; ++i) { + tmem_full_barriers[i]->init(1); + tmem_empty_barriers[i]->init(2 * kNumEpilogueThreads); + } + } + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 3) { + Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + comm::cluster_sync_with_relaxed_arrive(); + + auto scheduler = Kernel2L2Scheduler< + BLOCK_M, BLOCK_N, BLOCK_K, + L2_SHAPE_N, L2_SHAPE_K, + kNumExpertsPerRank, + kKernel1SMs, kKernel2SMs, kNumRanks>(workspace); + + uint32_t stage_idx = 0, phase = 0; + auto advance_pipeline = [&](uint32_t& k_block_idx) { + ++k_block_idx; + stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1; + phase ^= stage_idx == 0; + }; + + constexpr uint32_t kEpilogueFullBarrierIdx = 0; + constexpr uint32_t kEpilogueWGBarrierStartIdx = 1; + constexpr uint32_t kBeforeCombineReduceBarrierTag = 2; + constexpr uint32_t kEpilogueGridSyncIndex = 1; + + constexpr uint32_t kNumNonEpilogueRegisters = 40; + constexpr uint32_t kNumEpilogueRegisters = 208; + DG_STATIC_ASSERT(kNumNonEpilogueRegisters * kNumNonEpilogueThreads + + kNumEpilogueRegisters * kNumEpilogueThreads <= 64512, + "Too many registers"); + + if (warp_idx < kNumMMANonEpilogueWarps) { + cutlass::arch::warpgroup_reg_dealloc(); + } + + if (warp_idx == 0) { + scheduler.for_each_block([&](const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + (void) local_expert_idx; + (void) n_block_idx; + const uint32_t pool_block_idx = scheduler.get_current_pool_block_offset() + m_block_idx; + + DG_STATIC_ASSERT(BLOCK_K % BLOCK_N == 0, "Invalid block sizes"); + constexpr uint32_t kShiftAmount = (L2_SHAPE_K / BLOCK_N) * 2; + DG_STATIC_ASSERT(kShiftAmount <= 64, "Too many L1 output blocks for mask"); + constexpr uint64_t kExpectedMask = kShiftAmount == 64 + ? 0xffffffffffffffffull + : ((1ull << kShiftAmount) - 1ull); + while (ptx::ld_acq_gpu(workspace.get_l2_arrival_mask_ptr(pool_block_idx)) != kExpectedMask) + __nanosleep(64u); + + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { + empty_barriers[stage_idx]->wait(phase ^ 1); + + uint32_t m_idx = pool_block_idx * BLOCK_M; + uint32_t k_idx = k_block_idx * BLOCK_K; + uint32_t sfa_m_idx = pool_block_idx * SF_BLOCK_M; + uint32_t sfa_k_idx = k_block_idx * (BLOCK_K / 128); + if (not is_leader_cta) + m_idx += scheduler.template get_valid_m() / 2; + + if (cute::elect_one_sync()) { + tma::copy( + &tensor_map_l2_acts, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx, 2); + tma::copy( + &tensor_map_l2_acts_sf, full_barriers[stage_idx], smem_sfa[stage_idx], sfa_m_idx, sfa_k_idx, 2); + if (is_leader_cta) { + full_barriers[stage_idx]->arrive_and_expect_tx( + SMEM_A_SIZE_PER_STAGE * 2 + SMEM_SFA_SIZE_PER_STAGE * 2); + } else { + full_barriers[stage_idx]->arrive(0u); + } + } + __syncwarp(); + } + }); + } else if (warp_idx == 1) { + scheduler.for_each_block([&](const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + (void) m_block_idx; + constexpr uint32_t shape_k = L2_SHAPE_K; + constexpr uint32_t shape_n = L2_SHAPE_N; + constexpr uint32_t shape_sfb_k = math::constexpr_ceil_div(shape_k, kGranK * 4u); + + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { + empty_barriers[stage_idx]->wait(phase ^ 1); + + uint32_t n_idx = local_expert_idx * shape_n + n_block_idx * BLOCK_N; + uint32_t k_idx = k_block_idx * BLOCK_K; + uint32_t sfb_n_idx = n_block_idx * BLOCK_N; + uint32_t sfb_k_idx = local_expert_idx * shape_sfb_k + k_block_idx * (BLOCK_K / 128); + + if (cute::elect_one_sync()) { + tma::copy( + &tensor_map_l2_weights, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx, 2); + tma::copy( + &tensor_map_l2_weights_sf, full_barriers[stage_idx], smem_sfb[stage_idx], sfb_n_idx, sfb_k_idx, 2); + if (is_leader_cta) { + full_barriers[stage_idx]->arrive_and_expect_tx( + SMEM_B_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE * 2); + } else { + full_barriers[stage_idx]->arrive(0u); + } + } + __syncwarp(); + } + }); + } else if (warp_idx == 2) { + if (is_leader_cta) { + auto instr_desc = cute::UMMA::make_instr_desc_block_scaled< + b_dtype_t, a_dtype_t, float, cutlass::float_ue8m0_t, + UMMA_M, UMMA_N, + cute::UMMA::Major::K, cute::UMMA::Major::K>(); + auto sf_desc = mma::sm100::make_sf_desc(nullptr); + + DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); + auto a_desc = mma::sm100::make_umma_desc< + cute::UMMA::Major::K, LOAD_BLOCK_M, UMMA_BLOCK_K, kSwizzleAMode>(smem_a[0], 0, 0); + auto b_desc = mma::sm100::make_umma_desc< + cute::UMMA::Major::K, LOAD_BLOCK_N, UMMA_BLOCK_K, kSwizzleBMode>(smem_b[0], 0, 0); + uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; + uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; + + DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), + "Invalid MMA instruction shape"); + + uint32_t current_iter_idx = 0; + scheduler.for_each_block([&](const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + (void) local_expert_idx; + (void) m_block_idx; + (void) n_block_idx; + mma::sm100::update_instr_desc_with_umma_n(instr_desc, scheduler.template get_valid_m()); + + const auto accum_stage_idx = current_iter_idx % kNumEpilogueStages; + const auto accum_phase = (current_iter_idx++ / kNumEpilogueStages) & 1; + tmem_empty_barriers[accum_stage_idx]->wait(accum_phase ^ 1); + ptx::tcgen05_after_thread_sync(); + + auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) { + auto umma_arrive = [](const uint64_t* barrier) { + constexpr uint16_t kCTAMask = (1 << 2) - 1; + cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask); + }; + umma_arrive(reinterpret_cast(empty_barriers[stage_idx])); + if (do_tmem_full_arrive) + umma_arrive(reinterpret_cast(tmem_full_barriers[accum_stage_idx])); + __syncwarp(); + }; + + #pragma unroll 2 + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { + full_barriers[stage_idx]->wait(phase); + ptx::tcgen05_after_thread_sync(); + + const auto a_desc_base_lo = ptx::exchange(a_desc_lo, stage_idx); + const auto b_desc_base_lo = ptx::exchange(b_desc_lo, stage_idx); + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t umma_k_block_idx = 0; umma_k_block_idx < BLOCK_K / UMMA_BLOCK_K; ++umma_k_block_idx) { + using cute_utccp_t = cute::SM100_UTCCP_4x32dp128bit_2cta; + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++i) { + auto smem_ptr = smem_sfa[stage_idx] + umma_k_block_idx * SF_BLOCK_M + i * kNumUTCCPAlignedElems; + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4); + } + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++i) { + auto smem_ptr = smem_sfb[stage_idx] + umma_k_block_idx * SF_BLOCK_N + i * kNumUTCCPAlignedElems; + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4); + } + + #pragma unroll + for (uint32_t k = 0; k < UMMA_BLOCK_K / UMMA_K; ++k) { + const auto runtime_instr_desc = + mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, k, k); + a_desc.lo = mma::sm100::advance_umma_desc_lo< + cute::UMMA::Major::K, LOAD_BLOCK_M, kSwizzleAMode, a_dtype_t>( + a_desc_base_lo, umma_k_block_idx * UMMA_BLOCK_K * LOAD_BLOCK_M * sizeof(a_dtype_t), k * UMMA_K); + b_desc.lo = mma::sm100::advance_umma_desc_lo< + cute::UMMA::Major::K, LOAD_BLOCK_N, kSwizzleBMode, b_dtype_t>( + b_desc_base_lo, umma_k_block_idx * UMMA_BLOCK_K * LOAD_BLOCK_N * sizeof(b_dtype_t), k * UMMA_K); + ptx::SM100_MMA_MXF8F6F4_2x1SM_SS::fma( + b_desc, a_desc, accum_stage_idx * UMMA_N, + k_block_idx > 0 or umma_k_block_idx > 0 or k > 0, runtime_instr_desc, + kTmemStartColOfSFB, kTmemStartColOfSFA); + } + } + } + __syncwarp(); + empty_barrier_arrive(k_block_idx == num_k_blocks - 1); + } + }); + + if (current_iter_idx > 0) { + const auto accum_phase_idx = ((current_iter_idx - 1) / kNumEpilogueStages) & 1; + tmem_empty_barriers[(current_iter_idx - 1) % kNumEpilogueStages]->wait(accum_phase_idx); + } + } + } else if (warp_idx >= kNumMMANonEpilogueWarps) { + cutlass::arch::warpgroup_reg_alloc(); + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); + + const auto epilogue_warp_idx = warp_idx - kNumMMANonEpilogueWarps; + const auto epilogue_wg_idx = epilogue_warp_idx / 4; + const auto epilogue_thread_idx = epilogue_warp_idx * 32 + lane_idx; + const auto warp_idx_in_wg = epilogue_warp_idx % 4; + DG_STATIC_ASSERT(kNumMMANonEpilogueWarps % 4 == 0 and + kNumEpilogueWarps % 4 == 0, "Invalid epilogue warps"); + + constexpr uint32_t WG_BLOCK_M = BLOCK_M / kNumEpilogueWarpgroups; + constexpr uint32_t ATOM_M = 8; + constexpr uint32_t kNumBankGroupBytes = 16u; + DG_STATIC_ASSERT(BLOCK_M % kNumEpilogueWarpgroups == 0, "Invalid block M"); + DG_STATIC_ASSERT(WG_BLOCK_M % STORE_BLOCK_M == 0, "Invalid warpgroup block M"); + DG_STATIC_ASSERT(STORE_BLOCK_M % ATOM_M == 0, "Invalid store block M"); + DG_STATIC_ASSERT(BLOCK_N == 128, "Invalid block N"); + DG_STATIC_ASSERT(STORE_BLOCK_M % 8 == 0, "Invalid store M"); + constexpr uint32_t kNumRowsPerWarp = STORE_BLOCK_M / 8; + + uint32_t current_iter_idx = 0; + scheduler.for_each_block([&](const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + (void) local_expert_idx; + (void) num_k_blocks; + const auto accum_stage_idx = current_iter_idx % kNumEpilogueStages; + const auto accum_phase = (current_iter_idx++ / kNumEpilogueStages) & 1; + tmem_full_barriers[accum_stage_idx]->wait(accum_phase); + ptx::tcgen05_after_thread_sync(); + + const uint32_t valid_m = ptx::exchange(scheduler.template get_valid_m(), 0); + const uint32_t pool_block_idx = scheduler.get_current_pool_block_offset() + m_block_idx; + uint32_t m_idx = pool_block_idx * BLOCK_M; + uint32_t n_idx = n_block_idx * BLOCK_N; + + #pragma unroll + for (uint32_t s = 0; s < WG_BLOCK_M / STORE_BLOCK_M; ++s) { + if (epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M >= valid_m) { + ptx::tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + break; + } + + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_M / ATOM_M; ++i) { + uint32_t tmem_addr = + accum_stage_idx * UMMA_N + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M; + uint32_t values[ATOM_M]; + cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr | 0x00100000, + values[4], values[5], values[6], values[7]); + cutlass::arch::fence_view_async_tmem_load(); + + if (i == 0 and s > 0) + ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); + + if (s == WG_BLOCK_M / STORE_BLOCK_M - 1 and i == STORE_BLOCK_M / ATOM_M - 1) { + ptx::tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + } + + uint32_t row = lane_idx % 8; + uint32_t col = (epilogue_warp_idx % 2) * 4 + lane_idx / 8; + const auto smem_ptr = smem_cd_l2 + + epilogue_wg_idx * STORE_BLOCK_M * BLOCK_N * static_cast(sizeof(nv_bfloat16)) + + (warp_idx_in_wg / 2) * STORE_BLOCK_M * kSwizzleCDMode + + i * ATOM_M * kSwizzleCDMode + + row * (kNumBankGroupBytes * 8) + + (col ^ row) * kNumBankGroupBytes; + ptx::SM90_U32x4_STSM_T::copy( + math::cast_into_bf16_and_pack(values[0], values[1]), + math::cast_into_bf16_and_pack(values[2], values[3]), + math::cast_into_bf16_and_pack(values[4], values[5]), + math::cast_into_bf16_and_pack(values[6], values[7]), + smem_ptr); + } + + ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); + + const uint32_t row_in_atom = (warp_idx_in_wg * 2 + lane_idx / 16) % ATOM_M; + const uint32_t bank_group_idx = lane_idx % 8; + + #pragma unroll + for (uint32_t j = 0; j < kNumRowsPerWarp; ++j) { + const uint32_t row_in_store = j * 8 + warp_idx_in_wg * 2 + lane_idx / 16; + const uint32_t m_idx_in_block = epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + row_in_store; + if (m_idx_in_block >= valid_m) + break; + + const auto src_metadata = *workspace.get_token_src_metadata_ptr(m_idx + m_idx_in_block); + const uint32_t dst_rank_idx = src_metadata.rank_idx; + const uint32_t dst_token_idx = src_metadata.token_idx; + const uint32_t dst_topk_idx = src_metadata.topk_idx; + + const auto smem_ptr = smem_cd_l2 + + epilogue_wg_idx * STORE_BLOCK_M * BLOCK_N * static_cast(sizeof(nv_bfloat16)) + + (lane_idx % 16 / 8) * STORE_BLOCK_M * kSwizzleCDMode + + row_in_store * kSwizzleCDMode + + (bank_group_idx ^ row_in_atom) * kNumBankGroupBytes; + const auto packed = ptx::ld_shared(reinterpret_cast(smem_ptr)); + + const auto dst_token = combine_token_buffer.get_rank_buffer(dst_topk_idx) + .get_data_buffer(dst_token_idx); + const auto dst_ptr = math::advance_ptr( + dst_token.get_base_ptr(), + n_idx * static_cast(sizeof(nv_bfloat16)) + + (lane_idx % 16) * static_cast(sizeof(float4))); + *sym_buffer.map(dst_ptr, dst_rank_idx) = packed; + } + } + + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + if (epilogue_warp_idx == 0 and cute::elect_one_sync()) + atomicAdd(state + get_state_offset(SplitStateOffset::K2DoneTasks), 1u); + }); + + if (epilogue_warp_idx == 0) + Allocator().free(0, kNumTmemCols); + + comm::nvlink_barrier( + workspace, sym_buffer, sm_idx, epilogue_thread_idx, + [&]() { ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); }); + } + + if (thread_idx == 0) + atomicAdd(state + get_state_offset(SplitStateOffset::K2DoneBlocks), 1u); +#endif +} + +} // namespace deep_gemm::mega_moe_split diff --git a/deep_gemm/include/deep_gemm/layout/mega_moe_split.cuh b/deep_gemm/include/deep_gemm/layout/mega_moe_split.cuh new file mode 100644 index 0000000000..3ee7b9e84a --- /dev/null +++ b/deep_gemm/include/deep_gemm/layout/mega_moe_split.cuh @@ -0,0 +1,109 @@ +#pragma once + +#include + +#include +#include +#include + +namespace deep_gemm::layout { + +// Split-kernel MegaMoE workspace. +// +// The split pipeline (dispatch_l1_swiglu / l2_combine / combine_reduce) shares the fused-megamoe +// `Workspace` bookkeeping region (barriers, expert counts, L1/L2 arrival masks) but needs a +// route-based ("token pull shared") dispatch-metadata sub-layout for the K1 dispatch pull: a +// per-(expert, rank) source-token-topk slot sized by `num_max_tokens_per_rank`, plus per-token +// route-count and multi-route-entry regions. +// +// It derives from `Workspace` so it binds to the `const layout::Workspace&` parameters of the +// shared scheduler / comm helpers (which only touch the identical region), while the split +// kernels call the overridden dispatch accessors on the `SplitWorkspace` object directly. Only +// the dispatch-region accessors plus `get_num_bytes` / `get_end_ptr` (the buffer size, hence the +// pool base) are overridden; everything else (and the fused `Workspace`) is left untouched. +struct SplitWorkspace : public Workspace { + uint32_t num_topk; + + CUTLASS_HOST_DEVICE + SplitWorkspace(void* base, + const uint32_t& num_ranks, + const uint32_t& num_experts, + const uint32_t& num_max_tokens_per_rank, + const uint32_t& num_topk): + Workspace(base, num_ranks, num_experts, num_max_tokens_per_rank, num_topk), + num_topk(num_topk) {} + + CUTLASS_HOST_DEVICE + uint64_t get_num_bytes() const { + uint64_t num_bytes = 0; + + // Barrier + num_bytes += kNumBarrierSignalBytes; + + // Expert send/recv count + num_bytes += num_experts * sizeof(uint64_t) * 2; + + // Expert recv count sum + num_bytes += num_experts_per_rank * sizeof(uint64_t); + + // L1 arrival count (padded to even entry count for `uint64_t` alignment of L2 mask) + num_bytes += math::align(num_max_pool_blocks, 2u) * sizeof(uint32_t); + + // L2 block arrival mask + num_bytes += num_max_pool_blocks * sizeof(uint64_t); + + // Dispatch pulling source token-topk + num_bytes += num_experts_per_rank * num_ranks * num_max_tokens_per_rank * sizeof(int); + + // Dispatch pulling per-token route counts + num_bytes += num_ranks * num_max_tokens_per_rank * sizeof(int); + + // Dispatch pulling multi-route entries + num_bytes += num_ranks * num_max_tokens_per_rank * num_topk * sizeof(int); + + // Combine push source indices + num_bytes += num_max_pool_tokens * sizeof(TokenSrcMetadata); + + // Align to TMA descriptor requirements + num_bytes = math::align(num_bytes, 16); + return num_bytes; + } + + CUTLASS_HOST_DEVICE + void* get_end_ptr() const { + return math::advance_ptr(base, get_num_bytes()); + } + + // For dispatch pulling + CUTLASS_DEVICE + uint32_t* get_src_token_topk_idx_ptr( + const uint32_t& expert_idx = 0, const uint32_t& rank_idx = 0, const uint32_t& token_idx = 0) const { + const auto base = get_l2_arrival_mask_ptr(num_max_pool_blocks); + return reinterpret_cast(base) + + (expert_idx * num_ranks + rank_idx) * num_max_tokens_per_rank + token_idx; + } + + CUTLASS_DEVICE + uint32_t* get_src_route_count_ptr( + const uint32_t& rank_idx = 0, const uint32_t& token_idx = 0) const { + return get_src_token_topk_idx_ptr(num_experts_per_rank) + + rank_idx * num_max_tokens_per_rank + token_idx; + } + + CUTLASS_DEVICE + uint32_t* get_src_route_entry_ptr( + const uint32_t& rank_idx = 0, const uint32_t& token_idx = 0, const uint32_t& topk_idx = 0) const { + return get_src_route_count_ptr(num_ranks) + + (rank_idx * num_max_tokens_per_rank + token_idx) * num_topk + topk_idx; + } + + // For combine usages + CUTLASS_DEVICE + TokenSrcMetadata* get_token_src_metadata_ptr(const uint32_t& pool_token_idx = 0) const { + const auto base = reinterpret_cast( + get_src_route_entry_ptr(num_ranks, 0, 0)); + return base + pool_token_idx; + } +}; + +} // namespace deep_gemm::layout diff --git a/deep_gemm/include/deep_gemm/ptx/ld_st.cuh b/deep_gemm/include/deep_gemm/ptx/ld_st.cuh index b9bca55de9..00edf41198 100644 --- a/deep_gemm/include/deep_gemm/ptx/ld_st.cuh +++ b/deep_gemm/include/deep_gemm/ptx/ld_st.cuh @@ -166,6 +166,12 @@ CUTLASS_DEVICE uint64_t ld_volatile(const uint64_t* ptr) { return ret; } +CUTLASS_DEVICE void st_global_v4_u32(uint4* ptr, const uint4& value) { + asm volatile("st.global.v4.u32 [%0], {%1, %2, %3, %4};" + :: "l"(ptr), "r"(value.x), "r"(value.y), "r"(value.z), "r"(value.w) + : "memory"); +} + CUTLASS_DEVICE uint32_t ld_acq(const uint32_t* ptr) { uint32_t ret; asm volatile("ld.acquire.gpu.global.b32 %0, [%1];" : "=r"(ret) : "l"(ptr)); diff --git a/deep_gemm/mega/__init__.py b/deep_gemm/mega/__init__.py index 703435d618..b55452715b 100644 --- a/deep_gemm/mega/__init__.py +++ b/deep_gemm/mega/__init__.py @@ -1,6 +1,6 @@ import torch import types -from typing import Tuple, Optional +from typing import List, Tuple, Optional from ..utils.math import align # noinspection PyBroadException @@ -20,7 +20,8 @@ def __init__(self, group: dist.ProcessGroup, num_max_tokens_per_rank: int, num_topk: int, hidden: int, intermediate_hidden: int, use_fp8_dispatch: bool = True, - activation: str = 'swiglu'): + activation: str = 'swiglu', + split: bool = False): self.group = group self.num_experts = num_experts self.num_max_tokens_per_rank = num_max_tokens_per_rank @@ -28,8 +29,12 @@ def __init__(self, group: dist.ProcessGroup, self.hidden = hidden self.intermediate_hidden = intermediate_hidden - # Allocate a symmetric buffer - num_bytes, slice_input_buffers = _C.get_symm_buffer_size_for_mega_moe( + # Allocate a symmetric buffer. The split-kernel pipeline reserves a slightly larger + # bookkeeping region (route-based dispatch metadata), so it has its own sizing; the + # input/pool/combine layout is identical, so inputs/outputs stay comparable to fused. + size_fn = (_C.get_symm_buffer_size_for_mega_moe_split if split + else _C.get_symm_buffer_size_for_mega_moe) + num_bytes, slice_input_buffers = size_fn( group.size(), num_experts, num_max_tokens_per_rank, num_topk, hidden, intermediate_hidden, @@ -130,3 +135,83 @@ def fp8_fp4_mega_moe(y: torch.Tensor, activation, activation_clamp, fast_math ) + + +def get_symm_buffer_for_mega_moe_split(group: dist.ProcessGroup, + num_experts: int, + num_max_tokens_per_rank: int, num_topk: int, + hidden: int, intermediate_hidden: int, + use_fp8_dispatch: bool = True, + activation: str = 'swiglu') -> SymmBuffer: + # Symmetric buffer for the split-kernel pipeline (route-based dispatch bookkeeping region). + num_max_tokens_per_rank = align(num_max_tokens_per_rank, _C.get_token_alignment_for_mega_moe()) + return SymmBuffer( + group, num_experts, + num_max_tokens_per_rank, num_topk, + hidden, intermediate_hidden, + use_fp8_dispatch, activation, + split=True + ) + + +class SM100FP8FP4MegaMoESplitGraph: + """CUDA graph for the split-kernel MoE forward. + + Wires three kernels into one graph using green contexts: dispatch_l1_swiglu (K1, gather + + Linear1 + SwiGLU) and l2_combine (K2, Linear2 + NVLink combine) run concurrently on disjoint + SM partitions (`kernel1_sms` / `kernel2_sms`), coupled through an in-HBM arrival mask; + combine_reduce (K3, top-k reduce) runs on `reduce_sms` after both. Requires CUDA Runtime 13.1+. + """ + def __init__(self, + states: List[torch.Tensor], + ys: List[torch.Tensor], + sym_buffers: List[SymmBuffer], + l1_weights: List[Tuple[torch.Tensor, torch.Tensor]], + l2_weights: List[Tuple[torch.Tensor, torch.Tensor]], + stats: List[torch.Tensor], + num_tokens: int, + activation_clamp: float, + fast_math: bool, + kernel1_sms: int, + kernel2_sms: int, + reduce_sms: int, + kernel2_work_iters: int, + reduce_work_iters: int): + assert len(states) == len(ys) == len(sym_buffers) == len(l1_weights) == len(l2_weights) == len(stats) + raw_sym_buffers = [buffer.buffer for buffer in sym_buffers] + sym_buffer_ptrs = [buffer.handle.buffer_ptrs for buffer in sym_buffers] + l1_weight_tensors = [weight_pair[0] for weight_pair in l1_weights] + l1_weight_sf_tensors = [weight_pair[1] for weight_pair in l1_weights] + l2_weight_tensors = [weight_pair[0] for weight_pair in l2_weights] + l2_weight_sf_tensors = [weight_pair[1] for weight_pair in l2_weights] + first_buffer = sym_buffers[0] + self._impl = _C.SM100FP8FP4MegaMoESplitGraph( + states, + ys, + raw_sym_buffers, + sym_buffer_ptrs, + l1_weight_tensors, + l1_weight_sf_tensors, + l2_weight_tensors, + l2_weight_sf_tensors, + stats, + first_buffer.group.rank(), + first_buffer.num_max_tokens_per_rank, + first_buffer.num_experts, + first_buffer.num_topk, + num_tokens, + first_buffer.hidden, + first_buffer.intermediate_hidden, + activation_clamp, + fast_math, + kernel1_sms, + kernel2_sms, + reduce_sms, + kernel2_work_iters, + reduce_work_iters) + + def replay(self): + self._impl.replay() + + def get_green_context_ids(self): + return self._impl.get_green_context_ids() diff --git a/tests/test_mega_moe_split.py b/tests/test_mega_moe_split.py new file mode 100644 index 0000000000..0f0fa5a247 --- /dev/null +++ b/tests/test_mega_moe_split.py @@ -0,0 +1,257 @@ +"""Split-kernel MegaMoE test: correctness (bitwise vs the fused megakernel) + performance. + +The split pipeline runs three kernels wired into one CUDA graph via green contexts: + * dispatch_l1_swiglu (K1): gather routed tokens + Linear1 + SwiGLU + FP8-quant to the pool. + * l2_combine (K2): Linear2 + NVLink combine-scatter; runs CONCURRENTLY with K1 on a + disjoint SM partition, consuming K1's pool blocks via an arrival mask. + * combine_reduce (K3): reduce the top-k combine partials into the final output. + +It reproduces the fused `fp8_fp4_mega_moe` arithmetic exactly (same MMA order), so the output is +expected to be bitwise identical. Performance is usually slightly better than the fused kernel. +""" +import argparse +import random +import torch +import torch.distributed as dist +from typing import Callable, List, Tuple + +import deep_gemm +from deep_gemm.utils import per_token_cast_to_fp4, per_token_cast_to_fp8 +from deep_gemm.utils.dist import dist_print, init_dist, uneven_all_gather + + +def _bench_replay(replay: Callable, reset: Callable, barrier: Callable, + num_warmups: int, num_tests: int, flush_mb: int) -> float: + """Best-of-N wall-clock (seconds) of `replay`, resetting state and flushing L2 each iter.""" + flush = torch.empty(max(1, flush_mb * 1024 * 1024 // 4), dtype=torch.int, device='cuda') + + def prepare(): + reset() + flush.zero_() + torch.cuda.synchronize() + barrier() + + for _ in range(num_warmups): + prepare() + replay() + torch.cuda.synchronize() + + start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) + best = float('inf') + for _ in range(num_tests): + prepare() + start.record() + replay() + end.record() + torch.cuda.synchronize() + best = min(best, start.elapsed_time(end) / 1e3) + return best + + +def _capture(fn: Callable) -> torch.cuda.CUDAGraph: + """Warm up `fn` on a side stream, then capture it into a replayable CUDA graph.""" + side = torch.cuda.Stream() + side.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(side): + fn() + torch.cuda.current_stream().wait_stream(side) + torch.cuda.synchronize() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + fn() + torch.cuda.synchronize() + return graph + + +def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace): + rank_idx, num_ranks, group = init_dist(local_rank, num_local_ranks) + torch.manual_seed(rank_idx) + random.seed(rank_idx) + + num_max_tokens_per_rank = args.num_max_tokens_per_rank + num_tokens = args.num_tokens + hidden, intermediate_hidden = args.hidden, args.intermediate_hidden + num_experts, num_topk = args.num_experts, args.num_topk + num_experts_per_rank = num_experts // num_ranks + assert num_tokens <= num_max_tokens_per_rank + assert num_experts % num_ranks == 0 + + prop = torch.cuda.get_device_properties(torch.cuda.current_device()) + kernel1_sms = args.kernel1_sms + kernel2_sms = args.kernel2_sms + reduce_sms = args.reduce_sms + assert kernel1_sms + kernel2_sms <= prop.multi_processor_count + + # Symmetric buffers: fused (reference) and split (route-based dispatch layout). + fused_buffer = deep_gemm.get_symm_buffer_for_mega_moe( + group, num_experts, num_max_tokens_per_rank, num_topk, hidden, intermediate_hidden) + split_buffer = deep_gemm.get_symm_buffer_for_mega_moe_split( + group, num_experts, num_max_tokens_per_rank, num_topk, hidden, intermediate_hidden) + + def cast_weights_to_fp4(bf16_weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + num_groups, n, k = bf16_weights.shape + w = torch.empty((num_groups, n, k // 2), device='cuda', dtype=torch.int8) + w_sf = torch.empty((num_groups, n, k // 32), device='cuda', dtype=torch.float) + for i in range(num_groups): + w[i], w_sf[i] = per_token_cast_to_fp4(bf16_weights[i], use_ue8m0=True, gran_k=32) + w_sf = deep_gemm.transform_sf_into_required_layout(w_sf, n, k, (1, 32), num_groups) + return w, w_sf + + # Inputs (identical for fused and split) + x_bf16 = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') + l1_weights = torch.randn((num_experts_per_rank, intermediate_hidden * 2, hidden), dtype=torch.bfloat16, device='cuda') + l2_weights = torch.randn((num_experts_per_rank, hidden, intermediate_hidden), dtype=torch.bfloat16, device='cuda') + scores = torch.randn((num_tokens, num_experts), dtype=torch.float, device='cuda') + topk_weights, topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False) + assert hidden % 128 == 0 and intermediate_hidden % 128 == 0 + x = per_token_cast_to_fp8(x_bf16, use_ue8m0=True, gran_k=32, use_packed_ue8m0=True) + l1_weights = cast_weights_to_fp4(l1_weights) + l2_weights = cast_weights_to_fp4(l2_weights) + l1w, l2w = deep_gemm.transform_weights_for_mega_moe(l1_weights, l2_weights) + + # Routing stats for this rank's local experts (identical for fused and split: same topk_idx). + # Mirrors test_mega_moe.py so TFLOPS/HBM/NVLink are computed the same way. + gathered_topk_idx = uneven_all_gather(topk_idx, group=group) + gathered_topk_idx[(gathered_topk_idx < rank_idx * num_experts_per_rank) | + (gathered_topk_idx >= (rank_idx + 1) * num_experts_per_rank)] = -1 + num_recv_tokens = (gathered_topk_idx != -1).sum().item() + num_touched_experts = torch.unique(gathered_topk_idx[gathered_topk_idx >= 0]).numel() + + def fill(buffer): + buffer.x[:num_tokens].copy_(x[0]) + buffer.x_sf[:num_tokens].copy_(x[1]) + buffer.topk_idx[:num_tokens].copy_(topk_idx) + buffer.topk_weights[:num_tokens].copy_(topk_weights) + + fused_stats = torch.randint(0, 100, (num_experts_per_rank,), dtype=torch.int, device='cuda') + split_stats = fused_stats.clone() + y_fused = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') + y_split = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') + state = torch.zeros((16,), dtype=torch.int, device='cuda') + + def reset_fused(): + fused_buffer.buffer.zero_() + fused_stats.zero_() + fill(fused_buffer) + + def reset_split(): + split_buffer.buffer.zero_() + split_stats.zero_() + state.zero_() + fill(split_buffer) + + def run_fused(): + deep_gemm.fp8_fp4_mega_moe( + y_fused, l1w, l2w, fused_buffer, + cumulative_local_expert_recv_stats=fused_stats, + activation_clamp=args.activation_clamp, fast_math=bool(args.fast_math)) + + split_graph = deep_gemm.SM100FP8FP4MegaMoESplitGraph( + [state], [y_split], [split_buffer], [l1w], [l2w], [split_stats], + num_tokens, args.activation_clamp, bool(args.fast_math), + kernel1_sms, kernel2_sms, reduce_sms, 0, 0) + + dist_print('Config:', once_in_node=True) + dist_print(f' > Tokens: {num_tokens}/{num_max_tokens_per_rank}', once_in_node=True) + dist_print(f' > Hidden: {hidden}, Intermediate: {intermediate_hidden}', once_in_node=True) + dist_print(f' > Experts: {num_topk}/{num_experts}, GPU SMs: {prop.multi_processor_count}', once_in_node=True) + dist_print(f' > Split SMs: K1(dispatch_l1_swiglu)={kernel1_sms}, K2(l2_combine)={kernel2_sms}, ' + f'K3(combine_reduce)={reduce_sms}', once_in_node=True) + dist_print(f' > Split green context ids: {split_graph.get_green_context_ids()}', once_in_node=True) + dist_print(once_in_node=True) + + # Correctness: split output must be bitwise identical to the fused megakernel. + reset_fused() + run_fused() + torch.cuda.synchronize() + dist.barrier() + reset_split() + torch.cuda.synchronize() + dist.barrier() # all ranks must finish zeroing before any K1 NVLink dispatch reads a peer buffer + split_graph.replay() + torch.cuda.synchronize() + dist.barrier() + + diff = (y_split.float() - y_fused.float()).abs() + max_abs = diff.max().item() + max_rel = (diff / y_fused.float().abs().clamp_min(1e-6)).max().item() + is_bitwise = torch.equal(y_split, y_fused) + dist_print(f'Correctness: bitwise={is_bitwise} (max_abs={max_abs:.6g}, max_rel={max_rel:.6g})', + once_in_node=True) + assert is_bitwise, f'split output is NOT bitwise identical to fused (max_abs={max_abs:.6g})' + + # Performance: wall-clock best-of-N for the fused megakernel vs the split graph. + fused_graph = _capture(run_fused) + t_fused = _bench_replay(fused_graph.replay, reset_fused, lambda: dist.barrier(), + args.num_warmups, args.num_tests, args.flush_l2_mb) + t_split = _bench_replay(split_graph.replay, reset_split, lambda: dist.barrier(), + args.num_warmups, args.num_tests, args.flush_l2_mb) + safe_div = lambda a, b: float('nan') if b == 0 else a / b + + def perf_metrics(t): + # 3 matmuls (L1 left, L1 right, L2), each 2 * M * N * K + tflops = safe_div(2 * num_recv_tokens * (hidden * intermediate_hidden * 3) / 1e12, t) + # HBM bytes: weights (FP4 = 0.5 B) + activations + output + num_hbm_bytes = ( + num_touched_experts * intermediate_hidden * 2 * hidden * 0.5 # L1 weights + + num_touched_experts * hidden * intermediate_hidden * 0.5 # L2 weights + + num_recv_tokens * hidden # L1 acts read + + num_recv_tokens * intermediate_hidden # L1 output write + + num_recv_tokens * intermediate_hidden # L2 acts read + + num_recv_tokens * hidden * 2 # L2 output write (BF16) + ) + hbm_gbs = safe_div(num_hbm_bytes / 1e9, t) + # NVLink bytes: dispatch pull + combine write-back + nvlink_gbs = safe_div(num_recv_tokens * hidden * 3 / 1e9, t) + return tflops, hbm_gbs, nvlink_gbs + + tf_f, hbm_f, nvl_f = perf_metrics(t_fused) + tf_s, hbm_s, nvl_s = perf_metrics(t_split) + dist_print(f'Routing: recv_tokens={num_recv_tokens}, touched_experts={num_touched_experts} (rank 0)', + once_in_node=True) + dist_print('Performance:', once_in_node=True) + dist_print(f' > Fused megakernel : {t_fused * 1e6:7.1f} us | {tf_f:5.0f} TFLOPS | ' + f'HBM {hbm_f:5.0f} GB/s | NVL {nvl_f:4.0f} GB/s', once_in_node=True) + dist_print(f' > Split pipeline : {t_split * 1e6:7.1f} us | {tf_s:5.0f} TFLOPS | ' + f'HBM {hbm_s:5.0f} GB/s | NVL {nvl_s:4.0f} GB/s', once_in_node=True) + dist_print(f' > Split / Fused : {safe_div(t_split, t_fused):.3f}x ' + f'({"faster" if t_split < t_fused else "slower"})', once_in_node=True) + # Machine-parseable summary line for sweep collection. + dist_print(f'SWEEP_RESULT tokens={num_tokens} hidden={hidden} inter={intermediate_hidden} ' + f'procs={num_ranks} recv={num_recv_tokens} experts={num_touched_experts} ' + f'fused_us={t_fused * 1e6:.1f} split_us={t_split * 1e6:.1f} ' + f'ratio={safe_div(t_split, t_fused):.4f} ' + f'fused_tflops={tf_f:.1f} split_tflops={tf_s:.1f} ' + f'fused_hbm={hbm_f:.1f} split_hbm={hbm_s:.1f} ' + f'fused_nvl={nvl_f:.1f} split_nvl={nvl_s:.1f}', once_in_node=True) + + dist.barrier() + fused_buffer.destroy() + split_buffer.destroy() + dist.destroy_process_group() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Split-kernel MegaMoE: bitwise vs fused + perf') + parser.add_argument('--num-processes', type=int, default=8) + parser.add_argument('--num-max-tokens-per-rank', type=int, default=8192) + parser.add_argument('--num-tokens', type=int, default=8192) + parser.add_argument('--hidden', type=int, default=7168) + parser.add_argument('--intermediate-hidden', type=int, default=3072) + parser.add_argument('--activation-clamp', type=float, default=10) + parser.add_argument('--num-experts', type=int, default=384) + parser.add_argument('--num-topk', type=int, default=6) + parser.add_argument('--fast-math', type=int, default=1) + parser.add_argument('--kernel1-sms', type=int, default=96) + parser.add_argument('--kernel2-sms', type=int, default=52) + parser.add_argument('--reduce-sms', type=int, default=148) + parser.add_argument('--num-warmups', type=int, default=3) + parser.add_argument('--num-tests', type=int, default=20) + parser.add_argument('--flush-l2-mb', type=int, default=2048) + parser.add_argument('--local-rank-idx', type=int, default=None) + args = parser.parse_args() + + if args.local_rank_idx is not None: + test(args.local_rank_idx, args.num_processes, args) + else: + torch.multiprocessing.spawn(test, args=(args.num_processes, args), nprocs=args.num_processes)