๐คFFPA(Split-D): Yet another Faster Flash Prefill Attention with Split-D strategy, achieve O(1) SRAM complexity and O(d/4) or O(1) register complexity for large headdim (D > 256), almost 1.8x~3x ๐ faster than SDPA EA with or without MMA Acc F32 on many devices: ๐L20 ~1.9xโ๐, ๐A30 ~1.8xโ๐, ๐3080 ~2.9xโ๐, ๐4090 ~2.1xโ๐. FFPA Algo: Fine-grained tiling for large headim, FA-2 Algo: Coarse-grained tiling for small headidm.
@misc{ffpa-attn@2025, title={FFPA: Yet another Faster Flash Prefill Attention for large headdim.}, url={https://github.com/xlite-dev/ffpa-attn.git}, note={Open-source software available at https://github.com/xlite-dev/ffpa-attn.git}, author={xlite-dev etc}, year={2025} }
We have extended FlashAttention for large headdim (D > 256) by implementing Fine-grained Tiling at the MMA level (GEMM style) for the Q@K^T and P@V matmul. This approach results in a constant SRAM usage of Br * 16 or Bc * 16 (Br = Bc) for Q, K, and V, leading to an overall SRAM complexity of O(2 * Br * 16) โ O(1) and a register complexity of O(d/4) or O(1). Consequently, this method allows us to extend headdim beyond 256 and achieve faster performance compared to SDPA with or without MMA Accumulation F32 (1.8x~3x ๐ faster than SDPA EA).
We have named this new attention tiling technique FFPA: Faster Flash Prefill Attention. We have designed three (L1~L3)
levels of FFPA based on SRAM and register complexity considerations. All levels will not introduce any additional VRAM requirements, ensuring that the HBM memory complexity remains same as FlashAttention. ๐
By leveraging this approach, we can achieve better performance than SDPA EA for very large headdim (D > 256, FA-2 not supported
). Approximate SRAM and register complexity analysis for FFPA L1~L3 level is as follows: (d
=headdim, C,Br,Bc
=Constant, Br=Bc
, let O(C)โO(1)) ๐
๐๐Core Features๐๐: I have implemented FFPA L1~L3 using pure MMA PTX instructions, which supports many features such as Split-Q, SMEM Swizzle/Padding, QKV Multi-Stages(1~4), Tile MMAs/Warps, Mixed MMA F32/F16 Acc (Q@K^T MMA Acc F32 + P@V MMA Acc F16), Fully Shared QKV SMEM, Prefetch QKV g2s, Persist Q s2r/g2s, Fully QKV Fine-grained Tiling(GEMM style), Collective Store, etc.
๐Feature ๐Feature ๐Feature ๐Feature โ๏ธTensor Cores โ๏ธMMA(m16n8k16) โ๏ธTile Block(Br, Bc) โ๏ธTile MMA/Warp โ๏ธSplit Q(FA-2) โ๏ธPack LDST(128 bits) โ๏ธSMEM Swizzle/Pad โ๏ธCopy Async โ๏ธReg Double Buffers โ๏ธQKV Multi-Stages(1~4) โ๏ธCollective Store(Shfl) โ๏ธPrefetch QKV g2s โ๏ธQKV Fine-grained Tiling โ๏ธShared QKV SMEM โ๏ธMixed MMA Acc โ๏ธPersist Q s2r/g2sL1
kernel template signature: ffpa_attn_templates_L1.cuhtemplate< const int kHeadDim, // Headdim, 32~1024 const int kMmaAtomM, // MMA Atom M, 16 const int kMmaAtomN, // MMA Atom N, 8 const int kMmaAtomK, // MMA Atom K, 16 const int kMmaTileSeqLenQ, // 4, more MMA(warp), M=16*4=64, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)] const int kMmaTileSeqLenK, // 1, more MMA(warp), N=8*1 =8, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)] const int kMmaTileSeqLenP, // 4, more MMA(warp), M=16*4=64, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ] const int kMmaTileHeadDimV, // 1, more MMA(warp), N=8*1 =8, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ] const int kWarpTileSeqLenQ, // 1, more values, M, Br=64*1=64, matmul M const int kWarpTileSeqLenK, // 8, more values, N, Bc=8*8 =64, matmul N const int kWarpTileSeqLenP, // 1, more values, M, Br=64*1=64, matmul M const int kWarpTileHeadDimV, // 8, more values, N, d=8*(1|2|3|4|...)=8|...|32|64|96|128|... const int kMmaAccFloat32QK, // 0/1, Q@K^T, 0 MMA Acc with fp16, 1 MMA Acc with fp32. const int kMmaAccFloat32PV, // 0/1, P@V, 0 MMA Acc with fp16, 1 MMA Acc with fp32. const int kOStorageAccFloat32, // 0/1, MMA Acc always be f32/f16, but O storage can be fp32 or half. const int kPrefetchQK, // Prefetch QK at the Appropriate Time Point. const int kPrefetchPV, // Prefetch V at the Appropriate Time Point. const int kShareSmemQKV, // QKV share the same shared memory, reuse QK smem for V. const int kPersistQs2r, // Persist load Q s2r for headdim < 512, more registers, but still keep O(1) SRAM. const int kPersistQg2s, // Persist load Q g2s for headdim <= 320, more SRAM, but still keep register usage. const int kRegPipeKV, // Registers Ping pong double buffers for ldmatrix s2r & mma computation overlapping. const int kStageQK, // <= 4, may apply different multi stages policy for QK and V (<=4) const int kStagePV, // <= 4, may apply different multi stages policy for QK and V (<=4) const int kPadQ, // Pad Q/K/V 0,8; 0 -> smem swizzle, > 0 -> padding const int kPadK, // Pad Q/K/V 0,8; 0 -> smem swizzle, > 0 -> padding const int kPadV // Pad Q/K/V 0,8; 0 -> smem swizzle, > 0 -> padding > __global__ void // Q, K, V, O -> [B, H, N, D] // FFPA Attention Algo: Fine-grained tiling at MMA level for large headdim (d>=256), // which can achieve 1.8x~3x๐ faster than SDPA EA with or without MMA Acc F32. ffpa_mma_stages_split_q_L1_large_d_template(half* Q, half* K, half* V, half* O, ...); // FA-2 Attention Algo: Coarse-grained tiling at Attention level for small headdim (d<256), // which can achieve 95%-105%๐ performance as SDPA FA-2 BE with MMA Acc F32 for N<=4096, // and achieve almost 1.2x~1.4x๐ faster than SDPA FA-2 via Mixed MMA Acc(Q@K^T F32 + // P@V F16) for all range N on NVIDIA 4090 RTX device. ffpa_mma_stages_split_q_L1_small_d_template(half* Q, half* K, half* V, half* O, ...);
The FFPA implemented in this repo can be install as a python library, namely, ffpa-attn
library (optional).
git clone https://github.com/xlite-dev/ffpa-attn.git # clone, then, run bash .dev/install.sh directly or run commands: python3 setup.py bdist_wheel && cd dist && python3 -m pip install *.whl # pip uninstall ffpa-attn -y๐ FFPA L1 (Level 1): Benchmark ๐๐
L1: level 1, O(2xBrx16)โO(1) SRAM complexity, O(d/4) register complexity, the same GPU HBM memory complexity as FlashAttention. B=1, H=48, N=8192, D=320-1024(FA2 not supported ๐). (Notes, *
=MMA Acc F32, ^
=MMA Acc F16, Softmax Acc dtype is always be F32, T=TFLOPS, ๐Benchmark)
*
=MMA Acc F32, ^
=MMA Acc F16, T
=TFLOPS, ~1.8xโ๐)*
=MMA Acc: QK F32 + PV F16, ^
=MMA Acc F16, T
=TFLOPS, ~1.9xโ๐)*
=MMA Acc F32, ^
=MMA Acc F16, T
=TFLOPS, ~1.8xโ๐)*
=MMA Acc: QK F32 + PV F16, ^
=MMA Acc F16, T
=TFLOPS, ~1.9xโ๐)*
=MMA Acc F32, ^
=MMA Acc F16, T
=TFLOPS, ~2.5xโ๐)*
=MMA Acc: QK F32 + PV F16, ^
=MMA Acc F16, T
=TFLOPS, ~2.9xโ๐)*
=MMA Acc F32, ^
=MMA Acc F16, T
=TFLOPS, ~1.8xโ๐)*
=MMA Acc: QK F32 + PV F16, ^
=MMA Acc F16, T
=TFLOPS, ~2.1xโ๐)๐You can test many custom FFPA kernels via Python and figure out the difference in their performance. The --gen-bench
and --plot
options help you generate a benchmark table in Markdown style and speedup bar plots on your device. Contributions of your benchmark tables and plots are welcome via a PR ๐๐.
FA2 not supported
)# You can test on many devices, such as Volta, Ampere, Ada, Hopper, ... cd tests && python3 test_ffpa_attn.py --B 1 --H 48 --N 8192 --show-all --D 320 ---------------------------------------B=1, H=48, N=8192, D=320, Warmup: 1, Iters: 5-------------------- (sdpa): ['-0.02380371'], time:73.66518ms, TFLOPS:56.19 (+0.00 %)(~1.00x) (ffpa+acc+f32+L1+stage1): ['-0.02378845'], time:52.87361ms, TFLOPS:78.28 (+39.32%)(~1.39x) (ffpa+acc+f32+L1+stage2): ['-0.02378845'], time:40.84062ms, TFLOPS:101.35(+29.46%)(~1.80x) (ffpa+acc+f32+L1+stage3): ['-0.02378845'], time:40.49534ms, TFLOPS:102.21(+0.85 %)(~1.82x) (ffpa+acc+f32+L1+stage4): ['-0.02378845'], time:40.88177ms, TFLOPS:101.25(+0.00 %)(~1.80x) (ffpa+acc+f16+L1+stage1): ['-0.02378845'], time:53.43298ms, TFLOPS:77.46 (+0.00 %)(~1.38x) (ffpa+acc+f16+L1+stage2): ['-0.02378845'], time:39.76068ms, TFLOPS:104.10(+1.85 %)(~1.85x) (ffpa+acc+f16+L1+stage3): ['-0.02378845'], time:39.54901ms, TFLOPS:104.66(+0.54 %)(~1.86x) (ffpa+acc+f16+L1+stage4): ['-0.02378845'], time:41.06554ms, TFLOPS:100.79(+0.00 %)(~1.79x) --------------------------------------------------------------------------------------------------------
cd tests && pip install matplotlib && python3 test_ffpa_attn.py --gen-bench --show-all --plot
# Enable ffpa-attn small d kernel which using coarse-grained tiling method. export ENABLE_FFPA_PERSIST_Q_G2S=1 && export ENABLE_FFPA_PERSIST_KV_G2S=1 cd tests && python3 test_ffpa_attn.py --B 1 --H 32 --N 1024 --check --show-all --D 64 # NVIDIA L20 ---------------------------------------B=1, H=32, N=1024, D=64, Warmup: 1, Iters: 5-------------------- (sdpa): ['0.00802612'], time:0.148057ms, TFLOPS:59.14 (+0.00 %)(~1.00x) (ffpa+acc+f32+L1+stage1): ['0.00803375'], time:0.103807ms, TFLOPS:84.34 (+42.63%)(~1.43x) (ffpa+acc+f32+L1+stage2): ['0.00803375'], time:0.102233ms, TFLOPS:85.64 (+1.54 %)(~1.45x) (ffpa+acc+f32+L1+stage3): ['0.00803375'], time:0.102519ms, TFLOPS:85.40 (+0.00 %)(~1.44x) (ffpa+acc+f32+L1+stage4): ['0.00803375'], time:0.102043ms, TFLOPS:85.80 (+0.19 %)(~1.45x) (ffpa+acc+f16+L1+stage1): ['0.00795746'], time:0.104713ms, TFLOPS:83.61 (+0.00 %)(~1.41x) (ffpa+acc+f16+L1+stage2): ['0.00795746'], time:0.102949ms, TFLOPS:85.05 (+0.00 %)(~1.44x) (ffpa+acc+f16+L1+stage3): ['0.00795746'], time:0.108957ms, TFLOPS:80.36 (+0.00 %)(~1.36x) (ffpa+acc+f16+L1+stage4): ['0.00795746'], time:0.103282ms, TFLOPS:84.77 (+0.00 %)(~1.43x) -------------------------------------------------------------------------------------------------------- cd tests && python3 test_ffpa_attn.py --B 1 --H 32 --N 4096 --check --show-all --D 64 # NVIDIA L20 -------------------------B=1, H=32, N=4096, D=64, Warmup: 1, Iters: 5----------------------------------- (sdpa): ['0.01959229'], time:1.397752ms, TFLOPS:100.24(+0.00 %)(~1.00x) (ffpa+acc+f32+L1+stage1): ['0.01959229'], time:1.368856ms, TFLOPS:102.36(+2.11 %)(~1.02x) (ffpa+acc+f32+L1+stage2): ['0.01959229'], time:1.367807ms, TFLOPS:102.44(+0.08 %)(~1.02x) (ffpa+acc+f32+L1+stage3): ['0.01959229'], time:1.367855ms, TFLOPS:102.43(+0.00 %)(~1.02x) (ffpa+acc+f32+L1+stage4): ['0.01959229'], time:1.368045ms, TFLOPS:102.42(+0.00 %)(~1.02x) (ffpa+acc+f16+L1+stage1): ['0.01957703'], time:1.389312ms, TFLOPS:100.85(+0.00 %)(~1.01x) (ffpa+acc+f16+L1+stage2): ['0.01957703'], time:1.388311ms, TFLOPS:100.92(+0.00 %)(~1.01x) (ffpa+acc+f16+L1+stage3): ['0.01957703'], time:1.386976ms, TFLOPS:101.02(+0.00 %)(~1.01x) (ffpa+acc+f16+L1+stage4): ['0.01957703'], time:1.387834ms, TFLOPS:100.96(+0.00 %)(~1.01x) --------------------------------------------------------------------------------------------------------
๐กNOTE: Please check all configurable environment variables in env.py.
GNU General Public License v3.0
How to contribute? Wecome to starโญ๏ธ this repo to support me๐๐ป ~
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4