A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.norm.html below:

jax.numpy.linalg.norm — JAX documentation

jax.numpy.linalg.norm#
jax.numpy.linalg.norm(x, ord=None, axis=None, keepdims=False)[source]#

Compute the norm of a matrix or vector.

JAX implementation of numpy.linalg.norm().

Parameters:
  • x (ArrayLike) – N-dimensional array for which the norm will be computed.

  • ord (int | str | None) – specify the kind of norm to take. Default is Frobenius norm for matrices, and the 2-norm for vectors. For other options, see Notes below.

  • axis (None | tuple[int, ...] | int) – integer or sequence of integers specifying the axes over which the norm will be computed. For a single axis, compute a vector norm. For two axes, compute a matrix norm. Defaults to all axes of x.

  • keepdims (bool) – if True, the output array will have the same number of dimensions as the input, with the size of reduced axes replaced by 1 (default: False).

Returns:

array containing the specified norm of x.

Return type:

Array

Notes

The flavor of norm computed depends on the value of ord and the number of axes being reduced.

For vector norms (i.e. a single axis reduction):

For matrix norms (i.e. two axes reductions):

In the special case of ord=None and axis=None, this function accepts an array of any dimension and computes the vector 2-norm of the flattened array.

Examples

Vector norms:

>>> x = jnp.array([3., 4., 12.])
>>> jnp.linalg.norm(x)
Array(13., dtype=float32)
>>> jnp.linalg.norm(x, ord=1)
Array(19., dtype=float32)
>>> jnp.linalg.norm(x, ord=0)
Array(3., dtype=float32)

Matrix norms:

>>> x = jnp.array([[1., 2., 3.],
...                [4., 5., 7.]])
>>> jnp.linalg.norm(x)  # Frobenius norm
Array(10.198039, dtype=float32)
>>> jnp.linalg.norm(x, ord='nuc')  # nuclear norm
Array(10.762535, dtype=float32)
>>> jnp.linalg.norm(x, ord=1)  # 1-norm
Array(10., dtype=float32)

Batched vector norm:

>>> jnp.linalg.norm(x, axis=1)
Array([3.7416575, 9.486833 ], dtype=float32)

RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.4