diff --git a/csrc/apis/gemm.hpp b/csrc/apis/gemm.hpp index 42622df7d8..31a6f5bc99 100644 --- a/csrc/apis/gemm.hpp +++ b/csrc/apis/gemm.hpp @@ -7,6 +7,7 @@ #include "../jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp" #include "../jit_kernels/impls/sm90_bf16_gemm.hpp" #include "../jit_kernels/impls/sm100_fp8_fp4_gemm_1d1d.hpp" +#include "../jit_kernels/impls/sm100_fp4_gemm_1d1d.hpp" #include "../jit_kernels/impls/sm100_bf16_gemm.hpp" #endif @@ -73,7 +74,12 @@ static void fp8_fp4_gemm_nt(const std::pair& a, const auto [n , k_] = check_ab_fp8_fp4(b.first, major_b, arch_major); const auto [m_, n_] = get_shape<2>(d); DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); - DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat); + // FP4xFP4 path has no bf16 epilogue store: require fp32 D. + if (a.first.scalar_type() == kPackedFP4 and b.first.scalar_type() == kPackedFP4) { + DG_HOST_ASSERT(d.scalar_type() == torch::kFloat); + } else { + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat); + } // Early return for trivial cases if (early_return(m, n, k, d, c)) @@ -93,8 +99,17 @@ static void fp8_fp4_gemm_nt(const std::pair& a, sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, major_sfb, compiled_dims); } } else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) { - sm100_fp8_fp4_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, gran_k_a, gran_k_b, + if (a.first.scalar_type() == kPackedFP4 and b.first.scalar_type() == kPackedFP4) { + // FP4xFP4 path reinterprets int8 packed tensors as int32 for TMA; requires + // contiguous K-major layout and packed-int32-aligned K. + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + DG_HOST_ASSERT(k % 8 == 0); + sm100_fp4_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims); + } else { + sm100_fp8_fp4_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, gran_k_a, gran_k_b, + major_a, major_b, compiled_dims); + } } else { DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types"); } @@ -166,7 +181,12 @@ static void m_grouped_fp8_fp4_gemm_nt_contiguous(const std::pair(d); DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); DG_HOST_ASSERT(n > 0 and k > 0 and num_groups > 0); - DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + // FP4xFP4 path has no bf16 epilogue store: require fp32 D. + if (a.first.scalar_type() == kPackedFP4 and b.first.scalar_type() == kPackedFP4) { + DG_HOST_ASSERT(d.scalar_type() == torch::kFloat); + } else { + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + } DG_HOST_ASSERT(grouped_layout.scalar_type() == torch::kInt); // Layout checks @@ -197,9 +217,22 @@ static void m_grouped_fp8_fp4_gemm_nt_contiguous(const std::pair 0 ? m / num_groups : 0; + sm100_m_grouped_fp4_gemm_contiguous_1d1d(a.first, sfa, b.first, sfb, d, grouped_layout, + num_groups, m, n, k, expected_m_per_group, + major_a, major_b, compiled_dims); + } else { + sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d(a.first, sfa, b.first, sfb, d, grouped_layout, + num_groups, m, n, k, gran_k_a, gran_k_b, major_a, major_b, + compiled_dims, use_psum_layout, expected_m_for_psum_layout); + } } else { DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types"); } @@ -244,7 +277,12 @@ static void m_grouped_fp8_fp4_gemm_nt_masked(const std::pair 0 and m > 0 and n > 0 and k > 0 and num_groups > 0); - DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + // FP4xFP4 path has no bf16 epilogue store: require fp32 D. + if (a.first.scalar_type() == kPackedFP4 and b.first.scalar_type() == kPackedFP4) { + DG_HOST_ASSERT(d.scalar_type() == torch::kFloat); + } else { + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + } DG_HOST_ASSERT(masked_m.scalar_type() == torch::kInt); // D must be N-major @@ -260,9 +298,17 @@ static void m_grouped_fp8_fp4_gemm_nt_masked(const std::pair + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../../utils/math.hpp" +#include "../heuristics/common.hpp" +#include "../heuristics/sm100.hpp" + +#include "runtime_utils.hpp" + +namespace deep_gemm { + +// FP4xFP4 (MXFP4) GEMM via SM100_MMA_MXF4_SS. Distinct from the MXF8F6F4 +// path in sm100_fp8_fp4_gemm_1d1d.hpp; used when both A and B are kPackedFP4. +// GemmDesc carries logical FP4 K; this kernel template takes K and BLOCK_K in +// packed int32 units (8 FP4 / int32, 4 bytes). The wrapper converts at JIT +// instantiation and runtime launch. +class SM100FP4Gemm1D1DRuntime final: public LaunchRuntime { +public: + struct Args { + GemmDesc gemm_desc; + GemmConfig gemm_config; + LaunchArgs launch_args; + + void* grouped_layout; + CUtensorMap tensor_map_a; + CUtensorMap tensor_map_b; + CUtensorMap tensor_map_sfa; + CUtensorMap tensor_map_sfb; + CUtensorMap tensor_map_d; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm100_fp4_gemm_1d1d_impl< + {}, {}, + {}, {}, {}, + {}, {}, {}, + {}, + {}, {}, {}, + {}, + {}, {}, + {}, {}, + {}, + {}, + {}, {}, {} + >); +}}; +)", + to_string(args.gemm_desc.major_a), to_string(args.gemm_desc.major_b), + get_compiled_dim(args.gemm_desc.m, 'm', args.gemm_desc.compiled_dims), + get_compiled_dim(args.gemm_desc.n, 'n', args.gemm_desc.compiled_dims), + // SHAPE_K in packed int32 (= FP4_count / 8) + get_compiled_dim(args.gemm_desc.k / 8, 'k', args.gemm_desc.compiled_dims), + args.gemm_config.layout.block_m, args.gemm_config.layout.block_n, + // BLOCK_K in packed int32 (Layout.block_k is in bytes for int8-packed FP4) + args.gemm_config.layout.block_k / 4, + args.gemm_desc.num_groups, + args.gemm_config.storage_config.swizzle_a_mode, args.gemm_config.storage_config.swizzle_b_mode, args.gemm_config.storage_config.swizzle_cd_mode, + args.gemm_config.pipeline_config.num_stages, + args.gemm_config.launch_config.num_non_epilogue_threads, args.gemm_config.launch_config.num_epilogue_threads, + args.gemm_config.layout.get_cluster_size(), args.gemm_config.layout.cluster_n > 1, + args.gemm_config.launch_config.num_sms, + args.gemm_config.layout.swap_ab, + to_string(args.gemm_desc.gemm_type), args.gemm_desc.with_accumulation, + to_string(args.gemm_desc.cd_dtype)); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + // shape_k in packed int32 (must be a non-const local — launch_kernel takes + // &args internally and a const& cannot bind to void*). + int shape_k_int32 = args.gemm_desc.k / 8; + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.grouped_layout, args.gemm_desc.m, args.gemm_desc.n, shape_k_int32, + args.tensor_map_a, args.tensor_map_b, + args.tensor_map_sfa, args.tensor_map_sfb, + args.tensor_map_d)); + } +}; + +// Cap num_stages so SM100's 232448-byte smem capacity fits the FP4 SF smem +// footprint (sf_packed_k_per_stage = block_k_bytes / 64), which the shared +// SM100ArchSpec heuristic underestimates by 2x. +static std::pair recompute_stages_for_fp4(const GemmConfig& config, int block_m, int block_n, int k_fp4) { + constexpr int smem_capacity = 232448; + constexpr int sf_block_align = 128; + const int sf_block_m = (block_m + sf_block_align - 1) / sf_block_align * sf_block_align; + const int sf_block_n = (block_n + sf_block_align - 1) / sf_block_align * sf_block_align; + + // Per-stage smem footprint (matches the kernel's actual allocation): + // A: load_block_m * block_k_bytes, B: load_block_n * block_k_bytes + // SFA/B: sf_block_mn * sf_packed_k_per_stage * 4 + // sf_packed_k_per_stage = block_k_bytes / 64 (BLOCK_K_FP4 / VS / 4). + const int sf_packed_k_per_stage = config.layout.block_k / 64; + const int per_stage = config.storage_config.load_block_m * config.layout.block_k + + config.storage_config.load_block_n * config.layout.block_k + + sf_block_m * sf_packed_k_per_stage * 4 + + sf_block_n * sf_packed_k_per_stage * 4; + + // CD smem (matches kernel's SMEM_CD_SIZE_PER_STAGE × kNumTMAStoreStages=2); + // FP4xFP4 epilogue is fp32-only (4-byte cd elements). + int cd_size; + if (config.layout.swap_ab) { + constexpr int store_block_m_swap = 16; + cd_size = store_block_m_swap * block_n * 4 * 2; + } else { + cd_size = config.storage_config.store_block_m * config.storage_config.swizzle_cd_mode * 2; + } + const int barriers = 12 * 8 * 4 + 4 * 8 * 2 + 8; + const int tmem_ptr = 4; + const int fixed_extras = cd_size + barriers + tmem_ptr; + + const int max_stages_smem = (smem_capacity - fixed_extras) / per_stage; + const int num_k_blocks = ceil_div(k_fp4, config.layout.block_k * 2); + const int new_num_stages = std::min({12, max_stages_smem, num_k_blocks}); + const int new_smem_size = fixed_extras + new_num_stages * per_stage; + return {new_num_stages, new_smem_size}; +} + +// FP4xFP4 Layout selector: +// block_m = 128 (UMMA_M for MXF4); block_k = 128 bytes (32 packed int32 = 256 FP4). +// block_n: pick by wave count + composite score = est_stages^2 * bn, with a +// 2-epi-stage tiebreak when waves >= 2 (TMEM double-buffer needs >= 2 tiles/SM). +// cluster_m=2, cluster_n=1 (B-multicast) when M >= 512 for Normal/KGroupedContiguous. +// swap_ab: enabled for MGroupedContiguous when expected_m_per_group < BLOCK_M +// (sparse MoE); forces block_n=128 + cluster=1 per kernel asserts. +static Layout pick_fp4_layout(const GemmType& gemm_type, + const int& m, const int& n, const int& k, + const int& num_groups, const int& num_sms, + const int& expected_m_per_group = INT_MAX) { + constexpr int block_m = 128; + constexpr int block_k_bytes = 128; // 32 packed int32 = 256 FP4 elements per K block + constexpr int sf_pk = 2; // block_k_int32 / 16 = 32/16 + constexpr int sf_block_m_cols = (128 / 32) * sf_pk; // 8 + constexpr int smem_capacity = 232448; + + // BLOCK_N legality: only TMEM column budget; B-tensor OOB is handled by TMA + // store-drop, and SFB smem padding is zero-filled before warp transpose. + auto is_legal = [&](int bn) { + if (bn % 16 != 0 || bn < 16 || bn > 256) return false; + const int sf_block_n = (bn + 127) / 128 * 128; + const int sf_block_n_cols = (sf_block_n / 32) * sf_pk; + // TMEM minimum: 1 epi stage must fit + if ((1 * bn + sf_block_m_cols + sf_block_n_cols) > 512) return false; + return true; + }; + auto fits_2_epi = [&](int bn) { + const int sf_block_n = (bn + 127) / 128 * 128; + const int sf_block_n_cols = (sf_block_n / 32) * sf_pk; + return (2 * bn + sf_block_m_cols + sf_block_n_cols) <= 512; + }; + + auto num_blocks = [&](int bn) { return ceil_div(m, block_m) * ceil_div(n, bn) * num_groups; }; + auto num_waves = [&](int bn) { return ceil_div(num_blocks(bn), num_sms); }; + + int best_bn = 0, best_waves = 0, best_score = 0; + bool best_fits_2epi = false; + for (int bn = 16; bn <= 256; bn += 16) { + if (!is_legal(bn)) continue; + const int waves = num_waves(bn); + + // est_stages (conservative, no multicast): per_stage in bytes + const int sf_block_n = (bn + 127) / 128 * 128; + const int per_stage = block_m * block_k_bytes + bn * block_k_bytes + + 128 * sf_pk * 4 + sf_block_n * sf_pk * 4; + const int avail = smem_capacity - 32768 - 200; + const int est_stages = std::min(12, std::max(1, avail / per_stage)); + const int score = est_stages * est_stages * bn; + const bool cur_fits_2epi = fits_2_epi(bn); + const bool consider_epi = waves >= 2; // 2-epi benefit only when SM processes >= 2 tiles + + bool success = false; + if (best_bn == 0 || waves < best_waves) { + success = true; + } else if (waves == best_waves && bn <= n) { + if (consider_epi && cur_fits_2epi && !best_fits_2epi) success = true; + else if ((!consider_epi || cur_fits_2epi == best_fits_2epi) && score > best_score) success = true; + } + if (success) { + best_bn = bn; best_waves = waves; best_score = score; best_fits_2epi = cur_fits_2epi; + } + } + DG_HOST_ASSERT(best_bn > 0); + + // Env override (for benchmarking) + if (const auto env_bn = get_env("DG_FP4_BLOCK_N"); env_bn > 0) { + DG_HOST_ASSERT(env_bn % 16 == 0 && env_bn <= 256); + best_bn = env_bn; + } + + // Multicast: B-multicast (cluster_m=2, cluster_n=1) when M >= 512 for Normal / KGrouped. + // m-grouped types use cluster=1: multi-CTA M-distribution is incompatible + // with m_indices iteration in the kernel scheduler. + int cluster_m = 1, cluster_n = 1; + const bool can_multicast = (m >= 512) + && (ceil_div(m, block_m) % 2 == 0) + && (num_sms % 2 == 0) + && (gemm_type == GemmType::Normal || gemm_type == GemmType::KGroupedContiguous); + if (can_multicast) { + cluster_m = 2; + } + + // swap_ab for sparse m-grouped contiguous (forces block_n=128 + cluster=1) + bool swap_ab = false; + if (gemm_type == GemmType::MGroupedContiguous && expected_m_per_group < block_m) { + swap_ab = true; + best_bn = 128; + cluster_m = cluster_n = 1; + } + + return Layout{swap_ab, block_m, best_bn, block_k_bytes, cluster_m, cluster_n}; +} + +// Build a GemmDesc forcing a_dtype=b_dtype=kPackedFP4 (since this is the FP4xFP4 wrapper). +static GemmDesc make_fp4_desc(GemmType gemm_type, int m, int n, int k, int num_groups, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const at::ScalarType& cd_dtype, bool with_accumulation, + const std::string& compiled_dims, + int expected_m = 0, int expected_num_groups = 1) { + return GemmDesc { + .gemm_type = gemm_type, + .kernel_type = KernelType::Kernel1D1D, + .m = m, .n = n, .k = k, .num_groups = num_groups, + .a_dtype = kPackedFP4, .b_dtype = kPackedFP4, + .cd_dtype = cd_dtype, + .major_a = major_a, .major_b = major_b, + .with_accumulation = with_accumulation, + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), + .compiled_dims = compiled_dims, + .expected_m = expected_m > 0 ? expected_m : m, + .expected_n = n, .expected_k = k, + .expected_num_groups = expected_num_groups + }; +} + +static void sm100_fp4_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + constexpr int gran_k = 32; + const auto desc = make_fp4_desc(GemmType::Normal, m, n, k, 1, + major_a, major_b, d.scalar_type(), + c.has_value(), compiled_dims); + const auto layout = pick_fp4_layout(GemmType::Normal, m, n, k, 1, + device_runtime->get_num_sms()); + auto config = GemmConfig{ + .layout = layout, + .storage_config = SM100ArchSpec::get_storage_config(desc, layout), + .pipeline_config = {}, + .launch_config = SM100ArchSpec::get_launch_config(desc, layout), + }; + config.pipeline_config = SM100ArchSpec::get_pipeline_config(desc, layout, config.storage_config); + const auto [new_stages, new_smem] = recompute_stages_for_fp4(config, layout.block_m, layout.block_n, k); + config.pipeline_config.num_stages = new_stages; + config.pipeline_config.smem_size = new_smem; + + // View FP4-packed-int8 tensors as int32 (8 FP4 per int32). Same memory + // layout, different dtype tag — the TMA descriptor then uses INT32 (not + // 16U4_ALIGN16B unpacked-smem), matching the kernel's int32-packed smem. + const auto a_int32 = a.view(torch::kInt); + const auto b_int32 = b.view(torch::kInt); + const int k_int32 = k / 8; + const int block_k_int32 = config.layout.block_k / 4; + + const auto tensor_map_a = make_tma_a_desc(major_a, a_int32, m, k_int32, + config.storage_config.load_block_m, + block_k_int32, + static_cast(a_int32.stride(get_non_contiguous_dim(major_a))), 1, + config.storage_config.swizzle_a_mode); + const auto tensor_map_b = make_tma_b_desc(major_b, b_int32, n, k_int32, + config.storage_config.load_block_n, + block_k_int32, + static_cast(b_int32.stride(get_non_contiguous_dim(major_b))), 1, + config.storage_config.swizzle_b_mode); + const auto tensor_map_d = make_tma_cd_desc(d, m, n, + config.storage_config.store_block_m, + config.storage_config.store_block_n, + static_cast(d.stride(-2)), 1, + config.storage_config.swizzle_cd_mode); + const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.layout.block_m, gran_k, 1, 0); + const auto tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k, + config.layout.block_n, gran_k, 1, 0); + + // C is merged into D by the API's early_return() pre-launch; the kernel + // has no separate C load path. + const SM100FP4Gemm1D1DRuntime::Args args = { + .gemm_desc = desc, + .gemm_config = config, + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_sfa = tensor_map_sfa, + .tensor_map_sfb = tensor_map_sfb, + .tensor_map_d = tensor_map_d + }; + const auto code = SM100FP4Gemm1D1DRuntime::generate(args); + const auto runtime = compiler->build("sm100_fp4_gemm_1d1d", code); + SM100FP4Gemm1D1DRuntime::launch(runtime, args); +} + +static void sm100_m_grouped_fp4_gemm_contiguous_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const torch::Tensor& d, + const torch::Tensor& grouped_layout, + const int& num_groups, const int& m, const int& n, const int& k, + const int& expected_m_per_group, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + constexpr int gran_k = 32; + const auto desc = make_fp4_desc(GemmType::MGroupedContiguous, m, n, k, num_groups, + major_a, major_b, d.scalar_type(), + false, compiled_dims); + // swap_ab is gated on expected_m_per_group. The API passes m / num_groups + // (a host-side estimate that avoids inspecting m_indices). Accurate for + // dense layouts; over-estimates for sparse MoE with -1 padding, which + // conservatively disables swap_ab. Acceptable trade-off vs a device sync. + const auto layout = pick_fp4_layout(GemmType::MGroupedContiguous, m, n, k, num_groups, + device_runtime->get_num_sms(), expected_m_per_group); + auto config = GemmConfig{ + .layout = layout, + .storage_config = SM100ArchSpec::get_storage_config(desc, layout), + .pipeline_config = {}, + .launch_config = SM100ArchSpec::get_launch_config(desc, layout), + }; + config.pipeline_config = SM100ArchSpec::get_pipeline_config(desc, layout, config.storage_config); + const auto [new_stages_g, new_smem_g] = recompute_stages_for_fp4(config, layout.block_m, layout.block_n, k); + config.pipeline_config.num_stages = new_stages_g; + config.pipeline_config.smem_size = new_smem_g; + + const auto a_int32 = a.view(torch::kInt); + const auto b_int32 = b.view(torch::kInt); + const int k_int32 = k / 8; + const int block_k_int32 = config.layout.block_k / 4; + + const auto tensor_map_a = make_tma_a_desc(major_a, a_int32, m, k_int32, + config.storage_config.load_block_m, + block_k_int32, + static_cast(a_int32.stride(get_non_contiguous_dim(major_a))), 1, + config.storage_config.swizzle_a_mode); + const auto tensor_map_b = make_tma_b_desc(major_b, b_int32, n, k_int32, + config.storage_config.load_block_n, + block_k_int32, + static_cast(b_int32.stride(get_non_contiguous_dim(major_b))), num_groups, + config.storage_config.swizzle_b_mode); + const auto tensor_map_d = make_tma_cd_desc(d, m, n, + config.storage_config.store_block_m, + config.storage_config.store_block_n, + static_cast(d.stride(-2)), 1, + config.storage_config.swizzle_cd_mode); + const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.layout.block_m, gran_k, 1, 0); + const auto tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k, + config.layout.block_n, gran_k, num_groups, 0); + + const SM100FP4Gemm1D1DRuntime::Args args = { + .gemm_desc = desc, + .gemm_config = config, + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), + .grouped_layout = grouped_layout.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_sfa = tensor_map_sfa, + .tensor_map_sfb = tensor_map_sfb, + .tensor_map_d = tensor_map_d + }; + const auto code = SM100FP4Gemm1D1DRuntime::generate(args); + const auto runtime = compiler->build("sm100_m_grouped_fp4_gemm_contiguous_1d1d", code); + SM100FP4Gemm1D1DRuntime::launch(runtime, args); +} + +static void sm100_m_grouped_fp4_gemm_masked_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const torch::Tensor& d, + const torch::Tensor& masked_m, + const int& num_groups, const int& m, const int& n, const int& k, + const int& expected_m, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + constexpr int gran_k = 32; + const auto desc = make_fp4_desc(GemmType::MGroupedMasked, m, n, k, num_groups, + major_a, major_b, d.scalar_type(), + false, compiled_dims, + /*expected_m=*/expected_m, /*expected_num_groups=*/num_groups); + // m-grouped masked: pass expected_m as per-group hint; swap_ab gating + // inside pick_fp4_layout only activates for MGroupedContiguous. + const auto layout = pick_fp4_layout(GemmType::MGroupedMasked, m, n, k, num_groups, + device_runtime->get_num_sms(), expected_m); + auto config = GemmConfig{ + .layout = layout, + .storage_config = SM100ArchSpec::get_storage_config(desc, layout), + .pipeline_config = {}, + .launch_config = SM100ArchSpec::get_launch_config(desc, layout), + }; + config.pipeline_config = SM100ArchSpec::get_pipeline_config(desc, layout, config.storage_config); + const auto [new_stages_mk, new_smem_mk] = recompute_stages_for_fp4(config, layout.block_m, layout.block_n, k); + config.pipeline_config.num_stages = new_stages_mk; + config.pipeline_config.smem_size = new_smem_mk; + + const auto a_int32 = a.view(torch::kInt); + const auto b_int32 = b.view(torch::kInt); + const int k_int32 = k / 8; + const int block_k_int32 = config.layout.block_k / 4; + + const auto tensor_map_a = make_tma_a_desc(major_a, a_int32, m, k_int32, + config.storage_config.load_block_m, + block_k_int32, + static_cast(a_int32.stride(get_non_contiguous_dim(major_a))), num_groups, + config.storage_config.swizzle_a_mode); + const auto tensor_map_b = make_tma_b_desc(major_b, b_int32, n, k_int32, + config.storage_config.load_block_n, + block_k_int32, + static_cast(b_int32.stride(get_non_contiguous_dim(major_b))), num_groups, + config.storage_config.swizzle_b_mode); + const auto tensor_map_d = make_tma_cd_desc(d, m, n, + config.storage_config.store_block_m, + config.storage_config.store_block_n, + static_cast(d.stride(-2)), num_groups, + config.storage_config.swizzle_cd_mode); + const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.layout.block_m, gran_k, num_groups, 0); + const auto tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k, + config.layout.block_n, gran_k, num_groups, 0); + + const SM100FP4Gemm1D1DRuntime::Args args = { + .gemm_desc = desc, + .gemm_config = config, + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), + .grouped_layout = masked_m.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_sfa = tensor_map_sfa, + .tensor_map_sfb = tensor_map_sfb, + .tensor_map_d = tensor_map_d + }; + const auto code = SM100FP4Gemm1D1DRuntime::generate(args); + const auto runtime = compiler->build("sm100_m_grouped_fp4_gemm_masked_1d1d", code); + SM100FP4Gemm1D1DRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp4_gemm_1d1d.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp4_gemm_1d1d.cuh new file mode 100644 index 0000000000..742ef4d40d --- /dev/null +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp4_gemm_1d1d.cuh @@ -0,0 +1,735 @@ +#pragma once +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::mma::sm100; +using namespace deep_gemm::math; +using namespace deep_gemm::ptx; +using namespace deep_gemm::utils; + +template +__global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1) +sm100_fp4_gemm_1d1d_impl(int* grouped_layout, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfa, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfb, + const __grid_constant__ cute::TmaDescriptor tensor_map_d) { + using Barrier = cutlass::arch::ClusterTransactionBarrier; + using Allocator = cute::conditional_t; + + if constexpr (kWithAccumulation) + DG_STATIC_ASSERT(cute::is_same_v, "Invalid C/D data dtype"); + + // MMA configs + constexpr uint32_t LAYOUT_AD_M = 128; + constexpr uint32_t kNumMWaves = BLOCK_M / LAYOUT_AD_M; + constexpr uint32_t kNumTMAStoreStages = 2; + + // MXFP4 / SF configs + constexpr uint32_t kNumSFAStagesPerLoad = 1; + constexpr uint32_t kNumSFBStagesPerLoad = 1; + constexpr uint32_t kNumUTCCPAlignedElems = 128; + constexpr uint32_t FP4_ELEMS_PER_INT32 = 8; + constexpr uint32_t MXF4_VS = 32; + constexpr uint32_t BLOCK_K_FP4 = BLOCK_K * FP4_ELEMS_PER_INT32; + constexpr uint32_t UMMA_K_FP4 = 64; + constexpr uint32_t SF_K_PER_STAGE = BLOCK_K_FP4 / MXF4_VS; + constexpr uint32_t SF_PACKED_K_PER_STAGE = SF_K_PER_STAGE / 4; + + DG_STATIC_ASSERT(BLOCK_M % LAYOUT_AD_M == 0 and 2 % kNumMWaves == 0, "Invalid block M"); + DG_STATIC_ASSERT(BLOCK_K == 16 or BLOCK_K == 32, "FP4 BLOCK_K must be 16 or 32 int32"); + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + const uint32_t total_scales_k = ceil_div(shape_k * FP4_ELEMS_PER_INT32, MXF4_VS); + const uint32_t total_packed_k = ceil_div(total_scales_k, uint32_t(4)); + const uint32_t shape_sfa_k = total_packed_k; + const uint32_t shape_sfb_k = total_packed_k; + + // Utils + bool is_leader_cta = cute::block_rank_in_cluster() == 0; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = get_lane_idx(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // Load/store block sizes + constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1); + constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast); + // Swap-AB: STORE_BLOCK_M=16 (fine-grained M slices to skip padding rows), STORE_BLOCK_N=BLOCK_N. + // Non-swap: STORE_BLOCK_M=BLOCK_M, STORE_BLOCK_N derived from swizzle. + constexpr uint32_t STORE_BLOCK_M = kSwapAB ? 16u : cute::min(BLOCK_M, LAYOUT_AD_M); + constexpr uint32_t STORE_BLOCK_N = kSwapAB ? BLOCK_N : kSwizzleCDMode / sizeof(cd_dtype_t); + + DG_STATIC_ASSERT(not kIsMulticastOnA or kNumMulticast == 1, "FP4 only supports B-multicast (2CTA along M)"); + DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M and BLOCK_M % LAYOUT_AD_M == 0, "Only support tensor memory layout A/D"); + DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast"); + // Swap-AB requires BLOCK_N = LAYOUT_AD_M (= 128) so UMMA_M after swap stays = 128. + DG_STATIC_ASSERT(not kSwapAB or BLOCK_N == LAYOUT_AD_M, "kSwapAB requires BLOCK_N = LAYOUT_AD_M"); + DG_STATIC_ASSERT(not kSwapAB or kNumMulticast == 1, "kSwapAB initial impl: no multicast"); + + // Shared memory sizes + constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = kSwapAB + ? STORE_BLOCK_M * STORE_BLOCK_N * sizeof(cd_dtype_t) + : STORE_BLOCK_M * kSwizzleCDMode; + constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages; + constexpr uint32_t SMEM_A_PACKED_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(uint32_t); + constexpr uint32_t SMEM_B_PACKED_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(uint32_t); + constexpr uint32_t SF_BLOCK_M = constexpr_align(BLOCK_M, kNumUTCCPAlignedElems); + constexpr uint32_t SF_BLOCK_N = constexpr_align(BLOCK_N, kNumUTCCPAlignedElems); + constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * SF_PACKED_K_PER_STAGE * sizeof(uint32_t); + constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * SF_PACKED_K_PER_STAGE * sizeof(uint32_t); + + DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory must be aligned to 1024 bytes"); + DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages"); + + // Tensor memory size and offsets + constexpr uint32_t kNumSFATmemCols = (SF_BLOCK_M / 32) * SF_PACKED_K_PER_STAGE; + constexpr uint32_t kNumSFBTmemCols = (SF_BLOCK_N / 32) * SF_PACKED_K_PER_STAGE; + constexpr uint32_t kNumEpilogueStages = (2 * kNumMWaves * BLOCK_N + kNumSFATmemCols + kNumSFBTmemCols) > 512 ? 1 : 2; + constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * kNumMWaves * BLOCK_N; + constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols(); + constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols; + constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols; + + // Prefetch TMA descriptors at the very beginning + if (threadIdx.x == 0) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_sfa); + cute::prefetch_tma_descriptor(&tensor_map_sfb); + cute::prefetch_tma_descriptor(&tensor_map_d); + } + + // D/A/B shared memory + cd_dtype_t* smem_cd[kNumTMAStoreStages]; + uint32_t* smem_sfa[kNumStages]; + uint32_t* smem_sfb[kNumStages]; + uint32_t* smem_a_packed[kNumStages]; + uint32_t* smem_b_packed[kNumStages]; + + #pragma unroll + for (uint32_t i = 0; i < kNumTMAStoreStages; ++ i) + smem_cd[i] = reinterpret_cast(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE); + + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + smem_a_packed[i] = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_PACKED_SIZE_PER_STAGE); + smem_b_packed[i] = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_PACKED_SIZE_PER_STAGE + i * SMEM_B_PACKED_SIZE_PER_STAGE); + } + + // SFA/SFB shared memory + auto sf_start_ptr = smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_PACKED_SIZE_PER_STAGE + SMEM_B_PACKED_SIZE_PER_STAGE); + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + smem_sfa[i] = reinterpret_cast(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE); + smem_sfb[i] = reinterpret_cast(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE); + } + + // Barriers and tensor memory pointer + auto barrier_start_ptr = reinterpret_cast(smem_buffer + + SMEM_CD_SIZE + + kNumStages * (SMEM_A_PACKED_SIZE_PER_STAGE + SMEM_B_PACKED_SIZE_PER_STAGE) + + kNumStages * (SMEM_SFA_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE)); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto with_sf_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); + auto tmem_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); }); + auto tmem_empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + kNumEpilogueStages + i); }); + + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2); + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + + if (threadIdx.x == 0) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(1); + with_sf_full_barriers[i]->init(kNumMulticast * 32); + } + #pragma unroll + for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) { + tmem_full_barriers[i]->init(1); + tmem_empty_barriers[i]->init(kNumMulticast * kNumEpilogueThreads); + } + cutlass::arch::fence_view_async_shared(); + cutlass::arch::fence_barrier_init(); + } else if (threadIdx.x >= 32 and threadIdx.x < 64) { + Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = sched::Scheduler(shape_m, shape_n, shape_k, grouped_layout); + + // K-loop driver + struct DivisibleK {}; + struct NotDivisibleK {}; + uint32_t phase = 0; + + auto launch_k_iterations = [&](const auto& func) { + const uint32_t current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k); + const uint32_t num_iterations = ceil_div(current_shape_k, kNumStages * BLOCK_K); + const uint32_t num_last_stages = ceil_div(current_shape_k, BLOCK_K) % kNumStages; + + if (num_last_stages == 0) { + for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter, phase ^= 1) + func(k_iter, DivisibleK{}, k_iter == num_iterations - 1, num_last_stages); + } else { + for (uint32_t k_iter = 0; k_iter < num_iterations - 1; ++ k_iter, phase ^= 1) + func(k_iter, DivisibleK{}, false, num_last_stages); + func(num_iterations - 1, NotDivisibleK{}, true, num_last_stages), phase ^= 1; + } + }; + + auto dispatch_accum_stage_idx = [&](uint32_t accum_stage_idx, const auto& func) { + DG_STATIC_ASSERT(1 <= kNumEpilogueStages and kNumEpilogueStages <= 2, "Too many epilogue stages"); + accum_stage_idx == 0 ? func(0) : func(1); + }; + + // Dispatch warps into different roles: + // warp 0 : TMA load producer + // warp 1 : MMA consumer + UTCCP SF copy to TMEM + // warp 2 : SF SMEM warp transpose for UTCCP + // warp 3+ : Epilogue + if (warp_idx == 0) { + // TMA load warp + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) { + constexpr bool kHasDivisibleStages = cute::is_same_v; + const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages; + + #pragma unroll + for (uint32_t s = 0; s < kNumInnerStages; ++ s) { + empty_barriers[s]->wait(phase ^ 1); + + uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), sched::IndexType::MN>(shape_m, BLOCK_M, m_block_idx); + uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), sched::IndexType::MN>(shape_n, BLOCK_N, n_block_idx, m_block_idx); + + DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major"); + uint32_t k_block_idx = k_iter * kNumStages + s; + uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), sched::IndexType::K>(shape_k, BLOCK_K, k_block_idx, m_block_idx); + uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), sched::IndexType::K>(shape_k, BLOCK_K, k_block_idx, m_block_idx); + + if constexpr (kNumMulticast > 1) { + m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0; + n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N); + } + + if (cute::elect_one_sync()) { + if constexpr (kMajorA == cute::UMMA::Major::K) + tma_copy(&tensor_map_a, full_barriers[s], smem_a_packed[s], k_a_idx, m_idx); + if constexpr (kMajorA == cute::UMMA::Major::MN) + tma_copy(&tensor_map_a, full_barriers[s], smem_a_packed[s], m_idx, k_a_idx); + if constexpr (kMajorB == cute::UMMA::Major::K) + tma_copy(&tensor_map_b, full_barriers[s], smem_b_packed[s], k_b_idx, n_idx); + if constexpr (kMajorB == cute::UMMA::Major::MN) + tma_copy(&tensor_map_b, full_barriers[s], smem_b_packed[s], n_idx, k_b_idx); + } + auto num_arrival_bytes = SMEM_A_PACKED_SIZE_PER_STAGE + SMEM_B_PACKED_SIZE_PER_STAGE; + + const uint32_t sfa_tma_stage = (k_iter * kNumStages + s) % kNumSFAStagesPerLoad; + if (sfa_tma_stage == 0 and cute::elect_one_sync()) { + uint32_t sf_k_base = k_block_idx / kNumSFAStagesPerLoad * SF_PACKED_K_PER_STAGE; + #pragma unroll + for (uint32_t pk = 0; pk < SF_PACKED_K_PER_STAGE; ++ pk) { + tma_copy(&tensor_map_sfa, full_barriers[s], smem_sfa[s] + pk * SF_BLOCK_M, m_block_idx * BLOCK_M, + scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), sched::IndexType::SF_K>(shape_sfa_k, 1, sf_k_base + pk)); + } + num_arrival_bytes += BLOCK_M * SF_PACKED_K_PER_STAGE * sizeof(uint32_t); + } + const uint32_t sfb_tma_stage = (k_iter * kNumStages + s) % kNumSFBStagesPerLoad; + if (sfb_tma_stage == 0 and cute::elect_one_sync()) { + uint32_t sf_k_base = k_block_idx / kNumSFBStagesPerLoad * SF_PACKED_K_PER_STAGE; + #pragma unroll + for (uint32_t pk = 0; pk < SF_PACKED_K_PER_STAGE; ++ pk) { + tma_copy(&tensor_map_sfb, full_barriers[s], smem_sfb[s] + pk * SF_BLOCK_N, n_block_idx * BLOCK_N, + scheduler.template get_global_idx(shape_sfb_k, 1, sf_k_base + pk, m_block_idx)); + } + num_arrival_bytes += BLOCK_N * SF_PACKED_K_PER_STAGE * sizeof(uint32_t); + } + + if (cute::elect_one_sync()) + full_barriers[s]->arrive_and_expect_tx(num_arrival_bytes); + } + + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + empty_barriers[s]->wait(phase ^ 1); + if (cute::elect_one_sync()) + full_barriers[s]->arrive(); + } + }); + } + } else if (warp_idx == 1 and is_leader_cta) { + // MMA + UTCCP SF copy warp + constexpr uint32_t UMMA_M = LAYOUT_AD_M * (kIsMulticastOnA ? 1 : kNumMulticast); + // Swap-AB: UMMA_N becomes BLOCK_M (the original M dim now plays N role in MMA). + constexpr uint32_t UMMA_N = kSwapAB ? BLOCK_M : BLOCK_N * (kIsMulticastOnA ? kNumMulticast : 1); + constexpr uint32_t UMMA_K_INT32 = UMMA_K_FP4 / FP4_ELEMS_PER_INT32; + constexpr uint32_t NUM_K_ITERS_PER_STAGE = BLOCK_K / UMMA_K_INT32; + // After swap, the "N iters" walk over BLOCK_M (now the MMA-N axis); without swap, over BLOCK_N. + constexpr uint32_t NUM_N_ITERS = (kSwapAB ? BLOCK_M : BLOCK_N) / UMMA_N; + + // Swap-AB: pass (b_dtype, a_dtype, ..., kMajorB, kMajorA) to MMA so B occupies MMA-A slot + // and A occupies MMA-B slot. Output in TMEM is D^T (N rows, M cols). + auto instr_desc_mxf4 = kSwapAB + ? cute::UMMA::make_instr_desc_block_scaled< + cutlass::float_e2m1_t, cutlass::float_e2m1_t, float, cutlass::float_ue8m0_t, + UMMA_M, UMMA_N, kMajorB, kMajorA>() + : cute::UMMA::make_instr_desc_block_scaled< + cutlass::float_e2m1_t, cutlass::float_e2m1_t, float, cutlass::float_ue8m0_t, + UMMA_M, UMMA_N, kMajorA, kMajorB>(); + + using cute_mma_mxf4_noswap_t = cute::conditional_t, + cute::SM100_MMA_MXF4_2x1SM_SS>; + using cute_mma_mxf4_swap_t = cute::conditional_t, + cute::SM100_MMA_MXF4_2x1SM_SS>; + using cute_mma_mxf4_t = cute::conditional_t; + + DG_STATIC_ASSERT(UMMA_M == 128 or UMMA_M == 256, "MXF4 supports M=128 (1CTA) or M=256 (2CTA)"); + DG_STATIC_ASSERT((UMMA_N % 8 == 0) and (8 <= UMMA_N) and (UMMA_N <= 256), "Invalid MXF4 N-mode size"); + + using cute_utccp_t = cute::conditional_t; + auto sf_desc = make_sf_desc(nullptr); + + constexpr uint32_t SMEM_A_SIZE_PER_STAGE_PACKED = LOAD_BLOCK_M * BLOCK_K * sizeof(uint32_t); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE_PACKED = LOAD_BLOCK_N * BLOCK_K * sizeof(uint32_t); + auto a_desc_base = make_umma_desc(smem_a_packed[0], 0, 0); + auto b_desc_base = make_umma_desc(smem_b_packed[0], 0, 0); + + // MXF4 SF addressing constants + constexpr uint32_t kGroupsPerUmmaStep = UMMA_K_FP4 / MXF4_VS; // 2 + constexpr uint32_t kGroupsPerPacked = 4; + constexpr uint32_t kUmmaStepsPerPacked = kGroupsPerPacked / kGroupsPerUmmaStep; // 2 + constexpr uint32_t kSfaColsPerPackedGroup = SF_BLOCK_M / 32; + constexpr uint32_t kSfbColsPerPackedGroup = SF_BLOCK_N / 32; + + // Pre-compute runtime descriptors (sf_id alternates 0, 2, 0, 2) + const auto runtime_desc_sf0 = make_runtime_instr_desc_with_sf_id(instr_desc_mxf4, 0, 0); + const auto runtime_desc_sf2 = make_runtime_instr_desc_with_sf_id(instr_desc_mxf4, kGroupsPerUmmaStep, kGroupsPerUmmaStep); + + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + dispatch_accum_stage_idx(scheduler.current_iter % kNumEpilogueStages, [&](uint32_t accum_stage_idx) { + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1); + tcgen05_after_thread_sync(); + + auto empty_barrier_arrive = [&](uint32_t s, bool do_tmem_full_arrive) { + auto umma_arrive = [](const uint64_t* barrier) { + if constexpr (kNumMulticast == 1) { + cutlass::arch::umma_arrive(barrier); + } else { + constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1; + cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask); + } + }; + umma_arrive(reinterpret_cast(empty_barriers[s])); + if (do_tmem_full_arrive) + umma_arrive(reinterpret_cast(tmem_full_barriers[accum_stage_idx])); + }; + + launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) { + constexpr bool kHasDivisibleStages = cute::is_same_v; + const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages; + + #pragma unroll + for (uint32_t s = 0; s < kNumInnerStages; ++ s) { + with_sf_full_barriers[s]->wait(phase); + tcgen05_after_thread_sync(); + + // UTCCP: copy SF from SMEM → TMEM + // Must stay on warp 1: SF TMEM cols are reused across stages, + // so UTCCP must be serialized with MMA on the same warp. + const uint32_t sfa_copy_stage = (k_iter * kNumStages + s) % kNumSFAStagesPerLoad; + if (sfa_copy_stage == 0 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t pk = 0; pk < SF_PACKED_K_PER_STAGE; ++ pk) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) { + replace_smem_desc_addr(sf_desc, smem_sfa[s] + pk * SF_BLOCK_M + i * kNumUTCCPAlignedElems); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + pk * (SF_BLOCK_M / 32) + i * 4); + } + } + #pragma unroll + for (uint32_t pk = 0; pk < SF_PACKED_K_PER_STAGE; ++ pk) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) { + replace_smem_desc_addr(sf_desc, smem_sfb[s] + pk * SF_BLOCK_N + i * kNumUTCCPAlignedElems); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + pk * (SF_BLOCK_N / 32) + i * 4); + } + } + } + __syncwarp(); + + // MMA loop + uint32_t a_desc_stage_lo = a_desc_base.lo + s * (SMEM_A_SIZE_PER_STAGE_PACKED / 16); + uint32_t b_desc_stage_lo = b_desc_base.lo + s * (SMEM_B_SIZE_PER_STAGE_PACKED / 16); + + #pragma unroll + for (uint32_t k = 0; k < NUM_K_ITERS_PER_STAGE; ++k) { + uint32_t packed_group = k / kUmmaStepsPerPacked; + uint32_t tmem_sfa_base = kTmemStartColOfSFA + packed_group * kSfaColsPerPackedGroup; + uint32_t tmem_sfb_k = kTmemStartColOfSFB + packed_group * kSfbColsPerPackedGroup; + const auto& runtime_desc_k = (k % kUmmaStepsPerPacked == 0) ? runtime_desc_sf0 : runtime_desc_sf2; + + #pragma unroll + for (uint32_t n = 0; n < NUM_N_ITERS; ++n) { + auto b_desc = b_desc_base; + b_desc.lo = advance_umma_desc_lo( + b_desc_stage_lo, n * UMMA_N * BLOCK_K, k * UMMA_K_INT32); + + #pragma unroll + for (uint32_t w = 0; w < kNumMWaves; ++w) { + auto a_desc = a_desc_base; + a_desc.lo = advance_umma_desc_lo( + a_desc_stage_lo, w * LAYOUT_AD_M * BLOCK_K, k * UMMA_K_INT32); + + uint32_t tmem_col = accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N + n * UMMA_N; + if constexpr (kSwapAB) { + // Swap-AB: B goes into MMA-A slot, A into MMA-B slot. + // SF column args also swap (SFB first, SFA second). + cute_mma_mxf4_t::fma(b_desc, a_desc, tmem_col, + k_iter > 0 or s > 0 or k > 0, + runtime_desc_k, + tmem_sfb_k, + tmem_sfa_base + w * (kNumUTCCPAlignedElems / 32)); + } else { + cute_mma_mxf4_t::fma(a_desc, b_desc, tmem_col, + k_iter > 0 or s > 0 or k > 0, + runtime_desc_k, + tmem_sfa_base + w * (kNumUTCCPAlignedElems / 32), + tmem_sfb_k); + } + } + } + } + + empty_barrier_arrive(s, is_last_iter and s == kNumInnerStages - 1); + } + + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + with_sf_full_barriers[s]->wait(phase); + empty_barrier_arrive(s, false); + } + }); + }); + } + } else if (warp_idx == 2) { + // SF transpose warp + auto utccp_required_smem_warp_transpose = [&](const uint32_t* smem_ptr) { + DG_STATIC_ASSERT(kNumUTCCPAlignedElems == 128, "Invalid aligned elements"); + uint32_t values[4]; + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + values[i] = ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx); + __syncwarp(); + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]); + }; + auto fill_sfb_missing_k_groups = [&](uint32_t* smem_ptr) { + if constexpr (BLOCK_N < kNumUTCCPAlignedElems) { + // Zero-fill [BLOCK_N, SF_BLOCK_N) before warp-transpose: XOR pattern + // reads all 128 elements, so uninitialized positions corrupt valid data. + #pragma unroll + for (uint32_t pos = lane_idx; pos < kNumUTCCPAlignedElems; pos += 32) { + if (pos >= BLOCK_N) + st_shared(smem_ptr + pos, 0u); + } + __syncwarp(); + } + }; + + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) { + constexpr bool kHasDivisibleStages = cute::is_same_v; + const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages; + + #pragma unroll + for (uint32_t s = 0; s < kNumInnerStages; ++ s) { + full_barriers[s]->wait(phase); + + const uint32_t sfa_ut_stage = (k_iter * kNumStages + s) % kNumSFAStagesPerLoad; + if (sfa_ut_stage == 0) { + #pragma unroll + for (uint32_t pk = 0; pk < SF_PACKED_K_PER_STAGE; ++ pk) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) + utccp_required_smem_warp_transpose(smem_sfa[s] + pk * SF_BLOCK_M + i * kNumUTCCPAlignedElems); + } + cutlass::arch::fence_view_async_shared(); + } + + const uint32_t sfb_ut_stage = (k_iter * kNumStages + s) % kNumSFBStagesPerLoad; + if (sfb_ut_stage == 0) { + #pragma unroll + for (uint32_t pk = 0; pk < SF_PACKED_K_PER_STAGE; ++ pk) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) { + fill_sfb_missing_k_groups(smem_sfb[s] + pk * SF_BLOCK_N + i * kNumUTCCPAlignedElems); + utccp_required_smem_warp_transpose(smem_sfb[s] + pk * SF_BLOCK_N + i * kNumUTCCPAlignedElems); + } + } + cutlass::arch::fence_view_async_shared(); + } + + with_sf_full_barriers[s]->arrive(0u); + } + + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + full_barriers[s]->wait(phase); + with_sf_full_barriers[s]->arrive(0u); + } + }); + } + } else if (warp_idx >= kNumNonEpilogueThreads / 32) { + // Epilogue warps + const auto epilogue_thread_idx = threadIdx.x - kNumNonEpilogueThreads; + const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32); + + DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t); + + if constexpr (kSwapAB) { + // Swap-AB epilogue: TMEM holds D^T (BLOCK_N rows x BLOCK_M cols); each + // accum stage covers BLOCK_M TMEM cols. STORE_BLOCK_M=16 lets the loop + // skip 16-row M slices that are entirely padding (effective_m / 16), + // while STORE_BLOCK_N=BLOCK_N stores the full N tile per iteration. + constexpr uint32_t kNumSwizzleAtomRows = 8; + constexpr uint32_t STORE_BLOCK_N_ATOM = kSwizzleCDMode / sizeof(cd_dtype_t); + constexpr uint32_t kNumWarpsPerAtom = STORE_BLOCK_N_ATOM / 32; + DG_STATIC_ASSERT(STORE_BLOCK_M % kNumSwizzleAtomRows == 0, "Invalid swap-AB store_block_m"); + DG_STATIC_ASSERT(STORE_BLOCK_N % STORE_BLOCK_N_ATOM == 0, "Invalid swap-AB store_block_n"); + DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Swap-AB requires full warpgroup"); + + uint32_t tma_stage_idx = 0; + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + dispatch_accum_stage_idx(scheduler.current_iter % kNumEpilogueStages, [&](uint32_t accum_stage_idx) { + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + + tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx); + tcgen05_after_thread_sync(); + + // Effective M (aligned up to STORE_BLOCK_M=16): how many M-cols of D^T are valid. + const auto effective_m = scheduler.get_aligned_effective_m_in_block(m_block_idx); + const uint32_t num_stores = effective_m / STORE_BLOCK_M; + + // TMEM col where this accum stage's tile starts. + const auto tmem_base_addr = accum_stage_idx * BLOCK_M; + const auto base_m_idx = scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), sched::IndexType::MN>(shape_m, BLOCK_M, m_block_idx); + const auto base_n_idx = n_block_idx * BLOCK_N; + + #pragma unroll 1 + for (uint32_t s = 0; s < num_stores; ++ s) { + // Wait if TMA store pipeline full + if (s >= kNumTMAStoreStages) { + if (epilogue_thread_idx == 0) + cute::tma_store_wait(); + cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync(); + } + + // SMEM store: 4 warps cooperatively write STORE_BLOCK_M × STORE_BLOCK_N tile. + // Each warp owns STORE_BLOCK_M rows × STORE_BLOCK_N_ATOM cols. + // Within each warp, loop covers STORE_BLOCK_M / kNumSwizzleAtomRows = 2 sub-blocks of 8 rows. + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_M / kNumSwizzleAtomRows; ++ i) { + uint32_t tmem_addr = tmem_base_addr + + s * STORE_BLOCK_M + // M-slice (cols of TMEM) + i * kNumSwizzleAtomRows; // Sub-block within slice + + uint32_t outer_atom_offset = (epilogue_warp_idx / kNumWarpsPerAtom) * STORE_BLOCK_M * kSwizzleCDMode; + uint32_t inner_atom_offset = i * kNumSwizzleAtomRows * kSwizzleCDMode; + auto smem_base_ptr = reinterpret_cast(smem_cd[tma_stage_idx]) + + outer_atom_offset + inner_atom_offset; + + if constexpr (cute::is_same_v) { + uint32_t values[kNumSwizzleAtomRows]; + // Load 32dp × 8 cols of TMEM per warp (= 8 M-cols × 32 N-rows of D). + cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr, + values[0], values[1], values[2], values[3], + values[4], values[5], values[6], values[7]); + cutlass::arch::fence_view_async_tmem_load(); + uint32_t col = lane_idx / 4; + #pragma unroll + for (uint32_t row = 0; row < kNumSwizzleAtomRows; ++ row) { + auto smem_ptr = smem_base_ptr + row * (kNumBankGroupBytes * 8) + + (col ^ row) * kNumBankGroupBytes + + (lane_idx % 4) * sizeof(float); + st_shared(reinterpret_cast(smem_ptr), values[row]); + } + } + } + + // Notify TMEM empty on last store + if (s == num_stores - 1) { + tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + } + __syncwarp(); + + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync(); + + if (epilogue_thread_idx == 0) { + uint32_t m_idx = base_m_idx + s * STORE_BLOCK_M; + #pragma unroll + for (uint32_t ai = 0; ai < STORE_BLOCK_N / STORE_BLOCK_N_ATOM; ++ ai) { + auto smem_ptr = smem_cd[tma_stage_idx] + ai * STORE_BLOCK_M * STORE_BLOCK_N_ATOM; + uint32_t n_idx = base_n_idx + ai * STORE_BLOCK_N_ATOM; + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_d, smem_ptr, n_idx, m_idx); + } + cute::tma_store_arrive(); + } + + tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages; + } + + // If entire tile is padding (effective_m=0, hence num_stores=0): still arrive at empty barrier + // so MMA pipeline can advance. + if (num_stores == 0) { + tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + } + }); + } + } else { + + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + dispatch_accum_stage_idx(scheduler.current_iter % kNumEpilogueStages, [&](uint32_t accum_stage_idx) { + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + + if (epilogue_thread_idx == 0) + cute::tma_store_wait<0>(); + cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync(); + + tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx); + tcgen05_after_thread_sync(); + + DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough"); + DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); + + #pragma unroll + for (uint32_t w = 0; w < kNumMWaves; ++ w) { + constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N; + #pragma unroll + for (uint32_t s = 0; s < kNumStores; ++ s) { + const uint32_t iter_idx = w * kNumStores + s; + if (iter_idx >= kNumTMAStoreStages) { + if (epilogue_thread_idx == 0) + cute::tma_store_wait(); + cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync(); + } + + const auto tma_stage_idx = iter_idx % kNumTMAStoreStages; + const auto m_idx = scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), sched::IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * LAYOUT_AD_M; + const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N; + + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) { + auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes); + constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (i) : (bank_group_index % 8); + col ^= row % (kSwizzleCDMode / 16); + + uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N + s * STORE_BLOCK_N + + i * kNumElemsPerBankGroup; + auto smem_ptr = reinterpret_cast(smem_cd[tma_stage_idx]) + + epilogue_warp_idx * 32 * kSwizzleCDMode + + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; + + uint32_t values[kNumElemsPerBankGroup]; + if constexpr (cute::is_same_v) { + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, values[0], values[1], values[2], values[3]); + cutlass::arch::fence_view_async_tmem_load(); + st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + } + } + + if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) { + tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + } + __syncwarp(); + + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync(); + if (epilogue_thread_idx == 0) { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_d, smem_cd[tma_stage_idx], n_idx, m_idx); + cute::tma_store_arrive(); + } + } + } + }); + } + } // end of if constexpr (kSwapAB) else (non-swap) + + if (epilogue_thread_idx == 0) + cute::tma_store_wait<0>(); + + if (epilogue_warp_idx == 1) + Allocator().free(0, kNumTmemCols); + } + + if constexpr (kNumMulticast > 1) + cute::cluster_sync(); +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/tests/test_fp4.py b/tests/test_fp4.py new file mode 100644 index 0000000000..cb3d136a4a --- /dev/null +++ b/tests/test_fp4.py @@ -0,0 +1,134 @@ +import random +import torch + +import deep_gemm +from deep_gemm.testing import bench_kineto, calc_diff, count_bytes + +from generators import ( + KernelType, MajorTypeAB, QuantConfig, + generate_normal, generate_m_grouped_contiguous, generate_m_grouped_masked, +) + + +# FP4xFP4 (MXFP4): both operands packed FP4 (E2M1) with VS=32 UE8M0 scales. +# Dispatches to the SM100_MMA_MXF4_SS path; the API rejects bf16 D so we +# always allocate fp32 D below. +FP4_FP4 = QuantConfig((32, 32, True, True)) + + +def test_gemm() -> None: + print('Testing GEMM:') + nk_list = [(2112, 7168), (576, 7168), (24576, 1536), (32768, 512), + (7168, 16384), (4096, 7168), (7168, 2048)] + m_list = [128, 4096] + kernel_type = KernelType.Kernel1D1D + out_dtype = torch.float + recipe, recipe_a, recipe_b = FP4_FP4.get_recipes() + + for m in m_list: + for n, k in nk_list: + a, b, c, d, ref_d = generate_normal( + m, n, k, MajorTypeAB.KMajor, MajorTypeAB.KMajor, + accumulate=False, out_dtype=out_dtype, kernel_type=kernel_type, + use_ue8m0=True, quant_config=FP4_FP4) + deep_gemm.fp8_fp4_gemm_nt( + a, b, d, c=c, disable_ue8m0_cast=False, + recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b) + diff = calc_diff(d, ref_d) + assert diff < FP4_FP4.max_diff(), \ + f'{m=}, {n=}, {k=}, {diff:.5f}' + + t = bench_kineto( + lambda: deep_gemm.fp8_fp4_gemm_nt( + a, b, d, c=c, disable_ue8m0_cast=False, + recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b), + 'sm100_fp4_gemm', suppress_kineto_output=True) + print(f' > Perf (m={m:6}, n={n:6}, k={k:6}, 1D1D, layout=NT, FP32): ' + f'{t * 1e6:6.1f} us | {2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes(a, b, d) / 1e9 / t:4.0f} GB/s') + + +def test_m_grouped_gemm_contiguous() -> None: + print('Testing m-grouped contiguous GEMM:') + m_group_list = [(4, 8192), (8, 4096)] + n_k_list = [(4096, 7168), (7168, 2048), (24576, 1536), (32768, 512)] + out_dtype = torch.float + recipe, recipe_a, recipe_b = FP4_FP4.get_recipes() + + for num_groups, expected_m_per_group in m_group_list: + for n, k in n_k_list: + m, a, b, m_indices, d_bf16, _ref_bf16 = generate_m_grouped_contiguous( + num_groups, expected_m_per_group, n, k, + MajorTypeAB.KMajor, MajorTypeAB.KMajor, + use_ue8m0=True, quant_config=FP4_FP4) + # generate_m_grouped_contiguous returns bf16 d/ref; allocate fp32 d for + # the FP4xFP4 API and cast the reference for calc_diff. + d = torch.empty_like(d_bf16, dtype=out_dtype) + ref_d = _ref_bf16.to(out_dtype) + deep_gemm.m_grouped_fp8_fp4_gemm_nt_contiguous( + a, b, d, m_indices, + recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b, + disable_ue8m0_cast=False, use_psum_layout=False) + diff = calc_diff(d, ref_d) + assert diff < FP4_FP4.max_diff(), \ + f'{num_groups=}, {m=}, {n=}, {k=}, {diff:.5f}' + + t = bench_kineto( + lambda: deep_gemm.m_grouped_fp8_fp4_gemm_nt_contiguous( + a, b, d, m_indices, + recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b, + disable_ue8m0_cast=False, use_psum_layout=False), + 'sm100_fp4_gemm', suppress_kineto_output=True) + print(f' > Perf (num_groups={num_groups:2}, expected_m_per_group={expected_m_per_group:5}, ' + f'n={n:5}, k={k:5}, 1D1D, FP32): ' + f'{t * 1e6:6.1f} us | {2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes(a, b, d) / 1e9 / t:4.0f} GB/s') + + +def test_m_grouped_gemm_masked() -> None: + print('Testing m-grouped masked GEMM:') + max_m = 4096 + m_group_list = [(1, 1024), (2, 512), (4, 256)] + n_k_list = [(4096, 7168), (7168, 2048)] + out_dtype = torch.float + recipe, recipe_a, recipe_b = FP4_FP4.get_recipes() + + for num_groups, expected_m_per_group in m_group_list: + for n, k in n_k_list: + a, b, masked_m, _psum_m, d_bf16, _ref_bf16 = generate_m_grouped_masked( + num_groups, max_m, expected_m_per_group, n, k, + use_ue8m0=True, quant_config=FP4_FP4) + d = torch.empty_like(d_bf16, dtype=out_dtype) + ref_d = _ref_bf16.to(out_dtype) + deep_gemm.m_grouped_fp8_fp4_gemm_nt_masked( + a, b, d, masked_m, expected_m_per_group, + recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b, + disable_ue8m0_cast=False) + # Diff over the valid (non-padding) rows per group only. + for g in range(num_groups): + mm = int(masked_m[g].item()) + diff = calc_diff(d[g, :mm], ref_d[g, :mm]) + assert diff < FP4_FP4.max_diff(), \ + f'{num_groups=}, group={g}, masked_m={mm}, {n=}, {k=}, {diff:.5f}' + + t = bench_kineto( + lambda: deep_gemm.m_grouped_fp8_fp4_gemm_nt_masked( + a, b, d, masked_m, expected_m_per_group, + recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b, + disable_ue8m0_cast=False), + 'sm100_fp4_gemm', suppress_kineto_output=True) + print(f' > Perf (num_groups={num_groups}, max_m={max_m}, expected_m_per_group={expected_m_per_group:5}, ' + f'n={n:5}, k={k:5}, 1D1D, FP32): ' + f'{t * 1e6:6.1f} us | {2 * num_groups * expected_m_per_group * n * k / t / 1e12:4.0f} TFLOPS') + + +if __name__ == '__main__': + torch.manual_seed(0) + random.seed(0) + + print('Library path:') + print(f' > {deep_gemm.__path__}\n') + + test_gemm() + test_m_grouped_gemm_contiguous() + test_m_grouped_gemm_masked()