Compute the weighed average.
JAX Implementation of numpy.average()
.
a (ArrayLike) – array to be averaged
axis (Axis) – an optional integer or sequence of integers specifying the axis along which the mean to be computed. If not specified, mean is computed along all the axes.
weights (ArrayLike | None) – an optional array of weights for a weighted average. Must be broadcast-compatible with a
.
returned (bool) – If False (default) then return only the average. If True then return both the average and the normalization factor (i.e. the sum of weights).
keepdims (bool) – If True, reduced axes are left in the result with size 1. If False (default) then reduced axes are squeezed out.
An array average
or tuple of arrays (average, normalization)
if returned
is True.
Examples
Simple average:
>>> x = jnp.array([1, 2, 3, 2, 4]) >>> jnp.average(x) Array(2.4, dtype=float32)
Weighted average:
>>> weights = jnp.array([2, 1, 3, 2, 2]) >>> jnp.average(x, weights=weights) Array(2.5, dtype=float32)
Use returned=True
to optionally return the normalization, i.e. the sum of weights:
>>> jnp.average(x, returned=True) (Array(2.4, dtype=float32), Array(5., dtype=float32)) >>> jnp.average(x, weights=weights, returned=True) (Array(2.5, dtype=float32), Array(10., dtype=float32))
Weighted average along a specified axis:
>>> x = jnp.array([[8, 2, 7], ... [3, 6, 4]]) >>> weights = jnp.array([1, 2, 3]) >>> jnp.average(x, weights=weights, axis=1) Array([5.5, 4.5], dtype=float32)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4