Compute the matrix square root
JAX implementation of scipy.linalg.sqrtm()
.
A (ArrayLike) – array of shape (N, N)
blocksize (int) – Not supported in JAX; JAX always uses blocksize=1
.
An array of shape (N, N)
containing the matrix square root of A
Examples
>>> a = jnp.array([[1., 2., 3.], ... [2., 4., 2.], ... [3., 2., 1.]]) >>> sqrt_a = jax.scipy.linalg.sqrtm(a) >>> with jnp.printoptions(precision=2, suppress=True): ... print(sqrt_a) [[0.92+0.71j 0.54+0.j 0.92-0.71j] [0.54+0.j 1.85+0.j 0.54-0.j ] [0.92-0.71j 0.54-0.j 0.92+0.71j]]
By definition, matrix multiplication of the matrix square root with itself should equal the input:
>>> jnp.allclose(a, sqrt_a @ sqrt_a) Array(True, dtype=bool)
Notes
This function implements the complex Schur method described in [1]. It does not use recursive blocking to speed up computations as a Sylvester Equation solver is not yet available in JAX.
References
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.3