Compute the Pearson correlation coefficients.
JAX implementation of numpy.corrcoef()
.
This is a normalized version of the sample covariance computed by jax.numpy.cov()
. For a sample covariance \(C_{ij}\), the correlation coefficients are
\[R_{ij} = \frac{C_{ij}}{\sqrt{C_{ii}C_{jj}}}\]
they are constructed such that the values satisfy \(-1 \le R_{ij} \le 1\).
x (ArrayLike) – array of shape (M, N)
(if rowvar
is True), or (N, M)
(if rowvar
is False) representing N
observations of M
variables. x
may also be one-dimensional, representing N
observations of a single variable.
y (ArrayLike | None) – optional set of additional observations, with the same form as m
. If specified, then y
is combined with m
, i.e. for the default rowvar = True
case, m
becomes jnp.vstack([m, y])
.
rowvar (bool) – if True (default) then each row of m
represents a variable. If False, then each column represents a variable.
A covariance matrix of shape (M, M)
.
Examples
Consider these observations of two variables that correlate perfectly. The correlation matrix in this case is a 2x2 matrix of ones:
>>> x = jnp.array([[0, 1, 2], ... [0, 1, 2]]) >>> jnp.corrcoef(x) Array([[1., 1.], [1., 1.]], dtype=float32)
Now consider these observations of two variables that are perfectly anti-correlated. The correlation matrix in this case has -1
in the off-diagonal:
>>> x = jnp.array([[-1, 0, 1], ... [ 1, 0, -1]]) >>> jnp.corrcoef(x) Array([[ 1., -1.], [-1., 1.]], dtype=float32)
Equivalently, these sequences can be specified as separate arguments, in which case they are stacked before continuing the computation.
>>> x = jnp.array([-1, 0, 1]) >>> y = jnp.array([1, 0, -1]) >>> jnp.corrcoef(x, y) Array([[ 1., -1.], [-1., 1.]], dtype=float32)
The entries of the correlation matrix are normalized such that they lie within the range -1 to +1, where +1 indicates perfect correlation and -1 indicates perfect anti-correlation. For example, here is the correlation of 100 points drawn from a 3-dimensional standard normal distribution:
>>> key = jax.random.key(0) >>> x = jax.random.normal(key, shape=(3, 100)) >>> with jnp.printoptions(precision=2): ... print(jnp.corrcoef(x)) [[1. 0.03 0.12] [0.03 1. 0.01] [0.12 0.01 1. ]]
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.3