Return an array with ones on and below the diagonal and zeros elsewhere.
JAX implementation of numpy.tri()
N (int) – int. Dimension of the rows of the returned array.
M (int | None) – optional, int. Dimension of the columns of the returned array. If not specified, then M = N
.
k (int) – optional, int, default=0. Specifies the sub-diagonal on and below which the array is filled with ones. k=0
refers to main diagonal, k<0
refers to sub-diagonal below the main diagonal and k>0
refers to sub-diagonal above the main diagonal.
dtype (DTypeLike | None) – optional, data type of the returned array. The default type is float.
An array of shape (N, M)
containing the lower triangle with elements below the sub-diagonal specified by k
are set to one and zero elsewhere.
Examples
>>> jnp.tri(3) Array([[1., 0., 0.], [1., 1., 0.], [1., 1., 1.]], dtype=float32)
When M
is not equal to N
:
>>> jnp.tri(3, 4) Array([[1., 0., 0., 0.], [1., 1., 0., 0.], [1., 1., 1., 0.]], dtype=float32)
when k>0
:
>>> jnp.tri(3, k=1) Array([[1., 1., 0.], [1., 1., 1.], [1., 1., 1.]], dtype=float32)
When k<0
:
>>> jnp.tri(3, 4, k=-1) Array([[0., 0., 0., 0.], [1., 0., 0., 0.], [1., 1., 0., 0.]], dtype=float32)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4