diff --git a/.gitignore b/.gitignore index d0cdf6ca42..5b5e5bf583 100644 --- a/.gitignore +++ b/.gitignore @@ -21,4 +21,11 @@ deep_gemm/include/cutlass stubs/ # Symlinks to compiled extensions -deep_gemm/*.so \ No newline at end of file +deep_gemm/*.so +deep_gemm/*.pyd + +# Claude Code workspace +.claude/ + +# Windows null device artifact +nul diff --git a/CMakeLists.txt b/CMakeLists.txt index 79f1964dad..556599e32e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,8 +3,16 @@ cmake_minimum_required(VERSION 3.10) project(deep_gemm LANGUAGES CXX CUDA) set(CMAKE_VERBOSE_MAKEFILE ON) -set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O3 -fPIC -Wno-psabi") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC -Wno-psabi") +# Platform-specific compiler flags +if(MSVC) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /O2 /EHsc") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /O2 /EHsc /std:c++17 /Zc:__cplusplus /permissive-") + add_definitions(-D_CRT_SECURE_NO_WARNINGS) +else() + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O3 -fPIC -Wno-psabi") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC -Wno-psabi") +endif() + set(CUDA_SEPARABLE_COMPILATION ON) list(APPEND CUDA_NVCC_FLAGS "-DENABLE_FAST_DEBUG") list(APPEND CUDA_NVCC_FLAGS "-O3") @@ -22,8 +30,15 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CUDA_STANDARD 17) include_directories(deep_gemm/include third-party/cutlass/include third-party/cutlass/tools/util/include third-party/fmt/include) -include_directories(${CUDA_TOOLKIT_ROOT_DIR}/targets/x86_64-linux/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS}) -link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib64 ${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs) + +# Platform-specific include and library directories +if(WIN32) + include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS}) + link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64) +else() + include_directories(${CUDA_TOOLKIT_ROOT_DIR}/targets/x86_64-linux/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS}) + link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib64 ${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs) +endif() # The main Python API entrance pybind11_add_module(_C csrc/python_api.cpp) diff --git a/README.md b/README.md index 04a289dbb5..7b238b6c92 100644 --- a/README.md +++ b/README.md @@ -72,7 +72,7 @@ python tests/test_attention.py python tests/test_core.py ``` -### Installation +### Installation (Linux) ```bash cat install.sh @@ -81,6 +81,63 @@ cat install.sh Then, import `deep_gemm` in your Python project, and enjoy! +### Installation (Windows) + +DeepGEMM supports building on Windows with MSVC. Follow these steps: + +#### Prerequisites + +- **Visual Studio 2022** (or 2019) with C++ desktop development workload +- **CUDA Toolkit 12.3+** (12.9+ recommended for best performance) +- **Python 3.8+** with pip +- **PyTorch 2.1+** with CUDA support + +#### Build Steps + +```cmd +:: Clone the repository +git clone --recursive https://github.com/deepseek-ai/DeepGEMM.git +cd DeepGEMM + +:: Create and activate virtual environment (using uv or venv) +uv venv --python 3.10 +.venv\Scripts\activate + +:: Install dependencies +uv pip install packaging wheel setuptools ninja numpy==1.26.4 pip build psutil +uv pip install torch==2.10.0 --index-url https://download.pytorch.org/whl/cu128 + +:: Clear environment variables and initialize VS developer environment +set INCLUDE= +set LIB= +set LIBPATH= +call "C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Auxiliary\Build\vcvarsall.bat" x64 + +:: Configure CUDA +set CUDA_HOME=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.8 +set PATH=%CUDA_HOME%\bin;%PATH% +set DISTUTILS_USE_SDK=1 + +:: Set target architectures (adjust based on your GPU) +set TORCH_CUDA_ARCH_LIST=8.0;8.6;8.9;9.0 + +:: Initialize submodules +git submodule sync +git submodule update --init --recursive + +:: Build the wheel +set DG_USE_LOCAL_VERSION=0 +python setup.py bdist_wheel +``` + +The built wheel will be located in the `dist/` directory. + +#### Troubleshooting + +- **Linker errors (LNK1104)**: Ensure `vcvarsall.bat` was executed in the same terminal session +- **CUDA not found**: Verify `CUDA_HOME` points to your CUDA installation directory +- **Missing cuBLASLt symbols**: The build system should automatically link cuBLASLt; ensure CUDA Toolkit is properly installed + ## Interfaces #### Notices diff --git a/csrc/apis/gemm.hpp b/csrc/apis/gemm.hpp index 6770cf9285..774e88a944 100644 --- a/csrc/apis/gemm.hpp +++ b/csrc/apis/gemm.hpp @@ -69,9 +69,10 @@ static void fp8_fp4_gemm_nt(const std::pair& a, // Type and shape checks const auto arch_major = device_runtime->get_arch_major(); - const auto [m , k ] = check_ab_fp8_fp4(a.first, major_a, arch_major); - const auto [n , k_] = check_ab_fp8_fp4(b.first, major_b, arch_major); - const auto [m_, n_] = get_shape<2>(d); + int m, k, n, k_, m_, n_; + std::tie(m, k) = check_ab_fp8_fp4(a.first, major_a, arch_major); + std::tie(n, k_) = check_ab_fp8_fp4(b.first, major_b, arch_major); + std::tie(m_, n_) = get_shape<2>(d); DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat); @@ -80,7 +81,9 @@ static void fp8_fp4_gemm_nt(const std::pair& a, return; // Transform SFA and SFB into compute-required layout - const auto [sfa, sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout( + torch::Tensor sfa, sfb; + int gran_k_a, gran_k_b; + std::tie(sfa, sfb, gran_k_a, gran_k_b) = layout::transform_sf_pair_into_required_layout( a.second, b.second, m, n, k, recipe, recipe_a, recipe_b, std::nullopt, std::nullopt, disable_ue8m0_cast); // Dispatch into different implements @@ -161,9 +164,10 @@ static void m_grouped_fp8_fp4_gemm_nt_contiguous(const std::pairget_arch_major(); - const auto [m , k ] = check_ab_fp8_fp4(a.first, major_a, arch_major); - const auto [num_groups, n, k_] = check_grouped_ab_fp8_fp4(b.first, major_b, arch_major); - const auto [m_, n_] = get_shape<2>(d); + int m, k, num_groups, n, k_, m_, n_; + std::tie(m, k) = check_ab_fp8_fp4(a.first, major_a, arch_major); + std::tie(num_groups, n, k_) = check_grouped_ab_fp8_fp4(b.first, major_b, arch_major); + std::tie(m_, n_) = get_shape<2>(d); DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); DG_HOST_ASSERT(n > 0 and k > 0 and num_groups > 0); DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); @@ -171,10 +175,12 @@ static void m_grouped_fp8_fp4_gemm_nt_contiguous(const std::pair(grouped_layout); + int num_groups_; + std::tie(num_groups_) = get_shape<1>(grouped_layout); DG_HOST_ASSERT(num_groups == num_groups_); } else { - const auto& [m__] = get_shape<1>(grouped_layout); + int m__; + std::tie(m__) = get_shape<1>(grouped_layout); DG_HOST_ASSERT(m == m__); DG_HOST_ASSERT(not expected_m_for_psum_layout.has_value()); } @@ -187,7 +193,9 @@ static void m_grouped_fp8_fp4_gemm_nt_contiguous(const std::pairget_arch_major(); - const auto [num_groups , m , k ] = check_grouped_ab_fp8_fp4(a.first, major_a, arch_major); - const auto [num_groups_ , n , k_] = check_grouped_ab_fp8_fp4(b.first, major_b, arch_major); - const auto [num_groups__, m_, n_] = get_shape<3>(d); + int num_groups, m, k, num_groups_, n, k_, num_groups__, m_, n_; + std::tie(num_groups, m, k) = check_grouped_ab_fp8_fp4(a.first, major_a, arch_major); + std::tie(num_groups_, n, k_) = check_grouped_ab_fp8_fp4(b.first, major_b, arch_major); + std::tie(num_groups__, m_, n_) = get_shape<3>(d); const auto num_groups___ = static_cast(masked_m.numel()); DG_HOST_ASSERT(num_groups == num_groups_ and num_groups == num_groups__ and num_groups == num_groups___); DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); @@ -251,7 +260,9 @@ static void m_grouped_fp8_fp4_gemm_nt_masked(const std::pair(d); - const auto& [sum_k_ , m_] = get_shape<2>(a.first); - const auto& [sum_k__, n_] = get_shape<2>(b.first); + int num_groups, m, n, sum_k_, m_, sum_k__, n_; + std::tie(num_groups, m, n) = get_shape<3>(d); + std::tie(sum_k_, m_) = get_shape<2>(a.first); + std::tie(sum_k__, n_) = get_shape<2>(b.first); const int sum_k = std::accumulate(ks.begin(), ks.end(), 0); DG_HOST_ASSERT(m == m_ and n == n_ and sum_k == sum_k_ and sum_k == sum_k__); @@ -322,7 +334,8 @@ static void k_grouped_fp8_gemm_nt_contiguous(const std::pair(d); + int num_groups, m, n; + std::tie(num_groups, m, n) = get_shape<3>(d); const auto& sum_mk = a.first.numel(); const auto& sum_nk = b.first.numel(); const int sum_k = std::accumulate(ks.begin(), ks.end(), 0); @@ -374,9 +387,10 @@ static void bf16_gemm_nt(const torch::Tensor& a, check_major_type_cd(d); // Type and shape checks - const auto& [m , k ] = get_shape<2>(a); - const auto& [n , k_] = get_shape<2>(b); - const auto& [m_, n_] = get_shape<2>(d); + int m, k, n, k_, m_, n_; + std::tie(m, k) = get_shape<2>(a); + std::tie(n, k_) = get_shape<2>(b); + std::tie(m_, n_) = get_shape<2>(d); DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16); DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16); @@ -433,9 +447,10 @@ static void m_grouped_bf16_gemm_nt_contiguous(const torch::Tensor& a, const torc DG_HOST_ASSERT(grouped_layout.is_contiguous()); // Type and shape checks - const auto& [m, k] = get_shape<2>(a); - const auto& [num_groups, n, k_] = get_shape<3>(b); - const auto& [m_, n_] = get_shape<2>(d); + int m, k, num_groups, n, k_, m_, n_; + std::tie(m, k) = get_shape<2>(a); + std::tie(num_groups, n, k_) = get_shape<3>(b); + std::tie(m_, n_) = get_shape<2>(d); DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); DG_HOST_ASSERT(n > 0 and k > 0 and num_groups > 0); DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16); @@ -445,10 +460,12 @@ static void m_grouped_bf16_gemm_nt_contiguous(const torch::Tensor& a, const torc // Layout checks if (use_psum_layout) { - const auto& [num_groups_] = get_shape<1>(grouped_layout); + int num_groups_; + std::tie(num_groups_) = get_shape<1>(grouped_layout); DG_HOST_ASSERT(num_groups == num_groups_); } else { - const auto& [m__] = get_shape<1>(grouped_layout); + int m__; + std::tie(m__) = get_shape<1>(grouped_layout); DG_HOST_ASSERT(m == m__); DG_HOST_ASSERT(not expected_m_for_psum_layout.has_value()); } @@ -493,9 +510,10 @@ static void m_grouped_bf16_gemm_nt_masked(const torch::Tensor& a, const torch::T DG_HOST_ASSERT(masked_m.is_contiguous()); // Type and shape checks - const auto& [num_groups, m, k] = get_shape<3>(a); - const auto& [num_groups_, n, k_] = get_shape<3>(b); - const auto& [num_groups__, m_, n_] = get_shape<3>(d); + int num_groups, m, k, num_groups_, n, k_, num_groups__, m_, n_; + std::tie(num_groups, m, k) = get_shape<3>(a); + std::tie(num_groups_, n, k_) = get_shape<3>(b); + std::tie(num_groups__, m_, n_) = get_shape<3>(d); const auto& num_groups___ = static_cast(masked_m.numel()); DG_HOST_ASSERT(num_groups == num_groups_ and num_groups == num_groups__ and num_groups == num_groups___); DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); @@ -529,9 +547,10 @@ static void k_grouped_bf16_gemm_tn_contiguous(const torch::Tensor& a, const std::optional& c, const std::string& compiled_dims) { // Shape checks - const auto& [num_groups, m, n] = get_shape<3>(d); - const auto& [sum_k_ , m_] = get_shape<2>(a); - const auto& [sum_k__, n_] = get_shape<2>(b); + int num_groups, m, n, sum_k_, m_, sum_k__, n_; + std::tie(num_groups, m, n) = get_shape<3>(d); + std::tie(sum_k_, m_) = get_shape<2>(a); + std::tie(sum_k__, n_) = get_shape<2>(b); const int sum_k = std::accumulate(ks.begin(), ks.end(), 0); DG_HOST_ASSERT(m == m_ and n == n_ and sum_k == sum_k_ and sum_k == sum_k__); @@ -566,9 +585,10 @@ static void cublaslt_gemm_nt(const torch::Tensor& a, const torch::Tensor& b, const auto& major_b = get_major_type_ab(b); // Type and shape checks - const auto& [m , k ] = get_shape<2>(a); - const auto& [n , k_] = get_shape<2>(b); - const auto& [m_, n_] = get_shape<2>(d); + int m, k, n, k_, m_, n_; + std::tie(m, k) = get_shape<2>(a); + std::tie(n, k_) = get_shape<2>(b); + std::tie(m_, n_) = get_shape<2>(d); DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); // Early return for trivial cases diff --git a/csrc/jit/cache.hpp b/csrc/jit/cache.hpp index 1e8659fd32..0c6c8bd482 100644 --- a/csrc/jit/cache.hpp +++ b/csrc/jit/cache.hpp @@ -17,11 +17,12 @@ class KernelRuntimeCache { std::shared_ptr get(const std::filesystem::path& dir_path) { // Hit the runtime cache - if (const auto& iterator = cache.find(dir_path); iterator != cache.end()) + const auto& dir_path_str = dir_path.string(); + if (const auto& iterator = cache.find(dir_path_str); iterator != cache.end()) return iterator->second; if (KernelRuntime::check_validity(dir_path)) - return cache[dir_path] = std::make_shared(dir_path); + return cache[dir_path_str] = std::make_shared(dir_path); return nullptr; } }; diff --git a/csrc/jit/compiler.hpp b/csrc/jit/compiler.hpp index 3dc0cfbfff..21753f9043 100644 --- a/csrc/jit/compiler.hpp +++ b/csrc/jit/compiler.hpp @@ -46,7 +46,11 @@ class Compiler { Compiler::library_include_path = Compiler::library_root_path / "include"; Compiler::cuda_home = cuda_home_path_by_python; Compiler::library_version = get_library_version(); +#if DG_MSVC + Compiler::cuobjdump_path = Compiler::cuda_home / "bin" / "cuobjdump.exe"; +#else Compiler::cuobjdump_path = Compiler::cuda_home / "bin" / "cuobjdump"; +#endif } std::string signature, flags; @@ -61,7 +65,15 @@ class Compiler { DG_HOST_ASSERT(not cuobjdump_path.empty()); // Cache settings +#if DG_MSVC + // On Windows, use LOCALAPPDATA or USERPROFILE for cache directory + auto home_path = get_env("LOCALAPPDATA"); + if (home_path.empty()) + home_path = get_env("USERPROFILE"); + cache_dir_path = std::filesystem::path(home_path) / ".deep_gemm"; +#else cache_dir_path = std::filesystem::path(get_env("HOME")) / ".deep_gemm"; +#endif if (const auto& env_cache_dir_path = get_env("DG_JIT_CACHE_DIR"); not env_cache_dir_path.empty()) cache_dir_path = env_cache_dir_path; @@ -144,10 +156,12 @@ class Compiler { static void disassemble(const std::filesystem::path &cubin_path, const std::filesystem::path &sass_path) { // Disassemble the CUBIN file to SASS - const auto command = fmt::format("{} --dump-sass {} > {}", cuobjdump_path.c_str(), cubin_path.c_str(), sass_path.c_str()); + const auto command = fmt::format("{} --dump-sass {} > {}", cuobjdump_path.string(), cubin_path.string(), sass_path.string()); if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0)) printf("Running cuobjdump command: %s\n", command.c_str()); - const auto [return_code, output] = call_external_command(command); + int return_code; + std::string output; + std::tie(return_code, output) = call_external_command(command); if (return_code != 0) { printf("cuobjdump failed: %s\n", output.c_str()); DG_HOST_ASSERT(false and "cuobjdump failed"); @@ -170,8 +184,10 @@ class NVCCCompiler final: public Compiler { DG_HOST_ASSERT(std::filesystem::exists(nvcc_path)); // Call the version command - const auto& command = std::string(nvcc_path) + " --version"; - const auto& [return_code, output] = call_external_command(command); + const auto& command = nvcc_path.string() + " --version"; + int return_code; + std::string output; + std::tie(return_code, output) = call_external_command(command); DG_HOST_ASSERT(return_code == 0); // The version should be at least 12.3, for the best performance with 12.9 @@ -188,19 +204,31 @@ class NVCCCompiler final: public Compiler { public: NVCCCompiler() { // Override the compiler signature +#if DG_MSVC + nvcc_path = cuda_home / "bin" / "nvcc.exe"; +#else nvcc_path = cuda_home / "bin" / "nvcc"; +#endif if (const auto& env_nvcc_path = get_env("DG_JIT_NVCC_COMPILER"); not env_nvcc_path.empty()) nvcc_path = env_nvcc_path; - const auto& [nvcc_major, nvcc_minor] = get_nvcc_version(); + int nvcc_major, nvcc_minor; + std::tie(nvcc_major, nvcc_minor) = get_nvcc_version(); signature = fmt::format("NVCC{}.{}", nvcc_major, nvcc_minor); // The override the compiler flags // Only NVCC >= 12.9 supports arch-specific family suffix const auto& arch = device_runtime->get_arch(false, nvcc_major > 12 or nvcc_minor >= 9); +#if DG_MSVC + flags = fmt::format("{} -I{} --gpu-architecture=sm_{} " + "--compiler-options=/O2,/EHsc,/Zc:__cplusplus " + "-O3 --expt-relaxed-constexpr --expt-extended-lambda", + flags, library_include_path.string(), arch); +#else flags = fmt::format("{} -I{} --gpu-architecture=sm_{} " "--compiler-options=-fPIC,-O3,-fconcepts,-Wno-deprecated-declarations,-Wno-abi " "-O3 --expt-relaxed-constexpr --expt-extended-lambda", - flags, library_include_path.c_str(), arch); + flags, library_include_path.string(), arch); +#endif } void compile(const std::string &code, const std::filesystem::path& dir_path, @@ -211,10 +239,12 @@ class NVCCCompiler final: public Compiler { put(code_path, code); // Compile - const auto& command = fmt::format("{} {} -cubin -o {} {}", nvcc_path.c_str(), code_path.c_str(), cubin_path.c_str(), flags); + const auto& command = fmt::format("{} {} -cubin -o {} {}", nvcc_path.string(), code_path.string(), cubin_path.string(), flags); if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0)) printf("Running NVCC command: %s\n", command.c_str()); - const auto& [return_code, output] = call_external_command(command); + int return_code; + std::string output; + std::tie(return_code, output) = call_external_command(command); if (return_code != 0) { printf("NVCC compilation failed: %s\n", output.c_str()); DG_HOST_ASSERT(false and "NVCC compilation failed"); @@ -222,10 +252,12 @@ class NVCCCompiler final: public Compiler { // Compile to PTX if needed if (ptx_path.has_value()) { - const auto ptx_command = fmt::format("{} {} -ptx -o {} {}", nvcc_path.c_str(), code_path.c_str(), ptx_path->c_str(), flags); + const auto ptx_command = fmt::format("{} {} -ptx -o {} {}", nvcc_path.string(), code_path.string(), ptx_path->string(), flags); if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0)) printf("Running NVCC PTX command: %s\n", ptx_command.c_str()); - const auto [ptx_return_code, ptx_output] = call_external_command(ptx_command); + int ptx_return_code; + std::string ptx_output; + std::tie(ptx_return_code, ptx_output) = call_external_command(ptx_command); if (ptx_return_code != 0) { printf("NVCC PTX compilation failed: %s\n", ptx_output.c_str()); DG_HOST_ASSERT(false and "NVCC PTX compilation failed"); diff --git a/csrc/jit/handle.hpp b/csrc/jit/handle.hpp index 34447f91fd..40fd9b60b1 100644 --- a/csrc/jit/handle.hpp +++ b/csrc/jit/handle.hpp @@ -2,15 +2,44 @@ #include #include -#include #include #include "../utils/exception.hpp" #include "../utils/compatibility.hpp" +#include "../utils/msvc_compat.hpp" + +#if DG_MSVC + #include +#else + #include +#endif namespace deep_gemm { // Lazy loading all driver symbols +#if DG_MSVC +static HMODULE get_driver_handle() { + static HMODULE handle = nullptr; + if (handle == nullptr) { + handle = LoadLibraryA("nvcuda.dll"); + DG_HOST_ASSERT(handle != nullptr and "Failed to load CUDA driver `nvcuda.dll`"); + } + return handle; +} + +// Macro to define wrapper functions named `lazy_cu{API name}` (Windows version) +#define DECL_LAZY_CUDA_DRIVER_FUNCTION(name) \ +template \ +static auto lazy_##name(Args&&... args) -> decltype(name(args...)) { \ + using FuncType = decltype(&name); \ + static FuncType func = nullptr; \ + if (func == nullptr) { \ + func = reinterpret_cast(GetProcAddress(get_driver_handle(), #name)); \ + DG_HOST_ASSERT(func != nullptr and "Failed to load CUDA driver API"); \ + } \ + return func(std::forward(args)...); \ +} +#else static void* get_driver_handle() { static void* handle = nullptr; if (handle == nullptr) { @@ -20,7 +49,7 @@ static void* get_driver_handle() { return handle; } -// Macro to define wrapper functions named `lazy_cu{API name}` +// Macro to define wrapper functions named `lazy_cu{API name}` (POSIX version) #define DECL_LAZY_CUDA_DRIVER_FUNCTION(name) \ template \ static auto lazy_##name(Args&&... args) -> decltype(name(args...)) { \ @@ -32,6 +61,7 @@ static auto lazy_##name(Args&&... args) -> decltype(name(args...)) { \ } \ return func(std::forward(args)...); \ } +#endif DECL_LAZY_CUDA_DRIVER_FUNCTION(cuGetErrorName); DECL_LAZY_CUDA_DRIVER_FUNCTION(cuGetErrorString); @@ -56,7 +86,7 @@ static KernelHandle load_kernel(const std::filesystem::path& cubin_path, const s LibraryHandle *library_opt = nullptr) { LibraryHandle library; KernelHandle kernel{}; - DG_CUDA_RUNTIME_CHECK(cudaLibraryLoadFromFile(&library, cubin_path.c_str(), nullptr, nullptr, 0, nullptr, nullptr, 0)); + DG_CUDA_RUNTIME_CHECK(cudaLibraryLoadFromFile(&library, cubin_path.string().c_str(), nullptr, nullptr, 0, nullptr, nullptr, 0)); DG_CUDA_RUNTIME_CHECK(cudaLibraryGetKernel(&kernel, library, func_name.c_str())); if (library_opt != nullptr) @@ -114,7 +144,7 @@ static KernelHandle load_kernel(const std::filesystem::path& cubin_path, const s LibraryHandle *library_opt = nullptr) { LibraryHandle library; KernelHandle kernel; - DG_CUDA_DRIVER_CHECK(lazy_cuModuleLoad(&library, cubin_path.c_str())); + DG_CUDA_DRIVER_CHECK(lazy_cuModuleLoad(&library, cubin_path.string().c_str())); DG_CUDA_DRIVER_CHECK(lazy_cuModuleGetFunction(&kernel, library, func_name.c_str())); if (library_opt != nullptr) diff --git a/csrc/jit/kernel_runtime.hpp b/csrc/jit/kernel_runtime.hpp index ba66eeb88a..adf31cd379 100644 --- a/csrc/jit/kernel_runtime.hpp +++ b/csrc/jit/kernel_runtime.hpp @@ -9,18 +9,31 @@ namespace deep_gemm { struct LaunchArgs { - std::pair grid_dim; - int num_threads; - int smem_size; - int cluster_dim; - - LaunchArgs(const int& grid_dim_x, const int& num_threads, const int& smem_size = 0, const int& cluster_dim = 1): - grid_dim({grid_dim_x, 1}), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim) {} - - LaunchArgs(const std::pair& grid_dim, const int& num_threads, const int& smem_size = 0, const int& cluster_dim = 1): - grid_dim(grid_dim), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim) {} + std::pair grid_dim = {0, 0}; + int num_threads = 0; + int smem_size = 0; + int cluster_dim = 1; }; +// Helper functions to create LaunchArgs (for backward compatibility) +inline LaunchArgs make_launch_args(const int& grid_dim_x, const int& num_threads, const int& smem_size = 0, const int& cluster_dim = 1) { + LaunchArgs args; + args.grid_dim = {grid_dim_x, 1}; + args.num_threads = num_threads; + args.smem_size = smem_size; + args.cluster_dim = cluster_dim; + return args; +} + +inline LaunchArgs make_launch_args(const std::pair& grid_dim, const int& num_threads, const int& smem_size = 0, const int& cluster_dim = 1) { + LaunchArgs args; + args.grid_dim = grid_dim; + args.num_threads = num_threads; + args.smem_size = smem_size; + args.cluster_dim = cluster_dim; + return args; +} + class KernelRuntime final { public: static std::filesystem::path cuda_home; @@ -33,15 +46,19 @@ class KernelRuntime final { DG_HOST_ASSERT(not cuda_home.empty()); // NOLINT(*-pro-type-member-init) +#if DG_MSVC + const auto& cuobjdump_path = cuda_home / "bin" / "cuobjdump.exe"; +#else const auto& cuobjdump_path = cuda_home / "bin" / "cuobjdump"; +#endif const auto& cubin_path = dir_path / "kernel.cubin"; if (get_env("DG_JIT_DEBUG")) - printf("Loading CUBIN: %s\n", cubin_path.c_str()); + printf("Loading CUBIN: %s\n", cubin_path.string().c_str()); // Find the only symbol // TODO: use kernel enumeration for newer drivers const std::vector illegal_names = {"vprintf", "__instantiate_kernel", "__internal", "__assertfail"}; - const auto& [exit_code, symbols] = call_external_command(fmt::format("{} -symbols {}", cuobjdump_path.c_str(), cubin_path.c_str())); + const auto& [exit_code, symbols] = call_external_command(fmt::format("{} -symbols {}", cuobjdump_path.string(), cubin_path.string())); DG_HOST_ASSERT(exit_code == 0); std::istringstream iss(symbols); std::vector symbol_names; diff --git a/csrc/jit_kernels/heuristics/common.hpp b/csrc/jit_kernels/heuristics/common.hpp index a49584f421..cb8195892e 100644 --- a/csrc/jit_kernels/heuristics/common.hpp +++ b/csrc/jit_kernels/heuristics/common.hpp @@ -9,15 +9,19 @@ namespace deep_gemm { struct MulticastConfig { - int num_multicast; - bool is_multicast_on_a; - - MulticastConfig(const int& num_multicast, const bool& is_multicast_on_a): - num_multicast(num_multicast), is_multicast_on_a(is_multicast_on_a) { - DG_HOST_ASSERT(1 <= num_multicast and num_multicast <= 2); - } + int num_multicast = 1; + bool is_multicast_on_a = false; }; +// Helper function to create MulticastConfig with validation +inline MulticastConfig make_multicast_config(const int& num_multicast, const bool& is_multicast_on_a) { + DG_HOST_ASSERT(1 <= num_multicast and num_multicast <= 2); + MulticastConfig config; + config.num_multicast = num_multicast; + config.is_multicast_on_a = is_multicast_on_a; + return config; +} + struct SharedMemoryConfig { int smem_size; int swizzle_a_mode; @@ -119,7 +123,8 @@ static SharedMemoryConfig get_smem_config(const GemmType& gemm_type, const Kerne const int& smem_b_per_stage = load_block_n * block_k * ab_elem_size; // SF shared memory - const auto& [smem_sfa_per_stage, smem_sfb_per_stage] = + int smem_sfa_per_stage, smem_sfb_per_stage; + std::tie(smem_sfa_per_stage, smem_sfb_per_stage) = ArchSpec::get_sf_smem_size_per_stage(kernel_type, block_m, block_n, block_k, mma_kind, cd_dtype); const int& smem_extra_sfb = ArchSpec::get_extra_sfb_smem_size(m, n, k, block_m, block_n, block_k); @@ -233,7 +238,8 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k // Decide the number of TMA multicasts and whether broadcast on A MulticastConfig best_multicast_config = {1, false}; - auto [is_legal_on_a, is_legal_on_b] = ArchSpec::get_multicast_legality( + bool is_legal_on_a, is_legal_on_b; + std::tie(is_legal_on_a, is_legal_on_b) = ArchSpec::get_multicast_legality( gemm_type, num_groups, m, n, best_block_m, best_block_n, num_sms); // NOTES: TMA copy .b4x16_p64 only supports Swizzle 128B diff --git a/csrc/jit_kernels/heuristics/sm100.hpp b/csrc/jit_kernels/heuristics/sm100.hpp index dd1e6024ff..1467b948db 100644 --- a/csrc/jit_kernels/heuristics/sm100.hpp +++ b/csrc/jit_kernels/heuristics/sm100.hpp @@ -82,8 +82,7 @@ struct SM100ArchSpec { // Check tensor memory validity int sf_block_m = 0, sf_block_n = 0; if (kernel_type == KernelType::Kernel1D1D) { - const auto& [sf_block_m_, sf_block_n_] = get_sf_uttcp_aligned_block_sizes(block_m, block_n, mma_kind); - sf_block_m = sf_block_m_, sf_block_n = sf_block_n_; + std::tie(sf_block_m, sf_block_n) = get_sf_uttcp_aligned_block_sizes(block_m, block_n, mma_kind); } if (((2 * block_n) + (sf_block_m / 32) + (sf_block_n / 32)) > 512) return false; @@ -132,7 +131,8 @@ struct SM100ArchSpec { int smem_sfa_per_stage = 0; int smem_sfb_per_stage = 0; if (kernel_type == KernelType::Kernel1D1D) { - const auto [sf_block_m, sf_block_n] = get_sf_uttcp_aligned_block_sizes(block_m, block_n, mma_kind); + int sf_block_m, sf_block_n; + std::tie(sf_block_m, sf_block_n) = get_sf_uttcp_aligned_block_sizes(block_m, block_n, mma_kind); smem_sfa_per_stage = sf_block_m * 4; smem_sfb_per_stage = sf_block_n * 4; } else { diff --git a/csrc/jit_kernels/impls/sm100_bf16_gemm.hpp b/csrc/jit_kernels/impls/sm100_bf16_gemm.hpp index 95f7272987..24b1a0c8d0 100644 --- a/csrc/jit_kernels/impls/sm100_bf16_gemm.hpp +++ b/csrc/jit_kernels/impls/sm100_bf16_gemm.hpp @@ -17,7 +17,7 @@ class SM100BF16GemmRuntime final: public LaunchRuntime { public: struct Args { int m, n, k, num_groups; - const std::string& compiled_dims; + std::string compiled_dims; GemmConfig gemm_config; LaunchArgs launch_args; @@ -103,18 +103,18 @@ static void sm100_bf16_gemm(const torch::Tensor& a, config.smem_config.swizzle_cd_mode); // Launch - const SM100BF16GemmRuntime::Args& args = { - .m = m, .n = n, .k = k, - .num_groups = 1, - .compiled_dims = compiled_dims, - .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + SM100BF16GemmRuntime::Args args = { + m, n, k, + 1, + compiled_dims, + config, + make_launch_args(config.num_sms, config.thread_config.num_threads, config.smem_config.smem_size, config.multicast_config.num_multicast), - .grouped_layout = nullptr, - .tensor_map_a = tensor_map_a, - .tensor_map_b = tensor_map_b, - .tensor_map_cd = tensor_map_cd + nullptr, + tensor_map_a, + tensor_map_b, + tensor_map_cd }; const auto& code = SM100BF16GemmRuntime::generate(args); const auto& runtime = compiler->build("sm100_bf16_gemm", code); @@ -162,18 +162,18 @@ static void sm100_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a, config.smem_config.swizzle_cd_mode); // Launch - const SM100BF16GemmRuntime::Args& args = { - .m = m, .n = n, .k = k, - .num_groups = num_groups, - .compiled_dims = compiled_dims, - .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + SM100BF16GemmRuntime::Args args = { + m, n, k, + num_groups, + compiled_dims, + config, + make_launch_args(config.num_sms, config.thread_config.num_threads, config.smem_config.smem_size, config.multicast_config.num_multicast), - .grouped_layout = grouped_layout.data_ptr(), - .tensor_map_a = tensor_map_a, - .tensor_map_b = tensor_map_b, - .tensor_map_cd = tensor_map_cd + grouped_layout.data_ptr(), + tensor_map_a, + tensor_map_b, + tensor_map_cd }; const auto& code = SM100BF16GemmRuntime::generate(args); const auto& runtime = compiler->build("sm100_bf16_m_grouped_gemm_contiguous", code); @@ -212,18 +212,18 @@ static void sm100_m_grouped_bf16_gemm_masked(const torch::Tensor& a, config.smem_config.swizzle_cd_mode); // Launch - const SM100BF16GemmRuntime::Args& args = { - .m = m, .n = n, .k = k, - .num_groups = num_groups, - .compiled_dims = compiled_dims, - .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + SM100BF16GemmRuntime::Args args = { + m, n, k, + num_groups, + compiled_dims, + config, + make_launch_args(config.num_sms, config.thread_config.num_threads, config.smem_config.smem_size, config.multicast_config.num_multicast), - .grouped_layout = masked_m.data_ptr(), - .tensor_map_a = tensor_map_a, - .tensor_map_b = tensor_map_b, - .tensor_map_cd = tensor_map_cd + masked_m.data_ptr(), + tensor_map_a, + tensor_map_b, + tensor_map_cd }; const auto& code = SM100BF16GemmRuntime::generate(args); const auto& runtime = compiler->build("sm100_bf16_m_grouped_gemm_masked", code); @@ -274,18 +274,18 @@ static void sm100_bf16_k_grouped_gemm(const torch::Tensor& a, config.smem_config.swizzle_cd_mode); // Launch kernel - const SM100BF16GemmRuntime::Args& args = { - .m = m, .n = n, .k = sum_k, - .num_groups = num_groups, - .compiled_dims = compiled_dims, - .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + SM100BF16GemmRuntime::Args args = { + m, n, sum_k, + num_groups, + compiled_dims, + config, + make_launch_args(config.num_sms, config.thread_config.num_threads, config.smem_config.smem_size, config.multicast_config.num_multicast), - .grouped_layout = ks_tensor.data_ptr(), - .tensor_map_a = tensor_map_a, - .tensor_map_b = tensor_map_b, - .tensor_map_cd = tensor_map_cd + ks_tensor.data_ptr(), + tensor_map_a, + tensor_map_b, + tensor_map_cd }; const auto& code = SM100BF16GemmRuntime::generate(args); const auto& runtime = compiler->build("sm100_bf16_k_grouped_gemm", code); @@ -322,18 +322,18 @@ static void sm100_bf16_bhr_hdr_bhd(const torch::Tensor& tensor_a, config.smem_config.swizzle_cd_mode); // Launch - const SM100BF16GemmRuntime::Args& args = { - .m = b, .n = d, .k = r, - .num_groups = h, - .compiled_dims = compiled_dims, - .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + SM100BF16GemmRuntime::Args args = { + b, d, r, + h, + compiled_dims, + config, + make_launch_args(config.num_sms, config.thread_config.num_threads, config.smem_config.smem_size, config.multicast_config.num_multicast), - .grouped_layout = nullptr, - .tensor_map_a = tensor_map_a, - .tensor_map_b = tensor_map_b, - .tensor_map_cd = tensor_map_cd + nullptr, + tensor_map_a, + tensor_map_b, + tensor_map_cd }; const auto& code = SM100BF16GemmRuntime::generate(args); const auto& runtime = compiler->build("sm100_bf16_bhr_hdr_bhd", code); @@ -370,18 +370,18 @@ static void sm100_bf16_bhd_hdr_bhr(const torch::Tensor& tensor_a, config.smem_config.swizzle_cd_mode); // Launch - const SM100BF16GemmRuntime::Args& args = { - .m = b, .n = r, .k = d, - .num_groups = h, - .compiled_dims = compiled_dims, - .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + SM100BF16GemmRuntime::Args args = { + b, r, d, + h, + compiled_dims, + config, + make_launch_args(config.num_sms, config.thread_config.num_threads, config.smem_config.smem_size, config.multicast_config.num_multicast), - .grouped_layout = nullptr, - .tensor_map_a = tensor_map_a, - .tensor_map_b = tensor_map_b, - .tensor_map_cd = tensor_map_cd + nullptr, + tensor_map_a, + tensor_map_b, + tensor_map_cd }; const auto& code = SM100BF16GemmRuntime::generate(args); const auto& runtime = compiler->build("sm100_bf16_bhd_hdr_bhr", code); diff --git a/csrc/jit_kernels/impls/sm100_bmk_bnk_mn.hpp b/csrc/jit_kernels/impls/sm100_bmk_bnk_mn.hpp index 45c0bc8aca..1d771c890a 100644 --- a/csrc/jit_kernels/impls/sm100_bmk_bnk_mn.hpp +++ b/csrc/jit_kernels/impls/sm100_bmk_bnk_mn.hpp @@ -116,18 +116,17 @@ static void sm100_bmn_bnk_mn_gemm(const torch::Tensor &a, const auto& tensor_map_b = make_tma_2d_desc(b, k, s * n, block_k, block_n, k, swizzle_ab_mode); const auto& tensor_map_d = make_tma_2d_desc(d, n, m, block_n, block_m, n, swizzle_cd_mode); - const SM100BmkBnkMnRuntime::Args& args = { - .s = s, .m = m, .n = n, .k = k, - .block_m = block_m, .block_n = block_n, .block_k = block_k, - .split_factor = split_factor, - .swizzle_ab_mode = swizzle_ab_mode, - .swizzle_cd_mode = swizzle_cd_mode, - .num_stages = num_stages, - .num_threads = num_threads, - .launch_args = LaunchArgs(num_mn_blocks * ceil_div(num_sk_blocks, split_factor), num_threads, smem_size), - .tensor_map_a = tensor_map_a, - .tensor_map_b = tensor_map_b, - .tensor_map_d = tensor_map_d + SM100BmkBnkMnRuntime::Args args = { + s, m, n, k, + block_m, block_n, block_k, + split_factor, + swizzle_ab_mode, swizzle_cd_mode, + num_stages, + num_threads, + make_launch_args(num_mn_blocks * ceil_div(num_sk_blocks, split_factor), num_threads, smem_size), + tensor_map_a, + tensor_map_b, + tensor_map_d }; const auto& code = SM100BmkBnkMnRuntime::generate(args); const auto& runtime = compiler->build("sm100_bmn_bnk_mn_gemm", code); diff --git a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp index 07a977d73e..d54bb3d38b 100644 --- a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp +++ b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp @@ -20,8 +20,8 @@ class SM100FP8FP4Gemm1D1DRuntime final: public LaunchRuntime& epilogue_type; + std::string compiled_dims; + std::optional epilogue_type; GemmConfig gemm_config; LaunchArgs launch_args; @@ -121,23 +121,22 @@ static void sm100_fp8_fp4_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& config.block_n, gran_k_b, 1, 0); // Launch - const SM100FP8FP4Gemm1D1DRuntime::Args& args = { - .m = m, .n = n, .k = k, - .num_groups = 1, - .gran_k_a = gran_k_a, - .gran_k_b = gran_k_b, - .compiled_dims = compiled_dims, - .epilogue_type = epilogue_type, - .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + SM100FP8FP4Gemm1D1DRuntime::Args args = { + m, n, k, + 1, + gran_k_a, gran_k_b, + compiled_dims, + epilogue_type, + config, + make_launch_args(config.num_sms, config.thread_config.num_threads, config.smem_config.smem_size, config.multicast_config.num_multicast), - .grouped_layout = nullptr, - .tensor_map_a = tensor_map_a, - .tensor_map_b = tensor_map_b, - .tensor_map_sfa = tensor_map_sfa, - .tensor_map_sfb = tensor_map_sfb, - .tensor_map_cd = tensor_map_cd + nullptr, + tensor_map_a, + tensor_map_b, + tensor_map_sfa, + tensor_map_sfb, + tensor_map_cd }; const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args); const auto& runtime = compiler->build("sm100_fp8_fp4_gemm_1d1d", code); @@ -190,23 +189,22 @@ static void sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d(const torch::Tensor& a, config.block_n, gran_k_b, num_groups, 0); // Launch kernel - const SM100FP8FP4Gemm1D1DRuntime::Args& args = { - .m = m, .n = n, .k = k, - .num_groups = num_groups, - .gran_k_a = gran_k_a, - .gran_k_b = gran_k_b, - .compiled_dims = compiled_dims, - .epilogue_type = std::nullopt, - .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + SM100FP8FP4Gemm1D1DRuntime::Args args = { + m, n, k, + num_groups, + gran_k_a, gran_k_b, + compiled_dims, + std::nullopt, + config, + make_launch_args(config.num_sms, config.thread_config.num_threads, config.smem_config.smem_size, config.multicast_config.num_multicast), - .grouped_layout = grouped_layout.data_ptr(), - .tensor_map_a = tensor_map_a, - .tensor_map_b = tensor_map_b, - .tensor_map_sfa = tensor_map_sfa, - .tensor_map_sfb = tensor_map_sfb, - .tensor_map_cd = tensor_map_cd + grouped_layout.data_ptr(), + tensor_map_a, + tensor_map_b, + tensor_map_sfa, + tensor_map_sfb, + tensor_map_cd }; const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args); const auto& runtime = compiler->build("sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d", code); @@ -251,23 +249,22 @@ static void sm100_m_grouped_fp8_fp4_gemm_masked_1d1d(const torch::Tensor& a, con config.block_n, gran_k_b, num_groups, 0); // Launch kernel - const SM100FP8FP4Gemm1D1DRuntime::Args& args = { - .m = m, .n = n, .k = k, - .num_groups = num_groups, - .gran_k_a = gran_k_a, - .gran_k_b = gran_k_b, - .compiled_dims = compiled_dims, - .epilogue_type = std::nullopt, - .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + SM100FP8FP4Gemm1D1DRuntime::Args args = { + m, n, k, + num_groups, + gran_k_a, gran_k_b, + compiled_dims, + std::nullopt, + config, + make_launch_args(config.num_sms, config.thread_config.num_threads, config.smem_config.smem_size, config.multicast_config.num_multicast), - .grouped_layout = masked_m.data_ptr(), - .tensor_map_a = tensor_map_a, - .tensor_map_b = tensor_map_b, - .tensor_map_sfa = tensor_map_sfa, - .tensor_map_sfb = tensor_map_sfb, - .tensor_map_cd = tensor_map_cd + masked_m.data_ptr(), + tensor_map_a, + tensor_map_b, + tensor_map_sfa, + tensor_map_sfb, + tensor_map_cd }; const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args); const auto& runtime = compiler->build("sm100_m_grouped_fp8_fp4_gemm_masked_1d1d", code); @@ -322,23 +319,22 @@ static void sm100_k_grouped_fp8_gemm_1d1d(const torch::Tensor& a, const torch::T config.block_n, config.block_k, 1, 0); // Launch kernel - const SM100FP8FP4Gemm1D1DRuntime::Args& args = { - .m = m, .n = n, .k = sum_k, - .num_groups = num_groups, - .gran_k_a = 128, - .gran_k_b = 128, - .compiled_dims = compiled_dims, - .epilogue_type = std::nullopt, - .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + SM100FP8FP4Gemm1D1DRuntime::Args args = { + m, n, sum_k, + num_groups, + 128, 128, + compiled_dims, + std::nullopt, + config, + make_launch_args(config.num_sms, config.thread_config.num_threads, config.smem_config.smem_size, config.multicast_config.num_multicast), - .grouped_layout = ks_tensor.data_ptr(), - .tensor_map_a = tensor_map_a, - .tensor_map_b = tensor_map_b, - .tensor_map_sfa = tensor_map_sfa, - .tensor_map_sfb = tensor_map_sfb, - .tensor_map_cd = tensor_map_cd + ks_tensor.data_ptr(), + tensor_map_a, + tensor_map_b, + tensor_map_sfa, + tensor_map_sfb, + tensor_map_cd }; const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args); const auto& runtime = compiler->build("sm100_k_grouped_fp8_gemm_1d1d", code); @@ -390,23 +386,22 @@ static void sm100_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa, config.block_n, config.block_k, batch_size, 0); // Launch - const SM100FP8FP4Gemm1D1DRuntime::Args& args = { - .m = m, .n = n, .k = k, - .num_groups = batch_size, - .gran_k_a = 128, - .gran_k_b = 128, - .compiled_dims = compiled_dims, - .epilogue_type = std::nullopt, - .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + SM100FP8FP4Gemm1D1DRuntime::Args args = { + m, n, k, + batch_size, + 128, 128, + compiled_dims, + std::nullopt, + config, + make_launch_args(config.num_sms, config.thread_config.num_threads, config.smem_config.smem_size, config.multicast_config.num_multicast), - .grouped_layout = nullptr, - .tensor_map_a = tensor_map_a, - .tensor_map_b = tensor_map_b, - .tensor_map_sfa = tensor_map_sfa, - .tensor_map_sfb = tensor_map_sfb, - .tensor_map_cd = tensor_map_cd + nullptr, + tensor_map_a, + tensor_map_b, + tensor_map_sfa, + tensor_map_sfb, + tensor_map_cd }; const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args); const auto& runtime = compiler->build("sm100_fp8_gemm_1d1d", code); diff --git a/csrc/jit_kernels/impls/sm100_tf32_hc_prenorm_gemm.hpp b/csrc/jit_kernels/impls/sm100_tf32_hc_prenorm_gemm.hpp index 4f3ce5b1c9..86fc331a3f 100644 --- a/csrc/jit_kernels/impls/sm100_tf32_hc_prenorm_gemm.hpp +++ b/csrc/jit_kernels/impls/sm100_tf32_hc_prenorm_gemm.hpp @@ -127,19 +127,18 @@ static void sm100_tf32_hc_prenorm_gemm(const torch::Tensor& a, } // Launch - const SM100BF16HCPrenormGemmRuntime::Args& args = { - .m = m, .n = n, .k = k, - .block_m = block_m, .block_n = block_n, .block_k = block_k, - .num_splits = num_splits, - .swizzle_cd_mode = swizzle_cd_mode, - .num_stages = num_stages, - .num_mma_threads = num_mma_threads, - .num_cast_and_reduce_threads = num_cast_and_reduce_threads, - .launch_args = LaunchArgs(num_splits * ceil_div(m, block_m), num_mma_threads + num_cast_and_reduce_threads, smem_size, 1), - .tensor_map_a = tensor_map_a, - .tensor_map_b = tensor_map_b, - .tensor_map_d = tensor_map_d, - .sqr_sum = sqr_sum.data_ptr() + SM100BF16HCPrenormGemmRuntime::Args args = { + m, n, k, + block_m, block_n, block_k, + num_splits, + swizzle_cd_mode, + num_stages, + num_mma_threads, num_cast_and_reduce_threads, + make_launch_args(num_splits * ceil_div(m, block_m), num_mma_threads + num_cast_and_reduce_threads, smem_size, 1), + tensor_map_a, + tensor_map_b, + tensor_map_d, + sqr_sum.data_ptr() }; const auto& code = SM100BF16HCPrenormGemmRuntime::generate(args); const auto& runtime = compiler->build("sm100_tf32_hc_prenorm_gemm", code); diff --git a/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp b/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp index 32003f882d..573cf776d8 100644 --- a/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp +++ b/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp @@ -15,7 +15,7 @@ class SM90BF16GemmRuntime final: public LaunchRuntime { public: struct Args { int m, n, k, num_groups; - const std::string& compiled_dims; + std::string compiled_dims; GemmConfig gemm_config; LaunchArgs launch_args; @@ -104,18 +104,18 @@ static void sm90_bf16_gemm(const torch::Tensor& a, config.smem_config.swizzle_cd_mode); // Launch - const SM90BF16GemmRuntime::Args& args = { - .m = m, .n = n, .k = k, - .num_groups = 1, - .compiled_dims = compiled_dims, - .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + SM90BF16GemmRuntime::Args args = { + m, n, k, + 1, + compiled_dims, + config, + make_launch_args(config.num_sms, config.thread_config.num_threads, config.smem_config.smem_size, config.multicast_config.num_multicast), - .grouped_layout = nullptr, - .tensor_map_a = tensor_map_a, - .tensor_map_b = tensor_map_b, - .tensor_map_cd = tensor_map_cd, + nullptr, + tensor_map_a, + tensor_map_b, + tensor_map_cd }; const auto& code = SM90BF16GemmRuntime::generate(args); const auto& runtime = compiler->build("sm90_bf16_gemm", code); @@ -158,18 +158,18 @@ static void sm90_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a, config.smem_config.swizzle_cd_mode); // Launch - const SM90BF16GemmRuntime::Args& args = { - .m = m, .n = n, .k = k, - .num_groups = num_groups, - .compiled_dims = compiled_dims, - .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + SM90BF16GemmRuntime::Args args = { + m, n, k, + num_groups, + compiled_dims, + config, + make_launch_args(config.num_sms, config.thread_config.num_threads, config.smem_config.smem_size, config.multicast_config.num_multicast), - .grouped_layout = m_indices.data_ptr(), - .tensor_map_a = tensor_map_a, - .tensor_map_b = tensor_map_b, - .tensor_map_cd = tensor_map_cd, + m_indices.data_ptr(), + tensor_map_a, + tensor_map_b, + tensor_map_cd }; const auto& code = SM90BF16GemmRuntime::generate(args); const auto& runtime = compiler->build("sm90_m_grouped_bf16_gemm_contiguous", code); @@ -213,18 +213,18 @@ static void sm90_bf16_m_grouped_gemm_masked(const torch::Tensor& a, config.smem_config.swizzle_cd_mode); // Launch - const SM90BF16GemmRuntime::Args& args = { - .m = m, .n = n, .k = k, - .num_groups = num_groups, - .compiled_dims = compiled_dims, - .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + SM90BF16GemmRuntime::Args args = { + m, n, k, + num_groups, + compiled_dims, + config, + make_launch_args(config.num_sms, config.thread_config.num_threads, config.smem_config.smem_size, config.multicast_config.num_multicast), - .grouped_layout = masked_m.data_ptr(), - .tensor_map_a = tensor_map_a, - .tensor_map_b = tensor_map_b, - .tensor_map_cd = tensor_map_cd, + masked_m.data_ptr(), + tensor_map_a, + tensor_map_b, + tensor_map_cd }; const auto& code = SM90BF16GemmRuntime::generate(args); const auto& runtime = compiler->build("sm90_bf16_m_grouped_gemm_masked", code); @@ -275,18 +275,18 @@ static void sm90_bf16_k_grouped_gemm(const torch::Tensor& a, config.smem_config.swizzle_cd_mode); // Launch kernel - const SM90BF16GemmRuntime::Args& args = { - .m = m, .n = n, .k = sum_k, - .num_groups = num_groups, - .compiled_dims = compiled_dims, - .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + SM90BF16GemmRuntime::Args args = { + m, n, sum_k, + num_groups, + compiled_dims, + config, + make_launch_args(config.num_sms, config.thread_config.num_threads, config.smem_config.smem_size, config.multicast_config.num_multicast), - .grouped_layout = ks_tensor.data_ptr(), - .tensor_map_a = tensor_map_a, - .tensor_map_b = tensor_map_b, - .tensor_map_cd = tensor_map_cd, + ks_tensor.data_ptr(), + tensor_map_a, + tensor_map_b, + tensor_map_cd }; const auto& code = SM90BF16GemmRuntime::generate(args); const auto& runtime = compiler->build("sm90_bf16_k_grouped_gemm", code); @@ -322,18 +322,18 @@ static void sm90_bf16_bhr_hdr_bhd(const torch::Tensor& tensor_a, tensor_d.stride(0), tensor_d.stride(1), config.smem_config.swizzle_cd_mode); // Launch - const SM90BF16GemmRuntime::Args& args = { - .m = b, .n = d, .k = r, - .num_groups = h, - .compiled_dims = compiled_dims, - .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + SM90BF16GemmRuntime::Args args = { + b, d, r, + h, + compiled_dims, + config, + make_launch_args(config.num_sms, config.thread_config.num_threads, config.smem_config.smem_size, config.multicast_config.num_multicast), - .grouped_layout = nullptr, - .tensor_map_a = tensor_map_a, - .tensor_map_b = tensor_map_b, - .tensor_map_cd = tensor_map_cd, + nullptr, + tensor_map_a, + tensor_map_b, + tensor_map_cd }; const auto& code = SM90BF16GemmRuntime::generate(args); const auto& runtime = compiler->build("sm90_bf16_bhr_hdr_bhd", code); @@ -369,18 +369,18 @@ static void sm90_bf16_bhd_hdr_bhr(const torch::Tensor& tensor_a, tensor_d.stride(0), tensor_d.stride(1), config.smem_config.swizzle_cd_mode); // Launch - const SM90BF16GemmRuntime::Args& args = { - .m = b, .n = r, .k = d, - .num_groups = h, - .compiled_dims = compiled_dims, - .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + SM90BF16GemmRuntime::Args args = { + b, r, d, + h, + compiled_dims, + config, + make_launch_args(config.num_sms, config.thread_config.num_threads, config.smem_config.smem_size, config.multicast_config.num_multicast), - .grouped_layout = nullptr, - .tensor_map_a = tensor_map_a, - .tensor_map_b = tensor_map_b, - .tensor_map_cd = tensor_map_cd, + nullptr, + tensor_map_a, + tensor_map_b, + tensor_map_cd }; const auto& code = SM90BF16GemmRuntime::generate(args); const auto& runtime = compiler->build("sm90_bf16_bhd_hdr_bhr", code); diff --git a/csrc/jit_kernels/impls/sm90_bmk_bnk_mn.hpp b/csrc/jit_kernels/impls/sm90_bmk_bnk_mn.hpp index ccaea7f27a..29732134fd 100644 --- a/csrc/jit_kernels/impls/sm90_bmk_bnk_mn.hpp +++ b/csrc/jit_kernels/impls/sm90_bmk_bnk_mn.hpp @@ -111,17 +111,16 @@ static void sm90_bmn_bnk_mn_gemm(const torch::Tensor &a, const auto& tensor_map_a = make_tma_2d_desc(a, k, s * m, block_k, block_m, k, swizzle_ab_mode); const auto& tensor_map_b = make_tma_2d_desc(b, k, s * n, block_k, block_n, k, swizzle_ab_mode); - const SM90BmkBnkMnRuntime::Args& args = { - .s = s, .m = m, .n = n, .k = k, - .block_m = block_m, .block_n = block_n, .block_k = block_k, - .split_factor = split_factor, - .num_stages = num_stages, - .num_tma_threads = num_tma_threads, - .num_math_threads = num_math_threads, - .launch_args = LaunchArgs(num_mn_blocks * ceil_div(num_sk_blocks, split_factor), num_tma_threads + num_math_threads, smem_size), - .tensor_map_a = tensor_map_a, - .tensor_map_b = tensor_map_b, - .d = d.data_ptr() + SM90BmkBnkMnRuntime::Args args = { + s, m, n, k, + block_m, block_n, block_k, + split_factor, + num_stages, + num_tma_threads, num_math_threads, + make_launch_args(num_mn_blocks * ceil_div(num_sk_blocks, split_factor), num_tma_threads + num_math_threads, smem_size), + tensor_map_a, + tensor_map_b, + d.data_ptr() }; const auto& code = SM90BmkBnkMnRuntime::generate(args); const auto& runtime = compiler->build("sm90_bmn_bnk_mn_gemm", code); diff --git a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp index e61841b347..aeec6de642 100644 --- a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp +++ b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp @@ -16,7 +16,7 @@ class SM90FP8Gemm1D1DRuntime final: public LaunchRuntime public: struct Args { int m, n, k, num_groups; - const std::string& compiled_dims; + std::string compiled_dims; GemmConfig gemm_config; LaunchArgs launch_args; @@ -115,23 +115,23 @@ static void sm90_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, 0); // Launch - const SM90FP8Gemm1D1DRuntime::Args& args = { - .m = m, .n = n, .k = k, - .num_groups = 1, - .compiled_dims = compiled_dims, - .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + SM90FP8Gemm1D1DRuntime::Args args = { + m, n, k, + 1, + compiled_dims, + config, + make_launch_args(config.num_sms, config.thread_config.num_threads, config.smem_config.smem_size, config.multicast_config.num_multicast), - .gmem_a_ptr = nullptr, - .gmem_b_ptr = nullptr, - .grouped_layout = nullptr, - .tensor_map_buffer = nullptr, - .tensor_map_a_base = tensor_map_a, - .tensor_map_b_base = tensor_map_b, - .tensor_map_sfa = tensor_map_sfa, - .tensor_map_sfb = tensor_map_sfb, - .tensor_map_cd = tensor_map_cd, + nullptr, + nullptr, + nullptr, + nullptr, + tensor_map_a, + tensor_map_b, + tensor_map_sfa, + tensor_map_sfb, + tensor_map_cd }; const auto& code = SM90FP8Gemm1D1DRuntime::generate(args); const auto& runtime = compiler->build("sm90_fp8_gemm_1d1d", code); @@ -191,23 +191,23 @@ static void sm90_k_grouped_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Te config.smem_config.swizzle_cd_mode); // Launch - const SM90FP8Gemm1D1DRuntime::Args& args = { - .m = m, .n = n, .k = sum_k, - .num_groups = num_groups, - .compiled_dims = compiled_dims, - .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + SM90FP8Gemm1D1DRuntime::Args args = { + m, n, sum_k, + num_groups, + compiled_dims, + config, + make_launch_args(config.num_sms, config.thread_config.num_threads, config.smem_config.smem_size, config.multicast_config.num_multicast), - .gmem_a_ptr = a.data_ptr(), - .gmem_b_ptr = b.data_ptr(), - .grouped_layout = ks_tensor.data_ptr(), - .tensor_map_buffer = tensor_map_buffer.data_ptr(), - .tensor_map_a_base = tensor_map_a_base, - .tensor_map_b_base = tensor_map_b_base, - .tensor_map_sfa = tensor_map_sfa, - .tensor_map_sfb = tensor_map_sfb, - .tensor_map_cd = tensor_map_cd, + a.data_ptr(), + b.data_ptr(), + ks_tensor.data_ptr(), + tensor_map_buffer.data_ptr(), + tensor_map_a_base, + tensor_map_b_base, + tensor_map_sfa, + tensor_map_sfb, + tensor_map_cd }; const auto& code = SM90FP8Gemm1D1DRuntime::generate(args); const auto& runtime = compiler->build("sm90_fp8_gemm_1d1d", code); diff --git a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp index 2696b5a071..db224918cb 100644 --- a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp +++ b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp @@ -19,8 +19,8 @@ class SM90FP8Gemm1D2DRuntime final: public LaunchRuntime struct Args { cute::UMMA::Major major_sfb; int m, n, k, num_groups; - const std::string& compiled_dims; - const std::optional& epilogue_type; + std::string compiled_dims; + std::optional epilogue_type; GemmConfig gemm_config; LaunchArgs launch_args; @@ -116,22 +116,22 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, config.block_m, config.block_k, 1, 0); // Launch - const SM90FP8Gemm1D2DRuntime::Args& args = { - .major_sfb = major_sfb, - .m = m, .n = n, .k = k, - .num_groups = 1, - .compiled_dims = compiled_dims, - .epilogue_type = epilogue_type, - .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + SM90FP8Gemm1D2DRuntime::Args args = { + major_sfb, + m, n, k, + 1, + compiled_dims, + epilogue_type, + config, + make_launch_args(config.num_sms, config.thread_config.num_threads, config.smem_config.smem_size, config.multicast_config.num_multicast), - .sfb = sfb.data_ptr(), - .grouped_layout = nullptr, - .tensor_map_a = tensor_map_a, - .tensor_map_b = tensor_map_b, - .tensor_map_d = tensor_map_d, - .tensor_map_sfa = tensor_map_sfa, + sfb.data_ptr(), + nullptr, + tensor_map_a, + tensor_map_b, + tensor_map_d, + tensor_map_sfa }; const auto& code = SM90FP8Gemm1D2DRuntime::generate(args); const auto& runtime = compiler->build("sm90_fp8_gemm_1d2d", code); @@ -177,22 +177,22 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons config.block_m, config.block_k, 1, 0); // Launch - const SM90FP8Gemm1D2DRuntime::Args& args = { - .major_sfb = major_sfb, - .m = m, .n = n, .k = k, - .num_groups = num_groups, - .compiled_dims = compiled_dims, - .epilogue_type = std::nullopt, - .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + SM90FP8Gemm1D2DRuntime::Args args = { + major_sfb, + m, n, k, + num_groups, + compiled_dims, + std::nullopt, + config, + make_launch_args(config.num_sms, config.thread_config.num_threads, config.smem_config.smem_size, config.multicast_config.num_multicast), - .sfb = sfb.data_ptr(), - .grouped_layout = m_indices.data_ptr(), - .tensor_map_a = tensor_map_a, - .tensor_map_b = tensor_map_b, - .tensor_map_d = tensor_map_d, - .tensor_map_sfa = tensor_map_sfa, + sfb.data_ptr(), + m_indices.data_ptr(), + tensor_map_a, + tensor_map_b, + tensor_map_d, + tensor_map_sfa }; const auto& code = SM90FP8Gemm1D2DRuntime::generate(args); const auto& runtime = compiler->build("sm90_m_grouped_fp8_gemm_contiguous_1d2d", code); @@ -239,22 +239,22 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to config.block_m, config.block_k, num_groups, 0); // Launch - const SM90FP8Gemm1D2DRuntime::Args& args = { - .major_sfb = major_sfb, - .m = m, .n = n, .k = k, - .num_groups = num_groups, - .compiled_dims = compiled_dims, - .epilogue_type = std::nullopt, - .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + SM90FP8Gemm1D2DRuntime::Args args = { + major_sfb, + m, n, k, + num_groups, + compiled_dims, + std::nullopt, + config, + make_launch_args(config.num_sms, config.thread_config.num_threads, config.smem_config.smem_size, config.multicast_config.num_multicast), - .sfb = sfb.data_ptr(), - .grouped_layout = masked_m.data_ptr(), - .tensor_map_a = tensor_map_a, - .tensor_map_b = tensor_map_b, - .tensor_map_d = tensor_map_d, - .tensor_map_sfa = tensor_map_sfa, + sfb.data_ptr(), + masked_m.data_ptr(), + tensor_map_a, + tensor_map_b, + tensor_map_d, + tensor_map_sfa }; const auto& code = SM90FP8Gemm1D2DRuntime::generate(args); const auto& runtime = compiler->build("sm90_fp8_m_grouped_gemm_masked_1d2d", code); @@ -306,22 +306,22 @@ static void sm90_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa, config.block_m, config.block_k, batch_size, 0); // Launch - const SM90FP8Gemm1D2DRuntime::Args& args = { - .major_sfb = major_sfb, - .m = m, .n = n, .k = k, - .num_groups = batch_size, - .compiled_dims = compiled_dims, - .epilogue_type = std::nullopt, - .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + SM90FP8Gemm1D2DRuntime::Args args = { + major_sfb, + m, n, k, + batch_size, + compiled_dims, + std::nullopt, + config, + make_launch_args(config.num_sms, config.thread_config.num_threads, config.smem_config.smem_size, config.multicast_config.num_multicast), - .sfb = sfb.data_ptr(), - .grouped_layout = nullptr, - .tensor_map_a = tensor_map_a, - .tensor_map_b = tensor_map_b, - .tensor_map_d = tensor_map_d, - .tensor_map_sfa = tensor_map_sfa, + sfb.data_ptr(), + nullptr, + tensor_map_a, + tensor_map_b, + tensor_map_d, + tensor_map_sfa }; const auto& code = SM90FP8Gemm1D2DRuntime::generate(args); const auto& runtime = compiler->build("sm90_fp8_gemm_1d2d", code); diff --git a/csrc/jit_kernels/impls/sm90_tf32_hc_prenorm_gemm.hpp b/csrc/jit_kernels/impls/sm90_tf32_hc_prenorm_gemm.hpp index aeea262311..86f9995062 100644 --- a/csrc/jit_kernels/impls/sm90_tf32_hc_prenorm_gemm.hpp +++ b/csrc/jit_kernels/impls/sm90_tf32_hc_prenorm_gemm.hpp @@ -130,19 +130,18 @@ static void sm90_tf32_hc_prenorm_gemm(const torch::Tensor& a, smem_size = SM90ArchSpec::smem_capacity; // Launch - const SM90BF16HCPrenormGemmRuntime::Args& args = { - .m = m, .n = n, .k = k, - .block_m = block_m, .block_n = block_n, .block_k = block_k, - .num_splits = num_splits, - .swizzle_cd_mode = swizzle_cd_mode, - .num_stages = num_stages, - .num_math_threads = num_math_threads, - .num_tma_threads = num_tma_threads, - .launch_args = LaunchArgs(num_splits * ceil_div(m, block_m), num_threads, smem_size, 1), - .tensor_map_a = tensor_map_a, - .tensor_map_b = tensor_map_b, - .tensor_map_d = tensor_map_d, - .sqr_sum = sqr_sum.data_ptr() + SM90BF16HCPrenormGemmRuntime::Args args = { + m, n, k, + block_m, block_n, block_k, + num_splits, + swizzle_cd_mode, + num_stages, + num_math_threads, num_tma_threads, + make_launch_args(num_splits * ceil_div(m, block_m), num_threads, smem_size, 1), + tensor_map_a, + tensor_map_b, + tensor_map_d, + sqr_sum.data_ptr() }; const auto& code = SM90BF16HCPrenormGemmRuntime::generate(args); const auto& runtime = compiler->build("sm90_tf32_hc_prenorm_gemm", code); diff --git a/csrc/jit_kernels/impls/smxx_clean_logits.hpp b/csrc/jit_kernels/impls/smxx_clean_logits.hpp index fdb91a03bb..03a1eed8d2 100644 --- a/csrc/jit_kernels/impls/smxx_clean_logits.hpp +++ b/csrc/jit_kernels/impls/smxx_clean_logits.hpp @@ -58,17 +58,17 @@ static void smxx_clean_logits(const torch::Tensor& logits, const int smem_size = block_kv * sizeof(float); // Launch - const SMXXCleanLogitsRuntime::Args& args = { - .next_n = next_n, - .seq_len = seq_len, - .seq_len_kv = seq_len_kv, - .stride_logits = stride_logits, - .cu_seq_len_k_start = cu_seq_len_k_start.has_value() ? cu_seq_len_k_start.value().data_ptr() : nullptr, - .cu_seq_len_k_end = cu_seq_len_k_end.data_ptr(), - .logits = logits.data_ptr(), - .block_kv = block_kv, - .num_warps = num_warps, - .launch_args = LaunchArgs(device_runtime->get_num_sms(), + SMXXCleanLogitsRuntime::Args args = { + next_n, + seq_len, + seq_len_kv, + stride_logits, + cu_seq_len_k_start.has_value() ? cu_seq_len_k_start.value().data_ptr() : nullptr, + cu_seq_len_k_end.data_ptr(), + logits.data_ptr(), + block_kv, + num_warps, + make_launch_args(device_runtime->get_num_sms(), num_warps * 32, smem_size) }; const auto& code = SMXXCleanLogitsRuntime::generate(args); diff --git a/csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp b/csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp index f3b82e3de9..14b2332afc 100644 --- a/csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp +++ b/csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp @@ -132,27 +132,28 @@ static void smxx_fp8_mqa_logits(const torch::Tensor& q, DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity); // Launch - const SMXXFP8MQALogitsRuntime::Args& args = { - .seq_len = seq_len, - .seq_len_kv = seq_len_kv, - .max_seqlen_k = max_seqlen_k, - .stride_logits = stride_logits, - .num_heads = num_heads, .head_dim = head_dim, - .is_compressed_logits = is_compressed_logits, - .num_q_stages = num_q_stages, - .num_kv_stages = num_kv_stages, - .block_q = block_q, - .block_kv = block_kv, - .cu_seq_len_k_start = cu_seq_len_k_start.data_ptr(), - .cu_seq_len_k_end = cu_seq_len_k_end.data_ptr(), - .logits = logits.data_ptr(), - .tensor_map_q = tensor_map_q, - .tensor_map_kv = tensor_map_kv, - .tensor_map_kv_scales = tensor_map_kv_scales, - .tensor_map_weights = tensor_map_weights, - .num_specialized_threads = num_specialized_threads, - .num_math_threads = num_math_threads, - .launch_args = LaunchArgs(device_runtime->get_num_sms(), + SMXXFP8MQALogitsRuntime::Args args = { + seq_len, + seq_len_kv, + max_seqlen_k, + stride_logits, + num_heads, head_dim, + is_compressed_logits, + num_q_stages, + num_kv_stages, + block_q, + block_kv, + cu_seq_len_k_start.data_ptr(), + cu_seq_len_k_end.data_ptr(), + logits.data_ptr(), + 0.0f, // softmax_scale (unused) + tensor_map_q, + tensor_map_kv, + tensor_map_kv_scales, + tensor_map_weights, + num_specialized_threads, + num_math_threads, + make_launch_args(device_runtime->get_num_sms(), num_specialized_threads + num_math_threads, smem_size) }; diff --git a/csrc/jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp b/csrc/jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp index 1240aad84a..4200de0160 100644 --- a/csrc/jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp +++ b/csrc/jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp @@ -67,16 +67,16 @@ static void smxx_paged_mqa_logits_metadata(const torch::Tensor& context_lens, DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity); // Launch - const SMXXPagedMQALogitsMetadataRuntime::Args& args = { - .aligned_batch_size = aligned_batch_size, - .split_kv = split_kv, - .num_sms = num_sms, - .batch_size = batch_size, - .next_n = next_n, - .is_context_lens_2d = is_context_lens_2d, - .context_lens = context_lens.data_ptr(), - .schedule_metadata = schedule_metadata.data_ptr(), - .launch_args = LaunchArgs(1, num_threads, smem_size) + SMXXPagedMQALogitsMetadataRuntime::Args args = { + aligned_batch_size, + split_kv, + num_sms, + batch_size, + next_n, + is_context_lens_2d, + context_lens.data_ptr(), + schedule_metadata.data_ptr(), + make_launch_args(1, num_threads, smem_size) }; const auto& code = SMXXPagedMQALogitsMetadataRuntime::generate(args); const auto& runtime = compiler->build("smxx_paged_mqa_logits_metadata", code); @@ -231,29 +231,29 @@ static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q, } // Launch - const SMXXFP8PagedMQALogitsRuntime::Args& args = { - .batch_size = batch_size, - .next_n = next_n, - .num_heads = num_heads, - .head_dim = head_dim, - .block_kv = block_kv, - .is_context_lens_2d = is_context_lens_2d, - .block_table_stride = block_table_stride, - .logits_stride = logits_stride, - .num_q_stages = num_q_stages, - .num_kv_stages = num_kv_stages, - .split_kv = split_kv, - .context_lens = context_lens.data_ptr(), - .logits = logits.data_ptr(), - .block_table = block_table.data_ptr(), - .schedule_meta = schedule_meta.data_ptr(), - .tensor_map_q = tensor_map_q, - .tensor_map_kv = tensor_map_kv, - .tensor_map_kv_scales = tensor_map_kv_scales, - .tensor_map_weights = tensor_map_weights, - .num_specialized_threads = num_specialized_threads, - .num_math_threads = num_math_threads, - .launch_args = LaunchArgs(num_sms, + SMXXFP8PagedMQALogitsRuntime::Args args = { + batch_size, + next_n, + num_heads, + head_dim, + block_kv, + is_context_lens_2d, + block_table_stride, + logits_stride, + num_q_stages, + num_kv_stages, + split_kv, + context_lens.data_ptr(), + logits.data_ptr(), + block_table.data_ptr(), + schedule_meta.data_ptr(), + tensor_map_q, + tensor_map_kv, + tensor_map_kv_scales, + tensor_map_weights, + num_specialized_threads, + num_math_threads, + make_launch_args(num_sms, num_specialized_threads + num_math_threads, smem_size) }; diff --git a/csrc/jit_kernels/impls/smxx_layout.hpp b/csrc/jit_kernels/impls/smxx_layout.hpp index 3d113498d1..f32897a701 100644 --- a/csrc/jit_kernels/impls/smxx_layout.hpp +++ b/csrc/jit_kernels/impls/smxx_layout.hpp @@ -130,13 +130,12 @@ static torch::Tensor get_mn_major_tma_aligned_tensor(const torch::Tensor& sf) { constexpr int block_mn = 64; constexpr int num_threads = 512; const auto& smem_size = block_mn * (sf_k + (1 - (sf_k % 2))) * static_cast(sizeof(float)); - const TransposeFP32Runtime::Args& args = { - .mn = mn, - .sf_k = sf_k, - .block_mn = block_mn, - .sf = batched_sf.data_ptr(), - .out = out.data_ptr(), - .launch_args = LaunchArgs({ceil_div(mn, block_mn), num_groups}, num_threads, smem_size) + TransposeFP32Runtime::Args args = { + mn, sf_k, + block_mn, + batched_sf.data_ptr(), + out.data_ptr(), + make_launch_args({ceil_div(mn, block_mn), num_groups}, num_threads, smem_size) }; const auto& code = TransposeFP32Runtime::generate(args); @@ -184,13 +183,12 @@ static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::T constexpr int block_mn = 48; constexpr int num_threads = 512; - const TransposeAndPackFP32IntoUE8M0Runtime::Args& args = { - .mn = mn, - .sf_k = sf_k, - .block_mn = block_mn, - .sf = batched_sf.data_ptr(), - .out = out.data_ptr(), - .launch_args = LaunchArgs({ceil_div(mn, block_mn), num_groups}, num_threads, block_mn * sf_k * 4) + TransposeAndPackFP32IntoUE8M0Runtime::Args args = { + mn, sf_k, + block_mn, + batched_sf.data_ptr(), + out.data_ptr(), + make_launch_args({ceil_div(mn, block_mn), num_groups}, num_threads, block_mn * sf_k * 4) }; const auto& code = TransposeAndPackFP32IntoUE8M0Runtime::generate(args); @@ -204,17 +202,14 @@ static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::T constexpr int block_mn = 128; constexpr int block_packed_sf_k = 16; constexpr int num_threads = 512; - const PackFP32IntoUE8M0Runtime::Args& args = { - .num_groups = 1, - .mn = mn, - .sf_k = sf_k, - .packed_sf_k = packed_sf_k, - .block_mn = block_mn, - .block_packed_sf_k = block_packed_sf_k, - .sf = batched_sf.data_ptr(), - .out = out.data_ptr(), - .ks = nullptr, - .launch_args = LaunchArgs({ceil_div(mn, block_mn), ceil_div(packed_sf_k, block_packed_sf_k)}, num_threads) + PackFP32IntoUE8M0Runtime::Args args = { + 1, + mn, sf_k, packed_sf_k, + block_mn, block_packed_sf_k, + batched_sf.data_ptr(), + out.data_ptr(), + nullptr, + make_launch_args({ceil_div(mn, block_mn), ceil_div(packed_sf_k, block_packed_sf_k)}, num_threads) }; const auto& code = PackFP32IntoUE8M0Runtime::generate(args); @@ -242,17 +237,14 @@ static torch::Tensor get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(cons constexpr int block_mn = 128; constexpr int block_packed_sf_k = 16; constexpr int num_threads = 512; - const PackFP32IntoUE8M0Runtime::Args& args = { - .num_groups = num_groups, - .mn = mn, - .sf_k = sf_k, - .packed_sf_k = packed_sf_k, - .block_mn = block_mn, - .block_packed_sf_k = block_packed_sf_k, - .sf = sf.data_ptr(), - .out = out.data_ptr(), - .ks = ks_tensor.data_ptr(), - .launch_args = LaunchArgs({ceil_div(mn, block_mn), ceil_div(packed_sf_k, block_packed_sf_k)}, num_threads) + PackFP32IntoUE8M0Runtime::Args args = { + num_groups, + mn, sf_k, packed_sf_k, + block_mn, block_packed_sf_k, + sf.data_ptr(), + out.data_ptr(), + ks_tensor.data_ptr(), + make_launch_args({ceil_div(mn, block_mn), ceil_div(packed_sf_k, block_packed_sf_k)}, num_threads) }; const auto& code = PackFP32IntoUE8M0Runtime::generate(args); diff --git a/csrc/utils/exception.hpp b/csrc/utils/exception.hpp index 2aa270661a..da3d7305e2 100644 --- a/csrc/utils/exception.hpp +++ b/csrc/utils/exception.hpp @@ -6,6 +6,7 @@ #include #include "compatibility.hpp" +#include "msvc_compat.hpp" namespace deep_gemm { diff --git a/csrc/utils/layout.hpp b/csrc/utils/layout.hpp index d67cfcfb69..8755b6eb6e 100644 --- a/csrc/utils/layout.hpp +++ b/csrc/utils/layout.hpp @@ -43,7 +43,8 @@ static auto get_shape(const torch::Tensor& t) { } static std::tuple check_ab_fp8_fp4(const torch::Tensor& ab, const cute::UMMA::Major& major, const int& arch_major) { - auto [mn, k] = get_shape<2>(ab); + int mn, k; + std::tie(mn, k) = get_shape<2>(ab); if (ab.scalar_type() != torch::kFloat8_e4m3fn) { DG_HOST_ASSERT(ab.scalar_type() == kPackedFP4 and arch_major == 10); major == cute::UMMA::Major::K ? (k *= 2) : (mn *= 2); @@ -52,7 +53,8 @@ static std::tuple check_ab_fp8_fp4(const torch::Tensor& ab, const cute } static std::tuple check_grouped_ab_fp8_fp4(const torch::Tensor& ab, const cute::UMMA::Major& major, const int& arch_major) { - auto [num_groups, mn, k] = get_shape<3>(ab); + int num_groups, mn, k; + std::tie(num_groups, mn, k) = get_shape<3>(ab); if (ab.scalar_type() != torch::kFloat8_e4m3fn) { DG_HOST_ASSERT(ab.scalar_type() == kPackedFP4 and arch_major == 10); major == cute::UMMA::Major::K ? (k *= 2) : (mn *= 2); diff --git a/csrc/utils/msvc_compat.hpp b/csrc/utils/msvc_compat.hpp new file mode 100644 index 0000000000..9d3d66a0e3 --- /dev/null +++ b/csrc/utils/msvc_compat.hpp @@ -0,0 +1,40 @@ +#pragma once + +// MSVC compatibility layer for DeepGEMM +// This header provides POSIX-compatible macros and includes for Windows/MSVC builds + +#if defined(_MSC_VER) + #define DG_MSVC 1 +#else + #define DG_MSVC 0 +#endif + +#if DG_MSVC + // POSIX compatibility for Windows + #include + #include + #include + + // Map POSIX functions to Windows equivalents + #define popen _popen + #define pclose _pclose + #define getpid _getpid + + // WEXITSTATUS: On Windows, pclose returns the exit code directly + #define WEXITSTATUS(x) (x) + + // Include ciso646 for and/or/not keywords (C++ alternative tokens) + // In C++17 and later, these are built-in, but MSVC requires this header + // unless /permissive- is used consistently + #include + + // Disable some MSVC warnings that are noisy for this codebase + #pragma warning(disable: 4244) // conversion from 'type1' to 'type2', possible loss of data + #pragma warning(disable: 4267) // conversion from 'size_t' to 'type', possible loss of data + #pragma warning(disable: 4996) // deprecated functions + +#else + // POSIX systems + #include + #include +#endif diff --git a/csrc/utils/system.hpp b/csrc/utils/system.hpp index b0e28ba938..b4c1d6b116 100644 --- a/csrc/utils/system.hpp +++ b/csrc/utils/system.hpp @@ -6,8 +6,8 @@ #include #include #include -#include +#include "msvc_compat.hpp" #include "exception.hpp" #include "format.hpp" @@ -71,10 +71,10 @@ static std::filesystem::path make_dirs(const std::filesystem::path& path) { const bool& created = std::filesystem::create_directories(path, capture); if (not (created or capture.value() == 0)) { DG_HOST_UNREACHABLE(fmt::format("Failed to make directory: {}, created: {}, value: {}", - path.c_str(), created, capture.value())); + path.string(), created, capture.value())); } if (created and get_env("DG_JIT_DEBUG")) - printf("Create directory: %s\n", path.c_str()); + printf("Create directory: %s\n", path.string().c_str()); return path; } diff --git a/deep_gemm/include/deep_gemm/impls/smxx_layout.cuh b/deep_gemm/include/deep_gemm/impls/smxx_layout.cuh index bea7000276..01e4f243e0 100644 --- a/deep_gemm/include/deep_gemm/impls/smxx_layout.cuh +++ b/deep_gemm/include/deep_gemm/impls/smxx_layout.cuh @@ -66,8 +66,8 @@ __global__ void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, con const auto num_uint4 = num_values / 4; #pragma unroll for (uint32_t i = threadIdx.x; i < num_uint4; i += kNumThreads) { - const auto& [x, y, z, w] = __ldg(reinterpret_cast(local_sf) + i); - st_shared(reinterpret_cast(smem_buffer) + i, x, y, z, w); + const uint4 loaded = __ldg(reinterpret_cast(local_sf) + i); + st_shared(reinterpret_cast(smem_buffer) + i, loaded.x, loaded.y, loaded.z, loaded.w); } // Fill unaligned values as well diff --git a/setup.py b/setup.py index 6199d7c35f..7c3a05641b 100644 --- a/setup.py +++ b/setup.py @@ -24,11 +24,18 @@ DG_USE_LOCAL_VERSION = int(os.getenv('DG_USE_LOCAL_VERSION', '1')) == 1 DG_JIT_USE_RUNTIME_API = int(os.environ.get('DG_JIT_USE_RUNTIME_API', '0')) == 1 -# Compiler flags -cxx_flags = ['-std=c++17', '-O3', '-fPIC', '-Wno-psabi', '-Wno-deprecated-declarations', - f'-D_GLIBCXX_USE_CXX11_ABI={int(torch.compiled_with_cxx11_abi())}'] +# Platform detection +IS_WINDOWS = sys.platform.startswith('win') + +# Compiler flags (platform-specific) +if IS_WINDOWS: + cxx_flags = ['/std:c++latest', '/O2', '/EHsc', '/Zc:__cplusplus', '/permissive-', '/utf-8', + '/DNOMINMAX', '/DWIN32_LEAN_AND_MEAN'] +else: + cxx_flags = ['-std=c++17', '-O3', '-fPIC', '-Wno-psabi', '-Wno-deprecated-declarations', + f'-D_GLIBCXX_USE_CXX11_ABI={int(torch.compiled_with_cxx11_abi())}'] if DG_JIT_USE_RUNTIME_API: - cxx_flags.append('-DDG_JIT_USE_RUNTIME_API') + cxx_flags.append('-DDG_JIT_USE_RUNTIME_API' if not IS_WINDOWS else '/DDG_JIT_USE_RUNTIME_API') # Sources current_dir = os.path.dirname(os.path.realpath(__file__)) @@ -40,8 +47,11 @@ 'third-party/cutlass/include', 'third-party/fmt/include', ] -build_libraries = ['cudart', 'nvrtc'] -build_library_dirs = [f'{CUDA_HOME}/lib64'] +build_libraries = ['cudart', 'nvrtc', 'cublasLt', 'cublas'] +if IS_WINDOWS: + build_library_dirs = [f'{CUDA_HOME}/lib/x64'] +else: + build_library_dirs = [f'{CUDA_HOME}/lib64'] third_party_include_dirs = [ 'third-party/cutlass/include/cute', 'third-party/cutlass/include/cutlass', @@ -76,6 +86,9 @@ def get_package_version(): def get_platform(): if sys.platform.startswith('linux'): return f'linux_{platform.uname().machine}' + elif sys.platform.startswith('win'): + # Wheel platform tag must be lowercase (e.g., win_amd64, not win_AMD64) + return f'win_{platform.uname().machine.lower()}' else: raise ValueError('Unsupported platform: {}'.format(sys.platform)) @@ -165,10 +178,42 @@ def prepare_includes(self): shutil.copytree(src_dir, dst_dir) +def get_detailed_wheel_name(): + """Generate a detailed wheel filename with version info. + + Wheel filename format: {name}-{version}-{python}-{abi}-{platform}.whl + We encode CUDA/torch/ABI info in the version's local segment. + """ + torch_version = parse(torch.__version__) + torch_version_str = f'{torch_version.major}.{torch_version.minor}' + python_version = f'cp{sys.version_info.major}{sys.version_info.minor}' + platform_name = get_platform() + deep_gemm_version = get_package_version() + + # Get CUDA version from torch + cuda_version = parse(torch.version.cuda) + cuda_version_str = f'{cuda_version.major}{cuda_version.minor}' + + # CXX11 ABI flag (Linux only, encoded in version) + if IS_WINDOWS: + abi_str = '' + else: + cxx11_abi = int(torch._C._GLIBCXX_USE_CXX11_ABI) + abi_str = f'.cxx11abi{cxx11_abi}' + + # Version format: 2.3.0+cu128.torch2.10.cxx11abi1 (local version uses dots, not dashes) + full_version = f'{deep_gemm_version}+cu{cuda_version_str}.torch{torch_version_str}{abi_str}' + + # Standard wheel format: {name}-{version}-{python}-{abi}-{platform}.whl + return f'deep_gemm-{full_version}-{python_version}-{python_version}-{platform_name}.whl' + + class CachedWheelsCommand(_bdist_wheel): def run(self): if DG_FORCE_BUILD or DG_USE_LOCAL_VERSION: - return super().run() + super().run() + self._rename_wheel() + return wheel_url, wheel_filename = get_wheel_url() print(f'Try to download wheel from URL: {wheel_url}') @@ -189,6 +234,26 @@ def run(self): print('Precompiled wheel not found. Building from source...') # If the wheel could not be downloaded, build from source super().run() + self._rename_wheel() + + def _rename_wheel(self): + """Rename the wheel to include detailed version info.""" + if not os.path.exists(self.dist_dir): + return + + # Find the generated wheel + for filename in os.listdir(self.dist_dir): + if filename.startswith('deep_gemm') and filename.endswith('.whl'): + old_path = os.path.join(self.dist_dir, filename) + new_filename = get_detailed_wheel_name() + new_path = os.path.join(self.dist_dir, new_filename) + + if old_path != new_path: + if os.path.exists(new_path): + os.remove(new_path) + os.rename(old_path, new_path) + print(f'Renamed wheel: {filename} -> {new_filename}') + break if __name__ == '__main__': diff --git a/third-party/cutlass b/third-party/cutlass index f3fde58372..a4eb0e05f6 160000 --- a/third-party/cutlass +++ b/third-party/cutlass @@ -1 +1 @@ -Subproject commit f3fde58372d33e9a5650ba7b80fc48b3b49d40c8 +Subproject commit a4eb0e05f6dd0403f94087b495393bdca75bf0ad