Reduces all values from the src
tensor into out
at the indices specified in the index
tensor along the last dimension of index
. For each value in src
, its output index is specified by its index in src
for dimensions outside of index.dim() - 1
and by the corresponding value in index
for dimension index.dim() - 1
. The applied reduction is defined via the reduce
argument.
Formally, if src
and index
are \(n\)-dimensional and \(m\)-dimensional tensors with size \((x_0, ..., x_{m-1}, x_m, x_{m+1}, ..., x_{n-1})\) and \((x_0, ..., x_{m-1}, x_m)\), respectively, then out
must be an \(n\)-dimensional tensor with size \((x_0, ..., x_{m-1}, y, x_{m+1}, ..., x_{n-1})\). Moreover, the values of index
must be between \(0\) and \(y - 1\) in ascending order. The index
tensor supports broadcasting in case its dimensions do not match with src
.
For one-dimensional tensors with reduce="sum"
, the operation computes
\[\mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j\]
where \(\sum_j\) is over \(j\) such that \(\mathrm{index}_j = i\).
In contrast to scatter()
, this method expects values in index
to be sorted along dimension index.dim() - 1
. Due to the use of sorted indices, segment_coo()
is usually faster than the more general scatter()
operation.
Note
This operation is implemented via atomic operations on the GPU and is therefore non-deterministic since the order of parallel operations to the same value is undetermined. For floating-point variables, this results in a source of variance in the result.
src – The source tensor.
index – The sorted indices of elements to segment. The number of dimensions of index
needs to be less than or equal to src
.
out – The destination tensor.
dim_size – If out
is not given, automatically create output with size dim_size
at dimension index.dim() - 1
. If dim_size
is not given, a minimal sized output tensor according to index.max() + 1
is returned.
reduce – The reduce operation ("sum"
, "mean"
, "min"
or "max"
). (default: "sum"
)
Tensor
from torch_scatter import segment_coo src = torch.randn(10, 6, 64) index = torch.tensor([0, 0, 1, 1, 1, 2]) index = index.view(1, -1) # Broadcasting in the first and last dim. out = segment_coo(src, index, reduce="sum") print(out.size())
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4