Return lower triangle of an array.
JAX implementation of numpy.tril()
m (ArrayLike) – input array. Must have m.ndim >= 2
.
k (int) – k: optional, int, default=0. Specifies the sub-diagonal above which the elements of the array are set to zero. k=0
refers to main diagonal, k<0
refers to sub-diagonal below the main diagonal and k>0
refers to sub-diagonal above the main diagonal.
An array with same shape as input containing the lower triangle of the given array with elements above the sub-diagonal specified by k
are set to zero.
See also
jax.numpy.triu()
: Returns an upper triangle of an array.
jax.numpy.tri()
: Returns an array with ones on and below the diagonal and zeros elsewhere.
Examples
>>> x = jnp.array([[1, 2, 3, 4], ... [5, 6, 7, 8], ... [9, 10, 11, 12]]) >>> jnp.tril(x) Array([[ 1, 0, 0, 0], [ 5, 6, 0, 0], [ 9, 10, 11, 0]], dtype=int32) >>> jnp.tril(x, k=1) Array([[ 1, 2, 0, 0], [ 5, 6, 7, 0], [ 9, 10, 11, 12]], dtype=int32) >>> jnp.tril(x, k=-1) Array([[ 0, 0, 0, 0], [ 5, 0, 0, 0], [ 9, 10, 0, 0]], dtype=int32)
When m.ndim > 2
, jnp.tril
operates batch-wise on the trailing axes.
>>> x1 = jnp.array([[[1, 2], ... [3, 4]], ... [[5, 6], ... [7, 8]]]) >>> jnp.tril(x1) Array([[[1, 0], [3, 4]], [[5, 0], [7, 8]]], dtype=int32)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4