In this page we list our PyTorch Autograd-compatible operations. These operations come with performance knobs (configurations), some of which are specific to certain backends.
Changing those knobs is completely optional, and NATTEN will continue to be functionally correct in all cases. However, to squeeze out the maximum performance achievable, we highly recommend looking at backends, or just using our profiler toolkit and its dry run feature to navigate through available backends and their valid configurations for your specific use case and GPU architecture. You can also use the profiler's optimize feature to search and find the best configuration.
Neighborhood Attention
natten.na1d
na1d(
query,
key,
value,
kernel_size,
stride=1,
dilation=1,
is_causal=False,
scale=None,
additional_keys=None,
additional_values=None,
attention_kwargs=None,
backend=None,
q_tile_shape=None,
kv_tile_shape=None,
backward_q_tile_shape=None,
backward_kv_tile_shape=None,
backward_kv_splits=None,
backward_use_pt_reduction=False,
run_persistent_kernel=True,
kernel_schedule=None,
torch_compile=False,
try_fuse_additional_kv=False,
)
Computes 1-D neighborhood attention.
Parameters:
Name Type Description Defaultquery
Tensor
4-D query tensor, with the heads last layout ([batch, seqlen, heads, head_dim]
)
key
Tensor
4-D key tensor, with the heads last layout ([batch, seqlen, heads, head_dim]
)
value
Tensor
4-D value tensor, with the heads last layout ([batch, seqlen, heads, head_dim]
)
kernel_size
Tuple[int] | int
Neighborhood window (kernel) size.
Note
kernel_size
must be smaller than or equal to seqlen
.
stride
Tuple[int] | int
Sliding window step size. Defaults to 1
(standard sliding window).
Note
stride
must be smaller than or equal to kernel_size
. When stride == kernel_size
, there will be no overlap between sliding windows, which is equivalent to blocked attention (a.k.a. window self attention).
1
dilation
Tuple[int] | int
Dilation step size. Defaults to 1
(standard sliding window).
Note
The product of dilation
and kernel_size
must be smaller than or equal to seqlen
.
1
is_causal
Tuple[bool] | bool
Toggle causal masking. Defaults to False
(bi-directional).
False
scale
float
Attention scale. Defaults to head_dim ** -0.5
.
None
additional_keys
Optional[Tensor]
None
or 4-D key tensor, with the heads last layout ([batch, seqlen_kv, heads, head_dim]
), corresponding to key tokens from some additional context. Used when performing neighborhood cross-attention, where query
tokens attend to their neighborhood, as well as some fixed additional set of tokens.
None
additional_values
Optional[Tensor]
None
or 4-D value tensor, with the heads last layout ([batch, seqlen_kv, heads, head_dim]
), corresponding to value tokens from some additional context. Used when performing neighborhood cross-attention, where query
tokens attend to their neighborhood, as well as some fixed additional set of tokens.
Note
additional_keys
and additional_values
must both either be Tensor
s, or both None
s, and must match in shape.
None
Other Parameters:
Name Type Descriptionbackend
str
Backend implementation to run with. Choices are: None
(pick the best available one), "cutlass-fna"
, "hopper-fna"
, "blackwell-fna"
, "flex-fna"
. Refer to backends for more information.
q_tile_shape
Tuple[int]
1-D Tile shape for the query token layout in the forward pass kernel. You can use profiler to find valid choices for your use case, and search for the best combination.
kv_tile_shape
Tuple[int]
1-D Tile shape for the key-value token layout in the forward pass kernel. You can use profiler to find valid choices for your use case, and search for the best combination.
backward_q_tile_shape
Tuple[int]
1-D Tile shape for the query token layout in the backward pass kernel. This is only respected by the "cutlass-fna"
backend. You can use profiler to find valid choices for your use case, and search for the best combination.
backward_kv_tile_shape
Tuple[int]
1-D Tile shape for the key/value token layout in the backward pass kernel. This is only respected by the "cutlass-fna"
backend. You can use profiler to find valid choices for your use case, and search for the best combination.
backward_kv_splits
Tuple[int]
Number of key/value tiles allowed to work in parallel in the backward pass kernel. Like tile shapes, this is a tuple and not an integer for neighborhood attention operations, and the size of the tuple corresponds to the number of dimensions / rank of the layout of tokens. This is only respected by the "cutlass-fna"
backend, and only when KV parallelism is enabled.
backward_use_pt_reduction
bool
Whether to use PyTorch eager for computing the dO * O
product required by the backward pass, over the CUTLASS kernel. This only applies to the "cutlass-fna"
backend.
run_persistent_kernel
bool
Whether to use persistent tile scheduling in the forward pass kernel. This only applies to the "blackwell-fna"
backend.
kernel_schedule
Optional[str]
Kernel type (Hopper architecture only). Choices are None
: pick the default, "non"
(non-persistent), "coop"
(warp-specialized cooperative), or "pp"
(warp-specialized ping-ponging). Refer to Hopper FMHA/FNA backend for more information.
torch_compile
bool
Applies only to the "flex-fna"
backend. Whether or not to JIT compile the attention kernel. Due to this being an experimental feature in PyTorch, we do not recommend it, and it is guarded by context flags. Read more in Flex Attention + torch.compile
.
attention_kwargs
Optional[Dict]
arguments to the attention operator, if used to implement neighborhood cross-attention, or self attention as a fast path for neighborhood attention.
If additional_{keys,values}
are specified, NATTEN usually performs a separate cross-attention using our attention operator, and merges the results.
If for a given use case, the neighborhood attention problem is equivalent to self attention (not causal, kernel_size == seqlen
), NATTEN will also attempt to directly use attention.
You can override arguments to attention by passing a dictionary here.
Example
out = na1d(
q, k, v, kernel_size=kernel_size,
...,
attention_kwargs={
"backend": "blackwell-fmha",
"run_persistent_kernel": True,
}
)
try_fuse_additional_kv
bool
Some backends may support fusing cross-attention (additional KV) into the FNA kernel, instead of having to do a separate attention and then merge. This can only be supported in backends using Token Permutation for now, which means when there is dilation, there could be additional memory operations and memory usage if this fusion occurs. For now, only the "blackwell-fna"
backend supports this. We recommend using the profiler to see if this option is suitable for your use case before trying it.
Returns:
Name Type Descriptionoutput
Tensor
4-D output tensor, with the heads last layout ([batch, seqlen, heads, head_dim]
).
natten.na2d
na2d(
query,
key,
value,
kernel_size,
stride=1,
dilation=1,
is_causal=False,
scale=None,
additional_keys=None,
additional_values=None,
attention_kwargs=None,
backend=None,
q_tile_shape=None,
kv_tile_shape=None,
backward_q_tile_shape=None,
backward_kv_tile_shape=None,
backward_kv_splits=None,
backward_use_pt_reduction=False,
run_persistent_kernel=True,
kernel_schedule=None,
torch_compile=False,
try_fuse_additional_kv=False,
)
Computes 2-D neighborhood attention.
Parameters:
Name Type Description Defaultquery
Tensor
2-D query tensor, with the heads last layout: [batch, X, Y, heads, head_dim]
, where token layout shape (feature map shape) is (X, Y)
.
key
Tensor
2-D key tensor, with the heads last layout: [batch, X, Y, heads, head_dim]
, where token layout shape (feature map shape) is (X, Y)
.
value
Tensor
2-D value tensor, with the heads last layout: [batch, X, Y, heads, head_dim]
, where token layout shape (feature map shape) is (X, Y)
.
kernel_size
Tuple[int, int] | int
Neighborhood window (kernel) size/shape. If an integer, it will be repeated for all 2 dimensions. For example kernel_size=3
is reinterpreted as kernel_size=(3, 3)
.
Note
kernel_size
must be smaller than or equal to token layout shape ((X, Y)
) along every dimension.
stride
Tuple[int, int] | int
Sliding window step size/shape. Defaults to 1
(standard sliding window). If an integer, it will be repeated for all 2 dimensions. For example stride=2
is reinterpreted as stride=(2, 2)
.
Note
stride
must be smaller than or equal to kernel_size
along every dimension. When stride == kernel_size
, there will be no overlap between sliding windows, which is equivalent to blocked attention (a.k.a. window self attention).
1
dilation
Tuple[int, int] | int
Dilation step size/shape. Defaults to 1
(standard sliding window). If an integer, it will be repeated for all 2 dimensions. For example dilation=4
is reinterpreted as dilation=(4, 4)
.
Note
The product of dilation
and kernel_size
must be smaller than or equal to token layout shape ((X, Y)
) along every dimension.
1
is_causal
Tuple[bool, bool] | bool
Toggle causal masking. Defaults to False
(bi-directional). If a boolean, it will be repeated for all 2 dimensions. For example is_causal=True
is reinterpreted as is_causal=(True, True)
.
False
scale
float
Attention scale. Defaults to head_dim ** -0.5
.
None
additional_keys
Optional[Tensor]
None
or 4-D key tensor, with the heads last layout ([batch, seqlen_kv, heads, head_dim]
), corresponding to key tokens from some additional context. Used when performing neighborhood cross-attention, where query
tokens attend to their neighborhood, as well as some fixed additional set of tokens.
None
additional_values
Optional[Tensor]
None
or 4-D value tensor, with the heads last layout ([batch, seqlen_kv, heads, head_dim]
), corresponding to value tokens from some additional context. Used when performing neighborhood cross-attention, where query
tokens attend to their neighborhood, as well as some fixed additional set of tokens.
Note
additional_keys
and additional_values
must both either be Tensor
s, or both None
s, and must match in shape.
None
Other Parameters:
Name Type Descriptionbackend
str
Backend implementation to run with. Choices are: None
(pick the best available one), "cutlass-fna"
, "hopper-fna"
, "blackwell-fna"
, "flex-fna"
. Refer to backends for more information.
q_tile_shape
Tuple[int, int]
2-D Tile shape for the query token layout in the forward pass kernel. You can use profiler to find valid choices for your use case, and search for the best combination.
kv_tile_shape
Tuple[int, int]
2-D Tile shape for the key-value token layout in the forward pass kernel. You can use profiler to find valid choices for your use case, and search for the best combination.
backward_q_tile_shape
Tuple[int, int]
2-D Tile shape for the query token layout in the backward pass kernel. This is only respected by the "cutlass-fna"
backend. You can use profiler to find valid choices for your use case, and search for the best combination.
backward_kv_tile_shape
Tuple[int, int]
2-D Tile shape for the key/value token layout in the backward pass kernel. This is only respected by the "cutlass-fna"
backend. You can use profiler to find valid choices for your use case, and search for the best combination.
backward_kv_splits
Tuple[int, int]
Number of key/value tiles allowed to work in parallel in the backward pass kernel. Like tile shapes, this is a tuple and not an integer for neighborhood attention operations, and the size of the tuple corresponds to the number of dimensions / rank of the layout of tokens. This is only respected by the "cutlass-fna"
backend, and only when KV parallelism is enabled.
backward_use_pt_reduction
bool
Whether to use PyTorch eager for computing the dO * O
product required by the backward pass, over the CUTLASS kernel. This only applies to the "cutlass-fna"
backend.
run_persistent_kernel
bool
Whether to use persistent tile scheduling in the forward pass kernel. This only applies to the "blackwell-fna"
backend.
kernel_schedule
Optional[str]
Kernel type (Hopper architecture only). Choices are None
: pick the default, "non"
(non-persistent), "coop"
(warp-specialized cooperative), or "pp"
(warp-specialized ping-ponging). Refer to Hopper FMHA/FNA backend for more information.
torch_compile
bool
Applies only to the "flex-fna"
backend. Whether or not to JIT compile the attention kernel. Due to this being an experimental feature in PyTorch, we do not recommend it, and it is guarded by context flags. Read more in Flex Attention + torch.compile
.
attention_kwargs
Optional[Dict]
arguments to the attention operator, if used to implement neighborhood cross-attention, or self attention as a fast path for neighborhood attention.
If additional_{keys,values}
are specified, NATTEN usually performs a separate cross-attention using our attention operator, and merges the results.
If for a given use case, the neighborhood attention problem is equivalent to self attention (not causal along any dims, kernel_size == (X, Y)
), NATTEN will also attempt to directly use attention.
You can override arguments to attention by passing a dictionary here.
Example
out = na2d(
q, k, v, kernel_size=kernel_size,
...,
attention_kwargs={
"backend": "blackwell-fmha",
"run_persistent_kernel": True,
}
)
try_fuse_additional_kv
bool
Some backends may support fusing cross-attention (additional KV) into the FNA kernel, instead of having to do a separate attention and then merge. This can only be supported in backends using Token Permutation for now, which means when there is dilation, there could be additional memory operations and memory usage if this fusion occurs. For now, only the "blackwell-fna"
backend supports this. We recommend using the profiler to see if this option is suitable for your use case before trying it.
Returns:
Name Type Descriptionoutput
Tensor
5-D output tensor, with the heads last layout ([batch, X, Y, heads, head_dim]
).
natten.na3d
na3d(
query,
key,
value,
kernel_size,
stride=1,
dilation=1,
is_causal=False,
scale=None,
additional_keys=None,
additional_values=None,
attention_kwargs=None,
backend=None,
q_tile_shape=None,
kv_tile_shape=None,
backward_q_tile_shape=None,
backward_kv_tile_shape=None,
backward_kv_splits=None,
backward_use_pt_reduction=False,
run_persistent_kernel=True,
kernel_schedule=None,
torch_compile=False,
try_fuse_additional_kv=False,
)
Computes 3-D neighborhood attention.
Parameters:
Name Type Description Defaultquery
Tensor
3-D query tensor, with the heads last layout: [batch, X, Y, Z, heads, head_dim]
, where token layout shape (feature map shape) is (X, Y, Z)
.
key
Tensor
3-D key tensor, with the heads last layout: [batch, X, Y, Z, heads, head_dim]
, where token layout shape (feature map shape) is (X, Y, Z)
.
value
Tensor
3-D value tensor, with the heads last layout: [batch, X, Y, Z, heads, head_dim]
, where token layout shape (feature map shape) is (X, Y, Z)
.
kernel_size
Tuple[int, int, int] | int
Neighborhood window (kernel) size/shape. If an integer, it will be repeated for all 3 dimensions. For example kernel_size=3
is reinterpreted as kernel_size=(3, 3, 3)
.
Note
kernel_size
must be smaller than or equal to token layout shape ((X, Y, Z)
) along every dimension.
stride
Tuple[int, int, int] | int
Sliding window step size/shape. Defaults to 1
(standard sliding window). If an integer, it will be repeated for all 3 dimensions. For example stride=2
is reinterpreted as stride=(2, 2, 2)
.
Note
stride
must be smaller than or equal to kernel_size
along every dimension. When stride == kernel_size
, there will be no overlap between sliding windows, which is equivalent to blocked attention (a.k.a. window self attention).
1
dilation
Tuple[int, int, int] | int
Dilation step size/shape. Defaults to 1
(standard sliding window). If an integer, it will be repeated for all 3 dimensions. For example dilation=4
is reinterpreted as dilation=(4, 4, 4)
.
Note
The product of dilation
and kernel_size
must be smaller than or equal to token layout shape ((X, Y, Z)
) along every dimension.
1
is_causal
Tuple[bool, bool, bool] | bool
Toggle causal masking. Defaults to False
(bi-directional). If a boolean, it will be repeated for all 3 dimensions. For example is_causal=True
is reinterpreted as is_causal=(True, True, True)
.
False
scale
float
Attention scale. Defaults to head_dim ** -0.5
.
None
additional_keys
Optional[Tensor]
None
or 4-D key tensor, with the heads last layout ([batch, seqlen_kv, heads, head_dim]
), corresponding to key tokens from some additional context. Used when performing neighborhood cross-attention, where query
tokens attend to their neighborhood, as well as some fixed additional set of tokens.
None
additional_values
Optional[Tensor]
None
or 4-D value tensor, with the heads last layout ([batch, seqlen_kv, heads, head_dim]
), corresponding to value tokens from some additional context. Used when performing neighborhood cross-attention, where query
tokens attend to their neighborhood, as well as some fixed additional set of tokens.
Note
additional_keys
and additional_values
must both either be Tensor
s, or both None
s, and must match in shape.
None
Other Parameters:
Name Type Descriptionbackend
str
Backend implementation to run with. Choices are: None
(pick the best available one), "cutlass-fna"
, "hopper-fna"
, "blackwell-fna"
, "flex-fna"
. Refer to backends for more information.
q_tile_shape
Tuple[int, int, int]
3-D Tile shape for the query token layout in the forward pass kernel. You can use profiler to find valid choices for your use case, and search for the best combination.
kv_tile_shape
Tuple[int, int, int]
3-D Tile shape for the key-value token layout in the forward pass kernel. You can use profiler to find valid choices for your use case, and search for the best combination.
backward_q_tile_shape
Tuple[int, int, int]
3-D Tile shape for the query token layout in the backward pass kernel. This is only respected by the "cutlass-fna"
backend. You can use profiler to find valid choices for your use case, and search for the best combination.
backward_kv_tile_shape
Tuple[int, int, int]
3-D Tile shape for the key/value token layout in the backward pass kernel. This is only respected by the "cutlass-fna"
backend. You can use profiler to find valid choices for your use case, and search for the best combination.
backward_kv_splits
Tuple[int, int, int]
Number of key/value tiles allowed to work in parallel in the backward pass kernel. Like tile shapes, this is a tuple and not an integer for neighborhood attention operations, and the size of the tuple corresponds to the number of dimensions / rank of the layout of tokens. This is only respected by the "cutlass-fna"
backend, and only when KV parallelism is enabled.
backward_use_pt_reduction
bool
Whether to use PyTorch eager for computing the dO * O
product required by the backward pass, over the CUTLASS kernel. This only applies to the "cutlass-fna"
backend.
run_persistent_kernel
bool
Whether to use persistent tile scheduling in the forward pass kernel. This only applies to the "blackwell-fna"
backend.
kernel_schedule
Optional[str]
Kernel type (Hopper architecture only). Choices are None
: pick the default, "non"
(non-persistent), "coop"
(warp-specialized cooperative), or "pp"
(warp-specialized ping-ponging). Refer to Hopper FMHA/FNA backend for more information.
torch_compile
bool
Applies only to the "flex-fna"
backend. Whether or not to JIT compile the attention kernel. Due to this being an experimental feature in PyTorch, we do not recommend it, and it is guarded by context flags. Read more in Flex Attention + torch.compile
.
attention_kwargs
Optional[Dict]
arguments to the attention operator, if used to implement neighborhood cross-attention, or self attention as a fast path for neighborhood attention.
If additional_{keys,values}
are specified, NATTEN usually performs a separate cross-attention using our attention operator, and merges the results.
If for a given use case, the neighborhood attention problem is equivalent to self attention (not causal along any dims, kernel_size == (X, Y, Z)
), NATTEN will also attempt to directly use attention.
You can override arguments to attention by passing a dictionary here.
Example
out = na3d(
q, k, v, kernel_size=kernel_size,
...,
attention_kwargs={
"backend": "blackwell-fmha",
"run_persistent_kernel": True,
}
)
try_fuse_additional_kv
bool
Some backends may support fusing cross-attention (additional KV) into the FNA kernel, instead of having to do a separate attention and then merge. This can only be supported in backends using Token Permutation for now, which means when there is dilation, there could be additional memory operations and memory usage if this fusion occurs. For now, only the "blackwell-fna"
backend supports this. We recommend using the profiler to see if this option is suitable for your use case before trying it.
Returns:
Name Type Descriptionoutput
Tensor
6-D output tensor, with the heads last layout ([batch, X, Y, Z, heads, head_dim]
).
natten.attention
attention(
query,
key,
value,
scale=None,
backend=None,
q_tile_size=None,
kv_tile_size=None,
backward_q_tile_size=None,
backward_kv_tile_size=None,
backward_kv_splits=None,
backward_use_pt_reduction=False,
run_persistent_kernel=True,
kernel_schedule=None,
torch_compile=False,
return_lse=False,
)
Runs standard dot product attention.
This operation is used to implement neighborhood cross attention, in which we allow every token to interact with some additional context (additional_keys
and additional_values
tensors in na1d, na2d, and na3d). This operator is also used as a fast path for cases where neighborhood attention is equivalent to self attention (not causal along any dims, and kernel_size
is equal to the number of input tokens).
This operation does not call into PyTorch's SDPA, and only runs one of the NATTEN backends (cutlass-fmha
, hopper-fmha
, blackwell-fmha
, flex-fmha
). Reasons for that include being able to control performance-related arguments, return logsumexp, and more. For more information refer to backends.
Parameters:
Name Type Description Defaultquery
Tensor
4-D query tensor, with the heads last layout ([batch, seqlen, heads, head_dim]
)
key
Tensor
4-D key tensor, with the heads last layout ([batch, seqlen_kv, heads, head_dim]
)
value
Tensor
4-D value tensor, with the heads last layout ([batch, seqlen_kv, heads, head_dim_v]
)
scale
float
Attention scale. Defaults to head_dim ** -0.5
.
None
Other Parameters:
Name Type Descriptionbackend
str
Backend implementation to run with. Choices are: None
(pick the best available one), "cutlass-fmha"
, "hopper-fmha"
, "blackwell-fmha"
, "flex-fmha"
. Refer to backends for more information.
q_tile_size
int
Tile size along query sequence length in the forward pass kernel. You can use profiler to find valid choices for your use case.
kv_tile_size
int
Tile size along key/value sequence length in the forward pass kernel. You can use profiler to find valid choices for your use case.
backward_q_tile_size
int
Tile size along query sequence length in the backward pass kernel. This is only respected by the "cutlass-fmha"
backend. You can use profiler to find valid choices for your use case.
backward_kv_tile_size
int
Tile size along key/value sequence length in the backward pass kernel. This is only respected by the "cutlass-fmha"
backend. You can use profiler to find valid choices for your use case.
backward_kv_splits
int
Number of key/value tiles allowed to work in parallel in the backward pass kernel. This is only respected by the "cutlass-fmha"
backend, only when KV parallelism is enabled.
backward_use_pt_reduction
bool
Whether to use PyTorch eager for computing the dO * O
product required by the backward pass, over the CUTLASS kernel. This only applies to the "cutlass-fmha"
backend.
run_persistent_kernel
bool
Whether to use persistent tile scheduling in the forward pass kernel. This only applies to the "blackwell-fmha"
backend.
kernel_schedule
Optional[str]
Kernel type (Hopper architecture only). Choices are None
: pick the default, "non"
(non-persistent), "coop"
(warp-specialized cooperative), or "pp"
(warp-specialized ping-ponging). Refer to Hopper FMHA/FNA backend for more information.
torch_compile
bool
Applies only to the "flex-fmha"
backend. Whether or not to JIT compile the attention kernel. Due to this being an experimental feature in PyTorch, we do not recommend it, and it is guarded by context flags. Read more in Flex Attention + torch.compile
.
return_lse
bool
Whether or not to return the logsumexp
tensor. logsumexp
can be used in the backward pass, and for attention merging.
Returns:
Name Type Descriptionoutput
Tensor
4-D output tensor, with the heads last layout ([batch, seqlen, heads, head_dim_v]
).
logsumexp
Tensor
only returned when return_lse=True
. 3-D logsumexp tensor, with the heads last layout ([batch, seqlen, heads]
).
natten.merge_attentions
merge_attentions(outputs, lse_tensors, torch_compile=True)
Takes multiple attention outputs originating from the same query tensor, and their corresponding logsumexps, and merges them as if their context (key/value pair) had been concatenated.
This operation is used to implement cross-neighborhood attention, and can also be used for distributed setups, such as context-parallelism.
This operation also attempts to use torch.compile
to fuse the elementwise operations. This can be disabled by passing torch_compile=False
.
Parameters:
Name Type Description Defaultoutputs
List[Tensor]
List of 4-D attention output tensors, with the heads last layout ([batch, seqlen, heads, head_dim]
)
lse_tensors
List[Tensor]
List of 3-D logsumexp tensors, with the heads last layout ([batch, seqlen, heads]
)
torch_compile
bool
Attempt to use torch.compile
to fuse the underlying elementwise operations.
True
Returns:
Name Type Descriptionoutput
Tensor
merged attention output.
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4