Compute the singular value decomposition.
JAX implementation of numpy.linalg.svd()
, implemented in terms of jax.lax.linalg.svd()
.
The SVD of a matrix A is given by
\[A = U\Sigma V^H\]
\(U\) contains the left singular vectors and satisfies \(U^HU=I\)
\(V\) contains the right singular vectors and satisfies \(V^HV=I\)
\(\Sigma\) is a diagonal matrix of singular values.
a (ArrayLike) – input array, of shape (..., N, M)
full_matrices (bool) – if True (default) compute the full matrices; i.e. u
and vh
have shape (..., N, N)
and (..., M, M)
. If False, then the shapes are (..., N, K)
and (..., K, M)
with K = min(N, M)
.
compute_uv (bool) – if True (default), return the full SVD (u, s, vh)
. If False then return only the singular values s
.
hermitian (bool) – if True, assume the matrix is hermitian, which allows for a more efficient implementation (default=False)
subset_by_index (tuple[int, int] | None) – (TPU-only) Optional 2-tuple [start, end] indicating the range of indices of singular values to compute. For example, if [n-2, n]
then svd
computes the two largest singular values and their singular vectors. Only compatible with full_matrices=False
.
A tuple of arrays (u, s, vh)
if compute_uv
is True, otherwise the array s
.
u
: left singular vectors of shape (..., N, N)
if full_matrices
is True or (..., N, K)
otherwise.
s
: singular values of shape (..., K)
vh
: conjugate-transposed right singular vectors of shape (..., M, M)
if full_matrices
is True or (..., K, M)
otherwise.
where K = min(N, M)
.
Array | SVDResult
Examples
Consider the SVD of a small real-valued array:
>>> x = jnp.array([[1., 2., 3.], ... [6., 5., 4.]]) >>> u, s, vt = jnp.linalg.svd(x, full_matrices=False) >>> s Array([9.361919 , 1.8315067], dtype=float32)
The singular vectors are in the columns of u
and v = vt.T
. These vectors are orthonormal, which can be demonstrated by comparing the matrix product with the identity matrix:
>>> jnp.allclose(u.T @ u, jnp.eye(2), atol=1E-5) Array(True, dtype=bool) >>> v = vt.T >>> jnp.allclose(v.T @ v, jnp.eye(2), atol=1E-5) Array(True, dtype=bool)
Given the SVD, x
can be reconstructed via matrix multiplication:
>>> x_reconstructed = u @ jnp.diag(s) @ vt >>> jnp.allclose(x_reconstructed, x) Array(True, dtype=bool)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.3