Selects values from input
at the 1-dimensional indices from indices
along the given dim
.
If dim
is None, the input array is treated as if it has been flattened to 1d.
Functions that return indices along a dimension, like torch.argmax()
and torch.argsort()
, are designed to work with this function. See the examples below.
>>> t = torch.tensor([[10, 30, 20], [60, 40, 50]]) >>> max_idx = torch.argmax(t) >>> torch.take_along_dim(t, max_idx) tensor([60]) >>> sorted_idx = torch.argsort(t, dim=1) >>> torch.take_along_dim(t, sorted_idx, dim=1) tensor([[10, 20, 30], [40, 50, 60]])
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4