Return indices of a mask of an (n, n) array.
n (int) – static integer array dimension.
mask_func (Callable[[ArrayLike, int], Array]) – a function that takes a shape (n, n)
array and an optional offset k
, and returns a shape (n, n)
mask. Examples of functions with this signature are triu()
and tril()
.
k (int) – a scalar value passed to mask_func
.
size (int | None) – optional argument specifying the static size of the output arrays. This is passed to nonzero()
when generating the indices from the mask.
a tuple of indices where mask_func
is nonzero.
Examples
Calling mask_indices
on built-in masking functions:
>>> jnp.mask_indices(3, jnp.triu) (Array([0, 0, 0, 1, 1, 2], dtype=int32), Array([0, 1, 2, 1, 2, 2], dtype=int32))
>>> jnp.mask_indices(3, jnp.tril) (Array([0, 1, 1, 2, 2, 2], dtype=int32), Array([0, 0, 1, 0, 1, 2], dtype=int32))
Calling mask_indices
on a custom masking function:
>>> def mask_func(x, k=0): ... i = jnp.arange(x.shape[0])[:, None] ... j = jnp.arange(x.shape[1]) ... return (i + 1) % (j + 1 + k) == 0 >>> mask_func(jnp.ones((3, 3))) Array([[ True, False, False], [ True, True, False], [ True, False, True]], dtype=bool) >>> jnp.mask_indices(3, mask_func) (Array([0, 1, 1, 2, 2], dtype=int32), Array([0, 0, 1, 0, 2], dtype=int32))
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4