Return the index of the maximum value of an array.
JAX implementation of numpy.argmax()
.
an array containing the index of the maximum value along the specified axis.
Examples
>>> x = jnp.array([1, 3, 5, 4, 2]) >>> jnp.argmax(x) Array(2, dtype=int32)
>>> x = jnp.array([[1, 3, 2], ... [5, 4, 1]]) >>> jnp.argmax(x, axis=1) Array([1, 0], dtype=int32)
>>> jnp.argmax(x, axis=1, keepdims=True) Array([[1], [0]], dtype=int32)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4