Skip to content
Open
Show file tree
Hide file tree
Changes from 53 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
373f073
Add SM120a (RTX PRO 6000) dense FP8 GEMM support
leavelet Apr 26, 2026
6d8c50e
Add SM120a dense FP4 GEMM support and consolidate kernel to warp-spec…
leavelet Apr 28, 2026
18553d2
Add SM120a K-grouped contiguous FP8 GEMM and use custom MMA wrappers
leavelet Apr 29, 2026
8726331
Add SM120a Einsum (batched GEMM), M-grouped BF16, and HC prenorm TF32
leavelet Apr 29, 2026
c8a545e
Add SM120a FP8 MQA logits (dense + paged) for attention
leavelet Apr 30, 2026
c346b33
feat: bf16 mn gemm
leavelet Apr 30, 2026
d0a56e3
Add SM120a skip_head_mid, MN-major B for M-grouped, and FP8 stride_d fix
leavelet Apr 30, 2026
c5d92cd
Optimize MN-major B with ldmatrix.trans.x2, replacing 4 scalar loads
leavelet Apr 30, 2026
05d2818
Enable BLOCK_N=128 for MN-major B via multi-atom ldmatrix.trans
leavelet Apr 30, 2026
9996ceb
Add SM120a Einsum: BF16 bmk,bnk->mn kernel, FP8 bhd,hdr->bhr and bhd,…
leavelet Apr 30, 2026
ccaaea6
Add SM120a K-grouped FP8 TN layout and Paged MQA varlen support
leavelet May 1, 2026
9160119
Add SM120a FP4 MQA logits kernels (dense + paged)
leavelet May 1, 2026
9cbae28
Add SM120a FP8×FP4 mixed precision GEMM via native .b4x16_p64 TMA + l…
leavelet May 1, 2026
22a4e65
Enable FP8×FP4 mixed precision for K-grouped, M-grouped, and batched …
leavelet May 1, 2026
efeb709
Fix BF16 TMA store epilogue missing apply_index_n and add CUDA_ARCH g…
leavelet May 4, 2026
e34354a
SM120 FP4/FP8 perf: sub-tile TMA store epilogue + ldmatrix.x4 B load
leavelet May 4, 2026
55b8084
SM120: per-N-tile ldmatrix.x4 B load — zero MOV inner loop
leavelet May 4, 2026
2c75b45
SM120: cooperative warp layout (4M×2N) — halve per-warp B loads
leavelet May 4, 2026
a850640
SM120: tune heuristic to prefer BN=128 for cooperative warp layout
leavelet May 5, 2026
af7dd12
Remove K-grouped FP4 support (SM120-only, unused in production)
leavelet May 7, 2026
36aba91
SM120: MN-major support for FP8/FP4 GEMM (align with SM100 API)
leavelet May 8, 2026
0e8bbb9
SM120: enable block_kv=32 for FP4 paged MQA, guard FP8 to block_kv=64
leavelet May 13, 2026
ed129ae
SM120: SF-major loop with compile-time byte extraction and SwizzleCon…
leavelet May 13, 2026
e965b96
SM120: BLOCK_K=64 with 4 pipeline stages for large-M FP8/FP4 GEMM
leavelet May 13, 2026
1c50fbc
SM120: enable BLOCK_M=64 for grouped FP8/FP4 GEMM at small expected_m
lucifer1004 May 13, 2026
2f1ea74
SM120: enable BLOCK_M=32 for grouped FP8/FP4 GEMM via 2×4 warp layout
lucifer1004 May 13, 2026
dfa4e1b
Revert "SM120: enable BLOCK_M=32 for grouped FP8/FP4 GEMM via 2×4 war…
lucifer1004 May 14, 2026
319f092
SM120: AB-swap path for small-M FP8/FP4 BMM
lucifer1004 May 15, 2026
ea5f129
SM120: extend AB-swap to plain FP8/FP4 GEMM + asymmetric recipe support
lucifer1004 May 15, 2026
1383f15
SM120: wave-packing heuristic + BLOCK_M=64 + AB-swap for dense GEMM
leavelet May 25, 2026
764894d
SM120: AB-swap with cp.async B-tile loader for M=1..16 dense GEMM
leavelet May 25, 2026
aaf9740
Revert "SM120: AB-swap with cp.async B-tile loader for M=1..16 dense …
leavelet May 25, 2026
6b8a2c3
SM121: reuse SM120 family cubin on NVCC/NVRTC >= 12.9
leavelet May 25, 2026
0685768
SM120: AB-swap via TMA for M=1..16 dense GEMM + per-element epilogue
leavelet May 25, 2026
41b7aae
SM120: prefer BM=64 for AB-swap path (2x SM utilization)
leavelet May 26, 2026
fffb3f8
SM120: fix BMM swap stride, num_groups alignment, and BF16 heuristic …
leavelet May 26, 2026
e589be3
SM120: split-K for small-dimension FP8 GEMM
leavelet May 26, 2026
b9c8a84
SM120: disable split-K for AB-swap path
leavelet May 26, 2026
dcce8d4
SM120: BF16 kernel kNWarps=2 for BM=64 + runtime_align guard
leavelet May 26, 2026
5b01e23
SM120: fix rebase conflicts with nv_dev, guard paged MQA next_n <= 2
leavelet May 26, 2026
1f2f161
SM120: remove sm120-specific test files and fix SM90 test regression
leavelet May 26, 2026
c8094d8
SM120: enable split-K for AB-swap path, fix SF alignment bugs
leavelet May 27, 2026
553fe80
SM120: enable FP32 TMA store epilogue for wgrad
leavelet May 27, 2026
85b6d4c
SM120: fix batched FP8 GEMM accumulation double-count and swap exclusion
leavelet May 29, 2026
a2f90b0
SM120 tests: fix cuBLAS baseline timing for shapes using non-nvjet ke…
leavelet May 29, 2026
1845516
SM120: split-KV for dense FP8 MQA logits to fill idle SMs at small S
leavelet May 29, 2026
76e93aa
SM120: complete odd/large next_n (kPadOddN) port for paged MQA logits
leavelet May 31, 2026
8f5a8d2
Merge origin/nv_dev (245dc5d) into sm120
leavelet Jun 4, 2026
3f5f4e7
SM120: disable TMA-store epilogue for AB-swap (transposed) output
leavelet Jun 4, 2026
18d9c4e
SM120: fix SM90/SM100 MQA-logits SMEM regressions from shared-launche…
leavelet Jun 5, 2026
aced12c
SM120: clean up branch for upstream (drop dev artifacts, trim comments)
leavelet Jun 5, 2026
8130c43
SM120: skip empty (-1) tiles in MGroupedContiguous grouped GEMM
leavelet Jun 12, 2026
33a715e
SM120: fp4-A x fp8-B mixed GEMM (kAIsFP4, swapAB orientation)
leavelet Jun 12, 2026
1e22994
SM120: extract gemm dispatch helper and per-arch alignment policy
leavelet Jun 16, 2026
9ca3048
SM120: validate BF16 einsum vs FP32 truth, gate cuBLAS-diff tolerance
leavelet Jun 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions csrc/apis/attention.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "../jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp"
#include "../jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp"
#include "../jit_kernels/impls/sm100_fp8_fp4_gemm_1d1d.hpp"
#include "../jit_kernels/impls/sm120_fp8_fp4_gemm_1d1d.hpp"
#include "../jit_kernels/impls/smxx_fp8_fp4_mqa_logits.hpp"
#include "../jit_kernels/impls/smxx_fp8_fp4_paged_mqa_logits.hpp"
#include "../jit_kernels/impls/smxx_clean_logits.hpp"
Expand Down Expand Up @@ -68,6 +69,9 @@ static void fp8_gemm_nt_skip_head_mid(const std::pair<torch::Tensor, torch::Tens
// NOTES: Only granularity 128 and FP8 are exposed in the API
sm100_fp8_fp4_gemm_1d1d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k,
128, 128, major_a, major_b, compiled_dims, epilogue_type);
} else if (arch_major == 12 and sfa.scalar_type() == torch::kInt) {
sm120_fp8_fp4_gemm_1d1d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k,
128, 128, major_a, major_b, compiled_dims, epilogue_type);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
}
Expand Down Expand Up @@ -162,7 +166,7 @@ static torch::Tensor fp8_fp4_mqa_logits(const std::tuple<torch::Tensor, std::opt

// Allocate output
constexpr int block_qh = 128;
constexpr int block_kv = 256;
const int block_kv = (device_runtime->get_arch_major() == 12) ? 128 : 256;
const int block_q = block_qh / num_heads;
DG_HOST_ASSERT(block_qh % num_heads == 0);

Expand Down Expand Up @@ -192,11 +196,14 @@ static torch::Tensor fp8_fp4_mqa_logits(const std::tuple<torch::Tensor, std::opt
if (is_fp4 and arch_major == 10) {
sm100_fp4_mqa_logits(q_fp, q_sf.value(), kv_fp, kv_sf, weights, cu_seq_len_k_start, cu_seq_len_k_end, logits, logits_dtype,
seq_len, seq_len_kv, max_seqlen_k, stride_logits, num_heads, head_dim, block_q, block_kv);
} else if (is_fp4 and arch_major == 12) {
sm120_fp4_mqa_logits(q_fp, q_sf.value(), kv_fp, kv_sf, weights, cu_seq_len_k_start, cu_seq_len_k_end, logits, logits_dtype,
seq_len, seq_len_kv, max_seqlen_k, stride_logits, num_heads, head_dim, block_q, block_kv);
} else if (not is_fp4 and weights_is_f16) {
// FP16 weights -> FP16 MMA accumulator (Q*K score + per-head reduction in FP16); see note above
sm100_fp8_mqa_logits_f16_weights(q_fp, kv_fp, kv_sf, weights, cu_seq_len_k_start, cu_seq_len_k_end, logits, logits_dtype,
seq_len, seq_len_kv, max_seqlen_k, stride_logits, num_heads, head_dim, block_q, block_kv);
} else if (not is_fp4 and (arch_major == 9 or arch_major == 10)) {
} else if (not is_fp4 and (arch_major == 9 or arch_major == 10 or arch_major == 12)) {
smxx_fp8_mqa_logits(q_fp, kv_fp, kv_sf, weights, cu_seq_len_k_start, cu_seq_len_k_end, logits, logits_dtype,
seq_len, seq_len_kv, max_seqlen_k, stride_logits, num_heads, head_dim, block_q, block_kv);
} else {
Expand Down Expand Up @@ -228,15 +235,15 @@ static torch::Tensor get_paged_mqa_logits_metadata(const torch::Tensor& context_
const auto arch_major = device_runtime->get_arch_major();
if (is_varlen) {
const auto& indices_tensor = indices.value();
DG_HOST_ASSERT(arch_major == 10 and next_n == 1 and (block_kv == 64 or block_kv == 32));
DG_HOST_ASSERT((arch_major == 10 or arch_major == 12) and next_n == 1 and (block_kv == 64 or block_kv == 32));
DG_HOST_ASSERT(indices_tensor.dim() == 1 and indices_tensor.size(0) == batch_size);
DG_HOST_ASSERT(indices_tensor.is_contiguous());
DG_HOST_ASSERT(indices_tensor.scalar_type() == torch::kInt);
// Varlen runs on SM100 with next_n=1: no atomization (num_next_n_atoms=1).
smxx_paged_mqa_logits_metadata(context_lens, schedule_metadata, batch_size, next_n, block_kv,
num_sms, is_context_lens_2d, /*num_next_n_atoms=*/1,
/*is_varlen=*/true, indices_tensor.data_ptr<int>());
} else if (arch_major == 9 or arch_major == 10) {
} else if (arch_major == 9 or arch_major == 10 or arch_major == 12) {
DG_HOST_ASSERT(block_kv == 32 or block_kv == 64);
// SM90 schedules in units of `kComputeBlockKV = 64` regardless of physical
// `block_kv`; pass the compute block size to the metadata kernel.
Expand All @@ -245,6 +252,7 @@ static torch::Tensor get_paged_mqa_logits_metadata(const torch::Tensor& context_
// kNextNAtom = (kIsVarlen or kNextN >= 2) ? 2 : 1
// kNumNextNAtoms = ceil_div(kNextN, kNextNAtom)
// SM90 cluster multicast hard-codes kNumNextNAtoms = 1 (one q per cluster).
// SM100/SM120 atomize next_n in time: kNextNAtom = (next_n >= 2) ? 2 : 1.
int num_next_n_atoms;
if (arch_major == 9) {
num_next_n_atoms = 1;
Expand Down Expand Up @@ -375,7 +383,7 @@ static torch::Tensor fp8_fp4_paged_mqa_logits(const std::tuple<torch::Tensor, st
const auto arch_major = device_runtime->get_arch_major();
const auto indices_tensor = indices.value_or(torch::Tensor());
if (is_varlen) {
DG_HOST_ASSERT(arch_major == 10 and next_n == 1);
DG_HOST_ASSERT((arch_major == 10 or arch_major == 12) and next_n == 1);
DG_HOST_ASSERT(indices_tensor.dim() == 1 and indices_tensor.size(0) == batch_size);
DG_HOST_ASSERT(indices_tensor.is_contiguous());
DG_HOST_ASSERT(indices_tensor.scalar_type() == torch::kInt);
Expand All @@ -399,7 +407,8 @@ static torch::Tensor fp8_fp4_paged_mqa_logits(const std::tuple<torch::Tensor, st
DG_HOST_ASSERT(context_lens.scalar_type() == torch::kInt);

// Allocate output
constexpr int split_kv = 256;
// SM120a: 2 groups × 64 KV rows = 128; SM90/100: 256
const int split_kv = (arch_major == 12) ? 128 : 256;
const auto aligned_max_context_len = align(max_context_len, split_kv);
auto logits = torch::empty({batch_size * next_n, aligned_max_context_len}, q_fp.options().dtype(logits_dtype));
logits = logits.slice(-1, 0, max_context_len);
Expand All @@ -410,7 +419,11 @@ static torch::Tensor fp8_fp4_paged_mqa_logits(const std::tuple<torch::Tensor, st
sm100_fp4_paged_mqa_logits(q_fp, q_sf.value(), kv_cache, kv_cache_sf, weights, context_lens, logits, block_table, indices_tensor, schedule_meta,
logits_dtype, batch_size, next_n, num_heads, head_dim, num_kv_blocks, block_kv, is_context_lens_2d,
is_varlen, aligned_max_context_len, block_table_stride, num_sms, split_kv);
} else if (not is_fp4 and (arch_major == 9 or arch_major == 10)) {
} else if (is_fp4 and arch_major == 12) {
sm120_fp4_paged_mqa_logits(q_fp, q_sf.value(), kv_cache, kv_cache_sf, weights, context_lens, logits, block_table, indices_tensor, schedule_meta,
logits_dtype, batch_size, next_n, num_heads, head_dim, num_kv_blocks, block_kv, is_context_lens_2d,
is_varlen, aligned_max_context_len, block_table_stride, num_sms, split_kv);
} else if (not is_fp4 and (arch_major == 9 or arch_major == 10 or arch_major == 12)) {
smxx_fp8_paged_mqa_logits(q_fp, kv_cache, kv_cache_sf, weights, context_lens, logits, block_table, indices_tensor, schedule_meta,
logits_dtype, batch_size, next_n, num_heads, head_dim, num_kv_blocks, block_kv, is_context_lens_2d,
is_varlen, aligned_max_context_len, block_table_stride, num_sms, split_kv);
Expand Down
86 changes: 75 additions & 11 deletions csrc/apis/einsum.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
#include "../jit_kernels/impls/sm90_bmk_bnk_mn.hpp"
#include "../jit_kernels/impls/sm100_bmk_bnk_mn.hpp"
#include "../jit_kernels/impls/sm120_bmk_bnk_mn.hpp"
#include "../jit_kernels/impls/sm90_bf16_gemm.hpp"
#include "../jit_kernels/impls/sm100_bf16_gemm.hpp"
#include "../jit_kernels/impls/sm120_bf16_gemm.hpp"
#include "../jit_kernels/impls/sm120_fp8_fp4_gemm_1d1d.hpp"
#include "../jit_kernels/impls/smxx_cublaslt.hpp"
#endif

Expand Down Expand Up @@ -51,6 +54,8 @@ static void bmk_bnk_mn(const torch::Tensor& a, const torch::Tensor& b, const tor
const auto arch_major = device_runtime->get_arch_major();
if (arch_major == 9) {
sm90_bmn_bnk_mn_gemm(a, b, d, s, m, n, k);
} else if (arch_major == 12) {
sm120_bmn_bnk_mn_gemm(a, b, d, s, m, n, k);
} else if (arch_major == 10) {
sm100_bmn_bnk_mn_gemm(a, b, d, s, m, n, k);
} else {
Expand All @@ -74,6 +79,8 @@ static void bhr_hdr_bhd(const torch::Tensor& A, const torch::Tensor& B, const to
cublaslt_bhr_hdr_bhd(A, B, D, b, h, r, d);
} else if (arch_major == 9) {
sm90_bf16_bhr_hdr_bhd(A, B, D, b, h, r, d);
} else if (arch_major == 12) {
sm120_bf16_bhr_hdr_bhd(A, B, D, b, h, r, d);
} else if (arch_major == 10) {
sm100_bf16_bhr_hdr_bhd(A, B, D, b, h, r, d);
} else {
Expand All @@ -97,6 +104,8 @@ static void bhd_hdr_bhr(const torch::Tensor& A, const torch::Tensor& B, const to
cublaslt_bhd_hdr_bhr(A, B, D, b, h, r, d);
} else if (arch_major == 9) {
sm90_bf16_bhd_hdr_bhr(A, B, D, b, h, r, d);
} else if (arch_major == 12) {
sm120_bf16_bhd_hdr_bhr(A, B, D, b, h, r, d);
} else if (arch_major == 10) {
sm100_bf16_bhd_hdr_bhr(A, B, D, b, h, r, d);
} else {
Expand Down Expand Up @@ -161,13 +170,59 @@ static void fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa,
if (batch_size == 0 or gemm::early_return(m, n, k, d, c))
return;

// Transform scaling factors
// AB-swap small-M decode path. Mirrors TRT-LLM's runGemmSwapAB
// (cpp/include/tensorrt_llm/deep_gemm/fp8_gemm.cuh): SM120 1d1d has
// BLOCK_M ≥ 64, so M_orig ≤ 32 wastes lanes. Swapping A↔B moves the
// small dim to N where BLOCK_N can shrink to {16, 32}. The swap runs
// BEFORE the SF layout transform so transform_sf_pair_into_required_layout
// sees operands in their post-swap roles. The kernel writes back into the
// caller's (B, M_orig, N_orig) buffer directly via runtime stride_cd_m/n
// (no temp buffer); see sm120_fp8_fp4_bmm for the stride remap.
const auto arch_major = device_runtime->get_arch_major();
constexpr int kSwapAbMMax = 32;
const bool swap_ab_eligible =
arch_major == 12 and m >= 1 and m <= kSwapAbMMax
and major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K
and d.stride(-1) == 1 // d's innermost dim must be contiguous
and not c.has_value(); // swap's strided/transposed output is incompatible with
// the batched accumulation epilogue (both the REDUCE_ADD
// TMA-store path and the direct-store path mishandle the
// batch offset + swapped strides). Mirrors the dense GEMM
// swap exclusion in gemm.hpp.

if (swap_ab_eligible) {
// Swap the recipe's gran_mn entries too: (gran_mn_a, gran_mn_b, gran_k)
// describes the original A/B roles, so after operand swap the per-tensor
// granularities must follow. Without this, asymmetric recipes
// like (1, 128, 128) trip the SF layout shape check.
const auto eff_recipe = recipe.has_value()
? recipe.value()
: get_default_recipe(sfa.scalar_type(), sfb.scalar_type());
const auto& [ga, gb, gk] = eff_recipe;
std::optional<std::tuple<int, int, int>> swap_recipe = std::nullopt;
std::optional<std::tuple<int, int>> swap_recipe_a = std::make_tuple(gb, gk);
std::optional<std::tuple<int, int>> swap_recipe_b = std::make_tuple(ga, gk);
const auto [transformed_sfa_swap, transformed_sfb_swap, gran_k_a_swap, gran_k_b_swap]
= layout::transform_sf_pair_into_required_layout(
sfb, sfa, /*m=*/n, /*n=*/m, k, swap_recipe,
swap_recipe_a, swap_recipe_b, batch_size, batch_size, false);
sm120_fp8_fp4_bmm(
b, transformed_sfa_swap, a, transformed_sfb_swap, c, d,
batch_size, /*m=*/n, /*n=*/m, k,
gran_k_a_swap, gran_k_b_swap,
major_b, major_a, compiled_dims,
/*swap_ab=*/true);
return;
}

// Transform scaling factors (non-swap path)
const auto [transformed_sfa, transformed_sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout(
sfa, sfb, m, n, k, recipe, std::nullopt, std::nullopt, batch_size, batch_size, false);

// Dispatch implementation
const auto arch_major = device_runtime->get_arch_major();
if (arch_major == 10) {
if (arch_major == 12) {
sm120_fp8_fp4_bmm(a, transformed_sfa, b, transformed_sfb, c, d, batch_size, m, n, k, gran_k_a, gran_k_b, major_a, major_b, compiled_dims);
} else if (arch_major == 10) {
sm100_fp8_bmm(a, transformed_sfa, b, transformed_sfb, c, d, batch_size, m, n, k, gran_k_a, gran_k_b, major_a, major_b, compiled_dims);
} else {
const auto major_sfb = get_major_type_ab(sfb);
Expand All @@ -192,21 +247,30 @@ static void fp8_einsum(const std::string& expr,
const auto perm_d = d.permute({1, 0, 2});
const auto perm_c = c.has_value() ? std::make_optional(c.value().permute({1, 0, 2})) : std::nullopt;
fp8_bmm(perm_a, perm_sfa, b.first, b.second, perm_d, perm_c, recipe, "nk");
} else if (expr == "bhd,hdr->bhr" and arch_major == 10) {
} else if (expr == "bhd,hdr->bhr") {
// (batch_size, m, n, k): (h, b, r, d)
const auto perm_a = a.first.permute({1, 0, 2});
const auto perm_sfa = a.second.permute({1, 0, 2});
const auto perm_b = b.first.permute({0, 2, 1});
const auto perm_sfb = b.second.permute({0, 2, 1});
auto perm_b = b.first.permute({0, 2, 1});
auto perm_sfb = b.second.permute({0, 2, 1});
// SM120: B is MN-major after permute; .contiguous() to K-major (scalar MN-major path ~3x slower).
if (arch_major == 12) {
perm_b = perm_b.contiguous();
}
const auto perm_d = d.permute({1, 0, 2});
const auto perm_c = c.has_value() ? std::make_optional(c.value().permute({1, 0, 2})) : std::nullopt;
fp8_bmm(perm_a, perm_sfa, perm_b, perm_sfb, perm_d, perm_c, recipe, "nk");
} else if (expr == "bhd,bhr->hdr" and arch_major == 10) {
} else if (expr == "bhd,bhr->hdr") {
// (batch_size, m, n, k): (h, d, r, b)
const auto perm_a = a.first.permute({1, 2, 0});
const auto perm_sfa = a.second.permute({1, 2, 0});
const auto perm_b = b.first.permute({1, 2, 0});
const auto perm_sfb = b.second.permute({1, 2, 0});
auto perm_a = a.first.permute({1, 2, 0});
auto perm_sfa = a.second.permute({1, 2, 0});
auto perm_b = b.first.permute({1, 2, 0});
auto perm_sfb = b.second.permute({1, 2, 0});
// SM120: A/B MN-major after permute; force K-major (MN-major A unsupported, scalar path ~3x slower).
if (arch_major == 12) {
perm_a = perm_a.contiguous();
perm_b = perm_b.contiguous();
}
fp8_bmm(perm_a, perm_sfa, perm_b, perm_sfb, d, c, recipe, "mn");
} else {
DG_HOST_UNREACHABLE(fmt::format("Unsupported einsum expression: {}", expr));
Expand Down
Loading