Splits a tensor into multiple sub-tensors, all of which are views of input
, along dimension dim
according to the indices or number of sections specified by indices_or_sections
. This function is based on NumPy’s numpy.array_split()
.
>>> x = torch.arange(8) >>> torch.tensor_split(x, 3) (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7])) >>> x = torch.arange(7) >>> torch.tensor_split(x, 3) (tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6])) >>> torch.tensor_split(x, (1, 6)) (tensor([0]), tensor([1, 2, 3, 4, 5]), tensor([6])) >>> x = torch.arange(14).reshape(2, 7) >>> x tensor([[ 0, 1, 2, 3, 4, 5, 6], [ 7, 8, 9, 10, 11, 12, 13]]) >>> torch.tensor_split(x, 3, dim=1) (tensor([[0, 1, 2], [7, 8, 9]]), tensor([[ 3, 4], [10, 11]]), tensor([[ 5, 6], [12, 13]])) >>> torch.tensor_split(x, (1, 6), dim=1) (tensor([[0], [7]]), tensor([[ 1, 2, 3, 4, 5], [ 8, 9, 10, 11, 12]]), tensor([[ 6], [13]]))
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4