This repository contains the optimized CUDA kernel implementation for InfLLM V2's Two-Stage Sparse Attention Mechanism. Our implementation provides high-performance kernels for both Stage 1 (Top-K Context Selection) and Stage 2 (Sparse Attention Computation), enabling Large Language Models (LLMs) to efficiently process long contexts with trainable sparse patterns.
InfLLM V2 introduces a novel two-stage approach for efficient long-context processing:
This CUDA kernel implementation includes both stages, providing:
Built upon FlashAttention, our kernels leverage efficient memory access patterns and optimized implementations for both stages.
Stage 1: Top-K Context SelectionThe Top-K selection stage involves three sequential steps:
Note: The infllmv2_attn_stage1
kernel handles steps 1 and 2 (score computation and aggregation). Only step 3 (Top-K selection) is performed outside the kernel.
The sparse attention stage performs standard attention computation, but only on the blocks selected in Stage 1:
infllmv2_attn_stage1
: Calculates similarity scores between query tokens and compressed key representationsinfllmv2_sparse_attn_fwd
: Forward pass kernel for sparse attentioninfllmv2_sparse_attn_bwd
: Backward pass kernel for training# Clone the repository and use main branch for training git clone https://github.com/OpenBMB/infllm_v2_cuda.git --recursive cd infllm_v2_cuda git checkout main # Install with CUDA kernel compilation pip install -e .For Hugging Face Inference (feature_infer branch)
# Clone the repository and use feature_infer branch for inference git clone https://github.com/OpenBMB/infllm_v2_cuda.git --recursive cd infllm_v2_cuda git checkout feature_infer # Install with CUDA kernel compilation pip install -e .
The InfLLM V2 CUDA kernel provides the following interfaces for the two-stage sparse attention:
Stage 1: Attention Score Computation and Aggregation (feature_infer branch)from infllm_v2 import infllmv2_attn_stage1 # Stage 1: Compute and aggregate relevance scores between queries and semantic kernels # This kernel performs: # 1. LSE approximation using compressed keys # 2. Full attention score computation # 3. Score aggregation across query group dimension (hdim16_reduce) # Top-K selection must be performed separately on the aggregated scores # # Inputs: # - q: Query tensor (batch_size * n_heads, seqlen_q, head_dim) # - k: Compressed key tensor representing semantic kernels # - v: Placeholder tensor (not used in score computation) # - cu_seqlens_q, cu_seqlens_k: Cumulative sequence lengths # - max_seqlen_q, max_seqlen_k: Maximum sequence lengths # Returns aggregated attention scores for subsequent Top-K selection aggregated_scores = infllmv2_attn_stage1( q, k, v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, causal=True, # Apply causal masking return_attn_probs=True # Return attention scores ) # Top-K selection should be performed on the returned aggregated scores # (This step is not part of the kernel)Stage 2: Sparse Attention Computation
from infllm_v2 import infllmv2_sparse_attn_func # Stage 2: Sparse Attention Computation Kernel # Inputs: # - q_unpad: Queries tensor (token-level) # - k_unpad, v_unpad: Keys and Values tensors (block-level) # - cu_seqlens_q, cu_seqlens_k: Cumulative sequence lengths # - topk_idx: Selected block indices from Stage 1 # - max_seqlen_q, max_seqlen_k: Maximum sequence lengths # - block_window_size: Optional local attention window size out_unpad = infllmv2_sparse_attn_func( q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, topk_idx, # Block indices selected in Stage 1 max_seqlen_q, max_seqlen_k, block_window_size = 0, # Additional local window for attention )
All benchmarks were conducted with the following configuration:
If you use the InfLLM V2 CUDA kernels in your research, please cite:
@article{minicpm4, title={MiniCPM4: Ultra-Efficient LLMs on End Devices}, author={MiniCPM}, year={2025} }
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4