Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions csrc/apis/attention.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,9 @@ static torch::Tensor fp8_fp4_mqa_logits(const std::tuple<torch::Tensor, std::opt
if (is_fp4 and arch_major == 10) {
sm100_fp4_mqa_logits(q_fp, q_sf.value(), kv_fp, kv_sf, weights, cu_seq_len_k_start, cu_seq_len_k_end, logits, logits_dtype,
seq_len, seq_len_kv, max_seqlen_k, stride_logits, num_heads, head_dim, block_q, block_kv);
} else if (is_fp4 and arch_major == 9) {
sm90_fp4_mqa_logits(q_fp, q_sf.value(), kv_fp, kv_sf, weights, cu_seq_len_k_start, cu_seq_len_k_end, logits, logits_dtype,
seq_len, seq_len_kv, max_seqlen_k, stride_logits, num_heads, head_dim, block_q, block_kv);
} else if (not is_fp4 and (arch_major == 9 or arch_major == 10)) {
smxx_fp8_mqa_logits(q_fp, kv_fp, kv_sf, weights, cu_seq_len_k_start, cu_seq_len_k_end, logits, logits_dtype,
seq_len, seq_len_kv, max_seqlen_k, stride_logits, num_heads, head_dim, block_q, block_kv);
Expand Down
1 change: 1 addition & 0 deletions csrc/jit_kernels/impls/runtime_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType&
case torch::kFloat: return CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
case torch::kBFloat16: return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
case torch::kFloat8_e4m3fn: return CU_TENSOR_MAP_DATA_TYPE_UINT8;
case torch::kUInt8: return CU_TENSOR_MAP_DATA_TYPE_UINT8;
#if CUDA_VERSION >= 12080
case kPackedFP4: return fp4_unpacked_smem ? CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B
: CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B;
Expand Down
183 changes: 183 additions & 0 deletions csrc/jit_kernels/impls/smxx_fp8_fp4_mqa_logits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,4 +325,187 @@ static void sm100_fp4_mqa_logits(const torch::Tensor& q, const torch::Tensor& sf
SM100FP4MQALogitsRuntime::launch(runtime, args);
}

class SM90FP4MQALogitsRuntime final: public LaunchRuntime<SM90FP4MQALogitsRuntime> {
public:
struct Args {
int seq_len;
int seq_len_kv;
int max_seqlen_k;
int stride_logits;
int num_heads, head_dim;
bool is_compressed_logits;

int num_q_stages;
int num_kv_stages;
int block_q;
int block_kv;

int* cu_seq_len_k_start;
int* cu_seq_len_k_end;
void* logits;

CUtensorMap tensor_map_q;
CUtensorMap tensor_map_sf_q;
CUtensorMap tensor_map_kv;
CUtensorMap tensor_map_sf_kv;
CUtensorMap tensor_map_weights;
at::ScalarType logits_dtype;

int num_tma_threads;
int num_math_threads;

LaunchArgs launch_args;
};

static std::string generate_impl(const Args& args) {
DG_HOST_ASSERT(128 % args.num_heads == 0);

return fmt::format(R"(
#include <deep_gemm/impls/sm90_fp4_mqa_logits.cuh>

using namespace deep_gemm;

static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm90_fp4_mqa_logits<
{}, {},
{},
{}, {},
{}, {},
{},
{}, {},
{}
>);
}};
)", args.num_heads, args.head_dim,
args.is_compressed_logits,
args.block_q, args.block_kv,
args.num_q_stages, args.num_kv_stages,
args.launch_args.grid_dim.first,
args.num_tma_threads, args.num_math_threads,
to_string(args.logits_dtype));
}

static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.seq_len, args.seq_len_kv,
args.max_seqlen_k, args.stride_logits,
args.cu_seq_len_k_start, args.cu_seq_len_k_end,
args.logits,
args.tensor_map_q, args.tensor_map_sf_q,
args.tensor_map_kv, args.tensor_map_sf_kv,
args.tensor_map_weights
));
}
};

static void sm90_fp4_mqa_logits(const torch::Tensor& q, const torch::Tensor& sf_q,
const torch::Tensor& kv, const torch::Tensor& sf_kv,
const torch::Tensor& weights,
const torch::Tensor& cu_seq_len_k_start,
const torch::Tensor& cu_seq_len_k_end,
const torch::Tensor& logits,
const at::ScalarType& logits_dtype,
const int& seq_len, const int& seq_len_kv,
const int& max_seqlen_k, const int& stride_logits,
const int& num_heads, const int& head_dim,
const int& block_q, const int& block_kv) {
constexpr int num_tma_threads = 128;
constexpr int num_math_threads = 512;
constexpr int num_q_stages = 3, num_kv_stages = 3;

// Use compressed logits format when max_seqlen_k is specified
const bool is_compressed_logits = (max_seqlen_k > 0);

// Construct TMAs
// head_dim must be 128
DG_HOST_ASSERT(head_dim == 128);

// SM90 TMA does not support the FP4 packed data type (CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B).
// Reinterpret packed FP4 tensors as UINT8 so make_tma_2d_desc uses CU_TENSOR_MAP_DATA_TYPE_UINT8.
// Element size of kPackedFP4 is 1 byte (same as uint8), so strides/dims are unchanged.
auto q_as_uint8 = q.view(torch::kByte);
auto kv_as_uint8 = kv.view(torch::kByte);

// Q: packed FP4 viewed as UINT8 [seq_len * num_heads, head_dim/2], no swizzle (we load raw bytes)
const auto tensor_map_q = make_tma_2d_desc(q_as_uint8, head_dim / 2, seq_len * num_heads,
head_dim / 2, block_q * num_heads,
static_cast<int>(q.stride(1)),
0);
// SF Q: int32 [seq_len, num_heads], no swizzle
const auto tensor_map_sf_q = make_tma_2d_desc(sf_q, num_heads, seq_len,
num_heads, block_q,
static_cast<int>(sf_q.stride(0)), 0);
// Weights: float [seq_len, num_heads], no swizzle
const auto tensor_map_weights = make_tma_2d_desc(weights, num_heads, seq_len,
num_heads, block_q,
static_cast<int>(weights.stride(0)), 0);
// KV: packed FP4 viewed as UINT8 [seq_len_kv, head_dim/2], no swizzle
const auto tensor_map_kv = make_tma_2d_desc(kv_as_uint8, head_dim / 2, seq_len_kv,
head_dim / 2, block_kv,
static_cast<int>(kv.stride(0)),
0);
// SF KV: int32 [seq_len_kv], one int32 per token (packed 4x UE8M0 bytes)
// According to the driver API, the minimal alignment is 256 bytes
const auto tensor_map_sf_kv = make_tma_2d_desc(sf_kv,
get_tma_aligned_size(seq_len_kv, static_cast<int>(sf_kv.element_size())), 1,
block_kv, 1, 0, 0);

// Calculate shared memory size
// Non-FP8 parts
const int smem_fp4_q_size_per_stage = block_q * num_heads * head_dim / 2;
const int smem_fp4_kv_size_per_stage = block_kv * head_dim / 2;
const int smem_sf_q_size_per_stage = block_q * num_heads * static_cast<int>(sizeof(uint32_t));
const int smem_sf_kv_size_per_stage = block_kv * static_cast<int>(sizeof(uint32_t));
const int smem_weight_size_per_stage = block_q * num_heads * static_cast<int>(sizeof(float));
// FP8 dequantized parts (need swizzle alignment = 128 * 8 = 1024 bytes for head_dim=128)
const int fp8_swizzle_alignment = head_dim * 8;
const int smem_raw_offset =
smem_fp4_q_size_per_stage * num_q_stages +
smem_fp4_kv_size_per_stage * num_kv_stages +
smem_sf_q_size_per_stage * num_q_stages +
smem_sf_kv_size_per_stage * num_kv_stages +
smem_weight_size_per_stage * num_q_stages;
const int smem_fp8_offset = align(smem_raw_offset, fp8_swizzle_alignment);
const int smem_fp8_q_size_per_stage = block_q * num_heads * head_dim; // FP8
const int smem_fp8_kv_size_per_stage = block_kv * head_dim; // FP8
// Barriers: 4 sets (full_q, empty_q, full_kv, empty_kv), no dequant barriers (use NamedBarrier)
const int smem_barriers = (num_q_stages * 2 + num_kv_stages * 2) * 8;
const int smem_size = smem_fp8_offset
+ smem_fp8_q_size_per_stage * num_q_stages
+ smem_fp8_kv_size_per_stage * num_kv_stages
+ smem_barriers;
DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity);

// Launch
const SM90FP4MQALogitsRuntime::Args args = {
.seq_len = seq_len,
.seq_len_kv = seq_len_kv,
.max_seqlen_k = max_seqlen_k,
.stride_logits = stride_logits,
.num_heads = num_heads, .head_dim = head_dim,
.is_compressed_logits = is_compressed_logits,
.num_q_stages = num_q_stages,
.num_kv_stages = num_kv_stages,
.block_q = block_q,
.block_kv = block_kv,
.cu_seq_len_k_start = cu_seq_len_k_start.data_ptr<int>(),
.cu_seq_len_k_end = cu_seq_len_k_end.data_ptr<int>(),
.logits = logits.data_ptr(),
.tensor_map_q = tensor_map_q,
.tensor_map_sf_q = tensor_map_sf_q,
.tensor_map_kv = tensor_map_kv,
.tensor_map_sf_kv = tensor_map_sf_kv,
.tensor_map_weights = tensor_map_weights,
.logits_dtype = logits_dtype,
.num_tma_threads = num_tma_threads,
.num_math_threads = num_math_threads,
.launch_args = LaunchArgs(device_runtime->get_num_sms(),
num_tma_threads + num_math_threads,
smem_size)
};
const auto code = SM90FP4MQALogitsRuntime::generate(args);
const auto runtime = compiler->build("sm90_fp4_mqa_logits", code);
SM90FP4MQALogitsRuntime::launch(runtime, args);
}

} // namespace deep_gemm
Loading