Reduces all values from the src
tensor into out
within the ranges specified in the indptr
tensor along the last dimension of indptr
. For each value in src
, its output index is specified by its index in src
for dimensions outside of indptr.dim() - 1
and by the corresponding range index in indptr
for dimension indptr.dim() - 1
. The applied reduction is defined via the reduce
argument.
Formally, if src
and indptr
are \(n\)-dimensional and \(m\)-dimensional tensors with size \((x_0, ..., x_{m-1}, x_m, x_{m+1}, ..., x_{n-1})\) and \((x_0, ..., x_{m-2}, y)\), respectively, then out
must be an \(n\)-dimensional tensor with size \((x_0, ..., x_{m-2}, y - 1, x_{m}, ..., x_{n-1})\). Moreover, the values of indptr
must be between \(0\) and \(x_m\) in ascending order. The indptr
tensor supports broadcasting in case its dimensions do not match with src
.
For one-dimensional tensors with reduce="sum"
, the operation computes
\[\mathrm{out}_i = \sum_{j = \mathrm{indptr}[i]}^{\mathrm{indptr}[i+1]-1}~\mathrm{src}_j.\]
Due to the use of index pointers, segment_csr()
is the fastest method to apply for grouped reductions.
src – The source tensor.
indptr – The index pointers between elements to segment. The number of dimensions of index
needs to be less than or equal to src
.
out – The destination tensor.
reduce – The reduce operation ("sum"
, "mean"
, "min"
or "max"
). (default: "sum"
)
Tensor
from torch_scatter import segment_csr src = torch.randn(10, 6, 64) indptr = torch.tensor([0, 2, 5, 6]) indptr = indptr.view(1, -1) # Broadcasting in the first and last dim. out = segment_csr(src, indptr, reduce="sum") print(out.size())
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4