Compute the standard error of the mean.
JAX implementation of scipy.stats.sem()
.
a (ArrayLike) – arraylike
axis (int | None) – optional integer. If not specified, the input array is flattened.
ddof (int) – integer, default=1. The degrees of freedom in the SEM computation.
nan_policy (str) – str, default=”propagate”. JAX supports only “propagate” and “omit”.
keepdims (bool) – bool, default=False. If true, reduced axes are left in the result with size 1.
array
Examples
>>> x = jnp.array([2, 4, 1, 1, 3, 4, 4, 2, 3]) >>> with jnp.printoptions(precision=2, suppress=True): ... jax.scipy.stats.sem(x) Array(0.41, dtype=float32)
For multi dimensional arrays, sem
computes standard error of mean along axis=0
:
>>> x1 = jnp.array([[1, 2, 1, 3, 2, 1], ... [3, 1, 3, 2, 1, 3], ... [1, 2, 2, 3, 1, 2]]) >>> with jnp.printoptions(precision=2, suppress=True): ... jax.scipy.stats.sem(x1) Array([0.67, 0.33, 0.58, 0.33, 0.33, 0.58], dtype=float32)
If axis=1
, standard error of mean will be computed along axis 1
.
>>> with jnp.printoptions(precision=2, suppress=True): ... jax.scipy.stats.sem(x1, axis=1) Array([0.33, 0.4 , 0.31], dtype=float32)
If axis=None
, standard error of mean will be computed along all the axes.
>>> with jnp.printoptions(precision=2, suppress=True): ... jax.scipy.stats.sem(x1, axis=None) Array(0.2, dtype=float32)
By default, sem
reduces the dimension of the result. To keep the dimensions same as that of the input array, the argument keepdims
must be set to True
.
>>> with jnp.printoptions(precision=2, suppress=True): ... jax.scipy.stats.sem(x1, axis=1, keepdims=True) Array([[0.33], [0.4 ], [0.31]], dtype=float32)
Since, by default, nan_policy='propagate'
, sem
propagates the nan
values in the result.
>>> nan = jnp.nan >>> x2 = jnp.array([[1, 2, 3, nan, 4, 2], ... [4, 5, 4, 3, nan, 1], ... [7, nan, 8, 7, 9, nan]]) >>> with jnp.printoptions(precision=2, suppress=True): ... jax.scipy.stats.sem(x2) Array([1.73, nan, 1.53, nan, nan, nan], dtype=float32)
If nan_policy='omit`
, sem
omits the nan
values and computes the error for the remainging values along the specified axis.
>>> with jnp.printoptions(precision=2, suppress=True): ... jax.scipy.stats.sem(x2, nan_policy='omit') Array([1.73, 1.5 , 1.53, 2. , 2.5 , 0.5 ], dtype=float32)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.3