Reduces all values from the src
tensor into out
at the indices specified in the index
tensor along a given axis dim
. For each value in src
, its output index is specified by its index in src
for dimensions outside of dim
and by the corresponding value in index
for dimension dim
. The applied reduction is defined via the reduce
argument.
Formally, if src
and index
are \(n\)-dimensional tensors with size \((x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})\) and dim
= i, then out
must be an \(n\)-dimensional tensor with size \((x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})\). Moreover, the values of index
must be between \(0\) and \(y - 1\), although no specific ordering of indices is required. The index
tensor supports broadcasting in case its dimensions do not match with src
.
For one-dimensional tensors with reduce="sum"
, the operation computes
\[\mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j\]
where \(\sum_j\) is over \(j\) such that \(\mathrm{index}_j = i\).
Note
This operation is implemented via atomic operations on the GPU and is therefore non-deterministic since the order of parallel operations to the same value is undetermined. For floating-point variables, this results in a source of variance in the result.
src – The source tensor.
index – The indices of elements to scatter.
dim – The axis along which to index. (default: -1
)
out – The destination tensor.
dim_size – If out
is not given, automatically create output with size dim_size
at dimension dim
. If dim_size
is not given, a minimal sized output tensor according to index.max() + 1
is returned.
reduce – The reduce operation ("sum"
, "mul"
, "mean"
, "min"
or "max"
). (default: "sum"
)
Tensor
from torch_scatter import scatter src = torch.randn(10, 6, 64) index = torch.tensor([0, 1, 0, 1, 2, 1]) # Broadcasting in the first and last dim. out = scatter(src, index, dim=1, reduce="sum") print(out.size())
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4