Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions csrc/jit_kernels/impls/smxx_fp8_fp4_paged_mqa_logits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class SMXXPagedMQALogitsMetadataRuntime final: public LaunchRuntime<SMXXPagedMQA
int aligned_batch_size;
int split_kv;
int num_sms;
int num_next_n_atoms;
bool is_varlen;

int batch_size;
Expand All @@ -34,10 +35,10 @@ using namespace deep_gemm;

static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sched::smxx_paged_mqa_logits_metadata<
{}, {}, {}, {}
{}, {}, {}, {}, {}
>);
}};
)", args.aligned_batch_size, args.split_kv, args.num_sms, args.is_varlen ? "true" : "false");
)", args.aligned_batch_size, args.split_kv, args.num_sms, args.num_next_n_atoms, args.is_varlen ? "true" : "false");
}

static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
Expand All @@ -61,6 +62,10 @@ static void smxx_paged_mqa_logits_metadata(const torch::Tensor& context_lens,
constexpr int split_kv = 256;
constexpr int num_threads = 32;
const int aligned_batch_size = align(batch_size, 32);
const int next_n_atom = (is_varlen or next_n >= 2) ? 2 : 1;
// SM90 pads NextN=3 as one paired atom plus one single-token tail in the kernel.
const int num_next_n_atoms = (device_runtime->get_arch_major() == 9 and next_n == 3 and not is_varlen)
? 1 : ceil_div(next_n, next_n_atom);
DG_HOST_ASSERT(split_kv % block_kv == 0);

// Shared memory: prefix_sum[kAlignedBatchSize] plus varlen atom metadata when needed.
Expand All @@ -74,6 +79,7 @@ static void smxx_paged_mqa_logits_metadata(const torch::Tensor& context_lens,
.aligned_batch_size = aligned_batch_size,
.split_kv = split_kv,
.num_sms = num_sms,
.num_next_n_atoms = num_next_n_atoms,
.is_varlen = is_varlen,
.batch_size = batch_size,
.next_n = next_n,
Expand Down Expand Up @@ -190,7 +196,8 @@ static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q,
const int mma_m = (device_runtime->get_arch_major() == 10 ? 128 : 64);
const int num_math_warp_groups = split_kv / mma_m;
const int num_math_threads = num_math_warp_groups * 128;
const int num_q_stages = 3, num_kv_stages = (device_runtime->get_arch_major() == 10 ? 4 : 3);
const int num_q_stages = (device_runtime->get_arch_major() == 9 and next_n == 3) ? 4 : 3;
const int num_kv_stages = (device_runtime->get_arch_major() == 10 ? 4 : 3);
DG_HOST_ASSERT(split_kv % mma_m == 0 and logits_stride % split_kv == 0);

// Construct TMAs
Expand All @@ -217,8 +224,8 @@ static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q,
if (device_runtime->get_arch_major() == 9) {
const int swizzle_alignment = head_dim * 8;

const int smem_q_size_per_stage = next_n * num_heads * head_dim * static_cast<int>(q.element_size());
const int aligned_smem_weight_size_per_stage = align(next_n * num_heads * static_cast<int>(weights.element_size()), swizzle_alignment);
const int smem_q_size_per_stage = next_n_atom * num_heads * head_dim * static_cast<int>(q.element_size());
const int aligned_smem_weight_size_per_stage = align(next_n_atom * num_heads * static_cast<int>(weights.element_size()), swizzle_alignment);
const int smem_q_pipe_size = num_q_stages * (smem_q_size_per_stage + aligned_smem_weight_size_per_stage) + align(num_q_stages * 8 * 2, swizzle_alignment);

const int smem_kv_size_per_stage = block_kv * head_dim * static_cast<int>(kv_cache.element_size());
Expand All @@ -231,7 +238,7 @@ static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q,

smem_size = smem_q_pipe_size + num_math_warp_groups * smem_kv_pipe_size + smem_umma_barriers + smem_tmem_ptr;
DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity);
DG_HOST_ASSERT(next_n == 1 or next_n == 2);
DG_HOST_ASSERT(next_n >= 1 and next_n <= 3);
} else {
const int smem_q_size_per_stage = next_n_atom * num_heads * head_dim * static_cast<int>(q.element_size());
const int smem_kv_size_per_stage = split_kv * head_dim * static_cast<int>(kv_cache.element_size());
Expand Down
11 changes: 8 additions & 3 deletions csrc/jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class SMXXPagedMQALogitsMetadataRuntime final: public LaunchRuntime<SMXXPagedMQA
int aligned_batch_size;
int split_kv;
int num_sms;
int num_next_n_atoms;

int batch_size;
int next_n;
Expand All @@ -33,11 +34,11 @@ class SMXXPagedMQALogitsMetadataRuntime final: public LaunchRuntime<SMXXPagedMQA
using namespace deep_gemm;

static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&smxx_paged_mqa_logits_metadata<
{}, {}, {}
auto ptr = reinterpret_cast<void*>(&sched::smxx_paged_mqa_logits_metadata<
{}, {}, {}, {}
>);
}};
)", arch, args.aligned_batch_size, args.split_kv, args.num_sms);
)", arch, args.aligned_batch_size, args.split_kv, args.num_sms, args.num_next_n_atoms);
}

static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
Expand All @@ -46,6 +47,7 @@ static void __instantiate_kernel() {{
args.next_n,
args.is_context_lens_2d,
args.context_lens,
nullptr,
args.schedule_metadata
));
}
Expand All @@ -60,6 +62,8 @@ static void smxx_paged_mqa_logits_metadata(const torch::Tensor& context_lens,
constexpr int num_threads = 32;
const int aligned_batch_size = align(batch_size, 32);
const int split_kv = block_kv * num_math_warpgroups;
const int next_n_atom = (next_n >= 2) ? 2 : 1;
const int num_next_n_atoms = ceil_div(next_n, next_n_atom);

// Calculate shared memory size
const int smem_size = aligned_batch_size * static_cast<int>(sizeof(int));
Expand All @@ -71,6 +75,7 @@ static void smxx_paged_mqa_logits_metadata(const torch::Tensor& context_lens,
.aligned_batch_size = aligned_batch_size,
.split_kv = split_kv,
.num_sms = num_sms,
.num_next_n_atoms = num_next_n_atoms,
.batch_size = batch_size,
.next_n = next_n,
.is_context_lens_2d = is_context_lens_2d,
Expand Down
Loading