A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://optax.readthedocs.io/en/latest/api/utilities.html below:

Utilities — Optax documentation

Utilities# General# Scale gradient#
optax.scale_gradient(inputs: chex.ArrayTree, scale: float) chex.ArrayTree[source]#

Scales gradients for the backwards pass.

Parameters:
  • inputs – A nested array.

  • scale – The scale factor for the gradient on the backwards pass.

Returns:

An array of the same structure as inputs, with scaled backward gradient.

Value and grad from state#
optax.value_and_grad_from_state(value_fn: Callable[[...], Array | float]) Callable[[...], tuple[float | Array, TypeAliasForwardRef('optax.Updates')]][source]#

Alternative to jax.value_and_grad that fetches value, grad from state.

Line-search methods such as optax.scale_by_backtracking_linesearch() require to compute the gradient and objective function at the candidate iterate. This objective value and gradient can be re-used in the next iteration to save some computations using this utility function.

Parameters:

value_fn – function returning a scalar (float or array of dimension 1), amenable to differentiation in jax using jax.value_and_grad().

Returns:

A callable akin to jax.value_and_grad() that fetches value and grad from the state if present. If no value or grad are found or if multiple value and grads are found this function raises an error. If a value is found but is infinite or nan, the value and grad are computed using jax.value_and_grad(). If the gradient found in the state is None, raises an Error.

Examples

>>> import optax
>>> import jax.numpy as jnp
>>> def fn(x): return jnp.sum(x ** 2)
>>> solver = optax.chain(
...     optax.sgd(learning_rate=1.),
...     optax.scale_by_backtracking_linesearch(
...         max_backtracking_steps=15, store_grad=True
...     )
... )
>>> value_and_grad = optax.value_and_grad_from_state(fn)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: {:.2E}'.format(fn(params)))
Objective function: 1.40E+01
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...   value, grad = value_and_grad(params, state=opt_state)
...   updates, opt_state = solver.update(
...       grad, opt_state, params, value=value, grad=grad, value_fn=fn
...   )
...   params = optax.apply_updates(params, updates)
...   print('Objective function: {:.2E}'.format(fn(params)))
Objective function: 5.04E+00
Objective function: 1.81E+00
Objective function: 6.53E-01
Objective function: 2.35E-01
Objective function: 8.47E-02
Numerical Stability# Safe increment#
optax.safe_increment(count: chex.Numeric) chex.Numeric[source]#

Increments counter by one while avoiding overflow.

Denote max_val, min_val as the maximum, minimum, possible values for the dtype of count. Normally max_val + 1 would overflow to min_val. This functions ensures that when max_val is reached the counter stays at max_val.

Parameters:

count – a counter to be incremented.

Returns:

A counter incremented by 1, or max_val if the maximum value is reached.

Examples

>>> import jax.numpy as jnp
>>> import optax
>>> optax.safe_increment(jnp.asarray(1, dtype=jnp.int32))
Array(2, dtype=int32)
>>> optax.safe_increment(jnp.asarray(2147483647, dtype=jnp.int32))
Array(2147483647, dtype=int32)

Added in version 0.2.4.

Safe norm#
optax.safe_norm(x: chex.Array, min_norm: chex.Numeric, ord: int | float | str | None = None, axis: None | tuple[int, ...] | int = None, keepdims: bool = False) chex.Array[source]#

Returns jnp.maximum(jnp.linalg.norm(x), min_norm) with correct gradients.

The gradients of jnp.maximum(jnp.linalg.norm(x), min_norm) at 0.0 is NaN, because jax will evaluate both branches of the jnp.maximum. This function will instead return the correct gradient of 0.0 also in such setting.

Parameters:
  • x – jax array.

  • min_norm – lower bound for the returned norm.

  • ord – {non-zero int, inf, -inf, ‘fro’, ‘nuc’}, optional. Order of the norm. inf means numpy’s inf object. The default is None.

  • axis – {None, int, 2-tuple of ints}, optional. If axis is an integer, it specifies the axis of x along which to compute the vector norms. If axis is a 2-tuple, it specifies the axes that hold 2-D matrices, and the matrix norms of these matrices are computed. If axis is None then either a vector norm (when x is 1-D) or a matrix norm (when x is 2-D) is returned. The default is None.

  • keepdims – bool, optional. If this is set to True, the axes which are normed over are left in the result as dimensions with size one. With this option the result will broadcast correctly against the original x.

Returns:

The safe norm of the input vector, accounting for correct gradient.

Safe root mean squares#
optax.safe_root_mean_squares(x: chex.Array, min_rms: chex.Numeric) chex.Array[source]#

Returns maximum(sqrt(mean(abs_sq(x))), min_norm) with correct grads.

The gradients of maximum(sqrt(mean(abs_sq(x))), min_norm) at 0.0 is NaN, because jax will evaluate both branches of the jnp.maximum. This function will instead return the correct gradient of 0.0 also in such setting.

Parameters:
  • x – jax array.

  • min_rms – lower bound for the returned norm.

Returns:

The safe RMS of the input vector, accounting for correct gradient.

Linear Algebra Operators# Matrix inverse pth root#
optax.matrix_inverse_pth_root(matrix: chex.Array, p: int, num_iters: int = 100, ridge_epsilon: float = 1e-06, error_tolerance: float = 1e-06, precision: Precision = Precision.HIGHEST)[source]#

Computes matrix^(-1/p), where p is a positive integer.

This function uses the Coupled newton iterations algorithm for the computation of a matrix’s inverse pth root.

Parameters:
  • matrix – the symmetric PSD matrix whose power it to be computed

  • p – exponent, for p a positive integer.

  • num_iters – Maximum number of iterations.

  • ridge_epsilon – Ridge epsilon added to make the matrix positive definite.

  • error_tolerance – Error indicator, useful for early termination.

  • precision – precision XLA related flag, the available options are: a) lax.Precision.DEFAULT (better step time, but not precise); b) lax.Precision.HIGH (increased precision, slower); c) lax.Precision.HIGHEST (best possible precision, slowest).

Returns:

matrix^(-1/p)

References

[Functions of Matrices, Theory and Computation,

Nicholas J Higham, Pg 184, Eq 7.18]( https://epubs.siam.org/doi/book/10.1137/1.9780898717778)

Power iteration#
optax.power_iteration(matrix: TypeAliasForwardRef('chex.Array') | Callable[[TypeAliasForwardRef('chex.ArrayTree')], TypeAliasForwardRef('chex.ArrayTree')], *, v0: TypeAliasForwardRef('chex.ArrayTree') | None = None, num_iters: int = 100, error_tolerance: float = 1e-06, precision: Precision = Precision.HIGHEST, key: Array | None = None) tuple[TypeAliasForwardRef('chex.Numeric'), TypeAliasForwardRef('chex.ArrayTree')][source]#

Power iteration algorithm.

This algorithm computes the dominant eigenvalue (i.e. the spectral radius) and its associated eigenvector of a diagonalizable matrix. This matrix can be given as an array or as a callable implementing a matrix-vector product.

Parameters:
  • matrix – a square matrix, either as an array or a callable implementing a matrix-vector product.

  • v0 – initial vector approximating the dominiant eigenvector. If matrix is an array of size (n, n), v0 must be a vector of size (n,). If instead matrix is a callable, then v0 must be a tree with the same structure as the input of this callable. If this argument is None and matrix is an array, then a random vector sampled from a uniform distribution in [-1, 1] is used as initial vector.

  • num_iters – Number of power iterations.

  • error_tolerance – Iterative exit condition. The procedure stops when the relative error of the estimate of the dominant eigenvalue is below this threshold.

  • precision – precision XLA related flag, the available options are: a) lax.Precision.DEFAULT (better step time, but not precise); b) lax.Precision.HIGH (increased precision, slower); c) lax.Precision.HIGHEST (best possible precision, slowest).

  • key – random key for the initialization of v0 when not given explicitly. When this argument is None, jax.random.PRNGKey(0) is used.

Returns:

A pair (eigenvalue, eigenvector), where eigenvalue is the dominant eigenvalue of matrix and eigenvector is its associated eigenvector.

References

Wikipedia contributors. Power iteration.

Note

If the matrix is not diagonalizable or the dominant eigenvalue is not unique, the algorithm may not converge.

Changed in version 0.2.2: matrix can be a callable. Reversed the order of the return parameters, from (eigenvector, eigenvalue) to (eigenvalue, eigenvector).

Non-negative least squares#
optax.nnls(A: Array, b: Array, iters: int, unroll: int | bool = 1, L: Array | float | None = None) Array[source]#

Solves the non-negative least squares problem.

Minimizes \(\|A x - b\|_2\) subject to \(x \geq 0\).

Uses the fast projected gradient (FPG) algorithm of Polyak 2015.

Parameters:
  • A – Input matrix of shape (M, N).

  • b – Input vector of shape (M,) or matrix of shape (M, K).

  • iters – Number of iterations to run the algorithm for.

  • unroll – Unroll parameter passed to lax.scan.

  • L – An upper bound on the spectral radius of A.mT @ A (optional).

Returns:

A solution vector of shape (N,) or matrix of shape (N, K).

Examples

>>> from jax import numpy as jnp
>>> import optax
>>> A = jnp.array([[1., 2.], [3., 4.]])
>>> b = jnp.array([5., 6.])
>>> x = optax.nnls(A, b, 10**3)
>>> print(f"{x[0]:.2f}")
0.00
>>> print(f"{x[1]:.2f}")
1.70

References

Roman A. Polyak, Projected gradient method for non-negative least square, 2015

Second Order Optimization# Fisher diagonal#
optax.second_order.fisher_diag(negative_log_likelihood: LossFn, params: Any, inputs: Array, targets: Array) Array[source]#

Computes the diagonal of the (observed) Fisher information matrix.

Parameters:
  • negative_log_likelihood – the negative log likelihood function with expected signature loss = fn(params, inputs, targets).

  • params – model parameters.

  • inputs – inputs at which negative_log_likelihood is evaluated.

  • targets – targets at which negative_log_likelihood is evaluated.

Returns:

An Array corresponding to the product to the Hessian of negative_log_likelihood evaluated at (params, inputs, targets).

Hessian diagonal#
optax.second_order.hessian_diag(loss: LossFn, params: Any, inputs: Array, targets: Array) Array[source]#

Computes the diagonal hessian of loss at (inputs, targets).

Parameters:
  • loss – the loss function.

  • params – model parameters.

  • inputs – inputs at which loss is evaluated.

  • targets – targets at which loss is evaluated.

Returns:

A DeviceArray corresponding to the product to the Hessian of loss evaluated at (params, inputs, targets).

Hessian vector product#
optax.second_order.hvp(loss: LossFn, v: Array, params: Any, inputs: Array, targets: Array) Array[source]#

Performs an efficient vector-Hessian (of loss) product.

Parameters:
  • loss – the loss function.

  • v – a vector of size ravel(params).

  • params – model parameters.

  • inputs – inputs at which loss is evaluated.

  • targets – targets at which loss is evaluated.

Returns:

An Array corresponding to the product of v and the Hessian of loss evaluated at (params, inputs, targets).

Tree# NamedTupleKey#
class optax.tree_utils.NamedTupleKey(tuple_name: str, name: str)[source]#

KeyType for a NamedTuple in a tree.

When using a function filtering(path: KeyPath, value: Any) -> bool: ... in a tree in optax.tree_utils.tree_get_all_with_path(), optax.tree_utils.tree_get(), or optax.tree_utils.tree_set(), can filter the path to check if of the KeyEntry is a NamedTupleKey and then check if the name of named tuple is the one intended to be searched.

tuple_name#

name of the tuple containing the key.

Type:

str

name#

name of the key.

Type:

str

Added in version 0.2.2.

Tree add#
optax.tree_utils.tree_add(tree_x: Any, tree_y: Any, *other_trees: Any) Any[source]#

Add two (or more) pytrees.

Parameters:
  • tree_x – first pytree.

  • tree_y – second pytree.

  • *other_trees – optional other trees to add

Returns:

the sum of the two (or more) pytrees.

Changed in version 0.2.1: Added optional *other_trees argument.

Tree add and scalar multiply#
optax.tree_utils.tree_add_scale(tree_x: Any, scalar: float | Array, tree_y: Any) Any[source]#

Add two trees, where the second tree is scaled by a scalar.

In infix notation, the function performs out = tree_x + scalar * tree_y.

Parameters:
  • tree_x – first pytree.

  • scalar – scalar value.

  • tree_y – second pytree.

Returns:

a pytree with the same structure as tree_x and tree_y.

Tree all close#
optax.tree_utils.tree_allclose(a: Any, b: Any, rtol: Array | ndarray | bool | number | bool | int | float | complex = 1e-05, atol: Array | ndarray | bool | number | bool | int | float | complex = 1e-08, equal_nan: bool = False)[source]#

Check whether two trees are element-wise approximately equal within a tolerance.

See jax.numpy.allclose() for the equivalent on arrays.

Parameters:
  • a – a tree

  • b – a tree

  • rtol – relative tolerance used for approximate equality

  • atol – absolute tolerance used for approximate equality

  • equal_nan – boolean indicating whether NaNs are treated as equal

Returns:

a boolean value.

Tree batch reshaping#
optax.tree_utils.tree_batch_shape(tree: Any, shape: tuple[int, ...] = ())[source]#

Add leading batch dimensions to each leaf of a pytree.

Parameters:
  • tree – a pytree.

  • shape – a shape indicating what leading batch dimensions to add.

Returns:

a pytree with the leading batch dimensions added.

Tree cast#
optax.tree_utils.tree_cast(tree: chex.ArrayTree, dtype: str | type[Any] | dtype | SupportsDType | None) chex.ArrayTree[source]#

Cast tree to given dtype, skip if None.

Parameters:
  • tree – the tree to cast.

  • dtype – the dtype to cast to, or None to skip.

Returns:

the tree, with leaves cast to dtype.

Examples

>>> import jax.numpy as jnp
>>> import optax
>>> tree = {'a': {'b': jnp.array(1.0, dtype=jnp.float32)},
...         'c': jnp.array(2.0, dtype=jnp.float32)}
>>> optax.tree_utils.tree_cast(tree, dtype=jnp.bfloat16)
{'a': {'b': Array(1, dtype=bfloat16)}, 'c': Array(2, dtype=bfloat16)}
Tree cast like#
optax.tree_utils.tree_cast_like(tree: T, other_tree: chex.ArrayTree) T[source]#

Cast tree to dtypes of other_tree.

Parameters:
  • tree – the tree to cast.

  • other_tree – reference array tree to use to cast to dtypes of leaves

Returns:

the tree, with leaves cast to dtype.

Examples

>>> import jax.numpy as jnp
>>> import optax
>>> tree = {'a': {'b': jnp.array(1.0, dtype=jnp.float32)},
...         'c': jnp.array(2.0, dtype=jnp.float32)}
>>> other_tree = {'a': {'b': jnp.array(1.0, dtype=jnp.float32)},
...               'c': jnp.array(2.0, dtype=jnp.bfloat16)}
>>> optax.tree_utils.tree_cast_like(tree, other_tree)
{'a': {'b': Array(1., dtype=float32)}, 'c': Array(2, dtype=bfloat16)}
Tree clip#
optax.tree_utils.tree_clip(tree: Any, min_value: Array | ndarray | bool | number | bool | int | float | complex | None = None, max_value: Array | ndarray | bool | number | bool | int | float | complex | None = None) Any[source]#

Creates an identical tree where all tensors are clipped to [min, max].

Parameters:
  • tree – pytree.

  • min_value – optional minimal value to clip all tensors to. If None (default) then result will not be clipped to any minimum value.

  • max_value – optional maximal value to clip all tensors to. If None (default) then result will not be clipped to any maximum value.

Returns:

a tree with the same structure as tree.

Added in version 0.2.3.

Tree conjugate#
optax.tree_utils.tree_conj(tree: Any) Any[source]#

Compute the conjugate of a pytree.

Parameters:

tree – pytree.

Returns:

a pytree with the same structure as tree.

Tree data type#
optax.tree_utils.tree_dtype(tree: chex.ArrayTree, mixed_dtype_handler: str | None = None) str | type[Any] | dtype | SupportsDType[source]#

Fetch dtype of tree.

If the tree is empty, returns the default dtype of JAX arrays.

Parameters:
  • tree – the tree to fetch the dtype of.

  • mixed_dtype_handler – how to handle mixed dtypes in the tree. - If mixed_dtype_handler=None, returns the common dtype of the leaves of the tree if it exists, otherwise raises an error. - If mixed_dtype_handler='promote', promotes the dtypes of the leaves of the tree to a common promoted dtype using jax.numpy.promote_types(). - If mixed_dtype_handler='highest' or mixed_dtype_handler='lowest', returns the highest/lowest dtype of the leaves of the tree. We consider a partial ordering of dtypes as dtype1 <= dtype2 if dtype1 is promoted to dtype2, that is, if jax.numpy.promote_types(dtype1, dtype2) == dtype2. Since some dtypes cannot be promoted to one another, this is not a total ordering, and the ‘highest’ or ‘lowest’ options may not be applicable. These options will throw an error if the dtypes of the leaves of the tree cannot be promoted to one another.

Returns:

the dtype of the tree.

Raises:
  • ValueError – If mixed_dtype_handler is set to None and multiple dtypes are found in the tree.

  • ValueError – If mixed_dtype_handler is set to 'highest' or 'lowest' and some leaves’ dtypes in the tree cannot be promoted to one another.

Examples

>>> import jax.numpy as jnp
>>> import optax
>>> tree = {'a': {'b': jnp.array(1.0, dtype=jnp.float32)},
...         'c': jnp.array(2.0, dtype=jnp.float32)}
>>> optax.tree_utils.tree_dtype(tree)
dtype('float32')
>>> tree = {'a': {'b': jnp.array(1.0, dtype=jnp.float16)},
...         'c': jnp.array(2.0, dtype=jnp.float32)}
>>> optax.tree_utils.tree_dtype(tree, 'lowest')
dtype('float16')
>>> optax.tree_utils.tree_dtype(tree, 'highest')
dtype('float32')
>>> tree = {'a': {'b': jnp.array(1.0, dtype=jnp.int32)},
...         'c': jnp.array(2.0, dtype=jnp.uint32)}
>>> # optax.tree_utils.tree_dtype(tree, 'highest')
>>> # -> will throw an error because int32 and uint32
>>> # cannot be promoted to one another.
>>> optax.tree_utils.tree_dtype(tree, 'promote')
dtype('int64')

Added in version 0.2.4.

Tree full like#
optax.tree_utils.tree_full_like(tree: Any, fill_value: Array | ndarray | bool | number | bool | int | float | complex, dtype: str | type[Any] | dtype | SupportsDType | None = None) Any[source]#

Creates an identical tree where all tensors are filled with fill_value.

Parameters:
  • tree – pytree.

  • fill_value – the fill value for all tensors in the tree.

  • dtype – optional dtype to use for the tensors in the tree.

Returns:

an tree with the same structure as tree.

Tree divide#
optax.tree_utils.tree_div(tree_x: Any, tree_y: Any) Any[source]#

Divide two pytrees.

Parameters:
  • tree_x – first pytree.

  • tree_y – second pytree.

Returns:

the quotient of the two pytrees.

Fetch single value that match a given key#
optax.tree_utils.tree_get(tree: optax.PyTree, key: Any, default: Any | None = None, filtering: Callable[[Tuple[DictKey | FlattenedIndexKey | GetAttrKey | SequenceKey | NamedTupleKey, ...], Any], bool] | None = None) Any[source]#

Extract a value from a pytree matching a given key.

Search in the tree for a specific key (which can be a key from a dictionary, a field from a NamedTuple or the name of a NamedTuple).

If the tree does not containt key returns default.

Raises a KeyError if multiple values of key are found in tree.

Generally, you may first get all pairs (path_to_value, value) for a given key using optax.tree_utils.tree_get_all_with_path(). You may then define a filtering operation filtering(path: Key_Path, value: Any) -> bool: ... that enables you to select the specific values you wanted to fetch by looking at the type of the value, or looking at the path to that value. Note that contrarily to the paths returned by jax.tree_util.tree_leaves_with_path() the paths analyzed by the filtering operation in optax.tree_utils.tree_get_all_with_path(), optax.tree_utils.tree_get(), or optax.tree_utils.tree_set() detail the names of the named tuples considered in the path. Concretely, if the value considered is in the attribute key of a named tuple called MyNamedTuple the last element of the path will be a optax.tree_utils.NamedTupleKey containing both name=key and tuple_name='MyNamedTuple'. That way you may distinguish between identical values in different named tuples (arising for example when chaining transformations in optax). See the last example below.

Parameters:
  • tree – tree to search in.

  • key – keyword or field to search in tree for.

  • default – default value to return if key is not found in tree.

  • filtering – optional callable to further filter values in tree that match the key. filtering(path: Key_Path, value: Any) -> bool: ... takes as arguments both the path to the value (as returned by optax.tree_utils.tree_get_all_with_path()) and the value that match the given key.

Returns:
value

value in tree matching the given key. If none are found return default value. If multiple are found raises an error.

Raises:

KeyError – If multiple values of key are found in tree.

Examples

Basic usage

>>> import jax.numpy as jnp
>>> import optax
>>> params = jnp.array([1., 2., 3.])
>>> opt = optax.adam(learning_rate=1.)
>>> state = opt.init(params)
>>> count = optax.tree_utils.tree_get(state, 'count')
>>> print(count)
0

Usage with a filtering operation

>>> import jax.numpy as jnp
>>> import optax
>>> params = jnp.array([1., 2., 3.])
>>> opt = optax.inject_hyperparams(optax.sgd)(
...   learning_rate=lambda count: 1/(count+1)
... )
>>> state = opt.init(params)
>>> filtering = lambda path, value: isinstance(value, jnp.ndarray)
>>> lr = optax.tree_utils.tree_get(
...   state, 'learning_rate', filtering=filtering
... )
>>> print(lr)
1.0

Extracting a named tuple by its name

>>> params = jnp.array([1., 2., 3.])
>>> opt = optax.chain(
...     optax.add_noise(1.0, 0.9, key=0),
...     optax.scale_by_adam()
... )
>>> state = opt.init(params)
>>> noise_state = optax.tree_utils.tree_get(state, 'AddNoiseState')
>>> print(noise_state)
AddNoiseState(count=Array(0, dtype=int32), rng_key=Array((), dtype=key<fry>) overlaying:
[0 0])

Differentiating between two values by the name of their named tuples.

>>> import jax.numpy as jnp
>>> import optax
>>> params = jnp.array([1., 2., 3.])
>>> opt = optax.chain(
...   optax.add_noise(1.0, 0.9, key=0),
...   optax.scale_by_adam()
... )

Added in version 0.2.2.

Fetch all values that match a given key#
optax.tree_utils.tree_get_all_with_path(tree: optax.PyTree, key: Any, filtering: Callable[[Tuple[DictKey | FlattenedIndexKey | GetAttrKey | SequenceKey | NamedTupleKey, ...], Any], bool] | None = None) list[tuple[Tuple[DictKey | FlattenedIndexKey | GetAttrKey | SequenceKey | NamedTupleKey, ...], Any]][source]#

Extract values of a pytree matching a given key.

Search in a pytree tree for a specific key (which can be a key from a dictionary, a field from a NamedTuple or the name of a NamedTuple).

That key/field key may appear more than once in tree. So this function returns a list of all values corresponding to key with the path to that value. The path is a sequence of KeyEntry that can be transformed in readable format using jax.tree_util.keystr(), see the example below.

Parameters:
  • tree – tree to search in.

  • key – keyword or field to search in tree for.

  • filtering – optional callable to further filter values in tree that match the key. filtering(path: Key_Path, value: Any) -> bool: ... takes as arguments both the path to the value (as returned by optax.tree_utils.tree_get_all_with_path()) and the value that match the given key.

Returns:
values_with_path

list of tuples where each tuple is of the form (path_to_value, value). Here value is one entry of the tree that corresponds to the key, and path_to_value is a tuple of KeyEntry that is a tuple of jax.tree_util.DictKey, jax.tree_util.FlattenedIndexKey, jax.tree_util.GetAttrKey, jax.tree_util.SequenceKey, or optax.tree_utils.NamedTupleKey.

Examples

Basic usage

>>> import jax.numpy as jnp
>>> import optax
>>> params = jnp.array([1., 2., 3.])
>>> solver = optax.inject_hyperparams(optax.sgd)(
...   learning_rate=lambda count: 1/(count+1)
... )
>>> state = solver.init(params)
>>> found_values_with_path = optax.tree_utils.tree_get_all_with_path(
...   state, 'learning_rate'
... )
>>> print(
... *[(jax.tree_util.keystr(p), v) for p, v in found_values_with_path],
... sep="\n",
... )
("InjectStatefulHyperparamsState.hyperparams['learning_rate']", Array(1., dtype=float32))
("InjectStatefulHyperparamsState.hyperparams_states['learning_rate']", WrappedScheduleState(count=Array(0, dtype=int32)))

Usage with a filtering operation

>>> import jax.numpy as jnp
>>> import optax
>>> params = jnp.array([1., 2., 3.])
>>> solver = optax.inject_hyperparams(optax.sgd)(
...   learning_rate=lambda count: 1/(count+1)
... )
>>> state = solver.init(params)
>>> filtering = lambda path, value: isinstance(value, tuple)
>>> found_values_with_path = optax.tree_utils.tree_get_all_with_path(
...   state, 'learning_rate', filtering
... )
>>> print(
... *[(jax.tree_util.keystr(p), v) for p, v in found_values_with_path],
... sep="\n",
... )
("InjectStatefulHyperparamsState.hyperparams_states['learning_rate']", WrappedScheduleState(count=Array(0, dtype=int32)))

Added in version 0.2.2.

Tree norm#
optax.tree_utils.tree_norm(tree: Any, ord: int | str | float | None = None, squared: bool = False) Array[source]#

Compute the vector norm of the given ord of a pytree.

Parameters:
  • tree – pytree.

  • ord – the order of the vector norm to compute from (None, 1, 2, inf).

  • squared – whether the norm should be returned squared or not.

Returns:

a scalar value.

Tree map parameters#
optax.tree_utils.tree_map_params(initable: Callable[[TypeAliasForwardRef('optax.Params')], TypeAliasForwardRef('optax.OptState')] | Initable, f: Callable[[...], Any], state: optax.OptState, /, *rest: Any, transform_non_params: Callable[[...], Any] | None = None, is_leaf: Callable[[TypeAliasForwardRef('optax.Params')], bool] | None = None) optax.OptState[source]#

Apply a callable over all params in the given optimizer state.

This function exists to help construct partition specs over optimizer states, in the case that a partition spec is already known for the parameters.

For example, the following will replace all optimizer state parameter trees with copies of the given partition spec instead. The argument transform_non_params can be used to replace any remaining fields as required, in this case, we replace those fields by None.

>>> params, specs = jnp.array(0.), jnp.array(0.)  # Trees with the same shape
>>> opt = optax.sgd(1e-3)
>>> state = opt.init(params)
>>> opt_specs = optax.tree_map_params(
...     opt,
...     lambda _, spec: spec,
...     state,
...     specs,
...     transform_non_params=lambda _: None,
...     )
Parameters:
  • initable – A callable taking parameters and returning an optimizer state, or an object with an init attribute having the same function.

  • f – A callable that will be applied for all copies of the parameter tree within this optimizer state.

  • state – The optimizer state to map over.

  • *rest – Additional arguments, having the same shape as the parameter tree, that will be passed to f.

  • transform_non_params – An optional function that will be called on all non-parameter fields within the optimizer state.

  • is_leaf – Passed through to jax.tree.map. This makes it possible to ignore parts of the parameter tree e.g. when the gradient transformations modify the shape of the original pytree, such as for optax.masked.

Returns:

The result of applying the function f on all trees in the optimizer’s state that have the same shape as the parameter tree, along with the given optional extra arguments.

Tree max#
optax.tree_utils.tree_max(tree: Any) chex.Numeric[source]#

Compute the max of all the elements in a pytree.

Parameters:

tree – pytree.

Returns:

a scalar value.

Tree min#
optax.tree_utils.tree_min(tree: Any) chex.Numeric[source]#

Compute the min of all the elements in a pytree.

Parameters:

tree – pytree.

Returns:

a scalar value.

Tree multiply#
optax.tree_utils.tree_mul(tree_x: Any, tree_y: Any) Any[source]#

Multiply two pytrees.

Parameters:
  • tree_x – first pytree.

  • tree_y – second pytree.

Returns:

the product of the two pytrees.

Tree ones like#
optax.tree_utils.tree_ones_like(tree: Any, dtype: str | type[Any] | dtype | SupportsDType | None = None) Any[source]#

Creates an all-ones tree with the same structure.

Parameters:
  • tree – pytree.

  • dtype – optional dtype to use for the tree of ones.

Returns:

an all-ones tree with the same structure as tree.

Split key according to structure of a tree#
optax.tree_utils.tree_split_key_like(rng_key: Array, target_tree: chex.ArrayTree) chex.ArrayTree[source]#

Split keys to match structure of target tree.

Parameters:
  • rng_key – the key to split.

  • target_tree – the tree whose structure to match.

Returns:

a tree of rng keys.

Tree with random values#
optax.tree_utils.tree_random_like(rng_key: ~jax.Array, target_tree: chex.ArrayTree, sampler: ~collections.abc.Callable[[~jax.Array, ~typing.Sequence[int | ~typing.Any], str | type[~typing.Any] | ~numpy.dtype | ~jax._src.typing.SupportsDType], TypeAliasForwardRef('chex.Array')] | ~collections.abc.Callable[[~jax.Array, ~typing.Sequence[int | ~typing.Any], str | type[~typing.Any] | ~numpy.dtype | ~jax._src.typing.SupportsDType, ~jaxlib._jax.Sharding], TypeAliasForwardRef('chex.Array')] = <function normal>, dtype: str | type[~typing.Any] | ~numpy.dtype | ~jax._src.typing.SupportsDType | None = None) chex.ArrayTree[source]#

Create tree with random entries of the same shape as target tree.

Parameters:
  • rng_key – the key for the random number generator.

  • target_tree – the tree whose structure to match. Leaves must be arrays.

  • sampler – the noise sampling function, by default jax.random.normal.

  • dtype – the desired dtype for the random numbers, passed to sampler. If None, the dtype of the target tree is used if possible.

Returns:

a random tree with the same structure as target_tree, whose leaves have distribution sampler.

Warning

The possible dtypes may be limited by the sampler, for example jax.random.rademacher only supports integer dtypes and will raise an error if the dtype of the target tree is not an integer or if the dtype is not of integer type.

Added in version 0.2.1.

Tree real part#
optax.tree_utils.tree_real(tree: Any) Any[source]#

Compute the real part of a pytree.

Parameters:

tree – pytree.

Returns:

a pytree with the same structure as tree.

Tree scalar multiply#
optax.tree_utils.tree_scale(scalar: float | Array, tree: Any) Any[source]#

Multiply a tree by a scalar.

In infix notation, the function performs out = scalar * tree.

Parameters:
  • scalar – scalar value.

  • tree – pytree.

Returns:

a pytree with the same structure as tree.

Set values in a tree#
optax.tree_utils.tree_set(tree: optax.PyTree, filtering: Callable[[Tuple[DictKey | FlattenedIndexKey | GetAttrKey | SequenceKey | NamedTupleKey, ...], Any], bool] | None = None, /, **kwargs: Any) optax.PyTree[source]#

Creates a copy of tree with some values replaced as specified by kwargs.

Search in the tree for keys in **kwargs (which can be a key from a dictionary, a field from a NamedTuple or the name of a NamedTuple). If such a key is found, replace the corresponding value with the one given in **kwargs.

Raises a KeyError if some keys in **kwargs are not present in the tree.

Parameters:
  • tree – pytree whose values are to be replaced.

  • filtering – optional callable to further filter values in tree that match the keys to replace. filtering(path: Key_Path, value: Any) -> bool: ... takes as arguments both the path to the value (as returned by optax.tree_utils.tree_get_all_with_path()) and the value that match a given key.

  • **kwargs – dictionary of keys with values to replace in tree.

Returns:
new_tree

new pytree with the same structure as tree. For each element in tree whose key/field matches a key in **kwargs, its value is set by the corresponding value in **kwargs.

Raises:

KeyError – If no values of some key in **kwargs are found in tree or none of the values satisfy the filtering operation.

Examples

Basic usage

>>> import jax.numpy as jnp
>>> import optax
>>> params = jnp.array([1., 2., 3.])
>>> opt = optax.adam(learning_rate=1.)
>>> state = opt.init(params)
>>> print(state)
(ScaleByAdamState(count=Array(0, dtype=int32), mu=Array([0., 0., 0.], dtype=float32), nu=Array([0., 0., 0.], dtype=float32)), EmptyState())
>>> new_state = optax.tree_utils.tree_set(state, count=2.)
>>> print(new_state)
(ScaleByAdamState(count=2.0, mu=Array([0., 0., 0.], dtype=float32), nu=Array([0., 0., 0.], dtype=float32)), EmptyState())

Usage with a filtering operation

>>> import jax.numpy as jnp
>>> import optax
>>> params = jnp.array([1., 2., 3.])
>>> opt = optax.inject_hyperparams(optax.sgd)(
...     learning_rate=lambda count: 1/(count+1)
...  )
>>> state = opt.init(params)
>>> print(state)
InjectStatefulHyperparamsState(count=Array(0, dtype=int32), hyperparams={'learning_rate': Array(1., dtype=float32)}, hyperparams_states={'learning_rate': WrappedScheduleState(count=Array(0, dtype=int32))}, inner_state=(EmptyState(), EmptyState()))
>>> filtering = lambda path, value: isinstance(value, jnp.ndarray)
>>> new_state = optax.tree_utils.tree_set(
...   state, filtering, learning_rate=jnp.asarray(0.1)
... )
>>> print(new_state)
InjectStatefulHyperparamsState(count=Array(0, dtype=int32), hyperparams={'learning_rate': Array(0.1, dtype=float32, weak_type=True)}, hyperparams_states={'learning_rate': WrappedScheduleState(count=Array(0, dtype=int32))}, inner_state=(EmptyState(), EmptyState()))

Note

The recommended usage to inject hyperparameters schedules is through optax.inject_hyperparams(). This function is a helper for other purposes.

Added in version 0.2.2.

Tree size#
optax.tree_utils.tree_size(tree: Any) int[source]#

Total size of a pytree.

Parameters:

tree – pytree

Returns:

the total size of the pytree.

Tree subtract#
optax.tree_utils.tree_sub(tree_x: Any, tree_y: Any) Any[source]#

Subtract two pytrees.

Parameters:
  • tree_x – first pytree.

  • tree_y – second pytree.

Returns:

the difference of the two pytrees.

Tree sum#
optax.tree_utils.tree_sum(tree: Any) chex.Numeric[source]#

Compute the sum of all the elements in a pytree.

Parameters:

tree – pytree.

Returns:

a scalar value.

Tree inner product#
optax.tree_utils.tree_vdot(tree_x: Any, tree_y: Any) chex.Numeric[source]#

Compute the inner product between two pytrees.

Parameters:
  • tree_x – first pytree to use.

  • tree_y – second pytree to use.

Returns:

inner product between tree_x and tree_y, a scalar value.

Examples

>>> optax.tree_utils.tree_vdot(
...   {'a': jnp.array([1, 2]), 'b': jnp.array([1, 2])},
...   {'a': jnp.array([-1, -1]), 'b': jnp.array([1, 1])},
... )
Array(0, dtype=int32)

Note

We upcast the values to the highest precision to avoid numerical issues.

Tree where#
optax.tree_utils.tree_where(condition, tree_x, tree_y)[source]#

Select tree_x values if condition is true else tree_y values.

Parameters:
  • condition – boolean specifying which values to select from tree x or tree_y

  • tree_x – pytree chosen if condition is True

  • tree_y – pytree chosen if condition is False

Returns:

tree_x or tree_y depending on condition.

Tree zeros like#
optax.tree_utils.tree_zeros_like(tree: Any, dtype: str | type[Any] | dtype | SupportsDType | None = None) Any[source]#

Creates an all-zeros tree with the same structure.

Parameters:
  • tree – pytree.

  • dtype – optional dtype to use for the tree of zeros.

Returns:

an all-zeros tree with the same structure as tree.


RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.4