Compute the norm of a matrix or vector.
JAX implementation of numpy.linalg.norm()
.
x (ArrayLike) – N-dimensional array for which the norm will be computed.
ord (int | str | None) – specify the kind of norm to take. Default is Frobenius norm for matrices, and the 2-norm for vectors. For other options, see Notes below.
axis (None | tuple[int, ...] | int) – integer or sequence of integers specifying the axes over which the norm will be computed. For a single axis, compute a vector norm. For two axes, compute a matrix norm. Defaults to all axes of x
.
keepdims (bool) – if True, the output array will have the same number of dimensions as the input, with the size of reduced axes replaced by 1
(default: False).
array containing the specified norm of x.
Notes
The flavor of norm computed depends on the value of ord
and the number of axes being reduced.
For vector norms (i.e. a single axis reduction):
ord=None
(default) computes the 2-norm
ord=inf
computes max(abs(x))
ord=-inf
computes min(abs(x))``
ord=0
computes sum(x!=0)
for other numerical values, computes sum(abs(x) ** ord)**(1/ord)
For matrix norms (i.e. two axes reductions):
ord='fro'
or ord=None
(default) computes the Frobenius norm
ord='nuc'
computes the nuclear norm, or the sum of the singular values
ord=1
computes max(abs(x).sum(0))
ord=-1
computes min(abs(x).sum(0))
ord=2
computes the 2-norm, i.e. the largest singular value
ord=-2
computes the smallest singular value
In the special case of ord=None
and axis=None
, this function accepts an array of any dimension and computes the vector 2-norm of the flattened array.
Examples
Vector norms:
>>> x = jnp.array([3., 4., 12.]) >>> jnp.linalg.norm(x) Array(13., dtype=float32) >>> jnp.linalg.norm(x, ord=1) Array(19., dtype=float32) >>> jnp.linalg.norm(x, ord=0) Array(3., dtype=float32)
Matrix norms:
>>> x = jnp.array([[1., 2., 3.], ... [4., 5., 7.]]) >>> jnp.linalg.norm(x) # Frobenius norm Array(10.198039, dtype=float32) >>> jnp.linalg.norm(x, ord='nuc') # nuclear norm Array(10.762535, dtype=float32) >>> jnp.linalg.norm(x, ord=1) # 1-norm Array(10., dtype=float32)
Batched vector norm:
>>> jnp.linalg.norm(x, axis=1) Array([3.7416575, 9.486833 ], dtype=float32)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4