Return the index of the maximum value of an array, ignoring NaNs.
JAX implementation of numpy.nanargmax()
.
a (ArrayLike) – input array
axis (int | None | None) – optional integer specifying the axis along which to find the maximum value. If axis
is not specified, a
will be flattened.
out (None | None) – unused by JAX
keepdims (bool | None | None) – if True, then return an array with the same number of dimensions as a
.
an array containing the index of the maximum value along the specified axis.
Note
In the case of an axis with all-NaN values, the returned index will be -1. This differs from the behavior of numpy.nanargmax()
, which raises an error.
Examples
>>> x = jnp.array([1, 3, 5, 4, jnp.nan])
Using a standard argmax()
leads to potentially unexpected results:
>>> jnp.argmax(x) Array(4, dtype=int32)
Using nanargmax
returns the index of the maximum non-NaN value.
>>> jnp.nanargmax(x) Array(2, dtype=int32)
>>> x = jnp.array([[1, 3, jnp.nan], ... [5, 4, jnp.nan]]) >>> jnp.nanargmax(x, axis=1) Array([1, 0], dtype=int32)
>>> jnp.nanargmax(x, axis=1, keepdims=True) Array([[1], [0]], dtype=int32)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.3