This tutorial walks you through how to compile LLM models from Huggingface using Torch-TensorRT. We also introduce KV caching in Torch-TensorRT which can greatly improve the performance of LLM inference. The code is available in the tools/llm directory. We use the run_llm.py
script to compile the model, generate outputs, and measure the performance.
Note
This is an experimental release and APIs may change in future versions.
Note
The compilation scripts and tutorials for Llama-2-7b-chat-hf and gpt2 models have been consolidated into the unified run_llm.py
script located in the tools/llm directory.
The tools/llm
directory provides the following tools to compile LLM models from Huggingface:
run_llm.py: Main entry point for model compilation, generating outputs, and benchmarking
Static Cache Utilities: static_cache_v1.py
and static_cache_v2.py
for KV cache optimization
SDPA Attention: sdpa_converter.py
and register_sdpa.py
for registering scaled dot-product attention converter and lowering pass.
Testing Components: Model-specific test files for validation
Utility Functions: utils.py
and cache_utils.py
for common operations
We have officially verified support for the following LLM families:
Model Series
HuggingFace Model Card
Precision
KV Cache Support ?
GPT-2
gpt2
FP16, FP32
Yes
LLaMA 2
meta-llama/Llama-2-7b-chat-hf
FP16, FP32
Yes
LLaMA 3.1
meta-llama/Llama-3.1-8B-Instruct
FP16, FP32
Yes
LLaMA 3.2
meta-llama/Llama-3.2-1B-Instruct
meta-llama/Llama-3.2-3B-Instruct
FP16, FP32
Yes
Qwen 2.5
Qwen/Qwen2.5-0.5B-Instruct
Qwen/Qwen2.5-1.5B-Instruct
Qwen/Qwen2.5-3B-Instruct
Qwen/Qwen2.5-7B-Instruct
FP16, FP32
Yes
Getting Started with run_llm.py¶The main entry point is run_llm.py
, which provides a complete workflow for model compilation and benchmarking.
python tools/llm/run_llm.py \ --model meta-llama/Llama-3.2-1B-Instruct \ --prompt "What is parallel programming?" \ --precision FP16 \ --num_tokens 128 \ --cache static_v2 \ --benchmarkKey Arguments¶
--model
: Name or path of the HuggingFace LLM
--tokenizer
: (Optional) Tokenizer name; defaults to model name
--prompt
: Input prompt for text generation
--precision
: Precision mode (FP16
, FP32
)
--num_tokens
: Number of output tokens to generate
--cache
: KV cache type (static_v1
, static_v2
, or empty for no KV caching)
--benchmark
: Enable benchmarking mode for performance comparison
--enable_pytorch_run
: Also run and compare PyTorch baseline
# Compare different models performance python tools/llm/run_llm.py --model gpt2 --benchmark --enable_pytorch_run python tools/llm/run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --benchmark --enable_pytorch_run # Generate the outputs (disable benchmarking) by specifying the number of tokens to generate. Default = 128 python tools/llm/run_llm.py --model gpt2 --prompt "What is parallel programming?" --num_tokens 128 python tools/llm/run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --prompt "What is parallel programming?" --num_tokens 128 # Test different caching approaches python tools/llm/run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --cache static_v1 python tools/llm/run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --cache static_v2 # Compare FP16 vs FP32 performance python tools/llm/run_llm.py --model Qwen/Qwen2.5-1.5B-Instruct --precision FP16 --benchmark python tools/llm/run_llm.py --model Qwen/Qwen2.5-1.5B-Instruct --precision FP32 --benchmarkKV Caching in Torch-TensorRT¶
We provide two versions of static KV caching: static_cache_v1 and static_cache_v2. In both implementations, we add static KV cache tensors as model inputs/outputs without storing them as external memory. The length of KV cache = input sequence length + output sequence length (specified by --num_tokens
). The number of heads and head dimension are determined by the model config.
The static_cache_v1.py
implements KV cache in the model graph as follows:
class StaticCacheV1Model(nn.Module): def __init__(self): super().__init__() def forward(self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True): # Concatenate new key/value pairs with existing cache new_key_cache = torch.cat((key_cache[:, :, :start_idx, :], k, key_cache[:, :, end_idx:, :]), dim=2) new_value_cache = torch.cat((value_cache[:, :, :start_idx, :], v, value_cache[:, :, end_idx:, :]), dim=2) # Compute attention using the updated cache attn_output = torch._C._nn.scaled_dot_product_attention( q, new_key_cache[:, :, :end_idx, :], new_value_cache[:, :, :end_idx, :], dropout_p=0.0, is_causal=is_causal ) return attn_output, new_key_cache, new_value_cache
In the above code, we concatenate the new key/value pairs with the existing cache and update it. To compute the attention, we use the updated cache and gather the corresponding keys/values from the cache up until and including the current token index. The above code is actually implemented as a FX graph transformation pass. We register it as a Torch-TensorRT lowering pass using the decorator @_aten_lowering_pass
when we import the static_cache_v1.py
module.
Note
The start_idx
and end_idx
are the start and end indices of the current token in the cache. For prefill phase, start_idx
is 0 and end_idx
is the input sequence length. For decode phase, start_idx
begins at the input sequence length and end_idx
equals start_idx + 1
. The start_idx
is incremented by 1 until the end of the sequence or we reach the maximum number of tokens to generate.
The static_cache_v2.py
is similar to static_cache_v1.py
but it uses less number of slice operations. It implements KV cache in the model graph as follows:
class StaticCacheV2Model(nn.Module): def __init__(self): super().__init__() def forward(self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True): concat_keys = torch.cat((key_cache[:, :, :start_idx, :], k), dim=2) concat_values = torch.cat((value_cache[:, :, :start_idx, :], v), dim=2) new_key_cache = torch.cat((concat_keys, key_cache[:, :, end_idx:, :]), dim=2) new_value_cache = torch.cat((concat_values, value_cache[:, :, end_idx:, :]), dim=2) attn_output = torch._C._nn.scaled_dot_product_attention( q, concat_keys, concat_values, dropout_p=0.0, is_causal=is_causal ) return attn_output, new_key_cache, new_value_cache
In the above code, we concatenate the existing key/value cache with current key/value of the token. We use this to directly compute the attention and update the key/value cache inserting the current key/value. The above code is actually implemented as a FX graph transformation pass. We register it as a Torch-TensorRT lowering pass using the decorator @_aten_lowering_pass
when we import the static_cache_v1.py
module. The definitons of start_idx
and end_idx
are the same as static_cache_v1.py
.
After the model is compiled with static KV cache, the input signature of the model is changed. The new input signature is (input_ids, position_ids, key_cache_0, value_cache_0, ..., start_idx, end_idx)
. The number of key/value cache tensors is equal to the number of attention heads in the model. We can use the generate_with_static_cache
function to generate the outputs.
We use custom generate function to generate the outputs. This function performs standard autoregressive decoding without KV caching. There is also a generate_with_static_cache function that performs autoregressive decoding with KV caching.
The generate_with_static_cache
function takes care of preparing the inputs to the model compiled with static KV cache. The model inputs are input_ids
, position_ids
, key_cache_0
, value_cache_0
, …., start_idx
, end_idx
. We initialize the key/value cache tensors with zeros and for every token generated, the new key/value cache tensors are the outputs of the model.
Converts scaled dot-product attention operation using TRT Python API.
Supports causal and standard self-attention.
This is a Torch-TensorRT lowering pass that replaces variants of SDPA with torch.nn.functional.scaled_dot_product_attention
.
Registers the SDPA converter which is used for converting torch.nn.functional.scaled_dot_product_attention
operation.
Sliding window attention (used in Gemma3 and Qwen 3 models) is not yet supported
Some model architectures (e.g. Phi-4) have issues with exporting the torch model.
Torch-TensorRT 2.8.0 or later
Transformers v4.52.3
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4