Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 55 additions & 9 deletions csrc/apis/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "../jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp"
#include "../jit_kernels/impls/sm90_bf16_gemm.hpp"
#include "../jit_kernels/impls/sm100_fp8_fp4_gemm_1d1d.hpp"
#include "../jit_kernels/impls/sm100_fp4_gemm_1d1d.hpp"
#include "../jit_kernels/impls/sm100_bf16_gemm.hpp"
#endif

Expand Down Expand Up @@ -73,7 +74,12 @@ static void fp8_fp4_gemm_nt(const std::pair<torch::Tensor, torch::Tensor>& a,
const auto [n , k_] = check_ab_fp8_fp4(b.first, major_b, arch_major);
const auto [m_, n_] = get_shape<2>(d);
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat);
// FP4xFP4 path has no bf16 epilogue store: require fp32 D.
if (a.first.scalar_type() == kPackedFP4 and b.first.scalar_type() == kPackedFP4) {
DG_HOST_ASSERT(d.scalar_type() == torch::kFloat);
} else {
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat);
}

// Early return for trivial cases
if (early_return(m, n, k, d, c))
Expand All @@ -93,8 +99,17 @@ static void fp8_fp4_gemm_nt(const std::pair<torch::Tensor, torch::Tensor>& a,
sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, major_sfb, compiled_dims);
}
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
sm100_fp8_fp4_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, gran_k_a, gran_k_b,
if (a.first.scalar_type() == kPackedFP4 and b.first.scalar_type() == kPackedFP4) {
// FP4xFP4 path reinterprets int8 packed tensors as int32 for TMA; requires
// contiguous K-major layout and packed-int32-aligned K.
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
DG_HOST_ASSERT(k % 8 == 0);
sm100_fp4_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k,
major_a, major_b, compiled_dims);
} else {
sm100_fp8_fp4_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, gran_k_a, gran_k_b,
major_a, major_b, compiled_dims);
}
} else {
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
}
Expand Down Expand Up @@ -166,7 +181,12 @@ static void m_grouped_fp8_fp4_gemm_nt_contiguous(const std::pair<torch::Tensor,
const auto [m_, n_] = get_shape<2>(d);
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
DG_HOST_ASSERT(n > 0 and k > 0 and num_groups > 0);
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
// FP4xFP4 path has no bf16 epilogue store: require fp32 D.
if (a.first.scalar_type() == kPackedFP4 and b.first.scalar_type() == kPackedFP4) {
DG_HOST_ASSERT(d.scalar_type() == torch::kFloat);
} else {
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
}
DG_HOST_ASSERT(grouped_layout.scalar_type() == torch::kInt);

// Layout checks
Expand Down Expand Up @@ -197,9 +217,22 @@ static void m_grouped_fp8_fp4_gemm_nt_contiguous(const std::pair<torch::Tensor,
num_groups, m, n, k, major_a, major_b, major_sfb,
compiled_dims, use_psum_layout, expected_m_for_psum_layout);
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d(a.first, sfa, b.first, sfb, d, grouped_layout,
num_groups, m, n, k, gran_k_a, gran_k_b, major_a, major_b,
compiled_dims, use_psum_layout, expected_m_for_psum_layout);
if (a.first.scalar_type() == kPackedFP4 and b.first.scalar_type() == kPackedFP4) {
DG_HOST_ASSERT(not use_psum_layout);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
DG_HOST_ASSERT(k % 8 == 0);
// swap_ab gating uses m/num_groups as the per-group estimate (accurate
// for dense layouts; sparse MoE without a hint conservatively skips
// swap_ab).
const int expected_m_per_group = num_groups > 0 ? m / num_groups : 0;
sm100_m_grouped_fp4_gemm_contiguous_1d1d(a.first, sfa, b.first, sfb, d, grouped_layout,
num_groups, m, n, k, expected_m_per_group,
major_a, major_b, compiled_dims);
} else {
sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d(a.first, sfa, b.first, sfb, d, grouped_layout,
num_groups, m, n, k, gran_k_a, gran_k_b, major_a, major_b,
compiled_dims, use_psum_layout, expected_m_for_psum_layout);
}
} else {
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
}
Expand Down Expand Up @@ -244,7 +277,12 @@ static void m_grouped_fp8_fp4_gemm_nt_masked(const std::pair<torch::Tensor, torc
DG_HOST_ASSERT(num_groups == num_groups_ and num_groups == num_groups__ and num_groups == num_groups___);
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
DG_HOST_ASSERT(expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0);
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
// FP4xFP4 path has no bf16 epilogue store: require fp32 D.
if (a.first.scalar_type() == kPackedFP4 and b.first.scalar_type() == kPackedFP4) {
DG_HOST_ASSERT(d.scalar_type() == torch::kFloat);
} else {
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
}
DG_HOST_ASSERT(masked_m.scalar_type() == torch::kInt);

// D must be N-major
Expand All @@ -260,9 +298,17 @@ static void m_grouped_fp8_fp4_gemm_nt_masked(const std::pair<torch::Tensor, torc
sm90_m_grouped_fp8_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m,
num_groups, m, n, k, expected_m, major_a, major_b, major_sfb, compiled_dims);
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
sm100_m_grouped_fp8_fp4_gemm_masked_1d1d(a.first, sfa, b.first, sfb, d, masked_m,
num_groups, m, n, k, expected_m, gran_k_a, gran_k_b,
if (a.first.scalar_type() == kPackedFP4 and b.first.scalar_type() == kPackedFP4) {
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
DG_HOST_ASSERT(k % 8 == 0);
sm100_m_grouped_fp4_gemm_masked_1d1d(a.first, sfa, b.first, sfb, d, masked_m,
num_groups, m, n, k, expected_m,
major_a, major_b, compiled_dims);
} else {
sm100_m_grouped_fp8_fp4_gemm_masked_1d1d(a.first, sfa, b.first, sfb, d, masked_m,
num_groups, m, n, k, expected_m, gran_k_a, gran_k_b,
major_a, major_b, compiled_dims);
}
} else {
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
}
Expand Down
Loading