A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://github.com/xlite-dev/ffpa-attn.git below:

xlite-dev/ffpa-attn: ๐Ÿค–FFPA: Extend FlashAttention-2 with Split-D, ~O(1) SRAM complexity for large headdim, 1.8x~3xโ†‘๐ŸŽ‰ vs SDPA EA.

๐Ÿค–FFPA: 1.8x~3x๐ŸŽ‰faster vs SDPA EA with or without MMA Acc F32

๐Ÿค–FFPA(Split-D): Yet another Faster Flash Prefill Attention with Split-D strategy, achieve O(1) SRAM complexity and O(d/4) or O(1) register complexity for large headdim (D > 256), almost 1.8x~3x ๐ŸŽ‰ faster than SDPA EA with or without MMA Acc F32 on many devices: ๐Ÿ“ˆL20 ~1.9xโ†‘๐ŸŽ‰, ๐Ÿ“ˆA30 ~1.8xโ†‘๐ŸŽ‰, ๐Ÿ“ˆ3080 ~2.9xโ†‘๐ŸŽ‰, ๐Ÿ“ˆ4090 ~2.1xโ†‘๐ŸŽ‰. FFPA Algo: Fine-grained tiling for large headim, FA-2 Algo: Coarse-grained tiling for small headidm.

@misc{ffpa-attn@2025,
  title={FFPA: Yet another Faster Flash Prefill Attention for large headdim.},
  url={https://github.com/xlite-dev/ffpa-attn.git},
  note={Open-source software available at https://github.com/xlite-dev/ffpa-attn.git},
  author={xlite-dev etc},
  year={2025}
}
๐Ÿ“– FFPA L1~L3: FlashAttention + QKV Fine-grained Tiling at MMA level๐Ÿ’ก

We have extended FlashAttention for large headdim (D > 256) by implementing Fine-grained Tiling at the MMA level (GEMM style) for the Q@K^T and P@V matmul. This approach results in a constant SRAM usage of Br * 16 or Bc * 16 (Br = Bc) for Q, K, and V, leading to an overall SRAM complexity of O(2 * Br * 16) โ‰ˆ O(1) and a register complexity of O(d/4) or O(1). Consequently, this method allows us to extend headdim beyond 256 and achieve faster performance compared to SDPA with or without MMA Accumulation F32 (1.8x~3x ๐ŸŽ‰ faster than SDPA EA).

We have named this new attention tiling technique FFPA: Faster Flash Prefill Attention. We have designed three (L1~L3) levels of FFPA based on SRAM and register complexity considerations. All levels will not introduce any additional VRAM requirements, ensuring that the HBM memory complexity remains same as FlashAttention. ๐Ÿ‘‡

By leveraging this approach, we can achieve better performance than SDPA EA for very large headdim (D > 256, FA-2 not supported). Approximate SRAM and register complexity analysis for FFPA L1~L3 level is as follows: (d=headdim, C,Br,Bc=Constant, Br=Bc, let O(C)โ‰ˆO(1)) ๐Ÿ‘‡

๐Ÿ“šComplexity ๐Ÿ“šFFPA L1 ๐Ÿ“šFFPA L2 ๐Ÿ“šFFPA L3 ๐Ÿ“šFA-2 SRAM O(2xBrx16)โ‰ˆO(1) O(2xBrx16)โ‰ˆO(1) O(2xBrx16)โ‰ˆO(1) โ‰ˆO(3xBrxd), dโ†‘ Register โ‰ˆO(d/4), dโ†‘ O((Bc/16)x4+2C)โ‰ˆO(1) O((Bc/16)x4+2C)โ‰ˆO(1) โ‰ˆO(d/2), dโ†‘ HBM โ‰ˆFA2โ‰ˆO(Nd), O โ‰ˆFA2โ‰ˆO(Nd), O โ‰ˆFA2โ‰ˆO(Nd), O โ‰ˆO(Nd), O Extra HBM โ‰ˆFA2โ‰ˆO(N), m,l โ‰ˆFA2โ‰ˆO(N), m,l โ‰ˆFA2โ‰ˆO(N), m,l โ‰ˆO(N), m,l

๐Ÿ“š๐Ÿ‘‡Core Features๐ŸŽ‰๐ŸŽ‰: I have implemented FFPA L1~L3 using pure MMA PTX instructions, which supports many features such as Split-Q, SMEM Swizzle/Padding, QKV Multi-Stages(1~4), Tile MMAs/Warps, Mixed MMA F32/F16 Acc (Q@K^T MMA Acc F32 + P@V MMA Acc F16), Fully Shared QKV SMEM, Prefetch QKV g2s, Persist Q s2r/g2s, Fully QKV Fine-grained Tiling(GEMM style), Collective Store, etc.

๐Ÿ“šFeature ๐Ÿ“šFeature ๐Ÿ“šFeature ๐Ÿ“šFeature โœ”๏ธTensor Cores โœ”๏ธMMA(m16n8k16) โœ”๏ธTile Block(Br, Bc) โœ”๏ธTile MMA/Warp โœ”๏ธSplit Q(FA-2) โœ”๏ธPack LDST(128 bits) โœ”๏ธSMEM Swizzle/Pad โœ”๏ธCopy Async โœ”๏ธReg Double Buffers โœ”๏ธQKV Multi-Stages(1~4) โœ”๏ธCollective Store(Shfl) โœ”๏ธPrefetch QKV g2s โœ”๏ธQKV Fine-grained Tiling โœ”๏ธShared QKV SMEM โœ”๏ธMixed MMA Acc โœ”๏ธPersist Q s2r/g2s
template<
  const int kHeadDim,              // Headdim, 32~1024     
  const int kMmaAtomM,             // MMA Atom M, 16
  const int kMmaAtomN,             // MMA Atom N, 8
  const int kMmaAtomK,             // MMA Atom K, 16
  const int kMmaTileSeqLenQ,       // 4, more MMA(warp), M=16*4=64, Q@K^T=[Br(M), d(K)]@[d(K),  Bc(N)]  
  const int kMmaTileSeqLenK,       // 1, more MMA(warp), N=8*1 =8,  Q@K^T=[Br(M), d(K)]@[d(K),  Bc(N)]    
  const int kMmaTileSeqLenP,       // 4, more MMA(warp), M=16*4=64, P@V  =[Br(M),Bc(K)]@[Bc(K), d(N) ]
  const int kMmaTileHeadDimV,      // 1, more MMA(warp), N=8*1 =8,  P@V  =[Br(M),Bc(K)]@[Bc(K), d(N) ]       
  const int kWarpTileSeqLenQ,      // 1, more values, M, Br=64*1=64, matmul M 
  const int kWarpTileSeqLenK,      // 8, more values, N, Bc=8*8 =64, matmul N
  const int kWarpTileSeqLenP,      // 1, more values, M, Br=64*1=64, matmul M
  const int kWarpTileHeadDimV,     // 8, more values, N, d=8*(1|2|3|4|...)=8|...|32|64|96|128|...
  const int kMmaAccFloat32QK,      // 0/1, Q@K^T, 0 MMA Acc with fp16, 1 MMA Acc with fp32.
  const int kMmaAccFloat32PV,      // 0/1, P@V, 0 MMA Acc with fp16, 1 MMA Acc with fp32.
  const int kOStorageAccFloat32,   // 0/1, MMA Acc always be f32/f16, but O storage can be fp32 or half.
  const int kPrefetchQK,           // Prefetch QK at the Appropriate Time Point. 
  const int kPrefetchPV,           // Prefetch V at the Appropriate Time Point. 
  const int kShareSmemQKV,         // QKV share the same shared memory, reuse QK smem for V.
  const int kPersistQs2r,          // Persist load Q s2r for headdim  < 512, more registers, but still keep O(1) SRAM.
  const int kPersistQg2s,          // Persist load Q g2s for headdim <= 320, more SRAM, but still keep register usage.
  const int kRegPipeKV,            // Registers Ping pong double buffers for ldmatrix s2r & mma computation overlapping.
  const int kStageQK,              // <= 4, may apply different multi stages policy for QK and V (<=4)
  const int kStagePV,              // <= 4, may apply different multi stages policy for QK and V (<=4)
  const int kPadQ,                 // Pad Q/K/V 0,8; 0 -> smem swizzle, > 0 -> padding
  const int kPadK,                 // Pad Q/K/V 0,8; 0 -> smem swizzle, > 0 -> padding
  const int kPadV                  // Pad Q/K/V 0,8; 0 -> smem swizzle, > 0 -> padding
> __global__ void // Q, K, V, O -> [B, H, N, D]
// FFPA Attention Algo: Fine-grained tiling at MMA level for large headdim (d>=256), 
// which can achieve 1.8x~3x๐ŸŽ‰ faster than SDPA EA with or without MMA Acc F32.
ffpa_mma_stages_split_q_L1_large_d_template(half* Q, half* K, half* V, half* O, ...); 
// FA-2 Attention Algo: Coarse-grained tiling at Attention level for small headdim (d<256), 
// which can achieve 95%-105%๐ŸŽ‰ performance as SDPA FA-2 BE with MMA Acc F32 for N<=4096, 
// and achieve almost 1.2x~1.4x๐ŸŽ‰ faster than SDPA FA-2 via Mixed MMA Acc(Q@K^T F32 + 
// P@V F16) for all range N on NVIDIA 4090 RTX device.
ffpa_mma_stages_split_q_L1_small_d_template(half* Q, half* K, half* V, half* O, ...); 

The FFPA implemented in this repo can be install as a python library, namely, ffpa-attn library (optional).

git clone https://github.com/xlite-dev/ffpa-attn.git
# clone, then, run bash .dev/install.sh directly or run commands:
python3 setup.py bdist_wheel && cd dist && python3 -m pip install *.whl # pip uninstall ffpa-attn -y
๐Ÿ“– FFPA L1 (Level 1): Benchmark ๐ŸŽ‰๐ŸŽ‰

L1: level 1, O(2xBrx16)โ‰ˆO(1) SRAM complexity, O(d/4) register complexity, the same GPU HBM memory complexity as FlashAttention. B=1, H=48, N=8192, D=320-1024(FA2 not supported ๐Ÿ‘€). (Notes, *=MMA Acc F32, ^=MMA Acc F16, Softmax Acc dtype is always be F32, T=TFLOPS, ๐Ÿ‘‡Benchmark)

Algorithm 320 384 448 512 576 640 704 768 832 896 960 1024 SDPA EA 56T 63T 58T 58T 55T 56T 54T 55T 54T 55T 54T 56T FFPA L1* 102T 102T 103T 104T 103T 95T 95T 95T 95T 96T 95T 94T Speedup 1.82x 1.62x 1.78x 1.79x 1.87x 1.7x 1.76x 1.73x 1.76x 1.75x 1.76x 1.68x FFPA L1^ 104T 103T 103T 102T 104T 103T 102T 94T 94T 94T 100T 100T Speedup 1.86x 1.63x 1.78x 1.76x 1.89x 1.84x 1.89x 1.71x 1.74x 1.71x 1.85x 1.79x Algorithm 320 384 448 512 576 640 704 768 832 896 960 1024 SDPA EA 56T 64T 58T 58T 55T 56T 54T 55T 54T 55T 54T 56T FFPA L1* 105T 102T 104T 103T 105T 95T 95T 94T 94T 94T 102T 101T Speedup 1.88x 1.59x 1.79x 1.78x 1.91x 1.7x 1.76x 1.71x 1.74x 1.71x 1.89x 1.8x FFPA L1^ 104T 103T 103T 102T 103T 103T 102T 94T 94T 94T 100T 100T Speedup 1.86x 1.61x 1.78x 1.76x 1.87x 1.84x 1.89x 1.71x 1.74x 1.71x 1.85x 1.79x Algorithm 320 384 448 512 576 640 704 768 832 896 960 1024 SDPA EA 25T 25T 24T 24T 24T 24T 23T 22T 22T 22T 22T 18T FFPA L1* 45T 44T 44T 43T 43T 38T 37T 37T 37T 36T 33T 32T Speedup 1.8x 1.76x 1.83x 1.79x 1.79x 1.58x 1.61x 1.68x 1.68x 1.64x 1.5x 1.78x FFPA L1^ 48T 46T 45T 43T 44T 44T 44T 38T 37T 36T 40T 34T Speedup 1.92x 1.84x 1.88x 1.79x 1.83x 1.83x 1.91x 1.73x 1.68x 1.64x 1.82x 1.89x Algorithm 320 384 448 512 576 640 704 768 832 896 960 1024 SDPA EA 25T 25T 24T 24T 24T 24T 23T 22T 22T 22T 22T 18T FFPA L1* 48T 46T 46T 43T 44T 38T 38T 38T 37T 36T 40T 34T Speedup 1.92x 1.84x 1.92x 1.79x 1.83x 1.58x 1.65x 1.73x 1.68x 1.64x 1.82x 1.89x FFPA L1^ 48T 46T 45T 43T 44T 44T 44T 38T 37T 36T 39T 34T Speedup 1.92x 1.84x 1.88x 1.79x 1.83x 1.83x 1.91x 1.73x 1.68x 1.64x 1.77x 1.89x Algorithm 320 384 448 512 576 640 704 768 832 896 960 1024 SDPA EA 13T 16T 11T 16T 15T 15T 15T 15T 14T 14T 14T 14T FFPA L1* 33T 31T 30T 30T 30T 27T 27T 26T 26T 26T 26T 25T Speedup 2.54x 1.94x 2.73x 1.88x 2.0x 1.8x 1.8x 1.73x 1.86x 1.86x 1.86x 1.79x FFPA L1^ 43T 41T 39T 39T 39T 39T 39T 36T 34T 33T 31T 33T Speedup 3.31x 2.56x 3.55x 2.44x 2.6x 2.6x 2.6x 2.4x 2.43x 2.36x 2.21x 2.36x Algorithm 320 384 448 512 576 640 704 768 832 896 960 1024 SDPA EA 13T 15T 12T 15T 14T 15T 14T 14T 14T 14T 14T 14T FFPA L1* 38T 36T 34T 35T 34T 31T 32T 31T 30T 28T 27T 27T Speedup 2.92x 2.4x 2.83x 2.33x 2.43x 2.07x 2.29x 2.21x 2.14x 2.0x 1.93x 1.93x FFPA L1^ 44T 41T 39T 39T 38T 39T 39T 36T 34T 32T 31T 33T Speedup 3.38x 2.73x 3.25x 2.6x 2.71x 2.6x 2.79x 2.57x 2.43x 2.29x 2.21x 2.36x Algorithm 320 384 448 512 576 640 704 768 832 896 960 1024 SDPA EA 81T 94T 85T 85T 79T 81T 79T 80T 79T 80T 78T 78T FFPA L1* 149T 150T 150T 150T 150T 140T 140T 140T 139T 139T 137T 134T Speedup 1.84x 1.6x 1.76x 1.76x 1.9x 1.73x 1.77x 1.75x 1.76x 1.74x 1.76x 1.72x FFPA L1^ 194T 194T 189T 191T 197T 188T 184T 180T 177T 172T 171T 171T Speedup 2.4x 2.06x 2.22x 2.25x 2.49x 2.32x 2.33x 2.25x 2.24x 2.15x 2.19x 2.19x Algorithm 320 384 448 512 576 640 704 768 832 896 960 1024 SDPA EA 82T 92T 85T 84T 78T 81T 79T 80T 78T 79T 77T 78T FFPA L1* 176T 170T 171T 171T 171T 161T 160T 161T 160T 158T 165T 164T Speedup 2.15x 1.85x 2.01x 2.04x 2.19x 1.99x 2.03x 2.01x 2.05x 2.0x 2.14x 2.1x FFPA L1^ 200T 191T 189T 191T 188T 188T 186T 179T 175T 173T 172T 170T Speedup 2.44x 2.08x 2.22x 2.27x 2.41x 2.32x 2.35x 2.24x 2.24x 2.19x 2.23x 2.18x

๐Ÿ‘‡You can test many custom FFPA kernels via Python and figure out the difference in their performance. The --gen-bench and --plot options help you generate a benchmark table in Markdown style and speedup bar plots on your device. Contributions of your benchmark tables and plots are welcome via a PR ๐ŸŽ‰๐ŸŽ‰.

# You can test on many devices, such as Volta, Ampere, Ada, Hopper, ...
cd tests && python3 test_ffpa_attn.py --B 1 --H 48 --N 8192 --show-all --D 320
---------------------------------------B=1, H=48, N=8192, D=320, Warmup: 1, Iters: 5--------------------
                   (sdpa): ['-0.02380371'], time:73.66518ms, TFLOPS:56.19 (+0.00 %)(~1.00x)
 (ffpa+acc+f32+L1+stage1): ['-0.02378845'], time:52.87361ms, TFLOPS:78.28 (+39.32%)(~1.39x)
 (ffpa+acc+f32+L1+stage2): ['-0.02378845'], time:40.84062ms, TFLOPS:101.35(+29.46%)(~1.80x)
 (ffpa+acc+f32+L1+stage3): ['-0.02378845'], time:40.49534ms, TFLOPS:102.21(+0.85 %)(~1.82x)
 (ffpa+acc+f32+L1+stage4): ['-0.02378845'], time:40.88177ms, TFLOPS:101.25(+0.00 %)(~1.80x)
 (ffpa+acc+f16+L1+stage1): ['-0.02378845'], time:53.43298ms, TFLOPS:77.46 (+0.00 %)(~1.38x)
 (ffpa+acc+f16+L1+stage2): ['-0.02378845'], time:39.76068ms, TFLOPS:104.10(+1.85 %)(~1.85x)
 (ffpa+acc+f16+L1+stage3): ['-0.02378845'], time:39.54901ms, TFLOPS:104.66(+0.54 %)(~1.86x)
 (ffpa+acc+f16+L1+stage4): ['-0.02378845'], time:41.06554ms, TFLOPS:100.79(+0.00 %)(~1.79x)
--------------------------------------------------------------------------------------------------------
cd tests && pip install matplotlib && python3 test_ffpa_attn.py --gen-bench --show-all --plot
# Enable ffpa-attn small d kernel which using coarse-grained tiling method.
export ENABLE_FFPA_PERSIST_Q_G2S=1 && export ENABLE_FFPA_PERSIST_KV_G2S=1 
cd tests && python3 test_ffpa_attn.py --B 1 --H 32 --N 1024 --check --show-all --D 64 # NVIDIA L20
---------------------------------------B=1, H=32, N=1024, D=64, Warmup: 1, Iters: 5--------------------
                   (sdpa): ['0.00802612'], time:0.148057ms, TFLOPS:59.14 (+0.00 %)(~1.00x)
 (ffpa+acc+f32+L1+stage1): ['0.00803375'], time:0.103807ms, TFLOPS:84.34 (+42.63%)(~1.43x)
 (ffpa+acc+f32+L1+stage2): ['0.00803375'], time:0.102233ms, TFLOPS:85.64 (+1.54 %)(~1.45x)
 (ffpa+acc+f32+L1+stage3): ['0.00803375'], time:0.102519ms, TFLOPS:85.40 (+0.00 %)(~1.44x)
 (ffpa+acc+f32+L1+stage4): ['0.00803375'], time:0.102043ms, TFLOPS:85.80 (+0.19 %)(~1.45x)
 (ffpa+acc+f16+L1+stage1): ['0.00795746'], time:0.104713ms, TFLOPS:83.61 (+0.00 %)(~1.41x)
 (ffpa+acc+f16+L1+stage2): ['0.00795746'], time:0.102949ms, TFLOPS:85.05 (+0.00 %)(~1.44x)
 (ffpa+acc+f16+L1+stage3): ['0.00795746'], time:0.108957ms, TFLOPS:80.36 (+0.00 %)(~1.36x)
 (ffpa+acc+f16+L1+stage4): ['0.00795746'], time:0.103282ms, TFLOPS:84.77 (+0.00 %)(~1.43x)
--------------------------------------------------------------------------------------------------------
cd tests && python3 test_ffpa_attn.py --B 1 --H 32 --N 4096 --check --show-all --D 64 # NVIDIA L20
-------------------------B=1, H=32, N=4096, D=64, Warmup: 1, Iters: 5-----------------------------------
                   (sdpa): ['0.01959229'], time:1.397752ms, TFLOPS:100.24(+0.00 %)(~1.00x)
 (ffpa+acc+f32+L1+stage1): ['0.01959229'], time:1.368856ms, TFLOPS:102.36(+2.11 %)(~1.02x)
 (ffpa+acc+f32+L1+stage2): ['0.01959229'], time:1.367807ms, TFLOPS:102.44(+0.08 %)(~1.02x)
 (ffpa+acc+f32+L1+stage3): ['0.01959229'], time:1.367855ms, TFLOPS:102.43(+0.00 %)(~1.02x)
 (ffpa+acc+f32+L1+stage4): ['0.01959229'], time:1.368045ms, TFLOPS:102.42(+0.00 %)(~1.02x)
 (ffpa+acc+f16+L1+stage1): ['0.01957703'], time:1.389312ms, TFLOPS:100.85(+0.00 %)(~1.01x)
 (ffpa+acc+f16+L1+stage2): ['0.01957703'], time:1.388311ms, TFLOPS:100.92(+0.00 %)(~1.01x)
 (ffpa+acc+f16+L1+stage3): ['0.01957703'], time:1.386976ms, TFLOPS:101.02(+0.00 %)(~1.01x)
 (ffpa+acc+f16+L1+stage4): ['0.01957703'], time:1.387834ms, TFLOPS:100.96(+0.00 %)(~1.01x)
--------------------------------------------------------------------------------------------------------

๐Ÿ’กNOTE: Please check all configurable environment variables in env.py.

GNU General Public License v3.0

How to contribute? Wecome to starโญ๏ธ this repo to support me๐Ÿ‘†๐Ÿป ~


RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.4