Compute the mode (most common value) along an axis of an array.
JAX implementation of scipy.stats.mode()
.
A tuple of arrays, (mode, count)
. mode
is the array of modal values, and count
is the number of times each value appears in the input array.
ModeResult
Examples
>>> x = jnp.array([2, 4, 1, 1, 3, 4, 4, 2, 3]) >>> mode, count = jax.scipy.stats.mode(x) >>> mode, count (Array(4, dtype=int32), Array(3, dtype=int32))
For multi dimensional arrays, jax.scipy.stats.mode
computes the mode
and the corresponding count
along axis=0
:
>>> x1 = jnp.array([[1, 2, 1, 3, 2, 1], ... [3, 1, 3, 2, 1, 3], ... [1, 2, 2, 3, 1, 2]]) >>> mode, count = jax.scipy.stats.mode(x1) >>> mode, count (Array([1, 2, 1, 3, 1, 1], dtype=int32), Array([2, 2, 1, 2, 2, 1], dtype=int32))
If axis=1
, mode
and count
will be computed along axis 1
.
>>> mode, count = jax.scipy.stats.mode(x1, axis=1) >>> mode, count (Array([1, 3, 2], dtype=int32), Array([3, 3, 3], dtype=int32))
By default, jax.scipy.stats.mode
reduces the dimension of the result. To keep the dimensions same as that of the input array, the argument keepdims
must be set to True
.
>>> mode, count = jax.scipy.stats.mode(x1, axis=1, keepdims=True) >>> mode, count (Array([[1], [3], [2]], dtype=int32), Array([[3], [3], [3]], dtype=int32))
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.3