DeepGEMM is a library designed for clean and efficient General Matrix Multiplications (GEMMs). It supports FP8 and BF16 (working in progress) for both normal and Mix-of-Experts (MoE) grouped scenarios. Written in CUDA, the library has no kernel compilation need during installation, by compiling all kernels at runtime using a lightweight Just-In-Time (JIT) module.
DeepGEMM leverages some concepts from CUTLASS and CuTe, it avoids heavy reliance on their templates or algebras. Instead, the library is designed for simplicity, with only a limited number of core kernel functions. This makes it a clean and accessible resource for learning NVIDIA GPU kernel optimization techniques.
Despite its lightweight design, DeepGEMM's performance matches or exceeds expert-tuned libraries across various matrix shapes.
DG_JIT_USE_NVRTC=1
to enable it (may have performance loss with some cases).get_best_configs
modeling{fmt}
library (could be cloned by Git submodule)# Submodule must be cloned git clone --recursive git@github.com:deepseek-ai/DeepGEMM.git cd DeepGEMM # Link some essential includes and build the CPP JIT module cat develop.sh ./develop.sh # Test all GEMM implements python tests/test_layout.py python tests/test_core.py
cat install.sh ./install.sh
Then, import deep_gemm
in your Python project, and enjoy!
This library provides optimized GEMM kernels for NVIDIA GPUs with a naming convention: D = C + A @ B
. The input shape layout is NT (non-transposed A, transposed B). While the SM90 implementation supports only the NT memory layout (row-major, col-major), the SM100 implementation supports all memory layouts (NT, TN, NN, TT). For example, fp8_gemm_nt
will do a D = C + A @ B.T
For both architectures, the LHS scaling factor is required to have a TMA-aligned and transposed layout. And the data format for the scaling factor of SM90 and SM100 is different:
torch.int
.Please note that operations like input transposition or FP8 casting must be handled separately by the user, please implement or fuse them into prior kernels independently. While the library provides some simple PyTorch utility functions, these may result in slower performance, but our primary focus is on optimizing the GEMM kernels themselves.
Normal dense GEMMs (non-grouped)To perform a basic non-grouped FP8 GEMM, call the fp8_gemm_{nt, nn, tn, tt}
function. For more details, please refer to the function documentation.
Unlike traditional grouped GEMMs in CUTLASS, DeepGEMM groups only the M-axis, while N and K must remain fixed. This design is tailored for scenarios where experts in an MoE model share the same shape. For training forward passes or inference prefilling, where each expert may process a varying number of tokens, we concatenate these tokens into a single tensor, referred to as the "contiguous" layout. Note that each expert segment must be aligned to the GEMM M block size (get_mk_alignment_for_contiguous_layout()
). For more information, please refer to the m_grouped_fp8_gemm_{nt, nn}_contiguous
function documentation.
We also provide a K-axis-grouped API for MoE weight backward (with M and N must remain fixed), please refer to k_grouped_fp8_gemm_tn_contiguous
for more information.
During the inference decoding phase, when CUDA graph is enabled and the CPU is unaware of the number of tokens each expert receives, we support masked grouped GEMMs. By providing a mask tensor, the kernel computes only the valid portions.
Use m_grouped_fp8_gemm_nt_masked
for this purpose and consult the relevant documentation. An example usage is to use the output of low-latency kernels from DeepEP as input.
The library provides some utility functions besides the above kernels:
deep_gemm.set_num_sms
: set the maximum SM count to usedeep_gemm.get_num_sms
: get the current SM maximum count (return the device SM count if not set)deep_gemm.set_tc_util
: set an approximated tensor core utilization ratiodeep_gemm.get_tc_util
: get the current tensor core utilization ratiodeep_gemm.transform_sf_into_required_layout
: transform scaling factors into required layoutdeep_gemm.get_tma_aligned_size
: get the required TMA alignment sizedeep_gemm.get_mk_alignment_for_contiguous_layout
: get the group-level alignment requirement for grouped contiguous layoutdeep_gemm.get_mn_major_tma_aligned_tensor
: get a MN-major TMA-aligned tensordeep_gemm.get_mn_major_tma_aligned_packed_ue8m0_tensor
: get a MN-major TMA-aligned tensor (with packing FP32 into UE8M0)deep_gemm.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor
: K-grouped GEMM packing kernelThe library also provides some environment variables, which may be useful:
DG_JIT_DEBUG
: 0
or 1
, print more JIT debugging information, 0
by defaultDG_JIT_CACHE_DIR
: string, the cache directory to store compiled kernels, $HOME/.deep_gemm
by defaultDG_JIT_USE_NVRTC
: 0
or 1
, use NVRTC instead of NVCC, faster compilation but maybe have lower performance for some cases, 0
by defaultDG_JIT_NVCC_COMPILER
: string, specified NVCC compiler path; will find in torch.utils.cpp_extension.CUDA_HOME
by defaultDG_JIT_PTXAS_VERBOSE
: 0
or 1
, show detailed PTXAS compiler output, 0
by defaultDG_JIT_PRINT_COMPILER_COMMAND
: 0
or 1
, print NVCC compilation command, 0
by defaultDG_PRINT_CONFIGS
: 0
or 1
, print selected configs for each shape, 0
by defaultFor additional examples and details, please refer to the test code or review the corresponding Python documentation.
DeepGEMM is inspired by the CUTLASS project. Thanks and respect to the developers!
This code repository is released under the MIT License.
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4