Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ out_ij = out_ij.sum() # Scalar

For more details and the paged version `fp8_paged_mqa_logits`, please refer to `tests/test_attention.py`.

#### W4Afp8
- W4AFP8 (INT4-bit weight, FP8 activation) GEMM kernel for Hopper (SM90). Supports Normal GEMM, M-Grouped Contiguous GEMM, and M-Grouped Masked GEMM.

- Algorithm compatible with https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8, but uses a custom weight layout (see `convert_fp8_to_int4` in `tests/generators.py`).

#### Mega MoE

Mega MoE fuses and overlaps EP dispatch, linear 1 (FP8xFP4), SwiGLU, linear 2 (FP8xFP4), and EP combine into a single mega-kernel, overlapping NVLink communication and tensor core computation. It requires multi-process launch with symmetric memory. Usage:
Expand Down
120 changes: 120 additions & 0 deletions csrc/apis/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,113 @@ static void k_grouped_fp8_gemm_nt_contiguous(const std::pair<torch::Tensor, torc
DG_HOST_UNREACHABLE("Unsupported architecture");
}
}

// ====================
// SM90 W4AFP8 APIs
// ====================

static void sm90_w4afp8_gemm_nt(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const std::optional<torch::Tensor>& c,
const std::string& compiled_dims) {
const auto major_a = get_major_type_ab(a.first);
const auto major_b = get_major_type_ab(b.first);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K);
DG_HOST_ASSERT(major_b == cute::UMMA::Major::K);
check_major_type_cd(d);

const auto [m , k ] = get_shape<2>(a.first);
const auto [n , k_] = get_shape<2>(b.first);
const auto [m_, n_] = get_shape<2>(d);
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_ * 2);
DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn);
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat);

if (early_return(m, n, k, d, c))
return;

const auto recipe = std::make_tuple(m, 1, 128);
const auto sfa = a.second;
const auto sfb = layout::transform_sf_into_required_layout(
b.second, n, k, recipe, std::nullopt, false, false);
const auto major_sfb = cute::UMMA::Major::K;

sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k,
major_a, major_b, major_sfb, compiled_dims);
}

static void sm90_m_grouped_w4afp8_gemm_nt_contiguous(
const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const torch::Tensor& m_indices,
const std::string& compiled_dims) {
const auto major_a = get_major_type_ab(a.first);
const auto major_b = get_major_type_ab(b.first);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K);
DG_HOST_ASSERT(major_b == cute::UMMA::Major::K);
DG_HOST_ASSERT(m_indices.is_contiguous());
check_major_type_cd(d);

const auto [m , k ] = get_shape<2>(a.first);
const auto [num_groups, n, k_] = get_shape<3>(b.first);
const auto [m_, n_] = get_shape<2>(d);
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_ * 2);
DG_HOST_ASSERT(n > 0 and k > 0 and num_groups > 0);
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(m_indices.scalar_type() == torch::kInt);

const auto [m__] = get_shape<1>(m_indices);
DG_HOST_ASSERT(m == m__);

if (m == 0)
return;

const auto recipe = std::make_tuple(m, 1, 128);
const auto sfa = a.second;
const auto sfb = layout::transform_sf_into_required_layout(
b.second, n, k, recipe, num_groups, false, false);
const auto major_sfb = cute::UMMA::Major::K;

sm90_m_grouped_fp8_gemm_contiguous_1d2d(a.first, sfa, b.first, sfb, d, m_indices,
num_groups, m, n, k, major_a, major_b, major_sfb,
compiled_dims, false, std::nullopt);
}

static void sm90_m_grouped_w4afp8_gemm_nt_masked(
const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const torch::Tensor& masked_m,
const int& expected_m,
const std::string& compiled_dims) {
const auto major_a = get_major_type_ab(a.first);
const auto major_b = get_major_type_ab(b.first);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
DG_HOST_ASSERT(masked_m.is_contiguous());
check_major_type_cd(d);

const auto [num_groups , m , k ] = get_shape<3>(a.first);
const auto [num_groups_ , n , k_] = get_shape<3>(b.first);
const auto [num_groups__, m_, n_] = get_shape<3>(d);
const auto num_groups___ = static_cast<int>(masked_m.numel());
DG_HOST_ASSERT(num_groups == num_groups_ and num_groups == num_groups__ and num_groups == num_groups___);
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_ * 2);
DG_HOST_ASSERT(expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0);
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(masked_m.scalar_type() == torch::kInt);

const auto recipe = std::make_tuple(m, 1, 128);
const auto sfa = a.second;
const auto sfb = layout::transform_sf_into_required_layout(
b.second, n, k, recipe, num_groups, false, false);
const auto major_sfb = cute::UMMA::Major::K;

sm90_m_grouped_fp8_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m,
num_groups, m, n, k, expected_m, major_a, major_b, major_sfb, compiled_dims);
}

#endif

#if DG_TENSORMAP_COMPATIBLE
Expand Down Expand Up @@ -663,6 +770,19 @@ static void register_apis(pybind11::module_& m) {
m.attr("m_grouped_fp8_gemm_nt_contiguous") = m.attr("m_grouped_fp8_fp4_gemm_nt_contiguous");
m.attr("m_grouped_fp8_gemm_nn_contiguous") = m.attr("m_grouped_fp8_fp4_gemm_nn_contiguous");
m.attr("m_grouped_fp8_gemm_nt_masked") = m.attr("m_grouped_fp8_fp4_gemm_nt_masked");

// SM90 W4AFP8 GEMMs
m.def("sm90_w4afp8_gemm_nt", &sm90_w4afp8_gemm_nt,
py::arg("a"), py::arg("b"), py::arg("d"),
py::arg("c") = std::nullopt,
py::arg("compiled_dims") = "nk");
m.def("sm90_m_grouped_w4afp8_gemm_nt_contiguous", &sm90_m_grouped_w4afp8_gemm_nt_contiguous,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"),
py::arg("compiled_dims") = "nk");
m.def("sm90_m_grouped_w4afp8_gemm_nt_masked", &sm90_m_grouped_w4afp8_gemm_nt_masked,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"),
py::arg("expected_m"),
py::arg("compiled_dims") = "nk");
#endif

#if DG_TENSORMAP_COMPATIBLE
Expand Down
8 changes: 7 additions & 1 deletion csrc/jit_kernels/heuristics/config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ struct GemmDesc {
int num_sms, tc_util;
std::string compiled_dims;

// W4AFP8 mode (INT4 weight, FP8 activation), SM90 only
// Weight encoding: true when B is INT4 packed as FP8 (requires software dequant).
// Detected from a.size(-1) != b.size(-1) at the API layer.
bool is_w4 = false;

// Shape for heuristic generation
int expected_m = 0, expected_n = 0, expected_k = 0, expected_num_groups = 0;
int get_expected_m() const { return expected_m > 0 ? expected_m : m; }
Expand Down Expand Up @@ -63,7 +68,8 @@ struct GemmDesc {
<< ", expected_m=" << desc.expected_m
<< ", expected_n=" << desc.expected_n
<< ", expected_k=" << desc.expected_k
<< ", expected_num_groups=" << desc.expected_num_groups << ")";
<< ", expected_num_groups=" << desc.expected_num_groups
<< ", is_w4=" << static_cast<int>(desc.is_w4) << ")";
return os;
}
};
Expand Down
122 changes: 87 additions & 35 deletions csrc/jit_kernels/heuristics/sm90.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,23 @@ struct SM90ArchSpec {
static std::vector<Layout> get_layout_candidates(const GemmDesc& desc) {
// Block M candidates
std::vector<int> block_m_candidates;
if (desc.gemm_type == GemmType::Normal or
if (desc.is_w4) {
// W4 uses swap_ab: STSM pattern requires BLOCK_M >= 64
// For masked GEMM, restrict block_m candidates based on expected_m to avoid
// unnecessary multi-M-block scheduling when a larger block_m fits exactly.
if (desc.gemm_type == GemmType::MGroupedContiguous)
block_m_candidates = std::vector{heuristics_runtime->get_mk_alignment_for_contiguous_layout()};
else if (desc.gemm_type == GemmType::MGroupedMasked) {
// W4 BM=128 (MMA_N=128): async WGMMA runs longer, hiding more of the ldmatrix+dequant latency.
// Validated for single-tile-per-SM cases; larger problems keep BM=64 as safe default.
const int est_tiles = ceil_div(desc.get_expected_m(), 64)
* ceil_div(desc.get_expected_n(), 256)
* desc.get_expected_num_groups();
const bool use_large_block_m = est_tiles <= desc.num_sms and desc.get_expected_m() >= 49;
block_m_candidates = use_large_block_m ? std::vector{128} : std::vector{64};
} else
block_m_candidates = {64, 128, 256};
} else if (desc.gemm_type == GemmType::Normal or
desc.gemm_type == GemmType::Batched or
desc.gemm_type == GemmType::KGroupedContiguous) {
// TODO: check 256's performance
Expand All @@ -37,24 +53,30 @@ struct SM90ArchSpec {

// Block N candidates
std::vector<int> block_n_candidates;
int step = std::lcm(16, heuristics_runtime->get_block_n_multiple_of());
int start = step;
// Avoid bank conflicts for 1D1D kernel FP32 output
if (desc.kernel_type == KernelType::Kernel1D1D and desc.cd_dtype == torch::kFloat) {
DG_HOST_ASSERT(desc.major_a == cute::UMMA::Major::K);
DG_HOST_ASSERT(desc.major_b == cute::UMMA::Major::K);
start = 24;
block_n_candidates.push_back(16);
if (desc.is_w4) {
// W4 small-m: STSM only fills MMA_N/64 of smem_d atom, restrict to {64}
const bool w4_small_m = (desc.gemm_type == GemmType::Normal) and (desc.m > 0 and desc.m <= 56);
block_n_candidates = w4_small_m ? std::vector{64} : std::vector{64, 128, 256};
} else {
int step = std::lcm(16, heuristics_runtime->get_block_n_multiple_of());
int start = step;
// Avoid bank conflicts for 1D1D kernel FP32 output
if (desc.kernel_type == KernelType::Kernel1D1D and desc.cd_dtype == torch::kFloat) {
DG_HOST_ASSERT(desc.major_a == cute::UMMA::Major::K);
DG_HOST_ASSERT(desc.major_b == cute::UMMA::Major::K);
start = 24;
block_n_candidates.push_back(16);
}
// Register spills
int end = 256;
if (desc.kernel_type == KernelType::Kernel1D2D)
end = 192;
if (desc.kernel_type == KernelType::Kernel1D1D)
end = 160;
// Enumerate
for (int i = start; i <= end; i += step)
block_n_candidates.push_back(i);
}
// Register spills
int end = 256;
if (desc.kernel_type == KernelType::Kernel1D2D)
end = 192;
if (desc.kernel_type == KernelType::Kernel1D1D)
end = 160;
// Enumerate
for (int i = start; i <= end; i += step)
block_n_candidates.push_back(i);

// Block K is always in a fixed manner
const int block_k = 128 / get_element_size(desc.get_mma_kind());
Expand Down Expand Up @@ -91,7 +113,10 @@ struct SM90ArchSpec {
continue;

// The block sizes cannot be too large (for enough registers), so at least one dim less than 128
if (block_m > 128 and block_n > 128)
// W4 RS-mode swaps layout↔compute: block_n→compute_m, block_m→compute_n
const int compute_m = desc.is_w4 ? block_n : block_m;
const int compute_n = desc.is_w4 ? block_m : block_n;
if (compute_m > 128 and compute_n > 128)
continue;

// Calculate swizzling
Expand Down Expand Up @@ -119,6 +144,7 @@ struct SM90ArchSpec {

static StorageConfig get_storage_config(const GemmDesc& desc, const Layout& layout) {
constexpr int wgmma_m = 64;
const int weight_ratio = desc.is_w4 ? 2 : 1;

// Load/store block sizes (w/o consideration of swizzling atoms, w/ consideration of loop atoms)
// TODO: support swap AB
Expand All @@ -132,8 +158,9 @@ struct SM90ArchSpec {
// Decide swizzling by the inner dim
const auto swizzle_mode_a = get_swizzle_mode(
desc.major_a == cute::UMMA::Major::K ? layout.block_k : load_block_m, c10::elementSize(desc.a_dtype));
// W4: B weight is packed INT4 (half the bytes per element), so swizzle uses block_k / weight_ratio
const auto swizzle_mode_b = get_swizzle_mode(
desc.major_b == cute::UMMA::Major::K ? layout.block_k : load_block_n, c10::elementSize(desc.b_dtype));
desc.major_b == cute::UMMA::Major::K ? (layout.block_k / weight_ratio) : load_block_n, c10::elementSize(desc.b_dtype));
// We only enable swizzling for non-FP32 outputs
const auto swizzle_mode_cd = desc.cd_dtype != torch::kFloat ?
get_swizzle_mode(store_block_n, c10::elementSize(desc.cd_dtype)) : 0;
Expand All @@ -147,6 +174,7 @@ struct SM90ArchSpec {

static PipelineConfig get_pipeline_config(const GemmDesc& desc, const Layout& layout, const StorageConfig& storage_config) {
constexpr int kNumMaxStages = 16;
const int weight_ratio = desc.is_w4 ? 2 : 1;

// TODO: consider swap AB
// C/D for TMA stores
Expand All @@ -157,17 +185,20 @@ struct SM90ArchSpec {

// Calculate A/B per stages
const int smem_a_per_stage = storage_config.load_block_m * layout.block_k * c10::elementSize(desc.a_dtype);
const int smem_b_per_stage = storage_config.load_block_n * layout.block_k * c10::elementSize(desc.b_dtype);
// W4: B weight is INT4, halved smem
const int smem_b_per_stage = storage_config.load_block_n * layout.block_k * c10::elementSize(desc.b_dtype) / weight_ratio;

// Calculate SF A/B per stages
// W4 RS-mode: SFA slot stores weight scales (compute_m = block_n), not activation scales
const int compute_m = desc.is_w4 ? layout.block_n : layout.block_m;
const int smem_sfa_per_stage = desc.kernel_type == KernelType::KernelNoSF ?
0 : align(layout.block_m * static_cast<int>(sizeof(float)), 128);
0 : align(compute_m * static_cast<int>(sizeof(float)), 128);
const int smem_sfb_per_stage = desc.kernel_type != KernelType::Kernel1D1D ?
0 : align(layout.block_n * static_cast<int>(sizeof(float)), 128);

// Extra SFB sizes for 1D2D kernels
// Extra SFB sizes for 1D2D kernels (W4 has no extra SFB — per-tensor scale)
const int use_uniform_sfb = layout.block_k % layout.block_n == 0 ? 1 : 2;
const int smem_extra_sfb = desc.kernel_type != KernelType::Kernel1D2D ?
const int smem_extra_sfb = (desc.kernel_type != KernelType::Kernel1D2D or desc.is_w4) ?
0 : align<int>(ceil_div(desc.k, layout.block_k) * static_cast<int>(sizeof(float)) * use_uniform_sfb, 8);

// Extra tensormap for 1D1D kernels
Expand All @@ -188,7 +219,9 @@ struct SM90ArchSpec {

static LaunchConfig get_launch_config(const GemmDesc& desc, const Layout& layout) {
const int num_tma_threads = 128;
const int num_math_threads = layout.block_m <= 64 ? 128 : 256;
// W4 RS-mode: block_n is compute_m after swap
const int compute_m = desc.is_w4 ? layout.block_n : layout.block_m;
const int num_math_threads = compute_m <= 64 ? 128 : 256;
return {
desc.num_sms,
layout.get_cluster_size(),
Expand All @@ -211,29 +244,48 @@ struct SM90ArchSpec {
const int l2_bandwidth_per_cycle = std::min(64. * desc.num_sms, 8e6 / (1.3e3)); // B/cycle
const int l1_bandwidth_per_cycle = 128 * desc.num_sms; // B/cycle
const int wgmma_m = 64;
const int elem_size_ab = c10::elementSize(desc.a_dtype);
const int elem_size_a = c10::elementSize(desc.a_dtype);
const int elem_size_cd = c10::elementSize(desc.cd_dtype);
DG_HOST_ASSERT(desc.a_dtype == desc.b_dtype);

// Data movement per block
// W4: B is INT4 packed — use 2x scale factor to avoid integer truncation.
// All byte counts are computed at 2x scale (bandwidth terms cancel out).
const int scale = desc.is_w4 ? 2 : 1;
const int elem_size_a_s = elem_size_a * scale;
// W4: c10::elementSize(FP8)=1, times scale=2 divided by weight_ratio=2 → 1. No truncation.
const int weight_ratio = desc.is_w4 ? 2 : 1;
const int elem_size_b_s = c10::elementSize(desc.b_dtype) * scale / weight_ratio;
const int elem_size_cd_s = elem_size_cd * scale;

// Data movement per block (at 2x scale for W4, 1x for FP8)
int64_t expected_k = desc.get_expected_k();
int64_t num_bytes_l2_ab = expected_k * (layout.block_m / layout.cluster_n + layout.block_n / layout.cluster_m) * elem_size_ab;
int64_t num_bytes_l1_ab = expected_k * (layout.block_m + layout.block_n) * elem_size_ab;
int64_t num_bytes_l1_tc = expected_k * (std::max(wgmma_m, layout.block_m) + layout.block_n) * elem_size_ab
+ layout.block_m * layout.block_n * elem_size_cd;
int64_t num_bytes_l1_l2_cd = layout.block_m * layout.block_n * elem_size_cd * (desc.with_accumulation ? 2 : 1);
int64_t num_bytes_l2_ab = expected_k * (layout.block_m / layout.cluster_n * elem_size_a_s + layout.block_n / layout.cluster_m * elem_size_b_s);
int64_t num_bytes_l1_ab = expected_k * (layout.block_m * elem_size_a_s + layout.block_n * elem_size_b_s);
int64_t num_bytes_l1_tc = expected_k * (std::max(wgmma_m, layout.block_m) * elem_size_a_s + layout.block_n * elem_size_b_s)
+ layout.block_m * layout.block_n * elem_size_cd_s;
int64_t num_bytes_l1_l2_cd = layout.block_m * layout.block_n * elem_size_cd_s * (desc.with_accumulation ? 2 : 1);

// HBM bandwidth and total compute (Tensor/CUDA cores) are constant across configs
// We only model L1/L2 cycles as they are the primary variables between configs
// Scale factor cancels in the ratio (bytes * scale) / (bandwidth * scale) when scale is uniform
int64_t num_l2_cycles = (num_bytes_l2_ab + num_bytes_l1_l2_cd) * num_blocks / l2_bandwidth_per_cycle;
int64_t num_l1_cycles = (num_bytes_l1_ab + num_bytes_l1_tc + num_bytes_l1_l2_cd) * num_blocks / l1_bandwidth_per_cycle;
float wave_efficiency = static_cast<float>(num_blocks) / (num_waves * desc.num_sms);
int64_t num_cycles = std::max(num_l1_cycles, num_l2_cycles) / wave_efficiency;

// Disable multicasting if only one wave exists
// Disable multicasting if only one wave exists: cluster sync overhead
// outweighs bandwidth savings when all SMs are already utilized.
// NOTE: For W4 masked GEMM small-m, cluster_n=2 (weight multicast) was
// tested and found NOT beneficial — L2 working set fits within 50MB cache,
// so no DRAM savings, only added cluster sync overhead.
if (layout.cluster_n * layout.cluster_m > 1 and num_waves <= 1)
num_cycles = std::numeric_limits<int64_t>::max();

// For masked GEMM: disable all multicast when expected_m ≤ block_m
// (each group has only 1 M-block, cluster sync overhead outweighs savings).
if (layout.cluster_n * layout.cluster_m > 1 and
desc.gemm_type == GemmType::MGroupedMasked and
desc.get_expected_m() <= layout.block_m)
num_cycles = std::numeric_limits<int64_t>::max();

return {num_waves, last_wave_util, num_cycles, layout};
}

Expand Down
Loading