A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://docs.jax.dev/en/latest/_autosummary/jax.numpy.nanargmax.html below:

jax.numpy.nanargmax — JAX documentation

jax.numpy.nanargmax#
jax.numpy.nanargmax(a, axis=None, out=None, keepdims=None)[source]#

Return the index of the maximum value of an array, ignoring NaNs.

JAX implementation of numpy.nanargmax().

Parameters:
  • a (ArrayLike) – input array

  • axis (int | None | None) – optional integer specifying the axis along which to find the maximum value. If axis is not specified, a will be flattened.

  • out (None | None) – unused by JAX

  • keepdims (bool | None | None) – if True, then return an array with the same number of dimensions as a.

Returns:

an array containing the index of the maximum value along the specified axis.

Return type:

Array

Note

In the case of an axis with all-NaN values, the returned index will be -1. This differs from the behavior of numpy.nanargmax(), which raises an error.

Examples

>>> x = jnp.array([1, 3, 5, 4, jnp.nan])

Using a standard argmax() leads to potentially unexpected results:

>>> jnp.argmax(x)
Array(4, dtype=int32)

Using nanargmax returns the index of the maximum non-NaN value.

>>> jnp.nanargmax(x)
Array(2, dtype=int32)
>>> x = jnp.array([[1, 3, jnp.nan],
...                [5, 4, jnp.nan]])
>>> jnp.nanargmax(x, axis=1)
Array([1, 0], dtype=int32)
>>> jnp.nanargmax(x, axis=1, keepdims=True)
Array([[1],
       [0]], dtype=int32)

RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.3