A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://github.com/ayrnb/LeetCUDA below:

ayrnb/LeetCUDA: 📚LeetCUDA: Modern CUDA Learn Notes with PyTorch for Beginners🐑, 200+ CUDA/Tensor Cores Kernels, HGEMM, FA-2 MMA etc.🔥

📚 LeetCUDA: Modern CUDA Learn Notes with PyTorch for Beginners 🐑

📚 Modern CUDA Learn Notes with PyTorch for Beginners: It includes Tensor/CUDA Cores, TF32/F16/BF16/F8, 📖200+ CUDA Kernels🔥🔥(Easy -> Hard++) with PyTorch bindings, 📖100+ LLM/VLM/CV/CUDA/CuTe🔥 blogs, 📖toy-hgemm⚡️⚡️ which can achieve 98%~100% performance of cuBLAS, and 📖flash-attention-mma⚡️⚡️ using Tensor Cores with pure MMA PTX. Welcome to 🌟👆🏻star this repo to support me, many thanks ~ 🎉🎉

🔥🔥 PR Welcome: Add Your Kernel to LeetCUDA! Let's make it Awesome together! 🎉🎉

Currently, on NVIDIA L20, RTX 4090 and RTX 3080 Laptop, compared with cuBLAS's default Tensor Cores algorithm, the HGEMM (WMMA/MMA/CuTe) in this repo (blue🔵) can achieve 98%~100% of its (orange🟠) performance. Please check toy-hgemm library⚡️⚡️ or hgemm-tensorcores-mma⚡️⚡️ repo for more details.

📚Feature 📚Feature 📚Feature 📚Feature ✔️CUDA/Tensor Cores ✔️Loop over K ✔️Tile Block(BMxBK) ✔️Tile Threads(T 8x8) ✔️WMMA(m16n16k16) ✔️MMA(m16n8k16) ✔️Pack LDST(128 bits) ✔️SMEM Padding ✔️Copy Async ✔️Tile MMAs ✔️Tile Warps ✔️Multi Stages(2~4) ✔️Register Double Buffers ✔️Block Swizzle ✔️Warp Swizzle ✔️SMEM Swizzle(CuTe/MMA) ✔️Collective Store(Shfl) ✔️Layout NN ✔️Layout TN ✔️SGEMM FP32/TF32 📖 FA2-MMA Benchmark 🎉🎉

I have also implemented FlashAttention-2 using pure MMA PTX instructions, which supports features such as Multi-Stages, Tile MMA, Tile Warp, Shared KV SMEM, Fully Shared QKV SMEM, Prefetch Q s2r, Prefetch K/V g2s, QKV Fine-grained Tiling, Collective Store, etc. Please refer to flash-attention-mma⚡️⚡️ for more details.

📚Feature 📚Feature 📚Feature 📚Feature ✔️Tensor Cores ✔️Loop over N/D ✔️Tile Block(Br, Bc) ✔️MMA(m16n8k16) ✔️Pack LDST(128 bits) ✔️SMEM Swizzle/Padding ✔️Copy Async ✔️Tile MMAs ✔️Tile Warps ✔️Multi Stages(1/2) ✔️Collective Store(Shfl) ✔️Split KV/Q ✔️Shared QKV SMEM ✔️Prefetch Q s2r ✔️Prefetch KV g2s ✔️QKV Fine-grained Tiling

Currently, for small-scale attention (B<=4, H <=48, SeqLen <= 8192, D <= 64) it can run faster than FA2/SDPA on some Devices. For example, on NVIDIA RTX 3080 Laptop, 📚 Split Q + Fully Shared QKV SMEM method can achieve 55 TFLOPS (D=64) that almost ~1.5x 🎉 faster than FA2. On NVIDIA L20, 🤖ffpa-attn-mma method can achieve 104 TFLOPS (D=512) that almost ~1.8x 🎉 faster than SDPA (EFFICIENT ATTENTION). However, for large-scale attention, there remains a performance gap. Stay tuned for updates ~ (MMA Acc F16/F32, softmax Acc F32 vs FA2 MMA/softmax Acc F32, 👇Benchmark)

Algorithm (B,H,N,D) RTX 3080 Laptop L20 RTX 4090 FlashAttention-2 (1,8,8192,64) 37 TFLOPS 100 TFLOPS 145 TFLOPS share-qkv+stage2 (1,8,8192,64) 55 TFLOPS 99 TFLOPS 221 TFLOPS FlashAttention-2 (1,48,8192,64) 37 TFLOPS 109 TFLOPS 163 TFLOPS share-qkv+stage2 (1,48,8192,64) 48 TFLOPS 107 TFLOPS 224 TFLOPS SDPA(EFFICIENT ATTENTION) (1,48,8192,512) 16 TFLOPS 58 TFLOPS 85 TFLOPS 🤖ffpa-attn-mma (1,48,8192,512) 39 TFLOPS 104 TFLOPS 200 TFLOPS Precision Errors vs FA2/SDPA / max: < ~1e-3 min: ~0.0 mean: < ~1e-5

The Split KV and Split Q implementations have been carried out in flash-attention-mma⚡️⚡️ for performance comparison. The Split KV method, which involves splitting all QKV across MMA (Warps), is slower than Split Q method, which splitting Q across MMA(Warps) and keep access KV for all MMA(Warps).

// Split QKV across MMA(Warps) using naive matmul MMA&Warp tiling policy.
// case: The layout of 8 MMA(2x4)  [after] kWarpTileSeqLenQxkWarpTileSeqLenK(2x2) -> 32x2,32x2=64x64:
// |  [64,64]  |    warp_KV 0    |    warp_KV 1    |    warp_KV 2    |    warp_KV 3    |
// | warp_QP 0 |-- MMA 0,MMA 0 --|-- MMA 2,MMA 2 --|-- MMA 4,MMA 4 --|-- MMA 6,MMA 6 --|
// | warp_QP 0 |-- MMA 0,MMA 0 --|-- MMA 2,MMA 2 --|-- MMA 4,MMA 4 --|-- MMA 6,MMA 6 --|
// | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 2 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --|
// | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 2 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --|
__global__ void // Q, K, V, O -> [B, H, N, D]
flash_attn_mma_stages_split_kv_kernel(half* Q, half* K, half* V, half* O, ...);
// Split Q across MMA(Warps) and keep access KV for all MMA(Warps),
// in order to reduce the comm between warps via smem and warp shuffle.
// case: MMA = m16n8k16, Br=16x4=64, Bc=8x8=64, layout: 4 warps
// |   64x64   |      warp_KV 0       |
// | warp_QP 0 | MMA 0 ... MMA 0 (x8) |
// | warp_QP 1 | MMA 1 ... MMA 1 (x8) |
// | warp_QP 2 | MMA 2 ... MMA 2 (x8) |
// | warp_QP 3 | MMA 3 ... MMA 3 (x8) |
__global__ void // Q, K, V, O -> [B, H, N, D]
flash_attn_mma_stages_split_q_kernel(half* Q, half* K, half* V, half* O, ...);
// K, V shared the same shared memory, improve block occupancy.
__global__ void // Q, K, V, O -> [B, H, N, D]
flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, half* K, half* V, half* O, ...);
// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy
// and reduce Q SMEM IO-Access.
__global__ void // Q, K, V, O -> [B, H, N, D]
flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, half* K, half* V, half* O, ...);
// Fine-grained tiling at the MMA level for Q@K^T results in a constant SRAM usage of
// 64 * kMmaAtomK for Q and K. For V, the SRAM complexity is O(kMmaAtomK * d), leading to
// an overall SRAM complexity of O(kMmaAtomK * d). Consequently, this approach allows us to
// extend D (head dimension) up to 1024.
__global__ void // Q, K, V, O -> [B, H, N, D]
flash_attn_mma_stages_split_q_tiling_qk_kernel(half* Q, half* K, half* V, half* O, ...);
// Fine-grained tiling at the MMA level for all Q@K^T and P@V results in a constant SRAM usage of
// Br * 16 or Bc * 16 for Q, K, V, leading to an overall SRAM complexity of O(Br * 16). Consequently,
// this approach allows us to run faster than SDPA w or w/o MMA Acc F32.
__global__ void // Q, K, V, O -> [B, H, N, D]
flash_attn_mma_stages_split_q_tiling_qkv_kernel(half* Q, half* K, half* V, half* O, ...);

💡NOTE: 📚Split Q + Fully QKV Fine-grained Tiling has been refactored into 🤖ffpa-attn-mma.

@misc{LeetCUDA@2024,
  title={LeetCUDA: A Modern CUDA Learn Notes with PyTorch for Beginners},
  url={https://github.com/xlite-dev/LeetCUDA},
  note={Open-source software available at https://github.com/xlite-dev/LeetCUDA},
  author={xlite-dev etc},
  year={2024}
}
📖 200+ CUDA Kernels 🔥🔥 (Easy -> Hard++) (©️back👆🏻)

The kernels listed here will guide you through a step-by-step progression, ranging from easy to very challenging topics. The workflow for each topic will be as follows: custom CUDA kernel implementation -> PyTorch Python bindings -> Run tests. 👉TIPS: * = Tensor Cores (WMMA, MMA, CuTe), otherwise, CUDA Cores; / = not supported; ✔️ = supported; = TODO. Contents are listed as follows:

📚 Easy and 📚 Medium sections cover operations such as element-wise, mat_trans, warp/block reduce, nms, relu, gelu, swish, layer-norm, rms-norm, online-softmax, dot-prod, embedding and basic usage for FP32, FP16, BF16 and FP8 . 📚 Hard, 📚 Hard+ and 📚 Hard++ sections delve deeper into advanced topics, primarily focusing on operations like sgemv, sgemm, hgemv, hgemm and flash-attention. These sections also provide numerous kernels implemented using Tensor Cores with pure MMA PTX.

📚 Hard+ ⭐️⭐️⭐️⭐️ & Hard++ ⭐️⭐️⭐️⭐️⭐️ (©️back👆🏻)

💡NOTE: rr: means reduce registers usage (for d>128); f32: means MMA accumulate with FP32 dtype, otherwise, FP16. softmax Acc dtype is always be FP32 for high precision; swizzle: now, only support smem swizzle for MMA.

💡NOTE: 🤖ffpa-attn-mma: 📚FFPA - Yet another Faster Flash Prefill Attention with O(1)🎉SRAM complexity for headdim > 256, 1.8x~3x🎉faster than SDPA EA: 📈L20 ~1.9x↑🎉, 📈 A30 ~1.8x↑🎉, 📈3080 ~2.9x↑🎉, 📈4090 ~2.1x↑🎉.

💡说明: 本小节整理一些自己比较喜欢的文章。欢迎大家提PR推荐更多优秀的文章!

GNU General Public License v3.0

How to contribute? Star this repo or check 🌤🌤CONTRIBUTE🎉🎉.


RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.4