Return the indices of lower triangle of an array of size (n, m)
.
JAX implementation of numpy.tril_indices()
.
n (int) – int. Number of rows of the array for which the indices are returned.
k (int) – optional, int, default=0. Specifies the sub-diagonal on and below which the indices of lower triangle are returned. k=0
refers to main diagonal, k<0
refers to sub-diagonal below the main diagonal and k>0
refers to sub-diagonal above the main diagonal.
m (int | None | None) – optional, int. Number of columns of the array for which the indices are returned. If not specified, then m = n
.
A tuple of two arrays containing the indices of the lower triangle, one along each axis.
Examples
If only n
is provided in input, the indices of lower triangle of an array of size (n, n)
array are returned.
>>> jnp.tril_indices(3) (Array([0, 1, 1, 2, 2, 2], dtype=int32), Array([0, 0, 1, 0, 1, 2], dtype=int32))
If both n
and m
are provided in input, the indices of lower triangle of an (n, m)
array are returned.
>>> jnp.tril_indices(3, m=2) (Array([0, 1, 1, 2, 2], dtype=int32), Array([0, 0, 1, 0, 1], dtype=int32))
If k = 1
, the indices on and below the first sub-diagonal above the main diagonal are returned.
>>> jnp.tril_indices(3, k=1) (Array([0, 0, 1, 1, 1, 2, 2, 2], dtype=int32), Array([0, 1, 0, 1, 2, 0, 1, 2], dtype=int32))
If k = -1
, the indices on and below the first sub-diagonal below the main diagonal are returned.
>>> jnp.tril_indices(3, k=-1) (Array([1, 2, 2], dtype=int32), Array([0, 0, 1], dtype=int32))
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.3