Return the median of array elements along a given axis.
JAX implementation of numpy.median()
.
a (ArrayLike) – input array.
axis (int | tuple[int, ...] | None) – optional, int or sequence of ints, default=None. Axis along which the median to be computed. If None, median is computed for the flattened array.
keepdims (bool) – bool, default=False. If true, reduced axes are left in the result with size 1.
out (None) – Unused by JAX.
overwrite_input (bool) – Unused by JAX.
An array of the median along the given axis.
Examples
By default, the median is computed for the flattened array.
>>> x = jnp.array([[2, 4, 7, 1], ... [3, 5, 9, 2], ... [6, 1, 8, 3]]) >>> jnp.median(x) Array(3.5, dtype=float32)
If axis=1
, the median is computed along axis 1.
>>> jnp.median(x, axis=1) Array([3. , 4. , 4.5], dtype=float32)
If keepdims=True
, ndim
of the output is equal to that of the input.
>>> jnp.median(x, axis=1, keepdims=True) Array([[3. ], [4. ], [4.5]], dtype=float32)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4