Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 153 additions & 0 deletions csrc/apis/mega.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
#include <functional>
#include <string>
#include <pybind11/functional.h>
#include <pybind11/stl.h>

#if DG_TENSORMAP_COMPATIBLE
#include "../jit/compiler.hpp"
#endif
#include "../jit/device_runtime.hpp"
#include "../jit_kernels/impls/sm100_fp8_fp4_mega_moe.hpp"
#include "../jit_kernels/impls/sm100_fp8_fp4_mega_moe_split.hpp"

namespace deep_gemm::mega {

Expand Down Expand Up @@ -132,6 +134,123 @@ get_symm_buffer_size_for_mega_moe(
return {reinterpret_cast<int64_t>(combine_token_buffer.get_end_ptr()), slice_input_buffers};
}

// Symmetric-buffer sizing for the SPLIT-kernel pipeline. Identical to the fused layout above
// except it reserves the route-based `SplitWorkspace` bookkeeping region (the split K1 dispatch
// pull metadata). Kept separate so the fused MegaMoE buffer layout is untouched.
static std::tuple<int64_t, std::function<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>(const torch::Tensor&)>>
get_symm_buffer_size_for_mega_moe_split(
const int& num_ranks, const int& num_experts,
const int& num_max_tokens_per_rank, const int& num_topk,
const int& hidden, const int& intermediate_hidden,
const bool& use_fp8_dispatch, const std::string& activation) {
DG_HOST_ASSERT(num_experts % num_ranks == 0);
DG_HOST_ASSERT(use_fp8_dispatch);

// Workspace bytes (route-based split layout)
const auto workspace = layout::SplitWorkspace(nullptr, num_ranks, num_experts, num_max_tokens_per_rank, num_topk);

// Layouts
const auto fp8_token_layout = layout::Data(hidden);
const auto bf16_token_layout = layout::Data(hidden * 2);
const auto fp8_intermediate_token_layout = layout::Data(intermediate_hidden);
const auto fp8_sf_layout = layout::Data(hidden / 32);
const auto fp8_intermediate_sf_layout = layout::Data(intermediate_hidden / 32);
const auto input_topk_idx_layout = layout::Data(num_topk * sizeof(int64_t), false);
const auto input_topk_weights_layout = layout::Data(num_topk * sizeof(float), false);
const auto l1_topk_weights_layout = layout::Data(sizeof(float), false);

// Input buffers
const auto input_token_buffer = layout::Buffer(
fp8_token_layout, 1, num_max_tokens_per_rank,
workspace.get_end_ptr());
const auto input_sf_buffer = layout::Buffer(
fp8_sf_layout, 1, num_max_tokens_per_rank,
input_token_buffer.get_end_ptr());
const auto input_topk_idx_buffer = layout::Buffer(
input_topk_idx_layout, 1, num_max_tokens_per_rank,
input_sf_buffer.get_end_ptr());
const auto input_topk_weights_buffer = layout::Buffer(
input_topk_weights_layout, 1, num_max_tokens_per_rank,
input_topk_idx_buffer.get_end_ptr());

// Buffer configs
const auto num_max_pool_tokens = static_cast<int>(workspace.num_max_pool_tokens);
int num_max_padded_sf_pool_tokens = 0;
for (int block_m: layout::kCandidateBlockM) {
num_max_padded_sf_pool_tokens = std::max(
num_max_padded_sf_pool_tokens,
layout::get_num_padded_sf_pool_tokens(num_max_pool_tokens, block_m)
);
}

// L1 input buffer
const auto l1_token_buffer = layout::Buffer(
fp8_token_layout, 1, num_max_pool_tokens,
input_topk_weights_buffer.get_end_ptr());
const auto l1_sf_buffer = layout::Buffer(
fp8_sf_layout, 1, num_max_padded_sf_pool_tokens,
l1_token_buffer.get_end_ptr());
const auto l1_topk_weights_buffer = layout::Buffer(
l1_topk_weights_layout, 1, num_max_pool_tokens,
l1_sf_buffer.get_end_ptr());

// L2 input buffer
const auto l2_token_buffer = layout::Buffer(
fp8_intermediate_token_layout, 1, num_max_pool_tokens,
l1_topk_weights_buffer.get_end_ptr());
const auto l2_sf_buffer = layout::Buffer(
fp8_intermediate_sf_layout, 1, num_max_padded_sf_pool_tokens,
l2_token_buffer.get_end_ptr());

// Combine input buffer: BF16 tokens for cross-rank combine
const auto combine_token_buffer = layout::Buffer(
bf16_token_layout, num_topk, num_max_tokens_per_rank,
l2_sf_buffer.get_end_ptr());

// Check SF buffer requirements
DG_HOST_ASSERT(hidden % 128 == 0 and intermediate_hidden % 128 == 0);
DG_HOST_ASSERT(num_max_padded_sf_pool_tokens % 4 == 0);

auto slice_input_buffers = [=](const torch::Tensor& buffer) {
auto x = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_token_buffer.base)),
{num_max_tokens_per_rank, hidden},
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device()));
auto x_sf = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_sf_buffer.base)),
{num_max_tokens_per_rank, hidden / 128},
torch::TensorOptions().dtype(torch::kInt).device(buffer.device()));
auto topk_idx = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_topk_idx_buffer.base)),
{num_max_tokens_per_rank, num_topk},
torch::TensorOptions().dtype(torch::kInt64).device(buffer.device()));
auto topk_weights = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_topk_weights_buffer.base)),
{num_max_tokens_per_rank, num_topk},
torch::TensorOptions().dtype(torch::kFloat32).device(buffer.device()));
auto l1_acts = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l1_token_buffer.base)),
{num_max_pool_tokens, hidden},
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device()));
auto l1_acts_sf = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l1_sf_buffer.base)),
{num_max_padded_sf_pool_tokens, hidden / 128},
{1, num_max_padded_sf_pool_tokens},
torch::TensorOptions().dtype(torch::kInt).device(buffer.device()));
auto l2_acts = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l2_token_buffer.base)),
{num_max_pool_tokens, intermediate_hidden},
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device()));
auto l2_acts_sf = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l2_sf_buffer.base)),
{num_max_padded_sf_pool_tokens, intermediate_hidden / 128},
{1, num_max_padded_sf_pool_tokens},
torch::TensorOptions().dtype(torch::kInt).device(buffer.device()));
return std::make_tuple(x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf);
};
return {reinterpret_cast<int64_t>(combine_token_buffer.get_end_ptr()), slice_input_buffers};
}

static void fp8_fp4_mega_moe(
const torch::Tensor& y,
const std::tuple<torch::Tensor, torch::Tensor>& l1_weights_tuple,
Expand Down Expand Up @@ -230,7 +349,41 @@ static void register_apis(pybind11::module_& m) {
#if DG_TENSORMAP_COMPATIBLE
m.def("get_token_alignment_for_mega_moe", &get_token_alignment_for_mega_moe);
m.def("get_symm_buffer_size_for_mega_moe", &get_symm_buffer_size_for_mega_moe);
m.def("get_symm_buffer_size_for_mega_moe_split", &get_symm_buffer_size_for_mega_moe_split);
m.def("fp8_fp4_mega_moe", &fp8_fp4_mega_moe);
// Split-kernel MegaMoE: dispatch_l1_swiglu (K1) + l2_combine (K2) run concurrently on
// green-context SM partitions, combine_reduce (K3) reduces after both. The graph class
// exists on any CUDA (a stub throws below 13.1); see sm100_fp8_fp4_mega_moe_split.hpp.
pybind11::class_<SM100FP8FP4MegaMoESplitGraph>(
m, "SM100FP8FP4MegaMoESplitGraph")
.def(pybind11::init<
std::vector<torch::Tensor>,
std::vector<torch::Tensor>,
std::vector<torch::Tensor>,
std::vector<std::vector<int64_t>>,
std::vector<torch::Tensor>,
std::vector<torch::Tensor>,
std::vector<torch::Tensor>,
std::vector<torch::Tensor>,
std::vector<torch::Tensor>,
const int&,
const int&,
const int&,
const int&,
const int&,
const int&,
const int&,
const float&,
const bool&,
const int&,
const int&,
const int&,
const int&,
const int&>())
.def("replay", &SM100FP8FP4MegaMoESplitGraph::replay,
pybind11::call_guard<pybind11::gil_scoped_release>())
.def("get_green_context_ids",
&SM100FP8FP4MegaMoESplitGraph::get_green_context_ids);
#endif
}

Expand Down
57 changes: 57 additions & 0 deletions csrc/jit/handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
#include <cuda.h>
#include <cuda_runtime.h>
#include <dlfcn.h>
#include <cstdlib>
#include <filesystem>
#include <string>

#include "../utils/exception.hpp"
#include "../utils/compatibility.hpp"
Expand All @@ -20,6 +22,22 @@ static void* get_driver_handle() {
return handle;
}

static void* get_runtime_handle() {
static void* handle = nullptr;
if (handle == nullptr) {
if (const auto cuda_home = std::getenv("CUDA_HOME"); cuda_home != nullptr and cuda_home[0] != '\0') {
const auto runtime_path = std::string(cuda_home) + "/lib64/libcudart.so.13";
handle = dlopen(runtime_path.c_str(), RTLD_LAZY | RTLD_LOCAL);
}
if (handle == nullptr)
handle = dlopen("libcudart.so.13", RTLD_LAZY | RTLD_LOCAL);
if (handle == nullptr)
handle = dlopen("libcudart.so", RTLD_LAZY | RTLD_LOCAL);
DG_HOST_ASSERT(handle != nullptr and "Failed to load CUDA runtime `libcudart.so`");
}
return handle;
}

// Macro to define wrapper functions named `lazy_cu{API name}`
#define DECL_LAZY_CUDA_DRIVER_FUNCTION(name) \
template <typename... Args> \
Expand All @@ -33,6 +51,18 @@ static auto lazy_##name(Args&&... args) -> decltype(name(args...)) { \
return func(std::forward<decltype(args)>(args)...); \
}

#define DECL_LAZY_CUDA_RUNTIME_FUNCTION(name) \
template <typename... Args> \
static auto lazy_##name(Args&&... args) -> decltype(name(args...)) { \
using FuncType = decltype(&(name)); \
static FuncType func = nullptr; \
if (func == nullptr) { \
func = reinterpret_cast<FuncType>(dlsym(get_runtime_handle(), #name)); \
DG_HOST_ASSERT(func != nullptr and "Failed to load CUDA runtime API"); \
} \
return func(std::forward<decltype(args)>(args)...); \
}

DECL_LAZY_CUDA_DRIVER_FUNCTION(cuGetErrorName);
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuGetErrorString);
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuFuncSetAttribute);
Expand All @@ -45,6 +75,33 @@ DECL_LAZY_CUDA_DRIVER_FUNCTION(cuKernelGetFunction);
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuLaunchKernelEx);
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuTensorMapEncodeTiled);

#if CUDART_VERSION >= 13010
DECL_LAZY_CUDA_RUNTIME_FUNCTION(cudaDeviceGetExecutionCtx);
DECL_LAZY_CUDA_RUNTIME_FUNCTION(cudaDeviceGetDevResource);
DECL_LAZY_CUDA_RUNTIME_FUNCTION(cudaDevSmResourceSplit);
DECL_LAZY_CUDA_RUNTIME_FUNCTION(cudaDevResourceGenerateDesc);
DECL_LAZY_CUDA_RUNTIME_FUNCTION(cudaGreenCtxCreate);
DECL_LAZY_CUDA_RUNTIME_FUNCTION(cudaExecutionCtxGetId);
DECL_LAZY_CUDA_RUNTIME_FUNCTION(cudaExecutionCtxDestroy);
DECL_LAZY_CUDA_RUNTIME_FUNCTION(cudaGraphAddNode);
DECL_LAZY_CUDA_RUNTIME_FUNCTION(cudaGraphKernelNodeSetAttribute);
DECL_LAZY_CUDA_RUNTIME_FUNCTION(cudaGraphLaunch);
DECL_LAZY_CUDA_RUNTIME_FUNCTION(cudaGraphExecDestroy);
DECL_LAZY_CUDA_RUNTIME_FUNCTION(cudaGraphDestroy);

static cudaError_t lazy_cudaGraphInstantiate(cudaGraphExec_t *phGraphExec,
cudaGraph_t graph,
unsigned long long flags) {
using FuncType = cudaError_t (CUDARTAPI *)(cudaGraphExec_t*, cudaGraph_t, unsigned long long);
static FuncType func = nullptr;
if (func == nullptr) {
func = reinterpret_cast<FuncType>(dlsym(get_runtime_handle(), "cudaGraphInstantiate"));
DG_HOST_ASSERT(func != nullptr and "Failed to load CUDA runtime API");
}
return func(phGraphExec, graph, flags);
}
#endif

#if CUDART_VERSION >= 12080 and defined(DG_JIT_USE_RUNTIME_API)

// Use CUDA runtime API
Expand Down
Loading