Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,11 @@ deep_gemm/include/cutlass
stubs/

# Symlinks to compiled extensions
deep_gemm/*.so
deep_gemm/*.so
deep_gemm/*.pyd

# Claude Code workspace
.claude/

# Windows null device artifact
nul
23 changes: 19 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,16 @@ cmake_minimum_required(VERSION 3.10)
project(deep_gemm LANGUAGES CXX CUDA)
set(CMAKE_VERBOSE_MAKEFILE ON)

set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O3 -fPIC -Wno-psabi")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC -Wno-psabi")
# Platform-specific compiler flags
if(MSVC)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /O2 /EHsc")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /O2 /EHsc /std:c++17 /Zc:__cplusplus /permissive-")
add_definitions(-D_CRT_SECURE_NO_WARNINGS)
else()
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O3 -fPIC -Wno-psabi")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC -Wno-psabi")
endif()

set(CUDA_SEPARABLE_COMPILATION ON)
list(APPEND CUDA_NVCC_FLAGS "-DENABLE_FAST_DEBUG")
list(APPEND CUDA_NVCC_FLAGS "-O3")
Expand All @@ -22,8 +30,15 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)

include_directories(deep_gemm/include third-party/cutlass/include third-party/cutlass/tools/util/include third-party/fmt/include)
include_directories(${CUDA_TOOLKIT_ROOT_DIR}/targets/x86_64-linux/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS})
link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib64 ${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs)

# Platform-specific include and library directories
if(WIN32)
include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS})
link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64)
else()
include_directories(${CUDA_TOOLKIT_ROOT_DIR}/targets/x86_64-linux/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS})
link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib64 ${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs)
endif()

# The main Python API entrance
pybind11_add_module(_C csrc/python_api.cpp)
Expand Down
59 changes: 58 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ python tests/test_attention.py
python tests/test_core.py
```

### Installation
### Installation (Linux)

```bash
cat install.sh
Expand All @@ -81,6 +81,63 @@ cat install.sh

Then, import `deep_gemm` in your Python project, and enjoy!

### Installation (Windows)

DeepGEMM supports building on Windows with MSVC. Follow these steps:

#### Prerequisites

- **Visual Studio 2022** (or 2019) with C++ desktop development workload
- **CUDA Toolkit 12.3+** (12.9+ recommended for best performance)
- **Python 3.8+** with pip
- **PyTorch 2.1+** with CUDA support

#### Build Steps

```cmd
:: Clone the repository
git clone --recursive https://github.com/deepseek-ai/DeepGEMM.git
cd DeepGEMM

:: Create and activate virtual environment (using uv or venv)
uv venv --python 3.10
.venv\Scripts\activate

:: Install dependencies
uv pip install packaging wheel setuptools ninja numpy==1.26.4 pip build psutil
uv pip install torch==2.10.0 --index-url https://download.pytorch.org/whl/cu128

:: Clear environment variables and initialize VS developer environment
set INCLUDE=
set LIB=
set LIBPATH=
call "C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Auxiliary\Build\vcvarsall.bat" x64

:: Configure CUDA
set CUDA_HOME=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.8
set PATH=%CUDA_HOME%\bin;%PATH%
set DISTUTILS_USE_SDK=1

:: Set target architectures (adjust based on your GPU)
set TORCH_CUDA_ARCH_LIST=8.0;8.6;8.9;9.0

:: Initialize submodules
git submodule sync
git submodule update --init --recursive

:: Build the wheel
set DG_USE_LOCAL_VERSION=0
python setup.py bdist_wheel
```

The built wheel will be located in the `dist/` directory.

#### Troubleshooting

- **Linker errors (LNK1104)**: Ensure `vcvarsall.bat` was executed in the same terminal session
- **CUDA not found**: Verify `CUDA_HOME` points to your CUDA installation directory
- **Missing cuBLASLt symbols**: The build system should automatically link cuBLASLt; ensure CUDA Toolkit is properly installed

## Interfaces

#### Notices
Expand Down
90 changes: 55 additions & 35 deletions csrc/apis/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,10 @@ static void fp8_fp4_gemm_nt(const std::pair<torch::Tensor, torch::Tensor>& a,

// Type and shape checks
const auto arch_major = device_runtime->get_arch_major();
const auto [m , k ] = check_ab_fp8_fp4(a.first, major_a, arch_major);
const auto [n , k_] = check_ab_fp8_fp4(b.first, major_b, arch_major);
const auto [m_, n_] = get_shape<2>(d);
int m, k, n, k_, m_, n_;
std::tie(m, k) = check_ab_fp8_fp4(a.first, major_a, arch_major);
std::tie(n, k_) = check_ab_fp8_fp4(b.first, major_b, arch_major);
std::tie(m_, n_) = get_shape<2>(d);
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat);

Expand All @@ -80,7 +81,9 @@ static void fp8_fp4_gemm_nt(const std::pair<torch::Tensor, torch::Tensor>& a,
return;

// Transform SFA and SFB into compute-required layout
const auto [sfa, sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout(
torch::Tensor sfa, sfb;
int gran_k_a, gran_k_b;
std::tie(sfa, sfb, gran_k_a, gran_k_b) = layout::transform_sf_pair_into_required_layout(
a.second, b.second, m, n, k, recipe, recipe_a, recipe_b, std::nullopt, std::nullopt, disable_ue8m0_cast);

// Dispatch into different implements
Expand Down Expand Up @@ -161,20 +164,23 @@ static void m_grouped_fp8_fp4_gemm_nt_contiguous(const std::pair<torch::Tensor,

// Type and shape checks
const auto arch_major = device_runtime->get_arch_major();
const auto [m , k ] = check_ab_fp8_fp4(a.first, major_a, arch_major);
const auto [num_groups, n, k_] = check_grouped_ab_fp8_fp4(b.first, major_b, arch_major);
const auto [m_, n_] = get_shape<2>(d);
int m, k, num_groups, n, k_, m_, n_;
std::tie(m, k) = check_ab_fp8_fp4(a.first, major_a, arch_major);
std::tie(num_groups, n, k_) = check_grouped_ab_fp8_fp4(b.first, major_b, arch_major);
std::tie(m_, n_) = get_shape<2>(d);
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
DG_HOST_ASSERT(n > 0 and k > 0 and num_groups > 0);
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(grouped_layout.scalar_type() == torch::kInt);

// Layout checks
if (use_psum_layout) {
const auto& [num_groups_] = get_shape<1>(grouped_layout);
int num_groups_;
std::tie(num_groups_) = get_shape<1>(grouped_layout);
DG_HOST_ASSERT(num_groups == num_groups_);
} else {
const auto& [m__] = get_shape<1>(grouped_layout);
int m__;
std::tie(m__) = get_shape<1>(grouped_layout);
DG_HOST_ASSERT(m == m__);
DG_HOST_ASSERT(not expected_m_for_psum_layout.has_value());
}
Expand All @@ -187,7 +193,9 @@ static void m_grouped_fp8_fp4_gemm_nt_contiguous(const std::pair<torch::Tensor,
return;

// Transform SFA and SFB into compute-required layout
const auto [sfa, sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout(
torch::Tensor sfa, sfb;
int gran_k_a, gran_k_b;
std::tie(sfa, sfb, gran_k_a, gran_k_b) = layout::transform_sf_pair_into_required_layout(
a.second, b.second, m, n, k, recipe, recipe_a, recipe_b, std::nullopt, num_groups, disable_ue8m0_cast);

// Dispatch implementation
Expand Down Expand Up @@ -237,9 +245,10 @@ static void m_grouped_fp8_fp4_gemm_nt_masked(const std::pair<torch::Tensor, torc

// Type and shape checks
const auto arch_major = device_runtime->get_arch_major();
const auto [num_groups , m , k ] = check_grouped_ab_fp8_fp4(a.first, major_a, arch_major);
const auto [num_groups_ , n , k_] = check_grouped_ab_fp8_fp4(b.first, major_b, arch_major);
const auto [num_groups__, m_, n_] = get_shape<3>(d);
int num_groups, m, k, num_groups_, n, k_, num_groups__, m_, n_;
std::tie(num_groups, m, k) = check_grouped_ab_fp8_fp4(a.first, major_a, arch_major);
std::tie(num_groups_, n, k_) = check_grouped_ab_fp8_fp4(b.first, major_b, arch_major);
std::tie(num_groups__, m_, n_) = get_shape<3>(d);
const auto num_groups___ = static_cast<int>(masked_m.numel());
DG_HOST_ASSERT(num_groups == num_groups_ and num_groups == num_groups__ and num_groups == num_groups___);
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
Expand All @@ -251,7 +260,9 @@ static void m_grouped_fp8_fp4_gemm_nt_masked(const std::pair<torch::Tensor, torc
check_major_type_cd(d);

// Transform scaling factors
const auto [sfa, sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout(
torch::Tensor sfa, sfb;
int gran_k_a, gran_k_b;
std::tie(sfa, sfb, gran_k_a, gran_k_b) = layout::transform_sf_pair_into_required_layout(
a.second, b.second, m, n, k, recipe, recipe_a, recipe_b, num_groups, num_groups, disable_ue8m0_cast);

// Dispatch implementation
Expand Down Expand Up @@ -280,9 +291,10 @@ static void k_grouped_fp8_gemm_tn_contiguous(const std::pair<torch::Tensor, torc
DG_HOST_ASSERT(recipe == std::make_tuple(1, 1, 128));

// Shape checks
const auto& [num_groups, m, n] = get_shape<3>(d);
const auto& [sum_k_ , m_] = get_shape<2>(a.first);
const auto& [sum_k__, n_] = get_shape<2>(b.first);
int num_groups, m, n, sum_k_, m_, sum_k__, n_;
std::tie(num_groups, m, n) = get_shape<3>(d);
std::tie(sum_k_, m_) = get_shape<2>(a.first);
std::tie(sum_k__, n_) = get_shape<2>(b.first);
const int sum_k = std::accumulate(ks.begin(), ks.end(), 0);
DG_HOST_ASSERT(m == m_ and n == n_ and sum_k == sum_k_ and sum_k == sum_k__);

Expand Down Expand Up @@ -322,7 +334,8 @@ static void k_grouped_fp8_gemm_nt_contiguous(const std::pair<torch::Tensor, torc
DG_HOST_ASSERT(recipe == std::make_tuple(1, 1, 128));

// Shape checks
const auto& [num_groups, m, n] = get_shape<3>(d);
int num_groups, m, n;
std::tie(num_groups, m, n) = get_shape<3>(d);
const auto& sum_mk = a.first.numel();
const auto& sum_nk = b.first.numel();
const int sum_k = std::accumulate(ks.begin(), ks.end(), 0);
Expand Down Expand Up @@ -374,9 +387,10 @@ static void bf16_gemm_nt(const torch::Tensor& a,
check_major_type_cd(d);

// Type and shape checks
const auto& [m , k ] = get_shape<2>(a);
const auto& [n , k_] = get_shape<2>(b);
const auto& [m_, n_] = get_shape<2>(d);
int m, k, n, k_, m_, n_;
std::tie(m, k) = get_shape<2>(a);
std::tie(n, k_) = get_shape<2>(b);
std::tie(m_, n_) = get_shape<2>(d);
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16);
Expand Down Expand Up @@ -433,9 +447,10 @@ static void m_grouped_bf16_gemm_nt_contiguous(const torch::Tensor& a, const torc
DG_HOST_ASSERT(grouped_layout.is_contiguous());

// Type and shape checks
const auto& [m, k] = get_shape<2>(a);
const auto& [num_groups, n, k_] = get_shape<3>(b);
const auto& [m_, n_] = get_shape<2>(d);
int m, k, num_groups, n, k_, m_, n_;
std::tie(m, k) = get_shape<2>(a);
std::tie(num_groups, n, k_) = get_shape<3>(b);
std::tie(m_, n_) = get_shape<2>(d);
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
DG_HOST_ASSERT(n > 0 and k > 0 and num_groups > 0);
DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16);
Expand All @@ -445,10 +460,12 @@ static void m_grouped_bf16_gemm_nt_contiguous(const torch::Tensor& a, const torc

// Layout checks
if (use_psum_layout) {
const auto& [num_groups_] = get_shape<1>(grouped_layout);
int num_groups_;
std::tie(num_groups_) = get_shape<1>(grouped_layout);
DG_HOST_ASSERT(num_groups == num_groups_);
} else {
const auto& [m__] = get_shape<1>(grouped_layout);
int m__;
std::tie(m__) = get_shape<1>(grouped_layout);
DG_HOST_ASSERT(m == m__);
DG_HOST_ASSERT(not expected_m_for_psum_layout.has_value());
}
Expand Down Expand Up @@ -493,9 +510,10 @@ static void m_grouped_bf16_gemm_nt_masked(const torch::Tensor& a, const torch::T
DG_HOST_ASSERT(masked_m.is_contiguous());

// Type and shape checks
const auto& [num_groups, m, k] = get_shape<3>(a);
const auto& [num_groups_, n, k_] = get_shape<3>(b);
const auto& [num_groups__, m_, n_] = get_shape<3>(d);
int num_groups, m, k, num_groups_, n, k_, num_groups__, m_, n_;
std::tie(num_groups, m, k) = get_shape<3>(a);
std::tie(num_groups_, n, k_) = get_shape<3>(b);
std::tie(num_groups__, m_, n_) = get_shape<3>(d);
const auto& num_groups___ = static_cast<int>(masked_m.numel());
DG_HOST_ASSERT(num_groups == num_groups_ and num_groups == num_groups__ and num_groups == num_groups___);
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
Expand Down Expand Up @@ -529,9 +547,10 @@ static void k_grouped_bf16_gemm_tn_contiguous(const torch::Tensor& a,
const std::optional<torch::Tensor>& c,
const std::string& compiled_dims) {
// Shape checks
const auto& [num_groups, m, n] = get_shape<3>(d);
const auto& [sum_k_ , m_] = get_shape<2>(a);
const auto& [sum_k__, n_] = get_shape<2>(b);
int num_groups, m, n, sum_k_, m_, sum_k__, n_;
std::tie(num_groups, m, n) = get_shape<3>(d);
std::tie(sum_k_, m_) = get_shape<2>(a);
std::tie(sum_k__, n_) = get_shape<2>(b);
const int sum_k = std::accumulate(ks.begin(), ks.end(), 0);
DG_HOST_ASSERT(m == m_ and n == n_ and sum_k == sum_k_ and sum_k == sum_k__);

Expand Down Expand Up @@ -566,9 +585,10 @@ static void cublaslt_gemm_nt(const torch::Tensor& a, const torch::Tensor& b,
const auto& major_b = get_major_type_ab(b);

// Type and shape checks
const auto& [m , k ] = get_shape<2>(a);
const auto& [n , k_] = get_shape<2>(b);
const auto& [m_, n_] = get_shape<2>(d);
int m, k, n, k_, m_, n_;
std::tie(m, k) = get_shape<2>(a);
std::tie(n, k_) = get_shape<2>(b);
std::tie(m_, n_) = get_shape<2>(d);
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);

// Early return for trivial cases
Expand Down
5 changes: 3 additions & 2 deletions csrc/jit/cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ class KernelRuntimeCache {

std::shared_ptr<KernelRuntime> get(const std::filesystem::path& dir_path) {
// Hit the runtime cache
if (const auto& iterator = cache.find(dir_path); iterator != cache.end())
const auto& dir_path_str = dir_path.string();
if (const auto& iterator = cache.find(dir_path_str); iterator != cache.end())
return iterator->second;

if (KernelRuntime::check_validity(dir_path))
return cache[dir_path] = std::make_shared<KernelRuntime>(dir_path);
return cache[dir_path_str] = std::make_shared<KernelRuntime>(dir_path);
return nullptr;
}
};
Expand Down
Loading